├── .gitattributes ├── .readthedocs.yml ├── .travis.yml ├── LICENSE ├── LibMTL ├── __init__.py ├── _record.py ├── architecture │ ├── CGC.py │ ├── Cross_stitch.py │ ├── DSelect_k.py │ ├── HPS.py │ ├── LTB.py │ ├── MMoE.py │ ├── MTAN.py │ ├── PLE.py │ ├── __init__.py │ └── abstract_arch.py ├── config.py ├── loss.py ├── metrics.py ├── model │ ├── __init__.py │ ├── resnet.py │ └── resnet_dilated.py ├── trainer.py ├── utils.py └── weighting │ ├── Aligned_MTL.py │ ├── CAGrad.py │ ├── DB_MTL.py │ ├── DWA.py │ ├── EW.py │ ├── ExcessMTL.py │ ├── FAMO.py │ ├── FairGrad.py │ ├── GLS.py │ ├── GradDrop.py │ ├── GradNorm.py │ ├── GradVac.py │ ├── IMTL.py │ ├── MGDA.py │ ├── MoCo.py │ ├── MoDo.py │ ├── Nash_MTL.py │ ├── PCGrad.py │ ├── RLW.py │ ├── SDMGrad.py │ ├── STCH.py │ ├── UPGrad.py │ ├── UW.py │ ├── __init__.py │ └── abstract_weighting.py ├── README.md ├── docs ├── Makefile ├── README.md ├── _build │ ├── doctrees │ │ ├── README.doctree │ │ ├── autoapi_templates │ │ │ └── python │ │ │ │ └── module.doctree │ │ ├── docs │ │ │ ├── _autoapi │ │ │ │ └── LibMTL │ │ │ │ │ ├── _record │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── architecture │ │ │ │ │ ├── CGC │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── Cross_stitch │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── DSelect_k │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── HPS │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── MMoE │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── MTAN │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── PLE │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── abstract_arch │ │ │ │ │ │ └── index.doctree │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── config │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── index.doctree │ │ │ │ │ ├── loss │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── metrics │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── model │ │ │ │ │ ├── index.doctree │ │ │ │ │ ├── resnet │ │ │ │ │ │ └── index.doctree │ │ │ │ │ └── resnet_dilated │ │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── trainer │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── utils │ │ │ │ │ └── index.doctree │ │ │ │ │ └── weighting │ │ │ │ │ ├── CAGrad │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── DWA │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── EW │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── GLS │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── GradDrop │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── GradNorm │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── GradVac │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── IMTL │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── MGDA │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── PCGrad │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── RLW │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── UW │ │ │ │ │ └── index.doctree │ │ │ │ │ ├── abstract_weighting │ │ │ │ │ └── index.doctree │ │ │ │ │ └── index.doctree │ │ │ ├── develop │ │ │ │ ├── arch.doctree │ │ │ │ ├── dataset.doctree │ │ │ │ └── weighting.doctree │ │ │ ├── getting_started │ │ │ │ ├── installation.doctree │ │ │ │ ├── introduction.doctree │ │ │ │ └── quick_start.doctree │ │ │ └── user_guide │ │ │ │ ├── benchmark.doctree │ │ │ │ ├── benchmark │ │ │ │ ├── nyuv2.doctree │ │ │ │ └── office.doctree │ │ │ │ ├── framework.doctree │ │ │ │ └── mtl.doctree │ │ ├── environment.pickle │ │ └── index.doctree │ └── html │ │ ├── .buildinfo │ │ ├── README.html │ │ ├── _images │ │ ├── framework.png │ │ ├── multi_input.png │ │ └── rep_grad.png │ │ ├── _modules │ │ ├── LibMTL │ │ │ ├── architecture │ │ │ │ ├── CGC.html │ │ │ │ ├── Cross_stitch.html │ │ │ │ ├── DSelect_k.html │ │ │ │ ├── HPS.html │ │ │ │ ├── MMoE.html │ │ │ │ ├── MTAN.html │ │ │ │ ├── PLE.html │ │ │ │ └── abstract_arch.html │ │ │ ├── config.html │ │ │ ├── loss.html │ │ │ ├── metrics.html │ │ │ ├── model │ │ │ │ ├── resnet.html │ │ │ │ └── resnet_dilated.html │ │ │ ├── trainer.html │ │ │ ├── utils.html │ │ │ └── weighting │ │ │ │ ├── CAGrad.html │ │ │ │ ├── DWA.html │ │ │ │ ├── EW.html │ │ │ │ ├── GLS.html │ │ │ │ ├── GradDrop.html │ │ │ │ ├── GradNorm.html │ │ │ │ ├── GradVac.html │ │ │ │ ├── IMTL.html │ │ │ │ ├── MGDA.html │ │ │ │ ├── PCGrad.html │ │ │ │ ├── RLW.html │ │ │ │ ├── UW.html │ │ │ │ └── abstract_weighting.html │ │ └── index.html │ │ ├── _sources │ │ ├── README.md.txt │ │ ├── autoapi_templates │ │ │ └── python │ │ │ │ └── module.rst.txt │ │ ├── docs │ │ │ ├── _autoapi │ │ │ │ └── LibMTL │ │ │ │ │ ├── _record │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── architecture │ │ │ │ │ ├── CGC │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── Cross_stitch │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── DSelect_k │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── HPS │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── MMoE │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── MTAN │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── PLE │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── abstract_arch │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── config │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── index.rst.txt │ │ │ │ │ ├── loss │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── metrics │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── model │ │ │ │ │ ├── index.rst.txt │ │ │ │ │ ├── resnet │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ └── resnet_dilated │ │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── trainer │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── utils │ │ │ │ │ └── index.rst.txt │ │ │ │ │ └── weighting │ │ │ │ │ ├── CAGrad │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── DWA │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── EW │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── GLS │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── GradDrop │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── GradNorm │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── GradVac │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── IMTL │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── MGDA │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── PCGrad │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── RLW │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── UW │ │ │ │ │ └── index.rst.txt │ │ │ │ │ ├── abstract_weighting │ │ │ │ │ └── index.rst.txt │ │ │ │ │ └── index.rst.txt │ │ │ ├── develop │ │ │ │ ├── arch.md.txt │ │ │ │ ├── dataset.md.txt │ │ │ │ └── weighting.md.txt │ │ │ ├── getting_started │ │ │ │ ├── installation.md.txt │ │ │ │ ├── introduction.md.txt │ │ │ │ └── quick_start.md.txt │ │ │ └── user_guide │ │ │ │ ├── benchmark.rst.txt │ │ │ │ ├── benchmark │ │ │ │ ├── nyuv2.md.txt │ │ │ │ └── office.md.txt │ │ │ │ ├── framework.md.txt │ │ │ │ └── mtl.md.txt │ │ └── index.rst.txt │ │ ├── _static │ │ ├── basic.css │ │ ├── css │ │ │ ├── badge_only.css │ │ │ ├── fonts │ │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ │ ├── fontawesome-webfont.eot │ │ │ │ ├── fontawesome-webfont.svg │ │ │ │ ├── fontawesome-webfont.ttf │ │ │ │ ├── fontawesome-webfont.woff │ │ │ │ ├── fontawesome-webfont.woff2 │ │ │ │ ├── lato-bold-italic.woff │ │ │ │ ├── lato-bold-italic.woff2 │ │ │ │ ├── lato-bold.woff │ │ │ │ ├── lato-bold.woff2 │ │ │ │ ├── lato-normal-italic.woff │ │ │ │ ├── lato-normal-italic.woff2 │ │ │ │ ├── lato-normal.woff │ │ │ │ └── lato-normal.woff2 │ │ │ └── theme.css │ │ ├── doctools.js │ │ ├── documentation_options.js │ │ ├── file.png │ │ ├── graphviz.css │ │ ├── jquery-3.5.1.js │ │ ├── jquery.js │ │ ├── js │ │ │ ├── badge_only.js │ │ │ ├── html5shiv-printshiv.min.js │ │ │ ├── html5shiv.min.js │ │ │ └── theme.js │ │ ├── language_data.js │ │ ├── minus.png │ │ ├── plus.png │ │ ├── pygments.css │ │ ├── searchtools.js │ │ ├── theme_overrides.css │ │ ├── underscore-1.13.1.js │ │ └── underscore.js │ │ ├── autoapi_templates │ │ └── python │ │ │ └── module.html │ │ ├── docs │ │ ├── _autoapi │ │ │ └── LibMTL │ │ │ │ ├── _record │ │ │ │ └── index.html │ │ │ │ ├── architecture │ │ │ │ ├── CGC │ │ │ │ │ └── index.html │ │ │ │ ├── Cross_stitch │ │ │ │ │ └── index.html │ │ │ │ ├── DSelect_k │ │ │ │ │ └── index.html │ │ │ │ ├── HPS │ │ │ │ │ └── index.html │ │ │ │ ├── MMoE │ │ │ │ │ └── index.html │ │ │ │ ├── MTAN │ │ │ │ │ └── index.html │ │ │ │ ├── PLE │ │ │ │ │ └── index.html │ │ │ │ ├── abstract_arch │ │ │ │ │ └── index.html │ │ │ │ └── index.html │ │ │ │ ├── config │ │ │ │ └── index.html │ │ │ │ ├── index.html │ │ │ │ ├── loss │ │ │ │ └── index.html │ │ │ │ ├── metrics │ │ │ │ └── index.html │ │ │ │ ├── model │ │ │ │ ├── index.html │ │ │ │ ├── resnet │ │ │ │ │ └── index.html │ │ │ │ └── resnet_dilated │ │ │ │ │ └── index.html │ │ │ │ ├── trainer │ │ │ │ └── index.html │ │ │ │ ├── utils │ │ │ │ └── index.html │ │ │ │ └── weighting │ │ │ │ ├── CAGrad │ │ │ │ └── index.html │ │ │ │ ├── DWA │ │ │ │ └── index.html │ │ │ │ ├── EW │ │ │ │ └── index.html │ │ │ │ ├── GLS │ │ │ │ └── index.html │ │ │ │ ├── GradDrop │ │ │ │ └── index.html │ │ │ │ ├── GradNorm │ │ │ │ └── index.html │ │ │ │ ├── GradVac │ │ │ │ └── index.html │ │ │ │ ├── IMTL │ │ │ │ └── index.html │ │ │ │ ├── MGDA │ │ │ │ └── index.html │ │ │ │ ├── PCGrad │ │ │ │ └── index.html │ │ │ │ ├── RLW │ │ │ │ └── index.html │ │ │ │ ├── UW │ │ │ │ └── index.html │ │ │ │ ├── abstract_weighting │ │ │ │ └── index.html │ │ │ │ └── index.html │ │ ├── develop │ │ │ ├── arch.html │ │ │ ├── dataset.html │ │ │ └── weighting.html │ │ ├── getting_started │ │ │ ├── installation.html │ │ │ ├── introduction.html │ │ │ └── quick_start.html │ │ └── user_guide │ │ │ ├── benchmark.html │ │ │ ├── benchmark │ │ │ ├── nyuv2.html │ │ │ └── office.html │ │ │ ├── framework.html │ │ │ └── mtl.html │ │ ├── genindex.html │ │ ├── index.html │ │ ├── objects.inv │ │ ├── py-modindex.html │ │ ├── search.html │ │ └── searchindex.js ├── _static │ └── theme_overrides.css ├── _templates │ └── footer.html ├── autoapi_templates │ └── python │ │ └── module.rst ├── conf.py ├── docs │ ├── _autoapi │ │ └── LibMTL │ │ │ ├── _record │ │ │ └── index.rst │ │ │ ├── architecture │ │ │ ├── CGC │ │ │ │ └── index.rst │ │ │ ├── Cross_stitch │ │ │ │ └── index.rst │ │ │ ├── DSelect_k │ │ │ │ └── index.rst │ │ │ ├── HPS │ │ │ │ └── index.rst │ │ │ ├── MMoE │ │ │ │ └── index.rst │ │ │ ├── MTAN │ │ │ │ └── index.rst │ │ │ ├── PLE │ │ │ │ └── index.rst │ │ │ ├── abstract_arch │ │ │ │ └── index.rst │ │ │ └── index.rst │ │ │ ├── config │ │ │ └── index.rst │ │ │ ├── index.rst │ │ │ ├── loss │ │ │ └── index.rst │ │ │ ├── metrics │ │ │ └── index.rst │ │ │ ├── model │ │ │ ├── index.rst │ │ │ ├── resnet │ │ │ │ └── index.rst │ │ │ └── resnet_dilated │ │ │ │ └── index.rst │ │ │ ├── trainer │ │ │ └── index.rst │ │ │ ├── utils │ │ │ └── index.rst │ │ │ └── weighting │ │ │ ├── CAGrad │ │ │ └── index.rst │ │ │ ├── DWA │ │ │ └── index.rst │ │ │ ├── EW │ │ │ └── index.rst │ │ │ ├── GLS │ │ │ └── index.rst │ │ │ ├── GradDrop │ │ │ └── index.rst │ │ │ ├── GradNorm │ │ │ └── index.rst │ │ │ ├── GradVac │ │ │ └── index.rst │ │ │ ├── IMTL │ │ │ └── index.rst │ │ │ ├── MGDA │ │ │ └── index.rst │ │ │ ├── PCGrad │ │ │ └── index.rst │ │ │ ├── RLW │ │ │ └── index.rst │ │ │ ├── UW │ │ │ └── index.rst │ │ │ ├── abstract_weighting │ │ │ └── index.rst │ │ │ └── index.rst │ ├── develop │ │ ├── arch.md │ │ ├── dataset.md │ │ └── weighting.md │ ├── getting_started │ │ ├── installation.md │ │ ├── introduction.md │ │ └── quick_start.md │ ├── images │ │ ├── framework.png │ │ ├── multi_input.png │ │ └── rep_grad.png │ ├── references.bib │ └── user_guide │ │ ├── benchmark.rst │ │ ├── benchmark │ │ ├── nyuv2.md │ │ └── office.md │ │ ├── framework.md │ │ └── mtl.md ├── index.rst └── requirements.txt ├── examples ├── README.md ├── cityscapes │ ├── README.md │ ├── create_dataset.py │ └── main.py ├── nyu │ ├── README.md │ ├── aspp.py │ ├── create_dataset.py │ ├── main.py │ ├── main_segnet.py │ ├── results │ │ ├── resnet_results.pdf │ │ └── segnet_results.pdf │ ├── segnet_mtan.py │ └── utils.py ├── office │ ├── README.md │ ├── create_dataset.py │ ├── data_txt │ │ ├── office-31 │ │ │ ├── amazon_test.txt │ │ │ ├── amazon_train.txt │ │ │ ├── amazon_val.txt │ │ │ ├── dslr_test.txt │ │ │ ├── dslr_train.txt │ │ │ ├── dslr_val.txt │ │ │ ├── webcam_test.txt │ │ │ ├── webcam_train.txt │ │ │ └── webcam_val.txt │ │ └── office-home │ │ │ ├── Art_test.txt │ │ │ ├── Art_train.txt │ │ │ ├── Art_val.txt │ │ │ ├── Clipart_test.txt │ │ │ ├── Clipart_train.txt │ │ │ ├── Clipart_val.txt │ │ │ ├── Product_test.txt │ │ │ ├── Product_train.txt │ │ │ ├── Product_val.txt │ │ │ ├── Real_World_test.txt │ │ │ ├── Real_World_train.txt │ │ │ └── Real_World_val.txt │ └── main.py ├── qm9 │ ├── README.md │ ├── main.py │ ├── random_split.t │ └── utils.py └── xtreme │ ├── README.md │ ├── create_dataset.py │ ├── main.py │ ├── processors │ ├── pawsx.py │ └── utils_sc.py │ ├── propocess_data │ ├── conll.py │ ├── conllu_to_conll.py │ ├── download_data.sh │ └── utils_preprocess.py │ └── utils.py ├── pyproject.toml ├── requirements.txt ├── setup.py └── tests ├── README.md ├── coverage.svg ├── htmlcov ├── coverage_html.js ├── d_0226cf0404de869f___init___py.html ├── d_0226cf0404de869f__record_py.html ├── d_0226cf0404de869f_config_py.html ├── d_0226cf0404de869f_loss_py.html ├── d_0226cf0404de869f_metrics_py.html ├── d_0226cf0404de869f_trainer_py.html ├── d_0226cf0404de869f_utils_py.html ├── d_2d6ae2ceef8f3ee1___init___py.html ├── d_2d6ae2ceef8f3ee1_resnet_dilated_py.html ├── d_2d6ae2ceef8f3ee1_resnet_py.html ├── d_997679bd2e75e2e9_CAGrad_py.html ├── d_997679bd2e75e2e9_DWA_py.html ├── d_997679bd2e75e2e9_EW_py.html ├── d_997679bd2e75e2e9_GLS_py.html ├── d_997679bd2e75e2e9_GradDrop_py.html ├── d_997679bd2e75e2e9_GradNorm_py.html ├── d_997679bd2e75e2e9_GradVac_py.html ├── d_997679bd2e75e2e9_IMTL_py.html ├── d_997679bd2e75e2e9_MGDA_py.html ├── d_997679bd2e75e2e9_Nash_MTL_py.html ├── d_997679bd2e75e2e9_PCGrad_py.html ├── d_997679bd2e75e2e9_RLW_py.html ├── d_997679bd2e75e2e9_UW_py.html ├── d_997679bd2e75e2e9___init___py.html ├── d_997679bd2e75e2e9_abstract_weighting_py.html ├── d_bd1c6167b4f7256d_CGC_py.html ├── d_bd1c6167b4f7256d_Cross_stitch_py.html ├── d_bd1c6167b4f7256d_DSelect_k_py.html ├── d_bd1c6167b4f7256d_HPS_py.html ├── d_bd1c6167b4f7256d_LTB_py.html ├── d_bd1c6167b4f7256d_MMoE_py.html ├── d_bd1c6167b4f7256d_MTAN_py.html ├── d_bd1c6167b4f7256d_PLE_py.html ├── d_bd1c6167b4f7256d___init___py.html ├── d_bd1c6167b4f7256d_abstract_arch_py.html ├── favicon_32.png ├── index.html ├── keybd_closed.png ├── keybd_open.png ├── status.json └── style.css ├── test_nyu.py ├── test_office31.py ├── test_office_home.py ├── test_pawsx.py └── test_qm9.py /.gitattributes: -------------------------------------------------------------------------------- 1 | tests/htmlcov/* linguist-vendored -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/conf.py 5 | 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "3.8" 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 median-research-group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LibMTL/__init__.py: -------------------------------------------------------------------------------- 1 | from . import architecture 2 | from . import model 3 | from . import weighting 4 | from .trainer import Trainer 5 | from . import config 6 | from . import loss 7 | from . import metrics 8 | #from .record import PerformanceMeter 9 | from . import utils -------------------------------------------------------------------------------- /LibMTL/architecture/CGC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.architecture.MMoE import MMoE 7 | 8 | 9 | class CGC(MMoE): 10 | r"""Customized Gate Control (CGC). 11 | 12 | This method is proposed in `Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) `_ \ 13 | and implemented by us. 14 | 15 | Args: 16 | img_size (list): The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224. 17 | num_experts (list): The numbers of experts shared by all the tasks and specific to each task, respectively. Each expert is an encoder network. 18 | 19 | """ 20 | def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): 21 | super(CGC, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 22 | 23 | self.num_experts = {task: self.kwargs['num_experts'][tn+1] for tn, task in enumerate(self.task_name)} 24 | self.num_experts['share'] = self.kwargs['num_experts'][0] 25 | self.experts_specific = nn.ModuleDict({task: nn.ModuleList([encoder_class() for _ in range(self.num_experts[task])]) for task in self.task_name}) 26 | self.gate_specific = nn.ModuleDict({task: nn.Sequential(nn.Linear(self.input_size, 27 | self.num_experts['share']+self.num_experts[task]), 28 | nn.Softmax(dim=-1)) for task in self.task_name}) 29 | 30 | def forward(self, inputs, task_name=None): 31 | experts_shared_rep = torch.stack([e(inputs) for e in self.experts_shared]) 32 | out = {} 33 | for task in self.task_name: 34 | if task_name is not None and task != task_name: 35 | continue 36 | experts_specific_rep = torch.stack([e(inputs) for e in self.experts_specific[task]]) 37 | selector = self.gate_specific[task](torch.flatten(inputs, start_dim=1)) 38 | gate_rep = torch.einsum('ij..., ji -> j...', 39 | torch.cat([experts_shared_rep, experts_specific_rep], dim=0), 40 | selector) 41 | gate_rep = self._prepare_rep(gate_rep, task, same_rep=False) 42 | out[task] = self.decoders[task](gate_rep) 43 | return out 44 | 45 | -------------------------------------------------------------------------------- /LibMTL/architecture/Cross_stitch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.architecture.abstract_arch import AbsArchitecture 7 | 8 | class _transform_resnet_cross(nn.Module): 9 | def __init__(self, encoder_list, task_name, device): 10 | super(_transform_resnet_cross, self).__init__() 11 | 12 | self.task_name = task_name 13 | self.task_num = len(task_name) 14 | self.device = device 15 | self.resnet_conv = nn.ModuleDict({task: nn.Sequential(encoder_list[tn].conv1, encoder_list[tn].bn1, 16 | encoder_list[tn].relu, encoder_list[tn].maxpool) for tn, task in enumerate(self.task_name)}) 17 | self.resnet_layer = nn.ModuleDict({}) 18 | for i in range(4): 19 | self.resnet_layer[str(i)] = nn.ModuleList([]) 20 | for tn in range(self.task_num): 21 | encoder = encoder_list[tn] 22 | self.resnet_layer[str(i)].append(eval('encoder.layer'+str(i+1))) 23 | self.cross_unit = nn.Parameter(torch.ones(4, self.task_num, self.task_num)) 24 | 25 | def forward(self, inputs): 26 | s_rep = {task: self.resnet_conv[task](inputs) for task in self.task_name} 27 | ss_rep = {i: [0]*self.task_num for i in range(4)} 28 | for i in range(4): 29 | for tn, task in enumerate(self.task_name): 30 | if i == 0: 31 | ss_rep[i][tn] = self.resnet_layer[str(i)][tn](s_rep[task]) 32 | else: 33 | cross_rep = sum([self.cross_unit[i-1][tn][j]*ss_rep[i-1][j] for j in range(self.task_num)]) 34 | ss_rep[i][tn] = self.resnet_layer[str(i)][tn](cross_rep) 35 | return ss_rep[3] 36 | 37 | class Cross_stitch(AbsArchitecture): 38 | r"""Cross-stitch Networks (Cross_stitch). 39 | 40 | This method is proposed in `Cross-stitch Networks for Multi-task Learning (CVPR 2016) `_ \ 41 | and implemented by us. 42 | 43 | .. warning:: 44 | - :class:`Cross_stitch` does not work with multiple inputs MTL problem, i.e., ``multi_input`` must be ``False``. 45 | 46 | - :class:`Cross_stitch` is only supported by ResNet-based encoders. 47 | 48 | """ 49 | def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): 50 | super(Cross_stitch, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 51 | 52 | if self.multi_input: 53 | raise ValueError('No support Cross Stitch for multiple inputs MTL problem') 54 | 55 | self.encoder = nn.ModuleList([self.encoder_class() for _ in range(self.task_num)]) 56 | self.encoder = _transform_resnet_cross(self.encoder, task_name, device) 57 | -------------------------------------------------------------------------------- /LibMTL/architecture/HPS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.architecture.abstract_arch import AbsArchitecture 7 | 8 | 9 | class HPS(AbsArchitecture): 10 | r"""Hard Parameter Sharing (HPS). 11 | 12 | This method is proposed in `Multitask Learning: A Knowledge-Based Source of Inductive Bias (ICML 1993) `_ \ 13 | and implemented by us. 14 | """ 15 | def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): 16 | super(HPS, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 17 | self.encoder = self.encoder_class() 18 | -------------------------------------------------------------------------------- /LibMTL/architecture/MMoE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.architecture.abstract_arch import AbsArchitecture 7 | 8 | class MMoE(AbsArchitecture): 9 | r"""Multi-gate Mixture-of-Experts (MMoE). 10 | 11 | This method is proposed in `Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts (KDD 2018) `_ \ 12 | and implemented by us. 13 | 14 | Args: 15 | img_size (list): The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224. 16 | num_experts (int): The number of experts shared for all tasks. Each expert is an encoder network. 17 | 18 | """ 19 | def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): 20 | super(MMoE, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 21 | 22 | self.img_size = self.kwargs['img_size'] 23 | self.input_size = np.array(self.img_size, dtype=int).prod() 24 | self.num_experts = self.kwargs['num_experts'][0] 25 | self.experts_shared = nn.ModuleList([encoder_class() for _ in range(self.num_experts)]) 26 | self.gate_specific = nn.ModuleDict({task: nn.Sequential(nn.Linear(self.input_size, self.num_experts), 27 | nn.Softmax(dim=-1)) for task in self.task_name}) 28 | 29 | def forward(self, inputs, task_name=None): 30 | experts_shared_rep = torch.stack([e(inputs) for e in self.experts_shared]) 31 | out = {} 32 | for task in self.task_name: 33 | if task_name is not None and task != task_name: 34 | continue 35 | selector = self.gate_specific[task](torch.flatten(inputs, start_dim=1)) 36 | gate_rep = torch.einsum('ij..., ji -> j...', experts_shared_rep, selector) 37 | gate_rep = self._prepare_rep(gate_rep, task, same_rep=False) 38 | out[task] = self.decoders[task](gate_rep) 39 | return out 40 | 41 | def get_share_params(self): 42 | return self.experts_shared.parameters() 43 | 44 | def zero_grad_share_params(self): 45 | self.experts_shared.zero_grad(set_to_none=False) 46 | -------------------------------------------------------------------------------- /LibMTL/architecture/__init__.py: -------------------------------------------------------------------------------- 1 | from LibMTL.architecture.abstract_arch import AbsArchitecture 2 | from LibMTL.architecture.HPS import HPS 3 | from LibMTL.architecture.Cross_stitch import Cross_stitch 4 | from LibMTL.architecture.MMoE import MMoE 5 | from LibMTL.architecture.MTAN import MTAN 6 | from LibMTL.architecture.CGC import CGC 7 | from LibMTL.architecture.PLE import PLE 8 | from LibMTL.architecture.DSelect_k import DSelect_k 9 | from LibMTL.architecture.LTB import LTB 10 | 11 | __all__ = ['AbsArchitecture', 12 | 'HPS', 13 | 'Cross_stitch', 14 | 'MMoE', 15 | 'MTAN', 16 | 'CGC', 17 | 'PLE', 18 | 'DSelect_k', 19 | 'LTB'] -------------------------------------------------------------------------------- /LibMTL/architecture/abstract_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class AbsArchitecture(nn.Module): 8 | r"""An abstract class for MTL architectures. 9 | 10 | Args: 11 | task_name (list): A list of strings for all tasks. 12 | encoder_class (class): A neural network class. 13 | decoders (dict): A dictionary of name-decoder pairs of type (:class:`str`, :class:`torch.nn.Module`). 14 | rep_grad (bool): If ``True``, the gradient of the representation for each task can be computed. 15 | multi_input (bool): Is ``True`` if each task has its own input data, otherwise is ``False``. 16 | device (torch.device): The device where model and data will be allocated. 17 | kwargs (dict): A dictionary of hyperparameters of architectures. 18 | 19 | """ 20 | def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): 21 | super(AbsArchitecture, self).__init__() 22 | 23 | self.task_name = task_name 24 | self.task_num = len(task_name) 25 | self.encoder_class = encoder_class 26 | self.decoders = decoders 27 | self.rep_grad = rep_grad 28 | self.multi_input = multi_input 29 | self.device = device 30 | self.kwargs = kwargs 31 | 32 | if self.rep_grad: 33 | self.rep_tasks = {} 34 | self.rep = {} 35 | 36 | def forward(self, inputs, task_name=None): 37 | r""" 38 | 39 | Args: 40 | inputs (torch.Tensor): The input data. 41 | task_name (str, default=None): The task name corresponding to ``inputs`` if ``multi_input`` is ``True``. 42 | 43 | Returns: 44 | dict: A dictionary of name-prediction pairs of type (:class:`str`, :class:`torch.Tensor`). 45 | """ 46 | out = {} 47 | s_rep = self.encoder(inputs) 48 | same_rep = True if not isinstance(s_rep, list) and not self.multi_input else False 49 | for tn, task in enumerate(self.task_name): 50 | if task_name is not None and task != task_name: 51 | continue 52 | ss_rep = s_rep[tn] if isinstance(s_rep, list) else s_rep 53 | ss_rep = self._prepare_rep(ss_rep, task, same_rep) 54 | out[task] = self.decoders[task](ss_rep) 55 | return out 56 | 57 | def get_share_params(self): 58 | r"""Return the shared parameters of the model. 59 | """ 60 | return self.encoder.parameters() 61 | 62 | def zero_grad_share_params(self): 63 | r"""Set gradients of the shared parameters to zero. 64 | """ 65 | self.encoder.zero_grad(set_to_none=False) 66 | 67 | def _prepare_rep(self, rep, task, same_rep=None): 68 | if self.rep_grad: 69 | if not same_rep: 70 | self.rep[task] = rep 71 | else: 72 | self.rep = rep 73 | self.rep_tasks[task] = rep.detach().clone() 74 | self.rep_tasks[task].requires_grad = True 75 | return self.rep_tasks[task] 76 | else: 77 | return rep 78 | -------------------------------------------------------------------------------- /LibMTL/loss.py: -------------------------------------------------------------------------------- 1 | import torch, time 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class AbsLoss(object): 7 | r"""An abstract class for loss functions. 8 | """ 9 | def __init__(self): 10 | self.record = [] 11 | self.bs = [] 12 | 13 | def compute_loss(self, pred, gt): 14 | r"""Calculate the loss. 15 | 16 | Args: 17 | pred (torch.Tensor): The prediction tensor. 18 | gt (torch.Tensor): The ground-truth tensor. 19 | 20 | Return: 21 | torch.Tensor: The loss. 22 | """ 23 | pass 24 | 25 | def _update_loss(self, pred, gt): 26 | loss = self.compute_loss(pred, gt) 27 | self.record.append(loss.item()) 28 | self.bs.append(pred.size()[0]) 29 | return loss 30 | 31 | def _average_loss(self): 32 | record = np.array(self.record) 33 | bs = np.array(self.bs) 34 | return (record*bs).sum()/bs.sum() 35 | 36 | def _reinit(self): 37 | self.record = [] 38 | self.bs = [] 39 | 40 | class CELoss(AbsLoss): 41 | r"""The cross-entropy loss function. 42 | """ 43 | def __init__(self): 44 | super(CELoss, self).__init__() 45 | 46 | self.loss_fn = nn.CrossEntropyLoss() 47 | 48 | def compute_loss(self, pred, gt): 49 | r""" 50 | """ 51 | loss = self.loss_fn(pred, gt) 52 | return loss 53 | 54 | class KLDivLoss(AbsLoss): 55 | r"""The Kullback-Leibler divergence loss function. 56 | """ 57 | def __init__(self): 58 | super(KLDivLoss, self).__init__() 59 | 60 | self.loss_fn = nn.KLDivLoss() 61 | 62 | def compute_loss(self, pred, gt): 63 | r""" 64 | """ 65 | loss = self.loss_fn(pred, gt) 66 | return loss 67 | 68 | class L1Loss(AbsLoss): 69 | r"""The Mean Absolute Error (MAE) loss function. 70 | """ 71 | def __init__(self): 72 | super(L1Loss, self).__init__() 73 | 74 | self.loss_fn = nn.L1Loss() 75 | 76 | def compute_loss(self, pred, gt): 77 | r""" 78 | """ 79 | loss = self.loss_fn(pred, gt) 80 | return loss 81 | 82 | class MSELoss(AbsLoss): 83 | r"""The Mean Squared Error (MSE) loss function. 84 | """ 85 | def __init__(self): 86 | super(MSELoss, self).__init__() 87 | 88 | self.loss_fn = nn.MSELoss() 89 | 90 | def compute_loss(self, pred, gt): 91 | r""" 92 | """ 93 | loss = self.loss_fn(pred, gt) 94 | return loss -------------------------------------------------------------------------------- /LibMTL/metrics.py: -------------------------------------------------------------------------------- 1 | import torch, time 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class AbsMetric(object): 7 | r"""An abstract class for the performance metrics of a task. 8 | 9 | Attributes: 10 | record (list): A list of the metric scores in every iteration. 11 | bs (list): A list of the number of data in every iteration. 12 | """ 13 | def __init__(self): 14 | self.record = [] 15 | self.bs = [] 16 | 17 | @property 18 | def update_fun(self, pred, gt): 19 | r"""Calculate the metric scores in every iteration and update :attr:`record`. 20 | 21 | Args: 22 | pred (torch.Tensor): The prediction tensor. 23 | gt (torch.Tensor): The ground-truth tensor. 24 | """ 25 | pass 26 | 27 | @property 28 | def score_fun(self): 29 | r"""Calculate the final score (when an epoch ends). 30 | 31 | Return: 32 | list: A list of metric scores. 33 | """ 34 | pass 35 | 36 | def reinit(self): 37 | r"""Reset :attr:`record` and :attr:`bs` (when an epoch ends). 38 | """ 39 | self.record = [] 40 | self.bs = [] 41 | 42 | # accuracy 43 | class AccMetric(AbsMetric): 44 | r"""Calculate the accuracy. 45 | """ 46 | def __init__(self): 47 | super(AccMetric, self).__init__() 48 | 49 | def update_fun(self, pred, gt): 50 | r""" 51 | """ 52 | pred = F.softmax(pred, dim=-1).max(-1)[1] 53 | self.record.append(gt.eq(pred).sum().item()) 54 | self.bs.append(pred.size()[0]) 55 | 56 | def score_fun(self): 57 | r""" 58 | """ 59 | return [(sum(self.record)/sum(self.bs))] 60 | 61 | 62 | # L1 Error 63 | class L1Metric(AbsMetric): 64 | r"""Calculate the Mean Absolute Error (MAE). 65 | """ 66 | def __init__(self): 67 | super(L1Metric, self).__init__() 68 | 69 | def update_fun(self, pred, gt): 70 | r""" 71 | """ 72 | abs_err = torch.abs(pred - gt) 73 | self.record.append(abs_err.item()) 74 | self.bs.append(pred.size()[0]) 75 | 76 | def score_fun(self): 77 | r""" 78 | """ 79 | records = np.array(self.record) 80 | batch_size = np.array(self.bs) 81 | return [(records*batch_size).sum()/(sum(batch_size))] 82 | -------------------------------------------------------------------------------- /LibMTL/model/__init__.py: -------------------------------------------------------------------------------- 1 | from LibMTL.model.resnet import resnet18 2 | from LibMTL.model.resnet import resnet34 3 | from LibMTL.model.resnet import resnet50 4 | from LibMTL.model.resnet import resnet101 5 | from LibMTL.model.resnet import resnet152 6 | from LibMTL.model.resnet import resnext50_32x4d 7 | from LibMTL.model.resnet import resnext101_32x8d 8 | from LibMTL.model.resnet import wide_resnet50_2 9 | from LibMTL.model.resnet import wide_resnet101_2 10 | from LibMTL.model.resnet_dilated import resnet_dilated 11 | 12 | __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 13 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 14 | 'wide_resnet50_2', 'wide_resnet101_2', 'resnet_dilated'] 15 | -------------------------------------------------------------------------------- /LibMTL/model/resnet_dilated.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import LibMTL.model.resnet as resnet 3 | 4 | class ResnetDilated(nn.Module): 5 | def __init__(self, orig_resnet, dilate_scale=8): 6 | super(ResnetDilated, self).__init__() 7 | from functools import partial 8 | 9 | if dilate_scale == 8: 10 | orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) 11 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) 12 | elif dilate_scale == 16: 13 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) 14 | 15 | # take pre-defined ResNet, except AvgPool and FC 16 | self.conv1 = orig_resnet.conv1 17 | self.bn1 = orig_resnet.bn1 18 | self.relu = orig_resnet.relu 19 | 20 | self.maxpool = orig_resnet.maxpool 21 | self.layer1 = orig_resnet.layer1 22 | self.layer2 = orig_resnet.layer2 23 | self.layer3 = orig_resnet.layer3 24 | self.layer4 = orig_resnet.layer4 25 | 26 | self.feature_dim = orig_resnet.feature_dim 27 | 28 | def _nostride_dilate(self, m, dilate): 29 | classname = m.__class__.__name__ 30 | if classname.find('Conv') != -1: 31 | # the convolution with stride 32 | if m.stride == (2, 2): 33 | m.stride = (1, 1) 34 | if m.kernel_size == (3, 3): 35 | m.dilation = (dilate//2, dilate//2) 36 | m.padding = (dilate//2, dilate//2) 37 | # other convoluions 38 | else: 39 | if m.kernel_size == (3, 3): 40 | m.dilation = (dilate, dilate) 41 | m.padding = (dilate, dilate) 42 | 43 | def forward(self, x): 44 | x = self.relu(self.bn1(self.conv1(x))) 45 | x = self.maxpool(x) 46 | 47 | x = self.layer1(x) 48 | x = self.layer2(x) 49 | x = self.layer3(x) 50 | x = self.layer4(x) 51 | return x 52 | 53 | def resnet_dilated(basenet, pretrained=True, dilate_scale=8): 54 | r"""Dilated Residual Network models from `"Dilated Residual Networks" `_ 55 | 56 | Args: 57 | basenet (str): The type of ResNet. 58 | pretrained (bool): If True, returns a model pre-trained on ImageNet. 59 | dilate_scale ({8, 16}, default=8): The type of dilating process. 60 | """ 61 | return ResnetDilated(resnet.__dict__[basenet](pretrained=pretrained), dilate_scale=dilate_scale) -------------------------------------------------------------------------------- /LibMTL/utils.py: -------------------------------------------------------------------------------- 1 | import random, torch, os 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | def get_root_dir(): 6 | r"""Return the root path of project.""" 7 | return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | 9 | def set_random_seed(seed): 10 | r"""Set the random seed for reproducibility. 11 | 12 | Args: 13 | seed (int, default=0): The random seed. 14 | """ 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | if torch.cuda.is_available(): 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | 22 | def set_device(gpu_id): 23 | r"""Set the device where model and data will be allocated. 24 | 25 | Args: 26 | gpu_id (str, default='0'): The id of gpu. 27 | """ 28 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id 29 | 30 | def count_parameters(model): 31 | r'''Calculate the number of parameters for a model. 32 | 33 | Args: 34 | model (torch.nn.Module): A neural network module. 35 | ''' 36 | trainable_params = 0 37 | non_trainable_params = 0 38 | for p in model.parameters(): 39 | if p.requires_grad: 40 | trainable_params += p.numel() 41 | else: 42 | non_trainable_params += p.numel() 43 | print('='*40) 44 | print('Total Params:', trainable_params + non_trainable_params) 45 | print('Trainable Params:', trainable_params) 46 | print('Non-trainable Params:', non_trainable_params) 47 | 48 | def count_improvement(base_result, new_result, weight): 49 | r"""Calculate the improvement between two results as 50 | 51 | .. math:: 52 | \Delta_{\mathrm{p}}=100\%\times \frac{1}{T}\sum_{t=1}^T 53 | \frac{1}{M_t}\sum_{m=1}^{M_t}\frac{(-1)^{w_{t,m}}(B_{t,m}-N_{t,m})}{B_{t,m}}. 54 | 55 | Args: 56 | base_result (dict): A dictionary of scores of all metrics of all tasks. 57 | new_result (dict): The same structure with ``base_result``. 58 | weight (dict): The same structure with ``base_result`` while each element is binary integer representing whether higher or lower score is better. 59 | 60 | Returns: 61 | float: The improvement between ``new_result`` and ``base_result``. 62 | 63 | Examples:: 64 | 65 | base_result = {'A': [96, 98], 'B': [0.2]} 66 | new_result = {'A': [93, 99], 'B': [0.5]} 67 | weight = {'A': [1, 0], 'B': [1]} 68 | 69 | print(count_improvement(base_result, new_result, weight)) 70 | """ 71 | improvement = 0 72 | count = 0 73 | for task in list(base_result.keys()): 74 | improvement += (((-1)**np.array(weight[task]))*\ 75 | (np.array(base_result[task])-np.array(new_result[task]))/\ 76 | np.array(base_result[task])).mean() 77 | count += 1 78 | return improvement/count 79 | 80 | def set_param(curr_mod, name, param=None, mode='update'): 81 | if '.' in name: 82 | n = name.split('.') 83 | module_name = n[0] 84 | rest = '.'.join(n[1:]) 85 | for name, mod in curr_mod.named_children(): 86 | if module_name == name: 87 | return set_param(mod, rest, param, mode=mode) 88 | else: 89 | if mode == 'update': 90 | delattr(curr_mod, name) 91 | setattr(curr_mod, name, param) 92 | elif mode == 'get': 93 | if hasattr(curr_mod, name): 94 | p = getattr(curr_mod, name) 95 | return p -------------------------------------------------------------------------------- /LibMTL/weighting/Aligned_MTL.py: -------------------------------------------------------------------------------- 1 | import torch, copy 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class Aligned_MTL(AbsWeighting): 9 | r"""Aligned-MTL. 10 | 11 | This method is proposed in `Independent Component Alignment for Multi-Task Learning (CVPR 2023) `_ \ 12 | and implemented by modifying from the `official PyTorch implementation `_. 13 | 14 | """ 15 | def __init__(self): 16 | super(Aligned_MTL, self).__init__() 17 | 18 | def backward(self, losses, **kwargs): 19 | 20 | grads = self._get_grads(losses, mode='backward') 21 | if self.rep_grad: 22 | per_grads, grads = grads[0], grads[1] 23 | 24 | M = torch.matmul(grads, grads.t()) # [num_tasks, num_tasks] 25 | lmbda, V = torch.symeig(M, eigenvectors=True) 26 | tol = ( 27 | torch.max(lmbda) 28 | * max(M.shape[-2:]) 29 | * torch.finfo().eps 30 | ) 31 | rank = sum(lmbda > tol) 32 | 33 | order = torch.argsort(lmbda, dim=-1, descending=True) 34 | lmbda, V = lmbda[order][:rank], V[:, order][:, :rank] 35 | 36 | sigma = torch.diag(1 / lmbda.sqrt()) 37 | B = lmbda[-1].sqrt() * ((V @ sigma) @ V.t()) 38 | alpha = B.sum(0) 39 | 40 | if self.rep_grad: 41 | self._backward_new_grads(alpha, per_grads=per_grads) 42 | else: 43 | self._backward_new_grads(alpha, grads=grads) 44 | return alpha.detach().cpu().numpy() 45 | -------------------------------------------------------------------------------- /LibMTL/weighting/CAGrad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | from scipy.optimize import minimize 9 | 10 | class CAGrad(AbsWeighting): 11 | r"""Conflict-Averse Gradient descent (CAGrad). 12 | 13 | This method is proposed in `Conflict-Averse Gradient Descent for Multi-task learning (NeurIPS 2021) `_ \ 14 | and implemented by modifying from the `official PyTorch implementation `_. 15 | 16 | Args: 17 | calpha (float, default=0.5): A hyperparameter that controls the convergence rate. 18 | rescale ({0, 1, 2}, default=1): The type of the gradient rescaling. 19 | 20 | .. warning:: 21 | CAGrad is not supported by representation gradients, i.e., ``rep_grad`` must be ``False``. 22 | 23 | """ 24 | def __init__(self): 25 | super(CAGrad, self).__init__() 26 | 27 | def backward(self, losses, **kwargs): 28 | calpha, rescale = kwargs['calpha'], kwargs['rescale'] 29 | if self.rep_grad: 30 | raise ValueError('No support method CAGrad with representation gradients (rep_grad=True)') 31 | # per_grads = self._compute_grad(losses, mode='backward', rep_grad=True) 32 | # grads = per_grads.reshape(self.task_num, self.rep.size()[0], -1).sum(1) 33 | else: 34 | self._compute_grad_dim() 35 | grads = self._compute_grad(losses, mode='backward') 36 | 37 | GG = torch.matmul(grads, grads.t()).cpu() # [num_tasks, num_tasks] 38 | g0_norm = (GG.mean()+1e-8).sqrt() # norm of the average gradient 39 | 40 | x_start = np.ones(self.task_num) / self.task_num 41 | bnds = tuple((0,1) for x in x_start) 42 | cons=({'type':'eq','fun':lambda x:1-sum(x)}) 43 | A = GG.numpy() 44 | b = x_start.copy() 45 | c = (calpha*g0_norm+1e-8).item() 46 | def objfn(x): 47 | return (x.reshape(1,-1).dot(A).dot(b.reshape(-1,1))+c*np.sqrt(x.reshape(1,-1).dot(A).dot(x.reshape(-1,1))+1e-8)).sum() 48 | res = minimize(objfn, x_start, bounds=bnds, constraints=cons) 49 | w_cpu = res.x 50 | ww = torch.Tensor(w_cpu).to(self.device) 51 | gw = (grads * ww.view(-1, 1)).sum(0) 52 | gw_norm = gw.norm() 53 | lmbda = c / (gw_norm+1e-8) 54 | g = grads.mean(0) + lmbda * gw 55 | if rescale == 0: 56 | new_grads = g 57 | elif rescale == 1: 58 | new_grads = g / (1+calpha**2) 59 | elif rescale == 2: 60 | new_grads = g / (1 + calpha) 61 | else: 62 | raise ValueError('No support rescale type {}'.format(rescale)) 63 | self._reset_grad(new_grads) 64 | return w_cpu 65 | -------------------------------------------------------------------------------- /LibMTL/weighting/DB_MTL.py: -------------------------------------------------------------------------------- 1 | import torch, random, copy 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class DB_MTL(AbsWeighting): 9 | 10 | def __init__(self): 11 | super(DB_MTL, self).__init__() 12 | 13 | def init_param(self): 14 | self.step = 0 15 | self._compute_grad_dim() 16 | self.grad_buffer = torch.zeros(self.task_num, self.grad_dim).to(self.device) 17 | 18 | def backward(self, losses, **kwargs): 19 | self.step += 1 20 | beta = kwargs['DB_beta'] 21 | beta_sigma = kwargs['DB_beta_sigma'] 22 | 23 | batch_weight = np.ones(len(losses)) 24 | if self.rep_grad: 25 | raise ValueError('No support method DB_MTL with representation gradients (rep_grad=True)') 26 | else: 27 | self._compute_grad_dim() 28 | batch_grads = self._compute_grad(torch.log(losses+1e-8), mode='backward') # [task_num, grad_dim] 29 | 30 | self.grad_buffer = batch_grads + (beta/self.step**beta_sigma) * (self.grad_buffer - batch_grads) 31 | 32 | u_grad = self.grad_buffer.norm(dim=-1) 33 | 34 | alpha = u_grad.max() / (u_grad + 1e-8) 35 | new_grads = sum([alpha[i] * self.grad_buffer[i] for i in range(self.task_num)]) 36 | 37 | self._reset_grad(new_grads) 38 | return batch_weight -------------------------------------------------------------------------------- /LibMTL/weighting/DWA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class DWA(AbsWeighting): 9 | r"""Dynamic Weight Average (DWA). 10 | 11 | This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) `_ \ 12 | and implemented by modifying from the `official PyTorch implementation `_. 13 | 14 | Args: 15 | T (float, default=2.0): The softmax temperature. 16 | 17 | """ 18 | def __init__(self): 19 | super(DWA, self).__init__() 20 | 21 | def backward(self, losses, **kwargs): 22 | T = kwargs['T'] 23 | if self.epoch > 1: 24 | w_i = torch.Tensor(self.train_loss_buffer[:,self.epoch-1]/self.train_loss_buffer[:,self.epoch-2]).to(self.device) 25 | batch_weight = self.task_num*F.softmax(w_i/T, dim=-1) 26 | else: 27 | batch_weight = torch.ones_like(losses).to(self.device) 28 | loss = torch.mul(losses, batch_weight).sum() 29 | loss.backward() 30 | return batch_weight.detach().cpu().numpy() -------------------------------------------------------------------------------- /LibMTL/weighting/EW.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class EW(AbsWeighting): 9 | r"""Equal Weighting (EW). 10 | 11 | The loss weight for each task is always ``1 / T`` in every iteration, where ``T`` denotes the number of tasks. 12 | 13 | """ 14 | def __init__(self): 15 | super(EW, self).__init__() 16 | 17 | def backward(self, losses, **kwargs): 18 | loss = torch.mul(losses, torch.ones_like(losses).to(self.device)).sum() 19 | loss.backward() 20 | return np.ones(self.task_num) -------------------------------------------------------------------------------- /LibMTL/weighting/ExcessMTL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from LibMTL.weighting.abstract_weighting import AbsWeighting 3 | 4 | class ExcessMTL(AbsWeighting): 5 | r"""ExcessMTL. 6 | 7 | This method is proposed in `Robust Multi-Task Learning with Excess Risks (ICML 2024) `_ \ 8 | and implemented by modifying from the `official PyTorch implementation `_. 9 | 10 | """ 11 | def __init__(self): 12 | super(ExcessMTL, self).__init__() 13 | 14 | def init_param(self): 15 | self.loss_weight = torch.tensor([1.0]*self.task_num, device=self.device, requires_grad=False) 16 | self.grad_sum = None 17 | self.first_epoch = True 18 | 19 | def backward(self, losses, **kwargs): 20 | 21 | grads = self._get_grads(losses, mode='autograd') 22 | if self.rep_grad: 23 | per_grads, grads = grads[0], grads[1] 24 | 25 | grads = [] 26 | for grad in per_grads: 27 | grads.append(torch.sum(grad, dim=0)) 28 | grads = torch.stack(grads) 29 | 30 | if self.grad_sum is None: 31 | self.grad_sum = torch.zeros_like(grads) 32 | 33 | w = torch.zeros(self.task_num, device=self.device) 34 | for i in range(self.task_num): 35 | self.grad_sum[i] += grads[i]**2 36 | grad_i = grads[i] 37 | h_i = torch.sqrt(self.grad_sum[i] + 1e-7) 38 | w[i] = grad_i * (1 / h_i) @ grad_i.t() 39 | 40 | if self.first_epoch: 41 | self.initial_w = w 42 | self.first_epoch = False 43 | else: 44 | w = w / self.initial_w 45 | robust_step_size = kwargs['robust_step_size'] 46 | self.loss_weight = self.loss_weight * torch.exp(w* robust_step_size) 47 | self.loss_weight = self.loss_weight / self.loss_weight.sum() * self.task_num 48 | self.loss_weight = self.loss_weight.detach().clone() 49 | 50 | self.encoder.zero_grad() 51 | loss = torch.mul(losses, self.loss_weight).sum() 52 | loss.backward() 53 | 54 | return self.loss_weight.cpu().numpy() -------------------------------------------------------------------------------- /LibMTL/weighting/FAMO.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | 9 | class FAMO(AbsWeighting): 10 | r"""Fast Adaptive Multitask Optimization (FAMO). 11 | 12 | This method is proposed in `FAMO: Fast Adaptive Multitask Optimization (NeurIPS 2023) `_ \ 13 | and implemented by modifying from the `official PyTorch implementation `_. 14 | 15 | Args: 16 | FAMO_w_lr (float, default=0.025): The learing rate of loss weights. 17 | FAMO_w_gamma (float, default=1e-3): The weight decay of loss weights. 18 | 19 | """ 20 | def __init__(self): 21 | super().__init__() 22 | 23 | def init_param(self): 24 | self.step = 0 25 | self.min_losses = torch.zeros(self.task_num).to(self.device) 26 | self.w = torch.tensor([0.0] * self.task_num, device=self.device, requires_grad=True) 27 | self.w_opt = torch.optim.Adam([self.w], lr=0.0, weight_decay=0.0) 28 | 29 | def backward(self, losses, **kwargs): 30 | self.step += 1 31 | if self.step == 1: 32 | for param_group in self.w_opt.param_groups: 33 | param_group['lr'] = kwargs['FAMO_w_lr'] 34 | param_group['weight_decay'] = kwargs['FAMO_w_gamma'] 35 | 36 | self.prev_losses = losses 37 | z = F.softmax(self.w, -1) 38 | D = losses - self.min_losses + 1e-8 39 | c = (z / D).sum().detach() 40 | loss = (D.log() * z / c).sum() 41 | loss.backward() 42 | return None 43 | 44 | def update_w(self, curr_losses): 45 | delta = (self.prev_losses - self.min_losses + 1e-8).log() - \ 46 | (curr_losses - self.min_losses + 1e-8).log() 47 | with torch.enable_grad(): 48 | d = torch.autograd.grad(F.softmax(self.w, -1), 49 | self.w, 50 | grad_outputs=delta.detach())[0] 51 | self.w_opt.zero_grad(set_to_none=False) 52 | self.w.grad = d 53 | self.w_opt.step() 54 | -------------------------------------------------------------------------------- /LibMTL/weighting/FairGrad.py: -------------------------------------------------------------------------------- 1 | import torch, copy 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from scipy.optimize import least_squares 6 | 7 | from LibMTL.weighting.abstract_weighting import AbsWeighting 8 | 9 | class FairGrad(AbsWeighting): 10 | r"""FairGrad. 11 | 12 | This method is proposed in `Fair Resource Allocation in Multi-Task Learning (ICML 2024) `_ \ 13 | and implemented by modifying from the `official PyTorch implementation `_. 14 | 15 | """ 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def backward(self, losses, **kwargs): 20 | alpha = kwargs['FairGrad_alpha'] 21 | 22 | if self.rep_grad: 23 | raise ValueError('No support method FairGrad with representation gradients (rep_grad=True)') 24 | else: 25 | self._compute_grad_dim() 26 | grads = self._compute_grad(losses, mode='autograd') 27 | 28 | GTG = torch.mm(grads, grads.t()) 29 | 30 | x_start = np.ones(self.task_num) / self.task_num 31 | A = GTG.data.cpu().numpy() 32 | 33 | def objfn(x): 34 | return np.dot(A, x) - np.power(1 / x, 1 / alpha) 35 | 36 | res = least_squares(objfn, x_start, bounds=(0, np.inf)) 37 | w_cpu = res.x 38 | ww = torch.Tensor(w_cpu).to(self.device) 39 | 40 | torch.sum(ww*losses).backward() 41 | 42 | return w_cpu 43 | -------------------------------------------------------------------------------- /LibMTL/weighting/GLS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class GLS(AbsWeighting): 9 | r"""Geometric Loss Strategy (GLS). 10 | 11 | This method is proposed in `MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning (CVPR 2019 workshop) `_ \ 12 | and implemented by us. 13 | 14 | """ 15 | def __init__(self): 16 | super(GLS, self).__init__() 17 | 18 | def backward(self, losses, **kwargs): 19 | loss = torch.pow(losses.prod(), 1./self.task_num) 20 | loss.backward() 21 | batch_weight = losses / (self.task_num * losses.prod()) 22 | return batch_weight.detach().cpu().numpy() -------------------------------------------------------------------------------- /LibMTL/weighting/GradDrop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | 9 | class GradDrop(AbsWeighting): 10 | r"""Gradient Sign Dropout (GradDrop). 11 | 12 | This method is proposed in `Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout (NeurIPS 2020) `_ \ 13 | and implemented by us. 14 | 15 | Args: 16 | leak (float, default=0.0): The leak parameter for the weighting matrix. 17 | 18 | .. warning:: 19 | GradDrop is not supported by parameter gradients, i.e., ``rep_grad`` must be ``True``. 20 | 21 | """ 22 | def __init__(self): 23 | super(GradDrop, self).__init__() 24 | 25 | def backward(self, losses, **kwargs): 26 | leak = kwargs['leak'] 27 | if self.rep_grad: 28 | per_grads = self._compute_grad(losses, mode='backward', rep_grad=True) 29 | else: 30 | raise ValueError('No support method GradDrop with parameter gradients (rep_grad=False)') 31 | 32 | if not isinstance(self.rep, dict): 33 | inputs = self.rep.unsqueeze(0).repeat_interleave(self.task_num, dim=0) 34 | else: 35 | try: 36 | inputs = torch.stack(list(self.rep.values())) 37 | per_grads = torch.stack(per_grads) 38 | except: 39 | raise ValueError('The representation dimensions of different tasks must be consistent') 40 | grads = (per_grads*inputs.sign()).sum(1) 41 | P = 0.5 * (1 + grads.sum(0) / (grads.abs().sum(0)+1e-7)) 42 | U = torch.rand_like(P) 43 | M = P.gt(U).unsqueeze(0).repeat_interleave(self.task_num, dim=0)*grads.gt(0) + \ 44 | P.lt(U).unsqueeze(0).repeat_interleave(self.task_num, dim=0)*grads.lt(0) 45 | M = M.unsqueeze(1).repeat_interleave(per_grads.size()[1], dim=1) 46 | transformed_grad = (per_grads*(leak+(1-leak)*M)) 47 | 48 | if not isinstance(self.rep, dict): 49 | self.rep.backward(transformed_grad.sum(0)) 50 | else: 51 | for tn, task in enumerate(self.task_name): 52 | self.rep[task].backward(transformed_grad[tn], retain_graph=True) 53 | return None -------------------------------------------------------------------------------- /LibMTL/weighting/GradNorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class GradNorm(AbsWeighting): 9 | r"""Gradient Normalization (GradNorm). 10 | 11 | This method is proposed in `GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks (ICML 2018) `_ \ 12 | and implemented by us. 13 | 14 | Args: 15 | alpha (float, default=1.5): The strength of the restoring force which pulls tasks back to a common training rate. 16 | 17 | """ 18 | def __init__(self): 19 | super(GradNorm, self).__init__() 20 | 21 | def init_param(self): 22 | self.loss_scale = nn.Parameter(torch.tensor([1.0]*self.task_num, device=self.device)) 23 | 24 | def backward(self, losses, **kwargs): 25 | alpha = kwargs['alpha'] 26 | if self.epoch >= 1: 27 | loss_scale = self.task_num * F.softmax(self.loss_scale, dim=-1) 28 | grads = self._get_grads(losses, mode='backward') 29 | if self.rep_grad: 30 | per_grads, grads = grads[0], grads[1] 31 | 32 | G_per_loss = torch.norm(loss_scale.unsqueeze(1)*grads, p=2, dim=-1) 33 | G = G_per_loss.mean(0) 34 | L_i = torch.Tensor([losses[tn].item()/self.train_loss_buffer[tn, 0] for tn in range(self.task_num)]).to(self.device) 35 | r_i = L_i/L_i.mean() 36 | constant_term = (G*(r_i**alpha)).detach() 37 | L_grad = (G_per_loss-constant_term).abs().sum(0) 38 | L_grad.backward() 39 | loss_weight = loss_scale.detach().clone() 40 | 41 | if self.rep_grad: 42 | self._backward_new_grads(loss_weight, per_grads=per_grads) 43 | else: 44 | self._backward_new_grads(loss_weight, grads=grads) 45 | return loss_weight.cpu().numpy() 46 | else: 47 | loss = torch.mul(losses, torch.ones_like(losses).to(self.device)).sum() 48 | loss.backward() 49 | return np.ones(self.task_num) 50 | -------------------------------------------------------------------------------- /LibMTL/weighting/IMTL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | 9 | class IMTL(AbsWeighting): 10 | r"""Impartial Multi-task Learning (IMTL). 11 | 12 | This method is proposed in `Towards Impartial Multi-task Learning (ICLR 2021) `_ \ 13 | and implemented by us. 14 | 15 | """ 16 | def __init__(self): 17 | super(IMTL, self).__init__() 18 | 19 | def init_param(self): 20 | self.loss_scale = nn.Parameter(torch.tensor([0.0]*self.task_num, device=self.device)) 21 | 22 | def backward(self, losses, **kwargs): 23 | losses = self.loss_scale.exp()*losses - self.loss_scale 24 | grads = self._get_grads(losses, mode='backward') 25 | if self.rep_grad: 26 | per_grads, grads = grads[0], grads[1] 27 | 28 | grads_unit = grads/torch.norm(grads, p=2, dim=-1, keepdim=True) 29 | 30 | D = grads[0:1].repeat(self.task_num-1, 1) - grads[1:] 31 | U = grads_unit[0:1].repeat(self.task_num-1, 1) - grads_unit[1:] 32 | 33 | alpha = torch.matmul(torch.matmul(grads[0], U.t()), torch.inverse(torch.matmul(D, U.t()))) 34 | alpha = torch.cat((1-alpha.sum().unsqueeze(0), alpha), dim=0) 35 | 36 | if self.rep_grad: 37 | self._backward_new_grads(alpha, per_grads=per_grads) 38 | else: 39 | self._backward_new_grads(alpha, grads=grads) 40 | return alpha.detach().cpu().numpy() 41 | -------------------------------------------------------------------------------- /LibMTL/weighting/MoCo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class MoCo(AbsWeighting): 9 | r"""MoCo. 10 | 11 | This method is proposed in `Mitigating Gradient Bias in Multi-objective Learning: A Provably Convergent Approach (ICLR 2023) `_ \ 12 | and implemented based on the author' sharing code (Heshan Fernando: fernah@rpi.edu). 13 | 14 | Args: 15 | MoCo_beta (float, default=0.5): The learning rate of y. 16 | MoCo_beta_sigma (float, default=0.5): The decay rate of MoCo_beta. 17 | MoCo_gamma (float, default=0.1): The learning rate of lambd. 18 | MoCo_gamma_sigma (float, default=0.5): The decay rate of MoCo_gamma. 19 | MoCo_rho (float, default=0): The \ell_2 regularization parameter of lambda's update. 20 | 21 | .. warning:: 22 | MoCo is not supported by representation gradients, i.e., ``rep_grad`` must be ``False``. 23 | 24 | """ 25 | def __init__(self): 26 | super(MoCo, self).__init__() 27 | 28 | def init_param(self): 29 | self._compute_grad_dim() 30 | self.step = 0 31 | self.y = torch.zeros(self.task_num, self.grad_dim).to(self.device) 32 | self.lambd = (torch.ones([self.task_num, ]) / self.task_num).to(self.device) 33 | 34 | def backward(self, losses, **kwargs): 35 | self.step += 1 36 | beta, beta_sigma = kwargs['MoCo_beta'], kwargs['MoCo_beta_sigma'] 37 | gamma, gamma_sigma = kwargs['MoCo_gamma'], kwargs['MoCo_gamma_sigma'] 38 | rho = kwargs['MoCo_rho'] 39 | 40 | if self.rep_grad: 41 | raise ValueError('No support method MoCo with representation gradients (rep_grad=True)') 42 | else: 43 | self._compute_grad_dim() 44 | grads = self._compute_grad(losses, mode='backward') 45 | 46 | with torch.no_grad(): 47 | for tn in range(self.task_num): 48 | grads[tn] = grads[tn]/(grads[tn].norm()+1e-8)*losses[tn] 49 | self.y = self.y - (beta/self.step**beta_sigma) * (self.y - grads) 50 | self.lambd = F.softmax(self.lambd - (gamma/self.step**gamma_sigma) * (self.y@self.y.t()+rho*torch.eye(self.task_num).to(self.device))@self.lambd, -1) 51 | new_grads = self.y.t()@self.lambd 52 | 53 | self._reset_grad(new_grads) 54 | return self.lambd.detach().cpu().numpy() 55 | -------------------------------------------------------------------------------- /LibMTL/weighting/MoDo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | 9 | class MoDo(AbsWeighting): 10 | r"""Multi-objective gradient with Double sampling (MoDo). 11 | 12 | This method is proposed in `Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance (NeurIPS 2023; JMLR 2024) `_ \ 13 | and implemented by modifying from the `official PyTorch implementation `_. 14 | 15 | Args: 16 | MoDo_gamma (float, default=0.001): The learning rate of lambd. 17 | MoDo_rho (float, default=0.1): The \ell_2 regularization parameter of lambda's update. 18 | 19 | """ 20 | def __init__(self): 21 | super().__init__() 22 | 23 | def init_param(self): 24 | self.lambd = 1/self.task_num*torch.ones([self.task_num, ]).to(self.device) 25 | 26 | def _projection2simplex(self, y): 27 | m = len(y) 28 | sorted_y = torch.sort(y, descending=True)[0] 29 | tmpsum = 0.0 30 | tmax_f = (torch.sum(y) - 1.0)/m 31 | for i in range(m-1): 32 | tmpsum += sorted_y[i] 33 | tmax = (tmpsum - 1)/ (i+1.0) 34 | if tmax > sorted_y[i+1]: 35 | tmax_f = tmax 36 | break 37 | return torch.max(y - tmax_f, torch.zeros(m).to(y.device)) 38 | 39 | def backward(self, losses, **kwargs): 40 | # losses: [3, num_tasks] in MoDo 41 | assert self.rep_grad == False, "No support method MoDo with representation gradients (rep_grad=True)" 42 | 43 | MoDo_gamma, MoDo_rho = kwargs['MoDo_gamma'], kwargs['MoDo_rho'] 44 | 45 | grads = [] 46 | for i in range(3): 47 | grads.append(self._get_grads(losses[i], mode='backward')) 48 | grads = torch.stack(grads) 49 | 50 | # average the gradient in decorders if only 3rd gradient is used to update shared part 51 | for task in list(self.decoders.keys()): 52 | for p_idx, p in enumerate(self.decoders[task].parameters()): 53 | p.grad.data = p.grad.data/3 54 | 55 | self.lambd = self._projection2simplex(self.lambd - MoDo_gamma*(grads[0]@(torch.transpose(grads[1], 0, 1)@self.lambd )+MoDo_rho*self.lambd)) 56 | 57 | self._backward_new_grads(self.lambd, grads=grads[2]) 58 | return self.lambd.detach().cpu().numpy() 59 | -------------------------------------------------------------------------------- /LibMTL/weighting/PCGrad.py: -------------------------------------------------------------------------------- 1 | import torch, random 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class PCGrad(AbsWeighting): 9 | r"""Project Conflicting Gradients (PCGrad). 10 | 11 | This method is proposed in `Gradient Surgery for Multi-Task Learning (NeurIPS 2020) `_ \ 12 | and implemented by us. 13 | 14 | .. warning:: 15 | PCGrad is not supported by representation gradients, i.e., ``rep_grad`` must be ``False``. 16 | 17 | """ 18 | def __init__(self): 19 | super(PCGrad, self).__init__() 20 | 21 | def backward(self, losses, **kwargs): 22 | batch_weight = np.ones(len(losses)) 23 | if self.rep_grad: 24 | raise ValueError('No support method PCGrad with representation gradients (rep_grad=True)') 25 | else: 26 | self._compute_grad_dim() 27 | grads = self._compute_grad(losses, mode='backward') # [task_num, grad_dim] 28 | pc_grads = grads.clone() 29 | for tn_i in range(self.task_num): 30 | task_index = list(range(self.task_num)) 31 | random.shuffle(task_index) 32 | for tn_j in task_index: 33 | g_ij = torch.dot(pc_grads[tn_i], grads[tn_j]) 34 | if g_ij < 0: 35 | pc_grads[tn_i] -= g_ij * grads[tn_j] / (grads[tn_j].norm().pow(2)+1e-8) 36 | batch_weight[tn_j] -= (g_ij/(grads[tn_j].norm().pow(2)+1e-8)).item() 37 | new_grads = pc_grads.sum(0) 38 | self._reset_grad(new_grads) 39 | return batch_weight 40 | -------------------------------------------------------------------------------- /LibMTL/weighting/RLW.py: -------------------------------------------------------------------------------- 1 | import torch, random 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | 9 | class RLW(AbsWeighting): 10 | r"""Random Loss Weighting (RLW). 11 | 12 | This method is proposed in `Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning (TMLR 2022) `_ \ 13 | and implemented by us. 14 | 15 | """ 16 | def __init__(self): 17 | super(RLW, self).__init__() 18 | 19 | def backward(self, losses, **kwargs): 20 | batch_weight = F.softmax(torch.randn(self.task_num), dim=-1).to(self.device) 21 | loss = torch.mul(losses, batch_weight).sum() 22 | loss.backward() 23 | return batch_weight.detach().cpu().numpy() 24 | -------------------------------------------------------------------------------- /LibMTL/weighting/SDMGrad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | 9 | class SDMGrad(AbsWeighting): 10 | r"""Stochastic Direction-oriented Multi-objective Gradient descent (SDMGrad). 11 | 12 | This method is proposed in `Direction-oriented Multi-objective Learning: Simple and Provable Stochastic Algorithms (NeurIPS 2023) `_ \ 13 | and implemented by modifying from the `official PyTorch implementation `_. 14 | 15 | Args: 16 | SDMGrad_lamda (float, default=0.3): The regularization hyperparameter. 17 | SDMGrad_niter (int, default=20): The update iteration of loss weights. 18 | 19 | """ 20 | def __init__(self): 21 | super().__init__() 22 | 23 | def init_param(self): 24 | self.w = 1/self.task_num*torch.ones(self.task_num).to(self.device) 25 | 26 | def euclidean_proj_simplex(self, v, s=1): 27 | assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s 28 | v = v.astype(np.float64) 29 | n, = v.shape 30 | if v.sum() == s and np.alltrue(v >= 0): 31 | return v 32 | u = np.sort(v)[::-1] 33 | cssv = np.cumsum(u) 34 | rho = np.nonzero(u * np.arange(1, n + 1) > (cssv - s))[0][-1] 35 | theta = float(cssv[rho] - s) / (rho + 1) 36 | w = (v - theta).clip(min=0) 37 | return w 38 | 39 | def backward(self, losses, **kwargs): 40 | # losses: [3, num_tasks] in SDMGrad 41 | assert self.rep_grad == False, "No support method SDMGrad with representation gradients (rep_grad=True)" 42 | 43 | SDMGrad_lamda, SDMGrad_niter = kwargs['SDMGrad_lamda'], kwargs['SDMGrad_niter'] 44 | 45 | grads = [] 46 | for i in range(3): 47 | grads.append(self._get_grads(losses[i], mode='backward')) 48 | 49 | zeta_grads, xi_grads1, xi_grads2 = grads 50 | GG = torch.mm(xi_grads1, xi_grads2.t()) 51 | GG_diag = torch.diag(GG) 52 | GG_diag = torch.where(GG_diag < 0, torch.zeros_like(GG_diag), GG_diag) 53 | scale = torch.mean(torch.sqrt(GG_diag)) 54 | GG = GG / (scale.pow(2) + 1e-8) 55 | Gg = torch.mean(GG, dim=1) 56 | 57 | self.w.requires_grad = True 58 | optimizer = torch.optim.SGD([self.w], lr=5, momentum=0.5) 59 | for i in range(SDMGrad_niter): 60 | optimizer.zero_grad() 61 | self.w.grad = torch.mv(GG, self.w.detach()) + SDMGrad_lamda * Gg 62 | optimizer.step() 63 | proj = self.euclidean_proj_simplex(self.w.data.cpu().numpy()) 64 | self.w.data.copy_(torch.from_numpy(proj).data) 65 | self.w.requires_grad = False 66 | 67 | g0 = torch.mean(zeta_grads, dim=0) 68 | gw = (zeta_grads * self.w.view(-1, 1)).sum(0) 69 | g = (gw + SDMGrad_lamda * g0) / (1 + SDMGrad_lamda) 70 | 71 | self._reset_grad(g) 72 | return None -------------------------------------------------------------------------------- /LibMTL/weighting/STCH.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class STCH(AbsWeighting): 9 | r"""STCH. 10 | 11 | This method is proposed in `Smooth Tchebycheff Scalarization for Multi-Objective Optimization (ICML 2024) `_ \ 12 | and implemented by modifying from the `official PyTorch implementation `_. 13 | 14 | """ 15 | def __init__(self): 16 | super(STCH, self).__init__() 17 | 18 | def init_param(self): 19 | self.step = 0 20 | self.nadir_vector = None 21 | 22 | self.average_loss = 0.0 23 | self.average_loss_count = 0 24 | 25 | def backward(self, losses, **kwargs): 26 | self.step += 1 27 | mu = kwargs['STCH_mu'] 28 | warmup_epoch = kwargs['STCH_warmup_epoch'] 29 | 30 | batch_weight = np.ones(len(losses)) 31 | 32 | if self.epoch < warmup_epoch: 33 | loss = torch.mul(torch.log(losses+1e-20), torch.ones_like(losses).to(self.device)).sum() 34 | loss.backward() 35 | return batch_weight 36 | elif self.epoch == warmup_epoch: 37 | loss = torch.mul(torch.log(losses+1e-20), torch.ones_like(losses).to(self.device)).sum() 38 | self.average_loss += losses.detach() 39 | self.average_loss_count += 1 40 | 41 | loss.backward() 42 | return batch_weight 43 | else: 44 | if self.nadir_vector == None: 45 | self.nadir_vector = self.average_loss / self.average_loss_count 46 | print(self.nadir_vector) 47 | 48 | losses = torch.log(losses/self.nadir_vector+1e-20) 49 | max_term = torch.max(losses.data).detach() 50 | reg_losses = losses - max_term 51 | 52 | loss = mu * torch.log(torch.sum(torch.exp(reg_losses/mu))) * self.task_num 53 | loss.backward() 54 | 55 | return batch_weight -------------------------------------------------------------------------------- /LibMTL/weighting/UW.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.weighting.abstract_weighting import AbsWeighting 7 | 8 | class UW(AbsWeighting): 9 | r"""Uncertainty Weights (UW). 10 | 11 | This method is proposed in `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (CVPR 2018) `_ \ 12 | and implemented by us. 13 | 14 | """ 15 | def __init__(self): 16 | super(UW, self).__init__() 17 | 18 | def init_param(self): 19 | self.loss_scale = nn.Parameter(torch.tensor([-0.5]*self.task_num, device=self.device)) 20 | 21 | def backward(self, losses, **kwargs): 22 | loss = (losses/(2*self.loss_scale.exp())+self.loss_scale/2).sum() 23 | loss.backward() 24 | return (1/(2*torch.exp(self.loss_scale))).detach().cpu().numpy() 25 | -------------------------------------------------------------------------------- /LibMTL/weighting/__init__.py: -------------------------------------------------------------------------------- 1 | from LibMTL.weighting.abstract_weighting import AbsWeighting 2 | from LibMTL.weighting.EW import EW 3 | from LibMTL.weighting.GradNorm import GradNorm 4 | from LibMTL.weighting.MGDA import MGDA 5 | from LibMTL.weighting.UW import UW 6 | from LibMTL.weighting.DWA import DWA 7 | from LibMTL.weighting.GLS import GLS 8 | from LibMTL.weighting.GradDrop import GradDrop 9 | from LibMTL.weighting.PCGrad import PCGrad 10 | from LibMTL.weighting.GradVac import GradVac 11 | from LibMTL.weighting.IMTL import IMTL 12 | from LibMTL.weighting.CAGrad import CAGrad 13 | from LibMTL.weighting.Nash_MTL import Nash_MTL 14 | from LibMTL.weighting.RLW import RLW 15 | from LibMTL.weighting.MoCo import MoCo 16 | from LibMTL.weighting.Aligned_MTL import Aligned_MTL 17 | from LibMTL.weighting.DB_MTL import DB_MTL 18 | from LibMTL.weighting.STCH import STCH 19 | from LibMTL.weighting.ExcessMTL import ExcessMTL 20 | from LibMTL.weighting.FairGrad import FairGrad 21 | from LibMTL.weighting.FAMO import FAMO 22 | from LibMTL.weighting.MoDo import MoDo 23 | from LibMTL.weighting.SDMGrad import SDMGrad 24 | from LibMTL.weighting.UPGrad import UPGrad 25 | 26 | __all__ = ['AbsWeighting', 27 | 'EW', 28 | 'GradNorm', 29 | 'MGDA', 30 | 'UW', 31 | 'DWA', 32 | 'GLS', 33 | 'GradDrop', 34 | 'PCGrad', 35 | 'GradVac', 36 | 'IMTL', 37 | 'CAGrad', 38 | 'Nash_MTL', 39 | 'RLW', 40 | 'MoCo', 41 | 'Aligned_MTL', 42 | 'DB_MTL', 43 | 'STCH', 44 | 'ExcessMTL', 45 | 'FairGrad', 46 | 'FAMO', 47 | 'MoDo', 48 | 'SDMGrad', 49 | 'UPGrad'] -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documents of LibMTL 2 | 3 | This document is based on [Sphinx](http://sphinx-doc.org/) and uses [Read the Docs](https://readthedocs.org/) for deployment. Besides, it is implemented by referencing from https://github.com/rlworkgroup/garage/tree/master/docs and https://github.com/LibCity/Bigscity-LibCity-Docs. 4 | 5 | 6 | ``` 7 | pip install -r requirements.txt 8 | make html 9 | ``` 10 | -------------------------------------------------------------------------------- /docs/_build/doctrees/README.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/README.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/autoapi_templates/python/module.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/autoapi_templates/python/module.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/_record/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/_record/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/CGC/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/CGC/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/Cross_stitch/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/Cross_stitch/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/DSelect_k/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/DSelect_k/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/HPS/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/HPS/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/MMoE/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/MMoE/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/MTAN/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/MTAN/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/PLE/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/PLE/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/abstract_arch/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/abstract_arch/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/architecture/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/config/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/config/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/loss/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/loss/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/metrics/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/metrics/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/model/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/model/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/model/resnet/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/model/resnet/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/model/resnet_dilated/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/model/resnet_dilated/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/trainer/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/trainer/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/utils/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/utils/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/CAGrad/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/CAGrad/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/DWA/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/DWA/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/EW/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/EW/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GLS/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GLS/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradDrop/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradDrop/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradNorm/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradNorm/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradVac/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/GradVac/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/IMTL/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/IMTL/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/MGDA/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/MGDA/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/PCGrad/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/PCGrad/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/RLW/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/RLW/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/UW/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/UW/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/abstract_weighting/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/abstract_weighting/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/_autoapi/LibMTL/weighting/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/develop/arch.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/develop/arch.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/develop/dataset.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/develop/dataset.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/develop/weighting.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/develop/weighting.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/getting_started/installation.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/getting_started/installation.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/getting_started/introduction.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/getting_started/introduction.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/getting_started/quick_start.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/getting_started/quick_start.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/user_guide/benchmark.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/user_guide/benchmark.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/user_guide/benchmark/nyuv2.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/user_guide/benchmark/nyuv2.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/user_guide/benchmark/office.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/user_guide/benchmark/office.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/user_guide/framework.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/user_guide/framework.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/docs/user_guide/mtl.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/docs/user_guide/mtl.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/environment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/environment.pickle -------------------------------------------------------------------------------- /docs/_build/doctrees/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/doctrees/index.doctree -------------------------------------------------------------------------------- /docs/_build/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 979045b8d53d6dcd1826ddb3a4fcb161 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/_build/html/_images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_images/framework.png -------------------------------------------------------------------------------- /docs/_build/html/_images/multi_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_images/multi_input.png -------------------------------------------------------------------------------- /docs/_build/html/_images/rep_grad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_images/rep_grad.png -------------------------------------------------------------------------------- /docs/_build/html/_sources/README.md.txt: -------------------------------------------------------------------------------- 1 | # Documents of LibMTL 2 | 3 | This document is based on [Sphinx](http://sphinx-doc.org/) and uses [Read the Docs](https://readthedocs.org/) for deployment. Besides, it is implemented by referring to https://github.com/rlworkgroup/garage/tree/master/docs and https://github.com/LibCity/Bigscity-LibCity-Docs. -------------------------------------------------------------------------------- /docs/_build/html/_sources/autoapi_templates/python/module.rst.txt: -------------------------------------------------------------------------------- 1 | {% if not obj.display %} 2 | :orphan: 3 | 4 | {% endif %} 5 | :mod:`{{ obj.name }}` 6 | ======={{ "=" * obj.name|length }} 7 | 8 | .. py:module:: {{ obj.name }} 9 | 10 | {% if obj.docstring %} 11 | .. autoapi-nested-parse:: 12 | 13 | {{ obj.docstring|prepare_docstring|indent(3) }} 14 | 15 | {% endif %} 16 | 17 | {% block content %} 18 | {% if obj.all is not none %} 19 | {% set visible_children = obj.children|selectattr("short_name", "in", obj.all)|list %} 20 | {% elif obj.type is equalto("package") %} 21 | {% set visible_children = obj.children|selectattr("display")|list %} 22 | {% else %} 23 | {% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %} 24 | {% endif %} 25 | {% if visible_children %} 26 | {# {{ obj.type|title }} Contents 27 | {{ "-" * obj.type|length }}--------- #} 28 | 29 | {% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %} 30 | {% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %} 31 | {% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %} 32 | 33 | {# {% block classes scoped %} 34 | {% if visible_classes %} 35 | Classes 36 | ------- 37 | 38 | .. autoapisummary:: 39 | 40 | {% for klass in visible_classes %} 41 | {{ klass.id }} 42 | {% endfor %} 43 | 44 | 45 | {% endif %} 46 | {% endblock %} #} 47 | 48 | {# {% block functions scoped %} 49 | {% if visible_functions %} 50 | Functions 51 | --------- 52 | 53 | .. autoapisummary:: 54 | 55 | {% for function in visible_functions %} 56 | {{ function.id }} 57 | {% endfor %} 58 | 59 | 60 | {% endif %} 61 | {% endblock %} #} 62 | 63 | {% endif %} 64 | {% for obj_item in visible_children %} 65 | {{ obj_item.rendered|indent(0) }} 66 | {% endfor %} 67 | {% endif %} 68 | {% endblock %} 69 | 70 | {% block subpackages %} 71 | {% set visible_subpackages = obj.subpackages|selectattr("display")|list %} 72 | {% if visible_subpackages %} 73 | {# Subpackages 74 | ----------- #} 75 | .. toctree:: 76 | :titlesonly: 77 | :maxdepth: 1 78 | 79 | {% for subpackage in visible_subpackages %} 80 | {{ subpackage.short_name }}/index.rst 81 | {% endfor %} 82 | 83 | 84 | {% endif %} 85 | {% endblock %} 86 | {# {% block submodules %} 87 | {% set visible_submodules = obj.submodules|selectattr("display")|list %} 88 | {% if visible_submodules %} 89 | Submodules 90 | ---------- 91 | .. toctree:: 92 | :titlesonly: 93 | :maxdepth: 1 94 | 95 | {% for submodule in visible_submodules %} 96 | {{ submodule.short_name }}/index.rst 97 | {% endfor %} 98 | 99 | 100 | {% endif %} 101 | {% endblock %} #} 102 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/_record/index.rst.txt: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | :mod:`LibMTL._record` 4 | ===================== 5 | 6 | .. py:module:: LibMTL._record 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/CGC/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.CGC` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.CGC 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: CGC(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.MMoE.MMoE` 14 | 15 | Customized Gate Control (CGC). 16 | 17 | This method is proposed in `Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The numbers of experts shared for all tasks and specific to each task, respectively. Each expert is the encoder network. 23 | :type num_experts: list 24 | 25 | .. py:method:: forward(self, inputs, task_name=None) 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/Cross_stitch/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.Cross_stitch` 2 | ======================================= 3 | 4 | .. py:module:: LibMTL.architecture.Cross_stitch 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: Cross_stitch(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Cross-stitch Networks (Cross_stitch). 16 | 17 | This method is proposed in `Cross-stitch Networks for Multi-task Learning (CVPR 2016) `_ \ 18 | and implemented by us. 19 | 20 | .. warning:: 21 | - :class:`Cross_stitch` does not work with multiple inputs MTL problem, i.e., ``multi_input`` must be ``False``. 22 | 23 | - :class:`Cross_stitch` is only supported with ResNet-based encoder. 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/DSelect_k/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.DSelect_k` 2 | ==================================== 3 | 4 | .. py:module:: LibMTL.architecture.DSelect_k 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: DSelect_k(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.MMoE.MMoE` 14 | 15 | DSelect-k. 16 | 17 | This method is proposed in `DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning (NeurIPS 2021) `_ \ 18 | and implemented by modifying from the `official TensorFlow implementation `_. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The number of experts shared for all tasks. Each expert is the encoder network. 23 | :type num_experts: int 24 | :param num_nonzeros: The number of selected experts. 25 | :type num_nonzeros: int 26 | :param kgamma: A scaling parameter for the smooth-step function. 27 | :type kgamma: float, default=1.0 28 | 29 | .. py:method:: forward(self, inputs, task_name=None) 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/HPS/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.HPS` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.HPS 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: HPS(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Hrad Parameter Sharing (HPS). 16 | 17 | This method is proposed in `Multitask Learning: A Knowledge-Based Source of Inductive Bias (ICML 1993) `_ \ 18 | and implemented by us. 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/MMoE/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.MMoE` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.architecture.MMoE 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MMoE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Multi-gate Mixture-of-Experts (MMoE). 16 | 17 | This method is proposed in `Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts (KDD 2018) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The number of experts shared for all tasks. Each expert is the encoder network. 23 | :type num_experts: int 24 | 25 | .. py:method:: forward(self, inputs, task_name=None) 26 | 27 | 28 | .. py:method:: get_share_params(self) 29 | 30 | 31 | .. py:method:: zero_grad_share_params(self) 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/MTAN/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.MTAN` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.architecture.MTAN 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MTAN(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Multi-Task Attention Network (MTAN). 16 | 17 | This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | .. warning:: 21 | :class:`MTAN` is only supported with ResNet-based encoder. 22 | 23 | 24 | .. py:method:: forward(self, inputs, task_name=None) 25 | 26 | 27 | .. py:method:: get_share_params(self) 28 | 29 | 30 | .. py:method:: zero_grad_share_params(self) 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/PLE/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.PLE` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.PLE 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: PLE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Progressive Layered Extraction (PLE). 16 | 17 | This method is proposed in `Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The numbers of experts shared for all tasks and specific to each task, respectively. Each expert is the encoder network. 23 | :type num_experts: list 24 | 25 | .. warning:: 26 | - :class:`PLE` does not work with multiple inputs MTL problem, i.e., ``multi_input`` must be ``False``. 27 | - :class:`PLE` is only supported with ResNet-based encoder. 28 | 29 | 30 | .. py:method:: forward(self, inputs, task_name=None) 31 | 32 | 33 | .. py:method:: get_share_params(self) 34 | 35 | 36 | .. py:method:: zero_grad_share_params(self) 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/architecture/abstract_arch/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.abstract_arch` 2 | ======================================== 3 | 4 | .. py:module:: LibMTL.architecture.abstract_arch 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsArchitecture(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | An abstract class for MTL architectures. 16 | 17 | :param task_name: A list of strings for all tasks. 18 | :type task_name: list 19 | :param encoder_class: A neural network class. 20 | :type encoder_class: class 21 | :param decoders: A dictionary of name-decoder pairs of type (:class:`str`, :class:`torch.nn.Module`). 22 | :type decoders: dict 23 | :param rep_grad: If ``True``, the gradient of the representation for each task can be computed. 24 | :type rep_grad: bool 25 | :param multi_input: Is ``True`` if each task has its own input data, ``False`` otherwise. 26 | :type multi_input: bool 27 | :param device: The device where model and data will be allocated. 28 | :type device: torch.device 29 | :param kwargs: A dictionary of hyperparameters of architecture methods. 30 | :type kwargs: dict 31 | 32 | .. py:method:: forward(self, inputs, task_name=None) 33 | 34 | :param inputs: The input data. 35 | :type inputs: torch.Tensor 36 | :param task_name: The task name corresponding to ``inputs`` if ``multi_input`` is ``True``. 37 | :type task_name: str, default=None 38 | 39 | :returns: A dictionary of name-prediction pairs of type (:class:`str`, :class:`torch.Tensor`). 40 | :rtype: dict 41 | 42 | 43 | .. py:method:: get_share_params(self) 44 | 45 | Return the shared parameters of the model. 46 | 47 | 48 | 49 | .. py:method:: zero_grad_share_params(self) 50 | 51 | Set gradients of the shared parameters to zero. 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/config/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.config` 2 | ==================== 3 | 4 | .. py:module:: LibMTL.config 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:data:: LibMTL_args 12 | 13 | 14 | 15 | 16 | .. py:function:: prepare_args(params) 17 | 18 | Return the configuration of hyperparameters, optimizier, and learning rate scheduler. 19 | 20 | :param params: The command-line arguments. 21 | :type params: argparse.Namespace 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/loss/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.loss` 2 | ================== 3 | 4 | .. py:module:: LibMTL.loss 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsLoss 12 | 13 | Bases: :py:obj:`object` 14 | 15 | An abstract class for loss function. 16 | 17 | 18 | .. py:method:: compute_loss(self, pred, gt) 19 | :property: 20 | 21 | Calculate the loss. 22 | 23 | :param pred: The prediction tensor. 24 | :type pred: torch.Tensor 25 | :param gt: The ground-truth tensor. 26 | :type gt: torch.Tensor 27 | 28 | :returns: The loss. 29 | :rtype: torch.Tensor 30 | 31 | 32 | 33 | .. py:class:: CELoss 34 | 35 | Bases: :py:obj:`AbsLoss` 36 | 37 | The cross entropy loss function. 38 | 39 | 40 | .. py:method:: compute_loss(self, pred, gt) 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/metrics/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.metrics` 2 | ===================== 3 | 4 | .. py:module:: LibMTL.metrics 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsMetric 12 | 13 | Bases: :py:obj:`object` 14 | 15 | An abstract class for the performance metrics of a task. 16 | 17 | .. attribute:: record 18 | 19 | A list of the metric scores in every iteration. 20 | 21 | :type: list 22 | 23 | .. attribute:: bs 24 | 25 | A list of the number of data in every iteration. 26 | 27 | :type: list 28 | 29 | .. py:method:: update_fun(self, pred, gt) 30 | :property: 31 | 32 | Calculate the metric scores in every iteration and update :attr:`record`. 33 | 34 | :param pred: The prediction tensor. 35 | :type pred: torch.Tensor 36 | :param gt: The ground-truth tensor. 37 | :type gt: torch.Tensor 38 | 39 | 40 | .. py:method:: score_fun(self) 41 | :property: 42 | 43 | Calculate the final score (when a epoch ends). 44 | 45 | :returns: A list of metric scores. 46 | :rtype: list 47 | 48 | 49 | .. py:method:: reinit(self) 50 | 51 | Reset :attr:`record` and :attr:`bs` (when a epoch ends). 52 | 53 | 54 | 55 | 56 | .. py:class:: AccMetric 57 | 58 | Bases: :py:obj:`AbsMetric` 59 | 60 | Calculate the accuracy. 61 | 62 | 63 | .. py:method:: update_fun(self, pred, gt) 64 | 65 | 66 | 67 | 68 | .. py:method:: score_fun(self) 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/model/resnet_dilated/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.model.resnet_dilated` 2 | ================================== 3 | 4 | .. py:module:: LibMTL.model.resnet_dilated 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: ResnetDilated(orig_resnet, dilate_scale=8) 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | Base class for all neural network modules. 16 | 17 | Your models should also subclass this class. 18 | 19 | Modules can also contain other Modules, allowing to nest them in 20 | a tree structure. You can assign the submodules as regular attributes:: 21 | 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | class Model(nn.Module): 26 | def __init__(self): 27 | super(Model, self).__init__() 28 | self.conv1 = nn.Conv2d(1, 20, 5) 29 | self.conv2 = nn.Conv2d(20, 20, 5) 30 | 31 | def forward(self, x): 32 | x = F.relu(self.conv1(x)) 33 | return F.relu(self.conv2(x)) 34 | 35 | Submodules assigned in this way will be registered, and will have their 36 | parameters converted too when you call :meth:`to`, etc. 37 | 38 | .. py:method:: forward(self, x) 39 | 40 | Defines the computation performed at every call. 41 | 42 | Should be overridden by all subclasses. 43 | 44 | .. note:: 45 | Although the recipe for forward pass needs to be defined within 46 | this function, one should call the :class:`Module` instance afterwards 47 | instead of this since the former takes care of running the 48 | registered hooks while the latter silently ignores them. 49 | 50 | 51 | .. py:method:: forward_stage(self, x, stage) 52 | 53 | 54 | 55 | .. py:function:: resnet_dilated(basenet, pretrained=True, dilate_scale=8) 56 | 57 | Dilated Residual Network models from `"Dilated Residual Networks" `_ 58 | 59 | :param basenet: The type of ResNet. 60 | :type basenet: str 61 | :param pretrained: If True, returns a model pre-trained on ImageNet. 62 | :type pretrained: bool 63 | :param dilate_scale: The type of dilating process. 64 | :type dilate_scale: {8, 16}, default=8 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/utils/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.utils` 2 | =================== 3 | 4 | .. py:module:: LibMTL.utils 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:function:: set_random_seed(seed) 12 | 13 | Set the random seed for reproducibility. 14 | 15 | :param seed: The random seed. 16 | :type seed: int, default=0 17 | 18 | 19 | .. py:function:: set_device(gpu_id) 20 | 21 | Set the device where model and data will be allocated. 22 | 23 | :param gpu_id: The id of gpu. 24 | :type gpu_id: str, default='0' 25 | 26 | 27 | .. py:function:: count_parameters(model) 28 | 29 | Calculates the number of parameters for a model. 30 | 31 | :param model: A neural network module. 32 | :type model: torch.nn.Module 33 | 34 | 35 | .. py:function:: count_improvement(base_result, new_result, weight) 36 | 37 | Calculate the improvement between two results, 38 | 39 | .. math:: 40 | \Delta_{\mathrm{p}}=100\%\times \frac{1}{T}\sum_{t=1}^T 41 | \frac{1}{M_t}\sum_{m=1}^{M_t}\frac{(-1)^{w_{t,m}}(B_{t,m}-N_{t,m})}{N_{t,m}}. 42 | 43 | :param base_result: A dictionary of scores of all metrics of all tasks. 44 | :type base_result: dict 45 | :param new_result: The same structure with ``base_result``. 46 | :type new_result: dict 47 | :param weight: The same structure with ``base_result`` while each elements is binary integer representing whether higher or lower score is better. 48 | :type weight: dict 49 | 50 | :returns: The improvement between ``new_result`` and ``base_result``. 51 | :rtype: float 52 | 53 | Examples:: 54 | 55 | base_result = {'A': [96, 98], 'B': [0.2]} 56 | new_result = {'A': [93, 99], 'B': [0.5]} 57 | weight = {'A': [1, 0], 'B': [1]} 58 | 59 | print(count_improvement(base_result, new_result, weight)) 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/CAGrad/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.CAGrad` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.weighting.CAGrad 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: CAGrad 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Conflict-Averse Gradient descent (CAGrad). 16 | 17 | This method is proposed in `Conflict-Averse Gradient Descent for Multi-task learning (NeurIPS 2021) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param calpha: A hyperparameter that controls the convergence rate. 21 | :type calpha: float, default=0.5 22 | :param rescale: The type of gradient rescale. 23 | :type rescale: {0, 1, 2}, default=1 24 | 25 | .. warning:: 26 | CAGrad is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 27 | 28 | 29 | .. py:method:: backward(self, losses, **kwargs) 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/DWA/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.DWA` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.DWA 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: DWA 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Dynamic Weight Average (DWA). 16 | 17 | This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param T: The softmax temperature. 21 | :type T: float, default=2.0 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/EW/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.EW` 2 | ========================== 3 | 4 | .. py:module:: LibMTL.weighting.EW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: EW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Equally Weighting (EW). 16 | 17 | The loss weight for each task is always ``1 / T`` in every iteration, where ``T`` means the number of tasks. 18 | 19 | 20 | .. py:method:: backward(self, losses, **kwargs) 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/GLS/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GLS` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.GLS 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GLS 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Geometric Loss Strategy (GLS). 16 | 17 | This method is proposed in `MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning (CVPR 2019 workshop) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: backward(self, losses, **kwargs) 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/GradDrop/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradDrop` 2 | ================================ 3 | 4 | .. py:module:: LibMTL.weighting.GradDrop 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradDrop 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Sign Dropout (GradDrop). 16 | 17 | This method is proposed in `Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout (NeurIPS 2020) `_ \ 18 | and implemented by us. 19 | 20 | :param leak: The leak parameter for the weighting matrix. 21 | :type leak: float, default=0.0 22 | 23 | .. warning:: 24 | GradDrop is not supported with parameter gradients, i.e., ``rep_grad`` must be ``True``. 25 | 26 | 27 | .. py:method:: backward(self, losses, **kwargs) 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/GradNorm/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradNorm` 2 | ================================ 3 | 4 | .. py:module:: LibMTL.weighting.GradNorm 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradNorm 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Normalization (GradNorm). 16 | 17 | This method is proposed in `GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks (ICML 2018) `_ \ 18 | and implemented by us. 19 | 20 | :param alpha: The strength of the restoring force which pulls tasks back to a common training rate. 21 | :type alpha: float, default=1.5 22 | 23 | .. py:method:: init_param(self) 24 | 25 | 26 | .. py:method:: backward(self, losses, **kwargs) 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/GradVac/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradVac` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.weighting.GradVac 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradVac 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Vaccine (GradVac). 16 | 17 | This method is proposed in `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) `_ \ 18 | and implemented by us. 19 | 20 | :param beta: The exponential moving average (EMA) decay parameter. 21 | :type beta: float, default=0.5 22 | 23 | .. warning:: 24 | GradVac is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 25 | 26 | 27 | .. py:method:: backward(self, losses, **kwargs) 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/IMTL/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.IMTL` 2 | ============================ 3 | 4 | .. py:module:: LibMTL.weighting.IMTL 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: IMTL 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Impartial Multi-task Learning (IMTL). 16 | 17 | This method is proposed in `Towards Impartial Multi-task Learning (ICLR 2021) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: init_param(self) 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/MGDA/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.MGDA` 2 | ============================ 3 | 4 | .. py:module:: LibMTL.weighting.MGDA 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MGDA 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Multiple Gradient Descent Algorithm (MGDA). 16 | 17 | This method is proposed in `Multi-Task Learning as Multi-Objective Optimization (NeurIPS 2018) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param mgda_gn: The type of gradient normalization. 21 | :type mgda_gn: {'none', 'l2', 'loss', 'loss+'}, default='none' 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/PCGrad/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.PCGrad` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.weighting.PCGrad 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: PCGrad 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Project Conflicting Gradients (PCGrad). 16 | 17 | This method is proposed in `Gradient Surgery for Multi-Task Learning (NeurIPS 2020) `_ \ 18 | and implemented by us. 19 | 20 | .. warning:: 21 | PCGrad is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/RLW/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.RLW` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.RLW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: RLW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Random Loss Weighting (RLW). 16 | 17 | This method is proposed in `A Closer Look at Loss Weighting in Multi-Task Learning (arXiv:2111.10603) `_ \ 18 | and implemented by us. 19 | 20 | :param dist: The type of distribution where the loss weigghts are sampled from. 21 | :type dist: {'Uniform', 'Normal', 'Dirichlet', 'Bernoulli', 'constrained_Bernoulli'}, default='Normal' 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | :param losses: A list of loss of each task. 26 | :type losses: list 27 | :param kwargs: A dictionary of hyperparameters of weighting methods. 28 | :type kwargs: dict 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/UW/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.UW` 2 | ========================== 3 | 4 | .. py:module:: LibMTL.weighting.UW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: UW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Uncertainty Weights (UW). 16 | 17 | This method is proposed in `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (CVPR 2018) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: init_param(self) 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/_autoapi/LibMTL/weighting/abstract_weighting/index.rst.txt: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.abstract_weighting` 2 | ========================================== 3 | 4 | .. py:module:: LibMTL.weighting.abstract_weighting 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsWeighting 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | An abstract class for weighting strategies. 16 | 17 | 18 | .. py:method:: init_param(self) 19 | 20 | Define and initialize some trainable parameters required by specific weighting methods. 21 | 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | :property: 26 | 27 | :param losses: A list of loss of each task. 28 | :type losses: list 29 | :param kwargs: A dictionary of hyperparameters of weighting methods. 30 | :type kwargs: dict 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/develop/arch.md.txt: -------------------------------------------------------------------------------- 1 | ## Customize an Architecture 2 | 3 | Here we would like to introduce how to customize a new architecture with the support of ``LibMTL``. 4 | 5 | ### Create a New Architecture Class 6 | 7 | ```eval_rst 8 | Firstly, you need to create a new architecture class by inheriting class :class:`LibMTL.architecture.AbsArchitecture`. 9 | ``` 10 | 11 | ```python 12 | from LibMTL.architecture import AbsArchitecture 13 | 14 | class NewArchitecture(AbsArchitecture): 15 | def __init__(self, task_name, encoder_class, decoders, rep_grad, 16 | multi_input, device, **kwargs): 17 | super(NewArchitecture, self).__init__(task_name, encoder_class, decoders, rep_grad, 18 | multi_input, device, **kwargs) 19 | ``` 20 | 21 | ### Rewrite Corresponding Methods 22 | 23 | ```eval_rst 24 | There are four important function in :class:`LibMTL.architecture.AbsArchitecture`. We will introduce them in detail as follows. 25 | 26 | - :func:`forward`: The forward function and its input and output format can be found in :func:`LibMTL.architecture.AbsArchitecture.forward`. To rewrite this function, you need to consider the case of ``multi-input`` and ``multi-label`` (refer to `here <../user_guide/mtl.html#network-architecture>`_) and the case of ``rep-grad`` and ``param-grad`` (refer to `here <../user_guide/mtl.html#weighting-strategy>`_) if you would like to combine your architecture with more weighting strategies or apply your architecture to more datasets. 27 | - :func:`get_share_params`: This function is used to return the shared parameters of the model. It returned all the parameters of encoder by default. You can rewrite it if necessary. 28 | - :func:`zero_grad_share_params`: This function is used to set gradients of the shared parameters to zero. It will set the gradients of all the encoder parameters to zero by default. You can rewrite it if necessary. 29 | - :func:`_prepare_rep`: This function is used to allow to compute the gradients for representations. More details can be found `here <../../_modules/LibMTL/architecture/abstract_arch.html#AbsArchitecture>`_. 30 | ``` 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/develop/weighting.md.txt: -------------------------------------------------------------------------------- 1 | ## Customize a Weighting Strategy 2 | 3 | Here we would like to introduce how to customize a new weighting strategy with the support of ``LibMTL``. 4 | 5 | ### Create a New Weighting Class 6 | 7 | ```eval_rst 8 | Firstly, you need to create a new weighting class by inheriting class :class:`LibMTL.weighting.AbsWeighting`. 9 | ``` 10 | 11 | ```python 12 | from LibMTL.weighting import AbsWeighting 13 | 14 | class NewWeighting(AbsWeighting): 15 | def __init__(self): 16 | super(NewWeighting, self).__init__() 17 | ``` 18 | 19 | ### Rewrite Corresponding Methods 20 | 21 | ```eval_rst 22 | There are four important function in :class:`LibMTL.weighting.AbsWeighting`. We will introduce them in detail as follows. 23 | 24 | - :func:`backward`: The main function of a weighting strategy its input and output format can be found in :func:`LibMTL.weighting.AbsWeighting.backward`. To rewrite this function, you need to consider the case of ``multi-input`` and ``multi-label`` (refer to `here <../user_guide/mtl.html#network-architecture>`_) and the case of ``rep-grad`` and ``param-grad`` (refer to `here <../user_guide/mtl.html#weighting-strategy>`_) if you would like to combine your weighting method with more architectures or apply your method to more datasets. 25 | - :func:`init_param`: This function is used to define and initialize some trainable parameters. It does nothing by default and can be rewritten if necessary. 26 | - :func:`_get_grads`: This function is used to return the gradients of representations or shared parameters (covering the case of ``rep-grad`` and ``param-grad``). 27 | - :func:`_backward_new_grads`: This function is used to reset the gradients and make a backward (covering the case of ``rep-grad`` and ``param-grad``). 28 | 29 | The :func:`_get_grads` and :func:`_backward_new_grads` functions are very useful to rewrite the :func:`backward` function and you can find more details about them in `here <../../_modules/LibMTL/weighting/abstract_weighting.html#AbsWeighting>`_. 30 | ``` 31 | 32 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/getting_started/installation.md.txt: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ### Dependencies 4 | 5 | To install ``LibMTL``, you need to setup the following libraries: 6 | 7 | - Python >= 3.7 8 | - PyTorch >= 1.8.0 9 | - torchvision >= 0.9.0 10 | - numpy >= 1.20 11 | 12 | ### User Installation 13 | 14 | #### Using PyPi 15 | 16 | The simplest way to install `LibMTL` is using `pip`. 17 | 18 | ```shell 19 | pip install -U LibMTL 20 | ``` 21 | 22 | #### Using Source Code 23 | 24 | If you prefer, you can clone the source code from the GitHub and run the setup.py file. 25 | 26 | ```shell 27 | git clone https://github.com/median-research-group/LibMTL.git 28 | cd LibMTL 29 | python setup.py install 30 | ``` 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/getting_started/quick_start.md.txt: -------------------------------------------------------------------------------- 1 | ## Quick Start 2 | 3 | We use NYUv2 dataset as an example to show how to use ``LibMTL``. More details and results are provided here. 4 | 5 | ### Download Dataset 6 | 7 | The NYUv2 dataset we used is pre-processed by [mtan](https://github.com/lorenmt/mtan). You can download this dataset [here](https://www.dropbox.com/sh/86nssgwm6hm3vkb/AACrnUQ4GxpdrBbLjb6n-mWNa?dl=0). The directory structure is as follows: 8 | 9 | ```shell 10 | */nyuv2/ 11 | ├── train 12 | │   ├── depth 13 | │   ├── image 14 | │   ├── label 15 | │   └── normal 16 | └── val 17 | ├── depth 18 | ├── image 19 | ├── label 20 | └── normal 21 | ``` 22 | 23 | The NYUv2 dataset is a multi-label dataset, which includes three tasks: 13-class semantic segmentation, depth estimation, and surface normal prediction. ``image`` contains the input images and ``label``, ``depth``, ``normal`` contains the labels for three tasks, respectively. We train the MTL model with the data in ``train`` and evaluate on ``val``. 24 | 25 | ### Run a Model 26 | 27 | The complete training code of NYUv2 dataset are provided [here](https://github.com/median-research-group/LibMTL/examples/nyu). The file ``train_nyu.py`` is the main file of training on NYUv2 dataset. 28 | 29 | You can find the command-line arguments by running the following command. 30 | 31 | ```shell 32 | python train_nyu.py -h 33 | ``` 34 | 35 | For instance, running the following command will start training a MTL model with EW and HPS on NYUv2 dataset. 36 | 37 | ```shell 38 | python train_nyu.py --weighting EW --arch HPS --dataset_path */nyuv2 --gpu_id 0 --scheduler step 39 | ``` 40 | 41 | If everything works fine, you will see the following outputs which includes the training configurations and the number of model parameters: 42 | 43 | ``` 44 | ======================================== 45 | General Configuration: 46 | Wighting: EW 47 | Architecture: HPS 48 | Rep_Grad: False 49 | Multi_Input: False 50 | Seed: 0 51 | Device: cuda:0 52 | Optimizer Configuration: 53 | optim: adam 54 | lr: 0.0001 55 | weight_decay: 1e-05 56 | Scheduler Configuration: 57 | scheduler: step 58 | step_size: 100 59 | gamma: 0.5 60 | ======================================== 61 | Total Params: 71888721 62 | Trainable Params: 71888721 63 | Non-trainable Params: 0 64 | ======================================== 65 | ``` 66 | 67 | Next, the results will be printed in following format: 68 | 69 | ``` 70 | LOG FORMAT | segmentation_LOSS mIoU pixAcc | depth_LOSS abs_err rel_err | normal_LOSS mean median <11.25 <22.5 <30 | TIME 71 | Epoch: 0000 | TRAIN: 1.4417 0.2494 0.5717 | 1.4941 1.4941 0.5002 | 0.3383 43.1593 38.2601 0.0913 0.2639 0.3793 | Time: 81.6612 | TEST: 1.0898 0.3589 0.6676 | 0.7027 0.7027 0.2615 | 0.2143 32.8732 29.4323 0.1734 0.3878 0.5090 | Time: 11.9699 72 | Epoch: 0001 | TRAIN: 0.8958 0.4194 0.7201 | 0.7011 0.7011 0.2448 | 0.1993 31.5235 27.8404 0.1826 0.4060 0.5361 | Time: 82.2399 | TEST: 0.9980 0.4189 0.6868 | 0.6274 0.6274 0.2347 | 0.1991 31.0144 26.5077 0.2065 0.4332 0.5551 | Time: 12.0278 73 | ``` 74 | 75 | If the training process ends, the best result on ``val`` will be printed as follows: 76 | 77 | ``` 78 | Best Result: Epoch 65, result {'segmentation': [0.5377492904663086, 0.7544658184051514], 'depth': [0.38453552363844823, 0.1605487049810748], 'normal': [23.573742, 17.04381, 0.35038458555943763, 0.609274380451927, 0.7207172795833373]} 79 | ``` 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/user_guide/benchmark.rst.txt: -------------------------------------------------------------------------------- 1 | Run a Benckmark 2 | =============== 3 | 4 | Here we will introduce some MTL benchmark datasets and show run models on them for a fair compaison. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | benchmark/nyuv2 10 | benchmark/office 11 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/user_guide/benchmark/office.md.txt: -------------------------------------------------------------------------------- 1 | ## Office-31 and Office-Home 2 | 3 | ```eval_rst 4 | The Office-31 dataset :cite:`saenko2010adapting` consists of three domains: Amazon, DSLR, and Webcam, where each domain contains 31 object categories. It can be download `here `_. This dataset contains 4,110 labeled images and we randomly split these samples with 60\% for training, 20\% for validation, and the rest 20\% for test. 5 | 6 | The Office-Home dataset :cite:`venkateswara2017deep` has four domains: Artistic images (abbreviated as Art), Clip art, Product images, and Real-world images. It can be download `here `_. This dataset has 15,500 labeled images in total and each domain contains 65 classes. We divide the entire data in the same proportion as Office-31. 7 | 8 | For both two datasets, we consider the multi-class classification problem on each domain as a task. Thus, the ``multi_input`` must be ``True`` for both two office datasets. 9 | 10 | The training code are available in ``examples/office``. We used the ResNet-18 network pretrained on the ImageNet dataset followed by a fully connected layer as a shared encoder among tasks and a fully connected layer is applied as a task-specific output layer for each task. All the input images are resized to 3x224x224. 11 | ``` 12 | 13 | ### Run a Model 14 | 15 | The script ``train_office.py`` is the main file for training and evaluating a MTL model on the Office-31 or Office-Home dataset. A set of command-line arguments is provided to allow users to adjust the training parameter configuration. 16 | 17 | Some important arguments are described as follows. 18 | 19 | ```eval_rst 20 | - ``weighting``: The weighting strategy. Refer to `here <../_autoapi/LibMTL/weighting/index.html>`_. 21 | - ``arch``: The MTL architecture. Refer to `here <../_autoapi/LibMTL/architecture/index.html>`_. 22 | - ``gpu_id``: The id of gpu. Default to '0'. 23 | - ``seed``: The random seed for reproducibility. Default to 0. 24 | - ``optim``: The type of the optimizer. We recommend to use 'adam' here. 25 | - ``dataset``: Training on Office-31 or Office-Home. Options: 'office-31', 'office-home'. 26 | - ``dataset_path``: The path of the Office-31 or Office-Home dataset. 27 | - ``bs``: The batch size of training, validation, and test data. Default to 64. 28 | ``` 29 | 30 | The complete command-line arguments and their descriptions can be found by running the following command. 31 | 32 | ```shell 33 | python train_office.py -h 34 | ``` 35 | 36 | If you understand those command-line arguments, you can train a MTL model by running a command like this. 37 | 38 | ```shell 39 | python train_office.py --weighting WEIGHTING --arch ARCH --dataset_path PATH --gpu_id GPU_ID --multi_input 40 | ``` 41 | 42 | ### References 43 | 44 | ```eval_rst 45 | .. bibliography:: 46 | :style: unsrt 47 | :filter: docname in docnames 48 | ``` 49 | 50 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/docs/user_guide/framework.md.txt: -------------------------------------------------------------------------------- 1 | ## Overall Framework 2 | 3 | ``LibMTL`` provides a unified running framework to train a MTL model with some kind of architectures or weighting strategies on a given dataset. The overall framework is shown below. There are five modules to support to run. We introduce them as follows. 4 | 5 | - **Config Module**: Responsible for all the configuration parameters involved in the running framework, including the parameters of optimizer and learning rate scheduler, the hyper-parameters of MTL model, training configuration like batch size, total epoch, random seed and so on. 6 | - **Dataloaders Module**: Responsible for data pre-processing and loading. 7 | - **Model Module**: Responsible for inheriting classes architecture and weighting and instantiating a MTL model. Note that the architecture and the weighting strategy determine the forward and backward processes of the MTL model, respectively. 8 | - **Losses Module**: Responsible for computing the loss for each task. 9 | - **Metrics Module**: Responsible for evaluating the MTL model and calculating the metric scores for each task. 10 | 11 | ```eval_rst 12 | .. figure:: ../images/framework.png 13 | :scale: 100% 14 | ``` 15 | 16 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | .. LibMTL documentation master file, created by 2 | sphinx-quickstart on Thu Nov 25 17:02:04 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | LibMTL: A PyTorch Library for Multi-Task Learning 7 | ================================================= 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Getting Started: 12 | 13 | docs/getting_started/introduction 14 | docs/getting_started/installation 15 | docs/getting_started/quick_start 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | :caption: User Guide: 20 | 21 | docs/user_guide/mtl 22 | docs/user_guide/framework 23 | docs/user_guide/benchmark 24 | 25 | .. toctree:: 26 | :maxdepth: 1 27 | :caption: Developer Guide: 28 | 29 | docs/develop/dataset 30 | docs/develop/arch 31 | docs/develop/weighting 32 | 33 | .. toctree:: 34 | :maxdepth: 1 35 | :caption: API Reference: 36 | 37 | docs/_autoapi/LibMTL/index 38 | docs/_autoapi/LibMTL/loss/index 39 | docs/_autoapi/LibMTL/utils/index 40 | docs/_autoapi/LibMTL/model/index 41 | docs/_autoapi/LibMTL/config/index 42 | docs/_autoapi/LibMTL/metrics/index 43 | docs/_autoapi/LibMTL/weighting/index 44 | docs/_autoapi/LibMTL/architecture/index 45 | 46 | 47 | 48 | Indices and tables 49 | ================== 50 | 51 | * :ref:`genindex` 52 | * :ref:`modindex` 53 | * :ref:`search` 54 | -------------------------------------------------------------------------------- /docs/_build/html/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false 12 | }; -------------------------------------------------------------------------------- /docs/_build/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/file.png -------------------------------------------------------------------------------- /docs/_build/html/_static/graphviz.css: -------------------------------------------------------------------------------- 1 | /* 2 | * graphviz.css 3 | * ~~~~~~~~~~~~ 4 | * 5 | * Sphinx stylesheet -- graphviz extension. 6 | * 7 | * :copyright: Copyright 2007-2021 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | img.graphviz { 13 | border: 0; 14 | max-width: 100%; 15 | } 16 | 17 | object.graphviz { 18 | max-width: 100%; 19 | } 20 | -------------------------------------------------------------------------------- /docs/_build/html/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /docs/_build/html/_static/js/html5shiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/_build/html/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/minus.png -------------------------------------------------------------------------------- /docs/_build/html/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/_static/plus.png -------------------------------------------------------------------------------- /docs/_build/html/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | /* override table width restrictions*/ 2 | @media screen and (min-width: 767px) { 3 | 4 | .wy-table-responsive table td { 5 | /* !important prevents the common CSS stylesheets from overriding 6 | this as on RTD they are loaded after this stylesheet */ 7 | white-space: normal !important; 8 | } 9 | 10 | .wy-table-responsive { 11 | overflow: visible !important; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /docs/_build/html/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/_build/html/objects.inv -------------------------------------------------------------------------------- /docs/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | /* override table width restrictions*/ 2 | @media screen and (min-width: 767px) { 3 | 4 | .wy-table-responsive table td { 5 | /* !important prevents the common CSS stylesheets from overriding 6 | this as on RTD they are loaded after this stylesheet */ 7 | white-space: normal !important; 8 | } 9 | 10 | .wy-table-responsive { 11 | overflow: visible !important; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /docs/_templates/footer.html: -------------------------------------------------------------------------------- 1 | {% extends "!footer.html" %} 2 | {% block extrafooter %} 3 |

4 |

Made with ❤ at and  

5 | {{ super() }} 6 | {% endblock %} 7 | -------------------------------------------------------------------------------- /docs/autoapi_templates/python/module.rst: -------------------------------------------------------------------------------- 1 | {% if not obj.display %} 2 | :orphan: 3 | 4 | {% endif %} 5 | :mod:`{{ obj.name }}` 6 | ======={{ "=" * obj.name|length }} 7 | 8 | .. py:module:: {{ obj.name }} 9 | 10 | {% if obj.docstring %} 11 | .. autoapi-nested-parse:: 12 | 13 | {{ obj.docstring|prepare_docstring|indent(3) }} 14 | 15 | {% endif %} 16 | 17 | {% block content %} 18 | {% if obj.all is not none %} 19 | {% set visible_children = obj.children|selectattr("short_name", "in", obj.all)|list %} 20 | {% elif obj.type is equalto("package") %} 21 | {% set visible_children = obj.children|selectattr("display")|list %} 22 | {% else %} 23 | {% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %} 24 | {% endif %} 25 | {% if visible_children %} 26 | {# {{ obj.type|title }} Contents 27 | {{ "-" * obj.type|length }}--------- #} 28 | 29 | {% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %} 30 | {% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %} 31 | {% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %} 32 | 33 | {# {% block classes scoped %} 34 | {% if visible_classes %} 35 | Classes 36 | ------- 37 | 38 | .. autoapisummary:: 39 | 40 | {% for klass in visible_classes %} 41 | {{ klass.id }} 42 | {% endfor %} 43 | 44 | 45 | {% endif %} 46 | {% endblock %} #} 47 | 48 | {# {% block functions scoped %} 49 | {% if visible_functions %} 50 | Functions 51 | --------- 52 | 53 | .. autoapisummary:: 54 | 55 | {% for function in visible_functions %} 56 | {{ function.id }} 57 | {% endfor %} 58 | 59 | 60 | {% endif %} 61 | {% endblock %} #} 62 | 63 | {% endif %} 64 | {% for obj_item in visible_children %} 65 | {{ obj_item.rendered|indent(0) }} 66 | {% endfor %} 67 | {% endif %} 68 | {% endblock %} 69 | 70 | {% block subpackages %} 71 | {% set visible_subpackages = obj.subpackages|selectattr("display")|list %} 72 | {% if visible_subpackages %} 73 | {# Subpackages 74 | ----------- #} 75 | .. toctree:: 76 | :titlesonly: 77 | :maxdepth: 1 78 | 79 | {% for subpackage in visible_subpackages %} 80 | {{ subpackage.short_name }}/index.rst 81 | {% endfor %} 82 | 83 | 84 | {% endif %} 85 | {% endblock %} 86 | {# {% block submodules %} 87 | {% set visible_submodules = obj.submodules|selectattr("display")|list %} 88 | {% if visible_submodules %} 89 | Submodules 90 | ---------- 91 | .. toctree:: 92 | :titlesonly: 93 | :maxdepth: 1 94 | 95 | {% for submodule in visible_submodules %} 96 | {{ submodule.short_name }}/index.rst 97 | {% endfor %} 98 | 99 | 100 | {% endif %} 101 | {% endblock %} #} 102 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/_record/index.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | :mod:`LibMTL._record` 4 | ===================== 5 | 6 | .. py:module:: LibMTL._record 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/CGC/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.CGC` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.CGC 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: CGC(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.MMoE.MMoE` 14 | 15 | Customized Gate Control (CGC). 16 | 17 | This method is proposed in `Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The numbers of experts shared for all tasks and specific to each task, respectively. Each expert is the encoder network. 23 | :type num_experts: list 24 | 25 | .. py:method:: forward(self, inputs, task_name=None) 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/Cross_stitch/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.Cross_stitch` 2 | ======================================= 3 | 4 | .. py:module:: LibMTL.architecture.Cross_stitch 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: Cross_stitch(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Cross-stitch Networks (Cross_stitch). 16 | 17 | This method is proposed in `Cross-stitch Networks for Multi-task Learning (CVPR 2016) `_ \ 18 | and implemented by us. 19 | 20 | .. warning:: 21 | - :class:`Cross_stitch` does not work with multiple inputs MTL problem, i.e., ``multi_input`` must be ``False``. 22 | 23 | - :class:`Cross_stitch` is only supported with ResNet-based encoder. 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/DSelect_k/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.DSelect_k` 2 | ==================================== 3 | 4 | .. py:module:: LibMTL.architecture.DSelect_k 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: DSelect_k(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.MMoE.MMoE` 14 | 15 | DSelect-k. 16 | 17 | This method is proposed in `DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning (NeurIPS 2021) `_ \ 18 | and implemented by modifying from the `official TensorFlow implementation `_. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The number of experts shared for all tasks. Each expert is the encoder network. 23 | :type num_experts: int 24 | :param num_nonzeros: The number of selected experts. 25 | :type num_nonzeros: int 26 | :param kgamma: A scaling parameter for the smooth-step function. 27 | :type kgamma: float, default=1.0 28 | 29 | .. py:method:: forward(self, inputs, task_name=None) 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/HPS/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.HPS` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.HPS 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: HPS(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Hrad Parameter Sharing (HPS). 16 | 17 | This method is proposed in `Multitask Learning: A Knowledge-Based Source of Inductive Bias (ICML 1993) `_ \ 18 | and implemented by us. 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/MMoE/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.MMoE` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.architecture.MMoE 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MMoE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Multi-gate Mixture-of-Experts (MMoE). 16 | 17 | This method is proposed in `Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts (KDD 2018) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The number of experts shared for all tasks. Each expert is the encoder network. 23 | :type num_experts: int 24 | 25 | .. py:method:: forward(self, inputs, task_name=None) 26 | 27 | 28 | .. py:method:: get_share_params(self) 29 | 30 | 31 | .. py:method:: zero_grad_share_params(self) 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/MTAN/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.MTAN` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.architecture.MTAN 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MTAN(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Multi-Task Attention Network (MTAN). 16 | 17 | This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | .. warning:: 21 | :class:`MTAN` is only supported with ResNet-based encoder. 22 | 23 | 24 | .. py:method:: forward(self, inputs, task_name=None) 25 | 26 | 27 | .. py:method:: get_share_params(self) 28 | 29 | 30 | .. py:method:: zero_grad_share_params(self) 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/PLE/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.PLE` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.architecture.PLE 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: PLE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`LibMTL.architecture.abstract_arch.AbsArchitecture` 14 | 15 | Progressive Layered Extraction (PLE). 16 | 17 | This method is proposed in `Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) `_ \ 18 | and implemented by us. 19 | 20 | :param img_size: The size of input data. For example, [3, 244, 244] for input images with size 3x224x224. 21 | :type img_size: list 22 | :param num_experts: The numbers of experts shared for all tasks and specific to each task, respectively. Each expert is the encoder network. 23 | :type num_experts: list 24 | 25 | .. warning:: 26 | - :class:`PLE` does not work with multiple inputs MTL problem, i.e., ``multi_input`` must be ``False``. 27 | - :class:`PLE` is only supported with ResNet-based encoder. 28 | 29 | 30 | .. py:method:: forward(self, inputs, task_name=None) 31 | 32 | 33 | .. py:method:: get_share_params(self) 34 | 35 | 36 | .. py:method:: zero_grad_share_params(self) 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/architecture/abstract_arch/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.architecture.abstract_arch` 2 | ======================================== 3 | 4 | .. py:module:: LibMTL.architecture.abstract_arch 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsArchitecture(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | An abstract class for MTL architectures. 16 | 17 | :param task_name: A list of strings for all tasks. 18 | :type task_name: list 19 | :param encoder_class: A neural network class. 20 | :type encoder_class: class 21 | :param decoders: A dictionary of name-decoder pairs of type (:class:`str`, :class:`torch.nn.Module`). 22 | :type decoders: dict 23 | :param rep_grad: If ``True``, the gradient of the representation for each task can be computed. 24 | :type rep_grad: bool 25 | :param multi_input: Is ``True`` if each task has its own input data, ``False`` otherwise. 26 | :type multi_input: bool 27 | :param device: The device where model and data will be allocated. 28 | :type device: torch.device 29 | :param kwargs: A dictionary of hyperparameters of architecture methods. 30 | :type kwargs: dict 31 | 32 | .. py:method:: forward(self, inputs, task_name=None) 33 | 34 | :param inputs: The input data. 35 | :type inputs: torch.Tensor 36 | :param task_name: The task name corresponding to ``inputs`` if ``multi_input`` is ``True``. 37 | :type task_name: str, default=None 38 | 39 | :returns: A dictionary of name-prediction pairs of type (:class:`str`, :class:`torch.Tensor`). 40 | :rtype: dict 41 | 42 | 43 | .. py:method:: get_share_params(self) 44 | 45 | Return the shared parameters of the model. 46 | 47 | 48 | 49 | .. py:method:: zero_grad_share_params(self) 50 | 51 | Set gradients of the shared parameters to zero. 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/config/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.config` 2 | ==================== 3 | 4 | .. py:module:: LibMTL.config 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:data:: LibMTL_args 12 | 13 | 14 | 15 | 16 | .. py:function:: prepare_args(params) 17 | 18 | Return the configuration of hyperparameters, optimizier, and learning rate scheduler. 19 | 20 | :param params: The command-line arguments. 21 | :type params: argparse.Namespace 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/loss/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.loss` 2 | ================== 3 | 4 | .. py:module:: LibMTL.loss 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsLoss 12 | 13 | Bases: :py:obj:`object` 14 | 15 | An abstract class for loss function. 16 | 17 | 18 | .. py:method:: compute_loss(self, pred, gt) 19 | :property: 20 | 21 | Calculate the loss. 22 | 23 | :param pred: The prediction tensor. 24 | :type pred: torch.Tensor 25 | :param gt: The ground-truth tensor. 26 | :type gt: torch.Tensor 27 | 28 | :returns: The loss. 29 | :rtype: torch.Tensor 30 | 31 | 32 | 33 | .. py:class:: CELoss 34 | 35 | Bases: :py:obj:`AbsLoss` 36 | 37 | The cross entropy loss function. 38 | 39 | 40 | .. py:method:: compute_loss(self, pred, gt) 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/metrics/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.metrics` 2 | ===================== 3 | 4 | .. py:module:: LibMTL.metrics 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsMetric 12 | 13 | Bases: :py:obj:`object` 14 | 15 | An abstract class for the performance metrics of a task. 16 | 17 | .. attribute:: record 18 | 19 | A list of the metric scores in every iteration. 20 | 21 | :type: list 22 | 23 | .. attribute:: bs 24 | 25 | A list of the number of data in every iteration. 26 | 27 | :type: list 28 | 29 | .. py:method:: update_fun(self, pred, gt) 30 | :property: 31 | 32 | Calculate the metric scores in every iteration and update :attr:`record`. 33 | 34 | :param pred: The prediction tensor. 35 | :type pred: torch.Tensor 36 | :param gt: The ground-truth tensor. 37 | :type gt: torch.Tensor 38 | 39 | 40 | .. py:method:: score_fun(self) 41 | :property: 42 | 43 | Calculate the final score (when a epoch ends). 44 | 45 | :returns: A list of metric scores. 46 | :rtype: list 47 | 48 | 49 | .. py:method:: reinit(self) 50 | 51 | Reset :attr:`record` and :attr:`bs` (when a epoch ends). 52 | 53 | 54 | 55 | 56 | .. py:class:: AccMetric 57 | 58 | Bases: :py:obj:`AbsMetric` 59 | 60 | Calculate the accuracy. 61 | 62 | 63 | .. py:method:: update_fun(self, pred, gt) 64 | 65 | 66 | 67 | 68 | .. py:method:: score_fun(self) 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/model/resnet_dilated/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.model.resnet_dilated` 2 | ================================== 3 | 4 | .. py:module:: LibMTL.model.resnet_dilated 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: ResnetDilated(orig_resnet, dilate_scale=8) 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | Base class for all neural network modules. 16 | 17 | Your models should also subclass this class. 18 | 19 | Modules can also contain other Modules, allowing to nest them in 20 | a tree structure. You can assign the submodules as regular attributes:: 21 | 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | class Model(nn.Module): 26 | def __init__(self): 27 | super(Model, self).__init__() 28 | self.conv1 = nn.Conv2d(1, 20, 5) 29 | self.conv2 = nn.Conv2d(20, 20, 5) 30 | 31 | def forward(self, x): 32 | x = F.relu(self.conv1(x)) 33 | return F.relu(self.conv2(x)) 34 | 35 | Submodules assigned in this way will be registered, and will have their 36 | parameters converted too when you call :meth:`to`, etc. 37 | 38 | .. py:method:: forward(self, x) 39 | 40 | Defines the computation performed at every call. 41 | 42 | Should be overridden by all subclasses. 43 | 44 | .. note:: 45 | Although the recipe for forward pass needs to be defined within 46 | this function, one should call the :class:`Module` instance afterwards 47 | instead of this since the former takes care of running the 48 | registered hooks while the latter silently ignores them. 49 | 50 | 51 | .. py:method:: forward_stage(self, x, stage) 52 | 53 | 54 | 55 | .. py:function:: resnet_dilated(basenet, pretrained=True, dilate_scale=8) 56 | 57 | Dilated Residual Network models from `"Dilated Residual Networks" `_ 58 | 59 | :param basenet: The type of ResNet. 60 | :type basenet: str 61 | :param pretrained: If True, returns a model pre-trained on ImageNet. 62 | :type pretrained: bool 63 | :param dilate_scale: The type of dilating process. 64 | :type dilate_scale: {8, 16}, default=8 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/utils/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.utils` 2 | =================== 3 | 4 | .. py:module:: LibMTL.utils 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:function:: set_random_seed(seed) 12 | 13 | Set the random seed for reproducibility. 14 | 15 | :param seed: The random seed. 16 | :type seed: int, default=0 17 | 18 | 19 | .. py:function:: set_device(gpu_id) 20 | 21 | Set the device where model and data will be allocated. 22 | 23 | :param gpu_id: The id of gpu. 24 | :type gpu_id: str, default='0' 25 | 26 | 27 | .. py:function:: count_parameters(model) 28 | 29 | Calculates the number of parameters for a model. 30 | 31 | :param model: A neural network module. 32 | :type model: torch.nn.Module 33 | 34 | 35 | .. py:function:: count_improvement(base_result, new_result, weight) 36 | 37 | Calculate the improvement between two results, 38 | 39 | .. math:: 40 | \Delta_{\mathrm{p}}=100\%\times \frac{1}{T}\sum_{t=1}^T 41 | \frac{1}{M_t}\sum_{m=1}^{M_t}\frac{(-1)^{w_{t,m}}(B_{t,m}-N_{t,m})}{N_{t,m}}. 42 | 43 | :param base_result: A dictionary of scores of all metrics of all tasks. 44 | :type base_result: dict 45 | :param new_result: The same structure with ``base_result``. 46 | :type new_result: dict 47 | :param weight: The same structure with ``base_result`` while each elements is binary integer representing whether higher or lower score is better. 48 | :type weight: dict 49 | 50 | :returns: The improvement between ``new_result`` and ``base_result``. 51 | :rtype: float 52 | 53 | Examples:: 54 | 55 | base_result = {'A': [96, 98], 'B': [0.2]} 56 | new_result = {'A': [93, 99], 'B': [0.5]} 57 | weight = {'A': [1, 0], 'B': [1]} 58 | 59 | print(count_improvement(base_result, new_result, weight)) 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/CAGrad/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.CAGrad` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.weighting.CAGrad 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: CAGrad 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Conflict-Averse Gradient descent (CAGrad). 16 | 17 | This method is proposed in `Conflict-Averse Gradient Descent for Multi-task learning (NeurIPS 2021) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param calpha: A hyperparameter that controls the convergence rate. 21 | :type calpha: float, default=0.5 22 | :param rescale: The type of gradient rescale. 23 | :type rescale: {0, 1, 2}, default=1 24 | 25 | .. warning:: 26 | CAGrad is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 27 | 28 | 29 | .. py:method:: backward(self, losses, **kwargs) 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/DWA/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.DWA` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.DWA 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: DWA 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Dynamic Weight Average (DWA). 16 | 17 | This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param T: The softmax temperature. 21 | :type T: float, default=2.0 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/EW/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.EW` 2 | ========================== 3 | 4 | .. py:module:: LibMTL.weighting.EW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: EW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Equally Weighting (EW). 16 | 17 | The loss weight for each task is always ``1 / T`` in every iteration, where ``T`` means the number of tasks. 18 | 19 | 20 | .. py:method:: backward(self, losses, **kwargs) 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/GLS/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GLS` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.GLS 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GLS 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Geometric Loss Strategy (GLS). 16 | 17 | This method is proposed in `MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning (CVPR 2019 workshop) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: backward(self, losses, **kwargs) 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/GradDrop/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradDrop` 2 | ================================ 3 | 4 | .. py:module:: LibMTL.weighting.GradDrop 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradDrop 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Sign Dropout (GradDrop). 16 | 17 | This method is proposed in `Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout (NeurIPS 2020) `_ \ 18 | and implemented by us. 19 | 20 | :param leak: The leak parameter for the weighting matrix. 21 | :type leak: float, default=0.0 22 | 23 | .. warning:: 24 | GradDrop is not supported with parameter gradients, i.e., ``rep_grad`` must be ``True``. 25 | 26 | 27 | .. py:method:: backward(self, losses, **kwargs) 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/GradNorm/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradNorm` 2 | ================================ 3 | 4 | .. py:module:: LibMTL.weighting.GradNorm 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradNorm 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Normalization (GradNorm). 16 | 17 | This method is proposed in `GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks (ICML 2018) `_ \ 18 | and implemented by us. 19 | 20 | :param alpha: The strength of the restoring force which pulls tasks back to a common training rate. 21 | :type alpha: float, default=1.5 22 | 23 | .. py:method:: init_param(self) 24 | 25 | 26 | .. py:method:: backward(self, losses, **kwargs) 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/GradVac/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.GradVac` 2 | =============================== 3 | 4 | .. py:module:: LibMTL.weighting.GradVac 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: GradVac 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Gradient Vaccine (GradVac). 16 | 17 | This method is proposed in `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) `_ \ 18 | and implemented by us. 19 | 20 | :param beta: The exponential moving average (EMA) decay parameter. 21 | :type beta: float, default=0.5 22 | 23 | .. warning:: 24 | GradVac is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 25 | 26 | 27 | .. py:method:: backward(self, losses, **kwargs) 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/IMTL/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.IMTL` 2 | ============================ 3 | 4 | .. py:module:: LibMTL.weighting.IMTL 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: IMTL 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Impartial Multi-task Learning (IMTL). 16 | 17 | This method is proposed in `Towards Impartial Multi-task Learning (ICLR 2021) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: init_param(self) 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/MGDA/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.MGDA` 2 | ============================ 3 | 4 | .. py:module:: LibMTL.weighting.MGDA 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: MGDA 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Multiple Gradient Descent Algorithm (MGDA). 16 | 17 | This method is proposed in `Multi-Task Learning as Multi-Objective Optimization (NeurIPS 2018) `_ \ 18 | and implemented by modifying from the `official PyTorch implementation `_. 19 | 20 | :param mgda_gn: The type of gradient normalization. 21 | :type mgda_gn: {'none', 'l2', 'loss', 'loss+'}, default='none' 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/PCGrad/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.PCGrad` 2 | ============================== 3 | 4 | .. py:module:: LibMTL.weighting.PCGrad 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: PCGrad 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Project Conflicting Gradients (PCGrad). 16 | 17 | This method is proposed in `Gradient Surgery for Multi-Task Learning (NeurIPS 2020) `_ \ 18 | and implemented by us. 19 | 20 | .. warning:: 21 | PCGrad is not supported with representation gradients, i.e., ``rep_grad`` must be ``False``. 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/RLW/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.RLW` 2 | =========================== 3 | 4 | .. py:module:: LibMTL.weighting.RLW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: RLW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Random Loss Weighting (RLW). 16 | 17 | This method is proposed in `A Closer Look at Loss Weighting in Multi-Task Learning (arXiv:2111.10603) `_ \ 18 | and implemented by us. 19 | 20 | :param dist: The type of distribution where the loss weigghts are sampled from. 21 | :type dist: {'Uniform', 'Normal', 'Dirichlet', 'Bernoulli', 'constrained_Bernoulli'}, default='Normal' 22 | 23 | .. py:method:: backward(self, losses, **kwargs) 24 | 25 | :param losses: A list of loss of each task. 26 | :type losses: list 27 | :param kwargs: A dictionary of hyperparameters of weighting methods. 28 | :type kwargs: dict 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/UW/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.UW` 2 | ========================== 3 | 4 | .. py:module:: LibMTL.weighting.UW 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: UW 12 | 13 | Bases: :py:obj:`LibMTL.weighting.abstract_weighting.AbsWeighting` 14 | 15 | Uncertainty Weights (UW). 16 | 17 | This method is proposed in `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (CVPR 2018) `_ \ 18 | and implemented by us. 19 | 20 | 21 | .. py:method:: init_param(self) 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/docs/_autoapi/LibMTL/weighting/abstract_weighting/index.rst: -------------------------------------------------------------------------------- 1 | :mod:`LibMTL.weighting.abstract_weighting` 2 | ========================================== 3 | 4 | .. py:module:: LibMTL.weighting.abstract_weighting 5 | 6 | 7 | 8 | 9 | 10 | 11 | .. py:class:: AbsWeighting 12 | 13 | Bases: :py:obj:`torch.nn.Module` 14 | 15 | An abstract class for weighting strategies. 16 | 17 | 18 | .. py:method:: init_param(self) 19 | 20 | Define and initialize some trainable parameters required by specific weighting methods. 21 | 22 | 23 | 24 | .. py:method:: backward(self, losses, **kwargs) 25 | :property: 26 | 27 | :param losses: A list of loss of each task. 28 | :type losses: list 29 | :param kwargs: A dictionary of hyperparameters of weighting methods. 30 | :type kwargs: dict 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /docs/docs/develop/arch.md: -------------------------------------------------------------------------------- 1 | ## Customize an Architecture 2 | 3 | Here we introduce how to customize a new architecture with the support of ``LibMTL``. 4 | 5 | ### Create a New Architecture Class 6 | 7 | ```eval_rst 8 | Firstly, you need to create a new architecture class by inheriting class :class:`LibMTL.architecture.AbsArchitecture`. 9 | ``` 10 | 11 | ```python 12 | from LibMTL.architecture import AbsArchitecture 13 | 14 | class NewArchitecture(AbsArchitecture): 15 | def __init__(self, task_name, encoder_class, decoders, rep_grad, 16 | multi_input, device, **kwargs): 17 | super(NewArchitecture, self).__init__(task_name, encoder_class, decoders, rep_grad, 18 | multi_input, device, **kwargs) 19 | ``` 20 | 21 | ### Rewrite Relevant Methods 22 | 23 | ```eval_rst 24 | There are four important functions in :class:`LibMTL.architecture.AbsArchitecture`. 25 | 26 | - :func:`forward`: The forward function and its input/output format can be found in :func:`LibMTL.architecture.AbsArchitecture.forward`. To rewrite this function, you need to consider the case of ``single-input`` and ``multi-input`` (refer to `here <../user_guide/mtl.html#network-architecture>`_) and the case of ``rep-grad`` and ``param-grad`` (refer to `here <../user_guide/mtl.html#weighting-strategy>`_) if you want to combine your architecture with more weighting strategies or apply your architecture to more datasets. 27 | - :func:`get_share_params`: This function is used to return the shared parameters of the model. It returns all the parameters of the encoder by default. You can rewrite it if necessary. 28 | - :func:`zero_grad_share_params`: This function is used to set gradients of the shared parameters to zero. It will set the gradients of all the encoder parameters to zero by default. You can rewrite it if necessary. 29 | - :func:`_prepare_rep`: This function is used to compute the gradients for representations. More details can be found `here <../../_modules/LibMTL/architecture/abstract_arch.html#AbsArchitecture>`_. 30 | ``` 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/docs/develop/weighting.md: -------------------------------------------------------------------------------- 1 | ## Customize a Weighting Strategy 2 | 3 | Here we introduce how to customize a new weighting strategy with the support of ``LibMTL``. 4 | 5 | ### Create a New Weighting Class 6 | 7 | ```eval_rst 8 | Firstly, you need to create a new weighting class by inheriting class :class:`LibMTL.weighting.AbsWeighting`. 9 | ``` 10 | 11 | ```python 12 | from LibMTL.weighting import AbsWeighting 13 | 14 | class NewWeighting(AbsWeighting): 15 | def __init__(self): 16 | super(NewWeighting, self).__init__() 17 | ``` 18 | 19 | ### Rewrite Relevant Methods 20 | 21 | ```eval_rst 22 | There are four important functions in :class:`LibMTL.weighting.AbsWeighting`. 23 | 24 | - :func:`backward`: It is the main function of a weighting strategy whose input and output formats can be found in :func:`LibMTL.weighting.AbsWeighting.backward`. To rewrite this function, you need to consider the case of ``single-input`` and ``multi-input`` (refer to `here <../user_guide/mtl.html#network-architecture>`_) and the case of ``rep-grad`` and ``param-grad`` (refer to `here <../user_guide/mtl.html#weighting-strategy>`_) if you want to combine your weighting method with more architectures or apply your method to more datasets. 25 | - :func:`init_param`: This function is used to define and initialize some trainable parameters. It does nothing by default and can be rewritten if necessary. 26 | - :func:`_get_grads`: This function is used to return the gradients of representations or shared parameters (corresponding to the case of ``rep-grad`` and ``param-grad``, respectively). 27 | - :func:`_backward_new_grads`: This function is used to reset the gradients and make a backward pass (corresponding to the case of ``rep-grad`` and ``param-grad``, respectively). 28 | 29 | The :func:`_get_grads` and :func:`_backward_new_grads` functions are very useful to rewrite the :func:`backward` function and you can find more details `here <../../_modules/LibMTL/weighting/abstract_weighting.html#AbsWeighting>`_. 30 | ``` 31 | 32 | -------------------------------------------------------------------------------- /docs/docs/getting_started/installation.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ### Dependencies 4 | 5 | To install ``LibMTL``, we recommend to use the following libraries: 6 | 7 | - Python == 3.8 8 | - torch == 1.8.1+cu111 9 | - torchvision == 0.9.1+cu111 10 | 11 | ### User Installation 12 | 13 | * Create a virtual environment 14 | 15 | ```shell 16 | conda create -n libmtl python=3.8 17 | conda activate libmtl 18 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 19 | ``` 20 | 21 | * Clone the repository 22 | 23 | ```shell 24 | git clone https://github.com/median-research-group/LibMTL.git 25 | ``` 26 | 27 | * Install `LibMTL` 28 | 29 | ```shell 30 | cd LibMTL 31 | pip install -r requirements.txt 32 | pip install -e . 33 | ``` 34 | -------------------------------------------------------------------------------- /docs/docs/images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/docs/images/framework.png -------------------------------------------------------------------------------- /docs/docs/images/multi_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/docs/images/multi_input.png -------------------------------------------------------------------------------- /docs/docs/images/rep_grad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/docs/docs/images/rep_grad.png -------------------------------------------------------------------------------- /docs/docs/references.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{silberman2012indoor, 2 | title={Indoor segmentation and support inference from rgbd images}, 3 | author={Silberman, Nathan and Hoiem, Derek and Kohli, Pushmeet and Fergus, Rob}, 4 | booktitle={Proceedings of the 8th European Conference on Computer Vision}, 5 | pages={746--760}, 6 | year={2012}, 7 | } 8 | 9 | @inproceedings{ljd19, 10 | author = {Shikun Liu and 11 | Edward Johns and 12 | Andrew J. Davison}, 13 | title = {End-To-End Multi-Task Learning With Attention}, 14 | booktitle = {Proceedings of {IEEE} Conference on Computer Vision and Pattern Recognition}, 15 | pages = {1871--1880}, 16 | year = {2019} 17 | } 18 | 19 | @inproceedings{YuKF17, 20 | author = {Fisher Yu and 21 | Vladlen Koltun and 22 | Thomas A. Funkhouser}, 23 | title = {Dilated Residual Networks}, 24 | booktitle = {Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition}, 25 | pages = {636--644}, 26 | year = {2017}, 27 | } 28 | 29 | @inproceedings{ChenZPSA18, 30 | author = {Liang{-}Chieh Chen and 31 | Yukun Zhu and 32 | George Papandreou and 33 | Florian Schroff and 34 | Hartwig Adam}, 35 | title = {Encoder-Decoder with Atrous Separable Convolution for Semantic Image 36 | Segmentation}, 37 | booktitle = {Proceedings of the 14th European Conference on Computer Vision}, 38 | volume = {11211}, 39 | pages = {833--851}, 40 | year = {2018}, 41 | } 42 | 43 | @article{lin2021rlw, 44 | title={A Closer Look at Loss Weighting in Multi-Task Learning}, 45 | author={Lin, Baijiong and Ye, Feiyang and Zhang, Yu}, 46 | journal={arXiv preprint arXiv:2111.10603}, 47 | year={2021} 48 | } 49 | 50 | @inproceedings{saenko2010adapting, 51 | title={Adapting visual category models to new domains}, 52 | author={Saenko, Kate and Kulis, Brian and Fritz, Mario and Darrell, Trevor}, 53 | booktitle={Proceedings of the 6th European Conference on Computer Vision}, 54 | pages={213--226}, 55 | year={2010}, 56 | } 57 | 58 | @inproceedings{venkateswara2017deep, 59 | title={Deep hashing network for unsupervised domain adaptation}, 60 | author={Venkateswara, Hemanth and Eusebio, Jose and Chakraborty, Shayok and Panchanathan, Sethuraman}, 61 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 62 | pages={5018--5027}, 63 | year={2017} 64 | } 65 | 66 | @article{ZhangY21, 67 | author = {Yu Zhang and 68 | Qiang Yang}, 69 | title = {A Survey on Multi-Task Learning}, 70 | journal = {{IEEE} Transactions on Knowledge and Data Engineering}, 71 | year = {2021}, 72 | } 73 | 74 | @article{Vandenhende21, 75 | author = {Simon Vandenhende and Stamatios Georgoulis and Wouter Van Gansbeke and Marc Proesmans and Dengxin Dai and Luc Van Gool}, 76 | title = {Multi-Task Learning for Dense Prediction Tasks: A Survey}, 77 | journal = {{IEEE} Transactions on Pattern Analysis and Machine Intelligence }, 78 | year = {2021}, 79 | } 80 | 81 | @article{Michael20, 82 | title={Multi-Task Learning with Deep Neural Networks: A Survey}, 83 | author={Michael Crawshaw}, 84 | journal={arXiv preprint arXiv:2009.09796}, 85 | year={2020} 86 | } -------------------------------------------------------------------------------- /docs/docs/user_guide/benchmark.rst: -------------------------------------------------------------------------------- 1 | Run a Benchmark 2 | =============== 3 | 4 | Here we introduce some MTL benchmark datasets and show how to run models on them for a fair comparison. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | benchmark/nyuv2 10 | benchmark/office 11 | -------------------------------------------------------------------------------- /docs/docs/user_guide/benchmark/office.md: -------------------------------------------------------------------------------- 1 | ## Office-31 and Office-Home 2 | 3 | ```eval_rst 4 | The Office-31 dataset :cite:`saenko2010adapting` consists of three classification tasks on three domains: Amazon, DSLR, and Webcam, where each task has 31 object categories. It can be download `here `_. This dataset contains 4,110 labeled images and we randomly split these samples, with 60\% for training, 20\% for validation, and the rest 20\% for testing. 5 | 6 | The Office-Home dataset :cite:`venkateswara2017deep` has four classification tasks on four domains: Artistic images (abbreviated as Art), Clip art, Product images, and Real-world images. It can be download `here `_. This dataset has 15,500 labeled images in total and each domain contains 65 classes. We divide the entire data into the same proportion as the Office-31 dataset. 7 | 8 | Both datasets belong to the multi-input setting in MTL. Thus, the ``multi_input`` must be ``True`` for both of the two office datasets. 9 | 10 | The training codes are available in ``examples/office``. We use the ResNet-18 network pretrained on the ImageNet dataset followed by a fully connected layer as a shared encoder among tasks and a fully connected layer is applied as a task-specific output layer for each task. All the input images are resized to 3x224x224. 11 | ``` 12 | 13 | ### Run a Model 14 | 15 | The script ``train_office.py`` is the main file for training and evaluating a MTL model on the Office-31 or Office-Home dataset. A set of command-line arguments is provided to allow users to adjust the training parameter configuration. 16 | 17 | Some important arguments are described as follows. 18 | 19 | ```eval_rst 20 | - ``weighting``: The weighting strategy. Refer to `here <../_autoapi/LibMTL/weighting/index.html>`_. 21 | - ``arch``: The MTL architecture. Refer to `here <../_autoapi/LibMTL/architecture/index.html>`_. 22 | - ``gpu_id``: The id of gpu. The default value is '0'. 23 | - ``seed``: The random seed for reproducibility. The default value is 0. 24 | - ``optim``: The type of the optimizer. We recommend to use 'adam' here. 25 | - ``dataset``: Training on Office-31 or Office-Home. Options: 'office-31', 'office-home'. 26 | - ``dataset_path``: The path of the Office-31 or Office-Home dataset. 27 | - ``bs``: The batch size of training, validation, and test data. The default value is 64. 28 | ``` 29 | 30 | The complete command-line arguments and their descriptions can be found by running the following command. 31 | 32 | ```shell 33 | python main.py -h 34 | ``` 35 | 36 | If you understand those command-line arguments, you can train a MTL model by running a command like this. 37 | 38 | ```shell 39 | python main.py --weighting WEIGHTING --arch ARCH --dataset_path PATH --gpu_id GPU_ID --multi_input 40 | ``` 41 | 42 | ### References 43 | 44 | ```eval_rst 45 | .. bibliography:: 46 | :style: unsrt 47 | :filter: docname in docnames 48 | ``` 49 | -------------------------------------------------------------------------------- /docs/docs/user_guide/framework.md: -------------------------------------------------------------------------------- 1 | ## Overall Framework 2 | 3 | ``LibMTL`` provides a unified framework to train a MTL model with several architectures and weighting strategies on benchmark datasets. The overall framework consists of nine modules as introduced below. 4 | 5 | ```eval_rst 6 | - The Dataloader module is responsible for data pre-processing and loading. 7 | 8 | - The `LibMTL.loss <../_autoapi/LibMTL/loss/index.html>`_ module defines loss functions for each task. 9 | 10 | - The `LibMTL.metrics <../_autoapi/LibMTL/metrics/index.html>`_ module defines evaluation metrics for all the tasks. 11 | 12 | - The `LibMTL.config <../_autoapi/LibMTL/config/index.html>`_ module is responsible for all the configuration parameters involved in the training process, such as the corresponding MTL setting (i.e. the multi-input case or not), the potential hyper-parameters of loss weighting strategies and architectures, the training configuration (e.g., the batch size, the running epoch, the random seed, and the learning rate), and so on. This module adopts command-line arguments to enable users to conveniently set those configuration parameters. 13 | 14 | - The `LibMTL.Trainer <../_autoapi/LibMTL/trainer/index.html>`_ module provides a unified framework for the training process under different MTL settings and for different MTL approaches 15 | 16 | - The `LibMTL.utils <../_autoapi/LibMTL/utils/index.html>`_ module implements some useful functionalities for the training process such as calculating the total number of parameters in an MTL model. 17 | 18 | - The `LibMTL.architecture <../_autoapi/LibMTL/architecture/index.html>`_ module contains the implementations of various architectures in MTL. 19 | 20 | - The `LibMTL.weighting <../_autoapi/LibMTL/weighting/index.html>`_ module contains the implementations of various loss weighting strategies in MTL. 21 | 22 | - The `LibMTL.model <../_autoapi/LibMTL/model/index.html>`_ module includes some popular backbone networks (e.g., ResNet). 23 | ``` 24 | 25 | ```eval_rst 26 | .. figure:: ../images/framework.png 27 | :scale: 50% 28 | ``` 29 | 30 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. LibMTL documentation master file, created by 2 | sphinx-quickstart on Thu Nov 25 17:02:04 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | LibMTL: A PyTorch Library for Multi-Task Learning 7 | ================================================= 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Getting Started: 12 | 13 | docs/getting_started/introduction 14 | docs/getting_started/installation 15 | docs/getting_started/quick_start 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | :caption: User Guide: 20 | 21 | docs/user_guide/mtl 22 | docs/user_guide/framework 23 | docs/user_guide/benchmark 24 | 25 | .. toctree:: 26 | :maxdepth: 1 27 | :caption: Developer Guide: 28 | 29 | docs/develop/dataset 30 | docs/develop/arch 31 | docs/develop/weighting 32 | 33 | .. toctree:: 34 | :maxdepth: 1 35 | :caption: API Reference: 36 | 37 | docs/_autoapi/LibMTL/index 38 | docs/_autoapi/LibMTL/loss/index 39 | docs/_autoapi/LibMTL/utils/index 40 | docs/_autoapi/LibMTL/model/index 41 | docs/_autoapi/LibMTL/config/index 42 | docs/_autoapi/LibMTL/metrics/index 43 | docs/_autoapi/LibMTL/weighting/index 44 | docs/_autoapi/LibMTL/architecture/index 45 | 46 | 47 | 48 | Indices and tables 49 | ================== 50 | 51 | * :ref:`genindex` 52 | * :ref:`modindex` 53 | * :ref:`search` 54 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx==7.1.2 2 | recommonmark==0.7.1 3 | sphinx-autoapi==3.1.1 4 | sphinx-autobuild==2021.3.14 5 | sphinx-markdown-tables==0.0.17 6 | sphinx-rtd-theme==2.0.0 7 | sphinxcontrib-applehelp==1.0.2 8 | sphinxcontrib-bibtex==2.4.1 9 | sphinxcontrib-devhelp==1.0.2 10 | sphinxcontrib-htmlhelp==2.0.0 11 | sphinxcontrib-jsmath==1.0.1 12 | sphinxcontrib-qthelp==1.0.3 13 | sphinxcontrib-serializinghtml==1.1.5 14 | sphinxcontrib-websupport==1.2.3 15 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ## Run a Benchmark 2 | 3 | Here we introduce some MTL benchmark datasets and show how to run models on them for a fair comparison. 4 | 5 | - [The NYUv2 Dataset](https://github.com/median-research-group/LibMTL/tree/main/examples/nyu) 6 | - [The Cityscapes Dataset](https://github.com/median-research-group/LibMTL/tree/main/examples/cityscapes) 7 | - [The Office-31 and Office-Home Datasets](https://github.com/median-research-group/LibMTL/tree/main/examples/office) 8 | - [The QM9 Dataset](https://github.com/median-research-group/LibMTL/tree/main/examples/qm9) 9 | - [The PAWS-X Dataset from XTREME Benchmark](https://github.com/median-research-group/LibMTL/tree/main/examples/xtreme) 10 | 11 | 12 | -------------------------------------------------------------------------------- /examples/cityscapes/create_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | 3 | import os 4 | import torch 5 | import torch.nn.functional as F 6 | import fnmatch 7 | import numpy as np 8 | import random 9 | import json 10 | 11 | 12 | class CityScapes(Dataset): 13 | 14 | def __init__(self, root, mode='train'): 15 | self.mode = mode 16 | self.root = os.path.expanduser(root) 17 | 18 | # read the data file 19 | if self.mode == 'train': 20 | data_len = len(fnmatch.filter(os.listdir(self.root + '/train/image'), '*.npy')) 21 | self.index_list = list(range(data_len)) 22 | self.data_path = self.root + '/train' 23 | elif self.mode == 'test': 24 | data_len = len(fnmatch.filter(os.listdir(self.root + '/val/image'), '*.npy')) 25 | self.index_list = list(range(data_len)) 26 | self.data_path = self.root + '/val' 27 | 28 | def __getitem__(self, i): 29 | index = self.index_list[i] 30 | # load data from the pre-processed npy files 31 | image = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/image/{:d}.npy'.format(index)), -1, 0)) 32 | semantic = torch.from_numpy(np.load(self.data_path + '/label/{:d}.npy'.format(index))) 33 | depth = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/depth/{:d}.npy'.format(index)), -1, 0)) 34 | 35 | return image.float(), {'segmentation': semantic.float(), 'depth': depth.float()} 36 | 37 | def __len__(self): 38 | return len(self.index_list) 39 | -------------------------------------------------------------------------------- /examples/nyu/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DeepLabHead(nn.Sequential): 7 | def __init__(self, in_channels, num_classes): 8 | super(DeepLabHead, self).__init__( 9 | ASPP(in_channels, [12, 24, 36]), 10 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 11 | nn.BatchNorm2d(256), 12 | nn.ReLU(), 13 | nn.Conv2d(256, num_classes, 1) 14 | ) 15 | 16 | 17 | class ASPPConv(nn.Sequential): 18 | def __init__(self, in_channels, out_channels, dilation): 19 | modules = [ 20 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 21 | nn.BatchNorm2d(out_channels), 22 | nn.ReLU() 23 | ] 24 | super(ASPPConv, self).__init__(*modules) 25 | 26 | 27 | class ASPPPooling(nn.Sequential): 28 | def __init__(self, in_channels, out_channels): 29 | super(ASPPPooling, self).__init__( 30 | nn.AdaptiveAvgPool2d(1), 31 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 32 | nn.BatchNorm2d(out_channels), 33 | nn.ReLU()) 34 | 35 | def forward(self, x): 36 | size = x.shape[-2:] 37 | x = super(ASPPPooling, self).forward(x) 38 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 39 | 40 | 41 | class ASPP(nn.Module): 42 | def __init__(self, in_channels, atrous_rates): 43 | super(ASPP, self).__init__() 44 | out_channels = 256 45 | modules = [] 46 | modules.append(nn.Sequential( 47 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 48 | nn.BatchNorm2d(out_channels), 49 | nn.ReLU())) 50 | 51 | rate1, rate2, rate3 = tuple(atrous_rates) 52 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 53 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 54 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 55 | modules.append(ASPPPooling(in_channels, out_channels)) 56 | 57 | self.convs = nn.ModuleList(modules) 58 | 59 | self.project = nn.Sequential( 60 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 61 | nn.BatchNorm2d(out_channels), 62 | nn.ReLU(), 63 | nn.Dropout(0.5)) 64 | 65 | def forward(self, x): 66 | res = [] 67 | for conv in self.convs: 68 | res.append(conv(x)) 69 | res = torch.cat(res, dim=1) 70 | return self.project(res) 71 | -------------------------------------------------------------------------------- /examples/nyu/results/resnet_results.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/examples/nyu/results/resnet_results.pdf -------------------------------------------------------------------------------- /examples/nyu/results/segnet_results.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/examples/nyu/results/segnet_results.pdf -------------------------------------------------------------------------------- /examples/office/README.md: -------------------------------------------------------------------------------- 1 | ## Office-31 and Office-Home 2 | 3 | The Office-31 dataset [[1]](#1) consists of three classification tasks on three domains: Amazon, DSLR, and Webcam, where each task has 31 object categories. It can be download [here](https://www.cc.gatech.edu/~judy/domainadapt/#datasets_code). This dataset contains 4,110 labeled images and we randomly split these samples, with 60% for training, 20% for validation, and the rest 20% for testing. 4 | 5 | The Office-Home dataset [[2]](#2) has four classification tasks on four domains: Artistic images (abbreviated as Art), Clip art, Product images, and Real-world images. It can be download [here](https://www.hemanthdv.org/officeHomeDataset.html). This dataset has 15,500 labeled images in total and each domain contains 65 classes. We divide the entire data into the same proportion as the Office-31 dataset. 6 | 7 | Both datasets belong to the multi-input setting in MTL. Thus, the ``multi_input`` must be ``True`` for both of the two office datasets. 8 | 9 | We use the ResNet-18 network pretrained on the ImageNet dataset followed by a fully connected layer as a shared encoder among tasks and a fully connected layer is applied as a task-specific output layer for each task. All the input images are resized to . 10 | 11 | ### Run a Model 12 | 13 | The script ``main.py`` is the main file for training and evaluating a MTL model on the Office-31 or Office-Home dataset. A set of command-line arguments is provided to allow users to adjust the training parameter configuration. 14 | 15 | Some important arguments are described as follows. 16 | 17 | - ``weighting``: The weighting strategy. Refer to [here](../../LibMTL#supported-algorithms). 18 | - ``arch``: The MTL architecture. Refer to [here](../../LibMTL#supported-algorithms). 19 | - ``gpu_id``: The id of gpu. The default value is '0'. 20 | - ``seed``: The random seed for reproducibility. The default value is 0. 21 | - ``optim``: The type of the optimizer. We recommend to use 'adam' here. 22 | - ``dataset``: Training on Office-31 or Office-Home. Options: 'office-31', 'office-home'. 23 | - ``dataset_path``: The path of the Office-31 or Office-Home dataset. 24 | - ``bs``: The batch size of training, validation, and test data. The default value is 64. 25 | 26 | The complete command-line arguments and their descriptions can be found by running the following command. 27 | 28 | ```shell 29 | python main.py -h 30 | ``` 31 | 32 | If you understand those command-line arguments, you can train an MTL model by running a command like this. 33 | 34 | ```shell 35 | python main.py --weighting WEIGHTING --arch ARCH --dataset DATASET --dataset_path PATH --gpu_id GPU_ID --multi_input --save_path PATH --mode train 36 | ``` 37 | 38 | You can test the trained model by running the following command. 39 | 40 | ```she 41 | python main.py --weighting WEIGHTING --arch ARCH --dataset DATASET --dataset_path PATH --gpu_id GPU_ID --multi_input --load_path PATH --mode test 42 | ``` 43 | 44 | ### References 45 | 46 | [1] Kate Saenko, Brian Kulis, Mario Fritz, and Trevor Darrell. Adapting Visual Category Models to New Domains. In *European Conference on Computer Vision*, 2010. 47 | 48 | [2] Hemanth Venkateswara, Jose Eusebio, Shayok Chakraborty, and Sethuraman Panchanathan. Deep Hashing Network for Unsupervised Domain Adaptation. In *IEEE Conference on Computer Vision and Pattern Recognition*, 2017. 49 | -------------------------------------------------------------------------------- /examples/office/create_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Dataset 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | from LibMTL.utils import get_root_dir 8 | 9 | class office_Dataset(Dataset): 10 | def __init__(self, dataset, root_path, task, mode): 11 | self.transform = transforms.Compose([ 12 | transforms.Resize((224, 224)), 13 | transforms.ToTensor(), 14 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]), 15 | ]) 16 | f = open(os.path.join(get_root_dir(), 'examples/office', 'data_txt/{}/{}_{}.txt'.format(dataset, task, mode)), 'r') 17 | self.img_list = f.readlines() 18 | f.close() 19 | self.root_path = root_path 20 | 21 | def __getitem__(self, i): 22 | img_path = self.img_list[i][:-1].split(' ')[0] 23 | y = int(self.img_list[i][:-1].split(' ')[1]) 24 | img = Image.open(os.path.join(self.root_path, img_path)).convert('RGB') 25 | return self.transform(img), y 26 | 27 | def __len__(self): 28 | return len(self.img_list) 29 | 30 | def office_dataloader(dataset, batchsize, root_path): 31 | if dataset == 'office-31': 32 | tasks = ['amazon', 'dslr', 'webcam'] 33 | elif dataset == 'office-home': 34 | tasks = ['Art', 'Clipart', 'Product', 'Real_World'] 35 | data_loader = {} 36 | iter_data_loader = {} 37 | for k, d in enumerate(tasks): 38 | data_loader[d] = {} 39 | iter_data_loader[d] = {} 40 | for mode in ['train', 'val', 'test']: 41 | shuffle = True if mode == 'train' else False 42 | drop_last = True if mode == 'train' else False 43 | txt_dataset = office_Dataset(dataset, root_path, d, mode) 44 | # print(d, mode, len(txt_dataset)) 45 | data_loader[d][mode] = DataLoader(txt_dataset, 46 | num_workers=2, 47 | pin_memory=True, 48 | batch_size=batchsize, 49 | shuffle=shuffle, 50 | drop_last=drop_last) 51 | iter_data_loader[d][mode] = iter(data_loader[d][mode]) 52 | return data_loader, iter_data_loader 53 | -------------------------------------------------------------------------------- /examples/qm9/README.md: -------------------------------------------------------------------------------- 1 | ## QM9 2 | 3 | The QM9 dataset [[1]](#1) consists of about 130K molecules with 19 regression targets. The training codes are mainly followed [[2]](#2) and modified from [pytorch_geometric](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/qm9_nn_conv.py). 4 | 5 | ### Run a Model 6 | 7 | The script ``main.py`` is the main file for training and evaluating a MTL model on the QM9 dataset. A set of command-line arguments is provided to allow users to adjust the training parameter configuration. 8 | 9 | Some important arguments are described as follows. 10 | 11 | - ``weighting``: The weighting strategy. Refer to [here](../../LibMTL#supported-algorithms). 12 | - ``arch``: The MTL architecture. Refer to [here](../../LibMTL#supported-algorithms). 13 | - ``gpu_id``: The id of gpu. The default value is '0'. 14 | - ``seed``: The random seed for reproducibility. The default value is 0. 15 | - ``optim``: The type of the optimizer. We recommend to use 'adam' here. 16 | - ``target``: The index of target tasks. 17 | - ``dataset_path``: The path of the QM9 dataset. 18 | - ``bs``: The batch size of training, validation, and test data. The default value is 128. 19 | 20 | The complete command-line arguments and their descriptions can be found by running the following command. 21 | 22 | ```shell 23 | python main.py -h 24 | ``` 25 | 26 | If you understand those command-line arguments, you can train a MTL model by running a command like this. 27 | 28 | ```shell 29 | python main.py --weighting WEIGHTING --arch ARCH --dataset_path PATH --gpu_id GPU_ID --target TARGET --mode train --save_path PATH 30 | ``` 31 | 32 | You can test the trained MTL model by running the following command. 33 | 34 | ```she 35 | python main.py --weighting WEIGHTING --arch ARCH --dataset_path PATH --gpu_id GPU_ID --target TARGET --mode test --load_path PATH 36 | ``` 37 | 38 | ### References 39 | 40 | [1] Zhenqin Wu, Bharath Ramsundar, Evan N. Feinberg, Joseph Gomes, Caleb Geniesse, Aneesh S. Pappu, Karl Leswing, and Vijay Pande. MoleculeNet: A Benchmark for Molecular Machine Learning. *Chemical Science*, 9(2):513-530, 2018. 41 | 42 | [2] Aviv Navon, Aviv Shamsian, Idan Achituve, Haggai Maron, Kenji Kawaguchi, Gal Chechik, and Ethan Fetaya. Multi-task Learning as a Bargaining Game. In *International Conference on Machine Learning*, 2022. 43 | -------------------------------------------------------------------------------- /examples/qm9/random_split.t: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/examples/qm9/random_split.t -------------------------------------------------------------------------------- /examples/qm9/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.metrics import AbsMetric 7 | from LibMTL.loss import AbsLoss 8 | 9 | class QM9Metric(AbsMetric): 10 | r"""Calculate the Mean Absolute Error (MAE). 11 | """ 12 | def __init__(self, std, scale=1): 13 | super(QM9Metric, self).__init__() 14 | 15 | self.std = std 16 | self.scale = scale 17 | 18 | def update_fun(self, pred, gt): 19 | r""" 20 | """ 21 | abs_err = torch.abs(pred * (self.std).to(pred.device) - gt * (self.std).to(pred.device)).view(pred.size()[0], -1).sum(-1) 22 | self.record.append(abs_err.cpu().numpy()) 23 | 24 | def score_fun(self): 25 | r""" 26 | """ 27 | records = np.concatenate(self.record) 28 | return [records.mean()*self.scale] 29 | -------------------------------------------------------------------------------- /examples/xtreme/README.md: -------------------------------------------------------------------------------- 1 | ## PAWS-X from XTREME Benchmark 2 | 3 | The PAWS-X dataset is a multilingual sentence classification dataset from XTREME benchmark [[1]](#1). Following [[2]](#2), we use English (en), Mandarin (zh), German (de) and Spanish (es) to form a multi-input multi-task problem. Each language/task has about 49.4K, 2.0K, and 2.0K data samples for training, validation, and testing. The training settings are mainly followed [[2]](#2) and the codes are modified from [xtreme](https://github.com/google-research/xtreme). 4 | 5 | Run the following command to download the dataset, 6 | 7 | ```shell 8 | bash propocess_data/download_data.sh 9 | ``` 10 | 11 | ### Dependencies 12 | 13 | - networkx==1.11 14 | 15 | - transformers==4.6.1 16 | 17 | ### Run a Model 18 | 19 | The script ``main.py`` is the main file for training and evaluating a MTL model on the PAWS-X dataset. A set of command-line arguments is provided to allow users to adjust the training parameter configuration. 20 | 21 | Some important arguments are described as follows. 22 | 23 | - ``weighting``: The weighting strategy. Refer to [here](../../LibMTL#supported-algorithms). 24 | - ``arch``: The MTL architecture. Refer to [here](../../LibMTL#supported-algorithms). 25 | - ``gpu_id``: The id of gpu. The default value is '0'. 26 | - ``seed``: The random seed for reproducibility. The default value is 0. 27 | - ``dataset_path``: The path of the PAWS-X dataset. 28 | - ``bs``: The batch size of training, validation, and test data. The default value is 32. 29 | 30 | The complete command-line arguments and their descriptions can be found by running the following command. 31 | 32 | ```shell 33 | python main.py -h 34 | ``` 35 | 36 | If you understand those command-line arguments, you can train a MTL model by running a command like this. 37 | 38 | ```shell 39 | python main.py --weighting WEIGHTING --arch ARCH --dataset_path PATH --gpu_id GPU_ID --multi_input --mode train --save_path PATH 40 | ``` 41 | 42 | You can test the trained MTL model by running the following command. 43 | 44 | ```shell 45 | python main.py --weighting WEIGHTING --arch ARCH --dataset_path PATH --gpu_id GPU_ID --multi_input --mode test --load_path PATH 46 | ``` 47 | 48 | ### References 49 | 50 | [1] Junjie Hu, Sebastian Ruder, Aditya Siddhant, Graham Neubig, Orhan Firat, and Melvin Johnson. XTREME: A Massively Multilingual Multi-task Benchmark for Evaluating Cross-lingual Generalisation. In *International Conference on Machine Learning*, 2020. 51 | 52 | [2] Baijiong Lin, Feiyang Ye, Yu Zhang, and Ivor Tsang. Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-task Learning. *Transactions on Machine Learning Research*, 2022. 53 | -------------------------------------------------------------------------------- /examples/xtreme/propocess_data/conllu_to_conll.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from itertools import islice 3 | from pathlib import Path 4 | import argparse 5 | import sys, copy 6 | 7 | from conll import CoNLLReader 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description="""Convert conllu to conll format""") 11 | parser.add_argument('input', help="conllu file") 12 | parser.add_argument('output', help="target file", type=Path) 13 | parser.add_argument('--replace_subtokens_with_fused_forms', help="By default removes fused tokens", default=False, action="store_true") 14 | parser.add_argument('--remove_deprel_suffixes', help="Restrict deprels to the common universal subset, e.g. nmod:tmod becomes nmod", default=False, action="store_true") 15 | parser.add_argument('--remove_node_properties', help="space-separated list of node properties to remove: form, lemma, cpostag, postag, feats", choices=['form', 'lemma', 'cpostag','postag','feats'], metavar='prop', type=str, nargs='+') 16 | parser.add_argument('--lang', help="specify a language 2-letter code", default="default") 17 | parser.add_argument('--output_format', choices=['conll2006', 'conll2009', 'conllu'], default="conll2006") 18 | parser.add_argument('--remove_arabic_diacritics', help="remove Arabic short vowels", default=False, action="store_true") 19 | parser.add_argument('--print_comments',default=False,action="store_true") 20 | parser.add_argument('--print_fused_forms',default=False,action="store_true") 21 | 22 | args = parser.parse_args() 23 | 24 | if sys.version_info < (3,0): 25 | print("Sorry, requires Python 3.x.") #suggestion: install anaconda python 26 | sys.exit(1) 27 | 28 | POSRANKPRECEDENCEDICT = defaultdict(list) 29 | POSRANKPRECEDENCEDICT["default"] = "VERB NOUN PROPN PRON ADJ NUM ADV INTJ AUX ADP DET PART CCONJ SCONJ X PUNCT ".split(" ") 30 | # POSRANKPRECEDENCEDICT["de"] = "PROPN ADP DET ".split(" ") 31 | POSRANKPRECEDENCEDICT["es"] = "VERB AUX PRON ADP DET".split(" ") 32 | POSRANKPRECEDENCEDICT["fr"] = "VERB AUX PRON NOUN ADJ ADV ADP DET PART SCONJ CONJ".split(" ") 33 | POSRANKPRECEDENCEDICT["it"] = "VERB AUX ADV PRON ADP DET INTJ".split(" ") 34 | 35 | if args.lang in POSRANKPRECEDENCEDICT: 36 | current_pos_precedence_list = POSRANKPRECEDENCEDICT[args.lang] 37 | else: 38 | current_pos_precedence_list = POSRANKPRECEDENCEDICT["default"] 39 | 40 | cio = CoNLLReader() 41 | orig_treebank = cio.read_conll_u(args.input)#, args.keep_fused_forms, args.lang, POSRANKPRECEDENCEDICT) 42 | modif_treebank = copy.copy(orig_treebank) 43 | 44 | # As per Dec 2015 the args.lang variable is redundant once you have current_pos_precedence_list 45 | # We keep it for future modifications, i.e. any language-specific modules 46 | for s in modif_treebank: 47 | # print('sentence', s.get_sentence_as_string(printid=True)) 48 | s.filter_sentence_content(args.replace_subtokens_with_fused_forms, args.lang, current_pos_precedence_list,args.remove_node_properties,args.remove_deprel_suffixes,args.remove_arabic_diacritics) 49 | 50 | cio.write_conll(modif_treebank,args.output, args.output_format,print_fused_forms=args.print_fused_forms, print_comments=args.print_comments) 51 | 52 | if __name__ == "__main__": 53 | main() -------------------------------------------------------------------------------- /examples/xtreme/propocess_data/download_data.sh: -------------------------------------------------------------------------------- 1 | REPO=$PWD 2 | PRO=$REPO/propocess_data/ 3 | DIR=$REPO/data/ 4 | mkdir -p $DIR 5 | 6 | # download PAWS-X dataset 7 | function download_pawsx { 8 | cd $DIR 9 | wget https://storage.googleapis.com/paws/pawsx/x-final.tar.gz -q --show-progress 10 | tar xzf x-final.tar.gz -C $DIR/ 11 | python $PRO/utils_preprocess.py \ 12 | --data_dir $DIR/x-final \ 13 | --output_dir $DIR/pawsx/ \ 14 | --task pawsx 15 | rm -rf x-final x-final.tar.gz 16 | echo "Successfully downloaded data at $DIR/pawsx" >> $DIR/download.log 17 | } 18 | 19 | download_pawsx -------------------------------------------------------------------------------- /examples/xtreme/utils.py: -------------------------------------------------------------------------------- 1 | import torch, math, warnings 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from LibMTL.loss import AbsLoss 7 | 8 | warnings.simplefilter('ignore', UserWarning) 9 | 10 | # pawsx 11 | class SCLoss(AbsLoss): 12 | def __init__(self, label_num): 13 | super(SCLoss, self).__init__() 14 | self.loss_fn = nn.CrossEntropyLoss() 15 | self.label_num = label_num 16 | 17 | def compute_loss(self, pred, gt): 18 | return self.loss_fn(pred.view(-1, self.label_num), gt.view(-1)) 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "LibMTL" 3 | version = "1.1.5" 4 | description = "A PyTorch Library for Multi-Task Learning" 5 | authors = ["Baijiong Lin "] 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cvxpy==1.6.5 2 | numpy==1.26.3 3 | qpsolvers==4.7.0 4 | scipy==1.13.1 5 | torch==2.3.0+cu121 6 | torchvision==0.18.0+cu121 7 | torch-scatter==2.1.2+pt23cu121 8 | torch_sparse==0.6.18+pt23cu121 9 | torch_geometric==2.6.1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('README.md', 'r', encoding='utf-8') as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name='LibMTL', 8 | version='1.1.5', 9 | description='A PyTorch Library for Multi-Task Learning', 10 | author='Baijiong Lin', 11 | author_email='bj.lin.email@gmail.com', 12 | url='https://github.com/median-research-group/LibMTL', 13 | packages=find_packages(), 14 | license='MIT', 15 | platforms=["all"], 16 | classifiers=['Intended Audience :: Developers', 17 | 'Intended Audience :: Education', 18 | 'Intended Audience :: Science/Research', 19 | 'License :: OSI Approved :: MIT License', 20 | 'Programming Language :: Python :: 3.9', 21 | 'Programming Language :: Python :: 3.10', 22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 23 | 'Topic :: Scientific/Engineering :: Mathematics', 24 | 'Topic :: Software Development :: Libraries',], 25 | long_description=long_description, 26 | long_description_content_type='text/markdown', 27 | install_requires=['torch>=2.3.0', 28 | 'torchvision>=0.18.0', 29 | 'numpy>=1.26'] 30 | ) 31 | 32 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | ## Test 2 | 3 | ```shell 4 | pip install coverage 5 | coverage erase 6 | coverage run -m --parallel-mode --source ./LibMTL pytest tests/test_nyu.py 7 | coverage run -m --parallel-mode --source ./LibMTL pytest tests/test_office31.py 8 | coverage run -m --parallel-mode --source ./LibMTL pytest tests/test_office_home.py 9 | coverage run -m --parallel-mode --source ./LibMTL pytest tests/test_qm9.py 10 | coverage run -m --parallel-mode --source ./LibMTL pytest tests/test_pawsx.py 11 | coverage combine 12 | coverage report 13 | rm -rf tests/htmlcov 14 | coverage html -d tests/htmlcov 15 | pip install coverage-badge 16 | rm tests/coverage.svg 17 | coverage-badge -o tests/coverage.svg 18 | ``` 19 | 20 | The HTML report is [here](https://htmlpreview.github.io/?https://github.com/median-research-group/LibMTL/blob/main/tests/htmlcov/index.html). 21 | 22 | 23 | -------------------------------------------------------------------------------- /tests/coverage.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | coverage 17 | coverage 18 | 94% 19 | 94% 20 | 21 | 22 | -------------------------------------------------------------------------------- /tests/htmlcov/favicon_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/tests/htmlcov/favicon_32.png -------------------------------------------------------------------------------- /tests/htmlcov/keybd_closed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/tests/htmlcov/keybd_closed.png -------------------------------------------------------------------------------- /tests/htmlcov/keybd_open.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/median-research-group/LibMTL/4336804847eaa5e0b924b743d76beec7ac3fdc97/tests/htmlcov/keybd_open.png --------------------------------------------------------------------------------