├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── historys ├── mini-5-way-1-shot-test.png ├── mini-5-way-1-shot.png ├── mini-5-way-5-shot-test.png ├── mini-5-way-5-shot.png ├── miniimagenet-5-way-1-shot-acc-test.txt ├── miniimagenet-5-way-1-shot-acc.txt ├── miniimagenet-5-way-1-shot-loss-test.txt ├── miniimagenet-5-way-1-shot-train.txt ├── miniimagenet-5-way-1-shot.png ├── miniimagenet-5-way-5-shot-acc-test.txt ├── miniimagenet-5-way-5-shot-acc.txt ├── miniimagenet-5-way-5-shot-loss-test.txt ├── miniimagenet-5-way-5-shot-train.txt ├── miniimagenet-5-way-5-shot.png ├── omn-20-way-1-shot-test.png ├── omn-20-way-1-shot.png ├── omn-5-way-1-shot-test.png ├── omn-5-way-1-shot.png ├── omniglot-20-way-1-shot-acc-test.txt ├── omniglot-20-way-1-shot-acc.txt ├── omniglot-20-way-1-shot-loss-test.txt ├── omniglot-20-way-1-shot-train.txt ├── omniglot-20-way-1-shot.png ├── omniglot-5-way-1-shot-acc-test.txt ├── omniglot-5-way-1-shot-acc.txt ├── omniglot-5-way-1-shot-loss-test.txt ├── omniglot-5-way-1-shot-train.txt └── omniglot-5-way-1-shot.png ├── outputs ├── avg_to_nn.png ├── eval_sine_1.png ├── eval_sine_2.png ├── eval_sine_3.png ├── maml_test_1.png ├── maml_test_2.png ├── maml_train.png ├── meta_sine_1.png ├── meta_sine_2.png ├── meta_sine_3.png ├── meta_test_1.png ├── meta_test_2.png ├── meta_test_3.png ├── meta_to_avg.png ├── meta_train.png ├── meta_train_1.png ├── miniimagenet 5-way-1-shot.png ├── miniimagenet_5w1s.png ├── miniimagenet_5w5s.png ├── model.png ├── omniglot_20w1s.png ├── omniglot_5w1s.png └── sinemodel_loss.png └── scripts ├── image_classification ├── history_vis.py ├── image_preprocess.py ├── main.py ├── meta_learner.py └── task_generator.py ├── reinforcement_learning └── maml-rl-easy │ ├── Navigation2DEnv.py │ ├── __pycache__ │ ├── episode.cpython-36.pyc │ ├── policy.cpython-36.pyc │ ├── sampler.cpython-36.pyc │ └── subproc_vec_env.cpython-36.pyc │ ├── envs_test │ ├── __pycache__ │ │ └── sampler.cpython-36.pyc │ ├── maze_test.py │ └── navigation_test.py │ ├── episode.py │ ├── gym_test │ ├── cartpole_nn.py │ ├── cartpole_policy_gradient.py │ ├── gym_env_test.py │ ├── model │ │ ├── CartPole-v0-nn.h5 │ │ ├── MountainCar-v0-dqn.h5 │ │ └── MountainCar-v0-q-learning.pickle │ ├── mountain_car_pg.py │ ├── mountain_dqn.py │ └── mountaincar_q_learning.py │ ├── main.py │ ├── maze.py │ ├── maze_policy_gradient.py │ ├── meta_learner.py │ ├── policy.py │ ├── sampler.py │ └── subproc_vec_env.py └── sine_fitting ├── sine_model.py ├── sinusoid_generator.py ├── train_sine_model.py └── tri_sine_fitting.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.vscode 2 | /test_folder 3 | /.history 4 | .gz 5 | .python-version 6 | /dataset 7 | */__pycache__/ 8 | /weights 9 | /logs 10 | 11 | scripts/sine_fitting/__pycache__/ 12 | scripts/image_classification/__pycache__/ 13 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "scripts/reinforcement_learning/maml-rl-tf2"] 2 | path = scripts/reinforcement_learning/maml-rl-tf2 3 | url = git@github.com:HilbertXu/maml-rl-tf2.git 4 | [submodule "scripts/reinforcement_learning/maml-rl-pytorch"] 5 | path = scripts/reinforcement_learning/maml-rl-pytorch 6 | url = git@github.com:HilbertXu/pytorch-maml-rl.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 HilbertXu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAML-Tensorflow 2 | Tensorflow r2.1 reimplementation of Model-Agnostic Meta-Learning from this paper: 3 | 4 | [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) 5 | 6 | Reinforcement Learning part is adapted from [MoritzTaylor/maml-rl-tf2](https://github.com/MoritzTaylor/maml-rl-tf2) and [tristandeleu/pytorch-maml-rl]() Sincerely Thanks for their fantastic works!!! 7 | 8 | ## Project Requirements 9 | 10 | 1. python 3.x 11 | 2. Tensorflow r2.1 12 | 3. numpy 13 | 4. matplotlib 14 | 5. ... 15 | 6. All Scripts in image_classification folder is tested on python 3.6.5 16 | 17 | ## MiniImagenet Dataset 18 | 19 | I wrote a task generator which samples randomly from the whole dataset to set up a train batch during every training steps so it won't consume too much GPU memory. 20 | 21 | For 5-way 1-shot tasks on the MiniImagenet, it takes at round 1.6 s to run one training steps on the GTX1070 and for each task update fast_weights 1 time, and it allocate 1.3GB GPU memory. 22 | 23 | For 5-way 5-shot tasks on MiniImagenet, it takes at around 2.2 s to run one training steps. 24 | 25 | If you set the `--update_steps > 1`, it will take more time for one training step. 26 | 27 | 28 | 29 | 1. Download the [MiniImagenet](https://drive.google.com/open?id=1HkgrkAwukzEZA0TpO7010PkAOREb2Nuk), and the split files `train.csv, test.csv, val.csv` from [here](https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet) 30 | 31 | 2. Put the MiniImagenet dataset to your project folder like this 32 | 33 | ``` 34 | dataset/miniimagenet 35 | |-- images 36 | |--- n0153282900000005.jpg 37 | |--- n0153282900000006.jpg 38 | ... 39 | |-- test.csv 40 | |-- train.csv 41 | |-- val.csv 42 | ``` 43 | 44 | 3. Run the python script to resize and split the whole dataset 45 | 46 | ``` 47 | cd scripts/image_classification 48 | python image_preprocess.py --dataset=miniimagenet 49 | ``` 50 | 51 | 4. Modify the path to dataset in `scripts/image_classification/task_generator.py` 52 | 53 | ```python 54 | if self.dataset == 'miniimagenet': 55 | ... 56 | META_TRAIN_DIR = '../../dataset/miniImagenet/train' 57 | META_VAL_DIR = '../../dataset/miniImagenet/test' 58 | ... 59 | 60 | if self.dataset == 'omniglot': 61 | ... 62 | DATA_FOLDER = '../../dataset/omniglot' 63 | ... 64 | ``` 65 | 66 | 67 | 68 | 5. Run the main python script 69 | 70 | ``` 71 | cd scripts/image_classification 72 | # For 5-way 1-shot on miniimagenet 73 | python main.py --dataset=miniimagenet --mode=train --n_way=5 --k_shot=1 --k_query=15 74 | # For 5-way 5-shot on miniimagenet 75 | python main.py --dataset=miniimagenet --mode=train --n_way=5 --k_shot=5 --k_query=15 76 | ``` 77 | 78 | 79 | 80 | ## Omniglot Dataset 81 | 82 | For Omniglot dataset, it will consume fewer computing resource and time 83 | 84 | For 5-way 1-shot, 0.3 s for one training step 85 | 86 | For 20-way 1-shot, 0.7 s for one training step 87 | 88 | 1. Download Omniglot dataset from [here](git clone git@github.com:brendenlake/omniglot.git) and extract the contents of `python/images_background.zip` and `python/images_evaluation.zip` to the `dataset/omniglot` it will looks like this: 89 | 90 | ``` 91 | dataset/omniglot 92 | |-- Alphabet_of_the_Magi 93 | |-- Angelic 94 | ... 95 | ``` 96 | 97 | 2. Run the python script to resize the images 98 | 99 | ``` 100 | cd scripts/image_classification 101 | python image_preprocess.py --dataset=omniglot 102 | ``` 103 | 104 | 3. Run the main python script 105 | 106 | ``` 107 | cd scripts/image_classification 108 | # For 5-way 1-shot on Omniglot 109 | python main.py --dataset=omniglot --mode=train --n_way=5 --k_shot=1 --k_query=1 --inner_lr=0.1 110 | # For 20-way 1-shot on Omniglot 111 | python main.py --dataset=omniglot --mode=train --n_way=20 --k_shot=1 --k_query=1 --inner_lr=0.1 112 | ``` 113 | 114 | 115 | 116 | ## References 117 | 118 | This project is, for the most part, a reproduction of the original implementation [cbfinn/maml_rl](https://github.com/cbfinn/maml_rl/) in TensorFlow 2. The experiments are based on the paper 119 | 120 | > Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep 121 | > networks. _International Conference on Machine Learning (ICML)_, 2017 [[ArXiv](https://arxiv.org/abs/1703.03400)] 122 | 123 | If you want to cite this paper 124 | 125 | ``` 126 | @article{DBLP:journals/corr/FinnAL17, 127 | author = {Chelsea Finn and Pieter Abbeel and Sergey Levine}, 128 | title = {Model-{A}gnostic {M}eta-{L}earning for {F}ast {A}daptation of {D}eep {N}etworks}, 129 | journal = {International Conference on Machine Learning (ICML)}, 130 | year = {2017}, 131 | url = {http://arxiv.org/abs/1703.03400} 132 | } 133 | ``` -------------------------------------------------------------------------------- /historys/mini-5-way-1-shot-test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/mini-5-way-1-shot-test.png -------------------------------------------------------------------------------- /historys/mini-5-way-1-shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/mini-5-way-1-shot.png -------------------------------------------------------------------------------- /historys/mini-5-way-5-shot-test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/mini-5-way-5-shot-test.png -------------------------------------------------------------------------------- /historys/mini-5-way-5-shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/mini-5-way-5-shot.png -------------------------------------------------------------------------------- /historys/miniimagenet-5-way-1-shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/miniimagenet-5-way-1-shot.png -------------------------------------------------------------------------------- /historys/miniimagenet-5-way-5-shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/miniimagenet-5-way-5-shot.png -------------------------------------------------------------------------------- /historys/omn-20-way-1-shot-test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/omn-20-way-1-shot-test.png -------------------------------------------------------------------------------- /historys/omn-20-way-1-shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/omn-20-way-1-shot.png -------------------------------------------------------------------------------- /historys/omn-5-way-1-shot-test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/omn-5-way-1-shot-test.png -------------------------------------------------------------------------------- /historys/omn-5-way-1-shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/omn-5-way-1-shot.png -------------------------------------------------------------------------------- /historys/omniglot-20-way-1-shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/omniglot-20-way-1-shot.png -------------------------------------------------------------------------------- /historys/omniglot-5-way-1-shot-acc-test.txt: -------------------------------------------------------------------------------- 1 | [0.8, 0.2, 0.2, 0.2] 2 | [0.6, 0.6, 0.6, 0.6] 3 | [0.6, 0.4, 0.8, 1.0] 4 | [1.0, 0.6, 0.4, 0.4] 5 | [0.6, 0.8, 0.8, 0.8] 6 | [1.0, 0.6, 0.8, 0.4] 7 | [1.0, 1.0, 0.8, 0.8] 8 | [0.6, 0.6, 0.6, 0.6] 9 | [0.4, 1.0, 0.6, 0.6] 10 | [0.6, 0.8, 0.4, 0.4] 11 | [0.4, 0.6, 0.8, 0.4] 12 | [0.6, 1.0, 0.8, 1.0] 13 | [1.0, 1.0, 0.8, 0.2] 14 | [0.6, 0.6, 0.6, 0.8] 15 | [0.8, 0.8, 0.6, 0.8] 16 | [0.2, 0.8, 0.6, 1.0] 17 | [0.6, 0.8, 0.8, 0.4] 18 | [0.6, 0.8, 0.6, 0.4] 19 | [0.8, 0.6, 0.8, 0.8] 20 | [0.6, 1.0, 0.4, 0.4] 21 | [0.4, 0.6, 1.0, 0.4] 22 | [0.8, 0.4, 0.8, 0.4] 23 | [0.2, 0.6, 0.6, 0.4] 24 | [0.8, 1.0, 1.0, 0.8] 25 | [1.0, 1.0, 0.4, 0.6] 26 | [0.4, 0.8, 0.8, 1.0] 27 | [0.8, 1.0, 0.6, 1.0] 28 | [0.4, 1.0, 0.8, 0.6] 29 | [0.6, 0.6, 1.0, 0.2] 30 | [1.0, 0.2, 0.4, 0.8] 31 | [0.6, 0.4, 0.6, 0.8] 32 | [0.6, 0.4, 0.4, 0.4] 33 | [0.6, 0.4, 0.6, 0.6] 34 | [1.0, 0.6, 0.2, 0.8] 35 | [1.0, 0.8, 0.6, 0.6] 36 | [0.8, 0.8, 0.8, 0.4] 37 | [0.6, 0.8, 0.4, 0.4] 38 | [0.8, 0.4, 0.6, 0.6] 39 | [1.0, 0.8, 1.0, 0.6] 40 | [0.8, 0.8, 0.0, 0.8] 41 | [0.8, 0.6, 0.6, 0.8] 42 | [0.8, 0.8, 0.8, 0.2] 43 | [0.8, 0.4, 0.8, 0.4] 44 | [1.0, 0.8, 0.4, 0.4] 45 | [0.8, 1.0, 0.8, 1.0] 46 | [0.4, 0.8, 0.6, 0.6] 47 | [1.0, 0.6, 0.6, 0.4] 48 | [1.0, 0.4, 1.0, 0.4] 49 | [0.8, 0.8, 0.4, 0.6] 50 | [0.8, 0.6, 1.0, 0.6] 51 | [0.8, 0.8, 0.6, 0.8] 52 | [0.8, 0.6, 0.6, 0.4] 53 | [0.8, 0.6, 0.8, 1.0] 54 | [1.0, 0.8, 0.6, 0.8] 55 | [0.8, 1.0, 0.8, 0.6] 56 | [0.8, 0.4, 0.4, 1.0] 57 | [0.8, 0.6, 0.8, 0.4] 58 | [0.6, 1.0, 0.6, 0.6] 59 | [0.8, 0.6, 0.6, 0.8] 60 | [0.6, 0.8, 0.6, 1.0] 61 | [0.6, 0.8, 1.0, 0.8] 62 | [0.8, 0.8, 0.4, 0.6] 63 | [0.8, 0.4, 0.4, 0.6] 64 | [0.8, 0.4, 0.6, 0.8] 65 | [0.2, 0.6, 0.8, 0.8] 66 | [1.0, 0.6, 0.4, 0.8] 67 | [0.4, 1.0, 0.8, 0.8] 68 | [1.0, 0.4, 0.6, 0.6] 69 | [0.6, 0.4, 0.6, 0.6] 70 | [0.2, 0.8, 0.8, 0.8] 71 | [1.0, 1.0, 0.6, 0.6] 72 | [1.0, 1.0, 0.8, 0.6] 73 | [0.6, 0.4, 0.8, 0.8] 74 | [1.0, 0.8, 0.4, 0.4] 75 | [0.0, 0.8, 1.0, 0.6] 76 | [0.8, 0.8, 1.0, 0.8] 77 | [1.0, 0.8, 0.8, 0.4] 78 | [1.0, 0.6, 0.8, 0.8] 79 | [0.6, 0.8, 0.6, 0.8] 80 | [0.8, 1.0, 0.8, 0.6] 81 | [0.6, 1.0, 0.6, 0.8] 82 | [0.6, 1.0, 0.6, 0.8] 83 | [0.8, 0.8, 0.4, 0.4] 84 | [1.0, 1.0, 0.6, 0.4] 85 | [0.6, 0.2, 0.4, 0.6] 86 | [1.0, 1.0, 0.8, 1.0] 87 | [0.6, 0.6, 0.6, 0.8] 88 | [1.0, 0.6, 0.2, 0.6] 89 | [0.6, 0.8, 1.0, 0.6] 90 | [1.0, 0.6, 1.0, 0.6] 91 | [0.6, 1.0, 1.0, 1.0] 92 | [0.2, 0.4, 0.8, 0.8] 93 | [1.0, 0.6, 0.8, 0.8] 94 | [0.8, 0.8, 0.8, 0.4] 95 | [1.0, 0.8, 0.8, 0.8] 96 | [0.8, 0.8, 0.4, 0.4] 97 | [0.8, 0.6, 0.6, 0.8] 98 | [1.0, 0.8, 0.6, 1.0] 99 | [0.8, 0.8, 0.2, 1.0] 100 | [0.8, 0.8, 0.6, 0.8] 101 | [1.0, 0.4, 1.0, 0.8] 102 | [1.0, 0.8, 0.6, 0.6] 103 | [1.0, 0.6, 0.8, 0.4] 104 | [0.8, 0.8, 0.4, 0.6] 105 | [0.8, 0.8, 0.8, 0.6] 106 | [1.0, 0.8, 0.6, 0.8] 107 | [0.8, 0.6, 1.0, 0.6] 108 | [1.0, 0.8, 0.8, 0.4] 109 | [1.0, 0.8, 0.6, 1.0] 110 | [1.0, 0.8, 0.8, 1.0] 111 | [0.8, 0.6, 1.0, 0.6] 112 | [0.4, 0.8, 0.8, 0.4] 113 | [0.6, 0.4, 0.8, 0.6] 114 | [0.6, 1.0, 0.6, 0.8] 115 | [0.8, 1.0, 1.0, 0.6] 116 | [1.0, 1.0, 0.6, 0.6] 117 | [0.6, 0.8, 0.8, 0.6] 118 | [1.0, 0.8, 0.8, 0.6] 119 | [0.8, 1.0, 0.8, 0.0] 120 | [0.8, 0.4, 0.8, 1.0] 121 | [1.0, 1.0, 0.8, 0.8] 122 | [0.6, 0.8, 1.0, 0.4] 123 | [1.0, 1.0, 0.6, 0.8] 124 | [1.0, 0.6, 1.0, 1.0] 125 | [1.0, 0.8, 0.8, 1.0] 126 | [1.0, 0.8, 0.4, 0.8] 127 | [0.8, 1.0, 0.8, 0.6] 128 | [0.8, 0.6, 0.6, 0.6] 129 | [0.8, 0.6, 1.0, 1.0] 130 | [0.8, 1.0, 0.4, 1.0] 131 | [0.8, 0.4, 1.0, 0.8] 132 | [0.6, 1.0, 0.6, 1.0] 133 | [1.0, 1.0, 0.6, 0.6] 134 | [1.0, 0.4, 0.6, 0.6] 135 | [0.8, 1.0, 0.6, 0.6] 136 | [1.0, 0.4, 0.8, 0.6] 137 | [0.8, 1.0, 1.0, 0.8] 138 | [0.6, 0.8, 0.8, 0.6] 139 | [0.6, 1.0, 0.2, 0.6] 140 | [0.8, 0.6, 0.8, 0.6] 141 | [1.0, 1.0, 0.8, 0.4] 142 | [0.8, 0.8, 0.6, 0.6] 143 | [1.0, 0.8, 0.4, 0.6] 144 | [0.8, 0.8, 1.0, 1.0] 145 | [0.6, 0.8, 1.0, 0.4] 146 | [0.8, 1.0, 1.0, 0.4] 147 | [0.8, 0.8, 0.4, 0.8] 148 | [1.0, 0.8, 0.6, 0.8] 149 | [1.0, 0.6, 0.8, 0.8] 150 | [1.0, 0.8, 0.8, 0.4] 151 | [1.0, 1.0, 0.8, 0.8] 152 | [0.8, 0.6, 0.4, 1.0] 153 | [0.8, 0.6, 0.4, 0.4] 154 | [1.0, 0.8, 0.6, 0.4] 155 | [1.0, 0.4, 0.6, 1.0] 156 | [0.8, 0.6, 0.6, 0.8] 157 | [0.8, 0.8, 0.8, 0.6] 158 | [0.6, 0.8, 1.0, 0.6] 159 | [0.6, 1.0, 0.8, 0.6] 160 | [1.0, 0.4, 1.0, 0.6] 161 | [1.0, 0.6, 0.6, 0.6] 162 | [1.0, 0.8, 1.0, 0.4] 163 | [1.0, 1.0, 0.6, 0.8] 164 | [1.0, 1.0, 0.2, 0.6] 165 | [0.8, 0.8, 0.6, 0.8] 166 | [0.8, 0.6, 0.6, 0.8] 167 | [0.8, 0.6, 0.8, 0.4] 168 | [1.0, 1.0, 1.0, 0.2] 169 | [0.6, 1.0, 0.6, 1.0] 170 | [0.8, 1.0, 0.6, 0.6] 171 | [0.6, 0.8, 0.6, 0.8] 172 | [0.8, 1.0, 0.8, 0.4] 173 | [0.8, 1.0, 0.8, 1.0] 174 | [0.4, 0.6, 0.8, 0.8] 175 | [0.6, 0.6, 0.8, 0.6] 176 | [0.2, 1.0, 1.0, 0.8] 177 | [1.0, 0.6, 0.8, 0.8] 178 | [0.8, 1.0, 1.0, 0.8] 179 | [1.0, 0.8, 0.8, 1.0] 180 | [0.8, 0.6, 1.0, 0.8] 181 | [0.8, 1.0, 0.6, 0.4] 182 | [0.4, 0.8, 0.6, 1.0] 183 | [1.0, 0.8, 1.0, 0.6] 184 | [0.6, 0.6, 0.6, 0.6] 185 | [0.8, 1.0, 1.0, 0.6] 186 | [1.0, 1.0, 1.0, 0.6] 187 | [1.0, 0.8, 0.6, 1.0] 188 | [0.8, 0.6, 1.0, 1.0] 189 | [0.2, 0.4, 0.6, 0.6] 190 | [0.8, 1.0, 0.8, 1.0] 191 | [1.0, 0.8, 0.8, 1.0] 192 | [1.0, 0.8, 0.8, 0.8] 193 | [1.0, 0.6, 0.6, 0.6] 194 | [0.6, 0.8, 0.4, 0.8] 195 | [0.6, 0.6, 0.8, 0.8] 196 | [1.0, 0.4, 1.0, 0.6] 197 | [0.6, 0.8, 0.8, 0.4] 198 | [1.0, 0.8, 1.0, 0.6] 199 | [0.4, 0.4, 0.8, 0.6] 200 | [1.0, 0.8, 0.6, 0.8] 201 | [1.0, 1.0, 1.0, 0.8] 202 | [1.0, 1.0, 0.8, 1.0] 203 | [0.6, 0.6, 0.0, 0.6] 204 | [0.4, 0.8, 0.6, 0.4] 205 | [0.6, 0.6, 0.6, 0.8] 206 | [0.6, 0.8, 0.6, 0.6] 207 | [1.0, 0.8, 0.4, 1.0] 208 | [0.8, 0.4, 0.8, 0.4] 209 | [0.8, 1.0, 1.0, 0.4] 210 | [1.0, 1.0, 1.0, 0.4] 211 | [0.6, 0.4, 1.0, 0.4] 212 | [0.6, 0.8, 1.0, 1.0] 213 | [1.0, 0.4, 1.0, 0.6] 214 | [0.6, 1.0, 1.0, 0.4] 215 | [0.4, 0.8, 1.0, 1.0] 216 | [0.6, 0.6, 0.6, 0.8] 217 | [1.0, 0.8, 0.8, 0.8] 218 | [1.0, 0.6, 0.8, 1.0] 219 | [0.8, 0.8, 0.6, 0.8] 220 | [1.0, 0.6, 0.6, 0.8] 221 | [0.8, 0.8, 0.8, 0.8] 222 | [1.0, 0.6, 0.8, 1.0] 223 | [1.0, 0.8, 0.8, 0.6] 224 | [0.8, 0.6, 0.6, 0.8] 225 | [1.0, 1.0, 0.4, 0.8] 226 | [1.0, 1.0, 0.8, 0.6] 227 | [1.0, 0.8, 0.8, 1.0] 228 | [1.0, 0.8, 0.8, 0.8] 229 | [0.8, 0.8, 0.8, 0.6] 230 | [0.8, 1.0, 0.2, 0.8] 231 | [1.0, 1.0, 0.8, 0.8] 232 | [0.6, 0.8, 1.0, 1.0] 233 | [1.0, 0.6, 0.6, 0.2] 234 | [0.8, 0.6, 0.4, 1.0] 235 | [0.8, 0.8, 0.6, 1.0] 236 | [1.0, 0.4, 0.4, 0.4] 237 | [0.8, 1.0, 0.8, 0.4] 238 | [1.0, 1.0, 0.6, 1.0] 239 | [0.8, 0.8, 0.8, 1.0] 240 | [0.6, 0.6, 0.6, 0.8] 241 | [0.8, 1.0, 1.0, 0.6] 242 | [0.6, 0.6, 1.0, 0.6] 243 | [0.6, 0.6, 0.8, 0.8] 244 | [0.8, 0.8, 1.0, 0.4] 245 | [0.8, 0.6, 0.4, 0.6] 246 | [0.8, 1.0, 0.4, 1.0] 247 | [0.8, 0.8, 0.6, 1.0] 248 | [1.0, 0.8, 1.0, 1.0] 249 | [0.8, 0.8, 1.0, 0.8] 250 | [1.0, 1.0, 0.6, 0.8] 251 | [0.8, 0.6, 0.6, 0.8] 252 | [1.0, 0.6, 0.6, 0.8] 253 | [0.8, 1.0, 0.8, 0.6] 254 | [1.0, 0.4, 1.0, 0.8] 255 | [0.8, 1.0, 0.6, 0.8] 256 | [1.0, 0.8, 1.0, 1.0] 257 | [0.8, 0.8, 0.8, 1.0] 258 | [0.4, 0.8, 1.0, 0.8] 259 | [0.6, 0.8, 0.8, 1.0] 260 | [0.8, 0.8, 0.6, 1.0] 261 | [0.6, 1.0, 0.6, 0.6] 262 | [0.6, 0.6, 1.0, 1.0] 263 | [1.0, 0.8, 0.8, 0.6] 264 | [0.8, 0.8, 0.8, 1.0] 265 | [1.0, 0.8, 0.8, 0.8] 266 | [0.8, 1.0, 0.8, 0.2] 267 | [0.8, 1.0, 0.8, 1.0] 268 | [0.8, 0.6, 0.6, 0.8] 269 | [1.0, 0.8, 1.0, 0.4] 270 | [0.8, 1.0, 0.6, 0.8] 271 | [0.6, 0.6, 0.8, 0.8] 272 | [1.0, 1.0, 1.0, 1.0] 273 | [1.0, 0.6, 1.0, 1.0] 274 | [1.0, 0.8, 1.0, 0.8] 275 | [1.0, 0.8, 0.8, 0.8] 276 | [1.0, 0.6, 1.0, 0.6] 277 | [0.8, 1.0, 1.0, 0.8] 278 | [1.0, 0.8, 0.8, 0.4] 279 | [0.8, 0.6, 0.6, 0.8] 280 | [1.0, 1.0, 0.8, 1.0] 281 | [1.0, 0.4, 1.0, 0.8] 282 | [0.8, 0.8, 0.8, 1.0] 283 | [0.6, 0.8, 0.8, 1.0] 284 | [0.8, 0.8, 0.8, 0.6] 285 | [0.6, 1.0, 1.0, 1.0] 286 | [0.6, 1.0, 0.6, 0.8] 287 | [1.0, 1.0, 1.0, 1.0] 288 | [0.8, 0.8, 0.6, 1.0] 289 | [1.0, 0.6, 1.0, 0.4] 290 | [0.8, 1.0, 1.0, 0.8] 291 | [1.0, 0.8, 0.8, 0.8] 292 | [1.0, 0.6, 1.0, 0.6] 293 | [1.0, 1.0, 0.6, 0.6] 294 | [0.8, 0.8, 0.8, 0.8] 295 | [1.0, 0.8, 0.4, 0.8] 296 | [1.0, 1.0, 1.0, 0.6] 297 | [1.0, 1.0, 0.8, 0.4] 298 | [0.6, 1.0, 1.0, 1.0] 299 | [0.8, 1.0, 1.0, 0.6] 300 | [1.0, 0.6, 1.0, 0.6] 301 | [0.6, 0.4, 0.8, 0.8] 302 | [1.0, 0.6, 1.0, 0.8] 303 | [0.6, 0.6, 0.6, 0.6] 304 | [1.0, 1.0, 1.0, 0.8] 305 | [1.0, 1.0, 0.8, 1.0] 306 | [0.8, 1.0, 0.8, 0.8] 307 | [0.8, 0.6, 1.0, 1.0] 308 | [1.0, 0.6, 1.0, 0.8] 309 | [1.0, 0.6, 0.6, 0.6] 310 | [1.0, 0.4, 0.6, 0.4] 311 | [1.0, 1.0, 1.0, 0.8] 312 | [0.8, 1.0, 0.8, 0.8] 313 | [1.0, 1.0, 1.0, 0.6] 314 | [0.8, 0.8, 0.8, 1.0] 315 | [1.0, 0.8, 0.6, 0.8] 316 | [0.6, 0.8, 0.6, 1.0] 317 | [1.0, 1.0, 0.8, 0.8] 318 | [1.0, 0.8, 1.0, 0.8] 319 | [1.0, 0.8, 0.8, 1.0] 320 | [1.0, 0.8, 1.0, 0.6] 321 | [0.8, 0.8, 1.0, 0.6] 322 | [1.0, 1.0, 0.8, 1.0] 323 | [0.8, 0.6, 0.4, 1.0] 324 | [1.0, 0.8, 0.6, 0.8] 325 | [1.0, 0.8, 0.8, 0.6] 326 | [0.6, 1.0, 0.8, 0.8] 327 | [1.0, 1.0, 0.6, 0.6] 328 | [1.0, 0.8, 0.6, 0.4] 329 | [1.0, 0.6, 0.8, 0.8] 330 | [1.0, 1.0, 0.8, 0.6] 331 | [1.0, 0.6, 1.0, 0.6] 332 | [1.0, 1.0, 1.0, 0.6] 333 | [0.8, 0.8, 1.0, 1.0] 334 | [1.0, 1.0, 0.8, 0.8] 335 | [1.0, 0.6, 0.8, 0.8] 336 | [1.0, 0.6, 0.6, 1.0] 337 | [1.0, 0.4, 1.0, 0.6] 338 | [0.8, 1.0, 0.8, 0.8] 339 | [0.8, 0.6, 1.0, 0.4] 340 | [0.6, 0.8, 1.0, 1.0] 341 | [0.8, 0.6, 0.4, 1.0] 342 | [1.0, 0.8, 0.6, 0.8] 343 | [1.0, 1.0, 0.8, 1.0] 344 | [1.0, 0.8, 0.8, 0.6] 345 | [1.0, 1.0, 0.6, 0.6] 346 | [0.8, 1.0, 1.0, 1.0] 347 | [1.0, 1.0, 1.0, 1.0] 348 | [1.0, 1.0, 0.8, 1.0] 349 | [1.0, 0.4, 0.6, 1.0] 350 | [0.8, 1.0, 0.8, 1.0] 351 | [0.8, 0.6, 0.6, 0.4] 352 | [1.0, 1.0, 0.8, 0.8] 353 | [1.0, 0.8, 0.4, 0.6] 354 | [0.8, 0.6, 1.0, 1.0] 355 | [0.8, 1.0, 0.8, 0.8] 356 | [1.0, 0.8, 1.0, 0.6] 357 | [1.0, 1.0, 1.0, 1.0] 358 | [0.8, 0.8, 0.8, 0.8] 359 | [1.0, 1.0, 0.8, 1.0] 360 | [0.8, 1.0, 0.8, 0.8] 361 | [1.0, 1.0, 1.0, 1.0] 362 | [1.0, 1.0, 1.0, 1.0] 363 | [1.0, 0.8, 1.0, 1.0] 364 | [0.8, 1.0, 1.0, 0.8] 365 | [0.8, 0.8, 1.0, 0.8] 366 | [0.6, 1.0, 0.8, 0.8] 367 | [0.8, 0.8, 0.4, 1.0] 368 | [1.0, 1.0, 1.0, 1.0] 369 | [1.0, 0.8, 0.6, 0.6] 370 | [1.0, 0.8, 1.0, 1.0] 371 | [1.0, 0.8, 1.0, 0.6] 372 | [1.0, 0.8, 0.8, 1.0] 373 | [1.0, 1.0, 0.8, 1.0] 374 | [1.0, 1.0, 1.0, 0.4] 375 | [0.8, 0.8, 0.8, 1.0] 376 | [1.0, 1.0, 0.8, 0.6] 377 | [1.0, 0.8, 1.0, 0.6] 378 | [1.0, 0.8, 1.0, 1.0] 379 | [1.0, 1.0, 1.0, 0.6] 380 | [0.8, 1.0, 0.8, 0.8] 381 | [1.0, 1.0, 0.6, 0.6] 382 | [1.0, 0.8, 0.8, 1.0] 383 | [1.0, 1.0, 1.0, 0.4] 384 | [1.0, 0.8, 1.0, 0.6] 385 | [0.8, 1.0, 0.6, 1.0] 386 | [0.6, 0.8, 1.0, 1.0] 387 | [0.8, 1.0, 0.8, 1.0] 388 | [1.0, 1.0, 0.8, 0.8] 389 | [1.0, 1.0, 1.0, 0.8] 390 | [0.2, 1.0, 0.6, 1.0] 391 | [1.0, 0.8, 0.6, 1.0] 392 | [0.8, 0.8, 0.8, 1.0] 393 | [1.0, 1.0, 1.0, 1.0] 394 | [1.0, 0.4, 1.0, 1.0] 395 | [0.8, 1.0, 0.8, 0.8] 396 | [1.0, 1.0, 1.0, 1.0] 397 | [0.6, 1.0, 0.6, 1.0] 398 | [0.8, 0.8, 1.0, 0.8] 399 | [1.0, 1.0, 0.8, 0.6] 400 | [1.0, 0.8, 0.8, 0.6] 401 | [0.8, 0.6, 0.6, 1.0] 402 | [1.0, 0.8, 1.0, 1.0] 403 | [0.6, 1.0, 0.8, 0.8] 404 | [0.8, 0.8, 0.8, 0.8] 405 | [0.8, 1.0, 0.8, 0.6] 406 | [1.0, 0.8, 0.8, 0.8] 407 | [1.0, 0.4, 0.8, 1.0] 408 | [0.8, 0.4, 1.0, 1.0] 409 | [1.0, 0.8, 0.6, 0.8] 410 | [1.0, 1.0, 1.0, 1.0] 411 | [1.0, 0.8, 0.8, 0.8] 412 | [0.6, 1.0, 1.0, 1.0] 413 | [1.0, 0.4, 1.0, 0.8] 414 | [0.8, 0.6, 0.8, 0.8] 415 | [1.0, 1.0, 0.6, 1.0] 416 | [1.0, 1.0, 0.6, 0.8] 417 | [0.8, 1.0, 0.6, 1.0] 418 | [0.8, 0.8, 1.0, 0.6] 419 | [1.0, 1.0, 0.6, 0.8] 420 | [0.8, 0.8, 1.0, 1.0] 421 | [1.0, 1.0, 0.6, 1.0] 422 | [0.8, 0.6, 1.0, 0.8] 423 | [1.0, 0.6, 0.8, 0.6] 424 | [1.0, 0.6, 0.6, 1.0] 425 | [0.8, 0.6, 0.8, 0.6] 426 | [1.0, 0.6, 0.6, 0.8] 427 | [0.8, 0.8, 0.4, 1.0] 428 | [0.8, 0.8, 0.6, 0.6] 429 | [0.8, 1.0, 1.0, 1.0] 430 | [0.6, 0.8, 0.6, 0.4] 431 | [1.0, 1.0, 1.0, 0.8] 432 | [1.0, 0.8, 0.4, 1.0] 433 | [1.0, 1.0, 0.4, 0.6] 434 | [0.8, 1.0, 0.8, 0.6] 435 | [1.0, 0.8, 1.0, 0.8] 436 | [0.6, 0.6, 1.0, 0.6] 437 | [1.0, 0.4, 1.0, 1.0] 438 | [0.8, 0.6, 1.0, 0.8] 439 | [0.8, 0.8, 0.8, 0.8] 440 | [0.8, 0.8, 0.6, 1.0] 441 | [0.6, 1.0, 1.0, 0.8] 442 | [1.0, 1.0, 1.0, 1.0] 443 | [0.8, 0.8, 1.0, 0.6] 444 | [0.8, 1.0, 0.8, 1.0] 445 | [0.8, 0.8, 0.8, 0.6] 446 | [1.0, 1.0, 0.8, 0.4] 447 | [1.0, 0.8, 0.6, 1.0] 448 | [1.0, 0.8, 1.0, 0.6] 449 | [0.8, 1.0, 0.8, 0.8] 450 | [0.8, 1.0, 0.8, 1.0] 451 | [1.0, 1.0, 0.8, 0.8] 452 | [1.0, 0.8, 0.6, 0.6] 453 | [0.6, 0.6, 1.0, 0.8] 454 | [1.0, 1.0, 0.8, 0.8] 455 | [0.8, 1.0, 1.0, 0.8] 456 | [0.8, 1.0, 0.8, 0.8] 457 | [0.8, 1.0, 0.6, 0.8] 458 | [0.6, 0.8, 1.0, 1.0] 459 | [1.0, 0.8, 0.6, 1.0] 460 | [0.6, 0.8, 1.0, 0.6] 461 | [1.0, 1.0, 0.8, 0.6] 462 | [1.0, 0.8, 0.8, 1.0] 463 | [1.0, 1.0, 1.0, 1.0] 464 | [1.0, 0.6, 0.8, 1.0] 465 | [1.0, 1.0, 1.0, 1.0] 466 | [0.8, 0.8, 0.8, 0.6] 467 | [0.8, 0.8, 0.8, 0.4] 468 | [0.8, 0.6, 0.8, 1.0] 469 | [1.0, 0.8, 0.6, 0.8] 470 | [1.0, 1.0, 0.8, 0.8] 471 | [1.0, 0.8, 1.0, 1.0] 472 | [0.8, 1.0, 1.0, 0.4] 473 | [0.8, 0.6, 0.8, 0.8] 474 | [1.0, 1.0, 1.0, 0.8] 475 | [0.8, 0.8, 1.0, 1.0] 476 | [1.0, 1.0, 0.8, 0.8] 477 | [1.0, 0.8, 0.8, 0.6] 478 | [0.8, 1.0, 0.8, 0.8] 479 | [1.0, 0.8, 1.0, 0.6] 480 | [1.0, 0.8, 1.0, 0.8] 481 | [0.8, 0.8, 0.8, 1.0] 482 | [1.0, 0.6, 0.6, 1.0] 483 | [1.0, 0.6, 0.8, 1.0] 484 | [0.8, 1.0, 1.0, 1.0] 485 | [1.0, 1.0, 1.0, 0.4] 486 | [0.4, 0.6, 0.8, 0.8] 487 | [1.0, 0.6, 1.0, 0.8] 488 | [0.8, 0.8, 0.8, 0.8] 489 | [0.8, 0.8, 0.8, 0.8] 490 | [1.0, 1.0, 0.8, 0.4] 491 | [0.8, 1.0, 0.6, 0.8] 492 | [0.8, 0.8, 1.0, 1.0] 493 | [1.0, 0.8, 0.8, 1.0] 494 | [0.8, 0.8, 1.0, 1.0] 495 | [0.8, 0.8, 0.8, 0.6] 496 | [1.0, 1.0, 1.0, 0.8] 497 | [0.8, 1.0, 0.4, 1.0] 498 | [0.6, 0.8, 1.0, 0.8] 499 | [1.0, 1.0, 0.8, 0.6] 500 | [1.0, 0.8, 1.0, 1.0] 501 | [0.8, 1.0, 0.6, 0.6] 502 | [0.8, 0.8, 0.8, 0.6] 503 | [0.8, 1.0, 0.8, 0.8] 504 | [1.0, 0.8, 0.8, 0.8] 505 | [1.0, 0.8, 1.0, 0.4] 506 | [1.0, 0.8, 1.0, 1.0] 507 | [1.0, 1.0, 1.0, 0.6] 508 | [1.0, 0.8, 1.0, 1.0] 509 | [0.8, 0.8, 0.8, 1.0] 510 | [1.0, 0.8, 0.8, 0.8] 511 | [0.8, 0.8, 1.0, 0.8] 512 | [0.8, 1.0, 1.0, 0.8] 513 | [0.8, 1.0, 0.6, 0.4] 514 | [0.8, 0.8, 0.8, 0.8] 515 | [0.8, 0.8, 1.0, 1.0] 516 | [1.0, 1.0, 1.0, 0.4] 517 | [1.0, 1.0, 0.6, 0.8] 518 | [1.0, 1.0, 0.8, 0.8] 519 | [0.8, 0.8, 0.8, 0.8] 520 | [1.0, 1.0, 0.8, 1.0] 521 | [0.8, 0.6, 0.8, 0.6] 522 | [0.8, 0.8, 0.8, 0.8] 523 | [1.0, 0.8, 0.6, 0.8] 524 | [0.8, 0.6, 1.0, 0.8] 525 | [0.6, 1.0, 0.6, 0.8] 526 | [1.0, 0.8, 1.0, 1.0] 527 | [1.0, 0.6, 0.8, 1.0] 528 | [0.8, 0.8, 0.6, 1.0] 529 | [0.8, 0.8, 1.0, 1.0] 530 | [0.8, 0.6, 1.0, 0.6] 531 | [1.0, 1.0, 1.0, 0.8] 532 | [0.8, 1.0, 0.8, 1.0] 533 | [1.0, 0.8, 0.6, 0.6] 534 | [0.8, 1.0, 1.0, 0.6] 535 | [0.8, 0.8, 0.8, 0.8] 536 | [0.6, 1.0, 0.6, 0.6] 537 | [1.0, 0.8, 0.8, 1.0] 538 | [0.8, 1.0, 0.6, 1.0] 539 | [0.8, 1.0, 0.8, 0.8] 540 | [0.8, 0.8, 0.4, 0.8] 541 | [1.0, 1.0, 0.8, 1.0] 542 | [0.8, 1.0, 0.8, 0.8] 543 | [0.8, 1.0, 0.8, 0.4] 544 | [0.6, 0.8, 1.0, 0.8] 545 | [1.0, 1.0, 0.8, 1.0] 546 | [1.0, 1.0, 1.0, 0.8] 547 | [1.0, 1.0, 1.0, 0.8] 548 | [0.6, 0.8, 0.4, 0.4] 549 | [1.0, 0.8, 0.8, 1.0] 550 | [1.0, 1.0, 1.0, 0.8] 551 | [0.8, 1.0, 1.0, 1.0] 552 | [0.8, 1.0, 0.8, 1.0] 553 | [1.0, 0.8, 1.0, 0.6] 554 | [0.8, 1.0, 0.8, 1.0] 555 | [1.0, 0.8, 0.4, 0.6] 556 | [1.0, 1.0, 1.0, 0.8] 557 | [1.0, 0.8, 0.8, 0.8] 558 | [1.0, 1.0, 1.0, 1.0] 559 | [0.6, 1.0, 0.8, 0.6] 560 | [1.0, 0.6, 0.6, 0.8] 561 | [1.0, 1.0, 0.4, 1.0] 562 | [1.0, 0.8, 0.8, 0.6] 563 | [0.8, 0.8, 0.6, 0.8] 564 | [1.0, 1.0, 1.0, 0.8] 565 | [1.0, 0.6, 0.6, 1.0] 566 | [1.0, 1.0, 1.0, 0.8] 567 | [1.0, 0.8, 1.0, 1.0] 568 | [1.0, 1.0, 0.8, 0.8] 569 | [1.0, 1.0, 1.0, 1.0] 570 | [1.0, 0.8, 0.8, 0.8] 571 | [1.0, 0.8, 0.4, 0.6] 572 | [1.0, 1.0, 0.8, 0.8] 573 | [1.0, 1.0, 0.6, 1.0] 574 | [0.8, 0.8, 1.0, 0.6] 575 | [0.8, 1.0, 0.4, 0.8] 576 | [1.0, 0.6, 1.0, 1.0] 577 | [1.0, 0.8, 1.0, 1.0] 578 | [1.0, 1.0, 1.0, 0.6] 579 | [0.8, 0.8, 1.0, 0.8] 580 | [1.0, 0.8, 0.6, 0.8] 581 | [1.0, 1.0, 0.8, 0.8] 582 | [1.0, 0.8, 1.0, 1.0] 583 | [1.0, 0.8, 0.8, 0.6] 584 | [0.8, 1.0, 1.0, 0.8] 585 | [0.8, 1.0, 0.4, 1.0] 586 | [1.0, 0.8, 1.0, 1.0] 587 | [1.0, 0.8, 1.0, 0.8] 588 | [1.0, 0.8, 0.4, 0.8] 589 | [0.8, 0.6, 0.6, 1.0] 590 | [1.0, 0.8, 0.8, 0.6] 591 | [0.6, 0.6, 0.8, 0.6] 592 | [1.0, 1.0, 1.0, 0.8] 593 | [0.8, 0.6, 0.8, 0.8] 594 | [1.0, 1.0, 0.8, 0.8] 595 | [1.0, 1.0, 1.0, 1.0] 596 | [1.0, 0.4, 1.0, 0.6] 597 | [0.8, 0.8, 0.8, 1.0] 598 | [1.0, 1.0, 1.0, 1.0] 599 | [0.8, 1.0, 1.0, 0.6] 600 | [0.6, 0.6, 0.6, 0.6] 601 | [0.8, 0.8, 0.8, 0.8] 602 | [1.0, 1.0, 0.8, 1.0] 603 | [1.0, 1.0, 0.6, 0.8] 604 | [0.8, 0.8, 0.8, 0.8] 605 | [1.0, 1.0, 1.0, 0.8] 606 | [1.0, 1.0, 0.4, 0.8] 607 | [1.0, 0.4, 0.6, 0.8] 608 | [0.6, 1.0, 0.6, 0.4] 609 | [0.8, 0.6, 0.8, 0.8] 610 | [1.0, 0.8, 1.0, 1.0] 611 | [1.0, 0.8, 1.0, 0.8] 612 | [0.8, 1.0, 1.0, 1.0] 613 | [1.0, 0.8, 1.0, 0.6] 614 | [1.0, 0.8, 1.0, 0.8] 615 | [0.8, 1.0, 0.4, 0.8] 616 | [1.0, 1.0, 1.0, 0.8] 617 | [0.8, 1.0, 0.8, 0.6] 618 | [0.8, 0.8, 0.8, 0.8] 619 | [0.8, 0.8, 1.0, 0.6] 620 | [0.8, 1.0, 1.0, 0.8] 621 | [1.0, 0.8, 1.0, 0.8] 622 | [0.8, 1.0, 0.8, 0.8] 623 | [1.0, 1.0, 1.0, 1.0] 624 | [1.0, 0.6, 0.8, 0.8] 625 | [1.0, 0.8, 0.8, 0.6] 626 | [1.0, 0.8, 0.6, 0.8] 627 | [1.0, 1.0, 0.6, 1.0] 628 | [0.8, 1.0, 0.8, 0.6] 629 | [1.0, 1.0, 0.8, 0.6] 630 | [1.0, 1.0, 0.6, 0.4] 631 | [1.0, 1.0, 1.0, 0.6] 632 | [1.0, 0.8, 0.8, 0.8] 633 | [1.0, 0.6, 0.6, 0.8] 634 | [0.8, 0.8, 1.0, 0.4] 635 | [0.8, 0.8, 1.0, 1.0] 636 | [1.0, 0.8, 1.0, 0.8] 637 | [0.8, 0.8, 0.8, 0.6] 638 | [1.0, 0.6, 1.0, 1.0] 639 | [1.0, 0.8, 0.8, 1.0] 640 | [1.0, 1.0, 1.0, 1.0] 641 | [1.0, 1.0, 1.0, 0.6] 642 | [1.0, 0.8, 0.6, 0.8] 643 | [0.8, 0.8, 0.8, 0.8] 644 | [0.8, 1.0, 1.0, 0.6] 645 | [0.8, 0.8, 1.0, 0.8] 646 | [1.0, 0.8, 0.8, 0.6] 647 | [1.0, 0.8, 0.8, 1.0] 648 | [1.0, 1.0, 0.8, 0.8] 649 | [1.0, 1.0, 0.8, 0.8] 650 | [0.8, 1.0, 0.6, 0.8] 651 | [1.0, 0.8, 1.0, 0.8] 652 | [1.0, 1.0, 1.0, 0.6] 653 | [1.0, 0.8, 0.8, 0.6] 654 | [1.0, 1.0, 0.8, 1.0] 655 | [1.0, 1.0, 0.8, 0.6] 656 | [1.0, 1.0, 0.8, 1.0] 657 | [1.0, 1.0, 0.6, 1.0] 658 | [0.6, 1.0, 0.8, 1.0] 659 | [1.0, 1.0, 0.8, 0.6] 660 | [1.0, 1.0, 0.6, 1.0] 661 | [1.0, 1.0, 1.0, 0.8] 662 | [1.0, 1.0, 0.8, 0.6] 663 | [1.0, 0.8, 0.6, 0.8] 664 | [0.6, 1.0, 1.0, 1.0] 665 | [1.0, 0.8, 0.8, 0.8] 666 | [1.0, 0.8, 1.0, 0.8] 667 | [1.0, 1.0, 0.6, 0.8] 668 | [1.0, 0.8, 1.0, 1.0] 669 | [1.0, 0.8, 1.0, 0.8] 670 | [1.0, 1.0, 0.6, 0.8] 671 | [1.0, 1.0, 0.8, 0.8] 672 | [1.0, 0.8, 1.0, 1.0] 673 | [0.8, 0.6, 0.8, 0.8] 674 | [0.8, 1.0, 1.0, 0.8] 675 | [1.0, 1.0, 1.0, 1.0] 676 | [1.0, 0.8, 0.8, 0.8] 677 | [0.8, 1.0, 0.8, 0.8] 678 | [1.0, 1.0, 1.0, 0.8] 679 | [0.6, 1.0, 1.0, 1.0] 680 | [1.0, 0.8, 0.6, 0.8] 681 | [1.0, 0.8, 1.0, 0.6] 682 | [1.0, 1.0, 0.8, 0.8] 683 | [0.6, 0.4, 0.4, 0.8] 684 | [0.8, 1.0, 0.8, 0.8] 685 | [1.0, 0.8, 0.8, 0.4] 686 | [1.0, 0.8, 1.0, 1.0] 687 | [1.0, 0.8, 0.8, 0.8] 688 | [0.8, 0.8, 0.8, 0.6] 689 | [1.0, 0.8, 0.8, 0.8] 690 | [0.8, 1.0, 0.6, 1.0] 691 | [1.0, 1.0, 0.8, 0.8] 692 | [1.0, 0.8, 0.8, 0.8] 693 | [1.0, 0.8, 0.8, 0.8] 694 | [0.8, 1.0, 0.8, 1.0] 695 | [1.0, 1.0, 1.0, 0.6] 696 | [1.0, 0.6, 0.8, 1.0] 697 | [1.0, 0.8, 1.0, 1.0] 698 | [0.8, 1.0, 0.8, 1.0] 699 | [1.0, 1.0, 1.0, 1.0] 700 | [1.0, 0.6, 0.6, 0.6] 701 | [0.8, 1.0, 0.8, 1.0] 702 | [1.0, 1.0, 0.8, 0.6] 703 | [0.8, 1.0, 0.8, 0.8] 704 | [1.0, 0.8, 0.8, 0.8] 705 | [1.0, 0.8, 1.0, 1.0] 706 | [1.0, 1.0, 1.0, 0.6] 707 | [0.6, 1.0, 1.0, 0.6] 708 | [1.0, 1.0, 1.0, 0.8] 709 | [0.6, 1.0, 0.8, 1.0] 710 | [1.0, 1.0, 0.8, 0.4] 711 | [1.0, 1.0, 1.0, 0.8] 712 | [1.0, 1.0, 0.8, 0.8] 713 | [0.6, 0.4, 0.8, 0.8] 714 | [0.6, 1.0, 0.8, 1.0] 715 | [1.0, 1.0, 1.0, 0.8] 716 | [1.0, 1.0, 1.0, 1.0] 717 | [0.8, 0.8, 1.0, 1.0] 718 | [1.0, 1.0, 0.6, 1.0] 719 | [0.8, 0.6, 0.8, 0.8] 720 | [0.8, 0.8, 0.8, 0.8] 721 | [0.8, 1.0, 0.8, 1.0] 722 | [0.8, 1.0, 1.0, 0.8] 723 | [0.8, 1.0, 1.0, 1.0] 724 | [1.0, 0.8, 1.0, 1.0] 725 | [0.8, 1.0, 0.6, 1.0] 726 | [0.6, 1.0, 1.0, 1.0] 727 | [1.0, 0.8, 0.6, 0.8] 728 | [1.0, 0.8, 1.0, 0.8] 729 | [0.8, 0.8, 0.6, 0.6] 730 | [0.8, 0.8, 1.0, 0.6] 731 | [0.8, 0.8, 0.8, 0.6] 732 | [0.6, 1.0, 0.6, 1.0] 733 | [0.8, 0.8, 0.6, 1.0] 734 | [0.8, 1.0, 1.0, 0.8] 735 | [1.0, 1.0, 0.8, 1.0] 736 | [0.8, 1.0, 1.0, 0.8] 737 | [1.0, 1.0, 0.8, 1.0] 738 | [0.6, 0.8, 0.6, 0.6] 739 | [0.8, 1.0, 0.8, 1.0] 740 | [1.0, 0.8, 0.8, 0.6] 741 | [1.0, 0.8, 1.0, 1.0] 742 | [0.8, 0.8, 1.0, 0.6] 743 | [1.0, 1.0, 1.0, 1.0] 744 | [1.0, 1.0, 1.0, 1.0] 745 | [0.8, 0.8, 1.0, 0.8] 746 | [1.0, 1.0, 1.0, 0.6] 747 | [1.0, 0.8, 0.8, 1.0] 748 | [0.4, 0.4, 1.0, 0.6] 749 | [1.0, 1.0, 0.8, 1.0] 750 | [1.0, 1.0, 0.8, 0.8] 751 | [0.8, 0.4, 0.8, 0.6] 752 | [0.8, 0.8, 0.8, 0.6] 753 | [0.6, 1.0, 1.0, 0.6] 754 | [0.8, 1.0, 0.8, 0.8] 755 | [1.0, 0.4, 1.0, 1.0] 756 | [1.0, 0.6, 0.8, 1.0] 757 | [1.0, 0.8, 1.0, 1.0] 758 | [1.0, 1.0, 0.8, 0.8] 759 | [1.0, 1.0, 0.6, 0.4] 760 | [1.0, 0.8, 1.0, 0.8] 761 | [0.8, 0.6, 0.6, 1.0] 762 | [0.6, 0.8, 0.8, 0.8] 763 | [0.8, 1.0, 0.8, 1.0] 764 | [0.8, 0.8, 0.8, 1.0] 765 | [1.0, 1.0, 0.8, 0.8] 766 | [1.0, 1.0, 0.8, 0.8] 767 | [0.8, 1.0, 0.8, 0.6] 768 | [0.6, 0.8, 0.6, 0.8] 769 | [1.0, 0.8, 0.8, 0.8] 770 | [1.0, 1.0, 1.0, 0.8] 771 | [1.0, 1.0, 1.0, 1.0] 772 | [0.8, 1.0, 0.6, 0.8] 773 | [1.0, 1.0, 1.0, 1.0] 774 | [1.0, 1.0, 0.8, 0.8] 775 | [0.8, 0.8, 0.8, 1.0] 776 | [1.0, 1.0, 1.0, 1.0] 777 | [1.0, 0.8, 0.8, 0.8] 778 | [1.0, 1.0, 0.8, 0.8] 779 | [1.0, 1.0, 1.0, 0.8] 780 | [0.8, 0.8, 1.0, 0.8] 781 | [0.6, 0.8, 0.8, 0.8] 782 | [0.8, 1.0, 1.0, 1.0] 783 | [0.8, 1.0, 0.8, 0.8] 784 | [0.8, 1.0, 1.0, 0.4] 785 | [1.0, 1.0, 0.6, 0.4] 786 | [0.8, 1.0, 0.8, 1.0] 787 | [1.0, 1.0, 1.0, 1.0] 788 | [1.0, 1.0, 0.6, 0.8] 789 | [0.8, 0.6, 0.8, 0.8] 790 | [1.0, 0.8, 0.8, 1.0] 791 | [1.0, 1.0, 1.0, 1.0] 792 | [0.8, 1.0, 0.8, 1.0] 793 | [1.0, 0.6, 1.0, 0.6] 794 | [0.8, 1.0, 0.6, 0.8] 795 | [1.0, 1.0, 0.8, 0.8] 796 | [1.0, 0.8, 1.0, 0.8] 797 | [1.0, 0.6, 1.0, 1.0] 798 | [1.0, 1.0, 1.0, 0.8] 799 | [1.0, 1.0, 1.0, 0.8] 800 | [1.0, 0.8, 1.0, 0.8] 801 | [1.0, 0.8, 1.0, 1.0] 802 | [0.8, 1.0, 1.0, 0.6] 803 | [0.8, 1.0, 0.8, 0.6] 804 | [1.0, 1.0, 0.8, 0.6] 805 | [0.8, 1.0, 0.8, 0.8] 806 | [1.0, 1.0, 0.6, 1.0] 807 | [1.0, 1.0, 1.0, 1.0] 808 | [1.0, 0.8, 1.0, 0.6] 809 | [0.8, 0.6, 0.6, 1.0] 810 | [1.0, 0.4, 1.0, 0.6] 811 | [1.0, 1.0, 1.0, 1.0] 812 | [1.0, 1.0, 1.0, 0.8] 813 | [0.8, 0.8, 0.8, 1.0] 814 | [0.8, 1.0, 0.8, 0.8] 815 | [0.8, 1.0, 0.6, 1.0] 816 | [1.0, 0.8, 1.0, 0.4] 817 | [1.0, 1.0, 0.8, 0.8] 818 | [1.0, 0.8, 0.8, 0.6] 819 | [1.0, 0.8, 0.8, 1.0] 820 | [0.8, 1.0, 0.8, 1.0] 821 | [1.0, 1.0, 1.0, 1.0] 822 | [0.8, 0.8, 0.8, 1.0] 823 | [1.0, 1.0, 0.8, 0.6] 824 | [1.0, 1.0, 1.0, 1.0] 825 | [1.0, 0.6, 0.8, 0.8] 826 | [1.0, 1.0, 0.8, 0.8] 827 | [1.0, 1.0, 0.8, 0.6] 828 | [1.0, 1.0, 1.0, 1.0] 829 | [1.0, 0.8, 0.8, 0.6] 830 | [1.0, 1.0, 0.4, 1.0] 831 | [1.0, 1.0, 0.6, 1.0] 832 | [1.0, 1.0, 0.8, 0.8] 833 | [0.8, 0.6, 0.8, 1.0] 834 | [1.0, 0.6, 1.0, 0.6] 835 | [0.8, 0.8, 1.0, 0.6] 836 | [1.0, 0.6, 1.0, 1.0] 837 | [1.0, 1.0, 1.0, 1.0] 838 | [1.0, 1.0, 0.8, 1.0] 839 | [1.0, 0.8, 1.0, 1.0] 840 | [1.0, 0.6, 1.0, 1.0] 841 | [1.0, 1.0, 1.0, 0.6] 842 | [1.0, 1.0, 1.0, 0.8] 843 | [0.8, 0.6, 0.8, 0.6] 844 | [1.0, 0.8, 1.0, 0.6] 845 | [0.8, 1.0, 0.6, 0.8] 846 | [1.0, 0.4, 1.0, 1.0] 847 | [0.8, 1.0, 0.4, 0.8] 848 | [1.0, 1.0, 1.0, 1.0] 849 | [1.0, 0.8, 1.0, 0.6] 850 | [1.0, 1.0, 0.6, 0.6] 851 | [1.0, 0.8, 0.6, 1.0] 852 | [1.0, 1.0, 0.8, 0.4] 853 | [1.0, 1.0, 1.0, 0.8] 854 | [1.0, 0.8, 1.0, 0.6] 855 | [1.0, 0.6, 0.8, 0.8] 856 | [1.0, 0.8, 0.8, 1.0] 857 | [0.8, 1.0, 1.0, 1.0] 858 | [1.0, 0.8, 0.6, 1.0] 859 | [1.0, 0.8, 1.0, 0.6] 860 | [1.0, 1.0, 1.0, 1.0] 861 | [1.0, 0.8, 0.8, 0.8] 862 | [1.0, 1.0, 0.6, 0.6] 863 | [0.8, 0.8, 0.8, 0.6] 864 | [1.0, 0.6, 1.0, 1.0] 865 | [1.0, 1.0, 0.8, 0.8] 866 | [1.0, 1.0, 0.8, 0.6] 867 | [1.0, 0.8, 0.8, 0.6] 868 | [1.0, 1.0, 1.0, 1.0] 869 | [1.0, 1.0, 0.8, 0.8] 870 | [1.0, 0.8, 1.0, 0.8] 871 | [0.8, 0.8, 1.0, 1.0] 872 | [0.8, 1.0, 1.0, 0.6] 873 | [1.0, 0.8, 0.8, 1.0] 874 | [1.0, 0.6, 1.0, 0.8] 875 | [0.8, 1.0, 0.6, 0.8] 876 | [1.0, 0.8, 0.8, 0.8] 877 | [1.0, 0.8, 1.0, 0.8] 878 | [0.8, 0.8, 0.6, 1.0] 879 | [1.0, 0.8, 0.6, 0.8] 880 | [1.0, 0.8, 0.8, 0.8] 881 | [1.0, 1.0, 0.6, 0.6] 882 | [1.0, 1.0, 1.0, 0.8] 883 | [0.8, 1.0, 0.8, 0.8] 884 | [1.0, 1.0, 0.6, 0.8] 885 | [1.0, 1.0, 0.8, 0.4] 886 | [0.8, 1.0, 0.4, 0.8] 887 | [0.8, 1.0, 0.6, 1.0] 888 | [1.0, 0.8, 1.0, 0.8] 889 | [1.0, 0.8, 1.0, 1.0] 890 | [1.0, 0.8, 0.8, 0.6] 891 | [0.8, 0.8, 1.0, 0.8] 892 | [0.8, 0.8, 1.0, 0.6] 893 | [1.0, 0.6, 0.6, 0.6] 894 | [1.0, 1.0, 1.0, 0.6] 895 | [0.8, 1.0, 0.6, 1.0] 896 | [0.8, 1.0, 1.0, 0.8] 897 | [1.0, 0.8, 0.8, 0.2] 898 | [1.0, 1.0, 0.2, 1.0] 899 | [0.8, 0.6, 1.0, 1.0] 900 | [1.0, 0.8, 0.8, 1.0] 901 | [0.8, 1.0, 0.8, 0.8] 902 | [0.8, 1.0, 0.6, 1.0] 903 | [0.8, 1.0, 1.0, 0.8] 904 | [1.0, 0.8, 0.6, 1.0] 905 | [1.0, 1.0, 1.0, 1.0] 906 | [0.8, 0.6, 0.6, 0.8] 907 | [0.8, 0.8, 0.6, 1.0] 908 | [1.0, 0.8, 1.0, 0.8] 909 | [0.6, 0.8, 1.0, 1.0] 910 | [1.0, 1.0, 1.0, 0.8] 911 | [1.0, 1.0, 1.0, 0.6] 912 | [1.0, 1.0, 1.0, 1.0] 913 | [1.0, 1.0, 0.0, 0.6] 914 | [0.8, 1.0, 0.8, 1.0] 915 | [1.0, 1.0, 0.8, 1.0] 916 | [1.0, 0.8, 0.8, 1.0] 917 | [0.8, 0.8, 0.8, 0.8] 918 | [0.6, 1.0, 0.6, 0.6] 919 | [1.0, 1.0, 0.8, 0.8] 920 | [1.0, 0.8, 1.0, 0.6] 921 | [0.8, 0.8, 0.8, 1.0] 922 | [1.0, 0.8, 0.8, 0.6] 923 | [0.8, 0.8, 0.8, 0.8] 924 | [1.0, 0.6, 0.6, 0.6] 925 | [0.6, 0.8, 0.8, 1.0] 926 | [1.0, 1.0, 0.8, 0.8] 927 | [1.0, 0.8, 0.8, 0.8] 928 | [0.6, 0.6, 0.8, 0.8] 929 | [1.0, 0.8, 0.4, 1.0] 930 | [0.8, 0.6, 0.6, 0.8] 931 | [0.6, 1.0, 0.8, 0.4] 932 | [1.0, 0.8, 1.0, 0.6] 933 | [1.0, 1.0, 0.6, 1.0] 934 | [0.8, 1.0, 1.0, 1.0] 935 | [1.0, 0.8, 1.0, 1.0] 936 | [0.8, 0.8, 1.0, 1.0] 937 | [1.0, 0.8, 0.8, 1.0] 938 | [0.8, 1.0, 0.8, 0.8] 939 | [0.8, 0.8, 1.0, 0.8] 940 | [1.0, 1.0, 0.6, 1.0] 941 | [1.0, 0.8, 1.0, 0.6] 942 | [0.8, 1.0, 0.6, 0.8] 943 | [0.8, 1.0, 0.8, 1.0] 944 | [0.8, 0.8, 0.8, 0.8] 945 | [1.0, 0.8, 1.0, 0.8] 946 | [0.8, 0.8, 0.6, 0.8] 947 | [1.0, 1.0, 1.0, 1.0] 948 | [1.0, 1.0, 0.8, 0.6] 949 | [1.0, 1.0, 0.8, 0.8] 950 | [1.0, 1.0, 0.4, 0.8] 951 | [1.0, 1.0, 0.8, 0.8] 952 | [1.0, 1.0, 0.6, 0.8] 953 | [1.0, 0.6, 0.8, 0.8] 954 | [0.8, 1.0, 1.0, 1.0] 955 | [1.0, 1.0, 1.0, 1.0] 956 | [1.0, 0.8, 0.8, 0.4] 957 | [0.8, 0.8, 0.8, 1.0] 958 | [0.8, 0.8, 1.0, 1.0] 959 | [0.8, 1.0, 0.4, 0.8] 960 | [1.0, 1.0, 1.0, 1.0] 961 | [1.0, 0.6, 1.0, 0.8] 962 | [1.0, 1.0, 1.0, 1.0] 963 | [1.0, 1.0, 0.8, 1.0] 964 | [1.0, 1.0, 1.0, 1.0] 965 | [1.0, 0.8, 0.8, 1.0] 966 | [1.0, 1.0, 0.8, 0.4] 967 | [1.0, 0.8, 1.0, 1.0] 968 | [1.0, 1.0, 0.8, 0.6] 969 | [1.0, 0.8, 0.6, 0.8] 970 | [1.0, 0.8, 0.8, 0.8] 971 | [1.0, 0.8, 1.0, 0.8] 972 | [1.0, 1.0, 1.0, 1.0] 973 | [0.8, 1.0, 1.0, 1.0] 974 | [1.0, 0.8, 0.8, 1.0] 975 | [1.0, 0.8, 1.0, 0.8] 976 | [1.0, 1.0, 0.8, 0.8] 977 | [1.0, 0.8, 0.6, 1.0] 978 | [1.0, 0.8, 1.0, 0.8] 979 | [1.0, 1.0, 0.8, 0.8] 980 | [1.0, 0.8, 0.8, 1.0] 981 | [1.0, 0.8, 0.8, 0.8] 982 | [1.0, 0.8, 0.8, 1.0] 983 | [1.0, 0.8, 1.0, 0.8] 984 | [1.0, 0.8, 1.0, 0.8] 985 | [1.0, 0.6, 0.8, 0.8] 986 | [1.0, 1.0, 0.8, 1.0] 987 | [1.0, 0.8, 0.6, 0.4] 988 | [1.0, 0.8, 1.0, 1.0] 989 | [1.0, 0.6, 0.8, 0.8] 990 | [1.0, 0.6, 1.0, 1.0] 991 | [0.8, 1.0, 0.8, 1.0] 992 | [1.0, 1.0, 0.8, 0.8] 993 | [1.0, 1.0, 0.8, 1.0] 994 | [1.0, 1.0, 0.8, 1.0] 995 | [1.0, 0.4, 1.0, 0.0] 996 | [1.0, 0.8, 0.6, 1.0] 997 | [1.0, 1.0, 0.4, 1.0] 998 | [1.0, 1.0, 0.8, 1.0] 999 | [1.0, 0.8, 1.0, 0.6] 1000 | [0.6, 0.8, 1.0, 0.8] 1001 | [1.0, 1.0, 1.0, 0.6] 1002 | [1.0, 0.6, 1.0, 1.0] 1003 | [1.0, 1.0, 0.2, 0.6] 1004 | [1.0, 1.0, 1.0, 0.8] 1005 | [1.0, 0.8, 0.8, 1.0] 1006 | [1.0, 1.0, 0.8, 1.0] 1007 | [0.6, 0.8, 1.0, 0.6] 1008 | [1.0, 0.8, 0.8, 0.6] 1009 | [0.8, 0.4, 0.6, 0.6] 1010 | [0.8, 1.0, 1.0, 1.0] 1011 | [1.0, 0.8, 0.8, 0.8] 1012 | [1.0, 1.0, 1.0, 1.0] 1013 | [0.8, 0.8, 0.8, 0.4] 1014 | [1.0, 1.0, 0.8, 0.8] 1015 | [1.0, 1.0, 0.8, 1.0] 1016 | [1.0, 1.0, 1.0, 1.0] 1017 | [1.0, 0.8, 0.8, 0.8] 1018 | [0.8, 0.8, 1.0, 1.0] 1019 | [1.0, 1.0, 0.6, 1.0] 1020 | [0.8, 1.0, 1.0, 1.0] 1021 | [0.8, 1.0, 1.0, 1.0] 1022 | [1.0, 1.0, 1.0, 0.8] 1023 | [1.0, 1.0, 1.0, 1.0] 1024 | [0.6, 0.8, 1.0, 1.0] 1025 | [1.0, 1.0, 1.0, 0.6] 1026 | [0.8, 1.0, 1.0, 1.0] 1027 | [1.0, 0.8, 1.0, 1.0] 1028 | [1.0, 1.0, 0.8, 0.8] 1029 | [1.0, 1.0, 1.0, 1.0] 1030 | [1.0, 0.6, 0.6, 1.0] 1031 | [1.0, 0.8, 1.0, 1.0] 1032 | [1.0, 1.0, 0.8, 0.8] 1033 | [1.0, 1.0, 1.0, 0.8] 1034 | [1.0, 1.0, 1.0, 1.0] 1035 | [0.8, 1.0, 0.8, 0.8] 1036 | [0.8, 0.8, 1.0, 0.8] 1037 | [1.0, 0.8, 0.4, 0.6] 1038 | [1.0, 1.0, 1.0, 1.0] 1039 | [0.8, 0.8, 1.0, 1.0] 1040 | [0.8, 1.0, 1.0, 0.6] 1041 | [1.0, 1.0, 1.0, 1.0] 1042 | [1.0, 1.0, 0.8, 0.8] 1043 | [1.0, 1.0, 1.0, 0.8] 1044 | [1.0, 0.8, 0.8, 1.0] 1045 | [1.0, 0.6, 1.0, 0.4] 1046 | [0.8, 1.0, 0.8, 1.0] 1047 | [0.6, 0.8, 1.0, 0.6] 1048 | [1.0, 0.8, 1.0, 0.6] 1049 | [0.8, 1.0, 1.0, 0.4] 1050 | [1.0, 1.0, 0.6, 1.0] 1051 | [1.0, 1.0, 1.0, 1.0] 1052 | [1.0, 0.8, 0.6, 0.6] 1053 | [1.0, 1.0, 0.8, 1.0] 1054 | [1.0, 0.8, 1.0, 0.8] 1055 | [1.0, 1.0, 1.0, 0.4] 1056 | [1.0, 0.8, 1.0, 0.6] 1057 | [1.0, 0.6, 1.0, 0.6] 1058 | [1.0, 1.0, 0.6, 0.8] 1059 | [0.8, 0.8, 0.6, 0.6] 1060 | [1.0, 0.6, 1.0, 0.8] 1061 | [0.8, 1.0, 0.4, 0.8] 1062 | [1.0, 0.8, 0.8, 0.8] 1063 | [1.0, 0.8, 0.8, 0.6] 1064 | [1.0, 1.0, 1.0, 1.0] 1065 | [1.0, 1.0, 0.8, 0.8] 1066 | [0.8, 1.0, 1.0, 1.0] 1067 | [1.0, 1.0, 1.0, 1.0] 1068 | [1.0, 0.6, 0.8, 1.0] 1069 | [0.8, 0.6, 1.0, 0.8] 1070 | [1.0, 0.8, 0.8, 1.0] 1071 | [0.6, 0.8, 1.0, 0.8] 1072 | [1.0, 0.8, 0.8, 0.8] 1073 | [0.8, 0.8, 0.8, 0.8] 1074 | [0.8, 1.0, 1.0, 0.6] 1075 | [0.8, 0.8, 1.0, 0.6] 1076 | [1.0, 1.0, 1.0, 0.4] 1077 | [1.0, 1.0, 1.0, 0.6] 1078 | [0.8, 1.0, 0.8, 0.8] 1079 | [1.0, 0.8, 1.0, 0.8] 1080 | [0.8, 0.8, 0.4, 0.8] 1081 | [0.8, 1.0, 1.0, 0.4] 1082 | [0.8, 1.0, 1.0, 1.0] 1083 | [1.0, 0.6, 0.8, 0.8] 1084 | [1.0, 0.6, 1.0, 1.0] 1085 | [1.0, 0.8, 0.8, 1.0] 1086 | [1.0, 0.8, 1.0, 1.0] 1087 | [1.0, 1.0, 1.0, 0.6] 1088 | [1.0, 0.8, 1.0, 1.0] 1089 | [1.0, 1.0, 0.8, 0.8] 1090 | [1.0, 0.8, 1.0, 1.0] 1091 | [1.0, 0.4, 0.8, 1.0] 1092 | [0.6, 1.0, 0.8, 0.8] 1093 | [1.0, 0.4, 0.8, 0.8] 1094 | [0.8, 1.0, 1.0, 1.0] 1095 | [0.8, 1.0, 1.0, 0.8] 1096 | [1.0, 1.0, 1.0, 0.8] 1097 | [1.0, 0.6, 0.8, 0.8] 1098 | [0.8, 0.8, 1.0, 1.0] 1099 | [1.0, 1.0, 0.8, 1.0] 1100 | [1.0, 1.0, 0.6, 0.8] 1101 | [0.6, 1.0, 1.0, 0.8] 1102 | [1.0, 0.8, 1.0, 1.0] 1103 | [0.6, 0.6, 0.4, 1.0] 1104 | [0.8, 0.8, 0.8, 0.6] 1105 | [1.0, 1.0, 1.0, 1.0] 1106 | [1.0, 0.8, 0.8, 0.6] 1107 | [1.0, 0.8, 1.0, 1.0] 1108 | [1.0, 1.0, 1.0, 1.0] 1109 | [1.0, 0.6, 0.8, 0.8] 1110 | [1.0, 1.0, 0.8, 0.8] 1111 | [1.0, 1.0, 1.0, 1.0] 1112 | [1.0, 1.0, 0.8, 1.0] 1113 | [0.8, 1.0, 0.8, 1.0] 1114 | [0.4, 0.8, 1.0, 0.6] 1115 | [0.2, 0.8, 1.0, 0.4] 1116 | [1.0, 0.6, 1.0, 0.8] 1117 | [1.0, 1.0, 1.0, 0.8] 1118 | [0.8, 0.8, 1.0, 0.8] 1119 | [1.0, 1.0, 0.8, 1.0] 1120 | [1.0, 1.0, 0.6, 0.4] 1121 | [1.0, 0.4, 1.0, 1.0] 1122 | [0.6, 0.6, 1.0, 0.6] 1123 | [0.6, 0.8, 1.0, 1.0] 1124 | [1.0, 1.0, 0.6, 0.8] 1125 | [0.8, 1.0, 1.0, 0.8] 1126 | [1.0, 0.8, 1.0, 1.0] 1127 | [0.8, 1.0, 0.8, 1.0] 1128 | [0.6, 1.0, 0.6, 0.6] 1129 | [0.6, 1.0, 0.4, 1.0] 1130 | [1.0, 0.8, 1.0, 1.0] 1131 | [0.6, 1.0, 0.8, 0.4] 1132 | [1.0, 0.8, 0.6, 1.0] 1133 | [1.0, 1.0, 1.0, 0.8] 1134 | [1.0, 1.0, 0.6, 0.8] 1135 | [0.8, 1.0, 0.8, 1.0] 1136 | [1.0, 1.0, 0.4, 1.0] 1137 | [1.0, 1.0, 0.6, 1.0] 1138 | [1.0, 0.8, 1.0, 0.6] 1139 | [1.0, 1.0, 0.8, 0.6] 1140 | [1.0, 0.8, 1.0, 1.0] 1141 | [0.8, 0.8, 0.8, 1.0] 1142 | [0.8, 1.0, 1.0, 0.6] 1143 | [1.0, 0.8, 0.8, 0.8] 1144 | [0.8, 1.0, 1.0, 1.0] 1145 | [1.0, 0.6, 0.8, 0.6] 1146 | [1.0, 1.0, 0.8, 1.0] 1147 | [1.0, 1.0, 0.8, 1.0] 1148 | [1.0, 0.8, 1.0, 0.8] 1149 | [1.0, 0.8, 0.6, 0.6] 1150 | [1.0, 1.0, 0.8, 1.0] 1151 | [1.0, 0.8, 1.0, 0.6] 1152 | [1.0, 0.8, 1.0, 1.0] 1153 | [1.0, 1.0, 0.6, 1.0] 1154 | [1.0, 1.0, 0.8, 1.0] 1155 | [1.0, 1.0, 0.8, 0.4] 1156 | [0.8, 0.8, 0.8, 1.0] 1157 | [0.8, 0.8, 0.4, 1.0] 1158 | [0.8, 1.0, 1.0, 1.0] 1159 | [0.6, 1.0, 1.0, 1.0] 1160 | [1.0, 0.8, 1.0, 0.6] 1161 | [1.0, 0.6, 0.6, 0.8] 1162 | [0.6, 0.8, 1.0, 0.6] 1163 | [0.8, 0.4, 1.0, 0.6] 1164 | [1.0, 0.8, 0.8, 0.8] 1165 | [1.0, 0.8, 0.8, 0.8] 1166 | [1.0, 0.8, 0.2, 0.8] 1167 | [1.0, 0.8, 1.0, 0.8] 1168 | [1.0, 1.0, 0.6, 0.4] 1169 | [1.0, 0.6, 1.0, 0.8] 1170 | [0.8, 0.6, 0.8, 1.0] 1171 | [0.6, 1.0, 1.0, 0.8] 1172 | [0.8, 1.0, 0.2, 0.8] 1173 | [0.8, 1.0, 0.8, 1.0] 1174 | [1.0, 0.6, 1.0, 0.6] 1175 | [1.0, 1.0, 1.0, 0.8] 1176 | [0.6, 0.8, 0.8, 0.4] 1177 | [1.0, 0.8, 1.0, 0.8] 1178 | [1.0, 0.8, 0.6, 1.0] 1179 | [0.8, 0.8, 0.6, 0.8] 1180 | [0.8, 0.8, 0.4, 1.0] 1181 | [1.0, 0.8, 0.6, 1.0] 1182 | [1.0, 0.8, 0.4, 0.8] 1183 | [1.0, 1.0, 0.4, 1.0] 1184 | [1.0, 0.8, 0.8, 1.0] 1185 | [1.0, 0.8, 1.0, 0.8] 1186 | [0.4, 1.0, 1.0, 1.0] 1187 | [1.0, 1.0, 0.8, 1.0] 1188 | [1.0, 1.0, 0.8, 1.0] 1189 | [0.8, 1.0, 1.0, 1.0] 1190 | [0.8, 0.8, 0.4, 0.8] 1191 | [0.8, 1.0, 1.0, 0.6] 1192 | [1.0, 0.8, 0.8, 1.0] 1193 | [1.0, 0.8, 1.0, 0.8] 1194 | [1.0, 0.8, 0.8, 0.8] 1195 | [1.0, 0.6, 0.6, 0.8] 1196 | [0.8, 1.0, 1.0, 0.8] 1197 | [1.0, 1.0, 0.8, 0.8] 1198 | [1.0, 0.8, 1.0, 1.0] 1199 | [1.0, 0.6, 0.8, 0.8] 1200 | [1.0, 0.8, 0.8, 1.0] 1201 | -------------------------------------------------------------------------------- /historys/omniglot-5-way-1-shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/historys/omniglot-5-way-1-shot.png -------------------------------------------------------------------------------- /outputs/avg_to_nn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/avg_to_nn.png -------------------------------------------------------------------------------- /outputs/eval_sine_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/eval_sine_1.png -------------------------------------------------------------------------------- /outputs/eval_sine_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/eval_sine_2.png -------------------------------------------------------------------------------- /outputs/eval_sine_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/eval_sine_3.png -------------------------------------------------------------------------------- /outputs/maml_test_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/maml_test_1.png -------------------------------------------------------------------------------- /outputs/maml_test_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/maml_test_2.png -------------------------------------------------------------------------------- /outputs/maml_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/maml_train.png -------------------------------------------------------------------------------- /outputs/meta_sine_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/meta_sine_1.png -------------------------------------------------------------------------------- /outputs/meta_sine_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/meta_sine_2.png -------------------------------------------------------------------------------- /outputs/meta_sine_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/meta_sine_3.png -------------------------------------------------------------------------------- /outputs/meta_test_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/meta_test_1.png -------------------------------------------------------------------------------- /outputs/meta_test_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/meta_test_2.png -------------------------------------------------------------------------------- /outputs/meta_test_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/meta_test_3.png -------------------------------------------------------------------------------- /outputs/meta_to_avg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/meta_to_avg.png -------------------------------------------------------------------------------- /outputs/meta_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/meta_train.png -------------------------------------------------------------------------------- /outputs/meta_train_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/meta_train_1.png -------------------------------------------------------------------------------- /outputs/miniimagenet 5-way-1-shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/miniimagenet 5-way-1-shot.png -------------------------------------------------------------------------------- /outputs/miniimagenet_5w1s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/miniimagenet_5w1s.png -------------------------------------------------------------------------------- /outputs/miniimagenet_5w5s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/miniimagenet_5w5s.png -------------------------------------------------------------------------------- /outputs/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/model.png -------------------------------------------------------------------------------- /outputs/omniglot_20w1s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/omniglot_20w1s.png -------------------------------------------------------------------------------- /outputs/omniglot_5w1s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/omniglot_5w1s.png -------------------------------------------------------------------------------- /outputs/sinemodel_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/outputs/sinemodel_loss.png -------------------------------------------------------------------------------- /scripts/image_classification/history_vis.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Date: 6th Mar 2020 3 | Author: HilbertXu 4 | Abstract: Code for visualizing the training history and smooth the line 5 | ''' 6 | import os 7 | import sys 8 | import argparse 9 | import numpy as np 10 | import scipy.signal as signal 11 | import matplotlib.pyplot as plt 12 | 13 | plt.rcParams['font.sans-serif']=['SimSun'] #用来正常显示中文标签 14 | plt.rcParams['axes.unicode_minus']=False #用来正常显示负号 15 | 16 | 17 | def read_file(file_name): 18 | ''' 19 | :param file_name: History file to be read 20 | :return A list 21 | ''' 22 | if args.mode == 'train': 23 | file_data = [] 24 | with open(file_name, 'r') as f: 25 | for line in f: 26 | data = line[:-1] 27 | data = float(data) 28 | data = round(data, 2) 29 | file_data.append(data) 30 | return file_data 31 | elif args.mode == 'test': 32 | file_data = [] 33 | with open(file_name, 'r') as f: 34 | for line in f: 35 | data = line[:-1] 36 | file_data.append(data) 37 | return file_data 38 | 39 | def data_preprocess(data): 40 | _data = [] 41 | for line in data: 42 | line = line[1:-1] 43 | line = line.split(',') 44 | line = [float(num) for num in line] 45 | line = sorted(line) 46 | # line = line[1:] 47 | # line = np.mean(line) 48 | line = max(line) 49 | print (line) 50 | _data.append(line) 51 | return _data 52 | 53 | 54 | def smooth(data): 55 | # tmp = scipy.signal.savgol_filter(data, 53, 3) 56 | tmp = signal.savgol_filter(data, 49, 3) 57 | return tmp 58 | 59 | def plot_figure(loss, smooth_loss, acc, smooth_acc): 60 | fig = plt.figure(dpi=128, figsize=(10,6)) 61 | plt.plot(loss, color='coral', alpha=0.2, label='训练误差') 62 | plt.plot(smooth_loss,color='coral', label='平滑后的训练误差') 63 | plt.plot(acc, color='royalblue', alpha=0.2, label='训练精度') 64 | plt.plot(smooth_acc, color='royalblue', label='平滑后的训练精度') 65 | plt.legend(loc='upper right') 66 | plt.title('{}数据集 {}-way {}-shot 小样本图像分类任务{}过程曲线'.format(dataset, n_way, k_shot, '训练')) 67 | plt.xlabel('元批次数', fontsize=16) 68 | plt.ylabel('', fontsize=16) 69 | # plt.tick_params(axis='both', which='major', labelsize=16) 70 | plt.show() 71 | 72 | if __name__ == '__main__': 73 | argparse = argparse.ArgumentParser() 74 | # Dataset options 75 | argparse.add_argument('--dataset', type=str, help='Dataset miniimagenet or omniglot', default='miniimagenet') 76 | # Task options 77 | argparse.add_argument('--mode', type=str, help='Train process or test process', default='train') 78 | argparse.add_argument('--n_way', type=int, help='N-way', default=5) 79 | argparse.add_argument('--k_shot', type=int, help='K-shot', default=1) 80 | argparse.add_argument('--his_dir', type=str, help='Path to the training history directory', default='../../historys') 81 | # Generate args 82 | args = argparse.parse_args() 83 | 84 | dataset = args.dataset 85 | n_way = args.n_way 86 | k_shot = args.k_shot 87 | os.chdir(args.his_dir) 88 | if args.mode == 'train': 89 | loss = read_file('{}-{}-way-{}-shot-train.txt'.format(dataset, n_way, k_shot)) 90 | acc = read_file('{}-{}-way-{}-shot-acc.txt'.format(dataset, n_way, k_shot)) 91 | # calculate means and std of last 1000 iteration 92 | acc_mean = np.mean(acc[-1000:]) 93 | acc_std = np.std(acc[-1000:]) 94 | print (acc_mean, acc_std) 95 | 96 | elif args.mode == 'test': 97 | loss = read_file('{}-{}-way-{}-shot-loss-test.txt'.format(dataset, n_way, k_shot)) 98 | acc = read_file('{}-{}-way-{}-shot-acc-test.txt'.format(dataset, n_way, k_shot)) 99 | # pre process 100 | loss = data_preprocess(loss) 101 | acc = data_preprocess(acc) 102 | # calculate means and std of last 200 iteration 103 | # calculate means and std of last 1000 iteration 104 | acc_mean = np.mean(acc[-200:]) 105 | acc_std = np.std(acc[-200:]) 106 | print (acc_mean, acc_std) 107 | 108 | 109 | 110 | 111 | smooth_loss = smooth(loss) 112 | smooth_acc = smooth(acc) 113 | 114 | plot_figure(loss, smooth_loss, acc, smooth_acc) -------------------------------------------------------------------------------- /scripts/image_classification/image_preprocess.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Date: 1st Feb 2020 3 | Author: HilbertXu 4 | Abstract: Code for preprocessing dataset images 5 | ''' 6 | from __future__ import print_function 7 | import csv 8 | import glob 9 | import os 10 | import sys 11 | import random 12 | import argparse 13 | import numpy as np 14 | from tqdm import tqdm 15 | from tqdm._tqdm import trange 16 | from PIL import Image 17 | import tensorflow as tf 18 | import cv2 19 | 20 | class ImageProc: 21 | def __init__(self, dataset): 22 | if dataset == 'miniimagenet': 23 | print ('Processing MiniImagenet dataset') 24 | self.path_to_image = '../../dataset/miniImagenet/' 25 | all_images = glob.glob(self.path_to_image + '/images/*') 26 | # Resize images 27 | with tqdm(total=len(all_images)) as pbar: 28 | for i, image_file in enumerate(all_images): 29 | img = Image.open(image_file) 30 | img = img.resize((84,84), resample=Image.LANCZOS) 31 | img.save(image_file) 32 | if i % 500 == 0 and i > 0: 33 | pbar.set_description('{} images processed'.format(i)) 34 | pbar.update(500) 35 | # self.set_dir() 36 | elif dataset == 'omniglot': 37 | print ('Processing Omniglot dataset') 38 | self.root = '../../dataset/omniglot' 39 | character_folders = [ 40 | os.path.join(self.root, family, character) \ 41 | for family in os.listdir(self.root) \ 42 | if os.path.isdir(os.path.join(self.root, family)) \ 43 | for character in os.listdir(os.path.join(self.root, family)) 44 | ] 45 | for character in character_folders: 46 | print ('Currently processing {}'.format(character)) 47 | images = os.listdir(character) 48 | for image in images: 49 | image_file = os.path.join(character, image) 50 | img = Image.open(image_file) 51 | img = img.resize((28,28), resample=Image.LANCZOS) 52 | img.save(image_file) 53 | 54 | def set_dir(self): 55 | os.chdir(self.path_to_image) 56 | for datatype in ['train', 'test', 'val']: 57 | if os.path.exists(datatype) is False: 58 | print ('create /{} directories'.format(datatype)) 59 | os.system('mkdir {}'.format(datatype)) 60 | else: 61 | print ('Directories /{} already exist'.format(datatype)) 62 | count = len(open(datatype + '.csv', 'r').readlines()) 63 | with open(datatype + '.csv', 'r') as csvfile: 64 | print ('Reading {}.csv, {} images in total'.format(datatype, count-1)) 65 | reader = csv.reader(csvfile, delimiter=',') 66 | last_label = '' 67 | with tqdm(total=count) as pbar: 68 | for i, row in enumerate(reader): 69 | if i == 0: # skip the headers 70 | continue 71 | image_name = row[0] 72 | label = row[1] 73 | # Set up a folder for a new class 74 | if label != last_label: 75 | label_dir = datatype + '/' + label + '/' 76 | os.system('mkdir -p {}'.format(label_dir)) 77 | last_label = label 78 | os.system('mv images/' + image_name+ ' ' + label_dir) 79 | 80 | if i % 400 == 0 and i > 0: 81 | pbar.set_description('{} {} images moved'.format(datatype, i)) 82 | pbar.update(500) 83 | 84 | if __name__ == '__main__': 85 | argparse = argparse.ArgumentParser() 86 | # Dataset options 87 | argparse.add_argument('--dataset', type=str, help='Dataset to be processed', default='miniimagenet') 88 | # Generate args 89 | args = argparse.parse_args() 90 | proc = ImageProc(args.dataset) 91 | -------------------------------------------------------------------------------- /scripts/image_classification/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Date: Feb 11st 2020 3 | Author: Hilbert XU 4 | Abstract: Training process and functions 5 | """ 6 | # -*- coding: UTF-8 -*- 7 | import os 8 | import cv2 9 | import sys 10 | import random 11 | import datetime 12 | import numpy as np 13 | import argparse 14 | import tensorflow as tf 15 | import time 16 | import matplotlib.pyplot as plt 17 | from task_generator import TaskGenerator 18 | from meta_learner import MetaLearner 19 | 20 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 21 | 22 | def write_histogram(model, writer, step): 23 | ''' 24 | :param model: A model 25 | :param writer: tf.summary writer 26 | :param step: Current training step 27 | ''' 28 | with writer.as_default(): 29 | for idx, layer in enumerate(model.layers): 30 | if 'conv' in layer.name or 'dense' in layer.name: 31 | tf.summary.histogram(layer.name+':kernel', layer.kernel, step=step) 32 | tf.summary.histogram(layer.name+':bias', layer.bias, step=step) 33 | if 'batch_normalization' in layer.name: 34 | tf.summary.histogram(layer.name+':gamma', layer.gamma, step=step) 35 | tf.summary.histogram(layer.name+':beta', layer.beta, step=step) 36 | 37 | 38 | def write_gradient(grads, writer, step, with_bn=True): 39 | ''' 40 | :param grads: Gradients on query set 41 | :param writer: tf.summary writer 42 | :param step: Current training step 43 | ''' 44 | if with_bn: 45 | name = [ 46 | 'conv_0:kernel_grad', 'conv_0:bias_grad', 'batch_normalization_1:gamma_grad', 'batch_normalization_1:beta_grad', 47 | 'conv_1:kernel_grad', 'conv_1:bias_grad', 'batch_normalization_2:gamma_grad', 'batch_normalization_2:beta_grad', 48 | 'conv_2:kernel_grad', 'conv_2:bias_grad', 'batch_normalization_3:gamma_grad', 'batch_normalization_3:beta_grad', 49 | 'conv_3:kernel_grad', 'conv_3:bias_grad', 'batch_normalization_4:gamma_grad', 'batch_normalization_4:beta_grad', 50 | 'dense:kernel_grad', 'dense:bias_grad' 51 | ] 52 | with writer.as_default(): 53 | for idx, grad in enumerate(grads): 54 | tf.summary.histogram(name[idx], grad, step=step) 55 | elif with_bn is False: 56 | name = [ 57 | 'conv_0:kernel_grad', 'conv_0:bias_grad', 'conv_1:kernel_grad', 'conv_1:bias_grad', 58 | 'conv_2:kernel_grad', 'conv_2:bias_grad', 'conv_3:kernel_grad', 'conv_3:bias_grad', 59 | 'dense:kernel_grad', 'dense:bias_grad' 60 | ] 61 | with writer.as_default(): 62 | for idx, grad in enumerate(grads): 63 | tf.summary.histogram(name[idx], grad, step=step) 64 | 65 | 66 | def restore_model(model, weights_dir): 67 | ''' 68 | :param model: Model to be restored 69 | :param weights_dir: Path to weights 70 | 71 | :return: model with trained weights 72 | ''' 73 | print ('Relod weights from: {}'.format(weights_dir)) 74 | ckpt = tf.train.Checkpoint(maml_model=model) 75 | latest_weights = tf.train.latest_checkpoint(weights_dir) 76 | ckpt.restore(latest_weights) 77 | return model 78 | 79 | def copy_model(model, x): 80 | ''' 81 | :param model: model to be copied 82 | :param x: a set of data, used to build the copied model 83 | 84 | :return copied model 85 | ''' 86 | copied_model = MetaLearner() 87 | copied_model(x) 88 | copied_model.set_weights(model.get_weights()) 89 | return copied_model 90 | 91 | def loss_fn(y, pred_y): 92 | ''' 93 | :param pred_y: Prediction output of model 94 | :param y: Ground truth 95 | 96 | :return loss value: 97 | ''' 98 | return tf.reduce_mean(tf.losses.categorical_crossentropy(y, pred_y)) 99 | 100 | def accuracy_fn(y, pred_y): 101 | ''' 102 | :param pred_y: Prediction output of model 103 | :param y: Ground truth 104 | 105 | :return accuracy value: 106 | ''' 107 | accuracy = tf.keras.metrics.Accuracy() 108 | _ = accuracy.update_state(tf.argmax(pred_y, axis=1), tf.argmax(y, axis=1)) 109 | return accuracy.result() 110 | 111 | def compute_loss(model, x, y, loss_fn=loss_fn): 112 | ''' 113 | :param model: A neural net 114 | :param x: Train data 115 | :param y: Groud truth 116 | :param loss_fn: Loss function used to compute loss value 117 | 118 | :return Loss value 119 | ''' 120 | _, pred_y = model(x) 121 | loss = loss_fn(y, pred_y) 122 | return loss, pred_y 123 | 124 | def compute_gradients(model, x, y, loss_fn=loss_fn): 125 | ''' 126 | :param model: Neural network 127 | :param x: Input tensor 128 | :param y: Ground truth of input tensor 129 | :param loss_fn: loss function 130 | 131 | :return Gradient tensor 132 | ''' 133 | with tf.GradientTape() as tape: 134 | _, pred = model(x) 135 | loss = loss_fn(y, pred) 136 | grads = tape.gradient(loss, model.trainable_variables) 137 | return grads 138 | 139 | def apply_gradients(optimizer, gradients, variables): 140 | ''' 141 | :param optimizer: optimizer, Adam for task-level update, SGD for meta level update 142 | :param gradients: gradients 143 | :param variables: trainable variables of model 144 | 145 | :return None 146 | ''' 147 | optimizer.apply_gradients(zip(gradients, variables)) 148 | 149 | def regular_train_step(model, x, y, optimizer): 150 | gradients = compute_gradients(model, x, y, loss_fn=loss_fn) 151 | apply_gradients(optimizer, gradients, model.trainable_variables) 152 | return model 153 | 154 | def maml_train(model, batch_generator): 155 | # Set parameters 156 | visual = args.visual 157 | n_way = args.n_way 158 | k_shot = args.k_shot 159 | total_batches = args.total_batches 160 | meta_batchsz = args.meta_batchsz 161 | update_steps = args.update_steps 162 | update_steps_test = args.update_steps_test 163 | test_steps = args.test_steps 164 | ckpt_steps = args.ckpt_steps 165 | print_steps = args.print_steps 166 | inner_lr = args.inner_lr 167 | meta_lr = args.meta_lr 168 | ckpt_dir = args.ckpt_dir + args.dataset+'/{}way{}shot/'.format(n_way, k_shot) 169 | print ('Start training process of {}-way {}-shot {}-query problem'.format(args.n_way, args.k_shot, args.k_query)) 170 | print ('{} steps, inner_lr: {}, meta_lr:{}, meta_batchsz:{}'.format(total_batches, inner_lr, meta_lr, meta_batchsz)) 171 | 172 | # Initialize Tensorboard writer 173 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 174 | log_dir = args.log_dir + args.dataset +'/{}way{}shot/'.format(n_way, k_shot) + current_time 175 | summary_writer = tf.summary.create_file_writer(log_dir) 176 | 177 | # Meta optimizer for update model parameters 178 | meta_optimizer = tf.keras.optimizers.Adam(learning_rate=args.meta_lr, name='meta_optimizer') 179 | 180 | # Initialize Checkpoint handle 181 | checkpoint = tf.train.Checkpoint(maml_model=model) 182 | losses = [] 183 | accs = [] 184 | test_losses = [] 185 | test_accs = [] 186 | 187 | test_min_losses = [] 188 | test_max_accs = [] 189 | 190 | def _maml_finetune_step(test_set): 191 | # Set up recorders for test batch 192 | batch_loss = [0 for _ in range(meta_batchsz)] 193 | batch_acc = [0 for _ in range(meta_batchsz)] 194 | # Set up copied models 195 | copied_model = MetaLearner.hard_copy(model, args) 196 | for idx, task in enumerate(test_set): 197 | # Slice task to support set and query set 198 | support_x, support_y, query_x, query_y = task 199 | # Update fast weights several times 200 | for i in range(update_steps_test): 201 | # Set up inner gradient tape, watch the copied_model.inner_weights 202 | with tf.GradientTape(watch_accessed_variables=False) as inner_tape: 203 | # we only want inner tape watch the fast weights in each update steps 204 | inner_tape.watch(copied_model.inner_weights) 205 | inner_loss, _ = compute_loss(copied_model, support_x, support_y) 206 | inner_grads = inner_tape.gradient(inner_loss, copied_model.inner_weights) 207 | copied_model = MetaLearner.meta_update(copied_model, args, alpha=inner_lr, grads=inner_grads) 208 | # Compute task loss & accuracy on the query set 209 | task_loss, task_pred = compute_loss(copied_model, query_x, query_y, loss_fn=loss_fn) 210 | task_acc = accuracy_fn(query_y, task_pred) 211 | batch_loss[idx] += task_loss 212 | batch_acc[idx] += task_acc 213 | 214 | # Delete copied_model for saving memory 215 | del copied_model 216 | 217 | return batch_loss, batch_acc 218 | 219 | # Define the maml train step 220 | def _maml_train_step(batch_set): 221 | # Set up recorders for every batch 222 | batch_loss = [0 for _ in range(meta_batchsz)] 223 | batch_acc = [0 for _ in range(meta_batchsz)] 224 | # Set up outer gradient tape, only watch model.trainable_variables 225 | # Because GradientTape only auto record tranable_variables of model 226 | # But the copied_model.inner_weights is tf.Tensor, so they won't be automatically watched 227 | with tf.GradientTape() as outer_tape: 228 | # Use the average loss over all tasks in one batch to compute gradients 229 | for idx, task in enumerate(batch_set): 230 | # Set up copied model 231 | copied_model = model 232 | # Slice task to support set and query set 233 | support_x, support_y, query_x, query_y = task 234 | if visual: 235 | with summary_writer.as_default(): 236 | tf.summary.image('Support Images', support_x, max_outputs=5, step=step) 237 | tf.summary.image('Query Images', query_x, max_outputs=5, step=step) 238 | # Update fast weights several times 239 | for i in range(update_steps): 240 | # Set up inner gradient tape, watch the copied_model.inner_weights 241 | with tf.GradientTape(watch_accessed_variables=False) as inner_tape: 242 | # we only want inner tape watch the fast weights in each update steps 243 | inner_tape.watch(copied_model.inner_weights) 244 | inner_loss, _ = compute_loss(copied_model, support_x, support_y) 245 | inner_grads = inner_tape.gradient(inner_loss, copied_model.inner_weights) 246 | copied_model = MetaLearner.meta_update(copied_model, args, alpha=inner_lr, grads=inner_grads) 247 | # Compute task loss & accuracy on the query set 248 | task_loss, task_pred = compute_loss(copied_model, query_x, query_y, loss_fn=loss_fn) 249 | task_acc = accuracy_fn(query_y, task_pred) 250 | batch_loss[idx] += task_loss 251 | batch_acc[idx] += task_acc 252 | # Compute mean loss of the whole batch 253 | mean_loss = tf.reduce_mean(batch_loss) 254 | # Compute second order gradients 255 | outer_grads = outer_tape.gradient(mean_loss, model.trainable_variables) 256 | apply_gradients(meta_optimizer, outer_grads, model.trainable_variables) 257 | if visual: 258 | # Write gradients histogram 259 | write_gradient(outer_grads, summary_writer, step) 260 | # Return reslut of one maml train step 261 | return batch_loss, batch_acc 262 | 263 | # Main loop 264 | start = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 265 | print ('Start at {}'.format(start)) 266 | # For each epoch update model total_batches times 267 | start = time.time() 268 | for step in range(total_batches+1): 269 | # Get a batch data 270 | batch_set = batch_generator.train_batch() 271 | # batch_generator.print_label_map() 272 | # Run maml train step 273 | batch_loss, batch_acc = _maml_train_step(batch_set) 274 | if visual: 275 | # Write histogram 276 | write_histogram(model, summary_writer, step) 277 | # Record Loss 278 | losses.append(tf.reduce_mean(batch_loss).numpy()) 279 | accs.append(tf.reduce_mean(batch_acc).numpy()) 280 | # Write to Tensorboard 281 | with summary_writer.as_default(): 282 | tf.summary.scalar('query loss', tf.reduce_mean(batch_loss), step=step) 283 | tf.summary.scalar('query accuracy', tf.reduce_mean(batch_acc), step=step) 284 | 285 | # Print train result 286 | if step % print_steps == 0 and step > 0: 287 | batch_loss = [loss.numpy() for loss in batch_loss] 288 | batch_acc = [acc.numpy() for acc in batch_acc] 289 | print ('[STEP. {}] Task Losses: {}; Task Accuracies: {}; Time to run {} Steps: {}'.format(step, batch_loss, batch_acc, print_steps, time.time()-start)) 290 | start = time.time() 291 | # Uncomment to see the sampled folders of each task 292 | # train_ds.print_label_map() 293 | 294 | # Save checkpoint 295 | if step % ckpt_steps == 0 and step > 0: 296 | checkpoint.save(ckpt_dir+'maml_model.ckpt') 297 | 298 | # Evaluating model 299 | if step % test_steps == 0 and step > 0: 300 | test_set = batch_generator.test_batch() 301 | batch_generator.print_label_map() 302 | test_loss, test_acc = _maml_finetune_step(test_set) 303 | with summary_writer.as_default(): 304 | tf.summary.scalar('test loss', tf.reduce_mean(test_loss), step=step) 305 | tf.summary.scalar('test accuracy', tf.reduce_mean(test_acc), step=step) 306 | # Tensor to list 307 | test_loss = [loss.numpy() for loss in test_loss] 308 | test_acc = [acc.numpy() for acc in test_acc] 309 | # Record test history 310 | test_losses.append(test_loss) 311 | test_accs.append(test_acc) 312 | print ('Test Losses: {}, Test Accuracys: {}'.format(test_loss, test_acc)) 313 | print ('=====================================================================') 314 | # Meta train step 315 | 316 | # Record training history 317 | os.chdir(args.his_dir) 318 | losses_plot, = plt.plot(losses, label = "Train Acccuracy", color='coral') 319 | accs_plot, = plt.plot(accs,'--',label = "Train Loss", color='royalblue') 320 | # accs_plot = plt.plot(accs, '--',color='blue') 321 | plt.legend([losses_plot, accs_plot], ['Train Loss', 'Train Accuracy']) 322 | plt.title('{} {}-Way {}-Shot MAML Training Process'.format(args.dataset, n_way, k_shot)) 323 | plt.savefig('{}-{}-way-{}-shot.png'.format(args.dataset, n_way, k_shot)) 324 | 325 | train_hist = '{}-{}-way{}-shot-train.txt'.format(args.dataset, n_way,k_shot) 326 | acc_hist = '{}-{}-way{}-shot-acc.txt'.format(args.dataset, n_way,k_shot) 327 | test_acc_hist = '{}-{}-way{}-shot-acc-test.txt'.format(args.dataset, n_way,k_shot) 328 | test_loss_hist = '{}-{}-way{}-shot-loss-test.txt'.format(args.dataset, n_way,k_shot) 329 | 330 | # Save History 331 | f = open(train_hist, 'w') 332 | for i in range(len(losses)): 333 | f.write(str(losses[i]) + '\n') 334 | f.close() 335 | 336 | f = open(acc_hist, 'w') 337 | for i in range(len(accs)): 338 | f.write(str(accs[i]) + '\n') 339 | f.close() 340 | 341 | f = open(test_acc_hist, 'w') 342 | for i in range(len(test_accs)): 343 | f.write(str(test_accs[i]) + '\n') 344 | f.close() 345 | 346 | f = open(test_loss_hist, 'w') 347 | for i in range(len(test_losses)): 348 | f.write(str(test_losses[i]) + '\n') 349 | f.close() 350 | 351 | return model 352 | 353 | def eval_model(model, batch_generator, num_steps=None): 354 | if num_steps is None: 355 | num_steps = (0, 1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100) 356 | # Generate a batch data 357 | batch_set = batch_generator.test_batch() 358 | # Print the label map of each task 359 | batch_generator.print_label_map() 360 | # Use a copy of current model 361 | copied_model = model 362 | # Initialize optimizer 363 | optimizer = tf.keras.optimizers.SGD(learning_rate=args.inner_lr) 364 | 365 | task_losses = [0, 0, 0, 0] 366 | task_accs = [0, 0, 0, 0] 367 | 368 | loss_res = [[] for _ in range(len(batch_set))] 369 | acc_res = [[] for _ in range(len(batch_set))] 370 | 371 | # Record test result 372 | if 0 in num_steps: 373 | for idx, task in enumerate(batch_set): 374 | support_x, support_y, query_x, query_y = task 375 | loss, pred = compute_loss(model, query_x, query_y) 376 | acc = accuracy_fn(query_y, pred) 377 | task_losses[idx] += loss.numpy() 378 | task_accs[idx] += acc.numpy() 379 | loss_res[idx].append((0, loss.numpy())) 380 | acc_res[idx].append((0, acc.numpy())) 381 | print ('Before any update steps, test result:') 382 | print ('Task losses: {}'.format(task_losses)) 383 | print ('Task accuracies: {}'.format(task_accs)) 384 | # Test for each task 385 | for idx, task in enumerate(batch_set): 386 | print ('========== Task {} =========='.format(idx+1)) 387 | support_x, support_y, query_x, query_y = task 388 | for step in range(1, np.max(num_steps)+1): 389 | with tf.GradientTape() as tape: 390 | #regular_train_step(model, support_x, support_y, optimizer) 391 | loss, pred = compute_loss(model, support_x, support_y) 392 | grads = tape.gradient(loss, model.trainable_variables) 393 | optimizer.apply_gradients(zip(grads, model.trainable_variables)) 394 | # Test on query set 395 | qry_loss, qry_pred = compute_loss(model, query_x, query_y) 396 | qry_acc = accuracy_fn(query_y, qry_pred) 397 | # Record result 398 | if step in num_steps: 399 | loss_res[idx].append((step, qry_loss.numpy())) 400 | acc_res[idx].append((step, qry_acc.numpy())) 401 | print ('After {} steps update'.format(step)) 402 | print ('Task losses: {}'.format(qry_loss.numpy())) 403 | print ('Task accs: {}'.format(qry_acc.numpy())) 404 | print ('---------------------------------') 405 | 406 | for idx in range(len(batch_set)): 407 | l_x=[] 408 | l_y=[] 409 | a_x = [] 410 | a_y=[] 411 | # plt.subplot(2, 2, idx+1) 412 | plt.figure() 413 | for j in range(len(num_steps)): 414 | l_x.append(loss_res[idx][j][0]) 415 | l_y.append(loss_res[idx][j][1]) 416 | a_x.append(acc_res[idx][j][0]) 417 | a_y.append(acc_res[idx][j][1]) 418 | plt.plot(l_x, l_y, 'x', color='coral') 419 | plt.plot(a_x, a_y, '*', color='royalblue') 420 | # plt.annotate('Loss After 1 Fine Tune Step: %.2f'%l_y[1], xy=(l_x[1], l_y[1]), xytext=(l_x[1]-0.2, l_y[1]-0.2)) 421 | # plt.annotate('Accuracy After 1 Fine Tune Step: %.2f'%a_y[1], xy=(a_x[1], a_y[1]), xytext=(a_x[1]-0.2, a_y[1]-0.2)) 422 | plt.plot(l_x, l_y, linestyle='--', color='coral') 423 | plt.plot(a_x, a_y, linestyle='--', color='royalblue') 424 | plt.xlabel('Fine Tune Step', fontsize=12) 425 | plt.fill_between(a_x, [a+0.1 for a in a_y], [a-0.1 for a in a_y], facecolor='royalblue', alpha=0.3) 426 | legend=['Fine Tune Points','Fine Tune Points','Loss', 'Accuracy'] 427 | plt.legend(legend) 428 | plt.title('Task {} Fine Tuning Process'.format(idx+1)) 429 | plt.show() 430 | 431 | 432 | 433 | if __name__ == '__main__': 434 | argparse = argparse.ArgumentParser() 435 | argparse.add_argument('--mode', type=str, help='train or test', default='train') 436 | # Dataset options 437 | argparse.add_argument('--dataset', type=str, help='Dataset used to train model', default='miniimagenet') 438 | argparse.add_argument('--visual', type=bool, help='Set True to visualize the batch data', default=True) 439 | # Task options 440 | argparse.add_argument('--n_way', type=int, help='Number of classes used in classification (e.g. 5-way classification)', default=5) 441 | argparse.add_argument('--k_shot', type=int, help='Number of images in support set', default=1) 442 | argparse.add_argument('--k_query', type=int, help='Number of images in query set(For Omniglot, equal to k_shot)', default=15) 443 | # Model options 444 | argparse.add_argument('--num_filters', type=int, help='Number of filters in the convolution layers (32 for MiniImagenet, 64 for Ominiglot)', default=32) 445 | argparse.add_argument('--with_bn', type=bool, help='Turn True to add BatchNormalization layers in neural net', default=True) 446 | # Training options 447 | argparse.add_argument('--meta_batchsz', type=int, help='Number of tasks in one batch', default=4) 448 | argparse.add_argument('--update_steps', type=int, help='Number of inner gradient updates for each task', default=5) 449 | argparse.add_argument('--update_steps_test', type=int, help='Number of inner gradient updates for each task while testing', default=10) 450 | argparse.add_argument('--inner_lr', type=float, help='Learning rate of inner update steps, the step size alpha in the algorithm', default=1e-2) # 0.1 for ominiglot 451 | argparse.add_argument('--meta_lr', type=float, help='Learning rate of meta update steps, the step size beta in the algorithm', default=1e-3) 452 | argparse.add_argument('--total_batches', type=int, help='Total update steps for each epoch', default=40000) 453 | # Log options 454 | argparse.add_argument('--ckpt_steps', type=int, help='Number of steps for recording checkpoints', default=5000) 455 | argparse.add_argument('--test_steps', type=int, help='Number of steps for evaluating model', default=5) 456 | argparse.add_argument('--print_steps', type=int, help='Number of steps for prints result in the console', default=1) 457 | argparse.add_argument('--log_dir', type=str, help='Path to the log directory', default='../../logs/') 458 | argparse.add_argument('--ckpt_dir', type=str, help='Path to the checkpoint directory', default='../../weights/') 459 | argparse.add_argument('--his_dir', type=str, help='Path to the training history directory', default='../../historys/') 460 | # Generate args 461 | args = argparse.parse_args() 462 | 463 | print ('Initialize model with 4 Conv({} filters) Blocks and 1 Dense Layer'.format(args.num_filters)) 464 | model = MetaLearner(args=args) 465 | print ('Build model') 466 | model = MetaLearner.initialize(model) 467 | model.summary() 468 | tf.keras.utils.plot_model(model, to_file='../model.png',show_shapes=True,show_layer_names=True,dpi=128) 469 | # Initialize task generator 470 | batch_generator = TaskGenerator(args) 471 | 472 | if args.mode == 'train': 473 | # model = restore_model(model, '../../weights/{}/{}way{}shot'.format(args.dataset, args.n_way, args.k_shot)) 474 | model = maml_train(model, batch_generator) 475 | elif args.mode == 'test': 476 | model = restore_model(model, '../../weights/{}/{}way{}shot'.format(args.dataset, args.n_way, args.k_shot)) 477 | if args.dataset == 'miniimagenet': 478 | eval_model(model, batch_generator, num_steps=(0, 1, 10, 100, 200, 300, 400, 500, 600)) 479 | elif args.dataset == 'omniglot': 480 | eval_model(model, batch_generator, num_steps=(0, 1, 5, 100, 200, 300, 400, 500, 600)) 481 | -------------------------------------------------------------------------------- /scripts/image_classification/meta_learner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Date: Feb 11st 2020 3 | Author: Hilbert XU 4 | Abstract: MetaLeaner model 5 | """ 6 | from task_generator import TaskGenerator 7 | 8 | import tensorflow as tf 9 | import tensorflow.keras as keras 10 | import tensorflow.keras.backend as keras_backend 11 | import os 12 | import numpy as np 13 | import cv2 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '/gpu:0' 16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 17 | 18 | def loss_fn(y, pred_y): 19 | ''' 20 | :param pred_y: Prediction output of model 21 | :param y: Ground truth 22 | 23 | :return loss value: 24 | ''' 25 | return tf.reduce_mean(tf.losses.categorical_crossentropy(y, pred_y)) 26 | 27 | class MetaLearner(tf.keras.models.Model): 28 | """ 29 | Meta Learner 30 | """ 31 | def __init__(self, args=None, bn=None): 32 | """ 33 | :param filters: Number of filters in conv layers 34 | :param img_size: Size of input image, [84, 84, 3] for miniimagenet 35 | :param n_way: Number of classes 36 | :param name: Name of model 37 | """ 38 | super(MetaLearner, self).__init__() 39 | # for miniimagener dataset set conv2d kernel size=[32, 3, 3] 40 | # for ominiglot dataset set conv2d kernel size=[64, 3, 3] 41 | if args is not None: 42 | if args.dataset == 'miniimagenet': 43 | self.filters = 32 44 | self.ip_size = (1, 84, 84, 3) 45 | self.op_channel = args.n_way 46 | self.with_bn = args.with_bn 47 | self.training = True if args.mode is 'train' else False 48 | if args.dataset == 'omniglot': 49 | self.filters = 64 50 | self.ip_size = (1, 28, 28, 1) 51 | self.op_channel = args.n_way 52 | self.with_bn = args.with_bn 53 | self.training = True if args.mode is 'train' else False 54 | else: 55 | self.filters = 32 56 | self.ip_size = (1, 84, 84, 3) 57 | self.op_channel = 5 58 | self.training = True 59 | if bn is not None: 60 | self.with_bn = bn 61 | else: 62 | self.with_bn = False 63 | 64 | if self.with_bn is True: 65 | # Build model layers 66 | self.conv_1 = tf.keras.layers.Conv2D(filters=self.filters, kernel_size=(3,3), strides=(1,1), padding='SAME', kernel_initializer='glorot_normal') 67 | self.bn_1 = tf.keras.layers.BatchNormalization(axis=-1) 68 | self.max_pool_1 = tf.keras.layers.MaxPool2D(pool_size=(2,2)) 69 | 70 | self.conv_2 = tf.keras.layers.Conv2D(filters=self.filters, kernel_size=(3,3), strides=(1,1), padding='SAME', kernel_initializer='glorot_normal') 71 | self.bn_2 = tf.keras.layers.BatchNormalization(axis=-1) 72 | self.max_pool_2 = tf.keras.layers.MaxPool2D(pool_size=(2,2)) 73 | 74 | self.conv_3 = tf.keras.layers.Conv2D(filters=self.filters, kernel_size=(3,3), strides=(1,1), padding='SAME', kernel_initializer='glorot_normal') 75 | self.bn_3 = tf.keras.layers.BatchNormalization(axis=-1) 76 | self.max_pool_3 = tf.keras.layers.MaxPool2D(pool_size=(2,2)) 77 | 78 | self.conv_4 = tf.keras.layers.Conv2D(filters=self.filters, kernel_size=(3,3), strides=(1,1), padding='SAME', kernel_initializer='glorot_normal') 79 | self.bn_4 = tf.keras.layers.BatchNormalization(axis=-1) 80 | self.max_pool_4 = tf.keras.layers.MaxPool2D(pool_size=(2,2)) 81 | 82 | self.fc = tf.keras.layers.Flatten() 83 | self.out = tf.keras.layers.Dense(self.op_channel) 84 | 85 | elif self.with_bn is False: 86 | self.conv_1 = tf.keras.layers.Conv2D(filters=self.filters, kernel_size=(3,3), strides=(1,1), padding='SAME', kernel_initializer='glorot_normal') 87 | self.max_pool_1 = tf.keras.layers.MaxPool2D(pool_size=(2,2)) 88 | 89 | self.conv_2 = tf.keras.layers.Conv2D(filters=self.filters, kernel_size=(3,3), strides=(1,1), padding='SAME', kernel_initializer='glorot_normal') 90 | self.max_pool_2 = tf.keras.layers.MaxPool2D(pool_size=(2,2)) 91 | 92 | self.conv_3 = tf.keras.layers.Conv2D(filters=self.filters, kernel_size=(3,3), strides=(1,1), padding='SAME', kernel_initializer='glorot_normal') 93 | self.max_pool_3 = tf.keras.layers.MaxPool2D(pool_size=(2,2)) 94 | 95 | self.conv_4 = tf.keras.layers.Conv2D(filters=self.filters, kernel_size=(3,3), strides=(1,1), padding='SAME', kernel_initializer='glorot_normal') 96 | self.max_pool_4 = tf.keras.layers.MaxPool2D(pool_size=(2,2)) 97 | 98 | self.fc = tf.keras.layers.Flatten() 99 | self.out = tf.keras.layers.Dense(self.op_channel) 100 | 101 | @property 102 | def inner_weights(self): 103 | ''' 104 | :return model weights 105 | ''' 106 | if self.with_bn is True: 107 | weights = [ 108 | self.conv_1.kernel, self.conv_1.bias, 109 | self.bn_1.gamma, self.bn_1.beta, 110 | self.conv_2.kernel, self.conv_2.bias, 111 | self.bn_2.gamma, self.bn_2.beta, 112 | self.conv_3.kernel, self.conv_3.bias, 113 | self.bn_3.gamma, self.bn_3.beta, 114 | self.conv_4.kernel, self.conv_4.bias, 115 | self.bn_4.gamma, self.bn_4.beta, 116 | self.out.kernel, self.out.bias 117 | ] 118 | elif self.with_bn is False: 119 | weights = [ 120 | self.conv_1.kernel, self.conv_1.bias, 121 | self.conv_2.kernel, self.conv_2.bias, 122 | self.conv_3.kernel, self.conv_3.bias, 123 | self.conv_4.kernel, self.conv_4.bias, 124 | self.out.kernel, self.out.bias 125 | ] 126 | return weights 127 | 128 | @classmethod 129 | def initialize(cls, model): 130 | ''' 131 | :return initialized model 132 | ''' 133 | ip_size = model.ip_size 134 | model.build(ip_size) 135 | return model 136 | 137 | @classmethod 138 | def hard_copy(cls, model, args): 139 | copied_model = cls(args) 140 | copied_model.build(model.ip_size) 141 | 142 | if copied_model.with_bn is True: 143 | copied_model.conv_1.kernel = model.conv_1.kernel 144 | copied_model.conv_1.bias = model.conv_1.bias 145 | copied_model.bn_1.gamma = model.bn_1.gamma 146 | copied_model.bn_1.beta = model.bn_1.beta 147 | # copied_model.max_pool_1 = model.max_pool_1 148 | 149 | copied_model.conv_2.kernel = model.conv_2.kernel 150 | copied_model.conv_2.bias = model.conv_2.bias 151 | copied_model.bn_2.gamma = model.bn_2.gamma 152 | copied_model.bn_2.beta = model.bn_2.beta 153 | # copied_model.max_pool_2 = model.max_pool_2 154 | 155 | copied_model.conv_3.kernel = model.conv_3.kernel 156 | copied_model.conv_3.bias = model.conv_3.bias 157 | copied_model.bn_3.gamma = model.bn_3.gamma 158 | copied_model.bn_3.beta = model.bn_3.beta 159 | # copied_model.max_pool_3 = model.max_pool_3 160 | 161 | copied_model.conv_4.kernel = model.conv_4.kernel 162 | copied_model.conv_4.bias = model.conv_4.bias 163 | copied_model.bn_4.gamma = model.bn_4.gamma 164 | copied_model.bn_4.beta = model.bn_4.beta 165 | # copied_model.max_pool_4 = model.max_pool_4 166 | 167 | copied_model.out.kernel = model.out.kernel 168 | copied_model.out.bias = model.out.bias 169 | 170 | elif copied_model.with_bn is False: 171 | copied_model.conv_1.kernel = model.conv_1.kernel 172 | copied_model.conv_1.bias = model.conv_1.bias 173 | # copied_model.max_pool_1 = model.max_pool_1 174 | 175 | copied_model.conv_2.kernel = model.conv_2.kernel 176 | copied_model.conv_2.bias = model.conv_2.bias 177 | # copied_model.max_pool_2 = model.max_pool_2 178 | 179 | copied_model.conv_3.kernel = model.conv_3.kernel 180 | copied_model.conv_3.bias = model.conv_3.bias 181 | # copied_model.max_pool_3 = model.max_pool_3 182 | 183 | copied_model.conv_4.kernel = model.conv_4.kernel 184 | copied_model.conv_4.bias = model.conv_4.bias 185 | # copied_model.max_pool_4 = model.max_pool_4 186 | 187 | copied_model.out.kernel = model.out.kernel 188 | copied_model.out.bias = model.out.bias 189 | 190 | return copied_model 191 | 192 | 193 | @classmethod 194 | def meta_update(cls, model, args, alpha=0.01, grads=None): 195 | ''' 196 | :param cls: Class MetaLearner 197 | :param model: Model to be copied 198 | :param alpah: The inner learning rate when update the fast weights 199 | :param grads: Gradients to generate fast weights 200 | 201 | :return model with fast weights 202 | ''' 203 | # Make a hard copy of target model 204 | # If with bn layers, call like copied_model = cls(bn=True) 205 | copied_model = cls(args) 206 | ''' 207 | !!!!!!!!!!! 208 | IMPORTANT 209 | !!!!!!!!!!! 210 | Must call copied_model.build(input_shape) to build up model weights before calling copied_model(x) 211 | If not, when we call copied_model(x) tf will reinitialize the model weights and overwrite the fast weights 212 | At the same time, GradientTape will fail to record it and the gradients will return Nones 213 | ''' 214 | copied_model.build(model.ip_size) 215 | 216 | if copied_model.with_bn is True: 217 | copied_model.conv_1.kernel = model.conv_1.kernel 218 | copied_model.conv_1.bias = model.conv_1.bias 219 | copied_model.bn_1.gamma = model.bn_1.gamma 220 | copied_model.bn_1.beta = model.bn_1.beta 221 | # copied_model.max_pool_1 = model.max_pool_1 222 | 223 | copied_model.conv_2.kernel = model.conv_2.kernel 224 | copied_model.conv_2.bias = model.conv_2.bias 225 | copied_model.bn_2.gamma = model.bn_2.gamma 226 | copied_model.bn_2.beta = model.bn_2.beta 227 | # copied_model.max_pool_2 = model.max_pool_2 228 | 229 | copied_model.conv_3.kernel = model.conv_3.kernel 230 | copied_model.conv_3.bias = model.conv_3.bias 231 | copied_model.bn_3.gamma = model.bn_3.gamma 232 | copied_model.bn_3.beta = model.bn_3.beta 233 | # copied_model.max_pool_3 = model.max_pool_3 234 | 235 | copied_model.conv_4.kernel = model.conv_4.kernel 236 | copied_model.conv_4.bias = model.conv_4.bias 237 | copied_model.bn_4.gamma = model.bn_4.gamma 238 | copied_model.bn_4.beta = model.bn_4.beta 239 | # copied_model.max_pool_4 = model.max_pool_4 240 | 241 | copied_model.out.kernel = model.out.kernel 242 | copied_model.out.bias = model.out.bias 243 | 244 | # if call with gradients, apply it by using SGD 245 | # Manually apply Gradient descent as the task-level optimizer 246 | if grads is not None: 247 | copied_model.conv_1.kernel = copied_model.conv_1.kernel - alpha * grads[0] 248 | copied_model.conv_1.bias = copied_model.conv_1.bias - alpha * grads[1] 249 | copied_model.bn_1.gamma = copied_model.bn_1.gamma - alpha * grads[2] 250 | copied_model.bn_1.beta = copied_model.bn_1.beta - alpha * grads[3] 251 | 252 | copied_model.conv_2.kernel = copied_model.conv_2.kernel - alpha * grads[4] 253 | copied_model.conv_2.bias = copied_model.conv_2.bias - alpha * grads[5] 254 | copied_model.bn_2.gamma = copied_model.bn_2.gamma - alpha * grads[6] 255 | copied_model.bn_2.beta = copied_model.bn_2.beta - alpha * grads[7] 256 | 257 | copied_model.conv_3.kernel = copied_model.conv_3.kernel - alpha * grads[8] 258 | copied_model.conv_3.bias = copied_model.conv_3.bias - alpha * grads[9] 259 | copied_model.bn_3.gamma = copied_model.bn_3.gamma - alpha * grads[10] 260 | copied_model.bn_3.beta = copied_model.bn_3.beta - alpha * grads[11] 261 | 262 | copied_model.conv_4.kernel = copied_model.conv_4.kernel - alpha * grads[12] 263 | copied_model.conv_4.bias = copied_model.conv_4.bias - alpha * grads[13] 264 | copied_model.bn_4.gamma = copied_model.bn_4.gamma - alpha * grads[14] 265 | copied_model.bn_4.beta = copied_model.bn_4.beta - alpha * grads[15] 266 | 267 | copied_model.out.kernel = copied_model.out.kernel - alpha * grads[16] 268 | copied_model.out.bias = copied_model.out.bias - alpha * grads[17] 269 | 270 | elif copied_model.with_bn is False: 271 | copied_model.conv_1.kernel = model.conv_1.kernel 272 | copied_model.conv_1.bias = model.conv_1.bias 273 | # copied_model.max_pool_1 = model.max_pool_1 274 | 275 | copied_model.conv_2.kernel = model.conv_2.kernel 276 | copied_model.conv_2.bias = model.conv_2.bias 277 | # copied_model.max_pool_2 = model.max_pool_2 278 | 279 | copied_model.conv_3.kernel = model.conv_3.kernel 280 | copied_model.conv_3.bias = model.conv_3.bias 281 | # copied_model.max_pool_3 = model.max_pool_3 282 | 283 | copied_model.conv_4.kernel = model.conv_4.kernel 284 | copied_model.conv_4.bias = model.conv_4.bias 285 | # copied_model.max_pool_4 = model.max_pool_4 286 | 287 | copied_model.out.kernel = model.out.kernel 288 | copied_model.out.bias = model.out.bias 289 | 290 | # if call with gradients, apply it by using SGD 291 | # Manually apply Gradient descent as the task-level optimizer 292 | if grads is not None: 293 | copied_model.conv_1.kernel = copied_model.conv_1.kernel - alpha * grads[0] 294 | copied_model.conv_1.bias = copied_model.conv_1.bias - alpha * grads[1] 295 | 296 | copied_model.conv_2.kernel = copied_model.conv_2.kernel - alpha * grads[2] 297 | copied_model.conv_2.bias = copied_model.conv_2.bias - alpha * grads[3] 298 | 299 | copied_model.conv_3.kernel = copied_model.conv_3.kernel - alpha * grads[4] 300 | copied_model.conv_3.bias = copied_model.conv_3.bias - alpha * grads[5] 301 | 302 | copied_model.conv_4.kernel = copied_model.conv_4.kernel - alpha * grads[6] 303 | copied_model.conv_4.bias = copied_model.conv_4.bias - alpha * grads[7] 304 | 305 | copied_model.out.kernel = copied_model.out.kernel - alpha * grads[8] 306 | copied_model.out.bias = copied_model.out.bias - alpha * grads[9] 307 | 308 | return copied_model 309 | 310 | def call(self, x): 311 | ''' 312 | @TODO Change network to conv-relu-bn-maxpool 313 | ''' 314 | if self.with_bn is True: 315 | # Conv block #1 316 | x = self.max_pool_1(tf.keras.activations.relu(self.bn_1(self.conv_1(x), training=self.training))) 317 | # Conv block #2 318 | x = self.max_pool_2(tf.keras.activations.relu(self.bn_2(self.conv_2(x), training=self.training))) 319 | # Conv block #3 320 | x = self.max_pool_3(tf.keras.activations.relu(self.bn_3(self.conv_3(x), training=self.training))) 321 | # Conv block #4 322 | x = self.max_pool_4(tf.keras.activations.relu(self.bn_4(self.conv_4(x), training=self.training))) 323 | 324 | elif self.with_bn is False: 325 | # Conv block #1 326 | x = self.max_pool_1(tf.keras.activations.relu(self.conv_1(x))) 327 | # Conv block #2 328 | x = self.max_pool_2(tf.keras.activations.relu(self.conv_2(x))) 329 | # Conv block #3 330 | x = self.max_pool_3(tf.keras.activations.relu(self.conv_3(x))) 331 | # Conv block #4 332 | x = self.max_pool_4(tf.keras.activations.relu(self.conv_4(x))) 333 | 334 | # Fully Connect Layer 335 | x = self.fc(x) 336 | # Logits Output 337 | logits = self.out(x) 338 | # Prediction 339 | pred = tf.keras.activations.softmax(logits) 340 | 341 | return logits, pred 342 | 343 | -------------------------------------------------------------------------------- /scripts/image_classification/task_generator.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Date: 14th Feb 2020 3 | Author: HilbertXu 4 | Abstract: Code for generating meta-train tasks using miniimagenet and ominiglot dataset 5 | Meta learning is different from general supervised learning 6 | The basic training element in training process is TASK(N-way K-shot) 7 | A batch contains several tasks 8 | tasks: containing N-way K-shot for meta-train, N-way N-query for meta-test 9 | ''' 10 | 11 | from __future__ import print_function 12 | import argparse 13 | import csv 14 | import glob 15 | import os 16 | import sys 17 | import random 18 | import numpy as np 19 | from tqdm import tqdm 20 | from tqdm._tqdm import trange 21 | from PIL import Image 22 | import tensorflow as tf 23 | import cv2 24 | import time 25 | 26 | 27 | 28 | class TaskGenerator: 29 | def __init__(self, args=None): 30 | ''' 31 | :param mode: train or test 32 | :param n_way: a train task contains images from different N classes 33 | :param k_shot: k images used for meta-train 34 | :param k_query: k images used for meta-test 35 | :param meta_batchsz: the number of tasks in a batch 36 | :param total_batch_num: the number of batches 37 | ''' 38 | if args is not None: 39 | self.dataset = args.dataset 40 | self.mode = args.mode 41 | self.meta_batchsz = args.meta_batchsz 42 | self.n_way = args.n_way 43 | self.spt_num = args.k_shot 44 | self.qry_num = args.k_query 45 | self.dim_output = self.n_way 46 | else: 47 | self.dataset = 'omniglot' 48 | self.mode = 'test' 49 | self.meta_batchsz = 4 50 | self.n_way = 5 51 | self.spt_num = 1 52 | self.qry_num = 15 53 | self.img_size = 84 54 | self.img_channel = 3 55 | self.dim_output = self.n_way 56 | # For example: 57 | # 5-way 1-shot 15-query for MiniImagenet 58 | if self.dataset == 'miniimagenet': 59 | self.img_size = 84 60 | self.img_channel = 3 61 | META_TRAIN_DIR = '../../dataset/miniImagenet/train' 62 | META_VAL_DIR = '../../dataset/miniImagenet/test' 63 | # Set sample folders 64 | self.metatrain_folders = [os.path.join(META_TRAIN_DIR, label) \ 65 | for label in os.listdir(META_TRAIN_DIR) \ 66 | if os.path.isdir(os.path.join(META_TRAIN_DIR, label)) 67 | ] 68 | self.metaval_folders = [os.path.join(META_VAL_DIR, label) \ 69 | for label in os.listdir(META_VAL_DIR) \ 70 | if os.path.isdir(os.path.join(META_VAL_DIR, label)) 71 | ] 72 | 73 | if self.dataset == 'omniglot': 74 | self.img_size = 28 75 | self.img_channel = 1 76 | if self.spt_num != self.qry_num: 77 | # For Omniglot dataset set k_query = k_shot 78 | self.qry_num = self.spt_num 79 | DATA_FOLDER = '../../dataset/omniglot' 80 | character_folders = [ 81 | os.path.join(DATA_FOLDER, family, character) \ 82 | for family in os.listdir(DATA_FOLDER) \ 83 | if os.path.isdir(os.path.join(DATA_FOLDER, family)) \ 84 | for character in os.listdir(os.path.join(DATA_FOLDER, family)) 85 | ] 86 | # Shuffle dataset 87 | random.seed(9314) 88 | random.shuffle(character_folders) 89 | # Slice dataset to train set and test set 90 | # Use 1400 Alphabets as train set, the rest as test set 91 | self.metatrain_folders = character_folders[:1400] 92 | self.metaval_folders = character_folders[1400:] 93 | 94 | # Record the relationship between image label and the folder name in each task 95 | self.label_map = [] 96 | 97 | def print_label_map(self): 98 | print ('[TEST] Label map of current Batch') 99 | if self.dataset == 'miniimagenet': 100 | if len(self.label_map) > 0: 101 | for i, task in enumerate(self.label_map): 102 | print ('========= Task {} =========='.format(i+1)) 103 | for i, ref in enumerate(task): 104 | path = ref[0] 105 | label = path.split('/')[-1] 106 | print ('map {} --> {}\t'.format(label, ref[1]), end='') 107 | if i == 4: 108 | print ('') 109 | print ('========== END ==========') 110 | self.label_map = [] 111 | elif len(self.label_map) == 0: 112 | print ('ERROR! print_label_map() function must be called after generating a batch dataset') 113 | elif self.dataset == 'omniglot': 114 | if len(self.label_map) > 0: 115 | for i, task in enumerate(self.label_map): 116 | print ('========= Task {} =========='.format(i+1)) 117 | for i, ref in enumerate(task): 118 | path = ref[0] 119 | label = path.split('/')[-2] +'/'+ path.split('/')[-1] 120 | print ('map {} --> {}\t'.format(label, ref[1]), end='') 121 | if i == 4: 122 | print ('') 123 | print ('========== END ==========') 124 | self.label_map = [] 125 | elif len(self.label_map) == 0: 126 | print ('ERROR! print_label_map() function must be called after generating a batch dataset') 127 | 128 | 129 | def shuffle_set(self, set_x, set_y): 130 | # Shuffle 131 | set_seed = random.randint(0, 100) 132 | random.seed(set_seed) 133 | random.shuffle(set_x) 134 | random.seed(set_seed) 135 | random.shuffle(set_y) 136 | return set_x, set_y 137 | 138 | def read_images(self, image_file): 139 | if self.dataset == 'omniglot': 140 | # For Omniglot dataset image size:[28, 28, 1] 141 | return np.reshape(cv2.cvtColor(cv2.imread(image_file), cv2.COLOR_BGR2GRAY).astype(np.float32)/255, (self.img_size, self.img_size, self.img_channel)) 142 | if self.dataset == 'miniimagenet': 143 | # For Omniglot dataset image size:[84, 84, 3] 144 | return np.reshape(cv2.imread(image_file).astype(np.float32)/255, (self.img_size, self.img_size, self.img_channel)) 145 | 146 | def convert_to_tensor(self, np_objects): 147 | return [tf.convert_to_tensor(obj) for obj in np_objects] 148 | 149 | def generate_set(self, folder_list, shuffle=False): 150 | k_shot = self.spt_num 151 | k_query = self.qry_num 152 | set_sampler = lambda x: np.random.choice(x, k_shot+k_query, False) 153 | label_map = [] 154 | images_with_labels = [] 155 | # sample images for support set and query set 156 | # images_with_labels: size [5, 16] 5 classes with 16 images & labels per class 157 | for i, elem in enumerate(folder_list): 158 | folder = elem[0] 159 | label = elem[1] 160 | label_map.append((folder, label)) 161 | image_with_label = [(os.path.join(folder, image), label) \ 162 | for image in set_sampler(os.listdir(folder))] 163 | images_with_labels.append(image_with_label) 164 | self.label_map.append(label_map) 165 | if shuffle == True: 166 | for i, elem in enumerate(images_with_labels): 167 | random.shuffle(elem) 168 | 169 | # Function for slicing the dataset 170 | # support set & query set 171 | def _slice_set(ds): 172 | spt_x = list() 173 | spt_y = list() 174 | qry_x = list() 175 | qry_y = list() 176 | # 此处是从每类的k_shot+k_query张图片中抽取k_shot张作为support set, 其余作为query set 177 | # 并且按照图片路径读取图片,对label进行one hot编码 178 | # 将support set和query set整体转化为张量 179 | for i, class_elem in enumerate(ds): 180 | spt_elem = random.sample(class_elem, self.spt_num) 181 | qry_elem = [elem for elem in class_elem if elem not in spt_elem] 182 | spt_elem = list(zip(*spt_elem)) 183 | qry_elem = list(zip(*qry_elem)) 184 | spt_x.extend([self.read_images(img) for img in spt_elem[0]]) 185 | spt_y.extend([tf.one_hot(label, self.n_way) for label in spt_elem[1]]) 186 | qry_x.extend([self.read_images(img) for img in qry_elem[0]]) 187 | qry_y.extend([tf.one_hot(label, self.n_way) for label in qry_elem[1]]) 188 | 189 | # Shuffle datasets 190 | spt_x, spt_y = self.shuffle_set(spt_x, spt_y) 191 | qry_x, qry_y = self.shuffle_set(qry_x, qry_y) 192 | # convert to tensor 193 | spt_x, spt_y = self.convert_to_tensor((np.array(spt_x), np.array(spt_y))) 194 | qry_x, qry_y = self.convert_to_tensor((np.array(qry_x), np.array(qry_y))) 195 | return spt_x, spt_y, qry_x, qry_y 196 | return _slice_set(images_with_labels) 197 | 198 | def train_batch(self): 199 | ''' 200 | :return: a batch of support set tensor and query set tensor 201 | 202 | ''' 203 | folders = self.metatrain_folders 204 | # Shuffle root folder in order to prevent repeat 205 | batch_set = [] 206 | self.label_map = [] 207 | # Generate batch dataset 208 | # batch_spt_set: [meta_batchsz, n_way * k_shot, image_size] & [meta_batchsz, n_way * k_shot, n_way] 209 | # batch_qry_set: [meta_batchsz, n_way * k_query, image_size] & [meta_batchsz, n_way * k_query, n_way] 210 | for i in range(self.meta_batchsz): 211 | sampled_folders_idx = np.array(np.random.choice(len(folders), self.n_way, False)) 212 | np.random.shuffle(sampled_folders_idx) 213 | sampled_folders = np.array(folders)[sampled_folders_idx].tolist() 214 | folder_with_label = [] 215 | # for i, folder in enumerate(sampled_folders): 216 | # elem = (folder, i) 217 | # folder_with_label.append(elem) 218 | labels = np.arange(self.n_way) 219 | np.random.shuffle(labels) 220 | labels = labels.tolist() 221 | folder_with_label = list(zip(sampled_folders, labels)) 222 | support_x, support_y, query_x, query_y = self.generate_set(folder_with_label) 223 | batch_set.append((support_x, support_y, query_x, query_y)) 224 | # return [meta_batchsz * (support_x, support_y, query_x, query_y)] 225 | return batch_set 226 | 227 | def test_batch(self): 228 | ''' 229 | :return: a batch of support set tensor and query set tensor 230 | 231 | ''' 232 | folders = self.metaval_folders 233 | print ('Sample test batch from {} classes'.format(len(folders))) 234 | # Shuffle root folder in order to prevent repeat 235 | batch_set = [] 236 | self.label_map = [] 237 | # Generate batch dataset 238 | # batch_spt_set: [meta_batchsz, n_way * k_shot, image_size] & [meta_batchsz, n_way * k_shot, n_way] 239 | # batch_qry_set: [meta_batchsz, n_way * k_query, image_size] & [meta_batchsz, n_way * k_query, n_way] 240 | for i in range(self.meta_batchsz): 241 | sampled_folders_idx = np.array(np.random.choice(len(folders), self.n_way, False)) 242 | np.random.shuffle(sampled_folders_idx) 243 | sampled_folders = np.array(folders)[sampled_folders_idx].tolist() 244 | folder_with_label = [] 245 | # for i, folder in enumerate(sampled_folders): 246 | # elem = (folder, i) 247 | # folder_with_label.append(elem) 248 | labels = np.arange(self.n_way) 249 | np.random.shuffle(labels) 250 | labels = labels.tolist() 251 | folder_with_label = list(zip(sampled_folders, labels)) 252 | support_x, support_y, query_x, query_y = self.generate_set(folder_with_label) 253 | batch_set.append((support_x, support_y, query_x, query_y)) 254 | # return [meta_batchsz * (support_x, support_y, query_x, query_y)] 255 | return batch_set 256 | 257 | if __name__ == '__main__': 258 | tasks = TaskGenerator() 259 | tasks.mode = 'train' 260 | for i in range(20): 261 | batch_set = tasks.train_batch() 262 | tasks.print_label_map() 263 | print (len(batch_set)) 264 | time.sleep(5) 265 | 266 | ''' 267 | @TODO 268 | change to np.random.choice 269 | And find out the reason why so many repeat 270 | ''' -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/Navigation2DEnv.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from gym import spaces 4 | from gym.utils import seeding 5 | from gym.envs.classic_control import rendering 6 | 7 | 8 | class Navigation2DEnv(gym.Env): 9 | """2D navigation problems, as described in [1]. The code is adapted from 10 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/maml_examples/point_env_randgoal.py 11 | 12 | At each time step, the 2D agent takes an action (its velocity, clipped in 13 | [-0.1, 0.1]), and receives a penalty equal to its L2 distance to the goal 14 | position (ie. the reward is `-distance`). The 2D navigation tasks are 15 | generated by sampling goal positions from the uniform distribution 16 | on [-0.5, 0.5]^2. 17 | 18 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 19 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 20 | (https://arxiv.org/abs/1703.03400) 21 | """ 22 | def __init__(self, task={}): 23 | super(Navigation2DEnv, self).__init__() 24 | 25 | self.observation_space = spaces.Box(low=-np.inf, high=np.inf, 26 | shape=(2,), dtype=np.float32) 27 | self.action_space = spaces.Box(low=-0.1, high=0.1, 28 | shape=(2,), dtype=np.float32) 29 | 30 | self.viewer = None 31 | 32 | self._task = task 33 | self._goal = task.get('goal', np.zeros(2, dtype=np.float32)) 34 | self._state = np.zeros(2, dtype=np.float32) 35 | self.seed() 36 | 37 | def seed(self, seed=None): 38 | self.np_random, seed = seeding.np_random(seed) 39 | return [seed] 40 | 41 | def sample_tasks(self, num_tasks): 42 | goals = self.np_random.uniform(-0.5, 0.5, size=(num_tasks, 2)) 43 | tasks = [{'goal': goal} for goal in goals] 44 | return tasks 45 | 46 | def reset_task(self, task): 47 | self._task = task 48 | self._goal = task['goal'] 49 | 50 | def reset(self, env=True): 51 | self._state = np.zeros(2, dtype=np.float32) 52 | return self._state 53 | 54 | def step(self, action): 55 | action = np.clip(action, -0.1, 0.1) 56 | assert self.action_space.contains(action), f"Action {action} not in the action space" 57 | self._state = self._state + action 58 | 59 | x = self._state[0] - self._goal[0] 60 | y = self._state[1] - self._goal[1] 61 | reward = -np.sqrt(x ** 2 + y ** 2) 62 | done = bool((np.abs(x) < 0.01) and (np.abs(y) < 0.01)) 63 | 64 | return self._state, reward, done, self._task 65 | 66 | def render(self, mode='rgb_array'): 67 | screen_width = 500 68 | screen_height = 500 69 | 70 | print (self._state) 71 | 72 | if self.viewer is None: 73 | self._last_state = (0,0) 74 | 75 | # Initialize Viewer 76 | self.viewer = rendering.Viewer(screen_width, screen_height) 77 | ''' 78 | The coordinate system in rendering looks like: 79 | ============================================== 80 | (0,y) 81 | 82 | (0,0) (x,0) 83 | ============================================== 84 | ''' 85 | # Create coordinate system 86 | self.x_axis = rendering.Line((0,250),(500,250)) 87 | self.y_axis = rendering.Line((250,0), (250,500)) 88 | self.x_axis.set_color(0,0,0) 89 | self.y_axis.set_color(0,0,0) 90 | self.viewer.add_geom(self.x_axis) 91 | self.viewer.add_geom(self.y_axis) 92 | 93 | # Create goal point 94 | self.goal_point = rendering.make_circle(2) 95 | self.goal_trans = rendering.Transform(translation=(int((self._goal[0]+0.5)*500), int((self._goal[1]+0.5)*500))) 96 | self.goal_point.add_attr(self.goal_trans) 97 | self.goal_point.set_color(0.25, 0.42, 0.88) 98 | self.viewer.add_geom(self.goal_point) 99 | 100 | # Create current state 101 | self.state_point = rendering.make_circle(2) 102 | self.state_trans = rendering.Transform(translation=(int((self._state[0]+0.5)*500), int((self._state[1]+0.5)*500))) 103 | self.state_point.add_attr(self.state_trans) 104 | self.state_point.set_color(1, 0.49, 0.31) 105 | self.viewer.add_geom(self.state_point) 106 | 107 | # Update rendering for each step 108 | cur_pos = self._state 109 | last_pos = self._last_state 110 | self.state_trans.set_translation(int((self._state[0]+0.5)*500), int((self._state[1]+0.5)*500)) 111 | self.trace = rendering.Line(((last_pos[0]+0.5)*500, (last_pos[1]+0.5)*500), ((cur_pos[0]+0.5)*500, (cur_pos[1]+0.5)*500)) 112 | self.trace.set_color(1, 0.49, 0.31) 113 | self.viewer.add_geom(self.trace) 114 | self._last_state = self._state 115 | 116 | return self.viewer.render(return_rgb_array = mode == 'rgb_array') 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/__pycache__/episode.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/scripts/reinforcement_learning/maml-rl-easy/__pycache__/episode.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/__pycache__/policy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/scripts/reinforcement_learning/maml-rl-easy/__pycache__/policy.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/scripts/reinforcement_learning/maml-rl-easy/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/__pycache__/subproc_vec_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/scripts/reinforcement_learning/maml-rl-easy/__pycache__/subproc_vec_env.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/envs_test/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/scripts/reinforcement_learning/maml-rl-easy/envs_test/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/envs_test/maze_test.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import time 3 | import random 4 | import numpy as np 5 | 6 | env = gym.make('Maze-v0') 7 | 8 | state = env.reset() 9 | 10 | action_list = ['up', 'down', 'left', 'right'] 11 | 12 | all_pos = [[i, j] for i in range(8) for j in range(8)] 13 | all_possible_traps = [[i, j] for i in range(1, 7) for j in range(1,7)] 14 | goal_index = np.random.randint(0, 64) 15 | goals = all_pos[goal_index] 16 | trap_index = np.random.randint(0, 36, 2) 17 | traps = [all_possible_traps[trap_index[0]], all_possible_traps[trap_index[1]]] 18 | 19 | task = {'goal': goals, 'traps':traps} 20 | 21 | env.reset_task(task) 22 | 23 | score = 0 24 | 25 | # Without any policy 26 | while True: 27 | time.sleep(1) 28 | env.render() 29 | action = np.random.randint(0, 4, 1)[0] 30 | print (action_list[action]) 31 | state, reward, done, _ = env.step(action) 32 | score += reward 33 | print ('reward: ', reward, 'done: ', done) 34 | print ('=======================') 35 | if done: # 游戏结束 36 | print('score: ', score) # 打印分数 37 | break 38 | env.close() 39 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/envs_test/navigation_test.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import time 3 | import random 4 | import numpy as np 5 | 6 | env = gym.make('Navigation2D-v0') 7 | 8 | state = env.reset() 9 | 10 | goals = np.random.uniform(-0.5, 0.5, size=(2,)) 11 | task = {'goal': goals} 12 | 13 | env.reset_task(task) 14 | 15 | score = 0 16 | 17 | # Without any policy 18 | while True: 19 | time.sleep(1) 20 | env.render() 21 | action = np.random.uniform(-0.1, 0.1, size=(2,)) 22 | state, reward, done, _ = env.step(action) 23 | score += reward 24 | if done: # 游戏结束 25 | print('score: ', score) # 打印分数 26 | break 27 | env.close() 28 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/episode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import tensorflow as tf 4 | 5 | """ 6 | The code is taken from and perhaps will be changed in the future 7 | https://github.com/tristandeleu/pytorch-maml-rl/blob/master/maml_rl/episode.py 8 | """ 9 | 10 | class BatchEpisodes(object): 11 | def __init__(self, batch_size, gamma=0.95): 12 | self.batch_size = batch_size 13 | self.gamma = gamma 14 | 15 | self._observations_list = [[] for _ in range(batch_size)] 16 | self._actions_list = [[] for _ in range(batch_size)] 17 | self._rewards_list = [[] for _ in range(batch_size)] 18 | self._mask_list = [] 19 | 20 | self._observations = None 21 | self._actions = None 22 | self._rewards = None 23 | self._returns = None 24 | self._mask = None 25 | 26 | @property 27 | def observations(self): 28 | if self._observations is None: 29 | observation_shape = self._observations_list[0][0].shape 30 | observations = np.zeros((len(self), self.batch_size) + observation_shape, dtype=np.float32) 31 | for i in range(self.batch_size): 32 | length = len(self._observations_list[i]) 33 | observations[:length, i] = np.stack(self._observations_list[i], axis=0) 34 | self._observations = observations 35 | return self._observations 36 | 37 | @property 38 | def actions(self): 39 | if self._actions is None: 40 | action_shape = self._actions_list[0][0].shape 41 | actions = np.zeros((len(self), self.batch_size) 42 | + action_shape, dtype=self._actions_list[0][0].dtype) 43 | for i in range(self.batch_size): 44 | length = len(self._actions_list[i]) 45 | actions[:length, i] = np.stack(self._actions_list[i], axis=0) 46 | self._actions = actions 47 | return self._actions 48 | 49 | @property 50 | def rewards(self): 51 | if self._rewards is None: 52 | rewards = np.zeros((len(self), self.batch_size), dtype=np.float32) 53 | for i in range(self.batch_size): 54 | length = len(self._rewards_list[i]) 55 | rewards[:length, i] = np.stack(self._rewards_list[i], axis=0) 56 | self._rewards = rewards 57 | return self._rewards 58 | 59 | @property 60 | def returns(self): 61 | if self._returns is None: 62 | return_ = np.zeros(self.batch_size, dtype=np.float32) 63 | returns = np.zeros((len(self), self.batch_size), dtype=np.float32) 64 | rewards = self.rewards 65 | mask = self.mask 66 | for i in range(len(self) - 1, -1, -1): 67 | return_ = self.gamma * return_ + rewards[i] * mask[i] 68 | returns[i] = return_ 69 | self._returns = returns 70 | return self._returns 71 | 72 | @property 73 | def mask(self): 74 | if self._mask is None: 75 | mask = np.zeros((len(self), self.batch_size), dtype=np.float32) 76 | for i in range(self.batch_size): 77 | length = len(self._actions_list[i]) 78 | mask[:length, i] = 1.0 79 | self._mask = mask 80 | return self._mask 81 | 82 | def gae(self, values, tau=1.0): 83 | # Add an additional 0 at the end of values for 84 | # the estimation at the end of the episode 85 | values = tf.squeeze(values, axis=2) 86 | values = tf.pad(values * self.mask, [[0, 1], [0, 0]]) 87 | 88 | deltas = self.rewards + self.gamma * values[1:] - values[:-1] 89 | advantages = tf.TensorArray(tf.float32, *deltas.shape) 90 | gae = tf.zeros_like(deltas[0], dtype=tf.float32) 91 | 92 | for i in range(len(self) - 1, -1, -1): 93 | gae = gae * self.gamma * tau + deltas[i] 94 | advantages = advantages.write(i, gae) 95 | advantages = advantages.stack() 96 | # tf.reshape(advantages, shape=(1, advantages.shape[-1])) 97 | return advantages 98 | 99 | def append(self, observations, actions, rewards, batch_ids): 100 | for observation, action, reward, batch_id in zip( 101 | observations, actions, rewards, batch_ids): 102 | if batch_id is None: 103 | continue 104 | self._observations_list[batch_id].append(observation.astype(np.float32)) 105 | self._actions_list[batch_id].append(action.astype(action.dtype)) 106 | self._rewards_list[batch_id].append(reward.astype(np.float32)) 107 | 108 | def __len__(self): 109 | return max(map(len, self._rewards_list)) 110 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/gym_test/cartpole_nn.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import random 3 | import time 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.keras.layers import Dense 7 | 8 | ''' 9 | CartPole-V0 10 | ========================================================= 11 | State: car speed, car position, bar speed, bar position 12 | Action: 0-move left, 1-move right 13 | Reward: 1 per step 14 | Done: 200 round in total 15 | ''' 16 | env = gym.make('CartPole-v0') 17 | 18 | STATE_DIM = 4 19 | ACTION_DIM = 2 20 | 21 | model = tf.keras.Sequential() 22 | model.add(Dense(64, input_shape=(STATE_DIM,), activation='relu')) 23 | model.add(Dense(20, activation='relu')) 24 | model.add(Dense(ACTION_DIM, activation='linear')) 25 | 26 | model.summary() 27 | 28 | def generate_episode_data(): 29 | x, y, score = [], [], 0 30 | state = env.reset() 31 | 32 | while True: 33 | action = random.randrange(0, 2) 34 | x.append(state) 35 | y.append([1, 0] if action==0 else [0,1]) 36 | state, reward, done, _ = env.step(action) 37 | score += reward 38 | 39 | if done: 40 | break 41 | 42 | return x, y, score 43 | 44 | 45 | def generate_training_data(min_score=100): 46 | ''' 47 | Generate N episodes, use episodes with score>100 as training data 48 | ''' 49 | data_x, data_y, scores = [], [], [] 50 | 51 | for i in range(10000): 52 | x, y, score = generate_episode_data() 53 | 54 | if score > min_score: 55 | data_x.extend(x) 56 | data_y.extend(y) 57 | scores.append(score) 58 | 59 | print ('dataset size: {}, max score: {}'.format(len(data_x), max(scores))) 60 | return np.array(data_x), np.array(data_y) 61 | 62 | 63 | data_x, data_y = generate_training_data() 64 | print(data_x.shape) 65 | model.compile(loss='mse', optimizer='adam') 66 | model.fit(data_x, data_y, epochs=5) 67 | model.save('./model/CartPole-v0-nn.h5') 68 | 69 | for i in range(5): 70 | state = env.reset() 71 | scores = 0 72 | while True: 73 | time.sleep(0.1) 74 | env.render() 75 | action = np.argmax(model.predict(np.array([state]))[0]) 76 | state, reward, done, _ = env.step(action) 77 | scores += reward 78 | if done: 79 | print('CartPole Using NN, final score: {}'.format(scores)) 80 | break 81 | 82 | env.close() 83 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/gym_test/cartpole_policy_gradient.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import scipy.misc 3 | import gym 4 | ''' 5 | CartPole-V0 6 | ========================================================= 7 | State: car speed, car position, bar speed, bar position 8 | Action: 0-move left, 1-move right 9 | Reward: 1 per step 10 | Done: 200 round in total 11 | ''' 12 | env = gym.make('CartPole-v0') 13 | 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | STATE_DIM = 4 19 | ACTION_DIM = 2 20 | 21 | # Set the precsion of keras otherwise the sum of probability given by softmax will not be 1 22 | tf.keras.backend.set_floatx('float64') 23 | 24 | class PGModel(tf.keras.models.Model): 25 | def __init__(self, input_dim, output_dim): 26 | super(PGModel, self).__init__() 27 | self.input_dim = input_dim 28 | self.output_dim = output_dim 29 | 30 | # Initialize layers 31 | self.dense_1 = tf.keras.layers.Dense(128, input_shape=(None,self.input_dim), activation='relu') 32 | # tf.keras.layers.Dropout(0.1) 33 | self.all_act = tf.keras.layers.Dense(self.output_dim) 34 | 35 | def call(self, state): 36 | x = self.dense_1(state) 37 | x = self.all_act(x) 38 | self.logits = x 39 | output = tf.keras.activations.softmax(x) 40 | #output = tf.nn.softmax(x) 41 | return output, self.logits 42 | 43 | class PolicyGradient(object): 44 | def __init__( 45 | self, 46 | lr = 0.001, 47 | state_dim=STATE_DIM, 48 | action_dim=ACTION_DIM, 49 | reward_decay=0.95 50 | ): 51 | # Learning rate 52 | self.lr = lr 53 | # Dimension of state space 54 | self.state_dim = state_dim 55 | # Dimension of action space 56 | self.action_dim = action_dim 57 | # reward decay rate 58 | self.reward_decay = reward_decay 59 | # Observation, actions, reward of an episode 60 | self.ep_obs, self.ep_acts, self.ep_rs = [], [], [] 61 | # Policy Net 62 | self.model = PGModel(STATE_DIM, ACTION_DIM) 63 | # Optimizer 64 | self.optimizer = tf.keras.optimizers.Adam(self.lr) 65 | 66 | def loss_func(self, predict, actions, ep_rs_norm): 67 | actions = tf.one_hot(self.ep_acts, depth=self.action_dim) 68 | neg_log_prob = tf.nn.softmax_cross_entropy_with_logits(logits=predict, labels=actions) 69 | loss = tf.reduce_mean(neg_log_prob*ep_rs_norm) 70 | return loss 71 | 72 | def store_transition(self, s, a, r): 73 | self.ep_obs.append(s) 74 | self.ep_acts.append(a) 75 | self.ep_rs.append(r) 76 | 77 | def choose_action(self, state): 78 | prob_dist, _ = self.model(np.array([state])) 79 | action = np.random.choice(len(prob_dist[0]), p=prob_dist[0]) 80 | return action 81 | 82 | def discount_and_norm_reward(self): 83 | out = np.zeros_like(self.ep_rs) 84 | dis_reward = 0 85 | 86 | # Calculate reward with discount 87 | for i in reversed(range(len(self.ep_rs))): 88 | dis_reward = dis_reward + self.reward_decay * self.ep_rs[i] 89 | out[i] = dis_reward 90 | # Normalization 91 | out -= np.mean(out) 92 | out /= np.std(out) 93 | return out 94 | 95 | def train_op(self): 96 | discounted_reward = self.discount_and_norm_reward() 97 | 98 | with tf.GradientTape() as tape: 99 | prob_dist, logits = self.model(np.vstack(self.ep_obs)) 100 | loss = self.loss_func(logits, self.ep_acts, discounted_reward) 101 | grads = tape.gradient(loss, self.model.trainable_variables) 102 | self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) 103 | self.ep_obs, self.ep_acts, self.ep_rs = [], [], [] 104 | 105 | 106 | 107 | 108 | # Use default parameters 109 | agent_pg = PolicyGradient() 110 | # Make gym environment 111 | env = gym.make('CartPole-v0') 112 | env.seed(1) 113 | env = env.unwrapped 114 | DISPLAY_REWARD_THRESHOLD = 160 # renders environment if total episode reward is greater then this threshold 115 | RENDER = False # rendering wastes time 116 | 117 | print(env.action_space) 118 | print(env.observation_space) 119 | print(env.observation_space.high) 120 | print(env.observation_space.low) 121 | 122 | ''' 123 | 此处训练时遇到了Policy Gradient的一个主要问题,就是沿 124 | ''' 125 | 126 | for ep_idx in range(2000): 127 | observation = env.reset() 128 | while True: 129 | if RENDER: 130 | env.render() 131 | # Choose action with current policy 132 | action = agent_pg.choose_action(observation) 133 | # Execute the action 134 | _obs, reward, done, info = env.step(action) 135 | # Store ob, action, reward 136 | agent_pg.store_transition(_obs, action, reward) 137 | # Update observation 138 | observation = _obs 139 | 140 | if done: 141 | ep_rs_sum = sum(agent_pg.ep_rs) 142 | 143 | if 'running_reward' not in globals(): 144 | running_reward = ep_rs_sum 145 | else: 146 | running_reward = running_reward * 0.99 + ep_rs_sum * 0.01 147 | if running_reward > DISPLAY_REWARD_THRESHOLD: 148 | RENDER = True 149 | print ('Episode: {} Reward: {}'.format(ep_idx, int(running_reward))) 150 | 151 | # Update parameters of policy using the policy gradient 152 | agent_pg.train_op() 153 | 154 | break 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/gym_test/gym_env_test.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import time 3 | import random 4 | 5 | env = gym.make('MountainCar-v0') 6 | 7 | state = env.reset() 8 | 9 | score = 0 10 | 11 | while True: 12 | time.sleep(0.1) 13 | env.render() 14 | action = random.randint(0, 2) 15 | state, reward, done, _ = env.step(action) 16 | score += reward 17 | if done: # 游戏结束 18 | print('score: ', score) # 打印分数 19 | break 20 | env.close() -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/gym_test/model/CartPole-v0-nn.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/scripts/reinforcement_learning/maml-rl-easy/gym_test/model/CartPole-v0-nn.h5 -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/gym_test/model/MountainCar-v0-dqn.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/scripts/reinforcement_learning/maml-rl-easy/gym_test/model/MountainCar-v0-dqn.h5 -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/gym_test/model/MountainCar-v0-q-learning.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilbertXu/MAML-Tensorflow/ee8e9365ac0590706bc773687d74c9e91a9a432c/scripts/reinforcement_learning/maml-rl-easy/gym_test/model/MountainCar-v0-q-learning.pickle -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/gym_test/mountain_car_pg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import scipy.misc 3 | import gym 4 | ''' 5 | MountainCar-V0 6 | ======================================== 7 | State: Position, Speed 8 | Action: 0-left 1-hold 2-right 9 | Reward: -1 per round 10 | Done: 200 round in total/Reach the peak 11 | ''' 12 | 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | STATE_DIM = 2 18 | ACTION_DIM = 3 19 | 20 | # Set the precsion of keras otherwise the sum of probability given by softmax will not be 1 21 | tf.keras.backend.set_floatx('float64') 22 | 23 | class PGModel(tf.keras.models.Model): 24 | def __init__(self, input_dim, output_dim): 25 | super(PGModel, self).__init__() 26 | self.input_dim = input_dim 27 | self.output_dim = output_dim 28 | 29 | # Initialize layers 30 | self.dense_1 = tf.keras.layers.Dense(128, input_shape=(None,self.input_dim), activation='relu') 31 | # tf.keras.layers.Dropout(0.1) 32 | self.all_act = tf.keras.layers.Dense(self.output_dim) 33 | 34 | def call(self, state): 35 | x = self.dense_1(state) 36 | x = self.all_act(x) 37 | self.logits = x 38 | output = tf.keras.activations.softmax(x) 39 | #output = tf.nn.softmax(x) 40 | return output, self.logits 41 | 42 | class PolicyGradient(object): 43 | def __init__( 44 | self, 45 | lr = 0.001, 46 | state_dim=STATE_DIM, 47 | action_dim=ACTION_DIM, 48 | reward_decay=0.95 49 | ): 50 | # Learning rate 51 | self.lr = lr 52 | # Dimension of state space 53 | self.state_dim = state_dim 54 | # Dimension of action space 55 | self.action_dim = action_dim 56 | # reward decay rate 57 | self.reward_decay = reward_decay 58 | # Observation, actions, reward of an episode 59 | self.ep_obs, self.ep_acts, self.ep_rs = [], [], [] 60 | # Policy Net 61 | self.model = PGModel(STATE_DIM, ACTION_DIM) 62 | # Optimizer 63 | self.optimizer = tf.keras.optimizers.Adam(self.lr) 64 | 65 | def loss_func(self, predict, actions, ep_rs_norm): 66 | actions = tf.one_hot(self.ep_acts, depth=self.action_dim) 67 | neg_log_prob = tf.nn.softmax_cross_entropy_with_logits(logits=predict, labels=actions) 68 | loss = tf.reduce_mean(neg_log_prob*ep_rs_norm) 69 | return loss 70 | 71 | def store_transition(self, s, a, r): 72 | self.ep_obs.append(s) 73 | self.ep_acts.append(a) 74 | self.ep_rs.append(r) 75 | 76 | def choose_action(self, state): 77 | prob_dist, _ = self.model(np.array([state])) 78 | action = np.random.choice(len(prob_dist[0]), p=prob_dist[0]) 79 | return action 80 | 81 | def discount_and_norm_reward(self): 82 | out = np.zeros_like(self.ep_rs) 83 | dis_reward = 0 84 | 85 | # Calculate reward with discount 86 | for i in reversed(range(len(self.ep_rs))): 87 | dis_reward = dis_reward + self.reward_decay * self.ep_rs[i] 88 | out[i] = dis_reward 89 | # Normalization 90 | out -= np.mean(out) 91 | out /= np.std(out) 92 | return out 93 | 94 | def train_op(self): 95 | discounted_reward = self.discount_and_norm_reward() 96 | 97 | with tf.GradientTape() as tape: 98 | prob_dist, logits = self.model(np.vstack(self.ep_obs)) 99 | loss = self.loss_func(logits, self.ep_acts, discounted_reward) 100 | grads = tape.gradient(loss, self.model.trainable_variables) 101 | self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) 102 | self.ep_obs, self.ep_acts, self.ep_rs = [], [], [] 103 | 104 | 105 | 106 | 107 | # Use default parameters 108 | agent_pg = PolicyGradient() 109 | # Make gym environment 110 | env = gym.make('MountainCar-v0') 111 | env.seed(1) 112 | env = env.unwrapped 113 | DISPLAY_REWARD_THRESHOLD = -100 # renders environment if total episode reward is greater then this threshold 114 | RENDER = False # rendering wastes time 115 | 116 | print(env.action_space) 117 | print(env.observation_space) 118 | print(env.observation_space.high) 119 | print(env.observation_space.low) 120 | 121 | ''' 122 | 此处训练时遇到了Policy Gradient的一个主要问题 123 | ''' 124 | 125 | for ep_idx in range(2000): 126 | observation = env.reset() 127 | while True: 128 | if RENDER: 129 | env.render() 130 | # Choose action with current policy 131 | action = agent_pg.choose_action(observation) 132 | # Execute the action 133 | _obs, reward, done, info = env.step(action) 134 | # Store ob, action, reward 135 | agent_pg.store_transition(_obs, action, reward) 136 | # Update observation 137 | observation = _obs 138 | 139 | if done: 140 | ep_rs_sum = sum(agent_pg.ep_rs) 141 | 142 | if 'running_reward' not in globals(): 143 | running_reward = ep_rs_sum 144 | else: 145 | running_reward = running_reward * 0.99 + ep_rs_sum * 0.01 146 | if running_reward > DISPLAY_REWARD_THRESHOLD: 147 | RENDER = True 148 | print ('Episode: {} Reward: {}'.format(ep_idx, int(running_reward))) 149 | 150 | # Update parameters of policy using the policy gradient 151 | agent_pg.train_op() 152 | 153 | break 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/gym_test/mountain_dqn.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import random 3 | import time 4 | import numpy as np 5 | import pickle 6 | import tensorflow as tf 7 | from collections import deque 8 | 9 | ''' 10 | Replace Q-Table with deep neural network 11 | ''' 12 | 13 | class DQN(object): 14 | def __init__(self): 15 | self.step = 0 16 | self.update_freq = 200 17 | # Size of training set 18 | self.replay_size = 2000 19 | self.replay_queue = deque(maxlen=self.replay_size) 20 | self.model = self.create_model() 21 | self.target_model = self.create_model() 22 | 23 | def create_model(self): 24 | ''' 25 | Create a neural network with hidden size 100 26 | ''' 27 | STATE_DIM = 2 28 | ACTION_DIM = 3 29 | model = tf.keras.Sequential() 30 | model.add(tf.keras.layers.Dense(100, input_shape=(STATE_DIM,), activation='relu')) 31 | model.add(tf.keras.layers.Dense(ACTION_DIM, activation='linear')) 32 | model.compile(loss='mse', optimizer='adam') 33 | 34 | return model 35 | 36 | def act(self, s, epsilon=0.1): 37 | ''' 38 | Predict actions using neural network 39 | ''' 40 | # Introduce randomness firstly 41 | if np.random.uniform() < epsilon - self.step*0.0002: 42 | return np.random.choice([0, 1, 2]) 43 | return np.argmax(self.model.predict(np.array([s]))[0]) 44 | 45 | def save_model(self, file_path='./model/MountainCar-v0-dqn.h5'): 46 | print ('Model saved') 47 | self.model.save(file_path) 48 | 49 | def remember(self, s, a, next_s, reward): 50 | ''' 51 | For goal[0] = 0.5 52 | if next_s[0] > 0.4 give extra reward to boost the train process 53 | ''' 54 | if next_s[0] > 0.4: 55 | reward += 1 56 | self.replay_queue.append((s, a, next_s, reward)) 57 | 58 | def train(self, batch_size=64, lr=1, factor=0.95): 59 | if len(self.replay_queue) < self.replay_size: 60 | return 61 | self.step += 1 62 | 63 | # Every update_freq, update the weights of self.model to self.target_model 64 | if self.step % self.update_freq == 0: 65 | self.target_model.set_weights(self.model.get_weights()) 66 | 67 | replay_batch = random.sample(self.replay_queue, batch_size) 68 | s_batch = np.array([replay[0] for replay in replay_batch]) 69 | next_s_batch = np.array([replay[2] for replay in replay_batch]) 70 | 71 | Q = self.model.predict(s_batch) 72 | Q_next = self.target_model.predict(next_s_batch) 73 | 74 | # Update Q value in training set 75 | for i, replay in enumerate(replay_batch): 76 | _, a, _, reward = replay 77 | Q[i][a] = (1 - lr) * Q[i][a] + lr * (reward + factor * np.amax(Q_next[i])) 78 | 79 | # Input data to neural network 80 | self.model.fit(s_batch, Q, verbose=0) 81 | 82 | 83 | env = gym.make('MountainCar-v0') 84 | episodes = 1000 85 | score_list = [] 86 | agent = DQN() 87 | 88 | for i in range(episodes): 89 | s = env.reset() 90 | score = 0 91 | while True: 92 | a = agent.act(s) 93 | next_s, reward, done, _ = env.step(a) 94 | agent.remember(s, a, next_s, reward) 95 | agent.train() 96 | score += reward 97 | s = next_s 98 | 99 | if done: 100 | score_list.append(score) 101 | print ('Episode: {}, score: {}, currently max score: {}'.format(i, score, max(score_list))) 102 | 103 | break 104 | if np.mean(score_list[-10:]) > -180: 105 | agent.save_model() 106 | break 107 | env.close() 108 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/gym_test/mountaincar_q_learning.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import random 3 | import time 4 | import numpy as np 5 | import pickle 6 | from collections import defaultdict 7 | 8 | ''' 9 | MountainCar-v0 10 | ======================================== 11 | State: Position, Speed 12 | Action: 0-left 1-hold 2-right 13 | Reward: -1 per round 14 | Done: 200 round in total/Reach the peak 15 | ''' 16 | env = gym.make('MountainCar-v0') 17 | 18 | ''' 19 | Q-learning update rule 20 | ======================================================================= 21 | Q[s][a] = (1 - lr) * Q[s][a] + lr * (reward + factor * max(Q[next_s])) 22 | ======================================================================= 23 | s, a, next_s: current state, action, next state 24 | reward: reward for acting actions 25 | Q[s][a]: the quality of action a in state s 26 | max(Q[next_s]): the maximum quality of all actions in next state next_s 27 | lr: learning rate, bigger lr for less history experience 28 | factor: discount factor, bigger factor for more history experience 29 | ''' 30 | 31 | # Initialize Q-Table 32 | Q = defaultdict(lambda: [0, 0, 0]) 33 | 34 | 35 | def transform_state(state): 36 | ''' 37 | transform continous State(position, speed) to discrete state(40x40) 38 | ''' 39 | pos, v = state 40 | pos_low, v_low = env.observation_space.low 41 | pos_high, v_high = env.observation_space.high 42 | 43 | pos_int = 40 * (pos - pos_low) / (pos_high - pos_low) 44 | v_int = 40 * (v - v_low) / (v_high - v_low) 45 | 46 | return int(pos_int), int(v_int) 47 | 48 | 49 | lr, factor = 0.7, 0.55 50 | episodes = 10000 51 | score_list = [] 52 | 53 | for i in range(episodes): 54 | s = transform_state(env.reset()) 55 | score = 0 56 | while True: 57 | a = np.argmax(Q[s]) 58 | # Introduce more randomness 59 | if np.random.random() > i*3 / episodes: 60 | a = np.random.choice([0, 1, 2]) 61 | # Apply actions 62 | next_s, reward, done, _ = env.step(a) 63 | next_s = transform_state(next_s) 64 | # Update Q-Table 65 | Q[s][a] = (1-lr)*Q[s][a] + lr*(reward + factor*max(Q[next_s])) 66 | score += reward 67 | 68 | s = next_s 69 | 70 | if done: 71 | score_list.append(score) 72 | print ('Episode: {}, score: {}, currently max score: {}'.format(i, score, max(score_list))) 73 | break 74 | 75 | env.close() 76 | 77 | with open('./model/MountainCar-v0-q-learning.pickle', 'wb') as f: 78 | pickle.dump(dict(Q), f) 79 | print('model saved') 80 | 81 | s = env.reset() 82 | score = 0 83 | while True: 84 | env.render() 85 | time.sleep(0.01) 86 | # transform_state函数 与 训练时的一致 87 | s = transform_state(s) 88 | a = np.argmax(Q[s]) if s in Q else 0 89 | s, reward, done, _ = env.step(a) 90 | score += reward 91 | if done: 92 | print('score:', score) 93 | break 94 | env.close() 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/main.py: -------------------------------------------------------------------------------- 1 | from policy import PolicyGradientModel, clone_policy 2 | from sampler import BatchSampler 3 | import multiprocessing as mp 4 | 5 | 6 | sampler = BatchSampler('Maze-v0', 7 | batch_size=20, 8 | num_workers=mp.cpu_count() - 1) 9 | 10 | print (sampler.envs.observation_space.shape) 11 | print (sampler.envs.action_space.shape) 12 | 13 | 14 | tasks = sampler.sample_tasks(num_tasks=40) -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/maze.py: -------------------------------------------------------------------------------- 1 | import math 2 | import gym 3 | from gym import spaces, logger 4 | from gym.utils import seeding 5 | import numpy as np 6 | from gym.envs.classic_control import rendering 7 | 8 | 9 | class MazeEnv(gym.Env): 10 | def __init__(self, task={}): 11 | super(MazeEnv, self).__init__() 12 | # 0-up 1-down 2-left 3-right 13 | self.action_space = [0, 1, 2, 3] 14 | self.action_dim = len(self.action_space) 15 | self.observation_space = spaces.Discrete(2) 16 | self.action_space = spaces.Discrete(4) 17 | self.all_pos = [[i, j] for i in range(8) for j in range(8)] 18 | self.all_possible_traps = [[i, j] for i in range(1, 7) for j in range(1,7)] 19 | trap_index = np.random.randint(0, 36, 2) 20 | traps = [self.all_possible_traps[trap_index[0]], self.all_possible_traps[trap_index[1]]] 21 | self.all_possible_goal = [x for x in self.all_pos if x not in traps] 22 | goal_index = np.random.randint(0, len(self.all_possible_traps)) 23 | goal = self.all_possible_goal[goal_index] 24 | 25 | self.viewer = None 26 | 27 | # Set task and traps 28 | self._task = task 29 | self._trap = task.get('traps', traps) 30 | self._goal = task.get('goal', goal) 31 | self._state = np.zeros(2, dtype=np.int32) 32 | self.seed() 33 | 34 | def seed(self, seed=None): 35 | self.np_random, seed = seeding.np_random(seed) 36 | return [seed] 37 | 38 | def _out_of_maze(self, action): 39 | if action == 0: # up 40 | if self._state[1] + 1 > 7: 41 | return True 42 | else: 43 | return False 44 | if action == 1: # Down 45 | if self._state[1] - 1 < 0: 46 | return True 47 | else: 48 | return False 49 | if action == 2: # Left 50 | if self._state[0] - 1 < 0: 51 | return True 52 | else: 53 | return False 54 | if action == 3: # Right 55 | if self._state[0] + 1 > 7: 56 | return True 57 | else: 58 | return False 59 | 60 | def reset_task(self, task): 61 | self._task = task 62 | self._goal = task['goal'] 63 | self._trap = task['traps'] 64 | 65 | def reset(self, env=True): 66 | self._state = np.zeros(2, dtype=np.int32) 67 | return self._state 68 | 69 | def sample_task(self, num_tasks): 70 | 71 | 72 | def step(self, action): 73 | assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action)) 74 | 75 | if not self._out_of_maze(action): 76 | if action == 0: # Up 77 | next_state = [self._state[0], self._state[1]+1] 78 | elif action == 1: # Down 79 | next_state = [self._state[0], self._state[1]-1] 80 | elif action == 2: # Left 81 | next_state = [self._state[0]-1, self._state[1]] 82 | elif action == 3: # Right 83 | next_state = [self._state[0]+1, self._state[1]] 84 | 85 | print ('next_state: {}, goal: {}, traps: {}'.format(next_state, self._goal, self._trap)) 86 | 87 | if next_state == self._goal: 88 | print ('Congratz!') 89 | reward = 0 90 | done = True 91 | elif next_state in self._trap: 92 | reward = -16 - (abs(next_state[0] - self._goal[0]) + abs(next_state[1] - self._goal[1])) 93 | print ('ooooops! There is a trap here, reward: ', reward) 94 | done = False 95 | else: 96 | reward = -(abs(next_state[0] - self._goal[0]) + abs(next_state[1] - self._goal[1])) 97 | done = False 98 | 99 | self._state = next_state 100 | return np.array(self._state), reward, done, {} 101 | else: 102 | print("hold current position ({} {})".format(self._state[0], self._state[1])) 103 | self._state = self._state 104 | reward = -(abs(self._state[0] - self._goal[0]) + abs(self._state[1] - self._goal[1])) 105 | done = False 106 | return np.array(self._state), reward, done, {} 107 | 108 | def render(self, mode='rgb_array'): 109 | width = 880 110 | height = 880 111 | 112 | if self.viewer is None: 113 | # Initialize Viewer 114 | self.viewer = rendering.Viewer(width, height) 115 | line_1 = rendering.Line((110,0), (110,880)) 116 | line_1.set_color(0,0,0) 117 | line_2 = rendering.Line((220,0), (220,880)) 118 | line_2.set_color(0,0,0) 119 | line_3 = rendering.Line((330,0), (330,880)) 120 | line_3.set_color(0,0,0) 121 | line_4 = rendering.Line((440,0), (440,880)) 122 | line_4.set_color(0,0,0) 123 | line_5 = rendering.Line((550,0), (550,880)) 124 | line_5.set_color(0,0,0) 125 | line_6 = rendering.Line((660,0), (660,880)) 126 | line_6.set_color(0,0,0) 127 | line_7 = rendering.Line((770,0), (770,880)) 128 | line_7.set_color(0,0,0) 129 | line_8 = rendering.Line((0,110), (880,110)) 130 | line_8.set_color(0,0,0) 131 | line_9 = rendering.Line((0,220), (880,220)) 132 | line_9.set_color(0,0,0) 133 | line_10 = rendering.Line((0,330), (880,330)) 134 | line_10.set_color(0,0,0) 135 | line_11 = rendering.Line((0,440), (880,440)) 136 | line_11.set_color(0,0,0) 137 | line_12 = rendering.Line((0,550), (880,550)) 138 | line_12.set_color(0,0,0) 139 | line_13 = rendering.Line((0,660), (880,660)) 140 | line_13.set_color(0,0,0) 141 | line_14 = rendering.Line((0,770), (880,770)) 142 | line_14.set_color(0,0,0) 143 | 144 | self.viewer.add_geom(line_1) 145 | self.viewer.add_geom(line_2) 146 | self.viewer.add_geom(line_3) 147 | self.viewer.add_geom(line_4) 148 | self.viewer.add_geom(line_5) 149 | self.viewer.add_geom(line_6) 150 | self.viewer.add_geom(line_7) 151 | self.viewer.add_geom(line_8) 152 | self.viewer.add_geom(line_9) 153 | self.viewer.add_geom(line_10) 154 | self.viewer.add_geom(line_11) 155 | self.viewer.add_geom(line_12) 156 | self.viewer.add_geom(line_13) 157 | self.viewer.add_geom(line_14) 158 | 159 | # Create goal point 160 | goal_point = rendering.make_circle(20) 161 | goal_trans = rendering.Transform(translation=(self._goal[0]*110+55, self._goal[1]*110+55)) 162 | goal_point.add_attr(goal_trans) 163 | goal_point.set_color(1, 0.84, 0) 164 | self.viewer.add_geom(goal_point) 165 | 166 | # Create trap points 167 | trap_point_1 = rendering.make_circle(20) 168 | trap_trans_1 = rendering.Transform(translation=(self._trap[0][0]*110+55, self._trap[0][1]*110+55)) 169 | trap_point_1.add_attr(trap_trans_1) 170 | trap_point_1.set_color(0,0,0) 171 | self.viewer.add_geom(trap_point_1) 172 | 173 | trap_point_2 = rendering.make_circle(20) 174 | trap_trans_2 = rendering.Transform(translation=(self._trap[1][0]*110+55, self._trap[1][1]*110+55)) 175 | trap_point_2.add_attr(trap_trans_2) 176 | trap_point_2.set_color(0,0,0) 177 | self.viewer.add_geom(trap_point_2) 178 | 179 | # Create current state 180 | self.state_point = rendering.make_circle(20) 181 | self.state_trans = rendering.Transform(translation=(self._state[0]*110+55, self._state[1]*110+55)) 182 | self.state_point.add_attr(self.state_trans) 183 | self.state_point.set_color(0.25, 0.42, 0.88) 184 | self.viewer.add_geom(self.state_point) 185 | # Create current state 186 | self.state_trans.set_translation(self._state[0]*110+55, self._state[1]*110+55) 187 | self.state_point.set_color(0.25, 0.42, 0.88) 188 | self.viewer.add_geom(self.state_point) 189 | return self.viewer.render(return_rgb_array = mode == 'rgb_array') 190 | 191 | 192 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/maze_policy_gradient.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import scipy.misc 3 | import gym 4 | env = gym.make('Maze-v0') 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | STATE_DIM = 2 11 | ACTION_DIM = 4 12 | 13 | # Set the precsion of keras otherwise the sum of probability given by softmax will not be 1 14 | tf.keras.backend.set_floatx('float64') 15 | 16 | class PGModel(tf.keras.models.Model): 17 | def __init__(self, input_dim, output_dim): 18 | super(PGModel, self).__init__() 19 | self.input_dim = input_dim 20 | self.output_dim = output_dim 21 | 22 | # Initialize layers 23 | self.dense_1 = tf.keras.layers.Dense(128, input_shape=(None,self.input_dim), activation='relu') 24 | # tf.keras.layers.Dropout(0.1) 25 | self.all_act = tf.keras.layers.Dense(self.output_dim) 26 | 27 | def call(self, state): 28 | x = self.dense_1(state) 29 | x = self.all_act(x) 30 | self.logits = x 31 | output = tf.keras.activations.softmax(x) 32 | #output = tf.nn.softmax(x) 33 | return output, self.logits 34 | 35 | class PolicyGradient(object): 36 | def __init__( 37 | self, 38 | lr = 0.001, 39 | state_dim=STATE_DIM, 40 | action_dim=ACTION_DIM, 41 | reward_decay=0.95 42 | ): 43 | # Learning rate 44 | self.lr = lr 45 | # Dimension of state space 46 | self.state_dim = state_dim 47 | # Dimension of action space 48 | self.action_dim = action_dim 49 | # reward decay rate 50 | self.reward_decay = reward_decay 51 | # Observation, actions, reward of an episode 52 | self.ep_obs, self.ep_acts, self.ep_rs = [], [], [] 53 | # Policy Net 54 | self.model = PGModel(STATE_DIM, ACTION_DIM) 55 | # Optimizer 56 | self.optimizer = tf.keras.optimizers.Adam(self.lr) 57 | 58 | def loss_func(self, predict, actions, ep_rs_norm): 59 | actions = tf.one_hot(self.ep_acts, depth=self.action_dim) 60 | neg_log_prob = tf.nn.softmax_cross_entropy_with_logits(logits=predict, labels=actions) 61 | loss = tf.reduce_mean(neg_log_prob*ep_rs_norm) 62 | return loss 63 | 64 | def store_transition(self, s, a, r): 65 | self.ep_obs.append(s) 66 | self.ep_acts.append(a) 67 | self.ep_rs.append(r) 68 | 69 | def choose_action(self, state): 70 | prob_dist, _ = self.model(np.array([state])) 71 | action = np.random.choice(len(prob_dist[0]), p=prob_dist[0]) 72 | return action 73 | 74 | def discount_and_norm_reward(self): 75 | out = np.zeros_like(self.ep_rs) 76 | dis_reward = 0 77 | 78 | # Calculate reward with discount 79 | for i in reversed(range(len(self.ep_rs))): 80 | dis_reward = dis_reward + self.reward_decay * self.ep_rs[i] 81 | out[i] = dis_reward 82 | # Normalization 83 | out -= np.mean(out) 84 | out /= np.std(out) 85 | return out 86 | 87 | def train_op(self): 88 | discounted_reward = self.discount_and_norm_reward() 89 | 90 | with tf.GradientTape() as tape: 91 | prob_dist, logits = self.model(np.vstack(self.ep_obs)) 92 | loss = self.loss_func(logits, self.ep_acts, discounted_reward) 93 | grads = tape.gradient(loss, self.model.trainable_variables) 94 | self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) 95 | self.ep_obs, self.ep_acts, self.ep_rs = [], [], [] 96 | 97 | def save_model(self, file_path='./model/Maze-v0-PG.h5'): 98 | print ('Model saved') 99 | self.model.save(file_path) 100 | 101 | # Use default parameters 102 | agent_pg = PolicyGradient() 103 | # Make gym environment 104 | env = gym.make('Maze-v0') 105 | env.seed(1) 106 | env = env.unwrapped 107 | DISPLAY_REWARD_THRESHOLD = -120 # renders environment if total episode reward is greater then this threshold 108 | RENDER = False # rendering wastes time 109 | 110 | print(env.action_space) 111 | print(env.observation_space) 112 | 113 | ''' 114 | 此处训练时遇到了Policy Gradient的一个主要问题,就是沿 115 | ''' 116 | 117 | for ep_idx in range(2000): 118 | observation = env.reset() 119 | while True: 120 | if RENDER: 121 | env.render() 122 | # Choose action with current policy 123 | action = agent_pg.choose_action(observation) 124 | # Execute the action 125 | _obs, reward, done, info = env.step(action) 126 | # Store ob, action, reward 127 | agent_pg.store_transition(_obs, action, reward) 128 | # Update observation 129 | observation = _obs 130 | 131 | if done: 132 | ep_rs_sum = sum(agent_pg.ep_rs) 133 | 134 | if 'running_reward' not in globals(): 135 | running_reward = ep_rs_sum 136 | else: 137 | running_reward = running_reward * 0.99 + ep_rs_sum * 0.01 138 | if running_reward > DISPLAY_REWARD_THRESHOLD: 139 | RENDER = True 140 | print ('Episode: {} Reward: {}'.format(ep_idx, int(running_reward))) 141 | 142 | # Update parameters of policy using the policy gradient 143 | agent_pg.train_op() 144 | 145 | break 146 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/meta_learner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from collections import OrderedDict 4 | 5 | class BaseMetaLearner(object): 6 | 7 | def inner_loss(self, episodes, params=None): 8 | raise NotImplementedError 9 | 10 | def surrogate_loss(self, episodes, old_pis=None): 11 | raise NotImplementedError 12 | 13 | def adapt(self, episodes, first_order=False): 14 | raise NotImplementedError 15 | 16 | def step(self, episodes): 17 | raise NotImplementedError 18 | 19 | class MetaPolicyGradient(BaseMetaLearner): 20 | def __init__( 21 | self, 22 | policy, 23 | sampler, 24 | optimizer, 25 | gamma=0.95, 26 | inner_lr=0.01, 27 | outer_lr=0.001, 28 | ): 29 | self.policy = policy 30 | self.sampler = sampler 31 | self.optimizer = optimizer 32 | self.gamma = gamma 33 | self.inner_lr = inner_lr 34 | self.outer_lr = outer_lr 35 | 36 | self.ep_obs = [] 37 | self.ep_rs = [] 38 | self.ep_acts = [] 39 | 40 | def inner_loss(self, episodes, params=None): 41 | return 0 42 | 43 | 44 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/policy.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from collections import OrderedDict 4 | 5 | ''' 6 | PolicyGradientModel for maze env 7 | ''' 8 | 9 | def clone_policy(policy, params=None, with_name=False): 10 | if params is None: 11 | params = policy.get_trainable_variables 12 | 13 | assert isinstance(policy, PolicyGradientModel) 14 | 15 | cloned_policy = PolicyGradientModel( 16 | input_dim=policy.input_dim, 17 | output_dim=policy.output_dim, 18 | hidden_size=policy.hidden_size, 19 | name=policy.name 20 | ) 21 | print (cloned_policy.name) 22 | 23 | if with_name: 24 | cloned_policy.set_params_with_name(params) 25 | else: 26 | cloned_policy.set_params(params) 27 | 28 | return cloned_policy 29 | 30 | class PolicyGradientModel(tf.keras.Model): 31 | def __init__(self, input_dim=2, output_dim=4, hidden_size=(100,), name=None): 32 | super(PolicyGradientModel, self).__init__( 33 | name=name 34 | ) 35 | 36 | self.input_dim = input_dim 37 | self.output_dim = output_dim 38 | self.hidden_size = hidden_size 39 | self.nonlinearity = tf.nn.relu 40 | self.all_param = OrderedDict() 41 | 42 | layer_sizes = (self.input_dim,)+self.hidden_size 43 | self.num_layer = len(self.hidden_size) + 1 44 | kernel_init = tf.keras.initializers.glorot_uniform() 45 | bias_init = tf.zeros_initializer() 46 | 47 | for i in range(1, self.num_layer): 48 | with tf.name_scope('layer_{}'.format(i)): 49 | kernel = tf.Variable( 50 | initial_value=kernel_init(shape=(layer_sizes[i-1], layer_sizes[i]), dtype='float32'), 51 | name='kernel', 52 | trainable=True 53 | ) 54 | self.all_param[kernel.name] = kernel 55 | bias = tf.Variable( 56 | initial_value=bias_init(shape=(layer_sizes[i],), dtype='float32'), 57 | name='bias', 58 | trainable=True 59 | ) 60 | self.all_param[bias.name] = bias 61 | 62 | with tf.name_scope('prob_dist'): 63 | kernel = tf.Variable( 64 | initial_value=kernel_init(shape=(layer_sizes[-1], self.output_dim), dtype='float32'), 65 | name='kernel', 66 | trainable=True 67 | ) 68 | self.all_param[kernel.name] = kernel 69 | bias = tf.Variable( 70 | initial_value=bias_init(shape=(self.output_dim,), dtype='float32'), 71 | name='bias', 72 | trainable=True 73 | ) 74 | self.all_param[bias.name] = bias 75 | 76 | @property 77 | def get_trainable_variables(self): 78 | return list(self.trainable_variables) 79 | 80 | def set_params_with_name(self, var_list): 81 | old_var_list = self.get_trainable_variables 82 | for (name, var), old_var in zip(var_list.items(), old_var_list): 83 | old_var.assign(var) 84 | 85 | def set_params(self, var_list): 86 | old_var_list = self.get_trainable_variables 87 | for var, old_var in zip(var_list, old_var_list): 88 | old_var.assign(var) 89 | 90 | def update_params(self, grads, step_size=0.01): 91 | updated_params = OrderedDict() 92 | params_with_name = [(x.name, x) for x in self.get_trainable_variables] 93 | for (name, param), grad in zip(params_with_name, grads): 94 | updated_params[name] = tf.subtract(param, tf.multiply(step_size, grad)) 95 | 96 | return updated_params 97 | 98 | def forward(self, x, params=None): 99 | if params is None: 100 | params = self.get_trainable_variables 101 | params_dict = OrderedDict((v.name, v) for v in params) 102 | else: 103 | params_dict = params 104 | 105 | x = tf.convert_to_tensor(x) 106 | for i in range(1, self.num_layer): 107 | layer_name = self.name + 'layer_{}/'.format(i) 108 | kernel = params_dict[layer_name+'kernel:0'] 109 | bias = params_dict[layer_name+'bias:0'] 110 | x = tf.matmul(x, kernel) 111 | x = tf.add(x, bias) 112 | x = self.nonlinearity(x) 113 | 114 | kernel = params_dict[self.name + 'prob_dist/kernel:0'] 115 | bias = params_dict[self.name + 'prob_dist/bias:0'] 116 | x = tf.matmul(x, kernel) 117 | x = tf.add(x, bias) 118 | 119 | return x 120 | 121 | def __call__(self, x, params=None): 122 | return self.forward(x, params) 123 | 124 | 125 | if __name__ == '__main__': 126 | with tf.name_scope('Policy') as scope: 127 | policy = PolicyGradientModel(name=scope) 128 | #print(policy.all_param) 129 | print(type(policy.get_trainable_variables)) 130 | print(policy.name) 131 | print(policy.get_trainable_variables) 132 | cloned_policy = clone_policy(policy, policy.all_param, with_name=True) 133 | print (cloned_policy.get_trainable_variables) 134 | 135 | 136 | 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/sampler.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | 3 | import gym 4 | import tensorflow as tf 5 | 6 | 7 | from subproc_vec_env import SubprocVecEnv 8 | from episode import BatchEpisodes 9 | 10 | 11 | """ 12 | The code is taken from and perhaps will be changed in the future 13 | https://github.com/tristandeleu/pytorch-maml-rl/blob/master/maml_rl/sampler.py 14 | """ 15 | 16 | 17 | def make_env(env_name): 18 | def _make_env(): 19 | return gym.make(env_name) 20 | return _make_env 21 | 22 | 23 | class BatchSampler(object): 24 | def __init__(self, env_name, batch_size, num_workers=mp.cpu_count() - 1): 25 | self.env_name = env_name 26 | self.batch_size = batch_size 27 | self.num_workers = num_workers 28 | 29 | self.queue = mp.Queue() 30 | self.envs = SubprocVecEnv([make_env(env_name) for _ in range(num_workers)], 31 | queue=self.queue) 32 | self._env = gym.make(env_name) 33 | 34 | def sample(self, policy, params=None, gamma=0.95): 35 | episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma) 36 | for i in range(self.batch_size): 37 | self.queue.put(i) 38 | for _ in range(self.num_workers): 39 | self.queue.put(None) 40 | observations, batch_ids = self.envs.reset() 41 | dones = [False] 42 | while (not all(dones)) or (not self.queue.empty()): 43 | observations_tensor = observations 44 | actions_tensor = policy(observations_tensor, params=params).sample() 45 | with tf.device('/CPU:0'): 46 | actions = actions_tensor.numpy() 47 | new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(actions) 48 | episodes.append(observations, actions, rewards, batch_ids) 49 | observations, batch_ids = new_observations, new_batch_ids 50 | 51 | return episodes 52 | 53 | def reset_task(self, task): 54 | tasks = [task for _ in range(self.num_workers)] 55 | reset = self.envs.reset_task(tasks) 56 | return all(reset) 57 | 58 | def sample_tasks(self, num_tasks): 59 | tasks = self._env.unwrapped.sample_tasks(num_tasks) 60 | return tasks 61 | -------------------------------------------------------------------------------- /scripts/reinforcement_learning/maml-rl-easy/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import sys 3 | 4 | import gym 5 | import numpy as np 6 | 7 | is_py2 = (sys.version[0] == '2') 8 | if is_py2: 9 | import Queue as queue 10 | else: 11 | import queue as queue 12 | 13 | class EnvWorker(mp.Process): 14 | def __init__(self, remote, env_fn, queue, lock): 15 | super(EnvWorker, self).__init__() 16 | self.remote = remote 17 | self.env = env_fn() 18 | self.queue = queue 19 | self.lock = lock 20 | self.task_id = None 21 | self.done = False 22 | 23 | def empty_step(self): 24 | observation = np.zeros(self.env.observation_space.shape, 25 | dtype=np.float32) 26 | reward, done = 0.0, True 27 | return observation, reward, done, {} 28 | 29 | def try_reset(self): 30 | with self.lock: 31 | try: 32 | self.task_id = self.queue.get(True) 33 | self.done = (self.task_id is None) 34 | except queue.Empty: 35 | self.done = True 36 | observation = (np.zeros(self.env.observation_space.shape, 37 | dtype=np.float32) if self.done else self.env.reset()) 38 | return observation 39 | 40 | def run(self): 41 | while True: 42 | command, data = self.remote.recv() 43 | if command == 'step': 44 | observation, reward, done, info = (self.empty_step() 45 | if self.done else self.env.step(data)) 46 | if done and (not self.done): 47 | observation = self.try_reset() 48 | self.remote.send((observation, reward, done, self.task_id, info)) 49 | elif command == 'reset': 50 | observation = self.try_reset() 51 | self.remote.send((observation, self.task_id)) 52 | elif command == 'reset_task': 53 | self.env.unwrapped.reset_task(data) 54 | self.remote.send(True) 55 | elif command == 'close': 56 | self.remote.close() 57 | break 58 | elif command == 'get_spaces': 59 | self.remote.send((self.env.observation_space, 60 | self.env.action_space)) 61 | else: 62 | raise NotImplementedError() 63 | 64 | class SubprocVecEnv(gym.Env): 65 | def __init__(self, env_factory, queue): 66 | self.lock = mp.Lock() 67 | self.remotes, self.work_remotes = zip(*[mp.Pipe() for _ in env_factory]) 68 | self.workers = [EnvWorker(remote, env_fn, queue, self.lock) 69 | for (remote, env_fn) in zip(self.work_remotes, env_factory)] 70 | for worker in self.workers: 71 | worker.daemon = True 72 | worker.start() 73 | for remote in self.work_remotes: 74 | remote.close() 75 | self.waiting = False 76 | self.closed = False 77 | 78 | self.remotes[0].send(('get_spaces', None)) 79 | observation_space, action_space = self.remotes[0].recv() 80 | self.observation_space = observation_space 81 | self.action_space = action_space 82 | 83 | def step(self, actions): 84 | self.step_async(actions) 85 | return self.step_wait() 86 | 87 | def step_async(self, actions): 88 | for remote, action in zip(self.remotes, actions): 89 | remote.send(('step', action)) 90 | self.waiting = True 91 | 92 | def step_wait(self): 93 | results = [remote.recv() for remote in self.remotes] 94 | self.waiting = False 95 | observations, rewards, dones, task_ids, infos = zip(*results) 96 | return np.stack(observations), np.stack(rewards), np.stack(dones), task_ids, infos 97 | 98 | def reset(self): 99 | for remote in self.remotes: 100 | remote.send(('reset', None)) 101 | results = [remote.recv() for remote in self.remotes] 102 | observations, task_ids = zip(*results) 103 | return np.stack(observations), task_ids 104 | 105 | def reset_task(self, tasks): 106 | for remote, task in zip(self.remotes, tasks): 107 | remote.send(('reset_task', task)) 108 | return np.stack([remote.recv() for remote in self.remotes]) 109 | 110 | def close(self): 111 | if self.closed: 112 | return 113 | if self.waiting: 114 | for remote in self.remotes: 115 | remote.recv() 116 | for remote in self.remotes: 117 | remote.send(('close', None)) 118 | for worker in self.workers: 119 | worker.join() 120 | self.closed = True 121 | -------------------------------------------------------------------------------- /scripts/sine_fitting/sine_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from tensorflow.keras.utils import plot_model 3 | import tensorflow.keras as keras 4 | 5 | class SineModel(keras.Model): 6 | def __init__(self): 7 | super().__init__() 8 | self.hidden1 = keras.layers.Dense(40, input_shape=(1,)) 9 | self.hidden2 = keras.layers.Dense(40) 10 | self.out = keras.layers.Dense(1) 11 | 12 | def forward(self, x): 13 | x = keras.activations.relu(self.hidden1(x)) 14 | x = keras.activations.relu(self.hidden2(x)) 15 | x = self.out(x) 16 | return x 17 | 18 | -------------------------------------------------------------------------------- /scripts/sine_fitting/sinusoid_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | class SinusoidGenerator(): 5 | ''' 6 | Sinusoid Generator. 7 | 8 | p(T) is continuous, where the amplitude varies within [0.1, 5.0] 9 | and the phase varies within [0, π]. 10 | 11 | This abstraction is the basically the same defined at: 12 | https://towardsdatascience.com/paper-repro-deep-metalearning-using-maml-and-reptile-fd1df1cc81b0 13 | ''' 14 | def __init__(self, K=10, amplitude=None, phase=None): 15 | ''' 16 | Args: 17 | K: batch size. Number of values sampled at every batch. 18 | amplitude: Sine wave amplitude. If None is uniformly sampled from 19 | the [0.1, 5.0] interval. 20 | pahse: Sine wave phase. If None is uniformly sampled from the [0, π] 21 | interval. 22 | ''' 23 | self.K = K 24 | self.amplitude = amplitude if amplitude else np.random.uniform(0.1, 5.0) 25 | self.phase = phase if amplitude else np.random.uniform(0, np.pi) 26 | self.sampled_points = None 27 | self.x = self._sample_x() 28 | 29 | def _sample_x(self): 30 | return np.random.uniform(-5, 5, self.K) 31 | 32 | def f(self, x): 33 | '''Sinewave function.''' 34 | return self.amplitude * np.sin(x - self.phase) 35 | 36 | def batch(self, x = None, force_new=False): 37 | '''Returns a batch of size K. 38 | 39 | It also changes the sape of `x` to add a batch dimension to it. 40 | 41 | Args: 42 | x: Batch data, if given `y` is generated based on this data. 43 | Usually it is None. If None `self.x` is used. 44 | force_new: Instead of using `x` argument the batch data is 45 | uniformly sampled. 46 | 47 | ''' 48 | if x is None: 49 | if force_new: 50 | x = self._sample_x() 51 | else: 52 | x = self.x 53 | y = self.f(x) 54 | return x[:, None], y[:, None] 55 | 56 | def equally_spaced_samples(self, K=None): 57 | '''Returns `K` equally spaced samples.''' 58 | if K is None: 59 | K = self.K 60 | return self.batch(x=np.linspace(-5, 5, K)) 61 | 62 | 63 | def plot(data, *args, **kwargs): 64 | '''Plot helper.''' 65 | x, y = data 66 | return plt.plot(x, y, *args, **kwargs) 67 | 68 | 69 | 70 | def generate_dataset(K, train_size=20000, test_size=10): 71 | '''Generate train and test dataset. 72 | 73 | A dataset is composed of SinusoidGenerators that are able to provide 74 | a batch (`K`) elements at a time. 75 | ''' 76 | def _generate_dataset(size): 77 | return [SinusoidGenerator(K=K) for _ in range(size)] 78 | return _generate_dataset(train_size), _generate_dataset(test_size) 79 | -------------------------------------------------------------------------------- /scripts/sine_fitting/train_sine_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | import sys 4 | import random 5 | import numpy as np 6 | import tensorflow as tf 7 | import time 8 | import matplotlib.pyplot as plt 9 | from sinusoid_generator import SinusoidGenerator, generate_dataset 10 | from sine_model import SineModel 11 | 12 | plt.rcParams['font.sans-serif']=['SimSun'] #用来正常显示中文标签 13 | plt.rcParams['axes.unicode_minus']=False #用来正常显示负号 14 | 15 | tf.keras.backend.set_floatx('float64') 16 | 17 | 18 | 19 | def np_to_tensor(list_of_numpy_objs): 20 | return (tf.convert_to_tensor(obj) for obj in list_of_numpy_objs) 21 | 22 | def copy_model(model, x): 23 | copied_model = SineModel() 24 | copied_model.forward(x) 25 | copied_model.set_weights(model.get_weights()) 26 | return copied_model 27 | 28 | def loss_fn(y, pred_y): 29 | return tf.reduce_mean(tf.keras.metrics.mean_squared_error(y, pred_y)) 30 | 31 | def compute_loss(model, x, y , loss_fn=loss_fn): 32 | logits = model.forward(x) 33 | mse = loss_fn(logits, y) 34 | return mse, logits 35 | 36 | def compute_gradients(model, x, y, loss_fn=loss_fn): 37 | with tf.GradientTape() as tape: 38 | loss, logits = compute_loss(model, x, y, loss_fn) 39 | return tape.gradient(loss, model.trainable_variables), loss 40 | 41 | def apply_gradients(optimizer, gradients, variables): 42 | optimizer.apply_gradients(zip(gradients, variables)) 43 | 44 | def train_step(x, y, model, optimizer): 45 | tensor_x, tensor_y = np_to_tensor((x, y)) 46 | gradients, loss = compute_gradients(model, tensor_x, tensor_y) 47 | apply_gradients(optimizer, gradients, model.trainable_variables) 48 | return loss 49 | 50 | def regular_train(model, train_ds, epochs=1, lr=0.001, log_steps=1000): 51 | optimizer = tf.keras.optimizers.Adam(learning_rate=lr) 52 | for eopch in range(epochs): 53 | losses = [] 54 | total_loss = 0 55 | start = time.time() 56 | for i, sinusoid_generator in enumerate(train_ds): 57 | x, y = sinusoid_generator.batch() 58 | loss = train_step(x, y, model, optimizer) 59 | total_loss += loss 60 | curr_loss = total_loss / (i + 1.0) 61 | losses.append(curr_loss) 62 | 63 | if i % log_steps == 0 and i > 0: 64 | print('Step {}: loss = {}, Time to run {} steps = {:.2f} seconds'.format( 65 | i, curr_loss, log_steps, time.time() - start)) 66 | start = time.time() 67 | plt.plot(losses) 68 | plt.title('联合训练模型损失函数值随训练迭代次数的变化曲线') 69 | plt.show() 70 | return model 71 | 72 | def maml_train(model, train_ds, epochs=1, lr_inner=0.01, batch_size=1, log_steps=1000): 73 | optimizer = tf.keras.optimizers.Adam() 74 | for epoch in range(epochs): 75 | total_loss = 0 76 | losses = [] 77 | start = time.time() 78 | # 打乱生成的训练集 79 | for i, ds in enumerate(random.sample(train_ds, len(train_ds))): 80 | x, y = np_to_tensor(ds.batch()) 81 | model.forward(x) 82 | with tf.GradientTape() as test_tape: 83 | with tf.GradientTape() as train_tape: 84 | train_loss, _ = compute_loss(model, x, y) 85 | gradients = train_tape.gradient(train_loss, model.trainable_variables) 86 | k=0 87 | model_copy = copy_model(model, x) 88 | for j in range(len(model_copy.layers)): 89 | model_copy.layers[j].kernel = tf.subtract(model.layers[j].kernel, 90 | tf.multiply(lr_inner, gradients[k])) 91 | model_copy.layers[j].bias = tf.subtract(model.layers[j].bias, 92 | tf.multiply(lr_inner, gradients[k+1])) 93 | k+=2 94 | test_loss, logits = compute_loss(model_copy, x, y) 95 | gradients = test_tape.gradient(test_loss, model.trainable_variables) 96 | optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 97 | 98 | # Logs 99 | total_loss += test_loss 100 | loss = total_loss / (i+1.0) 101 | losses.append(loss) 102 | 103 | if i % log_steps == 0 and i > 0: 104 | print('Step {}: loss = {}, Time to run {} steps = {}'.format(i, loss, log_steps, time.time() - start)) 105 | start = time.time() 106 | plt.plot(losses) 107 | plt.title('MAML模型损失函数值随训练迭代次数的变化曲线') 108 | plt.show() 109 | return model 110 | 111 | def plot_model_comparison_to_average(model, ds, model_name='模型的预测值', K=5): 112 | '''Compare model to average. 113 | 114 | Computes mean of training sine waves actual `y` and compare to 115 | the model's prediction to a new sine wave, the intuition is that 116 | these two plots should be similar. 117 | ''' 118 | sinu_generator = SinusoidGenerator(K=K) 119 | 120 | # calculate average prediction 121 | avg_pred = [] 122 | for i, sinusoid_generator in enumerate(ds): 123 | x, y = sinusoid_generator.equally_spaced_samples() 124 | avg_pred.append(y) 125 | 126 | x, _ = sinu_generator.equally_spaced_samples() 127 | avg_plot, = plt.plot(x, np.mean(avg_pred, axis=0), '--') 128 | 129 | # calculate model prediction 130 | model_pred = model.forward(tf.convert_to_tensor(x)) 131 | model_plot, = plt.plot(x, model_pred.numpy()) 132 | 133 | # plot 134 | plt.legend([avg_plot, model_plot], ['训练数据的平均值', model_name]) 135 | plt.title(model_name + '与训练数据均值的比较曲线') 136 | plt.show() 137 | 138 | 139 | def eval_sine_test(model, optimizer, x, y, x_test, y_test, num_steps=(0, 1, 10)): 140 | '''Evaluate how the model fits to the curve training for `fits` steps. 141 | 142 | Args: 143 | model: Model evaluated. 144 | optimizer: Optimizer to be for training. 145 | x: Data used for training. 146 | y: Targets used for training. 147 | x_test: Data used for evaluation. 148 | y_test: Targets used for evaluation. 149 | num_steps: Number of steps to log. 150 | ''' 151 | fit_res = [] 152 | 153 | tensor_x_test, tensor_y_test = np_to_tensor((x_test, y_test)) 154 | 155 | # If 0 in fits we log the loss before any training 156 | if 0 in num_steps: 157 | loss, logits = compute_loss(model, tensor_x_test, tensor_y_test) 158 | fit_res.append((0, logits, loss)) 159 | 160 | for step in range(1, np.max(num_steps) + 1): 161 | train_step(x, y, model, optimizer) 162 | loss, logits = compute_loss(model, tensor_x_test, tensor_y_test) 163 | if step in num_steps: 164 | fit_res.append( 165 | ( 166 | step, 167 | logits, 168 | loss 169 | ) 170 | ) 171 | return fit_res 172 | 173 | 174 | def eval_sinewave_for_test(model, sinusoid_generator=None, num_steps=(0, 1, 10), lr=0.01, plot=True, name=None): 175 | '''Evaluates how the sinewave addapts at dataset. 176 | 177 | The idea is to use the pretrained model as a weight initializer and 178 | try to fit the model on this new dataset. 179 | 180 | Args: 181 | model: Already trained model. 182 | sinusoid_generator: A sinusoidGenerator instance. 183 | num_steps: Number of training steps to be logged. 184 | lr: Learning rate used for training on the test data. 185 | plot: If plot is True than it plots how the curves are fitted along 186 | `num_steps`. 187 | 188 | Returns: 189 | The fit results. A list containing the loss, logits and step. For 190 | every step at `num_steps`. 191 | ''' 192 | 193 | if sinusoid_generator is None: 194 | sinusoid_generator = SinusoidGenerator(K=5) 195 | 196 | # generate equally spaced samples for ploting 197 | x_test, y_test = sinusoid_generator.equally_spaced_samples(100) 198 | 199 | # batch used for training 200 | x, y = sinusoid_generator.batch() 201 | 202 | # copy model so we can use the same model multiple times 203 | copied_model = copy_model(model, x) 204 | 205 | # use SGD for this part of training as described in the paper 206 | optimizer = tf.keras.optimizers.SGD(learning_rate=lr) 207 | 208 | # run training and log fit results 209 | fit_res = eval_sine_test(copied_model, optimizer, x, y, x_test, y_test, num_steps) 210 | 211 | # plot 212 | train, = plt.plot(x, y, '^') 213 | ground_truth, = plt.plot(x_test, y_test) 214 | plots = [train, ground_truth] 215 | legend = ['采样点', '真实曲线'] 216 | for n, res, loss in fit_res: 217 | cur, = plt.plot(x_test, res[:, 0], '--') 218 | plots.append(cur) 219 | legend.append(f'{n} 次梯度迭代') 220 | plt.legend(plots, legend) 221 | plt.ylim(-5, 5) 222 | plt.xlim(-6, 6) 223 | plt.title(name) 224 | if plot: 225 | plt.show() 226 | 227 | return fit_res 228 | 229 | def compare_maml_and_neural_net(maml, neural_net, sinusoid_generator, num_steps=list(range(10)), 230 | intermediate_plot=True, marker='x', linestyle='--',figure_name=None): 231 | '''Compare the loss of a MAML model and a neural net. 232 | 233 | Fits the models for a new task (new sine wave) and then plot 234 | the loss of both models along `num_steps` interactions. 235 | 236 | Args: 237 | maml: An already trained MAML. 238 | neural_net: An already trained neural net. 239 | num_steps: Number of steps to be logged. 240 | intermediate_plot: If True plots intermediate plots from 241 | `eval_sinewave_for_test`. 242 | marker: Marker used for plotting. 243 | linestyle: Line style used for plotting. 244 | ''' 245 | if intermediate_plot: 246 | print('MAML') 247 | fit_maml = eval_sinewave_for_test(maml, sinusoid_generator, plot=intermediate_plot,name=figure_name) 248 | if intermediate_plot: 249 | print('Neural Net') 250 | fit_neural_net = eval_sinewave_for_test(neural_net, sinusoid_generator, plot=intermediate_plot,name=figure_name) 251 | 252 | fit_res = {'MAML模型': fit_maml, '联合训练模型': fit_neural_net} 253 | 254 | legend = [] 255 | for name in fit_res: 256 | x = [] 257 | y = [] 258 | for n, _, loss in fit_res[name]: 259 | x.append(n) 260 | y.append(loss) 261 | plt.plot(x, y, marker=marker, linestyle=linestyle) 262 | plt.xticks(num_steps) 263 | legend.append(name) 264 | plt.title('损失函数值随训练迭代次数的变化曲线') 265 | plt.legend(legend) 266 | plt.show() 267 | 268 | 269 | if __name__ == '__main__': 270 | model = SineModel() 271 | train_ds, test_ds = generate_dataset(K=5) 272 | name = '联合训练模型 K=5 lr=0.001' 273 | neural_model = regular_train(model, train_ds) 274 | plot_model_comparison_to_average(neural_model, train_ds, model_name='联合训练模型的预测值') 275 | for index in np.random.randint(0, len(test_ds), size=3): 276 | eval_sinewave_for_test(neural_model, test_ds[index],name=name) 277 | 278 | 279 | # model = SineModel() 280 | name = 'MAML模型 K=5, Alpha=0.01 Beta=0.001' 281 | maml_model = maml_train(model, train_ds) 282 | plot_model_comparison_to_average(maml_model, train_ds, model_name='MAML模型的预测值') 283 | for index in np.random.randint(0, len(test_ds), size=3): 284 | eval_sinewave_for_test(maml_model, test_ds[index],name=name) 285 | 286 | # for _ in range(3): 287 | # index = np.random.choice(range(len(test_ds))) 288 | # compare_maml_and_neural_net(maml_model, neural_model, test_ds[index]) -------------------------------------------------------------------------------- /scripts/sine_fitting/tri_sine_fitting.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import pretty_errors 3 | import tensorflow as tf 4 | import tensorflow.keras as keras 5 | import tensorflow.keras.backend as keras_backend 6 | 7 | 8 | # Other dependencies 9 | import random 10 | import sys 11 | import time 12 | 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | from mpl_toolkits.mplot3d import Axes3D 16 | 17 | # Reproduction 18 | np.random.seed(102) 19 | tf.keras.backend.set_floatx('float64') 20 | 21 | 22 | print('Python version: ', sys.version) 23 | print('TensorFlow version: ', tf.__version__) 24 | 25 | device_name = tf.test.gpu_device_name() 26 | if device_name != '/device:GPU:0': 27 | raise SystemError('GPU device not found') 28 | print('GPU found at: {}'.format(device_name)) 29 | 30 | class SinusoidSurfaceGenerator(): 31 | def __init__(self, batchsz=10, x_amplitude=None, x_phase=None, y_ampltude=None, y_phase=None): 32 | """ 33 | Generate a Function like: z = A1*Sin(x-P1) + A2*Cos(x-P2) 34 | :param x_amplitude: A1 35 | :param x_phase: P1 36 | :param y_amplitude: A2 37 | :param y_phase: P2 38 | :param batchsz: number of points in a batch 39 | 40 | """ 41 | self.batchsz = batchsz 42 | self.x_amplitude = x_amplitude if x_amplitude else np.random.uniform(0.1, 5.0) 43 | self.x_phase = x_phase if x_phase else np.random.uniform(0, np.pi) 44 | self.y_ampltude = y_ampltude if y_ampltude else np.random.uniform(0.1, 5.0) 45 | self.y_phase = y_phase if y_phase else np.random.uniform(0, np.pi) 46 | self.x_vector = self._sample_x() 47 | self.y_vector = self._sample_y() 48 | 49 | def _sample_x(self): 50 | return np.random.uniform(-5, 5, self.batchsz) 51 | 52 | def _sample_y(self): 53 | return np.random.uniform(-5, 5, self.batchsz) 54 | 55 | def f(self, x, y): 56 | ''' 57 | Sine Surface Function 58 | ''' 59 | return self.x_amplitude * np.sin(x - self.x_phase) + self.y_ampltude * np.cos(y - self.y_phase) 60 | 61 | def batch(self, x_vector=None, y_vector=None, force_new=False): 62 | ''' 63 | generate a batch of size batchsz 64 | z = f(x, y) 65 | :param (x, y): A point array 66 | :param force_new: if True, resample point array, else, use self.x & self.y 67 | :return (x, y) array and z vector 68 | ''' 69 | points = [] 70 | values = [] 71 | if x_vector is None: 72 | if force_new: 73 | x_vector = self._sample_x() 74 | else: 75 | x_vector = self.x_vector 76 | if y_vector is None: 77 | if force_new: 78 | y_vector = self._sample_y() 79 | else: 80 | y_vector = self.y_vector 81 | for x in x_vector: 82 | for y in y_vector: 83 | point = [x, y] 84 | points.append(point) 85 | z = self.f(x, y) 86 | values.append(z) 87 | return np.array(points), values 88 | 89 | def equally_spaced_samples(self, K=None): 90 | if K is None: 91 | K = self.batchsz 92 | return self.batch(x_vector=np.linspace(-5,5,K), y_vector=np.linspace(-5,5,K)) 93 | 94 | def plot_figure(self, x=None, y=None, z=None): 95 | ''' 96 | 3D surface 97 | 98 | ''' 99 | fig = plt.figure() 100 | axl = plt.gca(projection='3d') 101 | print ('using sine function: z = {}*Sin(x - {}) + {}*Cos(y - {})'.format(self.x_amplitude, self.x_phase, self.y_ampltude, self.y_phase)) 102 | if x is None and y is None and z is None: 103 | x_vector=np.linspace(-5,5,100) 104 | y_vector=np.linspace(-5,5,100) 105 | # Test batch generator 106 | # _, _, x_vector, y_vector = self.batch(x_vector=np.linspace(-5,5,self.batchsz), y_vector=np.linspace(-5,5,self.batchsz)) 107 | x, y = np.meshgrid(x_vector, y_vector) 108 | z = self.f(x, y) 109 | axl.plot_surface(x, y, z, cmap='Reds') 110 | plt.show() 111 | else: 112 | x, y = np.meshgrid(x, y ) 113 | axl.plot_surface(x, y, z, cmap='Reds') 114 | plt.show() 115 | 116 | def generate_dataset(batchsz, train_size=20000, test_size=10): 117 | ''' 118 | Generate dataset of size: train_size and test_size 119 | A dataset is composed of SinusoidGenerators that are able to provide 120 | a batch (`K`) elements at a time 121 | ''' 122 | def _generate_dataset(size): 123 | return [SinusoidSurfaceGenerator(batchsz=batchsz) for _ in range(size)] 124 | return _generate_dataset(train_size), _generate_dataset(test_size) 125 | 126 | class SineModel(keras.Model): 127 | def __init__(self): 128 | super().__init__() 129 | self.hidden1 = keras.layers.Dense(600, input_shape=(2,)) 130 | self.hidden2 = keras.layers.Dense(800) 131 | self.hidden3 = keras.layers.Dense(1000) 132 | self.out = keras.layers.Dense(1) 133 | 134 | def forward(self, x): 135 | x = keras.activations.relu(self.hidden1(x)) 136 | x = keras.activations.relu(self.hidden2(x)) 137 | x = self.out(x) 138 | return x 139 | 140 | def copy_model(model, x): 141 | '''Copy model weights to a new model. 142 | 143 | Args: 144 | model: model to be copied. 145 | x: An input example. This is used to run 146 | a forward pass in order to add the weights of the graph 147 | as variables. 148 | Returns: 149 | A copy of the model. 150 | ''' 151 | copied_model = SineModel() 152 | 153 | # If we don't run this step the weights are not "initialized" 154 | # and the gradients will not be computed. 155 | copied_model.forward(tf.convert_to_tensor(x)) 156 | 157 | copied_model.set_weights(model.get_weights()) 158 | return copied_model 159 | 160 | def loss_function(pred_y, y): 161 | return keras_backend.mean(keras.losses.mean_squared_error(y, pred_y)) 162 | 163 | def np_to_tensor(list_of_numpy_objs): 164 | return (tf.convert_to_tensor(obj) for obj in list_of_numpy_objs) 165 | 166 | 167 | def compute_loss(model, x, y, loss_fn=loss_function): 168 | logits = model.forward(x) 169 | mse = loss_fn(y, logits) 170 | return mse, logits 171 | 172 | 173 | def compute_gradients(model, x, y, loss_fn=loss_function): 174 | with tf.GradientTape() as tape: 175 | loss, _ = compute_loss(model, x, y, loss_fn) 176 | return tape.gradient(loss, model.trainable_variables), loss 177 | 178 | 179 | def apply_gradients(optimizer, gradients, variables): 180 | optimizer.apply_gradients(zip(gradients, variables)) 181 | 182 | 183 | def train_batch(x, y, model, optimizer): 184 | tensor_x, tensor_y = np_to_tensor((x, y)) 185 | gradients, loss = compute_gradients(model, tensor_x, tensor_y) 186 | apply_gradients(optimizer, gradients, model.trainable_variables) 187 | return loss 188 | 189 | def train_model(dataset, epochs=1, lr=0.01, log_steps=1000): 190 | model = SineModel() 191 | # optimizer = keras.optimizers.Adam(learning_rate=lr) 192 | optimizer = keras.optimizers.SGD(learning_rate=lr) 193 | for epoch in range(epochs): 194 | losses = [] 195 | total_loss = 0 196 | start = time.time() 197 | for i, sinusoid_generator in enumerate(dataset): 198 | x, y = sinusoid_generator.batch() 199 | loss = train_batch(x, y, model, optimizer) 200 | total_loss += loss 201 | curr_loss = total_loss / (i + 1.0) 202 | losses.append(curr_loss) 203 | 204 | if i % log_steps == 0 and i > 0: 205 | print('Step {}: loss = {}, Time to run {} steps = {:.2f} seconds'.format( 206 | i, curr_loss, log_steps, time.time() - start)) 207 | start = time.time() 208 | plt.plot(losses) 209 | plt.title('Loss Vs Time steps') 210 | plt.show() 211 | return model 212 | 213 | def eval_sine_test(model, optimizer, x, y, x_test, y_test, num_steps=(0, 1, 10)): 214 | '''Evaluate how the model fits to the curve training for `fits` steps. 215 | 216 | Args: 217 | model: Model evaluated. 218 | optimizer: Optimizer to be for training. 219 | x: Data used for training. 220 | y: Targets used for training. 221 | x_test: Data used for evaluation. 222 | y_test: Targets used for evaluation. 223 | num_steps: Number of steps to log. 224 | ''' 225 | fit_res = [] 226 | 227 | tensor_x_test, tensor_y_test = np_to_tensor((x_test, y_test)) 228 | 229 | # If 0 in fits we log the loss before any training 230 | if 0 in num_steps: 231 | loss, logits = compute_loss(model, tensor_x_test, tensor_y_test) 232 | fit_res.append((0, logits, loss)) 233 | 234 | for step in range(1, np.max(num_steps) + 1): 235 | train_batch(x, y, model, optimizer) 236 | loss, logits = compute_loss(model, tensor_x_test, tensor_y_test) 237 | if step in num_steps: 238 | fit_res.append( 239 | ( 240 | step, 241 | logits, 242 | loss 243 | ) 244 | ) 245 | return fit_res 246 | 247 | 248 | def eval_sinewave_for_test(model, sinusoid_generator=None, num_steps=(0, 1, 10), lr=0.01, plot=True): 249 | '''Evaluates how the sinewave addapts at dataset. 250 | 251 | The idea is to use the pretrained model as a weight initializer and 252 | try to fit the model on this new dataset. 253 | 254 | Args: 255 | model: Already trained model. 256 | sinusoid_generator: A sinusoidGenerator instance. 257 | num_steps: Number of training steps to be logged. 258 | lr: Learning rate used for training on the test data. 259 | plot: If plot is True than it plots how the curves are fitted along 260 | `num_steps`. 261 | 262 | Returns: 263 | The fit results. A list containing the loss, logits and step. For 264 | every step at `num_steps`. 265 | ''' 266 | 267 | if sinusoid_generator is None: 268 | sinusoid_generator = SinusoidSurfaceGenerator(batchsz=10) 269 | 270 | # generate equally spaced samples for ploting 271 | x_test, y_test = sinusoid_generator.equally_spaced_samples(100) 272 | 273 | # batch used for training 274 | x, y = sinusoid_generator.batch() 275 | 276 | # copy model so we can use the same model multiple times 277 | copied_model = copy_model(model, x) 278 | 279 | # use SGD for this part of training as described in the paper 280 | optimizer = keras.optimizers.SGD(learning_rate=lr) 281 | 282 | # run training and log fit results 283 | fit_res = eval_sine_test(copied_model, optimizer, x, y, x_test, y_test, num_steps) 284 | 285 | # plot 286 | train, = plt.plot(x, y, '^') 287 | ground_truth, = plt.plot(x_test, y_test) 288 | plots = [train, ground_truth] 289 | legend = ['Training Points', 'True Function'] 290 | for n, res, loss in fit_res: 291 | cur, = plt.plot(x_test, res[:, 0], '--') 292 | plots.append(cur) 293 | legend.append(f'After {n} Steps') 294 | plt.legend(plots, legend) 295 | plt.ylim(-5, 5) 296 | plt.xlim(-6, 6) 297 | if plot: 298 | plt.show() 299 | 300 | return fit_res 301 | 302 | if __name__ == '__main__': 303 | Generator = SinusoidSurfaceGenerator(batchsz=10) 304 | Generator.plot_figure() 305 | # Generator.equally_spaced_samples() 306 | # # Generator.plot_figure() 307 | # x, y = Generator.batch() 308 | # print (x, y) 309 | # tensor_x, tensor_y = np_to_tensor((x, y)) 310 | # # print (tensor_x, tensor_y) 311 | # model = SineModel() 312 | 313 | # # print (output) 314 | # mse, logits = compute_loss(model, tensor_x, tensor_y) 315 | # print (mse) 316 | # print (logits) 317 | # Generator.plot_figure() 318 | 319 | neural_model = SineModel() 320 | train_ds, test_ds = generate_dataset(batchsz=40) 321 | neural_model = train_model(train_ds) 322 | 323 | 324 | x_test, y_test = Generator.equally_spaced_samples(20) 325 | # print (x_test.T[0]) 326 | 327 | tensor_x, tensor_y = np_to_tensor((x_test, y_test)) 328 | output = neural_model.forward(tensor_x) 329 | print (output) 330 | x_test = x_test.T 331 | x = x_test[0] 332 | y = x_test[1] 333 | z = output.numpy() 334 | fig = plt.figure() 335 | axl = plt.gca(projection='3d') 336 | x, y = np.meshgrid(x, y) 337 | axl.plot_surface(x, y, z, cmap='Blues') 338 | plt.show() 339 | 340 | 341 | 342 | 343 | --------------------------------------------------------------------------------