├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST ├── README.rst ├── docs ├── Makefile ├── conf.py ├── example.text_classification.rst ├── examples.rst ├── index.rst ├── make.bat ├── stanza.cluster.rst ├── stanza.corenlp.rst ├── stanza.ml.rst ├── stanza.monitoring.rst ├── stanza.research.rst ├── stanza.text.rst └── stanza.util.rst ├── examples ├── convert_to_rst.sh └── text_classification.ipynb ├── requirements.txt ├── setup.cfg ├── setup.py ├── stanza ├── __init__.py ├── cluster │ ├── __init__.py │ └── pick_gpu.py ├── ml │ ├── __init__.py │ ├── embeddings.py │ └── tensorflow_utils.py ├── monitoring │ ├── __init__.py │ ├── crc32c.py │ ├── experiment.py │ ├── progress.py │ ├── summary.py │ └── trigger.py ├── nlp │ ├── CoreNLP_pb2.py │ ├── __init__.py │ ├── corenlp.py │ ├── data.py │ └── protobuf_json.py ├── research │ ├── __init__.py │ ├── bleu.py │ ├── codalab.py │ ├── config.py │ ├── evaluate.py │ ├── instance.py │ ├── iterators.py │ ├── learner.py │ ├── logfile.py │ ├── metrics.py │ ├── mockfs.py │ ├── output.py │ ├── pick_gpu.py │ ├── progress.py │ ├── quickstart │ ├── rng.py │ ├── summary.py │ ├── summary_basic.py │ └── templates │ │ ├── README.rst │ │ ├── baseline.py │ │ ├── coveragerc │ │ ├── datasets.py │ │ ├── dependencies │ │ ├── error_analysis.py │ │ ├── fasttests │ │ ├── gitignore │ │ ├── learners.py │ │ ├── metrics.py │ │ ├── run_experiment.py │ │ ├── setup.cfg │ │ ├── tests │ │ ├── third-party │ │ └── tensorflow │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ └── core │ │ │ ├── __init__.py │ │ │ ├── framework │ │ │ ├── __init__.py │ │ │ ├── attr_value_pb2.py │ │ │ ├── function_pb2.py │ │ │ ├── graph_pb2.py │ │ │ ├── op_def_pb2.py │ │ │ ├── summary_pb2.py │ │ │ ├── tensor_pb2.py │ │ │ ├── tensor_shape_pb2.py │ │ │ └── types_pb2.py │ │ │ └── util │ │ │ ├── __init__.py │ │ │ └── event_pb2.py │ │ └── wrapper.py ├── text │ ├── __init__.py │ ├── dataset.py │ ├── utils.py │ └── vocab.py ├── unstable │ └── __init__.py └── util │ ├── __init__.py │ ├── postgres.py │ ├── resource.py │ └── unicode.py └── test ├── __init__.py ├── slow_tests ├── __init__.py └── text │ ├── __init__.py │ ├── test_glove.py │ └── test_senna.py └── unit_tests ├── README.md ├── __init__.py ├── ml ├── __init__.py ├── test_embeddings.py └── test_tensorflow_utils.py ├── monitoring ├── __init__.py ├── test_summary.py └── test_trigger.py ├── nlp ├── __init__.py ├── document.pb └── test_data.py └── text ├── __init__.py ├── test_dataset.py └── test_vocab.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | .idea/ 6 | .ropeproject/ 7 | data/ 8 | 9 | _*/ 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *,cover 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | 59 | # Sphinx documentation 60 | docs/_build/ 61 | 62 | # PyBuilder 63 | target/ 64 | 65 | .ipynb_checkpoints 66 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | dist: trusty 3 | language: python 4 | python: 5 | - 2.7 6 | notifications: 7 | email: false 8 | # Setup anaconda 9 | before_install: 10 | - sudo apt-get update 11 | # We do this conditionally because it saves us some downloading if the 12 | # version is the same. 13 | - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 14 | wget https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh; 15 | else 16 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 17 | fi 18 | - bash miniconda.sh -b -p ${HOME}/miniconda 19 | - export PATH="${HOME}/miniconda/bin:$PATH" 20 | - conda config --set always_yes yes --set changeps1 no 21 | - conda update -q conda 22 | # Useful for debugging any issues with conda 23 | - conda info -a 24 | # Install packages 25 | install: 26 | - which pip 27 | - pip --version 28 | - which python 29 | - python --version 30 | - conda install --yes python=$TRAVIS_PYTHON_VERSION numpy scipy matplotlib scikit-learn 31 | - pip install requests nose pypng pyhocon ConfigArgParse==0.10.0 mock pytest python-Levenshtein 32 | - python -c "import requests" 33 | - pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0rc0-cp27-none-linux_x86_64.whl 34 | - python -c "import tensorflow" 35 | - pip install -e . 36 | # env: 37 | # - THEANO_FLAGS='floatX=float32' 38 | # Run test 39 | script: python setup.py test 40 | # only integrate the master branch 41 | branches: 42 | only: 43 | - master 44 | - feature 45 | - develop 46 | # web hooks (eg. gitter) 47 | notifications: 48 | slack: 49 | rooms: 50 | - stanfordnlp:aAWRsBOxoolXY8vxXEJ8NkTV#stanza 51 | on_success: never 52 | on_failure: always 53 | 54 | # webhooks: 55 | # on_success: change # options: [always|never|change] default: always 56 | # on_failure: always # options: [always|never|change] default: always 57 | # on_start: false # default: false 58 | email: 59 | on_success: never # [always|never|change] # default: change 60 | on_failure: change #[always|never|change] # default: always 61 | -------------------------------------------------------------------------------- /MANIFEST: -------------------------------------------------------------------------------- 1 | # file GENERATED by distutils, do NOT edit 2 | setup.cfg 3 | setup.py 4 | stanza/__init__.py 5 | stanza/monitoring/__init__.py 6 | stanza/monitoring/crc32c.py 7 | stanza/monitoring/experiment.py 8 | stanza/monitoring/progress.py 9 | stanza/monitoring/summary.py 10 | stanza/monitoring/trigger.py 11 | stanza/text/__init__.py 12 | stanza/text/dataset.py 13 | stanza/text/utils.py 14 | stanza/text/vocab.py 15 | stanza/util/__init__.py 16 | stanza/util/postgres.py 17 | stanza/util/resource.py 18 | stanza/util/unicode.py 19 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Stanza 2 | ====== 3 | 4 | |Master Build Status| |Documentation Status| 5 | 6 | Stanza is the Stanford NLP group’s shared repository for Python 7 | infrastructure. The goal of Stanza is not to replace your modeling tools 8 | of choice, but to offer implementations for common patterns useful for 9 | machine learning experiments. 10 | 11 | Usage 12 | ----- 13 | 14 | You can install the package as follows: 15 | 16 | :: 17 | 18 | git clone git@github.com:stanfordnlp/stanza.git 19 | cd stanza 20 | pip install -e . 21 | 22 | To use the package, import it in your python code. An example would be: 23 | 24 | :: 25 | 26 | from stanza.text.vocab import Vocab 27 | v = Vocab('UNK') 28 | 29 | To use the Python client for the CoreNLP server, `first launch your CoreNLP Java server `__. Then, in your Python program: 30 | 31 | :: 32 | 33 | from stanza.nlp.corenlp import CoreNLPClient 34 | client = CoreNLPClient(server='http://localhost:9000', default_annotators=['ssplit', 'tokenize', 'lemma', 'pos', 'ner']) 35 | annotated = client.annotate('This is an example document. Here is a second sentence') 36 | for sentence in annotated.sentences: 37 | print('sentence', sentence) 38 | for token in sentence: 39 | print(token.word, token.lemma, token.pos, token.ner) 40 | 41 | Please see the documentation for more use cases. 42 | 43 | Documentation 44 | ------------- 45 | 46 | Documentation is hosted on Read the Docs at 47 | http://stanza.readthedocs.org/en/latest/. Stanza is still in early 48 | development. Interfaces and code organization will probably change 49 | substantially over the next few months. 50 | 51 | Development Guide 52 | ----------------- 53 | 54 | To request or discuss additional functionality, please open a GitHub 55 | issue. We greatly appreciate pull requests! 56 | 57 | Tests 58 | ~~~~~ 59 | 60 | Stanza has unit tests, doctests, and longer, integration tests. We ask that all 61 | contributors run the unit tests and doctests before submitting pull requests: 62 | 63 | .. code:: python 64 | 65 | python setup.py test 66 | 67 | Doctests are the easiest way to write a test for new functionality, and serve 68 | as helpful examples for how to use your code. See 69 | `progress.py `__ for a simple example of a easily 70 | testable module, or `summary.py `__ for a more 71 | involved setup involving a mocked filesystem. 72 | 73 | Adding a new module 74 | ~~~~~~~~~~~~~~~~~~~ 75 | 76 | If you are adding a new module, please remember to add it to 77 | ``setup.py`` as well as a corresponding ``.rst`` file in the ``docs`` 78 | directory. 79 | 80 | Documentation 81 | ~~~~~~~~~~~~~ 82 | 83 | Documentation is generated via 84 | `Sphinx `__ using inline comments. 85 | This means that the docstring in Python double both as interactive 86 | documentation and standalone documentation. This also means that you 87 | must format your docstring in RST. RST is very similar to Markdown. 88 | There are many tutorials on the exact syntax, essentially you only need 89 | to know the function parameter syntax which can be found 90 | `here `__. 91 | You can, of course, look at documentations for existing modules for 92 | guidance as well. A good place to start is the ``text.dataset`` package. 93 | 94 | To set up your environment such that you can generate docs locally: 95 | 96 | :: 97 | 98 | pip install sphinx sphinx-autobuild 99 | 100 | If you introduced a new module, please auto-generate the docs: 101 | 102 | :: 103 | 104 | sphinx-apidoc -F -o docs stanza 105 | cd docs && make 106 | open _build/html/index.html 107 | 108 | You most likely need to manually edit the `rst` file corresponding to your new module. 109 | 110 | Our docs are `hosted on Readthedocs `__. If you'd like admin access to the Readthedocs project, please contact Victor or Will. 111 | 112 | Road Map 113 | -------- 114 | 115 | - common objects used in NLP 116 | 117 | - [x] a Vocabulary object mapping from strings to integers/vectors 118 | 119 | - tools for running experiments on the NLP cluster 120 | 121 | - [ ] a function for querying GPU device stats (to aid in selecting 122 | a GPU on the cluster) 123 | - [ ] a tool for plotting training curves from multiple jobs 124 | - [ ] a tool for interacting with an already running job via edits 125 | to a text file 126 | 127 | - [x] an API for calling CoreNLP 128 | 129 | For Stanford NLP members 130 | ------------------------ 131 | 132 | Stanza is not meant to include every research project the group 133 | undertakes. If you have a standalone project that you would like to 134 | share with other people in the group, you can: 135 | 136 | - request your own private repo under the `stanfordnlp GitHub 137 | account `__. 138 | - share your code on `CodaLab `__. 139 | - For targeted questions, ask on `Stanford NLP 140 | Overflow `__ (use the ``stanza`` 141 | tag). 142 | 143 | Using `git subtree` 144 | ~~~~~~~~~~~~~~~~~~~ 145 | 146 | That said, it can be useful to add functionality to Stanza while you work in a 147 | separate repo on a project that depends on Stanza. Since Stanza is under active 148 | development, you will want to version-control the Stanza code that your code 149 | uses. Probably the most effective way of accomplishing this is by using 150 | ``git subtree``. 151 | 152 | ``git subtree`` includes the source tree of another repo (in 153 | this case, Stanza) as a directory within your repo (your cutting-edge 154 | research), and keeps track of some metadata that allows you to keep that 155 | directory in sync with the original Stanza code. The main advantage of ``git 156 | subtree`` is that you can modify the Stanza code locally, merge in updates, and 157 | push your changes back to the Stanza repo to share them with the group. (``git 158 | submodule`` doesn't allow this.) 159 | 160 | It has some downsides to be aware of: 161 | 162 | - You have a copy of all of Stanza as part of your repo. For small projects, 163 | this could increase your repo size dramatically. (Note: you can keep the 164 | history of your repo from growing at the same rate as Stanza's by using 165 | squashed commits; it's only the size of the source tree that unavoidably 166 | bloats your project.) 167 | - Your repo's history will contain a merge commit every time you update Stanza 168 | from upstream. This can look ugly, especially in graphical viewers. 169 | 170 | Still, ``subtree`` can be configured to be fairly easy to use, and the consensus 171 | seems to be that it is superior to ``submodule`` (``__). 172 | 173 | Here's one way to configure ``subtree`` so that you can include Stanza in 174 | your repo and contribute your changes back to the master repo: 175 | 176 | :: 177 | 178 | # Add Stanza as a remote repo 179 | git remote add stanza http://@github.com/stanfordnlp/stanza.git 180 | # Import the contents of the repo as a subtree 181 | git subtree add --prefix third-party/stanza stanza develop --squash 182 | # Put a symlink to the actual module somewhere where your code needs it 183 | ln -s third-party/stanza/stanza stanza 184 | # Add aliases for the two things you'll need to do with the subtree 185 | git config alias.stanza-update 'subtree pull --prefix third-party/stanza stanza develop --squash' 186 | git config alias.stanza-push 'subtree push --prefix third-party/stanza stanza develop' 187 | 188 | After this, you can use the aliases to push and pull Stanza like so: 189 | 190 | :: 191 | 192 | git stanza-update 193 | git stanza-push 194 | 195 | I [@futurulus] highly recommend a `topic branch/rebase workflow `__, 196 | which will keep your history fairly clean besides those pesky subtree merge 197 | commits: 198 | 199 | :: 200 | 201 | # Create a topic branch 202 | git checkout -b fix-stanza 203 | # 204 | 205 | git checkout master 206 | # Update Stanza on master, should go smoothly because master doesn't 207 | # have any of your changes yet 208 | git stanza-update 209 | 210 | # Go back and replay your fixes on top of master changes 211 | git checkout fix-stanza 212 | git rebase master 213 | # You might need to resolve merge conflicts here 214 | 215 | # Add your rebased changes to master and push 216 | git checkout master 217 | git merge --ff-only fix-stanza 218 | git stanza-push 219 | # Done! 220 | git branch -d fix-stanza 221 | 222 | .. |Master Build Status| image:: https://travis-ci.org/stanfordnlp/stanza.svg?branch=master 223 | :target: https://travis-ci.org/stanfordnlp/stanza 224 | .. |Documentation Status| image:: https://readthedocs.org/projects/stanza/badge/?version=latest 225 | :target: http://stanza.readthedocs.org/en/latest/?badge=latest 226 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " applehelp to make an Apple Help Book" 34 | @echo " devhelp to make HTML files and a Devhelp project" 35 | @echo " epub to make an epub" 36 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 37 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 38 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 39 | @echo " text to make text files" 40 | @echo " man to make manual pages" 41 | @echo " texinfo to make Texinfo files" 42 | @echo " info to make Texinfo files and run them through makeinfo" 43 | @echo " gettext to make PO message catalogs" 44 | @echo " changes to make an overview of all changed/added/deprecated items" 45 | @echo " xml to make Docutils-native XML files" 46 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 47 | @echo " linkcheck to check all external links for integrity" 48 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 49 | @echo " coverage to run coverage check of the documentation (if enabled)" 50 | 51 | clean: 52 | rm -rf $(BUILDDIR)/* 53 | 54 | html: 55 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 56 | @echo 57 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 58 | 59 | dirhtml: 60 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 61 | @echo 62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 63 | 64 | singlehtml: 65 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 66 | @echo 67 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 68 | 69 | pickle: 70 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 71 | @echo 72 | @echo "Build finished; now you can process the pickle files." 73 | 74 | json: 75 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 76 | @echo 77 | @echo "Build finished; now you can process the JSON files." 78 | 79 | htmlhelp: 80 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 81 | @echo 82 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 83 | ".hhp project file in $(BUILDDIR)/htmlhelp." 84 | 85 | qthelp: 86 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 87 | @echo 88 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 89 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 90 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/stanza.qhcp" 91 | @echo "To view the help file:" 92 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/stanza.qhc" 93 | 94 | applehelp: 95 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 96 | @echo 97 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 98 | @echo "N.B. You won't be able to view it unless you put it in" \ 99 | "~/Library/Documentation/Help or install it in your application" \ 100 | "bundle." 101 | 102 | devhelp: 103 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 104 | @echo 105 | @echo "Build finished." 106 | @echo "To view the help file:" 107 | @echo "# mkdir -p $$HOME/.local/share/devhelp/stanza" 108 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/stanza" 109 | @echo "# devhelp" 110 | 111 | epub: 112 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 113 | @echo 114 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 115 | 116 | latex: 117 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 118 | @echo 119 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 120 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 121 | "(use \`make latexpdf' here to do that automatically)." 122 | 123 | latexpdf: 124 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 125 | @echo "Running LaTeX files through pdflatex..." 126 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 127 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 128 | 129 | latexpdfja: 130 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 131 | @echo "Running LaTeX files through platex and dvipdfmx..." 132 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 133 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 134 | 135 | text: 136 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 137 | @echo 138 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 139 | 140 | man: 141 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 142 | @echo 143 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 144 | 145 | texinfo: 146 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 147 | @echo 148 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 149 | @echo "Run \`make' in that directory to run these through makeinfo" \ 150 | "(use \`make info' here to do that automatically)." 151 | 152 | info: 153 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 154 | @echo "Running Texinfo files through makeinfo..." 155 | make -C $(BUILDDIR)/texinfo info 156 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 157 | 158 | gettext: 159 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 160 | @echo 161 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 162 | 163 | changes: 164 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 165 | @echo 166 | @echo "The overview file is in $(BUILDDIR)/changes." 167 | 168 | linkcheck: 169 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 170 | @echo 171 | @echo "Link check complete; look for any errors in the above output " \ 172 | "or in $(BUILDDIR)/linkcheck/output.txt." 173 | 174 | doctest: 175 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 176 | @echo "Testing of doctests in the sources finished, look at the " \ 177 | "results in $(BUILDDIR)/doctest/output.txt." 178 | 179 | coverage: 180 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 181 | @echo "Testing of coverage in the sources finished, look at the " \ 182 | "results in $(BUILDDIR)/coverage/python.txt." 183 | 184 | xml: 185 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 186 | @echo 187 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 188 | 189 | pseudoxml: 190 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 191 | @echo 192 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 193 | -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | Text Classification 8 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. stanza documentation master file, created by 2 | sphinx-quickstart on Fri Apr 1 22:03:24 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | 7 | Documentation Quicknav 8 | ====================== 9 | 10 | .. toctree:: 11 | :maxdepth: 4 12 | 13 | Text 14 | CoreNLP Client 15 | Machine learning utilities 16 | Monitoring utilities 17 | Cluster management 18 | Misc 19 | Research (development code) 20 | 21 | 22 | Examples 23 | ======== 24 | 25 | .. toctree:: 26 | :maxdepth: 2 27 | 28 | Examples 29 | 30 | 31 | .. include:: ../README.rst 32 | 33 | 34 | Indices and tables 35 | ================== 36 | 37 | * :ref:`genindex` 38 | * :ref:`modindex` 39 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | echo. coverage to run coverage check of the documentation if enabled 41 | goto end 42 | ) 43 | 44 | if "%1" == "clean" ( 45 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 46 | del /q /s %BUILDDIR%\* 47 | goto end 48 | ) 49 | 50 | 51 | REM Check if sphinx-build is available and fallback to Python version if any 52 | %SPHINXBUILD% 2> nul 53 | if errorlevel 9009 goto sphinx_python 54 | goto sphinx_ok 55 | 56 | :sphinx_python 57 | 58 | set SPHINXBUILD=python -m sphinx.__init__ 59 | %SPHINXBUILD% 2> nul 60 | if errorlevel 9009 ( 61 | echo. 62 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 63 | echo.installed, then set the SPHINXBUILD environment variable to point 64 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 65 | echo.may add the Sphinx directory to PATH. 66 | echo. 67 | echo.If you don't have Sphinx installed, grab it from 68 | echo.http://sphinx-doc.org/ 69 | exit /b 1 70 | ) 71 | 72 | :sphinx_ok 73 | 74 | 75 | if "%1" == "html" ( 76 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 77 | if errorlevel 1 exit /b 1 78 | echo. 79 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 80 | goto end 81 | ) 82 | 83 | if "%1" == "dirhtml" ( 84 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 85 | if errorlevel 1 exit /b 1 86 | echo. 87 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 88 | goto end 89 | ) 90 | 91 | if "%1" == "singlehtml" ( 92 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 93 | if errorlevel 1 exit /b 1 94 | echo. 95 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 96 | goto end 97 | ) 98 | 99 | if "%1" == "pickle" ( 100 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 101 | if errorlevel 1 exit /b 1 102 | echo. 103 | echo.Build finished; now you can process the pickle files. 104 | goto end 105 | ) 106 | 107 | if "%1" == "json" ( 108 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 109 | if errorlevel 1 exit /b 1 110 | echo. 111 | echo.Build finished; now you can process the JSON files. 112 | goto end 113 | ) 114 | 115 | if "%1" == "htmlhelp" ( 116 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 117 | if errorlevel 1 exit /b 1 118 | echo. 119 | echo.Build finished; now you can run HTML Help Workshop with the ^ 120 | .hhp project file in %BUILDDIR%/htmlhelp. 121 | goto end 122 | ) 123 | 124 | if "%1" == "qthelp" ( 125 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 129 | .qhcp project file in %BUILDDIR%/qthelp, like this: 130 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\stanza.qhcp 131 | echo.To view the help file: 132 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\stanza.ghc 133 | goto end 134 | ) 135 | 136 | if "%1" == "devhelp" ( 137 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 138 | if errorlevel 1 exit /b 1 139 | echo. 140 | echo.Build finished. 141 | goto end 142 | ) 143 | 144 | if "%1" == "epub" ( 145 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 146 | if errorlevel 1 exit /b 1 147 | echo. 148 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 149 | goto end 150 | ) 151 | 152 | if "%1" == "latex" ( 153 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 154 | if errorlevel 1 exit /b 1 155 | echo. 156 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 157 | goto end 158 | ) 159 | 160 | if "%1" == "latexpdf" ( 161 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 162 | cd %BUILDDIR%/latex 163 | make all-pdf 164 | cd %~dp0 165 | echo. 166 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 167 | goto end 168 | ) 169 | 170 | if "%1" == "latexpdfja" ( 171 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 172 | cd %BUILDDIR%/latex 173 | make all-pdf-ja 174 | cd %~dp0 175 | echo. 176 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 177 | goto end 178 | ) 179 | 180 | if "%1" == "text" ( 181 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 182 | if errorlevel 1 exit /b 1 183 | echo. 184 | echo.Build finished. The text files are in %BUILDDIR%/text. 185 | goto end 186 | ) 187 | 188 | if "%1" == "man" ( 189 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 190 | if errorlevel 1 exit /b 1 191 | echo. 192 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 193 | goto end 194 | ) 195 | 196 | if "%1" == "texinfo" ( 197 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 198 | if errorlevel 1 exit /b 1 199 | echo. 200 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 201 | goto end 202 | ) 203 | 204 | if "%1" == "gettext" ( 205 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 206 | if errorlevel 1 exit /b 1 207 | echo. 208 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 209 | goto end 210 | ) 211 | 212 | if "%1" == "changes" ( 213 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 214 | if errorlevel 1 exit /b 1 215 | echo. 216 | echo.The overview file is in %BUILDDIR%/changes. 217 | goto end 218 | ) 219 | 220 | if "%1" == "linkcheck" ( 221 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 222 | if errorlevel 1 exit /b 1 223 | echo. 224 | echo.Link check complete; look for any errors in the above output ^ 225 | or in %BUILDDIR%/linkcheck/output.txt. 226 | goto end 227 | ) 228 | 229 | if "%1" == "doctest" ( 230 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 231 | if errorlevel 1 exit /b 1 232 | echo. 233 | echo.Testing of doctests in the sources finished, look at the ^ 234 | results in %BUILDDIR%/doctest/output.txt. 235 | goto end 236 | ) 237 | 238 | if "%1" == "coverage" ( 239 | %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage 240 | if errorlevel 1 exit /b 1 241 | echo. 242 | echo.Testing of coverage in the sources finished, look at the ^ 243 | results in %BUILDDIR%/coverage/python.txt. 244 | goto end 245 | ) 246 | 247 | if "%1" == "xml" ( 248 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 249 | if errorlevel 1 exit /b 1 250 | echo. 251 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 252 | goto end 253 | ) 254 | 255 | if "%1" == "pseudoxml" ( 256 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 257 | if errorlevel 1 exit /b 1 258 | echo. 259 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 260 | goto end 261 | ) 262 | 263 | :end 264 | -------------------------------------------------------------------------------- /docs/stanza.cluster.rst: -------------------------------------------------------------------------------- 1 | stanza.cluster package 2 | ====================== 3 | 4 | stanza.cluster.pick_gpu module 5 | ------------------------------ 6 | 7 | .. automodule:: stanza.cluster.pick_gpu 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/stanza.corenlp.rst: -------------------------------------------------------------------------------- 1 | stanza.corenlp package 2 | ====================== 3 | 4 | stanza.corenlp.client module 5 | ---------------------------- 6 | 7 | .. automodule:: stanza.corenlp.client 8 | :members: 9 | :special-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/stanza.ml.rst: -------------------------------------------------------------------------------- 1 | stanza.ml package 2 | ================= 3 | 4 | stanza.ml.tensorflow_utils module 5 | --------------------------------- 6 | 7 | .. automodule:: stanza.ml.tensorflow_utils 8 | :members: 9 | :special-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/stanza.monitoring.rst: -------------------------------------------------------------------------------- 1 | stanza.monitoring package 2 | ========================= 3 | 4 | stanza.monitoring.trigger module 5 | -------------------------------- 6 | 7 | .. automodule:: stanza.monitoring.trigger 8 | :members: 9 | :special-members: 10 | :show-inheritance: 11 | 12 | stanza.monitoring.experiment module 13 | ----------------------------------- 14 | 15 | .. automodule:: stanza.monitoring.experiment 16 | :members: 17 | :special-members: 18 | :show-inheritance: 19 | 20 | stanza.monitoring.progress module 21 | --------------------------------- 22 | 23 | .. automodule:: stanza.monitoring.progress 24 | :members: 25 | :special-members: 26 | :show-inheritance: 27 | 28 | stanza.monitoring.summary module 29 | -------------------------------- 30 | 31 | .. automodule:: stanza.monitoring.summary 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | -------------------------------------------------------------------------------- /docs/stanza.research.rst: -------------------------------------------------------------------------------- 1 | stanza.research package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | stanza.research.codalab module 8 | ------------------------------ 9 | 10 | .. automodule:: stanza.research.codalab 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | stanza.research.config module 16 | ----------------------------- 17 | 18 | .. automodule:: stanza.research.config 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | stanza.research.evaluate module 24 | ------------------------------- 25 | 26 | .. automodule:: stanza.research.evaluate 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | stanza.research.instance module 32 | ------------------------------- 33 | 34 | .. automodule:: stanza.research.instance 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | stanza.research.iterators module 40 | -------------------------------- 41 | 42 | .. automodule:: stanza.research.iterators 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | stanza.research.learner module 48 | ------------------------------ 49 | 50 | .. automodule:: stanza.research.learner 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | stanza.research.logfile module 56 | ------------------------------ 57 | 58 | .. automodule:: stanza.research.logfile 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | stanza.research.metrics module 64 | ------------------------------ 65 | 66 | .. automodule:: stanza.research.metrics 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | stanza.research.mockfs module 72 | ----------------------------- 73 | 74 | .. automodule:: stanza.research.mockfs 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | stanza.research.output module 80 | ----------------------------- 81 | 82 | .. automodule:: stanza.research.output 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | stanza.research.rng module 88 | -------------------------- 89 | 90 | .. automodule:: stanza.research.rng 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | stanza.research.summary_basic module 96 | ------------------------------------ 97 | 98 | .. automodule:: stanza.research.summary_basic 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | 104 | Module contents 105 | --------------- 106 | 107 | .. automodule:: stanza.research 108 | :members: 109 | :undoc-members: 110 | :show-inheritance: 111 | -------------------------------------------------------------------------------- /docs/stanza.text.rst: -------------------------------------------------------------------------------- 1 | stanza.text package 2 | =================== 3 | 4 | 5 | stanza.text.dataset module 6 | -------------------------- 7 | 8 | .. automodule:: stanza.text.dataset 9 | :members: 10 | :special-members: 11 | :show-inheritance: 12 | 13 | stanza.text.vocab module 14 | ------------------------ 15 | 16 | .. automodule:: stanza.text.vocab 17 | :members: 18 | :special-members: 19 | :show-inheritance: 20 | -------------------------------------------------------------------------------- /docs/stanza.util.rst: -------------------------------------------------------------------------------- 1 | stanza.util package 2 | =================== 3 | 4 | stanza.util.resource module 5 | --------------------------- 6 | 7 | .. automodule:: stanza.util.resource 8 | :members: 9 | :special-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /examples/convert_to_rst.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | set -e 3 | 4 | jupyter nbconvert *.ipynb --to rst 5 | for f in *.rst; do 6 | echo 'moving ${f}' 7 | mv $f ../docs/example.${f} 8 | done 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | nose 3 | pypng 4 | pyhocon 5 | ConfigArgParse==0.10.0 6 | mock 7 | pytest 8 | python-Levenshtein 9 | google 10 | protobuf 11 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor, wmonroe4, kelvinguu' 2 | 3 | from setuptools import setup, Command, find_packages 4 | 5 | 6 | class UnitTest2(Command): 7 | user_options = [] 8 | 9 | def initialize_options(self): 10 | pass 11 | 12 | def finalize_options(self): 13 | pass 14 | 15 | def run(self): 16 | import subprocess 17 | errno = subprocess.call(['python2', '-m', 'pytest', '--doctest-modules', 18 | '--ignore=stanza/research/pick_gpu.py', 19 | '--ignore=stanza/research/progress.py', 20 | '--ignore=stanza/research/summary.py', 21 | '--ignore=stanza/research/templates/third-party', 22 | 'stanza', 'test/unit_tests']) 23 | raise SystemExit(errno) 24 | 25 | class UnitTest3(Command): 26 | user_options = [] 27 | 28 | def initialize_options(self): 29 | pass 30 | 31 | def finalize_options(self): 32 | pass 33 | 34 | def run(self): 35 | import subprocess 36 | errno = subprocess.call(['python3', '-m', 'pytest', '--doctest-modules', 37 | '--ignore=stanza/research/pick_gpu.py', 38 | '--ignore=stanza/research/progress.py', 39 | '--ignore=stanza/research/summary.py', 40 | '--ignore=stanza/research/templates/third-party', 41 | 'stanza', 'test/unit_tests']) 42 | raise SystemExit(errno) 43 | 44 | class SlowTest(Command): 45 | user_options = [] 46 | 47 | def initialize_options(self): 48 | pass 49 | 50 | def finalize_options(self): 51 | pass 52 | 53 | def run(self): 54 | import subprocess 55 | errno = subprocess.call(['py.test', '--doctest-modules', 'test/slow_tests']) 56 | raise SystemExit(errno) 57 | 58 | 59 | class AllTest(Command): 60 | user_options = [] 61 | 62 | def initialize_options(self): 63 | pass 64 | 65 | def finalize_options(self): 66 | pass 67 | 68 | def run(self): 69 | import subprocess 70 | errno = subprocess.call(['py.test', '--doctest-modules']) 71 | raise SystemExit(errno) 72 | 73 | 74 | setup( 75 | name='stanza', 76 | version='0.3', 77 | packages=find_packages(exclude=['docs', 'test']), 78 | url='https://github.com/stanfordnlp/stanza', 79 | license='MIT', 80 | author='Stanford NLP', 81 | author_email='victor@victorzhong.com', 82 | description='NLP library for Python', 83 | cmdclass={'test': UnitTest2, 'test3' : UnitTest3, 'slow_test': SlowTest, 'all_test': AllTest}, 84 | download_url='https://github.com/stanfordnlp/stanza/tarball/0.1', 85 | keywords=['nlp', 'neural networks', 'machine learning'], 86 | ) 87 | -------------------------------------------------------------------------------- /stanza/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | SOURCE_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | DATA_DIR = os.path.join(SOURCE_DIR, 'data') 5 | if not os.path.isdir(DATA_DIR): 6 | try: 7 | os.makedirs(DATA_DIR) 8 | except Exception as e: 9 | sys.stderr.write("Could not create data directory at {}.\n{}".format(DATA_DIR, e)) 10 | -------------------------------------------------------------------------------- /stanza/cluster/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wmonroe4' 2 | -------------------------------------------------------------------------------- /stanza/cluster/pick_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | ''' 3 | Print the name of a device to use, either 'cpu' or 'gpu0', 'gpu1',... 4 | The least-used GPU with usage under the constant threshold will be chosen; 5 | ties are broken randomly. 6 | 7 | Can be called from the shell, with no arguments: 8 | 9 | $ python pick_gpu.py 10 | gpu0 11 | 12 | Warning: This is hacky and brittle, and can break if nvidia-smi changes 13 | in the way it formats its output. 14 | ''' 15 | __author__ = 'sbowman@stanford.edu, wmonroe4@stanford.edu' 16 | 17 | import subprocess 18 | import sys 19 | import random 20 | from collections import namedtuple 21 | 22 | 23 | USAGE_THRESHOLD = 0.8 24 | 25 | Usage = namedtuple('Usage', 'fan,mem,cpu') 26 | 27 | 28 | def best_gpu(max_usage=USAGE_THRESHOLD, verbose=False): 29 | ''' 30 | Return the name of a device to use, either 'cpu' or 'gpu0', 'gpu1',... 31 | The least-used GPU with usage under the constant threshold will be chosen; 32 | ties are broken randomly. 33 | ''' 34 | try: 35 | proc = subprocess.Popen("nvidia-smi", stdout=subprocess.PIPE, 36 | stderr=subprocess.PIPE) 37 | output, error = proc.communicate() 38 | if error: 39 | raise Exception(error) 40 | except Exception, e: 41 | sys.stderr.write("Couldn't run nvidia-smi to find best GPU, using CPU: %s\n" % str(e)) 42 | sys.stderr.write("(This is normal if you have no GPU or haven't configured CUDA.)\n") 43 | return "cpu" 44 | 45 | usages = parse_output(output) 46 | 47 | pct_usage = [max(u.mem, cpu_backoff(u)) for u in usages] 48 | max_usage = min(max_usage, min(pct_usage)) 49 | 50 | open_gpus = [index for index, usage in enumerate(usages) 51 | if max(usage.mem, cpu_backoff(usage)) <= max_usage] 52 | if verbose: 53 | print('Best GPUs:') 54 | for index in open_gpus: 55 | print('%d: %s fan, %s mem, %s cpu' % 56 | (index, format_percent(usages[index].fan), 57 | format_percent(usages[index].mem), 58 | format_percent(usages[index].cpu))) 59 | 60 | if open_gpus: 61 | result = "gpu" + str(random.choice(open_gpus)) 62 | else: 63 | result = "cpu" 64 | 65 | if verbose: 66 | print('Chosen: ' + result) 67 | return result 68 | 69 | 70 | def parse_output(output): 71 | start = output.index('|===') 72 | end = output.index('\n ') 73 | lines = output[start:end].split('\n')[2::3] 74 | fields = [line.split() for line in lines] 75 | 76 | fan_fields = [line[1] for line in fields] 77 | mem_used_fields = [line[8] for line in fields] 78 | total_mem_fields = [line[10] for line in fields] 79 | cpu_fields = [line[12] for line in fields] 80 | 81 | fan_amts = [parse_percent(f) for f in fan_fields] 82 | mem_used_amts = [parse_bytes(f) for f in mem_used_fields] 83 | total_mem_amts = [parse_bytes(f) for f in total_mem_fields] 84 | cpu_amts = [parse_percent(f) for f in cpu_fields] 85 | 86 | pct_mem_used = [(float(usage_amt) / float(total) 87 | if None not in (usage_amt, total) 88 | else None) 89 | for (usage_amt, total) in zip(mem_used_amts, total_mem_amts)] 90 | 91 | return [Usage(fan, mem, cpu) 92 | for (fan, mem, cpu) in zip(fan_amts, pct_mem_used, cpu_amts)] 93 | 94 | 95 | def parse_percent(field): 96 | try: 97 | if field.endswith('%'): 98 | return float(field[:-1]) 99 | else: 100 | return float(field) 101 | except ValueError: 102 | return None 103 | 104 | 105 | def parse_bytes(field): 106 | ''' 107 | >>> parse_bytes('24B') 108 | 24.0 109 | >>> parse_bytes('4MiB') 110 | 4194304.0 111 | ''' 112 | if field[-1] in 'bB': 113 | field = field[:-1] 114 | 115 | try: 116 | for i, prefix in enumerate('KMGTPEZ'): 117 | if field.endswith(prefix + 'i'): 118 | factor = 2 ** (10 * (i + 1)) 119 | return float(field[:-2]) * factor 120 | 121 | return float(field) 122 | except ValueError: 123 | return None 124 | 125 | 126 | def cpu_backoff(u): 127 | if u.cpu is not None: 128 | return u.cpu 129 | elif u.fan is not None: 130 | return u.fan 131 | else: 132 | return 0.0 133 | 134 | 135 | def format_percent(p): 136 | if p is None: 137 | return 'N/A' 138 | else: 139 | return '%f%%' % p 140 | 141 | 142 | def bind_theano(device=None, max_usage=USAGE_THRESHOLD, verbose=True): 143 | ''' 144 | Initialize Theano to use a certain device. If `device` is None (the 145 | default), use the device returned by calling `best_gpu` 146 | with the same parameters. 147 | 148 | This needs to be called *before* importing Theano. Currently (Dec 2015) 149 | Theano has no way of switching devices after it is bound (which happens 150 | on import). 151 | ''' 152 | if device is None: 153 | device = best_gpu(max_usage, verbose=verbose) 154 | if device and device != 'cpu': 155 | import unittest 156 | try: 157 | import theano.sandbox.cuda 158 | theano.sandbox.cuda.use(device) 159 | except (ImportError, unittest.case.SkipTest): 160 | import theano.gpuarray 161 | theano.gpuarray.use(device.replace('gpu', 'cuda')) 162 | 163 | 164 | 165 | __all__ = [ 166 | 'best_gpu', 167 | 'bind_theano', 168 | ] 169 | 170 | 171 | if __name__ == '__main__': 172 | print(best_gpu()) 173 | -------------------------------------------------------------------------------- /stanza/ml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/stanza-old/920c55d8eaa1e7105971059c66eb448a74c100d6/stanza/ml/__init__.py -------------------------------------------------------------------------------- /stanza/ml/embeddings.py: -------------------------------------------------------------------------------- 1 | from collections import Mapping 2 | from contextlib import contextmanager 3 | import logging 4 | import numpy as np 5 | from stanza.text import Vocab 6 | 7 | 8 | __author__ = 'kelvinguu' 9 | 10 | 11 | class Embeddings(Mapping): 12 | """A map from strings to vectors. 13 | 14 | Vectors are stored as a numpy array. 15 | Vectors are saved/loaded from disk using numpy.load, which is roughly 3-4 times faster 16 | than reading a text file. 17 | """ 18 | 19 | def __init__(self, array, vocab): 20 | """Create embeddings object. 21 | 22 | :param (np.array) array: has shape (vocab_size, embed_dim) 23 | :param (Vocab) vocab: a Vocab object 24 | """ 25 | assert len(array.shape) == 2 26 | assert array.shape[0] == len(vocab) # entries line up 27 | 28 | self.array = array 29 | self.vocab = vocab 30 | 31 | def __getitem__(self, w): 32 | idx = self.vocab.word2index(w) 33 | return self.array[idx] 34 | 35 | def __iter__(self): 36 | return iter(self.vocab) 37 | 38 | def __len__(self): 39 | return len(self.vocab) 40 | 41 | def __contains__(self, item): 42 | return item in self.vocab 43 | 44 | def subset(self, words): 45 | sub_vocab = self.vocab.subset(words) 46 | idxs = [self.vocab[w] for w in sub_vocab] 47 | sub_array = self.array[idxs] 48 | return self.__class__(sub_array, sub_vocab) 49 | 50 | def inner_products(self, vec): 51 | """Get the inner product of a vector with every embedding. 52 | 53 | :param (np.array) vector: the query vector 54 | 55 | :return (list[tuple[str, float]]): a map of embeddings to inner products 56 | """ 57 | products = self.array.dot(vec) 58 | return self._word_to_score(np.arange(len(products)), products) 59 | 60 | def _word_to_score(self, ids, scores): 61 | """Return a map from each word to its score. 62 | 63 | :param (np.array) ids: a vector of word ids 64 | :param (np.array) scores: a vector of scores 65 | 66 | :return (dict[unicode, float]): a map from each word (unicode) to its score (float) 67 | """ 68 | # should be 1-D vectors 69 | assert len(ids.shape) == 1 70 | assert ids.shape == scores.shape 71 | 72 | w2s = {} 73 | for i in range(len(ids)): 74 | w2s[self.vocab.index2word(ids[i])] = scores[i] 75 | return w2s 76 | 77 | def k_nearest(self, vec, k): 78 | """Get the k nearest neighbors of a vector (in terms of highest inner products). 79 | 80 | :param (np.array) vec: query vector 81 | :param (int) k: number of top neighbors to return 82 | 83 | :return (list[tuple[str, float]]): a list of (word, score) pairs, in descending order 84 | """ 85 | nbr_score_pairs = self.inner_products(vec) 86 | return sorted(nbr_score_pairs.items(), key=lambda x: x[1], reverse=True)[:k] 87 | 88 | def _init_lsh_forest(self): 89 | """Construct an LSH forest for nearest neighbor search.""" 90 | import sklearn.neighbors 91 | lshf = sklearn.neighbors.LSHForest() 92 | lshf.fit(self.array) 93 | return lshf 94 | 95 | def k_nearest_approx(self, vec, k): 96 | """Get the k nearest neighbors of a vector (in terms of cosine similarity). 97 | 98 | :param (np.array) vec: query vector 99 | :param (int) k: number of top neighbors to return 100 | 101 | :return (list[tuple[str, float]]): a list of (word, cosine similarity) pairs, in descending order 102 | """ 103 | if not hasattr(self, 'lshf'): 104 | self.lshf = self._init_lsh_forest() 105 | 106 | # TODO(kelvin): make this inner product score, to be consistent with k_nearest 107 | distances, neighbors = self.lshf.kneighbors([vec], n_neighbors=k, return_distance=True) 108 | scores = np.subtract(1, distances) 109 | nbr_score_pairs = self._word_to_score(np.squeeze(neighbors), np.squeeze(scores)) 110 | 111 | return sorted(nbr_score_pairs.items(), key=lambda x: x[1], reverse=True) 112 | 113 | def to_dict(self): 114 | """Convert to dictionary. 115 | 116 | :return (dict): A dict mapping from strings to vectors. 117 | """ 118 | d = {} 119 | for word, idx in self.vocab.iteritems(): 120 | d[word] = self.array[idx].tolist() 121 | return d 122 | 123 | @classmethod 124 | def from_dict(cls, d, unk): 125 | assert unk in d 126 | vocab = Vocab(unk) 127 | vocab.update(d) 128 | vecs = [] 129 | for i in range(len(vocab)): 130 | word = vocab.index2word(i) 131 | vec = d[word] 132 | vecs.append(vec) 133 | array = np.array(vecs) 134 | return cls(array, vocab) 135 | 136 | def to_files(self, array_file, vocab_file): 137 | """Write the embedding matrix and the vocab to files. 138 | 139 | :param (file) array_file: file to write array to 140 | :param (file) vocab_file: file to write vocab to 141 | """ 142 | logging.info('Writing array...') 143 | np.save(array_file, self.array) 144 | logging.info('Writing vocab...') 145 | self.vocab.to_file(vocab_file) 146 | 147 | @classmethod 148 | def from_files(cls, array_file, vocab_file): 149 | """Load the embedding matrix and the vocab from files. 150 | 151 | :param (file) array_file: file to read array from 152 | :param (file) vocab_file: file to read vocab from 153 | 154 | :return (Embeddings): an Embeddings object 155 | """ 156 | logging.info('Loading array...') 157 | array = np.load(array_file) 158 | logging.info('Loading vocab...') 159 | vocab = Vocab.from_file(vocab_file) 160 | return cls(array, vocab) 161 | 162 | @staticmethod 163 | @contextmanager 164 | def _path_prefix_to_files(path_prefix, mode): 165 | array_path = path_prefix + '.npy' 166 | vocab_path = path_prefix + '.vocab' 167 | with open(array_path, mode) as array_file, open(vocab_path, mode) as vocab_file: 168 | yield array_file, vocab_file 169 | 170 | def to_file_path(self, path_prefix): 171 | """Write the embedding matrix and the vocab to .npy and .vocab. 172 | 173 | :param (str) path_prefix: path prefix of the saved files 174 | """ 175 | with self._path_prefix_to_files(path_prefix, 'w') as (array_file, vocab_file): 176 | self.to_files(array_file, vocab_file) 177 | 178 | @classmethod 179 | def from_file_path(cls, path_prefix): 180 | """Load the embedding matrix and the vocab from .npy and .vocab. 181 | 182 | :param (str) path_prefix: path prefix of the saved files 183 | """ 184 | with cls._path_prefix_to_files(path_prefix, 'r') as (array_file, vocab_file): 185 | return cls.from_files(array_file, vocab_file) 186 | -------------------------------------------------------------------------------- /stanza/ml/tensorflow_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | __author__ = 'kelvinguu' 4 | 5 | 6 | def labels_to_onehots(labels, num_classes): 7 | """Convert a vector of integer class labels to a matrix of one-hot target vectors. 8 | 9 | :param labels: a vector of integer labels, 0 to num_classes. Has shape (batch_size,). 10 | :param num_classes: the total number of classes 11 | :return: has shape (batch_size, num_classes) 12 | """ 13 | batch_size = labels.get_shape().as_list()[0] 14 | 15 | with tf.name_scope("one_hot"): 16 | labels = tf.expand_dims(labels, 1) 17 | indices = tf.expand_dims(tf.range(0, batch_size, 1), 1) 18 | sparse_ptrs = tf.concat(1, [indices, labels], name="ptrs") 19 | onehots = tf.sparse_to_dense(sparse_ptrs, [batch_size, num_classes], 20 | 1.0, 0.0) 21 | return onehots -------------------------------------------------------------------------------- /stanza/monitoring/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | -------------------------------------------------------------------------------- /stanza/monitoring/crc32c.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2007 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | 19 | 20 | 21 | """Implementation of CRC-32C checksumming. 22 | 23 | See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for details on CRC-32C. 24 | 25 | This code is a manual python translation of c code generated by 26 | pycrc 0.7.1 (http://www.tty1.net/pycrc/). Command line used: 27 | './pycrc.py --model=crc-32c --generate c --algorithm=table-driven' 28 | """ 29 | 30 | 31 | 32 | 33 | import array 34 | 35 | CRC_TABLE = ( 36 | 0x00000000L, 0xf26b8303L, 0xe13b70f7L, 0x1350f3f4L, 37 | 0xc79a971fL, 0x35f1141cL, 0x26a1e7e8L, 0xd4ca64ebL, 38 | 0x8ad958cfL, 0x78b2dbccL, 0x6be22838L, 0x9989ab3bL, 39 | 0x4d43cfd0L, 0xbf284cd3L, 0xac78bf27L, 0x5e133c24L, 40 | 0x105ec76fL, 0xe235446cL, 0xf165b798L, 0x030e349bL, 41 | 0xd7c45070L, 0x25afd373L, 0x36ff2087L, 0xc494a384L, 42 | 0x9a879fa0L, 0x68ec1ca3L, 0x7bbcef57L, 0x89d76c54L, 43 | 0x5d1d08bfL, 0xaf768bbcL, 0xbc267848L, 0x4e4dfb4bL, 44 | 0x20bd8edeL, 0xd2d60dddL, 0xc186fe29L, 0x33ed7d2aL, 45 | 0xe72719c1L, 0x154c9ac2L, 0x061c6936L, 0xf477ea35L, 46 | 0xaa64d611L, 0x580f5512L, 0x4b5fa6e6L, 0xb93425e5L, 47 | 0x6dfe410eL, 0x9f95c20dL, 0x8cc531f9L, 0x7eaeb2faL, 48 | 0x30e349b1L, 0xc288cab2L, 0xd1d83946L, 0x23b3ba45L, 49 | 0xf779deaeL, 0x05125dadL, 0x1642ae59L, 0xe4292d5aL, 50 | 0xba3a117eL, 0x4851927dL, 0x5b016189L, 0xa96ae28aL, 51 | 0x7da08661L, 0x8fcb0562L, 0x9c9bf696L, 0x6ef07595L, 52 | 0x417b1dbcL, 0xb3109ebfL, 0xa0406d4bL, 0x522bee48L, 53 | 0x86e18aa3L, 0x748a09a0L, 0x67dafa54L, 0x95b17957L, 54 | 0xcba24573L, 0x39c9c670L, 0x2a993584L, 0xd8f2b687L, 55 | 0x0c38d26cL, 0xfe53516fL, 0xed03a29bL, 0x1f682198L, 56 | 0x5125dad3L, 0xa34e59d0L, 0xb01eaa24L, 0x42752927L, 57 | 0x96bf4dccL, 0x64d4cecfL, 0x77843d3bL, 0x85efbe38L, 58 | 0xdbfc821cL, 0x2997011fL, 0x3ac7f2ebL, 0xc8ac71e8L, 59 | 0x1c661503L, 0xee0d9600L, 0xfd5d65f4L, 0x0f36e6f7L, 60 | 0x61c69362L, 0x93ad1061L, 0x80fde395L, 0x72966096L, 61 | 0xa65c047dL, 0x5437877eL, 0x4767748aL, 0xb50cf789L, 62 | 0xeb1fcbadL, 0x197448aeL, 0x0a24bb5aL, 0xf84f3859L, 63 | 0x2c855cb2L, 0xdeeedfb1L, 0xcdbe2c45L, 0x3fd5af46L, 64 | 0x7198540dL, 0x83f3d70eL, 0x90a324faL, 0x62c8a7f9L, 65 | 0xb602c312L, 0x44694011L, 0x5739b3e5L, 0xa55230e6L, 66 | 0xfb410cc2L, 0x092a8fc1L, 0x1a7a7c35L, 0xe811ff36L, 67 | 0x3cdb9bddL, 0xceb018deL, 0xdde0eb2aL, 0x2f8b6829L, 68 | 0x82f63b78L, 0x709db87bL, 0x63cd4b8fL, 0x91a6c88cL, 69 | 0x456cac67L, 0xb7072f64L, 0xa457dc90L, 0x563c5f93L, 70 | 0x082f63b7L, 0xfa44e0b4L, 0xe9141340L, 0x1b7f9043L, 71 | 0xcfb5f4a8L, 0x3dde77abL, 0x2e8e845fL, 0xdce5075cL, 72 | 0x92a8fc17L, 0x60c37f14L, 0x73938ce0L, 0x81f80fe3L, 73 | 0x55326b08L, 0xa759e80bL, 0xb4091bffL, 0x466298fcL, 74 | 0x1871a4d8L, 0xea1a27dbL, 0xf94ad42fL, 0x0b21572cL, 75 | 0xdfeb33c7L, 0x2d80b0c4L, 0x3ed04330L, 0xccbbc033L, 76 | 0xa24bb5a6L, 0x502036a5L, 0x4370c551L, 0xb11b4652L, 77 | 0x65d122b9L, 0x97baa1baL, 0x84ea524eL, 0x7681d14dL, 78 | 0x2892ed69L, 0xdaf96e6aL, 0xc9a99d9eL, 0x3bc21e9dL, 79 | 0xef087a76L, 0x1d63f975L, 0x0e330a81L, 0xfc588982L, 80 | 0xb21572c9L, 0x407ef1caL, 0x532e023eL, 0xa145813dL, 81 | 0x758fe5d6L, 0x87e466d5L, 0x94b49521L, 0x66df1622L, 82 | 0x38cc2a06L, 0xcaa7a905L, 0xd9f75af1L, 0x2b9cd9f2L, 83 | 0xff56bd19L, 0x0d3d3e1aL, 0x1e6dcdeeL, 0xec064eedL, 84 | 0xc38d26c4L, 0x31e6a5c7L, 0x22b65633L, 0xd0ddd530L, 85 | 0x0417b1dbL, 0xf67c32d8L, 0xe52cc12cL, 0x1747422fL, 86 | 0x49547e0bL, 0xbb3ffd08L, 0xa86f0efcL, 0x5a048dffL, 87 | 0x8ecee914L, 0x7ca56a17L, 0x6ff599e3L, 0x9d9e1ae0L, 88 | 0xd3d3e1abL, 0x21b862a8L, 0x32e8915cL, 0xc083125fL, 89 | 0x144976b4L, 0xe622f5b7L, 0xf5720643L, 0x07198540L, 90 | 0x590ab964L, 0xab613a67L, 0xb831c993L, 0x4a5a4a90L, 91 | 0x9e902e7bL, 0x6cfbad78L, 0x7fab5e8cL, 0x8dc0dd8fL, 92 | 0xe330a81aL, 0x115b2b19L, 0x020bd8edL, 0xf0605beeL, 93 | 0x24aa3f05L, 0xd6c1bc06L, 0xc5914ff2L, 0x37faccf1L, 94 | 0x69e9f0d5L, 0x9b8273d6L, 0x88d28022L, 0x7ab90321L, 95 | 0xae7367caL, 0x5c18e4c9L, 0x4f48173dL, 0xbd23943eL, 96 | 0xf36e6f75L, 0x0105ec76L, 0x12551f82L, 0xe03e9c81L, 97 | 0x34f4f86aL, 0xc69f7b69L, 0xd5cf889dL, 0x27a40b9eL, 98 | 0x79b737baL, 0x8bdcb4b9L, 0x988c474dL, 0x6ae7c44eL, 99 | 0xbe2da0a5L, 0x4c4623a6L, 0x5f16d052L, 0xad7d5351L, 100 | ) 101 | 102 | 103 | 104 | CRC_INIT = 0xffffffffL 105 | 106 | 107 | def crc_update(crc, data): 108 | """Update CRC-32C checksum with data. 109 | 110 | Args: 111 | crc: 32-bit checksum to update as long. 112 | data: byte array, string or iterable over bytes. 113 | 114 | Returns: 115 | 32-bit updated CRC-32C as long. 116 | """ 117 | 118 | if type(data) != array.array or data.itemsize != 1: 119 | buf = array.array("B", data) 120 | else: 121 | buf = data 122 | 123 | for b in buf: 124 | table_index = (crc ^ b) & 0xff 125 | crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & 0xffffffffL 126 | return crc & 0xffffffffL 127 | 128 | 129 | def crc_finalize(crc): 130 | """Finalize CRC-32C checksum. 131 | 132 | This function should be called as last step of crc calculation. 133 | 134 | Args: 135 | crc: 32-bit checksum as long. 136 | 137 | Returns: 138 | finalized 32-bit checksum as long 139 | """ 140 | return crc ^ 0xffffffffL 141 | 142 | 143 | def crc(data): 144 | """Compute CRC-32C checksum of the data. 145 | 146 | Args: 147 | data: byte array, string or iterable over bytes. 148 | 149 | Returns: 150 | 32-bit CRC-32C checksum of data as long. 151 | """ 152 | return crc_finalize(crc_update(CRC_INIT, data)) 153 | -------------------------------------------------------------------------------- /stanza/monitoring/experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities useful for experiment setup 3 | """ 4 | __author__ = 'victor' 5 | import json 6 | 7 | 8 | class AttrDict(dict): 9 | """ 10 | A dictionary object which keys can be referenced via attributes. 11 | 12 | Example: 13 | 14 | .. code-block:: python 15 | 16 | d = AttriDict(foo=1, bar='cool') 17 | print(d) 18 | print(d['foo']) 19 | print(d.foo) 20 | print(d.bar) 21 | """ 22 | 23 | def __init__(self, *args, **kwargs): 24 | super(AttrDict, self).__init__(*args, **kwargs) 25 | self.__dict__ = self 26 | 27 | def save(self, fname): 28 | """ Saves the dictionary in json format 29 | :param fname: file to save to 30 | """ 31 | with open(fname, 'wb') as f: 32 | json.dump(self, f) 33 | 34 | @classmethod 35 | def load(cls, fname): 36 | """ Loads the dictionary from json file 37 | :param fname: file to load from 38 | :return: loaded dictionary 39 | """ 40 | with open(fname) as f: 41 | return Config(**json.load(f)) 42 | -------------------------------------------------------------------------------- /stanza/monitoring/progress.py: -------------------------------------------------------------------------------- 1 | """A module for periodically displaying progress on a hierarchy of tasks 2 | and estimating time to completion. 3 | 4 | >>> import progress, datetime 5 | >>> progress.set_resolution(datetime.datetime.resolution) # show all messages, don't sample 6 | >>> progress.start_task('Repetition', 2) 7 | >>> for rep in range(2): # doctest: +ELLIPSIS 8 | ... progress.progress(rep) 9 | ... progress.start_task('Example', 3) 10 | ... for ex in range(3): 11 | ... progress.progress(ex) 12 | ... progress.end_task() 13 | ... 14 | Repetition 0 of 2 (~0% done, ETA unknown on ...) 15 | Repetition 0 of 2, Example 0 of 3 (~0% done, ETA unknown on ...) 16 | Repetition 0 of 2, Example 1 of 3 (~17% done, ETA ...) 17 | Repetition 0 of 2, Example 2 of 3 (~33% done, ETA ...) 18 | Repetition 0 of 2, Example 3 of 3 (~50% done, ETA ...) 19 | Repetition 1 of 2 (~50% done, ETA ...) 20 | Repetition 1 of 2, Example 0 of 3 (~50% done, ETA ...) 21 | Repetition 1 of 2, Example 1 of 3 (~67% done, ETA ...) 22 | Repetition 1 of 2, Example 2 of 3 (~83% done, ETA ...) 23 | Repetition 1 of 2, Example 3 of 3 (~100% done, ETA ...) 24 | >>> progress.end_task() # doctest: +ELLIPSIS 25 | Repetition 2 of 2 (~100% done, ETA ...) 26 | """ 27 | 28 | __author__ = 'wmonroe4' 29 | 30 | 31 | import datetime 32 | import doctest 33 | from collections import namedtuple 34 | 35 | 36 | class ProgressMonitor(object): 37 | ''' 38 | Keeps track of a hierarchy of tasks and displays percent completion 39 | and estimated completion time. 40 | ''' 41 | def __init__(self, resolution=datetime.datetime.resolution): 42 | ''' 43 | Create a `ProgressMonitor` object. 44 | 45 | :param datetime.datetime resolution: The minimum interval at which 46 | progress updates are shown. The default is to show all updates. 47 | This setting can be modified after creation by assigning to 48 | the `resolution` field of a `ProgressMonitor` object. 49 | (Note that the global `progress.*` functions override this to 50 | show updates every minute by default. This can be reset by 51 | calling `progress.set_resolution(datetime.datetime.resolution)`.) 52 | ''' 53 | self.task_stack = [] 54 | self.last_report = datetime.datetime.min 55 | self.resolution = resolution 56 | self.start_time = datetime.datetime.now() 57 | 58 | def start_task(self, name, size): 59 | ''' 60 | Add a task to the stack. If, for example, `name` is `'Iteration'` and 61 | `size` is 10, progress on that task will be shown as 62 | 63 | ..., Iteration

