├── examples ├── fastai │ ├── requirements.txt │ └── train.py ├── keras │ ├── requirements.txt │ └── train.py ├── tensorflow │ ├── requirements.txt │ └── train.py ├── pytorch │ ├── requirements.txt │ └── train.py └── demo │ ├── demo.gif │ └── demo.py ├── reloading ├── __init__.py ├── test_reloading.py └── reloading.py ├── setup.cfg ├── .gitignore ├── MANIFEST ├── Makefile ├── LICENSE.txt ├── setup.py └── README.md /examples/fastai/requirements.txt: -------------------------------------------------------------------------------- 1 | fastai -------------------------------------------------------------------------------- /examples/keras/requirements.txt: -------------------------------------------------------------------------------- 1 | keras 2 | numpy -------------------------------------------------------------------------------- /examples/tensorflow/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow 2 | tqdm -------------------------------------------------------------------------------- /reloading/__init__.py: -------------------------------------------------------------------------------- 1 | from .reloading import reloading 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /examples/pytorch/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tqdm 4 | -------------------------------------------------------------------------------- /examples/demo/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/julvo/reloading/HEAD/examples/demo/demo.gif -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.swp 3 | *.pyc 4 | dist 5 | build 6 | .DS_Store 7 | examples/FashionMNIST 8 | .vscode/ 9 | reloading.egg-info 10 | -------------------------------------------------------------------------------- /MANIFEST: -------------------------------------------------------------------------------- 1 | # file GENERATED by distutils, do NOT edit 2 | setup.cfg 3 | setup.py 4 | reloading/__init__.py 5 | reloading/reloading.py 6 | reloading/test_reloading.py 7 | -------------------------------------------------------------------------------- /examples/demo/demo.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | sys.path.insert(0, '../..') 4 | from reloading import reloading 5 | 6 | epochs = 10000 7 | loss = 100 8 | model = { 'weights': [0.2, 0.1, 0.4, 0.8, 0.1] } 9 | 10 | for i in reloading(range(epochs)): 11 | time.sleep(2) 12 | loss /= 2 13 | 14 | print('Epoch:', i, 'Loss:', loss) 15 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PWD := $(shell pwd) 2 | 3 | .PHONY: test 4 | test: test3.6 test3.7 test3.8 5 | 6 | test3.6: 7 | docker run -w /app -v $(PWD):/app python:3.6.10-alpine3.11 python -m unittest 8 | test3.7: 9 | docker run -w /app -v $(PWD):/app python:3.7.7-alpine3.11 python -m unittest 10 | test3.8: 11 | docker run -w /app -v $(PWD):/app python:3.8.3-alpine3.11 python -m unittest 12 | 13 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | Copyright (c) 2019 Julian Vossen 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | The above copyright notice and this permission notice shall be included in all 10 | copies or substantial portions of the Software. 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from os import path 3 | 4 | this_directory = path.abspath(path.dirname(__file__)) 5 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 6 | long_description = f.read() 7 | 8 | setup( 9 | name='reloading', 10 | packages=['reloading'], 11 | version='1.1.2', 12 | license='MIT', 13 | description='Reloads source code of a running program without losing state', 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | author='Julian Vossen', 17 | author_email='pypi@julianvossen.de', 18 | url='https://github.com/julvo/reloading', 19 | download_url='https://github.com/julvo/reloading/archive/v1.1.2.tar.gz', 20 | keywords=['reload', 'reloading', 'refresh', 'loop', 'decorator'], 21 | install_requires=[], 22 | classifiers=[ 23 | 'Development Status :: 3 - Alpha', 24 | 'Intended Audience :: Developers', 25 | 'Topic :: Utilities', 26 | 'License :: OSI Approved :: MIT License', 27 | 'Programming Language :: Python :: 3', 28 | 'Programming Language :: Python :: 3.4', 29 | 'Programming Language :: Python :: 3.5', 30 | 'Programming Language :: Python :: 3.6', 31 | 'Programming Language :: Python :: 3.7', 32 | 'Programming Language :: Python :: 3.8', 33 | 'Programming Language :: Python :: 3.9', 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /examples/pytorch/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../..') 3 | from reloading import reloading 4 | 5 | from torch import nn 6 | from torch.optim import Adam 7 | import torch.nn.functional as F 8 | from torchvision.models import resnet18 9 | from torchvision.datasets import FashionMNIST 10 | from torchvision.transforms import ToTensor 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | 14 | 15 | dataset = FashionMNIST('.', download=True, transform=ToTensor()) 16 | dataloader = DataLoader(dataset, batch_size=8) 17 | 18 | model = resnet18(pretrained=True) 19 | model.fc = nn.Linear(model.fc.in_features, 10) 20 | 21 | optimiser = Adam(model.parameters()) 22 | 23 | for epoch in reloading(range(1000)): 24 | # Try to change the code inside this loop during the training and see how the 25 | # changes are applied without restarting the training 26 | 27 | model.train() 28 | losses = [] 29 | 30 | for images, targets in tqdm(dataloader): 31 | losses.append(1) 32 | 33 | optimiser.zero_grad() 34 | predictions = model(images.expand(8, 3, 28, 28)) 35 | loss = F.cross_entropy(predictions, targets) 36 | loss.backward() 37 | optimiser.step() 38 | losses.append(loss.item()) 39 | 40 | # Here would be your validation code 41 | 42 | print(f'Epoch {epoch} - Loss {sum(losses) / len(losses)}') 43 | 44 | 45 | -------------------------------------------------------------------------------- /examples/fastai/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../..') 3 | from reloading import reloading 4 | 5 | from fastai.basic_train import LearnerCallback 6 | from fastai.vision import (URLs, untar_data, ImageDataBunch, 7 | cnn_learner, models, accuracy) 8 | 9 | 10 | @reloading 11 | def set_learning_rate(learner): 12 | # Change the learning rate below during the training 13 | learner.opt.opt.lr = 1e-3 14 | print('Set LR to', learner.opt.opt.lr) 15 | 16 | class LearningRateSetter(LearnerCallback): 17 | def on_epoch_begin(self, **kwargs): 18 | set_learning_rate(self.learn) 19 | 20 | 21 | @reloading 22 | def print_model_statistics(model): 23 | # Uncomment the following lines after during the training 24 | # to start printing statistics 25 | # 26 | # print('{: <28} {: <7} {: <7}'.format('NAME', ' MEAN', ' STDDEV')) 27 | # for name, param in model.named_parameters(): 28 | # mean = param.mean().item() 29 | # std = param.std().item() 30 | # print('{: <28} {: 6.4f} {: 6.4f}'.format(name, mean, std)) 31 | pass 32 | 33 | class ModelStatsPrinter(LearnerCallback): 34 | def on_epoch_begin(self, **kwargs): 35 | print_model_statistics(self.learn.model) 36 | 37 | 38 | path = untar_data(URLs.MNIST_SAMPLE) 39 | data = ImageDataBunch.from_folder(path) 40 | learn = cnn_learner(data, models.resnet18, metrics=accuracy, 41 | callback_fns=[ModelStatsPrinter, LearningRateSetter]) 42 | learn.fit(10) 43 | -------------------------------------------------------------------------------- /examples/keras/train.py: -------------------------------------------------------------------------------- 1 | # Example taken from https://keras.io/getting-started/sequential-model-guide/#examples 2 | import sys 3 | sys.path.insert(0, '../..') 4 | from reloading import reloading 5 | 6 | import keras 7 | from keras import backend as K 8 | from keras.models import Sequential 9 | from keras.layers import Dense, Activation 10 | from keras.optimizers import SGD 11 | from keras.callbacks import Callback 12 | 13 | 14 | @reloading 15 | def set_learning_rate(model): 16 | # Change the below value during training and see how it updates 17 | K.set_value(model.optimizer.lr, 1e-3) 18 | print('Set LR to', K.get_value(model.optimizer.lr)) 19 | 20 | class LearningRateSetter(Callback): 21 | def on_epoch_begin(self, epoch, logs=None): 22 | set_learning_rate(self.model) 23 | 24 | 25 | # Generate dummy data 26 | import numpy as np 27 | x_train = np.random.random((10000, 20)) 28 | y_train = keras.utils.to_categorical(np.random.randint(10, size=(10000, 1)), num_classes=10) 29 | x_test = np.random.random((1000, 20)) 30 | y_test = keras.utils.to_categorical(np.random.randint(10, size=(1000, 1)), num_classes=10) 31 | 32 | model = Sequential() 33 | model.add(Dense(64, activation='relu', input_dim=20)) 34 | model.add(Dense(10, activation='softmax')) 35 | 36 | sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) 37 | model.compile(loss='categorical_crossentropy', 38 | optimizer=sgd, 39 | metrics=['accuracy']) 40 | 41 | model.fit(x_train, y_train, 42 | epochs=200, 43 | batch_size=128, 44 | callbacks=[LearningRateSetter()]) 45 | score = model.evaluate(x_test, y_test, batch_size=128) -------------------------------------------------------------------------------- /examples/tensorflow/train.py: -------------------------------------------------------------------------------- 1 | # Example from https://www.tensorflow.org/tutorials/quickstart/advanced 2 | 3 | from __future__ import absolute_import, division, print_function, unicode_literals 4 | 5 | import sys 6 | sys.path.insert(0, '../..') 7 | from reloading import reloading 8 | 9 | import tensorflow as tf 10 | from tensorflow.keras.layers import Dense, Flatten, Conv2D 11 | from tensorflow.keras import Model 12 | 13 | from tqdm import tqdm 14 | 15 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() 16 | x_train, x_test = x_train / 255.0, x_test / 255.0 17 | 18 | # Add a channels dimension 19 | x_train = x_train[..., tf.newaxis] 20 | x_test = x_test[..., tf.newaxis] 21 | 22 | train_ds = tf.data.Dataset.from_tensor_slices( 23 | (x_train, y_train)).shuffle(10000).batch(32) 24 | 25 | test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32) 26 | 27 | class MyModel(Model): 28 | def __init__(self): 29 | super(MyModel, self).__init__() 30 | self.conv1 = Conv2D(32, 3, activation='relu') 31 | self.flatten = Flatten() 32 | self.d1 = Dense(128, activation='relu') 33 | self.d2 = Dense(10) 34 | 35 | def call(self, x): 36 | x = self.conv1(x) 37 | x = self.flatten(x) 38 | x = self.d1(x) 39 | return self.d2(x) 40 | 41 | model = MyModel() 42 | 43 | loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 44 | 45 | optimizer = tf.keras.optimizers.Adam() 46 | 47 | train_loss = tf.keras.metrics.Mean(name='train_loss') 48 | train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') 49 | 50 | test_loss = tf.keras.metrics.Mean(name='test_loss') 51 | test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') 52 | 53 | @tf.function 54 | def train_step(images, labels): 55 | with tf.GradientTape() as tape: 56 | predictions = model(images, training=True) 57 | loss = loss_object(labels, predictions) 58 | gradients = tape.gradient(loss, model.trainable_variables) 59 | optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 60 | 61 | train_loss(loss) 62 | train_accuracy(labels, predictions) 63 | 64 | @tf.function 65 | def test_step(images, labels): 66 | predictions = model(images, training=False) 67 | t_loss = loss_object(labels, predictions) 68 | 69 | test_loss(t_loss) 70 | test_accuracy(labels, predictions) 71 | 72 | EPOCHS = 5 73 | 74 | for epoch in reloading(range(EPOCHS)): 75 | # Try to change the source code inside this loop during the training to 76 | # see how the changes are applied without restarting the training. 77 | # You can use it e.g. to inspect the model or changing the learning rate. 78 | 79 | train_loss.reset_states() 80 | train_accuracy.reset_states() 81 | test_loss.reset_states() 82 | test_accuracy.reset_states() 83 | 84 | for images, labels in tqdm(train_ds): 85 | train_step(images, labels) 86 | 87 | for test_images, test_labels in tqdm(test_ds): 88 | test_step(test_images, test_labels) 89 | 90 | template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}' 91 | print(template.format(epoch+1, 92 | train_loss.result(), 93 | train_accuracy.result()*100, 94 | test_loss.result(), 95 | test_accuracy.result()*100)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # reloading 2 | [![pypi badge](https://img.shields.io/pypi/v/reloading?color=%230c0)](https://pypi.org/project/reloading/) 3 | 4 | A Python utility to reload a loop body from source on each iteration without 5 | losing state 6 | 7 | Useful for editing source code during training of deep learning models. This lets 8 | you e.g. add logging, print statistics or save the model without restarting the 9 | training and, therefore, without losing the training progress. 10 | 11 | ![Demo](https://github.com/julvo/reloading/blob/master/examples/demo/demo.gif) 12 | 13 | ## Install 14 | ``` 15 | pip install reloading 16 | ``` 17 | 18 | ## Usage 19 | 20 | To reload the body of a `for` loop from source before each iteration, simply 21 | wrap the iterator with `reloading`, e.g. 22 | ```python 23 | from reloading import reloading 24 | 25 | for i in reloading(range(10)): 26 | # this code will be reloaded before each iteration 27 | print(i) 28 | 29 | ``` 30 | 31 | To reload a function from source before each execution, decorate the function 32 | definition with `@reloading`, e.g. 33 | ```python 34 | from reloading import reloading 35 | 36 | @reloading 37 | def some_function(): 38 | # this code will be reloaded before each invocation 39 | pass 40 | ``` 41 | 42 | ## Additional Options 43 | 44 | Pass the keyword argument `every` to reload only on every n-th invocation or iteration. E.g. 45 | ```python 46 | for i in reloading(range(1000), every=10): 47 | # this code will only be reloaded before every 10th iteration 48 | # this can help to speed-up tight loops 49 | pass 50 | 51 | @reloading(every=10) 52 | def some_function(): 53 | # this code with only be reloaded before every 10th invocation 54 | pass 55 | ``` 56 | 57 | Pass `forever=True` instead of an iterable to create an endless reloading loop, e.g. 58 | ```python 59 | for i in reloading(forever=True): 60 | # this code will loop forever and reload from source before each iteration 61 | pass 62 | ``` 63 | 64 | ## Examples 65 | 66 | Here are the short snippets of how to use reloading with your favourite library. 67 | For complete examples, check out the [examples folder](https://github.com/julvo/reloading/blob/master/examples). 68 | 69 | ### PyTorch 70 | ```python 71 | for epoch in reloading(range(NB_EPOCHS)): 72 | # the code inside this outer loop will be reloaded before each epoch 73 | 74 | for images, targets in dataloader: 75 | optimiser.zero_grad() 76 | predictions = model(images) 77 | loss = F.cross_entropy(predictions, targets) 78 | loss.backward() 79 | optimiser.step() 80 | ``` 81 | [Here](https://github.com/julvo/reloading/blob/master/examples/pytorch/train.py) is a full PyTorch example. 82 | 83 | ### fastai 84 | ```python 85 | @reloading 86 | def update_learner(learner): 87 | # this function will be reloaded from source before each epoch so that you 88 | # can make changes to the learner while the training is running 89 | pass 90 | 91 | class LearnerUpdater(LearnerCallback): 92 | def on_epoch_begin(self, **kwargs): 93 | update_learner(self.learn) 94 | 95 | path = untar_data(URLs.MNIST_SAMPLE) 96 | data = ImageDataBunch.from_folder(path) 97 | learn = cnn_learner(data, models.resnet18, metrics=accuracy, 98 | callback_fns=[LearnerUpdater]) 99 | learn.fit(10) 100 | ``` 101 | [Here](https://github.com/julvo/reloading/blob/master/examples/fastai/train.py) is a full fastai example. 102 | 103 | ### Keras 104 | ```python 105 | @reloading 106 | def update_model(model): 107 | # this function will be reloaded from source before each epoch so that you 108 | # can make changes to the model while the training is running using 109 | # K.set_value() 110 | pass 111 | 112 | class ModelUpdater(Callback): 113 | def on_epoch_begin(self, epoch, logs=None): 114 | update_model(self.model) 115 | 116 | model = Sequential() 117 | model.add(Dense(64, activation='relu', input_dim=20)) 118 | model.add(Dense(10, activation='softmax')) 119 | 120 | sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) 121 | model.compile(loss='categorical_crossentropy', 122 | optimizer=sgd, 123 | metrics=['accuracy']) 124 | 125 | model.fit(x_train, y_train, 126 | epochs=200, 127 | batch_size=128, 128 | callbacks=[ModelUpdater()]) 129 | ``` 130 | [Here](https://github.com/julvo/reloading/blob/master/examples/keras/train.py) is a full Keras example. 131 | 132 | ### TensorFlow 133 | ```python 134 | for epoch in reloading(range(NB_EPOCHS)): 135 | # the code inside this outer loop will be reloaded from source 136 | # before each epoch so that you can change it during training 137 | 138 | train_loss.reset_states() 139 | train_accuracy.reset_states() 140 | test_loss.reset_states() 141 | test_accuracy.reset_states() 142 | 143 | for images, labels in tqdm(train_ds): 144 | train_step(images, labels) 145 | 146 | for test_images, test_labels in tqdm(test_ds): 147 | test_step(test_images, test_labels) 148 | ``` 149 | [Here](https://github.com/julvo/reloading/blob/master/examples/tensorflow/train.py) is a full TensorFlow example. 150 | 151 | ## Testing 152 | 153 | Make sure you have `python` and `python3` available in your path, then run: 154 | ``` 155 | $ python3 reloading/test_reloading.py 156 | ``` 157 | -------------------------------------------------------------------------------- /reloading/test_reloading.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import subprocess as sp 4 | import time 5 | 6 | from reloading import reloading 7 | 8 | SRC_FILE_NAME = "temporary_testing_file.py" 9 | 10 | TEST_CHANGING_SOURCE_LOOP_CONTENT = """ 11 | from reloading import reloading 12 | from time import sleep 13 | 14 | for epoch in reloading(range(10)): 15 | sleep(0.2) 16 | print('INITIAL_FILE_CONTENTS') 17 | """ 18 | 19 | TEST_CHANGING_LINE_NUMBER_OF_LOOP = """ 20 | from reloading import reloading 21 | from time import sleep 22 | 23 | pass 24 | 25 | for epoch in reloading(range(10)): 26 | sleep(0.2) 27 | print('INITIAL_FILE_CONTENTS') 28 | """ 29 | 30 | TEST_CHANGING_SOURCE_FN_CONTENT = """ 31 | from reloading import reloading 32 | from time import sleep 33 | 34 | @reloading 35 | def reload_this_fn(): 36 | print('INITIAL_FILE_CONTENTS') 37 | 38 | for epoch in reloading(range(10)): 39 | sleep(0.2) 40 | reload_this_fn() 41 | """ 42 | 43 | TEST_KEEP_LOCAL_VARIABLES_CONTENT = """ 44 | from reloading import reloading 45 | from time import sleep 46 | 47 | fpath = "DON'T CHANGE ME" 48 | for epoch in reloading(range(1)): 49 | assert fpath == "DON'T CHANGE ME" 50 | """ 51 | 52 | TEST_PERSIST_AFTER_LOOP = """ 53 | from reloading import reloading 54 | from time import sleep 55 | 56 | state = 'INIT' 57 | for epoch in reloading(range(1)): 58 | state = 'CHANGED' 59 | 60 | assert state == 'CHANGED' 61 | """ 62 | 63 | TEST_COMMENT_AFTER_LOOP_CONTENT = """ 64 | from reloading import reloading 65 | from time import sleep 66 | 67 | for epoch in reloading(range(10)): 68 | sleep(0.2) 69 | print('INITIAL_FILE_CONTENTS') 70 | 71 | # a comment here should not cause an error 72 | """ 73 | 74 | TEST_FORMAT_STR_IN_LOOP_CONTENT = """ 75 | from reloading import reloading 76 | from time import sleep 77 | 78 | for epoch in reloading(range(10)): 79 | sleep(0.2) 80 | file_contents = 'FILE_CONTENTS' 81 | print(f'INITIAL_{file_contents}') 82 | """ 83 | 84 | TEST_FUNCTION_AFTER = """ 85 | from reloading import reloading 86 | from time import sleep 87 | 88 | @reloading 89 | def some_func(a, b): 90 | sleep(0.2) 91 | print(a+b) 92 | 93 | for _ in range(10): 94 | some_func(2,1) 95 | """ 96 | 97 | 98 | def run_and_update_source(init_src, updated_src=None, update_after=0.5, bin="python3"): 99 | """Runs init_src in a subprocess and updates source to updated_src after 100 | update_after seconds. Returns the standard output of the subprocess and 101 | whether the subprocess produced an uncaught exception. 102 | """ 103 | with open(SRC_FILE_NAME, "w") as f: 104 | f.write(init_src) 105 | 106 | cmd = [bin, SRC_FILE_NAME] 107 | with sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.PIPE) as proc: 108 | if updated_src is not None: 109 | time.sleep(update_after) 110 | with open(SRC_FILE_NAME, "w") as f: 111 | f.write(updated_src) 112 | 113 | try: 114 | stdout, _ = proc.communicate(timeout=2) 115 | stdout = stdout.decode("utf-8") 116 | has_error = False 117 | except: 118 | stdout = "" 119 | has_error = True 120 | proc.terminate() 121 | 122 | if os.path.isfile(SRC_FILE_NAME): 123 | os.remove(SRC_FILE_NAME) 124 | 125 | return stdout, has_error 126 | 127 | 128 | class TestReloading(unittest.TestCase): 129 | def test_simple_looping(self): 130 | iters = 0 131 | for _ in reloading(range(10)): 132 | iters += 1 133 | 134 | def test_changing_source_loop(self): 135 | for bin in ["python", "python3"]: 136 | stdout, _ = run_and_update_source( 137 | init_src=TEST_CHANGING_SOURCE_LOOP_CONTENT, 138 | updated_src=TEST_CHANGING_SOURCE_LOOP_CONTENT.replace("INITIAL", "CHANGED").rstrip("\n"), 139 | bin=bin, 140 | ) 141 | 142 | self.assertTrue("INITIAL_FILE_CONTENTS" in stdout and "CHANGED_FILE_CONTENTS" in stdout) 143 | 144 | def test_changing_line_number_of_loop(self): 145 | for bin in ["python", "python3"]: 146 | stdout, _ = run_and_update_source( 147 | init_src=TEST_CHANGING_LINE_NUMBER_OF_LOOP, 148 | updated_src=( 149 | TEST_CHANGING_LINE_NUMBER_OF_LOOP 150 | .replace("pass", "pass\npass\n") 151 | .replace("INITIAL", "CHANGED") 152 | .rstrip("\n") 153 | ), 154 | bin=bin, 155 | ) 156 | 157 | self.assertTrue("INITIAL_FILE_CONTENTS" in stdout and "CHANGED_FILE_CONTENTS" in stdout) 158 | 159 | def test_comment_after_loop(self): 160 | for bin in ["python", "python3"]: 161 | stdout, _ = run_and_update_source( 162 | init_src=TEST_COMMENT_AFTER_LOOP_CONTENT, 163 | updated_src=TEST_COMMENT_AFTER_LOOP_CONTENT.replace("INITIAL", "CHANGED").rstrip("\n"), 164 | bin=bin, 165 | ) 166 | 167 | self.assertTrue("INITIAL_FILE_CONTENTS" in stdout and "CHANGED_FILE_CONTENTS" in stdout) 168 | 169 | def test_format_str_in_loop(self): 170 | stdout, _ = run_and_update_source( 171 | init_src=TEST_FORMAT_STR_IN_LOOP_CONTENT, 172 | updated_src=TEST_FORMAT_STR_IN_LOOP_CONTENT.replace("INITIAL", "CHANGED").rstrip("\n"), 173 | bin="python3", 174 | ) 175 | 176 | self.assertTrue("INITIAL_FILE_CONTENTS" in stdout and "CHANGED_FILE_CONTENTS" in stdout) 177 | 178 | def test_keep_local_variables(self): 179 | for bin in ["python", "python3"]: 180 | _, has_error = run_and_update_source(init_src=TEST_KEEP_LOCAL_VARIABLES_CONTENT, bin=bin) 181 | self.assertFalse(has_error) 182 | 183 | def test_persist_after_loop(self): 184 | for bin in ["python", "python3"]: 185 | _, has_error = run_and_update_source(init_src=TEST_PERSIST_AFTER_LOOP, bin=bin) 186 | self.assertFalse(has_error) 187 | 188 | def test_simple_function(self): 189 | @reloading 190 | def some_func(): 191 | return "result" 192 | 193 | self.assertTrue(some_func() == "result") 194 | 195 | def test_reloading_function(self): 196 | for bin in ["python", "python3"]: 197 | stdout, _ = run_and_update_source( 198 | init_src=TEST_FUNCTION_AFTER, 199 | updated_src=TEST_FUNCTION_AFTER.replace("a+b", "a-b"), 200 | bin=bin, 201 | ) 202 | self.assertTrue("3" in stdout and "1" in stdout) 203 | 204 | def test_changing_source_function(self): 205 | for bin in ["python", "python3"]: 206 | stdout, _ = run_and_update_source( 207 | init_src=TEST_CHANGING_SOURCE_FN_CONTENT, 208 | updated_src=TEST_CHANGING_SOURCE_FN_CONTENT.replace("INITIAL", "CHANGED").rstrip("\n"), 209 | bin=bin, 210 | ) 211 | 212 | self.assertTrue("INITIAL_FILE_CONTENTS" in stdout and "CHANGED_FILE_CONTENTS" in stdout) 213 | 214 | 215 | if __name__ == "__main__": 216 | unittest.main() 217 | -------------------------------------------------------------------------------- /reloading/reloading.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import sys 3 | import ast 4 | import traceback 5 | import types 6 | from itertools import chain 7 | from functools import partial, update_wrapper 8 | 9 | 10 | # have to make our own partial in case someone wants to use reloading as a iterator without any arguments 11 | # they would get a partial back because a call without a iterator argument is assumed to be a decorator. 12 | # getting a "TypeError: 'functools.partial' object is not iterable" 13 | # which is not really descriptive. 14 | # hence we overwrite the iter to make sure that the error makes sense. 15 | class no_iter_partial(partial): 16 | def __iter__(self): 17 | raise TypeError("Nothing to iterate over. Please pass an iterable to reloading.") 18 | 19 | 20 | def reloading(fn_or_seq=None, every=1, forever=None): 21 | """Wraps a loop iterator or decorates a function to reload the source code 22 | before every loop iteration or function invocation. 23 | 24 | When wrapped around the outermost iterator in a `for` loop, e.g. 25 | `for i in reloading(range(10))`, causes the loop body to reload from source 26 | before every iteration while keeping the state. 27 | When used as a function decorator, the decorated function is reloaded from 28 | source before each execution. 29 | 30 | Pass the integer keyword argument `every` to reload the source code 31 | only every n-th iteration/invocation. 32 | 33 | Args: 34 | fn_or_seq (function | iterable): A function or loop iterator which should 35 | be reloaded from source before each invocation or iteration, 36 | respectively 37 | every (int, Optional): After how many iterations/invocations to reload 38 | forever (bool, Optional): Pass `forever=true` instead of an iterator to 39 | create an endless loop 40 | 41 | """ 42 | if fn_or_seq: 43 | if isinstance(fn_or_seq, types.FunctionType): 44 | return _reloading_function(fn_or_seq, every=every) 45 | return _reloading_loop(fn_or_seq, every=every) 46 | if forever: 47 | return _reloading_loop(iter(int, 1), every=every) 48 | 49 | # return this function with the keyword arguments partialed in, 50 | # so that the return value can be used as a decorator 51 | decorator = update_wrapper(no_iter_partial(reloading, every=every), reloading) 52 | return decorator 53 | 54 | 55 | def unique_name(used): 56 | # get the longest element of the used names and append a "0" 57 | return max(used, key=len) + "0" 58 | 59 | 60 | def format_itervars(ast_node): 61 | """Formats an `ast_node` of loop iteration variables as string, e.g. 'a, b'""" 62 | 63 | # handle the case that there only is a single loop var 64 | if isinstance(ast_node, ast.Name): 65 | return ast_node.id 66 | 67 | names = [] 68 | for child in ast_node.elts: 69 | if isinstance(child, ast.Name): 70 | names.append(child.id) 71 | elif isinstance(child, ast.Tuple) or isinstance(child, ast.List): 72 | # if its another tuple, like "a, (b, c)", recurse 73 | names.append("({})".format(format_itervars(child))) 74 | 75 | return ", ".join(names) 76 | 77 | 78 | def load_file(path): 79 | src = "" 80 | # while loop here since while saving, the file may sometimes be empty. 81 | while (src == ""): 82 | with open(path, "r") as f: 83 | src = f.read() 84 | return src + "\n" 85 | 86 | 87 | def parse_file_until_successful(path): 88 | source = load_file(path) 89 | while True: 90 | try: 91 | tree = ast.parse(source) 92 | return tree 93 | except SyntaxError: 94 | handle_exception(path) 95 | source = load_file(path) 96 | 97 | 98 | def isolate_loop_body_and_get_itervars(tree, lineno, loop_id): 99 | """Modifies tree inplace as unclear how to create ast.Module. 100 | Returns itervars""" 101 | candidate_nodes = [] 102 | for node in ast.walk(tree): 103 | if ( 104 | isinstance(node, ast.For) 105 | and isinstance(node.iter, ast.Call) 106 | and node.iter.func.id == "reloading" 107 | and ( 108 | (loop_id is not None and loop_id == get_loop_id(node)) 109 | or getattr(node, "lineno", None) == lineno 110 | ) 111 | ): 112 | candidate_nodes.append(node) 113 | 114 | if len(candidate_nodes) > 1: 115 | raise LookupError( 116 | "The reloading loop is ambigious. Use `reloading` only once per line and make sure that the code in that line is unique within the source file." 117 | ) 118 | 119 | if len(candidate_nodes) < 1: 120 | raise LookupError( 121 | "Could not locate reloading loop. Please make sure the code in the line that uses `reloading` doesn't change between reloads." 122 | ) 123 | 124 | loop_node = candidate_nodes[0] 125 | tree.body = loop_node.body 126 | return loop_node.target, get_loop_id(loop_node) 127 | 128 | 129 | def get_loop_id(ast_node): 130 | """Generates a unique identifier for an `ast_node` of type ast.For to find the loop in the changed source file 131 | """ 132 | return ast.dump(ast_node.target) + "__" + ast.dump(ast_node.iter) 133 | 134 | 135 | def get_loop_code(loop_frame_info, loop_id): 136 | fpath = loop_frame_info[1] 137 | while True: 138 | tree = parse_file_until_successful(fpath) 139 | try: 140 | itervars, found_loop_id = isolate_loop_body_and_get_itervars(tree, lineno=loop_frame_info[2], loop_id=loop_id) 141 | return compile(tree, filename="", mode="exec"), format_itervars(itervars), found_loop_id 142 | except LookupError: 143 | handle_exception(fpath) 144 | 145 | 146 | def handle_exception(fpath): 147 | exc = traceback.format_exc() 148 | exc = exc.replace('File ""', 'File "{}"'.format(fpath)) 149 | sys.stderr.write(exc + "\n") 150 | print("Edit {} and press return to continue".format(fpath)) 151 | sys.stdin.readline() 152 | 153 | 154 | def _reloading_loop(seq, every=1): 155 | loop_frame_info = inspect.stack()[2] 156 | fpath = loop_frame_info[1] 157 | 158 | caller_globals = loop_frame_info[0].f_globals 159 | caller_locals = loop_frame_info[0].f_locals 160 | 161 | # create a unique name in the caller namespace that we can safely write 162 | # the values of the iteration variables into 163 | unique = unique_name(chain(caller_locals.keys(), caller_globals.keys())) 164 | loop_id = None 165 | 166 | for i, itervar_values in enumerate(seq): 167 | if i % every == 0: 168 | compiled_body, itervars, loop_id = get_loop_code(loop_frame_info, loop_id=loop_id) 169 | 170 | caller_locals[unique] = itervar_values 171 | exec(itervars + " = " + unique, caller_globals, caller_locals) 172 | try: 173 | # run main loop body 174 | exec(compiled_body, caller_globals, caller_locals) 175 | except Exception: 176 | handle_exception(fpath) 177 | 178 | return [] 179 | 180 | 181 | def get_decorator_name_or_none(dec_node): 182 | if hasattr(dec_node, "id"): 183 | return dec_node.id 184 | elif hasattr(dec_node.func, "id"): 185 | return dec_node.func.id 186 | elif hasattr(dec_node.func.value, "id"): 187 | return dec_node.func.value.id 188 | else: 189 | return None 190 | 191 | 192 | def strip_reloading_decorator(func): 193 | """Remove the 'reloading' decorator and all decorators before it""" 194 | decorator_names = [get_decorator_name(dec) for dec in func.decorator_list] 195 | reloading_idx = decorator_names.index("reloading") 196 | func.decorator_list = func.decorator_list[reloading_idx + 1:] 197 | 198 | 199 | def isolate_function_def(funcname, tree): 200 | """Strip everything but the function definition from the ast in-place. 201 | Also strips the reloading decorator from the function definition""" 202 | for node in ast.walk(tree): 203 | if ( 204 | isinstance(node, ast.FunctionDef) 205 | and node.name == funcname 206 | and "reloading" in [ 207 | get_decorator_name_or_none(dec) 208 | for dec in node.decorator_list 209 | ] 210 | ): 211 | strip_reloading_decorator(node) 212 | tree.body = [ node ] 213 | return True 214 | return False 215 | 216 | 217 | def get_function_def_code(fpath, fn): 218 | tree = parse_file_until_successful(fpath) 219 | found = isolate_function_def(fn.__name__, tree) 220 | if not found: 221 | return None 222 | compiled = compile(tree, filename="", mode="exec") 223 | return compiled 224 | 225 | 226 | def get_reloaded_function(caller_globals, caller_locals, fpath, fn): 227 | code = get_function_def_code(fpath, fn) 228 | if code is None: 229 | return None 230 | # need to copy locals, otherwise the exec will overwrite the decorated with the undecorated new version 231 | # this became a need after removing the reloading decorator from the newly defined version 232 | caller_locals_copy = caller_locals.copy() 233 | exec(code, caller_globals, caller_locals_copy) 234 | func = caller_locals_copy[fn.__name__] 235 | return func 236 | 237 | 238 | def _reloading_function(fn, every=1): 239 | stack = inspect.stack() 240 | frame, fpath = stack[2][:2] 241 | caller_locals = frame.f_locals 242 | caller_globals = frame.f_globals 243 | 244 | # crutch to use dict as python2 doesn't support nonlocal 245 | state = { 246 | "func": None, 247 | "reloads": 0, 248 | } 249 | 250 | def wrapped(*args, **kwargs): 251 | if state["reloads"] % every == 0: 252 | state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"] 253 | state["reloads"] += 1 254 | while True: 255 | try: 256 | result = state["func"](*args, **kwargs) 257 | return result 258 | except Exception: 259 | handle_exception(fpath) 260 | state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"] 261 | 262 | caller_locals[fn.__name__] = wrapped 263 | return wrapped 264 | --------------------------------------------------------------------------------