├── mnft
├── requirements.txt
├── .DS_Store
├── __init__.py
├── utils.py
└── callback.py
├── verify.py
├── setup.py
├── data.py
├── train.py
├── README.md
├── model.py
└── .gitignore
/mnft/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 |
--------------------------------------------------------------------------------
/mnft/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shafu0x/NFT-Pytorch-Callback/main/mnft/.DS_Store
--------------------------------------------------------------------------------
/mnft/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | pyexample.
3 |
4 | An example python library.
5 | """
6 | from .callback import NftCallback
7 |
8 | __version__ = "0.1.0"
9 | __author__ = 'Sharif'
10 | __credits__ = 'Algovera'
11 |
--------------------------------------------------------------------------------
/verify.py:
--------------------------------------------------------------------------------
1 | from nftc.utils import verify
2 |
3 | PATH = "/Users/shafu/NFT-Pytorch-Callback/xxxx-model"
4 | OWNER = "0x34e619ef675d6161868cc16cf929f860f88242f7"
5 | LOSS = -0.781
6 | EPOCH = 2
7 | DATE = "2022-09-05 12:29:05.084334"
8 |
9 | HASH = "4535573218e5945ca8c36d8dab34f07b613f11faeb35d5304ede253445107150"
10 |
11 | assert verify(HASH, PATH, OWNER, LOSS, EPOCH, DATE)
12 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name='mnft',
5 | version='0.1.0',
6 | description='Model NFT Pytorch Lightning Callback',
7 | url='https://github.com/SharifElfouly/NFT-Pytorch-Callback',
8 | author='Sharif Elfouly',
9 | author_email='selfouly@gmail.com',
10 | license='BSD 2-clause',
11 | packages=['mnft'],
12 | install_requires=[
13 | 'pytorch-lightning',
14 | 'termcolor',
15 | 'torch',
16 | ],
17 |
18 | classifiers=[
19 | 'Programming Language :: Python :: 3',
20 | 'Programming Language :: Python :: 3.4',
21 | 'Programming Language :: Python :: 3.5',
22 | ],
23 | )
24 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torchvision import datasets, transforms
3 | from torch.utils.data import TensorDataset, DataLoader, random_split
4 | from torchvision.datasets import MNIST
5 |
6 | def prepare_data():
7 | # transforms for images
8 | transform=transforms.Compose([transforms.ToTensor(),
9 | transforms.Normalize((0.1307,), (0.3081,))])
10 |
11 | # prepare transforms standard to MNIST
12 | mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
13 | mnist_train = [mnist_train[i] for i in range(2200)]
14 |
15 | mnist_train, mnist_val = random_split(mnist_train, [2000, 200])
16 |
17 | mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
18 | mnist_test = [mnist_test[i] for i in range(3000,4000)]
19 |
20 | return mnist_train, mnist_val, mnist_test
21 |
22 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | from model import Model
3 | from mnft.callback import NftCallback
4 | from mnft.utils import *
5 | from torch.utils.data import TensorDataset, DataLoader, random_split
6 | from data import prepare_data
7 | import torch
8 |
9 | VERBOSE = False
10 | EPOCHS = 10
11 | BATCH_SIZE = 64
12 | OWNER = "0x34e619ef675d6161868cc16cf929f860f88242f7"
13 |
14 | train, val, test = prepare_data()
15 |
16 | train_loader = DataLoader(train, batch_size=BATCH_SIZE)
17 | val_loader = DataLoader(val, batch_size=BATCH_SIZE)
18 | test_loader = DataLoader(test, batch_size=BATCH_SIZE)
19 |
20 | model = Model()
21 |
22 | trainer = pl.Trainer(max_epochs=EPOCHS, callbacks=[NftCallback(OWNER)], enable_progress_bar=VERBOSE)
23 | trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
24 |
25 | # save model
26 | torch.save(model, "test-model")
27 |
--------------------------------------------------------------------------------
/mnft/utils.py:
--------------------------------------------------------------------------------
1 | from hashlib import sha256
2 | import time
3 | import torch
4 | from datetime import datetime
5 |
6 | def sha(string): return sha256(string.encode()).hexdigest()
7 |
8 | def get_model_size(model):
9 | "return model size in mb"
10 | param_size = 0
11 | for param in model.parameters():
12 | param_size += param.nelement() * param.element_size()
13 | buffer_size = 0
14 | for buffer in model.buffers():
15 | buffer_size += buffer.nelement() * buffer.element_size()
16 |
17 | size_all_mb = (param_size + buffer_size) / 1024**2
18 | return size_all_mb
19 |
20 | def hash_model_weights(model):
21 | # we could use `named_parameters()` here as well
22 | return sha(str(list(model.parameters())))
23 |
24 | def hash_train_data(data):
25 | # TODO
26 | # hashing the whole data probably does not make sense
27 | pass
28 |
29 | def hash_training(model, owner, loss, epoch, date=datetime.now()):
30 | model_hash = hash_model_weights(model)
31 | size = str(get_model_size(model))
32 | date = str(date)
33 | loss = '{:.3f}'.format(loss)
34 | epoch = str(epoch)
35 |
36 | s = model_hash + owner + size + date + loss + epoch
37 | return sha(s)
38 |
39 | def verify(model_hash, model_path, owner, loss, epoch, date):
40 | "verify if `model_hash` is correct"
41 | model = torch.load(model_path)
42 | return model_hash == hash_training(model, owner, loss, epoch, date)
43 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NFT-Pytorch-Callback
2 |
3 | Generously supported by the Algovera Community. See more details [here](https://forum.algovera.ai/t/proposal-pytorch-nft-checkpoint/100).
4 |
5 | A custom pytorch NFT-checkpoint that hashes the current network weights, some metadata (data, accuracy, etc…) and your eth address (could be turned into some kind of standard later) every N epoch, which proves who did the network training.
6 |
7 | This NFT could represent a tradable license of some sort for example.
8 |
9 |
12 |
13 | ### MINT
14 |
15 | Take your hash and mint your model NFT at `m-nft.com`
16 |
17 | ### Installation
18 |
19 | ```
20 | pip3 install mnft
21 | ```
22 |
23 | ### Mint your NFT
24 |
25 | Visit [nft-model.vercel.app](https://nft-model.vercel.app/) to mint your Model NFT.
26 |
27 | ### Description
28 | Trying to determine who actually did the training of a machine learning model is currently very hard. Looking at open-source model zoos today it is nearly impossible to determine who trained which model. The Pytorch NFT checkpoint solves this problem.
29 |
30 | A custom NFT is generated each epoch, which proves who generated which network weights. This could be used as a badge of honor or turned into a tradable license.
31 |
32 | You can also watch me presenting the idea in the video below (timestamp 1:34).
33 |
34 |
35 |
38 |
--------------------------------------------------------------------------------
/mnft/callback.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning.callbacks import Callback
2 | from .utils import hash_training
3 | from termcolor import cprint
4 |
5 | class NftCallback(Callback):
6 | def __init__(self, owner):
7 | self.owner = owner
8 | self.epochs = 0
9 | self.hashes = []
10 |
11 | def on_train_epoch_end(self, trainer, pl_module):
12 | pass
13 | # print("on_train_epoch_end")
14 |
15 | def on_validation_epoch_end(self, trainer, pl_module):
16 | loss = float(trainer.callback_metrics["loss"])
17 |
18 | h = hash_training(trainer.model, self.owner, loss, self.epochs)
19 | d = {
20 | "epoch": self.epochs,
21 | "loss": loss,
22 | "hash": h
23 | }
24 | self.hashes.append(d)
25 |
26 | self.print_hash(d)
27 |
28 | self.epochs += 1
29 |
30 |
31 | def on_train_end(self, trainer, pl_module):
32 | # print("on_train_end")
33 | self.print_hashes(self.hashes)
34 |
35 | cprint("Mint Your Model Training NFT now! Visit www.m-nft.com",
36 | "red",
37 | attrs=["bold", "blink"])
38 |
39 | def on_validation_end(self, trainer, pl_module):
40 | # print("on_val_end")
41 | pass
42 |
43 | def print_hashes(self, losses):
44 | print()
45 | cprint("Summary", "green", attrs=["bold"])
46 | for loss in losses:
47 | self.print_hash(loss)
48 |
49 | print()
50 |
51 | cprint("Lowest Loss", "green", attrs=["bold"])
52 | lowest_loss = sorted(self.hashes, key=lambda d: d['loss'], reverse=False)[0]
53 | self.print_hash(lowest_loss)
54 | print()
55 |
56 | def print_hash(self, loss):
57 | e = loss["epoch"]
58 | l = '{:.3f}'.format(round(loss["loss"], 3))
59 | h = loss["hash"]
60 | print(f"epoch {e}: loss {l} - hash {h}")
61 |
62 | @staticmethod
63 | def verify(model_hash, model_path, owner, loss, epoch):
64 | return utils.verify(model_hash, model_path, owner, loss, epoch)
65 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch import optim, nn, utils, Tensor
3 | from torchvision.datasets import MNIST
4 | from torchvision.transforms import ToTensor
5 | import pytorch_lightning as pl
6 | import torch
7 | from torch.nn import functional as F
8 |
9 |
10 | # define the LightningModule
11 | class Model(pl.LightningModule):
12 | def __init__(self):
13 | super(Model, self).__init__()
14 |
15 | # mnist images are (1, 28, 28) (channels, width, height)
16 | self.layer_1 = torch.nn.Linear(28 * 28, 128)
17 | self.layer_2 = torch.nn.Linear(128, 256)
18 | self.layer_3 = torch.nn.Linear(256, 10)
19 |
20 | def forward(self, x):
21 | batch_size, channels, width, height = x.size()
22 |
23 | # (b, 1, 28, 28) -> (b, 1*28*28)
24 | x = x.view(batch_size, -1)
25 |
26 | # layer 1 (b, 1*28*28) -> (b, 128)
27 | x = self.layer_1(x)
28 | x = torch.relu(x)
29 |
30 | # layer 2 (b, 128) -> (b, 256)
31 | x = self.layer_2(x)
32 | x = torch.relu(x)
33 |
34 | # layer 3 (b, 256) -> (b, 10)
35 | x = self.layer_3(x)
36 |
37 | # probability distribution over labels
38 | x = torch.softmax(x, dim=1)
39 |
40 | return x
41 |
42 | def cross_entropy_loss(self, logits, labels):
43 | return F.nll_loss(logits, labels)
44 |
45 | def training_step(self, train_batch, batch_idx):
46 | x, y = train_batch
47 | logits = self.forward(x)
48 | loss = self.cross_entropy_loss(logits, y)
49 |
50 | logs = {'train_loss': loss}
51 | return {'loss': loss, 'log': logs}
52 |
53 | def validation_step(self, val_batch, batch_idx):
54 | # print("VAL STEP")
55 | x, y = val_batch
56 | logits = self.forward(x)
57 | loss = self.cross_entropy_loss(logits, y)
58 | return {'val_loss': loss}
59 |
60 | def test_step(self, val_batch, batch_idx):
61 | x, y = val_batch
62 | logits = self.forward(x)
63 | loss = self.cross_entropy_loss(logits, y)
64 | return {'test_loss': loss}
65 |
66 | def validation_epoch_end(self, outputs):
67 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
68 | tensorboard_logs = {'val_loss': avg_loss}
69 | self.log("loss", avg_loss)
70 | return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
71 |
72 | def test_epoch_end(self, outputs):
73 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
74 | tensorboard_logs = {'test_loss': avg_loss}
75 | return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
76 |
77 | def configure_optimizers(self):
78 | optimizer = torch.optim.Adam(self.parameters())
79 | lr_scheduler = {'scheduler': torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.95),
80 | 'name': 'expo_lr'}
81 | return [optimizer], [lr_scheduler]
82 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | test-model
2 | .vim/*
3 |
4 | MNIST/*
5 |
6 | lightning_logs/*
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 | cover/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | .pybuilder/
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | # For a library or package, you might want to ignore these files since the code is
94 | # intended to run in multiple environments; otherwise, check them in:
95 | # .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # poetry
105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109 | #poetry.lock
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | #pdm.lock
114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115 | # in version control.
116 | # https://pdm.fming.dev/#use-with-ide
117 | .pdm.toml
118 |
119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120 | __pypackages__/
121 |
122 | # Celery stuff
123 | celerybeat-schedule
124 | celerybeat.pid
125 |
126 | # SageMath parsed files
127 | *.sage.py
128 |
129 | # Environments
130 | .env
131 | .venv
132 | env/
133 | venv/
134 | ENV/
135 | env.bak/
136 | venv.bak/
137 |
138 | # Spyder project settings
139 | .spyderproject
140 | .spyproject
141 |
142 | # Rope project settings
143 | .ropeproject
144 |
145 | # mkdocs documentation
146 | /site
147 |
148 | # mypy
149 | .mypy_cache/
150 | .dmypy.json
151 | dmypy.json
152 |
153 | # Pyre type checker
154 | .pyre/
155 |
156 | # pytype static type analyzer
157 | .pytype/
158 |
159 | # Cython debug symbols
160 | cython_debug/
161 |
162 | # PyCharm
163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165 | # and can be added to the global gitignore or merged into this file. For a more nuclear
166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167 | #.idea/
168 |
--------------------------------------------------------------------------------