├── IRL_market_model-paper-multiple_stocks-with_next_day return_and_SP500.ipynb ├── README.md ├── SSRN-id3174498.pdf ├── dja_cap.csv ├── spx_holdings_and_spx_closeprice.csv └── tensorflow_optimization_examples.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Inverse_Reinforcement_Learning_for_Stocks 2 | In this project, we will 3 | 4 | 1. Explore and estimate an IRL-based model of market returns that is based on IRL of a market-optimal portfolio 5 | 2. Investigate the role and impact of choices of different signals on model estimation and trading strategies 6 | 3. Compare simple IRL-based and UL-based trading strategies 7 | 8 | by implementing the model of Halperin and Feldshteyn (2012) for DJIA and SP500 9 | 10 | Data and Jupyter notebook files are included so one can reproduce the results and make future enhancement. 11 | 12 | **dja_cap.csv** - containes DJIA stock prices 13 | 14 | **spx_holdings_and_spx_closeprice.csv** - contains SP500 stock prices 15 | 16 | **SSRN-id3174498.pdf​** - the paper that describes the IRL model 17 | 18 | **IRL_market_model-paper-multiple_stocks-with_next_day return_and_SP500.ipynb** - Python notebook 19 | 20 | **tensorflow_optimization_examples.ipynb** - an example of using Tensorflow for minimizing the Maximum Likelihood function 21 | 22 | -------------------------------------------------------------------------------- /SSRN-id3174498.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chandc/Inverse_Reinforcement_Learning_for_Stocks/09cbb47a6685c37d23ebabdb15764c5e7fe6a86d/SSRN-id3174498.pdf -------------------------------------------------------------------------------- /tensorflow_optimization_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/danielchan/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 13 | " from ._conv import register_converters as _register_converters\n" 14 | ] 15 | }, 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "WARNING:tensorflow:From /home/danielchan/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py:118: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.\n", 21 | "Instructions for updating:\n", 22 | "Use `tf.global_variables_initializer` instead.\n", 23 | "starting at x: 2.0 log(x)^2: 0.480453\n", 24 | "step 0 x: 1.6534264 log(x)^2: 0.25285786\n", 25 | "step 1 x: 1.3493005 log(x)^2: 0.08975195\n", 26 | "step 2 x: 1.1272697 log(x)^2: 0.014351694\n", 27 | "step 3 x: 1.0209966 log(x)^2: 0.0004317743\n", 28 | "step 4 x: 1.0006447 log(x)^2: 4.1534943e-07\n", 29 | "step 5 x: 1.0000006 log(x)^2: 3.5527118e-13\n", 30 | "step 6 x: 1.0 log(x)^2: 0.0\n", 31 | "step 7 x: 1.0 log(x)^2: 0.0\n", 32 | "step 8 x: 1.0 log(x)^2: 0.0\n", 33 | "step 9 x: 1.0 log(x)^2: 0.0\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "import tensorflow as tf\n", 39 | "\n", 40 | "x = tf.Variable(2, name='x', dtype=tf.float32)\n", 41 | "log_x = tf.log(x)\n", 42 | "log_x_squared = tf.square(log_x)\n", 43 | "\n", 44 | "optimizer = tf.train.GradientDescentOptimizer(0.5)\n", 45 | "train = optimizer.minimize(log_x_squared)\n", 46 | "\n", 47 | "init = tf.initialize_all_variables()\n", 48 | "\n", 49 | "def optimize():\n", 50 | " with tf.Session() as session:\n", 51 | " session.run(init)\n", 52 | " print(\"starting at\", \"x:\", session.run(x), \"log(x)^2:\", session.run(log_x_squared))\n", 53 | " for step in range(10): \n", 54 | " session.run(train)\n", 55 | " print(\"step\", step, \"x:\", session.run(x), \"log(x)^2:\", session.run(log_x_squared))\n", 56 | " \n", 57 | "\n", 58 | "optimize()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 23, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "0 [1.8902434] [0.75676066] [793.9243]\n", 71 | "100 [1.7902387] [0.85675937] [552.0268]\n", 72 | "200 [1.6902341] [0.9567581] [361.5271]\n", 73 | "300 [1.5902294] [1.0567602] [217.04721]\n", 74 | "400 [1.4902247] [1.1567649] [113.450935]\n", 75 | "500 [1.39022] [1.2567695] [45.84207]\n", 76 | "600 [1.2902154] [1.3567742] [9.563324]\n", 77 | "700 [1.2009802] [1.4430666] [0.0404439]\n", 78 | "800 [1.1951108] [1.4289846] [0.0381165]\n", 79 | "900 [1.1893662] [1.4152687] [0.03590535]\n", 80 | "1000 [1.1837461] [1.4019138] [0.03380606]\n", 81 | "1100 [1.17825] [1.388914] [0.03181415]\n", 82 | "1200 [1.1728771] [1.3762645] [0.0299254]\n", 83 | "1300 [1.1676264] [1.3639578] [0.02813538]\n", 84 | "1400 [1.1624974] [1.35199] [0.02644018]\n", 85 | "1500 [1.1574893] [1.340355] [0.02483577]\n", 86 | "1600 [1.1526014] [1.3290471] [0.02331822]\n", 87 | "1700 [1.1478317] [1.3180588] [0.02188351]\n", 88 | "1800 [1.1431798] [1.3073858] [0.02052809]\n", 89 | "1900 [1.1386445] [1.2970217] [0.01924835]\n", 90 | "2000 [1.1342242] [1.28696] [0.01804069]\n", 91 | "2100 [1.1299179] [1.2771952] [0.01690177]\n", 92 | "2200 [1.1257238] [1.2677207] [0.01582825]\n", 93 | "2300 [1.1216404] [1.2585298] [0.01481687]\n", 94 | "2400 [1.1176664] [1.2496169] [0.01386463]\n", 95 | "2500 [1.1138003] [1.2409765] [0.0129686]\n", 96 | "2600 [1.1100402] [1.2326012] [0.01212582]\n", 97 | "2700 [1.1063843] [1.2244855] [0.01133355]\n", 98 | "2800 [1.102831] [1.2166232] [0.01058919]\n", 99 | "2900 [1.0993788] [1.2090085] [0.00989019]\n", 100 | "3000 [1.0960255] [1.2016344] [0.00923403]\n", 101 | "3100 [1.0927693] [1.1944958] [0.00861846]\n", 102 | "3200 [1.0896087] [1.1875869] [0.00804126]\n", 103 | "3300 [1.0865414] [1.180901] [0.00750023]\n", 104 | "3400 [1.0835658] [1.174433] [0.00699336]\n", 105 | "3500 [1.08068] [1.1681769] [0.00651873]\n", 106 | "3600 [1.077882] [1.162127] [0.00607445]\n", 107 | "3700 [1.07517] [1.156278] [0.0056588]\n", 108 | "3800 [1.0725418] [1.1506239] [0.00527004]\n", 109 | "3900 [1.0699958] [1.1451594] [0.00490661]\n", 110 | "4000 [1.0675296] [1.1398786] [0.00456696]\n", 111 | "4100 [1.0651416] [1.1347771] [0.0042497]\n", 112 | "4200 [1.06283] [1.1298497] [0.00395347]\n", 113 | "4300 [1.0605928] [1.125091] [0.00367695]\n", 114 | "4400 [1.0584278] [1.1204952] [0.00341891]\n", 115 | "4500 [1.0563335] [1.1160585] [0.00317822]\n", 116 | "4600 [1.054308] [1.1117759] [0.00295379]\n", 117 | "4700 [1.0523497] [1.1076429] [0.00274461]\n", 118 | "4800 [1.0504564] [1.1036544] [0.00254968]\n", 119 | "4900 [1.0486261] [1.0998057] [0.00236807]\n", 120 | "5000 [1.0468574] [1.0960927] [0.00219894]\n", 121 | "5100 [1.0451484] [1.092511] [0.00204147]\n", 122 | "5200 [1.0434977] [1.0890571] [0.00189493]\n", 123 | "5300 [1.0419033] [1.085726] [0.00175856]\n", 124 | "5400 [1.0403638] [1.0825144] [0.00163172]\n", 125 | "5500 [1.0388776] [1.0794185] [0.00151377]\n", 126 | "5600 [1.0374427] [1.0764338] [0.0014041]\n", 127 | "5700 [1.0360577] [1.0735567] [0.00130215]\n", 128 | "5800 [1.0347213] [1.0707842] [0.00120742]\n", 129 | "5900 [1.0334316] [1.0681121] [0.00111939]\n", 130 | "6000 [1.0321876] [1.0655377] [0.00103764]\n", 131 | "6100 [1.0309876] [1.0630571] [0.00096171]\n", 132 | "6200 [1.0298302] [1.0606676] [0.00089122]\n", 133 | "6300 [1.0287142] [1.0583658] [0.00082578]\n", 134 | "6400 [1.027638] [1.0561485] [0.00076504]\n", 135 | "6500 [1.0266004] [1.0540129] [0.00070867]\n", 136 | "6600 [1.0256001] [1.0519564] [0.00065638]\n", 137 | "6700 [1.0246364] [1.0499767] [0.00060789]\n", 138 | "6800 [1.0237072] [1.0480698] [0.0005629]\n", 139 | "6900 [1.0228118] [1.0462341] [0.00052119]\n", 140 | "7000 [1.0219492] [1.0444669] [0.00048252]\n", 141 | "7100 [1.0211182] [1.042766] [0.00044668]\n", 142 | "7200 [1.0203176] [1.0411282] [0.00041345]\n", 143 | "7300 [1.019546] [1.0395514] [0.00038264]\n", 144 | "7400 [1.0188031] [1.0380342] [0.00035411]\n", 145 | "7500 [1.0180877] [1.0365742] [0.00032768]\n", 146 | "7600 [1.0173987] [1.035169] [0.00030319]\n", 147 | "7700 [1.0167352] [1.0338168] [0.00028051]\n", 148 | "7800 [1.0160965] [1.0325158] [0.0002595]\n", 149 | "7900 [1.0154815] [1.031264] [0.00024005]\n", 150 | "8000 [1.0148895] [1.0300598] [0.00022205]\n", 151 | "8100 [1.0143198] [1.0289016] [0.00020538]\n", 152 | "8200 [1.0137712] [1.0277867] [0.00018994]\n", 153 | "8300 [1.0132432] [1.0267144] [0.00017566]\n", 154 | "8400 [1.0127351] [1.0256829] [0.00016244]\n", 155 | "8500 [1.012246] [1.0246906] [0.0001502]\n", 156 | "8600 [1.0117754] [1.0237362] [0.00013888]\n", 157 | "8700 [1.0113225] [1.0228182] [0.0001284]\n", 158 | "8800 [1.0108868] [1.0219353] [0.00011871]\n", 159 | "8900 [1.0104674] [1.0210861] [0.00010974]\n", 160 | "9000 [1.010064] [1.0202694] [0.00010144]\n", 161 | "9100 [1.0096759] [1.0194839] [9.377145e-05]\n", 162 | "9200 [1.0093026] [1.0187289] [8.667611e-05]\n", 163 | "9300 [1.0089433] [1.0180023] [8.011e-05]\n", 164 | "9400 [1.0085979] [1.0173038] [7.4040094e-05]\n", 165 | "9500 [1.0082656] [1.0166326] [6.842942e-05]\n", 166 | "9600 [1.0079461] [1.015987] [6.3240834e-05]\n", 167 | "9700 [1.0076388] [1.0153663] [5.8443857e-05]\n", 168 | "9800 [1.0073432] [1.0147696] [5.400819e-05]\n", 169 | "9900 [1.0070587] [1.0141954] [4.9904953e-05]\n", 170 | "10000 [1.0067854] [1.0136437] [4.611414e-05]\n", 171 | "10100 [1.0065224] [1.0131134] [4.261008e-05]\n", 172 | "10200 [1.0062697] [1.0126038] [3.937232e-05]\n", 173 | "10300 [1.0060265] [1.0121135] [3.637677e-05]\n", 174 | "10400 [1.0057927] [1.0116421] [3.360929e-05]\n", 175 | "10500 [1.005568] [1.0111892] [3.105209e-05]\n", 176 | "10600 [1.0053519] [1.0107538] [2.8687871e-05]\n", 177 | "10700 [1.0051438] [1.0103344] [2.6499838e-05]\n", 178 | "10800 [1.0049441] [1.0099323] [2.4482675e-05]\n", 179 | "10900 [1.004752] [1.0095456] [2.261781e-05]\n", 180 | "11000 [1.004567] [1.009173] [2.0890568e-05]\n", 181 | "11100 [1.0043893] [1.0088155] [1.929696e-05]\n", 182 | "11200 [1.0042188] [1.0084724] [1.782707e-05]\n", 183 | "11300 [1.0040548] [1.0081422] [1.6467564e-05]\n", 184 | "11400 [1.0038973] [1.0078254] [1.5213036e-05]\n", 185 | "11500 [1.0037454] [1.00752] [1.4050856e-05]\n", 186 | "11600 [1.0035998] [1.0072268] [1.2978757e-05]\n", 187 | "11700 [1.0034593] [1.0069445] [1.1986116e-05]\n", 188 | "11800 [1.0033249] [1.0066742] [1.1072562e-05]\n", 189 | "11900 [1.0031949] [1.006413] [1.0224141e-05]\n", 190 | "12000 [1.0030704] [1.0061625] [9.442154e-06]\n", 191 | "12100 [1.0029511] [1.0059228] [8.723185e-06]\n", 192 | "12200 [1.0028362] [1.0056918] [8.056742e-06]\n", 193 | "12300 [1.0027258] [1.00547] [7.441969e-06]\n", 194 | "12400 [1.0026191] [1.0052557] [6.8709373e-06]\n", 195 | "12500 [1.0025172] [1.0050509] [6.3466805e-06]\n", 196 | "12600 [1.0024195] [1.0048547] [5.8633987e-06]\n", 197 | "12700 [1.0023247] [1.0046641] [5.4128777e-06]\n", 198 | "12800 [1.0022341] [1.0044822] [4.999202e-06]\n", 199 | "12900 [1.0021468] [1.0043069] [4.616289e-06]\n", 200 | "13000 [1.0020635] [1.0041395] [4.2648508e-06]\n", 201 | "13100 [1.0019825] [1.0039768] [3.936489e-06]\n", 202 | "13200 [1.0019053] [1.003822] [3.636073e-06]\n", 203 | "13300 [1.0018317] [1.003674] [3.360407e-06]\n", 204 | "13400 [1.0017604] [1.0035309] [3.1038267e-06]\n", 205 | "13500 [1.0016907] [1.003391] [2.8630761e-06]\n", 206 | "13600 [1.001625] [1.0032591] [2.6447346e-06]\n", 207 | "13700 [1.0015608] [1.0031304] [2.4402632e-06]\n", 208 | "13800 [1.001501] [1.0030102] [2.256446e-06]\n", 209 | "13900 [1.0014416] [1.0028911] [2.0816165e-06]\n", 210 | "14000 [1.0013849] [1.0027771] [1.9208283e-06]\n", 211 | "14100 [1.0013313] [1.0026697] [1.775189e-06]\n", 212 | "14200 [1.0012786] [1.0025641] [1.637545e-06]\n", 213 | "14300 [1.0012292] [1.0024649] [1.513479e-06]\n", 214 | "14400 [1.0011816] [1.0023694] [1.398458e-06]\n", 215 | "14500 [1.0011342] [1.0022742] [1.2883646e-06]\n", 216 | "14600 [1.0010892] [1.002184] [1.1883353e-06]\n", 217 | "14700 [1.0010475] [1.0021003] [1.0990813e-06]\n", 218 | "14800 [1.0010062] [1.0020174] [1.0139854e-06]\n", 219 | "14900 [1.0009669] [1.0019386] [9.3636345e-07]\n", 220 | "15000 [1.0009305] [1.0018657] [8.673742e-07]\n", 221 | "15100 [1.0008949] [1.0017942] [8.021324e-07]\n", 222 | "15200 [1.0008593] [1.0017227] [7.3952384e-07]\n", 223 | "15300 [1.0008241] [1.0016521] [6.8016664e-07]\n", 224 | "15400 [1.000792] [1.0015879] [6.28342e-07]\n", 225 | "15500 [1.0007619] [1.0015274] [5.8140125e-07]\n", 226 | "15600 [1.0007322] [1.0014678] [5.369111e-07]\n", 227 | "15700 [1.000703] [1.0014093] [4.949954e-07]\n", 228 | "15800 [1.0006748] [1.0013529] [4.5616588e-07]\n", 229 | "15900 [1.0006491] [1.0013012] [4.2201157e-07]\n", 230 | "16000 [1.0006253] [1.0012535] [3.9156765e-07]\n", 231 | "16100 [1.0006015] [1.0012058] [3.6240687e-07]\n", 232 | "16200 [1.0005777] [1.0011581] [3.3429208e-07]\n", 233 | "16300 [1.000554] [1.0011104] [3.0733827e-07]\n", 234 | "16400 [1.0005314] [1.0010654] [2.8293618e-07]\n", 235 | "16500 [1.0005107] [1.0010235] [2.611707e-07]\n", 236 | "16600 [1.0004911] [1.0009845] [2.416314e-07]\n", 237 | "16700 [1.0004733] [1.0009488] [2.2438654e-07]\n", 238 | "16800 [1.0004555] [1.000913] [2.077988e-07]\n", 239 | "16900 [1.0004376] [1.0008773] [1.918727e-07]\n", 240 | "17000 [1.0004199] [1.0008415] [1.7651848e-07]\n", 241 | "17100 [1.0004022] [1.0008062] [1.6201476e-07]\n", 242 | "17200 [1.000386] [1.0007738] [1.4923592e-07]\n", 243 | "17300 [1.0003712] [1.0007441] [1.3804276e-07]\n", 244 | "17400 [1.0003588] [1.0007192] [1.289564e-07]\n", 245 | "17500 [1.0003469] [1.0006953] [1.2054358e-07]\n", 246 | "17600 [1.000335] [1.0006715] [1.12414966e-07]\n", 247 | "17700 [1.000323] [1.0006477] [1.0457058e-07]\n", 248 | "17800 [1.0003113] [1.0006238] [9.702207e-08]\n", 249 | "17900 [1.0002993] [1.0006] [8.974327e-08]\n", 250 | "18000 [1.0002874] [1.0005761] [8.2748684e-08]\n", 251 | "18100 [1.0002756] [1.0005523] [7.605286e-08]\n", 252 | "18200 [1.0002637] [1.0005285] [6.9623866e-08]\n", 253 | "18300 [1.0002518] [1.0005046] [6.34791e-08]\n", 254 | "18400 [1.0002403] [1.0004817] [5.7847497e-08]\n", 255 | "18500 [1.0002297] [1.0004603] [5.2839212e-08]\n", 256 | "18600 [1.0002193] [1.0004398] [4.8227378e-08]\n", 257 | "18700 [1.0002115] [1.0004238] [4.4773795e-08]\n", 258 | "18800 [1.0002035] [1.0004079] [4.149923e-08]\n", 259 | "18900 [1.0001959] [1.0003926] [3.8412633e-08]\n", 260 | "19000 [1.0001895] [1.0003798] [3.5977617e-08]\n", 261 | "19100 [1.0001832] [1.0003672] [3.3622438e-08]\n", 262 | "19200 [1.0001769] [1.0003545] [3.13471e-08]\n", 263 | "19300 [1.0001707] [1.0003421] [2.9192277e-08]\n", 264 | "19400 [1.0001647] [1.0003302] [2.7192812e-08]\n", 265 | "19500 [1.0001588] [1.0003183] [2.5248767e-08]\n", 266 | "19600 [1.0001528] [1.0003064] [2.3407038e-08]\n", 267 | "19700 [1.000147] [1.0002944] [2.1627342e-08]\n", 268 | "19800 [1.000141] [1.0002825] [1.991067e-08]\n", 269 | "19900 [1.0001351] [1.0002706] [1.8265053e-08]\n" 270 | ] 271 | } 272 | ], 273 | "source": [ 274 | "import tensorflow as tf\n", 275 | "\n", 276 | "# initialize arrays\n", 277 | "x1_data = tf.Variable(initial_value=tf.random_uniform([1], -3, 3),name='x1')\n", 278 | "x2_data = tf.Variable(initial_value=tf.random_uniform([1], -3, 3), name='x2')\n", 279 | "\n", 280 | "# The Rosenbrock problem is defined as y = (1 - x1)^2 + 100 * (x2 - x1^2)^2, \n", 281 | "# giving the optimal solution on x1 = x2 = 1\n", 282 | "\n", 283 | "# Loss function\n", 284 | "y = tf.add(tf.pow(tf.subtract(1.0, x1_data), 2.0), \n", 285 | " tf.multiply(100.0, tf.pow(tf.subtract(x2_data,tf.pow(x1_data, 2.0)), 2.0)), 'y')\n", 286 | "\n", 287 | "#opt = tf.train.GradientDescentOptimizer(0.0035)\n", 288 | "#train = opt.minimize(y)\n", 289 | "\n", 290 | "opt = tf.train.GradientDescentOptimizer(1e-3)\n", 291 | "grads_and_vars = opt.compute_gradients(y, [x1_data, x2_data])\n", 292 | "clipped_grads_and_vars = [(tf.clip_by_value(g, -1., 1.), v) for g, v in grads_and_vars]\n", 293 | "train = opt.apply_gradients(clipped_grads_and_vars)\n", 294 | "\n", 295 | "\n", 296 | "sess = tf.Session()\n", 297 | "\n", 298 | "init = tf.initialize_all_variables()\n", 299 | "sess.run(init)\n", 300 | "\n", 301 | "for step in range(20000):\n", 302 | " sess.run(train)\n", 303 | " if step % 100 == 0:\n", 304 | " print(step, sess.run(x1_data), sess.run(x2_data), sess.run(y))" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 8, 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "name": "stdout", 314 | "output_type": "stream", 315 | "text": [ 316 | "y = -0.9999969x + 0.9999908\n", 317 | "Loss: 5.6999738e-11\n" 318 | ] 319 | } 320 | ], 321 | "source": [ 322 | "# Tensorflow #1 Example\n", 323 | "# Tensorflow example of Gradient Descent\n", 324 | "# on a linear equation (y = mx + b)\n", 325 | "#\n", 326 | "# https://github.com/FFY00/DeepLearning-Studies\n", 327 | "\n", 328 | "\n", 329 | "import tensorflow as tf\n", 330 | "\n", 331 | "m = tf.Variable([.3], dtype=tf.float32)\n", 332 | "b = tf.Variable([-.3], dtype=tf.float32)\n", 333 | "x = tf.placeholder(tf.float32)\n", 334 | "linear_model = m * x + b # y = mx + b\n", 335 | "\n", 336 | "y = tf.placeholder(tf.float32)\n", 337 | "squared_deltas = tf.square(linear_model - y) # Also known as r^2\n", 338 | "loss = tf.reduce_sum(squared_deltas)\n", 339 | "\n", 340 | "# If you decrease the learning rate, you have to increase the loop range value\n", 341 | "optimizer = tf.train.GradientDescentOptimizer(0.01)\n", 342 | "train = optimizer.minimize(loss)\n", 343 | "\n", 344 | "init = tf.global_variables_initializer()\n", 345 | "sess = tf.Session()\n", 346 | "sess.run(init)\n", 347 | "\n", 348 | "x_set = [1, 2, 3, 4]\n", 349 | "y_set = [0, -1, -2, -3]\n", 350 | "\n", 351 | "for i in range(1000):\n", 352 | " sess.run(train, {x: x_set, y: y_set})\n", 353 | "\n", 354 | "m_value, b_value, loss = sess.run([m, b, loss], {x: x_set, y: y_set})\n", 355 | "print ( \"y = {}x + {}\".format(repr(m_value[0]), repr(b_value[0])) )\n", 356 | "print (\"Loss: \", loss)\n" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 115, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "name": "stdout", 366 | "output_type": "stream", 367 | "text": [ 368 | "Tensor(\"add_4:0\", shape=(1,), dtype=float32)\n", 369 | "0 [-1.3972076] [1.4620501] [29.77023]\n", 370 | "100 [-1.297203] [1.5620548] [6.733525]\n", 371 | "200 [-1.2421428] [1.5511112] [5.033916]\n", 372 | "300 [-1.2004808] [1.4496552] [4.849343]\n", 373 | "400 [-1.1575612] [1.3486443] [4.6626334]\n", 374 | "500 [-1.1131102] [1.2479031] [4.4731355]\n", 375 | "600 [-1.0668845] [1.1473306] [4.2802706]\n", 376 | "700 [-1.0186123] [1.0468733] [4.08345]\n", 377 | "800 [-0.96796685] [0.94649905] [3.881993]\n", 378 | "900 [-0.914538] [0.84618735] [3.6750746]\n", 379 | "1000 [-0.8577915] [0.74592423] [3.4616263]\n", 380 | "1100 [-0.79700214] [0.6456998] [3.2402153]\n", 381 | "1200 [-0.7311409] [0.54550666] [3.0088165]\n", 382 | "1300 [-0.6586538] [0.4453393] [2.7643905]\n", 383 | "1400 [-0.5769843] [0.34519356] [2.5019658]\n", 384 | "1500 [-0.4813365] [0.245066] [2.2122633]\n", 385 | "1600 [-0.36041316] [0.14495385] [1.873393]\n", 386 | "1700 [-0.17624307] [0.04485497] [1.4025733]\n", 387 | "1800 [0.02671082] [7.182778e-05] [0.94733304]\n", 388 | "1900 [0.20114698] [0.03732746] [0.63914746]\n", 389 | "2000 [0.35112444] [0.11945886] [0.42250603]\n", 390 | "2100 [0.45071608] [0.19991261] [0.30275762]\n", 391 | "2200 [0.5220393] [0.26977426] [0.2292031]\n", 392 | "2300 [0.5778115] [0.3314967] [0.17880456]\n", 393 | "2400 [0.62366366] [0.38689512] [0.1420539]\n", 394 | "2500 [0.66258687] [0.43721467] [0.11417403]\n", 395 | "2600 [0.69636065] [0.48332563] [0.09245047]\n", 396 | "2700 [0.7261298] [0.5258551] [0.07520352]\n", 397 | "2800 [0.75267375] [0.5652669] [0.06132674]\n", 398 | "2900 [0.77654666] [0.6019126] [0.05005506]\n", 399 | "3000 [0.79815716] [0.6360645] [0.0408386]\n", 400 | "3100 [0.817814] [0.667938] [0.03326948]\n", 401 | "3200 [0.83575743] [0.69770527] [0.02703727]\n", 402 | "3300 [0.85217714] [0.7255071] [0.02190043]\n", 403 | "3400 [0.8672261] [0.75146] [0.01766748]\n", 404 | "3500 [0.88103] [0.77566266] [0.01418424]\n", 405 | "3600 [0.8936933] [0.79819953] [0.01132494]\n", 406 | "3700 [0.90530443] [0.8191448] [0.00898585]\n", 407 | "3800 [0.91593945] [0.8385649] [0.00708063]\n", 408 | "3900 [0.9256645] [0.8565209] [0.00553691]\n", 409 | "4000 [0.9345385] [0.8730702] [0.00429374]\n", 410 | "4100 [0.94261473] [0.88826805] [0.00329954]\n", 411 | "4200 [0.9499418] [0.90216833] [0.00251071]\n", 412 | "4300 [0.9565646] [0.9148251] [0.00189027]\n", 413 | "4400 [0.9625265] [0.9262932] [0.00140695]\n", 414 | "4500 [0.9678684] [0.93662894] [0.00103441]\n", 415 | "4600 [0.97262955] [0.94588923] [0.00075056]\n", 416 | "4700 [0.9768492] [0.9541342] [0.00053696]\n", 417 | "4800 [0.98056537] [0.9614246] [0.00037841]\n", 418 | "4900 [0.98381585] [0.96782404] [0.00026241]\n", 419 | "5000 [0.986638] [0.9733969] [0.00017887]\n", 420 | "5100 [0.9890682] [0.9782091] [0.00011972]\n", 421 | "5200 [0.991143] [0.9823263] [7.859172e-05]\n", 422 | "5300 [0.9928978] [0.9858156] [5.053389e-05]\n", 423 | "5400 [0.9943671] [0.9887418] [3.1787866e-05]\n", 424 | "5500 [0.99558425] [0.9911691] [1.9534327e-05]\n", 425 | "5600 [0.99658096] [0.99315894] [1.1711346e-05]\n", 426 | "5700 [0.9973869] [0.9947697] [6.8403315e-06]\n", 427 | "5800 [0.9980305] [0.9960565] [3.885958e-06]\n", 428 | "5900 [0.99853724] [0.99707025] [2.1436515e-06]\n", 429 | "6000 [0.99893034] [0.9978574] [1.1461286e-06]\n", 430 | "6100 [0.9992306] [0.99845856] [5.930153e-07]\n", 431 | "6200 [0.9994561] [0.99891007] [2.9638736e-07]\n", 432 | "6300 [0.99962234] [0.9992433] [1.4286348e-07]\n", 433 | "6400 [0.9997429] [0.9994846] [6.6259425e-08]\n", 434 | "6500 [0.9998285] [0.99965614] [2.9497217e-08]\n", 435 | "6600 [0.99988806] [0.9997756] [1.2558786e-08]\n", 436 | "6700 [0.9999286] [0.9998567] [5.1216063e-09]\n", 437 | "6800 [0.99995553] [0.9999106] [1.9998796e-09]\n", 438 | "6900 [0.9999729] [0.99994576] [7.3550055e-10]\n", 439 | "7000 [0.9999841] [0.999968] [2.5469046e-10]\n", 440 | "7100 [0.9999908] [0.99998146] [8.567724e-11]\n", 441 | "7200 [0.99999493] [0.9999898] [2.6023628e-11]\n", 442 | "7300 [0.9999972] [0.99999434] [8.203216e-12]\n", 443 | "7400 [0.9999983] [0.9999966] [3.3431036e-12]\n", 444 | "7500 [0.999999] [0.999998] [1.0267343e-12]\n", 445 | "7600 [0.99999934] [0.99999875] [7.851497e-13]\n", 446 | "7700 [0.9999995] [0.99999905] [2.2737366e-13]\n", 447 | "7800 [0.9999995] [0.99999905] [2.2737366e-13]\n", 448 | "7900 [0.9999995] [0.99999905] [2.2737366e-13]\n", 449 | "8000 [0.9999995] [0.99999905] [2.2737366e-13]\n", 450 | "8100 [0.9999995] [0.99999905] [2.2737366e-13]\n", 451 | "8200 [0.9999996] [0.9999992] [1.5951683e-12]\n", 452 | "8300 [0.99999994] [0.99999964] [5.687894e-12]\n", 453 | "8400 [0.99999994] [0.99999994] [3.5882408e-13]\n", 454 | "8500 [0.99999976] [1.] [2.279421e-11]\n", 455 | "8600 [0.99999976] [1.] [2.279421e-11]\n", 456 | "8700 [1.000001] [0.9999994] [6.2760813e-10]\n", 457 | "8800 [1.0000024] [0.9999979] [4.704148e-09]\n", 458 | "8900 [0.99999726] [1.0000007] [3.8501327e-09]\n", 459 | "9000 [0.9999971] [1.0000004] [3.851145e-09]\n", 460 | "9100 [1.0000196] [0.99999017] [2.3984967e-07]\n", 461 | "9200 [0.9999384] [1.0000275] [2.2742536e-06]\n", 462 | "9300 [0.9999974] [1.0000008] [3.7031214e-09]\n", 463 | "9400 [0.99999714] [0.99999994] [3.2823664e-09]\n", 464 | "9500 [0.9999988] [1.] [5.698553e-10]\n", 465 | "9600 [1.0000024] [0.99999636] [7.0688344e-09]\n", 466 | "9700 [1.] [0.9999994] [3.5527137e-11]\n", 467 | "9800 [1.] [0.9999994] [3.5527137e-11]\n", 468 | "9900 [1.0000035] [0.99999523] [1.3660056e-08]\n", 469 | "10000 [1.] [0.9999997] [8.881783e-12]\n", 470 | "10100 [0.9999801] [1.0000116] [2.6437857e-07]\n", 471 | "10200 [0.9999991] [0.99999976] [2.4096283e-10]\n", 472 | "10300 [0.99999446] [1.0000021] [1.7698015e-08]\n", 473 | "10400 [0.9999632] [1.0000256] [9.862447e-07]\n", 474 | "10500 [0.9999989] [1.0000004] [6.278497e-10]\n", 475 | "10600 [1.0000035] [0.9999962] [1.1522744e-08]\n", 476 | "10700 [1.000026] [0.99997693] [5.6380924e-07]\n", 477 | "10800 [0.99999875] [1.0000007] [1.0375382e-09]\n", 478 | "10900 [1.000007] [0.999991] [5.32581e-08]\n", 479 | "11000 [0.99999905] [1.0000006] [6.578062e-10]\n", 480 | "11100 [1.0000039] [0.99999434] [1.8322254e-08]\n", 481 | "11200 [0.99999666] [1.0000026] [8.657025e-09]\n", 482 | "11300 [0.99999875] [0.99999994] [5.987779e-10]\n", 483 | "11400 [1.0000218] [0.99997824] [4.280127e-07]\n", 484 | "11500 [1.0000147] [0.99998635] [1.8489962e-07]\n", 485 | "11600 [0.9999999] [0.99999905] [5.1173288e-11]\n", 486 | "11700 [1.000018] [0.999981] [3.0298997e-07]\n", 487 | "11800 [1.0000155] [0.999986] [2.0275372e-07]\n", 488 | "11900 [1.0000012] [0.9999983] [1.6928681e-09]\n", 489 | "12000 [1.0000186] [0.99997836] [3.4644017e-07]\n", 490 | "12100 [1.0000005] [0.9999993] [2.7876013e-10]\n", 491 | "12200 [0.99996454] [1.0000314] [1.0474097e-06]\n", 492 | "12300 [1.0000154] [0.99998534] [2.0652267e-07]\n", 493 | "12400 [0.9999978] [1.0000017] [3.701107e-09]\n", 494 | "12500 [0.99999887] [1.0000001] [5.697167e-10]\n", 495 | "12600 [1.0000043] [0.99999297] [2.4405665e-08]\n", 496 | "12700 [1.00001] [0.99999005] [8.998713e-08]\n", 497 | "12800 [1.0000013] [0.99999803] [2.1081235e-09]\n", 498 | "12900 [1.0000126] [0.9999846] [1.654049e-07]\n", 499 | "13000 [0.9999992] [1.0000005] [4.112941e-10]\n", 500 | "13100 [0.99999744] [0.9999992] [1.952035e-09]\n", 501 | "13200 [1.0000002] [0.99999946] [1.0273027e-10]\n", 502 | "13300 [1.0001369] [0.9998655] [1.6684082e-05]\n", 503 | "13400 [0.99999964] [0.9999993] [1.2789769e-13]\n", 504 | "13500 [0.9999998] [0.9999998] [3.2294167e-12]\n", 505 | "13600 [0.9999998] [0.9999998] [3.2294167e-12]\n", 506 | "13700 [0.9999998] [0.9999998] [3.2294167e-12]\n", 507 | "13800 [0.9999998] [0.9999998] [3.2294167e-12]\n", 508 | "13900 [1.000013] [0.999952] [5.4731555e-07]\n", 509 | "14000 [0.99999666] [0.9999926] [6.230039e-11]\n", 510 | "14100 [0.9999991] [0.9999981] [2.2204458e-12]\n", 511 | "14200 [0.9999994] [0.9999988] [3.5527137e-13]\n", 512 | "14300 [0.9999994] [0.9999988] [3.5527137e-13]\n", 513 | "14400 [0.9999994] [0.9999988] [3.5527137e-13]\n", 514 | "14500 [0.9999994] [0.9999988] [3.5527137e-13]\n", 515 | "14600 [0.9999994] [0.9999988] [3.5527137e-13]\n", 516 | "14700 [0.9999994] [0.9999988] [3.5527137e-13]\n", 517 | "14800 [0.99999976] [0.9999995] [5.6843415e-14]\n", 518 | "14900 [0.99999994] [0.9999998] [3.5882408e-13]\n", 519 | "15000 [0.9999999] [0.9999999] [1.4352962e-12]\n", 520 | "15100 [0.99999994] [0.9999998] [3.5882408e-13]\n", 521 | "15200 [0.99999994] [0.9999998] [3.5882408e-13]\n", 522 | "15300 [0.9999935] [1.0000063] [3.7337177e-08]\n", 523 | "15400 [0.9999972] [0.99999857] [1.7486776e-09]\n", 524 | "15500 [0.9999998] [0.9999995] [1.4530598e-12]\n", 525 | "15600 [0.9999998] [0.9999995] [1.4530598e-12]\n", 526 | "15700 [0.9999999] [0.9999997] [3.6948222e-13]\n", 527 | "15800 [1.0000002] [0.9999994] [1.1516477e-10]\n", 528 | "15900 [0.9999389] [1.0000495] [2.9504954e-06]\n", 529 | "16000 [0.9999991] [0.9999988] [3.6326497e-11]\n", 530 | "16100 [0.99999964] [0.99999976] [2.2865264e-11]\n", 531 | "16200 [0.9999999] [0.99999976] [1.4210854e-14]\n", 532 | "16300 [1.] [0.9999998] [3.1974423e-12]\n", 533 | "16400 [1.] [0.9999998] [3.1974423e-12]\n", 534 | "16500 [1.] [0.9999998] [3.1974423e-12]\n", 535 | "16600 [1.] [0.9999998] [3.1974423e-12]\n", 536 | "16700 [1.] [0.9999998] [3.1974423e-12]\n", 537 | "16800 [1.0002725] [0.9997275] [6.693048e-05]\n", 538 | "16900 [0.9999901] [0.99998087] [1.4088641e-10]\n", 539 | "17000 [0.99999684] [0.9999933] [2.2769342e-11]\n", 540 | "17100 [0.999999] [0.99999785] [2.4478197e-12]\n", 541 | "17200 [0.99999964] [0.9999996] [9.009681e-12]\n", 542 | "17300 [0.9999999] [0.9999996] [3.2116532e-12]\n", 543 | "17400 [0.9999999] [0.9999996] [3.2116532e-12]\n", 544 | "17500 [0.9999999] [0.9999997] [3.6948222e-13]\n", 545 | "17600 [0.9999998] [0.99999976] [1.4530598e-12]\n", 546 | "17700 [0.9999999] [0.9999997] [3.6948222e-13]\n", 547 | "17800 [0.9999999] [0.9999997] [3.6948222e-13]\n", 548 | "17900 [0.9999998] [0.99999976] [1.4530598e-12]\n", 549 | "18000 [0.9999999] [0.9999997] [3.6948222e-13]\n", 550 | "18100 [0.9999999] [0.9999997] [3.6948222e-13]\n", 551 | "18200 [0.9999998] [0.99999976] [1.4530598e-12]\n", 552 | "18300 [0.9999999] [0.9999997] [3.6948222e-13]\n", 553 | "18400 [0.9999999] [0.9999997] [3.6948222e-13]\n", 554 | "18500 [0.9999995] [1.0000001] [1.153353e-10]\n", 555 | "18600 [1.0000275] [0.99996483] [8.151099e-07]\n", 556 | "18700 [0.9999988] [0.9999985] [8.1357136e-11]\n" 557 | ] 558 | }, 559 | { 560 | "name": "stdout", 561 | "output_type": "stream", 562 | "text": [ 563 | "18800 [0.9999998] [0.9999998] [3.2294167e-12]\n", 564 | "18900 [0.9999998] [0.9999998] [3.2294167e-12]\n", 565 | "19000 [0.9999998] [0.9999998] [3.2294167e-12]\n", 566 | "19100 [0.9999998] [0.9999998] [3.2294167e-12]\n", 567 | "19200 [0.9999998] [0.9999998] [3.2294167e-12]\n", 568 | "19300 [0.9999998] [0.9999998] [3.2294167e-12]\n", 569 | "19400 [1.000078] [0.99989605] [6.7596447e-06]\n", 570 | "19500 [0.99999285] [0.9999901] [2.0495603e-09]\n", 571 | "19600 [0.99999845] [0.99999684] [2.756906e-12]\n", 572 | "19700 [0.99999964] [0.999999] [9.009681e-12]\n", 573 | "19800 [0.9999997] [0.99999934] [8.8817835e-14]\n", 574 | "19900 [0.9999997] [0.99999934] [8.8817835e-14]\n" 575 | ] 576 | } 577 | ], 578 | "source": [ 579 | "import tensorflow as tf\n", 580 | "\n", 581 | "# initialize arrays\n", 582 | "x1_data = tf.Variable(initial_value=tf.random_uniform([1], -3, 3),name='x1')\n", 583 | "x2_data = tf.Variable(initial_value=tf.random_uniform([1], -3, 3), name='x2')\n", 584 | "\n", 585 | "# The Rosenbrock problem is defined as y = (1 - x1)^2 + 100 * (x2 - x1^2)^2, \n", 586 | "# giving the optimal solution on x1 = x2 = 1\n", 587 | "\n", 588 | "# Loss function\n", 589 | "#y = tf.add(tf.pow(tf.subtract(1.0, x1_data), 2.0), \n", 590 | "# tf.multiply(100.0, tf.pow(tf.subtract(x2_data,tf.pow(x1_data, 2.0)), 2.0)), 'y')\n", 591 | "\n", 592 | "y = (1.0-x1_data)**2 + 100.*(x2_data-x1_data**2)**2\n", 593 | "\n", 594 | "print (y)\n", 595 | "\n", 596 | "#opt = tf.train.GradientDescentOptimizer(0.0035)\n", 597 | "#train = opt.minimize(y)\n", 598 | "\n", 599 | "opt = tf.train.GradientDescentOptimizer(1e-3)\n", 600 | "opt = tf.train.AdamOptimizer(1e-3)\n", 601 | "\n", 602 | "\n", 603 | "grads_and_vars = opt.compute_gradients(y, [x1_data, x2_data])\n", 604 | "clipped_grads_and_vars = [(tf.clip_by_value(g, -1., 1.), v) for g, v in grads_and_vars]\n", 605 | "train = opt.apply_gradients(clipped_grads_and_vars)\n", 606 | "\n", 607 | "\n", 608 | "sess = tf.Session()\n", 609 | "\n", 610 | "init = tf.initialize_all_variables()\n", 611 | "sess.run(init)\n", 612 | "\n", 613 | "for step in range(20000):\n", 614 | " sess.run(train)\n", 615 | " if step % 100 == 0:\n", 616 | " print(step, sess.run(x1_data), sess.run(x2_data), sess.run(y))" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": 30, 622 | "metadata": {}, 623 | "outputs": [ 624 | { 625 | "name": "stdout", 626 | "output_type": "stream", 627 | "text": [ 628 | "10\n" 629 | ] 630 | } 631 | ], 632 | "source": [ 633 | "import tensorflow as tf\n", 634 | "\n", 635 | "i = tf.constant(0)\n", 636 | "c = lambda i: tf.less(i, 10)\n", 637 | "b = lambda i: tf.add(i, 1)\n", 638 | "r = tf.while_loop(c, b, [i])\n", 639 | "\n", 640 | "with tf.Session() as sess:\n", 641 | " p = sess.run(r)\n", 642 | " print (p)\n" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": 27, 648 | "metadata": {}, 649 | "outputs": [ 650 | { 651 | "name": "stdout", 652 | "output_type": "stream", 653 | "text": [ 654 | "[1. 2.]\n", 655 | "[[1. 0.6]\n", 656 | " [0.6 2. ]]\n" 657 | ] 658 | } 659 | ], 660 | "source": [ 661 | "# Let mean vector and co-variance be:\n", 662 | "mu = [1., 2] \n", 663 | "cov = [[ 1, 3/5],[ 3/5, 2]]\n", 664 | "\n", 665 | "#Multivariate Normal distribution\n", 666 | "gaussian = tf.contrib.distributions.MultivariateNormalFullCovariance(\n", 667 | " loc=mu,\n", 668 | " covariance_matrix=cov)\n", 669 | "\n", 670 | "# Generate a mesh grid to plot the distributions\n", 671 | "X, Y = tf.meshgrid(tf.range(-3, 3, 0.1), tf.range(-3, 3, 0.1))\n", 672 | "idx = tf.concat([tf.reshape(X, [-1, 1]), tf.reshape(Y,[-1,1])], axis =1)\n", 673 | "prob = tf.reshape(gaussian.prob(idx), tf.shape(X))\n", 674 | "\n", 675 | "with tf.Session() as sess:\n", 676 | " p = sess.run(prob)\n", 677 | " m, c = sess.run([gaussian.mean(), gaussian.covariance()])\n", 678 | " print (m)\n", 679 | " print (c)\n", 680 | " # m is [1., 2.]\n", 681 | " # c is [[1, 0.6], [0.6, 2]]" 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": 108, 687 | "metadata": {}, 688 | "outputs": [ 689 | { 690 | "name": "stdout", 691 | "output_type": "stream", 692 | "text": [ 693 | "0.1248408\n" 694 | ] 695 | } 696 | ], 697 | "source": [ 698 | "import tensorflow as tf\n", 699 | "import numpy as np\n", 700 | "tfd = tf.contrib.distributions\n", 701 | "\n", 702 | "x = tf.range(0, 1., 3,)\n", 703 | "\n", 704 | "mvn = tfd.MultivariateNormalFullCovariance(\n", 705 | " loc=[0.0, 0.0, 0.0],\n", 706 | " covariance_matrix=[[0.1, 0.0, 0.0], \n", 707 | " [0.0, 0.1, 0.0], \n", 708 | " [0.0, 0.0, 0.1]] )\n", 709 | "\n", 710 | "\n", 711 | "with tf.Session() as sess:\n", 712 | "\n", 713 | " #p = sess.run(mvn.prob([0.31580481, 0.70949184, 0.19824165]))\n", 714 | " p = sess.run(mvn.prob([0., 0.33333333, 0.66666667]))\n", 715 | "\n", 716 | " print (p)\n", 717 | " " 718 | ] 719 | }, 720 | { 721 | "cell_type": "code", 722 | "execution_count": 133, 723 | "metadata": {}, 724 | "outputs": [ 725 | { 726 | "name": "stdout", 727 | "output_type": "stream", 728 | "text": [ 729 | "-11.92646\n", 730 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 731 | "[0.31622776 0.31622776 0.31622776 0.31622776 0.31622776 0.31622776\n", 732 | " 0.31622776 0.31622776 0.31622776 0.31622776]\n" 733 | ] 734 | } 735 | ], 736 | "source": [ 737 | "import tensorflow as tf\n", 738 | "import numpy as np\n", 739 | "tfd = tf.contrib.distributions\n", 740 | "\n", 741 | "N=10\n", 742 | "\n", 743 | "x = np.linspace(0, 1, N, endpoint=False, dtype='float32')\n", 744 | "\n", 745 | "Iden = np.identity(N)\n", 746 | "I_cov = np.array(0.1*Iden, dtype='float32')\n", 747 | "xx = x.reshape(-1,N)\n", 748 | "\n", 749 | "\n", 750 | "mvn = tfd.MultivariateNormalFullCovariance(\n", 751 | " #loc=[0.0, 0.0, 0.0],\n", 752 | " covariance_matrix=I_cov )\n", 753 | "\n", 754 | "\n", 755 | "with tf.Session() as sess:\n", 756 | "\n", 757 | " #p = sess.run(mvn.prob([0.31580481, 0.70949184, 0.19824165]))\n", 758 | " p = sess.run(mvn.log_prob(x))\n", 759 | "\n", 760 | " print (p)\n", 761 | " print ( sess.run(mvn.mean()))\n", 762 | " print ( sess.run(mvn.stddev()))" 763 | ] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "execution_count": 135, 768 | "metadata": {}, 769 | "outputs": [ 770 | { 771 | "name": "stdout", 772 | "output_type": "stream", 773 | "text": [ 774 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 775 | "[0.31622776 0.31622776 0.31622776 0.31622776 0.31622776 0.31622776\n", 776 | " 0.31622776 0.31622776 0.31622776 0.31622776]\n", 777 | "-11.92646\n" 778 | ] 779 | } 780 | ], 781 | "source": [ 782 | "import tensorflow as tf\n", 783 | "import numpy as np\n", 784 | "tfd = tf.contrib.distributions\n", 785 | "\n", 786 | "#\n", 787 | "# missing a factor somewhere, do not use this\n", 788 | "#\n", 789 | "\n", 790 | "N=10\n", 791 | "\n", 792 | "x = np.linspace(0, 1, N, endpoint=False, dtype='float32')\n", 793 | "\n", 794 | "#x = [[0., 0.33333333, 0.66666667]]\n", 795 | "\n", 796 | "mvn = tfd.MultivariateNormalDiag(\n", 797 | " loc=N*[0.0],\n", 798 | " scale_identity_multiplier=0.31622776 )\n", 799 | "\n", 800 | "\n", 801 | "with tf.Session() as sess:\n", 802 | "\n", 803 | " #p = sess.run(mvn.prob([0.31580481, 0.70949184, 0.19824165]))\n", 804 | " p = sess.run(mvn.log_prob(x))\n", 805 | " print ( sess.run(mvn.mean()))\n", 806 | " print ( sess.run(mvn.stddev()))\n", 807 | " print (p)" 808 | ] 809 | }, 810 | { 811 | "cell_type": "code", 812 | "execution_count": 149, 813 | "metadata": {}, 814 | "outputs": [ 815 | { 816 | "name": "stdout", 817 | "output_type": "stream", 818 | "text": [ 819 | "[1431.8148]\n" 820 | ] 821 | } 822 | ], 823 | "source": [ 824 | "# Turn this into a helper function\n", 825 | "\n", 826 | "import tensorflow as tf\n", 827 | "import numpy as np\n", 828 | "tfd = tf.contrib.distributions\n", 829 | "\n", 830 | "def negMVN_diag(sig_dia,N,x): \n", 831 | "# N is the number of stocks\n", 832 | "# x has the shape,time x stock\n", 833 | " Iden = np.identity(N) #all the sigmas are the same, c-variances are zero\n", 834 | " I_cov = np.array(sig_dia*Iden, dtype='float32')\n", 835 | " x = x.reshape(-1,N)\n", 836 | " mvn = tfd.MultivariateNormalFullCovariance(\n", 837 | " covariance_matrix=I_cov )\n", 838 | "\n", 839 | "\n", 840 | " with tf.Session() as sess:\n", 841 | " p = sess.run(-mvn.log_prob(x))\n", 842 | "\n", 843 | " return(p) # retur negative log value\n", 844 | "\n", 845 | "N = 1000\n", 846 | "x = np.linspace(0, 1, N, endpoint=False, dtype='float32')\n", 847 | "\n", 848 | "print (negMVN_diag(0.1,N,x))" 849 | ] 850 | }, 851 | { 852 | "cell_type": "code", 853 | "execution_count": 172, 854 | "metadata": {}, 855 | "outputs": [ 856 | { 857 | "name": "stdout", 858 | "output_type": "stream", 859 | "text": [ 860 | "Tensor(\"add_7:0\", shape=(1,), dtype=float32)\n", 861 | "0 [-2.671099] [-1.4884] [10.805868]\n", 862 | "100 [-2.5711062] [-1.4884] [10.181693]\n", 863 | "200 [-2.4711134] [-1.4884] [9.577515]\n", 864 | "300 [-2.3711207] [-1.4884] [8.993334]\n", 865 | "400 [-2.271128] [-1.4884] [8.42915]\n", 866 | "500 [-2.1711352] [-1.4884] [7.884963]\n", 867 | "600 [-2.0711424] [-1.4884] [7.360773]\n", 868 | "700 [-1.9711462] [-1.4884] [6.856563]\n", 869 | "800 [-1.8711416] [-1.4884] [6.3723116]\n", 870 | "900 [-1.7711369] [-1.4884] [5.9080625]\n", 871 | "1000 [-1.6711322] [-1.4884] [5.4638147]\n", 872 | "1100 [-1.5711275] [-1.4884] [5.039569]\n", 873 | "1200 [-1.4711229] [-1.4884] [4.6353245]\n", 874 | "1300 [-1.3711182] [-1.4884] [4.2510824]\n", 875 | "1400 [-1.2711135] [-1.4884] [3.8868427]\n", 876 | "1500 [-1.1711088] [-1.4884] [3.5426044]\n", 877 | "1600 [-1.0711042] [-1.4884] [3.218368]\n", 878 | "1700 [-0.9711012] [-1.4884] [2.914139]\n", 879 | "1800 [-0.8711025] [-1.4884] [2.6299224]\n", 880 | "1900 [-0.7711038] [-1.4884] [2.365705]\n", 881 | "2000 [-0.6711051] [-1.4884] [2.1214871]\n", 882 | "2100 [-0.5711064] [-1.4884] [1.897269]\n", 883 | "2200 [-0.47110766] [-1.4884] [1.6930501]\n", 884 | "2300 [-0.37110895] [-1.4884] [1.508831]\n", 885 | "2400 [-0.27111024] [-1.4884] [1.3446112]\n", 886 | "2500 [-0.17111035] [-1.4884] [1.2003891]\n", 887 | "2600 [-0.07111014] [-1.4884] [1.0761667]\n", 888 | "2700 [0.02844376] [-1.4884] [0.97236526]\n", 889 | "2800 [0.11596418] [-1.4884] [0.89748347]\n", 890 | "2900 [0.1885992] [-1.4884] [0.8469703]\n", 891 | "3000 [0.24915928] [-1.4884] [0.8129211]\n", 892 | "3100 [0.29962212] [-1.4884] [0.79015124]\n", 893 | "3200 [0.34149143] [-1.4884] [0.7751249]\n", 894 | "3300 [0.3759862] [-1.4884] [0.7653794]\n", 895 | "3400 [0.40414405] [-1.4884] [0.75918835]\n", 896 | "3500 [0.4268772] [-1.4884] [0.75534695]\n", 897 | "3600 [0.44500196] [-1.4884] [0.75302476]\n", 898 | "3700 [0.4592536] [-1.4884] [0.7516603]\n", 899 | "3800 [0.4702923] [-1.4884] [0.7508825]\n", 900 | "3900 [0.47870535] [-1.4884] [0.7504535]\n", 901 | "4000 [0.48500764] [-1.4884] [0.75022477]\n", 902 | "4100 [0.48964304] [-1.4884] [0.7501073]\n", 903 | "4200 [0.49298722] [-1.4884] [0.7500492]\n", 904 | "4300 [0.49535093] [-1.4884] [0.7500216]\n", 905 | "4400 [0.49698606] [-1.4884] [0.75000906]\n", 906 | "4500 [0.49809164] [-1.4884] [0.75000364]\n", 907 | "4600 [0.49882153] [-1.4884] [0.75000143]\n", 908 | "4700 [0.499291] [-1.4884] [0.7500005]\n", 909 | "4800 [0.49958515] [-1.4884] [0.7500001]\n", 910 | "4900 [0.49976426] [-1.4884] [0.75000006]\n", 911 | "5000 [0.49987015] [-1.4884] [0.75]\n", 912 | "5100 [0.49993077] [-1.4884] [0.75]\n", 913 | "5200 [0.49996436] [-1.4884] [0.75]\n", 914 | "5300 [0.49998233] [-1.4884] [0.75]\n", 915 | "5400 [0.49999148] [-1.4884] [0.75]\n", 916 | "5500 [0.49999592] [-1.4884] [0.75]\n", 917 | "5600 [0.49999836] [-1.4884] [0.75]\n", 918 | "5700 [0.49999836] [-1.4884] [0.75]\n", 919 | "5800 [0.49999836] [-1.4884] [0.75]\n", 920 | "5900 [0.4999984] [-1.4884] [0.75]\n", 921 | "6000 [0.49999842] [-1.4884] [0.75]\n", 922 | "6100 [0.49999854] [-1.4884] [0.75]\n", 923 | "6200 [0.49999863] [-1.4884] [0.75]\n", 924 | "6300 [0.49999866] [-1.4884] [0.75]\n", 925 | "6400 [0.49999875] [-1.4884] [0.75]\n", 926 | "6500 [0.49999878] [-1.4884] [0.75]\n", 927 | "6600 [0.49999887] [-1.4884] [0.75]\n", 928 | "6700 [0.4999989] [-1.4884] [0.75]\n", 929 | "6800 [0.499999] [-1.4884] [0.75]\n", 930 | "6900 [0.49999902] [-1.4884] [0.75]\n", 931 | "7000 [0.4999991] [-1.4884] [0.75]\n", 932 | "7100 [0.4999991] [-1.4884] [0.75]\n", 933 | "7200 [0.49999914] [-1.4884] [0.75]\n", 934 | "7300 [0.49999923] [-1.4884] [0.75]\n", 935 | "7400 [0.49999923] [-1.4884] [0.75]\n", 936 | "7500 [0.49999925] [-1.4884] [0.75]\n", 937 | "7600 [0.49999934] [-1.4884] [0.75]\n", 938 | "7700 [0.49999934] [-1.4884] [0.75]\n", 939 | "7800 [0.49999937] [-1.4884] [0.75]\n", 940 | "7900 [0.49999937] [-1.4884] [0.75]\n", 941 | "8000 [0.49999946] [-1.4884] [0.75]\n", 942 | "8100 [0.49999946] [-1.4884] [0.75]\n", 943 | "8200 [0.4999995] [-1.4884] [0.75]\n", 944 | "8300 [0.4999995] [-1.4884] [0.75]\n", 945 | "8400 [0.49999958] [-1.4884] [0.75]\n", 946 | "8500 [0.49999958] [-1.4884] [0.75]\n", 947 | "8600 [0.49999958] [-1.4884] [0.75]\n", 948 | "8700 [0.4999996] [-1.4884] [0.75]\n", 949 | "8800 [0.4999996] [-1.4884] [0.75]\n", 950 | "8900 [0.4999996] [-1.4884] [0.75]\n", 951 | "9000 [0.4999997] [-1.4884] [0.75]\n", 952 | "9100 [0.4999997] [-1.4884] [0.75]\n", 953 | "9200 [0.4999997] [-1.4884] [0.75]\n", 954 | "9300 [0.4999997] [-1.4884] [0.75]\n", 955 | "9400 [0.49999973] [-1.4884] [0.75]\n", 956 | "9500 [0.49999973] [-1.4884] [0.75]\n", 957 | "9600 [0.49999973] [-1.4884] [0.75]\n", 958 | "9700 [0.49999973] [-1.4884] [0.75]\n", 959 | "9800 [0.49999982] [-1.4884] [0.75]\n", 960 | "9900 [0.49999982] [-1.4884] [0.75]\n", 961 | "10000 [0.49999982] [-1.4884] [0.75]\n", 962 | "10100 [0.49999982] [-1.4884] [0.75]\n", 963 | "10200 [0.49999982] [-1.4884] [0.75]\n", 964 | "10300 [0.49999982] [-1.4884] [0.75]\n", 965 | "10400 [0.49999985] [-1.4884] [0.75]\n", 966 | "10500 [0.49999985] [-1.4884] [0.75]\n", 967 | "10600 [0.49999985] [-1.4884] [0.75]\n", 968 | "10700 [0.49999985] [-1.4884] [0.75]\n", 969 | "10800 [0.49999985] [-1.4884] [0.75]\n", 970 | "10900 [0.49999985] [-1.4884] [0.75]\n", 971 | "11000 [0.49999985] [-1.4884] [0.75]\n", 972 | "11100 [0.49999985] [-1.4884] [0.75]\n", 973 | "11200 [0.49999994] [-1.4884] [0.75]\n", 974 | "11300 [0.49999994] [-1.4884] [0.75]\n", 975 | "11400 [0.49999994] [-1.4884] [0.75]\n", 976 | "11500 [0.49999994] [-1.4884] [0.75]\n", 977 | "11600 [0.49999994] [-1.4884] [0.75]\n", 978 | "11700 [0.49999994] [-1.4884] [0.75]\n", 979 | "11800 [0.49999994] [-1.4884] [0.75]\n", 980 | "11900 [0.49999994] [-1.4884] [0.75]\n", 981 | "12000 [0.49999994] [-1.4884] [0.75]\n", 982 | "12100 [0.49999994] [-1.4884] [0.75]\n", 983 | "12200 [0.49999994] [-1.4884] [0.75]\n", 984 | "12300 [0.49999994] [-1.4884] [0.75]\n", 985 | "12400 [0.49999994] [-1.4884] [0.75]\n", 986 | "12500 [0.49999994] [-1.4884] [0.75]\n", 987 | "12600 [0.49999997] [-1.4884] [0.75]\n", 988 | "12700 [0.49999997] [-1.4884] [0.75]\n", 989 | "12800 [0.49999997] [-1.4884] [0.75]\n", 990 | "12900 [0.49999997] [-1.4884] [0.75]\n", 991 | "13000 [0.49999997] [-1.4884] [0.75]\n", 992 | "13100 [0.49999997] [-1.4884] [0.75]\n", 993 | "13200 [0.49999997] [-1.4884] [0.75]\n", 994 | "13300 [0.49999997] [-1.4884] [0.75]\n", 995 | "13400 [0.49999997] [-1.4884] [0.75]\n", 996 | "13500 [0.49999997] [-1.4884] [0.75]\n", 997 | "13600 [0.49999997] [-1.4884] [0.75]\n", 998 | "13700 [0.49999997] [-1.4884] [0.75]\n", 999 | "13800 [0.49999997] [-1.4884] [0.75]\n", 1000 | "13900 [0.49999997] [-1.4884] [0.75]\n", 1001 | "14000 [0.49999997] [-1.4884] [0.75]\n", 1002 | "14100 [0.49999997] [-1.4884] [0.75]\n", 1003 | "14200 [0.49999997] [-1.4884] [0.75]\n", 1004 | "14300 [0.49999997] [-1.4884] [0.75]\n", 1005 | "14400 [0.49999997] [-1.4884] [0.75]\n", 1006 | "14500 [0.49999997] [-1.4884] [0.75]\n", 1007 | "14600 [0.49999997] [-1.4884] [0.75]\n", 1008 | "14700 [0.49999997] [-1.4884] [0.75]\n", 1009 | "14800 [0.49999997] [-1.4884] [0.75]\n", 1010 | "14900 [0.49999997] [-1.4884] [0.75]\n", 1011 | "15000 [0.49999997] [-1.4884] [0.75]\n", 1012 | "15100 [0.49999997] [-1.4884] [0.75]\n", 1013 | "15200 [0.49999997] [-1.4884] [0.75]\n", 1014 | "15300 [0.49999997] [-1.4884] [0.75]\n", 1015 | "15400 [0.49999997] [-1.4884] [0.75]\n", 1016 | "15500 [0.49999997] [-1.4884] [0.75]\n", 1017 | "15600 [0.49999997] [-1.4884] [0.75]\n", 1018 | "15700 [0.49999997] [-1.4884] [0.75]\n", 1019 | "15800 [0.49999997] [-1.4884] [0.75]\n", 1020 | "15900 [0.49999997] [-1.4884] [0.75]\n", 1021 | "16000 [0.49999997] [-1.4884] [0.75]\n", 1022 | "16100 [0.49999997] [-1.4884] [0.75]\n", 1023 | "16200 [0.49999997] [-1.4884] [0.75]\n", 1024 | "16300 [0.49999997] [-1.4884] [0.75]\n", 1025 | "16400 [0.49999997] [-1.4884] [0.75]\n", 1026 | "16500 [0.49999997] [-1.4884] [0.75]\n", 1027 | "16600 [0.49999997] [-1.4884] [0.75]\n", 1028 | "16700 [0.49999997] [-1.4884] [0.75]\n", 1029 | "16800 [0.49999997] [-1.4884] [0.75]\n", 1030 | "16900 [0.49999997] [-1.4884] [0.75]\n", 1031 | "17000 [0.49999997] [-1.4884] [0.75]\n", 1032 | "17100 [0.49999997] [-1.4884] [0.75]\n", 1033 | "17200 [0.49999997] [-1.4884] [0.75]\n", 1034 | "17300 [0.49999997] [-1.4884] [0.75]\n", 1035 | "17400 [0.49999997] [-1.4884] [0.75]\n", 1036 | "17500 [0.49999997] [-1.4884] [0.75]\n", 1037 | "17600 [0.49999997] [-1.4884] [0.75]\n", 1038 | "17700 [0.49999997] [-1.4884] [0.75]\n", 1039 | "17800 [0.49999997] [-1.4884] [0.75]\n", 1040 | "17900 [0.49999997] [-1.4884] [0.75]\n", 1041 | "18000 [0.49999997] [-1.4884] [0.75]\n", 1042 | "18100 [0.49999997] [-1.4884] [0.75]\n", 1043 | "18200 [0.49999997] [-1.4884] [0.75]\n", 1044 | "18300 [0.49999997] [-1.4884] [0.75]\n", 1045 | "18400 [0.49999997] [-1.4884] [0.75]\n", 1046 | "18500 [0.49999997] [-1.4884] [0.75]\n", 1047 | "18600 [0.49999997] [-1.4884] [0.75]\n", 1048 | "18700 [0.49999997] [-1.4884] [0.75]\n", 1049 | "18800 [0.49999997] [-1.4884] [0.75]\n", 1050 | "18900 [0.49999997] [-1.4884] [0.75]\n", 1051 | "19000 [0.49999997] [-1.4884] [0.75]\n", 1052 | "19100 [0.49999997] [-1.4884] [0.75]\n", 1053 | "19200 [0.49999997] [-1.4884] [0.75]\n", 1054 | "19300 [0.49999997] [-1.4884] [0.75]\n", 1055 | "19400 [0.49999997] [-1.4884] [0.75]\n", 1056 | "19500 [0.49999997] [-1.4884] [0.75]\n", 1057 | "19600 [0.49999997] [-1.4884] [0.75]\n", 1058 | "19700 [0.49999997] [-1.4884] [0.75]\n", 1059 | "19800 [0.49999997] [-1.4884] [0.75]\n", 1060 | "19900 [0.49999997] [-1.4884] [0.75]\n" 1061 | ] 1062 | } 1063 | ], 1064 | "source": [ 1065 | "import tensorflow as tf\n", 1066 | "\n", 1067 | "# initialize arrays\n", 1068 | "x1_data = tf.Variable(initial_value=tf.random_uniform([1], -3, 3),name='x1')\n", 1069 | "x2_data = tf.Variable(initial_value=tf.random_uniform([1], -3, 3), name='x2')\n", 1070 | "\n", 1071 | "# The Rosenbrock problem is defined as y = (1 - x1)^2 + 100 * (x2 - x1^2)^2, \n", 1072 | "# giving the optimal solution on x1 = x2 = 1\n", 1073 | "\n", 1074 | "# Loss function\n", 1075 | "#y = tf.add(tf.pow(tf.subtract(1.0, x1_data), 2.0), \n", 1076 | "# tf.multiply(100.0, tf.pow(tf.subtract(x2_data,tf.pow(x1_data, 2.0)), 2.0)), 'y')\n", 1077 | "\n", 1078 | "#y = (1.0-x1_data)**2 + 100.*(x2_data-x1_data**2)**2\n", 1079 | "y = (1.0-x1_data)**2 + x1_data \n", 1080 | "\n", 1081 | "print (y)\n", 1082 | "\n", 1083 | "#opt = tf.train.GradientDescentOptimizer(0.0035)\n", 1084 | "#train = opt.minimize(y)\n", 1085 | "\n", 1086 | "opt = tf.train.GradientDescentOptimizer(1e-3)\n", 1087 | "opt = tf.train.AdamOptimizer(1e-3)\n", 1088 | "\n", 1089 | "\n", 1090 | "grads_and_vars = opt.compute_gradients(y, [x1_data])\n", 1091 | "clipped_grads_and_vars = [(tf.clip_by_value(g, -1., 1.), v) for g, v in grads_and_vars]\n", 1092 | "train = opt.apply_gradients(clipped_grads_and_vars)\n", 1093 | "\n", 1094 | "\n", 1095 | "sess = tf.Session()\n", 1096 | "\n", 1097 | "init = tf.initialize_all_variables()\n", 1098 | "sess.run(init)\n", 1099 | "\n", 1100 | "for step in range(20000):\n", 1101 | " sess.run(train)\n", 1102 | " if step % 100 == 0:\n", 1103 | " print(step, sess.run(x1_data), sess.run(x2_data), sess.run(y))" 1104 | ] 1105 | }, 1106 | { 1107 | "cell_type": "code", 1108 | "execution_count": null, 1109 | "metadata": {}, 1110 | "outputs": [], 1111 | "source": [] 1112 | } 1113 | ], 1114 | "metadata": { 1115 | "kernelspec": { 1116 | "display_name": "Python 3", 1117 | "language": "python", 1118 | "name": "python3" 1119 | }, 1120 | "language_info": { 1121 | "codemirror_mode": { 1122 | "name": "ipython", 1123 | "version": 3 1124 | }, 1125 | "file_extension": ".py", 1126 | "mimetype": "text/x-python", 1127 | "name": "python", 1128 | "nbconvert_exporter": "python", 1129 | "pygments_lexer": "ipython3", 1130 | "version": "3.6.8" 1131 | } 1132 | }, 1133 | "nbformat": 4, 1134 | "nbformat_minor": 2 1135 | } 1136 | --------------------------------------------------------------------------------