├── .github ├── ISSUE_TEMPLATE │ ├── Help-wanted Issue.md │ └── feature-bug-issue.md └── workflows │ └── ci.yml ├── .gitignore ├── .pylintrc ├── ACKNOWLEDGMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTOR_LICENSE_AGREEMENT.md ├── LICENSE ├── README.md ├── hubconf.py ├── official ├── assets │ ├── cat.jpg │ ├── cat_det_out.jpg │ ├── cat_seg_out.jpg │ ├── dcgan.png │ ├── imagenet_class_info.json │ ├── norway_sample_2687.png │ ├── norway_sampling.mp4 │ ├── norway_segmentation.png │ ├── test_000009.png │ ├── test_000010.png │ ├── test_depth.png │ ├── test_sample_255.png │ ├── test_sampling.mp4 │ └── total.png ├── multimodal │ ├── __init__.py │ ├── big_sleep │ │ ├── README.md │ │ ├── __init__.py │ │ ├── big_sleep.py │ │ ├── biggan.py │ │ ├── ema.py │ │ ├── resample.py │ │ └── spectral_norm.py │ ├── clip │ │ ├── README.md │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── functional.py │ │ ├── inference_utils.py │ │ ├── models.py │ │ └── simple_tokenizer.py │ ├── dalle │ │ ├── README.md │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── dalle.py │ │ ├── functional.py │ │ ├── generate.py │ │ ├── pretrained.py │ │ ├── tokenizer.py │ │ ├── transformer.py │ │ └── vae │ │ │ ├── __init__.py │ │ │ ├── base_vae.py │ │ │ ├── openai_dvae.py │ │ │ ├── openaidvae │ │ │ ├── __init__.py │ │ │ ├── decoder.py │ │ │ ├── encoder.py │ │ │ └── utils.py │ │ │ └── vqgan_vae.py │ └── taming_transformer │ │ ├── README.md │ │ ├── __init__.py │ │ ├── cond_transformer.py │ │ ├── data │ │ └── drin_images │ │ │ ├── n01795545 │ │ │ └── ILSVRC2012_val_00023344.JPEG │ │ │ ├── n01819313 │ │ │ └── ILSVRC2012_val_00003068.JPEG │ │ │ ├── n01820546 │ │ │ ├── ILSVRC2012_val_00034784.JPEG │ │ │ └── ILSVRC2012_val_00047491.JPEG │ │ │ ├── n01828970 │ │ │ ├── ILSVRC2012_val_00001336.JPEG │ │ │ ├── ILSVRC2012_val_00008236.JPEG │ │ │ └── ILSVRC2012_val_00046802.JPEG │ │ │ ├── n01843065 │ │ │ └── ILSVRC2012_val_00022439.JPEG │ │ │ ├── n01847000 │ │ │ └── ILSVRC2012_val_00022364.JPEG │ │ │ ├── n02085782 │ │ │ └── ILSVRC2012_val_00012298.JPEG │ │ │ ├── n02086646 │ │ │ └── ILSVRC2012_val_00011473.JPEG │ │ │ ├── n02088466 │ │ │ └── ILSVRC2012_val_00013651.JPEG │ │ │ ├── n02089973 │ │ │ └── ILSVRC2012_val_00000028.JPEG │ │ │ ├── n02093256 │ │ │ └── ILSVRC2012_val_00046547.JPEG │ │ │ ├── n02096294 │ │ │ └── ILSVRC2012_val_00042133.JPEG │ │ │ ├── n02099601 │ │ │ └── ILSVRC2012_val_00005697.JPEG │ │ │ ├── n02099712 │ │ │ └── ILSVRC2012_val_00023471.JPEG │ │ │ ├── n02100877 │ │ │ └── ILSVRC2012_val_00039863.JPEG │ │ │ ├── n02101006 │ │ │ ├── ILSVRC2012_val_00032333.JPEG │ │ │ └── ILSVRC2012_val_00047325.JPEG │ │ │ ├── n02101556 │ │ │ └── ILSVRC2012_val_00030540.JPEG │ │ │ ├── n02102318 │ │ │ └── ILSVRC2012_val_00024691.JPEG │ │ │ ├── n02105505 │ │ │ └── ILSVRC2012_val_00031252.JPEG │ │ │ ├── n02110627 │ │ │ └── ILSVRC2012_val_00008310.JPEG │ │ │ └── n02111889 │ │ │ └── ILSVRC2012_val_00042625.JPEG │ │ ├── diffusion_modules.py │ │ ├── functional.py │ │ ├── inference_utils.py │ │ ├── mingpt.py │ │ ├── quantize.py │ │ └── vqgan.py ├── nlp │ ├── __init__.py │ └── bert │ │ ├── README.md │ │ ├── __init__.py │ │ ├── config_args.py │ │ ├── glue_data │ │ └── MRPC │ │ │ ├── dev.tsv │ │ │ ├── dev_ids.tsv │ │ │ ├── msr_paraphrase_test.txt │ │ │ ├── msr_paraphrase_train.txt │ │ │ ├── test.tsv │ │ │ └── train.tsv │ │ ├── model.py │ │ ├── mrpc_dataset.py │ │ ├── test.py │ │ ├── tokenization.py │ │ └── train.py ├── quantization │ ├── README.md │ ├── __init__.py │ ├── calibration.py │ ├── finetune.py │ ├── inference.py │ ├── models │ │ ├── __init__.py │ │ ├── mobilenet_v2.py │ │ ├── resnet.py │ │ └── shufflenet.py │ ├── param_config.py │ ├── test.py │ └── train.py └── vision │ ├── __init__.py │ ├── classification │ ├── README.md │ ├── __init__.py │ ├── dump.py │ ├── resnet │ │ ├── README.md │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── model.py │ │ ├── test.py │ │ └── train.py │ └── shufflenet │ │ ├── README.md │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── model.py │ │ ├── test.py │ │ └── train.py │ ├── detection │ ├── README.md │ ├── __init__.py │ ├── configs │ │ ├── __init__.py │ │ ├── atss_res101_coco_3x_800size.py │ │ ├── atss_res18_coco_3x_800size.py │ │ ├── atss_res34_coco_3x_800size.py │ │ ├── atss_res50_coco_3x_800size.py │ │ ├── atss_resx101_coco_2x_800size.py │ │ ├── faster_rcnn_res101_coco_3x_800size.py │ │ ├── faster_rcnn_res18_coco_3x_800size.py │ │ ├── faster_rcnn_res34_coco_3x_800size.py │ │ ├── faster_rcnn_res50_coco_3x_800size.py │ │ ├── faster_rcnn_resx101_coco_2x_800size.py │ │ ├── fcos_res101_coco_3x_800size.py │ │ ├── fcos_res18_coco_3x_800size.py │ │ ├── fcos_res34_coco_3x_800size.py │ │ ├── fcos_res50_coco_3x_800size.py │ │ ├── fcos_resx101_coco_2x_800size.py │ │ ├── freeanchor_res101_coco_3x_800size.py │ │ ├── freeanchor_res18_coco_3x_800size.py │ │ ├── freeanchor_res34_coco_3x_800size.py │ │ ├── freeanchor_res50_coco_3x_800size.py │ │ ├── freeanchor_resx101_coco_2x_800size.py │ │ ├── retinanet_res101_coco_3x_800size.py │ │ ├── retinanet_res18_coco_3x_800size.py │ │ ├── retinanet_res34_coco_3x_800size.py │ │ ├── retinanet_res50_coco_3x_800size.py │ │ └── retinanet_resx101_coco_2x_800size.py │ ├── layers │ │ ├── __init__.py │ │ ├── basic │ │ │ ├── __init__.py │ │ │ ├── functional.py │ │ │ ├── nn.py │ │ │ └── norm.py │ │ └── det │ │ │ ├── __init__.py │ │ │ ├── anchor.py │ │ │ ├── box_head.py │ │ │ ├── box_utils.py │ │ │ ├── fpn.py │ │ │ ├── loss.py │ │ │ ├── matcher.py │ │ │ ├── point_head.py │ │ │ ├── pooler.py │ │ │ ├── rcnn.py │ │ │ ├── rpn.py │ │ │ └── sampling.py │ ├── models │ │ ├── __init__.py │ │ ├── atss.py │ │ ├── faster_rcnn.py │ │ ├── fcos.py │ │ ├── freeanchor.py │ │ └── retinanet.py │ └── tools │ │ ├── data_mapper.py │ │ ├── inference.py │ │ ├── nms.py │ │ ├── test.py │ │ ├── test_in_table.py │ │ ├── test_random.py │ │ ├── train.py │ │ ├── train_random.py │ │ └── utils.py │ ├── gan │ ├── README.md │ ├── megengine_mimicry │ │ ├── __init__.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── data_utils.py │ │ │ └── image_loader.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ ├── compute_fid.py │ │ │ ├── compute_is.py │ │ │ ├── compute_kid.py │ │ │ ├── compute_metrics.py │ │ │ ├── fid │ │ │ │ ├── __init__.py │ │ │ │ └── fid_utils.py │ │ │ ├── inception_model │ │ │ │ ├── __init__.py │ │ │ │ └── inception_utils.py │ │ │ ├── inception_score │ │ │ │ ├── __init__.py │ │ │ │ └── inception_score_utils.py │ │ │ ├── kid │ │ │ │ ├── __init__.py │ │ │ │ └── kid_utils.py │ │ │ └── utils.py │ │ ├── nets │ │ │ ├── __init__.py │ │ │ ├── basemodel.py │ │ │ ├── blocks.py │ │ │ ├── dcgan │ │ │ │ ├── __init__.py │ │ │ │ ├── dcgan_base.py │ │ │ │ └── dcgan_cifar.py │ │ │ ├── gan.py │ │ │ ├── losses.py │ │ │ └── wgan │ │ │ │ ├── __init__.py │ │ │ │ ├── wgan_base.py │ │ │ │ └── wgan_cifar.py │ │ ├── training │ │ │ ├── __init__.py │ │ │ ├── logger.py │ │ │ ├── metric_log.py │ │ │ ├── scheduler.py │ │ │ └── trainer.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ └── vis.py │ ├── requirements.txt │ ├── train_dcgan.py │ └── train_wgan.py │ ├── keypoints │ ├── README.md │ ├── config.py │ ├── dataset.py │ ├── inference.py │ ├── models │ │ ├── __init__.py │ │ └── simplebaseline.py │ ├── test.py │ ├── train.py │ └── transforms.py │ └── segmentation │ ├── README.md │ ├── configs │ ├── __init__.py │ ├── deeplabv3plus_res101_cityscapes_768size.py │ └── deeplabv3plus_res101_voc_512size.py │ ├── models │ ├── __init__.py │ └── deeplabv3plus.py │ └── tools │ ├── inference.py │ ├── test.py │ ├── train.py │ └── utils.py ├── requirements.txt ├── requires-style.txt ├── run_format_check.sh └── setup.cfg /.github/ISSUE_TEMPLATE/Help-wanted Issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Help-wanted Issue 3 | about: 请使用此模板提出help-wanted任务 4 | title: Help-wanted Issue 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 背景 11 | 12 | 13 | 14 | ## 任务描述 15 | 16 | 17 | 18 | ## 目标 19 | 20 | 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-bug-issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature/Bug Issue 3 | about: 请使用此模型提出您的建议/问题 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | ## 环境 12 | 1.系统环境: 13 | 2.MegEngine版本: 14 | 3.python版本: 15 | 4.模型名称: 16 | 17 | ## 复现步骤 18 | 1. 19 | 2. 20 | 3. 21 | 22 | ## 请提供关键的代码片段便于追查问题 23 | 24 | 25 | 26 | ## 请提供完整的日志及报错信息 27 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: CI 4 | 5 | # Controls when the action will run. Triggers the workflow on push or pull request 6 | # events but only for the master branch 7 | on: 8 | push: 9 | pull_request: 10 | 11 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 12 | jobs: 13 | # This workflow contains a single job called "build" 14 | build: 15 | # The type of runner that the job will run on 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | python-version: [3.6, 3.7, 3.8] 20 | 21 | # Steps represent a sequence of tasks that will be executed as part of the job 22 | steps: 23 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 24 | - uses: actions/checkout@v2 25 | 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v1 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install -r requirements.txt 35 | 36 | # Runs a set of commands using the runners shell 37 | - name: Format check 38 | run: ./run_format_check.sh 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *log*/ 2 | *.jpg 3 | *.png 4 | 5 | # compilation and distribution 6 | __pycache__ 7 | _ext 8 | *.pyc 9 | *.so 10 | build/ 11 | dist/ 12 | wheels/ 13 | 14 | # pytorch/python/numpy formats 15 | *.pth 16 | *.pkl 17 | *.npy 18 | 19 | # ipython/jupyter notebooks 20 | *.ipynb 21 | **/.ipynb_checkpoints/ 22 | 23 | # Editor temporaries 24 | *.swn 25 | *.swo 26 | *.swp 27 | *~ 28 | 29 | # pycharm editor settings 30 | .idea 31 | 32 | # vscode editor settings 33 | .vscode 34 | 35 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to a positive environment for our community include: 10 | 11 | * Using welcoming and inclusive language 12 | * Being respectful of differing viewpoints and experiences 13 | * Gracefully accepting constructive criticism 14 | * Focusing on what is best for the community 15 | * Showing empathy towards other community members 16 | 17 | Examples of unacceptable behavior include: 18 | 19 | * The use of sexualized language or imagery, and sexual attention or advances of any kind 20 | * Trolling, insulting or derogatory comments, and personal or political attacks 21 | * Public or private harassment 22 | * Publishing others’ private information, such as a physical or email address, without their explicit permission 23 | * Other conduct which could reasonably be considered inappropriate in a professional setting 24 | 25 | All MegEngine forums and spaces are meant for professional interactions, and any behavior which could reasonably be considered inappropriate in a professional setting is unacceptable. 26 | 27 | ## Our Responsibilities 28 | 29 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 30 | 31 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 32 | 33 | ## Scope 34 | 35 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 36 | 37 | 38 | ## Enforcement 39 | 40 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at megengine@megvii.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 41 | 42 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 43 | 44 | ## Attribution 45 | 46 | This Code of Conduct is updated from the Contributor Covenant, version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 47 | 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MegEngine Models 2 | 3 | ![](https://github.com/MegEngine/Models/workflows/CI/badge.svg) 4 | 5 | 本仓库包含了采用[MegEngine](https://github.com/megengine/megengine)实现的各种主流深度学习模型。 6 | 7 | [official](./official)目录下提供了各种经典的图像分类、目标检测、图像分割以及自然语言模型的官方实现。每个模型同时提供了模型定义、推理以及训练的代码。 8 | 9 | 官方会一直维护[official](./official)下的代码,保持适配MegEngine的最新API,提供最优的模型实现。同时,提供高质量的学习文档,帮助新手学习如何在MegEngine下训练自己的模型。 10 | 11 | ## 综述 12 | 13 | 对于每个模型,我们提供了至少四个脚本文件:模型定义(`model.py`)、模型推理(`inference.py`)、模型训练(`train.py`)、模型测试(`test.py`)。 14 | 15 | 每个模型目录下都对应有一个`README`,介绍了模型的详细信息,并详细描述了训练和测试的流程。例如 [ResNet README](./official/vision/classification/resnet/README.md)。 16 | 17 | 另外,`official`下定义的模型可以通过`megengine.hub`来直接加载,例如: 18 | 19 | ```bash 20 | import megengine.hub 21 | 22 | # 只加载网络结构 23 | resnet18 = megengine.hub.load("megengine/models", "resnet18") 24 | # 加载网络结构和预训练权重 25 | resnet18 = megengine.hub.load("megengine/models", "resnet18", pretrained=True) 26 | ``` 27 | 28 | 更多可以通过`megengine.hub`接口加载的模型见[hubconf.py](./hubconf.py)。 29 | 30 | ## 安装和环境配置 31 | 32 | 在开始运行本仓库下的代码之前,用户需要通过以下步骤来配置本地环境: 33 | 34 | 1. 克隆仓库 35 | 36 | ```bash 37 | git clone https://github.com/MegEngine/Models.git 38 | ``` 39 | 40 | 2. 安装依赖包 41 | 42 | ```bash 43 | pip3 install --user -r requirements.txt 44 | ``` 45 | 46 | 3. 添加目录到python环境变量中 47 | 48 | ```bash 49 | export PYTHONPATH=/path/to/models:$PYTHONPATH 50 | ``` 51 | 52 | 53 | ## 官方模型介绍 54 | 55 | ### 图像分类 56 | 57 | 图像分类是计算机视觉的基础任务。许多计算机视觉的其它任务(例如物体检测)都使用了基于图像分类的预训练模型。因此,我们提供了各种在ImageNet上预训练好的分类模型, 58 | 具体实现模型参考[这里](./official/vision/classification). 59 | 60 | ### 目标检测 61 | 62 | 目标检测同样是计算机视觉中的常见任务,我们提供了多个经典的目标检测模型,具体模型的实现可以参考[这里](./official/vision/detection). 63 | 64 | ### 图像分割 65 | 66 | 语意分割也是计算机视觉中的一项基础任务,为此我们也提供了经典的语义分割模型,具体可以参考[这里](./official/vision/segmentation/). 67 | 68 | ### 人体关节点检测 69 | 70 | 我们提供了人体关节点检测的经典模型和高精度模型,具体的实现可以参考[这里](./official/vision/keypoints). 71 | 72 | ### 自然语言处理 73 | 74 | 我们同样支持一些常见的自然语言处理模型,模型的权重来自Google的pre-trained models, 用户可以直接使用`megengine.hub`轻松的调用预训练的bert模型。 75 | 76 | 另外,我们在[bert](./official/nlp/bert)中还提供了更加方便的脚本, 可以通过任务名直接获取到对应字典, 配置, 与预训练模型。 77 | 78 | ### 多模态 79 | 80 | 多模态学习拥有令人着迷的魅力,其有着丰富有趣的现实应用。我们支持了一些经典的多模态模型,模型的权重来源于官方预训练模型,用户可以参考仓库下的教程轻松体验多模态的奇妙。 81 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | from official.multimodal.big_sleep import BigGAN, Imagine, biggan_128, biggan_256, biggan_512 2 | from official.multimodal.clip.inference_utils import ClipInferenceUtils 3 | from official.multimodal.clip.models import ( 4 | rn50, 5 | rn50x4, 6 | rn50x16, 7 | rn50x64, 8 | rn101, 9 | vit_b_16, 10 | vit_b_32, 11 | vit_l_14, 12 | vit_l_14_336px, 13 | ) 14 | from official.multimodal.dalle import ( 15 | Generator, 16 | OpenAIDiscreteVAE, 17 | OpenAIDiscreteVAEDecoder, 18 | OpenAIDiscreteVAEEncoder, 19 | VQGanVAE, 20 | coco_512_16_16d_16h_80tsl, 21 | openai_discrete_VAE_decoder, 22 | openai_discrete_VAE_encoder, 23 | vqgan_vae_1024, 24 | ) 25 | from official.multimodal.taming_transformer import ( 26 | ConditionalSampler, 27 | FastSampler, 28 | Reconstruction, 29 | celebahq_transformer, 30 | drin_transformer, 31 | s_flckr_transformer, 32 | vqgan_gumbel_f8, 33 | vqgan_imagenet_f16_1024, 34 | vqgan_imagenet_f16_16384, 35 | ) 36 | from official.nlp.bert.model import ( 37 | cased_L_12_H_768_A_12, 38 | cased_L_24_H_1024_A_16, 39 | chinese_L_12_H_768_A_12, 40 | multi_cased_L_12_H_768_A_12, 41 | uncased_L_12_H_768_A_12, 42 | uncased_L_24_H_1024_A_16, 43 | wwm_cased_L_24_H_1024_A_16, 44 | wwm_uncased_L_24_H_1024_A_16, 45 | ) 46 | from official.quantization.models import quantized_resnet18 47 | from official.vision.classification.resnet.model import ( 48 | BasicBlock, 49 | Bottleneck, 50 | ResNet, 51 | resnet18, 52 | resnet34, 53 | resnet50, 54 | resnet101, 55 | resnet152, 56 | resnext50_32x4d, 57 | resnext101_32x8d, 58 | ) 59 | from official.vision.classification.shufflenet.model import ( 60 | shufflenet_v2_x0_5, 61 | shufflenet_v2_x1_0, 62 | shufflenet_v2_x1_5, 63 | shufflenet_v2_x2_0, 64 | ) 65 | from official.vision.detection.configs import ( 66 | atss_res18_coco_3x_800size, 67 | atss_res34_coco_3x_800size, 68 | atss_res50_coco_3x_800size, 69 | atss_res101_coco_3x_800size, 70 | atss_resx101_coco_2x_800size, 71 | faster_rcnn_res18_coco_3x_800size, 72 | faster_rcnn_res34_coco_3x_800size, 73 | faster_rcnn_res50_coco_3x_800size, 74 | faster_rcnn_res101_coco_3x_800size, 75 | faster_rcnn_resx101_coco_2x_800size, 76 | fcos_res18_coco_3x_800size, 77 | fcos_res34_coco_3x_800size, 78 | fcos_res50_coco_3x_800size, 79 | fcos_res101_coco_3x_800size, 80 | fcos_resx101_coco_2x_800size, 81 | freeanchor_res18_coco_3x_800size, 82 | freeanchor_res34_coco_3x_800size, 83 | freeanchor_res50_coco_3x_800size, 84 | freeanchor_res101_coco_3x_800size, 85 | freeanchor_resx101_coco_2x_800size, 86 | retinanet_res18_coco_3x_800size, 87 | retinanet_res34_coco_3x_800size, 88 | retinanet_res50_coco_3x_800size, 89 | retinanet_res101_coco_3x_800size, 90 | retinanet_resx101_coco_2x_800size, 91 | ) 92 | from official.vision.detection.models import ATSS, FCOS, FasterRCNN, FreeAnchor, RetinaNet 93 | from official.vision.detection.tools.utils import DetEvaluator 94 | from official.vision.keypoints.inference import KeypointEvaluator 95 | from official.vision.keypoints.models import ( 96 | simplebaseline_res50, 97 | simplebaseline_res101, 98 | simplebaseline_res152, 99 | ) 100 | from official.vision.segmentation.configs import ( 101 | deeplabv3plus_res101_cityscapes_768size, 102 | deeplabv3plus_res101_voc_512size, 103 | ) 104 | from official.vision.segmentation.models import DeepLabV3Plus 105 | -------------------------------------------------------------------------------- /official/assets/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/cat.jpg -------------------------------------------------------------------------------- /official/assets/cat_det_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/cat_det_out.jpg -------------------------------------------------------------------------------- /official/assets/cat_seg_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/cat_seg_out.jpg -------------------------------------------------------------------------------- /official/assets/dcgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/dcgan.png -------------------------------------------------------------------------------- /official/assets/norway_sample_2687.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/norway_sample_2687.png -------------------------------------------------------------------------------- /official/assets/norway_sampling.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/norway_sampling.mp4 -------------------------------------------------------------------------------- /official/assets/norway_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/norway_segmentation.png -------------------------------------------------------------------------------- /official/assets/test_000009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/test_000009.png -------------------------------------------------------------------------------- /official/assets/test_000010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/test_000010.png -------------------------------------------------------------------------------- /official/assets/test_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/test_depth.png -------------------------------------------------------------------------------- /official/assets/test_sample_255.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/test_sample_255.png -------------------------------------------------------------------------------- /official/assets/test_sampling.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/test_sampling.mp4 -------------------------------------------------------------------------------- /official/assets/total.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/assets/total.png -------------------------------------------------------------------------------- /official/multimodal/__init__.py: -------------------------------------------------------------------------------- 1 | from .dalle.dalle import DALLE 2 | -------------------------------------------------------------------------------- /official/multimodal/big_sleep/README.md: -------------------------------------------------------------------------------- 1 | # Big Sleep 2 | 3 | 此仓库包含MegEngine实现的多模态模型`Big Sleep`,其将`CLIP`与`BigGAN`的生成器相结合,用户可以轻松使用一行文本构想图像! 4 | 5 | ## 使用方法 6 | 7 | 请使用GPU设备,否则生成过程可能会过长。 8 | 9 | 使用`hub`加载 10 | 11 | ```python 12 | from megengine import hub 13 | modelhub = hub.import_module(repo_info='megengine/models', git_host='github.com') 14 | 15 | dream = modelhub.Imagine( 16 | # 需要进行构想的文本 17 | text = "fire in the sky", 18 | # 传入参考图像用于稍微引导生成 19 | img = None, 20 | # 生成图像尺寸大小 21 | image_size=512, 22 | # 迭代过程的学习率 23 | lr = 5e-2, 24 | # 保存图像的间隔 25 | save_every = 25, 26 | # 是否保存迭代过程中的所有图像,否则图像将会重写到一张图片上 27 | save_progress = True, 28 | # 惩罚关键词 29 | text_min = None, 30 | # 梯度累积的步数 31 | gradient_accumulate_every: int = 1, 32 | epochs: int = 20, 33 | iterations: int = 1050, 34 | # 是否将迭代过程中的所有图像保存为mp4视频文件 35 | animate: bool = False, 36 | # 保存mp4的帧率 37 | fps: int = 15, 38 | # BIgSleep中采样方式 39 | bilinear: bool = False, 40 | # 固定随机种子 41 | seed: Optional[int] = None, 42 | # 限制最大类别数量 43 | max_classes: Optional[int] = None, 44 | # 用于可微topk 45 | class_temperature: float = 2., 46 | # 保存文件时是否加上日期前缀 47 | save_date_time: bool = False, 48 | # 是否保存得分最高的图像 49 | save_best: bool = True, 50 | # 实验性采样 51 | experimental_resample: bool = False, 52 | ema_decay: float = 0.99, 53 | num_cutouts: int = 128, 54 | center_bias: bool = False, 55 | clip_type: str = 'RN50', 56 | root: str = 'BigSleep', 57 | ) 58 | 59 | # 开始迭代生成图像 60 | dream() 61 | ``` 62 | 63 | 本地加载 64 | 65 | ```python 66 | from official.multimodal.big_sleep import Imagine 67 | 68 | dream = Imagine( 69 | text = "fire in the sky", 70 | lr = 5e-2, 71 | save_every = 25, 72 | save_progress = True, 73 | image_size=512 74 | ) 75 | 76 | # 开始迭代生成图像 77 | dream() 78 | ``` 79 | 80 | ### 参考 81 | 82 | [lucidrains/big-sleep](https://github.com/lucidrains/big-sleep) 83 | -------------------------------------------------------------------------------- /official/multimodal/big_sleep/__init__.py: -------------------------------------------------------------------------------- 1 | from .big_sleep import Imagine 2 | from .biggan import BigGAN, biggan_128, biggan_256, biggan_512 3 | -------------------------------------------------------------------------------- /official/multimodal/big_sleep/ema.py: -------------------------------------------------------------------------------- 1 | # Exponential Moving Average (from https://gist.github.com/crowsonkb/76b94d5238272722290734bf4725d204) # noqa: E501 2 | from copy import deepcopy 3 | 4 | import megengine as mge 5 | import megengine.functional as F 6 | import megengine.module as M 7 | 8 | 9 | class EMA(M.Module): 10 | def __init__(self, model: M.Module, decay: float): 11 | super(EMA, self).__init__() 12 | self.model = model 13 | self.decay = decay 14 | self.accum = mge.tensor(1.) 15 | 16 | self._biased = deepcopy(model) 17 | self.average = deepcopy(model) 18 | for param in self._biased.parameters(): 19 | param.set_value(param.detach() * 0) 20 | for param in self.average.parameters(): 21 | param.set_value(param.detach() * 0) 22 | self.update() 23 | 24 | def update(self): 25 | if not self.training: 26 | raise RuntimeError('Update should only be called during training') 27 | 28 | self.accum *= self.decay 29 | 30 | model_params = dict(self.model.named_parameters()) 31 | biased_params = dict(self._biased.named_parameters()) 32 | average_params = dict(self.average.named_parameters()) 33 | assert model_params.keys() == biased_params.keys() == average_params.keys( 34 | ), 'Model parameter keys incompatible with EMA stored parameter keys' 35 | 36 | for name, param in model_params.items(): 37 | biased_params[name].set_value( 38 | F.mul(biased_params[name], self.decay)) 39 | biased_params[name].set_value( 40 | F.add(biased_params[name], (1 - self.decay) * param)) 41 | average_params[name].set_value(biased_params[name]) 42 | average_params[name].set_value( 43 | F.div(average_params[name], 1 - self.accum)) 44 | 45 | model_buffers = dict(self.model.named_buffers()) 46 | biased_buffers = dict(self._biased.named_buffers()) 47 | average_buffers = dict(self.average.named_buffers()) 48 | assert model_buffers.keys() == biased_buffers.keys() == average_buffers.keys() 49 | 50 | for name, buffer in model_buffers.items(): 51 | biased_buffers[name].set_value(buffer) 52 | average_buffers[name].set_value(buffer) 53 | 54 | def forward(self, *args, **kwargs): 55 | if self.training: 56 | return self.model(*args, **kwargs) 57 | return self.average(*args, **kwargs) 58 | -------------------------------------------------------------------------------- /official/multimodal/big_sleep/resample.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import update_wrapper 3 | 4 | import numpy as np 5 | 6 | import megengine as mge 7 | import megengine.functional as F 8 | 9 | 10 | def sinc(x): 11 | return F.where(x != 0, F.sin(math.pi * x) / (math.pi * x), F.ones_like(x)) 12 | 13 | 14 | def lanczos(x, a): 15 | cond = F.logical_and(-a < x, x < a) 16 | out = F.where(cond, sinc(x) * sinc(x / a), F.zeros_like(x)) 17 | return out / F.sum(out) 18 | 19 | 20 | def ramp(ratio, width): 21 | n = math.ceil(width / ratio + 1) 22 | out = np.zeros(n) 23 | cur = 0 24 | for i in range(out.shape[0]): 25 | out[i] = cur 26 | cur += ratio 27 | out = np.concatenate([np.flip(-out[1:], axis=0), out])[1:-1] 28 | return mge.tensor(out, dtype='float32') 29 | 30 | 31 | def odd(fn): 32 | return update_wrapper(lambda x: F.sin(x) * fn(F.abs(x)), fn) 33 | 34 | 35 | def _to_linear_srgb(input): 36 | cond = input <= 0.04045 37 | a = input / 12.92 38 | b = ((input + 0.055) / 1.055)**2.4 39 | return F.where(cond, a, b) 40 | 41 | 42 | def _to_nonlinear_srgb(input): 43 | cond = input <= 0.0031308 44 | a = 12.92 * input 45 | b = 1.055 * input**(1 / 2.4) - 0.055 46 | return F.where(cond, a, b) 47 | 48 | 49 | to_linear_srgb = odd(_to_linear_srgb) 50 | to_nonlinear_srgb = odd(_to_nonlinear_srgb) 51 | 52 | 53 | def resample(input, size, align_corners=True, is_srgb=False): # pylint: disable=unused-argument 54 | n, c, h, w = input.shape 55 | dh, dw = size 56 | 57 | if is_srgb: 58 | input = to_linear_srgb(input) 59 | 60 | input = input.reshape(n * c, 1, h, w) 61 | 62 | if dh < h: 63 | kernel_h = lanczos( 64 | ramp(dh / h, 3), 3).to(input.device).astype(input.dtype) 65 | pad_h = (kernel_h.shape[0] - 1) // 2 66 | input = F.pad( 67 | input, [(0, 0), (0, 0), (pad_h, pad_h), (0, 0)], 'reflect') 68 | input = F.conv2d(input, kernel_h[None, None, :, None]) 69 | 70 | if dw < w: 71 | kernel_w = lanczos( 72 | ramp(dw / w, 3), 3).to(input.device).astype(input.dtype) 73 | pad_w = (kernel_w.shape[0] - 1) // 2 74 | input = F.pad(input, [(0, 0), (0, 0), (0, 0), 75 | (pad_w, pad_w)], 'reflect') 76 | input = F.conv2d(input, kernel_w[None, None, None, :]) 77 | 78 | input = input.reshape(n, c, h, w) 79 | # NOTE: can not set align_corners when specify mode with `bicubic` in megengine 80 | input = F.nn.interpolate(input, size, mode='bicubic', 81 | align_corners=None) 82 | 83 | if is_srgb: 84 | input = to_nonlinear_srgb(input) 85 | 86 | return input 87 | -------------------------------------------------------------------------------- /official/multimodal/clip/README.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | 3 | 此仓库包含MegEngine实现的多模态模型`CLIP`,但不包含训练及测试代码。 4 | 5 | `models.py`中实现了CLIP的不同配置:`RN50`, `RN101`, `RN50x4`, `RN50x16`, `RN50x64`, `ViT-B-32`, `ViT-B-16`, `ViT-L-14`和`ViT-L-14-336px`。 6 | 7 | 在ImageNet V2 matched-frequency数据集上,以float16的精度达成了一下的零样本分类准确度 8 | 9 | | 模型 | TOP-1 |TOP-5 | 10 | | -------------- | -------|------| 11 | | RN50 | 53.55% |81.53%| 12 | | RN101 | 56.21% |83.77%| 13 | | RN50x4 | 59.77% |85.90%| 14 | | RN50x16 | 64.14% |88.39%| 15 | | RN50x64 | 66.90% |90.46%| 16 | | ViT-B-32 | 56.48% |83.57%| 17 | | ViT-B-16 | 62.24% |87.72%| 18 | | ViT-L-14 | 69.72% |90.89%| 19 | | ViT-L-14-336px | 70.72% |91.68%| 20 | 21 | ## 零样本(zero-shot)分类 22 | 23 | 用户可以使用以下模板使用`CLIP`进行零样本图像分类。 24 | 25 | ### 加载网络 26 | 27 | ```python 28 | import megengine as mge 29 | from megengine import hub 30 | modelhub = hub.import_module(repo_info='megengine/models', git_host='github.com') 31 | 32 | # 加载网络结构及预训练模型 33 | # 方式一 34 | clip = hub.load("megengine/models", "rn50", pretrained=True) 35 | clip.eval() 36 | 37 | # 将网络部分权重转换为float16, 仅限GPU 38 | clip.convert_weights('float16') 39 | 40 | # 方式二 41 | # 查看所有可用模型 42 | print(CLIP.available_models()) 43 | 44 | # 直接使用 from_pretrained 方法加载模型即可 45 | clip = CLIP.from_pretrained(model_name='RN50', dtype='float16') 46 | 47 | # 查看网络配置信息 48 | clip.model_config() 49 | 50 | # 使用float32的精度推理 51 | clip.convert_weigths('float32') 52 | ``` 53 | 54 | ### 数据处理 55 | 56 | ```python 57 | import cv2 58 | from megengine.data.transform import CenterCrop, Compose, Normalize, Resize 59 | 60 | #数据处理 61 | image_resolution = clip.image_resolution # clip需要固定输入图片的大小 62 | transfroms = Compose([ 63 | Resize(image_resolution, interpolation=cv2.INTER_CUBIC), 64 | CenterCrop(image_resolution), 65 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 66 | ]) 67 | 68 | ``` 69 | 70 | 数据处理构建完毕后需要用户手动构建`Dataloader`。 71 | 72 | ### 构建文本模板和类别 73 | 74 | `CLIP`需要一些文本模板/提示来描述某一张图片,比如:`a photo of {}.`,`a photo of many {}.`等,大括号中可以填入各种类别名称。这样为每一个类别都生成n句话,再使用文本编码器和图片编码器的输出向量做相似度计算,得分高者则认为其为该类的概率更高。 75 | 76 | `CLIP`中内置了imagenet的80个文本模板,这里使用内置的CLIP推理工具,使用方法如下。 77 | 78 | ```python 79 | utils = modelhub.ClipInferenceUtils 80 | ``` 81 | 82 | 随后调用如下方法即可得到对应的文本模板。 83 | 84 | ```python 85 | imagenet_templates = utils.generate_imagenet_templates() 86 | ``` 87 | 88 | 对于不同的数据集可以采用不同的文本模板,其格式如下: 89 | 90 | ```python 91 | templates: List[str] = [ 92 | 'a bad photo of a {}.', 93 | 'a photo of many {}.', 94 | ... 95 | ] 96 | ``` 97 | 98 | 同时我们需要各个类别的名称,可通过调用以下代码得到imagenet的1000个类别。 99 | 100 | ```python 101 | imagenet_classes = utils.generate_imagenet_classes() 102 | ``` 103 | 104 | 对于不同的数据集需要使用对应的类别名称,其格式如下: 105 | 106 | ```python 107 | classes:List[str] = [ 108 | 'tench', 109 | 'goldfish', 110 | ... 111 | ] 112 | ``` 113 | 114 | ### 生成零样本分类权重 115 | 116 | 使用下列代码生成权重。 117 | 118 | ```python 119 | zeroshot_wieghts = utils.generate_zeroshot_classifier_weight(clip, imagenet_classes, imagenet_templates) 120 | ``` 121 | 122 | ### 预测 123 | 124 | 传入模型、dataloader和零样本权重即可进行预测 125 | 126 | ```python 127 | top1, top5 = utils.predict(clip, loader, zeroshot_wieghts, logit_scale=100.) 128 | print(f"Top-1 accuracy: {top1:.2f}") 129 | print(f"Top-5 accuracy: {top5:.2f}") 130 | ``` 131 | 132 | 如果你只想预测一张图片,使用`predict_once`方法即可 133 | 134 | ```python 135 | logits = utils.predict_once(clip, image, zeroshot_wieghts, logit_scale=100.) 136 | ``` 137 | 138 | ## 参考 139 | 140 | [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) 141 | 142 | [openai/CLIP](https://github.com/openai/CLIP) 143 | -------------------------------------------------------------------------------- /official/multimodal/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference_utils import ClipInferenceUtils 2 | from .models import CLIP 3 | from .simple_tokenizer import SimpleTokenizer, tokenize 4 | -------------------------------------------------------------------------------- /official/multimodal/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/Models/78882f9cbaa037ad701f47d47bb80b66ad95ce87/official/multimodal/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /official/multimodal/dalle/README.md: -------------------------------------------------------------------------------- 1 | # DALLE 2 | 3 | 此仓库包含MegEngine实现的多模态模型DALLE以及文生图代码,但不包含训练代码。 4 | 5 | ## 图像重建 6 | 7 | 对于给定的大小为256x256的归一化四维输入,可以使用如下方式进行重建: 8 | 9 | ```python 10 | from official.multimodal.dalle.vae import OpenAIDiscreteVAE 11 | from official.multimodal.big_sleep.big_sleep import save_images 12 | 13 | 14 | vae = OpenAIDiscreteVAE(True) 15 | 16 | img_seq = vae.get_codebook_indices(input) 17 | 18 | reconstructed_image = vae.decode(img_seq) 19 | 20 | save_images(reconstructed_image, './image.png') 21 | 22 | ``` 23 | 24 | 25 | 26 | ## 文生图 27 | 28 | 可以使用以下代码体验文生图的功能,需要先下载[dalle_new_variety.bpe](https://data.megengine.org.cn/research/multimodality/dalle_new_variety.bpe)文件 29 | 30 | ```python 31 | from official.multimodal.dalle import coco_512_16_16d_16h_80tsl 32 | from official.multimodal.dalle import Generator 33 | 34 | dalle = coco_512_16_16d_16h_80tsl() 35 | 36 | generator = Generator( 37 | dalle, 38 | texts = ['A tower has a clock on it on a day with a blue sky'], 39 | num_images=64, 40 | batch_size=4, 41 | bpe_path = './dalle_new_variety.bpe', 42 | root='./dalle' 43 | ) 44 | 45 | generator() 46 | ``` 47 | 48 | 生成结果如下所示: 49 | 50 | ![res](../../assets/total.png) 51 | 52 | 53 | ## 参考 54 | 55 | [DALLE-pytorch](https://github.com/lucidrains/DALLE-pytorch) 56 | 57 | [DALLE-pytorch-discussions](https://github.com/lucidrains/DALLE-pytorch/discussions/335) 58 | -------------------------------------------------------------------------------- /official/multimodal/dalle/__init__.py: -------------------------------------------------------------------------------- 1 | from .dalle import DALLE 2 | from .generate import Generator 3 | from .pretrained import coco_512_16_16d_16h_80tsl 4 | from .vae import ( 5 | OpenAIDiscreteVAE, 6 | OpenAIDiscreteVAEDecoder, 7 | OpenAIDiscreteVAEEncoder, 8 | VQGanVAE, 9 | openai_discrete_VAE_decoder, 10 | openai_discrete_VAE_encoder 11 | ) 12 | from .vae.vqgan_vae import vqgan_vae_1024 13 | -------------------------------------------------------------------------------- /official/multimodal/dalle/pretrained.py: -------------------------------------------------------------------------------- 1 | from megengine import hub 2 | 3 | from .dalle import DALLE 4 | from .vae.vqgan_vae import vqgan_vae_1024 5 | 6 | 7 | @hub.pretrained( 8 | "https://data.megengine.org.cn/research/multimodality/dalle_coco_512_16_16d_16h_80tsl.pkl" 9 | ) 10 | def coco_512_16_16d_16h_80tsl(): 11 | vae = vqgan_vae_1024(False) 12 | model = DALLE( 13 | num_text_tokens=8192, 14 | text_seq_len=80, 15 | embed_dim=512, 16 | vae=vae, 17 | num_heads=16, 18 | head_dim=64, 19 | stable=False, 20 | depths=16, 21 | attention_types=['row', 'row', 'column', 'row', 'row', 'row', 'column', 'full'] 22 | ) 23 | return model 24 | -------------------------------------------------------------------------------- /official/multimodal/dalle/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import youtokentome as yttm 3 | 4 | import megengine.functional as F 5 | from megengine import Tensor 6 | 7 | from ..clip.simple_tokenizer import SimpleTokenizer # pylint: disable=unused-import # noqa: F401 8 | 9 | 10 | class YttmTokenizer: 11 | def __init__(self, bpe_path: str): 12 | if not os.path.exists(bpe_path): 13 | raise ValueError(f'BPE json path {bpe_path} does not exist') 14 | 15 | tokenizer = yttm.BPE(model=bpe_path) 16 | self.tokenizer = tokenizer 17 | self.vocab_size = tokenizer.vocab_size() 18 | 19 | def decode(self, tokens, pad_tokens=(0, )): 20 | if isinstance(tokens, Tensor): 21 | tokens = tokens.tolist() 22 | 23 | return self.tokenizer.decode(tokens, ignore_ids=pad_tokens) 24 | 25 | def encode(self, texts): 26 | encoded = self.tokenizer.encode(texts, output_type=yttm.OutputType.ID) 27 | return list(map(Tensor, encoded)) 28 | 29 | def tokenize(self, texts, context_length=256, truncate_text=False): 30 | if isinstance(texts, str): 31 | texts = [texts] 32 | 33 | all_tokens = self.encode(texts) 34 | 35 | result = F.zeros((len(all_tokens), context_length), dtype='int32') 36 | for i, tokens in enumerate(all_tokens): 37 | if len(tokens) > context_length: 38 | if truncate_text: 39 | tokens = tokens[:context_length] 40 | else: 41 | raise RuntimeError( 42 | f"Input {texts[i]} is too long for context length {context_length}") 43 | result[i, :len(tokens)] = Tensor(tokens) 44 | 45 | return result 46 | -------------------------------------------------------------------------------- /official/multimodal/dalle/vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .openai_dvae import DiscreteVAE as OpenAIDiscreteVAE 2 | from .openaidvae import ( 3 | OpenAIDiscreteVAEDecoder, 4 | OpenAIDiscreteVAEEncoder, 5 | map_pixels, 6 | openai_discrete_VAE_decoder, 7 | openai_discrete_VAE_encoder, 8 | unmap_pixels 9 | ) 10 | from .vqgan_vae import VQGanVAE 11 | -------------------------------------------------------------------------------- /official/multimodal/dalle/vae/base_vae.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import megengine.module as M 4 | 5 | 6 | class BaseVAE(M.Module): 7 | def __init__( 8 | self, 9 | num_layers: int, 10 | num_tokens: int, 11 | image_size: int, 12 | channels: int = 3, 13 | ): 14 | super(BaseVAE, self).__init__() 15 | 16 | self.channels = channels 17 | self.num_layers = num_layers 18 | self.num_tokens = num_tokens 19 | self.image_size = image_size 20 | 21 | @abstractmethod 22 | def get_codebook_indices(self, inputs): 23 | pass 24 | 25 | @abstractmethod 26 | def decode(self, inputs): 27 | pass 28 | 29 | def forward(self, inputs): 30 | raise NotImplementedError() 31 | -------------------------------------------------------------------------------- /official/multimodal/dalle/vae/openai_dvae.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import megengine.functional as F 4 | 5 | from .base_vae import BaseVAE 6 | from .openaidvae import openai_discrete_VAE_decoder, openai_discrete_VAE_encoder 7 | from .openaidvae.utils import map_pixels, unmap_pixels 8 | 9 | 10 | class DiscreteVAE(BaseVAE): 11 | def __init__( 12 | self, 13 | pretrained: bool = True 14 | ): 15 | super(DiscreteVAE, self).__init__( 16 | num_layers=3, 17 | num_tokens=8192, 18 | image_size=256, 19 | ) 20 | 21 | self.encoder = openai_discrete_VAE_encoder(pretrained=pretrained) 22 | self.decoder = openai_discrete_VAE_decoder(pretrained=pretrained) 23 | 24 | def get_codebook_indices(self, img): 25 | img = map_pixels(img) 26 | z_logits = self.encoder.blocks(img) 27 | z = F.argmax(z_logits, axis=1) 28 | z = F.flatten(z, 1) 29 | return z 30 | 31 | def decode(self, img_seq): 32 | b, n, = img_seq.shape 33 | L = int(math.sqrt(n)) 34 | img_seq = img_seq.reshape(b, L, L) 35 | 36 | z = F.one_hot(img_seq, num_classes=self.num_tokens) 37 | 38 | z = z.transpose(0, 3, 1, 2).astype('float32') 39 | x_stats = self.decoder(z).astype('float32') 40 | x_rec = unmap_pixels(F.sigmoid(x_stats[:, :3])) 41 | return x_rec 42 | 43 | def forward(self, inputs): 44 | raise NotImplementedError("Do not call forward method!") 45 | -------------------------------------------------------------------------------- /official/multimodal/dalle/vae/openaidvae/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import Decoder as OpenAIDiscreteVAEDecoder 2 | from .decoder import openai_discrete_VAE_decoder 3 | from .encoder import Encoder as OpenAIDiscreteVAEEncoder 4 | from .encoder import openai_discrete_VAE_encoder 5 | from .utils import map_pixels, unmap_pixels 6 | -------------------------------------------------------------------------------- /official/multimodal/dalle/vae/openaidvae/decoder.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | 4 | import megengine.module as M 5 | from megengine import hub 6 | 7 | from .utils import Upsample 8 | 9 | 10 | class DecoderBlock(M.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | layers, 16 | ) -> None: 17 | super(DecoderBlock, self).__init__() 18 | assert out_channels % 4 == 0, "The output channel must be devided into 4" 19 | self.post_gain = 1 / (layers ** 2) 20 | hid_ch = out_channels // 4 21 | self.id_path = M.Conv2d( 22 | in_channels, out_channels, 1) if in_channels != out_channels else M.Identity() 23 | self.res_path = M.Sequential(OrderedDict([ 24 | ("relu1", M.ReLU()), 25 | ('conv_1', M.Conv2d(in_channels, hid_ch, 1)), 26 | ("relu2", M.ReLU()), 27 | ('conv_2', M.Conv2d(hid_ch, hid_ch, 3, padding=1)), 28 | ("relu3", M.ReLU()), 29 | ('conv_3', M.Conv2d(hid_ch, hid_ch, 3, padding=1)), 30 | ("relu4", M.ReLU()), 31 | ('conv_4', M.Conv2d(hid_ch, out_channels, 3, padding=1)), 32 | ])) 33 | 34 | def forward(self, x): 35 | return self.id_path(x) + self.post_gain * self.res_path(x) 36 | 37 | 38 | class Decoder(M.Module): 39 | def __init__(self, n_init=128, n_hid=256, n_blk_per_group=2, out_ch=3, vocab_size=8192): 40 | super(Decoder, self).__init__() 41 | group_count = 4 42 | n_layers = group_count * n_blk_per_group 43 | blk_range = range(n_blk_per_group) 44 | make_blk = partial(DecoderBlock, layers=n_layers) 45 | self.vocab_size = vocab_size 46 | self.blocks = M.Sequential(OrderedDict([ 47 | ('input', M.Conv2d(vocab_size, n_init, 1)), 48 | ('group_1', M.Sequential(OrderedDict([ 49 | *[(f'block_{i + 1}', make_blk(n_init if i == 0 else 8 50 | * n_hid, 8 * n_hid)) for i in blk_range], 51 | ('upsample', Upsample(scale_factor=2, mode='nearest')), 52 | ]))), 53 | ('group_2', M.Sequential(OrderedDict([ 54 | *[(f'block_{i + 1}', make_blk(8 * n_hid if i 55 | == 0 else 4 * n_hid, 4 * n_hid)) for i in blk_range], 56 | ('upsample', Upsample(scale_factor=2, mode='nearest')), 57 | ]))), 58 | ('group_3', M.Sequential(OrderedDict([ 59 | *[(f'block_{i + 1}', make_blk(4 * n_hid if i 60 | == 0 else 2 * n_hid, 2 * n_hid)) for i in blk_range], 61 | ('upsample', Upsample(scale_factor=2, mode='nearest')), 62 | ]))), 63 | ('group_4', M.Sequential(OrderedDict([ 64 | *[(f'block_{i + 1}', make_blk(2 * n_hid if i 65 | == 0 else 1 * n_hid, 1 * n_hid)) for i in blk_range], 66 | ]))), 67 | ('output', M.Sequential(OrderedDict([ 68 | ('relu', M.ReLU()), 69 | ('conv', M.Conv2d(1 * n_hid, 2 * out_ch, 1)), 70 | ]))), 71 | ])) 72 | 73 | def forward(self, x): 74 | if x.ndim != 4: 75 | raise ValueError("The input must be 4-dim") 76 | if x.shape[1] != self.vocab_size: 77 | raise ValueError( 78 | "The input must be the same shape as the vocab") 79 | # if x.dtype != "float32": 80 | # raise ValueError("The input must be float32") 81 | return self.blocks(x) 82 | 83 | 84 | @hub.pretrained( 85 | "https://data.megengine.org.cn/research/multimodality/dalle_openai_dvae_decoder.pkl" 86 | ) 87 | def openai_discrete_VAE_decoder(**kwargs): 88 | return Decoder(**kwargs) 89 | -------------------------------------------------------------------------------- /official/multimodal/dalle/vae/openaidvae/encoder.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | 4 | import megengine.module as M 5 | from megengine import hub 6 | 7 | 8 | class EncoderBlock(M.Module): 9 | def __init__(self, n_in, n_out, layers): 10 | super(EncoderBlock, self).__init__() 11 | n_hid = n_out // 4 12 | self.pre_gain = 1 / (layers ** 2) 13 | self.id_path = M.Conv2d( 14 | n_in, n_out, 1) if n_in != n_out else M.Identity() 15 | self.res_path = M.Sequential(OrderedDict([ 16 | ("relu1", M.ReLU()), 17 | ('conv_1', M.Conv2d(n_in, n_hid, 3, padding=1)), 18 | ("relu2", M.ReLU()), 19 | ('conv_2', M.Conv2d(n_hid, n_hid, 3, padding=1)), 20 | ("relu3", M.ReLU()), 21 | ('conv_3', M.Conv2d(n_hid, n_hid, 3, padding=1)), 22 | ("relu4", M.ReLU()), 23 | ('conv_4', M.Conv2d(n_hid, n_out, 1)), 24 | ])) 25 | 26 | def forward(self, x): 27 | return self.id_path(x) + self.pre_gain * self.res_path(x) 28 | 29 | 30 | class Encoder(M.Module): 31 | def __init__(self, input_channel=3, n_hid=256, n_blk_per_group=2, vocab_size=8192): 32 | super(Encoder, self).__init__() 33 | group_count = 4 34 | n_layers = group_count * n_blk_per_group 35 | blk_range = range(n_blk_per_group) 36 | make_blk = partial(EncoderBlock, layers=n_layers) 37 | self.input_channel = input_channel 38 | self.vocab_size = vocab_size 39 | self.blocks = M.Sequential(OrderedDict([ 40 | ('input', M.Conv2d(input_channel, n_hid, 7, padding=3)), 41 | ('group_1', M.Sequential(OrderedDict([ 42 | *[(f'block_{i + 1}', make_blk(n_hid, n_hid)) 43 | for i in blk_range], 44 | ('pool', M.MaxPool2d(kernel_size=2, stride=2)), 45 | ]))), 46 | ('group_2', M.Sequential(OrderedDict([ 47 | *[(f'block_{i + 1}', make_blk(n_hid if i 48 | == 0 else 2 * n_hid, 2 * n_hid)) for i in blk_range], 49 | ('pool', M.MaxPool2d(kernel_size=2, stride=2)), 50 | ]))), 51 | ('group_3', M.Sequential(OrderedDict([ 52 | *[(f'block_{i + 1}', make_blk(2 * n_hid if i 53 | == 0 else 4 * n_hid, 4 * n_hid)) for i in blk_range], 54 | ('pool', M.MaxPool2d(kernel_size=2, stride=2)), 55 | ]))), 56 | ('group_4', M.Sequential(OrderedDict([ 57 | *[(f'block_{i + 1}', make_blk(4 * n_hid if i 58 | == 0 else 8 * n_hid, 8 * n_hid)) for i in blk_range], 59 | ]))), 60 | ('output', M.Sequential(OrderedDict([ 61 | ('relu', M.ReLU()), 62 | ('conv', M.Conv2d(8 * n_hid, self.vocab_size, 1)), 63 | ]))), 64 | ])) 65 | 66 | def forward(self, x): 67 | if x.ndim != 4: 68 | raise ValueError("Input must be 4D tensor") 69 | if x.shape[1] != self.input_channel: 70 | raise ValueError( 71 | f"Input channel must be {self.input_channel}") 72 | return self.blocks(x) 73 | 74 | 75 | @hub.pretrained( 76 | "https://data.megengine.org.cn/research/multimodality/dalle_openai_dvae_encoder.pkl" 77 | ) 78 | def openai_discrete_VAE_encoder(**kwargs): 79 | return Encoder(**kwargs) 80 | -------------------------------------------------------------------------------- /official/multimodal/dalle/vae/openaidvae/utils.py: -------------------------------------------------------------------------------- 1 | import megengine.functional as F 2 | import megengine.module as M 3 | 4 | logit_laplace_eps: float = 0.1 5 | 6 | 7 | def map_pixels(x): 8 | if x.ndim != 4: 9 | raise ValueError('input must be 4D') 10 | return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps 11 | 12 | 13 | def unmap_pixels(x): 14 | if x.ndim != 4: 15 | raise ValueError('input must be 4D') 16 | return F.clip((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1) 17 | 18 | 19 | class Upsample(M.Module): 20 | def __init__(self, scale_factor, mode): 21 | super().__init__() 22 | self.scale_factor = scale_factor 23 | self.mode = mode 24 | 25 | def forward(self, inputs): 26 | return F.nn.interpolate(inputs, scale_factor=self.scale_factor, mode=self.mode) 27 | -------------------------------------------------------------------------------- /official/multimodal/dalle/vae/vqgan_vae.py: -------------------------------------------------------------------------------- 1 | from math import log, sqrt 2 | from typing import Union 3 | 4 | import megengine.functional as F 5 | 6 | from ...taming_transformer.vqgan import GumbelVQ, VQModel, vqgan_imagenet_f16_1024 7 | from .base_vae import BaseVAE 8 | 9 | 10 | class VQGanVAE(BaseVAE): 11 | def __init__(self, model: Union[VQModel, GumbelVQ]): 12 | image_size = model.in_resolution 13 | num_layers = int(log(image_size / model.attn_resolution[0]) / log(2)) 14 | channels = model.in_channel 15 | num_tokens = model.quantize.num_embeddings 16 | 17 | super(VQGanVAE, self).__init__( 18 | num_layers, 19 | num_tokens, 20 | image_size, 21 | channels 22 | ) 23 | self.model = model 24 | 25 | self.is_gumbel = isinstance(model, GumbelVQ) 26 | 27 | def get_codebook_indices(self, img): 28 | b = img.shape[0] 29 | img = (2 * img) - 1 30 | _, _, [_, _, indices] = self.model.encode(img) 31 | if self.is_gumbel: 32 | return F.flatten(indices, 1) 33 | return indices.reshape(b, -1) 34 | 35 | def decode(self, img_seq): 36 | b, n = img_seq.shape 37 | one_hot_indices = F.one_hot(img_seq, num_classes=self.num_tokens).astype('float32') 38 | z = one_hot_indices @ self.model.quantize.embedding.weight 39 | 40 | c = z.shape[-1] 41 | z = z.reshape(b, int(sqrt(n)), -1, c).transpose(0, 3, 1, 2) 42 | img = self.model.decode(z) 43 | 44 | img = (F.clip(img, -1., 1.) + 1) * 0.5 45 | return img 46 | 47 | def forward(self): 48 | raise NotImplementedError() 49 | 50 | 51 | def vqgan_vae_1024(pretrained=True): 52 | vae = vqgan_imagenet_f16_1024(pretrained=pretrained) 53 | model = VQGanVAE(vae) 54 | return model 55 | -------------------------------------------------------------------------------- /official/multimodal/taming_transformer/README.md: -------------------------------------------------------------------------------- 1 | # Taming Transformer 2 | 3 | 此仓库包含MegEngine实现的`taming_transformer`模型代码及推理代码,但不包含训练代码。`taming_transformer`通过`VQGAN`将卷积的高效性和`Transformer`极强的表达能力相结合,拥有强大的图像重建和高分辨率图像合成能力。 4 | 5 | ## 图像重建 6 | 7 | 我们可以使用`VQGAN`来测试图像重建,`VQGAN`的结构参考与`Diffusion Model`,并且使用GAN的方式进行训练。其主要拥有两种不同的模型——`VQModel`和`GumbelVQ`,主要区别在于模型中的`quantize离散化`部分,`VQModel`使用`VQVAE`中的离散化方法,`GumbelVQ`则使用`Gumbel Softmax`进行离散化。 8 | 9 | 我们可以很方便的使用如下代码进行图像重建。 10 | 11 | ```python 12 | from official.multimodal.taming_transformer import Reconstruction 13 | 14 | # 加载模型及权重 15 | model = vqgan_imagenet_f16_16384(pretrained=True) 16 | 17 | # 传入模型 18 | rec = Reconstruction(model) 19 | 20 | image_path: str = ... 21 | # 传入图片路径和保存路径 22 | reconstructed_image = rec(image_path, file_name='reconstructed_image.png') 23 | ``` 24 | 25 | ## 从分割图采样 26 | 27 | `taming_transformer`可以利用分割图作为引导,逐步的从噪声中进行采样。可以使用如下代码进行采样。 28 | 29 | ```python 30 | from official.multimodal.taming_transformer import s_flckr_transformer 31 | # 加载模型及权重 32 | model = s_flckr_transformer(pretrained=True) 33 | 34 | sampler = ConditionalSampler( 35 | model, 36 | temperature=1.0, 37 | top_k=100, 38 | update_every=50, # 多少次采样保存一次图片 39 | scale_factor=1.0, # 对输入图片进行缩放 40 | animate=True, # 保存采样过程为mp4 41 | root='test', # 根目录,用于保存采样过程中的文件和视频 42 | seed=2022, # 固定随机种子 43 | kernal_size=16, # 每次采样的窗口大小,越大效果越好 44 | fps=15, # 保存视频的帧率 45 | segmentation_save=True # 为分割图使用专门的保存方式,保证每次推理保存的分割图色彩一致 46 | ) 47 | 48 | # 可以在official/multimodal/taming_transformer/data目录下找到更多图片 49 | segmentation_path: str = r"official/multimodal/taming_transformer/data/sflckr_segmentations/norway/25735082181_999927fe5a_b.png" 50 | # 传入分割图地址 51 | sampler.sample_segmentation(segmentation_path, name='norway') 52 | ``` 53 | 54 | 分割图如下所示: 55 | ![segmentation](../../assets/norway_segmentation.png) 56 | 57 | 采样结果如下所示: 58 | ![result](../../assets/norway_sample_2687.png) 59 | 多次运行即可获得更多样的结果 60 | 61 | 采样过程: 62 | 63 |