├── tests ├── __init__.py ├── test_super_resolution.py ├── test_tokenizer.py ├── test_image_prompts.py ├── conftest.py ├── test_dalle.py └── test_vae.py ├── requirements-test.txt ├── pics ├── rainbow-full.png ├── rainbow-cherry-pick.png ├── rainbow-super-resolution.png ├── anime-girl-super-resolution.png └── russian-temple-image-prompt.png ├── requirements.txt ├── rudalle_paddle ├── packages │ ├── einops │ │ ├── _torch_specific.py │ │ ├── __init__.py │ │ ├── layers │ │ │ ├── chainer.py │ │ │ ├── torch.py │ │ │ ├── paddle.py │ │ │ ├── keras.py │ │ │ ├── gluon.py │ │ │ ├── tensorflow.py │ │ │ ├── __init__.py │ │ │ └── _weighted_einsum.py │ │ ├── parsing.py │ │ ├── _backends.py │ │ └── einops.py │ ├── __init__.py │ └── transformers │ │ └── __init__.py ├── __init__.py ├── vae │ ├── vqgan.gumbelf8-sber.config.yml │ ├── __init__.py │ └── model.py ├── utils.py ├── ruclip │ ├── __init__.py │ ├── processor.py │ └── model.py ├── realesrgan │ ├── __init__.py │ ├── model.py │ ├── utils.py │ ├── rrdbnet_arch.py │ └── arch_util.py ├── dalle │ ├── utils.py │ ├── fp16.py │ ├── image_attention.py │ ├── __init__.py │ ├── model.py │ └── transformer.py ├── image_prompts.py ├── tokenizer.py ├── pipelines.py └── future.py ├── setup.cfg ├── .gitlab-ci.yml ├── download.py ├── .pre-commit-config.yaml ├── convert.py ├── setup.py ├── .gitignore ├── README.md └── LICENSE.txt /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | pytest 3 | pytest-cov 4 | pre-commit 5 | -------------------------------------------------------------------------------- /pics/rainbow-full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AgentMaker/ru-dalle-paddle/HEAD/pics/rainbow-full.png -------------------------------------------------------------------------------- /pics/rainbow-cherry-pick.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AgentMaker/ru-dalle-paddle/HEAD/pics/rainbow-cherry-pick.png -------------------------------------------------------------------------------- /pics/rainbow-super-resolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AgentMaker/ru-dalle-paddle/HEAD/pics/rainbow-super-resolution.png -------------------------------------------------------------------------------- /pics/anime-girl-super-resolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AgentMaker/ru-dalle-paddle/HEAD/pics/anime-girl-super-resolution.png -------------------------------------------------------------------------------- /pics/russian-temple-image-prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AgentMaker/ru-dalle-paddle/HEAD/pics/russian-temple-image-prompt.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | more_itertools~=8.10.0 2 | youtokentome~=1.0.6 3 | omegaconf>=2.0.0 4 | paddlepaddle-gpu~=2.2.0 5 | paddleclip 6 | huggingface-hub~=0.1.2 7 | matplotlib -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/_torch_specific.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Specialization of einops for torch. 4 | Unfortunately, torch's jit scripting mechanism 5 | """ 6 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | from os.path import dirname, abspath 4 | 5 | dirname = dirname(abspath(__file__)) 6 | sys.path.insert(0, dirname) 7 | 8 | __all__ = [] 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max-line-length = 120 3 | exclude = .tox,*migrations*,.json 4 | 5 | [flake8] 6 | max-line-length = 120 7 | exclude = .tox,*migrations*,.json 8 | 9 | [autopep8-wrapper] 10 | exclude = .tox,*migrations*,.json 11 | 12 | [check-docstring-first] 13 | exclude = .tox,*migrations*,.json 14 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = 'Alex Rogozhnikov' 3 | __version__ = '0.3.2' 4 | 5 | 6 | from .einops import rearrange, reduce, repeat, parse_shape, asnumpy, EinopsError 7 | 8 | 9 | __all__ = ['rearrange', 'reduce', 'repeat', 'parse_shape', 'asnumpy', 'EinopsError'] 10 | -------------------------------------------------------------------------------- /tests/test_super_resolution.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from rudalle_paddle.pipelines import super_resolution 3 | 4 | 5 | def test_super_resolution(sample_image, realesrgan): 6 | img = sample_image.copy() 7 | img = img.resize((32, 32)) 8 | sr_img = super_resolution([img], realesrgan)[0] 9 | assert sr_img.size[0] == 32*2 10 | assert sr_img.size[1] == 32*2 11 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | stages: 2 | - test 3 | 4 | all_branch_test: 5 | stage: test 6 | tags: 7 | - docker 8 | image: python:3.9 9 | script: 10 | - apt-get update ##[edited] 11 | - apt-get install ffmpeg libsm6 libxext6 -y 12 | - pip install cython 13 | - pip install -r requirements-test.txt --no-cache-dir 14 | - pip install codecov 15 | - pytest --cov=rudalle_paddle tests/ 16 | - bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN 17 | except: 18 | - tags 19 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from rudalle_paddle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip 3 | 4 | get_rudalle_model('Malevich-paddle', True, device='cuda', cache_dir='pretrained_models') 5 | get_tokenizer(cache_dir='pretrained_models') 6 | get_vae(name='vqgan.gumbelf8-sber.paddle', pretrained=True, cache_dir='pretrained_models') 7 | get_realesrgan('x2-paddle', cache_dir='pretrained_models') 8 | get_realesrgan('x4-paddle', cache_dir='pretrained_models') 9 | get_realesrgan('x8-paddle', cache_dir='pretrained_models') 10 | get_ruclip('ruclip-vit-base-patch32-v5-paddle', cache_dir='pretrained_models') 11 | -------------------------------------------------------------------------------- /rudalle_paddle/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .future import * # noqa 3 | from .packages import * # noqa 4 | from .vae import get_vae 5 | from .dalle import get_rudalle_model 6 | from .tokenizer import get_tokenizer 7 | from .realesrgan import get_realesrgan 8 | from .ruclip import get_ruclip 9 | from . import vae, dalle, tokenizer, realesrgan, pipelines, ruclip, image_prompts 10 | 11 | 12 | __all__ = [ 13 | 'get_vae', 14 | 'get_rudalle_model', 15 | 'get_tokenizer', 16 | 'get_realesrgan', 17 | 'get_ruclip', 18 | 'vae', 19 | 'dalle', 20 | 'ruclip', 21 | 'tokenizer', 22 | 'realesrgan', 23 | 'pipelines', 24 | 'image_prompts', 25 | ] 26 | 27 | __version__ = '0.0.1-rc1' 28 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pytest 3 | 4 | 5 | @pytest.mark.parametrize('text, text_seq_length, bpe_dropout', [ 6 | ('hello, how are you?', 128, 0.1), 7 | ('hello, how are you?', 128, 0.5), 8 | ('hello, how are you?', 128, 1.0), 9 | ('hello ... how are you ?', 256, 1.0), 10 | ('a person standing at a table with bottles of win', 64, 0.5), 11 | ('привет как дела???', 76, 0.0), 12 | ('клип на русском языке :)', 76, 0.1), 13 | ]) 14 | def test_encode_decode_text_yttm(yttm_tokenizer, text, text_seq_length, bpe_dropout): 15 | tokens = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length, bpe_dropout=bpe_dropout) 16 | decoded_text = yttm_tokenizer.decode_text(tokens) 17 | assert text == decoded_text 18 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.2.3 4 | hooks: 5 | - id: check-docstring-first 6 | stages: 7 | - commit 8 | - push 9 | - id: check-merge-conflict 10 | stages: 11 | - push 12 | - id: double-quote-string-fixer 13 | stages: 14 | - commit 15 | - push 16 | - id: fix-encoding-pragma 17 | stages: 18 | - commit 19 | - push 20 | - id: flake8 21 | args: ['--config=setup.cfg'] 22 | stages: 23 | - commit 24 | - push 25 | - repo: https://github.com/pre-commit/mirrors-autopep8 26 | rev: v1.4.4 27 | hooks: 28 | - id: autopep8 29 | stages: 30 | - commit 31 | - push 32 | -------------------------------------------------------------------------------- /rudalle_paddle/vae/vqgan.gumbelf8-sber.config.yml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.GumbelVQ 4 | params: 5 | kl_weight: 1.0e-08 6 | embed_dim: 256 7 | n_embed: 8192 8 | monitor: val/rec_loss 9 | temperature_scheduler_config: 10 | target: taming.lr_scheduler.LambdaWarmUpCosineScheduler 11 | params: 12 | warm_up_steps: 0 13 | max_decay_steps: 1000001 14 | lr_start: 0.9 15 | lr_max: 0.9 16 | lr_min: 1.0e-06 17 | ddconfig: 18 | double_z: false 19 | z_channels: 256 20 | resolution: 256 21 | in_channels: 3 22 | out_ch: 3 23 | ch: 128 24 | ch_mult: 25 | - 1 26 | - 1 27 | - 2 28 | - 4 29 | num_res_blocks: 2 30 | attn_resolutions: 31 | - 32 32 | dropout: 0.0 33 | lossconfig: 34 | target: taming.modules.losses.vqperceptual.DummyLoss 35 | -------------------------------------------------------------------------------- /tests/test_image_prompts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pytest 3 | 4 | from rudalle_paddle.image_prompts import ImagePrompts 5 | 6 | 7 | @pytest.mark.parametrize('borders, crop_first', [ 8 | ({'up': 4, 'right': 0, 'left': 0, 'down': 0}, False), 9 | ({'up': 4, 'right': 0, 'left': 0, 'down': 0}, True), 10 | ({'up': 4, 'right': 3, 'left': 3, 'down': 3}, False) 11 | ]) 12 | def test_image_prompts(sample_image, vae, borders, crop_first): 13 | img = sample_image.copy() 14 | img = img.resize((256, 256)) 15 | image_prompt = ImagePrompts(img, borders, vae, crop_first=crop_first) 16 | if crop_first: 17 | assert image_prompt.image_prompts.shape[1] == borders['up'] * 32 18 | assert len(image_prompt.image_prompts_idx) == borders['up'] * 32 19 | else: 20 | assert image_prompt.image_prompts.shape[1] == 32 * 32 21 | assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \ 22 | + (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down']) 23 | -------------------------------------------------------------------------------- /rudalle_paddle/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import random 4 | 5 | import paddle 6 | import numpy as np 7 | 8 | from PIL import Image 9 | 10 | 11 | def seed_everything(seed): 12 | random.seed(seed) 13 | os.environ['PYTHONHASHSEED'] = str(seed) 14 | np.random.seed(seed) 15 | paddle.seed(seed) 16 | 17 | 18 | def paddle_tensors_to_pil_list(input_images): 19 | out_images = [] 20 | for in_image in input_images: 21 | in_image = np.clip(in_image.cpu().detach().numpy() * 255, 0, 255) 22 | out_image = Image.fromarray(in_image.transpose([1, 2, 0]).astype(np.uint8)) 23 | out_images.append(out_image) 24 | return out_images 25 | 26 | 27 | def pil_list_to_paddle_tensors(pil_images): 28 | result = [] 29 | for pil_image in pil_images: 30 | image = np.array(pil_image, dtype=np.uint8) 31 | image = paddle.to_tensor(image).astype(paddle.int64) 32 | image = image.transpose([2, 0, 1]).unsqueeze(0) 33 | result.append(image) 34 | return paddle.concat(result, axis=0) 35 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import io 3 | from os.path import abspath, dirname 4 | 5 | import PIL 6 | import pytest 7 | import requests 8 | 9 | from rudalle_paddle import get_tokenizer, get_rudalle_model, get_vae, get_realesrgan 10 | 11 | 12 | TEST_ROOT = dirname(abspath(__file__)) 13 | 14 | 15 | @pytest.fixture(scope='module') 16 | def realesrgan(): 17 | realesrgan = get_realesrgan('x2-paddle', device='cpu') 18 | yield realesrgan 19 | 20 | 21 | @pytest.fixture(scope='module') 22 | def vae(): 23 | vae = get_vae(pretrained=False) 24 | yield vae 25 | 26 | 27 | @pytest.fixture(scope='module') 28 | def yttm_tokenizer(): 29 | tokenizer = get_tokenizer() 30 | yield tokenizer 31 | 32 | 33 | @pytest.fixture(scope='module') 34 | def sample_image(): 35 | url = 'https://cdn.kqed.org/wp-content/uploads/sites/12/2013/12/rudolph.png' 36 | resp = requests.get(url) 37 | resp.raise_for_status() 38 | image = PIL.Image.open(io.BytesIO(resp.content)) 39 | yield image 40 | 41 | 42 | @pytest.fixture(scope='module') 43 | def small_dalle(): 44 | model = get_rudalle_model('small', pretrained=False, fp16=False, device='cpu') 45 | return model 46 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/layers/chainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import chainer 3 | 4 | from . import RearrangeMixin, ReduceMixin 5 | from ._weighted_einsum import WeightedEinsumMixin 6 | 7 | __author__ = 'Alex Rogozhnikov' 8 | 9 | 10 | class Rearrange(RearrangeMixin, chainer.Link): 11 | def __call__(self, x): 12 | return self._apply_recipe(x) 13 | 14 | 15 | class Reduce(ReduceMixin, chainer.Link): 16 | def __call__(self, x): 17 | return self._apply_recipe(x) 18 | 19 | 20 | class WeightedEinsum(WeightedEinsumMixin, chainer.Link): 21 | def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): 22 | uniform = chainer.variable.initializers.Uniform 23 | with self.init_scope(): 24 | self.weight = chainer.variable.Parameter(uniform(weight_bound), weight_shape) 25 | if bias_shape is not None: 26 | self.bias = chainer.variable.Parameter(uniform(bias_bound), bias_shape) 27 | else: 28 | self.bias = None 29 | 30 | def __call__(self, input): 31 | result = chainer.functions.einsum(self.einsum_pattern, input, self.weight) 32 | if self.bias is not None: 33 | result = result + self.bias 34 | return result 35 | -------------------------------------------------------------------------------- /rudalle_paddle/ruclip/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | from .model import CLIPModel 5 | from huggingface_hub import hf_hub_url, cached_download 6 | 7 | from .processor import RuCLIPProcessor 8 | 9 | MODELS = { 10 | 'ruclip-vit-base-patch32-v5': dict( 11 | repo_id='sberbank-ai/ru-clip', 12 | filenames=[ 13 | 'bpe.model', 'config.json', 'pytorch_model.bin' 14 | ] 15 | ), 16 | 'ruclip-vit-base-patch32-v5-paddle': dict( 17 | repo_id='HighCWu/rudalle-paddle-utils', 18 | filenames=[ 19 | 'bpe.model', 'config.json', 'ruclip_paddle.pdparams' 20 | ] 21 | ), 22 | } 23 | 24 | 25 | def get_ruclip(name, cache_dir='/tmp/rudalle'): 26 | assert name in MODELS 27 | config = MODELS[name] 28 | repo_id = config['repo_id'] 29 | cache_dir = os.path.join(cache_dir, name) 30 | for filename in config['filenames']: 31 | config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{name}/{filename}') 32 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) 33 | ruclip = CLIPModel.from_pretrained(cache_dir) 34 | ruclip_processor = RuCLIPProcessor.from_pretrained(cache_dir) 35 | print('ruclip --> ready') 36 | return ruclip, ruclip_processor 37 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/layers/torch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | from . import RearrangeMixin, ReduceMixin 5 | from ._weighted_einsum import WeightedEinsumMixin 6 | 7 | __author__ = 'Alex Rogozhnikov' 8 | 9 | 10 | class Rearrange(RearrangeMixin, torch.nn.Module): 11 | def forward(self, input): 12 | return self._apply_recipe(input) 13 | 14 | 15 | class Reduce(ReduceMixin, torch.nn.Module): 16 | def forward(self, input): 17 | return self._apply_recipe(input) 18 | 19 | 20 | class WeightedEinsum(WeightedEinsumMixin, torch.nn.Module): 21 | def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): 22 | self.weight = torch.nn.Parameter(torch.zeros(weight_shape).uniform_(-weight_bound, weight_bound), 23 | requires_grad=True) 24 | if bias_shape is not None: 25 | self.bias = torch.nn.Parameter(torch.zeros(bias_shape).uniform_(-bias_bound, bias_bound), 26 | requires_grad=True) 27 | else: 28 | self.bias = None 29 | 30 | def forward(self, input): 31 | result = torch.einsum(self.einsum_pattern, input, self.weight) 32 | if self.bias is not None: 33 | result += self.bias 34 | return result 35 | -------------------------------------------------------------------------------- /tests/test_dalle.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import paddle 3 | import pytest 4 | 5 | from .test_vae import preprocess 6 | 7 | 8 | @pytest.mark.parametrize('text', [ 9 | 'мальчик играет с оленем', 10 | ]) 11 | def test_forward_step_and_criterion(text, sample_image, yttm_tokenizer, vae, small_dalle): 12 | bs = 4 13 | text_seq_length = small_dalle.get_param('text_seq_length') 14 | total_seq_length = small_dalle.get_param('total_seq_length') 15 | device = small_dalle.get_param('device') 16 | 17 | img = sample_image.copy() 18 | img = preprocess(img, target_image_size=256) 19 | images = img.tile([bs, 1, 1, 1]).to(device) 20 | 21 | text = text.lower().strip() 22 | text_input_ids = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length) 23 | text_input_ids = text_input_ids.unsqueeze(0).tile([bs, 1]).to(device) 24 | 25 | attention_mask = paddle.tril(paddle.ones((bs, 1, total_seq_length, total_seq_length))).to(device) 26 | with paddle.no_grad(): 27 | image_input_ids = vae.get_codebook_indices(images) 28 | input_ids = paddle.concat((text_input_ids, image_input_ids), axis=1).to(device) 29 | loss, loss_values = small_dalle.forward(input_ids, attention_mask, return_loss=True) 30 | assert type(loss.detach().item()) == float 31 | assert type(loss_values) == dict 32 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/layers/paddle.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import paddle 3 | 4 | from . import RearrangeMixin, ReduceMixin 5 | from ._weighted_einsum import WeightedEinsumMixin 6 | 7 | __author__ = 'Wu Hecong' 8 | 9 | 10 | class Rearrange(RearrangeMixin, paddle.nn.Layer): 11 | def forward(self, input): 12 | return self._apply_recipe(input) 13 | 14 | 15 | class Reduce(ReduceMixin, paddle.nn.Layer): 16 | def forward(self, input): 17 | return self._apply_recipe(input) 18 | 19 | 20 | class WeightedEinsum(WeightedEinsumMixin, paddle.nn.Layer): 21 | def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): 22 | self.weight = self.create_parameter(weight_shape, 23 | default_initializer=paddle.initializer.Uniform(-weight_bound, weight_bound)) 24 | if bias_shape is not None: 25 | self.bias = self.create_parameter(bias_shape, 26 | default_initializer=paddle.initializer.Uniform(-bias_bound, bias_bound)) 27 | else: 28 | self.bias = None 29 | 30 | def forward(self, input): 31 | result = paddle.einsum(self.einsum_pattern, input, self.weight) 32 | if self.bias is not None: 33 | result += self.bias 34 | return result 35 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/layers/keras.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tensorflow.keras.layers import Layer 3 | 4 | from .._backends import UnknownSize 5 | from . import RearrangeMixin, ReduceMixin 6 | 7 | __author__ = 'Alex Rogozhnikov' 8 | 9 | 10 | def _compute_output_shape(self, input_shape): 11 | input_shape = tuple(UnknownSize() if d is None else int(d) for d in input_shape) 12 | init_shapes, reduced_axes, axes_reordering, added_axes, final_shape = \ 13 | self.recipe().reconstruct_from_shape(input_shape) 14 | final_shape = tuple(None if isinstance(d, UnknownSize) else int(d) for d in final_shape) 15 | return final_shape 16 | 17 | 18 | class Rearrange(RearrangeMixin, Layer): 19 | def compute_output_shape(self, input_shape): 20 | return _compute_output_shape(self, input_shape) 21 | 22 | def call(self, inputs): 23 | return self._apply_recipe(inputs) 24 | 25 | def get_config(self): 26 | return {'pattern': self.pattern, **self.axes_lengths} 27 | 28 | 29 | class Reduce(ReduceMixin, Layer): 30 | def compute_output_shape(self, input_shape): 31 | return _compute_output_shape(self, input_shape) 32 | 33 | def call(self, inputs): 34 | return self._apply_recipe(inputs) 35 | 36 | def get_config(self): 37 | return {'pattern': self.pattern, 'reduction': self.reduction, **self.axes_lengths} 38 | 39 | 40 | keras_custom_objects = {Rearrange.__name__: Rearrange, Reduce.__name__: Reduce} 41 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/layers/gluon.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import mxnet 3 | 4 | from . import RearrangeMixin, ReduceMixin 5 | from ._weighted_einsum import WeightedEinsumMixin 6 | 7 | __author__ = 'Alex Rogozhnikov' 8 | 9 | 10 | class Rearrange(RearrangeMixin, mxnet.gluon.HybridBlock): 11 | def hybrid_forward(self, F, x): 12 | return self._apply_recipe(x) 13 | 14 | 15 | class Reduce(ReduceMixin, mxnet.gluon.HybridBlock): 16 | def hybrid_forward(self, F, x): 17 | return self._apply_recipe(x) 18 | 19 | 20 | class WeightedEinsum(WeightedEinsumMixin, mxnet.gluon.HybridBlock): 21 | def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): 22 | with self.name_scope(): 23 | 24 | self.weight = self.params.get(name='weight', shape=weight_shape, 25 | init=mxnet.initializer.Uniform(weight_bound), 26 | ) 27 | if bias_shape is not None: 28 | self.bias = self.params.get(name='bias', shape=bias_shape, 29 | init=mxnet.initializer.Uniform(bias_bound), 30 | ) 31 | else: 32 | self.bias = None 33 | 34 | def hybrid_forward(self, F, x, *args, **kwargs): 35 | result = mxnet.np.einsum(self.einsum_pattern, x, self.weight.data()) 36 | if self.bias is not None: 37 | result += self.bias.data() 38 | return result 39 | -------------------------------------------------------------------------------- /rudalle_paddle/vae/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from os.path import dirname, abspath, join 3 | 4 | import paddle 5 | from huggingface_hub import hf_hub_url, cached_download 6 | from omegaconf import OmegaConf 7 | 8 | from .model import VQGanGumbelVAE 9 | 10 | 11 | MODELS = { 12 | 'vqgan.gumbelf8-sber': dict( 13 | repo_id='shonenkov/rudalle-utils', 14 | filename='vqgan.gumbelf8-sber.model.ckpt', 15 | ), 16 | 'vqgan.gumbelf8-sber.paddle': dict( 17 | repo_id='HighCWu/rudalle-paddle-utils', 18 | filename='vqgan.gumbelf8-sber.model.pdckpt', 19 | ), 20 | } 21 | 22 | 23 | def get_vae(name='vqgan.gumbelf8-sber.paddle', pretrained=True, cache_dir='/tmp/rudalle'): 24 | # TODO 25 | config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml')) 26 | vae = VQGanGumbelVAE(config) 27 | config = MODELS[name] 28 | if pretrained: 29 | cache_dir = join(cache_dir, name) 30 | config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) 31 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) 32 | if config['filename'][-5:] == '.ckpt': 33 | VQGanGumbelVAE.convert(join(cache_dir, config['filename'])) 34 | config['filename'] = config['filename'][:-5] + '.pdckpt' 35 | checkpoint = paddle.load(join(cache_dir, config['filename'])) 36 | vae.model.set_state_dict(checkpoint['state_dict']) 37 | print('vae --> ready') 38 | return vae 39 | -------------------------------------------------------------------------------- /rudalle_paddle/realesrgan/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import paddle 4 | from huggingface_hub import hf_hub_url, cached_download 5 | 6 | from .model import RealESRGAN 7 | 8 | 9 | MODELS = { 10 | 'x2': dict( 11 | scale=2, 12 | repo_id='shonenkov/rudalle-utils', 13 | filename='RealESRGAN_x2.pth', 14 | ), 15 | 'x4': dict( 16 | scale=4, 17 | repo_id='shonenkov/rudalle-utils', 18 | filename='RealESRGAN_x4.pth', 19 | ), 20 | 'x8': dict( 21 | scale=8, 22 | repo_id='shonenkov/rudalle-utils', 23 | filename='RealESRGAN_x8.pth', 24 | ), 25 | 'x2-paddle': dict( 26 | scale=2, 27 | repo_id='HighCWu/rudalle-paddle-utils', 28 | filename='RealESRGAN_x2.pdparams', 29 | ), 30 | 'x4-paddle': dict( 31 | scale=4, 32 | repo_id='HighCWu/rudalle-paddle-utils', 33 | filename='RealESRGAN_x4.pdparams', 34 | ), 35 | 'x8-paddle': dict( 36 | scale=8, 37 | repo_id='HighCWu/rudalle-paddle-utils', 38 | filename='RealESRGAN_x8.pdparams', 39 | ), 40 | } 41 | 42 | 43 | def get_realesrgan(name, device='cpu', cache_dir='/tmp/rudalle'): 44 | assert name in MODELS 45 | paddle.set_device(device) 46 | config = MODELS[name] 47 | model = RealESRGAN(device, config['scale']) 48 | cache_dir = os.path.join(cache_dir, name) 49 | config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) 50 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) 51 | model.load_weights(os.path.join(cache_dir, config['filename'])) 52 | print(f'{name} --> ready') 53 | return model 54 | -------------------------------------------------------------------------------- /rudalle_paddle/dalle/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import paddle 3 | 4 | 5 | def exists(val): 6 | return val is not None 7 | 8 | 9 | def is_empty(t): 10 | return all([s == 0 for s in t.shape]) 11 | 12 | 13 | def ensure_divisibility(numerator, denominator): 14 | """Ensure that numerator is divisible by the denominator.""" 15 | assert numerator % denominator == 0, '{} is not divisible by {}'.format( 16 | numerator, denominator) 17 | 18 | 19 | def divide(numerator, denominator): 20 | """Ensure that numerator is divisible by the denominator and return 21 | the division value.""" 22 | ensure_divisibility(numerator, denominator) 23 | return numerator // denominator 24 | 25 | 26 | def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): 27 | """ 28 | Split a tensor along its last dimension. 29 | Arguments: 30 | tensor: input tensor. 31 | num_partitions: number of partitions to split the tensor 32 | contiguous_split_chunks: If True, make each chunk contiguous 33 | in memory. 34 | """ 35 | # Get the size and dimension. 36 | last_dim = tensor.dim() - 1 37 | # Split. 38 | tensor_list = paddle.split(tensor, num_partitions, axis=last_dim) 39 | # Note: paddle.split does not create contiguous tensors by default. 40 | if contiguous_split_chunks: 41 | return tuple(chunk for chunk in tensor_list) 42 | return tensor_list 43 | 44 | 45 | def init_method_normal(std=0.02): 46 | """Init method based on normal distribution. 47 | 48 | This is only used for embeddings. The transformer has its 49 | own initializer. 50 | """ 51 | def init_(tensor): 52 | return paddle.nn.initializer.Normal(mean=0.0, std=std)(tensor) 53 | return init_ 54 | -------------------------------------------------------------------------------- /tests/test_vae.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pytest 3 | import paddle 4 | import paddle.vision.transforms as T 5 | import paddle.vision.transforms.functional as TF 6 | 7 | 8 | @pytest.mark.parametrize('target_image_size', [128, 192, 256]) 9 | def test_decode_vae(vae, sample_image, target_image_size): 10 | img = sample_image.copy() 11 | img = preprocess(img, target_image_size=target_image_size) 12 | with paddle.no_grad(): 13 | img_seq = vae.get_codebook_indices(img) 14 | out_img = vae.decode(img_seq) 15 | assert list(out_img.shape) == [1, 3, target_image_size, target_image_size] 16 | 17 | 18 | @pytest.mark.parametrize('target_image_size', [128, 192, 256]) 19 | def test_reconstruct_vae(vae, sample_image, target_image_size): 20 | img = sample_image.copy() 21 | with paddle.no_grad(): 22 | x_vqgan = preprocess(img, target_image_size=target_image_size) 23 | output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), vae.model) 24 | assert list(output.shape) == [1, 3, target_image_size, target_image_size] 25 | 26 | 27 | def preprocess(img, target_image_size=256): 28 | s = min(img.size) 29 | if s < target_image_size: 30 | raise ValueError(f'min dim for image {s} < {target_image_size}') 31 | r = target_image_size / s 32 | s = (round(r * img.size[1]), round(r * img.size[0])) 33 | img = TF.resize(img, s, interpolation='lanczos') 34 | img = TF.center_crop(img, output_size=2 * [target_image_size]) 35 | img = paddle.unsqueeze(T.ToTensor()(img), 0) 36 | return img 37 | 38 | 39 | def preprocess_vqgan(x): 40 | x = 2.*x - 1. 41 | return x 42 | 43 | 44 | def reconstruct_with_vqgan(x, model): 45 | z, _, [_, _, _] = model.encode(x) 46 | print(f'VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}') 47 | xrec = model.decode(z) 48 | return xrec 49 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import glob 4 | import shutil 5 | 6 | from rudalle_paddle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip 7 | 8 | target_dir = 'rudalle-paddle-utils' 9 | os.makedirs(target_dir, exist_ok=True) 10 | 11 | get_rudalle_model('Malevich', True, device='cuda', cache_dir='/tmp/rudalle') 12 | get_tokenizer(cache_dir='/tmp/rudalle') 13 | get_vae(name='vqgan.gumbelf8-sber', pretrained=True, cache_dir='/tmp/rudalle') 14 | get_realesrgan('x2', cache_dir='/tmp/rudalle') 15 | get_realesrgan('x4', cache_dir='/tmp/rudalle') 16 | get_realesrgan('x8', cache_dir='/tmp/rudalle') 17 | 18 | files = glob.glob('/tmp/rudalle/**/*.pkl', recursive=True) + \ 19 | glob.glob('/tmp/rudalle/**/*.json', recursive=True) + \ 20 | glob.glob('/tmp/rudalle/**/*.model', recursive=True) + \ 21 | glob.glob('/tmp/rudalle/**/*.pdckpt', recursive=True) + \ 22 | glob.glob('/tmp/rudalle/**/*.pdparams', recursive=True) 23 | 24 | for path in files: 25 | name = os.path.basename(path) 26 | shutil.move(path, os.path.join(target_dir, name)) 27 | 28 | target_dir = os.path.join(target_dir, 'ruclip-vit-base-patch32-v5-paddle') 29 | os.makedirs(target_dir, exist_ok=True) 30 | 31 | get_ruclip('ruclip-vit-base-patch32-v5', cache_dir='/tmp') 32 | 33 | files = glob.glob('/tmp/ruclip-vit-base-patch32-v5/**/*.pkl', recursive=True) + \ 34 | glob.glob('/tmp/ruclip-vit-base-patch32-v5/**/*.json', recursive=True) + \ 35 | glob.glob('/tmp/ruclip-vit-base-patch32-v5/**/*.model', recursive=True) + \ 36 | glob.glob('/tmp/ruclip-vit-base-patch32-v5/**/*.pdckpt', recursive=True) + \ 37 | glob.glob('/tmp/ruclip-vit-base-patch32-v5/**/*.pdparams', recursive=True) 38 | 39 | for path in files: 40 | name = os.path.basename(path) 41 | shutil.move(path, os.path.join(target_dir, name)) 42 | -------------------------------------------------------------------------------- /rudalle_paddle/dalle/fp16.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import paddle 3 | from paddle import nn 4 | 5 | 6 | def conversion_helper(val, conversion): 7 | """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" 8 | if not isinstance(val, (tuple, list)): 9 | return conversion(val) 10 | rtn = [conversion_helper(v, conversion) for v in val] 11 | if isinstance(val, tuple): 12 | rtn = tuple(rtn) 13 | return rtn 14 | 15 | 16 | def fp32_to_fp16(val): 17 | """Convert fp32 `val` to fp16""" 18 | def half_conversion(val): 19 | if val.dtype == paddle.float32: 20 | val = val.astype(paddle.float16) 21 | return val 22 | return conversion_helper(val, half_conversion) 23 | 24 | 25 | def fp16_to_fp32(val): 26 | """Convert fp16 `val` to fp32""" 27 | def float_conversion(val): 28 | if isinstance(val, paddle.Tensor) and val.dtype == paddle.float16: 29 | val = val.astype(paddle.float32) 30 | return val 31 | return conversion_helper(val, float_conversion) 32 | 33 | 34 | class FP16Module(nn.Layer): 35 | def __init__(self, module): 36 | super(FP16Module, self).__init__() 37 | self.add_sublayer('module', module.to(module.device, dtype='float16')) 38 | 39 | def forward(self, *inputs, **kwargs): 40 | return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) 41 | 42 | def state_dict(self, destination=None, include_sublayers=True): 43 | return self.module.state_dict(destination, include_sublayers) 44 | 45 | def set_state_dict(self, state_dict, use_structured_name=True): 46 | self.module.set_state_dict(state_dict, use_structured_name) 47 | 48 | def get_param(self, item): 49 | return self.module.get_param(item) 50 | 51 | def to(self, device, *args, **kwargs): 52 | self.module = self.module.to(device) 53 | return super().to(device, *args, **kwargs) 54 | -------------------------------------------------------------------------------- /rudalle_paddle/dalle/image_attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import paddle 4 | 5 | 6 | def _init_mask(text_tokens, image_tokens_per_dim): 7 | attn_size = text_tokens + image_tokens_per_dim**2 8 | mask = paddle.tril(paddle.ones([attn_size, attn_size])) 9 | return mask 10 | 11 | 12 | def get_row_mask(text_tokens=256, image_tokens_per_dim=32): 13 | mask = _init_mask(text_tokens, image_tokens_per_dim) 14 | step = image_tokens_per_dim + 1 15 | for col in range(text_tokens, mask.shape[1]): 16 | if col + step >= mask.shape[0] or col > mask.shape[1]: 17 | continue # paddle not support index >= tensor shape 18 | mask[col + step:, col] = 0.0 19 | return mask 20 | 21 | 22 | def get_col_mask(text_tokens=256, image_tokens_per_dim=32): 23 | mask = _init_mask(text_tokens, image_tokens_per_dim) 24 | step = image_tokens_per_dim - 1 25 | for col in range(text_tokens, mask.shape[1]): 26 | for i in range(1, mask.shape[0], step+1): 27 | if col + i >= mask.shape[0] or col + i + step > mask.shape[0] or col > mask.shape[1]: 28 | continue # paddle not support index >= tensor shape 29 | mask[col + i: col + i + step, col] = 0.0 30 | return mask 31 | 32 | 33 | def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11): 34 | mask = _init_mask(text_tokens, image_tokens_per_dim) 35 | shift = kernel // 2 36 | for pos in range(text_tokens, mask.shape[1]): 37 | if pos + 1 < mask.shape[0] and pos < mask.shape[1]: 38 | mask[pos+1:, pos] = 0.0 39 | img = paddle.zeros([image_tokens_per_dim, image_tokens_per_dim]) 40 | pixel_id = pos - text_tokens 41 | row = pixel_id // image_tokens_per_dim 42 | col = pixel_id % image_tokens_per_dim 43 | for r in range(-shift, shift+1): 44 | for c in range(-shift, shift+1): 45 | c_abs = (c + col) % image_tokens_per_dim 46 | r_abs = (r + row) % image_tokens_per_dim 47 | img[r_abs, c_abs] = 0.2 48 | cell_id = r_abs * image_tokens_per_dim + c_abs 49 | if text_tokens + cell_id > pos: 50 | mask[text_tokens + cell_id, pos] = 1.0 51 | 52 | img[row, col] = 1.0 53 | return mask 54 | -------------------------------------------------------------------------------- /rudalle_paddle/image_prompts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import paddle 3 | import numpy as np 4 | 5 | 6 | class ImagePrompts: 7 | 8 | def __init__(self, pil_image, borders, vae, device='cpu', crop_first=False): 9 | """ 10 | Args: 11 | pil_image (PIL.Image): image in PIL format 12 | borders (dict[str] | int): borders that we croped from pil_image 13 | example: {'up': 4, 'right': 0, 'left': 0, 'down': 0} (1 int eq 8 pixels) 14 | vae (VQGanGumbelVAE): VQGAN model for image encoding 15 | device (str): cpu or cuda 16 | crop_first (bool): if True, croped image before VQGAN encoding 17 | """ 18 | self.device = device 19 | img = self._preprocess_img(pil_image) 20 | self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first) 21 | 22 | def _preprocess_img(self, pil_img): 23 | img = paddle.to_tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255. 24 | img = img.unsqueeze(0).to(self.device, dtype=paddle.float32) 25 | img = (2 * img) - 1 26 | return img 27 | 28 | @staticmethod 29 | def _get_image_prompts(img, borders, vae, crop_first): 30 | if crop_first: 31 | assert borders['right'] + borders['left'] + borders['down'] == 0 32 | up_border = borders['up'] * 8 33 | _, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :]) 34 | else: 35 | _, _, [_, _, vqg_img] = vae.model.encode(img) 36 | 37 | bs, vqg_img_w, vqg_img_h = vqg_img.shape 38 | mask = paddle.zeros([vqg_img_w, vqg_img_h]) 39 | if borders['up'] != 0: 40 | mask[:borders['up'], :] = 1. 41 | if borders['down'] != 0: 42 | mask[-borders['down']:, :] = 1. 43 | if borders['right'] != 0: 44 | mask[:, :borders['right']] = 1. 45 | if borders['left'] != 0: 46 | mask[:, -borders['left']:] = 1. 47 | mask = mask.reshape((-1,)).astype(paddle.bool) 48 | 49 | image_prompts = vqg_img.reshape((bs, -1)) 50 | image_prompts_idx = np.arange(vqg_img_w * vqg_img_h) 51 | image_prompts_idx = set(image_prompts_idx[mask]) 52 | 53 | return image_prompts_idx, image_prompts 54 | -------------------------------------------------------------------------------- /rudalle_paddle/tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from os.path import join 3 | 4 | import paddle 5 | import numpy as np 6 | import youtokentome as yttm 7 | from huggingface_hub import hf_hub_url, cached_download 8 | 9 | 10 | def get_tokenizer(path=None, cache_dir='/tmp/rudalle'): 11 | # TODO docstring 12 | if path is None: 13 | repo_id = 'HighCWu/rudalle-paddle-utils' 14 | filename = 'bpe.model' 15 | cache_dir = join(cache_dir, 'tokenizer') 16 | config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) 17 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) 18 | path = join(cache_dir, filename) 19 | tokenizer = YTTMTokenizerWrapper(yttm.BPE(model=path)) 20 | print('tokenizer --> ready') 21 | return tokenizer 22 | 23 | 24 | class YTTMTokenizerWrapper: 25 | eos_id = 3 26 | bos_id = 2 27 | unk_id = 1 28 | pad_id = 0 29 | 30 | def __init__(self, tokenizer): 31 | self.tokenizer = tokenizer 32 | 33 | def __len__(self): 34 | return self.vocab_size() 35 | 36 | def get_pad_token_id(self): 37 | # TODO docstring 38 | return self.tokenizer.subword_to_id('') 39 | 40 | def vocab_size(self): 41 | # TODO docstring 42 | return self.tokenizer.vocab_size() 43 | 44 | def encode_text(self, text, text_seq_length, bpe_dropout=0.0): 45 | # TODO docstring 46 | tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=bpe_dropout)[0] 47 | tokens = [self.bos_id] + tokens + [self.eos_id] 48 | return self.prepare_tokens(tokens, text_seq_length) 49 | 50 | def decode_text(self, encoded): 51 | # TODO docstring 52 | return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ 53 | self.eos_id, self.bos_id, self.unk_id, self.pad_id 54 | ])[0] 55 | 56 | @staticmethod 57 | def prepare_tokens(tokens, text_seq_length): 58 | # TODO docstring 59 | empty_positions = text_seq_length - len(tokens) 60 | if empty_positions > 0: 61 | tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text 62 | if len(tokens) > text_seq_length: 63 | tokens = tokens[:text_seq_length] 64 | return paddle.to_tensor(tokens).astype(paddle.int64) 65 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import re 4 | from setuptools import setup 5 | 6 | 7 | def read(filename): 8 | with open(os.path.join(os.path.dirname(__file__), filename)) as f: 9 | file_content = f.read() 10 | return file_content 11 | 12 | 13 | def get_requirements(): 14 | requirements = [] 15 | for requirement in read('requirements.txt').splitlines(): 16 | if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+'): 17 | parsed_requires = re.findall(r'#egg=([\w\d\.]+)-([\d\.]+)$', requirement) 18 | if parsed_requires: 19 | package, version = parsed_requires[0] 20 | requirements.append(f'{package}=={version}') 21 | else: 22 | print('WARNING! For correct matching dependency links need to specify package name and version' 23 | 'such as #egg=-') 24 | else: 25 | requirements.append(requirement) 26 | return requirements 27 | 28 | 29 | def get_links(): 30 | return [ 31 | requirement for requirement in read('requirements.txt').splitlines() 32 | if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+') 33 | ] 34 | 35 | 36 | def get_version(): 37 | """ Get version from the package without actually importing it. """ 38 | init = read('rudalle_paddle/__init__.py') 39 | for line in init.split('\n'): 40 | if line.startswith('__version__'): 41 | return eval(line.split('=')[1]) 42 | 43 | 44 | setup( 45 | name='rudalle_paddle', 46 | version=get_version(), 47 | author='SberAI, SberDevices', 48 | author_email='', 49 | description='', 50 | packages=[ 51 | 'rudalle_paddle', 52 | 'rudalle_paddle/dalle', 53 | 'rudalle_paddle/realesrgan', 54 | 'rudalle_paddle/ruclip', 55 | 'rudalle_paddle/vae', 56 | 'rudalle_paddle/packages', 57 | 'rudalle_paddle/packages/einops', 58 | 'rudalle_paddle/packages/einops/layers', 59 | 'rudalle_paddle/packages/taming/modules/diffusionmodules', 60 | 'rudalle_paddle/packages/transformers', 61 | ], 62 | package_data={'rudalle_paddle/vae': ['*.yml']}, 63 | install_requires=get_requirements(), 64 | dependency_links=get_links(), 65 | long_description=read('README.md'), 66 | long_description_content_type='text/markdown', 67 | ) 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/layers/tensorflow.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | from tensorflow.keras.layers import Layer 4 | 5 | from .._backends import UnknownSize 6 | from . import RearrangeMixin, ReduceMixin 7 | from ._weighted_einsum import WeightedEinsumMixin 8 | 9 | __author__ = 'Alex Rogozhnikov' 10 | 11 | 12 | class Rearrange(RearrangeMixin, Layer): 13 | def compute_output_shape(self, input_shape): 14 | input_shape = tuple(UnknownSize() if d.value is None else int(d) for d in input_shape) 15 | init_shapes, reduced_axes, axes_reordering, final_shape = self.recipe().reconstruct_from_shape(input_shape) 16 | final_shape = tuple(None if isinstance(d, UnknownSize) else int(d) for d in final_shape) 17 | return final_shape 18 | 19 | def call(self, inputs): 20 | return self._apply_recipe(inputs) 21 | 22 | def get_config(self): 23 | return {'pattern': self.pattern, **self.axes_lengths} 24 | 25 | 26 | class Reduce(ReduceMixin, Layer): 27 | def compute_output_shape(self, input_shape): 28 | input_shape = tuple(UnknownSize() if d.value is None else int(d) for d in input_shape) 29 | init_shapes, reduced_axes, axes_reordering, final_shape = self.recipe().reconstruct_from_shape(input_shape) 30 | final_shape = tuple(None if isinstance(d, UnknownSize) else int(d) for d in final_shape) 31 | return final_shape 32 | 33 | def call(self, inputs): 34 | return self._apply_recipe(inputs) 35 | 36 | def get_config(self): 37 | return {'pattern': self.pattern, 'reduction': self.reduction, **self.axes_lengths} 38 | 39 | 40 | class WeightedEinsum(WeightedEinsumMixin, Layer): 41 | def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): 42 | self.weight = tf.Variable(tf.random_uniform_initializer(-weight_bound, weight_bound)(shape=weight_shape), 43 | trainable=True) 44 | if bias_shape is not None: 45 | self.bias = tf.Variable(tf.random_uniform_initializer(-bias_bound, bias_bound)(shape=bias_shape), 46 | trainable=True) 47 | else: 48 | self.bias = None 49 | 50 | def build(self, input_shape): 51 | pass 52 | 53 | def call(self, inputs): 54 | result = tf.einsum(self.einsum_pattern, inputs, self.weight) 55 | if self.bias is not None: 56 | result = result + self.bias 57 | return result 58 | 59 | def get_config(self): 60 | return {'pattern': self.pattern, 61 | 'weight_shape': self.weight_shape, 62 | 'bias_shape': self.bias_shape, 63 | **self.axes_lengths} 64 | -------------------------------------------------------------------------------- /rudalle_paddle/ruclip/processor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import json 4 | import paddle 5 | import youtokentome as yttm 6 | import paddle.vision.transforms as T 7 | 8 | 9 | class RuCLIPProcessor: 10 | eos_id = 3 11 | bos_id = 2 12 | unk_id = 1 13 | pad_id = 0 14 | 15 | def __init__(self, tokenizer_path, image_size=224, text_seq_length=76, mean=None, std=None): 16 | 17 | self.tokenizer = yttm.BPE(tokenizer_path) 18 | self.mean = mean or [0.485, 0.456, 0.406] 19 | self.std = std or [0.229, 0.224, 0.225] 20 | self.image_transform = T.Compose([ 21 | lambda img: img.convert('RGB') if img.mode != 'RGB' else img, 22 | T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)), 23 | T.ToTensor(), 24 | T.Normalize(mean=self.mean, std=self.std) 25 | ]) 26 | self.text_seq_length = text_seq_length 27 | self.image_size = image_size 28 | 29 | def encode_text(self, text): 30 | text = text.lower() 31 | tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0] 32 | tokens = [self.bos_id] + tokens + [self.eos_id] 33 | tokens = tokens[:self.text_seq_length] 34 | mask = [1] * len(tokens) 35 | return paddle.to_tensor(tokens).astype(paddle.int64), paddle.to_tensor(mask).astype(paddle.int64) 36 | 37 | def decode_text(self, encoded): 38 | return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ 39 | self.eos_id, self.bos_id, self.unk_id, self.pad_id 40 | ])[0] 41 | 42 | def __call__(self, text=None, images=None, **kwargs): 43 | inputs = {} 44 | if text is not None: 45 | input_ids, masks = [], [] 46 | texts = [text] if isinstance(text, str) else text 47 | for text in texts: 48 | tokens, mask = self.encode_text(text) 49 | input_ids.append(tokens) 50 | masks.append(mask) 51 | inputs['input_ids'] = paddle.pad_sequence(input_ids, batch_first=True) 52 | inputs['attention_mask'] = paddle.pad_sequence(masks, batch_first=True) 53 | if images is not None: 54 | pixel_values = [] 55 | for i, image in enumerate(images): 56 | pixel_values.append(self.image_transform(image)) 57 | inputs['pixel_values'] = paddle.pad_sequence(pixel_values, batch_first=True) 58 | return inputs 59 | 60 | @classmethod 61 | def from_pretrained(cls, folder): 62 | tokenizer_path = os.path.join(folder, 'bpe.model') 63 | config = json.load(open(os.path.join(folder, 'config.json'))) 64 | image_size = config['vision_config']['image_size'] 65 | text_seq_length = config['text_config']['max_position_embeddings'] - 1 66 | mean, std = config.get('mean'), config.get('std') 67 | return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std) 68 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = 'Alex Rogozhnikov' 3 | 4 | import functools 5 | 6 | from ..einops import TransformRecipe, _prepare_transformation_recipe 7 | from .. import EinopsError 8 | 9 | 10 | class RearrangeMixin: 11 | """ 12 | Rearrange layer behaves identically to einops.rearrange operation. 13 | 14 | :param pattern: str, rearrangement pattern 15 | :param axes_lengths: any additional specification of dimensions 16 | 17 | See einops.rearrange for source_examples. 18 | """ 19 | 20 | def __init__(self, pattern, **axes_lengths): 21 | super().__init__() 22 | self.pattern = pattern 23 | self.axes_lengths = axes_lengths 24 | self._recipe = self.recipe() # checking parameters 25 | 26 | def __repr__(self): 27 | params = repr(self.pattern) 28 | for axis, length in self.axes_lengths.items(): 29 | params += ', {}={}'.format(axis, length) 30 | return '{}({})'.format(self.__class__.__name__, params) 31 | 32 | @functools.lru_cache(maxsize=1024) 33 | def recipe(self) -> TransformRecipe: 34 | try: 35 | hashable_lengths = tuple(sorted(self.axes_lengths.items())) 36 | return _prepare_transformation_recipe(self.pattern, operation='rearrange', axes_lengths=hashable_lengths) 37 | except EinopsError as e: 38 | raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e)) 39 | 40 | def _apply_recipe(self, x): 41 | return self._recipe.apply(x) 42 | 43 | 44 | class ReduceMixin: 45 | """ 46 | Reduce layer behaves identically to einops.reduce operation. 47 | 48 | :param pattern: str, rearrangement pattern 49 | :param reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive 50 | :param axes_lengths: any additional specification of dimensions 51 | 52 | See einops.reduce for source_examples. 53 | """ 54 | 55 | def __init__(self, pattern, reduction, **axes_lengths): 56 | super().__init__() 57 | self.pattern = pattern 58 | self.reduction = reduction 59 | self.axes_lengths = axes_lengths 60 | self._recipe = self.recipe() # checking parameters 61 | 62 | def __repr__(self): 63 | params = '{!r}, {!r}'.format(self.pattern, self.reduction) 64 | for axis, length in self.axes_lengths.items(): 65 | params += ', {}={}'.format(axis, length) 66 | return '{}({})'.format(self.__class__.__name__, params) 67 | 68 | @functools.lru_cache(maxsize=1024) 69 | def recipe(self) -> TransformRecipe: 70 | try: 71 | hashable_lengths = tuple(sorted(self.axes_lengths.items())) 72 | return _prepare_transformation_recipe(self.pattern, operation=self.reduction, axes_lengths=hashable_lengths) 73 | except EinopsError as e: 74 | raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e)) 75 | 76 | def _apply_recipe(self, x): 77 | return self._recipe.apply(x) 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ruDALL-E PaddlePaddle 2 | 3 | ruDALL-E in PaddlePaddle. 4 | 5 | Install: 6 | 7 | ``` 8 | pip install rudalle_paddle==0.0.1rc1 9 | ``` 10 | 11 | Run with free v100 on [AI Studio](https://aistudio.baidu.com/aistudio/projectdetail/2684828). 12 | 13 | Original Pytorch version Readme: 14 | 15 | # ruDALL-E 16 | ### Generate images from texts 17 | [![Apache license](https://img.shields.io/badge/License-Apache-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0) 18 | [![Coverage Status](https://codecov.io/gh/sberbank-ai/ru-dalle/branch/master/graphs/badge.svg)](https://codecov.io/gh/sberbank-ai/ru-dalle) 19 | [![pipeline](https://gitlab.com/shonenkov/ru-dalle/badges/master/pipeline.svg)](https://gitlab.com/shonenkov/ru-dalle/-/pipelines) 20 | [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master) 21 | 22 | ### 🤗 HF Models: 23 | [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) 24 | 25 | ### Minimal Example: 26 | 27 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1wGE-046et27oHvNlBNPH07qrEQNE04PQ?usp=sharing) 28 | [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/shonenkov/rudalle-example-generation) 29 | 30 | ### generation by ruDALLE: 31 | ```python 32 | from rudalle_paddle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip 33 | from rudalle_paddle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip 34 | from rudalle_paddle.utils import seed_everything 35 | 36 | # prepare models 37 | device = 'cuda' 38 | dalle = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device) 39 | realesrgan = get_realesrgan('x4', device=device) 40 | tokenizer = get_tokenizer() 41 | vae = get_vae().to(device) 42 | ruclip, ruclip_processor = get_ruclip('ruclip-vit-base-patch32-v5') 43 | ruclip = ruclip.to(device) 44 | 45 | text = 'изображение радуги на фоне ночного города' 46 | 47 | seed_everything(42) 48 | pil_images = [] 49 | scores = [] 50 | for top_k, top_p, images_num in [ 51 | (2048, 0.995, 3), 52 | (1536, 0.99, 3), 53 | (1024, 0.99, 3), 54 | (1024, 0.98, 3), 55 | (512, 0.97, 3), 56 | (384, 0.96, 3), 57 | (256, 0.95, 3), 58 | (128, 0.95, 3), 59 | ]: 60 | _pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p) 61 | pil_images += _pil_images 62 | scores += _scores 63 | 64 | show(pil_images, 6) 65 | ``` 66 | ![](./pics/rainbow-full.png) 67 | ### auto cherry-pick by ruCLIP: 68 | ```python 69 | top_images, clip_scores = cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device=device, count=6) 70 | show(top_images, 3) 71 | ``` 72 | ![](./pics/rainbow-cherry-pick.png) 73 | ### super resolution: 74 | ```python 75 | sr_images = super_resolution(top_images, realesrgan) 76 | show(sr_images, 3) 77 | ``` 78 | ![](./pics/rainbow-super-resolution.png) 79 | 80 | ```python 81 | text, seed = 'красивая тян из аниме', 6955 82 | ``` 83 | ![](./pics/anime-girl-super-resolution.png) 84 | 85 | 86 | ### Image Prompt 87 | see `jupyters/ruDALLE-image-prompts-A100.ipynb` 88 | ```python 89 | text, seed = 'Храм Василия Блаженного', 42 90 | skyes = [red_sky, sunny_sky, cloudy_sky, night_sky] 91 | ``` 92 | ![](./pics/russian-temple-image-prompt.png) 93 | 94 | 95 | ### 🚀 Contributors 🚀 96 | 97 | - [@neverix](https://www.kaggle.com/neverix) thanks a lot for contributing for speed up of inference 98 | - [@oriBetelgeuse](https://github.com/oriBetelgeuse) thanks a lot for easy API of generation using image prompt 99 | -------------------------------------------------------------------------------- /rudalle_paddle/realesrgan/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Source: https://github.com/boomb0om/Real-ESRGAN-colab 3 | 4 | import paddle 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from .rrdbnet_arch import RRDBNet 9 | from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, unpad_image 10 | 11 | 12 | class RealESRGAN: 13 | def __init__(self, device, scale=4): 14 | self.device = device 15 | self.scale = scale 16 | self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) 17 | 18 | def load_weights(self, model_path): 19 | if model_path[-4:] == '.pth': 20 | model_path = RealESRGAN.convert(model_path) 21 | loadnet = paddle.load(model_path) 22 | if 'params' in loadnet: 23 | self.model.set_state_dict(loadnet['params']) 24 | elif 'params_ema' in loadnet: 25 | self.model.set_state_dict(loadnet['params_ema']) 26 | else: 27 | self.model.set_state_dict(loadnet) 28 | self.model.eval() 29 | self.model.to(self.device) 30 | 31 | def predict(self, lr_image, batch_size=4, patches_size=192, 32 | padding=24, pad_size=15): 33 | scale = self.scale 34 | device = self.device 35 | lr_image = np.array(lr_image) 36 | lr_image = pad_reflect(lr_image, pad_size) 37 | 38 | patches, p_shape = split_image_into_overlapping_patches(lr_image, patch_size=patches_size, 39 | padding_size=padding) 40 | img = paddle.to_tensor(patches / 255.).astype(paddle.float32).transpose((0, 3, 1, 2)).to(device).detach() 41 | 42 | with paddle.no_grad(): 43 | res = self.model(img[0:batch_size]) 44 | for i in range(batch_size, img.shape[0], batch_size): 45 | res = paddle.concat((res, self.model(img[i:i + batch_size])), 0) 46 | 47 | sr_image = res.transpose((0, 2, 3, 1)).cpu().clip_(0, 1) 48 | np_sr_image = sr_image.numpy() 49 | 50 | padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,) 51 | scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,) 52 | np_sr_image = stich_together(np_sr_image, padded_image_shape=padded_size_scaled, 53 | target_shape=scaled_image_shape, padding_size=padding * scale) 54 | sr_img = (np_sr_image * 255).astype(np.uint8) 55 | sr_img = unpad_image(sr_img, pad_size * scale) 56 | sr_img = Image.fromarray(sr_img) 57 | 58 | return sr_img 59 | 60 | @staticmethod 61 | def convert(model_path): 62 | import os 63 | import torch 64 | torch_weights = model_path 65 | target_model_path = model_path[:-4] + '.pdparams' 66 | if os.path.exists(target_model_path): 67 | return target_model_path 68 | _state_dict = torch.load(torch_weights, map_location='cpu') 69 | if 'params' in _state_dict: 70 | state_dict = _state_dict['params'] 71 | elif 'params_ema' in _state_dict: 72 | state_dict = _state_dict['params_ema'] 73 | else: 74 | state_dict = _state_dict 75 | 76 | paddle_state_dict = {} 77 | for name, param in state_dict.items(): 78 | if param.ndim == 0: 79 | param = param.unsqueeze(0) 80 | param = param.cpu().detach().numpy() 81 | paddle_state_dict[name] = param 82 | 83 | if 'params' in _state_dict: 84 | paddle_state_dict = {'params': paddle_state_dict} 85 | elif 'params_ema' in _state_dict: 86 | paddle_state_dict = {'params_ema': paddle_state_dict} 87 | 88 | paddle.save(paddle_state_dict, target_model_path) 89 | 90 | return target_model_path 91 | -------------------------------------------------------------------------------- /rudalle_paddle/dalle/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | import pickle 5 | import paddle 6 | from huggingface_hub import hf_hub_url, cached_download 7 | 8 | from .model import DalleModel 9 | from .fp16 import FP16Module 10 | 11 | 12 | MODELS = { 13 | 'Malevich': dict( 14 | description='◼️ Malevich is 1.3 billion params model from the family GPT3-like, ' 15 | 'that uses Russian language and text+image multi-modality.', 16 | model_params=dict( 17 | num_layers=24, 18 | hidden_size=2048, 19 | num_attention_heads=16, 20 | embedding_dropout_prob=0.1, 21 | output_dropout_prob=0.1, 22 | attention_dropout_prob=0.1, 23 | image_tokens_per_dim=32, 24 | text_seq_length=128, 25 | use_masks=True, 26 | cogview_sandwich_layernorm=True, 27 | cogview_pb_relax=True, 28 | vocab_size=16384+128, 29 | image_vocab_size=8192, 30 | ), 31 | repo_id='sberbank-ai/rudalle-Malevich', 32 | filename='pytorch_model.bin', 33 | full_description='', # TODO 34 | ), 35 | 'Malevich-paddle': dict( 36 | description='◼️ Malevich is 1.3 billion params model from the family GPT3-like, ' 37 | 'that uses Russian language and text+image multi-modality.', 38 | model_params=dict( 39 | num_layers=24, 40 | hidden_size=2048, 41 | num_attention_heads=16, 42 | embedding_dropout_prob=0.1, 43 | output_dropout_prob=0.1, 44 | attention_dropout_prob=0.1, 45 | image_tokens_per_dim=32, 46 | text_seq_length=128, 47 | use_masks=True, 48 | cogview_sandwich_layernorm=True, 49 | cogview_pb_relax=True, 50 | vocab_size=16384+128, 51 | image_vocab_size=8192, 52 | ), 53 | repo_id='HighCWu/rudalle-paddle-utils', 54 | filename='rudalle_paddle.pkl', 55 | full_description='', # TODO 56 | ), 57 | 'small': dict( 58 | description='', 59 | model_params=dict( 60 | num_layers=12, 61 | hidden_size=768, 62 | num_attention_heads=12, 63 | embedding_dropout_prob=0.1, 64 | output_dropout_prob=0.1, 65 | attention_dropout_prob=0.1, 66 | image_tokens_per_dim=32, 67 | text_seq_length=128, 68 | use_masks=True, 69 | cogview_sandwich_layernorm=True, 70 | cogview_pb_relax=True, 71 | vocab_size=16384+128, 72 | image_vocab_size=8192, 73 | ), 74 | repo_id='', 75 | filename='', 76 | full_description='', # TODO 77 | ), 78 | } 79 | 80 | 81 | def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir='/tmp/rudalle'): 82 | # TODO docstring 83 | assert name in MODELS 84 | 85 | paddle.set_device(device) 86 | 87 | if fp16 and device == 'cpu': 88 | print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.') 89 | 90 | config = MODELS[name] 91 | model = DalleModel(device=device, fp16=fp16, **config['model_params']) 92 | if pretrained: 93 | cache_dir = os.path.join(cache_dir, name) 94 | config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) 95 | cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) 96 | if config['filename'] == 'pytorch_model.bin': 97 | DalleModel.convert(cache_dir) 98 | config['filename'] = 'rudalle_paddle.pkl' 99 | with open(os.path.join(cache_dir, config['filename']), 'rb') as f: 100 | # paddle.load could not load large paddle.save file 101 | checkpoint = pickle.load(f) 102 | checkpoint = {k: v.astype('float32') for k, v in checkpoint.items()} 103 | model.set_state_dict(checkpoint) 104 | if fp16: 105 | model = FP16Module(model) 106 | model.eval() 107 | model.to(device) 108 | if config['description'] and pretrained: 109 | print(config['description']) 110 | return model 111 | -------------------------------------------------------------------------------- /rudalle_paddle/realesrgan/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | 5 | def pad_reflect(image, pad_size): 6 | imsize = image.shape 7 | height, width = imsize[:2] 8 | new_img = np.zeros([height + pad_size * 2, width + pad_size * 2, imsize[2]]).astype(np.uint8) 9 | new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image 10 | new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) # top 11 | new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) # bottom 12 | new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size * 2, :], axis=1) # left 13 | new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size * 2:-pad_size, :], axis=1) # right 14 | return new_img 15 | 16 | 17 | def unpad_image(image, pad_size): 18 | return image[pad_size:-pad_size, pad_size:-pad_size, :] 19 | 20 | 21 | def pad_patch(image_patch, padding_size, channel_last=True): 22 | """ Pads image_patch with with padding_size edge values. """ 23 | if channel_last: 24 | return np.pad( 25 | image_patch, 26 | ((padding_size, padding_size), (padding_size, padding_size), (0, 0)), 27 | 'edge', 28 | ) 29 | else: 30 | return np.pad( 31 | image_patch, 32 | ((0, 0), (padding_size, padding_size), (padding_size, padding_size)), 33 | 'edge', 34 | ) 35 | 36 | 37 | def unpad_patches(image_patches, padding_size): 38 | return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :] 39 | 40 | 41 | def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2): 42 | """ Splits the image into partially overlapping patches. 43 | The patches overlap by padding_size pixels. 44 | Pads the image twice: 45 | - first to have a size multiple of the patch size, 46 | - then to have equal padding at the borders. 47 | Args: 48 | image_array: numpy array of the input image. 49 | patch_size: size of the patches from the original image (without padding). 50 | padding_size: size of the overlapping area. 51 | """ 52 | xmax, ymax, _ = image_array.shape 53 | x_remainder = xmax % patch_size 54 | y_remainder = ymax % patch_size 55 | 56 | # modulo here is to avoid extending of patch_size instead of 0 57 | x_extend = (patch_size - x_remainder) % patch_size 58 | y_extend = (patch_size - y_remainder) % patch_size 59 | 60 | # make sure the image is divisible into regular patches 61 | extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge') 62 | 63 | # add padding around the image to simplify computations 64 | padded_image = pad_patch(extended_image, padding_size, channel_last=True) 65 | 66 | xmax, ymax, _ = padded_image.shape 67 | patches = [] 68 | 69 | x_lefts = range(padding_size, xmax - padding_size, patch_size) 70 | y_tops = range(padding_size, ymax - padding_size, patch_size) 71 | 72 | for x in x_lefts: 73 | for y in y_tops: 74 | x_left = x - padding_size 75 | y_top = y - padding_size 76 | x_right = x + patch_size + padding_size 77 | y_bottom = y + patch_size + padding_size 78 | patch = padded_image[x_left:x_right, y_top:y_bottom, :] 79 | patches.append(patch) 80 | 81 | return np.array(patches), padded_image.shape 82 | 83 | 84 | def stich_together(patches, padded_image_shape, target_shape, padding_size=4): 85 | """ Reconstruct the image from overlapping patches. 86 | After scaling, shapes and padding should be scaled too. 87 | Args: 88 | patches: patches obtained with split_image_into_overlapping_patches 89 | padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches 90 | target_shape: shape of the final image 91 | padding_size: size of the overlapping area. 92 | """ 93 | 94 | xmax, ymax, _ = padded_image_shape 95 | patches = unpad_patches(patches, padding_size) 96 | patch_size = patches.shape[1] 97 | n_patches_per_row = ymax // patch_size 98 | 99 | complete_image = np.zeros((xmax, ymax, 3)) 100 | 101 | row = -1 102 | col = 0 103 | for i in range(len(patches)): 104 | if i % n_patches_per_row == 0: 105 | row += 1 106 | col = 0 107 | complete_image[ 108 | row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, : 109 | ] = patches[i] 110 | col += 1 111 | return complete_image[0: target_shape[0], 0: target_shape[1], :] 112 | -------------------------------------------------------------------------------- /rudalle_paddle/pipelines.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import paddle 3 | import transformers 4 | import more_itertools 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from tqdm.auto import tqdm 8 | 9 | from . import utils 10 | 11 | 12 | def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image_prompts=None, temperature=1.0, bs=8, 13 | seed=None, use_cache=True): 14 | # TODO docstring 15 | if seed is not None: 16 | utils.seed_everything(seed) 17 | 18 | vocab_size = dalle.get_param('vocab_size') 19 | text_seq_length = dalle.get_param('text_seq_length') 20 | image_seq_length = dalle.get_param('image_seq_length') 21 | total_seq_length = dalle.get_param('total_seq_length') 22 | device = dalle.get_param('device') 23 | 24 | text = text.lower().strip() 25 | input_ids = tokenizer.encode_text(text, text_seq_length=text_seq_length) 26 | pil_images, scores = [], [] 27 | for chunk in more_itertools.chunked(range(images_num), bs): 28 | chunk_bs = len(chunk) 29 | with paddle.no_grad(): 30 | attention_mask = paddle.tril(paddle.ones((chunk_bs, 1, total_seq_length, total_seq_length)).to(device)) 31 | out = input_ids.unsqueeze(0).tile([chunk_bs, 1]).to(device) 32 | has_cache = False 33 | sample_scores = [] 34 | if image_prompts is not None: 35 | prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts 36 | prompts = prompts.tile([images_num, 1]) 37 | if use_cache: 38 | use_cache = False 39 | print('Warning: use_cache changed to False') 40 | for idx in tqdm(range(out.shape[1], total_seq_length)): 41 | idx -= text_seq_length 42 | if image_prompts is not None and idx in prompts_idx: 43 | out = paddle.concat((out, prompts[:, idx].unsqueeze(1)), axis=-1) 44 | else: 45 | logits, has_cache = dalle(out, attention_mask, 46 | has_cache=has_cache, use_cache=use_cache, return_loss=False) 47 | logits = logits[:, -1, vocab_size:] 48 | logits /= temperature 49 | filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 50 | probs = paddle.nn.functional.softmax(filtered_logits, axis=-1) 51 | probs = paddle.nn.functional.softmax(logits, axis=-1) 52 | sample = paddle.multinomial(probs, 1) 53 | sample_scores.append(probs[paddle.arange(probs.shape[0]), sample.swapaxes(0, 1)]) 54 | out = paddle.concat((out, sample), axis=-1) 55 | codebooks = out[:, -image_seq_length:] 56 | images = vae.decode(codebooks) 57 | pil_images += utils.paddle_tensors_to_pil_list(images) 58 | scores += paddle.concat(sample_scores).sum(0).detach().cpu().numpy().tolist() 59 | return pil_images, scores 60 | 61 | 62 | def super_resolution(pil_images, realesrgan): 63 | result = [] 64 | for pil_image in pil_images: 65 | with paddle.no_grad(): 66 | sr_image = realesrgan.predict(np.array(pil_image)) 67 | result.append(sr_image) 68 | return result 69 | 70 | 71 | def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu', count=4): 72 | with paddle.no_grad(): 73 | inputs = ruclip_processor(text=text, images=pil_images) 74 | for key in inputs.keys(): 75 | inputs[key] = inputs[key].to(device) 76 | outputs = ruclip(**inputs) 77 | sims = paddle.nn.functional.softmax(outputs.logits_per_image.reshape([-1]), axis=0) 78 | items = [] 79 | for index, sim in enumerate(sims.cpu().numpy()): 80 | items.append({'img_index': index, 'cosine': sim}) 81 | items = sorted(items, key=lambda x: x['cosine'], reverse=True)[:count] 82 | top_pil_images = [pil_images[x['img_index']] for x in items] 83 | top_scores = [x['cosine'] for x in items] 84 | return top_pil_images, top_scores 85 | 86 | 87 | def show(pil_images, nrow=4): 88 | imgs = utils.pil_list_to_paddle_tensors(pil_images) 89 | imgs = paddle.nn.functional.pad(imgs.astype(paddle.int64), [1, 1, 1, 1], value=0) 90 | imgs = paddle.concat(paddle.concat(imgs.split(imgs.shape[0]//nrow, 0), 2).split(nrow, 0), 3) 91 | if not isinstance(imgs, list): 92 | imgs = [imgs.cpu()] 93 | fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14)) 94 | for i, img in enumerate(imgs): 95 | img = img.detach().numpy().astype(np.uint8)[0].transpose([1, 2, 0]) 96 | axs[0, i].imshow(img) 97 | axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) 98 | fix.show() 99 | plt.show() 100 | -------------------------------------------------------------------------------- /rudalle_paddle/vae/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from math import sqrt, log 3 | 4 | import paddle 5 | import paddle.nn as nn 6 | import paddle.nn.functional as F 7 | from paddle import einsum 8 | from einops import rearrange 9 | from taming.modules.diffusionmodules.model import Encoder, Decoder 10 | 11 | 12 | class VQGanGumbelVAE(paddle.nn.Layer): 13 | 14 | def __init__(self, config): 15 | super().__init__() 16 | model = GumbelVQ( 17 | ddconfig=config.model.params.ddconfig, 18 | n_embed=config.model.params.n_embed, 19 | embed_dim=config.model.params.embed_dim, 20 | kl_weight=config.model.params.kl_weight, 21 | ) 22 | self.model = model 23 | self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0]) / log(2)) 24 | self.image_size = 256 25 | self.num_tokens = config.model.params.n_embed 26 | 27 | @paddle.no_grad() 28 | def get_codebook_indices(self, img): 29 | img = (2 * img) - 1 30 | _, _, [_, _, indices] = self.model.encode(img) 31 | return rearrange(indices, 'b h w -> b (h w)') 32 | 33 | def decode(self, img_seq): 34 | b, n = img_seq.shape 35 | one_hot_indices = paddle.nn.functional.one_hot(img_seq, num_classes=self.num_tokens).astype(paddle.float32) 36 | z = paddle.matmul(one_hot_indices, self.model.quantize.embed.weight) 37 | z = rearrange(z, 'b (h w) c -> b c h w', h=int(sqrt(n))) 38 | img = self.model.decode(z) 39 | img = (img.clip(-1., 1.) + 1) * 0.5 40 | return img 41 | 42 | @staticmethod 43 | def convert(model_path): 44 | import os 45 | import torch 46 | torch_weights = model_path 47 | target_model_path = model_path[:-5] + '.pdckpt' 48 | if os.path.exists(target_model_path): 49 | return target_model_path 50 | state_dict = torch.load(torch_weights, map_location='cpu')['state_dict'] 51 | 52 | paddle_state_dict = {} 53 | for name, param in state_dict.items(): 54 | if 'dense' in name and param.ndim == 2: 55 | param = param.transpose(1, 0) 56 | if param.ndim == 0: 57 | param = param.unsqueeze(0) 58 | param = param.cpu().detach().numpy() 59 | paddle_state_dict[name] = param 60 | 61 | paddle.save({'state_dict': paddle_state_dict}, target_model_path) 62 | 63 | return model_path 64 | 65 | 66 | class GumbelQuantize(nn.Layer): 67 | """ 68 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 69 | Gumbel Softmax trick quantizer 70 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 71 | https://arxiv.org/abs/1611.01144 72 | """ 73 | 74 | def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, 75 | kl_weight=5e-4, temp_init=1.0, use_vqinterface=True): 76 | super().__init__() 77 | self.embedding_dim = embedding_dim 78 | self.n_embed = n_embed 79 | self.straight_through = straight_through 80 | self.temperature = temp_init 81 | self.kl_weight = kl_weight 82 | self.proj = nn.Conv2D(num_hiddens, n_embed, 1) 83 | self.embed = nn.Embedding(self.n_embed, self.embedding_dim) 84 | self.use_vqinterface = use_vqinterface 85 | 86 | def forward(self, z, temp=None, return_logits=False): 87 | hard = self.straight_through if self.training else True 88 | temp = self.temperature if temp is None else temp 89 | logits = self.proj(z) 90 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, axis=1, hard=hard) 91 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) 92 | # + kl divergence to the prior loss 93 | qy = F.softmax(logits, axis=1) 94 | diff = self.kl_weight * paddle.sum(qy * paddle.log(qy * self.n_embed + 1e-10), axis=1).mean() 95 | ind = soft_one_hot.argmax(axis=1) 96 | if self.use_vqinterface: 97 | if return_logits: 98 | return z_q, diff, (None, None, ind), logits 99 | return z_q, diff, (None, None, ind) 100 | return z_q, diff, ind 101 | 102 | 103 | class GumbelVQ(nn.Layer): 104 | 105 | def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8): 106 | super().__init__() 107 | z_channels = ddconfig['z_channels'] 108 | self.encoder = Encoder(**ddconfig) 109 | self.decoder = Decoder(**ddconfig) 110 | self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0) 111 | self.quant_conv = paddle.nn.Conv2D(ddconfig['z_channels'], embed_dim, 1) 112 | self.post_quant_conv = paddle.nn.Conv2D(embed_dim, ddconfig['z_channels'], 1) 113 | 114 | def encode(self, x): 115 | h = self.encoder(x) 116 | h = self.quant_conv(h) 117 | quant, emb_loss, info = self.quantize(h) 118 | return quant, emb_loss, info 119 | 120 | def decode(self, quant): 121 | quant = self.post_quant_conv(quant) 122 | dec = self.decoder(quant) 123 | return dec 124 | -------------------------------------------------------------------------------- /rudalle_paddle/realesrgan/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import paddle 3 | from paddle import nn as nn 4 | from paddle.nn import functional as F 5 | 6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle 7 | 8 | 9 | class ResidualDenseBlock(nn.Layer): 10 | """Residual Dense Block. 11 | Used in RRDB block in ESRGAN. 12 | Args: 13 | num_feat (int): Channel number of intermediate features. 14 | num_grow_ch (int): Channels for each growth. 15 | """ 16 | 17 | def __init__(self, num_feat=64, num_grow_ch=32): 18 | super(ResidualDenseBlock, self).__init__() 19 | self.conv1 = nn.Conv2D(num_feat, num_grow_ch, 3, 1, 1) 20 | self.conv2 = nn.Conv2D(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 21 | self.conv3 = nn.Conv2D(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 22 | self.conv4 = nn.Conv2D(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv5 = nn.Conv2D(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 24 | 25 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, ) 26 | 27 | # initialization 28 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 29 | 30 | def forward(self, x): 31 | x1 = self.lrelu(self.conv1(x)) 32 | x2 = self.lrelu(self.conv2(paddle.concat((x, x1), 1))) 33 | x3 = self.lrelu(self.conv3(paddle.concat((x, x1, x2), 1))) 34 | x4 = self.lrelu(self.conv4(paddle.concat((x, x1, x2, x3), 1))) 35 | x5 = self.conv5(paddle.concat((x, x1, x2, x3, x4), 1)) 36 | # Emperically, we use 0.2 to scale the residual for better performance 37 | return x5 * 0.2 + x 38 | 39 | 40 | class RRDB(nn.Layer): 41 | """Residual in Residual Dense Block. 42 | Used in RRDB-Net in ESRGAN. 43 | Args: 44 | num_feat (int): Channel number of intermediate features. 45 | num_grow_ch (int): Channels for each growth. 46 | """ 47 | 48 | def __init__(self, num_feat, num_grow_ch=32): 49 | super(RRDB, self).__init__() 50 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 51 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 52 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 53 | 54 | def forward(self, x): 55 | out = self.rdb1(x) 56 | out = self.rdb2(out) 57 | out = self.rdb3(out) 58 | # Emperically, we use 0.2 to scale the residual for better performance 59 | return out * 0.2 + x 60 | 61 | 62 | class RRDBNet(nn.Layer): 63 | """Networks consisting of Residual in Residual Dense Block, which is used 64 | in ESRGAN. 65 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 66 | We extend ESRGAN for scale x2 and scale x1. 67 | Note: This is one option for scale 1, scale 2 in RRDBNet. 68 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 69 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 70 | Args: 71 | num_in_ch (int): Channel number of inputs. 72 | num_out_ch (int): Channel number of outputs. 73 | num_feat (int): Channel number of intermediate features. 74 | Default: 64 75 | num_block (int): Block number in the trunk network. Defaults: 23 76 | num_grow_ch (int): Channels for each growth. Default: 32. 77 | """ 78 | 79 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 80 | super(RRDBNet, self).__init__() 81 | self.scale = scale 82 | if scale == 2: 83 | num_in_ch = num_in_ch * 4 84 | elif scale == 1: 85 | num_in_ch = num_in_ch * 16 86 | self.conv_first = nn.Conv2D(num_in_ch, num_feat, 3, 1, 1) 87 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 88 | self.conv_body = nn.Conv2D(num_feat, num_feat, 3, 1, 1) 89 | # upsample 90 | self.conv_up1 = nn.Conv2D(num_feat, num_feat, 3, 1, 1) 91 | self.conv_up2 = nn.Conv2D(num_feat, num_feat, 3, 1, 1) 92 | if scale == 8: 93 | self.conv_up3 = nn.Conv2D(num_feat, num_feat, 3, 1, 1) 94 | self.conv_hr = nn.Conv2D(num_feat, num_feat, 3, 1, 1) 95 | self.conv_last = nn.Conv2D(num_feat, num_out_ch, 3, 1, 1) 96 | 97 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, ) 98 | 99 | def forward(self, x): 100 | if self.scale == 2: 101 | feat = pixel_unshuffle(x, scale=2) 102 | elif self.scale == 1: 103 | feat = pixel_unshuffle(x, scale=4) 104 | else: 105 | feat = x 106 | feat = self.conv_first(feat) 107 | body_feat = self.conv_body(self.body(feat)) 108 | feat = feat + body_feat 109 | # upsample 110 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 111 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 112 | if self.scale == 8: 113 | feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest'))) 114 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 115 | return out 116 | -------------------------------------------------------------------------------- /rudalle_paddle/ruclip/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import json 4 | import paddle 5 | 6 | from clip import CLIP 7 | 8 | 9 | class CLIPModel(CLIP): 10 | def encode_text(self, text): 11 | x = self.token_embedding(text) 12 | if x.shape[1] != self.context_length: 13 | x = paddle.concat([ 14 | x, 15 | paddle.zeros( 16 | [x.shape[0], self.context_length - x.shape[1], x.shape[2]], 17 | dtype=x.dtype 18 | ) 19 | ], 1) 20 | x = x + self.positional_embedding 21 | x = self.transformer(x) 22 | x = self.ln_final(x)[:, :text.shape[1]] 23 | 24 | select = [] 25 | index = zip( 26 | paddle.arange(x.shape[0]).numpy(), 27 | text.argmax(axis=-1).numpy() 28 | ) 29 | for i, j in index: 30 | select.append(x[int(i), int(j)]) 31 | 32 | x = paddle.stack(select) @ self.text_projection 33 | 34 | return x 35 | 36 | def forward(self, **kwargs): 37 | logits_per_image, logits_per_text = super(CLIPModel, self).forward( 38 | kwargs.get('pixel_values'), kwargs.get('input_ids')) 39 | outputs = type('LamdaCls', (), { 40 | 'logits_per_image': logits_per_image, 41 | 'logits_per_text': logits_per_text 42 | }) 43 | return outputs 44 | 45 | @classmethod 46 | def from_pretrained(cls, folder): 47 | with open(os.path.join(folder, 'config.json'), 'r', encoding='utf-8') as f: 48 | src_conf = json.load(f) 49 | dst_conf = { 50 | 'embed_dim': src_conf['projection_dim'], 51 | # vision 52 | 'image_resolution': src_conf['vision_config']['image_size'], 53 | 'vision_layers': src_conf['vision_config']['num_hidden_layers'], 54 | 'vision_width': src_conf['vision_config']['hidden_size'], 55 | 'vision_patch_size': src_conf['vision_config_dict']['patch_size'], 56 | # text 57 | 'context_length': src_conf['text_config']['max_position_embeddings'], 58 | 'vocab_size': src_conf['text_config']['vocab_size'], 59 | 'transformer_width': src_conf['text_config']['hidden_size'], 60 | 'transformer_heads': src_conf['text_config']['num_attention_heads'], 61 | 'transformer_layers': src_conf['text_config']['num_hidden_layers'], 62 | } 63 | obj = cls(**dst_conf) 64 | paddle_weights = os.path.join(folder, 'ruclip_paddle.pdparams') 65 | if not os.path.exists(paddle_weights): 66 | cls.convert(folder) 67 | obj.set_state_dict(paddle.load(paddle_weights)) 68 | return obj 69 | 70 | @staticmethod 71 | def convert(folder): 72 | import os 73 | import torch 74 | torch_weights = os.path.join(folder, 'pytorch_model.bin') 75 | target_model_path = os.path.join(folder, 'ruclip_paddle.pdparams') 76 | if os.path.exists(target_model_path): 77 | return 78 | state_dict = torch.load(torch_weights, map_location='cpu') 79 | 80 | name_pairs = [ 81 | ('text_model.embeddings.position_embedding.weight', 'positional_embedding', False), 82 | ('visual_projection.weight', 'visual.proj', True), 83 | ('text_projection.weight', 'text_projection', True), 84 | ('text_model.embeddings.token_embedding.weight', 'token_embedding.weight', False), 85 | ('logit_scale', 'logit_scale', False), 86 | ('vision_model.embeddings.class_embedding', 'visual.class_embedding', False), 87 | ('vision_model.embeddings.patch_embedding.weight', 'visual.conv1.weight', False), 88 | ('vision_model.embeddings.position_embedding.weight', 'visual.positional_embedding', False), 89 | ('vision_model.pre_layrnorm', 'visual.ln_pre', False), 90 | ('vision_model.encoder.layers', 'visual.transformer.resblocks', True), 91 | ('text_model.encoder.layers', 'transformer.resblocks', True), 92 | ('self_attn.k_proj', 'attn.k_proj', True), 93 | ('self_attn.v_proj', 'attn.v_proj', True), 94 | ('self_attn.q_proj', 'attn.q_proj', True), 95 | ('self_attn.out_proj', 'attn.out_proj', True), 96 | ('layer_norm1', 'ln_1', False), 97 | ('layer_norm2', 'ln_2', False), 98 | ('mlp.fc1', 'mlp.c_fc', True), 99 | ('mlp.fc2', 'mlp.c_proj', True), 100 | ('vision_model.post_layernorm', 'visual.ln_post', False), 101 | ('text_model.final_layer_norm', 'ln_final', False) 102 | ] 103 | exclude_names = [ 104 | 'text_model.embeddings.position_ids', 105 | 'vision_model.embeddings.position_ids' 106 | ] 107 | paddle_state_dict = {} 108 | for name, param in state_dict.items(): 109 | is_pair = False 110 | no_need_transpose = True 111 | if name in exclude_names: 112 | continue 113 | for pre_name, post_name, do_transpose in name_pairs: 114 | if pre_name in name: 115 | is_pair = True 116 | name = name.replace(pre_name, post_name) 117 | no_need_transpose = not do_transpose if no_need_transpose else False 118 | assert is_pair, f'Weight of {name} need to be converted.' 119 | if not no_need_transpose and param.ndim == 2: 120 | param = param.transpose(1, 0) 121 | if param.ndim == 0: 122 | param = param.unsqueeze(0) 123 | param = param.cpu().detach().numpy() 124 | paddle_state_dict[name] = param 125 | 126 | paddle.save(paddle_state_dict, target_model_path) 127 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC 3 | 4 | import paddle 5 | 6 | 7 | class LogitsWarper(ABC): 8 | """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" 9 | 10 | def __call__(self, input_ids, scores): 11 | """Paddle method for warping logits.""" 12 | raise NotImplementedError( 13 | f'{self.__class__} is an abstract class. Only classes inheriting this class can be called.' 14 | ) 15 | 16 | 17 | class TopPLogitsWarper(LogitsWarper): 18 | """ 19 | :class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= 20 | prob_cut_off. 21 | Args: 22 | top_p (:obj:`float`): 23 | If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are 24 | kept for generation. 25 | filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): 26 | All filtered values will be set to this float value. 27 | min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): 28 | Minimum number of tokens that cannot be filtered. 29 | """ 30 | 31 | def __init__(self, top_p: float, filter_value: float = -float('Inf'), min_tokens_to_keep: int = 1): 32 | top_p = float(top_p) 33 | if top_p < 0 or top_p > 1.0: 34 | raise ValueError(f'`top_p` has to be a float > 0 and < 1, but is {top_p}') 35 | 36 | self.top_p = top_p 37 | self.filter_value = filter_value 38 | self.min_tokens_to_keep = min_tokens_to_keep 39 | 40 | def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor) -> paddle.Tensor: 41 | sorted_logits = paddle.sort(scores, descending=True) 42 | sorted_indices = paddle.argsort(scores, descending=True) 43 | cumulative_probs = sorted_logits.softmax(axis=-1).cumsum(axis=-1) 44 | 45 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) 46 | sorted_indices_to_remove = cumulative_probs > self.top_p 47 | if self.min_tokens_to_keep > 1: 48 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 49 | sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0 50 | # Shift the indices to the right to keep also the first token above the threshold 51 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 52 | sorted_indices_to_remove[..., 0] = 0 53 | 54 | # scatter sorted tensors to original indexing 55 | # unknow why package transformers use dim as `1` 56 | indices_to_remove = sorted_indices_to_remove.scatter_by_axis(-1, sorted_indices, sorted_indices_to_remove) 57 | scores = scores.masked_fill(indices_to_remove, self.filter_value) 58 | return scores 59 | 60 | 61 | class TopKLogitsWarper(LogitsWarper): 62 | r""" 63 | :class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements. 64 | Args: 65 | top_k (:obj:`int`): 66 | The number of highest probability vocabulary tokens to keep for top-k-filtering. 67 | filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): 68 | All filtered values will be set to this float value. 69 | min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): 70 | Minimum number of tokens that cannot be filtered. 71 | """ 72 | 73 | def __init__(self, top_k: int, filter_value: float = -float('Inf'), min_tokens_to_keep: int = 1): 74 | if not isinstance(top_k, int) or top_k <= 0: 75 | raise ValueError(f'`top_k` has to be a strictly positive integer, but is {top_k}') 76 | 77 | self.top_k = top_k 78 | self.filter_value = filter_value 79 | self.min_tokens_to_keep = min_tokens_to_keep 80 | 81 | def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor) -> paddle.Tensor: 82 | top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check 83 | # Remove all tokens with a probability less than the last token of the top-k 84 | indices_to_remove = scores < paddle.topk(scores, top_k)[0][..., -1, None] 85 | scores = scores.masked_fill(indices_to_remove, self.filter_value) 86 | return scores 87 | 88 | 89 | def top_k_top_p_filtering( 90 | logits: paddle.Tensor, 91 | top_k: int = 0, 92 | top_p: float = 1.0, 93 | filter_value: float = -float('Inf'), 94 | min_tokens_to_keep: int = 1, 95 | ) -> paddle.Tensor: 96 | """ 97 | Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 98 | Args: 99 | logits: logits distribution shape (batch size, vocabulary size) 100 | top_k (:obj:`int`, `optional`, defaults to 0): 101 | If > 0, only keep the top k tokens with highest probability (top-k filtering) 102 | top_p (:obj:`float`, `optional`, defaults to 1.0): 103 | If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus 104 | filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 105 | min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): 106 | Minimumber of tokens we keep per batch example in the output. 107 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 108 | """ 109 | if top_k > 0: 110 | logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( 111 | None, logits 112 | ) 113 | 114 | if 0 <= top_p <= 1.0: 115 | logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) 116 | 117 | return logits 118 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/parsing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import keyword 3 | import warnings 4 | from typing import List 5 | 6 | _ellipsis = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated 7 | 8 | 9 | class EinopsError(RuntimeError): 10 | """ Runtime error thrown by einops """ 11 | pass 12 | 13 | 14 | class AnonymousAxis(object): 15 | """Important thing: all instances of this class are not equal to each other """ 16 | 17 | def __init__(self, value: str): 18 | self.value = int(value) 19 | if self.value <= 1: 20 | if self.value == 1: 21 | raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue') 22 | else: 23 | raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value)) 24 | 25 | def __repr__(self): 26 | return '{}-axis'.format(str(self.value)) 27 | 28 | 29 | class ParsedExpression: 30 | """ 31 | non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') 32 | and keeps some information important for downstream 33 | """ 34 | 35 | def __init__(self, expression): 36 | self.has_ellipsis = False 37 | self.has_ellipsis_parenthesized = None 38 | self.identifiers = set() 39 | # that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition 40 | self.has_non_unitary_anonymous_axes = False 41 | # composition keeps structure of composite axes, see how different corner cases are handled in tests 42 | self.composition = [] 43 | if '.' in expression: 44 | if '...' not in expression: 45 | raise EinopsError('Expression may contain dots only inside ellipsis (...)') 46 | if str.count(expression, '...') != 1 or str.count(expression, '.') != 3: 47 | raise EinopsError( 48 | 'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ') 49 | expression = expression.replace('...', _ellipsis) 50 | self.has_ellipsis = True 51 | 52 | bracket_group = None 53 | 54 | def add_axis_name(x): 55 | if x is not None: 56 | if x in self.identifiers: 57 | raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) 58 | if x == _ellipsis: 59 | self.identifiers.add(_ellipsis) 60 | if bracket_group is None: 61 | self.composition.append(_ellipsis) 62 | self.has_ellipsis_parenthesized = False 63 | else: 64 | bracket_group.append(_ellipsis) 65 | self.has_ellipsis_parenthesized = True 66 | else: 67 | is_number = str.isdecimal(x) 68 | if is_number and int(x) == 1: 69 | # handling the case of anonymous axis of length 1 70 | if bracket_group is None: 71 | self.composition.append([]) 72 | else: 73 | pass # no need to think about 1s inside parenthesis 74 | return 75 | is_axis_name, reason = self.check_axis_name(x, return_reason=True) 76 | if not (is_number or is_axis_name): 77 | raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) 78 | if is_number: 79 | x = AnonymousAxis(x) 80 | self.identifiers.add(x) 81 | if is_number: 82 | self.has_non_unitary_anonymous_axes = True 83 | if bracket_group is None: 84 | self.composition.append([x]) 85 | else: 86 | bracket_group.append(x) 87 | 88 | current_identifier = None 89 | for char in expression: 90 | if char in '() ': 91 | add_axis_name(current_identifier) 92 | current_identifier = None 93 | if char == '(': 94 | if bracket_group is not None: 95 | raise EinopsError('Axis composition is one-level (brackets inside brackets not allowed)') 96 | bracket_group = [] 97 | elif char == ')': 98 | if bracket_group is None: 99 | raise EinopsError('Brackets are not balanced') 100 | self.composition.append(bracket_group) 101 | bracket_group = None 102 | elif str.isalnum(char) or char in ['_', _ellipsis]: 103 | if current_identifier is None: 104 | current_identifier = char 105 | else: 106 | current_identifier += char 107 | else: 108 | raise EinopsError("Unknown character '{}'".format(char)) 109 | 110 | if bracket_group is not None: 111 | raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression)) 112 | add_axis_name(current_identifier) 113 | 114 | def flat_axes_order(self) -> List: 115 | result = [] 116 | for composed_axis in self.composition: 117 | assert isinstance(composed_axis, list), 'does not work with ellipsis' 118 | for axis in composed_axis: 119 | result.append(axis) 120 | return result 121 | 122 | def has_composed_axes(self) -> bool: 123 | # this will ignore 1 inside brackets 124 | for axes in self.composition: 125 | if isinstance(axes, list) and len(axes) > 1: 126 | return True 127 | return False 128 | 129 | @staticmethod 130 | def check_axis_name(name: str, return_reason=False): 131 | """ 132 | Valid axes names are python identifiers except keywords, 133 | and additionally should not start or end with underscore 134 | """ 135 | if not str.isidentifier(name): 136 | result = False, 'not a valid python identifier' 137 | elif name[0] == '_' or name[-1] == '_': 138 | result = False, 'axis name should should not start or end with underscore' 139 | else: 140 | if keyword.iskeyword(name): 141 | warnings.warn('It is discouraged to use axes names that are keywords: {}'.format(name), RuntimeWarning) 142 | if name in ['axis']: 143 | warnings.warn("It is discouraged to use 'axis' as an axis name " 144 | 'and will raise an error in future', FutureWarning) 145 | result = True, None 146 | if return_reason: 147 | return result 148 | else: 149 | return result[0] 150 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/layers/_weighted_einsum.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from einops import EinopsError 3 | from einops.parsing import ParsedExpression 4 | import warnings 5 | import string 6 | from ..einops import _product 7 | 8 | 9 | def _report_axes(axes: set, report_message: str): 10 | if len(axes) > 0: 11 | raise EinopsError(report_message.format(axes)) 12 | 13 | 14 | class WeightedEinsumMixin: 15 | def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths): 16 | """ 17 | WeightedEinsum - Einstein summation with second argument being weight tensor. 18 | NB: it is an experimental API. RFC https://github.com/arogozhnikov/einops/issues/71 19 | 20 | Imagine taking einsum with two arguments, one of each input, and one - tensor with weights 21 | >>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight) 22 | 23 | This layer manages weights for you after a minor tweaking 24 | >>> WeightedEinsum('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out') 25 | But otherwise it is the same einsum. 26 | 27 | Simple linear layer with bias term (you have one like that in your framework) 28 | >>> WeightedEinsum('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20) 29 | Channel-wise multiplication (like one used in normalizations) 30 | >>> WeightedEinsum('t b c -> t b c', weight_shape='c', c=128) 31 | Separate dense layer within each head, no connection between different heads 32 | >>> WeightedEinsum('t b head cin -> t b head cout', weight_shape='head cin cout', ...) 33 | 34 | ... ah yes, you need to specify all dimensions of weight shape/bias shape in parameters. 35 | 36 | Good use cases: 37 | - when channel dimension is not last, use WeightedEinsum, not transposition 38 | - when need only within-group connections to reduce number of weights and computations 39 | - perfect as a part of sequential models 40 | 41 | Uniform He initialization is applied to weight tensor. 42 | 43 | Parameters 44 | :param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output 45 | :param weight_shape: axes of weight. Tensor od this shape is created, stored, and optimized in a layer 46 | :param bias_shape: axes of bias added to output. 47 | :param axes_lengths: dimensions of weight tensor 48 | """ 49 | super().__init__() 50 | warnings.warn('WeightedEinsum is experimental feature. API can change in unpredictable and enjoyable ways', 51 | FutureWarning) 52 | self.pattern = pattern 53 | self.weight_shape = weight_shape 54 | self.bias_shape = bias_shape 55 | self.axes_lengths = axes_lengths 56 | 57 | left, right = pattern.split('->') 58 | left = ParsedExpression(left) 59 | right = ParsedExpression(right) 60 | weight = ParsedExpression(weight_shape) 61 | _report_axes( 62 | set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}), 63 | 'Unrecognized identifiers on the right side of WeightedEinsum {}' 64 | ) 65 | 66 | if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis: 67 | raise EinopsError('Ellipsis is not supported in WeightedEinsum (right now)') 68 | if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]): 69 | raise EinopsError('Anonymous axes (numbers) are not allowed in WeightedEinsum') 70 | if '(' in weight_shape or ')' in weight_shape: 71 | raise EinopsError('Parenthesis is not allowed in weight shape') 72 | # TODO implement this 73 | if '(' in pattern or ')' in pattern: 74 | raise EinopsError('Axis composition/decomposition are not yet supported in einsum') 75 | for axis in weight.identifiers: 76 | if axis not in axes_lengths: 77 | raise EinopsError('Dimension {} of weight should be specified'.format(axis)) 78 | _report_axes( 79 | set.difference(set(axes_lengths), {*left.identifiers, *weight.identifiers}), 80 | 'Axes {} are not used in pattern', 81 | ) 82 | _report_axes( 83 | set.difference(weight.identifiers, {*left.identifiers, *right.identifiers}), 84 | 'Weight axes {} are redundant' 85 | ) 86 | if len(weight.identifiers) == 0: 87 | warnings.warn('WeightedEinsum: weight has no dimensions (means multiplication by a number)') 88 | 89 | _weight_shape = [axes_lengths[axis] for axis, in weight.composition] 90 | # single output element is a combination of fan_in input elements 91 | _fan_in = _product([axes_lengths[axis] for axis, in weight.composition if axis not in right.identifiers]) 92 | if bias_shape is not None: 93 | if not isinstance(bias_shape, str): 94 | raise EinopsError('bias shape should be string specifying which axes bias depends on') 95 | bias = ParsedExpression(bias_shape) 96 | _report_axes( 97 | set.difference(bias.identifiers, right.identifiers), 98 | 'Bias axes {} not present in output' 99 | ) 100 | _report_axes( 101 | set.difference(bias.identifiers, set(axes_lengths)), 102 | 'Sizes not provided for bias axes {}', 103 | ) 104 | 105 | _bias_shape = [] 106 | for axes in right.composition: 107 | for axis in axes: 108 | if axis in bias.identifiers: 109 | _bias_shape.append(axes_lengths[axis]) 110 | else: 111 | _bias_shape.append(1) 112 | else: 113 | _bias_shape = None 114 | # _bias_input_size = None 115 | 116 | weight_bound = (3 / _fan_in) ** 0.5 117 | bias_bound = (1 / _fan_in) ** 0.5 118 | self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound) 119 | 120 | # rewrite einsum expression with single-letter latin identifiers so that each expression is 121 | mapping2letters = {*left.identifiers, *right.identifiers, *weight.identifiers} 122 | mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapping2letters)} 123 | 124 | def write_flat(axes: list): 125 | return ''.join(mapping2letters[axis] for axis in axes) 126 | 127 | self.einsum_pattern = '{},{}->{}'.format( 128 | write_flat(left.flat_axes_order()), 129 | write_flat(weight.flat_axes_order()), 130 | write_flat(right.flat_axes_order()), 131 | ) 132 | 133 | def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): 134 | """ Shape and implementations """ 135 | raise NotImplementedError('Should be defined in framework implementations') 136 | 137 | def __repr__(self): 138 | params = repr(self.pattern) 139 | params += ', ' + self.weight_shape 140 | for axis, length in self.axes_lengths.items(): 141 | params += ', {}={}'.format(axis, length) 142 | return '{}({})'.format(self.__class__.__name__, params) 143 | -------------------------------------------------------------------------------- /rudalle_paddle/future.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # add future ops for paddle 3 | import warnings 4 | 5 | import paddle 6 | import paddle.nn.functional as F 7 | 8 | from paddle import Tensor 9 | 10 | 11 | set_device = paddle.set_device 12 | paddle.set_device = lambda device: set_device(device.replace('cuda', 'gpu')) 13 | 14 | 15 | def to(self, place, dtype=None): 16 | if isinstance(place, str): 17 | if place == 'cpu': 18 | place = paddle.CPUPlace() 19 | elif place == 'cuda': 20 | place = paddle.CUDAPlace(0) 21 | elif 'cuda:' in place: 22 | place = paddle.CUDAPlace(int(place.split(':')[1])) 23 | out = self 24 | if isinstance(dtype, str): 25 | dtype = getattr(paddle, dtype) 26 | if dtype is not None and self.dtype != dtype: 27 | out = self.astype(dtype) 28 | if self.place._equals(place): 29 | return out 30 | out = paddle.to_tensor(out, place=place, stop_gradient=self.stop_gradient) 31 | if self.grad is not None: 32 | grad = self.grad.to(place, dtype) 33 | out._set_grad_ivar(grad) 34 | return out 35 | 36 | 37 | paddle.Tensor.to = to 38 | paddle.Tensor.cpu = lambda self: to(self, 'cpu') 39 | paddle.Tensor.cuda = lambda self: to(self, 'cuda') 40 | 41 | 42 | _layer_to = paddle.nn.Layer.to 43 | 44 | 45 | def layer_to(self, device=None, dtype=None, blocking=None): 46 | if isinstance(device, str): 47 | if device == 'cpu': 48 | place = paddle.CPUPlace() 49 | elif device == 'cuda': 50 | place = paddle.CUDAPlace(0) 51 | elif 'cuda:' in device: 52 | place = paddle.CUDAPlace(int(device.split(':')[1])) 53 | if isinstance(dtype, paddle.dtype): 54 | dtype = str(dtype).split('.')[-1] 55 | if self.parameters()[0].place._equals(place): 56 | device = None 57 | if self.parameters()[0].dtype == dtype: 58 | dtype = None 59 | if device is None and dtype is None: 60 | return self 61 | _layer_to(self, device, dtype, blocking) 62 | return self 63 | 64 | 65 | paddle.nn.Layer.to = layer_to 66 | 67 | 68 | def swapaxes(self, a, b): 69 | dims = list(range(self.ndim)) 70 | dims[a], dims[b] = dims[b], dims[a] 71 | 72 | return self.transpose(dims) 73 | 74 | 75 | paddle.Tensor.swapaxes = swapaxes 76 | 77 | 78 | def masked_fill(self, mask, value): 79 | y = paddle.full(self.shape, value, self.dtype) 80 | return paddle.where(mask, y, self) 81 | 82 | 83 | paddle.Tensor.masked_fill = masked_fill 84 | 85 | 86 | def pad_sequence(sequences, batch_first=False, padding_value=0): 87 | max_t = max([seq.shape[0] for seq in sequences]) 88 | _sequences = [] 89 | for seq in sequences: 90 | if max_t > seq.shape[0]: 91 | pad_num = max_t - seq.shape[0] 92 | pads = paddle.to_tensor([padding_value] * pad_num, dtype=seq.dtype, place=seq.place) 93 | pads = pads.reshape([-1, *list(range(len(seq.shape[1:])))]) 94 | pads = pads.expand([1, *seq.shape[1:]]) 95 | seq = paddle.stack([seq, pads], 0 if batch_first else 1) 96 | else: 97 | seq = seq.unsqueeze(0 if batch_first else 1) 98 | _sequences.append(seq) 99 | 100 | return paddle.concat(_sequences, 0 if batch_first else 1) 101 | 102 | 103 | paddle.pad_sequence = pad_sequence 104 | 105 | 106 | def exponential_(self): 107 | eps = 1e-10 108 | U = paddle.rand(self.shape) 109 | out = -paddle.log(U + eps) + eps 110 | self[:] = out 111 | return self 112 | 113 | 114 | paddle.Tensor.exponential_ = exponential_ 115 | 116 | 117 | paddle.Tensor.softmax = lambda self, *args, **kwargs: F.softmax(self, *args, **kwargs) 118 | 119 | 120 | def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, axis: int = -1) -> Tensor: 121 | r""" 122 | Samples from the Gumbel-Softmax distribution (`Link 1`_ `Link 2`_) and optionally discretizes. 123 | 124 | Args: 125 | logits: `[..., num_features]` unnormalized log probabilities 126 | tau: non-negative scalar temperature 127 | hard: if ``True``, the returned samples will be discretized as one-hot vectors, 128 | but will be differentiated as if it is the soft sample in autograd 129 | axis (int): A dimension along which softmax will be computed. Default: -1. 130 | 131 | Returns: 132 | Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. 133 | If ``hard=True``, the returned samples will be one-hot, otherwise they will 134 | be probability distributions that sum to 1 across `dim`. 135 | 136 | .. note:: 137 | This function is here for legacy reasons, may be removed from nn.Functional in the future. 138 | 139 | .. note:: 140 | The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft` 141 | 142 | It achieves two things: 143 | - makes the output value exactly one-hot 144 | (since we add then subtract y_soft value) 145 | - makes the gradient equal to y_soft gradient 146 | (since we strip all other gradients) 147 | 148 | Examples:: 149 | >>> logits = paddle.randn(20, 32) 150 | >>> # Sample soft categorical using reparametrization trick: 151 | >>> F.gumbel_softmax(logits, tau=1, hard=False) 152 | >>> # Sample hard categorical using "Straight-through" trick: 153 | >>> F.gumbel_softmax(logits, tau=1, hard=True) 154 | 155 | .. _Link 1: 156 | https://arxiv.org/abs/1611.00712 157 | .. _Link 2: 158 | https://arxiv.org/abs/1611.01144 159 | """ 160 | if eps != 1e-10: 161 | warnings.warn('`eps` parameter is deprecated and has no effect.') 162 | 163 | gumbels = ( 164 | -paddle.empty_like(logits, dtype=logits.dtype).exponential_().log() 165 | ) # ~Gumbel(0,1) 166 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 167 | y_soft = gumbels.softmax(axis) 168 | 169 | if hard: 170 | # Straight through. 171 | index = y_soft.argmax(axis, keepdim=True) 172 | if axis < 0: 173 | axis = index.ndim + axis 174 | y_hard = F.one_hot(index, y_soft.shape[axis]).swapaxes(axis, -1)[..., 0] 175 | ret = y_hard - y_soft.detach() + y_soft 176 | else: 177 | # Reparametrization trick. 178 | ret = y_soft 179 | return ret 180 | 181 | 182 | F.gumbel_softmax = gumbel_softmax 183 | 184 | 185 | def scatter_by_axis(self, axis, index, src): 186 | shapes = index.shape 187 | # Create additional indices 188 | grid_range = [paddle.arange(s) for s in shapes] 189 | grids = paddle.meshgrid(*grid_range) 190 | grids[axis] = index 191 | # Create final indices 192 | final_index = paddle.stack(grids, -1) 193 | # Get scatter-added tensor 194 | is_bool = src.dtype == paddle.bool 195 | if is_bool: 196 | src = src.astype(paddle.int64) 197 | scatter = paddle.scatter_nd(final_index, src, self.shape) 198 | if is_bool: 199 | scatter = scatter > 0 200 | 201 | return scatter 202 | 203 | 204 | paddle.Tensor.scatter_by_axis = scatter_by_axis 205 | 206 | pd_max_native = paddle.Tensor.max 207 | 208 | 209 | def pd_max(self, *args, **kwargs): 210 | return pd_max_native(self.to(self.place, 'float32'), *args, **kwargs).to(self.place, self.dtype) 211 | 212 | 213 | paddle.Tensor.max = pd_max 214 | 215 | # not support fp16 yet 216 | 217 | 218 | def layer_norm_forward(self, input): 219 | return F.layer_norm( 220 | input.to(input.place, 'float32'), 221 | normalized_shape=self._normalized_shape, 222 | weight=self.weight.to(input.place, 'float32'), 223 | bias=self.bias.to(input.place, 'float32'), 224 | epsilon=self._epsilon).to(input.place, input.dtype) 225 | 226 | 227 | paddle.nn.LayerNorm.forward = layer_norm_forward 228 | -------------------------------------------------------------------------------- /rudalle_paddle/realesrgan/arch_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | import paddle 4 | from paddle import nn as nn 5 | from paddle.nn import functional as F 6 | from paddle.nn import initializer as init 7 | from paddle.nn.layer.norm import _BatchNormBase 8 | 9 | 10 | @paddle.no_grad() 11 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): 12 | """Initialize network weights. 13 | Args: 14 | module_list (list[nn.Layer] | nn.Layer): Modules to be initialized. 15 | scale (float): Scale initialized weights, especially for residual 16 | blocks. Default: 1. 17 | bias_fill (float): The value to fill bias. Default: 0 18 | kwargs (dict): Other arguments for initialization function. 19 | """ 20 | if not isinstance(module_list, list): 21 | module_list = [module_list] 22 | for module in module_list: 23 | for m in module.sublayers(): 24 | if isinstance(m, nn.Conv2D): 25 | init.KaimingNormal(**kwargs)(m.weight) 26 | m.weight.set_value(m.weight * scale) 27 | if m.bias is not None: 28 | init.Constant(bias_fill)(m.bias) 29 | elif isinstance(m, nn.Linear): 30 | init.KaimingNormal(**kwargs)(m.weight) 31 | m.weight.set_value(m.weight * scale) 32 | if m.bias is not None: 33 | init.Constant(bias_fill)(m.bias) 34 | elif isinstance(m, _BatchNormBase): 35 | init.Constant(1)(m.weight) 36 | if m.bias is not None: 37 | init.Constant(bias_fill)(m.bias) 38 | 39 | 40 | def make_layer(basic_block, num_basic_block, **kwarg): 41 | """Make layers by stacking the same blocks. 42 | Args: 43 | basic_block (nn.module): nn.module class for basic block. 44 | num_basic_block (int): number of blocks. 45 | Returns: 46 | nn.Sequential: Stacked blocks in nn.Sequential. 47 | """ 48 | layers = [] 49 | for _ in range(num_basic_block): 50 | layers.append(basic_block(**kwarg)) 51 | return nn.Sequential(*layers) 52 | 53 | 54 | class ResidualBlockNoBN(nn.Layer): 55 | """Residual block without BN. 56 | It has a style of: 57 | ---Conv-ReLU-Conv-+- 58 | |________________| 59 | Args: 60 | num_feat (int): Channel number of intermediate features. 61 | Default: 64. 62 | res_scale (float): Residual scale. Default: 1. 63 | pypaddle_init (bool): If set to True, use paddle default init, 64 | otherwise, use default_init_weights. Default: False. 65 | """ 66 | 67 | def __init__(self, num_feat=64, res_scale=1, pypaddle_init=False): 68 | super(ResidualBlockNoBN, self).__init__() 69 | self.res_scale = res_scale 70 | self.conv1 = nn.Conv2D(num_feat, num_feat, 3, 1, 1, bias=True) 71 | self.conv2 = nn.Conv2D(num_feat, num_feat, 3, 1, 1, bias=True) 72 | self.relu = nn.ReLU() 73 | 74 | if not pypaddle_init: 75 | default_init_weights([self.conv1, self.conv2], 0.1) 76 | 77 | def forward(self, x): 78 | identity = x 79 | out = self.conv2(self.relu(self.conv1(x))) 80 | return identity + out * self.res_scale 81 | 82 | 83 | class Upsample(nn.Sequential): 84 | """Upsample module. 85 | Args: 86 | scale (int): Scale factor. Supported scales: 2^n and 3. 87 | num_feat (int): Channel number of intermediate features. 88 | """ 89 | 90 | def __init__(self, scale, num_feat): 91 | m = [] 92 | if (scale & (scale - 1)) == 0: # scale = 2^n 93 | for _ in range(int(math.log(scale, 2))): 94 | m.append(nn.Conv2D(num_feat, 4 * num_feat, 3, 1, 1)) 95 | m.append(nn.PixelShuffle(2)) 96 | elif scale == 3: 97 | m.append(nn.Conv2D(num_feat, 9 * num_feat, 3, 1, 1)) 98 | m.append(nn.PixelShuffle(3)) 99 | else: 100 | raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') 101 | super(Upsample, self).__init__(*m) 102 | 103 | 104 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): 105 | """Warp an image or feature map with optical flow. 106 | Args: 107 | x (Tensor): Tensor with size (n, c, h, w). 108 | flow (Tensor): Tensor with size (n, h, w, 2), normal value. 109 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. 110 | padding_mode (str): 'zeros' or 'border' or 'reflection'. 111 | Default: 'zeros'. 112 | align_corners (bool): Before paddle 1.3, the default value is 113 | align_corners=True. After paddle 1.3, the default value is 114 | align_corners=False. Here, we use the True as default. 115 | Returns: 116 | Tensor: Warped image or feature map. 117 | """ 118 | assert x.shape[-2:] == flow.shape[1:3] 119 | _, _, h, w = x.shape 120 | # create mesh grid 121 | grid_y, grid_x = paddle.meshgrid(paddle.arange(0, h).astype(x.dtype), paddle.arange(0, w).astype(x.dtype)) 122 | grid = paddle.stack((grid_x, grid_y), 2).astype(paddle.float32) # W(x), H(y), 2 123 | grid.stop_gradient = True 124 | 125 | vgrid = grid + flow 126 | # scale grid to [-1,1] 127 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 128 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 129 | vgrid_scaled = paddle.stack((vgrid_x, vgrid_y), axis=3) 130 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) 131 | 132 | # TODO, what if align_corners=False 133 | return output 134 | 135 | 136 | def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): 137 | """Resize a flow according to ratio or shape. 138 | Args: 139 | flow (Tensor): Precomputed flow. shape [N, 2, H, W]. 140 | size_type (str): 'ratio' or 'shape'. 141 | sizes (list[int | float]): the ratio for resizing or the final output 142 | shape. 143 | 1) The order of ratio should be [ratio_h, ratio_w]. For 144 | downsampling, the ratio should be smaller than 1.0 (i.e., ratio 145 | < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., 146 | ratio > 1.0). 147 | 2) The order of output_size should be [out_h, out_w]. 148 | interp_mode (str): The mode of interpolation for resizing. 149 | Default: 'bilinear'. 150 | align_corners (bool): Whether align corners. Default: False. 151 | Returns: 152 | Tensor: Resized flow. 153 | """ 154 | _, _, flow_h, flow_w = flow.shape 155 | if size_type == 'ratio': 156 | output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) 157 | elif size_type == 'shape': 158 | output_h, output_w = sizes[0], sizes[1] 159 | else: 160 | raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') 161 | 162 | input_flow = flow.clone() 163 | ratio_h = output_h / flow_h 164 | ratio_w = output_w / flow_w 165 | input_flow[:, 0, :, :] *= ratio_w 166 | input_flow[:, 1, :, :] *= ratio_h 167 | resized_flow = F.interpolate( 168 | input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) 169 | return resized_flow 170 | 171 | 172 | # TODO: may write a cpp file 173 | def pixel_unshuffle(x, scale): 174 | """ Pixel unshuffle. 175 | Args: 176 | x (Tensor): Input feature with shape (b, c, hh, hw). 177 | scale (int): Downsample ratio. 178 | Returns: 179 | Tensor: the pixel unshuffled feature. 180 | """ 181 | b, c, hh, hw = x.shape 182 | out_channel = c * (scale**2) 183 | assert hh % scale == 0 and hw % scale == 0 184 | h = hh // scale 185 | w = hw // scale 186 | x_view = x.reshape([b, c, h, scale, w, scale]) 187 | return x_view.transpose([0, 1, 3, 5, 2, 4]).reshape([b, out_channel, h, w]) 188 | -------------------------------------------------------------------------------- /rudalle_paddle/dalle/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import paddle 3 | import paddle.nn.functional as F 4 | from einops import rearrange 5 | 6 | from .utils import init_method_normal 7 | from .image_attention import get_conv_mask, get_row_mask, get_col_mask 8 | 9 | from .transformer import DalleTransformer 10 | 11 | 12 | class DalleModel(paddle.nn.Layer): 13 | def __init__(self, 14 | device, 15 | num_layers, 16 | vocab_size, 17 | hidden_size, 18 | num_attention_heads, 19 | embedding_dropout_prob, 20 | attention_dropout_prob, 21 | output_dropout_prob, 22 | text_seq_length=128, 23 | image_tokens_per_dim=32, 24 | image_vocab_size=16384, 25 | loss_img_weight=7, 26 | fp16=False, 27 | use_masks=True, 28 | cogview_sandwich_layernorm=False, 29 | cogview_pb_relax=False): 30 | 31 | super(DalleModel, self).__init__() 32 | self.device = device 33 | self.fp16 = fp16 34 | self.image_tokens_per_dim = image_tokens_per_dim 35 | self.image_seq_length = image_tokens_per_dim ** 2 36 | self.text_seq_length = text_seq_length 37 | self.total_seq_length = self.text_seq_length + self.image_seq_length 38 | self.total_vocab_size = vocab_size + image_vocab_size 39 | self.vocab_size = vocab_size 40 | self.loss_img_weight = loss_img_weight 41 | 42 | # TODO "to" 43 | mask_map = self.prepare_image_masks(num_layers, text_seq_length, image_tokens_per_dim) 44 | if use_masks: 45 | self._mask_map = mask_map 46 | else: 47 | self._mask_map = [] 48 | 49 | init_method = init_method_normal(std=0.02) 50 | 51 | self.text_embeddings = paddle.nn.Embedding(vocab_size, hidden_size) 52 | self.image_embeddings = paddle.nn.Embedding(image_vocab_size, hidden_size) 53 | 54 | # Position embedding (serial). 55 | self.text_pos_embeddings = paddle.nn.Embedding(text_seq_length + 1, hidden_size) 56 | self.image_row_embeddings = paddle.nn.Embedding(image_tokens_per_dim, hidden_size) 57 | self.image_col_embeddings = paddle.nn.Embedding(image_tokens_per_dim, hidden_size) 58 | init_method(self.text_pos_embeddings.weight) 59 | init_method(self.image_row_embeddings.weight) 60 | init_method(self.image_col_embeddings.weight) 61 | 62 | self.to_logits = paddle.nn.Sequential( 63 | paddle.nn.LayerNorm(hidden_size), 64 | paddle.nn.Linear(hidden_size, self.total_vocab_size), 65 | ) 66 | 67 | # Embeddings dropout 68 | self.embedding_dropout = paddle.nn.Dropout(embedding_dropout_prob) 69 | 70 | # Transformer 71 | self.transformer = DalleTransformer( 72 | num_layers, 73 | hidden_size, 74 | num_attention_heads, 75 | attention_dropout_prob, 76 | output_dropout_prob, 77 | cogview_sandwich_layernorm=cogview_sandwich_layernorm, 78 | cogview_pb_relax=cogview_pb_relax, 79 | ) 80 | self.transformer._mask_map = self._mask_map 81 | 82 | def get_param(self, item): 83 | return getattr(self, item) 84 | 85 | def prepare_image_masks(self, num_layers, text_seq_length, image_tokens_per_dim): 86 | row_mask = get_row_mask(text_seq_length, image_tokens_per_dim).to(self.device) 87 | col_mask = get_col_mask(text_seq_length, image_tokens_per_dim).to(self.device) 88 | conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim).to(self.device) 89 | # if self.fp16: 90 | # row_mask = row_mask.astype(paddle.float16) 91 | # col_mask = col_mask.astype(paddle.float16) 92 | # conv_mask = conv_mask.astype(paddle.float16) 93 | self.register_buffer('row_mask', row_mask) 94 | self.register_buffer('col_mask', col_mask) 95 | self.register_buffer('conv_mask', conv_mask) 96 | mask_map = [] 97 | for i in range(num_layers): 98 | if ((i - 1) % 4 == 0): 99 | mask_map.append(col_mask) 100 | elif i != num_layers - 1: 101 | mask_map.append(row_mask) 102 | else: 103 | mask_map.append(conv_mask) 104 | return mask_map 105 | 106 | def get_image_pos_embeddings(self, image_input_ids, past_length=0): 107 | input_shape = image_input_ids.shape 108 | row_ids = paddle.arange(past_length, input_shape[-1] + past_length, 109 | dtype=paddle.int64).to(self.device) // self.image_tokens_per_dim 110 | row_ids = row_ids.unsqueeze(0).reshape([-1, input_shape[-1]]) 111 | col_ids = paddle.arange(past_length, input_shape[-1] + past_length, 112 | dtype=paddle.int64).to(self.device) % self.image_tokens_per_dim 113 | col_ids = col_ids.unsqueeze(0).reshape([-1, input_shape[-1]]) 114 | return self.image_row_embeddings(row_ids) + self.image_col_embeddings(col_ids) 115 | 116 | def forward( 117 | self, 118 | input_ids, 119 | attention_mask, 120 | return_loss=False, 121 | has_cache=False, 122 | use_cache=False, 123 | ): 124 | text = input_ids[:, :self.text_seq_length] 125 | text_range = paddle.arange(self.text_seq_length) 126 | text_range += (self.vocab_size - self.text_seq_length) 127 | text_range = text_range.to(self.device) 128 | text = paddle.where(text == 0, text_range, text) 129 | # some hardcode :) 130 | text = F.pad(text[:, None, :], (1, 0), value=2, data_format='NCL')[:, 0] 131 | text_embeddings = self.text_embeddings(text) + \ 132 | self.text_pos_embeddings(paddle.arange(text.shape[1]).to(self.device)) 133 | 134 | image_input_ids = None 135 | 136 | if input_ids.shape[1] > self.text_seq_length: 137 | image_input_ids = input_ids[:, self.text_seq_length:] 138 | image_embeddings = self.image_embeddings(image_input_ids) + \ 139 | self.get_image_pos_embeddings(image_input_ids, past_length=0) 140 | embeddings = paddle.concat((text_embeddings, image_embeddings), axis=1) 141 | else: 142 | embeddings = text_embeddings 143 | # some hardcode :) 144 | if embeddings.shape[1] > self.total_seq_length: 145 | embeddings = embeddings[:, :-1] 146 | 147 | alpha = 0.1 148 | embeddings = embeddings * alpha + embeddings.detach() * (1-alpha) 149 | 150 | attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]] 151 | transformer_output, present_has_cache = self.transformer( 152 | embeddings, attention_mask, has_cache=has_cache, use_cache=use_cache) 153 | 154 | logits = self.to_logits(transformer_output) 155 | if return_loss is False: 156 | return logits, present_has_cache 157 | 158 | labels = paddle.concat((text[:, 1:], image_input_ids), axis=1).astype(paddle.int64) 159 | logits = rearrange(logits, 'b n c -> b c n') 160 | 161 | text_logits = logits[:, :self.vocab_size, :self.text_seq_length].astype(paddle.float32) 162 | image_logits = logits[:, self.vocab_size:, self.text_seq_length:].astype(paddle.float32) 163 | 164 | loss_text = F.cross_entropy( 165 | text_logits.swapaxes(-1, -2), 166 | labels[:, :self.text_seq_length]) 167 | loss_img = F.cross_entropy( 168 | image_logits.swapaxes(-1, -2), 169 | labels[:, self.text_seq_length:]) 170 | 171 | loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1) 172 | return loss, { 173 | 'text': loss_text.detach().astype(paddle.float32), 174 | 'image': loss_img.detach().astype(paddle.float32) 175 | } 176 | 177 | def to(self, device, dtype=None, **kwargs): 178 | self.device = device 179 | self._mask_map = [mask.to(device, dtype) for mask in self._mask_map] 180 | self.transformer._mask_map = [mask.to(device, dtype) for mask in self.transformer._mask_map] 181 | return super().to(device, dtype, **kwargs) 182 | 183 | @staticmethod 184 | def convert(folder): 185 | import os 186 | import torch 187 | import pickle 188 | torch_weights = os.path.join(folder, 'pytorch_model.bin') 189 | target_model_path = os.path.join(folder, 'rudalle_paddle.pkl') 190 | if os.path.exists(target_model_path): 191 | return 192 | state_dict = torch.load(torch_weights, map_location='cpu') 193 | 194 | paddle_state_dict = {} 195 | for name, param in state_dict.items(): 196 | if param.ndim == 2 and '_mask' not in name and '_embeddings' not in name: 197 | param = param.transpose(1, 0) 198 | if param.ndim == 0: 199 | param = param.unsqueeze(0) 200 | param = param.cpu().detach().numpy() 201 | paddle_state_dict[name] = param 202 | 203 | with open(target_model_path, 'wb') as f: 204 | pickle.dump(paddle_state_dict, f, protocol=4) 205 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2020] [sberbank-ai] 190 | Copyright [2021] [Wu Hecong] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /rudalle_paddle/dalle/transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | import paddle 5 | from paddle.nn import LayerNorm 6 | 7 | from .utils import divide, split_tensor_along_last_dim 8 | 9 | 10 | def gelu_impl(x): 11 | """OpenAI's gelu implementation.""" 12 | return 0.5 * x * (1.0 + paddle.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) 13 | 14 | 15 | def gelu(x): 16 | # return gelu_impl(x) 17 | return paddle.nn.functional.gelu(x, approximate=True) 18 | 19 | 20 | class DalleTransformer(paddle.nn.Layer): 21 | """ 22 | This module takes input from embedding layer and it's output can 23 | be used directly by a logit layer. It consists of L (num-layers) 24 | blocks of: 25 | layer norm 26 | self attention 27 | residual connection 28 | layer norm 29 | mlp 30 | residual connection 31 | followed by a final layer norm. 32 | 33 | Arguments: 34 | num_layers: Number of transformer layers. 35 | hidden_size: The hidden size of the self attention. 36 | num_attention_heads: number of attention head in the self 37 | attention. 38 | attention_dropout_prob: dropout probability of the attention 39 | score in self attention. 40 | output_dropout_prob: dropout probability for the outputs 41 | after self attention and final output. 42 | layernorm_epsilon: epsilon used in layernorm to avoid 43 | division by zero. 44 | """ 45 | _mask_map = [] 46 | 47 | def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, 48 | layernorm_epsilon=1.0e-5, cogview_sandwich_layernorm=False, cogview_pb_relax=False): 49 | super(DalleTransformer, self).__init__() 50 | 51 | # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf 52 | self.cogview_pb_relax = cogview_pb_relax 53 | 54 | # Transformer layers. 55 | self.layers = paddle.nn.LayerList([ 56 | DalleTransformerLayer( 57 | hidden_size, 58 | num_attention_heads, 59 | attention_dropout_prob, 60 | output_dropout_prob, 61 | layernorm_epsilon, 62 | cogview_sandwich_layernorm=cogview_sandwich_layernorm, 63 | cogview_pb_relax=cogview_pb_relax, 64 | ) for _ in range(num_layers) 65 | ]) 66 | 67 | # Final layer norm before output. 68 | self.final_layernorm = LayerNorm(hidden_size, epsilon=layernorm_epsilon) 69 | 70 | def forward(self, hidden_states, attention_mask, has_cache, use_cache): 71 | for i, layer in enumerate(self.layers): 72 | mask = attention_mask 73 | if len(self._mask_map): 74 | layer_mask = self._mask_map[i][:mask.shape[2], :mask.shape[3]] 75 | mask = paddle.multiply(attention_mask, layer_mask) 76 | hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache) 77 | output = self.final_layernorm(hidden_states) 78 | return output, present_has_cache 79 | 80 | 81 | class DalleTransformerLayer(paddle.nn.Layer): 82 | """ 83 | A single layer transformer. 84 | 85 | We use the following notation: 86 | h: hidden size 87 | n: number of attention heads 88 | b: batch size 89 | s: sequence length 90 | Transformer layer takes input with size [b, s, h] and returns an 91 | output of the same size. 92 | 93 | Arguments: 94 | hidden_size: The hidden size of the self attention. 95 | num_attention_heads: number of attention head in the self 96 | attention. 97 | attention_dropout_prob: dropout probability of the attention 98 | score in self attention. 99 | output_dropout_prob: dropout probability for the outputs 100 | after self attention and final output. 101 | layernorm_epsilon: epsilon used in layernorm to avoid 102 | division by zero. 103 | """ 104 | 105 | def __init__(self, 106 | hidden_size, 107 | num_attention_heads, 108 | attention_dropout_prob, 109 | output_dropout_prob, 110 | layernorm_epsilon, 111 | cogview_sandwich_layernorm=False, 112 | cogview_pb_relax=False): 113 | super(DalleTransformerLayer, self).__init__() 114 | 115 | # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf 116 | self.cogview_sandwich_layernorm = cogview_sandwich_layernorm 117 | self.cogview_pb_relax = cogview_pb_relax 118 | 119 | # Layernorm on the input data. 120 | self.input_layernorm = LayerNorm(hidden_size, epsilon=layernorm_epsilon) 121 | 122 | if self.cogview_sandwich_layernorm: 123 | self.before_first_addition_layernorm = LayerNorm(hidden_size, epsilon=layernorm_epsilon) 124 | self.before_second_addition_layernorm = LayerNorm(hidden_size, epsilon=layernorm_epsilon) 125 | 126 | # Self attention. 127 | self.attention = DalleSelfAttention( 128 | hidden_size, 129 | num_attention_heads, 130 | attention_dropout_prob, 131 | output_dropout_prob, 132 | cogview_pb_relax=cogview_pb_relax 133 | ) 134 | 135 | # Layernorm on the input data. 136 | self.post_attention_layernorm = LayerNorm(hidden_size, epsilon=layernorm_epsilon) 137 | 138 | # MLP 139 | self.mlp = DalleMLP(hidden_size, output_dropout_prob) 140 | 141 | def forward(self, hidden_states, ltor_mask, has_cache, use_cache): 142 | # hidden_states: [b, s, h] 143 | # ltor_mask: [1, 1, s, s] 144 | 145 | # Layer norm at the begining of the transformer layer. 146 | layernorm_output = self.input_layernorm(hidden_states) 147 | 148 | # Self attention. 149 | attention_output, att_has_cache = self.attention( 150 | layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache) 151 | 152 | if self.cogview_sandwich_layernorm: 153 | attention_output = self.before_first_addition_layernorm(attention_output) 154 | 155 | # Residual connection. 156 | layernorm_input = hidden_states + attention_output 157 | 158 | # Layer norm post the self attention. 159 | layernorm_output = self.post_attention_layernorm(layernorm_input) 160 | 161 | # MLP. 162 | mlp_output, mlp_has_cache = self.mlp( 163 | layernorm_output, has_cache=has_cache, use_cache=use_cache) 164 | 165 | if self.cogview_sandwich_layernorm: 166 | mlp_output = self.before_second_addition_layernorm(mlp_output) 167 | 168 | # Second residual connection. 169 | output = layernorm_input + mlp_output 170 | 171 | return output, att_has_cache and mlp_has_cache 172 | 173 | 174 | class DalleSelfAttention(paddle.nn.Layer): 175 | """ 176 | Self-attention layer takes input with size [b, s, h] where b is 177 | the batch size, s is the sequence length, and h is the hidden size 178 | and creates output of the same size. 179 | Arguments: 180 | hidden_size: total hidden size of the layer (h). 181 | num_attention_heads: number of attention heads (n). Note that we 182 | require n to be divisible by number of GPUs 183 | used to parallelize the model. Also, we 184 | require hidden size to be divisible by n. 185 | attention_dropout_prob: dropout probability for the attention scores. 186 | output_dropout_prob: dropout probability for the output. 187 | We use the following notation: 188 | h: hidden_size 189 | n: num_attention_heads 190 | p: number of partitions 191 | np: n/p 192 | hp: h/p 193 | hn: h/n 194 | b: batch size 195 | s: sequence length 196 | """ 197 | 198 | def __init__(self, hidden_size, num_attention_heads, 199 | attention_dropout_prob, output_dropout_prob, cogview_pb_relax=False): 200 | super(DalleSelfAttention, self).__init__() 201 | 202 | # CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf 203 | self.cogview_pb_relax = cogview_pb_relax 204 | 205 | self.hidden_size = hidden_size 206 | self.num_attention_heads = num_attention_heads 207 | self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) 208 | 209 | self.query_key_value = paddle.nn.Linear(hidden_size, 3*hidden_size) 210 | self.attention_dropout = paddle.nn.Dropout(attention_dropout_prob) 211 | 212 | # Output. 213 | self.dense = paddle.nn.Linear(hidden_size, hidden_size) 214 | self.output_dropout = paddle.nn.Dropout(output_dropout_prob) 215 | 216 | # Cache 217 | self.past_key = None 218 | self.past_value = None 219 | self.past_output = None 220 | 221 | def _transpose_for_scores(self, tensor): 222 | """ Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """ 223 | new_tensor_shape = tensor.shape[:-1] + [self.num_attention_heads, self.hidden_size_per_attention_head] 224 | tensor = tensor.reshape(new_tensor_shape) 225 | return tensor.transpose([0, 2, 1, 3]) 226 | 227 | def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask): 228 | key_t = key_layer.swapaxes(-1, -2) 229 | if self.cogview_pb_relax: 230 | attention_scores = paddle.matmul( 231 | query_layer / math.sqrt(self.hidden_size_per_attention_head), 232 | key_t 233 | ) 234 | else: 235 | attention_scores = paddle.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head) 236 | ltor_mask = ltor_mask[:, :, -attention_scores.shape[-2]:] 237 | attention_scores = paddle.multiply(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask) 238 | if self.cogview_pb_relax: 239 | # normalize attention scores. Should not affect resulting softmax value 240 | alpha = 32 241 | attention_scores_scaled = attention_scores / alpha 242 | attention_scores_scaled_maxes = attention_scores_scaled.detach().reshape( 243 | [attention_scores.shape[0], attention_scores.shape[1], -1] 244 | ).max(axis=-1) # max per head per sample 245 | attention_scores_scaled_maxes = attention_scores_scaled_maxes.unsqueeze(-1).unsqueeze(-1).expand( 246 | [-1, -1, attention_scores.shape[2], attention_scores.shape[3]] 247 | ) # expand to [b, np, s, s] 248 | attention_scores = (attention_scores_scaled - attention_scores_scaled_maxes) * alpha 249 | return attention_scores 250 | 251 | def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,): 252 | # hidden_states: [b, s, h] 253 | # ltor_mask: [1, 1, s, s] 254 | # Attention heads. [b, s, hp] 255 | if has_cache and use_cache: 256 | mixed_x_layer = self.query_key_value(hidden_states[:, -1:, :]) 257 | else: 258 | mixed_x_layer = self.query_key_value(hidden_states) 259 | 260 | (mixed_query_layer, 261 | mixed_key_layer, 262 | mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) 263 | 264 | query_layer = self._transpose_for_scores(mixed_query_layer) 265 | key_layer = self._transpose_for_scores(mixed_key_layer) 266 | value_layer = self._transpose_for_scores(mixed_value_layer) 267 | 268 | # Can be simplified, but I didn't for readability's sake 269 | if use_cache and has_cache: 270 | key_layer = paddle.concat((self.past_key, key_layer), axis=-2) 271 | value_layer = paddle.concat((self.past_value, value_layer), axis=-2) 272 | attention_scores = self._calculate_attention_scores( 273 | query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask 274 | ) 275 | else: 276 | attention_scores = self._calculate_attention_scores( 277 | query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask 278 | ) 279 | 280 | if use_cache: 281 | self.past_key = key_layer 282 | self.past_value = value_layer 283 | else: 284 | self.past_key = None 285 | self.past_value = None 286 | self.past_output = None 287 | has_cache = False 288 | 289 | if use_cache and has_cache: 290 | attention_scores = attention_scores[..., -1:, :] 291 | 292 | # Attention probabilities. [b, np, s, s] 293 | attention_probs = paddle.nn.Softmax(axis=-1)(attention_scores) 294 | 295 | # This is actually dropping out entire tokens to attend to, which might 296 | # seem a bit unusual, but is taken from the original Transformer paper. 297 | attention_probs = self.attention_dropout(attention_probs) 298 | 299 | # Context layer. 300 | # [b, np, s, hn] 301 | context_layer = paddle.matmul(attention_probs, value_layer) 302 | 303 | # [b, s, np, hn] 304 | context_layer = context_layer.transpose([0, 2, 1, 3]) 305 | 306 | new_context_layer_shape = context_layer.shape[:-2] + [self.hidden_size, ] 307 | # [b, s, hp] 308 | context_layer = context_layer.reshape(new_context_layer_shape) 309 | 310 | # Output. [b, s, h] 311 | output = self.dense(context_layer) 312 | 313 | if use_cache: 314 | # Can be simplified, but I didn't for readability's sake 315 | if has_cache: 316 | output = paddle.concat((self.past_output, output), axis=-2) 317 | self.past_output = output 318 | else: 319 | self.past_output = output 320 | has_cache = True 321 | 322 | output = self.output_dropout(output) 323 | return output, has_cache 324 | 325 | 326 | class DalleMLP(paddle.nn.Layer): 327 | """ 328 | MLP will take the input with h hidden state, project it to 4*h 329 | hidden dimension, perform gelu transformation, and project the 330 | state back into h hidden dimension. At the end, dropout is also 331 | applied. 332 | Arguments: 333 | hidden_size: The hidden size of the self attention. 334 | output_dropout_prob: dropout probability for the outputs 335 | after self attention and final output. 336 | """ 337 | 338 | def __init__(self, hidden_size, output_dropout_prob): 339 | super(DalleMLP, self).__init__() 340 | # Project to 4h. 341 | self.dense_h_to_4h = paddle.nn.Linear(hidden_size, 4*hidden_size) 342 | # Project back to h. 343 | self.dense_4h_to_h = paddle.nn.Linear(4*hidden_size, hidden_size) 344 | self.dropout = paddle.nn.Dropout(output_dropout_prob) 345 | # MLP cache 346 | self.past_x = None 347 | 348 | def forward(self, hidden_states, has_cache=False, use_cache=False): 349 | if has_cache and use_cache: 350 | hidden_states = hidden_states[:, -1:] 351 | 352 | # [b, s, 4hp] 353 | x = self.dense_h_to_4h(hidden_states) 354 | x = gelu(x) 355 | # [b, s, h] 356 | x = self.dense_4h_to_h(x) 357 | if use_cache: 358 | # Can be simplified, but I didn't for readability's sake 359 | if has_cache: 360 | x = paddle.concat((self.past_x, x), axis=-2) 361 | self.past_x = x 362 | else: 363 | self.past_x = x 364 | 365 | has_cache = True 366 | else: 367 | self.past_x = None 368 | has_cache = False 369 | output = self.dropout(x) 370 | 371 | return output, has_cache 372 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/_backends.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Backends in `einops` are organized to meet the following requirements 4 | - backends are not imported unless those are actually needed, because 5 | - backends may not be installed 6 | - importing all available backends will drive to significant memory footprint 7 | - backends may by present but installed with errors (but never used), 8 | importing may drive to crashes 9 | - backend should be either symbolic or imperative (tensorflow is for both, but that causes problems) 10 | - this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined 11 | - if backend can't (temporarily) provide symbols for shape dimensions, UnknownSize objects are used 12 | """ 13 | 14 | import sys 15 | import warnings 16 | 17 | __author__ = 'Alex Rogozhnikov' 18 | 19 | _backends = {} 20 | _debug_importing = False 21 | 22 | 23 | def get_backend(tensor) -> 'AbstractBackend': 24 | """ 25 | Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor. 26 | If needed, imports package and creates backend 27 | """ 28 | for framework_name, backend in _backends.items(): 29 | if backend.is_appropriate_type(tensor): 30 | return backend 31 | 32 | # Find backend subclasses recursively 33 | backend_subclasses = [] 34 | backends = AbstractBackend.__subclasses__() 35 | while backends: 36 | backend = backends.pop() 37 | backends += backend.__subclasses__() 38 | backend_subclasses.append(backend) 39 | 40 | for BackendSubclass in backend_subclasses: 41 | if _debug_importing: 42 | print('Testing for subclass of ', BackendSubclass) 43 | if BackendSubclass.framework_name not in _backends: 44 | # check that module was already imported. Otherwise it can't be imported 45 | if BackendSubclass.framework_name in sys.modules: 46 | if _debug_importing: 47 | print('Imported backend for ', BackendSubclass.framework_name) 48 | backend = BackendSubclass() 49 | _backends[backend.framework_name] = backend 50 | if backend.is_appropriate_type(tensor): 51 | return backend 52 | 53 | raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor))) 54 | 55 | 56 | class AbstractBackend: 57 | """ Base backend class, major part of methods are only for debugging purposes. """ 58 | framework_name = None 59 | 60 | def is_appropriate_type(self, tensor): 61 | """ helper method should recognize tensors it can handle """ 62 | raise NotImplementedError() 63 | 64 | def from_numpy(self, x): 65 | raise NotImplementedError("framework doesn't support imperative execution") 66 | 67 | def to_numpy(self, x): 68 | raise NotImplementedError("framework doesn't support imperative execution") 69 | 70 | def create_symbol(self, shape): 71 | raise NotImplementedError("framework doesn't support symbolic computations") 72 | 73 | def eval_symbol(self, symbol, input_dict): 74 | raise NotImplementedError("framework doesn't support symbolic computations") 75 | 76 | def arange(self, start, stop): 77 | # supplementary method used only in testing, so should implement CPU version 78 | raise NotImplementedError("framework doesn't implement arange") 79 | 80 | def shape(self, x): 81 | """shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)""" 82 | return x.shape 83 | 84 | def reshape(self, x, shape): 85 | return x.reshape(shape) 86 | 87 | def transpose(self, x, axes): 88 | return x.transpose(axes) 89 | 90 | def reduce(self, x, operation, axes): 91 | return getattr(x, operation)(axis=axes) 92 | 93 | def stack_on_zeroth_dimension(self, tensors: list): 94 | raise NotImplementedError() 95 | 96 | def add_axis(self, x, new_position): 97 | raise NotImplementedError() 98 | 99 | def add_axes(self, x, n_axes, pos2len): 100 | repeats = [1] * n_axes 101 | for axis_position, axis_length in pos2len.items(): 102 | x = self.add_axis(x, axis_position) 103 | repeats[axis_position] = axis_length 104 | return self.tile(x, tuple(repeats)) 105 | 106 | def tile(self, x, repeats): 107 | """repeats is a number of """ 108 | raise NotImplementedError() 109 | 110 | def is_float_type(self, x): 111 | # some backends (torch) can't compute average for non-floating types. 112 | # Decided to drop average for all backends if type is not floating 113 | raise NotImplementedError() 114 | 115 | def layers(self): 116 | raise NotImplementedError('backend does not provide layers') 117 | 118 | def __repr__(self): 119 | return ''.format(self.framework_name) 120 | 121 | 122 | class UnknownSize: 123 | """ pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """ 124 | 125 | def __floordiv__(self, other): 126 | return self 127 | 128 | def __eq__(self, other): 129 | return True # we don't know actual size 130 | 131 | def __mul__(self, other): 132 | return self 133 | 134 | def __rmul__(self, other): 135 | return self 136 | 137 | def __hash__(self): 138 | return None.__hash__() 139 | 140 | 141 | class NumpyBackend(AbstractBackend): 142 | framework_name = 'numpy' 143 | 144 | def __init__(self): 145 | import numpy 146 | self.np = numpy 147 | 148 | def is_appropriate_type(self, tensor): 149 | return isinstance(tensor, self.np.ndarray) 150 | 151 | def from_numpy(self, x): 152 | return x 153 | 154 | def to_numpy(self, x): 155 | return x 156 | 157 | def arange(self, start, stop): 158 | return self.np.arange(start, stop) 159 | 160 | def stack_on_zeroth_dimension(self, tensors: list): 161 | return self.np.stack(tensors) 162 | 163 | def tile(self, x, repeats): 164 | return self.np.tile(x, repeats) 165 | 166 | def is_float_type(self, x): 167 | return x.dtype in ('float16', 'float32', 'float64', 'float128') 168 | 169 | def add_axis(self, x, new_position): 170 | return self.np.expand_dims(x, new_position) 171 | 172 | 173 | class JaxBackend(NumpyBackend): 174 | framework_name = 'jax' 175 | 176 | def __init__(self): 177 | super(JaxBackend, self).__init__() 178 | self.onp = self.np 179 | 180 | import jax.numpy 181 | self.np = jax.numpy 182 | 183 | def from_numpy(self, x): 184 | return self.np.asarray(x) 185 | 186 | def to_numpy(self, x): 187 | return self.onp.asarray(x) 188 | 189 | 190 | class GluonBackend(AbstractBackend): 191 | framework_name = 'mxnet.ndarray' 192 | 193 | def __init__(self): 194 | import mxnet 195 | self.mx = mxnet 196 | 197 | def is_appropriate_type(self, tensor): 198 | return isinstance(tensor, self.mx.nd.NDArray) 199 | 200 | def from_numpy(self, x): 201 | if len(x.shape) == 0: 202 | x = x[None] # poor support of scalars in mxnet, otherwise mxnet can't attach gradients 203 | var = self.mx.nd.array(x, dtype=x.dtype) 204 | var.attach_grad() 205 | return var 206 | 207 | def to_numpy(self, x): 208 | return self.mx.nd.NDArray.asnumpy(x) 209 | 210 | def reshape(self, x, shape): 211 | if len(shape) == 0: 212 | return x # poor support of scalars in mxnet 213 | return x.reshape(shape) 214 | 215 | def arange(self, start, stop): 216 | return self.mx.nd.arange(start, stop) 217 | 218 | def stack_on_zeroth_dimension(self, tensors: list): 219 | return self.mx.nd.stack(*tensors) 220 | 221 | def tile(self, x, repeats): 222 | return self.mx.nd.tile(x, repeats) 223 | 224 | def add_axis(self, x, new_position): 225 | return self.mx.nd.expand_dims(x, new_position) 226 | 227 | def is_float_type(self, x): 228 | return 'float' in str(x.dtype) 229 | 230 | def layers(self): 231 | from .layers import gluon 232 | return gluon 233 | 234 | 235 | class MXNetBackend(AbstractBackend): 236 | framework_name = 'mxnet.symbol' 237 | 238 | def __init__(self): 239 | import mxnet 240 | self.mx = mxnet 241 | 242 | def is_appropriate_type(self, tensor): 243 | return isinstance(tensor, self.mx.symbol.Symbol) 244 | 245 | def create_symbol(self, shape, dtype='float32'): 246 | # mxnet accepts zeros as undefined dimensions 247 | shape = tuple(0 if d is None else d for d in shape) 248 | var = self.mx.symbol.Variable('input', shape=shape, dtype=dtype) 249 | return var 250 | 251 | def eval_symbol(self, symbol, input_dict): 252 | args = {var.name: self.mx.nd.array(val) for var, val in input_dict} 253 | ex = symbol.bind(ctx=self.mx.cpu(), args=args) 254 | ex.forward() 255 | return ex.outputs[0].asnumpy() 256 | 257 | def shape(self, x): 258 | # mxnet has problems with shape inference - it does not provide shape symbols 259 | # shape_array seems to be impossible to use in shape inference 260 | # infer_shape_partial returns empty tuple if was not able to infer shape 261 | # reductions such as sum can't return scalars, but return 1-element vectors 262 | shape = x.infer_shape_partial()[1][0] 263 | if len(shape) == 0: 264 | warnings.warn('mxnet inferred shape to be (), which probably means it could not be inferred') 265 | shape = tuple(UnknownSize() if d == 0 else d for d in shape) 266 | return shape 267 | 268 | def reshape(self, x, shape): 269 | if len(shape) == 0: 270 | return x # poor support of scalars in mxnet 271 | if any(isinstance(dimension, UnknownSize) for dimension in shape): 272 | from einops import EinopsError 273 | raise EinopsError("Mxnet couldn't infer all dimensions statically, please provide those with axes_lengths") 274 | return x.reshape(shape) 275 | 276 | def arange(self, start, stop): 277 | return self.mx.symbol.arange(start, stop) 278 | 279 | def stack_on_zeroth_dimension(self, tensors: list): 280 | return self.mx.symbol.stack(*tensors) 281 | 282 | def tile(self, x, repeats): 283 | return self.mx.symbol.tile(x, repeats) 284 | 285 | def add_axis(self, x, new_position): 286 | return self.mx.symbol.expand_dims(x, new_position) 287 | 288 | def is_float_type(self, x): 289 | return 'float' in str(x.infer_type()[1][0]) 290 | 291 | def layers(self): 292 | from .layers import gluon 293 | return gluon 294 | 295 | 296 | class TorchBackend(AbstractBackend): 297 | framework_name = 'torch' 298 | 299 | def __init__(self): 300 | import torch 301 | self.torch = torch 302 | 303 | def is_appropriate_type(self, tensor): 304 | return isinstance(tensor, self.torch.Tensor) 305 | 306 | def from_numpy(self, x): 307 | variable = self.torch.from_numpy(x) 308 | if self.is_float_type(variable): 309 | # attach grad only to floating types 310 | variable.requires_grad = True 311 | return variable 312 | 313 | def to_numpy(self, x): 314 | return x.detach().cpu().numpy() 315 | 316 | def arange(self, start, stop): 317 | return self.torch.arange(start, stop, dtype=self.torch.int64) 318 | 319 | def reduce(self, x, operation, reduced_axes): 320 | for axis in sorted(reduced_axes, reverse=True): 321 | if operation == 'min': 322 | x, _ = x.min(dim=axis) 323 | elif operation == 'max': 324 | x, _ = x.max(dim=axis) 325 | elif operation in ['sum', 'mean', 'prod']: 326 | x = getattr(x, operation)(dim=axis) 327 | else: 328 | raise NotImplementedError('Unknown reduction ', operation) 329 | return x 330 | 331 | def transpose(self, x, axes): 332 | return x.permute(axes) 333 | 334 | def stack_on_zeroth_dimension(self, tensors: list): 335 | return self.torch.stack(tensors) 336 | 337 | def tile(self, x, repeats): 338 | return x.repeat(repeats) 339 | 340 | def add_axis(self, x, new_position): 341 | return self.torch.unsqueeze(x, new_position) 342 | 343 | def is_float_type(self, x): 344 | return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64] 345 | 346 | def layers(self): 347 | from .layers import torch 348 | return torch 349 | 350 | 351 | class CupyBackend(AbstractBackend): 352 | framework_name = 'cupy' 353 | 354 | def __init__(self): 355 | import cupy 356 | self.cupy = cupy 357 | 358 | def is_appropriate_type(self, tensor): 359 | return isinstance(tensor, self.cupy.ndarray) 360 | 361 | def from_numpy(self, x): 362 | return self.cupy.asarray(x) 363 | 364 | def to_numpy(self, x): 365 | return self.cupy.asnumpy(x) 366 | 367 | def arange(self, start, stop): 368 | return self.cupy.arange(start, stop) 369 | 370 | def stack_on_zeroth_dimension(self, tensors: list): 371 | return self.cupy.stack(tensors) 372 | 373 | def tile(self, x, repeats): 374 | return self.cupy.tile(x, repeats) 375 | 376 | def add_axis(self, x, new_position): 377 | return self.cupy.expand_dims(x, new_position) 378 | 379 | def is_float_type(self, x): 380 | return x.dtype in ('float16', 'float32', 'float64', 'float128') 381 | 382 | 383 | class ChainerBackend(AbstractBackend): 384 | framework_name = 'chainer' 385 | 386 | def __init__(self): 387 | import chainer 388 | import numpy 389 | self.numpy = numpy 390 | self.chainer = chainer 391 | 392 | def is_appropriate_type(self, tensor): 393 | return isinstance(tensor, self.chainer.Variable) 394 | 395 | def from_numpy(self, x): 396 | return self.chainer.Variable(x.astype('float32')) 397 | 398 | def to_numpy(self, x): 399 | if isinstance(x, self.chainer.Variable): 400 | x = x.data 401 | return x 402 | 403 | def arange(self, start, stop): 404 | return self.numpy.arange(start, stop) 405 | 406 | def reduce(self, x, operation, axes): 407 | return getattr(self.chainer.functions, operation)(x, axis=axes) 408 | 409 | def stack_on_zeroth_dimension(self, tensors: list): 410 | return self.chainer.functions.stack(tensors) 411 | 412 | def tile(self, x, repeats): 413 | return self.chainer.functions.tile(x, repeats) 414 | 415 | def add_axis(self, x, new_position): 416 | return self.chainer.functions.expand_dims(x, new_position) 417 | 418 | def is_float_type(self, x): 419 | return x.dtype in ('float16', 'float32', 'float64', 'float128') 420 | 421 | def layers(self): 422 | from .layers import chainer 423 | return chainer 424 | 425 | 426 | class HashableTuple: 427 | """Overcomes non-hashability of symbolic elements""" 428 | 429 | def __init__(self, elements: tuple): 430 | self.elements = elements 431 | 432 | def __iter__(self): 433 | for x in self.elements: 434 | yield x 435 | 436 | def __len__(self): 437 | return len(self.elements) 438 | 439 | def __getitem__(self, item): 440 | return self.elements[item] 441 | 442 | 443 | class TensorflowBackend(AbstractBackend): 444 | framework_name = 'tensorflow' 445 | 446 | def __init__(self): 447 | import tensorflow 448 | self.tf = tensorflow 449 | 450 | def is_appropriate_type(self, tensor): 451 | return isinstance(tensor, (self.tf.Tensor, self.tf.Variable)) 452 | 453 | def from_numpy(self, x): 454 | assert self.tf.executing_eagerly() 455 | return self.tf.convert_to_tensor(x) 456 | 457 | def to_numpy(self, x): 458 | assert self.tf.executing_eagerly() 459 | return x.numpy() 460 | 461 | def arange(self, start, stop): 462 | return self.tf.range(start, stop) 463 | 464 | def shape(self, x): 465 | if self.tf.executing_eagerly(): 466 | return tuple(UnknownSize() if d is None else int(d) for d in x.shape) 467 | else: 468 | static_shape = x.shape.as_list() 469 | tf_shape = self.tf.shape(x) 470 | # use the static shape where known, otherwise use the TF shape components 471 | shape = tuple([s or tf_shape[dim] for dim, s in enumerate(static_shape)]) 472 | try: 473 | hash(shape) 474 | return shape 475 | except Exception as _: 476 | _ = _ 477 | # unhashable symbols in shape. Wrap tuple to be hashable. 478 | return HashableTuple(shape) 479 | 480 | def reduce(self, x, operation, axes): 481 | return getattr(self.tf, 'reduce_' + operation)(x, axis=axes) 482 | 483 | def reshape(self, x, shape): 484 | return self.tf.reshape(x, shape) 485 | 486 | def transpose(self, x, axes): 487 | return self.tf.transpose(x, axes) 488 | 489 | def stack_on_zeroth_dimension(self, tensors: list): 490 | return self.tf.stack(tensors) 491 | 492 | def tile(self, x, repeats): 493 | return self.tf.tile(x, repeats) 494 | 495 | def add_axis(self, x, new_position): 496 | return self.tf.expand_dims(x, new_position) 497 | 498 | def is_float_type(self, x): 499 | return x.dtype in ('float16', 'float32', 'float64', 'float128') 500 | 501 | def layers(self): 502 | from .layers import tensorflow 503 | return tensorflow 504 | 505 | 506 | class KerasBackend(AbstractBackend): 507 | framework_name = 'tensorflow.keras' 508 | 509 | def __init__(self): 510 | import tensorflow as tf 511 | self.tf = tf 512 | self.keras = tf.keras 513 | self.K = tf.keras.backend 514 | 515 | def is_appropriate_type(self, tensor): 516 | return self.tf.is_tensor(tensor) and self.K.is_keras_tensor(tensor) 517 | 518 | def create_symbol(self, shape): 519 | return self.keras.Input(batch_shape=shape) 520 | 521 | def eval_symbol(self, symbol, input_dict): 522 | (variable, value), = input_dict 523 | model = self.keras.models.Model(variable, symbol) 524 | return model.predict_on_batch(value) 525 | 526 | def arange(self, start, stop): 527 | return self.K.arange(start, stop) 528 | 529 | def shape(self, x): 530 | shape = self.K.shape(x) # tf tensor 531 | return HashableTuple(tuple(shape)) 532 | 533 | def reduce(self, x, operation, axes): 534 | return getattr(self.K, operation)(x, axis=axes) 535 | 536 | def reshape(self, x, shape): 537 | return self.K.reshape(x, shape) 538 | 539 | def transpose(self, x, axes): 540 | return self.K.permute_dimensions(x, axes) 541 | 542 | def stack_on_zeroth_dimension(self, tensors: list): 543 | return self.K.stack(tensors) 544 | 545 | def tile(self, x, repeats): 546 | return self.K.tile(x, repeats) 547 | 548 | def add_axis(self, x, new_position): 549 | return self.K.expand_dims(x, new_position) 550 | 551 | def is_float_type(self, x): 552 | return 'float' in self.K.dtype(x) 553 | 554 | def layers(self): 555 | from .layers import keras 556 | return keras 557 | 558 | 559 | class PaddleBackend(AbstractBackend): 560 | framework_name = 'paddle' 561 | 562 | def __init__(self): 563 | import paddle 564 | self.paddle = paddle 565 | 566 | def is_appropriate_type(self, tensor): 567 | return isinstance(tensor, self.paddle.Tensor) 568 | 569 | def from_numpy(self, x): 570 | variable = self.paddle.to_tensor(x) 571 | if self.is_float_type(variable): 572 | # attach grad only to floating types 573 | variable.stop_gradient = False 574 | return variable 575 | 576 | def to_numpy(self, x): 577 | return x.detach().numpy() 578 | 579 | def arange(self, start, stop): 580 | return self.paddle.arange(start, stop, dtype=self.paddle.int64) 581 | 582 | def shape(self, x): 583 | shape = x.shape 584 | return HashableTuple(tuple(shape)) 585 | 586 | def reduce(self, x, operation, reduced_axes): 587 | for axis in sorted(reduced_axes, reverse=True): 588 | if operation == 'min': 589 | x = x.min(axis=axis) 590 | elif operation == 'max': 591 | x = x.max(axis=axis) 592 | elif operation in ['sum', 'mean', 'prod']: 593 | x = getattr(x, operation)(axis=axis) 594 | else: 595 | raise NotImplementedError('Unknown reduction ', operation) 596 | return x 597 | 598 | def transpose(self, x, axes): 599 | return x.transpose(axes) 600 | 601 | def stack_on_zeroth_dimension(self, tensors: list): 602 | return self.paddle.stack(tensors) 603 | 604 | def tile(self, x, repeats): 605 | return x.tile(repeats) 606 | 607 | def add_axis(self, x, new_position): 608 | return self.paddle.unsqueeze(x, new_position) 609 | 610 | def is_float_type(self, x): 611 | return x.dtype in [self.paddle.float16, self.paddle.float32, self.paddle.float64] 612 | 613 | def layers(self): 614 | from .layers import paddle 615 | return paddle 616 | -------------------------------------------------------------------------------- /rudalle_paddle/packages/einops/einops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import functools 3 | import itertools 4 | import math 5 | from collections import OrderedDict 6 | from typing import Tuple, List, Dict, Union, Callable 7 | 8 | from ._backends import get_backend 9 | from .parsing import ParsedExpression, _ellipsis, AnonymousAxis 10 | 11 | ReductionCallable = Callable[['tensor', Tuple[int]], 'tensor'] # noqa 12 | Reduction = Union[str, ReductionCallable] 13 | 14 | 15 | class EinopsError(RuntimeError): 16 | """ Runtime error thrown by einops """ 17 | pass 18 | 19 | 20 | _reductions = ('min', 'max', 'sum', 'mean', 'prod') 21 | 22 | 23 | def _product(sequence): 24 | """ minimalistic product that works both with numbers and symbols. Supports empty lists """ 25 | result = 1 26 | for element in sequence: 27 | result *= element 28 | return result 29 | 30 | 31 | def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: Tuple[int], backend): 32 | reduced_axes = tuple(reduced_axes) 33 | if callable(reduction_type): 34 | # custom callable 35 | return reduction_type(tensor, reduced_axes) 36 | else: 37 | # one of built-in operations 38 | if len(reduced_axes) == 0: 39 | return tensor 40 | assert reduction_type in _reductions 41 | if reduction_type == 'mean': 42 | if not backend.is_float_type(tensor): 43 | raise NotImplementedError('reduce_mean is not available for non-floating tensors') 44 | return backend.reduce(tensor, reduction_type, reduced_axes) 45 | 46 | 47 | def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes): 48 | # TODO this method can be optimized 49 | assert len(axes_reordering) + len(reduced_axes) == len(init_shapes) 50 | # joining consecutive axes that will be reduced 51 | # possibly we can skip this if all backends can optimize this (not sure) 52 | reduced_axes = tuple(sorted(reduced_axes)) 53 | for i in range(len(reduced_axes) - 1)[::-1]: 54 | if reduced_axes[i] + 1 == reduced_axes[i + 1]: 55 | removed_axis = reduced_axes[i + 1] 56 | removed_length = init_shapes[removed_axis] 57 | init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] 58 | init_shapes[removed_axis - 1] *= removed_length 59 | reduced_axes = reduced_axes[:i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:]) 60 | 61 | # removing axes that are moved together during reshape 62 | def build_mapping(): 63 | init_to_final = {} 64 | for axis in range(len(init_shapes)): 65 | if axis in reduced_axes: 66 | init_to_final[axis] = None 67 | else: 68 | after_reduction = sum(x is not None for x in init_to_final.values()) 69 | init_to_final[axis] = list(axes_reordering).index(after_reduction) 70 | return init_to_final 71 | 72 | init_axis_to_final_axis = build_mapping() 73 | 74 | for init_axis in range(len(init_shapes) - 1)[::-1]: 75 | if init_axis_to_final_axis[init_axis] is None: 76 | continue 77 | if init_axis_to_final_axis[init_axis + 1] is None: 78 | continue 79 | if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]: 80 | removed_axis = init_axis + 1 81 | removed_length = init_shapes[removed_axis] 82 | removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis)) 83 | 84 | reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes) 85 | init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] 86 | init_shapes[removed_axis - 1] *= removed_length 87 | old_reordering = axes_reordering 88 | axes_reordering = [] 89 | for axis in old_reordering: 90 | if axis == removed_axis_after_reduction: 91 | pass 92 | elif axis < removed_axis_after_reduction: 93 | axes_reordering.append(axis) 94 | else: 95 | axes_reordering.append(axis - 1) 96 | init_axis_to_final_axis = build_mapping() 97 | 98 | return init_shapes, reduced_axes, axes_reordering, final_shapes 99 | 100 | 101 | class TransformRecipe: 102 | """ 103 | Recipe describes actual computation pathway. 104 | Recipe can be applied to a tensor or variable. 105 | """ 106 | 107 | # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+) 108 | 109 | def __init__(self, 110 | # list of expressions (or just sizes) for elementary axes as they appear in left expression. 111 | # this is what (after computing unknown parts) will be a shape after first transposition. 112 | # If ellipsis is present, it forms one dimension here (in the right position). 113 | elementary_axes_lengths: List, 114 | # each dimension in input can help to reconstruct length of one elementary axis 115 | # or verify one of dimensions. Each element points to element of elementary_axes_lengths 116 | input_composite_axes: List[Tuple[List[int], List[int]]], 117 | # indices of axes to be squashed 118 | reduced_elementary_axes: Tuple[int], 119 | # in which order should axes be reshuffled after reduction 120 | axes_permutation: Tuple[int], 121 | # at which positions which of elementary axes should appear 122 | added_axes: Dict[int, int], 123 | # ids of axes as they appear in result, again pointers to elementary_axes_lengths, 124 | # only used to infer result dimensions 125 | output_composite_axes: List[List[int]], 126 | reduction_type: Reduction = 'rearrange', 127 | # positions of ellipsis in lhs and rhs of expression 128 | ellipsis_position_in_lhs: int = math.inf, 129 | ): 130 | self.elementary_axes_lengths = elementary_axes_lengths 131 | self.input_composite_axes = input_composite_axes 132 | self.output_composite_axes = output_composite_axes 133 | # self.final_axes_grouping_flat = list(itertools.chain(*output_composite_axes)) 134 | self.axes_permutation = axes_permutation 135 | self.added_axes = added_axes 136 | self.reduction_type = reduction_type 137 | # This is redundant information, but more convenient during to use in reconstruction 138 | self.reduced_elementary_axes = reduced_elementary_axes 139 | self.ellipsis_position_in_lhs = ellipsis_position_in_lhs 140 | 141 | @functools.lru_cache(maxsize=1024) 142 | def reconstruct_from_shape(self, shape, optimize=False): 143 | """ 144 | Reconstruct all actual parameters using shape. 145 | Shape is a tuple that may contain integers, shape symbols (tf, keras, theano) and UnknownSize (keras, mxnet) 146 | known axes can be integers or symbols, but not Nones. 147 | """ 148 | axes_lengths = list(self.elementary_axes_lengths) 149 | if self.ellipsis_position_in_lhs != math.inf: 150 | if len(shape) < len(self.input_composite_axes) - 1: 151 | raise EinopsError('Expected at least {} dimensions, got {}'.format( 152 | len(self.input_composite_axes) - 1, len(shape))) 153 | else: 154 | if len(shape) != len(self.input_composite_axes): 155 | raise EinopsError('Expected {} dimensions, got {}'.format(len(self.input_composite_axes), len(shape))) 156 | for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composite_axes): 157 | before_ellipsis = input_axis 158 | after_ellipsis = input_axis + len(shape) - len(self.input_composite_axes) 159 | if input_axis == self.ellipsis_position_in_lhs: 160 | assert len(known_axes) == 0 and len(unknown_axes) == 1 161 | unknown_axis, = unknown_axes 162 | ellipsis_shape = shape[before_ellipsis:after_ellipsis + 1] 163 | if any(d is None for d in ellipsis_shape): 164 | raise EinopsError("Couldn't infer shape for one or more axes represented by ellipsis") 165 | axes_lengths[unknown_axis] = _product(ellipsis_shape) 166 | else: 167 | if input_axis < self.ellipsis_position_in_lhs: 168 | length = shape[before_ellipsis] 169 | else: 170 | length = shape[after_ellipsis] 171 | known_product = 1 172 | for axis in known_axes: 173 | known_product *= axes_lengths[axis] 174 | 175 | if len(unknown_axes) == 0: 176 | if isinstance(length, int) and isinstance(known_product, int) and length != known_product: 177 | raise EinopsError('Shape mismatch, {} != {}'.format(length, known_product)) 178 | else: 179 | if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0: 180 | raise EinopsError("Shape mismatch, can't divide axis of length {} in chunks of {}".format( 181 | length, known_product)) 182 | 183 | unknown_axis, = unknown_axes 184 | axes_lengths[unknown_axis] = length // known_product 185 | 186 | # at this point all axes_lengths are computed (either have values or variables, but not Nones) 187 | 188 | # TODO more readable expression. and confirm we don't want to deal with ellipsis 189 | init_shapes = axes_lengths[:len(axes_lengths) - len(self.added_axes)] 190 | # reduced_axes_lengths = [dim for i, dim in enumerate(axes_lengths) if i not in self.reduced_elementary_axes] 191 | final_shapes = [] 192 | for output_axis, grouping in enumerate(self.output_composite_axes): 193 | if grouping == _ellipsis: 194 | final_shapes.extend(ellipsis_shape) 195 | else: 196 | lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping] 197 | final_shapes.append(_product(lengths)) 198 | reduced_axes = self.reduced_elementary_axes 199 | axes_reordering = self.axes_permutation 200 | added_axes = {pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items()} 201 | if optimize: 202 | assert len(self.added_axes) == 0 203 | return _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes) 204 | else: 205 | return init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes 206 | 207 | def apply(self, tensor): 208 | backend = get_backend(tensor) 209 | init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = self.reconstruct_from_shape( 210 | backend.shape(tensor)) 211 | tensor = backend.reshape(tensor, init_shapes) 212 | tensor = _reduce_axes(tensor, reduction_type=self.reduction_type, reduced_axes=reduced_axes, backend=backend) 213 | tensor = backend.transpose(tensor, axes_reordering) 214 | if len(added_axes) > 0: 215 | tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes) 216 | return backend.reshape(tensor, final_shapes) 217 | 218 | 219 | @functools.lru_cache(256) 220 | def _prepare_transformation_recipe(pattern: str, 221 | operation: Reduction, 222 | axes_lengths: Tuple[Tuple]) -> TransformRecipe: 223 | """ Perform initial parsing of pattern and provided supplementary info 224 | axes_lengths is a tuple of tuples (axis_name, axis_length) 225 | """ 226 | left, rght = pattern.split('->') 227 | left = ParsedExpression(left) 228 | rght = ParsedExpression(rght) 229 | 230 | # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction 231 | if not left.has_ellipsis and rght.has_ellipsis: 232 | raise EinopsError('Ellipsis found in left side, but not right side of a pattern {}'.format(pattern)) 233 | if left.has_ellipsis and left.has_ellipsis_parenthesized: 234 | raise EinopsError('Ellipsis is parenthesis in the left side is not allowed: {}'.format(pattern)) 235 | if operation == 'rearrange': 236 | difference = set.symmetric_difference(left.identifiers, rght.identifiers) 237 | if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes: 238 | raise EinopsError('Non-unitary anonymous axes are not supported in rearrange (exception is length 1)') 239 | if len(difference) > 0: 240 | raise EinopsError('Identifiers only on one side of expression (should be on both): {}'.format(difference)) 241 | elif operation == 'repeat': 242 | difference = set.difference(left.identifiers, rght.identifiers) 243 | if len(difference) > 0: 244 | raise EinopsError('Unexpected identifiers on the left side of repeat: {}'.format(difference)) 245 | axes_without_size = set.difference({ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)}, 246 | {*left.identifiers, *(ax for ax, _ in axes_lengths)}) 247 | if len(axes_without_size) > 0: 248 | raise EinopsError('Specify sizes for new axes in repeat: {}'.format(axes_without_size)) 249 | elif operation in _reductions or callable(operation): 250 | difference = set.difference(rght.identifiers, left.identifiers) 251 | if len(difference) > 0: 252 | raise EinopsError('Unexpected identifiers on the right side of reduce {}: {}'.format(operation, difference)) 253 | else: 254 | raise EinopsError('Unknown reduction {}. Expect one of {}.'.format(operation, _reductions)) 255 | 256 | # parsing all dimensions to find out lengths 257 | axis_name2known_length = OrderedDict() 258 | for composite_axis in left.composition: 259 | for axis_name in composite_axis: 260 | if isinstance(axis_name, AnonymousAxis): 261 | axis_name2known_length[axis_name] = axis_name.value 262 | else: 263 | axis_name2known_length[axis_name] = None 264 | 265 | # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point 266 | 267 | repeat_axes_names = [] 268 | for axis_name in rght.identifiers: 269 | if axis_name not in axis_name2known_length: 270 | if isinstance(axis_name, AnonymousAxis): 271 | axis_name2known_length[axis_name] = axis_name.value 272 | else: 273 | axis_name2known_length[axis_name] = None 274 | repeat_axes_names.append(axis_name) 275 | 276 | axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)} 277 | reduced_axes = [position for axis, position in axis_name2position.items() if axis not in rght.identifiers] 278 | reduced_axes = tuple(sorted(reduced_axes)) 279 | 280 | for elementary_axis, axis_length in axes_lengths: 281 | if not ParsedExpression.check_axis_name(elementary_axis): 282 | raise EinopsError('Invalid name for an axis', elementary_axis) 283 | if elementary_axis not in axis_name2known_length: 284 | raise EinopsError('Axis {} is not used in transform'.format(elementary_axis)) 285 | axis_name2known_length[elementary_axis] = axis_length 286 | 287 | input_axes_known_unknown = [] 288 | # some of shapes will be inferred later - all information is prepared for faster inference 289 | for composite_axis in left.composition: 290 | known = {axis for axis in composite_axis if axis_name2known_length[axis] is not None} 291 | unknown = {axis for axis in composite_axis if axis_name2known_length[axis] is None} 292 | if len(unknown) > 1: 293 | raise EinopsError('Could not infer sizes for {}'.format(unknown)) 294 | assert len(unknown) + len(known) == len(composite_axis) 295 | input_axes_known_unknown.append( 296 | ([axis_name2position[axis] for axis in known], 297 | [axis_name2position[axis] for axis in unknown]) 298 | ) 299 | 300 | axis_position_after_reduction = {} 301 | for axis_name in itertools.chain(*left.composition): 302 | if axis_name in rght.identifiers: 303 | axis_position_after_reduction[axis_name] = len(axis_position_after_reduction) 304 | 305 | result_axes_grouping = [] 306 | for composite_axis in rght.composition: 307 | if composite_axis == _ellipsis: 308 | result_axes_grouping.append(_ellipsis) 309 | else: 310 | result_axes_grouping.append([axis_name2position[axis] for axis in composite_axis]) 311 | 312 | ordered_axis_right = list(itertools.chain(*rght.composition)) 313 | axes_permutation = tuple( 314 | axis_position_after_reduction[axis] for axis in ordered_axis_right if axis in left.identifiers) 315 | added_axes = {i: axis_name2position[axis_name] for i, axis_name in enumerate(ordered_axis_right) 316 | if axis_name not in left.identifiers} 317 | 318 | ellipsis_left = math.inf if _ellipsis not in left.composition else left.composition.index(_ellipsis) 319 | 320 | return TransformRecipe( 321 | elementary_axes_lengths=list(axis_name2known_length.values()), 322 | input_composite_axes=input_axes_known_unknown, 323 | reduced_elementary_axes=reduced_axes, 324 | axes_permutation=axes_permutation, 325 | added_axes=added_axes, 326 | output_composite_axes=result_axes_grouping, 327 | reduction_type=operation, 328 | ellipsis_position_in_lhs=ellipsis_left, 329 | ) 330 | 331 | 332 | def reduce(tensor, pattern: str, reduction: Reduction, **axes_lengths: int): 333 | """ 334 | einops.reduce provides combination of reordering and reduction using reader-friendly notation. 335 | 336 | Examples for reduce operation: 337 | 338 | ```python 339 | >>> x = np.random.randn(100, 32, 64) 340 | 341 | # perform max-reduction on the first axis 342 | >>> y = reduce(x, 't b c -> b c', 'max') 343 | 344 | # same as previous, but with clearer axes meaning 345 | >>> y = reduce(x, 'time batch channel -> batch channel', 'max') 346 | 347 | >>> x = np.random.randn(10, 20, 30, 40) 348 | 349 | # 2d max-pooling with kernel size = 2 * 2 for image processing 350 | >>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2) 351 | 352 | # if one wants to go back to the original height and width, depth-to-space trick can be applied 353 | >>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2) 354 | >>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w') 355 | 356 | # Adaptive 2d max-pooling to 3 * 4 grid 357 | >>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape 358 | (10, 20, 3, 4) 359 | 360 | # Global average pooling 361 | >>> reduce(x, 'b c h w -> b c', 'mean').shape 362 | (10, 20) 363 | 364 | # Subtracting mean over batch for each channel 365 | >>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean') 366 | 367 | # Subtracting per-image mean for each channel 368 | >>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean') 369 | 370 | ``` 371 | 372 | Parameters: 373 | tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch, mxnet.ndarray). 374 | list of tensors is also accepted, those should be of the same type and shape 375 | pattern: string, reduction pattern 376 | reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive 377 | alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided. 378 | This allows using various reductions, examples: np.max, tf.reduce_logsumexp, torch.var, etc. 379 | axes_lengths: any additional specifications for dimensions 380 | 381 | Returns: 382 | tensor of the same type as input 383 | """ 384 | try: 385 | hashable_axes_lengths = tuple(sorted(axes_lengths.items())) 386 | recipe = _prepare_transformation_recipe(pattern, reduction, axes_lengths=hashable_axes_lengths) 387 | return recipe.apply(tensor) 388 | except EinopsError as e: 389 | message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern) 390 | if not isinstance(tensor, list): 391 | message += '\n Input tensor shape: {}. '.format(get_backend(tensor).shape(tensor)) 392 | else: 393 | message += '\n Input is list. ' 394 | message += 'Additional info: {}.'.format(axes_lengths) 395 | raise EinopsError(message + '\n {}'.format(e)) 396 | 397 | 398 | def rearrange(tensor, pattern: str, **axes_lengths): 399 | """ 400 | einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors. 401 | This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, 402 | stack, concatenate and other operations. 403 | 404 | Examples for rearrange operation: 405 | 406 | ```python 407 | # suppose we have a set of 32 images in "h w c" format (height-width-channel) 408 | >>> images = [np.random.randn(30, 40, 3) for _ in range(32)] 409 | 410 | # stack along first (batch) axis, output is a single array 411 | >>> rearrange(images, 'b h w c -> b h w c').shape 412 | (32, 30, 40, 3) 413 | 414 | # concatenate images along height (vertical axis), 960 = 32 * 30 415 | >>> rearrange(images, 'b h w c -> (b h) w c').shape 416 | (960, 40, 3) 417 | 418 | # concatenated images along horizontal axis, 1280 = 32 * 40 419 | >>> rearrange(images, 'b h w c -> h (b w) c').shape 420 | (30, 1280, 3) 421 | 422 | # reordered axes to "b c h w" format for deep learning 423 | >>> rearrange(images, 'b h w c -> b c h w').shape 424 | (32, 3, 30, 40) 425 | 426 | # flattened each image into a vector, 3600 = 30 * 40 * 3 427 | >>> rearrange(images, 'b h w c -> b (c h w)').shape 428 | (32, 3600) 429 | 430 | # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 431 | >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape 432 | (128, 15, 20, 3) 433 | 434 | # space-to-depth operation 435 | >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape 436 | (32, 15, 20, 12) 437 | 438 | ``` 439 | 440 | When composing axes, C-order enumeration used (consecutive elements have different last axis) 441 | Find more examples in einops tutorial. 442 | 443 | Parameters: 444 | tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch, mxnet.ndarray). 445 | list of tensors is also accepted, those should be of the same type and shape 446 | pattern: string, rearrangement pattern 447 | axes_lengths: any additional specifications for dimensions 448 | 449 | Returns: 450 | tensor of the same type as input. If possible, a view to the original tensor is returned. 451 | 452 | """ 453 | if isinstance(tensor, list): 454 | if len(tensor) == 0: 455 | raise TypeError("Rearrange can't be applied to an empty list") 456 | tensor = get_backend(tensor[0]).stack_on_zeroth_dimension(tensor) 457 | return reduce(tensor, pattern, reduction='rearrange', **axes_lengths) 458 | 459 | 460 | def repeat(tensor, pattern: str, **axes_lengths): 461 | """ 462 | einops.repeat allows reordering elements and repeating them in arbitrary combinations. 463 | This operation includes functionality of repeat, tile, broadcast functions. 464 | 465 | Examples for repeat operation: 466 | 467 | ```python 468 | # a grayscale image (of shape height x width) 469 | >>> image = np.random.randn(30, 40) 470 | 471 | # change it to RGB format by repeating in each channel 472 | >>> repeat(image, 'h w -> h w c', c=3).shape 473 | (30, 40, 3) 474 | 475 | # repeat image 2 times along height (vertical axis) 476 | >>> repeat(image, 'h w -> (repeat h) w', repeat=2).shape 477 | (60, 40) 478 | 479 | # repeat image 3 times along width 480 | >>> repeat(image, 'h w -> h (repeat w)', repeat=3).shape 481 | (30, 120) 482 | 483 | # convert each pixel to a small square 2x2. Upsample image by 2x 484 | >>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape 485 | (60, 80) 486 | 487 | # pixelate image first by downsampling by 2x, then upsampling 488 | >>> downsampled = reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2) 489 | >>> repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape 490 | (30, 40) 491 | 492 | ``` 493 | 494 | When composing axes, C-order enumeration used (consecutive elements have different last axis) 495 | Find more examples in einops tutorial. 496 | 497 | Parameters: 498 | tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch, mxnet.ndarray). 499 | list of tensors is also accepted, those should be of the same type and shape 500 | pattern: string, rearrangement pattern 501 | axes_lengths: any additional specifications for dimensions 502 | 503 | Returns: 504 | Tensor of the same type as input. If possible, a view to the original tensor is returned. 505 | 506 | """ 507 | return reduce(tensor, pattern, reduction='repeat', **axes_lengths) 508 | 509 | 510 | def parse_shape(x, pattern: str): 511 | """ 512 | Parse a tensor shape to dictionary mapping axes names to their lengths. 513 | 514 | ```python 515 | # Use underscore to skip the dimension in parsing. 516 | >>> x = np.zeros([2, 3, 5, 7]) 517 | >>> parse_shape(x, 'batch _ h w') 518 | {'batch': 2, 'h': 5, 'w': 7} 519 | 520 | # `parse_shape` output can be used to specify axes_lengths for other operations: 521 | >>> y = np.zeros([700]) 522 | >>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape 523 | (2, 10, 5, 7) 524 | 525 | ``` 526 | 527 | For symbolic frameworks may return symbols, not integers. 528 | 529 | Parameters: 530 | x: tensor of any of supported frameworks 531 | pattern: str, space separated names for axes, underscore means skip axis 532 | 533 | Returns: 534 | dict, maps axes names to their lengths 535 | """ 536 | names = [elementary_axis for elementary_axis in pattern.split(' ') if len(elementary_axis) > 0] 537 | shape = get_backend(x).shape(x) 538 | if len(shape) != len(names): 539 | raise RuntimeError("Can't parse shape with different number of dimensions: {pattern} {shape}".format( 540 | pattern=pattern, shape=shape)) 541 | result = {} 542 | for axis_name, axis_length in zip(names, shape): 543 | if axis_name != '_': 544 | result[axis_name] = axis_length 545 | return result 546 | 547 | 548 | # this one is probably not needed in the public API 549 | def _enumerate_directions(x): 550 | """ 551 | For an n-dimensional tensor, returns tensors to enumerate each axis. 552 | ```python 553 | x = np.zeros([2, 3, 4]) # or any other tensor 554 | i, j, k = _enumerate_directions(x) 555 | result = i + 2 * j + 3 * k 556 | ``` 557 | 558 | `result[i, j, k] = i + 2 * j + 3 * k`, and also has the same shape as result 559 | Works very similarly to numpy.ogrid (open indexing grid) 560 | """ 561 | backend = get_backend(x) 562 | shape = backend.shape(x) 563 | result = [] 564 | for axis_id, axis_length in enumerate(shape): 565 | shape = [1] * len(shape) 566 | shape[axis_id] = axis_length 567 | result.append(backend.reshape(backend.arange(0, axis_length), shape)) 568 | return result 569 | 570 | 571 | def asnumpy(tensor): 572 | """ 573 | Convert a tensor of an imperative framework (i.e. numpy/cupy/torch/gluon/etc.) to `numpy.ndarray` 574 | 575 | Parameters: 576 | tensor: tensor of any of known imperative framework 577 | 578 | Returns: 579 | `numpy.ndarray`, converted to numpy 580 | """ 581 | return get_backend(tensor).to_numpy(tensor) 582 | --------------------------------------------------------------------------------