├── .gitignore ├── LICENSE ├── LibMTL.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt ├── LibMTL ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── _record.cpython-38.pyc │ ├── config.cpython-38.pyc │ ├── loss.cpython-38.pyc │ ├── metrics.cpython-38.pyc │ ├── trainer.cpython-38.pyc │ └── utils.cpython-38.pyc ├── _record.py ├── architecture │ ├── CGC.py │ ├── Cross_stitch.py │ ├── DSelect_k.py │ ├── HPS.py │ ├── LTB.py │ ├── MMoE.py │ ├── MTAN.py │ ├── PLE.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── CGC.cpython-38.pyc │ │ ├── Cross_stitch.cpython-38.pyc │ │ ├── DSelect_k.cpython-38.pyc │ │ ├── HPS.cpython-38.pyc │ │ ├── LTB.cpython-38.pyc │ │ ├── MMoE.cpython-38.pyc │ │ ├── MTAN.cpython-38.pyc │ │ ├── PLE.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ └── abstract_arch.cpython-38.pyc │ └── abstract_arch.py ├── config.py ├── loss.py ├── metrics.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── resnet.cpython-38.pyc │ │ └── resnet_dilated.cpython-38.pyc │ ├── resnet.py │ └── resnet_dilated.py ├── trainer.py ├── utils.py └── weighting │ ├── AMTL.py │ ├── Aligned_MTL.py │ ├── Arithmetic.py │ ├── CAGrad.py │ ├── DWA.py │ ├── EW.py │ ├── GLS.py │ ├── GeMTL.py │ ├── GradDrop.py │ ├── GradNorm.py │ ├── GradVac.py │ ├── IMTL.py │ ├── IMTL_G.py │ ├── IMTL_L.py │ ├── LSBwD.py │ ├── LSBwoD.py │ ├── MGDA.py │ ├── MoCo.py │ ├── Nash_MTL.py │ ├── PCGrad.py │ ├── RLW.py │ ├── SI.py │ ├── SI_naive.py │ ├── UW.py │ ├── __init__.py │ ├── __pycache__ │ ├── AMTL.cpython-38.pyc │ ├── AMTL_GeM.cpython-38.pyc │ ├── AMTL_GeM_anti.cpython-38.pyc │ ├── AMTL_GeM_curri.cpython-38.pyc │ ├── AMTL_SI.cpython-38.pyc │ ├── Aligned_MTL.cpython-38.pyc │ ├── Arithmetic.cpython-38.pyc │ ├── CAGrad.cpython-38.pyc │ ├── DWA.cpython-38.pyc │ ├── EW.cpython-38.pyc │ ├── GLS.cpython-38.pyc │ ├── GeMTL.cpython-38.pyc │ ├── GeM_anti.cpython-38.pyc │ ├── GeM_curri.cpython-38.pyc │ ├── GeMopt0.cpython-38.pyc │ ├── GeMopt1.cpython-38.pyc │ ├── GeMopt10.cpython-38.pyc │ ├── GeMopt11.cpython-38.pyc │ ├── GeMopt2.cpython-38.pyc │ ├── GeMopt3.cpython-38.pyc │ ├── GeMopt4.cpython-38.pyc │ ├── GradDrop.cpython-38.pyc │ ├── GradNorm.cpython-38.pyc │ ├── GradVac.cpython-38.pyc │ ├── IMTL.cpython-38.pyc │ ├── IMTL_G.cpython-38.pyc │ ├── IMTL_L.cpython-38.pyc │ ├── LSBwD.cpython-38.pyc │ ├── LSBwoD.cpython-38.pyc │ ├── MBMTL_AM.cpython-38.pyc │ ├── MBMTL_AM_10ep.cpython-38.pyc │ ├── MGDA.cpython-38.pyc │ ├── MoCo.cpython-38.pyc │ ├── Nash_MTL.cpython-38.pyc │ ├── PCGrad.cpython-38.pyc │ ├── RLW.cpython-38.pyc │ ├── SI.cpython-38.pyc │ ├── SI_naive.cpython-38.pyc │ ├── UW.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ └── abstract_weighting.cpython-38.pyc │ └── abstract_weighting.py ├── README.md ├── docs ├── Makefile ├── README.md ├── _build │ ├── doctrees │ │ ├── README.doctree │ │ ├── autoapi_templates │ │ │ └── python │ │ │ │ └── module.doctree │ │ ├── docs │ │ │ ├── _autoapi │ │ │ │ └── LibMTL │ │ │ │ │ ├── _record │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── architecture │ │ │ │ │ ├── CGC │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── Cross_stitch │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── DSelect_k │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── HPS │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── MMoE │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── MTAN │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── PLE │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── abstract_arch │ │ │ │ │ │ └── index.doctree │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── config │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── index.doctree │ │ │ │ │ ├── loss │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── metrics │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── model │ │ │ │ │ ├── index.doctree │ │ │ │ │ ├── resnet │ │ │ │ │ │ └── index.doctree │ │ │ │ │ └── resnet_dilated │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── trainer │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── utils │ │ │ │ │ └── index.doctree │ │ │ │ │ └── weighting │ │ │ │ │ ├── CAGrad │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── DWA │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── EW │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── GLS │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── GradDrop │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── GradNorm │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── GradVac │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── IMTL │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── MGDA │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── PCGrad │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── RLW │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── UW │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── abstract_weighting │ │ │ │ │ └── index.doctree │ │ │ │ │ └── index.doctree │ │ │ ├── develop │ │ │ │ ├── arch.doctree │ │ │ │ ├── dataset.doctree │ │ │ │ └── weighting.doctree │ │ │ ├── getting_started │ │ │ │ ├── installation.doctree │ │ │ │ ├── introduction.doctree │ │ │ │ └── quick_start.doctree │ │ │ └── user_guide │ │ │ │ ├── benchmark.doctree │ │ │ │ ├── benchmark │ │ │ │ ├── nyuv2.doctree │ │ │ │ └── office.doctree │ │ │ │ ├── framework.doctree │ │ │ │ └── mtl.doctree │ │ ├── environment.pickle │ │ └── index.doctree │ └── html │ │ ├── .buildinfo │ │ ├── README.html │ │ ├── _images │ │ ├── framework.png │ │ ├── multi_input.png │ │ └── rep_grad.png │ │ ├── _modules │ │ ├── LibMTL │ │ │ ├── architecture │ │ │ │ ├── CGC.html │ │ │ │ ├── Cross_stitch.html │ │ │ │ ├── DSelect_k.html │ │ │ │ ├── HPS.html │ │ │ │ ├── MMoE.html │ │ │ │ ├── MTAN.html │ │ │ │ ├── PLE.html │ │ │ │ └── abstract_arch.html │ │ │ ├── config.html │ │ │ ├── loss.html │ │ │ ├── metrics.html │ │ │ ├── model │ │ │ │ ├── resnet.html │ │ │ │ └── resnet_dilated.html │ │ │ ├── trainer.html │ │ │ ├── utils.html │ │ │ └── weighting │ │ │ │ ├── CAGrad.html │ │ │ │ ├── DWA.html │ │ │ │ ├── EW.html │ │ │ │ ├── GLS.html │ │ │ │ ├── GradDrop.html │ │ │ │ ├── GradNorm.html │ │ │ │ ├── GradVac.html │ │ │ │ ├── IMTL.html │ │ │ │ ├── MGDA.html │ │ │ │ ├── PCGrad.html │ │ │ │ ├── RLW.html │ │ │ │ ├── UW.html │ │ │ │ └── abstract_weighting.html │ │ └── index.html │ │ ├── _sources │ │ ├── README.md.txt │ │ ├── autoapi_templates │ │ │ └── python │ │ │ │ └── module.rst.txt │ │ ├── docs │ │ │ ├── _autoapi │ │ │ │ └── LibMTL │ │ │ │ │ ├── _record │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── architecture │ │ │ │ │ ├── CGC │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── Cross_stitch │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── DSelect_k │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── HPS │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── MMoE │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── MTAN │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── PLE │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── abstract_arch │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── config │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── index.rst.txt │ │ │ │ │ ├── loss │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── metrics │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── model │ │ │ │ │ ├── index.rst.txt │ │ │ │ │ ├── resnet │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ └── resnet_dilated │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── trainer │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── utils │ │ │ │ │ └── index.rst.txt │ │ │ │ │ └── weighting │ │ │ │ │ ├── CAGrad │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── DWA │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── EW │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── GLS │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── GradDrop │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── GradNorm │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── GradVac │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── IMTL │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── MGDA │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── PCGrad │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── RLW │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── UW │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── abstract_weighting │ │ │ │ │ └── index.rst.txt │ │ │ │ │ └── index.rst.txt │ │ │ ├── develop │ │ │ │ ├── arch.md.txt │ │ │ │ ├── dataset.md.txt │ │ │ │ └── weighting.md.txt │ │ │ ├── getting_started │ │ │ │ ├── installation.md.txt │ │ │ │ ├── introduction.md.txt │ │ │ │ └── quick_start.md.txt │ │ │ └── user_guide │ │ │ │ ├── benchmark.rst.txt │ │ │ │ ├── benchmark │ │ │ │ ├── nyuv2.md.txt │ │ │ │ └── office.md.txt │ │ │ │ ├── framework.md.txt │ │ │ │ └── mtl.md.txt │ │ └── index.rst.txt │ │ ├── _static │ │ ├── basic.css │ │ ├── css │ │ │ ├── badge_only.css │ │ │ ├── fonts │ │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ │ ├── fontawesome-webfont.eot │ │ │ │ ├── fontawesome-webfont.svg │ │ │ │ ├── fontawesome-webfont.ttf │ │ │ │ ├── fontawesome-webfont.woff │ │ │ │ ├── fontawesome-webfont.woff2 │ │ │ │ ├── lato-bold-italic.woff │ │ │ │ ├── lato-bold-italic.woff2 │ │ │ │ ├── lato-bold.woff │ │ │ │ ├── lato-bold.woff2 │ │ │ │ ├── lato-normal-italic.woff │ │ │ │ ├── lato-normal-italic.woff2 │ │ │ │ ├── lato-normal.woff │ │ │ │ └── lato-normal.woff2 │ │ │ └── theme.css │ │ ├── doctools.js │ │ ├── documentation_options.js │ │ ├── file.png │ │ ├── graphviz.css │ │ ├── jquery-3.5.1.js │ │ ├── jquery.js │ │ ├── js │ │ │ ├── badge_only.js │ │ │ ├── html5shiv-printshiv.min.js │ │ │ ├── html5shiv.min.js │ │ │ └── theme.js │ │ ├── language_data.js │ │ ├── minus.png │ │ ├── plus.png │ │ ├── pygments.css │ │ ├── searchtools.js │ │ ├── theme_overrides.css │ │ ├── underscore-1.13.1.js │ │ └── underscore.js │ │ ├── autoapi_templates │ │ └── python │ │ │ └── module.html │ │ ├── docs │ │ ├── _autoapi │ │ │ └── LibMTL │ │ │ │ ├── _record │ │ │ │ └── index.html │ │ │ │ ├── architecture │ │ │ │ ├── CGC │ │ │ │ │ └── index.html │ │ │ │ ├── Cross_stitch │ │ │ │ │ └── index.html │ │ │ │ ├── DSelect_k │ │ │ │ │ └── index.html │ │ │ │ ├── HPS │ │ │ │ │ └── index.html │ │ │ │ ├── MMoE │ │ │ │ │ └── index.html │ │ │ │ ├── MTAN │ │ │ │ │ └── index.html │ │ │ │ ├── PLE │ │ │ │ │ └── index.html │ │ │ │ ├── abstract_arch │ │ │ │ │ └── index.html │ │ │ │ └── index.html │ │ │ │ ├── config │ │ │ │ └── index.html │ │ │ │ ├── index.html │ │ │ │ ├── loss │ │ │ │ └── index.html │ │ │ │ ├── metrics │ │ │ │ └── index.html │ │ │ │ ├── model │ │ │ │ ├── index.html │ │ │ │ ├── resnet │ │ │ │ │ └── index.html │ │ │ │ └── resnet_dilated │ │ │ │ │ └── index.html │ │ │ │ ├── trainer │ │ │ │ └── index.html │ │ │ │ ├── utils │ │ │ │ └── index.html │ │ │ │ └── weighting │ │ │ │ ├── CAGrad │ │ │ │ └── index.html │ │ │ │ ├── DWA │ │ │ │ └── index.html │ │ │ │ ├── EW │ │ │ │ └── index.html │ │ │ │ ├── GLS │ │ │ │ └── index.html │ │ │ │ ├── GradDrop │ │ │ │ └── index.html │ │ │ │ ├── GradNorm │ │ │ │ └── index.html │ │ │ │ ├── GradVac │ │ │ │ └── index.html │ │ │ │ ├── IMTL │ │ │ │ └── index.html │ │ │ │ ├── MGDA │ │ │ │ └── index.html │ │ │ │ ├── PCGrad │ │ │ │ └── index.html │ │ │ │ ├── RLW │ │ │ │ └── index.html │ │ │ │ ├── UW │ │ │ │ └── index.html │ │ │ │ ├── abstract_weighting │ │ │ │ └── index.html │ │ │ │ └── index.html │ │ ├── develop │ │ │ ├── arch.html │ │ │ ├── dataset.html │ │ │ └── weighting.html │ │ ├── getting_started │ │ │ ├── installation.html │ │ │ ├── introduction.html │ │ │ └── quick_start.html │ │ └── user_guide │ │ │ ├── benchmark.html │ │ │ ├── benchmark │ │ │ ├── nyuv2.html │ │ │ └── office.html │ │ │ ├── framework.html │ │ │ └── mtl.html │ │ ├── genindex.html │ │ ├── index.html │ │ ├── objects.inv │ │ ├── py-modindex.html │ │ ├── search.html │ │ └── searchindex.js ├── _static │ └── theme_overrides.css ├── _templates │ └── footer.html ├── autoapi_templates │ └── python │ │ └── module.rst ├── conf.py ├── docs │ ├── _autoapi │ │ └── LibMTL │ │ │ ├── _record │ │ │ └── index.rst │ │ │ ├── architecture │ │ │ ├── CGC │ │ │ │ └── index.rst │ │ │ ├── Cross_stitch │ │ │ │ └── index.rst │ │ │ ├── DSelect_k │ │ │ │ └── index.rst │ │ │ ├── HPS │ │ │ │ └── index.rst │ │ │ ├── MMoE │ │ │ │ └── index.rst │ │ │ ├── MTAN │ │ │ │ └── index.rst │ │ │ ├── PLE │ │ │ │ └── index.rst │ │ │ ├── abstract_arch │ │ │ │ └── index.rst │ │ │ └── index.rst │ │ │ ├── config │ │ │ └── index.rst │ │ │ ├── index.rst │ │ │ ├── loss │ │ │ └── index.rst │ │ │ ├── metrics │ │ │ └── index.rst │ │ │ ├── model │ │ │ ├── index.rst │ │ │ ├── resnet │ │ │ │ └── index.rst │ │ │ └── resnet_dilated │ │ │ │ └── index.rst │ │ │ ├── trainer │ │ │ └── index.rst │ │ │ ├── utils │ │ │ └── index.rst │ │ │ └── weighting │ │ │ ├── CAGrad │ │ │ └── index.rst │ │ │ ├── DWA │ │ │ └── index.rst │ │ │ ├── EW │ │ │ └── index.rst │ │ │ ├── GLS │ │ │ └── index.rst │ │ │ ├── GradDrop │ │ │ └── index.rst │ │ │ ├── GradNorm │ │ │ └── index.rst │ │ │ ├── GradVac │ │ │ └── index.rst │ │ │ ├── IMTL │ │ │ └── index.rst │ │ │ ├── MGDA │ │ │ └── index.rst │ │ │ ├── PCGrad │ │ │ └── index.rst │ │ │ ├── RLW │ │ │ └── index.rst │ │ │ ├── UW │ │ │ └── index.rst │ │ │ ├── abstract_weighting │ │ │ └── index.rst │ │ │ └── index.rst │ ├── develop │ │ ├── arch.md │ │ ├── dataset.md │ │ └── weighting.md │ ├── getting_started │ │ ├── installation.md │ │ ├── introduction.md │ │ └── quick_start.md │ ├── images │ │ ├── framework.png │ │ ├── multi_input.png │ │ └── rep_grad.png │ ├── references.bib │ └── user_guide │ │ ├── benchmark.rst │ │ ├── benchmark │ │ ├── nyuv2.md │ │ └── office.md │ │ ├── framework.md │ │ └── mtl.md ├── index.rst └── requirements.txt ├── examples ├── README.md ├── nyusp │ ├── README.md │ ├── __pycache__ │ │ ├── aspp.cpython-38.pyc │ │ ├── create_dataset.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── aspp.py │ ├── create_dataset.py │ ├── data_split.json │ ├── main.py │ ├── main_segnet.py │ ├── run.sh │ ├── segnet_mtan.py │ └── utils.py └── office │ ├── README.md │ ├── __pycache__ │ └── create_dataset.cpython-38.pyc │ ├── create_dataset.py │ ├── data_txt │ ├── office-31 │ │ ├── amazon_test.txt │ │ ├── amazon_train.txt │ │ ├── amazon_val.txt │ │ ├── dslr_test.txt │ │ ├── dslr_train.txt │ │ ├── dslr_val.txt │ │ ├── webcam_test.txt │ │ ├── webcam_train.txt │ │ └── webcam_val.txt │ └── office-home │ │ ├── Art_test.txt │ │ ├── Art_train.txt │ │ ├── Art_val.txt │ │ ├── Clipart_test.txt │ │ ├── Clipart_train.txt │ │ ├── Clipart_val.txt │ │ ├── Product_test.txt │ │ ├── Product_train.txt │ │ ├── Product_val.txt │ │ ├── Real_World_test.txt │ │ ├── Real_World_train.txt │ │ └── Real_World_val.txt │ ├── main.py │ └── run.sh ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # *.pth -------------------------------------------------------------------------------- /LibMTL.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: LibMTL 3 | Version: 1.1.5 4 | Summary: A PyTorch Library for Multi-Task Learning 5 | Home-page: https://github.com/median-research-group/LibMTL 6 | Author: Baijiong Lin 7 | Author-email: linbj@mail.sustech.edu.cn 8 | License: MIT 9 | Platform: all 10 | Classifier: Intended Audience :: Developers 11 | Classifier: Intended Audience :: Education 12 | Classifier: Intended Audience :: Science/Research 13 | Classifier: License :: OSI Approved :: MIT License 14 | Classifier: Programming Language :: Python :: 3.7 15 | Classifier: Programming Language :: Python :: 3.8 16 | Classifier: Programming Language :: Python :: 3.9 17 | Classifier: Programming Language :: Python :: 3.10 18 | Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence 19 | Classifier: Topic :: Scientific/Engineering :: Mathematics 20 | Classifier: Topic :: Software Development :: Libraries 21 | Description-Content-Type: text/markdown 22 | License-File: LICENSE 23 | 24 | # GeMTL 25 | 26 | This is an implementation of exploiting the generalized mean for per-task loss aggregation in multi-task learning. 27 | Our code is mainly based on [LibMTL](https://github.com/median-research-group/LibMTL?tab=readme-ov-file). 28 | 29 | ## Getting started 30 | 31 | 1. Create a virtual environment 32 | 33 | ```shell 34 | conda create -n gemtl python=3.8 35 | conda activate gemtl 36 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 37 | ``` 38 | 39 | 2. Clone this repository 40 | 41 | 3. Install `LibMTL` 42 | 43 | ```shell 44 | cd GeMTL 45 | pip install -e . 46 | ``` 47 | 48 | ## Requirements 49 | 50 | - Python >= 3.8 51 | - Pytorch >= 1.8.1 52 | 53 | ```shell 54 | pip install -r requirements.txt 55 | ``` 56 | 57 | 58 | ## Dataset 59 | 60 | You can download datasets in the following links. 61 | - [NYUv2](https://github.com/lorenmt/mtan) 62 | - [Office-Home](https://www.hemanthdv.org/officeHomeDataset.html) 63 | 64 | 65 | ## Run 66 | 67 | Training and testing codes are in `./examples/{nyusp, office}/run.sh`. 68 | You can check the results by running the following command. 69 | 70 | ```shell 71 | bash ./examples/{datasets}/run.sh 72 | ``` 73 | 74 | ## Reference 75 | 76 | Our implementation is developed on the following repositories. Thanks to the contributors! 77 | - [LibMTL](https://github.com/median-research-group/LibMTL?tab=readme-ov-file) 78 | - [CAGrad](https://github.com/Cranial-XIX/CAGrad) 79 | - [mtan](https://github.com/lorenmt/mtan) 80 | 81 | 82 | ## License 83 | 84 | This repository is released under the [MIT](./LICENSE) license. 85 | -------------------------------------------------------------------------------- /LibMTL.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | setup.py 4 | LibMTL/__init__.py 5 | LibMTL/_record.py 6 | LibMTL/config.py 7 | LibMTL/loss.py 8 | LibMTL/metrics.py 9 | LibMTL/trainer.py 10 | LibMTL/utils.py 11 | LibMTL.egg-info/PKG-INFO 12 | LibMTL.egg-info/SOURCES.txt 13 | LibMTL.egg-info/dependency_links.txt 14 | LibMTL.egg-info/requires.txt 15 | LibMTL.egg-info/top_level.txt 16 | LibMTL/architecture/CGC.py 17 | LibMTL/architecture/Cross_stitch.py 18 | LibMTL/architecture/DSelect_k.py 19 | LibMTL/architecture/HPS.py 20 | LibMTL/architecture/LTB.py 21 | LibMTL/architecture/MMoE.py 22 | LibMTL/architecture/MTAN.py 23 | LibMTL/architecture/PLE.py 24 | LibMTL/architecture/__init__.py 25 | LibMTL/architecture/abstract_arch.py 26 | LibMTL/model/__init__.py 27 | LibMTL/model/resnet.py 28 | LibMTL/model/resnet_dilated.py 29 | LibMTL/weighting/AMTL.py 30 | LibMTL/weighting/Aligned_MTL.py 31 | LibMTL/weighting/Arithmetic.py 32 | LibMTL/weighting/CAGrad.py 33 | LibMTL/weighting/DWA.py 34 | LibMTL/weighting/EW.py 35 | LibMTL/weighting/GLS.py 36 | LibMTL/weighting/GeMTL.py 37 | LibMTL/weighting/GradDrop.py 38 | LibMTL/weighting/GradNorm.py 39 | LibMTL/weighting/GradVac.py 40 | LibMTL/weighting/IMTL.py 41 | LibMTL/weighting/IMTL_G.py 42 | LibMTL/weighting/IMTL_L.py 43 | LibMTL/weighting/LSBwD.py 44 | LibMTL/weighting/LSBwoD.py 45 | LibMTL/weighting/MGDA.py 46 | LibMTL/weighting/MoCo.py 47 | LibMTL/weighting/Nash_MTL.py 48 | LibMTL/weighting/PCGrad.py 49 | LibMTL/weighting/RLW.py 50 | LibMTL/weighting/SI.py 51 | LibMTL/weighting/SI_naive.py 52 | LibMTL/weighting/UW.py 53 | LibMTL/weighting/__init__.py 54 | LibMTL/weighting/abstract_weighting.py -------------------------------------------------------------------------------- /LibMTL.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /LibMTL.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | torch>=1.8.0 2 | torchvision>=0.9.0 3 | numpy>=1.20 4 | -------------------------------------------------------------------------------- /LibMTL.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | LibMTL 2 | -------------------------------------------------------------------------------- /LibMTL/__init__.py: -------------------------------------------------------------------------------- 1 | from . import architecture 2 | from . import model 3 | from . import weighting 4 | from .trainer import Trainer 5 | from . import config 6 | from . import loss 7 | from . import metrics 8 | #from .record import PerformanceMeter 9 | from . import utils -------------------------------------------------------------------------------- /LibMTL/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/__pycache__/_record.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/__pycache__/_record.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/architecture/CGC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.architecture.MMoE import MMoE 7 | 8 | 9 | class CGC(MMoE): 10 | r"""Customized Gate Control (CGC). 11 | 12 | This method is proposed in `Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) `_ \ 13 | and implemented by us. 14 | 15 | Args: 16 | img_size (list): The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224. 17 | num_experts (list): The numbers of experts shared by all the tasks and specific to each task, respectively. Each expert is an encoder network. 18 | 19 | """ 20 | def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): 21 | super(CGC, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 22 | 23 | self.num_experts = {task: self.kwargs['num_experts'][tn+1] for tn, task in enumerate(self.task_name)} 24 | self.num_experts['share'] = self.kwargs['num_experts'][0] 25 | self.experts_specific = nn.ModuleDict({task: nn.ModuleList([encoder_class() for _ in range(self.num_experts[task])]) for task in self.task_name}) 26 | self.gate_specific = nn.ModuleDict({task: nn.Sequential(nn.Linear(self.input_size, 27 | self.num_experts['share']+self.num_experts[task]), 28 | nn.Softmax(dim=-1)) for task in self.task_name}) 29 | 30 | def forward(self, inputs, task_name=None): 31 | experts_shared_rep = torch.stack([e(inputs) for e in self.experts_shared]) 32 | out = {} 33 | for task in self.task_name: 34 | if task_name is not None and task != task_name: 35 | continue 36 | experts_specific_rep = torch.stack([e(inputs) for e in self.experts_specific[task]]) 37 | selector = self.gate_specific[task](torch.flatten(inputs, start_dim=1)) 38 | gate_rep = torch.einsum('ij..., ji -> j...', 39 | torch.cat([experts_shared_rep, experts_specific_rep], dim=0), 40 | selector) 41 | gate_rep = self._prepare_rep(gate_rep, task, same_rep=False) 42 | out[task] = self.decoders[task](gate_rep) 43 | return out 44 | 45 | -------------------------------------------------------------------------------- /LibMTL/architecture/Cross_stitch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.architecture.abstract_arch import AbsArchitecture 7 | 8 | class _transform_resnet_cross(nn.Module): 9 | def __init__(self, encoder_list, task_name, device): 10 | super(_transform_resnet_cross, self).__init__() 11 | 12 | self.task_name = task_name 13 | self.task_num = len(task_name) 14 | self.device = device 15 | self.resnet_conv = nn.ModuleDict({task: nn.Sequential(encoder_list[tn].conv1, encoder_list[tn].bn1, 16 | encoder_list[tn].relu, encoder_list[tn].maxpool) for tn, task in enumerate(self.task_name)}) 17 | self.resnet_layer = nn.ModuleDict({}) 18 | for i in range(4): 19 | self.resnet_layer[str(i)] = nn.ModuleList([]) 20 | for tn in range(self.task_num): 21 | encoder = encoder_list[tn] 22 | self.resnet_layer[str(i)].append(eval('encoder.layer'+str(i+1))) 23 | self.cross_unit = nn.Parameter(torch.ones(4, self.task_num, self.task_num)) 24 | 25 | def forward(self, inputs): 26 | s_rep = {task: self.resnet_conv[task](inputs) for task in self.task_name} 27 | ss_rep = {i: [0]*self.task_num for i in range(4)} 28 | for i in range(4): 29 | for tn, task in enumerate(self.task_name): 30 | if i == 0: 31 | ss_rep[i][tn] = self.resnet_layer[str(i)][tn](s_rep[task]) 32 | else: 33 | cross_rep = sum([self.cross_unit[i-1][tn][j]*ss_rep[i-1][j] for j in range(self.task_num)]) 34 | ss_rep[i][tn] = self.resnet_layer[str(i)][tn](cross_rep) 35 | return ss_rep[3] 36 | 37 | class Cross_stitch(AbsArchitecture): 38 | r"""Cross-stitch Networks (Cross_stitch). 39 | 40 | This method is proposed in `Cross-stitch Networks for Multi-task Learning (CVPR 2016) `_ \ 41 | and implemented by us. 42 | 43 | .. warning:: 44 | - :class:`Cross_stitch` does not work with multiple inputs MTL problem, i.e., ``multi_input`` must be ``False``. 45 | 46 | - :class:`Cross_stitch` is only supported by ResNet-based encoders. 47 | 48 | """ 49 | def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): 50 | super(Cross_stitch, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 51 | 52 | if self.multi_input: 53 | raise ValueError('No support Cross Stitch for multiple inputs MTL problem') 54 | 55 | self.encoder = nn.ModuleList([self.encoder_class() for _ in range(self.task_num)]) 56 | self.encoder = _transform_resnet_cross(self.encoder, task_name, device) 57 | -------------------------------------------------------------------------------- /LibMTL/architecture/HPS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.architecture.abstract_arch import AbsArchitecture 7 | 8 | 9 | class HPS(AbsArchitecture): 10 | r"""Hard Parameter Sharing (HPS). 11 | 12 | This method is proposed in `Multitask Learning: A Knowledge-Based Source of Inductive Bias (ICML 1993) `_ \ 13 | and implemented by us. 14 | """ 15 | def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): 16 | super(HPS, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 17 | self.encoder = self.encoder_class() 18 | -------------------------------------------------------------------------------- /LibMTL/architecture/MMoE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.architecture.abstract_arch import AbsArchitecture 7 | 8 | class MMoE(AbsArchitecture): 9 | r"""Multi-gate Mixture-of-Experts (MMoE). 10 | 11 | This method is proposed in `Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts (KDD 2018) `_ \ 12 | and implemented by us. 13 | 14 | Args: 15 | img_size (list): The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224. 16 | num_experts (int): The number of experts shared for all tasks. Each expert is an encoder network. 17 | 18 | """ 19 | def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): 20 | super(MMoE, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 21 | 22 | self.img_size = self.kwargs['img_size'] 23 | self.input_size = np.array(self.img_size, dtype=int).prod() 24 | self.num_experts = self.kwargs['num_experts'][0] 25 | self.experts_shared = nn.ModuleList([encoder_class() for _ in range(self.num_experts)]) 26 | self.gate_specific = nn.ModuleDict({task: nn.Sequential(nn.Linear(self.input_size, self.num_experts), 27 | nn.Softmax(dim=-1)) for task in self.task_name}) 28 | 29 | def forward(self, inputs, task_name=None): 30 | experts_shared_rep = torch.stack([e(inputs) for e in self.experts_shared]) 31 | out = {} 32 | for task in self.task_name: 33 | if task_name is not None and task != task_name: 34 | continue 35 | selector = self.gate_specific[task](torch.flatten(inputs, start_dim=1)) 36 | gate_rep = torch.einsum('ij..., ji -> j...', experts_shared_rep, selector) 37 | gate_rep = self._prepare_rep(gate_rep, task, same_rep=False) 38 | out[task] = self.decoders[task](gate_rep) 39 | return out 40 | 41 | def get_share_params(self): 42 | return self.experts_shared.parameters() 43 | 44 | def zero_grad_share_params(self): 45 | self.experts_shared.zero_grad() 46 | -------------------------------------------------------------------------------- /LibMTL/architecture/__init__.py: -------------------------------------------------------------------------------- 1 | from LibMTL.architecture.abstract_arch import AbsArchitecture 2 | from LibMTL.architecture.HPS import HPS 3 | from LibMTL.architecture.Cross_stitch import Cross_stitch 4 | from LibMTL.architecture.MMoE import MMoE 5 | from LibMTL.architecture.MTAN import MTAN 6 | from LibMTL.architecture.CGC import CGC 7 | from LibMTL.architecture.PLE import PLE 8 | from LibMTL.architecture.DSelect_k import DSelect_k 9 | from LibMTL.architecture.LTB import LTB 10 | 11 | __all__ = ['AbsArchitecture', 12 | 'HPS', 13 | 'Cross_stitch', 14 | 'MMoE', 15 | 'MTAN', 16 | 'CGC', 17 | 'PLE', 18 | 'DSelect_k', 19 | 'LTB'] -------------------------------------------------------------------------------- /LibMTL/architecture/__pycache__/CGC.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/architecture/__pycache__/CGC.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/architecture/__pycache__/Cross_stitch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/architecture/__pycache__/Cross_stitch.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/architecture/__pycache__/DSelect_k.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/architecture/__pycache__/DSelect_k.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/architecture/__pycache__/HPS.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/architecture/__pycache__/HPS.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/architecture/__pycache__/LTB.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/architecture/__pycache__/LTB.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/architecture/__pycache__/MMoE.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/architecture/__pycache__/MMoE.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/architecture/__pycache__/MTAN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/architecture/__pycache__/MTAN.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/architecture/__pycache__/PLE.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/architecture/__pycache__/PLE.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/architecture/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/architecture/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/architecture/__pycache__/abstract_arch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/architecture/__pycache__/abstract_arch.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/architecture/abstract_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class AbsArchitecture(nn.Module): 8 | r"""An abstract class for MTL architectures. 9 | 10 | Args: 11 | task_name (list): A list of strings for all tasks. 12 | encoder_class (class): A neural network class. 13 | decoders (dict): A dictionary of name-decoder pairs of type (:class:`str`, :class:`torch.nn.Module`). 14 | rep_grad (bool): If ``True``, the gradient of the representation for each task can be computed. 15 | multi_input (bool): Is ``True`` if each task has its own input data, otherwise is ``False``. 16 | device (torch.device): The device where model and data will be allocated. 17 | kwargs (dict): A dictionary of hyperparameters of architectures. 18 | 19 | """ 20 | def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): 21 | super(AbsArchitecture, self).__init__() 22 | 23 | self.task_name = task_name 24 | self.task_num = len(task_name) 25 | self.encoder_class = encoder_class 26 | self.decoders = decoders 27 | self.rep_grad = rep_grad 28 | self.multi_input = multi_input 29 | self.device = device 30 | self.kwargs = kwargs 31 | 32 | if self.rep_grad: 33 | self.rep_tasks = {} 34 | self.rep = {} 35 | 36 | def forward(self, inputs, task_name=None): 37 | r""" 38 | 39 | Args: 40 | inputs (torch.Tensor): The input data. 41 | task_name (str, default=None): The task name corresponding to ``inputs`` if ``multi_input`` is ``True``. 42 | 43 | Returns: 44 | dict: A dictionary of name-prediction pairs of type (:class:`str`, :class:`torch.Tensor`). 45 | """ 46 | out = {} 47 | s_rep = self.encoder(inputs) 48 | same_rep = True if not isinstance(s_rep, list) and not self.multi_input else False 49 | for tn, task in enumerate(self.task_name): 50 | if task_name is not None and task != task_name: 51 | continue 52 | ss_rep = s_rep[tn] if isinstance(s_rep, list) else s_rep 53 | ss_rep = self._prepare_rep(ss_rep, task, same_rep) 54 | out[task] = self.decoders[task](ss_rep) 55 | return out 56 | 57 | def get_share_params(self): 58 | r"""Return the shared parameters of the model. 59 | """ 60 | return self.encoder.parameters() 61 | 62 | def zero_grad_share_params(self): 63 | r"""Set gradients of the shared parameters to zero. 64 | """ 65 | self.encoder.zero_grad() 66 | 67 | def _prepare_rep(self, rep, task, same_rep=None): 68 | if self.rep_grad: 69 | if not same_rep: 70 | self.rep[task] = rep 71 | else: 72 | self.rep = rep 73 | self.rep_tasks[task] = rep.detach().clone() 74 | self.rep_tasks[task].requires_grad = True 75 | return self.rep_tasks[task] 76 | else: 77 | return rep 78 | -------------------------------------------------------------------------------- /LibMTL/loss.py: -------------------------------------------------------------------------------- 1 | import torch, time 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class AbsLoss(object): 7 | r"""An abstract class for loss functions. 8 | """ 9 | def __init__(self): 10 | self.record = [] 11 | self.bs = [] 12 | 13 | def compute_loss(self, pred, gt): 14 | r"""Calculate the loss. 15 | 16 | Args: 17 | pred (torch.Tensor): The prediction tensor. 18 | gt (torch.Tensor): The ground-truth tensor. 19 | 20 | Return: 21 | torch.Tensor: The loss. 22 | """ 23 | pass 24 | 25 | def _update_loss(self, pred, gt): 26 | loss = self.compute_loss(pred, gt) 27 | self.record.append(loss.item()) 28 | self.bs.append(pred.size()[0]) 29 | return loss 30 | 31 | def _average_loss(self): 32 | record = np.array(self.record) 33 | bs = np.array(self.bs) 34 | return (record*bs).sum()/bs.sum() 35 | 36 | def _reinit(self): 37 | self.record = [] 38 | self.bs = [] 39 | 40 | class CELoss(AbsLoss): 41 | r"""The cross-entropy loss function. 42 | """ 43 | def __init__(self): 44 | super(CELoss, self).__init__() 45 | 46 | self.loss_fn = nn.CrossEntropyLoss() 47 | 48 | def compute_loss(self, pred, gt): 49 | r""" 50 | """ 51 | loss = self.loss_fn(pred, gt) 52 | return loss 53 | 54 | class KLDivLoss(AbsLoss): 55 | r"""The Kullback-Leibler divergence loss function. 56 | """ 57 | def __init__(self): 58 | super(KLDivLoss, self).__init__() 59 | 60 | self.loss_fn = nn.KLDivLoss() 61 | 62 | def compute_loss(self, pred, gt): 63 | r""" 64 | """ 65 | loss = self.loss_fn(pred, gt) 66 | return loss 67 | 68 | class L1Loss(AbsLoss): 69 | r"""The Mean Absolute Error (MAE) loss function. 70 | """ 71 | def __init__(self): 72 | super(L1Loss, self).__init__() 73 | 74 | self.loss_fn = nn.L1Loss() 75 | 76 | def compute_loss(self, pred, gt): 77 | r""" 78 | """ 79 | loss = self.loss_fn(pred, gt) 80 | return loss 81 | 82 | class MSELoss(AbsLoss): 83 | r"""The Mean Squared Error (MSE) loss function. 84 | """ 85 | def __init__(self): 86 | super(MSELoss, self).__init__() 87 | 88 | self.loss_fn = nn.MSELoss() 89 | 90 | def compute_loss(self, pred, gt): 91 | r""" 92 | """ 93 | loss = self.loss_fn(pred, gt) 94 | return loss -------------------------------------------------------------------------------- /LibMTL/metrics.py: -------------------------------------------------------------------------------- 1 | import torch, time 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class AbsMetric(object): 7 | r"""An abstract class for the performance metrics of a task. 8 | 9 | Attributes: 10 | record (list): A list of the metric scores in every iteration. 11 | bs (list): A list of the number of data in every iteration. 12 | """ 13 | def __init__(self): 14 | self.record = [] 15 | self.bs = [] 16 | 17 | @property 18 | def update_fun(self, pred, gt): 19 | r"""Calculate the metric scores in every iteration and update :attr:`record`. 20 | 21 | Args: 22 | pred (torch.Tensor): The prediction tensor. 23 | gt (torch.Tensor): The ground-truth tensor. 24 | """ 25 | pass 26 | 27 | @property 28 | def score_fun(self): 29 | r"""Calculate the final score (when an epoch ends). 30 | 31 | Return: 32 | list: A list of metric scores. 33 | """ 34 | pass 35 | 36 | def reinit(self): 37 | r"""Reset :attr:`record` and :attr:`bs` (when an epoch ends). 38 | """ 39 | self.record = [] 40 | self.bs = [] 41 | 42 | # accuracy 43 | class AccMetric(AbsMetric): 44 | r"""Calculate the accuracy. 45 | """ 46 | def __init__(self): 47 | super(AccMetric, self).__init__() 48 | 49 | def update_fun(self, pred, gt): 50 | r""" 51 | """ 52 | pred = F.softmax(pred, dim=-1).max(-1)[1] 53 | self.record.append(gt.eq(pred).sum().item()) 54 | self.bs.append(pred.size()[0]) 55 | 56 | def score_fun(self): 57 | r""" 58 | """ 59 | return [(sum(self.record)/sum(self.bs))] 60 | 61 | 62 | # L1 Error 63 | class L1Metric(AbsMetric): 64 | r"""Calculate the Mean Absolute Error (MAE). 65 | """ 66 | def __init__(self): 67 | super(L1Metric, self).__init__() 68 | 69 | def update_fun(self, pred, gt): 70 | r""" 71 | """ 72 | abs_err = torch.abs(pred - gt) 73 | self.record.append(abs_err.item()) 74 | self.bs.append(pred.size()[0]) 75 | 76 | def score_fun(self): 77 | r""" 78 | """ 79 | records = np.array(self.record) 80 | batch_size = np.array(self.bs) 81 | return [(records*batch_size).sum()/(sum(batch_size))] 82 | -------------------------------------------------------------------------------- /LibMTL/model/__init__.py: -------------------------------------------------------------------------------- 1 | from LibMTL.model.resnet import resnet18 2 | from LibMTL.model.resnet import resnet34 3 | from LibMTL.model.resnet import resnet50 4 | from LibMTL.model.resnet import resnet101 5 | from LibMTL.model.resnet import resnet152 6 | from LibMTL.model.resnet import resnext50_32x4d 7 | from LibMTL.model.resnet import resnext101_32x8d 8 | from LibMTL.model.resnet import wide_resnet50_2 9 | from LibMTL.model.resnet import wide_resnet101_2 10 | from LibMTL.model.resnet_dilated import resnet_dilated 11 | 12 | __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 13 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 14 | 'wide_resnet50_2', 'wide_resnet101_2', 'resnet_dilated'] 15 | -------------------------------------------------------------------------------- /LibMTL/model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/model/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/model/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/model/__pycache__/resnet_dilated.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/model/__pycache__/resnet_dilated.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/model/resnet_dilated.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import LibMTL.model.resnet as resnet 3 | 4 | class ResnetDilated(nn.Module): 5 | def __init__(self, orig_resnet, dilate_scale=8): 6 | super(ResnetDilated, self).__init__() 7 | from functools import partial 8 | 9 | if dilate_scale == 8: 10 | orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) 11 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) 12 | elif dilate_scale == 16: 13 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) 14 | 15 | # take pre-defined ResNet, except AvgPool and FC 16 | self.conv1 = orig_resnet.conv1 17 | self.bn1 = orig_resnet.bn1 18 | self.relu = orig_resnet.relu 19 | 20 | self.maxpool = orig_resnet.maxpool 21 | self.layer1 = orig_resnet.layer1 22 | self.layer2 = orig_resnet.layer2 23 | self.layer3 = orig_resnet.layer3 24 | self.layer4 = orig_resnet.layer4 25 | 26 | self.feature_dim = orig_resnet.feature_dim 27 | 28 | def _nostride_dilate(self, m, dilate): 29 | classname = m.__class__.__name__ 30 | if classname.find('Conv') != -1: 31 | # the convolution with stride 32 | if m.stride == (2, 2): 33 | m.stride = (1, 1) 34 | if m.kernel_size == (3, 3): 35 | m.dilation = (dilate//2, dilate//2) 36 | m.padding = (dilate//2, dilate//2) 37 | # other convoluions 38 | else: 39 | if m.kernel_size == (3, 3): 40 | m.dilation = (dilate, dilate) 41 | m.padding = (dilate, dilate) 42 | 43 | def forward(self, x): 44 | x = self.relu(self.bn1(self.conv1(x))) 45 | x = self.maxpool(x) 46 | 47 | x = self.layer1(x) 48 | x = self.layer2(x) 49 | x = self.layer3(x) 50 | x = self.layer4(x) 51 | return x 52 | 53 | def resnet_dilated(basenet, pretrained=True, dilate_scale=8): 54 | r"""Dilated Residual Network models from `"Dilated Residual Networks" `_ 55 | 56 | Args: 57 | basenet (str): The type of ResNet. 58 | pretrained (bool): If True, returns a model pre-trained on ImageNet. 59 | dilate_scale ({8, 16}, default=8): The type of dilating process. 60 | """ 61 | return ResnetDilated(resnet.__dict__[basenet](pretrained=pretrained), dilate_scale=dilate_scale) -------------------------------------------------------------------------------- /LibMTL/utils.py: -------------------------------------------------------------------------------- 1 | import random, torch, os 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | def get_root_dir(): 6 | r"""Return the root path of project.""" 7 | return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | 9 | def set_random_seed(seed): 10 | r"""Set the random seed for reproducibility. 11 | 12 | Args: 13 | seed (int, default=0): The random seed. 14 | """ 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | if torch.cuda.is_available(): 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | 22 | def set_device(gpu_id): 23 | r"""Set the device where model and data will be allocated. 24 | 25 | Args: 26 | gpu_id (str, default='0'): The id of gpu. 27 | """ 28 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id 29 | 30 | def count_parameters(model): 31 | r'''Calculate the number of parameters for a model. 32 | 33 | Args: 34 | model (torch.nn.Module): A neural network module. 35 | ''' 36 | trainable_params = 0 37 | non_trainable_params = 0 38 | for p in model.parameters(): 39 | if p.requires_grad: 40 | trainable_params += p.numel() 41 | else: 42 | non_trainable_params += p.numel() 43 | print('='*40) 44 | print('Total Params:', trainable_params + non_trainable_params) 45 | print('Trainable Params:', trainable_params) 46 | print('Non-trainable Params:', non_trainable_params) 47 | 48 | def count_improvement(base_result, new_result, weight): 49 | r"""Calculate the improvement between two results as 50 | 51 | .. math:: 52 | \Delta_{\mathrm{p}}=100\%\times \frac{1}{T}\sum_{t=1}^T 53 | \frac{1}{M_t}\sum_{m=1}^{M_t}\frac{(-1)^{w_{t,m}}(B_{t,m}-N_{t,m})}{N_{t,m}}. 54 | 55 | Args: 56 | base_result (dict): A dictionary of scores of all metrics of all tasks. 57 | new_result (dict): The same structure with ``base_result``. 58 | weight (dict): The same structure with ``base_result`` while each element is binary integer representing whether higher or lower score is better. 59 | 60 | Returns: 61 | float: The improvement between ``new_result`` and ``base_result``. 62 | 63 | Examples:: 64 | 65 | base_result = {'A': [96, 98], 'B': [0.2]} 66 | new_result = {'A': [93, 99], 'B': [0.5]} 67 | weight = {'A': [1, 0], 'B': [1]} 68 | 69 | print(count_improvement(base_result, new_result, weight)) 70 | """ 71 | improvement = 0 72 | count = 0 73 | for task in list(base_result.keys()): 74 | improvement += (((-1)**np.array(weight[task]))*\ 75 | (np.array(base_result[task])-np.array(new_result[task]))/\ 76 | np.array(base_result[task])).mean() 77 | count += 1 78 | return improvement/count -------------------------------------------------------------------------------- /LibMTL/weighting/AMTL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class AMTL(AbsWeighting): 9 | r"""Achievement-based Multi-task Learning (AMTL). 10 | 11 | This method is proposed in `Achievement-based Training Progress Balancing for Multi-Task Learning (ICCV 2023) `_ \ 12 | and implemented by us. 13 | 14 | """ 15 | def __init__(self): 16 | super(AMTL, self).__init__() 17 | 18 | def backward(self, losses, **kwargs): 19 | # Load hyperparameters 20 | if not hasattr(self, 'potentials'): 21 | self.potentials = eval(kwargs['potentials_' + kwargs['dataset_str']]) 22 | assert isinstance(self.potentials, list), 'TypeError: type of potentials should be List.' 23 | self.potentials = torch.Tensor(self.potentials).to(self.device) * 1.05 # multiply a slight margin 24 | self.focusing_factor = kwargs['focusing_factor'] 25 | 26 | # Geometric at first epoch 27 | if not hasattr(self, 'val_results'): 28 | weight = torch.ones_like(losses).to(self.device) 29 | 30 | else: 31 | # Given validation results (self.model.val_results), 32 | def get_achievement(cur_results): 33 | cur_achievement = [1] * self.task_num 34 | for tn, task in enumerate(self.task_name): 35 | for i, weight in enumerate(self.task_dict[task]['weight']): # i: metric number 36 | cur_achievement[tn] *= cur_results[task][i] ** (2*weight-1) # (1,0) -> (1,-1) 37 | cur_achievement[tn] = cur_achievement[tn] ** (1/len(self.task_dict[task]['weight'])) 38 | return torch.Tensor(cur_achievement).unsqueeze(1).to(self.device) 39 | 40 | cur_achievement = get_achievement(self.val_results) 41 | weight = torch.pow(1 - cur_achievement / self.potentials, self.focusing_factor) 42 | weight = torch.softmax(weight, dim=0) 43 | 44 | # loss = torch.pow( torch.pow(losses, weight).prod(), 1./self.task_num) # GM 45 | loss = torch.mul(losses, weight).sum() # AM 46 | # p = 2 47 | # loss = torch.pow( (torch.mul(torch.pow(losses, p), weight)).sum() / self.task_num, 1/p) # QM 48 | loss.backward() 49 | batch_weight = losses / (self.task_num * losses.prod()) 50 | return batch_weight.detach().cpu().numpy() -------------------------------------------------------------------------------- /LibMTL/weighting/Aligned_MTL.py: -------------------------------------------------------------------------------- 1 | import torch, copy 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class Aligned_MTL(AbsWeighting): 9 | r"""Aligned-MTL. 10 | 11 | This method is proposed in `Independent Component Alignment for Multi-Task Learning (CVPR 2023) `_ \ 12 | and implemented by modifying from the `official PyTorch implementation `_. 13 | 14 | """ 15 | def __init__(self): 16 | super(Aligned_MTL, self).__init__() 17 | 18 | def backward(self, losses, **kwargs): 19 | 20 | grads = self._get_grads(losses, mode='backward') 21 | if self.rep_grad: 22 | per_grads, grads = grads[0], grads[1] 23 | 24 | M = torch.matmul(grads, grads.t()) # [num_tasks, num_tasks] 25 | lmbda, V = torch.symeig(M, eigenvectors=True) 26 | tol = ( 27 | torch.max(lmbda) 28 | * max(M.shape[-2:]) 29 | * torch.finfo().eps 30 | ) 31 | rank = sum(lmbda > tol) 32 | 33 | order = torch.argsort(lmbda, dim=-1, descending=True) 34 | lmbda, V = lmbda[order][:rank], V[:, order][:, :rank] 35 | 36 | sigma = torch.diag(1 / lmbda.sqrt()) 37 | B = lmbda[-1].sqrt() * ((V @ sigma) @ V.t()) 38 | alpha = B.sum(0) 39 | 40 | if self.rep_grad: 41 | self._backward_new_grads(alpha, per_grads=per_grads) 42 | else: 43 | self._backward_new_grads(alpha, grads=grads) 44 | return alpha.detach().cpu().numpy() 45 | -------------------------------------------------------------------------------- /LibMTL/weighting/Arithmetic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class Arithmetic(AbsWeighting): 9 | r"""Equal Weighting (EW). 10 | 11 | The loss weight for each task is always ``1 / T`` in every iteration, where ``T`` denotes the number of tasks. 12 | 13 | """ 14 | def __init__(self): 15 | super(Arithmetic, self).__init__() 16 | 17 | def backward(self, losses, **kwargs): 18 | loss = torch.mul(losses, torch.ones_like(losses).to(self.device)).sum() / self.task_num 19 | loss.backward() 20 | return np.ones(self.task_num) -------------------------------------------------------------------------------- /LibMTL/weighting/CAGrad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | from scipy.optimize import minimize 9 | 10 | class CAGrad(AbsWeighting): 11 | r"""Conflict-Averse Gradient descent (CAGrad). 12 | 13 | This method is proposed in `Conflict-Averse Gradient Descent for Multi-task learning (NeurIPS 2021) `_ \ 14 | and implemented by modifying from the `official PyTorch implementation `_. 15 | 16 | Args: 17 | calpha (float, default=0.5): A hyperparameter that controls the convergence rate. 18 | rescale ({0, 1, 2}, default=1): The type of the gradient rescaling. 19 | 20 | .. warning:: 21 | CAGrad is not supported by representation gradients, i.e., ``rep_grad`` must be ``False``. 22 | 23 | """ 24 | def __init__(self): 25 | super(CAGrad, self).__init__() 26 | 27 | def backward(self, losses, **kwargs): 28 | calpha, rescale = kwargs['calpha'], kwargs['rescale'] 29 | if self.rep_grad: 30 | raise ValueError('No support method CAGrad with representation gradients (rep_grad=True)') 31 | # per_grads = self._compute_grad(losses, mode='backward', rep_grad=True) 32 | # grads = per_grads.reshape(self.task_num, self.rep.size()[0], -1).sum(1) 33 | else: 34 | self._compute_grad_dim() 35 | grads = self._compute_grad(losses, mode='backward') 36 | 37 | GG = torch.matmul(grads, grads.t()).cpu() # [num_tasks, num_tasks] 38 | g0_norm = (GG.mean()+1e-8).sqrt() # norm of the average gradient 39 | 40 | x_start = np.ones(self.task_num) / self.task_num 41 | bnds = tuple((0,1) for x in x_start) 42 | cons=({'type':'eq','fun':lambda x:1-sum(x)}) 43 | A = GG.numpy() 44 | b = x_start.copy() 45 | c = (calpha*g0_norm+1e-8).item() 46 | def objfn(x): 47 | return (x.reshape(1,-1).dot(A).dot(b.reshape(-1,1))+c*np.sqrt(x.reshape(1,-1).dot(A).dot(x.reshape(-1,1))+1e-8)).sum() 48 | res = minimize(objfn, x_start, bounds=bnds, constraints=cons) 49 | w_cpu = res.x 50 | ww = torch.Tensor(w_cpu).to(self.device) 51 | gw = (grads * ww.view(-1, 1)).sum(0) 52 | gw_norm = gw.norm() 53 | lmbda = c / (gw_norm+1e-8) 54 | g = grads.mean(0) + lmbda * gw 55 | if rescale == 0: 56 | new_grads = g 57 | elif rescale == 1: 58 | new_grads = g / (1+calpha**2) 59 | elif rescale == 2: 60 | new_grads = g / (1 + calpha) 61 | else: 62 | raise ValueError('No support rescale type {}'.format(rescale)) 63 | self._reset_grad(new_grads) 64 | return w_cpu 65 | -------------------------------------------------------------------------------- /LibMTL/weighting/DWA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class DWA(AbsWeighting): 9 | r"""Dynamic Weight Average (DWA). 10 | 11 | This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) `_ \ 12 | and implemented by modifying from the `official PyTorch implementation `_. 13 | 14 | Args: 15 | T (float, default=2.0): The softmax temperature. 16 | 17 | """ 18 | def __init__(self): 19 | super(DWA, self).__init__() 20 | 21 | def backward(self, losses, **kwargs): 22 | T = kwargs['T'] 23 | if self.epoch > 1: 24 | w_i = torch.Tensor(self.train_loss_buffer[:,self.epoch-1]/self.train_loss_buffer[:,self.epoch-2]).to(self.device) 25 | batch_weight = self.task_num*F.softmax(w_i/T, dim=-1) 26 | else: 27 | batch_weight = torch.ones_like(losses).to(self.device) 28 | loss = torch.mul(losses, batch_weight).sum() 29 | loss.backward() 30 | return batch_weight.detach().cpu().numpy() -------------------------------------------------------------------------------- /LibMTL/weighting/EW.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class EW(AbsWeighting): 9 | r"""Equal Weighting (EW). 10 | 11 | The loss weight for each task is always ``1 / T`` in every iteration, where ``T`` denotes the number of tasks. 12 | 13 | """ 14 | def __init__(self): 15 | super(EW, self).__init__() 16 | 17 | def backward(self, losses, **kwargs): 18 | loss = torch.mul(losses, torch.ones_like(losses).to(self.device)).sum() 19 | loss.backward() 20 | return np.ones(self.task_num) -------------------------------------------------------------------------------- /LibMTL/weighting/GLS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class GLS(AbsWeighting): 9 | r"""Geometric Loss Strategy (GLS). 10 | 11 | This method is proposed in `MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning (CVPR 2019 workshop) `_ \ 12 | and implemented by us. 13 | 14 | """ 15 | def __init__(self): 16 | super(GLS, self).__init__() 17 | 18 | def backward(self, losses, **kwargs): 19 | loss = torch.pow(losses.prod(), 1./self.task_num) 20 | loss.backward() 21 | batch_weight = losses / (self.task_num * losses.prod()) 22 | return batch_weight.detach().cpu().numpy() -------------------------------------------------------------------------------- /LibMTL/weighting/GeMTL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class GeMTL(AbsWeighting): 9 | r"""Achievement-based Multi-task Learning (AMTL). 10 | 11 | This method is proposed in `Achievement-based Training Progress Balancing for Multi-Task Learning (ICCV 2023) `_ \ 12 | and implemented by us. 13 | 14 | """ 15 | def __init__(self): 16 | super(GeMTL, self).__init__() 17 | 18 | def backward(self, losses, **kwargs): 19 | # Load hyperparameters 20 | if not hasattr(self, 'potentials'): 21 | self.p = -0.5 22 | 23 | self.p += 1 / (self.train_batch * self.epochs) 24 | 25 | if abs(self.p)<0.1: 26 | loss = torch.pow( losses.prod(), 1./self.task_num) # GM 27 | else: 28 | loss = torch.pow( torch.pow(losses, self.p).sum() / self.task_num, 1/self.p) # GeM 29 | 30 | loss.backward() 31 | batch_weight = losses / (self.task_num * losses.prod()) 32 | return batch_weight.detach().cpu().numpy() -------------------------------------------------------------------------------- /LibMTL/weighting/GradDrop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | 9 | class GradDrop(AbsWeighting): 10 | r"""Gradient Sign Dropout (GradDrop). 11 | 12 | This method is proposed in `Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout (NeurIPS 2020) `_ \ 13 | and implemented by us. 14 | 15 | Args: 16 | leak (float, default=0.0): The leak parameter for the weighting matrix. 17 | 18 | .. warning:: 19 | GradDrop is not supported by parameter gradients, i.e., ``rep_grad`` must be ``True``. 20 | 21 | """ 22 | def __init__(self): 23 | super(GradDrop, self).__init__() 24 | 25 | def backward(self, losses, **kwargs): 26 | leak = kwargs['leak'] 27 | if self.rep_grad: 28 | per_grads = self._compute_grad(losses, mode='backward', rep_grad=True) 29 | else: 30 | raise ValueError('No support method GradDrop with parameter gradients (rep_grad=False)') 31 | 32 | if not isinstance(self.rep, dict): 33 | inputs = self.rep.unsqueeze(0).repeat_interleave(self.task_num, dim=0) 34 | else: 35 | try: 36 | inputs = torch.stack(list(self.rep.values())) 37 | per_grads = torch.stack(per_grads) 38 | except: 39 | raise ValueError('The representation dimensions of different tasks must be consistent') 40 | grads = (per_grads*inputs.sign()).sum(1) 41 | P = 0.5 * (1 + grads.sum(0) / (grads.abs().sum(0)+1e-7)) 42 | U = torch.rand_like(P) 43 | M = P.gt(U).unsqueeze(0).repeat_interleave(self.task_num, dim=0)*grads.gt(0) + \ 44 | P.lt(U).unsqueeze(0).repeat_interleave(self.task_num, dim=0)*grads.lt(0) 45 | M = M.unsqueeze(1).repeat_interleave(per_grads.size()[1], dim=1) 46 | transformed_grad = (per_grads*(leak+(1-leak)*M)) 47 | 48 | if not isinstance(self.rep, dict): 49 | self.rep.backward(transformed_grad.sum(0)) 50 | else: 51 | for tn, task in enumerate(self.task_name): 52 | self.rep[task].backward(transformed_grad[tn], retain_graph=True) 53 | return None -------------------------------------------------------------------------------- /LibMTL/weighting/GradNorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class GradNorm(AbsWeighting): 9 | r"""Gradient Normalization (GradNorm). 10 | 11 | This method is proposed in `GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks (ICML 2018) `_ \ 12 | and implemented by us. 13 | 14 | Args: 15 | alpha (float, default=1.5): The strength of the restoring force which pulls tasks back to a common training rate. 16 | 17 | """ 18 | def __init__(self): 19 | super(GradNorm, self).__init__() 20 | 21 | def init_param(self): 22 | self.loss_scale = nn.Parameter(torch.tensor([1.0]*self.task_num, device=self.device)) 23 | 24 | def backward(self, losses, **kwargs): 25 | alpha = kwargs['alpha'] 26 | if self.epoch >= 1: 27 | loss_scale = self.task_num * F.softmax(self.loss_scale, dim=-1) 28 | grads = self._get_grads(losses, mode='backward') 29 | if self.rep_grad: 30 | per_grads, grads = grads[0], grads[1] 31 | 32 | G_per_loss = torch.norm(loss_scale.unsqueeze(1)*grads, p=2, dim=-1) 33 | G = G_per_loss.mean(0) 34 | L_i = torch.Tensor([losses[tn].item()/self.train_loss_buffer[tn, 0] for tn in range(self.task_num)]).to(self.device) 35 | r_i = L_i/L_i.mean() 36 | constant_term = (G*(r_i**alpha)).detach() 37 | L_grad = (G_per_loss-constant_term).abs().sum(0) 38 | L_grad.backward() 39 | loss_weight = loss_scale.detach().clone() 40 | 41 | if self.rep_grad: 42 | self._backward_new_grads(loss_weight, per_grads=per_grads) 43 | else: 44 | self._backward_new_grads(loss_weight, grads=grads) 45 | return loss_weight.cpu().numpy() 46 | else: 47 | loss = torch.mul(losses, torch.ones_like(losses).to(self.device)).sum() 48 | loss.backward() 49 | return np.ones(self.task_num) 50 | -------------------------------------------------------------------------------- /LibMTL/weighting/IMTL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | 9 | class IMTL(AbsWeighting): 10 | r"""Impartial Multi-task Learning (IMTL). 11 | 12 | This method is proposed in `Towards Impartial Multi-task Learning (ICLR 2021) `_ \ 13 | and implemented by us. 14 | 15 | """ 16 | def __init__(self): 17 | super(IMTL, self).__init__() 18 | 19 | def init_param(self): 20 | self.loss_scale = nn.Parameter(torch.tensor([0.0]*self.task_num, device=self.device)) 21 | 22 | def backward(self, losses, **kwargs): 23 | losses = self.loss_scale.exp()*losses - self.loss_scale 24 | grads = self._get_grads(losses, mode='backward') 25 | if self.rep_grad: 26 | per_grads, grads = grads[0], grads[1] 27 | 28 | grads_unit = grads/torch.norm(grads, p=2, dim=-1, keepdim=True) 29 | 30 | D = grads[0:1].repeat(self.task_num-1, 1) - grads[1:] 31 | U = grads_unit[0:1].repeat(self.task_num-1, 1) - grads_unit[1:] 32 | 33 | alpha = torch.matmul(torch.matmul(grads[0], U.t()), torch.inverse(torch.matmul(D, U.t()))) 34 | alpha = torch.cat((1-alpha.sum().unsqueeze(0), alpha), dim=0) 35 | 36 | if self.rep_grad: 37 | self._backward_new_grads(alpha, per_grads=per_grads) 38 | else: 39 | self._backward_new_grads(alpha, grads=grads) 40 | return alpha.detach().cpu().numpy() 41 | -------------------------------------------------------------------------------- /LibMTL/weighting/IMTL_G.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | 9 | class IMTL_G(AbsWeighting): 10 | r"""Impartial Multi-task Learning (IMTL). 11 | 12 | This method is proposed in `Towards Impartial Multi-task Learning (ICLR 2021) `_ \ 13 | and implemented by us. 14 | 15 | """ 16 | def __init__(self): 17 | super(IMTL_G, self).__init__() 18 | 19 | def init_param(self): 20 | self.loss_scale = nn.Parameter(torch.tensor([0.0]*self.task_num, device=self.device)) 21 | 22 | def backward(self, losses, **kwargs): 23 | # losses = self.loss_scale.exp()*losses - self.loss_scale 24 | grads = self._get_grads(losses, mode='backward') 25 | if self.rep_grad: 26 | per_grads, grads = grads[0], grads[1] 27 | 28 | grads_unit = grads/torch.norm(grads, p=2, dim=-1, keepdim=True) 29 | 30 | D = grads[0:1].repeat(self.task_num-1, 1) - grads[1:] 31 | U = grads_unit[0:1].repeat(self.task_num-1, 1) - grads_unit[1:] 32 | 33 | alpha = torch.matmul(torch.matmul(grads[0], U.t()), torch.inverse(torch.matmul(D, U.t()))) 34 | alpha = torch.cat((1-alpha.sum().unsqueeze(0), alpha), dim=0) 35 | 36 | if self.rep_grad: 37 | self._backward_new_grads(alpha, per_grads=per_grads) 38 | else: 39 | self._backward_new_grads(alpha, grads=grads) 40 | return alpha.detach().cpu().numpy() 41 | -------------------------------------------------------------------------------- /LibMTL/weighting/IMTL_L.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | 9 | class IMTL_L(AbsWeighting): 10 | r"""Impartial Multi-task Learning (IMTL). 11 | 12 | This method is proposed in `Towards Impartial Multi-task Learning (ICLR 2021) `_ \ 13 | and implemented by us. 14 | 15 | """ 16 | def __init__(self): 17 | super(IMTL_L, self).__init__() 18 | 19 | def init_param(self): 20 | self.loss_scale = nn.Parameter(torch.tensor([0.0]*self.task_num, device=self.device)) 21 | 22 | def backward(self, losses, **kwargs): 23 | losses = self.loss_scale.exp()*losses - self.loss_scale 24 | loss = torch.sum(losses) 25 | loss.backward() 26 | return self.loss_scale.exp().detach().cpu().numpy() 27 | -------------------------------------------------------------------------------- /LibMTL/weighting/LSBwD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | ''' 9 | From https://github.com/hw-ch0/IPMTL/blob/35009698edfcbe2893c04a1738505e60a62be7c5/im2im_pred/utils.py 10 | 11 | if index==0: 12 | # w_semantic, w_depth, w_normal = 1/3, 1/3, 1/3 13 | weights[index,:] = 1/3, 1/3, 1/3 14 | else: 15 | loss_prev = weights[index-1,0]*avg_cost[index-1,0] + weights[index-1,1]*avg_cost[index-1,3] + weights[index-1,2]*avg_cost[index-1,6] 16 | weights[index,:] = (loss_prev/avg_cost[index-1,0])/3, (loss_prev/avg_cost[index-1,3])/3, (loss_prev/avg_cost[index-1,6])/3 17 | if not index==1: 18 | loss_prev2 = weights[index-2,0]*avg_cost[index-2,0] + weights[index-2,1]*avg_cost[index-2,3] + weights[index-2,2]*avg_cost[index-2,6] 19 | difficulties[index,0] = (avg_cost[index-1,0]/avg_cost[index-2,0]) / (loss_prev/loss_prev2) 20 | difficulties[index,1] = (avg_cost[index-1,3]/avg_cost[index-2,3]) / (loss_prev/loss_prev2) 21 | difficulties[index,2] = (avg_cost[index-1,6]/avg_cost[index-2,6]) / (loss_prev/loss_prev2) 22 | ''' 23 | 24 | 25 | class LSBwD(AbsWeighting): 26 | r"""Loss Scale Balancing (LSB). 27 | 28 | """ 29 | def __init__(self): 30 | super(LSBwD, self).__init__() 31 | 32 | def backward(self, losses, **kwargs): 33 | if not hasattr(self, 'prev_weight'): 34 | self.prev_weight = torch.ones_like(losses).detach() / self.task_num 35 | self.loss_cache = 0 36 | self.losses_cache = 0 37 | self.iter = 0 38 | self.beta = - (1.0 / self.epochs) 39 | 40 | if not hasattr(self, 'prev2_weight'): 41 | loss = torch.mul(losses, self.prev_weight).sum() 42 | else: 43 | loss = torch.mul(losses, self.alpha * self.difficulties * self.prev_weight).sum() 44 | 45 | self.loss_cache += loss.detach() / self.train_batch 46 | self.losses_cache += losses.detach() / self.train_batch 47 | self.iter += 1 48 | if (self.iter+1) % self.train_batch==0: # epoch == period 49 | if (self.iter+1) > self.train_batch: 50 | self.prev2_weight = self.prev_weight 51 | temp_prev_weight = self.loss_cache / (self.losses_cache * self.task_num) 52 | self.beta += 1.0 / self.epochs 53 | self.difficulties = torch.pow((temp_prev_weight/self.prev2_weight) / (self.losses_cache/self.losses_cache_prev2), self.beta) 54 | self.alpha = self.task_num / sum(self.difficulties) 55 | self.prev_weight = self.loss_cache / (self.losses_cache * self.task_num) 56 | self.losses_cache_prev2 = temp_prev_weight if (self.iter+1) > self.train_batch else self.losses_cache 57 | self.loss_cache = 0 58 | self.losses_cache = 0 59 | 60 | loss.backward() 61 | # return self.prev_weight.detach().cpu().numpy() -------------------------------------------------------------------------------- /LibMTL/weighting/LSBwoD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | ''' 9 | From https://github.com/hw-ch0/IPMTL/blob/35009698edfcbe2893c04a1738505e60a62be7c5/im2im_pred/utils.py 10 | 11 | if index==0: 12 | # w_semantic, w_depth, w_normal = 1/3, 1/3, 1/3 13 | weights[index,:] = 1/3, 1/3, 1/3 14 | else: 15 | loss_prev = weights[index-1,0]*avg_cost[index-1,0] + weights[index-1,1]*avg_cost[index-1,3] + weights[index-1,2]*avg_cost[index-1,6] 16 | weights[index,:] = (loss_prev/avg_cost[index-1,0])/3, (loss_prev/avg_cost[index-1,3])/3, (loss_prev/avg_cost[index-1,6])/3 17 | if not index==1: 18 | loss_prev2 = weights[index-2,0]*avg_cost[index-2,0] + weights[index-2,1]*avg_cost[index-2,3] + weights[index-2,2]*avg_cost[index-2,6] 19 | difficulties[index,0] = (avg_cost[index-1,0]/avg_cost[index-2,0]) / (loss_prev/loss_prev2) 20 | difficulties[index,1] = (avg_cost[index-1,3]/avg_cost[index-2,3]) / (loss_prev/loss_prev2) 21 | difficulties[index,2] = (avg_cost[index-1,6]/avg_cost[index-2,6]) / (loss_prev/loss_prev2) 22 | ''' 23 | 24 | class LSBwoD(AbsWeighting): 25 | r"""Loss Scale Balancing (LSB). 26 | 27 | """ 28 | def __init__(self): 29 | super(LSBwoD, self).__init__() 30 | 31 | def backward(self, losses, **kwargs): 32 | if not hasattr(self, 'prev_weight'): 33 | self.prev_weight = torch.ones_like(losses).detach() / self.task_num 34 | self.loss_cache = 0 35 | self.losses_cache = 0 36 | self.iter = 0 37 | 38 | loss = torch.mul(losses, self.prev_weight).sum() 39 | self.loss_cache += loss.detach() / self.train_batch 40 | self.losses_cache += losses.detach() / self.train_batch 41 | self.iter += 1 42 | if (self.iter+1) % self.train_batch==0: 43 | self.prev_weight = self.loss_cache / (self.losses_cache * self.task_num) 44 | self.loss_cache = 0 45 | self.losses_cache = 0 46 | 47 | loss.backward() 48 | return self.prev_weight.detach().cpu().numpy() -------------------------------------------------------------------------------- /LibMTL/weighting/MoCo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class MoCo(AbsWeighting): 9 | r"""MoCo. 10 | 11 | This method is proposed in `Mitigating Gradient Bias in Multi-objective Learning: A Provably Convergent Approach (ICLR 2023) `_ \ 12 | and implemented based on the author' sharing code (Heshan Fernando: fernah@rpi.edu). 13 | 14 | Args: 15 | MoCo_beta (float, default=0.5): The learning rate of y. 16 | MoCo_beta_sigma (float, default=0.5): The decay rate of MoCo_beta. 17 | MoCo_gamma (float, default=0.1): The learning rate of lambd. 18 | MoCo_gamma_sigma (float, default=0.5): The decay rate of MoCo_gamma. 19 | MoCo_rho (float, default=0): The \ell_2 regularization parameter of lambda's update. 20 | 21 | .. warning:: 22 | MoCo is not supported by representation gradients, i.e., ``rep_grad`` must be ``False``. 23 | 24 | """ 25 | def __init__(self): 26 | super(MoCo, self).__init__() 27 | 28 | def init_param(self): 29 | self._compute_grad_dim() 30 | self.step = 0 31 | self.y = torch.zeros(self.task_num, self.grad_dim).to(self.device) 32 | self.lambd = (torch.ones([self.task_num, ]) / self.task_num).to(self.device) 33 | 34 | def backward(self, losses, **kwargs): 35 | self.step += 1 36 | beta, beta_sigma = kwargs['MoCo_beta'], kwargs['MoCo_beta_sigma'] 37 | gamma, gamma_sigma = kwargs['MoCo_gamma'], kwargs['MoCo_gamma_sigma'] 38 | rho = kwargs['MoCo_rho'] 39 | 40 | if self.rep_grad: 41 | raise ValueError('No support method MoCo with representation gradients (rep_grad=True)') 42 | else: 43 | self._compute_grad_dim() 44 | grads = self._compute_grad(losses, mode='backward') 45 | 46 | with torch.no_grad(): 47 | for tn in range(self.task_num): 48 | grads[tn] = grads[tn]/(grads[tn].norm()+1e-8)*losses[tn] 49 | self.y = self.y - (beta/self.step**beta_sigma) * (self.y - grads) 50 | self.lambd = F.softmax(self.lambd - (gamma/self.step**gamma_sigma) * (self.y@self.y.t()+rho*torch.eye(self.task_num).to(self.device))@self.lambd, -1) 51 | new_grads = self.y.t()@self.lambd 52 | 53 | self._reset_grad(new_grads) 54 | return self.lambd.detach().cpu().numpy() 55 | -------------------------------------------------------------------------------- /LibMTL/weighting/PCGrad.py: -------------------------------------------------------------------------------- 1 | import torch, random 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class PCGrad(AbsWeighting): 9 | r"""Project Conflicting Gradients (PCGrad). 10 | 11 | This method is proposed in `Gradient Surgery for Multi-Task Learning (NeurIPS 2020) `_ \ 12 | and implemented by us. 13 | 14 | .. warning:: 15 | PCGrad is not supported by representation gradients, i.e., ``rep_grad`` must be ``False``. 16 | 17 | """ 18 | def __init__(self): 19 | super(PCGrad, self).__init__() 20 | 21 | def backward(self, losses, **kwargs): 22 | batch_weight = np.ones(len(losses)) 23 | if self.rep_grad: 24 | raise ValueError('No support method PCGrad with representation gradients (rep_grad=True)') 25 | else: 26 | self._compute_grad_dim() 27 | grads = self._compute_grad(losses, mode='backward') # [task_num, grad_dim] 28 | pc_grads = grads.clone() 29 | for tn_i in range(self.task_num): 30 | task_index = list(range(self.task_num)) 31 | random.shuffle(task_index) 32 | for tn_j in task_index: 33 | g_ij = torch.dot(pc_grads[tn_i], grads[tn_j]) 34 | if g_ij < 0: 35 | pc_grads[tn_i] -= g_ij * grads[tn_j] / (grads[tn_j].norm().pow(2)+1e-8) 36 | batch_weight[tn_j] -= (g_ij/(grads[tn_j].norm().pow(2)+1e-8)).item() 37 | new_grads = pc_grads.sum(0) 38 | self._reset_grad(new_grads) 39 | return batch_weight 40 | -------------------------------------------------------------------------------- /LibMTL/weighting/RLW.py: -------------------------------------------------------------------------------- 1 | import torch, random 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | 9 | class RLW(AbsWeighting): 10 | r"""Random Loss Weighting (RLW). 11 | 12 | This method is proposed in `Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning (TMLR 2022) `_ \ 13 | and implemented by us. 14 | 15 | """ 16 | def __init__(self): 17 | super(RLW, self).__init__() 18 | 19 | def backward(self, losses, **kwargs): 20 | batch_weight = F.softmax(torch.randn(self.task_num), dim=-1).to(self.device) 21 | loss = torch.mul(losses, batch_weight).sum() 22 | loss.backward() 23 | return batch_weight.detach().cpu().numpy() 24 | -------------------------------------------------------------------------------- /LibMTL/weighting/SI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class SI(AbsWeighting): 9 | r"""Scale Invariant (SI). 10 | 11 | """ 12 | def __init__(self): 13 | super(SI, self).__init__() 14 | 15 | def backward(self, losses, **kwargs): 16 | loss = torch.log(losses).sum() 17 | loss.backward() 18 | return np.ones(self.task_num) -------------------------------------------------------------------------------- /LibMTL/weighting/SI_naive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class SI_naive(AbsWeighting): 9 | r"""Scale Invariant (SI) - naive version. 10 | 11 | """ 12 | def __init__(self): 13 | super(SI_naive, self).__init__() 14 | 15 | def backward(self, losses, **kwargs): 16 | loss = torch.mul(losses, (1/losses.detach())).sum() 17 | loss.backward() 18 | return np.ones(self.task_num) -------------------------------------------------------------------------------- /LibMTL/weighting/UW.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class UW(AbsWeighting): 9 | r"""Uncertainty Weights (UW). 10 | 11 | This method is proposed in `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (CVPR 2018) `_ \ 12 | and implemented by us. 13 | 14 | """ 15 | def __init__(self): 16 | super(UW, self).__init__() 17 | 18 | def init_param(self): 19 | self.loss_scale = nn.Parameter(torch.tensor([-0.5]*self.task_num, device=self.device)) 20 | 21 | def backward(self, losses, **kwargs): 22 | loss = (losses/(2*self.loss_scale.exp())+self.loss_scale/2).sum() 23 | loss.backward() 24 | return (1/(2*torch.exp(self.loss_scale))).detach().cpu().numpy() 25 | -------------------------------------------------------------------------------- /LibMTL/weighting/__init__.py: -------------------------------------------------------------------------------- 1 | from LibMTL.weighting.abstract_weighting import AbsWeighting 2 | from LibMTL.weighting.EW import EW 3 | from LibMTL.weighting.GradNorm import GradNorm 4 | from LibMTL.weighting.MGDA import MGDA 5 | from LibMTL.weighting.UW import UW 6 | from LibMTL.weighting.DWA import DWA 7 | from LibMTL.weighting.GLS import GLS 8 | from LibMTL.weighting.Arithmetic import Arithmetic 9 | from LibMTL.weighting.GradDrop import GradDrop 10 | from LibMTL.weighting.PCGrad import PCGrad 11 | from LibMTL.weighting.GradVac import GradVac 12 | from LibMTL.weighting.IMTL import IMTL 13 | from LibMTL.weighting.IMTL_L import IMTL_L 14 | from LibMTL.weighting.IMTL_G import IMTL_G 15 | from LibMTL.weighting.LSBwD import LSBwD 16 | from LibMTL.weighting.LSBwoD import LSBwoD 17 | from LibMTL.weighting.CAGrad import CAGrad 18 | from LibMTL.weighting.Nash_MTL import Nash_MTL 19 | from LibMTL.weighting.RLW import RLW 20 | from LibMTL.weighting.MoCo import MoCo 21 | from LibMTL.weighting.Aligned_MTL import Aligned_MTL 22 | from LibMTL.weighting.SI import SI 23 | from LibMTL.weighting.SI_naive import SI_naive 24 | from LibMTL.weighting.AMTL import AMTL 25 | from LibMTL.weighting.GeMTL import GeMTL 26 | 27 | __all__ = ['AbsWeighting', 28 | 'EW', 29 | 'GradNorm', 30 | 'MGDA', 31 | 'UW', 32 | 'DWA', 33 | 'GLS', 34 | 'Arithmetic', 35 | 'GradDrop', 36 | 'PCGrad', 37 | 'GradVac', 38 | 'IMTL', 39 | 'IMTL_L', 40 | 'IMTL_G', 41 | 'LSBwD', 42 | 'LSBwoD', 43 | 'CAGrad', 44 | 'Nash_MTL', 45 | 'RLW', 46 | 'MoCo', 47 | 'Aligned_MTL', 48 | 'SI', 49 | 'SI_naive', 50 | 'AMTL', 51 | 'GeMTL', 52 | ] -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/AMTL.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/AMTL.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/AMTL_GeM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/AMTL_GeM.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/AMTL_GeM_anti.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/AMTL_GeM_anti.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/AMTL_GeM_curri.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/AMTL_GeM_curri.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/AMTL_SI.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/AMTL_SI.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/Aligned_MTL.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/Aligned_MTL.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/Arithmetic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/Arithmetic.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/CAGrad.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/CAGrad.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/DWA.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/DWA.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/EW.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/EW.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GLS.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GLS.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GeMTL.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GeMTL.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GeM_anti.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GeM_anti.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GeM_curri.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GeM_curri.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GeMopt0.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GeMopt0.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GeMopt1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GeMopt1.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GeMopt10.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GeMopt10.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GeMopt11.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GeMopt11.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GeMopt2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GeMopt2.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GeMopt3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GeMopt3.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GeMopt4.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GeMopt4.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GradDrop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GradDrop.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GradNorm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GradNorm.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/GradVac.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/GradVac.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/IMTL.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/IMTL.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/IMTL_G.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/IMTL_G.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/IMTL_L.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/IMTL_L.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/LSBwD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/LSBwD.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/LSBwoD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/LSBwoD.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/MBMTL_AM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/MBMTL_AM.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/MBMTL_AM_10ep.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/MBMTL_AM_10ep.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/MGDA.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/MGDA.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/MoCo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/MoCo.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/Nash_MTL.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/Nash_MTL.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/PCGrad.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/PCGrad.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/RLW.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/RLW.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/SI.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/SI.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/SI_naive.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/SI_naive.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/UW.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/UW.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LibMTL/weighting/__pycache__/abstract_weighting.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/LibMTL/weighting/__pycache__/abstract_weighting.cpython-38.pyc -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documents of LibMTL 2 | 3 | This document is based on [Sphinx](http://sphinx-doc.org/) and uses [Read the Docs](https://readthedocs.org/) for deployment. Besides, it is implemented by referencing from https://github.com/rlworkgroup/garage/tree/master/docs and https://github.com/LibCity/Bigscity-LibCity-Docs. 4 | -------------------------------------------------------------------------------- /docs/_build/doctrees/README.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/README.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/autoapi_templates/python/module.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/autoapi_templates/python/module.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/_record/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/_record/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/CGC/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/CGC/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/Cross_stitch/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/Cross_stitch/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/DSelect_k/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/DSelect_k/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/HPS/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/HPS/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/MMoE/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/MMoE/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/MTAN/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/MTAN/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/PLE/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/PLE/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/abstract_arch/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/abstract_arch/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/config/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/config/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/loss/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/loss/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/metrics/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/metrics/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/model/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/model/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/model/resnet/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/model/resnet/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/model/resnet_dilated/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/model/resnet_dilated/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/trainer/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/trainer/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/utils/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/utils/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/CAGrad/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/CAGrad/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/DWA/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/DWA/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/EW/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/EW/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GLS/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GLS/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradDrop/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradDrop/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradNorm/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradNorm/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradVac/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradVac/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/IMTL/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/IMTL/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/MGDA/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/MGDA/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/PCGrad/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/PCGrad/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/RLW/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/RLW/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/UW/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/UW/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/abstract_weighting/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/abstract_weighting/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/develop/arch.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/develop/arch.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/develop/dataset.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/develop/dataset.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/develop/weighting.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/develop/weighting.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/getting_started/installation.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/getting_started/installation.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/getting_started/introduction.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/getting_started/introduction.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/getting_started/quick_start.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/getting_started/quick_start.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/user_guide/benchmark.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/user_guide/benchmark.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/user_guide/benchmark/nyuv2.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/user_guide/benchmark/nyuv2.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/user_guide/benchmark/office.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/user_guide/benchmark/office.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/user_guide/framework.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/user_guide/framework.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/user_guide/mtl.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/docs/user_guide/mtl.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/environment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/environment.pickle -------------------------------------------------------------------------------- /docs/_build/doctrees/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/doctrees/index.doctree -------------------------------------------------------------------------------- /docs/_build/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 979045b8d53d6dcd1826ddb3a4fcb161 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/_build/html/_images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_images/framework.png -------------------------------------------------------------------------------- /docs/_build/html/_images/multi_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_images/multi_input.png -------------------------------------------------------------------------------- /docs/_build/html/_images/rep_grad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_images/rep_grad.png -------------------------------------------------------------------------------- /docs/_build/html/_sources/README.md.txt: -------------------------------------------------------------------------------- 1 | # Documents of LibMTL 2 | 3 | This document is based on [Sphinx](http://sphinx-doc.org/) and uses [Read the Docs](https://readthedocs.org/) for deployment. Besides, it is implemented by referring to https://github.com/rlworkgroup/garage/tree/master/docs and https://github.com/LibCity/Bigscity-LibCity-Docs. -------------------------------------------------------------------------------- /docs/_build/html/_sources/autoapi_templates/python/module.rst.txt: -------------------------------------------------------------------------------- 1 | {% if not obj.display %} 2 | :orphan: 3 | 4 | {% endif %} 5 | :mod:`{{ obj.name }}` 6 | ======={{ "=" * obj.name|length }} 7 | 8 | .. py:module:: {{ obj.name }} 9 | 10 | {% if obj.docstring %} 11 | .. autoapi-nested-parse:: 12 | 13 | {{ obj.docstring|prepare_docstring|indent(3) }} 14 | 15 | {% endif %} 16 | 17 | {% block content %} 18 | {% if obj.all is not none %} 19 | {% set visible_children = obj.children|selectattr("short_name", "in", obj.all)|list %} 20 | {% elif obj.type is equalto("package") %} 21 | {% set visible_children = obj.children|selectattr("display")|list %} 22 | {% else %} 23 | {% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %} 24 | {% endif %} 25 | {% if visible_children %} 26 | {# {{ obj.type|title }} Contents 27 | {{ "-" * obj.type|length }}--------- #} 28 | 29 | {% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %} 30 | {% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %} 31 | {% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %} 32 | 33 | {# {% block classes scoped %} 34 | {% if visible_classes %} 35 | Classes 36 | ------- 37 | 38 | .. autoapisummary:: 39 | 40 | {% for klass in visible_classes %} 41 | {{ klass.id }} 42 | {% endfor %} 43 | 44 | 45 | {% endif %} 46 | {% endblock %} #} 47 | 48 | {# {% block functions scoped %} 49 | {% if visible_functions %} 50 | Functions 51 | --------- 52 | 53 | .. autoapisummary:: 54 | 55 | {% for function in visible_functions %} 56 | {{ function.id }} 57 | {% endfor %} 58 | 59 | 60 | {% endif %} 61 | {% endblock %} #} 62 | 63 | {% endif %} 64 | {% for obj_item in visible_children %} 65 | {{ obj_item.rendered|indent(0) }} 66 | {% endfor %} 67 | {% endif %} 68 | {% endblock %} 69 | 70 | {% block subpackages %} 71 | {% set visible_subpackages = obj.subpackages|selectattr("display")|list %} 72 | {% if visible_subpackages %} 73 | {# Subpackages 74 | ----------- #} 75 | .. toctree:: 76 | :titlesonly: 77 | :maxdepth: 1 78 | 79 | {% for subpackage in visible_subpackages %} 80 | {{ subpackage.short_name }}/index.rst 81 | {% endfor %} 82 | 83 | 84 | {% endif %} 85 | {% endblock %} 86 | {# {% block submodules %} 87 | {% set visible_submodules = obj.submodules|selectattr("display")|list %} 88 | {% if visible_submodules %} 89 | Submodules 90 | ---------- 91 | .. toctree:: 92 | :titlesonly: 93 | :maxdepth: 1 94 | 95 | {% for submodule in visible_submodules %} 96 | {{ submodule.short_name }}/index.rst 97 | {% endfor %} 98 | 99 | 100 | {% endif %} 101 | {% endblock %} #} 102 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/_record/index.rst.txt: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | :mod:`LibMTL._record` 4 | ===================== 5 | 6 | .. py:module:: LibMTL._record 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/CGC/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.CGC` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.CGC 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: CGC(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.MMoE.MMoE` 14 | 15 | Customized Gate Control (CGC). 16 | 17 | This method is proposed in `Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The numbers of experts shared for all tasks and specific to each task, respectively. Each expert is the encoder network. 23 | :type num_experts: list 24 | 25 | .. py:method:: forward(self, inputs, task_name=None) 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/Cross_stitch/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.Cross_stitch` 2 | ======================================= 3 | 4 | .. py:module:: LibMTL.architecture.Cross_stitch 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: Cross_stitch(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Cross-stitch Networks (Cross_stitch). 16 | 17 | This method is proposed in `Cross-stitch Networks for Multi-task Learning (CVPR 2016) `_ \ 18 | and implemented by us. 19 | 20 | .. warning:: 21 | - :class:`Cross_stitch` does not work with multiple inputs MTL problem, i.e., ``multi_input`` must be ``False``. 22 | 23 | - :class:`Cross_stitch` is only supported with ResNet-based encoder. 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/DSelect_k/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.DSelect_k` 2 | ==================================== 3 | 4 | .. py:module:: LibMTL.architecture.DSelect_k 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: DSelect_k(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.MMoE.MMoE` 14 | 15 | DSelect-k. 16 | 17 | This method is proposed in `DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning (NeurIPS 2021) `_ \ 18 | and implemented by modifying from the `official TensorFlow implementation `_. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The number of experts shared for all tasks. Each expert is the encoder network. 23 | :type num_experts: int 24 | :param num_nonzeros: The number of selected experts. 25 | :type num_nonzeros: int 26 | :param kgamma: A scaling parameter for the smooth-step function. 27 | :type kgamma: float, default=1.0 28 | 29 | .. py:method:: forward(self, inputs, task_name=None) 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/HPS/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.HPS` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.HPS 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: HPS(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Hrad Parameter Sharing (HPS). 16 | 17 | This method is proposed in `Multitask Learning: A Knowledge-Based Source of Inductive Bias (ICML 1993) `_ \ 18 | and implemented by us. 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/MMoE/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.MMoE` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.architecture.MMoE 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MMoE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Multi-gate Mixture-of-Experts (MMoE). 16 | 17 | This method is proposed in `Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts (KDD 2018) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The number of experts shared for all tasks. Each expert is the encoder network. 23 | :type num_experts: int 24 | 25 | .. py:method:: forward(self, inputs, task_name=None) 26 | 27 | 28 | .. py:method:: get_share_params(self) 29 | 30 | 31 | .. py:method:: zero_grad_share_params(self) 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/MTAN/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.MTAN` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.architecture.MTAN 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MTAN(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Multi-Task Attention Network (MTAN). 16 | 17 | This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | .. warning:: 21 | :class:`MTAN` is only supported with ResNet-based encoder. 22 | 23 | 24 | .. py:method:: forward(self, inputs, task_name=None) 25 | 26 | 27 | .. py:method:: get_share_params(self) 28 | 29 | 30 | .. py:method:: zero_grad_share_params(self) 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/PLE/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.PLE` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.PLE 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: PLE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Progressive Layered Extraction (PLE). 16 | 17 | This method is proposed in `Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The numbers of experts shared for all tasks and specific to each task, respectively. Each expert is the encoder network. 23 | :type num_experts: list 24 | 25 | .. warning:: 26 | - :class:`PLE` does not work with multiple inputs MTL problem, i.e., ``multi_input`` must be ``False``. 27 | - :class:`PLE` is only supported with ResNet-based encoder. 28 | 29 | 30 | .. py:method:: forward(self, inputs, task_name=None) 31 | 32 | 33 | .. py:method:: get_share_params(self) 34 | 35 | 36 | .. py:method:: zero_grad_share_params(self) 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/abstract_arch/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.abstract_arch` 2 | ======================================== 3 | 4 | .. py:module:: LibMTL.architecture.abstract_arch 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsArchitecture(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | An abstract class for MTL architectures. 16 | 17 | :param task_name: A list of strings for all tasks. 18 | :type task_name: list 19 | :param encoder_class: A neural network class. 20 | :type encoder_class: class 21 | :param decoders: A dictionary of name-decoder pairs of type (:class:`str`, :class:`torch.nn.Module`). 22 | :type decoders: dict 23 | :param rep_grad: If ``True``, the gradient of the representation for each task can be computed. 24 | :type rep_grad: bool 25 | :param multi_input: Is ``True`` if each task has its own input data, ``False`` otherwise. 26 | :type multi_input: bool 27 | :param device: The device where model and data will be allocated. 28 | :type device: torch.device 29 | :param kwargs: A dictionary of hyperparameters of architecture methods. 30 | :type kwargs: dict 31 | 32 | .. py:method:: forward(self, inputs, task_name=None) 33 | 34 | :param inputs: The input data. 35 | :type inputs: torch.Tensor 36 | :param task_name: The task name corresponding to ``inputs`` if ``multi_input`` is ``True``. 37 | :type task_name: str, default=None 38 | 39 | :returns: A dictionary of name-prediction pairs of type (:class:`str`, :class:`torch.Tensor`). 40 | :rtype: dict 41 | 42 | 43 | .. py:method:: get_share_params(self) 44 | 45 | Return the shared parameters of the model. 46 | 47 | 48 | 49 | .. py:method:: zero_grad_share_params(self) 50 | 51 | Set gradients of the shared parameters to zero. 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/config/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.config` 2 | ==================== 3 | 4 | .. py:module:: LibMTL.config 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:data:: LibMTL_args 12 | 13 | 14 | 15 | 16 | .. py:function:: prepare_args(params) 17 | 18 | Return the configuration of hyperparameters, optimizier, and learning rate scheduler. 19 | 20 | :param params: The command-line arguments. 21 | :type params: argparse.Namespace 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/loss/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.loss` 2 | ================== 3 | 4 | .. py:module:: LibMTL.loss 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsLoss 12 | 13 | Bases: :py:obj:`object` 14 | 15 | An abstract class for loss function. 16 | 17 | 18 | .. py:method:: compute_loss(self, pred, gt) 19 | :property: 20 | 21 | Calculate the loss. 22 | 23 | :param pred: The prediction tensor. 24 | :type pred: torch.Tensor 25 | :param gt: The ground-truth tensor. 26 | :type gt: torch.Tensor 27 | 28 | :returns: The loss. 29 | :rtype: torch.Tensor 30 | 31 | 32 | 33 | .. py:class:: CELoss 34 | 35 | Bases: :py:obj:`AbsLoss` 36 | 37 | The cross entropy loss function. 38 | 39 | 40 | .. py:method:: compute_loss(self, pred, gt) 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/metrics/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.metrics` 2 | ===================== 3 | 4 | .. py:module:: LibMTL.metrics 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsMetric 12 | 13 | Bases: :py:obj:`object` 14 | 15 | An abstract class for the performance metrics of a task. 16 | 17 | .. attribute:: record 18 | 19 | A list of the metric scores in every iteration. 20 | 21 | :type: list 22 | 23 | .. attribute:: bs 24 | 25 | A list of the number of data in every iteration. 26 | 27 | :type: list 28 | 29 | .. py:method:: update_fun(self, pred, gt) 30 | :property: 31 | 32 | Calculate the metric scores in every iteration and update :attr:`record`. 33 | 34 | :param pred: The prediction tensor. 35 | :type pred: torch.Tensor 36 | :param gt: The ground-truth tensor. 37 | :type gt: torch.Tensor 38 | 39 | 40 | .. py:method:: score_fun(self) 41 | :property: 42 | 43 | Calculate the final score (when a epoch ends). 44 | 45 | :returns: A list of metric scores. 46 | :rtype: list 47 | 48 | 49 | .. py:method:: reinit(self) 50 | 51 | Reset :attr:`record` and :attr:`bs` (when a epoch ends). 52 | 53 | 54 | 55 | 56 | .. py:class:: AccMetric 57 | 58 | Bases: :py:obj:`AbsMetric` 59 | 60 | Calculate the accuracy. 61 | 62 | 63 | .. py:method:: update_fun(self, pred, gt) 64 | 65 | 66 | 67 | 68 | .. py:method:: score_fun(self) 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/model/resnet_dilated/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.model.resnet_dilated` 2 | ================================== 3 | 4 | .. py:module:: LibMTL.model.resnet_dilated 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: ResnetDilated(orig_resnet, dilate_scale=8) 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | Base class for all neural network modules. 16 | 17 | Your models should also subclass this class. 18 | 19 | Modules can also contain other Modules, allowing to nest them in 20 | a tree structure. You can assign the submodules as regular attributes:: 21 | 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | class Model(nn.Module): 26 | def __init__(self): 27 | super(Model, self).__init__() 28 | self.conv1 = nn.Conv2d(1, 20, 5) 29 | self.conv2 = nn.Conv2d(20, 20, 5) 30 | 31 | def forward(self, x): 32 | x = F.relu(self.conv1(x)) 33 | return F.relu(self.conv2(x)) 34 | 35 | Submodules assigned in this way will be registered, and will have their 36 | parameters converted too when you call :meth:`to`, etc. 37 | 38 | .. py:method:: forward(self, x) 39 | 40 | Defines the computation performed at every call. 41 | 42 | Should be overridden by all subclasses. 43 | 44 | .. note:: 45 | Although the recipe for forward pass needs to be defined within 46 | this function, one should call the :class:`Module` instance afterwards 47 | instead of this since the former takes care of running the 48 | registered hooks while the latter silently ignores them. 49 | 50 | 51 | .. py:method:: forward_stage(self, x, stage) 52 | 53 | 54 | 55 | .. py:function:: resnet_dilated(basenet, pretrained=True, dilate_scale=8) 56 | 57 | Dilated Residual Network models from `"Dilated Residual Networks" `_ 58 | 59 | :param basenet: The type of ResNet. 60 | :type basenet: str 61 | :param pretrained: If True, returns a model pre-trained on ImageNet. 62 | :type pretrained: bool 63 | :param dilate_scale: The type of dilating process. 64 | :type dilate_scale: {8, 16}, default=8 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/utils/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.utils` 2 | =================== 3 | 4 | .. py:module:: LibMTL.utils 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:function:: set_random_seed(seed) 12 | 13 | Set the random seed for reproducibility. 14 | 15 | :param seed: The random seed. 16 | :type seed: int, default=0 17 | 18 | 19 | .. py:function:: set_device(gpu_id) 20 | 21 | Set the device where model and data will be allocated. 22 | 23 | :param gpu_id: The id of gpu. 24 | :type gpu_id: str, default='0' 25 | 26 | 27 | .. py:function:: count_parameters(model) 28 | 29 | Calculates the number of parameters for a model. 30 | 31 | :param model: A neural network module. 32 | :type model: torch.nn.Module 33 | 34 | 35 | .. py:function:: count_improvement(base_result, new_result, weight) 36 | 37 | Calculate the improvement between two results, 38 | 39 | .. math:: 40 | \Delta_{\mathrm{p}}=100\%\times \frac{1}{T}\sum_{t=1}^T 41 | \frac{1}{M_t}\sum_{m=1}^{M_t}\frac{(-1)^{w_{t,m}}(B_{t,m}-N_{t,m})}{N_{t,m}}. 42 | 43 | :param base_result: A dictionary of scores of all metrics of all tasks. 44 | :type base_result: dict 45 | :param new_result: The same structure with ``base_result``. 46 | :type new_result: dict 47 | :param weight: The same structure with ``base_result`` while each elements is binary integer representing whether higher or lower score is better. 48 | :type weight: dict 49 | 50 | :returns: The improvement between ``new_result`` and ``base_result``. 51 | :rtype: float 52 | 53 | Examples:: 54 | 55 | base_result = {'A': [96, 98], 'B': [0.2]} 56 | new_result = {'A': [93, 99], 'B': [0.5]} 57 | weight = {'A': [1, 0], 'B': [1]} 58 | 59 | print(count_improvement(base_result, new_result, weight)) 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/CAGrad/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.CAGrad` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.weighting.CAGrad 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: CAGrad 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Conflict-Averse Gradient descent (CAGrad). 16 | 17 | This method is proposed in `Conflict-Averse Gradient Descent for Multi-task learning (NeurIPS 2021) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param calpha: A hyperparameter that controls the convergence rate. 21 | :type calpha: float, default=0.5 22 | :param rescale: The type of gradient rescale. 23 | :type rescale: {0, 1, 2}, default=1 24 | 25 | .. warning:: 26 | CAGrad is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 27 | 28 | 29 | .. py:method:: backward(self, losses, **kwargs) 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/DWA/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.DWA` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.DWA 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: DWA 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Dynamic Weight Average (DWA). 16 | 17 | This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param T: The softmax temperature. 21 | :type T: float, default=2.0 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/EW/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.EW` 2 | ========================== 3 | 4 | .. py:module:: LibMTL.weighting.EW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: EW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Equally Weighting (EW). 16 | 17 | The loss weight for each task is always ``1 / T`` in every iteration, where ``T`` means the number of tasks. 18 | 19 | 20 | .. py:method:: backward(self, losses, **kwargs) 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/GLS/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GLS` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.GLS 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GLS 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Geometric Loss Strategy (GLS). 16 | 17 | This method is proposed in `MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning (CVPR 2019 workshop) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: backward(self, losses, **kwargs) 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/GradDrop/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradDrop` 2 | ================================ 3 | 4 | .. py:module:: LibMTL.weighting.GradDrop 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradDrop 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Sign Dropout (GradDrop). 16 | 17 | This method is proposed in `Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout (NeurIPS 2020) `_ \ 18 | and implemented by us. 19 | 20 | :param leak: The leak parameter for the weighting matrix. 21 | :type leak: float, default=0.0 22 | 23 | .. warning:: 24 | GradDrop is not supported with parameter gradients, i.e., ``rep_grad`` must be ``True``. 25 | 26 | 27 | .. py:method:: backward(self, losses, **kwargs) 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/GradNorm/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradNorm` 2 | ================================ 3 | 4 | .. py:module:: LibMTL.weighting.GradNorm 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradNorm 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Normalization (GradNorm). 16 | 17 | This method is proposed in `GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks (ICML 2018) `_ \ 18 | and implemented by us. 19 | 20 | :param alpha: The strength of the restoring force which pulls tasks back to a common training rate. 21 | :type alpha: float, default=1.5 22 | 23 | .. py:method:: init_param(self) 24 | 25 | 26 | .. py:method:: backward(self, losses, **kwargs) 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/GradVac/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradVac` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.weighting.GradVac 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradVac 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Vaccine (GradVac). 16 | 17 | This method is proposed in `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) `_ \ 18 | and implemented by us. 19 | 20 | :param beta: The exponential moving average (EMA) decay parameter. 21 | :type beta: float, default=0.5 22 | 23 | .. warning:: 24 | GradVac is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 25 | 26 | 27 | .. py:method:: backward(self, losses, **kwargs) 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/IMTL/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.IMTL` 2 | ============================ 3 | 4 | .. py:module:: LibMTL.weighting.IMTL 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: IMTL 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Impartial Multi-task Learning (IMTL). 16 | 17 | This method is proposed in `Towards Impartial Multi-task Learning (ICLR 2021) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: init_param(self) 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/MGDA/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.MGDA` 2 | ============================ 3 | 4 | .. py:module:: LibMTL.weighting.MGDA 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MGDA 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Multiple Gradient Descent Algorithm (MGDA). 16 | 17 | This method is proposed in `Multi-Task Learning as Multi-Objective Optimization (NeurIPS 2018) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param mgda_gn: The type of gradient normalization. 21 | :type mgda_gn: {'none', 'l2', 'loss', 'loss+'}, default='none' 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/PCGrad/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.PCGrad` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.weighting.PCGrad 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: PCGrad 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Project Conflicting Gradients (PCGrad). 16 | 17 | This method is proposed in `Gradient Surgery for Multi-Task Learning (NeurIPS 2020) `_ \ 18 | and implemented by us. 19 | 20 | .. warning:: 21 | PCGrad is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/RLW/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.RLW` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.RLW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: RLW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Random Loss Weighting (RLW). 16 | 17 | This method is proposed in `A Closer Look at Loss Weighting in Multi-Task Learning (arXiv:2111.10603) `_ \ 18 | and implemented by us. 19 | 20 | :param dist: The type of distribution where the loss weigghts are sampled from. 21 | :type dist: {'Uniform', 'Normal', 'Dirichlet', 'Bernoulli', 'constrained_Bernoulli'}, default='Normal' 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | :param losses: A list of loss of each task. 26 | :type losses: list 27 | :param kwargs: A dictionary of hyperparameters of weighting methods. 28 | :type kwargs: dict 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/UW/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.UW` 2 | ========================== 3 | 4 | .. py:module:: LibMTL.weighting.UW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: UW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Uncertainty Weights (UW). 16 | 17 | This method is proposed in `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (CVPR 2018) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: init_param(self) 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/abstract_weighting/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.abstract_weighting` 2 | ========================================== 3 | 4 | .. py:module:: LibMTL.weighting.abstract_weighting 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsWeighting 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | An abstract class for weighting strategies. 16 | 17 | 18 | .. py:method:: init_param(self) 19 | 20 | Define and initialize some trainable parameters required by specific weighting methods. 21 | 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | :property: 26 | 27 | :param losses: A list of loss of each task. 28 | :type losses: list 29 | :param kwargs: A dictionary of hyperparameters of weighting methods. 30 | :type kwargs: dict 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/develop/arch.md.txt: -------------------------------------------------------------------------------- 1 | ## Customize an Architecture 2 | 3 | Here we would like to introduce how to customize a new architecture with the support of ``LibMTL``. 4 | 5 | ### Create a New Architecture Class 6 | 7 | ```eval_rst 8 | Firstly, you need to create a new architecture class by inheriting class :class:`LibMTL.architecture.AbsArchitecture`. 9 | ``` 10 | 11 | ```python 12 | from LibMTL.architecture import AbsArchitecture 13 | 14 | class NewArchitecture(AbsArchitecture): 15 | def __init__(self, task_name, encoder_class, decoders, rep_grad, 16 | multi_input, device, **kwargs): 17 | super(NewArchitecture, self).__init__(task_name, encoder_class, decoders, rep_grad, 18 | multi_input, device, **kwargs) 19 | ``` 20 | 21 | ### Rewrite Corresponding Methods 22 | 23 | ```eval_rst 24 | There are four important function in :class:`LibMTL.architecture.AbsArchitecture`. We will introduce them in detail as follows. 25 | 26 | - :func:`forward`: The forward function and its input and output format can be found in :func:`LibMTL.architecture.AbsArchitecture.forward`. To rewrite this function, you need to consider the case of ``multi-input`` and ``multi-label`` (refer to `here <../user_guide/mtl.html#network-architecture>`_) and the case of ``rep-grad`` and ``param-grad`` (refer to `here <../user_guide/mtl.html#weighting-strategy>`_) if you would like to combine your architecture with more weighting strategies or apply your architecture to more datasets. 27 | - :func:`get_share_params`: This function is used to return the shared parameters of the model. It returned all the parameters of encoder by default. You can rewrite it if necessary. 28 | - :func:`zero_grad_share_params`: This function is used to set gradients of the shared parameters to zero. It will set the gradients of all the encoder parameters to zero by default. You can rewrite it if necessary. 29 | - :func:`_prepare_rep`: This function is used to allow to compute the gradients for representations. More details can be found `here <../../_modules/LibMTL/architecture/abstract_arch.html#AbsArchitecture>`_. 30 | ``` 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/develop/weighting.md.txt: -------------------------------------------------------------------------------- 1 | ## Customize a Weighting Strategy 2 | 3 | Here we would like to introduce how to customize a new weighting strategy with the support of ``LibMTL``. 4 | 5 | ### Create a New Weighting Class 6 | 7 | ```eval_rst 8 | Firstly, you need to create a new weighting class by inheriting class :class:`LibMTL.weighting.AbsWeighting`. 9 | ``` 10 | 11 | ```python 12 | from LibMTL.weighting import AbsWeighting 13 | 14 | class NewWeighting(AbsWeighting): 15 | def __init__(self): 16 | super(NewWeighting, self).__init__() 17 | ``` 18 | 19 | ### Rewrite Corresponding Methods 20 | 21 | ```eval_rst 22 | There are four important function in :class:`LibMTL.weighting.AbsWeighting`. We will introduce them in detail as follows. 23 | 24 | - :func:`backward`: The main function of a weighting strategy its input and output format can be found in :func:`LibMTL.weighting.AbsWeighting.backward`. To rewrite this function, you need to consider the case of ``multi-input`` and ``multi-label`` (refer to `here <../user_guide/mtl.html#network-architecture>`_) and the case of ``rep-grad`` and ``param-grad`` (refer to `here <../user_guide/mtl.html#weighting-strategy>`_) if you would like to combine your weighting method with more architectures or apply your method to more datasets. 25 | - :func:`init_param`: This function is used to define and initialize some trainable parameters. It does nothing by default and can be rewritten if necessary. 26 | - :func:`_get_grads`: This function is used to return the gradients of representations or shared parameters (covering the case of ``rep-grad`` and ``param-grad``). 27 | - :func:`_backward_new_grads`: This function is used to reset the gradients and make a backward (covering the case of ``rep-grad`` and ``param-grad``). 28 | 29 | The :func:`_get_grads` and :func:`_backward_new_grads` functions are very useful to rewrite the :func:`backward` function and you can find more details about them in `here <../../_modules/LibMTL/weighting/abstract_weighting.html#AbsWeighting>`_. 30 | ``` 31 | 32 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/getting_started/installation.md.txt: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ### Dependencies 4 | 5 | To install ``LibMTL``, you need to setup the following libraries: 6 | 7 | - Python >= 3.7 8 | - PyTorch >= 1.8.0 9 | - torchvision >= 0.9.0 10 | - numpy >= 1.20 11 | 12 | ### User Installation 13 | 14 | #### Using PyPi 15 | 16 | The simplest way to install `LibMTL` is using `pip`. 17 | 18 | ```shell 19 | pip install -U LibMTL 20 | ``` 21 | 22 | #### Using Source Code 23 | 24 | If you prefer, you can clone the source code from the GitHub and run the setup.py file. 25 | 26 | ```shell 27 | git clone https://github.com/median-research-group/LibMTL.git 28 | cd LibMTL 29 | python setup.py install 30 | ``` 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/user_guide/benchmark.rst.txt: -------------------------------------------------------------------------------- 1 | Run a Benckmark 2 | =============== 3 | 4 | Here we will introduce some MTL benchmark datasets and show run models on them for a fair compaison. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | benchmark/nyuv2 10 | benchmark/office 11 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/user_guide/benchmark/office.md.txt: -------------------------------------------------------------------------------- 1 | ## Office-31 and Office-Home 2 | 3 | ```eval_rst 4 | The Office-31 dataset :cite:`saenko2010adapting` consists of three domains: Amazon, DSLR, and Webcam, where each domain contains 31 object categories. It can be download `here `_. This dataset contains 4,110 labeled images and we randomly split these samples with 60\% for training, 20\% for validation, and the rest 20\% for test. 5 | 6 | The Office-Home dataset :cite:`venkateswara2017deep` has four domains: Artistic images (abbreviated as Art), Clip art, Product images, and Real-world images. It can be download `here `_. This dataset has 15,500 labeled images in total and each domain contains 65 classes. We divide the entire data in the same proportion as Office-31. 7 | 8 | For both two datasets, we consider the multi-class classification problem on each domain as a task. Thus, the ``multi_input`` must be ``True`` for both two office datasets. 9 | 10 | The training code are available in ``examples/office``. We used the ResNet-18 network pretrained on the ImageNet dataset followed by a fully connected layer as a shared encoder among tasks and a fully connected layer is applied as a task-specific output layer for each task. All the input images are resized to 3x224x224. 11 | ``` 12 | 13 | ### Run a Model 14 | 15 | The script ``train_office.py`` is the main file for training and evaluating a MTL model on the Office-31 or Office-Home dataset. A set of command-line arguments is provided to allow users to adjust the training parameter configuration. 16 | 17 | Some important arguments are described as follows. 18 | 19 | ```eval_rst 20 | - ``weighting``: The weighting strategy. Refer to `here <../_autoapi/LibMTL/weighting/index.html>`_. 21 | - ``arch``: The MTL architecture. Refer to `here <../_autoapi/LibMTL/architecture/index.html>`_. 22 | - ``gpu_id``: The id of gpu. Default to '0'. 23 | - ``seed``: The random seed for reproducibility. Default to 0. 24 | - ``optim``: The type of the optimizer. We recommend to use 'adam' here. 25 | - ``dataset``: Training on Office-31 or Office-Home. Options: 'office-31', 'office-home'. 26 | - ``dataset_path``: The path of the Office-31 or Office-Home dataset. 27 | - ``bs``: The batch size of training, validation, and test data. Default to 64. 28 | ``` 29 | 30 | The complete command-line arguments and their descriptions can be found by running the following command. 31 | 32 | ```shell 33 | python train_office.py -h 34 | ``` 35 | 36 | If you understand those command-line arguments, you can train a MTL model by running a command like this. 37 | 38 | ```shell 39 | python train_office.py --weighting WEIGHTING --arch ARCH --dataset_path PATH --gpu_id GPU_ID --multi_input 40 | ``` 41 | 42 | ### References 43 | 44 | ```eval_rst 45 | .. bibliography:: 46 | :style: unsrt 47 | :filter: docname in docnames 48 | ``` 49 | 50 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/user_guide/framework.md.txt: -------------------------------------------------------------------------------- 1 | ## Overall Framework 2 | 3 | ``LibMTL`` provides a unified running framework to train a MTL model with some kind of architectures or weighting strategies on a given dataset. The overall framework is shown below. There are five modules to support to run. We introduce them as follows. 4 | 5 | - **Config Module**: Responsible for all the configuration parameters involved in the running framework, including the parameters of optimizer and learning rate scheduler, the hyper-parameters of MTL model, training configuration like batch size, total epoch, random seed and so on. 6 | - **Dataloaders Module**: Responsible for data pre-processing and loading. 7 | - **Model Module**: Responsible for inheriting classes architecture and weighting and instantiating a MTL model. Note that the architecture and the weighting strategy determine the forward and backward processes of the MTL model, respectively. 8 | - **Losses Module**: Responsible for computing the loss for each task. 9 | - **Metrics Module**: Responsible for evaluating the MTL model and calculating the metric scores for each task. 10 | 11 | ```eval_rst 12 | .. figure:: ../images/framework.png 13 | :scale: 100% 14 | ``` 15 | 16 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | .. LibMTL documentation master file, created by 2 | sphinx-quickstart on Thu Nov 25 17:02:04 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | LibMTL: A PyTorch Library for Multi-Task Learning 7 | ================================================= 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Getting Started: 12 | 13 | docs/getting_started/introduction 14 | docs/getting_started/installation 15 | docs/getting_started/quick_start 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | :caption: User Guide: 20 | 21 | docs/user_guide/mtl 22 | docs/user_guide/framework 23 | docs/user_guide/benchmark 24 | 25 | .. toctree:: 26 | :maxdepth: 1 27 | :caption: Developer Guide: 28 | 29 | docs/develop/dataset 30 | docs/develop/arch 31 | docs/develop/weighting 32 | 33 | .. toctree:: 34 | :maxdepth: 1 35 | :caption: API Reference: 36 | 37 | docs/_autoapi/LibMTL/index 38 | docs/_autoapi/LibMTL/loss/index 39 | docs/_autoapi/LibMTL/utils/index 40 | docs/_autoapi/LibMTL/model/index 41 | docs/_autoapi/LibMTL/config/index 42 | docs/_autoapi/LibMTL/metrics/index 43 | docs/_autoapi/LibMTL/weighting/index 44 | docs/_autoapi/LibMTL/architecture/index 45 | 46 | 47 | 48 | Indices and tables 49 | ================== 50 | 51 | * :ref:`genindex` 52 | * :ref:`modindex` 53 | * :ref:`search` 54 | -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false 12 | }; -------------------------------------------------------------------------------- /docs/_build/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/file.png -------------------------------------------------------------------------------- /docs/_build/html/_static/graphviz.css: -------------------------------------------------------------------------------- 1 | /* 2 | * graphviz.css 3 | * ~~~~~~~~~~~~ 4 | * 5 | * Sphinx stylesheet -- graphviz extension. 6 | * 7 | * :copyright: Copyright 2007-2021 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | img.graphviz { 13 | border: 0; 14 | max-width: 100%; 15 | } 16 | 17 | object.graphviz { 18 | max-width: 100%; 19 | } 20 | -------------------------------------------------------------------------------- /docs/_build/html/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /docs/_build/html/_static/js/html5shiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/_build/html/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/minus.png -------------------------------------------------------------------------------- /docs/_build/html/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/_static/plus.png -------------------------------------------------------------------------------- /docs/_build/html/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | /* override table width restrictions*/ 2 | @media screen and (min-width: 767px) { 3 | 4 | .wy-table-responsive table td { 5 | /* !important prevents the common CSS stylesheets from overriding 6 | this as on RTD they are loaded after this stylesheet */ 7 | white-space: normal !important; 8 | } 9 | 10 | .wy-table-responsive { 11 | overflow: visible !important; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /docs/_build/html/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/_build/html/objects.inv -------------------------------------------------------------------------------- /docs/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | /* override table width restrictions*/ 2 | @media screen and (min-width: 767px) { 3 | 4 | .wy-table-responsive table td { 5 | /* !important prevents the common CSS stylesheets from overriding 6 | this as on RTD they are loaded after this stylesheet */ 7 | white-space: normal !important; 8 | } 9 | 10 | .wy-table-responsive { 11 | overflow: visible !important; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /docs/_templates/footer.html: -------------------------------------------------------------------------------- 1 | {% extends "!footer.html" %} 2 | {% block extrafooter %} 3 |

