├── .circleci └── config.yml ├── .github ├── ISSUE_TEMPLATE.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── make.bat └── source │ ├── _static │ ├── graph_all.svg │ ├── graph_partitioned.svg │ ├── graph_unpartitioned.svg │ ├── icon_bw.svg │ ├── icon_color.svg │ ├── logo_bw.svg │ └── logo_color.svg │ ├── batch_preparation.rst │ ├── conf.py │ ├── configuration_file.rst │ ├── data_model.rst │ ├── distributed_training.rst │ ├── downstream_tasks.rst │ ├── dynamic_relations.rst │ ├── evaluation.rst │ ├── faq_troubleshooting.rst │ ├── featurized_entities.rst │ ├── index.rst │ ├── input_output.rst │ ├── loss_optimization.rst │ ├── pretrained_embeddings.rst │ ├── related.rst │ └── scoring.rst ├── ifbpy.py ├── setup.cfg ├── setup.py ├── test ├── __init__.py ├── test_batching.py ├── test_bucket_scheduling.py ├── test_checkpoint_manager.py ├── test_distributed.py ├── test_edgelist.py ├── test_entitylist.py ├── test_functional.py ├── test_graph_storages.py ├── test_losses.py ├── test_model.py ├── test_optimizers.py ├── test_schema.py ├── test_stats.py ├── test_train.py └── test_util.py └── torchbiggraph ├── VERSION.txt ├── __init__.py ├── async_adagrad.py ├── batching.py ├── bucket_scheduling.py ├── checkpoint_manager.py ├── checkpoint_storage.py ├── config.py ├── converters ├── __init__.py ├── dictionary.py ├── export_to_tsv.py ├── import_from_parquet.py ├── import_from_tsv.py ├── importers.py └── utils.py ├── distributed.py ├── edgelist.py ├── entitylist.py ├── eval.py ├── examples ├── LICENSE.txt ├── __init__.py ├── configs │ ├── fb15k_config_cpu.py │ ├── fb15k_config_gpu.py │ └── livejournal_config.py ├── fb15k.py └── livejournal.py ├── filtered_eval.py ├── graph_storages.py ├── losses.py ├── model.py ├── operators.py ├── parameter_sharing.py ├── partitionserver.py ├── plugin.py ├── regularizers.py ├── row_adagrad.py ├── rpc.py ├── schema.py ├── stats.py ├── tensorlist.py ├── train.py ├── train_cpu.py ├── train_gpu.py ├── types.py ├── util.cpp └── util.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Python CircleCI 2.0 configuration file 2 | # 3 | # Check https://circleci.com/docs/2.0/language-python/ for more details 4 | # 5 | version: 2 6 | jobs: 7 | build: 8 | docker: 9 | - image: circleci/python:3.6 10 | working_directory: ~/PyTorch-BigGraph 11 | steps: 12 | - restore_cache: 13 | keys: 14 | - source-v1-{{ .Branch }}-{{ .Revision }} 15 | - source-v1-{{ .Branch }}- 16 | - source-v1- 17 | - checkout 18 | - save_cache: 19 | key: source-v1-{{ .Branch }}-{{ .Revision }} 20 | paths: 21 | - .git 22 | - restore_cache: 23 | keys: 24 | - pip-cache-v1-{{ arch }}-{{ .Branch }}- 25 | - pip-cache-v1-{{ arch }}- 26 | - pip-cache-v1- 27 | - run: 28 | name: Install PyTorch-BigGraph 29 | command: | 30 | python3 -m venv ~/venv 31 | . ~/venv/bin/activate 32 | pip install pip --upgrade --upgrade-strategy eager 33 | pip install dataclasses numpy typing_extensions 34 | pip install torch --no-index --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --no-deps --upgrade --upgrade-strategy eager --progress-bar off 35 | pip install . --upgrade --upgrade-strategy eager --progress-bar off 36 | pip freeze > ~/pip_freeze.txt 37 | - save_cache: 38 | key: pip-cache-v1-{{ arch }}-{{ .Branch }}-{{ checksum "~/pip_freeze.txt" }} 39 | paths: 40 | - ~/.cache/pip 41 | - run: 42 | name: Test all the things! 43 | command: | 44 | . ~/venv/bin/activate 45 | CIRCLECI_TEST=1 python3 setup.py test 46 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Steps to reproduce 2 | 3 | 1. _____ 4 | 2. _____ 5 | 3. _____ 6 | 7 | ## Observed Results 8 | 9 | * What happened? This could be a description, log output, etc. 10 | 11 | ## Expected Results 12 | 13 | * What did you expect to happen? 14 | 15 | ## Relevant Code 16 | 17 | ``` 18 | // TODO(you): code here to reproduce the problem 19 | ``` 20 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Types of changes 2 | 3 | 4 | 5 | - [ ] Docs change / refactoring / dependency upgrade 6 | - [ ] Bug fix (non-breaking change which fixes an issue) 7 | - [ ] New feature (non-breaking change which adds functionality) 8 | - [ ] Breaking change (fix or feature that would cause existing functionality to change) 9 | 10 | ## Motivation and Context / Related issue 11 | 12 | 13 | 14 | 15 | 16 | ## How Has This Been Tested (if it applies) 17 | 18 | 19 | 20 | ## Checklist 21 | 22 | 23 | 24 | 25 | - [ ] The documentation is up-to-date with the changes I made. 26 | - [ ] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**). 27 | - [ ] All tests passed, and additional code has been covered with new tests. 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | build/ 3 | data/ 4 | dist/ 5 | docs/build/ 6 | .DS_Store 7 | *.egg 8 | *.egg-info/ 9 | .idea/ 10 | model/ 11 | __pycache__/ 12 | *.py[cod] 13 | *.so 14 | venv/ 15 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## main 4 | 5 | Initial version 6 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | 78 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to torchbiggraph 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Our Development Process 7 | 8 | This project's source-of-truth is the version in Facebook's internal codebase, 9 | which is continuously synced with the GitHub mirror using 10 | [ShipIt](https://github.com/facebook/fbshipit). Pull requests on GitHub are 11 | copied over using ImportIt (a companion tool for ShipIt). 12 | 13 | ## Pull Requests 14 | 15 | We actively welcome your pull requests. 16 | 17 | 1. Fork the repo and create your branch from `main`. 18 | 2. If you've added code that should be tested, add tests. 19 | 3. If you've changed APIs, update the documentation. 20 | 4. Ensure the test suite passes. 21 | 5. Make sure your code lints. 22 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 23 | 24 | ## Contributor License Agreement ("CLA") 25 | 26 | In order to accept your pull request, we need you to submit a CLA. You only need 27 | to do this once to work on any of Facebook's open source projects. 28 | 29 | Complete your CLA here: 30 | 31 | ## Issues 32 | 33 | We use GitHub issues to track public bugs. Please ensure your description is 34 | clear and has sufficient instructions to be able to reproduce the issue. 35 | 36 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 37 | disclosure of security bugs. In those cases, please go through the process 38 | outlined on that page and do not file a public issue. 39 | 40 | ## Coding Style 41 | 42 | This project adheres to the [PEP 8](https://www.python.org/dev/peps/pep-0008/) 43 | style guidelines. It is linted using [Flake8](https://pypi.org/project/flake8/). 44 | Additional conventions from the [Black](https://black.readthedocs.io/en/stable/) 45 | formatter are sometimes adopted. 46 | 47 | ## License 48 | 49 | By contributing to torchbiggraph, you agree that your contributions will be 50 | licensed under the LICENSE.txt file in the root directory of this source tree. 51 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For torchbiggraph software 4 | 5 | Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | global-include LICENSE* 2 | include *.md 3 | graft docs 4 | prune docs/build 5 | global-exclude .* 6 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/graph_all.svg: -------------------------------------------------------------------------------- 1 | 2 | 8 | 11 | 14 | 17 | 20 | 23 | 26 | 29 | 32 | 35 | 38 | 41 | 44 | 49 | 54 | 59 | 64 | 69 | 74 | 79 | 84 | 89 | 94 | 99 | 104 | 109 | 114 | 117 | 120 | 123 | 126 | 0 130 | 1 134 | 2 138 | 0 142 | 1 146 | 2 150 | 0 154 | 155 | -------------------------------------------------------------------------------- /docs/source/_static/graph_partitioned.svg: -------------------------------------------------------------------------------- 1 | 2 | 8 | 11 | 14 | 19 | 24 | 29 | 34 | 39 | 44 | 49 | 54 | 59 | 64 | 69 | 74 | 79 | 84 | 87 | 90 | 93 | 96 | 0 100 | 1 104 | 2 108 | 0 112 | 1 116 | 2 120 | 0 124 | 125 | -------------------------------------------------------------------------------- /docs/source/_static/graph_unpartitioned.svg: -------------------------------------------------------------------------------- 1 | 2 | 8 | 11 | 14 | 17 | 20 | 23 | 26 | 29 | 32 | 35 | 38 | 41 | 44 | 49 | 54 | 59 | 64 | 69 | 74 | 79 | 84 | 89 | 94 | 99 | 104 | 109 | 114 | 115 | -------------------------------------------------------------------------------- /docs/source/_static/icon_bw.svg: -------------------------------------------------------------------------------- 1 | 2 | 8 | 9 | 12 | 13 | 16 | 19 | 22 | 25 | 26 | -------------------------------------------------------------------------------- /docs/source/_static/icon_color.svg: -------------------------------------------------------------------------------- 1 | 2 | 8 | 9 | 12 | 13 | 16 | 19 | 22 | 25 | 26 | -------------------------------------------------------------------------------- /docs/source/_static/logo_bw.svg: -------------------------------------------------------------------------------- 1 | 2 | 8 | 9 | 14 | 15 | 18 | 21 | 24 | 27 | 30 | 33 | 36 | 39 | 42 | 45 | 48 | 51 | 54 | 57 | 60 | 63 | 66 | 69 | 72 | 73 | -------------------------------------------------------------------------------- /docs/source/_static/logo_color.svg: -------------------------------------------------------------------------------- 1 | 2 | 8 | 9 | 14 | 15 | 18 | 21 | 24 | 27 | 30 | 33 | 36 | 39 | 42 | 45 | 48 | 51 | 54 | 57 | 60 | 63 | 66 | 69 | 72 | 73 | -------------------------------------------------------------------------------- /docs/source/batch_preparation.rst: -------------------------------------------------------------------------------- 1 | .. _batch-preparation: 2 | 3 | Batch preparation 4 | ================= 5 | 6 | This section presents how the training data is prepared and organized in batches 7 | before the loss is :ref:`calculated ` and :ref:`optimized ` 8 | on each of them. 9 | 10 | Training proceeds by iterating over the edges, through various nested loops. The 11 | outermost one walks through so-called **epochs**. Each epoch is independent and 12 | essentially equivalent to every other one. Their goal is to repeat the inner loop 13 | until convergence. Each epoch visits all the edges exactly once. The number of 14 | epochs is specified in the ``num_epochs`` configuration parameter. 15 | 16 | The edges are partitioned into **edge sets** (one for each directory of the ``edge_paths`` 17 | configuration key) and, within each epoch, the edge sets are traversed in order. 18 | 19 | When iterating over an edge set, each of its buckets is first divided into 20 | equally sized **chunks**: each chunk spans a contiguous interval of edges (in the 21 | order they are stored in the files) and the number of chunks can be tweaked 22 | using the ``num_edge_chunks`` configuration key. The training first operates 23 | on the all the first chunks of all buckets, then on all of their second chunks, 24 | and so on. 25 | 26 | Next, the algorithm iterates over the **buckets**. The order in which buckets are 27 | processed depends on the value of the ``bucket_order`` configuration key. In 28 | addition to a random permutation, there are methods that try to have successive 29 | buckets share a common partition: this allows for that partition to be reused, 30 | thus allowing it to be kept in memory rather than being unloaded and another one 31 | getting loaded in its place. (In :ref:`distributed mode `, 32 | the various trainer processes operate on the buckets at the same time, thus the 33 | iteration is managed differently). 34 | 35 | Once the trainer has fixed a given chunk and a certain bucket, its edges are 36 | finally loaded from disk. When 37 | :ref:`evaluating during training `, a subset of these 38 | edges is withheld (such subset is the same for all epochs). The remaining edges 39 | are immediately uniformly shuffled and then split into equal parts. These parts 40 | are distributed among a pool of **processes**, so that the training can proceed 41 | in parallel on all of them at the same time. These subprocesses are "Hogwild!" 42 | workers, which do not synchronize their computations or memory accesses. The 43 | number of such workers is determined by the ``workers`` parameter. 44 | 45 | The way each worker trains on its set of edges depends on whether 46 | :ref:`dynamic relations ` are in use. The simplest scenario is if 47 | they are, in which case the edges are split into contiguous **batches** (each one having 48 | the size specified in the ``batch_size`` configuration key, except possibly the last 49 | one which could be smaller). Training is then performed on that batch before moving 50 | on to the next one. 51 | 52 | When dynamic relations are not in use, however, the loss can only be computed on 53 | a set of edges that are all of the same type. Thus the worker first randomly 54 | samples a relation type, with probability proportional to the number of edges 55 | of that type that are left in the pool. It then takes the first ``batch_size`` relations of 56 | that type (or fewer, if not enough of them are left), removes them from the pool and 57 | performs training on them. 58 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Configuration file for the Sphinx documentation builder. 5 | # 6 | # This file does only contain a selection of the most common options. For a 7 | # full list see the documentation: 8 | # http://www.sphinx-doc.org/en/master/config 9 | 10 | # -- Path setup -------------------------------------------------------------- 11 | 12 | # If extensions (or modules to document with autodoc) are in another directory, 13 | # add these directories to sys.path here. If the directory is relative to the 14 | # documentation root, use os.path.abspath to make it absolute, like shown here. 15 | # 16 | # import os 17 | # import sys 18 | # sys.path.insert(0, os.path.abspath('.')) 19 | 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "PyTorch-BigGraph" 24 | copyright = "2019, Facebook AI Research" 25 | author = "Facebook AI Research" 26 | 27 | # The short X.Y version 28 | version = "" 29 | # The full version, including alpha/beta/rc tags 30 | release = "1.dev" 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # If your documentation needs a minimal Sphinx version, state it here. 36 | # 37 | # needs_sphinx = '1.0' 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = [ 43 | "sphinx.ext.autodoc", 44 | "sphinx.ext.intersphinx", 45 | "sphinx.ext.todo", 46 | "sphinx.ext.mathjax", 47 | "sphinx.ext.githubpages", 48 | ] 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ["_templates"] 52 | 53 | # The suffix(es) of source filenames. 54 | # You can specify multiple suffix as a list of string: 55 | # 56 | # source_suffix = ['.rst', '.md'] 57 | source_suffix = ".rst" 58 | 59 | # The master toctree document. 60 | master_doc = "index" 61 | 62 | # The language for content autogenerated by Sphinx. Refer to documentation 63 | # for a list of supported languages. 64 | # 65 | # This is also used if you do content translation via gettext catalogs. 66 | # Usually you set "language" from the command line for these cases. 67 | language = None 68 | 69 | # List of patterns, relative to source directory, that match files and 70 | # directories to ignore when looking for source files. 71 | # This pattern also affects html_static_path and html_extra_path. 72 | exclude_patterns = [] 73 | 74 | # The name of the Pygments (syntax highlighting) style to use. 75 | pygments_style = None 76 | 77 | 78 | # -- Options for HTML output ------------------------------------------------- 79 | 80 | # The theme to use for HTML and HTML Help pages. See the documentation for 81 | # a list of builtin themes. 82 | # 83 | html_theme = "alabaster" 84 | 85 | # Theme options are theme-specific and customize the look and feel of a theme 86 | # further. For a list of options available for each theme, see the 87 | # documentation. 88 | # 89 | # html_theme_options = {} 90 | 91 | # Add any paths that contain custom static files (such as style sheets) here, 92 | # relative to this directory. They are copied after the builtin static files, 93 | # so a file named "default.css" will overwrite the builtin "default.css". 94 | html_static_path = ["_static"] 95 | 96 | # Custom sidebar templates, must be a dictionary that maps document names 97 | # to template names. 98 | # 99 | # The default sidebars (for documents that don't match any pattern) are 100 | # defined by theme itself. Builtin themes are using these templates by 101 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 102 | # 'searchbox.html']``. 103 | # 104 | # html_sidebars = {} 105 | 106 | 107 | # -- Options for HTMLHelp output --------------------------------------------- 108 | 109 | # Output file base name for HTML help builder. 110 | htmlhelp_basename = "torchbiggraphdoc" 111 | 112 | 113 | # -- Options for LaTeX output ------------------------------------------------ 114 | 115 | latex_elements = { 116 | # The paper size ('letterpaper' or 'a4paper'). 117 | # 118 | # 'papersize': 'letterpaper', 119 | # The font size ('10pt', '11pt' or '12pt'). 120 | # 121 | # 'pointsize': '10pt', 122 | # Additional stuff for the LaTeX preamble. 123 | # 124 | # 'preamble': '', 125 | # Latex figure (float) alignment 126 | # 127 | # 'figure_align': 'htbp', 128 | } 129 | 130 | # Grouping the document tree into LaTeX files. List of tuples 131 | # (source start file, target name, title, 132 | # author, documentclass [howto, manual, or own class]). 133 | latex_documents = [ 134 | ( 135 | master_doc, 136 | "torchbiggraph.tex", 137 | "PyTorch-BigGraph Documentation", 138 | "Facebook AI Research", 139 | "manual", 140 | ) 141 | ] 142 | 143 | 144 | # -- Options for manual page output ------------------------------------------ 145 | 146 | # One entry per manual page. List of tuples 147 | # (source start file, name, description, authors, manual section). 148 | man_pages = [ 149 | (master_doc, "torchbiggraph", "PyTorch-BigGraph Documentation", [author], 1) 150 | ] 151 | 152 | 153 | # -- Options for Texinfo output ---------------------------------------------- 154 | 155 | # Grouping the document tree into Texinfo files. List of tuples 156 | # (source start file, target name, title, author, 157 | # dir menu entry, description, category) 158 | texinfo_documents = [ 159 | ( 160 | master_doc, 161 | "torchbiggraph", 162 | "PyTorch-BigGraph Documentation", 163 | author, 164 | "torchbiggraph", 165 | "One line description of project.", 166 | "Miscellaneous", 167 | ) 168 | ] 169 | 170 | 171 | # -- Options for Epub output ------------------------------------------------- 172 | 173 | # Bibliographic Dublin Core info. 174 | epub_title = project 175 | 176 | # The unique identifier of the text. This can be a ISBN number 177 | # or the project homepage. 178 | # 179 | # epub_identifier = '' 180 | 181 | # A unique identification for the text. 182 | # 183 | # epub_uid = '' 184 | 185 | # A list of files that should not be packed into the epub file. 186 | epub_exclude_files = ["search.html"] 187 | 188 | 189 | # -- Extension configuration ------------------------------------------------- 190 | 191 | # -- Options for intersphinx extension --------------------------------------- 192 | 193 | # Example configuration for intersphinx: refer to the Python standard library. 194 | intersphinx_mapping = { 195 | "numpy": ("http://docs.scipy.org/doc/numpy/", None), 196 | "python": ("https://docs.python.org/", None), 197 | "torch": ("https://pytorch.org/docs/master/", None), 198 | } 199 | 200 | # -- Options for todo extension ---------------------------------------------- 201 | 202 | # If true, `todo` and `todoList` produce output, else they produce nothing. 203 | todo_include_todos = True 204 | -------------------------------------------------------------------------------- /docs/source/data_model.rst: -------------------------------------------------------------------------------- 1 | .. _data-model: 2 | 3 | Data model 4 | ========== 5 | 6 | PBG operates on directed multi-relation multigraphs, whose vertices are called **entities**. 7 | Each **edge** connects a source to a destination entity, which are respectively called its 8 | **left-** and **right-hand side** (shortened to **LHS** and **RHS**). Multiple edges between 9 | the same pair of entities are allowed. Loops, i.e., edges whose left- and right- hand sides 10 | are the same, are allowed as well. 11 | 12 | Each entity is of a certain **entity type** (one and only one type per entity). 13 | Thus, the types partition all the entities into disjoint groups. Similarly, each 14 | edge also belongs to exactly one **relation type**. All edges of a given 15 | relation type must have all their left-hand side entities of the same entity 16 | type and, similarly, all their right-hand side entities of the same entity type 17 | (possibly a different entity type than the left-hand side one). This property 18 | means that each relation type has a left-hand side entity type and a right-hand 19 | side entity type. 20 | 21 | .. figure:: _static/graph_unpartitioned.svg 22 | :figwidth: 100 % 23 | :width: 60 % 24 | :align: center 25 | :alt: a graph with three entity types and three relation types 26 | 27 | In this graph, there are 14 entities: 5 of the red entity type, 6 of the 28 | yellow entity type and 3 of the blue entity type; there are also 12 edges: 29 | 6 of the orange relation type (between red and yellow entities), 3 of the 30 | purple relation type (between red and blue entities) and 3 of the green relation 31 | type (between yellow and blue entities). 32 | 33 | In order for PBG to operate on large-scale graphs, the graph is broken 34 | up into small pieces, on which training can happen in a distributed manner. This 35 | is first achieved by further splitting the entities of each type into a certain 36 | number of subsets, called **partitions**. Then, for each relation type, its 37 | edges are divided into **buckets**: for each pair of partitions (one from the 38 | left- and one from the right-hand side entity types for that relation type) 39 | a bucket is created, which contains the edges of that type whose left- and 40 | right-hand side entities are in those partitions. 41 | 42 | .. figure:: _static/graph_partitioned.svg 43 | :figwidth: 100 % 44 | :width: 60 % 45 | :align: center 46 | :alt: the graph from before, partitioned and with only one bucket visible 47 | 48 | This graph shows a possible partition of the entities, with red having 3 49 | partitions, yellow having 3, and blue having only one (hence blue is 50 | unpartitioned). The edges displayed are those of the orange bucket between 51 | the partitions 2 of the red entities and the partition 1 of the yellow entities. 52 | 53 | .. note:: 54 | For technical reasons, at the current state all entity types that appear 55 | on the left-hand side of some relation type must be divided into the same 56 | number of partitions (except unpartitioned entities). The same must hold for 57 | all entity types that appear on the right-hand side. In numpy-speak, it means 58 | that the number of partitions of all entities must be broadcastable to the same value. 59 | 60 | An entity is identified by its type, its partition and its index within the 61 | partition (indices must be contiguous, meaning that if there are :math:`N` 62 | entities in a type's partition, their indices lie in the half-open interval :math:`[0, N)`). 63 | An edge is identified by its type, its bucket (i.e., the partitions 64 | of its left- and right-hand side entity types) and the indices of its left- and 65 | right-hand side entities in their respective partitions. An edge doesn't have 66 | to specify its left- and right-hand side entity types, because they are implicit 67 | in the edge's relation type. 68 | 69 | Formally, each bucket can be identifies by a pair of integers :math:`(i, j)`, where :math:`i` and :math:`j` are 70 | respectively the left- and right-hand side partitions. Inside that bucket, each edge can be identified by a triplet 71 | of integers :math:`(x, r, y)`, with :math:`x` and :math:`y` representing respectively the left- and right-hand side 72 | entities and :math:`r` representing the relation type. This edge is "interpreted" by first looking up relation type 73 | :math:`r` in the configuration, and finding out that it can only have entities of type :math:`e_1` on its left-hand side 74 | and of type :math:`e_2` on its right-hand side. One can then determine the left-hand side entity, which is given by 75 | :math:`(e_1, i, x)` (its type, its partition and its index within the partition), and, similarly, the right-hand side one 76 | which is :math:`(e_2, j, y)`. 77 | -------------------------------------------------------------------------------- /docs/source/dynamic_relations.rst: -------------------------------------------------------------------------------- 1 | .. _dynamic-relations: 2 | 3 | Dynamic relations 4 | ----------------- 5 | 6 | .. caution:: This is an advanced topic! 7 | 8 | Enabling the ``dynamic_relations`` flag in the configuration activates an alternative mode to be 9 | used for graphs with a large number of relations (more than ~100 relations). In dynamic relation mode, 10 | PBG runs with several modifications to its "standard" operation in order to support the large number of relations. 11 | The differences are: 12 | 13 | - The *number* of relations isn't provided in the config but is instead found in the input data, namely in the entity 14 | path, inside a :file:`dynamic_rel_count.txt` file. The settings of the relations, however, are still provided in the 15 | config file. This happens by providing a single relation config which will act as a "template" for all other ones, by 16 | being duplicated an appropriate number of times. One can think of this as the one relation in the config being 17 | "broadcasted" to the size of the relation list found in the :file:`dynamic_rel_count.txt` file. 18 | 19 | - The batches of positive edges that are passed from the training loop into the model contain edges for multiple relation 20 | types at the same time (instead of each batch coming entirely from the same relation type). This introduces some performance challenges 21 | in how the operators are applied to the embeddings, as instead of a single operator with a single set of parameters 22 | applied to all edges, there might be a different one for each edge. The previous property ensures that all the operators 23 | are of the same type, so just their parameters might differ from one row to another. To account for this, the operators 24 | for dynamic relations are implemented differently, with a single operator object containing the parameters for all 25 | relation types. This implementation detail should be transparent as for how the operators are applied to the embeddings, 26 | but might come up when retrieving the parameters at the end of training. 27 | 28 | - With non-dynamic relations, the operator is applied to the embedding of the right-hand side entity of the edge, whereas 29 | the embedding of the left-hand side entity is left unchanged. In a given batch, denote the :math:`i`-th positive edge 30 | by :math:`(x_i, r, y_i)` (:math:`x_i` and :math:`y_i` being the left- and right-hand side entities, :math:`r` being the 31 | relation type). For each of the positive edges, denote its :math:`j`-th negative sample :math:`(x_i, r, y'_{i,j})`. 32 | Due to :ref:`same-batch negative sampling ` it may occur that the same right-hand side 33 | entity is used as a negative for several positives, that is, that :math:`y'_{i_1,j_1} = y'_{i_2,j_2}` for 34 | :math:`i_1 \neq i_2`. However, since it's the same relation type :math:`r` for all negatives, all the right-hand side 35 | entities will be transformed in the same way (i.e., passed through :math:`r`'s operator) no matter what positive edge 36 | they are a negative for. we need to apply the operator of :math:`r` to all of them, hence the total number of operator 37 | evaluations is equal to the number of positives and negatives. 38 | 39 | In case of dynamic relations the batch contains edges of the form :math:`(x_i, r_i, y_i)`, with possibly a different 40 | :math:`r_i` for each :math:`i`. If negative sampling and operator application worked the same, it might end up being 41 | necessary to transform each right-hand side entity multiple times in several ways, once for each different relation 42 | type of the edges the entity is a negative for. This would multiply the number of required operations by a significant 43 | factor and cause a sensible performance hit. 44 | 45 | To counter this, operators are applied differently in case of dynamic relations. They are applied to *either* the 46 | left- *or* the right-hand side (never both at the same time), and a different set of parameters is used in each of 47 | these two cases. On an input edge :math:`(x_i, r_i, y_i)` both ways of applying the operators are performed (separately). 48 | For the negatives of the form :math:`(x'_{i,j}, r_i, y_i)` (i.e., with the left-hand side entity changed), the operator 49 | is only applied to the right-hand side. Symmetrically, on :math:`(x_i, r_i, y'_{i,j})`, the operator is only applied to 50 | the left-hand side. This means that the operator is ever only applied to the entities of the original positive input 51 | edge, not on the entities of the negatives. Thus the number of operator evaluations is equal to the number of input 52 | edges in the batch. 53 | 54 | One could imagine it as if, for each edge of a certain relation type, a reversed edge were added to the graph, of a 55 | symmetric relation type. For each of these edges, the operator is only applied to the right-hand side, just like with 56 | standard relations. However, when sampling negatives, only the left-hand side entities are replaced, whereas the 57 | right-hand ones are kept unchanged. 58 | 59 | For more insight about this, look also at the "reciprocal predicates" described in [this paper](https://arxiv.org/pdf/1806.07297.pdf). 60 | -------------------------------------------------------------------------------- /docs/source/evaluation.rst: -------------------------------------------------------------------------------- 1 | Evaluation 2 | ========== 3 | 4 | During training, the average loss is reported for each edge bucket at each pass. 5 | Evaluation metrics can be computed on held-out data during or after training to 6 | measure the quality of trained embeddings. 7 | 8 | Offline evaluation 9 | ------------------ 10 | 11 | The ``torchbiggraph_eval`` command will perform an offline evaluation of trained PBG embeddings on a validation dataset. 12 | This dataset should contain held-out data not included in the training dataset. It is invoked in the same 13 | way as the training command and takes the same arguments. 14 | 15 | It is generally advisable to have two versions of the config file, one for training and one for evaluation, with the same 16 | parameters except for the edge paths, in order to evaluate a separate (and often smaller) set of edges. (It's also possible 17 | to use a single config file and have it produce different output based on environment variables or other context). 18 | Training-specific config parameters (e.g., the learning rate, loss function, ...) will be ignored during evaluation. 19 | 20 | The metrics are first reported on each bucket, and a global average is computed at the end. 21 | (If multiple edge paths are in use, metrics are computed separately for each of them but still ultimately averaged). 22 | 23 | Many metrics are statistics based on the "ranks" of the edges of the validation set. 24 | The rank of a positive edge is determined by the rank of its score against the scores of 25 | :ref:`a certain number of negative edges `. A rank of 1 is the "best" 26 | outcome as it means that the positive edge had a higher score than all the negatives. Higher 27 | values are "worse" as they indicate that the positive didn't stand out. 28 | 29 | It may happen that some of the negative samples used in the rank computation are in fact 30 | other positive samples, which are expected to have a high score and may thus cause adverse effects on the rank. 31 | This effect is especially visible on smaller graphs, in particular when all other entities are used to construct the negatives. 32 | To fix it, and to match what is typically done in the literature, 33 | a so-called "filtered" rank is used in the FB15k demo script (and there only), where positive 34 | samples are filtered out when computing the rank of an edge. It is hard to scale this technique 35 | to large graphs, and thus it is not enabled globally. However, filtering is less important 36 | on large graphs as it's less likely to see a training edge among the sampled negatives. 37 | 38 | The metrics are: 39 | 40 | - **Mean Rank**: the average of the ranks of all positives (lower is better, best is 1). 41 | - **Mean Reciprocal Rank (MRR)**: the average of the *reciprocal* of the ranks of all positives (higher is better, best is 1). 42 | - **Hits@1**: the fraction of positives that rank better than all their negatives, i.e., have a rank of 1 (higher is better, best is 1). 43 | - **Hits@10**: the fraction of positives that rank in the top 10 among their negatives (higher is better, best is 1). 44 | - **Hits@50**: the fraction of positives that rank in the top 50 among their negatives (higher is better, best is 1). 45 | - **Area Under the Curve (AUC)**: an estimation of the probability that a randomly chosen positive scores higher than a 46 | randomly chosen negative (*any* negative, not only the negatives constructed by corrupting that positive). 47 | 48 | .. _evaluation-during-training: 49 | 50 | Evaluation during training 51 | -------------------------- 52 | 53 | Offline evaluation is a slow process that is intended to be run after training is complete 54 | to evaluate the final model on a held-out set of edges constructed by the user. However, it's 55 | useful to be able to monitor overfitting as training progresses. PBG offers this functionality, 56 | by calculating the same metrics as the offline evaluation before and after each pass on a 57 | small set of training edges. These stats are printed to the logs. 58 | 59 | The metrics are computed on a set of edges that is held out automatically from the training set. To be more explicit: 60 | using this feature means that training happens on *fewer* edges, as some are excluded and reserved for this evaluation. 61 | The holdout fraction is controlled by the ``eval_fraction`` config parameter (setting it to zero thus disables this 62 | feature). The evaluations before and after each training iteration happen on the same set of edges, thus are comparable. 63 | Moreover, the evaluations for the same edge chunk, edge path and bucket at different epochs also use the same set of edges. 64 | 65 | Evaluation metrics are computed both before and after training each edge bucket because it provides insight into 66 | whether the partitioned training is working. If the partitioned training is converging, then the gap between the "before" 67 | and "after" statistics should go to zero over time. On the other hand, if the partitioned training is causing the model to 68 | overfit on each edge bucket (thus decreasing performance for other edge buckets) then there will be a persistent gap between 69 | the "before" and "after" statistics. 70 | 71 | It's possible to use different batch sizes for :ref:`same-batch ` and 72 | :ref:`uniform negative sampling ` by tuning the ``eval_num_batch_negs`` and the 73 | ``eval_num_uniform_negs`` config parameters. 74 | -------------------------------------------------------------------------------- /docs/source/faq_troubleshooting.rst: -------------------------------------------------------------------------------- 1 | FAQ & Troubleshooting 2 | ===================== 3 | 4 | Frequently Asked Questions 5 | -------------------------- 6 | 7 | Undirected graphs 8 | ^^^^^^^^^^^^^^^^^ 9 | 10 | Edges in PBG's :ref:`data model ` are always interpreted as directed. 11 | To operate on undirected data, it is often enough to replace each undirected edge 12 | ```a <-> b`` with two directed edges ``a -> b`` and ``b -> a``. In fact, even with data 13 | that is already directed, it may be beneficial to artificially add a "reversed" 14 | edge ``b ~> a`` for each original edge ``a -> b``, of a different relation type. 15 | This is automatically done by PBG in the :ref:`dynamic relations mode `. 16 | 17 | Common issues 18 | ------------- 19 | 20 | Bus error 21 | ^^^^^^^^^ 22 | 23 | Training might occasionally fail with a ``Bus error (core dumped)`` message, and 24 | no traceback. This is often caused by the inability to allocate enough _shared_ 25 | memory, that is, memory that can be simultaneously accessed by multiple processes 26 | (this is needed to perform training in parallel). Such an error is produced by 27 | the kernel and it's hard to detect it in advance or catch it. 28 | 29 | This may occur when running PBG inside a Docker container, as by default the 30 | shared memory limit for them is rather small. `This PyTorch issue `_ 31 | may provide some insight in how to address that. If this occurs on a Linux machine, 32 | it may be fixed by increasing the size of the ``tmpfs`` mount on ``/dev/shm`` or 33 | on ``/var/run/shm``. 34 | 35 | It may also just be that the machine ran out of physical memory because the data 36 | is too large. This is exactly the scenario PBG was designed for, and can be fixed 37 | by increasing the number of partitions of the data. 38 | -------------------------------------------------------------------------------- /docs/source/featurized_entities.rst: -------------------------------------------------------------------------------- 1 | .. _featurized-entities: 2 | 3 | Featurized entities 4 | =================== 5 | 6 | .. caution:: This is an advanced feature, which is still under development and hasn't fully stabilized yet. 7 | 8 | In normal operation PBG considers each entity atomic and distinct from all others, and as such it learns an independent embedding 9 | for each of them, with no correlation other than the one acquired during training. However, it is common practice to represent 10 | some type of data as collections of underlying "features", each of which has its own learned embedding. The embedding of an entity 11 | will be implicitly derived from the embeddings of its features. Sharing a feature will enforce a correlation between the embeddings 12 | of two entities. 13 | 14 | For example, entities that represent text documents could have their words as features, i.e., an embedding is learned 15 | for each word and the embedding of a document is the average of the embeddings of the words it contains. 16 | 17 | PBG provides this capability. Featurized mode is activated on a per-entity type basis by enabling the 18 | ``featurized`` flag on its config. As this feature isn't finalized yet, the tooling around it isn't up to par 19 | with non-featurized entities, in particular for converting featurized edgelists to the PBG format. 20 | Practitioners will have to implement their own converters, based on the format described below. 21 | Contributions of converters to and from standard formats are welcome. 22 | 23 | The following changes occur in the training process when featurized entities are enabled: 24 | 25 | - The count stored in the :file:`entity_count_{type}_{part}.txt` file refers to the total number of different *features* 26 | that are encountered in the edge files, rather than to the number of different sets of features. 27 | 28 | - Each edge file :file:`edges_{lhs}_{rhs}.h5` must contain a few more datasets. If any edge in it has a featurized 29 | entity on the left-hand side then it must contain two one-dimensional datasets of integers: ``lhsd_data``, which 30 | contains the flattened concatenation of the lists of features of all left-hand side entities of the edges in the file, 31 | and ``lhsd_offsets``, which contains the "cutpoints" of ``lhsd_data`` where the feature list of one entity ends and 32 | the one for the next entity starts. 33 | 34 | Thus the *entries* of ``lhsd_data`` are feature identifiers, while the *entries* of ``lhsd_offsets`` are *indices* of 35 | ``lhsd_data``. Each pair of consecutive entries of ``lhsd_offsets`` represents an half-open interval of ``lhsd_data``, 36 | thus the first entry of ``lhsd_offsets`` should be 0, the last entry should be the size of ``lhsd_data``, and entries 37 | should be in non-decreasing order. If the edge file contains :math:`N` edges, then ``lhsd_offsets`` must contain 38 | :math:`N + 1` entries. 39 | 40 | * If the left-hand side entity of edge :math:`i` is featurized, then its features will be the values of ``lhsd_data`` 41 | between positions ``lhsd_offsets``:math:`[i]` (inclusive) and ``lhsd_offsets``:math:`[i+1]` (exclusive). 42 | The :math:`i`-th entry of the ``lhs`` dataset, on the other hand, can be any value, as it will be ignored. 43 | 44 | * If the left-hand side entity of edge :math:`i` is *not* featurized, then the offset of the entity will be in 45 | ``lhs``:math:`[i]`, just as usual. In that case its set of features should be empty, that is, one should have 46 | ``lhsd_offsets``:math:`[i]` equal to ``lhsd_offsets``:math:`[i+1]`. 47 | 48 | If any right-hand side entity is featurized, the same must hold for datasets ``rhsd_offsets`` and ``rhsd_data``. 49 | 50 | - Entities are represented as "bags of features". That is, their embeddings will be the average of the embeddings of their 51 | features, similarly to how text documents can be represented as the average of the embeddings of the words they contain. 52 | 53 | - The only form of :ref:`negative sampling ` supported for featurized entities is the 54 | :ref:`same-batch mode `. Both the :ref:`all negatives ` and the 55 | :ref:`uniformly-sampled negatives mode ` are not supported. Observe that this means that 56 | uniform sampling of negatives must be disabled globally. 57 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. torchbiggraph documentation master file, created by 2 | sphinx-quickstart on Tue Jan 22 11:14:33 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to PyTorch-BigGraph's documentation! 7 | ============================================ 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | data_model 14 | scoring 15 | input_output 16 | batch_preparation 17 | distributed_training 18 | loss_optimization 19 | evaluation 20 | dynamic_relations 21 | featurized_entities 22 | configuration_file 23 | faq_troubleshooting 24 | related 25 | downstream_tasks 26 | pretrained_embeddings 27 | 28 | 29 | Indices and tables 30 | ================== 31 | 32 | * :ref:`genindex` 33 | * :ref:`modindex` 34 | * :ref:`search` 35 | 36 | 37 | Legal 38 | ===== 39 | 40 | * `Terms of Use `_ 41 | * `Privacy Policy `_ 42 | -------------------------------------------------------------------------------- /docs/source/input_output.rst: -------------------------------------------------------------------------------- 1 | .. _io-format: 2 | 3 | I/O format 4 | ========== 5 | 6 | Entity and relation types 7 | ------------------------- 8 | 9 | The list of entity types (each identified by a string), plus some information 10 | about each of them, is given in the ``entities`` dictionary in the configuration file. 11 | The list of relation types (each identified by its index in that list), plus 12 | some data like what their left- and right-hand side entity types are, is in the 13 | ``relations`` key of the configuration file. 14 | 15 | Entities 16 | -------- 17 | 18 | The only information that needs to be provided about entities is how many there 19 | are in each entity type's partition. This is done by putting a file named :file:`entity_count_{type}_{part}.txt` for each entity type identified 20 | by ``type`` and each partition ``part`` in the directory specified by the ``entity_path`` config parameter. These files must contain a single 21 | integer (as text), which is the number of entities in that partition. The directory where all these 22 | files reside must be specified as the ``entity_path`` key of the configuration file. 23 | 24 | It is possible to provide an initial value for the embeddings, by specifying a 25 | value for the ``init_path`` configuration key, which is the name of a directory that 26 | contains files in a format similar to the output format detailed in 27 | :ref:`output-format` (possibly without the optimizer state dicts). 28 | 29 | If no initial value is provided, it will be auto-generated, with each dimension 30 | sampled from the centered normal distribution whose standard deviation can be 31 | configured using the ``init_scale`` configuration key. For performance reasons 32 | the samples of all the entities of a certain type will not be independent. 33 | 34 | Edges 35 | ----- 36 | 37 | For each bucket there must be a file that stores all the edges that fall in that 38 | bucket, of all relation types. This means that such a file is only identified by 39 | two integers, the partitions of its left- and right-hand side entities. It must 40 | be named :file:`edges_{lhs}_{rhs}.h5` (where ``lhs`` and ``rhs`` are the above 41 | integers), it must be a `HDF5 `_ file 42 | containing three one-dimensional datasets of the same length, called ``rel``, 43 | ``lhs`` and ``rhs``. The elements in the :math:`i`-th positions in each of them 44 | define the :math:`i`-th edge: ``rel`` identifies the relation type (and thus the 45 | left- and right-hand side entity types), ``lhs`` and ``rhs`` given the indices 46 | of the left- and right-hand side entities within their respective partitions. 47 | 48 | To ease future updates to this format, each file must contain the format version 49 | in the ``format_version`` attribute of the top-level group. The current version is 1. 50 | 51 | If an entity type is unpartitioned (that is, all its entities belong to the 52 | same partition), then the edges incident to these entities must still be 53 | uniformly spread across all buckets. 54 | 55 | These files, for all buckets, must be stored in the same directory, which must 56 | be passed as the ``edge_paths`` configuration key. That key can actually contain 57 | a list of paths, each pointing to a directory of the format described above: in 58 | that case the graph will contain the union of all their edges. 59 | 60 | .. _output-format: 61 | 62 | Checkpoint 63 | ---------- 64 | 65 | The training's checkpoints are also its output, and they are written to the directory 66 | given as the ``checkpoint_path`` parameter in the configuration. Checkpoints are identified 67 | by successive positive integers, starting from 1, and all the files belonging to 68 | a certain checkpoint have an extra component :file:`.v{version}` between their name and extension 69 | (e.g., :file:`{something}.v42.h5` for version 42). 70 | 71 | The latest complete checkpoint version is stored in an additional file in the same directory, called 72 | :file:`checkpoint_version.txt`, which contains a single integer number, the current version. 73 | 74 | Each checkpoint contains a JSON dump of the config that was used to produce it stored in the :file:`config.json` file. 75 | 76 | After a new checkpoint version is saved, the previous one will automatically be 77 | deleted. In order to periodically preserve some of these versions, set the 78 | ``checkpoint_preservation_interval`` config flag to the desired period (expressed 79 | in number of epochs). 80 | 81 | Model parameters 82 | ^^^^^^^^^^^^^^^^ 83 | 84 | The model parameters are stored in a file named :file:`model.h5`, which is a HDF5 file containing 85 | one dataset for each parameter, all of which are located within the ``model`` group. Currently, the 86 | parameters that are provided are: 87 | 88 | - :samp:`model/relations/{idx}/operator/{side}/{param}` with the parameters of each relation's operator. 89 | - :samp:`model/entities/{type}/global_embedding` with the per-entity type global embedding. 90 | 91 | Each of these datasets also contains, in the ``state_dict_key`` attribute, the key it was stored inside the 92 | model state dict. An additional dataset may exist, ``optimizer/state_dict``, which contains the binary blob 93 | (obtained through :func:`torch.save`) of the state dict of the model's optimizer. 94 | 95 | Finally, the top-level group of the file contains a few attributes with additional metadata. This mainly 96 | includes the format version, a JSON-dump of the config and some information about the iteration that produced 97 | the checkpoint. 98 | 99 | Embeddings 100 | ^^^^^^^^^^ 101 | 102 | Then, for each entity type and each of its partitions, there is a file 103 | :file:`embeddings_{type}_{part}.h5` (where ``type`` is the type's name and ``part`` 104 | is the 0-based index of the partition), which is a HDF5 file with two datasets. 105 | One two-dimensional dataset, called ``embeddings``, contains the embeddings of 106 | the entities, with the first dimension being the number of entities and the 107 | second being the dimension of the embedding. 108 | 109 | Just like for the model parameters file, the optimizer state dict and additional metadata is also included. 110 | -------------------------------------------------------------------------------- /docs/source/pretrained_embeddings.rst: -------------------------------------------------------------------------------- 1 | Pre-trained embeddings 2 | ====================== 3 | 4 | For demonstration purposes and to save users their time, we provide pre-trained embeddings for 5 | some common public datasets. 6 | 7 | .. _wiki-data: 8 | 9 | Wikidata 10 | -------- 11 | 12 | `Wikidata `_ is a well-known knowledge base, which includes the discontinued Freebase 13 | knowledge base. 14 | 15 | We used the so-called "truthy" dump from 2019-03-06, in the RDF NTriples format. (The original file isn't available 16 | anymore on the Wikidata website). We used as entities all the distinct strings that appeared as either source or 17 | target nodes in this dump: this means that entities include URLs of Wikidata entities (in the form :samp:``), 18 | plain quoted strings (e.g., :samp:`"{Foo}"`), strings with language annotation (e.g., :samp:`"{Bar}"@{fr}`), dates and times, and possibly more. 19 | Similarly, we used as relation types all the distinct strings that appeared as properties. We then filtered out entities and relation types that 20 | appeared less than 5 times in the data dump. 21 | 22 | The embeddings were trained with the following configuration:: 23 | 24 | def get_torchbiggraph_config(): 25 | 26 | config = dict( 27 | # I/O data 28 | entity_path='data/wikidata', 29 | edge_paths=[], 30 | checkpoint_path='model/wikidata', 31 | 32 | # Graph structure 33 | entities={ 34 | 'all': {'num_partitions': 1}, 35 | }, 36 | relations=[{ 37 | 'name': 'all_edges', 38 | 'lhs': 'all', 39 | 'rhs': 'all', 40 | 'operator': 'translation', 41 | }], 42 | dynamic_relations=True, 43 | 44 | # Scoring model 45 | dimension=200, 46 | global_emb=False, 47 | comparator='dot', 48 | 49 | # Training 50 | num_epochs=4, 51 | num_edge_chunks=10, 52 | batch_size=10000, 53 | num_batch_negs=500, 54 | num_uniform_negs=500, 55 | loss_fn='softmax', 56 | lr=0.1, 57 | relation_lr=0.01, 58 | 59 | # Evaluation during training 60 | eval_fraction=0.001, 61 | eval_num_batch_negs=10000, 62 | eval_num_uniform_negs=0, 63 | 64 | # Misc 65 | verbose=1, 66 | ) 67 | 68 | return config 69 | 70 | The output embeddings are available in various formats: 71 | 72 | - `wikidata_translation_v1.tsv.gz `_ (36GiB), 73 | a gzipped TSV (tab-separated value) file in an old format produced by ``torchbiggraph_export_to_tsv`` 74 | (see :ref:`here ` for how to parse it). 75 | - `wikidata_translation_v1_names.json.gz `_ (378MiB), 76 | a gzipped JSON-encoded list of all the keys in the first column of the TSV file. 77 | - `wikidata_translation_v1_vectors.npy.gz `_ (39.9GiB), 78 | a gzipped serialized NumPy array with the 200-dimension vectors, one for each line of the TSV file. 79 | -------------------------------------------------------------------------------- /docs/source/related.rst: -------------------------------------------------------------------------------- 1 | Related works 2 | ============= 3 | 4 | PBG was designed with the expertise gained from many previous works in the knowledge 5 | base completion literature, integrating what has been shown to work well over time. 6 | In the sections below, we describe the models that inspired some of the operators and features of PBG. 7 | 8 | TransE 9 | ------ 10 | 11 | TransE_ is a popular model in knowledge base completion due to its simplicity: 12 | when two embeddings are compared to calculate the score of an edge between them, 13 | the right-hand side one is first translated by a vector :math:`v_r` (of the same 14 | dimension as the embeddings) that is specific to the relation type. Contrary to 15 | PBG, TransE aims at giving lower scores to entities that are nearby, hence the 16 | score of a triple :math:`(x, r, y)` is computed as: 17 | 18 | .. math:: 19 | s(x, r, y) = d(\theta_x + v_r - \theta_y) 20 | 21 | where :math:`d` is a dissimilarity function such as the :math:`L_1` or :math:`L_2` 22 | norm. 23 | 24 | PBG can be configured to operate like TransE by using the ``translation`` operator 25 | and by introducing a new :ref:`comparator ` based on the desired 26 | dissimilarity function. However, contrary to the dot product or the cosine distance, 27 | the comparison between all pairs of vectors from two sets using the :math:`L_1` and 28 | :math:`L_2` norms cannot be expressed as a matrix multiplication and thus would be 29 | challenging to implement as efficiently. One could consider using the cosine distance 30 | instead of the :math:`L_2` norm since the former measures (the cosine of) the angle 31 | between two vectors which, when small, is approximately their :math:`L_2` distance. 32 | 33 | RESCAL 34 | ------ 35 | 36 | RESCAL_ is a restriction on the Tucker factorization. Relations are represented 37 | as matrices :math:`M_r`, and the score of a triple :math:`(x, r, y)` is computed as: 38 | 39 | .. math:: 40 | s(x, r, y) = \theta^{\top}_x M_r \theta_y 41 | 42 | This corresponds to PBG's ``linear`` operator. The original paper suggests to use 43 | weight-decay on the parameters of the model. Such regularization is not available 44 | in PBG currently, which relies instead on early stopping and control of the maximum 45 | norm of the embeddings which scales more easily. 46 | 47 | The need for weight decay stems from each relation having a lot of parameters, 48 | which could lead them to overfit and to not perform well on, for example, FB15k. 49 | RESCAL should only be considered for models with a large number of edges for 50 | each relation type, where overfitting is not an issue. 51 | 52 | DistMult 53 | -------- 54 | 55 | DistMult_ is a special case of RESCAL, in which relations are limited to diagonal 56 | matrices represented as vectors :math:`v_r`. The score of a triple :math:`(x, r, y)` 57 | is thus: 58 | 59 | .. math:: 60 | s(x, r, y) = \langle \theta_x, v_r, \theta_y \rangle = \sum_{d=1}^D \theta_{x, d} v_{r, d} \theta_{y, d} 61 | 62 | This is the ``diagonal`` operator in PBG. Notice that with the same embedding 63 | space on the left and right-hand side, this operator is limited to representing 64 | symmetric relations. This restriction however leads to less over-fitting and good 65 | performances on `several benchmarks `_. 66 | 67 | ComplEx 68 | ------- 69 | 70 | ComplEx_ is similar to DistMult, but uses embeddings in :math:`\mathbb{C}` and 71 | represents the right-hand side embeddings as complex conjugates of the left-hand side ones. 72 | This allows to represent non-symmetric relations. The score of a triple :math:`(x, r, y)` 73 | is computed as: 74 | 75 | .. math:: 76 | s(x, r, y) = \operatorname{Re}(\langle \theta_x, v_r, \overline{\theta_y} \rangle) 77 | 78 | The ``complex_diagonal`` operator in PBG interprets a :math:`D`-dimensional real 79 | embedding as a :math:`D/2`-dimensional complex one, with the first :math:`D/2` 80 | values representing the real part and the remaining ones for the imaginary part. 81 | As shown in the original paper, the ComplEx score can then be written as a dot 82 | product in :math:`\mathbb{R}^{D}`, hence replicated in PBG using the ``dot`` operator. 83 | 84 | Reciprocal Relations 85 | -------------------- 86 | 87 | Two papers (`[1] `_ and `[2] `_) 88 | simultaneously suggested to explicitly train on reciprocal relations, i.e., for each triple 89 | :math:`(x, r, y)` in the training set add another one :math:`(y, r', x)`. This can 90 | be done implicitly in PBG with :ref:`dynamic relations `. Jointly with 91 | the ``complex_diagonal`` operator, this allows reproducing state of the art results on ``FB15K`` with PBG. 92 | 93 | .. _TransE: http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf 94 | .. _RESCAL: http://www.icml-2011.org/papers/438_icmlpaper.pdf 95 | .. _Distmult: https://arxiv.org/pdf/1412.6575v4.pdf 96 | .. _ComplEx: http://proceedings.mlr.press/v48/trouillon16.pdf 97 | -------------------------------------------------------------------------------- /ifbpy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import warnings 10 | 11 | from libfb.py.ipython_par import launch_ipython 12 | 13 | 14 | warnings.simplefilter("ignore", category=DeprecationWarning) 15 | warnings.simplefilter("ignore", category=ResourceWarning) 16 | 17 | 18 | launch_ipython() 19 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = torchbiggraph 3 | version = file: torchbiggraph/VERSION.txt 4 | url = https://github.com/facebookresearch/PyTorch-BigGraph 5 | project_urls = 6 | Source = https://github.com/facebookresearch/PyTorch-BigGraph 7 | Bug Reports = https://github.com/facebookresearch/PyTorch-BigGraph/issues 8 | Documentation = https://torchbiggraph.readthedocs.io/ 9 | author = Facebook AI Research 10 | classifiers = 11 | Development Status :: 5 - Production/Stable 12 | Environment :: Console 13 | Intended Audience :: Science/Research 14 | License :: OSI Approved :: BSD License 15 | Operating System :: OS Independent 16 | Programming Language :: Python 17 | Programming Language :: Python :: 3 18 | Programming Language :: Python :: 3.6 19 | Programming Language :: Python :: 3.7 20 | Programming Language :: Python :: 3 :: Only 21 | Topic :: Scientific/Engineering :: Artificial Intelligence 22 | # Already provided as a classifier. 23 | # license = BSD License 24 | license_files = 25 | LICENSE.txt 26 | torchbiggraph/examples/LICENSE.txt 27 | description = A distributed system to learn embeddings of large graphs 28 | long_description = file: README.md 29 | long_description_content_type = text/markdown 30 | keywords = 31 | machine-learning 32 | knowledge-base 33 | graph-embedding 34 | link-prediction 35 | test_suite = test 36 | 37 | [options] 38 | setup_requires = 39 | setuptools >= 39.2 40 | install_requires = 41 | attrs >= 18.2 42 | h5py >= 2.8 43 | numpy 44 | setuptools 45 | torch >= 1 46 | tqdm 47 | python_requires = >=3.6, <4 48 | packages = find: 49 | 50 | [options.extras_require] 51 | docs = Sphinx 52 | parquet = parquet 53 | 54 | [options.entry_points] 55 | console_scripts = 56 | torchbiggraph_config = torchbiggraph.config:main 57 | torchbiggraph_eval = torchbiggraph.eval:main 58 | torchbiggraph_example_fb15k = torchbiggraph.examples.fb15k:main 59 | torchbiggraph_example_livejournal = torchbiggraph.examples.livejournal:main 60 | torchbiggraph_export_to_tsv = torchbiggraph.converters.export_to_tsv:main 61 | torchbiggraph_import_from_tsv = torchbiggraph.converters.import_from_tsv:main 62 | torchbiggraph_partitionserver = torchbiggraph.partitionserver:main 63 | torchbiggraph_train = torchbiggraph.train:main 64 | torchbiggraph_import_from_parquet = torchbiggraph.converters.import_from_parquet:main [parquet] 65 | 66 | 67 | [options.packages.find] 68 | exclude = 69 | docs 70 | test 71 | 72 | [options.package_data] 73 | torchbiggraph = 74 | VERSION.txt 75 | torchbiggraph.examples = 76 | configs/*.py 77 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import os 10 | 11 | from setuptools import setup 12 | from torch.utils import cpp_extension 13 | 14 | 15 | if __name__ == "__main__": 16 | if int(os.getenv("PBG_INSTALL_CPP", 0)) == 0: 17 | setup() 18 | else: 19 | setup( 20 | ext_modules=[ 21 | cpp_extension.CppExtension( 22 | "torchbiggraph._C", ["torchbiggraph/util.cpp"] 23 | ) 24 | ], 25 | cmdclass={"build_ext": cpp_extension.BuildExtension}, 26 | ) 27 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/PyTorch-BigGraph/6fccd3f572f530469fe97ee63669524c686bcc2a/test/__init__.py -------------------------------------------------------------------------------- /test/test_bucket_scheduling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import random 10 | from itertools import product 11 | from unittest import main, TestCase 12 | 13 | from torchbiggraph.bucket_scheduling import create_ordered_buckets 14 | from torchbiggraph.config import BucketOrder 15 | 16 | 17 | class TestCreateOrderedBuckets(TestCase): 18 | def test_valid(self) -> None: 19 | """Ensure every method produces a valid order (contain all pairs once). 20 | 21 | Even if it may not be the intended order. 22 | 23 | """ 24 | orders = [ 25 | BucketOrder.RANDOM, 26 | BucketOrder.AFFINITY, 27 | BucketOrder.INSIDE_OUT, 28 | BucketOrder.OUTSIDE_IN, 29 | ] 30 | shapes = [(4, 4), (3, 5), (6, 1), (1, 6), (1, 1)] 31 | generator = random.Random() 32 | 33 | for order in orders: 34 | for nparts_lhs, nparts_rhs in shapes: 35 | seed = random.getrandbits(32) 36 | with self.subTest( 37 | order=order, shape=(nparts_lhs, nparts_rhs), seed=seed 38 | ): 39 | generator.seed(seed) 40 | actual_buckets = create_ordered_buckets( 41 | nparts_lhs=nparts_lhs, 42 | nparts_rhs=nparts_rhs, 43 | order=order, 44 | generator=generator, 45 | ) 46 | 47 | self.assertCountEqual( 48 | actual_buckets, product(range(nparts_lhs), range(nparts_rhs)) 49 | ) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /test/test_checkpoint_manager.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import json 10 | from unittest import main, TestCase 11 | 12 | from torchbiggraph.checkpoint_manager import ConfigMetadataProvider, TwoWayMapping 13 | from torchbiggraph.config import ConfigSchema, EntitySchema, RelationSchema 14 | 15 | 16 | class TestTwoWayMapping(TestCase): 17 | def test_one_field(self) -> None: 18 | m = TwoWayMapping("foo.bar.{field}", "{field}/ham/eggs", fields=["field"]) 19 | self.assertEqual(m.private_to_public.map("foo.bar.baz"), "baz/ham/eggs") 20 | self.assertEqual(m.public_to_private.map("spam/ham/eggs"), "foo.bar.spam") 21 | with self.assertRaises(ValueError): 22 | m.private_to_public.map("f00.b4r.b4z") 23 | with self.assertRaises(ValueError): 24 | m.private_to_public.map("foo.bar") 25 | with self.assertRaises(ValueError): 26 | m.private_to_public.map("foo.bar.") 27 | with self.assertRaises(ValueError): 28 | m.private_to_public.map("foo.bar.baz.2") 29 | with self.assertRaises(ValueError): 30 | m.private_to_public.map("2.foo.bar.baz") 31 | with self.assertRaises(ValueError): 32 | m.public_to_private.map("sp4m/h4m/3gg5") 33 | with self.assertRaises(ValueError): 34 | m.public_to_private.map("ham/eggs") 35 | with self.assertRaises(ValueError): 36 | m.public_to_private.map("/ham/eggs") 37 | with self.assertRaises(ValueError): 38 | m.public_to_private.map("2/spam/ham/eggs") 39 | with self.assertRaises(ValueError): 40 | m.public_to_private.map("spam/ham/eggs/2") 41 | 42 | def test_many_field(self) -> None: 43 | m = TwoWayMapping( 44 | "fo{field1}.{field2}ar.b{field3}z", 45 | "sp{field3}m/{field2}am/egg{field1}", 46 | fields=["field1", "field2", "field3"], 47 | ) 48 | self.assertEqual(m.private_to_public.map("foo.bar.baz"), "spam/bam/eggo") 49 | self.assertEqual(m.public_to_private.map("spam/ham/eggs"), "fos.har.baz") 50 | 51 | 52 | class TestConfigMetadataProvider(TestCase): 53 | def test_basic(self) -> None: 54 | config = ConfigSchema( 55 | entities={"e": EntitySchema(num_partitions=1)}, 56 | relations=[RelationSchema(name="r", lhs="e", rhs="e")], 57 | dimension=1, 58 | entity_path="foo", 59 | edge_paths=["bar"], 60 | checkpoint_path="baz", 61 | ) 62 | metadata = ConfigMetadataProvider(config).get_checkpoint_metadata() 63 | self.assertIsInstance(metadata, dict) 64 | self.assertCountEqual(metadata.keys(), ["config/json"]) 65 | self.assertEqual( 66 | config, ConfigSchema.from_dict(json.loads(metadata["config/json"])) 67 | ) 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /test/test_distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | from unittest import main, TestCase 10 | 11 | from torchbiggraph.distributed import ProcessRanks 12 | 13 | 14 | class TestProcessRanks(TestCase): 15 | def test_implicit_partition_servers(self) -> None: 16 | ranks = ProcessRanks.from_num_invocations(3, -1) 17 | self.assertEqual(ranks.trainers, [0, 1, 2]) 18 | self.assertEqual(ranks.parameter_servers, [3, 4, 5]) 19 | self.assertEqual(ranks.parameter_clients, [6, 7, 8]) 20 | self.assertEqual(ranks.lock_server, 9) 21 | self.assertEqual(ranks.partition_servers, [10, 11, 12]) 22 | 23 | def test_no_partition_servers(self) -> None: 24 | ranks = ProcessRanks.from_num_invocations(4, 0) 25 | self.assertEqual(ranks.trainers, [0, 1, 2, 3]) 26 | self.assertEqual(ranks.parameter_servers, [4, 5, 6, 7]) 27 | self.assertEqual(ranks.parameter_clients, [8, 9, 10, 11]) 28 | self.assertEqual(ranks.lock_server, 12) 29 | self.assertEqual(ranks.partition_servers, []) 30 | 31 | def test_explicit_partition_servers(self) -> None: 32 | ranks = ProcessRanks.from_num_invocations(5, 3) 33 | self.assertEqual(ranks.trainers, [0, 1, 2, 3, 4]) 34 | self.assertEqual(ranks.parameter_servers, [5, 6, 7, 8, 9]) 35 | self.assertEqual(ranks.parameter_clients, [10, 11, 12, 13, 14]) 36 | self.assertEqual(ranks.lock_server, 15) 37 | self.assertEqual(ranks.partition_servers, [16, 17, 18]) 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /test/test_entitylist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | from typing import Sequence 10 | from unittest import main, TestCase 11 | 12 | import torch 13 | from torchbiggraph.entitylist import EntityList 14 | from torchbiggraph.tensorlist import TensorList 15 | 16 | 17 | def tensor_list_from_lists(lists: Sequence[Sequence[int]]) -> TensorList: 18 | offsets = torch.tensor([0] + [len(l) for l in lists], dtype=torch.long).cumsum( 19 | dim=0 20 | ) 21 | data = torch.cat([torch.tensor(l, dtype=torch.long) for l in lists], dim=0) 22 | return TensorList(offsets, data) 23 | 24 | 25 | class TestEntityList(TestCase): 26 | def test_empty(self) -> None: 27 | self.assertEqual( 28 | EntityList.empty(), 29 | EntityList(torch.empty((0,), dtype=torch.long), TensorList.empty()), 30 | ) 31 | 32 | def test_from_tensor(self) -> None: 33 | self.assertEqual( 34 | EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), 35 | EntityList( 36 | torch.tensor([3, 4], dtype=torch.long), TensorList.empty(num_tensors=2) 37 | ), 38 | ) 39 | 40 | def test_from_tensor_list(self) -> None: 41 | tensor_list = tensor_list_from_lists([[3, 4], [0, 2]]) 42 | self.assertEqual( 43 | EntityList.from_tensor_list(tensor_list), 44 | EntityList(torch.full((2,), -1, dtype=torch.long), tensor_list), 45 | ) 46 | 47 | def test_cat(self) -> None: 48 | tensor_1 = torch.tensor([2, 3], dtype=torch.long) 49 | tensor_2 = torch.tensor([0, 1], dtype=torch.long) 50 | tensor_sum = torch.tensor([2, 3, 0, 1], dtype=torch.long) 51 | tensor_list_1 = tensor_list_from_lists([[3, 4], [0]]) 52 | tensor_list_2 = tensor_list_from_lists([[1, 2, 0], []]) 53 | tensor_list_sum = tensor_list_from_lists([[3, 4], [0], [1, 2, 0], []]) 54 | self.assertEqual( 55 | EntityList.cat( 56 | [ 57 | EntityList(tensor_1, tensor_list_1), 58 | EntityList(tensor_2, tensor_list_2), 59 | ] 60 | ), 61 | EntityList(tensor_sum, tensor_list_sum), 62 | ) 63 | 64 | def test_constructor_checks(self) -> None: 65 | with self.assertRaises(ValueError): 66 | EntityList( 67 | torch.tensor([3, 4, 0], dtype=torch.long), 68 | tensor_list_from_lists([[2, 1]]), 69 | ) 70 | 71 | def test_to_tensor(self) -> None: 72 | self.assertTrue( 73 | torch.equal( 74 | EntityList( 75 | torch.tensor([2, 3], dtype=torch.long), 76 | tensor_list_from_lists([[], []]), 77 | ).to_tensor(), 78 | torch.tensor([2, 3], dtype=torch.long), 79 | ) 80 | ) 81 | 82 | def test_to_tensor_list(self) -> None: 83 | self.assertEqual( 84 | EntityList( 85 | torch.tensor([-1, -1], dtype=torch.long), 86 | tensor_list_from_lists([[3, 4], [0]]), 87 | ).to_tensor_list(), 88 | tensor_list_from_lists([[3, 4], [0]]), 89 | ) 90 | 91 | def test_equal(self) -> None: 92 | el = EntityList( 93 | torch.tensor([3, 4], dtype=torch.long), 94 | tensor_list_from_lists([[], [2, 1, 0]]), 95 | ) 96 | self.assertEqual(el, el) 97 | self.assertNotEqual( 98 | el, 99 | EntityList( 100 | torch.tensor([4, 2], dtype=torch.long), 101 | tensor_list_from_lists([[], [2, 1, 0]]), 102 | ), 103 | ) 104 | self.assertNotEqual( 105 | el, 106 | EntityList( 107 | torch.tensor([3, 4], dtype=torch.long), 108 | tensor_list_from_lists([[3], [2, 0]]), 109 | ), 110 | ) 111 | 112 | def test_len(self) -> None: 113 | self.assertEqual( 114 | len( 115 | EntityList( 116 | torch.tensor([3, 4], dtype=torch.long), 117 | tensor_list_from_lists([[], [2, 1, 0]]), 118 | ) 119 | ), 120 | 2, 121 | ) 122 | 123 | def test_getitem_int(self) -> None: 124 | self.assertEqual( 125 | EntityList( 126 | torch.tensor([3, 4, 1, 0], dtype=torch.long), 127 | tensor_list_from_lists([[2, 1], [0], [], [3, 4, 5]]), 128 | )[-3], 129 | EntityList( 130 | torch.tensor([4], dtype=torch.long), tensor_list_from_lists([[0]]) 131 | ), 132 | ) 133 | 134 | def test_getitem_slice(self) -> None: 135 | self.assertEqual( 136 | EntityList( 137 | torch.tensor([3, 4, 1, 0], dtype=torch.long), 138 | tensor_list_from_lists([[2, 1], [0], [], [3, 4, 5]]), 139 | )[1:3], 140 | EntityList( 141 | torch.tensor([4, 1], dtype=torch.long), 142 | tensor_list_from_lists([[0], []]), 143 | ), 144 | ) 145 | 146 | def test_getitem_longtensor(self) -> None: 147 | self.assertEqual( 148 | EntityList( 149 | torch.tensor([3, 4, 1, 0], dtype=torch.long), 150 | tensor_list_from_lists([[2, 1], [0], [], [3, 4, 5]]), 151 | )[torch.tensor([2, 0])], 152 | EntityList( 153 | torch.tensor([1, 3], dtype=torch.long), 154 | tensor_list_from_lists([[], [2, 1]]), 155 | ), 156 | ) 157 | 158 | 159 | if __name__ == "__main__": 160 | main() 161 | -------------------------------------------------------------------------------- /test/test_graph_storages.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import tempfile 10 | from unittest import main, TestCase 11 | 12 | import h5py 13 | import numpy as np 14 | import torch 15 | from torchbiggraph.graph_storages import FileEdgeAppender 16 | from torchbiggraph.tensorlist import TensorList 17 | 18 | 19 | class TestFileEdgeAppender(TestCase): 20 | def test_tensors(self) -> None: 21 | with tempfile.NamedTemporaryFile() as bf: 22 | with h5py.File(bf.name, "w") as hf, FileEdgeAppender(hf) as buffered_hf: 23 | buffered_hf.append_tensor( 24 | "foo", torch.tensor([1, 2, 3], dtype=torch.long) 25 | ) 26 | buffered_hf.append_tensor( 27 | "bar", torch.tensor([10, 11], dtype=torch.long) 28 | ) 29 | buffered_hf.append_tensor("foo", torch.tensor([4], dtype=torch.long)) 30 | buffered_hf.append_tensor("foo", torch.tensor([], dtype=torch.long)) 31 | buffered_hf.append_tensor( 32 | "bar", torch.arange(12, 1_000_000, dtype=torch.long) 33 | ) 34 | buffered_hf.append_tensor("foo", torch.tensor([5, 6], dtype=torch.long)) 35 | 36 | with h5py.File(bf.name, "r") as hf: 37 | np.testing.assert_equal( 38 | hf["foo"], np.array([1, 2, 3, 4, 5, 6], dtype=np.int64) 39 | ) 40 | np.testing.assert_equal( 41 | hf["bar"], np.arange(10, 1_000_000, dtype=np.int64) 42 | ) 43 | 44 | def test_tensor_list(self) -> None: 45 | with tempfile.NamedTemporaryFile() as bf: 46 | with h5py.File(bf.name, "w") as hf, FileEdgeAppender(hf) as buffered_hf: 47 | buffered_hf.append_tensor_list( 48 | "foo", 49 | TensorList( 50 | torch.tensor([0, 3, 5], dtype=torch.long), 51 | torch.tensor([1, 2, 3, 4, 5], dtype=torch.long), 52 | ), 53 | ) 54 | buffered_hf.append_tensor_list( 55 | "bar", 56 | TensorList( 57 | torch.tensor([0, 1_000_000], dtype=torch.long), 58 | torch.arange(1_000_000, dtype=torch.long), 59 | ), 60 | ) 61 | buffered_hf.append_tensor_list( 62 | "foo", 63 | TensorList( 64 | torch.tensor([0, 1, 1, 3], dtype=torch.long), 65 | torch.tensor([6, 7, 8], dtype=torch.long), 66 | ), 67 | ) 68 | 69 | with h5py.File(bf.name, "r") as hf: 70 | np.testing.assert_equal( 71 | hf["foo_offsets"], np.array([0, 3, 5, 6, 6, 8], dtype=np.int64) 72 | ) 73 | np.testing.assert_equal( 74 | hf["foo_data"], np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int64) 75 | ) 76 | np.testing.assert_equal( 77 | hf["bar_offsets"], np.array([0, 1_000_000], dtype=np.int64) 78 | ) 79 | np.testing.assert_equal( 80 | hf["bar_data"], np.arange(1_000_000, dtype=np.int64) 81 | ) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /test/test_optimizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | # In order to keep values visually aligned in matrix form we use double spaces 10 | # and exceed line length. Tell flake8 to tolerate that. Ideally we'd want to 11 | # disable only those two checks but there doesn't seem to be a way to do so. 12 | # flake8: noqa 13 | 14 | import logging 15 | import os 16 | import unittest 17 | from unittest import main, TestCase 18 | 19 | import torch 20 | import torch.multiprocessing as mp 21 | import torch.nn as nn 22 | from torch.optim import Adagrad 23 | from torchbiggraph.async_adagrad import AsyncAdagrad 24 | from torchbiggraph.row_adagrad import RowAdagrad 25 | 26 | 27 | logger = logging.getLogger("torchbiggraph") 28 | 29 | 30 | class TensorTestCase(TestCase): 31 | def assertTensorEqual(self, actual, expected): 32 | if not isinstance(actual, (torch.FloatTensor, torch.cuda.FloatTensor)): 33 | self.fail("Expected FloatTensor, got %s" % type(actual)) 34 | if actual.size() != expected.size(): 35 | self.fail( 36 | "Expected tensor of size %s, got %s" % (expected.size(), actual.size()) 37 | ) 38 | if not torch.allclose( 39 | actual, expected, rtol=0.00005, atol=0.00005, equal_nan=True 40 | ): 41 | self.fail("Expected\n%r\ngot\n%r" % (expected, actual)) 42 | 43 | 44 | def do_optim(model, optimizer, N, rank): 45 | torch.random.manual_seed(rank) 46 | for i in range(N): 47 | optimizer.zero_grad() 48 | NE = model.weight.shape[0] 49 | inputs = (torch.rand(10) * NE).long() 50 | L = model(inputs).sum() 51 | L.backward() 52 | # print(next(model.parameters()).grad) 53 | optimizer.step() 54 | 55 | 56 | class TestOptimizers(TensorTestCase): 57 | def _stress_optimizer(self, model, optimizer, num_processes=1, iterations=100): 58 | logger.info("_stress_optimizer begin") 59 | processes = [] 60 | for rank in range(num_processes): 61 | p = mp.get_context("spawn").Process( 62 | name=f"Process-{rank}", 63 | target=do_optim, 64 | args=(model, optimizer, iterations, rank), 65 | ) 66 | p.start() 67 | self.addCleanup(p.terminate) 68 | processes.append(p) 69 | 70 | for p in processes: 71 | p.join() 72 | 73 | logger.info("_stress_optimizer complete") 74 | 75 | # def testHogwildStability_Adagrad(self): 76 | # NE = 10000 77 | # model = nn.Embedding(NE, 100) 78 | # optimizer = Adagrad(model.parameters()) 79 | # num_processes = mp.cpu_count() // 2 + 1 80 | # self._stress_optimizer(model, optimizer, num_processes) 81 | 82 | # # This fails for Adagrad because it's not stable 83 | # self.assertLess(model.weight.abs().max(), 1000) 84 | 85 | @unittest.skipIf(os.environ.get("CIRCLECI_TEST") == "1", "Hangs in CircleCI") 86 | def testHogwildStability_AsyncAdagrad(self): 87 | NE = 10000 88 | model = nn.Embedding(NE, 100) 89 | optimizer = AsyncAdagrad(model.parameters()) 90 | num_processes = mp.cpu_count() // 2 + 1 91 | self._stress_optimizer( 92 | model, optimizer, num_processes=num_processes, iterations=50 93 | ) 94 | 95 | self.assertLess(model.weight.abs().max(), 1000) 96 | 97 | @unittest.skipIf(os.environ.get("CIRCLECI_TEST") == "1", "Hangs in CircleCI") 98 | def testHogwildStability_RowAdagrad(self): 99 | NE = 10000 100 | model = nn.Embedding(NE, 100) 101 | optimizer = RowAdagrad(model.parameters()) 102 | num_processes = mp.cpu_count() // 2 + 1 103 | self._stress_optimizer( 104 | model, optimizer, num_processes=num_processes, iterations=50 105 | ) 106 | 107 | # This fails for Adagrad because it's not stable 108 | self.assertLess(model.weight.abs().max(), 1000) 109 | 110 | def _assert_testAccuracy_AsyncAdagrad(self, sparse): 111 | # testing that Adagrad = AsyncAdagrad with 1 process 112 | NE = 10000 113 | golden_model = nn.Embedding(NE, 100, sparse=sparse) 114 | test_model = nn.Embedding(NE, 100, sparse=sparse) 115 | test_model.load_state_dict(golden_model.state_dict()) 116 | 117 | golden_optimizer = Adagrad(golden_model.parameters()) 118 | self._stress_optimizer(golden_model, golden_optimizer, num_processes=1) 119 | 120 | test_optimizer = AsyncAdagrad(test_model.parameters()) 121 | self._stress_optimizer(test_model, test_optimizer, num_processes=1) 122 | 123 | # This fails for Adagrad because it's not stable 124 | self.assertTensorEqual(golden_model.weight, test_model.weight) 125 | 126 | def testAccuracy_AsyncAdagrad_sprase_true(self): 127 | self._assert_testAccuracy_AsyncAdagrad(sparse=True) 128 | 129 | def testAccuracy_AsyncAdagrad_sprase_false(self): 130 | self._assert_testAccuracy_AsyncAdagrad(sparse=False) 131 | -------------------------------------------------------------------------------- /test/test_stats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | from unittest import main, TestCase 10 | 11 | from torchbiggraph.stats import Stats 12 | 13 | 14 | class TestConfig(TestCase): 15 | def test_sum(self) -> None: 16 | a = Stats(my_int_metric=1, my_float_metric=0.1, count=1) 17 | b = Stats(my_int_metric=2, my_float_metric=0.0, count=2) 18 | c = Stats(my_int_metric=0, my_float_metric=0.2, count=2) 19 | self.assertEqual( 20 | Stats.sum([a, b, c]), 21 | Stats(my_int_metric=3, my_float_metric=0.30000000000000004, count=5), 22 | ) 23 | 24 | def test_average(self) -> None: 25 | total = Stats(my_int_metric=9, my_float_metric=1.2, count=3) 26 | self.assertEqual( 27 | total.average(), 28 | Stats(my_int_metric=3, my_float_metric=0.39999999999999997, count=3), 29 | ) 30 | 31 | def test_str(self) -> None: 32 | self.assertEqual( 33 | str(Stats(my_int_metric=1, my_float_metric=0.2, count=3)), 34 | "my_int_metric: 1 , my_float_metric: 0.2 , count: 3", 35 | ) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /test/test_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | from itertools import product 10 | from unittest import main, TestCase 11 | 12 | from torchbiggraph.train_cpu import IterationManager 13 | 14 | 15 | class TestIterationManager(TestCase): 16 | def test_full(self) -> None: 17 | im = IterationManager( 18 | num_epochs=2, edge_paths=["A", "B", "C"], num_edge_chunks=4 19 | ) 20 | self.assertEqual(list(im), list(product(range(2), range(3), range(4)))) 21 | 22 | def test_partial(self) -> None: 23 | im = IterationManager( 24 | num_epochs=2, 25 | edge_paths=["A", "B", "C"], 26 | num_edge_chunks=4, 27 | iteration_idx=(1 * 3 + 2) * 4 + 3, 28 | ) 29 | self.assertEqual(list(im), [(1, 2, 3)]) 30 | 31 | def test_tampering(self) -> None: 32 | im = IterationManager( 33 | num_epochs=2, edge_paths=["A", "B", "C"], num_edge_chunks=4 34 | ) 35 | it = iter(im) 36 | self.assertEqual(next(it), (0, 0, 0)) 37 | im.iteration_idx = (0 * 3 + 1) * 4 + 1 38 | # When calling next it gets incremented. 39 | self.assertEqual(next(it), (0, 1, 2)) 40 | im.edge_paths = ["foo", "bar"] 41 | im.num_edge_chunks = 2 42 | self.assertEqual(next(it), (1, 1, 1)) 43 | im.iteration_idx = 100 44 | with self.assertRaises(StopIteration): 45 | next(it) 46 | 47 | def test_properties(self) -> None: 48 | im = IterationManager( 49 | num_epochs=2, 50 | edge_paths=["A", "B", "C"], 51 | num_edge_chunks=4, 52 | iteration_idx=(0 * 3 + 1) * 4 + 2, 53 | ) 54 | self.assertEqual(im.epoch_idx, 0) 55 | self.assertEqual(im.edge_path_idx, 1) 56 | self.assertEqual(im.edge_path, "B") 57 | self.assertEqual(im.edge_chunk_idx, 2) 58 | self.assertEqual( 59 | im.get_checkpoint_metadata(), 60 | { 61 | "iteration/num_epochs": 2, 62 | "iteration/epoch_idx": 0, 63 | "iteration/num_edge_paths": 3, 64 | "iteration/edge_path_idx": 1, 65 | "iteration/edge_path": "B", 66 | "iteration/num_edge_chunks": 4, 67 | "iteration/edge_chunk_idx": 2, 68 | }, 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /test/test_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | from unittest import main, TestCase 10 | 11 | import torch 12 | from torchbiggraph.util import ( 13 | match_shape, 14 | round_up_to_nearest_multiple, 15 | split_almost_equally, 16 | ) 17 | 18 | 19 | class TestSplitAlmostEqually(TestCase): 20 | def test_exact(self) -> None: 21 | self.assertEqual( 22 | list(split_almost_equally(24, num_parts=4)), 23 | [slice(0, 6), slice(6, 12), slice(12, 18), slice(18, 24)], 24 | ) 25 | 26 | def test_more(self) -> None: 27 | self.assertEqual( 28 | list(split_almost_equally(25, num_parts=4)), 29 | [slice(0, 7), slice(7, 14), slice(14, 21), slice(21, 25)], 30 | ) 31 | 32 | def test_fewer(self) -> None: 33 | self.assertEqual( 34 | list(split_almost_equally(23, num_parts=4)), 35 | [slice(0, 6), slice(6, 12), slice(12, 18), slice(18, 23)], 36 | ) 37 | 38 | def test_so_few_that_last_slice_would_underflow(self) -> None: 39 | # All slices have the same size, which is the ratio size/num_parts 40 | # rounded up. This however may cause earlier slices to get so many 41 | # elements that later ones end up being empty. We need to be careful 42 | # about not returning negative slices in that case. 43 | self.assertEqual( 44 | list(split_almost_equally(5, num_parts=4)), 45 | [slice(0, 2), slice(2, 4), slice(4, 5), slice(5, 5)], 46 | ) 47 | self.assertEqual( 48 | list(split_almost_equally(6, num_parts=5)), 49 | [slice(0, 2), slice(2, 4), slice(4, 6), slice(6, 6), slice(6, 6)], 50 | ) 51 | 52 | 53 | class TestRoundUpToNearestMultiple(TestCase): 54 | def test_exact(self) -> None: 55 | self.assertEqual(round_up_to_nearest_multiple(24, 4), 24) 56 | 57 | def test_more(self) -> None: 58 | self.assertEqual(round_up_to_nearest_multiple(25, 4), 28) 59 | 60 | def test_fewer(self) -> None: 61 | self.assertEqual(round_up_to_nearest_multiple(23, 4), 24) 62 | 63 | 64 | class TestMatchShape(TestCase): 65 | def test_zero_dimensions(self) -> None: 66 | t = torch.zeros(()) 67 | self.assertIsNone(match_shape(t)) 68 | self.assertIsNone(match_shape(t, ...)) 69 | with self.assertRaises(TypeError): 70 | match_shape(t, 0) 71 | with self.assertRaises(TypeError): 72 | match_shape(t, 1) 73 | with self.assertRaises(TypeError): 74 | match_shape(t, -1) 75 | 76 | def test_one_dimension(self) -> None: 77 | t = torch.zeros((3,)) 78 | self.assertIsNone(match_shape(t, 3)) 79 | self.assertIsNone(match_shape(t, ...)) 80 | self.assertIsNone(match_shape(t, 3, ...)) 81 | self.assertIsNone(match_shape(t, ..., 3)) 82 | self.assertEqual(match_shape(t, -1), 3) 83 | with self.assertRaises(TypeError): 84 | match_shape(t) 85 | with self.assertRaises(TypeError): 86 | match_shape(t, 3, 1) 87 | with self.assertRaises(TypeError): 88 | match_shape(t, 3, ..., 3) 89 | 90 | def test_many_dimension(self) -> None: 91 | t = torch.zeros((3, 4, 5)) 92 | self.assertIsNone(match_shape(t, 3, 4, 5)) 93 | self.assertIsNone(match_shape(t, ...)) 94 | self.assertIsNone(match_shape(t, ..., 5)) 95 | self.assertIsNone(match_shape(t, 3, ..., 5)) 96 | self.assertIsNone(match_shape(t, 3, 4, 5, ...)) 97 | self.assertEqual(match_shape(t, -1, 4, 5), 3) 98 | self.assertEqual(match_shape(t, -1, ...), 3) 99 | self.assertEqual(match_shape(t, -1, 4, ...), 3) 100 | self.assertEqual(match_shape(t, -1, ..., 5), 3) 101 | self.assertEqual(match_shape(t, -1, 4, -1), (3, 5)) 102 | self.assertEqual(match_shape(t, ..., -1, -1), (4, 5)) 103 | self.assertEqual(match_shape(t, -1, -1, -1), (3, 4, 5)) 104 | self.assertEqual(match_shape(t, -1, -1, ..., -1), (3, 4, 5)) 105 | with self.assertRaises(TypeError): 106 | match_shape(t) 107 | with self.assertRaises(TypeError): 108 | match_shape(t, 3) 109 | with self.assertRaises(TypeError): 110 | match_shape(t, 3, 4) 111 | with self.assertRaises(TypeError): 112 | match_shape(t, 5, 4, 3) 113 | with self.assertRaises(TypeError): 114 | match_shape(t, 3, 4, 5, 6) 115 | with self.assertRaises(TypeError): 116 | match_shape(t, 3, 4, ..., 4, 5) 117 | 118 | def test_bad_args(self) -> None: 119 | t = torch.empty((0,)) 120 | with self.assertRaises(RuntimeError): 121 | match_shape(t, ..., ...) 122 | with self.assertRaises(RuntimeError): 123 | match_shape(t, "foo") 124 | with self.assertRaises(AttributeError): 125 | match_shape(None) 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /torchbiggraph/VERSION.txt: -------------------------------------------------------------------------------- 1 | 1.0.1.dev 2 | -------------------------------------------------------------------------------- /torchbiggraph/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import pkg_resources 4 | 5 | 6 | __version__ = ( 7 | pkg_resources.resource_string("torchbiggraph", "VERSION.txt") 8 | .decode("utf-8") 9 | .strip() 10 | ) 11 | -------------------------------------------------------------------------------- /torchbiggraph/async_adagrad.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import torch 10 | from torch.optim import Adagrad 11 | 12 | 13 | class AsyncAdagrad(Adagrad): 14 | """Variant of Adagrad that is more robust to asynchronous (HOGWILD) updates. 15 | 16 | c.f. torch.optim.Adagrad 17 | """ 18 | 19 | @torch.no_grad() 20 | def step(self, closure=None): 21 | """Performs a single optimization step. 22 | 23 | Arguments: 24 | closure (callable, optional): A closure that reevaluates the model 25 | and returns the loss. 26 | """ 27 | loss = None 28 | if closure is not None: 29 | with torch.enable_grad(): 30 | loss = closure() 31 | 32 | for group in self.param_groups: 33 | for p in group["params"]: 34 | if p.grad is None: 35 | continue 36 | 37 | grad = p.grad 38 | state = self.state[p] 39 | 40 | state["step"] += 1 41 | 42 | if group["weight_decay"] != 0: 43 | if p.grad.is_sparse: 44 | raise RuntimeError( 45 | "weight_decay option is not compatible with sparse gradients" 46 | ) 47 | grad = grad.add(p, alpha=group["weight_decay"]) 48 | 49 | clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"]) 50 | 51 | if grad.is_sparse: 52 | grad = ( 53 | grad.coalesce() 54 | ) # the update is non-linear so indices must be unique 55 | grad_indices = grad._indices() 56 | grad_values = grad._values() 57 | size = grad.size() 58 | 59 | def make_sparse(values): 60 | constructor = grad.new 61 | if grad_indices.dim() == 0 or values.dim() == 0: 62 | return constructor().resize_as_(grad) 63 | return constructor(grad_indices, values, size) 64 | 65 | # multiple HOGWILD processes may perform unsynchronized 66 | # updates to G. Update a local copy of G independently from 67 | # the shared-memory copy, to guarantee that 68 | # local_G >= grad^2 69 | local_G = state["sum"].sparse_mask(grad)._values() 70 | delta_G = grad_values.pow(2) 71 | state["sum"].add_(make_sparse(delta_G)) 72 | local_G += delta_G 73 | std_values = local_G.sqrt_().add_(group["eps"]) 74 | p.add_(make_sparse(grad_values / std_values), alpha=-clr) 75 | else: 76 | # multiple HOGWILD processes may perform unsynchronized 77 | # updates to G. Update a local copy of G independently from 78 | # the shared-memory copy, to guarantee that 79 | # local_G >= grad^2 80 | local_G = state["sum"].clone() 81 | delta_G = grad * grad 82 | state["sum"].add_(delta_G) 83 | local_G += delta_G 84 | std = local_G.sqrt().add_(group["eps"]) 85 | p.addcdiv_(grad, std, value=-clr) 86 | 87 | return loss 88 | -------------------------------------------------------------------------------- /torchbiggraph/batching.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import time 10 | from abc import ABC, abstractmethod 11 | from typing import Any, Callable, Dict, Iterable, List, Optional 12 | 13 | import torch 14 | from torchbiggraph.edgelist import EdgeList 15 | from torchbiggraph.losses import AbstractLossFunction 16 | from torchbiggraph.model import MultiRelationEmbedder, override_model, Scores 17 | from torchbiggraph.stats import Stats 18 | from torchbiggraph.types import LongTensorType 19 | 20 | 21 | def group_by_relation_type(edges: EdgeList) -> List[EdgeList]: 22 | """Split the edge list in groups that have the same relation type.""" 23 | if len(edges) == 0: 24 | return [] 25 | if edges.has_scalar_relation_type(): 26 | return [edges] 27 | 28 | # FIXME Is PyTorch's sort stable? Won't this risk messing up the random shuffle? 29 | sorted_rel, order = edges.rel.sort() 30 | delta = sorted_rel[1:] - sorted_rel[:-1] 31 | cutpoints = (delta.nonzero().flatten() + 1).tolist() 32 | 33 | result: List[EdgeList] = [] 34 | for start, end in zip([0] + cutpoints, cutpoints + [len(edges)]): 35 | rel_type = sorted_rel[start] 36 | edges_for_rel_type = edges[order[start:end]] 37 | result.append( 38 | EdgeList( 39 | edges_for_rel_type.lhs, 40 | edges_for_rel_type.rhs, 41 | rel_type, 42 | edges_for_rel_type.weight, 43 | ) 44 | ) 45 | return result 46 | 47 | 48 | def batch_edges_mix_relation_types( 49 | edges: EdgeList, *, batch_size: int 50 | ) -> Iterable[EdgeList]: 51 | """Split the edges in batches that can contain multiple relation types 52 | 53 | The output preserves the input's order. Batches are all of the same size, 54 | except possibly the last one. 55 | """ 56 | for offset in range(0, len(edges), batch_size): 57 | yield edges[offset : offset + batch_size] 58 | 59 | 60 | def batch_edges_group_by_relation_type( 61 | edges: EdgeList, *, batch_size: int 62 | ) -> Iterable[EdgeList]: 63 | """Split the edges in batches that each contain a single relation type 64 | 65 | Batches are all of the same size, except possibly the last one for each 66 | relation type. 67 | """ 68 | edge_groups = group_by_relation_type(edges) 69 | num_edges_left_per_group = torch.tensor( 70 | [len(edges) for edges in edge_groups], dtype=torch.long 71 | ) 72 | 73 | while num_edges_left_per_group.sum() > 0: 74 | idx = int(torch.multinomial(num_edges_left_per_group.float(), 1)) 75 | edge_group = edge_groups[idx] 76 | offset = len(edge_group) - int(num_edges_left_per_group[idx]) 77 | sub_edges = edge_group[offset : offset + batch_size] 78 | yield sub_edges 79 | num_edges_left_per_group[idx] -= len(sub_edges) 80 | 81 | 82 | def call(f: Callable[[], Stats]) -> Stats: 83 | """Helper to be able to do pool.map(call, [partial(f, foo=42)]) 84 | 85 | Using pool.starmap(f, [(42,)]) is shorter, but it doesn't support keyword 86 | arguments. It appears going through partial is the only way to do that. 87 | """ 88 | return f() 89 | 90 | 91 | def process_in_batches( 92 | batch_size: int, 93 | model: MultiRelationEmbedder, 94 | batch_processor: "AbstractBatchProcessor", 95 | edges: EdgeList, 96 | indices: Optional[LongTensorType] = None, 97 | delay: float = 0.0, 98 | ) -> Stats: 99 | """Split lhs, rhs and rel in batches, process them and sum the stats 100 | 101 | If indices is not None, only operate on x[indices] for x = lhs, rhs and rel. 102 | If delay is positive, wait for that many seconds before starting. 103 | """ 104 | if indices is not None: 105 | edges = edges[indices] 106 | 107 | time.sleep(delay) 108 | 109 | # FIXME: it's not really safe to do partial batches if num_batch_negs != 0 110 | # because partial batches will produce incorrect results, and if the 111 | # dataset per thread is very small then every batch may be partial. I don't 112 | # know of a perfect solution for this that doesn't introduce other biases... 113 | 114 | all_stats = [] 115 | 116 | if model.num_dynamic_rels > 0: 117 | batcher = batch_edges_mix_relation_types 118 | else: 119 | batcher = batch_edges_group_by_relation_type 120 | 121 | for batch_edges in batcher(edges, batch_size=batch_size): 122 | all_stats.append(batch_processor.process_one_batch(model, batch_edges)) 123 | 124 | stats = Stats.sum(all_stats) 125 | return stats 126 | 127 | 128 | class AbstractBatchProcessor(ABC): 129 | def __init__( 130 | self, 131 | loss_fn: AbstractLossFunction, 132 | relation_weights: List[float], 133 | overrides: Optional[Dict[str, Any]] = None, 134 | ): 135 | self.loss_fn = loss_fn 136 | self.relation_weights = relation_weights 137 | self.overrides = overrides 138 | 139 | def calc_loss(self, scores: Scores, batch_edges: EdgeList): 140 | 141 | lhs_loss = self.loss_fn(scores.lhs_pos, scores.lhs_neg, batch_edges.weight) 142 | rhs_loss = self.loss_fn(scores.rhs_pos, scores.rhs_neg, batch_edges.weight) 143 | relation = ( 144 | batch_edges.get_relation_type_as_scalar() 145 | if batch_edges.has_scalar_relation_type() 146 | else 0 147 | ) 148 | loss = self.relation_weights[relation] * (lhs_loss + rhs_loss) 149 | 150 | return loss 151 | 152 | @abstractmethod 153 | def _process_one_batch( 154 | self, model: MultiRelationEmbedder, batch_edges: EdgeList 155 | ) -> Stats: 156 | pass 157 | 158 | def process_one_batch( 159 | self, model: MultiRelationEmbedder, batch_edges: EdgeList 160 | ) -> Stats: 161 | if self.overrides is not None: 162 | with override_model(model, **self.overrides): 163 | return self._process_one_batch(model, batch_edges) 164 | else: 165 | return self._process_one_batch(model, batch_edges) 166 | -------------------------------------------------------------------------------- /torchbiggraph/converters/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | -------------------------------------------------------------------------------- /torchbiggraph/converters/dictionary.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import math 10 | from typing import Dict, List, Tuple 11 | 12 | 13 | class Dictionary: 14 | def __init__(self, ix_to_word: List[str], *, num_parts: int = 1) -> None: 15 | self.ix_to_word: List[str] = ix_to_word 16 | self.word_to_ix: Dict[str, int] = {w: i for i, w in enumerate(ix_to_word)} 17 | self.num_parts = num_parts 18 | 19 | def get_id(self, word: str) -> int: 20 | return self.word_to_ix[word] 21 | 22 | def size(self) -> int: 23 | return len(self.ix_to_word) 24 | 25 | def get_list(self) -> List[str]: 26 | return self.ix_to_word 27 | 28 | def part_start(self, part: int) -> int: 29 | return math.ceil(part / self.num_parts * self.size()) 30 | 31 | def part_end(self, part: int) -> int: 32 | return self.part_start(part + 1) 33 | 34 | def part_size(self, part: int) -> int: 35 | if not 0 <= part < self.num_parts: 36 | raise ValueError(f"{part} not in [0, {self.num_parts})") 37 | return self.part_end(part) - self.part_start(part) 38 | 39 | def get_partition(self, word: str) -> Tuple[int, int]: 40 | idx = self.get_id(word) 41 | part = math.floor(idx / self.size() * self.num_parts) 42 | assert self.part_start(part) <= idx < self.part_end(part) 43 | return part, idx - self.part_start(part) 44 | 45 | def get_part_list(self, part: int) -> List[str]: 46 | if not 0 <= part < self.num_parts: 47 | raise ValueError(f"{part} not in [0, {self.num_parts})") 48 | return self.ix_to_word[self.part_start(part) : self.part_end(part)] 49 | -------------------------------------------------------------------------------- /torchbiggraph/converters/export_to_tsv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import argparse 10 | from typing import Iterable, TextIO 11 | 12 | from torchbiggraph.checkpoint_manager import CheckpointManager 13 | from torchbiggraph.config import ConfigFileLoader, ConfigSchema 14 | from torchbiggraph.graph_storages import ( 15 | AbstractEntityStorage, 16 | AbstractRelationTypeStorage, 17 | ENTITY_STORAGES, 18 | RELATION_TYPE_STORAGES, 19 | ) 20 | from torchbiggraph.model import make_model, MultiRelationEmbedder 21 | 22 | 23 | def write(outf: TextIO, key: Iterable[str], value: Iterable[float]) -> None: 24 | outf.write("%s\t%s\n" % ("\t".join(key), "\t".join("%.9f" % x for x in value))) 25 | 26 | 27 | def make_tsv( 28 | config: ConfigSchema, entities_tf: TextIO, relation_types_tf: TextIO 29 | ) -> None: 30 | print("Loading relation types and entities...") 31 | entity_storage = ENTITY_STORAGES.make_instance(config.entity_path) 32 | relation_type_storage = RELATION_TYPE_STORAGES.make_instance(config.entity_path) 33 | 34 | print("Initializing model...") 35 | model = make_model(config) 36 | 37 | print("Loading model check point...") 38 | checkpoint_manager = CheckpointManager(config.checkpoint_path) 39 | state_dict, _ = checkpoint_manager.read_model() 40 | if state_dict is not None: 41 | model.load_state_dict(state_dict, strict=False) 42 | 43 | make_tsv_for_entities(model, checkpoint_manager, entity_storage, entities_tf) 44 | make_tsv_for_relation_types(model, relation_type_storage, relation_types_tf) 45 | 46 | 47 | def make_tsv_for_entities( 48 | model: MultiRelationEmbedder, 49 | checkpoint_manager: CheckpointManager, 50 | entity_storage: AbstractEntityStorage, 51 | entities_tf: TextIO, 52 | ) -> None: 53 | print("Writing entity embeddings...") 54 | for ent_t_name, ent_t_config in model.entities.items(): 55 | for partition in range(ent_t_config.num_partitions): 56 | print( 57 | f"Reading embeddings for entity type {ent_t_name} partition " 58 | f"{partition} from checkpoint..." 59 | ) 60 | entities = entity_storage.load_names(ent_t_name, partition) 61 | embeddings, _ = checkpoint_manager.read(ent_t_name, partition) 62 | 63 | if model.global_embs is not None: 64 | embeddings += model.global_embs[model.EMB_PREFIX + ent_t_name] 65 | 66 | print( 67 | f"Writing embeddings for entity type {ent_t_name} partition " 68 | f"{partition} to output file..." 69 | ) 70 | for ix in range(len(embeddings)): 71 | write(entities_tf, (entities[ix],), embeddings[ix]) 72 | if (ix + 1) % 5000 == 0: 73 | print(f"- Processed {ix+1}/{len(embeddings)} entities so far...") 74 | print(f"- Processed all {len(embeddings)} entities") 75 | 76 | entities_output_filename = getattr(entities_tf, "name", "the output file") 77 | print(f"Done exporting entity data to {entities_output_filename}") 78 | 79 | 80 | def make_tsv_for_relation_types( 81 | model: MultiRelationEmbedder, 82 | relation_type_storage: AbstractRelationTypeStorage, 83 | relation_types_tf: TextIO, 84 | ) -> None: 85 | print("Writing relation type parameters...") 86 | relation_types = relation_type_storage.load_names() 87 | if model.num_dynamic_rels > 0: 88 | (rel_t_config,) = model.relations 89 | op_name = rel_t_config.operator 90 | (lhs_operator,) = model.lhs_operators 91 | (rhs_operator,) = model.rhs_operators 92 | for side, operator in [("lhs", lhs_operator), ("rhs", rhs_operator)]: 93 | for param_name, all_params in operator.named_parameters(): 94 | for rel_t_name, param in zip(relation_types, all_params): 95 | shape = "x".join(f"{d}" for d in param.shape) 96 | write( 97 | relation_types_tf, 98 | (rel_t_name, side, op_name, param_name, shape), 99 | param.flatten(), 100 | ) 101 | else: 102 | for rel_t_name, rel_t_config, operator in zip( 103 | relation_types, model.relations, model.rhs_operators 104 | ): 105 | if rel_t_name != rel_t_config.name: 106 | raise ValueError( 107 | f"Mismatch in relations names: got {rel_t_name} in the " 108 | f"dictionary and {rel_t_config.name} in the config." 109 | ) 110 | op_name = rel_t_config.operator 111 | for param_name, param in operator.named_parameters(): 112 | shape = "x".join(f"{d}" for d in param.shape) 113 | write( 114 | relation_types_tf, 115 | (rel_t_name, "rhs", op_name, param_name, shape), 116 | param.flatten(), 117 | ) 118 | 119 | relation_types_output_filename = getattr( 120 | relation_types_tf, "name", "the output file" 121 | ) 122 | print(f"Done exporting relation type data to {relation_types_output_filename}") 123 | 124 | 125 | def main(): 126 | config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help()) 127 | parser = argparse.ArgumentParser( 128 | epilog=config_help, 129 | # Needed to preserve line wraps in epilog. 130 | formatter_class=argparse.RawDescriptionHelpFormatter, 131 | ) 132 | parser.add_argument("config", help="Path to config file") 133 | parser.add_argument("-p", "--param", action="append", nargs="*") 134 | parser.add_argument("--entities-output", required=True) 135 | parser.add_argument("--relation-types-output", required=True) 136 | opt = parser.parse_args() 137 | 138 | loader = ConfigFileLoader() 139 | config = loader.load_config(opt.config, opt.param) 140 | 141 | with open(opt.entities_output, "xt") as entities_tf, open( 142 | opt.relation_types_output, "xt" 143 | ) as relation_types_tf: 144 | make_tsv(config, entities_tf, relation_types_tf) 145 | 146 | 147 | if __name__ == "__main__": 148 | main() 149 | -------------------------------------------------------------------------------- /torchbiggraph/converters/import_from_parquet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import argparse 10 | from pathlib import Path 11 | 12 | from torchbiggraph.config import ConfigFileLoader, ConfigSchema 13 | from torchbiggraph.converters.importers import ( 14 | convert_input_data, 15 | ParquetEdgelistReader, 16 | parse_config_partial, 17 | ) 18 | 19 | 20 | def main(): 21 | config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help()) 22 | parser = argparse.ArgumentParser( 23 | epilog=config_help, 24 | # Needed to preserve line wraps in epilog. 25 | formatter_class=argparse.RawDescriptionHelpFormatter, 26 | ) 27 | parser.add_argument("config", help="Path to config file") 28 | parser.add_argument("-p", "--param", action="append", nargs="*") 29 | parser.add_argument("edge_paths", type=Path, nargs="*", help="Input file paths") 30 | parser.add_argument( 31 | "-l", 32 | "--lhs-col", 33 | type=str, 34 | required=True, 35 | help="Column index for source entity", 36 | ) 37 | parser.add_argument( 38 | "-r", 39 | "--rhs-col", 40 | type=str, 41 | required=True, 42 | help="Column index for target entity", 43 | ) 44 | parser.add_argument("--rel-col", type=str, help="Column index for relation entity") 45 | parser.add_argument( 46 | "--weight-col", type=int, help="(Optional) Column index for edge weight" 47 | ) 48 | parser.add_argument( 49 | "--relation-type-min-count", 50 | type=int, 51 | default=1, 52 | help="Min count for relation types", 53 | ) 54 | parser.add_argument( 55 | "--entity-min-count", type=int, default=1, help="Min count for entities" 56 | ) 57 | opt = parser.parse_args() 58 | 59 | loader = ConfigFileLoader() 60 | config_dict = loader.load_raw_config(opt.config, opt.param) 61 | 62 | ( 63 | entity_configs, 64 | relation_configs, 65 | entity_path, 66 | edge_paths, 67 | dynamic_relations, 68 | ) = parse_config_partial( # noqa 69 | config_dict 70 | ) 71 | 72 | convert_input_data( 73 | entity_configs, 74 | relation_configs, 75 | entity_path, 76 | edge_paths, 77 | opt.edge_paths, 78 | ParquetEdgelistReader(opt.lhs_col, opt.rhs_col, opt.rel_col, opt.weight_col), 79 | opt.entity_min_count, 80 | opt.relation_type_min_count, 81 | dynamic_relations, 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /torchbiggraph/converters/import_from_tsv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import argparse 10 | from pathlib import Path 11 | 12 | from torchbiggraph.config import ConfigFileLoader, ConfigSchema 13 | from torchbiggraph.converters.importers import ( 14 | convert_input_data, 15 | parse_config_partial, 16 | TSVEdgelistReader, 17 | ) 18 | 19 | 20 | def main(): 21 | config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help()) 22 | parser = argparse.ArgumentParser( 23 | epilog=config_help, 24 | # Needed to preserve line wraps in epilog. 25 | formatter_class=argparse.RawDescriptionHelpFormatter, 26 | ) 27 | parser.add_argument("config", help="Path to config file") 28 | parser.add_argument("-p", "--param", action="append", nargs="*") 29 | parser.add_argument("edge_paths", type=Path, nargs="*", help="Input file paths") 30 | parser.add_argument( 31 | "-l", 32 | "--lhs-col", 33 | type=int, 34 | required=True, 35 | help="Column index for source entity", 36 | ) 37 | parser.add_argument( 38 | "-r", 39 | "--rhs-col", 40 | type=int, 41 | required=True, 42 | help="Column index for target entity", 43 | ) 44 | parser.add_argument("--rel-col", type=int, help="Column index for relation entity") 45 | parser.add_argument( 46 | "--weight-col", type=int, help="(Optional) Column index for edge weight" 47 | ) 48 | parser.add_argument( 49 | "--relation-type-min-count", 50 | type=int, 51 | default=1, 52 | help="Min count for relation types", 53 | ) 54 | parser.add_argument( 55 | "--entity-min-count", type=int, default=1, help="Min count for entities" 56 | ) 57 | opt = parser.parse_args() 58 | 59 | loader = ConfigFileLoader() 60 | config_dict = loader.load_raw_config(opt.config, opt.param) 61 | 62 | ( 63 | entity_configs, 64 | relation_configs, 65 | entity_path, 66 | edge_paths, 67 | dynamic_relations, 68 | ) = parse_config_partial( # noqa 69 | config_dict 70 | ) 71 | 72 | convert_input_data( 73 | entity_configs, 74 | relation_configs, 75 | entity_path, 76 | edge_paths, 77 | opt.edge_paths, 78 | TSVEdgelistReader(opt.lhs_col, opt.rhs_col, opt.rel_col, opt.weight_col), 79 | opt.entity_min_count, 80 | opt.relation_type_min_count, 81 | dynamic_relations, 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /torchbiggraph/converters/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import gzip 10 | import shutil 11 | import tarfile 12 | from pathlib import Path 13 | from typing import Callable, Optional 14 | from urllib.parse import urlparse 15 | from urllib.request import urlretrieve 16 | 17 | from tqdm import tqdm 18 | 19 | 20 | def extract_gzip(gzip_path: Path, remove_finished: bool = False) -> Path: 21 | print(f"Extracting {gzip_path}") 22 | if gzip_path.suffix != ".gz": 23 | raise RuntimeError("Not a gzipped file") 24 | fpath = gzip_path.with_suffix("") 25 | 26 | if fpath.exists(): 27 | print( 28 | "Found a file that indicates that the input data " 29 | "has already been extracted, not doing it again." 30 | ) 31 | print(f"This file is: {fpath}") 32 | return fpath 33 | 34 | with fpath.open("wb") as out_bf, gzip.GzipFile(gzip_path) as zip_f: 35 | shutil.copyfileobj(zip_f, out_bf) 36 | if remove_finished: 37 | gzip_path.unlink() 38 | 39 | return fpath 40 | 41 | 42 | def extract_tar(fpath: Path) -> None: 43 | # extract file 44 | with tarfile.open(fpath, "r:gz") as tar: 45 | tar.extractall(path=fpath.parent) 46 | 47 | 48 | def gen_bar_updater(pbar: tqdm) -> Callable[[int, int, int], None]: 49 | def bar_update(count: int, block_size: int, total_size: int) -> None: 50 | if pbar.total is None and total_size: 51 | pbar.total = total_size 52 | progress_bytes = count * block_size 53 | pbar.update(progress_bytes - pbar.n) 54 | 55 | return bar_update 56 | 57 | 58 | def download_url(url: str, root: Path, filename: Optional[str] = None) -> Path: 59 | """Download a file from a url and place it in root. 60 | Args: 61 | url (str): URL to download file from 62 | root (str): Directory to place downloaded file in 63 | filename (str): Name to save the file under. 64 | If None, use the basename of the URL 65 | """ 66 | 67 | root = root.expanduser() 68 | if filename is None: 69 | filename = Path(urlparse(url).path).name 70 | fpath = root / filename 71 | if not root.exists(): 72 | root.mkdir(parents=True, exist_ok=True) 73 | 74 | # downloads file 75 | if fpath.is_file(): 76 | print(f"Using downloaded and verified file: {fpath}") 77 | else: 78 | try: 79 | print(f"Downloading {url} to {fpath}") 80 | urlretrieve( 81 | url, 82 | str(fpath), 83 | reporthook=gen_bar_updater(tqdm(unit="B", unit_scale=True)), 84 | ) 85 | except OSError: 86 | print(f"Failed to download from url: {url}") 87 | 88 | return fpath 89 | -------------------------------------------------------------------------------- /torchbiggraph/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import logging 10 | import multiprocessing as mp 11 | import os 12 | from abc import ABC, abstractmethod 13 | from datetime import timedelta 14 | from typing import Callable, List, NamedTuple, Optional 15 | 16 | import torch.distributed as td 17 | import torch.multiprocessing # noqa monkeypatches 18 | from torchbiggraph.types import Rank 19 | from torchbiggraph.util import tag_logs_with_process_name 20 | 21 | 22 | logger = logging.getLogger("torchbiggraph") 23 | 24 | 25 | class ProcessRanks(NamedTuple): 26 | """Assign a unique ordinal rank to each process for distributed training. 27 | 28 | torch.distributed requires that N communicating processes register 29 | themselves with globally unique ranks [0, ..., N-1]. Distributed training 30 | launches several communicating subprocesses on each machine. This class 31 | manages the assignment from processes/subprocesses to ranks. 32 | """ 33 | 34 | world_size: int 35 | trainers: List[Rank] 36 | parameter_servers: List[Rank] 37 | parameter_clients: List[Rank] 38 | lock_server: Rank 39 | partition_servers: List[Rank] 40 | 41 | @classmethod 42 | def from_num_invocations( 43 | cls, num_machines: int, num_partition_servers: int 44 | ) -> "ProcessRanks": 45 | world_size = 0 46 | 47 | def add_group(group_size: int) -> List[Rank]: 48 | nonlocal world_size 49 | group = [world_size + r for r in range(group_size)] 50 | world_size += group_size 51 | return group 52 | 53 | trainers = add_group(num_machines) 54 | parameter_servers = add_group(num_machines) 55 | parameter_clients = add_group(num_machines) 56 | (lock_server,) = add_group(1) 57 | if num_partition_servers < 0: 58 | # Use machines as partition servers 59 | partition_servers = add_group(num_machines) 60 | else: 61 | partition_servers = add_group(num_partition_servers) 62 | 63 | return cls( 64 | world_size, 65 | trainers, 66 | parameter_servers, 67 | parameter_clients, 68 | lock_server, 69 | partition_servers, 70 | ) 71 | 72 | 73 | def init_process_group( 74 | init_method: Optional[str], 75 | world_size: int, 76 | rank: Rank, 77 | groups: List[List[Rank]], 78 | backend: str = "gloo", 79 | ) -> List["td.ProcessGroup"]: 80 | # With the THD backend there were no timeouts so high variance in 81 | # execution time between trainers was not a problem. With the new c10d 82 | # implementation we do have to take timeouts into account. To simulate 83 | # the old behavior we use a ridiculously high default timeout. 84 | timeout = timedelta(days=365) 85 | logger.info("init_process_group start") 86 | 87 | # Adding code block below to route ftw region traffic to new 88 | # ensembles - T132536412; Easy to extend to other regions and 89 | # handle more traffics by adding more ensembles 90 | try: 91 | # fetch run time trainer cluster which contains region information 92 | runtime_cluster = os.environ["BUMBLEBEE_CLUSTER"] 93 | cluster_region = runtime_cluster.split("-")[1] 94 | filamentZeusMap = {"ftw": ("zelos.8fc7", "zelos.5f14")} 95 | if cluster_region in filamentZeusMap: 96 | logger.info(f"Run time cluster region is: {cluster_region}") 97 | regionList = filamentZeusMap[cluster_region] 98 | parent_flow_id = init_method.split("f")[-1] 99 | zeus_endpoint = regionList[int(parent_flow_id) % len(regionList)] 100 | init_method = f"elasticzeus://{zeus_endpoint}/f{parent_flow_id}" 101 | logger.info("The updated init_method: {}".format(init_method)) 102 | else: 103 | logger.info("Run time cluster region not in Filament ensemble map") 104 | 105 | except Exception as e: 106 | logger.info(f"List BUMBLEBEE_CLUSTER test FAILED due to: {e}") 107 | 108 | if init_method is None: 109 | raise RuntimeError("distributed_init_method must be set when num_machines > 1") 110 | td.init_process_group( 111 | backend, 112 | init_method=init_method, 113 | world_size=world_size, 114 | rank=rank, 115 | timeout=timeout, 116 | ) 117 | logger.info("init_process_group creating groups") 118 | group_objs = [] 119 | for group in groups: 120 | group_objs.append(td.new_group(group, timeout=timeout)) 121 | logger.info("init_process_group done") 122 | return group_objs 123 | 124 | 125 | class Startable(ABC): 126 | @abstractmethod 127 | def start(self, groups: List["td.ProcessGroup"]) -> None: 128 | pass 129 | 130 | 131 | def _server_init( 132 | server: Startable, 133 | process_name: str, 134 | init_method: Optional[str], 135 | world_size: int, 136 | server_rank: Rank, 137 | groups: List[List[Rank]], 138 | subprocess_init: Optional[Callable[[], None]] = None, 139 | ) -> None: 140 | tag_logs_with_process_name(process_name) 141 | if subprocess_init is not None: 142 | subprocess_init() 143 | groups = init_process_group( 144 | init_method=init_method, world_size=world_size, rank=server_rank, groups=groups 145 | ) 146 | server.start(groups) 147 | 148 | 149 | def start_server( 150 | server: Startable, 151 | process_name: str, 152 | init_method: Optional[str], 153 | world_size: int, 154 | server_rank: Rank, 155 | groups: List[List[Rank]], 156 | subprocess_init: Optional[Callable[[], None]] = None, 157 | ) -> mp.Process: 158 | p = mp.get_context("spawn").Process( 159 | name=process_name, 160 | target=_server_init, 161 | args=( 162 | server, 163 | process_name, 164 | init_method, 165 | world_size, 166 | server_rank, 167 | groups, 168 | subprocess_init, 169 | ), 170 | ) 171 | p.daemon = True 172 | p.start() 173 | return p 174 | -------------------------------------------------------------------------------- /torchbiggraph/entitylist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | from typing import Any, Sequence, Union 10 | 11 | import torch 12 | from torchbiggraph.tensorlist import TensorList 13 | from torchbiggraph.types import LongTensorType 14 | 15 | 16 | class EntityList: 17 | """Served as a wrapper of id-based entity and featurized entity. 18 | 19 | self.tensor is an id-based entity list 20 | self.tensor_list is a featurized entity list 21 | 22 | This class maintains the indexing and slicing of these two parallel 23 | representations. 24 | """ 25 | 26 | @classmethod 27 | def empty(cls) -> "EntityList": 28 | return cls(torch.empty((0,), dtype=torch.long), TensorList.empty()) 29 | 30 | @classmethod 31 | def from_tensor(cls, tensor: LongTensorType) -> "EntityList": 32 | if tensor.dim() != 1: 33 | raise ValueError("Expected 1D tensor, got %dD" % tensor.dim()) 34 | tensor_list = TensorList.empty(num_tensors=tensor.shape[0]) 35 | return cls(tensor, tensor_list) 36 | 37 | @classmethod 38 | def from_tensor_list(cls, tensor_list: TensorList) -> "EntityList": 39 | tensor = torch.full((len(tensor_list),), -1, dtype=torch.long) 40 | return cls(tensor, tensor_list) 41 | 42 | @classmethod 43 | def cat(cls, entity_lists: Sequence["EntityList"]) -> "EntityList": 44 | return cls( 45 | torch.cat([el.tensor for el in entity_lists]), 46 | TensorList.cat(el.tensor_list for el in entity_lists), 47 | ) 48 | 49 | def __init__(self, tensor: LongTensorType, tensor_list: TensorList) -> None: 50 | if not isinstance(tensor, (torch.LongTensor, torch.cuda.LongTensor)): 51 | raise TypeError( 52 | "Expected long tensor as first argument, got %s" % type(tensor) 53 | ) 54 | if not isinstance(tensor_list, TensorList): 55 | raise TypeError( 56 | "Expected tensor list as second argument, got %s" % type(tensor_list) 57 | ) 58 | if tensor.dim() != 1: 59 | raise ValueError( 60 | "Expected 1-dimensional tensor, got %d-dimensional one" % tensor.dim() 61 | ) 62 | if tensor.shape[0] != len(tensor_list): 63 | raise ValueError( 64 | "The tensor and tensor list have different lengths: %d != %d" 65 | % (tensor.shape[0], len(tensor_list)) 66 | ) 67 | # TODO We could check that, for all i, we have either tensor[i] < 0 or 68 | # tensor_list[i] empty, however it's expensive and we're already doing 69 | # something similar at retrieval inside to_tensor(_list). 70 | self.tensor: LongTensorType = tensor 71 | self.tensor_list: TensorList = tensor_list 72 | 73 | def to_tensor(self) -> LongTensorType: 74 | if len(self.tensor_list.data) != 0: 75 | raise RuntimeError( 76 | "Getting the tensor data of an EntityList " 77 | "that also has tensor list data" 78 | ) 79 | return self.tensor 80 | 81 | def to_tensor_list(self) -> TensorList: 82 | if not self.tensor.eq(-1).all(): 83 | raise RuntimeError( 84 | "Getting the tensor list data of an EntityList " 85 | "that also has tensor data" 86 | ) 87 | return self.tensor_list 88 | 89 | def __eq__(self, other: Any) -> bool: 90 | if not isinstance(other, EntityList): 91 | return NotImplemented 92 | return ( 93 | torch.equal(self.tensor, other.tensor) 94 | and torch.equal(self.tensor_list.offsets, other.tensor_list.offsets) 95 | and torch.equal(self.tensor_list.data, other.tensor_list.data) 96 | ) 97 | 98 | def __str__(self) -> str: 99 | return repr(self) 100 | 101 | def __repr__(self) -> str: 102 | return "EntityList(%r, TensorList(%r, %r))" % ( 103 | self.tensor, 104 | self.tensor_list.offsets, 105 | self.tensor_list.data, 106 | ) 107 | 108 | def __getitem__(self, index: Union[int, slice, LongTensorType]) -> "EntityList": 109 | if isinstance(index, int): 110 | return self[index : index + 1] 111 | 112 | if isinstance(index, (torch.LongTensor, torch.cuda.LongTensor)) or isinstance( 113 | index, int 114 | ): 115 | tensor_sub = self.tensor[index] 116 | tensor_list_sub = self.tensor_list[index] 117 | return type(self)(tensor_sub, tensor_list_sub) 118 | 119 | if isinstance(index, slice): 120 | start, stop, step = index.indices(len(self)) 121 | if step != 1: 122 | raise ValueError("Expected slice with step 1, got %d" % step) 123 | tensor_sub = self.tensor[start:stop] 124 | tensor_list_sub = self.tensor_list[start:stop] 125 | return type(self)(tensor_sub, tensor_list_sub) 126 | 127 | raise KeyError("Unknown index type: %s" % type(index)) 128 | 129 | def __len__(self) -> int: 130 | return self.tensor.shape[0] 131 | 132 | def to(self, *args, **kwargs) -> "EntityList": 133 | return type(self)( 134 | self.tensor.to(*args, **kwargs), self.tensor_list.to(*args, **kwargs) 135 | ) 136 | -------------------------------------------------------------------------------- /torchbiggraph/examples/LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For torchbiggraph software 4 | 5 | Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | Additional Disclaimer: This software may be used to download and use third-party 33 | data. The copyright holder has not verified that any such third-party data is 34 | free to download or use. It is the responsibility of the software user, and not 35 | the copyright holder, to comply with any and all legal requirements and 36 | third-party rights related to such third-party data, including without 37 | limitation that the software user has the rights to download and use such 38 | third-party data. The copyright holder makes no warranties that this software 39 | complies with any legal requirements or third-party rights related to such 40 | third-party data. 41 | -------------------------------------------------------------------------------- /torchbiggraph/examples/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | -------------------------------------------------------------------------------- /torchbiggraph/examples/configs/fb15k_config_cpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | 10 | def get_torchbiggraph_config(): 11 | 12 | config = dict( # noqa 13 | # I/O data 14 | entity_path="data/FB15k", 15 | edge_paths=[ 16 | "data/FB15k/freebase_mtr100_mte100-train_partitioned", 17 | "data/FB15k/freebase_mtr100_mte100-valid_partitioned", 18 | "data/FB15k/freebase_mtr100_mte100-test_partitioned", 19 | ], 20 | checkpoint_path="model/fb15k", 21 | # Graph structure 22 | entities={"all": {"num_partitions": 1}}, 23 | relations=[ 24 | { 25 | "name": "all_edges", 26 | "lhs": "all", 27 | "rhs": "all", 28 | "operator": "complex_diagonal", 29 | } 30 | ], 31 | dynamic_relations=True, 32 | # Scoring model 33 | dimension=400, 34 | global_emb=False, 35 | comparator="dot", 36 | # Training 37 | num_epochs=50, 38 | num_uniform_negs=1000, 39 | loss_fn="softmax", 40 | lr=0.1, 41 | regularization_coef=1e-3, 42 | # Evaluation during training 43 | eval_fraction=0, # to reproduce results, we need to use all training data 44 | ) 45 | 46 | return config 47 | -------------------------------------------------------------------------------- /torchbiggraph/examples/configs/fb15k_config_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | 10 | def get_torchbiggraph_config(): 11 | 12 | config = dict( # noqa 13 | # I/O data 14 | entity_path="data/FB15k", 15 | edge_paths=[ 16 | "data/FB15k/freebase_mtr100_mte100-train_partitioned", 17 | "data/FB15k/freebase_mtr100_mte100-valid_partitioned", 18 | "data/FB15k/freebase_mtr100_mte100-test_partitioned", 19 | ], 20 | checkpoint_path="model/fb15k", 21 | # Graph structure 22 | entities={"all": {"num_partitions": 1}}, 23 | relations=[ 24 | { 25 | "name": "all_edges", 26 | "lhs": "all", 27 | "rhs": "all", 28 | "operator": "complex_diagonal", 29 | } 30 | ], 31 | dynamic_relations=True, 32 | # Scoring model 33 | dimension=400, 34 | global_emb=False, 35 | comparator="dot", 36 | # Training 37 | num_epochs=50, 38 | batch_size=5000, 39 | num_uniform_negs=1000, 40 | loss_fn="softmax", 41 | lr=0.1, 42 | regularization_coef=1e-3, 43 | # Evaluation during training 44 | eval_fraction=0, 45 | # GPU 46 | num_gpus=1, 47 | ) 48 | 49 | return config 50 | -------------------------------------------------------------------------------- /torchbiggraph/examples/configs/livejournal_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | 10 | def get_torchbiggraph_config(): 11 | 12 | config = dict( # noqa 13 | # I/O data 14 | entity_path="data/livejournal", 15 | edge_paths=["data/train_partitioned", "data/test_partitioned"], 16 | checkpoint_path="model/livejournal", 17 | # Graph structure 18 | entities={"user_id": {"num_partitions": 1}}, 19 | relations=[ 20 | {"name": "follow", "lhs": "user_id", "rhs": "user_id", "operator": "none"} 21 | ], 22 | # Scoring model 23 | dimension=1024, 24 | global_emb=False, 25 | # Training 26 | num_epochs=30, 27 | lr=0.001, 28 | # Misc 29 | hogwild_delay=2, 30 | ) 31 | 32 | return config 33 | -------------------------------------------------------------------------------- /torchbiggraph/examples/fb15k.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import argparse 10 | from pathlib import Path 11 | 12 | import attr 13 | import pkg_resources 14 | import torch 15 | from torchbiggraph.config import add_to_sys_path, ConfigFileLoader 16 | from torchbiggraph.converters.importers import convert_input_data, TSVEdgelistReader 17 | from torchbiggraph.converters.utils import download_url, extract_tar 18 | from torchbiggraph.eval import do_eval 19 | from torchbiggraph.filtered_eval import FilteredRankingEvaluator 20 | from torchbiggraph.train import train 21 | from torchbiggraph.util import ( 22 | set_logging_verbosity, 23 | setup_logging, 24 | SubprocessInitializer, 25 | ) 26 | 27 | 28 | FB15K_URL = "https://dl.fbaipublicfiles.com/starspace/fb15k.tgz" 29 | FILENAMES = [ 30 | "FB15k/freebase_mtr100_mte100-train.txt", 31 | "FB15k/freebase_mtr100_mte100-valid.txt", 32 | "FB15k/freebase_mtr100_mte100-test.txt", 33 | ] 34 | 35 | # Figure out the path where the sample config was installed by the package manager. 36 | # This can be overridden with --config. 37 | USE_CUDA = torch.cuda.is_available() 38 | if USE_CUDA: 39 | DEFAULT_CONFIG = pkg_resources.resource_filename( 40 | "torchbiggraph.examples", "configs/fb15k_config_gpu.py" 41 | ) 42 | else: 43 | DEFAULT_CONFIG = pkg_resources.resource_filename( 44 | "torchbiggraph.examples", "configs/fb15k_config_cpu.py" 45 | ) 46 | 47 | 48 | def main(): 49 | setup_logging() 50 | parser = argparse.ArgumentParser(description="Example on FB15k") 51 | parser.add_argument("--config", default=DEFAULT_CONFIG, help="Path to config file") 52 | parser.add_argument("-p", "--param", action="append", nargs="*") 53 | parser.add_argument( 54 | "--data_dir", type=Path, default="data", help="where to save processed data" 55 | ) 56 | parser.add_argument( 57 | "--no-filtered", 58 | dest="filtered", 59 | action="store_false", 60 | help="Run unfiltered eval", 61 | ) 62 | args = parser.parse_args() 63 | 64 | # download data 65 | data_dir = args.data_dir 66 | fpath = download_url(FB15K_URL, data_dir) 67 | extract_tar(fpath) 68 | print("Downloaded and extracted file.") 69 | 70 | loader = ConfigFileLoader() 71 | config = loader.load_config(args.config, args.param) 72 | set_logging_verbosity(config.verbose) 73 | subprocess_init = SubprocessInitializer() 74 | subprocess_init.register(setup_logging, config.verbose) 75 | subprocess_init.register(add_to_sys_path, loader.config_dir.name) 76 | input_edge_paths = [data_dir / name for name in FILENAMES] 77 | output_train_path, output_valid_path, output_test_path = config.edge_paths 78 | 79 | convert_input_data( 80 | config.entities, 81 | config.relations, 82 | config.entity_path, 83 | config.edge_paths, 84 | input_edge_paths, 85 | TSVEdgelistReader(lhs_col=0, rhs_col=2, rel_col=1), 86 | dynamic_relations=config.dynamic_relations, 87 | ) 88 | 89 | train_config = attr.evolve(config, edge_paths=[output_train_path]) 90 | train(train_config, subprocess_init=subprocess_init) 91 | 92 | relations = [attr.evolve(r, all_negs=True) for r in config.relations] 93 | eval_config = attr.evolve( 94 | config, edge_paths=[output_test_path], relations=relations, num_uniform_negs=0 95 | ) 96 | if args.filtered: 97 | filter_paths = [output_test_path, output_valid_path, output_train_path] 98 | do_eval( 99 | eval_config, 100 | evaluator=FilteredRankingEvaluator(eval_config, filter_paths), 101 | subprocess_init=subprocess_init, 102 | ) 103 | else: 104 | do_eval(eval_config, subprocess_init=subprocess_init) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /torchbiggraph/examples/livejournal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import argparse 10 | import random 11 | from pathlib import Path 12 | 13 | import attr 14 | import pkg_resources 15 | from torchbiggraph.config import add_to_sys_path, ConfigFileLoader 16 | from torchbiggraph.converters.importers import convert_input_data, TSVEdgelistReader 17 | from torchbiggraph.converters.utils import download_url, extract_gzip 18 | from torchbiggraph.eval import do_eval 19 | from torchbiggraph.train import train 20 | from torchbiggraph.util import ( 21 | set_logging_verbosity, 22 | setup_logging, 23 | SubprocessInitializer, 24 | ) 25 | 26 | 27 | URL = "https://snap.stanford.edu/data/soc-LiveJournal1.txt.gz" 28 | TRAIN_FILENAME = "train.txt" 29 | TEST_FILENAME = "test.txt" 30 | FILENAMES = [TRAIN_FILENAME, TEST_FILENAME] 31 | TRAIN_FRACTION = 0.75 32 | 33 | # Figure out the path where the sample config was installed by the package manager. 34 | # This can be overridden with --config. 35 | DEFAULT_CONFIG = pkg_resources.resource_filename( 36 | "torchbiggraph.examples", "configs/livejournal_config.py" 37 | ) 38 | 39 | 40 | def random_split_file(fpath: Path) -> None: 41 | train_file = fpath.parent / TRAIN_FILENAME 42 | test_file = fpath.parent / TEST_FILENAME 43 | 44 | if train_file.exists() and test_file.exists(): 45 | print( 46 | "Found some files that indicate that the input data " 47 | "has already been shuffled and split, not doing it again." 48 | ) 49 | print(f"These files are: {train_file} and {test_file}") 50 | return 51 | 52 | print("Shuffling and splitting train/test file. This may take a while.") 53 | 54 | print(f"Reading data from file: {fpath}") 55 | with fpath.open("rt") as in_tf: 56 | lines = in_tf.readlines() 57 | 58 | # The first few lines are comments 59 | lines = lines[4:] 60 | print("Shuffling data") 61 | random.shuffle(lines) 62 | split_len = int(len(lines) * TRAIN_FRACTION) 63 | 64 | print("Splitting to train and test files") 65 | with train_file.open("wt") as out_tf_train: 66 | for line in lines[:split_len]: 67 | out_tf_train.write(line) 68 | 69 | with test_file.open("wt") as out_tf_test: 70 | for line in lines[split_len:]: 71 | out_tf_test.write(line) 72 | 73 | 74 | def main(): 75 | setup_logging() 76 | parser = argparse.ArgumentParser(description="Example on Livejournal") 77 | parser.add_argument("--config", default=DEFAULT_CONFIG, help="Path to config file") 78 | parser.add_argument("-p", "--param", action="append", nargs="*") 79 | parser.add_argument( 80 | "--data_dir", type=Path, default="data", help="where to save processed data" 81 | ) 82 | 83 | args = parser.parse_args() 84 | 85 | # download data 86 | data_dir = args.data_dir 87 | data_dir.mkdir(parents=True, exist_ok=True) 88 | fpath = download_url(URL, data_dir) 89 | fpath = extract_gzip(fpath) 90 | print("Downloaded and extracted file.") 91 | 92 | # random split file for train and test 93 | random_split_file(fpath) 94 | 95 | loader = ConfigFileLoader() 96 | config = loader.load_config(args.config, args.param) 97 | set_logging_verbosity(config.verbose) 98 | subprocess_init = SubprocessInitializer() 99 | subprocess_init.register(setup_logging, config.verbose) 100 | subprocess_init.register(add_to_sys_path, loader.config_dir.name) 101 | input_edge_paths = [data_dir / name for name in FILENAMES] 102 | output_train_path, output_test_path = config.edge_paths 103 | 104 | convert_input_data( 105 | config.entities, 106 | config.relations, 107 | config.entity_path, 108 | config.edge_paths, 109 | input_edge_paths, 110 | TSVEdgelistReader(lhs_col=0, rhs_col=1, rel_col=None), 111 | dynamic_relations=config.dynamic_relations, 112 | ) 113 | 114 | train_config = attr.evolve(config, edge_paths=[output_train_path]) 115 | train(train_config, subprocess_init=subprocess_init) 116 | 117 | eval_config = attr.evolve(config, edge_paths=[output_test_path]) 118 | do_eval(eval_config, subprocess_init=subprocess_init) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /torchbiggraph/filtered_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import logging 10 | from collections import defaultdict 11 | from typing import Dict, List, Tuple 12 | 13 | from torchbiggraph.config import ConfigSchema 14 | from torchbiggraph.edgelist import EdgeList 15 | from torchbiggraph.eval import RankingEvaluator 16 | from torchbiggraph.graph_storages import EDGE_STORAGES 17 | from torchbiggraph.losses import LOSS_FUNCTIONS 18 | from torchbiggraph.model import Scores 19 | from torchbiggraph.types import UNPARTITIONED 20 | 21 | 22 | logger = logging.getLogger("torchbiggraph") 23 | 24 | 25 | class FilteredRankingEvaluator(RankingEvaluator): 26 | """ 27 | This Evaluator is meant for datasets such as FB15K, FB15K-237, WN18, WN18RR 28 | used in knowledge base completion. We only support one non featurized, 29 | non-partitioned entity type and evaluation with all negatives to be 30 | comparable to standard benchmarks. 31 | """ 32 | 33 | def __init__(self, config: ConfigSchema, filter_paths: List[str]) -> None: 34 | loss_fn = LOSS_FUNCTIONS.get_class(config.loss_fn)(margin=config.margin) 35 | relation_weights = [r.weight for r in config.relations] 36 | super().__init__(loss_fn, relation_weights) 37 | 38 | if len(config.relations) != 1 or len(config.entities) != 1: 39 | raise RuntimeError( 40 | "Filtered ranking evaluation should only be used " 41 | "with dynamic relations and one entity type." 42 | ) 43 | if not config.relations[0].all_negs: 44 | raise RuntimeError("Filtered Eval can only be done with all negatives.") 45 | (entity,) = config.entities.values() 46 | if entity.featurized: 47 | raise RuntimeError("Entity cannot be featurized for filtered eval.") 48 | if entity.num_partitions > 1: 49 | raise RuntimeError("Entity cannot be partitioned for filtered eval.") 50 | 51 | self.lhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list) 52 | self.rhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list) 53 | for path in filter_paths: 54 | logger.info(f"Building links map from path {path}") 55 | e_storage = EDGE_STORAGES.make_instance(path) 56 | # Assume unpartitioned. 57 | edges = e_storage.load_edges(UNPARTITIONED, UNPARTITIONED) 58 | for idx in range(len(edges)): 59 | # Assume non-featurized. 60 | cur_lhs = int(edges.lhs.to_tensor()[idx]) 61 | # Assume dynamic relations. 62 | cur_rel = int(edges.rel[idx]) 63 | # Assume non-featurized. 64 | cur_rhs = int(edges.rhs.to_tensor()[idx]) 65 | 66 | self.lhs_map[cur_lhs, cur_rel].append(cur_rhs) 67 | self.rhs_map[cur_rhs, cur_rel].append(cur_lhs) 68 | 69 | logger.info(f"Done building links map from path {path}") 70 | 71 | def _adjust_scores(self, scores: Scores, batch_edges: EdgeList): 72 | 73 | for idx in range(len(batch_edges)): 74 | # Assume non-featurized. 75 | cur_lhs = int(batch_edges.lhs.to_tensor()[idx]) 76 | # Assume dynamic relations. 77 | cur_rel = int(batch_edges.rel[idx]) 78 | # Assume non-featurized. 79 | cur_rhs = int(batch_edges.rhs.to_tensor()[idx]) 80 | 81 | rhs_edges_filtered = self.lhs_map[cur_lhs, cur_rel] 82 | lhs_edges_filtered = self.rhs_map[cur_rhs, cur_rel] 83 | assert cur_lhs in lhs_edges_filtered 84 | assert cur_rhs in rhs_edges_filtered 85 | 86 | # The rank is computed as the number of non-negative margins (as 87 | # that means a negative with at least as good a score as a positive) 88 | # so to avoid counting positives we give them a negative margin. 89 | scores.lhs_neg[idx][lhs_edges_filtered] = -1e9 90 | scores.rhs_neg[idx][rhs_edges_filtered] = -1e9 91 | -------------------------------------------------------------------------------- /torchbiggraph/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | from abc import ABC, abstractmethod 10 | from typing import Optional 11 | 12 | import torch 13 | from torch import nn as nn 14 | from torch.nn import functional as F 15 | from torchbiggraph.model import match_shape 16 | from torchbiggraph.plugin import PluginRegistry 17 | from torchbiggraph.types import FloatTensorType 18 | 19 | 20 | class AbstractLossFunction(nn.Module, ABC): 21 | """Calculate weighted loss of scores for positive and negative pairs. 22 | 23 | The inputs are a 1-D tensor of size P containing scores for positive pairs 24 | of entities (i.e., those among which an edge exists) and a P x N tensor 25 | containing scores for negative pairs (i.e., where no edge should exist). The 26 | pairs of entities corresponding to pos_scores[i] and to neg_scores[i,j] have 27 | at least one endpoint in common. The output is the loss value these scores 28 | induce. If the method supports weighting (as is the case for the logistic 29 | loss) all positive scores will be weighted by the same weight and so will 30 | all the negative ones. 31 | """ 32 | 33 | def __init__(self, **kwargs): 34 | # loss functions will default ignore any kwargs, but can ask for any 35 | # specific kwargs of interest in their constructor 36 | # FIXME: This is not ideal. Perhaps we should pass in the config 37 | # or a subconfig instead? 38 | super().__init__() 39 | 40 | @abstractmethod 41 | def forward( 42 | self, 43 | pos_scores: FloatTensorType, 44 | neg_scores: FloatTensorType, 45 | weight: Optional[FloatTensorType], 46 | ) -> FloatTensorType: 47 | pass 48 | 49 | 50 | LOSS_FUNCTIONS = PluginRegistry[AbstractLossFunction]() 51 | 52 | 53 | @LOSS_FUNCTIONS.register_as("logistic") 54 | class LogisticLossFunction(AbstractLossFunction): 55 | def forward( 56 | self, 57 | pos_scores: FloatTensorType, 58 | neg_scores: FloatTensorType, 59 | weight: Optional[FloatTensorType], 60 | ) -> FloatTensorType: 61 | num_pos = match_shape(pos_scores, -1) 62 | num_neg = match_shape(neg_scores, num_pos, -1) 63 | neg_weight = 1 / num_neg if num_neg > 0 else 0 64 | 65 | if weight is not None: 66 | match_shape(weight, num_pos) 67 | pos_loss = F.binary_cross_entropy_with_logits( 68 | pos_scores, 69 | pos_scores.new_ones(()).expand(num_pos), 70 | reduction="sum", 71 | weight=weight, 72 | ) 73 | neg_loss = F.binary_cross_entropy_with_logits( 74 | neg_scores, 75 | neg_scores.new_zeros(()).expand(num_pos, num_neg), 76 | reduction="sum", 77 | weight=weight.unsqueeze(-1) if weight is not None else None, 78 | ) 79 | 80 | loss = pos_loss + neg_weight * neg_loss 81 | 82 | return loss 83 | 84 | 85 | @LOSS_FUNCTIONS.register_as("ranking") 86 | class RankingLossFunction(AbstractLossFunction): 87 | def __init__(self, *, margin, **kwargs): 88 | super().__init__() 89 | self.margin = margin 90 | 91 | def forward( 92 | self, 93 | pos_scores: FloatTensorType, 94 | neg_scores: FloatTensorType, 95 | weight: Optional[FloatTensorType], 96 | ) -> FloatTensorType: 97 | num_pos = match_shape(pos_scores, -1) 98 | num_neg = match_shape(neg_scores, num_pos, -1) 99 | 100 | # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15223. 101 | if num_pos == 0 or num_neg == 0: 102 | return torch.zeros((), device=pos_scores.device, requires_grad=True) 103 | 104 | if weight is not None: 105 | match_shape(weight, num_pos) 106 | loss_per_sample = F.margin_ranking_loss( 107 | neg_scores, 108 | pos_scores.unsqueeze(1), 109 | target=pos_scores.new_full((1, 1), -1, dtype=torch.float), 110 | margin=self.margin, 111 | reduction="none", 112 | ) 113 | loss = (loss_per_sample * weight.unsqueeze(-1)).sum() 114 | else: 115 | # more memory efficient way if no weights 116 | loss = F.margin_ranking_loss( 117 | neg_scores, 118 | pos_scores.unsqueeze(1), 119 | target=pos_scores.new_full((1, 1), -1, dtype=torch.float), 120 | margin=self.margin, 121 | reduction="sum", 122 | ) 123 | 124 | return loss 125 | 126 | 127 | @LOSS_FUNCTIONS.register_as("softmax") 128 | class SoftmaxLossFunction(AbstractLossFunction): 129 | def forward( 130 | self, 131 | pos_scores: FloatTensorType, 132 | neg_scores: FloatTensorType, 133 | weight: Optional[FloatTensorType], 134 | ) -> FloatTensorType: 135 | num_pos = match_shape(pos_scores, -1) 136 | num_neg = match_shape(neg_scores, num_pos, -1) 137 | 138 | # FIXME Workaround for https://github.com/pytorch/pytorch/issues/15870 139 | # and https://github.com/pytorch/pytorch/issues/15223. 140 | if num_pos == 0 or num_neg == 0: 141 | return torch.zeros((), device=pos_scores.device, requires_grad=True) 142 | 143 | scores = torch.cat( 144 | [pos_scores.unsqueeze(1), neg_scores.logsumexp(dim=1, keepdim=True)], dim=1 145 | ) 146 | if weight is not None: 147 | loss_per_sample = F.cross_entropy( 148 | scores, 149 | pos_scores.new_zeros((num_pos,), dtype=torch.long), 150 | reduction="none", 151 | ) 152 | match_shape(weight, num_pos) 153 | loss_per_sample = loss_per_sample * weight 154 | else: 155 | loss_per_sample = F.cross_entropy( 156 | scores, 157 | pos_scores.new_zeros((num_pos,), dtype=torch.long), 158 | reduction="sum", 159 | ) 160 | 161 | return loss_per_sample.sum() 162 | -------------------------------------------------------------------------------- /torchbiggraph/partitionserver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import argparse 10 | import logging 11 | from typing import Callable, List, Optional 12 | 13 | import torch.distributed as td 14 | from torchbiggraph.config import add_to_sys_path, ConfigFileLoader, ConfigSchema 15 | from torchbiggraph.distributed import init_process_group, ProcessRanks 16 | from torchbiggraph.parameter_sharing import ParameterServer, ShardedParameterServer 17 | from torchbiggraph.types import Rank, SINGLE_TRAINER 18 | from torchbiggraph.util import ( 19 | set_logging_verbosity, 20 | setup_logging, 21 | SubprocessInitializer, 22 | tag_logs_with_process_name, 23 | ) 24 | 25 | 26 | logger = logging.getLogger("torchbiggraph") 27 | 28 | 29 | # This is a small binary that just runs a partition server. 30 | # You need to run this if you run a distributed run and set 31 | # num_partition_servers > 1. 32 | 33 | 34 | def run_partition_server( 35 | config: ConfigSchema, 36 | rank: Rank = SINGLE_TRAINER, 37 | subprocess_init: Optional[Callable[[], None]] = None, 38 | ) -> None: 39 | tag_logs_with_process_name(f"PartS-{rank}") 40 | if config.num_partition_servers <= 0: 41 | raise RuntimeError("Config doesn't require explicit partition servers") 42 | if not 0 <= rank < config.num_partition_servers: 43 | raise RuntimeError("Invalid rank for partition server") 44 | if not td.is_available(): 45 | raise RuntimeError( 46 | "The installed PyTorch version doesn't provide " 47 | "distributed training capabilities." 48 | ) 49 | ranks = ProcessRanks.from_num_invocations( 50 | config.num_machines, config.num_partition_servers 51 | ) 52 | 53 | if config.num_partition_servers > config.num_machines: 54 | num_ps_groups = config.num_partition_servers * ( 55 | config.num_groups_per_sharded_partition_server + 1 56 | ) 57 | else: 58 | num_ps_groups = config.num_groups_for_partition_server 59 | groups: List[List[int]] = [ranks.trainers] # barrier group 60 | groups += [ranks.trainers + ranks.partition_servers] * num_ps_groups # ps groups 61 | group_idxs_for_partition_servers = range(1, len(groups)) 62 | 63 | if subprocess_init is not None: 64 | subprocess_init() 65 | groups = init_process_group( 66 | rank=ranks.partition_servers[rank], 67 | world_size=ranks.world_size, 68 | init_method=config.distributed_init_method, 69 | groups=groups, 70 | ) 71 | if config.num_partition_servers > config.num_machines: 72 | ps = ShardedParameterServer( 73 | num_clients=len(ranks.trainers), 74 | num_data_pgs=config.num_groups_per_sharded_partition_server, 75 | group_idxs=group_idxs_for_partition_servers, 76 | log_stats=True, 77 | ) 78 | else: 79 | ps = ParameterServer( 80 | num_clients=len(ranks.trainers), 81 | group_idxs=group_idxs_for_partition_servers, 82 | log_stats=True, 83 | ) 84 | ps.start(groups) 85 | logger.info("ps.start done") 86 | 87 | 88 | def main(): 89 | setup_logging() 90 | config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help()) 91 | parser = argparse.ArgumentParser( 92 | epilog=config_help, 93 | # Needed to preserve line wraps in epilog. 94 | formatter_class=argparse.RawDescriptionHelpFormatter, 95 | ) 96 | parser.add_argument("config", help="Path to config file") 97 | parser.add_argument("-p", "--param", action="append", nargs="*") 98 | parser.add_argument( 99 | "--rank", 100 | type=int, 101 | default=SINGLE_TRAINER, 102 | help="For multi-machine, this machine's rank", 103 | ) 104 | opt = parser.parse_args() 105 | 106 | loader = ConfigFileLoader() 107 | config = loader.load_config(opt.config, opt.param) 108 | set_logging_verbosity(config.verbose) 109 | subprocess_init = SubprocessInitializer() 110 | subprocess_init.register(setup_logging, config.verbose) 111 | subprocess_init.register(add_to_sys_path, loader.config_dir.name) 112 | 113 | run_partition_server(config, opt.rank, subprocess_init=subprocess_init) 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /torchbiggraph/plugin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | from typing import Callable, Dict, Generic, Type, TypeVar 10 | from urllib.parse import urlparse 11 | 12 | 13 | T = TypeVar("T") 14 | 15 | 16 | class PluginRegistry(Generic[T]): 17 | def __init__(self) -> None: 18 | self.registry: Dict[str, Type[T]] = {} 19 | 20 | def register(self, name: str, class_: Type[T]) -> None: 21 | reg_class = self.registry.setdefault(name, class_) 22 | if reg_class is not class_: 23 | raise RuntimeError( 24 | f"Attempting to re-register {name} " 25 | f"which was already set to {reg_class!r}" 26 | ) 27 | 28 | def register_as(self, name: str) -> Callable[[Type[T]], Type[T]]: 29 | def decorator(class_: Type[T]) -> Type[T]: 30 | self.register(name, class_) 31 | return class_ 32 | 33 | return decorator 34 | 35 | def get_class(self, name: str) -> Type[T]: 36 | try: 37 | return self.registry[name] 38 | except KeyError: 39 | all_names = ", ".join(sorted(self.registry.keys())) 40 | raise NotImplementedError(f"Unknown name {name} (known names: {all_names})") 41 | 42 | 43 | class URLPluginRegistry(PluginRegistry[T], Generic[T]): 44 | def make_instance(self, url: str) -> T: 45 | scheme = urlparse(url).scheme 46 | try: 47 | class_: Type[T] = self.get_class(scheme) 48 | except NotImplementedError as err: 49 | raise NotImplementedError(f"Unsupported URL {url}: {err}") from None 50 | return class_(url) 51 | -------------------------------------------------------------------------------- /torchbiggraph/regularizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | from abc import ABC, abstractmethod 10 | from typing import Optional 11 | 12 | import torch 13 | import torchbiggraph.model 14 | from torchbiggraph.operators import AbstractDynamicOperator, AbstractOperator 15 | from torchbiggraph.plugin import PluginRegistry 16 | from torchbiggraph.types import FloatTensorType, LongTensorType 17 | 18 | 19 | class AbstractRegularizer(ABC): 20 | """ 21 | Computes a weighted penalty for embeddings involved in score computations. 22 | """ 23 | 24 | def __init__(self, weight): 25 | self.weight = weight 26 | 27 | @abstractmethod 28 | def forward_dynamic( 29 | self, 30 | src_pos: FloatTensorType, 31 | dst_pos: FloatTensorType, 32 | src_operators: Optional[FloatTensorType], 33 | dst_operators: Optional[FloatTensorType], 34 | ) -> FloatTensorType: 35 | pass 36 | 37 | @abstractmethod 38 | def forward( 39 | self, 40 | src_pos: FloatTensorType, 41 | dst_pos: FloatTensorType, 42 | src_operators: Optional[FloatTensorType], 43 | dst_operators: Optional[FloatTensorType], 44 | ) -> FloatTensorType: 45 | pass 46 | 47 | 48 | REGULARIZERS = PluginRegistry[AbstractRegularizer]() 49 | 50 | 51 | @REGULARIZERS.register_as("N3") 52 | class N3Regularizer(AbstractRegularizer): 53 | """N3 regularizer described in https://arxiv.org/pdf/1806.07297.pdf""" 54 | 55 | def reg_embs( 56 | self, src_pos: FloatTensorType, dst_pos: FloatTensorType 57 | ) -> FloatTensorType: 58 | a, b, rank = torchbiggraph.model.match_shape(src_pos, -1, -1, -1) 59 | torchbiggraph.model.match_shape(dst_pos, a, b, rank) 60 | total = 0 61 | for x in (src_pos, dst_pos): 62 | total += torch.sum(self.modulus(x, rank // 2) ** 3) 63 | return total * self.weight 64 | 65 | def forward_dynamic( 66 | self, 67 | src_pos: FloatTensorType, 68 | dst_pos: FloatTensorType, 69 | operator: AbstractDynamicOperator, 70 | rel_idxs: LongTensorType, 71 | ) -> FloatTensorType: 72 | total = 0 73 | operator_params = operator.get_operator_params_for_reg(rel_idxs) 74 | if operator_params is not None: 75 | operator_params = operator_params.to(src_pos.device) 76 | total += torch.sum(operator_params**3) 77 | for x in (src_pos, dst_pos): 78 | total += torch.sum(operator.prepare_embs_for_reg(x) ** 3) 79 | total *= self.weight 80 | return total 81 | 82 | def forward( 83 | self, 84 | src_pos: FloatTensorType, 85 | dst_pos: FloatTensorType, 86 | operator: AbstractOperator, 87 | ) -> FloatTensorType: 88 | total = 0 89 | operator_params = operator.get_operator_params_for_reg() 90 | if operator_params is not None: 91 | operator_params = operator_params.to(src_pos.device) 92 | batch_size = len(src_pos) 93 | total += torch.sum(operator_params**3) * batch_size 94 | for x in (src_pos, dst_pos): 95 | total += torch.sum(operator.prepare_embs_for_reg(x) ** 3) 96 | total *= self.weight 97 | return total 98 | -------------------------------------------------------------------------------- /torchbiggraph/row_adagrad.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import logging 10 | 11 | from torch.optim import Optimizer 12 | 13 | 14 | logger = logging.getLogger("torchbiggraph") 15 | 16 | 17 | class RowAdagrad(Optimizer): 18 | """Implements a row-wise variant of the Adagrad algorithm. 19 | Assumes that all the model parameters are 2-dimensional tensors 20 | containing embedding weights. 21 | 22 | Code mostly copy-pasted from torch/optim/Adagrad, with HOGWILD-safe 23 | update (see async_adagrad.py) 24 | """ 25 | 26 | def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0): 27 | # lr_decay is a little tricky beause keeping track of # of steps 28 | # is not straightforward when they're happening in a distributed way. 29 | # Anyway, we don't use lr_decay in Filament anyway 30 | assert lr_decay == 0, "lr_decay not currently supported." 31 | defaults = {"lr": lr, "lr_decay": lr_decay, "weight_decay": weight_decay} 32 | super().__init__(params, defaults) 33 | 34 | for group in self.param_groups: 35 | for p in group["params"]: 36 | # state['step'] = 0 37 | if p.dim() != 2: 38 | raise ValueError("RowAdagrad only works on 2D tensors") 39 | state = self.state[p] 40 | state["sum"] = p.new_zeros((p.shape[0],)) 41 | 42 | def share_memory(self): 43 | for group in self.param_groups: 44 | for p in group["params"]: 45 | state = self.state[p] 46 | state["sum"].share_memory_() 47 | 48 | def step(self, closure=None): 49 | """Performs a single optimization step. 50 | 51 | Arguments: 52 | closure (callable, optional): A closure that reevaluates the model 53 | and returns the loss. 54 | """ 55 | loss = None 56 | if closure is not None: 57 | loss = closure() 58 | 59 | for group in self.param_groups: 60 | for p in group["params"]: 61 | if p.grad is None: 62 | continue 63 | 64 | grad = p.grad.data 65 | state = self.state[p] 66 | 67 | # state['step'] += 1 68 | 69 | if group["weight_decay"] != 0: 70 | if grad.is_sparse: 71 | raise RuntimeError( 72 | "weight_decay option is not " 73 | "compatible with sparse gradients " 74 | ) 75 | grad = grad.add(group["weight_decay"], p.data) 76 | 77 | # clr = group['lr'] / (1 + (state['step'] - 1) * group['lr_decay']) 78 | clr = group["lr"] 79 | 80 | if grad.is_sparse: 81 | if grad._indices().numel() == 0: 82 | continue 83 | # the update is non-linear so indices must be unique 84 | grad = grad.coalesce() 85 | grad_indices = grad._indices()[0] 86 | grad_values = grad._values() 87 | # multiple HOGWILD processes may perform unsynchronized 88 | # updates to G. Update a local copy of G independently from 89 | # the shared-memory copy, to guarantee that 90 | # local_G >= grad^2 91 | local_G = state["sum"][grad_indices] # _sparse_mask 92 | delta_G = (grad_values * grad_values).mean(1) 93 | state["sum"].index_add_(0, grad_indices, delta_G) 94 | local_G += delta_G 95 | std_values = local_G.sqrt_().add_(1e-10).unsqueeze(1) 96 | p.data.index_add_(0, grad_indices, -clr * grad_values / std_values) 97 | else: 98 | # multiple HOGWILD processes may perform unsynchronized 99 | # updates to G. Update a local copy of G independently from 100 | # the shared-memory copy, to guarantee that 101 | # local_G >= grad^2 102 | local_G = state["sum"].clone() 103 | delta_G = (grad * grad).mean(1) 104 | state["sum"] += delta_G 105 | local_G += delta_G 106 | std = local_G.sqrt().add_(1e-10) 107 | p.data.addcdiv_(grad, std.unsqueeze(1), value=-clr) 108 | 109 | return loss 110 | -------------------------------------------------------------------------------- /torchbiggraph/rpc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import io 9 | import sys 10 | import traceback 11 | 12 | import numpy as np 13 | import torch 14 | import torch.distributed as td 15 | 16 | 17 | # FIXME: is it efficient to torch.save into a buf? It's going to have to copy 18 | # all the tensors. 19 | def _serialize(data): 20 | buf = io.BytesIO() 21 | torch.save(data, buf) 22 | data_bytes = buf.getvalue() 23 | # FIXME: how to properly copy bytes to ByteTensor? 24 | t = torch.from_numpy(np.frombuffer(data_bytes, dtype=np.uint8)) 25 | return t 26 | 27 | 28 | def _deserialize(t): 29 | data_bytes = t.numpy().tobytes() 30 | buf = io.BytesIO(data_bytes) 31 | return torch.load(buf) 32 | 33 | 34 | def send(data, dst): 35 | """ 36 | Sends an arbitrary torch-serializable object to a destination node. 37 | This is a blocking send, equivalent to `torch.distributed.send`. 38 | 39 | Args: 40 | data: An arbitrary torch-serializable object to be sent. 41 | dst: The rank of the destination node. 42 | """ 43 | 44 | # FIXME: we've got to get rid of this two-pass nonsense for dynamically sized 45 | # send and receive. 46 | t = _serialize(data) 47 | sizet = torch.LongTensor([t.nelement()]) 48 | td.send(sizet, dst) 49 | td.send(t, dst) 50 | 51 | 52 | def recv(src=None): 53 | """ 54 | Receives an arbitrary torch-serializable object from a source node. 55 | This is a blocking receive, `torch.distributed.recv` 56 | 57 | Args: 58 | src: The rank of the source node. If None, will receive from any rank. 59 | 60 | Returns: 61 | data: The data send from the source node. 62 | src: The rank of the source node. 63 | """ 64 | sizet = torch.LongTensor(1) 65 | src = td.recv(sizet, src) 66 | t = torch.ByteTensor(sizet.item()) 67 | td.recv(t, src) 68 | return _deserialize(t), src 69 | 70 | 71 | _JOIN_KEY = "seU17sb9nwqDZhsH9AyW" 72 | 73 | 74 | class Server: 75 | """Base class for an RPC server using `torch.distributed`. 76 | Users should subclass this class and add the server methods. 77 | 78 | Example: 79 | init_method = "file://myfile.tmp" 80 | num_clients = 1 81 | torch.distributed.init_process_group('gloo', 82 | init_method=init_method, 83 | world_size=num_clients + 1, 84 | rank=0) 85 | 86 | class MyServer(Server): 87 | def test_func(self, T, k=0): 88 | return ("the result is ", T + k) 89 | 90 | s = MyServer(num_clients) 91 | s.start() # will block until all clients have called `join()` 92 | """ 93 | 94 | def __init__(self, num_clients): 95 | """ 96 | Args: 97 | num_clients: The number of clients that will call `join()` upon 98 | completion. 99 | """ 100 | self.num_clients = num_clients 101 | 102 | def start(self, groups=None): 103 | join_clients = [] 104 | 105 | while True: 106 | rpc, src = recv() 107 | if rpc == _JOIN_KEY: 108 | join_clients += [src] 109 | if len(join_clients) == self.num_clients: 110 | for client in join_clients: 111 | # after sending the join cmd, 112 | # each client waits on this ack to know everyone is done 113 | # and it's safe to exit 114 | send(_JOIN_KEY, client) 115 | break 116 | else: 117 | F, args, kwargs = rpc 118 | try: 119 | res = getattr(self, F)(*args, **kwargs) 120 | send((False, res), src) 121 | except BaseException as e: 122 | # should we print the exception on the server also? 123 | # traceback.print_exc() 124 | exc_str = traceback.format_exc() 125 | send((True, (e, exc_str)), src) 126 | 127 | 128 | class Client: 129 | """A client for connecting to a subclass of `rpc.Server`. 130 | 131 | Example: 132 | init_method = "file://myfile.tmp" 133 | num_clients = 1 134 | torch.distributed.init_process_group('gloo', 135 | init_method=init_method, 136 | world_size=num_clients + 1, 137 | rank=1) 138 | 139 | c = Client(MyServer, server_rank=0) 140 | 141 | print(c.test_func(torch.arange(0, 3), k=2)) 142 | # ('the result is ', tensor([ 2, 3, 4])) 143 | 144 | c.join() 145 | """ 146 | 147 | def __init__(self, server_class, server_rank): 148 | """ 149 | Args: 150 | server_class: The class of the server object. This should be a 151 | subclass of `rpc.Server`. 152 | server_rank: The rank of the node where the `rpc.Server` is running. 153 | """ 154 | self.server_class = server_class 155 | self.server_rank = server_rank 156 | 157 | def __getattr__(self, name): 158 | if name not in dir(self.server_class): 159 | raise AttributeError( 160 | "%s has no attribute %s" % (self.server_class.__name__, name) 161 | ) 162 | func = getattr(self.server_class, name) 163 | if not isinstance(func, type(lambda: 1)): # FIXME 164 | raise TypeError("%s object is not callable" % (type(func))) 165 | 166 | def inner(*args, **kwargs): 167 | send((name, args, kwargs), self.server_rank) 168 | (is_exception, res), _src = recv(self.server_rank) 169 | if not is_exception: 170 | return res 171 | else: 172 | exc, exc_str = res 173 | print(exc_str, file=sys.stderr) 174 | raise exc 175 | 176 | return inner 177 | 178 | def join(self) -> None: 179 | """Should be called by each client upon completion, to ensure a clean exit.""" 180 | send(_JOIN_KEY, self.server_rank) 181 | recv(self.server_rank) 182 | -------------------------------------------------------------------------------- /torchbiggraph/stats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | from collections import defaultdict 10 | from statistics import mean 11 | from typing import cast, Dict, Iterable, Optional, Type, Union 12 | 13 | from torchbiggraph.types import FloatTensorType 14 | 15 | 16 | def average_of_sums(*tensors: FloatTensorType) -> float: 17 | return mean(t.sum().item() for t in tensors) 18 | 19 | 20 | SerializedStats = Dict[str, Union[int, Dict[str, float]]] 21 | 22 | 23 | class Stats: 24 | """A class collecting a set of metrics. 25 | 26 | When defining the stats produced by a certain operation (say, training or 27 | evaluation), subclass this class, decorate it with @stats and define the 28 | metrics you want to collect as class attributes with type annotations whose 29 | values are attr.ib() instances. A metric named count is automatically added. 30 | Doing this automatically gives you space-optimized classes (using slots) 31 | equipped with the most common magic methods (__init__, __eq__, ...) plus 32 | some convenience methods to aggregate, convert and format stats (see below). 33 | 34 | """ 35 | 36 | def __init__(self, *, count: int, **metrics: float) -> None: 37 | self.count = count 38 | self.metrics = metrics 39 | 40 | @classmethod 41 | def sum(cls: Type["Stats"], stats: Iterable["Stats"]) -> "Stats": 42 | """Return a stats whose metrics are the sums of the given stats.""" 43 | total_metrics = defaultdict(lambda: 0) 44 | for s in stats: 45 | for k, v in s.metrics.items(): 46 | total_metrics[k] += v 47 | return cls(count=sum(s.count for s in stats), **total_metrics) 48 | 49 | def average(self) -> "Stats": 50 | """Return these stats with all metrics, except count, averaged.""" 51 | if self.count == 0: 52 | return self 53 | return type(self)( 54 | count=self.count, **{k: v / self.count for k, v in self.metrics.items()} 55 | ) 56 | 57 | @classmethod 58 | def average_list(cls: Type["Stats"], stats: Iterable["Stats"]) -> "Stats": 59 | """Return a stats whose metrics are the average of all stats.""" 60 | 61 | return cls.sum([s * s.count for s in stats]).average() 62 | 63 | def __str__(self) -> str: 64 | return "%s , count: %d" % ( 65 | " , ".join("%s: %.6g" % (k, v) for k, v in self.metrics.items()), 66 | self.count, 67 | ) 68 | 69 | def __eq__(self, other: "Stats") -> bool: 70 | return ( 71 | isinstance(other, Stats) 72 | and self.count == other.count 73 | and self.metrics == other.metrics 74 | ) 75 | 76 | def __mul__(self, c: float) -> "Stats": 77 | return type(self)( 78 | count=self.count, **{k: v * c for k, v in self.metrics.items()} 79 | ) 80 | 81 | def to_dict(self) -> SerializedStats: 82 | return {"count": self.count, "metrics": self.metrics} 83 | 84 | @classmethod 85 | def from_dict(cls, d: SerializedStats) -> "Stats": 86 | if set(d.keys()) != {"count", "metrics"}: 87 | raise ValueError( 88 | f"Expect keys ['count', 'metrics'] from input but get {list(d.keys())}." 89 | ) 90 | return Stats( 91 | count=cast(int, d["count"]), **cast(Dict[str, float], d["metrics"]) 92 | ) 93 | 94 | 95 | class StatsHandler: 96 | def on_stats( 97 | self, 98 | index: int, 99 | eval_stats_before: Optional[Stats] = None, 100 | train_stats: Optional[Stats] = None, 101 | eval_stats_after: Optional[Stats] = None, 102 | eval_stats_chunk_avg: Optional[Stats] = None, 103 | ) -> None: 104 | pass 105 | -------------------------------------------------------------------------------- /torchbiggraph/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | import argparse 10 | import logging 11 | from typing import Callable, Optional 12 | 13 | from torchbiggraph.batching import AbstractBatchProcessor 14 | from torchbiggraph.config import add_to_sys_path, ConfigFileLoader, ConfigSchema 15 | from torchbiggraph.model import MultiRelationEmbedder 16 | from torchbiggraph.train_cpu import TrainingCoordinator 17 | from torchbiggraph.train_gpu import GPUTrainingCoordinator 18 | from torchbiggraph.types import Rank, SINGLE_TRAINER 19 | from torchbiggraph.util import ( 20 | set_logging_verbosity, 21 | setup_logging, 22 | SubprocessInitializer, 23 | ) 24 | 25 | 26 | logger = logging.getLogger("torchbiggraph") 27 | dist_logger = logging.LoggerAdapter(logger, {"distributed": True}) 28 | 29 | 30 | def train( 31 | config: ConfigSchema, 32 | model: Optional[MultiRelationEmbedder] = None, 33 | trainer: Optional[AbstractBatchProcessor] = None, 34 | evaluator: Optional[AbstractBatchProcessor] = None, 35 | rank: Rank = SINGLE_TRAINER, 36 | subprocess_init: Optional[Callable[[], None]] = None, 37 | ) -> None: 38 | CoordinatorT = ( 39 | GPUTrainingCoordinator if config.num_gpus > 0 else TrainingCoordinator 40 | ) 41 | coordinator = CoordinatorT(config, model, trainer, evaluator, rank, subprocess_init) 42 | coordinator.train() 43 | coordinator.close() 44 | 45 | 46 | def main(): 47 | setup_logging() 48 | config_help = "\n\nConfig parameters:\n\n" + "\n".join(ConfigSchema.help()) 49 | parser = argparse.ArgumentParser( 50 | epilog=config_help, 51 | # Needed to preserve line wraps in epilog. 52 | formatter_class=argparse.RawDescriptionHelpFormatter, 53 | ) 54 | parser.add_argument("config", help="Path to config file") 55 | parser.add_argument("-p", "--param", action="append", nargs="*") 56 | parser.add_argument( 57 | "--rank", 58 | type=int, 59 | default=SINGLE_TRAINER, 60 | help="For multi-machine, this machine's rank", 61 | ) 62 | opt = parser.parse_args() 63 | 64 | loader = ConfigFileLoader() 65 | config = loader.load_config(opt.config, opt.param) 66 | set_logging_verbosity(config.verbose) 67 | subprocess_init = SubprocessInitializer() 68 | subprocess_init.register(setup_logging, config.verbose) 69 | subprocess_init.register(add_to_sys_path, loader.config_dir.name) 70 | 71 | train(config, rank=opt.rank, subprocess_init=subprocess_init) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /torchbiggraph/types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE.txt file in the root directory of this source tree. 8 | 9 | from enum import Enum 10 | from typing import Any, Dict, NamedTuple, TypeVar 11 | 12 | import torch 13 | 14 | 15 | # torch.FloatTensor and torch.LongTensor are defined as empty subclasses of 16 | # torch.Tensor by PyTorch's type stub, which means that any operation on them 17 | # returns plain untyped torch.Tensors. This makes it impossible to use the typed 18 | # subtypes to annotate functions as they wouldn't get the type they expect. 19 | # Thus for type checking to work functions must be annotated with torch.Tensor. 20 | # To preserve and expose that information, at least to humans, we use more 21 | # informative aliases for torch.Tensor. (PS: FloatTensor and LongTensor are in 22 | # fact instances of the torch.tensortype metaclass). 23 | ByteTensorType = torch.Tensor # uint8 24 | CharTensorType = torch.Tensor # int8 25 | FloatTensorType = torch.Tensor # float32 26 | LongTensorType = torch.Tensor # int64 27 | 28 | 29 | T = TypeVar("T") 30 | 31 | 32 | class Side(Enum): 33 | LHS = 0 34 | RHS = 1 35 | 36 | def pick(self, lhs: T, rhs: T) -> T: 37 | if self is Side.LHS: 38 | return lhs 39 | elif self is Side.RHS: 40 | return rhs 41 | else: 42 | raise NotImplementedError("Unknown side: %s" % self) 43 | 44 | 45 | EntityName = str 46 | Rank = int 47 | GPURank = int 48 | Partition = int 49 | SubPartition = int 50 | ModuleStateDict = Dict[str, torch.Tensor] 51 | OptimizerStateDict = Dict[str, Any] 52 | 53 | 54 | class Bucket(NamedTuple): 55 | lhs: Partition 56 | rhs: Partition 57 | 58 | def get_partition(self, side: Side) -> Partition: 59 | return side.pick(self.lhs, self.rhs) 60 | 61 | def __str__(self) -> str: 62 | return "( %d , %d )" % (self.lhs, self.rhs) 63 | 64 | 65 | # Use as partition index for unpartitioned entities, which have a single partition. 66 | UNPARTITIONED: Partition = 0 67 | # Use as rank for single-machine training. 68 | SINGLE_TRAINER: Rank = 0 69 | --------------------------------------------------------------------------------