├── adamod
├── __init__.py
└── adamod.py
├── demos
├── cifar100
│ ├── requirement.txt.txt
│ ├── models
│ │ ├── __init__.py
│ │ ├── densenet.py
│ │ └── resnet.py
│ ├── pretrained
│ │ ├── resnet-sgd-lr0.05-momentum0.9
│ │ ├── densenet-sgd-lr0.05-momentum0.9
│ │ ├── resnet-adam-lr0.001-betas0.9-0.999
│ │ ├── densenet-adam-lr0.001-betas0.9-0.999
│ │ ├── resnet-adamod-lr0.001-betas0.9-0.999-0.999
│ │ └── densenet-adamod-lr0.001-betas0.9-0.999-0.999
│ ├── README.md
│ ├── visualization.py
│ └── main.py
├── README.md
└── nmt
│ ├── lr_scheduler
│ ├── __init__.py
│ └── cold_start_scheduler.py
│ └── README.md
├── img
└── Loss.bmp
├── release.sh
├── setup.py
├── README.md
└── LICENSE
/adamod/__init__.py:
--------------------------------------------------------------------------------
1 | from .adamod import AdaMod
--------------------------------------------------------------------------------
/demos/cifar100/requirement.txt.txt:
--------------------------------------------------------------------------------
1 | torch>=1.1.0
2 | torchvision>=0.3.0
--------------------------------------------------------------------------------
/img/Loss.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lancopku/AdaMod/HEAD/img/Loss.bmp
--------------------------------------------------------------------------------
/demos/cifar100/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .densenet import *
3 |
--------------------------------------------------------------------------------
/release.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python -m setup.py sdist bdist_wheel
3 | python -m twine upload dist/*
4 |
5 |
--------------------------------------------------------------------------------
/demos/cifar100/pretrained/resnet-sgd-lr0.05-momentum0.9:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lancopku/AdaMod/HEAD/demos/cifar100/pretrained/resnet-sgd-lr0.05-momentum0.9
--------------------------------------------------------------------------------
/demos/cifar100/pretrained/densenet-sgd-lr0.05-momentum0.9:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lancopku/AdaMod/HEAD/demos/cifar100/pretrained/densenet-sgd-lr0.05-momentum0.9
--------------------------------------------------------------------------------
/demos/cifar100/pretrained/resnet-adam-lr0.001-betas0.9-0.999:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lancopku/AdaMod/HEAD/demos/cifar100/pretrained/resnet-adam-lr0.001-betas0.9-0.999
--------------------------------------------------------------------------------
/demos/cifar100/pretrained/densenet-adam-lr0.001-betas0.9-0.999:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lancopku/AdaMod/HEAD/demos/cifar100/pretrained/densenet-adam-lr0.001-betas0.9-0.999
--------------------------------------------------------------------------------
/demos/cifar100/pretrained/resnet-adamod-lr0.001-betas0.9-0.999-0.999:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lancopku/AdaMod/HEAD/demos/cifar100/pretrained/resnet-adamod-lr0.001-betas0.9-0.999-0.999
--------------------------------------------------------------------------------
/demos/cifar100/pretrained/densenet-adamod-lr0.001-betas0.9-0.999-0.999:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lancopku/AdaMod/HEAD/demos/cifar100/pretrained/densenet-adamod-lr0.001-betas0.9-0.999-0.999
--------------------------------------------------------------------------------
/demos/README.md:
--------------------------------------------------------------------------------
1 | # Demos
2 |
3 | Here we provide some demos of using AdaMod on several benchmark tasks.The purpose of these demos is to give an example of how to use it your research, and also illustrate the robust performance of AdaMod.
4 |
5 | In short, AdaMod restrict the adaptive learning rates with adaptive and momental upper bounds. In this way, it can **smooths out unexpected large learning rates and stabilizes the training of deep neural networks.**.
6 |
7 | In NMT examples, you can observe that AdaMod achieves both faster convergence and stronger performance compared with vanilla Adam when training Transfomer-based models even if **without warmup**. Other auxiliary examples prove the versatility of our method.
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | __VERSION__ = '0.0.3'
4 |
5 | setup(name='adamod',
6 | version=__VERSION__,
7 | description='AdaMod optimization algorithm, build on PyTorch.',
8 | long_description=open("README.md", encoding='utf-8').read(),
9 | long_description_content_type="text/markdown",
10 | keywords=['machine learning', 'deep learning'],
11 | classifiers=[
12 | 'Intended Audience :: Science/Research',
13 | 'Development Status :: 3 - Alpha',
14 | 'License :: OSI Approved :: Apache Software License',
15 | 'Programming Language :: Python :: 3.6',
16 | 'Programming Language :: Python :: 3.7',
17 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
18 | ],
19 | url='https://github.com/karrynest/AdaMod',
20 | author='Jianbang Ding',
21 | author_email='jianbangding@pku.edu.cn',
22 | license='Apache',
23 | packages=['adamod'],
24 | install_requires=[
25 | 'torch>=0.4.0',
26 | ],
27 | zip_safe=False,
28 | python_requires='>=3.6.0')
29 |
--------------------------------------------------------------------------------
/demos/nmt/lr_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the LICENSE file in
5 | # the root directory of this source tree. An additional grant of patent rights
6 | # can be found in the PATENTS file in the same directory.
7 |
8 | import importlib
9 | import os
10 |
11 | from .fairseq_lr_scheduler import FairseqLRScheduler
12 |
13 |
14 | LR_SCHEDULER_REGISTRY = {}
15 |
16 |
17 | def build_lr_scheduler(args, optimizer):
18 | return LR_SCHEDULER_REGISTRY[args.lr_scheduler](args, optimizer)
19 |
20 |
21 | def register_lr_scheduler(name):
22 | """Decorator to register a new LR scheduler."""
23 |
24 | def register_lr_scheduler_cls(cls):
25 | if name in LR_SCHEDULER_REGISTRY:
26 | raise ValueError('Cannot register duplicate LR scheduler ({})'.format(name))
27 | if not issubclass(cls, FairseqLRScheduler):
28 | raise ValueError('LR Scheduler ({}: {}) must extend FairseqLRScheduler'.format(name, cls.__name__))
29 | LR_SCHEDULER_REGISTRY[name] = cls
30 | return cls
31 |
32 | return register_lr_scheduler_cls
33 |
34 |
35 | # automatically import any Python files in the optim/lr_scheduler/ directory
36 | for file in os.listdir(os.path.dirname(__file__)):
37 | if file.endswith('.py') and not file.startswith('_'):
38 | module = file[:file.find('.py')]
39 | importlib.import_module('fairseq.optim.lr_scheduler.' + module)
40 |
--------------------------------------------------------------------------------
/demos/cifar100/README.md:
--------------------------------------------------------------------------------
1 | # Examples on CIFAR-100
2 |
3 | In this example, we test AdaMod on the standard CIFAR-100 image classification dataset, comparing with SGD and Adam. The implementation is highly based on [this project](https://github.com/kuangliu/pytorch-cifar).
4 |
5 | Tested with PyTorch 1.1.0.
6 |
7 | ## Settings
8 |
9 | We have already provided the results produced by AdaMod with default settings and baseline optimizers with their best hyperparameters. The best hyperparameters are listed as follows to ease your reproduction:
10 |
11 | **ResNet-34/DenseNet-121:**
12 |
13 | | optimizer | lr | momentum | beta1 | beta2 | beta3 |
14 | | :---: | :---: | :---: | :---: | :---: | :---: |
15 | | SGD | 0.05 | 0.9 | | | | |
16 | | Adam | 0.001 | | 0.9 | 0.999 | | |
17 | | AdaMod (def.) | 0.001 | | 0.9 | 0.999 | 0.999 |
18 |
19 | For the sake of better performance, we apply a weight decay of `5e-4` to all the optimizers (decoupled weight decay to adaptive methods).
20 |
21 | ## Running by Yourself
22 |
23 | You may also run the experiment and visualize the result by yourself. The following is an example to train DenseNet-121 using AdaMod with a learning rate of 0.001 and a smoothing coefficient (i.e. **beta3**) of 0.999.
24 |
25 | ```bash
26 | python main.py --model=densenet --optim=adamod --lr=0.001 --beta3=0.999
27 | ```
28 |
29 | The checkpoints will be saved in the `checkpoint` folder and the data points of the learning curve will be save in the `curve` folder.
30 |
31 | ## Visualization
32 |
33 | You can directly run [visualization.py](./visualization.py) to make it easier to visualize the performance of AdaMod.
34 |
35 | ## Acknowledgement
36 | The way of searching the best settings for baseline optimizers is referenced from Luo et al. (2019). [Adaptive Gradient Methods with Dynamic Bound of Learning Rate](https://openreview.net/forum?id=Bkg3g2R9FX). In *Proc. of ICLR 2019*.
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/demos/cifar100/visualization.py:
--------------------------------------------------------------------------------
1 | import os
2 | #%matplotlib notebook
3 | import matplotlib.pyplot as plt
4 | import torch
5 | import numpy as np
6 |
7 | LABELS = ['SGD','Adam', 'AdaMod']
8 |
9 | def get_folder_path(use_pretrained=True):
10 | if use_pretrained:
11 | path = 'pretrained'
12 | else:
13 | path = 'curve'
14 | return path
15 |
16 | def get_curve_data(use_pretrained=True, model='ResNet'):
17 | folder_path = get_folder_path(use_pretrained)
18 | filenames = [name for name in os.listdir(folder_path) if name.startswith(model.lower())]
19 | paths = [os.path.join(folder_path, name) for name in filenames]
20 | keys = [name.split('-')[1] for name in filenames]
21 | return {key: torch.load(fp) for key, fp in zip(keys, paths)}
22 |
23 |
24 | def plot(use_pretrained=True, model='ResNet', optimizers=None, curve_type='train'):
25 | assert model in ['ResNet', 'DenseNet'], 'Invalid model name: {}'.format(model)
26 | assert curve_type in ['train', 'test'], 'Invalid curve type: {}'.format(curve_type)
27 | assert all(_ in LABELS for _ in optimizers), 'Invalid optimizer'
28 |
29 | curve_data = get_curve_data(use_pretrained, model=model)
30 |
31 | plt.figure()
32 | plt.title('{} Accuracy for {} on CIFAR-100'.format(curve_type.capitalize(), model))
33 | plt.xlabel('Epoch')
34 | plt.ylabel('{} Accuracy %'.format(curve_type.capitalize()))
35 | if curve_type == 'train':
36 | plt.ylim(80, 101)
37 | else:
38 | plt.ylim(50, 81)
39 |
40 | for optim in optimizers:
41 | accuracies = np.array(curve_data[optim.lower()]['{}_acc'.format(curve_type)])
42 | plt.plot(accuracies, label=optim)
43 |
44 | plt.grid(ls='--')
45 | plt.legend()
46 | plt.show()
47 | plt.savefig('cifar100-{}-{}.png'.format(model, curve_type.capitalize()))
48 |
49 | def main():
50 | # plot(use_pretrained=True, model='ResNet', optimizers=LABELS, curve_type='train')
51 | # plot(use_pretrained=True, model='ResNet', optimizers=LABELS, curve_type='test')
52 |
53 | plot(use_pretrained=True, model='DenseNet', optimizers=LABELS, curve_type='train')
54 | plot(use_pretrained=True, model='DenseNet', optimizers=LABELS, curve_type='test')
55 |
56 | if __name__ == '__main__':
57 | main()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaMod
2 |
3 | An optimizer which exerts adaptive momental upper bounds on individual learning rates to prevent them becoming undesirably lager than what the historical statistics suggest and avoid the non-convergence issue, thus to a better performance. Strong empirical results on many deep learning applications demonstrate the effectiveness of our proposed method especially on complex networks such as DenseNet and Transformer.
4 |
5 | Based on Ding et al. (2023). [An Adaptive Learning Method for Solving the Extreme Learning Rate Problem of Transformer.](https://link.springer.com/chapter/10.1007/978-3-031-44693-1_29)
6 |
7 |

8 |
9 | ## Installation
10 |
11 | AdaMod requires Python 3.6.0 or later.
12 |
13 | ### Installing via pip
14 |
15 | The preferred way to install AdaMod is via `pip` with a virtual environment.
16 | Just run
17 | ```bash
18 | pip install adamod
19 | ```
20 | in your Python environment and you are ready to go!
21 |
22 | ### Using source code
23 |
24 | As AdaMod is a Python class with only 100+ lines, an alternative way is directly downloading
25 | [adamod.py](./adamod/adamod.py) and copying it to your project.
26 |
27 | ## Usage
28 |
29 | You can use AdaMod just like any other PyTorch optimizers.
30 |
31 | ```python3
32 | optimizer = adamod.AdaMod(model.parameters(), lr=1e-3, beta3=0.999)
33 | ```
34 | As described in the paper, AdaMod can smooths out unexpected large learning rates throughout the training process. The `beta3` parameter is the smoothing coefficient for actual learning rate, which controls the average range. In common cases, a `beta3` in `{0.999,0.9999}` can achieve relatively good and stable results. See the paper for more details.
35 |
36 | ## Citation
37 |
38 | If you use AdaMod in your research, please cite FINAL VERSION [An Adaptive Learning Method for Solving the Extreme Learning Rate Problem of Transformer.](https://link.springer.com/chapter/10.1007/978-3-031-44693-1_29) Thanks!
39 | ```
40 | @inproceedings{DBLP:conf/nlpcc/DingRL23,
41 | author = {Jianbang Ding and Xuancheng Ren and Ruixuan Luo},
42 | title = {An Adaptive Learning Method for Solving the Extreme Learning Rate Problem of Transformer},
43 | booktitle = {{NLPCC} {(1)}},
44 | series = {Lecture Notes in Computer Science},
45 | volume = {14302},
46 | pages = {361--372},
47 | publisher = {Springer},
48 | year = {2023}
49 | }
50 | ```
51 |
52 | The arXiv version is available as an alternative:
53 | ```
54 | @article{ding2019adaptive,
55 | title={An Adaptive and Momental Bound Method for Stochastic Learning},
56 | author={Jianbang Ding and Xuancheng Ren and Ruixuan Luo and Xu Sun},
57 | journal={arXiv preprint arXiv:1910.12249},
58 | year={2019}
59 | }
60 | ```
61 |
62 | ## Demo
63 |
64 | For the full list of demos, please refer to [this page](./demos).
65 |
66 | ## Contributors
67 |
68 | [@dingjianbang](https://github.com/karrynest)
69 | [@luoruixuan](https://github.com/luoruixuan)
70 |
71 |
72 |
73 |
74 |
75 |
76 |
--------------------------------------------------------------------------------
/demos/nmt/lr_scheduler/cold_start_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the LICENSE file in
5 | # the root directory of this source tree. An additional grant of patent rights
6 | # can be found in the PATENTS file in the same directory.
7 |
8 | from . import FairseqLRScheduler, register_lr_scheduler
9 |
10 |
11 | @register_lr_scheduler('cold_start')
12 | class ColdStartSchedule(FairseqLRScheduler):
13 | """Decay the LR based on the inverse square root of the update number.
14 |
15 | We also support a warmup phase where we linearly increase the learning rate
16 | from some initial learning rate (``--warmup-init-lr``) until the configured
17 | learning rate (``--lr``). Thereafter we decay proportional to the number of
18 | updates, with a decay factor set to align with the configured learning rate.
19 |
20 | During warmup::
21 |
22 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
23 | lr = lrs[update_num]
24 |
25 | After warmup::
26 |
27 | decay_factor = args.lr * sqrt(args.warmup_updates)
28 | lr = decay_factor / sqrt(update_num)
29 | """
30 |
31 | def __init__(self, args, optimizer):
32 | super().__init__(args, optimizer)
33 | if len(args.lr) > 1:
34 | raise ValueError(
35 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.'
36 | ' Consider --lr-scheduler=fixed instead.'
37 | )
38 | warmup_end_lr = args.lr[0]
39 | if args.warmup_init_lr < 0:
40 | args.warmup_init_lr = warmup_end_lr
41 |
42 | # linearly warmup for the first args.warmup_updates
43 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
44 |
45 | # then, decay prop. to the inverse square root of the update number
46 | self.decay_factor = warmup_end_lr * args.warmup_updates**0.5
47 |
48 | # initial learning rate
49 | self.lr = args.warmup_init_lr
50 | self.optimizer.set_lr(self.lr)
51 |
52 | @staticmethod
53 | def add_args(parser):
54 | """Add arguments to the parser for this LR scheduler."""
55 | # fmt: off
56 | parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N',
57 | help='warmup the learning rate linearly for the first N updates')
58 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
59 | help='initial learning rate during warmup phase; default is args.lr')
60 | # fmt: on
61 |
62 | def step(self, epoch, val_loss=None):
63 | """Update the learning rate at the end of the given epoch."""
64 | super().step(epoch, val_loss)
65 | # we don't change the learning rate at epoch boundaries
66 | return self.optimizer.get_lr()
67 |
68 | def step_update(self, num_updates):
69 | """Update the learning rate after each update."""
70 | if num_updates < self.args.warmup_updates:
71 | self.lr = self.args.lr[0]
72 | else:
73 | self.lr = self.decay_factor * num_updates**-0.5
74 | self.optimizer.set_lr(self.lr)
75 | return self.lr
76 |
--------------------------------------------------------------------------------
/demos/nmt/README.md:
--------------------------------------------------------------------------------
1 | # Examples on NMT
2 |
3 | In this example, we test AdaMod on the IWSLT'14 De-En and WMT'14 En-De datasets, comparing with Adam. The implementation is highly based on [fairseq repo](https://github.com/pytorch/fairseq/tree/master/examples/translation).
4 |
5 | Tested with PyTorch 1.1.0.
6 |
7 | ## Settings
8 |
9 | ### IWSLT'14 German to English (Transformer-small)
10 | After downloading and preprocessing the data (please refer to the original [ repo](https://github.com/pytorch/fairseq/tree/master/examples/translation)), we'll train a Transformer-small model over this data:
11 | ```bash
12 | CUDA_VISIBLE_DEVICES=0 fairseq-train \
13 | data-bin/iwslt14.tokenized.de-en \
14 | --arch transformer_iwslt_de_en \
15 | --optimizer adamod --adam-betas '(0.9, 0.98)' --beta3 0.9999 \
16 | --lr 5e-4 --lr-scheduler cold_start --warmup-updates 4000 \
17 | --dropout 0.3 --weight-decay 0.0001 \
18 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
19 | --max-tokens 4000 --max-update 50000
20 | ```
21 | Note that for fair comparison with Adam, we still hold `--warmup-updates 4000` for this setting. In fact, on the IWSLT'14 De-En dataset, AdaMod does not depend on any `--lr-scheduler`. A constant lr (e.g. 5e-4) can achieve higher BLEU4 score up to `35.1`. What's more, if you further use `--update-freq` option for delay updating, the state-of-the-art result `35.6` will be achieved.
22 |
23 | Then you need to average 10 latest checkpoints:
24 | ```bash
25 | python scripts/average_checkpoints.py --inputs checkpoints/transformer \
26 | --num-epoch-checkpoints 10 --output checkpoints/transformer/model.pt
27 | ```
28 |
29 | Finally you can evaluate trained model:
30 | ```bash
31 | fairseq-generate data-bin/iwslt14.tokenized.de-en \
32 | --path checkpoints/transformer/model.pt \
33 | --batch-size 128 --beam 5 --remove-bpe
34 | ```
35 |
36 | ### WMT'14 English to German (Transformer-base/big)
37 | Similarly, after processing the data (following [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762)), you can train a new model on this data.
38 |
39 | For Transformer-base:
40 | ```bash
41 | fairseq-train data-bin/wmt14_en_de \
42 | --arch transformer_wmt_en_de \
43 | --optimizer adamod --adam-betas '(0.9, 0.98)' --beta3 0.999 --clip-norm 0.0 \
44 | --lr-scheduler cold_start --warmup-updates 4000 \
45 | --lr 0.0005 --min-lr 1e-09 \
46 | --dropout 0.1 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
47 | --max-tokens 3584 --max-epoch 50 \
48 | --fp16
49 | ```
50 | Note that the --fp16 flag requires you have CUDA 9.1 or greater and a Volta GPU.
51 |
52 | For Transformer-big:
53 | ```bash
54 | fairseq-train data-bin/wmt14_en_de \
55 | --arch transformer_vaswani_wmt_en_de_big \
56 | --optimizer adamod --adam-betas '(0.9, 0.98)' --beta3 0.9999 --clip-norm 0.0 \
57 | --lr-scheduler cold_start --warmup-updates 4000 \
58 | --lr 0.0005 --min-lr 1e-09 \
59 | --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
60 | --max-tokens 3584 --max-epoch 50 \
61 | --fp16
62 | ```
63 |
64 | Then you need to average 10 latest checkpoints:
65 | ```bash
66 | python scripts/average_checkpoints.py --inputs checkpoints/transformer \
67 | --num-epoch-checkpoints 10 --output checkpoints/transformer/model.pt
68 | ```
69 |
70 | Finally you can evaluate trained model:
71 | ```bash
72 | python generate.py data-bin/wmt14_en_de \
73 | --path checkpoints/transformer/model.pt \
74 | --batch-size 128 --beam 4 --remove-bpe --lenpen 0.6
75 | ```
76 |
--------------------------------------------------------------------------------
/demos/cifar100/models/densenet.py:
--------------------------------------------------------------------------------
1 | """
2 | .. Densely Connected Convolutional Networks:
3 | https://arxiv.org/abs/1608.06993
4 | """
5 | import math
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class Bottleneck(nn.Module):
13 | def __init__(self, in_planes, growth_rate):
14 | super(Bottleneck, self).__init__()
15 | self.bn1 = nn.BatchNorm2d(in_planes)
16 | self.conv1 = nn.Conv2d(in_planes, 4 * growth_rate, kernel_size=1, bias=False)
17 | self.bn2 = nn.BatchNorm2d(4 * growth_rate)
18 | self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
19 |
20 | def forward(self, x):
21 | out = self.conv1(F.relu(self.bn1(x)))
22 | out = self.conv2(F.relu(self.bn2(out)))
23 | out = torch.cat([out, x], 1)
24 | return out
25 |
26 |
27 | class Transition(nn.Module):
28 | def __init__(self, in_planes, out_planes):
29 | super(Transition, self).__init__()
30 | self.bn = nn.BatchNorm2d(in_planes)
31 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
32 |
33 | def forward(self, x):
34 | out = self.conv(F.relu(self.bn(x)))
35 | out = F.avg_pool2d(out, 2)
36 | return out
37 |
38 |
39 | class DenseNet(nn.Module):
40 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=100):
41 | super(DenseNet, self).__init__()
42 | self.growth_rate = growth_rate
43 |
44 | num_planes = 2 * growth_rate
45 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
46 |
47 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
48 | num_planes += nblocks[0] * growth_rate
49 | out_planes = int(math.floor(num_planes * reduction))
50 | self.trans1 = Transition(num_planes, out_planes)
51 | num_planes = out_planes
52 |
53 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
54 | num_planes += nblocks[1] * growth_rate
55 | out_planes = int(math.floor(num_planes * reduction))
56 | self.trans2 = Transition(num_planes, out_planes)
57 | num_planes = out_planes
58 |
59 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
60 | num_planes += nblocks[2] * growth_rate
61 | out_planes = int(math.floor(num_planes * reduction))
62 | self.trans3 = Transition(num_planes, out_planes)
63 | num_planes = out_planes
64 |
65 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
66 | num_planes += nblocks[3] * growth_rate
67 |
68 | self.bn = nn.BatchNorm2d(num_planes)
69 | self.linear = nn.Linear(num_planes, num_classes)
70 |
71 | def _make_dense_layers(self, block, in_planes, nblock):
72 | layers = []
73 | for i in range(nblock):
74 | layers.append(block(in_planes, self.growth_rate))
75 | in_planes += self.growth_rate
76 | return nn.Sequential(*layers)
77 |
78 | def forward(self, x):
79 | out = self.conv1(x)
80 | out = self.trans1(self.dense1(out))
81 | out = self.trans2(self.dense2(out))
82 | out = self.trans3(self.dense3(out))
83 | out = self.dense4(out)
84 | out = F.avg_pool2d(F.relu(self.bn(out)), 4)
85 | out = out.view(out.size(0), -1)
86 | out = self.linear(out)
87 | return out
88 |
89 |
90 | def DenseNet121():
91 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32)
92 |
93 |
94 | def DenseNet169():
95 | return DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32)
96 |
97 |
98 | def DenseNet201():
99 | return DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32)
100 |
101 |
102 | def DenseNet161():
103 | return DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48)
104 |
105 |
106 | def densenet_cifar():
107 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=12)
108 |
109 |
110 | def test():
111 | net = densenet_cifar()
112 | x = torch.randn(1, 3, 32, 32)
113 | y = net(x)
114 | print(y)
115 |
116 | # test()
117 |
--------------------------------------------------------------------------------
/demos/cifar100/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | .. Deep Residual Learning for Image Recognition:
3 | https://arxiv.org/abs/1512.03385
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class BasicBlock(nn.Module):
11 | expansion = 1
12 |
13 | def __init__(self, in_planes, planes, stride=1):
14 | super(BasicBlock, self).__init__()
15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1,
16 | bias=False)
17 | self.bn1 = nn.BatchNorm2d(planes)
18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
19 | self.bn2 = nn.BatchNorm2d(planes)
20 |
21 | self.shortcut = nn.Sequential()
22 | if stride != 1 or in_planes != self.expansion * planes:
23 | self.shortcut = nn.Sequential(
24 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride,
25 | bias=False),
26 | nn.BatchNorm2d(self.expansion * planes)
27 | )
28 |
29 | def forward(self, x):
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = self.bn2(self.conv2(out))
32 | out += self.shortcut(x)
33 | out = F.relu(out)
34 | return out
35 |
36 |
37 | class Bottleneck(nn.Module):
38 | expansion = 4
39 |
40 | def __init__(self, in_planes, planes, stride=1):
41 | super(Bottleneck, self).__init__()
42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
43 | self.bn1 = nn.BatchNorm2d(planes)
44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
45 | self.bn2 = nn.BatchNorm2d(planes)
46 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
47 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
48 |
49 | self.shortcut = nn.Sequential()
50 | if stride != 1 or in_planes != self.expansion * planes:
51 | self.shortcut = nn.Sequential(
52 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride,
53 | bias=False),
54 | nn.BatchNorm2d(self.expansion * planes)
55 | )
56 |
57 | def forward(self, x):
58 | out = F.relu(self.bn1(self.conv1(x)))
59 | out = F.relu(self.bn2(self.conv2(out)))
60 | out = self.bn3(self.conv3(out))
61 | out += self.shortcut(x)
62 | out = F.relu(out)
63 | return out
64 |
65 |
66 | class ResNet(nn.Module):
67 | def __init__(self, block, num_blocks, num_classes=100):
68 | super(ResNet, self).__init__()
69 | self.in_planes = 64
70 |
71 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
72 | self.bn1 = nn.BatchNorm2d(64)
73 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
74 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
75 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
76 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
77 | self.linear = nn.Linear(512 * block.expansion, num_classes)
78 |
79 | def _make_layer(self, block, planes, num_blocks, stride):
80 | strides = [stride] + [1] * (num_blocks - 1)
81 | layers = []
82 | for stride in strides:
83 | layers.append(block(self.in_planes, planes, stride))
84 | self.in_planes = planes * block.expansion
85 | return nn.Sequential(*layers)
86 |
87 | def forward(self, x):
88 | out = F.relu(self.bn1(self.conv1(x)))
89 | out = self.layer1(out)
90 | out = self.layer2(out)
91 | out = self.layer3(out)
92 | out = self.layer4(out)
93 | out = F.avg_pool2d(out, 4)
94 | out = out.view(out.size(0), -1)
95 | out = self.linear(out)
96 | return out
97 |
98 |
99 | def ResNet18():
100 | return ResNet(BasicBlock, [2, 2, 2, 2])
101 |
102 |
103 | def ResNet34():
104 | return ResNet(BasicBlock, [3, 4, 6, 3])
105 |
106 |
107 | def ResNet50():
108 | return ResNet(Bottleneck, [3, 4, 6, 3])
109 |
110 |
111 | def ResNet101():
112 | return ResNet(Bottleneck, [3, 4, 23, 3])
113 |
114 |
115 | def ResNet152():
116 | return ResNet(Bottleneck, [3, 8, 36, 3])
117 |
118 |
119 | def test():
120 | net = ResNet18()
121 | y = net(torch.randn(1, 3, 32, 32))
122 | print(y.size())
123 |
124 | # test()
125 |
--------------------------------------------------------------------------------
/adamod/adamod.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim import Optimizer
4 |
5 | class AdaMod(Optimizer):
6 | """Implements AdaMod algorithm with Decoupled Weight Decay (arxiv.org/abs/1711.05101)
7 | It has been proposed in `Adaptive and Momental Bounds for Adaptive Learning Rate Methods`_.
8 | Arguments:
9 | params (iterable): iterable of parameters to optimize or dicts defining
10 | parameter groups
11 | lr (float, optional): learning rate (default: 1e-3)
12 | betas (Tuple[float, float], optional): coefficients used for computing
13 | running averages of gradient and its square (default: (0.9, 0.999))
14 | beta3 (float, optional): smoothing coefficient for adaptive learning rates (default: 0.9999)
15 | eps (float, optional): term added to the denominator to improve
16 | numerical stability (default: 1e-8)
17 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
18 | """
19 |
20 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), beta3=0.999,
21 | eps=1e-8, weight_decay=0):
22 | if not 0.0 <= lr:
23 | raise ValueError("Invalid learning rate: {}".format(lr))
24 | if not 0.0 <= eps:
25 | raise ValueError("Invalid epsilon value: {}".format(eps))
26 | if not 0.0 <= betas[0] < 1.0:
27 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
28 | if not 0.0 <= betas[1] < 1.0:
29 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
30 | if not 0.0 <= beta3 < 1.0:
31 | raise ValueError("Invalid beta3 parameter: {}".format(beta3))
32 | defaults = dict(lr=lr, betas=betas, beta3=beta3, eps=eps,
33 | weight_decay=weight_decay)
34 | super(AdaMod, self).__init__(params, defaults)
35 |
36 | def __setstate__(self, state):
37 | super(AdaMod, self).__setstate__(state)
38 |
39 | def step(self, closure=None):
40 | """Performs a single optimization step.
41 | Arguments:
42 | closure (callable, optional): A closure that reevaluates the model
43 | and returns the loss.
44 | """
45 | loss = None
46 | if closure is not None:
47 | loss = closure()
48 |
49 | for group in self.param_groups:
50 | for p in group['params']:
51 | if p.grad is None:
52 | continue
53 | grad = p.grad.data
54 | if grad.is_sparse:
55 | raise RuntimeError(
56 | 'AdaMod does not support sparse gradients')
57 |
58 | state = self.state[p]
59 |
60 | # State initialization
61 | if len(state) == 0:
62 | state['step'] = 0
63 | # Exponential moving average of gradient values
64 | state['exp_avg'] = torch.zeros_like(p.data)
65 | # Exponential moving average of squared gradient values
66 | state['exp_avg_sq'] = torch.zeros_like(p.data)
67 | # Exponential moving average of actual learning rates
68 | state['exp_avg_lr'] = torch.zeros_like(p.data)
69 |
70 | exp_avg, exp_avg_sq, exp_avg_lr = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_lr']
71 | beta1, beta2 = group['betas']
72 |
73 | state['step'] += 1
74 |
75 | # Decay the first and second moment running average coefficient
76 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
77 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
78 |
79 | denom = exp_avg_sq.sqrt().add_(group['eps'])
80 |
81 | bias_correction1 = 1 - beta1 ** state['step']
82 | bias_correction2 = 1 - beta2 ** state['step']
83 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
84 |
85 | if group['weight_decay'] != 0:
86 | p.data.add_(-group['weight_decay'] * group['lr'], p.data)
87 |
88 | # Applies momental bounds on actual learning rates
89 | step_size = torch.full_like(denom, step_size)
90 | step_size.div_(denom)
91 | exp_avg_lr.mul_(group['beta3']).add_(1 - group['beta3'], step_size)
92 | step_size = torch.min(step_size, exp_avg_lr)
93 | step_size.mul_(exp_avg)
94 |
95 | p.data.add_(-step_size)
96 |
97 | return loss
98 |
--------------------------------------------------------------------------------
/demos/cifar100/main.py:
--------------------------------------------------------------------------------
1 | """Train CIFAR100 with PyTorch."""
2 | from __future__ import print_function
3 |
4 | import torch
5 | import torch.optim as optim
6 | import torch.backends.cudnn as cudnn
7 | import torchvision
8 | import torchvision.transforms as transforms
9 |
10 | import os
11 | import argparse
12 |
13 | from models import *
14 | from adamod import AdaMod
15 |
16 | def get_parser():
17 | parser = argparse.ArgumentParser(description='PyTorch CIFAR100 Training')
18 | parser.add_argument('--model', default='resnet', type=str, help='model',
19 | choices=['resnet', 'densenet'])
20 | parser.add_argument('--optim', default='adamod', type=str, help='optimizer',
21 | choices=['sgd', 'adam', 'adamod'])
22 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
23 | parser.add_argument('--beta3', default=0.999, type=float,
24 | help=' smoothing coefficient term of AdaMod')
25 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum term')
26 | parser.add_argument('--beta1', default=0.9, type=float, help='Adam coefficients beta_1')
27 | parser.add_argument('--beta2', default=0.999, type=float, help='Adam coefficients beta_2')
28 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
29 | parser.add_argument('--weight_decay', default=5e-4, type=float,
30 | help='weight decay for optimizers')
31 | return parser
32 |
33 |
34 | def build_dataset():
35 | print('==> Preparing data..')
36 | transform_train = transforms.Compose([
37 | transforms.RandomCrop(32, padding=4),
38 | transforms.RandomHorizontalFlip(),
39 | transforms.ToTensor(),
40 | transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
41 | ])
42 |
43 | transform_test = transforms.Compose([
44 | transforms.ToTensor(),
45 | transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
46 | ])
47 |
48 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True,
49 | transform=transform_train)
50 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True,
51 | num_workers=2)
52 |
53 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True,
54 | transform=transform_test)
55 | test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
56 |
57 | # classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
58 |
59 | return train_loader, test_loader
60 |
61 |
62 | def get_ckpt_name(dataset='cifar100', model='resnet', optimizer='adamod', lr=0.1, momentum=0.9,
63 | beta1=0.9, beta2=0.999, beta3=0.999):
64 | name = {
65 | 'sgd': 'lr{}-momentum{}'.format(lr, momentum),
66 | 'adam': 'lr{}-betas{}-{}'.format(lr, beta1, beta2),
67 | 'adamod': 'lr{}-betas{}-{}-{}'.format(lr, beta1, beta2, beta3),
68 | }[optimizer]
69 | return '{}-{}-{}'.format(model, optimizer, name)
70 |
71 |
72 | def load_checkpoint(ckpt_name):
73 | print('==> Resuming from checkpoint..')
74 | path = os.path.join('checkpoint', ckpt_name)
75 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
76 | assert os.path.exists(path), 'Error: checkpoint {} not found'.format(ckpt_name)
77 | return torch.load(ckpt_name)
78 |
79 |
80 | def build_model(args, device, ckpt=None):
81 | print('==> Building model..')
82 | net = {
83 | 'resnet': ResNet34,
84 | 'densenet': DenseNet121,
85 | }[args.model]()
86 | net = net.to(device)
87 | if device == 'cuda':
88 | net = torch.nn.DataParallel(net)
89 | cudnn.benchmark = True
90 |
91 | if ckpt:
92 | net.load_state_dict(ckpt['net'])
93 |
94 | return net
95 |
96 |
97 | def create_optimizer(args, model_params):
98 | if args.optim == 'sgd':
99 | return optim.SGD(model_params, args.lr, momentum=args.momentum,
100 | weight_decay=args.weight_decay)
101 | elif args.optim == 'adam':
102 | return optim.AdamW(model_params, args.lr, betas=(args.beta1, args.beta2),
103 | weight_decay=args.weight_decay)
104 | elif args.optim == 'adamod':
105 | return AdaMod(model_params, args.lr, betas=(args.beta1, args.beta2),
106 | beta3=args.beta3, weight_decay=args.weight_decay)
107 |
108 | def train(net, epoch, device, data_loader, optimizer, criterion):
109 | print('\nEpoch: %d' % epoch)
110 | net.train()
111 | train_loss = 0
112 | correct = 0
113 | total = 0
114 | for batch_idx, (inputs, targets) in enumerate(data_loader):
115 | inputs, targets = inputs.to(device), targets.to(device)
116 | optimizer.zero_grad()
117 | outputs = net(inputs)
118 | loss = criterion(outputs, targets)
119 | loss.backward()
120 | optimizer.step()
121 |
122 | train_loss += loss.item()
123 | _, predicted = outputs.max(1)
124 | total += targets.size(0)
125 | correct += predicted.eq(targets).sum().item()
126 |
127 | accuracy = 100. * correct / total
128 | print('train acc %.3f' % accuracy)
129 |
130 | return accuracy
131 |
132 |
133 | def test(net, device, data_loader, criterion):
134 | net.eval()
135 | test_loss = 0
136 | correct = 0
137 | total = 0
138 | with torch.no_grad():
139 | for batch_idx, (inputs, targets) in enumerate(data_loader):
140 | inputs, targets = inputs.to(device), targets.to(device)
141 | outputs = net(inputs)
142 | loss = criterion(outputs, targets)
143 |
144 | test_loss += loss.item()
145 | _, predicted = outputs.max(1)
146 | total += targets.size(0)
147 | correct += predicted.eq(targets).sum().item()
148 |
149 | accuracy = 100. * correct / total
150 | print(' test acc %.3f' % accuracy)
151 |
152 | return accuracy
153 |
154 |
155 | def main():
156 | parser = get_parser()
157 | args = parser.parse_args()
158 |
159 | train_loader, test_loader = build_dataset()
160 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
161 |
162 | ckpt_name = get_ckpt_name(model=args.model, optimizer=args.optim, lr=args.lr,
163 | momentum=args.momentum, beta1=args.beta1, beta2=args.beta2, beta3=args.beta3)
164 | if args.resume:
165 | ckpt = load_checkpoint(ckpt_name)
166 | best_acc = ckpt['acc']
167 | start_epoch = ckpt['epoch']
168 | else:
169 | ckpt = None
170 | best_acc = 0
171 | start_epoch = -1
172 |
173 | net = build_model(args, device, ckpt=ckpt)
174 | criterion = nn.CrossEntropyLoss()
175 | optimizer = create_optimizer(args, net.parameters())
176 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [150, 225], gamma=0.1,
177 | last_epoch=start_epoch)
178 | train_accuracies = []
179 | test_accuracies = []
180 |
181 | for epoch in range(start_epoch + 1, 300):
182 | scheduler.step()
183 | train_acc = train(net, epoch, device, train_loader, optimizer, criterion)
184 | test_acc = test(net, device, test_loader, criterion)
185 |
186 | # Save checkpoint.
187 | if test_acc > best_acc:
188 | print('Saving..')
189 | state = {
190 | 'net': net.state_dict(),
191 | 'acc': test_acc,
192 | 'epoch': epoch,
193 | }
194 | if not os.path.isdir('checkpoint'):
195 | os.mkdir('checkpoint')
196 | torch.save(state, os.path.join('checkpoint', ckpt_name))
197 | best_acc = test_acc
198 |
199 | train_accuracies.append(train_acc)
200 | test_accuracies.append(test_acc)
201 | if not os.path.isdir('curve'):
202 | os.mkdir('curve')
203 | torch.save({'train_acc': train_accuracies, 'test_acc': test_accuracies},
204 | os.path.join('curve', ckpt_name))
205 |
206 |
207 | if __name__ == '__main__':
208 | main()
209 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------