├── README.md ├── dqn ├── __init__.py ├── cart_pole.py └── logs │ └── dqn │ └── 20191214-195334 │ └── events.out.tfevents.1576382014.Vivians-MacBook-Pro-2.local.25490.32728120.v2 ├── em ├── data │ ├── labeled.csv │ └── unlabeled.csv ├── em.py ├── semi-supervised.png └── unsupervised.png ├── mcmc_gibbs ├── data │ ├── img.png │ └── img_noisy.png ├── image_denoising.py └── output │ ├── denoise_image.png │ ├── log_energy │ └── log_energy.png ├── policy_gradient ├── model │ ├── __init__.py │ ├── baseline_net.py │ ├── main.py │ ├── policy_gradient.py │ └── policy_net.py └── results │ ├── rewards.npy │ ├── rewards.png │ └── videos │ ├── openaigym.episode_batch.0.2394.stats.json │ ├── openaigym.manifest.0.2394.manifest.json │ ├── openaigym.video.0.2394.video000000.meta.json │ └── openaigym.video.0.2394.video000000.mp4 └── search ├── __init__.py ├── a_star.py ├── dijkstra.py ├── dp.py └── tree_search.py /README.md: -------------------------------------------------------------------------------- 1 | # dqn 2 | 3 | #### source code for tutorial : https://towardsdatascience.com/deep-reinforcement-learning-build-a-deep-q-network-dqn-to-play-cartpole-with-tensorflow-2-and-gym-8e105744b998 4 | 5 | 6 | # mcmc_gibbs 7 | 8 | #### source code for tutorial: https://towardsdatascience.com/image-denoising-with-gibbs-sampling-mcmc-concepts-and-code-implementation-11d42a90e153 9 | 10 | 11 | # em 12 | 13 | #### source code for tutorial: https://towardsdatascience.com/implement-expectation-maximization-em-algorithm-in-python-from-scratch-f1278d1b9137 14 | 15 | # policy_gradient 16 | 17 | #### source code for tutorial : https://towardsdatascience.com/policy-gradient-reinforce-algorithm-with-baseline-e95ace11c1c4 18 | 19 | # search 20 | 21 | #### source code for tutorial : https://siwei-xu.medium.com/search-algorithms-concepts-and-implementation-1073594aeda6 22 | -------------------------------------------------------------------------------- /dqn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/dqn/__init__.py -------------------------------------------------------------------------------- /dqn/cart_pole.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import gym 4 | import os 5 | import datetime 6 | from statistics import mean 7 | from gym import wrappers 8 | 9 | 10 | class MyModel(tf.keras.Model): 11 | def __init__(self, num_states, hidden_units, num_actions): 12 | super(MyModel, self).__init__() 13 | self.input_layer = tf.keras.layers.InputLayer(input_shape=(num_states,)) 14 | self.hidden_layers = [] 15 | for i in hidden_units: 16 | self.hidden_layers.append(tf.keras.layers.Dense( 17 | i, activation='tanh', kernel_initializer='RandomNormal')) 18 | self.output_layer = tf.keras.layers.Dense( 19 | num_actions, activation='linear', kernel_initializer='RandomNormal') 20 | 21 | @tf.function 22 | def call(self, inputs): 23 | z = self.input_layer(inputs) 24 | for layer in self.hidden_layers: 25 | z = layer(z) 26 | output = self.output_layer(z) 27 | return output 28 | 29 | 30 | class DQN: 31 | def __init__(self, num_states, num_actions, hidden_units, gamma, max_experiences, min_experiences, batch_size, lr): 32 | self.num_actions = num_actions 33 | self.batch_size = batch_size 34 | self.optimizer = tf.optimizers.Adam(lr) 35 | self.gamma = gamma 36 | self.model = MyModel(num_states, hidden_units, num_actions) 37 | self.experience = {'s': [], 'a': [], 'r': [], 's2': [], 'done': []} 38 | self.max_experiences = max_experiences 39 | self.min_experiences = min_experiences 40 | 41 | def predict(self, inputs): 42 | return self.model(np.atleast_2d(inputs.astype('float32'))) 43 | 44 | def train(self, TargetNet): 45 | if len(self.experience['s']) < self.min_experiences: 46 | return 0 47 | ids = np.random.randint(low=0, high=len(self.experience['s']), size=self.batch_size) 48 | states = np.asarray([self.experience['s'][i] for i in ids]) 49 | actions = np.asarray([self.experience['a'][i] for i in ids]) 50 | rewards = np.asarray([self.experience['r'][i] for i in ids]) 51 | states_next = np.asarray([self.experience['s2'][i] for i in ids]) 52 | dones = np.asarray([self.experience['done'][i] for i in ids]) 53 | value_next = np.max(TargetNet.predict(states_next), axis=1) 54 | actual_values = np.where(dones, rewards, rewards+self.gamma*value_next) 55 | 56 | with tf.GradientTape() as tape: 57 | selected_action_values = tf.math.reduce_sum( 58 | self.predict(states) * tf.one_hot(actions, self.num_actions), axis=1) 59 | loss = tf.math.reduce_mean(tf.square(actual_values - selected_action_values)) 60 | variables = self.model.trainable_variables 61 | gradients = tape.gradient(loss, variables) 62 | self.optimizer.apply_gradients(zip(gradients, variables)) 63 | return loss 64 | 65 | def get_action(self, states, epsilon): 66 | if np.random.random() < epsilon: 67 | return np.random.choice(self.num_actions) 68 | else: 69 | return np.argmax(self.predict(np.atleast_2d(states))[0]) 70 | 71 | def add_experience(self, exp): 72 | if len(self.experience['s']) >= self.max_experiences: 73 | for key in self.experience.keys(): 74 | self.experience[key].pop(0) 75 | for key, value in exp.items(): 76 | self.experience[key].append(value) 77 | 78 | def copy_weights(self, TrainNet): 79 | variables1 = self.model.trainable_variables 80 | variables2 = TrainNet.model.trainable_variables 81 | for v1, v2 in zip(variables1, variables2): 82 | v1.assign(v2.numpy()) 83 | 84 | 85 | def play_game(env, TrainNet, TargetNet, epsilon, copy_step): 86 | rewards = 0 87 | iter = 0 88 | done = False 89 | observations, _ = env.reset() 90 | losses = list() 91 | while not done: 92 | action = TrainNet.get_action(observations, epsilon) 93 | prev_observations = observations 94 | observations, reward, done, _, _ = env.step(action) 95 | rewards += reward 96 | if done: 97 | reward = -200 98 | env.reset() 99 | 100 | exp = {'s': prev_observations, 'a': action, 'r': reward, 's2': observations, 'done': done} 101 | TrainNet.add_experience(exp) 102 | loss = TrainNet.train(TargetNet) 103 | if isinstance(loss, int): 104 | losses.append(loss) 105 | else: 106 | losses.append(loss.numpy()) 107 | iter += 1 108 | if iter % copy_step == 0: 109 | TargetNet.copy_weights(TrainNet) 110 | return rewards, mean(losses) 111 | 112 | def make_video(env, TrainNet): 113 | env = wrappers.Monitor(env, os.path.join(os.getcwd(), "videos"), force=True) 114 | rewards = 0 115 | steps = 0 116 | done = False 117 | observation = env.reset() 118 | while not done: 119 | env.render() 120 | action = TrainNet.get_action(observation, 0) 121 | observation, reward, done, _, _= env.step(action) 122 | steps += 1 123 | rewards += reward 124 | print("Testing steps: {} rewards {}: ".format(steps, rewards)) 125 | 126 | 127 | def main(): 128 | env = gym.make('CartPole-v0') 129 | gamma = 0.99 130 | copy_step = 25 131 | num_states = len(env.observation_space.sample()) 132 | num_actions = env.action_space.n 133 | hidden_units = [200, 200] 134 | max_experiences = 10000 135 | min_experiences = 100 136 | batch_size = 32 137 | lr = 1e-2 138 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 139 | log_dir = 'logs/dqn/' + current_time 140 | summary_writer = tf.summary.create_file_writer(log_dir) 141 | 142 | TrainNet = DQN(num_states, num_actions, hidden_units, gamma, max_experiences, min_experiences, batch_size, lr) 143 | TargetNet = DQN(num_states, num_actions, hidden_units, gamma, max_experiences, min_experiences, batch_size, lr) 144 | N = 50000 145 | total_rewards = np.empty(N) 146 | epsilon = 0.99 147 | decay = 0.9999 148 | min_epsilon = 0.1 149 | for n in range(N): 150 | epsilon = max(min_epsilon, epsilon * decay) 151 | total_reward, losses = play_game(env, TrainNet, TargetNet, epsilon, copy_step) 152 | total_rewards[n] = total_reward 153 | avg_rewards = total_rewards[max(0, n - 100):(n + 1)].mean() 154 | with summary_writer.as_default(): 155 | tf.summary.scalar('episode reward', total_reward, step=n) 156 | tf.summary.scalar('running avg reward(100)', avg_rewards, step=n) 157 | tf.summary.scalar('average loss)', losses, step=n) 158 | if n % 100 == 0: 159 | print("episode:", n, "episode reward:", total_reward, "eps:", epsilon, "avg reward (last 100):", avg_rewards, 160 | "episode loss: ", losses) 161 | print("avg reward for last 100 episodes:", avg_rewards) 162 | make_video(env, TrainNet) 163 | env.close() 164 | 165 | 166 | if __name__ == '__main__': 167 | for i in range(3): 168 | main() 169 | -------------------------------------------------------------------------------- /dqn/logs/dqn/20191214-195334/events.out.tfevents.1576382014.Vivians-MacBook-Pro-2.local.25490.32728120.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/dqn/logs/dqn/20191214-195334/events.out.tfevents.1576382014.Vivians-MacBook-Pro-2.local.25490.32728120.v2 -------------------------------------------------------------------------------- /em/data/labeled.csv: -------------------------------------------------------------------------------- 1 | x1,x2,y 2 | -0.813,-1.099,0 3 | -0.414,-1.369,0 4 | -0.463,0.195,0 5 | -1.72,-2.537,0 6 | -0.188,-0.25,0 7 | -1.711,-0.508,0 8 | -0.976,-0.324,0 9 | -1.821,-3.03,0 10 | -0.307,-0.687,0 11 | -1.042,-1.173,0 12 | -1.393,-1.287,0 13 | -0.105,-1.505,0 14 | -0.696,-1.016,0 15 | -0.888,-1.385,0 16 | -0.925,-2.182,0 17 | -0.989,-0.878,0 18 | -0.401,-0.565,0 19 | -1.634,-2.338,0 20 | -2.043,-2.684,0 21 | -1.279,-1.949,0 22 | 0.884,1.261,1 23 | 1.706,1.997,1 24 | 1.225,0.943,1 25 | 1.047,0.01,1 26 | 1.461,1.554,1 27 | -1.802,-2.659,0 28 | -0.039,0.524,1 29 | 0.882,0.981,1 30 | -0.113,1.11,1 31 | 0.7,1.427,1 32 | -0.473,0.283,0 33 | 1.457,1.637,1 34 | -0.698,0.526,1 35 | 2.282,1.329,1 36 | 0.117,0.869,1 37 | 0.522,0.098,1 38 | 2.24,0.605,1 39 | 2.612,1.165,1 40 | 1.479,1.327,1 41 | 0.484,0.67,1 42 | -0.371,-0.666,0 43 | -1.398,-0.995,0 44 | -1.879,-0.916,0 45 | -1.106,-0.256,0 46 | -0.596,-0.875,0 47 | -0.705,-0.133,0 48 | 2.211,1.69,1 49 | -0.156,-0.216,0 50 | 0.958,0.706,1 51 | 0.405,-0.482,0 52 | -0.804,-2.273,0 53 | -1.205,-1.077,0 54 | -1.772,-2.231,0 55 | -1.026,-0.909,0 56 | -1.219,-0.6,0 57 | -0.994,0.228,0 58 | -0.811,-0.487,0 59 | -0.973,-0.65,0 60 | -0.939,0.63,0 61 | -1.114,-0.835,0 62 | 1.376,1.174,1 63 | 1.723,1.302,1 64 | 0.346,1.175,1 65 | 0.406,0.782,1 66 | 2.415,1.917,1 67 | -0.325,0.542,1 68 | 2.255,0.75,1 69 | 0.428,0.398,1 70 | 1.018,0.573,1 71 | 1.136,0.484,1 72 | -0.549,0.55,1 73 | -0.843,-1.964,0 74 | 2.029,1.108,1 75 | 2.041,1.705,1 76 | 2.399,1.613,1 77 | 0.762,1.628,1 78 | 1.592,1.084,1 79 | 0.707,1.923,1 80 | 1.447,1.185,1 81 | 1.942,0.688,1 82 | 1.282,0.94,1 83 | -0.024,0.326,1 84 | -1.398,-1.954,0 85 | 1.183,1.328,1 86 | 3.092,1.03,1 87 | 1.426,1.489,1 88 | 0.704,1.07,1 89 | 0.017,0.532,1 90 | -0.12,1.159,1 91 | 0.511,0.696,1 92 | 0.428,0.298,1 93 | 0.627,1.605,1 94 | 2.317,1.126,1 95 | -1.771,-2.436,0 96 | 0.628,0.751,1 97 | -0.114,0.483,1 98 | 1.095,-0.305,1 99 | 0.491,0.604,1 100 | 1.872,1.139,1 101 | -0.174,0.628,1 102 | -------------------------------------------------------------------------------- /em/data/unlabeled.csv: -------------------------------------------------------------------------------- 1 | x1,x2 2 | -0.187,0.747 3 | 2.824,0.377 4 | 0.713,0.766 5 | 1.635,1.846 6 | 2.711,1.996 7 | 0.456,0.592 8 | 0.564,0.969 9 | 1.86,1.812 10 | -0.05,0.755 11 | 1.68,1.087 12 | 0.902,1.35 13 | 0.708,0.993 14 | 1.19,1.726 15 | -0.143,1.287 16 | -1.932,-2.018 17 | 1.297,0.519 18 | -0.173,1.022 19 | 0.822,2.264 20 | 0.4,1.207 21 | 2.204,0.826 22 | 0.525,0.744 23 | 0.701,0.942 24 | 1.556,0.623 25 | 1.144,1.195 26 | 0.202,1.772 27 | 1.431,0.981 28 | 2.18,0.096 29 | 2.359,0.991 30 | 2.292,0.98 31 | 0.422,1.892 32 | 0.557,1.4 33 | 0.951,1.225 34 | 0.135,1.648 35 | 0.856,1.218 36 | 1.898,1.818 37 | 0.676,0.222 38 | -0.062,-0.903 39 | 0.061,0.629 40 | 0.562,0.91 41 | 1.682,1.427 42 | -0.185,-0.712 43 | -1.309,-1.89 44 | -1.095,-2.306 45 | 0.369,0.521 46 | -1.228,-1.346 47 | -1.746,-1.868 48 | 1.526,0.949 49 | -1.014,0.144 50 | -1.398,-0.09 51 | -0.418,-0.362 52 | -0.993,-2.258 53 | -0.525,-0.897 54 | -0.265,-0.324 55 | -2.048,-1.228 56 | -0.565,-1.016 57 | 0.949,0.06 58 | -0.847,-1.409 59 | -0.215,-0.161 60 | 0.339,-0.545 61 | 1.863,1.799 62 | 0.596,0.921 63 | 0.734,0.138 64 | -0.365,0.603 65 | 1.902,1.448 66 | 1.155,0.739 67 | 1.969,1.346 68 | 1.859,1.83 69 | 1.321,0.187 70 | -1.434,-1.908 71 | 0.849,0.338 72 | 0.193,0.211 73 | 0.79,1.18 74 | 1.44,1.859 75 | 0.181,0.771 76 | 1.464,1.867 77 | 0.518,1.005 78 | 1.622,0.824 79 | 1.282,0.83 80 | -1.371,-2.468 81 | -1.332,-1.805 82 | -0.301,-0.815 83 | -0.546,-1.417 84 | -0.992,2.02 85 | -1.525,-0.49 86 | -1.656,-2.223 87 | -0.823,-0.781 88 | -0.256,-0.738 89 | -0.483,-0.096 90 | -0.801,-0.536 91 | -0.765,-0.702 92 | -1.423,-1.843 93 | -0.6,-0.686 94 | -1.76,-0.851 95 | -0.596,-0.018 96 | -1.583,-1.234 97 | -1.333,-1.814 98 | 0.903,-0.055 99 | 0.221,0.926 100 | -1.727,-1.149 101 | -1.212,-0.82 102 | 0.628,1.473 103 | -0.21,0.831 104 | 1.767,0.546 105 | 2.252,-0.016 106 | 1.331,1.413 107 | -0.534,0.596 108 | -1.514,-1.854 109 | 1.02,0.593 110 | 0.444,0.718 111 | 0.809,1.081 112 | 0.48,1.697 113 | -1.145,-1.29 114 | 1.488,1.332 115 | 0.532,1.222 116 | -0.211,0.54 117 | -0.213,0.563 118 | 0.7,1.417 119 | -0.318,1.073 120 | 0.416,-0.271 121 | 0.475,1.54 122 | -0.098,0.134 123 | 2.731,2.092 124 | 1.831,0.673 125 | -1.058,-0.724 126 | 0.196,-0.251 127 | 1.883,1.68 128 | 0.915,0.766 129 | -0.064,0.622 130 | 0.599,0.956 131 | 0.328,1.476 132 | 1.432,1.173 133 | 0.815,1.585 134 | 1.813,1.094 135 | 0.934,1.862 136 | 0.731,0.112 137 | 2.698,1.891 138 | 1.517,1.288 139 | -0.436,1.024 140 | 2.016,1.251 141 | 0.531,1.912 142 | 0.385,1.515 143 | 1.832,1.708 144 | 0.35,0.448 145 | -0.534,-0.59 146 | 0.987,1.691 147 | 1.905,1.115 148 | 2.044,0.804 149 | 0.257,0.425 150 | 1.913,0.963 151 | 1.005,1.79 152 | 2.155,0.738 153 | -0.717,0.655 154 | 1.24,1.415 155 | 0.904,0.987 156 | 0.644,1.457 157 | 0.311,1.232 158 | 1.09,1.071 159 | 0.49,0.537 160 | 0.464,0.396 161 | -0.44,0.404 162 | 0.317,0.599 163 | -1.512,-1.483 164 | 2.253,1.14 165 | 1.298,0.767 166 | 0.365,0.946 167 | 1.074,1.699 168 | 0.874,1.647 169 | 0.743,0.495 170 | 0.716,1.309 171 | 1.677,2.278 172 | 1.302,1.047 173 | 1.84,0.67 174 | 2.557,1.43 175 | 0.485,1.303 176 | 0.417,1.028 177 | 1.299,0.636 178 | 0.32,1.361 179 | -0.505,0.809 180 | -0.716,0.118 181 | 1.764,0.774 182 | 0.886,0.556 183 | 1.011,1.212 184 | 1.114,1.082 185 | -0.077,0.362 186 | 2.052,1.399 187 | -0.392,-0.942 188 | 0.846,1.435 189 | 1.86,1.226 190 | 0.485,0.576 191 | 1.625,1.602 192 | 2.123,1.615 193 | 1.736,1.802 194 | 2.182,1.88 195 | 1.04,1.446 196 | 0.797,1.464 197 | 1.508,1.515 198 | 1.413,0.956 199 | -0.261,0.488 200 | 2.655,0.794 201 | 2.404,1.603 202 | 1.223,0.906 203 | -1.076,-0.568 204 | -1.247,-1.7 205 | -0.943,-3.35 206 | -1.369,-1.248 207 | 0.805,1.492 208 | -1.025,-0.535 209 | -1.025,-0.811 210 | 2.824,1.692 211 | -0.628,0.254 212 | -0.68,0.102 213 | -0.899,-0.888 214 | -0.09,-0.335 215 | -1.153,-2.056 216 | -0.906,0.589 217 | -1.356,-1.408 218 | -2.144,-2.256 219 | -0.731,-1.422 220 | -1.554,-0.996 221 | 0.873,-0.017 222 | -1.726,-1.457 223 | -0.801,-1.487 224 | -0.9,-0.676 225 | -0.549,-0.801 226 | -2.208,-2.131 227 | -0.722,0.887 228 | 0.316,0.802 229 | -1.184,-1.174 230 | -1.294,-0.485 231 | 0.897,1.042 232 | -0.057,0.7 233 | -1.673,-2.627 234 | -0.453,0.104 235 | -1.967,0.724 236 | -1.228,-0.739 237 | 1.181,0.867 238 | -0.077,-0.078 239 | -1.136,-0.145 240 | -1.452,-1.159 241 | -0.867,-0.003 242 | 0.473,1.032 243 | -1.149,-1.386 244 | -0.519,-0.996 245 | -1.689,-1.03 246 | -2.116,-3.295 247 | 1.142,1.096 248 | -0.406,-0.732 249 | -1.319,-0.805 250 | -1.805,-1.108 251 | -1.43,-1.093 252 | -1.203,-0.852 253 | 0.557,0.468 254 | -0.467,-0.35 255 | -1.272,-1.706 256 | -1.547,-1.951 257 | -0.141,-0.031 258 | -1.221,-1.262 259 | -0.997,-0.241 260 | -0.525,-0.267 261 | -1.846,-0.852 262 | 0.573,1.606 263 | 0.949,1.158 264 | 0.95,0.372 265 | -0.632,-0.288 266 | 0.517,0.618 267 | -0.792,0.3 268 | 1.523,1.123 269 | -0.617,-0.871 270 | -0.455,1.15 271 | 0.177,0.946 272 | 1.815,1.801 273 | 1.567,0.571 274 | -0.05,1.035 275 | 2.458,0.396 276 | 1.679,2.544 277 | 1.638,0.78 278 | 1.085,-0.005 279 | 0.602,1.271 280 | 0.52,0.838 281 | 0.968,1.592 282 | 0.725,0.593 283 | 0.762,1.427 284 | 0.44,0.277 285 | 2.114,1.402 286 | 0.362,1.106 287 | 1.683,1.319 288 | -0.187,1.08 289 | 0.255,1.266 290 | 0.583,0.744 291 | 2.302,1.591 292 | -1.093,-0.59 293 | 0.358,0.688 294 | -0.402,1.448 295 | -0.505,0.675 296 | 0.142,1.181 297 | -0.37,-1.204 298 | 1.638,0.859 299 | -0.995,-1.496 300 | -0.672,-1.435 301 | 2.149,1.668 302 | -0.17,-0.714 303 | 1.74,0.555 304 | -1.157,-0.149 305 | -0.444,1.656 306 | -1.757,-2.124 307 | -1.64,-0.88 308 | -1.565,-2.378 309 | -1.144,-2.218 310 | -1.332,0.566 311 | -0.086,-0.159 312 | -1.232,-0.174 313 | -1.571,-2.116 314 | -1.013,-0.964 315 | -0.561,-0.247 316 | 0.424,0.388 317 | -1.303,0.208 318 | -1.463,-1.419 319 | -1.473,-1.4 320 | 0.126,0.746 321 | -1.373,-0.308 322 | -1.491,-1.873 323 | -1.517,-1.528 324 | -1.034,-1.423 325 | -0.05,0.631 326 | -1.751,-1.859 327 | -1.412,-2.62 328 | -1.457,-1.817 329 | -1.36,-1.796 330 | -1.829,-0.932 331 | -1.375,-0.263 332 | -1.2,-0.516 333 | -1.579,-1.882 334 | -2.221,-2.028 335 | 0.177,0.517 336 | -0.026,0.05 337 | -1.277,-1.589 338 | -0.028,0.59 339 | 1.339,1.907 340 | -1.186,-2.24 341 | -0.276,-0.796 342 | 2.61,1.938 343 | 0.231,0.503 344 | 2.025,1.198 345 | 1.345,2.075 346 | -1.041,-0.351 347 | 1.498,0.643 348 | 0.3,1.63 349 | 0.922,2.148 350 | -0.72,0.096 351 | 1.161,0.065 352 | 0.71,1.236 353 | 1.482,1.754 354 | 1.32,0.768 355 | 0.387,0.948 356 | 0.029,1.267 357 | 1.372,1.34 358 | 0.613,0.798 359 | 0.588,1.863 360 | 0.066,0.718 361 | 0.712,0.859 362 | 1.504,0.691 363 | -0.578,0.993 364 | 1.138,1.312 365 | 0.482,1.175 366 | 2.28,0.876 367 | 0.952,0.93 368 | 1.784,0.999 369 | -0.378,1.339 370 | 1.773,1.544 371 | 2.106,1.412 372 | 1.581,2.059 373 | 0.313,-0.877 374 | 1.766,1.451 375 | 0.133,1.401 376 | 1.557,1.294 377 | 0.745,0.729 378 | 2.622,0.956 379 | 1.729,1.471 380 | 0.388,0.15 381 | 0.395,1.335 382 | 1.245,0.719 383 | 1.875,1.164 384 | -0.414,0.881 385 | 0.449,-0.437 386 | 0.14,1.216 387 | 0.337,0.172 388 | 0.615,1.084 389 | 2.374,1.646 390 | 1.759,1.211 391 | -2.189,-2.729 392 | 1.403,1.553 393 | 1.562,0.634 394 | 1.383,0.603 395 | -0.108,0.823 396 | 0.965,0.858 397 | 0.335,0.233 398 | 1.57,1.74 399 | 0.656,0.786 400 | 2.095,1.218 401 | 0.121,1.461 402 | -2.117,-1.86 403 | -0.151,-0.207 404 | -1.223,-0.679 405 | 2.183,0.861 406 | -1.344,-1.733 407 | -1.428,-1.878 408 | -0.647,-0.555 409 | -1.183,-1.435 410 | -2.447,-2.718 411 | -2.538,-2.814 412 | -0.267,-1.388 413 | -0.482,-0.622 414 | -0.589,-0.096 415 | -1.579,-2.151 416 | 0.768,1.32 417 | -1.015,-0.637 418 | -1.096,-0.395 419 | -0.35,-0.477 420 | -1.27,-1.415 421 | -0.088,0.014 422 | -1.33,-1.197 423 | -0.251,1.059 424 | 0.963,0.533 425 | 2.156,0.179 426 | 0.474,0.65 427 | 0.728,1.192 428 | 1.491,0.775 429 | 1.665,1.348 430 | 1.112,1.111 431 | 1.117,1.748 432 | 0.946,1.123 433 | 1.125,0.322 434 | 1.429,0.433 435 | 1.538,1.181 436 | 0.788,0.977 437 | 0.662,0.363 438 | 2.211,1.22 439 | 1.572,1.83 440 | 0.819,1.313 441 | 2.091,1.612 442 | 1.328,0.059 443 | -0.253,0.63 444 | 1.622,0.915 445 | 1.459,0.501 446 | 0.659,1.407 447 | -0.824,-1.023 448 | 0.934,1.499 449 | -0.124,1.555 450 | -1.125,-1.285 451 | -0.812,0.423 452 | 0.277,1.525 453 | 0.486,0.723 454 | 1.097,1.433 455 | -0.132,1.094 456 | 2.736,3.031 457 | 0.335,1.617 458 | 0.481,0.575 459 | -0.922,-1.13 460 | -0.033,0.403 461 | -0.047,0.642 462 | -0.302,0.994 463 | -1.451,-1.536 464 | -0.235,0.01 465 | -1.148,-1.385 466 | -0.688,-0.79 467 | -1.543,-0.796 468 | -1.246,-1.885 469 | -1.419,-1.068 470 | -0.893,-0.547 471 | -1.063,-0.412 472 | -1.46,-1.945 473 | -0.568,0.669 474 | -1.191,-1.417 475 | -0.273,0.426 476 | -1.272,-2.131 477 | 0.639,0.822 478 | 1.592,1.268 479 | 1.752,0.766 480 | -1.884,-1.501 481 | -0.628,-0.872 482 | -0.873,-0.585 483 | 0.152,1.247 484 | 0.965,1.075 485 | 0.685,0.52 486 | 0.092,0.453 487 | -0.281,1.242 488 | 0.514,0.313 489 | 1.887,-0.15 490 | 2.024,2.278 491 | 1.319,1.058 492 | -1.326,-1.57 493 | 1.599,0.195 494 | 1.533,0.877 495 | 2.597,1.451 496 | -0.808,-1.18 497 | 2.427,0.966 498 | 3.09,1.039 499 | 1.188,0.988 500 | 1.442,1.582 501 | 0.991,0.895 502 | -0.492,-1.016 503 | -0.346,-0.149 504 | -1.258,-1.696 505 | 0.664,0.115 506 | -0.77,-0.425 507 | -0.573,-0.764 508 | 2.717,0.698 509 | -1.315,0.775 510 | -0.632,-0.596 511 | -0.42,-0.638 512 | -1.741,-2.162 513 | -0.879,0.223 514 | 0.844,0.903 515 | -0.644,-0.193 516 | 0.052,0.132 517 | -1.606,-1.841 518 | -1.278,-1.433 519 | 0.519,0.506 520 | -1.132,-2.419 521 | 0.061,-0.626 522 | -0.88,-0.608 523 | -0.104,0.209 524 | -0.304,0.33 525 | -0.266,-0.915 526 | -0.586,0.296 527 | -0.701,-0.893 528 | -1.455,-2.094 529 | -0.747,-0.18 530 | -1.546,-1.986 531 | -0.986,-1.269 532 | 0.371,1.277 533 | -0.949,0.588 534 | -0.478,-1.063 535 | -0.86,-1.85 536 | -1.382,-1.101 537 | -0.439,-1.38 538 | -0.738,0.488 539 | -1.037,0.035 540 | -1.801,-2.948 541 | -0.883,-1.312 542 | -1.557,-1.667 543 | -0.635,-1.203 544 | -1.38,-2.095 545 | -1.28,-3.374 546 | -0.904,-0.662 547 | -0.971,-0.517 548 | -0.995,-0.317 549 | -1.497,-1.981 550 | -0.666,-0.906 551 | 0.939,-0.703 552 | -1.417,-2.091 553 | -0.875,-0.78 554 | -0.483,-1.61 555 | -0.978,-0.968 556 | -0.434,-1.16 557 | 0.709,1.391 558 | -1.199,-0.639 559 | 2.929,2.054 560 | -2.103,-2.37 561 | 0.13,0.445 562 | 0.251,0.135 563 | 0.94,0.682 564 | 0.896,0.759 565 | 2.363,2.292 566 | 0.408,1.248 567 | -0.867,0.748 568 | 1.132,1.015 569 | 3.357,1.514 570 | 1.594,0.901 571 | 0.649,1.4 572 | 1.845,0.504 573 | 1.412,1.303 574 | 1.26,0.992 575 | 0.597,1.616 576 | 1.203,1.656 577 | 1.434,1.123 578 | 1.214,0.894 579 | 1.507,1.715 580 | -0.893,-1.775 581 | 1.255,0.756 582 | -1.857,-2.568 583 | -0.941,-1.65 584 | -0.694,-0.028 585 | -1.019,-2.401 586 | -1.362,0.044 587 | 0.66,0.966 588 | -1.132,-0.621 589 | -1.804,-3.099 590 | -1.259,-1.85 591 | -0.331,-0.146 592 | -2.46,-1.629 593 | -2.219,-3.086 594 | -0.01,1.385 595 | 0.279,1.024 596 | -0.899,-2.38 597 | -1.099,0.044 598 | -1.125,-1.059 599 | -1.545,-1.876 600 | 1.56,-0.08 601 | -1.011,-0.538 602 | 2.289,1.461 603 | 1.292,0.574 604 | 1.33,0.237 605 | 1.373,1.072 606 | 3.504,1.451 607 | 1.162,1.155 608 | 0.409,0.183 609 | 1.639,1.575 610 | 0.977,1.383 611 | 0.88,0.817 612 | 2.303,1.029 613 | 0.758,0.582 614 | -0.517,1.118 615 | 1.027,0.948 616 | 1.343,2.522 617 | 0.495,0.666 618 | 0.774,1.255 619 | 1.146,2.38 620 | 1.926,1.233 621 | 1.835,0.598 622 | 0.275,0.876 623 | 2.78,1.489 624 | -2.113,-0.64 625 | 0.399,1.308 626 | 0.55,1.079 627 | 1.009,1.123 628 | 1.369,0.75 629 | 0.811,1.17 630 | 2.184,1.259 631 | 1.711,0.985 632 | 1.299,0.519 633 | -0.193,0.286 634 | 1.1,0.21 635 | -1.236,0.333 636 | 1.473,0.945 637 | 1.608,0.467 638 | 0.854,0.833 639 | 0.68,0.356 640 | 1.173,0.949 641 | -0.127,0.994 642 | 1.887,1.799 643 | 1.408,1.674 644 | 0.41,0.763 645 | 0.965,0.989 646 | 0.136,0.924 647 | 0.783,0.846 648 | 2.097,1.709 649 | 0.235,0.372 650 | 2.349,1.381 651 | 0.622,1.759 652 | 1.203,1.503 653 | 0.178,0.599 654 | 1.816,-0.02 655 | -1.327,-1.7 656 | -1.536,-1.919 657 | 1.621,1.83 658 | 1.621,1.758 659 | -0.488,0.567 660 | 1.702,1.015 661 | 1.467,1.201 662 | -1.695,-1.85 663 | -0.247,-1.407 664 | -1.555,-0.458 665 | 0.43,0.735 666 | 1.749,1.207 667 | -0.736,-0.667 668 | -0.663,-0.476 669 | -0.815,-0.309 670 | -1.349,-2.17 671 | -0.783,-0.214 672 | -1.244,-1.805 673 | -0.934,-1.16 674 | 1.07,1.595 675 | -1.869,-1.44 676 | 0.723,0.682 677 | -1.331,-0.797 678 | -0.701,-0.068 679 | 0.108,0.39 680 | -0.748,-1.421 681 | -1.117,-1.077 682 | -1.156,1.58 683 | 2.501,0.475 684 | 1.308,1.919 685 | 0.632,1.019 686 | 0.759,-0.137 687 | 1.981,0.196 688 | 0.603,0.28 689 | 1.974,1.396 690 | 0.911,0.216 691 | 1.574,1.979 692 | 1.963,1.105 693 | 1.171,0.156 694 | -1.109,0.107 695 | 1.608,0.742 696 | 1.68,1.614 697 | -1.823,-1.957 698 | 0.372,0.993 699 | 1.127,1.173 700 | -0.097,1.186 701 | -0.837,-0.159 702 | -2.063,-1.063 703 | -2.816,-3.01 704 | -0.034,0.076 705 | 1.619,0.815 706 | -0.573,-0.563 707 | -1.972,-1.294 708 | -1.044,0.091 709 | -2.025,-2.531 710 | -0.96,-0.143 711 | 1.791,1.299 712 | -1.495,-1.622 713 | -0.625,-0.154 714 | -1.157,-1.827 715 | -0.264,0.184 716 | -1.191,-0.686 717 | -0.493,0.165 718 | 0.27,1.736 719 | -1.005,-1.302 720 | -1.032,-0.804 721 | -1.136,-0.399 722 | 1.278,0.209 723 | 0.924,1.491 724 | 0.678,0.713 725 | -0.854,0.315 726 | 1.35,1.82 727 | -0.918,-0.665 728 | 1.364,0.983 729 | 2.12,1.773 730 | -0.09,1.438 731 | -0.221,1.302 732 | 1.194,0.565 733 | 2.271,0.601 734 | 1.003,1.173 735 | 1.015,1.307 736 | 1.569,0.63 737 | 0.293,1.237 738 | 1.215,1.312 739 | -0.239,0.523 740 | 2.541,0.918 741 | 1.957,1.479 742 | -1.528,-2.415 743 | -0.658,-1.396 744 | -1.17,-1.296 745 | -1.529,-1.008 746 | -0.622,0.159 747 | -0.184,-0.552 748 | 1.287,1.015 749 | -0.648,-0.544 750 | -0.43,-0.491 751 | -0.89,-0.826 752 | -1.28,-1.255 753 | -1.021,-0.451 754 | 0.058,0.777 755 | -1.755,-1.789 756 | -2.189,-2.485 757 | 0.215,0.82 758 | -1.256,-1.002 759 | -1.542,-0.778 760 | -1.563,-0.823 761 | 0.56,0.907 762 | -0.728,-0.427 763 | -0.854,-2.31 764 | -1.438,-0.973 765 | -1.901,-2.122 766 | -0.389,-0.516 767 | -1.68,-1.135 768 | -1.441,-0.932 769 | -1.422,-0.674 770 | 0.862,0.554 771 | -1.544,-0.293 772 | -0.755,-1.438 773 | -1.601,-1.193 774 | -1.132,-1.228 775 | -0.307,0.556 776 | -1.379,-2.346 777 | -0.246,-0.322 778 | 1.881,0.658 779 | -0.797,0.373 780 | 0.499,1.758 781 | -2.761,-1.969 782 | 1.233,0.75 783 | 1.79,0.867 784 | 0.14,0.386 785 | -0.455,0.848 786 | 0.087,0.204 787 | 2.892,1.44 788 | 0.675,1.3 789 | 0.629,0.418 790 | -0.536,-1.65 791 | 1.051,1.179 792 | 1.323,1.201 793 | 1.23,1.102 794 | 0.795,1.104 795 | -1.284,-0.779 796 | 0.274,1.04 797 | 0.811,1.666 798 | 1.008,0.165 799 | -1.735,-2.124 800 | 2.096,1.774 801 | 1.916,0.591 802 | -1.91,-1.69 803 | 0.202,-1.016 804 | -1.533,-1.946 805 | -1.373,-1.949 806 | -0.875,-1.036 807 | -0.59,0.495 808 | -0.718,-1.228 809 | 0.974,0.645 810 | -0.984,-2.241 811 | -0.765,-0.984 812 | -2.067,-2.419 813 | -0.87,-0.432 814 | -1.781,-0.88 815 | -0.162,-0.921 816 | 0.613,0.801 817 | -1.413,-0.235 818 | 0.764,0.912 819 | -0.835,-0.589 820 | -1.778,-0.351 821 | -1.172,-1.493 822 | 1.094,1.05 823 | 1.551,1.165 824 | 0.69,0.116 825 | 0.362,0.216 826 | 0.615,2.801 827 | 1.277,0.774 828 | 1.653,1.452 829 | 0.226,0.879 830 | 0.614,0.698 831 | 0.513,0.617 832 | -1.182,-2.494 833 | 1.404,0.131 834 | -0.965,-1.452 835 | 1.252,1.475 836 | 2.224,0.871 837 | 1.48,1.982 838 | 1.065,0.533 839 | 0.022,0.35 840 | 0.228,1.417 841 | 0.241,-0.711 842 | 1.942,1.558 843 | 0.475,0.03 844 | 0.975,0.739 845 | 0.296,1.167 846 | 2.33,0.836 847 | 1.124,1.745 848 | 0.721,0.364 849 | 1.668,1.05 850 | 1.433,1.491 851 | 0.837,1.277 852 | 0.599,0.208 853 | 0.38,0.043 854 | 0.971,1.013 855 | -0.216,0.497 856 | 1.616,1.062 857 | 1.324,1.473 858 | -1.28,-1.395 859 | -0.115,1.095 860 | 1.363,0.111 861 | 0.558,1.38 862 | -0.749,-1.522 863 | -1.061,-1.162 864 | 0.953,1.051 865 | 0.558,0.852 866 | -1.36,-0.957 867 | -0.747,-0.668 868 | -1.356,-0.68 869 | -0.638,0.037 870 | -0.322,-1.495 871 | -1.201,-1.053 872 | -1.582,-1.742 873 | -0.165,-0.442 874 | -0.525,-1.638 875 | -0.262,-1.038 876 | -0.784,-0.359 877 | -0.258,-1.022 878 | -0.961,0.196 879 | 3.293,1.852 880 | -0.767,-0.298 881 | -1.317,-1.127 882 | 0.635,0.751 883 | 1.738,1.334 884 | 1.585,0.813 885 | 1.249,0.842 886 | 0.399,-0.543 887 | -1.389,-0.999 888 | 0.282,-0.171 889 | 1.925,0.965 890 | 0.439,0.994 891 | 1.601,0.288 892 | 2.182,1.264 893 | -0.252,1.215 894 | 1.678,0.954 895 | 0.926,1.055 896 | 2.238,1.245 897 | 0.651,0.256 898 | 0.321,0.351 899 | 0.971,1.513 900 | 1.959,0.991 901 | -1.315,-0.866 902 | -1.534,-0.217 903 | 0.21,-0.317 904 | -0.242,-1.212 905 | -0.106,0.082 906 | -1.306,-0.853 907 | -0.865,-1.598 908 | -0.567,0.204 909 | -1.556,-0.597 910 | -0.301,-0.132 911 | -1.235,-1.504 912 | -1.432,-0.355 913 | -1.132,-0.25 914 | 0.496,1.581 915 | -1.687,-0.707 916 | 0.38,-0.064 917 | 0.184,0.697 918 | 1.107,0.63 919 | -2.474,-1.682 920 | -0.591,-0.522 921 | 0.037,0.58 922 | 1.439,0.841 923 | 1.665,1.175 924 | 1.7,1.064 925 | -1.764,-3.05 926 | -1.322,0.206 927 | 0.953,0.835 928 | -1.416,-1.99 929 | 0.635,1.251 930 | -1.046,-2.126 931 | 1.968,0.815 932 | -1.033,-1.129 933 | 1.676,1.902 934 | 0.545,1.384 935 | 0.059,1.02 936 | 2.251,1.96 937 | 1.25,1.167 938 | 0.877,0.305 939 | 0.295,0.153 940 | 2.48,1.292 941 | 0.644,1.163 942 | 0.839,0.973 943 | -0.228,-0.993 944 | -1.553,-1.077 945 | 0.198,0.817 946 | -1.721,-2.312 947 | -0.058,-0.525 948 | -0.958,-0.936 949 | -1.459,-2.109 950 | -1.433,-2.247 951 | -0.612,-1.666 952 | -0.92,-1.239 953 | -1.406,-1.601 954 | -1.499,-1.942 955 | -0.475,-0.312 956 | 1.657,0.74 957 | -0.376,-1.333 958 | 0.346,0.993 959 | -0.951,-1.032 960 | -0.772,-1.798 961 | -1.012,-1.324 962 | 2.08,0.463 963 | 1.374,1.335 964 | -0.335,0.135 965 | 2.563,0.461 966 | 0.49,1.365 967 | 1.075,0.911 968 | -0.039,1.759 969 | 1.041,0.831 970 | 0.631,0.969 971 | 2.209,1.25 972 | 1.385,0.087 973 | -0.094,0.373 974 | -1.891,-1.577 975 | -2.52,-1.332 976 | 0.785,1.642 977 | 0.969,1.151 978 | 1.109,0.503 979 | 2.11,-0.01 980 | -0.803,-1.465 981 | -1.459,-0.778 982 | -0.848,-0.06 983 | -1.841,-0.479 984 | -1.618,-0.005 985 | 2.465,1.516 986 | 0.062,1.159 987 | -1.513,0.699 988 | -1.784,-2.442 989 | -0.351,-1.648 990 | 0.172,1.615 991 | -1.466,-1.103 992 | -0.79,-1.465 993 | 0.774,1.171 994 | 0.506,1.064 995 | -0.163,0.938 996 | -1.114,-0.61 997 | 1.99,1.367 998 | -1.557,-1.412 999 | -1.033,-0.726 1000 | -0.195,0.531 1001 | -2.534,-1.977 1002 | -------------------------------------------------------------------------------- /em/em.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from scipy import stats 4 | from scipy.special import logsumexp 5 | from sklearn.mixture import GaussianMixture 6 | from matplotlib import pyplot as plt 7 | 8 | 9 | def GMM_sklearn(x, weights=None, means=None, covariances=None): 10 | model = GaussianMixture(n_components=2, 11 | covariance_type='full', 12 | tol=0.01, 13 | max_iter=1000, 14 | weights_init=weights, 15 | means_init=means, 16 | precisions_init=covariances) 17 | model.fit(x) 18 | print("\nscikit learn:\n\tphi: %s\n\tmu_0: %s\n\tmu_1: %s\n\tsigma_0: %s\n\tsigma_1: %s" 19 | % (model.weights_[1], model.means_[0, :], model.means_[1, :], model.covariances_[0, :], model.covariances_[1, :])) 20 | return model.predict(x), model.predict_proba(x)[:,1] 21 | 22 | 23 | def get_random_psd(n): 24 | x = np.random.normal(0, 1, size=(n, n)) 25 | return np.dot(x, x.transpose()) 26 | 27 | 28 | def initialize_random_params(): 29 | params = {'phi': np.random.uniform(0, 1), 30 | 'mu0': np.random.normal(0, 1, size=(2,)), 31 | 'mu1': np.random.normal(0, 1, size=(2,)), 32 | 'sigma0': get_random_psd(2), 33 | 'sigma1': get_random_psd(2)} 34 | return params 35 | 36 | 37 | def learn_params(x_labeled, y_labeled): 38 | n = x_labeled.shape[0] 39 | phi = x_labeled[y_labeled == 1].shape[0] / n 40 | mu0 = np.sum(x_labeled[y_labeled == 0], axis=0) / x_labeled[y_labeled == 0].shape[0] 41 | mu1 = np.sum(x_labeled[y_labeled == 1], axis=0) / x_labeled[y_labeled == 1].shape[0] 42 | sigma0 = np.cov(x_labeled[y_labeled == 0].T, bias= True) 43 | sigma1 = np.cov(x_labeled[y_labeled == 1].T, bias=True) 44 | return {'phi': phi, 'mu0': mu0, 'mu1': mu1, 'sigma0': sigma0, 'sigma1': sigma1} 45 | 46 | 47 | def e_step(x, params): 48 | np.log([stats.multivariate_normal(params["mu0"], params["sigma0"]).pdf(x), 49 | stats.multivariate_normal(params["mu1"], params["sigma1"]).pdf(x)]) 50 | log_p_y_x = np.log([1-params["phi"], params["phi"]])[np.newaxis, ...] + \ 51 | np.log([stats.multivariate_normal(params["mu0"], params["sigma0"]).pdf(x), 52 | stats.multivariate_normal(params["mu1"], params["sigma1"]).pdf(x)]).T 53 | log_p_y_x_norm = logsumexp(log_p_y_x, axis=1) 54 | return log_p_y_x_norm, np.exp(log_p_y_x - log_p_y_x_norm[..., np.newaxis]) 55 | 56 | 57 | def m_step(x, params): 58 | total_count = x.shape[0] 59 | _, heuristics = e_step(x, params) 60 | heuristic0 = heuristics[:, 0] 61 | heuristic1 = heuristics[:, 1] 62 | sum_heuristic1 = np.sum(heuristic1) 63 | sum_heuristic0 = np.sum(heuristic0) 64 | phi = (sum_heuristic1/total_count) 65 | mu0 = (heuristic0[..., np.newaxis].T.dot(x)/sum_heuristic0).flatten() 66 | mu1 = (heuristic1[..., np.newaxis].T.dot(x)/sum_heuristic1).flatten() 67 | diff0 = x - mu0 68 | sigma0 = diff0.T.dot(diff0 * heuristic0[..., np.newaxis]) / sum_heuristic0 69 | diff1 = x - mu1 70 | sigma1 = diff1.T.dot(diff1 * heuristic1[..., np.newaxis]) / sum_heuristic1 71 | params = {'phi': phi, 'mu0': mu0, 'mu1': mu1, 'sigma0': sigma0, 'sigma1': sigma1} 72 | return params 73 | 74 | 75 | def get_avg_log_likelihood(x, params): 76 | loglikelihood, _ = e_step(x, params) 77 | return np.mean(loglikelihood) 78 | 79 | 80 | def run_em(x, params): 81 | avg_loglikelihoods = [] 82 | while True: 83 | avg_loglikelihood = get_avg_log_likelihood(x, params) 84 | avg_loglikelihoods.append(avg_loglikelihood) 85 | if len(avg_loglikelihoods) > 2 and abs(avg_loglikelihoods[-1] - avg_loglikelihoods[-2]) < 0.0001: 86 | break 87 | params = m_step(x_unlabeled, params) 88 | print("\tphi: %s\n\tmu_0: %s\n\tmu_1: %s\n\tsigma_0: %s\n\tsigma_1: %s" 89 | % (params['phi'], params['mu0'], params['mu1'], params['sigma0'], params['sigma1'])) 90 | _, posterior = e_step(x_unlabeled, params) 91 | forecasts = np.argmax(posterior, axis=1) 92 | return forecasts, posterior, avg_loglikelihoods 93 | 94 | 95 | # def unsupervised_gmm(x_unlabeled): 96 | # params = initialize_random_params() 97 | # weights = [1 - params["phi"], params["phi"]] 98 | # means = [params["mu0"], params["mu1"]] 99 | # covariances = [params["sigma0"], params["sigma1"]] 100 | # sklearn_forecasts, posterior_sklearn = GMM_sklearn(x_unlabeled, weights, means, covariances) 101 | # forecasts, posterior, loglikelihoods = run_em(x_unlabeled, params) 102 | # print("total steps: ", len(loglikelihoods)) 103 | # plt.plot(loglikelihoods) 104 | # plt.title("unsupervised log likelihoods") 105 | # plt.savefig("unsupervised.png") 106 | # plt.close() 107 | # return pd.DataFrame({'forecasts': forecasts, 'posterior': posterior[:,1], 108 | # 'sklearn_forecasts': sklearn_forecasts, 109 | # 'posterior_sklearn': posterior_sklearn}) 110 | 111 | 112 | # def semi_supervised_gmm(x_unlabeled): 113 | # data_labeled = pd.read_csv("data/labeled.csv") 114 | # x_labeled = data_labeled[["x1", "x2"]].values 115 | # y_labeled = data_labeled["y"].values 116 | # params = learn_params(x_labeled, y_labeled) 117 | # weights = [1 - params["phi"], params["phi"]] 118 | # means = [params["mu0"], params["mu1"]] 119 | # covariances = [params["sigma0"], params["sigma1"]] 120 | # sklearn_forecasts, posterior_sklearn = GMM_sklearn(x_unlabeled, weights, means, covariances) 121 | # forecasts, posterior, loglikelihoods = run_em(x_unlabeled, params) 122 | # print("total steps: ", len(loglikelihoods)) 123 | # plt.plot(loglikelihoods) 124 | # plt.title("semi-supervised log likelihoods") 125 | # plt.savefig("semi-supervised.png") 126 | # return pd.DataFrame({'forecasts': forecasts, 'posterior': posterior[:, 1], 127 | # 'sklearn_forecasts': sklearn_forecasts, 128 | # 'posterior_sklearn': posterior_sklearn}) 129 | 130 | 131 | 132 | if __name__ == '__main__': 133 | data_unlabeled = pd.read_csv("data/unlabeled.csv") 134 | x_unlabeled = data_unlabeled[["x1", "x2"]].values 135 | 136 | # Unsupervised learning 137 | print("unsupervised: ") 138 | random_params = initialize_random_params() 139 | unsupervised_forecastsforecasts, unsupervised_posterior, unsupervised_loglikelihoods = run_em(x_unlabeled, random_params) 140 | print("total steps: ", len(unsupervised_loglikelihoods)) 141 | plt.plot(unsupervised_loglikelihoods) 142 | plt.title("unsupervised log likelihoods") 143 | plt.savefig("unsupervised.png") 144 | plt.close() 145 | 146 | # Semi-supervised learning 147 | print("\nsemi-supervised: ") 148 | data_labeled = pd.read_csv("data/labeled.csv") 149 | x_labeled = data_labeled[["x1", "x2"]].values 150 | y_labeled = data_labeled["y"].values 151 | learned_params = learn_params(x_labeled, y_labeled) 152 | semisupervised_forecasts, semisupervised_posterior, semisupervised_loglikelihoods = run_em(x_unlabeled, learned_params) 153 | print("total steps: ", len(semisupervised_loglikelihoods)) 154 | plt.plot(semisupervised_loglikelihoods) 155 | plt.title("semi-supervised log likelihoods") 156 | plt.savefig("semi-supervised.png") 157 | 158 | # Compare the forecats with Scikit-learn API 159 | learned_params = learn_params(x_labeled, y_labeled) 160 | weights = [1 - learned_params["phi"], learned_params["phi"]] 161 | means = [learned_params["mu0"], learned_params["mu1"]] 162 | covariances = [learned_params["sigma0"], learned_params["sigma1"]] 163 | sklearn_forecasts, posterior_sklearn = GMM_sklearn(x_unlabeled, weights, means, covariances) 164 | 165 | output_df = pd.DataFrame({'semisupervised_forecasts': semisupervised_forecasts, 166 | 'semisupervised_posterior': semisupervised_posterior[:, 1], 167 | 'sklearn_forecasts': sklearn_forecasts, 168 | 'posterior_sklearn': posterior_sklearn}) 169 | 170 | print("\n%s%% of forecasts matched." % (output_df[output_df["semisupervised_forecasts"] == output_df["sklearn_forecasts"]].shape[0] /output_df.shape[0] * 100)) 171 | 172 | -------------------------------------------------------------------------------- /em/semi-supervised.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/em/semi-supervised.png -------------------------------------------------------------------------------- /em/unsupervised.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/em/unsupervised.png -------------------------------------------------------------------------------- /mcmc_gibbs/data/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/mcmc_gibbs/data/img.png -------------------------------------------------------------------------------- /mcmc_gibbs/data/img_noisy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/mcmc_gibbs/data/img_noisy.png -------------------------------------------------------------------------------- /mcmc_gibbs/image_denoising.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def load_image(filename): 7 | my_img = plt.imread(filename) 8 | img_gray = np.dot(my_img[..., :3], [0.2989, 0.5870, 0.1140]) 9 | img_gray = np.where(img_gray > 0.5, 1, -1) 10 | img_padded = np.zeros([img_gray.shape[0] + 2, img_gray.shape[1] + 2]) 11 | img_padded[1:-1, 1:-1] = img_gray 12 | return img_padded 13 | 14 | 15 | def sample_y(i, j, Y, X): 16 | markov_blanket = [Y[i - 1, j], Y[i, j - 1], Y[i, j + 1], Y[i + 1, j], X[i, j]] 17 | w = ITA * markov_blanket[-1] + BETA * sum(markov_blanket[:4]) 18 | prob = 1 / (1 + math.exp(-2*w)) 19 | return (np.random.rand() < prob) * 2 - 1 20 | 21 | 22 | def get_posterior(filename, burn_in_steps, total_samples, logfile): 23 | X = load_image(filename) 24 | posterior = np.zeros(X.shape) 25 | print(X.shape) 26 | Y = np.random.choice([1, -1], size=X.shape) 27 | energy_list = list() 28 | for step in range(burn_in_steps + total_samples): 29 | for i in range(1, Y.shape[0]-1): 30 | for j in range(1, Y.shape[1]-1): 31 | y = sample_y(i, j, Y, X) 32 | Y[i, j] = y 33 | if y == 1 and step >= burn_in_steps: 34 | posterior[i, j] += 1 35 | energy = -np.sum(np.multiply(Y, X))*ITA-(np.sum(np.multiply(Y[:-1], Y[1:]))+np.sum(np.multiply(Y[:, :-1], Y[:, 1:])))*BETA 36 | if step < burn_in_steps: 37 | energy_list.append(str(step) + "\t" + str(energy) + "\tB") 38 | else: 39 | energy_list.append(str(step) + "\t" + str(energy) + "\tS") 40 | posterior = posterior / total_samples 41 | 42 | file = open(logfile, 'w') 43 | for element in energy_list: 44 | file.writelines(element) 45 | file.write('\n') 46 | file.close() 47 | return posterior 48 | 49 | 50 | def denoise_image(filename, burn_in_steps, total_samples, logfile): 51 | posterior = get_posterior(filename, burn_in_steps, total_samples, logfile=logfile) 52 | denoised = np.zeros(posterior.shape, dtype=np.float64) 53 | denoised[posterior > 0.5] = 1 54 | return denoised[1:-1, 1:-1] 55 | 56 | 57 | def plot_energy(filename): 58 | x = np.genfromtxt(filename, dtype=None, encoding='utf8') 59 | its, energies, phases = zip(*x) 60 | its = np.asarray(its) 61 | energies = np.asarray(energies) 62 | phases = np.asarray(phases) 63 | burn_mask = (phases == 'B') 64 | samp_mask = (phases == 'S') 65 | assert np.sum(burn_mask) + np.sum(samp_mask) == len(x), 'Found bad phase' 66 | its_burn, energies_burn = its[burn_mask], energies[burn_mask] 67 | its_samp, energies_samp = its[samp_mask], energies[samp_mask] 68 | p1, = plt.plot(its_burn, energies_burn, 'r') 69 | p2, = plt.plot(its_samp, energies_samp, 'b') 70 | plt.title("energy") 71 | plt.xlabel('iteration number') 72 | plt.ylabel('energy') 73 | plt.legend([p1, p2], ['burn in', 'sampling']) 74 | plt.savefig('%s.png' % filename) 75 | plt.close() 76 | 77 | 78 | def save_image(denoised_image): 79 | plt.imshow(denoised_image, cmap='gray') 80 | plt.title("denoised image") 81 | plt.savefig('output/denoise_image.png') 82 | plt.close() 83 | 84 | 85 | if __name__ == '__main__': 86 | ITA = 1 87 | BETA = 1 88 | total_samples = 1000 89 | burn_in_steps = 100 90 | logfile = "output/log_energy" 91 | denoised_img = denoise_image("data/img_noisy.png", burn_in_steps=burn_in_steps, 92 | total_samples=total_samples, logfile=logfile) 93 | plot_energy(logfile) 94 | save_image(denoised_img) -------------------------------------------------------------------------------- /mcmc_gibbs/output/denoise_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/mcmc_gibbs/output/denoise_image.png -------------------------------------------------------------------------------- /mcmc_gibbs/output/log_energy: -------------------------------------------------------------------------------- 1 | 0 -423328.0 B 2 | 1 -513834.0 B 3 | 2 -532120.0 B 4 | 3 -539532.0 B 5 | 4 -543492.0 B 6 | 5 -546038.0 B 7 | 6 -547468.0 B 8 | 7 -548422.0 B 9 | 8 -549058.0 B 10 | 9 -549536.0 B 11 | 10 -549942.0 B 12 | 11 -550386.0 B 13 | 12 -550700.0 B 14 | 13 -550740.0 B 15 | 14 -551036.0 B 16 | 15 -551058.0 B 17 | 16 -551150.0 B 18 | 17 -551322.0 B 19 | 18 -551390.0 B 20 | 19 -551420.0 B 21 | 20 -551522.0 B 22 | 21 -551534.0 B 23 | 22 -551634.0 B 24 | 23 -551642.0 B 25 | 24 -551734.0 B 26 | 25 -551762.0 B 27 | 26 -551792.0 B 28 | 27 -551922.0 B 29 | 28 -551854.0 B 30 | 29 -551948.0 B 31 | 30 -552022.0 B 32 | 31 -552084.0 B 33 | 32 -552042.0 B 34 | 33 -552058.0 B 35 | 34 -552212.0 B 36 | 35 -552318.0 B 37 | 36 -552164.0 B 38 | 37 -552188.0 B 39 | 38 -552208.0 B 40 | 39 -552160.0 B 41 | 40 -552212.0 B 42 | 41 -552042.0 B 43 | 42 -552250.0 B 44 | 43 -552092.0 B 45 | 44 -552356.0 B 46 | 45 -552238.0 B 47 | 46 -552164.0 B 48 | 47 -552302.0 B 49 | 48 -552368.0 B 50 | 49 -552374.0 B 51 | 50 -552406.0 B 52 | 51 -552418.0 B 53 | 52 -552402.0 B 54 | 53 -552354.0 B 55 | 54 -552340.0 B 56 | 55 -552412.0 B 57 | 56 -552358.0 B 58 | 57 -552412.0 B 59 | 58 -552398.0 B 60 | 59 -552420.0 B 61 | 60 -552368.0 B 62 | 61 -552396.0 B 63 | 62 -552444.0 B 64 | 63 -552424.0 B 65 | 64 -552408.0 B 66 | 65 -552394.0 B 67 | 66 -552388.0 B 68 | 67 -552428.0 B 69 | 68 -552452.0 B 70 | 69 -552508.0 B 71 | 70 -552448.0 B 72 | 71 -552532.0 B 73 | 72 -552508.0 B 74 | 73 -552584.0 B 75 | 74 -552444.0 B 76 | 75 -552408.0 B 77 | 76 -552526.0 B 78 | 77 -552510.0 B 79 | 78 -552500.0 B 80 | 79 -552542.0 B 81 | 80 -552608.0 B 82 | 81 -552460.0 B 83 | 82 -552528.0 B 84 | 83 -552418.0 B 85 | 84 -552480.0 B 86 | 85 -552370.0 B 87 | 86 -552486.0 B 88 | 87 -552522.0 B 89 | 88 -552464.0 B 90 | 89 -552418.0 B 91 | 90 -552524.0 B 92 | 91 -552472.0 B 93 | 92 -552464.0 B 94 | 93 -552492.0 B 95 | 94 -552522.0 B 96 | 95 -552452.0 B 97 | 96 -552598.0 B 98 | 97 -552706.0 B 99 | 98 -552704.0 B 100 | 99 -552454.0 B 101 | 100 -552594.0 S 102 | 101 -552560.0 S 103 | 102 -552526.0 S 104 | 103 -552594.0 S 105 | 104 -552598.0 S 106 | 105 -552526.0 S 107 | 106 -552590.0 S 108 | 107 -552492.0 S 109 | 108 -552558.0 S 110 | 109 -552458.0 S 111 | 110 -552438.0 S 112 | 111 -552540.0 S 113 | 112 -552434.0 S 114 | 113 -552376.0 S 115 | 114 -552388.0 S 116 | 115 -552490.0 S 117 | 116 -552402.0 S 118 | 117 -552540.0 S 119 | 118 -552530.0 S 120 | 119 -552486.0 S 121 | 120 -552604.0 S 122 | 121 -552560.0 S 123 | 122 -552564.0 S 124 | 123 -552454.0 S 125 | 124 -552582.0 S 126 | 125 -552656.0 S 127 | 126 -552482.0 S 128 | 127 -552448.0 S 129 | 128 -552556.0 S 130 | 129 -552544.0 S 131 | 130 -552438.0 S 132 | 131 -552592.0 S 133 | 132 -552564.0 S 134 | 133 -552536.0 S 135 | 134 -552542.0 S 136 | 135 -552544.0 S 137 | 136 -552420.0 S 138 | 137 -552480.0 S 139 | 138 -552434.0 S 140 | 139 -552478.0 S 141 | 140 -552472.0 S 142 | 141 -552524.0 S 143 | 142 -552352.0 S 144 | 143 -552558.0 S 145 | 144 -552556.0 S 146 | 145 -552554.0 S 147 | 146 -552590.0 S 148 | 147 -552606.0 S 149 | 148 -552660.0 S 150 | 149 -552562.0 S 151 | 150 -552626.0 S 152 | 151 -552466.0 S 153 | 152 -552656.0 S 154 | 153 -552596.0 S 155 | 154 -552616.0 S 156 | 155 -552596.0 S 157 | 156 -552480.0 S 158 | 157 -552532.0 S 159 | 158 -552566.0 S 160 | 159 -552504.0 S 161 | 160 -552612.0 S 162 | 161 -552570.0 S 163 | 162 -552552.0 S 164 | 163 -552536.0 S 165 | 164 -552586.0 S 166 | 165 -552704.0 S 167 | 166 -552542.0 S 168 | 167 -552422.0 S 169 | 168 -552454.0 S 170 | 169 -552528.0 S 171 | 170 -552668.0 S 172 | 171 -552672.0 S 173 | 172 -552570.0 S 174 | 173 -552496.0 S 175 | 174 -552602.0 S 176 | 175 -552612.0 S 177 | 176 -552704.0 S 178 | 177 -552576.0 S 179 | 178 -552534.0 S 180 | 179 -552436.0 S 181 | 180 -552670.0 S 182 | 181 -552508.0 S 183 | 182 -552682.0 S 184 | 183 -552674.0 S 185 | 184 -552502.0 S 186 | 185 -552480.0 S 187 | 186 -552536.0 S 188 | 187 -552546.0 S 189 | 188 -552626.0 S 190 | 189 -552686.0 S 191 | 190 -552624.0 S 192 | 191 -552552.0 S 193 | 192 -552592.0 S 194 | 193 -552590.0 S 195 | 194 -552602.0 S 196 | 195 -552612.0 S 197 | 196 -552550.0 S 198 | 197 -552568.0 S 199 | 198 -552744.0 S 200 | 199 -552640.0 S 201 | 200 -552550.0 S 202 | 201 -552568.0 S 203 | 202 -552602.0 S 204 | 203 -552448.0 S 205 | 204 -552454.0 S 206 | 205 -552622.0 S 207 | 206 -552624.0 S 208 | 207 -552554.0 S 209 | 208 -552614.0 S 210 | 209 -552520.0 S 211 | 210 -552536.0 S 212 | 211 -552672.0 S 213 | 212 -552522.0 S 214 | 213 -552576.0 S 215 | 214 -552626.0 S 216 | 215 -552516.0 S 217 | 216 -552558.0 S 218 | 217 -552588.0 S 219 | 218 -552532.0 S 220 | 219 -552626.0 S 221 | 220 -552540.0 S 222 | 221 -552510.0 S 223 | 222 -552598.0 S 224 | 223 -552600.0 S 225 | 224 -552506.0 S 226 | 225 -552650.0 S 227 | 226 -552666.0 S 228 | 227 -552566.0 S 229 | 228 -552530.0 S 230 | 229 -552660.0 S 231 | 230 -552510.0 S 232 | 231 -552488.0 S 233 | 232 -552404.0 S 234 | 233 -552446.0 S 235 | 234 -552492.0 S 236 | 235 -552532.0 S 237 | 236 -552532.0 S 238 | 237 -552522.0 S 239 | 238 -552512.0 S 240 | 239 -552434.0 S 241 | 240 -552512.0 S 242 | 241 -552532.0 S 243 | 242 -552538.0 S 244 | 243 -552512.0 S 245 | 244 -552538.0 S 246 | 245 -552474.0 S 247 | 246 -552504.0 S 248 | 247 -552606.0 S 249 | 248 -552580.0 S 250 | 249 -552648.0 S 251 | 250 -552644.0 S 252 | 251 -552584.0 S 253 | 252 -552640.0 S 254 | 253 -552714.0 S 255 | 254 -552564.0 S 256 | 255 -552528.0 S 257 | 256 -552530.0 S 258 | 257 -552572.0 S 259 | 258 -552490.0 S 260 | 259 -552582.0 S 261 | 260 -552624.0 S 262 | 261 -552498.0 S 263 | 262 -552640.0 S 264 | 263 -552600.0 S 265 | 264 -552656.0 S 266 | 265 -552458.0 S 267 | 266 -552570.0 S 268 | 267 -552610.0 S 269 | 268 -552454.0 S 270 | 269 -552582.0 S 271 | 270 -552638.0 S 272 | 271 -552588.0 S 273 | 272 -552544.0 S 274 | 273 -552520.0 S 275 | 274 -552522.0 S 276 | 275 -552552.0 S 277 | 276 -552610.0 S 278 | 277 -552398.0 S 279 | 278 -552434.0 S 280 | 279 -552482.0 S 281 | 280 -552560.0 S 282 | 281 -552586.0 S 283 | 282 -552516.0 S 284 | 283 -552634.0 S 285 | 284 -552642.0 S 286 | 285 -552540.0 S 287 | 286 -552616.0 S 288 | 287 -552544.0 S 289 | 288 -552526.0 S 290 | 289 -552544.0 S 291 | 290 -552646.0 S 292 | 291 -552522.0 S 293 | 292 -552630.0 S 294 | 293 -552528.0 S 295 | 294 -552428.0 S 296 | 295 -552528.0 S 297 | 296 -552548.0 S 298 | 297 -552496.0 S 299 | 298 -552478.0 S 300 | 299 -552610.0 S 301 | 300 -552526.0 S 302 | 301 -552578.0 S 303 | 302 -552484.0 S 304 | 303 -552556.0 S 305 | 304 -552570.0 S 306 | 305 -552558.0 S 307 | 306 -552462.0 S 308 | 307 -552502.0 S 309 | 308 -552508.0 S 310 | 309 -552522.0 S 311 | 310 -552338.0 S 312 | 311 -552552.0 S 313 | 312 -552550.0 S 314 | 313 -552456.0 S 315 | 314 -552570.0 S 316 | 315 -552656.0 S 317 | 316 -552576.0 S 318 | 317 -552546.0 S 319 | 318 -552548.0 S 320 | 319 -552644.0 S 321 | 320 -552620.0 S 322 | 321 -552566.0 S 323 | 322 -552404.0 S 324 | 323 -552458.0 S 325 | 324 -552538.0 S 326 | 325 -552510.0 S 327 | 326 -552522.0 S 328 | 327 -552392.0 S 329 | 328 -552560.0 S 330 | 329 -552622.0 S 331 | 330 -552640.0 S 332 | 331 -552410.0 S 333 | 332 -552638.0 S 334 | 333 -552504.0 S 335 | 334 -552586.0 S 336 | 335 -552498.0 S 337 | 336 -552428.0 S 338 | 337 -552414.0 S 339 | 338 -552560.0 S 340 | 339 -552520.0 S 341 | 340 -552630.0 S 342 | 341 -552556.0 S 343 | 342 -552570.0 S 344 | 343 -552592.0 S 345 | 344 -552614.0 S 346 | 345 -552644.0 S 347 | 346 -552614.0 S 348 | 347 -552574.0 S 349 | 348 -552550.0 S 350 | 349 -552568.0 S 351 | 350 -552506.0 S 352 | 351 -552458.0 S 353 | 352 -552494.0 S 354 | 353 -552560.0 S 355 | 354 -552616.0 S 356 | 355 -552572.0 S 357 | 356 -552564.0 S 358 | 357 -552654.0 S 359 | 358 -552484.0 S 360 | 359 -552520.0 S 361 | 360 -552544.0 S 362 | 361 -552666.0 S 363 | 362 -552638.0 S 364 | 363 -552542.0 S 365 | 364 -552486.0 S 366 | 365 -552654.0 S 367 | 366 -552474.0 S 368 | 367 -552562.0 S 369 | 368 -552648.0 S 370 | 369 -552536.0 S 371 | 370 -552436.0 S 372 | 371 -552572.0 S 373 | 372 -552672.0 S 374 | 373 -552576.0 S 375 | 374 -552514.0 S 376 | 375 -552328.0 S 377 | 376 -552470.0 S 378 | 377 -552478.0 S 379 | 378 -552450.0 S 380 | 379 -552522.0 S 381 | 380 -552444.0 S 382 | 381 -552602.0 S 383 | 382 -552558.0 S 384 | 383 -552494.0 S 385 | 384 -552616.0 S 386 | 385 -552654.0 S 387 | 386 -552544.0 S 388 | 387 -552594.0 S 389 | 388 -552600.0 S 390 | 389 -552654.0 S 391 | 390 -552584.0 S 392 | 391 -552566.0 S 393 | 392 -552538.0 S 394 | 393 -552526.0 S 395 | 394 -552662.0 S 396 | 395 -552578.0 S 397 | 396 -552556.0 S 398 | 397 -552492.0 S 399 | 398 -552616.0 S 400 | 399 -552484.0 S 401 | 400 -552440.0 S 402 | 401 -552620.0 S 403 | 402 -552634.0 S 404 | 403 -552494.0 S 405 | 404 -552474.0 S 406 | 405 -552542.0 S 407 | 406 -552676.0 S 408 | 407 -552520.0 S 409 | 408 -552466.0 S 410 | 409 -552580.0 S 411 | 410 -552566.0 S 412 | 411 -552678.0 S 413 | 412 -552626.0 S 414 | 413 -552654.0 S 415 | 414 -552630.0 S 416 | 415 -552588.0 S 417 | 416 -552596.0 S 418 | 417 -552454.0 S 419 | 418 -552578.0 S 420 | 419 -552600.0 S 421 | 420 -552492.0 S 422 | 421 -552586.0 S 423 | 422 -552504.0 S 424 | 423 -552554.0 S 425 | 424 -552612.0 S 426 | 425 -552504.0 S 427 | 426 -552448.0 S 428 | 427 -552588.0 S 429 | 428 -552494.0 S 430 | 429 -552676.0 S 431 | 430 -552466.0 S 432 | 431 -552482.0 S 433 | 432 -552574.0 S 434 | 433 -552678.0 S 435 | 434 -552590.0 S 436 | 435 -552480.0 S 437 | 436 -552496.0 S 438 | 437 -552554.0 S 439 | 438 -552508.0 S 440 | 439 -552522.0 S 441 | 440 -552386.0 S 442 | 441 -552526.0 S 443 | 442 -552602.0 S 444 | 443 -552610.0 S 445 | 444 -552670.0 S 446 | 445 -552620.0 S 447 | 446 -552460.0 S 448 | 447 -552456.0 S 449 | 448 -552624.0 S 450 | 449 -552572.0 S 451 | 450 -552528.0 S 452 | 451 -552584.0 S 453 | 452 -552532.0 S 454 | 453 -552484.0 S 455 | 454 -552518.0 S 456 | 455 -552550.0 S 457 | 456 -552620.0 S 458 | 457 -552586.0 S 459 | 458 -552506.0 S 460 | 459 -552506.0 S 461 | 460 -552586.0 S 462 | 461 -552598.0 S 463 | 462 -552456.0 S 464 | 463 -552594.0 S 465 | 464 -552738.0 S 466 | 465 -552652.0 S 467 | 466 -552534.0 S 468 | 467 -552548.0 S 469 | 468 -552496.0 S 470 | 469 -552552.0 S 471 | 470 -552452.0 S 472 | 471 -552582.0 S 473 | 472 -552572.0 S 474 | 473 -552668.0 S 475 | 474 -552554.0 S 476 | 475 -552642.0 S 477 | 476 -552682.0 S 478 | 477 -552454.0 S 479 | 478 -552598.0 S 480 | 479 -552506.0 S 481 | 480 -552678.0 S 482 | 481 -552620.0 S 483 | 482 -552598.0 S 484 | 483 -552538.0 S 485 | 484 -552496.0 S 486 | 485 -552516.0 S 487 | 486 -552420.0 S 488 | 487 -552522.0 S 489 | 488 -552576.0 S 490 | 489 -552604.0 S 491 | 490 -552576.0 S 492 | 491 -552486.0 S 493 | 492 -552500.0 S 494 | 493 -552584.0 S 495 | 494 -552624.0 S 496 | 495 -552638.0 S 497 | 496 -552594.0 S 498 | 497 -552646.0 S 499 | 498 -552568.0 S 500 | 499 -552604.0 S 501 | 500 -552612.0 S 502 | 501 -552462.0 S 503 | 502 -552612.0 S 504 | 503 -552548.0 S 505 | 504 -552560.0 S 506 | 505 -552630.0 S 507 | 506 -552546.0 S 508 | 507 -552620.0 S 509 | 508 -552544.0 S 510 | 509 -552662.0 S 511 | 510 -552464.0 S 512 | 511 -552662.0 S 513 | 512 -552484.0 S 514 | 513 -552650.0 S 515 | 514 -552556.0 S 516 | 515 -552536.0 S 517 | 516 -552480.0 S 518 | 517 -552620.0 S 519 | 518 -552608.0 S 520 | 519 -552730.0 S 521 | 520 -552604.0 S 522 | 521 -552518.0 S 523 | 522 -552516.0 S 524 | 523 -552570.0 S 525 | 524 -552608.0 S 526 | 525 -552494.0 S 527 | 526 -552466.0 S 528 | 527 -552450.0 S 529 | 528 -552620.0 S 530 | 529 -552522.0 S 531 | 530 -552438.0 S 532 | 531 -552484.0 S 533 | 532 -552414.0 S 534 | 533 -552484.0 S 535 | 534 -552616.0 S 536 | 535 -552570.0 S 537 | 536 -552660.0 S 538 | 537 -552556.0 S 539 | 538 -552498.0 S 540 | 539 -552524.0 S 541 | 540 -552608.0 S 542 | 541 -552562.0 S 543 | 542 -552566.0 S 544 | 543 -552524.0 S 545 | 544 -552376.0 S 546 | 545 -552570.0 S 547 | 546 -552618.0 S 548 | 547 -552538.0 S 549 | 548 -552646.0 S 550 | 549 -552500.0 S 551 | 550 -552564.0 S 552 | 551 -552582.0 S 553 | 552 -552632.0 S 554 | 553 -552596.0 S 555 | 554 -552662.0 S 556 | 555 -552556.0 S 557 | 556 -552636.0 S 558 | 557 -552498.0 S 559 | 558 -552562.0 S 560 | 559 -552632.0 S 561 | 560 -552582.0 S 562 | 561 -552508.0 S 563 | 562 -552600.0 S 564 | 563 -552622.0 S 565 | 564 -552550.0 S 566 | 565 -552598.0 S 567 | 566 -552600.0 S 568 | 567 -552428.0 S 569 | 568 -552708.0 S 570 | 569 -552686.0 S 571 | 570 -552710.0 S 572 | 571 -552628.0 S 573 | 572 -552582.0 S 574 | 573 -552720.0 S 575 | 574 -552654.0 S 576 | 575 -552534.0 S 577 | 576 -552582.0 S 578 | 577 -552574.0 S 579 | 578 -552598.0 S 580 | 579 -552568.0 S 581 | 580 -552496.0 S 582 | 581 -552650.0 S 583 | 582 -552674.0 S 584 | 583 -552598.0 S 585 | 584 -552576.0 S 586 | 585 -552534.0 S 587 | 586 -552562.0 S 588 | 587 -552612.0 S 589 | 588 -552588.0 S 590 | 589 -552594.0 S 591 | 590 -552590.0 S 592 | 591 -552582.0 S 593 | 592 -552540.0 S 594 | 593 -552688.0 S 595 | 594 -552640.0 S 596 | 595 -552448.0 S 597 | 596 -552472.0 S 598 | 597 -552564.0 S 599 | 598 -552534.0 S 600 | 599 -552568.0 S 601 | 600 -552484.0 S 602 | 601 -552508.0 S 603 | 602 -552660.0 S 604 | 603 -552514.0 S 605 | 604 -552632.0 S 606 | 605 -552562.0 S 607 | 606 -552592.0 S 608 | 607 -552508.0 S 609 | 608 -552552.0 S 610 | 609 -552618.0 S 611 | 610 -552674.0 S 612 | 611 -552610.0 S 613 | 612 -552502.0 S 614 | 613 -552500.0 S 615 | 614 -552618.0 S 616 | 615 -552630.0 S 617 | 616 -552530.0 S 618 | 617 -552434.0 S 619 | 618 -552518.0 S 620 | 619 -552428.0 S 621 | 620 -552630.0 S 622 | 621 -552578.0 S 623 | 622 -552650.0 S 624 | 623 -552670.0 S 625 | 624 -552640.0 S 626 | 625 -552440.0 S 627 | 626 -552504.0 S 628 | 627 -552568.0 S 629 | 628 -552502.0 S 630 | 629 -552556.0 S 631 | 630 -552554.0 S 632 | 631 -552486.0 S 633 | 632 -552594.0 S 634 | 633 -552496.0 S 635 | 634 -552370.0 S 636 | 635 -552468.0 S 637 | 636 -552550.0 S 638 | 637 -552530.0 S 639 | 638 -552540.0 S 640 | 639 -552516.0 S 641 | 640 -552658.0 S 642 | 641 -552436.0 S 643 | 642 -552600.0 S 644 | 643 -552658.0 S 645 | 644 -552706.0 S 646 | 645 -552622.0 S 647 | 646 -552636.0 S 648 | 647 -552642.0 S 649 | 648 -552566.0 S 650 | 649 -552458.0 S 651 | 650 -552600.0 S 652 | 651 -552468.0 S 653 | 652 -552502.0 S 654 | 653 -552500.0 S 655 | 654 -552530.0 S 656 | 655 -552460.0 S 657 | 656 -552558.0 S 658 | 657 -552696.0 S 659 | 658 -552614.0 S 660 | 659 -552490.0 S 661 | 660 -552584.0 S 662 | 661 -552602.0 S 663 | 662 -552636.0 S 664 | 663 -552418.0 S 665 | 664 -552628.0 S 666 | 665 -552696.0 S 667 | 666 -552552.0 S 668 | 667 -552528.0 S 669 | 668 -552538.0 S 670 | 669 -552532.0 S 671 | 670 -552492.0 S 672 | 671 -552474.0 S 673 | 672 -552522.0 S 674 | 673 -552532.0 S 675 | 674 -552666.0 S 676 | 675 -552592.0 S 677 | 676 -552642.0 S 678 | 677 -552574.0 S 679 | 678 -552724.0 S 680 | 679 -552544.0 S 681 | 680 -552578.0 S 682 | 681 -552728.0 S 683 | 682 -552620.0 S 684 | 683 -552490.0 S 685 | 684 -552626.0 S 686 | 685 -552664.0 S 687 | 686 -552512.0 S 688 | 687 -552532.0 S 689 | 688 -552370.0 S 690 | 689 -552598.0 S 691 | 690 -552636.0 S 692 | 691 -552618.0 S 693 | 692 -552648.0 S 694 | 693 -552742.0 S 695 | 694 -552596.0 S 696 | 695 -552564.0 S 697 | 696 -552536.0 S 698 | 697 -552486.0 S 699 | 698 -552488.0 S 700 | 699 -552390.0 S 701 | 700 -552666.0 S 702 | 701 -552674.0 S 703 | 702 -552614.0 S 704 | 703 -552548.0 S 705 | 704 -552602.0 S 706 | 705 -552648.0 S 707 | 706 -552410.0 S 708 | 707 -552290.0 S 709 | 708 -552598.0 S 710 | 709 -552640.0 S 711 | 710 -552582.0 S 712 | 711 -552674.0 S 713 | 712 -552650.0 S 714 | 713 -552526.0 S 715 | 714 -552652.0 S 716 | 715 -552592.0 S 717 | 716 -552574.0 S 718 | 717 -552594.0 S 719 | 718 -552574.0 S 720 | 719 -552558.0 S 721 | 720 -552598.0 S 722 | 721 -552652.0 S 723 | 722 -552590.0 S 724 | 723 -552620.0 S 725 | 724 -552640.0 S 726 | 725 -552750.0 S 727 | 726 -552586.0 S 728 | 727 -552532.0 S 729 | 728 -552528.0 S 730 | 729 -552560.0 S 731 | 730 -552646.0 S 732 | 731 -552632.0 S 733 | 732 -552558.0 S 734 | 733 -552634.0 S 735 | 734 -552644.0 S 736 | 735 -552598.0 S 737 | 736 -552620.0 S 738 | 737 -552544.0 S 739 | 738 -552592.0 S 740 | 739 -552620.0 S 741 | 740 -552594.0 S 742 | 741 -552622.0 S 743 | 742 -552622.0 S 744 | 743 -552550.0 S 745 | 744 -552594.0 S 746 | 745 -552514.0 S 747 | 746 -552582.0 S 748 | 747 -552546.0 S 749 | 748 -552626.0 S 750 | 749 -552610.0 S 751 | 750 -552644.0 S 752 | 751 -552622.0 S 753 | 752 -552678.0 S 754 | 753 -552592.0 S 755 | 754 -552546.0 S 756 | 755 -552562.0 S 757 | 756 -552616.0 S 758 | 757 -552584.0 S 759 | 758 -552596.0 S 760 | 759 -552632.0 S 761 | 760 -552620.0 S 762 | 761 -552628.0 S 763 | 762 -552692.0 S 764 | 763 -552570.0 S 765 | 764 -552588.0 S 766 | 765 -552586.0 S 767 | 766 -552490.0 S 768 | 767 -552524.0 S 769 | 768 -552582.0 S 770 | 769 -552554.0 S 771 | 770 -552588.0 S 772 | 771 -552558.0 S 773 | 772 -552506.0 S 774 | 773 -552640.0 S 775 | 774 -552618.0 S 776 | 775 -552490.0 S 777 | 776 -552572.0 S 778 | 777 -552762.0 S 779 | 778 -552514.0 S 780 | 779 -552620.0 S 781 | 780 -552684.0 S 782 | 781 -552710.0 S 783 | 782 -552758.0 S 784 | 783 -552606.0 S 785 | 784 -552522.0 S 786 | 785 -552638.0 S 787 | 786 -552640.0 S 788 | 787 -552656.0 S 789 | 788 -552616.0 S 790 | 789 -552614.0 S 791 | 790 -552598.0 S 792 | 791 -552574.0 S 793 | 792 -552586.0 S 794 | 793 -552586.0 S 795 | 794 -552698.0 S 796 | 795 -552696.0 S 797 | 796 -552664.0 S 798 | 797 -552696.0 S 799 | 798 -552608.0 S 800 | 799 -552570.0 S 801 | 800 -552620.0 S 802 | 801 -552592.0 S 803 | 802 -552680.0 S 804 | 803 -552678.0 S 805 | 804 -552636.0 S 806 | 805 -552512.0 S 807 | 806 -552464.0 S 808 | 807 -552500.0 S 809 | 808 -552636.0 S 810 | 809 -552638.0 S 811 | 810 -552608.0 S 812 | 811 -552518.0 S 813 | 812 -552592.0 S 814 | 813 -552664.0 S 815 | 814 -552618.0 S 816 | 815 -552592.0 S 817 | 816 -552556.0 S 818 | 817 -552502.0 S 819 | 818 -552522.0 S 820 | 819 -552588.0 S 821 | 820 -552544.0 S 822 | 821 -552600.0 S 823 | 822 -552518.0 S 824 | 823 -552584.0 S 825 | 824 -552542.0 S 826 | 825 -552480.0 S 827 | 826 -552538.0 S 828 | 827 -552524.0 S 829 | 828 -552624.0 S 830 | 829 -552474.0 S 831 | 830 -552448.0 S 832 | 831 -552576.0 S 833 | 832 -552574.0 S 834 | 833 -552480.0 S 835 | 834 -552580.0 S 836 | 835 -552612.0 S 837 | 836 -552548.0 S 838 | 837 -552458.0 S 839 | 838 -552596.0 S 840 | 839 -552602.0 S 841 | 840 -552634.0 S 842 | 841 -552528.0 S 843 | 842 -552532.0 S 844 | 843 -552562.0 S 845 | 844 -552604.0 S 846 | 845 -552568.0 S 847 | 846 -552572.0 S 848 | 847 -552496.0 S 849 | 848 -552626.0 S 850 | 849 -552458.0 S 851 | 850 -552538.0 S 852 | 851 -552428.0 S 853 | 852 -552550.0 S 854 | 853 -552566.0 S 855 | 854 -552442.0 S 856 | 855 -552440.0 S 857 | 856 -552502.0 S 858 | 857 -552550.0 S 859 | 858 -552530.0 S 860 | 859 -552568.0 S 861 | 860 -552422.0 S 862 | 861 -552598.0 S 863 | 862 -552530.0 S 864 | 863 -552604.0 S 865 | 864 -552514.0 S 866 | 865 -552614.0 S 867 | 866 -552522.0 S 868 | 867 -552562.0 S 869 | 868 -552646.0 S 870 | 869 -552602.0 S 871 | 870 -552476.0 S 872 | 871 -552590.0 S 873 | 872 -552608.0 S 874 | 873 -552552.0 S 875 | 874 -552650.0 S 876 | 875 -552634.0 S 877 | 876 -552636.0 S 878 | 877 -552538.0 S 879 | 878 -552616.0 S 880 | 879 -552644.0 S 881 | 880 -552492.0 S 882 | 881 -552556.0 S 883 | 882 -552650.0 S 884 | 883 -552634.0 S 885 | 884 -552548.0 S 886 | 885 -552622.0 S 887 | 886 -552610.0 S 888 | 887 -552572.0 S 889 | 888 -552570.0 S 890 | 889 -552622.0 S 891 | 890 -552626.0 S 892 | 891 -552544.0 S 893 | 892 -552602.0 S 894 | 893 -552692.0 S 895 | 894 -552654.0 S 896 | 895 -552604.0 S 897 | 896 -552464.0 S 898 | 897 -552534.0 S 899 | 898 -552528.0 S 900 | 899 -552562.0 S 901 | 900 -552602.0 S 902 | 901 -552542.0 S 903 | 902 -552512.0 S 904 | 903 -552602.0 S 905 | 904 -552604.0 S 906 | 905 -552572.0 S 907 | 906 -552608.0 S 908 | 907 -552586.0 S 909 | 908 -552452.0 S 910 | 909 -552630.0 S 911 | 910 -552542.0 S 912 | 911 -552630.0 S 913 | 912 -552550.0 S 914 | 913 -552668.0 S 915 | 914 -552652.0 S 916 | 915 -552458.0 S 917 | 916 -552660.0 S 918 | 917 -552596.0 S 919 | 918 -552660.0 S 920 | 919 -552552.0 S 921 | 920 -552486.0 S 922 | 921 -552492.0 S 923 | 922 -552528.0 S 924 | 923 -552640.0 S 925 | 924 -552634.0 S 926 | 925 -552682.0 S 927 | 926 -552506.0 S 928 | 927 -552616.0 S 929 | 928 -552498.0 S 930 | 929 -552482.0 S 931 | 930 -552568.0 S 932 | 931 -552568.0 S 933 | 932 -552582.0 S 934 | 933 -552458.0 S 935 | 934 -552438.0 S 936 | 935 -552508.0 S 937 | 936 -552520.0 S 938 | 937 -552656.0 S 939 | 938 -552656.0 S 940 | 939 -552594.0 S 941 | 940 -552698.0 S 942 | 941 -552612.0 S 943 | 942 -552666.0 S 944 | 943 -552576.0 S 945 | 944 -552576.0 S 946 | 945 -552574.0 S 947 | 946 -552572.0 S 948 | 947 -552520.0 S 949 | 948 -552464.0 S 950 | 949 -552554.0 S 951 | 950 -552620.0 S 952 | 951 -552536.0 S 953 | 952 -552596.0 S 954 | 953 -552594.0 S 955 | 954 -552588.0 S 956 | 955 -552526.0 S 957 | 956 -552640.0 S 958 | 957 -552700.0 S 959 | 958 -552598.0 S 960 | 959 -552642.0 S 961 | 960 -552652.0 S 962 | 961 -552600.0 S 963 | 962 -552566.0 S 964 | 963 -552592.0 S 965 | 964 -552616.0 S 966 | 965 -552566.0 S 967 | 966 -552564.0 S 968 | 967 -552562.0 S 969 | 968 -552582.0 S 970 | 969 -552592.0 S 971 | 970 -552580.0 S 972 | 971 -552568.0 S 973 | 972 -552586.0 S 974 | 973 -552650.0 S 975 | 974 -552584.0 S 976 | 975 -552516.0 S 977 | 976 -552432.0 S 978 | 977 -552546.0 S 979 | 978 -552450.0 S 980 | 979 -552674.0 S 981 | 980 -552680.0 S 982 | 981 -552564.0 S 983 | 982 -552594.0 S 984 | 983 -552634.0 S 985 | 984 -552578.0 S 986 | 985 -552512.0 S 987 | 986 -552492.0 S 988 | 987 -552460.0 S 989 | 988 -552528.0 S 990 | 989 -552612.0 S 991 | 990 -552544.0 S 992 | 991 -552598.0 S 993 | 992 -552444.0 S 994 | 993 -552590.0 S 995 | 994 -552496.0 S 996 | 995 -552670.0 S 997 | 996 -552546.0 S 998 | 997 -552650.0 S 999 | 998 -552590.0 S 1000 | 999 -552644.0 S 1001 | 1000 -552686.0 S 1002 | 1001 -552546.0 S 1003 | 1002 -552594.0 S 1004 | 1003 -552670.0 S 1005 | 1004 -552684.0 S 1006 | 1005 -552568.0 S 1007 | 1006 -552466.0 S 1008 | 1007 -552514.0 S 1009 | 1008 -552550.0 S 1010 | 1009 -552592.0 S 1011 | 1010 -552608.0 S 1012 | 1011 -552604.0 S 1013 | 1012 -552642.0 S 1014 | 1013 -552456.0 S 1015 | 1014 -552494.0 S 1016 | 1015 -552542.0 S 1017 | 1016 -552556.0 S 1018 | 1017 -552614.0 S 1019 | 1018 -552588.0 S 1020 | 1019 -552542.0 S 1021 | 1020 -552646.0 S 1022 | 1021 -552578.0 S 1023 | 1022 -552498.0 S 1024 | 1023 -552590.0 S 1025 | 1024 -552654.0 S 1026 | 1025 -552696.0 S 1027 | 1026 -552586.0 S 1028 | 1027 -552588.0 S 1029 | 1028 -552652.0 S 1030 | 1029 -552560.0 S 1031 | 1030 -552580.0 S 1032 | 1031 -552536.0 S 1033 | 1032 -552512.0 S 1034 | 1033 -552540.0 S 1035 | 1034 -552588.0 S 1036 | 1035 -552642.0 S 1037 | 1036 -552476.0 S 1038 | 1037 -552716.0 S 1039 | 1038 -552678.0 S 1040 | 1039 -552570.0 S 1041 | 1040 -552514.0 S 1042 | 1041 -552502.0 S 1043 | 1042 -552700.0 S 1044 | 1043 -552726.0 S 1045 | 1044 -552648.0 S 1046 | 1045 -552622.0 S 1047 | 1046 -552690.0 S 1048 | 1047 -552686.0 S 1049 | 1048 -552558.0 S 1050 | 1049 -552530.0 S 1051 | 1050 -552526.0 S 1052 | 1051 -552628.0 S 1053 | 1052 -552628.0 S 1054 | 1053 -552626.0 S 1055 | 1054 -552556.0 S 1056 | 1055 -552544.0 S 1057 | 1056 -552580.0 S 1058 | 1057 -552756.0 S 1059 | 1058 -552500.0 S 1060 | 1059 -552454.0 S 1061 | 1060 -552636.0 S 1062 | 1061 -552552.0 S 1063 | 1062 -552558.0 S 1064 | 1063 -552572.0 S 1065 | 1064 -552560.0 S 1066 | 1065 -552558.0 S 1067 | 1066 -552558.0 S 1068 | 1067 -552678.0 S 1069 | 1068 -552616.0 S 1070 | 1069 -552638.0 S 1071 | 1070 -552628.0 S 1072 | 1071 -552572.0 S 1073 | 1072 -552474.0 S 1074 | 1073 -552590.0 S 1075 | 1074 -552508.0 S 1076 | 1075 -552610.0 S 1077 | 1076 -552558.0 S 1078 | 1077 -552510.0 S 1079 | 1078 -552638.0 S 1080 | 1079 -552636.0 S 1081 | 1080 -552584.0 S 1082 | 1081 -552552.0 S 1083 | 1082 -552608.0 S 1084 | 1083 -552536.0 S 1085 | 1084 -552598.0 S 1086 | 1085 -552556.0 S 1087 | 1086 -552646.0 S 1088 | 1087 -552710.0 S 1089 | 1088 -552668.0 S 1090 | 1089 -552626.0 S 1091 | 1090 -552492.0 S 1092 | 1091 -552506.0 S 1093 | 1092 -552644.0 S 1094 | 1093 -552670.0 S 1095 | 1094 -552694.0 S 1096 | 1095 -552586.0 S 1097 | 1096 -552652.0 S 1098 | 1097 -552658.0 S 1099 | 1098 -552618.0 S 1100 | 1099 -552492.0 S 1101 | -------------------------------------------------------------------------------- /mcmc_gibbs/output/log_energy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/mcmc_gibbs/output/log_energy.png -------------------------------------------------------------------------------- /policy_gradient/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/policy_gradient/model/__init__.py -------------------------------------------------------------------------------- /policy_gradient/model/baseline_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | from tensorflow.keras import layers 4 | 5 | 6 | class BaselineNet(): 7 | def __init__(self, input_size, output_size): 8 | self.model = keras.Sequential( 9 | layers=[ 10 | keras.Input(shape=(input_size,)), 11 | layers.Dense(64, activation="relu", name="relu_layer"), 12 | layers.Dense(output_size, activation="linear", name="linear_layer") 13 | ], 14 | name="baseline") 15 | self.optimizer = tf.keras.optimizers.Adam(learning_rate=3e-2) 16 | 17 | def forward(self, observations): 18 | output = tf.squeeze(self.model(observations)) 19 | return output 20 | 21 | def update(self, observations, target): 22 | with tf.GradientTape() as tape: 23 | predictions = self.forward(observations) 24 | loss = tf.keras.losses.mean_squared_error(y_true=target, y_pred=predictions) 25 | grads = tape.gradient(loss, self.model.trainable_weights) 26 | self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights)) 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /policy_gradient/model/main.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from model.policy_gradient import PolicyGradient 4 | 5 | 6 | if __name__ == '__main__': 7 | env = gym.make("CartPole-v0") 8 | model = PolicyGradient(env) 9 | model.train() 10 | model.make_video() -------------------------------------------------------------------------------- /policy_gradient/model/policy_gradient.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import tensorflow as tf 5 | from gym import wrappers 6 | import matplotlib.pyplot as plt 7 | 8 | from model.baseline_net import BaselineNet 9 | from model.policy_net import PolicyNet 10 | 11 | 12 | def export_plot(ys, ylabel, title, filename): 13 | plt.figure() 14 | plt.plot(range(len(ys)), ys) 15 | plt.xlabel("Training Episode") 16 | plt.ylabel(ylabel) 17 | plt.title(title) 18 | plt.savefig(filename) 19 | plt.close() 20 | 21 | 22 | class PolicyGradient(object): 23 | def __init__(self, env, num_iterations=300, batch_size=2000, max_ep_len=200, output_path="../results/"): 24 | self.output_path = output_path 25 | if not os.path.exists(output_path): 26 | os.makedirs(output_path) 27 | self.env = env 28 | self.observation_dim = self.env.observation_space.shape[0] 29 | self.action_dim = self.env.action_space.n 30 | self.gamma = 0.99 31 | self.num_iterations = num_iterations 32 | self.batch_size = batch_size 33 | self.max_ep_len = max_ep_len 34 | self.optimizer = tf.keras.optimizers.Adam(learning_rate=3e-2) 35 | self.policy_net = PolicyNet(input_size=self.observation_dim, output_size=self.action_dim) 36 | self.baseline_net = BaselineNet(input_size=self.observation_dim, output_size=1) 37 | 38 | def play_games(self, env=None, num_episodes = None): 39 | episode = 0 40 | episode_rewards = [] 41 | paths = [] 42 | t = 0 43 | if not env: 44 | env = self.env 45 | 46 | while (num_episodes or t < self.batch_size): 47 | state = env.reset() 48 | states, actions, rewards = [], [], [] 49 | episode_reward = 0 50 | 51 | for step in range(self.max_ep_len): 52 | states.append(state) 53 | action = self.policy_net.sampel_action(np.atleast_2d(state))[0] 54 | state, reward, done, _ = env.step(action) 55 | actions.append(action) 56 | rewards.append(reward) 57 | episode_reward += reward 58 | t += 1 59 | 60 | if (done or step == self.max_ep_len-1): 61 | episode_rewards.append(episode_reward) 62 | break 63 | if (not num_episodes) and t == self.batch_size: 64 | break 65 | 66 | path = {"observation": np.array(states), 67 | "reward": np.array(rewards), 68 | "action": np.array(actions)} 69 | paths.append(path) 70 | episode += 1 71 | if num_episodes and episode >= num_episodes: 72 | break 73 | return paths, episode_rewards 74 | 75 | def get_returns(self, paths): 76 | all_returns = [] 77 | for path in paths: 78 | rewards = path["reward"] 79 | returns = [] 80 | reversed_rewards = np.flip(rewards,0) 81 | g_t = 0 82 | for r in reversed_rewards: 83 | g_t = r + self.gamma*g_t 84 | returns.insert(0, g_t) 85 | all_returns.append(returns) 86 | returns = np.concatenate(all_returns) 87 | return returns 88 | 89 | def get_advantage(self, returns, observations): 90 | values = self.baseline_net.forward(observations).numpy() 91 | advantages = returns - values 92 | advantages = (advantages-np.mean(advantages)) / np.sqrt(np.sum(advantages**2)) 93 | return advantages 94 | 95 | def update_policy(self, observations, actions, advantages): 96 | observations = tf.convert_to_tensor(observations) 97 | actions = tf.convert_to_tensor(actions) 98 | advantages = tf.convert_to_tensor(advantages) 99 | with tf.GradientTape() as tape: 100 | log_prob = self.policy_net.action_distribution(observations).log_prob(actions) 101 | loss = -tf.math.reduce_mean(log_prob * tf.cast(advantages, tf.float32)) 102 | grads = tape.gradient(loss, self.policy_net.model.trainable_weights) 103 | self.optimizer.apply_gradients(zip(grads, self.policy_net.model.trainable_weights)) 104 | 105 | def train(self): 106 | all_total_rewards = [] 107 | averaged_total_rewards = [] 108 | for t in range(self.num_iterations): 109 | paths, total_rewards = self.play_games() 110 | all_total_rewards.extend(total_rewards) 111 | observations = np.concatenate([path["observation"] for path in paths]) 112 | actions = np.concatenate([path["action"] for path in paths]) 113 | returns = self.get_returns(paths) 114 | advantages = self.get_advantage(returns, observations) 115 | self.baseline_net.update(observations=observations, target=returns) 116 | self.update_policy(observations, actions, advantages) 117 | avg_reward = np.mean(total_rewards) 118 | averaged_total_rewards.append(avg_reward) 119 | print("Average reward for batch {}: {:04.2f}".format(t,avg_reward)) 120 | print("Training complete") 121 | np.save(self.output_path+ "rewards.npy", averaged_total_rewards) 122 | export_plot(averaged_total_rewards, "Reward", "CartPole-v0", self.output_path + "rewards.png") 123 | 124 | def eval(self, env, num_episodes=1): 125 | paths, rewards = self.play_games(env, num_episodes) 126 | avg_reward = np.mean(rewards) 127 | print("Average eval reward: {:04.2f}".format(avg_reward)) 128 | return avg_reward 129 | 130 | def make_video(self): 131 | env = wrappers.Monitor(self.env, self.output_path+"videos", force=True) 132 | self.eval(env=env, num_episodes=1) 133 | -------------------------------------------------------------------------------- /policy_gradient/model/policy_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow_probability as tfp 2 | from tensorflow import keras 3 | from tensorflow.keras import layers 4 | 5 | 6 | class PolicyNet(): 7 | def __init__(self, input_size, output_size): 8 | self.model = keras.Sequential( 9 | layers=[ 10 | keras.Input(shape=(input_size,)), 11 | layers.Dense(64, activation="relu", name="relu_layer"), 12 | layers.Dense(output_size, activation="linear", name="linear_layer") 13 | ], 14 | name="policy") 15 | 16 | def action_distribution(self, observations): 17 | logits = self.model(observations) 18 | return tfp.distributions.Categorical(logits=logits) 19 | 20 | def sampel_action(self, observations): 21 | sampled_actions = self.action_distribution(observations).sample().numpy() 22 | return sampled_actions 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /policy_gradient/results/rewards.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/policy_gradient/results/rewards.npy -------------------------------------------------------------------------------- /policy_gradient/results/rewards.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/policy_gradient/results/rewards.png -------------------------------------------------------------------------------- /policy_gradient/results/videos/openaigym.episode_batch.0.2394.stats.json: -------------------------------------------------------------------------------- 1 | {"initial_reset_timestamp": 1616129004.328471, "timestamps": [1616129009.9160478], "episode_lengths": [200], "episode_rewards": [200.0], "episode_types": ["t"]} -------------------------------------------------------------------------------- /policy_gradient/results/videos/openaigym.manifest.0.2394.manifest.json: -------------------------------------------------------------------------------- 1 | {"stats": "openaigym.episode_batch.0.2394.stats.json", "videos": [["openaigym.video.0.2394.video000000.mp4", "openaigym.video.0.2394.video000000.meta.json"]], "env_info": {"gym_version": "0.17.3", "env_id": "CartPole-v0"}} -------------------------------------------------------------------------------- /policy_gradient/results/videos/openaigym.video.0.2394.video000000.meta.json: -------------------------------------------------------------------------------- 1 | {"episode_id": 0, "content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.3.1 Copyright (c) 2000-2020 the FFmpeg developers\\nbuilt with Apple clang version 12.0.0 (clang-1200.0.32.29)\\nconfiguration: --prefix=/usr/local/Cellar/ffmpeg/4.3.1_9 --enable-shared --enable-pthreads --enable-version3 --enable-avresample --cc=clang --host-cflags= --host-ldflags= --enable-ffplay --enable-gnutls --enable-gpl --enable-libaom --enable-libbluray --enable-libdav1d --enable-libmp3lame --enable-libopus --enable-librav1e --enable-librubberband --enable-libsnappy --enable-libsrt --enable-libtesseract --enable-libtheora --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx264 --enable-libx265 --enable-libxml2 --enable-libxvid --enable-lzma --enable-libfontconfig --enable-libfreetype --enable-frei0r --enable-libass --enable-libopencore-amrnb --enable-libopencore-amrwb --enable-libopenjpeg --enable-librtmp --enable-libspeex --enable-libsoxr --enable-libzmq --enable-libzimg --disable-libjack --disable-indev=jack --enable-videotoolbox\\nlibavutil 56. 51.100 / 56. 51.100\\nlibavcodec 58. 91.100 / 58. 91.100\\nlibavformat 58. 45.100 / 58. 45.100\\nlibavdevice 58. 10.100 / 58. 10.100\\nlibavfilter 7. 85.100 / 7. 85.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 7.100 / 5. 7.100\\nlibswresample 3. 7.100 / 3. 7.100\\nlibpostproc 55. 7.100 / 55. 7.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/Users/VivianXU/Projects/Medium-Tutorials/policy_gradient/results/videos/openaigym.video.0.2394.video000000.mp4"]}} -------------------------------------------------------------------------------- /policy_gradient/results/videos/openaigym.video.0.2394.video000000.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/policy_gradient/results/videos/openaigym.video.0.2394.video000000.mp4 -------------------------------------------------------------------------------- /search/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXU1230/Medium-Tutorials/bc059d131e020c9121f4c258a007f539fa364c85/search/__init__.py -------------------------------------------------------------------------------- /search/a_star.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | 3 | 4 | def dijkstra(grid): 5 | if grid[0][0] == 1 or grid[-1][-1] == 1: 6 | return -1 7 | n = len(grid) 8 | heap = [] 9 | 10 | seen = set() 11 | heapq.heappush(heap, (1, 1, 0, 0)) 12 | cache = {(0, 0): 1} 13 | while heap: 14 | (cost, dist, x, y) = heapq.heappop(heap) 15 | print("({},{})".format(x, y), end=" ") 16 | if x == n - 1 and y == n - 1: 17 | return dist 18 | if (x,y) not in seen: 19 | seen.add((x, y)) 20 | for (nei_x, nei_y) in get_nei(x, y): 21 | if 0 <= nei_x < n and 0 <= nei_y < n and grid[nei_x][nei_y] == 0: 22 | new_dist = dist + 1 23 | if new_dist < cache.get((nei_x, nei_y), float('inf')): 24 | cache[(nei_x, nei_y)] = new_dist 25 | heapq.heappush(heap, (new_dist, new_dist, nei_x, nei_y)) 26 | return -1 27 | 28 | 29 | def a_star(grid): 30 | def get_heuristic(current, target): 31 | return max(abs(current[0] - target[0]), abs(current[1] - target[1])) 32 | 33 | if grid[0][0] == 1 or grid[-1][-1] == 1: 34 | return -1 35 | 36 | n = len(grid) 37 | target = (n - 1, n - 1) 38 | heap = [] 39 | seen = set() 40 | heuristic = get_heuristic((0, 0), target) 41 | heapq.heappush(heap, (heuristic + 1, 1, 0, 0)) 42 | cache = {(0, 0): heuristic + 1} 43 | while heap: 44 | (cost, dist, x, y) = heapq.heappop(heap) 45 | print("({},{})".format(x, y), end=" ") 46 | if x == n - 1 and y == n - 1: 47 | return dist 48 | if (x, y) not in seen: 49 | seen.add((x, y)) 50 | for (nei_x, nei_y) in get_nei(x, y): 51 | if 0 <= nei_x < n and 0 <= nei_y < n and grid[nei_x][nei_y] == 0: 52 | new_dist = dist + 1 53 | heuristic = get_heuristic((nei_x, nei_y), target) 54 | if new_dist + heuristic < cache.get((nei_x, nei_y), float('inf')): 55 | cache[(nei_x, nei_y)] = new_dist + heuristic 56 | heapq.heappush(heap, (new_dist + heuristic, new_dist, nei_x, nei_y)) 57 | 58 | return -1 59 | 60 | 61 | def get_nei(x, y): 62 | return [(x + 1, y), (x - 1, y), (x + 1, y - 1), (x - 1, y - 1), (x + 1, y + 1), (x - 1, y + 1), (x, y + 1), 63 | (x, y - 1)] 64 | 65 | 66 | if __name__ == "__main__": 67 | grid = [[0,0,0,0,1,1,1,1,0], 68 | [0,1,1,0,0,0,0,1,0], 69 | [0,0,1,0,0,0,0,0,0], 70 | [1,1,0,0,1,0,0,1,1], 71 | [0,0,1,1,1,0,1,0,1], 72 | [0,1,0,1,0,0,0,0,0], 73 | [0,0,0,1,0,1,0,0,0], 74 | [0,1,0,1,1,0,0,0,0], 75 | [0,0,0,0,0,1,0,1,0]] 76 | 77 | print("Dijkstra path: ") 78 | dist_d = dijkstra(grid) 79 | print("\nshortest distance: {}\n".format(dist_d)) 80 | print("A* path: ") 81 | dist_a = a_star(grid) 82 | print("\nshortest distance: {}".format(dist_a)) 83 | 84 | # https://leetcode.com/problems/shortest-path-in-binary-matrix/ -------------------------------------------------------------------------------- /search/dijkstra.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | 3 | 4 | def dijkstra(grid): 5 | seen = set() 6 | heap = [] 7 | dist = grid[0][0] 8 | heapq.heappush(heap, (dist, 0, 0)) 9 | row_len = len(grid) 10 | col_len = len(grid[0]) 11 | min_dist = float('inf') 12 | while heap: 13 | dist, row, col = heapq.heappop(heap) 14 | if row == row_len - 1 and col == col_len - 1: 15 | min_dist = min(min_dist, dist) 16 | break 17 | if (row, col) not in seen: 18 | seen.add((row, col)) 19 | if row+1 < row_len: 20 | heapq.heappush(heap, (dist+grid[row+1][col], row+1, col)) 21 | if col+1 < col_len : 22 | heapq.heappush(heap, (dist+grid[row][col+1], row, col+1)) 23 | return min_dist 24 | 25 | 26 | if __name__ == "__main__": 27 | grid = [[1,3,1],[1,5,1],[4,2,1]] 28 | print("shortest distance: ", dijkstra(grid)) 29 | 30 | # https://leetcode.com/problems/minimum-path-sum/ 31 | 32 | -------------------------------------------------------------------------------- /search/dp.py: -------------------------------------------------------------------------------- 1 | def path_search(row_len, col_len): 2 | dp = [[0 for _ in range(col_len)] for _ in range(row_len)] 3 | for r in reversed(range(row_len)): 4 | for c in reversed(range(col_len)): 5 | if r == row_len-1 and c == col_len-1: 6 | dp[r][c] = 1 7 | else: 8 | if r+1 < row_len: 9 | dp[r][c] += dp[r+1][c] 10 | if c+1 < col_len: 11 | dp[r][c] += dp[r][c+1] 12 | return dp[0][0] 13 | 14 | 15 | if __name__ == "__main__": 16 | m = 3 17 | n = 7 18 | print("Number of unique paths: ", path_search(m, n)) 19 | 20 | # https://leetcode.com/problems/unique-paths/ 21 | -------------------------------------------------------------------------------- /search/tree_search.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | 4 | class Node: 5 | def __init__(self, val=None): 6 | self.val = val 7 | self.left = None 8 | self.right = None 9 | 10 | 11 | def get_root(): 12 | values = iter([1, 6, 4, 7, None, None, 9, 8, None, None, 13 | 10, None, None, 5, None, None, 2, 3, None, None, 11, None, None]) 14 | 15 | def tree_recur(itr): 16 | val = next(itr) 17 | if val is not None: 18 | node = Node(val) 19 | node.left = tree_recur(itr) 20 | node.right = tree_recur(itr) 21 | return node 22 | 23 | return tree_recur(values) 24 | 25 | 26 | def dfs(): 27 | root = get_root() 28 | res = float("inf") 29 | 30 | def dfs_search(node, depth): 31 | if node is not None: 32 | val = node.val 33 | print(val, end=" ") 34 | if val >= 10: 35 | nonlocal res 36 | res = min(res, depth) 37 | else: 38 | dfs_search(node.left, depth+1) 39 | dfs_search(node.right, depth + 1) 40 | 41 | dfs_search(root, 0) 42 | if res < float("inf"): 43 | return res 44 | return -1 45 | 46 | 47 | def bfs(): 48 | root = get_root() 49 | queue = collections.deque() 50 | queue.appendleft((root, 0)) 51 | res = -1 52 | 53 | while queue: 54 | node, depth = queue.pop() 55 | print(node.val, end=" ") 56 | if node.val >= 10: 57 | res = depth 58 | break 59 | 60 | if node.left: 61 | queue.appendleft((node.left, depth+1)) 62 | if node.right: 63 | queue.appendleft((node.right, depth+1)) 64 | return res 65 | 66 | 67 | def iddfs(): 68 | root = get_root() 69 | res = float("inf") 70 | 71 | def iddfs_search(node, depth, limit): 72 | if depth <= limit and node is not None: 73 | val = node.val 74 | print(val, end=" ") 75 | if val >= 10: 76 | nonlocal res 77 | res = min(res, depth) 78 | else: 79 | iddfs_search(node.left, depth + 1, limit) 80 | iddfs_search(node.right, depth + 1, limit) 81 | 82 | for limit in range(1, 5): 83 | print("\nmax depth: ", limit) 84 | iddfs_search(root, 0, limit) 85 | if res < float("inf"): 86 | return res 87 | return -1 88 | 89 | 90 | if __name__ == "__main__": 91 | print("\nBFS") 92 | print("\nshortest depth: ", bfs()) 93 | print("\nDFS") 94 | print("\nshortest depth: ", dfs()) 95 | print("\nIDDFS", end="") 96 | print("\nshortest depth: ", iddfs()) --------------------------------------------------------------------------------