├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── __pycache__ ├── config.cpython-36.pyc ├── model.cpython-36.pyc └── utils.cpython-36.pyc ├── checkpoint └── params_train.npz ├── config.py ├── log └── train_log │ ├── events.out.tfevents.1508176022.NB-USLP01540 │ ├── events.out.tfevents.1508176069.NB-USLP01540 │ ├── events.out.tfevents.1508176830.NB-USLP01540 │ ├── events.out.tfevents.1508177959.NB-USLP01540 │ ├── events.out.tfevents.1508178144.NB-USLP01540 │ ├── events.out.tfevents.1508178400.NB-USLP01540 │ ├── events.out.tfevents.1508180216.NB-USLP01540 │ ├── events.out.tfevents.1508181221.NB-USLP01540 │ ├── events.out.tfevents.1508194995.NB-USLP01540 │ ├── events.out.tfevents.1508265230.NB-USLP01540 │ ├── events.out.tfevents.1508265646.NB-USLP01540 │ ├── events.out.tfevents.1508271339.NB-USLP01540 │ ├── events.out.tfevents.1508271851.NB-USLP01540 │ ├── events.out.tfevents.1508273479.NB-USLP01540 │ ├── events.out.tfevents.1508278081.NB-USLP01540 │ └── events.out.tfevents.1508280187.NB-USLP01540 ├── main.py ├── model.py ├── tensorlayer ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── activation.cpython-36.pyc │ ├── cost.cpython-36.pyc │ ├── files.cpython-36.pyc │ ├── iterate.cpython-36.pyc │ ├── layers.cpython-36.pyc │ ├── nlp.cpython-36.pyc │ ├── ops.cpython-36.pyc │ ├── prepro.cpython-36.pyc │ ├── rein.cpython-36.pyc │ ├── utils.cpython-36.pyc │ └── visualize.cpython-36.pyc ├── activation.py ├── cost.py ├── db.py ├── files.py ├── iterate.py ├── layers.py ├── nlp.py ├── ops.py ├── prepro.py ├── rein.py ├── utils.py └── visualize.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | samples/ 2 | log/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow implementation of the paper "Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution" (CVPR 2017) 2 | 3 | This is a Tensorflow implementation using TensorLayer. 4 | Original paper and implementation using MatConNet can be found on their [project webpage](http://vllab1.ucmerced.edu/~wlai24/LapSRN/). 5 | 6 | ### Environment 7 | The implementation is tested using python 3.6 and cuda 8.0. 8 | 9 | ### Download repository: 10 | 11 | $ git clone https://github.com/zjuela/LapSRN-tensorflow.git 12 | 13 | ### Train model 14 | Specify dataset path in config.py file and run: 15 | 16 | $ python main.py 17 | 18 | The pre-trained model is trained using [NTIRE 2017](http://www.vision.ee.ethz.ch/ntire17/) challenge dataset. 19 | 20 | ### Test 21 | Run with your test image: 22 | 23 | $ python main.py -m test -f TESTIMAGE 24 | 25 | Results can be find in folder ./samples/ 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /checkpoint/params_train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/checkpoint/params_train.npz -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import json 3 | 4 | config = edict() 5 | 6 | config.model = edict() 7 | config.model.result_path = "samples" 8 | config.model.checkpoint_path = "checkpoint" 9 | config.model.log_path = "log" 10 | config.model.scale = 4 11 | config.model.resblock_depth = 10 12 | config.model.recursive_depth = 1 13 | 14 | config.valid = edict() 15 | config.valid.hr_folder_path = '/media/zhehu/DATA/Public_Dataset/NTIRE2017/DIV2K_valid_HR/' 16 | config.valid.lr_folder_path = '/media/zhehu/DATA/Public_Dataset/NTIRE2017/DIV2K_valid_LR_bicubic/X4/' 17 | 18 | config.train = edict() 19 | config.train.hr_folder_path = '/media/zhehu/DATA/Public_Dataset/NTIRE2017/DIV2K_train_HR/' 20 | config.train.lr_folder_path = '/media/zhehu/DATA/Public_Dataset/NTIRE2017/DIV2K_train_LR_bicubic/X4/' 21 | config.train.batch_size = 4 # use large number if you have enough memory 22 | config.train.in_patch_size = 64 23 | config.train.out_patch_size = config.model.scale * config.train.in_patch_size 24 | config.train.batch_size_each_folder = 30 25 | config.train.log_write = False 26 | config.train.lr_init = 5*1.e-6 27 | config.train.lr_decay = 0.5 28 | config.train.decay_iter = 10 29 | config.train.beta1 = 0.90 30 | config.train.n_epoch = 300 31 | config.train.dump_intermediate_result = True 32 | 33 | def log_config(filename, cfg): 34 | with open(filename, 'w') as f: 35 | f.write("================================================\n") 36 | f.write(json.dumps(cfg, indent=4)) 37 | f.write("\n================================================\n") 38 | -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508176022.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508176022.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508176069.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508176069.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508176830.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508176830.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508177959.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508177959.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508178144.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508178144.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508178400.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508178400.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508180216.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508180216.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508181221.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508181221.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508194995.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508194995.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508265230.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508265230.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508265646.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508265646.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508271339.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508271339.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508271851.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508271851.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508273479.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508273479.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508278081.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508278081.NB-USLP01540 -------------------------------------------------------------------------------- /log/train_log/events.out.tfevents.1508280187.NB-USLP01540: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/log/train_log/events.out.tfevents.1508280187.NB-USLP01540 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | import os, time, random 5 | import numpy as np 6 | import scipy 7 | 8 | import tensorflow as tf 9 | import tensorlayer as tl 10 | from model import * 11 | from utils import * 12 | from config import * 13 | 14 | ###====================== HYPER-PARAMETERS ===========================### 15 | batch_size = config.train.batch_size 16 | patch_size = config.train.in_patch_size 17 | ni = int(np.sqrt(config.train.batch_size)) 18 | 19 | 20 | 21 | def compute_charbonnier_loss(tensor1, tensor2, is_mean=True): 22 | epsilon = 1e-6 23 | if is_mean: 24 | loss = tf.reduce_mean(tf.reduce_mean(tf.sqrt(tf.square(tf.subtract(tensor1,tensor2))+epsilon), [1, 2, 3])) 25 | else: 26 | loss = tf.reduce_mean(tf.reduce_sum(tf.sqrt(tf.square(tf.subtract(tensor1,tensor2))+epsilon), [1, 2, 3])) 27 | 28 | return loss 29 | 30 | 31 | 32 | def load_file_list(): 33 | train_hr_file_list = [] 34 | train_lr_file_list = [] 35 | valid_hr_file_list = [] 36 | valid_lr_file_list = [] 37 | 38 | directory = config.train.hr_folder_path 39 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]: 40 | train_hr_file_list.append("%s%s"%(directory,filename)) 41 | 42 | directory = config.train.lr_folder_path 43 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]: 44 | train_lr_file_list.append("%s%s"%(directory,filename)) 45 | 46 | directory = config.valid.hr_folder_path 47 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]: 48 | valid_hr_file_list.append("%s%s"%(directory,filename)) 49 | 50 | directory = config.valid.lr_folder_path 51 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]: 52 | valid_lr_file_list.append("%s%s"%(directory,filename)) 53 | 54 | return sorted(train_hr_file_list),sorted(train_lr_file_list),sorted(valid_hr_file_list),sorted(valid_lr_file_list) 55 | 56 | 57 | 58 | def prepare_nn_data(hr_img_list, lr_img_list, idx_img=None): 59 | i = np.random.randint(len(hr_img_list)) if (idx_img is None) else idx_img 60 | 61 | input_image = get_imgs_fn(lr_img_list[i]) 62 | output_image = get_imgs_fn(hr_img_list[i]) 63 | scale = int(output_image.shape[0] / input_image.shape[0]) 64 | assert scale == config.model.scale 65 | 66 | out_patch_size = patch_size * scale 67 | input_batch = np.empty([batch_size,patch_size,patch_size,3]) 68 | output_batch = np.empty([batch_size,out_patch_size,out_patch_size,3]) 69 | 70 | for idx in range(batch_size): 71 | in_row_ind = random.randint(0,input_image.shape[0]-patch_size) 72 | in_col_ind = random.randint(0,input_image.shape[1]-patch_size) 73 | 74 | input_cropped = augment_imgs_fn(input_image[in_row_ind:in_row_ind+patch_size, 75 | in_col_ind:in_col_ind+patch_size]) 76 | input_cropped = normalize_imgs_fn(input_cropped) 77 | input_cropped = np.expand_dims(input_cropped,axis=0) 78 | input_batch[idx] = input_cropped 79 | 80 | out_row_ind = in_row_ind * scale 81 | out_col_ind = in_col_ind * scale 82 | output_cropped = output_image[out_row_ind:out_row_ind+out_patch_size, 83 | out_col_ind:out_col_ind+out_patch_size] 84 | output_cropped = normalize_imgs_fn(output_cropped) 85 | output_cropped = np.expand_dims(output_cropped,axis=0) 86 | output_batch[idx] = output_cropped 87 | 88 | return input_batch,output_batch 89 | 90 | 91 | 92 | def train(): 93 | save_dir = "%s/%s_train"%(config.model.result_path,tl.global_flag['mode']) 94 | checkpoint_dir = "%s"%(config.model.checkpoint_path) 95 | tl.files.exists_or_mkdir(save_dir) 96 | tl.files.exists_or_mkdir(checkpoint_dir) 97 | 98 | ###========================== DEFINE MODEL ============================### 99 | t_image = tf.placeholder('float32', [batch_size, patch_size, patch_size, 3], name='t_image_input') 100 | t_target_image = tf.placeholder('float32', [batch_size, patch_size*config.model.scale, patch_size*config.model.scale, 3], name='t_target_image') 101 | t_target_image_down = tf.image.resize_images(t_target_image, size=[patch_size*2, patch_size*2], method=0, align_corners=False) 102 | 103 | net_image2, net_grad2, net_image1, net_grad1 = LapSRN(t_image, is_train=True, reuse=False) 104 | net_image2.print_params(False) 105 | 106 | ## test inference 107 | net_image_test, net_grad_test, _, _ = LapSRN(t_image, is_train=False, reuse=True) 108 | 109 | ###========================== DEFINE TRAIN OPS ==========================### 110 | loss2 = compute_charbonnier_loss(net_image2.outputs, t_target_image, is_mean=True) 111 | loss1 = compute_charbonnier_loss(net_image1.outputs, t_target_image_down, is_mean=True) 112 | g_loss = loss1 + loss2 * 4 113 | g_vars = tl.layers.get_variables_with_name('LapSRN', True, True) 114 | 115 | with tf.variable_scope('learning_rate'): 116 | lr_v = tf.Variable(config.train.lr_init, trainable=False) 117 | 118 | g_optim = tf.train.AdamOptimizer(lr_v, beta1=config.train.beta1).minimize(g_loss, var_list=g_vars) 119 | 120 | ###========================== RESTORE MODEL =============================### 121 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) 122 | tl.layers.initialize_global_variables(sess) 123 | tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/params_{}.npz'.format(tl.global_flag['mode']), network=net_image2) 124 | 125 | ###========================== PRE-LOAD DATA ===========================### 126 | train_hr_list,train_lr_list,valid_hr_list,valid_lr_list = load_file_list() 127 | 128 | ###========================== INTERMEDIATE RESULT ===============================### 129 | sample_ind = 37 130 | sample_input_imgs,sample_output_imgs = prepare_nn_data(valid_hr_list,valid_lr_list,sample_ind) 131 | tl.vis.save_images(truncate_imgs_fn(sample_input_imgs), [ni, ni], save_dir+'/train_sample_input.png') 132 | tl.vis.save_images(truncate_imgs_fn(sample_output_imgs), [ni, ni], save_dir+'/train_sample_output.png') 133 | 134 | ###========================== TRAINING ====================### 135 | sess.run(tf.assign(lr_v, config.train.lr_init)) 136 | print(" ** learning rate: %f" % config.train.lr_init) 137 | 138 | for epoch in range(config.train.n_epoch): 139 | ## update learning rate 140 | if epoch != 0 and (epoch % config.train.decay_iter == 0): 141 | lr_decay = config.train.lr_decay ** (epoch // config.train.decay_iter) 142 | lr = config.train.lr_init * lr_decay 143 | sess.run(tf.assign(lr_v, lr)) 144 | print(" ** learning rate: %f" % (lr)) 145 | 146 | epoch_time = time.time() 147 | total_g_loss, n_iter = 0, 0 148 | 149 | ## load image data 150 | idx_list = np.random.permutation(len(train_hr_list)) 151 | for idx_file in range(len(idx_list)): 152 | step_time = time.time() 153 | batch_input_imgs,batch_output_imgs = prepare_nn_data(train_hr_list,train_lr_list,idx_file) 154 | errM, _ = sess.run([g_loss, g_optim], {t_image: batch_input_imgs, t_target_image: batch_output_imgs}) 155 | total_g_loss += errM 156 | n_iter += 1 157 | 158 | print("[*] Epoch: [%2d/%2d] time: %4.4fs, loss: %.8f" % (epoch, config.train.n_epoch, time.time() - epoch_time, total_g_loss/n_iter)) 159 | 160 | ## save model and evaluation on sample set 161 | if (epoch >= 0): 162 | tl.files.save_npz(net_image2.all_params, name=checkpoint_dir+'/params_{}.npz'.format(tl.global_flag['mode']), sess=sess) 163 | 164 | if config.train.dump_intermediate_result is True: 165 | sample_out, sample_grad_out = sess.run([net_image_test.outputs,net_grad_test.outputs], {t_image: sample_input_imgs})#; print('gen sub-image:', out.shape, out.min(), out.max()) 166 | tl.vis.save_images(truncate_imgs_fn(sample_out), [ni, ni], save_dir+'/train_predict_%d.png' % epoch) 167 | tl.vis.save_images(truncate_imgs_fn(np.abs(sample_grad_out)), [ni, ni], save_dir+'/train_grad_predict_%d.png' % epoch) 168 | 169 | 170 | 171 | def test(file): 172 | try: 173 | img = get_imgs_fn(file) 174 | except IOError: 175 | print('cannot open %s'%(file)) 176 | else: 177 | checkpoint_dir = config.model.checkpoint_path 178 | save_dir = "%s/%s"%(config.model.result_path,tl.global_flag['mode']) 179 | input_image = normalize_imgs_fn(img) 180 | 181 | size = input_image.shape 182 | print('Input size: %s,%s,%s'%(size[0],size[1],size[2])) 183 | t_image = tf.placeholder('float32', [None,size[0],size[1],size[2]], name='input_image') 184 | net_g, _, _, _ = LapSRN(t_image, is_train=False, reuse=False) 185 | 186 | ###========================== RESTORE G =============================### 187 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) 188 | tl.layers.initialize_global_variables(sess) 189 | tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/params_train.npz', network=net_g) 190 | 191 | ###======================= TEST =============================### 192 | start_time = time.time() 193 | out = sess.run(net_g.outputs, {t_image: [input_image]}) 194 | print("took: %4.4fs" % (time.time() - start_time)) 195 | 196 | tl.files.exists_or_mkdir(save_dir) 197 | tl.vis.save_image(truncate_imgs_fn(out[0,:,:,:]), save_dir+'/test_out.png') 198 | tl.vis.save_image(input_image, save_dir+'/test_input.png') 199 | 200 | 201 | 202 | if __name__ == '__main__': 203 | import argparse 204 | parser = argparse.ArgumentParser() 205 | parser.add_argument('-m', '--mode', choices=['train','test'], default='train', help='select mode') 206 | parser.add_argument('-f','--file', help='input file') 207 | 208 | args = parser.parse_args() 209 | 210 | tl.global_flag['mode'] = args.mode 211 | if tl.global_flag['mode'] == 'train': 212 | train() 213 | elif tl.global_flag['mode'] == 'test': 214 | if (args.file is None): 215 | raise Exception("Please enter input file name for test mode") 216 | test(args.file) 217 | else: 218 | raise Exception("Unknow --mode") 219 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorlayer as tl 6 | from tensorlayer.layers import * 7 | 8 | from config import * 9 | 10 | 11 | 12 | def lrelu(x): 13 | return tf.maximum(x*0.2,x) 14 | 15 | 16 | 17 | def LapSRNSingleLevel(net_image, net_feature, reuse=False): 18 | with tf.variable_scope("Model_level", reuse=reuse): 19 | tl.layers.set_name_reuse(reuse) 20 | 21 | net_tmp = net_feature 22 | # recursive block 23 | for d in range(config.model.resblock_depth): 24 | net_tmp = PReluLayer(net_tmp, name='prelu_D%s'%(d)) 25 | net_tmp = Conv2dLayer(net_tmp,shape=[3,3,64,64],strides=[1,1,1,1], 26 | name='conv_D%s'%(d), W_init=tf.contrib.layers.xavier_initializer()) 27 | 28 | # for r in range(1,config.model.recursive_depth): 29 | # for d in range(config.model.resblock_depth): 30 | # net_tmp = PReluLayer(net_tmp, name='prelu_R%s_D%s'%(r,d)) 31 | # net_tmp = Conv2dLayer(net_tmp,shape=[3,3,64,64],strides=[1,1,1,1], 32 | # name='conv_R%s_D%s'%(r,d), W_init=tf.contrib.layers.xavier_initializer()) 33 | 34 | net_feature = ElementwiseLayer(layer=[net_feature,net_tmp],combine_fn=tf.add,name='add_feature') 35 | 36 | net_feature = PReluLayer(net_feature, name='prelu_feature') 37 | net_feature = Conv2dLayer(net_feature,shape=[3,3,64,256],strides=[1,1,1,1], 38 | name='upconv_feature', W_init=tf.contrib.layers.xavier_initializer()) 39 | net_feature = SubpixelConv2d(net_feature,scale=2,n_out_channel=64, 40 | name='subpixel_feature') 41 | 42 | # add image back 43 | gradient_level = Conv2dLayer(net_feature,shape=[3,3,64,3],strides=[1,1,1,1],act=lrelu, 44 | name='grad', W_init=tf.contrib.layers.xavier_initializer()) 45 | net_image = Conv2dLayer(net_image,shape=[3,3,3,12],strides=[1,1,1,1], 46 | name='upconv_image', W_init=tf.contrib.layers.xavier_initializer()) 47 | net_image = SubpixelConv2d(net_image,scale=2,n_out_channel=3, 48 | name='subpixel_image') 49 | net_image = ElementwiseLayer(layer=[gradient_level,net_image],combine_fn=tf.add,name='add_image') 50 | 51 | return net_image, net_feature, gradient_level 52 | 53 | 54 | 55 | def LapSRN(inputs, is_train=False, reuse=False): 56 | n_level = int(np.log2(config.model.scale)) 57 | assert n_level >= 1 58 | 59 | with tf.variable_scope("LapSRN", reuse=reuse) as vs: 60 | tl.layers.set_name_reuse(reuse) 61 | 62 | shapes = tf.shape(inputs) 63 | inputs_level = InputLayer(inputs, name='input_level') 64 | 65 | net_feature = Conv2dLayer(inputs_level, shape=[3,3,3,64], strides=[1,1,1,1], 66 | W_init=tf.contrib.layers.xavier_initializer(), 67 | name='init_conv') 68 | net_image = inputs_level 69 | 70 | # 2X for each level 71 | net_image1, net_feature1, net_gradient1 = LapSRNSingleLevel(net_image, net_feature, reuse=reuse) 72 | net_image2, net_feature2, net_gradient2 = LapSRNSingleLevel(net_image1, net_feature1, reuse=True) 73 | 74 | return net_image2, net_gradient2, net_image1, net_gradient1 -------------------------------------------------------------------------------- /tensorlayer/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deep learning and Reinforcement learning library for Researchers and Engineers 3 | """ 4 | from __future__ import absolute_import 5 | 6 | 7 | try: 8 | install_instr = "Please make sure you install a recent enough version of TensorFlow." 9 | import tensorflow 10 | except ImportError: 11 | raise ImportError("__init__.py : Could not import TensorFlow." + install_instr) 12 | 13 | from . import activation 14 | from . import cost 15 | from . import files 16 | from . import iterate 17 | from . import layers 18 | from . import ops 19 | from . import utils 20 | from . import visualize 21 | from . import prepro 22 | from . import nlp 23 | from . import rein 24 | 25 | # alias 26 | act = activation 27 | vis = visualize 28 | 29 | __version__ = "1.5.0" 30 | 31 | global_flag = {} 32 | global_dict = {} 33 | -------------------------------------------------------------------------------- /tensorlayer/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/__pycache__/activation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/activation.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/__pycache__/cost.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/cost.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/__pycache__/files.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/files.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/__pycache__/iterate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/iterate.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/__pycache__/nlp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/nlp.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/__pycache__/ops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/ops.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/__pycache__/prepro.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/prepro.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/__pycache__/rein.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/rein.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/__pycache__/visualize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjuela/LapSRN-tensorflow/3714cefcf154576b8f0e911833f2618ecd9ae0e0/tensorlayer/__pycache__/visualize.cpython-36.pyc -------------------------------------------------------------------------------- /tensorlayer/activation.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | 6 | import tensorflow as tf 7 | 8 | def identity(x, name=None): 9 | """The identity activation function, Shortcut is ``linear``. 10 | 11 | Parameters 12 | ---------- 13 | x : a tensor input 14 | input(s) 15 | 16 | 17 | Returns 18 | -------- 19 | A `Tensor` with the same type as `x`. 20 | """ 21 | return x 22 | 23 | # Shortcut 24 | linear = identity 25 | 26 | def ramp(x=None, v_min=0, v_max=1, name=None): 27 | """The ramp activation function. 28 | 29 | Parameters 30 | ---------- 31 | x : a tensor input 32 | input(s) 33 | v_min : float 34 | if input(s) smaller than v_min, change inputs to v_min 35 | v_max : float 36 | if input(s) greater than v_max, change inputs to v_max 37 | name : a string or None 38 | An optional name to attach to this activation function. 39 | 40 | 41 | Returns 42 | -------- 43 | A `Tensor` with the same type as `x`. 44 | """ 45 | return tf.clip_by_value(x, clip_value_min=v_min, clip_value_max=v_max, name=name) 46 | 47 | def leaky_relu(x=None, alpha=0.1, name="LeakyReLU"): 48 | """The LeakyReLU, Shortcut is ``lrelu``. 49 | 50 | Modified version of ReLU, introducing a nonzero gradient for negative 51 | input. 52 | 53 | Parameters 54 | ---------- 55 | x : A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`, 56 | `int16`, or `int8`. 57 | alpha : `float`. slope. 58 | name : a string or None 59 | An optional name to attach to this activation function. 60 | 61 | Examples 62 | --------- 63 | >>> network = tl.layers.DenseLayer(network, n_units=100, name = 'dense_lrelu', 64 | ... act= lambda x : tl.act.lrelu(x, 0.2)) 65 | 66 | References 67 | ------------ 68 | - `Rectifier Nonlinearities Improve Neural Network Acoustic Models, Maas et al. (2013) `_ 69 | """ 70 | with tf.name_scope(name) as scope: 71 | # x = tf.nn.relu(x) 72 | # m_x = tf.nn.relu(-x) 73 | # x -= alpha * m_x 74 | x = tf.maximum(x, alpha * x) 75 | return x 76 | 77 | #Shortcut 78 | lrelu = leaky_relu 79 | 80 | def pixel_wise_softmax(output, name='pixel_wise_softmax'): 81 | """Return the softmax outputs of images, every pixels have multiple label, the sum of a pixel is 1. 82 | Usually be used for image segmentation. 83 | 84 | Parameters 85 | ------------ 86 | output : tensor 87 | - For 2d image, 4D tensor [batch_size, height, weight, channel], channel >= 2. 88 | - For 3d image, 5D tensor [batch_size, depth, height, weight, channel], channel >= 2. 89 | 90 | Examples 91 | --------- 92 | >>> outputs = pixel_wise_softmax(network.outputs) 93 | >>> dice_loss = 1 - dice_coe(outputs, y_, epsilon=1e-5) 94 | 95 | References 96 | ----------- 97 | - `tf.reverse `_ 98 | """ 99 | with tf.name_scope(name) as scope: 100 | return tf.nn.softmax(output) 101 | ## old implementation 102 | # exp_map = tf.exp(output) 103 | # if output.get_shape().ndims == 4: # 2d image 104 | # evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, True])) 105 | # elif output.get_shape().ndims == 5: # 3d image 106 | # evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, False, True])) 107 | # else: 108 | # raise Exception("output parameters should be 2d or 3d image, not %s" % str(output._shape)) 109 | # return tf.div(exp_map, evidence) 110 | -------------------------------------------------------------------------------- /tensorlayer/cost.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | 6 | import tensorflow as tf 7 | import numbers 8 | from tensorflow.python.framework import ops 9 | from tensorflow.python.ops import standard_ops 10 | 11 | ## Cost Functions 12 | 13 | def cross_entropy(output, target, name=None): 14 | """It is a softmax cross-entropy operation, returns the TensorFlow expression of cross-entropy of two distributions, implement 15 | softmax internally. See ``tf.nn.sparse_softmax_cross_entropy_with_logits``. 16 | 17 | Parameters 18 | ---------- 19 | output : Tensorflow variable 20 | A distribution with shape: [batch_size, n_feature]. 21 | target : Tensorflow variable 22 | A batch of index with shape: [batch_size, ]. 23 | name : string 24 | Name of this loss. 25 | 26 | Examples 27 | -------- 28 | >>> ce = tl.cost.cross_entropy(y_logits, y_target_logits, 'my_loss') 29 | 30 | References 31 | ----------- 32 | - About cross-entropy: `wiki `_.\n 33 | - The code is borrowed from: `here `_. 34 | """ 35 | try: # old 36 | return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output, targets=target)) 37 | except: # TF 1.0 38 | assert name is not None, "Please give a unique name to tl.cost.cross_entropy for TF1.0+" 39 | return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output, name=name)) 40 | 41 | def sigmoid_cross_entropy(output, target, name=None): 42 | """It is a sigmoid cross-entropy operation, see ``tf.nn.sigmoid_cross_entropy_with_logits``. 43 | """ 44 | try: # TF 1.0 45 | return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output, name=name)) 46 | except: 47 | return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output, targets=target)) 48 | 49 | 50 | def binary_cross_entropy(output, target, epsilon=1e-8, name='bce_loss'): 51 | """Computes binary cross entropy given `output`. 52 | 53 | For brevity, let `x = output`, `z = target`. The binary cross entropy loss is 54 | 55 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i])) 56 | 57 | Parameters 58 | ---------- 59 | output : tensor of type `float32` or `float64`. 60 | target : tensor of the same type and shape as `output`. 61 | epsilon : float 62 | A small value to avoid output is zero. 63 | name : string 64 | An optional name to attach to this layer. 65 | 66 | References 67 | ----------- 68 | - `DRAW `_ 69 | """ 70 | # from tensorflow.python.framework import ops 71 | # with ops.op_scope([output, target], name, "bce_loss") as name: 72 | # output = ops.convert_to_tensor(output, name="preds") 73 | # target = ops.convert_to_tensor(targets, name="target") 74 | with tf.name_scope(name): 75 | return tf.reduce_mean(tf.reduce_sum(-(target * tf.log(output + epsilon) + 76 | (1. - target) * tf.log(1. - output + epsilon)), axis=1)) 77 | 78 | 79 | def mean_squared_error(output, target, is_mean=False): 80 | """Return the TensorFlow expression of mean-squre-error of two distributions. 81 | 82 | Parameters 83 | ---------- 84 | output : 2D or 4D tensor. 85 | target : 2D or 4D tensor. 86 | is_mean : boolean, if True, use ``tf.reduce_mean`` to compute the loss of one data, otherwise, use ``tf.reduce_sum`` (default). 87 | 88 | References 89 | ------------ 90 | - `Wiki Mean Squared Error `_ 91 | """ 92 | with tf.name_scope("mean_squared_error_loss"): 93 | if output.get_shape().ndims == 2: # [batch_size, n_feature] 94 | if is_mean: 95 | mse = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(output, target), 1)) 96 | else: 97 | mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), 1)) 98 | elif output.get_shape().ndims == 4: # [batch_size, w, h, c] 99 | if is_mean: 100 | mse = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(output, target), [1, 2, 3])) 101 | else: 102 | mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), [1, 2, 3])) 103 | return mse 104 | 105 | def normalized_mean_square_error(output, target): 106 | """Return the TensorFlow expression of normalized mean-squre-error of two distributions. 107 | 108 | Parameters 109 | ---------- 110 | output : 2D or 4D tensor. 111 | target : 2D or 4D tensor. 112 | """ 113 | with tf.name_scope("mean_squared_error_loss"): 114 | if output.get_shape().ndims == 2: # [batch_size, n_feature] 115 | nmse_a = tf.sqrt(tf.reduce_sum(tf.squared_difference(output, target), axis=1)) 116 | nmse_b = tf.sqrt(tf.reduce_sum(tf.square(target), axis=1)) 117 | elif output.get_shape().ndims == 4: # [batch_size, w, h, c] 118 | nmse_a = tf.sqrt(tf.reduce_sum(tf.squared_difference(output, target), axis=[1,2,3])) 119 | nmse_b = tf.sqrt(tf.reduce_sum(tf.square(target), axis=[1,2,3])) 120 | nmse = tf.reduce_mean(nmse_a / nmse_b) 121 | return nmse 122 | 123 | 124 | def dice_coe(output, target, epsilon=1e-10): 125 | """Sørensen–Dice coefficient for comparing the similarity of two distributions, 126 | usually be used for binary image segmentation i.e. labels are binary. 127 | The coefficient = [0, 1], 1 if totally match. 128 | 129 | Parameters 130 | ----------- 131 | output : tensor 132 | A distribution with shape: [batch_size, ....], (any dimensions). 133 | target : tensor 134 | A distribution with shape: [batch_size, ....], (any dimensions). 135 | epsilon : float 136 | An optional name to attach to this layer. 137 | 138 | Examples 139 | --------- 140 | >>> outputs = tl.act.pixel_wise_softmax(network.outputs) 141 | >>> dice_loss = 1 - tl.cost.dice_coe(outputs, y_, epsilon=1e-5) 142 | 143 | References 144 | ----------- 145 | - `wiki-dice `_ 146 | """ 147 | # inse = tf.reduce_sum( tf.mul(output, target) ) 148 | # l = tf.reduce_sum( tf.mul(output, output) ) 149 | # r = tf.reduce_sum( tf.mul(target, target) ) 150 | inse = tf.reduce_sum( output * target ) 151 | l = tf.reduce_sum( output * output ) 152 | r = tf.reduce_sum( target * target ) 153 | dice = 2 * (inse) / (l + r) 154 | if epsilon == 0: 155 | return dice 156 | else: 157 | return tf.clip_by_value(dice, 0, 1.0-epsilon) 158 | 159 | 160 | def dice_hard_coe(output, target, epsilon=1e-10): 161 | """Non-differentiable Sørensen–Dice coefficient for comparing the similarity of two distributions, 162 | usually be used for binary image segmentation i.e. labels are binary. 163 | The coefficient = [0, 1], 1 if totally match. 164 | 165 | Parameters 166 | ----------- 167 | output : tensor 168 | A distribution with shape: [batch_size, ....], (any dimensions). 169 | target : tensor 170 | A distribution with shape: [batch_size, ....], (any dimensions). 171 | epsilon : float 172 | An optional name to attach to this layer. 173 | 174 | Examples 175 | --------- 176 | >>> outputs = pixel_wise_softmax(network.outputs) 177 | >>> dice_loss = 1 - dice_coe(outputs, y_, epsilon=1e-5) 178 | 179 | References 180 | ----------- 181 | - `wiki-dice `_ 182 | """ 183 | output = tf.cast(output > 0.5, dtype=tf.float32) 184 | target = tf.cast(target > 0.5, dtype=tf.float32) 185 | inse = tf.reduce_sum( output * target ) 186 | l = tf.reduce_sum( output * output ) 187 | r = tf.reduce_sum( target * target ) 188 | dice = 2 * (inse) / (l + r) 189 | if epsilon == 0: 190 | return dice 191 | else: 192 | return tf.clip_by_value(dice, 0, 1.0-epsilon) 193 | 194 | def iou_coe(output, target, threshold=0.5, epsilon=1e-10): 195 | """Non-differentiable Intersection over Union, usually be used for evaluating binary image segmentation. 196 | The coefficient = [0, 1], 1 means totally match. 197 | 198 | Parameters 199 | ----------- 200 | output : tensor 201 | A distribution with shape: [batch_size, ....], (any dimensions). 202 | target : tensor 203 | A distribution with shape: [batch_size, ....], (any dimensions). 204 | threshold : float 205 | The threshold value to be true. 206 | epsilon : float 207 | A small value to avoid zero denominator when both output and target output nothing. 208 | 209 | Examples 210 | --------- 211 | >>> outputs = tl.act.pixel_wise_softmax(network.outputs) 212 | >>> iou = tl.cost.iou_coe(outputs[:,:,:,0], y_[:,:,:,0]) 213 | 214 | Notes 215 | ------ 216 | - IOU cannot be used as training loss, people usually use dice coefficient for training, and IOU for evaluating. 217 | """ 218 | pre = tf.cast(output > threshold, dtype=tf.float32) 219 | truth = tf.cast(target > threshold, dtype=tf.float32) 220 | intersection = tf.reduce_sum(pre * truth) 221 | union = tf.reduce_sum(tf.cast((pre + truth) > threshold, dtype=tf.float32)) 222 | return tf.reduce_sum(intersection) / (tf.reduce_sum(union) + epsilon) 223 | 224 | 225 | def cross_entropy_seq(logits, target_seqs, batch_size=None):#, batch_size=1, num_steps=None): 226 | """Returns the expression of cross-entropy of two sequences, implement 227 | softmax internally. Normally be used for Fixed Length RNN outputs. 228 | 229 | Parameters 230 | ---------- 231 | logits : Tensorflow variable 232 | 2D tensor, ``network.outputs``, [batch_size*n_steps (n_examples), number of output units] 233 | target_seqs : Tensorflow variable 234 | target : 2D tensor [batch_size, n_steps], if the number of step is dynamic, please use ``cross_entropy_seq_with_mask`` instead. 235 | batch_size : None or int. 236 | If not None, the return cost will be divided by batch_size. 237 | 238 | Examples 239 | -------- 240 | >>> see PTB tutorial for more details 241 | >>> input_data = tf.placeholder(tf.int32, [batch_size, num_steps]) 242 | >>> targets = tf.placeholder(tf.int32, [batch_size, num_steps]) 243 | >>> cost = tl.cost.cross_entropy_seq(network.outputs, targets) 244 | """ 245 | try: # TF 1.0 246 | sequence_loss_by_example_fn = tf.contrib.legacy_seq2seq.sequence_loss_by_example 247 | except: 248 | sequence_loss_by_example_fn = tf.nn.seq2seq.sequence_loss_by_example 249 | 250 | loss = sequence_loss_by_example_fn( 251 | [logits], 252 | [tf.reshape(target_seqs, [-1])], 253 | [tf.ones_like(tf.reshape(target_seqs, [-1]), dtype=tf.float32)]) 254 | # [tf.ones([batch_size * num_steps])]) 255 | cost = tf.reduce_sum(loss) #/ batch_size 256 | if batch_size is not None: 257 | cost = cost / batch_size 258 | return cost 259 | 260 | 261 | def cross_entropy_seq_with_mask(logits, target_seqs, input_mask, return_details=False, name=None): 262 | """Returns the expression of cross-entropy of two sequences, implement 263 | softmax internally. Normally be used for Dynamic RNN outputs. 264 | 265 | Parameters 266 | ----------- 267 | logits : network identity outputs 268 | 2D tensor, ``network.outputs``, [batch_size, number of output units]. 269 | target_seqs : int of tensor, like word ID. 270 | [batch_size, ?] 271 | input_mask : the mask to compute loss 272 | The same size with target_seqs, normally 0 and 1. 273 | return_details : boolean 274 | - If False (default), only returns the loss. 275 | - If True, returns the loss, losses, weights and targets (reshape to one vetcor). 276 | 277 | Examples 278 | -------- 279 | - see Image Captioning Example. 280 | """ 281 | targets = tf.reshape(target_seqs, [-1]) # to one vector 282 | weights = tf.to_float(tf.reshape(input_mask, [-1])) # to one vector like targets 283 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=targets, name=name) * weights 284 | #losses = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=targets, name=name)) # for TF1.0 and others 285 | 286 | try: ## TF1.0 287 | loss = tf.divide(tf.reduce_sum(losses), # loss from mask. reduce_sum before element-wise mul with mask !! 288 | tf.reduce_sum(weights), 289 | name="seq_loss_with_mask") 290 | except: ## TF0.12 291 | loss = tf.div(tf.reduce_sum(losses), # loss from mask. reduce_sum before element-wise mul with mask !! 292 | tf.reduce_sum(weights), 293 | name="seq_loss_with_mask") 294 | if return_details: 295 | return loss, losses, weights, targets 296 | else: 297 | return loss 298 | 299 | 300 | def cosine_similarity(v1, v2): 301 | """Cosine similarity [-1, 1], `wiki `_. 302 | 303 | Parameters 304 | ----------- 305 | v1, v2 : tensor of [batch_size, n_feature], with the same number of features. 306 | 307 | Returns 308 | ----------- 309 | a tensor of [batch_size, ] 310 | """ 311 | try: ## TF1.0 312 | cost = tf.reduce_sum(tf.multiply(v1, v2), 1) / (tf.sqrt(tf.reduce_sum(tf.multiply(v1, v1), 1)) * tf.sqrt(tf.reduce_sum(tf.multiply(v2, v2), 1))) 313 | except: ## TF0.12 314 | cost = tf.reduce_sum(tf.mul(v1, v2), reduction_indices=1) / (tf.sqrt(tf.reduce_sum(tf.mul(v1, v1), reduction_indices=1)) * tf.sqrt(tf.reduce_sum(tf.mul(v2, v2), reduction_indices=1))) 315 | return cost 316 | 317 | 318 | ## Regularization Functions 319 | def li_regularizer(scale, scope=None): 320 | """li regularization removes the neurons of previous layer, `i` represents `inputs`.\n 321 | Returns a function that can be used to apply group li regularization to weights.\n 322 | The implementation follows `TensorFlow contrib `_. 323 | 324 | Parameters 325 | ---------- 326 | scale : float 327 | A scalar multiplier `Tensor`. 0.0 disables the regularizer. 328 | scope: An optional scope name for TF12+. 329 | 330 | Returns 331 | -------- 332 | A function with signature `li(weights, name=None)` that apply Li regularization. 333 | 334 | Raises 335 | ------ 336 | ValueError : if scale is outside of the range [0.0, 1.0] or if scale is not a float. 337 | """ 338 | import numbers 339 | from tensorflow.python.framework import ops 340 | from tensorflow.python.ops import standard_ops 341 | # from tensorflow.python.platform import tf_logging as logging 342 | 343 | if isinstance(scale, numbers.Integral): 344 | raise ValueError('scale cannot be an integer: %s' % scale) 345 | if isinstance(scale, numbers.Real): 346 | if scale < 0.: 347 | raise ValueError('Setting a scale less than 0 on a regularizer: %g' % 348 | scale) 349 | if scale >= 1.: 350 | raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % 351 | scale) 352 | if scale == 0.: 353 | logging.info('Scale of 0 disables regularizer.') 354 | return lambda _, name=None: None 355 | 356 | def li(weights, name=None): 357 | """Applies li regularization to weights.""" 358 | with tf.name_scope('li_regularizer') as scope: 359 | my_scale = ops.convert_to_tensor(scale, 360 | dtype=weights.dtype.base_dtype, 361 | name='scale') 362 | if tf.__version__ <= '0.12': 363 | standard_ops_fn = standard_ops.mul 364 | else: 365 | standard_ops_fn = standard_ops.multiply 366 | return standard_ops_fn( 367 | my_scale, 368 | standard_ops.reduce_sum(standard_ops.sqrt(standard_ops.reduce_sum(tf.square(weights), 1))), 369 | name=scope) 370 | return li 371 | 372 | 373 | 374 | def lo_regularizer(scale, scope=None): 375 | """lo regularization removes the neurons of current layer, `o` represents `outputs`\n 376 | Returns a function that can be used to apply group lo regularization to weights.\n 377 | The implementation follows `TensorFlow contrib `_. 378 | 379 | Parameters 380 | ---------- 381 | scale : float 382 | A scalar multiplier `Tensor`. 0.0 disables the regularizer. 383 | scope: An optional scope name for TF12+. 384 | 385 | Returns 386 | ------- 387 | A function with signature `lo(weights, name=None)` that apply Lo regularization. 388 | 389 | Raises 390 | ------ 391 | ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float. 392 | """ 393 | import numbers 394 | from tensorflow.python.framework import ops 395 | from tensorflow.python.ops import standard_ops 396 | # from tensorflow.python.platform import tf_logging as logging 397 | 398 | if isinstance(scale, numbers.Integral): 399 | raise ValueError('scale cannot be an integer: %s' % scale) 400 | if isinstance(scale, numbers.Real): 401 | if scale < 0.: 402 | raise ValueError('Setting a scale less than 0 on a regularizer: %g' % 403 | scale) 404 | if scale >= 1.: 405 | raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % 406 | scale) 407 | if scale == 0.: 408 | logging.info('Scale of 0 disables regularizer.') 409 | return lambda _, name=None: None 410 | 411 | def lo(weights, name='lo_regularizer'): 412 | """Applies group column regularization to weights.""" 413 | with tf.name_scope(name) as scope: 414 | my_scale = ops.convert_to_tensor(scale, 415 | dtype=weights.dtype.base_dtype, 416 | name='scale') 417 | if tf.__version__ <= '0.12': 418 | standard_ops_fn = standard_ops.mul 419 | else: 420 | standard_ops_fn = standard_ops.multiply 421 | return standard_ops_fn( 422 | my_scale, 423 | standard_ops.reduce_sum(standard_ops.sqrt(standard_ops.reduce_sum(tf.square(weights), 0))), 424 | name=scope) 425 | return lo 426 | 427 | def maxnorm_regularizer(scale=1.0, scope=None): 428 | """Max-norm regularization returns a function that can be used 429 | to apply max-norm regularization to weights. 430 | About max-norm: `wiki `_.\n 431 | The implementation follows `TensorFlow contrib `_. 432 | 433 | Parameters 434 | ---------- 435 | scale : float 436 | A scalar multiplier `Tensor`. 0.0 disables the regularizer. 437 | scope: An optional scope name. 438 | 439 | Returns 440 | --------- 441 | A function with signature `mn(weights, name=None)` that apply Lo regularization. 442 | 443 | Raises 444 | -------- 445 | ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float. 446 | """ 447 | import numbers 448 | from tensorflow.python.framework import ops 449 | from tensorflow.python.ops import standard_ops 450 | 451 | if isinstance(scale, numbers.Integral): 452 | raise ValueError('scale cannot be an integer: %s' % scale) 453 | if isinstance(scale, numbers.Real): 454 | if scale < 0.: 455 | raise ValueError('Setting a scale less than 0 on a regularizer: %g' % 456 | scale) 457 | # if scale >= 1.: 458 | # raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % 459 | # scale) 460 | if scale == 0.: 461 | logging.info('Scale of 0 disables regularizer.') 462 | return lambda _, name=None: None 463 | 464 | def mn(weights, name='max_regularizer'): 465 | """Applies max-norm regularization to weights.""" 466 | with tf.name_scope(name) as scope: 467 | my_scale = ops.convert_to_tensor(scale, 468 | dtype=weights.dtype.base_dtype, 469 | name='scale') 470 | if tf.__version__ <= '0.12': 471 | standard_ops_fn = standard_ops.mul 472 | else: 473 | standard_ops_fn = standard_ops.multiply 474 | return standard_ops_fn(my_scale, standard_ops.reduce_max(standard_ops.abs(weights)), name=scope) 475 | return mn 476 | 477 | def maxnorm_o_regularizer(scale, scope): 478 | """Max-norm output regularization removes the neurons of current layer.\n 479 | Returns a function that can be used to apply max-norm regularization to each column of weight matrix.\n 480 | The implementation follows `TensorFlow contrib `_. 481 | 482 | Parameters 483 | ---------- 484 | scale : float 485 | A scalar multiplier `Tensor`. 0.0 disables the regularizer. 486 | scope: An optional scope name. 487 | 488 | Returns 489 | --------- 490 | A function with signature `mn_o(weights, name=None)` that apply Lo regularization. 491 | 492 | Raises 493 | --------- 494 | ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float. 495 | """ 496 | import numbers 497 | from tensorflow.python.framework import ops 498 | from tensorflow.python.ops import standard_ops 499 | 500 | if isinstance(scale, numbers.Integral): 501 | raise ValueError('scale cannot be an integer: %s' % scale) 502 | if isinstance(scale, numbers.Real): 503 | if scale < 0.: 504 | raise ValueError('Setting a scale less than 0 on a regularizer: %g' % 505 | scale) 506 | # if scale >= 1.: 507 | # raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % 508 | # scale) 509 | if scale == 0.: 510 | logging.info('Scale of 0 disables regularizer.') 511 | return lambda _, name=None: None 512 | 513 | def mn_o(weights, name='maxnorm_o_regularizer'): 514 | """Applies max-norm regularization to weights.""" 515 | with tf.name_scope(name) as scope: 516 | my_scale = ops.convert_to_tensor(scale, 517 | dtype=weights.dtype.base_dtype, 518 | name='scale') 519 | if tf.__version__ <= '0.12': 520 | standard_ops_fn = standard_ops.mul 521 | else: 522 | standard_ops_fn = standard_ops.multiply 523 | return standard_ops_fn(my_scale, standard_ops.reduce_sum(standard_ops.reduce_max(standard_ops.abs(weights), 0)), name=scope) 524 | return mn_o 525 | 526 | def maxnorm_i_regularizer(scale, scope=None): 527 | """Max-norm input regularization removes the neurons of previous layer.\n 528 | Returns a function that can be used to apply max-norm regularization to each row of weight matrix.\n 529 | The implementation follows `TensorFlow contrib `_. 530 | 531 | Parameters 532 | ---------- 533 | scale : float 534 | A scalar multiplier `Tensor`. 0.0 disables the regularizer. 535 | scope: An optional scope name. 536 | 537 | Returns 538 | --------- 539 | A function with signature `mn_i(weights, name=None)` that apply Lo regularization. 540 | 541 | Raises 542 | --------- 543 | ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float. 544 | """ 545 | import numbers 546 | from tensorflow.python.framework import ops 547 | from tensorflow.python.ops import standard_ops 548 | 549 | if isinstance(scale, numbers.Integral): 550 | raise ValueError('scale cannot be an integer: %s' % scale) 551 | if isinstance(scale, numbers.Real): 552 | if scale < 0.: 553 | raise ValueError('Setting a scale less than 0 on a regularizer: %g' % 554 | scale) 555 | # if scale >= 1.: 556 | # raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % 557 | # scale) 558 | if scale == 0.: 559 | logging.info('Scale of 0 disables regularizer.') 560 | return lambda _, name=None: None 561 | 562 | def mn_i(weights, name='maxnorm_i_regularizer'): 563 | """Applies max-norm regularization to weights.""" 564 | with tf.name_scope(name) as scope: 565 | my_scale = ops.convert_to_tensor(scale, 566 | dtype=weights.dtype.base_dtype, 567 | name='scale') 568 | if tf.__version__ <= '0.12': 569 | standard_ops_fn = standard_ops.mul 570 | else: 571 | standard_ops_fn = standard_ops.multiply 572 | return standard_ops_fn(my_scale, standard_ops.reduce_sum(standard_ops.reduce_max(standard_ops.abs(weights), 1)), name=scope) 573 | return mn_i 574 | 575 | 576 | 577 | 578 | 579 | # 580 | -------------------------------------------------------------------------------- /tensorlayer/db.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | """ 4 | Experimental Database Management System. 5 | 6 | Latest Version 7 | """ 8 | 9 | 10 | import tensorflow as tf 11 | import tensorlayer as tl 12 | import numpy as np 13 | import time 14 | import math 15 | 16 | 17 | import uuid 18 | 19 | import pymongo 20 | import gridfs 21 | import pickle 22 | from pymongo import MongoClient 23 | from datetime import datetime 24 | 25 | import inspect 26 | 27 | def AutoFill(func): 28 | def func_wrapper(self,*args,**kwargs): 29 | d=inspect.getcallargs(func,self,*args,**kwargs) 30 | d['args'].update({"studyID":self.studyID}) 31 | return func(**d) 32 | return func_wrapper 33 | 34 | 35 | 36 | 37 | 38 | 39 | class TensorDB(object): 40 | """TensorDB is a MongoDB based manager that help you to manage data, network topology, parameters and logging. 41 | 42 | Parameters 43 | ------------- 44 | ip : string, localhost or IP address. 45 | port : int, port number. 46 | db_name : string, database name. 47 | user_name : string, set to None if it donnot need authentication. 48 | password : string. 49 | 50 | Properties 51 | ------------ 52 | db : ``pymongo.MongoClient[db_name]``, xxxxxx 53 | datafs : ``gridfs.GridFS(self.db, collection="datafs")``, xxxxxxxxxx 54 | modelfs : ``gridfs.GridFS(self.db, collection="modelfs")``, 55 | paramsfs : ``gridfs.GridFS(self.db, collection="paramsfs")``, 56 | db.Params : Collection for 57 | db.TrainLog : Collection for 58 | db.ValidLog : Collection for 59 | db.TestLog : Collection for 60 | studyID : string, unique ID, if None random generate one. 61 | 62 | Dependencies 63 | ------------- 64 | 1 : MongoDB, as TensorDB is based on MongoDB, you need to install it in your 65 | local machine or remote machine. 66 | 2 : pip install pymongo, for MongoDB python API. 67 | 68 | Optional Tools 69 | ---------------- 70 | 1 : You may like to install MongoChef or Mongo Management Studo APP for 71 | visualizing or testing your MongoDB. 72 | """ 73 | def __init__( 74 | self, 75 | ip = 'localhost', 76 | port = 27017, 77 | db_name = 'db_name', 78 | user_name = None, 79 | password = 'password', 80 | studyID=None 81 | ): 82 | ## connect mongodb 83 | client = MongoClient(ip, port) 84 | self.db = client[db_name] 85 | if user_name != None: 86 | self.db.authenticate(user_name, password) 87 | 88 | 89 | if studyID is None: 90 | self.studyID=str(uuid.uuid1()) 91 | else: 92 | self.studyID=studyID 93 | 94 | ## define file system (Buckets) 95 | self.datafs = gridfs.GridFS(self.db, collection="datafs") 96 | self.modelfs = gridfs.GridFS(self.db, collection="modelfs") 97 | self.paramsfs = gridfs.GridFS(self.db, collection="paramsfs") 98 | self.archfs=gridfs.GridFS(self.db,collection="ModelArchitecture") 99 | ## 100 | print("[TensorDB] Connect SUCCESS {}:{} {} {} {}".format(ip, port, db_name, user_name, studyID)) 101 | 102 | self.ip = ip 103 | self.port = port 104 | self.db_name = db_name 105 | self.user_name = user_name 106 | 107 | def __autofill(self,args): 108 | return args.update({'studyID':self.studyID}) 109 | 110 | def __serialization(self,ps): 111 | return pickle.dumps(ps, protocol=2) 112 | 113 | def __deserialization(self,ps): 114 | return pickle.loads(ps) 115 | 116 | def save_params(self, params=[], args={}):#, file_name='parameters'): 117 | """ Save parameters into MongoDB Buckets, and save the file ID into Params Collections. 118 | 119 | Parameters 120 | ---------- 121 | params : a list of parameters 122 | args : dictionary, item meta data. 123 | 124 | Returns 125 | --------- 126 | f_id : the Buckets ID of the parameters. 127 | """ 128 | self.__autofill(args) 129 | s = time.time() 130 | f_id = self.paramsfs.put(self.__serialization(params))#, file_name=file_name) 131 | args.update({'f_id': f_id, 'time': datetime.utcnow()}) 132 | self.db.Params.insert_one(args) 133 | # print("[TensorDB] Save params: {} SUCCESS, took: {}s".format(file_name, round(time.time()-s, 2))) 134 | print("[TensorDB] Save params: SUCCESS, took: {}s".format(round(time.time()-s, 2))) 135 | return f_id 136 | 137 | @AutoFill 138 | def find_one_params(self, args={},sort=None): 139 | """ Find one parameter from MongoDB Buckets. 140 | 141 | Parameters 142 | ---------- 143 | args : dictionary, find items. 144 | 145 | Returns 146 | -------- 147 | params : the parameters, return False if nothing found. 148 | f_id : the Buckets ID of the parameters, return False if nothing found. 149 | """ 150 | 151 | s = time.time() 152 | # print(args) 153 | d = self.db.Params.find_one(filter=args,sort=sort) 154 | 155 | if d is not None: 156 | f_id = d['f_id'] 157 | else: 158 | print("[TensorDB] FAIL! Cannot find: {}".format(args)) 159 | return False, False 160 | try: 161 | params = self.__deserialization(self.paramsfs.get(f_id).read()) 162 | print("[TensorDB] Find one params SUCCESS, {} took: {}s".format(args, round(time.time()-s, 2))) 163 | return params, f_id 164 | except: 165 | return False, False 166 | 167 | @AutoFill 168 | def find_all_params(self, args={}): 169 | """ Find all parameter from MongoDB Buckets 170 | 171 | Parameters 172 | ---------- 173 | args : dictionary, find items 174 | 175 | Returns 176 | -------- 177 | params : the parameters, return False if nothing found. 178 | 179 | """ 180 | 181 | s = time.time() 182 | pc = self.db.Params.find(args) 183 | 184 | if pc is not None: 185 | f_id_list = pc.distinct('f_id') 186 | params = [] 187 | for f_id in f_id_list: # you may have multiple Buckets files 188 | tmp = self.paramsfs.get(f_id).read() 189 | params.append(self.__deserialization(tmp)) 190 | else: 191 | print("[TensorDB] FAIL! Cannot find any: {}".format(args)) 192 | return False 193 | 194 | print("[TensorDB] Find all params SUCCESS, took: {}s".format(round(time.time()-s, 2))) 195 | return params 196 | 197 | @AutoFill 198 | def del_params(self, args={}): 199 | """ Delete params in MongoDB uckets. 200 | 201 | Parameters 202 | ----------- 203 | args : dictionary, find items to delete, leave it empty to delete all parameters. 204 | """ 205 | 206 | pc = self.db.Params.find(args) 207 | f_id_list = pc.distinct('f_id') 208 | # remove from Buckets 209 | for f in f_id_list: 210 | self.paramsfs.delete(f) 211 | # remove from Collections 212 | self.db.Params.remove(args) 213 | 214 | print("[TensorDB] Delete params SUCCESS: {}".format(args)) 215 | 216 | def _print_dict(self, args): 217 | # return " / ".join(str(key) + ": "+ str(value) for key, value in args.items()) 218 | 219 | string = '' 220 | for key, value in args.items(): 221 | if key is not '_id': 222 | string += str(key) + ": "+ str(value) + " / " 223 | return string 224 | 225 | ## =========================== LOG =================================== ## 226 | @AutoFill 227 | def train_log(self, args={}): 228 | """Save the training log. 229 | 230 | Parameters 231 | ----------- 232 | args : dictionary, items to save. 233 | 234 | Examples 235 | --------- 236 | >>> db.train_log(time=time.time(), {'loss': loss, 'acc': acc}) 237 | """ 238 | 239 | _result = self.db.TrainLog.insert_one(args) 240 | _log = self._print_dict(args) 241 | #print("[TensorDB] TrainLog: " +_log) 242 | return _result 243 | 244 | @AutoFill 245 | def del_train_log(self, args={}): 246 | """ Delete train log. 247 | 248 | Parameters 249 | ----------- 250 | args : dictionary, find items to delete, leave it empty to delete all log. 251 | """ 252 | 253 | self.db.TrainLog.delete_many(args) 254 | print("[TensorDB] Delete TrainLog SUCCESS") 255 | 256 | @AutoFill 257 | def valid_log(self, args={}): 258 | """Save the validating log. 259 | 260 | Parameters 261 | ----------- 262 | args : dictionary, items to save. 263 | 264 | Examples 265 | --------- 266 | >>> db.valid_log(time=time.time(), {'loss': loss, 'acc': acc}) 267 | """ 268 | 269 | _result = self.db.ValidLog.insert_one(args) 270 | # _log = "".join(str(key) + ": " + str(value) for key, value in args.items()) 271 | _log = self._print_dict(args) 272 | print("[TensorDB] ValidLog: " +_log) 273 | return _result 274 | 275 | @AutoFill 276 | def del_valid_log(self, args={}): 277 | """ Delete validation log. 278 | 279 | Parameters 280 | ----------- 281 | args : dictionary, find items to delete, leave it empty to delete all log. 282 | """ 283 | self.db.ValidLog.delete_many(args) 284 | print("[TensorDB] Delete ValidLog SUCCESS") 285 | 286 | @AutoFill 287 | def test_log(self, args={}): 288 | """Save the testing log. 289 | 290 | Parameters 291 | ----------- 292 | args : dictionary, items to save. 293 | 294 | Examples 295 | --------- 296 | >>> db.test_log(time=time.time(), {'loss': loss, 'acc': acc}) 297 | """ 298 | 299 | _result = self.db.TestLog.insert_one(args) 300 | # _log = "".join(str(key) + str(value) for key, value in args.items()) 301 | _log = self._print_dict(args) 302 | print("[TensorDB] TestLog: " +_log) 303 | return _result 304 | 305 | @AutoFill 306 | def del_test_log(self, args={}): 307 | """ Delete test log. 308 | 309 | Parameters 310 | ----------- 311 | args : dictionary, find items to delete, leave it empty to delete all log. 312 | """ 313 | 314 | self.db.TestLog.delete_many(args) 315 | print("[TensorDB] Delete TestLog SUCCESS") 316 | 317 | ## =========================== Network Architecture ================== ## 318 | @AutoFill 319 | def save_model_architecture(self,s,args={}): 320 | self.__autofill(args) 321 | fid=self.archfs.put(s,filename="modelarchitecture") 322 | args.update({"fid":fid}) 323 | self.db.march.insert_one(args) 324 | 325 | @AutoFill 326 | def load_model_architecture(self,args={}): 327 | 328 | d = self.db.march.find_one(args) 329 | if d is not None: 330 | fid = d['fid'] 331 | print(d) 332 | print(fid) 333 | # "print find" 334 | else: 335 | print("[TensorDB] FAIL! Cannot find: {}".format(args)) 336 | print ("no idtem") 337 | return False, False 338 | try: 339 | archs = self.archfs.get(fid).read() 340 | '''print("[TensorDB] Find one params SUCCESS, {} took: {}s".format(args, round(time.time()-s, 2)))''' 341 | return archs, fid 342 | except Exception as e: 343 | print("exception") 344 | print(e) 345 | return False, False 346 | 347 | @AutoFill 348 | def save_job(self, script=None, args={}): 349 | """Save the job. 350 | 351 | Parameters 352 | ----------- 353 | script : a script file name or None. 354 | args : dictionary, items to save. 355 | 356 | Examples 357 | --------- 358 | >>> # Save your job 359 | >>> db.save_job('your_script.py', {'job_id': 1, 'learning_rate': 0.01, 'n_units': 100}) 360 | >>> # Run your job 361 | >>> temp = db.find_one_job(args={'job_id': 1}) 362 | >>> print(temp['learning_rate']) 363 | ... 0.01 364 | >>> import _your_script 365 | ... running your script 366 | """ 367 | self.__autofill(args) 368 | if script is not None: 369 | _script = open(script, 'rb').read() 370 | args.update({'script': _script, 'script_name': script}) 371 | # _result = self.db.Job.insert_one(args) 372 | _result = self.db.Job.replace_one(args, args, upsert=True) 373 | _log = self._print_dict(args) 374 | print("[TensorDB] Save Job: script={}, args={}".format(script, args)) 375 | return _result 376 | 377 | @AutoFill 378 | def find_one_job(self, args={}): 379 | """ Find one job from MongoDB Job Collections. 380 | 381 | Parameters 382 | ---------- 383 | args : dictionary, find items. 384 | 385 | Returns 386 | -------- 387 | dictionary : contains all meta data and script. 388 | """ 389 | 390 | 391 | temp = self.db.Job.find_one(args) 392 | 393 | if temp is not None: 394 | if 'script_name' in temp.keys(): 395 | f = open('_' + temp['script_name'], 'wb') 396 | f.write(temp['script']) 397 | f.close() 398 | print("[TensorDB] Find Job: {}".format(args)) 399 | else: 400 | print("[TensorDB] FAIL! Cannot find any: {}".format(args)) 401 | return False 402 | 403 | return temp 404 | 405 | def push_job(self,margs, wargs,dargs,epoch): 406 | 407 | ms,mid=self.load_model_architecture(margs) 408 | weight,wid=self.find_one_params(wargs) 409 | args={"weight":wid,"model":mid,"dargs":dargs,"epoch":epoch,"time":datetime.utcnow(),"Running":False} 410 | self.__autofill(args) 411 | self.db.JOBS.insert_one(args) 412 | 413 | def peek_job(self): 414 | args={'Running':False} 415 | self.__autofill(args) 416 | m=self.db.JOBS.find_one(args) 417 | print(m) 418 | if m is None: 419 | return False 420 | 421 | s=self.paramsfs.get(m['weight']).read() 422 | w=self.__deserialization(s) 423 | 424 | ach=self.archfs.get(m['model']).read() 425 | 426 | return m['_id'], ach,w,m["dargs"],m['epoch'] 427 | 428 | def run_job(self,jid): 429 | self.db.JOBS.find_one_and_update({'_id':jid},{'$set': {'Running': True,"Since":datetime.utcnow()}}) 430 | 431 | def del_job(self,jid): 432 | self.db.JOBS.find_one_and_update({'_id':jid},{'$set': {'Running': True,"Finished":datetime.utcnow()}}) 433 | 434 | def __str__(self): 435 | _s = "[TensorDB] Info:\n" 436 | _t = _s + " " + str(self.db) 437 | return _t 438 | 439 | # def save_bulk_data(self, data=None, filename='filename'): 440 | # """ Put bulk data into TensorDB.datafs, return file ID. 441 | # When you have a very large data, you may like to save it into GridFS Buckets 442 | # instead of Collections, then when you want to load it, XXXX 443 | # 444 | # Parameters 445 | # ----------- 446 | # data : serialized data. 447 | # filename : string, GridFS Buckets. 448 | # 449 | # References 450 | # ----------- 451 | # - MongoDB find, xxxxx 452 | # """ 453 | # s = time.time() 454 | # f_id = self.datafs.put(data, filename=filename) 455 | # print("[TensorDB] save_bulk_data: {} took: {}s".format(filename, round(time.time()-s, 2))) 456 | # return f_id 457 | # 458 | # def save_collection(self, data=None, collect_name='collect_name'): 459 | # """ Insert data into MongoDB Collections, return xx. 460 | # 461 | # Parameters 462 | # ----------- 463 | # data : serialized data. 464 | # collect_name : string, MongoDB collection name. 465 | # 466 | # References 467 | # ----------- 468 | # - MongoDB find, xxxxx 469 | # """ 470 | # s = time.time() 471 | # rl = self.db[collect_name].insert_many(data) 472 | # print("[TensorDB] save_collection: {} took: {}s".format(collect_name, round(time.time()-s, 2))) 473 | # return rl 474 | # 475 | # def find(self, args={}, collect_name='collect_name'): 476 | # """ Find data from MongoDB Collections. 477 | # 478 | # Parameters 479 | # ----------- 480 | # args : dictionary, arguments for finding. 481 | # collect_name : string, MongoDB collection name. 482 | # 483 | # References 484 | # ----------- 485 | # - MongoDB find, xxxxx 486 | # """ 487 | # s = time.time() 488 | # 489 | # pc = self.db[collect_name].find(args) # pymongo.cursor.Cursor object 490 | # flist = pc.distinct('f_id') 491 | # fldict = {} 492 | # for f in flist: # you may have multiple Buckets files 493 | # # fldict[f] = pickle.loads(self.datafs.get(f).read()) 494 | # # s2 = time.time() 495 | # tmp = self.datafs.get(f).read() 496 | # # print(time.time()-s2) 497 | # fldict[f] = pickle.loads(tmp) 498 | # # print(time.time()-s2) 499 | # # exit() 500 | # # print(round(time.time()-s, 2)) 501 | # data = [fldict[x['f_id']][x['id']] for x in pc] 502 | # data = np.asarray(data) 503 | # print("[TensorDB] find: {} get: {} took: {}s".format(collect_name, pc.count(), round(time.time()-s, 2))) 504 | # return data 505 | 506 | 507 | 508 | class DBLogger: 509 | """ """ 510 | def __init__(self,db,model): 511 | self.db=db 512 | self.model=model 513 | 514 | def on_train_begin(self,logs={}): 515 | print("start") 516 | 517 | def on_train_end(self,logs={}): 518 | print("end") 519 | 520 | def on_epoch_begin(self,epoch,logs={}): 521 | self.epoch=epoch 522 | self.et=time.time() 523 | return 524 | 525 | def on_epoch_end(self, epoch, logs={}): 526 | self.et=time.time()-self.et 527 | print("ending") 528 | print(epoch) 529 | logs['epoch']=epoch 530 | logs['time']=datetime.utcnow() 531 | logs['stepTime']=self.et 532 | logs['acc']=np.asscalar(logs['acc']) 533 | print(logs) 534 | 535 | w=self.model.Params 536 | fid=self.db.save_params(w,logs) 537 | logs.update({'params':fid}) 538 | self.db.valid_log(logs) 539 | def on_batch_begin(self, batch,logs={}): 540 | self.t=time.time() 541 | self.losses = [] 542 | self.batch=batch 543 | 544 | def on_batch_end(self, batch, logs={}): 545 | self.t2=time.time()-self.t 546 | logs['acc']=np.asscalar(logs['acc']) 547 | #logs['loss']=np.asscalar(logs['loss']) 548 | logs['step_time']=self.t2 549 | logs['time']=datetime.utcnow() 550 | logs['epoch']=self.epoch 551 | logs['batch']=self.batch 552 | self.db.train_log(logs) 553 | -------------------------------------------------------------------------------- /tensorlayer/files.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | import tensorflow as tf 6 | import os 7 | import numpy as np 8 | import re 9 | import sys 10 | import tarfile 11 | import gzip 12 | import zipfile 13 | from . import visualize 14 | from . import nlp 15 | import pickle 16 | from six.moves import urllib 17 | from six.moves import cPickle 18 | from six.moves import zip 19 | from tensorflow.python.platform import gfile 20 | 21 | 22 | ## Load dataset functions 23 | def load_mnist_dataset(shape=(-1,784), path="data/mnist/"): 24 | """Automatically download MNIST dataset 25 | and return the training, validation and test set with 50000, 10000 and 10000 26 | digit images respectively. 27 | 28 | Parameters 29 | ---------- 30 | shape : tuple 31 | The shape of digit images, defaults to (-1,784) 32 | path : string 33 | Path to download data to, defaults to data/mnist/ 34 | 35 | Examples 36 | -------- 37 | >>> X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1,784)) 38 | >>> X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) 39 | """ 40 | # We first define functions for loading MNIST images and labels. 41 | # For convenience, they also download the requested files if needed. 42 | def load_mnist_images(path, filename): 43 | filepath = maybe_download_and_extract(filename, path, 'http://yann.lecun.com/exdb/mnist/') 44 | 45 | print(filepath) 46 | # Read the inputs in Yann LeCun's binary format. 47 | with gzip.open(filepath, 'rb') as f: 48 | data = np.frombuffer(f.read(), np.uint8, offset=16) 49 | # The inputs are vectors now, we reshape them to monochrome 2D images, 50 | # following the shape convention: (examples, channels, rows, columns) 51 | data = data.reshape(shape) 52 | # The inputs come as bytes, we convert them to float32 in range [0,1]. 53 | # (Actually to range [0, 255/256], for compatibility to the version 54 | # provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.) 55 | return data / np.float32(256) 56 | 57 | def load_mnist_labels(path, filename): 58 | filepath = maybe_download_and_extract(filename, path, 'http://yann.lecun.com/exdb/mnist/') 59 | # Read the labels in Yann LeCun's binary format. 60 | with gzip.open(filepath, 'rb') as f: 61 | data = np.frombuffer(f.read(), np.uint8, offset=8) 62 | # The labels are vectors of integers now, that's exactly what we want. 63 | return data 64 | 65 | # Download and read the training and test set images and labels. 66 | print("Load or Download MNIST > {}".format(path)) 67 | X_train = load_mnist_images(path, 'train-images-idx3-ubyte.gz') 68 | y_train = load_mnist_labels(path, 'train-labels-idx1-ubyte.gz') 69 | X_test = load_mnist_images(path, 't10k-images-idx3-ubyte.gz') 70 | y_test = load_mnist_labels(path, 't10k-labels-idx1-ubyte.gz') 71 | 72 | # We reserve the last 10000 training examples for validation. 73 | X_train, X_val = X_train[:-10000], X_train[-10000:] 74 | y_train, y_val = y_train[:-10000], y_train[-10000:] 75 | 76 | # We just return all the arrays in order, as expected in main(). 77 | # (It doesn't matter how we do this as long as we can read them again.) 78 | X_train = np.asarray(X_train, dtype=np.float32) 79 | y_train = np.asarray(y_train, dtype=np.int32) 80 | X_val = np.asarray(X_val, dtype=np.float32) 81 | y_val = np.asarray(y_val, dtype=np.int32) 82 | X_test = np.asarray(X_test, dtype=np.float32) 83 | y_test = np.asarray(y_test, dtype=np.int32) 84 | return X_train, y_train, X_val, y_val, X_test, y_test 85 | 86 | 87 | def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data/cifar10/', plotable=False, second=3): 88 | """The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 89 | 6000 images per class. There are 50000 training images and 10000 test images. 90 | 91 | The dataset is divided into five training batches and one test batch, each with 92 | 10000 images. The test batch contains exactly 1000 randomly-selected images from 93 | each class. The training batches contain the remaining images in random order, 94 | but some training batches may contain more images from one class than another. 95 | Between them, the training batches contain exactly 5000 images from each class. 96 | 97 | Parameters 98 | ---------- 99 | shape : tupe 100 | The shape of digit images: e.g. (-1, 3, 32, 32) , (-1, 32, 32, 3) , (-1, 32*32*3) 101 | plotable : True, False 102 | Whether to plot some image examples. 103 | second : int 104 | If ``plotable`` is True, ``second`` is the display time. 105 | path : string 106 | Path to download data to, defaults to data/cifar10/ 107 | 108 | Examples 109 | -------- 110 | >>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=True) 111 | 112 | Notes 113 | ------ 114 | CIFAR-10 images can only be display without color change under uint8. 115 | >>> X_train = np.asarray(X_train, dtype=np.uint8) 116 | >>> plt.ion() 117 | >>> fig = plt.figure(1232) 118 | >>> count = 1 119 | >>> for row in range(10): 120 | >>> for col in range(10): 121 | >>> a = fig.add_subplot(10, 10, count) 122 | >>> plt.imshow(X_train[count-1], interpolation='nearest') 123 | >>> plt.gca().xaxis.set_major_locator(plt.NullLocator()) # 不显示刻度(tick) 124 | >>> plt.gca().yaxis.set_major_locator(plt.NullLocator()) 125 | >>> count = count + 1 126 | >>> plt.draw() 127 | >>> plt.pause(3) 128 | 129 | References 130 | ---------- 131 | - `CIFAR website `_ 132 | - `Data download link `_ 133 | - `Code references `_ 134 | """ 135 | 136 | print("Load or Download cifar10 > {}".format(path)) 137 | 138 | #Helper function to unpickle the data 139 | def unpickle(file): 140 | fp = open(file, 'rb') 141 | if sys.version_info.major == 2: 142 | data = pickle.load(fp) 143 | elif sys.version_info.major == 3: 144 | data = pickle.load(fp, encoding='latin-1') 145 | fp.close() 146 | return data 147 | 148 | filename = 'cifar-10-python.tar.gz' 149 | url = 'https://www.cs.toronto.edu/~kriz/' 150 | #Download and uncompress file 151 | maybe_download_and_extract(filename, path, url, extract=True) 152 | 153 | #Unpickle file and fill in data 154 | X_train = None 155 | y_train = [] 156 | for i in range(1,6): 157 | data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "data_batch_{}".format(i))) 158 | if i == 1: 159 | X_train = data_dic['data'] 160 | else: 161 | X_train = np.vstack((X_train, data_dic['data'])) 162 | y_train += data_dic['labels'] 163 | 164 | test_data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "test_batch")) 165 | X_test = test_data_dic['data'] 166 | y_test = np.array(test_data_dic['labels']) 167 | 168 | if shape == (-1, 3, 32, 32): 169 | X_test = X_test.reshape(shape) 170 | X_train = X_train.reshape(shape) 171 | elif shape == (-1, 32, 32, 3): 172 | X_test = X_test.reshape(shape, order='F') 173 | X_train = X_train.reshape(shape, order='F') 174 | X_test = np.transpose(X_test, (0, 2, 1, 3)) 175 | X_train = np.transpose(X_train, (0, 2, 1, 3)) 176 | else: 177 | X_test = X_test.reshape(shape) 178 | X_train = X_train.reshape(shape) 179 | 180 | y_train = np.array(y_train) 181 | 182 | if plotable == True: 183 | print('\nCIFAR-10') 184 | import matplotlib.pyplot as plt 185 | fig = plt.figure(1) 186 | 187 | print('Shape of a training image: X_train[0]',X_train[0].shape) 188 | 189 | plt.ion() # interactive mode 190 | count = 1 191 | for row in range(10): 192 | for col in range(10): 193 | a = fig.add_subplot(10, 10, count) 194 | if shape == (-1, 3, 32, 32): 195 | # plt.imshow(X_train[count-1], interpolation='nearest') 196 | plt.imshow(np.transpose(X_train[count-1], (1, 2, 0)), interpolation='nearest') 197 | # plt.imshow(np.transpose(X_train[count-1], (2, 1, 0)), interpolation='nearest') 198 | elif shape == (-1, 32, 32, 3): 199 | plt.imshow(X_train[count-1], interpolation='nearest') 200 | # plt.imshow(np.transpose(X_train[count-1], (1, 0, 2)), interpolation='nearest') 201 | else: 202 | raise Exception("Do not support the given 'shape' to plot the image examples") 203 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) # 不显示刻度(tick) 204 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 205 | count = count + 1 206 | plt.draw() # interactive mode 207 | plt.pause(3) # interactive mode 208 | 209 | print("X_train:",X_train.shape) 210 | print("y_train:",y_train.shape) 211 | print("X_test:",X_test.shape) 212 | print("y_test:",y_test.shape) 213 | 214 | X_train = np.asarray(X_train, dtype=np.float32) 215 | X_test = np.asarray(X_test, dtype=np.float32) 216 | y_train = np.asarray(y_train, dtype=np.int32) 217 | y_test = np.asarray(y_test, dtype=np.int32) 218 | 219 | return X_train, y_train, X_test, y_test 220 | 221 | 222 | def load_ptb_dataset(path='data/ptb/'): 223 | """Penn TreeBank (PTB) dataset is used in many LANGUAGE MODELING papers, 224 | including "Empirical Evaluation and Combination of Advanced Language 225 | Modeling Techniques", "Recurrent Neural Network Regularization". 226 | 227 | It consists of 929k training words, 73k validation words, and 82k test 228 | words. It has 10k words in its vocabulary. 229 | 230 | In "Recurrent Neural Network Regularization", they trained regularized LSTMs 231 | of two sizes; these are denoted the medium LSTM and large LSTM. Both LSTMs 232 | have two layers and are unrolled for 35 steps. They initialize the hidden 233 | states to zero. They then use the final hidden states of the current 234 | minibatch as the initial hidden state of the subsequent minibatch 235 | (successive minibatches sequentially traverse the training set). 236 | The size of each minibatch is 20. 237 | 238 | The medium LSTM has 650 units per layer and its parameters are initialized 239 | uniformly in [−0.05, 0.05]. They apply 50% dropout on the non-recurrent 240 | connections. They train the LSTM for 39 epochs with a learning rate of 1, 241 | and after 6 epochs they decrease it by a factor of 1.2 after each epoch. 242 | They clip the norm of the gradients (normalized by minibatch size) at 5. 243 | 244 | The large LSTM has 1500 units per layer and its parameters are initialized 245 | uniformly in [−0.04, 0.04]. We apply 65% dropout on the non-recurrent 246 | connections. They train the model for 55 epochs with a learning rate of 1; 247 | after 14 epochs they start to reduce the learning rate by a factor of 1.15 248 | after each epoch. They clip the norm of the gradients (normalized by 249 | minibatch size) at 10. 250 | 251 | Parameters 252 | ---------- 253 | path : : string 254 | Path to download data to, defaults to data/ptb/ 255 | 256 | Returns 257 | -------- 258 | train_data, valid_data, test_data, vocabulary size 259 | 260 | Examples 261 | -------- 262 | >>> train_data, valid_data, test_data, vocab_size = tl.files.load_ptb_dataset() 263 | 264 | Code References 265 | --------------- 266 | - ``tensorflow.models.rnn.ptb import reader`` 267 | 268 | Download Links 269 | --------------- 270 | - `Manual download `_ 271 | """ 272 | print("Load or Download Penn TreeBank (PTB) dataset > {}".format(path)) 273 | 274 | #Maybe dowload and uncompress tar, or load exsisting files 275 | filename = 'simple-examples.tgz' 276 | url = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/' 277 | maybe_download_and_extract(filename, path, url, extract=True) 278 | 279 | data_path = os.path.join(path, 'simple-examples', 'data') 280 | train_path = os.path.join(data_path, "ptb.train.txt") 281 | valid_path = os.path.join(data_path, "ptb.valid.txt") 282 | test_path = os.path.join(data_path, "ptb.test.txt") 283 | 284 | word_to_id = nlp.build_vocab(nlp.read_words(train_path)) 285 | 286 | train_data = nlp.words_to_word_ids(nlp.read_words(train_path), word_to_id) 287 | valid_data = nlp.words_to_word_ids(nlp.read_words(valid_path), word_to_id) 288 | test_data = nlp.words_to_word_ids(nlp.read_words(test_path), word_to_id) 289 | vocabulary = len(word_to_id) 290 | 291 | # print(nlp.read_words(train_path)) # ... 'according', 'to', 'mr.', '', ''] 292 | # print(train_data) # ... 214, 5, 23, 1, 2] 293 | # print(word_to_id) # ... 'beyond': 1295, 'anti-nuclear': 9599, 'trouble': 1520, '': 2 ... } 294 | # print(vocabulary) # 10000 295 | # exit() 296 | return train_data, valid_data, test_data, vocabulary 297 | 298 | 299 | def load_matt_mahoney_text8_dataset(path='data/mm_test8/'): 300 | """Download a text file from Matt Mahoney's website 301 | if not present, and make sure it's the right size. 302 | Extract the first file enclosed in a zip file as a list of words. 303 | This dataset can be used for Word Embedding. 304 | 305 | Parameters 306 | ---------- 307 | path : : string 308 | Path to download data to, defaults to data/mm_test8/ 309 | 310 | Returns 311 | -------- 312 | word_list : a list 313 | a list of string (word).\n 314 | e.g. [.... 'their', 'families', 'who', 'were', 'expelled', 'from', 'jerusalem', ...] 315 | 316 | Examples 317 | -------- 318 | >>> words = tl.files.load_matt_mahoney_text8_dataset() 319 | >>> print('Data size', len(words)) 320 | """ 321 | 322 | print("Load or Download matt_mahoney_text8 Dataset> {}".format(path)) 323 | 324 | filename = 'text8.zip' 325 | url = 'http://mattmahoney.net/dc/' 326 | maybe_download_and_extract(filename, path, url, expected_bytes=31344016) 327 | 328 | with zipfile.ZipFile(os.path.join(path, filename)) as f: 329 | word_list = f.read(f.namelist()[0]).split() 330 | 331 | return word_list 332 | 333 | 334 | def load_imdb_dataset(path='data/imdb/', nb_words=None, skip_top=0, 335 | maxlen=None, test_split=0.2, seed=113, 336 | start_char=1, oov_char=2, index_from=3): 337 | """Load IMDB dataset 338 | 339 | Parameters 340 | ---------- 341 | path : : string 342 | Path to download data to, defaults to data/imdb/ 343 | 344 | Examples 345 | -------- 346 | >>> X_train, y_train, X_test, y_test = tl.files.load_imbd_dataset( 347 | ... nb_words=20000, test_split=0.2) 348 | >>> print('X_train.shape', X_train.shape) 349 | ... (20000,) [[1, 62, 74, ... 1033, 507, 27],[1, 60, 33, ... 13, 1053, 7]..] 350 | >>> print('y_train.shape', y_train.shape) 351 | ... (20000,) [1 0 0 ..., 1 0 1] 352 | 353 | References 354 | ----------- 355 | - `Modified from keras. `_ 356 | """ 357 | 358 | filename = "imdb.pkl" 359 | url = 'https://s3.amazonaws.com/text-datasets/' 360 | maybe_download_and_extract(filename, path, url) 361 | 362 | if filename.endswith(".gz"): 363 | f = gzip.open(os.path.join(path, filename), 'rb') 364 | else: 365 | f = open(os.path.join(path, filename), 'rb') 366 | 367 | X, labels = cPickle.load(f) 368 | f.close() 369 | 370 | np.random.seed(seed) 371 | np.random.shuffle(X) 372 | np.random.seed(seed) 373 | np.random.shuffle(labels) 374 | 375 | if start_char is not None: 376 | X = [[start_char] + [w + index_from for w in x] for x in X] 377 | elif index_from: 378 | X = [[w + index_from for w in x] for x in X] 379 | 380 | if maxlen: 381 | new_X = [] 382 | new_labels = [] 383 | for x, y in zip(X, labels): 384 | if len(x) < maxlen: 385 | new_X.append(x) 386 | new_labels.append(y) 387 | X = new_X 388 | labels = new_labels 389 | if not X: 390 | raise Exception('After filtering for sequences shorter than maxlen=' + 391 | str(maxlen) + ', no sequence was kept. ' 392 | 'Increase maxlen.') 393 | if not nb_words: 394 | nb_words = max([max(x) for x in X]) 395 | 396 | # by convention, use 2 as OOV word 397 | # reserve 'index_from' (=3 by default) characters: 0 (padding), 1 (start), 2 (OOV) 398 | if oov_char is not None: 399 | X = [[oov_char if (w >= nb_words or w < skip_top) else w for w in x] for x in X] 400 | else: 401 | nX = [] 402 | for x in X: 403 | nx = [] 404 | for w in x: 405 | if (w >= nb_words or w < skip_top): 406 | nx.append(w) 407 | nX.append(nx) 408 | X = nX 409 | 410 | X_train = np.array(X[:int(len(X) * (1 - test_split))]) 411 | y_train = np.array(labels[:int(len(X) * (1 - test_split))]) 412 | 413 | X_test = np.array(X[int(len(X) * (1 - test_split)):]) 414 | y_test = np.array(labels[int(len(X) * (1 - test_split)):]) 415 | 416 | return X_train, y_train, X_test, y_test 417 | 418 | def load_nietzsche_dataset(path='data/nietzsche/'): 419 | """Load Nietzsche dataset. 420 | Returns a string. 421 | 422 | Parameters 423 | ---------- 424 | path : string 425 | Path to download data to, defaults to data/nietzsche/ 426 | 427 | Examples 428 | -------- 429 | >>> see tutorial_generate_text.py 430 | >>> words = tl.files.load_nietzsche_dataset() 431 | >>> words = basic_clean_str(words) 432 | >>> words = words.split() 433 | """ 434 | print("Load or Download nietzsche dataset > {}".format(path)) 435 | 436 | filename = "nietzsche.txt" 437 | url = 'https://s3.amazonaws.com/text-datasets/' 438 | filepath = maybe_download_and_extract(filename, path, url) 439 | 440 | with open(filepath, "r") as f: 441 | words = f.read() 442 | return words 443 | 444 | def load_wmt_en_fr_dataset(path='data/wmt_en_fr/'): 445 | """It will download English-to-French translation data from the WMT'15 446 | Website (10^9-French-English corpus), and the 2013 news test from 447 | the same site as development set. 448 | Returns the directories of training data and test data. 449 | 450 | Parameters 451 | ---------- 452 | path : string 453 | Path to download data to, defaults to data/wmt_en_fr/ 454 | 455 | References 456 | ---------- 457 | - Code modified from /tensorflow/models/rnn/translation/data_utils.py 458 | 459 | Notes 460 | ----- 461 | Usually, it will take a long time to download this dataset. 462 | """ 463 | # URLs for WMT data. 464 | _WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/" 465 | _WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/" 466 | 467 | def gunzip_file(gz_path, new_path): 468 | """Unzips from gz_path into new_path.""" 469 | print("Unpacking %s to %s" % (gz_path, new_path)) 470 | with gzip.open(gz_path, "rb") as gz_file: 471 | with open(new_path, "wb") as new_file: 472 | for line in gz_file: 473 | new_file.write(line) 474 | 475 | def get_wmt_enfr_train_set(path): 476 | """Download the WMT en-fr training corpus to directory unless it's there.""" 477 | filename = "training-giga-fren.tar" 478 | maybe_download_and_extract(filename, path, _WMT_ENFR_TRAIN_URL, extract=True) 479 | train_path = os.path.join(path, "giga-fren.release2.fixed") 480 | gunzip_file(train_path + ".fr.gz", train_path + ".fr") 481 | gunzip_file(train_path + ".en.gz", train_path + ".en") 482 | return train_path 483 | 484 | def get_wmt_enfr_dev_set(path): 485 | """Download the WMT en-fr training corpus to directory unless it's there.""" 486 | filename = "dev-v2.tgz" 487 | dev_file = maybe_download_and_extract(filename, path, _WMT_ENFR_DEV_URL, extract=False) 488 | dev_name = "newstest2013" 489 | dev_path = os.path.join(path, "newstest2013") 490 | if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")): 491 | print("Extracting tgz file %s" % dev_file) 492 | with tarfile.open(dev_file, "r:gz") as dev_tar: 493 | fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr") 494 | en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en") 495 | fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix. 496 | en_dev_file.name = dev_name + ".en" 497 | dev_tar.extract(fr_dev_file, path) 498 | dev_tar.extract(en_dev_file, path) 499 | return dev_path 500 | 501 | print("Load or Download WMT English-to-French translation > {}".format(path)) 502 | 503 | train_path = get_wmt_enfr_train_set(path) 504 | dev_path = get_wmt_enfr_dev_set(path) 505 | 506 | return train_path, dev_path 507 | 508 | 509 | ## Load and save network 510 | def save_npz(save_list=[], name='model.npz', sess=None): 511 | """Input parameters and the file name, save parameters into .npz file. Use tl.utils.load_npz() to restore. 512 | 513 | Parameters 514 | ---------- 515 | save_list : a list 516 | Parameters want to be saved. 517 | name : a string or None 518 | The name of the .npz file. 519 | sess : None or Session 520 | 521 | Examples 522 | -------- 523 | >>> tl.files.save_npz(network.all_params, name='model_test.npz', sess=sess) 524 | ... File saved to: model_test.npz 525 | >>> load_params = tl.files.load_npz(name='model_test.npz') 526 | ... Loading param0, (784, 800) 527 | ... Loading param1, (800,) 528 | ... Loading param2, (800, 800) 529 | ... Loading param3, (800,) 530 | ... Loading param4, (800, 10) 531 | ... Loading param5, (10,) 532 | >>> put parameters into a TensorLayer network, please see assign_params() 533 | 534 | Notes 535 | ----- 536 | If you got session issues, you can change the value.eval() to value.eval(session=sess) 537 | 538 | References 539 | ---------- 540 | - `Saving dictionary using numpy `_ 541 | """ 542 | ## save params into a list 543 | save_list_var = [] 544 | if sess: 545 | save_list_var = sess.run(save_list) 546 | else: 547 | try: 548 | for k, value in enumerate(save_list): 549 | save_list_var.append(value.eval()) 550 | except: 551 | print(" Fail to save model, Hint: pass the session into this function, save_npz(network.all_params, name='model.npz', sess=sess)") 552 | np.savez(name, params=save_list_var) 553 | save_list_var = None 554 | del save_list_var 555 | print("[*] %s saved" % name) 556 | 557 | ## save params into a dictionary 558 | # rename_dict = {} 559 | # for k, value in enumerate(save_dict): 560 | # rename_dict.update({'param'+str(k) : value.eval()}) 561 | # np.savez(name, **rename_dict) 562 | # print('Model is saved to: %s' % name) 563 | 564 | def save_npz_dict(save_list=[], name='model.npz', sess=None): 565 | """Input parameters and the file name, save parameters as a dictionary into .npz file. Use tl.utils.load_npz_dict() to restore. 566 | 567 | Parameters 568 | ---------- 569 | save_list : a list 570 | Parameters want to be saved. 571 | name : a string or None 572 | The name of the .npz file. 573 | sess : None or Session 574 | 575 | Notes 576 | ----- 577 | This function tries to avoid a potential broadcasting error raised by numpy. 578 | 579 | """ 580 | ## save params into a list 581 | save_list_var = [] 582 | if sess: 583 | save_list_var = sess.run(save_list) 584 | else: 585 | try: 586 | for k, value in enumerate(save_list): 587 | save_list_var.append(value.eval()) 588 | except: 589 | print(" Fail to save model, Hint: pass the session into this function, save_npz_dict(network.all_params, name='model.npz', sess=sess)") 590 | save_var_dict = {str(idx):val for idx, val in enumerate(save_list_var)} 591 | np.savez(name, **save_var_dict) 592 | save_list_var = None 593 | save_var_dict = None 594 | del save_list_var 595 | del save_var_dict 596 | print("[*] %s saved" % name) 597 | 598 | def load_npz(path='', name='model.npz'): 599 | """Load the parameters of a Model saved by tl.files.save_npz(). 600 | 601 | Parameters 602 | ---------- 603 | path : a string 604 | Folder path to .npz file. 605 | name : a string or None 606 | The name of the .npz file. 607 | 608 | Returns 609 | -------- 610 | params : list 611 | A list of parameters in order. 612 | 613 | Examples 614 | -------- 615 | - See save_npz and assign_params 616 | 617 | References 618 | ---------- 619 | - `Saving dictionary using numpy `_ 620 | """ 621 | ## if save_npz save params into a dictionary 622 | # d = np.load( path+name ) 623 | # params = [] 624 | # print('Load Model') 625 | # for key, val in sorted( d.items() ): 626 | # params.append(val) 627 | # print('Loading %s, %s' % (key, str(val.shape))) 628 | # return params 629 | ## if save_npz save params into a list 630 | d = np.load( path+name ) 631 | # for val in sorted( d.items() ): 632 | # params = val 633 | # return params 634 | return d['params'] 635 | # print(d.items()[0][1]['params']) 636 | # exit() 637 | # return d.items()[0][1]['params'] 638 | 639 | def load_npz_dict(path='', name='model.npz'): 640 | """Load the parameters of a Model saved by tl.files.save_npz_dict(). 641 | 642 | Parameters 643 | ---------- 644 | path : a string 645 | Folder path to .npz file. 646 | name : a string or None 647 | The name of the .npz file. 648 | 649 | Returns 650 | -------- 651 | params : list 652 | A list of parameters in order. 653 | """ 654 | d = np.load( path+name ) 655 | saved_list_var = [val[1] for val in sorted(d.items(), key=lambda tup: int(tup[0]))] 656 | return saved_list_var 657 | 658 | def assign_params(sess, params, network): 659 | """Assign the given parameters to the TensorLayer network. 660 | 661 | Parameters 662 | ---------- 663 | sess : TensorFlow Session. Automatically run when sess is not None. 664 | params : a list 665 | A list of parameters in order. 666 | network : a :class:`Layer` class 667 | The network to be assigned 668 | 669 | Returns 670 | -------- 671 | ops : list 672 | A list of tf ops in order that assign params. Support sess.run(ops) manually. 673 | 674 | Examples 675 | -------- 676 | >>> Save your network as follow: 677 | >>> tl.files.save_npz(network.all_params, name='model_test.npz') 678 | >>> network.print_params() 679 | ... 680 | ... Next time, load and assign your network as follow: 681 | >>> tl.layers.initialize_global_variables(sess) 682 | >>> load_params = tl.files.load_npz(name='model_test.npz') 683 | >>> tl.files.assign_params(sess, load_params, network) 684 | >>> network.print_params() 685 | 686 | References 687 | ---------- 688 | - `Assign value to a TensorFlow variable `_ 689 | """ 690 | ops = [] 691 | for idx, param in enumerate(params): 692 | ops.append(network.all_params[idx].assign(param)) 693 | if sess is not None: 694 | sess.run(ops) 695 | return ops 696 | 697 | def load_and_assign_npz(sess=None, name=None, network=None): 698 | """Load model from npz and assign to a network. 699 | 700 | Parameters 701 | ------------- 702 | sess : TensorFlow Session 703 | name : string 704 | Model path. 705 | network : a :class:`Layer` class 706 | The network to be assigned 707 | 708 | Returns 709 | -------- 710 | Returns False if faild to model is not exist. 711 | 712 | Examples 713 | --------- 714 | >>> tl.files.load_and_assign_npz(sess=sess, name='net.npz', network=net) 715 | """ 716 | assert network is not None 717 | assert sess is not None 718 | if not os.path.exists(name): 719 | print("[!] Load {} failed!".format(name)) 720 | return False 721 | else: 722 | params = load_npz(name=name) 723 | assign_params(sess, params, network) 724 | print("[*] Load {} SUCCESS!".format(name)) 725 | return network 726 | 727 | # Load and save variables 728 | def save_any_to_npy(save_dict={}, name='file.npy'): 729 | """Save variables to .npy file. 730 | 731 | Examples 732 | --------- 733 | >>> tl.files.save_any_to_npy(save_dict={'data': ['a','b']}, name='test.npy') 734 | >>> data = tl.files.load_npy_to_any(name='test.npy') 735 | >>> print(data) 736 | ... {'data': ['a','b']} 737 | """ 738 | np.save(name, save_dict) 739 | 740 | def load_npy_to_any(path='', name='file.npy'): 741 | """Load .npy file. 742 | 743 | Examples 744 | --------- 745 | - see save_any_to_npy() 746 | """ 747 | file_path = os.path.join(path, name) 748 | try: 749 | npy = np.load(file_path).item() 750 | except: 751 | npy = np.load(file_path) 752 | finally: 753 | try: 754 | return npy 755 | except: 756 | print("[!] Fail to load %s" % file_path) 757 | exit() 758 | 759 | 760 | # Visualizing npz files 761 | def npz_to_W_pdf(path=None, regx='w1pre_[0-9]+\.(npz)'): 762 | """Convert the first weight matrix of .npz file to .pdf by using tl.visualize.W(). 763 | 764 | Parameters 765 | ---------- 766 | path : a string or None 767 | A folder path to npz files. 768 | regx : a string 769 | Regx for the file name. 770 | 771 | Examples 772 | -------- 773 | >>> Convert the first weight matrix of w1_pre...npz file to w1_pre...pdf. 774 | >>> tl.files.npz_to_W_pdf(path='/Users/.../npz_file/', regx='w1pre_[0-9]+\.(npz)') 775 | """ 776 | file_list = load_file_list(path=path, regx=regx) 777 | for f in file_list: 778 | W = load_npz(path, f)[0] 779 | print("%s --> %s" % (f, f.split('.')[0]+'.pdf')) 780 | visualize.W(W, second=10, saveable=True, name=f.split('.')[0], fig_idx=2012) 781 | 782 | 783 | ## Helper functions 784 | def load_file_list(path=None, regx='\.npz', printable=True): 785 | """Return a file list in a folder by given a path and regular expression. 786 | 787 | Parameters 788 | ---------- 789 | path : a string or None 790 | A folder path. 791 | regx : a string 792 | The regx of file name. 793 | printable : boolean, whether to print the files infomation. 794 | 795 | Examples 796 | ---------- 797 | >>> file_list = tl.files.load_file_list(path=None, regx='w1pre_[0-9]+\.(npz)') 798 | """ 799 | if path == False: 800 | path = os.getcwd() 801 | file_list = os.listdir(path) 802 | return_list = [] 803 | for idx, f in enumerate(file_list): 804 | if re.search(regx, f): 805 | return_list.append(f) 806 | # return_list.sort() 807 | if printable: 808 | print('Match file list = %s' % return_list) 809 | print('Number of files = %d' % len(return_list)) 810 | return return_list 811 | 812 | def load_folder_list(path=""): 813 | """Return a folder list in a folder by given a folder path. 814 | 815 | Parameters 816 | ---------- 817 | path : a string or None 818 | A folder path. 819 | """ 820 | return [os.path.join(path,o) for o in os.listdir(path) if os.path.isdir(os.path.join(path,o))] 821 | 822 | def exists_or_mkdir(path, verbose=True): 823 | """Check a folder by given name, if not exist, create the folder and return False, 824 | if directory exists, return True. 825 | 826 | Parameters 827 | ---------- 828 | path : a string 829 | A folder path. 830 | verbose : boolean 831 | If True, prints results, deaults is True 832 | 833 | Returns 834 | -------- 835 | True if folder exist, otherwise, returns False and create the folder 836 | 837 | Examples 838 | -------- 839 | >>> tl.files.exists_or_mkdir("checkpoints/train") 840 | """ 841 | if not os.path.exists(path): 842 | if verbose: 843 | print("[*] creates %s ..." % path) 844 | os.makedirs(path) 845 | return False 846 | else: 847 | if verbose: 848 | print("[!] %s exists ..." % path) 849 | return True 850 | 851 | def maybe_download_and_extract(filename, working_directory, url_source, extract=False, expected_bytes=None): 852 | """Checks if file exists in working_directory otherwise tries to dowload the file, 853 | and optionally also tries to extract the file if format is ".zip" or ".tar" 854 | 855 | Parameters 856 | ---------- 857 | filename : string 858 | The name of the (to be) dowloaded file. 859 | working_directory : string 860 | A folder path to search for the file in and dowload the file to 861 | url : string 862 | The URL to download the file from 863 | extract : bool, defaults to False 864 | If True, tries to uncompress the dowloaded file is ".tar.gz/.tar.bz2" or ".zip" file 865 | expected_bytes : int/None 866 | If set tries to verify that the downloaded file is of the specified size, otherwise raises an Exception, 867 | defaults to None which corresponds to no check being performed 868 | Returns 869 | ---------- 870 | filepath to dowloaded (uncompressed) file 871 | 872 | Examples 873 | -------- 874 | >>> down_file = tl.files.maybe_download_and_extract(filename = 'train-images-idx3-ubyte.gz', 875 | working_directory = 'data/', 876 | url_source = 'http://yann.lecun.com/exdb/mnist/') 877 | >>> tl.files.maybe_download_and_extract(filename = 'ADEChallengeData2016.zip', 878 | working_directory = 'data/', 879 | url_source = 'http://sceneparsing.csail.mit.edu/data/', 880 | extract=True) 881 | """ 882 | # We first define a download function, supporting both Python 2 and 3. 883 | def _download(filename, working_directory, url_source): 884 | def _dlProgress(count, blockSize, totalSize): 885 | if(totalSize != 0): 886 | percent = float(count * blockSize) / float(totalSize) * 100.0 887 | sys.stdout.write("\r" "Downloading " + filename + "...%d%%" % percent) 888 | sys.stdout.flush() 889 | if sys.version_info[0] == 2: 890 | from urllib import urlretrieve 891 | else: 892 | from urllib.request import urlretrieve 893 | filepath = os.path.join(working_directory, filename) 894 | urlretrieve(url_source+filename, filepath, reporthook=_dlProgress) 895 | 896 | exists_or_mkdir(working_directory, verbose=False) 897 | filepath = os.path.join(working_directory, filename) 898 | 899 | if not os.path.exists(filepath): 900 | _download(filename, working_directory, url_source) 901 | print() 902 | statinfo = os.stat(filepath) 903 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 904 | if(not(expected_bytes is None) and (expected_bytes != statinfo.st_size)): 905 | raise Exception('Failed to verify ' + filename + '. Can you get to it with a browser?') 906 | if(extract): 907 | if tarfile.is_tarfile(filepath): 908 | print('Trying to extract tar file') 909 | tarfile.open(filepath, 'r').extractall(working_directory) 910 | print('... Success!') 911 | elif zipfile.is_zipfile(filepath): 912 | print('Trying to extract zip file') 913 | with zipfile.ZipFile(filepath) as zf: 914 | zf.extractall(working_directory) 915 | print('... Success!') 916 | else: 917 | print("Unknown compression_format only .tar.gz/.tar.bz2/.tar and .zip supported") 918 | return filepath 919 | -------------------------------------------------------------------------------- /tensorlayer/iterate.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | 6 | import numpy as np 7 | from six.moves import xrange 8 | 9 | def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False): 10 | """Generate a generator that input a group of example in numpy.array and 11 | their labels, return the examples and labels by the given batchsize. 12 | 13 | Parameters 14 | ---------- 15 | inputs : numpy.array 16 | (X) The input features, every row is a example. 17 | targets : numpy.array 18 | (y) The labels of inputs, every row is a example. 19 | batch_size : int 20 | The batch size. 21 | shuffle : boolean 22 | Indicating whether to use a shuffling queue, shuffle the dataset before return. 23 | 24 | Hints 25 | ------- 26 | - If you have two inputs, e.g. X1 (1000, 100) and X2 (1000, 80), you can ``np.hstack((X1, X2)) 27 | into (1000, 180) and feed into ``inputs``, then you can split a batch of X1 and X2. 28 | 29 | Examples 30 | -------- 31 | >>> X = np.asarray([['a','a'], ['b','b'], ['c','c'], ['d','d'], ['e','e'], ['f','f']]) 32 | >>> y = np.asarray([0,1,2,3,4,5]) 33 | >>> for batch in tl.iterate.minibatches(inputs=X, targets=y, batch_size=2, shuffle=False): 34 | >>> print(batch) 35 | ... (array([['a', 'a'], 36 | ... ['b', 'b']], 37 | ... dtype='>> X = np.asarray([['a','a'], ['b','b'], ['c','c'], ['d','d'], ['e','e'], ['f','f']]) 64 | >>> y = np.asarray([0, 1, 2, 3, 4, 5]) 65 | >>> for batch in tl.iterate.seq_minibatches(inputs=X, targets=y, batch_size=2, seq_length=2, stride=1): 66 | >>> print(batch) 67 | ... (array([['a', 'a'], 68 | ... ['b', 'b'], 69 | ... ['b', 'b'], 70 | ... ['c', 'c']], 71 | ... dtype='>> return_last = True 82 | >>> num_steps = 2 83 | >>> X = np.asarray([['a','a'], ['b','b'], ['c','c'], ['d','d'], ['e','e'], ['f','f']]) 84 | >>> Y = np.asarray([0,1,2,3,4,5]) 85 | >>> for batch in tl.iterate.seq_minibatches(inputs=X, targets=Y, batch_size=2, seq_length=num_steps, stride=1): 86 | >>> x, y = batch 87 | >>> if return_last: 88 | >>> tmp_y = y.reshape((-1, num_steps) + y.shape[1:]) 89 | >>> y = tmp_y[:, -1] 90 | >>> print(x, y) 91 | ... [['a' 'a'] 92 | ... ['b' 'b'] 93 | ... ['b' 'b'] 94 | ... ['c' 'c']] [1 2] 95 | ... [['c' 'c'] 96 | ... ['d' 'd'] 97 | ... ['d' 'd'] 98 | ... ['e' 'e']] [3 4] 99 | """ 100 | assert len(inputs) == len(targets) 101 | n_loads = (batch_size * stride) + (seq_length - stride) 102 | for start_idx in range(0, len(inputs) - n_loads + 1, (batch_size * stride)): 103 | seq_inputs = np.zeros((batch_size, seq_length) + inputs.shape[1:], 104 | dtype=inputs.dtype) 105 | seq_targets = np.zeros((batch_size, seq_length) + targets.shape[1:], 106 | dtype=targets.dtype) 107 | for b_idx in xrange(batch_size): 108 | start_seq_idx = start_idx + (b_idx * stride) 109 | end_seq_idx = start_seq_idx + seq_length 110 | seq_inputs[b_idx] = inputs[start_seq_idx:end_seq_idx] 111 | seq_targets[b_idx] = targets[start_seq_idx:end_seq_idx] 112 | flatten_inputs = seq_inputs.reshape((-1,) + inputs.shape[1:]) 113 | flatten_targets = seq_targets.reshape((-1,) + targets.shape[1:]) 114 | yield flatten_inputs, flatten_targets 115 | 116 | def seq_minibatches2(inputs, targets, batch_size, num_steps): 117 | """Generate a generator that iterates on two list of words. Yields (Returns) the source contexts and 118 | the target context by the given batch_size and num_steps (sequence_length), 119 | see ``PTB tutorial``. In TensorFlow's tutorial, this generates the batch_size pointers into the raw 120 | PTB data, and allows minibatch iteration along these pointers. 121 | 122 | - Hint, if the input data are images, you can modify the code as follow. 123 | 124 | .. code-block:: python 125 | 126 | from 127 | data = np.zeros([batch_size, batch_len) 128 | to 129 | data = np.zeros([batch_size, batch_len, inputs.shape[1], inputs.shape[2], inputs.shape[3]]) 130 | 131 | Parameters 132 | ---------- 133 | inputs : a list 134 | the context in list format; note that context usually be 135 | represented by splitting by space, and then convert to unique 136 | word IDs. 137 | targets : a list 138 | the context in list format; note that context usually be 139 | represented by splitting by space, and then convert to unique 140 | word IDs. 141 | batch_size : int 142 | the batch size. 143 | num_steps : int 144 | the number of unrolls. i.e. sequence_length 145 | 146 | Yields 147 | ------ 148 | Pairs of the batched data, each a matrix of shape [batch_size, num_steps]. 149 | 150 | Raises 151 | ------ 152 | ValueError : if batch_size or num_steps are too high. 153 | 154 | Examples 155 | -------- 156 | >>> X = [i for i in range(20)] 157 | >>> Y = [i for i in range(20,40)] 158 | >>> for batch in tl.iterate.seq_minibatches2(X, Y, batch_size=2, num_steps=3): 159 | ... x, y = batch 160 | ... print(x, y) 161 | ... 162 | ... [[ 0. 1. 2.] 163 | ... [ 10. 11. 12.]] 164 | ... [[ 20. 21. 22.] 165 | ... [ 30. 31. 32.]] 166 | ... 167 | ... [[ 3. 4. 5.] 168 | ... [ 13. 14. 15.]] 169 | ... [[ 23. 24. 25.] 170 | ... [ 33. 34. 35.]] 171 | ... 172 | ... [[ 6. 7. 8.] 173 | ... [ 16. 17. 18.]] 174 | ... [[ 26. 27. 28.] 175 | ... [ 36. 37. 38.]] 176 | 177 | Code References 178 | --------------- 179 | - ``tensorflow/models/rnn/ptb/reader.py`` 180 | """ 181 | assert len(inputs) == len(targets) 182 | data_len = len(inputs) 183 | batch_len = data_len // batch_size 184 | # data = np.zeros([batch_size, batch_len]) 185 | data = np.zeros((batch_size, batch_len) + inputs.shape[1:], 186 | dtype=inputs.dtype) 187 | data2 = np.zeros([batch_size, batch_len]) 188 | 189 | for i in range(batch_size): 190 | data[i] = inputs[batch_len * i:batch_len * (i + 1)] 191 | data2[i] = targets[batch_len * i:batch_len * (i + 1)] 192 | 193 | epoch_size = (batch_len - 1) // num_steps 194 | 195 | if epoch_size == 0: 196 | raise ValueError("epoch_size == 0, decrease batch_size or num_steps") 197 | 198 | for i in range(epoch_size): 199 | x = data[:, i*num_steps:(i+1)*num_steps] 200 | x2 = data2[:, i*num_steps:(i+1)*num_steps] 201 | yield (x, x2) 202 | 203 | 204 | def ptb_iterator(raw_data, batch_size, num_steps): 205 | """ 206 | Generate a generator that iterates on a list of words, see PTB tutorial. Yields (Returns) the source contexts and 207 | the target context by the given batch_size and num_steps (sequence_length).\n 208 | see ``PTB tutorial``. 209 | 210 | e.g. x = [0, 1, 2] y = [1, 2, 3] , when batch_size = 1, num_steps = 3, 211 | raw_data = [i for i in range(100)] 212 | 213 | In TensorFlow's tutorial, this generates batch_size pointers into the raw 214 | PTB data, and allows minibatch iteration along these pointers. 215 | 216 | Parameters 217 | ---------- 218 | raw_data : a list 219 | the context in list format; note that context usually be 220 | represented by splitting by space, and then convert to unique 221 | word IDs. 222 | batch_size : int 223 | the batch size. 224 | num_steps : int 225 | the number of unrolls. i.e. sequence_length 226 | 227 | Yields 228 | ------ 229 | Pairs of the batched data, each a matrix of shape [batch_size, num_steps]. 230 | The second element of the tuple is the same data time-shifted to the 231 | right by one. 232 | 233 | Raises 234 | ------ 235 | ValueError : if batch_size or num_steps are too high. 236 | 237 | Examples 238 | -------- 239 | >>> train_data = [i for i in range(20)] 240 | >>> for batch in tl.iterate.ptb_iterator(train_data, batch_size=2, num_steps=3): 241 | >>> x, y = batch 242 | >>> print(x, y) 243 | ... [[ 0 1 2] <---x 1st subset/ iteration 244 | ... [10 11 12]] 245 | ... [[ 1 2 3] <---y 246 | ... [11 12 13]] 247 | ... 248 | ... [[ 3 4 5] <--- 1st batch input 2nd subset/ iteration 249 | ... [13 14 15]] <--- 2nd batch input 250 | ... [[ 4 5 6] <--- 1st batch target 251 | ... [14 15 16]] <--- 2nd batch target 252 | ... 253 | ... [[ 6 7 8] 3rd subset/ iteration 254 | ... [16 17 18]] 255 | ... [[ 7 8 9] 256 | ... [17 18 19]] 257 | 258 | Code References 259 | ---------------- 260 | - ``tensorflow/models/rnn/ptb/reader.py`` 261 | """ 262 | raw_data = np.array(raw_data, dtype=np.int32) 263 | 264 | data_len = len(raw_data) 265 | batch_len = data_len // batch_size 266 | data = np.zeros([batch_size, batch_len], dtype=np.int32) 267 | for i in range(batch_size): 268 | data[i] = raw_data[batch_len * i:batch_len * (i + 1)] 269 | 270 | epoch_size = (batch_len - 1) // num_steps 271 | 272 | if epoch_size == 0: 273 | raise ValueError("epoch_size == 0, decrease batch_size or num_steps") 274 | 275 | for i in range(epoch_size): 276 | x = data[:, i*num_steps:(i+1)*num_steps] 277 | y = data[:, i*num_steps+1:(i+1)*num_steps+1] 278 | yield (x, y) 279 | 280 | 281 | 282 | # def minibatches_for_sequence2D(inputs, targets, batch_size, sequence_length, stride=1): 283 | # """ 284 | # Input a group of example in 2D numpy.array and their labels. 285 | # Return the examples and labels by the given batchsize, sequence_length. 286 | # Use for RNN. 287 | # 288 | # Parameters 289 | # ---------- 290 | # inputs : numpy.array 291 | # (X) The input features, every row is a example. 292 | # targets : numpy.array 293 | # (y) The labels of inputs, every row is a example. 294 | # batchsize : int 295 | # The batch size must be a multiple of sequence_length: int(batch_size % sequence_length) == 0 296 | # sequence_length : int 297 | # The sequence length 298 | # stride : int 299 | # The stride step 300 | # 301 | # Examples 302 | # -------- 303 | # >>> sequence_length = 2 304 | # >>> batch_size = 4 305 | # >>> stride = 1 306 | # >>> X_train = np.asarray([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15],[16,17,18],[19,20,21],[22,23,24]]) 307 | # >>> y_train = np.asarray(['0','1','2','3','4','5','6','7']) 308 | # >>> print('X_train = %s' % X_train) 309 | # >>> print('y_train = %s' % y_train) 310 | # >>> for batch in minibatches_for_sequence2D(X_train, y_train, batch_size=batch_size, sequence_length=sequence_length, stride=stride): 311 | # >>> inputs, targets = batch 312 | # >>> print(inputs) 313 | # >>> print(targets) 314 | # ... [[ 1. 2. 3.] 315 | # ... [ 4. 5. 6.] 316 | # ... [ 4. 5. 6.] 317 | # ... [ 7. 8. 9.]] 318 | # ... [1 2] 319 | # ... [[ 4. 5. 6.] 320 | # ... [ 7. 8. 9.] 321 | # ... [ 7. 8. 9.] 322 | # ... [ 10. 11. 12.]] 323 | # ... [2 3] 324 | # ... ... 325 | # ... [[ 16. 17. 18.] 326 | # ... [ 19. 20. 21.] 327 | # ... [ 19. 20. 21.] 328 | # ... [ 22. 23. 24.]] 329 | # ... [6 7] 330 | # """ 331 | # print('len(targets)=%d batch_size=%d sequence_length=%d stride=%d' % (len(targets), batch_size, sequence_length, stride)) 332 | # assert len(inputs) == len(targets), '1 feature vector have 1 target vector/value' #* sequence_length 333 | # # assert int(batch_size % sequence_length) == 0, 'batch_size % sequence_length must == 0\ 334 | # # batch_size is number of examples rather than number of targets' 335 | # 336 | # # print(inputs.shape, len(inputs), len(inputs[0])) 337 | # 338 | # n_targets = int(batch_size/sequence_length) 339 | # # n_targets = int(np.ceil(batch_size/sequence_length)) 340 | # X = np.empty(shape=(0,len(inputs[0])), dtype=np.float32) 341 | # y = np.zeros(shape=(1, n_targets), dtype=np.int32) 342 | # 343 | # for idx in range(sequence_length, len(inputs), stride): # go through all example during 1 epoch 344 | # for n in range(n_targets): # for num of target 345 | # X = np.concatenate((X, inputs[idx-sequence_length+n:idx+n])) 346 | # y[0][n] = targets[idx-1+n] 347 | # # y = np.vstack((y, targets[idx-1+n])) 348 | # yield X, y[0] 349 | # X = np.empty(shape=(0,len(inputs[0]))) 350 | # # y = np.empty(shape=(1,0)) 351 | # 352 | # 353 | # def minibatches_for_sequence4D(inputs, targets, batch_size, sequence_length, stride=1): # 354 | # """ 355 | # Input a group of example in 4D numpy.array and their labels. 356 | # Return the examples and labels by the given batchsize, sequence_length. 357 | # Use for RNN. 358 | # 359 | # Parameters 360 | # ---------- 361 | # inputs : numpy.array 362 | # (X) The input features, every row is a example. 363 | # targets : numpy.array 364 | # (y) The labels of inputs, every row is a example. 365 | # batchsize : int 366 | # The batch size must be a multiple of sequence_length: int(batch_size % sequence_length) == 0 367 | # sequence_length : int 368 | # The sequence length 369 | # stride : int 370 | # The stride step 371 | # 372 | # Examples 373 | # -------- 374 | # >>> sequence_length = 2 375 | # >>> batch_size = 2 376 | # >>> stride = 1 377 | # >>> X_train = np.asarray([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15],[16,17,18],[19,20,21],[22,23,24]]) 378 | # >>> y_train = np.asarray(['0','1','2','3','4','5','6','7']) 379 | # >>> X_train = np.expand_dims(X_train, axis=1) 380 | # >>> X_train = np.expand_dims(X_train, axis=3) 381 | # >>> for batch in minibatches_for_sequence4D(X_train, y_train, batch_size=batch_size, sequence_length=sequence_length, stride=stride): 382 | # >>> inputs, targets = batch 383 | # >>> print(inputs) 384 | # >>> print(targets) 385 | # ... [[[[ 1.] 386 | # ... [ 2.] 387 | # ... [ 3.]]] 388 | # ... [[[ 4.] 389 | # ... [ 5.] 390 | # ... [ 6.]]]] 391 | # ... [1] 392 | # ... [[[[ 4.] 393 | # ... [ 5.] 394 | # ... [ 6.]]] 395 | # ... [[[ 7.] 396 | # ... [ 8.] 397 | # ... [ 9.]]]] 398 | # ... [2] 399 | # ... ... 400 | # ... [[[[ 19.] 401 | # ... [ 20.] 402 | # ... [ 21.]]] 403 | # ... [[[ 22.] 404 | # ... [ 23.] 405 | # ... [ 24.]]]] 406 | # ... [7] 407 | # """ 408 | # print('len(targets)=%d batch_size=%d sequence_length=%d stride=%d' % (len(targets), batch_size, sequence_length, stride)) 409 | # assert len(inputs) == len(targets), '1 feature vector have 1 target vector/value' #* sequence_length 410 | # # assert int(batch_size % sequence_length) == 0, 'in LSTM, batch_size % sequence_length must == 0\ 411 | # # batch_size is number of X_train rather than number of targets' 412 | # assert stride >= 1, 'stride must be >=1, at least move 1 step for each iternation' 413 | # 414 | # n_example, n_channels, width, height = inputs.shape 415 | # print('n_example=%d n_channels=%d width=%d height=%d' % (n_example, n_channels, width, height)) 416 | # 417 | # n_targets = int(np.ceil(batch_size/sequence_length)) # 实际为 batchsize/sequence_length + 1 418 | # print(n_targets) 419 | # X = np.zeros(shape=(batch_size, n_channels, width, height), dtype=np.float32) 420 | # # X = np.zeros(shape=(n_targets, sequence_length, n_channels, width, height), dtype=np.float32) 421 | # y = np.zeros(shape=(1,n_targets), dtype=np.int32) 422 | # # y = np.empty(shape=(0,1), dtype=np.float32) 423 | # # time.sleep(2) 424 | # for idx in range(sequence_length, n_example-n_targets+2, stride): # go through all example during 1 epoch 425 | # for n in range(n_targets): # for num of target 426 | # # print(idx+n, inputs[idx-sequence_length+n : idx+n].shape) 427 | # X[n*sequence_length : (n+1)*sequence_length] = inputs[idx+n-sequence_length : idx+n] 428 | # # X[n] = inputs[idx-sequence_length+n:idx+n] 429 | # y[0][n] = targets[idx+n-1] 430 | # # y = np.vstack((y, targets[idx-1+n])) 431 | # # y = targets[idx: idx+n_targets] 432 | # yield X, y[0] 433 | -------------------------------------------------------------------------------- /tensorlayer/nlp.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | 6 | 7 | import tensorflow as tf 8 | import os 9 | from sys import platform as _platform 10 | import collections 11 | import random 12 | import numpy as np 13 | import warnings 14 | from six.moves import xrange 15 | from tensorflow.python.platform import gfile 16 | import re 17 | 18 | ## Iteration functions 19 | def generate_skip_gram_batch(data, batch_size, num_skips, skip_window, data_index=0): 20 | """Generate a training batch for the Skip-Gram model. 21 | 22 | Parameters 23 | ---------- 24 | data : a list 25 | To present context. 26 | batch_size : an int 27 | Batch size to return. 28 | num_skips : an int 29 | How many times to reuse an input to generate a label. 30 | skip_window : an int 31 | How many words to consider left and right. 32 | data_index : an int 33 | Index of the context location. 34 | without using yield, this code use data_index to instead. 35 | 36 | Returns 37 | -------- 38 | batch : a list 39 | Inputs 40 | labels : a list 41 | Labels 42 | data_index : an int 43 | Index of the context location. 44 | 45 | Examples 46 | -------- 47 | >>> Setting num_skips=2, skip_window=1, use the right and left words. 48 | >>> In the same way, num_skips=4, skip_window=2 means use the nearby 4 words. 49 | 50 | >>> data = [1,2,3,4,5,6,7,8,9,10,11] 51 | >>> batch, labels, data_index = tl.nlp.generate_skip_gram_batch(data=data, batch_size=8, num_skips=2, skip_window=1, data_index=0) 52 | >>> print(batch) 53 | ... [2 2 3 3 4 4 5 5] 54 | >>> print(labels) 55 | ... [[3] 56 | ... [1] 57 | ... [4] 58 | ... [2] 59 | ... [5] 60 | ... [3] 61 | ... [4] 62 | ... [6]] 63 | 64 | References 65 | ----------- 66 | - `TensorFlow word2vec tutorial `_ 67 | """ 68 | # global data_index # you can put data_index outside the function, then 69 | # modify the global data_index in the function without return it. 70 | # note: without using yield, this code use data_index to instead. 71 | assert batch_size % num_skips == 0 72 | assert num_skips <= 2 * skip_window 73 | batch = np.ndarray(shape=(batch_size), dtype=np.int32) 74 | labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) 75 | span = 2 * skip_window + 1 # [ skip_window target skip_window ] 76 | buffer = collections.deque(maxlen=span) 77 | for _ in range(span): 78 | buffer.append(data[data_index]) 79 | data_index = (data_index + 1) % len(data) 80 | for i in range(batch_size // num_skips): 81 | target = skip_window # target label at the center of the buffer 82 | targets_to_avoid = [ skip_window ] 83 | for j in range(num_skips): 84 | while target in targets_to_avoid: 85 | target = random.randint(0, span - 1) 86 | targets_to_avoid.append(target) 87 | batch[i * num_skips + j] = buffer[skip_window] 88 | labels[i * num_skips + j, 0] = buffer[target] 89 | buffer.append(data[data_index]) 90 | data_index = (data_index + 1) % len(data) 91 | return batch, labels, data_index 92 | 93 | 94 | ## Sampling functions 95 | def sample(a=[], temperature=1.0): 96 | """Sample an index from a probability array. 97 | 98 | Parameters 99 | ---------- 100 | a : a list 101 | List of probabilities. 102 | temperature : float or None 103 | The higher the more uniform.\n 104 | When a = [0.1, 0.2, 0.7],\n 105 | temperature = 0.7, the distribution will be sharpen [ 0.05048273 0.13588945 0.81362782]\n 106 | temperature = 1.0, the distribution will be the same [0.1 0.2 0.7]\n 107 | temperature = 1.5, the distribution will be filtered [ 0.16008435 0.25411807 0.58579758]\n 108 | If None, it will be ``np.argmax(a)`` 109 | 110 | Notes 111 | ------ 112 | No matter what is the temperature and input list, the sum of all probabilities will be one. 113 | Even if input list = [1, 100, 200], the sum of all probabilities will still be one. 114 | 115 | For large vocabulary_size, choice a higher temperature to avoid error. 116 | """ 117 | b = np.copy(a) 118 | try: 119 | if temperature == 1: 120 | return np.argmax(np.random.multinomial(1, a, 1)) 121 | if temperature is None: 122 | return np.argmax(a) 123 | else: 124 | a = np.log(a) / temperature 125 | a = np.exp(a) / np.sum(np.exp(a)) 126 | return np.argmax(np.random.multinomial(1, a, 1)) 127 | except: 128 | # np.set_printoptions(threshold=np.nan) 129 | # print(a) 130 | # print(np.sum(a)) 131 | # print(np.max(a)) 132 | # print(np.min(a)) 133 | # exit() 134 | message = "For large vocabulary_size, choice a higher temperature\ 135 | to avoid log error. Hint : use ``sample_top``. " 136 | warnings.warn(message, Warning) 137 | # print(a) 138 | # print(b) 139 | return np.argmax(np.random.multinomial(1, b, 1)) 140 | 141 | def sample_top(a=[], top_k=10): 142 | """Sample from ``top_k`` probabilities. 143 | 144 | Parameters 145 | ---------- 146 | a : a list 147 | List of probabilities. 148 | top_k : int 149 | Number of candidates to be considered. 150 | """ 151 | idx = np.argpartition(a, -top_k)[-top_k:] 152 | probs = a[idx] 153 | # print("new", probs) 154 | probs = probs / np.sum(probs) 155 | choice = np.random.choice(idx, p=probs) 156 | return choice 157 | ## old implementation 158 | # a = np.array(a) 159 | # idx = np.argsort(a)[::-1] 160 | # idx = idx[:top_k] 161 | # # a = a[idx] 162 | # probs = a[idx] 163 | # print("prev", probs) 164 | # # probs = probs / np.sum(probs) 165 | # # choice = np.random.choice(idx, p=probs) 166 | # # return choice 167 | 168 | 169 | ## Vector representations of words (Advanced) UNDOCUMENT 170 | class SimpleVocabulary(object): 171 | """Simple vocabulary wrapper, see create_vocab(). 172 | 173 | Parameters 174 | ------------ 175 | vocab : A dictionary of word to word_id. 176 | unk_id : Id of the special 'unknown' word. 177 | """ 178 | 179 | def __init__(self, vocab, unk_id): 180 | """Initializes the vocabulary.""" 181 | 182 | 183 | self._vocab = vocab 184 | self._unk_id = unk_id 185 | 186 | def word_to_id(self, word): 187 | """Returns the integer id of a word string.""" 188 | if word in self._vocab: 189 | return self._vocab[word] 190 | else: 191 | return self._unk_id 192 | 193 | class Vocabulary(object): 194 | """Create Vocabulary class from a given vocabulary and its id-word, word-id convert, 195 | see create_vocab() and ``tutorial_tfrecord3.py``. 196 | 197 | Parameters 198 | ----------- 199 | vocab_file : File containing the vocabulary, where the words are the first 200 | whitespace-separated token on each line (other tokens are ignored) and 201 | the word ids are the corresponding line numbers. 202 | start_word : Special word denoting sentence start. 203 | end_word : Special word denoting sentence end. 204 | unk_word : Special word denoting unknown words. 205 | 206 | Properties 207 | ------------ 208 | vocab : a dictionary from word to id. 209 | reverse_vocab : a list from id to word. 210 | start_id : int of start id 211 | end_id : int of end id 212 | unk_id : int of unk id 213 | pad_id : int of padding id 214 | 215 | Vocab_files 216 | ------------- 217 | >>> Look as follow, includes `start_word` , `end_word` but no `unk_word` . 218 | >>> a 969108 219 | >>> 586368 220 | >>> 586368 221 | >>> . 440479 222 | >>> on 213612 223 | >>> of 202290 224 | >>> the 196219 225 | >>> in 182598 226 | >>> with 152984 227 | >>> and 139109 228 | >>> is 97322 229 | """ 230 | 231 | def __init__(self, 232 | vocab_file, 233 | start_word="", 234 | end_word="", 235 | unk_word="", 236 | pad_word=""): 237 | if not tf.gfile.Exists(vocab_file): 238 | tf.logging.fatal("Vocab file %s not found.", vocab_file) 239 | tf.logging.info("Initializing vocabulary from file: %s", vocab_file) 240 | 241 | with tf.gfile.GFile(vocab_file, mode="r") as f: 242 | reverse_vocab = list(f.readlines()) 243 | reverse_vocab = [line.split()[0] for line in reverse_vocab] 244 | assert start_word in reverse_vocab 245 | assert end_word in reverse_vocab 246 | if unk_word not in reverse_vocab: 247 | reverse_vocab.append(unk_word) 248 | vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)]) 249 | 250 | print(" [TL] Vocabulary from %s : %s %s %s" % (vocab_file, start_word, end_word, unk_word)) 251 | print(" vocabulary with %d words (includes start_word, end_word, unk_word)" % len(vocab)) 252 | # tf.logging.info(" vocabulary with %d words" % len(vocab)) 253 | 254 | self.vocab = vocab # vocab[word] = id 255 | self.reverse_vocab = reverse_vocab # reverse_vocab[id] = word 256 | 257 | # Save special word ids. 258 | self.start_id = vocab[start_word] 259 | self.end_id = vocab[end_word] 260 | self.unk_id = vocab[unk_word] 261 | self.pad_id = vocab[pad_word] 262 | print(" start_id: %d" % self.start_id) 263 | print(" end_id: %d" % self.end_id) 264 | print(" unk_id: %d" % self.unk_id) 265 | print(" pad_id: %d" % self.pad_id) 266 | 267 | def word_to_id(self, word): 268 | """Returns the integer word id of a word string.""" 269 | if word in self.vocab: 270 | return self.vocab[word] 271 | else: 272 | return self.unk_id 273 | 274 | def id_to_word(self, word_id): 275 | """Returns the word string of an integer word id.""" 276 | if word_id >= len(self.reverse_vocab): 277 | return self.reverse_vocab[self.unk_id] 278 | else: 279 | return self.reverse_vocab[word_id] 280 | 281 | def process_sentence(sentence, start_word="", end_word=""): 282 | """Converts a sentence string into a list of string words, add start_word and end_word, 283 | see ``create_vocab()`` and ``tutorial_tfrecord3.py``. 284 | 285 | Parameter 286 | --------- 287 | sentence : a sentence in string. 288 | start_word : a string or None, if None, non start word will be appended. 289 | end_word : a string or None, if None, non end word will be appended. 290 | 291 | Returns 292 | --------- 293 | A list of strings; the processed caption. 294 | 295 | Examples 296 | ----------- 297 | >>> c = "how are you?" 298 | >>> c = tl.nlp.process_sentence(c) 299 | >>> print(c) 300 | ... ['', 'how', 'are', 'you', '?', ''] 301 | 302 | Notes 303 | ------- 304 | - You have to install the following package. 305 | - `Installing NLTK `_ 306 | - `Installing NLTK data `_ 307 | """ 308 | try: 309 | import nltk 310 | except: 311 | raise Exception("Hint : NLTK is required.") 312 | if start_word is not None: 313 | process_sentence = [start_word] 314 | else: 315 | process_sentence = [] 316 | process_sentence.extend(nltk.tokenize.word_tokenize(sentence.lower())) 317 | if end_word is not None: 318 | process_sentence.append(end_word) 319 | return process_sentence 320 | 321 | def create_vocab(sentences, word_counts_output_file, min_word_count=1): 322 | """Creates the vocabulary of word to word_id, see create_vocab() and ``tutorial_tfrecord3.py``. 323 | 324 | The vocabulary is saved to disk in a text file of word counts. The id of each 325 | word in the file is its corresponding 0-based line number. 326 | 327 | Parameters 328 | ------------ 329 | sentences : a list of lists of strings. 330 | word_counts_output_file : A string 331 | The file name. 332 | min_word_count : a int 333 | Minimum number of occurrences for a word. 334 | 335 | Returns 336 | -------- 337 | - tl.nlp.SimpleVocabulary object. 338 | 339 | Mores 340 | ----- 341 | - ``tl.nlp.build_vocab()`` 342 | 343 | Examples 344 | -------- 345 | >>> captions = ["one two , three", "four five five"] 346 | >>> processed_capts = [] 347 | >>> for c in captions: 348 | >>> c = tl.nlp.process_sentence(c, start_word="", end_word="") 349 | >>> processed_capts.append(c) 350 | >>> print(processed_capts) 351 | ...[['', 'one', 'two', ',', 'three', ''], ['', 'four', 'five', 'five', '']] 352 | 353 | >>> tl.nlp.create_vocab(processed_capts, word_counts_output_file='vocab.txt', min_word_count=1) 354 | ... [TL] Creating vocabulary. 355 | ... Total words: 8 356 | ... Words in vocabulary: 8 357 | ... Wrote vocabulary file: vocab.txt 358 | >>> vocab = tl.nlp.Vocabulary('vocab.txt', start_word="", end_word="", unk_word="") 359 | ... INFO:tensorflow:Initializing vocabulary from file: vocab.txt 360 | ... [TL] Vocabulary from vocab.txt : 361 | ... vocabulary with 10 words (includes start_word, end_word, unk_word) 362 | ... start_id: 2 363 | ... end_id: 3 364 | ... unk_id: 9 365 | ... pad_id: 0 366 | """ 367 | from collections import Counter 368 | print(" [TL] Creating vocabulary.") 369 | counter = Counter() 370 | for c in sentences: 371 | counter.update(c) 372 | # print('c',c) 373 | print(" Total words: %d" % len(counter)) 374 | 375 | # Filter uncommon words and sort by descending count. 376 | word_counts = [x for x in counter.items() if x[1] >= min_word_count] 377 | word_counts.sort(key=lambda x: x[1], reverse=True) 378 | word_counts = [("", 0)] + word_counts # 1st id should be reserved for padding 379 | # print(word_counts) 380 | print(" Words in vocabulary: %d" % len(word_counts)) 381 | 382 | # Write out the word counts file. 383 | with tf.gfile.FastGFile(word_counts_output_file, "w") as f: 384 | f.write("\n".join(["%s %d" % (w, c) for w, c in word_counts])) 385 | print(" Wrote vocabulary file: %s" % word_counts_output_file) 386 | 387 | # Create the vocabulary dictionary. 388 | reverse_vocab = [x[0] for x in word_counts] 389 | unk_id = len(reverse_vocab) 390 | vocab_dict = dict([(x, y) for (y, x) in enumerate(reverse_vocab)]) 391 | vocab = SimpleVocabulary(vocab_dict, unk_id) 392 | 393 | return vocab 394 | 395 | 396 | ## Vector representations of words 397 | def simple_read_words(filename="nietzsche.txt"): 398 | """Read context from file without any preprocessing. 399 | 400 | Parameters 401 | ---------- 402 | filename : a string 403 | A file path (like .txt file) 404 | 405 | Returns 406 | -------- 407 | The context in a string 408 | """ 409 | with open("nietzsche.txt", "r") as f: 410 | words = f.read() 411 | return words 412 | 413 | def read_words(filename="nietzsche.txt", replace = ['\n', '']): 414 | """File to list format context. Note that, this script can not handle punctuations. 415 | For customized read_words method, see ``tutorial_generate_text.py``. 416 | 417 | Parameters 418 | ---------- 419 | filename : a string 420 | A file path (like .txt file), 421 | replace : a list 422 | [original string, target string], to disable replace use ['', ''] 423 | 424 | Returns 425 | -------- 426 | The context in a list, split by space by default, and use ``''`` to represent ``'\n'``, 427 | e.g. ``[... 'how', 'useful', 'it', "'s" ... ]``. 428 | 429 | Code References 430 | --------------- 431 | - `tensorflow.models.rnn.ptb.reader `_ 432 | """ 433 | with tf.gfile.GFile(filename, "r") as f: 434 | try: # python 3.4 or older 435 | context_list = f.read().replace(*replace).split() 436 | except: # python 3.5 437 | f.seek(0) 438 | replace = [x.encode('utf-8') for x in replace] 439 | context_list = f.read().replace(*replace).split() 440 | return context_list 441 | 442 | def read_analogies_file(eval_file='questions-words.txt', word2id={}): 443 | """Reads through an analogy question file, return its id format. 444 | 445 | Parameters 446 | ---------- 447 | eval_data : a string 448 | The file name. 449 | word2id : a dictionary 450 | Mapping words to unique IDs. 451 | 452 | Returns 453 | -------- 454 | analogy_questions : a [n, 4] numpy array containing the analogy question's 455 | word ids. 456 | questions_skipped: questions skipped due to unknown words. 457 | 458 | Examples 459 | --------- 460 | >>> eval_file should be in this format : 461 | >>> : capital-common-countries 462 | >>> Athens Greece Baghdad Iraq 463 | >>> Athens Greece Bangkok Thailand 464 | >>> Athens Greece Beijing China 465 | >>> Athens Greece Berlin Germany 466 | >>> Athens Greece Bern Switzerland 467 | >>> Athens Greece Cairo Egypt 468 | >>> Athens Greece Canberra Australia 469 | >>> Athens Greece Hanoi Vietnam 470 | >>> Athens Greece Havana Cuba 471 | ... 472 | 473 | >>> words = tl.files.load_matt_mahoney_text8_dataset() 474 | >>> data, count, dictionary, reverse_dictionary = \ 475 | tl.nlp.build_words_dataset(words, vocabulary_size, True) 476 | >>> analogy_questions = tl.nlp.read_analogies_file( \ 477 | eval_file='questions-words.txt', word2id=dictionary) 478 | >>> print(analogy_questions) 479 | ... [[ 3068 1248 7161 1581] 480 | ... [ 3068 1248 28683 5642] 481 | ... [ 3068 1248 3878 486] 482 | ... ..., 483 | ... [ 1216 4309 19982 25506] 484 | ... [ 1216 4309 3194 8650] 485 | ... [ 1216 4309 140 312]] 486 | """ 487 | questions = [] 488 | questions_skipped = 0 489 | with open(eval_file, "rb") as analogy_f: 490 | for line in analogy_f: 491 | if line.startswith(b":"): # Skip comments. 492 | continue 493 | words = line.strip().lower().split(b" ") # lowercase 494 | ids = [word2id.get(w.strip()) for w in words] 495 | if None in ids or len(ids) != 4: 496 | questions_skipped += 1 497 | else: 498 | questions.append(np.array(ids)) 499 | print("Eval analogy file: ", eval_file) 500 | print("Questions: ", len(questions)) 501 | print("Skipped: ", questions_skipped) 502 | analogy_questions = np.array(questions, dtype=np.int32) 503 | return analogy_questions 504 | 505 | def build_vocab(data): 506 | """Build vocabulary. 507 | Given the context in list format. 508 | Return the vocabulary, which is a dictionary for word to id. 509 | e.g. {'campbell': 2587, 'atlantic': 2247, 'aoun': 6746 .... } 510 | 511 | Parameters 512 | ---------- 513 | data : a list of string 514 | the context in list format 515 | 516 | Returns 517 | -------- 518 | word_to_id : a dictionary 519 | mapping words to unique IDs. e.g. {'campbell': 2587, 'atlantic': 2247, 'aoun': 6746 .... } 520 | 521 | Code References 522 | --------------- 523 | - `tensorflow.models.rnn.ptb.reader `_ 524 | 525 | Examples 526 | -------- 527 | >>> data_path = os.getcwd() + '/simple-examples/data' 528 | >>> train_path = os.path.join(data_path, "ptb.train.txt") 529 | >>> word_to_id = build_vocab(read_txt_words(train_path)) 530 | """ 531 | # data = _read_words(filename) 532 | counter = collections.Counter(data) 533 | # print('counter', counter) # dictionary for the occurrence number of each word, e.g. 'banknote': 1, 'photography': 1, 'kia': 1 534 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 535 | # print('count_pairs',count_pairs) # convert dictionary to list of tuple, e.g. ('ssangyong', 1), ('swapo', 1), ('wachter', 1) 536 | words, _ = list(zip(*count_pairs)) 537 | word_to_id = dict(zip(words, range(len(words)))) 538 | # print(words) # list of words 539 | # print(word_to_id) # dictionary for word to id, e.g. 'campbell': 2587, 'atlantic': 2247, 'aoun': 6746 540 | return word_to_id 541 | 542 | def build_reverse_dictionary(word_to_id): 543 | """Given a dictionary for converting word to integer id. 544 | Returns a reverse dictionary for converting a id to word. 545 | 546 | Parameters 547 | ---------- 548 | word_to_id : dictionary 549 | mapping words to unique ids 550 | 551 | Returns 552 | -------- 553 | reverse_dictionary : a dictionary 554 | mapping ids to words 555 | """ 556 | reverse_dictionary = dict(zip(word_to_id.values(), word_to_id.keys())) 557 | return reverse_dictionary 558 | 559 | def build_words_dataset(words=[], vocabulary_size=50000, printable=True, unk_key = 'UNK'): 560 | """Build the words dictionary and replace rare words with 'UNK' token. 561 | The most common word has the smallest integer id. 562 | 563 | Parameters 564 | ---------- 565 | words : a list of string or byte 566 | The context in list format. You may need to do preprocessing on the words, 567 | such as lower case, remove marks etc. 568 | vocabulary_size : an int 569 | The maximum vocabulary size, limiting the vocabulary size. 570 | Then the script replaces rare words with 'UNK' token. 571 | printable : boolean 572 | Whether to print the read vocabulary size of the given words. 573 | unk_key : a string 574 | Unknown words = unk_key 575 | 576 | Returns 577 | -------- 578 | data : a list of integer 579 | The context in a list of ids 580 | count : a list of tuple and list 581 | count[0] is a list : the number of rare words\n 582 | count[1:] are tuples : the number of occurrence of each word\n 583 | e.g. [['UNK', 418391], (b'the', 1061396), (b'of', 593677), (b'and', 416629), (b'one', 411764)] 584 | dictionary : a dictionary 585 | word_to_id, mapping words to unique IDs. 586 | reverse_dictionary : a dictionary 587 | id_to_word, mapping id to unique word. 588 | 589 | Examples 590 | -------- 591 | >>> words = tl.files.load_matt_mahoney_text8_dataset() 592 | >>> vocabulary_size = 50000 593 | >>> data, count, dictionary, reverse_dictionary = tl.nlp.build_words_dataset(words, vocabulary_size) 594 | 595 | Code References 596 | ----------------- 597 | - `tensorflow/examples/tutorials/word2vec/word2vec_basic.py `_ 598 | """ 599 | import collections 600 | count = [[unk_key, -1]] 601 | count.extend(collections.Counter(words).most_common(vocabulary_size - 1)) 602 | dictionary = dict() 603 | for word, _ in count: 604 | dictionary[word] = len(dictionary) 605 | data = list() 606 | unk_count = 0 607 | for word in words: 608 | if word in dictionary: 609 | index = dictionary[word] 610 | else: 611 | index = 0 # dictionary['UNK'] 612 | unk_count += 1 613 | data.append(index) 614 | count[0][1] = unk_count 615 | reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) 616 | if printable: 617 | print('Real vocabulary size %d' % len(collections.Counter(words).keys())) 618 | print('Limited vocabulary size {}'.format(vocabulary_size)) 619 | assert len(collections.Counter(words).keys()) >= vocabulary_size , \ 620 | "the limited vocabulary_size must be less than or equal to the read vocabulary_size" 621 | return data, count, dictionary, reverse_dictionary 622 | 623 | def words_to_word_ids(data=[], word_to_id={}, unk_key = 'UNK'): 624 | """Given a context (words) in list format and the vocabulary, 625 | Returns a list of IDs to represent the context. 626 | 627 | Parameters 628 | ---------- 629 | data : a list of string or byte 630 | the context in list format 631 | word_to_id : a dictionary 632 | mapping words to unique IDs. 633 | unk_key : a string 634 | Unknown words = unk_key 635 | 636 | Returns 637 | -------- 638 | A list of IDs to represent the context. 639 | 640 | Examples 641 | -------- 642 | >>> words = tl.files.load_matt_mahoney_text8_dataset() 643 | >>> vocabulary_size = 50000 644 | >>> data, count, dictionary, reverse_dictionary = \ 645 | ... tl.nlp.build_words_dataset(words, vocabulary_size, True) 646 | >>> context = [b'hello', b'how', b'are', b'you'] 647 | >>> ids = tl.nlp.words_to_word_ids(words, dictionary) 648 | >>> context = tl.nlp.word_ids_to_words(ids, reverse_dictionary) 649 | >>> print(ids) 650 | ... [6434, 311, 26, 207] 651 | >>> print(context) 652 | ... [b'hello', b'how', b'are', b'you'] 653 | 654 | Code References 655 | --------------- 656 | - `tensorflow.models.rnn.ptb.reader `_ 657 | """ 658 | # if isinstance(data[0], six.string_types): 659 | # print(type(data[0])) 660 | # # exit() 661 | # print(data[0]) 662 | # print(word_to_id) 663 | # return [word_to_id[str(word)] for word in data] 664 | # else: 665 | 666 | word_ids = [] 667 | for word in data: 668 | if word_to_id.get(word) is not None: 669 | word_ids.append(word_to_id[word]) 670 | else: 671 | word_ids.append(word_to_id[unk_key]) 672 | return word_ids 673 | # return [word_to_id[word] for word in data] # this one 674 | 675 | # if isinstance(data[0], str): 676 | # # print('is a string object') 677 | # return [word_to_id[word] for word in data] 678 | # else:#if isinstance(s, bytes): 679 | # # print('is a unicode object') 680 | # # print(data[0]) 681 | # return [word_to_id[str(word)] f 682 | 683 | def word_ids_to_words(data, id_to_word): 684 | """Given a context (ids) in list format and the vocabulary, 685 | Returns a list of words to represent the context. 686 | 687 | Parameters 688 | ---------- 689 | data : a list of integer 690 | the context in list format 691 | id_to_word : a dictionary 692 | mapping id to unique word. 693 | 694 | Returns 695 | -------- 696 | A list of string or byte to represent the context. 697 | 698 | Examples 699 | --------- 700 | >>> see words_to_word_ids 701 | """ 702 | return [id_to_word[i] for i in data] 703 | 704 | def save_vocab(count=[], name='vocab.txt'): 705 | """Save the vocabulary to a file so the model can be reloaded. 706 | 707 | Parameters 708 | ---------- 709 | count : a list of tuple and list 710 | count[0] is a list : the number of rare words\n 711 | count[1:] are tuples : the number of occurrence of each word\n 712 | e.g. [['UNK', 418391], (b'the', 1061396), (b'of', 593677), (b'and', 416629), (b'one', 411764)] 713 | 714 | Examples 715 | --------- 716 | >>> words = tl.files.load_matt_mahoney_text8_dataset() 717 | >>> vocabulary_size = 50000 718 | >>> data, count, dictionary, reverse_dictionary = \ 719 | ... tl.nlp.build_words_dataset(words, vocabulary_size, True) 720 | >>> tl.nlp.save_vocab(count, name='vocab_text8.txt') 721 | >>> vocab_text8.txt 722 | ... UNK 418391 723 | ... the 1061396 724 | ... of 593677 725 | ... and 416629 726 | ... one 411764 727 | ... in 372201 728 | ... a 325873 729 | ... to 316376 730 | """ 731 | pwd = os.getcwd() 732 | vocabulary_size = len(count) 733 | with open(os.path.join(pwd, name), "w") as f: 734 | for i in xrange(vocabulary_size): 735 | f.write("%s %d\n" % (tf.compat.as_text(count[i][0]), count[i][1])) 736 | print("%d vocab saved to %s in %s" % (vocabulary_size, name, pwd)) 737 | 738 | ## Functions for translation 739 | def basic_tokenizer(sentence, _WORD_SPLIT=re.compile(b"([.,!?\"':;)(])")): 740 | """Very basic tokenizer: split the sentence into a list of tokens. 741 | 742 | Parameters 743 | ----------- 744 | sentence : tensorflow.python.platform.gfile.GFile Object 745 | _WORD_SPLIT : regular expression for word spliting. 746 | 747 | 748 | Examples 749 | -------- 750 | >>> see create_vocabulary 751 | >>> from tensorflow.python.platform import gfile 752 | >>> train_path = "wmt/giga-fren.release2" 753 | >>> with gfile.GFile(train_path + ".en", mode="rb") as f: 754 | >>> for line in f: 755 | >>> tokens = tl.nlp.basic_tokenizer(line) 756 | >>> print(tokens) 757 | >>> exit() 758 | ... [b'Changing', b'Lives', b'|', b'Changing', b'Society', b'|', b'How', 759 | ... b'It', b'Works', b'|', b'Technology', b'Drives', b'Change', b'Home', 760 | ... b'|', b'Concepts', b'|', b'Teachers', b'|', b'Search', b'|', b'Overview', 761 | ... b'|', b'Credits', b'|', b'HHCC', b'Web', b'|', b'Reference', b'|', 762 | ... b'Feedback', b'Virtual', b'Museum', b'of', b'Canada', b'Home', b'Page'] 763 | 764 | References 765 | ---------- 766 | - Code from ``/tensorflow/models/rnn/translation/data_utils.py`` 767 | """ 768 | words = [] 769 | sentence = tf.compat.as_bytes(sentence) 770 | for space_separated_fragment in sentence.strip().split(): 771 | words.extend(re.split(_WORD_SPLIT, space_separated_fragment)) 772 | return [w for w in words if w] 773 | 774 | def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, 775 | tokenizer=None, normalize_digits=True, 776 | _DIGIT_RE=re.compile(br"\d"), 777 | _START_VOCAB=[b"_PAD", b"_GO", b"_EOS", b"_UNK"]): 778 | """Create vocabulary file (if it does not exist yet) from data file. 779 | 780 | Data file is assumed to contain one sentence per line. Each sentence is 781 | tokenized and digits are normalized (if normalize_digits is set). 782 | Vocabulary contains the most-frequent tokens up to max_vocabulary_size. 783 | We write it to vocabulary_path in a one-token-per-line format, so that later 784 | token in the first line gets id=0, second line gets id=1, and so on. 785 | 786 | Parameters 787 | ----------- 788 | vocabulary_path : path where the vocabulary will be created. 789 | data_path : data file that will be used to create vocabulary. 790 | max_vocabulary_size : limit on the size of the created vocabulary. 791 | tokenizer : a function to use to tokenize each data sentence. 792 | if None, basic_tokenizer will be used. 793 | normalize_digits : Boolean 794 | if true, all digits are replaced by 0s. 795 | 796 | References 797 | ---------- 798 | - Code from ``/tensorflow/models/rnn/translation/data_utils.py`` 799 | """ 800 | if not gfile.Exists(vocabulary_path): 801 | print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path)) 802 | vocab = {} 803 | with gfile.GFile(data_path, mode="rb") as f: 804 | counter = 0 805 | for line in f: 806 | counter += 1 807 | if counter % 100000 == 0: 808 | print(" processing line %d" % counter) 809 | tokens = tokenizer(line) if tokenizer else basic_tokenizer(line) 810 | for w in tokens: 811 | word = re.sub(_DIGIT_RE, b"0", w) if normalize_digits else w 812 | if word in vocab: 813 | vocab[word] += 1 814 | else: 815 | vocab[word] = 1 816 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) 817 | if len(vocab_list) > max_vocabulary_size: 818 | vocab_list = vocab_list[:max_vocabulary_size] 819 | with gfile.GFile(vocabulary_path, mode="wb") as vocab_file: 820 | for w in vocab_list: 821 | vocab_file.write(w + b"\n") 822 | else: 823 | print("Vocabulary %s from data %s exists" % (vocabulary_path, data_path)) 824 | 825 | def initialize_vocabulary(vocabulary_path): 826 | """Initialize vocabulary from file, return the word_to_id (dictionary) 827 | and id_to_word (list). 828 | 829 | We assume the vocabulary is stored one-item-per-line, so a file:\n 830 | dog\n 831 | cat\n 832 | will result in a vocabulary {"dog": 0, "cat": 1}, and this function will 833 | also return the reversed-vocabulary ["dog", "cat"]. 834 | 835 | Parameters 836 | ----------- 837 | vocabulary_path : path to the file containing the vocabulary. 838 | 839 | Returns 840 | -------- 841 | vocab : a dictionary 842 | Word to id. A dictionary mapping string to integers. 843 | rev_vocab : a list 844 | Id to word. The reversed vocabulary (a list, which reverses the vocabulary mapping). 845 | 846 | Examples 847 | --------- 848 | >>> Assume 'test' contains 849 | ... dog 850 | ... cat 851 | ... bird 852 | >>> vocab, rev_vocab = tl.nlp.initialize_vocabulary("test") 853 | >>> print(vocab) 854 | >>> {b'cat': 1, b'dog': 0, b'bird': 2} 855 | >>> print(rev_vocab) 856 | >>> [b'dog', b'cat', b'bird'] 857 | 858 | Raises 859 | ------- 860 | ValueError : if the provided vocabulary_path does not exist. 861 | """ 862 | if gfile.Exists(vocabulary_path): 863 | rev_vocab = [] 864 | with gfile.GFile(vocabulary_path, mode="rb") as f: 865 | rev_vocab.extend(f.readlines()) 866 | rev_vocab = [tf.compat.as_bytes(line.strip()) for line in rev_vocab] 867 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 868 | return vocab, rev_vocab 869 | else: 870 | raise ValueError("Vocabulary file %s not found.", vocabulary_path) 871 | 872 | def sentence_to_token_ids(sentence, vocabulary, 873 | tokenizer=None, normalize_digits=True, 874 | UNK_ID=3, _DIGIT_RE=re.compile(br"\d")): 875 | """Convert a string to list of integers representing token-ids. 876 | 877 | For example, a sentence "I have a dog" may become tokenized into 878 | ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2, 879 | "a": 4, "dog": 7"} this function will return [1, 2, 4, 7]. 880 | 881 | Parameters 882 | ----------- 883 | sentence : tensorflow.python.platform.gfile.GFile Object 884 | The sentence in bytes format to convert to token-ids.\n 885 | see basic_tokenizer(), data_to_token_ids() 886 | vocabulary : a dictionary mapping tokens to integers. 887 | tokenizer : a function to use to tokenize each sentence; 888 | If None, basic_tokenizer will be used. 889 | normalize_digits : Boolean 890 | If true, all digits are replaced by 0s. 891 | 892 | Returns 893 | -------- 894 | A list of integers, the token-ids for the sentence. 895 | """ 896 | 897 | if tokenizer: 898 | words = tokenizer(sentence) 899 | else: 900 | words = basic_tokenizer(sentence) 901 | if not normalize_digits: 902 | return [vocabulary.get(w, UNK_ID) for w in words] 903 | # Normalize digits by 0 before looking words up in the vocabulary. 904 | return [vocabulary.get(re.sub(_DIGIT_RE, b"0", w), UNK_ID) for w in words] 905 | 906 | def data_to_token_ids(data_path, target_path, vocabulary_path, 907 | tokenizer=None, normalize_digits=True, 908 | UNK_ID=3, _DIGIT_RE=re.compile(br"\d")): 909 | """Tokenize data file and turn into token-ids using given vocabulary file. 910 | 911 | This function loads data line-by-line from data_path, calls the above 912 | sentence_to_token_ids, and saves the result to target_path. See comment 913 | for sentence_to_token_ids on the details of token-ids format. 914 | 915 | Parameters 916 | ----------- 917 | data_path : path to the data file in one-sentence-per-line format. 918 | target_path : path where the file with token-ids will be created. 919 | vocabulary_path : path to the vocabulary file. 920 | tokenizer : a function to use to tokenize each sentence; 921 | if None, basic_tokenizer will be used. 922 | normalize_digits : Boolean; if true, all digits are replaced by 0s. 923 | 924 | References 925 | ---------- 926 | - Code from ``/tensorflow/models/rnn/translation/data_utils.py`` 927 | """ 928 | if not gfile.Exists(target_path): 929 | print("Tokenizing data in %s" % data_path) 930 | vocab, _ = initialize_vocabulary(vocabulary_path) 931 | with gfile.GFile(data_path, mode="rb") as data_file: 932 | with gfile.GFile(target_path, mode="w") as tokens_file: 933 | counter = 0 934 | for line in data_file: 935 | counter += 1 936 | if counter % 100000 == 0: 937 | print(" tokenizing line %d" % counter) 938 | token_ids = sentence_to_token_ids(line, vocab, tokenizer, 939 | normalize_digits, UNK_ID=UNK_ID, 940 | _DIGIT_RE=_DIGIT_RE) 941 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 942 | else: 943 | print("Target path %s exists" % target_path) 944 | -------------------------------------------------------------------------------- /tensorlayer/ops.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | 6 | 7 | import tensorflow as tf 8 | import os 9 | import sys 10 | from sys import platform as _platform 11 | 12 | 13 | def exit_tf(sess=None): 14 | """Close tensorboard and nvidia-process if available 15 | 16 | Parameters 17 | ---------- 18 | sess : a session instance of TensorFlow 19 | TensorFlow session 20 | """ 21 | text = "[tl] Close tensorboard and nvidia-process if available" 22 | sess.close() 23 | # import time 24 | # time.sleep(2) 25 | if _platform == "linux" or _platform == "linux2": 26 | print('linux: %s' % text) 27 | os.system('nvidia-smi') 28 | os.system('fuser 6006/tcp -k') # kill tensorboard 6006 29 | os.system("nvidia-smi | grep python |awk '{print $3}'|xargs kill") # kill all nvidia-smi python process 30 | elif _platform == "darwin": 31 | print('OS X: %s' % text) 32 | os.system("lsof -i tcp:6006 | grep -v PID | awk '{print $2}' | xargs kill") # kill tensorboard 6006 33 | elif _platform == "win32": 34 | print('Windows: %s' % text) 35 | else: 36 | print(_platform) 37 | exit() 38 | 39 | def clear_all(printable=True): 40 | """Clears all the placeholder variables of keep prob, 41 | including keeping probabilities of all dropout, denoising, dropconnect etc. 42 | 43 | Parameters 44 | ---------- 45 | printable : boolean 46 | If True, print all deleted variables. 47 | """ 48 | print('clear all .....................................') 49 | gl = globals().copy() 50 | for var in gl: 51 | if var[0] == '_': continue 52 | if 'func' in str(globals()[var]): continue 53 | if 'module' in str(globals()[var]): continue 54 | if 'class' in str(globals()[var]): continue 55 | 56 | if printable: 57 | print(" clear_all ------- %s" % str(globals()[var])) 58 | 59 | del globals()[var] 60 | 61 | # def clear_all2(vars, printable=True): 62 | # """ 63 | # The :function:`clear_all()` Clears all the placeholder variables of keep prob, 64 | # including keeping probabilities of all dropout, denoising, dropconnect 65 | # Parameters 66 | # ---------- 67 | # printable : if True, print all deleted variables. 68 | # """ 69 | # print('clear all .....................................') 70 | # for var in vars: 71 | # if var[0] == '_': continue 72 | # if 'func' in str(var): continue 73 | # if 'module' in str(var): continue 74 | # if 'class' in str(var): continue 75 | # 76 | # if printable: 77 | # print(" clear_all ------- %s" % str(var)) 78 | # 79 | # del var 80 | 81 | def set_gpu_fraction(sess=None, gpu_fraction=0.3): 82 | """Set the GPU memory fraction for the application. 83 | 84 | Parameters 85 | ---------- 86 | sess : a session instance of TensorFlow 87 | TensorFlow session 88 | gpu_fraction : a float 89 | Fraction of GPU memory, (0 ~ 1] 90 | 91 | References 92 | ---------- 93 | - `TensorFlow using GPU `_ 94 | """ 95 | print(" tensorlayer: GPU MEM Fraction %f" % gpu_fraction) 96 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction) 97 | sess = tf.Session(config = tf.ConfigProto(gpu_options = gpu_options)) 98 | return sess 99 | 100 | 101 | 102 | 103 | 104 | def disable_print(): 105 | """Disable console output, ``suppress_stdout`` is recommended. 106 | 107 | Examples 108 | --------- 109 | >>> print("You can see me") 110 | >>> tl.ops.disable_print() 111 | >>> print(" You can't see me") 112 | >>> tl.ops.enable_print() 113 | >>> print("You can see me") 114 | """ 115 | # sys.stdout = os.devnull # this one kill the process 116 | sys.stdout = None 117 | sys.stderr = os.devnull 118 | 119 | def enable_print(): 120 | """Enable console output, ``suppress_stdout`` is recommended. 121 | 122 | Examples 123 | -------- 124 | - see tl.ops.disable_print() 125 | """ 126 | sys.stdout = sys.__stdout__ 127 | sys.stderr = sys.__stderr__ 128 | 129 | 130 | # class temporary_disable_print: 131 | # """Temporarily disable console output. 132 | # 133 | # Examples 134 | # --------- 135 | # >>> print("You can see me") 136 | # >>> with tl.ops.temporary_disable_print() as t: 137 | # >>> print("You can't see me") 138 | # >>> print("You can see me") 139 | # """ 140 | # def __init__(self): 141 | # pass 142 | # def __enter__(self): 143 | # sys.stdout = None 144 | # sys.stderr = os.devnull 145 | # def __exit__(self, type, value, traceback): 146 | # sys.stdout = sys.__stdout__ 147 | # sys.stderr = sys.__stderr__ 148 | # return isinstance(value, TypeError) 149 | 150 | 151 | from contextlib import contextmanager 152 | @contextmanager 153 | def suppress_stdout(): 154 | """Temporarily disable console output. 155 | 156 | Examples 157 | --------- 158 | >>> print("You can see me") 159 | >>> with tl.ops.suppress_stdout(): 160 | >>> print("You can't see me") 161 | >>> print("You can see me") 162 | 163 | References 164 | ----------- 165 | - `stackoverflow `_ 166 | """ 167 | with open(os.devnull, "w") as devnull: 168 | old_stdout = sys.stdout 169 | sys.stdout = devnull 170 | try: 171 | yield 172 | finally: 173 | sys.stdout = old_stdout 174 | 175 | 176 | 177 | def get_site_packages_directory(): 178 | """Print and return the site-packages directory. 179 | 180 | Examples 181 | --------- 182 | >>> loc = tl.ops.get_site_packages_directory() 183 | """ 184 | import site 185 | try: 186 | loc = site.getsitepackages() 187 | print(" tl.ops : site-packages in ", loc) 188 | return loc 189 | except: 190 | print(" tl.ops : Cannot find package dir from virtual environment") 191 | return False 192 | 193 | 194 | 195 | def empty_trash(): 196 | """Empty trash folder. 197 | 198 | """ 199 | text = "[tl] Empty the trash" 200 | if _platform == "linux" or _platform == "linux2": 201 | print('linux: %s' % text) 202 | os.system("rm -rf ~/.local/share/Trash/*") 203 | elif _platform == "darwin": 204 | print('OS X: %s' % text) 205 | os.system("sudo rm -rf ~/.Trash/*") 206 | elif _platform == "win32": 207 | print('Windows: %s' % text) 208 | try: 209 | os.system("rd /s c:\$Recycle.Bin") # Windows 7 or Server 2008 210 | except: 211 | pass 212 | try: 213 | os.system("rd /s c:\recycler") # Windows XP, Vista, or Server 2003 214 | except: 215 | pass 216 | else: 217 | print(_platform) 218 | 219 | # 220 | -------------------------------------------------------------------------------- /tensorlayer/rein.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | from six.moves import xrange 9 | 10 | def discount_episode_rewards(rewards=[], gamma=0.99, mode=0): 11 | """ Take 1D float array of rewards and compute discounted rewards for an 12 | episode. When encount a non-zero value, consider as the end a of an episode. 13 | 14 | Parameters 15 | ---------- 16 | rewards : numpy list 17 | a list of rewards 18 | gamma : float 19 | discounted factor 20 | mode : int 21 | if mode == 0, reset the discount process when encount a non-zero reward (Ping-pong game). 22 | if mode == 1, would not reset the discount process. 23 | 24 | Examples 25 | ---------- 26 | >>> rewards = np.asarray([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1]) 27 | >>> gamma = 0.9 28 | >>> discount_rewards = tl.rein.discount_episode_rewards(rewards, gamma) 29 | >>> print(discount_rewards) 30 | ... [ 0.72899997 0.81 0.89999998 1. 0.72899997 0.81 31 | ... 0.89999998 1. 0.72899997 0.81 0.89999998 1. ] 32 | >>> discount_rewards = tl.rein.discount_episode_rewards(rewards, gamma, mode=1) 33 | >>> print(discount_rewards) 34 | ... [ 1.52110755 1.69011939 1.87791049 2.08656716 1.20729685 1.34144104 35 | ... 1.49048996 1.65610003 0.72899997 0.81 0.89999998 1. ] 36 | """ 37 | discounted_r = np.zeros_like(rewards, dtype=np.float32) 38 | running_add = 0 39 | for t in reversed(xrange(0, rewards.size)): 40 | if mode == 0: 41 | if rewards[t] != 0: running_add = 0 42 | 43 | running_add = running_add * gamma + rewards[t] 44 | discounted_r[t] = running_add 45 | return discounted_r 46 | 47 | 48 | def cross_entropy_reward_loss(logits, actions, rewards, name=None): 49 | """ Calculate the loss for Policy Gradient Network. 50 | 51 | Parameters 52 | ---------- 53 | logits : tensor 54 | The network outputs without softmax. This function implements softmax 55 | inside. 56 | actions : tensor/ placeholder 57 | The agent actions. 58 | rewards : tensor/ placeholder 59 | The rewards. 60 | 61 | Examples 62 | ---------- 63 | >>> states_batch_pl = tf.placeholder(tf.float32, shape=[None, D]) # observation for training 64 | >>> network = tl.layers.InputLayer(states_batch_pl, name='input_layer') 65 | >>> network = tl.layers.DenseLayer(network, n_units=H, act = tf.nn.relu, name='relu1') 66 | >>> network = tl.layers.DenseLayer(network, n_units=3, act = tl.activation.identity, name='output_layer') 67 | >>> probs = network.outputs 68 | >>> sampling_prob = tf.nn.softmax(probs) 69 | >>> actions_batch_pl = tf.placeholder(tf.int32, shape=[None]) 70 | >>> discount_rewards_batch_pl = tf.placeholder(tf.float32, shape=[None]) 71 | >>> loss = cross_entropy_reward_loss(probs, actions_batch_pl, discount_rewards_batch_pl) 72 | >>> train_op = tf.train.RMSPropOptimizer(learning_rate, decay_rate).minimize(loss) 73 | """ 74 | 75 | try: # TF 1.0 76 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=actions, logits=logits, name=name) 77 | except: 78 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, targets=actions) 79 | # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, actions) 80 | 81 | try: ## TF1.0 82 | loss = tf.reduce_sum(tf.multiply(cross_entropy, rewards)) 83 | except: ## TF0.12 84 | loss = tf.reduce_sum(tf.mul(cross_entropy, rewards)) # element-wise mul 85 | return loss 86 | -------------------------------------------------------------------------------- /tensorlayer/utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | import tensorflow as tf 4 | import tensorlayer as tl 5 | from . import iterate 6 | import numpy as np 7 | import time 8 | import math 9 | import random 10 | 11 | 12 | def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_size=100, 13 | n_epoch=100, print_freq=5, X_val=None, y_val=None, eval_train=True, 14 | tensorboard=False, tensorboard_epoch_freq=5, tensorboard_weight_histograms=True, tensorboard_graph_vis=True): 15 | """Traing a given non time-series network by the given cost function, training data, batch_size, n_epoch etc. 16 | 17 | Parameters 18 | ---------- 19 | sess : TensorFlow session 20 | sess = tf.InteractiveSession() 21 | network : a TensorLayer layer 22 | the network will be trained 23 | train_op : a TensorFlow optimizer 24 | like tf.train.AdamOptimizer 25 | X_train : numpy array 26 | the input of training data 27 | y_train : numpy array 28 | the target of training data 29 | x : placeholder 30 | for inputs 31 | y_ : placeholder 32 | for targets 33 | acc : the TensorFlow expression of accuracy (or other metric) or None 34 | if None, would not display the metric 35 | batch_size : int 36 | batch size for training and evaluating 37 | n_epoch : int 38 | the number of training epochs 39 | print_freq : int 40 | display the training information every ``print_freq`` epochs 41 | X_val : numpy array or None 42 | the input of validation data 43 | y_val : numpy array or None 44 | the target of validation data 45 | eval_train : boolean 46 | if X_val and y_val are not None, it refects whether to evaluate the training data 47 | tensorboard : boolean 48 | if True summary data will be stored to the log/ direcory for visualization with tensorboard. 49 | See also detailed tensorboard_X settings for specific configurations of features. (default False) 50 | Also runs tl.layers.initialize_global_variables(sess) internally in fit() to setup the summary nodes, see Note: 51 | tensorboard_epoch_freq : int 52 | how many epochs between storing tensorboard checkpoint for visualization to log/ directory (default 5) 53 | tensorboard_weight_histograms : boolean 54 | if True updates tensorboard data in the logs/ directory for visulaization 55 | of the weight histograms every tensorboard_epoch_freq epoch (default True) 56 | tensorboard_graph_vis : boolean 57 | if True stores the graph in the tensorboard summaries saved to log/ (default True) 58 | 59 | Examples 60 | -------- 61 | >>> see tutorial_mnist_simple.py 62 | >>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_, 63 | ... acc=acc, batch_size=500, n_epoch=200, print_freq=5, 64 | ... X_val=X_val, y_val=y_val, eval_train=False) 65 | >>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_, 66 | ... acc=acc, batch_size=500, n_epoch=200, print_freq=5, 67 | ... X_val=X_val, y_val=y_val, eval_train=False, 68 | ... tensorboard=True, tensorboard_weight_histograms=True, tensorboard_graph_vis=True) 69 | 70 | Note 71 | -------- 72 | If tensorboard=True, the global_variables_initializer will be run inside the fit function 73 | in order to initalize the automatically generated summary nodes used for tensorboard visualization, 74 | thus tf.global_variables_initializer().run() before the fit() call will be undefined. 75 | """ 76 | assert X_train.shape[0] >= batch_size, "Number of training examples should be bigger than the batch size" 77 | 78 | if(tensorboard): 79 | print("Setting up tensorboard ...") 80 | #Set up tensorboard summaries and saver 81 | tl.files.exists_or_mkdir('logs/') 82 | 83 | #Only write summaries for more recent TensorFlow versions 84 | if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'): 85 | if tensorboard_graph_vis: 86 | train_writer = tf.summary.FileWriter('logs/train',sess.graph) 87 | val_writer = tf.summary.FileWriter('logs/validation',sess.graph) 88 | else: 89 | train_writer = tf.summary.FileWriter('logs/train') 90 | val_writer = tf.summary.FileWriter('logs/validation') 91 | 92 | #Set up summary nodes 93 | if(tensorboard_weight_histograms): 94 | for param in network.all_params: 95 | if hasattr(tf, 'summary') and hasattr(tf.summary, 'histogram'): 96 | print('Param name ', param.name) 97 | tf.summary.histogram(param.name, param) 98 | 99 | if hasattr(tf, 'summary') and hasattr(tf.summary, 'histogram'): 100 | tf.summary.scalar('cost', cost) 101 | 102 | merged = tf.summary.merge_all() 103 | 104 | #Initalize all variables and summaries 105 | tl.layers.initialize_global_variables(sess) 106 | print("Finished! use $tensorboard --logdir=logs/ to start server") 107 | 108 | print("Start training the network ...") 109 | start_time_begin = time.time() 110 | tensorboard_train_index, tensorboard_val_index = 0, 0 111 | for epoch in range(n_epoch): 112 | start_time = time.time() 113 | loss_ep = 0; n_step = 0 114 | for X_train_a, y_train_a in iterate.minibatches(X_train, y_train, 115 | batch_size, shuffle=True): 116 | feed_dict = {x: X_train_a, y_: y_train_a} 117 | feed_dict.update( network.all_drop ) # enable noise layers 118 | loss, _ = sess.run([cost, train_op], feed_dict=feed_dict) 119 | loss_ep += loss 120 | n_step += 1 121 | loss_ep = loss_ep/ n_step 122 | 123 | if tensorboard and hasattr(tf, 'summary'): 124 | if epoch+1 == 1 or (epoch+1) % tensorboard_epoch_freq == 0: 125 | for X_train_a, y_train_a in iterate.minibatches( 126 | X_train, y_train, batch_size, shuffle=True): 127 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 128 | feed_dict = {x: X_train_a, y_: y_train_a} 129 | feed_dict.update(dp_dict) 130 | result = sess.run(merged, feed_dict=feed_dict) 131 | train_writer.add_summary(result, tensorboard_train_index) 132 | tensorboard_train_index += 1 133 | if (X_val is not None) and (y_val is not None): 134 | for X_val_a, y_val_a in iterate.minibatches( 135 | X_val, y_val, batch_size, shuffle=True): 136 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 137 | feed_dict = {x: X_val_a, y_: y_val_a} 138 | feed_dict.update(dp_dict) 139 | result = sess.run(merged, feed_dict=feed_dict) 140 | val_writer.add_summary(result, tensorboard_val_index) 141 | tensorboard_val_index += 1 142 | 143 | if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: 144 | if (X_val is not None) and (y_val is not None): 145 | print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) 146 | if eval_train is True: 147 | train_loss, train_acc, n_batch = 0, 0, 0 148 | for X_train_a, y_train_a in iterate.minibatches( 149 | X_train, y_train, batch_size, shuffle=True): 150 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 151 | feed_dict = {x: X_train_a, y_: y_train_a} 152 | feed_dict.update(dp_dict) 153 | if acc is not None: 154 | err, ac = sess.run([cost, acc], feed_dict=feed_dict) 155 | train_acc += ac 156 | else: 157 | err = sess.run(cost, feed_dict=feed_dict) 158 | train_loss += err; n_batch += 1 159 | print(" train loss: %f" % (train_loss/ n_batch)) 160 | if acc is not None: 161 | print(" train acc: %f" % (train_acc/ n_batch)) 162 | val_loss, val_acc, n_batch = 0, 0, 0 163 | for X_val_a, y_val_a in iterate.minibatches( 164 | X_val, y_val, batch_size, shuffle=True): 165 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 166 | feed_dict = {x: X_val_a, y_: y_val_a} 167 | feed_dict.update(dp_dict) 168 | if acc is not None: 169 | err, ac = sess.run([cost, acc], feed_dict=feed_dict) 170 | val_acc += ac 171 | else: 172 | err = sess.run(cost, feed_dict=feed_dict) 173 | val_loss += err; n_batch += 1 174 | print(" val loss: %f" % (val_loss/ n_batch)) 175 | if acc is not None: 176 | print(" val acc: %f" % (val_acc/ n_batch)) 177 | else: 178 | print("Epoch %d of %d took %fs, loss %f" % (epoch + 1, n_epoch, time.time() - start_time, loss_ep)) 179 | print("Total training time: %fs" % (time.time() - start_time_begin)) 180 | 181 | 182 | def test(sess, network, acc, X_test, y_test, x, y_, batch_size, cost=None): 183 | """ 184 | Test a given non time-series network by the given test data and metric. 185 | 186 | Parameters 187 | ---------- 188 | sess : TensorFlow session 189 | sess = tf.InteractiveSession() 190 | network : a TensorLayer layer 191 | the network will be trained 192 | acc : the TensorFlow expression of accuracy (or other metric) or None 193 | if None, would not display the metric 194 | X_test : numpy array 195 | the input of test data 196 | y_test : numpy array 197 | the target of test data 198 | x : placeholder 199 | for inputs 200 | y_ : placeholder 201 | for targets 202 | batch_size : int or None 203 | batch size for testing, when dataset is large, we should use minibatche for testing. 204 | when dataset is small, we can set it to None. 205 | cost : the TensorFlow expression of cost or None 206 | if None, would not display the cost 207 | 208 | Examples 209 | -------- 210 | >>> see tutorial_mnist_simple.py 211 | >>> tl.utils.test(sess, network, acc, X_test, y_test, x, y_, batch_size=None, cost=cost) 212 | """ 213 | print('Start testing the network ...') 214 | if batch_size is None: 215 | dp_dict = dict_to_one( network.all_drop ) 216 | feed_dict = {x: X_test, y_: y_test} 217 | feed_dict.update(dp_dict) 218 | if cost is not None: 219 | print(" test loss: %f" % sess.run(cost, feed_dict=feed_dict)) 220 | print(" test acc: %f" % sess.run(acc, feed_dict=feed_dict)) 221 | # print(" test acc: %f" % np.mean(y_test == sess.run(y_op, 222 | # feed_dict=feed_dict))) 223 | else: 224 | test_loss, test_acc, n_batch = 0, 0, 0 225 | for X_test_a, y_test_a in iterate.minibatches( 226 | X_test, y_test, batch_size, shuffle=True): 227 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 228 | feed_dict = {x: X_test_a, y_: y_test_a} 229 | feed_dict.update(dp_dict) 230 | if cost is not None: 231 | err, ac = sess.run([cost, acc], feed_dict=feed_dict) 232 | test_loss += err 233 | else: 234 | ac = sess.run(acc, feed_dict=feed_dict) 235 | test_acc += ac; n_batch += 1 236 | if cost is not None: 237 | print(" test loss: %f" % (test_loss/ n_batch)) 238 | print(" test acc: %f" % (test_acc/ n_batch)) 239 | 240 | 241 | def predict(sess, network, X, x, y_op, batch_size=None): 242 | """ 243 | Return the predict results of given non time-series network. 244 | 245 | Parameters 246 | ---------- 247 | sess : TensorFlow session 248 | sess = tf.InteractiveSession() 249 | network : a TensorLayer layer 250 | the network will be trained 251 | X : numpy array 252 | the input 253 | x : placeholder 254 | for inputs 255 | y_op : placeholder 256 | the argmax expression of softmax outputs 257 | batch_size : int or None 258 | batch size for prediction, when dataset is large, we should use minibatche for prediction. 259 | when dataset is small, we can set it to None. 260 | 261 | Examples 262 | -------- 263 | >>> see tutorial_mnist_simple.py 264 | >>> y = network.outputs 265 | >>> y_op = tf.argmax(tf.nn.softmax(y), 1) 266 | >>> print(tl.utils.predict(sess, network, X_test, x, y_op)) 267 | """ 268 | if batch_size is None: 269 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 270 | feed_dict = {x: X,} 271 | feed_dict.update(dp_dict) 272 | return sess.run(y_op, feed_dict=feed_dict) 273 | else: 274 | result = None 275 | for X_a, _ in iterate.minibatches( 276 | X, X, batch_size, shuffle=False): 277 | dp_dict = dict_to_one( network.all_drop ) 278 | feed_dict = {x: X_a, } 279 | feed_dict.update(dp_dict) 280 | result_a = sess.run(y_op, feed_dict=feed_dict) 281 | if result is None: 282 | result = result_a 283 | else: 284 | result = np.hstack((result, result_a)) 285 | return result 286 | 287 | 288 | ## Evaluation 289 | def evaluation(y_test=None, y_predict=None, n_classes=None): 290 | """ 291 | Input the predicted results, targets results and 292 | the number of class, return the confusion matrix, F1-score of each class, 293 | accuracy and macro F1-score. 294 | 295 | Parameters 296 | ---------- 297 | y_test : numpy.array or list 298 | target results 299 | y_predict : numpy.array or list 300 | predicted results 301 | n_classes : int 302 | number of classes 303 | 304 | Examples 305 | -------- 306 | >>> c_mat, f1, acc, f1_macro = evaluation(y_test, y_predict, n_classes) 307 | """ 308 | from sklearn.metrics import confusion_matrix, f1_score, accuracy_score 309 | c_mat = confusion_matrix(y_test, y_predict, labels = [x for x in range(n_classes)]) 310 | f1 = f1_score(y_test, y_predict, average = None, labels = [x for x in range(n_classes)]) 311 | f1_macro = f1_score(y_test, y_predict, average='macro') 312 | acc = accuracy_score(y_test, y_predict) 313 | print('confusion matrix: \n',c_mat) 314 | print('f1-score:',f1) 315 | print('f1-score(macro):',f1_macro) # same output with > f1_score(y_true, y_pred, average='macro') 316 | print('accuracy-score:', acc) 317 | return c_mat, f1, acc, f1_macro 318 | 319 | def dict_to_one(dp_dict={}): 320 | """ 321 | Input a dictionary, return a dictionary that all items are set to one, 322 | use for disable dropout, dropconnect layer and so on. 323 | 324 | Parameters 325 | ---------- 326 | dp_dict : dictionary 327 | keeping probabilities 328 | 329 | Examples 330 | -------- 331 | >>> dp_dict = dict_to_one( network.all_drop ) 332 | >>> dp_dict = dict_to_one( network.all_drop ) 333 | >>> feed_dict.update(dp_dict) 334 | """ 335 | return {x: 1 for x in dp_dict} 336 | 337 | def flatten_list(list_of_list=[[],[]]): 338 | """ 339 | Input a list of list, return a list that all items are in a list. 340 | 341 | Parameters 342 | ---------- 343 | list_of_list : a list of list 344 | 345 | Examples 346 | -------- 347 | >>> tl.utils.flatten_list([[1, 2, 3],[4, 5],[6]]) 348 | ... [1, 2, 3, 4, 5, 6] 349 | """ 350 | return sum(list_of_list, []) 351 | 352 | 353 | def class_balancing_oversample(X_train=None, y_train=None, printable=True): 354 | """Input the features and labels, return the features and labels after oversampling. 355 | 356 | Parameters 357 | ---------- 358 | X_train : numpy.array 359 | Features, each row is an example 360 | y_train : numpy.array 361 | Labels 362 | 363 | Examples 364 | -------- 365 | - One X 366 | >>> X_train, y_train = class_balancing_oversample(X_train, y_train, printable=True) 367 | 368 | - Two X 369 | >>> X, y = tl.utils.class_balancing_oversample(X_train=np.hstack((X1, X2)), y_train=y, printable=False) 370 | >>> X1 = X[:, 0:5] 371 | >>> X2 = X[:, 5:] 372 | """ 373 | # ======== Classes balancing 374 | if printable: 375 | print("Classes balancing for training examples...") 376 | from collections import Counter 377 | c = Counter(y_train) 378 | if printable: 379 | print('the occurrence number of each stage: %s' % c.most_common()) 380 | print('the least stage is Label %s have %s instances' % c.most_common()[-1]) 381 | print('the most stage is Label %s have %s instances' % c.most_common(1)[0]) 382 | most_num = c.most_common(1)[0][1] 383 | if printable: 384 | print('most num is %d, all classes tend to be this num' % most_num) 385 | 386 | locations = {} 387 | number = {} 388 | 389 | for lab, num in c.most_common(): # find the index from y_train 390 | number[lab] = num 391 | locations[lab] = np.where(np.array(y_train)==lab)[0] 392 | if printable: 393 | print('convert list(np.array) to dict format') 394 | X = {} # convert list to dict 395 | for lab, num in number.items(): 396 | X[lab] = X_train[locations[lab]] 397 | 398 | # oversampling 399 | if printable: 400 | print('start oversampling') 401 | for key in X: 402 | temp = X[key] 403 | while True: 404 | if len(X[key]) >= most_num: 405 | break 406 | X[key] = np.vstack((X[key], temp)) 407 | if printable: 408 | print('first features of label 0 >', len(X[0][0])) 409 | print('the occurrence num of each stage after oversampling') 410 | for key in X: 411 | print(key, len(X[key])) 412 | if printable: 413 | print('make each stage have same num of instances') 414 | for key in X: 415 | X[key] = X[key][0:most_num,:] 416 | print(key, len(X[key])) 417 | 418 | # convert dict to list 419 | if printable: 420 | print('convert from dict to list format') 421 | y_train = [] 422 | X_train = np.empty(shape=(0,len(X[0][0]))) 423 | for key in X: 424 | X_train = np.vstack( (X_train, X[key] ) ) 425 | y_train.extend([key for i in range(len(X[key]))]) 426 | # print(len(X_train), len(y_train)) 427 | c = Counter(y_train) 428 | if printable: 429 | print('the occurrence number of each stage after oversampling: %s' % c.most_common()) 430 | # ================ End of Classes balancing 431 | return X_train, y_train 432 | 433 | ## Random 434 | def get_random_int(min=0, max=10, number=5, seed=None): 435 | """Return a list of random integer by the given range and quantity. 436 | 437 | Examples 438 | --------- 439 | >>> r = get_random_int(min=0, max=10, number=5) 440 | ... [10, 2, 3, 3, 7] 441 | """ 442 | rnd = random.Random() 443 | if seed: 444 | rnd = random.Random(seed) 445 | # return [random.randint(min,max) for p in range(0, number)] 446 | return [rnd.randint(min,max) for p in range(0, number)] 447 | 448 | # 449 | # def class_balancing_sequence_4D(X_train, y_train, sequence_length, model='downsampling' ,printable=True): 450 | # ''' 输入、输出都是sequence format 451 | # oversampling or downsampling 452 | # ''' 453 | # n_features = X_train.shape[2] 454 | # # ======== Classes balancing for sequence 455 | # if printable: 456 | # print("Classes balancing for 4D sequence training examples...") 457 | # from collections import Counter 458 | # c = Counter(y_train) # Counter({2: 454, 4: 267, 3: 124, 1: 57, 0: 48}) 459 | # if printable: 460 | # print('the occurrence number of each stage: %s' % c.most_common()) 461 | # print('the least Label %s have %s instances' % c.most_common()[-1]) 462 | # print('the most Label %s have %s instances' % c.most_common(1)[0]) 463 | # # print(c.most_common()) # [(2, 454), (4, 267), (3, 124), (1, 57), (0, 48)] 464 | # most_num = c.most_common(1)[0][1] 465 | # less_num = c.most_common()[-1][1] 466 | # 467 | # locations = {} 468 | # number = {} 469 | # for lab, num in c.most_common(): 470 | # number[lab] = num 471 | # locations[lab] = np.where(np.array(y_train)==lab)[0] 472 | # # print(locations) 473 | # # print(number) 474 | # if printable: 475 | # print(' convert list to dict') 476 | # X = {} # convert list to dict 477 | # ### a sequence 478 | # for lab, _ in number.items(): 479 | # X[lab] = np.empty(shape=(0,1,n_features,1)) # 4D 480 | # for lab, _ in number.items(): 481 | # #X[lab] = X_train[locations[lab] 482 | # for l in locations[lab]: 483 | # X[lab] = np.vstack((X[lab], X_train[l*sequence_length : (l+1)*(sequence_length)])) 484 | # # X[lab] = X_train[locations[lab]*sequence_length : locations[lab]*(sequence_length+1)] # a sequence 485 | # # print(X) 486 | # 487 | # if model=='oversampling': 488 | # if printable: 489 | # print(' oversampling -- most num is %d, all classes tend to be this num\nshuffle applied' % most_num) 490 | # for key in X: 491 | # temp = X[key] 492 | # while True: 493 | # if len(X[key]) >= most_num * sequence_length: # sequence 494 | # break 495 | # X[key] = np.vstack((X[key], temp)) 496 | # # print(key, len(X[key])) 497 | # if printable: 498 | # print(' make each stage have same num of instances') 499 | # for key in X: 500 | # X[key] = X[key][0:most_num*sequence_length,:] # sequence 501 | # if printable: 502 | # print(key, len(X[key])) 503 | # elif model=='downsampling': 504 | # import random 505 | # if printable: 506 | # print(' downsampling -- less num is %d, all classes tend to be this num by randomly choice without replacement\nshuffle applied' % less_num) 507 | # for key in X: 508 | # # print(key, len(X[key]))#, len(X[key])/sequence_length) 509 | # s_idx = [ i for i in range(int(len(X[key])/sequence_length))] 510 | # s_idx = np.asarray(s_idx)*sequence_length # start index of sequnce in X[key] 511 | # # print('s_idx',s_idx) 512 | # r_idx = np.random.choice(s_idx, less_num, replace=False) # random choice less_num of s_idx 513 | # # print('r_idx',r_idx) 514 | # temp = X[key] 515 | # X[key] = np.empty(shape=(0,1,n_features,1)) # 4D 516 | # for idx in r_idx: 517 | # X[key] = np.vstack((X[key], temp[idx:idx+sequence_length])) 518 | # # print(key, X[key]) 519 | # # np.random.choice(l, len(l), replace=False) 520 | # else: 521 | # raise Exception(' model should be oversampling or downsampling') 522 | # 523 | # # convert dict to list 524 | # if printable: 525 | # print(' convert dict to list') 526 | # y_train = [] 527 | # # X_train = np.empty(shape=(0,len(X[0][0]))) 528 | # # X_train = np.empty(shape=(0,len(X[1][0]))) # 2D 529 | # X_train = np.empty(shape=(0,1,n_features,1)) # 4D 530 | # l_key = list(X.keys()) # shuffle 531 | # random.shuffle(l_key) # shuffle 532 | # # for key in X: # no shuffle 533 | # for key in l_key: # shuffle 534 | # X_train = np.vstack( (X_train, X[key] ) ) 535 | # # print(len(X[key])) 536 | # y_train.extend([key for i in range(int(len(X[key])/sequence_length))]) 537 | # # print(X_train,y_train, type(X_train), type(y_train)) 538 | # # ================ End of Classes balancing for sequence 539 | # # print(X_train.shape, len(y_train)) 540 | # return X_train, np.asarray(y_train) 541 | -------------------------------------------------------------------------------- /tensorlayer/visualize.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | # import matplotlib.pyplot as plt 9 | import numpy as np 10 | import os 11 | 12 | 13 | ## Save images 14 | import scipy.misc 15 | 16 | def save_image(image, image_path): 17 | """Save one image. 18 | 19 | Parameters 20 | ----------- 21 | images : numpy array [w, h, c] 22 | image_path : string. 23 | """ 24 | scipy.misc.imsave(image_path, image) 25 | 26 | def save_images(images, size, image_path): 27 | """Save mutiple images into one single image. 28 | 29 | Parameters 30 | ----------- 31 | images : numpy array [batch, w, h, c] 32 | size : list of two int, row and column number. 33 | number of images should be equal or less than size[0] * size[1] 34 | image_path : string. 35 | 36 | Examples 37 | --------- 38 | >>> images = np.random.rand(64, 100, 100, 3) 39 | >>> tl.visualize.save_images(images, [8, 8], 'temp.png') 40 | """ 41 | def merge(images, size): 42 | h, w = images.shape[1], images.shape[2] 43 | img = np.zeros((h * size[0], w * size[1], 3)) 44 | for idx, image in enumerate(images): 45 | i = idx % size[1] 46 | j = idx // size[1] 47 | img[j*h:j*h+h, i*w:i*w+w, :] = image 48 | return img 49 | 50 | def imsave(images, size, path): 51 | return scipy.misc.imsave(path, merge(images, size)) 52 | 53 | assert len(images) <= size[0] * size[1], "number of images should be equal or less than size[0] * size[1] {}".format(len(images)) 54 | return imsave(images, size, image_path) 55 | 56 | def W(W=None, second=10, saveable=True, shape=[28,28], name='mnist', fig_idx=2396512): 57 | """Visualize every columns of the weight matrix to a group of Greyscale img. 58 | 59 | Parameters 60 | ---------- 61 | W : numpy.array 62 | The weight matrix 63 | second : int 64 | The display second(s) for the image(s), if saveable is False. 65 | saveable : boolean 66 | Save or plot the figure. 67 | shape : a list with 2 int 68 | The shape of feature image, MNIST is [28, 80]. 69 | name : a string 70 | A name to save the image, if saveable is True. 71 | fig_idx : int 72 | matplotlib figure index. 73 | 74 | Examples 75 | -------- 76 | >>> tl.visualize.W(network.all_params[0].eval(), second=10, saveable=True, name='weight_of_1st_layer', fig_idx=2012) 77 | """ 78 | if saveable is False: 79 | plt.ion() 80 | fig = plt.figure(fig_idx) # show all feature images 81 | size = W.shape[0] 82 | n_units = W.shape[1] 83 | 84 | num_r = int(np.sqrt(n_units)) # 每行显示的个数 若25个hidden unit -> 每行显示5个 85 | num_c = int(np.ceil(n_units/num_r)) 86 | count = int(1) 87 | for row in range(1, num_r+1): 88 | for col in range(1, num_c+1): 89 | if count > n_units: 90 | break 91 | a = fig.add_subplot(num_r, num_c, count) 92 | # ------------------------------------------------------------ 93 | # plt.imshow(np.reshape(W[:,count-1],(28,28)), cmap='gray') 94 | # ------------------------------------------------------------ 95 | feature = W[:,count-1] / np.sqrt( (W[:,count-1]**2).sum()) 96 | # feature[feature<0.0001] = 0 # value threshold 97 | # if count == 1 or count == 2: 98 | # print(np.mean(feature)) 99 | # if np.std(feature) < 0.03: # condition threshold 100 | # feature = np.zeros_like(feature) 101 | # if np.mean(feature) < -0.015: # condition threshold 102 | # feature = np.zeros_like(feature) 103 | plt.imshow(np.reshape(feature ,(shape[0],shape[1])), 104 | cmap='gray', interpolation="nearest")#, vmin=np.min(feature), vmax=np.max(feature)) 105 | # plt.title(name) 106 | # ------------------------------------------------------------ 107 | # plt.imshow(np.reshape(W[:,count-1] ,(np.sqrt(size),np.sqrt(size))), cmap='gray', interpolation="nearest") 108 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick 109 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 110 | count = count + 1 111 | if saveable: 112 | plt.savefig(name+'.pdf',format='pdf') 113 | else: 114 | plt.draw() 115 | plt.pause(second) 116 | 117 | def frame(I=None, second=5, saveable=True, name='frame', cmap=None, fig_idx=12836): 118 | """Display a frame(image). Make sure OpenAI Gym render() is disable before using it. 119 | 120 | Parameters 121 | ---------- 122 | I : numpy.array 123 | The image 124 | second : int 125 | The display second(s) for the image(s), if saveable is False. 126 | saveable : boolean 127 | Save or plot the figure. 128 | name : a string 129 | A name to save the image, if saveable is True. 130 | cmap : None or string 131 | 'gray' for greyscale, None for default, etc. 132 | fig_idx : int 133 | matplotlib figure index. 134 | 135 | Examples 136 | -------- 137 | >>> env = gym.make("Pong-v0") 138 | >>> observation = env.reset() 139 | >>> tl.visualize.frame(observation) 140 | """ 141 | if saveable is False: 142 | plt.ion() 143 | fig = plt.figure(fig_idx) # show all feature images 144 | 145 | if len(I.shape) and I.shape[-1]==1: # (10,10,1) --> (10,10) 146 | I = I[:,:,0] 147 | 148 | plt.imshow(I, cmap) 149 | plt.title(name) 150 | # plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick 151 | # plt.gca().yaxis.set_major_locator(plt.NullLocator()) 152 | 153 | if saveable: 154 | plt.savefig(name+'.pdf',format='pdf') 155 | else: 156 | plt.draw() 157 | plt.pause(second) 158 | 159 | def CNN2d(CNN=None, second=10, saveable=True, name='cnn', fig_idx=3119362): 160 | """Display a group of RGB or Greyscale CNN masks. 161 | 162 | Parameters 163 | ---------- 164 | CNN : numpy.array 165 | The image. e.g: 64 5x5 RGB images can be (5, 5, 3, 64). 166 | second : int 167 | The display second(s) for the image(s), if saveable is False. 168 | saveable : boolean 169 | Save or plot the figure. 170 | name : a string 171 | A name to save the image, if saveable is True. 172 | fig_idx : int 173 | matplotlib figure index. 174 | 175 | Examples 176 | -------- 177 | >>> tl.visualize.CNN2d(network.all_params[0].eval(), second=10, saveable=True, name='cnn1_mnist', fig_idx=2012) 178 | """ 179 | # print(CNN.shape) # (5, 5, 3, 64) 180 | # exit() 181 | n_mask = CNN.shape[3] 182 | n_row = CNN.shape[0] 183 | n_col = CNN.shape[1] 184 | n_color = CNN.shape[2] 185 | row = int(np.sqrt(n_mask)) 186 | col = int(np.ceil(n_mask/row)) 187 | plt.ion() # active mode 188 | fig = plt.figure(fig_idx) 189 | count = 1 190 | for ir in range(1, row+1): 191 | for ic in range(1, col+1): 192 | if count > n_mask: 193 | break 194 | a = fig.add_subplot(col, row, count) 195 | # print(CNN[:,:,:,count-1].shape, n_row, n_col) # (5, 1, 32) 5 5 196 | # exit() 197 | # plt.imshow( 198 | # np.reshape(CNN[count-1,:,:,:], (n_row, n_col)), 199 | # cmap='gray', interpolation="nearest") # theano 200 | if n_color == 1: 201 | plt.imshow( 202 | np.reshape(CNN[:,:,:,count-1], (n_row, n_col)), 203 | cmap='gray', interpolation="nearest") 204 | elif n_color == 3: 205 | plt.imshow( 206 | np.reshape(CNN[:,:,:,count-1], (n_row, n_col, n_color)), 207 | cmap='gray', interpolation="nearest") 208 | else: 209 | raise Exception("Unknown n_color") 210 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick 211 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 212 | count = count + 1 213 | if saveable: 214 | plt.savefig(name+'.pdf',format='pdf') 215 | else: 216 | plt.draw() 217 | plt.pause(second) 218 | 219 | 220 | def images2d(images=None, second=10, saveable=True, name='images', dtype=None, 221 | fig_idx=3119362): 222 | """Display a group of RGB or Greyscale images. 223 | 224 | Parameters 225 | ---------- 226 | images : numpy.array 227 | The images. 228 | second : int 229 | The display second(s) for the image(s), if saveable is False. 230 | saveable : boolean 231 | Save or plot the figure. 232 | name : a string 233 | A name to save the image, if saveable is True. 234 | dtype : None or numpy data type 235 | The data type for displaying the images. 236 | fig_idx : int 237 | matplotlib figure index. 238 | 239 | Examples 240 | -------- 241 | >>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) 242 | >>> tl.visualize.images2d(X_train[0:100,:,:,:], second=10, saveable=False, name='cifar10', dtype=np.uint8, fig_idx=20212) 243 | """ 244 | # print(images.shape) # (50000, 32, 32, 3) 245 | # exit() 246 | if dtype: 247 | images = np.asarray(images, dtype=dtype) 248 | n_mask = images.shape[0] 249 | n_row = images.shape[1] 250 | n_col = images.shape[2] 251 | n_color = images.shape[3] 252 | row = int(np.sqrt(n_mask)) 253 | col = int(np.ceil(n_mask/row)) 254 | plt.ion() # active mode 255 | fig = plt.figure(fig_idx) 256 | count = 1 257 | for ir in range(1, row+1): 258 | for ic in range(1, col+1): 259 | if count > n_mask: 260 | break 261 | a = fig.add_subplot(col, row, count) 262 | # print(images[:,:,:,count-1].shape, n_row, n_col) # (5, 1, 32) 5 5 263 | # plt.imshow( 264 | # np.reshape(images[count-1,:,:,:], (n_row, n_col)), 265 | # cmap='gray', interpolation="nearest") # theano 266 | if n_color == 1: 267 | plt.imshow( 268 | np.reshape(images[count-1,:,:], (n_row, n_col)), 269 | cmap='gray', interpolation="nearest") 270 | # plt.title(name) 271 | elif n_color == 3: 272 | plt.imshow(images[count-1,:,:], 273 | cmap='gray', interpolation="nearest") 274 | # plt.title(name) 275 | else: 276 | raise Exception("Unknown n_color") 277 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick 278 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 279 | count = count + 1 280 | if saveable: 281 | plt.savefig(name+'.pdf',format='pdf') 282 | else: 283 | plt.draw() 284 | plt.pause(second) 285 | 286 | def tsne_embedding(embeddings, reverse_dictionary, plot_only=500, 287 | second=5, saveable=False, name='tsne', fig_idx=9862): 288 | """Visualize the embeddings by using t-SNE. 289 | 290 | Parameters 291 | ---------- 292 | embeddings : a matrix 293 | The images. 294 | reverse_dictionary : a dictionary 295 | id_to_word, mapping id to unique word. 296 | plot_only : int 297 | The number of examples to plot, choice the most common words. 298 | second : int 299 | The display second(s) for the image(s), if saveable is False. 300 | saveable : boolean 301 | Save or plot the figure. 302 | name : a string 303 | A name to save the image, if saveable is True. 304 | fig_idx : int 305 | matplotlib figure index. 306 | 307 | Examples 308 | -------- 309 | >>> see 'tutorial_word2vec_basic.py' 310 | >>> final_embeddings = normalized_embeddings.eval() 311 | >>> tl.visualize.tsne_embedding(final_embeddings, labels, reverse_dictionary, 312 | ... plot_only=500, second=5, saveable=False, name='tsne') 313 | """ 314 | def plot_with_labels(low_dim_embs, labels, figsize=(18, 18), second=5, 315 | saveable=True, name='tsne', fig_idx=9862): 316 | assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings" 317 | if saveable is False: 318 | plt.ion() 319 | plt.figure(fig_idx) 320 | plt.figure(figsize=figsize) #in inches 321 | for i, label in enumerate(labels): 322 | x, y = low_dim_embs[i,:] 323 | plt.scatter(x, y) 324 | plt.annotate(label, 325 | xy=(x, y), 326 | xytext=(5, 2), 327 | textcoords='offset points', 328 | ha='right', 329 | va='bottom') 330 | if saveable: 331 | plt.savefig(name+'.pdf',format='pdf') 332 | else: 333 | plt.draw() 334 | plt.pause(second) 335 | 336 | try: 337 | from sklearn.manifold import TSNE 338 | import matplotlib.pyplot as plt 339 | from six.moves import xrange 340 | 341 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) 342 | # plot_only = 500 343 | low_dim_embs = tsne.fit_transform(embeddings[:plot_only,:]) 344 | labels = [reverse_dictionary[i] for i in xrange(plot_only)] 345 | plot_with_labels(low_dim_embs, labels, second=second, saveable=saveable, \ 346 | name=name, fig_idx=fig_idx) 347 | except ImportError: 348 | print("Please install sklearn and matplotlib to visualize embeddings.") 349 | 350 | 351 | # 352 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import numpy as np 3 | 4 | def get_imgs_fn(file_name): 5 | return scipy.misc.imread(file_name, mode='RGB') 6 | 7 | def augment_imgs_fn(x, add_noise=True): 8 | return x+0.1*x.std()*np.random.random(x.shape) 9 | 10 | def normalize_imgs_fn(x): 11 | x = x * (2./ 255.) - 1. 12 | # x = x * (1./255.) 13 | return x 14 | 15 | def truncate_imgs_fn(x): 16 | x = np.where(x > -1., x, -1.) 17 | x = np.where(x < 1., x, 1.) 18 | return x --------------------------------------------------------------------------------