|
|
19 | |
|
|
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | .pt
131 | *.pt
132 | weights/
133 |
--------------------------------------------------------------------------------
/modules/flash_attn.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import torch
4 | from einops import rearrange
5 | from torch import einsum
6 | from torch.autograd.function import Function
7 |
8 |
9 | EPSILON = 1e-6
10 | exists = lambda val: val is not None
11 | default = lambda val, d: val if exists(val) else d
12 |
13 | class FlashAttentionFunction(Function):
14 | @staticmethod
15 | @torch.no_grad()
16 | def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
17 | """Algorithm 2 in the paper"""
18 |
19 | device = q.device
20 | max_neg_value = -torch.finfo(q.dtype).max
21 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
22 |
23 | o = torch.zeros_like(q)
24 | all_row_sums = torch.zeros((*q.shape[:-1], 1), device=device)
25 | all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device=device)
26 |
27 | scale = q.shape[-1] ** -0.5
28 |
29 | if not exists(mask):
30 | mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
31 | else:
32 | mask = rearrange(mask, "b n -> b 1 1 n")
33 | mask = mask.split(q_bucket_size, dim=-1)
34 |
35 | row_splits = zip(
36 | q.split(q_bucket_size, dim=-2),
37 | o.split(q_bucket_size, dim=-2),
38 | mask,
39 | all_row_sums.split(q_bucket_size, dim=-2),
40 | all_row_maxes.split(q_bucket_size, dim=-2),
41 | )
42 |
43 | for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
44 | q_start_index = ind * q_bucket_size - qk_len_diff
45 |
46 | col_splits = zip(
47 | k.split(k_bucket_size, dim=-2),
48 | v.split(k_bucket_size, dim=-2),
49 | )
50 |
51 | for k_ind, (kc, vc) in enumerate(col_splits):
52 | k_start_index = k_ind * k_bucket_size
53 |
54 | attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
55 |
56 | if exists(row_mask):
57 | attn_weights.masked_fill_(~row_mask, max_neg_value)
58 |
59 | if causal and q_start_index < (k_start_index + k_bucket_size - 1):
60 | causal_mask = torch.ones(
61 | (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
62 | ).triu(q_start_index - k_start_index + 1)
63 | attn_weights.masked_fill_(causal_mask, max_neg_value)
64 |
65 | block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
66 | attn_weights -= block_row_maxes
67 | exp_weights = torch.exp(attn_weights)
68 |
69 | if exists(row_mask):
70 | exp_weights.masked_fill_(~row_mask, 0.0)
71 |
72 | block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
73 | min=EPSILON
74 | )
75 |
76 | new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
77 |
78 | exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc)
79 |
80 | exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
81 | exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
82 |
83 | new_row_sums = (
84 | exp_row_max_diff * row_sums
85 | + exp_block_row_max_diff * block_row_sums
86 | )
87 |
88 | oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
89 | (exp_block_row_max_diff / new_row_sums) * exp_values
90 | )
91 |
92 | row_maxes.copy_(new_row_maxes)
93 | row_sums.copy_(new_row_sums)
94 |
95 | lse = all_row_sums.log() + all_row_maxes
96 |
97 | ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
98 | ctx.save_for_backward(q, k, v, o, lse)
99 |
100 | return o
101 |
102 | @staticmethod
103 | @torch.no_grad()
104 | def backward(ctx, do):
105 | """Algorithm 4 in the paper"""
106 |
107 | causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
108 | q, k, v, o, lse = ctx.saved_tensors
109 |
110 | device = q.device
111 |
112 | max_neg_value = -torch.finfo(q.dtype).max
113 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
114 |
115 | dq = torch.zeros_like(q)
116 | dk = torch.zeros_like(k)
117 | dv = torch.zeros_like(v)
118 |
119 | row_splits = zip(
120 | q.split(q_bucket_size, dim=-2),
121 | o.split(q_bucket_size, dim=-2),
122 | do.split(q_bucket_size, dim=-2),
123 | mask,
124 | lse.split(q_bucket_size, dim=-2),
125 | dq.split(q_bucket_size, dim=-2),
126 | )
127 |
128 | for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
129 | q_start_index = ind * q_bucket_size - qk_len_diff
130 |
131 | col_splits = zip(
132 | k.split(k_bucket_size, dim=-2),
133 | v.split(k_bucket_size, dim=-2),
134 | dk.split(k_bucket_size, dim=-2),
135 | dv.split(k_bucket_size, dim=-2),
136 | )
137 |
138 | for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
139 | k_start_index = k_ind * k_bucket_size
140 |
141 | attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
142 |
143 | if causal and q_start_index < (k_start_index + k_bucket_size - 1):
144 | causal_mask = torch.ones(
145 | (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
146 | ).triu(q_start_index - k_start_index + 1)
147 | attn_weights.masked_fill_(causal_mask, max_neg_value)
148 |
149 | p = torch.exp(attn_weights - lsec)
150 |
151 | if exists(row_mask):
152 | p.masked_fill_(~row_mask, 0.0)
153 |
154 | dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
155 | dp = einsum("... i d, ... j d -> ... i j", doc, vc)
156 |
157 | D = (doc * oc).sum(dim=-1, keepdims=True)
158 | ds = p * scale * (dp - D)
159 |
160 | dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
161 | dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
162 |
163 | dqc.add_(dq_chunk)
164 | dkc.add_(dk_chunk)
165 | dvc.add_(dv_chunk)
166 |
167 | return dq, dk, dv, None, None, None, None
168 |
--------------------------------------------------------------------------------
/modules/safechecker.py:
--------------------------------------------------------------------------------
1 | # this code is adapted from the script contributed by anon from /h/
2 | # modified, from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/safe.py
3 |
4 | import io
5 | import pickle
6 | import collections
7 | import sys
8 | import traceback
9 |
10 | import torch
11 | import numpy
12 | import _codecs
13 | import zipfile
14 | import re
15 |
16 |
17 | # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
18 | TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
19 |
20 |
21 | def encode(*args):
22 | out = _codecs.encode(*args)
23 | return out
24 |
25 |
26 | class RestrictedUnpickler(pickle.Unpickler):
27 | extra_handler = None
28 |
29 | def persistent_load(self, saved_id):
30 | assert saved_id[0] == 'storage'
31 | return TypedStorage()
32 |
33 | def find_class(self, module, name):
34 | if self.extra_handler is not None:
35 | res = self.extra_handler(module, name)
36 | if res is not None:
37 | return res
38 |
39 | if module == 'collections' and name == 'OrderedDict':
40 | return getattr(collections, name)
41 | if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
42 | return getattr(torch._utils, name)
43 | if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
44 | return getattr(torch, name)
45 | if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
46 | return getattr(torch.nn.modules.container, name)
47 | if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
48 | return getattr(numpy.core.multiarray, name)
49 | if module == 'numpy' and name in ['dtype', 'ndarray']:
50 | return getattr(numpy, name)
51 | if module == '_codecs' and name == 'encode':
52 | return encode
53 | if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
54 | import pytorch_lightning.callbacks
55 | return pytorch_lightning.callbacks.model_checkpoint
56 | if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
57 | import pytorch_lightning.callbacks.model_checkpoint
58 | return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
59 | if module == "__builtin__" and name == 'set':
60 | return set
61 |
62 | # Forbid everything else.
63 | raise Exception(f"global '{module}/{name}' is forbidden")
64 |
65 |
66 | # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/Hso @ nyanko.sketch2img.gradio
509 |689 | Will use the following formula: w = scale * token_weight_martix * log(1 + sigma) * max(qk). 690 |
691 |