4 |

Made with ❤ at and  

5 | {{ super() }} 6 | {% endblock %} 7 | -------------------------------------------------------------------------------- /docs/autoapi_templates/python/module.rst: -------------------------------------------------------------------------------- 1 | {% if not obj.display %} 2 | :orphan: 3 | 4 | {% endif %} 5 | :mod:`{{ obj.name }}` 6 | ======={{ "=" * obj.name|length }} 7 | 8 | .. py:module:: {{ obj.name }} 9 | 10 | {% if obj.docstring %} 11 | .. autoapi-nested-parse:: 12 | 13 | {{ obj.docstring|prepare_docstring|indent(3) }} 14 | 15 | {% endif %} 16 | 17 | {% block content %} 18 | {% if obj.all is not none %} 19 | {% set visible_children = obj.children|selectattr("short_name", "in", obj.all)|list %} 20 | {% elif obj.type is equalto("package") %} 21 | {% set visible_children = obj.children|selectattr("display")|list %} 22 | {% else %} 23 | {% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %} 24 | {% endif %} 25 | {% if visible_children %} 26 | {# {{ obj.type|title }} Contents 27 | {{ "-" * obj.type|length }}--------- #} 28 | 29 | {% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %} 30 | {% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %} 31 | {% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %} 32 | 33 | {# {% block classes scoped %} 34 | {% if visible_classes %} 35 | Classes 36 | ------- 37 | 38 | .. autoapisummary:: 39 | 40 | {% for klass in visible_classes %} 41 | {{ klass.id }} 42 | {% endfor %} 43 | 44 | 45 | {% endif %} 46 | {% endblock %} #} 47 | 48 | {# {% block functions scoped %} 49 | {% if visible_functions %} 50 | Functions 51 | --------- 52 | 53 | .. autoapisummary:: 54 | 55 | {% for function in visible_functions %} 56 | {{ function.id }} 57 | {% endfor %} 58 | 59 | 60 | {% endif %} 61 | {% endblock %} #} 62 | 63 | {% endif %} 64 | {% for obj_item in visible_children %} 65 | {{ obj_item.rendered|indent(0) }} 66 | {% endfor %} 67 | {% endif %} 68 | {% endblock %} 69 | 70 | {% block subpackages %} 71 | {% set visible_subpackages = obj.subpackages|selectattr("display")|list %} 72 | {% if visible_subpackages %} 73 | {# Subpackages 74 | ----------- #} 75 | .. toctree:: 76 | :titlesonly: 77 | :maxdepth: 1 78 | 79 | {% for subpackage in visible_subpackages %} 80 | {{ subpackage.short_name }}/index.rst 81 | {% endfor %} 82 | 83 | 84 | {% endif %} 85 | {% endblock %} 86 | {# {% block submodules %} 87 | {% set visible_submodules = obj.submodules|selectattr("display")|list %} 88 | {% if visible_submodules %} 89 | Submodules 90 | ---------- 91 | .. toctree:: 92 | :titlesonly: 93 | :maxdepth: 1 94 | 95 | {% for submodule in visible_submodules %} 96 | {{ submodule.short_name }}/index.rst 97 | {% endfor %} 98 | 99 | 100 | {% endif %} 101 | {% endblock %} #} 102 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/_record/index.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | :mod:`LibMTL._record` 4 | ===================== 5 | 6 | .. py:module:: LibMTL._record 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/CGC/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.CGC` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.CGC 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: CGC(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.MMoE.MMoE` 14 | 15 | Customized Gate Control (CGC). 16 | 17 | This method is proposed in `Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The numbers of experts shared for all tasks and specific to each task, respectively. Each expert is the encoder network. 23 | :type num_experts: list 24 | 25 | .. py:method:: forward(self, inputs, task_name=None) 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/Cross_stitch/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.Cross_stitch` 2 | ======================================= 3 | 4 | .. py:module:: LibMTL.architecture.Cross_stitch 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: Cross_stitch(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Cross-stitch Networks (Cross_stitch). 16 | 17 | This method is proposed in `Cross-stitch Networks for Multi-task Learning (CVPR 2016) `_ \ 18 | and implemented by us. 19 | 20 | .. warning:: 21 | - :class:`Cross_stitch` does not work with multiple inputs MTL problem, i.e., ``multi_input`` must be ``False``. 22 | 23 | - :class:`Cross_stitch` is only supported with ResNet-based encoder. 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/DSelect_k/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.DSelect_k` 2 | ==================================== 3 | 4 | .. py:module:: LibMTL.architecture.DSelect_k 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: DSelect_k(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.MMoE.MMoE` 14 | 15 | DSelect-k. 16 | 17 | This method is proposed in `DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning (NeurIPS 2021) `_ \ 18 | and implemented by modifying from the `official TensorFlow implementation `_. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The number of experts shared for all tasks. Each expert is the encoder network. 23 | :type num_experts: int 24 | :param num_nonzeros: The number of selected experts. 25 | :type num_nonzeros: int 26 | :param kgamma: A scaling parameter for the smooth-step function. 27 | :type kgamma: float, default=1.0 28 | 29 | .. py:method:: forward(self, inputs, task_name=None) 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/HPS/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.HPS` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.HPS 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: HPS(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Hrad Parameter Sharing (HPS). 16 | 17 | This method is proposed in `Multitask Learning: A Knowledge-Based Source of Inductive Bias (ICML 1993) `_ \ 18 | and implemented by us. 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/MMoE/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.MMoE` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.architecture.MMoE 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MMoE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Multi-gate Mixture-of-Experts (MMoE). 16 | 17 | This method is proposed in `Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts (KDD 2018) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The number of experts shared for all tasks. Each expert is the encoder network. 23 | :type num_experts: int 24 | 25 | .. py:method:: forward(self, inputs, task_name=None) 26 | 27 | 28 | .. py:method:: get_share_params(self) 29 | 30 | 31 | .. py:method:: zero_grad_share_params(self) 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/MTAN/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.MTAN` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.architecture.MTAN 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MTAN(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Multi-Task Attention Network (MTAN). 16 | 17 | This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | .. warning:: 21 | :class:`MTAN` is only supported with ResNet-based encoder. 22 | 23 | 24 | .. py:method:: forward(self, inputs, task_name=None) 25 | 26 | 27 | .. py:method:: get_share_params(self) 28 | 29 | 30 | .. py:method:: zero_grad_share_params(self) 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/PLE/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.PLE` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.PLE 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: PLE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Progressive Layered Extraction (PLE). 16 | 17 | This method is proposed in `Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The numbers of experts shared for all tasks and specific to each task, respectively. Each expert is the encoder network. 23 | :type num_experts: list 24 | 25 | .. warning:: 26 | - :class:`PLE` does not work with multiple inputs MTL problem, i.e., ``multi_input`` must be ``False``. 27 | - :class:`PLE` is only supported with ResNet-based encoder. 28 | 29 | 30 | .. py:method:: forward(self, inputs, task_name=None) 31 | 32 | 33 | .. py:method:: get_share_params(self) 34 | 35 | 36 | .. py:method:: zero_grad_share_params(self) 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/abstract_arch/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.abstract_arch` 2 | ======================================== 3 | 4 | .. py:module:: LibMTL.architecture.abstract_arch 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsArchitecture(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | An abstract class for MTL architectures. 16 | 17 | :param task_name: A list of strings for all tasks. 18 | :type task_name: list 19 | :param encoder_class: A neural network class. 20 | :type encoder_class: class 21 | :param decoders: A dictionary of name-decoder pairs of type (:class:`str`, :class:`torch.nn.Module`). 22 | :type decoders: dict 23 | :param rep_grad: If ``True``, the gradient of the representation for each task can be computed. 24 | :type rep_grad: bool 25 | :param multi_input: Is ``True`` if each task has its own input data, ``False`` otherwise. 26 | :type multi_input: bool 27 | :param device: The device where model and data will be allocated. 28 | :type device: torch.device 29 | :param kwargs: A dictionary of hyperparameters of architecture methods. 30 | :type kwargs: dict 31 | 32 | .. py:method:: forward(self, inputs, task_name=None) 33 | 34 | :param inputs: The input data. 35 | :type inputs: torch.Tensor 36 | :param task_name: The task name corresponding to ``inputs`` if ``multi_input`` is ``True``. 37 | :type task_name: str, default=None 38 | 39 | :returns: A dictionary of name-prediction pairs of type (:class:`str`, :class:`torch.Tensor`). 40 | :rtype: dict 41 | 42 | 43 | .. py:method:: get_share_params(self) 44 | 45 | Return the shared parameters of the model. 46 | 47 | 48 | 49 | .. py:method:: zero_grad_share_params(self) 50 | 51 | Set gradients of the shared parameters to zero. 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/config/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.config` 2 | ==================== 3 | 4 | .. py:module:: LibMTL.config 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:data:: LibMTL_args 12 | 13 | 14 | 15 | 16 | .. py:function:: prepare_args(params) 17 | 18 | Return the configuration of hyperparameters, optimizier, and learning rate scheduler. 19 | 20 | :param params: The command-line arguments. 21 | :type params: argparse.Namespace 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/loss/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.loss` 2 | ================== 3 | 4 | .. py:module:: LibMTL.loss 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsLoss 12 | 13 | Bases: :py:obj:`object` 14 | 15 | An abstract class for loss function. 16 | 17 | 18 | .. py:method:: compute_loss(self, pred, gt) 19 | :property: 20 | 21 | Calculate the loss. 22 | 23 | :param pred: The prediction tensor. 24 | :type pred: torch.Tensor 25 | :param gt: The ground-truth tensor. 26 | :type gt: torch.Tensor 27 | 28 | :returns: The loss. 29 | :rtype: torch.Tensor 30 | 31 | 32 | 33 | .. py:class:: CELoss 34 | 35 | Bases: :py:obj:`AbsLoss` 36 | 37 | The cross entropy loss function. 38 | 39 | 40 | .. py:method:: compute_loss(self, pred, gt) 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/metrics/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.metrics` 2 | ===================== 3 | 4 | .. py:module:: LibMTL.metrics 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsMetric 12 | 13 | Bases: :py:obj:`object` 14 | 15 | An abstract class for the performance metrics of a task. 16 | 17 | .. attribute:: record 18 | 19 | A list of the metric scores in every iteration. 20 | 21 | :type: list 22 | 23 | .. attribute:: bs 24 | 25 | A list of the number of data in every iteration. 26 | 27 | :type: list 28 | 29 | .. py:method:: update_fun(self, pred, gt) 30 | :property: 31 | 32 | Calculate the metric scores in every iteration and update :attr:`record`. 33 | 34 | :param pred: The prediction tensor. 35 | :type pred: torch.Tensor 36 | :param gt: The ground-truth tensor. 37 | :type gt: torch.Tensor 38 | 39 | 40 | .. py:method:: score_fun(self) 41 | :property: 42 | 43 | Calculate the final score (when a epoch ends). 44 | 45 | :returns: A list of metric scores. 46 | :rtype: list 47 | 48 | 49 | .. py:method:: reinit(self) 50 | 51 | Reset :attr:`record` and :attr:`bs` (when a epoch ends). 52 | 53 | 54 | 55 | 56 | .. py:class:: AccMetric 57 | 58 | Bases: :py:obj:`AbsMetric` 59 | 60 | Calculate the accuracy. 61 | 62 | 63 | .. py:method:: update_fun(self, pred, gt) 64 | 65 | 66 | 67 | 68 | .. py:method:: score_fun(self) 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/model/resnet_dilated/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.model.resnet_dilated` 2 | ================================== 3 | 4 | .. py:module:: LibMTL.model.resnet_dilated 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: ResnetDilated(orig_resnet, dilate_scale=8) 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | Base class for all neural network modules. 16 | 17 | Your models should also subclass this class. 18 | 19 | Modules can also contain other Modules, allowing to nest them in 20 | a tree structure. You can assign the submodules as regular attributes:: 21 | 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | class Model(nn.Module): 26 | def __init__(self): 27 | super(Model, self).__init__() 28 | self.conv1 = nn.Conv2d(1, 20, 5) 29 | self.conv2 = nn.Conv2d(20, 20, 5) 30 | 31 | def forward(self, x): 32 | x = F.relu(self.conv1(x)) 33 | return F.relu(self.conv2(x)) 34 | 35 | Submodules assigned in this way will be registered, and will have their 36 | parameters converted too when you call :meth:`to`, etc. 37 | 38 | .. py:method:: forward(self, x) 39 | 40 | Defines the computation performed at every call. 41 | 42 | Should be overridden by all subclasses. 43 | 44 | .. note:: 45 | Although the recipe for forward pass needs to be defined within 46 | this function, one should call the :class:`Module` instance afterwards 47 | instead of this since the former takes care of running the 48 | registered hooks while the latter silently ignores them. 49 | 50 | 51 | .. py:method:: forward_stage(self, x, stage) 52 | 53 | 54 | 55 | .. py:function:: resnet_dilated(basenet, pretrained=True, dilate_scale=8) 56 | 57 | Dilated Residual Network models from `"Dilated Residual Networks" `_ 58 | 59 | :param basenet: The type of ResNet. 60 | :type basenet: str 61 | :param pretrained: If True, returns a model pre-trained on ImageNet. 62 | :type pretrained: bool 63 | :param dilate_scale: The type of dilating process. 64 | :type dilate_scale: {8, 16}, default=8 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/utils/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.utils` 2 | =================== 3 | 4 | .. py:module:: LibMTL.utils 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:function:: set_random_seed(seed) 12 | 13 | Set the random seed for reproducibility. 14 | 15 | :param seed: The random seed. 16 | :type seed: int, default=0 17 | 18 | 19 | .. py:function:: set_device(gpu_id) 20 | 21 | Set the device where model and data will be allocated. 22 | 23 | :param gpu_id: The id of gpu. 24 | :type gpu_id: str, default='0' 25 | 26 | 27 | .. py:function:: count_parameters(model) 28 | 29 | Calculates the number of parameters for a model. 30 | 31 | :param model: A neural network module. 32 | :type model: torch.nn.Module 33 | 34 | 35 | .. py:function:: count_improvement(base_result, new_result, weight) 36 | 37 | Calculate the improvement between two results, 38 | 39 | .. math:: 40 | \Delta_{\mathrm{p}}=100\%\times \frac{1}{T}\sum_{t=1}^T 41 | \frac{1}{M_t}\sum_{m=1}^{M_t}\frac{(-1)^{w_{t,m}}(B_{t,m}-N_{t,m})}{N_{t,m}}. 42 | 43 | :param base_result: A dictionary of scores of all metrics of all tasks. 44 | :type base_result: dict 45 | :param new_result: The same structure with ``base_result``. 46 | :type new_result: dict 47 | :param weight: The same structure with ``base_result`` while each elements is binary integer representing whether higher or lower score is better. 48 | :type weight: dict 49 | 50 | :returns: The improvement between ``new_result`` and ``base_result``. 51 | :rtype: float 52 | 53 | Examples:: 54 | 55 | base_result = {'A': [96, 98], 'B': [0.2]} 56 | new_result = {'A': [93, 99], 'B': [0.5]} 57 | weight = {'A': [1, 0], 'B': [1]} 58 | 59 | print(count_improvement(base_result, new_result, weight)) 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/CAGrad/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.CAGrad` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.weighting.CAGrad 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: CAGrad 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Conflict-Averse Gradient descent (CAGrad). 16 | 17 | This method is proposed in `Conflict-Averse Gradient Descent for Multi-task learning (NeurIPS 2021) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param calpha: A hyperparameter that controls the convergence rate. 21 | :type calpha: float, default=0.5 22 | :param rescale: The type of gradient rescale. 23 | :type rescale: {0, 1, 2}, default=1 24 | 25 | .. warning:: 26 | CAGrad is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 27 | 28 | 29 | .. py:method:: backward(self, losses, **kwargs) 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/DWA/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.DWA` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.DWA 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: DWA 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Dynamic Weight Average (DWA). 16 | 17 | This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param T: The softmax temperature. 21 | :type T: float, default=2.0 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/EW/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.EW` 2 | ========================== 3 | 4 | .. py:module:: LibMTL.weighting.EW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: EW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Equally Weighting (EW). 16 | 17 | The loss weight for each task is always ``1 / T`` in every iteration, where ``T`` means the number of tasks. 18 | 19 | 20 | .. py:method:: backward(self, losses, **kwargs) 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/GLS/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GLS` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.GLS 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GLS 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Geometric Loss Strategy (GLS). 16 | 17 | This method is proposed in `MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning (CVPR 2019 workshop) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: backward(self, losses, **kwargs) 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/GradDrop/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradDrop` 2 | ================================ 3 | 4 | .. py:module:: LibMTL.weighting.GradDrop 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradDrop 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Sign Dropout (GradDrop). 16 | 17 | This method is proposed in `Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout (NeurIPS 2020) `_ \ 18 | and implemented by us. 19 | 20 | :param leak: The leak parameter for the weighting matrix. 21 | :type leak: float, default=0.0 22 | 23 | .. warning:: 24 | GradDrop is not supported with parameter gradients, i.e., ``rep_grad`` must be ``True``. 25 | 26 | 27 | .. py:method:: backward(self, losses, **kwargs) 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/GradNorm/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradNorm` 2 | ================================ 3 | 4 | .. py:module:: LibMTL.weighting.GradNorm 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradNorm 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Normalization (GradNorm). 16 | 17 | This method is proposed in `GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks (ICML 2018) `_ \ 18 | and implemented by us. 19 | 20 | :param alpha: The strength of the restoring force which pulls tasks back to a common training rate. 21 | :type alpha: float, default=1.5 22 | 23 | .. py:method:: init_param(self) 24 | 25 | 26 | .. py:method:: backward(self, losses, **kwargs) 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/GradVac/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradVac` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.weighting.GradVac 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradVac 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Vaccine (GradVac). 16 | 17 | This method is proposed in `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) `_ \ 18 | and implemented by us. 19 | 20 | :param beta: The exponential moving average (EMA) decay parameter. 21 | :type beta: float, default=0.5 22 | 23 | .. warning:: 24 | GradVac is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 25 | 26 | 27 | .. py:method:: backward(self, losses, **kwargs) 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/IMTL/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.IMTL` 2 | ============================ 3 | 4 | .. py:module:: LibMTL.weighting.IMTL 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: IMTL 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Impartial Multi-task Learning (IMTL). 16 | 17 | This method is proposed in `Towards Impartial Multi-task Learning (ICLR 2021) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: init_param(self) 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/MGDA/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.MGDA` 2 | ============================ 3 | 4 | .. py:module:: LibMTL.weighting.MGDA 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MGDA 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Multiple Gradient Descent Algorithm (MGDA). 16 | 17 | This method is proposed in `Multi-Task Learning as Multi-Objective Optimization (NeurIPS 2018) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param mgda_gn: The type of gradient normalization. 21 | :type mgda_gn: {'none', 'l2', 'loss', 'loss+'}, default='none' 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/PCGrad/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.PCGrad` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.weighting.PCGrad 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: PCGrad 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Project Conflicting Gradients (PCGrad). 16 | 17 | This method is proposed in `Gradient Surgery for Multi-Task Learning (NeurIPS 2020) `_ \ 18 | and implemented by us. 19 | 20 | .. warning:: 21 | PCGrad is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/RLW/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.RLW` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.RLW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: RLW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Random Loss Weighting (RLW). 16 | 17 | This method is proposed in `A Closer Look at Loss Weighting in Multi-Task Learning (arXiv:2111.10603) `_ \ 18 | and implemented by us. 19 | 20 | :param dist: The type of distribution where the loss weigghts are sampled from. 21 | :type dist: {'Uniform', 'Normal', 'Dirichlet', 'Bernoulli', 'constrained_Bernoulli'}, default='Normal' 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | :param losses: A list of loss of each task. 26 | :type losses: list 27 | :param kwargs: A dictionary of hyperparameters of weighting methods. 28 | :type kwargs: dict 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/UW/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.UW` 2 | ========================== 3 | 4 | .. py:module:: LibMTL.weighting.UW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: UW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Uncertainty Weights (UW). 16 | 17 | This method is proposed in `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (CVPR 2018) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: init_param(self) 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/abstract_weighting/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.abstract_weighting` 2 | ========================================== 3 | 4 | .. py:module:: LibMTL.weighting.abstract_weighting 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsWeighting 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | An abstract class for weighting strategies. 16 | 17 | 18 | .. py:method:: init_param(self) 19 | 20 | Define and initialize some trainable parameters required by specific weighting methods. 21 | 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | :property: 26 | 27 | :param losses: A list of loss of each task. 28 | :type losses: list 29 | :param kwargs: A dictionary of hyperparameters of weighting methods. 30 | :type kwargs: dict 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /docs/docs/develop/arch.md: -------------------------------------------------------------------------------- 1 | ## Customize an Architecture 2 | 3 | Here we introduce how to customize a new architecture with the support of ``LibMTL``. 4 | 5 | ### Create a New Architecture Class 6 | 7 | ```eval_rst 8 | Firstly, you need to create a new architecture class by inheriting class :class:`LibMTL.architecture.AbsArchitecture`. 9 | ``` 10 | 11 | ```python 12 | from LibMTL.architecture import AbsArchitecture 13 | 14 | class NewArchitecture(AbsArchitecture): 15 | def __init__(self, task_name, encoder_class, decoders, rep_grad, 16 | multi_input, device, **kwargs): 17 | super(NewArchitecture, self).__init__(task_name, encoder_class, decoders, rep_grad, 18 | multi_input, device, **kwargs) 19 | ``` 20 | 21 | ### Rewrite Relevant Methods 22 | 23 | ```eval_rst 24 | There are four important functions in :class:`LibMTL.architecture.AbsArchitecture`. 25 | 26 | - :func:`forward`: The forward function and its input/output format can be found in :func:`LibMTL.architecture.AbsArchitecture.forward`. To rewrite this function, you need to consider the case of ``single-input`` and ``multi-input`` (refer to `here <../user_guide/mtl.html#network-architecture>`_) and the case of ``rep-grad`` and ``param-grad`` (refer to `here <../user_guide/mtl.html#weighting-strategy>`_) if you want to combine your architecture with more weighting strategies or apply your architecture to more datasets. 27 | - :func:`get_share_params`: This function is used to return the shared parameters of the model. It returns all the parameters of the encoder by default. You can rewrite it if necessary. 28 | - :func:`zero_grad_share_params`: This function is used to set gradients of the shared parameters to zero. It will set the gradients of all the encoder parameters to zero by default. You can rewrite it if necessary. 29 | - :func:`_prepare_rep`: This function is used to compute the gradients for representations. More details can be found `here <../../_modules/LibMTL/architecture/abstract_arch.html#AbsArchitecture>`_. 30 | ``` 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/docs/develop/weighting.md: -------------------------------------------------------------------------------- 1 | ## Customize a Weighting Strategy 2 | 3 | Here we introduce how to customize a new weighting strategy with the support of ``LibMTL``. 4 | 5 | ### Create a New Weighting Class 6 | 7 | ```eval_rst 8 | Firstly, you need to create a new weighting class by inheriting class :class:`LibMTL.weighting.AbsWeighting`. 9 | ``` 10 | 11 | ```python 12 | from LibMTL.weighting import AbsWeighting 13 | 14 | class NewWeighting(AbsWeighting): 15 | def __init__(self): 16 | super(NewWeighting, self).__init__() 17 | ``` 18 | 19 | ### Rewrite Relevant Methods 20 | 21 | ```eval_rst 22 | There are four important functions in :class:`LibMTL.weighting.AbsWeighting`. 23 | 24 | - :func:`backward`: It is the main function of a weighting strategy whose input and output formats can be found in :func:`LibMTL.weighting.AbsWeighting.backward`. To rewrite this function, you need to consider the case of ``single-input`` and ``multi-input`` (refer to `here <../user_guide/mtl.html#network-architecture>`_) and the case of ``rep-grad`` and ``param-grad`` (refer to `here <../user_guide/mtl.html#weighting-strategy>`_) if you want to combine your weighting method with more architectures or apply your method to more datasets. 25 | - :func:`init_param`: This function is used to define and initialize some trainable parameters. It does nothing by default and can be rewritten if necessary. 26 | - :func:`_get_grads`: This function is used to return the gradients of representations or shared parameters (corresponding to the case of ``rep-grad`` and ``param-grad``, respectively). 27 | - :func:`_backward_new_grads`: This function is used to reset the gradients and make a backward pass (corresponding to the case of ``rep-grad`` and ``param-grad``, respectively). 28 | 29 | The :func:`_get_grads` and :func:`_backward_new_grads` functions are very useful to rewrite the :func:`backward` function and you can find more details `here <../../_modules/LibMTL/weighting/abstract_weighting.html#AbsWeighting>`_. 30 | ``` 31 | 32 | -------------------------------------------------------------------------------- /docs/docs/getting_started/installation.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ### Dependencies 4 | 5 | To install ``LibMTL``, you need to setup the following libraries: 6 | 7 | - Python >= 3.7 8 | - torch >= 1.8.0 9 | - torchvision >= 0.9.0 10 | - numpy >= 1.20 11 | 12 | ### User Installation 13 | 14 | * Create a virtual environment 15 | 16 | ```shell 17 | conda create -n libmtl python=3.8 18 | conda activate libmtl 19 | pip install torch==1.8.0 torchvision==0.9.0 numpy==1.20 20 | ``` 21 | 22 | * Clone the repository 23 | 24 | ```shell 25 | git clone https://github.com/median-research-group/LibMTL.git 26 | ``` 27 | 28 | * Install `LibMTL` 29 | 30 | ```shell 31 | pip install -e . 32 | ``` 33 | -------------------------------------------------------------------------------- /docs/docs/images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/docs/images/framework.png -------------------------------------------------------------------------------- /docs/docs/images/multi_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/docs/images/multi_input.png -------------------------------------------------------------------------------- /docs/docs/images/rep_grad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/docs/docs/images/rep_grad.png -------------------------------------------------------------------------------- /docs/docs/references.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{silberman2012indoor, 2 | title={Indoor segmentation and support inference from rgbd images}, 3 | author={Silberman, Nathan and Hoiem, Derek and Kohli, Pushmeet and Fergus, Rob}, 4 | booktitle={Proceedings of the 8th European Conference on Computer Vision}, 5 | pages={746--760}, 6 | year={2012}, 7 | } 8 | 9 | @inproceedings{ljd19, 10 | author = {Shikun Liu and 11 | Edward Johns and 12 | Andrew J. Davison}, 13 | title = {End-To-End Multi-Task Learning With Attention}, 14 | booktitle = {Proceedings of {IEEE} Conference on Computer Vision and Pattern Recognition}, 15 | pages = {1871--1880}, 16 | year = {2019} 17 | } 18 | 19 | @inproceedings{YuKF17, 20 | author = {Fisher Yu and 21 | Vladlen Koltun and 22 | Thomas A. Funkhouser}, 23 | title = {Dilated Residual Networks}, 24 | booktitle = {Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition}, 25 | pages = {636--644}, 26 | year = {2017}, 27 | } 28 | 29 | @inproceedings{ChenZPSA18, 30 | author = {Liang{-}Chieh Chen and 31 | Yukun Zhu and 32 | George Papandreou and 33 | Florian Schroff and 34 | Hartwig Adam}, 35 | title = {Encoder-Decoder with Atrous Separable Convolution for Semantic Image 36 | Segmentation}, 37 | booktitle = {Proceedings of the 14th European Conference on Computer Vision}, 38 | volume = {11211}, 39 | pages = {833--851}, 40 | year = {2018}, 41 | } 42 | 43 | @article{lin2021rlw, 44 | title={A Closer Look at Loss Weighting in Multi-Task Learning}, 45 | author={Lin, Baijiong and Ye, Feiyang and Zhang, Yu}, 46 | journal={arXiv preprint arXiv:2111.10603}, 47 | year={2021} 48 | } 49 | 50 | @inproceedings{saenko2010adapting, 51 | title={Adapting visual category models to new domains}, 52 | author={Saenko, Kate and Kulis, Brian and Fritz, Mario and Darrell, Trevor}, 53 | booktitle={Proceedings of the 6th European Conference on Computer Vision}, 54 | pages={213--226}, 55 | year={2010}, 56 | } 57 | 58 | @inproceedings{venkateswara2017deep, 59 | title={Deep hashing network for unsupervised domain adaptation}, 60 | author={Venkateswara, Hemanth and Eusebio, Jose and Chakraborty, Shayok and Panchanathan, Sethuraman}, 61 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 62 | pages={5018--5027}, 63 | year={2017} 64 | } 65 | 66 | @article{ZhangY21, 67 | author = {Yu Zhang and 68 | Qiang Yang}, 69 | title = {A Survey on Multi-Task Learning}, 70 | journal = {{IEEE} Transactions on Knowledge and Data Engineering}, 71 | year = {2021}, 72 | } 73 | 74 | @article{Vandenhende21, 75 | author = {Simon Vandenhende and Stamatios Georgoulis and Wouter Van Gansbeke and Marc Proesmans and Dengxin Dai and Luc Van Gool}, 76 | title = {Multi-Task Learning for Dense Prediction Tasks: A Survey}, 77 | journal = {{IEEE} Transactions on Pattern Analysis and Machine Intelligence }, 78 | year = {2021}, 79 | } 80 | 81 | @article{Michael20, 82 | title={Multi-Task Learning with Deep Neural Networks: A Survey}, 83 | author={Michael Crawshaw}, 84 | journal={arXiv preprint arXiv:2009.09796}, 85 | year={2020} 86 | } -------------------------------------------------------------------------------- /docs/docs/user_guide/benchmark.rst: -------------------------------------------------------------------------------- 1 | Run a Benchmark 2 | =============== 3 | 4 | Here we introduce some MTL benchmark datasets and show how to run models on them for a fair comparison. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | benchmark/nyuv2 10 | benchmark/office 11 | -------------------------------------------------------------------------------- /docs/docs/user_guide/benchmark/office.md: -------------------------------------------------------------------------------- 1 | ## Office-31 and Office-Home 2 | 3 | ```eval_rst 4 | The Office-31 dataset :cite:`saenko2010adapting` consists of three classification tasks on three domains: Amazon, DSLR, and Webcam, where each task has 31 object categories. It can be download `here `_. This dataset contains 4,110 labeled images and we randomly split these samples, with 60\% for training, 20\% for validation, and the rest 20\% for testing. 5 | 6 | The Office-Home dataset :cite:`venkateswara2017deep` has four classification tasks on four domains: Artistic images (abbreviated as Art), Clip art, Product images, and Real-world images. It can be download `here `_. This dataset has 15,500 labeled images in total and each domain contains 65 classes. We divide the entire data into the same proportion as the Office-31 dataset. 7 | 8 | Both datasets belong to the multi-input setting in MTL. Thus, the ``multi_input`` must be ``True`` for both of the two office datasets. 9 | 10 | The training codes are available in ``examples/office``. We use the ResNet-18 network pretrained on the ImageNet dataset followed by a fully connected layer as a shared encoder among tasks and a fully connected layer is applied as a task-specific output layer for each task. All the input images are resized to 3x224x224. 11 | ``` 12 | 13 | ### Run a Model 14 | 15 | The script ``train_office.py`` is the main file for training and evaluating a MTL model on the Office-31 or Office-Home dataset. A set of command-line arguments is provided to allow users to adjust the training parameter configuration. 16 | 17 | Some important arguments are described as follows. 18 | 19 | ```eval_rst 20 | - ``weighting``: The weighting strategy. Refer to `here <../_autoapi/LibMTL/weighting/index.html>`_. 21 | - ``arch``: The MTL architecture. Refer to `here <../_autoapi/LibMTL/architecture/index.html>`_. 22 | - ``gpu_id``: The id of gpu. The default value is '0'. 23 | - ``seed``: The random seed for reproducibility. The default value is 0. 24 | - ``optim``: The type of the optimizer. We recommend to use 'adam' here. 25 | - ``dataset``: Training on Office-31 or Office-Home. Options: 'office-31', 'office-home'. 26 | - ``dataset_path``: The path of the Office-31 or Office-Home dataset. 27 | - ``bs``: The batch size of training, validation, and test data. The default value is 64. 28 | ``` 29 | 30 | The complete command-line arguments and their descriptions can be found by running the following command. 31 | 32 | ```shell 33 | python train_office.py -h 34 | ``` 35 | 36 | If you understand those command-line arguments, you can train a MTL model by running a command like this. 37 | 38 | ```shell 39 | python train_office.py --weighting WEIGHTING --arch ARCH --dataset_path PATH --gpu_id GPU_ID --multi_input 40 | ``` 41 | 42 | ### References 43 | 44 | ```eval_rst 45 | .. bibliography:: 46 | :style: unsrt 47 | :filter: docname in docnames 48 | ``` 49 | 50 | -------------------------------------------------------------------------------- /docs/docs/user_guide/framework.md: -------------------------------------------------------------------------------- 1 | ## Overall Framework 2 | 3 | ``LibMTL`` provides a unified framework to train a MTL model with several architectures and weighting strategies on benchmark datasets. The overall framework consists of nine modules as introduced below. 4 | 5 | ```eval_rst 6 | - The Dataloader module is responsible for data pre-processing and loading. 7 | 8 | - The `LibMTL.loss <../_autoapi/LibMTL/loss/index.html>`_ module defines loss functions for each task. 9 | 10 | - The `LibMTL.metrics <../_autoapi/LibMTL/metrics/index.html>`_ module defines evaluation metrics for all the tasks. 11 | 12 | - The `LibMTL.config <../_autoapi/LibMTL/config/index.html>`_ module is responsible for all the configuration parameters involved in the training process, such as the corresponding MTL setting (i.e. the multi-input case or not), the potential hyper-parameters of loss weighting strategies and architectures, the training configuration (e.g., the batch size, the running epoch, the random seed, and the learning rate), and so on. This module adopts command-line arguments to enable users to conveniently set those configuration parameters. 13 | 14 | - The `LibMTL.Trainer <../_autoapi/LibMTL/trainer/index.html>`_ module provides a unified framework for the training process under different MTL settings and for different MTL approaches 15 | 16 | - The `LibMTL.utils <../_autoapi/LibMTL/utils/index.html>`_ module implements some useful functionalities for the training process such as calculating the total number of parameters in an MTL model. 17 | 18 | - The `LibMTL.architecture <../_autoapi/LibMTL/architecture/index.html>`_ module contains the implementations of various architectures in MTL. 19 | 20 | - The `LibMTL.weighting <../_autoapi/LibMTL/weighting/index.html>`_ module contains the implementations of various loss weighting strategies in MTL. 21 | 22 | - The `LibMTL.model <../_autoapi/LibMTL/model/index.html>`_ module includes some popular backbone networks (e.g., ResNet). 23 | ``` 24 | 25 | ```eval_rst 26 | .. figure:: ../images/framework.png 27 | :scale: 50% 28 | ``` 29 | 30 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. LibMTL documentation master file, created by 2 | sphinx-quickstart on Thu Nov 25 17:02:04 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | LibMTL: A PyTorch Library for Multi-Task Learning 7 | ================================================= 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Getting Started: 12 | 13 | docs/getting_started/introduction 14 | docs/getting_started/installation 15 | docs/getting_started/quick_start 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | :caption: User Guide: 20 | 21 | docs/user_guide/mtl 22 | docs/user_guide/framework 23 | docs/user_guide/benchmark 24 | 25 | .. toctree:: 26 | :maxdepth: 1 27 | :caption: Developer Guide: 28 | 29 | docs/develop/dataset 30 | docs/develop/arch 31 | docs/develop/weighting 32 | 33 | .. toctree:: 34 | :maxdepth: 1 35 | :caption: API Reference: 36 | 37 | docs/_autoapi/LibMTL/index 38 | docs/_autoapi/LibMTL/loss/index 39 | docs/_autoapi/LibMTL/utils/index 40 | docs/_autoapi/LibMTL/model/index 41 | docs/_autoapi/LibMTL/config/index 42 | docs/_autoapi/LibMTL/metrics/index 43 | docs/_autoapi/LibMTL/weighting/index 44 | docs/_autoapi/LibMTL/architecture/index 45 | 46 | 47 | 48 | Indices and tables 49 | ================== 50 | 51 | * :ref:`genindex` 52 | * :ref:`modindex` 53 | * :ref:`search` 54 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx==5.1.1 2 | recommonmark==0.7.1 3 | sphinx-autoapi==1.8.4 4 | sphinx-autobuild==2021.3.14 5 | sphinx-markdown-tables==0.0.17 6 | sphinx-rtd-theme==1.0.0 7 | sphinxcontrib-applehelp==1.0.2 8 | sphinxcontrib-bibtex==2.4.1 9 | sphinxcontrib-devhelp==1.0.2 10 | sphinxcontrib-htmlhelp==2.0.0 11 | sphinxcontrib-jsmath==1.0.1 12 | sphinxcontrib-qthelp==1.0.3 13 | sphinxcontrib-serializinghtml==1.1.5 14 | sphinxcontrib-websupport==1.2.3 15 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ## Run a Benchmark 2 | 3 | Here we introduce some MTL benchmark datasets and show how to run models on them for a fair comparison. 4 | 5 | - [The NYUv2 Dataset](https://github.com/median-research-group/LibMTL/tree/main/examples/nyu) 6 | - [The Office-31 and Office-Home Datasets](https://github.com/median-research-group/LibMTL/tree/main/examples/office) 7 | - [The QM9 Dataset](https://github.com/median-research-group/LibMTL/tree/main/examples/qm9) 8 | - [The PAWS-X Dataset from XTREME Benchmark](https://github.com/median-research-group/LibMTL/tree/main/examples/xtreme) 9 | 10 | 11 | -------------------------------------------------------------------------------- /examples/nyusp/__pycache__/aspp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/examples/nyusp/__pycache__/aspp.cpython-38.pyc -------------------------------------------------------------------------------- /examples/nyusp/__pycache__/create_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/examples/nyusp/__pycache__/create_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /examples/nyusp/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/examples/nyusp/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /examples/nyusp/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DeepLabHead(nn.Sequential): 7 | def __init__(self, in_channels, num_classes): 8 | super(DeepLabHead, self).__init__( 9 | ASPP(in_channels, [12, 24, 36]), 10 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 11 | nn.BatchNorm2d(256), 12 | nn.ReLU(), 13 | nn.Conv2d(256, num_classes, 1) 14 | ) 15 | 16 | 17 | class ASPPConv(nn.Sequential): 18 | def __init__(self, in_channels, out_channels, dilation): 19 | modules = [ 20 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 21 | nn.BatchNorm2d(out_channels), 22 | nn.ReLU() 23 | ] 24 | super(ASPPConv, self).__init__(*modules) 25 | 26 | 27 | class ASPPPooling(nn.Sequential): 28 | def __init__(self, in_channels, out_channels): 29 | super(ASPPPooling, self).__init__( 30 | nn.AdaptiveAvgPool2d(1), 31 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 32 | nn.BatchNorm2d(out_channels), 33 | nn.ReLU()) 34 | 35 | def forward(self, x): 36 | size = x.shape[-2:] 37 | x = super(ASPPPooling, self).forward(x) 38 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 39 | 40 | 41 | class ASPP(nn.Module): 42 | def __init__(self, in_channels, atrous_rates): 43 | super(ASPP, self).__init__() 44 | out_channels = 256 45 | modules = [] 46 | modules.append(nn.Sequential( 47 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 48 | nn.BatchNorm2d(out_channels), 49 | nn.ReLU())) 50 | 51 | rate1, rate2, rate3 = tuple(atrous_rates) 52 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 53 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 54 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 55 | modules.append(ASPPPooling(in_channels, out_channels)) 56 | 57 | self.convs = nn.ModuleList(modules) 58 | 59 | self.project = nn.Sequential( 60 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 61 | nn.BatchNorm2d(out_channels), 62 | nn.ReLU(), 63 | nn.Dropout(0.5)) 64 | 65 | def forward(self, x): 66 | res = [] 67 | for conv in self.convs: 68 | res.append(conv(x)) 69 | res = torch.cat(res, dim=1) 70 | return self.project(res) 71 | -------------------------------------------------------------------------------- /examples/nyusp/run.sh: -------------------------------------------------------------------------------- 1 | mkdir -p logs 2 | 3 | GPU=0 4 | seed=0 5 | 6 | 7 | weighting=GeMTL 8 | # Arithmetic, GLS, UW, DWA, RLW, GradNorm, SI, IMTL_L, LSBwD, LSBwoD, AMTL, GeMTL 9 | 10 | arch=HPS 11 | 12 | python main.py \ 13 | --weighting ${weighting} \ 14 | --arch ${arch} \ 15 | --dataset_path /dataset/nyuv2 \ 16 | --gpu_id ${GPU} \ 17 | --seed ${seed} \ 18 | --scheduler step \ 19 | --mode train 20 | 21 | #>> logs/${arch}_${weighting}_seed${seed}.txt -------------------------------------------------------------------------------- /examples/office/__pycache__/create_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvlab-yonsei/Multi-Task-Learning/88a5e91af67375454cdfdb8dec285cc3ef0bbcfe/examples/office/__pycache__/create_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /examples/office/create_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Dataset 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | from LibMTL.utils import get_root_dir 8 | 9 | class office_Dataset(Dataset): 10 | def __init__(self, dataset, root_path, task, mode): 11 | self.transform = transforms.Compose([ 12 | transforms.Resize((224, 224)), 13 | transforms.ToTensor(), 14 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]), 15 | ]) 16 | f = open(os.path.join(get_root_dir(), 'examples/office', 'data_txt/{}/{}_{}.txt'.format(dataset, task, mode)), 'r') 17 | self.img_list = f.readlines() 18 | f.close() 19 | self.root_path = root_path 20 | 21 | def __getitem__(self, i): 22 | img_path = self.img_list[i][:-1].split(' ')[0] 23 | y = int(self.img_list[i][:-1].split(' ')[1]) 24 | img = Image.open(os.path.join(self.root_path, img_path)).convert('RGB') 25 | return self.transform(img), y 26 | 27 | def __len__(self): 28 | return len(self.img_list) 29 | 30 | def office_dataloader(dataset, batchsize, root_path): 31 | if dataset == 'office-31': 32 | tasks = ['amazon', 'dslr', 'webcam'] 33 | elif dataset == 'office-home': 34 | tasks = ['Art', 'Clipart', 'Product', 'Real_World'] 35 | data_loader = {} 36 | iter_data_loader = {} 37 | for k, d in enumerate(tasks): 38 | data_loader[d] = {} 39 | iter_data_loader[d] = {} 40 | for mode in ['train', 'val', 'test']: 41 | shuffle = True if mode == 'train' else False 42 | drop_last = True if mode == 'train' else False 43 | txt_dataset = office_Dataset(dataset, root_path, d, mode) 44 | # print(d, mode, len(txt_dataset)) 45 | data_loader[d][mode] = DataLoader(txt_dataset, 46 | num_workers=2, 47 | pin_memory=True, 48 | batch_size=batchsize, 49 | shuffle=shuffle, 50 | drop_last=drop_last) 51 | iter_data_loader[d][mode] = iter(data_loader[d][mode]) 52 | return data_loader, iter_data_loader 53 | -------------------------------------------------------------------------------- /examples/office/run.sh: -------------------------------------------------------------------------------- 1 | mkdir -p logs 2 | 3 | GPU=0 4 | seed=0 5 | 6 | weighting=GeMTL 7 | # Arithmetic, GLS, UW, DWA, RLW, GradNorm, SI, IMTL_L, LSBwD, LSBwoD, AMTL, GeMTL 8 | 9 | arch=HPS 10 | 11 | python main.py \ 12 | --weighting ${weighting} \ 13 | --arch ${arch} \ 14 | --dataset office-home \ 15 | --dataset_path /dataset/Office-Home \ 16 | --gpu_id ${GPU} \ 17 | --seed ${seed} \ 18 | --scheduler step \ 19 | --mode train 20 | 21 | #>> logs/${arch}_${weighting}_seed${seed}.txt 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | clarabel==0.6.0 2 | cvxpy==1.4.1 3 | ecos==2.0.12 4 | numpy==1.24.4 5 | osqp==0.6.3 6 | Pillow==10.1.0 7 | pybind11==2.11.1 8 | qdldl==0.1.7.post0 9 | scipy==1.10.1 10 | scs==3.2.3 11 | torch==1.8.1+cu111 12 | torchvision==0.9.1+cu111 13 | typing_extensions==4.8.0 14 | torch_geometric==2.2.0 15 | torch_sparse==0.6.10 16 | torch_scatter==2.0.8 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('README.md', 'r', encoding='utf-8') as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name='LibMTL', 8 | version='1.1.5', 9 | description='A PyTorch Library for Multi-Task Learning', 10 | author='Baijiong Lin', 11 | author_email='linbj@mail.sustech.edu.cn', 12 | url='https://github.com/median-research-group/LibMTL', 13 | packages=find_packages(), 14 | license='MIT', 15 | platforms=["all"], 16 | classifiers=['Intended Audience :: Developers', 17 | 'Intended Audience :: Education', 18 | 'Intended Audience :: Science/Research', 19 | 'License :: OSI Approved :: MIT License', 20 | 'Programming Language :: Python :: 3.7', 21 | 'Programming Language :: Python :: 3.8', 22 | 'Programming Language :: Python :: 3.9', 23 | 'Programming Language :: Python :: 3.10', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | 'Topic :: Scientific/Engineering :: Mathematics', 26 | 'Topic :: Software Development :: Libraries',], 27 | long_description=long_description, 28 | long_description_content_type='text/markdown', 29 | install_requires=['torch>=1.8.0', 30 | 'torchvision>=0.9.0', 31 | 'numpy>=1.20'] 32 | ) 33 | 34 | --------------------------------------------------------------------------------