of 10, ... 64 | 65 | :param str name: A descriptive name for the type of subtask that is 66 | being completed. 67 | :param int size: The total number of subtasks to complete. 68 | ''' 69 | if len(self.task_stack) == 0: 70 | self.start_time = datetime.datetime.now() 71 | self.task_stack.append(Task(name, size, 0)) 72 | 73 | def progress(self, p): 74 | ''' 75 | Update the current progress on the task at the top of the stack. 76 | 77 | :param int p: The current subtask number, between 0 and `size` 78 | (passed to `start_task`), inclusive. 79 | ''' 80 | self.task_stack[-1] = self.task_stack[-1]._replace(progress=p) 81 | self.progress_report() 82 | 83 | def end_task(self): 84 | ''' 85 | Remove the current task from the stack. 86 | ''' 87 | self.progress(self.task_stack[-1].size) 88 | self.task_stack.pop() 89 | 90 | def progress_report(self, force=False): 91 | ''' 92 | Print the current progress. 93 | 94 | :param bool force: If `True`, print the report regardless of the 95 | elapsed time since the last progress report. 96 | ''' 97 | now = datetime.datetime.now() 98 | if (len(self.task_stack) > 1 or self.task_stack[0] > 0) and \ 99 | now - self.last_report < self.resolution and not force: 100 | return 101 | 102 | stack_printout = ', '.join('%s %s of %s' % (t.name, t.progress, t.size) 103 | for t in self.task_stack) 104 | 105 | frac_done = self.fraction_done() 106 | if frac_done == 0.0: 107 | now_str = now.strftime('%c') 108 | eta_str = 'unknown on %s' % now_str 109 | else: 110 | elapsed = (now - self.start_time) 111 | estimated_length = elapsed.total_seconds() / frac_done 112 | eta = self.start_time + datetime.timedelta(seconds=estimated_length) 113 | eta_str = eta.strftime('%c') 114 | 115 | print '%s (~%d%% done, ETA %s)' % (stack_printout, 116 | round(frac_done * 100.0), 117 | eta_str) 118 | self.last_report = datetime.datetime.now() 119 | 120 | def fraction_done(self, start=0.0, finish=1.0, stack=None): 121 | ''' 122 | :return float: The estimated fraction of the overall task hierarchy 123 | that has been finished. A number in the range [0.0, 1.0]. 124 | ''' 125 | if stack is None: 126 | stack = self.task_stack 127 | 128 | if len(stack) == 0: 129 | return start 130 | elif stack[0].size == 0: 131 | # Avoid divide by zero 132 | return finish 133 | else: 134 | top_fraction = stack[0].progress * 1.0 / stack[0].size 135 | next_top_fraction = (stack[0].progress + 1.0) / stack[0].size 136 | inner_start = start + top_fraction * (finish - start) 137 | inner_finish = start + next_top_fraction * (finish - start) 138 | return self.fraction_done(inner_start, inner_finish, stack[1:]) 139 | 140 | 141 | Task = namedtuple('Task', ('name', 'size', 'progress')) 142 | 143 | _global_t = ProgressMonitor(resolution=datetime.timedelta(minutes=1)) 144 | 145 | 146 | def start_task(name, size): 147 | ''' 148 | Call `start_task` on a global `ProgressMonitor`. 149 | ''' 150 | _global_t.start_task(name, size) 151 | 152 | 153 | def progress(p): 154 | ''' 155 | Call `progress` on a global `ProgressMonitor`. 156 | ''' 157 | _global_t.progress(p) 158 | 159 | 160 | def end_task(): 161 | ''' 162 | Call `end_task` on a global `ProgressMonitor`. 163 | ''' 164 | _global_t.end_task() 165 | 166 | 167 | def set_resolution(res): 168 | ''' 169 | Change the resolution on the global `ProgressMonitor`. 170 | See `ProgressMonitor.__init__`. 171 | ''' 172 | _global_t.resolution = res 173 | 174 | 175 | __all__ = [ 176 | 'ProgressMonitor', 177 | 'start_task', 178 | 'progress', 179 | 'end_task', 180 | 'set_resolution', 181 | ] 182 | 183 | 184 | if __name__ == '__main__': 185 | doctest.testmod() 186 | -------------------------------------------------------------------------------- /stanza/monitoring/trigger.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | __author__ = 'victor, kelvinguu' 4 | import numpy as np 5 | 6 | 7 | class Trigger(object): 8 | """ 9 | A generic Trigger object that performs some action on some event 10 | """ 11 | pass 12 | 13 | 14 | class StatefulTriggerMixin(object): 15 | """ 16 | A mix-in denoting triggers with memory 17 | """ 18 | 19 | def reset(self): 20 | """ 21 | reset the Trigger to its initial state (eg. by clear its memory) 22 | """ 23 | raise NotImplementedError() 24 | 25 | 26 | class MetricTrigger(object): 27 | """ 28 | An abstract class denoting triggers that are based on some metric 29 | """ 30 | 31 | def __call__(self, *args, **kwargs): 32 | raise NotImplementedError() 33 | 34 | 35 | class ThresholdTrigger(MetricTrigger): 36 | """ 37 | Triggers when the variable crosses the min threshold or the max threshold. 38 | """ 39 | 40 | def __init__(self, min_threshold=-float('inf'), max_threshold=float('inf')): 41 | """ 42 | :param min_threshold: if the variable crosses this threshold then the trigger returns True. 43 | :param max_threshold: if the variable crosses this threshold then the trigger returns True. 44 | """ 45 | super(MetricTrigger, self).__init__() 46 | self.min = min_threshold 47 | self.max = max_threshold 48 | 49 | def __call__(self, new_value): 50 | """ 51 | :return: whether the value exceeds the predefined thresholds. 52 | """ 53 | return new_value > self.max or new_value < self.min 54 | 55 | 56 | class PatienceTrigger(MetricTrigger, StatefulTriggerMixin): 57 | """ 58 | Triggers when N time steps has elapsed since the best value 59 | for the variable was encountered. N is denoted by the patience parameter. 60 | """ 61 | 62 | def __init__(self, patience): 63 | """ 64 | :param patience: how many consecutive suboptimal values to tolerate before triggering. 65 | """ 66 | super(PatienceTrigger, self).__init__() 67 | self.patience = patience 68 | self.best_so_far = -float('inf') 69 | self.time_since_best = 0 70 | 71 | def __call__(self, new_value): 72 | """ 73 | :param new_value: value for this iteration. 74 | :return: True if `self.patience` consecutive suboptimal values have been seen. 75 | """ 76 | if new_value > self.best_so_far: 77 | self.best_so_far = new_value 78 | self.time_since_best = 0 79 | return False 80 | self.time_since_best += 1 81 | return self.time_since_best > self.patience 82 | 83 | def reset(self): 84 | """ 85 | reset the Trigger to its initial state (eg. by clear its memory) 86 | """ 87 | self.best_so_far = -float('inf') 88 | self.time_since_best = 0 89 | 90 | 91 | class SlopeTrigger(MetricTrigger, StatefulTriggerMixin): 92 | """ 93 | Triggers when the slope of the values in the most recent time window 94 | falls within the specified range (inclusive). 95 | 96 | The slope is approximated with a least squares fit on the data points 97 | in the window. 98 | 99 | Data points passed to the slope trigger are assumed to each be one 100 | unit apart on the x axis. 101 | """ 102 | def __init__(self, range, window_size=10): 103 | """ 104 | :param range: a tuple of minimum and maximum range to tolerate. 105 | :param window_size: how many points to use to estimate the slope 106 | """ 107 | self.range = range 108 | self.window_size = window_size 109 | self.vals = deque(maxlen=window_size) 110 | 111 | def __call__(self, new_value): 112 | """ 113 | :param new_value: value for this time step 114 | :return: True if the value falls within the predefined range. 115 | """ 116 | self.vals.append(new_value) 117 | 118 | # not enough points to robustly estimate slope 119 | if len(self.vals) < self.window_size: 120 | return False 121 | 122 | return self.range[0] <= self.slope() <= self.range[1] 123 | 124 | def slope(self): 125 | """ 126 | :return: the esitmated slope for points in the current window 127 | """ 128 | x = range(self.window_size) 129 | y = self.vals 130 | slope, bias = np.polyfit(x, y, 1) 131 | return slope 132 | 133 | def reset(self): 134 | """ 135 | reset the Trigger to its initial state (eg. by clear its memory) 136 | """ 137 | self.vals = deque(maxlen=self.window_size) 138 | -------------------------------------------------------------------------------- /stanza/nlp/__init__.py: -------------------------------------------------------------------------------- 1 | from stanza.nlp.data import * 2 | from stanza.nlp.corenlp import * -------------------------------------------------------------------------------- /stanza/nlp/data.py: -------------------------------------------------------------------------------- 1 | from abc import abstractproperty 2 | from collections import Sequence 3 | 4 | __author__ = 'kelvinguu' 5 | 6 | class Document(Sequence): 7 | """A sequence of Sentence objects.""" 8 | pass 9 | 10 | 11 | class Sentence(Sequence): 12 | """A sequence of Token objects.""" 13 | pass 14 | 15 | 16 | class Token(object): 17 | @abstractproperty 18 | def word(self): 19 | pass 20 | 21 | 22 | class Entity(object): 23 | """An 'entity' in a information extraction sense. Each entity has 24 | a type, a token sequence in a sentence and an optional canonical 25 | link (if coreference is present). """ 26 | 27 | @abstractproperty 28 | def sentence(self): 29 | """Returns the referring sentence""" 30 | pass 31 | 32 | @abstractproperty 33 | def head_token(self): 34 | """Returns the start token.""" 35 | pass 36 | 37 | @abstractproperty 38 | def token_span(self): 39 | """Returns the index of the end token.""" 40 | pass 41 | 42 | @abstractproperty 43 | def character_span(self): 44 | """Returns the index of the end character.""" 45 | pass 46 | 47 | @abstractproperty 48 | def type(self): 49 | """Returns the type of the string""" 50 | pass 51 | 52 | @abstractproperty 53 | def gloss(self): 54 | """Returns the exact string of the entity""" 55 | pass 56 | 57 | @abstractproperty 58 | def canonical_entity(self): 59 | """Returns the exact string of the canonical reference""" 60 | pass 61 | -------------------------------------------------------------------------------- /stanza/nlp/protobuf_json.py: -------------------------------------------------------------------------------- 1 | # JSON serialization support for Google's protobuf Messages 2 | # Copyright (c) 2009, Paul Dovbush 3 | # All rights reserved. 4 | # http://code.google.com/p/protobuf-json/ 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are 8 | # met: 9 | # 10 | # * Redistributions of source code must retain the above copyright 11 | # notice, this list of conditions and the following disclaimer. 12 | # * Redistributions in binary form must reproduce the above 13 | # copyright notice, this list of conditions and the following disclaimer 14 | # in the documentation and/or other materials provided with the 15 | # distribution. 16 | # * Neither the name of nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | # 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | ''' 33 | Provide serialization and de-serialization of Google's protobuf Messages into/from JSON format. 34 | ''' 35 | 36 | # groups are deprecated and not supported; 37 | # Note that preservation of unknown fields is currently not available for Python (c) google docs 38 | # extensions is not supported from 0.0.5 (due to gpb2.3 changes) 39 | 40 | __version__='0.0.6' 41 | __author__='Paul Dovbush ' 42 | 43 | 44 | import six 45 | from functools import partial 46 | from google.protobuf.descriptor import FieldDescriptor as FD 47 | 48 | class ParseError(Exception): pass 49 | 50 | 51 | def json2pb(pb, js, useFieldNumber=False): 52 | ''' convert JSON string to google.protobuf.descriptor instance ''' 53 | for field in pb.DESCRIPTOR.fields: 54 | if useFieldNumber: 55 | key = field.number 56 | else: 57 | key = field.name 58 | if key not in js: 59 | continue 60 | if field.type == FD.TYPE_MESSAGE: 61 | pass 62 | elif field.type in _js2ftype: 63 | ftype = _js2ftype[field.type] 64 | else: 65 | raise ParseError("Field %s.%s of type '%d' is not supported" % (pb.__class__.__name__, field.name, field.type, )) 66 | value = js[key] 67 | if field.label == FD.LABEL_REPEATED: 68 | pb_value = getattr(pb, field.name, None) 69 | for v in value: 70 | if field.type == FD.TYPE_MESSAGE: 71 | json2pb(pb_value.add(), v, useFieldNumber=useFieldNumber) 72 | else: 73 | pb_value.append(ftype(v)) 74 | else: 75 | if field.type == FD.TYPE_MESSAGE: 76 | json2pb(getattr(pb, field.name, None), value, useFieldNumber=useFieldNumber) 77 | else: 78 | setattr(pb, field.name, ftype(value)) 79 | return pb 80 | 81 | 82 | 83 | def pb2json(pb, useFieldNumber=False): 84 | ''' convert google.protobuf.descriptor instance to JSON string ''' 85 | js = {} 86 | # fields = pb.DESCRIPTOR.fields #all fields 87 | fields = pb.ListFields() #only filled (including extensions) 88 | for field,value in fields: 89 | if useFieldNumber: 90 | key = field.number 91 | else: 92 | key = field.name 93 | if field.type == FD.TYPE_MESSAGE: 94 | ftype = partial(pb2json, useFieldNumber=useFieldNumber) 95 | elif field.type in _ftype2js: 96 | ftype = _ftype2js[field.type] 97 | else: 98 | raise ParseError("Field %s.%s of type '%d' is not supported" % (pb.__class__.__name__, field.name, field.type, )) 99 | if field.label == FD.LABEL_REPEATED: 100 | js_value = [] 101 | for v in value: 102 | js_value.append(ftype(v)) 103 | else: 104 | js_value = ftype(value) 105 | js[key] = js_value 106 | return js 107 | 108 | 109 | _ftype2js = { 110 | FD.TYPE_DOUBLE: float, 111 | FD.TYPE_FLOAT: float, 112 | FD.TYPE_INT64: long if six.PY2 else int, 113 | FD.TYPE_UINT64: long if six.PY2 else int, 114 | FD.TYPE_INT32: int, 115 | FD.TYPE_FIXED64: float, 116 | FD.TYPE_FIXED32: float, 117 | FD.TYPE_BOOL: bool, 118 | FD.TYPE_STRING: unicode if six.PY2 else str, 119 | #FD.TYPE_MESSAGE: pb2json, #handled specially 120 | FD.TYPE_BYTES: lambda x: x.encode('string_escape'), 121 | FD.TYPE_UINT32: int, 122 | FD.TYPE_ENUM: int, 123 | FD.TYPE_SFIXED32: float, 124 | FD.TYPE_SFIXED64: float, 125 | FD.TYPE_SINT32: int, 126 | FD.TYPE_SINT64: long if six.PY2 else int, 127 | } 128 | 129 | _js2ftype = { 130 | FD.TYPE_DOUBLE: float, 131 | FD.TYPE_FLOAT: float, 132 | FD.TYPE_INT64: long if six.PY2 else int, 133 | FD.TYPE_UINT64: long if six.PY2 else int, 134 | FD.TYPE_INT32: int, 135 | FD.TYPE_FIXED64: float, 136 | FD.TYPE_FIXED32: float, 137 | FD.TYPE_BOOL: bool, 138 | FD.TYPE_STRING: unicode if six.PY2 else str, 139 | # FD.TYPE_MESSAGE: json2pb, #handled specially 140 | FD.TYPE_BYTES: lambda x: x.decode('string_escape'), 141 | FD.TYPE_UINT32: int, 142 | FD.TYPE_ENUM: int, 143 | FD.TYPE_SFIXED32: float, 144 | FD.TYPE_SFIXED64: float, 145 | FD.TYPE_SINT32: int, 146 | FD.TYPE_SINT64: long if six.PY2 else int, 147 | } 148 | -------------------------------------------------------------------------------- /stanza/research/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/stanza-old/920c55d8eaa1e7105971059c66eb448a74c100d6/stanza/research/__init__.py -------------------------------------------------------------------------------- /stanza/research/bleu.py: -------------------------------------------------------------------------------- 1 | ''' 2 | An implementation of unsmoothed corpus-level BLEU. 3 | ''' 4 | 5 | __author__ = 'wmonroe4' 6 | 7 | from math import log, exp 8 | from collections import Counter 9 | 10 | 11 | def corpus_bleu(reference_groups, predictions): 12 | nums = [0] * 4 13 | denoms = [0] * 4 14 | prediction_len = 0 15 | reference_len = 0 16 | 17 | for refs, pred in zip(reference_groups, predictions): 18 | for n in range(1, 5): 19 | correct, total = modified_ngram_precision(refs, pred, n) 20 | nums[n - 1] += correct 21 | denoms[n - 1] += total 22 | 23 | prediction_len += len(pred) 24 | reference_len += closest_length(refs, pred) 25 | 26 | if prediction_len <= 0: 27 | brevity_penalty = 0.0 if reference_len > 0 else 1.0 28 | else: 29 | brevity_penalty = min(1.0, exp(1.0 - reference_len * 1.0 / prediction_len)) 30 | fracs = [(num, denom) for num, denom in zip(nums, denoms) if denom != 0] 31 | if not fracs or 0 in [num for num, denom in fracs]: 32 | return 0.0 33 | else: 34 | weight = 1.0 / len(fracs) 35 | geom_mean = exp(sum(weight * log(num * 1.0 / denom) for num, denom in fracs)) 36 | return brevity_penalty * geom_mean 37 | 38 | 39 | def modified_ngram_precision(references, pred, n): 40 | ''' 41 | Borrowed from the ntlk BLEU implementation: 42 | http://www.nltk.org/_modules/nltk/translate/bleu_score.html 43 | 44 | >>> modified_ngram_precision([['the', 'fat', 'cat', 'the', 'rat']], 45 | ... ['the', 'the', 'the', 'the', 'the'], 1) 46 | (2, 5) 47 | >>> modified_ngram_precision([['the', 'fat', 'cat', 'the', 'rat']], 48 | ... ['the', 'fat', 'the', 'rat'], 2) 49 | (2, 3) 50 | ''' 51 | counts = Counter(iter_ngrams(pred, n)) 52 | max_counts = {} 53 | for reference in references: 54 | reference_counts = Counter(iter_ngrams(reference, n)) 55 | for ngram in counts: 56 | max_counts[ngram] = max(max_counts.get(ngram, 0), 57 | reference_counts[ngram]) 58 | 59 | clipped_counts = {ngram: min(count, max_counts[ngram]) 60 | for ngram, count in counts.items()} 61 | 62 | numerator = sum(clipped_counts.values()) 63 | denominator = sum(counts.values()) 64 | return numerator, denominator 65 | 66 | 67 | def iter_ngrams(s, n): 68 | return (tuple(s[i:i + n]) for i in range(len(s) - n + 1)) 69 | 70 | 71 | def closest_length(refs, pred): 72 | ''' 73 | >>> closest_length(['1234', '12345', '1'], '123') 74 | 4 75 | >>> closest_length(['123', '12345', '1'], '12') 76 | 1 77 | ''' 78 | smallest_diff = float('inf') 79 | closest_length = float('inf') 80 | for ref in refs: 81 | diff = abs(len(ref) - len(pred)) 82 | if diff < smallest_diff or (diff == smallest_diff and len(ref) < closest_length): 83 | smallest_diff = diff 84 | closest_length = len(ref) 85 | return closest_length 86 | -------------------------------------------------------------------------------- /stanza/research/codalab.py: -------------------------------------------------------------------------------- 1 | """Tools for working with CodaLab.""" 2 | import os 3 | import tempfile 4 | import subprocess 5 | import cPickle as pickle 6 | from os.path import abspath 7 | from os.path import dirname 8 | 9 | import matplotlib.image as mpimg 10 | import json 11 | import sys 12 | import platform 13 | from contextlib import contextmanager 14 | import shutil 15 | 16 | __author__ = 'kelvinguu' 17 | 18 | 19 | # need to be specified by user 20 | worksheet = None 21 | site = None 22 | 23 | # http://stackoverflow.com/questions/18421757/live-output-from-subprocess-command 24 | def shell(cmd, verbose=False, debug=False): 25 | if verbose: 26 | print cmd 27 | 28 | if debug: 29 | return # don't actually execute command 30 | 31 | output = [] 32 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True) 33 | 34 | for c in iter(lambda: process.stdout.read(1), ''): 35 | output.append(c) 36 | if verbose: 37 | sys.stdout.write(c) 38 | sys.stdout.flush() 39 | 40 | status = process.wait() 41 | if status != 0: 42 | raise RuntimeError('Error, exit code: {}'.format(status)) 43 | 44 | # TODO: make sure we get all output 45 | return ''.join(output) 46 | 47 | 48 | def get_uuids(): 49 | """List all bundle UUIDs in the worksheet.""" 50 | result = shell('cl ls -w {} -u'.format(worksheet)) 51 | uuids = result.split('\n') 52 | uuids = uuids[1:-1] # trim non uuids 53 | return uuids 54 | 55 | 56 | @contextmanager 57 | def open_file(uuid, path): 58 | """Get the raw file content within a particular bundle at a particular path. 59 | 60 | Path have no leading slash. 61 | """ 62 | # create temporary file just so we can get an unused file path 63 | f = tempfile.NamedTemporaryFile() 64 | f.close() # close and delete right away 65 | fname = f.name 66 | 67 | # download file to temporary path 68 | cmd ='cl down -o {} -w {} {}/{}'.format(fname, worksheet, uuid, path) 69 | try: 70 | shell(cmd) 71 | except RuntimeError: 72 | try: 73 | os.remove(fname) # if file exists, remove it 74 | except OSError: 75 | pass 76 | raise IOError('Failed to open file {}/{}'.format(uuid, path)) 77 | 78 | f = open(fname) 79 | yield f 80 | f.close() 81 | os.remove(fname) # delete temp file 82 | 83 | 84 | class Bundle(object): 85 | def __init__(self, uuid): 86 | self.uuid = uuid 87 | 88 | def __getattr__(self, item): 89 | """ 90 | Load attributes: history, meta on demand 91 | """ 92 | if item == 'history': 93 | try: 94 | with open_file(self.uuid, 'history.cpkl') as f: 95 | value = pickle.load(f) 96 | except IOError: 97 | value = {} 98 | 99 | elif item == 'meta': 100 | try: 101 | with open_file(self.uuid, 'meta.json') as f: 102 | value = json.load(f) 103 | except IOError: 104 | value = {} 105 | 106 | # load codalab info 107 | fields = ('uuid', 'name', 'bundle_type', 'state', 'time', 'remote') 108 | cmd = 'cl info -w {} -f {} {}'.format(worksheet, ','.join(fields), self.uuid) 109 | result = shell(cmd) 110 | info = dict(zip(fields, result.split())) 111 | value.update(info) 112 | 113 | elif item in ('stderr', 'stdout'): 114 | with open_file(self.uuid, item) as f: 115 | value = f.read() 116 | 117 | else: 118 | raise AttributeError(item) 119 | 120 | self.__setattr__(item, value) 121 | return value 122 | 123 | def __repr__(self): 124 | return self.uuid 125 | 126 | def load_img(self, img_path): 127 | """ 128 | Return an image object that can be immediately plotted with matplotlib 129 | """ 130 | with open_file(self.uuid, img_path) as f: 131 | return mpimg.imread(f) 132 | 133 | 134 | def download_logs(bundle, log_dir): 135 | if bundle.meta['bundle_type'] != 'run' or bundle.meta['state'] == 'queued': 136 | print 'Skipped {}\n'.format(bundle.uuid) 137 | return 138 | 139 | if isinstance(bundle, str): 140 | bundle = Bundle(bundle) 141 | 142 | uuid = bundle.uuid 143 | name = bundle.meta['name'] 144 | log_path = os.path.join(log_dir, '{}_{}'.format(name, uuid)) 145 | 146 | cmd ='cl down -o {} -w {} {}/logs'.format(log_path, worksheet, uuid) 147 | 148 | print uuid 149 | try: 150 | shell(cmd, verbose=True) 151 | except RuntimeError: 152 | print 'Failed to download', bundle.uuid 153 | print 154 | 155 | 156 | def report(render, uuids=None, reverse=True, limit=None): 157 | if uuids is None: 158 | uuids = get_uuids() 159 | 160 | if reverse: 161 | uuids = uuids[::-1] 162 | 163 | if limit is not None: 164 | uuids = uuids[:limit] 165 | 166 | for uuid in uuids: 167 | bundle = Bundle(uuid) 168 | try: 169 | render(bundle) 170 | except Exception: 171 | print 'Failed to render', bundle.uuid 172 | 173 | 174 | def monitor_jobs(logdir, uuids=None, reverse=True, limit=None): 175 | if os.path.exists(logdir): 176 | delete = raw_input('Overwrite existing logdir? ({})'.format(logdir)) 177 | if delete == 'y': 178 | shutil.rmtree(logdir) 179 | os.makedirs(logdir) 180 | else: 181 | os.makedirs(logdir) 182 | print 'Using logdir:', logdir 183 | 184 | report(lambda bd: download_logs(bd, logdir), uuids, reverse, limit) 185 | 186 | 187 | def tensorboard(logdir): 188 | print 'Run this in bash:' 189 | shell('tensorboard --logdir={}'.format(logdir), verbose=True, debug=True) 190 | print '\nGo to TensorBoard: http://localhost:6006/' 191 | 192 | 193 | def add_to_sys_path(path): 194 | """Add a path to the system PATH.""" 195 | sys.path.insert(0, path) 196 | 197 | 198 | def configure_matplotlib(): 199 | """Set Matplotlib backend to 'Agg', which is necessary on CodaLab docker image.""" 200 | import warnings 201 | import matplotlib 202 | with warnings.catch_warnings(): 203 | warnings.simplefilter('ignore') 204 | matplotlib.use('Agg') # needed when running from server 205 | 206 | 207 | def in_codalab(): 208 | """Check if we are running inside CodaLab Docker container or not.""" 209 | # TODO: below is a total hack. If the OS is not a Mac, we assume we're on CodaLab. 210 | return platform.system() != 'Darwin' 211 | 212 | 213 | def launch_job(job_name, cmd=None, 214 | code_dir=None, excludes='*.ipynb .git .ipynb_checkpoints', dependencies=tuple(), 215 | queue='john', image='codalab/python', memory='18g', 216 | debug=False, tail=False): 217 | """Launch a job on CodaLab (optionally upload code that the job depends on). 218 | 219 | Args: 220 | job_name: name of the job 221 | cmd: command to execute 222 | code_dir: path to code folder. If None, no code is uploaded. 223 | excludes: file types to exclude from the upload 224 | dependencies: list of other bundles that we depend on 225 | debug: if True, prints SSH commands, but does not execute them 226 | tail: show the streaming output returned by CodaLab once it launches the job 227 | """ 228 | print 'Remember to set up SSH tunnel and LOG IN through the command line before calling this.' 229 | 230 | def execute(cmd): 231 | return shell(cmd, verbose=True, debug=debug) 232 | 233 | if code_dir: 234 | execute('cl up -n code -w {} {} -x {}'.format(worksheet, code_dir, excludes)) 235 | 236 | options = '-v -n {} -w {} --request-queue {} --request-docker-image {} --request-memory {}'.format( 237 | job_name, worksheet, queue, image, memory) 238 | dep_str = ' '.join(['{0}:{0}'.format(dep) for dep in dependencies]) 239 | cmd = "cl run {} {} '{}'".format(options, dep_str, cmd) 240 | if tail: 241 | cmd += ' -t' 242 | execute(cmd) -------------------------------------------------------------------------------- /stanza/research/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configargparse 3 | import os 4 | import sys 5 | import json 6 | import logfile 7 | import traceback 8 | import StringIO 9 | import contextlib 10 | import __builtin__ 11 | from pyhocon import ConfigFactory 12 | 13 | 14 | class ArgumentParser(configargparse.Parser): 15 | def convert_setting_to_command_line_arg(self, action, key, value): 16 | args = [] 17 | if action is None: 18 | command_line_key = \ 19 | self.get_command_line_key_for_unknown_config_file_setting(key) 20 | else: 21 | command_line_key = action.option_strings[-1] 22 | 23 | if isinstance(action, argparse._StoreTrueAction): 24 | if value is True: 25 | args.append(command_line_key) 26 | elif isinstance(action, argparse._StoreFalseAction): 27 | if value is False: 28 | args.append(command_line_key) 29 | elif isinstance(action, argparse._StoreConstAction): 30 | if value == action.const: 31 | args.append(command_line_key) 32 | elif isinstance(action, argparse._CountAction): 33 | for _ in range(value): 34 | args.append(command_line_key) 35 | elif action is not None and value == action.default: 36 | pass 37 | elif isinstance(value, list): 38 | args.append(command_line_key) 39 | args.extend([str(e) for e in value]) 40 | else: 41 | args.append(command_line_key) 42 | args.append(str(value)) 43 | return args 44 | 45 | 46 | class HoconConfigFileParser(object): 47 | def parse(self, stream): 48 | try: 49 | basedir = os.path.dirname(stream.name) 50 | except AttributeError: 51 | basedir = os.getcwd() 52 | return dict(ConfigFactory.parse_string(stream.read(), basedir=basedir)) 53 | 54 | def serialize(self, items): 55 | return json.dumps(items, sort_keys=True, indent=2, separators=(',', ': ')) 56 | 57 | def get_syntax_description(self): 58 | return ('Config files should use HOCON syntax. HOCON is a superset of ' 59 | 'JSON; for more, see ' 60 | '.') 61 | 62 | 63 | _options_parser = ArgumentParser(conflict_handler='resolve', add_help=False, 64 | config_file_parser=HoconConfigFileParser(), 65 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 66 | _options_parser.add_argument('--run_dir', '-R', type=str, default=None, 67 | help='The directory in which to write log files, parameters, etc. ' 68 | 'Will be created if it does not exist. If None, output files ' 69 | 'will not be written.') 70 | _options_parser.add_argument('--config', '-C', default=None, is_config_file=True, 71 | help='Path to a JSON or HOCON file containing option settings. ' 72 | 'Can be loaded from the config.json of a previous run to rerun ' 73 | 'an experiment. If None, only options given as command line ' 74 | 'arguments will be used.') 75 | _options_parser.add_argument('--overwrite', '-O', action='store_true', 76 | help='If True, allow overwriting the contents of the run directory. ' 77 | 'Otherwise, an error will be raised if the run directory ' 78 | 'contains a config.json to prevent accidental overwriting. ') 79 | 80 | 81 | def get_options_parser(): 82 | return _options_parser 83 | 84 | 85 | _options = None 86 | 87 | 88 | def options(allow_partial=False, read=False): 89 | ''' 90 | Get the object containing the values of the parsed command line options. 91 | 92 | :param bool allow_partial: If `True`, ignore unrecognized arguments and allow 93 | the options to be re-parsed next time `options` is called. This 94 | also suppresses overwrite checking (the check is performed the first 95 | time `options` is called with `allow_partial=False`). 96 | :param bool read: If `True`, do not create or overwrite a `config.json` 97 | file, and do not check whether such file already exists. Use for scripts 98 | that read from the run directory rather than/in addition to writing to it. 99 | 100 | :return argparse.Namespace: An object storing the values of the options specified 101 | to the parser returned by `get_options_parser()`. 102 | ''' 103 | global _options 104 | 105 | if allow_partial: 106 | opts, extras = _options_parser.parse_known_args() 107 | if opts.run_dir: 108 | mkdirp(opts.run_dir) 109 | return opts 110 | 111 | if _options is None: 112 | # Add back in the help option (only show help and quit once arguments are finalized) 113 | _options_parser.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS, 114 | help='show this help message and exit') 115 | _options = _options_parser.parse_args() 116 | if _options.run_dir: 117 | mkdirp(_options.run_dir, overwrite=_options.overwrite or read) 118 | 119 | if not read: 120 | options_dump = vars(_options) 121 | # People should be able to rerun an experiment with -C config.json safely. 122 | # Don't include the overwrite option, since using a config from an experiment 123 | # done with -O should still require passing -O for it to be overwritten again. 124 | del options_dump['overwrite'] 125 | # And don't write the name of the other config file in this new one! It's 126 | # probably harmless (config file interpretation can't be chained with the 127 | # config option), but still confusing. 128 | del options_dump['config'] 129 | dump_pretty(options_dump, 'config.json') 130 | return _options 131 | 132 | 133 | class OverwriteError(Exception): 134 | pass 135 | 136 | 137 | def mkdirp(dirname, overwrite=True): 138 | ''' 139 | Create a directory at the path given by `dirname`, if it doesn't 140 | already exist. If `overwrite` is False, raise an error when trying 141 | to create a directory that already has a config.json file in it. 142 | Otherwise do nothing if the directory already exists. (Note that an 143 | existing directory without a config.json will not raise an error 144 | regardless.) 145 | 146 | http://stackoverflow.com/a/14364249/4481448 147 | ''' 148 | try: 149 | os.makedirs(dirname) 150 | except OSError: 151 | if not os.path.isdir(dirname): 152 | raise 153 | config_path = os.path.join(dirname, 'config.json') 154 | if not overwrite and os.path.lexists(config_path): 155 | raise OverwriteError('%s exists and already contains a config.json. To allow ' 156 | 'overwriting, pass the -O/--overwrite option.' % dirname) 157 | 158 | 159 | def get_file_path(filename): 160 | opts = options(allow_partial=True) 161 | if not opts.run_dir: 162 | return None 163 | return os.path.join(opts.run_dir, filename) 164 | 165 | 166 | def open(filename, *args, **kwargs): 167 | file_path = get_file_path(filename) 168 | if not file_path: 169 | # create a dummy file because we don't have a run dir 170 | return contextlib.closing(StringIO.StringIO()) 171 | return __builtin__.open(file_path, *args, **kwargs) 172 | 173 | 174 | def boolean(arg): 175 | """Convert a string to a bool treating 'false' and 'no' as False.""" 176 | if arg in ('true', 'True', 'yes', '1', 1): 177 | return True 178 | elif arg in ('false', 'False', 'no', '0', 0): 179 | return False 180 | else: 181 | raise argparse.ArgumentTypeError( 182 | 'could not interpret "%s" as true or false' % (arg,)) 183 | 184 | 185 | def redirect_output(): 186 | outfile = get_file_path('stdout.log') 187 | if outfile is None: 188 | return 189 | logfile.log_stdout_to(outfile) 190 | logfile.log_stderr_to(get_file_path('stderr.log')) 191 | 192 | 193 | def dump(data, filename, lines=False, *args, **kwargs): 194 | try: 195 | with open(filename, 'w') as outfile: 196 | if lines: 197 | for item in data: 198 | json.dump(item, outfile, *args, **kwargs) 199 | outfile.write('\n') 200 | else: 201 | json.dump(data, outfile, *args, **kwargs) 202 | except IOError: 203 | traceback.print_exc() 204 | print >>sys.stderr, 'Unable to write %s' % filename 205 | except TypeError: 206 | traceback.print_exc() 207 | print >>sys.stderr, 'Unable to write %s' % filename 208 | 209 | 210 | def dump_pretty(data, filename): 211 | dump(data, filename, 212 | sort_keys=True, indent=2, separators=(',', ': ')) 213 | -------------------------------------------------------------------------------- /stanza/research/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | 4 | from . import config 5 | 6 | 7 | def evaluate(learner, eval_data, metrics, metric_names=None, split_id=None, 8 | write_data=False): 9 | ''' 10 | Evaluate `learner` on the instances in `eval_data` according to each 11 | metric in `metric`, and return a dictionary summarizing the values of 12 | the metrics. 13 | 14 | Dump the predictions, scores, and metric summaries in JSON format 15 | to "{predictions|scores|results}.`split_id`.json" in the run directory. 16 | 17 | :param learner: The model to be evaluated. 18 | :type learner: learner.Learner 19 | 20 | :param eval_data: The data to use to evaluate the model. 21 | :type eval_data: list(instance.Instance) 22 | 23 | :param metrics: An iterable of functions that defines the standard by 24 | which predictions are evaluated. 25 | :type metrics: Iterable(function(eval_data: list(instance.Instance), 26 | predictions: list(output_type), 27 | scores: list(float)) -> list(float)) 28 | 29 | :param bool write_data: If `True`, write out the instances in `eval_data` 30 | as JSON, one per line, to the file `data..jsons`. 31 | ''' 32 | if metric_names is None: 33 | metric_names = [ 34 | (metric.__name__ if hasattr(metric, '__name__') 35 | else ('m%d' % i)) 36 | for i, metric in enumerate(metrics) 37 | ] 38 | 39 | split_prefix = split_id + '.' if split_id else '' 40 | 41 | if write_data: 42 | config.dump([inst.__dict__ for inst in eval_data], 43 | 'data.%sjsons' % split_prefix, 44 | default=json_default, lines=True) 45 | 46 | results = {split_prefix + 'num_params': learner.num_params} 47 | 48 | predictions, scores = learner.predict_and_score(eval_data) 49 | config.dump(predictions, 'predictions.%sjsons' % split_prefix, lines=True) 50 | config.dump(scores, 'scores.%sjsons' % split_prefix, lines=True) 51 | 52 | for metric, metric_name in zip(metrics, metric_names): 53 | prefix = split_prefix + (metric_name + '.' if metric_name else '') 54 | 55 | inst_outputs = metric(eval_data, predictions, scores, learner) 56 | if metric_name in ['data', 'predictions', 'scores']: 57 | warnings.warn('not outputting metric scores for metric "%s" because it would shadow ' 58 | 'another results file') 59 | else: 60 | config.dump(inst_outputs, '%s.%sjsons' % (metric_name, split_prefix), lines=True) 61 | 62 | mean = np.mean(inst_outputs) 63 | gmean = np.exp(np.log(inst_outputs).mean()) 64 | sum = np.sum(inst_outputs) 65 | std = np.std(inst_outputs) 66 | 67 | results.update({ 68 | prefix + 'mean': mean, 69 | prefix + 'gmean': gmean, 70 | prefix + 'sum': sum, 71 | prefix + 'std': std, 72 | # prefix + 'ci_lower': ci_lower, 73 | # prefix + 'ci_upper': ci_upper, 74 | }) 75 | 76 | config.dump_pretty(results, 'results.%sjson' % split_prefix) 77 | 78 | return results 79 | 80 | 81 | def json_default(o): 82 | import numpy as np 83 | if isinstance(o, np.ndarray): 84 | return o.tolist() 85 | else: 86 | return o.__dict__ 87 | -------------------------------------------------------------------------------- /stanza/research/instance.py: -------------------------------------------------------------------------------- 1 | class Instance(object): 2 | ''' 3 | Represents an individual data point in a training or testing set, for a classifier 4 | trained to predict `output` given `input`. 5 | 6 | `alt_inputs` and `alt_outputs` are optional lists of alternative stimuli and 7 | predictions, respectively, for use in pragmatic settings. `annotated_input` and 8 | `annotated_output` can be used for versions of the input and output that have 9 | been augmented with additional data, e.g. parse trees, logical forms, POS tags. 10 | 11 | `source` is the original object from which this instance is constructed. 12 | ''' 13 | def __init__(self, 14 | input, output=None, 15 | annotated_input=None, annotated_output=None, 16 | alt_inputs=None, alt_outputs=None, 17 | source=None): 18 | self.source = source 19 | self.input, self.output = input, output 20 | self.annotated_input, self.annotated_output = annotated_input, annotated_output 21 | self.alt_inputs, self.alt_outputs = alt_inputs, alt_outputs 22 | 23 | def stripped(self, include_annotated=True): 24 | ''' 25 | Return a version of this instance with all information removed that could be used 26 | to "cheat" at test time: the true output and its annotated version, and the 27 | reference to the full source. 28 | 29 | If `include_annotated` is true, `annotated_input` will also be included (but not 30 | `annotated_output` in either case). 31 | ''' 32 | return Instance(self.input, 33 | annotated_input=self.annotated_input if include_annotated else None, 34 | alt_inputs=self.alt_inputs, alt_outputs=self.alt_outputs) 35 | 36 | def inverted(self): 37 | ''' 38 | Return a version of this instance with inputs replaced by outputs and vice versa. 39 | ''' 40 | return Instance(input=self.output, output=self.input, 41 | annotated_input=self.annotated_output, 42 | annotated_output=self.annotated_input, 43 | alt_inputs=self.alt_outputs, 44 | alt_outputs=self.alt_inputs, 45 | source=self.source) 46 | 47 | def __repr__(self): 48 | return 'Instance(%s, %s)' % (repr(self.input), repr(self.output)) 49 | -------------------------------------------------------------------------------- /stanza/research/iterators.py: -------------------------------------------------------------------------------- 1 | from itertools import islice, imap, chain 2 | 3 | 4 | def iter_batches(iterable, batch_size): 5 | ''' 6 | Given a sequence or iterable, yield batches from that iterable until it 7 | runs out. Note that this function returns a generator, and also each 8 | batch will be a generator. 9 | 10 | :param iterable: The sequence or iterable to split into batches 11 | :param int batch_size: The number of elements of `iterable` to iterate over 12 | in each batch 13 | 14 | >>> batches = iter_batches('abcdefghijkl', batch_size=5) 15 | >>> list(next(batches)) 16 | ['a', 'b', 'c', 'd', 'e'] 17 | >>> list(next(batches)) 18 | ['f', 'g', 'h', 'i', 'j'] 19 | >>> list(next(batches)) 20 | ['k', 'l'] 21 | >>> list(next(batches)) 22 | Traceback (most recent call last): 23 | ... 24 | StopIteration 25 | 26 | Warning: It is important to iterate completely over each batch before 27 | requesting the next, or batch sizes will be truncated to 1. For example, 28 | making a list of all batches before asking for the contents of each 29 | will not work: 30 | 31 | >>> batches = list(iter_batches('abcdefghijkl', batch_size=5)) 32 | >>> len(batches) 33 | 12 34 | >>> list(batches[0]) 35 | ['a'] 36 | 37 | However, making a list of each individual batch as it is received will 38 | produce expected behavior (as shown in the first example). 39 | ''' 40 | # http://stackoverflow.com/a/8290514/4481448 41 | sourceiter = iter(iterable) 42 | while True: 43 | batchiter = islice(sourceiter, batch_size) 44 | yield chain([batchiter.next()], batchiter) 45 | 46 | 47 | def gen_batches(iterable, batch_size): 48 | ''' 49 | Returns a generator object that yields batches from `iterable`. 50 | See `iter_batches` for more details and caveats. 51 | 52 | Note that `iter_batches` returns an iterator, which never supports `len()`, 53 | `gen_batches` returns an iterable which supports `len()` if and only if 54 | `iterable` does. This *may* be an iterator, but could be a `SizedGenerator` 55 | object. To obtain an iterator (for example, to use the `next()` function), 56 | call `iter()` on this iterable. 57 | 58 | >>> batches = gen_batches('abcdefghijkl', batch_size=5) 59 | >>> len(batches) 60 | 3 61 | >>> for batch in batches: 62 | ... print(list(batch)) 63 | ['a', 'b', 'c', 'd', 'e'] 64 | ['f', 'g', 'h', 'i', 'j'] 65 | ['k', 'l'] 66 | ''' 67 | def batches_thunk(): 68 | return iter_batches(iterable, batch_size) 69 | 70 | try: 71 | length = len(iterable) 72 | except TypeError: 73 | return batches_thunk() 74 | 75 | num_batches = (length - 1) // batch_size + 1 76 | return SizedGenerator(batches_thunk, length=num_batches) 77 | 78 | 79 | class SizedGenerator(object): 80 | ''' 81 | A class that wraps a generator to support len(). 82 | 83 | Usage: 84 | 85 | >>> func = lambda: (x ** 2 for x in range(5)) 86 | >>> gen = SizedGenerator(func, length=5) 87 | >>> len(gen) 88 | 5 89 | >>> list(gen) 90 | [0, 1, 4, 9, 16] 91 | 92 | `length=None` can be passed to have the length be inferred (a O(n) 93 | running time operation): 94 | 95 | >>> func = lambda: (x ** 2 for x in range(5)) 96 | >>> gen = SizedGenerator(func, length=None) 97 | >>> len(gen) 98 | 5 99 | >>> list(gen) 100 | [0, 1, 4, 9, 16] 101 | 102 | Caller is responsible for assuring that provided length, if any, 103 | matches the actual length of the sequence: 104 | 105 | >>> func = lambda: (x ** 2 for x in range(8)) 106 | >>> gen = SizedGenerator(func, length=10) 107 | >>> len(gen) 108 | 10 109 | >>> len(list(gen)) 110 | 8 111 | 112 | Note that this class has the following caveats: 113 | 114 | * `func` must be a callable that can be called with no arguments. 115 | * If length=None is passed to the constructor, the sequence yielded by `func()` 116 | will be enumerated once during the construction of this object. This means 117 | `func()` must yield sequences of the same length when called multiple times. 118 | Also, assuming you plan to enumerate the sequence again to use it, this can 119 | double the time spent going through the sequence! 120 | 121 | The last requirement is because in general it is not possible to predict the 122 | length of a generator sequence, so we actually observe the output for one 123 | run-through and assume the length will stay the same on a second run-through. 124 | ''' 125 | def __init__(self, func, length): 126 | self.func = func 127 | if length is None: 128 | length = sum(1 for _ in func()) 129 | self.length = length 130 | 131 | def __len__(self): 132 | return self.length 133 | 134 | def __iter__(self): 135 | return iter(self.func()) 136 | 137 | 138 | def sized_imap(func, iterable, strict=False): 139 | ''' 140 | Return an iterable whose elements are the result of applying the callable `func` 141 | to each element of `iterable`. If `iterable` has a `len()`, then the iterable returned 142 | by this function will have the same `len()`. Otherwise calling `len()` on the 143 | returned iterable will raise `TypeError`. 144 | 145 | :param func: The function to apply to each element of `iterable`. 146 | :param iterable: An iterable whose objects will be mapped. 147 | :param bool strict: If `True` and `iterable` does not support `len()`, raise an exception 148 | immediately instead of returning an iterable that does not support `len()`. 149 | ''' 150 | try: 151 | length = len(iterable) 152 | except TypeError: 153 | if strict: 154 | raise 155 | else: 156 | return imap(func, iterable) 157 | return SizedGenerator(lambda: imap(func, iterable), length=length) 158 | -------------------------------------------------------------------------------- /stanza/research/learner.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | 3 | from . import evaluate, output 4 | 5 | 6 | class Learner(object): 7 | def __init__(self): 8 | self._using_default_separate = False 9 | self._using_default_combined = False 10 | 11 | def train(self, training_instances, validation_instances=None, metrics=None): 12 | ''' 13 | Fit a model on training data. 14 | 15 | :param training_instances: The data to use to train the model. 16 | Instances should have at least the `input` and `output` fields 17 | populated. 18 | :type training_instances: list(instance.Instance) 19 | 20 | :param validation_instances: The data to use to validate the model. 21 | Good practice says this should be held out (separate from the 22 | training set), but this API does not require that to be the case. 23 | :type validation_instances: list(instance.Instance) 24 | 25 | :param metrics: Functions like those found in the `metrics` module 26 | to use in validation. (These are not necessarily the objective function 27 | for training; subclasses define their own training objectives.) 28 | :type metrics: list(function) 29 | 30 | :returns: None 31 | ''' 32 | raise NotImplementedError 33 | 34 | def validate(self, validation_instances, metrics, iteration=None): 35 | ''' 36 | Evaluate this model on `validation_instances` during training and 37 | output a report. 38 | 39 | :param validation_instances: The data to use to validate the model. 40 | :type validation_instances: list(instance.Instance) 41 | 42 | :param metrics: Functions like those found in the `metrics` module 43 | for quantifying the performance of the learner. 44 | :type metrics: list(function) 45 | 46 | :param iteration: A label (anything with a sensible `str()` conversion) 47 | identifying the current iteration in output. 48 | ''' 49 | if not validation_instances or not metrics: 50 | return {} 51 | split_id = 'val%s' % iteration if iteration is not None else 'val' 52 | train_results = evaluate.evaluate(self, validation_instances, 53 | metrics=metrics, split_id=split_id) 54 | output.output_results(train_results, split_id) 55 | return train_results 56 | 57 | def predict(self, eval_instances, random=False, verbosity=0): 58 | ''' 59 | Return most likely predictions for each testing instance in 60 | `eval_instances`. 61 | 62 | :param eval_instances: The data to use to evaluate the model. 63 | Instances should have at least the `input` field populated. 64 | The `output` field need not be populated; subclasses should 65 | ignore it if it is present. 66 | :param random: If `True`, sample from the probability distribution 67 | defined by the classifier rather than output the most likely 68 | prediction. 69 | :param verbosity: The level of diagnostic output, relative to the 70 | global --verbosity option. Used to adjust output when models 71 | are composed of multiple sub-models. 72 | :type eval_instances: list(instance.Instance) 73 | 74 | :returns: list(output_type) 75 | ''' 76 | if hasattr(self, '_using_default_combined') and self._using_default_combined: 77 | raise NotImplementedError 78 | 79 | self._using_default_separate = True 80 | return self.predict_and_score(eval_instances, random=random, verbosity=verbosity)[0] 81 | 82 | def score(self, eval_instances, verbosity=0): 83 | ''' 84 | Return scores (negative log likelihoods) assigned to each testing 85 | instance in `eval_instances`. 86 | 87 | :param eval_instances: The data to use to evaluate the model. 88 | Instances should have at least the `input` and `output` fields 89 | populated. `output` is needed to define which score is to 90 | be returned. 91 | :param verbosity: The level of diagnostic output, relative to the 92 | global --verbosity option. Used to adjust output when models 93 | are composed of multiple sub-models. 94 | :type eval_instances: list(instance.Instance) 95 | 96 | :returns: list(float) 97 | ''' 98 | if hasattr(self, '_using_default_combined') and self._using_default_combined: 99 | raise NotImplementedError 100 | 101 | self._using_default_separate = True 102 | return self.predict_and_score(eval_instances, verbosity=verbosity)[1] 103 | 104 | def predict_and_score(self, eval_instances, random=False, verbosity=0): 105 | ''' 106 | Return most likely outputs and scores for the particular set of 107 | outputs given in `eval_instances`, as a tuple. Return value should 108 | be equivalent to the default implementation of 109 | 110 | return (self.predict(eval_instances), self.score(eval_instances)) 111 | 112 | but subclasses can override this to combine the two calls and reduce 113 | duplicated work. Either the two separate methods or this one (or all 114 | of them) should be overridden. 115 | 116 | :param eval_instances: The data to use to evaluate the model. 117 | Instances should have at least the `input` and `output` fields 118 | populated. `output` is needed to define which score is to 119 | be returned. 120 | :param random: If `True`, sample from the probability distribution 121 | defined by the classifier rather than output the most likely 122 | prediction. 123 | :param verbosity: The level of diagnostic output, relative to the 124 | global --verbosity option. Used to adjust output when models 125 | are composed of multiple sub-models. 126 | :type eval_instances: list(instance.Instance) 127 | 128 | :returns: tuple(list(output_type), list(float)) 129 | ''' 130 | if hasattr(self, '_using_default_separate') and self._using_default_separate: 131 | raise NotImplementedError 132 | 133 | self._using_default_combined = True 134 | return (self.predict(eval_instances, random=random, verbosity=verbosity), 135 | self.score(eval_instances, verbosity=verbosity)) 136 | 137 | def dump(self, outfile): 138 | ''' 139 | Serialize the model for this learner and write it to a file. 140 | Serialized models can be loaded back in with `load`. 141 | 142 | By default, pickle the entire object. This may not be very efficient 143 | or reliable for long-term storage; consider overriding this (and `load`) 144 | to serialize only the necessary parameters. Alternatively, you can 145 | define __getstate__ and __setstate__ for subclasses to influence how 146 | the model is pickled (see https://docs.python.org/2/library/pickle.html). 147 | 148 | :param file outfile: A file-like object where the serialized model will 149 | be written. 150 | ''' 151 | pickle.dump(self, outfile) 152 | 153 | def load(self, infile): 154 | ''' 155 | Deserialize a model from a stored file. 156 | 157 | By default, unpickle an entire object. If `dump` is overridden to 158 | use a different storage format, `load` should be as well. 159 | 160 | :param file outfile: A file-like object from which to retrieve the 161 | serialized model. 162 | ''' 163 | model = pickle.load(infile) 164 | self.__dict__.update(model.__dict__) 165 | -------------------------------------------------------------------------------- /stanza/research/logfile.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | class Tee(object): 5 | ''' 6 | A file-like object that duplicates output to two other file-like 7 | objects. 8 | 9 | Thanks to Akkana Peck for the implementation: 10 | 11 | http://shallowsky.com/blog/programming/python-tee.html 12 | ''' 13 | def __init__(self, _fd1, _fd2): 14 | self.fd1 = _fd1 15 | self.fd2 = _fd2 16 | 17 | def __del__(self): 18 | if sys is not None and self.fd1 != sys.stdout and self.fd1 != sys.stderr: 19 | self.fd1.close() 20 | if sys is not None and self.fd2 != sys.stdout and self.fd2 != sys.stderr: 21 | self.fd2.close() 22 | 23 | def write(self, text): 24 | self.fd1.write(text) 25 | self.fd2.write(text) 26 | 27 | def flush(self): 28 | self.fd1.flush() 29 | self.fd2.flush() 30 | 31 | 32 | tees = [] 33 | 34 | 35 | def log_stdout_to(logfilename): 36 | stdoutsav = sys.stdout 37 | outputlog = open(logfilename, "w") 38 | sys.stdout = Tee(stdoutsav, outputlog) 39 | tees.append(sys.stdout) 40 | 41 | 42 | def log_stderr_to(logfilename): 43 | stderrsav = sys.stderr 44 | outputlog = open(logfilename, "w") 45 | sys.stderr = Tee(stderrsav, outputlog) 46 | tees.append(sys.stderr) 47 | -------------------------------------------------------------------------------- /stanza/research/mockfs.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import mock 3 | import os 4 | import StringIO 5 | 6 | 7 | def yields(thing): 8 | yield thing 9 | 10 | 11 | class MockOpen(object): 12 | def __init__(self, test_dir): 13 | self.files = {} 14 | self.old_open = open 15 | self.test_dir = test_dir 16 | 17 | def __call__(self, filename, mode, *args, **kwargs): 18 | if filename.startswith(self.test_dir): 19 | if filename not in self.files or mode in ('w', 'w+'): 20 | self.files[filename] = StringIO.StringIO() 21 | fakefile = self.files[filename] 22 | if mode in ('r', 'r+'): 23 | fakefile.seek(0) 24 | else: 25 | fakefile.seek(0, os.SEEK_END) 26 | return contextlib.contextmanager(yields)(fakefile) 27 | else: 28 | return self.old_open(filename, *args, **kwargs) 29 | 30 | 31 | def patcher(module, test_dir): 32 | mo = MockOpen(test_dir) 33 | return mock.patch(module + '.open', mo) 34 | -------------------------------------------------------------------------------- /stanza/research/output.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def output_results(results, split_id='results', output_stream=None): 5 | ''' 6 | Log `results` readably to `output_stream`, with a header 7 | containing `split_id`. 8 | 9 | :param results: a dictionary of summary statistics from an evaluation 10 | :type results: dict(str -> object) 11 | 12 | :param str split_id: an identifier for the source of `results` (e.g. 'dev') 13 | 14 | :param file output_stream: the file-like object to which to log the results 15 | (default: stdout) 16 | :type split_id: str 17 | ''' 18 | if output_stream is None: 19 | output_stream = sys.stdout 20 | 21 | output_stream.write('----- %s -----\n' % split_id) 22 | for name in sorted(results.keys()): 23 | output_stream.write('%s: %s\n' % (name, repr(results[name]))) 24 | 25 | output_stream.flush() 26 | -------------------------------------------------------------------------------- /stanza/research/pick_gpu.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | import warnings 4 | 5 | warnings.warn('pick_gpu has been moved from stanza.research to stanza.cluster; the module in research is deprecated.') 6 | 7 | sys.modules[__name__] = importlib.import_module('...cluster.pick_gpu', __name__) 8 | -------------------------------------------------------------------------------- /stanza/research/progress.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | import warnings 4 | 5 | warnings.warn('progress has been moved from stanza.research to stanza.monitoring; the module in research is deprecated.') 6 | 7 | sys.modules[__name__] = importlib.import_module('...monitoring.progress', __name__) 8 | -------------------------------------------------------------------------------- /stanza/research/quickstart: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -o errexit 4 | 5 | if [ "$#" == "0" ]; then 6 | target_dir="." 7 | else 8 | target_dir="$1" 9 | fi 10 | 11 | script_dir="$(dirname "$(readlink -e $0)")" 12 | 13 | if [ -e "$script_dir"/templates ]; then 14 | templates_dir="$script_dir"/templates 15 | else 16 | python -c 'import stanza' 17 | stanza_dir="$(dirname "$(python -c 'import stanza; print(stanza.__file__)')")" 18 | templates_dir="$stanza_dir"/research/templates 19 | fi 20 | 21 | 22 | mkdir -p "$target_dir" 23 | echo "Copying template files" 24 | cp -nr "$templates_dir"/* "$target_dir"/ 25 | 26 | cd "$target_dir" 27 | mv -n coveragerc .coveragerc 28 | mv -n gitignore .gitignore 29 | 30 | git init 31 | git add . 32 | git commit -am 'Initial commit' 33 | 34 | # Setup subtree 35 | if [[ ! "$SUBTREE" =~ ^[Nn0Ff] ]]; then 36 | echo "Setting up Stanza subtree" 37 | # Add Stanza as a remote repo 38 | git remote add stanza https://github.com/stanfordnlp/stanza.git 39 | # Import the contents of the repo as a subtree 40 | git subtree add --prefix third-party/stanza stanza develop --squash 41 | # Put a symlink to the actual module somewhere where your code needs it 42 | ln -s third-party/stanza/stanza stanza 43 | # Add aliases for the two things you'll need to do with the subtree 44 | git config alias.stanza-update 'subtree pull --prefix third-party/stanza stanza develop --squash' 45 | git config alias.stanza-push 'subtree push --prefix third-party/stanza stanza develop' 46 | else 47 | echo "Skipping Stanza setup" 48 | fi 49 | 50 | echo "Quickstart project created at $target_dir" 51 | -------------------------------------------------------------------------------- /stanza/research/rng.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from . import config 4 | 5 | parser = config.get_options_parser() 6 | parser.add_argument('--random_seed', default='DefaultRandomSeed', 7 | help='A string for initializing the random number generator, ' 8 | 'for reproducible experiments. The string will be hashed ' 9 | "and the hash used as the seed to numpy's RandomState.") 10 | 11 | _random_state = None 12 | 13 | 14 | def get_rng(): 15 | global _random_state 16 | if _random_state is None: 17 | options, _ = parser.parse_known_args() 18 | _random_state = np.random.RandomState(np.uint32(hash(options.random_seed))) 19 | 20 | return _random_state 21 | -------------------------------------------------------------------------------- /stanza/research/summary.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | import warnings 4 | 5 | warnings.warn('summary has been moved from stanza.research to stanza.monitoring; the module in research is deprecated.') 6 | 7 | sys.modules[__name__] = importlib.import_module('...monitoring.summary', __name__) 8 | -------------------------------------------------------------------------------- /stanza/research/summary_basic.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | __author__ = 'kelvinguu' 4 | 5 | class TensorBoardLogger(object): 6 | """Log scalars to event files that can then be read by TensorBoard. 7 | 8 | This object keeps its own TF Graph, and creates a Variable on the fly 9 | for every metric you want to log. 10 | 11 | This can be easily extended to log other kinds of summary events. 12 | 13 | @wmonroe has a version that doesn't rely so heavily on the TF library. 14 | See summary.py 15 | """ 16 | 17 | def __init__(self, log_dir): 18 | self.g = tf.Graph() 19 | self.summaries = {} 20 | self.sess = tf.Session(graph=self.g) 21 | self.summ_writer = tf.train.SummaryWriter(log_dir, flush_secs=5) 22 | 23 | def log_proto(self, proto, step_num): 24 | """Log a Summary protobuf to the event file. 25 | 26 | :param proto: a Summary protobuf 27 | :param step_num: the iteration number at which this value was logged 28 | """ 29 | self.summ_writer.add_summary(proto, step_num) 30 | return proto 31 | 32 | def log(self, key, val, step_num): 33 | """Directly log a scalar value to the event file. 34 | 35 | :param string key: a name for the value 36 | :param val: a float 37 | :param step_num: the iteration number at which this value was logged 38 | """ 39 | try: 40 | ph, summ = self.summaries[key] 41 | except KeyError: 42 | # if we haven't defined a variable for this key, define one 43 | with self.g.as_default(): 44 | ph = tf.placeholder(tf.float32, (), name=key) # scalar 45 | summ = tf.scalar_summary(key, ph) 46 | self.summaries[key] = (ph, summ) 47 | 48 | summary_str = self.sess.run(summ, {ph: val}) 49 | self.summ_writer.add_summary(summary_str, step_num) 50 | return val 51 | -------------------------------------------------------------------------------- /stanza/research/templates/README.rst: -------------------------------------------------------------------------------- 1 | Stanza project README 2 | ===================== 3 | 4 | This is the default README file for a new project with Stanza. 5 | 6 | Getting Started 7 | ~~~~~~~~~~~~~~~ 8 | 9 | You'll want to: 10 | 11 | * define your first learner (baseline.py); and 12 | * add functions to load your datasets (datasets.py). 13 | 14 | You can also: 15 | 16 | * define new metrics, if the ones in stanza.research.metrics aren't adequate (metrics.py); and 17 | * tweak the method names on the wrapper class for trained models (wrapper.py). 18 | 19 | After that, you can start running experiments: 20 | 21 | ./run_experiment.py --run_dir runs/baseline 22 | -------------------------------------------------------------------------------- /stanza/research/templates/baseline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | 4 | from stanza.monitoring import progress 5 | from stanza.research import config 6 | from stanza.research.learner import Learner 7 | 8 | 9 | # TODO: name of model 10 | class BaselineLearner(Learner): 11 | def __init__(self): 12 | self.get_options() 13 | # TODO: initialize parameters 14 | 15 | def train(self, training_instances, validation_instances='ignored', metrics='ignored'): 16 | # TODO: train model 17 | pass 18 | 19 | @property 20 | def num_params(self): 21 | total = 0 22 | # TODO: count parameters 23 | return total 24 | 25 | def predict_and_score(self, eval_instances, random='ignored', verbosity=4): 26 | eval_instances = list(eval_instances) 27 | predictions = [] 28 | scores = [] 29 | 30 | if verbosity >= 1: 31 | progress.start_task('Eval instance', len(eval_instances)) 32 | 33 | for i, inst in enumerate(eval_instances): 34 | if verbosity >= 1: 35 | progress.progress(i) 36 | 37 | pred = '' # TODO: make prediction 38 | score = -float('inf') # TODO: score gold output 39 | predictions.append(pred) 40 | scores.append(score) 41 | 42 | if verbosity >= 1: 43 | progress.end_task() 44 | 45 | return predictions, scores 46 | 47 | def get_options(self): 48 | if not hasattr(self, 'options'): 49 | options = config.options() 50 | self.options = argparse.Namespace(**options.__dict__) 51 | -------------------------------------------------------------------------------- /stanza/research/templates/coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | # Add problematic libraries here 3 | # You'll also need to change setup.cfg 4 | omit = *third-party* 5 | -------------------------------------------------------------------------------- /stanza/research/templates/datasets.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | from stanza.research.instance import Instance 4 | from stanza.research.rng import get_rng 5 | 6 | 7 | rng = get_rng() 8 | 9 | 10 | # TODO: replace these silly datasets with real data 11 | def foobar_train(): 12 | return [Instance(input='foo', output='bar') for _ in range(1000)] 13 | 14 | 15 | def foobar_dev(): 16 | return [Instance(input='foo', output='bar') for _ in range(100)] 17 | 18 | 19 | def foobar_test(): 20 | return [Instance(input='foo', output='bar') for _ in range(100)] 21 | 22 | 23 | DataSource = namedtuple('DataSource', ['train_data', 'test_data']) 24 | 25 | SOURCES = { 26 | 'foobar_dev': DataSource(foobar_train, foobar_dev), 27 | 'foobar_test': DataSource(foobar_train, foobar_test), 28 | } 29 | -------------------------------------------------------------------------------- /stanza/research/templates/dependencies: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo 'Installing dependencies...' 4 | pip install ConfigArgParse 'pyhocon==0.3.18' pypng 'Protobuf>=3.0.0b2' python-Levenshtein 5 | 6 | if [ ! -e tensorflow ]; then 7 | echo 'Checking for tensorboard protos...' 8 | ( python -c 'import tensorflow.core.util.event_pb2' >/dev/null 2>&1 ) || ( 9 | echo "It looks like you don't have TensorFlow installed, so I'm putting a" 10 | echo "symlink at ./tensorflow/ to just the bare minimum you need. If you" 11 | echo "decide to install Tensorflow in the future, you can remove it." 12 | ln -s third-party/tensorflow tensorflow 13 | ) 14 | fi 15 | 16 | # TODO: 17 | # pip/conda install python libraries 18 | # download datasets 19 | 20 | echo 'Installing testing modules (optional)...' 21 | pip install nose nose-exclude coverage mock 22 | -------------------------------------------------------------------------------- /stanza/research/templates/error_analysis.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import Levenshtein as lev 4 | import numpy as np 5 | import os 6 | import warnings 7 | from collections import namedtuple 8 | 9 | from stanza.util.unicode import uprint 10 | from stanza.research import config 11 | 12 | 13 | parser = config.get_options_parser() 14 | parser.add_argument('--max_examples', type=int, default=100, 15 | help='The maximum number of examples to display in error analysis.') 16 | parser.add_argument('--html', type=config.boolean, default=False, 17 | help='If true, output errors in HTML.') 18 | 19 | Output = namedtuple('Output', 'config,results,data,scores,predictions') 20 | 21 | 22 | COLORS = ['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'white'] 23 | HTML = ['Black', 'DarkRed', 'DarkGreen', 'Olive', 'Blue', 'Purple', 'DarkCyan', 'White'] 24 | 25 | 26 | def wrap_color_html(s, color): 27 | code = COLORS.index(color) 28 | if code == -1: 29 | html_color = color 30 | else: 31 | html_color = HTML[code] 32 | return '%s' % (html_color, s) 33 | 34 | 35 | def wrap_color_shell(s, color): 36 | code = COLORS.index(color) 37 | if code == -1: 38 | raise ValueError('unrecognized color: ' + color) 39 | return '\033[1;3%dm%s\033[0m' % (code, s) 40 | 41 | 42 | def highlight(text, positions, color, html=False): 43 | chars = [] 44 | wrap_color = wrap_color_html if html else wrap_color_shell 45 | for i, c in enumerate(text): 46 | if i in positions: 47 | chars.append(wrap_color(c, color)) 48 | else: 49 | chars.append(c) 50 | return u''.join(chars) 51 | 52 | 53 | def print_error_analysis(): 54 | options = config.options(read=True) 55 | output = get_output(options.run_dir, 'eval') 56 | errors = [(inst['input'], pred, inst['output']) 57 | for inst, pred in zip(output.data, output.predictions) 58 | if inst['output'] != pred] 59 | if 0 < options.max_examples < len(errors): 60 | indices = np.random.choice(np.arange(len(errors)), size=options.max_examples, replace=False) 61 | else: 62 | indices = range(len(errors)) 63 | 64 | if options.html: 65 | print('') 66 | print('Error analysis') 67 | for i in indices: 68 | inp, pred, gold = [unicode(s).strip() for s in errors[i]] 69 | editops = lev.editops(gold, pred) 70 | print_visualization(inp, pred, gold, editops, html=options.html) 71 | if options.html: 72 | print('') 73 | 74 | 75 | def print_visualization(input_seq, pred_output_seq, 76 | gold_output_seq, editops, html=False): 77 | gold_highlights = [] 78 | pred_highlights = [] 79 | for optype, gold_idx, pred_idx in editops: 80 | gold_highlights.append(gold_idx) 81 | pred_highlights.append(pred_idx) 82 | 83 | input_seq = highlight(input_seq, pred_highlights, 'cyan', html=html) 84 | pred_output_seq = highlight(pred_output_seq, pred_highlights, 'red', html=html) 85 | gold_output_seq = highlight(gold_output_seq, gold_highlights, 'yellow', html=html) 86 | 87 | if html: 88 | print('

