├── hugdatafast ├── __init__.py ├── transform.py └── fastai.py ├── docs ├── requirements.txt ├── Makefile ├── source │ ├── index.rst │ ├── conf.py │ └── start.rst └── make.bat ├── .gitignore ├── README.md ├── setup.py ├── LICENSE └── tests └── hf_nlp_extension_test.ipynb /hugdatafast/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.4.5" 2 | from .fastai import * 3 | from .transform import * -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx 2 | sphinx-autoapi 3 | sphinxcontrib-napoleon 4 | sphinx-copybutton 5 | sphinx_rtd_theme>=0.5.0 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | # Distribution / packaging 4 | .Python 5 | build/ 6 | develop-eggs/ 7 | dist/ 8 | downloads/ 9 | eggs/ 10 | .eggs/ 11 | lib/ 12 | lib64/ 13 | parts/ 14 | sdist/ 15 | var/ 16 | wheels/ 17 | *.egg-info/ 18 | .installed.cfg 19 | *.egg 20 | MANIFEST 21 | 22 | # PyInstaller 23 | # Usually these files are written by a python script from a template 24 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 25 | *.manifest 26 | *.spec 27 | 28 | # Installer logs 29 | pip-log.txt 30 | pip-delete-this-directory.txt 31 | 32 | # editor 33 | .vscode/* -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | hugdatafast 2 | ============= 3 | 4 | This package is to provide a elegant bridge between fastai and huggingface/datasets and some handy data transforms 5 | for NLPers. 6 | 7 | Author: Richard Wang 8 | 9 | Twitter: `Richard Wang `_ (You can follow to get news of the package if there is. Or see my recent research.) 10 | 11 | Installation 12 | --------------- 13 | 14 | :: 15 | 16 | pip install hugdatafast 17 | 18 | This will install also the lastest ``fastai`` and ``datasets``. 19 | 20 | .. toctree:: 21 | :maxdepth: 2 22 | :caption: About 23 | 24 | start 25 | 26 | .. toctree:: 27 | :maxdepth: 4 28 | :caption: API reference 29 | :glob: 30 | 31 | autoapi/hugdatafast/fastai/* 32 | autoapi/hugdatafast/transform/* 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hugdatafast 2 | The elegant integration of huggingface/datasets and fastai, and some handy data transformation for huggingface/datasets. 3 | 4 | 🎓 **Documentation** : https://hugdatafast.readthedocs.io/en/latest/ 5 | 6 | # Install 7 | `pip install hugdatafast` 8 | 9 | # Furture Plan 10 | - I would like to merge this library to fastai and huggingface/datasets respectively. But I may have no time for it. You're welcome to pr this library to the two libraries. 11 | 12 | - The implemenatation of `CombineTransform` works but might be too complexed to extend, hope HuggingFace or someone come up with some great ideas. 13 | 14 | - Currently, it is designed to work with the dataset part of huggingface/datasets, I may also integrate the part of metric or not. 15 | 16 | # Quick Intro 17 | ![hugdatafast_fastai](https://user-images.githubusercontent.com/17963619/92091020-be672f00-ee02-11ea-84c0-d54b4855ff4b.png) 18 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from hugdatafast.__init__ import __version__ 3 | 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | REQUIRED_PKGS = [ 8 | 'fastai>=2.0.8', 9 | 'fastscore>=1.0.1', # change of store_attr api 10 | 'datasets', 11 | ] 12 | 13 | setuptools.setup( 14 | name="hugdatafast", 15 | version=__version__, 16 | author="Richard Wang", 17 | author_email="richardyy1188@gmail.com", 18 | description="The elegant bridge between hugginface data and fastai", 19 | long_description=long_description, 20 | long_description_content_type="text/markdown", 21 | url="https://github.com/richarddwang/hugdatafast", 22 | license='Apache 2.0', 23 | packages=setuptools.find_packages(), 24 | classifiers=[ 25 | "Development Status :: 4 - Beta", 26 | "Intended Audience :: Developers", 27 | "Intended Audience :: Science/Research", 28 | "Programming Language :: Python :: 3", 29 | "License :: OSI Approved :: Apache Software License", 30 | "Operating System :: OS Independent", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | ], 33 | python_requires='>=3.6', 34 | install_requires=REQUIRED_PKGS, 35 | keywords='datasets machine learning datasets metrics fastai huggingface', 36 | ) -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'hugdatafast' 21 | copyright = '2020, Richard Wang' 22 | author = 'Richard Wang' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | #release = 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'autoapi.extension', 35 | #'sphinx.ext.napoleon', # google style docstring 36 | 'sphinxcontrib.napoleon', 37 | 'sphinx.ext.viewcode', 38 | 'sphinx_copybutton', 39 | 'sphinx_rtd_theme', 40 | ] 41 | 42 | autoapi_dirs = ['../../hugdatafast'] 43 | autoapi_add_toctree_entry = False 44 | autoapi_options = ['show-module-summary'] 45 | autoapi_python_class_content = 'both' 46 | 47 | # Add any paths that contain templates here, relative to this directory. 48 | templates_path = ['_templates'] 49 | 50 | # List of patterns, relative to source directory, that match files and 51 | # directories to ignore when looking for source files. 52 | # This pattern also affects html_static_path and html_extra_path. 53 | exclude_patterns = [] 54 | 55 | # https://github.com/readthedocs/readthedocs.org/issues/2569 56 | master_doc = 'index' 57 | # -- Options for HTML output ------------------------------------------------- 58 | 59 | # The theme to use for HTML and HTML Help pages. See the documentation for 60 | # a list of builtin themes. 61 | # 62 | html_theme = 'sphinx_rtd_theme' 63 | 64 | # Add any paths that contain custom static files (such as style sheets) here, 65 | # relative to this directory. They are copied after the builtin static files, 66 | # so a file named "default.css" will overwrite the builtin "default.css". 67 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /docs/source/start.rst: -------------------------------------------------------------------------------- 1 | ================== 2 | Get Started 3 | ================== 4 | 5 | ----------------- 6 | Base use case 7 | ----------------- 8 | 9 | :: 10 | 11 | >>> from datasets import load_dataset 12 | >>> from hugdatafast import * 13 | 14 | .. note:: 15 | This will also implicitly do ``from fastai.text.all import *`` 16 | 17 | Can you turn your data pipeline into only 3 lines ? 18 | 19 | :: 20 | 21 | >>> datasets = load_dataset('glue', 'cola') 22 | -> {'train': datasets.Dataset, 'validation': datasets.Dataset, 'test': datasets.Dataset} 23 | >>> tokenized_datasets = datasets.map(simple_tokenize_func({'sentence':'text_idxs'}, hf_tokenizer)) 24 | >>> dls = HF_Datasets(tokenized_datasets, cols=['text_idxs', 'label'], hf_toker=hf_tokenizer).dataloaders(bs=64) 25 | 26 | Now you can enjoy 27 | 28 | 1. :func:`show_batch` of fastai \n 29 | Inspect your processed data and quickly check if there is anything wrong with your data processing. 30 | 31 | :: 32 | 33 | >>> dls.show_batch(max_n=2) 34 | text_idxs label 35 | -------------------------------------------------------------------------------------------------------------------------------------- 36 | 0 everybody who has ever , worked in any office which contained any type ##writer which had ever been used to type any 1 37 | letters which had to be signed by any administrator who ever worked in any department like mine will know what i mean . 38 | -------------------------------------------------------------------------------------------------------------------------------------- 39 | 1 playing with matches is ; lots of fun , but doing , so and empty ##ing gasoline from one can to another at the same 1 40 | time is a sport best reserved for arson ##s . [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] 41 | 42 | 2. Train model on the data using fastai, and also show the prediction 43 | 44 | :: 45 | 46 | >>> learn = Learner(dls, your_model, loss_func=CrossEntropyLossFlat()) 47 | >>> learn.fit(3) 48 | >>> learn.show_results() 49 | text_idxs label label_ 50 | ----------------------------------------------------------------------------------------------------- 51 | 0 [CLS] scientists at the south hanoi institute of technology have succeeded in raising 1 1 52 | one dog with five legs , another with a cow ' s liver , and a third with no head . [SEP] 53 | ----------------------------------------------------------------------------------------------------- 54 | 1 [CLS] as a teacher , you have to deal simultaneously with the administration ' s pressure 0 1 55 | on you to succeed , and the children ' s to be a nice guy . [SEP] [PAD] [PAD] 56 | 57 | 3. Use it as normal Dataloaders if you don't use fastai . 58 | 59 | :: 60 | 61 | >>> train_dataloader, val_dataloader, test_dataloader = dls[0], dls[1], dls[2] 62 | >>> for b in train_dataloader: break 63 | 64 | ------------------ 65 | Other use cases 66 | ------------------ 67 | 68 | 1. Use your own dataset ? 69 | 70 | * `datasets.Dataset s from local structured files (csv, json, ...) `_ 71 | 72 | * `datasets.Dataset s from custom loading script `_ 73 | 74 | 2. Need to combine examples to generate new example ? (e.g. Traditional language model) 75 | 76 | :: 77 | 78 | >>> lm_datasets = LMTransform(datasets, max_len=20, text_col='text_idxs').map() 79 | >>> hf_tokenizer.decode(lm_datasets['validation'][-1]['x_text']) 80 | . john talked to bill about himself 81 | >>> hf_tokenizer.decode(lm_datasets['validation'][-1]['y_text']) 82 | john talked to bill about himself. 83 | 84 | If you want to implement your own logic to combine examples, try to extend :class:`CombineTransform`. 85 | 86 | ---------------------------- 87 | ``hugdatafast`` in practice 88 | ---------------------------- 89 | 90 | You can see how to use ``hugdatafast`` in the real situations. Also, You are welcome to share how you use 91 | ``hugdatafast`` in your project, contact me via github or twitter to put your project link here. 92 | 93 | * `electra_pytorch `_ : Pretrain ELECTRA and finetune on GLUE benchmark -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /hugdatafast/transform.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pyarrow as pa 3 | import datasets 4 | from fastai.text.all import * 5 | 6 | @patch 7 | def cache_directory(self: datasets.arrow_dataset.Dataset): 8 | return os.path.abspath(os.path.dirname(self.cache_files[0]['filename'])) 9 | 10 | @patch 11 | def my_map(self: datasets.arrow_dataset.Dataset, *args, **kwargs): 12 | """ 13 | The same as :class:`datasets.arrow_dataset.Dataset` , but it can add cache directory and .arrow to cache_file_name autmomatically for us. 14 | 15 | Example: 16 | >>> dataset.map(a_func, cache_file_name='processed') 17 | # cache file path become "/processed.arrow" 18 | """ 19 | cache_file_name = kwargs.pop('cache_file_name', None) 20 | if cache_file_name is not None: 21 | if not cache_file_name.endswith('.arrow'): cache_file_name += '.arrow' 22 | if '/' not in cache_file_name: cache_file_name = os.path.join(self.cache_directory(), cache_file_name) 23 | return self.map(*args, cache_file_name=cache_file_name, **kwargs) 24 | 25 | @patch 26 | def my_map(self: datasets.dataset_dict.DatasetDict, *args, **kwargs): 27 | """ 28 | The same as :class:`datasets.dataset_dict.DatasetDict` , but it can infer cache names for us. 29 | 30 | Example: 31 | >>> datasets.map(a_func, cache_file_names='processed_{split}') 32 | # cache file paths : "/processed_train.arrow", "/processed_validation.arrow", "/processed_test.arrow" 33 | """ 34 | # cache file names 35 | cache_file_names = kwargs.pop('cache_file_names', None) 36 | self._check_values_type() 37 | if cache_file_names is None: cache_file_names = {k: None for k in self} 38 | if isinstance(cache_file_names, str): cache_file_names = {k: cache_file_names.format(split=k) for k in self} 39 | # split specific kwargs 40 | fn_kwargs = kwargs.pop('fn_kwargs', None) 41 | if fn_kwargs is None: fn_kwargs = {} 42 | _fn_kwargs = {split_name:{} for split_name in self.keys()} 43 | for k,v in fn_kwargs.items(): 44 | if k in _fn_kwargs and isinstance(v, dict): # kwargs for a specific split 45 | _fn_kwargs[k] = v 46 | else: # generic kwargs for all splits 47 | for split in _fn_kwargs: _fn_kwargs[split][k] = v 48 | 49 | # pass 50 | return datasets.dataset_dict.DatasetDict({k: dataset.my_map(*args, 51 | cache_file_name=cache_file_names[k], 52 | fn_kwargs=_fn_kwargs[k], 53 | **kwargs) for k, dataset in self.items()}) 54 | 55 | class SimpleTokenize(): 56 | def __init__(self, cols, hf_toker): 57 | if isinstance(cols, list): cols = {c:c for c in cols} 58 | elif isinstance(cols, str): cols = {cols:cols} 59 | assert isinstance(cols, dict) 60 | self.cols = cols 61 | self.hf_toker = hf_toker 62 | def __call__(self, example): 63 | for in_col, out_col in self.cols.items(): 64 | example[out_col] = self.hf_toker.convert_tokens_to_ids(self.hf_toker.tokenize(example[in_col])) 65 | return example 66 | 67 | class CombineTransform(): 68 | """ 69 | Base Class for Transform that combine multiple original samples into a new sample. 70 | """ 71 | def __init__(self, hf_dset, in_cols, out_cols, drop_last=False): 72 | """ 73 | Args: 74 | hf_dset (:class:`Dataset` or :class:`DatasetDict`): The Hugging Face dataset(s) to do the transformation 75 | in_cols (`List[str]`): names of input columns that used to produce samples 76 | out_cols (`List[str]`): names of output columns to put combined samples. 77 | drop_last` (`Optional[bool]`, default: `False`): whether to drop the last accumulated sample. 78 | """ 79 | # Always do the case of multiple datasets for the convenience of coding 80 | if isinstance(hf_dset, datasets.arrow_dataset.Dataset): self.dsets = {'Single': hf_dset}; self.single=True 81 | else: self.dsets = hf_dset; self.single=False 82 | 83 | # check column names 84 | self.in_cols, self.out_cols = in_cols, out_cols 85 | for col in out_cols: assert col not in self.in_cols, f"New column name can't be the same with any original column name. '{col}'" 86 | 87 | # batched map need dataset in Python format 88 | for dset in self.dsets.values(): dset.set_format(type=None, columns=in_cols) 89 | 90 | # dealing with last sample 91 | self.last_idx = len(hf_dset) - 1 92 | self.drop_last = drop_last 93 | 94 | def __call__(self, b, indices): 95 | # If first batch, `datasets.Dataset.map` first test with several samples which affects our internal states, so we need to reinitialize. 96 | if 0 in indices: 97 | self.reset_states() 98 | 99 | self.new_b = { c:[] for c in self.out_cols } 100 | values = [ b[c] for c in self.in_cols ] 101 | for z in zip(*values): 102 | self.accumulate(*z) 103 | 104 | # If Last batch, whehther commit last incomplete example 105 | if not self.drop_last and self.last_idx in indices: 106 | try: self.commit_example(self.create_example()) 107 | except: pass # assume it is because there's nothing can be created 108 | 109 | return self.new_b 110 | 111 | def commit_example(self, example): 112 | if example is None: return 113 | for col,val in example.items(): 114 | self.new_b[col].append(val) 115 | 116 | def reset_states(self): 117 | """ 118 | Child Class should implement this method.\n 119 | Reset all containers, flags to their initial values. 120 | """ 121 | raise NotImplementedError 122 | 123 | def accumulate(self, *args): 124 | """ 125 | Child Class should implement this method.\n 126 | Given a example, do `self.commit_example(self.create_example()) when a new combined sample is ready.` 127 | Args: 128 | args : values of :data:`inp_cols` ( passed to :func:`__init__` ) of an example 129 | """ 130 | raise NotImplementedError 131 | 132 | def create_example(self): 133 | """ 134 | Child Class should implement this method.\n 135 | Use internal states stored in the child class instance to create a combined example (dict).\n 136 | When nothing can't be created, return ``None`` or raise any exception to show it. 137 | """ 138 | raise NotImplementedError 139 | 140 | def map(self, batch_size=1000, cache_file_name=None, **kwargs): 141 | """ 142 | Args: 143 | batch_size(int): See :class:`datasets.Dataset.map`, shouldn't be None here 144 | cache_file_name: The same with the one of :func:`my_map` 145 | kwargs: passed to :class:`datasets.Dataset.map` 146 | """ 147 | 148 | # check 149 | assert 'remove_columns' not in kwargs, "Aggregation type transform will only leave output columns for output dataset." 150 | 151 | # infer cache_file_name s 152 | if not isinstance(cache_file_name, dict): 153 | cache_names = { k:cache_file_name for k in self.dsets.keys() } 154 | for k, dset in self.dsets.items(): 155 | if cache_names[k] is None: continue 156 | if not cache_names[k].endswith('.arrow'): cache_names[k] += '.arrow' 157 | if '{split}' in cache_names[k]: cache_names[k] = cache_names[k].format(split=k) 158 | if '/' not in cache_names[k]: cache_names[k] = os.path.join(dset.cache_directory(), cache_names[k]) 159 | 160 | # map every dataset 161 | mapped_dsets = {} 162 | for k, dset in self.dsets.items(): 163 | self.last_idx = len(dset) - 1 164 | mapped_dset = dset.map(function=self, 165 | batched=True, batch_size=batch_size, 166 | with_indices=True, 167 | num_proc=1, 168 | cache_file_name=cache_names[k], 169 | remove_columns=self.in_cols, # Cuz output column has less rows (combined) than orginal column 170 | **kwargs) 171 | mapped_dset.set_format(None, columns=self.out_cols) 172 | mapped_dsets[k] = mapped_dset 173 | 174 | if self.single: return mapped_dsets['Single'] 175 | else: return datasets.DatasetDict(mapped_dsets) 176 | 177 | @delegates(CombineTransform, but=["inp_cols", "out_cols", "init_attrs"]) 178 | class LMTransform(CombineTransform): 179 | """ 180 | Transform any dataset has tokenized text into dataset (autotgressive) language model. 181 | !! Caution: This span context window across examples. So make sure your texts in examples of the datasets are consecutive or relative. 182 | Or you are knowing what you are doing. 183 | """ 184 | def __init__(self, tokenized_hf_dset, max_len, text_col, x_text_col='x_text', y_text_col='y_text', **kwargs): 185 | """ 186 | Args: 187 | tokenized_hf_dset (:class:`Dataset` or :class:`DatasetDict`): tokenized Hugging Face dataset(s) to do LM transform 188 | max_len (int): the length of a sentence 189 | text_col (str): the name of column that contains tokenized text (ids) of tokenized_hf_dset 190 | x_text_col (str): the name of the output column 191 | y_text_col (str): the name fo the output column 192 | kwargs: passed to :class:CombineTransform 193 | 194 | Example: 195 | >>> lm_dataset = LMTransform(tokenized_cola['validation'], max_len=20, text_col='text_idxs').map() 196 | >>> lm_dataset[0] 197 | {'x_text': [ 1996, 11279, 8469, 1996, 9478, 3154, 1997, 1996, 5749, 1012, 198 | 1996, 15871, 2081, 1996, 8164, 7683, 2058, 1996, 4139, 3240], 199 | 'y_text': [11279, 8469, 1996, 9478, 3154, 1997, 1996, 5749, 1012, 1996, 200 | 15871, 2081, 1996, 8164, 7683, 2058, 1996, 4139, 3240, 1012]} 201 | """ 202 | if isinstance(text_col, str): text_col = {text_col:['x_text','y_text']} 203 | assert isinstance(text_col, dict) 204 | self.text_col, (self.x_text_col, self.y_text_col) = next(iter(text_col.items())) 205 | self._max_len = max_len + 1 206 | self.reset_states() 207 | super().__init__(tokenized_hf_dset, in_cols=[self.text_col], out_cols=[x_text_col, y_text_col], **kwargs) 208 | 209 | def reset_states(self): 210 | self.new_text = [] 211 | self.residual_len = self._max_len 212 | 213 | def create_example(self): 214 | # when read all data, the accumulated new_text might be less than two characters. 215 | if len(self.new_text) >= 2: 216 | example = {self.x_text_col:self.new_text[:-1], self.y_text_col:self.new_text[1:]} 217 | else: 218 | example = None # mark "don't commit this" 219 | # reset accumulators 220 | self.reset_states() 221 | 222 | return example 223 | 224 | def accumulate(self, text): # *inp_cols 225 | usable_len = len(text) 226 | cursor = 0 227 | while usable_len != 0: 228 | use_len = min(usable_len, self.residual_len) 229 | self.new_text += text[cursor:cursor+use_len] 230 | self.residual_len -= use_len 231 | usable_len -= use_len 232 | cursor += use_len 233 | if self.residual_len == 0: 234 | self.commit_example(self.create_example()) 235 | 236 | @delegates(CombineTransform, but=["inp_cols", "out_cols", "init_attrs"]) 237 | class ELECTRADataTransform(CombineTransform): 238 | "Process any text corpus for ELECTRA's use" 239 | def __init__(self, hf_dset, is_docs, text_col, max_length, hf_toker, delimiter='\n', **kwargs): 240 | """ 241 | Args: 242 | hf_dset (:class:`Dataset` or :class:`DatasetDict`): **untokenized** Hugging Face dataset(s) to do the transform 243 | is_docs (bool): Whether each sample of this dataset is a doc 244 | text_col (str): the name of column of the dataset contains text 245 | max_length (str): max length of each sentence 246 | hf_toker (:class:`transformers.PreTrainedTokenizer`): Hugging Face tokenizer 247 | delimiter (str): what is the delimiter to segment sentences in the input text 248 | kwargs: passed to :class:`CombineTransform` 249 | """ 250 | self.is_docs = is_docs 251 | self.in_col = text_col 252 | self._max_length = max_length 253 | self.cls_idx, self.sep_idx = hf_toker.cls_token_id, hf_toker.sep_token_id 254 | self.hf_toker = hf_toker 255 | self.delimiter = delimiter 256 | self.reset_states() 257 | super().__init__(hf_dset, in_cols=[self.in_col], out_cols=['input_ids','sentA_lenth'], **kwargs) 258 | 259 | """ 260 | These three main functions adapts official source code creates pretraining dataset, to CombineTransform 261 | """ 262 | def reset_states(self): 263 | self._current_sentences = [] 264 | self._current_length = 0 265 | self._target_length = self._max_length 266 | 267 | def accumulate(self, text): 268 | sentences = text.split(self.delimiter) 269 | for sentence in sentences: 270 | if not sentence: continue # skip empty 271 | tokids = self.hf_toker.convert_tokens_to_ids(self.hf_toker.tokenize(sentence)) 272 | self.add_line(tokids) 273 | # end of doc 274 | if self.is_docs and self._current_length > 0: 275 | self.commit_example(self.create_example()) 276 | 277 | def create_example(self): 278 | input_ids, sentA_lenth = self._create_example() # this line reset _current_sentences and _current_length in the end 279 | return {'input_ids': input_ids, 'sentA_lenth':sentA_lenth} 280 | # ................................................... 281 | 282 | def add_line(self, tokids): 283 | """Adds a line of text to the current example being built.""" 284 | self._current_sentences.append(tokids) 285 | self._current_length += len(tokids) 286 | if self._current_length >= self._target_length: 287 | self.commit_example(self.create_example()) 288 | 289 | def _create_example(self): 290 | """Creates a pre-training example from the current list of sentences.""" 291 | # small chance to only have one segment as in classification tasks 292 | if random.random() < 0.1: 293 | first_segment_target_length = 100000 294 | else: 295 | # -3 due to not yet having [CLS]/[SEP] tokens in the input text 296 | first_segment_target_length = (self._target_length - 3) // 2 297 | 298 | first_segment = [] 299 | second_segment = [] 300 | for sentence in self._current_sentences: 301 | # the sentence goes to the first segment if (1) the first segment is 302 | # empty, (2) the sentence doesn't put the first segment over length or 303 | # (3) 50% of the time when it does put the first segment over length 304 | if (len(first_segment) == 0 or 305 | len(first_segment) + len(sentence) < first_segment_target_length or 306 | (len(second_segment) == 0 and 307 | len(first_segment) < first_segment_target_length and 308 | random.random() < 0.5)): 309 | first_segment += sentence 310 | else: 311 | second_segment += sentence 312 | 313 | # trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens 314 | first_segment = first_segment[:self._max_length - 2] 315 | second_segment = second_segment[:max(0, self._max_length - 316 | len(first_segment) - 3)] 317 | 318 | # prepare to start building the next example 319 | self._current_sentences = [] 320 | self._current_length = 0 321 | ## small chance for random-length instead of max_length example 322 | if random.random() < 0.05: 323 | self._target_length = random.randint(5, self._max_length) 324 | else: 325 | self._target_length = self._max_length 326 | 327 | return self._make_example(first_segment, second_segment) 328 | 329 | def _make_example(self, first_segment, second_segment): 330 | """Converts two "segments" of text into a tf.train.Example.""" 331 | input_ids = [self.cls_idx] + first_segment + [self.sep_idx] 332 | sentA_lenth = len(input_ids) 333 | if second_segment: 334 | input_ids += second_segment + [self.sep_idx] 335 | return input_ids, sentA_lenth 336 | 337 | def __getstate__(self): 338 | "specify something you don't want pickle here, remember to use copy to not modfiy orginal instance" 339 | state = self.__dict__.copy() 340 | state['hf_toker'] = None 341 | return state -------------------------------------------------------------------------------- /hugdatafast/fastai.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from pathlib import Path 3 | import json 4 | from tqdm import tqdm 5 | from torch.nn.utils.rnn import pad_sequence 6 | import datasets 7 | from fastai.text.all import * 8 | 9 | 10 | @delegates() 11 | class MySortedDL(TfmdDL): 12 | "A :class:`DataLoader` that do smart batching and dynamic padding. Different from :class:`SortedDL`, it automatically pad every attribute of samples, is able to filter samples, and can be cached to sort/filter only at first time." 13 | 14 | def __init__(self, dataset, srtkey_fc=None, filter_fc=False, pad_idx=None, cache_file=None, **kwargs): 15 | """ 16 | Args: 17 | dataset (HF_Dataset): Actually any object implements ``__len__`` and ``__getitem__`` that return a tuple as a sample. 18 | srtkey_fc (``*args->int``, optional): Get key for decending sorting from a sample .\n 19 | - If ``None``, sort by length of first element of a sample. 20 | - If ``False``, not sort. 21 | filter_fc (``*args->bool``, optional): Return ``True`` to keep the sample. 22 | pad_idx (``int``, optional): pad each attribute of samples to the max length of its max length within the batch.\n 23 | - If ``List[int]``, specify pad_idx for each attribute of a sample. e.g. a sample is a tuple (masked_inputs, labels), `pad_idx=[0 ,-100]` pad masked_inputs with 0, labels with -100. 24 | - If ``False``, do no padding. 25 | - If ``None``, try ``dataset.pad_idx``, do no padding if no such attribute. 26 | cache_file (``str``, optional): Path of a json file to cache info for sorting and filtering. 27 | kwargs: key arguments for `TfmDl` or `DataLoader` 28 | 29 | Example: 30 | >>> samples = [ (torch.tensor([1]), torch.tensor([7,8]), torch.tensor(1)),, 31 | ... (torch.tensor([2,3]), torch.tensor([9,10,11]), torch.tensor(2)), 32 | ... (torch.tensor([4,5,6]), torch.tensor([11,12,13,14]), torch.tensor(3)), ] 33 | ... dl = MySortedDL(samples, 34 | ... srtkey_fc=lambda *args: len(args[0]), 35 | ... filter_fc=lambda x1,y1: y1<3, 36 | ... pad_idx=-1, 37 | ... cache_file='/tmp/cache.json', # calls after this will load cache 38 | ... bs=999, # other parameters go to `TfmDL` and `DataLoader` 39 | ... ) 40 | ... dl.one_batch() 41 | (tensor([[ 2, 3], 42 | [ 1, -1]]), 43 | tensor([[ 9, 10, 11], 44 | [ 7, 8, -1]]), 45 | tensor([2, 1])) 46 | """ 47 | # Defaults 48 | if srtkey_fc is not False: srtkey_fc = lambda *x: len(x[0]) 49 | if pad_idx is None: pad_idx = getattr(dataset, 'pad_idx', False) 50 | if isinstance(pad_idx, int): pad_idxs = [pad_idx] * len(dataset[0]) 51 | elif isinstance(pad_idx, (list, tuple)): pad_idxs = pad_idx 52 | cache_file = Path(cache_file) if cache_file else None 53 | idmap = list(range(len(dataset))) 54 | 55 | # Save attributes 56 | super().__init__(dataset, **kwargs) 57 | store_attr('pad_idxs,srtkey_fc,filter_fc,cache_file,idmap', self) 58 | 59 | # Prepare records for sorting / filtered samples 60 | if srtkey_fc or filter_fc: 61 | if cache_file and cache_file.exists(): 62 | # load cache and check 63 | with cache_file.open(mode='r') as f: cache = json.load(f) 64 | idmap, srtkeys = cache['idmap'], cache['srtkeys'] 65 | if srtkey_fc: 66 | assert srtkeys, "srtkey_fc is passed, but it seems you didn't sort samples when creating cache." 67 | self.srtkeys = srtkeys 68 | if filter_fc: 69 | assert idmap, "filter_fc is passed, but it seems you didn't filter samples when creating cache." 70 | self.idmap = idmap 71 | else: 72 | # overwrite idmap if filter, get sorting keys if sort 73 | idmap = []; srtkeys = [] 74 | for i in tqdm(range_of(dataset), leave=False): 75 | sample = self.do_item(i) 76 | if filter_fc and not filter_fc(*sample): continue 77 | if filter_fc: idmap.append(i) 78 | if srtkey_fc: srtkeys.append(srtkey_fc(*sample)) 79 | if filter_fc: self.idmap = idmap 80 | if srtkey_fc: self.srtkeys = srtkeys 81 | # save to cache 82 | if cache_file: 83 | try: 84 | with cache_file.open(mode='w+') as f: json.dump({'idmap':idmap,'srtkeys':srtkeys}, f) 85 | except: os.remove(str(cache_file)) 86 | # an info for sorting 87 | if srtkey_fc: self.idx_max = np.argmax(self.srtkeys) 88 | # update number of samples 89 | if filter_fc: self.n = self.n = len(self.idmap) 90 | 91 | def create_item(self, i): return self.dataset[self.idmap[i]] 92 | 93 | def create_batch(self, samples): 94 | if self.pad_idx is False: return super().create_batch(samples) 95 | return tuple( pad_sequence(attr, batch_first=True, padding_value=self.pad_idxs[i]) if attr[0].shape and isinstance(self.pad_idxs[i], int) else torch.stack(attr) for i, attr in enumerate(zip(*samples))) 96 | 97 | def get_idxs(self): 98 | idxs = super().get_idxs() 99 | if self.shuffle: return idxs 100 | if self.srtkey_fc: return sorted(idxs, key=lambda i: self.srtkeys[i], reverse=True) 101 | return idxs 102 | 103 | def shuffle_fn(self,idxs): 104 | if not self.srtkey_fc: return super().shuffle_fn(idxs) 105 | idxs = np.random.permutation(self.n) 106 | idx_max = np.where(idxs==self.idx_max)[0][0] 107 | idxs[0],idxs[idx_max] = idxs[idx_max],idxs[0] 108 | sz = self.bs*50 109 | chunks = [idxs[i:i+sz] for i in range(0, len(idxs), sz)] 110 | chunks = [sorted(s, key=lambda i: self.srtkeys[i], reverse=True) for s in chunks] 111 | sort_idx = np.concatenate(chunks) 112 | 113 | sz = self.bs 114 | batches = [sort_idx[i:i+sz] for i in range(0, len(sort_idx), sz)] 115 | sort_idx = np.concatenate(np.random.permutation(batches[1:-1])) if len(batches) > 2 else np.array([],dtype=np.int) 116 | sort_idx = np.concatenate((batches[0], sort_idx) if len(batches)==1 else (batches[0], sort_idx, batches[-1])) 117 | return iter(sort_idx) 118 | 119 | @delegates(TfmdDL.new) 120 | def new(self, dataset=None, **kwargs): 121 | if 'get_idxs' in kwargs: # when Learner.get_preds, dataload has `get_idxs` will be cloned. So we need to prevent sorting again 122 | kwargs['cache_file'] = self.cache_file 123 | # We don't use filter_fc here cuz we can't don't validate certaion samples in dev/test set. 124 | return super().new(dataset=dataset, pad_idx=self.pad_idx, srtkey_fc=self.srtkey_fc, filter_fc=False, **kwargs) 125 | 126 | # ========================= 127 | # Titled primitives 128 | # ========================= 129 | 130 | class _Int(int, ShowPrint): 131 | def __new__(cls, *args, **kwargs): 132 | item = super().__new__(cls, *args) 133 | for n,v in kwargs.items(): setattr(item, n, v) 134 | return item 135 | 136 | class _Float(float, ShowPrint): 137 | def __new__(cls, *args, **kwargs): 138 | item = super().__new__(cls, *args) 139 | for n,v in kwargs.items(): setattr(item, n, v) 140 | return item 141 | 142 | class _Str(str, ShowPrint): 143 | def __new__(cls, *args, **kwargs): 144 | item = super().__new__(cls, *args) 145 | for n,v in kwargs.items(): setattr(item, n, v) 146 | return item 147 | 148 | class _Tuple(fastuple, ShowPrint): 149 | def __new__(cls, *args, **kwargs): 150 | item = super().__new__(cls, *args) 151 | for n,v in kwargs.items(): setattr(item, n, v) 152 | return item 153 | 154 | class _L(L, ShowPrint): 155 | def __new__(cls, *args, **kwargs): 156 | item = super().__new__(cls, *args) 157 | for n,v in kwargs.items(): setattr(item, n, v) 158 | return item 159 | 160 | # only change "label" to "title" 161 | def _show_title(o, ax=None, ctx=None, title=None, color='black', **kwargs): 162 | "Set title of `ax` to `o`, or print `o` if `ax` is `None`" 163 | ax = ifnone(ax,ctx) 164 | if ax is None: print(o) 165 | elif hasattr(ax, 'set_title'): 166 | t = ax.title.get_text() 167 | if len(t) > 0: o = t+'\n'+str(o) 168 | ax.set_title(o, color=color) 169 | elif isinstance(ax, pd.Series): 170 | while title in ax: title += '_' 171 | ax = ax.append(pd.Series({title: o})) 172 | return ax 173 | 174 | class _ShowTitle: 175 | def show(self, ctx=None, **kwargs): 176 | kwargs['title'] = kwargs.pop('title', getattr(self, 'title', self.default_title)) 177 | return _show_title(str(self), ctx=ctx, **kwargs) 178 | 179 | # it seems that python prioritising prior inherited class when finding methods 180 | 181 | class _TitledInt(_ShowTitle, _Int): default_title = 'int' 182 | 183 | class _TitledFloat(_ShowTitle, _Float): default_title = 'float' 184 | 185 | # I created it, but it just print book likt int, haven't find a way to solve it 186 | class _TitledBool(_ShowTitle, _Int): # python says bool can't be base class 187 | default_title = 'bool' 188 | 189 | class _TitledStr(_ShowTitle, _Str): 190 | default_title = 'text' 191 | def truncate(self, n): 192 | "Truncate self to `n`" 193 | words = self.split(' ')[:n] 194 | return _TitledStr(' '.join(words), title=getattr(self, 'title', 'text')) 195 | 196 | class _TitledTuple(_ShowTitle, _Tuple): default_title = 'list' 197 | 198 | class _Category(_ShowTitle, _Str): default_title = 'label' 199 | 200 | class _MultiCategory(_ShowTitle, _L): 201 | default_title = 'labels' 202 | def show(self, ctx=None, sep=';', color='black', **kwargs): 203 | kwargs['title'] = kwargs.pop('title', getattr(self, 'title', self.default_title)) 204 | return _show_title(sep.join(self.map(str)), ctx=ctx, color=color, **kwargs) 205 | 206 | """ Caution !! 207 | These two function is inperfect. 208 | But they cope with mutiple input columns problem (n_inp >1), which cause no df printing but just sequentail print 209 | These will be a problem when you are doing non-text problem with n_inp > 1 (multiple input column), 210 | which shouldn't be the case of huggingface/datasets user. 211 | And I hope fastai come up with a good solution to show_batch multiple inputs problems for text/non-text. 212 | """ 213 | @typedispatch 214 | def show_batch(x:tuple, y, samples, ctxs=None, max_n=9, **kwargs): 215 | if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n)) 216 | ctxs = show_batch[object](x, y, samples, max_n=max_n, ctxs=ctxs, **kwargs) 217 | display_df(pd.DataFrame(ctxs)) 218 | return ctxs 219 | 220 | @typedispatch 221 | def show_results(x: tuple, y, samples, outs, ctxs=None, max_n=10, trunc_at=150, **kwargs): 222 | if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n)) 223 | ctxs = show_results[object](x, y, samples, outs, ctxs=ctxs, max_n=max_n, **kwargs) 224 | display_df(pd.DataFrame(ctxs)) 225 | return ctxs 226 | 227 | class HF_Dataset(): 228 | """A wrapper for :class:`datasets.Dataset`. It will behavior like original :class:`datasets.Dataset`, 229 | but also function as a :class:`fastai.data.core.datasets` that provides samples and decodes.""" 230 | 231 | def __init__(self, hf_dset, cols=None, hf_toker=None, neat_show=False, n_inp=1): 232 | """ 233 | Args: 234 | hf_dset (:class:`datasets.Dataset`): Prerocessed Hugging Face dataset to be wrapped. 235 | cols (dict, optional): columns of :class:`datasets.Dataset` to be used to construct samples, and (optionally) semantic tensor type for each of those columns to decode.\n 236 | - cols(``Dict[Fastai Semantic Tensor]``): encode/decode column(key) with semantic tensor type(value). If {value} is ``noop``, semantic tensor of the column is by default `TensorTuple`. 237 | - cols(``list[str]``): specify only columns and take default setting for semantic tensor type of them.\n 238 | - if length is 1, regard the 1st element as `TensorText` 239 | - if length is 2, regard the 1st element as `TensorText`, 2nd element as `TensorCategory` 240 | - Otherwise, regard all elements as `TensorTuple` 241 | - cols(None): pass :data:`hf_dset.column_names` (list[str]) as cols. 242 | hf_toker (:class:`transformers.PreTrainedTokenizer`, optional): Hugging Face tokenizer, used in decode and provide ``pad_idx`` for dynamic padding 243 | neat_show (bool, optional): Show the original sentence instead of tokens joined by space. 244 | n_inp (int, optional): take the first ``n_inp`` columns of ``cols`` as x, and the rest as y . 245 | 246 | Example: 247 | >>> tokenized_cola_train_set[0] 248 | {'sentence': "Our friends won't buy this analysis, let alone the next one we propose.", 249 | 'label': 1, 250 | 'idx': 0, 251 | 'text_idxs': [ 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012]} 252 | >>> hf_dset = HF_Datset(tokenized_cola_train_set, cols=['text_idxs', 'label'], hf_toker=tokenizer_electra_small_fast) 253 | >>> len(hf_dset), hf_dset[0] 254 | 8551, (TensorText([ 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012]), TensorCategory(1)) 255 | >>> hf_dset.decode(hf_dset[0]) 256 | ("our friends won ' t buy this analysis , let alone the next one we propose .", '1') 257 | # The wrapped dataset "is" also the original huggingface dataset 258 | >>> hf_dset.column_names == tokenized_cola_train_set.column_names 259 | True 260 | # Manually specify `cols` with dict, here it is equivalent to the above. And addtionally, neatly decode samples. 261 | >>> neat_hf_dset = HF_Datset(tokenized_cola_train_set, {'text_idxs':TensorText, 'label':TensorCategory}, hf_toker=tokenizer_electra_small_fast, neat_show=True) 262 | >>> neat_hf_dset.decode(neat_hf_dset[0]) 263 | ("our friends won't buy this analysis, let alone the next one we propose.", '1') 264 | # Note: Original set will be set to Pytorch format with columns specified in `cols` 265 | >>> tokenized_cola_train_set[0] 266 | {'label': tensor(1), 267 | 'text_idxs': tensor([ 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012])} 268 | """ 269 | 270 | # some default setting for tensor type used in decoding 271 | if cols is None: cols = hf_dset.column_names 272 | if isinstance(cols, list): 273 | if n_inp==1: 274 | if len(cols)==1: cols = {cols[0]: TensorText} 275 | elif len(cols)==2: cols = {cols[0]: TensorText, cols[1]: TensorCategory} 276 | else: cols = { c: noop for c in cols } 277 | assert isinstance(cols, dict) 278 | 279 | # make dataset output pytorch tensor 280 | hf_dset.set_format( type='torch', columns=list(cols.keys()) ) 281 | 282 | # store attributes 283 | self.pad_idx = hf_toker.pad_token_id 284 | self.hf_dset = hf_dset 285 | store_attr("cols,n_inp,hf_toker,neat_show", self) 286 | 287 | def __getitem__(self, idx): 288 | sample = self.hf_dset[idx] 289 | return tuple( tensor_cls(sample[col]) for col, tensor_cls in self.cols.items() ) 290 | 291 | def __len__(self): return len(self.hf_dset) 292 | 293 | @property 294 | def col_names(self): return list(self.cols.keys()) 295 | 296 | def decode(self, o, full=True): # `full` is for micmic `Dataset.decode` 297 | if len(self.col_names) != len(o): return tuple( self._decode(o_) for o_ in o ) 298 | return tuple( self._decode(o_, self.col_names[i]) for i, o_ in enumerate(o) ) 299 | 300 | def _decode_title(self, d, title_cls, title): 301 | if title: return title_cls(d, title=title) 302 | else: return title_cls(d) 303 | 304 | @typedispatch 305 | def _decode(self, t:torch.Tensor, title): 306 | if t.shape: title_cls = _TitledTuple 307 | elif isinstance(t.item(),bool): title_cls = _TitledBool # bool is also int, so check whether is bool first 308 | elif isinstance(t.item(),float): title_cls = _TitledFloat 309 | elif isinstance(t.item(),int): title_cls = _TitledInt 310 | return self._decode_title(t.tolist(), title_cls , title) 311 | 312 | @typedispatch 313 | def _decode(self, t:TensorText, title): 314 | assert self.hf_toker, "You should give a huggingface tokenizer if you want to show batch." 315 | if self.neat_show: text = self.hf_toker.decode([idx for idx in t if idx != self.hf_toker.pad_token_id]) 316 | else: text = ' '.join(self.hf_toker.convert_ids_to_tokens(t)) 317 | return self._decode_title(text, _TitledStr, title) 318 | 319 | @typedispatch 320 | def _decode(self, t:LMTensorText, title): return self._decode[TensorText](self, t, title) 321 | 322 | @typedispatch 323 | def _decode(self, t:TensorCategory, title): return self._decode_title(t.item(), _Category, title) 324 | 325 | @typedispatch 326 | def _decode(self, t:TensorMultiCategory, title): return self._decode_title(t.tolist(), _MultiCategory, title) 327 | 328 | def __getattr__(self, name): 329 | "If not defined, let the datasets.Dataset in it act for us." 330 | if name in HF_Dataset.__dict__: return HF_Dataset.__dict__[name] 331 | elif name in self.__dict__: return self.__dict__[name] 332 | elif hasattr(self.hf_dset, name): return getattr(self.hf_dset, name) 333 | raise AttributeError(f"Both 'HF_Dataset' object and 'datasets.Dataset' object have no '{name}' attribute ") 334 | 335 | class HF_Datasets(FilteredBase): 336 | """Function as :class:`fastai.data.core.Datasets` to create :class:`fastai.data.core.Dataloaders` from a group of :class:`datasets.Dataset`s""" 337 | 338 | _dl_type,_dbunch_type = MySortedDL,DataLoaders 339 | 340 | @delegates(HF_Dataset.__init__) 341 | def __init__(self, hf_dsets: dict, test_with_y=False, **kwargs): 342 | """ 343 | Args: 344 | hf_dsets (`Dict[datasets.Dataset]`): Prerocessed Hugging Face Datasets, {key} is split name, {value} is :class:`datasets.Dataset`, order will become the order in :class:`fastai.data.core.Dataloaders`. 345 | test_with_y (bool, optional): Whether the test set come with y (answers) but not with fake y (e.g. all -1 label). 346 | If ``False``, tell only test set to construct samples from first ``n_inp`` columns (do not output fake y). 347 | And all datasets passed in ``hf_dsets`` with its name starts with "test" will be regarded as test set. 348 | kwargs: Passed to :class:`HF_Dataset`. Be sure to pass arguments that :class:`HF_Dataset` needs !! 349 | """ 350 | cols, n_inp = kwargs.pop('cols', None), kwargs.get('n_inp', 1) 351 | self.hf_dsets = {}; 352 | for split, dset in hf_dsets.items(): 353 | if cols is None: cols = dset.column_names 354 | if split.startswith('test') and not test_with_y: 355 | if isinstance(cols, list): _cols = cols[:n_inp] 356 | else: _cols = { k:v for _, (k,v) in zip(range(n_inp),cols.items()) } 357 | else: _cols = cols 358 | self.hf_dsets[split] = HF_Dataset(dset, cols=_cols, **kwargs) 359 | 360 | def subset(self, i): return list(self.hf_dsets.values())[i] 361 | def __getitem__(self, split): return self.hf_dsets[split] 362 | @property 363 | def n_subsets(self): return len(self.hf_dsets) 364 | @property 365 | def cache_dir(self): return Path(next(iter(self.hf_dsets.values())).cache_files[0]['filename']).parent 366 | 367 | @delegates(FilteredBase.dataloaders) 368 | def dataloaders(self, device='cpu', cache_dir=None, cache_name=None, dl_kwargs=None, **kwargs): 369 | """ 370 | Args: 371 | device (str): device where outputed batch will be on. Because a batch will be loaded to test when creating :class: `fastai.data.core.Dataloaders`, to prevent always leaving a batch of tensor in cuda:0, using default value cpu and then ``dls.to(other device)`` at the time you want is suggested. 372 | cache_dir (str, optional): directory to store caches of :class:`MySortedDL`. if ``None``, use cache directory of the first :class:`datasets.Dataset` in ``hf_dsets`` that passed to :method:`HF_Datasets.__init__`. 373 | cache_name (str, optional): format string that includes one param "{split}", which will be replaced with name of split as cache file name under `cache_dir` for each split. If ``None``, tell :class:MySortedDL don't do caching. 374 | dl_kwargs (list[dict], optional): ith item is addtional kwargs to be passed to initialization of ith dataloader for ith split 375 | kwargs: Passed to :func:`fastai.data.core.FilteredBase.dataloaders` 376 | 377 | Example: 378 | >>> tokenized_cola 379 | {'train': datasets.Dataset, 'validation': datasets.Dataset, 'test': datasets.Dataset} 380 | >>> tokenized_cola['test'][0] 381 | {'sentence': 'Bill whistled past the house.', 382 | 'label': -1, # Fake label. True labels are not open to the public. 383 | 'idx': 0, 384 | 'text_idxs': [3021, 26265, 2627, 1996, 2160, 1012]} 385 | >>> dls = HF_Datasets(tokenized_cola, 386 | ... cols=['text_idxs', 'label'], hf_toker=hf_tokenizer, # args for HF_Dataset 387 | ... ).dataloaders(bs=32 , cache_name="dl_cached_for_{split}") # args for MySortedDL 388 | >>> dls.show_batch(max_n=2) 389 | text_idxs label 390 | --------------------------------------------------------------------------------------------------------------------------------------------- 391 | 0 everybody who has ever, worked in any office which contained any typewriter which had ever been used to type any letters which had 1 392 | to be signed by any administrator who ever worked in any department like mine will know what i mean. 393 | --------------------------------------------------------------------------------------------------------------------------------------------- 394 | 1 playing with matches is ; lots of fun, but doing, so and emptying gasoline from one can to another at the same time is a sport best 1 395 | reserved for arsons. 396 | # test set won't produce label becuase of `test_with_y=False` 397 | >>> dls[-1].show_batch(max_n=2) 398 | text_idxs 399 | ------------------------------------------------------------------------------------------ 400 | 0 cultural commissioner megan smith said that the five ` ` soundscape'' pieces would ` ` 401 | give a festive air to park square, they're fun and interesting''. 402 | ------------------------------------------------------------------------------------------ 403 | 1 wendy is eager to sail around the world and bruce is eager to climb kilimanjaro, but 404 | neither of them can because money is too tight. 405 | """ 406 | if dl_kwargs is None: dl_kwargs = [{} for _ in range(len(self.hf_dsets))] 407 | elif isinstance(dl_kwargs, dict): 408 | dl_kwargs = [ dl_kwargs[split] if split in dl_kwargs else {} for split in self.hf_dsets] 409 | # infer cache file names for each dataloader if needed 410 | dl_type = kwargs.pop('dl_type', self._dl_type) 411 | if dl_type==MySortedDL and cache_name: 412 | assert "{split}" in cache_name, "`cache_name` should be a string with '{split}' in it to be formatted." 413 | cache_dir = Path(cache_dir) if cache_dir else self.cache_dir 414 | cache_dir.mkdir(exist_ok=True) 415 | if not cache_name.endswith('.json'): cache_name += '.json' 416 | for i, split in enumerate(self.hf_dsets): 417 | filled_cache_name = dl_kwargs[i].pop('cache_name', cache_name.format(split=split)) 418 | if 'cache_file' not in dl_kwargs[i]: 419 | dl_kwargs[i]['cache_file'] = cache_dir/filled_cache_name 420 | # change default to not drop last 421 | kwargs['drop_last'] = kwargs.pop('drop_last', False) 422 | # when corpus like glue/ax has only testset, set it to non-train setting 423 | if list(self.hf_dsets.keys())[0].startswith('test'): 424 | kwargs['shuffle_train'] = False 425 | kwargs['drop_last'] = False 426 | return super().dataloaders(dl_kwargs=dl_kwargs, device=device, **kwargs) -------------------------------------------------------------------------------- /tests/hf_nlp_extension_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "\n", 12 | "from fastai.callback.core import *\n", 13 | "\n", 14 | "from IPython.core.debugger import set_trace as bk\n", 15 | "import os\n", 16 | "from pathlib import Path\n", 17 | "from functools import partial\n", 18 | "import torch\n", 19 | "import datasets\n", 20 | "from transformers import ElectraTokenizerFast\n", 21 | "hf_tokenizer = ElectraTokenizerFast.from_pretrained(\"google/electra-small-generator\")\n", 22 | "from hugdatafast import *" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "# 1.Basics" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## Simple tokenization & infer cache name\n", 37 | "`cols`(`Dict[str]`): tokenize the every column named key into column named its value \n", 38 | "`cols`(`List[str]`): specify the name of columns to be tokenized, replace the original columns' data with tokenized one\n", 39 | "\n", 40 | "Here, we tokenized \"sentence\" into a new column named \"text_idxs\", the \"sentence\" column still exist." 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": { 47 | "tags": [] 48 | }, 49 | "outputs": [ 50 | { 51 | "output_type": "stream", 52 | "name": "stderr", 53 | "text": "Reusing dataset glue (/home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)\n" 54 | }, 55 | { 56 | "output_type": "display_data", 57 | "data": { 58 | "text/plain": "HBox(children=(FloatProgress(value=0.0, description='#1', max=4275.0, style=ProgressStyle(description_width='i…", 59 | "application/vnd.jupyter.widget-view+json": { 60 | "version_major": 2, 61 | "version_minor": 0, 62 | "model_id": "664fbd6cc00b44848cf79d99445e02d1" 63 | } 64 | }, 65 | "metadata": {} 66 | }, 67 | { 68 | "output_type": "display_data", 69 | "data": { 70 | "text/plain": "HBox(children=(FloatProgress(value=0.0, description='#0', max=4276.0, style=ProgressStyle(description_width='i…", 71 | "application/vnd.jupyter.widget-view+json": { 72 | "version_major": 2, 73 | "version_minor": 0, 74 | "model_id": "c8bda3d5f98e4c7db9a2a9f024b2ab0f" 75 | } 76 | }, 77 | "metadata": {} 78 | }, 79 | { 80 | "output_type": "stream", 81 | "name": "stdout", 82 | "text": "\n\n" 83 | }, 84 | { 85 | "output_type": "display_data", 86 | "data": { 87 | "text/plain": "HBox(children=(FloatProgress(value=0.0, description='#0', max=522.0, style=ProgressStyle(description_width='in…", 88 | "application/vnd.jupyter.widget-view+json": { 89 | "version_major": 2, 90 | "version_minor": 0, 91 | "model_id": "cd3364dce60a473ea3191cafd6835085" 92 | } 93 | }, 94 | "metadata": {} 95 | }, 96 | { 97 | "output_type": "display_data", 98 | "data": { 99 | "text/plain": "HBox(children=(FloatProgress(value=0.0, description='#1', max=521.0, style=ProgressStyle(description_width='in…", 100 | "application/vnd.jupyter.widget-view+json": { 101 | "version_major": 2, 102 | "version_minor": 0, 103 | "model_id": "1dcdbce0511b4c2a80bdf46561c3edc8" 104 | } 105 | }, 106 | "metadata": {} 107 | }, 108 | { 109 | "output_type": "stream", 110 | "name": "stdout", 111 | "text": "\n\n" 112 | }, 113 | { 114 | "output_type": "display_data", 115 | "data": { 116 | "text/plain": "HBox(children=(FloatProgress(value=0.0, description='#0', max=532.0, style=ProgressStyle(description_width='in…", 117 | "application/vnd.jupyter.widget-view+json": { 118 | "version_major": 2, 119 | "version_minor": 0, 120 | "model_id": "d738bb0ebfb441f4ad8cf7a7418ff38a" 121 | } 122 | }, 123 | "metadata": {} 124 | }, 125 | { 126 | "output_type": "display_data", 127 | "data": { 128 | "text/plain": "HBox(children=(FloatProgress(value=0.0, description='#1', max=531.0, style=ProgressStyle(description_width='in…", 129 | "application/vnd.jupyter.widget-view+json": { 130 | "version_major": 2, 131 | "version_minor": 0, 132 | "model_id": "3c3213746dbd42b9bf7995b3d12a84f3" 133 | } 134 | }, 135 | "metadata": {} 136 | }, 137 | { 138 | "output_type": "stream", 139 | "name": "stdout", 140 | "text": "\n\n{'idx': 0, 'label': 1, 'sentence': \"Our friends won't buy this analysis, let alone the next one we propose.\", 'text_idxs': [2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012]}\n\n/home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/tokenized_train_00000_of_00002.arrow\n/home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/tokenized_validation_00000_of_00002.arrow\n/home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/tokenized_test_00000_of_00002.arrow\n" 141 | } 142 | ], 143 | "source": [ 144 | "\n", 145 | "cola = datasets.load_dataset('glue', 'cola')\n", 146 | "tokenized_cola = cola.my_map(SimpleTokenize({'sentence':'text_idxs'}, hf_tokenizer),\n", 147 | " cache_file_names='tokenized_{split}', num_proc=2)\n", 148 | "print(tokenized_cola['train'][0])\n", 149 | "print()\n", 150 | "for dset in tokenized_cola.values(): print(dset.cache_files[0]['filename'])" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "## Create fastai `Dataloaders` and `show_batch`\n", 158 | "\n", 159 | "`cols`: **specify columns whose values form a output sample in order**, and the semantic type of each column to encode/decode, with one of the following signature (see doc).\n", 160 | "\n", 161 | "Here, `['text_idxs, 'label']` is equal to `{'text_idxs': TensorText, 'label': TensorCategory}`\n", 162 | "\n", 163 | "The bars are sorting samples according to length, see `MySortedDL`" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 3, 169 | "metadata": { 170 | "tags": [] 171 | }, 172 | "outputs": [ 173 | { 174 | "output_type": "stream", 175 | "name": "stderr", 176 | "text": "Set __getitem__(key) output type to torch for ['text_idxs', 'label'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to torch for ['text_idxs', 'label'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to torch for ['text_idxs'] columns (when key is int or slice) and don't output other (un-formatted) columns.\n 98%|█████████▊| 1037/1063 [00:02<00:00, 515.45it/s]" 177 | }, 178 | { 179 | "output_type": "display_data", 180 | "data": { 181 | "text/plain": "", 182 | "text/html": "\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
text_idxslabel
0everybody who has ever, worked in any office which contained any typewriter which had ever been used to type any letters which had to be signed by any administrator who ever worked in any department like mine will know what i mean.1
1ron wanted to wear a tuxedo to the party, but wear a tuxedo to the party caspar couldn't decide whether to.0
" 183 | }, 184 | "metadata": {} 185 | } 186 | ], 187 | "source": [ 188 | "cola_dsets = HF_Datasets(tokenized_cola, cols=['text_idxs', 'label'], hf_toker=hf_tokenizer, neat_show=True)\n", 189 | "cola_dls = cola_dsets.dataloaders(bs=32)\n", 190 | "cola_dls.show_batch(max_n=2) # show at most two rows" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": {}, 196 | "source": [ 197 | "You can either specify `neat_show=False` (which is default), to show real data which is tokenized and with pad " 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 4, 203 | "metadata": { 204 | "tags": [] 205 | }, 206 | "outputs": [ 207 | { 208 | "output_type": "stream", 209 | "name": "stderr", 210 | "text": "Set __getitem__(key) output type to torch for ['text_idxs', 'label'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to torch for ['text_idxs', 'label'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to torch for ['text_idxs'] columns (when key is int or slice) and don't output other (un-formatted) columns.\n 98%|█████████▊| 1040/1063 [00:02<00:00, 516.29it/s]" 211 | }, 212 | { 213 | "output_type": "display_data", 214 | "data": { 215 | "text/plain": "", 216 | "text/html": "\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
text_idxslabel
0everybody who has ever , worked in any office which contained any type ##writer which had ever been used to type any letters which had to be signed by any administrator who ever worked in any department like mine will know what i mean .1
1will put a picture of bill on your desk before tomorrow , this girl in the red coat will put a picture of bill on your desk before tomorrow . [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]0
" 217 | }, 218 | "metadata": {} 219 | } 220 | ], 221 | "source": [ 222 | "cola_dsets = HF_Datasets(tokenized_cola, cols={'text_idxs': TensorText, 'label': TensorCategory}, hf_toker=hf_tokenizer)\n", 223 | "cola_dls = cola_dsets.dataloaders(bs=32)\n", 224 | "cola_dls.show_batch(max_n=2)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "`test_with_label` is `False` by default, so in test set the sample formed by only first `n_inp` columns specified, which is x.\n", 232 | "\n", 233 | "This make you able to apply the same to all splits when test set come with no y or fake y" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "## Multiple columns (> 2) in sample\n", 241 | "Some points to notice:\n", 242 | "- title of each column showed is and in order of `cols` specified in `HF_Datasets`\n", 243 | "- auto pad sequence to the max length in the batch, for all columns\n", 244 | "- If a fastai semantic tensor type is not specified, it look dtype and shape of the tensor and decide how to decode it autmatically " 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 5, 250 | "metadata": { 251 | "tags": [] 252 | }, 253 | "outputs": [ 254 | { 255 | "output_type": "stream", 256 | "name": "stderr", 257 | "text": "https://raw.githubusercontent.com/huggingface/datasets/master/datasets/super_glue/super_glue.py not found in cache or force_download set to True, downloading to /home/yisiang/.cache/huggingface/datasets/tmp4vauussu\n" 258 | }, 259 | { 260 | "output_type": "display_data", 261 | "data": { 262 | "text/plain": "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=10474.0, style=ProgressStyle(descriptio…", 263 | "application/vnd.jupyter.widget-view+json": { 264 | "version_major": 2, 265 | "version_minor": 0, 266 | "model_id": "16f232473c554b5da97721aa9feef900" 267 | } 268 | }, 269 | "metadata": {} 270 | }, 271 | { 272 | "output_type": "stream", 273 | "name": "stdout", 274 | "text": "\nstoring https://raw.githubusercontent.com/huggingface/datasets/master/datasets/super_glue/super_glue.py in cache at /home/yisiang/.cache/huggingface/datasets/17727f4c5312e09bd16ee8581466c4f74b1802efd416965b4cfd523c12fad94d.ab72d3ffcbe0d0e93a4595f2a810b3988c20d7836ae0bdb5ff4bdccf6bd92a36.py\ncreating metadata file for /home/yisiang/.cache/huggingface/datasets/17727f4c5312e09bd16ee8581466c4f74b1802efd416965b4cfd523c12fad94d.ab72d3ffcbe0d0e93a4595f2a810b3988c20d7836ae0bdb5ff4bdccf6bd92a36.py\nhttps://raw.githubusercontent.com/huggingface/datasets/master/datasets/super_glue/dataset_infos.json not found in cache or force_download set to True, downloading to /home/yisiang/.cache/huggingface/datasets/tmpfw9s15jy\n" 275 | }, 276 | { 277 | "output_type": "display_data", 278 | "data": { 279 | "text/plain": "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=9253.0, style=ProgressStyle(description…", 280 | "application/vnd.jupyter.widget-view+json": { 281 | "version_major": 2, 282 | "version_minor": 0, 283 | "model_id": "108f0522a50f4a5f9fbeb97db165fd46" 284 | } 285 | }, 286 | "metadata": {} 287 | }, 288 | { 289 | "output_type": "stream", 290 | "name": "stderr", 291 | "text": "storing https://raw.githubusercontent.com/huggingface/datasets/master/datasets/super_glue/dataset_infos.json in cache at /home/yisiang/.cache/huggingface/datasets/be2c3836d8078b3465c52eebc3e437eeb18adabce99af20c98422a53acc7d3d4.9fa45241690c27df567c8014a4bf461a4ba1e82bd4358961888c6bf59769c3b5\ncreating metadata file for /home/yisiang/.cache/huggingface/datasets/be2c3836d8078b3465c52eebc3e437eeb18adabce99af20c98422a53acc7d3d4.9fa45241690c27df567c8014a4bf461a4ba1e82bd4358961888c6bf59769c3b5\nChecking /home/yisiang/.cache/huggingface/datasets/17727f4c5312e09bd16ee8581466c4f74b1802efd416965b4cfd523c12fad94d.ab72d3ffcbe0d0e93a4595f2a810b3988c20d7836ae0bdb5ff4bdccf6bd92a36.py for additional imports.\nFound main folder for dataset https://raw.githubusercontent.com/huggingface/datasets/master/datasets/super_glue/super_glue.py at /home/yisiang/.cache/huggingface/modules/datasets_modules/datasets/super_glue\nFound specific version folder for dataset https://raw.githubusercontent.com/huggingface/datasets/master/datasets/super_glue/super_glue.py at /home/yisiang/.cache/huggingface/modules/datasets_modules/datasets/super_glue/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de\nFound script file from https://raw.githubusercontent.com/huggingface/datasets/master/datasets/super_glue/super_glue.py to /home/yisiang/.cache/huggingface/modules/datasets_modules/datasets/super_glue/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de/super_glue.py\nFound dataset infos file from https://raw.githubusercontent.com/huggingface/datasets/master/datasets/super_glue/dataset_infos.json to /home/yisiang/.cache/huggingface/modules/datasets_modules/datasets/super_glue/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de/dataset_infos.json\nFound metadata file for dataset https://raw.githubusercontent.com/huggingface/datasets/master/datasets/super_glue/super_glue.py at /home/yisiang/.cache/huggingface/modules/datasets_modules/datasets/super_glue/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de/super_glue.json\n\nLoading Dataset Infos from /home/yisiang/.cache/huggingface/modules/datasets_modules/datasets/super_glue/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de\nOverwrite dataset info from restored data version.\nLoading Dataset info from /home/yisiang/.cache/huggingface/datasets/super_glue/wsc.fixed/1.0.2/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de\nReusing dataset super_glue (/home/yisiang/.cache/huggingface/datasets/super_glue/wsc.fixed/1.0.2/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de)\nConstructing Dataset for split train, validation, test, from /home/yisiang/.cache/huggingface/datasets/super_glue/wsc.fixed/1.0.2/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de\n100%|██████████| 3/3 [00:00<00:00, 7.93it/s]\nTesting the mapped function outputs\nTesting finished, running the mapping function on the dataset\nLoading cached processed dataset at /home/yisiang/.cache/huggingface/datasets/super_glue/wsc.fixed/1.0.2/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de/cache-4502eb5c5e2717ab.arrow\nTesting the mapped function outputs\nTesting finished, running the mapping function on the dataset\nLoading cached processed dataset at /home/yisiang/.cache/huggingface/datasets/super_glue/wsc.fixed/1.0.2/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de/cache-813b86f371807b5c.arrow\nTesting the mapped function outputs\nTesting finished, running the mapping function on the dataset\nLoading cached processed dataset at /home/yisiang/.cache/huggingface/datasets/super_glue/wsc.fixed/1.0.2/41d9edb3935257e1da4b7ce54cd90df0e8bb255a15e46cfe5cbc7e1c04f177de/cache-d3606f75f39a34a0.arrow\nSet __getitem__(key) output type to torch for ['text', 'span1_index', 'span1_text', 'span2_index', 'span2_text', 'label'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to torch for ['text', 'span1_index', 'span1_text', 'span2_index', 'span2_text', 'label'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to torch for ['text'] columns (when key is int or slice) and don't output other (un-formatted) columns.\n{'idx': 0, 'label': 0, 'span1_index': 0, 'span1_text': 'Mark', 'span2_index': 13, 'span2_text': 'He', 'text': 'Mark told Pete many lies about himself, which Pete included in his book. He should have been more skeptical.'}\n" 292 | }, 293 | { 294 | "output_type": "display_data", 295 | "data": { 296 | "text/plain": "", 297 | "text/html": "\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
textspan1_indexspan1_textspan2_indexspan2_textlabel
0mark told pete many lies about himself , which pete included in his book . he should have been more skeptical . [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]0mark [PAD] [PAD]13he0
1the mothers of arthur and celeste have come to the town to fetch them . they are very happy to have them back , but they sc ##old them just the same because they ran away . [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]1mothers [PAD] [PAD]25them0
2mark was close to mr . singer ' s heels . he heard him calling for the captain , promising him , in the jar ##gon everyone talked that night , that not one thing should be damaged on the ship except only the ammunition , but the captain and all his crew had best stay in the cabin until the work was over4mr . singer8he0
" 298 | }, 299 | "metadata": {} 300 | } 301 | ], 302 | "source": [ 303 | "wsc = datasets.load_dataset('super_glue', 'wsc.fixed')\n", 304 | "print(wsc['train'][0])\n", 305 | "tokenized_wsc = wsc.my_map(simple_tokenize_func(['text', 'span1_text', 'span2_text'], hf_tokenizer))\n", 306 | "wsc_dsets = HF_Datasets(tokenized_wsc, cols={'text': TensorText, 'span1_index': noop, 'span1_text':TensorText, 'span2_index': noop, 'span2_text': TensorText, 'label': lambda t: t.bool()}, # convert label (int) to (bool), just to test its abililty to show tensor(bool)\n", 307 | "hf_toker=hf_tokenizer)\n", 308 | "dls = wsc_dsets.dataloaders(bs=3, srtkey_fc=False, shuffle_train=False) # don't sort samples, don't shuffle trainset\n", 309 | "#bk()\n", 310 | "dls.show_batch()" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": {}, 316 | "source": [ 317 | "# 2. Aggregate Dataset\n", 318 | "a sample in transformed dataset is aggregated/accumulated from multiple original samples.\n", 319 | "\n", 320 | "- Except for `LMTransform`, you can implement your own logic create a class inherits `AggregateTransform` and implements `accumulate` and `create_example` method\n", 321 | "\n", 322 | "- Note that you should pass **tokenized** dataset(s)" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": {}, 328 | "source": [ 329 | "## Make dataset(s) for (traditional) language model`" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 6, 335 | "metadata": { 336 | "tags": [] 337 | }, 338 | "outputs": [ 339 | { 340 | "output_type": "stream", 341 | "name": "stderr", 342 | "text": "Set __getitem__(key) output type to python objects for ['text_idxs'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nTesting the mapped function outputs\nTesting finished, running the mapping function on the dataset\nCaching processed dataset at /home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-e6acb88170f61c6d.arrow\n" 343 | }, 344 | { 345 | "output_type": "display_data", 346 | "data": { 347 | "text/plain": "HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))", 348 | "application/vnd.jupyter.widget-view+json": { 349 | "version_major": 2, 350 | "version_minor": 0, 351 | "model_id": "c2a6dfaea1344f0281695366de90d62a" 352 | } 353 | }, 354 | "metadata": {} 355 | }, 356 | { 357 | "output_type": "stream", 358 | "name": "stdout", 359 | "text": "\nDone writing 481 examples in 157576 bytes /home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/tmppum5q2sm.\nSet __getitem__(key) output type to python objects for [] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to python objects for ['x_text', 'y_text'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nOriginal dataset:\nnum of samples: 1043\nsecond to last sentence: John arranged for himself to get the prize.\n last sentence: John talked to Bill about himself.\nLM dataset:\nnum of sampels: 481\nlast text (x): . john talked to bill about himself\nlast text (y): john talked to bill about himself.\n" 360 | } 361 | ], 362 | "source": [ 363 | "cola_val = tokenized_cola['validation']\n", 364 | "#bk()\n", 365 | "lm_cola_val = LMTransform(cola_val, max_len=20, text_col='text_idxs').map()\n", 366 | "\n", 367 | "print('Original dataset:')\n", 368 | "print('num of samples:', len(cola['validation']))\n", 369 | "print('second to last sentence:', cola['validation'][-2]['sentence'])\n", 370 | "print(' last sentence:', cola['validation'][-1]['sentence'])\n", 371 | "print('LM dataset:')\n", 372 | "print('num of sampels:', len(lm_cola_val))\n", 373 | "assert len(lm_cola_val) == 481\n", 374 | "print('last text (x):', hf_tokenizer.decode(lm_cola_val[-1]['x_text']))\n", 375 | "print('last text (y):', hf_tokenizer.decode(lm_cola_val[-1]['y_text']))" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 13, 381 | "metadata": { 382 | "tags": [] 383 | }, 384 | "outputs": [ 385 | { 386 | "output_type": "stream", 387 | "name": "stderr", 388 | "text": "Set __getitem__(key) output type to python objects for ['text_idxs'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to python objects for ['text_idxs'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to python objects for ['text_idxs'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nTesting the mapped function outputs\nTesting finished, running the mapping function on the dataset\nCaching processed dataset at /home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-2509c56d4d553502.arrow\n" 389 | }, 390 | { 391 | "output_type": "display_data", 392 | "data": { 393 | "text/plain": "HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))", 394 | "application/vnd.jupyter.widget-view+json": { 395 | "version_major": 2, 396 | "version_minor": 0, 397 | "model_id": "3c634eed05df4c81b7185ed05fc4b585" 398 | } 399 | }, 400 | "metadata": {} 401 | }, 402 | { 403 | "output_type": "stream", 404 | "name": "stdout", 405 | "text": "\nDone writing 1564 examples in 1263672 bytes /home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/tmp6hips0lj.\nSet __getitem__(key) output type to python objects for [] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to python objects for ['x_text', 'y_text'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nTesting the mapped function outputs\nTesting finished, running the mapping function on the dataset\nCaching processed dataset at /home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-b25d35199b6c7232.arrow\n" 406 | }, 407 | { 408 | "output_type": "display_data", 409 | "data": { 410 | "text/plain": "HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))", 411 | "application/vnd.jupyter.widget-view+json": { 412 | "version_major": 2, 413 | "version_minor": 0, 414 | "model_id": "83998357277546c59d3bc3204272927f" 415 | } 416 | }, 417 | "metadata": {} 418 | }, 419 | { 420 | "output_type": "stream", 421 | "name": "stderr", 422 | "text": "Done writing 198 examples in 159840 bytes /home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/tmplm69lznx.\nSet __getitem__(key) output type to python objects for [] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to python objects for ['x_text', 'y_text'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nTesting the mapped function outputs\nTesting finished, running the mapping function on the dataset\nCaching processed dataset at /home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-888ff239e1f71e93.arrow\n\n" 423 | }, 424 | { 425 | "output_type": "display_data", 426 | "data": { 427 | "text/plain": "HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))", 428 | "application/vnd.jupyter.widget-view+json": { 429 | "version_major": 2, 430 | "version_minor": 0, 431 | "model_id": "b5adc15037124d3e89a390bd83d812ff" 432 | } 433 | }, 434 | "metadata": {} 435 | }, 436 | { 437 | "output_type": "stream", 438 | "name": "stderr", 439 | "text": "Done writing 200 examples in 161056 bytes /home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/tmpigpvusur.\nSet __getitem__(key) output type to python objects for [] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to python objects for ['x_text', 'y_text'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to torch for ['x_text', 'y_text'] columns (when key is int or slice) and don't output other (un-formatted) columns.\n\n" 440 | }, 441 | { 442 | "output_type": "display_data", 443 | "data": { 444 | "text/plain": "", 445 | "text/html": "\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
x_texty_text
0the sailors rode the breeze clear of the rocks . the weights made the rope stretch over the pull ##ey . the mechanical doll wr ##ig ##gled itself loose . if you had eaten more , you would want less . as you eat the most , you want thesailors rode the breeze clear of the rocks . the weights made the rope stretch over the pull ##ey . the mechanical doll wr ##ig ##gled itself loose . if you had eaten more , you would want less . as you eat the most , you want the least
1. the more you would want , the less you would eat . i demand that the more john eat , the more he pays . mary listen ##s to the grateful dead , she gets depressed . the ang ##rier mary got , the more she looked at picturesthe more you would want , the less you would eat . i demand that the more john eat , the more he pays . mary listen ##s to the grateful dead , she gets depressed . the ang ##rier mary got , the more she looked at pictures .
" 446 | }, 447 | "metadata": {} 448 | } 449 | ], 450 | "source": [ 451 | "lm_cola = LMTransform(tokenized_cola, max_len=50, text_col='text_idxs').map()\n", 452 | "# test single dataset\n", 453 | "lm_ds = HF_Dataset(lm_cola['validation'], cols={'x_text':LMTensorText, 'y_text':TensorText},hf_toker=hf_tokenizer)\n", 454 | "lm_dl = MySortedDL(lm_ds, srtkey_fc=False)\n", 455 | "lm_dl.show_batch(max_n=2)" 456 | ] 457 | }, 458 | { 459 | "cell_type": "markdown", 460 | "metadata": {}, 461 | "source": [ 462 | "## Test ELECTRA data creating" 463 | ] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "execution_count": 14, 468 | "metadata": { 469 | "tags": [] 470 | }, 471 | "outputs": [ 472 | { 473 | "output_type": "stream", 474 | "name": "stderr", 475 | "text": "Set __getitem__(key) output type to python objects for ['sentence'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nTesting the mapped function outputs\nTesting finished, running the mapping function on the dataset\nCaching processed dataset at /home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-7db22ab214e91040.arrow\n" 476 | }, 477 | { 478 | "output_type": "display_data", 479 | "data": { 480 | "text/plain": "HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))", 481 | "application/vnd.jupyter.widget-view+json": { 482 | "version_major": 2, 483 | "version_minor": 0, 484 | "model_id": "bbec5b14896f4efc93a66ef705397304" 485 | } 486 | }, 487 | "metadata": {} 488 | }, 489 | { 490 | "output_type": "stream", 491 | "name": "stderr", 492 | "text": "Done writing 78 examples in 78488 bytes /home/yisiang/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/tmpi_20vusq.\n\nSet __getitem__(key) output type to python objects for [] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to python objects for ['input_ids', 'sentA_lenth'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to torch for ['input_ids', 'sentA_lenth'] columns (when key is int or slice) and don't output other (un-formatted) columns.\n" 493 | }, 494 | { 495 | "output_type": "display_data", 496 | "data": { 497 | "text/plain": "", 498 | "text/html": "\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
input_idssentA_lenth
0[CLS] no writer , nor any playwright , meets in vienna . that you will marry any student is not certain . felicia kicked the ball off the bench . i sent the package halfway around the world . sam gave the ball out of the basket . sam offered the ball out of the basket . park square has a fest ##ive air . [SEP] the worker will have a job . no one can forgive that comment to you . we launched the rocket to the moon , but it blew up before it got there . sarah promised catherine her old car , but then gave it to her son instead . i lent the book part ##way to tony . the farmer loaded [SEP]66
1[CLS] i borrowed fred ' s diagram of a snake ' s eye because steve ' s had been stolen . jerry attempted to blow up the pentagon . so fast did he run that nobody could catch him . bill bought a red house , and max bought one too . who always drinks milk ? the book which inspired them was very long . [SEP] the book what inspired them was very long . i know the person whose mother died . the person whose mother ' s dog we were all fond of . i wonder whose mother died . i wonder whose mother ' s dog died . i wonder to whom they dedicated the building . give me the phone number of [SEP]67
" 499 | }, 500 | "metadata": {} 501 | } 502 | ], 503 | "source": [ 504 | "proc_dset = ELECTRADataTransform(cola['validation'], is_docs=False, text_col='sentence', max_length=128, hf_toker=hf_tokenizer).map()\n", 505 | "e_dsets = HF_Datasets({'train':proc_dset}, cols={'input_ids':TensorText,'sentA_lenth':noop}, hf_toker=hf_tokenizer)\n", 506 | "e_dls = e_dsets.dataloaders(srtkey_fc=False)\n", 507 | "e_dls.show_batch(max_n=2)" 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": {}, 513 | "source": [ 514 | "# 3. Test filtering feature\n", 515 | "Note that filter won't be applied to split other than train, because validation/test set is for fair comparison, and you can't take out samples at your will " 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": 9, 521 | "metadata": { 522 | "tags": [] 523 | }, 524 | "outputs": [ 525 | { 526 | "output_type": "stream", 527 | "name": "stdout", 528 | "text": "{'train': 26, 'validation': 2, 'test': 6}\n" 529 | } 530 | ], 531 | "source": [ 532 | "l = 23\n", 533 | "num = {}\n", 534 | "for split in tokenized_cola:\n", 535 | " num[split] = reduce(lambda sum, sample: sum+(1 if len(sample['text_idxs'])==l else 0), \n", 536 | " tokenized_cola[split], 0)\n", 537 | "print(num)" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": 10, 543 | "metadata": { 544 | "tags": [] 545 | }, 546 | "outputs": [ 547 | { 548 | "output_type": "stream", 549 | "name": "stderr", 550 | "text": "Set __getitem__(key) output type to torch for ['text_idxs', 'label'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to torch for ['text_idxs', 'label'] columns (when key is int or slice) and don't output other (un-formatted) columns.\nSet __getitem__(key) output type to torch for ['text_idxs'] columns (when key is int or slice) and don't output other (un-formatted) columns.\n 99%|█████████▉| 1054/1063 [00:02<00:00, 433.55it/s]Test passed\n" 551 | } 552 | ], 553 | "source": [ 554 | "ccola_dsets = HF_Datasets(tokenized_cola, cols=['text_idxs', 'label'], hf_toker=hf_tokenizer)\n", 555 | "ccola_dls = ccola_dsets.dataloaders(filter_fc=lambda text_idxs, label: len(text_idxs)!=l,)\n", 556 | "\n", 557 | "for i, split in enumerate(tokenized_cola):\n", 558 | " if split == 'train':\n", 559 | " assert ccola_dls[i].n == len(tokenized_cola[split])-num[split],f\"{split}: filtered: {ccola_dls[i].n}, unfiltered: {len(tokenized_cola[split])}, should be filtered: {num[split]}\"\n", 560 | " else:\n", 561 | " assert ccola_dls[i].n == len(tokenized_cola[split]), f\"{split}: accidentally filtered: {ccola_dls[i].n}, unfiltered: {len(tokenized_cola[split])}\"\n", 562 | "print(\"Test passed\")" 563 | ] 564 | }, 565 | { 566 | "cell_type": "markdown", 567 | "metadata": {}, 568 | "source": [ 569 | "# 4. Cache dataloader\n", 570 | "If sorting or filtering is applied, dataloader need to create some record inside it, to do it only once, we can cache the records. \n", 571 | "\n", 572 | "If `cache_dir` is not specified, it will be the cache_dir of `dsets` passed to `HF_Datasets`." 573 | ] 574 | }, 575 | { 576 | "cell_type": "code", 577 | "execution_count": 11, 578 | "metadata": { 579 | "tags": [] 580 | }, 581 | "outputs": [ 582 | { 583 | "output_type": "stream", 584 | "name": "stderr", 585 | "text": "99%|█████████▉| 1054/1063 [00:02<00:00, 420.38it/s]" 586 | } 587 | ], 588 | "source": [ 589 | "for f in ['/tmp/cached_train.json','/tmp/cached_val.json', '/tmp/cached_test.json']:\n", 590 | " if Path(f).exists(): os.remove(f)\n", 591 | "\n", 592 | "ccola_dls = ccola_dsets.dataloaders(cache_dir='/tmp', cache_name='cached_{split}.json')" 593 | ] 594 | }, 595 | { 596 | "cell_type": "markdown", 597 | "metadata": {}, 598 | "source": [ 599 | "This time we load the caches, it should be fast and progress bars sholdn't appear" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": 12, 605 | "metadata": {}, 606 | "outputs": [], 607 | "source": [ 608 | "ccola_dls = ccola_dsets.dataloaders(cache_dir='/tmp', cache_name='cached_{split}.json')" 609 | ] 610 | } 611 | ], 612 | "metadata": { 613 | "language_info": { 614 | "codemirror_mode": { 615 | "name": "ipython", 616 | "version": 3 617 | }, 618 | "file_extension": ".py", 619 | "mimetype": "text/x-python", 620 | "name": "python", 621 | "nbconvert_exporter": "python", 622 | "pygments_lexer": "ipython3", 623 | "version": "3.7.7-final" 624 | }, 625 | "orig_nbformat": 2, 626 | "kernelspec": { 627 | "name": "python3", 628 | "display_name": "Python 3" 629 | } 630 | }, 631 | "nbformat": 4, 632 | "nbformat_minor": 2 633 | } --------------------------------------------------------------------------------