├── mnft ├── requirements.txt ├── .DS_Store ├── __init__.py ├── utils.py └── callback.py ├── verify.py ├── setup.py ├── data.py ├── train.py ├── README.md ├── model.py └── .gitignore /mnft/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | -------------------------------------------------------------------------------- /mnft/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shafu0x/NFT-Pytorch-Callback/main/mnft/.DS_Store -------------------------------------------------------------------------------- /mnft/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | pyexample. 3 | 4 | An example python library. 5 | """ 6 | from .callback import NftCallback 7 | 8 | __version__ = "0.1.0" 9 | __author__ = 'Sharif' 10 | __credits__ = 'Algovera' 11 | -------------------------------------------------------------------------------- /verify.py: -------------------------------------------------------------------------------- 1 | from nftc.utils import verify 2 | 3 | PATH = "/Users/shafu/NFT-Pytorch-Callback/xxxx-model" 4 | OWNER = "0x34e619ef675d6161868cc16cf929f860f88242f7" 5 | LOSS = -0.781 6 | EPOCH = 2 7 | DATE = "2022-09-05 12:29:05.084334" 8 | 9 | HASH = "4535573218e5945ca8c36d8dab34f07b613f11faeb35d5304ede253445107150" 10 | 11 | assert verify(HASH, PATH, OWNER, LOSS, EPOCH, DATE) 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='mnft', 5 | version='0.1.0', 6 | description='Model NFT Pytorch Lightning Callback', 7 | url='https://github.com/SharifElfouly/NFT-Pytorch-Callback', 8 | author='Sharif Elfouly', 9 | author_email='selfouly@gmail.com', 10 | license='BSD 2-clause', 11 | packages=['mnft'], 12 | install_requires=[ 13 | 'pytorch-lightning', 14 | 'termcolor', 15 | 'torch', 16 | ], 17 | 18 | classifiers=[ 19 | 'Programming Language :: Python :: 3', 20 | 'Programming Language :: Python :: 3.4', 21 | 'Programming Language :: Python :: 3.5', 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision import datasets, transforms 3 | from torch.utils.data import TensorDataset, DataLoader, random_split 4 | from torchvision.datasets import MNIST 5 | 6 | def prepare_data(): 7 | # transforms for images 8 | transform=transforms.Compose([transforms.ToTensor(), 9 | transforms.Normalize((0.1307,), (0.3081,))]) 10 | 11 | # prepare transforms standard to MNIST 12 | mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform) 13 | mnist_train = [mnist_train[i] for i in range(2200)] 14 | 15 | mnist_train, mnist_val = random_split(mnist_train, [2000, 200]) 16 | 17 | mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform) 18 | mnist_test = [mnist_test[i] for i in range(3000,4000)] 19 | 20 | return mnist_train, mnist_val, mnist_test 21 | 22 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from model import Model 3 | from mnft.callback import NftCallback 4 | from mnft.utils import * 5 | from torch.utils.data import TensorDataset, DataLoader, random_split 6 | from data import prepare_data 7 | import torch 8 | 9 | VERBOSE = False 10 | EPOCHS = 10 11 | BATCH_SIZE = 64 12 | OWNER = "0x34e619ef675d6161868cc16cf929f860f88242f7" 13 | 14 | train, val, test = prepare_data() 15 | 16 | train_loader = DataLoader(train, batch_size=BATCH_SIZE) 17 | val_loader = DataLoader(val, batch_size=BATCH_SIZE) 18 | test_loader = DataLoader(test, batch_size=BATCH_SIZE) 19 | 20 | model = Model() 21 | 22 | trainer = pl.Trainer(max_epochs=EPOCHS, callbacks=[NftCallback(OWNER)], enable_progress_bar=VERBOSE) 23 | trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader) 24 | 25 | # save model 26 | torch.save(model, "test-model") 27 | -------------------------------------------------------------------------------- /mnft/utils.py: -------------------------------------------------------------------------------- 1 | from hashlib import sha256 2 | import time 3 | import torch 4 | from datetime import datetime 5 | 6 | def sha(string): return sha256(string.encode()).hexdigest() 7 | 8 | def get_model_size(model): 9 | "return model size in mb" 10 | param_size = 0 11 | for param in model.parameters(): 12 | param_size += param.nelement() * param.element_size() 13 | buffer_size = 0 14 | for buffer in model.buffers(): 15 | buffer_size += buffer.nelement() * buffer.element_size() 16 | 17 | size_all_mb = (param_size + buffer_size) / 1024**2 18 | return size_all_mb 19 | 20 | def hash_model_weights(model): 21 | # we could use `named_parameters()` here as well 22 | return sha(str(list(model.parameters()))) 23 | 24 | def hash_train_data(data): 25 | # TODO 26 | # hashing the whole data probably does not make sense 27 | pass 28 | 29 | def hash_training(model, owner, loss, epoch, date=datetime.now()): 30 | model_hash = hash_model_weights(model) 31 | size = str(get_model_size(model)) 32 | date = str(date) 33 | loss = '{:.3f}'.format(loss) 34 | epoch = str(epoch) 35 | 36 | s = model_hash + owner + size + date + loss + epoch 37 | return sha(s) 38 | 39 | def verify(model_hash, model_path, owner, loss, epoch, date): 40 | "verify if `model_hash` is correct" 41 | model = torch.load(model_path) 42 | return model_hash == hash_training(model, owner, loss, epoch, date) 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NFT-Pytorch-Callback 2 | 3 | Generously supported by the Algovera Community. See more details [here](https://forum.algovera.ai/t/proposal-pytorch-nft-checkpoint/100). 4 | 5 | A custom pytorch NFT-checkpoint that hashes the current network weights, some metadata (data, accuracy, etc…) and your eth address (could be turned into some kind of standard later) every N epoch, which proves who did the network training. 6 | 7 | This NFT could represent a tradable license of some sort for example. 8 | 9 | IMAGE ALT TEXT HERE 12 | 13 | ### MINT 14 | 15 | Take your hash and mint your model NFT at `m-nft.com` 16 | 17 | ### Installation 18 | 19 | ``` 20 | pip3 install mnft 21 | ``` 22 | 23 | ### Mint your NFT 24 | 25 | Visit [nft-model.vercel.app](https://nft-model.vercel.app/) to mint your Model NFT. 26 | 27 | ### Description 28 | Trying to determine who actually did the training of a machine learning model is currently very hard. Looking at open-source model zoos today it is nearly impossible to determine who trained which model. The Pytorch NFT checkpoint solves this problem. 29 | 30 | A custom NFT is generated each epoch, which proves who generated which network weights. This could be used as a badge of honor or turned into a tradable license. 31 | 32 | You can also watch me presenting the idea in the video below (timestamp 1:34). 33 | 34 | 35 | IMAGE ALT TEXT HERE 38 | -------------------------------------------------------------------------------- /mnft/callback.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import Callback 2 | from .utils import hash_training 3 | from termcolor import cprint 4 | 5 | class NftCallback(Callback): 6 | def __init__(self, owner): 7 | self.owner = owner 8 | self.epochs = 0 9 | self.hashes = [] 10 | 11 | def on_train_epoch_end(self, trainer, pl_module): 12 | pass 13 | # print("on_train_epoch_end") 14 | 15 | def on_validation_epoch_end(self, trainer, pl_module): 16 | loss = float(trainer.callback_metrics["loss"]) 17 | 18 | h = hash_training(trainer.model, self.owner, loss, self.epochs) 19 | d = { 20 | "epoch": self.epochs, 21 | "loss": loss, 22 | "hash": h 23 | } 24 | self.hashes.append(d) 25 | 26 | self.print_hash(d) 27 | 28 | self.epochs += 1 29 | 30 | 31 | def on_train_end(self, trainer, pl_module): 32 | # print("on_train_end") 33 | self.print_hashes(self.hashes) 34 | 35 | cprint("Mint Your Model Training NFT now! Visit www.m-nft.com", 36 | "red", 37 | attrs=["bold", "blink"]) 38 | 39 | def on_validation_end(self, trainer, pl_module): 40 | # print("on_val_end") 41 | pass 42 | 43 | def print_hashes(self, losses): 44 | print() 45 | cprint("Summary", "green", attrs=["bold"]) 46 | for loss in losses: 47 | self.print_hash(loss) 48 | 49 | print() 50 | 51 | cprint("Lowest Loss", "green", attrs=["bold"]) 52 | lowest_loss = sorted(self.hashes, key=lambda d: d['loss'], reverse=False)[0] 53 | self.print_hash(lowest_loss) 54 | print() 55 | 56 | def print_hash(self, loss): 57 | e = loss["epoch"] 58 | l = '{:.3f}'.format(round(loss["loss"], 3)) 59 | h = loss["hash"] 60 | print(f"epoch {e}: loss {l} - hash {h}") 61 | 62 | @staticmethod 63 | def verify(model_hash, model_path, owner, loss, epoch): 64 | return utils.verify(model_hash, model_path, owner, loss, epoch) 65 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch import optim, nn, utils, Tensor 3 | from torchvision.datasets import MNIST 4 | from torchvision.transforms import ToTensor 5 | import pytorch_lightning as pl 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | # define the LightningModule 11 | class Model(pl.LightningModule): 12 | def __init__(self): 13 | super(Model, self).__init__() 14 | 15 | # mnist images are (1, 28, 28) (channels, width, height) 16 | self.layer_1 = torch.nn.Linear(28 * 28, 128) 17 | self.layer_2 = torch.nn.Linear(128, 256) 18 | self.layer_3 = torch.nn.Linear(256, 10) 19 | 20 | def forward(self, x): 21 | batch_size, channels, width, height = x.size() 22 | 23 | # (b, 1, 28, 28) -> (b, 1*28*28) 24 | x = x.view(batch_size, -1) 25 | 26 | # layer 1 (b, 1*28*28) -> (b, 128) 27 | x = self.layer_1(x) 28 | x = torch.relu(x) 29 | 30 | # layer 2 (b, 128) -> (b, 256) 31 | x = self.layer_2(x) 32 | x = torch.relu(x) 33 | 34 | # layer 3 (b, 256) -> (b, 10) 35 | x = self.layer_3(x) 36 | 37 | # probability distribution over labels 38 | x = torch.softmax(x, dim=1) 39 | 40 | return x 41 | 42 | def cross_entropy_loss(self, logits, labels): 43 | return F.nll_loss(logits, labels) 44 | 45 | def training_step(self, train_batch, batch_idx): 46 | x, y = train_batch 47 | logits = self.forward(x) 48 | loss = self.cross_entropy_loss(logits, y) 49 | 50 | logs = {'train_loss': loss} 51 | return {'loss': loss, 'log': logs} 52 | 53 | def validation_step(self, val_batch, batch_idx): 54 | # print("VAL STEP") 55 | x, y = val_batch 56 | logits = self.forward(x) 57 | loss = self.cross_entropy_loss(logits, y) 58 | return {'val_loss': loss} 59 | 60 | def test_step(self, val_batch, batch_idx): 61 | x, y = val_batch 62 | logits = self.forward(x) 63 | loss = self.cross_entropy_loss(logits, y) 64 | return {'test_loss': loss} 65 | 66 | def validation_epoch_end(self, outputs): 67 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 68 | tensorboard_logs = {'val_loss': avg_loss} 69 | self.log("loss", avg_loss) 70 | return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} 71 | 72 | def test_epoch_end(self, outputs): 73 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() 74 | tensorboard_logs = {'test_loss': avg_loss} 75 | return {'avg_test_loss': avg_loss, 'log': tensorboard_logs} 76 | 77 | def configure_optimizers(self): 78 | optimizer = torch.optim.Adam(self.parameters()) 79 | lr_scheduler = {'scheduler': torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.95), 80 | 'name': 'expo_lr'} 81 | return [optimizer], [lr_scheduler] 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | test-model 2 | .vim/* 3 | 4 | MNIST/* 5 | 6 | lightning_logs/* 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | --------------------------------------------------------------------------------