├── .DS_Store ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── apex ├── .gitignore ├── .nojekyll ├── LICENSE ├── README.md ├── apex.patch ├── apex │ ├── RNN │ │ ├── README.md │ │ ├── RNNBackend.py │ │ ├── __init__.py │ │ ├── cells.py │ │ └── models.py │ ├── __init__.py │ ├── amp │ │ ├── README.md │ │ ├── __init__.py │ │ ├── __version__.py │ │ ├── _amp_state.py │ │ ├── _initialize.py │ │ ├── _process_optimizer.py │ │ ├── amp.py │ │ ├── compat.py │ │ ├── frontend.py │ │ ├── handle.py │ │ ├── lists │ │ │ ├── __init__.py │ │ │ ├── functional_overrides.py │ │ │ ├── tensor_overrides.py │ │ │ └── torch_overrides.py │ │ ├── opt.py │ │ ├── rnn_compat.py │ │ ├── scaler.py │ │ ├── utils.py │ │ └── wrap.py │ ├── fp16_utils │ │ ├── README.md │ │ ├── __init__.py │ │ ├── fp16_optimizer.py │ │ ├── fp16util.py │ │ └── loss_scaler.py │ ├── multi_tensor_apply │ │ ├── __init__.py │ │ └── multi_tensor_apply.py │ ├── normalization │ │ ├── __init__.py │ │ └── fused_layer_norm.py │ ├── optimizers │ │ ├── __init__.py │ │ ├── fp16_optimizer.py │ │ └── fused_adam.py │ ├── parallel │ │ ├── LARC.py │ │ ├── README.md │ │ ├── __init__.py │ │ ├── distributed.py │ │ ├── multiproc.py │ │ ├── optimized_sync_batchnorm.py │ │ ├── optimized_sync_batchnorm_kernel.py │ │ ├── sync_batchnorm.py │ │ └── sync_batchnorm_kernel.py │ └── reparameterization │ │ ├── README.md │ │ ├── __init__.py │ │ ├── reparameterization.py │ │ └── weight_norm.py ├── csrc │ ├── amp_C_frontend.cpp │ ├── flatten_unflatten.cpp │ ├── fused_adam_cuda.cpp │ ├── fused_adam_cuda_kernel.cu │ ├── layer_norm_cuda.cpp │ ├── layer_norm_cuda_kernel.cu │ ├── multi_tensor_apply.cuh │ ├── multi_tensor_axpby_kernel.cu │ ├── multi_tensor_l2norm_kernel.cu │ ├── multi_tensor_lamb_stage_1.cu │ ├── multi_tensor_lamb_stage_2.cu │ ├── multi_tensor_scale_kernel.cu │ ├── syncbn.cpp │ ├── type_shim.h │ └── welford.cu ├── docs │ ├── Makefile │ └── source │ │ ├── _static │ │ ├── css │ │ │ └── pytorch_theme.css │ │ └── img │ │ │ └── nv-pytorch2.png │ │ ├── _templates │ │ └── layout.html │ │ ├── advanced.rst │ │ ├── amp.rst │ │ ├── conf.py │ │ ├── fp16_utils.rst │ │ ├── index.rst │ │ ├── layernorm.rst │ │ ├── optimizers.rst │ │ └── parallel.rst ├── examples │ ├── README.md │ ├── dcgan │ │ └── README.md │ ├── docker │ │ ├── Dockerfile │ │ └── README.md │ ├── imagenet │ │ ├── README.md │ │ └── main_amp.py │ └── simple │ │ └── distributed │ │ ├── README.md │ │ ├── distributed_data_parallel.py │ │ └── run.sh ├── setup.py └── tests │ ├── L0 │ ├── run_amp │ │ ├── __init__.py │ │ ├── test_add_param_group.py │ │ ├── test_basic_casts.py │ │ ├── test_cache.py │ │ ├── test_multi_tensor_axpby.py │ │ ├── test_multi_tensor_l2norm.py │ │ ├── test_multi_tensor_scale.py │ │ ├── test_multiple_models_optimizers_losses.py │ │ ├── test_promotion.py │ │ ├── test_rnn.py │ │ └── utils.py │ ├── run_fp16util │ │ ├── __init__.py │ │ └── test_fp16util.py │ ├── run_fused_layer_norm │ │ └── test_fused_layer_norm.py │ ├── run_mixed_adam │ │ ├── __init__.py │ │ ├── test_fp16_optimizer.py │ │ └── test_mixed_adam.py │ └── run_test.py │ ├── L1 │ ├── common │ │ ├── compare.py │ │ ├── main_amp.py │ │ └── run_test.sh │ ├── cross_product │ │ └── run.sh │ └── cross_product_distributed │ │ └── run.sh │ ├── distributed │ ├── DDP │ │ ├── ddp_race_condition_test.py │ │ └── run_race_test.sh │ ├── amp_master_params │ │ ├── amp_master_params.py │ │ ├── compare.py │ │ └── run.sh │ └── synced_batchnorm │ │ ├── single_gpu_unit_test.py │ │ ├── test_groups.py │ │ ├── two_gpu_unit_test.py │ │ └── unit_test.sh │ └── docker_extension_builds │ └── run.sh ├── jukebox ├── Interacting_with_Jukebox.ipynb ├── __init__.py ├── align.py ├── data │ ├── __init__.py │ ├── artist_genre_processor.py │ ├── data_processor.py │ ├── files_dataset.py │ ├── ids │ │ ├── v2_artist_ids.txt │ │ ├── v2_genre_ids.txt │ │ ├── v3_artist_ids.txt │ │ └── v3_genre_ids.txt │ ├── labels.py │ └── text_processor.py ├── hparams.py ├── lyricdict.py ├── make_models.py ├── prior │ ├── __init__.py │ ├── autoregressive.py │ ├── conditioners.py │ └── prior.py ├── sample.py ├── save_html.py ├── tests │ └── test_sample.py ├── train.py ├── transformer │ ├── __init__.py │ ├── factored_attention.py │ ├── ops.py │ └── transformer.py ├── utils │ ├── __init__.py │ ├── audio_utils.py │ ├── checkpoint.py │ ├── dist_adapter.py │ ├── dist_utils.py │ ├── ema.py │ ├── fp16.py │ ├── io.py │ ├── logger.py │ ├── remote_utils.py │ ├── sample_utils.py │ └── torch_utils.py └── vqvae │ ├── __init__.py │ ├── bottleneck.py │ ├── encdec.py │ ├── resnet.py │ └── vqvae.py ├── requirements.txt ├── setup.py └── tensorboardX ├── .codecov.yml ├── .flake8 ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature-requests-or-general-questions.md ├── .gitignore ├── .travis.yml ├── HISTORY.rst ├── LICENSE ├── MANIFEST.in ├── README.md ├── compile.sh ├── docs ├── Makefile ├── conf.py ├── index.rst ├── tensorboard.rst ├── tutorial.rst ├── tutorial_zh.rst └── utils.rst ├── examples ├── RUN_AFTER_PIP_INSTALL ├── __init__.py ├── chainer │ ├── extension_logger │ │ ├── net.py │ │ ├── train_dcgan.py │ │ ├── updater.py │ │ ├── visualize.py │ │ └── writetensorboard.py │ └── plain_logger │ │ ├── data.py │ │ ├── net.py │ │ └── train_vae.py ├── demo.py ├── demo_beholder.py ├── demo_caffe2.py ├── demo_custom_scalars.py ├── demo_embedding.py ├── demo_graph.py ├── demo_hparams.py ├── demo_matplotlib.py ├── demo_multiple_embedding.py ├── demo_nvidia_smi.py ├── demo_onnx.py ├── demo_purge.py └── tensorboardX ├── screenshots ├── Demo.gif ├── audio.png ├── distribution.png ├── embedding.png ├── graph.png ├── histogram.png ├── image.png ├── scalar.png └── text.png ├── setup.cfg ├── setup.py ├── tensorboardX.patch ├── tensorboardX ├── __init__.py ├── beholder │ ├── __init__.py │ ├── beholder.py │ ├── file_system_tools.py │ ├── shared_config.py │ └── video_writing.py ├── caffe2_graph.py ├── crc32c.py ├── embedding.py ├── event_file_writer.py ├── onnx_graph.py ├── proto │ ├── __init__.py │ ├── api.proto │ ├── api_pb2.py │ ├── attr_value.proto │ ├── attr_value_pb2.py │ ├── event.proto │ ├── event_pb2.py │ ├── graph.proto │ ├── graph_pb2.py │ ├── layout.proto │ ├── layout_pb2.py │ ├── node_def.proto │ ├── node_def_pb2.py │ ├── plugin_hparams.proto │ ├── plugin_hparams_pb2.py │ ├── plugin_mesh.proto │ ├── plugin_mesh_pb2.py │ ├── plugin_pr_curve.proto │ ├── plugin_pr_curve_pb2.py │ ├── plugin_text.proto │ ├── plugin_text_pb2.py │ ├── resource_handle.proto │ ├── resource_handle_pb2.py │ ├── step_stats.proto │ ├── step_stats_pb2.py │ ├── summary.proto │ ├── summary_pb2.py │ ├── tensor.proto │ ├── tensor_pb2.py │ ├── tensor_shape.proto │ ├── tensor_shape_pb2.py │ ├── types.proto │ ├── types_pb2.py │ ├── versions.proto │ └── versions_pb2.py ├── proto_graph.py ├── pytorch_graph.py ├── record_writer.py ├── summary.py ├── torchvis.py ├── utils.py ├── visdom_writer.py ├── writer.py └── x2num.py └── tests ├── __init__.py ├── event_file_writer_test.py ├── expect ├── caffe_mnist.expect ├── caffe_overfeat.expect ├── test_caffe2.test_simple_cnnmodel.expect ├── test_caffe2.test_simple_model.expect ├── test_pr_curve.test_pr_purve.expect ├── test_pr_curve.test_pr_purve_raw.expect ├── test_summary.test_audio.expect ├── test_summary.test_custom_scalars.expect ├── test_summary.test_float32_image.expect ├── test_summary.test_histogram_auto.expect ├── test_summary.test_histogram_doane.expect ├── test_summary.test_histogram_fd.expect ├── test_summary.test_hparams.expect ├── test_summary.test_image_with_3_channel_batched.expect ├── test_summary.test_image_with_boxes.expect ├── test_summary.test_image_with_four_channel.expect ├── test_summary.test_image_with_four_channel_batched.expect ├── test_summary.test_image_with_one_channel.expect ├── test_summary.test_image_with_one_channel_batched.expect ├── test_summary.test_image_without_channel.expect ├── test_summary.test_mesh.expect ├── test_summary.test_text.expect ├── test_summary.test_uint8_image.expect └── test_summary.test_video.expect ├── expect_reader.py ├── record_writer_test.py ├── test_beholder.py ├── test_caffe2.py ├── test_chainer_np.py ├── test_crc32c.py ├── test_embedding.py ├── test_figure.py ├── test_numpy.py ├── test_onnx_graph.py ├── test_pr_curve.py ├── test_pytorch_graph.py ├── test_pytorch_np.py ├── test_record_writer.py ├── test_summary.py ├── test_summary_writer.py ├── test_test.py ├── test_utils.py ├── test_visdom.py └── test_writer.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Global 2 | .DS_Store 3 | .idea 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Noncommercial Use License 2 | 3 | Software Copyright (c) 2020 OpenAI 4 | 5 | We don’t claim ownership of the content you create with Jukebox. 6 | We only ask that you use Jukebox responsibly and clearly indicate your content was created using OpenAI’s Jukebox. 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 9 | documentation files (the "Software"), to deal in the Software, including without limitation the rights to use, copy, 10 | modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the 11 | Software is furnished to do so, subject to the following conditions: 12 | 13 | No portion of the Software, nor any content created with the Software, may be used for commercial purposes. 14 | 15 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 16 | 17 | The above copyright notice and this permission notice need not be included with content created by the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 20 | WARRANTIES OF MERCHANTABILITY,FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 21 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 22 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include jukebox *.py 2 | recursive-include jukebox *.txt 3 | -------------------------------------------------------------------------------- /apex/.gitignore: -------------------------------------------------------------------------------- 1 | apex.egg-info 2 | dist 3 | build 4 | docs/build 5 | *~ -------------------------------------------------------------------------------- /apex/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/apex/.nojekyll -------------------------------------------------------------------------------- /apex/LICENSE: -------------------------------------------------------------------------------- 1 | All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /apex/apex.patch: -------------------------------------------------------------------------------- 1 | diff --git a/csrc/fused_adam_cuda_kernel.cu b/csrc/fused_adam_cuda_kernel.cu 2 | index 34f7aa2..95581d1 100644 3 | --- a/csrc/fused_adam_cuda_kernel.cu 4 | +++ b/csrc/fused_adam_cuda_kernel.cu 5 | @@ -19,8 +19,8 @@ typedef enum{ 6 | 7 | template 8 | __global__ void adam_cuda_kernel( 9 | - T* __restrict__ p, 10 | - GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed 11 | + GRAD_T* __restrict__ p, 12 | + T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed 13 | T* __restrict__ m, 14 | T* __restrict__ v, 15 | const GRAD_T * __restrict__ g, 16 | @@ -50,7 +50,7 @@ __global__ void adam_cuda_kernel( 17 | else // Mode 1 18 | denom = sqrtf(v[j]) + eps; 19 | float update = (m[j]/denom) + (decay*p[j]); 20 | - p[j] = p[j] - (step_size*update); 21 | + p[j] = (GRAD_T) (p[j] - (step_size*update)); 22 | if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j]; 23 | } 24 | } 25 | @@ -93,14 +93,14 @@ void fused_adam_cuda( 26 | 27 | if (g.scalar_type() == at::ScalarType::Half) { 28 | //all other values should be fp32 for half gradients 29 | - AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); 30 | +// AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); 31 | //dispatch is done on the gradient type 32 | using namespace at; // prevents "toString is undefined" errors 33 | DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", 34 | using accscalar_t = at::acc_type; 35 | adam_cuda_kernel<<>>( 36 | - p.data(), 37 | - p_copy.numel() ? p_copy.data() : NULL, 38 | + p.data(), 39 | + NULL, //don't output p_copy for fp32, it's wasted write 40 | m.data(), 41 | v.data(), 42 | g.data(), 43 | -------------------------------------------------------------------------------- /apex/apex/RNN/README.md: -------------------------------------------------------------------------------- 1 | Under construction... 2 | -------------------------------------------------------------------------------- /apex/apex/RNN/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import LSTM, GRU, ReLU, Tanh, mLSTM 2 | 3 | __all__ = ['models'] 4 | -------------------------------------------------------------------------------- /apex/apex/RNN/cells.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .RNNBackend import RNNCell 6 | 7 | from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend 8 | 9 | import math 10 | 11 | 12 | class mLSTMRNNCell(RNNCell): 13 | """ 14 | mLSTMRNNCell 15 | """ 16 | 17 | def __init__(self, input_size, hidden_size, bias = False, output_size = None): 18 | gate_multiplier = 4 19 | super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size) 20 | 21 | self.w_mih = nn.Parameter(torch.Tensor(self.output_size, self.input_size)) 22 | self.w_mhh = nn.Parameter(torch.Tensor(self.output_size, self.output_size)) 23 | 24 | self.reset_parameters() 25 | 26 | def forward(self, input): 27 | """ 28 | mLSTMRNNCell.forward() 29 | """ 30 | #if not inited or bsz has changed this will create hidden states 31 | self.init_hidden(input.size()[0]) 32 | 33 | hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden 34 | 35 | self.hidden = list( 36 | self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh, 37 | b_ih=self.b_ih, b_hh=self.b_hh) 38 | ) 39 | 40 | if self.output_size != self.hidden_size: 41 | self.hidden[0] = F.linear(self.hidden[0], self.w_ho) 42 | return tuple(self.hidden) 43 | 44 | 45 | def new_like(self, new_input_size=None): 46 | if new_input_size is None: 47 | new_input_size = self.input_size 48 | 49 | return type(self)( 50 | new_input_size, 51 | self.hidden_size, 52 | self.bias, 53 | self.output_size) 54 | 55 | def mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None): 56 | """ 57 | mLSTMCell 58 | """ 59 | 60 | if input.is_cuda: 61 | igates = F.linear(input, w_ih) 62 | m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh) 63 | hgates = F.linear(m, w_hh) 64 | 65 | state = fusedBackend.LSTMFused.apply 66 | return state(igates, hgates, hidden[1], b_ih, b_hh) 67 | 68 | hx, cx = hidden 69 | 70 | m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh) 71 | gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh) 72 | 73 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 74 | 75 | ingate = F.sigmoid(ingate) 76 | forgetgate = F.sigmoid(forgetgate) 77 | cellgate = F.tanh(cellgate) 78 | outgate = F.sigmoid(outgate) 79 | 80 | cy = (forgetgate * cx) + (ingate * cellgate) 81 | hy = outgate * F.tanh(cy) 82 | 83 | return hy, cy 84 | 85 | -------------------------------------------------------------------------------- /apex/apex/RNN/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell 4 | 5 | from .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell 6 | from .cells import mLSTMRNNCell, mLSTMCell 7 | 8 | def toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0): 9 | """ 10 | :class:`toRNNBackend` 11 | """ 12 | 13 | if bidirectional: 14 | return bidirectionalRNN(inputRNN, num_layers, dropout = dropout) 15 | else: 16 | return stackedRNN(inputRNN, num_layers, dropout = dropout) 17 | 18 | 19 | def LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): 20 | """ 21 | :class:`LSTM` 22 | """ 23 | inputRNN = RNNCell(4, input_size, hidden_size, LSTMCell, 2, bias, output_size) 24 | return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) 25 | 26 | def GRU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): 27 | """ 28 | :class:`GRU` 29 | """ 30 | inputRNN = RNNCell(3, input_size, hidden_size, GRUCell, 1, bias, output_size) 31 | return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) 32 | 33 | def ReLU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): 34 | """ 35 | :class:`ReLU` 36 | """ 37 | inputRNN = RNNCell(1, input_size, hidden_size, RNNReLUCell, 1, bias, output_size) 38 | return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) 39 | 40 | def Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): 41 | """ 42 | :class:`Tanh` 43 | """ 44 | inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size) 45 | return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) 46 | 47 | def mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): 48 | """ 49 | :class:`mLSTM` 50 | """ 51 | inputRNN = mLSTMRNNCell(input_size, hidden_size, bias=bias, output_size=output_size) 52 | return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) 53 | 54 | 55 | -------------------------------------------------------------------------------- /apex/apex/__init__.py: -------------------------------------------------------------------------------- 1 | from . import parallel 2 | from . import amp 3 | from . import fp16_utils 4 | 5 | # For optimizers and normalization there is no Python fallback. 6 | # Absence of cuda backend is a hard error. 7 | # I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda 8 | # to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext 9 | # so they expect those backends to be available, but for some reason they actually aren't 10 | # available (for example because they built improperly in a way that isn't revealed until 11 | # load time) the error message is timely and visible. 12 | from . import optimizers 13 | from . import normalization 14 | -------------------------------------------------------------------------------- /apex/apex/amp/README.md: -------------------------------------------------------------------------------- 1 | # amp: Automatic Mixed Precision 2 | 3 | ## Annotating User Functions 4 | 5 | Nearly all PyTorch user code needs nothing more than the two steps 6 | above to use amp. After all, custom layers are built out of simpler 7 | PyTorch components, and amp already can see those. 8 | 9 | However, any custom C++ or CUDA code is outside of amp's (default) 10 | view of things. For example, suppose I implemented a new recurrent 11 | cell called a "forgetful recurrent unit" that calls directly into a 12 | CUDA backend: 13 | 14 | ```python 15 | from backend import FRUBackend 16 | 17 | def fru(input, hidden, weight, bias): 18 | # call to CUDA code 19 | FRUBackend(input, hidden, weight, bias) 20 | ``` 21 | 22 | In this case, it is possible to get a runtime type mismatch. For 23 | example, you might have `input` in fp16, and `weight` in fp32, and amp 24 | doesn't have the visibility to insert an appropriate cast. 25 | 26 | amp exposes two ways to handle "invisible" backend code: function 27 | annotations and explicit registration. 28 | 29 | #### Function annotation 30 | 31 | The first way to handle backend code is a set of function annotations: 32 | 33 | - `@amp.half_function` 34 | - `@amp.float_function` 35 | - `@amp.promote_function` 36 | 37 | These correspond to: 38 | 39 | - Cast all arguments to fp16 40 | - Cast all argumnets fo fp32 41 | - If there are any type mismatches, cast everything to the widest type 42 | 43 | In our example, we believe that the FRU unit is fp16-safe and will get 44 | performance gains from casting its arguments to fp16, so we write: 45 | 46 | ```python 47 | @amp.half_function 48 | def fru(input, hidden, weight, bias): 49 | #... 50 | ``` 51 | 52 | #### Explicit registration 53 | 54 | The other way to handle backend code is with explicit function 55 | registration: 56 | 57 | - `amp.register_half_function(module, function_name)` 58 | - `amp.register_float_function(module, function_name)` 59 | - `amp.register_promote_function(module, function_name)` 60 | 61 | When using this API, `module` is the containing class or module for 62 | the function, and `function_name` is the _string_ name of the 63 | function. Note that the function must be registered before the call to 64 | `amp.initalize()`. 65 | 66 | For our FRU unit, we can register the backend function directly: 67 | 68 | ```python 69 | import backend 70 | 71 | amp.register_half_function(backend, 'FRUBackend') 72 | ``` 73 | -------------------------------------------------------------------------------- /apex/apex/amp/__init__.py: -------------------------------------------------------------------------------- 1 | from .amp import init, half_function, float_function, promote_function,\ 2 | register_half_function, register_float_function, register_promote_function 3 | from .handle import scale_loss, disable_casts 4 | from .frontend import initialize 5 | from ._amp_state import master_params, _amp_state 6 | -------------------------------------------------------------------------------- /apex/apex/amp/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 1, 0) 2 | __version__ = '.'.join(map(str, VERSION)) 3 | -------------------------------------------------------------------------------- /apex/apex/amp/_amp_state.py: -------------------------------------------------------------------------------- 1 | # This is a "header object" that allows different amp modules to communicate. 2 | # I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like. 3 | # But apparently it's ok: 4 | # http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm 5 | import os 6 | import torch 7 | 8 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 9 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 10 | 11 | if TORCH_MAJOR == 0: 12 | import collections.abc as container_abcs 13 | else: 14 | from torch._six import container_abcs 15 | 16 | 17 | class AmpState(object): 18 | def __init__(self): 19 | self.hard_override=False 20 | self.allow_incoming_model_not_fp32 = False 21 | self.verbosity=1 22 | 23 | 24 | # Attribute stash. Could also just stash things as global module attributes. 25 | _amp_state = AmpState() 26 | 27 | 28 | def warn_or_err(msg): 29 | if _amp_state.hard_override: 30 | print("Warning: " + msg) 31 | else: 32 | raise RuntimeError(msg) 33 | # I'm not sure if allowing hard_override is a good idea. 34 | # + " If you're sure you know what you're doing, supply " + 35 | # "hard_override=True to amp.initialize.") 36 | 37 | 38 | distributed = False 39 | if 'WORLD_SIZE' in os.environ: 40 | distributed = int(os.environ['WORLD_SIZE']) > 1 41 | 42 | 43 | def maybe_print(msg, rank0=False): 44 | if _amp_state.verbosity > 0: 45 | if rank0: 46 | if distributed: 47 | if torch.distributed.get_rank() == 0: 48 | print(msg) 49 | else: 50 | print(msg) 51 | else: 52 | print(msg) 53 | 54 | 55 | # def iter_params(param_groups): 56 | # for group in param_groups: 57 | # for p in group['params']: 58 | # yield p 59 | 60 | 61 | def master_params(optimizer): 62 | """ 63 | Generator expression that iterates over the params owned by ``optimizer``. 64 | 65 | Args: 66 | optimizer: An optimizer previously returned from ``amp.initialize``. 67 | """ 68 | for group in optimizer.param_groups: 69 | for p in group['params']: 70 | yield p 71 | -------------------------------------------------------------------------------- /apex/apex/amp/compat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # True for post-0.4, when Variables/Tensors merged. 4 | def variable_is_tensor(): 5 | v = torch.autograd.Variable() 6 | return isinstance(v, torch.Tensor) 7 | 8 | def tensor_is_variable(): 9 | x = torch.Tensor() 10 | return type(x) == torch.autograd.Variable 11 | 12 | # False for post-0.4 13 | def tensor_is_float_tensor(): 14 | x = torch.Tensor() 15 | return type(x) == torch.FloatTensor 16 | 17 | # Akin to `torch.is_tensor`, but returns True for Variable 18 | # objects in pre-0.4. 19 | def is_tensor_like(x): 20 | return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable) 21 | 22 | # Wraps `torch.is_floating_point` if present, otherwise checks 23 | # the suffix of `x.type()`. 24 | def is_floating_point(x): 25 | if hasattr(torch, 'is_floating_point'): 26 | return torch.is_floating_point(x) 27 | try: 28 | torch_type = x.type() 29 | return torch_type.endswith('FloatTensor') or \ 30 | torch_type.endswith('HalfTensor') or \ 31 | torch_type.endswith('DoubleTensor') 32 | except AttributeError: 33 | return False 34 | 35 | def scalar_python_val(x): 36 | if hasattr(x, 'item'): 37 | return x.item() 38 | else: 39 | if isinstance(x, torch.autograd.Variable): 40 | return x.data[0] 41 | else: 42 | return x[0] 43 | -------------------------------------------------------------------------------- /apex/apex/amp/lists/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/apex/apex/amp/lists/__init__.py -------------------------------------------------------------------------------- /apex/apex/amp/lists/functional_overrides.py: -------------------------------------------------------------------------------- 1 | 2 | # TODO: think about the following two. They do weird things. 3 | # - torch.nn.utils.clip_grad (but it should always be fp32 anyway) 4 | # - torch.nn.utils.weight_norm 5 | 6 | # Notes: 7 | # F.instance_norm uses batch_norm internally. Which correctly handles 8 | # fp16 in/out with fp32 weights. So we shouldn't do anything for 9 | # either of these. 10 | # F.normalize calls `input.norm()` internally, so it's redundant, but 11 | # kept here in case impl. changes. 12 | # F.cosine_similarity is same: calls `x.norm()` internally. 13 | 14 | import torch.nn.functional 15 | 16 | MODULE = torch.nn.functional 17 | 18 | FP16_FUNCS = [ 19 | 'conv1d', 20 | 'conv2d', 21 | 'conv3d', 22 | 'conv_transpose1d', 23 | 'conv_transpose2d', 24 | 'conv_transpose3d', 25 | 'conv_tbc', # Undocumented / maybe new? 26 | 'linear', 27 | ] 28 | 29 | FP32_FUNCS = [ 30 | 31 | # Interpolation/Upsampling 32 | 'interpolate', 33 | 34 | # Pointwise 35 | 'softplus', 36 | 'softmin', 37 | 'log_softmax', 38 | 'softmax', 39 | 40 | # Normalization 41 | 'layer_norm', 42 | 'group_norm', 43 | 'local_response_norm', 44 | 'normalize', 45 | 'cosine_similarity', 46 | 47 | # Loss functions 48 | # TODO: which of these can be fp16? 49 | 'poisson_nll_loss', 50 | 'cosine_embedding_loss', 51 | 'cross_entropy', 52 | 'hinge_embedding_loss', 53 | 'kl_div', 54 | 'l1_loss', 55 | 'mse_loss', 56 | 'margin_ranking_loss', 57 | 'multilabel_margin_loss', 58 | 'multilabel_soft_margin_loss', 59 | 'multi_margin_loss', 60 | 'nll_loss', 61 | 'binary_cross_entropy_with_logits', 62 | 'smooth_l1_loss', 63 | 'soft_margin_loss', 64 | 'triplet_margin_loss' 65 | ] 66 | 67 | BANNED_FUNCS = [ 68 | ('binary_cross_entropy', 69 | ("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` " 70 | "It requires that the output of the previous function be already a FloatTensor. \n\n" 71 | "Most models have a Sigmoid right before BCELoss. In that case, you can use\n" 72 | " torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer " 73 | "that is compatible with amp.\nAnother option is to add\n" 74 | " amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n" 75 | "If you _really_ know what you are doing, you can disable this warning by passing " 76 | "allow_banned=True to `amp.init()`.")) 77 | ] 78 | -------------------------------------------------------------------------------- /apex/apex/amp/lists/tensor_overrides.py: -------------------------------------------------------------------------------- 1 | from .. import compat 2 | from . import torch_overrides 3 | 4 | import importlib 5 | 6 | import torch 7 | 8 | if compat.variable_is_tensor() and not compat.tensor_is_variable(): 9 | MODULE = torch.Tensor 10 | else: 11 | MODULE = torch.autograd.Variable 12 | 13 | 14 | FP16_FUNCS = [ 15 | '__matmul__', 16 | ] 17 | 18 | FP32_FUNCS = [ 19 | '__ipow__', 20 | '__pow__', 21 | '__rpow__', 22 | 23 | # Cast to fp32 before transfer to CPU 24 | 'cpu', 25 | ] 26 | 27 | CASTS = [ 28 | '__add__', 29 | '__div__', 30 | '__eq__', 31 | '__ge__', 32 | '__gt__', 33 | '__iadd__', 34 | '__idiv__', 35 | '__imul__', 36 | '__isub__', 37 | '__itruediv__', 38 | '__le__', 39 | '__lt__', 40 | '__mul__', 41 | '__ne__', 42 | '__radd__', 43 | '__rdiv__', 44 | '__rmul__', 45 | '__rsub__', 46 | '__rtruediv__', 47 | '__sub__', 48 | '__truediv__', 49 | ] 50 | 51 | # None of these, but here to make code cleaner. 52 | SEQUENCE_CASTS = [] 53 | 54 | # We need to grab all the methods from torch_overrides and add them to 55 | # the Tensor lists as well, as almost all methods are duplicated 56 | # between `torch` and `torch.Tensor` (and check with `hasattr`, 57 | # because a few random ones aren't defined on Tensor) 58 | _self_mod = importlib.import_module(__name__) 59 | for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']: 60 | lst = getattr(_self_mod, attrname) 61 | for fn in getattr(torch_overrides, attrname): 62 | if hasattr(MODULE, fn): 63 | lst.append(fn) 64 | -------------------------------------------------------------------------------- /apex/apex/amp/lists/torch_overrides.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .. import utils 4 | 5 | MODULE = torch 6 | 7 | FP16_FUNCS = [ 8 | # Low level functions wrapped by torch.nn layers. 9 | # The wrapper layers contain the weights which are then passed in as a parameter 10 | # to these functions. 11 | 'conv1d', 12 | 'conv2d', 13 | 'conv3d', 14 | 'conv_transpose1d', 15 | 'conv_transpose2d', 16 | 'conv_transpose3d', 17 | 'conv_tbc', 18 | 'prelu', 19 | 20 | # BLAS 21 | 'addmm', 22 | 'addmv', 23 | 'addr', 24 | 'matmul', 25 | 'mm', 26 | 'mv', 27 | ] 28 | 29 | FP32_FUNCS = [ 30 | # Pointwise 31 | 'acos', 32 | 'asin', 33 | 'cosh', 34 | 'erfinv', 35 | 'exp', 36 | 'expm1', 37 | 'log', 38 | 'log10', 39 | 'log2', 40 | 'reciprocal', 41 | 'rsqrt', 42 | 'sinh', 43 | 'tan', 44 | 45 | # Other math 46 | 'pow', 47 | 48 | # Reduction 49 | 'cumprod', 50 | 'cumsum', 51 | 'dist', 52 | 'mean', 53 | 'norm', 54 | 'prod', 55 | 'std', 56 | 'sum', 57 | 'var', 58 | 59 | # Misc 60 | 'renorm' 61 | ] 62 | 63 | # Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We 64 | # check the CUDA version -- if at least 9.1, then put the bmm 65 | # functions on the fp16 list. Otherwise, put them on the fp32 list. 66 | _bmms = ['addbmm', 67 | 'baddbmm', 68 | 'bmm'] 69 | if utils.get_cuda_version() >= (9, 1, 0): 70 | FP16_FUNCS.extend(_bmms) 71 | else: 72 | FP32_FUNCS.extend(_bmms) 73 | 74 | # Multi-tensor fns that may need type promotion 75 | CASTS = [ 76 | # Multi-tensor math 77 | 'addcdiv', 78 | 'addcmul', 79 | 'atan2', 80 | 'cross', 81 | 'bilinear', 82 | 83 | # Element-wise _or_ tensor-wise math 84 | 'add', 85 | 'div', 86 | 'mul', 87 | 88 | # Comparison 89 | 'eq', 90 | 'equal', 91 | 'ge', 92 | 'gt', 93 | 'le', 94 | 'lt', 95 | 'ne' 96 | ] 97 | 98 | # Functions that take sequence arguments. We need to inspect the whole 99 | # sequence and cast to the widest type. 100 | SEQUENCE_CASTS = [ 101 | 'cat', 102 | 'stack' 103 | ] 104 | -------------------------------------------------------------------------------- /apex/apex/amp/rnn_compat.py: -------------------------------------------------------------------------------- 1 | from . import utils, wrap 2 | 3 | import torch 4 | _VF = torch._C._VariableFunctions 5 | RNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm'] 6 | 7 | def _gen_VF_wrapper(name): 8 | def wrapper(*args, **kwargs): 9 | return getattr(_VF, name)(*args, **kwargs) 10 | return wrapper 11 | 12 | # Some python magic to generate an object that has the rnn cell functions 13 | # defined on it, all of which call into corresponding _VF version. 14 | # Intended to patch torch.nn.modules.rnn._VF (aka, the ref named "_VF" 15 | # imported at module scope within torch.nn.modules.rnn). This should 16 | # not affect third-party importers of _VF.py. 17 | class VariableFunctionsShim(object): 18 | def __init__(self): 19 | for name in RNN_NAMES: 20 | for suffix in ['', '_cell']: 21 | fn_name = name + suffix 22 | setattr(self, fn_name, _gen_VF_wrapper(fn_name)) 23 | 24 | def has_old_rnns(): 25 | try: 26 | torch.nn.backends.thnn.backend.LSTMCell 27 | return True 28 | except: 29 | return False 30 | 31 | def whitelist_rnn_cells(handle, verbose): 32 | # Different module + function names in old/new RNN cases 33 | if has_old_rnns(): 34 | fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell'] 35 | mod = torch.nn.backends.thnn.backend 36 | else: 37 | fn_names = [x + '_cell' for x in RNN_NAMES] 38 | mod = torch.nn.modules.rnn._VF 39 | assert isinstance(mod, VariableFunctionsShim) 40 | 41 | # Insert casts on cell functions 42 | for fn in fn_names: 43 | wrap.cached_cast(mod, fn, utils.maybe_half, handle, 44 | try_caching=True, verbose=verbose) 45 | 46 | if has_old_rnns(): 47 | # Special handling of `backward` for fused gru / lstm: 48 | # The `backward` method calls Tensor.sum() (blacklist) internally, 49 | # and then the resulting grad_input has the wrong type. 50 | # TODO: where else is this a problem? 51 | for rnn_type in ['GRUFused', 'LSTMFused']: 52 | mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type) 53 | wrap.disable_casts(mod, 'backward', handle) 54 | -------------------------------------------------------------------------------- /apex/apex/fp16_utils/README.md: -------------------------------------------------------------------------------- 1 | fp16_optimizer.py contains `FP16_Optimizer`, a Python class designed to wrap an existing Pytorch optimizer and automatically enable master parameters and loss scaling in a manner transparent to the user. To use `FP16_Optimizer`, only two lines of one's Python model need to change. 2 | 3 | #### [FP16_Optimizer API documentation](https://nvidia.github.io/apex/fp16_utils.html#automatic-management-of-master-params-loss-scaling) 4 | 5 | #### [Simple examples with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple) 6 | 7 | #### [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) 8 | 9 | #### [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model) 10 | 11 | 12 | fp16_util.py contains a number of utilities to manually manage master parameters and loss scaling, if the user chooses. 13 | 14 | #### [Manual management documentation](https://nvidia.github.io/apex/fp16_utils.html#manual-master-parameter-management) 15 | 16 | The [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) and [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model) directories also contain `main.py` files that demonstrate manual management of master parameters and static loss scaling. These examples illustrate what sort of operations `FP16_Optimizer` is performing automatically. 17 | -------------------------------------------------------------------------------- /apex/apex/fp16_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .fp16util import ( 2 | BN_convert_float, 3 | network_to_half, 4 | prep_param_lists, 5 | model_grads_to_master_grads, 6 | master_params_to_model_params, 7 | tofp16, 8 | to_python_float, 9 | clip_grad_norm, 10 | convert_module, 11 | convert_network, 12 | FP16Model, 13 | ) 14 | 15 | from .fp16_optimizer import FP16_Optimizer 16 | from .loss_scaler import LossScaler, DynamicLossScaler 17 | -------------------------------------------------------------------------------- /apex/apex/multi_tensor_apply/__init__.py: -------------------------------------------------------------------------------- 1 | from .multi_tensor_apply import MultiTensorApply 2 | 3 | multi_tensor_applier = MultiTensorApply(2048*32) 4 | 5 | -------------------------------------------------------------------------------- /apex/apex/multi_tensor_apply/multi_tensor_apply.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MultiTensorApply(object): 4 | available = False 5 | warned = False 6 | 7 | def __init__(self, chunk_size): 8 | try: 9 | import amp_C 10 | MultiTensorApply.available = True 11 | self.chunk_size = chunk_size 12 | except ImportError as err: 13 | MultiTensorApply.available = False 14 | MultiTensorApply.import_err = err 15 | 16 | def check_avail(self): 17 | if MultiTensorApply.available == False: 18 | raise RuntimeError( 19 | "Attempted to call MultiTensorApply method, but MultiTensorApply " 20 | "is not available, possibly because Apex was installed without " 21 | "--cpp_ext --cuda_ext. Original import error message:", 22 | MultiTensorApply.import_err) 23 | 24 | def __call__(self, op, noop_flag_buffer, tensor_lists, *args): 25 | self.check_avail() 26 | 27 | return op(self.chunk_size, 28 | noop_flag_buffer, 29 | tensor_lists, 30 | *args) 31 | -------------------------------------------------------------------------------- /apex/apex/normalization/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_layer_norm import FusedLayerNorm 2 | -------------------------------------------------------------------------------- /apex/apex/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_adam import FusedAdam 2 | from .fp16_optimizer import FP16_Optimizer 3 | -------------------------------------------------------------------------------- /apex/apex/parallel/README.md: -------------------------------------------------------------------------------- 1 | ## Distributed Data Parallel 2 | 3 | distributed.py contains the source code for `apex.parallel.DistributedDataParallel`, a module wrapper that enables multi-process multi-GPU data parallel training optimized for NVIDIA's NCCL communication library. 4 | 5 | `apex.parallel.DistributedDataParallel` achieves high performance by overlapping communication with 6 | computation in the backward pass and bucketing smaller transfers to reduce the total number of 7 | transfers required. 8 | 9 | multiproc.py contains the source code for `apex.parallel.multiproc`, a launch utility that places one process on each of the node's available GPUs. 10 | 11 | #### [API Documentation](https://nvidia.github.io/apex/parallel.html) 12 | 13 | #### [Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed) 14 | 15 | #### [Imagenet example with Mixed Precision](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) 16 | 17 | #### [Simple example with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple/distributed_apex) 18 | 19 | ### Synchronized Batch Normalization 20 | 21 | `apex.parallel.SyncBatchNorm` has similar APIs as with `torch.nn.BatchNorm*N*d`. 22 | It reduces stats on the first (channel) dimension of the Tensor and accepts 23 | arbitrary spatial dimensions. 24 | 25 | #### Installation 26 | 27 | Apex provides two sync BN implementation: 28 | 29 | 1. There is the Python-only implementation, which is the default implementation 30 | when install with `python setup.py install`. 31 | It uses PyTorch primitive operations and distributed communication package from 32 | `torch.distributed`. 33 | 34 | - _Python-only implementation requires input tensor to be of same data type as 35 | layer_ 36 | 37 | 2. We also provide implementation with kernels through CUDA/C++ extension with 38 | improved performance. We are experimenting with Welford and Kahan for reduction 39 | hoping to get better accuracy. 40 | To use the kernel implementation, user need to install Apex with CUDA extension 41 | enabled `python setup.py install --cuda_ext`. 42 | 43 | - _Custom kernel implementation supports fp16 input with fp32 layer as cudnn. 44 | This is required to run imagenet example in fp16._ 45 | 46 | - _Currently kernel implementation only supports GPU._ 47 | 48 | #### HowTo 49 | 50 | 1. User could use `apex.parallel.SyncBatchNorm` by building their module with 51 | the layer explicitly. 52 | 53 | ``` 54 | import apex 55 | input_t = torch.randn(3, 5, 20).cuda() 56 | sbn = apex.parallel.SyncBatchNorm(5).cuda() 57 | output_t = sbn(input) 58 | ``` 59 | 60 | 2. User could also take a constructed `torch.nn.Model` and replace all its `torch.nn.BatchNorm*N*d` modules with `apex.parallel.SyncBatchNorm` through utility function `apex.parallel.convert_syncbn_model`. 61 | 62 | ``` 63 | # model is an instance of torch.nn.Module 64 | import apex 65 | sync_bn_model = apex.parallel.convert_syncbn_model(model) 66 | ``` 67 | -------------------------------------------------------------------------------- /apex/apex/parallel/multiproc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import subprocess 4 | 5 | def docstring_hack(): 6 | """ 7 | Multiproc file which will launch a set of processes locally for multi-gpu 8 | usage: python -m apex.parallel.multiproc main.py ... 9 | """ 10 | pass 11 | 12 | argslist = list(sys.argv)[1:] 13 | world_size = torch.cuda.device_count() 14 | 15 | if '--world-size' in argslist: 16 | world_size = int(argslist[argslist.index('--world-size')+1]) 17 | else: 18 | argslist.append('--world-size') 19 | argslist.append(str(world_size)) 20 | 21 | workers = [] 22 | 23 | for i in range(world_size): 24 | if '--rank' in argslist: 25 | argslist[argslist.index('--rank')+1] = str(i) 26 | else: 27 | argslist.append('--rank') 28 | argslist.append(str(i)) 29 | stdout = None if i == 0 else open("GPU_"+str(i)+".log", "w") 30 | print(argslist) 31 | p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) 32 | workers.append(p) 33 | 34 | for p in workers: 35 | p.wait() 36 | -------------------------------------------------------------------------------- /apex/apex/reparameterization/README.md: -------------------------------------------------------------------------------- 1 | Under construction... 2 | -------------------------------------------------------------------------------- /apex/csrc/amp_C_frontend.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void multi_tensor_scale_cuda( 4 | int chunk_size, 5 | at::Tensor noop_flag, 6 | std::vector> tensor_lists, 7 | float scale); 8 | 9 | void multi_tensor_axpby_cuda( 10 | int chunk_size, 11 | at::Tensor noop_flag, 12 | std::vector> tensor_lists, 13 | float a, 14 | float b, 15 | int arg_to_check); 16 | 17 | std::tuple multi_tensor_l2norm_cuda( 18 | int chunk_size, 19 | at::Tensor noop_flag, 20 | std::vector> tensor_lists, 21 | at::optional per_tensor_python); 22 | 23 | void multi_tensor_lamb_stage1_cuda( 24 | int chunk_size, 25 | at::Tensor noop_flag, 26 | std::vector> tensor_lists, 27 | at::Tensor per_tensor_decay, 28 | const int step, 29 | const float beta1, 30 | const float beta2, 31 | const float epsilon, 32 | const float global_grad_norm, 33 | const float max_global_grad_norm); 34 | 35 | void multi_tensor_lamb_stage2_cuda( 36 | int chunk_size, 37 | at::Tensor noop_flag, 38 | std::vector> tensor_lists, 39 | at::Tensor per_tensor_param_norm, 40 | at::Tensor per_tensor_update_norm, 41 | const float step_size); 42 | 43 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 44 | m.def("multi_tensor_scale", &multi_tensor_scale_cuda, 45 | "Fused overflow check + scale for a list of contiguous tensors"); 46 | m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda, 47 | "out = a*x + b*y for a list of contiguous tensors"); 48 | m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, 49 | "Computes L2 norm for a list of contiguous tensors"); 50 | m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda, 51 | "Computes update part of LAMB optimizer"); 52 | m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda, 53 | "Completes application of gradient to parameters for LAMB optimizer"); 54 | } 55 | -------------------------------------------------------------------------------- /apex/csrc/flatten_unflatten.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | // https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h 4 | 5 | at::Tensor flatten(std::vector tensors) 6 | { 7 | return torch::utils::flatten_dense_tensors(tensors); 8 | } 9 | 10 | std::vector unflatten(at::Tensor flat, std::vector tensors) 11 | { 12 | return torch::utils::unflatten_dense_tensors(flat, tensors); 13 | } 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("flatten", &flatten, "Flatten dense tensors"); 17 | m.def("unflatten", &unflatten, "Unflatten dense tensors"); 18 | } 19 | -------------------------------------------------------------------------------- /apex/csrc/fused_adam_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declaration 4 | void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); 5 | 6 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 7 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 8 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 9 | 10 | // C++ interface 11 | void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { 12 | CHECK_INPUT(p) 13 | if (p_copy.numel() > 0) CHECK_INPUT(p_copy); 14 | CHECK_INPUT(m); 15 | CHECK_INPUT(v); 16 | CHECK_INPUT(g); 17 | int64_t num_elem = p.numel(); 18 | AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); 19 | AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); 20 | AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); 21 | AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); 22 | 23 | fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); 24 | } 25 | 26 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 27 | m.def("adam", &adam, "Adam optimized CUDA implementation."); 28 | } 29 | -------------------------------------------------------------------------------- /apex/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = NVIDIAAPEX 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | gh-pages: 16 | git checkout gh-pages 17 | rm -rf build 18 | rm -rf source 19 | git checkout master -- . 20 | make html 21 | rm -rf ../_modules ../_sources ../_static 22 | mv -fv build/html/* ../ 23 | rm -rf build 24 | git add -A 25 | git commit -m "Generated gh-pages for `git log master -1 --pretty=short --abbrev-commit`" && git push origin gh-pages ; git checkout master 26 | 27 | .PHONY: help Makefile 28 | 29 | # Catch-all target: route all unknown targets to Sphinx using the new 30 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 31 | %: Makefile 32 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 33 | -------------------------------------------------------------------------------- /apex/docs/source/_static/css/pytorch_theme.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 3 | } 4 | 5 | /* Default header fonts are ugly */ 6 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { 7 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 8 | } 9 | 10 | /* Use white for docs background */ 11 | .wy-side-nav-search { 12 | background-color: #fff; 13 | } 14 | 15 | .wy-nav-content-wrap, .wy-menu li.current > a { 16 | background-color: #fff; 17 | } 18 | 19 | @media screen and (min-width: 1400px) { 20 | .wy-nav-content-wrap { 21 | background-color: rgba(0, 0, 0, 0.0470588); 22 | } 23 | 24 | .wy-nav-content { 25 | background-color: #fff; 26 | } 27 | } 28 | 29 | /* Fixes for mobile */ 30 | .wy-nav-top { 31 | background-color: #fff; 32 | background-image: url('../img/apex.jpg'); 33 | background-repeat: no-repeat; 34 | background-position: center; 35 | padding: 0; 36 | margin: 0.4045em 0.809em; 37 | color: #333; 38 | } 39 | 40 | .wy-nav-top > a { 41 | display: none; 42 | } 43 | 44 | @media screen and (max-width: 768px) { 45 | .wy-side-nav-search>a img.logo { 46 | height: 60px; 47 | } 48 | } 49 | 50 | /* This is needed to ensure that logo above search scales properly */ 51 | .wy-side-nav-search a { 52 | display: block; 53 | } 54 | 55 | /* This ensures that multiple constructors will remain in separate lines. */ 56 | .rst-content dl:not(.docutils) dt { 57 | display: table; 58 | } 59 | 60 | /* Use our red for literals (it's very similar to the original color) */ 61 | .rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal { 62 | color: #F05732; 63 | } 64 | 65 | .rst-content tt.xref, a .rst-content tt, .rst-content tt.xref, 66 | .rst-content code.xref, a .rst-content tt, a .rst-content code { 67 | color: #404040; 68 | } 69 | 70 | /* Change link colors (except for the menu) */ 71 | 72 | a { 73 | color: #F05732; 74 | } 75 | 76 | a:hover { 77 | color: #F05732; 78 | } 79 | 80 | 81 | a:visited { 82 | color: #D44D2C; 83 | } 84 | 85 | .wy-menu a { 86 | color: #b3b3b3; 87 | } 88 | 89 | .wy-menu a:hover { 90 | color: #b3b3b3; 91 | } 92 | 93 | /* Default footer text is quite big */ 94 | footer { 95 | font-size: 80%; 96 | } 97 | 98 | footer .rst-footer-buttons { 99 | font-size: 125%; /* revert footer settings - 1/80% = 125% */ 100 | } 101 | 102 | footer p { 103 | font-size: 100%; 104 | } 105 | 106 | /* For hidden headers that appear in TOC tree */ 107 | /* see http://stackoverflow.com/a/32363545/3343043 */ 108 | .rst-content .hidden-section { 109 | display: none; 110 | } 111 | 112 | nav .hidden-section { 113 | display: inherit; 114 | } 115 | 116 | .wy-side-nav-search>div.version { 117 | color: #000; 118 | } 119 | -------------------------------------------------------------------------------- /apex/docs/source/_static/img/nv-pytorch2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/apex/docs/source/_static/img/nv-pytorch2.png -------------------------------------------------------------------------------- /apex/docs/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% block sidebartitle %} {{ super() }} 3 | 4 | 32 | {% endblock %} 33 | 34 | {% block footer %} {{ super() }} 35 | 36 | 51 | {% endblock %} 52 | -------------------------------------------------------------------------------- /apex/docs/source/fp16_utils.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | apex.fp16_utils 5 | =================================== 6 | 7 | This submodule contains utilities designed to streamline the mixed precision training recipe 8 | presented by NVIDIA `on Parallel Forall`_ and in GTC 2018 Sessions 9 | `Training Neural Networks with Mixed Precision: Theory and Practice`_ and 10 | `Training Neural Networks with Mixed Precision: Real Examples`_. 11 | For Pytorch users, Real Examples in particular is recommended. 12 | 13 | Full runnable Python scripts demonstrating ``apex.fp16_utils`` 14 | can be found on the Github page: 15 | 16 | | `Simple FP16_Optimizer demos`_ 17 | | 18 | | `Distributed Mixed Precision Training with imagenet`_ 19 | | 20 | | `Mixed Precision Training with word_language_model`_ 21 | | 22 | | 23 | 24 | .. _`on Parallel Forall`: 25 | https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/ 26 | .. _`Training Neural Networks with Mixed Precision: Theory and Practice`: 27 | http://on-demand.gputechconf.com/gtc/2018/video/S8923/ 28 | .. _`Training Neural Networks with Mixed Precision: Real Examples`: 29 | http://on-demand.gputechconf.com/gtc/2018/video/S81012/ 30 | .. _`Simple FP16_Optimizer demos`: 31 | https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple 32 | .. _`Distributed Mixed Precision Training with imagenet`: 33 | https://github.com/NVIDIA/apex/tree/master/examples/imagenet 34 | .. _`Mixed Precision Training with word_language_model`: 35 | https://github.com/NVIDIA/apex/tree/master/examples/word_language_model 36 | 37 | .. automodule:: apex.fp16_utils 38 | .. currentmodule:: apex.fp16_utils 39 | 40 | Automatic management of master params + loss scaling 41 | ---------------------------------------------------- 42 | 43 | .. autoclass:: FP16_Optimizer 44 | :members: 45 | 46 | .. autoclass:: LossScaler 47 | :members: 48 | 49 | .. autoclass:: DynamicLossScaler 50 | :members: 51 | 52 | Manual master parameter management 53 | ---------------------------------- 54 | 55 | .. autofunction:: prep_param_lists 56 | 57 | .. autofunction:: master_params_to_model_params 58 | 59 | .. autofunction:: model_grads_to_master_grads 60 | -------------------------------------------------------------------------------- /apex/docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. PyTorch documentation master file, created by 2 | sphinx-quickstart on Fri Dec 23 13:31:47 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/nvidia/apex 7 | 8 | Apex (A PyTorch Extension) 9 | =================================== 10 | 11 | This site contains the API documentation for Apex (https://github.com/nvidia/apex), 12 | a Pytorch extension with NVIDIA-maintained utilities to streamline mixed precision and distributed training. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible. 13 | 14 | Installation instructions can be found here: https://github.com/NVIDIA/apex#quick-start. 15 | 16 | .. toctree:: 17 | :maxdepth: 1 18 | :caption: AMP: Automatic Mixed Precision 19 | 20 | amp 21 | 22 | .. toctree:: 23 | :maxdepth: 1 24 | :caption: Distributed Training 25 | 26 | parallel 27 | 28 | .. toctree:: 29 | :maxdepth: 1 30 | :caption: Fused Optimizers 31 | 32 | optimizers 33 | 34 | .. toctree:: 35 | :maxdepth: 1 36 | :caption: Fused Layer Norm 37 | 38 | layernorm 39 | 40 | .. .. toctree:: 41 | :maxdepth: 1 42 | :caption: Deprecated mixed precision API 43 | fp16_util 44 | 45 | .. reparameterization 46 | .. RNN 47 | 48 | Indices and tables 49 | ================== 50 | 51 | * :ref:`genindex` 52 | * :ref:`modindex` 53 | -------------------------------------------------------------------------------- /apex/docs/source/layernorm.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | apex.normalization.fused_layer_norm 5 | =================================== 6 | 7 | .. automodule:: apex.normalization 8 | .. currentmodule:: apex.normalization 9 | 10 | .. FusedAdam 11 | ---------- 12 | 13 | .. autoclass:: FusedLayerNorm 14 | :members: 15 | -------------------------------------------------------------------------------- /apex/docs/source/optimizers.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | apex.optimizers 5 | =================================== 6 | 7 | .. automodule:: apex.optimizers 8 | .. currentmodule:: apex.optimizers 9 | 10 | .. FusedAdam 11 | ---------- 12 | 13 | .. autoclass:: FusedAdam 14 | :members: 15 | -------------------------------------------------------------------------------- /apex/docs/source/parallel.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | apex.parallel 5 | =================================== 6 | 7 | .. automodule:: apex.parallel 8 | .. currentmodule:: apex.parallel 9 | 10 | .. DistributedDataParallel 11 | ---------- 12 | 13 | .. autoclass:: DistributedDataParallel 14 | :members: 15 | 16 | .. autoclass:: Reducer 17 | :members: 18 | 19 | .. autoclass:: SyncBatchNorm 20 | :members: 21 | 22 | Utility functions 23 | ---------------------------------- 24 | 25 | .. autofunction:: convert_syncbn_model 26 | -------------------------------------------------------------------------------- /apex/examples/README.md: -------------------------------------------------------------------------------- 1 | This directory contains examples illustrating Apex mixed precision and distributed tools. 2 | 3 | **Note for users of the pre-unification API**: 4 | `deprecated_api` contains examples illustrating the old (pre-unified) APIs. These APIs will be removed soon, and users are strongly encouraged to switch. The separate mixed precision tools called `Amp` and `FP16_Optimizer` in the old API are exposed via different flags/optimization levels in the new API. 5 | -------------------------------------------------------------------------------- /apex/examples/dcgan/README.md: -------------------------------------------------------------------------------- 1 | Under construction... 2 | -------------------------------------------------------------------------------- /apex/examples/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Base image must at least have pytorch and CUDA installed. 2 | ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:19.03-py3 3 | FROM $BASE_IMAGE 4 | ARG BASE_IMAGE 5 | RUN echo "Installing Apex on top of ${BASE_IMAGE}" 6 | # make sure we don't overwrite some existing directory called "apex" 7 | WORKDIR /tmp/unique_for_apex 8 | # uninstall Apex if present, twice to make absolutely sure :) 9 | RUN pip uninstall -y apex || : 10 | RUN pip uninstall -y apex || : 11 | # SHA is something the user can touch to force recreation of this Docker layer, 12 | # and therefore force cloning of the latest version of Apex 13 | RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git 14 | WORKDIR /tmp/unique_for_apex/apex 15 | RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 16 | WORKDIR /workspace 17 | -------------------------------------------------------------------------------- /apex/examples/docker/README.md: -------------------------------------------------------------------------------- 1 | ## Option 1: Create a new container with Apex 2 | 3 | **Dockerfile** installs the latest Apex on top of an existing image. Run 4 | ``` 5 | docker build -t new_image_with_apex . 6 | ``` 7 | By default, **Dockerfile** uses NVIDIA's Pytorch container as the base image, 8 | which requires an NVIDIA GPU Cloud (NGC) account. If you don't have an NGC account, you can sign up for free by following the instructions [here](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html#generating-api-key). 9 | 10 | Alternatively, you can supply your own base image via the `BASE_IMAGE` build-arg. 11 | `BASE_IMAGE` must have Pytorch and Cuda installed. For example, any 12 | `-devel` image for Pytorch 1.0 and later from the 13 | [official Pytorch Dockerhub](https://hub.docker.com/r/pytorch/pytorch) may be used: 14 | ``` 15 | docker build --build-arg BASE_IMAGE=pytorch/pytorch:nightly-devel-cuda10.0-cudnn7 -t new_image_with_apex . 16 | ``` 17 | 18 | If you want to rebuild your image, and force the latest Apex to be cloned and installed, make any small change to the `SHA` variable in **Dockerfile**. 19 | 20 | **Warning:** 21 | Currently, the non-`-devel` images on Pytorch Dockerhub do not contain the Cuda compiler `nvcc`. Therefore, 22 | images whose name does not contain `-devel` are not eligible candidates for `BASE_IMAGE`. 23 | 24 | ### Running your Apex container 25 | 26 | Like any Cuda-enabled Pytorch container, a container with Apex should be run via [nvidia-docker](https://github.com/NVIDIA/nvidia-docker), for example: 27 | ``` 28 | docker run --runtime=nvidia -it --rm --ipc=host new_image_with_apex 29 | ``` 30 | 31 | ## Option 2: Install Apex in a running container 32 | 33 | Instead of building a new container, it is also a viable option to `git clone https://github.com/NVIDIA/apex.git` on bare metal, mount the Apex repo into your container at launch by running, for example, 34 | ``` 35 | docker run --runtime=nvidia -it --rm --ipc=host -v /bare/metal/apex:/apex/in/container 36 | ``` 37 | then go to /apex/in/container within the running container and 38 | ``` 39 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 40 | ``` 41 | -------------------------------------------------------------------------------- /apex/examples/simple/distributed/README.md: -------------------------------------------------------------------------------- 1 | **distributed_data_parallel.py** and **run.sh** show an example using Amp with 2 | [apex.parallel.DistributedDataParallel](https://nvidia.github.io/apex/parallel.html) or 3 | [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#distributeddataparallel) 4 | and the Pytorch multiprocess launcher script, 5 | [torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility). 6 | The use of `Amp` with DistributedDataParallel does not need to change from ordinary 7 | single-process use. The only gotcha is that wrapping your model with `DistributedDataParallel` must 8 | come after the call to `amp.initialize`. Test via 9 | ```bash 10 | bash run.sh 11 | ``` 12 | 13 | **This is intended purely as an instructional example, not a performance showcase.** 14 | -------------------------------------------------------------------------------- /apex/examples/simple/distributed/distributed_data_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | from apex import amp 5 | # FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead) 6 | from apex.parallel import DistributedDataParallel 7 | 8 | parser = argparse.ArgumentParser() 9 | # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied 10 | # automatically by torch.distributed.launch. 11 | parser.add_argument("--local_rank", default=0, type=int) 12 | args = parser.parse_args() 13 | 14 | # FOR DISTRIBUTED: If we are running under torch.distributed.launch, 15 | # the 'WORLD_SIZE' environment variable will also be set automatically. 16 | args.distributed = False 17 | if 'WORLD_SIZE' in os.environ: 18 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 19 | 20 | if args.distributed: 21 | # FOR DISTRIBUTED: Set the device according to local_rank. 22 | torch.cuda.set_device(args.local_rank) 23 | 24 | # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide 25 | # environment variables, and requires that you use init_method=`env://`. 26 | torch.distributed.init_process_group(backend='nccl', 27 | init_method='env://') 28 | 29 | torch.backends.cudnn.benchmark = True 30 | 31 | N, D_in, D_out = 64, 1024, 16 32 | 33 | # Each process receives its own batch of "fake input data" and "fake target data." 34 | # The "training loop" in each process just uses this fake batch over and over. 35 | # https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic 36 | # example of distributed data sampling for both training and validation. 37 | x = torch.randn(N, D_in, device='cuda') 38 | y = torch.randn(N, D_out, device='cuda') 39 | 40 | model = torch.nn.Linear(D_in, D_out).cuda() 41 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 42 | 43 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1") 44 | 45 | if args.distributed: 46 | # FOR DISTRIBUTED: After amp.initialize, wrap the model with 47 | # apex.parallel.DistributedDataParallel. 48 | model = DistributedDataParallel(model) 49 | # torch.nn.parallel.DistributedDataParallel is also fine, with some added args: 50 | # model = torch.nn.parallel.DistributedDataParallel(model, 51 | # device_ids=[args.local_rank], 52 | # output_device=args.local_rank) 53 | 54 | loss_fn = torch.nn.MSELoss() 55 | 56 | for t in range(500): 57 | optimizer.zero_grad() 58 | y_pred = model(x) 59 | loss = loss_fn(y_pred, y) 60 | with amp.scale_loss(loss, optimizer) as scaled_loss: 61 | scaled_loss.backward() 62 | optimizer.step() 63 | 64 | if args.local_rank == 0: 65 | print("final loss = ", loss) 66 | -------------------------------------------------------------------------------- /apex/examples/simple/distributed/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m torch.distributed.launch --nproc_per_node=2 distributed_data_parallel.py 3 | -------------------------------------------------------------------------------- /apex/tests/L0/run_amp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/apex/tests/L0/run_amp/__init__.py -------------------------------------------------------------------------------- /apex/tests/L0/run_amp/test_multi_tensor_l2norm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import functools as ft 4 | import itertools as it 5 | 6 | from apex import amp 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | from utils import common_init, HALF, FLOAT,\ 12 | ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT 13 | 14 | try: 15 | import amp_C 16 | from amp_C import multi_tensor_l2norm 17 | from apex.multi_tensor_apply import MultiTensorApply 18 | disabled = False 19 | except ImportError as err: 20 | print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err) 21 | disabled = True 22 | 23 | 24 | class TestMultiTensorL2Norm(unittest.TestCase): 25 | 26 | def setUp(self): 27 | common_init(self) 28 | self.val = 4.0 29 | self.overflow_buf = torch.cuda.IntTensor(1).zero_() 30 | 31 | def tearDown(self): 32 | pass 33 | 34 | # The tensor creation here is written for convenience, not speed. 35 | def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type, per_tensor): 36 | self.overflow_buf.zero_() 37 | a = torch.cuda.FloatTensor(sizea).fill_(self.val) 38 | b = torch.cuda.FloatTensor(sizeb).fill_(self.val) 39 | 40 | in_list = [] 41 | for i in range(repeat_tensors): 42 | in_list += [a.clone().to(in_type), b.clone().to(in_type)] 43 | 44 | if per_tensor: 45 | norm, norm_per_tensor = applier(multi_tensor_l2norm, self.overflow_buf, [in_list], True) 46 | normab = torch.cat((a.norm().view(1), b.norm().view(1))) 47 | norm_per_tensor = norm_per_tensor.view(-1, 2) 48 | else: 49 | norm, _ = applier(multi_tensor_l2norm, self.overflow_buf, [in_list], True) 50 | 51 | reference = torch.cuda.FloatTensor((sizea + sizeb)*repeat_tensors).fill_(self.val).norm() 52 | 53 | self.assertTrue(torch.allclose(norm, reference)) 54 | if per_tensor: 55 | self.assertTrue(torch.allclose(norm_per_tensor, normab)) 56 | self.assertTrue(self.overflow_buf.item() == 0) 57 | 58 | @unittest.skipIf(disabled, "amp_C is unavailable") 59 | def test_fuzz(self): 60 | input_size_pairs = ( 61 | (7777*77, 555*555), 62 | (777, 555), 63 | (555, 2048*32+1), 64 | (2048*32+1, 555), 65 | (555, 2048*32), 66 | (2048*32, 555), 67 | (33333, 555), 68 | (555, 33333)) 69 | appliers = ( 70 | MultiTensorApply(2048*32), 71 | MultiTensorApply(333), 72 | MultiTensorApply(33333)) 73 | repeat_tensors = ( 74 | 1, 75 | 55) 76 | 77 | for sizea, sizeb in input_size_pairs: 78 | for applier in appliers: 79 | for repeat in repeat_tensors: 80 | for in_type in (torch.float32, torch.float16): 81 | for per_tensor in (False, True): 82 | self.l2norm(sizea, sizeb, applier, repeat, in_type, per_tensor) 83 | 84 | 85 | 86 | if __name__ == '__main__': 87 | unittest.main() 88 | -------------------------------------------------------------------------------- /apex/tests/L0/run_amp/test_promotion.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import itertools as it 4 | 5 | from apex import amp 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from utils import common_init, HALF, FLOAT, DTYPES 11 | 12 | class TestPromotion(unittest.TestCase): 13 | def setUp(self): 14 | self.handle = amp.init(enabled=True) 15 | common_init(self) 16 | 17 | def tearDown(self): 18 | self.handle._deactivate() 19 | 20 | def run_binary_promote_test(self, fns, input_shape, x_inplace=False): 21 | type_pairs = it.product(DTYPES, DTYPES) 22 | for fn, (xtype, ytype) in it.product(fns, type_pairs): 23 | x = torch.randn(input_shape, dtype=xtype).requires_grad_() 24 | x_leaf = x 25 | if x_inplace: 26 | # We need a non-leaf to call in place on 27 | x = x.clone() 28 | y = torch.randn(input_shape, dtype=ytype) 29 | out = fn(x, y) 30 | if x_inplace: 31 | # In place: always match xtype 32 | self.assertEqual(out.type(), x.type()) 33 | else: 34 | # Out of place: match widest type 35 | if xtype == torch.float or ytype == torch.float: 36 | self.assertEqual(out.type(), FLOAT) 37 | else: 38 | self.assertEqual(out.type(), HALF) 39 | out.float().sum().backward() 40 | self.assertEqual(x_leaf.grad.dtype, xtype) 41 | 42 | def test_atan2_matches_widest(self): 43 | fns = [lambda x, y : torch.atan2(x, y), 44 | lambda x, y : x.atan2(y)] 45 | self.run_binary_promote_test(fns, (self.b,)) 46 | 47 | def test_mul_matches_widest(self): 48 | fns = [lambda x, y : torch.mul(x, y), 49 | lambda x, y: x.mul(y)] 50 | self.run_binary_promote_test(fns, (self.b,)) 51 | 52 | def test_cat_matches_widest(self): 53 | shape = self.b 54 | ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)] 55 | x_float = torch.randn(shape) 56 | out = torch.cat(ys + [x_float]) 57 | self.assertEqual(out.type(), FLOAT) 58 | x_half = torch.randn(shape, dtype=torch.half) 59 | out = torch.cat(ys + [x_half]) 60 | self.assertEqual(out.type(), HALF) 61 | 62 | def test_inplace_exp_is_error_for_half(self): 63 | xs = torch.randn(self.b) 64 | xs.exp_() 65 | self.assertEqual(xs.type(), FLOAT) 66 | xs = torch.randn(self.b, dtype=torch.half) 67 | with self.assertRaises(NotImplementedError): 68 | xs.exp_() 69 | 70 | def test_inplace_add_matches_self(self): 71 | fn = lambda x, y: x.add_(y) 72 | self.run_binary_promote_test([fn], (self.b,), x_inplace=True) 73 | 74 | if __name__ == '__main__': 75 | unittest.main() 76 | -------------------------------------------------------------------------------- /apex/tests/L0/run_amp/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | HALF = 'torch.cuda.HalfTensor' 4 | FLOAT = 'torch.cuda.FloatTensor' 5 | 6 | DTYPES = [torch.half, torch.float] 7 | 8 | ALWAYS_HALF = {torch.float: HALF, 9 | torch.half: HALF} 10 | ALWAYS_FLOAT = {torch.float: FLOAT, 11 | torch.half: FLOAT} 12 | MATCH_INPUT = {torch.float: FLOAT, 13 | torch.half: HALF} 14 | 15 | def common_init(test_case): 16 | test_case.h = 64 17 | test_case.b = 16 18 | test_case.c = 16 19 | test_case.k = 3 20 | test_case.t = 10 21 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 22 | -------------------------------------------------------------------------------- /apex/tests/L0/run_fp16util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/apex/tests/L0/run_fp16util/__init__.py -------------------------------------------------------------------------------- /apex/tests/L0/run_fp16util/test_fp16util.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from apex.fp16_utils import FP16Model 7 | 8 | 9 | class DummyBlock(nn.Module): 10 | def __init__(self): 11 | super(DummyBlock, self).__init__() 12 | 13 | self.conv = nn.Conv2d(10, 10, 2) 14 | self.bn = nn.BatchNorm2d(10, affine=True) 15 | 16 | def forward(self, x): 17 | return self.conv(self.bn(x)) 18 | 19 | 20 | class DummyNet(nn.Module): 21 | def __init__(self): 22 | super(DummyNet, self).__init__() 23 | 24 | self.conv1 = nn.Conv2d(3, 10, 2) 25 | self.bn1 = nn.BatchNorm2d(10, affine=False) 26 | self.db1 = DummyBlock() 27 | self.db2 = DummyBlock() 28 | 29 | def forward(self, x): 30 | out = x 31 | out = self.conv1(out) 32 | out = self.bn1(out) 33 | out = self.db1(out) 34 | out = self.db2(out) 35 | return out 36 | 37 | 38 | class DummyNetWrapper(nn.Module): 39 | def __init__(self): 40 | super(DummyNetWrapper, self).__init__() 41 | 42 | self.bn = nn.BatchNorm2d(3, affine=True) 43 | self.dn = DummyNet() 44 | 45 | def forward(self, x): 46 | return self.dn(self.bn(x)) 47 | 48 | 49 | class TestFP16Model(unittest.TestCase): 50 | def setUp(self): 51 | self.N = 64 52 | self.C_in = 3 53 | self.H_in = 16 54 | self.W_in = 32 55 | self.in_tensor = torch.randn((self.N, self.C_in, self.H_in, self.W_in)).cuda() 56 | self.orig_model = DummyNetWrapper().cuda() 57 | self.fp16_model = FP16Model(self.orig_model) 58 | 59 | def test_params_and_buffers(self): 60 | exempted_modules = [ 61 | self.fp16_model.network.bn, 62 | self.fp16_model.network.dn.db1.bn, 63 | self.fp16_model.network.dn.db2.bn, 64 | ] 65 | for m in self.fp16_model.modules(): 66 | expected_dtype = torch.float if (m in exempted_modules) else torch.half 67 | for p in m.parameters(recurse=False): 68 | assert p.dtype == expected_dtype 69 | for b in m.buffers(recurse=False): 70 | assert b.dtype in (expected_dtype, torch.int64) 71 | 72 | def test_output_is_half(self): 73 | out_tensor = self.fp16_model(self.in_tensor) 74 | assert out_tensor.dtype == torch.half 75 | 76 | -------------------------------------------------------------------------------- /apex/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import random 4 | 5 | import torch 6 | import apex 7 | 8 | 9 | class TestFusedLayerNorm(unittest.TestCase): 10 | def setUp(self): 11 | self.module = apex.normalization.FusedLayerNorm(normalized_shape=[32, 64], elementwise_affine=False) 12 | self.input_ = torch.randn(16, 32, 64) 13 | torch.cuda.manual_seed(42) 14 | 15 | def forward_cpu(self, input_): 16 | self.module.cpu() 17 | return self.module(input_.cpu()) 18 | 19 | def forward_cuda(self, input_): 20 | self.module.cuda() 21 | return self.module(input_.cuda()) 22 | 23 | def test_forward_cuda(self): 24 | out_ = self.forward_cuda(self.input_) 25 | assert out_.is_cuda == True 26 | 27 | def test_forward_cpu(self): 28 | out_ = self.forward_cpu(self.input_) 29 | assert out_.is_cuda == False 30 | 31 | def test_same_output(self): 32 | out_cpu = self.forward_cpu(self.input_) 33 | out_cuda = self.forward_cuda(self.input_) 34 | torch.testing.assert_allclose(out_cpu, out_cuda.cpu()) 35 | 36 | 37 | class TestFusedLayerNormElemWise(TestFusedLayerNorm): 38 | def setUp(self): 39 | self.module = apex.normalization.FusedLayerNorm(normalized_shape=[32, 64], elementwise_affine=True) 40 | self.input_ = torch.randn(16, 32, 64) 41 | torch.cuda.manual_seed(42) -------------------------------------------------------------------------------- /apex/tests/L0/run_mixed_adam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/apex/tests/L0/run_mixed_adam/__init__.py -------------------------------------------------------------------------------- /apex/tests/L0/run_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import sys 3 | 4 | test_dirs = ["run_amp", "run_fp16util", "run_mixed_adam", "run_fused_layer_norm"] 5 | 6 | runner = unittest.TextTestRunner(verbosity=2) 7 | 8 | errcode = 0 9 | 10 | for test_dir in test_dirs: 11 | suite = unittest.TestLoader().discover(test_dir) 12 | 13 | print("\nExecuting tests from " + test_dir) 14 | 15 | result = runner.run(suite) 16 | 17 | if not result.wasSuccessful(): 18 | errcode = 1 19 | 20 | sys.exit(errcode) 21 | -------------------------------------------------------------------------------- /apex/tests/L1/common/compare.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | parser = argparse.ArgumentParser(description='Compare') 5 | parser.add_argument('--opt-level', type=str) 6 | parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) 7 | parser.add_argument('--loss-scale', type=str, default=None) 8 | parser.add_argument('--fused-adam', action='store_true') 9 | parser.add_argument('--use_baseline', action='store_true') 10 | args = parser.parse_args() 11 | 12 | base_file = str(args.opt_level) + "_" +\ 13 | str(args.loss_scale) + "_" +\ 14 | str(args.keep_batchnorm_fp32) + "_" +\ 15 | str(args.fused_adam) 16 | 17 | file_e = "True_" + base_file 18 | file_p = "False_" + base_file 19 | if args.use_baseline: 20 | file_b = "baselines/True_" + base_file 21 | 22 | dict_e = torch.load(file_e) 23 | dict_p = torch.load(file_p) 24 | if args.use_baseline: 25 | dict_b = torch.load(file_b) 26 | 27 | torch.set_printoptions(precision=10) 28 | 29 | print(file_e) 30 | print(file_p) 31 | if args.use_baseline: 32 | print(file_b) 33 | 34 | # ugly duplication here... 35 | if not args.use_baseline: 36 | for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])): 37 | assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p) 38 | 39 | loss_e = dict_e["Loss"][n] 40 | loss_p = dict_p["Loss"][n] 41 | assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p) 42 | print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format( 43 | i_e, 44 | loss_e, 45 | loss_p, 46 | dict_e["Speed"][n], 47 | dict_p["Speed"][n])) 48 | else: 49 | for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])): 50 | assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p) 51 | 52 | loss_e = dict_e["Loss"][n] 53 | loss_p = dict_p["Loss"][n] 54 | loss_b = dict_b["Loss"][n] 55 | assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p) 56 | assert loss_e == loss_b, "Iteration {}, loss_e = {}, loss_b = {}".format(i_e, loss_e, loss_b) 57 | print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format( 58 | i_e, 59 | loss_b, 60 | loss_e, 61 | loss_p, 62 | dict_b["Speed"][n], 63 | dict_e["Speed"][n], 64 | dict_p["Speed"][n])) 65 | -------------------------------------------------------------------------------- /apex/tests/L1/cross_product/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/" 4 | # DATADIR="/opt/home/apex/examples/imagenet/" 5 | cp ../common/* . 6 | bash run_test.sh single_gpu $1 $DATADIR yes 7 | -------------------------------------------------------------------------------- /apex/tests/L1/cross_product_distributed/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cp ../common/* . 4 | bash run_test.sh distributed $1 5 | -------------------------------------------------------------------------------- /apex/tests/distributed/DDP/ddp_race_condition_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.nn import Parameter 4 | from torch.nn import Module 5 | from apex.parallel import DistributedDataParallel as DDP 6 | import argparse 7 | import os 8 | 9 | 10 | parser = argparse.ArgumentParser(description='allreduce hook example') 11 | parser.add_argument("--local_rank", default=0, type=int) 12 | args = parser.parse_args() 13 | 14 | args.distributed = False 15 | if 'WORLD_SIZE' in os.environ: 16 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 17 | 18 | if args.distributed: 19 | args.gpu = args.local_rank % torch.cuda.device_count() 20 | torch.cuda.set_device(args.gpu) 21 | torch.distributed.init_process_group(backend='nccl', 22 | init_method='env://') 23 | args.world_size = torch.distributed.get_world_size() 24 | 25 | torch.set_printoptions(precision=10) 26 | torch.manual_seed(args.local_rank) 27 | 28 | class Model(Module): 29 | def __init__(self): 30 | super(Model, self).__init__() 31 | self.a = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(1.0)) 32 | self.b = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(2.0)) 33 | def forward(self, input): 34 | return (input*self.a)*self.b 35 | 36 | model = Model() 37 | # model = DDP(model, message_size=1, gradient_predivide_factor=8.0) 38 | model = DDP(model, delay_allreduce=True) 39 | # model = DDP(model, message_size=1, allreduce_trigger_params=[model.b]) 40 | 41 | x = torch.cuda.FloatTensor(4096*4096) 42 | 43 | passed = True 44 | torch.cuda.cudart().cudaProfilerStart() 45 | for i in range(10): 46 | x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity 47 | model.zero_grad() 48 | out = model(x) 49 | loss = out.sum() 50 | # torch.cuda.nvtx.range_push("backward") 51 | loss.backward() 52 | # torch.cuda.nvtx.range_pop() 53 | 54 | # torch.cuda.nvtx.range_push("synchronize() + info") 55 | # torch.cuda.synchronize() 56 | print("i = {}".format(i)) 57 | def info(name, param, val): 58 | expected = val*4096*4096*(2.*i+1)/2. 59 | actual = param.grad.data.sum().item() 60 | print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format( 61 | param.grad.data_ptr(), expected, actual)) 62 | return (expected == actual) 63 | if not info("model.a", model.module.a, 2.): passed = False 64 | if not info("model.b", model.module.b, 1.): passed = False 65 | # torch.cuda.nvtx.range_pop() 66 | torch.cuda.cudart().cudaProfilerStop() 67 | 68 | print("passed = ", passed) 69 | -------------------------------------------------------------------------------- /apex/tests/distributed/DDP/run_race_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 ddp_race_condition_test.py 4 | -------------------------------------------------------------------------------- /apex/tests/distributed/amp_master_params/amp_master_params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | from apex import amp 5 | # FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead) 6 | from apex.parallel import DistributedDataParallel 7 | 8 | parser = argparse.ArgumentParser() 9 | # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied 10 | # automatically by torch.distributed.launch. 11 | parser.add_argument("--local_rank", default=0, type=int) 12 | args = parser.parse_args() 13 | 14 | # FOR DISTRIBUTED: If we are running under torch.distributed.launch, 15 | # the 'WORLD_SIZE' environment variable will also be set automatically. 16 | args.distributed = False 17 | if 'WORLD_SIZE' in os.environ: 18 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 19 | 20 | if args.distributed: 21 | # FOR DISTRIBUTED: Set the device according to local_rank. 22 | torch.cuda.set_device(args.local_rank) 23 | 24 | # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide 25 | # environment variables, and requires that you use init_method=`env://`. 26 | torch.distributed.init_process_group(backend='nccl', 27 | init_method='env://') 28 | 29 | torch.manual_seed(torch.distributed.get_rank()) 30 | 31 | torch.backends.cudnn.benchmark = True 32 | 33 | N, D_in, D_out = 64, 1024, 16 34 | 35 | # Each process receives its own batch of "fake input data" and "fake target data." 36 | # The "training loop" in each process just uses this fake batch over and over. 37 | # https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic 38 | # example of distributed data sampling for both training and validation. 39 | x = torch.randn(N, D_in, device='cuda') 40 | y = torch.randn(N, D_out, device='cuda') 41 | 42 | model = torch.nn.Linear(D_in, D_out).cuda() 43 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 44 | 45 | model, optimizer = amp.initialize(model, optimizer, opt_level="O2") 46 | 47 | if args.distributed: 48 | # FOR DISTRIBUTED: After amp.initialize, wrap the model with 49 | # apex.parallel.DistributedDataParallel. 50 | model = DistributedDataParallel(model) 51 | # torch.nn.parallel.DistributedDataParallel is also fine, with some added args: 52 | # model = torch.nn.parallel.DistributedDataParallel(model, 53 | # device_ids=[args.local_rank], 54 | # output_device=args.local_rank) 55 | 56 | loss_fn = torch.nn.MSELoss() 57 | 58 | for t in range(500): 59 | optimizer.zero_grad() 60 | y_pred = model(x) 61 | loss = loss_fn(y_pred, y) 62 | with amp.scale_loss(loss, optimizer) as scaled_loss: 63 | scaled_loss.backward() 64 | optimizer.step() 65 | 66 | if args.local_rank == 0: 67 | print("final loss = ", loss) 68 | 69 | torch.save(list(model.parameters()), "rank{}model.pth".format(torch.distributed.get_rank())) 70 | torch.save(list(amp.master_params(optimizer)), "rank{}master.pth".format(torch.distributed.get_rank())) 71 | -------------------------------------------------------------------------------- /apex/tests/distributed/amp_master_params/compare.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | model_params_rank0 = torch.load("rank0model.pth", 4 | map_location = lambda storage, loc: storage.cuda(0)) 5 | model_params_rank1 = torch.load("rank1model.pth", 6 | map_location = lambda storage, loc: storage.cuda(0)) 7 | master_params_rank0 = torch.load("rank0master.pth", 8 | map_location = lambda storage, loc: storage.cuda(0)) 9 | master_params_rank1 = torch.load("rank1master.pth", 10 | map_location = lambda storage, loc: storage.cuda(0)) 11 | 12 | for model_rank0, model_rank1, master_rank0, master_rank1 in zip( 13 | model_params_rank0, 14 | model_params_rank1, 15 | master_params_rank0, 16 | master_params_rank1): 17 | assert torch.allclose(model_rank0, model_rank1), "Model param mismatch" 18 | assert torch.allclose(master_rank0, master_rank1), "Master param mismatch" 19 | # Some debugging/investigation assistance code: 20 | # maxval, maxind = torch.max(((torch.abs(model_rank0).float())/torch.abs(master_rank0)).view(-1), 0) 21 | # offending_val_half = model_rank0.view(-1)[maxind.item()] 22 | # offending_val_float = master_rank0.view(-1)[maxind.item()] 23 | # print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(), 24 | # offending_val_float.half().item()) 25 | # rtol needs to be > 2^-11 because of denormals... 26 | assert torch.allclose(model_rank0, master_rank0.half(), rtol=.005), "Model-master mismatch" 27 | 28 | print("OK: Model and master params match across ranks.") 29 | -------------------------------------------------------------------------------- /apex/tests/distributed/amp_master_params/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py 3 | 4 | python compare.py 5 | -------------------------------------------------------------------------------- /apex/tests/distributed/synced_batchnorm/unit_test.sh: -------------------------------------------------------------------------------- 1 | python single_gpu_unit_test.py 2 | python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py 3 | python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp64 4 | #beware, you need a system with at least 4 gpus to test group_size'] = 0 14 | self.n_vocab = len(vocab) + 1 15 | self.tokens = {v: k for k, v in self.vocab.items()} 16 | self.tokens[0] = '' # became '' 17 | self.not_vocab = not_vocab 18 | 19 | def clean(self, text): 20 | text = unidecode(text) # Convert to ascii 21 | text = text.replace('\\', '\n') 22 | text = self.not_vocab.sub('', text) # Remove non vocab 23 | return text 24 | 25 | def tokenise(self, text): 26 | return [self.vocab[char] for char in text] 27 | 28 | def textise(self, tokens): 29 | return ''.join([self.tokens[token] for token in tokens]) 30 | 31 | def characterise(self, tokens): 32 | return [self.tokens[token] for token in tokens] 33 | -------------------------------------------------------------------------------- /jukebox/prior/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/jukebox/prior/__init__.py -------------------------------------------------------------------------------- /jukebox/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/jukebox/transformer/__init__.py -------------------------------------------------------------------------------- /jukebox/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/jukebox/utils/__init__.py -------------------------------------------------------------------------------- /jukebox/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Simple gradient checkpointing. Works with distributed data parallel 2 | import torch as t 3 | 4 | def checkpoint(func, inputs, params, flag): 5 | if flag: 6 | args = inputs + tuple(params) 7 | return CheckpointFunction.apply(func, len(inputs), *args) 8 | else: 9 | return func(*inputs) 10 | 11 | class CheckpointFunction(t.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, run_function, length, *args): 14 | ctx.run_function = run_function 15 | ctx.input_tensors = list(args[:length]) 16 | ctx.input_params = list(args[length:]) 17 | with t.no_grad(): 18 | output_tensors = ctx.run_function(*ctx.input_tensors) 19 | return output_tensors 20 | 21 | @staticmethod 22 | def backward(ctx, *output_grads): 23 | for i in range(len(ctx.input_tensors)): 24 | temp = ctx.input_tensors[i] 25 | ctx.input_tensors[i] = temp.detach() 26 | ctx.input_tensors[i].requires_grad = temp.requires_grad 27 | with t.enable_grad(): 28 | output_tensors = ctx.run_function(*ctx.input_tensors) 29 | input_grads = t.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) 30 | del ctx.input_tensors 31 | del output_tensors 32 | return (None, None) + input_grads 33 | -------------------------------------------------------------------------------- /jukebox/utils/dist_adapter.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | from enum import Enum 3 | 4 | class ReduceOp(Enum): 5 | SUM = 0, 6 | PRODUCT = 1, 7 | MIN = 2, 8 | MAX = 3 9 | 10 | def ToDistOp(self): 11 | return { 12 | self.SUM: dist.ReduceOp.SUM, 13 | self.PRODUCT: dist.ReduceOp.PRODUCT, 14 | self.MIN: dist.ReduceOp.MIN, 15 | self.MAX: dist.ReduceOp.MAX 16 | }[self] 17 | 18 | def is_available(): 19 | return dist.is_available() 20 | 21 | def get_rank(): 22 | if is_available(): 23 | return _get_rank() 24 | else: 25 | return 0 26 | 27 | def get_world_size(): 28 | if is_available(): 29 | return _get_world_size() 30 | else: 31 | return 1 32 | 33 | def barrier(): 34 | if is_available(): 35 | return _barrier() 36 | #else: do nothing 37 | 38 | def all_gather(tensor_list, tensor): 39 | if is_available(): 40 | return _all_gather(tensor_list, tensor) 41 | else: 42 | tensor_list[0] = tensor 43 | 44 | def all_reduce(tensor, op=ReduceOp.SUM): 45 | if is_available(): 46 | return _all_reduce(tensor, op) 47 | #else: do nothing 48 | 49 | def reduce(tensor, dst, op=ReduceOp.SUM): 50 | if is_available(): 51 | return _reduce(tensor, dst, op) 52 | #else: do nothing 53 | 54 | def broadcast(tensor, src): 55 | if is_available(): 56 | return _broadcast(tensor, src) 57 | #else: do nothing 58 | 59 | def init_process_group(backend, init_method): 60 | if is_available(): 61 | return _init_process_group(backend, init_method) 62 | #else: do nothing 63 | 64 | def _get_rank(): 65 | return dist.get_rank() 66 | 67 | def _barrier(): 68 | return dist.barrier() 69 | 70 | def _get_world_size(): 71 | return dist.get_world_size() 72 | 73 | def _all_gather(tensor_list, tensor): 74 | return dist.all_gather(tensor_list, tensor) 75 | 76 | def _all_reduce(tensor, op): 77 | return dist.all_reduce(tensor, op.ToDistOp()) 78 | 79 | def _reduce(tensor, dst, op): 80 | return dist.reduce(tensor, dst, op.ToDistOp()) 81 | 82 | def _broadcast(tensor, src): 83 | return dist.broadcast(tensor, src) 84 | 85 | def _init_process_group(backend, init_method): 86 | return dist.init_process_group(backend, init_method) -------------------------------------------------------------------------------- /jukebox/utils/remote_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import subprocess 3 | 4 | def download(remote_path, local_path, async_download=False): 5 | args = ['wget', '-O', local_path, remote_path] 6 | print("Running ", " ".join(args)) 7 | if async_download: 8 | subprocess.Popen(args) 9 | else: 10 | subprocess.call(args) 11 | 12 | # GCE 13 | def gs_download(gs_path, local_path, async_download=False): 14 | args = ['gsutil', 15 | '-o', 'GSUtil:parallel_thread_count=1', 16 | '-o', 'GSUtil:sliced_object_download_max_components=8', 17 | 'cp', gs_path, local_path] 18 | if async_download: 19 | subprocess.Popen(args) 20 | else: 21 | subprocess.call(args) 22 | 23 | 24 | def gs_upload(local_path, gs_path, async_upload=False): 25 | # NOTE: Download and upload have differ -o flags. 26 | # We also use -n to prevent clobbering checkpoints by mistake 27 | assert not local_path.startswith("gs://") 28 | assert gs_path.startswith("gs://") 29 | args = ['gsutil', 30 | '-o', 'GSUtil:parallel_composite_upload_threshold=150M', 31 | 'cp', '-n', local_path, gs_path] 32 | if async_upload: 33 | subprocess.Popen(args) 34 | else: 35 | subprocess.call(args) 36 | 37 | def ls(regex): 38 | outputs = subprocess.check_output(['gsutil', 'ls', regex]).decode(sys.stdout.encoding) 39 | outputs = outputs.split('\n') 40 | outputs = [output for output in outputs if output is not ''] 41 | return outputs 42 | 43 | -------------------------------------------------------------------------------- /jukebox/utils/sample_utils.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | 3 | def split_batch(obj, n_samples, split_size): 4 | n_passes = (n_samples + split_size - 1) // split_size 5 | if isinstance(obj, t.Tensor): 6 | return t.split(obj, split_size, dim=0) 7 | elif isinstance(obj, list): 8 | return list(zip(*[t.split(item, split_size, dim=0) for item in obj])) 9 | elif obj is None: 10 | return [None] * n_passes 11 | else: 12 | raise TypeError('Unknown input type') 13 | 14 | # Break total_length into hops/windows of size n_ctx separated by hop_length 15 | def get_starts(total_length, n_ctx, hop_length): 16 | starts = [] 17 | for start in range(0, total_length - n_ctx + hop_length, hop_length): 18 | if start + n_ctx >= total_length: 19 | # Last hop could be smaller, we make it n_ctx to maximise context 20 | start = total_length - n_ctx 21 | starts.append(start) 22 | return starts 23 | -------------------------------------------------------------------------------- /jukebox/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch as t 3 | 4 | def freeze_model(model): 5 | model.eval() 6 | for params in model.parameters(): 7 | params.requires_grad = False 8 | 9 | 10 | def unfreeze_model(model): 11 | model.train() 12 | for params in model.parameters(): 13 | params.requires_grad = True 14 | 15 | def zero_grad(model): 16 | for p in model.parameters(): 17 | if p.requires_grad and p.grad is not None: 18 | p.grad = None 19 | 20 | def empty_cache(): 21 | gc.collect() 22 | t.cuda.empty_cache() 23 | 24 | def assert_shape(x, exp_shape): 25 | assert x.shape == exp_shape, f"Expected {exp_shape} got {x.shape}" 26 | 27 | def count_parameters(model): 28 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 29 | 30 | def count_state(model): 31 | return sum(s.numel() for s in model.state_dict().values()) 32 | 33 | -------------------------------------------------------------------------------- /jukebox/vqvae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/jukebox/vqvae/__init__.py -------------------------------------------------------------------------------- /jukebox/vqvae/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import jukebox.utils.dist_adapter as dist 4 | from jukebox.utils.checkpoint import checkpoint 5 | 6 | class ResConvBlock(nn.Module): 7 | def __init__(self, n_in, n_state): 8 | super().__init__() 9 | self.model = nn.Sequential( 10 | nn.ReLU(), 11 | nn.Conv2d(n_in, n_state, 3, 1, 1), 12 | nn.ReLU(), 13 | nn.Conv2d(n_state, n_in, 1, 1, 0), 14 | ) 15 | 16 | def forward(self, x): 17 | return x + self.model(x) 18 | 19 | class Resnet(nn.Module): 20 | def __init__(self, n_in, n_depth, m_conv=1.0): 21 | super().__init__() 22 | self.model = nn.Sequential(*[ResConvBlock(n_in, int(m_conv * n_in)) for _ in range(n_depth)]) 23 | 24 | def forward(self, x): 25 | return self.model(x) 26 | 27 | class ResConv1DBlock(nn.Module): 28 | def __init__(self, n_in, n_state, dilation=1, zero_out=False, res_scale=1.0): 29 | super().__init__() 30 | padding = dilation 31 | self.model = nn.Sequential( 32 | nn.ReLU(), 33 | nn.Conv1d(n_in, n_state, 3, 1, padding, dilation), 34 | nn.ReLU(), 35 | nn.Conv1d(n_state, n_in, 1, 1, 0), 36 | ) 37 | if zero_out: 38 | out = self.model[-1] 39 | nn.init.zeros_(out.weight) 40 | nn.init.zeros_(out.bias) 41 | self.res_scale = res_scale 42 | 43 | def forward(self, x): 44 | return x + self.res_scale * self.model(x) 45 | 46 | class Resnet1D(nn.Module): 47 | def __init__(self, n_in, n_depth, m_conv=1.0, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False, reverse_dilation=False, checkpoint_res=False): 48 | super().__init__() 49 | def _get_depth(depth): 50 | if dilation_cycle is None: 51 | return depth 52 | else: 53 | return depth % dilation_cycle 54 | blocks = [ResConv1DBlock(n_in, int(m_conv * n_in), 55 | dilation=dilation_growth_rate ** _get_depth(depth), 56 | zero_out=zero_out, 57 | res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth)) 58 | for depth in range(n_depth)] 59 | if reverse_dilation: 60 | blocks = blocks[::-1] 61 | self.checkpoint_res = checkpoint_res 62 | if self.checkpoint_res == 1: 63 | if dist.get_rank() == 0: 64 | print("Checkpointing convs") 65 | self.blocks = nn.ModuleList(blocks) 66 | else: 67 | self.model = nn.Sequential(*blocks) 68 | 69 | def forward(self, x): 70 | if self.checkpoint_res == 1: 71 | for block in self.blocks: 72 | x = checkpoint(block, (x, ), block.parameters(), True) 73 | return x 74 | else: 75 | return self.model(x) 76 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fire==0.1.3 2 | tqdm==4.45.0 3 | soundfile==0.10.3.post1 4 | unidecode==1.1.1 5 | numba==0.48.0 6 | librosa==0.7.2 7 | mpi4py>=3.0.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="jukebox", 8 | py_modules=["jukebox"], 9 | version="1.0", 10 | description="", 11 | author="OpenAI", 12 | packages=find_packages(), 13 | install_requires=[ 14 | str(r) 15 | for r in pkg_resources.parse_requirements( 16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 17 | ) 18 | ], 19 | include_package_data=True 20 | ) 21 | -------------------------------------------------------------------------------- /tensorboardX/.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: # measuring the overall project coverage 4 | default: # context, you can create multiple ones with custom titles 5 | enabled: yes 6 | patch: 7 | default: 8 | enabled: no 9 | -------------------------------------------------------------------------------- /tensorboardX/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E305,E402,E721,E741,F401,F403,F405,F821,F841,F999 4 | exclude = tensorboardX/proto -------------------------------------------------------------------------------- /tensorboardX/.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create bug report 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Minimal runnable code to reproduce the behavior** 14 | ``` 15 | from tensorboardX import SummaryWriter 16 | ... 17 | ``` 18 | 19 | **Expected behavior** 20 | A clear and concise description of what you expected to happen. 21 | 22 | **Screenshots** 23 | If applicable, add screenshots to help explain your problem. 24 | 25 | **Environment** 26 | What is the result of 27 | `pip list|grep -E "torch|proto|tensor"` 28 | If the version is too old, please try to update first. 29 | 30 | 31 | **Python environment** 32 | Which version of python are you using? Did you use Andconda or Virtualenv? 33 | 34 | **Additional context** 35 | Add any other context about the problem here. 36 | -------------------------------------------------------------------------------- /tensorboardX/.github/ISSUE_TEMPLATE/feature-requests-or-general-questions.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature requests or General questions 3 | about: Feature requests or general questions 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /tensorboardX/.gitignore: -------------------------------------------------------------------------------- 1 | proto_src/ 2 | protoc-*.zip 3 | protoc/ 4 | __pycache__ 5 | docs/_* 6 | build 7 | dist 8 | *.egg-info 9 | runs/* 10 | *.pyc 11 | -------------------------------------------------------------------------------- /tensorboardX/.travis.yml: -------------------------------------------------------------------------------- 1 | dist: xenial 2 | language: python 3 | python: 4 | # We don't actually use the Travis Python, but this keeps it organized. 5 | - "2.7" 6 | - "3.6" 7 | 8 | env: 9 | - PYTORCH_VER="torch" 10 | - PYTORCH_VER="torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html" 11 | 12 | matrix: 13 | allow_failures: 14 | - env: PYTORCH_VER="torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html" 15 | 16 | install: 17 | - export MPLBACKEND=Agg 18 | - export CODECOV_TOKEN="26239910-fe4e-463d-aa3d-e662e9bf39ef" 19 | 20 | - sudo apt-get update 21 | # We do this conditionally because it saves us some downloading if the 22 | # version is the same. 23 | - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 24 | wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; 25 | else 26 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 27 | fi 28 | - bash miniconda.sh -b -p $HOME/miniconda 29 | - export PATH="$HOME/miniconda/bin:$PATH" 30 | - export BOTO_CONFIG=/dev/null # https://github.com/travis-ci/travis-ci/issues/7940 31 | - export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 32 | - hash -r 33 | - conda config --set always_yes yes --set changeps1 no 34 | - conda update -q conda 35 | # Useful for debugging any issues with conda 36 | - conda info -a 37 | 38 | # Replace dep1 dep2 ... with your dependencies 39 | - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION 40 | - source activate test-environment 41 | - which python 42 | - pip install future 43 | - pip install chainer -q 44 | - pip install torchvision==0.2.1 -q 45 | - pip uninstall torch -y 46 | - pip install $PYTORCH_VER 47 | - pip install moviepy==0.2.3.2 -q 48 | - pip install matplotlib -q 49 | - pip install requests -q 50 | - pip install codecov 51 | - pip install onnx 52 | - pip install boto3 53 | - pip install moto 54 | - pip install visdom 55 | - pip install tb-nightly 56 | - pip install crc32c 57 | - pip install protobuf==3.8.0 58 | - conda install ffmpeg 59 | - conda list 60 | - python -c "import imageio; imageio.plugins.ffmpeg.download()" 61 | - pip install --upgrade pytest-cov flake8 62 | - python setup.py install 63 | 64 | script: 65 | - visdom & 66 | - sleep 5 67 | - python -c "import visdom; v = visdom.Visdom()" 68 | - py.test --cov=tensorboardX tests/ 69 | - python examples/demo.py 70 | - python examples/demo_graph.py 71 | - python examples/demo_embedding.py 72 | - python examples/demo_custom_scalars.py 73 | - python examples/demo_multiple_embedding.py 74 | - python examples/demo_purge.py 75 | - python examples/demo_matplotlib.py 76 | - pip uninstall -y tensorboardX 77 | - pip install tensorboardX 78 | - pytest 79 | 80 | after_success: 81 | - codecov 82 | -------------------------------------------------------------------------------- /tensorboardX/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Tzu-Wei Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tensorboardX/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include HISTORY.rst 2 | include LICENSE 3 | include compile.sh 4 | recursive-include tensorboardX/proto * 5 | recursive-exclude test * 6 | recursive-exclude examples * 7 | recursive-include tensorboardX/beholder * -------------------------------------------------------------------------------- /tensorboardX/compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exit on error 4 | # set -e 5 | 6 | DESIRED_PROTO_VERSION="3.6.1" 7 | 8 | # call protoc direclty, if version is not the desired one, download the desired vesrion. 9 | 10 | 11 | if [ -f "protoc/bin/protoc" ]; then 12 | PROTOC_BIN="protoc/bin/protoc" 13 | else 14 | PROTOC_BIN=`which protoc` 15 | fi 16 | 17 | echo "using" $PROTOC_BIN 18 | 19 | CURRENT_PROTOC_VER=`${PROTOC_BIN} --version` 20 | if [ -z ${PROTOC_BIN} ] || [[ "$CURRENT_PROTOC_VER" != "libprotoc "$DESIRED_PROTO_VERSION ]]; then 21 | # Download and use the latest version of protoc. 22 | if [ "$(uname)" == "Darwin" ]; then 23 | PROTOC_ZIP="protoc-"$DESIRED_PROTO_VERSION"-osx-x86_64.zip" 24 | else 25 | PROTOC_ZIP="protoc-"$DESIRED_PROTO_VERSION"-linux-x86_64.zip" 26 | fi 27 | WGET_BIN=`which wget` 28 | if [[ ! -z ${WGET_BIN} ]]; then 29 | ${WGET_BIN} https://github.com/protocolbuffers/protobuf/releases/download/v"$DESIRED_PROTO_VERSION"/${PROTOC_ZIP} 30 | rm -rf protoc 31 | python -c "import zipfile; zipfile.ZipFile('"${PROTOC_ZIP}"','r').extractall('protoc')" 32 | PROTOC_BIN=protoc/bin/protoc 33 | chmod +x ${PROTOC_BIN} 34 | fi 35 | fi 36 | 37 | # Regenerate 38 | if [[ ! -z ${PROTOC_BIN} ]]; then 39 | # Delete all existing Python protobuf (*_pb2.py) output 40 | rm -rf tensorboardX/proto/*pb2*.py 41 | ${PROTOC_BIN} tensorboardX/proto/*.proto --python_out=. 42 | 43 | echo "Done generating tensorboardX/proto/*pb2*.py" 44 | else 45 | echo "protoc not installed so can't regenerate tensorboardX/proto/*pb2*.py, using precompiled version." 46 | fi 47 | 48 | -------------------------------------------------------------------------------- /tensorboardX/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = tensorboardX 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /tensorboardX/docs/index.rst: -------------------------------------------------------------------------------- 1 | .. tensorboardX documentation master file, created by 2 | sphinx-quickstart on Wed Aug 9 01:38:01 2017. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to tensorboardX's documentation! 7 | =============================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | tensorboard 14 | utils 15 | tutorial 16 | tutorial_zh 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /tensorboardX/docs/tensorboard.rst: -------------------------------------------------------------------------------- 1 | tensorboardX 2 | =================================== 3 | .. automodule:: tensorboardX 4 | 5 | .. autoclass:: SummaryWriter 6 | :members: 7 | 8 | .. automethod:: __init__ 9 | 10 | .. autoclass:: TorchVis 11 | :members: 12 | 13 | .. automethod:: __init__ -------------------------------------------------------------------------------- /tensorboardX/docs/utils.rst: -------------------------------------------------------------------------------- 1 | Helper functions 2 | =================================== 3 | .. autofunction:: tensorboardX.utils.figure_to_image -------------------------------------------------------------------------------- /tensorboardX/examples/RUN_AFTER_PIP_INSTALL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/examples/RUN_AFTER_PIP_INSTALL -------------------------------------------------------------------------------- /tensorboardX/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/examples/__init__.py -------------------------------------------------------------------------------- /tensorboardX/examples/chainer/extension_logger/updater.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import chainer 6 | import chainer.functions as F 7 | from chainer import Variable 8 | 9 | 10 | class DCGANUpdater(chainer.training.StandardUpdater): 11 | 12 | def __init__(self, *args, **kwargs): 13 | self.gen, self.dis = kwargs.pop('models') 14 | super(DCGANUpdater, self).__init__(*args, **kwargs) 15 | 16 | def loss_dis(self, dis, y_fake, y_real): 17 | batchsize = len(y_fake) 18 | L1 = F.sum(F.softplus(-y_real)) / batchsize 19 | L2 = F.sum(F.softplus(y_fake)) / batchsize 20 | loss = L1 + L2 21 | chainer.report({'loss': loss}, dis) 22 | return loss 23 | 24 | def loss_gen(self, gen, y_fake): 25 | batchsize = len(y_fake) 26 | loss = F.sum(F.softplus(-y_fake)) / batchsize 27 | chainer.report({'loss': loss}, gen) 28 | return loss 29 | 30 | def update_core(self): 31 | gen_optimizer = self.get_optimizer('gen') 32 | dis_optimizer = self.get_optimizer('dis') 33 | 34 | batch = self.get_iterator('main').next() 35 | x_real = Variable(self.converter(batch, self.device)) / 255. 36 | xp = chainer.cuda.get_array_module(x_real.data) 37 | 38 | gen, dis = self.gen, self.dis 39 | batchsize = len(batch) 40 | 41 | y_real = dis(x_real) 42 | 43 | z = Variable(xp.asarray(gen.make_hidden(batchsize))) 44 | x_fake = gen(z) 45 | y_fake = dis(x_fake) 46 | 47 | dis_optimizer.update(self.loss_dis, dis, y_fake, y_real) 48 | gen_optimizer.update(self.loss_gen, gen, y_fake) 49 | -------------------------------------------------------------------------------- /tensorboardX/examples/chainer/extension_logger/visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import chainer 9 | import chainer.cuda 10 | from chainer import Variable 11 | 12 | 13 | def out_generated_image(gen, dis, rows, cols, seed, dst, writer): 14 | @chainer.training.make_extension() 15 | def make_image(trainer): 16 | np.random.seed(seed) 17 | n_images = rows * cols 18 | xp = gen.xp 19 | z = Variable(xp.asarray(gen.make_hidden(n_images))) 20 | with chainer.using_config('train', False): 21 | x = gen(z) 22 | writer.add_image('img', x, trainer.updater.iteration) 23 | 24 | return make_image 25 | -------------------------------------------------------------------------------- /tensorboardX/examples/chainer/plain_logger/data.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | 4 | import numpy as np 5 | import six 6 | from six.moves.urllib import request 7 | 8 | parent = 'http://yann.lecun.com/exdb/mnist' 9 | train_images = 'train-images-idx3-ubyte.gz' 10 | train_labels = 'train-labels-idx1-ubyte.gz' 11 | test_images = 't10k-images-idx3-ubyte.gz' 12 | test_labels = 't10k-labels-idx1-ubyte.gz' 13 | num_train = 60000 14 | num_test = 10000 15 | dim = 784 16 | 17 | 18 | def load_mnist(images, labels, num): 19 | data = np.zeros(num * dim, dtype=np.uint8).reshape((num, dim)) 20 | target = np.zeros(num, dtype=np.uint8).reshape((num, )) 21 | 22 | with gzip.open(images, 'rb') as f_images,\ 23 | gzip.open(labels, 'rb') as f_labels: 24 | f_images.read(16) 25 | f_labels.read(8) 26 | for i in six.moves.range(num): 27 | target[i] = ord(f_labels.read(1)) 28 | for j in six.moves.range(dim): 29 | data[i, j] = ord(f_images.read(1)) 30 | 31 | return data, target 32 | 33 | 34 | def download_mnist_data(): 35 | print('Downloading {:s}...'.format(train_images)) 36 | request.urlretrieve('{:s}/{:s}'.format(parent, train_images), train_images) 37 | print('Done') 38 | print('Downloading {:s}...'.format(train_labels)) 39 | request.urlretrieve('{:s}/{:s}'.format(parent, train_labels), train_labels) 40 | print('Done') 41 | print('Downloading {:s}...'.format(test_images)) 42 | request.urlretrieve('{:s}/{:s}'.format(parent, test_images), test_images) 43 | print('Done') 44 | print('Downloading {:s}...'.format(test_labels)) 45 | request.urlretrieve('{:s}/{:s}'.format(parent, test_labels), test_labels) 46 | print('Done') 47 | 48 | print('Converting training data...') 49 | data_train, target_train = load_mnist(train_images, train_labels, 50 | num_train) 51 | print('Done') 52 | print('Converting test data...') 53 | data_test, target_test = load_mnist(test_images, test_labels, num_test) 54 | mnist = {'data': np.append(data_train, data_test, axis=0), 55 | 'target': np.append(target_train, target_test, axis=0)} 56 | print('Done') 57 | print('Save output...') 58 | with open('mnist.pkl', 'wb') as output: 59 | six.moves.cPickle.dump(mnist, output, -1) 60 | print('Done') 61 | print('Convert completed') 62 | 63 | 64 | def load_mnist_data(): 65 | if not os.path.exists('mnist.pkl'): 66 | download_mnist_data() 67 | with open('mnist.pkl', 'rb') as mnist_pickle: 68 | mnist = six.moves.cPickle.load(mnist_pickle) 69 | return mnist 70 | -------------------------------------------------------------------------------- /tensorboardX/examples/chainer/plain_logger/net.py: -------------------------------------------------------------------------------- 1 | import six 2 | 3 | import chainer 4 | import chainer.functions as F 5 | from chainer.functions.loss.vae import gaussian_kl_divergence 6 | import chainer.links as L 7 | 8 | 9 | class VAE(chainer.Chain): 10 | """Variational AutoEncoder""" 11 | 12 | def __init__(self, n_in, n_latent, n_h): 13 | super(VAE, self).__init__() 14 | with self.init_scope(): 15 | # encoder 16 | self.le1 = L.Linear(n_in, n_h) 17 | self.le2_mu = L.Linear(n_h, n_latent) 18 | self.le2_ln_var = L.Linear(n_h, n_latent) 19 | # decoder 20 | self.ld1 = L.Linear(n_latent, n_h) 21 | self.ld2 = L.Linear(n_h, n_in) 22 | 23 | def __call__(self, x, sigmoid=True): 24 | """AutoEncoder""" 25 | return self.decode(self.encode(x)[0], sigmoid) 26 | 27 | def encode(self, x): 28 | h1 = F.tanh(self.le1(x)) 29 | mu = self.le2_mu(h1) 30 | ln_var = self.le2_ln_var(h1) # log(sigma**2) 31 | return mu, ln_var 32 | 33 | def decode(self, z, sigmoid=True): 34 | h1 = F.tanh(self.ld1(z)) 35 | h2 = self.ld2(h1) 36 | if sigmoid: 37 | return F.sigmoid(h2) 38 | else: 39 | return h2 40 | 41 | def get_loss_func(self, C=1.0, k=1): 42 | """Get loss function of VAE. 43 | 44 | The loss value is equal to ELBO (Evidence Lower Bound) 45 | multiplied by -1. 46 | 47 | Args: 48 | C (int): Usually this is 1.0. Can be changed to control the 49 | second term of ELBO bound, which works as regularization. 50 | k (int): Number of Monte Carlo samples used in encoded vector. 51 | """ 52 | def lf(x): 53 | mu, ln_var = self.encode(x) 54 | batchsize = len(mu.data) 55 | # reconstruction loss 56 | rec_loss = 0 57 | for l in six.moves.range(k): 58 | z = F.gaussian(mu, ln_var) 59 | rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) \ 60 | / (k * batchsize) 61 | self.rec_loss = rec_loss 62 | self.loss = self.rec_loss + \ 63 | C * gaussian_kl_divergence(mu, ln_var) / batchsize 64 | return self.loss 65 | return lf 66 | -------------------------------------------------------------------------------- /tensorboardX/examples/demo_beholder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Simple MNIST classifier to demonstrate features of Beholder. 16 | 17 | Based on tensorflow/examples/tutorials/mnist/mnist_with_summaries.py. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import numpy as np 25 | import tensorboardX.beholder as beholder_lib 26 | import time 27 | 28 | from collections import namedtuple 29 | 30 | 31 | LOG_DIRECTORY = '/tmp/beholder-demo' 32 | tensor_and_name = namedtuple('tensor_and_name', 'tensor, name') 33 | 34 | 35 | def beholder_pytorch(): 36 | for i in range(1000): 37 | fake_param = [tensor_and_name(np.random.randn(128, 768, 3), 'test' + str(i)) 38 | for i in range(5)] 39 | arrays = [tensor_and_name(np.random.randn(128, 768, 3), 'test' + str(i)) 40 | for i in range(5)] 41 | beholder = beholder_lib.Beholder(logdir=LOG_DIRECTORY) 42 | beholder.update( 43 | trainable=fake_param, 44 | arrays=arrays, 45 | frame=np.random.randn(128, 128), 46 | ) 47 | time.sleep(0.1) 48 | print(i) 49 | 50 | 51 | if __name__ == '__main__': 52 | import os 53 | if not os.path.exists(LOG_DIRECTORY): 54 | os.makedirs(LOG_DIRECTORY) 55 | print(LOG_DIRECTORY) 56 | beholder_pytorch() 57 | -------------------------------------------------------------------------------- /tensorboardX/examples/demo_custom_scalars.py: -------------------------------------------------------------------------------- 1 | from numpy.random import rand 2 | from tensorboardX import SummaryWriter 3 | import time 4 | 5 | 6 | with SummaryWriter() as writer: 7 | for n_iter in range(100): 8 | writer.add_scalar('twse/0050', rand(), n_iter) 9 | writer.add_scalar('twse/2330', rand(), n_iter) 10 | t = rand() 11 | writer.add_scalar('dow/aaa', t, n_iter) 12 | writer.add_scalar('dow/bbb', t - 1, n_iter) 13 | writer.add_scalar('dow/ccc', t + 1, n_iter) 14 | writer.add_scalar('nasdaq/aaa', rand(), n_iter) 15 | writer.add_scalar('nasdaq/bbb', rand(), n_iter) 16 | writer.add_scalar('nasdaq/ccc', rand(), n_iter) 17 | 18 | layout = {'Taiwan': {'twse': ['Multiline', ['twse/0050', 'twse/2330']]}, 19 | 'USA': {'dow': ['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']], 20 | 'nasdaq': ['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}} 21 | writer.add_custom_scalars(layout) 22 | # writer.add_custom_scalars(layout) second call has no effect 23 | 24 | time.sleep(1) 25 | 26 | with SummaryWriter() as writer: 27 | for n_iter in range(100): 28 | writer.add_scalar('twse/0050', rand(), n_iter) 29 | writer.add_scalar('twse/2330', rand(), n_iter) 30 | 31 | writer.add_custom_scalars_multilinechart(['twse/0050', 'twse/2330']) 32 | 33 | time.sleep(1) 34 | 35 | with SummaryWriter() as writer: 36 | for n_iter in range(100): 37 | t = rand() 38 | writer.add_scalar('dow/aaa', t, n_iter) 39 | writer.add_scalar('dow/bbb', t - 1, n_iter) 40 | writer.add_scalar('dow/ccc', t + 1, n_iter) 41 | 42 | writer.add_custom_scalars_marginchart(['dow/aaa', 'dow/bbb', 'dow/ccc']) 43 | -------------------------------------------------------------------------------- /tensorboardX/examples/demo_hparams.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | import time 3 | import random 4 | 5 | 6 | hparam = {'lr': [0.1, 0.01, 0.001], 7 | 'bsize': [1, 2, 4], 8 | 'n_hidden': [100, 200]} 9 | 10 | metrics = {'accuracy', 'loss'} 11 | 12 | def train(lr, bsize, n_hidden): 13 | x = random.random() 14 | return x, x*5 15 | 16 | with SummaryWriter() as w: 17 | for lr in hparam['lr']: 18 | for bsize in hparam['bsize']: 19 | for n_hidden in hparam['n_hidden']: 20 | accu, loss = train(lr, bsize, n_hidden) 21 | 22 | w.add_hparams({'lr': lr, 'bsize': bsize, 'n_hidden': n_hidden}, 23 | {'accuracy': accu, 'loss': loss}) 24 | 25 | -------------------------------------------------------------------------------- /tensorboardX/examples/demo_matplotlib.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | plt.switch_backend('agg') 3 | 4 | fig = plt.figure() 5 | 6 | c1 = plt.Circle((0.2, 0.5), 0.2, color='r') 7 | c2 = plt.Circle((0.8, 0.5), 0.2, color='r') 8 | 9 | ax = plt.gca() 10 | ax.add_patch(c1) 11 | ax.add_patch(c2) 12 | plt.axis('scaled') 13 | 14 | 15 | from tensorboardX import SummaryWriter 16 | writer = SummaryWriter() 17 | writer.add_figure('matplotlib', fig) 18 | writer.close() 19 | -------------------------------------------------------------------------------- /tensorboardX/examples/demo_multiple_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from tensorboardX import SummaryWriter 4 | 5 | 6 | def main(): 7 | degrees = np.linspace(0, 3600 * math.pi / 180.0, 3600) 8 | degrees = degrees.reshape(3600, 1) 9 | labels = ["%d" % (i) for i in range(0, 3600)] 10 | 11 | with SummaryWriter() as writer: 12 | # Maybe make a bunch of data that's always shifted in some 13 | # way, and that will be hard for PCA to turn into a sphere? 14 | 15 | for epoch in range(0, 16): 16 | shift = epoch * 2 * math.pi / 16.0 17 | mat = np.concatenate([ 18 | np.sin(shift + degrees * 2 * math.pi / 180.0), 19 | np.sin(shift + degrees * 3 * math.pi / 180.0), 20 | np.sin(shift + degrees * 5 * math.pi / 180.0), 21 | np.sin(shift + degrees * 7 * math.pi / 180.0), 22 | np.sin(shift + degrees * 11 * math.pi / 180.0) 23 | ], axis=1) 24 | writer.add_embedding( 25 | mat=mat, 26 | metadata=labels, 27 | tag="sin", 28 | global_step=epoch) 29 | 30 | mat = np.concatenate([ 31 | np.cos(shift + degrees * 2 * math.pi / 180.0), 32 | np.cos(shift + degrees * 3 * math.pi / 180.0), 33 | np.cos(shift + degrees * 5 * math.pi / 180.0), 34 | np.cos(shift + degrees * 7 * math.pi / 180.0), 35 | np.cos(shift + degrees * 11 * math.pi / 180.0) 36 | ], axis=1) 37 | writer.add_embedding( 38 | mat=mat, 39 | metadata=labels, 40 | tag="cos", 41 | global_step=epoch) 42 | 43 | mat = np.concatenate([ 44 | np.tan(shift + degrees * 2 * math.pi / 180.0), 45 | np.tan(shift + degrees * 3 * math.pi / 180.0), 46 | np.tan(shift + degrees * 5 * math.pi / 180.0), 47 | np.tan(shift + degrees * 7 * math.pi / 180.0), 48 | np.tan(shift + degrees * 11 * math.pi / 180.0) 49 | ], axis=1) 50 | writer.add_embedding( 51 | mat=mat, 52 | metadata=labels, 53 | tag="tan", 54 | global_step=epoch) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | 60 | # tensorboard --logdir runs 61 | # Under "Projection, you should see 62 | # 48 tensor found named 63 | # cos:cos-00000 to cos:cos-00016 64 | # sin:sin-00000 to sin:sin-00016 65 | # tan:tan-00000 to tan:tan-00016 66 | -------------------------------------------------------------------------------- /tensorboardX/examples/demo_nvidia_smi.py: -------------------------------------------------------------------------------- 1 | """ 2 | write gpu and (gpu) memory usage of nvidia cards as scalar 3 | """ 4 | from tensorboardX import SummaryWriter 5 | import time 6 | import torch 7 | try: 8 | import nvidia_smi 9 | nvidia_smi.nvmlInit() 10 | handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0) # gpu0 11 | except ImportError: 12 | print('This demo needs nvidia-ml-py or nvidia-ml-py3') 13 | exit() 14 | 15 | 16 | with SummaryWriter() as writer: 17 | x = [] 18 | for n_iter in range(50): 19 | x.append(torch.Tensor(1000, 1000).cuda()) 20 | res = nvidia_smi.nvmlDeviceGetUtilizationRates(handle) 21 | writer.add_scalar('nv/gpu', res.gpu, n_iter) 22 | res = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) 23 | writer.add_scalar('nv/gpu_mem', res.used, n_iter) 24 | time.sleep(0.1) 25 | -------------------------------------------------------------------------------- /tensorboardX/examples/demo_onnx.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | 3 | import subprocess 4 | zoo_address = 'https://onnxzoo.blob.core.windows.net/models/opset_8/mnist/mnist.tar.gz' 5 | 6 | res = subprocess.call(['wget', '-nc', zoo_address]) 7 | assert res == 0, 'cannot download example onnx model from the zoo' 8 | res = subprocess.call(['tar', 'xf', 'mnist.tar.gz', '-C', 'examples/', 'mnist/model.onnx']) 9 | 10 | 11 | 12 | with SummaryWriter() as w: 13 | w.add_onnx_graph('examples/mnist/model.onnx') 14 | # w.add_onnx_graph('/Users/dexter/Downloads/resnet50/model.onnx') 15 | -------------------------------------------------------------------------------- /tensorboardX/examples/demo_purge.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | from tensorboardX import SummaryWriter 3 | 4 | with SummaryWriter(logdir='runs/purge') as w: 5 | for i in range(100): 6 | w.add_scalar('purgetest', i, i) 7 | 8 | sleep(1.0) 9 | 10 | with SummaryWriter(logdir='runs/purge', purge_step=42) as w: 11 | # event 42~99 are removed (inclusively) 12 | for i in range(42, 100): 13 | w.add_scalar('purgetest', 42, i) 14 | -------------------------------------------------------------------------------- /tensorboardX/examples/tensorboardX: -------------------------------------------------------------------------------- 1 | ../tensorboardX/ -------------------------------------------------------------------------------- /tensorboardX/screenshots/Demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/screenshots/Demo.gif -------------------------------------------------------------------------------- /tensorboardX/screenshots/audio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/screenshots/audio.png -------------------------------------------------------------------------------- /tensorboardX/screenshots/distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/screenshots/distribution.png -------------------------------------------------------------------------------- /tensorboardX/screenshots/embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/screenshots/embedding.png -------------------------------------------------------------------------------- /tensorboardX/screenshots/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/screenshots/graph.png -------------------------------------------------------------------------------- /tensorboardX/screenshots/histogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/screenshots/histogram.png -------------------------------------------------------------------------------- /tensorboardX/screenshots/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/screenshots/image.png -------------------------------------------------------------------------------- /tensorboardX/screenshots/scalar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/screenshots/scalar.png -------------------------------------------------------------------------------- /tensorboardX/screenshots/text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/screenshots/text.png -------------------------------------------------------------------------------- /tensorboardX/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_file = LICENSE 3 | 4 | [bdist_wheel] 5 | universal = 1 6 | -------------------------------------------------------------------------------- /tensorboardX/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import subprocess 5 | import os 6 | from setuptools import setup, find_packages 7 | from setuptools.command.develop import develop 8 | from setuptools.command.install import install 9 | 10 | # Dynamically compile protos 11 | def compileProtoBuf(): 12 | res = subprocess.call(['bash', './compile.sh']) 13 | assert res == 0, 'cannot compile protobuf' 14 | 15 | class PostDevelopCommand(develop): 16 | """Post-installation for development mode.""" 17 | def run(self): 18 | compileProtoBuf() 19 | develop.run(self) 20 | 21 | 22 | class PostInstallCommand(install): 23 | """Post-installation for installation mode.""" 24 | def run(self): 25 | compileProtoBuf() 26 | import os 27 | os.system("pip install protobuf numpy six") 28 | install.run(self) 29 | 30 | with open('HISTORY.rst') as history_file: 31 | history = history_file.read() 32 | 33 | preparing_PyPI_package = False 34 | version_git = version = '1.8' 35 | 36 | if not preparing_PyPI_package: 37 | if os.path.exists('.git'): 38 | sha = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() 39 | version_git = version_git + '+' + sha[:7] 40 | 41 | with open('tensorboardX/__init__.py', 'a') as f: 42 | f.write('\n__version__ = "{}"\n'.format(version_git)) 43 | 44 | requirements = [ 45 | 'numpy', 46 | 'protobuf >= 3.6.1', 47 | 'six', 48 | ] 49 | 50 | test_requirements = [ 51 | 'pytest', 52 | 'matplotlib', 53 | 'crc32c', 54 | ] 55 | 56 | setup( 57 | name='tensorboardX', 58 | version=version_git, 59 | description='TensorBoardX lets you watch Tensors Flow without Tensorflow', 60 | long_description=history, 61 | author='Tzu-Wei Huang', 62 | author_email='huang.dexter@gmail.com', 63 | url='https://github.com/lanpa/tensorboardX', 64 | packages=['tensorboardX'], 65 | include_package_data=True, 66 | install_requires=requirements, 67 | license='MIT license', 68 | zip_safe=False, 69 | classifiers=[ 70 | 'Development Status :: 2 - Pre-Alpha', 71 | 'Intended Audience :: Developers', 72 | 'License :: OSI Approved :: MIT License', 73 | 'Natural Language :: English', 74 | 'Programming Language :: Python :: 2', 75 | 'Programming Language :: Python :: 2.7', 76 | 'Programming Language :: Python :: 3', 77 | 'Programming Language :: Python :: 3.4', 78 | 'Programming Language :: Python :: 3.5', 79 | 'Programming Language :: Python :: 3.6', 80 | ], 81 | cmdclass={ 82 | 'develop': PostDevelopCommand, 83 | 'install': PostInstallCommand, 84 | }, 85 | test_suite='tests', 86 | tests_require=test_requirements 87 | ) 88 | 89 | 90 | # checklist: update History.rst readme.md 91 | # change preparing_PyPI_package to True 92 | # remove __version__ = "1.old" in __init__.py 93 | # commit 94 | # add tag 95 | # python setup.py sdist bdist_wheel --universal 96 | # twine upload dist/* 97 | # push commit -------------------------------------------------------------------------------- /tensorboardX/tensorboardX.patch: -------------------------------------------------------------------------------- 1 | diff --git a/tensorboardX/summary.py b/tensorboardX/summary.py 2 | index 27d99ea..f5bf234 100644 3 | --- a/tensorboardX/summary.py 4 | +++ b/tensorboardX/summary.py 5 | @@ -373,36 +373,24 @@ def make_video(tensor, fps): 6 | 7 | def audio(tag, tensor, sample_rate=44100): 8 | tensor = make_np(tensor) 9 | - tensor = tensor.squeeze() 10 | if abs(tensor).max() > 1: 11 | print('warning: audio amplitude out of range, auto clipped.') 12 | tensor = tensor.clip(-1, 1) 13 | - assert(tensor.ndim == 1), 'input tensor should be 1 dimensional.' 14 | - 15 | - tensor_list = [int(32767.0 * x) for x in tensor] 16 | + assert(tensor.ndim == 2), 'input tensor should be 2 dimensional.' 17 | + length_frames, num_channels = tensor.shape 18 | + assert num_channels == 1 or num_channels == 2, f'Expected 1/2 channels, got {num_channels}' 19 | + import soundfile 20 | import io 21 | - import wave 22 | - import struct 23 | - fio = io.BytesIO() 24 | - Wave_write = wave.open(fio, 'wb') 25 | - Wave_write.setnchannels(1) 26 | - Wave_write.setsampwidth(2) 27 | - Wave_write.setframerate(sample_rate) 28 | - tensor_enc = b'' 29 | - tensor_enc += struct.pack("<" + "h" * len(tensor_list), *tensor_list) 30 | - 31 | - Wave_write.writeframes(tensor_enc) 32 | - Wave_write.close() 33 | - audio_string = fio.getvalue() 34 | - fio.close() 35 | + with io.BytesIO() as fio: 36 | + soundfile.write(fio, tensor, samplerate=sample_rate, format='wav') 37 | + audio_string = fio.getvalue() 38 | audio = Summary.Audio(sample_rate=sample_rate, 39 | - num_channels=1, 40 | - length_frames=len(tensor_list), 41 | + num_channels=num_channels, 42 | + length_frames=length_frames, 43 | encoded_audio_string=audio_string, 44 | content_type='audio/wav') 45 | return Summary(value=[Summary.Value(tag=tag, audio=audio)]) 46 | 47 | - 48 | def custom_scalars(layout): 49 | categoriesnames = layout.keys() 50 | categories = [] 51 | diff --git a/tensorboardX/writer.py b/tensorboardX/writer.py 52 | index 06337a7..58d57a1 100644 53 | --- a/tensorboardX/writer.py 54 | +++ b/tensorboardX/writer.py 55 | @@ -716,7 +716,7 @@ class SummaryWriter(object): 56 | sample_rate (int): sample rate in Hz 57 | walltime (float): Optional override default walltime (time.time()) of event 58 | Shape: 59 | - snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1]. 60 | + snd_tensor: :math:`(L, c)`. The values should lie between [-1, 1]. 61 | """ 62 | if self._check_caffe2_blob(snd_tensor): 63 | snd_tensor = workspace.FetchBlob(snd_tensor) 64 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/__init__.py: -------------------------------------------------------------------------------- 1 | """A module for visualization with tensorboard 2 | """ 3 | 4 | from .record_writer import RecordWriter 5 | from .torchvis import TorchVis 6 | from .writer import FileWriter, SummaryWriter 7 | 8 | __version__ = "1.8" # will be overwritten if run setup.py 9 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/beholder/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .beholder import Beholder 16 | from .beholder import BeholderHook 17 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/beholder/file_system_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import pickle 20 | 21 | # import tensorflow as tf 22 | # from google.protobuf import message 23 | 24 | 25 | def write_file(contents, path, mode='wb'): 26 | with open(path, mode) as new_file: 27 | new_file.write(contents) 28 | 29 | 30 | def write_pickle(obj, path): 31 | with open(path, 'wb') as new_file: 32 | pickle.dump(obj, new_file) 33 | 34 | 35 | def read_pickle(path, default=None): 36 | with open(path, 'rb') as pickle_file: 37 | result = pickle.load(pickle_file) 38 | return result 39 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/beholder/shared_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | PLUGIN_NAME = 'beholder' 20 | TAG_NAME = 'beholder-frame' 21 | SUMMARY_FILENAME = 'frame.summary' 22 | CONFIG_FILENAME = 'config.pkl' 23 | SECTION_INFO_FILENAME = 'section-info.pkl' 24 | SUMMARY_COLLECTION_KEY_NAME = 'summaries_beholder' 25 | 26 | DEFAULT_CONFIG = { 27 | 'values': 'trainable_variables', 28 | 'mode': 'variance', 29 | 'scaling': 'layer', 30 | 'window_size': 15, 31 | 'FPS': 10, 32 | 'is_recording': False, 33 | 'show_all': False, 34 | 'colormap': 'magma' 35 | } 36 | 37 | SECTION_HEIGHT = 128 38 | IMAGE_WIDTH = 512 + 256 39 | 40 | TB_WHITE = 245 41 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/onnx_graph.py: -------------------------------------------------------------------------------- 1 | from .proto.graph_pb2 import GraphDef 2 | from .proto.node_def_pb2 import NodeDef 3 | from .proto.versions_pb2 import VersionDef 4 | from .proto.attr_value_pb2 import AttrValue 5 | from .proto.tensor_shape_pb2 import TensorShapeProto 6 | 7 | 8 | def load_onnx_graph(fname): 9 | import onnx 10 | m = onnx.load(fname) 11 | g = m.graph 12 | return parse(g) 13 | 14 | 15 | def parse(graph): 16 | nodes_proto = [] 17 | nodes = [] 18 | import itertools 19 | for node in itertools.chain(graph.input, graph.output): 20 | nodes_proto.append(node) 21 | 22 | for node in nodes_proto: 23 | print(node.name) 24 | shapeproto = TensorShapeProto( 25 | dim=[TensorShapeProto.Dim(size=d.dim_value) for d in node.type.tensor_type.shape.dim]) 26 | nodes.append(NodeDef( 27 | name=node.name.encode(encoding='utf_8'), 28 | op='Variable', 29 | input=[], 30 | attr={ 31 | 'dtype': AttrValue(type=node.type.tensor_type.elem_type), 32 | 'shape': AttrValue(shape=shapeproto), 33 | }) 34 | ) 35 | 36 | for node in graph.node: 37 | attr = [] 38 | for s in node.attribute: 39 | attr.append(' = '.join([str(f[1]) for f in s.ListFields()])) 40 | attr = ', '.join(attr).encode(encoding='utf_8') 41 | print(node.output[0]) 42 | nodes.append(NodeDef( 43 | name=node.output[0].encode(encoding='utf_8'), 44 | op=node.op_type, 45 | input=node.input, 46 | attr={'parameters': AttrValue(s=attr)}, 47 | )) 48 | 49 | # two pass token replacement, appends opname to object id 50 | mapping = {} 51 | for node in nodes: 52 | mapping[node.name] = node.op + '_' + node.name 53 | 54 | return GraphDef(node=nodes, versions=VersionDef(producer=22)) 55 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/jukebox/08efbbc1d4ed1a3cef96e08a931944c8b4d63bb3/tensorboardX/tensorboardX/proto/__init__.py -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/attr_value.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboardX; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "AttrValueProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/proto/tensor.proto"; 10 | import "tensorboardX/proto/tensor_shape.proto"; 11 | import "tensorboardX/proto/types.proto"; 12 | 13 | // Protocol buffer representing the value for an attr used to configure an Op. 14 | // Comment indicates the corresponding attr type. Only the field matching the 15 | // attr type may be filled. 16 | message AttrValue { 17 | // LINT.IfChange 18 | message ListValue { 19 | repeated bytes s = 2; // "list(string)" 20 | repeated int64 i = 3 [packed = true]; // "list(int)" 21 | repeated float f = 4 [packed = true]; // "list(float)" 22 | repeated bool b = 5 [packed = true]; // "list(bool)" 23 | repeated DataType type = 6 [packed = true]; // "list(type)" 24 | repeated TensorShapeProto shape = 7; // "list(shape)" 25 | repeated TensorProto tensor = 8; // "list(tensor)" 26 | repeated NameAttrList func = 9; // "list(attr)" 27 | } 28 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) 29 | 30 | oneof value { 31 | bytes s = 2; // "string" 32 | int64 i = 3; // "int" 33 | float f = 4; // "float" 34 | bool b = 5; // "bool" 35 | DataType type = 6; // "type" 36 | TensorShapeProto shape = 7; // "shape" 37 | TensorProto tensor = 8; // "tensor" 38 | ListValue list = 1; // any "list(...)" 39 | 40 | // "func" represents a function. func.name is a function's name or 41 | // a primitive op's name. func.attr.first is the name of an attr 42 | // defined for that function. func.attr.second is the value for 43 | // that attr in the instantiation. 44 | NameAttrList func = 10; 45 | 46 | // This is a placeholder only used in nodes defined inside a 47 | // function. It indicates the attr value will be supplied when 48 | // the function is instantiated. For example, let us suppose a 49 | // node "N" in function "FN". "N" has an attr "A" with value 50 | // placeholder = "foo". When FN is instantiated with attr "foo" 51 | // set to "bar", the instantiated node N's attr A will have been 52 | // given the value "bar". 53 | string placeholder = 9; 54 | } 55 | } 56 | 57 | // A list of attr names and their values. The whole list is attached 58 | // with a string name. E.g., MatMul[T=float]. 59 | message NameAttrList { 60 | string name = 1; 61 | map attr = 2; 62 | } 63 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/event.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboardX; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "EventProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.util"; 8 | 9 | import "tensorboardX/proto/summary.proto"; 10 | 11 | // Protocol buffer representing an event that happened during 12 | // the execution of a Brain model. 13 | message Event { 14 | // Timestamp of the event. 15 | double wall_time = 1; 16 | 17 | // Global step of the event. 18 | int64 step = 2; 19 | 20 | oneof what { 21 | // An event file was started, with the specified version. 22 | // This is use to identify the contents of the record IO files 23 | // easily. Current version is "brain.Event:2". All versions 24 | // start with "brain.Event:". 25 | string file_version = 3; 26 | // An encoded version of a GraphDef. 27 | bytes graph_def = 4; 28 | // A summary was generated. 29 | Summary summary = 5; 30 | // The user output a log message. Not all messages are logged, only ones 31 | // generated via the Python tensorboard_logging module. 32 | LogMessage log_message = 6; 33 | // The state of the session which can be used for restarting after crashes. 34 | SessionLog session_log = 7; 35 | // The metadata returned by running a session.run() call. 36 | TaggedRunMetadata tagged_run_metadata = 8; 37 | // An encoded version of a MetaGraphDef. 38 | bytes meta_graph_def = 9; 39 | } 40 | } 41 | 42 | // Protocol buffer used for logging messages to the events file. 43 | message LogMessage { 44 | enum Level { 45 | UNKNOWN = 0; 46 | DEBUG = 10; 47 | INFO = 20; 48 | WARN = 30; 49 | ERROR = 40; 50 | FATAL = 50; 51 | } 52 | Level level = 1; 53 | string message = 2; 54 | } 55 | 56 | // Protocol buffer used for logging session state. 57 | message SessionLog { 58 | enum SessionStatus { 59 | STATUS_UNSPECIFIED = 0; 60 | START = 1; 61 | STOP = 2; 62 | CHECKPOINT = 3; 63 | } 64 | 65 | SessionStatus status = 1; 66 | // This checkpoint_path contains both the path and filename. 67 | string checkpoint_path = 2; 68 | string msg = 3; 69 | } 70 | 71 | // For logging the metadata output for a single session.run() call. 72 | message TaggedRunMetadata { 73 | // Tag name associated with this metadata. 74 | string tag = 1; 75 | // Byte-encoded version of the `RunMetadata` proto in order to allow lazy 76 | // deserialization. 77 | bytes run_metadata = 2; 78 | } 79 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/graph.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboardX; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "GraphProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/proto/node_def.proto"; 10 | //import "tensorflow/core/framework/function.proto"; 11 | import "tensorboardX/proto/versions.proto"; 12 | 13 | // Represents the graph of operations 14 | message GraphDef { 15 | repeated NodeDef node = 1; 16 | 17 | // Compatibility versions of the graph. See core/public/version.h for version 18 | // history. The GraphDef version is distinct from the TensorFlow version, and 19 | // each release of TensorFlow will support a range of GraphDef versions. 20 | VersionDef versions = 4; 21 | 22 | // Deprecated single version field; use versions above instead. Since all 23 | // GraphDef changes before "versions" was introduced were forward 24 | // compatible, this field is entirely ignored. 25 | int32 version = 3 [deprecated = true]; 26 | 27 | // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. 28 | // 29 | // "library" provides user-defined functions. 30 | // 31 | // Naming: 32 | // * library.function.name are in a flat namespace. 33 | // NOTE: We may need to change it to be hierarchical to support 34 | // different orgs. E.g., 35 | // { "/google/nn", { ... }}, 36 | // { "/google/vision", { ... }} 37 | // { "/org_foo/module_bar", { ... }} 38 | // map named_lib; 39 | // * If node[i].op is the name of one function in "library", 40 | // node[i] is deemed as a function call. Otherwise, node[i].op 41 | // must be a primitive operation supported by the runtime. 42 | // 43 | // 44 | // Function call semantics: 45 | // 46 | // * The callee may start execution as soon as some of its inputs 47 | // are ready. The caller may want to use Tuple() mechanism to 48 | // ensure all inputs are ready in the same time. 49 | // 50 | // * The consumer of return values may start executing as soon as 51 | // the return values the consumer depends on are ready. The 52 | // consumer may want to use Tuple() mechanism to ensure the 53 | // consumer does not start until all return values of the callee 54 | // function are ready. 55 | //FunctionDefLibrary library = 2; 56 | }; 57 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/node_def.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboardX; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "NodeProto"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/proto/attr_value.proto"; 10 | 11 | message NodeDef { 12 | // The name given to this operator. Used for naming inputs, 13 | // logging, visualization, etc. Unique within a single GraphDef. 14 | // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". 15 | string name = 1; 16 | 17 | // The operation name. There may be custom parameters in attrs. 18 | // Op names starting with an underscore are reserved for internal use. 19 | string op = 2; 20 | 21 | // Each input is "node:src_output" with "node" being a string name and 22 | // "src_output" indicating which output tensor to use from "node". If 23 | // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs 24 | // may optionally be followed by control inputs that have the format 25 | // "^node". 26 | repeated string input = 3; 27 | 28 | // A (possibly partial) specification for the device on which this 29 | // node should be placed. 30 | // The expected syntax for this string is as follows: 31 | // 32 | // DEVICE_SPEC ::= PARTIAL_SPEC 33 | // 34 | // PARTIAL_SPEC ::= ("/" CONSTRAINT) * 35 | // CONSTRAINT ::= ("job:" JOB_NAME) 36 | // | ("replica:" [1-9][0-9]*) 37 | // | ("task:" [1-9][0-9]*) 38 | // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) 39 | // 40 | // Valid values for this string include: 41 | // * "/job:worker/replica:0/task:1/gpu:3" (full specification) 42 | // * "/job:worker/gpu:3" (partial specification) 43 | // * "" (no specification) 44 | // 45 | // If the constraints do not resolve to a single device (or if this 46 | // field is empty or not present), the runtime will attempt to 47 | // choose a device automatically. 48 | string device = 4; 49 | 50 | // Operation-specific graph-construction-time configuration. 51 | // Note that this should include all attrs defined in the 52 | // corresponding OpDef, including those with a value matching 53 | // the default -- this allows the default to change and makes 54 | // NodeDefs easier to interpret on their own. However, if 55 | // an attr with a default is not specified in this list, the 56 | // default will be used. 57 | // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and 58 | // one of the names from the corresponding OpDef's attr field). 59 | // The values must have a type matching the corresponding OpDef 60 | // attr's type field. 61 | // TODO(josh11b): Add some examples here showing best practices. 62 | map attr = 5; 63 | }; 64 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/plugin_mesh.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboardX.mesh; 4 | 5 | // A MeshPluginData encapsulates information on which plugins are able to make 6 | // use of a certain summary value. 7 | message MeshPluginData { 8 | enum ContentType { 9 | UNDEFINED = 0; 10 | VERTEX = 1; 11 | FACE = 2; // Triangle face. 12 | COLOR = 3; 13 | } 14 | 15 | // Version `0` is the only supported version. 16 | int32 version = 1; 17 | 18 | // The name of the mesh summary this particular summary belongs to. 19 | string name = 2; 20 | 21 | // Type of data in the summary. 22 | ContentType content_type = 3; 23 | 24 | // JSON-serialized dictionary of ThreeJS classes configuration. 25 | string json_config = 5; 26 | 27 | // Shape of underlying data. Cache it here for performance reasons. 28 | repeated int32 shape = 6; 29 | } 30 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/plugin_pr_curve.proto: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | syntax = "proto3"; 17 | 18 | package tensorboardX; 19 | 20 | message PrCurvePluginData { 21 | // Version `0` is the only supported version. 22 | int32 version = 1; 23 | 24 | uint32 num_thresholds = 2; 25 | } 26 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/plugin_pr_curve_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/proto/plugin_pr_curve.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor.FileDescriptor( 18 | name='tensorboardX/proto/plugin_pr_curve.proto', 19 | package='tensorboardX', 20 | syntax='proto3', 21 | serialized_options=None, 22 | serialized_pb=_b('\n(tensorboardX/proto/plugin_pr_curve.proto\x12\x0ctensorboardX\"<\n\x11PrCurvePluginData\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x16\n\x0enum_thresholds\x18\x02 \x01(\rb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _PRCURVEPLUGINDATA = _descriptor.Descriptor( 29 | name='PrCurvePluginData', 30 | full_name='tensorboardX.PrCurvePluginData', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='version', full_name='tensorboardX.PrCurvePluginData.version', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | serialized_options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='num_thresholds', full_name='tensorboardX.PrCurvePluginData.num_thresholds', index=1, 44 | number=2, type=13, cpp_type=3, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | serialized_options=None, file=DESCRIPTOR), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | serialized_options=None, 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=58, 62 | serialized_end=118, 63 | ) 64 | 65 | DESCRIPTOR.message_types_by_name['PrCurvePluginData'] = _PRCURVEPLUGINDATA 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | PrCurvePluginData = _reflection.GeneratedProtocolMessageType('PrCurvePluginData', (_message.Message,), dict( 69 | DESCRIPTOR = _PRCURVEPLUGINDATA, 70 | __module__ = 'tensorboardX.proto.plugin_pr_curve_pb2' 71 | # @@protoc_insertion_point(class_scope:tensorboardX.PrCurvePluginData) 72 | )) 73 | _sym_db.RegisterMessage(PrCurvePluginData) 74 | 75 | 76 | # @@protoc_insertion_point(module_scope) 77 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/plugin_text.proto: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | syntax = "proto3"; 17 | 18 | package tensorboardX; 19 | 20 | // Text summaries created by the `tensorboard.plugins.text.summary` 21 | // module will include `SummaryMetadata` whose `plugin_data` field has 22 | // as `content` a binary string that is the encoding of an 23 | // `TextPluginData` proto. 24 | message TextPluginData { 25 | // Version `0` is the only supported version. 26 | int32 version = 1; 27 | } 28 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/plugin_text_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/proto/plugin_text.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor.FileDescriptor( 18 | name='tensorboardX/proto/plugin_text.proto', 19 | package='tensorboardX', 20 | syntax='proto3', 21 | serialized_options=None, 22 | serialized_pb=_b('\n$tensorboardX/proto/plugin_text.proto\x12\x0ctensorboardX\"!\n\x0eTextPluginData\x12\x0f\n\x07version\x18\x01 \x01(\x05\x62\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _TEXTPLUGINDATA = _descriptor.Descriptor( 29 | name='TextPluginData', 30 | full_name='tensorboardX.TextPluginData', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='version', full_name='tensorboardX.TextPluginData.version', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | serialized_options=None, file=DESCRIPTOR), 42 | ], 43 | extensions=[ 44 | ], 45 | nested_types=[], 46 | enum_types=[ 47 | ], 48 | serialized_options=None, 49 | is_extendable=False, 50 | syntax='proto3', 51 | extension_ranges=[], 52 | oneofs=[ 53 | ], 54 | serialized_start=54, 55 | serialized_end=87, 56 | ) 57 | 58 | DESCRIPTOR.message_types_by_name['TextPluginData'] = _TEXTPLUGINDATA 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | TextPluginData = _reflection.GeneratedProtocolMessageType('TextPluginData', (_message.Message,), dict( 62 | DESCRIPTOR = _TEXTPLUGINDATA, 63 | __module__ = 'tensorboardX.proto.plugin_text_pb2' 64 | # @@protoc_insertion_point(class_scope:tensorboardX.TextPluginData) 65 | )) 66 | _sym_db.RegisterMessage(TextPluginData) 67 | 68 | 69 | # @@protoc_insertion_point(module_scope) 70 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/resource_handle.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboardX; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "ResourceHandle"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // Protocol buffer representing a handle to a tensorflow resource. Handles are 10 | // not valid across executions, but can be serialized back and forth from within 11 | // a single run. 12 | message ResourceHandleProto { 13 | // Unique name for the device containing the resource. 14 | string device = 1; 15 | 16 | // Container in which this resource is placed. 17 | string container = 2; 18 | 19 | // Unique name of this resource. 20 | string name = 3; 21 | 22 | // Hash code for the type of the resource. Is only valid in the same device 23 | // and in the same execution. 24 | uint64 hash_code = 4; 25 | 26 | // For debug-only, the name of the type pointed to by this handle, if 27 | // available. 28 | string maybe_type_name = 5; 29 | }; 30 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/tensor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboardX; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/proto/resource_handle.proto"; 10 | import "tensorboardX/proto/tensor_shape.proto"; 11 | import "tensorboardX/proto/types.proto"; 12 | 13 | // Protocol buffer representing a tensor. 14 | message TensorProto { 15 | DataType dtype = 1; 16 | 17 | // Shape of the tensor. TODO(touts): sort out the 0-rank issues. 18 | TensorShapeProto tensor_shape = 2; 19 | 20 | // Only one of the representations below is set, one of "tensor_contents" and 21 | // the "xxx_val" attributes. We are not using oneof because as oneofs cannot 22 | // contain repeated fields it would require another extra set of messages. 23 | 24 | // Version number. 25 | // 26 | // In version 0, if the "repeated xxx" representations contain only one 27 | // element, that element is repeated to fill the shape. This makes it easy 28 | // to represent a constant Tensor with a single value. 29 | int32 version_number = 3; 30 | 31 | // Serialized raw tensor content from either Tensor::AsProtoTensorContent or 32 | // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation 33 | // can be used for all tensor types. The purpose of this representation is to 34 | // reduce serialization overhead during RPC call by avoiding serialization of 35 | // many repeated small items. 36 | bytes tensor_content = 4; 37 | 38 | // Type specific representations that make it easy to create tensor protos in 39 | // all languages. Only the representation corresponding to "dtype" can 40 | // be set. The values hold the flattened representation of the tensor in 41 | // row major order. 42 | 43 | // DT_HALF. Note that since protobuf has no int16 type, we'll have some 44 | // pointless zero padding for each value here. 45 | repeated int32 half_val = 13 [packed = true]; 46 | 47 | // DT_FLOAT. 48 | repeated float float_val = 5 [packed = true]; 49 | 50 | // DT_DOUBLE. 51 | repeated double double_val = 6 [packed = true]; 52 | 53 | // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. 54 | repeated int32 int_val = 7 [packed = true]; 55 | 56 | // DT_STRING 57 | repeated bytes string_val = 8; 58 | 59 | // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real 60 | // and imaginary parts of i-th single precision complex. 61 | repeated float scomplex_val = 9 [packed = true]; 62 | 63 | // DT_INT64 64 | repeated int64 int64_val = 10 [packed = true]; 65 | 66 | // DT_BOOL 67 | repeated bool bool_val = 11 [packed = true]; 68 | 69 | // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real 70 | // and imaginary parts of i-th double precision complex. 71 | repeated double dcomplex_val = 12 [packed = true]; 72 | 73 | // DT_RESOURCE 74 | repeated ResourceHandleProto resource_handle_val = 14; 75 | }; 76 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/tensor_shape.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffer representing the shape of tensors. 2 | 3 | syntax = "proto3"; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorShapeProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | package tensorboardX; 10 | 11 | // Dimensions of a tensor. 12 | message TensorShapeProto { 13 | // One dimension of the tensor. 14 | message Dim { 15 | // Size of the tensor in that dimension. 16 | // This value must be >= -1, but values of -1 are reserved for "unknown" 17 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers 18 | // that work with TensorShapeProto may fail at runtime when deserializing 19 | // a TensorShapeProto containing a dim value of -1. 20 | int64 size = 1; 21 | 22 | // Optional name of the tensor dimension. 23 | string name = 2; 24 | }; 25 | 26 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40} 27 | // for a 30 x 40 2D tensor. If an entry has size -1, this 28 | // corresponds to a dimension of unknown size. The names are 29 | // optional. 30 | // 31 | // The order of entries in "dim" matters: It indicates the layout of the 32 | // values in the tensor in-memory representation. 33 | // 34 | // The first entry in "dim" is the outermost dimension used to layout the 35 | // values, the last entry is the innermost dimension. This matches the 36 | // in-memory layout of RowMajor Eigen tensors. 37 | // 38 | // If "dim.size()" > 0, "unknown_rank" must be false. 39 | repeated Dim dim = 2; 40 | 41 | // If true, the number of dimensions in the shape is unknown. 42 | // 43 | // If true, "dim.size()" must be 0. 44 | bool unknown_rank = 3; 45 | }; 46 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/types.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboardX; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TypesProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // LINT.IfChange 10 | enum DataType { 11 | // Not a legal value for DataType. Used to indicate a DataType field 12 | // has not been set. 13 | DT_INVALID = 0; 14 | 15 | // Data types that all computation devices are expected to be 16 | // capable to support. 17 | DT_FLOAT = 1; 18 | DT_DOUBLE = 2; 19 | DT_INT32 = 3; 20 | DT_UINT8 = 4; 21 | DT_INT16 = 5; 22 | DT_INT8 = 6; 23 | DT_STRING = 7; 24 | DT_COMPLEX64 = 8; // Single-precision complex 25 | DT_INT64 = 9; 26 | DT_BOOL = 10; 27 | DT_QINT8 = 11; // Quantized int8 28 | DT_QUINT8 = 12; // Quantized uint8 29 | DT_QINT32 = 13; // Quantized int32 30 | DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. 31 | DT_QINT16 = 15; // Quantized int16 32 | DT_QUINT16 = 16; // Quantized uint16 33 | DT_UINT16 = 17; 34 | DT_COMPLEX128 = 18; // Double-precision complex 35 | DT_HALF = 19; 36 | DT_RESOURCE = 20; 37 | 38 | // TODO(josh11b): DT_GENERIC_PROTO = ??; 39 | // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? 40 | 41 | // Do not use! These are only for parameters. Every enum above 42 | // should have a corresponding value below (verified by types_test). 43 | DT_FLOAT_REF = 101; 44 | DT_DOUBLE_REF = 102; 45 | DT_INT32_REF = 103; 46 | DT_UINT8_REF = 104; 47 | DT_INT16_REF = 105; 48 | DT_INT8_REF = 106; 49 | DT_STRING_REF = 107; 50 | DT_COMPLEX64_REF = 108; 51 | DT_INT64_REF = 109; 52 | DT_BOOL_REF = 110; 53 | DT_QINT8_REF = 111; 54 | DT_QUINT8_REF = 112; 55 | DT_QINT32_REF = 113; 56 | DT_BFLOAT16_REF = 114; 57 | DT_QINT16_REF = 115; 58 | DT_QUINT16_REF = 116; 59 | DT_UINT16_REF = 117; 60 | DT_COMPLEX128_REF = 118; 61 | DT_HALF_REF = 119; 62 | DT_RESOURCE_REF = 120; 63 | } 64 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) 65 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto/versions.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboardX; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "VersionsProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // Version information for a piece of serialized data 10 | // 11 | // There are different types of versions for each type of data 12 | // (GraphDef, etc.), but they all have the same common shape 13 | // described here. 14 | // 15 | // Each consumer has "consumer" and "min_producer" versions (specified 16 | // elsewhere). A consumer is allowed to consume this data if 17 | // 18 | // producer >= min_producer 19 | // consumer >= min_consumer 20 | // consumer not in bad_consumers 21 | // 22 | message VersionDef { 23 | // The version of the code that produced this data. 24 | int32 producer = 1; 25 | 26 | // Any consumer below this version is not allowed to consume this data. 27 | int32 min_consumer = 2; 28 | 29 | // Specific consumer versions which are disallowed (e.g. due to bugs). 30 | repeated int32 bad_consumers = 3; 31 | }; 32 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/proto_graph.py: -------------------------------------------------------------------------------- 1 | from .proto.graph_pb2 import GraphDef 2 | from .proto.node_def_pb2 import NodeDef 3 | from .proto.versions_pb2 import VersionDef 4 | from .proto.attr_value_pb2 import AttrValue 5 | from .proto.tensor_shape_pb2 import TensorShapeProto 6 | 7 | 8 | def attr_value_proto(dtype, shape, s): 9 | """Creates a dict of objects matching 10 | https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto 11 | specifically designed for a NodeDef. The values have been 12 | reverse engineered from standard TensorBoard logged data. 13 | """ 14 | attr = {} 15 | if s is not None: 16 | attr['attr'] = AttrValue(s=s.encode(encoding='utf_8')) 17 | if shape is not None: 18 | shapeproto = tensor_shape_proto(shape) 19 | attr['_output_shapes'] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto])) 20 | return attr 21 | 22 | 23 | def tensor_shape_proto(outputsize): 24 | """Creates an object matching 25 | https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto 26 | """ 27 | return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize]) 28 | 29 | 30 | def node_proto(name, 31 | op='UnSpecified', 32 | input=None, 33 | dtype=None, 34 | shape=None, # type: tuple 35 | outputsize=None, 36 | attributes='' 37 | ): 38 | """Creates an object matching 39 | https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto 40 | """ 41 | if input is None: 42 | input = [] 43 | if not isinstance(input, list): 44 | input = [input] 45 | return NodeDef( 46 | name=name.encode(encoding='utf_8'), 47 | op=op, 48 | input=input, 49 | attr=attr_value_proto(dtype, outputsize, attributes) 50 | ) 51 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/torchvis.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import gc 7 | import six 8 | import time 9 | 10 | from functools import wraps 11 | from .writer import SummaryWriter 12 | from .visdom_writer import VisdomWriter 13 | 14 | 15 | # Supports both TensorBoard and Visdom (no embedding or graph visualization with Visdom) 16 | vis_formats = {'tensorboard': SummaryWriter, 'visdom': VisdomWriter} 17 | 18 | 19 | class TorchVis: 20 | def __init__(self, *args, **init_kwargs): 21 | """ 22 | Args: 23 | args (list of strings): The name of the visualization target(s). 24 | Accepted targets are 'tensorboard' and 'visdom'. 25 | init_kwargs: Additional keyword parameters for the visdom writer (For example, server IP). 26 | See https://github.com/facebookresearch/visdom/blob/master/README.md#visdom-arguments-python-only 27 | for more. 28 | """ 29 | self.subscribers = {} 30 | self.register(*args, **init_kwargs) 31 | 32 | def register(self, *args, **init_kwargs): 33 | # Sets tensorboard as the default visualization format if not specified 34 | formats = ['tensorboard'] if not args else args 35 | for format in formats: 36 | if self.subscribers.get(format) is None and format in vis_formats.keys(): 37 | self.subscribers[format] = vis_formats[format](**init_kwargs.get(format, {})) 38 | 39 | def unregister(self, *args): 40 | for format in args: 41 | self.subscribers[format].close() 42 | del self.subscribers[format] 43 | gc.collect() 44 | 45 | def __getattr__(self, attr): 46 | for _, subscriber in six.iteritems(self.subscribers): 47 | def wrapper(*args, **kwargs): 48 | for _, subscriber in six.iteritems(self.subscribers): 49 | if hasattr(subscriber, attr): 50 | getattr(subscriber, attr)(*args, **kwargs) 51 | return wrapper 52 | raise AttributeError 53 | 54 | # Handle writer management (open/close) for the user 55 | def __del__(self): 56 | for _, subscriber in six.iteritems(self.subscribers): 57 | subscriber.close() 58 | -------------------------------------------------------------------------------- /tensorboardX/tensorboardX/x2num.py: -------------------------------------------------------------------------------- 1 | # DO NOT alter/distruct/free input object ! 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import logging 7 | import numpy as np 8 | import six 9 | 10 | 11 | def check_nan(array): 12 | tmp = np.sum(array) 13 | if np.isnan(tmp) or np.isinf(tmp): 14 | logging.warning('NaN or Inf found in input tensor.') 15 | return array 16 | 17 | 18 | def make_np(x): 19 | if isinstance(x, list): 20 | return check_nan(np.array(x)) 21 | if isinstance(x, np.ndarray): 22 | return check_nan(x) 23 | if isinstance(x, six.string_types): # Caffe2 will pass name of blob(s) to fetch 24 | return check_nan(prepare_caffe2(x)) 25 | if np.isscalar(x): 26 | return check_nan(np.array([x])) 27 | if 'torch' in str(type(x)): 28 | return check_nan(prepare_pytorch(x)) 29 | if 'chainer' in str(type(x)): 30 | return check_nan(prepare_chainer(x)) 31 | if 'mxnet' in str(type(x)): 32 | return check_nan(prepare_mxnet(x)) 33 | raise NotImplementedError( 34 | 'Got {}, but expected numpy array or torch tensor.'.format(type(x))) 35 | 36 | 37 | def prepare_pytorch(x): 38 | import torch 39 | if isinstance(x, torch.autograd.Variable): 40 | x = x.data 41 | x = x.cpu().numpy() 42 | return x 43 | 44 | 45 | def prepare_theano(x): 46 | import theano 47 | pass 48 | 49 | 50 | def prepare_caffe2(x): 51 | from caffe2.python import workspace 52 | x = workspace.FetchBlob(x) 53 | return x 54 | 55 | 56 | def prepare_mxnet(x): 57 | x = x.asnumpy() 58 | return x 59 | 60 | 61 | def prepare_chainer(x): 62 | import chainer 63 | x = chainer.cuda.to_cpu(x.data) 64 | return x 65 | -------------------------------------------------------------------------------- /tensorboardX/tests/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorboardX.proto 3 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_pr_curve.test_pr_purve.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "tag" 3 | tensor { 4 | dtype: DT_FLOAT 5 | tensor_shape { 6 | dim { 7 | size: 6 8 | } 9 | dim { 10 | size: 1 11 | } 12 | } 13 | float_val: 57.0 14 | float_val: 43.0 15 | float_val: 0.0 16 | float_val: 0.0 17 | float_val: 0.57 18 | float_val: 1.0 19 | } 20 | metadata { 21 | plugin_data { 22 | plugin_name: "pr_curves" 23 | content: "\020\001" 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_pr_curve.test_pr_purve_raw.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "prcurve with raw data" 3 | tensor { 4 | dtype: DT_FLOAT 5 | tensor_shape { 6 | dim { 7 | size: 6 8 | } 9 | dim { 10 | size: 5 11 | } 12 | } 13 | float_val: 75.0 14 | float_val: 64.0 15 | float_val: 21.0 16 | float_val: 5.0 17 | float_val: 0.0 18 | float_val: 150.0 19 | float_val: 105.0 20 | float_val: 18.0 21 | float_val: 0.0 22 | float_val: 0.0 23 | float_val: 0.0 24 | float_val: 45.0 25 | float_val: 132.0 26 | float_val: 150.0 27 | float_val: 150.0 28 | float_val: 0.0 29 | float_val: 11.0 30 | float_val: 54.0 31 | float_val: 70.0 32 | float_val: 75.0 33 | float_val: 0.3333333 34 | float_val: 0.3786982 35 | float_val: 0.5384616 36 | float_val: 1.0 37 | float_val: 0.0 38 | float_val: 1.0 39 | float_val: 0.8533334 40 | float_val: 0.28 41 | float_val: 0.0666667 42 | float_val: 0.0 43 | } 44 | metadata { 45 | plugin_data { 46 | plugin_name: "pr_curves" 47 | content: "\020\001" 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_summary.test_audio.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "dummy" 3 | audio { 4 | sample_rate: 44100.0 5 | num_channels: 1 6 | length_frames: 42 7 | encoded_audio_string: "RIFFx\000\000\000WAVEfmt \020\000\000\000\001\000\001\000D\254\000\000\210X\001\000\002\000\020\000dataT\000\000\000\000\000\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177\377\177" 8 | content_type: "audio/wav" 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_summary.test_custom_scalars.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "custom_scalars__config__" 3 | tensor { 4 | dtype: DT_STRING 5 | tensor_shape { 6 | } 7 | string_val: "\022(\n\006Taiwan\022\036\n\004twse\022\026\n\ttwse/0050\n\ttwse/2330\022]\n\003USA\022$\n\003dow\032\035\n\033\n\007dow/aaa\022\007dow/bbb\032\007dow/ccc\0220\n\006nasdaq\032&\n$\n\nnasdaq/aaa\022\nnasdaq/bbb\032\nnasdaq/ccc" 8 | } 9 | metadata { 10 | plugin_data { 11 | plugin_name: "custom_scalars" 12 | } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_summary.test_float32_image.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "dummy" 3 | image { 4 | height: 32 5 | width: 32 6 | colorspace: 3 7 | encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000 \000\000\000 \010\002\000\000\000\374\030\355\243\000\000\000DIDATx\234cd``\370OK\300\370\340\301\003\232Z\3002j\301\360\267\200QAA\201\266\026\214\346\203Q\013\006\277\005\243\371\200 \030\372\2210\204b\311\233\305/\344G\000\334\236\021Uu\005R\000\377\007\244\224\342\013||\007\2655\330BfP\215\337S`>:{_l\020\335\242\tX6-\000\032r\007G\316\000\2561\226\201\244\252/\005V\357\026\271\003\033\0149\000\232\270\003+\260\301\220\003\240y\000T\221\324V\250_v\320\000\000\000\000IEND\256B`\202" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_summary.test_image_with_four_channel.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "dummy" 3 | image { 4 | height: 8 5 | width: 8 6 | colorspace: 4 7 | encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\010\000\000\000\010\010\006\000\000\000\304\017\276\213\000\000\000\036IDATx\234cd8\320\340\360\037\017`\371\361\343\307\217\037\204\024\0204a\260+\000\000\240\302\373\327\246\231O\'\000\000\000\000IEND\256B`\202" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_summary.test_image_with_four_channel_batched.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "dummy" 3 | image { 4 | height: 8 5 | width: 16 6 | colorspace: 4 7 | encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\020\000\000\000\010\010\006\000\000\000\360v\177\227\000\000\000-IDATx\234cd8\320\340\360\037\017`ggg\307\'\317\362\343\307\217\037?\360(\370\001\305x\r\300g\003!0j\000\025\014\000\000\356b\366\370\366\336\316\301\000\000\000\000IEND\256B`\202" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_summary.test_image_with_one_channel.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "dummy" 3 | image { 4 | height: 8 5 | width: 8 6 | colorspace: 3 7 | encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\010\000\000\000\010\010\002\000\000\000Km)\334\000\000\000\031IDATx\234cd``\370\217\r0\376\370\361\003\253\004\313\240\224\000\000;\267\273\313%\020=\255\000\000\000\000IEND\256B`\202" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_summary.test_image_with_one_channel_batched.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "dummy" 3 | image { 4 | height: 8 5 | width: 16 6 | colorspace: 3 7 | encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\020\000\000\000\010\010\002\000\000\000\177\024\350\300\000\000\000(IDATx\234cd``\370\217\r\034?~\034\2538\313\217\037?~\374\370\201)\201U\020\252\001\253\304\250\006$\000\000\230\346y\315\204l;t\000\000\000\000IEND\256B`\202" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_summary.test_image_without_channel.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "dummy" 3 | image { 4 | height: 8 5 | width: 8 6 | colorspace: 3 7 | encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\010\000\000\000\010\010\002\000\000\000Km)\334\000\000\000\031IDATx\234cd``\370\217\r0\376\370\361\003\253\004\313\240\224\000\000;\267\273\313%\020=\255\000\000\000\000IEND\256B`\202" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_summary.test_mesh.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "my_mesh_1" 3 | tensor { 4 | dtype: DT_FLOAT 5 | tensor_shape { 6 | dim { 7 | size: 1 8 | } 9 | dim { 10 | size: 4 11 | } 12 | dim { 13 | size: 3 14 | } 15 | } 16 | float_val: 1.0 17 | float_val: 1.0 18 | float_val: 1.0 19 | float_val: -1.0 20 | float_val: -1.0 21 | float_val: 1.0 22 | float_val: 1.0 23 | float_val: -1.0 24 | float_val: -1.0 25 | float_val: -1.0 26 | float_val: 1.0 27 | float_val: -1.0 28 | } 29 | metadata { 30 | plugin_data { 31 | plugin_name: "mesh" 32 | content: "\022\007my_mesh\030\001*\004null2\003\001\004\003" 33 | } 34 | } 35 | } 36 | value { 37 | tag: "my_mesh_2" 38 | tensor { 39 | dtype: DT_FLOAT 40 | tensor_shape { 41 | dim { 42 | size: 1 43 | } 44 | dim { 45 | size: 4 46 | } 47 | dim { 48 | size: 3 49 | } 50 | } 51 | float_val: 0.0 52 | float_val: 2.0 53 | float_val: 3.0 54 | float_val: 0.0 55 | float_val: 3.0 56 | float_val: 1.0 57 | float_val: 0.0 58 | float_val: 1.0 59 | float_val: 2.0 60 | float_val: 1.0 61 | float_val: 3.0 62 | float_val: 2.0 63 | } 64 | metadata { 65 | plugin_data { 66 | plugin_name: "mesh" 67 | content: "\022\007my_mesh\030\002*\004null2\003\001\004\003" 68 | } 69 | } 70 | } 71 | value { 72 | tag: "my_mesh_3" 73 | tensor { 74 | dtype: DT_FLOAT 75 | tensor_shape { 76 | dim { 77 | size: 1 78 | } 79 | dim { 80 | size: 4 81 | } 82 | dim { 83 | size: 3 84 | } 85 | } 86 | float_val: 255.0 87 | float_val: 0.0 88 | float_val: 0.0 89 | float_val: 0.0 90 | float_val: 255.0 91 | float_val: 0.0 92 | float_val: 0.0 93 | float_val: 0.0 94 | float_val: 255.0 95 | float_val: 255.0 96 | float_val: 0.0 97 | float_val: 255.0 98 | } 99 | metadata { 100 | plugin_data { 101 | plugin_name: "mesh" 102 | content: "\022\007my_mesh\030\003*\004null2\003\001\004\003" 103 | } 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_summary.test_text.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "dummy/text_summary" 3 | tensor { 4 | dtype: DT_STRING 5 | tensor_shape { 6 | dim { 7 | size: 1 8 | } 9 | } 10 | string_val: "text 123" 11 | } 12 | metadata { 13 | plugin_data { 14 | plugin_name: "text" 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /tensorboardX/tests/expect/test_summary.test_uint8_image.expect: -------------------------------------------------------------------------------- 1 | value { 2 | tag: "dummy" 3 | image { 4 | height: 32 5 | width: 32 6 | colorspace: 3 7 | encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000 \000\000\000 \010\002\000\000\000\374\030\355\243\000\000\000CIDATx\234cd```\244)PPP\240\251\371,\243\026\014\177\013\030\037