├── .gitignore ├── Dockerfile ├── README.md ├── dev_guide.md ├── docs ├── Makefile ├── _templates │ ├── module.rst │ └── package.rst ├── conf.py ├── index.md ├── make.bat └── source │ ├── devguide.md │ └── userguide.md ├── mmmt ├── __init__.py ├── data │ ├── __init__.py │ ├── graph │ │ ├── __init__.py │ │ ├── concept_to_graph.py │ │ ├── data_to_graph.py │ │ ├── dgl_data_loader.py │ │ ├── general_file_loader.py │ │ ├── graph_to_graph.py │ │ ├── mat_file_loader.py │ │ └── visualization.py │ ├── operators │ │ ├── __init__.py │ │ ├── op_build_graph.py │ │ ├── op_concat_names.py │ │ ├── op_forwardpass.py │ │ └── op_resample.py │ └── representation │ │ ├── __init__.py │ │ ├── auto_encoder_trainer.py │ │ ├── encoded_unimodal_to_concept.py │ │ ├── fusion.py │ │ ├── modality_encoding.py │ │ └── model_builder_trainer.py ├── models │ ├── __init__.py │ ├── classic │ │ ├── __init__.py │ │ ├── fusion_mlp.py │ │ ├── late_fusion.py │ │ └── uncertainty_late_fusion.py │ ├── graph │ │ ├── __init__.py │ │ ├── gcn.py │ │ ├── mgnn.py │ │ ├── module_configurator.py │ │ ├── multi_behavioral_gnn.py │ │ ├── multiplex_gcn.py │ │ ├── multiplex_gin.py │ │ └── relational_gcn.py │ ├── head │ │ ├── __init__.py │ │ └── mlp.py │ ├── model_builder.py │ ├── multimodal_graph_model.py │ └── multimodal_mlp.py ├── pipeline │ ├── __init__.py │ ├── defaults.yaml │ ├── object_registry.py │ └── pipeline.py └── py.typed ├── mmmt_examples ├── README.md ├── __init__.py └── knight │ ├── __init__.py │ ├── demonstration_notebook.ipynb │ ├── full_mmmt_pipeline.py │ ├── get_splits.py │ ├── knight_eval.py │ ├── mmmt_pipeline_config.yaml │ ├── mmmt_pipeline_config_demonstration.yaml │ ├── mmmt_pipeline_config_demonstration_mlp.yaml │ ├── op_knight.py │ ├── pipeline.gv.png │ ├── splits_final.pkl │ └── user_input.png ├── pyproject.toml ├── setup.cfg ├── setup.py └── test ├── __init__.py ├── data ├── graph │ ├── test_dgl_data_loader.py │ ├── test_general_file_loader.py │ ├── test_graph_to_graph.py │ ├── test_mat_file_loader.py │ └── test_visualization.py ├── operators │ └── test_operators.py └── representation │ ├── test_encoded_unimodal_to_concept.py │ └── test_modality_encoding.py ├── models ├── classic │ └── test_classic_models.py ├── graph │ ├── test_graph_models.py │ └── test_module_configurator.py ├── head │ └── test_head_models.py ├── test_mgm.py └── test_model_builder.py └── pipeline ├── test.yaml ├── test_mlp.yaml └── test_pipeline.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | docs/api/* 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | .idea/ 108 | .vscode/ 109 | .DS_Store 110 | tmp/ 111 | conda_env*/ 112 | _examples 113 | mlruns/ 114 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8-slim 2 | 3 | RUN groupadd -r mmmtuser && useradd -r -g mmmtuser mmmtuser 4 | 5 | RUN apt-get update 6 | RUN apt-get -y install git 7 | RUN apt-get -y install gcc 8 | 9 | WORKDIR /home/mmmtuser 10 | 11 | RUN chown mmmtuser:mmmtuser /home/mmmtuser 12 | 13 | USER mmmtuser 14 | 15 | ARG GIT_TOKEN 16 | RUN pip3 install "git+https://$GIT_TOKEN@github.com/BiomedSciAI/multimodal-model-toolkit" 17 | 18 | CMD ["bash"] 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # multimodal-model-toolkit 2 | 3 | **multimodal-model-toolkit (MMMT, pronounced mammut)** is a platform for accelerating research and development with data in multiple 4 | modalities, from data pre-processing, to model evaluation. 5 | 6 | MMMT has a modular structure and distinguishes the following modules coordinated by one pipeline: 7 | 1. data loading 8 | 2. data representation 9 | 1. unimodal representation 10 | 2. multimodal representation 11 | 3. training 12 | 4. inference 13 | 5. evaluation 14 | 15 | ## Setup 16 | 17 | MMMT supports Python 3.8 and 3.9. To install: 18 | ```sh 19 | pip install git+ssh://git@github.com/BiomedSciAI/multimodal-model-toolkit 20 | ``` 21 | 22 | If you are contributing to the development of MMMT, see [dev_guide.md](dev_guide.md). 23 | 24 | ## Current methods for fusing representation 25 | List of the methods currently integrated in the toolkit: 26 | 27 | | Method | Short description | Link to publication | 28 | |--------|-------------------|----------------------------------------------------| 29 | | multiplex_gcn | multiplex GCN for message passing according to sGCN Conv for sparse graphs | https://arxiv.org/abs/2210.14377 | 30 | | multiplex_gin | multiplex GIN framework for message passing via multiplex walks | early https://arxiv.org/abs/2210.14377 | 31 | | relational_gcn | relational GCN | https://arxiv.org/pdf/1703.06103.pdf | 32 | | gcn | baseline GCN | https://arxiv.org/abs/1609.02907v4 | 33 | | mgnn | mGNN framework for message passing | https://arxiv.org/abs/2109.10119 | 34 | | multi_behavioral_gnn | multibehavioral GNN framework for message passing | https://dl.acm.org/doi/pdf/10.1145/3340531.3412119 | 35 | 36 | 37 | ## User interface 38 | In order to simplify the configuration of a computation using MMMT, we use the concept of pipeline, which can be fully specified using a `yaml` file. 39 | 40 | The basic concept is that the yaml file describes the phases of the computations (e.g. data loading) and each phase contains the list of steps to be executed, specifying which object to use and the values of the arguments to give. 41 | 42 | Beside clear modularization, this enables the user to launch multiple computations by just calling the same starting script (e.g. [full_mmmt_pipeline.py](mmmt_examples/knight/full_mmmt_pipeline.py)) passing different yaml files. 43 | 44 | Default configuration values for a computation involving all possible phases are [here](mmmt/pipeline/defaults.yaml). 45 | 46 | 47 | ## Examples 48 | In [mmmt_examples](mmmt_examples/README.md) we keep a list of examples of MMMT applications. 49 | 50 | The goal of these scripts is to showcase how to use MMMT to selected datasets. 51 | 52 | ### Datasets used so far in [mmmt_examples](mmmt_examples/README.md) 53 | | Dataset name | Short description | Link to dataset | 54 | |--------------|-----------------------------------|-----------------------------------| 55 | | KNIGHT | Kidney clinical Notes and Imaging to Guide and Help personalize Treatment and biomarkers discovery | https://research.ibm.com/haifa/Workshops/KNIGHT/ | 56 | -------------------------------------------------------------------------------- /dev_guide.md: -------------------------------------------------------------------------------- 1 | # Developer guide 2 | 3 | ## Developer setup 4 | An [editable installation](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) 5 | is preferable for developers: 6 | ```sh 7 | # from your local clone: 8 | pip install -e . 9 | ``` 10 | 11 | ## Continuous deployment 12 | The continuous deployment (CD) pipeline uses 13 | [`python-semantic-release`](https://github.com/relekang/python-semantic-release) and 14 | [`setuptools_scm`](https://github.com/pypa/setuptools_scm) to automatically manage 15 | versioning and releasing. For each commit to `main`, i.e. upon merging, the pipeline 16 | assesses if a release is needed and –if that is the case– creates and pushes a Git tag, a 17 | GitHub release, and a PyPI/Artifactory package version. 18 | 19 | To check if a release is required and determine the target version, 20 | python-semantic-release parses the relevant commit messages assuming 21 | [Angular style](https://github.com/angular/angular.js/blob/master/DEVELOPERS.md#commits), 22 | and decides *if* and *how* to bump the version 23 | (as per [SemVer](https://semver.org)) based on config options `parser_angular_patch_types` 24 | and `parser_angular_minor_types`. 25 | 26 | Consider an example where the current version is `3.4.1` and `pyproject.toml` includes: 27 | ``` 28 | # ... 29 | [tool.semantic_release] 30 | parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test" 31 | parser_angular_patch_types = "fix,perf" 32 | parser_angular_minor_types = "feat" 33 | # ... 34 | ``` 35 | 36 | Then: 37 | - merging commit `"docs: update README"` would *not* lead to a new release 38 | - merging commit `"perf: optimize parser"` would lead to new release `3.4.2` (patch bump) 39 | - merging commit `"feat: add parser"` would lead to a new release `3.5.0` (minor bump) 40 | - merging a commit with a `"BREAKING CHANGE:"` 41 | [footer](https://github.com/angular/angular.js/blob/master/DEVELOPERS.md#footer) would 42 | lead to a new release `4.0.0` (major bump) 43 | - note that this logic is currently hard-coded in python-semantic-release, i.e. it is 44 | currently not possible to configure custom cues/patterns for triggering major bumps. 45 | 46 | Note that each type defined in `parser_angular_patch_types` or `parser_angular_minor_types` 47 | should also be in `parser_angular_allowed_types`, otherwise it will not trigger. 48 | 49 | For more details about python-semantic-release check its 50 | [docs](https://python-semantic-release.readthedocs.io) and [default config](https://github.com/relekang/python-semantic-release/blob/master/semantic_release/defaults.cfg), 51 | as well as MMMT's local config in `pyproject.toml`, under `[tool.semantic_release]`. 52 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | clean: 18 | @-rm -rf $(BUILDDIR)/* 19 | @-rm -rf api/*.rst 20 | 21 | # Catch-all target: route all unknown targets to Sphinx using the new 22 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 23 | %: Makefile 24 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 25 | -------------------------------------------------------------------------------- /docs/_templates/module.rst: -------------------------------------------------------------------------------- 1 | {# The :autogenerated: tag is picked up by breadcrumbs.html to suppress "Edit on Github" link #} 2 | :autogenerated: 3 | 4 | {{ fullname }} module 5 | {% for item in range(7 + fullname|length) -%}={%- endfor %} 6 | 7 | .. currentmodule:: {{ fullname }} 8 | 9 | .. automodule:: {{ fullname }} 10 | {% if members -%} 11 | :members: {{ members|join(", ") }} 12 | :undoc-members: 13 | :show-inheritance: 14 | :member-order: bysource 15 | 16 | Summary 17 | ------- 18 | 19 | {%- if exceptions %} 20 | 21 | Exceptions: 22 | 23 | .. autosummary:: 24 | :nosignatures: 25 | {% for item in exceptions %} 26 | {{ item }} 27 | {%- endfor %} 28 | {%- endif %} 29 | 30 | {%- if classes %} 31 | 32 | Classes: 33 | 34 | .. autosummary:: 35 | :nosignatures: 36 | {% for item in classes %} 37 | {{ item }} 38 | {%- endfor %} 39 | {%- endif %} 40 | 41 | {%- if functions %} 42 | 43 | Functions: 44 | 45 | .. autosummary:: 46 | :nosignatures: 47 | {% for item in functions %} 48 | {{ item }} 49 | {%- endfor %} 50 | {%- endif %} 51 | {%- endif %} 52 | 53 | {% set data = get_members(typ='data', in_list='__all__') %} 54 | {%- if data %} 55 | 56 | Data: 57 | 58 | .. autosummary:: 59 | :nosignatures: 60 | {% for item in data %} 61 | {{ item }} 62 | {%- endfor %} 63 | {%- endif %} 64 | 65 | {% set all_refs = get_members(in_list='__all__', include_imported=True, out_format='refs') %} 66 | {% if all_refs %} 67 | ``__all__``: {{ all_refs|join(", ") }} 68 | {%- endif %} 69 | 70 | 71 | {% if members %} 72 | Reference 73 | --------- 74 | 75 | {%- endif %} 76 | -------------------------------------------------------------------------------- /docs/_templates/package.rst: -------------------------------------------------------------------------------- 1 | {# The :autogenerated: tag is picked up by breadcrumbs.html to suppress "Edit on Github" link #} 2 | :autogenerated: 3 | 4 | {{ fullname }} package 5 | {% for item in range(8 + fullname|length) -%}={%- endfor %} 6 | 7 | .. automodule:: {{ fullname }} 8 | {% if members -%} 9 | :members: {{ members|join(", ") }} 10 | :undoc-members: 11 | :show-inheritance: 12 | {%- endif %} 13 | 14 | {% if submodules %} 15 | Submodules: 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | {% for item in submodules %} 20 | {{ fullname }}.{{ item }} 21 | {%- endfor %} 22 | {%- endif -%} 23 | 24 | {% if subpackages %} 25 | 26 | Subpackages: 27 | 28 | .. toctree:: 29 | :maxdepth: 1 30 | {% for item in subpackages %} 31 | {{ fullname }}.{{ item }} 32 | {%- endfor %} 33 | {%- endif %} 34 | 35 | {% set all = get_members(in_list='__all__', include_imported=True) %} 36 | {% if members or all %} 37 | Summary 38 | ------- 39 | 40 | {%- set exceptions = get_members(typ='exception', in_list='__all__', include_imported=True, out_format='table') -%} 41 | {%- set classes = get_members(typ='class', in_list='__all__', include_imported=True, out_format='table') -%} 42 | {%- set functions = get_members(typ='function', in_list='__all__', include_imported=True, out_format='table') -%} 43 | {%- set data = get_members(typ='data', in_list='__all__', include_imported=True, out_format='table') -%} 44 | {%- set private_exceptions = get_members(typ='exception', in_list='__private__', out_format='table') -%} 45 | {%- set private_classes = get_members(typ='class', in_list='__private__', out_format='table') -%} 46 | {%- set private_functions = get_members(typ='function', in_list='__private__', out_format='table') -%} 47 | 48 | {%- if exceptions %} 49 | 50 | ``__all__`` Exceptions: 51 | 52 | {% for line in exceptions %} 53 | {{ line }} 54 | {%- endfor %} 55 | {%- endif %} 56 | {%- if private_exceptions %} 57 | 58 | Private Exceptions: 59 | 60 | {% for line in private_exceptions %} 61 | {{ line }} 62 | {%- endfor %} 63 | {%- endif %} 64 | 65 | {%- if classes %} 66 | 67 | ``__all__`` Classes: 68 | 69 | {% for line in classes %} 70 | {{ line }} 71 | {%- endfor %} 72 | {%- endif %} 73 | {%- if private_classes %} 74 | 75 | Private Classes: 76 | 77 | {% for line in private_classes %} 78 | {{ line }} 79 | {%- endfor %} 80 | {%- endif %} 81 | 82 | {%- if functions %} 83 | 84 | ``__all__`` Functions: 85 | 86 | {% for line in functions %} 87 | {{ line }} 88 | {%- endfor %} 89 | {%- endif %} 90 | {%- if private_functions %} 91 | 92 | Private Functions: 93 | 94 | {% for line in private_functions %} 95 | {{ line }} 96 | {%- endfor %} 97 | {%- endif %} 98 | 99 | {%- if data %} 100 | 101 | ``__all__`` Data: 102 | 103 | {% for line in data %} 104 | {{ line }} 105 | {%- endfor %} 106 | {%- endif %} 107 | 108 | {%- endif %} 109 | 110 | 111 | {% if members %} 112 | Reference 113 | --------- 114 | 115 | {%- endif %} 116 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("..")) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = "mmmt" 22 | copyright = "IBM Corp. 2022" 23 | author = "IBM Research" 24 | 25 | # -- Generate API (auto) documentation ------------------------------------------------ 26 | 27 | 28 | def run_apidoc(app): 29 | """Generage API documentation""" 30 | import better_apidoc 31 | 32 | better_apidoc.APP = app 33 | better_apidoc.main( 34 | [ 35 | "better-apidoc", 36 | "-t", 37 | os.path.join(".", "_templates"), 38 | "--force", 39 | "--no-toc", 40 | "--separate", 41 | "-o", 42 | os.path.join(".", "api"), 43 | os.path.join("..", project), 44 | ] 45 | ) 46 | 47 | 48 | # -- General configuration --------------------------------------------------- 49 | 50 | # Add any Sphinx extension module names here, as strings. They can be 51 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 52 | # ones. 53 | extensions = [ 54 | "sphinx.ext.autodoc", 55 | "sphinx.ext.autosummary", 56 | "sphinx.ext.todo", 57 | "sphinx.ext.coverage", 58 | "sphinx.ext.viewcode", 59 | "sphinx.ext.githubpages", 60 | "sphinx.ext.napoleon", 61 | "sphinx_autodoc_typehints", 62 | "sphinx_rtd_theme", 63 | "myst_parser", 64 | ] 65 | 66 | # Add any paths that contain templates here, relative to this directory. 67 | templates_path = ["_templates"] 68 | 69 | # List of patterns, relative to source directory, that match files and 70 | # directories to ignore when looking for source files. 71 | # This pattern also affects html_static_path and html_extra_path. 72 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 73 | 74 | 75 | # -- Options for HTML output ------------------------------------------------- 76 | 77 | # The theme to use for HTML and HTML Help pages. See the documentation for 78 | # a list of builtin themes. 79 | # 80 | html_theme = "sphinx_rtd_theme" 81 | 82 | # Add any paths that contain custom static files (such as style sheets) here, 83 | # relative to this directory. They are copied after the builtin static files, 84 | # so a file named "default.css" will overwrite the builtin "default.css". 85 | # Commented out while no static files present, to prevent warning. 86 | # 87 | # html_static_path = ["_static"] 88 | 89 | 90 | # -- Extension configuration ------------------------------------------------- 91 | add_module_names = False 92 | 93 | 94 | napoleon_google_docstring = True 95 | napoleon_include_init_with_doc = True 96 | 97 | coverage_ignore_modules = [] 98 | coverage_ignore_functions = [] 99 | coverage_ignore_classes = [] 100 | 101 | coverage_show_missing_items = True 102 | 103 | # The name of the Pygments (syntax highlighting) style to use. 104 | pygments_style = "sphinx" 105 | 106 | # -- Options for todo extension ---------------------------------------------- 107 | 108 | # If true, `todo` and `todoList` produce output, else they produce nothing. 109 | todo_include_todos = True 110 | 111 | 112 | def setup(app): 113 | app.connect("builder-inited", run_apidoc) 114 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Multimodal Models Toolkit 2 | 3 | ```{include} ../README.md 4 | ``` 5 | ## Check the API 6 | ```{toctree} 7 | --- 8 | maxdepth: 3 9 | --- 10 | API 11 | ``` 12 | 13 | ## User guide 14 | 15 | ```{toctree} 16 | --- 17 | maxdepth: 2 18 | --- 19 | User guide 20 | ``` 21 | 22 | ## Development Guide 23 | 24 | ```{toctree} 25 | --- 26 | maxdepth: 2 27 | --- 28 | Development guide 29 | ``` 30 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/devguide.md: -------------------------------------------------------------------------------- 1 | ```{include} ../../dev_guide.md 2 | ``` 3 | -------------------------------------------------------------------------------- /mmmt/__init__.py: -------------------------------------------------------------------------------- 1 | """Module initialization.""" 2 | import mmmt.data 3 | import mmmt.models 4 | import mmmt.pipeline 5 | 6 | __all__ = ["data", "models", "pipeline"] 7 | -------------------------------------------------------------------------------- /mmmt/data/__init__.py: -------------------------------------------------------------------------------- 1 | import mmmt.data.graph 2 | import mmmt.data.representation 3 | import mmmt.data.operators 4 | -------------------------------------------------------------------------------- /mmmt/data/graph/__init__.py: -------------------------------------------------------------------------------- 1 | import mmmt.data.graph.data_to_graph 2 | from mmmt.data.graph.graph_to_graph import GraphTransform 3 | from mmmt.data.graph.dgl_data_loader import DGLFileLoader 4 | from mmmt.data.graph.mat_file_loader import MatFileLoader 5 | -------------------------------------------------------------------------------- /mmmt/data/graph/concept_to_graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | from fuse.data.datasets.caching.samples_cacher import SamplesCacher 3 | from fuse.data.datasets.dataset_default import DatasetDefault 4 | from fuse.utils.ndict import NDict 5 | 6 | from mmmt.data.operators.op_build_graph import OpBuildBaseGraph, OpBuildDerivedGraph 7 | 8 | import mlflow 9 | 10 | 11 | class ConceptToGraph: 12 | """ 13 | Transformation of samples in concatenated space to graphs using concepts 14 | """ 15 | 16 | def __init__( 17 | self, 18 | args_dict, 19 | ): 20 | """ 21 | Args: 22 | args_dict (dict): arguments of ConceptToGraph as defined in configuration yaml 23 | """ 24 | 25 | self.training_sample_ids = args_dict["pipeline"]["data_splits"]["train_ids"] 26 | self.val_sample_ids = args_dict["pipeline"]["data_splits"]["val_ids"] 27 | self.test_sample_ids = args_dict["pipeline"]["data_splits"]["test_ids"] 28 | 29 | self.concept_encoder_model = args_dict["pipeline"][ 30 | args_dict["step_args"]["io"]["concept_encoder_model_key"] 31 | ] 32 | 33 | self.thresh_q = args_dict["step_args"]["thresh_q"] 34 | self.graph_module = args_dict["step_args"]["module_identifier"] 35 | self.cache_graph_config = { 36 | "workers": args_dict["num_workers"], 37 | "restart_cache": args_dict["restart_cache"], 38 | "math_epsilon": 1e-5, 39 | } 40 | self.root_dir = args_dict["root_dir"] 41 | 42 | self.fused_dataset_key = args_dict["step_args"]["io"]["fused_dataset_key"] 43 | self.input_key = args_dict["step_args"]["io"]["input_key"] 44 | self.output_key = args_dict["step_args"]["io"]["output_key"] 45 | 46 | self.concept_to_graph_config = args_dict["step_args"] 47 | self.mmmt_pipeline = args_dict["pipeline"] 48 | self.pipeline = args_dict["pipeline"]["fuse_pipeline"] 49 | 50 | def __call__( 51 | self, 52 | ): 53 | """ 54 | Build graphs from samples by extending the FuseMedML pipeline with specific operators 55 | 56 | Returns: 57 | train, validation and test sets with samples structured as graphs 58 | """ 59 | with mlflow.start_run(run_name=f"{self.__class__.__qualname__}", nested=True): 60 | mlflow.log_params(NDict(self.concept_to_graph_config).flatten()) 61 | mlflow.log_param("parent", "yes") 62 | self.pipeline.extend( 63 | [ 64 | ( 65 | OpBuildBaseGraph(self.concept_encoder_model, self.thresh_q), 66 | dict( 67 | key_in_concat=self.input_key, 68 | key_in_concept="data.forward_pass.multimodal", 69 | key_out="data.base_graph", 70 | ), 71 | ), 72 | ( 73 | OpBuildDerivedGraph(self.graph_module), 74 | dict( 75 | key_in="data.base_graph", 76 | key_out=self.output_key, 77 | ), 78 | ), 79 | ] 80 | ) 81 | 82 | if self.cache_graph_config is None: 83 | self.cache_graph_config = {} 84 | if "cache_dirs" not in self.cache_graph_config: 85 | self.cache_graph_config["cache_dirs"] = [ 86 | os.path.join(self.root_dir, "cache_graph") 87 | ] 88 | 89 | cacher_graph = SamplesCacher( 90 | "cache_graph", 91 | self.pipeline, 92 | audit_first_sample=False, 93 | audit_rate=None, # disabling audit because deepdiff returns an error when comparing tensors, which are contained in derived_graphs from modules ['gcn', 'rgcn', 'mplex-prop'] 94 | **self.cache_graph_config, 95 | ) 96 | graph_train_dataset = DatasetDefault( 97 | sample_ids=self.training_sample_ids, 98 | static_pipeline=self.pipeline, 99 | cacher=cacher_graph, 100 | ) 101 | graph_train_dataset.create() 102 | 103 | graph_validation_dataset = DatasetDefault( 104 | sample_ids=self.val_sample_ids, 105 | static_pipeline=self.pipeline, 106 | cacher=cacher_graph, 107 | ) 108 | graph_validation_dataset.create() 109 | 110 | if self.test_sample_ids is not None: 111 | graph_test_dataset = DatasetDefault( 112 | sample_ids=self.test_sample_ids, 113 | static_pipeline=self.pipeline, 114 | cacher=cacher_graph, 115 | ) 116 | graph_test_dataset.create() 117 | else: 118 | graph_test_dataset = None 119 | 120 | self.mmmt_pipeline[self.fused_dataset_key] = { 121 | "graph_train_dataset": graph_train_dataset, 122 | "graph_validation_dataset": graph_validation_dataset, 123 | "graph_test_dataset": graph_test_dataset, 124 | } 125 | -------------------------------------------------------------------------------- /mmmt/data/graph/data_to_graph.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import numpy as np 5 | import scipy.sparse as spp 6 | import dgl 7 | 8 | 9 | def create_edge_list(adj): 10 | """ 11 | Create the list of edges from an adjacency matrix. 12 | 13 | Args: 14 | adj: adjacency matrix in CSC format 15 | 16 | Returns: 17 | list_edge_tuples: list of edges, each edge is a tuple 18 | """ 19 | rows, cols = adj.nonzero() 20 | list_edge_tuples = [] 21 | 22 | for ind, r in enumerate(rows): 23 | list_edge_tuples.append((r, cols[ind])) 24 | 25 | return list_edge_tuples 26 | 27 | 28 | def heterograph_creator(dataset): 29 | """ 30 | Create a heterograph according to a dataset coming from DGL data loader. 31 | 32 | Args: 33 | dataset: a dataset from DGL loader utility 34 | 35 | Returns: 36 | hg: heterograph object with node ("feat") 37 | """ 38 | 39 | g = {} 40 | edge_types = np.unique(dataset.edge_type) # find unique edge types 41 | 42 | for j in range(len(edge_types)): 43 | # extract edge relations of type i 44 | i = edge_types[j] 45 | edge_list = ( 46 | dataset.edge_src[dataset.edge_type == i], 47 | dataset.edge_dst[dataset.edge_type == i], 48 | ) 49 | 50 | g[("feat", str(j), "feat")] = edge_list 51 | 52 | hg = dgl.heterograph(g) 53 | 54 | return hg 55 | 56 | 57 | def multigraph_object_creator( 58 | common_encoder, common_embedding, concatenated_modality_inputs, thresh_q 59 | ): 60 | """ 61 | Creates a multigraph from common encoder and embeddings using a quantile-based threshold 62 | 63 | Args: 64 | common_encoder : common encoder 65 | common_embedding: latent representation of all samples from common encoder 66 | concatenated_modality_inputs: concatenated embeddings of all samples 67 | thresh_q: quantile for thresholding 68 | 69 | Returns: 70 | multi_graphs: multigraph dgl object 71 | """ 72 | 73 | num_rel_type = common_embedding.shape[ 74 | 0 75 | ] # size of encoding - number of relation types 76 | num_nodes = concatenated_modality_inputs.shape[0] # size of graph - number of nodes 77 | 78 | common_embedding_diff = torch.zeros( 79 | [num_nodes, num_rel_type] 80 | ) # initialise impact matrix 81 | 82 | for n in range(num_nodes): 83 | # compute feature impact 84 | concatenated_modality_inputs_copied = copy.deepcopy( 85 | concatenated_modality_inputs 86 | ) 87 | concatenated_modality_inputs_copied[n] = 0.0 88 | common_embedding_perturbed = common_encoder( 89 | torch.from_numpy(concatenated_modality_inputs_copied) 90 | ) 91 | common_embedding_diff[n] = abs( 92 | common_embedding_perturbed.detach().squeeze() - common_embedding 93 | ) 94 | 95 | thresh = torch.quantile(common_embedding_diff, thresh_q, dim=0).expand_as( 96 | common_embedding_diff 97 | ) 98 | 99 | impacted_features = (common_embedding_diff > thresh).float() 100 | 101 | g = {} 102 | 103 | n_edges = {} 104 | 105 | for rel_type in range(num_rel_type): 106 | 107 | # create planar adjacency matrix 108 | adj_matrix = impacted_features[:, rel_type].reshape((-1, 1)) 109 | 110 | # removes self-connections 111 | adj_matrix_no_self_connections = spp.coo_matrix( 112 | (adj_matrix.mm(adj_matrix.transpose(0, 1))).mul( 113 | torch.ones(adj_matrix.shape[0]) - torch.eye(adj_matrix.shape[0]) 114 | ) 115 | ) 116 | 117 | rows, cols = adj_matrix_no_self_connections.nonzero() 118 | 119 | n_edges[rel_type] = adj_matrix_no_self_connections.sum() 120 | list_edge_tuples = [] 121 | for ind, r in enumerate(rows): 122 | list_edge_tuples.append((r, cols[ind])) 123 | 124 | # create planar graph from adjacency matrix, then add to heterograph 125 | g[("feat", str(rel_type), "feat")] = list_edge_tuples 126 | 127 | # create multigraph object, second argument ensures number of nodes are consistent across planes 128 | multi_graph = dgl.heterograph(g, {"feat": num_nodes}) 129 | 130 | node_features = np.array( 131 | [[x] for x in concatenated_modality_inputs.tolist()] 132 | ).astype(np.float32) 133 | 134 | return multi_graph, node_features 135 | -------------------------------------------------------------------------------- /mmmt/data/graph/dgl_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from dgl.contrib.data import load_data 5 | 6 | from .general_file_loader import GeneralFileLoader 7 | from .data_to_graph import heterograph_creator 8 | 9 | import logging 10 | 11 | logging.basicConfig(level=logging.INFO) 12 | 13 | 14 | class DGLFileLoader(GeneralFileLoader): 15 | """ 16 | DGL file loader 17 | Load graph and labels, separate training, validation and testing sets. 18 | """ 19 | 20 | def __init__(self, dataset_name, split_ratios, seed): 21 | """ 22 | Args: 23 | dataset_name: name of the dataset to retrieve 24 | split_ratios: list of 3 floating number for training, validation, testing split ratios, 25 | the sum has to give 1 26 | seed: pseudo-random number generator seed 27 | """ 28 | 29 | super().__init__(dataset_name, split_ratios, seed) 30 | 31 | if self.dataset_name == "AIFB": 32 | 33 | # load data 34 | self.data = load_data(dataset="aifb") 35 | 36 | # prepare split ratios 37 | self.set_random_seed(seed) 38 | self.num_label = len(self.data.train_idx) + len(self.data.test_idx) 39 | 40 | self.n_train_samples = int(np.ceil(split_ratios[0] * self.num_label)) 41 | self.n_val_samples = int(np.ceil(split_ratios[1] * self.num_label)) 42 | self.n_test_samples = ( 43 | len(self.data.labels) - self.n_train_samples - self.n_val_samples 44 | ) 45 | 46 | else: 47 | logging.error( 48 | "The dataset " + self.dataset_name + " has not yet been implemented" 49 | ) 50 | raise ValueError("Unknown dataset!") 51 | 52 | logging.info(dataset_name + " has been loaded") 53 | 54 | def build_graph(self): 55 | """ 56 | Build and return graph, labels, data split, the number of classes of the dataset, 57 | and the number of relation types in the graph. 58 | 59 | Returns: 60 | graph, labels, data split, the number of classes of the dataset, 61 | and the number of relation types in the graph 62 | """ 63 | 64 | # split set 65 | def_train_idx = self.data.train_idx 66 | def_test_idx = self.data.test_idx 67 | 68 | labeled_indices = np.concatenate((def_train_idx, def_test_idx), axis=0) 69 | 70 | ptation = np.random.permutation(self.num_label) 71 | train_idx = labeled_indices[ptation[: self.n_train_samples]] 72 | val_idx = labeled_indices[ 73 | ptation[self.n_train_samples : self.n_train_samples + self.n_val_samples] 74 | ] 75 | test_idx = labeled_indices[ptation[self.n_train_samples + self.n_val_samples :]] 76 | data_splits = [train_idx, val_idx, test_idx] 77 | 78 | # create graph 79 | g = heterograph_creator(self.data) 80 | 81 | # prepare label 82 | labels = self.data.labels 83 | labels = torch.from_numpy(labels).view(-1) 84 | 85 | # extract graph properties 86 | num_rels = len(g.etypes) 87 | n_classes = self.data.num_classes 88 | 89 | logging.info("Graph and metadata is ready") 90 | 91 | return g, labels, data_splits, n_classes, num_rels 92 | 93 | 94 | if __name__ == "__main__": 95 | 96 | DGL_FL = DGLFileLoader("AIFB", [0.7, 0.2, 0.1], 0) 97 | g, labels, data_splits, n_classes, num_rels = DGL_FL.build_graph() 98 | -------------------------------------------------------------------------------- /mmmt/data/graph/general_file_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | import logging 5 | import math 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | 10 | class GeneralFileLoader: 11 | """ 12 | General file loader 13 | Load graph and labels, separate training, validation and testing sets. 14 | """ 15 | 16 | def __init__(self, dataset_name, split_ratios, seed=0): 17 | """ 18 | Args: 19 | dataset_name: name of the dataset to retrieve 20 | split_ratios: list of 3 floating number for training, validation, testing split ratios, 21 | the sum has to give 1 22 | seed: pseudo-random number generator seed 23 | """ 24 | 25 | self.dataset_name = dataset_name 26 | self.data = None 27 | self.selected_classes = None 28 | self.num_label = None 29 | self.n_train_samples = None 30 | self.n_val_samples = None 31 | self.n_test_samples = None 32 | 33 | # check split ratio 34 | assert math.isclose(sum(split_ratios), 1), "split ratios do not sum 1" 35 | 36 | @staticmethod 37 | def set_random_seed(seed=0): 38 | """ 39 | Set the pseudo-random number generator seed, this function is only containing the set seed for numpy, 40 | which is used for data splits. For network initialization additional set seed functions should be used. 41 | 42 | Args: 43 | seed: pseudo-random number generator seed 44 | """ 45 | np.random.seed(seed) 46 | 47 | def build_graph(self): 48 | """ 49 | Build and return graph, labels, data split, the number of classes of the dataset, 50 | and the number of relation types in the graph. 51 | 52 | Returns: 53 | graph, labels, data split, the number of classes of the dataset, 54 | and the number of relation types in the graph 55 | """ 56 | g, labels, data_splits, n_classes, num_rels = None, None, None, None, None 57 | 58 | return g, labels, data_splits, n_classes, num_rels 59 | -------------------------------------------------------------------------------- /mmmt/data/graph/graph_to_graph.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import scipy.sparse as spp 3 | import numpy as np 4 | 5 | import logging 6 | 7 | 8 | class GraphTransform: 9 | """ 10 | Class containing methods to transform between graph structures 11 | """ 12 | 13 | def __init__( 14 | self, graph: dgl.DGLGraph, node_features, graph_module="mplex", alpha=0.5 15 | ): 16 | """ 17 | Args: 18 | graph: a graph in dgl format 19 | node_features: features of the nodes 20 | graph_module: the method to use for transformation, available values: ['mplex', 'mplex-prop', 'multibehav', 'mGNN'] 21 | alpha: paramater to control the intra and inter layer message passing, only applicable for mplex, mplex-prop 22 | """ 23 | self.graph = graph 24 | self.node_features = node_features 25 | self.graph_module = graph_module 26 | self.alpha = alpha 27 | 28 | def transform(self): 29 | """ 30 | Coordinate the graph transformation. 31 | 32 | Returns: 33 | transformed graph 34 | If method is not recognized returns None 35 | """ 36 | 37 | transformations = { 38 | "mplex": self.create_multiplex_graph_object, 39 | "mplex-prop": self.create_multiplex_prop_graph_object, 40 | "multibehav": self.create_multibehav_graph_object, 41 | "mgnn": self.create_mGNN_graph_object, 42 | "rgcn": self.pass_multigraph, 43 | "gcn": self.create_homogeneous_graph, 44 | } 45 | 46 | if self.graph_module in transformations: 47 | return transformations[self.graph_module]() 48 | else: 49 | logging.error( 50 | self.graph_module 51 | + " method not implemented. Choose within " 52 | + str(transformations.keys()) 53 | ) 54 | raise ValueError( 55 | "Unknown method! Choose within " + str(transformations.keys()) 56 | ) 57 | 58 | def compute_mplex_walk_mat(self): 59 | """ 60 | Computes the multiplex walk matrices 61 | 62 | Returns: 63 | AC: first intra-layer, then same layer 64 | CA: first same layer, then intra-layer 65 | """ 66 | e_types = self.graph.etypes 67 | number_e_types = len(e_types) 68 | number_nodes = self.graph.num_nodes() 69 | mplex_node_features = None 70 | 71 | adj_dimension = number_nodes * number_e_types 72 | intra_L = np.zeros((adj_dimension, adj_dimension)) 73 | 74 | for ind_et, e_type in enumerate(e_types): 75 | adj = self.graph[e_type].adjacency_matrix(transpose=True).to_dense().numpy() 76 | intra_L[ 77 | ind_et * number_nodes : ind_et * number_nodes + number_nodes, 78 | ind_et * number_nodes : ind_et * number_nodes + number_nodes, 79 | ] = adj 80 | if ind_et == 0: 81 | mplex_node_features = self.node_features 82 | else: 83 | mplex_node_features = np.concatenate( 84 | (mplex_node_features, self.node_features), axis=0 85 | ) 86 | 87 | mat_A = spp.coo_matrix( 88 | self.alpha * np.ones((number_e_types, number_e_types)) 89 | ) - spp.coo_matrix(self.alpha * np.eye(number_e_types)) 90 | mat_B = spp.coo_matrix(np.eye(number_nodes)) 91 | 92 | mat_C = spp.coo_matrix((1 - self.alpha) * np.eye(number_e_types)) 93 | 94 | C = spp.kron(mat_A, mat_B) + spp.kron(mat_C, mat_B) 95 | 96 | intra_L = spp.coo_matrix(intra_L) 97 | 98 | AC = intra_L.dot(C) 99 | CA = C.dot(intra_L) 100 | 101 | return AC, CA, mplex_node_features 102 | 103 | def create_multiplex_graph_object( 104 | self, 105 | ): 106 | """ 107 | Creates multiplex graph from heterograph objects 108 | 109 | Returns: 110 | message passing object type 111 | g1: type I supra-adjacency 112 | g2: type II supra-adjacency 113 | """ 114 | 115 | AC, CA, mplex_node_features = self.compute_mplex_walk_mat() 116 | 117 | g1 = dgl.from_scipy(spp.coo_matrix(AC)) 118 | g2 = dgl.from_scipy(spp.coo_matrix(CA)) 119 | 120 | # # type II graph can be inferred from type I graph TODO: verify 121 | # g2 = dgl.from_scipy(g1.adj(scipy_fmt="coo").T) 122 | 123 | return [g1, g2], mplex_node_features 124 | 125 | def create_multiplex_prop_graph_object(self): 126 | """ 127 | Creates multiplex propagation graphs from heterograph and node features 128 | 129 | Returns: 130 | g1: Type I graph objects 131 | g2: Type II graph objects 132 | """ 133 | 134 | AC, CA, mplex_node_features = self.compute_mplex_walk_mat() 135 | 136 | g1 = dgl.from_scipy(spp.coo_matrix(AC), eweight_name="w") 137 | g2 = dgl.from_scipy(spp.coo_matrix(CA), eweight_name="w") 138 | 139 | return [g1, g2], mplex_node_features 140 | 141 | def create_multibehav_graph_object(self): 142 | """ 143 | Creates quotient graph and multilayered graph according to https://dl.acm.org/doi/pdf/10.1145/3340531.3412119 144 | 145 | Returns: 146 | quotient_graph: quotient graph 147 | mplex_graph: multilayered graph 148 | """ 149 | e_types = self.graph.etypes 150 | number_e_types = len(e_types) 151 | number_nodes = self.graph.num_nodes() 152 | 153 | adj_dimension = number_nodes * number_e_types 154 | quotient_adj = np.zeros((number_nodes, number_nodes)) 155 | intra_L = np.zeros((adj_dimension, adj_dimension)) 156 | 157 | for ind_et, e_type in enumerate(e_types): 158 | adj = self.graph[e_type].adjacency_matrix(transpose=True).to_dense().numpy() 159 | intra_L[ 160 | ind_et * number_nodes : ind_et * number_nodes + number_nodes, 161 | ind_et * number_nodes : ind_et * number_nodes + number_nodes, 162 | ] = adj 163 | quotient_adj = quotient_adj + adj / number_e_types 164 | 165 | quotient_adj = spp.coo_matrix(quotient_adj) 166 | C = spp.coo_matrix( 167 | np.kron(np.ones((number_e_types, number_e_types)), np.eye(number_nodes)) 168 | - np.eye(number_nodes * number_e_types) 169 | ) 170 | 171 | intra_L = spp.coo_matrix(intra_L) 172 | mplex_adj = intra_L + C 173 | 174 | quotient_graph = dgl.from_scipy(quotient_adj) 175 | mplex_graph = dgl.from_scipy(mplex_adj) 176 | 177 | return [quotient_graph, mplex_graph], self.node_features 178 | 179 | def create_mGNN_graph_object(self): 180 | """ 181 | Creates intra and inter multigraph objects according to https://arxiv.org/pdf/2109.10119.pdf 182 | 183 | Returns: 184 | message passing object types 185 | g_inter_layer: inter graph 186 | g_intra_layer: intra graph 187 | """ 188 | e_types = self.graph.etypes 189 | number_e_types = len(e_types) 190 | number_nodes = self.graph.num_nodes() 191 | mplex_node_features = None 192 | 193 | adj_dimension = number_nodes * number_e_types 194 | intra_L = np.zeros((adj_dimension, adj_dimension)) 195 | 196 | for ind_et, e_type in enumerate(e_types): 197 | adj = self.graph[e_type].adjacency_matrix(transpose=True).to_dense().numpy() 198 | intra_L[ 199 | ind_et * number_nodes : ind_et * number_nodes + number_nodes, 200 | ind_et * number_nodes : ind_et * number_nodes + number_nodes, 201 | ] = adj 202 | if ind_et == 0: 203 | mplex_node_features = self.node_features 204 | else: 205 | mplex_node_features = np.concatenate( 206 | (mplex_node_features, self.node_features), axis=0 207 | ) 208 | 209 | C = spp.coo_matrix( 210 | np.kron(np.ones((number_e_types, number_e_types)), np.eye(number_nodes)) 211 | - np.eye(number_nodes * number_e_types) 212 | ) 213 | 214 | intra_L = spp.coo_matrix(intra_L) 215 | 216 | g_inter_layer = dgl.from_scipy(C) 217 | g_intra_layer = dgl.from_scipy(intra_L) 218 | 219 | return [g_inter_layer, g_intra_layer], mplex_node_features 220 | 221 | def pass_multigraph(self): 222 | return [self.graph], self.node_features 223 | 224 | def create_homogeneous_graph(self): 225 | return [dgl.to_homogeneous(self.graph)], self.node_features 226 | -------------------------------------------------------------------------------- /mmmt/data/graph/mat_file_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dgl 3 | import urllib 4 | 5 | import numpy as np 6 | import torch 7 | import scipy 8 | 9 | from .general_file_loader import GeneralFileLoader 10 | from .data_to_graph import create_edge_list 11 | 12 | import logging 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | 17 | class MatFileLoader(GeneralFileLoader): 18 | """ 19 | MAT file loader 20 | Load graph and labels, separate training, validation and testing sets. 21 | """ 22 | 23 | def __init__(self, dataset_name, split_ratios, seed): 24 | """ 25 | Args: 26 | dataset_name: name of the dataset to retrieve 27 | split_ratios: list of 3 floating number for training, validation, testing split ratios, 28 | the sum has to give 1 29 | seed: pseudo-random number generator seed 30 | """ 31 | 32 | super().__init__(dataset_name, split_ratios, seed) 33 | 34 | if self.dataset_name == "ACM": 35 | 36 | # load data 37 | data_url = "https://data.dgl.ai/dataset/ACM.mat" 38 | tmp_path = "./tmp" 39 | os.makedirs(tmp_path, exist_ok=True) 40 | 41 | data_file_path = os.path.join(tmp_path, "ACM.mat") 42 | 43 | urllib.request.urlretrieve(data_url, data_file_path) 44 | self.data = scipy.io.loadmat(data_file_path) 45 | 46 | # prepare split ratios 47 | # for ACM dataset the class 6 has no samples, therefore we are not including it 48 | self.selected_classes = [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13] 49 | 50 | self.set_random_seed(seed) 51 | self.num_label = self.data["PvsC"].shape[0] 52 | 53 | self.n_train_samples = int(np.ceil(self.num_label * split_ratios[0])) # 800 54 | self.n_val_samples = int(np.ceil(self.num_label * split_ratios[1])) # 200 55 | self.n_test_samples = ( 56 | self.num_label - self.n_train_samples - self.n_val_samples 57 | ) 58 | 59 | else: 60 | logging.error( 61 | "The dataset " + self.dataset_name + " has not yet been implemented" 62 | ) 63 | raise ValueError("Unknown dataset!") 64 | 65 | logging.info(dataset_name + " has been loaded") 66 | 67 | def build_graph(self): 68 | """ 69 | Build and return graph, labels, data split, the number of classes of the dataset, 70 | and the number of relation types in the graph. 71 | 72 | Returns: 73 | graph, labels, data split, the number of classes of the dataset, 74 | and the number of relation types in the graph 75 | """ 76 | 77 | # processing of MAT file for ACM is very specific, therefore currently isolated in the relative if branch 78 | if self.dataset_name == "ACM": 79 | # split set 80 | pvc = self.data["PvsC"].tocsr() 81 | p_selected = pvc[:, self.selected_classes].tocoo() 82 | pid = p_selected.row 83 | 84 | shuffle = np.random.permutation(pid) 85 | train_idx = torch.tensor(shuffle[0 : self.n_train_samples]).long() 86 | val_idx = torch.tensor( 87 | shuffle[ 88 | self.n_train_samples : self.n_train_samples + self.n_val_samples 89 | ] 90 | ).long() 91 | test_idx = torch.tensor( 92 | shuffle[self.n_train_samples + self.n_val_samples :] 93 | ).long() 94 | data_splits = [train_idx, val_idx, test_idx] 95 | 96 | # create graph 97 | ppA = self.data["PvsA"].dot(self.data["PvsA"].transpose()) > 1 98 | ppL = self.data["PvsL"].dot(self.data["PvsL"].transpose()) >= 1 99 | ppP = self.data["PvsP"] 100 | 101 | g = dgl.heterograph( 102 | { 103 | ("feat", "0", "feat"): create_edge_list(ppA), 104 | ("feat", "1", "feat"): create_edge_list(ppP), 105 | ("feat", "2", "feat"): create_edge_list(ppL), 106 | } 107 | ) 108 | 109 | # prepare labels 110 | labels = pvc.indices 111 | for ind, lbl in enumerate(self.selected_classes): 112 | labels[labels == lbl] = ind 113 | 114 | labels = torch.tensor(labels).long() 115 | 116 | # extract graph properties 117 | num_rels = len(g.etypes) 118 | n_classes = len(self.selected_classes) 119 | 120 | logging.info("Graph and metadata is ready") 121 | 122 | return g, labels, data_splits, n_classes, num_rels 123 | 124 | 125 | if __name__ == "__main__": 126 | M_FL = MatFileLoader("ACM", [0.064, 0.016, 0.92], 0) 127 | g, labels, data_splits, n_classes, num_rels = M_FL.build_graph() 128 | -------------------------------------------------------------------------------- /mmmt/data/graph/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dgl 4 | import graphviz 5 | import matplotlib.pyplot as plt 6 | import mlflow 7 | import networkx as nx 8 | import numpy as np 9 | import nxviz as nv 10 | from fuse.utils.ndict import NDict 11 | from nxviz import annotate 12 | from itertools import chain 13 | 14 | 15 | class GraphVisualization: 16 | """Visualization of the graph(s) induced by a dataset""" 17 | 18 | def __init__( 19 | self, 20 | args_dict, 21 | ): 22 | """ 23 | Args: 24 | args_dict (dict): arguments of DatasetGraphVisualization as defined in configuration yaml 25 | """ 26 | 27 | self.selected_samples = args_dict["step_args"]["selected_samples"] 28 | self.fused_dataset_key = args_dict["step_args"]["io"]["fused_dataset_key"] 29 | self.file_prefix = args_dict["step_args"]["io"]["file_prefix"] 30 | self.feature_group_sizes = args_dict["step_args"]["feature_group_sizes"] 31 | 32 | self.visualization_config = args_dict["step_args"] 33 | self.mmmt_pipeline = args_dict["pipeline"] 34 | self.pipeline = args_dict["pipeline"]["fuse_pipeline"] 35 | self.root_dir = args_dict["root_dir"] 36 | self.mmmt_pipeline = args_dict["pipeline"] 37 | 38 | def __call__( 39 | self, 40 | ): 41 | """ 42 | Build graphs from samples by extending the FuseMedML pipeline with specific operators 43 | 44 | Returns: 45 | train, validation and test sets with samples structured as graphs 46 | """ 47 | with mlflow.start_run(run_name=f"{self.__class__.__qualname__}", nested=True): 48 | mlflow.log_params(NDict(self.visualization_config).flatten()) 49 | 50 | node_categories = list( 51 | chain.from_iterable( 52 | [ 53 | [category] * size 54 | for category, size in self.feature_group_sizes.items() 55 | ] 56 | ) 57 | ) 58 | 59 | for split_name, split_samples in self.selected_samples.items(): 60 | selected_dataset = self.mmmt_pipeline[self.fused_dataset_key][ 61 | split_name 62 | ] 63 | if split_samples == "all": 64 | filename = os.path.join( 65 | self.root_dir, f"{self.file_prefix}_{split_name}_all.png" 66 | ) 67 | self.visualize_dataset( 68 | selected_dataset, 69 | samples=None, 70 | save_file=filename, 71 | node_categories=node_categories, 72 | ) 73 | mlflow.log_artifact(filename) 74 | else: 75 | samples_name = "_".join(map(str, split_samples)) 76 | filename = os.path.join( 77 | self.root_dir, 78 | f"{self.file_prefix}_{split_name}_{samples_name}.png", 79 | ) 80 | self.visualize_dataset( 81 | selected_dataset, 82 | samples=split_samples, 83 | save_file=filename, 84 | node_categories=node_categories, 85 | ) 86 | mlflow.log_artifact(filename) 87 | 88 | @staticmethod 89 | def visualize_dataset( 90 | dataset, 91 | graph_key="data.base_graph.graph", 92 | node_names_key="names.data.input.multimodal", 93 | node_categories=None, 94 | samples=None, 95 | max_edge_thickness=20, 96 | save_file=None, 97 | ): 98 | """Visualization of a Graph Based Multimodal Representation 99 | 100 | :param dataset: FuseMedML dataset to be visualized 101 | :type dataset: FuseMedML dataset 102 | :param graph_key: dictionary key that points to the graph, defaults to "data.base_graph.graph" 103 | :type graph_key: str, optional 104 | :param node_names_key: key that contains names of the nodes, defaults to "names.data.input.multimodal" 105 | :type node_names_key: str, optional 106 | :param node_categories: list of categories for the nodes, defaults to None 107 | :type node_categories: [str], optional 108 | :param samples: visulize only some samples, defaults to None 109 | :type samples: [int], optional 110 | :param max_edge_thickness: maximum edge thickness, defaults to 20 111 | :type max_edge_thickness: int, optional 112 | :return: NetworkX graph used for visualization 113 | :rtype: nx.Graph 114 | """ 115 | 116 | if node_categories is None: 117 | node_categories = dataset[0][node_names_key] 118 | node_names = dataset[0].get(node_names_key, node_categories) 119 | 120 | adj = np.zeros_like( 121 | dgl.to_homogeneous(dataset[0][graph_key]) 122 | .adjacency_matrix() 123 | .to_dense() 124 | .numpy() 125 | ) 126 | for sample in dataset: 127 | if not samples or sample["data"]["sample_id"] in samples: 128 | adj = ( 129 | adj 130 | + dgl.to_homogeneous(sample[graph_key]) 131 | .adjacency_matrix() 132 | .to_dense() 133 | .numpy() 134 | ) 135 | 136 | adj = max_edge_thickness * adj / adj.max() 137 | 138 | attrs = { 139 | num: { 140 | "value": num, 141 | "name": node_name, 142 | "group": node_categories[num].split(".")[-1], 143 | } 144 | for (num, node_name) in enumerate(node_names) 145 | } 146 | 147 | G = nx.from_numpy_matrix(adj) 148 | nx.set_node_attributes(G, attrs) 149 | 150 | plt.figure(figsize=(12, 12)) 151 | pos = nv.nodes.circos(G, group_by="group", color_by="group") 152 | nv.edges.circos(G, pos, lw_by="weight") 153 | annotate.circos_group(G, group_by="group") 154 | nv.plots.despine() 155 | nv.plots.aspect_equal() 156 | 157 | if save_file is not None: 158 | plt.savefig(save_file, format="PNG") 159 | 160 | return G 161 | 162 | def visualize_encoding_strategy(encoding_strategy): 163 | 164 | concept_encoder_name = "Concept Encoder" 165 | 166 | encoding_graph = nx.DiGraph() 167 | encoding_graph.add_node(concept_encoder_name) 168 | for modality_key, modality_encoder in encoding_strategy.get( 169 | "modality_encoders" 170 | ).items(): 171 | modality_name = modality_key.split()[-1] 172 | encoding_graph.add_node(modality_name) 173 | encoding_graph.add_edge(modality_name, concept_encoder_name) 174 | 175 | encoding_graph.add_edge(concept_encoder_name, "Graph") 176 | plt.figure(figsize=(12, 12)) 177 | nx.draw_networkx(encoding_graph) 178 | 179 | return encoding_graph 180 | 181 | def visualize_fuse_pipeline(pipeline, name=None, source_keys=None): 182 | """Visualize FuseMedML pipeline 183 | 184 | :param pipeline: FuseMedML pipeline 185 | :type pipeline: FuseMedML pipeline 186 | :param name: name of the pipeline, defaults to None 187 | :type name: str, optional 188 | :param source_keys: names of the keys that are used as source, defaults to None 189 | :type source_keys: [str], optional 190 | :return: GraphViz version of the pipeline 191 | :rtype: graphviz graph 192 | """ 193 | 194 | if name is None: 195 | name = pipeline.get_name() 196 | graph = graphviz.Digraph(name=name, strict=True, format="png") 197 | 198 | for step_id, pipeline_step in enumerate(pipeline._ops_and_kwargs): 199 | op_name = pipeline_step[0] 200 | op_params = pipeline_step[1] 201 | op_short_name = op_name.__str__().split()[0].split(".")[-1] 202 | 203 | if step_id == 0 and source_keys is not None: 204 | for skey in source_keys: 205 | graph.edge(op_short_name, skey) 206 | 207 | op_inputs = [] 208 | op_outputs = [] 209 | 210 | for op_key, data_key in op_params.items(): 211 | if op_key.startswith("key"): 212 | 213 | if op_key.startswith("key_in") or op_key.startswith("keys_in"): 214 | if isinstance(data_key, list): 215 | op_inputs.extend(data_key) 216 | else: 217 | op_inputs.append(data_key) 218 | if op_key.startswith("key_out") or op_key.startswith("keys_out"): 219 | if isinstance(data_key, list): 220 | op_outputs.extend(data_key) 221 | else: 222 | op_outputs.append(data_key) 223 | 224 | if isinstance(data_key, list): # multiple keys 225 | for item in data_key: 226 | graph.node(item, shape="plaintext") 227 | 228 | else: # single key 229 | graph.node(data_key, shape="plaintext") 230 | 231 | if op_inputs and op_outputs: 232 | for input in op_inputs: 233 | for output in op_outputs: 234 | 235 | op_node_id = f"{op_short_name}_{step_id}" 236 | 237 | graph.node(op_node_id, label=op_short_name, shape="box") 238 | graph.edge(input, op_node_id) 239 | graph.edge(op_node_id, output) 240 | 241 | return graph 242 | -------------------------------------------------------------------------------- /mmmt/data/operators/__init__.py: -------------------------------------------------------------------------------- 1 | from mmmt.data.operators.op_build_graph import OpBuildBaseGraph, OpBuildDerivedGraph 2 | from mmmt.data.operators.op_concat_names import OpConcatNames 3 | from mmmt.data.operators.op_forwardpass import OpForwardPass 4 | from mmmt.data.operators.op_resample import Op3DResample 5 | -------------------------------------------------------------------------------- /mmmt/data/operators/op_build_graph.py: -------------------------------------------------------------------------------- 1 | from fuse.data.ops.op_base import OpBase 2 | from typing import List, Union 3 | from fuse.utils.ndict import NDict 4 | import logging 5 | 6 | import mmmt 7 | 8 | 9 | class OpBuildBaseGraph(OpBase): 10 | """ 11 | Operator for building the base graph using mmmt library 12 | """ 13 | 14 | def __init__(self, model, thresh_q): 15 | """Constructor method 16 | 17 | :param model: concept encoding model 18 | :type model: torch.nn.Module 19 | :param thresh_q: saliency threshold 20 | :type thresh_q: float 21 | """ 22 | super().__init__() 23 | self.model = model 24 | self.thresh_q = thresh_q 25 | self.model.eval() 26 | logging.debug(self.model) 27 | 28 | def __call__( 29 | self, 30 | sample_dict: NDict, 31 | key_in_concat="data.input.multimodal", 32 | key_in_concept="data.forward_pass.multimodal", 33 | key_out="data.graph", 34 | **kwargs, 35 | ) -> Union[None, dict, List[dict]]: 36 | """produces a graph for a given sample 37 | 38 | :param sample_dict: sample dictionary 39 | :type sample_dict: NDict 40 | :param key_in_concat: input key (node features) in sample dict, defaults to "data.input.multimodal" 41 | :type key_in_concat: str, optional 42 | :param key_in_concept: concept key in sample dict, defaults to "data.forward_pass.multimodal" 43 | :type key_in_concept: str, optional 44 | :param key_out: graph key to generate, defaults to "data.graph" 45 | :type key_out: str, optional 46 | :return: updated sample dict 47 | :rtype: Union[None, dict, List[dict]] 48 | """ 49 | 50 | ( 51 | multi_graph, 52 | node_features, 53 | ) = mmmt.data.graph.data_to_graph.multigraph_object_creator( 54 | self.model, 55 | sample_dict[key_in_concept], 56 | sample_dict[key_in_concat], 57 | self.thresh_q, 58 | ) 59 | 60 | sample_dict[key_out] = {"graph": multi_graph, "node_features": node_features} 61 | 62 | return sample_dict 63 | 64 | 65 | class OpBuildDerivedGraph(OpBase): 66 | """ 67 | Operator for building the derived graph (e.g. mplex) using mmmt library 68 | """ 69 | 70 | def __init__(self, graph_module): 71 | """Constructor method 72 | 73 | :param graph_module: type of graph module to be used downstream, available values: ['mplex', 'mplex-prop', 'multibehav', 'mGNN'] 74 | :type graph_module: str 75 | """ 76 | super().__init__() 77 | self.graph_module = graph_module 78 | 79 | def __call__( 80 | self, 81 | sample_dict: NDict, 82 | key_in="data.base_graph", 83 | key_out="data.derived_graph", 84 | **kwargs, 85 | ) -> Union[None, dict, List[dict]]: 86 | """transforms a graph for a sample, to adapt it to the donwstream graph module 87 | 88 | :param sample_dict: sample dictionary 89 | :type sample_dict: NDict 90 | :param key_in: dict key with base graph, defaults to "data.base_graph" 91 | :type key_in: str, optional 92 | :param key_out: dict key where the transformed graph will be stored, defaults to "data.derived_graph" 93 | :type key_out: str, optional 94 | :return: update sample dict 95 | :rtype: Union[None, dict, List[dict]] 96 | """ 97 | 98 | graph = sample_dict[key_in]["graph"] 99 | node_features = sample_dict[key_in]["node_features"] 100 | 101 | GT = mmmt.data.graph.graph_to_graph.GraphTransform( 102 | graph, node_features, self.graph_module 103 | ) 104 | derived_graph, derived_node_features = GT.transform() 105 | 106 | sample_dict[key_out] = { 107 | "graph": derived_graph, 108 | "node_features": derived_node_features, 109 | } 110 | 111 | return sample_dict 112 | -------------------------------------------------------------------------------- /mmmt/data/operators/op_concat_names.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Sequence 2 | 3 | from fuse.data.ops.op_base import OpBase 4 | from fuse.utils.ndict import NDict 5 | 6 | 7 | class OpConcatNames(OpBase): 8 | """ 9 | Concatenate feature names. When two keys containing feature vectors are concatenated, the names 10 | of the features also need to be concatenated. Like this: 11 | [clinical1... clinicalN] + [imagingM... imagingM] produce 12 | [clinical1... clinicalN, imaging1... imagingM] 13 | """ 14 | 15 | def __call__( 16 | self, sample_dict: NDict, keys_in: Sequence[str], key_out: str 17 | ) -> Union[None, dict, List[dict]]: 18 | """performs the concatenation of N lists of names 19 | 20 | :param sample_dict: sample dict 21 | :type sample_dict: NDict 22 | :param keys_in: dictionary keys that contain names that will be concatenated 23 | :type keys_in: Sequence[str] 24 | :param key_out: new dictionary key that will contain the concatenated names 25 | :type key_out: str 26 | :return: updated sample dict 27 | :rtype: Union[None, dict, List[dict]] 28 | """ 29 | 30 | feature_names = [] 31 | for concat_key in keys_in: 32 | names_key_in = f"names.{concat_key}" 33 | feature_names = feature_names + sample_dict[names_key_in] 34 | 35 | sample_dict[key_out] = feature_names 36 | 37 | return sample_dict 38 | -------------------------------------------------------------------------------- /mmmt/data/operators/op_forwardpass.py: -------------------------------------------------------------------------------- 1 | from fuse.data.ops.op_base import OpBase 2 | from typing import List, Union 3 | from fuse.utils.ndict import NDict 4 | import torch 5 | import logging 6 | import numpy as np 7 | 8 | 9 | class OpForwardPass(OpBase): 10 | """ 11 | Operator for applying a forward pass of a pretrained model. 12 | """ 13 | 14 | def __init__(self, model, modality_dimensions: int, add_feature_names: bool = True): 15 | """Constructor method 16 | 17 | :param model: pretrained model to be applied 18 | :type model: torch.nn.Module 19 | :param modality_dimensions: dimensions required by the model 20 | :type modality_dimensions: int 21 | """ 22 | super().__init__() 23 | self.model = model 24 | self.modality_dimensions = modality_dimensions 25 | self.add_feature_names = add_feature_names 26 | self.model.eval() 27 | logging.debug(self.model) 28 | 29 | def __call__( 30 | self, 31 | sample_dict: NDict, 32 | key_in=None, 33 | key_out=None, 34 | **kwargs, 35 | ) -> Union[None, dict, List[dict]]: 36 | """performs a forward pass on key_in and stores the output on key_out 37 | 38 | :param sample_dict: sample dictionary 39 | :type sample_dict: NDict 40 | :param key_in: input dictionary key, defaults to None 41 | :type key_in: str, optional 42 | :param key_out: dictionary key to store the output, defaults to None 43 | :type key_out: str, optional 44 | :raises ValueError: if the number of dimensions is not supported 45 | :return: updated sample dict 46 | :rtype: Union[None, dict, List[dict]] 47 | """ 48 | 49 | if isinstance(sample_dict[key_in], np.ndarray): 50 | input_tensor = torch.from_numpy(sample_dict[key_in]) 51 | else: 52 | input_tensor = sample_dict[key_in] 53 | 54 | while len(input_tensor.shape) < self.modality_dimensions + 1: 55 | input_tensor = input_tensor.unsqueeze(0) 56 | 57 | logging.debug( 58 | f"input tensor shape: {input_tensor.shape} and type: {input_tensor.dtype}" 59 | ) 60 | 61 | # get sample id and input key 62 | input_tensor = input_tensor.float() 63 | logging.debug(type(input_tensor)) 64 | output_tensor = self.model(input_tensor) 65 | sample_dict[key_out] = output_tensor.detach().squeeze() 66 | 67 | if self.add_feature_names: 68 | names_key_in = f"names.{key_in}" 69 | names_key_out = f"names.{key_out}" 70 | 71 | names_in = sample_dict.get( 72 | names_key_in, list(range(torch.numel(input_tensor))) 73 | ) 74 | sample_dict[names_key_in] = names_in 75 | 76 | if torch.equal(output_tensor, input_tensor): 77 | sample_dict[names_key_out] = sample_dict[names_key_in] 78 | else: 79 | feature_names = [] 80 | flat_key_out = key_out.replace(".", "_") 81 | for feature_n in range(torch.numel(output_tensor)): 82 | feature_names.append(f"{flat_key_out}.feat_{feature_n}") 83 | sample_dict[names_key_out] = feature_names 84 | 85 | return sample_dict 86 | -------------------------------------------------------------------------------- /mmmt/data/operators/op_resample.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch.nn.functional as F 4 | from fuse.data.ops.op_base import OpBase 5 | from fuse.utils.ndict import NDict 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class Op3DResample(OpBase): 11 | """ 12 | Resampler of 3D tensors to desired size 13 | """ 14 | 15 | def __init__(self, desired_size=[16, 16, 16], mode="nearest"): 16 | """constructor method 17 | 18 | :param desired_size: desired size of the tensor, defaults to [16, 16, 16] 19 | :type desired_size: list, optional 20 | :param mode: interpolation method, defaults to "nearest" 21 | :type mode: str, optional 22 | """ 23 | super().__init__() 24 | 25 | self.desired_size = desired_size 26 | self.mode = mode 27 | 28 | def __call__( 29 | self, 30 | sample_dict: NDict, 31 | key_in="data.input.img", 32 | key_out="data.input.img", 33 | **kwargs, 34 | ) -> Union[None, dict, List[dict]]: 35 | """performs the resampling of a sample 36 | 37 | :param sample_dict: sample dictionary 38 | :type sample_dict: NDict 39 | :param key_in: input dictionary key, defaults to "data.input.img" 40 | :type key_in: str, optional 41 | :param key_out: output dictionary key, defaults to "data.input.img" 42 | :type key_out: str, optional 43 | :return: updated sample dict 44 | :rtype: Union[None, dict, List[dict]] 45 | """ 46 | 47 | dimensions = len(self.desired_size) 48 | if dimensions == 1: 49 | print("Only one dimension") 50 | new_size = [self.desired_size for dim in sample_dict[key_in].shape] 51 | self.desired_size = new_size 52 | 53 | input = sample_dict[key_in] 54 | if type(input) == np.ndarray: 55 | input = torch.from_numpy(input) 56 | 57 | sample_dict[key_out] = F.interpolate( 58 | sample_dict[key_in].unsqueeze(0).unsqueeze(0), 59 | self.desired_size, 60 | mode=self.mode, 61 | ).squeeze() 62 | 63 | return sample_dict 64 | -------------------------------------------------------------------------------- /mmmt/data/representation/__init__.py: -------------------------------------------------------------------------------- 1 | from mmmt.data.representation.auto_encoder_trainer import AutoEncoderTrainer 2 | from mmmt.data.representation.auto_encoder_trainer import AutoEncoderBuilder 3 | from mmmt.data.representation.model_builder_trainer import ModelBuilderTrainer 4 | from mmmt.data.representation.encoded_unimodal_to_concept import ( 5 | EncodedUnimodalToConcept, 6 | ) 7 | from mmmt.data.representation.modality_encoding import ModalityEncoding 8 | -------------------------------------------------------------------------------- /mmmt/data/representation/fusion.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from collections import Iterable 3 | 4 | import torch 5 | from fuse.data.datasets.dataset_base import DatasetDefault 6 | from fuse.data.ops.ops_common import OpConcat 7 | from fuse.data.pipelines.pipeline_default import PipelineDefault 8 | 9 | from mmmt.data.operators.op_forwardpass import OpForwardPass 10 | 11 | 12 | class FusionBaseClass: 13 | """ 14 | Fusion base class. 15 | 16 | """ 17 | 18 | def __init__(self, input_modality_keys: Iterable, output_key: str, **kwargs): 19 | super().__init__() 20 | self.input_modality_keys = input_modality_keys 21 | self.output_key = output_key 22 | 23 | def train( 24 | self, 25 | dataset: DatasetDefault = None, 26 | **kwargs, 27 | ) -> None: 28 | """trains/fits the fusion method on an user specified dataset. 29 | 30 | :param dataset: fuse dataset on which to train the fusion method, defaults to None 31 | :type dataset: DatasetDefault, optional 32 | """ 33 | pass 34 | 35 | @abstractmethod 36 | def toNonDifferentiableFusePipeline( 37 | self, dataset_pipeline: PipelineDefault, **kwargs 38 | ) -> PipelineDefault: 39 | """provides the trained method as a fuse operator by extending a fuse pipeline object 40 | 41 | :param dataset_pipeline: data pipeline to be extended 42 | :type dataset_pipeline: PipelineDefault 43 | :return: Extended pipeline with this fusion operator 44 | :rtype: PipelineDefault 45 | """ 46 | raise NotImplementedError 47 | 48 | 49 | class DifferentiableFusion(FusionBaseClass): 50 | """ 51 | Differentiable fusion class 52 | """ 53 | 54 | def __init__( 55 | self, 56 | input_modality_keys: Iterable, 57 | output_key: str, 58 | modality_dimensions: int, 59 | **kwargs, 60 | ): 61 | super().__init__() 62 | self.modality_dimensions = modality_dimensions 63 | 64 | @abstractmethod 65 | def toPytorchModule(self) -> torch.nn.Module: 66 | """returns a pytorch module as differentiable fusion method 67 | 68 | :return: pytorch module so that it can be composed with others 69 | :rtype: torch.nn.Module 70 | """ 71 | raise NotImplementedError 72 | 73 | def toNonDifferentiableFusePipeline( 74 | self, dataset_pipeline: PipelineDefault, **kwargs 75 | ) -> PipelineDefault: 76 | """provides the trained method as non differentiable a fuse operator by extending a 77 | fuse pipeline object. Subclasses might overwrite this method. By default it will concatenate 78 | input_modality_keys and feed a forward pass of the trained model 79 | 80 | :param dataset_pipeline: data pipeline to be extended 81 | :type dataset_pipeline: PipelineDefault 82 | :return: Extended pipeline with this fusion operator 83 | :rtype: PipelineDefault 84 | """ 85 | self.concat_key = self.output_key + "_concat" 86 | dataset_pipeline.extend( 87 | [ 88 | ( 89 | OpConcat(), 90 | dict( 91 | keys_in=self.input_modality_keys, 92 | key_out=self.concat_key, 93 | axis=0, 94 | ), 95 | ), 96 | ( 97 | OpForwardPass( 98 | self.toPytorchModule(), 99 | modality_dimensions=self.modality_dimensions, 100 | add_feature_names=False, 101 | ), 102 | dict(key_in=self.concat_key, key_out=self.output_key), 103 | ), 104 | ] 105 | ) 106 | -------------------------------------------------------------------------------- /mmmt/data/representation/model_builder_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, OrderedDict, Sequence, Union 2 | 3 | import mlflow 4 | import mlflow.pytorch 5 | import pandas as pd 6 | import pytorch_lightning as pl 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from fuse.dl.lightning.pl_funcs import convert_predictions_to_dataframe 11 | from fuse.dl.lightning.pl_module import LightningModuleDefault 12 | from fuse.dl.losses.loss_base import LossBase 13 | from fuse.dl.losses.loss_default import LossDefault 14 | from fuse.dl.models.model_wrapper import ModelWrapSeqToDict 15 | from fuse.eval.metrics.metrics_common import MetricBase 16 | from torch.utils.data.dataloader import DataLoader 17 | 18 | 19 | class ModelBuilderTrainer: 20 | """ 21 | Class containing methods to train the model constructed using the model builder 22 | """ 23 | 24 | def __init__( 25 | self, 26 | instantiated_model: torch.nn.Module, 27 | model_in_key: str, 28 | model_out_key: str, 29 | label_key: str, 30 | ) -> None: 31 | """ 32 | Args: 33 | instantiated_model: configured model to be trained 34 | model_in_key: key in the FuseMedML pipeline where the samples to train the model (e.g. derived graphs) are stored 35 | model_out_key: key in the FuseMedML pipeline where the model output should be stored 36 | label_key: key in the FuseMedML pipeline where the labels are stored 37 | """ 38 | # store arguments 39 | self._instantiated_model = instantiated_model 40 | self._model_in_key = model_in_key 41 | self._model_out_key = model_out_key 42 | self._label_key = label_key 43 | 44 | # set to None train configuration - should be set 45 | self._model_dir = None 46 | self._losses = None 47 | self._best_epoch_source = None 48 | self._optimizers_and_lr_schs = None 49 | self._train_metrics = None 50 | self._validation_metrics = None 51 | self._callbacks = None 52 | 53 | # wrap the model 54 | self._model = ModelWrapSeqToDict( 55 | model=self._instantiated_model, 56 | model_inputs=[self._model_in_key], 57 | model_outputs=[self._model_out_key], 58 | ) 59 | 60 | def set_train_config( 61 | self, 62 | model_dir: str, 63 | losses: Optional[Dict[str, LossBase]] = None, 64 | best_epoch_source: Optional[Union[Dict, List[Dict]]] = None, 65 | optimizers_and_lr_schs: Any = None, 66 | train_metrics: Optional[OrderedDict[str, MetricBase]] = None, 67 | validation_metrics: Optional[OrderedDict[str, MetricBase]] = None, 68 | callbacks: Optional[Sequence[pl.Callback]] = None, 69 | pl_trainer_num_epochs: int = 100, 70 | pl_trainer_accelerator: str = "gpu", 71 | pl_trainer_devices: int = 1, 72 | pl_trainer_strategy: Optional[str] = None, 73 | ): 74 | 75 | if losses is None: 76 | losses = { 77 | "cls_loss": LossDefault( 78 | pred=self._model_out_key, 79 | target=self._label_key, 80 | callable=F.cross_entropy, 81 | weight=1.0, 82 | ), 83 | } 84 | 85 | if best_epoch_source is None: 86 | best_epoch_source = dict(monitor="validation.losses.total_loss", mode="min") 87 | 88 | if optimizers_and_lr_schs is None: 89 | # create optimizer 90 | optimizer = optim.Adam( 91 | self._model.parameters(), 92 | lr=1e-4, 93 | weight_decay=0.001, 94 | ) 95 | 96 | # create learning scheduler 97 | lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) 98 | lr_sch_config = dict( 99 | scheduler=lr_scheduler, monitor="validation.losses.total_loss" 100 | ) 101 | 102 | # optimizer and lr sch - see pl.LightningModule.configure_optimizers return value for all options 103 | optimizers_and_lr_schs = dict( 104 | optimizer=optimizer, lr_scheduler=lr_sch_config 105 | ) 106 | 107 | self._model_dir = model_dir 108 | self._losses = losses 109 | self._best_epoch_source = best_epoch_source 110 | self._optimizers_and_lr_schs = optimizers_and_lr_schs 111 | 112 | self._train_metrics = train_metrics 113 | self._validation_metrics = validation_metrics 114 | self._callbacks = callbacks 115 | 116 | self._pl_trainer_num_epochs = pl_trainer_num_epochs 117 | self._pl_trainer_accelerator = pl_trainer_accelerator 118 | self._pl_trainer_devices = pl_trainer_devices 119 | self._pl_trainer_strategy = pl_trainer_strategy 120 | 121 | def fit( 122 | self, train_dataloader: DataLoader, validation_dataloader: DataLoader 123 | ) -> None: 124 | assert ( 125 | self._model_dir is not None 126 | ), "Error expecting train configuration. Call to method set_train_config() to set it" 127 | 128 | pl_module = LightningModuleDefault( 129 | model_dir=self._model_dir, 130 | model=self._model, 131 | losses=self._losses, 132 | train_metrics=self._train_metrics, 133 | validation_metrics=self._validation_metrics, 134 | best_epoch_source=self._best_epoch_source, 135 | optimizers_and_lr_schs=self._optimizers_and_lr_schs, 136 | ) 137 | 138 | # create lightning trainer. 139 | pl_trainer = pl.Trainer( 140 | default_root_dir=self._model_dir, 141 | max_epochs=self._pl_trainer_num_epochs, 142 | accelerator=self._pl_trainer_accelerator, 143 | strategy=self._pl_trainer_strategy, 144 | devices=self._pl_trainer_devices, 145 | auto_select_gpus=True, 146 | ) 147 | 148 | # train 149 | mlflow.pytorch.autolog() 150 | pl_trainer.fit(pl_module, train_dataloader, validation_dataloader) 151 | 152 | def load_checkpoint(self, checkpoint_filename: str) -> torch.nn.Module: 153 | LightningModuleDefault.load_from_checkpoint( 154 | checkpoint_filename, 155 | model_dir=self._model_dir, 156 | model=self._model, 157 | map_location="cpu", 158 | strict=True, 159 | ) 160 | return self._model 161 | 162 | def predict( 163 | self, 164 | infer_dataloader: DataLoader, 165 | model_dir: str, 166 | checkpoint_filename: str, 167 | keys_to_extract: Sequence[str], 168 | ) -> pd.DataFrame: 169 | """ 170 | Method for using the model to predict the label of samples 171 | 172 | Args: 173 | infer_dataloader: dataloader to infer - each batch expected to be a dictionary batch_dict) 174 | model_dir: path to directory with checkpoints 175 | checkpoint_filename: path to checkpoint file as stored in self.fit() method 176 | keys_to_extract: sequence of keys to extract and dump into the dataframe 177 | 178 | Returns: 179 | dataframe containing the predictions 180 | """ 181 | 182 | pl_module = LightningModuleDefault.load_from_checkpoint( 183 | checkpoint_filename, 184 | model_dir=model_dir, 185 | model=self._model, 186 | map_location="cpu", 187 | strict=True, 188 | ) 189 | pl_module.set_predictions_keys(keys_to_extract) 190 | 191 | # create lightning trainer. 192 | pl_trainer = pl.Trainer( 193 | default_root_dir=self._model_dir, 194 | accelerator=self._pl_trainer_accelerator, 195 | strategy=self._pl_trainer_strategy, 196 | devices=self._pl_trainer_devices, 197 | auto_select_gpus=True, 198 | ) 199 | 200 | preds = pl_trainer.predict( 201 | model=pl_module, dataloaders=infer_dataloader, return_predictions=True 202 | ) 203 | 204 | df = convert_predictions_to_dataframe(preds) 205 | 206 | return df 207 | -------------------------------------------------------------------------------- /mmmt/models/__init__.py: -------------------------------------------------------------------------------- 1 | import mmmt.models.graph 2 | import mmmt.models.head 3 | import mmmt.models.model_builder 4 | -------------------------------------------------------------------------------- /mmmt/models/classic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/multimodal-models-toolkit/256f12b369769003986b68044086f4898cb096bf/mmmt/models/classic/__init__.py -------------------------------------------------------------------------------- /mmmt/models/classic/fusion_mlp.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Dict, List, Optional, OrderedDict, Sequence, Tuple, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import pytorch_lightning as pl 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | 10 | from fuse.dl.models.heads.head_1D_classifier import Head1DClassifier 11 | from fuse.dl.losses import LossBase, LossDefault 12 | from fuse.dl.lightning.pl_module import LightningModuleDefault 13 | from fuse.eval import MetricBase 14 | from fuse.eval.metrics.classification.metrics_classification_common import MetricAUCROC 15 | 16 | 17 | class FusionMLPClassifer: 18 | """ 19 | Basic feature (mid) fusion algorithm using MLP, including training. 20 | """ 21 | 22 | def __init__( 23 | self, input_keys: Sequence[Tuple[str, int]], target_key: str, **arch_kwargs 24 | ): 25 | """ 26 | :param input_keys: List of feature map inputs - tuples of (batch_dict key, channel depth) 27 | If multiple inputs are used, they are concatenated on the channel axis 28 | for example: 29 | input_keys=(('model.backbone_features', 193),) 30 | :param arch_kwargs: arguments to create the MLP - see Head1DClassifier 31 | :param target_key: labels key 32 | """ 33 | 34 | self._input_keys = input_keys 35 | self._arch_kwargs = arch_kwargs 36 | self._scores_key = "model.output.classifier" 37 | self._logits_key = "model.logits.classifier" 38 | self._target_key = target_key 39 | 40 | self._model = Head1DClassifier("classifier", input_keys, **arch_kwargs) 41 | 42 | # set to None train configuration - should be set 43 | self._model_dir = None 44 | self._losses = None 45 | self._best_epoch_source = None 46 | self._optimizers_and_lr_schs = None 47 | self._train_metrics = None 48 | self._validation_metrics = None 49 | self._callbacks = None 50 | 51 | def set_train_config( 52 | self, 53 | model_dir: str, 54 | losses: Optional[Dict[str, LossBase]] = None, 55 | best_epoch_source: Optional[Union[Dict, List[Dict]]] = None, 56 | optimizers_and_lr_schs: Any = None, 57 | train_metrics: Optional[OrderedDict[str, MetricBase]] = None, 58 | validation_metrics: Optional[OrderedDict[str, MetricBase]] = None, 59 | callbacks: Optional[Sequence[pl.Callback]] = None, 60 | pl_trainer_num_epochs: int = 100, 61 | pl_trainer_accelerator: str = "gpu", 62 | pl_trainer_devices: int = 1, 63 | pl_trainer_strategy: Optional[str] = None, 64 | ): 65 | self._model_dir = model_dir 66 | if losses is None: 67 | self._losses = { 68 | "cls_loss": LossDefault( 69 | pred=self._logits_key, 70 | target=self._target_key, 71 | callable=F.cross_entropy, 72 | weight=1.0, 73 | ), 74 | } 75 | else: 76 | self._losses = losses 77 | 78 | if train_metrics is None: 79 | self._train_metrics = { 80 | "auc": MetricAUCROC(pred=self._scores_key, target=self._target_key) 81 | } 82 | else: 83 | self._train_metrics = train_metrics 84 | 85 | if validation_metrics is None: 86 | self._validation_metrics = deepcopy( 87 | self._train_metrics 88 | ) # use the same metrics in validation as well 89 | else: 90 | self._validation_metrics = validation_metrics 91 | 92 | # either a dict with arguments to pass to ModelCheckpoint or list dicts for multiple ModelCheckpoint callbacks (to monitor and save checkpoints for more then one metric). 93 | if best_epoch_source is None: 94 | # assumes binary classification and that MetricAUCROC is in validation_metrics 95 | self._best_epoch_source = dict( 96 | monitor="validation.metrics.auc", 97 | mode="max", 98 | ) 99 | else: 100 | self._best_epoch_source 101 | 102 | if optimizers_and_lr_schs is None: 103 | optimizer = optim.SGD( 104 | self._model.parameters(), 105 | lr=1e-3, 106 | weight_decay=0.0, 107 | momentum=0.9, 108 | nesterov=True, 109 | ) 110 | 111 | # create learning scheduler 112 | if optimizers_and_lr_schs is None: 113 | lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) 114 | lr_sch_config = dict( 115 | scheduler=lr_scheduler, monitor="validation.losses.total_loss" 116 | ) 117 | 118 | # optimizier and lr sch - see pl.LightningModule.configure_optimizers return value for all options 119 | self._optimizers_and_lr_schs = dict( 120 | optimizer=optimizer, lr_scheduler=lr_sch_config 121 | ) 122 | else: 123 | self._optimizers_and_lr_schs = optimizers_and_lr_schs 124 | 125 | self._callbacks = callbacks 126 | self._pl_trainer_num_epochs = pl_trainer_num_epochs 127 | self._pl_trainer_accelerator = pl_trainer_accelerator 128 | self._pl_trainer_devices = pl_trainer_devices 129 | self._pl_trainer_strategy = pl_trainer_strategy 130 | 131 | def fit(self, train_dataloader: DataLoader, validation_dataloader: DataLoader): 132 | pl_module = LightningModuleDefault( 133 | model_dir=self._model_dir, 134 | model=self._model, 135 | losses=self._losses, 136 | train_metrics=self._train_metrics, 137 | validation_metrics=self._validation_metrics, 138 | best_epoch_source=self._best_epoch_source, 139 | optimizers_and_lr_schs=self._optimizers_and_lr_schs, 140 | ) 141 | 142 | # create lightining trainer. 143 | pl_trainer = pl.Trainer( 144 | default_root_dir=self._model_dir, 145 | max_epochs=self._pl_trainer_num_epochs, 146 | accelerator=self._pl_trainer_accelerator, 147 | devices=self._pl_trainer_devices, 148 | auto_select_gpus=True, 149 | ) 150 | 151 | # train 152 | pl_trainer.fit(pl_module, train_dataloader, validation_dataloader) 153 | 154 | def load_checkpoint(self, checkpoint_filename: str) -> torch.nn.Module: 155 | LightningModuleDefault.load_from_checkpoint( 156 | checkpoint_filename, 157 | model_dir=self._model_dir, 158 | model=self._model, 159 | map_location="cpu", 160 | strict=True, 161 | ) 162 | return self._model 163 | 164 | def model(self) -> torch.nn.Module: 165 | return self._model 166 | -------------------------------------------------------------------------------- /mmmt/models/classic/late_fusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | 5 | class LateFusion: 6 | """ 7 | Basic late fusion 8 | """ 9 | 10 | def __init__(self, n_mods, n_classes, weights=None): 11 | """ 12 | Args: 13 | n_mods: number of modalities 14 | n_classes: number of classes/tasks 15 | """ 16 | self.n_mods = n_mods 17 | self.n_classes = n_classes 18 | 19 | """weights: weights used for weighted average with size [n_mods, n_classes]. Initialized with equal weights.""" 20 | if weights is None: 21 | self.weights = np.ones((n_mods, n_classes), float) / n_mods 22 | else: 23 | self.weights = weights 24 | 25 | def apply_fusion(self, predictions): 26 | """ 27 | Calculate fused predictions using weighted average 28 | 29 | preditions: predictions made by all every model with size [n_mod, n_sample, n_class] 30 | n_mod is the number of models/modalities 31 | n_sample is the numebr of data samples 32 | n_class is the number of classes 33 | 34 | return fused of size [n_sample, n_class]. 35 | """ 36 | fused = np.zeros((predictions.shape[1], self.n_classes)) 37 | logging.info(fused.shape, self.weights.shape, predictions.shape) 38 | for modality in range(self.n_mods): 39 | for class_id in range(self.n_classes): 40 | fused[:, class_id] = ( 41 | fused[:, class_id] 42 | + self.weights[modality, class_id] 43 | * predictions[modality, :, class_id] 44 | ) 45 | return fused 46 | -------------------------------------------------------------------------------- /mmmt/models/classic/uncertainty_late_fusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mmmt.models.classic.late_fusion import LateFusion 3 | 4 | 5 | class UncertaintyLateFusion(LateFusion): 6 | """ 7 | Implemention of the late fusion method described in the following paper. 8 | Wang, Hongzhi, Vaishnavi Subramanian, and Tanveer Syeda-Mahmood. 9 | Modeling uncertainty in multi-modal fusion for lung cancer survival analysis. 10 | In 2021 IEEE 18th International Symposium on Biomedical Imaging (ISBI), pp. 1169-1172. IEEE, 2021. 11 | """ 12 | 13 | def __init__(self, n_mods, n_classes): 14 | """ 15 | Args: 16 | n_mods: number of modalities 17 | n_classes: number of classes/tasks 18 | """ 19 | super().__init__(n_mods, n_classes) 20 | self.k = 0 21 | self.alpha = 0.01 22 | 23 | def solve_weight(self, m): 24 | """ 25 | Compute the fusion weights of a class_id from the covariance matrix. 26 | 27 | Args: 28 | m: covariance matrix of a class_id 29 | 30 | Returns: 31 | Fusion weights for the different models/modalities and one class_id 32 | """ 33 | invM = np.linalg.inv(m) 34 | w = np.matmul(invM, np.ones((m.shape[0], 1))) 35 | w /= np.sum(w) 36 | return w 37 | 38 | def compute_fusion_weights(self, predictions, ground_truth): 39 | """ 40 | Calculate fusion weights with model selection 41 | 42 | Args: 43 | preditions: predictions made by every model with size [n_mod, n_sample, n_class] 44 | n_mod is the number of models/modalities 45 | n_sample is the number of data samples 46 | n_class is the number of classes 47 | ground_truth: one-hot representation of the ground truth with size [n_sample, n_class] 48 | K: number of models to be selected for fusion. Only the top K models will be used for fusion, default K=n_mod. 49 | alpha: weight for adding identity matrix to make the covariance matrix well conditioned. Typical value can be 0.1 or 0.01 50 | 51 | Returns: 52 | fusion_weights of size [n_mod, n_class]. 53 | fusion_weights[:, L] is the fusion weights for class L. Note that there are only K non zero values. 54 | """ 55 | self.n_mods = predictions.shape[0] 56 | n_samples = predictions.shape[1] 57 | self.n_classes = predictions.shape[2] 58 | if self.k == -1 or self.k > self.n_mods: 59 | self.k = self.n_mods 60 | 61 | selected = np.zeros((self.n_classes, self.n_mods), int) 62 | errors = np.zeros((self.n_mods, n_samples, self.n_classes), float) 63 | sumerrors = np.zeros((self.n_classes, self.n_mods), float) 64 | weights = np.zeros((self.n_mods, self.n_classes)) 65 | 66 | for modality in range(self.n_mods): 67 | errors[modality, :, :] = np.abs(ground_truth - predictions[modality, :, :]) 68 | for class_id in range(self.n_classes): 69 | sumerrors[class_id, modality] += np.sum(errors[modality, :, class_id]) 70 | 71 | for class_id in range(self.n_classes): 72 | inds = np.argsort(sumerrors[class_id, :]) 73 | selected[class_id, :] = inds 74 | 75 | convariance_matrix = np.zeros((self.n_classes, self.k, self.k)) 76 | 77 | for class_id in range(self.n_classes): 78 | for mod1_index in range(self.k): 79 | for mod2_index in range(mod1_index, self.k): 80 | jerror = ( 81 | errors[selected[class_id, mod1_index], :, class_id] 82 | * errors[selected[class_id, mod2_index], :, class_id] 83 | ) 84 | common_error = np.sum(jerror) 85 | convariance_matrix[class_id, mod1_index, mod2_index] = common_error 86 | convariance_matrix[ 87 | class_id, mod2_index, mod1_index 88 | ] = convariance_matrix[class_id, mod1_index, mod2_index] 89 | 90 | for class_id in range(self.n_classes): 91 | convariance_matrix[class_id, :, :] = convariance_matrix[class_id, :, :] / ( 92 | np.max(convariance_matrix[class_id, :, :]) + 1e-10 93 | ) 94 | m = convariance_matrix[class_id, :, :] 95 | for mod_index in range(self.k): 96 | m[mod_index, mod_index] += self.alpha 97 | w = self.solve_weight(m) 98 | weights[selected[class_id, 0 : self.k], class_id] = np.matrix.flatten(w) 99 | return weights 100 | -------------------------------------------------------------------------------- /mmmt/models/graph/__init__.py: -------------------------------------------------------------------------------- 1 | from mmmt.models.graph.module_configurator import ModuleConfigurator 2 | from mmmt.models.graph.gcn import GCN 3 | from mmmt.models.graph.mgnn import mGNN 4 | from mmmt.models.graph.multi_behavioral_gnn import MultiBehavioralGNN 5 | from mmmt.models.graph.multiplex_gcn import MultiplexGCN 6 | from mmmt.models.graph.multiplex_gin import MultiplexGIN 7 | from mmmt.models.graph.relational_gcn import Relational_GCN 8 | -------------------------------------------------------------------------------- /mmmt/models/graph/gcn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from dgl.nn.pytorch.conv import GraphConv 4 | 5 | 6 | class GCN(nn.Module): 7 | """ 8 | Baseline GCN 9 | """ 10 | 11 | def __init__(self, module_input): 12 | """ 13 | Args: 14 | module_input: dictionary for module initialization containing the following keys: 15 | - in_size: input feature dimension 16 | - gl_hidden_size: size of the hidden layer 17 | """ 18 | super().__init__() 19 | 20 | in_size = module_input["node_emb_dim"] 21 | gl_hidden_size = module_input["gl_hidden_size"] 22 | 23 | # create layers 24 | self.layer1 = GraphConv(in_size, gl_hidden_size[0], allow_zero_in_degree=True) 25 | self.layer2 = GraphConv( 26 | gl_hidden_size[0], gl_hidden_size[-1], allow_zero_in_degree=True 27 | ) 28 | 29 | self.out_size = gl_hidden_size[-1] 30 | 31 | def forward(self, derived_graph): 32 | """ 33 | Forward pass of the model. 34 | 35 | Args: 36 | derived_graph: dictionary of graph topology (key 'graph') and node features (key 'node_features') 37 | 38 | Returns: 39 | Graph latent representation 40 | """ 41 | g = derived_graph["graph"][0] # only one graph is present 42 | node_features = derived_graph["node_features"] 43 | 44 | h = self.layer1(g, node_features) 45 | h = F.leaky_relu(h) 46 | h = self.layer2(g, h) 47 | 48 | return h 49 | -------------------------------------------------------------------------------- /mmmt/models/graph/mgnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from dgl.nn.pytorch.conv import GATConv 5 | 6 | 7 | class mGNN(nn.Module): 8 | """ 9 | mGNN framework for message passing from https://arxiv.org/abs/2109.10119 10 | """ 11 | 12 | def __init__(self, module_input): 13 | """ 14 | Args: 15 | module_input: dictionary for module initialization containing the following keys: 16 | - in_size: input feature dimension 17 | - gl_hidden_size: size of the hidden layer 18 | - num_att_heads: number of attention heads 19 | - n_layers: number of layers of the multiplex graphs, input of the graph model 20 | """ 21 | super().__init__() 22 | 23 | in_size = module_input["node_emb_dim"] 24 | gl_hidden_size = module_input["gl_hidden_size"] 25 | num_att_heads = module_input["num_att_heads"] 26 | n_layers = module_input["n_layers"] 27 | 28 | self.n_layers = n_layers 29 | 30 | self.convC1 = GATConv( 31 | in_size, in_size, num_heads=num_att_heads, allow_zero_in_degree=True 32 | ) 33 | self.convA1 = GATConv( 34 | in_size, in_size, num_heads=num_att_heads, allow_zero_in_degree=True 35 | ) 36 | 37 | self.convC2 = GATConv( 38 | 2 * in_size * num_att_heads, 39 | gl_hidden_size[0], 40 | num_heads=num_att_heads, 41 | allow_zero_in_degree=True, 42 | ) 43 | self.convA2 = GATConv( 44 | 2 * in_size * num_att_heads, 45 | gl_hidden_size[0], 46 | num_heads=num_att_heads, 47 | allow_zero_in_degree=True, 48 | ) 49 | 50 | self.out_size = gl_hidden_size[-1] * 2 * num_att_heads 51 | 52 | def forward(self, derived_graph): 53 | """ 54 | Forward pass of the model. 55 | 56 | Args: 57 | derived_graph: dictionary of graph topology (key 'graph') and node features (key 'node_features') 58 | 59 | Returns: 60 | Graph latent representation 61 | """ 62 | g1, g2 = derived_graph["graph"] 63 | node_features = derived_graph["node_features"] 64 | 65 | h_C1 = self.convC1(g1, node_features).reshape(node_features.size()[0], -1) 66 | h_A1 = self.convA1(g2, node_features).reshape(node_features.size()[0], -1) 67 | 68 | h = F.leaky_relu(torch.cat((h_C1, h_A1), dim=1)) 69 | 70 | h_C2 = self.convC2(g1, h).reshape(node_features.size()[0], -1) 71 | h_A2 = self.convA2(g2, h).reshape(node_features.size()[0], -1) 72 | 73 | # graph based readout 74 | h = torch.cat((h_C2, h_A2), dim=1) 75 | 76 | return h 77 | -------------------------------------------------------------------------------- /mmmt/models/graph/module_configurator.py: -------------------------------------------------------------------------------- 1 | import mmmt 2 | 3 | import logging 4 | 5 | 6 | class ModuleConfigurator: 7 | """ 8 | Configures the graph neural network taking the inputs from the user configuration stored in graph_module 9 | """ 10 | 11 | def __init__(self, graph_module): 12 | """ 13 | Args: 14 | graph_module: user input that defines which flavour to use for the graph neural network 15 | """ 16 | self.graph_module = graph_module 17 | self.graph_module_identifier = graph_module["module_identifier"] 18 | 19 | def get_module(self, graph_sample): 20 | """ 21 | Instantiate the chosen graph neural network and provides metadata for the construction of the head module downstream 22 | 23 | Args: 24 | graph_sample: a graph sample resulting from the combination of the available modalities 25 | 26 | Returns: 27 | A tuple containing the instantiated graph neural network, the input size needed for the head module and the number of nodes arriving to the head module 28 | """ 29 | modules = { 30 | "mplex": mmmt.models.graph.MultiplexGIN, 31 | "mplex-prop": mmmt.models.graph.MultiplexGCN, 32 | "mgnn": mmmt.models.graph.mGNN, 33 | "multibehav": mmmt.models.graph.MultiBehavioralGNN, 34 | "gcn": mmmt.models.graph.GCN, 35 | "rgcn": mmmt.models.graph.Relational_GCN, 36 | } 37 | 38 | if self.graph_module_identifier in modules: 39 | return ( 40 | modules[self.graph_module_identifier](self.graph_module), 41 | self.get_head_in_size(), 42 | self.get_head_num_nodes(graph_sample), 43 | ) 44 | else: 45 | logging.error( 46 | self.graph_module_identifier 47 | + " method not implemented. Choose within " 48 | + str(modules.keys()) 49 | ) 50 | raise ValueError("Unknown method! Choose within " + str(modules.keys())) 51 | 52 | def get_head_in_size(self): 53 | """ 54 | Provides the input size needed for the head module 55 | 56 | Returns: 57 | input size needed for the head module 58 | """ 59 | gl_hidden_size = self.graph_module.get("gl_hidden_size", [0]) 60 | node_emb_dim = self.graph_module.get("node_emb_dim", 0) 61 | num_att_heads = self.graph_module.get("num_att_heads", 1) 62 | sizes = { 63 | "mplex": 4 * node_emb_dim, 64 | "mplex-prop": gl_hidden_size[-1] * 2, 65 | "mgnn": gl_hidden_size[-1] * 2 * num_att_heads, 66 | } 67 | return sizes.get(self.graph_module_identifier, gl_hidden_size[-1]) 68 | # no need to capture exceptions as already done in self.get_module 69 | 70 | def get_head_num_nodes(self, graph_sample): 71 | """ 72 | Provides the number of nodes arriving to the head module 73 | 74 | Args: 75 | graph_sample: a graph sample resulting from the combination of the available modalities 76 | 77 | Returns: 78 | number of nodes arriving to the head module 79 | """ 80 | if self.graph_module_identifier == "multibehav": 81 | return graph_sample[1].num_nodes() # * self.graph_batch_size 82 | else: 83 | return graph_sample[0].num_nodes() # * self.graph_batch_size 84 | -------------------------------------------------------------------------------- /mmmt/models/graph/multi_behavioral_gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from dgl.nn.pytorch.conv import GraphConv 5 | 6 | 7 | class MultiBehavioralGNN(nn.Module): 8 | """ 9 | Multibehavioral GNN framework for message passing from https://dl.acm.org/doi/pdf/10.1145/3340531.3412119 10 | """ 11 | 12 | def __init__(self, module_input): 13 | """ 14 | Args: 15 | module_input: dictionary for module initialization containing the following keys: 16 | - in_size: input feature dimension 17 | - gl_hidden_size: size of the hidden layer 18 | - n_layers: number of concepts 19 | """ 20 | super().__init__() 21 | 22 | in_size = module_input["node_emb_dim"] 23 | gl_hidden_size = module_input["gl_hidden_size"] 24 | n_layers = module_input["n_layers"] 25 | 26 | self.n_layers = n_layers 27 | 28 | self.convQ1 = GraphConv(in_size, gl_hidden_size[0], allow_zero_in_degree=True) 29 | self.convM1 = GraphConv( 30 | gl_hidden_size[0], gl_hidden_size[-1], allow_zero_in_degree=True 31 | ) 32 | 33 | self.out_size = gl_hidden_size[-1] 34 | 35 | def forward(self, derived_graph): 36 | """ 37 | Forward pass of the model. 38 | 39 | Args: 40 | derived_graph: dictionary of graph topology (key 'graph') and node features (key 'node_features') 41 | 42 | Returns: 43 | Graph latent representation 44 | """ 45 | g1, g2 = derived_graph["graph"] 46 | Q_node_features = derived_graph["node_features"] 47 | h_Q1 = self.convQ1(g1, Q_node_features) 48 | h_Q1 = F.leaky_relu(h_Q1) 49 | 50 | # node based feature stack for Multiplex graph 51 | M_node_features = torch.kron(torch.ones(self.n_layers, 1), h_Q1) 52 | h_M1 = self.convM1(g2, M_node_features) 53 | h_M1 = F.leaky_relu(h_M1) 54 | 55 | return h_M1 56 | -------------------------------------------------------------------------------- /mmmt/models/graph/multiplex_gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from dgl.nn.pytorch.conv import SGConv 4 | 5 | 6 | class MultiplexGCN(nn.Module): 7 | """ 8 | Multiplex GCN for message passing according to sGCN Conv for sparse graphs. 9 | """ 10 | 11 | def __init__(self, module_input): 12 | """ 13 | Args: 14 | module_input: dictionary for module initialization containing the following keys: 15 | - in_size: input feature dimension 16 | - gl_hidden_size: list of sizes of the hidden layers 17 | """ 18 | super().__init__() 19 | 20 | in_size = module_input["node_emb_dim"] 21 | gl_hidden_size = module_input["gl_hidden_size"] 22 | 23 | self.convAC1 = SGConv(in_size, gl_hidden_size[0], allow_zero_in_degree=True) 24 | self.convCA1 = SGConv(in_size, gl_hidden_size[0], allow_zero_in_degree=True) 25 | 26 | self.convAC2 = SGConv( 27 | 2 * gl_hidden_size[0], gl_hidden_size[1], allow_zero_in_degree=True 28 | ) 29 | self.convCA2 = SGConv( 30 | 2 * gl_hidden_size[0], gl_hidden_size[1], allow_zero_in_degree=True 31 | ) 32 | 33 | self.out_size = gl_hidden_size[-1] * 2 34 | 35 | def forward(self, derived_graph): 36 | """ 37 | Forward pass of the model. 38 | 39 | Args: 40 | derived_graph: dictionary of graph topology (key 'graph') and node features (key 'node_features') 41 | 42 | Returns: 43 | Graph latent representation 44 | """ 45 | g1, g2 = derived_graph["graph"] 46 | node_features = derived_graph["node_features"] 47 | 48 | ef_1 = g1.edata["w"].float() 49 | ef_2 = g2.edata["w"].float() 50 | 51 | h_AC1 = self.convAC1(g1, node_features, edge_weight=ef_1) 52 | h_CA1 = self.convCA1(g2, node_features, edge_weight=ef_2) 53 | 54 | h = torch.cat((h_AC1, h_CA1), dim=1) 55 | 56 | h_AC2 = self.convAC2(g1, h, edge_weight=ef_1) 57 | h_CA2 = self.convCA2(g2, h, edge_weight=ef_2) 58 | 59 | # #aggregate across features 60 | h = torch.cat((h_AC2, h_CA2), dim=1) 61 | 62 | return h 63 | -------------------------------------------------------------------------------- /mmmt/models/graph/multiplex_gin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from dgl.nn.pytorch.conv import GINConv 5 | 6 | 7 | class MultiplexGIN(nn.Module): 8 | """ 9 | Multiplex GIN framework for message passing via multiplex walks. 10 | """ 11 | 12 | def __init__(self, module_input): 13 | """ 14 | Args: 15 | module_input: dictionary for module initialization 16 | 17 | """ 18 | super().__init__() 19 | 20 | self.convAC1 = GINConv(aggregator_type="mean", activation=F.leaky_relu) 21 | self.convCA1 = GINConv(aggregator_type="mean", activation=F.leaky_relu) 22 | 23 | self.convAC2 = GINConv(aggregator_type="mean", activation=F.leaky_relu) 24 | self.convCA2 = GINConv(aggregator_type="mean", activation=F.leaky_relu) 25 | 26 | self.out_size = 4 * module_input["node_emb_dim"] 27 | 28 | def forward(self, derived_graph): 29 | """ 30 | Forward pass of the model. 31 | 32 | Args: 33 | derived_graph: dictionary of graph topology (key 'graph') and node features (key 'node_features') 34 | 35 | Returns: 36 | Graph latent representation 37 | """ 38 | g1, g2 = derived_graph["graph"] 39 | node_features = derived_graph["node_features"] 40 | 41 | h_AC1 = self.convAC1(g1, node_features) 42 | h_CA1 = self.convCA1(g2, node_features) 43 | 44 | h = torch.cat((h_AC1, h_CA1), dim=1) 45 | 46 | h_AC2 = self.convAC2(g1, h) 47 | h_CA2 = self.convCA2(g2, h) 48 | 49 | # #aggregate across features 50 | h = torch.cat((h_AC2, h_CA2), dim=1) 51 | return h 52 | -------------------------------------------------------------------------------- /mmmt/models/graph/relational_gcn.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from dgl.nn.pytorch.conv import RelGraphConv 5 | 6 | 7 | class Relational_GCN(nn.Module): 8 | """ 9 | Relational GCN in https://arxiv.org/pdf/1703.06103.pdf 10 | (Uses in-built Relational Graph Conv) 11 | """ 12 | 13 | def __init__(self, module_input): 14 | """ 15 | Args: 16 | module_input: dictionary for module initialization containing the following keys: 17 | - in_size: input feature dimension 18 | - n_layers: number of concepts 19 | - num_bases: number of bases of the RelGraphConv 20 | - gl_hidden_size: size of the hidden layer 21 | """ 22 | super().__init__() 23 | 24 | in_size = module_input["node_emb_dim"] 25 | n_layers = module_input["n_layers"] 26 | num_bases = module_input["num_bases"] 27 | gl_hidden_size = module_input["gl_hidden_size"] 28 | 29 | # create layers 30 | self.layer1 = RelGraphConv( 31 | in_size, 32 | gl_hidden_size[0], 33 | num_rels=n_layers, 34 | regularizer="basis", 35 | num_bases=num_bases, 36 | ) 37 | self.layer2 = RelGraphConv( 38 | gl_hidden_size[0], 39 | gl_hidden_size[-1], 40 | num_rels=n_layers, 41 | regularizer="basis", 42 | num_bases=num_bases, 43 | ) 44 | 45 | self.out_size = gl_hidden_size[-1] 46 | 47 | def forward(self, derived_graph): 48 | """ 49 | Forward pass of the model. 50 | 51 | Args: 52 | derived_graph: dictionary of graph topology (key 'graph') and node features (key 'node_features') 53 | 54 | Returns: 55 | Graph latent representation 56 | """ 57 | g = derived_graph["graph"][0] # only one graph is present 58 | node_features = derived_graph["node_features"] 59 | 60 | etype = dgl.to_homogeneous(g).edata[dgl.ETYPE] 61 | 62 | h = self.layer1(dgl.to_homogeneous(g), node_features, etype) 63 | h = F.leaky_relu(h) 64 | 65 | h = self.layer2(dgl.to_homogeneous(g), h, etype) 66 | 67 | return h 68 | -------------------------------------------------------------------------------- /mmmt/models/head/__init__.py: -------------------------------------------------------------------------------- 1 | from mmmt.models.head.mlp import MLP 2 | -------------------------------------------------------------------------------- /mmmt/models/head/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MLP(nn.Module): 7 | """ 8 | MLP model for combination with the graph model. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | num_nodes, 14 | in_size, 15 | hidden_size: list, 16 | out_size, 17 | dropout, 18 | add_softmax=False, 19 | ): 20 | """ 21 | Args: 22 | num_nodes: number of nodes in the graph 23 | in_size: input feature dimension 24 | hidden_size: list of hidden layer dimension of the classification head, 25 | for one hidden layer len(hidden_layer) == 1 26 | out_size: output dimension, which normally corresponds to the number of classes 27 | add_softmax: whether to add a softmax layer at the end 28 | """ 29 | super().__init__() 30 | 31 | self.in_size = in_size 32 | self.num_nodes = num_nodes 33 | self.out_size = out_size 34 | 35 | if in_size > 1: 36 | self.agg = torch.nn.Linear(in_size, 1) 37 | else: 38 | self.agg = None 39 | 40 | self.dp_out = torch.nn.Dropout(dropout) 41 | 42 | self.dense_layers = [] 43 | 44 | in_features = self.num_nodes 45 | for out_features in hidden_size: 46 | self.dense_layers.append(torch.nn.Linear(in_features, out_features)) 47 | in_features = out_features 48 | 49 | self.dense_layers = nn.ModuleList(self.dense_layers) 50 | 51 | if add_softmax: 52 | self.final_layer = nn.Sequential( 53 | torch.nn.Linear(hidden_size[-1], out_size), torch.nn.Softmax(2) 54 | ) 55 | else: 56 | self.final_layer = torch.nn.Linear(hidden_size[-1], out_size) 57 | 58 | def forward(self, h): 59 | """ 60 | Forward pass of the classifier head model. 61 | 62 | Args: 63 | h: graph latent representation 64 | 65 | Returns: 66 | Logits, with defined dimension out_size 67 | """ 68 | 69 | if self.agg: 70 | h = self.agg(h.float()) 71 | 72 | h = self.dp_out(h) 73 | 74 | h = h.transpose(2, 0) 75 | 76 | else: 77 | h = h[None, :, :] 78 | 79 | h = self.dp_out(h) 80 | 81 | for layer in self.dense_layers: 82 | h = F.leaky_relu(layer(h)) 83 | 84 | h = self.final_layer(h) 85 | 86 | return h.squeeze(dim=0) 87 | -------------------------------------------------------------------------------- /mmmt/models/model_builder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import logging 3 | 4 | 5 | class ModelBuilder(nn.Module): 6 | """ 7 | Builder to unite graph model and head model. 8 | """ 9 | 10 | def __init__(self, graph_model_object, head_object, batch_size): 11 | """ 12 | Args: 13 | graph_model_object: instance of class of the graph model 14 | head_object: instance of class of the head model 15 | """ 16 | super().__init__() 17 | 18 | self.graph_model_object = graph_model_object 19 | self.head_object = head_object 20 | self.batch_size = batch_size 21 | 22 | self.check_compatibility() 23 | 24 | def check_compatibility(self): 25 | """ 26 | Compatibility checks between the defined graph neural network and the configured head module. 27 | If incompatibility is detected and assertion error is raised, if graph model metadata is missing, an attribute error is raised. 28 | """ 29 | if hasattr(self.head_object, "in_size") and hasattr( 30 | self.graph_model_object, "out_size" 31 | ): 32 | assert self.head_object.in_size == self.graph_model_object.out_size 33 | else: 34 | if hasattr(self.head_object, "in_size"): 35 | raise AttributeError( 36 | "out_size attribute in graph_object is not provided" 37 | ) 38 | else: 39 | raise AttributeError("in_size attribute in head_object is not provided") 40 | 41 | def forward(self, derived_graph): 42 | """ 43 | Forward pass of the model. 44 | 45 | Args: 46 | derived_graph: dictionary of graph topology (key 'graph') and node features (key 'node_features') 47 | 48 | Returns: 49 | Graph label logits, with defined dimension out_size 50 | """ 51 | 52 | h = self.graph_model_object(derived_graph) 53 | 54 | n_samples = h.shape[0] / self.head_object.num_nodes 55 | if n_samples != int(n_samples): 56 | logging.error( 57 | "irregular number of nodes in the batch: " 58 | + str(h.shape[0]) 59 | + " vs. " 60 | + str(self.head_object.num_nodes) 61 | ) 62 | if n_samples != self.batch_size: 63 | logging.debug( 64 | "number of samples per batch is not as defined, this should happen maximum once per epoch." 65 | ) 66 | 67 | h = ( 68 | h.reshape(-1, self.head_object.num_nodes, h.shape[1]) 69 | .transpose(1, 0) 70 | .reshape(-1, h.shape[1]) 71 | .reshape((self.head_object.num_nodes, -1, h.shape[1])) 72 | ) 73 | 74 | # Note: 75 | # from dgl documentation: in a batched graph is obtained by concatenating the corresponding features 76 | # from all graphs in order. 77 | # https://docs.dgl.ai/en/0.8.x/guide/training-graph.html?highlight=batch 78 | 79 | h = self.head_object(h) 80 | 81 | return h 82 | -------------------------------------------------------------------------------- /mmmt/models/multimodal_mlp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from fuse.utils.file_io.file_io import save_dataframe 4 | 5 | from torch.utils.data.dataloader import DataLoader 6 | from fuse.data.utils.collates import CollateDefault 7 | from fuse.utils.file_io.file_io import create_dir 8 | import mmmt 9 | 10 | 11 | class MultimodalMLP: 12 | """ 13 | Construction, training and inference of the multimodal graph model. 14 | """ 15 | 16 | def __init__(self, args_dict): 17 | """ 18 | Args: 19 | graph_train_dataset: training dataset for the multimodal graph model 20 | graph_validation_dataset: validation dataset for the multimodal graph model 21 | graph_model_configuration: user input, configuration of multimodal graph model 22 | """ 23 | 24 | self.train_config = copy.deepcopy(args_dict["step_args"]["training"]) 25 | 26 | train_dataset = args_dict["pipeline"][ 27 | args_dict["step_args"]["io"]["input_key"] 28 | ]["concatenated_training_dataset"] 29 | validation_dataset = args_dict["pipeline"][ 30 | args_dict["step_args"]["io"]["input_key"] 31 | ]["concatenated_validation_dataset"] 32 | self.test_dataset = args_dict["pipeline"][ 33 | args_dict["step_args"]["io"]["input_key"] 34 | ]["concatenated_test_dataset"] 35 | 36 | self.root_dir = args_dict["root_dir"] 37 | 38 | self.model_in_key = args_dict["step_args"]["io"]["input_key"] 39 | self.target_key = args_dict["step_args"]["io"]["target_key"] 40 | self.model_out_key = args_dict["step_args"]["io"]["prediction_key"] 41 | 42 | self.hidden_size = args_dict["step_args"]["model_config"]["hidden_size"] 43 | self.dropout = args_dict["step_args"]["model_config"]["dropout"] 44 | self.add_softmax = args_dict["step_args"]["model_config"].get( 45 | "add_softmax", True 46 | ) 47 | 48 | self.batch_size = args_dict["step_args"]["training"]["batch_size"] 49 | del self.train_config["batch_size"] 50 | 51 | self.obj_reg = args_dict["object_registry"] 52 | self.train_config["train_metrics"] = { 53 | args_dict["step_args"]["training"]["train_metrics"][ 54 | "key" 55 | ]: self.obj_reg.instance_object( 56 | args_dict["step_args"]["training"]["train_metrics"]["object"], 57 | args_dict["step_args"]["training"]["train_metrics"]["args"], 58 | ) 59 | } 60 | self.train_config["validation_metrics"] = { 61 | args_dict["step_args"]["training"]["validation_metrics"][ 62 | "key" 63 | ]: self.obj_reg.instance_object( 64 | args_dict["step_args"]["training"]["validation_metrics"]["object"], 65 | args_dict["step_args"]["training"]["validation_metrics"]["args"], 66 | ) 67 | } 68 | self.model_dir = os.path.join(self.root_dir, self.train_config["model_dir"]) 69 | self.train_config["model_dir"] = self.model_dir 70 | 71 | self.num_workers = args_dict["num_workers"] 72 | self.out_size = args_dict["step_args"]["model_config"]["num_classes"] 73 | 74 | self.test_results_filename = args_dict["step_args"]["testing"][ 75 | "test_results_filename" 76 | ] 77 | self.evaluation_directory = args_dict["step_args"]["testing"][ 78 | "evaluation_directory" 79 | ] 80 | 81 | self.checkpoint_filename = "best_epoch.ckpt" 82 | 83 | self.mb_trainer = None 84 | 85 | # Configure MLP Module 86 | num_feat = train_dataset[0][args_dict["step_args"]["io"]["input_key"]].shape[0] 87 | 88 | mlp_model = mmmt.models.head.mlp.MLP( 89 | num_feat, 90 | 1, 91 | self.hidden_size, 92 | self.out_size, 93 | self.dropout, 94 | self.add_softmax, 95 | ) 96 | 97 | self.build_model(mlp_model) 98 | 99 | # Configure data loaders 100 | train_dataloader = DataLoader( 101 | dataset=train_dataset, 102 | batch_size=self.batch_size, 103 | collate_fn=CollateDefault(), 104 | num_workers=self.num_workers, 105 | ) 106 | 107 | validation_dataloader = DataLoader( 108 | dataset=validation_dataset, 109 | batch_size=self.batch_size, 110 | collate_fn=CollateDefault(), 111 | num_workers=self.num_workers, 112 | ) 113 | 114 | self.train_dataloader = train_dataloader 115 | self.validation_dataloader = validation_dataloader 116 | 117 | def build_model(self, head_model): 118 | """ 119 | Multimodal MLP model construction from the MLP model. 120 | 121 | Args: 122 | head_model: head model, where signal is aggragated up to a target 123 | """ 124 | 125 | MB = head_model 126 | 127 | self.mb_trainer = mmmt.data.representation.ModelBuilderTrainer( 128 | MB, 129 | self.model_in_key, 130 | self.model_out_key, 131 | self.target_key, 132 | ) 133 | 134 | def __call__(self): 135 | self.train() 136 | self.test() 137 | 138 | def train( 139 | self, 140 | ): 141 | """ 142 | Multimodal MLP model training 143 | 144 | Returns: 145 | the model with the weights from the best epoch, measured using the validation set 146 | """ 147 | 148 | self.mb_trainer.set_train_config(**self.train_config) 149 | 150 | self.mb_trainer.fit(self.train_dataloader, self.validation_dataloader) 151 | 152 | def test( 153 | self, 154 | ): 155 | """ 156 | Apply model and extract both output and ground-truth labels 157 | 158 | """ 159 | 160 | # run inference and eval 161 | eval_dir = os.path.join(self.root_dir, self.evaluation_directory) 162 | create_dir(eval_dir) 163 | 164 | output_filename = os.path.join(eval_dir, self.test_results_filename) 165 | 166 | test_dataloader = DataLoader( 167 | dataset=self.test_dataset, 168 | batch_size=self.batch_size, 169 | collate_fn=CollateDefault(), 170 | num_workers=self.num_workers, 171 | ) 172 | test_df = self.mb_trainer.predict( 173 | test_dataloader, 174 | self.model_dir, 175 | os.path.join(self.model_dir, self.checkpoint_filename), 176 | [self.target_key, self.model_out_key], 177 | ) 178 | test_df = test_df.rename( 179 | columns={self.target_key: "target", self.model_out_key: "pred"} 180 | ) 181 | save_dataframe(test_df, output_filename) 182 | 183 | return test_df 184 | -------------------------------------------------------------------------------- /mmmt/pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/multimodal-models-toolkit/256f12b369769003986b68044086f4898cb096bf/mmmt/pipeline/__init__.py -------------------------------------------------------------------------------- /mmmt/pipeline/defaults.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | - fuse_object: "" 3 | 4 | # io: 5 | # input_key: None 6 | # output_key: "dataset_pipeline" 7 | - object: "get_splits_str_ids" 8 | io: 9 | input_key: null 10 | output_key: "data_splits" 11 | 12 | cache: 13 | num_workers: 1 14 | restart_cache: True 15 | root_dir: "path/to/cache" 16 | 17 | mlflow: 18 | MLFLOW_TRACKING_URI: null 19 | MLFLOW_EXPERIMENT_NAME: null 20 | 21 | modality_encoding_strategy: 22 | - object: "ModalityEncoding" 23 | 24 | 25 | fusion_strategy: 26 | - object: "EncodedUnimodalToConcept" # early or late 27 | args: 28 | use_autoencoders: True 29 | add_feature_names: False 30 | encoding_layers: 31 | - 32 32 | - &n_layers 16 33 | use_pretrained: True 34 | batch_size: 3 35 | training: 36 | model_dir: "model_concept" 37 | pl_trainer_num_epochs: 1 38 | pl_trainer_accelerator: "cpu" 39 | io: 40 | concept_encoder_model_key: "concept_encoder_model" 41 | output_key: "data.input.concatenated" 42 | - object: "ConceptToGraph" 43 | args: 44 | module_identifier: &graph_module "mplex" 45 | thresh_q: 0.95 46 | io: 47 | concept_encoder_model_key: "concept_encoder_model" 48 | fused_dataset_key: "fused_dataset" 49 | input_key: "data.input.concatenated" 50 | output_key: "data.derived_graph" 51 | 52 | task_strategy: 53 | - object: "MultimodalGraphModel" 54 | args: 55 | io: 56 | fused_dataset_key: "fused_dataset" 57 | input_key: "data.derived_graph" 58 | target_key: &target "data.gt.gt_global.task_1_label" 59 | prediction_key: &prediction "model.out" 60 | 61 | 62 | model_config: 63 | graph_model: 64 | module_identifier: *graph_module 65 | n_layers: *n_layers 66 | node_emb_dim: 1 # really needed? 67 | head_model: 68 | head_hidden_size: 69 | - 100 70 | - 20 71 | dropout: 0.5 72 | add_softmax: True 73 | num_classes: 2 74 | 75 | training: 76 | model_dir: "model_mplex" 77 | batch_size: 3 78 | best_epoch_source: 79 | mode: "max" 80 | monitor: "validation.metrics.auc" 81 | train_metrics: 82 | key: "auc" 83 | object: "MetricAUCROC" 84 | args: 85 | pred: *prediction 86 | target: *target 87 | validation_metrics: 88 | key: "auc" 89 | object: "MetricAUCROC" 90 | args: 91 | pred: *prediction 92 | target: *target 93 | pl_trainer_num_epochs: 1 94 | pl_trainer_accelerator: "cpu" 95 | pl_trainer_devices: 1 96 | 97 | testing: 98 | test_results_filename: &test_results_filename "test_results.pickle" 99 | evaluation_directory: &evaluation_directory "eval" 100 | 101 | - object: "Eval" 102 | args: 103 | test_results_filename: *test_results_filename 104 | evaluation_directory: *evaluation_directory 105 | -------------------------------------------------------------------------------- /mmmt/pipeline/object_registry.py: -------------------------------------------------------------------------------- 1 | import fuse.data.ops.ops_cast 2 | import fuse.data.ops.ops_common 3 | import fuse.eval.metrics.classification.metrics_classification_common 4 | 5 | import mmmt.models.multimodal_graph_model 6 | import mmmt.models.multimodal_mlp 7 | import mmmt.data.operators.op_forwardpass 8 | import mmmt.data.representation.encoded_unimodal_to_concept 9 | import mmmt.data.graph.concept_to_graph 10 | import mmmt.data.representation.modality_encoding 11 | import mmmt.data.representation.encoded_unimodal_to_concept 12 | import mmmt.data.graph.visualization 13 | 14 | 15 | class ObjectRegistry: 16 | """ 17 | Registry of objects commonly used in an MMMT pipeline. 18 | """ 19 | 20 | def __init__(self, specific_objects=None): 21 | """ 22 | 23 | Args: 24 | specific_objects (dict, optional): Specific objects needed by the pipeline and not contaiined in the default objects. Defaults to None. 25 | """ 26 | 27 | self.object_dict = { 28 | "OpToTensor": { 29 | "object": fuse.data.ops.ops_cast.OpToTensor, 30 | "need_cache": True, 31 | }, 32 | "OpLambda": { 33 | "object": fuse.data.ops.ops_common.OpLambda, 34 | "need_cache": True, 35 | }, 36 | "MetricAUCROC": { 37 | "object": fuse.eval.metrics.classification.metrics_classification_common.MetricAUCROC, 38 | "need_cache": True, 39 | }, 40 | "MetricROCCurve": { 41 | "object": fuse.eval.metrics.classification.metrics_classification_common.MetricROCCurve, 42 | "need_cache": True, 43 | }, 44 | "MetricAccuracy": { 45 | "object": fuse.eval.metrics.classification.metrics_classification_common.MetricAccuracy, 46 | "need_cache": True, 47 | }, 48 | "ForwardPass": { 49 | "object": mmmt.data.operators.op_forwardpass.OpForwardPass, 50 | "need_cache": True, 51 | }, 52 | "MultimodalGraphModel": { 53 | "object": mmmt.models.multimodal_graph_model.MultimodalGraphModel, 54 | "need_cache": True, 55 | "need_pipeline": True, 56 | "need_object_registry": True, 57 | "need_call_method": True, 58 | }, 59 | "MultimodalMLP": { 60 | "object": mmmt.models.multimodal_mlp.MultimodalMLP, 61 | "need_cache": True, 62 | "need_pipeline": True, 63 | "need_object_registry": True, 64 | "need_call_method": True, 65 | }, 66 | "ConceptToGraph": { 67 | "object": mmmt.data.graph.concept_to_graph.ConceptToGraph, 68 | "need_cache": True, 69 | "need_pipeline": True, 70 | "need_call_method": True, 71 | }, 72 | "ModalityEncoding": { 73 | "object": mmmt.data.representation.modality_encoding.ModalityEncoding, 74 | "need_cache": True, 75 | "need_pipeline": True, 76 | "need_call_method": True, 77 | "need_object_registry": True, 78 | }, 79 | "EncodedUnimodalToConcept": { 80 | "object": mmmt.data.representation.encoded_unimodal_to_concept.EncodedUnimodalToConcept, 81 | "need_cache": True, 82 | "need_pipeline": True, 83 | "need_call_method": True, 84 | }, 85 | "GraphVisualization": { 86 | "object": mmmt.data.graph.visualization.GraphVisualization, 87 | "need_cache": True, 88 | "need_pipeline": True, 89 | "need_call_method": True, 90 | }, 91 | } 92 | 93 | if specific_objects: 94 | self.object_dict.update(specific_objects) 95 | 96 | def instance_object(self, op_key, op_arguments): 97 | """ 98 | Instanciate the selected object. 99 | 100 | Args: 101 | op_key (str): object identifier 102 | op_arguments (dict): arguments needed by the selected object 103 | 104 | Returns: 105 | instanciated object 106 | """ 107 | 108 | if op_arguments is None: 109 | return self.object_dict[op_key]["object"]() 110 | elif any(isinstance(i, dict) for i in op_arguments.values()): 111 | return self.object_dict[op_key]["object"](op_arguments) 112 | else: 113 | return self.object_dict[op_key]["object"](**op_arguments) 114 | -------------------------------------------------------------------------------- /mmmt/pipeline/pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | from copy import copy 5 | 6 | import mlflow 7 | import yaml 8 | 9 | from mmmt.pipeline.object_registry import ObjectRegistry 10 | 11 | # add the option to extract values from env varaibles 12 | env_pattern = re.compile(r".*?\${(.*?)}.*?") 13 | 14 | 15 | def env_constructor(loader, node): 16 | value = loader.construct_scalar(node) 17 | for group in env_pattern.findall(value): 18 | if group not in os.environ or os.environ[group] is None: 19 | raise Exception(f"Error: missing env var {group}") 20 | print(f"Configuration file: env variable read {group}={os.environ.get(group)}") 21 | value = value.replace(f"${{{group}}}", os.environ.get(group)) 22 | return value 23 | 24 | 25 | yaml.add_implicit_resolver("!pathex", env_pattern) 26 | yaml.add_constructor("!pathex", env_constructor) 27 | 28 | logging.basicConfig(level=logging.INFO) 29 | 30 | 31 | class MMMTPipeline: 32 | """ 33 | MMMT pipeline, used to interpret the configuration set by the user. Values not defined by the user are taken from the defaults.yaml. 34 | 35 | A typical pipeline involves the following steps (as described in D'Souza et al. [1]_): 36 | 37 | 1. Each modality is encoded from its raw representation into feature vectors using 38 | the relevant key in `encoding_strategy`. It allows to use a pretrained model, train a 39 | modality-specific autoencoder, a combination of model + autoencoder or none (raw 40 | representation needs to be a feature vector) 41 | 42 | 2. All the encoded modalities are combined together with a `concept_encoder` that is a 43 | simple autoencoder that projects the modalities into a small latent space (autoencoder 44 | bottleneck). 45 | 46 | 3. The unimodal embedded features are organized as a graph, where the links are obtained 47 | through detecting saliency of the features on each individual dimension of the concept 48 | embedding space (one edge type per dimension). 49 | Like this, the `data.base_graph` key will contain, for each sample, a 50 | representation of the unimodal embedded features. 51 | 52 | 4. If the graph_module to be used is supporting multiplexed graphs, each of the edge types 53 | defines a graph layer, and nodes will be replicated across all layers and interlayer links 54 | are added into the `data.derived_graph` key. 55 | 56 | 5. The samples transformed in graphs are used to train, validate and test a GNN as specified 57 | in the configuration yaml file. 58 | 59 | 60 | .. rubric:: References 61 | .. [1] D'Souza, Niharika, et al. "Fusing Modalities by Multiplexed Graph Neural Networks for Outcome Prediction in Tuberculosis." International Conference on Medical Image Computing and Computer-Assisted Intervention. 2022. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | user_configs_pipeline_path, 67 | specific_objects, 68 | defaults="mmmt/pipeline/defaults.yaml", 69 | ): 70 | """ 71 | 72 | Args: 73 | user_configs_pipeline_path (str): path to case-specific configuration 74 | specific_objects (dict): dictionary of case-specific objects 75 | """ 76 | 77 | self.obj_reg = ObjectRegistry(specific_objects) 78 | 79 | defaults_mmmt_pipeline = yaml.full_load(open(defaults, "r")) 80 | 81 | user_configs = yaml.full_load(open(user_configs_pipeline_path, "r")) 82 | 83 | self.mmmt_pipeline_config = self.update_config( 84 | defaults_mmmt_pipeline, user_configs 85 | ) 86 | 87 | self.mlflow_configurator() 88 | 89 | self.cache = self.mmmt_pipeline_config["cache"] 90 | 91 | self.pipeline = {} 92 | 93 | logging.info("MMMT pipeline initialized") 94 | 95 | def update_config(self, to_be_updated, user_update): 96 | """ 97 | Recursive function to update the nested configuration dictionary. 98 | Args: 99 | to_be_updated (dict): dictionary to be updated 100 | user_update (dict): dictionary containing the updated values 101 | 102 | Returns: 103 | updated dictionary 104 | """ 105 | for k, v in user_update.items(): 106 | if isinstance(v, dict): 107 | to_be_updated[k] = self.update_config(to_be_updated.get(k, {}), v) 108 | elif isinstance(v, list): 109 | if k not in to_be_updated: 110 | to_be_updated[k] = copy(v) 111 | else: 112 | for ind, elem in enumerate(v): 113 | if isinstance(elem, dict): 114 | to_be_updated[k][ind] = self.update_config( 115 | to_be_updated[k][ind], elem 116 | ) 117 | else: 118 | if ind == len(to_be_updated[k]): 119 | to_be_updated[k].append(elem) 120 | else: 121 | to_be_updated[k][ind] = elem 122 | else: 123 | to_be_updated[k] = v 124 | return to_be_updated 125 | 126 | def mlflow_configurator( 127 | self, 128 | ): 129 | for mlflow_env_key in self.mmmt_pipeline_config["mlflow"]: 130 | if ( 131 | mlflow_env_key not in os.environ 132 | and self.mmmt_pipeline_config["mlflow"][mlflow_env_key] 133 | ): 134 | os.environ[mlflow_env_key] = self.mmmt_pipeline_config["mlflow"][ 135 | mlflow_env_key 136 | ] 137 | 138 | def run_pipeline(self, debugging=False): 139 | """ 140 | Run each pipeline step. 141 | Args: 142 | debugging (bool, optional): boolean to control the number of samples. If it is boolean only few samples are used. Defaults to False. 143 | """ 144 | with mlflow.start_run(nested=True): 145 | for phase in self.mmmt_pipeline_config: 146 | if phase in ["cache", "mlflow"]: 147 | continue 148 | for step in self.mmmt_pipeline_config[phase]: 149 | 150 | self.process_step(step) 151 | 152 | logging.info(phase + " processed") 153 | 154 | # for debugging 155 | if phase == "data" and debugging: 156 | self.pipeline["data_splits"]["train_ids"] = self.pipeline[ 157 | "data_splits" 158 | ]["train_ids"][:3] 159 | self.pipeline["data_splits"]["val_ids"] = self.pipeline[ 160 | "data_splits" 161 | ]["val_ids"][:3] 162 | self.pipeline["data_splits"]["test_ids"] = self.pipeline[ 163 | "data_splits" 164 | ]["test_ids"][:3] 165 | 166 | def process_step(self, step): 167 | """ 168 | Process a step taking the configuration, in particular it expect an object and its arguments. 169 | Args: 170 | step (dict): Configuration of the step to initialize and execute. 171 | """ 172 | if "fuse_object" in step: 173 | if "fuse_pipeline" not in self.pipeline: 174 | self.pipeline["fuse_pipeline"] = {} 175 | if not self.pipeline["fuse_pipeline"]: 176 | self.pipeline["fuse_pipeline"] = self.obj_reg.instance_object( 177 | step["fuse_object"], step["args"] 178 | ) 179 | else: 180 | self.pipeline["fuse_pipeline"].extend( 181 | (self.obj_reg.object_dict[step["fuse_object"]], step["args"]) 182 | ) 183 | else: 184 | need_cache = self.obj_reg.object_dict[step["object"]].get( 185 | "need_cache", False 186 | ) 187 | need_pipeline = self.obj_reg.object_dict[step["object"]].get( 188 | "need_pipeline", False 189 | ) 190 | need_object_registry = self.obj_reg.object_dict[step["object"]].get( 191 | "need_object_registry", False 192 | ) 193 | need_call_method = self.obj_reg.object_dict[step["object"]].get( 194 | "need_call_method", False 195 | ) 196 | 197 | if need_cache or need_pipeline or need_object_registry: 198 | step_args = {"step_args": step["args"]} 199 | 200 | if need_cache: 201 | step_args.update(self.cache) 202 | if need_pipeline: 203 | step_args["pipeline"] = self.pipeline 204 | if need_object_registry: 205 | step_args["object_registry"] = self.obj_reg 206 | else: 207 | step_args = step["args"] 208 | 209 | if "io" in step: 210 | 211 | self.pipeline[step["io"]["output_key"]] = self.obj_reg.instance_object( 212 | step["object"], step_args 213 | ) 214 | if need_call_method: 215 | self.pipeline[step["io"]["output_key"]]() 216 | else: 217 | self.pipeline[step["object"]] = self.obj_reg.instance_object( 218 | step["object"], step_args 219 | ) 220 | if need_call_method: 221 | self.pipeline[step["object"]]() 222 | -------------------------------------------------------------------------------- /mmmt/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/multimodal-models-toolkit/256f12b369769003986b68044086f4898cb096bf/mmmt/py.typed -------------------------------------------------------------------------------- /mmmt_examples/README.md: -------------------------------------------------------------------------------- 1 | # mmmt examples 2 | 3 | In [mmmt_examples](mmmt_examples/README.md) we keep a list of examples of MMMT applications. 4 | 5 | The goal of these scripts is to showcase how to use MMMT to selected datasets. 6 | 7 | ## Structure in mmmt_examples 8 | Each selected dataset has a folder in `mmmt_examples` containing 9 | 1. pipeline 10 | 11 | a runnable script going from data loading to model training and evaluation. 12 | 13 | *It requires installation of MMMT, see [MMMT readme](./../README.md)* 14 | 15 | 2. scripts specific to the selected dataset 16 | 17 | in particular, fuse operators for dataset loading and evaluation 18 | 19 | In addition, to the KNIGHT example we also provide a **demonstration notebook**: a runnable notebook containing same functionality as the pipeline and adding visualizations of the generated sample graphs. 20 | 21 | 22 | ## Datasets 23 | The code needed to download the selected datasets is not part of mmmt_example, but instructions are given in the pipeline and demonstration notebook. 24 | 25 | ### Datasets used so far in [mmmt_examples](mmmt_examples/README.md) 26 | | Dataset name | Short description | Link to dataset | 27 | |--------------|-----------------------------------|-----------------------------------| 28 | | KNIGHT | Kidney clinical Notes and Imaging to Guide and Help personalize Treatment and biomarkers discovery | https://research.ibm.com/haifa/Workshops/KNIGHT/ | 29 | 30 | 31 | ## Pipeline configuration 32 | 33 | Pipeline configuration is specified in the configuration yaml file. 34 | 35 | In the current examples we configure: 36 | 1. How to process the cache (`cache`) 37 | 2. How to read the data (`data`) 38 | 3. How to encode each modality independently (`modality_encoding_strategy`) 39 | 4. How to fuse the encoded modalities (`fusion_strategy`) 40 | 5. How to solve the task (`task_strategy`) 41 | 42 | 43 | ### As a reference we list below configuration examples for the 6 graph modules currently available in MMMT for the solution of the task 44 | 45 | The following arguments belong to `args.model_config.graph_model` of the object `MultimodalGraphModel`. 46 | 47 | 1. Multiplex GIN 48 | 49 | ```yaml 50 | module_identifier: "mplex" 51 | n_layers: *n_layers 52 | node_emb_dim: 1 53 | ``` 54 | 55 | 2. Multiplex GCN 56 | 57 | ```yaml 58 | module_identifier: "mplex-prop" 59 | n_layers: *n_layers 60 | gl_hidden_size: 61 | - 2 62 | - 2 63 | node_emb_dim: 1 64 | ``` 65 | 66 | 3. mGNN 67 | 68 | ```yaml 69 | module_identifier: "mgnn" 70 | n_layers: *n_layers 71 | gl_hidden_size: 72 | - 2 73 | num_att_heads: 4 74 | node_emb_dim: 1 75 | ``` 76 | 77 | 4. MultiBehavioral GNN 78 | 79 | 80 | ```yaml 81 | module_identifier: "multibehav" 82 | n_layers: *n_layers 83 | gl_hidden_size: 84 | - 2 85 | node_emb_dim: 1 86 | ``` 87 | 88 | 89 | 90 | 5. GCN 91 | 92 | 93 | ```yaml 94 | module_identifier: "gcn" 95 | n_layers: *n_layers 96 | gl_hidden_size: 97 | - 2 98 | node_emb_dim: 1 99 | ``` 100 | 101 | 102 | 6. R-GCN 103 | 104 | ```yaml 105 | module_identifier: "rgcn" 106 | n_layers: *n_layers 107 | gl_hidden_size: 108 | - 2 109 | num_bases: 8 110 | node_emb_dim: 1 111 | ``` 112 | -------------------------------------------------------------------------------- /mmmt_examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/multimodal-models-toolkit/256f12b369769003986b68044086f4898cb096bf/mmmt_examples/__init__.py -------------------------------------------------------------------------------- /mmmt_examples/knight/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/multimodal-models-toolkit/256f12b369769003986b68044086f4898cb096bf/mmmt_examples/knight/__init__.py -------------------------------------------------------------------------------- /mmmt_examples/knight/full_mmmt_pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from mmmt.pipeline.pipeline import MMMTPipeline 3 | 4 | import fuseimg.datasets.knight 5 | import knight_eval 6 | import get_splits 7 | 8 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | 12 | # pre-requisite for KNIGHT data 13 | # git clone https://github.com/neheller/KNIGHT.git 14 | # python KNIGHT/knight/scripts/get_imaging.py 15 | # mv KNIGHT/knight/data downloads/knight_data 16 | 17 | 18 | 19 | if __name__ == "__main__": 20 | 21 | parser = ArgumentParser( 22 | description="Multimodal fusion Experiments", 23 | formatter_class=ArgumentDefaultsHelpFormatter, 24 | ) 25 | 26 | parser.add_argument( 27 | "--path_to_config", 28 | dest="path_to_config", 29 | type=str, 30 | default="mmmt_examples/knight/mmmt_pipeline_config_demonstration.yaml", 31 | help="Path to pipeline configuration", 32 | ) 33 | 34 | # 2 yaml files are provided in this example: 35 | # - mmmt_pipeline_config.yaml for training a model using all the data and GPUs 36 | # - mmmt_pipeline_config_demonstration.yaml for training a model with a small subset of the data and CPUs 37 | 38 | parser_args = parser.parse_args() 39 | 40 | mmmt_pipeline_config_path = parser_args.path_to_config 41 | 42 | # Specify specific objects needed for this particular example 43 | specific_objects = { 44 | "KNIGHT.static_pipeline": { 45 | "object": fuseimg.datasets.knight.KNIGHT.static_pipeline, 46 | }, 47 | "Eval": { 48 | "object": knight_eval.knight_eval, 49 | "need_cache": True, 50 | }, 51 | "get_splits_str_ids": { 52 | "object": get_splits.get_splits_str_ids, 53 | }, 54 | } 55 | 56 | # Initialize the pipeline 57 | MMMTP = MMMTPipeline( 58 | mmmt_pipeline_config_path, specific_objects, defaults=mmmt_pipeline_config_path 59 | ) 60 | 61 | # Run the pipeline - option debugging=True will only use the first 3 samples for each dataset 62 | MMMTP.run_pipeline(debugging=True) 63 | 64 | logging.info("MMMT pipeline completed") 65 | -------------------------------------------------------------------------------- /mmmt_examples/knight/get_splits.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import random 3 | 4 | 5 | def get_splits(pickle_path="splits_final.pkl", split_id=0): 6 | splits = pd.read_pickle(pickle_path) 7 | 8 | train_val_range = [] 9 | for sample_id_str in splits[split_id]["train"]: 10 | sample_id = int(sample_id_str.split("_")[-1]) 11 | train_val_range.append(sample_id) 12 | 13 | val_length = int(0.2 * len(train_val_range)) 14 | val_ids = random.choices(train_val_range, k=val_length) 15 | train_ids = [v for v in train_val_range if v not in val_ids] 16 | 17 | test_ids = [] 18 | for sample_id_str in splits[split_id]["val"]: 19 | sample_id = int(sample_id_str.split("_")[-1]) 20 | test_ids.append(sample_id) 21 | 22 | data_splits = {"train_ids": train_ids, "val_ids": val_ids, "test_ids": test_ids} 23 | 24 | return data_splits 25 | 26 | 27 | def get_splits_str_ids(pickle_path="splits_final.pkl", split_id=0): 28 | splits = pd.read_pickle(pickle_path) 29 | 30 | train_val_range = [] 31 | for sample_id_str in splits[split_id]["train"]: 32 | train_val_range.append(sample_id_str) 33 | 34 | val_length = int(0.2 * len(train_val_range)) 35 | val_ids = random.choices(train_val_range, k=val_length) 36 | train_ids = [v for v in train_val_range if v not in val_ids] 37 | 38 | test_ids = [] 39 | for sample_id_str in splits[split_id]["val"]: 40 | test_ids.append(sample_id_str) 41 | 42 | data_splits = {"train_ids": train_ids, "val_ids": val_ids, "test_ids": test_ids} 43 | 44 | return data_splits 45 | -------------------------------------------------------------------------------- /mmmt_examples/knight/knight_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from fuse.eval.evaluator import EvaluatorDefault 4 | from fuse.eval.metrics.classification.metrics_classification_common import ( 5 | MetricAUCROC, 6 | MetricROCCurve, 7 | ) 8 | from fuse.eval.metrics.metrics_common import CI 9 | 10 | 11 | def knight_eval(args_dict: dict) -> dict: 12 | """ 13 | Evaluation of the binary task in knight 14 | Expect as input either a dataframe or a path to a dataframe that includes 3 columns: 15 | 1. "id" - unique identifier per sample - can be a running index 16 | 2. "pred" - prediction scores per sample 17 | 3. "target" - the ground truth label 18 | """ 19 | 20 | evaluation_directory = os.path.join( 21 | args_dict["root_dir"], args_dict["step_args"]["evaluation_directory"] 22 | ) 23 | test_results_filename = os.path.join( 24 | evaluation_directory, args_dict["step_args"]["test_results_filename"] 25 | ) 26 | 27 | metrics = { 28 | "auc": CI( 29 | MetricAUCROC( 30 | pred="pred", 31 | target="target", 32 | ), 33 | stratum="target", 34 | ), 35 | "roc_curve": MetricROCCurve( 36 | pred="pred", 37 | target="target", 38 | output_filename=os.path.join(evaluation_directory, "roc.png"), 39 | ), 40 | } 41 | 42 | evaluator = EvaluatorDefault() 43 | return evaluator.eval( 44 | ids=None, 45 | data=test_results_filename, 46 | metrics=metrics, 47 | output_dir=evaluation_directory, 48 | ) 49 | -------------------------------------------------------------------------------- /mmmt_examples/knight/mmmt_pipeline_config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | - fuse_object: "KNIGHT.static_pipeline" 3 | args: 4 | data_path: "path-to-knight-data" 5 | resize_to: 6 | - 70 7 | - 256 8 | - 256 9 | 10 | - object: "get_splits_str_ids" 11 | io: 12 | input_key: null 13 | output_key: "data_splits" 14 | args: 15 | pickle_path: "path-to-splits-file" 16 | 17 | 18 | mlflow: 19 | MLFLOW_TRACKING_URI: "path-to-track-with-mlflow" 20 | MLFLOW_EXPERIMENT_NAME: "experiment-name" 21 | 22 | cache: 23 | num_workers: 1 24 | restart_cache: True 25 | root_dir: "_examples/knight" 26 | 27 | modality_encoding_strategy: 28 | - object: "ModalityEncoding" 29 | args: 30 | data.input.clinical.all: 31 | model: null 32 | use_autoencoder: False 33 | output_key: "data.input.encoded_clinical" 34 | data.input.img: 35 | model_path: "path-to-model" 36 | add_feature_names: &add_feature_names False 37 | dimensions: 4 38 | output_key: "data.input.encoded_img" 39 | use_autoencoder: True 40 | encoding_layers: 41 | - 128 42 | - 64 43 | batch_size: 3 44 | training: 45 | pl_trainer_num_epochs: 1 46 | pl_trainer_accelerator: "cpu" 47 | 48 | 49 | fusion_strategy: 50 | - object: "EncodedUnimodalToConcept" 51 | args: 52 | use_autoencoders: True 53 | add_feature_names: *add_feature_names 54 | encoding_layers: 55 | - 32 56 | - &n_layers 16 57 | batch_size: 3 58 | training: 59 | pl_trainer_num_epochs: 1 60 | pl_trainer_accelerator: "cpu" 61 | io: 62 | input_keys: 63 | - "data.input.encoded_clinical" 64 | - "data.input.encoded_img" 65 | 66 | - object: "ConceptToGraph" 67 | args: 68 | module_identifier: &graph_module "mplex" 69 | thresh_q: 0.95 70 | 71 | 72 | task_strategy: 73 | - object: "MultimodalGraphModel" 74 | args: 75 | model_config: 76 | graph_model: 77 | module_identifier: *graph_module 78 | n_layers: *n_layers 79 | node_emb_dim: 1 80 | head_model: 81 | head_hidden_size: 82 | - 100 83 | - 20 84 | dropout: 0.5 85 | num_classes: 2 86 | 87 | training: 88 | batch_size: 3 89 | pl_trainer_num_epochs: 1 90 | pl_trainer_accelerator: "cpu" 91 | pl_trainer_devices: 1 92 | -------------------------------------------------------------------------------- /mmmt_examples/knight/mmmt_pipeline_config_demonstration.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | - fuse_object: "KNIGHT.static_pipeline" 3 | args: 4 | data_path: "path-to-knight-data" 5 | resize_to: 6 | - 70 7 | - 256 8 | - 256 9 | 10 | - object: "get_splits_str_ids" 11 | io: 12 | input_key: null 13 | output_key: "data_splits" 14 | args: 15 | pickle_path: "path-to-splits-file" 16 | 17 | mlflow: 18 | MLFLOW_TRACKING_URI: "path-to-track-with-mlflow" 19 | MLFLOW_EXPERIMENT_NAME: "experiment-name" 20 | 21 | cache: 22 | num_workers: 1 23 | restart_cache: True 24 | root_dir: "_examples/knight" 25 | 26 | modality_encoding_strategy: 27 | - object: "ModalityEncoding" 28 | args: 29 | data.input.clinical.all: 30 | model: null 31 | use_autoencoder: False 32 | output_key: "data.input.encoded_clinical" 33 | data.input.img: 34 | model_path: "path-to-model" 35 | add_feature_names: &add_feature_names True 36 | dimensions: 4 37 | output_key: "data.input.encoded_img" 38 | use_autoencoder: True 39 | encoding_layers: 40 | - 64 41 | - 32 42 | batch_size: 3 43 | training: 44 | pl_trainer_num_epochs: 1 45 | pl_trainer_accelerator: "cpu" 46 | 47 | 48 | fusion_strategy: 49 | - object: "EncodedUnimodalToConcept" 50 | args: 51 | use_autoencoders: True 52 | add_feature_names: *add_feature_names 53 | encoding_layers: 54 | - 32 55 | - &n_layers 16 56 | batch_size: 3 57 | training: 58 | pl_trainer_num_epochs: 1 59 | pl_trainer_accelerator: "cpu" 60 | io: 61 | input_keys: 62 | - "data.input.encoded_clinical" 63 | - "data.input.encoded_img" 64 | 65 | - object: "ConceptToGraph" 66 | args: 67 | module_identifier: &graph_module "mplex" 68 | thresh_q: 0.95 69 | 70 | 71 | task_strategy: 72 | - object: "MultimodalGraphModel" 73 | args: 74 | model_config: 75 | graph_model: 76 | module_identifier: *graph_module 77 | n_layers: *n_layers 78 | node_emb_dim: 1 79 | head_model: 80 | head_hidden_size: 81 | - 50 82 | - 20 83 | dropout: 0.5 84 | num_classes: 2 85 | 86 | training: 87 | batch_size: 3 88 | pl_trainer_num_epochs: 1 89 | pl_trainer_accelerator: "cpu" 90 | pl_trainer_devices: 1 91 | -------------------------------------------------------------------------------- /mmmt_examples/knight/mmmt_pipeline_config_demonstration_mlp.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | - fuse_object: "KNIGHT.static_pipeline" 3 | args: 4 | data_path: "path-to-knight-data" 5 | resize_to: 6 | - 70 7 | - 256 8 | - 256 9 | 10 | - object: "get_splits_str_ids" 11 | io: 12 | input_key: null 13 | output_key: "data_splits" 14 | args: 15 | pickle_path: "path-to-splits-file" 16 | 17 | mlflow: 18 | MLFLOW_TRACKING_URI: "path-to-track-with-mlflow" 19 | MLFLOW_EXPERIMENT_NAME: "experiment-name" 20 | 21 | cache: 22 | num_workers: 1 23 | restart_cache: True 24 | root_dir: "_examples/knight" 25 | 26 | modality_encoding_strategy: 27 | - object: "ModalityEncoding" 28 | args: 29 | data.input.clinical.all: 30 | model: null 31 | use_autoencoder: False 32 | output_key: "data.input.encoded_clinical" 33 | data.input.img: 34 | model_path: "path-to-model" 35 | add_feature_names: &add_feature_names True 36 | dimensions: 4 37 | output_key: "data.input.encoded_img" 38 | use_autoencoder: True 39 | encoding_layers: 40 | - 32 41 | batch_size: 3 42 | training: 43 | pl_trainer_num_epochs: 1 44 | pl_trainer_accelerator: "cpu" 45 | 46 | 47 | fusion_strategy: 48 | - object: "EncodedUnimodalToConcept" 49 | args: 50 | use_autoencoders: True 51 | add_feature_names: *add_feature_names 52 | encoding_layers: 53 | - 32 54 | - 16 55 | batch_size: 3 56 | training: 57 | pl_trainer_num_epochs: 1 58 | pl_trainer_accelerator: "cpu" 59 | io: 60 | input_keys: 61 | - "data.input.encoded_clinical" 62 | - "data.input.encoded_img" 63 | concept_encoder_model_key: "concept_encoder_model" 64 | output_key: "data.input.multimodal" 65 | 66 | 67 | task_strategy: 68 | - object: "MultimodalMLP" 69 | args: 70 | io: 71 | input_key: "data.input.multimodal" 72 | target_key: &target "data.gt.gt_global.task_1_label" 73 | prediction_key: &prediction "model.out" 74 | 75 | model_config: 76 | hidden_size: 77 | - 100 78 | - 20 79 | dropout: 0.5 80 | add_softmax: True 81 | num_classes: 2 82 | 83 | training: 84 | model_dir: "model_mlp" 85 | batch_size: 3 86 | best_epoch_source: 87 | mode: "max" 88 | monitor: "validation.metrics.auc" 89 | train_metrics: 90 | key: "auc" 91 | object: "MetricAUCROC" 92 | args: 93 | pred: *prediction 94 | target: *target 95 | validation_metrics: 96 | key: "auc" 97 | object: "MetricAUCROC" 98 | args: 99 | pred: *prediction 100 | target: *target 101 | pl_trainer_num_epochs: 1 102 | pl_trainer_accelerator: "cpu" 103 | pl_trainer_devices: 1 104 | 105 | testing: 106 | test_results_filename: &test_results_filename "test_results.pickle" 107 | evaluation_directory: &evaluation_directory "eval" 108 | -------------------------------------------------------------------------------- /mmmt_examples/knight/op_knight.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from fuse.data.ops.op_base import OpBase 8 | from fuse.data.utils.sample import get_sample_id 9 | from fuse.utils.ndict import NDict 10 | 11 | CLINICAL_FEATURES_CHALLENGE = [ 12 | "age_at_nephrectomy", 13 | "gender", 14 | "body_mass_index", 15 | "comorbidities", 16 | "smoking_history", 17 | "age_when_quit_smoking", 18 | "pack_years", 19 | "chewing_tobacco_use", 20 | "alcohol_use", 21 | "last_preop_egfr", 22 | "radiographic_size", 23 | "voxel_spacing", 24 | ] 25 | 26 | 27 | class OpDecodeKNIGHT(OpBase): 28 | """ 29 | Decoding operator class for the KNIGHT dataset. It has been reduced to 30 | minimal hardcoded values 31 | """ 32 | 33 | def __init__( 34 | self, 35 | knight_data_path, 36 | categorical_threshold=10, 37 | selected_cols=None, 38 | label="aua_risk_group", 39 | ): 40 | super().__init__() 41 | 42 | # replacement values for non existing data points 43 | replacements = { 44 | "None": np.nan, 45 | "not_applicable": np.nan, 46 | None: np.nan, 47 | "": np.nan, 48 | } 49 | self.selected_cols = selected_cols 50 | self.label = label 51 | 52 | self.knight_data_path = os.path.abspath(knight_data_path) 53 | json_file = os.path.join(self.knight_data_path, "knight.json") 54 | with open(json_file) as file: 55 | json_content = json.load(file) 56 | 57 | # raw dataframe without processing 58 | self.raw_dataframe = pd.read_json(json_file) 59 | 60 | # preprocessing of the data, applying replacements and casting types 61 | # replacements of "non-values" 62 | dataframe = pd.json_normalize(json_content).replace(replacements) 63 | 64 | clinical_feature_list = [] 65 | for col in dataframe.columns: 66 | for feature_name in CLINICAL_FEATURES_CHALLENGE: 67 | if feature_name in col: 68 | clinical_feature_list.append(col) 69 | 70 | self.label_dataframe = dataframe[[self.label]] 71 | dataframe = dataframe[clinical_feature_list] 72 | 73 | # identify categorical columns 74 | for col in dataframe.columns: 75 | if len(dataframe[col].unique()) < categorical_threshold: 76 | dataframe[col] = dataframe[col].astype("category") 77 | 78 | # convert categorical columns to one-hot-encoded features 79 | for k in dataframe.select_dtypes(include="category").columns: 80 | dataframe = pd.concat( 81 | [dataframe, pd.get_dummies(dataframe[k], prefix=k)], 82 | axis=1, 83 | join="inner", 84 | ) 85 | # convert non-categorical values to numbers (except case_id) 86 | for k in dataframe.select_dtypes(exclude="category").columns: 87 | if k != "case_id": 88 | dataframe[k] = pd.to_numeric(dataframe[k], errors="coerce") 89 | 90 | # only numerical features and one-hot encoded categories are kept 91 | self.dataframe = dataframe.select_dtypes("number").fillna(0) 92 | 93 | def __call__(self, sample_dict: NDict, **kwargs) -> Union[None, dict, List[dict]]: 94 | 95 | raw_key = "data.raw_clinical" 96 | clinical_key = "data.input.clinical_features" 97 | image_path_key = "data.input.image_path" 98 | label_key = "data.gt.label" 99 | clinical_names_key = "names.data.input.clinical_features" 100 | 101 | sid = get_sample_id(sample_dict) 102 | sample_dict[raw_key] = self.raw_dataframe.iloc[sid] 103 | sample_dict[clinical_key] = self.dataframe.iloc[sid].to_numpy(dtype=np.float32) 104 | sample_dict[image_path_key] = os.path.join( 105 | self.knight_data_path, 106 | sample_dict[raw_key]["case_id"], 107 | "imaging.nii.gz", 108 | ) 109 | sample_dict[label_key] = int( 110 | self.label_dataframe.iloc[sid]["aua_risk_group"] 111 | in ["high_risk", "very_high_risk"] 112 | ) 113 | sample_dict[clinical_names_key] = list(self.dataframe.columns) 114 | 115 | return sample_dict 116 | -------------------------------------------------------------------------------- /mmmt_examples/knight/pipeline.gv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/multimodal-models-toolkit/256f12b369769003986b68044086f4898cb096bf/mmmt_examples/knight/pipeline.gv.png -------------------------------------------------------------------------------- /mmmt_examples/knight/splits_final.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/multimodal-models-toolkit/256f12b369769003986b68044086f4898cb096bf/mmmt_examples/knight/splits_final.pkl -------------------------------------------------------------------------------- /mmmt_examples/knight/user_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/multimodal-models-toolkit/256f12b369769003986b68044086f4898cb096bf/mmmt_examples/knight/user_input.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # configuration approach followed: 2 | # - whenever possible, prefer pyproject.toml 3 | # - for configurations insufficiently supported by pyproject.toml, use setup.cfg instead 4 | # - setup.py discouraged; minimal stub included only for compatibility with legacy tools 5 | 6 | [build-system] 7 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] 8 | build-backend = "setuptools.build_meta" 9 | 10 | [project] 11 | name = "mmmt" 12 | description = "Open-source framework designed to solve multimodal machine learning tasks" 13 | authors = [ 14 | { name = "Andrea Giovannini", email = "agv@zurich.ibm.com"}, 15 | { name = "Antonio Foncubierta Rodriguez", email = "fra@zurich.ibm.com"}, 16 | { name = "Hongzhi Wang", email = "hongzhiw@us.ibm.com"}, 17 | { name = "Ken Wong", email = "clwong@us.ibm.com"}, 18 | { name = "Kevin Thandiackal", email = "kth@zurich.ibm.com"}, 19 | { name = "Michal Ozery-Flato", email = "ozery@il.ibm.com"}, 20 | { name = "Moshiko Raboh", email = "moshiko.raboh@ibm.com"}, 21 | { name = "Niharika D'Souza", email = "Niharika.DSouza@ibm.com"}, 22 | { name = "Panos Vagenas", email = "pva@zurich.ibm.com"}, 23 | { name = "Tanveer Syeda-Mahmood", email = "stf@us.ibm.com"}, 24 | ] 25 | readme = "README.md" 26 | # due to how PEP 440 defines version matching, prefer [incl, excl) definitions like below: 27 | requires-python = ">=3.8, <3.10" 28 | dependencies = [ 29 | "numpy>=1.23.5", 30 | "dgl==0.9.1", 31 | "rdflib>=6.2.0", 32 | "torch>=1.13.1", 33 | "scikit-learn>=1.1.3", 34 | "matplotlib>=3.6.2", 35 | "seaborn>=0.12.1", 36 | "mlflow>=2.0.1", 37 | "fuse-med-ml[all]>=0.2.9", 38 | "networkx==2.8.8", 39 | "nxviz>=0.7.4", 40 | "graphviz>=0.20.1", 41 | "pandas>=1.5.1", 42 | ] 43 | dynamic = ["version"] 44 | 45 | [project.optional-dependencies] 46 | dev = [ 47 | # tests 48 | "pytest", 49 | "pytest-cov", 50 | # checks 51 | "black", 52 | "flake8", 53 | "pep8-naming", 54 | # docs 55 | "Sphinx", 56 | "sphinx-autodoc-typehints", 57 | "sphinx-rtd-theme", 58 | "myst-parser", 59 | "better-apidoc", 60 | ] 61 | 62 | [tool.setuptools_scm] 63 | version_scheme = "post-release" 64 | 65 | [tool.semantic_release] 66 | # for default values check: 67 | # https://github.com/relekang/python-semantic-release/blob/master/semantic_release/defaults.cfg 68 | 69 | version_source = "tag_only" 70 | branch = "main" 71 | 72 | # configure types which should trigger minor and patch version bumps respectively 73 | # (note that they must be a subset of the configured allowed types): 74 | parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test" 75 | parser_angular_minor_types = "feat" 76 | parser_angular_patch_types = "fix,perf" 77 | 78 | # unstaging the changelog (i.e. command part before &&) is a workaround for 79 | # https://github.com/relekang/python-semantic-release/issues/381: 80 | build_command = "git restore --staged CHANGELOG.md && python -m build" 81 | 82 | github_token_var="GITHUB_TOKEN" 83 | 84 | hvcs_domain="github.ibm.com" 85 | hvcs_api_domain="github.ibm.com/api/v3" # see https://ibm-analytics.slack.com/archives/C3SSJ6CSE/p1660313562338669?thread_ts=1660300230.162449&cid=C3SSJ6CSE 86 | 87 | [tool.mypy] 88 | check_untyped_defs = true 89 | 90 | [[tool.mypy.overrides]] 91 | module = "pytest.*" 92 | ignore_missing_imports = true 93 | 94 | [tool.black] 95 | line-length = 88 96 | skip-string-normalization = false 97 | target-version = ['py37'] 98 | 99 | [tool.isort] 100 | multi_line_output = 3 101 | include_trailing_comma = true 102 | force_grid_wrap = 0 103 | use_parentheses = true 104 | ensure_newline_before_comments = true 105 | line_length = 88 106 | force_to_top = ["rdkit", "scikit-learn"] 107 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # configuration approach followed: 2 | # - whenever possible, prefer pyproject.toml 3 | # - for configurations insufficiently supported by pyproject.toml, use setup.cfg instead 4 | # - setup.py discouraged; minimal stub included only for compatibility with legacy tools 5 | 6 | # pyproject.toml support for configs outside PEP 621 is currently only in beta 7 | # see https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html 8 | [options] 9 | packages = find: 10 | 11 | [options.packages.find] 12 | exclude = 13 | test 14 | 15 | # flake8 currently does not support pyproject.toml 16 | # see https://github.com/PyCQA/flake8/issues/234 17 | [flake8] 18 | max-line-length = 80 19 | select = C,E,F,W,B,B950 20 | ignore = E203, E501, W503 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # configuration approach followed: 2 | # - whenever possible, prefer pyproject.toml 3 | # - for configurations insufficiently supported by pyproject.toml, use setup.cfg instead 4 | # - setup.py discouraged; minimal stub included only for compatibility with legacy tools 5 | 6 | # Minimal stub for backwards compatibility, e.g. for legacy tools without PEP 660 support. 7 | # See https://setuptools.pypa.io/en/latest/userguide/quickstart.html 8 | from setuptools import setup 9 | 10 | setup() 11 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/multimodal-models-toolkit/256f12b369769003986b68044086f4898cb096bf/test/__init__.py -------------------------------------------------------------------------------- /test/data/graph/test_dgl_data_loader.py: -------------------------------------------------------------------------------- 1 | """Unit test for complex_module.core.""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | 10 | import unittest 11 | from mmmt.data.graph.dgl_data_loader import DGLFileLoader 12 | from dgl import DGLGraph 13 | import torch 14 | 15 | 16 | class DGLFileLoaderTestCase(unittest.TestCase): 17 | """DGLFileLoaderTestCase class.""" 18 | 19 | def setUp(self): 20 | """Setting up the test.""" 21 | self.existing_datasets = ["AIFB"] 22 | self.non_existing_dataset = "this_dataset_does_not_exist" 23 | pass 24 | 25 | def test_dataset_names(self): 26 | """Test that supported datasets return the correct types.""" 27 | for dataset_name in self.existing_datasets: 28 | loader = DGLFileLoader(dataset_name, [0.3, 0.3, 0.4], 0) 29 | g, labels, data_splits, n_classes, num_rels = loader.build_graph() 30 | self.assertIsInstance(g, DGLGraph) 31 | self.assertIsInstance(labels, torch.Tensor) 32 | self.assertIsInstance(data_splits, list) 33 | self.assertIsInstance(n_classes, int) 34 | self.assertIsInstance(num_rels, int) 35 | 36 | def test_incorrect_splits(self): 37 | """Test that supported datasets return the correct types.""" 38 | with self.assertRaises(AssertionError): 39 | loader = DGLFileLoader("ACM", [0.6, 0.3, 0.4], 0) 40 | 41 | def test_non_existing_dataset(self): 42 | """Test that a non supported dataset raises an error""" 43 | with self.assertRaises(ValueError): 44 | DGLFileLoader(self.non_existing_dataset, [0.3, 0.3, 0.4], 0) 45 | 46 | def test_not_enough_parameters(self): 47 | """Test that a non supported dataset raises an error""" 48 | with self.assertRaises(TypeError): 49 | DGLFileLoader() 50 | 51 | def tearDown(self): 52 | """Tear down the tests.""" 53 | pass 54 | -------------------------------------------------------------------------------- /test/data/graph/test_general_file_loader.py: -------------------------------------------------------------------------------- 1 | """Unit test for complex_module.core.""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | 10 | import unittest 11 | from mmmt.data.graph.general_file_loader import GeneralFileLoader 12 | 13 | 14 | class GeneralFileLoaderTestCase(unittest.TestCase): 15 | """GeneralFileLoaderTestCase class.""" 16 | 17 | def setUp(self): 18 | """Setting up the test.""" 19 | pass 20 | 21 | def test_incorrect_splits(self): 22 | """Test that supported datasets return the correct types.""" 23 | with self.assertRaises(AssertionError): 24 | loader = GeneralFileLoader(None, [0.6, 0.3, 0.4], 0) 25 | 26 | def test_not_enough_parameters(self): 27 | """Test that a non supported dataset raises an error""" 28 | with self.assertRaises(TypeError): 29 | GeneralFileLoader() 30 | 31 | def tearDown(self): 32 | """Tear down the tests.""" 33 | pass 34 | -------------------------------------------------------------------------------- /test/data/graph/test_graph_to_graph.py: -------------------------------------------------------------------------------- 1 | """Unit test for mmmt.data.graph.graph2graph""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | import unittest 10 | from mmmt.data.graph.graph_to_graph import GraphTransform 11 | import numpy as np 12 | import dgl 13 | 14 | 15 | class GraphTransformTestCase(unittest.TestCase): 16 | """GraphTransformTestCase class.""" 17 | 18 | def setUp(self): 19 | """Setting up the test.""" 20 | # preparing dummy data 21 | self.num_nodes = 10 22 | self.num_edges = 50 23 | self.graph = dgl.rand_graph(self.num_nodes, self.num_edges) 24 | self.multi_graph_data = { 25 | ("feat", "mod1", "feat"): dgl.rand_graph( 26 | self.num_nodes, self.num_edges 27 | ).edges(), 28 | ("feat", "mod2", "feat"): dgl.rand_graph( 29 | self.num_nodes, self.num_edges 30 | ).edges(), 31 | ("feat", "mod3", "feat"): dgl.rand_graph( 32 | self.num_nodes, self.num_edges 33 | ).edges(), 34 | } 35 | self.multigraph = dgl.heterograph(self.multi_graph_data) 36 | 37 | self.node_features = np.random.rand(self.num_nodes) 38 | 39 | self.non_existing_method = "test" 40 | 41 | def test_mplex(self): 42 | """Test transformation to multiplex graph.""" 43 | GT = GraphTransform(self.multigraph, self.node_features, "mplex") 44 | derived_graphs, mplex_node_features = GT.transform() 45 | graph_mplex_i, graph_mplex_ii = derived_graphs 46 | total_num_edges = 0 47 | for etype in graph_mplex_i.etypes: 48 | total_num_edges += graph_mplex_i.num_edges(etype) 49 | 50 | self.assertIsInstance(graph_mplex_i, dgl.DGLGraph) 51 | self.assertEqual( 52 | graph_mplex_i.num_nodes() * len(graph_mplex_i.etypes), 53 | self.num_nodes * len(self.multi_graph_data), 54 | ) 55 | self.assertEqual( 56 | total_num_edges, self.num_edges * len(self.multi_graph_data) ** 2 57 | ) 58 | 59 | self.assertIsInstance(graph_mplex_ii, dgl.DGLGraph) 60 | self.assertEqual( 61 | graph_mplex_ii.num_nodes() * len(graph_mplex_ii.etypes), 62 | self.num_nodes * len(self.multi_graph_data), 63 | ) 64 | 65 | def test_mplex_prop(self): 66 | """Test transformation to multiplex graph with properties.""" 67 | GT = GraphTransform( 68 | self.multigraph, self.node_features, "mplex-prop", alpha=0.3 69 | ) 70 | derived_graphs, mplex_node_features = GT.transform() 71 | graph_mplex_i, graph_mplex_ii = derived_graphs 72 | total_num_edges = 0 73 | for etype in graph_mplex_i.etypes: 74 | total_num_edges += graph_mplex_i.num_edges(etype) 75 | ef_1 = graph_mplex_i.edata["w"].float() 76 | ef_2 = graph_mplex_ii.edata["w"].float() 77 | 78 | self.assertIsInstance(graph_mplex_i, dgl.DGLGraph) 79 | self.assertEqual( 80 | graph_mplex_i.num_nodes() * len(graph_mplex_i.etypes), 81 | self.num_nodes * len(self.multi_graph_data), 82 | ) 83 | self.assertEqual( 84 | total_num_edges, self.num_edges * len(self.multi_graph_data) ** 2 85 | ) 86 | 87 | self.assertIsInstance(graph_mplex_ii, dgl.DGLGraph) 88 | self.assertEqual( 89 | graph_mplex_ii.num_nodes() * len(graph_mplex_ii.etypes), 90 | self.num_nodes * len(self.multi_graph_data), 91 | ) 92 | 93 | self.assertEqual(len(ef_1), self.num_edges * len(self.multi_graph_data) ** 2) 94 | self.assertEqual(len(ef_2), self.num_edges * len(self.multi_graph_data) ** 2) 95 | 96 | def test_multibehav(self): 97 | """Test transformation to mulit-behavioural graph.""" 98 | GT = GraphTransform(self.multigraph, self.node_features, "multibehav") 99 | derived_graphs, mplex_node_features = GT.transform() 100 | quotient_graph, mplex_graph = derived_graphs 101 | 102 | self.assertIsInstance(quotient_graph, dgl.DGLGraph) 103 | self.assertEqual( 104 | quotient_graph.num_nodes() * len(quotient_graph.etypes), self.num_nodes 105 | ) 106 | 107 | self.assertIsInstance(mplex_graph, dgl.DGLGraph) 108 | self.assertEqual( 109 | mplex_graph.num_nodes() * len(mplex_graph.etypes), 110 | self.num_nodes * len(self.multi_graph_data), 111 | ) 112 | self.assertEqual( 113 | mplex_graph.num_edges(), 114 | self.num_edges * len(self.multi_graph_data) 115 | + self.num_nodes * len(self.multi_graph_data) * 2, 116 | ) 117 | 118 | def test_mGNN(self): 119 | """Test transformation to mGNN.""" 120 | GT = GraphTransform(self.multigraph, self.node_features, "mgnn") 121 | derived_graph, mplex_node_features = GT.transform() 122 | g_inter_layer, g_intra_layer = derived_graph 123 | 124 | self.assertIsInstance(g_inter_layer, dgl.DGLGraph) 125 | self.assertEqual( 126 | self.num_nodes * len(self.multi_graph_data), g_inter_layer.num_nodes() 127 | ) 128 | self.assertEqual( 129 | self.num_nodes * len(self.multi_graph_data) * 2, g_inter_layer.num_edges() 130 | ) 131 | 132 | self.assertIsInstance(g_intra_layer, dgl.DGLGraph) 133 | self.assertEqual( 134 | self.num_nodes * len(self.multi_graph_data), g_intra_layer.num_nodes() 135 | ) 136 | self.assertEqual( 137 | self.num_edges * len(self.multi_graph_data), g_intra_layer.num_edges() 138 | ) 139 | 140 | def test_non_existing_method(self): 141 | """Test that a non supported dataset raises an error""" 142 | with self.assertRaises(ValueError): 143 | GraphTransform( 144 | self.graph, self.node_features, self.non_existing_method 145 | ).transform() 146 | 147 | def tearDown(self): 148 | """Tear down the tests.""" 149 | pass 150 | -------------------------------------------------------------------------------- /test/data/graph/test_mat_file_loader.py: -------------------------------------------------------------------------------- 1 | """Unit test for complex_module.core.""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | 10 | import unittest 11 | from mmmt.data.graph.mat_file_loader import MatFileLoader 12 | from dgl import DGLGraph 13 | import torch 14 | 15 | 16 | class MatFileLoaderTestCase(unittest.TestCase): 17 | """MatFileLoaderTestCase class.""" 18 | 19 | def setUp(self): 20 | """Setting up the test.""" 21 | self.existing_datasets = ["ACM"] 22 | self.non_existing_dataset = "this_dataset_does_not_exist" 23 | pass 24 | 25 | def test_dataset_names(self): 26 | """Test that supported datasets return the correct types.""" 27 | for dataset_name in self.existing_datasets: 28 | loader = MatFileLoader(dataset_name, [0.3, 0.3, 0.4], 0) 29 | g, labels, data_splits, n_classes, num_rels = loader.build_graph() 30 | self.assertIsInstance(g, DGLGraph) 31 | self.assertIsInstance(labels, torch.Tensor) 32 | self.assertIsInstance(data_splits, list) 33 | self.assertIsInstance(n_classes, int) 34 | self.assertIsInstance(num_rels, int) 35 | 36 | def test_incorrect_splits(self): 37 | """Test that supported datasets return the correct types.""" 38 | with self.assertRaises(AssertionError): 39 | loader = MatFileLoader("ACM", [0.6, 0.3, 0.4], 0) 40 | 41 | def test_non_existing_dataset(self): 42 | """Test that a non supported dataset raises an error""" 43 | with self.assertRaises(ValueError): 44 | MatFileLoader(self.non_existing_dataset, [0.3, 0.3, 0.4], 0) 45 | 46 | def test_not_enough_parameters(self): 47 | """Test that a non supported dataset raises an error""" 48 | with self.assertRaises(TypeError): 49 | MatFileLoader() 50 | 51 | def tearDown(self): 52 | """Tear down the tests.""" 53 | pass 54 | -------------------------------------------------------------------------------- /test/data/graph/test_visualization.py: -------------------------------------------------------------------------------- 1 | """Unit test for complex_module.core.""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | 10 | import unittest 11 | from mmmt.data.graph.visualization import GraphVisualization 12 | import dgl 13 | from dgl import DGLGraph 14 | import torch 15 | from fuse.data.datasets.dataset_default import DatasetDefault 16 | from fuse.data.pipelines.pipeline_default import PipelineDefault 17 | from fuse.data.ops.ops_read import OpReadDataframe 18 | 19 | 20 | import pandas as pd 21 | 22 | 23 | class GraphVisualizationTestCase(unittest.TestCase): 24 | """MatFileLoaderTestCase class.""" 25 | 26 | def setUp(self): 27 | """Setting up the test.""" 28 | 29 | self.modality_names = ( 30 | ["modality1.features"] * 2 31 | + ["modality2.features"] * 2 32 | + ["modality3.features"] * 2 33 | ) 34 | 35 | data = { 36 | "sample_id": [0, 1, 2, 3], 37 | "node_names": [ 38 | self.modality_names, 39 | self.modality_names, 40 | self.modality_names, 41 | self.modality_names, 42 | ], 43 | "graph": [ 44 | dgl.rand_graph(6, 2), 45 | dgl.rand_graph(6, 1), 46 | dgl.rand_graph(6, 2), 47 | dgl.rand_graph(6, 4), 48 | ], 49 | } 50 | df = pd.DataFrame(data) 51 | pipeline = PipelineDefault("test_mock_data", [(OpReadDataframe(df), dict())]) 52 | sample_ids = list(range(4)) 53 | 54 | self.dataset = DatasetDefault(sample_ids, static_pipeline=pipeline) 55 | self.dataset.create() 56 | pass 57 | 58 | def test_visualization(self): 59 | """Test that visualization works.""" 60 | G = GraphVisualization.visualize_dataset( 61 | self.dataset, graph_key="graph", node_names_key="node_names" 62 | ) 63 | assert len(G.nodes) == 6 64 | 65 | def tearDown(self): 66 | """Tear down the tests.""" 67 | pass 68 | -------------------------------------------------------------------------------- /test/data/operators/test_operators.py: -------------------------------------------------------------------------------- 1 | """Unit test for complex_module.core.""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | 10 | import unittest 11 | import torch 12 | import pandas as pd 13 | from fuse.data.datasets.dataset_default import DatasetDefault 14 | from fuse.data.pipelines.pipeline_default import PipelineDefault 15 | from fuse.data.ops.ops_read import OpReadDataframe 16 | from fuse.data.ops.ops_cast import OpToTensor 17 | from mmmt.data.operators.op_forwardpass import OpForwardPass 18 | from mmmt.data.operators.op_concat_names import OpConcatNames 19 | from mmmt.data.operators.op_resample import Op3DResample 20 | import numpy as np 21 | 22 | 23 | class GBMRTestCase(unittest.TestCase): 24 | """CoreTestCase class.""" 25 | 26 | def setUp(self): 27 | """Setting up the test.""" 28 | 29 | data = { 30 | "sample_id": [0, 1, 2, 3], 31 | "modality1": [ 32 | np.random.randint((1, 5, 5)), 33 | np.random.randint((1, 5, 5)), 34 | np.random.randint((1, 5, 5)), 35 | np.random.randint((1, 5, 5)), 36 | ], 37 | "modality2": [ 38 | np.random.randint((1, 5)), 39 | np.random.randint((1, 5)), 40 | np.random.randint((1, 5)), 41 | np.random.randint((1, 5)), 42 | ], 43 | "modality3D": [ 44 | torch.rand((5, 5, 3)), 45 | torch.rand((5, 5, 3)), 46 | torch.rand((5, 5, 3)), 47 | torch.rand((5, 5, 3)), 48 | ], 49 | } 50 | self.df = pd.DataFrame(data) 51 | 52 | pass 53 | 54 | def test_resample(self): 55 | """Test resample().""" 56 | pipeline_list = [ 57 | (OpReadDataframe(self.df), dict()), 58 | (OpToTensor(), dict(key="modality3D", dtype=torch.float)), 59 | ( 60 | Op3DResample([4, 2, 3]), 61 | dict(key_in="modality3D", key_out="modality3D_resampled"), 62 | ), 63 | ] 64 | pipeline = PipelineDefault("test_mock_data", pipeline_list) 65 | sample_ids = list(range(4)) 66 | 67 | dataset = DatasetDefault(sample_ids, static_pipeline=pipeline) 68 | dataset.create() 69 | assert torch.numel(dataset[0]["modality3D_resampled"]) == 24 70 | 71 | def test_forwardpass_and_concat_names(self): 72 | """Test forwardpass().""" 73 | identity = torch.nn.Identity() 74 | pipeline_list = [ 75 | (OpReadDataframe(self.df), dict()), 76 | ( 77 | OpForwardPass(identity, 1), 78 | dict(key_in="modality2", key_out="modality2fp"), 79 | ), 80 | ( 81 | OpConcatNames(), 82 | dict(keys_in=["modality2fp", "modality2fp"], key_out="names.concat"), 83 | ), 84 | ] 85 | pipeline = PipelineDefault("test_mock_data", pipeline_list) 86 | sample_ids = list(range(4)) 87 | 88 | dataset = DatasetDefault(sample_ids, static_pipeline=pipeline) 89 | dataset.create() 90 | 91 | assert len(dataset[0]["names.modality2fp"]) + len( 92 | dataset[0]["names.modality2fp"] 93 | ) == len(dataset[0]["names.concat"]) 94 | 95 | def tearDown(self): 96 | """Tear down the tests.""" 97 | pass 98 | -------------------------------------------------------------------------------- /test/data/representation/test_encoded_unimodal_to_concept.py: -------------------------------------------------------------------------------- 1 | """Unit test for complex_module.core.""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | import os 10 | from tempfile import mkdtemp 11 | from mmmt.data.representation.encoded_unimodal_to_concept import ( 12 | EncodedUnimodalToConcept, 13 | ) 14 | import unittest 15 | import numpy as np 16 | from fuse.data.pipelines.pipeline_default import PipelineDefault 17 | from fuse.data.ops.ops_read import OpReadDataframe 18 | import pandas as pd 19 | 20 | 21 | class EncodedUnimodalToConceptTestCase(unittest.TestCase): 22 | """CoreTestCase class.""" 23 | 24 | def setUp(self): 25 | """Setting up the test.""" 26 | root = mkdtemp(prefix="EUTCTestCase") 27 | 28 | data = { 29 | "sample_id": [0, 1, 2, 3], 30 | "modality1": [ 31 | np.random.rand(64).astype(np.float32), 32 | np.random.rand(64).astype(np.float32), 33 | np.random.rand(64).astype(np.float32), 34 | np.random.rand(64).astype(np.float32), 35 | ], 36 | "modality2": [ 37 | np.random.rand(16).astype(np.float32), 38 | np.random.rand(16).astype(np.float32), 39 | np.random.rand(16).astype(np.float32), 40 | np.random.rand(16).astype(np.float32), 41 | ], 42 | } 43 | df = pd.DataFrame(data) 44 | 45 | pipeline_list = [ 46 | (OpReadDataframe(df), dict()), 47 | ] 48 | dataset_pipeline = PipelineDefault("static", pipeline_list) 49 | 50 | # Define splits 51 | training_sample_ids = [0, 1] 52 | val_sample_ids = [2] 53 | test_sample_ids = [3] 54 | 55 | self.fusion_strategy = { 56 | "pipeline": { 57 | "fuse_pipeline": dataset_pipeline, 58 | "data_splits": { 59 | "train_ids": training_sample_ids, 60 | "val_ids": val_sample_ids, 61 | "test_ids": test_sample_ids, 62 | }, 63 | }, 64 | "num_workers": 1, 65 | "restart_cache": True, 66 | "root_dir": root, 67 | "step_args": { 68 | "use_autoencoders": True, 69 | "add_feature_names": False, 70 | "encoding_layers": [32, 16], 71 | "use_pretrained": False, 72 | "batch_size": 3, 73 | "training": { 74 | "model_dir": "model_concept", 75 | "pl_trainer_num_epochs": 1, 76 | "pl_trainer_accelerator": "cpu", 77 | }, 78 | "io": { 79 | "concept_encoder_model_key": "concept_encoder_model", 80 | "input_keys": ["modality1", "modality2"], 81 | "output_key": "data.input.concatenated", 82 | }, 83 | }, 84 | } 85 | 86 | pass 87 | 88 | def test_empty_contructor(self): 89 | """Test salutation().""" 90 | with self.assertRaises(TypeError): 91 | EncodedUnimodalToConcept() 92 | 93 | def test_pipeline_too_soon(self): 94 | """Test salutation().""" 95 | EUTC = EncodedUnimodalToConcept(self.fusion_strategy) 96 | 97 | with self.assertRaises(AttributeError): 98 | EUTC.get_pipeline() 99 | 100 | def test_eutc(self): 101 | 102 | EUTC = EncodedUnimodalToConcept(self.fusion_strategy) 103 | 104 | EUTC.__call__() 105 | 106 | def tearDown(self): 107 | """Tear down the tests.""" 108 | pass 109 | 110 | 111 | if __name__ == "__main__": 112 | unittest.main() 113 | -------------------------------------------------------------------------------- /test/data/representation/test_modality_encoding.py: -------------------------------------------------------------------------------- 1 | """Unit test for complex_module.core.""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | import os 10 | from tempfile import mkdtemp 11 | from mmmt.data.representation.modality_encoding import ( 12 | ModalityEncoding, 13 | ) 14 | import unittest 15 | import numpy as np 16 | from fuse.data.pipelines.pipeline_default import PipelineDefault 17 | from fuse.data.ops.ops_read import OpReadDataframe 18 | import pandas as pd 19 | 20 | 21 | class ModalityEncodingTestCase(unittest.TestCase): 22 | """CoreTestCase class.""" 23 | 24 | def setUp(self): 25 | """Setting up the test.""" 26 | root = mkdtemp(prefix="METestCase") 27 | 28 | data = { 29 | "sample_id": [0, 1, 2, 3], 30 | "modality1": [ 31 | np.random.rand(64).astype(np.float32), 32 | np.random.rand(64).astype(np.float32), 33 | np.random.rand(64).astype(np.float32), 34 | np.random.rand(64).astype(np.float32), 35 | ], 36 | "modality2": [ 37 | np.random.rand(16).astype(np.float32), 38 | np.random.rand(16).astype(np.float32), 39 | np.random.rand(16).astype(np.float32), 40 | np.random.rand(16).astype(np.float32), 41 | ], 42 | } 43 | df = pd.DataFrame(data) 44 | 45 | pipeline_list = [ 46 | (OpReadDataframe(df), dict()), 47 | ] 48 | dataset_pipeline = PipelineDefault("static", pipeline_list) 49 | 50 | # Define splits 51 | training_sample_ids = [0, 1] 52 | val_sample_ids = [2] 53 | test_sample_ids = [3] 54 | 55 | self.encoding_strategy = { 56 | "pipeline": { 57 | "fuse_pipeline": dataset_pipeline, 58 | "data_splits": { 59 | "train_ids": training_sample_ids, 60 | "val_ids": val_sample_ids, 61 | "test_ids": test_sample_ids, 62 | }, 63 | }, 64 | "num_workers": 1, 65 | "restart_cache": True, 66 | "root_dir": root, 67 | "step_args": { 68 | "modality1": { 69 | "model": None, 70 | "use_autoencoder": False, 71 | "output_key": "modality1_output_key", 72 | }, 73 | "modality2": { 74 | "model": None, 75 | "add_feature_names": False, 76 | "use_autoencoder": True, 77 | "output_key": "modality1_output_key", 78 | "encoding_layers": [16, 4], 79 | "use_pretrained": False, 80 | "batch_size": 3, 81 | "training": { 82 | "model_dir": "model_image_features", 83 | "pl_trainer_num_epochs": 1, 84 | "pl_trainer_accelerator": "cpu", 85 | }, 86 | }, 87 | }, 88 | "object_registry": None, 89 | } 90 | 91 | pass 92 | 93 | def test_empty_contructor(self): 94 | """Test salutation().""" 95 | with self.assertRaises(TypeError): 96 | ModalityEncoding() 97 | 98 | def test_pipeline_too_soon(self): 99 | """Test salutation().""" 100 | ME = ModalityEncoding(self.encoding_strategy) 101 | 102 | with self.assertRaises(AttributeError): 103 | ME.get_pipeline() 104 | 105 | def test_me(self): 106 | 107 | ME = ModalityEncoding(self.encoding_strategy) 108 | 109 | ME.__call__() 110 | 111 | def tearDown(self): 112 | """Tear down the tests.""" 113 | pass 114 | 115 | 116 | if __name__ == "__main__": 117 | unittest.main() 118 | -------------------------------------------------------------------------------- /test/models/classic/test_classic_models.py: -------------------------------------------------------------------------------- 1 | """Unit test for models defined in mmmt.models.classic""" 2 | 3 | import unittest 4 | import numpy as np 5 | from mmmt.models.classic.late_fusion import LateFusion 6 | from mmmt.models.classic.uncertainty_late_fusion import UncertaintyLateFusion 7 | 8 | 9 | class LateFusionTestCase(unittest.TestCase): 10 | """CoreTestCase class.""" 11 | 12 | def setUp(self): 13 | """Setting up the test.""" 14 | pass 15 | 16 | def test_late_fusion(self): 17 | n_mods = 20 18 | n_samples = 1000 19 | n_classes = 10 20 | 21 | lf = LateFusion(n_mods, n_classes) 22 | 23 | predictions_test = np.random.rand(n_mods, n_samples, n_classes) 24 | fused_test = lf.apply_fusion(predictions_test) 25 | self.assertEqual(fused_test.shape[0], predictions_test.shape[1]) 26 | 27 | 28 | class UncertaintyFusionTestCase(unittest.TestCase): 29 | """CoreTestCase class.""" 30 | 31 | def setUp(self): 32 | """Setting up the test.""" 33 | pass 34 | 35 | def test_uncertainty_late_fusion(self): 36 | n_mods = 20 37 | n_samples = 1000 38 | n_classes = 10 39 | 40 | predictions_valid = np.random.rand(n_mods, n_samples, n_classes) 41 | GT_valid = np.random.rand(n_samples, n_classes) 42 | 43 | ulf = UncertaintyLateFusion(n_mods, n_classes) 44 | ulf.k = 5 45 | 46 | ulf.compute_fusion_weights(predictions_valid, GT_valid) 47 | 48 | predictions_test = np.random.rand(n_mods, n_samples, n_classes) 49 | fused_test = ulf.apply_fusion(predictions_test) 50 | self.assertEqual(fused_test.shape[0], predictions_test.shape[1]) 51 | -------------------------------------------------------------------------------- /test/models/graph/test_graph_models.py: -------------------------------------------------------------------------------- 1 | """Unit test for models defined in mmmt.models.graph""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | import unittest 10 | import dgl 11 | import numpy as np 12 | import torch 13 | 14 | from mmmt.models.graph.multiplex_gin import MultiplexGIN 15 | from mmmt.models.graph.multiplex_gcn import MultiplexGCN 16 | from mmmt.data.graph.graph_to_graph import GraphTransform 17 | from mmmt.models.graph.relational_gcn import Relational_GCN 18 | from mmmt.models.graph.mgnn import mGNN 19 | from mmmt.models.graph.gcn import GCN 20 | from mmmt.models.graph.multi_behavioral_gnn import MultiBehavioralGNN 21 | 22 | 23 | class CoreTestCase(unittest.TestCase): 24 | """CoreTestCase class.""" 25 | 26 | def setUp(self): 27 | """Setting up the test.""" 28 | # preparing dummy data 29 | self.num_nodes = 10 30 | self.num_edges = 100 31 | self.graph = dgl.rand_graph(self.num_nodes, self.num_edges) 32 | multi_graph_data = { 33 | ("feat", "mod1", "feat"): dgl.rand_graph( 34 | self.num_nodes, self.num_edges 35 | ).edges(), 36 | ("feat", "mod2", "feat"): dgl.rand_graph( 37 | self.num_nodes, self.num_edges 38 | ).edges(), 39 | ("feat", "mod3", "feat"): dgl.rand_graph( 40 | self.num_nodes, self.num_edges 41 | ).edges(), 42 | } 43 | self.multigraph = dgl.heterograph(multi_graph_data) 44 | self.n_layers = len(self.multigraph.etypes) 45 | 46 | self.node_emb_dim = 4 47 | self.features = np.random.rand(self.num_nodes, self.node_emb_dim).astype( 48 | np.float32 49 | ) 50 | 51 | def test_mplx_gin(self): 52 | GT = GraphTransform(self.multigraph, self.features, graph_module="mplex") 53 | derived_graph, features_multigraph = GT.transform() 54 | graph_dict = { 55 | "graph": derived_graph, 56 | "node_features": torch.from_numpy(features_multigraph), 57 | } 58 | 59 | module_input = { 60 | "node_emb_dim": self.node_emb_dim, 61 | } 62 | 63 | MPLX_GIN = MultiplexGIN(module_input) 64 | 65 | h = MPLX_GIN.forward(graph_dict) 66 | self.assertEqual(self.num_nodes * self.n_layers, h.shape[0]) 67 | self.assertEqual(self.node_emb_dim * 4, h.shape[1]) 68 | 69 | def test_mplx_gcn(self): 70 | GT = GraphTransform(self.multigraph, self.features, graph_module="mplex-prop") 71 | derived_graph, features_multigraph = GT.transform() 72 | graph_dict = { 73 | "graph": derived_graph, 74 | "node_features": torch.from_numpy(features_multigraph), 75 | } 76 | 77 | in_size = self.node_emb_dim 78 | gl_hidden_size = [max(int(in_size / 2), 1), 5] 79 | module_input = { 80 | "node_emb_dim": self.node_emb_dim, 81 | "gl_hidden_size": gl_hidden_size, 82 | } 83 | 84 | MPLX_GCN = MultiplexGCN(module_input) 85 | 86 | h = MPLX_GCN.forward(graph_dict) 87 | self.assertEqual(self.num_nodes * self.n_layers, h.shape[0]) 88 | self.assertEqual(gl_hidden_size[-1] * 2, h.shape[1]) 89 | 90 | def test_rgcn(self): 91 | GT = GraphTransform(self.multigraph, self.features, graph_module="rgcn") 92 | graph, features_graph = GT.transform() 93 | graph_dict = { 94 | "graph": graph, 95 | "node_features": torch.from_numpy(features_graph), 96 | } 97 | 98 | in_size = self.node_emb_dim 99 | gl_hidden_size = [int(in_size / 2)] 100 | num_bases = 8 101 | module_input = { 102 | "node_emb_dim": self.node_emb_dim, 103 | "n_layers": self.n_layers, 104 | "num_bases": num_bases, 105 | "gl_hidden_size": gl_hidden_size, 106 | } 107 | 108 | RGCN = Relational_GCN(module_input) 109 | 110 | h = RGCN.forward(graph_dict) 111 | self.assertEqual(self.num_nodes, h.shape[0]) 112 | self.assertEqual(gl_hidden_size[-1], h.shape[1]) 113 | 114 | def test_mgnn(self): 115 | GT = GraphTransform(self.multigraph, self.features, graph_module="mgnn") 116 | derived_graph, features_multigraph = GT.transform() 117 | graph_dict = { 118 | "graph": derived_graph, 119 | "node_features": torch.from_numpy(features_multigraph), 120 | } 121 | 122 | in_size = self.node_emb_dim 123 | gl_hidden_size = [int(in_size / 2)] 124 | num_att_heads = 8 125 | module_input = { 126 | "node_emb_dim": self.node_emb_dim, 127 | "n_layers": self.n_layers, 128 | "num_att_heads": num_att_heads, 129 | "gl_hidden_size": gl_hidden_size, 130 | } 131 | 132 | MGNN = mGNN(module_input) 133 | 134 | h = MGNN.forward(graph_dict) 135 | self.assertEqual(self.num_nodes * self.n_layers, h.shape[0]) 136 | self.assertEqual(gl_hidden_size[-1] * 2 * num_att_heads, h.shape[1]) 137 | # factor 2 above for concatenation of inter and intra layer 138 | 139 | def test_gcn(self): 140 | GT = GraphTransform(self.multigraph, self.features, graph_module="gcn") 141 | multigraph, features_multigraph = GT.transform() 142 | graph_dict = { 143 | "graph": multigraph, 144 | "node_features": torch.from_numpy(features_multigraph), 145 | } 146 | 147 | in_size = self.node_emb_dim 148 | gl_hidden_size = [int(in_size / 2)] 149 | module_input = { 150 | "node_emb_dim": self.node_emb_dim, 151 | "gl_hidden_size": gl_hidden_size, 152 | } 153 | 154 | GCN_obj = GCN(module_input) 155 | 156 | h = GCN_obj.forward(graph_dict) 157 | self.assertEqual(self.num_nodes, h.shape[0]) 158 | self.assertEqual(gl_hidden_size[-1], h.shape[1]) 159 | 160 | def test_MBGCN(self): 161 | GT = GraphTransform(self.multigraph, self.features, graph_module="multibehav") 162 | derived_graph, features_multigraph = GT.transform() 163 | graph_dict = { 164 | "graph": derived_graph, 165 | "node_features": torch.from_numpy(self.features), 166 | } 167 | 168 | in_size = self.node_emb_dim 169 | gl_hidden_size = [int(in_size / 2)] 170 | module_input = { 171 | "node_emb_dim": self.node_emb_dim, 172 | "n_layers": self.n_layers, 173 | "gl_hidden_size": gl_hidden_size, 174 | } 175 | 176 | MBGCN = MultiBehavioralGNN(module_input) 177 | 178 | h = MBGCN.forward(graph_dict) 179 | self.assertEqual(self.num_nodes * self.n_layers, h.shape[0]) 180 | self.assertEqual(gl_hidden_size[-1], h.shape[1]) 181 | 182 | def tearDown(self): 183 | """Tear down the tests.""" 184 | pass 185 | -------------------------------------------------------------------------------- /test/models/head/test_head_models.py: -------------------------------------------------------------------------------- 1 | """Unit test for models defined in mmmt.models.head""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | import unittest 10 | import torch 11 | from mmmt.models.head.mlp import MLP 12 | 13 | 14 | class CoreTestCase(unittest.TestCase): 15 | """CoreTestCase class.""" 16 | 17 | def setUp(self): 18 | """Setting up the test.""" 19 | # preparing dummy data 20 | self.num_nodes = 10 21 | self.head_in_size = 32 22 | self.out_size = 5 23 | self.batch_size = 2 24 | 25 | self.h = torch.rand(self.num_nodes, self.batch_size, self.head_in_size) 26 | pass 27 | 28 | def test_mlp(self): 29 | # preparing head model object 30 | head_hidden_size = [20, 10] 31 | dropout = 0.5 32 | 33 | MLP_obj = MLP( 34 | self.num_nodes, self.head_in_size, head_hidden_size, self.out_size, dropout 35 | ) 36 | 37 | h = MLP_obj(self.h) 38 | 39 | self.assertEqual(self.out_size * self.batch_size, torch.numel(h)) 40 | 41 | def tearDown(self): 42 | """Tear down the tests.""" 43 | pass 44 | -------------------------------------------------------------------------------- /test/models/test_model_builder.py: -------------------------------------------------------------------------------- 1 | """Unit test for mmmt.models.model_builder""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | import unittest 10 | import dgl 11 | import torch 12 | import numpy as np 13 | import mmmt 14 | from mmmt.data.graph.graph_to_graph import GraphTransform 15 | 16 | 17 | class CoreTestCase(unittest.TestCase): 18 | """CoreTestCase class.""" 19 | 20 | def setUp(self): 21 | """Setting up the test.""" 22 | # preparing dummy data 23 | num_nodes = 10 24 | num_edges = 100 25 | graph = dgl.rand_graph(num_nodes, num_edges) 26 | node_emb_dim = 4 27 | features = np.random.rand(num_nodes, node_emb_dim).astype(np.float32) 28 | 29 | self.graph_model_input = { 30 | "graph_module": { 31 | "module_identifier": "rgcn", 32 | "thresh_q": 0.95, 33 | "node_emb_dim": node_emb_dim, 34 | "gl_hidden_size": [2], 35 | "num_att_heads": 4, 36 | "num_bases": 8, 37 | "n_layers": 1, 38 | }, 39 | "head_module": { 40 | "head_hidden_size": [100, 20], 41 | "dropout": 0.5, 42 | }, 43 | } 44 | 45 | self.out_size = 5 46 | 47 | GT = GraphTransform( 48 | graph, 49 | features, 50 | graph_module=self.graph_model_input["graph_module"]["module_identifier"], 51 | ) 52 | derived_graph, features_multigraph = GT.transform() 53 | self.graph_dict = { 54 | "graph": derived_graph, 55 | "node_features": torch.from_numpy(features_multigraph), 56 | } 57 | self.graph_sample = derived_graph 58 | 59 | self.batch_size = 1 60 | 61 | pass 62 | 63 | def test_model_builder(self): 64 | """Test ModelBuilder.""" 65 | head_hidden_size = self.graph_model_input["head_module"]["head_hidden_size"] 66 | dropout = self.graph_model_input["head_module"]["dropout"] 67 | MC = mmmt.models.graph.module_configurator.ModuleConfigurator( 68 | self.graph_model_input["graph_module"] 69 | ) 70 | graph_model, head_in_size, head_num_nodes = MC.get_module(self.graph_sample) 71 | head_model = mmmt.models.head.mlp.MLP( 72 | head_num_nodes, head_in_size, head_hidden_size, self.out_size, dropout 73 | ) 74 | 75 | MB = mmmt.models.model_builder.ModelBuilder( 76 | graph_model, head_model, self.batch_size 77 | ) 78 | h = MB.forward(self.graph_dict) 79 | self.assertEqual(h.shape[1], self.out_size) 80 | 81 | def tearDown(self): 82 | """Tear down the tests.""" 83 | pass 84 | -------------------------------------------------------------------------------- /test/pipeline/test.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | - fuse_object: "static_pipeline" 3 | args: 4 | name: "Test dataset" 5 | - object: "get_splits" 6 | args: 7 | name: "splits" 8 | io: 9 | input_key: null 10 | output_key: "data_splits" 11 | 12 | mlflow: 13 | MLFLOW_TRACKING_URI: null 14 | MLFLOW_EXPERIMENT_NAME: null 15 | 16 | cache: 17 | num_workers: 1 18 | restart_cache: True 19 | root_dir: "_examples/tests/cache" 20 | 21 | modality_encoding_strategy: 22 | - object: "ModalityEncoding" 23 | args: 24 | data.input.raw.modality1: 25 | model: null 26 | output_key: "data.input.encoded.modality1" 27 | use_autoencoder: True 28 | encoding_layers: 29 | - 128 30 | - 64 31 | use_pretrained: False 32 | batch_size: 5 33 | # training config for modality autoencoder. Details in mmmt.data.representation.AutoEncoderTrainer.set_train_config(..) 34 | training: 35 | model_dir: "model_modality1" 36 | pl_trainer_num_epochs: 1 37 | pl_trainer_accelerator: "cpu" 38 | data.input.raw.modality2: 39 | model: null 40 | output_key: "data.input.encoded.modality2" 41 | use_autoencoder: True 42 | encoding_layers: 43 | - 32 44 | - 32 45 | use_pretrained: False 46 | batch_size: 5 47 | # training config for modality autoencoder. Details in mmmt.data.representation.AutoEncoderTrainer.set_train_config(..) 48 | training: 49 | model_dir: "model_modality2" 50 | pl_trainer_num_epochs: 1 51 | pl_trainer_accelerator: "cpu" 52 | data.input.raw.modality3: 53 | model: null 54 | output_key: "data.input.encoded.modality3" 55 | use_autoencoder: True 56 | encoding_layers: 57 | - 64 58 | - 64 59 | use_pretrained: False 60 | batch_size: 5 61 | # training config for modality autoencoder. Details in mmmt.data.representation.AutoEncoderTrainer.set_train_config(..) 62 | training: 63 | model_dir: "model_modality3" 64 | pl_trainer_num_epochs: 1 65 | pl_trainer_accelerator: "cpu" 66 | 67 | 68 | fusion_strategy: 69 | - object: "EncodedUnimodalToConcept" # early or late 70 | args: 71 | use_autoencoders: True 72 | add_feature_names: False 73 | encoding_layers: 74 | - 32 75 | - &n_layers 16 76 | use_pretrained: False 77 | batch_size: 5 78 | training: 79 | model_dir: "model_concept" 80 | pl_trainer_num_epochs: 1 81 | pl_trainer_accelerator: "cpu" 82 | io: 83 | concept_encoder_model_key: "concept_encoder_model" 84 | input_keys: 85 | - "data.input.encoded.modality1" 86 | - "data.input.encoded.modality2" 87 | - "data.input.encoded.modality3" 88 | output_key: "data.input.concatenated" 89 | - object: "ConceptToGraph" 90 | args: 91 | module_identifier: &graph_module "mplex" 92 | thresh_q: 0.95 93 | io: 94 | concept_encoder_model_key: "concept_encoder_model" 95 | fused_dataset_key: "fused_dataset" 96 | input_key: "data.input.concatenated" 97 | output_key: "data.derived_graph" 98 | - object: "GraphVisualization" 99 | args: 100 | selected_samples: 101 | graph_train_dataset: 102 | - 0 103 | - 1 104 | graph_validation_dataset: "all" 105 | graph_test_dataset: "all" 106 | feature_group_sizes: 107 | modality1: 64 108 | modality2: 32 109 | modality3: 64 110 | io: 111 | file_prefix: "graph_visualization" 112 | fused_dataset_key: "fused_dataset" 113 | task_strategy: 114 | - object: "MultimodalGraphModel" 115 | args: 116 | io: 117 | fused_dataset_key: "fused_dataset" 118 | input_key: "data.derived_graph" 119 | target_key: &target "data.ground_truth" 120 | prediction_key: &prediction "model.out" 121 | 122 | 123 | model_config: 124 | graph_model: 125 | module_identifier: *graph_module 126 | n_layers: *n_layers 127 | node_emb_dim: 1 # really needed? 128 | head_model: 129 | head_hidden_size: 130 | - 100 131 | - 20 132 | dropout: 0.5 133 | add_softmax: True 134 | num_classes: 2 135 | 136 | training: 137 | model_dir: "model_mplex" 138 | batch_size: 1 139 | best_epoch_source: 140 | mode: "max" 141 | monitor: "validation.metrics.accuracy" 142 | train_metrics: 143 | key: "accuracy" 144 | object: "MetricAccuracy" 145 | args: 146 | pred: *prediction 147 | target: *target 148 | validation_metrics: 149 | key: "accuracy" 150 | object: "MetricAccuracy" 151 | args: 152 | pred: *prediction 153 | target: *target 154 | pl_trainer_num_epochs: 5 155 | pl_trainer_accelerator: "cpu" 156 | pl_trainer_devices: 1 157 | 158 | testing: 159 | test_results_filename: &test_results_filename "test_results.pickle" 160 | evaluation_directory: &evaluation_directory "eval" 161 | -------------------------------------------------------------------------------- /test/pipeline/test_mlp.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | - fuse_object: "static_pipeline" 3 | args: 4 | name: "Test dataset" 5 | - object: "get_splits" 6 | args: 7 | name: "splits" 8 | io: 9 | input_key: null 10 | output_key: "data_splits" 11 | 12 | mlflow: 13 | MLFLOW_TRACKING_URI: null 14 | MLFLOW_EXPERIMENT_NAME: null 15 | 16 | cache: 17 | num_workers: 1 18 | restart_cache: True 19 | root_dir: "_examples/tests/cache" 20 | 21 | modality_encoding_strategy: 22 | - object: "ModalityEncoding" 23 | args: 24 | data.input.raw.modality1: 25 | model: null 26 | output_key: "data.input.encoded.modality1" 27 | use_autoencoder: True 28 | encoding_layers: 29 | - 128 30 | - 64 31 | use_pretrained: False 32 | batch_size: 5 33 | # training config for modality autoencoder. Details in mmmt.data.representation.AutoEncoderTrainer.set_train_config(..) 34 | training: 35 | model_dir: "model_modality1" 36 | pl_trainer_num_epochs: 1 37 | pl_trainer_accelerator: "cpu" 38 | data.input.raw.modality2: 39 | model: null 40 | output_key: "data.input.encoded.modality2" 41 | use_autoencoder: True 42 | encoding_layers: 43 | - 32 44 | - 32 45 | use_pretrained: False 46 | batch_size: 5 47 | # training config for modality autoencoder. Details in mmmt.data.representation.AutoEncoderTrainer.set_train_config(..) 48 | training: 49 | model_dir: "model_modality2" 50 | pl_trainer_num_epochs: 1 51 | pl_trainer_accelerator: "cpu" 52 | data.input.raw.modality3: 53 | model: null 54 | output_key: "data.input.encoded.modality3" 55 | use_autoencoder: True 56 | encoding_layers: 57 | - 64 58 | - 64 59 | use_pretrained: False 60 | batch_size: 5 61 | # training config for modality autoencoder. Details in mmmt.data.representation.AutoEncoderTrainer.set_train_config(..) 62 | training: 63 | model_dir: "model_modality3" 64 | pl_trainer_num_epochs: 1 65 | pl_trainer_accelerator: "cpu" 66 | 67 | 68 | fusion_strategy: 69 | - object: "EncodedUnimodalToConcept" # early or late 70 | args: 71 | use_autoencoders: True 72 | add_feature_names: False 73 | encoding_layers: 74 | - 32 75 | - &n_layers 16 76 | use_pretrained: False 77 | batch_size: 5 78 | training: 79 | model_dir: "model_concept" 80 | pl_trainer_num_epochs: 1 81 | pl_trainer_accelerator: "cpu" 82 | io: 83 | concept_encoder_model_key: "concept_encoder_model" 84 | input_keys: 85 | - "data.input.encoded.modality1" 86 | - "data.input.encoded.modality2" 87 | - "data.input.encoded.modality3" 88 | output_key: "data.input.concatenated" 89 | 90 | task_strategy: 91 | - object: "MultimodalMLP" 92 | args: 93 | io: 94 | input_key: "data.input.concatenated" 95 | target_key: &target "data.ground_truth" 96 | prediction_key: &prediction "model.out" 97 | 98 | 99 | model_config: 100 | hidden_size: 101 | - 100 102 | - 20 103 | dropout: 0.5 104 | add_softmax: True 105 | num_classes: 2 106 | 107 | training: 108 | model_dir: "model_mlp" 109 | batch_size: 1 110 | best_epoch_source: 111 | mode: "max" 112 | monitor: "validation.metrics.accuracy" 113 | train_metrics: 114 | key: "accuracy" 115 | object: "MetricAccuracy" 116 | args: 117 | pred: *prediction 118 | target: *target 119 | validation_metrics: 120 | key: "accuracy" 121 | object: "MetricAccuracy" 122 | args: 123 | pred: *prediction 124 | target: *target 125 | pl_trainer_num_epochs: 5 126 | pl_trainer_accelerator: "cpu" 127 | pl_trainer_devices: 1 128 | 129 | testing: 130 | test_results_filename: &test_results_filename "test_results.pickle" 131 | evaluation_directory: &evaluation_directory "eval" 132 | -------------------------------------------------------------------------------- /test/pipeline/test_pipeline.py: -------------------------------------------------------------------------------- 1 | """Unit test for complex_module.core.""" 2 | 3 | __copyright__ = """ 4 | LICENSED INTERNAL CODE. PROPERTY OF IBM. 5 | IBM Research Licensed Internal Code 6 | (C) Copyright IBM Corp. 2021 7 | ALL RIGHTS RESERVED 8 | """ 9 | import os 10 | from tempfile import mkdtemp 11 | from mmmt.data.representation.modality_encoding import ( 12 | ModalityEncoding, 13 | ) 14 | import unittest 15 | import numpy as np 16 | from fuse.data.pipelines.pipeline_default import PipelineDefault 17 | from fuse.data.ops.ops_read import OpReadDataframe 18 | from fuse.data.ops.ops_common import OpToOneHot 19 | import pandas as pd 20 | from mmmt.pipeline.pipeline import MMMTPipeline 21 | 22 | 23 | class TESTDataset: 24 | @staticmethod 25 | def static_pipeline(name=None) -> PipelineDefault: 26 | data = { 27 | "sample_id": [i for i in range(100)], 28 | "data.input.raw.modality1": [ 29 | np.random.rand(64).astype(np.float32) * (i % 2) for i in range(100) 30 | ], 31 | "data.input.raw.modality2": [ 32 | np.random.rand(32).astype(np.float32) * (i % 2) for i in range(100) 33 | ], 34 | "data.input.raw.modality3": [ 35 | np.random.rand(128).astype(np.float32) * (i % 2) for i in range(100) 36 | ], 37 | "data.ground_truth": [np.random.randint(2, size=1) for i in range(100)], 38 | } 39 | 40 | df = pd.DataFrame(data) 41 | df["data.ground_truth"] = df.applymap( 42 | lambda x: int(2 * np.mean(x)), na_action="ignore" 43 | )["data.input.raw.modality1"] 44 | 45 | pipeline_list = [ 46 | (OpReadDataframe(df), dict()), 47 | ( 48 | OpToOneHot(2), 49 | {"key_in": "data.ground_truth", "key_out": "data.ground_truth"}, 50 | ), 51 | ] 52 | dataset_pipeline = PipelineDefault("static", pipeline_list) 53 | print(name) 54 | return dataset_pipeline 55 | 56 | @staticmethod 57 | def get_splits(name=None): 58 | 59 | data_splits = { 60 | "train_ids": list(range(50)), 61 | "val_ids": list(range(50, 75)), 62 | "test_ids": list(range(75, 100)), 63 | } 64 | print(name) 65 | return data_splits 66 | 67 | 68 | class PipelineTestCase(unittest.TestCase): 69 | """CoreTestCase class.""" 70 | 71 | def setUp(self): 72 | """Setting up the test.""" 73 | 74 | self.specific_objects = { 75 | "static_pipeline": { 76 | "object": TESTDataset.static_pipeline, 77 | }, 78 | "get_splits": { 79 | "object": TESTDataset.get_splits, 80 | }, 81 | } 82 | 83 | pass 84 | 85 | def test_create_pipeline(self): 86 | """Test cteating a pipeline.""" 87 | MMMTP = MMMTPipeline( 88 | "test/pipeline/test.yaml", 89 | self.specific_objects, 90 | defaults="test/pipeline/test.yaml", 91 | ) 92 | 93 | def test_run_pipeline(self): 94 | """Test running the pipeline.""" 95 | MMMTP = MMMTPipeline( 96 | "test/pipeline/test.yaml", 97 | self.specific_objects, 98 | defaults="test/pipeline/test.yaml", 99 | ) 100 | MMMTP.run_pipeline() 101 | 102 | def test_run_mlp_pipeline(self): 103 | """Test running the pipeline.""" 104 | MMMTP = MMMTPipeline( 105 | "test/pipeline/test_mlp.yaml", 106 | self.specific_objects, 107 | defaults="test/pipeline/test_mlp.yaml", 108 | ) 109 | MMMTP.run_pipeline() 110 | 111 | def tearDown(self): 112 | """Tear down the tests.""" 113 | pass 114 | 115 | 116 | if __name__ == "__main__": 117 | unittest.main() 118 | --------------------------------------------------------------------------------