') 89 | br = u'
' 90 | else: 91 | br = u'' 92 | uprint(input_seq + br) 93 | uprint(pred_output_seq + br) 94 | uprint(gold_output_seq) 95 | if html: 96 | print('

') 97 | print('') 98 | 99 | 100 | def get_output(run_dir, split): 101 | config_dict = load_dict(os.path.join(run_dir, 'config.json')) 102 | 103 | results = {} 104 | for filename in glob.glob(os.path.join(run_dir, 'results.*.json')): 105 | results.update(load_dict(filename)) 106 | 107 | data = load_dataset(os.path.join(run_dir, 'data.%s.jsons' % split)) 108 | scores = load_dataset(os.path.join(run_dir, 'scores.%s.jsons' % split)) 109 | predictions = load_dataset(os.path.join(run_dir, 'predictions.%s.jsons' % split)) 110 | return Output(config_dict, results, data, scores, predictions) 111 | 112 | 113 | def load_dict(filename): 114 | try: 115 | with open(filename) as infile: 116 | return json.load(infile) 117 | except IOError, e: 118 | warnings.warn(str(e)) 119 | return {'error.message.value': str(e)} 120 | 121 | 122 | def load_dataset(filename, transform_func=(lambda x: x)): 123 | try: 124 | dataset = [] 125 | with open(filename) as infile: 126 | for line in infile: 127 | js = json.loads(line.strip()) 128 | dataset.append(transform_func(js)) 129 | return dataset 130 | except IOError, e: 131 | warnings.warn(str(e)) 132 | return [{'error': str(e)}] 133 | 134 | 135 | if __name__ == '__main__': 136 | print_error_analysis() 137 | -------------------------------------------------------------------------------- /stanza/research/templates/fasttests: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Run all tests using nose. 4 | # Dependencies: 5 | # $ pip install nose nose-exclude coverage 6 | # Make sure you have setup.cfg copied over, or this won't by default run 7 | # coverage or doctest. 8 | 9 | nosetests --exclude-dir=itest --exclude-dir=third-party/stanza/test/slow_tests 10 | -------------------------------------------------------------------------------- /stanza/research/templates/gitignore: -------------------------------------------------------------------------------- 1 | # Temporary files 2 | *.pyc 3 | *~ 4 | 5 | # Local scripts 6 | activate 7 | 8 | # Output directories 9 | runs/ 10 | -------------------------------------------------------------------------------- /stanza/research/templates/learners.py: -------------------------------------------------------------------------------- 1 | # TODO: import all learner models 2 | from baseline import BaselineLearner 3 | 4 | 5 | def new(key): 6 | ''' 7 | Construct a new learner with the class named by `key`. A list 8 | of available learners is in the dictionary `LEARNERS`. 9 | ''' 10 | return LEARNERS[key]() 11 | 12 | 13 | LEARNERS = { 14 | 'Baseline': BaselineLearner, 15 | } 16 | -------------------------------------------------------------------------------- /stanza/research/templates/metrics.py: -------------------------------------------------------------------------------- 1 | from stanza.research.metrics import * 2 | 3 | 4 | # TODO: define new metrics 5 | 6 | 7 | METRICS = { 8 | name: globals()[name] 9 | for name in dir() 10 | if (name not in ['np'] 11 | and not name.startswith('_')) 12 | } 13 | -------------------------------------------------------------------------------- /stanza/research/templates/run_experiment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from stanza.research import config 3 | config.redirect_output() 4 | 5 | import datetime 6 | 7 | from stanza.monitoring import progress 8 | from stanza.research import evaluate, output 9 | 10 | import metrics 11 | import learners 12 | import datasets 13 | 14 | parser = config.get_options_parser() 15 | parser.add_argument('--learner', default='Baseline', choices=learners.LEARNERS.keys(), 16 | help='The name of the model to use in the experiment.') 17 | parser.add_argument('--load', metavar='MODEL_FILE', default=None, 18 | help='If provided, skip training and instead load a pretrained model ' 19 | 'from the specified path. If None or an empty string, train a ' 20 | 'new model.') 21 | parser.add_argument('--train_size', type=int, default=None, 22 | help='The number of examples to use in training. This number should ' 23 | '*include* examples held out for validation. If None, use the ' 24 | 'whole training set.') 25 | parser.add_argument('--validation_size', type=int, default=0, 26 | help='The number of examples to hold out from the training set for ' 27 | 'monitoring generalization error.') 28 | parser.add_argument('--test_size', type=int, default=None, 29 | help='The number of examples to use in testing. ' 30 | 'If None, use the whole dev/test set.') 31 | parser.add_argument('--data_source', default='foobar_dev', choices=datasets.SOURCES.keys(), 32 | help='The type of data to use.') 33 | parser.add_argument('--metrics', default=['accuracy', 'perplexity', 'log_likelihood_bits'], 34 | choices=metrics.METRICS.keys(), 35 | help='The evaluation metrics to report for the experiment.') 36 | parser.add_argument('--output_train_data', type=config.boolean, default=False, 37 | help='If True, write out the training dataset (after cutting down to ' 38 | '`train_size`) as a JSON-lines file in the output directory.') 39 | parser.add_argument('--output_test_data', type=config.boolean, default=False, 40 | help='If True, write out the evaluation dataset (after cutting down to ' 41 | '`test_size`) as a JSON-lines file in the output directory.') 42 | parser.add_argument('--progress_tick', type=int, default=10, 43 | help='The number of seconds between logging progress updates.') 44 | 45 | 46 | def main(): 47 | options = config.options() 48 | 49 | progress.set_resolution(datetime.timedelta(seconds=options.progress_tick)) 50 | 51 | train_data = datasets.SOURCES[options.data_source].train_data()[:options.train_size] 52 | if options.validation_size: 53 | assert options.validation_size < len(train_data), \ 54 | ('No training data after validation split! (%d <= %d)' % 55 | (len(train_data), options.validation_size)) 56 | validation_data = train_data[-options.validation_size:] 57 | train_data = train_data[:-options.validation_size] 58 | else: 59 | validation_data = None 60 | test_data = datasets.SOURCES[options.data_source].test_data()[:options.test_size] 61 | 62 | learner = learners.new(options.learner) 63 | 64 | m = [metrics.METRICS[m] for m in options.metrics] 65 | 66 | if options.load: 67 | with open(options.load, 'rb') as infile: 68 | learner.load(infile) 69 | else: 70 | learner.train(train_data, validation_data, metrics=m) 71 | model_path = config.get_file_path('model.pkl') 72 | if model_path: 73 | with open(model_path, 'wb') as outfile: 74 | learner.dump(outfile) 75 | 76 | train_results = evaluate.evaluate(learner, train_data, metrics=m, split_id='train', 77 | write_data=options.output_train_data) 78 | output.output_results(train_results, 'train') 79 | 80 | test_results = evaluate.evaluate(learner, test_data, metrics=m, split_id='eval', 81 | write_data=options.output_test_data) 82 | output.output_results(test_results, 'eval') 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | -------------------------------------------------------------------------------- /stanza/research/templates/setup.cfg: -------------------------------------------------------------------------------- 1 | [nosetests] 2 | with-doctest=1 3 | with-coverage=1 4 | cover-package=. 5 | cover-html=1 6 | cover-html-dir=coverage_report 7 | # Add problematic libraries here 8 | # You'll also need to change .coveragerc 9 | exclude-dir=third-party 10 | -------------------------------------------------------------------------------- /stanza/research/templates/tests: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Run all tests using nose. 4 | # Dependencies: 5 | # $ pip install nose nose-exclude coverage 6 | # Make sure you have setup.cfg copied over, or this won't by default run 7 | # coverage or doctest. 8 | 9 | nosetests 10 | -------------------------------------------------------------------------------- /stanza/research/templates/third-party/tensorflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/stanza-old/920c55d8eaa1e7105971059c66eb448a74c100d6/stanza/research/templates/third-party/tensorflow/__init__.py -------------------------------------------------------------------------------- /stanza/research/templates/third-party/tensorflow/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/stanza-old/920c55d8eaa1e7105971059c66eb448a74c100d6/stanza/research/templates/third-party/tensorflow/core/__init__.py -------------------------------------------------------------------------------- /stanza/research/templates/third-party/tensorflow/core/framework/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/stanza-old/920c55d8eaa1e7105971059c66eb448a74c100d6/stanza/research/templates/third-party/tensorflow/core/framework/__init__.py -------------------------------------------------------------------------------- /stanza/research/templates/third-party/tensorflow/core/framework/graph_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorflow/core/framework/graph.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2 17 | from tensorflow.core.framework import function_pb2 as tensorflow_dot_core_dot_framework_dot_function__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='tensorflow/core/framework/graph.proto', 22 | package='tensorflow', 23 | syntax='proto3', 24 | serialized_pb=_b('\n%tensorflow/core/framework/graph.proto\x12\ntensorflow\x1a*tensorflow/core/framework/attr_value.proto\x1a(tensorflow/core/framework/function.proto\"^\n\x08GraphDef\x12!\n\x04node\x18\x01 \x03(\x0b\x32\x13.tensorflow.NodeDef\x12/\n\x07library\x18\x02 \x01(\x0b\x32\x1e.tensorflow.FunctionDefLibrary\"\xb3\x01\n\x07NodeDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02op\x18\x02 \x01(\t\x12\r\n\x05input\x18\x03 \x03(\t\x12\x0e\n\x06\x64\x65vice\x18\x04 \x01(\t\x12+\n\x04\x61ttr\x18\x05 \x03(\x0b\x32\x1d.tensorflow.NodeDef.AttrEntry\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x62\x06proto3') 25 | , 26 | dependencies=[tensorflow_dot_core_dot_framework_dot_attr__value__pb2.DESCRIPTOR,tensorflow_dot_core_dot_framework_dot_function__pb2.DESCRIPTOR,]) 27 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 28 | 29 | 30 | 31 | 32 | _GRAPHDEF = _descriptor.Descriptor( 33 | name='GraphDef', 34 | full_name='tensorflow.GraphDef', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='node', full_name='tensorflow.GraphDef.node', index=0, 41 | number=1, type=11, cpp_type=10, label=3, 42 | has_default_value=False, default_value=[], 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None), 46 | _descriptor.FieldDescriptor( 47 | name='library', full_name='tensorflow.GraphDef.library', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None), 53 | ], 54 | extensions=[ 55 | ], 56 | nested_types=[], 57 | enum_types=[ 58 | ], 59 | options=None, 60 | is_extendable=False, 61 | syntax='proto3', 62 | extension_ranges=[], 63 | oneofs=[ 64 | ], 65 | serialized_start=139, 66 | serialized_end=233, 67 | ) 68 | 69 | 70 | _NODEDEF_ATTRENTRY = _descriptor.Descriptor( 71 | name='AttrEntry', 72 | full_name='tensorflow.NodeDef.AttrEntry', 73 | filename=None, 74 | file=DESCRIPTOR, 75 | containing_type=None, 76 | fields=[ 77 | _descriptor.FieldDescriptor( 78 | name='key', full_name='tensorflow.NodeDef.AttrEntry.key', index=0, 79 | number=1, type=9, cpp_type=9, label=1, 80 | has_default_value=False, default_value=_b("").decode('utf-8'), 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | _descriptor.FieldDescriptor( 85 | name='value', full_name='tensorflow.NodeDef.AttrEntry.value', index=1, 86 | number=2, type=11, cpp_type=10, label=1, 87 | has_default_value=False, default_value=None, 88 | message_type=None, enum_type=None, containing_type=None, 89 | is_extension=False, extension_scope=None, 90 | options=None), 91 | ], 92 | extensions=[ 93 | ], 94 | nested_types=[], 95 | enum_types=[ 96 | ], 97 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 98 | is_extendable=False, 99 | syntax='proto3', 100 | extension_ranges=[], 101 | oneofs=[ 102 | ], 103 | serialized_start=349, 104 | serialized_end=415, 105 | ) 106 | 107 | _NODEDEF = _descriptor.Descriptor( 108 | name='NodeDef', 109 | full_name='tensorflow.NodeDef', 110 | filename=None, 111 | file=DESCRIPTOR, 112 | containing_type=None, 113 | fields=[ 114 | _descriptor.FieldDescriptor( 115 | name='name', full_name='tensorflow.NodeDef.name', index=0, 116 | number=1, type=9, cpp_type=9, label=1, 117 | has_default_value=False, default_value=_b("").decode('utf-8'), 118 | message_type=None, enum_type=None, containing_type=None, 119 | is_extension=False, extension_scope=None, 120 | options=None), 121 | _descriptor.FieldDescriptor( 122 | name='op', full_name='tensorflow.NodeDef.op', index=1, 123 | number=2, type=9, cpp_type=9, label=1, 124 | has_default_value=False, default_value=_b("").decode('utf-8'), 125 | message_type=None, enum_type=None, containing_type=None, 126 | is_extension=False, extension_scope=None, 127 | options=None), 128 | _descriptor.FieldDescriptor( 129 | name='input', full_name='tensorflow.NodeDef.input', index=2, 130 | number=3, type=9, cpp_type=9, label=3, 131 | has_default_value=False, default_value=[], 132 | message_type=None, enum_type=None, containing_type=None, 133 | is_extension=False, extension_scope=None, 134 | options=None), 135 | _descriptor.FieldDescriptor( 136 | name='device', full_name='tensorflow.NodeDef.device', index=3, 137 | number=4, type=9, cpp_type=9, label=1, 138 | has_default_value=False, default_value=_b("").decode('utf-8'), 139 | message_type=None, enum_type=None, containing_type=None, 140 | is_extension=False, extension_scope=None, 141 | options=None), 142 | _descriptor.FieldDescriptor( 143 | name='attr', full_name='tensorflow.NodeDef.attr', index=4, 144 | number=5, type=11, cpp_type=10, label=3, 145 | has_default_value=False, default_value=[], 146 | message_type=None, enum_type=None, containing_type=None, 147 | is_extension=False, extension_scope=None, 148 | options=None), 149 | ], 150 | extensions=[ 151 | ], 152 | nested_types=[_NODEDEF_ATTRENTRY, ], 153 | enum_types=[ 154 | ], 155 | options=None, 156 | is_extendable=False, 157 | syntax='proto3', 158 | extension_ranges=[], 159 | oneofs=[ 160 | ], 161 | serialized_start=236, 162 | serialized_end=415, 163 | ) 164 | 165 | _GRAPHDEF.fields_by_name['node'].message_type = _NODEDEF 166 | _GRAPHDEF.fields_by_name['library'].message_type = tensorflow_dot_core_dot_framework_dot_function__pb2._FUNCTIONDEFLIBRARY 167 | _NODEDEF_ATTRENTRY.fields_by_name['value'].message_type = tensorflow_dot_core_dot_framework_dot_attr__value__pb2._ATTRVALUE 168 | _NODEDEF_ATTRENTRY.containing_type = _NODEDEF 169 | _NODEDEF.fields_by_name['attr'].message_type = _NODEDEF_ATTRENTRY 170 | DESCRIPTOR.message_types_by_name['GraphDef'] = _GRAPHDEF 171 | DESCRIPTOR.message_types_by_name['NodeDef'] = _NODEDEF 172 | 173 | GraphDef = _reflection.GeneratedProtocolMessageType('GraphDef', (_message.Message,), dict( 174 | DESCRIPTOR = _GRAPHDEF, 175 | __module__ = 'tensorflow.core.framework.graph_pb2' 176 | # @@protoc_insertion_point(class_scope:tensorflow.GraphDef) 177 | )) 178 | _sym_db.RegisterMessage(GraphDef) 179 | 180 | NodeDef = _reflection.GeneratedProtocolMessageType('NodeDef', (_message.Message,), dict( 181 | 182 | AttrEntry = _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), dict( 183 | DESCRIPTOR = _NODEDEF_ATTRENTRY, 184 | __module__ = 'tensorflow.core.framework.graph_pb2' 185 | # @@protoc_insertion_point(class_scope:tensorflow.NodeDef.AttrEntry) 186 | )) 187 | , 188 | DESCRIPTOR = _NODEDEF, 189 | __module__ = 'tensorflow.core.framework.graph_pb2' 190 | # @@protoc_insertion_point(class_scope:tensorflow.NodeDef) 191 | )) 192 | _sym_db.RegisterMessage(NodeDef) 193 | _sym_db.RegisterMessage(NodeDef.AttrEntry) 194 | 195 | 196 | _NODEDEF_ATTRENTRY.has_options = True 197 | _NODEDEF_ATTRENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 198 | # @@protoc_insertion_point(module_scope) 199 | -------------------------------------------------------------------------------- /stanza/research/templates/third-party/tensorflow/core/framework/tensor_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorflow/core/framework/tensor.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2 17 | from tensorflow.core.framework import types_pb2 as tensorflow_dot_core_dot_framework_dot_types__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='tensorflow/core/framework/tensor.proto', 22 | package='tensorflow', 23 | syntax='proto3', 24 | serialized_pb=_b('\n&tensorflow/core/framework/tensor.proto\x12\ntensorflow\x1a,tensorflow/core/framework/tensor_shape.proto\x1a%tensorflow/core/framework/types.proto\"\xb5\x02\n\x0bTensorProto\x12#\n\x05\x64type\x18\x01 \x01(\x0e\x32\x14.tensorflow.DataType\x12\x32\n\x0ctensor_shape\x18\x02 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\x16\n\x0eversion_number\x18\x03 \x01(\x05\x12\x16\n\x0etensor_content\x18\x04 \x01(\x0c\x12\x15\n\tfloat_val\x18\x05 \x03(\x02\x42\x02\x10\x01\x12\x16\n\ndouble_val\x18\x06 \x03(\x01\x42\x02\x10\x01\x12\x13\n\x07int_val\x18\x07 \x03(\x05\x42\x02\x10\x01\x12\x12\n\nstring_val\x18\x08 \x03(\x0c\x12\x18\n\x0cscomplex_val\x18\t \x03(\x02\x42\x02\x10\x01\x12\x15\n\tint64_val\x18\n \x03(\x03\x42\x02\x10\x01\x12\x14\n\x08\x62ool_val\x18\x0b \x03(\x08\x42\x02\x10\x01\x62\x06proto3') 25 | , 26 | dependencies=[tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2.DESCRIPTOR,tensorflow_dot_core_dot_framework_dot_types__pb2.DESCRIPTOR,]) 27 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 28 | 29 | 30 | 31 | 32 | _TENSORPROTO = _descriptor.Descriptor( 33 | name='TensorProto', 34 | full_name='tensorflow.TensorProto', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='dtype', full_name='tensorflow.TensorProto.dtype', index=0, 41 | number=1, type=14, cpp_type=8, label=1, 42 | has_default_value=False, default_value=0, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None), 46 | _descriptor.FieldDescriptor( 47 | name='tensor_shape', full_name='tensorflow.TensorProto.tensor_shape', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None), 53 | _descriptor.FieldDescriptor( 54 | name='version_number', full_name='tensorflow.TensorProto.version_number', index=2, 55 | number=3, type=5, cpp_type=1, label=1, 56 | has_default_value=False, default_value=0, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None), 60 | _descriptor.FieldDescriptor( 61 | name='tensor_content', full_name='tensorflow.TensorProto.tensor_content', index=3, 62 | number=4, type=12, cpp_type=9, label=1, 63 | has_default_value=False, default_value=_b(""), 64 | message_type=None, enum_type=None, containing_type=None, 65 | is_extension=False, extension_scope=None, 66 | options=None), 67 | _descriptor.FieldDescriptor( 68 | name='float_val', full_name='tensorflow.TensorProto.float_val', index=4, 69 | number=5, type=2, cpp_type=6, label=3, 70 | has_default_value=False, default_value=[], 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 74 | _descriptor.FieldDescriptor( 75 | name='double_val', full_name='tensorflow.TensorProto.double_val', index=5, 76 | number=6, type=1, cpp_type=5, label=3, 77 | has_default_value=False, default_value=[], 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 81 | _descriptor.FieldDescriptor( 82 | name='int_val', full_name='tensorflow.TensorProto.int_val', index=6, 83 | number=7, type=5, cpp_type=1, label=3, 84 | has_default_value=False, default_value=[], 85 | message_type=None, enum_type=None, containing_type=None, 86 | is_extension=False, extension_scope=None, 87 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 88 | _descriptor.FieldDescriptor( 89 | name='string_val', full_name='tensorflow.TensorProto.string_val', index=7, 90 | number=8, type=12, cpp_type=9, label=3, 91 | has_default_value=False, default_value=[], 92 | message_type=None, enum_type=None, containing_type=None, 93 | is_extension=False, extension_scope=None, 94 | options=None), 95 | _descriptor.FieldDescriptor( 96 | name='scomplex_val', full_name='tensorflow.TensorProto.scomplex_val', index=8, 97 | number=9, type=2, cpp_type=6, label=3, 98 | has_default_value=False, default_value=[], 99 | message_type=None, enum_type=None, containing_type=None, 100 | is_extension=False, extension_scope=None, 101 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 102 | _descriptor.FieldDescriptor( 103 | name='int64_val', full_name='tensorflow.TensorProto.int64_val', index=9, 104 | number=10, type=3, cpp_type=2, label=3, 105 | has_default_value=False, default_value=[], 106 | message_type=None, enum_type=None, containing_type=None, 107 | is_extension=False, extension_scope=None, 108 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 109 | _descriptor.FieldDescriptor( 110 | name='bool_val', full_name='tensorflow.TensorProto.bool_val', index=10, 111 | number=11, type=8, cpp_type=7, label=3, 112 | has_default_value=False, default_value=[], 113 | message_type=None, enum_type=None, containing_type=None, 114 | is_extension=False, extension_scope=None, 115 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 116 | ], 117 | extensions=[ 118 | ], 119 | nested_types=[], 120 | enum_types=[ 121 | ], 122 | options=None, 123 | is_extendable=False, 124 | syntax='proto3', 125 | extension_ranges=[], 126 | oneofs=[ 127 | ], 128 | serialized_start=140, 129 | serialized_end=449, 130 | ) 131 | 132 | _TENSORPROTO.fields_by_name['dtype'].enum_type = tensorflow_dot_core_dot_framework_dot_types__pb2._DATATYPE 133 | _TENSORPROTO.fields_by_name['tensor_shape'].message_type = tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2._TENSORSHAPEPROTO 134 | DESCRIPTOR.message_types_by_name['TensorProto'] = _TENSORPROTO 135 | 136 | TensorProto = _reflection.GeneratedProtocolMessageType('TensorProto', (_message.Message,), dict( 137 | DESCRIPTOR = _TENSORPROTO, 138 | __module__ = 'tensorflow.core.framework.tensor_pb2' 139 | # @@protoc_insertion_point(class_scope:tensorflow.TensorProto) 140 | )) 141 | _sym_db.RegisterMessage(TensorProto) 142 | 143 | 144 | _TENSORPROTO.fields_by_name['float_val'].has_options = True 145 | _TENSORPROTO.fields_by_name['float_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 146 | _TENSORPROTO.fields_by_name['double_val'].has_options = True 147 | _TENSORPROTO.fields_by_name['double_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 148 | _TENSORPROTO.fields_by_name['int_val'].has_options = True 149 | _TENSORPROTO.fields_by_name['int_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 150 | _TENSORPROTO.fields_by_name['scomplex_val'].has_options = True 151 | _TENSORPROTO.fields_by_name['scomplex_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 152 | _TENSORPROTO.fields_by_name['int64_val'].has_options = True 153 | _TENSORPROTO.fields_by_name['int64_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 154 | _TENSORPROTO.fields_by_name['bool_val'].has_options = True 155 | _TENSORPROTO.fields_by_name['bool_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 156 | # @@protoc_insertion_point(module_scope) 157 | -------------------------------------------------------------------------------- /stanza/research/templates/third-party/tensorflow/core/framework/tensor_shape_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorflow/core/framework/tensor_shape.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='tensorflow/core/framework/tensor_shape.proto', 20 | package='tensorflow', 21 | syntax='proto3', 22 | serialized_pb=_b('\n,tensorflow/core/framework/tensor_shape.proto\x12\ntensorflow\"d\n\x10TensorShapeProto\x12-\n\x03\x64im\x18\x02 \x03(\x0b\x32 .tensorflow.TensorShapeProto.Dim\x1a!\n\x03\x44im\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\tb\x06proto3') 23 | ) 24 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 25 | 26 | 27 | 28 | 29 | _TENSORSHAPEPROTO_DIM = _descriptor.Descriptor( 30 | name='Dim', 31 | full_name='tensorflow.TensorShapeProto.Dim', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name='size', full_name='tensorflow.TensorShapeProto.Dim.size', index=0, 38 | number=1, type=3, cpp_type=2, label=1, 39 | has_default_value=False, default_value=0, 40 | message_type=None, enum_type=None, containing_type=None, 41 | is_extension=False, extension_scope=None, 42 | options=None), 43 | _descriptor.FieldDescriptor( 44 | name='name', full_name='tensorflow.TensorShapeProto.Dim.name', index=1, 45 | number=2, type=9, cpp_type=9, label=1, 46 | has_default_value=False, default_value=_b("").decode('utf-8'), 47 | message_type=None, enum_type=None, containing_type=None, 48 | is_extension=False, extension_scope=None, 49 | options=None), 50 | ], 51 | extensions=[ 52 | ], 53 | nested_types=[], 54 | enum_types=[ 55 | ], 56 | options=None, 57 | is_extendable=False, 58 | syntax='proto3', 59 | extension_ranges=[], 60 | oneofs=[ 61 | ], 62 | serialized_start=127, 63 | serialized_end=160, 64 | ) 65 | 66 | _TENSORSHAPEPROTO = _descriptor.Descriptor( 67 | name='TensorShapeProto', 68 | full_name='tensorflow.TensorShapeProto', 69 | filename=None, 70 | file=DESCRIPTOR, 71 | containing_type=None, 72 | fields=[ 73 | _descriptor.FieldDescriptor( 74 | name='dim', full_name='tensorflow.TensorShapeProto.dim', index=0, 75 | number=2, type=11, cpp_type=10, label=3, 76 | has_default_value=False, default_value=[], 77 | message_type=None, enum_type=None, containing_type=None, 78 | is_extension=False, extension_scope=None, 79 | options=None), 80 | ], 81 | extensions=[ 82 | ], 83 | nested_types=[_TENSORSHAPEPROTO_DIM, ], 84 | enum_types=[ 85 | ], 86 | options=None, 87 | is_extendable=False, 88 | syntax='proto3', 89 | extension_ranges=[], 90 | oneofs=[ 91 | ], 92 | serialized_start=60, 93 | serialized_end=160, 94 | ) 95 | 96 | _TENSORSHAPEPROTO_DIM.containing_type = _TENSORSHAPEPROTO 97 | _TENSORSHAPEPROTO.fields_by_name['dim'].message_type = _TENSORSHAPEPROTO_DIM 98 | DESCRIPTOR.message_types_by_name['TensorShapeProto'] = _TENSORSHAPEPROTO 99 | 100 | TensorShapeProto = _reflection.GeneratedProtocolMessageType('TensorShapeProto', (_message.Message,), dict( 101 | 102 | Dim = _reflection.GeneratedProtocolMessageType('Dim', (_message.Message,), dict( 103 | DESCRIPTOR = _TENSORSHAPEPROTO_DIM, 104 | __module__ = 'tensorflow.core.framework.tensor_shape_pb2' 105 | # @@protoc_insertion_point(class_scope:tensorflow.TensorShapeProto.Dim) 106 | )) 107 | , 108 | DESCRIPTOR = _TENSORSHAPEPROTO, 109 | __module__ = 'tensorflow.core.framework.tensor_shape_pb2' 110 | # @@protoc_insertion_point(class_scope:tensorflow.TensorShapeProto) 111 | )) 112 | _sym_db.RegisterMessage(TensorShapeProto) 113 | _sym_db.RegisterMessage(TensorShapeProto.Dim) 114 | 115 | 116 | # @@protoc_insertion_point(module_scope) 117 | -------------------------------------------------------------------------------- /stanza/research/templates/third-party/tensorflow/core/framework/types_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorflow/core/framework/types.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='tensorflow/core/framework/types.proto', 21 | package='tensorflow', 22 | syntax='proto3', 23 | serialized_pb=_b('\n%tensorflow/core/framework/types.proto\x12\ntensorflow*\xec\x03\n\x08\x44\x61taType\x12\x0e\n\nDT_INVALID\x10\x00\x12\x0c\n\x08\x44T_FLOAT\x10\x01\x12\r\n\tDT_DOUBLE\x10\x02\x12\x0c\n\x08\x44T_INT32\x10\x03\x12\x0c\n\x08\x44T_UINT8\x10\x04\x12\x0c\n\x08\x44T_INT16\x10\x05\x12\x0b\n\x07\x44T_INT8\x10\x06\x12\r\n\tDT_STRING\x10\x07\x12\x10\n\x0c\x44T_COMPLEX64\x10\x08\x12\x0c\n\x08\x44T_INT64\x10\t\x12\x0b\n\x07\x44T_BOOL\x10\n\x12\x0c\n\x08\x44T_QINT8\x10\x0b\x12\r\n\tDT_QUINT8\x10\x0c\x12\r\n\tDT_QINT32\x10\r\x12\x0f\n\x0b\x44T_BFLOAT16\x10\x0e\x12\x10\n\x0c\x44T_FLOAT_REF\x10\x65\x12\x11\n\rDT_DOUBLE_REF\x10\x66\x12\x10\n\x0c\x44T_INT32_REF\x10g\x12\x10\n\x0c\x44T_UINT8_REF\x10h\x12\x10\n\x0c\x44T_INT16_REF\x10i\x12\x0f\n\x0b\x44T_INT8_REF\x10j\x12\x11\n\rDT_STRING_REF\x10k\x12\x14\n\x10\x44T_COMPLEX64_REF\x10l\x12\x10\n\x0c\x44T_INT64_REF\x10m\x12\x0f\n\x0b\x44T_BOOL_REF\x10n\x12\x10\n\x0c\x44T_QINT8_REF\x10o\x12\x11\n\rDT_QUINT8_REF\x10p\x12\x11\n\rDT_QINT32_REF\x10q\x12\x13\n\x0f\x44T_BFLOAT16_REF\x10rb\x06proto3') 24 | ) 25 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 26 | 27 | _DATATYPE = _descriptor.EnumDescriptor( 28 | name='DataType', 29 | full_name='tensorflow.DataType', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | values=[ 33 | _descriptor.EnumValueDescriptor( 34 | name='DT_INVALID', index=0, number=0, 35 | options=None, 36 | type=None), 37 | _descriptor.EnumValueDescriptor( 38 | name='DT_FLOAT', index=1, number=1, 39 | options=None, 40 | type=None), 41 | _descriptor.EnumValueDescriptor( 42 | name='DT_DOUBLE', index=2, number=2, 43 | options=None, 44 | type=None), 45 | _descriptor.EnumValueDescriptor( 46 | name='DT_INT32', index=3, number=3, 47 | options=None, 48 | type=None), 49 | _descriptor.EnumValueDescriptor( 50 | name='DT_UINT8', index=4, number=4, 51 | options=None, 52 | type=None), 53 | _descriptor.EnumValueDescriptor( 54 | name='DT_INT16', index=5, number=5, 55 | options=None, 56 | type=None), 57 | _descriptor.EnumValueDescriptor( 58 | name='DT_INT8', index=6, number=6, 59 | options=None, 60 | type=None), 61 | _descriptor.EnumValueDescriptor( 62 | name='DT_STRING', index=7, number=7, 63 | options=None, 64 | type=None), 65 | _descriptor.EnumValueDescriptor( 66 | name='DT_COMPLEX64', index=8, number=8, 67 | options=None, 68 | type=None), 69 | _descriptor.EnumValueDescriptor( 70 | name='DT_INT64', index=9, number=9, 71 | options=None, 72 | type=None), 73 | _descriptor.EnumValueDescriptor( 74 | name='DT_BOOL', index=10, number=10, 75 | options=None, 76 | type=None), 77 | _descriptor.EnumValueDescriptor( 78 | name='DT_QINT8', index=11, number=11, 79 | options=None, 80 | type=None), 81 | _descriptor.EnumValueDescriptor( 82 | name='DT_QUINT8', index=12, number=12, 83 | options=None, 84 | type=None), 85 | _descriptor.EnumValueDescriptor( 86 | name='DT_QINT32', index=13, number=13, 87 | options=None, 88 | type=None), 89 | _descriptor.EnumValueDescriptor( 90 | name='DT_BFLOAT16', index=14, number=14, 91 | options=None, 92 | type=None), 93 | _descriptor.EnumValueDescriptor( 94 | name='DT_FLOAT_REF', index=15, number=101, 95 | options=None, 96 | type=None), 97 | _descriptor.EnumValueDescriptor( 98 | name='DT_DOUBLE_REF', index=16, number=102, 99 | options=None, 100 | type=None), 101 | _descriptor.EnumValueDescriptor( 102 | name='DT_INT32_REF', index=17, number=103, 103 | options=None, 104 | type=None), 105 | _descriptor.EnumValueDescriptor( 106 | name='DT_UINT8_REF', index=18, number=104, 107 | options=None, 108 | type=None), 109 | _descriptor.EnumValueDescriptor( 110 | name='DT_INT16_REF', index=19, number=105, 111 | options=None, 112 | type=None), 113 | _descriptor.EnumValueDescriptor( 114 | name='DT_INT8_REF', index=20, number=106, 115 | options=None, 116 | type=None), 117 | _descriptor.EnumValueDescriptor( 118 | name='DT_STRING_REF', index=21, number=107, 119 | options=None, 120 | type=None), 121 | _descriptor.EnumValueDescriptor( 122 | name='DT_COMPLEX64_REF', index=22, number=108, 123 | options=None, 124 | type=None), 125 | _descriptor.EnumValueDescriptor( 126 | name='DT_INT64_REF', index=23, number=109, 127 | options=None, 128 | type=None), 129 | _descriptor.EnumValueDescriptor( 130 | name='DT_BOOL_REF', index=24, number=110, 131 | options=None, 132 | type=None), 133 | _descriptor.EnumValueDescriptor( 134 | name='DT_QINT8_REF', index=25, number=111, 135 | options=None, 136 | type=None), 137 | _descriptor.EnumValueDescriptor( 138 | name='DT_QUINT8_REF', index=26, number=112, 139 | options=None, 140 | type=None), 141 | _descriptor.EnumValueDescriptor( 142 | name='DT_QINT32_REF', index=27, number=113, 143 | options=None, 144 | type=None), 145 | _descriptor.EnumValueDescriptor( 146 | name='DT_BFLOAT16_REF', index=28, number=114, 147 | options=None, 148 | type=None), 149 | ], 150 | containing_type=None, 151 | options=None, 152 | serialized_start=54, 153 | serialized_end=546, 154 | ) 155 | _sym_db.RegisterEnumDescriptor(_DATATYPE) 156 | 157 | DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE) 158 | DT_INVALID = 0 159 | DT_FLOAT = 1 160 | DT_DOUBLE = 2 161 | DT_INT32 = 3 162 | DT_UINT8 = 4 163 | DT_INT16 = 5 164 | DT_INT8 = 6 165 | DT_STRING = 7 166 | DT_COMPLEX64 = 8 167 | DT_INT64 = 9 168 | DT_BOOL = 10 169 | DT_QINT8 = 11 170 | DT_QUINT8 = 12 171 | DT_QINT32 = 13 172 | DT_BFLOAT16 = 14 173 | DT_FLOAT_REF = 101 174 | DT_DOUBLE_REF = 102 175 | DT_INT32_REF = 103 176 | DT_UINT8_REF = 104 177 | DT_INT16_REF = 105 178 | DT_INT8_REF = 106 179 | DT_STRING_REF = 107 180 | DT_COMPLEX64_REF = 108 181 | DT_INT64_REF = 109 182 | DT_BOOL_REF = 110 183 | DT_QINT8_REF = 111 184 | DT_QUINT8_REF = 112 185 | DT_QINT32_REF = 113 186 | DT_BFLOAT16_REF = 114 187 | 188 | 189 | DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE 190 | 191 | 192 | # @@protoc_insertion_point(module_scope) 193 | -------------------------------------------------------------------------------- /stanza/research/templates/third-party/tensorflow/core/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/stanza-old/920c55d8eaa1e7105971059c66eb448a74c100d6/stanza/research/templates/third-party/tensorflow/core/util/__init__.py -------------------------------------------------------------------------------- /stanza/research/templates/third-party/tensorflow/core/util/event_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorflow/core/util/event.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from tensorflow.core.framework import graph_pb2 as tensorflow_dot_core_dot_framework_dot_graph__pb2 17 | from tensorflow.core.framework import summary_pb2 as tensorflow_dot_core_dot_framework_dot_summary__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='tensorflow/core/util/event.proto', 22 | package='tensorflow', 23 | syntax='proto3', 24 | serialized_pb=_b('\n tensorflow/core/util/event.proto\x12\ntensorflow\x1a%tensorflow/core/framework/graph.proto\x1a\'tensorflow/core/framework/summary.proto\"\x9b\x01\n\x05\x45vent\x12\x11\n\twall_time\x18\x01 \x01(\x01\x12\x0c\n\x04step\x18\x02 \x01(\x03\x12\x16\n\x0c\x66ile_version\x18\x03 \x01(\tH\x00\x12)\n\tgraph_def\x18\x04 \x01(\x0b\x32\x14.tensorflow.GraphDefH\x00\x12&\n\x07summary\x18\x05 \x01(\x0b\x32\x13.tensorflow.SummaryH\x00\x42\x06\n\x04whatb\x06proto3') 25 | , 26 | dependencies=[tensorflow_dot_core_dot_framework_dot_graph__pb2.DESCRIPTOR,tensorflow_dot_core_dot_framework_dot_summary__pb2.DESCRIPTOR,]) 27 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 28 | 29 | 30 | 31 | 32 | _EVENT = _descriptor.Descriptor( 33 | name='Event', 34 | full_name='tensorflow.Event', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='wall_time', full_name='tensorflow.Event.wall_time', index=0, 41 | number=1, type=1, cpp_type=5, label=1, 42 | has_default_value=False, default_value=0, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None), 46 | _descriptor.FieldDescriptor( 47 | name='step', full_name='tensorflow.Event.step', index=1, 48 | number=2, type=3, cpp_type=2, label=1, 49 | has_default_value=False, default_value=0, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None), 53 | _descriptor.FieldDescriptor( 54 | name='file_version', full_name='tensorflow.Event.file_version', index=2, 55 | number=3, type=9, cpp_type=9, label=1, 56 | has_default_value=False, default_value=_b("").decode('utf-8'), 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None), 60 | _descriptor.FieldDescriptor( 61 | name='graph_def', full_name='tensorflow.Event.graph_def', index=3, 62 | number=4, type=11, cpp_type=10, label=1, 63 | has_default_value=False, default_value=None, 64 | message_type=None, enum_type=None, containing_type=None, 65 | is_extension=False, extension_scope=None, 66 | options=None), 67 | _descriptor.FieldDescriptor( 68 | name='summary', full_name='tensorflow.Event.summary', index=4, 69 | number=5, type=11, cpp_type=10, label=1, 70 | has_default_value=False, default_value=None, 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=None), 74 | ], 75 | extensions=[ 76 | ], 77 | nested_types=[], 78 | enum_types=[ 79 | ], 80 | options=None, 81 | is_extendable=False, 82 | syntax='proto3', 83 | extension_ranges=[], 84 | oneofs=[ 85 | _descriptor.OneofDescriptor( 86 | name='what', full_name='tensorflow.Event.what', 87 | index=0, containing_type=None, fields=[]), 88 | ], 89 | serialized_start=129, 90 | serialized_end=284, 91 | ) 92 | 93 | _EVENT.fields_by_name['graph_def'].message_type = tensorflow_dot_core_dot_framework_dot_graph__pb2._GRAPHDEF 94 | _EVENT.fields_by_name['summary'].message_type = tensorflow_dot_core_dot_framework_dot_summary__pb2._SUMMARY 95 | _EVENT.oneofs_by_name['what'].fields.append( 96 | _EVENT.fields_by_name['file_version']) 97 | _EVENT.fields_by_name['file_version'].containing_oneof = _EVENT.oneofs_by_name['what'] 98 | _EVENT.oneofs_by_name['what'].fields.append( 99 | _EVENT.fields_by_name['graph_def']) 100 | _EVENT.fields_by_name['graph_def'].containing_oneof = _EVENT.oneofs_by_name['what'] 101 | _EVENT.oneofs_by_name['what'].fields.append( 102 | _EVENT.fields_by_name['summary']) 103 | _EVENT.fields_by_name['summary'].containing_oneof = _EVENT.oneofs_by_name['what'] 104 | DESCRIPTOR.message_types_by_name['Event'] = _EVENT 105 | 106 | Event = _reflection.GeneratedProtocolMessageType('Event', (_message.Message,), dict( 107 | DESCRIPTOR = _EVENT, 108 | __module__ = 'tensorflow.core.util.event_pb2' 109 | # @@protoc_insertion_point(class_scope:tensorflow.Event) 110 | )) 111 | _sym_db.RegisterMessage(Event) 112 | 113 | 114 | # @@protoc_insertion_point(module_scope) 115 | -------------------------------------------------------------------------------- /stanza/research/templates/wrapper.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | 3 | from stanza.research.instance import Instance 4 | 5 | DEFAULT_MODEL = 'models/__INSERT_MODEL_HERE__/model.pkl' 6 | 7 | 8 | class Wrapper(object): 9 | ''' 10 | A wrapper class for pickled Learners. 11 | ''' 12 | def __init__(self, picklefile=None): 13 | ''' 14 | :param file picklefile: An open file-like object from which to 15 | load the model. Can be produced either from a normal experiment 16 | run or a quickpickle.py run. If `None`, try to load the default 17 | quickpickle file (this is less future-proof than the normal 18 | experiment-produced pickle files). 19 | ''' 20 | if picklefile is None: 21 | with open(DEFAULT_MODEL, 'rb') as infile: 22 | self.model = pickle.load(infile) 23 | else: 24 | self.model = pickle.load(picklefile) 25 | self.model.options.verbosity = 0 26 | 27 | def process(self, input): 28 | return self.process_all([input])[0] 29 | 30 | def process_all(self, inputs): 31 | insts = [Instance(i) for i in inputs] 32 | return self.model.predict(insts, verbosity=0) 33 | -------------------------------------------------------------------------------- /stanza/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools to facilitate NLP and ML experiments. 3 | """ 4 | 5 | from stanza.text.dataset import Dataset 6 | from stanza.text.vocab import Vocab, SennaVocab, GloveVocab 7 | from stanza.text.utils import to_unicode -------------------------------------------------------------------------------- /stanza/text/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset module for managing text datasets. 3 | """ 4 | __author__ = 'victor' 5 | from collections import OrderedDict 6 | import random 7 | import numpy as np 8 | 9 | 10 | class InvalidFieldsException(Exception): 11 | pass 12 | 13 | 14 | class Dataset(object): 15 | """ 16 | Generic Dataset object that encapsulates a list of instances. 17 | 18 | The dataset stores the instances in an ordered dictionary of fields. 19 | Each field maps to a list, the ith element of the list for field 'foo' corresponds to the attribute 'foo' for the ith instance in the dataset. 20 | 21 | The dataset object supports indexing, iterating, slicing (eg. for iterating over batches), shuffling, 22 | conversion to/from CONLL format, among others. 23 | 24 | Example: 25 | 26 | .. code-block:: python 27 | 28 | d = Dataset({'Name': ['Alice', 'Bob', 'Carol', 'David', 'Ellen'], 'SSN': [1, 23, 45, 56, 7890]}) 29 | print(d) # Dataset(Name, SSN) 30 | print(d[2]) # OrderedDict([('SSN', 45), ('Name', 'Carol')]) 31 | print(d[1:3]) # OrderedDict([('SSN', [23, 45]), ('Name', ['Bob', 'Carol'])]) 32 | 33 | for e in d: 34 | print(e) # OrderedDict([('SSN', 1), ('Name', 'Alice')]) ... 35 | """ 36 | 37 | def __init__(self, fields): 38 | """ 39 | :param fields: An ordered dictionary in which a key is the name of an attribute and a value is a list of the values of the instances in the dataset. 40 | 41 | :return: A Dataset object 42 | """ 43 | self.fields = OrderedDict(fields) 44 | length = None 45 | length_field = None 46 | for name, d in fields.items(): 47 | if length is None: 48 | length = len(d) 49 | length_field = name 50 | else: 51 | if len(d) != length: 52 | raise InvalidFieldsException('field {} has length {} but field {} has length {}'.format(length_field, length, name, len(d))) 53 | 54 | def __len__(self): 55 | """ 56 | :return: The number of instances in the dataset. 57 | """ 58 | if len(self.fields) == 0: 59 | return 0 60 | return len(self.fields.values()[0]) 61 | 62 | def __repr__(self): 63 | return "{}({})".format(self.__class__.__name__, ', '.join(self.fields.keys())) 64 | 65 | @classmethod 66 | def load_conll(cls, fname): 67 | """ 68 | The CONLL file must have a tab delimited header, for example:: 69 | 70 | # description tags 71 | Alice 72 | Hello t1 73 | my t2 74 | name t3 75 | is t4 76 | alice t5 77 | 78 | Bob 79 | I'm t1 80 | bob t2 81 | 82 | Here, the fields are `description` and `tags`. The first instance has the label `Alice` and the 83 | description `['Hello', 'my', 'name', 'is', 'alice']` and the tags `['t1', 't2', 't3', 't4', 't5']`. 84 | The second instance has the label `Bob` and the description `["I'm", 'bob']` and the tags `['t1', 't2']`. 85 | 86 | :param fname: The CONLL formatted file from which to load the dataset 87 | 88 | :return: loaded Dataset instance 89 | """ 90 | def process_cache(cache, fields): 91 | cache = [l.split() for l in cache if l] 92 | if not cache: 93 | return None 94 | fields['label'].append(cache[0][0]) 95 | instance = {k: [] for k in fields if k != 'label'} 96 | for l in cache[1:]: 97 | for i, k in enumerate(fields): 98 | if k != 'label': 99 | instance[k].append(None if l[i] == '-' else l[i]) 100 | for k, v in instance.items(): 101 | fields[k].append(v) 102 | 103 | cache = [] 104 | 105 | with open(fname) as f: 106 | header = f.next().strip().split('\t') 107 | header[0] = header[0].lstrip('# ') 108 | fields = OrderedDict([(head, []) for head in header]) 109 | fields['label'] = [] 110 | for line in f: 111 | line = line.strip() 112 | if line: 113 | cache.append(line) 114 | else: 115 | # met empty line, process cache 116 | process_cache(cache, fields) 117 | cache = [] 118 | if cache: 119 | process_cache(cache, fields) 120 | return cls(fields) 121 | 122 | def write_conll(self, fname): 123 | """ 124 | Serializes the dataset in CONLL format to fname 125 | """ 126 | if 'label' not in self.fields: 127 | raise InvalidFieldsException("dataset is not in CONLL format: missing label field") 128 | 129 | def instance_to_conll(inst): 130 | tab = [v for k, v in inst.items() if k != 'label'] 131 | return '{}\n{}'.format(inst['label'], '\n'.join(['\t'.join(['-' if e is None else str(e) for e in row]) for row in zip(*tab)])) 132 | 133 | with open(fname, 'wb') as f: 134 | f.write('# {}'.format('\t'.join([k for k in self.fields if k != 'label']))) 135 | for i, d in enumerate(self): 136 | f.write('\n{}'.format(instance_to_conll(d))) 137 | if i != len(self) - 1: 138 | f.write('\n') 139 | 140 | def convert(self, converters, in_place=False): 141 | """ 142 | Applies transformations to the dataset. 143 | 144 | :param converters: A dictionary specifying the function to apply to each field. If a field is missing from the dictionary, then it will not be transformed. 145 | 146 | :param in_place: Whether to perform the transformation in place or create a new dataset instance 147 | 148 | :return: the transformed dataset instance 149 | """ 150 | dataset = self if in_place else self.__class__(OrderedDict([(name, data[:]) for name, data in self.fields.items()])) 151 | for name, convert in converters.items(): 152 | if name not in self.fields.keys(): 153 | raise InvalidFieldsException('Converter specified for non-existent field {}'.format(name)) 154 | for i, d in enumerate(dataset.fields[name]): 155 | dataset.fields[name][i] = convert(d) 156 | return dataset 157 | 158 | def shuffle(self): 159 | """ 160 | Re-indexes the dataset in random order 161 | 162 | :return: the shuffled dataset instance 163 | """ 164 | order = range(len(self)) 165 | random.shuffle(order) 166 | for name, data in self.fields.items(): 167 | reindexed = [] 168 | for _, i in enumerate(order): 169 | reindexed.append(data[i]) 170 | self.fields[name] = reindexed 171 | return self 172 | 173 | def __getitem__(self, item): 174 | """ 175 | :param item: An integer index or a slice (eg. 2, 1:, 1:5) 176 | 177 | :return: an ordered dictionary of the instance(s) at index/indices `item`. 178 | """ 179 | return OrderedDict([(name, data[item]) for name, data in self.fields.items()]) 180 | 181 | def __setitem__(self, key, value): 182 | """ 183 | :param key: An integer index or a slice (eg. 2, 1:, 1:5) 184 | 185 | :param value: Sets the instances at index/indices `key` to the instances(s) `value` 186 | """ 187 | for name, data in self.fields.items(): 188 | if name not in value: 189 | raise InvalidFieldsException('field {} is missing in input data: {}'.format(name, value)) 190 | data[key] = value[name] 191 | 192 | def __iter__(self): 193 | """ 194 | :return: A iterator over the instances in the dataset 195 | """ 196 | for i in xrange(len(self)): 197 | yield self[i] 198 | 199 | def copy(self, keep_fields=None): 200 | """ 201 | :param keep_fields: if specified, then only the given fields will be kept 202 | :return: A deep copy of the dataset (each instance is copied). 203 | """ 204 | keep_fields = self.fields.keys() or keep_fields 205 | return self.__class__(OrderedDict([(name, data[:]) for name, data in self.fields.items() if name in keep_fields])) 206 | 207 | @classmethod 208 | def pad(cls, sequences, padding, pad_len=None): 209 | """ 210 | Pads a list of sequences such that they form a matrix. 211 | 212 | :param sequences: a list of sequences of varying lengths. 213 | :param padding: the value of padded cells. 214 | :param pad_len: the length of the maximum padded sequence. 215 | """ 216 | max_len = max([len(s) for s in sequences]) 217 | pad_len = pad_len or max_len 218 | assert pad_len >= max_len, 'pad_len {} must be greater or equal to the longest sequence {}'.format(pad_len, max_len) 219 | for i, s in enumerate(sequences): 220 | sequences[i] = [padding] * (pad_len - len(s)) + s 221 | return np.array(sequences) 222 | -------------------------------------------------------------------------------- /stanza/text/utils.py: -------------------------------------------------------------------------------- 1 | import six 2 | __author__ = ['kelvinguu'] 3 | 4 | 5 | def to_unicode(s): 6 | """Return the object as unicode (only matters for Python 2.x). 7 | 8 | If s is already Unicode, return s as is. 9 | Otherwise, assume that s is UTF-8 encoded, and convert to Unicode. 10 | 11 | :param (basestring) s: a str, unicode or other basestring object 12 | :return (unicode): the object as unicode 13 | """ 14 | if not isinstance(s, six.string_types): 15 | raise ValueError("{} must be str or unicode.".format(s)) 16 | if not isinstance(s, six.text_type): 17 | s = six.text_type(s, 'utf-8') 18 | return s 19 | -------------------------------------------------------------------------------- /stanza/unstable/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | import warnings 4 | 5 | warnings.warn('stanza.unstable has been renamed to stanza.research; the name unstable is deprecated.') 6 | 7 | sys.modules[__name__] = importlib.import_module('..research', __name__) 8 | -------------------------------------------------------------------------------- /stanza/util/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | -------------------------------------------------------------------------------- /stanza/util/postgres.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities to use when interfacing with Postgres. 3 | - These utilities support the workflow wherein you store annoated 4 | sentences in a database. 5 | """ 6 | __author__ = 'arunchaganty' 7 | import os 8 | import stanza 9 | import requests 10 | import logging 11 | 12 | def unescape_sql(inp): 13 | """ 14 | :param inp: an input string to be unescaped 15 | :return: return the unescaped version of the string. 16 | """ 17 | if inp.startswith('"') and inp.endswith('"'): 18 | inp = inp[1:-1] 19 | return inp.replace('""','"').replace('\\\\','\\') 20 | 21 | def parse_psql_array(inp): 22 | """ 23 | :param inp: a string encoding an array 24 | :return: the array of elements as represented by the input 25 | """ 26 | inp = unescape_sql(inp) 27 | # Strip '{' and '}' 28 | if inp.startswith("{") and inp.endswith("}"): 29 | inp = inp[1:-1] 30 | 31 | lst = [] 32 | elem = "" 33 | in_quotes, escaped = False, False 34 | 35 | for ch in inp: 36 | if escaped: 37 | elem += ch 38 | escaped = False 39 | elif ch == '"': 40 | in_quotes = not in_quotes 41 | escaped = False 42 | elif ch == '\\': 43 | escaped = True 44 | else: 45 | if in_quotes: 46 | elem += ch 47 | elif ch == ',': 48 | lst.append(elem) 49 | elem = "" 50 | else: 51 | elem += ch 52 | escaped = False 53 | if len(elem) > 0: 54 | lst.append(elem) 55 | return lst 56 | 57 | def test_parse_psql_array(): 58 | """ 59 | test case for parse_psql_array 60 | """ 61 | inp = '{Bond,was,set,at,$,"1,500",each,.}' 62 | lst = ["Bond", "was", "set", "at", "$", "1,500", "each","."] 63 | lst_ = parse_psql_array(inp) 64 | assert all([x == y for (x,y) in zip(lst, lst_)]) 65 | 66 | def escape_sql(inp): 67 | """ 68 | :param inp: an input string to be escaped 69 | :return: return the escaped version of the string. 70 | """ 71 | return '"' + inp.replace('"','""').replace('\\','\\\\') + '"' 72 | 73 | def to_psql_array(inp): 74 | """ 75 | :param inp: an array to be encoded. 76 | :return: a string encoding the array 77 | """ 78 | return "{" + ",".join(map(escape_sql, inp)) + "}" 79 | 80 | def test_to_psql_array(): 81 | """ 82 | Test for to_psql_array 83 | """ 84 | inp = ["Bond", "was", "set", "at", "$", "1,500", "each","."] 85 | out = '{"Bond","was","set","at","$","1,500","each","."}' 86 | out_ = to_psql_array(inp) 87 | assert out == out_ 88 | 89 | -------------------------------------------------------------------------------- /stanza/util/resource.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | import os 3 | import stanza 4 | import requests 5 | import logging 6 | 7 | 8 | def get_from_url(url): 9 | """ 10 | :param url: url to download from 11 | :return: return the content at the url 12 | """ 13 | return requests.get(url).content 14 | 15 | 16 | def get_data_or_download(dir_name, file_name, url='', size='unknown'): 17 | """Returns the data. if the data hasn't been downloaded, then first download the data. 18 | 19 | :param dir_name: directory to look in 20 | :param file_name: file name to retrieve 21 | :param url: if the file is not found, then download it from this url 22 | :param size: the expected size 23 | :return: path to the requested file 24 | """ 25 | dname = os.path.join(stanza.DATA_DIR, dir_name) 26 | fname = os.path.join(dname, file_name) 27 | if not os.path.isdir(dname): 28 | assert url, 'Could not locate data {}, and url was not specified. Cannot retrieve data.'.format(dname) 29 | os.makedirs(dname) 30 | if not os.path.isfile(fname): 31 | assert url, 'Could not locate data {}, and url was not specified. Cannot retrieve data.'.format(fname) 32 | logging.warn('downloading from {}. This file could potentially be *very* large! Actual size ({})'.format(url, size)) 33 | with open(fname, 'wb') as f: 34 | f.write(get_from_url(url)) 35 | return fname 36 | -------------------------------------------------------------------------------- /stanza/util/unicode.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import six 3 | 4 | 5 | def uprint(s): 6 | if six.PY2: 7 | print(s.encode('utf-8')) 8 | else: 9 | print(s) 10 | 11 | 12 | def urepr(s): 13 | if six.PY2: 14 | return repr(s).decode('unicode_escape') 15 | else: 16 | return repr(s) 17 | 18 | 19 | def uopen(filename, *args, **kwargs): 20 | return codecs.open(filename, *args, encoding='utf-8', **kwargs) 21 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | -------------------------------------------------------------------------------- /test/slow_tests/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | -------------------------------------------------------------------------------- /test/slow_tests/text/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | -------------------------------------------------------------------------------- /test/slow_tests/text/test_glove.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | 3 | import numpy as np 4 | from unittest import TestCase 5 | from stanza.text.vocab import GloveVocab 6 | 7 | 8 | class TestGlove(TestCase): 9 | 10 | def test_get_embeddings(self): 11 | v = GloveVocab() 12 | v.add("!") 13 | e_exclamation = np.array([float(e) for e in """ 14 | -0.58402 0.39031 0.65282 -0.3403 0.19493 -0.83489 0.11929 -0.57291 -0.56844 0.72989 -0.56975 0.53436 -0.38034 0.22471 15 | 0.98031 -0.2966 0.126 0.55222 -0.62737 -0.082242 -0.085359 0.31515 0.96077 0.31986 0.87878 -1.5189 -1.7831 0.35639 16 | 0.9674 -1.5497 2.335 0.8494 -1.2371 1.0623 -1.4267 -0.49056 0.85465 -1.2878 0.60204 -0.35963 0.28586 -0.052162 17 | -0.50818 -0.63459 0.33889 0.28416 -0.2034 -1.2338 0.46715 0.78858 18 | """.split() if e]) 19 | E = v.get_embeddings(corpus='wikipedia_gigaword', n_dim=50) 20 | self.assertTrue(np.allclose(e_exclamation, E[v["!"]])) 21 | 22 | -------------------------------------------------------------------------------- /test/slow_tests/text/test_senna.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | 3 | import numpy as np 4 | from unittest import TestCase 5 | from stanza.text.vocab import SennaVocab 6 | 7 | 8 | class TestSenna(TestCase): 9 | 10 | def test_get_embeddings(self): 11 | v = SennaVocab() 12 | v.add("!") 13 | E = v.get_embeddings() 14 | e_exclamation = np.array([float(e) for e in """ 15 | -1.03682 1.77856 -0.693547 1.5948 1.5799 0.859243 1.15221 -0.976317 0.745304 -0.494589 0.308086 0.25239 16 | -0.1976 1.26203 0.813864 -0.940734 -0.215163 0.11645 0.525697 1.95766 0.394232 1.27717 0.710788 -0.389351 17 | 0.161775 -0.106038 1.14148 0.607948 0.189781 -1.06022 0.280702 0.0251156 -0.198067 2.33027 0.408584 18 | 0.350751 -0.351293 1.77318 -0.723457 -0.13806 -1.47247 0.541779 -2.57005 -0.227714 -0.817816 -0.552209 19 | 0.360149 -0.10278 -0.36428 -0.64853 20 | """.split()]) 21 | self.assertTrue(np.allclose(e_exclamation, E[v["!"]])) 22 | -------------------------------------------------------------------------------- /test/unit_tests/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/stanza-old/920c55d8eaa1e7105971059c66eb448a74c100d6/test/unit_tests/README.md -------------------------------------------------------------------------------- /test/unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | -------------------------------------------------------------------------------- /test/unit_tests/ml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/stanza-old/920c55d8eaa1e7105971059c66eb448a74c100d6/test/unit_tests/ml/__init__.py -------------------------------------------------------------------------------- /test/unit_tests/ml/test_embeddings.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from stanza.ml.embeddings import Embeddings 4 | from stanza.text import Vocab 5 | import numpy as np 6 | from numpy.testing import assert_approx_equal 7 | 8 | 9 | @pytest.fixture 10 | def embeddings(): 11 | v = Vocab('unk') 12 | v.update('what a show'.split()) 13 | array = np.reshape(np.arange(12), (4, 3)) 14 | return Embeddings(array, v) 15 | 16 | 17 | @pytest.fixture 18 | def dict_embeddings(): 19 | return {'a': [6, 7, 8], 20 | 'show': [9, 10, 11], 21 | 'unk': [0, 1, 2], 22 | 'what': [3, 4, 5]} 23 | 24 | 25 | def test_to_dict(embeddings, dict_embeddings): 26 | d = embeddings.to_dict() 27 | assert d == dict_embeddings 28 | 29 | 30 | def test_from_dict(embeddings, dict_embeddings): 31 | emb = Embeddings.from_dict(dict_embeddings, 'unk') 32 | assert emb.to_dict() == dict_embeddings 33 | 34 | 35 | def test_get_item(embeddings): 36 | assert embeddings['what'].tolist() == [3, 4, 5] 37 | 38 | 39 | def test_inner_products(embeddings): 40 | query = np.array([3, 2, 1]) 41 | scores = embeddings.inner_products(query) 42 | correct = { 43 | 'a': 18 + 14 + 8, 44 | 'show': 27 + 20 + 11, 45 | 'unk': 2 + 2, 46 | 'what': 9 + 8 + 5, 47 | } 48 | assert scores == correct 49 | 50 | knn = embeddings.k_nearest(query, 3) 51 | assert knn == [('show', 58), ('a', 40), ('what', 22)] 52 | 53 | 54 | def test_k_nearest_approx(embeddings): 55 | # Code for calculating the correct cosine similarities. 56 | # for i in range(len(array)): 57 | # print 1-scipy.spatial.distance.cosine(array[i,:], query) 58 | 59 | query = np.array([3, 2, 1]) 60 | knn = embeddings.k_nearest_approx(query, 3) 61 | correct = [('show', 0.89199106528525429), ('a', 0.87579576196887721), ('what', 0.83152184062029977)] 62 | assert len(knn) == len(correct) 63 | for (w1, s1), (w2, s2) in zip(knn, correct): 64 | assert w1 == w2 65 | assert_approx_equal(s1, s2) 66 | 67 | 68 | def test_subset(embeddings): 69 | sub = embeddings.subset(['a', 'what']) 70 | assert sub.to_dict() == {'a': [6, 7, 8], 'unk': [0, 1, 2], 'what': [3, 4, 5]} 71 | -------------------------------------------------------------------------------- /test/unit_tests/ml/test_tensorflow_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from stanza.ml.tensorflow_utils import labels_to_onehots 4 | from unittest import TestCase 5 | 6 | __author__ = 'kelvinguu' 7 | 8 | 9 | class TestTensorFlowUtils(TestCase): 10 | 11 | @classmethod 12 | def setUpClass(cls): 13 | cls.sess = tf.InteractiveSession() 14 | 15 | def test_labels_to_onehots(self): 16 | labels_list = [0, 1, 2, 3, 1] 17 | 18 | labels = tf.constant(labels_list, dtype=tf.int32) 19 | onehots = labels_to_onehots(labels, 5) 20 | result = onehots.eval() 21 | 22 | correct = np.zeros((5, 5)) 23 | for i in range(5): 24 | correct[i, labels_list[i]] = 1 25 | 26 | # self.assertTrue(False) 27 | self.assertTrue(np.all(result == correct)) -------------------------------------------------------------------------------- /test/unit_tests/monitoring/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | -------------------------------------------------------------------------------- /test/unit_tests/monitoring/test_summary.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wmonroe4' 2 | 3 | import numpy as np 4 | from unittest import TestCase 5 | import stanza.monitoring.summary as summary 6 | from stanza.research.mockfs import patcher 7 | 8 | 9 | class TestLargeFloats(TestCase): 10 | ''' 11 | Proto serialization breaks if floats exceed the maximum of a float64. 12 | Make sure summary.py converts these to inf to avoid crashes. 13 | ''' 14 | def test_large_hist(self): 15 | fs = patcher('stanza.monitoring.summary', '/test') 16 | open = fs.start() 17 | 18 | writer = summary.SummaryWriter('/test/large_hist.tfevents') 19 | writer.log_histogram(1, 'bighist', np.array(1.0e39)) 20 | writer.flush() 21 | with open('/test/large_hist.tfevents', 'r') as infile: 22 | events = list(summary.read_events(infile)) 23 | 24 | self.assertEqual(len(events), 1) 25 | self.assertEqual(len(events[0].summary.value), 1) 26 | self.assertTrue(events[0].summary.value[0].HasField('histo'), 27 | events[0].summary.value[0]) 28 | 29 | fs.stop() 30 | 31 | def test_large_scalar(self): 32 | fs = patcher('stanza.monitoring.summary', '/test') 33 | open = fs.start() 34 | 35 | writer = summary.SummaryWriter('/test/large_scalar.tfevents') 36 | writer.log_scalar(1, 'bigvalue', 1.0e39) 37 | writer.flush() 38 | with open('/test/large_scalar.tfevents', 'r') as infile: 39 | events = list(summary.read_events(infile)) 40 | 41 | self.assertEqual(len(events), 1) 42 | self.assertEqual(len(events[0].summary.value), 1) 43 | self.assertTrue(np.isinf(events[0].summary.value[0].simple_value)) 44 | 45 | fs.stop() 46 | -------------------------------------------------------------------------------- /test/unit_tests/monitoring/test_trigger.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor, kelvinguu' 2 | 3 | from unittest import TestCase 4 | from stanza.monitoring.trigger import ThresholdTrigger, SlopeTrigger, PatienceTrigger 5 | 6 | 7 | class TestEarlyStopping(TestCase): 8 | 9 | def test_threshold(self): 10 | e = ThresholdTrigger(min_threshold=-10, max_threshold=2) 11 | for val in xrange(-10, 3): 12 | self.assertFalse(e(val)) 13 | self.assertTrue(e(-10.1)) 14 | self.assertTrue(e(2.1)) 15 | 16 | def test_patience(self): 17 | e = PatienceTrigger(patience=3) 18 | self.assertFalse(e(10)) 19 | self.assertFalse(e(9)) 20 | self.assertFalse(e(8)) 21 | self.assertFalse(e(11)) 22 | self.assertFalse(e(10)) 23 | self.assertFalse(e(1)) 24 | self.assertFalse(e(10)) 25 | self.assertTrue(e(10)) 26 | 27 | def test_slope_threshold(self): 28 | e = SlopeTrigger(range=(-1, 1), window_size=2) 29 | self.assertFalse(e(1)) # not enough points 30 | self.assertTrue(e(2)) # slope = 1 31 | self.assertFalse(e(4)) # slope 2 > 1 32 | self.assertFalse(e(2)) # slope -2 < -1 33 | self.assertTrue(e(2)) # slope = 0 34 | -------------------------------------------------------------------------------- /test/unit_tests/nlp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/stanza-old/920c55d8eaa1e7105971059c66eb448a74c100d6/test/unit_tests/nlp/__init__.py -------------------------------------------------------------------------------- /test/unit_tests/nlp/document.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/stanza-old/920c55d8eaa1e7105971059c66eb448a74c100d6/test/unit_tests/nlp/document.pb -------------------------------------------------------------------------------- /test/unit_tests/nlp/test_data.py: -------------------------------------------------------------------------------- 1 | # 2 | # pylint: disable=no-self-use, redefined-outer-name 3 | 4 | import copy 5 | import json 6 | 7 | import pytest 8 | 9 | import stanza.nlp.CoreNLP_pb2 as proto 10 | 11 | from stanza.nlp.corenlp import AnnotatedDocument, AnnotatedToken, AnnotatedSentence 12 | 13 | 14 | @pytest.fixture 15 | def json_dict(): 16 | """What CoreNLP would return for 'Belgian swimmers beat the United States. Really?'""" 17 | return json.loads('{"text": "Belgian swimmers beat the United States. Really?", "sentence": [{"characterOffsetBegin": 0, "hasRelationAnnotations": false, "hasNumerizedTokensAnnotation": false, "tokenOffsetBegin": 0, "token": [{"before": "", "value": "Belgian", "hasXmlContext": false, "endChar": 7, "beginChar": 0, "after": " ", "originalText": "Belgian", "word": "Belgian"}, {"before": " ", "value": "swimmers", "hasXmlContext": false, "endChar": 16, "beginChar": 8, "after": " ", "originalText": "swimmers", "word": "swimmers"}, {"before": " ", "value": "beat", "hasXmlContext": false, "endChar": 21, "beginChar": 17, "after": " ", "originalText": "beat", "word": "beat"}, {"before": " ", "value": "the", "hasXmlContext": false, "endChar": 25, "beginChar": 22, "after": " ", "originalText": "the", "word": "the"}, {"before": " ", "value": "United", "hasXmlContext": false, "endChar": 32, "beginChar": 26, "after": " ", "originalText": "United", "word": "United"}, {"before": " ", "value": "States", "hasXmlContext": false, "endChar": 39, "beginChar": 33, "after": "", "originalText": "States", "word": "States"}, {"before": "", "value": ".", "hasXmlContext": false, "endChar": 40, "beginChar": 39, "after": " ", "originalText": ".", "word": "."}], "tokenOffsetEnd": 7, "sentenceIndex": 0, "characterOffsetEnd": 40}, {"characterOffsetBegin": 41, "hasRelationAnnotations": false, "hasNumerizedTokensAnnotation": false, "tokenOffsetBegin": 7, "token": [{"before": " ", "value": "Really", "hasXmlContext": false, "endChar": 47, "beginChar": 41, "after": "", "originalText": "Really", "word": "Really"}, {"before": "", "value": "?", "hasXmlContext": false, "endChar": 48, "beginChar": 47, "after": "", "originalText": "?", "word": "?"}], "tokenOffsetEnd": 9, "sentenceIndex": 1, "characterOffsetEnd": 48}]}') 18 | 19 | @pytest.fixture 20 | def document_pb(): 21 | """What CoreNLP would return for: 22 | "Barack Hussein Obama is an American politician who is the 44th 23 | and current President of the United States. He is the first 24 | African American to hold the office and the first president born 25 | outside the continental United States. Born in Honolulu, Hawaii, 26 | Obama is a graduate of Columbia University and Harvard Law 27 | School, where he was president of the Harvard Law Review." 28 | """ 29 | doc = proto.Document() 30 | with open("test/unit_tests/nlp/document.pb", "rb") as f: 31 | doc.ParseFromString(f.read()) 32 | return doc 33 | 34 | class TestAnnotatedToken(object): 35 | #def test_json_to_pb(self, json_dict): 36 | # token_dict = json_dict['sentences'][0]['tokens'][0] 37 | # token = AnnotatedToken.json_to_pb(token_dict) 38 | # assert token.after == u' ' 39 | # assert token.before == u'' 40 | # assert token.beginChar == 0 41 | # assert token.endChar == 7 42 | # assert token.originalText == u'Belgian' 43 | # assert token.word == u'Belgian' 44 | 45 | def test_parse_pb(self, document_pb): 46 | token_pb = document_pb.sentence[1].token[3] 47 | token = AnnotatedToken.from_pb(token_pb) 48 | assert token.after == u' ' 49 | assert token.before == u' ' 50 | assert token.character_span == (117, 122) 51 | assert token.originalText == u'first' 52 | assert token.word == u'first' 53 | assert token.lemma == u'first' 54 | assert token.ner == u'ORDINAL' 55 | assert token.pos == u'JJ' 56 | 57 | class TestAnnotatedSentence(object): 58 | #def test_json_to_pb(self, json_dict): 59 | # orig_text = 'Really?' 60 | # sent_dict = json_dict['sentences'][1] 61 | # sent = AnnotatedSentence.from_json(sent_dict) 62 | # assert sent.text == orig_text 63 | # assert sent[1].word == u'?' 64 | 65 | def test_parse_pb(self, document_pb): 66 | sentence_pb = document_pb.sentence[0] 67 | sentence = AnnotatedSentence.from_pb(sentence_pb) 68 | assert sentence.text == u"Barack Hussein Obama is an American politician who is the 44th and current President of the United States." 69 | assert len(sentence) == 19 70 | assert sentence[1].word == "Hussein" 71 | assert sentence[1].ner == "PERSON" 72 | 73 | def test_depparse(self, document_pb): 74 | sentence_pb = document_pb.sentence[0] 75 | sentence = AnnotatedSentence.from_pb(sentence_pb) 76 | dp = sentence.depparse() 77 | assert dp.roots == [6] # politician 78 | assert (2, 'nsubj') in dp.children(6) # Obama is child of politician 79 | assert (3, 'cop') in dp.children(6) # 'is' is ia copula 80 | assert (0, 'compound') in dp.children(2) # 'Barack' is part of the compount that is Obama. 81 | 82 | def test_depparse_json(self, document_pb): 83 | sentence_pb = document_pb.sentence[0] 84 | sentence = AnnotatedSentence.from_pb(sentence_pb) 85 | dp = sentence.depparse() 86 | edges = dp.to_json() 87 | # politician is root 88 | assert any((edge['dep'] == 'root' and edge['dependent'] == 7 and edge['dependentgloss'] == 'politician') for edge in edges) 89 | # Obama is child of politician 90 | assert any((edge['governer'] == 7 and edge['dep'] == 'nsubj' and edge['dependent'] == 3 and edge['dependentgloss'] == 'Obama') for edge in edges) 91 | # 'is' is ia copula 92 | assert any((edge['governer'] == 7 and edge['dep'] == 'cop' and edge['dependent'] == 4 and edge['dependentgloss'] == 'is') for edge in edges) 93 | # 'Barack' is part of the compount that is Obama. 94 | assert any((edge['governer'] == 3 and edge['dep'] == 'compound' and edge['dependent'] == 1 and edge['dependentgloss'] == 'Barack') for edge in edges) 95 | 96 | def test_from_tokens(self): 97 | text = "This is a test." 98 | tokens = "This is a test .".split() 99 | sentence = AnnotatedSentence.from_tokens(text, tokens) 100 | assert sentence.text == text 101 | assert len(sentence) == 5 102 | assert sentence[1].word == "is" 103 | 104 | class TestAnnotatedDocument(object): 105 | #def test_json_to_pb(self, json_dict): 106 | # orig_text = 'Belgian swimmers beat the United States. Really?' 107 | # doc = AnnotatedDocument.from_json(json_dict) 108 | # assert doc.text == orig_text 109 | # assert doc[1].text == 'Really?' 110 | 111 | def test_json(self, json_dict): 112 | doc = AnnotatedDocument.from_json(json_dict) 113 | new_json = doc.to_json() 114 | assert json_dict == new_json 115 | 116 | def test_eq(self, json_dict): 117 | # exact copy 118 | json_dict1 = copy.deepcopy(json_dict) 119 | 120 | # same as json_dict, but 'Belgian' is no longer capitalized 121 | json_dict2 = copy.deepcopy(json_dict) 122 | first_token_json = json_dict2['sentence'][0]['token'][0] 123 | first_token_json[u'originalText'] = 'belgian' 124 | first_token_json[u'word'] = 'belgian' 125 | 126 | doc = AnnotatedDocument.from_json(json_dict) 127 | doc1 = AnnotatedDocument.from_json(json_dict1) 128 | doc2 = AnnotatedDocument.from_json(json_dict2) 129 | 130 | assert doc == doc1 131 | assert doc != doc2 132 | 133 | @pytest.fixture 134 | def doc(self, json_dict): 135 | return AnnotatedDocument.from_json(json_dict) 136 | 137 | def test_properties(self, doc): 138 | assert doc[0][1].word == u'swimmers' 139 | assert doc[0][2].character_span == (17, 21) 140 | assert doc[0].document == doc 141 | 142 | def test_parse_pb(self, document_pb): 143 | document = AnnotatedDocument.from_pb(document_pb) 144 | assert document.text == u"Barack Hussein Obama is an American politician who is the 44th and current President of the United States. He is the first African American to hold the office and the first president born outside the continental United States. Born in Honolulu, Hawaii, Obama is a graduate of Columbia University and Harvard Law School, where he was president of the Harvard Law Review." 145 | assert len(document) == 3 146 | assert document[0][1].word == "Hussein" 147 | assert document[0][1].ner == "PERSON" 148 | 149 | def test_mentions(self, document_pb): 150 | document = AnnotatedDocument.from_pb(document_pb) 151 | mentions = document.mentions 152 | assert len(mentions) == 17 153 | -------------------------------------------------------------------------------- /test/unit_tests/text/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | -------------------------------------------------------------------------------- /test/unit_tests/text/test_dataset.py: -------------------------------------------------------------------------------- 1 | __author__ = 'victor' 2 | 3 | from unittest import TestCase 4 | from stanza.text.dataset import Dataset, InvalidFieldsException 5 | from tempfile import NamedTemporaryFile 6 | import random 7 | from collections import OrderedDict 8 | import os 9 | 10 | 11 | class TestDataset(TestCase): 12 | 13 | CONLL = """# description\ttags 14 | Alice 15 | Hello\tt1 16 | my\tt2 17 | name\tt3 18 | is\tt4 19 | alice\tt5 20 | 21 | Bob 22 | I'm\tt1 23 | bob\tt2""" 24 | 25 | CONLL_MOCK = OrderedDict([ 26 | ('description', [ 27 | ['Hello', 'my', 'name', 'is', 'alice'], 28 | ["I'm", 'bob'], 29 | ]), 30 | ('tags', [ 31 | ['t1', 't2', 't3', 't4', 't5'], 32 | ['t1', 't2'], 33 | ]), 34 | ('label', ['Alice', 'Bob']), 35 | ]) 36 | 37 | MOCK = OrderedDict([('Name', ['Alice', 'Bob', 'Carol']), ('SSN', ['123', None, '7890'])]) 38 | 39 | def setUp(self): 40 | random.seed(1) 41 | self.mock = Dataset(OrderedDict([(name, d[:]) for name, d in self.MOCK.items()])) 42 | self.conll = Dataset(OrderedDict([(name, d[:]) for name, d in self.CONLL_MOCK.items()])) 43 | 44 | def test_init(self): 45 | self.assertRaises(InvalidFieldsException, lambda: Dataset({'name': ['alice', 'bob'], 'ssn': ['1']})) 46 | 47 | def test_length(self): 48 | self.assertEqual(0, len(Dataset({}))) 49 | self.assertEqual(2, len(Dataset({'name': ['foo', 'bar']}))) 50 | 51 | def test_load_conll(self): 52 | with NamedTemporaryFile() as f: 53 | f.write(self.CONLL) 54 | f.flush() 55 | d = Dataset.load_conll(f.name) 56 | self.assertDictEqual(self.CONLL_MOCK, d.fields) 57 | 58 | def test_write_conll(self): 59 | f = NamedTemporaryFile(delete=False) 60 | f.close() 61 | d = Dataset(self.CONLL_MOCK) 62 | d.write_conll(f.name) 63 | with open(f.name) as fin: 64 | self.assertEqual(self.CONLL, fin.read()) 65 | os.remove(f.name) 66 | 67 | def test_convert_new(self): 68 | d = self.mock 69 | dd = d.convert({'Name': str.lower}, in_place=False) 70 | 71 | # made a copy 72 | self.assertIsNot(d, dd) 73 | 74 | # doesn't change original 75 | self.assertDictEqual(self.MOCK, d.fields) 76 | 77 | # changes copy 78 | self.assertDictEqual({'Name': ['alice', 'bob', 'carol'], 'SSN': ['123', None, '7890']}, dd.fields) 79 | 80 | def test_convert_in_place(self): 81 | d = self.mock 82 | dd = d.convert({'Name': str.lower}, in_place=True) 83 | 84 | # did not make a copy 85 | self.assertIs(d, dd) 86 | 87 | # changes original 88 | self.assertDictEqual({'Name': ['alice', 'bob', 'carol'], 'SSN': ['123', None, '7890']}, d.fields) 89 | 90 | def test_shuffle(self): 91 | d = self.mock 92 | dd = d.shuffle() 93 | self.assertIs(d, dd) 94 | 95 | # this relies on random seed 96 | self.assertDictEqual({'Name': ['Carol', 'Bob', 'Alice'], 'SSN': ['7890', None, '123']}, d.fields) 97 | 98 | def test_getitem(self): 99 | d = self.mock 100 | self.assertRaises(IndexError, lambda: d.__getitem__(10)) 101 | self.assertDictEqual({'Name': 'Bob', 'SSN': None}, d[1]) 102 | self.assertDictEqual({'Name': 'Alice', 'SSN': '123'}, d[0]) 103 | self.assertDictEqual({'Name': 'Carol', 'SSN': '7890'}, d[-1]) 104 | 105 | self.assertDictEqual({'Name': ['Bob', 'Carol'], 'SSN': [None, '7890']}, d[1:]) 106 | self.assertDictEqual({'Name': ['Alice', 'Bob'], 'SSN': ['123', None]}, d[:2]) 107 | self.assertDictEqual({'Name': ['Bob'], 'SSN': [None]}, d[1:2]) 108 | 109 | def test_setitem(self): 110 | d = self.mock 111 | self.assertRaises(InvalidFieldsException, lambda: d.__setitem__(1, 'foo')) 112 | self.assertRaises(IndexError, lambda: d.__setitem__(10, {'Name': 'Victor', 'SSN': 123})) 113 | d[1] = {'Name': 'Victor', 'SSN': 123} 114 | self.assertDictEqual({'Name': ['Alice', 'Victor', 'Carol'], 'SSN': ['123', 123, '7890']}, d.fields) 115 | 116 | def test_copy(self): 117 | d = self.mock 118 | dd = d.copy() 119 | self.assertIsNot(d, dd) 120 | self.assertDictEqual(self.MOCK, d.fields) 121 | self.assertDictEqual(self.MOCK, dd.fields) 122 | 123 | for name in d.fields.keys(): 124 | self.assertIsNot(d.fields[name], dd.fields[name]) 125 | -------------------------------------------------------------------------------- /test/unit_tests/text/test_vocab.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | __author__ = 'victor, kelvinguu' 4 | 5 | from collections import Counter 6 | from unittest import TestCase 7 | from stanza.text.vocab import Vocab, SennaVocab, GloveVocab 8 | 9 | 10 | # new tests are written in the lighter-weight pytest format 11 | @pytest.fixture 12 | def vocab(): 13 | v = Vocab('unk') 14 | v.update('zero one two two three three three'.split()) 15 | return v 16 | 17 | 18 | def test_eq(vocab): 19 | v = Vocab('unk') 20 | v.update('zero one two two three three three'.split()) 21 | assert v == vocab 22 | v.add('zero', count=10) 23 | assert v == vocab # equality doesn't depend on count 24 | v.add('four') 25 | assert v != vocab 26 | 27 | 28 | def test_subset(vocab): 29 | v = vocab.subset(['zero', 'three', 'two']) 30 | correct = {'unk': 0, 'zero': 1, 'three': 2, 'two': 3} 31 | assert dict(v) == correct 32 | assert v._counts == Counter({'unk': 0, 'zero': 1, 'three': 3, 'two': 2}) 33 | 34 | 35 | class TestVocab(TestCase): 36 | 37 | def setUp(self): 38 | self.Vocab = Vocab 39 | 40 | def test_unk(self): 41 | unk = '**UNK**' 42 | v = self.Vocab(unk=unk) 43 | self.assertEqual(len(v), 1) 44 | self.assertIn(unk, v) 45 | self.assertEqual(v[unk], 0) 46 | self.assertEqual(v.count(unk), 0) 47 | 48 | def test_add(self): 49 | v = self.Vocab('**UNK**') 50 | v.add('hi') 51 | self.assertIn('hi', v) 52 | self.assertEqual(len(v), 2) 53 | self.assertEqual(v.count('hi'), 1) 54 | self.assertEqual(v['hi'], 1) 55 | 56 | def test_sent2index(self): 57 | v = self.Vocab(unk='unk') 58 | words = ['i', 'like', 'pie'] 59 | v.update(words) 60 | self.assertEqual(v.words2indices(words), [1, 2, 3]) 61 | self.assertEqual(v.words2indices(['i', 'said']), [1, 0]) 62 | 63 | def test_index2sent(self): 64 | v = self.Vocab(unk='unk') 65 | v.update(['i', 'like', 'pie']) 66 | words = v.indices2words([1, 2, 3, 0]) 67 | self.assertEqual(words, ['i', 'like', 'pie', 'unk']) 68 | 69 | def test_prune_rares(self): 70 | v = self.Vocab(unk='unk') 71 | v.update(['hi'] * 3 + ['bye'] * 5) 72 | self.assertEqual({'hi': 3, 'bye': 5, 'unk': 0}, dict(v._counts)) 73 | p = v.prune_rares(cutoff=4) 74 | self.assertEqual({'bye': 5, 'unk': 0}, dict(p._counts)) 75 | 76 | def test_sort_by_decreasing_count(self): 77 | v = self.Vocab(unk='unk') 78 | v.update('some words words for for for you you you you'.split()) 79 | s = v.sort_by_decreasing_count() 80 | self.assertEqual(['unk', 'you', 'for', 'words', 'some'], list(iter(s))) 81 | self.assertEqual({'unk': 0, 'you': 4, 'for': 3, 'words': 2, 'some': 1}, dict(s._counts)) 82 | 83 | def test_from_file(self): 84 | lines = ['unk\t10\n', 'cat\t4\n', 'bear\t6'] 85 | vocab = self.Vocab.from_file(lines) 86 | self.assertEqual(vocab._counts, Counter({'unk': 10, 'cat': 4, 'bear': 6})) 87 | self.assertEqual(dict(vocab), {'unk': 0, 'cat': 1, 'bear': 2}) 88 | 89 | 90 | class TestFrozen: 91 | 92 | def test_words2indices(self): 93 | v = Vocab('unk') 94 | words = ['i', 'like', 'pie'] 95 | v.update(words) 96 | v = v.freeze() 97 | assert v.words2indices(words) == [1, 2, 3] 98 | assert v.words2indices(['i', 'said']) == [1, 0] 99 | 100 | def test_indices2words(self): 101 | v = Vocab(unk='unk') 102 | v.update(['i', 'like', 'pie']) 103 | words = v.indices2words([1, 2, 3, 0]) 104 | assert words == ['i', 'like', 'pie', 'unk'] 105 | 106 | 107 | class TestSenna(TestVocab): 108 | 109 | def setUp(self): 110 | self.Vocab = SennaVocab 111 | 112 | 113 | class TestGlove(TestVocab): 114 | 115 | def setUp(self): 116 | self.Vocab = GloveVocab 117 | --------------------------------------------------------------------------------