├── .gitignore
├── LICENCE
├── README.md
├── augment.png
├── examples
├── aeda_example.ipynb
├── eda_example.ipynb
├── fasttext_example.ipynb
├── mixup_example_using_IMDB_sentiment.ipynb
└── word2vec_example.ipynb
├── requirements.txt
├── setup.py
├── tests
├── test_translate.py
├── test_word2vec.py
└── test_wordnet.py
└── textaugment
├── __init__.py
├── aeda.py
├── constants.py
├── eda.py
├── mixup.py
├── translate.py
├── word2vec.py
└── wordnet.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Joseph Sefara
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # [TextAugment: Improving Short Text Classification through Global Augmentation Methods](https://arxiv.org/abs/1907.03752)
4 |
5 | [](https://github.com/dsfsi/textaugment/blob/master/LICENCE) [](https://github.com/dsfsi/textaugment/releases) [](https://pypi.python.org/pypi/textaugment) [](https://pypi.org/project/textaugment/) [](https://pypi.org/project/textaugment/) [](https://pypi.org/project/textaugment/) [](https://link.springer.com/chapter/10.1007%2F978-3-030-57321-8_21) [](https://arxiv.org/abs/1907.03752)
6 |
7 |
8 | ## You have just found TextAugment.
9 |
10 | TextAugment is a Python 3 library for augmenting text for natural language processing applications. TextAugment stands on the giant shoulders of [NLTK](https://www.nltk.org/), [Gensim v3.x](https://radimrehurek.com/gensim/), and [TextBlob](https://textblob.readthedocs.io/) and plays nicely with them.
11 |
12 | ## Acknowledgements
13 | Cite this [paper](https://link.springer.com/chapter/10.1007%2F978-3-030-57321-8_21) when using this library. [Arxiv Version](https://arxiv.org/abs/1907.03752)
14 |
15 | ```
16 | @inproceedings{marivate2020improving,
17 | title={Improving short text classification through global augmentation methods},
18 | author={Marivate, Vukosi and Sefara, Tshephisho},
19 | booktitle={International Cross-Domain Conference for Machine Learning and Knowledge Extraction},
20 | pages={385--399},
21 | year={2020},
22 | organization={Springer}
23 | }
24 | ```
25 |
26 | # Table of Contents
27 |
28 | - [Features](#Features)
29 | - [Citation Paper](#citation-paper)
30 | - [Requirements](#Requirements)
31 | - [Installation](#Installation)
32 | - [How to use](#How-to-use)
33 | - [Word2vec-based augmentation](#Word2vec-based-augmentation)
34 | - [WordNet-based augmentation](#WordNet-based-augmentation)
35 | - [RTT-based augmentation](#RTT-based-augmentation)
36 | - [Easy data augmentation (EDA)](#eda-easy-data-augmentation-techniques-for-boosting-performance-on-text-classification-tasks)
37 | - [An easier data augmentation (AEDA)](#aeda-an-easier-data-augmentation-technique-for-text-classification)
38 | - [Mixup augmentation](#mixup-augmentation)
39 | - [Implementation](#Implementation)
40 | - [Acknowledgements](#Acknowledgements)
41 |
42 | ## Features
43 |
44 | - Generate synthetic data for improving model performance without manual effort
45 | - Simple, lightweight, easy-to-use library.
46 | - Plug and play to any machine learning frameworks (e.g. PyTorch, TensorFlow, Scikit-learn)
47 | - Support textual data
48 |
49 | ## Citation Paper
50 |
51 | **[Improving short text classification through global augmentation methods](https://link.springer.com/chapter/10.1007%2F978-3-030-57321-8_21)**.
52 |
53 |
54 |
55 | 
56 |
57 | ### Requirements
58 |
59 | * Python 3
60 |
61 | The following software packages are dependencies and will be installed automatically.
62 |
63 | ```shell
64 | $ pip install numpy nltk gensim==3.8.3 textblob googletrans
65 |
66 | ```
67 | The following code downloads NLTK corpus for [wordnet](http://www.nltk.org/howto/wordnet.html).
68 | ```python
69 | nltk.download('wordnet')
70 | ```
71 | The following code downloads [NLTK tokenizer](https://www.nltk.org/_modules/nltk/tokenize/punkt.html). This tokenizer divides a text into a list of sentences by using an unsupervised algorithm to build a model for abbreviation words, collocations, and words that start sentences.
72 | ```python
73 | nltk.download('punkt')
74 | ```
75 | The following code downloads default [NLTK part-of-speech tagger](https://www.nltk.org/_modules/nltk/tag.html) model. A part-of-speech tagger processes a sequence of words, and attaches a part of speech tag to each word.
76 | ```python
77 | nltk.download('averaged_perceptron_tagger')
78 | ```
79 | Use gensim to load a pre-trained word2vec model. Like [Google News from Google drive](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit).
80 | ```python
81 | import gensim
82 | model = gensim.models.KeyedVectors.load_word2vec_format('./GoogleNews-vectors-negative300.bin', binary=True)
83 | ```
84 | You can also use gensim to load Facebook's Fasttext [English](https://fasttext.cc/docs/en/english-vectors.html) and [Multilingual models](https://fasttext.cc/docs/en/crawl-vectors.html)
85 | ```
86 | import gensim
87 | model = gensim.models.fasttext.load_facebook_model('./cc.en.300.bin.gz')
88 | ```
89 |
90 | Or training one from scratch using your data or the following public dataset:
91 |
92 | - [Text8 Wiki](http://mattmahoney.net/dc/enwik9.zip)
93 |
94 | - [Dataset from "One Billion Word Language Modeling Benchmark"](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz)
95 |
96 | ### Installation
97 |
98 | Install from pip [Recommended]
99 | ```sh
100 | $ pip install textaugment
101 | or install latest release
102 | $ pip install git+git@github.com:dsfsi/textaugment.git
103 | ```
104 |
105 | Install from source
106 | ```sh
107 | $ git clone git@github.com:dsfsi/textaugment.git
108 | $ cd textaugment
109 | $ python setup.py install
110 | ```
111 |
112 | ### How to use
113 |
114 | There are three types of augmentations which can be used:
115 |
116 | - word2vec
117 |
118 | ```python
119 | from textaugment import Word2vec
120 | ```
121 | - fasttext
122 |
123 | ```python
124 | from textaugment import Fasttext
125 | ```
126 |
127 | - wordnet
128 | ```python
129 | from textaugment import Wordnet
130 | ```
131 | - translate (This will require internet access)
132 | ```python
133 | from textaugment import Translate
134 | ```
135 | #### Fasttext/Word2vec-based augmentation
136 |
137 | [See this notebook for an example](https://github.com/dsfsi/textaugment/blob/master/examples/word2vec_example.ipynb)
138 |
139 | **Basic example**
140 |
141 | ```python
142 | >>> from textaugment import Word2vec, Fasttext
143 | >>> t = Word2vec(model='path/to/gensim/model'or 'gensim model itself')
144 | >>> t.augment('The stories are good')
145 | The films are good
146 | >>> t = Fasttext(model='path/to/gensim/model'or 'gensim model itself')
147 | >>> t.augment('The stories are good')
148 | The films are good
149 | ```
150 | **Advanced example**
151 |
152 | ```python
153 | >>> runs = 1 # By default.
154 | >>> v = False # verbose mode to replace all the words. If enabled runs is not effective. Used in this paper (https://www.cs.cmu.edu/~diyiy/docs/emnlp_wang_2015.pdf)
155 | >>> p = 0.5 # The probability of success of an individual trial. (0.1
>> word = Word2vec(model='path/to/gensim/model'or'gensim model itself', runs=5, v=False, p=0.5)
158 | >>> word.augment('The stories are good', top_n=10)
159 | The movies are excellent
160 | >>> fast = Fasttext(model='path/to/gensim/model'or'gensim model itself', runs=5, v=False, p=0.5)
161 | >>> fast.augment('The stories are good', top_n=10)
162 | The movies are excellent
163 | ```
164 | #### WordNet-based augmentation
165 | **Basic example**
166 | ```python
167 | >>> import nltk
168 | >>> nltk.download('punkt')
169 | >>> nltk.download('wordnet')
170 | >>> from textaugment import Wordnet
171 | >>> t = Wordnet()
172 | >>> t.augment('In the afternoon, John is going to town')
173 | In the afternoon, John is walking to town
174 | ```
175 | **Advanced example**
176 |
177 | ```python
178 | >>> v = True # enable verbs augmentation. By default is True.
179 | >>> n = False # enable nouns augmentation. By default is False.
180 | >>> runs = 1 # number of times to augment a sentence. By default is 1.
181 | >>> p = 0.5 # The probability of success of an individual trial. (0.1
>> t = Wordnet(v=False ,n=True, p=0.5)
184 | >>> t.augment('In the afternoon, John is going to town', top_n=10)
185 | In the afternoon, Joseph is going to town.
186 | ```
187 | #### RTT-based augmentation
188 | **Example**
189 | ```python
190 | >>> src = "en" # source language of the sentence
191 | >>> to = "fr" # target language
192 | >>> from textaugment import Translate
193 | >>> t = Translate(src="en", to="fr")
194 | >>> t.augment('In the afternoon, John is going to town')
195 | In the afternoon John goes to town
196 | ```
197 | # EDA: Easy data augmentation techniques for boosting performance on text classification tasks
198 | ## This is the implementation of EDA by Jason Wei and Kai Zou.
199 |
200 | https://www.aclweb.org/anthology/D19-1670.pdf
201 |
202 | [See this notebook for an example](https://github.com/dsfsi/textaugment/blob/master/examples/eda_example.ipynb)
203 |
204 | #### Synonym Replacement
205 | Randomly choose *n* words from the sentence that are not stop words. Replace each of these words with
206 | one of its synonyms chosen at random.
207 |
208 | **Basic example**
209 | ```python
210 | >>> from textaugment import EDA
211 | >>> t = EDA()
212 | >>> t.synonym_replacement("John is going to town", top_n=10)
213 | John is give out to town
214 | ```
215 |
216 | #### Random Deletion
217 | Randomly remove each word in the sentence with probability *p*.
218 |
219 | **Basic example**
220 | ```python
221 | >>> from textaugment import EDA
222 | >>> t = EDA()
223 | >>> t.random_deletion("John is going to town", p=0.2)
224 | is going to town
225 | ```
226 |
227 | #### Random Swap
228 | Randomly choose two words in the sentence and swap their positions. Do this n times.
229 |
230 | **Basic example**
231 | ```python
232 | >>> from textaugment import EDA
233 | >>> t = EDA()
234 | >>> t.random_swap("John is going to town")
235 | John town going to is
236 | ```
237 |
238 | #### Random Insertion
239 | Find a random synonym of a random word in the sentence that is not a stop word. Insert that synonym into a random position in the sentence. Do this n times
240 |
241 | **Basic example**
242 | ```python
243 | >>> from textaugment import EDA
244 | >>> t = EDA()
245 | >>> t.random_insertion("John is going to town")
246 | John is going to make up town
247 | ```
248 |
249 | # AEDA: An easier data augmentation technique for text classification
250 |
251 | This is the implementation of AEDA by Karimi et al, a variant of EDA. It is based on the random insertion of punctuation marks.
252 |
253 | https://aclanthology.org/2021.findings-emnlp.234.pdf
254 |
255 | ## Implementation
256 | [See this notebook for an example](https://github.com/dsfsi/textaugment/blob/master/examples/eda_example.ipynb)
257 |
258 | #### Random Insertion of Punctuation Marks
259 |
260 | **Basic example**
261 | ```python
262 | >>> from textaugment import AEDA
263 | >>> t = AEDA()
264 | >>> t.punct_insertion("John is going to town")
265 | ! John is going to town
266 | ```
267 |
268 | # Mixup augmentation
269 |
270 | This is the implementation of mixup augmentation by [Hongyi Zhang, Moustapha Cisse, Yann Dauphin, David Lopez-Paz](https://openreview.net/forum?id=r1Ddp1-Rb) adapted to NLP.
271 |
272 | Used in [Augmenting Data with Mixup for Sentence Classification: An Empirical Study](https://arxiv.org/abs/1905.08941).
273 |
274 | Mixup is a generic and straightforward data augmentation principle. In essence, mixup trains a neural network on convex combinations of pairs of examples and their labels. By doing so, mixup regularises the neural network to favour simple linear behaviour in-between training examples.
275 |
276 | ## Implementation
277 |
278 | [See this notebook for an example](https://github.com/dsfsi/textaugment/blob/master/examples/mixup_example_using_IMDB_sentiment.ipynb)
279 |
280 | ## Built with ❤ on
281 | * [Python](http://python.org/)
282 |
283 | ## Authors
284 | * [Joseph Sefara](https://za.linkedin.com/in/josephsefara) (http://www.speechtech.co.za)
285 | * [Vukosi Marivate](http://www.vima.co.za) (http://www.vima.co.za)
286 |
287 | ## Acknowledgements
288 | Cite this [paper](https://link.springer.com/chapter/10.1007%2F978-3-030-57321-8_21) when using this library. [Arxiv Version](https://arxiv.org/abs/1907.03752)
289 |
290 | ```
291 | @inproceedings{marivate2020improving,
292 | title={Improving short text classification through global augmentation methods},
293 | author={Marivate, Vukosi and Sefara, Tshephisho},
294 | booktitle={International Cross-Domain Conference for Machine Learning and Knowledge Extraction},
295 | pages={385--399},
296 | year={2020},
297 | organization={Springer}
298 | }
299 | ```
300 |
301 | ## Licence
302 | MIT licensed. See the bundled [LICENCE](https://github.com/dsfsi/textaugment/blob/master/LICENCE) file for more details.
303 |
--------------------------------------------------------------------------------
/augment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dsfsi/textaugment/02c63e07f0b4dcdf95d9700722509e1512963d6a/augment.png
--------------------------------------------------------------------------------
/examples/aeda_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# AEDA example"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "try:\n",
17 | " from textaugment import AEDA\n",
18 | "except ModuleNotFoundError:\n",
19 | " !pip install textaugment\n",
20 | " from textaugment import AEDA"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "t = AEDA(random_state=1)"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "## Punctuation Insertion\n",
37 | "1. Randomly select the amount of punctuation to be inserted, between 1 and 1/3 of the length of the sentence.\n",
38 | "2. Randomly select the punctuation to be inserted.\n",
39 | "3. Randomly select the position of the punctuation to be inserted.\n",
40 | "4. Insert the punctuation at the selected position."
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 3,
46 | "metadata": {},
47 | "outputs": [
48 | {
49 | "name": "stdout",
50 | "output_type": "stream",
51 | "text": [
52 | "! John is going to town\n"
53 | ]
54 | }
55 | ],
56 | "source": [
57 | "output = t.punct_insertion(\"John is going to town\")\n",
58 | "print(output)"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "metadata": {},
64 | "source": [
65 | "## Cite the paper\n",
66 | "```\n",
67 | "@article{marivate2019improving,\n",
68 | " title={Improving short text classification through global augmentation methods},\n",
69 | " author={Marivate, Vukosi and Sefara, Tshephisho},\n",
70 | " journal={arXiv preprint arXiv:1907.03752},\n",
71 | " year={2019}\n",
72 | "}```\n",
73 | "\n",
74 | "https://arxiv.org/abs/1907.03752"
75 | ]
76 | }
77 | ],
78 | "metadata": {
79 | "kernelspec": {
80 | "display_name": "Python 3",
81 | "language": "python",
82 | "name": "python3"
83 | },
84 | "language_info": {
85 | "codemirror_mode": {
86 | "name": "ipython",
87 | "version": 3
88 | },
89 | "file_extension": ".py",
90 | "mimetype": "text/x-python",
91 | "name": "python",
92 | "nbconvert_exporter": "python",
93 | "pygments_lexer": "ipython3",
94 | "version": "3.7.7"
95 | }
96 | },
97 | "nbformat": 4,
98 | "nbformat_minor": 4
99 | }
100 |
--------------------------------------------------------------------------------
/examples/eda_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# EDA example"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "try:\n",
17 | " from textaugment import EDA\n",
18 | "except ModuleNotFoundError:\n",
19 | " !pip install textaugment\n",
20 | " from textaugment import EDA"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "t = EDA(random_state=1)"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "## Synonym Replacement\n",
37 | "Randomly choose *n* words from the sentence that are not stop words. Replace each of these words with one of its synonyms chosen at random"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 3,
43 | "metadata": {},
44 | "outputs": [
45 | {
46 | "name": "stdout",
47 | "output_type": "stream",
48 | "text": [
49 | "John is choke to town\n"
50 | ]
51 | }
52 | ],
53 | "source": [
54 | "output = t.synonym_replacement(\"John is going to town\", top_n=10)\n",
55 | "print(output)"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "## Random Insertion\n",
63 | "Find a random synonym of a random word in the sentence that is not a stop word. Insert that synonym into a random position in the sentence. Do this *n* times."
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 4,
69 | "metadata": {},
70 | "outputs": [
71 | {
72 | "name": "stdout",
73 | "output_type": "stream",
74 | "text": [
75 | "John is going to lead town\n"
76 | ]
77 | }
78 | ],
79 | "source": [
80 | "output = t.random_insertion(\"John is going to town\")\n",
81 | "print(output)"
82 | ]
83 | },
84 | {
85 | "cell_type": "markdown",
86 | "metadata": {},
87 | "source": [
88 | "## Random Swap\n",
89 | "Randomly choose two words in the sentence and swap their positions. Do this *n* times."
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 5,
95 | "metadata": {},
96 | "outputs": [
97 | {
98 | "name": "stdout",
99 | "output_type": "stream",
100 | "text": [
101 | "John is to going town\n"
102 | ]
103 | }
104 | ],
105 | "source": [
106 | "output = t.random_swap(\"John is going to town\")\n",
107 | "print(output)"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "## Random Deletion\n",
115 | "Randomly remove each word in the sentence with probability *p*."
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 6,
121 | "metadata": {},
122 | "outputs": [
123 | {
124 | "name": "stdout",
125 | "output_type": "stream",
126 | "text": [
127 | "John going to town\n"
128 | ]
129 | }
130 | ],
131 | "source": [
132 | "output = t.random_deletion(\"John is going to town\", p=0.2)\n",
133 | "print(output)"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | "## Cite the paper\n",
141 | "```\n",
142 | "@article{marivate2019improving,\n",
143 | " title={Improving short text classification through global augmentation methods},\n",
144 | " author={Marivate, Vukosi and Sefara, Tshephisho},\n",
145 | " journal={arXiv preprint arXiv:1907.03752},\n",
146 | " year={2019}\n",
147 | "}```\n",
148 | "\n",
149 | "https://arxiv.org/abs/1907.03752"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "metadata": {},
156 | "outputs": [],
157 | "source": []
158 | }
159 | ],
160 | "metadata": {
161 | "kernelspec": {
162 | "display_name": "Python 3",
163 | "language": "python",
164 | "name": "python3"
165 | },
166 | "language_info": {
167 | "codemirror_mode": {
168 | "name": "ipython",
169 | "version": 3
170 | },
171 | "file_extension": ".py",
172 | "mimetype": "text/x-python",
173 | "name": "python",
174 | "nbconvert_exporter": "python",
175 | "pygments_lexer": "ipython3",
176 | "version": "3.7.7"
177 | }
178 | },
179 | "nbformat": 4,
180 | "nbformat_minor": 4
181 | }
182 |
--------------------------------------------------------------------------------
/examples/fasttext_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Example for using Fasttext"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# Import libraries\n",
17 | "try:\n",
18 | " import textaugment, gensim\n",
19 | "except ModuleNotFoundError:\n",
20 | " !pip -q install textaugment gensim\n",
21 | " import textaugment, gensim"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "# Load Fasttext Embeddings \n",
29 | "\n",
30 | "Fasttext has Pre-trained word vectors on English webcrawl and Wikipedia which you can find [here](https://fasttext.cc/docs/en/english-vectors.html) as well as Pre-trained models for 157 different languages which you can find [here](https://fasttext.cc/docs/en/crawl-vectors.html)"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 3,
36 | "metadata": {},
37 | "outputs": [
38 | {
39 | "name": "stdout",
40 | "output_type": "stream",
41 | "text": [
42 | "--2020-09-01 10:11:28-- https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz\n",
43 | "Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 104.22.75.142, 104.22.74.142, 172.67.9.4, ...\n",
44 | "Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|104.22.75.142|:443... connected.\n",
45 | "HTTP request sent, awaiting response... 200 OK\n",
46 | "Length: 4503593528 (4.2G) [application/octet-stream]\n",
47 | "Saving to: ‘cc.en.300.bin.gz’\n",
48 | "\n",
49 | "cc.en.300.bin.gz 100%[===================>] 4.19G 4.32MB/s in 9m 57s \n",
50 | "\n",
51 | "2020-09-01 10:21:26 (7.20 MB/s) - ‘cc.en.300.bin.gz’ saved [4503593528/4503593528]\n",
52 | "\n"
53 | ]
54 | }
55 | ],
56 | "source": [
57 | "# Download the FastText embeddings in the language of your choice\n",
58 | "!wget \"https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz\""
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "# save path to your pre-trained model\n",
68 | "from gensim.test.utils import datapath\n",
69 | "pretrained_path = datapath('./cc.en.300.bin.gz')\n",
70 | "\n",
71 | "# load model\n",
72 | "model = gensim.models.fasttext.load_facebook_model(pretrained_path)"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "from textaugment import Word2vec\n",
82 | "t = Word2vec(model = model)\n",
83 | "output = t.augment('The stories are good', top_n=10)"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "print(output)"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "metadata": {},
98 | "source": [
99 | "## Cite the paper\n",
100 | "```\n",
101 | "@article{marivate2019improving,\n",
102 | " title={Improving short text classification through global augmentation methods},\n",
103 | " author={Marivate, Vukosi and Sefara, Tshephisho},\n",
104 | " journal={arXiv preprint arXiv:1907.03752},\n",
105 | " year={2019}\n",
106 | "}```\n",
107 | "\n",
108 | "https://arxiv.org/abs/1907.03752\n"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": []
117 | }
118 | ],
119 | "metadata": {
120 | "kernelspec": {
121 | "display_name": "Python 3",
122 | "language": "python",
123 | "name": "python3"
124 | },
125 | "language_info": {
126 | "codemirror_mode": {
127 | "name": "ipython",
128 | "version": 3
129 | },
130 | "file_extension": ".py",
131 | "mimetype": "text/x-python",
132 | "name": "python",
133 | "nbconvert_exporter": "python",
134 | "pygments_lexer": "ipython3",
135 | "version": "3.7.7"
136 | }
137 | },
138 | "nbformat": 4,
139 | "nbformat_minor": 4
140 | }
141 |
--------------------------------------------------------------------------------
/examples/mixup_example_using_IMDB_sentiment.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "kMccmZPoWd_h"
8 | },
9 | "source": [
10 | "# Mixup augmentation for NLP\n",
11 | "\n",
12 | "Using IMDB sentiment classification dataset"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 1,
18 | "metadata": {
19 | "colab": {
20 | "base_uri": "https://localhost:8080/",
21 | "height": 527
22 | },
23 | "colab_type": "code",
24 | "id": "YhKEHbrxWd_n",
25 | "outputId": "368747f0-47d5-439f-f4b3-d4db6d6a2d18"
26 | },
27 | "outputs": [
28 | {
29 | "name": "stdout",
30 | "output_type": "stream",
31 | "text": [
32 | "Collecting textaugment\n",
33 | " Downloading https://files.pythonhosted.org/packages/d5/87/906c855827f99a65ab91b22afbfa91731bd4397b5e3ca344de571e5c7651/textaugment-1.3-py3-none-any.whl\n",
34 | "Requirement already satisfied: nltk in /usr/local/lib/python3.6/dist-packages (from textaugment) (3.2.5)\n",
35 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from textaugment) (1.18.4)\n",
36 | "Requirement already satisfied: textblob in /usr/local/lib/python3.6/dist-packages (from textaugment) (0.15.3)\n",
37 | "Requirement already satisfied: gensim in /usr/local/lib/python3.6/dist-packages (from textaugment) (3.6.0)\n",
38 | "Collecting googletrans\n",
39 | " Downloading https://files.pythonhosted.org/packages/fd/f0/a22d41d3846d1f46a4f20086141e0428ccc9c6d644aacbfd30990cf46886/googletrans-2.4.0.tar.gz\n",
40 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from nltk->textaugment) (1.12.0)\n",
41 | "Requirement already satisfied: scipy>=0.18.1 in /usr/local/lib/python3.6/dist-packages (from gensim->textaugment) (1.4.1)\n",
42 | "Requirement already satisfied: smart-open>=1.2.1 in /usr/local/lib/python3.6/dist-packages (from gensim->textaugment) (2.0.0)\n",
43 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from googletrans->textaugment) (2.23.0)\n",
44 | "Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from smart-open>=1.2.1->gensim->textaugment) (1.13.13)\n",
45 | "Requirement already satisfied: boto in /usr/local/lib/python3.6/dist-packages (from smart-open>=1.2.1->gensim->textaugment) (2.49.0)\n",
46 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->googletrans->textaugment) (2.9)\n",
47 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->googletrans->textaugment) (1.24.3)\n",
48 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->googletrans->textaugment) (3.0.4)\n",
49 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->googletrans->textaugment) (2020.4.5.1)\n",
50 | "Requirement already satisfied: botocore<1.17.0,>=1.16.13 in /usr/local/lib/python3.6/dist-packages (from boto3->smart-open>=1.2.1->gensim->textaugment) (1.16.13)\n",
51 | "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->smart-open>=1.2.1->gensim->textaugment) (0.10.0)\n",
52 | "Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->smart-open>=1.2.1->gensim->textaugment) (0.3.3)\n",
53 | "Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.17.0,>=1.16.13->boto3->smart-open>=1.2.1->gensim->textaugment) (0.15.2)\n",
54 | "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.17.0,>=1.16.13->boto3->smart-open>=1.2.1->gensim->textaugment) (2.8.1)\n",
55 | "Building wheels for collected packages: googletrans\n",
56 | " Building wheel for googletrans (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
57 | " Created wheel for googletrans: filename=googletrans-2.4.0-cp36-none-any.whl size=15777 sha256=4de7ce4b52a5c57a680d9c96137d12291609a418bf5fdd1cf158003f747c7589\n",
58 | " Stored in directory: /root/.cache/pip/wheels/50/d6/e7/a8efd5f2427d5eb258070048718fa56ee5ac57fd6f53505f95\n",
59 | "Successfully built googletrans\n",
60 | "Installing collected packages: googletrans, textaugment\n",
61 | "Successfully installed googletrans-2.4.0 textaugment-1.3\n"
62 | ]
63 | }
64 | ],
65 | "source": [
66 | "# Import libraries\n",
67 | "try:\n",
68 | " import textaugment\n",
69 | "except ModuleNotFoundError:\n",
70 | " !pip install textaugment\n",
71 | " import textaugment\n",
72 | "\n",
73 | "import pandas as pd\n",
74 | "\n",
75 | "import tensorflow as tf\n",
76 | "from tensorflow.keras.preprocessing import sequence\n",
77 | "from tensorflow.keras.models import Sequential\n",
78 | "from tensorflow.keras.layers import Dense, Dropout, Activation\n",
79 | "from tensorflow.keras.layers import Embedding\n",
80 | "from tensorflow.keras.layers import Conv1D, GlobalMaxPooling1D\n",
81 | "from tensorflow.keras.datasets import imdb\n",
82 | "\n",
83 | "from textaugment import MIXUP\n",
84 | "%matplotlib inline"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 2,
90 | "metadata": {
91 | "colab": {
92 | "base_uri": "https://localhost:8080/",
93 | "height": 34
94 | },
95 | "colab_type": "code",
96 | "id": "JeMsxayIWd_r",
97 | "outputId": "814596bf-e5ca-47f1-c2ce-257e761e96c4"
98 | },
99 | "outputs": [
100 | {
101 | "data": {
102 | "text/plain": [
103 | "'2.2.0'"
104 | ]
105 | },
106 | "execution_count": 2,
107 | "metadata": {
108 | "tags": []
109 | },
110 | "output_type": "execute_result"
111 | }
112 | ],
113 | "source": [
114 | "tf.__version__"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 3,
120 | "metadata": {
121 | "colab": {
122 | "base_uri": "https://localhost:8080/",
123 | "height": 34
124 | },
125 | "colab_type": "code",
126 | "id": "_FbvA0uwRdEZ",
127 | "outputId": "8e912f45-8b7e-4ee7-a3ad-f342c3f090c7"
128 | },
129 | "outputs": [
130 | {
131 | "data": {
132 | "text/plain": [
133 | "'1.3'"
134 | ]
135 | },
136 | "execution_count": 3,
137 | "metadata": {
138 | "tags": []
139 | },
140 | "output_type": "execute_result"
141 | }
142 | ],
143 | "source": [
144 | "textaugment.__version__"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {
150 | "colab_type": "text",
151 | "id": "Oz8O8tISRdEg"
152 | },
153 | "source": [
154 | "## Initialize constant variables"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {
161 | "colab": {},
162 | "colab_type": "code",
163 | "id": "mg1AcYIWWd_w"
164 | },
165 | "outputs": [],
166 | "source": [
167 | "# set parameters:\n",
168 | "max_features = 5000\n",
169 | "maxlen = 400\n",
170 | "batch_size = 32\n",
171 | "embedding_dims = 50\n",
172 | "filters = 250\n",
173 | "kernel_size = 3\n",
174 | "hidden_dims = 250\n",
175 | "epochs = 10\n",
176 | "runs = 1"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 5,
182 | "metadata": {
183 | "colab": {
184 | "base_uri": "https://localhost:8080/",
185 | "height": 153
186 | },
187 | "colab_type": "code",
188 | "id": "ZRuNNVstWd_0",
189 | "outputId": "bc4ce3b2-5a12-4600-d1a8-b466615018df"
190 | },
191 | "outputs": [
192 | {
193 | "name": "stdout",
194 | "output_type": "stream",
195 | "text": [
196 | "Loading data...\n",
197 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz\n",
198 | "17465344/17464789 [==============================] - 0s 0us/step\n",
199 | "25000 train sequences\n",
200 | "25000 test sequences\n",
201 | "Pad sequences (samples x time)\n",
202 | "x_train shape: (25000, 400)\n",
203 | "x_test shape: (25000, 400)\n"
204 | ]
205 | }
206 | ],
207 | "source": [
208 | "print('Loading data...')\n",
209 | "(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)\n",
210 | "print(len(x_train), 'train sequences')\n",
211 | "print(len(x_test), 'test sequences')\n",
212 | "\n",
213 | "print('Pad sequences (samples x time)')\n",
214 | "x_train = sequence.pad_sequences(x_train, maxlen=maxlen)\n",
215 | "x_test = sequence.pad_sequences(x_test, maxlen=maxlen)\n",
216 | "print('x_train shape:', x_train.shape)\n",
217 | "print('x_test shape:', x_test.shape)"
218 | ]
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "metadata": {
223 | "colab_type": "text",
224 | "id": "Tx73Y-asRdEz"
225 | },
226 | "source": [
227 | "## Initialize mixup"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": null,
233 | "metadata": {
234 | "colab": {},
235 | "colab_type": "code",
236 | "id": "xvuxODUxRdE1"
237 | },
238 | "outputs": [],
239 | "source": [
240 | "mixup = MIXUP()\n",
241 | "generator, step = mixup.flow(x_train, y_train, batch_size=batch_size, runs=runs)"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 7,
247 | "metadata": {
248 | "colab": {
249 | "base_uri": "https://localhost:8080/",
250 | "height": 476
251 | },
252 | "colab_type": "code",
253 | "id": "6cm1o_fAWd_4",
254 | "outputId": "ea793754-100c-4c12-8acf-7798c096c399"
255 | },
256 | "outputs": [
257 | {
258 | "name": "stdout",
259 | "output_type": "stream",
260 | "text": [
261 | "Build model...\n",
262 | "Model: \"sequential\"\n",
263 | "_________________________________________________________________\n",
264 | "Layer (type) Output Shape Param # \n",
265 | "=================================================================\n",
266 | "embedding (Embedding) (None, 400, 50) 250000 \n",
267 | "_________________________________________________________________\n",
268 | "dropout (Dropout) (None, 400, 50) 0 \n",
269 | "_________________________________________________________________\n",
270 | "conv1d (Conv1D) (None, 398, 250) 37750 \n",
271 | "_________________________________________________________________\n",
272 | "global_max_pooling1d (Global (None, 250) 0 \n",
273 | "_________________________________________________________________\n",
274 | "dense (Dense) (None, 250) 62750 \n",
275 | "_________________________________________________________________\n",
276 | "dropout_1 (Dropout) (None, 250) 0 \n",
277 | "_________________________________________________________________\n",
278 | "activation (Activation) (None, 250) 0 \n",
279 | "_________________________________________________________________\n",
280 | "dense_1 (Dense) (None, 1) 251 \n",
281 | "_________________________________________________________________\n",
282 | "activation_1 (Activation) (None, 1) 0 \n",
283 | "=================================================================\n",
284 | "Total params: 350,751\n",
285 | "Trainable params: 350,751\n",
286 | "Non-trainable params: 0\n",
287 | "_________________________________________________________________\n"
288 | ]
289 | }
290 | ],
291 | "source": [
292 | "print('Build model...')\n",
293 | "model = Sequential()\n",
294 | "\n",
295 | "# we start off with an efficient embedding layer which maps\n",
296 | "# our vocab indices into embedding_dims dimensions\n",
297 | "model.add(Embedding(max_features,\n",
298 | " embedding_dims,\n",
299 | " input_length=maxlen))\n",
300 | "model.add(Dropout(0.2))\n",
301 | "\n",
302 | "# we add a Convolution1D, which will learn filters\n",
303 | "# word group filters of size filter_length:\n",
304 | "model.add(Conv1D(filters,\n",
305 | " kernel_size,\n",
306 | " padding='valid',\n",
307 | " activation='relu',\n",
308 | " strides=1))\n",
309 | "# we use max pooling:\n",
310 | "model.add(GlobalMaxPooling1D())\n",
311 | "\n",
312 | "# We add a vanilla hidden layer:\n",
313 | "model.add(Dense(hidden_dims))\n",
314 | "model.add(Dropout(0.2))\n",
315 | "model.add(Activation('relu'))\n",
316 | "\n",
317 | "# We project onto a single unit output layer, and squash it with a sigmoid:\n",
318 | "model.add(Dense(1))\n",
319 | "model.add(Activation('sigmoid'))\n",
320 | "\n",
321 | "model.compile(loss='binary_crossentropy',\n",
322 | " optimizer='adam',\n",
323 | " metrics=['accuracy'])\n",
324 | "model.summary()"
325 | ]
326 | },
327 | {
328 | "cell_type": "markdown",
329 | "metadata": {
330 | "colab_type": "text",
331 | "id": "b5zRyuq8UKmR"
332 | },
333 | "source": [
334 | "## Train model using mixup augmentation"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 8,
340 | "metadata": {
341 | "colab": {
342 | "base_uri": "https://localhost:8080/",
343 | "height": 357
344 | },
345 | "colab_type": "code",
346 | "id": "oGLSfzcUWeAB",
347 | "outputId": "81464964-8fd3-4249-b901-0e05cb664436"
348 | },
349 | "outputs": [
350 | {
351 | "name": "stdout",
352 | "output_type": "stream",
353 | "text": [
354 | "Epoch 1/10\n",
355 | "782/782 [==============================] - 8s 10ms/step - loss: 0.6867 - accuracy: 0.2859 - val_loss: 0.6408 - val_accuracy: 0.6537\n",
356 | "Epoch 2/10\n",
357 | "782/782 [==============================] - 8s 10ms/step - loss: 0.6655 - accuracy: 0.3081 - val_loss: 0.6140 - val_accuracy: 0.6620\n",
358 | "Epoch 3/10\n",
359 | "782/782 [==============================] - 8s 10ms/step - loss: 0.6443 - accuracy: 0.3267 - val_loss: 0.5688 - val_accuracy: 0.7233\n",
360 | "Epoch 4/10\n",
361 | "782/782 [==============================] - 8s 10ms/step - loss: 0.6250 - accuracy: 0.3287 - val_loss: 0.5167 - val_accuracy: 0.7434\n",
362 | "Epoch 5/10\n",
363 | "782/782 [==============================] - 8s 10ms/step - loss: 0.6140 - accuracy: 0.3337 - val_loss: 0.5154 - val_accuracy: 0.7534\n",
364 | "Epoch 6/10\n",
365 | "782/782 [==============================] - 8s 10ms/step - loss: 0.6029 - accuracy: 0.3338 - val_loss: 0.4763 - val_accuracy: 0.7765\n",
366 | "Epoch 7/10\n",
367 | "782/782 [==============================] - 8s 10ms/step - loss: 0.5976 - accuracy: 0.3314 - val_loss: 0.4659 - val_accuracy: 0.7810\n",
368 | "Epoch 8/10\n",
369 | "782/782 [==============================] - 8s 10ms/step - loss: 0.5857 - accuracy: 0.3423 - val_loss: 0.4551 - val_accuracy: 0.7873\n",
370 | "Epoch 9/10\n",
371 | "782/782 [==============================] - 8s 10ms/step - loss: 0.5800 - accuracy: 0.3488 - val_loss: 0.4502 - val_accuracy: 0.7927\n",
372 | "Epoch 10/10\n",
373 | "782/782 [==============================] - 8s 10ms/step - loss: 0.5793 - accuracy: 0.3402 - val_loss: 0.4653 - val_accuracy: 0.7927\n"
374 | ]
375 | }
376 | ],
377 | "source": [
378 | "h1 = model.fit(generator, steps_per_epoch=step,\n",
379 | " epochs=epochs,\n",
380 | " validation_data=(x_test, y_test))"
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "execution_count": 9,
386 | "metadata": {
387 | "colab": {
388 | "base_uri": "https://localhost:8080/",
389 | "height": 298
390 | },
391 | "colab_type": "code",
392 | "id": "XKrXdkt8XeYo",
393 | "outputId": "0d463439-1718-4f90-bc24-b32f6dae7eda"
394 | },
395 | "outputs": [
396 | {
397 | "data": {
398 | "text/plain": [
399 | ""
400 | ]
401 | },
402 | "execution_count": 9,
403 | "metadata": {
404 | "tags": []
405 | },
406 | "output_type": "execute_result"
407 | },
408 | {
409 | "data": {
410 | "image/png": "\n",
411 | "text/plain": [
412 | ""
413 | ]
414 | },
415 | "metadata": {
416 | "needs_background": "light",
417 | "tags": []
418 | },
419 | "output_type": "display_data"
420 | }
421 | ],
422 | "source": [
423 | "pd.DataFrame(h1.history)[['loss','val_loss']].plot(title=\"With mixup\")"
424 | ]
425 | },
426 | {
427 | "cell_type": "code",
428 | "execution_count": 10,
429 | "metadata": {
430 | "colab": {
431 | "base_uri": "https://localhost:8080/",
432 | "height": 476
433 | },
434 | "colab_type": "code",
435 | "id": "Iiv7ahP8WeAF",
436 | "outputId": "0ad04311-b497-4830-dd50-a832daf583ac"
437 | },
438 | "outputs": [
439 | {
440 | "name": "stdout",
441 | "output_type": "stream",
442 | "text": [
443 | "Build model...\n",
444 | "Model: \"sequential_1\"\n",
445 | "_________________________________________________________________\n",
446 | "Layer (type) Output Shape Param # \n",
447 | "=================================================================\n",
448 | "embedding_1 (Embedding) (None, 400, 50) 250000 \n",
449 | "_________________________________________________________________\n",
450 | "dropout_2 (Dropout) (None, 400, 50) 0 \n",
451 | "_________________________________________________________________\n",
452 | "conv1d_1 (Conv1D) (None, 398, 250) 37750 \n",
453 | "_________________________________________________________________\n",
454 | "global_max_pooling1d_1 (Glob (None, 250) 0 \n",
455 | "_________________________________________________________________\n",
456 | "dense_2 (Dense) (None, 250) 62750 \n",
457 | "_________________________________________________________________\n",
458 | "dropout_3 (Dropout) (None, 250) 0 \n",
459 | "_________________________________________________________________\n",
460 | "activation_2 (Activation) (None, 250) 0 \n",
461 | "_________________________________________________________________\n",
462 | "dense_3 (Dense) (None, 1) 251 \n",
463 | "_________________________________________________________________\n",
464 | "activation_3 (Activation) (None, 1) 0 \n",
465 | "=================================================================\n",
466 | "Total params: 350,751\n",
467 | "Trainable params: 350,751\n",
468 | "Non-trainable params: 0\n",
469 | "_________________________________________________________________\n"
470 | ]
471 | }
472 | ],
473 | "source": [
474 | "print('Build model...')\n",
475 | "model2 = Sequential()\n",
476 | "\n",
477 | "# we start off with an efficient embedding layer which maps\n",
478 | "# our vocab indices into embedding_dims dimensions\n",
479 | "model2.add(Embedding(max_features,\n",
480 | " embedding_dims,\n",
481 | " input_length=maxlen))\n",
482 | "model2.add(Dropout(0.2))\n",
483 | "\n",
484 | "# we add a Convolution1D, which will learn filters\n",
485 | "# word group filters of size filter_length:\n",
486 | "model2.add(Conv1D(filters,\n",
487 | " kernel_size,\n",
488 | " padding='valid',\n",
489 | " activation='relu',\n",
490 | " strides=1))\n",
491 | "# we use max pooling:\n",
492 | "model2.add(GlobalMaxPooling1D())\n",
493 | "\n",
494 | "# We add a vanilla hidden layer:\n",
495 | "model2.add(Dense(hidden_dims))\n",
496 | "model2.add(Dropout(0.2))\n",
497 | "model2.add(Activation('relu'))\n",
498 | "\n",
499 | "# We project onto a single unit output layer, and squash it with a sigmoid:\n",
500 | "model2.add(Dense(1))\n",
501 | "model2.add(Activation('sigmoid'))\n",
502 | "\n",
503 | "model2.compile(loss='binary_crossentropy',\n",
504 | " optimizer='adam',\n",
505 | " metrics=['accuracy'])\n",
506 | "model2.summary()"
507 | ]
508 | },
509 | {
510 | "cell_type": "code",
511 | "execution_count": 11,
512 | "metadata": {
513 | "colab": {
514 | "base_uri": "https://localhost:8080/",
515 | "height": 357
516 | },
517 | "colab_type": "code",
518 | "id": "ygNHmhGMWeAI",
519 | "outputId": "1592613d-52d2-409b-e210-cceddb7f5bbd"
520 | },
521 | "outputs": [
522 | {
523 | "name": "stdout",
524 | "output_type": "stream",
525 | "text": [
526 | "Epoch 1/10\n",
527 | "782/782 [==============================] - 8s 10ms/step - loss: 0.4057 - accuracy: 0.7964 - val_loss: 0.2819 - val_accuracy: 0.8825\n",
528 | "Epoch 2/10\n",
529 | "782/782 [==============================] - 8s 10ms/step - loss: 0.2260 - accuracy: 0.9100 - val_loss: 0.2540 - val_accuracy: 0.8957\n",
530 | "Epoch 3/10\n",
531 | "782/782 [==============================] - 8s 10ms/step - loss: 0.1579 - accuracy: 0.9409 - val_loss: 0.2806 - val_accuracy: 0.8874\n",
532 | "Epoch 4/10\n",
533 | "782/782 [==============================] - 8s 10ms/step - loss: 0.1056 - accuracy: 0.9625 - val_loss: 0.3103 - val_accuracy: 0.8897\n",
534 | "Epoch 5/10\n",
535 | "782/782 [==============================] - 8s 10ms/step - loss: 0.0732 - accuracy: 0.9730 - val_loss: 0.3593 - val_accuracy: 0.8838\n",
536 | "Epoch 6/10\n",
537 | "782/782 [==============================] - 8s 10ms/step - loss: 0.0539 - accuracy: 0.9808 - val_loss: 0.3938 - val_accuracy: 0.8884\n",
538 | "Epoch 7/10\n",
539 | "782/782 [==============================] - 8s 10ms/step - loss: 0.0419 - accuracy: 0.9854 - val_loss: 0.4444 - val_accuracy: 0.8817\n",
540 | "Epoch 8/10\n",
541 | "782/782 [==============================] - 8s 10ms/step - loss: 0.0340 - accuracy: 0.9876 - val_loss: 0.4842 - val_accuracy: 0.8870\n",
542 | "Epoch 9/10\n",
543 | "782/782 [==============================] - 8s 10ms/step - loss: 0.0388 - accuracy: 0.9857 - val_loss: 0.4686 - val_accuracy: 0.8863\n",
544 | "Epoch 10/10\n",
545 | "782/782 [==============================] - 8s 10ms/step - loss: 0.0314 - accuracy: 0.9887 - val_loss: 0.6685 - val_accuracy: 0.8559\n"
546 | ]
547 | }
548 | ],
549 | "source": [
550 | "h2 = model2.fit(x_train, y_train,\n",
551 | " batch_size=batch_size,\n",
552 | " epochs=epochs,\n",
553 | " validation_data=(x_test, y_test))"
554 | ]
555 | },
556 | {
557 | "cell_type": "code",
558 | "execution_count": 12,
559 | "metadata": {
560 | "colab": {
561 | "base_uri": "https://localhost:8080/",
562 | "height": 298
563 | },
564 | "colab_type": "code",
565 | "id": "DzJEhaPrWeAM",
566 | "outputId": "aec6c655-c5f8-434b-bb16-d1e1056adc03"
567 | },
568 | "outputs": [
569 | {
570 | "data": {
571 | "text/plain": [
572 | ""
573 | ]
574 | },
575 | "execution_count": 12,
576 | "metadata": {
577 | "tags": []
578 | },
579 | "output_type": "execute_result"
580 | },
581 | {
582 | "data": {
583 | "image/png": "\n",
584 | "text/plain": [
585 | ""
586 | ]
587 | },
588 | "metadata": {
589 | "needs_background": "light",
590 | "tags": []
591 | },
592 | "output_type": "display_data"
593 | }
594 | ],
595 | "source": [
596 | "pd.DataFrame(h2.history)[['loss','val_loss']].plot(title=\"Without mixup\")"
597 | ]
598 | },
599 | {
600 | "cell_type": "markdown",
601 | "metadata": {
602 | "colab_type": "text",
603 | "id": "M2HDERJbGr2a"
604 | },
605 | "source": [
606 | "# Comparison\n",
607 | "See the loss curve with mixup does not overfit."
608 | ]
609 | },
610 | {
611 | "cell_type": "markdown",
612 | "metadata": {
613 | "colab": {},
614 | "colab_type": "code",
615 | "id": "hqteWafKRdF1"
616 | },
617 | "source": [
618 | "## Cite the paper\n",
619 | "```\n",
620 | "@article{marivate2019improving,\n",
621 | " title={Improving short text classification through global augmentation methods},\n",
622 | " author={Marivate, Vukosi and Sefara, Tshephisho},\n",
623 | " journal={arXiv preprint arXiv:1907.03752},\n",
624 | " year={2019}\n",
625 | "}```\n",
626 | "\n",
627 | "https://arxiv.org/abs/1907.03752"
628 | ]
629 | },
630 | {
631 | "cell_type": "code",
632 | "execution_count": null,
633 | "metadata": {},
634 | "outputs": [],
635 | "source": []
636 | }
637 | ],
638 | "metadata": {
639 | "accelerator": "GPU",
640 | "colab": {
641 | "collapsed_sections": [],
642 | "name": "mixup_example_using_IMDB_sentiment.ipynb",
643 | "provenance": []
644 | },
645 | "kernelspec": {
646 | "display_name": "Python 3",
647 | "language": "python",
648 | "name": "python3"
649 | },
650 | "language_info": {
651 | "codemirror_mode": {
652 | "name": "ipython",
653 | "version": 3
654 | },
655 | "file_extension": ".py",
656 | "mimetype": "text/x-python",
657 | "name": "python",
658 | "nbconvert_exporter": "python",
659 | "pygments_lexer": "ipython3",
660 | "version": "3.7.7"
661 | }
662 | },
663 | "nbformat": 4,
664 | "nbformat_minor": 4
665 | }
666 |
--------------------------------------------------------------------------------
/examples/word2vec_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "JHDJLKDuJkcB"
8 | },
9 | "source": [
10 | "# Example for using word2vec"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {
17 | "colab": {},
18 | "colab_type": "code",
19 | "id": "9m8ChZsdAx41"
20 | },
21 | "outputs": [],
22 | "source": [
23 | "# Import libraries\n",
24 | "try:\n",
25 | " import textaugment, gensim\n",
26 | "except ModuleNotFoundError:\n",
27 | " !pip -q install textaugment gensim\n",
28 | " import textaugment, gensim"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 4,
34 | "metadata": {
35 | "colab": {
36 | "base_uri": "https://localhost:8080/",
37 | "height": 153
38 | },
39 | "colab_type": "code",
40 | "id": "ux6Bc4QSrYA8",
41 | "outputId": "9f2b8af1-3b22-455c-dd85-d1ac173a5317"
42 | },
43 | "outputs": [
44 | {
45 | "name": "stdout",
46 | "output_type": "stream",
47 | "text": [
48 | "[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
49 | "[nltk_data] Unzipping corpora/wordnet.zip.\n",
50 | "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
51 | "[nltk_data] Unzipping tokenizers/punkt.zip.\n",
52 | "[nltk_data] Downloading package averaged_perceptron_tagger to\n",
53 | "[nltk_data] /root/nltk_data...\n",
54 | "[nltk_data] Unzipping taggers/averaged_perceptron_tagger.zip.\n"
55 | ]
56 | },
57 | {
58 | "data": {
59 | "text/plain": [
60 | "True"
61 | ]
62 | },
63 | "execution_count": 4,
64 | "metadata": {
65 | "tags": []
66 | },
67 | "output_type": "execute_result"
68 | }
69 | ],
70 | "source": [
71 | "# Import NLRK and download data\n",
72 | "import nltk\n",
73 | "nltk.download(['wordnet','punkt','averaged_perceptron_tagger'])"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "metadata": {
79 | "colab_type": "text",
80 | "id": "8AUt-F5MtiuI"
81 | },
82 | "source": [
83 | "## Load Google Word2vec embeddings"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 5,
89 | "metadata": {
90 | "colab": {
91 | "base_uri": "https://localhost:8080/",
92 | "height": 204
93 | },
94 | "colab_type": "code",
95 | "id": "1xq4dJtSr4RM",
96 | "outputId": "1ff32743-04a9-4b8a-eda3-8dcf55e711ca"
97 | },
98 | "outputs": [
99 | {
100 | "name": "stdout",
101 | "output_type": "stream",
102 | "text": [
103 | "--2020-05-23 18:06:47-- https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz\n",
104 | "Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.178.197\n",
105 | "Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.178.197|:443... connected.\n",
106 | "HTTP request sent, awaiting response... 200 OK\n",
107 | "Length: 1647046227 (1.5G) [application/x-gzip]\n",
108 | "Saving to: ‘GoogleNews-vectors-negative300.bin.gz’\n",
109 | "\n",
110 | "GoogleNews-vectors- 100%[===================>] 1.53G 36.2MB/s in 44s \n",
111 | "\n",
112 | "2020-05-23 18:07:31 (35.7 MB/s) - ‘GoogleNews-vectors-negative300.bin.gz’ saved [1647046227/1647046227]\n",
113 | "\n"
114 | ]
115 | }
116 | ],
117 | "source": [
118 | "# Download Google Word2vec embeddings\n",
119 | "!wget \"https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz\""
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 8,
125 | "metadata": {
126 | "colab": {
127 | "base_uri": "https://localhost:8080/",
128 | "height": 71
129 | },
130 | "colab_type": "code",
131 | "id": "q2wxTNhwrjK-",
132 | "outputId": "e30ff6b7-96a3-4d59-c486-65def436cbd8"
133 | },
134 | "outputs": [
135 | {
136 | "name": "stderr",
137 | "output_type": "stream",
138 | "text": [
139 | "/usr/local/lib/python3.6/dist-packages/smart_open/smart_open_lib.py:253: UserWarning: This function is deprecated, use smart_open.open instead. See the migration notes for details: https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst#migrating-to-the-new-open-function\n",
140 | " 'See the migration notes for details: %s' % _MIGRATION_NOTES_URL\n"
141 | ]
142 | }
143 | ],
144 | "source": [
145 | "model = gensim.models.KeyedVectors.load_word2vec_format('./GoogleNews-vectors-negative300.bin.gz', binary=True)"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": 13,
151 | "metadata": {
152 | "colab": {
153 | "base_uri": "https://localhost:8080/",
154 | "height": 71
155 | },
156 | "colab_type": "code",
157 | "id": "3uHnRL77uATl",
158 | "outputId": "de09c7ff-47bc-4e21-d2eb-89fcadb4d2bd"
159 | },
160 | "outputs": [
161 | {
162 | "name": "stderr",
163 | "output_type": "stream",
164 | "text": [
165 | "/usr/local/lib/python3.6/dist-packages/gensim/matutils.py:737: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n",
166 | " if np.issubdtype(vec.dtype, np.int):\n"
167 | ]
168 | }
169 | ],
170 | "source": [
171 | "from textaugment import Word2vec\n",
172 | "t = Word2vec(model=model)\n",
173 | "output = t.augment('The stories are good', top_n=10)"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": 14,
179 | "metadata": {
180 | "colab": {
181 | "base_uri": "https://localhost:8080/",
182 | "height": 34
183 | },
184 | "colab_type": "code",
185 | "id": "BhVYt8V3uAwk",
186 | "outputId": "7c36d302-db66-4837-ff6b-ea1793a088d9"
187 | },
188 | "outputs": [
189 | {
190 | "name": "stdout",
191 | "output_type": "stream",
192 | "text": [
193 | "the stories are excellent\n"
194 | ]
195 | }
196 | ],
197 | "source": [
198 | "print(output)"
199 | ]
200 | },
201 | {
202 | "cell_type": "markdown",
203 | "metadata": {
204 | "colab": {},
205 | "colab_type": "code",
206 | "id": "IWoNJrZfy94n"
207 | },
208 | "source": [
209 | "## Cite the paper\n",
210 | "```\n",
211 | "@article{marivate2019improving,\n",
212 | " title={Improving short text classification through global augmentation methods},\n",
213 | " author={Marivate, Vukosi and Sefara, Tshephisho},\n",
214 | " journal={arXiv preprint arXiv:1907.03752},\n",
215 | " year={2019}\n",
216 | "}```\n",
217 | "\n",
218 | "https://arxiv.org/abs/1907.03752"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": null,
224 | "metadata": {},
225 | "outputs": [],
226 | "source": []
227 | }
228 | ],
229 | "metadata": {
230 | "accelerator": "GPU",
231 | "colab": {
232 | "collapsed_sections": [],
233 | "name": "word2vec example.ipynb",
234 | "provenance": []
235 | },
236 | "kernelspec": {
237 | "display_name": "Python 3",
238 | "language": "python",
239 | "name": "python3"
240 | },
241 | "language_info": {
242 | "codemirror_mode": {
243 | "name": "ipython",
244 | "version": 3
245 | },
246 | "file_extension": ".py",
247 | "mimetype": "text/x-python",
248 | "name": "python",
249 | "nbconvert_exporter": "python",
250 | "pygments_lexer": "ipython3",
251 | "version": "3.7.7"
252 | }
253 | },
254 | "nbformat": 4,
255 | "nbformat_minor": 4
256 | }
257 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gensim>=4.0
2 | googletrans>=2
3 | nltk
4 | numpy
5 | textblob
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import setuptools
4 | import re
5 |
6 |
7 | def find_version(fname):
8 | """Attempts to find the version number in the file names fname.
9 | Raises RuntimeError if not found.
10 | """
11 | version = ''
12 | with open(fname, 'r') as fp:
13 | reg = re.compile(r'__version__ = [\'"]([^\'"]*)[\'"]')
14 | for line in fp:
15 | m = reg.match(line)
16 | if m:
17 | version = m.group(1)
18 | break
19 | if not version:
20 | raise RuntimeError('Cannot find version information')
21 | return version
22 |
23 |
24 | __version__ = find_version('textaugment/__init__.py')
25 |
26 |
27 | def read(fname):
28 | with open(fname, "r") as fh:
29 | content = fh.read()
30 | return content
31 |
32 |
33 | setuptools.setup(
34 | name='textaugment',
35 | version=__version__,
36 | packages=setuptools.find_packages(exclude=('test*', )),
37 | author='Joseph Sefara',
38 | author_email='sefaratj@gmail.com',
39 | license='MIT',
40 | keywords=['text augmentation', 'python', 'natural language processing', 'nlp'],
41 | url='https://github.com/dsfsi/textaugment',
42 | description='A library for augmenting text for natural language processing applications.',
43 | long_description=read("README.md"),
44 | long_description_content_type="text/markdown",
45 | install_requires=['nltk', 'gensim>=4.0', 'textblob', 'numpy', 'googletrans>=2'],
46 | classifiers=[
47 | "Intended Audience :: Developers",
48 | "Natural Language :: English",
49 | "License :: OSI Approved :: MIT License",
50 | "Operating System :: OS Independent",
51 | "Programming Language :: Python :: 3",
52 | "Programming Language :: Python :: Implementation :: PyPy",
53 | "Topic :: Text Processing :: Linguistic",
54 | ]
55 | )
56 |
--------------------------------------------------------------------------------
/tests/test_translate.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import sys
3 | from textaugment.translate import Translate
4 | from textaugment import translate
5 |
6 |
7 | class InputTestCase(unittest.TestCase):
8 |
9 | def setUp(self):
10 | self.t = Translate(src="en", to="es")
11 |
12 | def test_geometric(self):
13 | with self.assertRaises(ValueError, msg="Parameters nust be set"):
14 | Translate()
15 |
16 | with self.assertRaises(KeyError, msg="Value of parameters must be correct"):
17 | Translate(to=7, src="hello") # Test parameter, type
18 |
19 | with self.assertRaises(TypeError, msg="Only strings are allowed"):
20 | self.t.augment(45)
21 |
22 | def test_translate(self):
23 | self.assertTrue(translate.LANGUAGES, msg="Files exists")
24 |
25 |
26 | class OutputTestCase(unittest.TestCase):
27 |
28 | def setUp(self):
29 | self.t = Translate(src="en", to="es")
30 | self.data = "He walks"
31 |
32 | def test_augment(self):
33 | self.assertEqual(self.t.augment(self.data), self.data)
34 |
35 | self.assertEqual(self.t.augment("4"), "4")
36 |
37 |
38 | class PlatformTestCase(unittest.TestCase):
39 |
40 | def test_platform(self):
41 | self.assertEqual(sys.version_info[0], 3, msg="Must be using Python 3")
42 |
43 |
44 | if __name__ == '__main__':
45 | unittest.main()
46 |
47 |
--------------------------------------------------------------------------------
/tests/test_word2vec.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import sys
3 | from textaugment.word2vec import Word2vec
4 |
5 |
6 | class InputTestCase(unittest.TestCase):
7 |
8 | def setUp(self):
9 | self.path = "/home/tjs/dev/papu/models/gensim_cbow_sepedi"
10 | self.wrongpath = "/home/tjs/dev/papu/models/gensim_cbow_sepedi-wrong"
11 | self.w = Word2vec(model=self.path)
12 |
13 | def test_augment(self):
14 | with self.assertRaises(TypeError, msg="Value for p should be float"):
15 | Word2vec(model=self.path, p="foo")
16 |
17 | with self.assertRaises(TypeError, msg="Value for runs should be integer"):
18 | Word2vec(model=self.path, runs="foo")
19 |
20 | with self.assertRaises(FileNotFoundError, msg="The model is not found"):
21 | Word2vec(model=self.wrongpath)
22 |
23 | with self.assertRaises(TypeError, msg="Input should not be lists"):
24 | self.w.augment(["hello"])
25 |
26 | with self.assertRaises(TypeError, msg="Input should not be numbers"):
27 | self.w.augment(45)
28 |
29 |
30 | class OutputTestCase(unittest.TestCase):
31 |
32 | def setUp(self):
33 | self.path = "/home/tjs/dev/papu/models/gensim_cbow_sepedi"
34 | self.w = Word2vec(model=self.path)
35 | self.data = "We are testing"
36 |
37 | def test_augment(self):
38 | self.assertIsInstance(self.w.augment(self.data), str, msg="Input must be a string")
39 | self.assertEqual(self.w.augment("4"), "4", msg="Input should not be numbers")
40 |
41 |
42 | class PlatformTestCase(unittest.TestCase):
43 |
44 | def test_platform(self):
45 | self.assertEqual(sys.version_info[0], 3, msg="Must be using Python 3")
46 |
47 |
48 | if __name__ == '__main__':
49 | unittest.main()
50 |
51 |
--------------------------------------------------------------------------------
/tests/test_wordnet.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import sys
3 | import numpy as np
4 | from textaugment.wordnet import Wordnet
5 |
6 |
7 | class InputTestCase(unittest.TestCase):
8 | def setUp(self):
9 | self.p = 0.8
10 | self.data = ["I", "am", "testing"]
11 | self.w = Wordnet(p=self.p)
12 |
13 | def test_geometric(self):
14 | with self.assertRaises(TypeError, msg="Receives one parameter"):
15 | self.w.geometric(p=self.p, data=self.data)
16 |
17 | with self.assertRaises(TypeError, msg="Receives one parameter"):
18 | self.w.geometric()
19 |
20 | with self.assertRaises(IndexError, msg="Data must be set using; data='data string'"):
21 | self.w.geometric(data=0)
22 |
23 | def test_augment(self):
24 | with self.assertRaises(TypeError, msg="Expect string not list"):
25 | self.w.augment(self.data)
26 |
27 | with self.assertRaises(TypeError, msg="Expect string not integer"):
28 | self.w.augment(data=0)
29 |
30 |
31 | class OutputTestCase(unittest.TestCase):
32 |
33 | def setUp(self):
34 | self.p = 0.8
35 | self.data = ["I", "am", "testing"]
36 | self.data2 = "известен още с псевдонимите"
37 | self.w = Wordnet(p=self.p)
38 |
39 | def test_augment(self):
40 | self.assertIsInstance(self.w.augment(" ".join(self.data)), str)
41 |
42 | def test_geometric(self):
43 | self.assertIsInstance(self.w.geometric(data=self.data), np.ndarray)
44 |
45 |
46 | class PlatformTestCase(unittest.TestCase):
47 |
48 | def test_platform(self):
49 | self.assertEqual(sys.version_info[0], 3, msg="Must be using Python 3")
50 |
51 |
52 | if __name__ == '__main__':
53 | unittest.main()
54 |
--------------------------------------------------------------------------------
/textaugment/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .translate import Translate
3 | from .word2vec import Word2vec
4 | from .word2vec import Fasttext
5 | from .wordnet import Wordnet
6 | from .eda import EDA
7 | from .aeda import AEDA
8 | from .mixup import MIXUP
9 | from .constants import LANGUAGES
10 |
11 | name = "textaugment"
12 |
13 | __version__ = '2.0.0'
14 | __licence__ = 'MIT'
15 | __author__ = 'Joseph Sefara'
16 | __url__ = 'https://github.com/dsfsi/textaugment/'
17 |
18 | PACKAGE_DIR = os.path.dirname(os.path.abspath(__file__))
19 |
20 | __all__ = [
21 | 'Translate',
22 | 'Word2vec',
23 | 'Wordnet',
24 | 'EDA',
25 | 'AEDA',
26 | 'MIXUP',
27 | 'LANGUAGES'
28 | ]
29 |
--------------------------------------------------------------------------------
/textaugment/aeda.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # TextAugment: AEDA
3 | #
4 | # Copyright (C) 2023
5 | # Author: Juhwan Choi
6 | #
7 | # URL:
8 | # For license information, see LICENSE
9 | #
10 | """
11 | This module is an implementation of the original AEDA algorithm (2021) [1].
12 | """
13 | import random
14 |
15 |
16 | class AEDA:
17 | """
18 | This class is an implementation of the original AEDA algorithm (2021) [1].
19 |
20 | [1] Karimi et al., 2021, November. AEDA: An Easier Data Augmentation Technique for Text Classification.
21 | In Findings of the Association for Computational Linguistics: EMNLP 2021 (pp. 2748-2754).
22 | https://aclanthology.org/2021.findings-emnlp.234.pdf
23 |
24 | Example usage: ::
25 | >>> from textaugment import AEDA
26 | >>> t = AEDA()
27 | >>> t.punct_insertion("John is going to town")
28 | ! John is going to town
29 | """
30 |
31 | @staticmethod
32 | def validate(**kwargs):
33 | """Validate input data"""
34 | if 'sentence' in kwargs:
35 | if not isinstance(kwargs['sentence'].strip(), str) or len(kwargs['sentence'].strip()) == 0:
36 | raise TypeError("sentence must be a valid sentence")
37 |
38 | def __init__(self, punctuations=['.', ';', '?', ':', '!', ','], random_state=1):
39 | """A method to initialize parameters
40 |
41 | :type punctuations: list
42 | :param punctuations: (optional) Punctuations to be inserted
43 | :type random_state: int
44 | :param random_state: (optional) Seed
45 |
46 | :rtype: None
47 | :return: Constructer do not return.
48 | """
49 | self.punctuations = punctuations
50 | self.random_state = random_state
51 | if isinstance(self.random_state, int):
52 | random.seed(self.random_state)
53 | else:
54 | raise TypeError("random_state must have type int")
55 |
56 | def punct_insertion(self, sentence: str):
57 | """Insert random punctuations to the sentence
58 |
59 | :type sentence: str
60 | :param sentence: Sentence
61 |
62 | :rtype: str
63 | :return: Augmented sentence
64 | """
65 | self.validate(sentence=sentence)
66 |
67 | sentence = sentence.strip().split(' ')
68 | len_sentence = len(sentence)
69 | # Get random number of punctuations to be inserted
70 | # The number of punctuations to be inserted is between 1 and 1/3 of the length of the sentence
71 | num_punctuations = random.randint(1, len_sentence // 3)
72 | augmented_sentence = sentence.copy()
73 |
74 | # Insert random punctuations in random positions
75 | for _ in range(num_punctuations):
76 | punct = random.choice(self.punctuations) # Select punctuation to be inserted
77 | pos = random.randint(0, len(augmented_sentence) - 1) # Select position to insert punctuation
78 | augmented_sentence = augmented_sentence[:pos] + [punct] + augmented_sentence[pos:] # Insert punctuation
79 | augmented_sentence = ' '.join(augmented_sentence)
80 |
81 | return augmented_sentence
82 |
--------------------------------------------------------------------------------
/textaugment/constants.py:
--------------------------------------------------------------------------------
1 | LANGUAGES = {
2 | 'af': 'afrikaans',
3 | 'sq': 'albanian',
4 | 'am': 'amharic',
5 | 'ar': 'arabic',
6 | 'hy': 'armenian',
7 | 'az': 'azerbaijani',
8 | 'eu': 'basque',
9 | 'be': 'belarusian',
10 | 'bn': 'bengali',
11 | 'bs': 'bosnian',
12 | 'bg': 'bulgarian',
13 | 'ca': 'catalan',
14 | 'ceb': 'cebuano',
15 | 'ny': 'chichewa',
16 | 'zh-cn': 'chinese (simplified)',
17 | 'zh-tw': 'chinese (traditional)',
18 | 'co': 'corsican',
19 | 'hr': 'croatian',
20 | 'cs': 'czech',
21 | 'da': 'danish',
22 | 'nl': 'dutch',
23 | 'en': 'english',
24 | 'eo': 'esperanto',
25 | 'et': 'estonian',
26 | 'tl': 'filipino',
27 | 'fi': 'finnish',
28 | 'fr': 'french',
29 | 'fy': 'frisian',
30 | 'gl': 'galician',
31 | 'ka': 'georgian',
32 | 'de': 'german',
33 | 'el': 'greek',
34 | 'gu': 'gujarati',
35 | 'ht': 'haitian creole',
36 | 'ha': 'hausa',
37 | 'haw': 'hawaiian',
38 | 'iw': 'hebrew',
39 | 'hi': 'hindi',
40 | 'hmn': 'hmong',
41 | 'hu': 'hungarian',
42 | 'is': 'icelandic',
43 | 'ig': 'igbo',
44 | 'id': 'indonesian',
45 | 'ga': 'irish',
46 | 'it': 'italian',
47 | 'ja': 'japanese',
48 | 'jw': 'javanese',
49 | 'kn': 'kannada',
50 | 'kk': 'kazakh',
51 | 'km': 'khmer',
52 | 'ko': 'korean',
53 | 'ku': 'kurdish (kurmanji)',
54 | 'ky': 'kyrgyz',
55 | 'lo': 'lao',
56 | 'la': 'latin',
57 | 'lv': 'latvian',
58 | 'lt': 'lithuanian',
59 | 'lb': 'luxembourgish',
60 | 'mk': 'macedonian',
61 | 'mg': 'malagasy',
62 | 'ms': 'malay',
63 | 'ml': 'malayalam',
64 | 'mt': 'maltese',
65 | 'mi': 'maori',
66 | 'mr': 'marathi',
67 | 'mn': 'mongolian',
68 | 'my': 'myanmar (burmese)',
69 | 'ne': 'nepali',
70 | 'no': 'norwegian',
71 | 'ps': 'pashto',
72 | 'fa': 'persian',
73 | 'pl': 'polish',
74 | 'pt': 'portuguese',
75 | 'pa': 'punjabi',
76 | 'ro': 'romanian',
77 | 'ru': 'russian',
78 | 'sm': 'samoan',
79 | 'gd': 'scots gaelic',
80 | 'sr': 'serbian',
81 | 'st': 'sesotho',
82 | 'sn': 'shona',
83 | 'sd': 'sindhi',
84 | 'si': 'sinhala',
85 | 'sk': 'slovak',
86 | 'sl': 'slovenian',
87 | 'so': 'somali',
88 | 'es': 'spanish',
89 | 'su': 'sundanese',
90 | 'sw': 'swahili',
91 | 'sv': 'swedish',
92 | 'tg': 'tajik',
93 | 'ta': 'tamil',
94 | 'te': 'telugu',
95 | 'th': 'thai',
96 | 'tr': 'turkish',
97 | 'uk': 'ukrainian',
98 | 'ur': 'urdu',
99 | 'uz': 'uzbek',
100 | 'vi': 'vietnamese',
101 | 'cy': 'welsh',
102 | 'xh': 'xhosa',
103 | 'yi': 'yiddish',
104 | 'yo': 'yoruba',
105 | 'zu': 'zulu',
106 | 'fil': 'Filipino',
107 | 'he': 'Hebrew'
108 | }
109 |
110 | LANGCODES = dict(map(reversed, LANGUAGES.items()))
111 |
--------------------------------------------------------------------------------
/textaugment/eda.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # TextAugment: EDA
3 | #
4 | # Copyright (C) 2018-2023
5 | # Author: Joseph Sefara
6 | #
7 | # URL:
8 | # For license information, see LICENSE
9 | #
10 | """
11 | This module is an implementation of the original EDA algorithm (2019) [1].
12 | """
13 | import nltk
14 | from nltk.corpus import wordnet, stopwords
15 | import random
16 |
17 |
18 | class EDA:
19 | """
20 | This class is an implementation of the original EDA algorithm (2019) [1].
21 |
22 | [1] Wei, J. and Zou, K., 2019, November. EDA: Easy Data Augmentation Techniques for Boosting Performance on
23 | Text Classification Tasks. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing
24 | and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP) (pp. 6383-6389).
25 | https://www.aclweb.org/anthology/D19-1670.pdf
26 |
27 | Example usage: ::
28 | >>> from textaugment import EDA
29 | >>> t = EDA()
30 | >>> t.synonym_replacement("John is going to town",top_n=3)
31 | John is give out to town
32 | >>> t.random_deletion("John is going to town", p=0.2)
33 | is going to town
34 | >>> t.random_swap("John is going to town")
35 | John town going to is
36 | >>> t.random_insertion("John is going to town")
37 | John is going to make up town
38 | """
39 |
40 | @staticmethod
41 | def _get_synonyms(word):
42 | """Generate synonym"""
43 | synonyms = set()
44 | for syn in wordnet.synsets(word):
45 | for lemma in syn.lemmas():
46 | synonym = lemma.name().replace("_", " ").replace("-", " ").lower()
47 | synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
48 | synonyms.add(synonym)
49 | if word in synonyms:
50 | synonyms.remove(word)
51 | synonyms = sorted(list(synonyms))
52 | random.shuffle(synonyms)
53 | return synonyms
54 |
55 |
56 | @staticmethod
57 | def swap_word(new_words):
58 | """Swap words"""
59 | random_idx_1 = random.randint(0, len(new_words) - 1)
60 | random_idx_2 = random_idx_1
61 | counter = 0
62 | while random_idx_2 == random_idx_1:
63 | random_idx_2 = random.randint(0, len(new_words) - 1)
64 | counter += 1
65 | if counter > 3:
66 | return new_words
67 | new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
68 | return new_words
69 |
70 | @staticmethod
71 | def validate(**kwargs):
72 | """Validate input data"""
73 |
74 | if 'p' in kwargs:
75 | if kwargs['p'] > 1 or kwargs['p'] < 0:
76 | raise TypeError("p must be a fraction between 0 and 1")
77 | if 'sentence' in kwargs:
78 | if not isinstance(kwargs['sentence'].strip(), str) or len(kwargs['sentence'].strip()) == 0:
79 | raise TypeError("sentence must be a valid sentence")
80 | if 'n' in kwargs:
81 | if not isinstance(kwargs['n'], int):
82 | raise TypeError("n must be a valid integer")
83 |
84 | def __init__(self, stop_words=None, random_state=1):
85 | """A method to initialize parameters
86 |
87 | :type random_state: int
88 | :param random_state: (optional) Seed
89 | :type stop_words: list
90 | :param stop_words: (optional) List of stopwords
91 |
92 | :rtype: None
93 | :return: Constructer do not return.
94 | """
95 | self.stopwords = stopwords.words('english') if stop_words is None else stop_words
96 | self.sentence = None
97 | self.p = None
98 | self.n = None
99 | self.random_state = random_state
100 | if isinstance(self.random_state, int):
101 | random.seed(self.random_state)
102 | else:
103 | raise TypeError("random_state must have type int")
104 |
105 | def add_word(self, new_words):
106 | """Insert word"""
107 | synonyms = list()
108 | counter = 0
109 | while len(synonyms) < 1:
110 | random_word_list = list([word for word in new_words if word not in self.stopwords])
111 | random_word = random_word_list[random.randint(0, len(random_word_list) - 1)]
112 | synonyms = self._get_synonyms(random_word)
113 | counter += 1
114 | if counter >= 10:
115 | return new_words # See Issue 14 for details
116 | random_synonym = synonyms[0] # TODO
117 | random_idx = random.randint(0, len(new_words) - 1)
118 | new_words.insert(random_idx, random_synonym)
119 | return new_words
120 |
121 | # def synonym_replacement_top_n(self,
122 | # sentence: str,
123 | # n: int = 1,
124 | # top_n: int = None,
125 | # stopwords: list = None,
126 | # lang: str = 'eng'):
127 | #
128 | # """Replace n words in the sentence with top_n synonyms from wordnet
129 | #
130 | # :type sentence: str
131 | # :param sentence: Sentence
132 | # :type n: int
133 | # :param n: Number of repetitions to replace
134 | # :type top_n: int
135 | # :param top_n: top_n of synonyms to randomly choose from
136 | # :type stopwords: list
137 | # :param stopwords: stopwords
138 | # :type lang: str
139 | # :param lang: lang
140 | #
141 | # :rtype: str
142 | # :return: Augmented sentence
143 | # """
144 | #
145 | # stopwords = stopwords if stopwords else self.stopwords
146 | #
147 | # def get_synonyms(w, pos):
148 | # morphy_tag = {
149 | # 'NN': wordnet.NOUN,
150 | # 'JJ': wordnet.ADJ,
151 | # 'VB': wordnet.VERB,
152 | # 'RB': wordnet.ADV
153 | # }
154 | # for sunset in wordnet.synsets(w,
155 | # lang=lang,
156 | # pos=morphy_tag[pos[:2]] if pos[:2] in morphy_tag else None):
157 | # for lemma in sunset.lemmas(lang=lang):
158 | # yield lemma.name()
159 | #
160 | # new_words = list()
161 | # for index, (word, tag) in enumerate(nltk.pos_tag(nltk.word_tokenize(sentence))):
162 | # synonyms = sorted(set(synonym for synonym in get_synonyms(word, tag) if synonym != word))
163 | # synonyms = synonyms[:top_n if top_n else len(synonyms)]
164 | # new_words.append({
165 | # "index": index,
166 | # "word": word,
167 | # "new_word": random.choice(synonyms) if len(synonyms) > 0 else "",
168 | # "synonyms": synonyms,
169 | # "in_stopwords": word in stopwords
170 | # })
171 | #
172 | # replaced_index = random.choices([word["index"] for word in new_words
173 | # if not word["in_stopwords"] and len(word["synonyms"]) > 0], k=n)
174 | #
175 | # return ' '.join([word["new_word" if word["index"] in replaced_index else "word"] for word in new_words])
176 |
177 | def synonym_replacement(self, sentence: str, n: int = 1, top_n: int = None):
178 | """Replace n words in the sentence with synonyms from wordnet
179 |
180 | :type sentence: str
181 | :param sentence: Sentence
182 | :type n: int
183 | :param n: Number of repetitions to replace
184 | :type top_n: int
185 | :param top_n: top_n of synonyms to randomly choose from
186 |
187 | :rtype: str
188 | :return: Augmented sentence
189 | """
190 | self.validate(sentence=sentence, n=n)
191 | self.n = n
192 | self.sentence = sentence
193 | words = sentence.split()
194 | new_words = words.copy()
195 | random_word_list = sorted(set([word for word in words if word not in self.stopwords]))
196 | random.shuffle(random_word_list)
197 | replaced = 0
198 | for random_word in random_word_list:
199 | synonyms = self._get_synonyms(random_word)
200 | if len(synonyms) > 0:
201 | synonyms = synonyms[:top_n if top_n else len(synonyms)] # use top n or all synonyms
202 | synonym = random.choice(synonyms)
203 | new_words = [synonym if word == random_word else word for word in new_words]
204 | replaced += 1
205 | if replaced >= self.n:
206 | break
207 | sentence = ' '.join(new_words)
208 |
209 | return sentence
210 |
211 | def random_deletion(self, sentence: str, p: float = 0.1):
212 | """Randomly delete words from the sentence with probability p
213 |
214 | :type sentence: str
215 | :param sentence: Sentence
216 | :type p: int
217 | :param p: Probability between 0 and 1
218 |
219 | :rtype: str
220 | :return: Augmented sentence
221 | """
222 | self.validate(sentence=sentence, p=p)
223 | self.p = p
224 | self.sentence = sentence
225 | words = sentence.split()
226 | if len(words) == 1:
227 | return words[0]
228 | new_words = list()
229 | for word in words:
230 | r = random.uniform(0, 1)
231 | if r > self.p:
232 | new_words.append(word)
233 | # if all words are deleted, just return a random word
234 | if len(new_words) == 0:
235 | return random.choice(words)
236 |
237 | return " ".join(new_words)
238 |
239 | def random_swap(self, sentence: str, n: int = 1):
240 | """Randomly swap two words in the sentence n times
241 |
242 | :type sentence: str
243 | :param sentence: Sentence
244 | :type n: int
245 | :param n: Number of repetitions to swap
246 |
247 | :rtype: str
248 | :return: Augmented sentence
249 | """
250 | self.validate(sentence=sentence, n=n)
251 | self.n = n
252 | self.sentence = sentence
253 | words = sentence.split()
254 | new_words = words.copy()
255 | for _ in range(self.n):
256 | new_words = self.swap_word(new_words)
257 | return " ".join(new_words)
258 |
259 | def random_insertion(self, sentence: str, n: int = 1):
260 | """Randomly insert n words into the sentence
261 |
262 | :type sentence: str
263 | :param sentence: Sentence
264 | :type n: int
265 | :param n: Number of words to insert
266 |
267 | :rtype: str
268 | :return: Augmented sentence
269 | """
270 | self.validate(sentence=sentence, n=n)
271 | self.n = n
272 | self.sentence = sentence
273 | words = sentence.split()
274 | new_words = words.copy()
275 | for _ in range(self.n):
276 | new_words = self.add_word(new_words)
277 | return " ".join(new_words)
278 |
--------------------------------------------------------------------------------
/textaugment/mixup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # TextAugment: mixup
3 | #
4 | # Copyright (C) 2018-2023
5 | # Authors: Joseph Sefara, Vukosi Marivate
6 | #
7 | # URL:
8 | # For license information, see LICENSE
9 | import numpy as np
10 | import random
11 |
12 |
13 | class MIXUP:
14 | """
15 | This class implements the mixup algorithm [1] for natural language processing.
16 |
17 | [1] Zhang, Hongyi, Moustapha Cisse, Yann N. Dauphin, and David Lopez-Paz. "mixup: Beyond empirical risk
18 | minimization." in International Conference on Learning Representations (2018).
19 | https://openreview.net/forum?id=r1Ddp1-Rb
20 | """
21 |
22 | @staticmethod
23 | def validate(**kwargs):
24 | """Validate input data"""
25 |
26 | if 'data' in kwargs:
27 | if isinstance(kwargs['data'], list):
28 | kwargs['data'] = np.array(kwargs['data'])
29 | if not isinstance(kwargs['data'], np.ndarray):
30 | raise TypeError("data must be numpy array. Found " + str(type(kwargs['data'])))
31 | if 'labels' in kwargs:
32 | if isinstance(kwargs['labels'], (list, type(None))):
33 | kwargs['labels'] = np.array(kwargs['labels'])
34 | if not isinstance(kwargs['labels'], np.ndarray):
35 | raise TypeError("labels must be numpy array. Found " + str(type(kwargs['labels'])))
36 | if 'batch_size' in kwargs:
37 | if not isinstance(kwargs['batch_size'], int):
38 | raise TypeError("batch_size must be a valid integer. Found " + str(type(kwargs['batch_size'])))
39 | if 'shuffle' in kwargs:
40 | if not isinstance(kwargs['shuffle'], bool):
41 | raise TypeError("shuffle must be a boolean. Found " + str(type(kwargs['shuffle'])))
42 | if 'runs' in kwargs:
43 | if not isinstance(kwargs['runs'], int):
44 | raise TypeError("runs must be a valid integer. Found " + str(type(kwargs['runs'])))
45 |
46 | def __init__(self, random_state=1, runs=1):
47 | self.random_state = random_state
48 | self.runs = runs
49 | if isinstance(self.random_state, int):
50 | random.seed(self.random_state)
51 | np.random.seed(self.random_state)
52 | else:
53 | raise TypeError("random_state must have type int")
54 |
55 | def mixup_data(self, x, y=None, alpha=0.2):
56 | """This method performs mixup. If runs = 1 it just does 1 mixup with whole batch, any n of runs
57 | creates many mixup matches.
58 |
59 | :type x: Numpy array
60 | :param x: Data array
61 | :type y: Numpy array
62 | :param y: (optional) labels
63 | :type alpha: float
64 | :param alpha: alpha
65 |
66 | :rtype: tuple
67 | :return: Returns mixed inputs, pairs of targets, and lambda
68 | """
69 | if self.runs is None:
70 | self.runs = 1
71 | output_x = []
72 | output_y = []
73 | batch_size = x.shape[0]
74 | for i in range(self.runs):
75 | lam_vector = np.random.beta(alpha, alpha, batch_size)
76 | index = np.random.permutation(batch_size)
77 | mixed_x = (x.T * lam_vector).T + (x[index, :].T * (1.0 - lam_vector)).T
78 | output_x.append(mixed_x)
79 | if y is None:
80 | return np.concatenate(output_x, axis=0)
81 | mixed_y = (y.T * lam_vector).T + (y[index].T * (1.0 - lam_vector)).T
82 | output_y.append(mixed_y)
83 | return np.concatenate(output_x, axis=0), np.concatenate(output_y, axis=0)
84 |
85 | def flow(self, data, labels=None, batch_size=32, shuffle=True, runs=1):
86 | """This function implements the batch iterator and specifically calls mixup
87 |
88 | :param data: Input data. Numpy ndarray or list of lists.
89 | :param labels: Labels. Numpy ndarray or list of lists.
90 | :param batch_size: Int (default: 32).
91 | :param shuffle: Boolean (default: True).
92 | :param runs: Int (default: 1). Number of augmentations
93 |
94 | :rtype: array or tuple
95 | :return: array or tuple of arrays (X_data array, labels array)."""
96 |
97 | self.validate(data=data, labels=labels, batch_size=batch_size, shuffle=shuffle, runs=runs)
98 |
99 | self.runs = runs
100 |
101 | num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1
102 |
103 | def data_generator():
104 | data_size = len(data)
105 | while True:
106 | # Shuffle the data at each epoch
107 | if shuffle:
108 | shuffle_indices = np.random.permutation(np.arange(data_size))
109 | shuffled_data = data[shuffle_indices]
110 | if labels is not None:
111 | shuffled_labels = labels[shuffle_indices]
112 | else:
113 | shuffled_data = data
114 | if labels is not None:
115 | shuffled_labels = labels
116 | for batch_num in range(num_batches_per_epoch):
117 | start_index = batch_num * batch_size
118 | end_index = min((batch_num + 1) * batch_size, data_size)
119 | X = shuffled_data[start_index: end_index]
120 | if labels is None:
121 | X = self.mixup_data(X, y=None)
122 | yield X
123 | else:
124 | y = shuffled_labels[start_index: end_index]
125 | X, y = self.mixup_data(X, y)
126 | yield X, y
127 |
128 | return data_generator(), num_batches_per_epoch
129 |
--------------------------------------------------------------------------------
/textaugment/translate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # WordNet-based data augmentation
3 | #
4 | # Copyright (C) 2020
5 | # Author: Joseph Sefara
6 | # URL:
7 | # For license information, see LICENSE
8 |
9 | from .constants import LANGUAGES
10 | from textblob import TextBlob
11 | from textblob.translate import NotTranslated
12 | from googletrans import Translator
13 |
14 |
15 | class Translate:
16 | """
17 | A set of functions used to augment data.
18 | Supported languages are:
19 | Language Name Code
20 | Afrikaans af
21 | Albanian sq
22 | Arabic ar
23 | Azerbaijani az
24 | Basque eu
25 | Bengali bn
26 | Belarusian be
27 | Bulgarian bg
28 | Catalan ca
29 | Chinese Simplified zh-CN
30 | Chinese Traditional zh-TW
31 | Croatian hr
32 | Czech cs
33 | Danish da
34 | Dutch nl
35 | English en
36 | Esperanto eo
37 | Estonian et
38 | Filipino tl
39 | Finnish fi
40 | French fr
41 | Galician gl
42 | Georgian ka
43 | German de
44 | Greek el
45 | Gujarati gu
46 | Haitian Creole ht
47 | Hebrew iw
48 | Hindi hi
49 | Hungarian hu
50 | Icelandic is
51 | Indonesian id
52 | Irish ga
53 | Italian it
54 | Japanese ja
55 | Kannada kn
56 | Korean ko
57 | Latin la
58 | Latvian lv
59 | Lithuanian lt
60 | Macedonian mk
61 | Malay ms
62 | Maltese mt
63 | Norwegian no
64 | Persian fa
65 | Polish pl
66 | Portuguese pt
67 | Romanian ro
68 | Russian ru
69 | Serbian sr
70 | Slovak sk
71 | Slovenian sl
72 | Spanish es
73 | Swahili sw
74 | Swedish sv
75 | Tamil ta
76 | Telugu te
77 | Thai th
78 | Turkish tr
79 | Ukrainian uk
80 | Urdu ur
81 | Vietnamese vi
82 | Welsh cy
83 | Yiddish yi
84 |
85 | Example usage: ::
86 | >>> from textaugment import Translate
87 | >>> t = Translate(src="en",to="es")
88 | >>> t.augment('I love school')
89 | i adore school
90 | """
91 |
92 | def __init__(self, **kwargs):
93 |
94 | """
95 | A method to initialize parameters
96 |
97 | :type src: str
98 | :param src: Source language of the text
99 | :type to: str
100 | :param to: Destination language to translate to. The language should be a family of the source language for
101 | better results. The text will then be translated back to the source language.
102 | :rtype: None
103 | :return: Constructer do not return.
104 | """
105 | hl = LANGUAGES
106 |
107 | try:
108 | if "to" not in kwargs:
109 | raise ValueError("'to' missing")
110 | elif "src" not in kwargs:
111 | raise ValueError("'src' missing")
112 | if kwargs['to'] not in hl:
113 | raise KeyError("Value of to is not surpported. See help(Translate)")
114 | if kwargs['src'] not in hl:
115 | raise KeyError("Value of src is not surpported. See help(Translate)")
116 | except (ValueError, KeyError):
117 | print("The values of the keys 'to' and 'src' are required. E.g Translate(src='en', to='es')")
118 | raise
119 | else:
120 | self.to = kwargs['to']
121 | self.src = kwargs['src']
122 |
123 | def augment(self, data):
124 | """
125 | A method to paraphrase a sentence.
126 |
127 | :type data: str
128 | :param data: sentence used for data augmentation
129 | :rtype: str
130 | :return: The augmented data
131 | """
132 | if type(data) is not str:
133 | raise TypeError("DataType must be a string")
134 | data = TextBlob(data.lower())
135 | try:
136 | data = data.translate(from_lang=self.src, to=self.to)
137 | data = data.translate(from_lang=self.to, to=self.src)
138 | except NotTranslated:
139 | try: # Switch to googletrans to do translation.
140 | translator = Translator()
141 | data = translator.translate(data, dest=self.to, src=self.src).text
142 | data = translator.translate(data, dest=self.src, src=self.to).text
143 | except Exception:
144 | print("Error Not translated.\n")
145 | raise
146 |
147 | return str(data).lower()
148 |
--------------------------------------------------------------------------------
/textaugment/word2vec.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Word2vec-based data augmentation
3 | #
4 | # Copyright (C) 2023
5 | # Author: Joseph Sefara
6 | # URL:
7 | # For license information, see LICENSE
8 |
9 | import gensim
10 | import numpy as np
11 | import random
12 |
13 |
14 | class Word2vec:
15 | """
16 | A set of functions used to augment data.
17 |
18 | Typical usage: ::
19 | >>> from textaugment import Word2vec
20 | >>> t = Word2vec(model='path/to/gensim/model'or 'gensim model itself')
21 | >>> t.augment('I love school', top_n=10)
22 | i adore school
23 | """
24 |
25 | def __init__(self, **kwargs):
26 | """
27 | A method to initialize a model on a given path.
28 | :type random_state: int, float, str, bytes, bytearray
29 | :param random_state: seed
30 | :type model: str or gensim.models.word2vec.Word2Vec or gensim.models.fasttext.FastText
31 | :param model: The path to the model or the model itself.
32 | :type runs: int, optional
33 | :param runs: The number of times to augment a sentence. By default is 1.
34 | :type v: bool or optional
35 | :param v: Replace all the words if true. If false randomly replace words.
36 | Used in a Paper (https://www.cs.cmu.edu/~diyiy/docs/emnlp_wang_2015.pdf)
37 | :type p: float, optional
38 | :param p: The probability of success of an individual trial. (0.1
>> from textaugment import Fasttext
162 | >>> t = Fasttext('path/to/gensim/model'or 'gensim model itself')
163 | >>> t.augment('I love school', top_n=10)
164 | i adore school
165 | """
166 | pass
167 |
--------------------------------------------------------------------------------
/textaugment/wordnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # WordNet-based data augmentation
3 | #
4 | # Copyright (C) 2023
5 | # Author: Joseph Sefara
6 | # URL:
7 | # For license information, see LICENSE
8 |
9 | import numpy as np
10 | import nltk
11 | from itertools import chain
12 | from nltk.corpus import wordnet
13 |
14 |
15 | class Wordnet:
16 | """
17 | A set of functions used to augment data.
18 |
19 | Typical usage: ::
20 | >>> import nltk
21 | >>> nltk.download('punkt')
22 | >>> nltk.download('wordnet')
23 | >>> nltk.download('averaged_perceptron_tagger')
24 | >>> from textaugment import Wordnet
25 | >>> t = Wordnet(v=True,n=True,p=0.5)
26 | >>> t.augment('I love school')
27 | i adore school
28 | """
29 |
30 | def __init__(self, **kwargs):
31 | """
32 | A method to initialize parameters
33 |
34 | :type random_state: int
35 | :param random_state: seed
36 | :type v: bool
37 | :param v: Verb, default is True
38 | :type n: bool
39 | :param n: Noun
40 | :type runs: int
41 | :param runs: Number of repetition on single text
42 | :type p: float, optional
43 | :param p: The probability of success of an individual trial. (0.1
= 1: # There are synonyms
126 | for word in words:
127 | synonyms1 = wordnet.synsets(word[1], wordnet.VERB, lang=lang) # Return verbs only
128 | synonyms = list(set(chain.from_iterable([syn.lemma_names(lang=lang) for syn in synonyms1])))
129 | synonyms_ = [] # Synonyms with no underscores goes here
130 | for w in synonyms:
131 | if '_' not in w:
132 | synonyms_.append(w) # Remove words with underscores
133 | if len(synonyms_) >= 1:
134 | synonyms_ = synonyms_[:top_n if top_n else len(synonyms_)] # use top n or all synonyms
135 | synonym = self.geometric(data=synonyms_).tolist()
136 | if synonym: # There is a synonym
137 | data[int(word[0])] = synonym[0].lower() # Take the first success
138 |
139 | if self.n:
140 | for loop in range(self.runs):
141 | words = [[i, x] for i, x, y in data_tokens if y[0] == 'N']
142 | words = [i for i in self.geometric(data=words)] # List of selected words
143 | if len(words) >= 1: # There are synonyms
144 | for word in words:
145 | synonyms1 = wordnet.synsets(word[1], wordnet.NOUN, lang=lang) # Return nouns only
146 | synonyms = list(set(chain.from_iterable([syn.lemma_names(lang=lang) for syn in synonyms1])))
147 | synonyms_ = [] # Synonyms with no underscores goes here
148 | for w in synonyms:
149 | if '_' not in w:
150 | synonyms_.append(w) # Remove words with underscores
151 | if len(synonyms_) >= 1:
152 | synonyms_ = synonyms_[:top_n if top_n else len(synonyms_)] # use top n or all synonyms
153 | synonym = self.geometric(data=synonyms_).tolist()
154 | if synonym: # There is a synonym
155 | data[int(word[0])] = synonym[0].lower() # Take the first success
156 |
157 | return " ".join(data)
158 |
159 | def augment(self, data, lang="eng", top_n=10):
160 | """
161 | Data augmentation for text. Generate new dataset based on verb/nouns synonyms.
162 |
163 | :type data: str
164 | :param data: sentence used for data augmentation
165 | :rtype: str
166 | :return: The augmented data
167 | :type lang: str
168 | :param lang: choose lang
169 | :type top_n: int
170 | :param top_n: top_n of synonyms to randomly choose from
171 |
172 | :rtype: str
173 | :return: The augmented data
174 | """
175 | # Error handling
176 | if type(data) is not str:
177 | raise TypeError("Only strings are supported")
178 | if type(lang) is not str:
179 | raise TypeError("Only strings are supported")
180 | if type(top_n) is not int:
181 | raise TypeError("Only integers are supported")
182 |
183 | data = self.replace(data, lang, top_n)
184 | return data
185 |
--------------------------------------------------------------------------------