'
13 |
14 | hr_faded: '
'
15 | hr_shaded: '
'
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3"
2 | services:
3 | fastai: &fastai
4 | restart: unless-stopped
5 | working_dir: /data
6 | image: fastai/codespaces
7 | logging:
8 | driver: json-file
9 | options:
10 | max-size: 50m
11 | stdin_open: true
12 | tty: true
13 | volumes:
14 | - .:/data/
15 |
16 | notebook:
17 | <<: *fastai
18 | command: bash -c "pip install -e . && jupyter notebook --allow-root --no-browser --ip=0.0.0.0 --port=8080 --NotebookApp.token='' --NotebookApp.password=''"
19 | ports:
20 | - "8080:8080"
21 |
22 | watcher:
23 | <<: *fastai
24 | command: watchmedo shell-command --command nbdev_build_docs --pattern *.ipynb --recursive --drop
25 | network_mode: host # for GitHub Codespaces https://github.com/features/codespaces/
26 |
27 | jekyll:
28 | <<: *fastai
29 | ports:
30 | - "4000:4000"
31 | command: >
32 | bash -c "cp -r docs_src docs
33 | && pip install .
34 | && nbdev_build_docs && cd docs
35 | && bundle i
36 | && chmod -R u+rwx . && bundle exec jekyll serve --host 0.0.0.0"
37 |
--------------------------------------------------------------------------------
/docs/licenses/LICENSE:
--------------------------------------------------------------------------------
1 | /* This license pertains to the docs template, except for the Navgoco jQuery component. */
2 |
3 | The MIT License (MIT)
4 |
5 | Original theme: Copyright (c) 2016 Tom Johnson
6 | Modifications: Copyright (c) 2017 onwards fast.ai, Inc
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 |
--------------------------------------------------------------------------------
/source_nbs/02_special_tokens.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# default_exp special_tokens\n",
10 | "%load_ext autoreload\n",
11 | "%autoreload 2"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "# Special Tokens\n"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "# export\n",
28 | "BOS_TOKEN = '[unused98]'\n",
29 | "EOS_TOKEN = '[unused99]'\n",
30 | "CLS_TOKEN = '[CLS]'\n",
31 | "SPACE_TOKEN = '[unused1]'\n",
32 | "UNK_TOKEN = '[UNK]'\n",
33 | "SPECIAL_TOKENS = [BOS_TOKEN, EOS_TOKEN, CLS_TOKEN, SPACE_TOKEN, UNK_TOKEN]\n",
34 | "TRAIN = 'train'\n",
35 | "EVAL = 'eval'\n",
36 | "PREDICT = 'infer'\n",
37 | "MODAL_LIST = ['image', 'others']"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": []
46 | }
47 | ],
48 | "metadata": {
49 | "kernelspec": {
50 | "display_name": "Python 3",
51 | "language": "python",
52 | "name": "python3"
53 | }
54 | },
55 | "nbformat": 4,
56 | "nbformat_minor": 2
57 | }
58 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | repository: JayYip/m3tl
2 | output: web
3 | topnav_title: m3tl
4 | site_title: m3tl
5 | company_name: jayyip
6 | description: BERT for Multi-task Learning
7 | # Set to false to disable KaTeX math
8 | use_math: true
9 | # Add Google analytics id if you have one and want to use it here
10 | google_analytics:
11 | # See http://nbdev.fast.ai/search for help with adding Search
12 | google_search:
13 |
14 | host: 127.0.0.1
15 | # the preview server used. Leave as is.
16 | port: 4000
17 | # the port where the preview is rendered.
18 |
19 | exclude:
20 | - .idea/
21 | - .gitignore
22 | - vendor
23 |
24 | exclude: [vendor]
25 |
26 | highlighter: rouge
27 | markdown: kramdown
28 | kramdown:
29 | input: GFM
30 | auto_ids: true
31 | hard_wrap: false
32 | syntax_highlighter: rouge
33 |
34 | collections:
35 | tooltips:
36 | output: false
37 |
38 | defaults:
39 | -
40 | scope:
41 | path: ""
42 | type: "pages"
43 | values:
44 | layout: "page"
45 | comments: true
46 | search: true
47 | sidebar: home_sidebar
48 | topnav: topnav
49 | -
50 | scope:
51 | path: ""
52 | type: "tooltips"
53 | values:
54 | layout: "page"
55 | comments: true
56 | search: true
57 | tooltip: true
58 |
59 | sidebars:
60 | - home_sidebar
61 |
62 | theme: jekyll-theme-cayman
63 | baseurl: /m3tl/
--------------------------------------------------------------------------------
/docs/feed.xml:
--------------------------------------------------------------------------------
1 | ---
2 | search: exclude
3 | layout: none
4 | ---
5 |
6 |
7 |
8 |
9 | {{ site.title | xml_escape }}
10 | {{ site.description | xml_escape }}
11 | {{ site.url }}/
12 |
13 | {{ site.time | date_to_rfc822 }}
14 | {{ site.time | date_to_rfc822 }}
15 | Jekyll v{{ jekyll.version }}
16 | {% for post in site.posts limit:10 %}
17 | -
18 |
{{ post.title | xml_escape }}
19 | {{ post.content | xml_escape }}
20 | {{ post.date | date_to_rfc822 }}
21 | {{ post.url | prepend: site.url }}
22 | {{ post.url | prepend: site.url }}
23 | {% for tag in post.tags %}
24 | {{ tag | xml_escape }}
25 | {% endfor %}
26 | {% for tag in page.tags %}
27 | {{ cat | xml_escape }}
28 | {% endfor %}
29 |
30 | {% endfor %}
31 |
32 |
33 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 |
2 | import codecs
3 | from setuptools import setup, find_packages
4 | from pkg_resources import parse_version
5 | from configparser import ConfigParser
6 | import setuptools
7 | assert parse_version(setuptools.__version__) >= parse_version('36.2')
8 |
9 | # note: all settings are in settings.ini; edit there, not here
10 | config = ConfigParser(delimiters=['='])
11 | config.read('settings.ini')
12 | cfg = config['DEFAULT']
13 |
14 | with codecs.open('README.md', 'r', 'utf8') as reader:
15 | long_description = reader.read()
16 |
17 |
18 | with codecs.open('requirements.txt', 'r', 'utf8') as reader:
19 | install_requires = list(map(lambda x: x.strip(), reader.readlines()))
20 |
21 |
22 | setup(
23 | name='m3tl',
24 | version=cfg['version'],
25 | packages=find_packages(),
26 | url='https://github.com/JayYip/m3tl',
27 | license='MIT',
28 | author='Jay Yip',
29 | author_email='junpang.yip@gmail.com',
30 | description='BERT for Multi-task Learning',
31 | long_description_content_type='text/markdown',
32 | long_description=long_description,
33 | python_requires='>=3.5.0',
34 | install_requires=install_requires,
35 | classifiers=(
36 | "Programming Language :: Python :: 3.6",
37 | "Programming Language :: Python :: 3.5",
38 | "License :: OSI Approved :: MIT License",
39 | "Operating System :: OS Independent",
40 | "Intended Audience :: Developers",
41 | ),
42 | )
43 |
--------------------------------------------------------------------------------
/docs/_includes/links.html:
--------------------------------------------------------------------------------
1 | {% comment %}Get links from each sidebar, as listed in the _config.yml file under sidebars{% endcomment %}
2 |
3 | {% for sidebar in site.sidebars %}
4 | {% for entry in site.data.sidebars[sidebar].entries %}
5 | {% for folder in entry.folders %}
6 | {% for folderitem in folder.folderitems %}
7 | {% if folderitem.url contains "html#" %}
8 | [{{folderitem.url | remove: "/" }}]: {{folderitem.url | remove: "/"}}
9 | {% else %}
10 | [{{folderitem.url | remove: "/" | remove: ".html"}}]: {{folderitem.url | remove: "/"}}
11 | {% endif %}
12 | {% for subfolders in folderitem.subfolders %}
13 | {% for subfolderitem in subfolders.subfolderitems %}
14 | [{{subfolderitem.url | remove: "/" | remove: ".html"}}]: {{subfolderitem.url | remove: "/"}}
15 | {% endfor %}
16 | {% endfor %}
17 | {% endfor %}
18 | {% endfor %}
19 | {% endfor %}
20 | {% endfor %}
21 |
22 |
23 | {% comment %} Get links from topnav {% endcomment %}
24 |
25 | {% for entry in site.data.topnav.topnav %}
26 | {% for item in entry.items %}
27 | {% if item.external_url == null %}
28 | [{{item.url | remove: "/" | remove: ".html"}}]: {{item.url | remove: "/"}}
29 | {% endif %}
30 | {% endfor %}
31 | {% endfor %}
32 |
33 | {% comment %}Get links from topnav dropdowns {% endcomment %}
34 |
35 | {% for entry in site.data.topnav.topnav_dropdowns %}
36 | {% for folder in entry.folders %}
37 | {% for folderitem in folder.folderitems %}
38 | {% if folderitem.external_url == null %}
39 | [{{folderitem.url | remove: "/" | remove: ".html"}}]: {{folderitem.url | remove: "/"}}
40 | {% endif %}
41 | {% endfor %}
42 | {% endfor %}
43 | {% endfor %}
44 |
45 |
--------------------------------------------------------------------------------
/docs/_includes/head_print.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
{% if page.homepage == true %} {{site.homepage_title}} {% elsif page.title %}{{ page.title }}{% endif %} | {{ site.site_title }}
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
23 |
24 |
29 |
--------------------------------------------------------------------------------
/docs/css/modern-business.css:
--------------------------------------------------------------------------------
1 | /*!
2 | * Start Bootstrap - Modern Business HTML Template (http://startbootstrap.com)
3 | * Code licensed under the Apache License v2.0.
4 | * For details, see http://www.apache.org/licenses/LICENSE-2.0.
5 | */
6 |
7 | /* Global Styles */
8 |
9 | html,
10 | body {
11 | height: 100%;
12 | }
13 |
14 | .img-portfolio {
15 | margin-bottom: 30px;
16 | }
17 |
18 | .img-hover:hover {
19 | opacity: 0.8;
20 | }
21 |
22 | /* Home Page Carousel */
23 |
24 | header.carousel {
25 | height: 50%;
26 | }
27 |
28 | header.carousel .item,
29 | header.carousel .item.active,
30 | header.carousel .carousel-inner {
31 | height: 100%;
32 | }
33 |
34 | header.carousel .fill {
35 | width: 100%;
36 | height: 100%;
37 | background-position: center;
38 | background-size: cover;
39 | }
40 |
41 | /* 404 Page Styles */
42 |
43 | .error-404 {
44 | font-size: 100px;
45 | }
46 |
47 | /* Pricing Page Styles */
48 |
49 | .price {
50 | display: block;
51 | font-size: 50px;
52 | line-height: 50px;
53 | }
54 |
55 | .price sup {
56 | top: -20px;
57 | left: 2px;
58 | font-size: 20px;
59 | }
60 |
61 | .period {
62 | display: block;
63 | font-style: italic;
64 | }
65 |
66 | /* Footer Styles */
67 |
68 | footer {
69 | margin: 50px 0;
70 | }
71 |
72 | /* Responsive Styles */
73 |
74 | @media(max-width:991px) {
75 | .client-img,
76 | .img-related {
77 | margin-bottom: 30px;
78 | }
79 | }
80 |
81 | @media(max-width:767px) {
82 | .img-portfolio {
83 | margin-bottom: 15px;
84 | }
85 |
86 | header.carousel .carousel {
87 | height: 70%;
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/docs/licenses/LICENSE-BSD-NAVGOCO.txt:
--------------------------------------------------------------------------------
1 | /* This license pertains to the Navgoco jQuery component used for the sidebar. */
2 |
3 | Copyright (c) 2013, Christodoulos Tsoulloftas, http://www.komposta.net
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without modification,
7 | are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice,
10 | this list of conditions and the following disclaimer.
11 | * Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 | * Neither the name of the
nor the names of its
15 | contributors may be used to endorse or promote products derived from this
16 | software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
21 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
22 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
23 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
26 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
27 | OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/docs/sidebar.json:
--------------------------------------------------------------------------------
1 | {
2 | "m3tl": {
3 | "Overview": "/",
4 | "Tutorial": "tutorial.html",
5 | "": {
6 | "Problem Types": {
7 | "Classification": "1_problem_type_cls.html",
8 | "Multi-Label Classification": "2_problem_type_multi_cls.html",
9 | "Sequence Labeling": "3_problem_type_seq_tag.html",
10 | "Masked Language Model": "4_problem_type_masklm.html",
11 | "NSP+MLM(Deprecated)": "6_problem_type_pretrain.html",
12 | "Regression": "7_problem_type_regression.html",
13 | "Vector Fitting": "8_problem_type_vector_fit.html",
14 | "Pre-masked Masked Language Model": "9_problem_type_premask_mlm.html",
15 | "Contrastive Learning": "10_problem_type_contrast_learning.html"
16 | },
17 | "MTL Models": {
18 | "MTLBase": "15-00_mtl_model_base.html",
19 | "MMoE": "15-01_mtl_model_mmoe.html"
20 | },
21 | "Loss Combination Strategy": {
22 | "LossCombinationStrategyBase": "16-00_loss_combination_strategy.html"
23 | },
24 | "Embedding Layers": {
25 | "EmbeddingLayerBase": "18-00_embedding_layer.html"
26 | }
27 | },
28 | "Params": "params.html",
29 | "Run Bert Multitask Learning": "run_bert_multitask.html",
30 | "Multitask Learning Model": "model_fn.html",
31 | "Utils": "utils.html",
32 | "Special Tokens": "special_tokens.html",
33 | "Bert Utils": "bert_utils.html",
34 | "Create Bert Features": "create_bert_features.html",
35 | "Preprocessing Decorator": "preproc_decorator.html",
36 | "Read and Write TFRecord": "read_write_tfrecord.html",
37 | "Function to Create Datasets": "input_fn.html",
38 | "Pre-defined Problems": "predefined_problems_test.html",
39 | "Body Modeling": "modeling.html",
40 | "Top Models": "top.html",
41 | "Test Base": "test_base.html"
42 | }
43 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | tmp/
107 | **/.DS_Store
108 |
109 | **/Boson*
110 |
111 | data/ctb8.0
112 | data/ontonotes
113 | .vscode
114 | .idea
115 | data/
116 | *_experiments*
117 | chinese_L-12_H-768_A-12/
118 | models/
119 | test.py
120 | cov.xml
--------------------------------------------------------------------------------
/docs/_layouts/page.html:
--------------------------------------------------------------------------------
1 | ---
2 | layout: default
3 | ---
4 |
5 |
12 |
13 | {% if page.simple_map == true %}
14 |
15 |
20 |
21 | {% include custom/{{page.map_name}}.html %}
22 |
23 | {% elsif page.complex_map == true %}
24 |
25 |
30 |
31 | {% include custom/{{page.map_name}}.html %}
32 |
33 | {% endif %}
34 |
35 |
36 |
37 | {% if page.summary %}
38 |
{{page.summary}}
39 | {% endif %}
40 |
41 | {% unless page.toc == false %}
42 | {% include toc.html %}
43 | {% endunless %}
44 |
45 |
46 | {% if site.github_editme_path %}
47 |
48 |
Edit me
49 |
50 | {% endif %}
51 |
52 | {{content}}
53 |
54 |
65 |
66 |
67 |
68 | {{site.data.alerts.hr_shaded}}
69 |
70 | {% include footer.html %}
71 |
--------------------------------------------------------------------------------
/docs/js/customscripts.js:
--------------------------------------------------------------------------------
1 | $('#mysidebar').height($(".nav").height());
2 |
3 |
4 | $( document ).ready(function() {
5 |
6 | //this script says, if the height of the viewport is greater than 800px, then insert affix class, which makes the nav bar float in a fixed
7 | // position as your scroll. if you have a lot of nav items, this height may not work for you.
8 | var h = $(window).height();
9 | //console.log (h);
10 | if (h > 800) {
11 | $( "#mysidebar" ).attr("class", "nav affix");
12 | }
13 | // activate tooltips. although this is a bootstrap js function, it must be activated this way in your theme.
14 | $('[data-toggle="tooltip"]').tooltip({
15 | placement : 'top'
16 | });
17 |
18 | /**
19 | * AnchorJS
20 | */
21 | anchors.add('h2,h3,h4,h5');
22 |
23 | });
24 |
25 | // needed for nav tabs on pages. See Formatting > Nav tabs for more details.
26 | // script from http://stackoverflow.com/questions/10523433/how-do-i-keep-the-current-tab-active-with-twitter-bootstrap-after-a-page-reload
27 | $(function() {
28 | var json, tabsState;
29 | $('a[data-toggle="pill"], a[data-toggle="tab"]').on('shown.bs.tab', function(e) {
30 | var href, json, parentId, tabsState;
31 |
32 | tabsState = localStorage.getItem("tabs-state");
33 | json = JSON.parse(tabsState || "{}");
34 | parentId = $(e.target).parents("ul.nav.nav-pills, ul.nav.nav-tabs").attr("id");
35 | href = $(e.target).attr('href');
36 | json[parentId] = href;
37 |
38 | return localStorage.setItem("tabs-state", JSON.stringify(json));
39 | });
40 |
41 | tabsState = localStorage.getItem("tabs-state");
42 | json = JSON.parse(tabsState || "{}");
43 |
44 | $.each(json, function(containerId, href) {
45 | return $("#" + containerId + " a[href=" + href + "]").tab('show');
46 | });
47 |
48 | $("ul.nav.nav-pills, ul.nav.nav-tabs").each(function() {
49 | var $this = $(this);
50 | if (!json[$this.attr("id")]) {
51 | return $this.find("a[data-toggle=tab]:first, a[data-toggle=pill]:first").tab("show");
52 | }
53 | });
54 | });
55 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # M3TL
2 |
3 |
4 |
5 | **M**ulti-**M**odal **M**ulti-**T**ask **L**earning
6 |
7 | ## Install
8 |
9 | ```
10 | pip install m3tl
11 | ```
12 |
13 | ## What is it
14 |
15 | This is a project that uses transformers(based on huggingface transformers) as base model to do **multi-modal multi-task learning**.
16 |
17 | ## Why do I need this
18 |
19 | Multi-task learning(MTL) is gaining more and more attention, especially in deep learning era. It is widely used in NLP, CV, recommendation, etc. However, MTL usually involves complicated data preprocessing, task managing and task interaction. Other open-source projects, like TencentNLP and PyText, supports MTL but in a naive way and it's not straightforward to implement complicated MTL algorithm. In this project, we try to make writing MTL model as easy as single task learning model and further extend MTL to multi-modal multi-task learning. To do so, we expose following MTL related programable module to user:
20 |
21 | - problem sampling strategy
22 | - loss combination strategy
23 | - gradient surgery
24 | - model after base model(transformers)
25 |
26 | Apart from programable modules, we also provide various built-in SOTA MTL algorithms.
27 |
28 | In a word, you can use this project to:
29 |
30 | - implement complicated MTL algorithm
31 | - do SOTA MTL without diving into details
32 | - do multi-modal learning
33 |
34 | And since we use transformers as base model, you get all the benefits that you can get from transformers!
35 |
36 | ## What type of problems are supported?
37 |
38 | ```
39 | params = Params()
40 | for problem_type in params.list_available_problem_types():
41 | print('`{problem_type}`: {desc}'.format(
42 | desc=params.problem_type_desc[problem_type], problem_type=problem_type))
43 |
44 | ```
45 |
46 | `cls`: Classification
47 | `multi_cls`: Multi-Label Classification
48 | `seq_tag`: Sequence Labeling
49 | `masklm`: Masked Language Model
50 | `pretrain`: NSP+MLM(Deprecated)
51 | `regression`: Regression
52 | `vector_fit`: Vector Fitting
53 | `premask_mlm`: Pre-masked Masked Language Model
54 | `contrastive_learning`: Contrastive Learning
55 |
56 |
57 |
58 | ## Get Started
59 |
60 | Please see tutorials.
61 |
62 |
--------------------------------------------------------------------------------
/docs/images/colab.svg:
--------------------------------------------------------------------------------
1 | Open in Colab Open in Colab
2 |
--------------------------------------------------------------------------------
/m3tl/problem_types/utils.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/12_0_problem_type_utils.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['empty_tensor_handling_loss', 'nan_loss_handling', 'create_dummy_if_empty', 'BaseTop', 'pad_to_shape']
4 |
5 | # Cell
6 | from typing import Dict, Tuple
7 |
8 | import tensorflow as tf
9 | from ..base_params import BaseParams
10 |
11 |
12 | def empty_tensor_handling_loss(labels, logits, loss_fn):
13 | if tf.equal(tf.size(labels), 0):
14 | return 0.0
15 | if tf.equal(tf.size(tf.shape(labels)), 0):
16 | return 0.0
17 | if tf.equal(tf.shape(labels)[0], 0):
18 | return 0.0
19 | else:
20 | return tf.reduce_mean(loss_fn(
21 | labels, logits, from_logits=True))
22 |
23 |
24 | @tf.function
25 | def nan_loss_handling(loss):
26 | if tf.math.is_nan(loss):
27 | return 0.0
28 | else:
29 | return loss
30 |
31 |
32 | @tf.function
33 | def create_dummy_if_empty(inp_tensor: tf.Tensor) -> tf.Tensor:
34 | shape_tensor = tf.shape(inp_tensor)
35 | if tf.equal(shape_tensor[0], 0):
36 | data_type = inp_tensor.dtype
37 | dummy_shape_first_dim = tf.convert_to_tensor([1], dtype=tf.int32)
38 | dummy_shape = tf.concat(
39 | [dummy_shape_first_dim, shape_tensor[1:]], axis=0)
40 | dummy_tensor = tf.zeros(dummy_shape, dtype=data_type)
41 | return dummy_tensor
42 | else:
43 | return inp_tensor
44 |
45 |
46 | class BaseTop(tf.keras.Model):
47 | def __init__(self, params: BaseParams, problem_name: str) -> None:
48 | super(BaseTop, self).__init__(name=problem_name)
49 | self.params = params
50 | self.problem_name = problem_name
51 |
52 | def call(self, inputs: Tuple[Dict], mode: str):
53 | raise NotImplementedError
54 |
55 | def pad_to_shape(from_tensor: tf.Tensor, to_tensor: tf.Tensor, axis=1) -> tf.Tensor:
56 | # sometimes the length of labels dose not equal to length of inputs
57 | # that's caused by tf.data.experimental.bucket_by_sequence_length in multi problem scenario
58 | pad_len = tf.shape(input=to_tensor)[
59 | axis] - tf.shape(input=from_tensor)[axis]
60 |
61 | # top, bottom, left, right
62 | pad_tensor = [[0, 0] for _ in range(len(from_tensor.shape))]
63 | pad_tensor[axis] = [0, pad_len]
64 | from_tensor = tf.pad(tensor=from_tensor, paddings=pad_tensor)
65 | return from_tensor
66 |
--------------------------------------------------------------------------------
/m3tl/problem_types/vector_fit.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/12_8_problem_type_vector_fit.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['cosine_wrapper', 'VectorFit', 'vector_fit_get_or_make_label_encoder_fn', 'vector_fit_label_handling_fn']
4 |
5 | # Cell
6 | from typing import Dict, Tuple
7 |
8 | import numpy as np
9 | import tensorflow as tf
10 | from ..base_params import BaseParams
11 | from .utils import (empty_tensor_handling_loss,
12 | nan_loss_handling)
13 | from ..special_tokens import PREDICT
14 | from ..utils import get_phase
15 |
16 |
17 | # Cell
18 | def cosine_wrapper(labels, logits, from_logits=True):
19 | return tf.keras.losses.cosine_similarity(labels, logits)
20 |
21 |
22 | class VectorFit(tf.keras.Model):
23 | def __init__(self, params: BaseParams, problem_name: str) -> None:
24 | super(VectorFit, self).__init__(name=problem_name)
25 | self.params = params
26 | self.problem_name = problem_name
27 | self.num_classes = self.params.get_problem_info(problem=problem_name, info_name='num_classes')
28 | self.dense = tf.keras.layers.Dense(self.num_classes)
29 |
30 | def call(self, inputs: Tuple[Dict]):
31 | mode = get_phase()
32 | feature, hidden_feature = inputs
33 | pooled_hidden = hidden_feature['pooled']
34 |
35 | logits = self.dense(pooled_hidden)
36 | if mode != PREDICT:
37 | # this is actually a vector
38 | label = feature['{}_label_ids'.format(self.problem_name)]
39 |
40 | loss = empty_tensor_handling_loss(label, logits, cosine_wrapper)
41 | loss = nan_loss_handling(loss)
42 | self.add_loss(loss)
43 |
44 | self.add_metric(tf.math.negative(
45 | loss), name='{}_cos_sim'.format(self.problem_name), aggregation='mean')
46 | return logits
47 |
48 | # Cell
49 | def vector_fit_get_or_make_label_encoder_fn(params: BaseParams, problem, mode, label_list, *args, **kwargs):
50 | if label_list:
51 | # set params num_classes for this problem
52 | label_array = np.array(label_list)
53 | params.set_problem_info(problem=problem, info_name='num_classes', info=label_array.shape[-1])
54 | return None
55 |
56 |
57 | # Cell
58 | def vector_fit_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):
59 | # return label_id and label mask
60 | label_id = np.array(target, dtype='float32')
61 | return label_id, None
62 |
63 |
--------------------------------------------------------------------------------
/m3tl/problem_types/regression.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/12_7_problem_type_regression.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['mse_wrapper', 'Regression', 'regression_get_or_make_label_encoder_fn', 'regression_label_handling_fn']
4 |
5 | # Cell
6 | from typing import Dict, List, Tuple
7 |
8 | import numpy as np
9 | import tensorflow as tf
10 | from ..base_params import BaseParams
11 | from .utils import (empty_tensor_handling_loss,
12 | nan_loss_handling)
13 | from ..special_tokens import PREDICT, TRAIN
14 | from ..utils import get_phase, variable_summaries
15 |
16 |
17 | # Cell
18 |
19 | def mse_wrapper(labels, logits, from_logits=True):
20 | return tf.keras.losses.mean_squared_error(labels, logits)
21 |
22 |
23 | class Regression(tf.keras.Model):
24 | def __init__(self, params: BaseParams, problem_name: str) -> None:
25 | super(Regression, self).__init__(name=problem_name)
26 | self.params = params
27 | self.problem_name = problem_name
28 | self.num_classes = 1
29 | self.dense = tf.keras.layers.Dense(self.num_classes)
30 |
31 | def call(self, inputs: Tuple[Dict]):
32 | mode = get_phase()
33 | feature, hidden_feature = inputs
34 | pooled_hidden = hidden_feature['pooled']
35 |
36 | logits = self.dense(pooled_hidden)
37 | if self.params.detail_log:
38 | for weight_variable in self.weights:
39 | variable_summaries(weight_variable, self.problem_name)
40 |
41 | if mode != PREDICT:
42 | # this is actually a float
43 | label = feature['{}_label_ids'.format(self.problem_name)]
44 |
45 | loss = empty_tensor_handling_loss(label, logits, mse_wrapper)
46 | loss = nan_loss_handling(loss)
47 | self.add_loss(loss)
48 |
49 | self.add_metric(tf.math.negative(
50 | loss), name='{}_neg_mse'.format(self.problem_name), aggregation='mean')
51 | return logits
52 |
53 | # Cell
54 | def regression_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs):
55 | if mode == TRAIN:
56 | # set params num_classes for this problem
57 | params.set_problem_info(problem=problem, info_name='num_classes', info=1)
58 | return None
59 |
60 |
61 | # Cell
62 | def regression_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):
63 | # return label_id and label mask
64 | label_id = float(target)
65 | return label_id, None
66 |
67 |
--------------------------------------------------------------------------------
/m3tl/mtl_model/base.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/15-00_mtl_model_base.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['MTLBase', 'BasicMTL']
4 |
5 | # Cell
6 | from copy import copy
7 | from typing import Dict, Tuple
8 |
9 | import tensorflow as tf
10 | from ..utils import dispatch_features, get_phase
11 |
12 |
13 | class MTLBase(tf.keras.Model):
14 | def __init__(self, params, name:str, *args, **kwargs):
15 | super(MTLBase, self).__init__(name, *args, **kwargs)
16 | self.params = params
17 | self.available_extract_target = copy(self.params.problem_list)
18 | self.available_extract_target.append('all')
19 | self.problem_list = self.params.problem_list
20 |
21 | def extract_feature(self, extract_problem: str, feature_dict: dict, hidden_feature_dict: dict):
22 |
23 | mode = get_phase()
24 | if extract_problem not in self.available_extract_target:
25 | raise ValueError('Tried to extract feature {0}, available extract problem: {1}'.format(
26 | extract_problem, self.available_extract_target))
27 |
28 | # if key contains problem, return directly
29 | if extract_problem in feature_dict and extract_problem in hidden_feature_dict:
30 | return feature_dict[extract_problem], hidden_feature_dict[extract_problem]
31 |
32 | # use dispatch function to extract record based on loss multiplier
33 | if 'all' in feature_dict and 'all' in hidden_feature_dict:
34 | return dispatch_features(
35 | features=feature_dict['all'], hidden_feature=hidden_feature_dict['all'],
36 | problem=extract_problem, mode=mode)
37 | return dispatch_features(
38 | features=feature_dict, hidden_feature=hidden_feature_dict,
39 | problem=extract_problem, mode=mode)
40 |
41 | def call(self, inputs: Tuple[Dict[str, tf.Tensor]]):
42 | raise NotImplementedError
43 |
44 |
45 | # Cell
46 | class BasicMTL(MTLBase):
47 | def __init__(self, params, name: str, *args, **kwargs):
48 | super().__init__(params, name, *args, **kwargs)
49 |
50 | def call(self, inputs: Tuple[Dict[str, tf.Tensor]]):
51 | mode = get_phase()
52 | features, hidden_features = inputs
53 | features_per_problem, hidden_features_per_problem = {}, {}
54 | for problem in self.available_extract_target:
55 | features_per_problem[problem], hidden_features_per_problem[problem] = self.extract_feature(
56 | extract_problem=problem, feature_dict=features, hidden_feature_dict=hidden_features
57 | )
58 | return features_per_problem, hidden_features_per_problem
--------------------------------------------------------------------------------
/docs/css/theme-green.css:
--------------------------------------------------------------------------------
1 | .summary {
2 | color: #808080;
3 | border-left: 5px solid #E50E51;
4 | font-size:16px;
5 | }
6 |
7 |
8 | h3 {color: #E50E51; }
9 | h4 {color: #808080; }
10 |
11 | .nav-tabs > li.active > a, .nav-tabs > li.active > a:hover, .nav-tabs > li.active > a:focus {
12 | background-color: #248ec2;
13 | color: white;
14 | }
15 |
16 | .nav > li.active > a {
17 | background-color: #72ac4a;
18 | }
19 |
20 | .nav > li > a:hover {
21 | background-color: #72ac4a;
22 | }
23 |
24 | div.navbar-collapse .dropdown-menu > li > a:hover {
25 | background-color: #72ac4a;
26 | }
27 |
28 | .navbar-inverse .navbar-nav>li>a, .navbar-inverse .navbar-brand {
29 | color: white;
30 | }
31 |
32 | .navbar-inverse .navbar-nav>li>a:hover, a.fa.fa-home.fa-lg.navbar-brand:hover {
33 | color: #f0f0f0;
34 | }
35 |
36 | .nav li.thirdlevel > a {
37 | background-color: #FAFAFA !important;
38 | color: #72ac4a;
39 | font-weight: bold;
40 | }
41 |
42 | a[data-toggle="tooltip"] {
43 | color: #649345;
44 | font-style: italic;
45 | cursor: default;
46 | }
47 |
48 | .navbar-inverse {
49 | background-color: #72ac4a;
50 | border-color: #5b893c;
51 | }
52 |
53 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
54 | color: #5b893c;
55 | }
56 |
57 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
58 | background-color: #5b893c;
59 | color: #ffffff;
60 | }
61 |
62 | /* not sure if using this ...*/
63 | .navbar-inverse .navbar-collapse, .navbar-inverse .navbar-form {
64 | border-color: #72ac4a !important;
65 | }
66 |
67 | .btn-primary {
68 | color: #ffffff;
69 | background-color: #5b893c;
70 | border-color: #5b893c;
71 | }
72 |
73 | .btn-primary:hover,
74 | .btn-primary:focus,
75 | .btn-primary:active,
76 | .btn-primary.active,
77 | .open .dropdown-toggle.btn-primary {
78 | background-color: #72ac4a;
79 | border-color: #5b893c;
80 | }
81 |
82 | .printTitle {
83 | color: #5b893c !important;
84 | }
85 |
86 | body.print h1 {color: #5b893c !important; font-size:28px;}
87 | body.print h2 {color: #595959 !important; font-size:24px;}
88 | body.print h3 {color: #E50E51 !important; font-size:14px;}
89 | body.print h4 {color: #679DCE !important; font-size:14px; font-style: italic;}
90 |
91 | .anchorjs-link:hover {
92 | color: #4f7233;
93 | }
94 |
95 | div.sidebarTitle {
96 | color: #E50E51;
97 | }
98 |
99 | li.sidebarTitle {
100 | margin-top:20px;
101 | font-weight:normal;
102 | font-size:130%;
103 | color: #ED1951;
104 | margin-bottom:10px;
105 | margin-left: 5px;
106 | }
107 |
108 | .navbar-inverse .navbar-toggle:focus, .navbar-inverse .navbar-toggle:hover {
109 | background-color: #E50E51;
110 | }
111 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | ## How to get started
4 |
5 | This repository is created using `nbdev`. Please install `nbdev` with following command.
6 |
7 | ```sh
8 | pip install nbdev
9 | ```
10 |
11 | Before anything else, please install the git hooks that run automatic scripts during each commit and merge to strip the notebooks of superfluous metadata (and avoid merge conflicts). After cloning the repository, run the following command inside it:
12 | ```
13 | nbdev_install_git_hooks
14 | ```
15 |
16 | All library code, documentation is generated from jupyter notebooks in `source_nbs`. Any change should be added to notebooks and then export to library and docs.
17 |
18 | ## Did you find a bug?
19 |
20 | * Ensure the bug was not already reported by searching under Issues.
21 | * If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring.
22 | * Be sure to add the complete error messages.
23 |
24 | #### Did you write a patch that fixes a bug?
25 |
26 | * Open a new merge request with the patch.
27 | * Ensure that your MR includes a test that fails without your patch, and pass with it.
28 | * Ensure the MR description clearly describes the problem and solution. Include the relevant issue number if applicable.
29 |
30 | ## MR submission guidelines
31 |
32 | * Keep each MR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each MR focused.
33 | * Do not mix style changes/fixes with "functional" changes. It's very difficult to review such MRs and it most likely get rejected.
34 | * Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can.
35 | * Do not turn an already submitted MR into your development playground. If after you submitted MR, you discovered that more work is needed - close the MR, do the required work and then submit a new MR. Otherwise each of your commits requires attention from maintainers of the project.
36 | * If, however, you submitted a MR and received a request for changes, you should proceed with commits inside that MR, so that the maintainer can see the incremental fixes and won't need to review the whole MR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the MR, do the work and then submit it again. Use common sense where you'd choose one way over another.
37 |
38 | ## Do you want to contribute to the documentation?
39 |
40 | * Docs are automatically created from the notebooks in the source_nbs folder.
41 |
42 | ## Things to check before commit
43 |
44 | - `make nbbuild`: Build lib and docs
45 | - `make check`: M:ke sure nbs are readable and clean
46 | - `make commit`: Make sure nbs are readable and clean and then run tests. SLOW to run
47 |
48 |
--------------------------------------------------------------------------------
/m3tl/loss_strategy/base.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/16-00_loss_combination_strategy.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['LossCombinationStrategyBase', 'SumLossCombination']
4 |
5 | # Cell
6 | from collections import deque
7 | from typing import Dict, List
8 |
9 | import tensorflow as tf
10 | from ..utils import get_phase
11 | from tensorflow.python.util.nest import (flatten,
12 | flatten_with_joined_string_paths)
13 |
14 |
15 | class LossCombinationStrategyBase(tf.keras.Model):
16 | def __init__(self, params, name:str, *args, **kwargs):
17 | super(LossCombinationStrategyBase, self).__init__(name, *args, **kwargs)
18 | self.params = params
19 | self.problem_list = self.params.problem_list
20 | self.hist_loss_dict = deque(maxlen=100)
21 | self.hist_metric_dict = deque(maxlen=100)
22 |
23 | def extract_loss_metric_dict_from_history(self,
24 | history: tf.keras.callbacks.History,
25 | structure: dict,
26 | prefix='val_') -> dict:
27 | history: Dict[str, float] = history.history
28 |
29 | # metrics from validation set starts with val
30 | if prefix:
31 | if prefix != 'val_':
32 | raise ValueError('prefix should either be "val_" or None')
33 | history = {k.replace(prefix, ''): v for k, v in history.items() if k.startswith(prefix)}
34 |
35 |
36 |
37 | # get structure path
38 | structure_path = [p for p, _ in flatten_with_joined_string_paths(structure)]
39 | # make flat history and pack
40 | flat_history = [history[p] for p in structure_path]
41 | history = tf.nest.pack_sequence_as(structure=structure, flat_sequence=flat_history)
42 |
43 | return history
44 |
45 | def get_all_losses(self, current_loss_dict: dict) -> List[tf.Tensor]:
46 | return flatten(current_loss_dict)
47 |
48 | def get_problem_loss(self, current_loss_dict:dict, problem: str) -> List[tf.Tensor]:
49 | flatten_loss_with_path = flatten_with_joined_string_paths(current_loss_dict)
50 | return [v for p, v in flatten_loss_with_path if problem in p]
51 |
52 | def call(self,
53 | current_loss_dict: dict,
54 | current_metric_dict: dict,
55 | history: tf.keras.callbacks.History):
56 | raise NotImplementedError
57 |
58 |
59 | # Cell
60 | class SumLossCombination(LossCombinationStrategyBase):
61 | def __init__(self, params, name: str, *args, **kwargs):
62 | super().__init__(params, name, *args, **kwargs)
63 |
64 | def call(self,
65 | current_loss_dict: dict,
66 | current_metric_dict: dict,
67 | history: tf.keras.callbacks.History):
68 | mode = get_phase()
69 | # total losses
70 | losses = self.get_all_losses(current_loss_dict)
71 | return losses
--------------------------------------------------------------------------------
/docs/_includes/sidebar.html:
--------------------------------------------------------------------------------
1 | {% assign sidebar = site.data.sidebars[page.sidebar].entries %}
2 | {% assign pageurl = page.url | remove: ".html" %}
3 |
4 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/docs/css/theme-blue.css:
--------------------------------------------------------------------------------
1 | .summary {
2 | color: #808080;
3 | border-left: 5px solid #ED1951;
4 | font-size:16px;
5 | }
6 |
7 |
8 | h3 {color: #000000; }
9 | h4 {color: #000000; }
10 |
11 | .nav-tabs > li.active > a, .nav-tabs > li.active > a:hover, .nav-tabs > li.active > a:focus {
12 | background-color: #248ec2;
13 | color: white;
14 | }
15 |
16 | .nav > li.active > a {
17 | background-color: #347DBE;
18 | }
19 |
20 | .nav > li > a:hover {
21 | background-color: #248ec2;
22 | }
23 |
24 | div.navbar-collapse .dropdown-menu > li > a:hover {
25 | background-color: #347DBE;
26 | }
27 |
28 | .nav li.thirdlevel > a {
29 | background-color: #FAFAFA !important;
30 | color: #248EC2;
31 | font-weight: bold;
32 | }
33 |
34 | a[data-toggle="tooltip"] {
35 | color: #649345;
36 | font-style: italic;
37 | cursor: default;
38 | }
39 |
40 | .navbar-inverse {
41 | background-color: #347DBE;
42 | border-color: #015CAE;
43 | }
44 | .navbar-inverse .navbar-nav>li>a, .navbar-inverse .navbar-brand {
45 | color: white;
46 | }
47 |
48 | .navbar-inverse .navbar-nav>li>a:hover, a.fa.fa-home.fa-lg.navbar-brand:hover {
49 | color: #f0f0f0;
50 | }
51 |
52 | a.navbar-brand:hover {
53 | color: #f0f0f0;
54 | }
55 |
56 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
57 | color: #015CAE;
58 | }
59 |
60 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
61 | background-color: #015CAE;
62 | color: #ffffff;
63 | }
64 |
65 | .navbar-inverse .navbar-collapse, .navbar-inverse .navbar-form {
66 | border-color: #248ec2 !important;
67 | }
68 |
69 | .btn-primary {
70 | color: #ffffff;
71 | background-color: #347DBE;
72 | border-color: #347DBE;
73 | }
74 |
75 | .navbar-inverse .navbar-nav > .active > a, .navbar-inverse .navbar-nav > .active > a:hover, .navbar-inverse .navbar-nav > .active > a:focus {
76 | background-color: #347DBE;
77 | }
78 |
79 | .btn-primary:hover,
80 | .btn-primary:focus,
81 | .btn-primary:active,
82 | .btn-primary.active,
83 | .open .dropdown-toggle.btn-primary {
84 | background-color: #248ec2;
85 | border-color: #347DBE;
86 | }
87 |
88 | .printTitle {
89 | color: #015CAE !important;
90 | }
91 |
92 | body.print h1 {color: #015CAE !important; font-size:28px !important;}
93 | body.print h2 {color: #595959 !important; font-size:20px !important;}
94 | body.print h3 {color: #E50E51 !important; font-size:14px !important;}
95 | body.print h4 {color: #679DCE !important; font-size:14px; font-style: italic !important;}
96 |
97 | .anchorjs-link:hover {
98 | color: #216f9b;
99 | }
100 |
101 | div.sidebarTitle {
102 | color: #015CAE;
103 | }
104 |
105 | li.sidebarTitle {
106 | margin-top:20px;
107 | font-weight:normal;
108 | font-size:130%;
109 | color: #ED1951;
110 | margin-bottom:10px;
111 | margin-left: 5px;
112 |
113 | }
114 |
115 | .navbar-inverse .navbar-toggle:focus, .navbar-inverse .navbar-toggle:hover {
116 | background-color: #015CAE;
117 | }
118 |
119 | .navbar-inverse .navbar-toggle {
120 | border-color: #015CAE;
121 | }
122 |
--------------------------------------------------------------------------------
/docs/_includes/topnav.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
13 |
14 |
15 |
16 | Nav
17 |
18 |
19 | {% assign topnav = site.data[page.topnav] %}
20 | {% assign topnav_dropdowns = site.data[page.topnav].topnav_dropdowns %}
21 |
22 | {% for entry in topnav.topnav %}
23 | {% for item in entry.items %}
24 | {% if item.external_url %}
25 | {{item.title}}
26 | {% elsif page.url contains item.url %}
27 | {{item.title}}
28 | {% else %}
29 | {{item.title}}
30 | {% endif %}
31 | {% endfor %}
32 | {% endfor %}
33 |
34 |
35 | {% for entry in topnav_dropdowns %}
36 | {% for folder in entry.folders %}
37 |
38 | {{ folder.title }}
39 |
50 |
51 | {% endfor %}
52 | {% endfor %}
53 | {% if site.google_search %}
54 |
55 | {% include search_google_custom.html %}
56 |
57 | {% endif %}
58 |
59 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/docs/js/jquery.navgoco.min.js:
--------------------------------------------------------------------------------
1 | /*
2 | * jQuery Navgoco Menus Plugin v0.2.1 (2014-04-11)
3 | * https://github.com/tefra/navgoco
4 | *
5 | * Copyright (c) 2014 Chris T (@tefra)
6 | * BSD - https://github.com/tefra/navgoco/blob/master/LICENSE-BSD
7 | */
8 | !function(a){"use strict";var b=function(b,c,d){return this.el=b,this.$el=a(b),this.options=c,this.uuid=this.$el.attr("id")?this.$el.attr("id"):d,this.state={},this.init(),this};b.prototype={init:function(){var b=this;b._load(),b.$el.find("ul").each(function(c){var d=a(this);d.attr("data-index",c),b.options.save&&b.state.hasOwnProperty(c)?(d.parent().addClass(b.options.openClass),d.show()):d.parent().hasClass(b.options.openClass)?(d.show(),b.state[c]=1):d.hide()});var c=a(" ").prepend(b.options.caretHtml),d=b.$el.find("li > a");b._trigger(c,!1),b._trigger(d,!0),b.$el.find("li:has(ul) > a").prepend(c)},_trigger:function(b,c){var d=this;b.on("click",function(b){b.stopPropagation();var e=c?a(this).next():a(this).parent().next(),f=!1;if(c){var g=a(this).attr("href");f=void 0===g||""===g||"#"===g}if(e=e.length>0?e:!1,d.options.onClickBefore.call(this,b,e),!c||e&&f)b.preventDefault(),d._toggle(e,e.is(":hidden")),d._save();else if(d.options.accordion){var h=d.state=d._parents(a(this));d.$el.find("ul").filter(":visible").each(function(){var b=a(this),c=b.attr("data-index");h.hasOwnProperty(c)||d._toggle(b,!1)}),d._save()}d.options.onClickAfter.call(this,b,e)})},_toggle:function(b,c){var d=this,e=b.attr("data-index"),f=b.parent();if(d.options.onToggleBefore.call(this,b,c),c){if(f.addClass(d.options.openClass),b.slideDown(d.options.slide),d.state[e]=1,d.options.accordion){var g=d.state=d._parents(b);g[e]=d.state[e]=1,d.$el.find("ul").filter(":visible").each(function(){var b=a(this),c=b.attr("data-index");g.hasOwnProperty(c)||d._toggle(b,!1)})}}else f.removeClass(d.options.openClass),b.slideUp(d.options.slide),d.state[e]=0;d.options.onToggleAfter.call(this,b,c)},_parents:function(b,c){var d={},e=b.parent(),f=e.parents("ul");return f.each(function(){var b=a(this),e=b.attr("data-index");return e?void(d[e]=c?b:1):!1}),d},_save:function(){if(this.options.save){var b={};for(var d in this.state)1===this.state[d]&&(b[d]=1);c[this.uuid]=this.state=b,a.cookie(this.options.cookie.name,JSON.stringify(c),this.options.cookie)}},_load:function(){if(this.options.save){if(null===c){var b=a.cookie(this.options.cookie.name);c=b?JSON.parse(b):{}}this.state=c.hasOwnProperty(this.uuid)?c[this.uuid]:{}}},toggle:function(b){var c=this,d=arguments.length;if(1>=d)c.$el.find("ul").each(function(){var d=a(this);c._toggle(d,b)});else{var e,f={},g=Array.prototype.slice.call(arguments,1);d--;for(var h=0;d>h;h++){e=g[h];var i=c.$el.find('ul[data-index="'+e+'"]').first();if(i&&(f[e]=i,b)){var j=c._parents(i,!0);for(var k in j)f.hasOwnProperty(k)||(f[k]=j[k])}}for(e in f)c._toggle(f[e],b)}c._save()},destroy:function(){a.removeData(this.$el),this.$el.find("li:has(ul) > a").unbind("click"),this.$el.find("li:has(ul) > a > span").unbind("click")}},a.fn.navgoco=function(c){if("string"==typeof c&&"_"!==c.charAt(0)&&"init"!==c)var d=!0,e=Array.prototype.slice.call(arguments,1);else c=a.extend({},a.fn.navgoco.defaults,c||{}),a.cookie||(c.save=!1);return this.each(function(f){var g=a(this),h=g.data("navgoco");h||(h=new b(this,d?a.fn.navgoco.defaults:c,f),g.data("navgoco",h)),d&&h[c].apply(h,e)})};var c=null;a.fn.navgoco.defaults={caretHtml:"",accordion:!1,openClass:"open",save:!0,cookie:{name:"navgoco",expires:!1,path:"/"},slide:{duration:400,easing:"swing"},onClickBefore:a.noop,onClickAfter:a.noop,onToggleBefore:a.noop,onToggleAfter:a.noop}}(jQuery);
--------------------------------------------------------------------------------
/source_nbs/12_0_problem_type_utils.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# default_exp problem_types.utils\n",
10 | "%load_ext autoreload\n",
11 | "%autoreload 2\n",
12 | "import os\n",
13 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {},
19 | "source": [
20 | "# Problem Type Utils\n",
21 | "\n",
22 | "Utils to create problem types."
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "# export\n",
32 | "from typing import Dict, Tuple\n",
33 | "\n",
34 | "import tensorflow as tf\n",
35 | "from m3tl.base_params import BaseParams\n",
36 | "\n",
37 | "\n",
38 | "def empty_tensor_handling_loss(labels, logits, loss_fn):\n",
39 | " if tf.equal(tf.size(labels), 0):\n",
40 | " return 0.0\n",
41 | " if tf.equal(tf.size(tf.shape(labels)), 0):\n",
42 | " return 0.0\n",
43 | " if tf.equal(tf.shape(labels)[0], 0):\n",
44 | " return 0.0\n",
45 | " else:\n",
46 | " return tf.reduce_mean(loss_fn(\n",
47 | " labels, logits, from_logits=True))\n",
48 | "\n",
49 | "\n",
50 | "@tf.function\n",
51 | "def nan_loss_handling(loss):\n",
52 | " if tf.math.is_nan(loss):\n",
53 | " return 0.0\n",
54 | " else:\n",
55 | " return loss\n",
56 | "\n",
57 | "\n",
58 | "@tf.function\n",
59 | "def create_dummy_if_empty(inp_tensor: tf.Tensor) -> tf.Tensor:\n",
60 | " shape_tensor = tf.shape(inp_tensor)\n",
61 | " if tf.equal(shape_tensor[0], 0):\n",
62 | " data_type = inp_tensor.dtype\n",
63 | " dummy_shape_first_dim = tf.convert_to_tensor([1], dtype=tf.int32)\n",
64 | " dummy_shape = tf.concat(\n",
65 | " [dummy_shape_first_dim, shape_tensor[1:]], axis=0)\n",
66 | " dummy_tensor = tf.zeros(dummy_shape, dtype=data_type)\n",
67 | " return dummy_tensor\n",
68 | " else:\n",
69 | " return inp_tensor\n",
70 | "\n",
71 | "\n",
72 | "class BaseTop(tf.keras.Model):\n",
73 | " def __init__(self, params: BaseParams, problem_name: str) -> None:\n",
74 | " super(BaseTop, self).__init__(name=problem_name)\n",
75 | " self.params = params\n",
76 | " self.problem_name = problem_name\n",
77 | "\n",
78 | " def call(self, inputs: Tuple[Dict], mode: str):\n",
79 | " raise NotImplementedError\n",
80 | " \n",
81 | "def pad_to_shape(from_tensor: tf.Tensor, to_tensor: tf.Tensor, axis=1) -> tf.Tensor:\n",
82 | " # sometimes the length of labels dose not equal to length of inputs\n",
83 | " # that's caused by tf.data.experimental.bucket_by_sequence_length in multi problem scenario\n",
84 | " pad_len = tf.shape(input=to_tensor)[\n",
85 | " axis] - tf.shape(input=from_tensor)[axis]\n",
86 | "\n",
87 | " # top, bottom, left, right\n",
88 | " pad_tensor = [[0, 0] for _ in range(len(from_tensor.shape))]\n",
89 | " pad_tensor[axis] = [0, pad_len]\n",
90 | " from_tensor = tf.pad(tensor=from_tensor, paddings=pad_tensor)\n",
91 | " return from_tensor\n"
92 | ]
93 | }
94 | ],
95 | "metadata": {},
96 | "nbformat": 4,
97 | "nbformat_minor": 2
98 | }
99 |
--------------------------------------------------------------------------------
/baseline.md:
--------------------------------------------------------------------------------
1 | # city_cws
2 | |experiment |city_cws_Accuracy|city_cws_F1 Score|city_cws_Precision|city_cws_Recall|
3 | |-----------|----------------:|----------------:|-----------------:|--------------:|
4 | |single_task| 0.9820| 0.9783| 0.9782| 0.9783|
5 | |multitask | 0.9759| 0.9724| 0.9727| 0.9721|
6 |
7 | # boson_ner
8 | |experiment |boson_ner_Accuracy|boson_ner_F1 Score|boson_ner_Precision|boson_ner_Recall|
9 | |-----------|-----------------:|-----------------:|------------------:|---------------:|
10 | |single_task| 0.9676| 0.8316| 0.8232| 0.8402|
11 | |multitask | 0.9667| 0.8268| 0.8209| 0.8328|
12 |
13 | # CWS
14 | |experiment |CWS_Accuracy|CWS_F1 Score|CWS_Precision|CWS_Recall|
15 | |-----------|-----------:|-----------:|------------:|---------:|
16 | |single_task| 0.9660| 0.9590| 0.9592| 0.9587|
17 | |multitask | 0.9625| 0.9541| 0.9540| 0.9543|
18 |
19 | # ontonotes_ner
20 | |experiment |ontonotes_ner_Accuracy|ontonotes_ner_F1 Score|ontonotes_ner_Precision|ontonotes_ner_Recall|
21 | |-----------|---------------------:|---------------------:|----------------------:|-------------------:|
22 | |single_task| 0.9752| 0.8490| 0.8365| 0.8619|
23 | |multitask | 0.9747| 0.8403| 0.8190| 0.8627|
24 |
25 | # pku_cws
26 | |experiment |pku_cws_Accuracy|pku_cws_F1 Score|pku_cws_Precision|pku_cws_Recall|
27 | |-----------|---------------:|---------------:|----------------:|-------------:|
28 | |single_task| 0.9612| 0.9529| 0.9567| 0.9491|
29 | |multitask | 0.9575| 0.9479| 0.9512| 0.9447|
30 |
31 | # msr_cws
32 | |experiment |msr_cws_Accuracy|msr_cws_F1 Score|msr_cws_Precision|msr_cws_Recall|
33 | |-----------|---------------:|---------------:|----------------:|-------------:|
34 | |single_task| 0.9724| 0.9650| 0.9615| 0.9685|
35 | |multitask | 0.9665| 0.9572| 0.9535| 0.9610|
36 |
37 | # ctb_pos
38 | |experiment |ctb_pos_Accuracy|ctb_pos_Accuracy Per Sequence|
39 | |-----------|---------------:|----------------------------:|
40 | |single_task| 0.9617| 0.5470|
41 | |multitask | 0.9606| 0.5345|
42 |
43 | # weibo_ner
44 | |experiment |weibo_ner_Accuracy|weibo_ner_F1 Score|weibo_ner_Precision|weibo_ner_Recall|
45 | |-----------|-----------------:|-----------------:|------------------:|---------------:|
46 | |single_task| 0.9806| 0.6748| 0.7020| 0.6495|
47 | |multitask | 0.9804| 0.6729| 0.6777| 0.6682|
48 |
49 | # msra_ner
50 | |experiment |msra_ner_Accuracy|msra_ner_F1 Score|msra_ner_Precision|msra_ner_Recall|
51 | |-----------|----------------:|----------------:|-----------------:|--------------:|
52 | |single_task| 0.9947| 0.9651| 0.9665| 0.9638|
53 | |multitask | 0.9959| 0.9705| 0.9717| 0.9692|
54 |
55 |
56 | # ctb_cws
57 | |experiment |ctb_cws_Accuracy|ctb_cws_F1 Score|ctb_cws_Precision|ctb_cws_Recall|
58 | |-----------|---------------:|---------------:|----------------:|-------------:|
59 | |single_task| 0.9569| 0.9501| 0.9511| 0.9491|
60 | |multitask | 0.9572| 0.9491| 0.9494| 0.9489|
61 |
62 |
--------------------------------------------------------------------------------
/m3tl/mtl_model/mmoe.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/15-01_mtl_model_mmoe.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['MMoE']
4 |
5 | # Cell
6 | from typing import Dict, Tuple
7 |
8 | import tensorflow as tf
9 | from ..base_params import BaseParams
10 | from .base import MTLBase
11 | from ..utils import get_phase
12 |
13 |
14 | class MMoE(MTLBase):
15 | def __init__(self, params: BaseParams, name:str):
16 | super(MMoE, self).__init__(params, name)
17 | self.num_experts = self.params.get('num_experts', 8)
18 | self.num_experts_units = self.params.get('num_experts_units', 128)
19 | self.problem_list = self.params.problem_list
20 | self.gate_dict = {
21 | problem: tf.keras.layers.Dense(self.num_experts, activation='softmax') for problem in self.problem_list
22 | }
23 |
24 | def build(self, input_shape):
25 | features_input_shape, hidden_feature_input_shape = input_shape
26 | pooled_shape = hidden_feature_input_shape['all']['pooled']
27 | self.experts_kernel = self.add_weight(
28 | name='experts_kernel',
29 | shape=(pooled_shape[-1], self.num_experts_units, self.num_experts)
30 | )
31 | # add leading dims to support braodcasting
32 | self.experts_bias = self.add_weight(
33 | name='experts_bias',
34 | shape=(1, 1, self.num_experts_units, self.num_experts)
35 | )
36 |
37 | def call(self, inputs: Tuple[Dict[str, tf.Tensor]]):
38 | mode = get_phase()
39 | features, hidden_features = inputs
40 | all_features, all_hidden_features = self.extract_feature('all', feature_dict=features, hidden_feature_dict=hidden_features)
41 |
42 | # get seq outputs
43 | # [batch_size, seq_len, hidden_size]
44 | seq_hidden = all_hidden_features['seq']
45 | # [batch_size, seq_len, num_expert_units, num_experts]
46 | experts_outputs = tf.tensordot(seq_hidden, self.experts_kernel, axes=[2, 0]) + self.experts_bias
47 |
48 | experts_output_dict = {
49 | 'pooled': experts_outputs[:, 0, :, :],
50 | 'seq': experts_outputs
51 | }
52 |
53 | # per problem gating
54 | # we can save a little bit of computation by extract per problem features first
55 | features_per_problem, hidden_features_per_problem = {}, {}
56 | for problem, gate_net in self.gate_dict.items():
57 | features_per_problem[problem], problem_experts_output = self.extract_feature(
58 | extract_problem=problem, feature_dict=all_features, hidden_feature_dict=experts_output_dict
59 | )
60 | _, problem_hidden_features = self.extract_feature(
61 | extract_problem=problem, feature_dict=all_features, hidden_feature_dict=all_hidden_features
62 | )
63 |
64 | # apply gating
65 | # [problem_batch_size, seq_len, 1, num_experts]
66 | experts_weight = gate_net(problem_hidden_features['seq'])
67 | experts_weight = tf.expand_dims(experts_weight, axis=2)
68 | # [problem_batch_size, seq_len, num_expert_units, num_experts]
69 | expert_output_per_problem = problem_experts_output['seq']
70 |
71 | # [problem_batch_size, seq_len, num_expert_units]
72 | gated_experts_output = tf.reduce_mean(experts_weight*expert_output_per_problem, axis=-1)
73 | hidden_features_per_problem[problem] = {
74 | 'pooled': gated_experts_output[:, 0, :],
75 | 'seq': gated_experts_output
76 | }
77 |
78 | return features_per_problem, hidden_features_per_problem
79 |
--------------------------------------------------------------------------------
/docs/_includes/initialize_shuffle.html:
--------------------------------------------------------------------------------
1 |
7 |
8 |
100 |
101 |
102 |
103 |
114 |
115 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/source_nbs/index.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "source": [
7 | "#hide\n",
8 | "from m3tl.params import Params"
9 | ],
10 | "outputs": [],
11 | "metadata": {}
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "source": [
16 | "\n",
17 | "# M3TL\n"
18 | ],
19 | "metadata": {}
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "source": [
24 | "**M**ulti-**M**odal **M**ulti-**T**ask **L**earning\n",
25 | "\n",
26 | "## Install\n",
27 | "\n",
28 | "```\n",
29 | "MASKED\n",
30 | "```\n",
31 | "\n",
32 | "## What is it\n",
33 | "\n",
34 | "This is a project that uses transformers(based on huggingface transformers) as base model to do **multi-modal multi-task learning**. \n",
35 | "\n",
36 | "## Why do I need this\n",
37 | "\n",
38 | "Multi-task learning(MTL) is gaining more and more attention, especially in deep learning era. It is widely used in NLP, CV, recommendation, etc. However, MTL usually involves complicated data preprocessing, task managing and task interaction. Other open-source projects, like TencentNLP and PyText, supports MTL but in a naive way and it's not straightforward to implement complicated MTL algorithm. In this project, we try to make writing MTL model as easy as single task learning model and further extend MTL to multi-modal multi-task learning. To do so, we expose following MTL related programable module to user:\n",
39 | "\n",
40 | "- problem sampling strategy\n",
41 | "- loss combination strategy\n",
42 | "- gradient surgery\n",
43 | "- model after base model(transformers)\n",
44 | "\n",
45 | "Apart from programable modules, we also provide various built-in SOTA MTL algorithms.\n",
46 | "\n",
47 | "In a word, you can use this project to:\n",
48 | "\n",
49 | "- implement complicated MTL algorithm\n",
50 | "- do SOTA MTL without diving into details\n",
51 | "- do multi-modal learning\n",
52 | "\n",
53 | "And since we use transformers as base model, you get all the benefits that you can get from transformers!\n",
54 | "\n",
55 | "## What type of problems are supported?"
56 | ],
57 | "metadata": {}
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "source": [
63 | "params = Params()\n",
64 | "for problem_type in params.list_available_problem_types():\n",
65 | " print('`{problem_type}`: {desc}'.format(\n",
66 | " desc=params.problem_type_desc[problem_type], problem_type=problem_type))\n"
67 | ],
68 | "outputs": [
69 | {
70 | "output_type": "stream",
71 | "name": "stdout",
72 | "text": [
73 | "`cls`: Classification\n",
74 | "`multi_cls`: Multi-Label Classification\n",
75 | "`seq_tag`: Sequence Labeling\n",
76 | "`masklm`: Masked Language Model\n",
77 | "`pretrain`: NSP+MLM(Deprecated)\n",
78 | "`regression`: Regression\n",
79 | "`vector_fit`: Vector Fitting\n",
80 | "`premask_mlm`: Pre-masked Masked Language Model\n",
81 | "`contrastive_learning`: Contrastive Learning\n"
82 | ]
83 | }
84 | ],
85 | "metadata": {}
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "source": [
90 | "\n",
91 | "## Get Started\n",
92 | "\n",
93 | "Please see tutorials.\n"
94 | ],
95 | "metadata": {}
96 | }
97 | ],
98 | "metadata": {
99 | "kernelspec": {
100 | "display_name": "Python 3",
101 | "language": "python",
102 | "name": "python3"
103 | }
104 | },
105 | "nbformat": 4,
106 | "nbformat_minor": 2
107 | }
--------------------------------------------------------------------------------
/docs/css/printstyles.css:
--------------------------------------------------------------------------------
1 |
2 | /*body.print .container {max-width: 650px;}*/
3 |
4 | body {
5 | font-size:14px;
6 | }
7 | .nav ul li a {border-top:0px; background-color:transparent; color: #808080; }
8 | #navig a[href] {color: #595959 !important;}
9 | table .table {max-width:650px;}
10 |
11 | #navig li.sectionHead {font-weight: bold; font-size: 18px; color: #595959 !important; }
12 | #navig li {font-weight: normal; }
13 |
14 | #navig a[href]::after { content: leader(".") target-counter(attr(href), page); }
15 |
16 | a[href]::after {
17 | content: " (page " target-counter(attr(href), page) ")"
18 | }
19 |
20 | a[href^="http:"]::after, a[href^="https:"]::after {
21 | content: "";
22 | }
23 |
24 | a[href] {
25 | color: blue !important;
26 | }
27 | a[href*="mailto"]::after, a[data-toggle="tooltip"]::after, a[href].noCrossRef::after {
28 | content: "";
29 | }
30 |
31 |
32 | @page {
33 | margin: 60pt 90pt 60pt 90pt;
34 | font-family: sans-serif;
35 | font-style:none;
36 | color: gray;
37 |
38 | }
39 |
40 | .printTitle {
41 | line-height:30pt;
42 | font-size:27pt;
43 | font-weight: bold;
44 | letter-spacing: -.5px;
45 | margin-bottom:25px;
46 | }
47 |
48 | .printSubtitle {
49 | font-size: 19pt;
50 | color: #cccccc !important;
51 | font-family: "Grotesque MT Light";
52 | line-height: 22pt;
53 | letter-spacing: -.5px;
54 | margin-bottom:20px;
55 | }
56 | .printTitleArea hr {
57 | color: #999999 !important;
58 | height: 2px;
59 | width: 100%;
60 | }
61 |
62 | .printTitleImage {
63 | max-width:300px;
64 | margin-bottom:200px;
65 | }
66 |
67 |
68 | .printTitleImage {
69 | max-width: 250px;
70 | }
71 |
72 | #navig {
73 | /*page-break-before: always;*/
74 | }
75 |
76 | .copyrightBoilerplate {
77 | page-break-before:always;
78 | font-size:14px;
79 | }
80 |
81 | .lastGeneratedDate {
82 | font-style: italic;
83 | font-size:14px;
84 | color: gray;
85 | }
86 |
87 | .alert a {
88 | text-decoration: none !important;
89 | }
90 |
91 |
92 | body.title { page: title }
93 |
94 | @page title {
95 | @top-left {
96 | content: " ";
97 | }
98 | @top-right {
99 | content: " "
100 | }
101 | @bottom-right {
102 | content: " ";
103 | }
104 | @bottom-left {
105 | content: " ";
106 | }
107 | }
108 |
109 | body.frontmatter { page: frontmatter }
110 | body.frontmatter {counter-reset: page 1}
111 |
112 |
113 | @page frontmatter {
114 | @top-left {
115 | content: prince-script(guideName);
116 | }
117 | @top-right {
118 | content: prince-script(datestamp);
119 | }
120 | @bottom-right {
121 | content: counter(page, lower-roman);
122 | }
123 | @bottom-left {
124 | content: "youremail@domain.com"; }
125 | }
126 |
127 | body.first_page {counter-reset: page 1}
128 |
129 | h1 { string-set: doctitle content() }
130 |
131 | @page {
132 | @top-left {
133 | content: string(doctitle);
134 | font-size: 11px;
135 | font-style: italic;
136 | }
137 | @top-right {
138 | content: prince-script(datestamp);
139 | font-size: 11px;
140 | }
141 |
142 | @bottom-right {
143 | content: "Page " counter(page);
144 | font-size: 11px;
145 | }
146 | @bottom-left {
147 | content: prince-script(guideName);
148 | font-size: 11px;
149 | }
150 | }
151 | .alert {
152 | background-color: #fafafa !important;
153 | border-color: #dedede !important;
154 | color: black;
155 | }
156 |
157 | pre {
158 | background-color: #fafafa;
159 | }
160 |
--------------------------------------------------------------------------------
/docs/_data/sidebars/home_sidebar.yml:
--------------------------------------------------------------------------------
1 |
2 | #################################################
3 | ### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###
4 | #################################################
5 | # Instead edit ../../sidebar.json
6 | entries:
7 | - folders:
8 | - folderitems:
9 | - output: web,pdf
10 | title: Overview
11 | url: /
12 | - output: web,pdf
13 | title: Tutorial
14 | url: tutorial.html
15 | subfolders:
16 | - output: web
17 | subfolderitems:
18 | - output: web,pdf
19 | title: Classification
20 | url: 1_problem_type_cls.html
21 | - output: web,pdf
22 | title: Multi-Label Classification
23 | url: 2_problem_type_multi_cls.html
24 | - output: web,pdf
25 | title: Sequence Labeling
26 | url: 3_problem_type_seq_tag.html
27 | - output: web,pdf
28 | title: Masked Language Model
29 | url: 4_problem_type_masklm.html
30 | - output: web,pdf
31 | title: NSP+MLM(Deprecated)
32 | url: 6_problem_type_pretrain.html
33 | - output: web,pdf
34 | title: Regression
35 | url: 7_problem_type_regression.html
36 | - output: web,pdf
37 | title: Vector Fitting
38 | url: 8_problem_type_vector_fit.html
39 | - output: web,pdf
40 | title: Pre-masked Masked Language Model
41 | url: 9_problem_type_premask_mlm.html
42 | - output: web,pdf
43 | title: Contrastive Learning
44 | url: 10_problem_type_contrast_learning.html
45 | title: Problem Types
46 | - output: web
47 | subfolderitems:
48 | - output: web,pdf
49 | title: MTLBase
50 | url: 15-00_mtl_model_base.html
51 | - output: web,pdf
52 | title: MMoE
53 | url: 15-01_mtl_model_mmoe.html
54 | title: MTL Models
55 | - output: web
56 | subfolderitems:
57 | - output: web,pdf
58 | title: LossCombinationStrategyBase
59 | url: 16-00_loss_combination_strategy.html
60 | title: Loss Combination Strategy
61 | - output: web
62 | subfolderitems:
63 | - output: web,pdf
64 | title: EmbeddingLayerBase
65 | url: 18-00_embedding_layer.html
66 | title: Embedding Layers
67 | - output: web,pdf
68 | title: Params
69 | url: params.html
70 | - output: web,pdf
71 | title: Run Bert Multitask Learning
72 | url: run_bert_multitask.html
73 | - output: web,pdf
74 | title: Multitask Learning Model
75 | url: model_fn.html
76 | - output: web,pdf
77 | title: Utils
78 | url: utils.html
79 | - output: web,pdf
80 | title: Special Tokens
81 | url: special_tokens.html
82 | - output: web,pdf
83 | title: Bert Utils
84 | url: bert_utils.html
85 | - output: web,pdf
86 | title: Create Bert Features
87 | url: create_bert_features.html
88 | - output: web,pdf
89 | title: Preprocessing Decorator
90 | url: preproc_decorator.html
91 | - output: web,pdf
92 | title: Read and Write TFRecord
93 | url: read_write_tfrecord.html
94 | - output: web,pdf
95 | title: Function to Create Datasets
96 | url: input_fn.html
97 | - output: web,pdf
98 | title: Pre-defined Problems
99 | url: predefined_problems_test.html
100 | - output: web,pdf
101 | title: Body Modeling
102 | url: modeling.html
103 | - output: web,pdf
104 | title: Top Models
105 | url: top.html
106 | - output: web,pdf
107 | title: Test Base
108 | url: test_base.html
109 | output: web
110 | title: m3tl
111 | output: web
112 | title: Sidebar
113 |
--------------------------------------------------------------------------------
/docs/0_base_params.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Params
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 |
10 |
11 | nb_path: "source_nbs/00_0_base_params.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 | {% raw %}
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
BaseParams()
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | {% endraw %}
55 |
56 | {% raw %}
57 |
58 |
78 | {% endraw %}
79 |
80 | {% raw %}
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
StaticBatchParams() :: BaseParams
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 | {% endraw %}
103 |
104 | {% raw %}
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
DynamicBatchSizeParams() :: BaseParams
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 | {% endraw %}
127 |
128 | {% raw %}
129 |
130 |
131 |
132 |
133 | {% endraw %}
134 |
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/m3tl/problem_types/cls.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/12_1_problem_type_cls.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['Classification', 'cls_get_or_make_label_encoder_fn', 'cls_label_handling_fn']
4 |
5 | # Cell
6 | from functools import partial
7 | from typing import List
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 | from ..base_params import BaseParams
12 | from .utils import (empty_tensor_handling_loss,
13 | nan_loss_handling)
14 | from ..special_tokens import PREDICT, TRAIN
15 | from ..utils import (LabelEncoder, get_label_encoder_save_path, get_phase,
16 | need_make_label_encoder, variable_summaries)
17 |
18 |
19 | # Cell
20 |
21 | class Classification(tf.keras.layers.Layer):
22 | """Classification Top Layer"""
23 | def __init__(self, params: BaseParams, problem_name: str) -> None:
24 | super(Classification, self).__init__(name=problem_name)
25 | self.params = params
26 | self.problem_name = problem_name
27 | self.num_classes = self.params.get_problem_info(problem=problem_name, info_name='num_classes')
28 | self.dense = tf.keras.layers.Dense(self.num_classes, activation=None)
29 | self.metric_fn = tf.keras.metrics.SparseCategoricalAccuracy(
30 | name='{}_acc'.format(self.problem_name))
31 |
32 | self.dropout = tf.keras.layers.Dropout(1-params.dropout_keep_prob)
33 |
34 | def call(self, inputs):
35 | mode = get_phase()
36 | training = (mode == TRAIN)
37 | feature, hidden_feature = inputs
38 | hidden_feature = hidden_feature['pooled']
39 | if mode != PREDICT:
40 | labels = feature['{}_label_ids'.format(self.problem_name)]
41 | else:
42 | labels = None
43 | hidden_feature = self.dropout(hidden_feature, training)
44 | logits = self.dense(hidden_feature)
45 |
46 | if self.params.detail_log:
47 | for weigth_variable in self.weights:
48 | variable_summaries(weigth_variable, self.problem_name)
49 |
50 | if mode != PREDICT:
51 | # labels = tf.squeeze(labels)
52 | # convert labels to one-hot to use label_smoothing
53 | one_hot_labels = tf.one_hot(
54 | labels, depth=self.num_classes)
55 | loss_fn = partial(tf.keras.losses.categorical_crossentropy,
56 | from_logits=True, label_smoothing=self.params.label_smoothing)
57 |
58 | loss = empty_tensor_handling_loss(
59 | one_hot_labels, logits,
60 | loss_fn)
61 | loss = nan_loss_handling(loss)
62 | self.add_loss(loss)
63 | acc = self.metric_fn(labels, logits)
64 | self.add_metric(acc)
65 | return tf.nn.softmax(
66 | logits, name='%s_predict' % self.problem_name)
67 |
68 | # Cell
69 | def cls_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:
70 |
71 | le_path = get_label_encoder_save_path(params=params, problem=problem)
72 | label_encoder = LabelEncoder()
73 | if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):
74 | # fit and save label encoder
75 | label_encoder.fit(label_list)
76 | label_encoder.dump(le_path)
77 | params.set_problem_info(problem=problem, info_name='num_classes', info=len(label_encoder.encode_dict))
78 | else:
79 | label_encoder.load(le_path)
80 |
81 | return label_encoder
82 |
83 | # Cell
84 | def cls_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):
85 | label_id = label_encoder.transform([target]).tolist()[0]
86 | label_id = np.int32(label_id)
87 | return label_id, None
88 |
89 |
--------------------------------------------------------------------------------
/docs/css/syntax.css:
--------------------------------------------------------------------------------
1 | .highlight { background: #ffffff; }
2 | .highlight .c { color: #999988; font-style: italic } /* Comment */
3 | .highlight .err { color: #a61717; background-color: #e3d2d2 } /* Error */
4 | .highlight .k { font-weight: bold } /* Keyword */
5 | .highlight .o { font-weight: bold } /* Operator */
6 | .highlight .cm { color: #999988; font-style: italic } /* Comment.Multiline */
7 | .highlight .cp { color: #999999; font-weight: bold } /* Comment.Preproc */
8 | .highlight .c1 { color: #999988; font-style: italic } /* Comment.Single */
9 | .highlight .cs { color: #999999; font-weight: bold; font-style: italic } /* Comment.Special */
10 | .highlight .gd { color: #000000; background-color: #ffdddd } /* Generic.Deleted */
11 | .highlight .gd .x { color: #000000; background-color: #ffaaaa } /* Generic.Deleted.Specific */
12 | .highlight .ge { font-style: italic } /* Generic.Emph */
13 | .highlight .gr { color: #aa0000 } /* Generic.Error */
14 | .highlight .gh { color: #999999 } /* Generic.Heading */
15 | .highlight .gi { color: #000000; background-color: #ddffdd } /* Generic.Inserted */
16 | .highlight .gi .x { color: #000000; background-color: #aaffaa } /* Generic.Inserted.Specific */
17 | .highlight .go { color: #888888 } /* Generic.Output */
18 | .highlight .gp { color: #555555 } /* Generic.Prompt */
19 | .highlight .gs { font-weight: bold } /* Generic.Strong */
20 | .highlight .gu { color: #aaaaaa } /* Generic.Subheading */
21 | .highlight .gt { color: #aa0000 } /* Generic.Traceback */
22 | .highlight .kc { font-weight: bold } /* Keyword.Constant */
23 | .highlight .kd { font-weight: bold } /* Keyword.Declaration */
24 | .highlight .kp { font-weight: bold } /* Keyword.Pseudo */
25 | .highlight .kr { font-weight: bold } /* Keyword.Reserved */
26 | .highlight .kt { color: #445588; font-weight: bold } /* Keyword.Type */
27 | .highlight .m { color: #009999 } /* Literal.Number */
28 | .highlight .s { color: #d14 } /* Literal.String */
29 | .highlight .na { color: #008080 } /* Name.Attribute */
30 | .highlight .nb { color: #0086B3 } /* Name.Builtin */
31 | .highlight .nc { color: #445588; font-weight: bold } /* Name.Class */
32 | .highlight .no { color: #008080 } /* Name.Constant */
33 | .highlight .ni { color: #800080 } /* Name.Entity */
34 | .highlight .ne { color: #990000; font-weight: bold } /* Name.Exception */
35 | .highlight .nf { color: #990000; font-weight: bold } /* Name.Function */
36 | .highlight .nn { color: #555555 } /* Name.Namespace */
37 | .highlight .nt { color: #000080 } /* Name.Tag */
38 | .highlight .nv { color: #008080 } /* Name.Variable */
39 | .highlight .ow { font-weight: bold } /* Operator.Word */
40 | .highlight .w { color: #bbbbbb } /* Text.Whitespace */
41 | .highlight .mf { color: #009999 } /* Literal.Number.Float */
42 | .highlight .mh { color: #009999 } /* Literal.Number.Hex */
43 | .highlight .mi { color: #009999 } /* Literal.Number.Integer */
44 | .highlight .mo { color: #009999 } /* Literal.Number.Oct */
45 | .highlight .sb { color: #d14 } /* Literal.String.Backtick */
46 | .highlight .sc { color: #d14 } /* Literal.String.Char */
47 | .highlight .sd { color: #d14 } /* Literal.String.Doc */
48 | .highlight .s2 { color: #d14 } /* Literal.String.Double */
49 | .highlight .se { color: #d14 } /* Literal.String.Escape */
50 | .highlight .sh { color: #d14 } /* Literal.String.Heredoc */
51 | .highlight .si { color: #d14 } /* Literal.String.Interpol */
52 | .highlight .sx { color: #d14 } /* Literal.String.Other */
53 | .highlight .sr { color: #009926 } /* Literal.String.Regex */
54 | .highlight .s1 { color: #d14 } /* Literal.String.Single */
55 | .highlight .ss { color: #990073 } /* Literal.String.Symbol */
56 | .highlight .bp { color: #999999 } /* Name.Builtin.Pseudo */
57 | .highlight .vc { color: #008080 } /* Name.Variable.Class */
58 | .highlight .vg { color: #008080 } /* Name.Variable.Global */
59 | .highlight .vi { color: #008080 } /* Name.Variable.Instance */
60 | .highlight .il { color: #009999 } /* Literal.Number.Integer.Long */
--------------------------------------------------------------------------------
/docs/js/toc.js:
--------------------------------------------------------------------------------
1 | // https://github.com/ghiculescu/jekyll-table-of-contents
2 | // this library modified by fastai to:
3 | // - update the location.href with the correct anchor when a toc item is clicked on
4 | (function($){
5 | $.fn.toc = function(options) {
6 | var defaults = {
7 | noBackToTopLinks: false,
8 | title: '',
9 | minimumHeaders: 3,
10 | headers: 'h1, h2, h3, h4',
11 | listType: 'ol', // values: [ol|ul]
12 | showEffect: 'show', // values: [show|slideDown|fadeIn|none]
13 | showSpeed: 'slow' // set to 0 to deactivate effect
14 | },
15 | settings = $.extend(defaults, options);
16 |
17 | var headers = $(settings.headers).filter(function() {
18 | // get all headers with an ID
19 | var previousSiblingName = $(this).prev().attr( "name" );
20 | if (!this.id && previousSiblingName) {
21 | this.id = $(this).attr( "id", previousSiblingName.replace(/\./g, "-") );
22 | }
23 | return this.id;
24 | }), output = $(this);
25 | if (!headers.length || headers.length < settings.minimumHeaders || !output.length) {
26 | return;
27 | }
28 |
29 | if (0 === settings.showSpeed) {
30 | settings.showEffect = 'none';
31 | }
32 |
33 | var render = {
34 | show: function() { output.hide().html(html).show(settings.showSpeed); },
35 | slideDown: function() { output.hide().html(html).slideDown(settings.showSpeed); },
36 | fadeIn: function() { output.hide().html(html).fadeIn(settings.showSpeed); },
37 | none: function() { output.html(html); }
38 | };
39 |
40 | var get_level = function(ele) { return parseInt(ele.nodeName.replace("H", ""), 10); }
41 | var highest_level = headers.map(function(_, ele) { return get_level(ele); }).get().sort()[0];
42 | //var return_to_top = ' ';
43 | // other nice icons that can be used instead: glyphicon-upload glyphicon-hand-up glyphicon-chevron-up glyphicon-menu-up glyphicon-triangle-top
44 | var level = get_level(headers[0]),
45 | this_level,
46 | html = settings.title + " <"+settings.listType+">";
47 | headers.on('click', function() {
48 | if (!settings.noBackToTopLinks) {
49 | var pos = $(window).scrollTop();
50 | window.location.hash = this.id;
51 | $(window).scrollTop(pos);
52 | }
53 | })
54 | .addClass('clickable-header')
55 | .each(function(_, header) {
56 | base_url = window.location.href;
57 | base_url = base_url.replace(/#.*$/, "");
58 | this_level = get_level(header);
59 | //if (!settings.noBackToTopLinks && this_level > 1) {
60 | // $(header).addClass('top-level-header').before(return_to_top);
61 | //}
62 | txt = header.textContent.split('¶')[0].split(/\[(test|source)\]/)[0];
63 | if (!txt) {return;}
64 | if (this_level === level) // same level as before; same indenting
65 | html += "" + txt + " ";
66 | else if (this_level <= level){ // higher level than before; end parent ol
67 | for(i = this_level; i < level; i++) {
68 | html += " "+settings.listType+">"
69 | }
70 | html += "" + txt + " ";
71 | }
72 | else if (this_level > level) { // lower level than before; expand the previous to contain a ol
73 | for(i = this_level; i > level; i--) {
74 | html += "<"+settings.listType+">"+((i-level == 2) ? "" : " ")
75 | }
76 | html += "" + txt + " ";
77 | }
78 | level = this_level; // update for the next one
79 | });
80 | html += ""+settings.listType+">";
81 | if (!settings.noBackToTopLinks) {
82 | $(document).on('click', '.back-to-top', function() {
83 | $(window).scrollTop(0);
84 | window.location.hash = '';
85 | });
86 | }
87 |
88 | render[settings.showEffect]();
89 | };
90 | })(jQuery);
91 |
--------------------------------------------------------------------------------
/docs/_layouts/default.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | {% include head.html %}
5 |
41 |
46 |
57 | {% if page.datatable == true %}
58 |
59 |
60 |
61 |
66 |
76 | {% endif %}
77 |
78 |
79 |
80 | {% include topnav.html %}
81 |
82 |
83 |
84 |
85 |
86 | {% assign content_col_size = "col-md-12" %}
87 | {% unless page.hide_sidebar %}
88 |
89 |
92 | {% assign content_col_size = "col-md-9" %}
93 | {% endunless %}
94 |
95 |
96 |
97 | {{content}}
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 | {% if site.google_analytics %}
108 | {% include google_analytics.html %}
109 | {% endif %}
110 |
111 |
--------------------------------------------------------------------------------
/m3tl/problem_types/multi_cls.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/12_2_problem_type_multi_cls.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['MultiLabelClassification', 'multi_cls_get_or_make_label_encoder_fn', 'multi_cls_label_handling_fn']
4 |
5 | # Cell
6 | import pickle
7 | from typing import List
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 | from ..base_params import BaseParams
12 | from .utils import (empty_tensor_handling_loss,
13 | nan_loss_handling)
14 | from ..special_tokens import PREDICT, TRAIN
15 | from ..utils import (LabelEncoder, get_label_encoder_save_path, get_phase,
16 | need_make_label_encoder, variable_summaries)
17 | from sklearn.preprocessing import MultiLabelBinarizer
18 |
19 |
20 | # Cell
21 |
22 | class MultiLabelClassification(tf.keras.Model):
23 | def __init__(self, params: BaseParams, problem_name: str) -> None:
24 | super(MultiLabelClassification, self).__init__(name=problem_name)
25 | self.params = params
26 | self.problem_name = problem_name
27 | self.dense = tf.keras.layers.Dense(
28 | self.params.get_problem_info(problem=problem_name, info_name='num_classes'))
29 | self.dropout = tf.keras.layers.Dropout(
30 | 1-self.params.dropout_keep_prob
31 | )
32 | # self.metric_fn = tfa.metrics.F1Score(
33 | # num_classes=self.params.num_classes[problem_name],
34 | # threshold=self.params.multi_cls_threshold,
35 | # average='macro',
36 | # name='{}_f1'.format(problem_name))
37 |
38 | def call(self, inputs):
39 | mode = get_phase()
40 | training = (mode == TRAIN)
41 | feature, hidden_feature = inputs
42 | hidden_feature = hidden_feature['pooled']
43 | if mode != PREDICT:
44 | labels = feature['{}_label_ids'.format(self.problem_name)]
45 | else:
46 | labels = None
47 | hidden_feature = self.dropout(hidden_feature, training)
48 | logits = self.dense(hidden_feature)
49 |
50 | if self.params.detail_log:
51 | for weight_variable in self.weights:
52 | variable_summaries(weight_variable, self.problem_name)
53 |
54 | if mode != PREDICT:
55 | labels = tf.cast(labels, tf.float32)
56 | # use weighted loss
57 | label_weights = self.params.multi_cls_positive_weight
58 |
59 | def _loss_fn_wrapper(x, y, from_logits=True):
60 | return tf.nn.weighted_cross_entropy_with_logits(x, y, pos_weight=label_weights, name='{}_loss'.format(self.problem_name))
61 | loss = empty_tensor_handling_loss(
62 | labels, logits, _loss_fn_wrapper)
63 | loss = nan_loss_handling(loss)
64 | self.add_loss(loss)
65 | # labels = create_dummy_if_empty(labels)
66 | # logits = create_dummy_if_empty(logits)
67 | # f1 = self.metric_fn(labels, logits)
68 | # self.add_metric(f1)
69 |
70 | return tf.nn.sigmoid(
71 | logits, name='%s_predict' % self.problem_name)
72 |
73 | # Cell
74 | def multi_cls_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:
75 |
76 | le_path = get_label_encoder_save_path(params=params, problem=problem)
77 |
78 | if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):
79 | # fit and save label encoder
80 | label_encoder = MultiLabelBinarizer()
81 | label_encoder.fit(label_list)
82 | pickle.dump(label_encoder, open(le_path, 'wb'))
83 | params.set_problem_info(problem=problem, info_name='num_classes', info=label_encoder.classes_.shape[0])
84 | else:
85 | label_encoder = pickle.load(open(le_path, 'rb'))
86 |
87 | return label_encoder
88 |
89 | # Cell
90 | def multi_cls_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):
91 | label_id = label_encoder.transform([target])[0]
92 | label_id = np.int32(label_id)
93 | return label_id, None
94 |
95 |
--------------------------------------------------------------------------------
/docs/_includes/head.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | {{ page.title }} | {{ site.site_title }}
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | {% if site.use_math %}
25 |
26 |
27 |
28 |
39 | {% endif %}
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
56 |
57 |
58 |
59 |
60 | {% if site.twitter_username %}
61 |
62 |
63 |
64 | {% endif %}
65 |
66 | {% if page.summary %}
67 |
68 | {% else %}
69 |
70 | {% endif %}
71 |
72 | {% if page.image %}
73 |
74 |
75 | {% else %}
76 |
77 |
78 | {% endif %}
79 |
80 |
81 |
82 |
83 |
84 |
--------------------------------------------------------------------------------
/docs/test_base.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Test Base
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 |
10 |
11 | nb_path: "source_nbs/99_test_base.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 | {% raw %}
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
TestBase()
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | {% endraw %}
55 |
56 | {% raw %}
57 |
58 |
59 |
60 |
61 | {% endraw %}
62 |
63 | {% raw %}
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
PysparkTestBase() :: TestBase
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 | {% endraw %}
86 |
87 | {% raw %}
88 |
89 |
90 |
91 |
92 | {% endraw %}
93 |
94 | {% raw %}
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
create_dummy_features_hidden_features(batch_size =1 , hidden_dim =768 , sample_features :dict=None , problem :str=None )
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 | {% endraw %}
117 |
118 | {% raw %}
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
test_top_layer(top_class , problem :str, params :Params , sample_features :dict, hidden_dim :int, test_batch_size_list :list=None , **kwargs )
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 | {% endraw %}
141 |
142 | {% raw %}
143 |
144 |
145 |
146 |
147 | {% endraw %}
148 |
149 |
150 |
151 |
152 |
--------------------------------------------------------------------------------
/m3tl/modeling.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/11_modeling.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['MultiModalBertModel']
4 |
5 | # Cell
6 | # nbdev_comment from __future__ import absolute_import, division, print_function
7 |
8 | import json
9 |
10 | import tensorflow as tf
11 | import transformers
12 | from loguru import logger
13 | from .params import Params
14 | from .utils import (get_embedding_table_from_model,
15 | get_shape_list, load_transformer_model)
16 | from .embedding_layer.base import DefaultMultimodalEmbedding
17 |
18 |
19 | class MultiModalBertModel(tf.keras.Model):
20 | def __init__(self, params: Params, use_one_hot_embeddings=False):
21 | super(MultiModalBertModel, self).__init__()
22 | self.params = params
23 | if self.params.init_weight_from_huggingface:
24 | self.bert_model = load_transformer_model(
25 | self.params.transformer_model_name, self.params.transformer_model_loading)
26 | else:
27 | self.bert_model = load_transformer_model(
28 | self.params.bert_config, self.params.transformer_model_loading)
29 | self.bert_model(tf.convert_to_tensor(
30 | transformers.file_utils.DUMMY_INPUTS))
31 | self.use_one_hot_embeddings = use_one_hot_embeddings
32 |
33 | # multimodal input dense
34 | self.embedding_layer = self.bert_model.get_input_embeddings()
35 | self.multimoda_embedding = self.params.embedding_layer['model'](
36 | params=self.params, embedding_layer=self.embedding_layer)
37 |
38 | @tf.function
39 | def call(self, inputs, training=False):
40 | emb_inputs, embedding_tup = self.multimoda_embedding(inputs, training)
41 | self.embedding_output = embedding_tup.word_embedding
42 | self.model_input_mask = embedding_tup.res_input_mask
43 | self.model_token_type_ids = embedding_tup.res_segment_ids
44 |
45 | outputs = self.bert_model(
46 | {'input_ids': None,
47 | 'inputs_embeds': self.embedding_output,
48 | 'attention_mask': self.model_input_mask,
49 | 'token_type_ids': self.model_token_type_ids,
50 | 'position_ids': None},
51 | training=training
52 | )
53 | self.sequence_output = outputs.last_hidden_state
54 | if 'pooler_output' in outputs:
55 | self.pooled_output = outputs.pooler_output
56 | else:
57 | # no pooled output, use mean of token embedding
58 | self.pooled_output = tf.reduce_mean(
59 | outputs.last_hidden_state, axis=1)
60 | outputs['pooler_output'] = self.pooled_output
61 | self.all_encoder_layers = tf.stack(outputs.hidden_states, axis=1)
62 | outputs = {k: v for k, v in outputs.items() if k not in (
63 | 'hidden_states', 'attentions')}
64 | outputs['model_input_mask'] = self.model_input_mask
65 | outputs['model_token_type_ids'] = self.model_token_type_ids
66 | outputs['all_encoder_layers'] = self.all_encoder_layers
67 | outputs['embedding_output'] = self.embedding_output
68 | outputs['embedding_table'] = self.embedding_layer.weights[0]
69 | return emb_inputs, outputs
70 |
71 | def get_pooled_output(self):
72 | return self.pooled_output
73 |
74 | def get_sequence_output(self):
75 | """Gets final hidden layer of encoder.
76 |
77 | Returns:
78 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
79 | to the final hidden of the transformer encoder.
80 | """
81 | return self.sequence_output
82 |
83 | def get_all_encoder_layers(self):
84 | return self.all_encoder_layers
85 |
86 | def get_embedding_output(self):
87 | """Gets output of the embedding lookup (i.e., input to the transformer).
88 |
89 | Returns:
90 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
91 | to the output of the embedding layer, after summing the word
92 | embeddings with the positional embeddings and the token type embeddings,
93 | then performing layer normalization. This is the input to the transformer.
94 | """
95 | return self.embedding_output
96 |
97 | def get_embedding_table(self):
98 | return get_embedding_table_from_model(self.bert_model)
99 |
100 | def get_input_mask(self):
101 | return self.model_input_mask
102 |
103 | def get_token_type_ids(self):
104 | return self.model_token_type_ids
105 |
--------------------------------------------------------------------------------
/docs/bert_utils.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Bert Utils
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 |
10 |
11 | nb_path: "source_nbs/03_bert_utils.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 | {% raw %}
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
truncate_seq_pair(tokens_a , tokens_b , target , max_length , rng =None , is_seq =False )
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | {% endraw %}
55 |
56 | {% raw %}
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
punc_augument(raw_inputs , params )
68 |
69 |
This code is dedicated in memory of a special time.
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 | {% endraw %}
80 |
81 | {% raw %}
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
create_instances_from_document(all_documents , document_index , max_seq_length , short_seq_prob , masked_lm_prob , max_predictions_per_seq , vocab_words , rng )
93 |
94 |
Creates TrainingInstance s for a single document.
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 | {% endraw %}
105 |
106 | {% raw %}
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
create_masked_lm_predictions(tokens , masked_lm_prob , max_predictions_per_seq , vocab_words , rng )
118 |
119 |
Creates the predictions for the masked LM objective.
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 | {% endraw %}
130 |
131 | {% raw %}
132 |
133 |
134 |
135 |
136 | {% endraw %}
137 |
138 |
139 |
140 |
141 |
--------------------------------------------------------------------------------
/m3tl/params.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/00_1_params.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['Params']
4 |
5 | # Cell
6 |
7 | from .base_params import BaseParams
8 | from .embedding_layer.base import (DefaultMultimodalEmbedding,
9 | DuplicateAugMultimodalEmbedding)
10 | from .loss_strategy.base import SumLossCombination
11 | from .mtl_model.mmoe import MMoE
12 | from .problem_types import cls as problem_type_cls
13 | from .problem_types import (contrastive_learning, masklm, multi_cls,
14 | premask_mlm, pretrain, regression, seq_tag,
15 | vector_fit)
16 |
17 |
18 | class Params(BaseParams):
19 | def __init__(self):
20 | super().__init__()
21 | # register pre-defined problem types
22 | self.register_problem_type(problem_type='cls',
23 | top_layer=problem_type_cls.Classification,
24 | label_handling_fn=problem_type_cls.cls_label_handling_fn,
25 | get_or_make_label_encoder_fn=problem_type_cls.cls_get_or_make_label_encoder_fn,
26 | description='Classification')
27 | self.register_problem_type(problem_type='multi_cls',
28 | top_layer=multi_cls.MultiLabelClassification,
29 | label_handling_fn=multi_cls.multi_cls_label_handling_fn,
30 | get_or_make_label_encoder_fn=multi_cls.multi_cls_get_or_make_label_encoder_fn,
31 | description='Multi-Label Classification')
32 | self.register_problem_type(problem_type='seq_tag',
33 | top_layer=seq_tag.SequenceLabel,
34 | label_handling_fn=seq_tag.seq_tag_label_handling_fn,
35 | get_or_make_label_encoder_fn=seq_tag.seq_tag_get_or_make_label_encoder_fn,
36 | description='Sequence Labeling')
37 | self.register_problem_type(problem_type='masklm',
38 | top_layer=masklm.MaskLM,
39 | label_handling_fn=masklm.masklm_label_handling_fn,
40 | get_or_make_label_encoder_fn=masklm.masklm_get_or_make_label_encoder_fn,
41 | description='Masked Language Model')
42 | self.register_problem_type(problem_type='pretrain',
43 | top_layer=pretrain.PreTrain,
44 | label_handling_fn=pretrain.pretrain_label_handling_fn,
45 | get_or_make_label_encoder_fn=pretrain.pretrain_get_or_make_label_encoder_fn,
46 | description='NSP+MLM(Deprecated)')
47 | self.register_problem_type(problem_type='regression',
48 | top_layer=regression.Regression,
49 | label_handling_fn=regression.regression_label_handling_fn,
50 | get_or_make_label_encoder_fn=regression.regression_get_or_make_label_encoder_fn,
51 | description='Regression')
52 | self.register_problem_type(
53 | problem_type='vector_fit',
54 | top_layer=vector_fit.VectorFit,
55 | label_handling_fn=vector_fit.vector_fit_label_handling_fn,
56 | get_or_make_label_encoder_fn=vector_fit.vector_fit_get_or_make_label_encoder_fn,
57 | description='Vector Fitting')
58 | self.register_problem_type(
59 | problem_type='premask_mlm',
60 | top_layer=premask_mlm.PreMaskMLM,
61 | label_handling_fn=premask_mlm.premask_mlm_label_handling_fn,
62 | get_or_make_label_encoder_fn=premask_mlm.premask_mlm_get_or_make_label_encoder_fn,
63 | description='Pre-masked Masked Language Model'
64 | )
65 | self.register_problem_type(
66 | problem_type='contrastive_learning',
67 | top_layer=contrastive_learning.ContrastiveLearning,
68 | label_handling_fn=contrastive_learning.contrastive_learning_label_handling_fn,
69 | get_or_make_label_encoder_fn=contrastive_learning.contrastive_learning_get_or_make_label_encoder_fn,
70 | description='Contrastive Learning'
71 | )
72 |
73 | self.register_mtl_model(
74 | 'mmoe', MMoE, include_top=False, extra_info='MMoE')
75 | self.register_loss_combination_strategy('sum', SumLossCombination)
76 | self.register_embedding_layer(
77 | 'duplicate_data_augmentation_embedding', DuplicateAugMultimodalEmbedding)
78 | self.register_embedding_layer(
79 | 'default_embedding', DefaultMultimodalEmbedding)
80 |
81 | self.assign_loss_combination_strategy('sum')
82 | self.assign_data_sampling_strategy()
83 | self.assign_embedding_layer('default_embedding')
84 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: M3TL
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 |
10 |
11 | nb_path: "source_nbs/index.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
33 |
34 |
M ulti-M odal M ulti-T ask L earning
35 |
Install
36 |
MASKED
37 |
What is it This is a project that uses transformers(based on huggingface transformers) as base model to do multi-modal multi-task learning .
38 |
Why do I need this Multi-task learning(MTL) is gaining more and more attention, especially in deep learning era. It is widely used in NLP, CV, recommendation, etc. However, MTL usually involves complicated data preprocessing, task managing and task interaction. Other open-source projects, like TencentNLP and PyText, supports MTL but in a naive way and it's not straightforward to implement complicated MTL algorithm. In this project, we try to make writing MTL model as easy as single task learning model and further extend MTL to multi-modal multi-task learning. To do so, we expose following MTL related programable module to user:
39 |
40 | problem sampling strategy
41 | loss combination strategy
42 | gradient surgery
43 | model after base model(transformers)
44 |
45 |
Apart from programable modules, we also provide various built-in SOTA MTL algorithms.
46 |
In a word, you can use this project to:
47 |
48 | implement complicated MTL algorithm
49 | do SOTA MTL without diving into details
50 | do multi-modal learning
51 |
52 |
And since we use transformers as base model, you get all the benefits that you can get from transformers!
53 |
What type of problems are supported?
54 |
55 |
56 |
57 | {% raw %}
58 |
59 |
60 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
`cls`: Classification
81 | `multi_cls`: Multi-Label Classification
82 | `seq_tag`: Sequence Labeling
83 | `masklm`: Masked Language Model
84 | `pretrain`: NSP+MLM(Deprecated)
85 | `regression`: Regression
86 | `vector_fit`: Vector Fitting
87 | `premask_mlm`: Pre-masked Masked Language Model
88 | `contrastive_learning`: Contrastive Learning
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 | {% endraw %}
98 |
99 |
100 |
101 |
Get Started Please see tutorials.
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
--------------------------------------------------------------------------------
/m3tl/input_fn.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/07_input_fn.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['element_length_func', 'train_eval_input_fn', 'predict_input_fn']
4 |
5 | # Cell
6 | from typing import List, Union, Dict
7 | import json
8 | from loguru import logger
9 |
10 | import tensorflow as tf
11 |
12 | from .params import Params
13 | from .read_write_tfrecord import read_tfrecord, write_tfrecord
14 | from .special_tokens import PREDICT, TRAIN
15 | from .utils import infer_shape_and_type_from_dict, get_is_pyspark
16 | from .preproc_decorator import preprocessing_fn
17 |
18 |
19 | # Cell
20 |
21 | def element_length_func(yield_dict: Dict[str, tf.Tensor]): # pragma: no cover
22 | input_ids_keys = [k for k in yield_dict.keys() if 'input_ids' in k]
23 | max_length = tf.reduce_sum([tf.shape(yield_dict[k])[0]
24 | for k in input_ids_keys])
25 | return max_length
26 |
27 |
28 | def train_eval_input_fn(params: Params, mode=TRAIN) -> tf.data.Dataset:
29 | '''
30 | This function will write and read tf record for training
31 | and evaluation.
32 |
33 | Arguments:
34 | params {Params} -- Params objects
35 |
36 | Keyword Arguments:
37 | mode {str} -- ModeKeys (default: {TRAIN})
38 |
39 | Returns:
40 | tf Dataset -- Tensorflow dataset
41 | '''
42 | write_tfrecord(params=params)
43 |
44 | # reading with pyspark is not supported
45 | if get_is_pyspark():
46 | return
47 |
48 | dataset_dict = read_tfrecord(params=params, mode=mode)
49 |
50 | # make sure the order is correct
51 | dataset_dict_keys = list(dataset_dict.keys())
52 | dataset_list = [dataset_dict[key] for key in dataset_dict_keys]
53 | sample_prob_dict = params.calculate_data_sampling_prob()
54 | weight_list = [
55 | sample_prob_dict[key]
56 | for key in dataset_dict_keys
57 | ]
58 |
59 | logger.info('sampling weights: ')
60 | logger.info(json.dumps(params.problem_sampling_weight_dict, indent=4))
61 | # for problem_chunk_name, weight in params.problem_sampling_weight_dict.items():
62 | # logger.info('{0}: {1}'.format(problem_chunk_name, weight))
63 |
64 | dataset = tf.data.experimental.sample_from_datasets(
65 | datasets=dataset_list, weights=weight_list)
66 | options = tf.data.Options()
67 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
68 | dataset = dataset.with_options(options)
69 |
70 | if mode == TRAIN:
71 | dataset = dataset.shuffle(params.shuffle_buffer)
72 |
73 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
74 | if params.dynamic_padding:
75 | dataset = dataset.apply(
76 | tf.data.experimental.bucket_by_sequence_length(
77 | element_length_func=element_length_func,
78 | bucket_batch_sizes=params.bucket_batch_sizes,
79 | bucket_boundaries=params.bucket_boundaries
80 | ))
81 | else:
82 | first_example = next(dataset.as_numpy_iterator())
83 | output_shapes, _ = infer_shape_and_type_from_dict(first_example)
84 |
85 | if mode == TRAIN:
86 | dataset = dataset.padded_batch(params.batch_size, output_shapes)
87 | else:
88 | dataset = dataset.padded_batch(params.batch_size*2, output_shapes)
89 |
90 | return dataset
91 |
92 |
93 | # Cell
94 | def predict_input_fn(input_file_or_list: Union[str, List[str]],
95 | params: Params,
96 | mode=PREDICT,
97 | labels_in_input=False) -> tf.data.Dataset:
98 | '''Input function that takes a file path or list of string and
99 | convert it to tf.dataset
100 |
101 | Example:
102 | predict_fn = lambda: predict_input_fn('test.txt', params)
103 | pred = estimator.predict(predict_fn)
104 |
105 | Arguments:
106 | input_file_or_list {str or list} -- file path or list of string
107 | params {Params} -- Params object
108 |
109 | Keyword Arguments:
110 | mode {str} -- ModeKeys (default: {PREDICT})
111 |
112 | Returns:
113 | tf dataset -- tf dataset
114 | '''
115 |
116 | # if is string, treat it as path to file
117 | if isinstance(input_file_or_list, str):
118 | inputs = open(input_file_or_list, 'r', encoding='utf8')
119 | else:
120 | inputs = input_file_or_list
121 |
122 | # ugly wrapping
123 | def gen():
124 | @preprocessing_fn
125 | def gen_wrapper(params, mode):
126 | return inputs
127 | return gen_wrapper(params, mode)
128 |
129 | first_dict = next(gen())
130 |
131 | output_shapes, output_type = infer_shape_and_type_from_dict(first_dict)
132 | dataset = tf.data.Dataset.from_generator(
133 | gen, output_types=output_type, output_shapes=output_shapes)
134 |
135 | dataset = dataset.padded_batch(
136 | params.batch_size,
137 | output_shapes
138 | )
139 | # dataset = dataset.batch(config.batch_size*2)
140 |
141 | return dataset
142 |
--------------------------------------------------------------------------------
/docs/js/jekyll-search.js:
--------------------------------------------------------------------------------
1 | !function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a="function"==typeof require&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);throw new Error("Cannot find module '"+o+"'")}var f=n[o]={exports:{}};t[o][0].call(f.exports,function(e){var n=t[o][1][e];return s(n?n:e)},f,f.exports,e,t,n,r)}return n[o].exports}for(var i="function"==typeof require&&require,o=0;o=0}var self=this;self.matches=function(string,crit){return"string"!=typeof string?!1:(string=string.trim(),doMatch(string,crit))}}module.exports=new LiteralSearchStrategy},{}],4:[function(require,module){module.exports=function(){function findMatches(store,crit,strategy){for(var data=store.get(),i=0;i{title} ',noResultsText:"No results found",limit:10,fuzzy:!1};self.init=function(_opt){validateOptions(_opt),assignOptions(_opt),isJSON(opt.dataSource)?initWithJSON(opt.dataSource):initWithURL(opt.dataSource)}}var Searcher=require("./Searcher"),Templater=require("./Templater"),Store=require("./Store"),JSONLoader=require("./JSONLoader"),searcher=new Searcher,templater=new Templater,store=new Store,jsonLoader=new JSONLoader;window.SimpleJekyllSearch=new SimpleJekyllSearch}(window,document)},{"./JSONLoader":1,"./Searcher":4,"./Store":5,"./Templater":6}]},{},[7]);
2 |
--------------------------------------------------------------------------------
/m3tl/problem_types/pretrain.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/12_6_problem_type_pretrain.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['PreTrain', 'pretrain_get_or_make_label_encoder_fn', 'pretrain_label_handling_fn']
4 |
5 | # Cell
6 | from typing import Dict, List, Tuple
7 |
8 | from loguru import logger
9 | import tensorflow as tf
10 | import transformers
11 | from ..base_params import BaseParams
12 | from .utils import empty_tensor_handling_loss
13 | from ..special_tokens import PREDICT
14 | from ..utils import (LabelEncoder, gather_indexes, get_phase,
15 | variable_summaries)
16 | from transformers import TFSharedEmbeddings
17 |
18 |
19 | # Cell
20 |
21 | class PreTrain(tf.keras.Model):
22 | def __init__(self, params: BaseParams, problem_name: str, input_embeddings: tf.Tensor=None, share_embedding=True):
23 | super(PreTrain, self).__init__(name=problem_name)
24 | self.problem_name = problem_name
25 | self.params = params
26 | self.nsp = transformers.models.bert.modeling_tf_bert.TFBertNSPHead(
27 | self.params.bert_config)
28 |
29 | if share_embedding is False:
30 | self.vocab_size = self.params.bert_config.vocab_size
31 | self.share_embedding = False
32 | else:
33 | word_embedding_weight = input_embeddings.word_embeddings
34 | self.vocab_size = word_embedding_weight.shape[0]
35 | embedding_size = word_embedding_weight.shape[-1]
36 | share_valid = (self.params.bert_config.hidden_size ==
37 | embedding_size)
38 | if not share_valid and self.params.share_embedding:
39 | logger.warning(
40 | 'Share embedding is enabled but hidden_size != embedding_size')
41 | self.share_embedding = self.params.share_embedding & share_valid
42 |
43 | if self.share_embedding:
44 | self.share_embedding_layer = TFSharedEmbeddings(
45 | vocab_size=word_embedding_weight.shape[0], hidden_size=word_embedding_weight.shape[1])
46 | self.share_embedding_layer.build([1])
47 | self.share_embedding_layer.weight = word_embedding_weight
48 | else:
49 | self.share_embedding_layer = tf.keras.layers.Dense(self.vocab_size)
50 |
51 | def call(self,
52 | inputs: Tuple[Dict[str, Dict[str, tf.Tensor]], Dict[str, Dict[str, tf.Tensor]]]) -> Tuple[tf.Tensor, tf.Tensor]:
53 | mode = get_phase()
54 | features, hidden_features = inputs
55 |
56 | # compute logits
57 | nsp_logits = self.nsp(hidden_features['pooled'])
58 |
59 | # masking is done inside the model
60 | seq_hidden_feature = hidden_features['seq']
61 | if mode != PREDICT:
62 | positions = features['masked_lm_positions']
63 |
64 | # gather_indexes will flatten the seq hidden_states, we need to reshape
65 | # back to 3d tensor
66 | input_tensor = gather_indexes(seq_hidden_feature, positions)
67 | shape_tensor = tf.shape(positions)
68 | shape_list = tf.concat(
69 | [shape_tensor, [seq_hidden_feature.shape.as_list()[-1]]], axis=0)
70 | input_tensor = tf.reshape(input_tensor, shape=shape_list)
71 | # set_shape to determin rank
72 | input_tensor.set_shape(
73 | [None, None, seq_hidden_feature.shape.as_list()[-1]])
74 | else:
75 | input_tensor = seq_hidden_feature
76 | if self.share_embedding:
77 | mlm_logits = self.share_embedding_layer(
78 | input_tensor, mode='linear')
79 | else:
80 | mlm_logits = self.share_embedding_layer(input_tensor)
81 |
82 | if self.params.detail_log:
83 | for weight_variable in self.weights:
84 | variable_summaries(weight_variable, self.problem_name)
85 |
86 | if mode != PREDICT:
87 | nsp_labels = features['next_sentence_label_ids']
88 | mlm_labels = features['masked_lm_ids']
89 | mlm_labels.set_shape([None, None])
90 | # compute loss
91 | nsp_loss = empty_tensor_handling_loss(
92 | nsp_labels, nsp_logits,
93 | tf.keras.losses.sparse_categorical_crossentropy)
94 | mlm_loss_layer = transformers.modeling_tf_utils.TFMaskedLanguageModelingLoss()
95 | # mlm_loss = tf.reduce_mean(
96 | # mlm_loss_layer.compute_loss(mlm_labels, mlm_logits))
97 |
98 | # add a useless from_logits argument to match the function signature of keras losses.
99 | def loss_fn_wrapper(labels, logits, from_logits=True):
100 | return mlm_loss_layer.compute_loss(labels, logits)
101 | mlm_loss = empty_tensor_handling_loss(
102 | mlm_labels,
103 | mlm_logits,
104 | loss_fn_wrapper
105 | )
106 | loss = nsp_loss + mlm_loss
107 | self.add_loss(loss)
108 |
109 | return (tf.sigmoid(nsp_logits), tf.nn.softmax(mlm_logits))
110 |
111 | # Cell
112 | def pretrain_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:
113 | params.set_problem_info(problem=problem, info_name='num_classes', info=1)
114 | return None
115 |
116 | # Cell
117 | def pretrain_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):
118 | return None, None
119 |
120 |
--------------------------------------------------------------------------------
/m3tl/problem_types/masklm.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/12_4_problem_type_masklm.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['MaskLM', 'masklm_get_or_make_label_encoder_fn', 'masklm_label_handling_fn']
4 |
5 | # Cell
6 | import pickle
7 | from typing import List
8 |
9 | import tensorflow as tf
10 | from loguru import logger
11 | from ..base_params import BaseParams
12 | from .utils import (empty_tensor_handling_loss,
13 | nan_loss_handling, pad_to_shape)
14 | from ..special_tokens import PREDICT
15 | from ..utils import (LabelEncoder, gather_indexes,
16 | get_label_encoder_save_path, get_phase,
17 | load_transformer_tokenizer, need_make_label_encoder)
18 | from transformers import TFSharedEmbeddings
19 |
20 |
21 | # Cell
22 | class MaskLM(tf.keras.Model):
23 | """Multimodal MLM top layer.
24 | """
25 |
26 | def __init__(self, params: BaseParams, problem_name: str, input_embeddings: tf.keras.layers.Layer=None, share_embedding=True) -> None:
27 | super(MaskLM, self).__init__(name=problem_name)
28 | self.params = params
29 | self.problem_name = problem_name
30 |
31 | if share_embedding is False:
32 | self.vocab_size = self.params.bert_config.vocab_size
33 | self.share_embedding = False
34 | else:
35 | self.vocab_size = input_embeddings.shape[0]
36 | embedding_size = input_embeddings.shape[-1]
37 | share_valid = (self.params.bert_config.hidden_size ==
38 | embedding_size)
39 | if not share_valid and self.params.share_embedding:
40 | logger.warning(
41 | 'Share embedding is enabled but hidden_size != embedding_size')
42 | self.share_embedding = self.params.share_embedding & share_valid
43 |
44 | if self.share_embedding:
45 | self.share_embedding_layer = TFSharedEmbeddings(
46 | vocab_size=self.vocab_size, hidden_size=input_embeddings.shape[1])
47 | self.share_embedding_layer.build([1])
48 | self.share_embedding_layer.weight = input_embeddings
49 | else:
50 | self.share_embedding_layer = tf.keras.layers.Dense(self.vocab_size)
51 |
52 | def call(self, inputs):
53 | mode = get_phase()
54 | features, hidden_features = inputs
55 |
56 | # masking is done inside the model
57 | seq_hidden_feature = hidden_features['seq']
58 | if mode != PREDICT:
59 | positions = features['masked_lm_positions']
60 |
61 | # gather_indexes will flatten the seq hidden_states, we need to reshape
62 | # back to 3d tensor
63 | input_tensor = gather_indexes(seq_hidden_feature, positions)
64 | shape_tensor = tf.shape(positions)
65 | shape_list = tf.concat([shape_tensor, [seq_hidden_feature.shape.as_list()[-1]]], axis=0)
66 | input_tensor = tf.reshape(input_tensor, shape=shape_list)
67 | # set_shape to determin rank
68 | input_tensor.set_shape(
69 | [None, None, seq_hidden_feature.shape.as_list()[-1]])
70 | else:
71 | input_tensor = seq_hidden_feature
72 | if self.share_embedding:
73 | mlm_logits = self.share_embedding_layer(
74 | input_tensor, mode='linear')
75 | else:
76 | mlm_logits = self.share_embedding_layer(input_tensor)
77 | if mode != PREDICT:
78 | mlm_labels = features['masked_lm_ids']
79 | mlm_labels.set_shape([None, None])
80 | mlm_labels = pad_to_shape(from_tensor=mlm_labels, to_tensor=mlm_logits, axis=1)
81 | # compute loss
82 | mlm_loss = empty_tensor_handling_loss(
83 | mlm_labels,
84 | mlm_logits,
85 | tf.keras.losses.sparse_categorical_crossentropy
86 | )
87 | loss = nan_loss_handling(mlm_loss)
88 | self.add_loss(loss)
89 |
90 | return tf.nn.softmax(mlm_logits)
91 |
92 | # Cell
93 | def masklm_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:
94 |
95 | le_path = get_label_encoder_save_path(params=params, problem=problem)
96 |
97 | if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):
98 | # fit and save label encoder
99 | label_encoder = load_transformer_tokenizer(params.transformer_tokenizer_name, params.transformer_tokenizer_loading)
100 | pickle.dump(label_encoder, open(le_path, 'wb'))
101 | try:
102 | params.set_problem_info(problem=problem, info_name='num_classes', info=len(label_encoder.vocab))
103 | except AttributeError:
104 | # models like xlnet's vocab size can only be retrieved from config instead of tokenizer
105 | params.set_problem_info(problem=problem, info_name='num_classes', info=params.bert_config.vocab_size)
106 | else:
107 | # models like xlnet's vocab size can only be retrieved from config instead of tokenizer
108 | params.set_problem_info(problem=problem, info_name='num_classes', info=params.bert_config.vocab_size)
109 | label_encoder = pickle.load(open(le_path, 'rb'))
110 |
111 | return label_encoder
112 |
113 | # Cell
114 | def masklm_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):
115 | # masklm is a special case since it modifies inputs
116 | # for more standard implementation of masklm, please see premask_mlm
117 | return None, None
118 |
119 |
--------------------------------------------------------------------------------
/docs/1_problem_type_cls.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Classification(cls)
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 |
10 |
11 | nb_path: "source_nbs/12_1_problem_type_cls.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
33 |
34 |
Classification. By default this problem will use [CLS] token embedding.
35 |
Example: m3tl.predefined_problems.get_weibo_fake_cls_fn.
36 |
37 |
38 |
39 |
40 |
41 |
42 |
Imports and utils
43 |
44 |
45 |
46 | {% raw %}
47 |
48 |
49 |
50 |
51 | {% endraw %}
52 |
53 |
59 | {% raw %}
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
Classification(*args , **kwargs ) :: Layer
71 |
72 |
Classification Top Layer
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 | {% endraw %}
83 |
84 | {% raw %}
85 |
86 |
87 |
88 |
89 | {% endraw %}
90 |
91 |
92 |
93 |
Get or make label encoder function
94 |
95 |
96 |
97 | {% raw %}
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
cls_get_or_make_label_encoder_fn(params :BaseParams , problem :str, mode :str, label_list :List[str], *args , **kwargs )
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 | {% endraw %}
120 |
121 | {% raw %}
122 |
123 |
124 |
125 |
126 | {% endraw %}
127 |
128 |
129 |
130 |
Label handing function
131 |
132 |
133 |
134 | {% raw %}
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
cls_label_handling_fn[source] cls_label_handling_fn(target , label_encoder =None , tokenizer =None , decoding_length =None , *args , **kwargs )
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 | {% endraw %}
157 |
158 | {% raw %}
159 |
160 |
161 |
162 |
163 | {% endraw %}
164 |
165 |
166 |
167 |
168 |
--------------------------------------------------------------------------------
/m3tl/problem_types/contrastive_learning.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/12_10_problem_type_contrast_learning.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['SimCSE', 'get_contrastive_learning_model', 'ContrastiveLearning',
4 | 'contrastive_learning_get_or_make_label_encoder_fn', 'contrastive_learning_label_handling_fn']
5 |
6 | # Cell
7 | from typing import List
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 | from loguru import logger
12 | from ..base_params import BaseParams
13 | from .utils import (empty_tensor_handling_loss,
14 | nan_loss_handling, pad_to_shape)
15 | from ..special_tokens import PREDICT
16 | from ..utils import (LabelEncoder, get_label_encoder_save_path, get_phase,
17 | need_make_label_encoder)
18 |
19 |
20 | # Cell
21 | # export
22 | class SimCSE(tf.keras.Model):
23 | def __init__(self, params: BaseParams, problem_name: str) -> None:
24 | super(SimCSE, self).__init__(name='simcse')
25 | self.params = params
26 | self.problem_name = problem_name
27 | self.dropout = tf.keras.layers.Dropout(self.params.dropout)
28 | self.pooler = self.params.get('simcse_pooler', 'pooled')
29 | self.metric_fn = tf.keras.metrics.CategoricalAccuracy(name='{}_acc'.format(problem_name))
30 | availabel_pooler = ['pooled', 'mean_pool']
31 | assert self.pooler in availabel_pooler, \
32 | 'available params.simcse_pooler: {}, got: {}'.format(
33 | availabel_pooler, self.pooler)
34 | if self.params.embedding_layer['name'] != 'duplicate_data_augmentation_embedding':
35 | raise ValueError(
36 | 'SimCSE requires duplicate_data_augmentation_embedding. Fix it with `params.assign_embedding_layer(\'duplicate_data_augmentation_embedding\')`')
37 |
38 | def call(self, inputs):
39 |
40 | features, hidden_features = inputs
41 | phase = get_phase()
42 |
43 | if phase != PREDICT:
44 | # created pool embedding
45 | if self.pooler == 'pooled':
46 | all_pooled_embedding = hidden_features['pooled']
47 | else:
48 | all_pooled_embedding = tf.reduce_mean(
49 | hidden_features['seq'], axis=1)
50 |
51 | # shape (batch_size, hidden_dim)
52 | pooled_rep1_embedding, pooled_rep2_embedding = tf.split(
53 | all_pooled_embedding, 2)
54 |
55 | # calculate similarity
56 | pooled_rep1_embedding = tf.math.l2_normalize(
57 | pooled_rep1_embedding, axis=1)
58 | pooled_rep2_embedding = tf.math.l2_normalize(
59 | pooled_rep2_embedding, axis=1)
60 | # shape (batch_size, batch_size)
61 | similarity = tf.matmul(pooled_rep1_embedding,
62 | pooled_rep2_embedding, transpose_b=True)
63 | labels = tf.eye(tf.shape(similarity)[0])
64 |
65 | # shape (batch_size*batch_size)
66 | similarity = tf.reshape(similarity, shape=(-1, 1))
67 | labels = tf.reshape(labels, shape=(-1, 1))
68 |
69 | # make compatible with binary crossentropy
70 | similarity = tf.concat([1-similarity, similarity], axis=1)
71 | labels = tf.concat([1-labels, labels], axis=1)
72 | loss = tf.keras.losses.binary_crossentropy(labels, similarity)
73 | loss = tf.reduce_mean(loss)
74 | self.add_loss(loss)
75 | acc = self.metric_fn(labels, similarity)
76 | self.add_metric(acc)
77 | return inputs[1]['pooled']
78 |
79 |
80 | # Cell
81 | def get_contrastive_learning_model(params: BaseParams, problem_name: str, model_name: str) -> tf.keras.Model:
82 | if model_name == 'simcse':
83 | return SimCSE(params=params, problem_name=problem_name)
84 |
85 | logger.warning(
86 | '{} not match any contrastive learning model, using SimCSE'.format(model_name))
87 | return SimCSE(params=params, problem_name=problem_name)
88 |
89 |
90 | # Cell
91 |
92 | class ContrastiveLearning(tf.keras.Model):
93 | def __init__(self, params: BaseParams, problem_name: str) -> None:
94 | super(ContrastiveLearning, self).__init__(name=problem_name)
95 | self.params = params
96 | self.problem_name = problem_name
97 | self.contrastive_learning_model_name = self.params.contrastive_learning_model_name
98 | self.contrastive_learning_model = get_contrastive_learning_model(
99 | params=self.params, problem_name=problem_name, model_name=self.contrastive_learning_model_name)
100 |
101 | def call(self, inputs):
102 | return self.contrastive_learning_model(inputs)
103 |
104 |
105 | # Cell
106 | def contrastive_learning_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:
107 |
108 | le_path = get_label_encoder_save_path(params=params, problem=problem)
109 | label_encoder = LabelEncoder()
110 | if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):
111 | # fit and save label encoder
112 | label_encoder.fit(label_list)
113 | label_encoder.dump(le_path)
114 | params.set_problem_info(
115 | problem=problem, info_name='num_classes', info=len(label_encoder.encode_dict))
116 | else:
117 | label_encoder.load(le_path)
118 |
119 | return label_encoder
120 |
121 |
122 | # Cell
123 | def contrastive_learning_label_handling_fn(target: str, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs) -> dict:
124 |
125 | label_id = label_encoder.transform([target]).tolist()[0]
126 | label_id = np.int32(label_id)
127 | return label_id, None
128 |
--------------------------------------------------------------------------------
/tests/test_nbs.py:
--------------------------------------------------------------------------------
1 | # We need a "test_" file with "test_" functions to make it easy to run with pytest
2 |
3 |
4 | # couple of example "test_" functions
5 |
6 | # import nbdev.test
7 | # def test_run():
8 | # print('running nbdev.test.test_nb("20_models.ipynb") ...')
9 | # nbdev.test.test_nb('20_models.ipynb')
10 |
11 | # import os
12 | # def test_run():
13 | # print('running nbdev_test_nbs...')
14 | # os.system('nbdev_test_nbs')
15 |
16 |
17 | # set-up a "before test" callback handler that will modify the notebook before it is run
18 | import os
19 | from pathlib import Path
20 | import time
21 | import glob
22 | import nbformat
23 | import random
24 | from nbdev.imports import Config, parallel
25 | from nbdev.export import read_nb, find_default_export, is_export, split_flags_and_code
26 | from nbdev.test import get_all_flags, NoExportPreprocessor
27 |
28 |
29 | def before_test(nb):
30 | "callback that will import modules and run cells that are not exported"
31 | default_export = find_default_export(nb['cells'])
32 | exports = [is_export(c, default_export) for c in nb['cells']]
33 | imports = ''
34 | # exclude exported, notebook2script calls etc
35 | things_to_exclude = ['notebook2script']
36 | cells = [(i, c, e) for i, (c, e) in enumerate(
37 | zip(nb['cells'], exports)) if c['cell_type'] == 'code']
38 | for i, c, e in cells:
39 | if e:
40 | # if it's exported to the library, don't run as test
41 | c['cell_type'] = 'exclude'
42 | # but we might still need to run import statements
43 | for line in split_flags_and_code(c):
44 | if 'import' in line:
45 | imports += f'{line}\n'
46 | continue
47 | for thing_to_exclude in things_to_exclude: # TODO: is this too coarse? maybe just exclude specific lines?
48 | if thing_to_exclude in c['source']:
49 | c['cell_type'] = 'exclude'
50 | continue
51 |
52 | nb['cells'].insert(0, nbformat.v4.new_code_cell(imports))
53 |
54 | # import everything from modules written to by this notebook
55 | for export in {export[0] for export in exports if export}:
56 | export_parts = export.split('.')
57 | b = export_parts.pop()
58 | export_parts.insert(0, Config().lib_name)
59 | a = '.'.join(export_parts)
60 | src = f"""
61 | from {a} import {b}
62 | for o in dir({b}):
63 | exec(f'from {a}.{b} import {{o}}')"""
64 | nb['cells'].insert(0, nbformat.v4.new_code_cell(src))
65 | return nb
66 |
67 | # uncomment to see current nbdev behaviour
68 | # i.e. use a before test callback that does nothing
69 | # def before_test(nb): return nb
70 |
71 | # If nbdev.test.test_nb knew to call our "before test" callback, the rest of this script could be just the following 3 lines
72 | # def test_run():
73 | # from nbdev.cli import nbdev_test_nbs
74 | # nbdev_test_nbs.__wrapped__()
75 |
76 | # until it does ... we need to duplicate a few chunks of nbdev
77 |
78 |
79 | def _test_nb(fn, flags=None):
80 | "Execute tests in notebook in `fn` with `flags`"
81 | os.environ["IN_TEST"] = '1'
82 | if flags is None:
83 | flags = []
84 | try:
85 | nb = read_nb(fn)
86 | nb = before_test(nb) # <- THIS is the only change to nbdev code
87 | for f in get_all_flags(nb['cells']):
88 | if f not in flags:
89 | return
90 | ep = NoExportPreprocessor(flags, timeout=600, kernel_name='python3')
91 | pnb = nbformat.from_dict(nb)
92 | ep.preprocess(pnb)
93 | finally:
94 | os.environ.pop("IN_TEST")
95 |
96 |
97 | def _test_one(fname, flags=None, verbose=True):
98 | print(f"testing: {fname}")
99 | start = time.time()
100 | try:
101 | _test_nb(fname, flags=flags)
102 | return True, time.time()-start
103 | except Exception as e:
104 | if "Kernel died before replying to kernel_info" in str(e):
105 | time.sleep(random.random())
106 | _test_one(fname, flags=flags)
107 | if verbose:
108 | print(f'Error in {fname}:\n{e}')
109 | return False, time.time()-start
110 |
111 |
112 | def nbdev_test_nbs(fname=None, flags=None, n_workers=None, verbose=True, timing=False):
113 | """
114 | fname:Param("A notebook name or glob to convert", str)=None,
115 | flags:Param("Space separated list of flags", str)=None,
116 | n_workers:Param("Number of workers to use", int)=None,
117 | verbose:Param("Print errors along the way", bool)=True,
118 | timing:Param("Timing each notebook to see the ones are slow", bool)=False
119 | """
120 | "Test in parallel the notebooks matching `fname`, passing along `flags`"
121 | if flags is not None:
122 | flags = flags.split(' ')
123 | if fname is None:
124 | # files = [f for f in Config().nbs_path.glob(
125 | # '*.ipynb') if not f.name.startswith('_')]
126 | files = [f for f in glob.glob(os.path.join(
127 | Config().nbs_path, '*.ipynb')) if not f.startswith('_')]
128 | else:
129 | files = glob.glob(fname)
130 | files = [Path(f).absolute() for f in sorted(files)]
131 | if len(files) == 1 and n_workers is None:
132 | n_workers = 0
133 | # make sure we are inside the notebook folder of the project
134 | os.chdir(Config().nbs_path)
135 | results = parallel(_test_one, files, flags=flags,
136 | verbose=verbose, n_workers=n_workers)
137 | passed, times = [r[0] for r in results], [r[1] for r in results]
138 | if all(passed):
139 | print("All tests are passing!")
140 | else:
141 | msg = "The following notebooks failed:\n"
142 | raise Exception(
143 | msg + '\n'.join([f.name for p, f in zip(passed, files) if not p]))
144 | if timing:
145 | for i, t in sorted(enumerate(times), key=lambda o: o[1], reverse=True):
146 | print(f"Notebook {files[i].name} took {int(t)} seconds")
147 |
148 |
149 | def test_run():
150 | # now we can "nbdev_test_nbs" and have our "before test" callback called
151 | # nbdev_test_nbs('00_core.ipynb') # Use this line to test a single notebook
152 | nbdev_test_nbs()
153 |
--------------------------------------------------------------------------------
/m3tl/problem_types/seq2seq_text.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/12_5_problem_type_seq2seq_text.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['Seq2Seq', 'seq2seq_text_get_or_make_label_encoder_fn', 'pad_wrapper', 'seq2seq_text_label_handling_fn']
4 |
5 | # Cell
6 | import pickle
7 | from typing import Dict, List, Tuple
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 | from ..base_params import BaseParams
12 | from ..utils import (LabelEncoder, get_label_encoder_save_path,
13 | load_transformer_tokenizer, need_make_label_encoder)
14 |
15 |
16 | # Cell
17 |
18 | class Seq2Seq(tf.keras.Model):
19 | def __init__(self, params: BaseParams, problem_name: str, input_embeddings: tf.keras.layers.Layer):
20 | super(Seq2Seq, self).__init__(name=problem_name)
21 | # self.params = params
22 | # self.problem_name = problem_name
23 | # # if self.params.init_weight_from_huggingface:
24 | # # self.decoder = load_transformer_model(
25 | # # self.params.transformer_decoder_model_name,
26 | # # self.params.transformer_decoder_model_loading)
27 | # # else:
28 | # # self.decoder = load_transformer_model(
29 | # # self.params.bert_decoder_config, self.params.transformer_decoder_model_loading)
30 |
31 | # # TODO: better implementation
32 | # logging.warning(
33 | # 'Seq2Seq model is not well supported yet. Bugs are expected.')
34 | # config = self.params.bert_decoder_config
35 | # # some hacky approach to share embeddings from encoder to decoder
36 | # word_embedding_weight = input_embeddings.word_embeddings
37 | # self.vocab_size = word_embedding_weight.shape[0]
38 | # self.share_embedding_layer = TFSharedEmbeddings(
39 | # vocab_size=word_embedding_weight.shape[0], hidden_size=word_embedding_weight.shape[1])
40 | # self.share_embedding_layer.build([1])
41 | # self.share_embedding_layer.weight = word_embedding_weight
42 | # # self.decoder = TFBartDecoder(
43 | # # config=config, embed_tokens=self.share_embedding_layer)
44 | # self.decoder = TFBartDecoderForConditionalGeneration(
45 | # config=config, embedding_layer=self.share_embedding_layer)
46 | # self.decoder.set_bos_id(self.params.bos_id)
47 | # self.decoder.set_eos_id(self.params.eos_id)
48 |
49 | # self.metric_fn = tf.keras.metrics.SparseCategoricalAccuracy(
50 | # name='{}_acc'.format(self.problem_name))
51 | raise NotImplementedError
52 |
53 | def _seq2seq_label_shift_right(self, labels: tf.Tensor, eos_id: int) -> tf.Tensor:
54 | batch_eos_ids = tf.fill([tf.shape(labels)[0], 1], eos_id)
55 | batch_eos_ids = tf.cast(batch_eos_ids, dtype=tf.int64)
56 | decoder_lable = labels[:, 1:]
57 | decoder_lable = tf.concat([decoder_lable, batch_eos_ids], axis=1)
58 | return decoder_lable
59 |
60 | def call(self,
61 | inputs: Tuple[Dict[str, Dict[str, tf.Tensor]], Dict[str, Dict[str, tf.Tensor]]],
62 | mode: str):
63 | features, hidden_features = inputs
64 | encoder_mask = features['model_input_mask']
65 |
66 | if mode == tf.estimator.ModeKeys.PREDICT:
67 | input_ids = None
68 | decoder_padding_mask = None
69 | else:
70 | input_ids = features['%s_label_ids' % self.problem_name]
71 | decoder_padding_mask = features['{}_mask'.format(
72 | self.problem_name)]
73 |
74 | if mode == tf.estimator.ModeKeys.PREDICT:
75 | return self.decoder.generate(eos_token_id=self.params.eos_id, encoder_hidden_states=hidden_features['seq'])
76 | else:
77 | decoder_output = self.decoder(input_ids=input_ids,
78 | encoder_hidden_states=hidden_features['seq'],
79 | encoder_padding_mask=encoder_mask,
80 | decoder_padding_mask=decoder_padding_mask,
81 | decode_max_length=self.params.decode_max_seq_len,
82 | mode=mode)
83 | loss = decoder_output.loss
84 | logits = decoder_output.logits
85 | self.add_loss(loss)
86 | decoder_label = self._seq2seq_label_shift_right(
87 | features['%s_label_ids' % self.problem_name], eos_id=self.params.eos_id)
88 | acc = self.metric_fn(decoder_label, logits)
89 | self.add_metric(acc)
90 | return logits
91 |
92 | # Cell
93 | def seq2seq_text_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:
94 |
95 | le_path = get_label_encoder_save_path(params=params, problem=problem)
96 | if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):
97 | # fit and save label encoder
98 | label_encoder = load_transformer_tokenizer(
99 | params.transformer_decoder_tokenizer_name, params.transformer_decoder_tokenizer_loading)
100 | pickle.dump(label_encoder, open(le_path, 'wb'))
101 | params.set_problem_info(problem=problem, info_name='num_classes', info=len(label_encoder.encode_dict))
102 | else:
103 | label_encoder = pickle.load(open(le_path, 'rb'))
104 |
105 | return label_encoder
106 |
107 | # Cell
108 | def pad_wrapper(inp, target_len=90):
109 | if len(inp) >= target_len:
110 | return inp[:target_len]
111 | else:
112 | return inp + [0]*(target_len - len(inp))
113 |
114 | def seq2seq_text_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):
115 | target = [label_encoder.bos_token] + \
116 | target + [label_encoder.eos_token]
117 | label_dict = label_encoder(
118 | target, add_special_tokens=False, is_split_into_words=True)
119 | label_id = label_dict['input_ids']
120 | label_mask = label_dict['attention_mask']
121 | label_id = pad_wrapper(label_id, decoding_length)
122 | label_mask = pad_wrapper(label_mask, decoding_length)
123 | return label_id, label_mask
124 |
125 |
--------------------------------------------------------------------------------
/m3tl/problem_types/premask_mlm.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/12_9_problem_type_premask_mlm.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['PreMaskMLM', 'premask_mlm_get_or_make_label_encoder_fn', 'premask_mlm_label_handling_fn']
4 |
5 | # Cell
6 | import numpy as np
7 | import tensorflow as tf
8 | from loguru import logger
9 | from ..base_params import BaseParams
10 | from .utils import (empty_tensor_handling_loss,
11 | nan_loss_handling, pad_to_shape)
12 | from ..special_tokens import PREDICT
13 | from ..utils import gather_indexes, get_phase, load_transformer_tokenizer
14 | from transformers import TFSharedEmbeddings
15 |
16 |
17 | # Cell
18 |
19 | class PreMaskMLM(tf.keras.Model):
20 | def __init__(self, params: BaseParams, problem_name: str, input_embeddings: tf.Tensor=None, share_embedding=False) -> None:
21 | super(PreMaskMLM, self).__init__(name=problem_name)
22 | self.params = params
23 | self.problem_name = problem_name
24 |
25 | # same as masklm
26 | if share_embedding is False:
27 | self.vocab_size = self.params.bert_config.vocab_size
28 | self.share_embedding = False
29 | else:
30 | self.vocab_size = input_embeddings.shape[0]
31 | embedding_size = input_embeddings.shape[-1]
32 | share_valid = (self.params.bert_config.hidden_size ==
33 | embedding_size)
34 | if not share_valid and self.params.share_embedding:
35 | logger.warning(
36 | 'Share embedding is enabled but hidden_size != embedding_size')
37 | self.share_embedding = self.params.share_embedding & share_valid
38 |
39 | if self.share_embedding:
40 | self.share_embedding_layer = TFSharedEmbeddings(
41 | vocab_size=self.vocab_size, hidden_size=input_embeddings.shape[1])
42 | self.share_embedding_layer.build([1])
43 | self.share_embedding_layer.weight = input_embeddings
44 | else:
45 | self.share_embedding_layer = tf.keras.layers.Dense(self.vocab_size)
46 |
47 | def call(self, inputs):
48 | mode = get_phase()
49 | features, hidden_features = inputs
50 |
51 | # masking is done inside the model
52 | seq_hidden_feature = hidden_features['seq']
53 | if mode != PREDICT:
54 | positions = features['{}_masked_lm_positions'.format(self.problem_name)]
55 |
56 | # gather_indexes will flatten the seq hidden_states, we need to reshape
57 | # back to 3d tensor
58 | input_tensor = gather_indexes(seq_hidden_feature, positions)
59 | shape_tensor = tf.shape(positions)
60 | shape_list = tf.concat([shape_tensor, [seq_hidden_feature.shape.as_list()[-1]]], axis=0)
61 | input_tensor = tf.reshape(input_tensor, shape=shape_list)
62 | # set_shape to determin rank
63 | input_tensor.set_shape(
64 | [None, None, seq_hidden_feature.shape.as_list()[-1]])
65 | else:
66 | input_tensor = seq_hidden_feature
67 | if self.share_embedding:
68 | mlm_logits = self.share_embedding_layer(
69 | input_tensor, mode='linear')
70 | else:
71 | mlm_logits = self.share_embedding_layer(input_tensor)
72 | if mode != PREDICT:
73 | mlm_labels = features['{}_masked_lm_ids'.format(self.problem_name)]
74 | mlm_labels.set_shape([None, None])
75 | mlm_labels = pad_to_shape(from_tensor=mlm_labels, to_tensor=mlm_logits, axis=1)
76 | # compute loss
77 | mlm_loss = empty_tensor_handling_loss(
78 | mlm_labels,
79 | mlm_logits,
80 | tf.keras.losses.sparse_categorical_crossentropy
81 | )
82 | loss = nan_loss_handling(mlm_loss)
83 | self.add_loss(loss)
84 |
85 | return tf.nn.softmax(mlm_logits)
86 |
87 | # Cell
88 | def premask_mlm_get_or_make_label_encoder_fn(params: BaseParams, problem, mode, label_list, *args, **kwargs):
89 | tok = load_transformer_tokenizer(tokenizer_name=params.transformer_tokenizer_name, load_module_name=params.transformer_tokenizer_loading)
90 | params.set_problem_info(problem=problem, info_name='num_classes', info=params.bert_config.vocab_size)
91 | return tok
92 |
93 |
94 | # Cell
95 | def premask_mlm_label_handling_fn(target: str, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs) -> dict:
96 |
97 | modal_name = kwargs['modal_name']
98 | modal_type = kwargs['modal_type']
99 | problem = kwargs['problem']
100 | max_predictions_per_seq = 20
101 |
102 | if modal_type != 'text':
103 | return {}
104 |
105 | tokenized_dict = kwargs['tokenized_inputs']
106 |
107 | # create mask lm features
108 | mask_lm_dict = tokenizer(target,
109 | truncation=True,
110 | is_split_into_words=True,
111 | padding='max_length',
112 | max_length=max_predictions_per_seq,
113 | return_special_tokens_mask=False,
114 | add_special_tokens=False,)
115 |
116 | mask_token_id = tokenizer(
117 | '[MASK]', add_special_tokens=False, is_split_into_words=False)['input_ids'][0]
118 | masked_lm_positions = [i for i, input_id in enumerate(
119 | tokenized_dict['input_ids']) if input_id == mask_token_id]
120 | # pad masked_lm_positions to max_predictions_per_seq
121 | if len(masked_lm_positions) < max_predictions_per_seq:
122 | masked_lm_positions = masked_lm_positions + \
123 | [0 for _ in range(max_predictions_per_seq -
124 | len(masked_lm_positions))]
125 | masked_lm_positions = masked_lm_positions[:max_predictions_per_seq]
126 | masked_lm_ids = np.array(mask_lm_dict['input_ids'], dtype='int32')
127 | masked_lm_weights = np.array(mask_lm_dict['attention_mask'], dtype='int32')
128 | mask_lm_dict = {'{}_masked_lm_positions'.format(problem): masked_lm_positions,
129 | '{}_masked_lm_ids'.format(problem): masked_lm_ids,
130 | '{}_masked_lm_weights'.format(problem): masked_lm_weights}
131 | return mask_lm_dict
132 |
--------------------------------------------------------------------------------
/docs/4_problem_type_masklm.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Masked Language Model(masklm)
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 |
10 |
11 | nb_path: "source_nbs/12_4_problem_type_masklm.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 | {% raw %}
33 |
34 |
56 | {% endraw %}
57 |
58 |
59 |
60 |
Imports and utils
61 |
62 |
63 |
64 | {% raw %}
65 |
66 |
67 |
68 |
69 | {% endraw %}
70 |
71 |
77 | {% raw %}
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
MaskLM(*args , **kwargs ) :: Model
89 |
90 |
Multimodal MLM top layer.
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 | {% endraw %}
101 |
102 | {% raw %}
103 |
104 |
105 |
106 |
107 | {% endraw %}
108 |
109 |
110 |
111 |
Get or make label encoder function
112 |
113 |
114 |
115 | {% raw %}
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
masklm_get_or_make_label_encoder_fn(params :BaseParams , problem :str, mode :str, label_list :List[str], *args , **kwargs )
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 | {% endraw %}
138 |
139 | {% raw %}
140 |
141 |
142 |
143 |
144 | {% endraw %}
145 |
146 |
147 |
148 |
Label handing function
149 |
150 |
151 |
152 | {% raw %}
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
masklm_label_handling_fn[source] masklm_label_handling_fn(target , label_encoder =None , tokenizer =None , decoding_length =None , *args , **kwargs )
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 | {% endraw %}
175 |
176 | {% raw %}
177 |
178 |
179 |
180 |
181 | {% endraw %}
182 |
183 |
184 |
185 |
186 |
--------------------------------------------------------------------------------