'
13 |
14 | hr_faded: '
'
15 | hr_shaded: '
'
--------------------------------------------------------------------------------
/archive_nbs/single_profile.py:
--------------------------------------------------------------------------------
1 | import torch_xla.debug.metrics as met
2 |
3 | import fastai_xla_extensions.core
4 | from fastai2.vision.all import *
5 | from my_timesaver_utils.profiling import *
6 | path = untar_data(URLs.PETS)/'images'
7 | Path.BASE_PATH = path; path.ls()
8 | print(f'running on default_device() & cuda is {torch.cuda.is_available()}')
9 |
10 | img = PILImage.create(path/'Abyssinian_1.jpg')
11 | resize = Resize(size=200)
12 | img2 = resize(img,split_idx=0)
13 |
14 |
15 |
16 |
17 | timg2 = TensorImage(array(img2)).permute(2,0,1).float()/255.
18 |
19 | def batch_ex(bs, device): return TensorImage(timg2[None].to(device).expand(bs, *timg2.shape))
20 |
21 |
22 | b768_img = batch_ex(768, default_device()); (b768_img.shape, b768_img.device)
23 |
24 |
25 | flip_tfm = Flip(p=1.0)
26 | # run without profile
27 | run_with_profile = True
28 | F.grid_sample = profile_call(F.grid_sample) if run_with_profile else F.grid_sample
29 |
30 | @profile_call
31 | def mtest(b_img):
32 | #set_trace()
33 | new_b_img = flip_tfm(b_img)
34 | return new_b_img
35 |
36 | clear_prof_data()
37 | print("--- 10 image tensor loops:")
38 | for i in range(10):
39 | print("--- ---------------------------------")
40 | new_b768_img = mtest(b768_img)
41 | print("--- ")
42 | print_prof_data()
43 |
44 | print(met.metrics_report())
45 |
--------------------------------------------------------------------------------
/docs/licenses/LICENSE:
--------------------------------------------------------------------------------
1 | /* This license pertains to the docs template, except for the Navgoco jQuery component. */
2 |
3 | The MIT License (MIT)
4 |
5 | Original theme: Copyright (c) 2016 Tom Johnson
6 | Modifications: Copyright (c) 2017 onwards fast.ai, Inc
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 |
--------------------------------------------------------------------------------
/docs/feed.xml:
--------------------------------------------------------------------------------
1 | ---
2 | search: exclude
3 | layout: none
4 | ---
5 |
6 |
7 |
8 |
9 | {{ site.title | xml_escape }}
10 | {{ site.description | xml_escape }}
11 | {{ site.url }}/
12 |
13 | {{ site.time | date_to_rfc822 }}
14 | {{ site.time | date_to_rfc822 }}
15 | Jekyll v{{ jekyll.version }}
16 | {% for post in site.posts limit:10 %}
17 | -
18 |
{{ post.title | xml_escape }}
19 | {{ post.content | xml_escape }}
20 | {{ post.date | date_to_rfc822 }}
21 | {{ post.url | prepend: site.url }}
22 | {{ post.url | prepend: site.url }}
23 | {% for tag in post.tags %}
24 | {{ tag | xml_escape }}
25 | {% endfor %}
26 | {% for tag in page.tags %}
27 | {{ cat | xml_escape }}
28 | {% endfor %}
29 |
30 | {% endfor %}
31 |
32 |
33 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | repository: butchland/fastai_xla_extensions
2 | output: web
3 | topnav_title: fastai_xla_extensions
4 | site_title: fastai_xla_extensions
5 | company_name: Butch Landingin
6 | description: A library to extend fastai to run on TPUs using pytorch-xla
7 | # Set to false to disable KaTeX math
8 | use_math: true
9 | # Add Google analytics id if you have one and want to use it here
10 | google_analytics:
11 | # See http://nbdev.fast.ai/search for help with adding Search
12 | google_search:
13 |
14 | host: 127.0.0.1
15 | # the preview server used. Leave as is.
16 | port: 4000
17 | # the port where the preview is rendered.
18 |
19 | exclude:
20 | - .idea/
21 | - .gitignore
22 | - vendor
23 |
24 | exclude: [vendor]
25 |
26 | highlighter: rouge
27 | markdown: kramdown
28 | kramdown:
29 | input: GFM
30 | auto_ids: true
31 | hard_wrap: false
32 | syntax_highlighter: rouge
33 |
34 | collections:
35 | tooltips:
36 | output: false
37 |
38 | defaults:
39 | -
40 | scope:
41 | path: ""
42 | type: "pages"
43 | values:
44 | layout: "page"
45 | comments: true
46 | search: true
47 | sidebar: home_sidebar
48 | topnav: topnav
49 | -
50 | scope:
51 | path: ""
52 | type: "tooltips"
53 | values:
54 | layout: "page"
55 | comments: true
56 | search: true
57 | tooltip: true
58 |
59 | sidebars:
60 | - home_sidebar
61 | permalink: pretty
62 |
63 | theme: jekyll-theme-cayman
64 | baseurl: /fastai_xla_extensions/
--------------------------------------------------------------------------------
/docs/_data/sidebars/home_sidebar.yml:
--------------------------------------------------------------------------------
1 |
2 | #################################################
3 | ### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###
4 | #################################################
5 | # Instead edit ../../sidebar.json
6 | entries:
7 | - folders:
8 | - folderitems:
9 | - output: web,pdf
10 | title: Overview
11 | url: /
12 | - output: web,pdf
13 | title: Core XLA extensions
14 | url: core
15 | - output: web,pdf
16 | title: Utils
17 | url: utils
18 | - output: web,pdf
19 | title: CIFAR Loader
20 | url: cifar_loader
21 | - output: web,pdf
22 | title: Miscellaneous Utilities
23 | url: misc_utils
24 | - output: web,pdf
25 | title: 'Multi Core XLA Base '
26 | url: multi_core.base
27 | - output: web,pdf
28 | title: Torch Compatible Utilities
29 | url: multi_core.torch_compat
30 | - output: web,pdf
31 | title: Multi Core XLA Learner extensions
32 | url: multi_core.learner
33 | - output: web,pdf
34 | title: Multi Core Callback XLA Extensions
35 | url: multi_core.callback
36 | - output: web,pdf
37 | title: Multi Core LR Find XLA Extensions
38 | url: multi_core.lr_find
39 | - output: web,pdf
40 | title: 'Multi Core XLA Inference '
41 | url: multi_core.inference
42 | - output: web,pdf
43 | title: Development Setup
44 | url: dev_setup
45 | output: web
46 | title: fastai_xla_extensions
47 | output: web
48 | title: Sidebar
49 |
--------------------------------------------------------------------------------
/fastai_xla_extensions/utils.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01_utils.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['xla_imported', 'print_aten_ops']
4 |
5 | # Internal Cell
6 | try:
7 | import torch_xla
8 | except ImportError:
9 | pass
10 |
11 | # Cell
12 | import sys
13 |
14 | def xla_imported():
15 | "Check whether the `torch_xla` module has been successfully imported"
16 | return 'torch_xla' in sys.modules
17 |
18 | # Internal Cell
19 | if xla_imported():
20 | import torch_xla.debug.metrics as met
21 |
22 | # Cell
23 | def print_aten_ops():
24 | "print out xla aten operations (from xla debug metrics report `torch_xla.debug.metrics`)"
25 | # import torch_xla.debug.metrics as met
26 | from io import StringIO
27 | import sys
28 |
29 | class Capturing(list):
30 | def __enter__(self):
31 | self._stdout = sys.stdout
32 | sys.stdout = self._stringio = StringIO()
33 | return self
34 | def __exit__(self, *args):
35 | self.extend(self._stringio.getvalue().splitlines())
36 | del self._stringio # free up some memory
37 | sys.stdout = self._stdout
38 |
39 | out = met.metrics_report()
40 | if out.find("aten::"):
41 | print_now=False
42 | lines = out.split("\n")
43 | for l in lines:
44 | if print_now:
45 | print_now=False
46 | print(l)
47 | if l.find("aten::")>-1:
48 | print("needs lowering:", l)
49 | print_now=True
--------------------------------------------------------------------------------
/docs/_includes/links.html:
--------------------------------------------------------------------------------
1 | {% comment %}Get links from each sidebar, as listed in the _config.yml file under sidebars{% endcomment %}
2 |
3 | {% for sidebar in site.sidebars %}
4 | {% for entry in site.data.sidebars[sidebar].entries %}
5 | {% for folder in entry.folders %}
6 | {% for folderitem in folder.folderitems %}
7 | {% if folderitem.url contains "html#" %}
8 | [{{folderitem.url | remove: "/" }}]: {{folderitem.url | remove: "/"}}
9 | {% else %}
10 | [{{folderitem.url | remove: "/" | remove: ".html"}}]: {{folderitem.url | remove: "/"}}
11 | {% endif %}
12 | {% for subfolders in folderitem.subfolders %}
13 | {% for subfolderitem in subfolders.subfolderitems %}
14 | [{{subfolderitem.url | remove: "/" | remove: ".html"}}]: {{subfolderitem.url | remove: "/"}}
15 | {% endfor %}
16 | {% endfor %}
17 | {% endfor %}
18 | {% endfor %}
19 | {% endfor %}
20 | {% endfor %}
21 |
22 |
23 | {% comment %} Get links from topnav {% endcomment %}
24 |
25 | {% for entry in site.data.topnav.topnav %}
26 | {% for item in entry.items %}
27 | {% if item.external_url == null %}
28 | [{{item.url | remove: "/" | remove: ".html"}}]: {{item.url | remove: "/"}}
29 | {% endif %}
30 | {% endfor %}
31 | {% endfor %}
32 |
33 | {% comment %}Get links from topnav dropdowns {% endcomment %}
34 |
35 | {% for entry in site.data.topnav.topnav_dropdowns %}
36 | {% for folder in entry.folders %}
37 | {% for folderitem in folder.folderitems %}
38 | {% if folderitem.external_url == null %}
39 | [{{folderitem.url | remove: "/" | remove: ".html"}}]: {{folderitem.url | remove: "/"}}
40 | {% endif %}
41 | {% endfor %}
42 | {% endfor %}
43 | {% endfor %}
44 |
45 |
--------------------------------------------------------------------------------
/docs/css/modern-business.css:
--------------------------------------------------------------------------------
1 | /*!
2 | * Start Bootstrap - Modern Business HTML Template (http://startbootstrap.com)
3 | * Code licensed under the Apache License v2.0.
4 | * For details, see http://www.apache.org/licenses/LICENSE-2.0.
5 | */
6 |
7 | /* Global Styles */
8 |
9 | html,
10 | body {
11 | height: 100%;
12 | }
13 |
14 | .img-portfolio {
15 | margin-bottom: 30px;
16 | }
17 |
18 | .img-hover:hover {
19 | opacity: 0.8;
20 | }
21 |
22 | /* Home Page Carousel */
23 |
24 | header.carousel {
25 | height: 50%;
26 | }
27 |
28 | header.carousel .item,
29 | header.carousel .item.active,
30 | header.carousel .carousel-inner {
31 | height: 100%;
32 | }
33 |
34 | header.carousel .fill {
35 | width: 100%;
36 | height: 100%;
37 | background-position: center;
38 | background-size: cover;
39 | }
40 |
41 | /* 404 Page Styles */
42 |
43 | .error-404 {
44 | font-size: 100px;
45 | }
46 |
47 | /* Pricing Page Styles */
48 |
49 | .price {
50 | display: block;
51 | font-size: 50px;
52 | line-height: 50px;
53 | }
54 |
55 | .price sup {
56 | top: -20px;
57 | left: 2px;
58 | font-size: 20px;
59 | }
60 |
61 | .period {
62 | display: block;
63 | font-style: italic;
64 | }
65 |
66 | /* Footer Styles */
67 |
68 | footer {
69 | margin: 50px 0;
70 | }
71 |
72 | /* Responsive Styles */
73 |
74 | @media(max-width:991px) {
75 | .client-img,
76 | .img-related {
77 | margin-bottom: 30px;
78 | }
79 | }
80 |
81 | @media(max-width:767px) {
82 | .img-portfolio {
83 | margin-bottom: 15px;
84 | }
85 |
86 | header.carousel .carousel {
87 | height: 70%;
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/docs/_includes/head_print.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
{% if page.homepage == true %} {{site.homepage_title}} {% elsif page.title %}{{ page.title }}{% endif %} | {{ site.site_title }}
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
23 |
24 |
29 |
--------------------------------------------------------------------------------
/docs/licenses/LICENSE-BSD-NAVGOCO.txt:
--------------------------------------------------------------------------------
1 | /* This license pertains to the Navgoco jQuery component used for the sidebar. */
2 |
3 | Copyright (c) 2013, Christodoulos Tsoulloftas, http://www.komposta.net
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without modification,
7 | are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice,
10 | this list of conditions and the following disclaimer.
11 | * Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 | * Neither the name of the
nor the names of its
15 | contributors may be used to endorse or promote products derived from this
16 | software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
21 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
22 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
23 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
26 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
27 | OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/docs/_layouts/page.html:
--------------------------------------------------------------------------------
1 | ---
2 | layout: default
3 | ---
4 |
5 |
9 |
10 | {% if page.simple_map == true %}
11 |
12 |
17 |
18 | {% include custom/{{page.map_name}}.html %}
19 |
20 | {% elsif page.complex_map == true %}
21 |
22 |
27 |
28 | {% include custom/{{page.map_name}}.html %}
29 |
30 | {% endif %}
31 |
32 |
33 |
34 | {% if page.summary %}
35 |
{{page.summary}}
36 | {% endif %}
37 |
38 | {% unless page.toc == false %}
39 | {% include toc.html %}
40 | {% endunless %}
41 |
42 |
43 | {% if site.github_editme_path %}
44 |
45 |
Edit me
46 |
47 | {% endif %}
48 |
49 | {{content}}
50 |
51 |
62 |
63 |
64 |
65 | {{site.data.alerts.hr_shaded}}
66 |
67 | {% include footer.html %}
68 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pkg_resources import parse_version
2 | from configparser import ConfigParser
3 | import setuptools
4 | assert parse_version(setuptools.__version__)>=parse_version('36.2')
5 |
6 | # note: all settings are in settings.ini; edit there, not here
7 | config = ConfigParser(delimiters=['='])
8 | config.read('settings.ini')
9 | cfg = config['DEFAULT']
10 |
11 | cfg_keys = 'version description keywords author author_email'.split()
12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split()
13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o)
14 | setup_cfg = {o:cfg[o] for o in cfg_keys}
15 |
16 | licenses = {
17 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'),
18 | }
19 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha',
20 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ]
21 | py_versions = '2.0 2.1 2.2 2.3 2.4 2.5 2.6 2.7 3.0 3.1 3.2 3.3 3.4 3.5 3.6 3.7 3.8'.split()
22 |
23 | requirements = cfg.get('requirements','').split()
24 | lic = licenses[cfg['license']]
25 | min_python = cfg['min_python']
26 |
27 | setuptools.setup(
28 | name = cfg['lib_name'],
29 | license = lic[0],
30 | classifiers = [
31 | 'Development Status :: ' + statuses[int(cfg['status'])],
32 | 'Intended Audience :: ' + cfg['audience'].title(),
33 | 'License :: ' + lic[1],
34 | 'Natural Language :: ' + cfg['language'].title(),
35 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]],
36 | url = cfg['git_url'],
37 | packages = setuptools.find_packages(),
38 | include_package_data = True,
39 | install_requires = requirements,
40 | dependency_links = cfg.get('dep_links','').split(),
41 | python_requires = '>=' + cfg['min_python'],
42 | long_description = open('README.md').read(),
43 | long_description_content_type = 'text/markdown',
44 | zip_safe = False,
45 | entry_points = { 'console_scripts': cfg.get('console_scripts','').split() },
46 | **setup_cfg)
47 |
48 |
--------------------------------------------------------------------------------
/docs/js/customscripts.js:
--------------------------------------------------------------------------------
1 | $('#mysidebar').height($(".nav").height());
2 |
3 |
4 | $( document ).ready(function() {
5 |
6 | //this script says, if the height of the viewport is greater than 800px, then insert affix class, which makes the nav bar float in a fixed
7 | // position as your scroll. if you have a lot of nav items, this height may not work for you.
8 | var h = $(window).height();
9 | //console.log (h);
10 | if (h > 800) {
11 | $( "#mysidebar" ).attr("class", "nav affix");
12 | }
13 | // activate tooltips. although this is a bootstrap js function, it must be activated this way in your theme.
14 | $('[data-toggle="tooltip"]').tooltip({
15 | placement : 'top'
16 | });
17 |
18 | /**
19 | * AnchorJS
20 | */
21 | anchors.add('h2,h3,h4,h5');
22 |
23 | });
24 |
25 | // needed for nav tabs on pages. See Formatting > Nav tabs for more details.
26 | // script from http://stackoverflow.com/questions/10523433/how-do-i-keep-the-current-tab-active-with-twitter-bootstrap-after-a-page-reload
27 | $(function() {
28 | var json, tabsState;
29 | $('a[data-toggle="pill"], a[data-toggle="tab"]').on('shown.bs.tab', function(e) {
30 | var href, json, parentId, tabsState;
31 |
32 | tabsState = localStorage.getItem("tabs-state");
33 | json = JSON.parse(tabsState || "{}");
34 | parentId = $(e.target).parents("ul.nav.nav-pills, ul.nav.nav-tabs").attr("id");
35 | href = $(e.target).attr('href');
36 | json[parentId] = href;
37 |
38 | return localStorage.setItem("tabs-state", JSON.stringify(json));
39 | });
40 |
41 | tabsState = localStorage.getItem("tabs-state");
42 | json = JSON.parse(tabsState || "{}");
43 |
44 | $.each(json, function(containerId, href) {
45 | return $("#" + containerId + " a[href=" + href + "]").tab('show');
46 | });
47 |
48 | $("ul.nav.nav-pills, ul.nav.nav-tabs").each(function() {
49 | var $this = $(this);
50 | if (!json[$this.attr("id")]) {
51 | return $this.find("a[data-toggle=tab]:first, a[data-toggle=pill]:first").tab("show");
52 | }
53 | });
54 | });
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.bak
2 | .gitattributes
3 | .last_checked
4 | .gitconfig
5 | *.bak
6 | *.log
7 | *~
8 | ~*
9 | _tmp*
10 | tmp*
11 | tags
12 |
13 | # Byte-compiled / optimized / DLL files
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | env/
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | .hypothesis/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # SageMath parsed files
92 | *.sage.py
93 |
94 | # dotenv
95 | .env
96 |
97 | # virtualenv
98 | .venv
99 | venv/
100 | ENV/
101 |
102 | # Spyder project settings
103 | .spyderproject
104 | .spyproject
105 |
106 | # Rope project settings
107 | .ropeproject
108 |
109 | # mkdocs documentation
110 | /site
111 |
112 | # mypy
113 | .mypy_cache/
114 |
115 | .vscode
116 | *.swp
117 |
118 | # osx generated files
119 | .DS_Store
120 | .DS_Store?
121 | .Trashes
122 | ehthumbs.db
123 | Thumbs.db
124 | .idea
125 |
126 | # pytest
127 | .pytest_cache
128 |
129 | # tools/trust-doc-nbs
130 | docs_src/.last_checked
131 |
132 | # symlinks to fastai
133 | docs_src/fastai
134 | tools/fastai
135 |
136 | # link checker
137 | checklink/cookies.txt
138 |
139 | # .gitconfig is now autogenerated
140 | .gitconfig
141 | # misc
142 | nbs/models
143 |
144 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | ## How to get started
4 |
5 | Before anything else, please install the git hooks that run automatic scripts during each commit and merge to strip the notebooks of superfluous metadata (and avoid merge conflicts). After cloning the repository, run the following command inside it:
6 | ```
7 | nbdev_install_git_hooks
8 | ```
9 |
10 | ## Did you find a bug?
11 |
12 | * Ensure the bug was not already reported by searching on GitHub under Issues.
13 | * If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring.
14 | * Be sure to add the complete error messages.
15 |
16 | #### Did you write a patch that fixes a bug?
17 |
18 | * Open a new GitHub pull request with the patch.
19 | * Ensure that your PR includes a test that fails without your patch, and pass with it.
20 | * Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable.
21 |
22 | ## PR submission guidelines
23 |
24 | * Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused.
25 | * Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected.
26 | * Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can.
27 | * Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project.
28 | * If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another.
29 |
30 | ## Do you want to contribute to the documentation?
31 |
32 | * Docs are automatically created from the notebooks in the nbs folder.
33 |
34 |
--------------------------------------------------------------------------------
/docs/css/theme-green.css:
--------------------------------------------------------------------------------
1 | .summary {
2 | color: #808080;
3 | border-left: 5px solid #E50E51;
4 | font-size:16px;
5 | }
6 |
7 |
8 | h3 {color: #E50E51; }
9 | h4 {color: #808080; }
10 |
11 | .nav-tabs > li.active > a, .nav-tabs > li.active > a:hover, .nav-tabs > li.active > a:focus {
12 | background-color: #248ec2;
13 | color: white;
14 | }
15 |
16 | .nav > li.active > a {
17 | background-color: #72ac4a;
18 | }
19 |
20 | .nav > li > a:hover {
21 | background-color: #72ac4a;
22 | }
23 |
24 | div.navbar-collapse .dropdown-menu > li > a:hover {
25 | background-color: #72ac4a;
26 | }
27 |
28 | .navbar-inverse .navbar-nav>li>a, .navbar-inverse .navbar-brand {
29 | color: white;
30 | }
31 |
32 | .navbar-inverse .navbar-nav>li>a:hover, a.fa.fa-home.fa-lg.navbar-brand:hover {
33 | color: #f0f0f0;
34 | }
35 |
36 | .nav li.thirdlevel > a {
37 | background-color: #FAFAFA !important;
38 | color: #72ac4a;
39 | font-weight: bold;
40 | }
41 |
42 | a[data-toggle="tooltip"] {
43 | color: #649345;
44 | font-style: italic;
45 | cursor: default;
46 | }
47 |
48 | .navbar-inverse {
49 | background-color: #72ac4a;
50 | border-color: #5b893c;
51 | }
52 |
53 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
54 | color: #5b893c;
55 | }
56 |
57 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
58 | background-color: #5b893c;
59 | color: #ffffff;
60 | }
61 |
62 | /* not sure if using this ...*/
63 | .navbar-inverse .navbar-collapse, .navbar-inverse .navbar-form {
64 | border-color: #72ac4a !important;
65 | }
66 |
67 | .btn-primary {
68 | color: #ffffff;
69 | background-color: #5b893c;
70 | border-color: #5b893c;
71 | }
72 |
73 | .btn-primary:hover,
74 | .btn-primary:focus,
75 | .btn-primary:active,
76 | .btn-primary.active,
77 | .open .dropdown-toggle.btn-primary {
78 | background-color: #72ac4a;
79 | border-color: #5b893c;
80 | }
81 |
82 | .printTitle {
83 | color: #5b893c !important;
84 | }
85 |
86 | body.print h1 {color: #5b893c !important; font-size:28px;}
87 | body.print h2 {color: #595959 !important; font-size:24px;}
88 | body.print h3 {color: #E50E51 !important; font-size:14px;}
89 | body.print h4 {color: #679DCE !important; font-size:14px; font-style: italic;}
90 |
91 | .anchorjs-link:hover {
92 | color: #4f7233;
93 | }
94 |
95 | div.sidebarTitle {
96 | color: #E50E51;
97 | }
98 |
99 | li.sidebarTitle {
100 | margin-top:20px;
101 | font-weight:normal;
102 | font-size:130%;
103 | color: #ED1951;
104 | margin-bottom:10px;
105 | margin-left: 5px;
106 | }
107 |
108 | .navbar-inverse .navbar-toggle:focus, .navbar-inverse .navbar-toggle:hover {
109 | background-color: #E50E51;
110 | }
111 |
--------------------------------------------------------------------------------
/docs/_includes/sidebar.html:
--------------------------------------------------------------------------------
1 | {% assign sidebar = site.data.sidebars[page.sidebar].entries %}
2 | {% assign pageurl = page.url | remove: ".html" %}
3 |
4 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/docs/css/theme-blue.css:
--------------------------------------------------------------------------------
1 | .summary {
2 | color: #808080;
3 | border-left: 5px solid #ED1951;
4 | font-size:16px;
5 | }
6 |
7 |
8 | h3 {color: #000000; }
9 | h4 {color: #000000; }
10 |
11 | .nav-tabs > li.active > a, .nav-tabs > li.active > a:hover, .nav-tabs > li.active > a:focus {
12 | background-color: #248ec2;
13 | color: white;
14 | }
15 |
16 | .nav > li.active > a {
17 | background-color: #347DBE;
18 | }
19 |
20 | .nav > li > a:hover {
21 | background-color: #248ec2;
22 | }
23 |
24 | div.navbar-collapse .dropdown-menu > li > a:hover {
25 | background-color: #347DBE;
26 | }
27 |
28 | .nav li.thirdlevel > a {
29 | background-color: #FAFAFA !important;
30 | color: #248EC2;
31 | font-weight: bold;
32 | }
33 |
34 | a[data-toggle="tooltip"] {
35 | color: #649345;
36 | font-style: italic;
37 | cursor: default;
38 | }
39 |
40 | .navbar-inverse {
41 | background-color: #347DBE;
42 | border-color: #015CAE;
43 | }
44 | .navbar-inverse .navbar-nav>li>a, .navbar-inverse .navbar-brand {
45 | color: white;
46 | }
47 |
48 | .navbar-inverse .navbar-nav>li>a:hover, a.fa.fa-home.fa-lg.navbar-brand:hover {
49 | color: #f0f0f0;
50 | }
51 |
52 | a.navbar-brand:hover {
53 | color: #f0f0f0;
54 | }
55 |
56 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
57 | color: #015CAE;
58 | }
59 |
60 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
61 | background-color: #015CAE;
62 | color: #ffffff;
63 | }
64 |
65 | .navbar-inverse .navbar-collapse, .navbar-inverse .navbar-form {
66 | border-color: #248ec2 !important;
67 | }
68 |
69 | .btn-primary {
70 | color: #ffffff;
71 | background-color: #347DBE;
72 | border-color: #347DBE;
73 | }
74 |
75 | .navbar-inverse .navbar-nav > .active > a, .navbar-inverse .navbar-nav > .active > a:hover, .navbar-inverse .navbar-nav > .active > a:focus {
76 | background-color: #347DBE;
77 | }
78 |
79 | .btn-primary:hover,
80 | .btn-primary:focus,
81 | .btn-primary:active,
82 | .btn-primary.active,
83 | .open .dropdown-toggle.btn-primary {
84 | background-color: #248ec2;
85 | border-color: #347DBE;
86 | }
87 |
88 | .printTitle {
89 | color: #015CAE !important;
90 | }
91 |
92 | body.print h1 {color: #015CAE !important; font-size:28px !important;}
93 | body.print h2 {color: #595959 !important; font-size:20px !important;}
94 | body.print h3 {color: #E50E51 !important; font-size:14px !important;}
95 | body.print h4 {color: #679DCE !important; font-size:14px; font-style: italic !important;}
96 |
97 | .anchorjs-link:hover {
98 | color: #216f9b;
99 | }
100 |
101 | div.sidebarTitle {
102 | color: #015CAE;
103 | }
104 |
105 | li.sidebarTitle {
106 | margin-top:20px;
107 | font-weight:normal;
108 | font-size:130%;
109 | color: #ED1951;
110 | margin-bottom:10px;
111 | margin-left: 5px;
112 |
113 | }
114 |
115 | .navbar-inverse .navbar-toggle:focus, .navbar-inverse .navbar-toggle:hover {
116 | background-color: #015CAE;
117 | }
118 |
119 | .navbar-inverse .navbar-toggle {
120 | border-color: #015CAE;
121 | }
122 |
--------------------------------------------------------------------------------
/docs/_includes/topnav.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
13 |
14 |
15 |
16 | Nav
17 |
18 |
19 | {% assign topnav = site.data[page.topnav] %}
20 | {% assign topnav_dropdowns = site.data[page.topnav].topnav_dropdowns %}
21 |
22 | {% for entry in topnav.topnav %}
23 | {% for item in entry.items %}
24 | {% if item.external_url %}
25 | {{item.title}}
26 | {% elsif page.url contains item.url %}
27 | {{item.title}}
28 | {% else %}
29 | {{item.title}}
30 | {% endif %}
31 | {% endfor %}
32 | {% endfor %}
33 |
34 |
35 | {% for entry in topnav_dropdowns %}
36 | {% for folder in entry.folders %}
37 |
38 | {{ folder.title }}
39 |
50 |
51 | {% endfor %}
52 | {% endfor %}
53 | {% if site.google_search %}
54 |
55 | {% include search_google_custom.html %}
56 |
57 | {% endif %}
58 |
59 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/docs/js/jquery.navgoco.min.js:
--------------------------------------------------------------------------------
1 | /*
2 | * jQuery Navgoco Menus Plugin v0.2.1 (2014-04-11)
3 | * https://github.com/tefra/navgoco
4 | *
5 | * Copyright (c) 2014 Chris T (@tefra)
6 | * BSD - https://github.com/tefra/navgoco/blob/master/LICENSE-BSD
7 | */
8 | !function(a){"use strict";var b=function(b,c,d){return this.el=b,this.$el=a(b),this.options=c,this.uuid=this.$el.attr("id")?this.$el.attr("id"):d,this.state={},this.init(),this};b.prototype={init:function(){var b=this;b._load(),b.$el.find("ul").each(function(c){var d=a(this);d.attr("data-index",c),b.options.save&&b.state.hasOwnProperty(c)?(d.parent().addClass(b.options.openClass),d.show()):d.parent().hasClass(b.options.openClass)?(d.show(),b.state[c]=1):d.hide()});var c=a(" ").prepend(b.options.caretHtml),d=b.$el.find("li > a");b._trigger(c,!1),b._trigger(d,!0),b.$el.find("li:has(ul) > a").prepend(c)},_trigger:function(b,c){var d=this;b.on("click",function(b){b.stopPropagation();var e=c?a(this).next():a(this).parent().next(),f=!1;if(c){var g=a(this).attr("href");f=void 0===g||""===g||"#"===g}if(e=e.length>0?e:!1,d.options.onClickBefore.call(this,b,e),!c||e&&f)b.preventDefault(),d._toggle(e,e.is(":hidden")),d._save();else if(d.options.accordion){var h=d.state=d._parents(a(this));d.$el.find("ul").filter(":visible").each(function(){var b=a(this),c=b.attr("data-index");h.hasOwnProperty(c)||d._toggle(b,!1)}),d._save()}d.options.onClickAfter.call(this,b,e)})},_toggle:function(b,c){var d=this,e=b.attr("data-index"),f=b.parent();if(d.options.onToggleBefore.call(this,b,c),c){if(f.addClass(d.options.openClass),b.slideDown(d.options.slide),d.state[e]=1,d.options.accordion){var g=d.state=d._parents(b);g[e]=d.state[e]=1,d.$el.find("ul").filter(":visible").each(function(){var b=a(this),c=b.attr("data-index");g.hasOwnProperty(c)||d._toggle(b,!1)})}}else f.removeClass(d.options.openClass),b.slideUp(d.options.slide),d.state[e]=0;d.options.onToggleAfter.call(this,b,c)},_parents:function(b,c){var d={},e=b.parent(),f=e.parents("ul");return f.each(function(){var b=a(this),e=b.attr("data-index");return e?void(d[e]=c?b:1):!1}),d},_save:function(){if(this.options.save){var b={};for(var d in this.state)1===this.state[d]&&(b[d]=1);c[this.uuid]=this.state=b,a.cookie(this.options.cookie.name,JSON.stringify(c),this.options.cookie)}},_load:function(){if(this.options.save){if(null===c){var b=a.cookie(this.options.cookie.name);c=b?JSON.parse(b):{}}this.state=c.hasOwnProperty(this.uuid)?c[this.uuid]:{}}},toggle:function(b){var c=this,d=arguments.length;if(1>=d)c.$el.find("ul").each(function(){var d=a(this);c._toggle(d,b)});else{var e,f={},g=Array.prototype.slice.call(arguments,1);d--;for(var h=0;d>h;h++){e=g[h];var i=c.$el.find('ul[data-index="'+e+'"]').first();if(i&&(f[e]=i,b)){var j=c._parents(i,!0);for(var k in j)f.hasOwnProperty(k)||(f[k]=j[k])}}for(e in f)c._toggle(f[e],b)}c._save()},destroy:function(){a.removeData(this.$el),this.$el.find("li:has(ul) > a").unbind("click"),this.$el.find("li:has(ul) > a > span").unbind("click")}},a.fn.navgoco=function(c){if("string"==typeof c&&"_"!==c.charAt(0)&&"init"!==c)var d=!0,e=Array.prototype.slice.call(arguments,1);else c=a.extend({},a.fn.navgoco.defaults,c||{}),a.cookie||(c.save=!1);return this.each(function(f){var g=a(this),h=g.data("navgoco");h||(h=new b(this,d?a.fn.navgoco.defaults:c,f),g.data("navgoco",h)),d&&h[c].apply(h,e)})};var c=null;a.fn.navgoco.defaults={caretHtml:"",accordion:!1,openClass:"open",save:!0,cookie:{name:"navgoco",expires:!1,path:"/"},slide:{duration:400,easing:"swing"},onClickBefore:a.noop,onClickAfter:a.noop,onToggleBefore:a.noop,onToggleAfter:a.noop}}(jQuery);
--------------------------------------------------------------------------------
/docs/misc_utils.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Miscellaneous Utilities
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 |
10 |
11 | nb_path: "nbs/02b_misc_utils.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
_BaseOptimizer patches
42 |
43 |
44 |
45 | {% raw %}
46 |
47 |
48 |
49 |
50 | {% endraw %}
51 |
52 | {% raw %}
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
_BaseOptimizer.__getstate__()
64 |
65 |
Pickling opt state should include param_groups and defaults
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
_BaseOptimizer.__setstate__(data )
76 |
77 |
Pickling opt state should include param_groups and defaults
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 | {% endraw %}
88 |
89 |
90 |
91 |
Patch the fastai.optimizer._BaseOptimizer __getstate__ and __setstate__ methods which are used in pickling fastai optimizers.
92 |
This should fix the bug where running the learner on multiple TPU cores on XLA triggers an error in which the method _fetch_gradients(optimizer) fails in the statement for param_group in optimizer.__getstate__()['param_groups']: in the torch_xla.core.xla_model module.
93 |
The patch modifies the copy constructor to include the param_groups and defaults.
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/docs/_includes/initialize_shuffle.html:
--------------------------------------------------------------------------------
1 |
7 |
8 |
100 |
101 |
102 |
103 |
114 |
115 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/docs/css/printstyles.css:
--------------------------------------------------------------------------------
1 |
2 | /*body.print .container {max-width: 650px;}*/
3 |
4 | body {
5 | font-size:14px;
6 | }
7 | .nav ul li a {border-top:0px; background-color:transparent; color: #808080; }
8 | #navig a[href] {color: #595959 !important;}
9 | table .table {max-width:650px;}
10 |
11 | #navig li.sectionHead {font-weight: bold; font-size: 18px; color: #595959 !important; }
12 | #navig li {font-weight: normal; }
13 |
14 | #navig a[href]::after { content: leader(".") target-counter(attr(href), page); }
15 |
16 | a[href]::after {
17 | content: " (page " target-counter(attr(href), page) ")"
18 | }
19 |
20 | a[href^="http:"]::after, a[href^="https:"]::after {
21 | content: "";
22 | }
23 |
24 | a[href] {
25 | color: blue !important;
26 | }
27 | a[href*="mailto"]::after, a[data-toggle="tooltip"]::after, a[href].noCrossRef::after {
28 | content: "";
29 | }
30 |
31 |
32 | @page {
33 | margin: 60pt 90pt 60pt 90pt;
34 | font-family: sans-serif;
35 | font-style:none;
36 | color: gray;
37 |
38 | }
39 |
40 | .printTitle {
41 | line-height:30pt;
42 | font-size:27pt;
43 | font-weight: bold;
44 | letter-spacing: -.5px;
45 | margin-bottom:25px;
46 | }
47 |
48 | .printSubtitle {
49 | font-size: 19pt;
50 | color: #cccccc !important;
51 | font-family: "Grotesque MT Light";
52 | line-height: 22pt;
53 | letter-spacing: -.5px;
54 | margin-bottom:20px;
55 | }
56 | .printTitleArea hr {
57 | color: #999999 !important;
58 | height: 2px;
59 | width: 100%;
60 | }
61 |
62 | .printTitleImage {
63 | max-width:300px;
64 | margin-bottom:200px;
65 | }
66 |
67 |
68 | .printTitleImage {
69 | max-width: 250px;
70 | }
71 |
72 | #navig {
73 | /*page-break-before: always;*/
74 | }
75 |
76 | .copyrightBoilerplate {
77 | page-break-before:always;
78 | font-size:14px;
79 | }
80 |
81 | .lastGeneratedDate {
82 | font-style: italic;
83 | font-size:14px;
84 | color: gray;
85 | }
86 |
87 | .alert a {
88 | text-decoration: none !important;
89 | }
90 |
91 |
92 | body.title { page: title }
93 |
94 | @page title {
95 | @top-left {
96 | content: " ";
97 | }
98 | @top-right {
99 | content: " "
100 | }
101 | @bottom-right {
102 | content: " ";
103 | }
104 | @bottom-left {
105 | content: " ";
106 | }
107 | }
108 |
109 | body.frontmatter { page: frontmatter }
110 | body.frontmatter {counter-reset: page 1}
111 |
112 |
113 | @page frontmatter {
114 | @top-left {
115 | content: prince-script(guideName);
116 | }
117 | @top-right {
118 | content: prince-script(datestamp);
119 | }
120 | @bottom-right {
121 | content: counter(page, lower-roman);
122 | }
123 | @bottom-left {
124 | content: "youremail@domain.com"; }
125 | }
126 |
127 | body.first_page {counter-reset: page 1}
128 |
129 | h1 { string-set: doctitle content() }
130 |
131 | @page {
132 | @top-left {
133 | content: string(doctitle);
134 | font-size: 11px;
135 | font-style: italic;
136 | }
137 | @top-right {
138 | content: prince-script(datestamp);
139 | font-size: 11px;
140 | }
141 |
142 | @bottom-right {
143 | content: "Page " counter(page);
144 | font-size: 11px;
145 | }
146 | @bottom-left {
147 | content: prince-script(guideName);
148 | font-size: 11px;
149 | }
150 | }
151 | .alert {
152 | background-color: #fafafa !important;
153 | border-color: #dedede !important;
154 | color: black;
155 | }
156 |
157 | pre {
158 | background-color: #fafafa;
159 | }
160 |
--------------------------------------------------------------------------------
/docs/css/syntax.css:
--------------------------------------------------------------------------------
1 | .highlight { background: #ffffff; }
2 | .highlight .c { color: #999988; font-style: italic } /* Comment */
3 | .highlight .err { color: #a61717; background-color: #e3d2d2 } /* Error */
4 | .highlight .k { font-weight: bold } /* Keyword */
5 | .highlight .o { font-weight: bold } /* Operator */
6 | .highlight .cm { color: #999988; font-style: italic } /* Comment.Multiline */
7 | .highlight .cp { color: #999999; font-weight: bold } /* Comment.Preproc */
8 | .highlight .c1 { color: #999988; font-style: italic } /* Comment.Single */
9 | .highlight .cs { color: #999999; font-weight: bold; font-style: italic } /* Comment.Special */
10 | .highlight .gd { color: #000000; background-color: #ffdddd } /* Generic.Deleted */
11 | .highlight .gd .x { color: #000000; background-color: #ffaaaa } /* Generic.Deleted.Specific */
12 | .highlight .ge { font-style: italic } /* Generic.Emph */
13 | .highlight .gr { color: #aa0000 } /* Generic.Error */
14 | .highlight .gh { color: #999999 } /* Generic.Heading */
15 | .highlight .gi { color: #000000; background-color: #ddffdd } /* Generic.Inserted */
16 | .highlight .gi .x { color: #000000; background-color: #aaffaa } /* Generic.Inserted.Specific */
17 | .highlight .go { color: #888888 } /* Generic.Output */
18 | .highlight .gp { color: #555555 } /* Generic.Prompt */
19 | .highlight .gs { font-weight: bold } /* Generic.Strong */
20 | .highlight .gu { color: #aaaaaa } /* Generic.Subheading */
21 | .highlight .gt { color: #aa0000 } /* Generic.Traceback */
22 | .highlight .kc { font-weight: bold } /* Keyword.Constant */
23 | .highlight .kd { font-weight: bold } /* Keyword.Declaration */
24 | .highlight .kp { font-weight: bold } /* Keyword.Pseudo */
25 | .highlight .kr { font-weight: bold } /* Keyword.Reserved */
26 | .highlight .kt { color: #445588; font-weight: bold } /* Keyword.Type */
27 | .highlight .m { color: #009999 } /* Literal.Number */
28 | .highlight .s { color: #d14 } /* Literal.String */
29 | .highlight .na { color: #008080 } /* Name.Attribute */
30 | .highlight .nb { color: #0086B3 } /* Name.Builtin */
31 | .highlight .nc { color: #445588; font-weight: bold } /* Name.Class */
32 | .highlight .no { color: #008080 } /* Name.Constant */
33 | .highlight .ni { color: #800080 } /* Name.Entity */
34 | .highlight .ne { color: #990000; font-weight: bold } /* Name.Exception */
35 | .highlight .nf { color: #990000; font-weight: bold } /* Name.Function */
36 | .highlight .nn { color: #555555 } /* Name.Namespace */
37 | .highlight .nt { color: #000080 } /* Name.Tag */
38 | .highlight .nv { color: #008080 } /* Name.Variable */
39 | .highlight .ow { font-weight: bold } /* Operator.Word */
40 | .highlight .w { color: #bbbbbb } /* Text.Whitespace */
41 | .highlight .mf { color: #009999 } /* Literal.Number.Float */
42 | .highlight .mh { color: #009999 } /* Literal.Number.Hex */
43 | .highlight .mi { color: #009999 } /* Literal.Number.Integer */
44 | .highlight .mo { color: #009999 } /* Literal.Number.Oct */
45 | .highlight .sb { color: #d14 } /* Literal.String.Backtick */
46 | .highlight .sc { color: #d14 } /* Literal.String.Char */
47 | .highlight .sd { color: #d14 } /* Literal.String.Doc */
48 | .highlight .s2 { color: #d14 } /* Literal.String.Double */
49 | .highlight .se { color: #d14 } /* Literal.String.Escape */
50 | .highlight .sh { color: #d14 } /* Literal.String.Heredoc */
51 | .highlight .si { color: #d14 } /* Literal.String.Interpol */
52 | .highlight .sx { color: #d14 } /* Literal.String.Other */
53 | .highlight .sr { color: #009926 } /* Literal.String.Regex */
54 | .highlight .s1 { color: #d14 } /* Literal.String.Single */
55 | .highlight .ss { color: #990073 } /* Literal.String.Symbol */
56 | .highlight .bp { color: #999999 } /* Name.Builtin.Pseudo */
57 | .highlight .vc { color: #008080 } /* Name.Variable.Class */
58 | .highlight .vg { color: #008080 } /* Name.Variable.Global */
59 | .highlight .vi { color: #008080 } /* Name.Variable.Instance */
60 | .highlight .il { color: #009999 } /* Literal.Number.Integer.Long */
--------------------------------------------------------------------------------
/docs/js/toc.js:
--------------------------------------------------------------------------------
1 | // https://github.com/ghiculescu/jekyll-table-of-contents
2 | // this library modified by fastai to:
3 | // - update the location.href with the correct anchor when a toc item is clicked on
4 | (function($){
5 | $.fn.toc = function(options) {
6 | var defaults = {
7 | noBackToTopLinks: false,
8 | title: '',
9 | minimumHeaders: 3,
10 | headers: 'h1, h2, h3, h4',
11 | listType: 'ol', // values: [ol|ul]
12 | showEffect: 'show', // values: [show|slideDown|fadeIn|none]
13 | showSpeed: 'slow' // set to 0 to deactivate effect
14 | },
15 | settings = $.extend(defaults, options);
16 |
17 | var headers = $(settings.headers).filter(function() {
18 | // get all headers with an ID
19 | var previousSiblingName = $(this).prev().attr( "name" );
20 | if (!this.id && previousSiblingName) {
21 | this.id = $(this).attr( "id", previousSiblingName.replace(/\./g, "-") );
22 | }
23 | return this.id;
24 | }), output = $(this);
25 | if (!headers.length || headers.length < settings.minimumHeaders || !output.length) {
26 | return;
27 | }
28 |
29 | if (0 === settings.showSpeed) {
30 | settings.showEffect = 'none';
31 | }
32 |
33 | var render = {
34 | show: function() { output.hide().html(html).show(settings.showSpeed); },
35 | slideDown: function() { output.hide().html(html).slideDown(settings.showSpeed); },
36 | fadeIn: function() { output.hide().html(html).fadeIn(settings.showSpeed); },
37 | none: function() { output.html(html); }
38 | };
39 |
40 | var get_level = function(ele) { return parseInt(ele.nodeName.replace("H", ""), 10); }
41 | var highest_level = headers.map(function(_, ele) { return get_level(ele); }).get().sort()[0];
42 | //var return_to_top = ' ';
43 | // other nice icons that can be used instead: glyphicon-upload glyphicon-hand-up glyphicon-chevron-up glyphicon-menu-up glyphicon-triangle-top
44 | var level = get_level(headers[0]),
45 | this_level,
46 | html = settings.title + " <"+settings.listType+">";
47 | headers.on('click', function() {
48 | if (!settings.noBackToTopLinks) {
49 | var pos = $(window).scrollTop();
50 | window.location.hash = this.id;
51 | $(window).scrollTop(pos);
52 | }
53 | })
54 | .addClass('clickable-header')
55 | .each(function(_, header) {
56 | base_url = window.location.href;
57 | base_url = base_url.replace(/#.*$/, "");
58 | this_level = get_level(header);
59 | //if (!settings.noBackToTopLinks && this_level > 1) {
60 | // $(header).addClass('top-level-header').before(return_to_top);
61 | //}
62 | txt = header.textContent.split('¶')[0].split(/\[(test|source)\]/)[0];
63 | if (!txt) {return;}
64 | if (this_level === level) // same level as before; same indenting
65 | html += "" + txt + " ";
66 | else if (this_level <= level){ // higher level than before; end parent ol
67 | for(i = this_level; i < level; i++) {
68 | html += " "+settings.listType+">"
69 | }
70 | html += "" + txt + " ";
71 | }
72 | else if (this_level > level) { // lower level than before; expand the previous to contain a ol
73 | for(i = this_level; i > level; i--) {
74 | html += "<"+settings.listType+">"+((i-level == 2) ? "" : " ")
75 | }
76 | html += "" + txt + " ";
77 | }
78 | level = this_level; // update for the next one
79 | });
80 | html += ""+settings.listType+">";
81 | if (!settings.noBackToTopLinks) {
82 | $(document).on('click', '.back-to-top', function() {
83 | $(window).scrollTop(0);
84 | window.location.hash = '';
85 | });
86 | }
87 |
88 | render[settings.showEffect]();
89 | };
90 | })(jQuery);
91 |
--------------------------------------------------------------------------------
/docs/_layouts/default.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | {% include head.html %}
5 |
41 |
46 |
57 | {% if page.datatable == true %}
58 |
59 |
60 |
61 |
66 |
76 | {% endif %}
77 |
78 |
79 |
80 | {% include topnav.html %}
81 |
82 |
83 |
84 |
85 |
86 | {% assign content_col_size = "col-md-12" %}
87 | {% unless page.hide_sidebar %}
88 |
89 |
92 | {% assign content_col_size = "col-md-9" %}
93 | {% endunless %}
94 |
95 |
96 |
97 | {{content}}
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 | {% if site.google_analytics %}
108 | {% include google_analytics.html %}
109 | {% endif %}
110 |
111 |
--------------------------------------------------------------------------------
/fastai_xla_extensions/multi_core/callback.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03c_multi_core.callback.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['maybe_item']
4 |
5 | # Internal Cell
6 | from ..utils import xla_imported
7 | from ..misc_utils import *
8 | from .base import *
9 | # from fastai_xla_extensions.multi_core.learner import *
10 |
11 | # Internal Cell
12 | try:
13 | import torch_xla
14 | except:
15 | pass
16 |
17 | # Internal Cell
18 | if xla_imported():
19 | import torch_xla.core.xla_model as xm
20 | import torch_xla.distributed.xla_multiprocessing as xmp
21 |
22 | # Internal Cell
23 | # from fastai.vision.all import *
24 |
25 |
26 | # Cell
27 | import torch
28 | from fastcore.xtras import is_listy
29 | def maybe_item(o):
30 | '''extract scalar values from a tensor, lists and dicts of tensors
31 | (and pulling it out of gpu/tpu into cpu) else if not tensor just
32 | use orig value'''
33 | if isinstance(o,torch.Tensor): return o.item()
34 | if is_listy(o):
35 | kls = o.__class__
36 | k = [maybe_item(i) for i in o]
37 | return kls(k)
38 | if isinstance(o,dict):
39 | return {k:maybe_item(v) for k,v in o.items()}
40 | # maybe scalar or object
41 | return o
42 |
43 |
44 | # Cell
45 | from fastai.learner import Recorder
46 | from fastcore.basics import patch
47 |
48 | @patch
49 | def get_extra_attrs(self:Recorder):
50 | 'Extract state attrs of Recorder into a dict (suitable for pickling)'
51 | # state_attrs = lrs','iters','losses','values'
52 | d = {}
53 | for attr in self._stateattrs:
54 | if hasattr(self,attr):
55 | value = getattr(self,attr)
56 | d[attr] = maybe_item(value)
57 | return d
58 |
59 |
60 | # Cell
61 | import pickle
62 | from fastai.learner import Recorder
63 | from fastcore.basics import patch
64 |
65 | @patch
66 | def dump_attrs(self:Recorder, fn='_rec_attr.pkl'):
67 | 'dump state attrs to a file'
68 | d = self.get_extra_attrs()
69 | with open(fn,'wb') as f:
70 | pickle.dump(d,f)
71 |
72 |
73 | # Cell
74 | import pickle
75 | from fastai.learner import Recorder
76 | from fastcore.basics import patch
77 | from pathlib import Path
78 |
79 | @patch
80 | def reload_attrs(self:Recorder, fn='_rec_attr.pkl'):
81 | 'reload attrs from file `fn`'
82 | if isinstance(fn,str):
83 | fn = Path(fn)
84 | if not fn.is_file():
85 | return
86 | with open(fn,'rb') as f:
87 | d = pickle.load(f)
88 | for k,v in d.items():
89 | setattr(self,k,v)
90 | fn.unlink()
91 |
92 | # Cell
93 | from fastai.learner import Recorder
94 | from fastcore.basics import patch
95 |
96 | @patch
97 | def after_fit(self: Recorder):
98 | 'after fit dump extra attrs to file'
99 | if getattr(self.learn,'inner_xla',False) and self.learn.xla_rank == 0:
100 | self.dump_attrs()
101 |
102 |
103 | # Cell
104 | from fastai.callback.schedule import ParamScheduler
105 | from fastcore.basics import patch
106 | from pathlib import Path
107 | import pickle
108 |
109 | @patch
110 | def dump_hps(self:ParamScheduler, fn='_paramsched_hps.pkl'):
111 | 'dump `hps` to a file `fn`'
112 | if not hasattr(self, 'hps'):
113 | return
114 |
115 | if isinstance(fn,str):
116 | fn = Path(fn)
117 |
118 | d = maybe_item(self.hps)
119 | with open(fn,'wb') as f:
120 | pickle.dump(d,f)
121 |
122 |
123 | # Cell
124 | from fastai.learner import Recorder
125 | from fastcore.basics import patch
126 | from pathlib import Path
127 |
128 | @patch
129 | def reload_hps(self:Recorder, fn='_paramsched_hps.pkl'):
130 | 'Load hyperparameters saved by ParamScheduler to recorder'
131 | if isinstance(fn,str):
132 | fn = Path(fn)
133 | if not fn.is_file():
134 | return
135 | with open(fn,'rb') as f:
136 | d = pickle.load(f)
137 | setattr(self,'hps',d)
138 | fn.unlink()
139 |
140 | # Cell
141 | from fastai.callback.schedule import ParamScheduler
142 | from fastcore.basics import patch
143 |
144 | @patch
145 | def after_fit(self:ParamScheduler):
146 | "save hps to file"
147 | if not hasattr(self,'hps'):
148 | return
149 |
150 | if hasattr(self.learn, 'recorder'):
151 | self.recorder.hps = self.hps
152 |
153 | if getattr(self.learn,'inner_xla',False) and self.learn.xla_rank == 0:
154 | self.dump_hps()
155 |
--------------------------------------------------------------------------------
/docs/_includes/head.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | {{ page.title }} | {{ site.site_title }}
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | {% if site.use_math %}
25 |
26 |
27 |
28 |
39 | {% endif %}
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
56 |
57 |
58 |
59 |
60 | {% if site.twitter_username %}
61 |
62 |
63 |
64 | {% endif %}
65 |
66 | {% if page.summary %}
67 |
68 | {% else %}
69 |
70 | {% endif %}
71 |
72 | {% if page.image %}
73 |
74 |
75 | {% else %}
76 |
77 |
78 | {% endif %}
79 |
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/docs/utils.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Utils
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 |
10 |
11 | nb_path: "nbs/01_utils.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
Utilities used by other modules
42 |
43 |
44 |
45 |
46 |
47 | {% raw %}
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
xla_imported()
59 |
60 |
Check whether the torch_xla module has been successfully imported
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 | {% endraw %}
71 |
72 | {% raw %}
73 |
74 |
75 |
76 |
77 | {% endraw %}
78 |
79 |
80 |
81 |
xla_imported is a utility method that is used to check if the torch_xla module has been successfully imported.
82 |
83 |
84 |
85 |
86 | {% raw %}
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
print_aten_ops()
98 |
99 |
print out xla aten operations (from xla debug metrics report torch_xla.debug.metrics)
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 | {% endraw %}
110 |
111 | {% raw %}
112 |
113 |
114 |
115 |
116 | {% endraw %}
117 |
118 |
119 |
120 |
One of the problems we have hit testing different models and transforms is that sometimes it is slower on TPUs compared to running on CPUs, but this happens because we hit operations on Pytorch XLA that are only handled by the CPU and not by the accelerator.
121 |
print_aten_ops calls directly some pytorch metrics which outputs to stdout, so the only way to get that info is capture it.
122 |
123 |
124 |
125 |
126 | {% raw %}
127 |
128 |
142 | {% endraw %}
143 |
144 |
145 |
146 |
147 |
--------------------------------------------------------------------------------
/docs/js/jekyll-search.js:
--------------------------------------------------------------------------------
1 | !function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a="function"==typeof require&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);throw new Error("Cannot find module '"+o+"'")}var f=n[o]={exports:{}};t[o][0].call(f.exports,function(e){var n=t[o][1][e];return s(n?n:e)},f,f.exports,e,t,n,r)}return n[o].exports}for(var i="function"==typeof require&&require,o=0;o=0}var self=this;self.matches=function(string,crit){return"string"!=typeof string?!1:(string=string.trim(),doMatch(string,crit))}}module.exports=new LiteralSearchStrategy},{}],4:[function(require,module){module.exports=function(){function findMatches(store,crit,strategy){for(var data=store.get(),i=0;i{title} ',noResultsText:"No results found",limit:10,fuzzy:!1};self.init=function(_opt){validateOptions(_opt),assignOptions(_opt),isJSON(opt.dataSource)?initWithJSON(opt.dataSource):initWithURL(opt.dataSource)}}var Searcher=require("./Searcher"),Templater=require("./Templater"),Store=require("./Store"),JSONLoader=require("./JSONLoader"),searcher=new Searcher,templater=new Templater,store=new Store,jsonLoader=new JSONLoader;window.SimpleJekyllSearch=new SimpleJekyllSearch}(window,document)},{"./JSONLoader":1,"./Searcher":4,"./Store":5,"./Templater":6}]},{},[7]);
2 |
--------------------------------------------------------------------------------
/nbs/02b_misc_utils.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "#default_exp misc_utils"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# Miscellaneous Utilities"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | " "
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [
31 | {
32 | "name": "stdout",
33 | "output_type": "stream",
34 | "text": [
35 | "\u001b[K |████████████████████████████████| 133.6MB 100kB/s \n",
36 | "\u001b[K |████████████████████████████████| 61kB 3.3MB/s \n",
37 | "\u001b[?25h"
38 | ]
39 | }
40 | ],
41 | "source": [
42 | "#hide\n",
43 | "#colab\n",
44 | "!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {},
51 | "outputs": [
52 | {
53 | "name": "stdout",
54 | "output_type": "stream",
55 | "text": [
56 | "Updating fastai...\n",
57 | "Done.\n"
58 | ]
59 | }
60 | ],
61 | "source": [
62 | "#hide\n",
63 | "#colab\n",
64 | "!curl -s https://course19.fast.ai/setup/colab | bash"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {},
71 | "outputs": [
72 | {
73 | "name": "stdout",
74 | "output_type": "stream",
75 | "text": [
76 | "\u001b[K |████████████████████████████████| 194kB 6.5MB/s \n",
77 | "\u001b[K |████████████████████████████████| 61kB 7.4MB/s \n",
78 | "\u001b[?25h"
79 | ]
80 | }
81 | ],
82 | "source": [
83 | "#hide\n",
84 | "#colab\n",
85 | "# !pip install -Uqq git+https://github.com/fastai/fastai.git \n",
86 | "!pip install -Uqq fastai==2.3.0"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "#hide\n",
96 | "#colab\n",
97 | "!pip install -qqq nbdev"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "name": "stdout",
107 | "output_type": "stream",
108 | "text": [
109 | "torch==1.7.0+cu101\n",
110 | "torch-xla==1.7\n",
111 | "torchsummary==1.5.1\n",
112 | "torchtext==0.3.1\n",
113 | "torchvision==0.8.1+cu101\n",
114 | "fastai==2.2.5\n",
115 | "fastcore==1.3.19\n",
116 | "fastdtw==0.3.4\n",
117 | "fastprogress==1.0.0\n",
118 | "fastrlock==0.5\n"
119 | ]
120 | }
121 | ],
122 | "source": [
123 | "#hide\n",
124 | "!pip freeze | grep torch\n",
125 | "!pip freeze | grep fast\n",
126 | "!pip freeze | grep nbdev"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "metadata": {},
132 | "source": [
133 | "## `_BaseOptimizer` patches\n"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": null,
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "#hide\n",
143 | "#colab\n",
144 | "from nbdev.showdoc import *"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "metadata": {},
151 | "outputs": [],
152 | "source": [
153 | "#export\n",
154 | "from fastai.optimizer import _BaseOptimizer\n",
155 | "from fastcore.basics import patch\n",
156 | "\n",
157 | "@patch\n",
158 | "def __getstate__(self:_BaseOptimizer):\n",
159 | " \"Pickling opt state should include `param_groups` and `defaults` \"\n",
160 | " d = {\n",
161 | " 'state': self.state_dict(),\n",
162 | " 'param_groups': self.param_groups,\n",
163 | " }\n",
164 | " if hasattr(self,'defaults'):\n",
165 | " d['defaults'] = self.defaults\n",
166 | " return d\n",
167 | "\n",
168 | "@patch\n",
169 | "def __setstate__(self:_BaseOptimizer, data):\n",
170 | " \"Pickling opt state should include `param_groups` and `defaults` \"\n",
171 | "\n",
172 | " if 'defaults' in data:\n",
173 | " self.defaults = data['defaults']\n",
174 | " self.load_state_dict(data['state'])\n",
175 | " self.param_groups = data['param_groups']"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "#hide_input\n",
185 | "#colab\n",
186 | "show_doc(_BaseOptimizer.__getstate__)\n",
187 | "show_doc(_BaseOptimizer.__setstate__)"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "metadata": {},
193 | "source": [
194 | "\n",
195 | "Patch the `fastai.optimizer._BaseOptimizer` `__getstate__` and `__setstate__` methods which are used in pickling fastai optimizers. \n",
196 | "\n",
197 | "This should fix the bug where running the learner on multiple TPU cores on XLA triggers an error in which the method `_fetch_gradients(optimizer)` fails in the statement `for param_group in optimizer.__getstate__()['param_groups']:` in the `torch_xla.core.xla_model` module. \n",
198 | "\n",
199 | "The patch modifies the copy constructor to include the param_groups and defaults."
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "#hide\n",
209 | "#TODO add tests for pickling optimizers"
210 | ]
211 | }
212 | ],
213 | "metadata": {
214 | "kernelspec": {
215 | "display_name": "Python 3 (ipykernel)",
216 | "language": "python",
217 | "name": "python3"
218 | }
219 | },
220 | "nbformat": 4,
221 | "nbformat_minor": 4
222 | }
223 |
--------------------------------------------------------------------------------
/nbs/01_utils.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "#default_exp utils"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# Utils"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | " \n"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "> Utilities used by other modules"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [
38 | {
39 | "name": "stdout",
40 | "output_type": "stream",
41 | "text": [
42 | "\u001b[K |████████████████████████████████| 133.6MB 77kB/s \n",
43 | "\u001b[K |████████████████████████████████| 61kB 3.1MB/s \n",
44 | "\u001b[?25h"
45 | ]
46 | }
47 | ],
48 | "source": [
49 | "#hide\n",
50 | "#colab\n",
51 | "!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stderr",
61 | "output_type": "stream",
62 | "text": [
63 | "WARNING:root:Waiting for TPU to be start up with version pytorch-1.7...\n",
64 | "WARNING:root:Waiting for TPU to be start up with version pytorch-1.7...\n",
65 | "WARNING:root:TPU has started up successfully with version pytorch-1.7\n"
66 | ]
67 | }
68 | ],
69 | "source": [
70 | "#exporti\n",
71 | "try:\n",
72 | " import torch_xla\n",
73 | "except ImportError:\n",
74 | " pass"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {},
80 | "source": []
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "#export\n",
89 | "import sys\n",
90 | "\n",
91 | "def xla_imported():\n",
92 | " \"Check whether the `torch_xla` module has been successfully imported\"\n",
93 | " return 'torch_xla' in sys.modules"
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "metadata": {},
99 | "source": [
100 | "`xla_imported` is a utility method that is used to check if the `torch_xla` module has been successfully imported."
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "#hide\n",
110 | "# fake out xla modules on environments not configured for TPU \n",
111 | "if not xla_imported():\n",
112 | " from types import SimpleNamespace\n",
113 | " def fake_metrics_report(*args,**kwargs):\n",
114 | " return \"\"\n",
115 | " met = SimpleNamespace(\n",
116 | " metrics_report = fake_metrics_report\n",
117 | " )\n",
118 | " "
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "#exporti\n",
128 | "if xla_imported():\n",
129 | " import torch_xla.debug.metrics as met"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "#export\n",
139 | "def print_aten_ops():\n",
140 | " \"print out xla aten operations (from xla debug metrics report `torch_xla.debug.metrics`)\"\n",
141 | " # import torch_xla.debug.metrics as met\n",
142 | " from io import StringIO\n",
143 | " import sys\n",
144 | "\n",
145 | " class Capturing(list):\n",
146 | " def __enter__(self):\n",
147 | " self._stdout = sys.stdout\n",
148 | " sys.stdout = self._stringio = StringIO()\n",
149 | " return self\n",
150 | " def __exit__(self, *args):\n",
151 | " self.extend(self._stringio.getvalue().splitlines())\n",
152 | " del self._stringio # free up some memory\n",
153 | " sys.stdout = self._stdout\n",
154 | "\n",
155 | " out = met.metrics_report()\n",
156 | " if out.find(\"aten::\"):\n",
157 | " print_now=False\n",
158 | " lines = out.split(\"\\n\")\n",
159 | " for l in lines:\n",
160 | " if print_now:\n",
161 | " print_now=False\n",
162 | " print(l)\n",
163 | " if l.find(\"aten::\")>-1:\n",
164 | " print(\"needs lowering:\", l)\n",
165 | " print_now=True"
166 | ]
167 | },
168 | {
169 | "cell_type": "markdown",
170 | "metadata": {},
171 | "source": [
172 | "\n",
173 | "One of the problems we have hit testing different models and transforms is that sometimes it is slower on TPUs compared to running on CPUs, but this happens because we hit operations on Pytorch XLA that are only handled by the CPU and not by the accelerator. \n",
174 | "\n",
175 | "`print_aten_ops` calls directly some pytorch metrics which outputs to `stdout`, so the only way to get that info is capture it."
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "#colab\n",
185 | "#test that torch_xla has been imported on colab\n",
186 | "assert xla_imported()"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "metadata": {},
193 | "outputs": [],
194 | "source": [
195 | "#hide\n",
196 | "#TODO: Add example usage for print_aten_ops"
197 | ]
198 | }
199 | ],
200 | "metadata": {
201 | "kernelspec": {
202 | "display_name": "Python 3 (ipykernel)",
203 | "language": "python",
204 | "name": "python3"
205 | }
206 | },
207 | "nbformat": 4,
208 | "nbformat_minor": 4
209 | }
210 |
--------------------------------------------------------------------------------
/fastai_xla_extensions/core.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_core.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['XLAOptimProxy', 'DeviceMoverTransform', 'isAffineCoordTfm', 'isDeviceMoverTransform', 'has_affinecoord_tfm',
4 | 'has_devicemover_tfm', 'get_last_affinecoord_tfm_idx', 'insert_batch_tfm', 'XLAOptCallback']
5 |
6 | # Internal Cell
7 | from .utils import xla_imported
8 |
9 | # Internal Cell
10 | try:
11 | import torch_xla
12 | except ImportError:
13 | pass
14 |
15 | # Internal Cell
16 | if xla_imported():
17 | import torch_xla.core.xla_model as xm
18 |
19 | from fastcore.foundation import GetAttr, patch
20 | from fastcore.transform import Transform,DisplayedTransform
21 | from fastcore.basics import store_attr
22 | from torch import Tensor
23 | import torch
24 | from fastai.vision.augment import AffineCoordTfm, RandomResizedCropGPU
25 | from fastai.data.core import DataLoaders
26 | from fastai.data.load import DataLoader
27 | from fastai.learner import Learner
28 | from fastai.callback.core import Callback, TrainEvalCallback
29 | from fastai.learner import Recorder
30 |
31 | # Cell
32 |
33 | class XLAOptimProxy(GetAttr):
34 | "Proxy optimizer to override `opt.step` with Pytorch XLA sync method `xm.optimizer_step` "
35 | _default='opt'
36 | def __init__(self,opt, barrier):
37 | self.opt = opt
38 | self._barrier = barrier
39 |
40 | def step(self):
41 | xm.optimizer_step(self.opt,barrier=self._barrier)
42 |
43 | @property
44 | def barrier(self): return self._barrier
45 | @barrier.setter
46 | def barrier(self,v): self._barrier = v
47 |
48 | # Cell
49 | class DeviceMoverTransform(DisplayedTransform):
50 | "Transform to move input to new device and reverse to cpu"
51 | def __init__(self, device_to, device_from=torch.device('cpu')):
52 | store_attr('device_to,device_from')
53 | def encodes(self, o:Tensor):
54 | return o.to(self.device_to)
55 | def decodes(self, o:Tensor):
56 | return o.to(self.device_from)
57 |
58 | # Cell
59 |
60 | def isAffineCoordTfm(o:Transform):
61 | "check whether the transform is either an AffineCoordTfm or RandomResizedCropGPU"
62 | return isinstance(o,(AffineCoordTfm,RandomResizedCropGPU))
63 |
64 | def isDeviceMoverTransform(o:Transform):
65 | "check whether the transform is a DeviceMoverTransform"
66 | return isinstance(o,DeviceMoverTransform)
67 |
68 | def has_affinecoord_tfm(dls: DataLoaders) -> bool:
69 | "returns true if train dataloader has an AffineCoordTfm in the batch_tfms"
70 | if not hasattr(dls.train,'after_batch'): return False
71 | if not hasattr(dls.train.after_batch,'fs'): return False
72 | idxs = dls.train.after_batch.fs.argwhere(isAffineCoordTfm)
73 | return len(idxs) > 0
74 | def has_devicemover_tfm(dl: DataLoader) -> bool:
75 | "returns true if train dataloader has a DeviceMoverTransform in the batch_tfms"
76 | if not hasattr(dl,'after_batch'): return False
77 | if not hasattr(dl.after_batch,'fs'): return False
78 | idxs = dl.after_batch.fs.argwhere(isDeviceMoverTransform)
79 | return len(idxs) > 0
80 |
81 | def get_last_affinecoord_tfm_idx(dl:DataLoader)-> int: # -1 if none
82 | "returns index of last AffineCoordTfm if it exists, otherwise returns -1"
83 | idxs = dl.after_batch.fs.argwhere(isAffineCoordTfm)
84 | return -1 if len(idxs) == 0 else idxs[-1]
85 |
86 | # Cell
87 | def insert_batch_tfm(dl:DataLoader, batch_tfm:Transform, idx:int):
88 | "adds a batch_tfm in the batch_tfms for the dataloader at idx location"
89 | dl.after_batch.fs.insert(idx, batch_tfm)
90 |
91 | # Cell
92 | @patch
93 | def setup_input_device_mover(self: Learner, new_device):
94 | "setup batch_tfms to use cpu if dataloader batch_tfms has AffineCoordTfms"
95 | if not has_affinecoord_tfm(self.dls):
96 | self.dls.device = new_device
97 | return
98 | self.dls.device = None
99 | if has_devicemover_tfm(self.dls.train):
100 | return # skip adding device mover if already added
101 | dm_tfm = DeviceMoverTransform(new_device)
102 | for dl in self.dls.loaders:
103 | if not has_devicemover_tfm(dl):
104 | idx = get_last_affinecoord_tfm_idx(dl)
105 | if idx != -1:
106 | insert_batch_tfm(dl, dm_tfm, idx+1)
107 |
108 | # Cell
109 |
110 | class XLAOptCallback(Callback):
111 | 'Callback to replace `opt.step` with `xm.optimizer_step(opt)` as required to run on TPU'
112 | run_after,run_before = TrainEvalCallback,Recorder
113 | def __init__(self, barrier=True):
114 | self._barrier = barrier
115 |
116 | def before_fit(self):
117 | 'replace opt with proxy which calls `xm.optimizer_step` instead of `opt.step` and set `dls.device` and model to `xla_device`'
118 | # set dls device to none so prevent trigger of moving to batch input to XLA device
119 | # as this move will be done by the DeviceMoverTransform which has been added to the dls after_batch tfms
120 | if has_affinecoord_tfm(self.dls):
121 | self.dls.device = None
122 |
123 | if self.learn.opt is not None:
124 | if not isinstance(self.learn.opt,XLAOptimProxy):
125 | # force opt to reinitialize its parameters and make sure its parameters
126 | opt = self.learn.opt
127 | self.learn.opt = XLAOptimProxy(opt, barrier=self._barrier)
128 |
129 | def after_fit(self):
130 | 'restore original opt '
131 | if isinstance(self.learn.opt, XLAOptimProxy):
132 | opt = self.learn.opt.opt
133 | self.learn.opt = opt
134 | @property
135 | def barrier(self): return self._barrier
136 | @barrier.setter
137 | def barrier(self,v): self._barrier = v
138 |
139 | # Cell
140 |
141 | @patch
142 | def to_xla(self:Learner, new_device=None):
143 | "Setup learner for single tpu core training"
144 | self.add_cb(XLAOptCallback())
145 | if new_device is None:
146 | new_device = xm.xla_device()
147 | self.model.to(new_device)
148 | self.setup_input_device_mover(new_device)
149 | self.opt = None
150 | return self
151 |
152 | # Cell
153 |
154 | @patch
155 | def detach_xla(self:Learner):
156 | "reset TPU single core setup and move model and dls back to cpu "
157 | self.remove_cb(XLAOptCallback)
158 | self.dls.device = torch.device('cpu')
159 | self.model = self.model.to(self.dls.device)
160 | self.opt = None
161 | return self
--------------------------------------------------------------------------------
/fastai_xla_extensions/_nbdev.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED BY NBDEV! DO NOT EDIT!
2 |
3 | __all__ = ["index", "modules", "custom_doc_links", "git_url"]
4 |
5 | index = {"XLAOptimProxy": "00_core.ipynb",
6 | "DeviceMoverTransform": "00_core.ipynb",
7 | "isAffineCoordTfm": "00_core.ipynb",
8 | "isDeviceMoverTransform": "00_core.ipynb",
9 | "has_affinecoord_tfm": "00_core.ipynb",
10 | "has_devicemover_tfm": "00_core.ipynb",
11 | "get_last_affinecoord_tfm_idx": "00_core.ipynb",
12 | "insert_batch_tfm": "00_core.ipynb",
13 | "Learner.setup_input_device_mover": "00_core.ipynb",
14 | "XLAOptCallback": "00_core.ipynb",
15 | "Learner.to_xla": "00_core.ipynb",
16 | "Learner.detach_xla": "00_core.ipynb",
17 | "xla_imported": "01_utils.ipynb",
18 | "print_aten_ops": "01_utils.ipynb",
19 | "download_torch_dsets": "02_cifar_loader.ipynb",
20 | "load_torch_items": "02_cifar_loader.ipynb",
21 | "load_classes": "02_cifar_loader.ipynb",
22 | "CifarNP2ImageTransform": "02_cifar_loader.ipynb",
23 | "Int2TensorTransform": "02_cifar_loader.ipynb",
24 | "CifarImageTransform": "02_cifar_loader.ipynb",
25 | "CifarImage2FloatTransform": "02_cifar_loader.ipynb",
26 | "make_torch_tfms": "02_cifar_loader.ipynb",
27 | "CifarTupleTransform": "02_cifar_loader.ipynb",
28 | "TupleTorchDS": "02_cifar_loader.ipynb",
29 | "make_cifar_item_tfm": "02_cifar_loader.ipynb",
30 | "i2t_tfm": "02_cifar_loader.ipynb",
31 | "cfnp2img_tfm": "02_cifar_loader.ipynb",
32 | "cfimg_tfm": "02_cifar_loader.ipynb",
33 | "cfimg2float_tfm": "02_cifar_loader.ipynb",
34 | "make_cifar_tls": "02_cifar_loader.ipynb",
35 | "make_cifar_dl": "02_cifar_loader.ipynb",
36 | "make_fastai_cifar_dls": "02_cifar_loader.ipynb",
37 | "revert_tensor": "03_multi_core.base.ipynb",
38 | "recast2tensor": "03_multi_core.base.ipynb",
39 | "round_to_multiple": "03_multi_core.base.ipynb",
40 | "TPUDistributedDL": "03_multi_core.base.ipynb",
41 | "th_data.DataLoader.__setattr__": "03_multi_core.base.ipynb",
42 | "after_batch": "03_multi_core.base.ipynb",
43 | "bs": "03_multi_core.base.ipynb",
44 | "device": "03_multi_core.base.ipynb",
45 | "th_data.DataLoader.to": "03_multi_core.base.ipynb",
46 | "th_data.DataLoader.set_distributed_sampler": "03_multi_core.base.ipynb",
47 | "build_distributed_dataloaders": "03_multi_core.base.ipynb",
48 | "make_fastai_dataloaders": "03_multi_core.base.ipynb",
49 | "wrap_parallel_loader": "03_multi_core.base.ipynb",
50 | "XLATrainingCallback": "03_multi_core.base.ipynb",
51 | "pack_metric": "03_multi_core.base.ipynb",
52 | "make_tensor": "03_multi_core.base.ipynb",
53 | "pack_metrics": "03_multi_core.base.ipynb",
54 | "restore_metrics": "03_multi_core.base.ipynb",
55 | "SyncedAvgSmoothLoss": "03_multi_core.base.ipynb",
56 | "SyncRecorderCallback": "03_multi_core.base.ipynb",
57 | "xm_save": "03_multi_core.base.ipynb",
58 | "Learner.save": "03_multi_core.base.ipynb",
59 | "Learner.to_multi_xla": "03_multi_core.base.ipynb",
60 | "do_one_loop": "03_multi_core.base.ipynb",
61 | "TfmdTorchDS": "03a_multi_core.torch_compat.ipynb",
62 | "to_list": "03a_multi_core.torch_compat.ipynb",
63 | "has_setup": "03a_multi_core.torch_compat.ipynb",
64 | "run_setups": "03a_multi_core.torch_compat.ipynb",
65 | "TorchDatasetBuilder": "03a_multi_core.torch_compat.ipynb",
66 | "VocabularyMapper": "03a_multi_core.torch_compat.ipynb",
67 | "make_torch_dataloaders": "03a_multi_core.torch_compat.ipynb",
68 | "FileNamePatternLabeller": "03a_multi_core.torch_compat.ipynb",
69 | "master_cbs": "03b_multi_core.learner.ipynb",
70 | "Learner.add_master_cb": "03b_multi_core.learner.ipynb",
71 | "Learner.add_master_cbs": "03b_multi_core.learner.ipynb",
72 | "Learner.grab_master_cbs": "03b_multi_core.learner.ipynb",
73 | "Learner.remove_master_cb": "03b_multi_core.learner.ipynb",
74 | "Learner.remove_master_cbs": "03b_multi_core.learner.ipynb",
75 | "make_xla_child_learner": "03b_multi_core.learner.ipynb",
76 | "setup_fit_cbs": "03b_multi_core.learner.ipynb",
77 | "xla_run_method": "03b_multi_core.learner.ipynb",
78 | "Learner.pack_learner_args": "03b_multi_core.learner.ipynb",
79 | "Learner.reload_child_model": "03b_multi_core.learner.ipynb",
80 | "Learner.delete_tmp_files": "03b_multi_core.learner.ipynb",
81 | "Learner.pre_xla_fit": "03b_multi_core.learner.ipynb",
82 | "Learner.post_xla_fit": "03b_multi_core.learner.ipynb",
83 | "tmp_files": "03b_multi_core.learner.ipynb",
84 | "prep_fit_args": "03b_multi_core.learner.ipynb",
85 | "Learner.xla_fit": "03b_multi_core.learner.ipynb",
86 | "Learner.xla_fit_one_cycle": "03b_multi_core.learner.ipynb",
87 | "Learner.xla_fit_flat_cos": "03b_multi_core.learner.ipynb",
88 | "prep_fit_sgdr_args": "03b_multi_core.learner.ipynb",
89 | "Learner.xla_fit_sgdr": "03b_multi_core.learner.ipynb",
90 | "prep_finetune_args": "03b_multi_core.learner.ipynb",
91 | "Learner.xla_fine_tune": "03b_multi_core.learner.ipynb",
92 | "maybe_item": "03c_multi_core.callback.ipynb",
93 | "Recorder.get_extra_attrs": "03c_multi_core.callback.ipynb",
94 | "Recorder.dump_attrs": "03c_multi_core.callback.ipynb",
95 | "Recorder.reload_attrs": "03c_multi_core.callback.ipynb",
96 | "Recorder.after_fit": "03c_multi_core.callback.ipynb",
97 | "ParamScheduler.dump_hps": "03c_multi_core.callback.ipynb",
98 | "Recorder.reload_hps": "03c_multi_core.callback.ipynb",
99 | "ParamScheduler.after_fit": "03c_multi_core.callback.ipynb",
100 | "SkipValidationCallback": "03d_multi_core.lr_find.ipynb",
101 | "SyncedCancelCallback": "03d_multi_core.lr_find.ipynb",
102 | "XLALRFinder": "03d_multi_core.lr_find.ipynb",
103 | "Learner.get_suggested_lrs": "03d_multi_core.lr_find.ipynb",
104 | "xla_run_lr_find": "03d_multi_core.lr_find.ipynb",
105 | "Learner.xla_lr_find": "03d_multi_core.lr_find.ipynb",
106 | "Learner.inner_get_preds": "03e_multi_core.inference.ipynb",
107 | "setup_inference_args": "03e_multi_core.inference.ipynb",
108 | "save_pred_results": "03e_multi_core.inference.ipynb",
109 | "xla_run_inference": "03e_multi_core.inference.ipynb",
110 | "reload_pred_results": "03e_multi_core.inference.ipynb",
111 | "Learner.pre_xla_inference": "03e_multi_core.inference.ipynb",
112 | "Learner.post_xla_inference": "03e_multi_core.inference.ipynb",
113 | "prep_inference_args": "03e_multi_core.inference.ipynb",
114 | "Learner.xla_get_preds": "03e_multi_core.inference.ipynb"}
115 |
116 | modules = ["core.py",
117 | "utils.py",
118 | "cifar_loader.py",
119 | "misc_utils.py",
120 | "multi_core/base.py",
121 | "multi_core/torch_compat.py",
122 | "multi_core/learner.py",
123 | "multi_core/callback.py",
124 | "multi_core/lr_find.py",
125 | "multi_core/inference.py"]
126 |
127 | doc_url = "https://butchland.github.io/fastai_xla_extensions/"
128 |
129 | git_url = "https://github.com/butchland/fastai_xla_extensions/tree/master/"
130 |
131 | def custom_doc_links(name): return None
132 |
--------------------------------------------------------------------------------
/archive_nbs/CollaboratoryFilteringGPU.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from fastai2.tabular.all import *\n",
10 | "from fastai2.collab import *"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 3,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "path = untar_data(URLs.ML_100k)\n",
20 | "ratings = pd.read_csv(path/'u.data', delimiter='\\t', header=None,\n",
21 | " usecols=(0,1,2), names=['user','movie','rating'])\n",
22 | "movies = pd.read_csv(path/'u.item', delimiter='|', encoding='latin-1',\n",
23 | " usecols=(0,1), names=('movie','title'), header=None)"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 4,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "ratings = ratings.merge(movies)"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 5,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "dls = CollabDataLoaders.from_df(ratings, item_name='title', bs=64)"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 6,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "learn = collab_learner(dls, n_factors=50, y_range=(0, 5.5))"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 7,
56 | "metadata": {},
57 | "outputs": [
58 | {
59 | "data": {
60 | "text/html": [
61 | "\n",
62 | " \n",
63 | " \n",
64 | " epoch \n",
65 | " train_loss \n",
66 | " valid_loss \n",
67 | " time \n",
68 | " \n",
69 | " \n",
70 | " \n",
71 | " \n",
72 | " 0 \n",
73 | " 0.943764 \n",
74 | " 0.945623 \n",
75 | " 00:08 \n",
76 | " \n",
77 | " \n",
78 | " 1 \n",
79 | " 0.850762 \n",
80 | " 0.866062 \n",
81 | " 00:08 \n",
82 | " \n",
83 | " \n",
84 | " 2 \n",
85 | " 0.735538 \n",
86 | " 0.828996 \n",
87 | " 00:08 \n",
88 | " \n",
89 | " \n",
90 | " 3 \n",
91 | " 0.578858 \n",
92 | " 0.816693 \n",
93 | " 00:08 \n",
94 | " \n",
95 | " \n",
96 | " 4 \n",
97 | " 0.469008 \n",
98 | " 0.817968 \n",
99 | " 00:08 \n",
100 | " \n",
101 | " \n",
102 | "
"
103 | ],
104 | "text/plain": [
105 | ""
106 | ]
107 | },
108 | "metadata": {},
109 | "output_type": "display_data"
110 | }
111 | ],
112 | "source": [
113 | "learn.fit_one_cycle(5, 5e-3, wd=0.1)"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {},
119 | "source": [
120 | "### Interpretation"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 10,
126 | "metadata": {},
127 | "outputs": [
128 | {
129 | "data": {
130 | "text/plain": [
131 | "array(['Star Wars (1977)', 'Contact (1997)', 'Fargo (1996)',\n",
132 | " 'Return of the Jedi (1983)', 'Liar Liar (1997)',\n",
133 | " 'English Patient, The (1996)', 'Scream (1996)', 'Toy Story (1995)',\n",
134 | " 'Air Force One (1997)', 'Independence Day (ID4) (1996)'],\n",
135 | " dtype=object)"
136 | ]
137 | },
138 | "execution_count": 10,
139 | "metadata": {},
140 | "output_type": "execute_result"
141 | }
142 | ],
143 | "source": [
144 | "g = ratings.groupby(\"title\")['rating'].count()\n",
145 | "top_movies = g.sort_values(ascending=False).index.values[:1000]\n",
146 | "top_movies[:10]"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 11,
152 | "metadata": {},
153 | "outputs": [
154 | {
155 | "data": {
156 | "text/plain": [
157 | "torch.Size([1000])"
158 | ]
159 | },
160 | "execution_count": 11,
161 | "metadata": {},
162 | "output_type": "execute_result"
163 | }
164 | ],
165 | "source": [
166 | "movie_bias = learn.model.bias(top_movies, is_item=True)\n",
167 | "movie_bias.shape"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": 16,
173 | "metadata": {},
174 | "outputs": [],
175 | "source": [
176 | "mean_ratings = ratings.groupby(\"title\")['rating'].mean()\n",
177 | "movie_ratings = [(b, i, mean_ratings.loc[i]) for i,b in zip(top_movies,movie_bias)]"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 17,
183 | "metadata": {},
184 | "outputs": [
185 | {
186 | "data": {
187 | "text/plain": [
188 | "[(tensor(-0.3798),\n",
189 | " 'Children of the Corn: The Gathering (1996)',\n",
190 | " 1.3157894736842106),\n",
191 | " (tensor(-0.2619), 'Mortal Kombat: Annihilation (1997)', 1.9534883720930232),\n",
192 | " (tensor(-0.2598), 'Bio-Dome (1996)', 1.903225806451613),\n",
193 | " (tensor(-0.2545), \"McHale's Navy (1997)\", 2.1884057971014492),\n",
194 | " (tensor(-0.2445), 'Showgirls (1995)', 1.9565217391304348),\n",
195 | " (tensor(-0.2393), 'Free Willy 3: The Rescue (1997)', 1.7407407407407407),\n",
196 | " (tensor(-0.2383), 'Leave It to Beaver (1997)', 1.8409090909090908),\n",
197 | " (tensor(-0.2355), 'Crow: City of Angels, The (1996)', 1.9487179487179487),\n",
198 | " (tensor(-0.2269),\n",
199 | " 'Lawnmower Man 2: Beyond Cyberspace (1996)',\n",
200 | " 1.7142857142857142),\n",
201 | " (tensor(-0.2221), 'Beautician and the Beast, The (1997)', 2.313953488372093),\n",
202 | " (tensor(-0.2202), 'Barb Wire (1996)', 1.9333333333333333),\n",
203 | " (tensor(-0.2080), 'Cable Guy, The (1996)', 2.339622641509434),\n",
204 | " (tensor(-0.2038), 'Beverly Hills Ninja (1997)', 2.3125),\n",
205 | " (tensor(-0.2036), 'Striptease (1996)', 2.2388059701492535),\n",
206 | " (tensor(-0.1995), 'Island of Dr. Moreau, The (1996)', 2.1578947368421053)]"
207 | ]
208 | },
209 | "execution_count": 17,
210 | "metadata": {},
211 | "output_type": "execute_result"
212 | }
213 | ],
214 | "source": [
215 | "item0 = lambda o:o[0]\n",
216 | "sorted(movie_ratings, key=item0)[:15]"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": null,
222 | "metadata": {},
223 | "outputs": [],
224 | "source": []
225 | }
226 | ],
227 | "metadata": {
228 | "kernelspec": {
229 | "display_name": "Python 3",
230 | "language": "python",
231 | "name": "python3"
232 | },
233 | "language_info": {
234 | "codemirror_mode": {
235 | "name": "ipython",
236 | "version": 3
237 | },
238 | "file_extension": ".py",
239 | "mimetype": "text/x-python",
240 | "name": "python",
241 | "nbconvert_exporter": "python",
242 | "pygments_lexer": "ipython3",
243 | "version": "3.7.6"
244 | }
245 | },
246 | "nbformat": 4,
247 | "nbformat_minor": 4
248 | }
249 |
--------------------------------------------------------------------------------
/fastai_xla_extensions/cifar_loader.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/02_cifar_loader.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['download_torch_dsets', 'load_torch_items', 'load_classes', 'CifarNP2ImageTransform', 'Int2TensorTransform',
4 | 'CifarImageTransform', 'CifarImage2FloatTransform', 'make_torch_tfms', 'CifarTupleTransform', 'TupleTorchDS',
5 | 'make_cifar_item_tfm', 'i2t_tfm', 'cfnp2img_tfm', 'cfimg_tfm', 'cfimg2float_tfm', 'make_cifar_tls',
6 | 'make_cifar_dl', 'make_fastai_cifar_dls']
7 |
8 | # Cell
9 | def download_torch_dsets(path, torch_dset):
10 | """Download cifar10 datasets using torchvision utils
11 |
12 | Arguments:
13 | path (pathlib.Path): path to download the dataset (aka root)
14 | """
15 | train_dataset = torch_dset(
16 | root=path,
17 | train=True,
18 | download=True
19 |
20 | )
21 | test_dataset = torch_dset(
22 | root=path,
23 | train=False,
24 | download=True,
25 | )
26 | return train_dataset,test_dataset
27 |
28 | # Internal Cell
29 | import numpy as np
30 | import torchvision.datasets.utils as tv_utils
31 | import torchvision.datasets.cifar as cifar_dsets
32 | import pickle
33 |
34 | # Cell
35 | def load_torch_items(downloaded_list, path, check=False):
36 | """loads cifar test/train items into tuple(data, target)
37 |
38 | scrobbled together from torch.data.utils.datasets.CIFAR10 code
39 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/cifar.html#CIFAR10
40 |
41 | Arguments:
42 | downloaded_list : a list of file names with their checksum, see CIFAR10.train_list or CIFAR10.test_list.
43 | path (pathlib.Path): the root path where the dataset was downloaded
44 | check(bool, optional): whether to perform an integrity check on the downloaded files (default: False)
45 | """
46 | data = []
47 | targets = []
48 | # now load the picked numpy arrays
49 | for file_name, checksum in downloaded_list:
50 | file_path = path/cifar_dsets.CIFAR10.base_folder/file_name
51 | if check and not tv_utils.check_integrity(file_path, checksum):
52 | raise RuntimeError(
53 | f'Data checksum failed for file:{file_path} checksum:{checksum}')
54 | with open(file_path, 'rb') as f:
55 | entry = pickle.load(f, encoding='latin1')
56 | data.append(entry['data'])
57 | if 'labels' in entry:
58 | targets.extend(entry['labels'])
59 | else:
60 | targets.extend(entry['fine_labels'])
61 |
62 | data = np.vstack(data).reshape(-1, 3, 32, 32)
63 | data = data.transpose((0, 2, 3, 1)) # convert to HWC
64 |
65 | return data, targets
66 |
67 | # Cell
68 | # TODO: incorporate list of classes into dataloaders vocab and decodes
69 |
70 | from fastcore.foundation import L
71 | def load_classes(path):
72 | """Load classes to used to map categories to target labels"""
73 | base_folder = cifar_dsets.CIFAR10.base_folder
74 | meta = cifar_dsets.CIFAR10.meta
75 | file_path = path/base_folder/meta['filename']
76 | if not tv_utils.check_integrity(file_path, meta['md5']):
77 | raise RuntimeError('Dataset metadata file not found or corrupted.' +
78 | ' You can use download=True to download it')
79 | data = {}
80 | with open(file_path, 'rb') as infile:
81 | data = pickle.load(infile, encoding='latin1')
82 | classes = data[meta['key']]
83 | # class_for i, _class in enumerato_idx = {_class: i te(classes)}
84 | return L(classes)
85 |
86 | # Internal Cell
87 | from fastcore.transform import Transform
88 | import torchvision.transforms.functional as TVF
89 | import torch
90 | import PIL
91 |
92 | # Cell
93 | class CifarNP2ImageTransform(Transform):
94 | def encodes(self, o:np.ndarray) -> None:
95 | return PIL.Image.fromarray(o)
96 |
97 | # Cell
98 | class Int2TensorTransform(Transform):
99 | def encodes(self, o: int) -> None:
100 | return torch.tensor(o)
101 |
102 | # Cell
103 | class CifarImageTransform(Transform):
104 | def encodes(self, o: PIL.Image) -> None:
105 | return TVF.to_tensor(o)
106 |
107 | # Cell
108 | class CifarImage2FloatTransform(Transform):
109 | def encodes(self, o: torch.Tensor) -> None:
110 | return o.float().div_(255.)
111 |
112 | # Internal Cell
113 | import torchvision as thv
114 |
115 | # Cell
116 | def make_torch_tfms():
117 | norm = thv.transforms.Normalize(
118 | mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
119 | transform_train = thv.transforms.Compose([
120 | thv.transforms.RandomCrop(32, padding=4),
121 | thv.transforms.RandomHorizontalFlip(),
122 | thv.transforms.ToTensor(),
123 | norm,
124 | ])
125 | transform_test = thv.transforms.Compose([
126 | thv.transforms.ToTensor(),
127 | norm,
128 | ])
129 | return transform_train, transform_test
130 |
131 | # Internal Cell
132 | from fastcore.transform import ItemTransform
133 | from fastcore.basics import store_attr
134 |
135 | # Cell
136 | class CifarTupleTransform(ItemTransform):
137 | def __init__(self, x_tfm, y_tfm):
138 | store_attr()
139 | def encodes(self,xy):
140 | return [self.x_tfm(xy[0]), self.y_tfm(xy[1])]
141 |
142 | # Internal Cell
143 | import torch.utils.data as th_data
144 | from torch.utils.data import Dataset
145 |
146 | # Cell
147 | # TODO: Use TupleTorchDS to create torch dataloaders
148 |
149 | class TupleTorchDS(th_data.Dataset):
150 | def __init__(self, items, x_tfm=None, y_tfm=None):
151 | store_attr()
152 |
153 | def __len__(self):
154 | return len(self.items)
155 |
156 | def __getitem__(self, index):
157 | x,y = self.items[index]
158 | x = self.x_tfm(x) if self.x_tfm is not None else x
159 | y = self.y_tfm(y) if self.y_tfm is not None else y
160 | return (x,y)
161 |
162 | # Internal Cell
163 | from fastcore.transform import Pipeline
164 |
165 | # Cell
166 | i2t_tfm = Int2TensorTransform() # cnvt int -> torch.tensor
167 | cfnp2img_tfm = CifarNP2ImageTransform() # cnvt ndarray -> PIL.Image
168 | cfimg_tfm = CifarImageTransform() # cnvt PIL.Image -> torch.tensor
169 | cfimg2float_tfm = CifarImage2FloatTransform() # cnvt tensor int -> float + div 255
170 |
171 | def make_cifar_item_tfm(th_img_tfms=None):
172 | img_tfms = [cfnp2img_tfm]
173 | if th_img_tfms is not None:
174 | # assumes th_img_tfms incl ToTensor (cnvt2 PIL.Image -> tensor + div by 255)
175 | img_tfms += [th_img_tfms]
176 | else:
177 | img_tfms += [cfimg_tfm, cfimg2float_tfm]
178 |
179 | return CifarTupleTransform(x_tfm=Pipeline(img_tfms), y_tfm=i2t_tfm)
180 |
181 | # Internal Cell
182 | from fastai.data.core import TfmdDL
183 | from fastai.data.core import TfmdLists
184 | from fastcore.foundation import L
185 |
186 | # Cell
187 | def make_cifar_tls(file_list, path, item_tfm, check=True):
188 | data, targets = load_torch_items(file_list, path, check=check)
189 | item_tuples = L(data,targets).zip()
190 | tls = TfmdLists(item_tuples,[item_tfm])
191 | return tls
192 |
193 | # Cell
194 | def make_cifar_dl(file_list, path, th_img_tfms=None, check=True, bs=64, **kwargs):
195 | item_tfm = make_cifar_item_tfm(th_img_tfms)
196 | tls = make_cifar_tls(file_list, path, item_tfm, check=check)
197 | dl = TfmdDL(tls,bs=bs, **kwargs)
198 | return dl
199 |
200 | # Internal Cell
201 | from fastai.data.core import DataLoaders
202 |
203 | # Cell
204 | def make_fastai_cifar_dls(path, bs=64, check=True, device=None, **kwargs):
205 | train_tfm, test_tfm = make_torch_tfms()
206 | train_dl = make_cifar_dl(
207 | cifar_dsets.CIFAR10.train_list,
208 | path,
209 | train_tfm,
210 | check=check, bs=bs,
211 | shuffle=True)
212 | test_dl = make_cifar_dl(
213 | cifar_dsets.CIFAR10.test_list,
214 | path,
215 | test_tfm,
216 | check=check, bs=bs,
217 | shuffle=False)
218 | dls = DataLoaders(train_dl, test_dl, device=device)
219 | return dls
--------------------------------------------------------------------------------
/docs/Gemfile.lock:
--------------------------------------------------------------------------------
1 | GEM
2 | remote: https://rubygems.org/
3 | specs:
4 | activesupport (6.0.3.2)
5 | concurrent-ruby (~> 1.0, >= 1.0.2)
6 | i18n (>= 0.7, < 2)
7 | minitest (~> 5.1)
8 | tzinfo (~> 1.1)
9 | zeitwerk (~> 2.2, >= 2.2.2)
10 | addressable (2.8.0)
11 | public_suffix (>= 2.0.2, < 5.0)
12 | coffee-script (2.4.1)
13 | coffee-script-source
14 | execjs
15 | coffee-script-source (1.11.1)
16 | colorator (1.1.0)
17 | commonmarker (0.17.13)
18 | ruby-enum (~> 0.5)
19 | concurrent-ruby (1.1.6)
20 | dnsruby (1.61.4)
21 | simpleidn (~> 0.1)
22 | em-websocket (0.5.1)
23 | eventmachine (>= 0.12.9)
24 | http_parser.rb (~> 0.6.0)
25 | ethon (0.12.0)
26 | ffi (>= 1.3.0)
27 | eventmachine (1.2.7)
28 | execjs (2.7.0)
29 | faraday (1.0.1)
30 | multipart-post (>= 1.2, < 3)
31 | ffi (1.13.1)
32 | forwardable-extended (2.6.0)
33 | gemoji (3.0.1)
34 | github-pages (207)
35 | github-pages-health-check (= 1.16.1)
36 | jekyll (= 3.9.0)
37 | jekyll-avatar (= 0.7.0)
38 | jekyll-coffeescript (= 1.1.1)
39 | jekyll-commonmark-ghpages (= 0.1.6)
40 | jekyll-default-layout (= 0.1.4)
41 | jekyll-feed (= 0.13.0)
42 | jekyll-gist (= 1.5.0)
43 | jekyll-github-metadata (= 2.13.0)
44 | jekyll-mentions (= 1.5.1)
45 | jekyll-optional-front-matter (= 0.3.2)
46 | jekyll-paginate (= 1.1.0)
47 | jekyll-readme-index (= 0.3.0)
48 | jekyll-redirect-from (= 0.15.0)
49 | jekyll-relative-links (= 0.6.1)
50 | jekyll-remote-theme (= 0.4.1)
51 | jekyll-sass-converter (= 1.5.2)
52 | jekyll-seo-tag (= 2.6.1)
53 | jekyll-sitemap (= 1.4.0)
54 | jekyll-swiss (= 1.0.0)
55 | jekyll-theme-architect (= 0.1.1)
56 | jekyll-theme-cayman (= 0.1.1)
57 | jekyll-theme-dinky (= 0.1.1)
58 | jekyll-theme-hacker (= 0.1.1)
59 | jekyll-theme-leap-day (= 0.1.1)
60 | jekyll-theme-merlot (= 0.1.1)
61 | jekyll-theme-midnight (= 0.1.1)
62 | jekyll-theme-minimal (= 0.1.1)
63 | jekyll-theme-modernist (= 0.1.1)
64 | jekyll-theme-primer (= 0.5.4)
65 | jekyll-theme-slate (= 0.1.1)
66 | jekyll-theme-tactile (= 0.1.1)
67 | jekyll-theme-time-machine (= 0.1.1)
68 | jekyll-titles-from-headings (= 0.5.3)
69 | jemoji (= 0.11.1)
70 | kramdown (= 2.3.0)
71 | kramdown-parser-gfm (= 1.1.0)
72 | liquid (= 4.0.3)
73 | mercenary (~> 0.3)
74 | minima (= 2.5.1)
75 | nokogiri (>= 1.10.4, < 2.0)
76 | rouge (= 3.19.0)
77 | terminal-table (~> 1.4)
78 | github-pages-health-check (1.16.1)
79 | addressable (~> 2.3)
80 | dnsruby (~> 1.60)
81 | octokit (~> 4.0)
82 | public_suffix (~> 3.0)
83 | typhoeus (~> 1.3)
84 | html-pipeline (2.13.0)
85 | activesupport (>= 2)
86 | nokogiri (>= 1.4)
87 | http_parser.rb (0.6.0)
88 | i18n (0.9.5)
89 | concurrent-ruby (~> 1.0)
90 | jekyll (3.9.0)
91 | addressable (~> 2.4)
92 | colorator (~> 1.0)
93 | em-websocket (~> 0.5)
94 | i18n (~> 0.7)
95 | jekyll-sass-converter (~> 1.0)
96 | jekyll-watch (~> 2.0)
97 | kramdown (>= 1.17, < 3)
98 | liquid (~> 4.0)
99 | mercenary (~> 0.3.3)
100 | pathutil (~> 0.9)
101 | rouge (>= 1.7, < 4)
102 | safe_yaml (~> 1.0)
103 | jekyll-avatar (0.7.0)
104 | jekyll (>= 3.0, < 5.0)
105 | jekyll-coffeescript (1.1.1)
106 | coffee-script (~> 2.2)
107 | coffee-script-source (~> 1.11.1)
108 | jekyll-commonmark (1.3.1)
109 | commonmarker (~> 0.14)
110 | jekyll (>= 3.7, < 5.0)
111 | jekyll-commonmark-ghpages (0.1.6)
112 | commonmarker (~> 0.17.6)
113 | jekyll-commonmark (~> 1.2)
114 | rouge (>= 2.0, < 4.0)
115 | jekyll-default-layout (0.1.4)
116 | jekyll (~> 3.0)
117 | jekyll-feed (0.13.0)
118 | jekyll (>= 3.7, < 5.0)
119 | jekyll-gist (1.5.0)
120 | octokit (~> 4.2)
121 | jekyll-github-metadata (2.13.0)
122 | jekyll (>= 3.4, < 5.0)
123 | octokit (~> 4.0, != 4.4.0)
124 | jekyll-mentions (1.5.1)
125 | html-pipeline (~> 2.3)
126 | jekyll (>= 3.7, < 5.0)
127 | jekyll-optional-front-matter (0.3.2)
128 | jekyll (>= 3.0, < 5.0)
129 | jekyll-paginate (1.1.0)
130 | jekyll-readme-index (0.3.0)
131 | jekyll (>= 3.0, < 5.0)
132 | jekyll-redirect-from (0.15.0)
133 | jekyll (>= 3.3, < 5.0)
134 | jekyll-relative-links (0.6.1)
135 | jekyll (>= 3.3, < 5.0)
136 | jekyll-remote-theme (0.4.1)
137 | addressable (~> 2.0)
138 | jekyll (>= 3.5, < 5.0)
139 | rubyzip (>= 1.3.0)
140 | jekyll-sass-converter (1.5.2)
141 | sass (~> 3.4)
142 | jekyll-seo-tag (2.6.1)
143 | jekyll (>= 3.3, < 5.0)
144 | jekyll-sitemap (1.4.0)
145 | jekyll (>= 3.7, < 5.0)
146 | jekyll-swiss (1.0.0)
147 | jekyll-theme-architect (0.1.1)
148 | jekyll (~> 3.5)
149 | jekyll-seo-tag (~> 2.0)
150 | jekyll-theme-cayman (0.1.1)
151 | jekyll (~> 3.5)
152 | jekyll-seo-tag (~> 2.0)
153 | jekyll-theme-dinky (0.1.1)
154 | jekyll (~> 3.5)
155 | jekyll-seo-tag (~> 2.0)
156 | jekyll-theme-hacker (0.1.1)
157 | jekyll (~> 3.5)
158 | jekyll-seo-tag (~> 2.0)
159 | jekyll-theme-leap-day (0.1.1)
160 | jekyll (~> 3.5)
161 | jekyll-seo-tag (~> 2.0)
162 | jekyll-theme-merlot (0.1.1)
163 | jekyll (~> 3.5)
164 | jekyll-seo-tag (~> 2.0)
165 | jekyll-theme-midnight (0.1.1)
166 | jekyll (~> 3.5)
167 | jekyll-seo-tag (~> 2.0)
168 | jekyll-theme-minimal (0.1.1)
169 | jekyll (~> 3.5)
170 | jekyll-seo-tag (~> 2.0)
171 | jekyll-theme-modernist (0.1.1)
172 | jekyll (~> 3.5)
173 | jekyll-seo-tag (~> 2.0)
174 | jekyll-theme-primer (0.5.4)
175 | jekyll (> 3.5, < 5.0)
176 | jekyll-github-metadata (~> 2.9)
177 | jekyll-seo-tag (~> 2.0)
178 | jekyll-theme-slate (0.1.1)
179 | jekyll (~> 3.5)
180 | jekyll-seo-tag (~> 2.0)
181 | jekyll-theme-tactile (0.1.1)
182 | jekyll (~> 3.5)
183 | jekyll-seo-tag (~> 2.0)
184 | jekyll-theme-time-machine (0.1.1)
185 | jekyll (~> 3.5)
186 | jekyll-seo-tag (~> 2.0)
187 | jekyll-titles-from-headings (0.5.3)
188 | jekyll (>= 3.3, < 5.0)
189 | jekyll-watch (2.2.1)
190 | listen (~> 3.0)
191 | jemoji (0.11.1)
192 | gemoji (~> 3.0)
193 | html-pipeline (~> 2.2)
194 | jekyll (>= 3.0, < 5.0)
195 | kramdown (2.3.0)
196 | rexml
197 | kramdown-parser-gfm (1.1.0)
198 | kramdown (~> 2.0)
199 | liquid (4.0.3)
200 | listen (3.2.1)
201 | rb-fsevent (~> 0.10, >= 0.10.3)
202 | rb-inotify (~> 0.9, >= 0.9.10)
203 | mercenary (0.3.6)
204 | mini_portile2 (2.8.0)
205 | minima (2.5.1)
206 | jekyll (>= 3.5, < 5.0)
207 | jekyll-feed (~> 0.9)
208 | jekyll-seo-tag (~> 2.1)
209 | minitest (5.14.1)
210 | multipart-post (2.1.1)
211 | nokogiri (1.13.6)
212 | mini_portile2 (~> 2.8.0)
213 | racc (~> 1.4)
214 | octokit (4.18.0)
215 | faraday (>= 0.9)
216 | sawyer (~> 0.8.0, >= 0.5.3)
217 | pathutil (0.16.2)
218 | forwardable-extended (~> 2.6)
219 | public_suffix (3.1.1)
220 | racc (1.6.0)
221 | rb-fsevent (0.10.4)
222 | rb-inotify (0.10.1)
223 | ffi (~> 1.0)
224 | rexml (3.2.5)
225 | rouge (3.19.0)
226 | ruby-enum (0.8.0)
227 | i18n
228 | rubyzip (2.3.0)
229 | safe_yaml (1.0.5)
230 | sass (3.7.4)
231 | sass-listen (~> 4.0.0)
232 | sass-listen (4.0.0)
233 | rb-fsevent (~> 0.9, >= 0.9.4)
234 | rb-inotify (~> 0.9, >= 0.9.7)
235 | sawyer (0.8.2)
236 | addressable (>= 2.3.5)
237 | faraday (> 0.8, < 2.0)
238 | simpleidn (0.1.1)
239 | unf (~> 0.1.4)
240 | terminal-table (1.8.0)
241 | unicode-display_width (~> 1.1, >= 1.1.1)
242 | thread_safe (0.3.6)
243 | typhoeus (1.4.0)
244 | ethon (>= 0.9.0)
245 | tzinfo (1.2.10)
246 | thread_safe (~> 0.1)
247 | unf (0.1.4)
248 | unf_ext
249 | unf_ext (0.0.7.7)
250 | unicode-display_width (1.7.0)
251 | zeitwerk (2.4.0)
252 |
253 | PLATFORMS
254 | ruby
255 |
256 | DEPENDENCIES
257 | github-pages
258 | jekyll (~> 3.9)
259 |
260 | BUNDLED WITH
261 | 2.1.4
262 |
--------------------------------------------------------------------------------
/archive_nbs/TabularTrainingGPU.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from fastai2.tabular.all import *"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "path = untar_data(URLs.ADULT_SAMPLE)\n",
19 | "df = pd.read_csv(path/'adult.csv')"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 3,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names=\"salary\",\n",
29 | " cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],\n",
30 | " cont_names = ['age', 'fnlwgt', 'education-num'],\n",
31 | " procs = [Categorify, FillMissing, Normalize])"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 4,
37 | "metadata": {},
38 | "outputs": [
39 | {
40 | "data": {
41 | "text/html": [
42 | "\n",
43 | " \n",
44 | " \n",
45 | " epoch \n",
46 | " train_loss \n",
47 | " valid_loss \n",
48 | " accuracy \n",
49 | " time \n",
50 | " \n",
51 | " \n",
52 | " \n",
53 | " \n",
54 | " 0 \n",
55 | " 0.369606 \n",
56 | " 0.363110 \n",
57 | " 0.827856 \n",
58 | " 00:04 \n",
59 | " \n",
60 | " \n",
61 | " 1 \n",
62 | " 0.367178 \n",
63 | " 0.364534 \n",
64 | " 0.828010 \n",
65 | " 00:04 \n",
66 | " \n",
67 | " \n",
68 | " 2 \n",
69 | " 0.348837 \n",
70 | " 0.367921 \n",
71 | " 0.827088 \n",
72 | " 00:04 \n",
73 | " \n",
74 | " \n",
75 | " 3 \n",
76 | " 0.355421 \n",
77 | " 0.364824 \n",
78 | " 0.830313 \n",
79 | " 00:04 \n",
80 | " \n",
81 | " \n",
82 | " 4 \n",
83 | " 0.352137 \n",
84 | " 0.357688 \n",
85 | " 0.834459 \n",
86 | " 00:04 \n",
87 | " \n",
88 | " \n",
89 | " 5 \n",
90 | " 0.344102 \n",
91 | " 0.360992 \n",
92 | " 0.836456 \n",
93 | " 00:04 \n",
94 | " \n",
95 | " \n",
96 | " 6 \n",
97 | " 0.341461 \n",
98 | " 0.357141 \n",
99 | " 0.838298 \n",
100 | " 00:04 \n",
101 | " \n",
102 | " \n",
103 | " 7 \n",
104 | " 0.330545 \n",
105 | " 0.355759 \n",
106 | " 0.837377 \n",
107 | " 00:04 \n",
108 | " \n",
109 | " \n",
110 | " 8 \n",
111 | " 0.343082 \n",
112 | " 0.357329 \n",
113 | " 0.836609 \n",
114 | " 00:04 \n",
115 | " \n",
116 | " \n",
117 | " 9 \n",
118 | " 0.327605 \n",
119 | " 0.357438 \n",
120 | " 0.835688 \n",
121 | " 00:04 \n",
122 | " \n",
123 | " \n",
124 | " 10 \n",
125 | " 0.320859 \n",
126 | " 0.357013 \n",
127 | " 0.836302 \n",
128 | " 00:04 \n",
129 | " \n",
130 | " \n",
131 | " 11 \n",
132 | " 0.320642 \n",
133 | " 0.360160 \n",
134 | " 0.833692 \n",
135 | " 00:04 \n",
136 | " \n",
137 | " \n",
138 | " 12 \n",
139 | " 0.330805 \n",
140 | " 0.360279 \n",
141 | " 0.836456 \n",
142 | " 00:04 \n",
143 | " \n",
144 | " \n",
145 | " 13 \n",
146 | " 0.321315 \n",
147 | " 0.359108 \n",
148 | " 0.835534 \n",
149 | " 00:04 \n",
150 | " \n",
151 | " \n",
152 | " 14 \n",
153 | " 0.330493 \n",
154 | " 0.358782 \n",
155 | " 0.836302 \n",
156 | " 00:04 \n",
157 | " \n",
158 | " \n",
159 | "
"
160 | ],
161 | "text/plain": [
162 | ""
163 | ]
164 | },
165 | "metadata": {},
166 | "output_type": "display_data"
167 | }
168 | ],
169 | "source": [
170 | "learn = tabular_learner(dls, metrics=accuracy)\n",
171 | "learn.fit_one_cycle(15)"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 5,
177 | "metadata": {},
178 | "outputs": [
179 | {
180 | "data": {
181 | "text/html": [],
182 | "text/plain": [
183 | ""
184 | ]
185 | },
186 | "metadata": {},
187 | "output_type": "display_data"
188 | },
189 | {
190 | "data": {
191 | "text/plain": [
192 | "( workclass education marital-status occupation relationship race \\\n",
193 | " 0 5.0 8.0 3.0 0.0 6.0 5.0 \n",
194 | " \n",
195 | " education-num_na age fnlwgt education-num salary \n",
196 | " 0 1.0 0.762793 -0.837396 0.752865 1.0 ,\n",
197 | " tensor(1),\n",
198 | " tensor([0.4151, 0.5849]))"
199 | ]
200 | },
201 | "execution_count": 5,
202 | "metadata": {},
203 | "output_type": "execute_result"
204 | }
205 | ],
206 | "source": [
207 | "learn.predict(df.iloc[0])\n"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": 6,
213 | "metadata": {},
214 | "outputs": [
215 | {
216 | "data": {
217 | "text/html": [],
218 | "text/plain": [
219 | ""
220 | ]
221 | },
222 | "metadata": {},
223 | "output_type": "display_data"
224 | },
225 | {
226 | "data": {
227 | "text/plain": [
228 | "(tensor([[4.1512e-01, 5.8488e-01],\n",
229 | " [1.9511e-01, 8.0489e-01],\n",
230 | " [9.9947e-01, 5.3267e-04],\n",
231 | " ...,\n",
232 | " [5.7765e-01, 4.2235e-01],\n",
233 | " [6.7786e-01, 3.2214e-01],\n",
234 | " [6.7229e-01, 3.2771e-01]]),\n",
235 | " None)"
236 | ]
237 | },
238 | "execution_count": 6,
239 | "metadata": {},
240 | "output_type": "execute_result"
241 | }
242 | ],
243 | "source": [
244 | "test_df = df.copy()\n",
245 | "test_df.drop(['salary'], axis=1, inplace=True)\n",
246 | "dl = learn.dls.test_dl(test_df)\n",
247 | "###\n",
248 | "learn.get_preds(dl=dl)"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": []
257 | }
258 | ],
259 | "metadata": {
260 | "kernelspec": {
261 | "display_name": "Python 3",
262 | "language": "python",
263 | "name": "python3"
264 | },
265 | "language_info": {
266 | "codemirror_mode": {
267 | "name": "ipython",
268 | "version": 3
269 | },
270 | "file_extension": ".py",
271 | "mimetype": "text/x-python",
272 | "name": "python",
273 | "nbconvert_exporter": "python",
274 | "pygments_lexer": "ipython3",
275 | "version": "3.7.6"
276 | }
277 | },
278 | "nbformat": 4,
279 | "nbformat_minor": 4
280 | }
281 |
--------------------------------------------------------------------------------
/fastai_xla_extensions/multi_core/torch_compat.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03a_multi_core.torch_compat.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['TfmdTorchDS', 'to_list', 'has_setup', 'run_setups', 'TorchDatasetBuilder', 'VocabularyMapper',
4 | 'make_torch_dataloaders', 'FileNamePatternLabeller']
5 |
6 | # Internal Cell
7 | from ..utils import xla_imported
8 | from .base import *
9 | from ..misc_utils import *
10 |
11 |
12 | # Internal Cell
13 | try:
14 | import torch_xla
15 | except ImportError:
16 | pass
17 |
18 | # Internal Cell
19 | if xla_imported():
20 | import torch_xla.core.xla_model as xm
21 | import torch_xla.distributed.xla_multiprocessing as xmp
22 |
23 | # Internal Cell
24 | from fastcore.basics import patch_to
25 | import torch
26 | import torch.utils.data as th_data
27 | from fastcore.foundation import L
28 | from pathlib import Path
29 | from fastcore.transform import Pipeline
30 | from fastai.data.core import DataLoaders
31 | from pathlib import Path
32 | from fastai.torch_core import find_bs, TensorBase
33 | from fastai.torch_core import TensorBase
34 | from fastcore.xtras import is_listy
35 | import torch.utils.hooks
36 | import torch.utils.data.distributed as th_distrib
37 |
38 | # Cell
39 | class TfmdTorchDS(th_data.Dataset):
40 | "A torch dataset compatible holder for items with x and y transforms"
41 | def __init__(self, items, x_tfm=None, y_tfm=None):
42 | self.items = items
43 | self.x_tfm = x_tfm
44 | self.y_tfm = y_tfm
45 |
46 | def __len__(self):
47 | return len(self.items)
48 |
49 | def __getitem__(self, index):
50 | item = self.items[index]
51 | x = self.x_tfm(item) if self.x_tfm is not None else item
52 | y = self.y_tfm(item) if self.y_tfm is not None else item
53 | return (x,y)
54 |
55 | # Internal Cell
56 | import torchvision as thv
57 | from operator import itemgetter
58 | from fastcore.imports import noop
59 |
60 | # Cell
61 | def to_list(o):
62 | "return item o as a list (unchanged if o is already a list and empty list if o is None)"
63 | return [] if o is None else [o] if not is_listy(o) else o
64 |
65 | def has_setup(tfms):
66 | """returns last index if at least 1 `tfm` in `tfms` has a method `setup` else return -1"""
67 | setups = L(tfms).attrgot('setup',None).argwhere(noop) # get indexes where tfm has `setup` attribute
68 | return -1 if len(setups) == 0 else setups[-1]
69 |
70 | def run_setups(tfms, items):
71 | """run tfm setups including tfm for all items"""
72 | indx = has_setup(tfms)
73 | if indx == -1: # no setup found
74 | return
75 |
76 | for i,tfm in enumerate(tfms):
77 | if hasattr(tfm,'setup'):
78 | tfm.setup(items)
79 | if i < indx:
80 | # tfm items to be fed into next tfm
81 | items = [tfm(item) for item in items]
82 |
83 | # Cell
84 | class TorchDatasetBuilder:
85 | "build torch compatible train and test datasets with transforms"
86 | def __init__(self, source, get_items, splitter,
87 | x_tfms, y_tfms,
88 | x_type_tfms=None,
89 | x_train_tfms=None, x_test_tfms=None,
90 | do_setup=False):
91 | self.source = source
92 | self.get_items = get_items
93 | self.splitter = splitter
94 | self.do_setup = do_setup
95 | self.x_tfms = to_list(x_tfms)
96 | self.y_tfms = to_list(y_tfms)
97 | self.x_type_tfms = to_list(x_type_tfms)
98 | self.x_train_tfms = to_list(x_train_tfms)
99 | self.x_test_tfms = to_list(x_test_tfms)
100 |
101 | def setup(self, items, do_setup=None, setup_x=False):
102 | self.do_setup = do_setup if do_setup is not None else self.do_setup
103 | if self.do_setup:
104 | all_x_tfms = [*self.x_type_tfms, *self.x_train_tfms, *self.x_tfms]
105 | if setup_x:
106 | run_setups(all_x_tfms, items)
107 | run_setups(self.y_tfms, items)
108 | self.do_setup = False
109 |
110 | def get_datasets(self, do_setup=None):
111 | self.do_setup = do_setup if do_setup is not None else self.do_setup
112 |
113 | items = self.get_items(self.source) if self.get_items is not None else self.source
114 |
115 | train_idxs, test_idxs = self.splitter(items)
116 |
117 | train_items = itemgetter(*train_idxs)(items)
118 | test_items = itemgetter(*test_idxs)(items)
119 | self.setup(train_items)
120 | allx_test_tfms = [*self.x_type_tfms, *self.x_test_tfms, *self.x_tfms]
121 | allx_train_tfms = [*self.x_type_tfms, *self.x_train_tfms, *self.x_tfms]
122 | train_x_tfm = thv.transforms.Compose(allx_train_tfms)
123 | test_x_tfm = thv.transforms.Compose(allx_test_tfms)
124 | y_tfm = thv.transforms.Compose(self.y_tfms)
125 | train_ds = TfmdTorchDS(train_items, x_tfm=train_x_tfm, y_tfm=y_tfm)
126 | test_ds = TfmdTorchDS(test_items, x_tfm=test_x_tfm, y_tfm=y_tfm)
127 | return train_ds, test_ds
128 |
129 | # Cell
130 | from fastai.data.transforms import CategoryMap
131 |
132 | class VocabularyMapper:
133 | """A simplified version of the fastai Categorize Transform"""
134 | def __init__(self, vocab=None):
135 | self.vocab = vocab
136 | self.c = 0
137 | def setup(self, items):
138 | self.vocab = CategoryMap(items)
139 | self.c = len(self.vocab)
140 | def __call__(self, o):
141 | if self.vocab is None: return o
142 | try:
143 | return torch.tensor(self.vocab.o2i[o])
144 | except KeyError as e:
145 | raise KeyError(f"Label '{o}' was not included in the training dataset") from e
146 |
147 | # Cell
148 | def make_torch_dataloaders(train_dataset, test_dataset,
149 | rank,
150 | world_size,
151 | bs,
152 | num_workers=4,
153 | distrib=True,
154 | sync_valid=False):
155 | "make torch-based distributed dataloaders from torch compatible datasets"
156 | if distrib:
157 | train_sampler = th_distrib.DistributedSampler(
158 | train_dataset,
159 | num_replicas=world_size,
160 | rank=rank,
161 | shuffle=True)
162 | train_loader = th_data.DataLoader(
163 | train_dataset,
164 | batch_size=bs,
165 | sampler=train_sampler,
166 | # shuffle=True,
167 | num_workers=num_workers,
168 | drop_last=True)
169 |
170 | if sync_valid:
171 | test_sampler = th_distrib.DistributedSampler(
172 | test_dataset,
173 | num_replicas=world_size,
174 | rank=rank,
175 | shuffle=False)
176 |
177 | test_loader = th_data.DataLoader(
178 | test_dataset,
179 | batch_size=bs,
180 | sampler=test_sampler,
181 | # shuffle=False,
182 | num_workers=num_workers,
183 | drop_last=True)
184 | else:
185 | test_loader = th_data.DataLoader(
186 | test_dataset,
187 | batch_size=bs,
188 | shuffle=False,
189 | num_workers=num_workers,
190 | drop_last=True)
191 |
192 | else:
193 | train_loader = th_data.DataLoader(
194 | train_dataset,
195 | batch_size=bs,
196 | # sampler=train_sampler,
197 | shuffle=True,
198 | num_workers=num_workers,
199 | drop_last=True)
200 |
201 | test_loader = th_data.DataLoader(
202 | test_dataset,
203 | batch_size=bs,
204 | shuffle=False,
205 | num_workers=num_workers,
206 | drop_last=True)
207 | dataloaders = DataLoaders(train_loader, test_loader, device=None)
208 | return dataloaders
209 |
210 | # Internal Cell
211 | import re
212 |
213 | # Cell
214 | class FileNamePatternLabeller:
215 | "Delayed action version of fastai RegexLabeller with file name selection"
216 | def __init__(self, pat_str, match=False):
217 | self.pat_str = pat_str
218 | self.match = match
219 | self.matcher = None
220 | self.pat = None
221 | def __call__(self, f):
222 | if isinstance(f,str):
223 | f = Path(f)
224 | o = f.name
225 | if self.pat is None:
226 | self.pat = re.compile(self.pat_str)
227 | self.matcher = self.pat.match if self.match else self.pat.search
228 | res = self.matcher(o)
229 | assert res, f'Failed to find "{self.pat}" in {o}'
230 | return res.group(1)
--------------------------------------------------------------------------------
/fastai_xla_extensions/multi_core/inference.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03e_multi_core.inference.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['setup_inference_args', 'save_pred_results', 'xla_run_inference', 'reload_pred_results',
4 | 'prep_inference_args']
5 |
6 | # Cell
7 | try:
8 | import torch_xla
9 | except ImportError:
10 | pass
11 |
12 | # Cell
13 |
14 | from fastai.vision.all import *
15 | from ..utils import xla_imported
16 | from ..misc_utils import *
17 | from ..core import XLAOptCallback
18 | from .base import *
19 | from .learner import *
20 | from .callback import *
21 |
22 | # Cell
23 |
24 | if xla_imported():
25 | import torch_xla.core.xla_model as xm
26 | import torch_xla.distributed.parallel_loader as pl
27 | import torch_xla.distributed.xla_multiprocessing as xmp
28 |
29 | # Cell
30 | from fastai.vision.all import *
31 |
32 |
33 | # Cell
34 | from fastai.learner import _ConstantFunc
35 | # from fastcore.basics import patch
36 | # from fastai.learner import Learner
37 |
38 | @patch
39 | def inner_get_preds(self:Learner, ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None,
40 | inner=False, reorder=True, cbs=None, **kwargs):
41 |
42 | xla_rank = getattr(self,'xla_rank',None)
43 | if xla_rank is None:
44 | return
45 |
46 | if dl is None:
47 | dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
48 | else:
49 | try: len(dl)
50 | except TypeError as e:
51 | raise TypeError("`dl` is something other than a single `DataLoader` object")
52 | if not isinstance(dl, TPUDistributedDL):
53 | world_size = kwargs.pop('world_size', xm.xrt_world_size())
54 | seed = kwargs.pop('dl_seed',42)
55 | dl = TPUDistributedDL(dl, xla_rank, world_size=world_size, seed=seed)
56 |
57 | if reorder and hasattr(dl, 'get_idxs'):
58 | idxs = dl.dl.get_idxs()
59 | dl = dl.new(get_idxs = _ConstantFunc(idxs))
60 | rank_idxs = dl.get_idxs()
61 | rank_idxs_len = len(rank_idxs)
62 |
63 | #handle save_preds and save_targs across ranks
64 | save_preds = kwargs.pop('save_preds',None)
65 | if save_preds is not None:
66 | if isinstance(save_preds, str):
67 | kwargs['save_preds'] = Path(save_preds + str(xla_rank)) # add rank to filename
68 | elif isinstance(save_preds, Path):
69 | kwargs['save_preds'] = Path(str(save_preds) + str(xla_rank))
70 | kwargs['save_preds'].mkdir(parents=True,exist_ok=True)
71 | save_targs = kwargs.pop('save_targs',None)
72 | if save_targs is not None:
73 | if isinstance(save_targs, str):
74 | kwargs['save_targs'] = Path(save_targs + str(xla_rank)) # add rank to filename
75 | elif isinstance(save_preds, Path):
76 | kwargs['save_targs'] = Path(str(save_targs) + str(xla_rank))
77 | kwargs['save_targs'].mkdir(parents=True,exist_ok=True)
78 |
79 | cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs)
80 | ctx_mgrs = self.validation_context(cbs=L(cbs)+[cb], inner=inner)
81 | if with_loss:
82 | ctx_mgrs.append(self.loss_not_reduced())
83 |
84 | with ContextManagers(ctx_mgrs):
85 | self._do_epoch_validate(dl=dl)
86 |
87 | if act is None:
88 | act = getattr(self.loss_func, 'activation', noop)
89 |
90 | res = cb.all_tensors()
91 |
92 | pred_i = 1 if with_input else 0
93 | if res[pred_i] is not None:
94 | if act != noop:
95 | # compute activation on tpu device and detach after
96 | tmp_pred = res[pred_i].to(xm.xla_device())
97 | tmp_res = act(tmp_pred)
98 | res[pred_i] = self.to_detach(tmp_res)
99 |
100 | if with_decoded:
101 | res.insert(pred_i+2, getattr(self.loss_func, 'decodes', noop)(res[pred_i]))
102 |
103 | if reorder and hasattr(dl, 'get_idxs'):
104 | t_idxs = tensor(rank_idxs)
105 | start_idx = xla_rank * rank_idxs_len
106 | t_idxs = t_idxs - tensor(start_idx) # broadcast
107 | sorted_idxs = t_idxs.argsort()
108 | res = nested_reorder(res, sorted_idxs )
109 |
110 | return tuple(res)
111 | self._end_cleanup()
112 |
113 |
114 | # Cell
115 |
116 | def setup_inference_args(rank, inference_args):
117 | master_cbs = ifnone(inference_args.pop('master_cbs', None),[])
118 | return inference_args, master_cbs
119 |
120 |
121 | # Cell
122 |
123 | import pickle
124 | def save_pred_results(rank, results):
125 | fn = f'preds{rank}.pkl'
126 | fn = Path(fn)
127 | with open(fn,'wb') as f:
128 | pickle.dump(results, f)
129 |
130 | # Cell
131 |
132 | def xla_run_inference(rank, learner_args, add_args, inference_args, ctrl_args):
133 | sync_valid = True
134 | learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args)
135 | pred_args, master_cbs = setup_inference_args(rank, inference_args)
136 |
137 | if rank == 0 and len(master_cbs) > 0:
138 | learner.add_cbs(master_cbs)
139 |
140 | # learner.synced_cancel.before_fit()
141 |
142 | if rank == 0:
143 | learner.sync_recorder.orig_logger = learner.logger
144 |
145 | results = learner.inner_get_preds(**pred_args)
146 | xm.rendezvous('xla_run_inference')
147 |
148 | save_pred_results(rank, results)
149 | xm.mark_step()
150 |
151 |
152 | # Cell
153 | from fastcore.foundation import L
154 |
155 | def reload_pred_results(num_files, n_samples):
156 | all_preds = L()
157 | for rank in range(num_files):
158 | fn = f'preds{rank}.pkl'
159 |
160 | fn = Path(fn)
161 | if fn.is_file():
162 | with open(fn,'rb') as f:
163 | rank_preds = pickle.load(f)
164 | all_preds.append(rank_preds)
165 | else:
166 | raise RuntimeException(f'Missing preds file for rank {rank}')
167 |
168 | for rank in range(num_files):
169 | fn = f'preds{rank}.pkl'
170 | fn = Path(fn)
171 | fn.unlink()
172 |
173 | n_items = len(all_preds[0]) # num items per preds
174 |
175 | all_res = []
176 | for i in range(n_items):
177 | items = all_preds.itemgot(i)
178 |
179 | if isinstance(items[0], torch.Tensor):
180 | all_items = torch.cat(tuple(items))
181 | elif is_listy(items[0]):
182 | all_items = [*items]
183 | else:
184 | all_items = items
185 | all_res.append(all_items)
186 | res = []
187 | for i, pred in enumerate(all_res):
188 | pred = pred[:n_samples] # take only first
189 | res.append(pred)
190 | return res
191 |
192 |
193 |
194 | # Cell
195 |
196 | @patch
197 | def pre_xla_inference(self:Learner):
198 | ctrl_args = {}
199 | progress_removed = False
200 | if 'progress' in L(self.cbs).attrgot('name'):
201 | self.remove_cbs(ProgressCallback)
202 | progress_removed = True
203 | ctrl_args['use_progress'] = progress_removed
204 | return ctrl_args
205 |
206 | # Cell
207 |
208 | @patch
209 | def post_xla_inference(self:Learner, ctrl_args):
210 | if ctrl_args['use_progress']:
211 | self.add_cbs(ProgressCallback)
212 | self.recorder.reload_attrs()
213 |
214 | # Cell
215 |
216 | def prep_inference_args(**kwargs):
217 | return kwargs
218 |
219 | # Cell
220 |
221 | #export
222 |
223 | @patch
224 | @delegates(Learner.get_preds, but='num_cores,start_method,master_cbs')
225 | def xla_get_preds(self:Learner, ds_idx=1, dl=None,
226 | with_input=False, with_decoded=False,
227 | with_loss=False, act=None, inner=False,
228 | reorder=True, cbs=None, num_cores=8,
229 | start_method='fork', master_cbs=None,**kwargs):
230 | ctrl_args = self.pre_xla_inference()
231 | learner_args, add_args = self.pack_learner_args()
232 |
233 | inference_args = prep_inference_args(ds_idx=ds_idx, dl=dl,
234 | with_input=with_input, with_decoded=with_decoded,
235 | with_loss=with_loss,
236 | act=act, inner=inner,
237 | reorder=reorder,
238 | cbs=cbs, master_cbs=master_cbs, **kwargs)
239 | if dl:
240 | n_results = len(dl.dataset)
241 | else:
242 | n_results = len(self.dls.loaders[ds_idx].dataset)
243 |
244 | xmp.spawn(xla_run_inference,
245 | args=(learner_args, add_args, inference_args, ctrl_args),
246 | nprocs=num_cores,
247 | start_method=start_method)
248 |
249 | all_results = reload_pred_results(num_cores, n_results)
250 | self.post_xla_inference(ctrl_args)
251 | return all_results
252 |
--------------------------------------------------------------------------------
/nbs/99_dev_setup.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "#default_exp dev_setup"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# Development Setup"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "[ \n",
24 | "](https://)"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {},
30 | "source": [
31 | "> Notes about `fastai_xla_extensions` development setup \n",
32 | "\n",
33 | "\n",
34 | "The `fastai_xla_extensions` project uses `nbdev` to build the package from Jupyter notebooks.\n",
35 | "\n",
36 | "In order to contribute or update the package, the development environment\n",
37 | "must be setup to run on Colab, which provides access to TPUs.\n",
38 | "\n",
39 | "The following steps are based on [instructions from `nbdev` documentation](https://nbdev.fast.ai/tutorial_colab.html#Setting-up-your-Google-Drive-and-Git-Configuration) on running nbdev on Colab."
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "## Google Drive setup\n",
47 | "\n",
48 | "Connect your Colab instance to your google drive"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "#colab\n",
58 | "from google.colab import drive\n",
59 | "drive.mount('/content/drive')\n",
60 | "%cd /content/drive"
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {},
66 | "source": [
67 | "## Git setup\n",
68 | "\n",
69 | "1. Fork the main fastai_xla_extensions github repo on Github\n",
70 | "\n",
71 | "2. Clone the forked fastai_xla_extensions github repo\n",
72 | " - in the clone command below, replace the _butchland_ user id with your own github user id\n"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "#colab\n",
82 | "%cd /content/drive/MyDrive\n",
83 | "!git clone https://github.com/butchland/fastai_xla_extensions.git"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "#colab\n",
93 | "%cd /content\n",
94 | "!ln -s /content/drive/MyDrive/fastai_xla_extensions fastai_xla_extensions"
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "metadata": {},
100 | "source": [
101 | "## Install Pytorch XLA"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "metadata": {},
108 | "outputs": [
109 | {
110 | "name": "stdout",
111 | "output_type": "stream",
112 | "text": [
113 | "\u001b[K |████████████████████████████████| 133.6MB 77kB/s \n",
114 | "\u001b[K |████████████████████████████████| 61kB 3.1MB/s \n",
115 | "\u001b[?25h"
116 | ]
117 | }
118 | ],
119 | "source": [
120 | "#colab\n",
121 | "!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "## Install fastai and other dependencies\n",
129 | "\n"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "#colab\n",
139 | "!pip install -Uqq fastai==2.3.0 \n",
140 | "!pip install -qqq nbdev"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "metadata": {},
146 | "source": [
147 | "## Import *fastai_xla_extensions* modules"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "#colab\n",
157 | "%cd /content/fastai_xla_extensions\n",
158 | "from fastai_xla_extensions.all import *\n",
159 | "%cd /content"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | "## Run nbdev commands "
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "#colab\n",
176 | "%cd /content/fastai_xla_extensions\n",
177 | "!nbdev_clean_nbs && nbdev_build_lib && nbdev_build_docs"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | "## Run git commands \n"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "metadata": {},
190 | "source": [
191 | "### Specify Github Credentials\n",
192 | "\n",
193 | "Your `github_id` and `github_repo` should contain the information you previously used as your Github ID and repo when you forked the `fastai_xla_extensions` repo. \n",
194 | "\n"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "#colab\n",
204 | "#@title Enter your github info {display-mode: \"form\"}\n",
205 | "github_id = \"\" #@param {type: \"string\"}\n",
206 | "user_email = \"\" #@param {type: \"string\"}\n",
207 | "real_name = \"\" #@param {type: \"string\"}\n"
208 | ]
209 | },
210 | {
211 | "cell_type": "markdown",
212 | "metadata": {},
213 | "source": [
214 | "\n",
215 | "Check that the github ID, repo, email and name have been filled out\n"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": null,
221 | "metadata": {},
222 | "outputs": [],
223 | "source": [
224 | "#colab\n",
225 | "github_repo = 'fastai_xla_extensions'\n",
226 | "if github_id == \"\" or github_repo == \"\" or user_email == \"\" or real_name == \"\":\n",
227 | " print(\"Rerun your notebook by pressing Cmd/Ctrl-F9 or menu Runtime/Run all\")\n",
228 | " raise RuntimeError(\"Please enter your Github ID and Repo as well as your user email and name\")"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": null,
234 | "metadata": {},
235 | "outputs": [],
236 | "source": [
237 | "#colab\n",
238 | "!git config --global user.name \"{real_name}\"\n",
239 | "!git config --global user.email \"{user_email}\""
240 | ]
241 | },
242 | {
243 | "cell_type": "code",
244 | "execution_count": null,
245 | "metadata": {},
246 | "outputs": [],
247 | "source": [
248 | "#colab\n",
249 | "from pathlib import Path\n",
250 | "if not (Path('/content')/github_repo).is_dir():\n",
251 | " print('You might have entered the wrong github credentials')\n",
252 | " raise RuntimeError(f'Could not download your github repo https://github.com/{github_id}/{github_repo}.git')\n",
253 | "%cd /content/{github_repo}"
254 | ]
255 | },
256 | {
257 | "cell_type": "markdown",
258 | "metadata": {},
259 | "source": [
260 | "Run git commands to add your changes to your local repo."
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "metadata": {},
267 | "outputs": [],
268 | "source": [
269 | "#colab\n",
270 | "# !git status\n",
271 | "# !git add -all\n",
272 | "# !git commit -m \"\""
273 | ]
274 | },
275 | {
276 | "cell_type": "markdown",
277 | "metadata": {},
278 | "source": [
279 | "### Enter Github Password \n",
280 | "Please enter your github password as requested\n"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": null,
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "#colab\n",
290 | "#@title {display-mode: \"form\"}\n",
291 | "print('Please enter your password.')\n",
292 | "import getpass\n",
293 | "github_password = getpass.getpass()\n"
294 | ]
295 | },
296 | {
297 | "cell_type": "markdown",
298 | "metadata": {},
299 | "source": [
300 | "## Push changes to repo"
301 | ]
302 | },
303 | {
304 | "cell_type": "code",
305 | "execution_count": null,
306 | "metadata": {},
307 | "outputs": [],
308 | "source": [
309 | "#colab\n",
310 | "#@title {display-mode: \"form\"}\n",
311 | "!git config credential.helper store\n",
312 | "!echo \"https://{github_id}:{github_password}@github.com\" > /root/.git-credentials\n",
313 | "!git push\n",
314 | "!rm -f /root/.git-credentials\n"
315 | ]
316 | }
317 | ],
318 | "metadata": {
319 | "kernelspec": {
320 | "display_name": "Python 3 (ipykernel)",
321 | "language": "python",
322 | "name": "python3"
323 | }
324 | },
325 | "nbformat": 4,
326 | "nbformat_minor": 4
327 | }
328 |
--------------------------------------------------------------------------------
/fastai_xla_extensions/multi_core/lr_find.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03d_multi_core.lr_find.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['SkipValidationCallback', 'SyncedCancelCallback', 'XLALRFinder', 'xla_run_lr_find']
4 |
5 | # Internal Cell
6 | from ..utils import xla_imported
7 | from .base import *
8 | from .callback import *
9 | from .learner import *
10 | from ..misc_utils import *
11 | from ..core import *
12 |
13 |
14 | # Internal Cell
15 | try:
16 | import torch_xla
17 | except:
18 | pass
19 |
20 |
21 | # Internal Cell
22 | if xla_imported():
23 | import torch_xla.core.xla_model as xm
24 | import torch_xla.distributed.xla_multiprocessing as xmp
25 | import torch_xla.distributed.parallel_loader as pl
26 |
27 | # Cell
28 |
29 | from fastai.callback.core import Callback
30 | from fastai.learner import CancelValidException
31 | class SkipValidationCallback(Callback):
32 | order,run_valid = -9, False
33 | # raise CancelValidException before XLATrainingCallback.before_validate
34 | # to prevent call to wrap_parallel_loader on before_validate
35 | def before_validate(self):
36 | raise CancelValidException()
37 |
38 | def after_cancel_validate(self):
39 | xm.mark_step()
40 |
41 |
42 | # Cell
43 |
44 | from fastcore.basics import patch
45 | # uncomment for notebook2html
46 | # import torch_xla.distributed.parallel_loader as pl
47 | from ..utils import xla_imported
48 |
49 | if xla_imported():
50 | @patch
51 | def close(self:pl.PerDeviceLoader):
52 | 'close data loader queues on xla parallel loader'
53 | self._loader.close() #
54 | else:
55 | def close(self):
56 | pass
57 |
58 | # Cell
59 |
60 | import torch
61 | from fastai.callback.core import Callback
62 | from fastai.learner import CancelFitException
63 |
64 | class SyncedCancelCallback(Callback):
65 | """A Callback to cancel training in sync
66 | (closing data loaders queues across all ranks)"""
67 | order = 199 # after all other callbacks
68 |
69 | def before_fit(self):
70 | if not getattr(self.learn,'inner_xla',False):
71 | return # skip if not spawned
72 |
73 | self.zero = torch.zeros(1).to(self.xla_training.pdevice)
74 | self.one = torch.ones(1).to(self.xla_training.pdevice)
75 | self.sync_cancel_fit = self.zero
76 |
77 | def after_batch(self):
78 | if not getattr(self.learn,'inner_xla',False):
79 | return # skip if not spawned
80 |
81 | cancel_fit = xm.all_reduce(xm.REDUCE_SUM, self.sync_cancel_fit)
82 |
83 | if cancel_fit > self.zero: # a rank triggered a cancel
84 | self.dl.close() # close per device loader
85 | raise CancelFitException()
86 |
87 | def trigger_cancel_fit(self):
88 | self.sync_cancel_fit = self.one
89 |
90 | # Cell
91 |
92 | from fastai.callback.schedule import ParamScheduler, SchedExp
93 | from fastcore.xtras import is_listy
94 | from fastcore.imports import noop
95 | class XLALRFinder(ParamScheduler):
96 | "Training with exponentially growing learning rate"
97 | def __init__(self, start_lr=1e-7, end_lr=10, num_it=100, stop_div=True):
98 | if is_listy(start_lr):
99 | self.scheds = {'lr': [SchedExp(s, e) for (s,e) in zip(start_lr,end_lr)]}
100 | else: self.scheds = {'lr': SchedExp(start_lr, end_lr)}
101 | self.num_it,self.stop_div = num_it,stop_div
102 | self.skip_batch = False
103 |
104 |
105 |
106 | def before_fit(self):
107 | super().before_fit()
108 | # no need to save orig weights
109 | # since learner instances are transient on spawned procs
110 | # self.learn.save('_tmp')
111 | self.best_loss = float('inf')
112 | self.skip_batch = False
113 |
114 | def before_epoch(self):
115 | # dont report losses while running lrfind (override sync_recorder)
116 | if not xm.is_master_ordinal():
117 | return
118 | if hasattr(self.learn, 'sync_recorder'):
119 | self.learn.logger = noop
120 | self.learn.sync_recorder._sync_stats_log = noop
121 |
122 | def before_batch(self):
123 | if self.skip_batch:
124 | return
125 | self._update_val(self.train_iter/self.num_it)
126 |
127 | def after_batch(self):
128 | if self.skip_batch:
129 | return
130 | super().after_batch()
131 | smooth_loss = self.smooth_loss.item() # move xla tensor to cpu
132 | if smooth_loss < self.best_loss:
133 | self.best_loss = smooth_loss
134 |
135 | # handle continuation of batch iteration until all batches exhausted
136 | if smooth_loss > 4*self.best_loss and self.stop_div:
137 | # print(f'xla {xm.get_ordinal()}: stop stats collection due to loss')
138 | self.skip_batch = True
139 | # self.copy_losses_and_lrs()
140 | self.synced_cancel.trigger_cancel_fit()
141 | return
142 |
143 |
144 | if self.train_iter >= self.num_it:
145 | # print(f'xla {xm.get_ordinal()}: stop stats collection due to num_iter')
146 | # return and stop updating losses
147 | self.skip_batch = True
148 | # self.copy_losses_and_lrs()
149 | self.synced_cancel.trigger_cancel_fit()
150 | return
151 |
152 | # def copy_losses_and_lrs(self):
153 | # if xm.is_master_ordinal():
154 | # losses = [loss.item() for loss in self.recorder.losses]
155 | # iters = self.recorder.iters[:]
156 | # values = self.recorder.values[:]
157 |
158 | # self.plot_data = {'lrs': self.recorder.lrs[:],
159 | # 'losses': losses,
160 | # 'iters': iters,
161 | # 'values': values}
162 | # if hasattr(self,'hps'):
163 | # self.plot_data['hps'] = {**self.hps}
164 |
165 | # def after_fit(self):
166 | # super().after_fit()
167 | # # no need to load old weights since these will be transient
168 | # # self.learn.opt.zero_grad() #Need to zero the gradients of the model before detaching the optimizer for future fits
169 | # # tmp_f = self.path/self.model_dir/'_tmp.pth'
170 | # # if tmp_f.exists():
171 | # # self.learn.load('_tmp', with_opt=True)
172 | # # os.remove(tmp_f)
173 | # if not self.skip_batch:
174 | # self.copy_losses_and_lrs()
175 | # if xm.is_master_ordinal():
176 | # with open('_plt_loss.pkl','wb') as f:
177 | # pickle.dump(self.plot_data,f)
178 |
179 |
180 | # Cell
181 | from fastai.learner import Learner
182 | from fastai.callback.schedule import SuggestedLRs
183 | from fastcore.basics import patch
184 | from fastai.torch_core import tensor
185 | @patch
186 | def get_suggested_lrs(self:Learner, num_it):
187 | 'compute Suggested LRs'
188 | lrs,losses = tensor(self.recorder.lrs[num_it//10:-5]),tensor(self.recorder.losses[num_it//10:-5])
189 | if len(losses) == 0: return
190 | lr_min = lrs[losses.argmin()].item()
191 | grads = (losses[1:]-losses[:-1]) / (lrs[1:].log()-lrs[:-1].log())
192 | lr_steep = lrs[grads.argmin()].item()
193 | return SuggestedLRs(lr_min/10.,lr_steep)
194 |
195 |
196 | # Cell
197 |
198 | def xla_run_lr_find(rank, learner_args, add_args, lr_find_args, ctrl_args):
199 | xm.rendezvous('start_xla_run_lr_find')
200 | # print(f'xla {rank} : start run lrfind')
201 | sync_valid = True
202 | learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args)
203 |
204 | num_it = lr_find_args['num_it']
205 | n_epoch = num_it//len(learner.dls.train) + 1
206 | learner.opt = None
207 | learner.create_opt()
208 | cb = XLALRFinder(**lr_find_args)
209 |
210 | skip_valid_cb = SkipValidationCallback()
211 | synced_cancel_cb = SyncedCancelCallback()
212 |
213 | with learner.no_logging():
214 | learner.fit(n_epoch, cbs=[cb, skip_valid_cb, synced_cancel_cb])
215 |
216 |
217 |
218 | # Cell
219 |
220 | from pathlib import Path
221 | from fastai.learner import Learner
222 | from fastcore.basics import patch
223 | from fastcore.meta import delegates
224 |
225 | @patch
226 | @delegates(Learner.lr_find)
227 | def xla_lr_find(self:Learner, num_cores=8, start_method='fork', **kwargs):
228 | lr_find_args = {
229 | 'start_lr': 1e-7,
230 | 'end_lr': 10.,
231 | 'num_it': 100,
232 | 'stop_div': True
233 | }
234 | fn = Path('_plt_loss.pkl')
235 | if fn.is_file():
236 | fn.unlink()
237 | # remove show_plot and suggestions param
238 | show_plot = kwargs.pop('show_plot', True)
239 | suggestions = kwargs.pop('suggestions',True)
240 | # override default with kwargs
241 | lr_find_args = {**lr_find_args, **kwargs}
242 |
243 | ctrl_args = self.pre_xla_fit()
244 | learner_args, add_args = self.pack_learner_args()
245 | xmp.spawn(xla_run_lr_find,
246 | args=(learner_args, add_args, lr_find_args, ctrl_args),
247 | nprocs=num_cores,
248 | start_method=start_method)
249 | self.post_xla_fit(ctrl_args)
250 | # self.recorder.reload_lr_find_attrs()
251 | if show_plot:
252 | # show_loss()
253 | self.recorder.plot_lr_find()
254 | if suggestions:
255 | return self.get_suggested_lrs(lr_find_args['num_it'])
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------