210 |
211 |
212 |
217 |
218 |
219 |
220 |
221 |
222 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
--------------------------------------------------------------------------------
/docs/searchindex.js:
--------------------------------------------------------------------------------
1 | Search.setIndex({docnames:["Criterion","Data","Lr_scheduler","Metric","Model","Optim","Vocab","index","notes/intro"],envversion:{"sphinx.domains.c":2,"sphinx.domains.changeset":1,"sphinx.domains.citation":1,"sphinx.domains.cpp":3,"sphinx.domains.index":1,"sphinx.domains.javascript":2,"sphinx.domains.math":2,"sphinx.domains.python":2,"sphinx.domains.rst":2,"sphinx.domains.std":2,"sphinx.ext.intersphinx":1,"sphinx.ext.todo":2,"sphinx.ext.viewcode":1,sphinx:56},filenames:["Criterion.rst","Data.rst","Lr_scheduler.rst","Metric.rst","Model.rst","Optim.rst","Vocab.rst","index.rst","notes/intro.md"],objects:{"lightning_asr.criterion":{joint_ctc_cross_entropy:[0,0,0,"-"],label_smoothed_cross_entropy:[0,0,0,"-"]},"lightning_asr.criterion.joint_ctc_cross_entropy":{JointCTCCrossEntropyLoss:[0,1,1,""]},"lightning_asr.criterion.label_smoothed_cross_entropy":{LabelSmoothedCrossEntropyLoss:[0,1,1,""]},"lightning_asr.data":{data_loader:[1,0,0,"-"],dataset:[1,0,0,"-"]},"lightning_asr.data.data_loader":{AudioDataLoader:[1,1,1,""],BucketingSampler:[1,1,1,""]},"lightning_asr.data.dataset":{AudioDataset:[1,1,1,""],FBankDataset:[1,1,1,""],MFCCDataset:[1,1,1,""],MelSpectrogramDataset:[1,1,1,""],SpectrogramDataset:[1,1,1,""]},"lightning_asr.data.librispeech":{preprocess:[1,0,0,"-"]},"lightning_asr.data.librispeech.preprocess":{collect_transcripts:[1,2,1,""],generate_manifest_file:[1,2,1,""],prepare_tokenizer:[1,2,1,""]},"lightning_asr.metric":{ErrorRate:[3,1,1,""],WordErrorRate:[3,1,1,""]},"lightning_asr.metric.WordErrorRate":{metric:[3,3,1,""]},"lightning_asr.optim":{adamp:[5,0,0,"-"],optimizer:[5,0,0,"-"],radam:[5,0,0,"-"]},"lightning_asr.optim.adamp":{AdamP:[5,1,1,""]},"lightning_asr.optim.adamp.AdamP":{step:[5,3,1,""]},"lightning_asr.optim.lr_scheduler":{lr_scheduler:[2,0,0,"-"],transformer_lr_scheduler:[2,0,0,"-"],tri_stage_lr_scheduler:[2,0,0,"-"]},"lightning_asr.optim.lr_scheduler.lr_scheduler":{LearningRateScheduler:[2,1,1,""]},"lightning_asr.optim.lr_scheduler.transformer_lr_scheduler":{TransformerLRScheduler:[2,1,1,""]},"lightning_asr.optim.lr_scheduler.tri_stage_lr_scheduler":{TriStageLRScheduler:[2,1,1,""]},"lightning_asr.optim.optimizer":{Optimizer:[5,1,1,""]},"lightning_asr.optim.radam":{RAdam:[5,1,1,""]},"lightning_asr.optim.radam.RAdam":{step:[5,3,1,""]},"lightning_asr.vocabs":{librispeech:[6,0,0,"-"],vocab:[6,0,0,"-"]},"lightning_asr.vocabs.librispeech":{LibriSpeechVocabulary:[6,1,1,""]},"lightning_asr.vocabs.vocab":{Vocabulary:[6,1,1,""]},lightning_asr:{metric:[3,0,0,"-"]}},objnames:{"0":["py","module","Python module"],"1":["py","class","Python class"],"2":["py","function","Python function"],"3":["py","method","Python method"]},objtypes:{"0":"py:module","1":"py:class","2":"py:function","3":"py:method"},terms:{"001":5,"03762":2,"08779":2,"16000":1,"1706":2,"1904":2,"2020":5,"999":5,"abstract":8,"class":[0,1,2,3,5,6],"default":0,"float":[0,1],"function":[0,5],"import":8,"int":[0,1,5,6],"new":8,"return":[0,5],"short":1,"true":[5,8],But:8,For:8,The:8,abil:8,abs:2,adam:5,adapt:5,after:[2,3],all:8,ani:8,apach:5,appli:1,applic:8,apply_spec_aug:1,appreci:8,area:1,arg:[1,6],arxiv:2,asr:8,assum:1,attent:8,audio:1,audio_path:1,audiodataload:1,audiodataset:1,augment:1,author:7,avail:8,bank:1,batch:1,batch_siz:1,befor:8,below:8,beta:5,between:[1,3],beyond:5,bin:8,blank:0,blank_id:0,bool:1,bucketingsampl:1,bug:8,build:8,cach:8,calcuat:3,calcul:0,callabl:5,can:8,checkout:8,classfic:0,classif:0,classs:5,clip:5,clone:8,closur:5,clovaai:5,cluster:8,coeffici:1,collabor:8,collect:1,collect_transcript:1,com:[5,8],command:8,complex:8,complic:8,comput:3,conda:8,confid:0,configur:8,conform:8,consist:8,contact:8,contribut:7,control:8,convert:6,copi:5,copyright:5,core:8,corp:5,correct:8,correspond:8,could:8,cpp_ext:8,creat:8,criterion:7,cross:0,cross_entropy_weight:0,ctc:8,ctc_weight:0,cuda:8,cuda_ext:8,current:8,data:7,data_load:1,data_sourc:1,dataset:[6,8],dataset_download:8,dataset_path:[1,8],decai:2,decay_step:2,decod:8,defin:3,degenerated_to_sgd:5,delta:5,depend:8,detail:8,develop:8,dim:0,dimens:0,dir:8,directli:[0,1,2,3,6],discuss:8,distanc:3,distribut:0,docstr:8,document:8,down:5,dynam:8,easi:8,ecognit:8,edit:3,emploi:2,encod:[0,8],entropi:0,env:8,environ:8,eos:1,eos_id:1,eps:5,error:3,errorr:3,especi:8,exactli:8,exampl:8,exponeti:2,extrem:8,fals:[1,5],familiar:8,fast:8,faster:8,fbankdataset:1,featur:8,feedback:8,feel:8,field:5,file:[1,8],filter:1,final_lr:2,final_lr_scal:2,fix:8,flag:1,follow:8,forg:8,fourier:1,frame:1,frame_length:1,frame_shift:1,framework:8,free:8,freq:1,freq_mask_num:1,freq_mask_para:1,from:[0,2,5],gcc:8,gener:[1,8],generate_manifest_fil:1,git:8,github:[5,8],given:5,global:8,gmail:8,gpu:8,grad:5,gradient:5,ground:0,guidelin:8,has:0,have:8,here:8,hierarch:8,high:8,higher:8,hold:2,hold_step:2,hop:1,hope:8,how:[1,8],http:[2,5,8],hydra:8,hyper:1,identif:[0,1],ighthn:8,ignor:0,ignore_index:0,implement:[2,8],improv:8,increas:2,index:[0,7],indic:1,init_lr:2,init_lr_scal:2,input:0,instal:7,instanti:5,intefac:[2,3],integ:0,introduc:8,introduct:7,invari:5,inverse_squre_root:2,issu:8,iter:[2,8],joint:8,joint_ctc_cross_entropi:0,jointctccrossentropyloss:0,keep:2,kim:8,kind:8,kospeech:5,kwarg:[1,6],label:6,label_smooth:0,label_smoothed_cross_entropi:0,labelsmoothedcrossentropyloss:0,lasr:8,learn:[2,5],learningrateschedul:2,length:1,librari:8,librispeech:8,librispeechvocabulari:6,librosa:8,licens:5,lightn:8,lightning_asr:[0,1,2,3,5,6],lightweight:8,like:8,limit:1,linearli:2,list:1,liyuanlucasliu:5,load:8,logarithm:0,logit:0,loss:5,lr_schedul:[2,5],lstm:8,main:8,major:8,make:[1,8],mani:1,manifest:1,mask:1,match:1,max_grad_norm:5,mean:[0,8],mel:1,melspectrogramdataset:1,method:0,metric:7,mfc:1,mfcc:1,mfccdataset:1,might:8,mit:[5,8],mix:8,model:[0,5,6,7,8],model_path:6,modifi:5,modul:[7,8],momentum:5,most:5,multi:8,naver:5,need:8,nesterov:5,none:[0,5],norm:5,num_class:0,num_mel:1,number:[0,1],numpi:8,nvidia:8,object:5,onc:8,one:[1,2,3,6],onli:8,open:8,optim:[2,7],option:[0,5,8],order:1,org:2,otherwis:5,page:7,paper:5,param:5,paramet:[0,1,3,5,6],part:1,path:[1,6],pdf:2,peak_lr:2,peech:8,pep:8,perform:[5,8],pip:8,pleas:8,point:0,precis:8,prepar:1,prepare_token:1,privid:0,probabl:0,problem:8,proce:8,project:8,provid:[2,3,5,8],python:8,pytorch:8,question:8,rate:[1,2,3,5],ratio:0,recognit:8,recommend:8,reduct:0,reevalu:5,refer:[5,8],report:8,request:8,research:8,retain:1,run:8,sampl:1,sample_r:1,scalabl:8,scale:5,schedul:[5,7],scheduler_period:5,search:7,see:8,sentenc:3,sentencepic:1,sentencepiec:[6,8],separ:3,setuptool:8,sgd:5,sh951011:8,shape:0,should:5,similar:2,similarli:1,simpli:8,simplifi:8,singl:5,size:[1,6],slow:5,slowdown:5,small:8,solv:8,sooftwar:8,soohwan:8,sos:1,sos_id:1,sourc:[0,1,2,3,5,6],space:3,spec:1,specifi:5,spectrogram:1,spectrogramdataset:1,start:2,step:5,stft:1,str:[0,1,6],string:[3,6],structur:8,sub:[1,2,3,6],sum:0,support:8,sure:8,target:0,tensor:0,thei:1,thi:[1,2,3,5,6,8],thing:8,those:8,three:2,thruth:0,time:1,time_mask_num:1,timestep:5,togeth:1,token:[1,3],torch:[0,5],torchaudio:8,total_step:2,tpu:8,train_transcript:1,transcript:1,transform:1,transformer_lr_schedul:2,transformerlrschedul:2,transript:1,tri_stag:2,tri_stage_lr_schedul:2,tristagelrschedul:2,troubleshoot:7,two:3,under:8,unit:3,unless:5,updat:5,upgrad:8,use:[1,2,3,6,8],used:5,user:8,using:8,utomat:8,valu:[0,5],varianc:5,version:8,virtual:8,visit:8,vocab:[3,7],vocab_s:[1,6],vocabulari:6,warmup:2,warmup_step:2,wd_ratio:5,websit:8,weight:[0,5],weight_decai:5,when:[0,5],whether:1,which:0,who:8,window:1,word:[0,3],worderrorr:3,wrapper:[5,8],you:8,your:8},titles:["Criterion","Data","LR Scheduler","Metric","Model","Optim","Vocabs","Welcome to Lightning ASR\u2019s documentation!","Introduction"],titleterms:{activ:4,adamp:5,apex:8,asr:7,attent:4,author:8,bit:8,code:8,contribut:8,convolut:4,criterion:0,crossentropi:0,ctc:0,data:1,dataset:1,decod:4,document:7,embed:4,encod:4,feed:4,forward:4,from:8,get:[7,8],indic:7,instal:8,introduct:8,joint:0,label:0,librari:7,librispeech:[1,6],licens:8,lightn:[1,7],loader:1,loss:0,metric:3,model:4,modul:[1,4],optim:5,preprocess:1,prerequisit:8,radam:5,recogn:8,refer:7,schedul:2,smooth:0,sourc:8,speech:8,stage:2,start:[7,8],style:8,tabl:7,train:8,transform:2,tri:2,troubleshoot:8,vocab:6,welcom:7}})
--------------------------------------------------------------------------------
/docs/source/Criterion.rst:
--------------------------------------------------------------------------------
1 | Criterion
2 | =====================================================
3 |
4 | Joint CTC-CrossEntropy Loss
5 | --------------------------------------------
6 | .. automodule:: lightning_asr.criterion.joint_ctc_cross_entropy
7 | :members:
8 |
9 | Label Smoothed CrossEntropy Loss
10 | --------------------------------------------
11 | .. automodule:: lightning_asr.criterion.label_smoothed_cross_entropy
12 | :members:
--------------------------------------------------------------------------------
/docs/source/Data.rst:
--------------------------------------------------------------------------------
1 | Data
2 | =====================================================
3 |
4 | Data Loader
5 | --------------------------------------------
6 | .. automodule:: lightning_asr.data.data_loader
7 | :members:
8 |
9 | Dataset
10 | --------------------------------------------
11 | .. automodule:: lightning_asr.data.dataset
12 | :members:
13 |
14 | Librispeech Preprocess
15 | --------------------------------------------
16 | .. automodule:: lightning_asr.data.librispeech.preprocess
17 | :members:
18 |
19 | Lightning Data Module
20 | --------------------------------------------
21 | .. automodule:: lightning_asr.data.librispeech.lit_data_module
22 | :members:
--------------------------------------------------------------------------------
/docs/source/Lr_scheduler.rst:
--------------------------------------------------------------------------------
1 | LR Scheduler
2 | =====================================================
3 |
4 | LR Scheduler
5 | --------------------------------------------
6 | .. automodule:: lightning_asr.optim.lr_scheduler.lr_scheduler
7 | :members:
8 |
9 | Transformer LR Scheduler
10 | --------------------------------------------
11 | .. automodule:: lightning_asr.optim.lr_scheduler.transformer_lr_scheduler
12 | :members:
13 |
14 | Tri-Stage LR Scheduler
15 | --------------------------------------------
16 | .. automodule:: lightning_asr.optim.lr_scheduler.tri_stage_lr_scheduler
17 | :members:
--------------------------------------------------------------------------------
/docs/source/Metric.rst:
--------------------------------------------------------------------------------
1 | Metric
2 | =====================================================
3 |
4 | Metric
5 | --------------------------------------------
6 | .. automodule:: lightning_asr.metric
7 | :members:
--------------------------------------------------------------------------------
/docs/source/Model.rst:
--------------------------------------------------------------------------------
1 | Model
2 | =====================================================
3 |
4 | Activation
5 | --------------------------------------------
6 | .. automodule:: lightning_asr.model.activation
7 | :members:
8 |
9 | Attention
10 | --------------------------------------------
11 | .. automodule:: lightning_asr.model.attention
12 | :members:
13 |
14 | Convolution
15 | --------------------------------------------
16 | .. automodule:: lightning_asr.model.convolution
17 | :members:
18 |
19 | Decoder
20 | --------------------------------------------
21 | .. automodule:: lightning_asr.model.Decoder
22 | :members:
23 |
24 | Embedding
25 | --------------------------------------------
26 | .. automodule:: lightning_asr.model.embedding
27 | :members:
28 |
29 | Encoder
30 | --------------------------------------------
31 | .. automodule:: lightning_asr.model.encoder
32 | :members:
33 |
34 | Feed Forward
35 | --------------------------------------------
36 | .. automodule:: lightning_asr.model.feed_forward
37 | :members:
38 |
39 | Modules
40 | --------------------------------------------
41 | .. automodule:: lightning_asr.model.modules
42 | :members:
43 |
44 | Model
45 | --------------------------------------------
46 | .. automodule:: lightning_asr.model.model
47 | :members:
--------------------------------------------------------------------------------
/docs/source/Optim.rst:
--------------------------------------------------------------------------------
1 | Optim
2 | =====================================================
3 |
4 | Optimizer
5 | --------------------------------------------
6 | .. automodule:: lightning_asr.optim.optimizer
7 | :members:
8 |
9 | AdamP
10 | --------------------------------------------
11 | .. automodule:: lightning_asr.optim.adamp
12 | :members:
13 |
14 | RAdam
15 | --------------------------------------------
16 | .. automodule:: lightning_asr.optim.radam
17 | :members:
--------------------------------------------------------------------------------
/docs/source/Vocab.rst:
--------------------------------------------------------------------------------
1 | Vocabs
2 | =====================================================
3 |
4 | Vocab
5 | --------------------------------------------
6 | .. automodule:: lightning_asr.vocabs.vocab
7 | :members:
8 |
9 | LibriSpeech Vocab
10 | --------------------------------------------
11 | .. automodule:: lightning_asr.vocabs.librispeech
12 | :members:
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 |
14 | import os
15 | import sys
16 | from recommonmark.parser import CommonMarkParser
17 | sys.path.append(os.path.abspath('.'))
18 | sys.path.append(os.path.abspath('..'))
19 | sys.path.append(os.path.abspath('../..'))
20 | cwd = os.path.dirname(os.path.abspath(__file__))
21 | sys.path.append(os.path.abspath(os.path.join(cwd, '../../')))
22 | sys.path.append(os.path.abspath(os.path.join(cwd, '../../../')))
23 | import sphinx_rtd_theme
24 |
25 |
26 | # -- Project information -----------------------------------------------------
27 |
28 | project = 'lightning_asr'
29 | copyright = '2021, Soohwan Kim'
30 | author = 'Soohwan Kim'
31 |
32 | # The full version, including alpha/beta/rc tags
33 | release = 'latest'
34 |
35 |
36 | # -- General configuration ---------------------------------------------------
37 |
38 | # Add any Sphinx extension module names here, as strings. They can be
39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
40 | # ones.
41 | extensions = [
42 | 'sphinx.ext.autosummary',
43 | 'sphinx.ext.doctest',
44 | 'sphinx.ext.intersphinx',
45 | 'sphinx.ext.todo',
46 | 'sphinx.ext.coverage',
47 | 'sphinx.ext.mathjax',
48 | 'sphinx.ext.ifconfig',
49 | 'sphinx.ext.napoleon',
50 | "sphinx_rtd_theme",
51 | 'sphinx.ext.autodoc',
52 | 'sphinx.ext.imgmath',
53 | 'sphinx.ext.ifconfig',
54 | 'sphinx.ext.viewcode',
55 | 'sphinx.ext.githubpages',
56 | 'recommonmark',
57 | ]
58 |
59 | napoleon_use_ivar = True
60 |
61 | # Add any paths that contain templates here, relative to this directory.
62 | templates_path = ['_templates']
63 |
64 | imgmath_image_format = 'svg'
65 | imgmath_latex = 'xelatex'
66 | imgmath_latex_args = ['--no-pdf']
67 |
68 | # Source parsers
69 | source_parsers = {
70 | #'.md': 'recommonmark.parser.CommonMarkParser'
71 | }
72 |
73 | # The suffix(es) of source filenames.
74 | # You can specify multiple suffix as a list of string:
75 | #
76 | # source_suffix = ['.rst', '.md']
77 | source_suffix = ['.rst', '.md']
78 |
79 | # The master toctree document.
80 | master_doc = 'index'
81 |
82 | # The language for content autogenerated by Sphinx. Refer to documentation
83 | # for a list of supported languages.
84 | #
85 | # This is also used if you do content translation via gettext catalogs.
86 | # Usually you set "language" from the command line for these cases.
87 | language = None
88 |
89 | # List of patterns, relative to source directory, that match files and
90 | # directories to ignore when looking for source files.
91 | # This pattern also affects html_static_path and html_extra_path.
92 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
93 |
94 | # The name of the Pygments (syntax highlighting) style to use.
95 | # The name of the Pygments (syntax highlighting) style to use.
96 | pygments_style = 'sphinx'
97 |
98 | # If true, `todo` and `todoList` produce output, else they produce nothing.
99 | todo_include_todos = True
100 |
101 | # -- Options for HTML output -------------------------------------------------
102 |
103 | # The theme to use for HTML and HTML Help pages. See the documentation for
104 | # a list of builtin themes.
105 | #
106 | html_theme = 'sphinx_rtd_theme'
107 |
108 | # Theme options are theme-specific and customize the look and feel of a theme
109 | # further. For a list of options available for each theme, see the
110 | # documentation.
111 | #
112 | # html_theme_options = {}
113 |
114 | # Add any paths that contain custom static files (such as style sheets) here,
115 | # relative to this directory. They are copied after the builtin static files,
116 | # so a file named "default.css" will overwrite the builtin "default.css".
117 | html_static_path = ['_static']
118 |
119 | # Custom sidebar templates, must be a dictionary that maps document names
120 | # to template names.
121 | #
122 | # The default sidebars (for documents that don't match any pattern) are
123 | # defined by theme itself. Builtin themes are using these templates by
124 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
125 | # 'searchbox.html']``.
126 | #
127 | # html_sidebars = {}
128 |
129 |
130 | # -- Options for HTMLHelp output ---------------------------------------------
131 |
132 | # Output file base name for HTML help builder.
133 | htmlhelp_basename = 'lightning-asr.doc'
134 |
135 |
136 | # -- Options for LaTeX output ------------------------------------------------
137 |
138 | latex_elements = {
139 | # The paper size ('letterpaper' or 'a4paper').
140 | #
141 | # 'papersize': 'letterpaper',
142 |
143 | # The font size ('10pt', '11pt' or '12pt').
144 | #
145 | # 'pointsize': '10pt',
146 |
147 | # Additional stuff for the LaTeX preamble.
148 | #
149 | # 'preamble': '',
150 |
151 | # Latex figure (float) alignment
152 | #
153 | # 'figure_align': 'htbp',
154 | }
155 |
156 | # Grouping the document tree into LaTeX files. List of tuples
157 | # (source start file, target name, title,
158 | # author, documentclass [howto, manual, or own class]).
159 | latex_documents = [
160 | (master_doc, 'lightning-asr.tex', 'Lightning ASR Documentation',
161 | 'Soohwan Kim', 'manual'),
162 | ]
163 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. lasr documentation master file, created by
2 | sphinx-quickstart on Tue Apr 20 22:40:08 2021.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to Lightning ASR's documentation!
7 | ============================================
8 |
9 | .. toctree::
10 | :maxdepth: 1
11 | :caption: GETTING STARTED
12 |
13 | notes/intro
14 |
15 | .. toctree::
16 | :maxdepth: 1
17 | :caption: LIBRARY REFERENCE
18 |
19 | Criterion
20 | Data
21 | Lr_scheduler
22 | Metric
23 | Model
24 | Optim
25 | Vocab
26 |
27 |
28 | Indices and tables
29 | ==================
30 |
31 | * :ref:`genindex`
32 | * :ref:`modindex`
33 | * :ref:`search`
34 |
--------------------------------------------------------------------------------
/docs/source/notes/intro.md:
--------------------------------------------------------------------------------
1 | ## Introduction
2 |
3 | [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) is the lightweight [PyTorch](https://github.com/pytorch/pytorch) wrapper for high-performance AI research. PyTorch is extremely easy to use to build complex AI models. But once the research gets complicated and things like multi-GPU training, 16-bit precision and TPU training get mixed in, users are likely to introduce bugs. PyTorch Lightning solves exactly this problem. Lightning structures your PyTorch code so it can abstract the details of training. This makes AI research scalable and fast to iterate on. This project is an example that implements the asr project with PyTorch Lightning. In this project, I trained a model consisting of a conformer encoder + LSTM decoder with Joint CTC-Attention. The **lasr** means **l**ighthning **a**utomatic **s**peech **r**ecognition. I hope this could be a guideline for those who research speech recognition.
4 |
5 | ## Installation
6 |
7 | This project recommends Python 3.7 or higher.
8 | I recommend creating a new virtual environment for this project (using virtual env or conda).
9 |
10 |
11 | ### Prerequisites
12 |
13 | * Numpy: `pip install numpy` (Refer [here](https://github.com/numpy/numpy) for problem installing Numpy).
14 | * Pytorch: Refer to [PyTorch website](http://pytorch.org/) to install the version w.r.t. your environment.
15 | * librosa: `conda install -c conda-forge librosa` (Refer [here](https://github.com/librosa/librosa) for problem installing librosa)
16 | * torchaudio: `pip install torchaudio==0.6.0` (Refer [here](https://github.com/pytorch/pytorch) for problem installing torchaudio)
17 | * sentencepiece: `pip install sentencepiece` (Refer [here](https://github.com/google/sentencepiece) for problem installing sentencepiece)
18 | * pytorch-lightning: `pip install pytorch-lightning` (Refer [here](https://github.com/PyTorchLightning/pytorch-lightning) for problem installing pytorch-lightning)
19 | * hydra: `pip install hydra-core --upgrade` (Refer [here](https://github.com/facebookresearch/hydra) for problem installing hydra)
20 |
21 | ### Install from source
22 | Currently we only support installation from source code using setuptools. Checkout the source code and run the
23 | following commands:
24 | ```
25 | pip install -e .
26 | ```
27 |
28 | ### Install Apex (for 16-bit training)
29 |
30 | For faster training install NVIDIA's apex library:
31 |
32 | ```
33 | $ git clone https://github.com/NVIDIA/apex
34 | $ cd apex
35 |
36 | # ------------------------
37 | # OPTIONAL: on your cluster you might need to load CUDA 10 or 9
38 | # depending on how you installed PyTorch
39 |
40 | # see available modules
41 | module avail
42 |
43 | # load correct CUDA before install
44 | module load cuda-10.0
45 | # ------------------------
46 |
47 | # make sure you've loaded a cuda version > 4.0 and < 7.0
48 | module load gcc-6.1.0
49 |
50 | $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
51 | ```
52 |
53 | ## Get Started
54 |
55 | I use [Hydra](https://github.com/facebookresearch/hydra) to control all the training configurations. If you are not familiar with Hydra we recommend visiting the [Hydra website](https://hydra.cc/). Generally, Hydra is an open-source framework that simplifies the development of research applications by providing the ability to create a hierarchical configuration dynamically.
56 |
57 | ### Training Speech Recognizer
58 |
59 | You can simply train with LibriSpeech dataset like below:
60 | ```
61 | $ python ./bin/main.py --dataset_path $DATASET_PATH --dataset_download True
62 | ```
63 |
64 | ## Troubleshoots and Contributing
65 | If you have any questions, bug reports, and feature requests, please [open an issue](https://github.com/sooftware/lasr/issues) on Github.
66 |
67 | I appreciate any kind of feedback or contribution. Feel free to proceed with small issues like bug fixes, documentation improvement. For major contributions and new features, please discuss with the collaborators in corresponding issues.
68 |
69 | ### Code Style
70 | I follow [PEP-8](https://www.python.org/dev/peps/pep-0008/) for code style. Especially the style of docstrings is important to generate documentation.
71 |
72 | ### License
73 | This project is licensed under the MIT LICENSE - see the [LICENSE.md](https://github.com/sooftware/lasr/blob/master/LICENSE) file for details
74 |
75 | ## Author
76 |
77 | * Soohwan Kim [@sooftware](https://github.com/sooftware)
78 | * Contacts: sh951011@gmail.com
79 |
--------------------------------------------------------------------------------
/lightning_asr/criterion/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from lightning_asr.criterion.joint_ctc_cross_entropy import JointCTCCrossEntropyLoss
24 | from lightning_asr.criterion.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyLoss
25 |
--------------------------------------------------------------------------------
/lightning_asr/criterion/joint_ctc_cross_entropy.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch.nn as nn
24 | from typing import Tuple
25 | from torch import Tensor
26 |
27 |
28 | class JointCTCCrossEntropyLoss(nn.Module):
29 | """
30 | Privides Joint CTC-CrossEntropy Loss function
31 |
32 | Args:
33 | num_classes (int): the number of classification
34 | ignore_index (int): indexes that are ignored when calculating loss
35 | dim (int): dimension of calculation loss
36 | reduction (str): reduction method [sum, mean] (default: mean)
37 | ctc_weight (float): weight of ctc loss
38 | cross_entropy_weight (float): weight of cross entropy loss
39 | blank_id (int): identification of blank for ctc
40 | """
41 | def __init__(
42 | self,
43 | num_classes: int,
44 | ignore_index: int,
45 | dim: int = -1,
46 | reduction='mean',
47 | ctc_weight: float = 0.3,
48 | cross_entropy_weight: float = 0.7,
49 | blank_id: int = None,
50 | ) -> None:
51 | super(JointCTCCrossEntropyLoss, self).__init__()
52 | self.num_classes = num_classes
53 | self.dim = dim
54 | self.ignore_index = ignore_index
55 | self.reduction = reduction.lower()
56 | self.ctc_weight = ctc_weight
57 | self.cross_entropy_weight = cross_entropy_weight
58 | self.ctc_loss = nn.CTCLoss(blank=blank_id, reduction=self.reduction, zero_infinity=True)
59 | self.cross_entropy_loss = nn.CrossEntropyLoss(reduction=self.reduction, ignore_index=self.ignore_index)
60 |
61 | def forward(
62 | self,
63 | encoder_log_probs: Tensor,
64 | decoder_log_probs: Tensor,
65 | output_lengths: Tensor,
66 | targets: Tensor,
67 | target_lengths: Tensor,
68 | ) -> Tuple[Tensor, Tensor, Tensor]:
69 | ctc_loss = self.ctc_loss(encoder_log_probs, targets, output_lengths, target_lengths)
70 | cross_entropy_loss = self.cross_entropy_loss(decoder_log_probs, targets.contiguous().view(-1))
71 | loss = cross_entropy_loss * self.cross_entropy_weight + ctc_loss * self.ctc_weight
72 | return loss, ctc_loss, cross_entropy_loss
73 |
--------------------------------------------------------------------------------
/lightning_asr/criterion/label_smoothed_cross_entropy.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch
24 | import torch.nn as nn
25 | import torch.nn.functional as F
26 | from torch import Tensor
27 |
28 |
29 | class LabelSmoothedCrossEntropyLoss(nn.Module):
30 | """
31 | Label smoothed cross entropy loss function.
32 |
33 | Args:
34 | num_classes (int): the number of classfication
35 | ignore_index (int): Indexes that are ignored when calculating loss
36 | smoothing (float): ratio of smoothing (confidence = 1.0 - smoothing)
37 | dim (int): dimension of calculation loss
38 | reduction (str): reduction method [sum, mean] (default: sum)
39 |
40 | Inputs: logits, target
41 | logits (torch.Tensor): probability distribution value from model and it has a logarithm shape
42 | target (torch.Tensor): ground-thruth encoded to integers which directly point a word in label
43 |
44 | Returns: label_smoothed
45 | - **label_smoothed** (float): sum of loss
46 | """
47 | def __init__(
48 | self,
49 | num_classes: int,
50 | ignore_index: int,
51 | smoothing: float = 0.1,
52 | dim: int = -1,
53 | reduction='mean',
54 | ) -> None:
55 | super(LabelSmoothedCrossEntropyLoss, self).__init__()
56 | self.confidence = 1.0 - smoothing
57 | self.smoothing = smoothing
58 | self.num_classes = num_classes
59 | self.dim = dim
60 | self.ignore_index = ignore_index
61 | self.reduction = reduction.lower()
62 |
63 | if self.reduction == 'sum':
64 | self.reduction_method = torch.sum
65 | elif self.reduction == 'mean':
66 | self.reduction_method = torch.mean
67 | else:
68 | raise ValueError("Unsupported reduction method {0}".format(reduction))
69 |
70 | def forward(self, logits: Tensor, targets: Tensor) -> Tensor:
71 | if self.smoothing > 0.0:
72 | with torch.no_grad():
73 | label_smoothed = torch.zeros_like(logits)
74 | label_smoothed.fill_(self.smoothing / (self.num_classes - 1))
75 | label_smoothed.scatter_(1, targets.data.unsqueeze(1), self.confidence)
76 | label_smoothed[targets == self.ignore_index, :] = 0
77 | return self.reduction_method(-label_smoothed * logits)
78 |
79 | return F.cross_entropy(logits, targets, ignore_index=self.ignore_index, reduction=self.reduction)
--------------------------------------------------------------------------------
/lightning_asr/data/data_loader.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch
24 | import numpy as np
25 | from torch.utils.data import DataLoader
26 | from torch.utils.data.sampler import Sampler
27 |
28 |
29 | def _collate_fn(batch, pad_id: int = 0):
30 | """ functions that pad to the maximum sequence length """
31 | def seq_length_(p):
32 | return len(p[0])
33 |
34 | def target_length_(p):
35 | return len(p[1])
36 |
37 | # sort by sequence length for rnn.pack_padded_sequence()
38 | batch = sorted(batch, key=lambda sample: sample[0].size(0), reverse=True)
39 |
40 | seq_lengths = [len(s[0]) for s in batch]
41 | target_lengths = [len(s[1]) - 1 for s in batch]
42 |
43 | max_seq_sample = max(batch, key=seq_length_)[0]
44 | max_target_sample = max(batch, key=target_length_)[1]
45 |
46 | max_seq_size = max_seq_sample.size(0)
47 | max_target_size = len(max_target_sample)
48 |
49 | feat_size = max_seq_sample.size(1)
50 | batch_size = len(batch)
51 |
52 | seqs = torch.zeros(batch_size, max_seq_size, feat_size)
53 |
54 | targets = torch.zeros(batch_size, max_target_size).to(torch.long)
55 | targets.fill_(pad_id)
56 |
57 | for x in range(batch_size):
58 | sample = batch[x]
59 | tensor = sample[0]
60 | target = sample[1]
61 | seq_length = tensor.size(0)
62 |
63 | seqs[x].narrow(0, 0, seq_length).copy_(tensor)
64 | targets[x].narrow(0, 0, len(target)).copy_(torch.LongTensor(target))
65 |
66 | seq_lengths = torch.IntTensor(seq_lengths)
67 | target_lengths = torch.IntTensor(target_lengths)
68 |
69 | return seqs, targets, seq_lengths, target_lengths
70 |
71 |
72 | class AudioDataLoader(DataLoader):
73 | """ Audio Data Loader """
74 | def __init__(
75 | self,
76 | dataset: torch.utils.data.Dataset,
77 | num_workers: int,
78 | batch_sampler: torch.utils.data.sampler.Sampler,
79 | **kwargs,
80 | ) -> None:
81 | super(AudioDataLoader, self).__init__(
82 | dataset=dataset,
83 | num_workers=num_workers,
84 | batch_sampler=batch_sampler,
85 | **kwargs,
86 | )
87 | self.collate_fn = _collate_fn
88 |
89 |
90 | class BucketingSampler(Sampler):
91 | """ Samples batches assuming they are in order of size to batch similarly sized samples together. """
92 | def __init__(self, data_source, batch_size: int = 32, drop_last: bool = False) -> None:
93 | super(BucketingSampler, self).__init__(data_source)
94 | self.batch_size = batch_size
95 | self.data_source = data_source
96 | ids = list(range(0, len(data_source)))
97 | self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)]
98 | self.drop_last = drop_last
99 |
100 | def __iter__(self):
101 | for ids in self.bins:
102 | np.random.shuffle(ids)
103 | yield ids
104 |
105 | def __len__(self):
106 | return len(self.bins)
107 |
108 | def shuffle(self, epoch):
109 | np.random.shuffle(self.bins)
110 |
--------------------------------------------------------------------------------
/lightning_asr/data/librispeech/preprocess.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import os
24 | import sentencepiece as spm
25 |
26 | LIBRI_SPEECH_DATASETS = [
27 | 'train-960',
28 | 'dev-clean',
29 | 'dev-other',
30 | 'test-clean',
31 | 'test-other',
32 | ]
33 |
34 |
35 | def collect_transcripts(dataset_path, librispeech_dir: str = 'LibriSpeech'):
36 | """ Collect librispeech transcripts """
37 | transcripts_collection = list()
38 |
39 | for dataset in LIBRI_SPEECH_DATASETS:
40 | dataset_transcripts = list()
41 |
42 | for subfolder1 in os.listdir(os.path.join(dataset_path, dataset)):
43 | for subfolder2 in os.listdir(os.path.join(dataset_path, dataset, subfolder1)):
44 | for file in os.listdir(os.path.join(dataset_path, dataset, subfolder1, subfolder2)):
45 | if file.endswith('txt'):
46 | with open(os.path.join(dataset_path, dataset, subfolder1, subfolder2, file)) as f:
47 | for line in f.readlines():
48 | tokens = line.split()
49 | audio_path = os.path.join(librispeech_dir, dataset, subfolder1, subfolder2, tokens[0])
50 | audio_path = f"{audio_path}.flac"
51 | transcript = " ".join(tokens[1:])
52 | dataset_transcripts.append('%s|%s' % (audio_path, transcript))
53 |
54 | else:
55 | continue
56 |
57 | transcripts_collection.append(dataset_transcripts)
58 |
59 | return transcripts_collection
60 |
61 |
62 | def prepare_tokenizer(train_transcripts, vocab_size):
63 | """ Prepare sentencepice tokenizer """
64 | input_file = 'spm_input.txt'
65 | model_name = 'tokenizer'
66 | model_type = 'unigram'
67 |
68 | with open(input_file, 'w') as f:
69 | for transcript in train_transcripts:
70 | f.write('{}\n'.format(transcript.split('|')[-1]))
71 |
72 | cmd = f"--input={input_file} --model_prefix={model_name} --vocab_size={vocab_size} " \
73 | f"--model_type={model_type} --user_defined_symbols="
74 | spm.SentencePieceTrainer.Train(cmd)
75 |
76 |
77 | def generate_manifest_file(dataset_path: str, part: str, transcripts: list):
78 | """ Generate manifest file """
79 | sp = spm.SentencePieceProcessor()
80 | sp.Load("tokenizer.model")
81 |
82 | with open(f"{dataset_path}/{part}.txt", 'w') as f:
83 | for transcript in transcripts:
84 | audio_path, transcript = transcript.split('|')
85 | text = " ".join(sp.EncodeAsPieces(transcript))
86 | label = " ".join([str(item) for item in sp.EncodeAsIds(transcript)])
87 |
88 | f.write('%s\t%s\t%s\n' % (audio_path, text, label))
89 |
--------------------------------------------------------------------------------
/lightning_asr/hydra_configs/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from lightning_asr.hydra_configs.data import DataConfigs
24 | from lightning_asr.hydra_configs.model import ConformerLSTMConfigs
25 | from lightning_asr.hydra_configs.trainer import (
26 | TrainerGPUConfigs,
27 | TrainerTPUConfigs,
28 | )
29 | from lightning_asr.hydra_configs.lr_scheduler import (
30 | ReduceLROnPlateauLRSchedulerConfigs,
31 | TransformerLRSchedulerConfigs,
32 | TriStageLRSchedulerConfigs,
33 | )
34 | from lightning_asr.hydra_configs.audio import (
35 | SpectrogramConfigs,
36 | MelSpectrogramConfigs,
37 | MFCCConfigs,
38 | FBankConfigs,
39 | )
40 |
--------------------------------------------------------------------------------
/lightning_asr/hydra_configs/audio.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from dataclasses import dataclass
24 |
25 |
26 | @dataclass
27 | class AudioConfigs:
28 | sample_rate: int = 16000
29 | frame_length: float = 25.0
30 | frame_shift: float = 10.0
31 |
32 |
33 | @dataclass
34 | class SpecAugmentConfigs:
35 | apply_spec_augment: bool = True
36 | freq_mask_para: int = 27
37 | freq_mask_num: int = 2
38 | time_mask_num: int = 4
39 |
40 |
41 | @dataclass
42 | class SpectrogramConfigs(AudioConfigs, SpecAugmentConfigs):
43 | num_mels: int = 161
44 | feature_extract_method: str = "spectrogram"
45 |
46 |
47 | @dataclass
48 | class MelSpectrogramConfigs(AudioConfigs, SpecAugmentConfigs):
49 | num_mels: int = 80
50 | feature_extract_method: str = "melspectrogram"
51 |
52 |
53 | @dataclass
54 | class FBankConfigs(AudioConfigs, SpecAugmentConfigs):
55 | num_mels: int = 80
56 | feature_extract_method: str = "fbank"
57 |
58 |
59 | @dataclass
60 | class MFCCConfigs(AudioConfigs, SpecAugmentConfigs):
61 | num_mels: int = 40
62 | feature_extract_method: str = "mfcc"
63 |
--------------------------------------------------------------------------------
/lightning_asr/hydra_configs/data.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from dataclasses import dataclass
24 |
25 |
26 | @dataclass
27 | class DataConfigs:
28 | dataset_path: str = "../../../librispeech"
29 | dataset_download: bool = True
30 | vocab_size: int = 5000
--------------------------------------------------------------------------------
/lightning_asr/hydra_configs/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from dataclasses import dataclass
24 |
25 |
26 | @dataclass
27 | class LRSchedulerConfigs:
28 | lr: float = 1e-04
29 |
30 |
31 | @dataclass
32 | class ReduceLROnPlateauLRSchedulerConfigs(LRSchedulerConfigs):
33 | lr_patience: int = 1
34 | scheduler: str = "reduce_lr_on_plateau"
35 | lr_factor: float = 0.3
36 |
37 |
38 | @dataclass
39 | class TriStageLRSchedulerConfigs(LRSchedulerConfigs):
40 | init_lr: float = 1e-10
41 | peak_lr: float = 1e-04
42 | final_lr: float = 1e-07
43 | init_lr_scale: float = 0.01
44 | final_lr_scale: float = 0.05
45 | warmup_steps: int = 10000
46 | decay_steps: int = 150000
47 | scheduler: str = "tri_stage"
48 |
49 |
50 | @dataclass
51 | class TransformerLRSchedulerConfigs(LRSchedulerConfigs):
52 | peak_lr: float = 1e-04
53 | final_lr: float = 1e-07
54 | final_lr_scale: float = 0.05
55 | warmup_steps: int = 10000
56 | decay_steps: int = 150000
57 | scheduler: str = "transformer"
58 |
--------------------------------------------------------------------------------
/lightning_asr/hydra_configs/model.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from dataclasses import dataclass
24 |
25 |
26 | @dataclass
27 | class ConformerLSTMConfigs:
28 | encoder_dim: int = 256
29 | num_encoder_layers: int = 17
30 | num_decoder_layers: int = 2
31 | num_attention_heads: int = 8
32 | feed_forward_expansion_factor: int = 4
33 | conv_expansion_factor: int = 2
34 | input_dropout_p: float = 0.1
35 | feed_forward_dropout_p: float = 0.1
36 | attention_dropout_p: float = 0.1
37 | conv_dropout_p: float = 0.1
38 | decoder_dropout_p: float = 0.1
39 | conv_kernel_size: int = 31
40 | half_step_residual: bool = True
41 | max_length: int = 128
42 | teacher_forcing_ratio: float = 1.0
43 | cross_entropy_weight: float = 0.7
44 | ctc_weight: float = 0.3
45 | joint_ctc_attention: bool = True
46 | rnn_type: str = "lstm"
47 | optimizer: str = "adam"
48 |
--------------------------------------------------------------------------------
/lightning_asr/hydra_configs/trainer.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from dataclasses import dataclass
24 |
25 |
26 | @dataclass
27 | class BaseTrainerConfigs:
28 | seed: int = 1
29 | accelerator: str = "dp"
30 | precision: int = 16
31 | accumulate_grad_batches: int = 4
32 | amp_backend: str = "apex"
33 | num_workers: int = 4
34 | batch_size: int = 32
35 | check_val_every_n_epoch: int = 1
36 | gradient_clip_val: float = 5.0
37 | use_tensorboard: bool = True
38 | max_epochs: int = 20
39 | auto_scale_batch_size: str = "binsearch"
40 |
41 |
42 | @dataclass
43 | class TrainerGPUConfigs(BaseTrainerConfigs):
44 | use_cuda: bool = True
45 | use_tpu: bool = False
46 | auto_select_gpus: bool = True
47 |
48 |
49 | @dataclass
50 | class TrainerTPUConfigs(BaseTrainerConfigs):
51 | use_cuda: bool = False
52 | use_tpu: bool = True
53 | tpu_cores: int = 8
54 |
--------------------------------------------------------------------------------
/lightning_asr/metric.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import Levenshtein as Lev
24 |
25 |
26 | class ErrorRate(object):
27 | """
28 | Provides inteface of error rate calcuation.
29 |
30 | Note:
31 | Do not use this class directly, use one of the sub classes.
32 | """
33 |
34 | def __init__(self, vocab) -> None:
35 | self.total_dist = 0.0
36 | self.total_length = 0.0
37 | self.vocab = vocab
38 |
39 | def __call__(self, targets, y_hats):
40 | """ Calculating character error rate """
41 | dist, length = self._get_distance(targets, y_hats)
42 | self.total_dist += dist
43 | self.total_length += length
44 | return self.total_dist / self.total_length
45 |
46 | def _get_distance(self, targets, y_hats):
47 | """
48 | Provides total character distance between targets & y_hats
49 |
50 | Args:
51 | targets (torch.Tensor): set of ground truth
52 | y_hats (torch.Tensor): predicted y values (y_hat) by the model
53 |
54 | Returns: total_dist, total_length
55 | - **total_dist**: total distance between targets & y_hats
56 | - **total_length**: total length of targets sequence
57 | """
58 | total_dist = 0
59 | total_length = 0
60 |
61 | for (target, y_hat) in zip(targets, y_hats):
62 | s1 = self.vocab.label_to_string(target)
63 | s2 = self.vocab.label_to_string(y_hat)
64 |
65 | dist, length = self.metric(s1, s2)
66 |
67 | total_dist += dist
68 | total_length += length
69 |
70 | return total_dist, total_length
71 |
72 | def metric(self, *args, **kwargs):
73 | raise NotImplementedError
74 |
75 |
76 | class WordErrorRate(ErrorRate):
77 | """ Provides word error rate calcuation. """
78 |
79 | def __init__(self, vocab) -> None:
80 | super(WordErrorRate, self).__init__(vocab)
81 |
82 | def metric(self, s1, s2):
83 | """
84 | Computes the Unit Error Rate, defined as the edit distance between the
85 | two provided sentences after tokenizing to words.
86 |
87 | Arguments:
88 | s1 (string): space-separated sentence
89 | s2 (string): space-separated sentence
90 | """
91 | # build mapping of words to integers
92 | b = set(s1.split() + s2.split())
93 | unit2char = dict(zip(b, range(len(b))))
94 |
95 | # map the words to a char array (Levenshtein packages only accepts strings)
96 | w1 = [chr(unit2char[w]) for w in s1.split()]
97 | w2 = [chr(unit2char[w]) for w in s2.split()]
98 |
99 | dist = Lev.distance(''.join(w1), ''.join(w2))
100 | length = len(s1.split())
101 | return dist, length
102 |
103 |
104 | class CharacterErrorRate(ErrorRate):
105 | """
106 | Computes the Character Error Rate, defined as the edit distance between the
107 | two provided sentences after tokenizing to characters.
108 | """
109 | def __init__(self, vocab):
110 | super(CharacterErrorRate, self).__init__(vocab)
111 |
112 | def metric(self, s1: str, s2: str):
113 | """
114 | Computes the Character Error Rate, defined as the edit distance between the
115 | two provided sentences after tokenizing to characters.
116 |
117 | Args:
118 | s1 (string): space-separated sentence
119 | s2 (string): space-separated sentence
120 | """
121 | s1 = s1.replace(' ', '')
122 | s2 = s2.replace(' ', '')
123 |
124 | # if '_' in sentence, means subword-unit, delete '_'
125 | if '_' in s1:
126 | s1 = s1.replace('_', '')
127 |
128 | if '_' in s2:
129 | s2 = s2.replace('_', '')
130 |
131 | dist = Lev.distance(s2, s1)
132 | length = len(s1.replace(' ', ''))
133 |
134 | return dist, length
135 |
--------------------------------------------------------------------------------
/lightning_asr/model/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from lightning_asr.model.model import ConformerLSTMModel
24 |
--------------------------------------------------------------------------------
/lightning_asr/model/activation.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch.nn as nn
24 | from torch import Tensor
25 |
26 |
27 | class Swish(nn.Module):
28 | """
29 | Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
30 | to a variety of challenging domains such as Image classification and Machine translation.
31 | """
32 | def __init__(self):
33 | super(Swish, self).__init__()
34 |
35 | def forward(self, inputs: Tensor) -> Tensor:
36 | return inputs * inputs.sigmoid()
37 |
38 |
39 | class GLU(nn.Module):
40 | """
41 | The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
42 | in the paper “Language Modeling with Gated Convolutional Networks”
43 | """
44 | def __init__(self, dim: int) -> None:
45 | super(GLU, self).__init__()
46 | self.dim = dim
47 |
48 | def forward(self, inputs: Tensor) -> Tensor:
49 | outputs, gate = inputs.chunk(2, dim=self.dim)
50 | return outputs * gate.sigmoid()
51 |
--------------------------------------------------------------------------------
/lightning_asr/model/convolution.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch
24 | import torch.nn as nn
25 | from torch import Tensor
26 | from typing import Tuple
27 |
28 | from lightning_asr.model.activation import Swish, GLU
29 | from lightning_asr.model.modules import LayerNorm, Transpose
30 |
31 |
32 | class DepthwiseConv1d(nn.Module):
33 | """
34 | When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
35 | this operation is termed in literature as depthwise convolution.
36 |
37 | Args:
38 | in_channels (int): Number of channels in the input
39 | out_channels (int): Number of channels produced by the convolution
40 | kernel_size (int or tuple): Size of the convolving kernel
41 | stride (int, optional): Stride of the convolution. Default: 1
42 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
43 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True
44 |
45 | Inputs: inputs
46 | - **inputs** (batch, in_channels, time): Tensor containing input vector
47 |
48 | Returns: outputs
49 | - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
50 | """
51 | def __init__(
52 | self,
53 | in_channels: int,
54 | out_channels: int,
55 | kernel_size: int,
56 | stride: int = 1,
57 | padding: int = 0,
58 | bias: bool = False,
59 | ) -> None:
60 | super(DepthwiseConv1d, self).__init__()
61 | assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
62 | self.conv = nn.Conv1d(
63 | in_channels=in_channels,
64 | out_channels=out_channels,
65 | kernel_size=kernel_size,
66 | groups=in_channels,
67 | stride=stride,
68 | padding=padding,
69 | bias=bias,
70 | )
71 |
72 | def forward(self, inputs: Tensor) -> Tensor:
73 | return self.conv(inputs)
74 |
75 |
76 | class PointwiseConv1d(nn.Module):
77 | """
78 | When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
79 | This operation often used to match dimensions.
80 |
81 | Args:
82 | in_channels (int): Number of channels in the input
83 | out_channels (int): Number of channels produced by the convolution
84 | stride (int, optional): Stride of the convolution. Default: 1
85 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
86 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True
87 |
88 | Inputs: inputs
89 | - **inputs** (batch, in_channels, time): Tensor containing input vector
90 |
91 | Returns: outputs
92 | - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
93 | """
94 | def __init__(
95 | self,
96 | in_channels: int,
97 | out_channels: int,
98 | stride: int = 1,
99 | padding: int = 0,
100 | bias: bool = True,
101 | ) -> None:
102 | super(PointwiseConv1d, self).__init__()
103 | self.conv = nn.Conv1d(
104 | in_channels=in_channels,
105 | out_channels=out_channels,
106 | kernel_size=1,
107 | stride=stride,
108 | padding=padding,
109 | bias=bias,
110 | )
111 |
112 | def forward(self, inputs: Tensor) -> Tensor:
113 | return self.conv(inputs)
114 |
115 |
116 | class ConformerConvModule(nn.Module):
117 | """
118 | Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
119 | This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
120 | to aid training deep models.
121 |
122 | Args:
123 | in_channels (int): Number of channels in the input
124 | kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
125 | dropout_p (float, optional): probability of dropout
126 |
127 | Inputs: inputs
128 | inputs (batch, time, dim): Tensor contains input sequences
129 |
130 | Outputs: outputs
131 | outputs (batch, time, dim): Tensor produces by model convolution module.
132 | """
133 | def __init__(
134 | self,
135 | in_channels: int,
136 | kernel_size: int = 31,
137 | expansion_factor: int = 2,
138 | dropout_p: float = 0.1,
139 | ) -> None:
140 | super(ConformerConvModule, self).__init__()
141 | assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
142 | assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
143 |
144 | self.sequential = nn.Sequential(
145 | LayerNorm(in_channels),
146 | Transpose(shape=(1, 2)),
147 | PointwiseConv1d(in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True),
148 | GLU(dim=1),
149 | DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
150 | nn.BatchNorm1d(in_channels),
151 | Swish(),
152 | PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True),
153 | nn.Dropout(p=dropout_p),
154 | )
155 |
156 | def forward(self, inputs: Tensor) -> Tensor:
157 | return self.sequential(inputs).transpose(1, 2)
158 |
159 |
160 | class Conv2dSubampling(nn.Module):
161 | """
162 | Convolutional 2D subsampling (to 1/4 length)
163 |
164 | Args:
165 | in_channels (int): Number of channels in the input image
166 | out_channels (int): Number of channels produced by the convolution
167 |
168 | Inputs: inputs
169 | - **inputs** (batch, time, dim): Tensor containing sequence of inputs
170 |
171 | Returns: outputs, output_lengths
172 | - **outputs** (batch, time, dim): Tensor produced by the convolution
173 | - **output_lengths** (batch): list of sequence output lengths
174 | """
175 | def __init__(self, in_channels: int, out_channels: int) -> None:
176 | super(Conv2dSubampling, self).__init__()
177 | self.sequential = nn.Sequential(
178 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2),
179 | nn.ReLU(),
180 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2),
181 | nn.ReLU(),
182 | )
183 |
184 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
185 | outputs = self.sequential(inputs.unsqueeze(1))
186 | batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
187 |
188 | outputs = outputs.transpose(1, 2)
189 | outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim)
190 |
191 | output_lengths = input_lengths >> 2
192 | output_lengths -= 1
193 |
194 | return outputs, output_lengths
195 |
--------------------------------------------------------------------------------
/lightning_asr/model/decoder.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch
24 | import torch.nn as nn
25 | import random
26 | import torch_xla.core.xla_model as xm
27 | from torch import Tensor, LongTensor
28 | from typing import Tuple, Optional
29 |
30 | from lightning_asr.model.attention import MultiHeadAttention
31 | from lightning_asr.model.modules import View
32 |
33 |
34 | class DecoderRNN(nn.Module):
35 | """
36 | Converts higher level features (from encoder) into output utterances
37 | by specifying a probability distribution over sequences of characters.
38 |
39 | Args:
40 | num_classes (int): number of classification
41 | hidden_state_dim (int): the number of features in the decoder hidden state `h`
42 | pad_id (int, optional): index of the pad symbol (default: 0)
43 | sos_id (int, optional): index of the start of sentence symbol (default: 1)
44 | eos_id (int, optional): index of the end of sentence symbol (default: 2)
45 | num_heads (int, optional): number of attention heads. (default: 4)
46 | num_layers (int, optional): number of recurrent layers (default: 2)
47 | rnn_type (str, optional): type of RNN cell (default: lstm)
48 | dropout_p (float, optional): dropout probability of decoder (default: 0.2)
49 | """
50 |
51 | supported_rnns = {
52 | 'lstm': nn.LSTM,
53 | 'gru': nn.GRU,
54 | 'rnn': nn.RNN,
55 | }
56 |
57 | def __init__(
58 | self,
59 | num_classes: int,
60 | max_length: int = 128,
61 | hidden_state_dim: int = 1024,
62 | pad_id: int = 0,
63 | sos_id: int = 1,
64 | eos_id: int = 2,
65 | num_heads: int = 4,
66 | num_layers: int = 2,
67 | rnn_type: str = 'lstm',
68 | dropout_p: float = 0.3,
69 | use_tpu: bool = False,
70 | ) -> None:
71 | super(DecoderRNN, self).__init__()
72 | self.hidden_state_dim = hidden_state_dim
73 | self.num_classes = num_classes
74 | self.num_heads = num_heads
75 | self.num_layers = num_layers
76 | self.max_length = max_length
77 | self.eos_id = eos_id
78 | self.sos_id = sos_id
79 | self.pad_id = pad_id
80 | self.use_tpu = use_tpu
81 | self.embedding = nn.Embedding(num_classes, hidden_state_dim)
82 | self.input_dropout = nn.Dropout(dropout_p)
83 | self.rnn = self.supported_rnns[rnn_type.lower()](
84 | input_size=hidden_state_dim,
85 | hidden_size=hidden_state_dim,
86 | num_layers=num_layers,
87 | bias=True,
88 | batch_first=True,
89 | dropout=dropout_p,
90 | bidirectional=False,
91 | )
92 | self.attention = MultiHeadAttention(hidden_state_dim, num_heads=num_heads)
93 | self.fc = nn.Sequential(
94 | nn.Linear(hidden_state_dim << 1, hidden_state_dim),
95 | nn.Tanh(),
96 | View(shape=(-1, self.hidden_state_dim), contiguous=True),
97 | nn.Linear(hidden_state_dim, num_classes),
98 | )
99 |
100 | def forward_step(
101 | self,
102 | input_var: Tensor,
103 | hidden_states: Optional[Tensor],
104 | encoder_outputs: Tensor,
105 | ) -> Tuple[Tensor, Tensor, Tensor]:
106 | batch_size, output_lengths = input_var.size(0), input_var.size(1)
107 |
108 | if self.use_tpu:
109 | xla_device = xm.xla_device()
110 | input_var = input_var.to(xla_device)
111 | elif torch.cuda.is_available():
112 | input_var = input_var.cuda()
113 |
114 | embedded = self.embedding(input_var)
115 | embedded = self.input_dropout(embedded)
116 |
117 | if self.training:
118 | self.rnn.flatten_parameters()
119 |
120 | outputs, hidden_states = self.rnn(embedded, hidden_states)
121 | context, attn = self.attention(outputs, encoder_outputs, encoder_outputs)
122 |
123 | outputs = torch.cat((outputs, context), dim=2)
124 |
125 | step_outputs = self.fc(outputs.view(-1, self.hidden_state_dim << 1)).log_softmax(dim=-1)
126 | step_outputs = step_outputs.view(batch_size, output_lengths, -1).squeeze(1)
127 |
128 | return step_outputs, hidden_states, attn
129 |
130 | def forward(
131 | self,
132 | targets: Optional[Tensor] = None,
133 | encoder_outputs: Tensor = None,
134 | teacher_forcing_ratio: float = 1.0,
135 | ) -> Tensor:
136 | """
137 | Forward propagate a `encoder_outputs` for training.
138 |
139 | Args:
140 | targets (torch.LongTensr): A target sequence passed to decoder. `IntTensor` of size ``(batch, seq_length)``
141 | encoder_outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
142 | ``(batch, seq_length, dimension)``
143 | teacher_forcing_ratio (float): ratio of teacher forcing
144 |
145 | Returns:
146 | * predicted_log_probs (torch.FloatTensor): Log probability of model predictions.
147 | """
148 | hidden_states, attn = None, None
149 | predicted_log_probs = list()
150 |
151 | targets, batch_size, max_length = self._validate_args(targets, encoder_outputs, teacher_forcing_ratio)
152 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
153 |
154 | if use_teacher_forcing:
155 | step_outputs, hidden_states, attn = self.forward_step(targets, hidden_states, encoder_outputs)
156 |
157 | for di in range(step_outputs.size(1)):
158 | step_output = step_outputs[:, di, :]
159 | predicted_log_probs.append(step_output)
160 |
161 | else:
162 | input_var = targets[:, 0].unsqueeze(1)
163 |
164 | for di in range(max_length):
165 | step_outputs, hidden_states, attn = self.forward_step(input_var, hidden_states, encoder_outputs)
166 | predicted_log_probs.append(step_outputs)
167 | input_var = predicted_log_probs[-1].topk(1)[1]
168 |
169 | predicted_log_probs = torch.stack(predicted_log_probs, dim=1)
170 |
171 | return predicted_log_probs
172 |
173 | def _validate_args(
174 | self,
175 | targets: Optional[Tensor] = None,
176 | encoder_outputs: Tensor = None,
177 | teacher_forcing_ratio: float = 1.0,
178 | ) -> Tuple[Tensor, int, int]:
179 | """ Validate arguments """
180 | assert encoder_outputs is not None
181 | batch_size = encoder_outputs.size(0)
182 |
183 | if targets is None: # inference
184 | targets = LongTensor([self.sos_id] * batch_size).view(batch_size, 1)
185 | max_length = self.max_length
186 |
187 | if torch.cuda.is_available():
188 | targets = targets.cuda()
189 |
190 | if teacher_forcing_ratio > 0:
191 | raise ValueError("Teacher forcing has to be disabled (set 0) when no targets is provided.")
192 |
193 | else:
194 | max_length = targets.size(1) - 1 # minus the start of sequence symbol
195 |
196 | return targets, batch_size, max_length
197 |
--------------------------------------------------------------------------------
/lightning_asr/model/embedding.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import math
24 | import torch
25 | import torch.nn as nn
26 | from torch import Tensor
27 |
28 |
29 | class PositionalEncoding(nn.Module):
30 | """
31 | Positional Encoding proposed in "Attention Is All You Need".
32 | Since transformer contains no recurrence and no convolution, in order for the model to make
33 | use of the order of the sequence, we must add some positional information.
34 |
35 | "Attention Is All You Need" use sine and cosine functions of different frequencies:
36 | PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
37 | PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
38 | """
39 | def __init__(self, d_model: int = 512, max_len: int = 10000) -> None:
40 | super(PositionalEncoding, self).__init__()
41 | pe = torch.zeros(max_len, d_model, requires_grad=False)
42 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
43 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
44 | pe[:, 0::2] = torch.sin(position * div_term)
45 | pe[:, 1::2] = torch.cos(position * div_term)
46 | pe = pe.unsqueeze(0)
47 | self.register_buffer('pe', pe)
48 |
49 | def forward(self, length: int) -> Tensor:
50 | return self.pe[:, :length]
--------------------------------------------------------------------------------
/lightning_asr/model/feed_forward.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch
24 | import torch.nn as nn
25 | from torch import Tensor
26 |
27 | from lightning_asr.model.activation import Swish
28 | from lightning_asr.model.modules import LayerNorm, Linear
29 |
30 |
31 | class FeedForwardModule(nn.Module):
32 | """
33 | Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit
34 | and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps
35 | regularizing the network.
36 |
37 | Args:
38 | encoder_dim (int): Dimension of model encoder
39 | expansion_factor (int): Expansion factor of feed forward module.
40 | dropout_p (float): Ratio of dropout
41 |
42 | Inputs: inputs
43 | - **inputs** (batch, time, dim): Tensor contains input sequences
44 |
45 | Outputs: outputs
46 | - **outputs** (batch, time, dim): Tensor produces by feed forward module.
47 | """
48 | def __init__(
49 | self,
50 | encoder_dim: int = 512,
51 | expansion_factor: int = 4,
52 | dropout_p: float = 0.1,
53 | ) -> None:
54 | super(FeedForwardModule, self).__init__()
55 | self.sequential = nn.Sequential(
56 | LayerNorm(encoder_dim),
57 | Linear(encoder_dim, encoder_dim * expansion_factor, bias=True),
58 | Swish(),
59 | nn.Dropout(p=dropout_p),
60 | Linear(encoder_dim * expansion_factor, encoder_dim, bias=True),
61 | nn.Dropout(p=dropout_p),
62 | )
63 |
64 | def forward(self, inputs: Tensor) -> Tensor:
65 | return self.sequential(inputs)
66 |
--------------------------------------------------------------------------------
/lightning_asr/model/modules.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch
24 | import torch.nn as nn
25 | import torch.nn.init as init
26 | from torch import Tensor
27 |
28 |
29 | class ResidualConnectionModule(nn.Module):
30 | """
31 | Residual Connection Module.
32 | outputs = (module(inputs) x module_factor + inputs x input_factor)
33 | """
34 | def __init__(self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0):
35 | super(ResidualConnectionModule, self).__init__()
36 | self.module = module
37 | self.module_factor = module_factor
38 | self.input_factor = input_factor
39 |
40 | def forward(self, inputs: Tensor) -> Tensor:
41 | return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor)
42 |
43 |
44 | class LayerNorm(nn.Module):
45 | """ Wrapper class of torch.nn.LayerNorm """
46 | def __init__(self, dim: int, eps: float = 1e-6) -> None:
47 | super(LayerNorm, self).__init__()
48 | self.gamma = nn.Parameter(torch.ones(dim))
49 | self.beta = nn.Parameter(torch.zeros(dim))
50 | self.eps = eps
51 |
52 | def forward(self, z: Tensor) -> Tensor:
53 | mean = z.mean(dim=-1, keepdim=True)
54 | std = z.std(dim=-1, keepdim=True)
55 | output = (z - mean) / (std + self.eps)
56 | output = self.gamma * output + self.beta
57 | return output
58 |
59 |
60 | class Linear(nn.Module):
61 | """
62 | Wrapper class of torch.nn.Linear
63 | Weight initialize by xavier initialization and bias initialize to zeros.
64 | """
65 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
66 | super(Linear, self).__init__()
67 | self.linear = nn.Linear(in_features, out_features, bias=bias)
68 | init.xavier_uniform_(self.linear.weight)
69 | if bias:
70 | init.zeros_(self.linear.bias)
71 |
72 | def forward(self, x: Tensor) -> Tensor:
73 | return self.linear(x)
74 |
75 |
76 | class View(nn.Module):
77 | """ Wrapper class of torch.view() for Sequential module. """
78 | def __init__(self, shape: tuple, contiguous: bool = False):
79 | super(View, self).__init__()
80 | self.shape = shape
81 | self.contiguous = contiguous
82 |
83 | def forward(self, x: Tensor) -> Tensor:
84 | if self.contiguous:
85 | x = x.contiguous()
86 |
87 | return x.view(*self.shape)
88 |
89 |
90 | class Transpose(nn.Module):
91 | """ Wrapper class of torch.transpose() for Sequential module. """
92 | def __init__(self, shape: tuple):
93 | super(Transpose, self).__init__()
94 | self.shape = shape
95 |
96 | def forward(self, x: Tensor) -> Tensor:
97 | return x.transpose(*self.shape)
98 |
--------------------------------------------------------------------------------
/lightning_asr/optim/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from lightning_asr.optim.adamp import AdamP
24 | from lightning_asr.optim.radam import RAdam
25 |
--------------------------------------------------------------------------------
/lightning_asr/optim/adamp.py:
--------------------------------------------------------------------------------
1 | # AdamP
2 | # Copyright (c) 2020-present NAVER Corp.
3 | # MIT license
4 |
5 | import torch
6 | from torch.optim.optimizer import Optimizer
7 | import math
8 |
9 |
10 | class AdamP(Optimizer):
11 | """
12 | Paper: "AdamP: Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights"
13 | Copied from https://github.com/clovaai/AdamP/
14 | Copyright (c) 2020 Naver Corp.
15 | MIT License
16 | """
17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
18 | weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
20 | delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
21 | super(AdamP, self).__init__(params, defaults)
22 |
23 | def _channel_view(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | def _layer_view(self, x):
27 | return x.view(1, -1)
28 |
29 | def _cosine_similarity(self, x, y, eps, view_func):
30 | x = view_func(x)
31 | y = view_func(y)
32 |
33 | x_norm = x.norm(dim=1).add_(eps)
34 | y_norm = y.norm(dim=1).add_(eps)
35 | dot = (x * y).sum(dim=1)
36 |
37 | return dot.abs() / x_norm / y_norm
38 |
39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
40 | wd = 1
41 | expand_size = [-1] + [1] * (len(p.shape) - 1)
42 | for view_func in [self._channel_view, self._layer_view]:
43 |
44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
45 |
46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
49 | wd = wd_ratio
50 |
51 | return perturb, wd
52 |
53 | return perturb, wd
54 |
55 | def step(self, closure=None):
56 | loss = None
57 | if closure is not None:
58 | loss = closure()
59 |
60 | for group in self.param_groups:
61 | for p in group['params']:
62 | if p.grad is None:
63 | continue
64 |
65 | grad = p.grad.data
66 | beta1, beta2 = group['betas']
67 | nesterov = group['nesterov']
68 |
69 | state = self.state[p]
70 |
71 | # State initialization
72 | if len(state) == 0:
73 | state['step'] = 0
74 | state['exp_avg'] = torch.zeros_like(p.data)
75 | state['exp_avg_sq'] = torch.zeros_like(p.data)
76 |
77 | # Adam
78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
79 |
80 | state['step'] += 1
81 | bias_correction1 = 1 - beta1 ** state['step']
82 | bias_correction2 = 1 - beta2 ** state['step']
83 |
84 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
85 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
86 |
87 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
88 | step_size = group['lr'] / bias_correction1
89 |
90 | if nesterov:
91 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
92 | else:
93 | perturb = exp_avg / denom
94 |
95 | # Projection
96 | wd_ratio = 1
97 | if len(p.shape) > 1:
98 | perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'],
99 | group['eps'])
100 |
101 | # Weight decay
102 | if group['weight_decay'] > 0:
103 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio)
104 |
105 | # Step
106 | p.data.add_(-step_size, perturb)
107 |
108 | return loss
109 |
--------------------------------------------------------------------------------
/lightning_asr/optim/lr_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from lightning_asr.optim.lr_scheduler.tri_stage_lr_scheduler import TriStageLRScheduler
24 | from lightning_asr.optim.lr_scheduler.transformer_lr_scheduler import TransformerLRScheduler
25 |
--------------------------------------------------------------------------------
/lightning_asr/optim/lr_scheduler/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | class LearningRateScheduler(object):
25 | """
26 | Provides inteface of learning rate scheduler.
27 |
28 | Note:
29 | Do not use this class directly, use one of the sub classes.
30 | """
31 | def __init__(self, optimizer, init_lr):
32 | self.optimizer = optimizer
33 | self.init_lr = init_lr
34 |
35 | def step(self, *args, **kwargs):
36 | raise NotImplementedError
37 |
38 | @staticmethod
39 | def set_lr(optimizer, lr):
40 | for g in optimizer.param_groups:
41 | g['lr'] = lr
42 |
43 | def get_lr(self):
44 | for g in self.optimizer.param_groups:
45 | return g['lr']
--------------------------------------------------------------------------------
/lightning_asr/optim/lr_scheduler/transformer_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import math
24 |
25 | from lightning_asr.optim.lr_scheduler.lr_scheduler import LearningRateScheduler
26 |
27 |
28 | class TransformerLRScheduler(LearningRateScheduler):
29 | """ Implement the learning rate scheduler in https://arxiv.org/abs/1706.03762 """
30 | def __init__(self, optimizer, peak_lr, final_lr, final_lr_scale, warmup_steps, decay_steps):
31 | assert isinstance(warmup_steps, int), "warmup_steps should be inteager type"
32 | assert isinstance(decay_steps, int), "total_steps should be inteager type"
33 |
34 | super(TransformerLRScheduler, self).__init__(optimizer, 0.0)
35 | self.final_lr = final_lr
36 | self.peak_lr = peak_lr
37 | self.warmup_steps = warmup_steps
38 | self.decay_steps = decay_steps
39 |
40 | self.warmup_rate = self.peak_lr / self.warmup_steps
41 | self.decay_factor = -math.log(final_lr_scale) / self.decay_steps
42 |
43 | self.lr = self.init_lr
44 | self.update_step = 0
45 |
46 | def _decide_stage(self):
47 | if self.update_step < self.warmup_steps:
48 | return 0, self.update_step
49 |
50 | if self.warmup_steps <= self.update_step < self.warmup_steps + self.decay_steps:
51 | return 1, self.update_step - self.warmup_steps
52 |
53 | return 2, None
54 |
55 | def step(self):
56 | self.update_step += 1
57 | stage, steps_in_stage = self._decide_stage()
58 |
59 | if stage == 0:
60 | self.lr = self.update_step * self.warmup_rate
61 | elif stage == 1:
62 | self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
63 | elif stage == 2:
64 | self.lr = self.final_lr
65 | else:
66 | raise ValueError("Undefined stage")
67 |
68 | self.set_lr(self.optimizer, self.lr)
69 |
70 | return self.lr
71 |
--------------------------------------------------------------------------------
/lightning_asr/optim/lr_scheduler/tri_stage_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import math
24 |
25 | from lightning_asr.optim.lr_scheduler.lr_scheduler import LearningRateScheduler
26 |
27 |
28 | class TriStageLRScheduler(LearningRateScheduler):
29 | """
30 | Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf
31 | Similar to inverse_squre_root scheduler,
32 | but tri_stage learning rate employs three stages LR scheduling:
33 |
34 | - warmup stage, starting from `lr` * `init_lr_scale`, linearly
35 | increased to `lr` in `warmup_steps` iterations
36 |
37 | - hold stage, after `warmup_steps`, keep the LR as `lr` for `hold_steps`
38 | iterations
39 |
40 | - decay stage, after hold stage, decay LR exponetially to
41 | `lr` * `final_lr_scale` in `decay_steps`;
42 | after that LR is keep as `final_lr_scale` * `lr`
43 | """
44 | def __init__(self, optimizer, init_lr, peak_lr, final_lr, init_lr_scale, final_lr_scale, warmup_steps, total_steps):
45 | assert isinstance(warmup_steps, int), "warmup_steps should be inteager type"
46 | assert isinstance(total_steps, int), "total_steps should be inteager type"
47 |
48 | super(TriStageLRScheduler, self).__init__(optimizer, init_lr)
49 | self.init_lr *= init_lr_scale
50 | self.final_lr = final_lr
51 | self.peak_lr = peak_lr
52 | self.warmup_steps = warmup_steps
53 | self.hold_steps = int(total_steps >> 1) - warmup_steps
54 | self.decay_steps = int(total_steps >> 1)
55 |
56 | self.warmup_rate = (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0
57 | self.decay_factor = -math.log(final_lr_scale) / self.decay_steps
58 |
59 | self.lr = self.init_lr
60 | self.update_step = 0
61 |
62 | def _decide_stage(self):
63 | if self.update_step < self.warmup_steps:
64 | return 0, self.update_step
65 |
66 | offset = self.warmup_steps
67 |
68 | if self.update_step < offset + self.hold_steps:
69 | return 1, self.update_step - offset
70 |
71 | offset += self.hold_steps
72 |
73 | if self.update_step <= offset + self.decay_steps:
74 | # decay stage
75 | return 2, self.update_step - offset
76 |
77 | offset += self.decay_steps
78 |
79 | return 3, self.update_step - offset
80 |
81 | def step(self):
82 | stage, steps_in_stage = self._decide_stage()
83 |
84 | if stage == 0:
85 | self.lr = self.init_lr + self.warmup_rate * steps_in_stage
86 | elif stage == 1:
87 | self.lr = self.peak_lr
88 | elif stage == 2:
89 | self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
90 | elif stage == 3:
91 | self.lr = self.final_lr
92 | else:
93 | raise ValueError(f"Undefined stage: {stage}")
94 |
95 | self.set_lr(self.optimizer, self.lr)
96 | self.update_step += 1
97 |
98 | return self.lr
99 |
--------------------------------------------------------------------------------
/lightning_asr/optim/optimizer.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch
24 |
25 |
26 | class Optimizer(object):
27 | """
28 | This is wrapper classs of torch.optim.Optimizer.
29 | This class provides functionalities for learning rate scheduling and gradient norm clipping.
30 |
31 | Args:
32 | optim (torch.optim.Optimizer): optimizer object, the parameters to be optimized
33 | should be given when instantiating the object, e.g. torch.optim.Adam, torch.optim.SGD
34 | scheduler (kospeech.optim.lr_scheduler, optional): learning rate scheduler
35 | scheduler_period (int, optional): timestep with learning rate scheduler
36 | max_grad_norm (int, optional): value used for gradient norm clipping
37 | """
38 | def __init__(self, optim, scheduler=None, scheduler_period=None, max_grad_norm=0):
39 | self.optimizer = optim
40 | self.scheduler = scheduler
41 | self.scheduler_period = scheduler_period
42 | self.max_grad_norm = max_grad_norm
43 | self.count = 0
44 |
45 | def step(self, model):
46 | if self.max_grad_norm > 0:
47 | torch.nn.utils.clip_grad_norm_(model.parameters(), self.max_grad_norm)
48 | self.optimizer.step()
49 |
50 | if self.scheduler is not None:
51 | self.update()
52 | self.count += 1
53 |
54 | if self.scheduler_period == self.count:
55 | self.scheduler = None
56 | self.scheduler_period = 0
57 | self.count = 0
58 |
59 | def set_scheduler(self, scheduler, scheduler_period):
60 | self.scheduler = scheduler
61 | self.scheduler_period = scheduler_period
62 | self.count = 0
63 |
64 | def update(self):
65 | if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
66 | pass
67 | else:
68 | self.scheduler.step()
69 |
70 | def zero_grad(self):
71 | self.optimizer.zero_grad()
72 |
73 | def get_lr(self):
74 | for g in self.optimizer.param_groups:
75 | return g['lr']
76 |
77 | def set_lr(self, lr):
78 | for g in self.optimizer.param_groups:
79 | g['lr'] = lr
80 |
--------------------------------------------------------------------------------
/lightning_asr/optim/radam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, LiyuanLucasLiu. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import torch
17 | from torch.optim.optimizer import Optimizer
18 |
19 |
20 | class RAdam(Optimizer):
21 | """
22 | Paper: "On the Variance of the Adaptive Learning Rate and Beyond"
23 | Refer to https://github.com/LiyuanLucasLiu/RAdam
24 | Copyright (c) LiyuanLucasLiu
25 | Apache 2.0 License
26 | """
27 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
28 | if lr < 0.0:
29 | raise ValueError("Invalid learning rate: {}".format(lr))
30 | if eps < 0.0:
31 | raise ValueError("Invalid epsilon value: {}".format(eps))
32 | if not 0.0 <= betas[0] < 1.0:
33 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
34 | if not 0.0 <= betas[1] < 1.0:
35 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
36 |
37 | self.degenerated_to_sgd = degenerated_to_sgd
38 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
39 | for param in params:
40 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
41 | param['buffer'] = [[None, None, None] for _ in range(10)]
42 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
43 | buffer=[[None, None, None] for _ in range(10)])
44 | super(RAdam, self).__init__(params, defaults)
45 |
46 | def __setstate__(self, state):
47 | super(RAdam, self).__setstate__(state)
48 |
49 | def step(self, closure=None):
50 |
51 | loss = None
52 | if closure is not None:
53 | loss = closure()
54 |
55 | for group in self.param_groups:
56 |
57 | for p in group['params']:
58 | if p.grad is None:
59 | continue
60 | grad = p.grad.data.float()
61 | if grad.is_sparse:
62 | raise RuntimeError('RAdam does not support sparse gradients')
63 |
64 | p_data_fp32 = p.data.float()
65 |
66 | state = self.state[p]
67 |
68 | if len(state) == 0:
69 | state['step'] = 0
70 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
71 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
72 | else:
73 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
74 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
75 |
76 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
77 | beta1, beta2 = group['betas']
78 |
79 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
80 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
81 |
82 | state['step'] += 1
83 | buffered = group['buffer'][int(state['step'] % 10)]
84 | if state['step'] == buffered[0]:
85 | N_sma, step_size = buffered[1], buffered[2]
86 | else:
87 | buffered[0] = state['step']
88 | beta2_t = beta2 ** state['step']
89 | N_sma_max = 2 / (1 - beta2) - 1
90 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
91 | buffered[1] = N_sma
92 |
93 | # more conservative since it's an approximated value
94 | if N_sma >= 5:
95 | step_size = math.sqrt(
96 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
97 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
98 | elif self.degenerated_to_sgd:
99 | step_size = 1.0 / (1 - beta1 ** state['step'])
100 | else:
101 | step_size = -1
102 | buffered[2] = step_size
103 |
104 | # more conservative since it's an approximated value
105 | if N_sma >= 5:
106 | if group['weight_decay'] != 0:
107 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
108 | denom = exp_avg_sq.sqrt().add_(group['eps'])
109 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
110 | p.data.copy_(p_data_fp32)
111 | elif step_size > 0:
112 | if group['weight_decay'] != 0:
113 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
114 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
115 | p.data.copy_(p_data_fp32)
116 |
117 | return loss
118 |
--------------------------------------------------------------------------------
/lightning_asr/utilities.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | import logging
23 | import torch
24 | import platform
25 | from omegaconf import DictConfig, OmegaConf
26 | from pytorch_lightning.loggers import TensorBoardLogger
27 |
28 |
29 | def _check_environment(use_cuda: bool, logger) -> int:
30 | """
31 | Check execution envirionment.
32 | OS, Processor, CUDA version, Pytorch version, ... etc.
33 | """
34 | cuda = use_cuda and torch.cuda.is_available()
35 | device = torch.device('cuda' if cuda else 'cpu')
36 |
37 | logger.info(f"Operating System : {platform.system()} {platform.release()}")
38 | logger.info(f"Processor : {platform.processor()}")
39 |
40 | num_devices = torch.cuda.device_count()
41 |
42 | if str(device) == 'cuda':
43 | for idx in range(torch.cuda.device_count()):
44 | logger.info(f"device : {torch.cuda.get_device_name(idx)}")
45 | logger.info(f"CUDA is available : {torch.cuda.is_available()}")
46 | logger.info(f"CUDA version : {torch.version.cuda}")
47 | logger.info(f"PyTorch version : {torch.__version__}")
48 |
49 | else:
50 | logger.info(f"CUDA is available : {torch.cuda.is_available()}")
51 | logger.info(f"PyTorch version : {torch.__version__}")
52 |
53 | return num_devices
54 |
55 |
56 | def parse_configs(configs: DictConfig):
57 | logger = logging.getLogger(__name__)
58 | logger.info(OmegaConf.to_yaml(configs))
59 | num_devices = _check_environment(configs.use_cuda, logger)
60 |
61 | if configs.use_tensorboard:
62 | logger = TensorBoardLogger("tensorboard", name="Lightning Automatic Speech Recognition")
63 | else:
64 | logger = True
65 |
66 | if configs.use_cuda and configs.use_tpu:
67 | raise ValueError("configs.use_cuda and configs.use_tpu both are True, Please choose between GPU and TPU.")
68 |
69 | return logger, num_devices
70 |
--------------------------------------------------------------------------------
/lightning_asr/vocabs/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from lightning_asr.vocabs.librispeech import LibriSpeechVocabulary
24 |
--------------------------------------------------------------------------------
/lightning_asr/vocabs/librispeech.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from lightning_asr.vocabs.vocab import Vocabulary
24 |
25 |
26 | class LibriSpeechVocabulary(Vocabulary):
27 | """
28 | Converts label to string for librispeech dataset.
29 |
30 | Args:
31 | model_path (str): path of sentencepiece model
32 | vocab_size (int): size of vocab
33 | """
34 | def __init__(self, model_path: str, vocab_size: int):
35 | super(LibriSpeechVocabulary, self).__init__()
36 | try:
37 | import sentencepiece as spm
38 | except ImportError:
39 | raise ImportError("Please install sentencepiece: `pip install sentencepiece`")
40 |
41 | self.sp = spm.SentencePieceProcessor()
42 | self.sp.Load(model_path)
43 | self.pad_id = self.sp.PieceToId("")
44 | self.sos_id = self.sp.PieceToId("")
45 | self.eos_id = self.sp.PieceToId("")
46 | self.blank_id = self.sp.PieceToId("")
47 | self.vocab_size = vocab_size
48 |
49 | def label_to_string(self, labels):
50 | if len(labels.shape) == 1:
51 | return self.sp.DecodeIds([l.item() for l in labels])
52 |
53 | elif len(labels.shape) == 2:
54 | sentences = list()
55 |
56 | for label in labels:
57 | sentence = self.sp.DecodeIds([l for l in label])
58 | sentences.append(sentence)
59 | return sentences
60 | else:
61 | raise ValueError("Unsupported label's shape")
62 |
--------------------------------------------------------------------------------
/lightning_asr/vocabs/vocab.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | class Vocabulary(object):
25 | """
26 | Note:
27 | Do not use this class directly, use one of the sub classes.
28 | """
29 | def __init__(self, *args, **kwargs):
30 | self.sos_id = None
31 | self.eos_id = None
32 | self.pad_id = None
33 | self.blank_id = None
34 | self.vocab_size = None
35 |
36 | def __len__(self):
37 | return self.vocab_size
38 |
39 | def label_to_string(self, labels):
40 | raise NotImplementedError
41 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from distutils.core import setup
24 |
25 | setup(
26 | name='lightning_asr',
27 | version='latest',
28 | description='Modular and extensible speech recognition library leveraging pytorch-lightning and hydra',
29 | author='Soohwan Kim',
30 | author_email='kaki.ai@tunib.ai',
31 | url='https://github.com/sooftware/lightning_asr',
32 | install_requires=[
33 | 'torch>=1.4.0',
34 | 'python-Levenshtein',
35 | 'numpy',
36 | 'pandas',
37 | 'astropy',
38 | 'sentencepiece',
39 | 'pytorch-lightning',
40 | 'hydra-core',
41 | 'wget',
42 | ],
43 | keywords=['asr', 'speech_recognition', 'pytorch-lightning'],
44 | python_requires='>=3.7',
45 | )
46 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | conda install -c pytorch torchaudio
2 | conda install -c conda-forge librosa
3 |
--------------------------------------------------------------------------------
/test/test_transformer_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch.nn as nn
24 | import matplotlib.pyplot as plt
25 | from torch import optim
26 |
27 | from lightning_asr.optim.lr_scheduler import TransformerLRScheduler
28 | from lightning_asr.optim.optimizer import Optimizer
29 |
30 |
31 | class Model(nn.Module):
32 | def __init__(self):
33 | super(Model, self).__init__()
34 | self.projection = nn.Linear(10, 10)
35 |
36 | def forward(self):
37 | pass
38 |
39 |
40 | INIT_LR = 1e-10
41 | PEAK_LR = 1e-04
42 | FINAL_LR = 1e-07
43 | INIT_LR_SCALE = 0.01
44 | FINAL_LR_SCALE = 0.1
45 | WARMUP_STEPS = 4000
46 | MAX_GRAD_NORM = 400
47 | TOTAL_STEPS = 120000
48 |
49 | model = Model()
50 |
51 | optimizer = optim.Adam(model.parameters(), lr=INIT_LR)
52 | scheduler = TransformerLRScheduler(
53 | optimizer=optimizer,
54 | peak_lr=PEAK_LR,
55 | final_lr=FINAL_LR,
56 | final_lr_scale=FINAL_LR_SCALE,
57 | warmup_steps=WARMUP_STEPS,
58 | decay_steps=TOTAL_STEPS-WARMUP_STEPS,
59 | )
60 | optimizer = Optimizer(optimizer, scheduler, TOTAL_STEPS, MAX_GRAD_NORM)
61 | lr_processes = list()
62 |
63 | for timestep in range(TOTAL_STEPS):
64 | optimizer.step(model)
65 | lr_processes.append(optimizer.get_lr())
66 |
67 | plt.title('Test Transformer lr scheduler')
68 | plt.plot(lr_processes)
69 | plt.xlabel('timestep', fontsize='large')
70 | plt.ylabel('lr', fontsize='large')
71 | plt.show()
72 |
--------------------------------------------------------------------------------
/test/test_tri_stage_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Soohwan Kim
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch.nn as nn
24 | import matplotlib.pyplot as plt
25 | from torch import optim
26 |
27 | from lightning_asr.optim.lr_scheduler import TriStageLRScheduler
28 | from lightning_asr.optim.optimizer import Optimizer
29 |
30 |
31 | class Model(nn.Module):
32 | def __init__(self):
33 | super(Model, self).__init__()
34 | self.projection = nn.Linear(10, 10)
35 |
36 | def forward(self):
37 | pass
38 |
39 |
40 | INIT_LR = 1e-10
41 | PEAK_LR = 1e-04
42 | FINAL_LR = 1e-07
43 | INIT_LR_SCALE = 0.01
44 | FINAL_LR_SCALE = 0.001
45 | WARMUP_STEPS = 4000
46 | MAX_GRAD_NORM = 400
47 | TOTAL_STEPS = 32000
48 |
49 | model = Model()
50 |
51 | optimizer = optim.Adam(model.parameters(), lr=INIT_LR)
52 | scheduler = TriStageLRScheduler(
53 | optimizer=optimizer,
54 | init_lr=INIT_LR,
55 | peak_lr=PEAK_LR,
56 | final_lr=FINAL_LR,
57 | init_lr_scale=INIT_LR_SCALE,
58 | final_lr_scale=FINAL_LR_SCALE,
59 | warmup_steps=WARMUP_STEPS,
60 | total_steps=TOTAL_STEPS,
61 | )
62 | optimizer = Optimizer(optimizer, scheduler, TOTAL_STEPS, MAX_GRAD_NORM)
63 | lr_processes = list()
64 |
65 | for timestep in range(TOTAL_STEPS):
66 | optimizer.step(model)
67 | lr_processes.append(optimizer.get_lr())
68 |
69 | plt.title('Test Tri-stage lr scheduler')
70 | plt.plot(lr_processes)
71 | plt.xlabel('timestep', fontsize='large')
72 | plt.ylabel('lr', fontsize='large')
73 | plt.show()
74 |
--------------------------------------------------------------------------------