├── .github └── workflows │ ├── ci.yml │ ├── lint.yml │ └── test_old.yml ├── .gitignore ├── LICENSE ├── README.md ├── bdpy ├── __init__.py ├── bdata │ ├── __init__.py │ ├── bdata.py │ ├── featureselector.py │ ├── metadata.py │ └── utils.py ├── dataform │ ├── __init__.py │ ├── datastore.py │ ├── features.py │ ├── kvs.py │ ├── pd.py │ ├── sparse.py │ └── utils.py ├── dataset │ ├── __init__.py │ └── utils.py ├── distcomp │ ├── __init__.py │ └── distcomp.py ├── dl │ ├── __init__.py │ ├── caffe.py │ └── torch │ │ ├── __init__.py │ │ ├── base.py │ │ ├── dataset.py │ │ ├── domain │ │ ├── __init__.py │ │ ├── core.py │ │ ├── feature_domain.py │ │ └── image_domain.py │ │ ├── models.py │ │ └── torch.py ├── evals │ ├── __init__.py │ └── metrics.py ├── feature │ ├── __init__.py │ └── feature.py ├── fig │ ├── __init__.py │ ├── draw_group_image_set.py │ ├── fig.py │ ├── makeplots.py │ ├── makeplots2.py │ └── tile_images.py ├── ml │ ├── __init__.py │ ├── crossvalidation.py │ ├── ensemble.py │ ├── learning.py │ ├── model.py │ ├── regress.py │ └── searchlight.py ├── mri │ ├── __init__.py │ ├── fmriprep.py │ ├── glm.py │ ├── image.py │ ├── load_epi.py │ ├── load_mri.py │ ├── roi.py │ └── spm.py ├── opendata │ ├── __init__.py │ └── openneuro.py ├── pipeline │ └── config.py ├── preproc │ ├── __init__.py │ ├── interface.py │ ├── preprocessor.py │ ├── select_top.py │ └── util.py ├── recon │ ├── __init__.py │ ├── torch │ │ ├── __init__.py │ │ ├── icnn.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── critic.py │ │ │ ├── encoder.py │ │ │ ├── generator.py │ │ │ ├── latent.py │ │ │ └── optimizer.py │ │ └── task │ │ │ ├── __init__.py │ │ │ └── inversion.py │ └── utils.py ├── stats │ ├── __init__.py │ └── corr.py ├── task │ ├── __init__.py │ ├── callback.py │ └── core.py └── util │ ├── __init__.py │ ├── info.py │ ├── math.py │ └── utils.py ├── docs ├── _config.yml ├── bdata_api_examples.md ├── dataform_features.md └── index.md ├── examples ├── .gitignore ├── bdata_labels.ipynb ├── data │ ├── sample_vmap.h5 │ └── sample_vmap_nomap.h5 ├── example_fmriprep.ipynb └── fig.ipynb ├── mypy.ini ├── pyproject.toml └── tests ├── .gitignore ├── __init__.py ├── bdata ├── __init__.py ├── test_bdata.py ├── test_featureselector.py ├── test_metadata.py └── test_utils.py ├── data ├── array_jl_dense_v1.mat ├── array_jl_sparse_v1.mat ├── mri │ ├── epi0001.hdr │ ├── epi0001.img │ ├── epi0002.hdr │ ├── epi0002.img │ ├── epi0003.hdr │ ├── epi0003.img │ ├── epi0004.hdr │ ├── epi0004.img │ ├── epi0005.hdr │ └── epi0005.img ├── test_models │ ├── fastl2lir-chunk-bd │ │ ├── W │ │ │ ├── 00000000.mat │ │ │ ├── 00000001.mat │ │ │ ├── 00000002.mat │ │ │ ├── 00000003.mat │ │ │ ├── 00000004.mat │ │ │ ├── 00000005.mat │ │ │ ├── 00000006.mat │ │ │ └── 00000007.mat │ │ ├── b │ │ │ ├── 00000000.mat │ │ │ ├── 00000001.mat │ │ │ ├── 00000002.mat │ │ │ ├── 00000003.mat │ │ │ ├── 00000004.mat │ │ │ ├── 00000005.mat │ │ │ ├── 00000006.mat │ │ │ └── 00000007.mat │ │ └── info.yaml │ ├── fastl2lir-chunk-pkl │ │ ├── 00000000.pkl.gz │ │ ├── 00000001.pkl.gz │ │ ├── 00000002.pkl.gz │ │ ├── 00000003.pkl.gz │ │ ├── 00000004.pkl.gz │ │ ├── 00000005.pkl.gz │ │ ├── 00000006.pkl.gz │ │ ├── 00000007.pkl.gz │ │ └── info.yaml │ ├── fastl2lir-nochunk-bd │ │ ├── W.mat │ │ ├── b.mat │ │ └── info.yaml │ ├── fastl2lir-nochunk-pkl │ │ ├── info.yaml │ │ └── model.pkl.gz │ └── lir-nochunk-pkl │ │ ├── info.yaml │ │ └── model.pkl.gz ├── testdata-2d-nan.pkl.gz └── testdata-2d.pkl.gz ├── dataform ├── __init__.py ├── test_features.py ├── test_kvs.py └── test_sparse.py ├── distcomp ├── __init__.py └── test_distcomp.py ├── dl ├── __init__.py └── torch │ ├── __init__.py │ ├── domain │ ├── __init__.py │ ├── test_core.py │ ├── test_feature_domain.py │ └── test_image_domain.py │ ├── test_models.py │ └── test_torch.py ├── env ├── py27 │ └── Pipfile └── py38 │ └── Pipfile ├── evals ├── __init__.py └── test_metrics.py ├── feature ├── __init__.py └── test_feature.py ├── ml ├── __init__.py ├── test_crossvalidation.py ├── test_ensemble.py ├── test_learning.py └── test_regress.py ├── preproc ├── __init__.py ├── test_interface.py └── test_select_top.py ├── recon ├── __init__.py └── torch │ ├── __init__.py │ ├── modules │ ├── __init__.py │ ├── test_critic.py │ ├── test_encoder.py │ ├── test_generator.py │ ├── test_latent.py │ └── test_optimizer.py │ └── task │ ├── __init__.py │ └── test_inversion.py ├── task ├── __init__.py ├── test_callback.py └── test_core.py ├── test_mri.py ├── test_pipeline.py ├── test_stats.py └── util ├── __init__.py ├── test_math.py └── test_utils.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: ci 5 | 6 | on: 7 | push: 8 | branches: [ "dev" ] 9 | pull_request: 10 | branches: [ "dev" ] 11 | 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | python-version: ["3.8", "3.9", "3.10", "3.11"] 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v4 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | cache: "pip" 28 | 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | python -m pip install pytest pytest-cov pytest-github-actions-annotate-failures==0.2.0 33 | pip install .[dev] 34 | 35 | # https://stackoverflow.com/questions/985876/tee-and-exit-status 36 | - name: Test with pytest 37 | run: | 38 | ( 39 | set -o pipefail 40 | pytest --junitxml=pytest.yml \ 41 | --cov-report=term-missing:skip-covered --cov=bdpy tests | tee pytest-coverage.txt 42 | ) 43 | 44 | - name: Upload coverage comment 45 | if: always() 46 | uses: MishaKav/pytest-coverage-comment@v1.1.47 47 | with: 48 | pytest-coverage-path: pytest-coverage.txt 49 | junitxml-path: pytest.yml 50 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | push: 5 | branches: [ "lint" ] 6 | 7 | jobs: 8 | lint: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Lint with ruff 13 | uses: chartboost/ruff-action@v1 14 | 15 | type-check: 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | 21 | - name: Set up Python 3.8 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: 3.8 25 | cache: "pip" 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install mypy mypy-gh-action-report 31 | pip install .[all] 32 | 33 | - name: Static type check with mypy 34 | run: | 35 | mypy . | mypy-gh-action-report 36 | 37 | -------------------------------------------------------------------------------- /.github/workflows/test_old.yml: -------------------------------------------------------------------------------- 1 | name: test_old 2 | 3 | on: 4 | push: 5 | branches: [ "test_old" ] 6 | 7 | jobs: 8 | test_py36_py37: 9 | runs-on: ubuntu-20.04 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | python-version: ["3.6", "3.7"] 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | cache: "pip" 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | python -m pip install pytest pytest-cov pytest-github-actions-annotate-failures 28 | pip install .[dev] 29 | 30 | - name: Test with pytest 31 | run: | 32 | ( 33 | set -o pipefail 34 | pytest --junitxml=pytest.yml \ 35 | --cov-report=term-missing:skip-covered --cov=bdpy tests | tee pytest-coverage.txt 36 | ) 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *~ 3 | 4 | *.pyc 5 | .python-version 6 | .pydevproject 7 | .project 8 | *.ipynb 9 | .ipynb_checkpoints 10 | build 11 | dist 12 | *.egg-info 13 | 14 | *.npy 15 | *.mat 16 | *.h5 17 | 18 | .coverage 19 | htmlcov 20 | .pylintrc 21 | 22 | junk 23 | tmp 24 | test_local 25 | test_versions 26 | 27 | /venv 28 | /venv-* 29 | /.venv 30 | 31 | *.code-workspace 32 | 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017-2024 Kamitani Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BdPy 2 | 3 | [![PyPI version](https://badge.fury.io/py/bdpy.svg)](https://badge.fury.io/py/bdpy) 4 | [![GitHub license](https://img.shields.io/github/license/KamitaniLab/bdpy)](https://github.com/KamitaniLab/bdpy/blob/master/LICENSE) 5 | [![ci](https://github.com/KamitaniLab/bdpy/actions/workflows/ci.yml/badge.svg)](https://github.com/KamitaniLab/bdpy/actions/workflows/ci.yml) 6 | 7 | Python package for brain decoding analysis 8 | 9 | ## Requirements 10 | 11 | - Python 3.8 or later 12 | - numpy 13 | - scipy 14 | - scikit-learn 15 | - pandas 16 | - h5py 17 | - hdf5storage 18 | - pyyaml 19 | 20 | ### Optional requirements 21 | 22 | - `dataform` module 23 | - pandas 24 | - `dl.caffe` module 25 | - Caffe 26 | - Pillow 27 | - tqdm 28 | - `dl.torch` module 29 | - PyTorch 30 | - Pillow 31 | - `fig` module 32 | - matplotlib 33 | - Pillow 34 | - `bdpy.ml` module 35 | - tqdm 36 | - `mri` module 37 | - nipy 38 | - nibabel 39 | - pandas 40 | - `recon.torch` module 41 | - PyTorch 42 | - Pillow 43 | 44 | ### Optional requirements for testing 45 | - fastl2lir 46 | 47 | ## Installation 48 | 49 | Latest stable release: 50 | 51 | ``` shell 52 | $ pip install bdpy 53 | ``` 54 | 55 | To install the latest development version ("master" branch of the repository), please run the following command. 56 | 57 | ```shell 58 | $ pip install git+https://github.com/KamitaniLab/bdpy.git 59 | ``` 60 | 61 | ## Packages 62 | 63 | - bdata: BdPy data format (BData) core package 64 | - dataform: Utilities for various data format 65 | - distcomp: Distributed computation utilities 66 | - dl: Deep learning utilities 67 | - feature: Utilities for DNN features 68 | - fig: Utilities for figure creation 69 | - ml: Machine learning utilities 70 | - mri: MRI utilities 71 | - opendata: Open data utilities 72 | - preproc: Utilities for preprocessing 73 | - recon: Reconstruction methods 74 | - stats: Utilities for statistics 75 | - util: Miscellaneous utilities 76 | 77 | ## BdPy data format 78 | 79 | BdPy data format (or BrainDecoderToolbox2 data format; BData) consists of two variables: dataset and metadata. **dataset** stores brain activity data (e.g., voxel signal value for fMRI data), target variables (e.g., ID of stimuli for vision experiments), and additional information specifying experimental design (e.g., run and block numbers for fMRI experiments). Each row corresponds to a single 'sample', and each column representes either single feature (voxel), target, or experiment design information. **metadata** contains data describing meta-information for each column in dataset. 80 | 81 | See [BData API examples](https://github.com/KamitaniLab/bdpy/blob/main/docs/bdata_api_examples.md) for useage of BData. 82 | 83 | ## Developers 84 | 85 | - Shuntaro C. Aoki (Kyoto Univ) 86 | -------------------------------------------------------------------------------- /bdpy/__init__.py: -------------------------------------------------------------------------------- 1 | """BdPy: Brain decoding toolbox for Python. 2 | 3 | Developed by Kamitani Lab, Kyoto Univ. and ATR. 4 | """ 5 | 6 | 7 | # `import bdpy` implicitly imports class `BData` (in package `bdata`) and 8 | # package `util`. 9 | from .bdata import BData 10 | from .bdata import vstack, metadata_equal 11 | from .util import create_groupvector, divide_chunks, get_refdata, makedir_ifnot, dump_info, average_elemwise 12 | -------------------------------------------------------------------------------- /bdpy/bdata/__init__.py: -------------------------------------------------------------------------------- 1 | """BdPy data package. 2 | 3 | This package is a part of BdPy. 4 | """ 5 | 6 | 7 | from .bdata import BData 8 | from .utils import concat_dataset, vstack, metadata_equal 9 | -------------------------------------------------------------------------------- /bdpy/bdata/featureselector.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Feature selector class 3 | 4 | This file is a part of BdPy 5 | ''' 6 | 7 | 8 | class FeatureSelector(object): 9 | ''' 10 | Feature selector class 11 | 12 | Parameters 13 | ---------- 14 | expression : str 15 | Selection command 16 | 17 | Attributes 18 | ---------- 19 | expression : str 20 | Selection command 21 | token : tuple 22 | Tokens 23 | rpn : tuple 24 | Tokens in reversed polish notation 25 | ''' 26 | 27 | # Class variables ################## 28 | signs = ('(', ')') 29 | operators = ('=', '|', '&', '+', '-', '@') 30 | 31 | __op_order = {'=': 10, 32 | '|': 5, 33 | '&': 5, 34 | '+': 5, 35 | '-': 5, 36 | '@': 3, 37 | '(': -1, 38 | ')': -1} 39 | 40 | # Methods ########################## 41 | 42 | def __init__(self, expression): 43 | self.expression = expression 44 | self.token = self.lexical_analysis(self.expression) 45 | self.rpn = self.parse(self.token) 46 | 47 | self.index = None 48 | 49 | def lexical_analysis(self, expression): 50 | '''Lexical analyser''' 51 | 52 | str_buf = '' 53 | output_buf = [] 54 | 55 | i = 0 56 | while i < len(expression): 57 | if expression[i] == ' ': 58 | # Ignore a white-space 59 | i += 1 60 | continue 61 | elif expression[i] == '"': 62 | i += 1 63 | while expression[i] != '"': 64 | str_buf += expression[i] 65 | i += 1 66 | i += 1 67 | continue 68 | elif expression[i] == "'": 69 | i += 1 70 | while expression[i] != "'": 71 | str_buf += expression[i] 72 | i += 1 73 | i += 1 74 | continue 75 | elif self.signs.count(expression[i]) or self.operators.count(expression[i]): 76 | if len(str_buf) > 0: 77 | output_buf.append(str_buf) 78 | str_buf = '' 79 | 80 | output_buf.append(expression[i]) 81 | else: 82 | str_buf += expression[i] 83 | 84 | i += 1 85 | 86 | if len(str_buf) > 0: 87 | output_buf.append(str_buf) 88 | str_buf = '' 89 | 90 | # Convert '+' to '|' 91 | output_buf = ['|' if a == '+' else a for a in output_buf] 92 | 93 | return tuple(output_buf) 94 | 95 | def parse(self, token_list): 96 | '''Parser for selection command''' 97 | 98 | out_que = [] 99 | op_stack = [] 100 | 101 | for token in token_list: 102 | 103 | if self.operators.count(token): 104 | while op_stack: 105 | if self.__op_order[token] > self.__op_order[op_stack[-1]]: 106 | break 107 | out_que.append(op_stack.pop()) 108 | 109 | op_stack.append(token) 110 | elif token == '(': 111 | op_stack.append('(') 112 | elif token == ')': 113 | while op_stack: 114 | if op_stack[-1] == '(': 115 | op_stack.pop() 116 | else: 117 | out_que.append(op_stack.pop()) 118 | else: 119 | out_que.append(token) 120 | 121 | while op_stack: 122 | out_que.append(op_stack.pop()) 123 | 124 | return tuple(out_que) 125 | -------------------------------------------------------------------------------- /bdpy/bdata/metadata.py: -------------------------------------------------------------------------------- 1 | """ 2 | MetaData class 3 | 4 | This file is a part of BdPy 5 | """ 6 | 7 | 8 | import numpy as np 9 | 10 | 11 | class MetaData(object): 12 | """ 13 | MetaData class 14 | 15 | 'MetaData' is a list of dictionaries. Each element has three keys: 'key', 16 | 'value', and 'description'. 17 | """ 18 | 19 | 20 | def __init__(self, key=None, value=None, description=None): 21 | if key is None: 22 | key = [] 23 | if value is None: 24 | value = np.ndarray((0, 0), dtype=float) 25 | if description is None: 26 | description = [] 27 | 28 | self.__key = key 29 | self.__value = value 30 | self.__description = description 31 | 32 | @property 33 | def key(self): 34 | return self.__key 35 | 36 | @key.setter 37 | def key(self, x): 38 | self.__key = x 39 | 40 | @property 41 | def value(self): 42 | return self.__value 43 | 44 | @value.setter 45 | def value(self, x): 46 | self.__value = x 47 | 48 | @property 49 | def description(self): 50 | return self.__description 51 | 52 | @description.setter 53 | def description(self, x): 54 | self.__description = x 55 | 56 | def set(self, key, value, description, updater=None): 57 | """ 58 | Set meta-data with `key`, `description`, and `value` 59 | 60 | Parameters 61 | ---------- 62 | key : str 63 | Meta-data key 64 | value : array_like 65 | Meta-data value 66 | description : str 67 | Meta-data description 68 | updater : function 69 | Function applied to meta-data value when meta-data named `key` already exists. 70 | It should take two args: new and old meta-data values. 71 | """ 72 | 73 | # If `value` is None, `set` does not update the value. 74 | is_novalue = True if value is None else False 75 | 76 | value = np.array(value) 77 | 78 | if key in self.__key: 79 | # Update existing metadata 80 | 81 | ind = [i for i, k in enumerate(self.__key) if k == key] 82 | 83 | if len(ind) > 1: 84 | raise ValueError('Multiple meta-data with the same key is not supported') 85 | 86 | ind = ind[0] 87 | 88 | self.__description[ind] = description 89 | 90 | # If `value` is None, `set` does not update the value. 91 | if is_novalue: 92 | return None 93 | 94 | if value.shape[0] > self.get_value_len(): 95 | cols = np.empty((self.__value.shape[0], value.shape[0] - self.get_value_len())) 96 | cols[:] = np.nan 97 | 98 | self.__value = np.hstack([self.__value, cols]) 99 | 100 | if updater is None: 101 | self.__value[ind, :] = value 102 | else: 103 | self.__value[ind, :] = np.array(updater(value, self.__value[ind, :]), dtype=float) 104 | else: 105 | # Add new metadata 106 | self.__key.append(key) 107 | self.__description.append(description) 108 | 109 | if value.shape[0] > self.get_value_len(): 110 | cols = np.empty((self.__value.shape[0], value.shape[0] - self.get_value_len())) 111 | cols[:] = np.nan 112 | 113 | self.__value = np.hstack([self.__value, cols]) 114 | 115 | self.__value = np.vstack([self.__value, value]) 116 | 117 | 118 | def get(self, key, field): 119 | """ 120 | Returns meta-data specified by `key` 121 | 122 | Parameters 123 | ---------- 124 | key : str 125 | Meta-data key 126 | field : str 127 | Field name of meta-data (either 'value' or 'description') 128 | 129 | Returns 130 | ------- 131 | array, str or None 132 | Meta-data value or description. If `key` was not found in 133 | the metadata, `None` is returned. 134 | """ 135 | if key in self.__key: 136 | ind = self.__key.index(key) 137 | else: 138 | return None 139 | 140 | if field == 'value': 141 | return self.__value[ind, :].astype(float) 142 | 143 | if field == 'description': 144 | return self.__description[ind] 145 | 146 | 147 | def get_value_len(self): 148 | """Returns length of meta-data value""" 149 | return self.__value.shape[1] 150 | 151 | 152 | def keylist(self): 153 | """Returns a list of keys""" 154 | return self.__key 155 | -------------------------------------------------------------------------------- /bdpy/dataform/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | BdPy data format package 3 | 4 | This package is a part of BdPy 5 | """ 6 | 7 | from .pd import * 8 | from .datastore import * 9 | from .sparse import * 10 | from .features import * 11 | from .kvs import SQLite3KeyValueStore 12 | -------------------------------------------------------------------------------- /bdpy/dataform/pd.py: -------------------------------------------------------------------------------- 1 | '''Utilities for Pandas dataframe 2 | 3 | This file is a part of BdPy 4 | ''' 5 | 6 | 7 | __all__ = ['convert_dataframe', 'append_dataframe'] 8 | 9 | 10 | import pandas as pd 11 | 12 | 13 | def convert_dataframe(lst): 14 | '''Convert `lst` to Pandas dataframe 15 | 16 | Parameters 17 | ---------- 18 | lst : list of dicts 19 | 20 | Returns 21 | ------- 22 | Pandas dataframe 23 | ''' 24 | 25 | df_lst = (pd.DataFrame([item.values()], columns=item.keys()) for item in lst) 26 | df = pd.concat(df_lst, axis=0, ignore_index=True) 27 | return df 28 | 29 | 30 | def append_dataframe(df, **kwargs): 31 | '''Append a row to Pandas dataframe `df` 32 | 33 | Parameters 34 | ---------- 35 | df : Pandas dataframe 36 | kwargs : key-value of data to be added in `df` 37 | 38 | Returns 39 | ------- 40 | Pandas dataframe 41 | ''' 42 | 43 | df_append = pd.DataFrame({k : [kwargs[k]] for k in kwargs}) 44 | return df.append(df_append, ignore_index=True) 45 | -------------------------------------------------------------------------------- /bdpy/dataform/sparse.py: -------------------------------------------------------------------------------- 1 | '''Sparse array class. 2 | 3 | This file is a part of bdpy. 4 | ''' 5 | 6 | __all__ = ['SparseArray', 'load_array', 'save_array', 'save_multiarrays'] 7 | 8 | 9 | import os 10 | 11 | import numpy as np 12 | import h5py 13 | import hdf5storage 14 | 15 | 16 | def load_array(fname, key='data'): 17 | '''Load an array (dense or sparse).''' 18 | 19 | with h5py.File(fname, 'r') as f: 20 | methods = [attr for attr in dir(f[key]) if callable(getattr(f[key], str(attr)))] 21 | if 'keys' in methods and '__bdpy_sparse_arrray' in f[key].keys(): 22 | # SparseArray 23 | s_ary = SparseArray(fname, key=key) 24 | return s_ary.dense 25 | elif type(f[key][()]) == np.ndarray: 26 | # Dense array 27 | return hdf5storage.loadmat(fname)[key] 28 | else: 29 | raise RuntimeError('Unsupported data type: %s' % type(f[key][()])) 30 | 31 | 32 | def save_array(fname, array, key='data', dtype=np.float64, sparse=False): 33 | '''Save an array (dense or sparse).''' 34 | 35 | if sparse: 36 | # Save as a SparseArray 37 | s_ary = SparseArray(array.astype(dtype)) 38 | s_ary.save(fname, key=key, dtype=dtype) 39 | else: 40 | # Save as a dense array 41 | hdf5storage.savemat(fname, 42 | {key: array.astype(dtype)}, 43 | format='7.3', oned_as='column', 44 | store_python_metadata=True) 45 | 46 | return None 47 | 48 | 49 | def save_multiarrays(fname, arrays): 50 | '''Save arrays (dense).''' 51 | 52 | save_dict = {k: v for k, v in arrays.items()} 53 | hdf5storage.savemat(fname, 54 | save_dict, 55 | format='7.3', oned_as='column', 56 | store_python_metadata=True) 57 | 58 | return None 59 | 60 | 61 | class SparseArray(object): 62 | '''Sparse array class.''' 63 | 64 | def __init__(self, src=None, key='data', background=0): 65 | self.__background = background 66 | 67 | if type(src) == np.ndarray: 68 | # Create sparse array from numpy.ndarray 69 | self.__make_sparse(src) 70 | elif os.path.isfile(src): 71 | # Load data from src 72 | self.__load(src, key=key) 73 | else: 74 | raise ValueError('Unsupported input') 75 | 76 | @property 77 | def dense(self): 78 | return self.__make_dense() 79 | 80 | def save(self, fname, key='data', dtype=np.float64): 81 | hdf5storage.savemat(fname, {key: {u'__bdpy_sparse_arrray': True, 82 | u'index': self.__index, 83 | u'value': self.__value.astype(dtype), 84 | u'shape': self.__shape, 85 | u'background' : self.__background}}, 86 | format='7.3', oned_as='column', store_python_metadata=True) 87 | return None 88 | 89 | def __make_sparse(self, array): 90 | self.__index = np.where(array != self.__background) 91 | self.__value = array[self.__index] 92 | self.__shape = array.shape 93 | return None 94 | 95 | def __make_dense(self): 96 | dense = np.ones(self.__shape) * self.__background 97 | dense[self.__index] = self.__value 98 | return dense 99 | 100 | def __load(self, fname, key='data'): 101 | data = hdf5storage.loadmat(fname)[key] 102 | 103 | index = data[u'index'] 104 | if isinstance(index, tuple): 105 | self.__index = index 106 | elif isinstance(index, np.ndarray): 107 | self.__index = tuple(index[0]) 108 | else: 109 | raise TypeError('Unsupported data type ("index").') 110 | 111 | value = data[u'value'] 112 | if value.ndim == 1: 113 | self.__value = value 114 | else: 115 | self.__value = value.flatten() 116 | 117 | array_shape = data[u'shape'] 118 | if isinstance(array_shape, tuple): 119 | self.__shape = array_shape 120 | elif isinstance(array_shape, np.ndarray): 121 | self.__shape = array_shape.flatten() 122 | else: 123 | raise TypeError('Unsupported data type ("shape").') 124 | 125 | self.__background = data[u'background'] 126 | return None 127 | -------------------------------------------------------------------------------- /bdpy/dataform/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for Bdpy dataformat.""" 2 | 3 | from typing import List, Union 4 | 5 | from bdpy.dataform import Features 6 | import numpy as np 7 | 8 | 9 | def get_multi_features(features: List[Features], layer: str, labels: Union[List[str], np.ndarray]) -> np.ndarray: 10 | """Load features from multiple Features.""" 11 | y_list = [] 12 | for label in labels: 13 | for feat in features: 14 | if label not in feat.labels: 15 | continue 16 | f = feat.get(layer=layer, label=label) 17 | y_list.append(f) 18 | return np.vstack(y_list) 19 | -------------------------------------------------------------------------------- /bdpy/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | """Dataset package.""" 2 | -------------------------------------------------------------------------------- /bdpy/dataset/utils.py: -------------------------------------------------------------------------------- 1 | """Dataset utilities.""" 2 | 3 | from typing import List, TypedDict, Union 4 | 5 | import hashlib 6 | import inspect 7 | import os 8 | import subprocess 9 | import urllib.request 10 | 11 | from tqdm import tqdm 12 | 13 | 14 | class FileDict(TypedDict): 15 | name: str 16 | url: str 17 | md5sum: str 18 | 19 | 20 | def download_file(url: str, destination: str, progress_bar: bool = True, md5sum: Union[str, None] = None) -> None: 21 | """Download a file. 22 | 23 | Parameters 24 | ---------- 25 | url: str 26 | File URL. 27 | destination: str 28 | Path to save the file. 29 | progress_bar: bool = True 30 | Show progress bar if True. 31 | md5sum: Union[str, None] = None 32 | md5sum hash of the file. 33 | 34 | Returns 35 | ------- 36 | None 37 | """ 38 | response = urllib.request.urlopen(url) 39 | file_size = int(response.info()["Content-Length"]) 40 | 41 | def __show_progress(block_num: int, block_size: int, total_size: int) -> None: 42 | downloaded = block_num * block_size 43 | if total_size > 0: 44 | progress_bar.update(downloaded - progress_bar.n) 45 | 46 | with tqdm(total=file_size, unit='B', unit_scale=True, desc=destination, ncols=100) as progress_bar: 47 | urllib.request.urlretrieve(url, destination, __show_progress) 48 | 49 | if md5sum is not None: 50 | md5_hash = hashlib.md5() 51 | with open(destination, 'rb') as f: 52 | for chunk in iter(lambda: f.read(4096), b''): 53 | md5_hash.update(chunk) 54 | md5sum_test = md5_hash.hexdigest() 55 | if md5sum != md5sum_test: 56 | raise ValueError(f'md5sum mismatch. \nExpected: {md5sum}\nActual: {md5sum_test}') 57 | 58 | 59 | def download_splitted_file(file_list: List[FileDict], destination: str, progress_bar: bool = True, md5sum: Union[str, None] = None) -> None: 60 | """Download a file. 61 | 62 | Parameters 63 | ---------- 64 | file_list: List[FileDict] 65 | List of split files. 66 | destination: str 67 | Path to save the file. 68 | progress_bar: bool = True 69 | Show progress bar if True. 70 | md5sum: Union[str, None] = None 71 | md5sum hash of the file. 72 | 73 | Returns 74 | ------- 75 | None 76 | """ 77 | wdir = os.path.dirname(destination) 78 | 79 | # Download split files 80 | for sf in file_list: 81 | _output = os.path.join(wdir, sf['name']) 82 | if not os.path.exists(_output): 83 | print(f'Downloading {_output} from {sf["url"]}') 84 | download_file(sf['url'], _output, progress_bar=progress_bar, md5sum=sf['md5sum']) 85 | 86 | # Merge files 87 | subprocess.run(f"cat {destination}-* > {destination}", shell=True) 88 | print(f"File created: {destination}") 89 | 90 | # Check md5sum 91 | if md5sum is not None: 92 | md5_hash = hashlib.md5() 93 | with open(destination, 'rb') as f: 94 | for chunk in iter(lambda: f.read(4096), b''): 95 | md5_hash.update(chunk) 96 | md5sum_test = md5_hash.hexdigest() 97 | if md5sum != md5sum_test: 98 | raise ValueError(f'md5sum mismatch. \nExpected: {md5sum}\nActual: {md5sum_test}') 99 | -------------------------------------------------------------------------------- /bdpy/distcomp/__init__.py: -------------------------------------------------------------------------------- 1 | '''Distributed computation package 2 | 3 | This package is a part of BdPy. 4 | ''' 5 | 6 | from .distcomp import * 7 | -------------------------------------------------------------------------------- /bdpy/distcomp/distcomp.py: -------------------------------------------------------------------------------- 1 | '''Distributed computation module 2 | 3 | This file is a part of BdPy. 4 | ''' 5 | 6 | 7 | __all__ = ['DistComp'] 8 | 9 | 10 | import os 11 | import warnings 12 | import sqlite3 13 | from contextlib import closing 14 | 15 | 16 | class DistComp(object): 17 | '''Distributed computation class''' 18 | 19 | def __init__(self, backend='file', comp_id=None, lockdir='tmp', db_path='./distcomp.db'): 20 | self.__backend = backend # 'file' or 'sqlite3' 21 | self.lockdir = lockdir 22 | self.comp_id = comp_id 23 | self.__db_path = db_path 24 | 25 | self.lockfile = self.__lockfilename(self.comp_id) if self.comp_id != None else None 26 | 27 | if self.__backend == 'sqlite3': 28 | if not os.path.isfile(self.__db_path): 29 | self.__init_db() 30 | 31 | def islocked(self, *args): 32 | if self.__backend == 'file' and len(args) > 0: 33 | raise RuntimeError('File backend does not requires computation ID.') 34 | if self.__backend == 'sqlite3' and len(args) != 1: 35 | raise RuntimeError('SQLite3 backend requires computation ID.') 36 | 37 | if self.__backend == 'file': 38 | if os.path.isfile(self.lockfile): 39 | return True 40 | else: 41 | return False 42 | elif self.__backend == 'sqlite3': 43 | comp_id = args[0] 44 | if self.__status_db(comp_id) == 'locked': 45 | return True 46 | else: 47 | return False 48 | else: 49 | raise ValueError('Unknown backend: %s' % self.__backend) 50 | 51 | def lock(self, *args): 52 | if self.__backend == 'file' and len(args) > 0: 53 | raise RuntimeError('File backend does not requires computation ID.') 54 | if self.__backend == 'sqlite3' and len(args) != 1: 55 | raise RuntimeError('SQLite3 backend requires computation ID.') 56 | 57 | if self.__backend == 'file': 58 | with open(self.lockfile, 'w'): 59 | pass 60 | elif self.__backend == 'sqlite3': 61 | comp_id = args[0] 62 | with sqlite3.connect(self.__db_path, isolation_level='EXCLUSIVE') as db: 63 | try: 64 | db.execute('INSERT INTO computation (name, status) VALUES ("%s", "locked")' % comp_id) 65 | return True 66 | except db.Error: 67 | print('Already locked') 68 | return False 69 | else: 70 | raise ValueError('Unknown backend: %s' % self.__backend) 71 | 72 | def unlock(self, *args): 73 | if self.__backend == 'file' and len(args) > 0: 74 | raise RuntimeError('File backend does not requires computation ID.') 75 | if self.__backend == 'sqlite3' and len(args) != 1: 76 | raise RuntimeError('SQLite3 backend requires computation ID.') 77 | 78 | if self.__backend == 'file': 79 | try: 80 | os.remove(self.lockfile) 81 | except OSError: 82 | warnings.warn('Failed to unlock the computation. Possibly double running.') 83 | elif self.__backend == 'sqlite3': 84 | comp_id = args[0] 85 | with sqlite3.connect(self.__db_path, isolation_level='EXCLUSIVE') as db: 86 | try: 87 | db.execute('DELETE FROM computation WHERE name = "%s"' % comp_id) 88 | return True 89 | except db.Error: 90 | print('Already unlocked') 91 | return False 92 | else: 93 | raise ValueError('Unknown backend: %s' % self.__backend) 94 | 95 | def islocked_lock(self, *args): 96 | if self.__backend == 'file' and len(args) > 0: 97 | raise RuntimeError('File backend does not requires computation ID.') 98 | 99 | if self.__backend == 'sqlite3': 100 | raise NotImplementedError() 101 | 102 | is_locked = os.path.isfile(self.lockfile) 103 | if not is_locked: 104 | with open(self.lockfile, 'w'): 105 | pass 106 | 107 | return is_locked 108 | 109 | def __lockfilename(self, comp_id): 110 | '''Return the lock file path''' 111 | return os.path.join(self.lockdir, comp_id + '.lock') 112 | 113 | def __init_db(self): 114 | with sqlite3.connect(self.__db_path, isolation_level='EXCLUSIVE') as conn: 115 | c = conn.cursor() 116 | c.execute('CREATE TABLE computation(name TEXT UNIQUE, status TEXT)') 117 | return None 118 | 119 | def __status_db(self, comp_id): 120 | '''Return status of `comp_id`.''' 121 | with sqlite3.connect(self.__db_path, isolation_level='EXCLUSIVE') as db: 122 | r = [row[0] for row in db.execute('SELECT STATUS FROM computation WHERE name = "%s"' % comp_id)] 123 | if len(r) == 0: 124 | st = 'not_found' 125 | else: 126 | st = r[0] 127 | return st 128 | -------------------------------------------------------------------------------- /bdpy/dl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/bdpy/dl/__init__.py -------------------------------------------------------------------------------- /bdpy/dl/caffe.py: -------------------------------------------------------------------------------- 1 | '''Caffe module.''' 2 | 3 | 4 | import os 5 | 6 | import PIL 7 | import numpy as np 8 | from bdpy.dataform import save_array 9 | from tqdm import tqdm 10 | 11 | 12 | def extract_image_features(image_file, net, layers=[], crop_center=False, image_preproc=[], save_dir=None, verbose=False, progbar=False, return_features=True): 13 | ''' 14 | Extract DNN features of a given image. 15 | 16 | Parameters 17 | ---------- 18 | image_file : str or list 19 | (List of) path to the input image file(s). 20 | net : Caffe network instance 21 | layers : list 22 | List of DNN layers of which features are returned. 23 | crop_center : bool (default: False) 24 | Crop the center of an image or not. 25 | image_preproc : list (default: []) 26 | List of additional preprocessing functions. The function input/output 27 | should be a PIL.Image instance. The preprocessing functions are applied 28 | after RGB conversion, center-cropping, and resizing of the input image. 29 | save_dir : None or str (default: None) 30 | Save the features in the specified directory if not None. 31 | verbose : bool (default: False) 32 | Output verbose messages or not. 33 | return_features: bool (default: True) 34 | Return the extracted features or not. 35 | 36 | Returns 37 | ------- 38 | dict 39 | Dictionary in which keys are DNN layers and values are features. 40 | ''' 41 | 42 | if isinstance(image_file, str): 43 | image_file = [image_file] 44 | 45 | features_dict = {} 46 | 47 | if progbar: 48 | image_file = tqdm(image_file) 49 | 50 | for imgf in image_file: 51 | if verbose: 52 | print('Image: %s' % imgf) 53 | 54 | image_size = net.blobs['data'].data.shape[-2:] 55 | mean_img = net.transformer.mean['data'] 56 | 57 | # Open the image 58 | img = PIL.Image.open(imgf) 59 | 60 | # Convert non-RGB to RGB 61 | if img.mode == 'CMYK': 62 | img = img.convert('RGB') 63 | 64 | if img.mode == 'RGBA': 65 | bg = PIL.Image.new('RGB', img.size, (255, 255, 255)) 66 | bg.paste(img, mask=img.split()[3]) 67 | img = bg 68 | 69 | # Convert monochrome to RGB 70 | if img.mode == 'L': 71 | img = img.convert('RGB') 72 | 73 | # Center cropping 74 | if crop_center: 75 | w, h = img.size 76 | img = img.crop(((w - min(img.size)) // 2, 77 | (h - min(img.size)) // 2, 78 | (w + min(img.size)) // 2, 79 | (h + min(img.size)) // 2)) 80 | 81 | # Resize 82 | img = img.resize(image_size, PIL.Image.BICUBIC) 83 | 84 | for p in image_preproc: 85 | img = p(img) 86 | 87 | img_array = np.array(img) 88 | 89 | try: 90 | img_array = np.float32(np.transpose(img_array, (2, 0, 1))[::-1]) - np.reshape(mean_img, (3, 1, 1)) 91 | except: 92 | import pdb; pdb.set_trace() 93 | 94 | # Forwarding 95 | net.blobs['data'].reshape(1, 3, img_array.shape[1], img_array.shape[2]) 96 | net.blobs['data'].data[0] = img_array 97 | net.forward() 98 | 99 | # Get features 100 | for lay in layers: 101 | feat = net.blobs[lay].data.copy() 102 | 103 | if return_features: 104 | if lay in features_dict: 105 | features_dict.update({ 106 | lay: np.vstack([features_dict[lay], feat]) 107 | }) 108 | else: 109 | features_dict.update({lay: feat}) 110 | 111 | if not save_dir is None: 112 | # Save the features 113 | save_dir_lay = os.path.join(save_dir, lay.replace('/', ':')) 114 | save_file = os.path.join(save_dir_lay, 115 | os.path.splitext(os.path.basename(imgf))[0] + '.mat') 116 | if not os.path.exists(save_dir_lay): 117 | os.makedirs(save_dir_lay) 118 | if os.path.exists(save_file): 119 | if verbose: 120 | print('%s already exists. Skipped.' % save_file) 121 | continue 122 | save_array(save_file, feat, key='feat', dtype=np.float32, sparse=False) 123 | if verbose: 124 | print('%s saved.' % save_file) 125 | 126 | if return_features: 127 | return features_dict 128 | else: 129 | return None 130 | -------------------------------------------------------------------------------- /bdpy/dl/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch import FeatureExtractor, ImageDataset 2 | from .base import * 3 | -------------------------------------------------------------------------------- /bdpy/dl/torch/base.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Base classes. 3 | ''' 4 | 5 | 6 | __all__ = [ 7 | 'DnnFeatureExtractorBase', 8 | 'ReconstructionBase', 9 | ] 10 | 11 | 12 | from typing import Any, Type, Iterable, List, Dict, Tuple, Callable, Union, Optional 13 | 14 | import os 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | _tensor_t = Union[np.ndarray, torch.Tensor] 22 | 23 | 24 | class DnnFeatureExtractorBase(object): 25 | ''' 26 | Base class for PyTorch DNN feature extractors. 27 | 28 | ''' 29 | 30 | def __init__(self, model: Optional[nn.Module] = None, model_cls: Optional[Type[nn.Module]] = None, layers: Iterable[str] = [], device: str = 'cpu', init_args={}): 31 | self.model = model 32 | self.model_cls = model_cls 33 | self.layers = layers 34 | self.device = torch.device(device) 35 | 36 | self.init(**init_args) 37 | 38 | if self.model is None: 39 | raise RuntimeError('`self.model` is None. You should define it it `init()`.') 40 | 41 | self.model.to(self.device) 42 | 43 | def init(self) -> None: 44 | ''' 45 | Custom initialization method. 46 | `init_args` in `__init__()` is passed to this function. 47 | ''' 48 | return None 49 | 50 | def preprocess(self, x: Any) -> Any: 51 | ''' 52 | Preprocesses the input for the DNN model. 53 | ''' 54 | return x 55 | 56 | def extract_features(self, x: Any) -> Dict[str, np.ndarray]: 57 | ''' 58 | Extracts features from the given input using the DNN model. 59 | ''' 60 | raise NotImplementedError("Subclass must implement extract_features method.") 61 | 62 | def __call__(self, x: Any, **kwargs) -> Dict[str, _tensor_t]: 63 | return self.extract_features(self.preprocess(x), **kwargs) 64 | 65 | 66 | class ReconstructionBase(object): 67 | ''' 68 | Base class for reconstruction. 69 | 70 | ''' 71 | 72 | def __init__(self, model: Optional[nn.Module] = None, model_cls: Optional[Type[nn.Module]] = None, layers: Iterable[str] = [], device: str = 'cpu', init_args={}): 73 | self.model = model 74 | self.model_cls = model_cls 75 | self.layers = layers 76 | self.device = torch.device(device) 77 | 78 | self.init(**init_args) 79 | 80 | if self.model is None: 81 | raise RuntimeError('`self.model` is None. You should define it it `init()`.') 82 | 83 | self.model.to(self.device) 84 | 85 | def init(self) -> None: 86 | ''' 87 | Custom initialization method. 88 | `init_args` in `__init__()` is passed to this function. 89 | ''' 90 | return None 91 | 92 | def preprocess(self, x: Any) -> Any: 93 | ''' 94 | Preprocesses the input for the DNN model. 95 | ''' 96 | return x 97 | 98 | def reconstruct(self, x: Any) -> Any: 99 | ''' 100 | Reconstruction from the given input. 101 | ''' 102 | raise NotImplementedError("Subclass must implement reconstruct method.") 103 | 104 | def __call__(self, x: Any, **kwargs) -> Any: 105 | return self.reconstruct(self.preprocess(x), **kwargs) 106 | -------------------------------------------------------------------------------- /bdpy/dl/torch/domain/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Domain, InternalDomain, IrreversibleDomain, ComposedDomain, KeyValueDomain -------------------------------------------------------------------------------- /bdpy/dl/torch/domain/feature_domain.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict 4 | 5 | import torch 6 | 7 | from .core import Domain 8 | 9 | _FeatureType = Dict[str, torch.Tensor] 10 | 11 | 12 | def _lnd2nld(feature: torch.Tensor) -> torch.Tensor: 13 | """Convert features having the shape of (L, N, D) to (N, L, D).""" 14 | return feature.permute(1, 0, 2) 15 | 16 | def _nld2lnd(feature: torch.Tensor) -> torch.Tensor: 17 | """Convert features having the shape of (N, L, D) to (L, N, D).""" 18 | return feature.permute(1, 0, 2) 19 | 20 | 21 | class ArbitraryFeatureKeyDomain(Domain): 22 | def __init__( 23 | self, 24 | to_internal: dict[str, str] | None = None, 25 | to_self: dict[str, str] | None = None, 26 | ): 27 | super().__init__() 28 | 29 | if to_internal is None and to_self is None: 30 | raise ValueError("Either to_internal or to_self must be specified.") 31 | 32 | if to_internal is None: 33 | to_internal = {value: key for key, value in to_self.items()} 34 | elif to_self is None: 35 | to_self = {value: key for key, value in to_internal.items()} 36 | 37 | self._to_internal = to_internal 38 | self._to_self = to_self 39 | 40 | def send(self, features: _FeatureType) -> _FeatureType: 41 | return {self._to_internal.get(key, key): value for key, value in features.items()} 42 | 43 | def receive(self, features: _FeatureType) -> _FeatureType: 44 | return {self._to_self.get(key, key): value for key, value in features.items()} 45 | -------------------------------------------------------------------------------- /bdpy/evals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/bdpy/evals/__init__.py -------------------------------------------------------------------------------- /bdpy/evals/metrics.py: -------------------------------------------------------------------------------- 1 | '''bdpy.evals.metrics''' 2 | 3 | import warnings 4 | 5 | import numpy as np 6 | from scipy.spatial.distance import cdist 7 | 8 | 9 | def profile_correlation(x, y): 10 | '''Profile correlation.''' 11 | 12 | sample_axis = 0 13 | 14 | orig_shape = x.shape 15 | n_sample = orig_shape[sample_axis] 16 | 17 | _x = x.reshape(n_sample, -1) 18 | _y = y.reshape(n_sample, -1) 19 | 20 | n_feat = _y.shape[1] 21 | 22 | r = np.array( 23 | [ 24 | np.corrcoef( 25 | _x[:, j].ravel(), 26 | _y[:, j].ravel() 27 | )[0, 1] 28 | for j in range(n_feat) 29 | ] 30 | ) 31 | 32 | r = r.reshape((1,) + orig_shape[1:]) 33 | 34 | return r 35 | 36 | 37 | def pattern_correlation(x, y, mean=None, std=None, remove_nan=True): 38 | '''Pattern correlation.''' 39 | 40 | sample_axis = 0 41 | 42 | orig_shape = x.shape 43 | n_sample = orig_shape[sample_axis] 44 | 45 | _x = x.reshape(n_sample, -1) 46 | _y = y.reshape(n_sample, -1) 47 | 48 | if mean is not None and std is not None: 49 | m = mean.reshape(-1) 50 | s = std.reshape(-1) 51 | 52 | _x = (_x - m) / s 53 | _y = (_y - m) / s 54 | 55 | if remove_nan: 56 | # Remove nan columns based on the decoded features 57 | nan_cols = np.isnan(_x).any(axis=0) | np.isnan(_y).any(axis=0) 58 | if nan_cols.any(): 59 | warnings.warn('NaN column removed ({})'.format(np.sum(nan_cols))) 60 | _x = _x[:, ~nan_cols] 61 | _y = _y[:, ~nan_cols] 62 | 63 | r = np.array( 64 | [ 65 | np.corrcoef( 66 | _x[i, :].ravel(), 67 | _y[i, :].ravel() 68 | )[0, 1] 69 | for i in range(n_sample) 70 | ] 71 | ) 72 | 73 | return r 74 | 75 | 76 | def pattern_cross_correlation(x, y, mean=None, std=None, remove_nan=True): 77 | '''Pattern correlation. 78 | Output: cross correlation of size (n_sample, n_sample). 79 | The (i,j) element of r corresponds to the correlation between i-th row of x and j-th row of y. 80 | ''' 81 | 82 | sample_axis = 0 83 | 84 | orig_shape = x.shape 85 | n_sample = orig_shape[sample_axis] 86 | 87 | _x = x.reshape(n_sample, -1) 88 | _y = y.reshape(n_sample, -1) 89 | 90 | if mean is not None and std is not None: 91 | if mean.shape[sample_axis] == n_sample: 92 | # if mean and std are different across samples 93 | m = mean.reshape(n_sample, -1) 94 | s = std.reshape(n_sample, -1) 95 | else: 96 | m = mean.reshape(-1) 97 | s = std.reshape(-1) 98 | 99 | _x = (_x - m) / s 100 | _y = (_y - m) / s 101 | 102 | if remove_nan: 103 | # Remove nan columns based on the decoded features 104 | nan_cols = np.isnan(_x).any(axis=0) | np.isnan(_y).any(axis=0) 105 | if nan_cols.any(): 106 | warnings.warn('NaN column removed ({})'.format(np.sum(nan_cols))) 107 | _x = _x[:, ~nan_cols] 108 | _y = _y[:, ~nan_cols] 109 | 110 | r = np.corrcoef( _x, _y)[:n_sample, n_sample:] 111 | 112 | return r 113 | 114 | 115 | def pairwise_identification(pred, true, metric='correlation', remove_nan=True, remove_nan_dist=True, single_trial=False, pred_labels=None, true_labels=None): 116 | '''Pair-wise identification.''' 117 | 118 | p = pred.reshape(pred.shape[0], -1) 119 | t = true.reshape(true.shape[0], -1) 120 | 121 | if remove_nan: 122 | # Remove nan columns based on the decoded features 123 | nan_cols = np.isnan(p).any(axis=0) | np.isnan(t).any(axis=0) 124 | if nan_cols.any(): 125 | warnings.warn('NaN column removed ({})'.format(np.sum(nan_cols))) 126 | p = p[:, ~nan_cols] 127 | t = t[:, ~nan_cols] 128 | 129 | if single_trial: 130 | cr = [] 131 | for i in range(p.shape[0]): 132 | d = 1 - cdist(p[i][np.newaxis], t, metric=metric) 133 | # label の情報 134 | ind = np.where(np.array(true_labels) == pred_labels[i])[0][0] 135 | 136 | s = (d - d[0, ind]).ravel() 137 | if remove_nan_dist and np.isnan(s).any(): 138 | warnings.warn('NaN value detected in the distance matrix ({}).'.format(np.sum(np.isnan(s)))) 139 | s = s[~np.isnan(s)] 140 | ac = np.sum(s < 0) / (len(s) - 1) 141 | cr.append(ac) 142 | cr = np.asarray(cr) 143 | else: 144 | d = 1 - cdist(p, t, metric=metric) 145 | 146 | if remove_nan_dist: 147 | cr = [] 148 | for d_ind in range(d.shape[0]): 149 | pef = d[d_ind, :] - d[d_ind, d_ind] 150 | if np.isnan(pef).any(): 151 | warnings.warn('NaN value detected in the distance matrix ({}).'.format(np.sum(np.isnan(pef)))) 152 | pef = pef[~np.isnan(pef)] # Remove nan value from the comparison for identification 153 | pef = np.sum(pef < 0) / (len(pef) - 1) 154 | cr.append(pef) 155 | cr = np.asarray(cr) 156 | else: 157 | cr = np.sum(d - np.diag(d)[:, np.newaxis] < 0, axis=1) / (d.shape[1] - 1) 158 | 159 | return cr 160 | 161 | 162 | def remove_nan_value(array, nan_flag=None, return_nan_flag=False): 163 | '''Helper function: 164 | Remove columns (units) which contain nan values 165 | 166 | array (numpy.array) ... shape should be [sample x units] 167 | nan_flag (numpy.array or list) ... if exist, remove columns according to the nan_flag 168 | return_nan_flag (bool) ... if True, return nan_flag to remove columns of the array. 169 | 170 | ''' 171 | 172 | if nan_flag is None: 173 | nan_flag = np.isnan(array).any(axis=0) 174 | nan_removed_array = array[:, ~nan_flag] 175 | 176 | if return_nan_flag: 177 | return nan_removed_array, nan_flag 178 | else: 179 | return nan_removed_array 180 | -------------------------------------------------------------------------------- /bdpy/feature/__init__.py: -------------------------------------------------------------------------------- 1 | '''Feature engineering module.''' 2 | 3 | from .feature import normalize_feature 4 | -------------------------------------------------------------------------------- /bdpy/feature/feature.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def normalize_feature(feature, 5 | channel_wise_mean=True, channel_wise_std=True, 6 | channel_axis=0, 7 | std_ddof=1, 8 | shift=None, scale=None, 9 | scaling_only=False): 10 | '''Normalize feature. 11 | 12 | Parameters 13 | ---------- 14 | feature : ndarray 15 | Feature to be normalized. 16 | channel_wise_mean, channel_wise_std : bool (default: True) 17 | If `True`, run channel-wise mean/SD normalization. 18 | channel_axis : int (default: 0) 19 | Channel axis. 20 | shift, scale : None, 'self', or ndarray (default: None) 21 | If shift/scale is `None`, nothing will be added/multiplied to the normalized features. 22 | If `'self'`, mean/SD of `feature` will be added/multiplied to the normalized features. 23 | If ndarrays are given, the arrays will be added/multiplied to the normalized features. 24 | std_ddof : int (default: 1) 25 | Delta degree of freedom for SD. 26 | 27 | Returns 28 | ------- 29 | ndarray 30 | Normalized (and scaled/shifted) features. 31 | ''' 32 | 33 | if feature.ndim == 1: 34 | axes_along = None 35 | else: 36 | axes = list(range(feature.ndim)) 37 | axes.remove(channel_axis) 38 | axes_along = tuple(axes) 39 | 40 | if channel_wise_mean: 41 | feat_mean = np.mean(feature, axis=axes_along, keepdims=True) 42 | else: 43 | feat_mean = np.mean(feature, keepdims=True) 44 | 45 | if channel_wise_std: 46 | feat_std = np.std(feature, axis=axes_along, ddof=std_ddof, keepdims=True) 47 | else: 48 | feat_std = np.mean(np.std(feature, axis=axes_along, ddof=std_ddof, keepdims=True), keepdims=True) 49 | 50 | if isinstance(shift, str) and shift == 'self': 51 | shift = feat_mean 52 | 53 | if isinstance(scale, str) and scale == 'self': 54 | scale = feat_std 55 | 56 | if scaling_only: 57 | feat_n = (feature / feat_std) * scale 58 | else: 59 | feat_n = ((feature - feat_mean) / feat_std) 60 | 61 | if not scale is None: 62 | feat_n = feat_n * scale 63 | if not shift is None: 64 | feat_n = feat_n + shift 65 | 66 | if not feature.shape == feat_n.shape: 67 | try: 68 | feat_n.reshape(feature.shape) 69 | except: 70 | raise ValueError('Invalid shape of normalized features (original: %s, normalized: %s). ' 71 | + 'Possibly incorrect shift and/or scale.' 72 | % (str(feature.shape), str(feat_n.shape))) 73 | 74 | return feat_n 75 | -------------------------------------------------------------------------------- /bdpy/fig/__init__.py: -------------------------------------------------------------------------------- 1 | '''Figure package 2 | 3 | This package is a part of BdPy. 4 | ''' 5 | 6 | from .fig import * 7 | from .tile_images import tile_images 8 | from .draw_group_image_set import draw_group_image_set 9 | from .makeplots import makeplots 10 | from .makeplots2 import makeplots2 11 | -------------------------------------------------------------------------------- /bdpy/fig/fig.py: -------------------------------------------------------------------------------- 1 | '''Figure module 2 | 3 | This file is a part of BdPy. 4 | 5 | Functions 6 | --------- 7 | makefigure 8 | Create a figure 9 | box_off 10 | Remove upper and right axes 11 | draw_footnote 12 | Draw footnote on a figure 13 | ''' 14 | 15 | 16 | __all__ = [ 17 | 'box_off', 18 | 'draw_footnote', 19 | 'make_violinplots', 20 | 'makefigure', 21 | ] 22 | 23 | 24 | import matplotlib.pyplot as plt 25 | import numpy as np 26 | import seaborn as sns 27 | 28 | 29 | def makefigure(figtype='a4landscape'): 30 | '''Create a figure''' 31 | 32 | if figtype == 'a4landscape': 33 | figsize = (11.7, 8.3) 34 | elif figtype == 'a4portrait': 35 | figsize = (8.3, 11.7) 36 | else: 37 | raise ValueError('Unknown figure type %s' % figtype) 38 | 39 | return plt.figure(figsize=figsize) 40 | 41 | 42 | def box_off(ax): 43 | '''Remove upper and right axes''' 44 | 45 | ax.spines['right'].set_visible(False) 46 | ax.spines['top'].set_visible(False) 47 | ax.xaxis.set_ticks_position('bottom') 48 | ax.yaxis.set_ticks_position('left') 49 | 50 | 51 | def draw_footnote(fig, string, fontsize=9): 52 | '''Draw footnote on a figure''' 53 | ax = fig.add_axes([0., 0., 1., 1.]) 54 | ax.text(0.5, 0.01, string, horizontalalignment='center', fontsize=fontsize) 55 | ax.patch.set_alpha(0.0) 56 | ax.set_axis_off() 57 | 58 | return ax 59 | 60 | 61 | def make_violinplots(df, x=None, y=None, subplot=None, figure=None, x_list=None, subplot_list=None, figure_list=None, title=None, x_label=None, y_label=None, fontsize=16, points=100): 62 | 63 | x_keys = sorted(df[x].unique()) 64 | subplot_keys = sorted(df[subplot].unique()) 65 | figure_keys = sorted(df[figure].unique()) 66 | 67 | x_list = x_keys if x_list is None else x_list 68 | subplot_list = subplot_keys if subplot_list is None else subplot_list 69 | figure_list = figure_keys if figure_list is None else figure_list 70 | 71 | print('X: {}'.format(x_list)) 72 | print('Subplot: {}'.format(subplot_list)) 73 | print('Figures: {}'.format(figure_list)) 74 | 75 | col_num = np.ceil(np.sqrt(len(subplot_list))) 76 | row_num = int(np.ceil(len(subplot_list) / col_num)) 77 | col_num = int(col_num) 78 | 79 | print('Subplot in {} x {}'.format(row_num, col_num)) 80 | 81 | figs = [] 82 | 83 | # Figure loop 84 | for fig_label in figure_list: 85 | print('Creating figure for {}'.format(fig_label)) 86 | fig = makefigure('a4landscape') 87 | 88 | sns.set() 89 | sns.set_style('ticks') 90 | sns.set_palette('gray') 91 | 92 | # Subplot loop 93 | for i, sp_label in enumerate(subplot_list): 94 | print('Creating subplot for {}'.format(sp_label)) 95 | 96 | # Set subplot position 97 | col = int(i / row_num) 98 | row = i - col * row_num 99 | sbpos = (row_num - row - 1) * col_num + col + 1 100 | 101 | # Get data 102 | data = [] 103 | for j, x_lbl in enumerate(x_list): 104 | df_t = df.query('{} == "{}" & {} == "{}" & {} == "{}"'.format(subplot, sp_label, figure, fig_label, x, x_lbl)) 105 | data_t = df_t[y].values 106 | data_t = np.array([np.nan, np.nan]) if len(data_t) == 0 else np.concatenate(data_t) 107 | # violinplot requires at least two elements in the dataset 108 | data.append(data_t) 109 | 110 | # Plot 111 | ax = plt.subplot(row_num, col_num, sbpos) 112 | 113 | ax.hlines(0, xmin=-1, xmax=len(x_list), color='k', linestyle='-', linewidth=0.5) 114 | ax.hlines([-0.4, -0.2, 0.2, 0.4, 0.6, 0.8], xmin=-1, xmax=len(x_list), color='k', linestyle=':', linewidth=0.5) 115 | 116 | xpos = range(len(x_list)) 117 | 118 | ax.violinplot(data, xpos, showmeans=True, showextrema=False, showmedians=False, points=points) 119 | 120 | ax.text(-0.5, 0.85, sp_label, horizontalalignment='left', fontsize=fontsize) 121 | 122 | ax.set_xlim([-1, len(x_list)]) 123 | ax.set_xticks(range(len(x_list))) 124 | if row == 0: 125 | ax.set_xticklabels(x_list, rotation=-45, fontsize=fontsize) 126 | else: 127 | ax.set_xticklabels([]) 128 | 129 | ax.set_ylim([-0.4, 1.0]) # FXIME: auto-scaling 130 | ax.tick_params(axis='y', labelsize=fontsize) 131 | box_off(ax) 132 | 133 | plt.tight_layout() 134 | 135 | # X Label 136 | if x_label is not None: 137 | ax = fig.add_axes([0, 0, 1, 1]) 138 | ax.text(0.5, 0, x_label, 139 | verticalalignment='center', horizontalalignment='center', fontsize=fontsize) 140 | ax.patch.set_alpha(0.0) 141 | ax.set_axis_off() 142 | 143 | # Y label 144 | if y_label is not None: 145 | ax = fig.add_axes([0, 0, 1, 1]) 146 | ax.text(0, 0.5, y_label, 147 | verticalalignment='center', horizontalalignment='center', fontsize=fontsize, rotation=90) 148 | ax.patch.set_alpha(0.0) 149 | ax.set_axis_off() 150 | 151 | # Figure title 152 | if title is not None: 153 | ax = fig.add_axes([0, 0, 1, 1]) 154 | ax.text(0.5, 0.99, '{}: {}'.format(title, fig_label), 155 | horizontalalignment='center', fontsize=fontsize) 156 | ax.patch.set_alpha(0.0) 157 | ax.set_axis_off() 158 | 159 | figs.append(fig) 160 | 161 | if len(figs) == 1: 162 | return figs[0] 163 | else: 164 | return figs 165 | -------------------------------------------------------------------------------- /bdpy/ml/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | BdPy machine learning package 3 | 4 | This package is a part of BdPy 5 | """ 6 | 7 | 8 | from .learning import Classification, CrossValidation, ModelTraining, ModelTest 9 | from .crossvalidation import make_cvindex, make_crossvalidationindex, make_cvindex_generator 10 | from .crossvalidation import cvindex_groupwise 11 | from .ensemble import * 12 | from .regress import * 13 | from .searchlight import * 14 | from .model import EnsembleClassifier 15 | -------------------------------------------------------------------------------- /bdpy/ml/ensemble.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for ensemble learning 3 | """ 4 | 5 | from collections import Counter 6 | 7 | import numpy as np 8 | 9 | 10 | __all__ = ['get_majority'] 11 | 12 | 13 | def get_majority(data, axis=0): 14 | """ 15 | Returns a list of majority elements in each row or column. 16 | 17 | If more than two elements occupies the same numbers in each row or column, 18 | 'get_majority' returns the first-sorted element. 19 | 20 | Parameters 21 | ---------- 22 | data : array_like 23 | axis : 0 or 1, optional 24 | Axis in which elements are counted (default: 0) 25 | 26 | 27 | Returns 28 | ------- 29 | majority_list : list 30 | A list of majority elements 31 | """ 32 | 33 | majority_list = [] 34 | 35 | if axis == 0: 36 | data = np.transpose(data) 37 | 38 | for i in range(data.shape[0]): 39 | target = data[i].tolist() 40 | # Change KS for returning first element if the same numbers 41 | #c = Counter(target) 42 | c = Counter(np.sort(target)) 43 | majority = c.most_common(1) 44 | majority_list.append(majority[0][0]) 45 | 46 | return majority_list 47 | -------------------------------------------------------------------------------- /bdpy/ml/regress.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is a part of BdPy 3 | """ 4 | 5 | 6 | __all__ = ['add_bias'] 7 | 8 | 9 | import numpy as np 10 | 11 | 12 | def add_bias(x, axis=0): 13 | """ 14 | Add bias terms to x 15 | 16 | Parameters 17 | ---------- 18 | x : array_like 19 | Data matrix 20 | axis : 0 or 1, optional 21 | Axis in which bias terms are added (default: 0) 22 | 23 | Returns 24 | ------- 25 | y : array_like 26 | Data matrix with bias terms 27 | """ 28 | 29 | if axis == 0: 30 | vlen = x.shape[1] 31 | y = np.concatenate((x, np.array([np.ones(vlen)])), axis=0) 32 | elif axis == 1: 33 | vlen = x.shape[0] 34 | y = np.concatenate((x, np.array([np.ones(vlen)]).T), axis=1) 35 | else: 36 | raise ValueError('axis should be either 0 or 1') 37 | 38 | return y 39 | -------------------------------------------------------------------------------- /bdpy/ml/searchlight.py: -------------------------------------------------------------------------------- 1 | '''Utilities for searchlight analysis.''' 2 | 3 | 4 | __all__ = ['get_neighbors'] 5 | 6 | 7 | import numpy as np 8 | 9 | 10 | def get_neighbors(xyz, space_xyz, shape='sphere', size=9): 11 | ''' 12 | Returns neighboring voxels (cluster). 13 | 14 | Parameters 15 | ---------- 16 | xyz : array_like, shape=(3,) or len=3 17 | Voxel XYZ coordinate in the center of the cluster. 18 | space_xyz : array_like, shape=(3, N) or (N, 3) 19 | XYZ coordinate of all voxels. 20 | shape : {'sphere'}, optional 21 | Shape of the cluster. 22 | size : float, optional 23 | Size of the cluster. 24 | 25 | Returns 26 | ------- 27 | cluster_index : array_like, dtype=bool, shape=(N,) 28 | Boolean index of voxels in the cluster. 29 | ''' 30 | 31 | # Input check 32 | if isinstance(xyz, list): 33 | xyz = np.array(xyz) 34 | 35 | if xyz.ndim != 1: 36 | raise TypeError('xyz should be 1-D array') 37 | 38 | if space_xyz.ndim != 2: 39 | raise TypeError('space_xyz should be 2-D array') 40 | 41 | # Fix input shape 42 | if space_xyz.shape[0] == 3: 43 | space_xyz = space_xyz.T 44 | 45 | if shape == 'sphere': 46 | dist = np.sum((space_xyz - xyz) ** 2, axis=1) 47 | cluster_index = dist <= size ** 2 48 | else: 49 | raise ValueError('Unknown shape: %s' % shape) 50 | 51 | return cluster_index 52 | -------------------------------------------------------------------------------- /bdpy/mri/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | BdPy MRI package 3 | 4 | This package is a part of BdPy 5 | """ 6 | 7 | from .load_epi import load_epi 8 | from .load_mri import load_mri 9 | from .roi import add_roimask, get_roiflag, add_roilabel, add_rois, merge_rois, add_hcp_rois, add_hcp_visual_cortex 10 | from .fmriprep import create_bdata_fmriprep, FmriprepData 11 | from .spm import create_bdata_spm_domestic 12 | from .image import export_brain_image 13 | from .glm import make_paradigm 14 | -------------------------------------------------------------------------------- /bdpy/mri/glm.py: -------------------------------------------------------------------------------- 1 | '''bdpy.mri.glm''' 2 | 3 | 4 | import csv 5 | 6 | import numpy as np 7 | from nipy.modalities.fmri.experimental_paradigm import BlockParadigm, EventRelatedParadigm 8 | 9 | 10 | def make_paradigm(event_files, num_vols, tr=2., cond_col=2, label_col=None, regressors=None, ignore_col=None, ignore_value=[], trial_wise=False, design='block'): 11 | ''' 12 | Make paradigm for GLM with Nipy from BIDS task event files. 13 | 14 | Parameters 15 | ---------- 16 | event_files : list 17 | List of task event files. 18 | num_vols : list 19 | List of the number of volumes in each run. 20 | tr : int or float 21 | TR in sec. 22 | cond_col : int 23 | Index of the condition column in the task event files. 24 | label_col : int 25 | Index of the label column in the task event files. 26 | regressors : list 27 | Names of regressors (conditions) included in the design matrix. 28 | ignore_col : int 29 | Index of the column to be ingored. 30 | ignore_value : list 31 | List of values to be ignored. 32 | design : 'block' or 'event_related 33 | Specifying experimental design. 34 | trial_wise : bool 35 | Returns trial-wise design matrix if True. 36 | 37 | Returns 38 | ------- 39 | dict 40 | paradigm : nipy.Paradigm 41 | condition_labels : labels for task regressors 42 | run_regressors : nuisance regressors for runs 43 | run_regressors_label : labels for the run regressors 44 | ''' 45 | 46 | onset = [] 47 | duration = [] 48 | conds = [] 49 | labels = [] 50 | 51 | # Run regressors 52 | run_regs = [] 53 | run_regs_labels = [] 54 | 55 | n_total_vols = np.sum(num_vols) 56 | 57 | trial_count = 0 58 | 59 | # Combining all runs/sessions into a single design matrix 60 | for i, (ef, nv) in enumerate(zip(event_files, num_vols)): 61 | n_run = i + 1 62 | 63 | with open(ef, 'r') as f: 64 | reader = csv.reader(f, delimiter='\t') 65 | header = reader.next() 66 | for row in reader: 67 | if not regressors is None and not row[cond_col] in regressors: 68 | continue 69 | if not ignore_col is None: 70 | if row[ignore_col] in ignore_value: 71 | continue 72 | trial_count += 1 73 | onset.append(float(row[0]) + i * nv * tr) 74 | duration.append(float(row[1])) 75 | if trial_wise: 76 | conds.append('trial-%06d' % trial_count) 77 | else: 78 | conds.append(row[cond_col]) 79 | if not label_col is None: 80 | labels.append(row[label_col]) 81 | 82 | # Run regressors 83 | run_reg = np.zeros((n_total_vols, 1)) 84 | run_reg[i * nv:(i + 1) * nv] = 1 85 | run_regs.append(run_reg) 86 | run_regs_labels.append('run-%02d' % n_run) 87 | 88 | run_regs = np.hstack(run_regs) 89 | 90 | if design == 'event_related': 91 | paradigm = EventRelatedParadigm(con_id=conds, onset=onset) 92 | else: 93 | paradigm = BlockParadigm(con_id=conds, onset=onset, duration=duration) 94 | 95 | return {'paradigm': paradigm, 96 | 'run_regressors': run_regs, 97 | 'run_regressor_labels': run_regs_labels, 98 | 'condition_labels': labels} 99 | -------------------------------------------------------------------------------- /bdpy/mri/image.py: -------------------------------------------------------------------------------- 1 | '''bdpy.mri.image''' 2 | 3 | 4 | from itertools import product 5 | 6 | import numpy as np 7 | import nibabel 8 | 9 | from bdpy.mri import load_mri 10 | 11 | 12 | def export_brain_image(brain_data, template, xyz=None, out_file=None): 13 | '''Export a brain data array as a brain image. 14 | 15 | Parameters 16 | ---------- 17 | brain_data : array 18 | Brain data array, shape = (n_sample, n_voxels) 19 | template : str 20 | Path to a template brain image file 21 | xyz : array, optional 22 | Voxel xyz coordinates of the brain data array 23 | 24 | Returns 25 | ------- 26 | nibabel.Nifti1Image 27 | ''' 28 | 29 | if brain_data.ndim == 1: 30 | brain_data = brain_data[np.newaxis, :] 31 | 32 | if brain_data.shape[0] > 1: 33 | raise RuntimeError('4-D image is not supported yet.') 34 | 35 | template_image = nibabel.load(template) 36 | _, brain_xyz, _ = load_mri(template) 37 | 38 | out_table = {} 39 | if xyz is None: 40 | xyz = brain_xyz 41 | 42 | for i in range(brain_data.shape[1]): 43 | x, y, z = xyz[0, i], xyz[1, i], xyz[2, i] 44 | out_table.update({(x, y, z): brain_data[0, i]}) 45 | 46 | out_image_array = np.zeros(template_image.shape[:3]) 47 | for i, j, k in product(range(template_image.shape[0]), range(template_image.shape[1]), range(template_image.shape[2])): 48 | x, y, z = template_image.affine[:3, :3].dot([i, j, k]) + template_image.affine[:3, 3] 49 | if (x, y, z) in out_table: 50 | out_image_array[i, j, k] = out_table[(x, y, z)] 51 | 52 | out_image = nibabel.Nifti1Image(out_image_array, template_image.affine) 53 | 54 | return out_image 55 | -------------------------------------------------------------------------------- /bdpy/mri/load_epi.py: -------------------------------------------------------------------------------- 1 | '''Loading EPIs module. 2 | 3 | This file is a part of BdPy. 4 | ''' 5 | 6 | 7 | import itertools as itr 8 | import os 9 | import re 10 | import string 11 | 12 | import nipy 13 | import numpy as np 14 | import scipy.io as sio 15 | 16 | 17 | def load_epi(datafiles): 18 | '''Load EPI files. 19 | 20 | The returned data and xyz are flattened by C-like order. 21 | 22 | Parameters 23 | ---------- 24 | datafiles: list 25 | EPI image files. 26 | 27 | Returns 28 | ------- 29 | data: array_like, shape = (M, N) 30 | Voxel signal values (M: the number of samples, N: the nubmer of 31 | voxels). 32 | xyz_array: array_like, shape = (3, N) 33 | XYZ Coordiantes of voxels. 34 | ''' 35 | 36 | data_list = [] 37 | xyz = np.array([]) 38 | 39 | for df in datafiles: 40 | print("Loading %s" % df) 41 | 42 | # Load an EPI image 43 | img = nipy.load_image(df) 44 | 45 | xyz = _check_xyz(xyz, img) 46 | data_list.append(np.array(img.get_data().flatten(), dtype=np.float64)) 47 | 48 | data = np.vstack(data_list) 49 | 50 | return data, xyz 51 | 52 | 53 | def _check_xyz(xyz, img): 54 | '''Check voxel xyz consistency.''' 55 | 56 | xyz_current = _get_xyz(img.coordmap.affine, img.get_data().shape) 57 | 58 | if xyz.size == 0: 59 | xyz = xyz_current 60 | elif (xyz != xyz_current).any(): 61 | raise ValueError("Voxel XYZ coordinates are inconsistent across volumes") 62 | 63 | return xyz 64 | 65 | 66 | def _get_xyz(affine, volume_shape): 67 | '''Return voxel XYZ coordinates based on an affine matrix. 68 | 69 | Parameters 70 | ---------- 71 | affine : array 72 | Affine matrix. 73 | volume_shape : list 74 | Shape of the volume (i, j, k lnegth). 75 | 76 | Returns 77 | ------- 78 | array, shape = (3, N) 79 | x-, y-, and z-coordinates (N: the number of voxels). 80 | ''' 81 | 82 | i_len, j_len, k_len = volume_shape 83 | ijk = np.array(list(itr.product(range(i_len), 84 | range(j_len), 85 | range(k_len), 86 | [1]))).T 87 | 88 | return np.dot(affine, ijk)[:-1] 89 | -------------------------------------------------------------------------------- /bdpy/mri/load_mri.py: -------------------------------------------------------------------------------- 1 | '''load_mri''' 2 | 3 | 4 | import numpy as np 5 | import nipy 6 | 7 | 8 | def load_mri(fpath): 9 | '''Load a MRI image. 10 | 11 | - Returns data as 2D array (sample x voxel) 12 | - Returns voxle xyz coordinates (3 x voxel) 13 | - Returns voxel ijk indexes (3 x voxel) 14 | - Data, xyz, and ijk are flattened by Fortran-like index order 15 | ''' 16 | img = nipy.load_image(fpath) 17 | 18 | data = img.get_data() 19 | if data.ndim == 4: 20 | data = data.reshape(-1, data.shape[-1], order='F').T 21 | i_len, j_len, k_len, t = img.shape 22 | affine = np.delete(np.delete(img.coordmap.affine, 3, axis=0), 3, axis=1) 23 | elif data.ndim == 3: 24 | data = data.flatten(order='F') 25 | i_len, j_len, k_len = img.shape 26 | affine = img.coordmap.affine 27 | else: 28 | raise ValueError('Invalid shape.') 29 | 30 | ijk = np.array(np.unravel_index(np.arange(i_len * j_len * k_len), 31 | (i_len, j_len, k_len), order='F')) 32 | ijk_b = np.vstack([ijk, np.ones((1, i_len * j_len * k_len))]) 33 | xyz_b = np.dot(affine, ijk_b) 34 | xyz = xyz_b[:-1] 35 | 36 | return data, xyz, ijk 37 | -------------------------------------------------------------------------------- /bdpy/opendata/__init__.py: -------------------------------------------------------------------------------- 1 | from .openneuro import makedata 2 | -------------------------------------------------------------------------------- /bdpy/pipeline/config.py: -------------------------------------------------------------------------------- 1 | """Config management.""" 2 | 3 | 4 | import argparse 5 | from pathlib import Path 6 | import inspect 7 | from datetime import datetime, timezone 8 | 9 | from hydra.experimental import initialize_config_dir, compose 10 | from omegaconf import OmegaConf, DictConfig 11 | 12 | 13 | def init_hydra_cfg() -> DictConfig: 14 | """Initialize Hydra config.""" 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('config', type=str, default=None, help='configuration file') 17 | parser.add_argument('-o', '--override', type=str, nargs='+', default=[], help='configuration override(s)') 18 | parser.add_argument('-a', '--analysis', type=str, nargs='?', default=None, help='analysis name (default: script file name)') 19 | args = parser.parse_args() 20 | 21 | config_file = args.config 22 | config_file = Path(config_file) 23 | 24 | override = args.override 25 | 26 | config_name = config_file.stem 27 | config_dir = config_file.absolute().parent 28 | 29 | # Called by 30 | stack = inspect.stack() 31 | if len(stack) >= 2: 32 | frame = stack[1] 33 | called_by = Path(frame.filename) 34 | called_by_name = called_by.stem 35 | called_by_path = str(called_by.absolute()) 36 | else: 37 | called_by_name = 'undef' 38 | called_by_path = 'undef' 39 | 40 | initialize_config_dir(config_dir=str(config_dir), job_name=str(called_by_name)) 41 | cfg = compose(config_name=config_name, overrides=override) 42 | 43 | now = datetime.now(timezone.utc).astimezone() 44 | date_str = now.strftime('%Y-%m-%d %H:%M:%S.%f %Z%z') 45 | 46 | # DictConfig of struct mode doesn't accept insertion of new keys. 47 | # Dirty solution 48 | OmegaConf.set_struct(cfg, False) 49 | cfg._run_ = { 50 | "name": called_by_name, 51 | "path": called_by_path, 52 | "timestamp": date_str, 53 | "config_name": config_name, 54 | "config_path": str(config_file.absolute()), 55 | } 56 | 57 | # code = cfg.get("code", None) 58 | cfg._analysis_name_ = args.analysis 59 | if cfg._analysis_name_ is None: 60 | cfg._analysis_name_ = called_by_name 61 | 62 | OmegaConf.set_struct(cfg, True) 63 | 64 | return cfg 65 | -------------------------------------------------------------------------------- /bdpy/preproc/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | BdPy preprocessing package 3 | 4 | This package is a part of BdPy 5 | """ 6 | 7 | 8 | from .interface import * 9 | from .select_top import * 10 | from .preprocessor import Preprocessor 11 | -------------------------------------------------------------------------------- /bdpy/preproc/interface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interface functions for preprocessing 3 | 4 | This file is a part of BdPy 5 | """ 6 | 7 | 8 | from .preprocessor import Average,Detrender,Normalize,Regressout,ReduceOutlier,ShiftSample 9 | from .util import print_start_msg, print_finish_msg 10 | 11 | 12 | def average_sample(x, group=[], verbose=True): 13 | """ 14 | Average samples within groups 15 | 16 | Parameters 17 | ---------- 18 | x : array 19 | Input data array (sample num * feature num) 20 | group : array_like 21 | Group vector (length = sample num) 22 | 23 | Returns 24 | ------- 25 | y : array 26 | Averaged data array (group num * feature num) 27 | index_map : array_like 28 | Vector mapping row indexes from y to x (length = group num) 29 | """ 30 | 31 | if verbose: 32 | print_start_msg() 33 | 34 | p = Average() 35 | y, ind_map = p.run(x, group) 36 | 37 | if verbose: 38 | print_finish_msg() 39 | 40 | return y, ind_map 41 | 42 | 43 | def detrend_sample(x, group=[], keep_mean=True, verbose=True): 44 | """ 45 | Apply linear detrend 46 | 47 | Parameters 48 | ---------- 49 | x : array 50 | Input data array (sample num * feature num) 51 | group : array_like 52 | Group vector (length = sample num) 53 | 54 | Returns 55 | ------- 56 | y : array 57 | Detrended data array (group num * feature num) 58 | """ 59 | 60 | if verbose: 61 | print_start_msg() 62 | 63 | p = Detrender() 64 | y, _ = p.run(x, group, keep_mean=keep_mean) 65 | 66 | if verbose: 67 | print_finish_msg() 68 | 69 | return y 70 | 71 | 72 | def normalize_sample(x, group=[], mode='PercentSignalChange', baseline='All', 73 | zero_threshold=1, verbose=True): 74 | """ 75 | Apply normalization 76 | 77 | Parameters 78 | ---------- 79 | x : array 80 | Input data array (sample num * feature num) 81 | group : array_like 82 | Group vector (length = sample num) 83 | Mode : str 84 | Normalization mode ('PercentSignalChange', 'Zscore', 'DivideMean', or 'SubtractMean'; default = 'PercentSignalChange') 85 | Baseline : array_like or str 86 | Baseline index vector (default: 'All') 87 | ZeroThreshold : float 88 | Zero threshold (default: 1) 89 | 90 | Returns 91 | ------- 92 | y : array 93 | Normalized data array (sample num * feature num) 94 | """ 95 | 96 | if verbose: 97 | print_start_msg() 98 | 99 | p = Normalize() 100 | y, _ = p.run(x, group, mode = mode, baseline = baseline, zero_threshold = zero_threshold) 101 | 102 | if verbose: 103 | print_finish_msg() 104 | 105 | return y 106 | 107 | 108 | def reduce_outlier(x, group=[], std=True, maxmin=True, remove=False, dimension=1, n_iter=10, std_threshold=3, max_value=None, min_value=None, verbose=True): 109 | '''Outlier reduction.''' 110 | 111 | if verbose: 112 | print_start_msg() 113 | 114 | if remove: 115 | raise NotImplementedError('"remove" option is not implemented yet.') 116 | 117 | p = ReduceOutlier() 118 | y, _ = p.run(x, group, std=std, maxmin=maxmin, dimension=dimension, n_iter=n_iter, std_threshold=std_threshold, max_value=max_value, min_value=min_value) 119 | 120 | if verbose: 121 | print_finish_msg() 122 | 123 | return y 124 | 125 | 126 | def regressout(x, group=[], regressor=[], remove_dc=True, linear_detrend=True, verbose=True): 127 | '''Remove nuisance regressors. 128 | 129 | Parameters 130 | ---------- 131 | x : array, shape = (n_sample, n_feature) 132 | Input data array 133 | group : array_like, lenght = n_sample 134 | Group vector. 135 | regressor : array_like, shape = (n_sample, n_regressor) 136 | Nuisance regressors. 137 | remove_dc : bool 138 | Remove DC component (signal mean) or not (default: True). 139 | linear_detrend : bool 140 | Remove linear trend or not (default: True). 141 | 142 | Returns 143 | ------- 144 | y : array, shape = (n_sample, n_feature) 145 | Signal without nuisance regressors. 146 | ''' 147 | 148 | if verbose: 149 | print_start_msg() 150 | 151 | p = Regressout() 152 | y, _ = p.run(x, group, regressor=regressor, remove_dc=remove_dc, linear_detrend=linear_detrend) 153 | 154 | if verbose: 155 | print_finish_msg() 156 | 157 | return y 158 | 159 | 160 | def shift_sample(x, group=[], shift_size = 1, verbose = True): 161 | """ 162 | Shift sample within groups 163 | 164 | Parameters 165 | ---------- 166 | x : array 167 | Input data (sample num * feature num) 168 | group : array_like 169 | Group vector (length: sample num) 170 | shift_size : int 171 | Shift size (default: 1) 172 | 173 | Returns 174 | ------- 175 | y : array 176 | Averaged data array (group num * feature num) 177 | index_map : array_like 178 | Vector mapping row indexes from y to x (length: group num) 179 | 180 | Example 181 | ------- 182 | 183 | import numpy as np 184 | from bdpy.preprocessor import shift_sample 185 | 186 | x = np.array([[ 1, 2, 3 ], 187 | [ 11, 12, 13 ], 188 | [ 21, 22, 23 ], 189 | [ 31, 32, 33 ], 190 | [ 41, 42, 43 ], 191 | [ 51, 52, 53 ]]) 192 | grp = np.array([ 1, 1, 1, 2, 2, 2 ]) 193 | 194 | shift_size = 1 195 | 196 | y, index = shift_sample(x, grp, shift_size) 197 | 198 | # >>> y 199 | # array([[11, 12, 13], 200 | # [21, 22, 23], 201 | # [41, 42, 43], 202 | # [51, 52, 53]]) 203 | 204 | # >>> index 205 | # array([0, 1, 3, 4]) 206 | """ 207 | 208 | if verbose: 209 | print_start_msg() 210 | 211 | p = ShiftSample() 212 | y, index_map = p.run(x, group, shift_size = shift_size) 213 | 214 | if verbose: 215 | print_finish_msg() 216 | 217 | return y, index_map 218 | -------------------------------------------------------------------------------- /bdpy/preproc/select_top.py: -------------------------------------------------------------------------------- 1 | """select_top. 2 | 3 | This file is a part of BdPy. 4 | """ 5 | 6 | 7 | __all__ = ['select_top'] 8 | 9 | 10 | from typing import Tuple, Optional 11 | import numpy as np 12 | from .util import print_start_msg, print_finish_msg 13 | 14 | 15 | def select_top(data: np.ndarray, value: np.ndarray, num: int, axis: Optional[int] = 0, verbose: Optional[bool] = True) -> Tuple[np.ndarray, np.ndarray]: 16 | """Select top `num` features of `value` from `data`. 17 | 18 | Parameters 19 | ---------- 20 | data : array 21 | Data matrix 22 | value : array_like 23 | Vector of values 24 | num : int 25 | Number of selected features 26 | 27 | Returns 28 | ------- 29 | selected_data : array 30 | Selected data matrix 31 | selected_index : array 32 | Index of selected data 33 | 34 | """ 35 | if verbose: 36 | print_start_msg() 37 | 38 | num_elem = data.shape[axis] 39 | 40 | value = np.array([-np.inf if np.isnan(a) else a for a in value]) 41 | sorted_index = np.argsort(value)[::-1] 42 | 43 | rank = np.zeros(num_elem, dtype=int) 44 | rank[sorted_index] = np.array(range(0, num_elem)) 45 | 46 | selected_index_bool = rank < num 47 | 48 | if axis == 0: 49 | selected_data = data[selected_index_bool, :] 50 | selected_index = np.array(range(0, num_elem), dtype=int)[selected_index_bool] 51 | elif axis == 1: 52 | selected_data = data[:, selected_index_bool] 53 | selected_index = np.array(range(0, num_elem), dtype=int)[selected_index_bool] 54 | else: 55 | raise ValueError('Invalid axis') 56 | 57 | if verbose: 58 | print_finish_msg() 59 | 60 | return selected_data, selected_index 61 | -------------------------------------------------------------------------------- /bdpy/preproc/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for preprocessing 3 | """ 4 | 5 | 6 | import inspect 7 | from datetime import datetime 8 | 9 | 10 | def print_start_msg(): 11 | """ 12 | Print process starting message 13 | """ 14 | print("%s Running %s" % (datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 15 | inspect.currentframe().f_back.f_code.co_name)) 16 | 17 | 18 | def print_finish_msg(): 19 | """ 20 | Print process finishing message 21 | """ 22 | print("%s DONE" % datetime.now().strftime('%Y-%m-%d %H:%M:%S')) 23 | -------------------------------------------------------------------------------- /bdpy/recon/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /bdpy/recon/torch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | -------------------------------------------------------------------------------- /bdpy/recon/torch/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import build_encoder, BaseEncoder 2 | from .generator import build_generator, BaseGenerator 3 | from .latent import ArbitraryLatent, BaseLatent 4 | from .critic import TargetNormalizedMSE, BaseCritic 5 | from .optimizer import build_optimizer_factory, build_scheduler_factory -------------------------------------------------------------------------------- /bdpy/recon/torch/modules/encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Iterable 5 | 6 | import torch 7 | import torch.nn as nn 8 | from bdpy.dl.torch import FeatureExtractor 9 | from bdpy.dl.torch.domain import Domain, InternalDomain 10 | 11 | 12 | class BaseEncoder(ABC): 13 | """Encoder network module.""" 14 | 15 | @abstractmethod 16 | def encode(self, images: torch.Tensor) -> dict[str, torch.Tensor]: 17 | """Encode images as a hierarchical feature representation. 18 | 19 | Parameters 20 | ---------- 21 | images : torch.Tensor 22 | Images. 23 | 24 | Returns 25 | ------- 26 | dict[str, torch.Tensor] 27 | Features indexed by the layer names. 28 | """ 29 | pass 30 | 31 | def __call__(self, images: torch.Tensor) -> dict[str, torch.Tensor]: 32 | """Call self.encode. 33 | 34 | Parameters 35 | ---------- 36 | images : torch.Tensor 37 | Images on the libraries internal domain. 38 | 39 | Returns 40 | ------- 41 | dict[str, torch.Tensor] 42 | Features indexed by the layer names. 43 | """ 44 | return self.encode(images) 45 | 46 | 47 | class NNModuleEncoder(BaseEncoder, nn.Module): 48 | """Encoder network module subclassed from nn.Module.""" 49 | 50 | def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]: 51 | """Call self.encode. 52 | 53 | Parameters 54 | ---------- 55 | images : torch.Tensor 56 | Images on the library's internal domain. 57 | 58 | Returns 59 | ------- 60 | dict[str, torch.Tensor] 61 | Features indexed by the layer names. 62 | """ 63 | return self.encode(images) 64 | 65 | def __call__(self, images: torch.Tensor) -> dict[str, torch.Tensor]: 66 | return nn.Module.__call__(self, images) 67 | 68 | 69 | class SimpleEncoder(NNModuleEncoder): 70 | """Encoder network module with a naive feature extractor. 71 | 72 | Parameters 73 | ---------- 74 | feature_network : nn.Module 75 | Feature network. This network should have a method `forward` that takes 76 | an image tensor and propagates it through the network. 77 | layer_names : list[str] 78 | Layer names to extract features from. 79 | domain : Domain, optional 80 | Domain of the input stimuli to receive. (default: InternalDomain()) 81 | 82 | Examples 83 | -------- 84 | >>> import torch 85 | >>> import torch.nn as nn 86 | >>> from bdpy.recon.torch.modules.encoder import SimpleEncoder 87 | >>> feature_network = nn.Sequential( 88 | ... nn.Conv2d(3, 3, 3), 89 | ... nn.ReLU(), 90 | ... ) 91 | >>> encoder = SimpleEncoder(feature_network, ['[0]']) 92 | >>> image = torch.randn(1, 3, 64, 64) 93 | >>> features = encoder(image) 94 | >>> features['[0]'].shape 95 | torch.Size([1, 3, 62, 62]) 96 | """ 97 | 98 | def __init__( 99 | self, 100 | feature_network: nn.Module, 101 | layer_names: Iterable[str], 102 | domain: Domain = InternalDomain(), 103 | ) -> None: 104 | super().__init__() 105 | self._feature_extractor = FeatureExtractor( 106 | encoder=feature_network, layers=layer_names, detach=False, device=None 107 | ) 108 | self._domain = domain 109 | self._feature_network = self._feature_extractor._encoder 110 | 111 | def encode(self, images: torch.Tensor) -> dict[str, torch.Tensor]: 112 | """Encode images as a hierarchical feature representation. 113 | 114 | Parameters 115 | ---------- 116 | images : torch.Tensor 117 | Images on the libraries internal domain. 118 | 119 | Returns 120 | ------- 121 | dict[str, torch.Tensor] 122 | Features indexed by the layer names. 123 | """ 124 | images = self._domain.receive(images) 125 | return self._feature_extractor(images) 126 | 127 | 128 | def build_encoder( 129 | feature_network: nn.Module, 130 | layer_names: Iterable[str], 131 | domain: Domain = InternalDomain(), 132 | ) -> BaseEncoder: 133 | """Build an encoder network with a naive feature extractor. 134 | 135 | The function builds an encoder module from a feature network that takes 136 | images on its own domain as input and processes them. The encoder module 137 | receives images on the library's internal domain and returns features on the 138 | library's internal domain indexed by `layer_names`. `domain` is used to 139 | convert the input images to the feature network's domain from the library's 140 | internal domain. 141 | 142 | Parameters 143 | ---------- 144 | feature_network : nn.Module 145 | Feature network. This network should have a method `forward` that takes 146 | an image tensor and propagates it through the network. The images should 147 | be on the network's own domain. 148 | layer_names : list[str] 149 | Layer names to extract features from. 150 | domain : Domain, optional 151 | Domain of the input stimuli to receive (default: InternalDomain()). 152 | One needs to specify the domain that corresponds to the feature network's 153 | input domain. 154 | 155 | Returns 156 | ------- 157 | BaseEncoder 158 | Encoder network. 159 | 160 | Examples 161 | -------- 162 | >>> import torch 163 | >>> import torch.nn as nn 164 | >>> from bdpy.recon.torch.modules.encoder import build_encoder 165 | >>> feature_network = nn.Sequential( 166 | ... nn.Conv2d(3, 3, 3), 167 | ... nn.ReLU(), 168 | ... ) 169 | >>> encoder = build_encoder(feature_network, layer_names=['[0]']) 170 | >>> image = torch.randn(1, 3, 64, 64) 171 | >>> features = encoder(image) 172 | >>> features['[0]'].shape 173 | torch.Size([1, 3, 62, 62]) 174 | """ 175 | return SimpleEncoder(feature_network, layer_names, domain) 176 | -------------------------------------------------------------------------------- /bdpy/recon/torch/modules/latent.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Callable, Iterator 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class BaseLatent(ABC): 11 | """Latent variable module.""" 12 | 13 | @abstractmethod 14 | def reset_states(self) -> None: 15 | """Reset the state of the latent variable.""" 16 | pass 17 | 18 | @abstractmethod 19 | def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: 20 | """Return the parameters of the latent variable.""" 21 | pass 22 | 23 | @abstractmethod 24 | def generate(self) -> torch.Tensor: 25 | """Generate a latent variable. 26 | 27 | Returns 28 | ------- 29 | torch.Tensor 30 | Latent variable. 31 | """ 32 | pass 33 | 34 | def __call__(self) -> torch.Tensor: 35 | """Call self.generate. 36 | 37 | Returns 38 | ------- 39 | torch.Tensor 40 | Latent variable. 41 | """ 42 | return self.generate() 43 | 44 | 45 | class NNModuleLatent(BaseLatent, nn.Module): 46 | """Latent variable module uses __call__ method and parameters method of nn.Module.""" 47 | 48 | def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: 49 | return nn.Module.parameters(self, recurse=recurse) 50 | 51 | def __call__(self) -> torch.Tensor: 52 | return nn.Module.__call__(self) 53 | 54 | def forward(self) -> torch.Tensor: 55 | return self.generate() 56 | 57 | 58 | class ArbitraryLatent(NNModuleLatent): 59 | """Latent variable with arbitrary shape and initialization function. 60 | 61 | Parameters 62 | ---------- 63 | shape : tuple[int, ...] 64 | Shape of the latent variable including the batch dimension. 65 | init_fn : Callable[[torch.Tensor], None] 66 | Function to initialize the latent variable. 67 | 68 | Examples 69 | -------- 70 | >>> from functools import partial 71 | >>> import torch 72 | >>> import torch.nn as nn 73 | >>> from bdpy.recon.torch.modules.latent import ArbitraryLatent 74 | >>> latent = ArbitraryLatent((1, 3, 64, 64), partial(nn.init.normal_, mean=0, std=1)) 75 | >>> latent().shape 76 | torch.Size([1, 3, 64, 64]) 77 | """ 78 | 79 | def __init__(self, shape: tuple[int, ...], init_fn: Callable[[torch.Tensor], None]) -> None: 80 | super().__init__() 81 | self._shape = shape 82 | self._init_fn = init_fn 83 | self._latent = nn.Parameter(torch.empty(shape)) 84 | 85 | def reset_states(self) -> None: 86 | """Reset the state of the latent variable.""" 87 | self._init_fn(self._latent) 88 | 89 | def generate(self) -> torch.Tensor: 90 | """Generate a latent variable. 91 | 92 | Returns 93 | ------- 94 | torch.Tensor 95 | Latent variable. 96 | """ 97 | return self._latent 98 | -------------------------------------------------------------------------------- /bdpy/recon/torch/modules/optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | from functools import partial 5 | from itertools import chain 6 | 7 | if TYPE_CHECKING: 8 | from typing import Dict, Any, Tuple, Union, Iterable, Callable 9 | from typing_extensions import TypeAlias 10 | from torch import Tensor 11 | import torch.optim as optim 12 | from ..modules import BaseGenerator, BaseLatent 13 | 14 | # NOTE: The definition of `_ParamsT` is the same as in `torch.optim.optimizer` 15 | # in torch>=2.2.0. We define it here for compatibility with older versions. 16 | _ParamsT: TypeAlias = Union[ 17 | Iterable[Tensor], Iterable[Dict[str, Any]], Iterable[Tuple[str, Tensor]] 18 | ] 19 | 20 | _OptimizerFactoryType: TypeAlias = Callable[ 21 | [BaseGenerator, BaseLatent], optim.Optimizer 22 | ] 23 | _SchedulerFactoryType: TypeAlias = Callable[ 24 | [optim.Optimizer], optim.lr_scheduler.LRScheduler 25 | ] 26 | _GetParamsFnType: TypeAlias = Callable[[BaseGenerator, BaseLatent], _ParamsT] 27 | 28 | 29 | def build_optimizer_factory( 30 | optimizer_class: type[optim.Optimizer], 31 | *, 32 | get_params_fn: _GetParamsFnType | None = None, 33 | **kwargs, 34 | ) -> _OptimizerFactoryType: 35 | """Build an optimizer factory. 36 | 37 | Parameters 38 | ---------- 39 | optimizer_class : type 40 | Optimizer class. 41 | get_params_fn : Callable[[BaseGenerator, BaseLatent], _ParamsT] | None 42 | Custom function to get parameters from the generator and the latent. 43 | If None, it uses `chain(generator.parameters(), latent.parameters())`. 44 | kwargs : dict 45 | Keyword arguments for the optimizer. 46 | 47 | Returns 48 | ------- 49 | Callable[[BaseGenerator, BaseLatent], optim.Optimizer] 50 | Optimizer factory. 51 | 52 | Examples 53 | -------- 54 | >>> from torch.optim import Adam 55 | >>> from bdpy.recon.torch.modules import build_optimizer_factory 56 | >>> optimizer_factory = build_optimizer_factory(Adam, lr=1e-3) 57 | >>> optimizer = optimizer_factory(generator, latent) 58 | """ 59 | if get_params_fn is None: 60 | get_params_fn = lambda generator, latent: chain( 61 | generator.parameters(), latent.parameters() 62 | ) 63 | 64 | def init_fn(generator: BaseGenerator, latent: BaseLatent) -> optim.Optimizer: 65 | return optimizer_class(get_params_fn(generator, latent), **kwargs) 66 | 67 | return init_fn 68 | 69 | 70 | def build_scheduler_factory( 71 | scheduler_class: type[optim.lr_scheduler.LRScheduler], **kwargs 72 | ) -> _SchedulerFactoryType: 73 | """Build a scheduler factory. 74 | 75 | Parameters 76 | ---------- 77 | scheduler_class : type 78 | Scheduler class. 79 | kwargs : dict 80 | Keyword arguments for the scheduler. 81 | 82 | Returns 83 | ------- 84 | Callable[[optim.Optimizer], optim.lr_scheduler.LRScheduler] 85 | Scheduler factory. 86 | 87 | Examples 88 | -------- 89 | >>> from torch.optim.lr_scheduler import StepLR 90 | >>> from bdpy.recon.torch.modules import build_scheduler_factory 91 | >>> scheduler_factory = build_scheduler_factory(StepLR, step_size=100, gamma=0.1) 92 | >>> scheduler = scheduler_factory(optimizer) 93 | """ 94 | return partial(scheduler_class, **kwargs) 95 | -------------------------------------------------------------------------------- /bdpy/recon/torch/task/__init__.py: -------------------------------------------------------------------------------- 1 | from .inversion import FeatureInversionTask -------------------------------------------------------------------------------- /bdpy/stats/__init__.py: -------------------------------------------------------------------------------- 1 | """BdPy statistics package. 2 | 3 | This package is a part of BdPy. 4 | 5 | 6 | Functions: 7 | 8 | - `corrcoef` : Returns correlation coefficient between `x` and `y` 9 | - `corrmat` : Returns correlation matrix between `x` and `y` 10 | """ 11 | 12 | 13 | from .corr import corrcoef, corrmat 14 | -------------------------------------------------------------------------------- /bdpy/stats/corr.py: -------------------------------------------------------------------------------- 1 | """Functions dealing with correlation. 2 | 3 | This file is a part of BdPy. 4 | """ 5 | 6 | __all__ = ['corrcoef', 'corrmat'] 7 | 8 | 9 | import numpy as np 10 | from numpy.matlib import repmat 11 | 12 | 13 | def corrcoef(x: np.ndarray, y: np.ndarray, var: str = 'row') -> np.ndarray: 14 | """Return correlation coefficients between `x` and `y`. 15 | 16 | Parameters 17 | ---------- 18 | x, y : array_like 19 | Matrix or vector 20 | var : str, 'row' or 'col' 21 | Specifying whether rows (default) or columns represent variables 22 | 23 | Returns 24 | ------- 25 | r 26 | Correlation coefficients 27 | """ 28 | # Convert vectors to arrays 29 | if x.ndim == 1: 30 | x = np.array([x]) 31 | 32 | if y.ndim == 1: 33 | y = np.array([y]) 34 | 35 | # Normalize x and y to row-var format 36 | if var == 'row': 37 | # 'rowvar=1' in np.corrcoef 38 | 39 | # Vertical vector --> horizontal 40 | if x.shape[1] == 1: 41 | x = x.T 42 | if y.shape[1] == 1: 43 | y = y.T 44 | elif var == 'col': 45 | # 'rowvar=0' in np.corrcoef 46 | 47 | # Horizontal vector --> vertical 48 | if x.shape[0] == 1: 49 | x = x.T 50 | if y.shape[0] == 1: 51 | y = y.T 52 | 53 | # Convert to rowvar=1 54 | x = x.T 55 | y = y.T 56 | else: 57 | raise ValueError('Unknown var parameter specified') 58 | 59 | # Match size of x and y 60 | if x.shape[0] == 1 and y.shape[0] != 1: 61 | x = repmat(x, y.shape[0], 1) 62 | 63 | elif x.shape[0] != 1 and y.shape[0] == 1: 64 | y = repmat(y, x.shape[0], 1) 65 | 66 | # Check size of normalized x and y 67 | if x.shape != y.shape: 68 | raise TypeError('Input matrixes size mismatch') 69 | 70 | # Get num variables 71 | nvar = x.shape[0] 72 | 73 | # Get correlation 74 | rmat = np.corrcoef(x, y, rowvar=1) 75 | r = np.diag(rmat[:nvar, nvar:]) 76 | 77 | return r 78 | 79 | 80 | def corrmat(x: np.ndarray, y: np.ndarray, var: str = 'row') -> np.ndarray: 81 | """Return correlation matrix between `x` and `y`. 82 | 83 | Parameters 84 | ---------- 85 | x, y : array_like 86 | Matrix or vector 87 | var : str, 'row' or 'col' 88 | Specifying whether rows (default) or columns represent variables 89 | 90 | Returns 91 | ------- 92 | rmat 93 | Correlation matrix 94 | """ 95 | # Fix x and y to represent variables in each row 96 | if var == 'row': 97 | pass 98 | elif var == 'col': 99 | x = x.T 100 | y = y.T 101 | else: 102 | raise ValueError('Unknown var parameter specified') 103 | 104 | nobs = x.shape[1] 105 | 106 | # Subtract mean(a, axis=1) from a 107 | def submean(a: np.ndarray) -> np.ndarray: 108 | return a - np.matrix(np.mean(a, axis=1)).T 109 | 110 | cmat = (np.dot(submean(x), submean(y).T) / (nobs - 1)) / np.dot(np.matrix(np.std(x, axis=1, ddof=1)).T, np.matrix(np.std(y, axis=1, ddof=1))) 111 | 112 | return np.array(cmat) 113 | -------------------------------------------------------------------------------- /bdpy/task/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import BaseTask 2 | -------------------------------------------------------------------------------- /bdpy/task/core.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | from typing import Iterable, Any, TypeVar, Generic 6 | 7 | from bdpy.task.callback import CallbackHandler, BaseCallback 8 | 9 | 10 | _CallbackType = TypeVar("_CallbackType", bound=BaseCallback) 11 | 12 | 13 | class BaseTask(ABC, Generic[_CallbackType]): 14 | """Base class for tasks. 15 | 16 | Parameters 17 | ---------- 18 | callbacks : BaseCallback | Iterable[BaseCallback] | None 19 | Callbacks to register. If `None`, no callbacks are registered. 20 | 21 | Attributes 22 | ---------- 23 | _callback_handler : CallbackHandler 24 | Callback handler. 25 | 26 | Notes 27 | ----- 28 | This class is designed to be used as a base class for tasks. The task 29 | implementation should override the `__call__` method. The actual interface 30 | of `__call__` depends on the task. For example, the task may take a single 31 | input and return a single output, or it may take multiple inputs and return 32 | multiple outputs. The task may also take keyword arguments. Please refer to 33 | the documentation of the specific task for details. 34 | """ 35 | 36 | _callback_handler: CallbackHandler[_CallbackType] 37 | 38 | def __init__( 39 | self, callbacks: _CallbackType | Iterable[_CallbackType] | None = None 40 | ) -> None: 41 | self._callback_handler = CallbackHandler(callbacks) 42 | 43 | def __call__(self, *inputs, **parameters) -> Any: 44 | """Run the task.""" 45 | return self.run(*inputs, **parameters) 46 | 47 | @abstractmethod 48 | def run(self, *inputs, **parameters) -> Any: 49 | """Run the task.""" 50 | pass 51 | 52 | def register_callback(self, callback: _CallbackType) -> None: 53 | """Register a callback. 54 | 55 | Parameters 56 | ---------- 57 | callback : BaseCallback 58 | Callback to register. 59 | """ 60 | self._callback_handler.register(callback) 61 | -------------------------------------------------------------------------------- /bdpy/util/__init__.py: -------------------------------------------------------------------------------- 1 | """BdPy utility package. 2 | 3 | This package is a part of BdPy. 4 | """ 5 | 6 | 7 | from .utils import create_groupvector, divide_chunks, get_refdata, makedir_ifnot 8 | from .info import dump_info 9 | from .math import average_elemwise 10 | -------------------------------------------------------------------------------- /bdpy/util/info.py: -------------------------------------------------------------------------------- 1 | """Information module.""" 2 | 3 | 4 | from typing import Dict, Optional 5 | 6 | import datetime 7 | import hashlib 8 | import os 9 | import sys 10 | import time 11 | import uuid 12 | import warnings 13 | import yaml 14 | import pwd 15 | 16 | 17 | def dump_info(output_dir: str, script: Optional[str] = None, parameters: Optional[Dict] = None, info_file: str ='info.yaml') -> Dict: 18 | """Dump runtime information.""" 19 | if script is not None: 20 | script_path = os.path.abspath(script) 21 | with open(script_path, 'r') as f: 22 | script_txt = f.read() 23 | if sys.version_info.major == 2: 24 | script_md5 = hashlib.md5(script_txt).hexdigest() 25 | else: 26 | script_md5 = hashlib.md5(script_txt.encode()).hexdigest() 27 | else: 28 | script_path = None 29 | script_txt = None 30 | script_md5 = None 31 | 32 | run_id = str(uuid.uuid1()) 33 | run_time = time.time() 34 | run_info = { 35 | 'run_time' : run_time, 36 | 'time_stamp' : datetime.datetime.fromtimestamp(run_time).strftime('%Y-%m-%d %H:%M:%S'), 37 | 'host' : os.uname()[1], 38 | 'hardware' : os.uname()[4], 39 | 'os' : os.uname()[0], 40 | 'os_release' : os.uname()[2], 41 | 'os_version' : os.uname()[3], 42 | 'user' : pwd.getpwuid(os.geteuid())[0], 43 | 'script_path': script_path, 44 | 'script_txt' : script_txt, 45 | 'script_md5' : script_md5, 46 | } 47 | 48 | if parameters is not None: 49 | parameters_fixed = {} 50 | for k, v in parameters.items(): 51 | if isinstance(v, type({}.keys())): 52 | v = list(v) 53 | parameters_fixed.update({k: v}) 54 | run_info.update({'parameters': parameters_fixed}) 55 | 56 | run_info_file = os.path.join(output_dir, info_file) 57 | 58 | if os.path.isfile(run_info_file): 59 | with open(run_info_file, 'r') as f: 60 | info_yaml = yaml.load(f, Loader=yaml.SafeLoader) 61 | while info_yaml is None: 62 | warnings.warn('Failed to load info from %s. Retrying...' % run_info_file, stacklevel=2) 63 | with open(run_info_file, 'r') as f: 64 | info_yaml = yaml.load(f, Loader=yaml.SafeLoader) 65 | 66 | else: 67 | info_yaml = {} 68 | 69 | if '_runtime_info' in info_yaml: 70 | pass 71 | else: 72 | info_yaml.update({'_runtime_info' : {}}) 73 | 74 | info_yaml['_runtime_info'].update({run_id: run_info}) 75 | 76 | with open(run_info_file, 'w') as f: 77 | f.write(yaml.dump(info_yaml, default_flow_style=False)) 78 | 79 | return run_info 80 | -------------------------------------------------------------------------------- /bdpy/util/math.py: -------------------------------------------------------------------------------- 1 | """Math utils.""" 2 | 3 | 4 | from typing import List 5 | 6 | import numpy as np 7 | 8 | 9 | def average_elemwise(arrays: List[np.ndarray], keepdims: bool = False) -> np.ndarray: 10 | """Return element-wise mean of arrays. 11 | 12 | Parameters 13 | ---------- 14 | arrays : list of ndarrays 15 | List of arrays. 16 | keepdims : bool 17 | Keep dimension in returned array or not. 18 | 19 | Return 20 | ------ 21 | ndarray 22 | """ 23 | n_array = len(arrays) 24 | 25 | max_dim_i = np.argmax([a.ndim for a in arrays]) 26 | max_array_shape = arrays[max_dim_i].shape 27 | 28 | arrays_sum = np.zeros(max_array_shape) 29 | 30 | for a in arrays: 31 | arrays_sum += a 32 | 33 | mean_array = arrays_sum / n_array 34 | 35 | if not keepdims: 36 | mean_array = np.squeeze(mean_array) 37 | 38 | return mean_array 39 | -------------------------------------------------------------------------------- /bdpy/util/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions. 2 | 3 | This file is a part of BdPy. 4 | """ 5 | 6 | 7 | from __future__ import division 8 | 9 | 10 | __all__ = ['create_groupvector', 11 | 'divide_chunks', 12 | 'get_refdata', 13 | 'makedir_ifnot'] 14 | 15 | 16 | from typing import List, Union 17 | 18 | import os 19 | import warnings 20 | 21 | import numpy as np 22 | 23 | 24 | def create_groupvector(group_label: Union[List, np.ndarray], group_size: Union[List, np.ndarray]) -> Union[List, np.ndarray]: 25 | """Create a group vector. 26 | 27 | Parameters 28 | ---------- 29 | group_label : array_like 30 | List or array of group labels. 31 | group_size : array_like 32 | Sample size of each group. 33 | 34 | Returns 35 | ------- 36 | group_vector : array_like, shape = (N,) 37 | A vector specifying groups. 38 | 39 | Example 40 | ------- 41 | 42 | >>> bdpy.util.create_groupvector([ 1, 2, 3 ], 2) 43 | array([1, 1, 2, 2, 3, 3]) 44 | 45 | >>> bdpy.util.create_groupvector([ 1, 2, 3 ], [ 2, 4, 2 ]) 46 | array([1, 1, 2, 2, 2, 2, 3, 3]) 47 | """ 48 | group_vector = [] 49 | 50 | if isinstance(group_size, int): 51 | # When 'group_size' is integer, create array in which each group label 52 | # is repeated for 'group_size' 53 | group_size_list = [group_size for _ in range(len(group_label))] 54 | elif isinstance(group_size, list) | isinstance(group_size, np.ndarray): 55 | if len(group_label) != len(group_size): 56 | raise ValueError("Length of 'group_label' and 'group_size' " 57 | "is mismatched") 58 | group_size_list = group_size 59 | else: 60 | raise TypeError("Invalid type of 'group_size'") 61 | 62 | group_list = [np.array([label for _ in range(group_size_list[i])]) 63 | for i, label in enumerate(group_label)] 64 | group_vector = np.hstack(group_list) 65 | 66 | return group_vector 67 | 68 | 69 | def divide_chunks(input_list: Union[List, np.ndarray], chunk_size: int = 100) -> List: 70 | """Divide elements in the input list into groups. 71 | 72 | Parameters 73 | ---------- 74 | input_list : array_like 75 | List or array to be divided. 76 | chunk_size : int 77 | The number of elements in each chunk. 78 | 79 | Returns 80 | ------- 81 | list 82 | List of chunks. 83 | 84 | Example 85 | ------- 86 | 87 | >>> a = [0, 1, 2, 3, 4, 5, 6] 88 | >>> divide_chunks(a, chunk_size=2) 89 | [[0, 1], [2, 3], [4, 5], [6]] 90 | >>> divide_chunks(a, chunk_size=3) 91 | [[0, 1, 2], [3, 4, 5], [6]] 92 | """ 93 | n_chunk = int(np.ceil(len(input_list) / chunk_size)) 94 | chunks = [input_list[i * chunk_size:(i + 1) * chunk_size] 95 | for i in range(n_chunk)] 96 | return chunks 97 | 98 | 99 | def get_refdata(data: Union[List, np.ndarray], ref_key: Union[List, np.ndarray], foreign_key: Union[List, np.ndarray]) -> Union[List, np.ndarray]: 100 | """Get data referred by `foreign_key`. 101 | 102 | Parameters 103 | ---------- 104 | data : array_like 105 | Data array. 106 | ref_key 107 | Reference keys for `data`. 108 | foreign_key 109 | Foreign keys referring `data` via `ref_key`. 110 | 111 | Returns 112 | ------- 113 | array_like 114 | Referred data. 115 | """ 116 | ind = [np.where(ref_key == i)[0][0] for i in foreign_key] 117 | 118 | if data.ndim == 1: 119 | return data[ind] 120 | else: 121 | return data[ind, :] 122 | 123 | 124 | def makedir_ifnot(dir_path: str) -> None: 125 | """Make a directory if it does not exist. 126 | 127 | Parameters 128 | ---------- 129 | dir_path : str 130 | Path to the directory to be created. 131 | 132 | Returns 133 | ------- 134 | bool 135 | True if the directory was created. 136 | """ 137 | if not os.path.isdir(dir_path): 138 | try: 139 | os.makedirs(dir_path) 140 | except OSError: 141 | warnings.warn('Failed to create directory %s.' % dir_path, stacklevel=2) 142 | return False 143 | return True 144 | else: 145 | return False 146 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /docs/bdata_api_examples.md: -------------------------------------------------------------------------------- 1 | # BData API examples 2 | 3 | ### Data API 4 | 5 | #### Import module and initialization. 6 | 7 | from bdpy import BData 8 | 9 | # Create an empty BData instance 10 | bdata = BData() 11 | 12 | # Load BData from a file 13 | bdata = BData('data_file.h5') 14 | 15 | #### Load data 16 | 17 | # Load BData from 'data_file.h5' 18 | bdata.load('data_file.h5') 19 | 20 | #### Show data 21 | 22 | # Show 'key' and 'description' of metadata 23 | bdata.show_meatadata() 24 | 25 | # Get 'value' of the metadata specified by 'key' 26 | voxel_x = bdata.get_metadata('voxel_x', where='VoxelData') 27 | 28 | #### Data extraction 29 | 30 | # Get an array of voxel data in V1 31 | data_v1 = bdata.select('ROI_V1') # shape=(M, num voxels in V1) 32 | 33 | # `select` accepts some operators 34 | data_v1v2 = bdata.select('ROI_V1 + ROI_V2') 35 | data_hvc = bdata.select('ROI_LOC + ROI_FFA + ROI_PPA - LOC_LVC') 36 | 37 | # Wildcard 38 | data_visual = data.select('ROI_V*') 39 | 40 | # Get labels ('image_index') in the dataset 41 | label_a = bdata.select('image_index') 42 | 43 | #### Data creation 44 | 45 | # Add new data 46 | x = numpy.random.rand(bdata.dataset.shape[0]) 47 | bdata.add(x, 'random_data') 48 | 49 | # Set description of metadata 50 | bdata.set_metadatadescription('random_data', 'Random data') 51 | 52 | # Save data 53 | bdata.save('output_file.h5') # File format is selected automatically by extension. .mat, .h5,and .npy are supported. 54 | -------------------------------------------------------------------------------- /docs/dataform_features.md: -------------------------------------------------------------------------------- 1 | # Features and DecodedFeatures 2 | 3 | bdpy provides classes to handle DNN's (true) features and decoded features: `dataform.Features` and `dataform.DecodedFeatures`. 4 | 5 | ## Basic usage 6 | 7 | ``` python 8 | from bdpy.dataform import Features, DecodedFeatures 9 | 10 | 11 | ## Initialize 12 | 13 | features = Features('/path/to/features/dir') 14 | 15 | decoded_features = DecodedFeatures('/path/to/decoded/features/dir') 16 | 17 | ## Get features as an array 18 | 19 | feat = features.get(layer='conv1') 20 | 21 | decfeat = decoded_features.get(layer='conv1', subject='sub-01', roi='VC', label='stimulus-0001) # Decoded features for specified sample (label) 22 | decfeat = decoded_features.get(layer='conv1', subject='sub-01', roi='VC') # Decoded features from all avaiable samples 23 | 24 | # Decoded features with CV 25 | decfeat = decoded_features.get(layer='conv1', subject='sub-01', roi='VC', fold='cv_fold1) 26 | 27 | ## List labels 28 | 29 | feat_labels = features.labels 30 | 31 | decfeat_labels = decoded_features.labels # All available labels 32 | decfeat_labels = decoded_features.selected_label # Labels assigned to decoded features previously obtained by `get` method 33 | ``` 34 | 35 | ## Feature statistics 36 | 37 | ``` python 38 | features.statistic('mean', layer='fc8') 39 | features.statistic('std', layer='fc8') # Default ddof = 1 40 | features.statistic('std, ddof=0', layer='fc8') 41 | 42 | decoded_features.statistic('mean', layer='fc8', subject='sub-01', roi='VC') 43 | decoded_features.statistic('std', layer='fc8', subject='sub-01', roi='VC') # Default ddof = 1 44 | decoded_features.statistic('std, ddof=0', layer='fc8', subject='sub-01', roi='VC') 45 | 46 | # Decoded features with CV 47 | decoded_features.statistic('mean', layer='fc8', subject='sub-01', roi='VC', fold='cv_fold1') # Mean within the specified fold 48 | decoded_features.statistic('mean', layer='fc8', subject='sub-01', roi='VC') 49 | 50 | # If `fold` is omitted for CV decoded features, decoded features are pooled across add CV folds and then the statistics are calculated. 51 | 52 | ``` 53 | 54 | 55 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Bdpy: Python Package for Brain Decoding 2 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | bdpy 3 | data 4 | figures 5 | -------------------------------------------------------------------------------- /examples/data/sample_vmap.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/examples/data/sample_vmap.h5 -------------------------------------------------------------------------------- /examples/data/sample_vmap_nomap.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/examples/data/sample_vmap_nomap.h5 -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.8 3 | warn_return_any = True 4 | warn_unused_configs = True 5 | 6 | [mypy-numpy] 7 | ignore_missing_imports = True 8 | 9 | [mypy-scipy] 10 | ignore_missing_imports = True 11 | 12 | [mypy-scipy.*] 13 | ignore_missing_imports = True 14 | 15 | [mypy-pandas] 16 | ignore_missing_imports = True 17 | 18 | [mypy-h5py] 19 | ignore_missing_imports = True 20 | 21 | [mypy-hdf5storage] 22 | ignore_missing_imports = True 23 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "bdpy" 7 | version = "0.25.1" 8 | description = "Brain decoder toolbox for Python" 9 | authors = [ 10 | { name = "Shuntaro C. Aoki", email = "s_aoki@i.kyoto-u.ac.jp" } 11 | ] 12 | readme = "README.md" 13 | requires-python = ">= 3.6, < 3.12" 14 | license = { file = "LICENSE" } 15 | keywords = ["neuroscience", "neuroimaging", "brain decoding", "fmri", "machine learning"] 16 | 17 | dependencies = [ 18 | "numpy>=1.20", 19 | "scipy", 20 | "scikit-learn", 21 | "h5py", 22 | "hdf5storage", 23 | "pyyaml", 24 | "pandas", 25 | "tqdm", 26 | "typing-extensions>=4.5", 27 | ] 28 | 29 | [project.optional-dependencies] 30 | caffe = [ 31 | "Pillow", 32 | ] 33 | torch = [ 34 | "torch", 35 | "torchvision", 36 | "Pillow" 37 | ] 38 | fig = [ 39 | "matplotlib", 40 | "Pillow" 41 | ] 42 | mri = [ 43 | "numpy<1.24", 44 | "nibabel==3.2", 45 | "nipy" 46 | ] 47 | pipeline = [ 48 | "hydra-core", 49 | "omegaconf" 50 | ] 51 | all = [ 52 | "bdpy[caffe]", 53 | "bdpy[torch]", 54 | "bdpy[fig]", 55 | "bdpy[mri]", 56 | "bdpy[pipeline]" 57 | ] 58 | dev = [ 59 | "bdpy[all]", 60 | "fastl2lir" 61 | ] 62 | 63 | [project.urls] 64 | Homepage = "https://github.com/KamitaniLab/bdpy" 65 | Repository = "https://github.com/KamitaniLab/bdpy" 66 | "Bug Tracker" = "https://github.com/KamitaniLab/bdpy/issues" 67 | 68 | [tool.rye] 69 | managed = true 70 | dev-dependencies = [ 71 | "bdpy[dev]", 72 | "jupyter", 73 | "pytest", 74 | "pytest-cov", 75 | "ruff", 76 | "mypy" 77 | ] 78 | 79 | [tool.hatch.metadata] 80 | allow-direct-references = true 81 | 82 | [tool.hatch.build.targets.wheel] 83 | only-include = ["bdpy"] 84 | 85 | [tool.hatch.build.targets.sdist] 86 | only-include = ["bdpy"] 87 | 88 | [tool.ruff] 89 | select = ["E", "F", "N", "D", "ANN", "B", "NPY", "RUF"] 90 | ignore = ["E501", "ANN101", "D213", "D203"] 91 | exclude = ["test"] 92 | 93 | [tool.ruff.extend-per-file-ignores] 94 | "__init__.py" = ["F401"] 95 | 96 | [tool.ruff.flake8-annotations] 97 | mypy-init-return = true 98 | 99 | [tool.pylint] 100 | disable = "line-too-long" 101 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | bdpy -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/__init__.py -------------------------------------------------------------------------------- /tests/bdata/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/bdata/__init__.py -------------------------------------------------------------------------------- /tests/bdata/test_featureselector.py: -------------------------------------------------------------------------------- 1 | '''Tests for FeatureSelector''' 2 | 3 | 4 | import unittest 5 | 6 | from bdpy.bdata.featureselector import FeatureSelector 7 | 8 | 9 | class TestFeatureSelector(unittest.TestCase): 10 | '''Tests for FeatureSelector''' 11 | 12 | def test_lexical_analysis_0001(self): 13 | 14 | test_input = 'A = B' 15 | exp_output = ('A', '=', 'B') 16 | 17 | selector = FeatureSelector(test_input) 18 | 19 | test_output = selector.token 20 | 21 | self.assertEqual(test_output, exp_output) 22 | 23 | def test_lexical_analysis_0002(self): 24 | 25 | test_input = 'HOGE = 1' 26 | exp_output = ('HOGE', '=', '1') 27 | 28 | selector = FeatureSelector(test_input) 29 | 30 | test_output = selector.token 31 | 32 | self.assertEqual(test_output, exp_output) 33 | 34 | def test_lexical_analysis_0003(self): 35 | 36 | test_input = '(HOGE = 1) & (FUGA = 0)' 37 | exp_output = ('(', 'HOGE', '=', '1', ')', 38 | '&', '(', 'FUGA', '=', '0', ')') 39 | 40 | selector = FeatureSelector(test_input) 41 | 42 | test_output = selector.token 43 | 44 | self.assertEqual(test_output, exp_output) 45 | 46 | # def test_lexical_analysis_0004(self): 47 | 48 | # test_input = 'HOGE top 100' 49 | # exp_output = ('HOGE', 'top', '100') 50 | 51 | # selector = FeatureSelector(test_input) 52 | 53 | # test_output = selector.token 54 | 55 | # self.assertEqual(test_output, exp_output) 56 | 57 | def test_parse_0001(self): 58 | 59 | test_input = 'A = B' 60 | exp_output = ('A', 'B', '=') 61 | 62 | selector = FeatureSelector(test_input) 63 | 64 | test_output = selector.rpn 65 | 66 | self.assertEqual(test_output, exp_output) 67 | 68 | def test_parse_0002(self): 69 | 70 | test_input = 'A = 1 | B = 0' 71 | exp_output = ('A', '1', '=', 'B', '0', '=', '|') 72 | 73 | selector = FeatureSelector(test_input) 74 | 75 | test_output = selector.rpn 76 | 77 | self.assertEqual(test_output, exp_output) 78 | 79 | # def test_parse_0003(self): 80 | 81 | # test_input = 'HOGE top 100 @ A = 1' 82 | # exp_output = ('HOGE', '100', 'top', 'A', '1', '=', '@') 83 | 84 | # selector = FeatureSelector(test_input) 85 | 86 | # test_output = selector.rpn 87 | 88 | # self.assertEqual(test_output, exp_output) 89 | 90 | def test_parse_0004(self): 91 | 92 | test_input = 'A = 1 & B = 2 | C = 3 & D = 4' 93 | exp_output = ('A', '1', '=', 'B', '2', '=', '&', 94 | 'C', '3', '=', '|', 'D', '4', '=', '&') 95 | 96 | selector = FeatureSelector(test_input) 97 | 98 | test_output = selector.rpn 99 | 100 | self.assertEqual(test_output, exp_output) 101 | 102 | def test_parse_0005(self): 103 | 104 | test_input = 'A = 1 & B = 2 | C = 3' 105 | exp_output = ('A', '1', '=', 'B', '2', '=', '&', 'C', '3', '=', '|') 106 | 107 | selector = FeatureSelector(test_input) 108 | 109 | test_output = selector.rpn 110 | 111 | self.assertEqual(test_output, exp_output) 112 | 113 | def test_parse_0006(self): 114 | 115 | test_input = 'A = 1 & (B = 2 | C = 3)' 116 | exp_output = ('A', '1', '=', 'B', '2', '=', 'C', '3', '=', '|', '&') 117 | 118 | selector = FeatureSelector(test_input) 119 | 120 | test_output = selector.rpn 121 | 122 | self.assertEqual(test_output, exp_output) 123 | 124 | 125 | if __name__ == '__main__': 126 | unittest.main() 127 | -------------------------------------------------------------------------------- /tests/bdata/test_metadata.py: -------------------------------------------------------------------------------- 1 | '''Tests for bdpy.bdata.metadata.''' 2 | 3 | 4 | import unittest 5 | 6 | import numpy as np 7 | from numpy.testing import assert_array_equal 8 | 9 | from bdpy.bdata import metadata 10 | 11 | 12 | class TestMetadata(unittest.TestCase): 13 | '''Tests for bdpy.bdata.metadata.''' 14 | 15 | def __init__(self, *args, **kwargs): 16 | super(TestMetadata, self).__init__(*args, **kwargs) 17 | 18 | def test_set_get(self): 19 | '''Test for MetaData.set() and MetaData.get().''' 20 | md = metadata.MetaData() 21 | md.set('MetaData_A', [1] * 10 + [0] * 5, 'Test metadata A') 22 | md.set('MetaData_B', [0] * 10 + [1] * 5, 'Test metadata B') 23 | 24 | assert_array_equal(md.get('MetaData_A', 'value'), [1] * 10 + [0] * 5) 25 | assert_array_equal(md.get('MetaData_A', 'description'), 'Test metadata A') 26 | assert_array_equal(md.get('MetaData_B', 'value'), [0] * 10 + [1] * 5) 27 | assert_array_equal(md.get('MetaData_B', 'description'), 'Test metadata B') 28 | 29 | def test_set_get_resize(self): 30 | '''Test for MetaData.set() and MetaData.get(); resizing values.''' 31 | md = metadata.MetaData() 32 | md.set('MetaData_A', [1] * 10 + [0] * 5, 'Test metadata A') 33 | md.set('MetaData_B', [0] * 10 + [1] * 5, 'Test metadata B') 34 | md.set('MetaData_C', [0] * 15 + [1] * 3, 'Test metadata C') 35 | 36 | assert_array_equal(md.get('MetaData_A', 'value'), [1] * 10 + [0] * 5 + [np.nan] * 3) 37 | assert_array_equal(md.get('MetaData_A', 'description'), 'Test metadata A') 38 | assert_array_equal(md.get('MetaData_B', 'value'), [0] * 10 + [1] * 5 + [np.nan] * 3) 39 | assert_array_equal(md.get('MetaData_B', 'description'), 'Test metadata B') 40 | assert_array_equal(md.get('MetaData_C', 'value'), [0] * 15 + [1] * 3) 41 | assert_array_equal(md.get('MetaData_C', 'description'), 'Test metadata C') 42 | 43 | def test_set_get_overwrite(self): 44 | '''Test for MetaData.set() and MetaData.get(); overwriting values.''' 45 | md = metadata.MetaData() 46 | md.set('MetaData_A', [1] * 10 + [0] * 5, 'Test metadata A') 47 | md.set('MetaData_B', [0] * 10 + [1] * 5, 'Test metadata B') 48 | 49 | md.set('MetaData_A', [10] * 10 + [0] * 5, 'Test metadata A') 50 | 51 | assert_array_equal(md.get('MetaData_A', 'value'), [10] * 10 + [0] * 5) 52 | assert_array_equal(md.get('MetaData_A', 'description'), 'Test metadata A') 53 | assert_array_equal(md.get('MetaData_B', 'value'), [0] * 10 + [1] * 5) 54 | assert_array_equal(md.get('MetaData_B', 'description'), 'Test metadata B') 55 | 56 | def test_set_get_overwrite_resize(self): 57 | '''Test for MetaData.set() and MetaData.get(); overwriting and resizing values.''' 58 | md = metadata.MetaData() 59 | md.set('MetaData_A', [1, 1, 1, 0, 0], 'Test metadata A') 60 | md.set('MetaData_B', [0, 0, 0, 1, 1], 'Test metadata B') 61 | 62 | md.set('MetaData_A', [2, 2, 2, 0, 0, 1, 1], 'Test metadata A') 63 | 64 | assert_array_equal(md.get('MetaData_A', 'value'), [2, 2, 2, 0, 0, 1, 1]) 65 | assert_array_equal(md.get('MetaData_A', 'description'), 'Test metadata A') 66 | assert_array_equal(md.get('MetaData_B', 'value'), [0, 0, 0, 1, 1, np.nan, np.nan]) 67 | assert_array_equal(md.get('MetaData_B', 'description'), 'Test metadata B') 68 | 69 | def test_set_get_update(self): 70 | '''Test for MetaData.set() and MetaData.get(); updating values.''' 71 | md = metadata.MetaData() 72 | md.set('MetaData_A', [1] * 10 + [0] * 5, 'Test metadata A') 73 | md.set('MetaData_B', [0] * 10 + [1] * 5, 'Test metadata B') 74 | 75 | md.set('MetaData_A', [10] * 10 + [0] * 5, 'Test metadata A', updater=lambda x, y: x + y) 76 | 77 | assert_array_equal(md.get('MetaData_A', 'value'), [11] * 10 + [0] * 5) 78 | assert_array_equal(md.get('MetaData_A', 'description'), 'Test metadata A') 79 | assert_array_equal(md.get('MetaData_B', 'value'), [0] * 10 + [1] * 5) 80 | assert_array_equal(md.get('MetaData_B', 'description'), 'Test metadata B') 81 | 82 | def test_get_notfound(self): 83 | '''Test for MetaData.get(); key not found case.''' 84 | md = metadata.MetaData() 85 | md.set('MetaData_A', [1] * 10 + [0] * 5, 'Test metadata A') 86 | md.set('MetaData_B', [0] * 10 + [1] * 5, 'Test metadata B') 87 | 88 | assert_array_equal(md.get('MetaData_NotFound', 'value'), None) 89 | assert_array_equal(md.get('MetaData_NotFound', 'description'), None) 90 | 91 | def test_get_value_len(self): 92 | '''Test for get_value_len().''' 93 | md = metadata.MetaData() 94 | md.set('MetaData_A', [1] * 10 + [0] * 5, 'Test metadata A') 95 | md.set('MetaData_B', [0] * 10 + [1] * 5, 'Test metadata B') 96 | 97 | assert_array_equal(md.get_value_len(), 15) 98 | 99 | def test_keylist(self): 100 | '''Test for keylist().''' 101 | md = metadata.MetaData() 102 | md.set('MetaData_A', [1] * 10 + [0] * 5, 'Test metadata A') 103 | md.set('MetaData_B', [0] * 10 + [1] * 5, 'Test metadata B') 104 | 105 | assert_array_equal(md.keylist(), ['MetaData_A', 'MetaData_B']) 106 | 107 | 108 | if __name__ == '__main__': 109 | unittest.main() 110 | -------------------------------------------------------------------------------- /tests/data/array_jl_dense_v1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/array_jl_dense_v1.mat -------------------------------------------------------------------------------- /tests/data/array_jl_sparse_v1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/array_jl_sparse_v1.mat -------------------------------------------------------------------------------- /tests/data/mri/epi0001.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/mri/epi0001.hdr -------------------------------------------------------------------------------- /tests/data/mri/epi0001.img: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/mri/epi0001.img -------------------------------------------------------------------------------- /tests/data/mri/epi0002.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/mri/epi0002.hdr -------------------------------------------------------------------------------- /tests/data/mri/epi0002.img: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/mri/epi0002.img -------------------------------------------------------------------------------- /tests/data/mri/epi0003.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/mri/epi0003.hdr -------------------------------------------------------------------------------- /tests/data/mri/epi0003.img: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/mri/epi0003.img -------------------------------------------------------------------------------- /tests/data/mri/epi0004.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/mri/epi0004.hdr -------------------------------------------------------------------------------- /tests/data/mri/epi0004.img: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/mri/epi0004.img -------------------------------------------------------------------------------- /tests/data/mri/epi0005.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/mri/epi0005.hdr -------------------------------------------------------------------------------- /tests/data/mri/epi0005.img: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/mri/epi0005.img -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/W/00000000.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/W/00000000.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/W/00000001.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/W/00000001.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/W/00000002.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/W/00000002.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/W/00000003.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/W/00000003.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/W/00000004.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/W/00000004.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/W/00000005.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/W/00000005.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/W/00000006.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/W/00000006.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/W/00000007.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/W/00000007.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/b/00000000.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/b/00000000.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/b/00000001.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/b/00000001.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/b/00000002.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/b/00000002.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/b/00000003.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/b/00000003.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/b/00000004.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/b/00000004.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/b/00000005.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/b/00000005.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/b/00000006.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/b/00000006.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/b/00000007.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-bd/b/00000007.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-bd/info.yaml: -------------------------------------------------------------------------------- 1 | _status: 2 | computation_id: fastl2lir-chunk-bd 3 | computation_status: done 4 | -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-pkl/00000000.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-pkl/00000000.pkl.gz -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-pkl/00000001.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-pkl/00000001.pkl.gz -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-pkl/00000002.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-pkl/00000002.pkl.gz -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-pkl/00000003.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-pkl/00000003.pkl.gz -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-pkl/00000004.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-pkl/00000004.pkl.gz -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-pkl/00000005.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-pkl/00000005.pkl.gz -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-pkl/00000006.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-pkl/00000006.pkl.gz -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-pkl/00000007.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-chunk-pkl/00000007.pkl.gz -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-chunk-pkl/info.yaml: -------------------------------------------------------------------------------- 1 | _status: 2 | computation_id: fastl2lir-chunk-pkl 3 | computation_status: done 4 | -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-nochunk-bd/W.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-nochunk-bd/W.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-nochunk-bd/b.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-nochunk-bd/b.mat -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-nochunk-bd/info.yaml: -------------------------------------------------------------------------------- 1 | _status: 2 | computation_id: fastl2lir-nochunk-bd 3 | computation_status: done 4 | -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-nochunk-pkl/info.yaml: -------------------------------------------------------------------------------- 1 | _status: 2 | computation_id: fastl2lir-nochunk-pkl 3 | computation_status: done 4 | -------------------------------------------------------------------------------- /tests/data/test_models/fastl2lir-nochunk-pkl/model.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/fastl2lir-nochunk-pkl/model.pkl.gz -------------------------------------------------------------------------------- /tests/data/test_models/lir-nochunk-pkl/info.yaml: -------------------------------------------------------------------------------- 1 | _status: 2 | computation_id: lir-nochunk-pkl 3 | computation_status: done 4 | -------------------------------------------------------------------------------- /tests/data/test_models/lir-nochunk-pkl/model.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/test_models/lir-nochunk-pkl/model.pkl.gz -------------------------------------------------------------------------------- /tests/data/testdata-2d-nan.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/testdata-2d-nan.pkl.gz -------------------------------------------------------------------------------- /tests/data/testdata-2d.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/data/testdata-2d.pkl.gz -------------------------------------------------------------------------------- /tests/dataform/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/dataform/__init__.py -------------------------------------------------------------------------------- /tests/dataform/test_features.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from typing import List, Tuple 4 | 5 | import os 6 | from glob import glob 7 | import tempfile 8 | 9 | import numpy as np 10 | from numpy.testing import assert_array_equal 11 | import scipy.io as sio 12 | import hdf5storage 13 | 14 | from bdpy.dataform.features import Features 15 | 16 | 17 | def _prepare_mock_data( 18 | tmpdir: str, 19 | mock_layer_names: List[str], 20 | mock_image_names: List[str], 21 | mock_shapes: List[Tuple[int, ...]] 22 | ) -> None: 23 | """Prepare mock data for testing.""" 24 | for layer_name, shape in zip(mock_layer_names, mock_shapes): 25 | os.makedirs(os.path.join(tmpdir, layer_name)) 26 | for image_name in mock_image_names: 27 | data = np.random.rand(*shape) 28 | hdf5storage.savemat( 29 | os.path.join(tmpdir, layer_name, image_name + '.mat'), 30 | {'feat': data}, 31 | format='5') 32 | 33 | 34 | class TestDataformFeatures(unittest.TestCase): 35 | def setUp(self): 36 | self.mock_layer_names = ['fc8', 'conv5'] 37 | self.mock_image_names = [ 38 | 'n01443537_22563', 39 | 'n01443537_22564', 40 | 'n01677366_18182', 41 | 'n01677370_20000', 42 | 'n04572121_3262', 43 | 'n04572121_3263', 44 | 'n04572121_3264' 45 | ] 46 | self.mock_shapes = [(1, 1000), (1, 256, 13, 13)] 47 | self.feature_dir = tempfile.TemporaryDirectory() 48 | _prepare_mock_data( 49 | self.feature_dir.name, 50 | self.mock_layer_names, 51 | self.mock_image_names, 52 | self.mock_shapes 53 | ) 54 | 55 | # Loading test data 56 | # AlexNet, fc8, all samples 57 | self.alexnet_fc8_all = np.vstack( 58 | [ 59 | sio.loadmat(f)['feat'] 60 | for f in sorted(glob(os.path.join(self.feature_dir.name, 'fc8', '*.mat'))) 61 | ] 62 | ) 63 | 64 | # AlexNet, conv5, all samples 65 | self.alexnet_conv5_all = np.vstack( 66 | [ 67 | sio.loadmat(f)['feat'] 68 | for f in sorted(glob(os.path.join(self.feature_dir.name, 'conv5', '*.mat'))) 69 | ] 70 | ) 71 | 72 | def tearDown(self): 73 | self.feature_dir.cleanup() 74 | 75 | def test_features_get_features(self): 76 | feat = Features(self.feature_dir.name) 77 | 78 | assert_array_equal( 79 | feat.get_features('fc8'), 80 | self.alexnet_fc8_all 81 | ) 82 | assert_array_equal( 83 | feat.get_features('conv5'), 84 | self.alexnet_conv5_all 85 | ) 86 | 87 | def test_features_get_all(self): 88 | feat = Features(self.feature_dir.name) 89 | 90 | assert_array_equal( 91 | feat.get('fc8'), 92 | self.alexnet_fc8_all 93 | ) 94 | assert_array_equal( 95 | feat.get('conv5'), 96 | self.alexnet_conv5_all 97 | ) 98 | 99 | def test_features_get_label(self): 100 | feat = Features(self.feature_dir.name) 101 | 102 | label_idx = 0 103 | labels = self.mock_image_names[label_idx] 104 | index = np.array([label_idx]) 105 | assert_array_equal( 106 | feat.get('fc8', label=labels), 107 | self.alexnet_fc8_all[index, :] 108 | ) 109 | assert_array_equal( 110 | feat.get('conv5', label=labels), 111 | self.alexnet_conv5_all[index, :] 112 | ) 113 | 114 | index = np.array([0, 2, 5]) 115 | labels = [self.mock_image_names[i] for i in index] 116 | assert_array_equal( 117 | feat.get('fc8', label=labels), 118 | self.alexnet_fc8_all[index, :] 119 | ) 120 | assert_array_equal( 121 | feat.get('conv5', label=labels), 122 | self.alexnet_conv5_all[index, :] 123 | ) 124 | 125 | if __name__ == "__main__": 126 | unittest.main() 127 | -------------------------------------------------------------------------------- /tests/dataform/test_sparse.py: -------------------------------------------------------------------------------- 1 | '''Tests for dataform''' 2 | 3 | import os 4 | import tempfile 5 | import unittest 6 | 7 | import numpy as np 8 | 9 | from bdpy.dataform.sparse import load_array, save_array 10 | 11 | 12 | class TestSparse(unittest.TestCase): 13 | 14 | def test_load_save_dense_array(self): 15 | payloads = [ 16 | [(10,), 'test_array_dense_ndim1.mat'], # ndim = 1 17 | [(3, 2), 'test_array_dense_ndim2.mat'], # ndim = 2 18 | [(4, 3, 2), 'test_array_dense_ndim3.mat'] # ndim = 3 19 | ] 20 | 21 | with tempfile.TemporaryDirectory() as tmpdir: 22 | for shape, fname in payloads: 23 | original_data = np.random.rand(*shape) 24 | save_array(tmpdir + '/' + fname, original_data, key='testdata') 25 | from_file = load_array(tmpdir + '/' + fname, key='testdata') 26 | 27 | np.testing.assert_array_equal(original_data, from_file) 28 | 29 | def test_load_save_sparse_array(self): 30 | payloads = [ 31 | [(10,), 'test_array_sparse_ndim1.mat'], # ndim = 1 32 | [(3, 2), 'test_array_sparse_ndim2.mat'], # ndim = 2 33 | [(4, 3, 2), 'test_array_sparse_ndim3.mat'] # ndim = 3 34 | ] 35 | 36 | with tempfile.TemporaryDirectory() as tmpdir: 37 | for shape, fname in payloads: 38 | original_data = np.random.rand(*shape) 39 | original_data[original_data < 0.8] = 0 40 | 41 | save_array(tmpdir + '/' + fname, original_data, key='testdata', sparse=True) 42 | from_file = load_array(tmpdir + '/' + fname, key='testdata') 43 | 44 | np.testing.assert_array_equal(original_data, from_file) 45 | 46 | def test_load_array_jl(self): 47 | data = np.array([[1, 0, 0, 0], 48 | [2, 2, 0, 0], 49 | [3, 3, 3, 0]]) 50 | data_dir = os.path.abspath(os.path.join( 51 | os.path.dirname(__file__), os.pardir, 'data' 52 | )) 53 | 54 | testdata = load_array( 55 | os.path.join(data_dir, 'array_jl_dense_v1.mat'), key='a') 56 | np.testing.assert_array_equal(data, testdata) 57 | 58 | testdata = load_array( 59 | os.path.join(data_dir, 'array_jl_sparse_v1.mat'), key='a') 60 | np.testing.assert_array_equal(data, testdata) 61 | 62 | 63 | if __name__ == '__main__': 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /tests/distcomp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/distcomp/__init__.py -------------------------------------------------------------------------------- /tests/distcomp/test_distcomp.py: -------------------------------------------------------------------------------- 1 | '''Tests for distcomp''' 2 | 3 | import unittest 4 | 5 | import os 6 | import tempfile 7 | 8 | from bdpy.distcomp import DistComp 9 | 10 | 11 | class TestDistComp(unittest.TestCase): 12 | def test_distcomp_file(self): 13 | with tempfile.TemporaryDirectory() as lockdir: 14 | comp_id = 'test-distcomp-fs' 15 | 16 | # init 17 | distcomp = DistComp(lockdir=lockdir, comp_id=comp_id) 18 | self.assertTrue(os.path.isdir(lockdir)) 19 | self.assertFalse(distcomp.islocked()) 20 | 21 | # lock 22 | distcomp.lock() 23 | self.assertTrue(os.path.isfile(os.path.join(lockdir, comp_id + '.lock'))) 24 | self.assertTrue(distcomp.islocked()) 25 | 26 | # unlock 27 | distcomp.unlock() 28 | self.assertFalse(os.path.isfile(os.path.join(lockdir, comp_id + '.lock'))) 29 | self.assertFalse(distcomp.islocked()) 30 | 31 | # islocked_lock 32 | distcomp.islocked_lock() 33 | self.assertTrue(os.path.isfile(os.path.join(lockdir, comp_id + '.lock'))) 34 | self.assertTrue(distcomp.islocked()) 35 | 36 | def test_distcomp_sqlite3(self): 37 | with tempfile.TemporaryDirectory() as lockdir: 38 | db_path = os.path.join(lockdir, 'distcomp.db') 39 | comp_id = 'test-distcomp-sqlite3-1' 40 | 41 | # init 42 | distcomp = DistComp(backend='sqlite3', db_path=db_path) 43 | self.assertTrue(os.path.isfile(db_path)) 44 | self.assertFalse(distcomp.islocked(comp_id)) 45 | 46 | # lock 47 | distcomp.lock(comp_id) 48 | self.assertTrue(distcomp.islocked(comp_id)) 49 | 50 | # unlock 51 | distcomp.unlock(comp_id) 52 | self.assertFalse(distcomp.islocked(comp_id)) 53 | 54 | # islocked_lock 55 | with self.assertRaises(NotImplementedError): 56 | distcomp.islocked_lock(comp_id) 57 | 58 | 59 | if __name__ == '__main__': 60 | unittest.main() -------------------------------------------------------------------------------- /tests/dl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/dl/__init__.py -------------------------------------------------------------------------------- /tests/dl/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/dl/torch/__init__.py -------------------------------------------------------------------------------- /tests/dl/torch/domain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/dl/torch/domain/__init__.py -------------------------------------------------------------------------------- /tests/dl/torch/domain/test_core.py: -------------------------------------------------------------------------------- 1 | """Tests for bdpy.dl.torch.domain.core.""" 2 | 3 | import unittest 4 | from bdpy.dl.torch.domain import core as core_module 5 | 6 | 7 | class DummyAddDomain(core_module.Domain): 8 | def send(self, num): 9 | return num + 1 10 | 11 | def receive(self, num): 12 | return num - 1 13 | 14 | 15 | class DummyDoubleDomain(core_module.Domain): 16 | def send(self, num): 17 | return num * 2 18 | 19 | def receive(self, num): 20 | return num // 2 21 | 22 | 23 | class DummyUpperCaseDomain(core_module.Domain): 24 | def send(self, text): 25 | return text.upper() 26 | 27 | def receive(self, value): 28 | return value.lower() 29 | 30 | 31 | class TestDomain(unittest.TestCase): 32 | """Tests for bdpy.dl.torch.domain.core.Domain.""" 33 | def setUp(self): 34 | self.domain = DummyAddDomain() 35 | self.original_space_num = 0 36 | self.internal_space_num = 1 37 | 38 | def test_instantiation(self): 39 | """Test instantiation.""" 40 | self.assertRaises(TypeError, core_module.Domain) 41 | 42 | def test_send(self): 43 | """test send""" 44 | self.assertEqual(self.domain.send(self.original_space_num), self.internal_space_num) 45 | 46 | def test_receive(self): 47 | """test receive""" 48 | self.assertEqual(self.domain.receive(self.internal_space_num), self.original_space_num) 49 | 50 | def test_invertibility(self): 51 | input_candidates = [-1, 0, 1, 0.5] 52 | for x in input_candidates: 53 | assert x == self.domain.send(self.domain.receive(x)) 54 | assert x == self.domain.receive(self.domain.send(x)) 55 | 56 | 57 | class TestInternalDomain(unittest.TestCase): 58 | """Tests for bdpy.dl.torch.domain.core.InternalDomain.""" 59 | def setUp(self): 60 | self.domain = core_module.InternalDomain() 61 | self.num = 1 62 | 63 | def test_send(self): 64 | """test send""" 65 | self.assertEqual(self.domain.send(self.num), self.num) 66 | 67 | def test_receive(self): 68 | """test receive""" 69 | self.assertEqual(self.domain.receive(self.num), self.num) 70 | 71 | def test_invertibility(self): 72 | input_candidates = [-1, 0, 1, 0.5] 73 | for x in input_candidates: 74 | assert x == self.domain.send(self.domain.receive(x)) 75 | assert x == self.domain.receive(self.domain.send(x)) 76 | 77 | 78 | class TestIrreversibleDomain(unittest.TestCase): 79 | """Tests for bdpy.dl.torch.domain.core.IrreversibleDomain.""" 80 | def setUp(self): 81 | self.domain = core_module.IrreversibleDomain() 82 | self.num = 1 83 | 84 | def test_send(self): 85 | """test send""" 86 | self.assertEqual(self.domain.send(self.num), self.num) 87 | 88 | def test_receive(self): 89 | """test receive""" 90 | self.assertEqual(self.domain.receive(self.num), self.num) 91 | 92 | 93 | class TestComposedDomain(unittest.TestCase): 94 | """Tests for bdpy.dl.torch.domain.core.ComposedDomain.""" 95 | def setUp(self): 96 | self.composed_domain = core_module.ComposedDomain([ 97 | DummyDoubleDomain(), 98 | DummyAddDomain(), 99 | ]) 100 | self.original_space_num = 0 101 | self.internal_space_num = 2 102 | 103 | def test_send(self): 104 | """test send""" 105 | self.assertEqual(self.composed_domain.send(self.original_space_num), self.internal_space_num) 106 | 107 | def test_receive(self): 108 | """test receive""" 109 | self.assertEqual(self.composed_domain.receive(self.internal_space_num), self.original_space_num) 110 | 111 | 112 | class TestKeyValueDomain(unittest.TestCase): 113 | """Tests for bdpy.dl.torch.domain.core.KeyValueDomain.""" 114 | def setUp(self): 115 | self.key_value_domain = core_module.KeyValueDomain({ 116 | "name": DummyUpperCaseDomain(), 117 | "age": DummyDoubleDomain() 118 | }) 119 | self.original_space_data = {"name": "alice", "age": 30} 120 | self.internal_space_data = {"name": "ALICE", "age": 60} 121 | 122 | def test_send(self): 123 | """test send""" 124 | self.assertEqual(self.key_value_domain.send(self.original_space_data), self.internal_space_data) 125 | 126 | def test_receive(self): 127 | """test receive""" 128 | self.assertEqual(self.key_value_domain.receive(self.internal_space_data), self.original_space_data) 129 | 130 | 131 | if __name__ == "__main__": 132 | #unittest.main() 133 | composed_domain = core_module.ComposedDomain([ 134 | DummyDoubleDomain(), 135 | DummyAddDomain(), 136 | ]) 137 | print(composed_domain.receive(-1)) 138 | print(composed_domain.send(-2)) 139 | -------------------------------------------------------------------------------- /tests/dl/torch/domain/test_feature_domain.py: -------------------------------------------------------------------------------- 1 | """Tests for bdpy.dl.torch.domain.feature_domain.""" 2 | 3 | import unittest 4 | import torch 5 | from bdpy.dl.torch.domain import feature_domain as feature_domain_module 6 | 7 | 8 | class TestMethods(unittest.TestCase): 9 | def setUp(self): 10 | self.lnd_tensor = torch.empty((12, 196, 768)) 11 | self.nld_tensor = torch.empty((196, 12, 768)) 12 | 13 | def test_lnd2nld(self): 14 | """test _lnd2nld""" 15 | self.assertEqual(feature_domain_module._lnd2nld(self.lnd_tensor).shape, self.nld_tensor.shape) 16 | 17 | def test_nld2lnd(self): 18 | """test _nld2lnd""" 19 | self.assertEqual(feature_domain_module._nld2lnd(self.nld_tensor).shape, self.lnd_tensor.shape) 20 | 21 | 22 | class TestArbitraryFeatureKeyDomain(unittest.TestCase): 23 | """Tests for bdpy.dl.torch.domain.feature_domain.ArbitraryFeatureKeyDomain.""" 24 | def setUp(self): 25 | self.to_internal_mapping = { 26 | "self_key1": "internal_key1", 27 | "self_key2": "internal_key2" 28 | } 29 | self.to_self_mapping = { 30 | "internal_key1": "self_key1", 31 | "internal_key2": "self_key2" 32 | } 33 | self.features = { 34 | "self_key1": 123, 35 | "self_key2": 456 36 | } 37 | self.internal_features = { 38 | "internal_key1": 123, 39 | "internal_key2": 456 40 | } 41 | 42 | def test_send(self): 43 | """test send""" 44 | # when both are specified 45 | domain = feature_domain_module.ArbitraryFeatureKeyDomain( 46 | to_internal=self.to_internal_mapping, 47 | to_self=self.to_self_mapping 48 | ) 49 | self.assertEqual(domain.send(self.features), self.internal_features) 50 | 51 | # when only to_self is specified 52 | domain = feature_domain_module.ArbitraryFeatureKeyDomain( 53 | to_self=self.to_self_mapping 54 | ) 55 | self.assertEqual(domain.send(self.features), self.internal_features) 56 | 57 | # when only to_internal is specified 58 | domain = feature_domain_module.ArbitraryFeatureKeyDomain( 59 | to_internal=self.to_internal_mapping 60 | ) 61 | self.assertEqual(domain.send(self.features), self.internal_features) 62 | 63 | def test_receive(self): 64 | """test receive""" 65 | # when both are specified 66 | domain = feature_domain_module.ArbitraryFeatureKeyDomain( 67 | to_internal=self.to_internal_mapping, 68 | to_self=self.to_self_mapping 69 | ) 70 | self.assertEqual(domain.receive(self.internal_features), self.features) 71 | 72 | # when only to_self is specified 73 | domain = feature_domain_module.ArbitraryFeatureKeyDomain( 74 | to_self=self.to_self_mapping 75 | ) 76 | self.assertEqual(domain.receive(self.internal_features), self.features) 77 | 78 | # when only to_internal is specified 79 | domain = feature_domain_module.ArbitraryFeatureKeyDomain( 80 | to_internal=self.to_internal_mapping 81 | ) 82 | self.assertEqual(domain.receive(self.internal_features), self.features) 83 | 84 | 85 | if __name__ == "__main__": 86 | unittest.main() -------------------------------------------------------------------------------- /tests/dl/torch/test_models.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from bdpy.dl.torch import models 7 | 8 | 9 | def _removeprefix(text: str, prefix: str) -> str: 10 | """Remove prefix from text. (Workaround for Python 3.8)""" 11 | if text.startswith(prefix): 12 | return text[len(prefix):] 13 | return text 14 | 15 | 16 | class MockModule(nn.Module): 17 | def __init__(self): 18 | super(MockModule, self).__init__() 19 | self.layer1 = nn.Linear(10, 10) 20 | self.layers = nn.Sequential( 21 | nn.Conv2d(1, 1, 3), 22 | nn.Conv2d(1, 1, 3), 23 | nn.Module(), 24 | nn.Sequential( 25 | nn.Conv2d(1, 1, 4), 26 | nn.Conv2d(1, 1, 8), 27 | ) 28 | ) 29 | inner_network = self.layers[-2] 30 | inner_network.features = nn.Sequential( 31 | nn.Conv2d(1, 1, 5), 32 | nn.Conv2d(1, 1, 5) 33 | ) 34 | 35 | 36 | class TestLayerMap(unittest.TestCase): 37 | def setUp(self): 38 | self.kv_pairs = [ 39 | {'net': 'vgg19', 'payload': {'key': 'fc6', 'value': 'classifier[0]'}}, 40 | {'net': 'vgg19', 'payload': {'key': 'conv5_4', 'value': 'features[34]'}}, 41 | {'net': 'alexnet', 'payload': {'key': 'fc6', 'value': 'classifier[0]'}}, 42 | {'net': 'alexnet', 'payload': {'key': 'conv5', 'value': 'features[12]'}} 43 | ] 44 | 45 | def test_layer_map(self): 46 | for kv_pair in self.kv_pairs: 47 | expected = kv_pair['payload'] 48 | output = models.layer_map(kv_pair['net']) 49 | self.assertIsInstance(output, dict) 50 | self.assertEqual(output[expected['key']], expected['value']) 51 | 52 | 53 | class TestParseLayerName(unittest.TestCase): 54 | def setUp(self): 55 | self.mock = MockModule() 56 | self.accessors = [ 57 | {'name': 'layer1', 'type': nn.Linear, 'attrs': {'in_features': 10, 'out_features': 10}}, 58 | {'name': 'layers[0]', 'type': nn.Conv2d, 'attrs': {'kernel_size': (3, 3)}}, 59 | {'name': 'layers[1]', 'type': nn.Conv2d, 'attrs': {'kernel_size': (3, 3)}}, 60 | {'name': 'layers[2].features[0]', 'type': nn.Conv2d, 'attrs': {'kernel_size': (5, 5)}}, 61 | {'name': 'layers[3][0]', 'type': nn.Conv2d, 'attrs': {'kernel_size': (4, 4)}}, 62 | {'name': 'layers[3][1]', 'type': nn.Conv2d, 'attrs': {'kernel_size': (8, 8)}} 63 | ] 64 | 65 | def test_parse_layer_name(self): 66 | for accessor in self.accessors: 67 | layer = models._parse_layer_name(self.mock, accessor['name']) 68 | self.assertIsInstance(layer, accessor['type']) 69 | for attr, value in accessor['attrs'].items(): 70 | self.assertEqual(getattr(layer, attr), value) 71 | 72 | # Test non-existing layer access 73 | self.assertRaises( 74 | ValueError, models._parse_layer_name, self.mock, 'not_existing_layer') 75 | # Test invalid layer access 76 | self.assertRaises( 77 | ValueError, models._parse_layer_name, self.mock, 'layers["key"]') 78 | 79 | def test_parse_layer_name_for_sequential(self): 80 | """Test _parse_layer_name for nn.Sequential. 81 | 82 | nn.Sequential is a special case because the submodules are directly 83 | accessible like a list. For example, `model[0]` will return the first 84 | module in the model. 85 | """ 86 | sequential_module = self.mock.layers 87 | accessors = [accessor for accessor in self.accessors if accessor['name'].startswith('layers')] 88 | for accessor in accessors: 89 | accsessor_key = _removeprefix(accessor['name'], 'layers') 90 | layer = models._parse_layer_name(sequential_module, accsessor_key) 91 | self.assertIsInstance(layer, accessor['type']) 92 | for attr, value in accessor['attrs'].items(): 93 | self.assertEqual(getattr(layer, attr), value) 94 | 95 | 96 | class TestVGG19(unittest.TestCase): 97 | def setUp(self): 98 | self.input_shape = (1, 3, 224, 224) 99 | self.model = models.VGG19() 100 | 101 | def test_forward(self): 102 | x = torch.rand(self.input_shape) 103 | output = self.model(x) 104 | self.assertIsInstance(output, torch.Tensor) 105 | self.assertEqual(output.shape, (1, 1000)) 106 | 107 | def test_layer_access(self): 108 | layer_names = models.layer_map('vgg19').values() 109 | for layer_name in layer_names: 110 | self.assertIsInstance( 111 | models._parse_layer_name(self.model, layer_name), nn.Module) 112 | 113 | 114 | class TestAlexNet(unittest.TestCase): 115 | def setUp(self): 116 | self.input_shape = (1, 3, 224, 224) 117 | self.model = models.AlexNet() 118 | 119 | def test_forward(self): 120 | x = torch.rand(self.input_shape) 121 | output = self.model(x) 122 | self.assertIsInstance(output, torch.Tensor) 123 | self.assertEqual(output.shape, (1, 1000)) 124 | 125 | def test_layer_access(self): 126 | layer_names = models.layer_map('alexnet').values() 127 | for layer_name in layer_names: 128 | self.assertIsInstance( 129 | models._parse_layer_name(self.model, layer_name), nn.Module) 130 | 131 | 132 | class TestAlexNetGenerator(unittest.TestCase): 133 | def setUp(self): 134 | self.input_shape = (1, 4096) 135 | self.model = models.AlexNetGenerator() 136 | 137 | def test_forward(self): 138 | x = torch.rand(self.input_shape) 139 | output = self.model(x) 140 | self.assertIsInstance(output, torch.Tensor) 141 | self.assertEqual(output.shape, (1, 3, 256, 256)) 142 | 143 | 144 | if __name__ == '__main__': 145 | unittest.main() -------------------------------------------------------------------------------- /tests/dl/torch/test_torch.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from bdpy.dl.torch import torch as bdtorch 8 | 9 | 10 | class MockInnerModule(nn.Module): 11 | def __init__(self): 12 | super(MockInnerModule, self).__init__() 13 | self.features = nn.Sequential( 14 | nn.Linear(3, 4), 15 | nn.Linear(4, 5) 16 | ) 17 | 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | x = self.features(x) 20 | return x 21 | 22 | 23 | class MockModule(nn.Module): 24 | def __init__(self): 25 | super(MockModule, self).__init__() 26 | self.layer1 = nn.Linear(10, 1) 27 | self.layers = nn.Sequential( 28 | nn.Linear(1, 2), 29 | nn.Linear(2, 3) 30 | ) 31 | self.out = MockInnerModule() 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | x = self.layer1(x) 35 | x = self.layers(x) 36 | x = self.out(x) 37 | return x 38 | 39 | 40 | class TestFeatureExtractor(unittest.TestCase): 41 | def setUp(self) -> None: 42 | self.encoder = MockModule() 43 | self.input_tensor = np.random.random(size=(10,)).astype(np.float32) 44 | self.layer_list = [ 45 | {'map': {'alias': 'L1', 'entity': 'layer1'}, 'shape': (1, 1)}, 46 | {'map': {'alias': 'L2', 'entity': 'layers[0]'}, 'shape': (1, 2)}, 47 | {'map': {'alias': 'L3', 'entity': 'layers[1]'}, 'shape': (1, 3)}, 48 | {'map': {'alias': 'L4', 'entity': 'out.features[0]'}, 'shape': (1, 4)}, 49 | {'map': {'alias': 'L5', 'entity': 'out.features[1]'}, 'shape': (1, 5)} 50 | ] 51 | 52 | def test_run(self): 53 | self.encoder.eval() 54 | layer_to_shape = {payload['map']['entity']: payload['shape'] for payload in self.layer_list} 55 | extractor = bdtorch.FeatureExtractor(self.encoder, layer_to_shape.keys(), detach=True) 56 | features = extractor.run(self.input_tensor) 57 | for layer, shape in layer_to_shape.items(): 58 | self.assertEqual(features[layer].shape, shape) 59 | 60 | def test_run_with_layer_map(self): 61 | self.encoder.eval() 62 | layer_to_shape = {payload['map']['alias']: payload['shape'] for payload in self.layer_list} 63 | layer_map = {payload['map']['alias']: payload['map']['entity'] for payload in self.layer_list} 64 | extractor = bdtorch.FeatureExtractor( 65 | self.encoder, layer_to_shape.keys(), 66 | layer_mapping=layer_map, detach=True) 67 | features = extractor.run(self.input_tensor) 68 | for layer, shape in layer_to_shape.items(): 69 | self.assertEqual(features[layer].shape, shape) 70 | 71 | 72 | class TestImageDataset(unittest.TestCase): 73 | ... 74 | 75 | 76 | if __name__ == '__main__': 77 | unittest.main() -------------------------------------------------------------------------------- /tests/env/py27/Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | numpy = "*" 10 | scipy = "*" 11 | scikit-learn = "*" 12 | h5py = "*" 13 | hdf5storage = "*" 14 | pyyaml = "*" 15 | 16 | [requires] 17 | python_version = "2.7" 18 | -------------------------------------------------------------------------------- /tests/env/py38/Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | numpy = "*" 10 | scipy = "*" 11 | scikit-learn = "*" 12 | h5py = "*" 13 | hdf5storage = "*" 14 | pyyaml = "*" 15 | 16 | [requires] 17 | python_version = "3.8" 18 | -------------------------------------------------------------------------------- /tests/evals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/evals/__init__.py -------------------------------------------------------------------------------- /tests/evals/test_metrics.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pickle 4 | 5 | import numpy as np 6 | 7 | from bdpy.evals.metrics import profile_correlation, pattern_correlation, pairwise_identification 8 | 9 | 10 | class TestMetrics(unittest.TestCase): 11 | def test_profile_correlation(self): 12 | # 2-d array 13 | n = 30 14 | x = np.random.rand(10, n) 15 | y = np.random.rand(10, n) 16 | r = np.array([[ 17 | np.corrcoef(x[:, i], y[:, i])[0, 1] 18 | for i in range(n) 19 | ]]) 20 | 21 | self.assertTrue(np.array_equal( 22 | profile_correlation(x, y), r 23 | )) 24 | self.assertEqual(profile_correlation(x, y).shape, (1, n)) 25 | 26 | # Multi-d array 27 | x = np.random.rand(10, 4, 3, 2) 28 | y = np.random.rand(10, 4, 3, 2) 29 | xf = x.reshape(10, -1) 30 | yf = y.reshape(10, -1) 31 | r = np.array([[ 32 | np.corrcoef(xf[:, i], yf[:, i])[0, 1] 33 | for i in range(4 * 3 * 2) 34 | ]]) 35 | r = r.reshape(1, 4, 3, 2) 36 | 37 | self.assertTrue(np.array_equal( 38 | profile_correlation(x, y), r 39 | )) 40 | self.assertEqual(profile_correlation(x, y).shape, (1, 4, 3, 2)) 41 | 42 | def test_pattern_correlation(self): 43 | # 2-d array 44 | x = np.random.rand(10, 30) 45 | y = np.random.rand(10, 30) 46 | r = np.array([ 47 | np.corrcoef(x[i, :], y[i, :])[0, 1] 48 | for i in range(10) 49 | ]) 50 | 51 | self.assertTrue(np.array_equal( 52 | pattern_correlation(x, y), r 53 | )) 54 | self.assertEqual(pattern_correlation(x, y).shape, (10,)) 55 | 56 | # Multi-d array 57 | x = np.random.rand(10, 4, 3, 2) 58 | y = np.random.rand(10, 4, 3, 2) 59 | xf = x.reshape(10, -1) 60 | yf = y.reshape(10, -1) 61 | r = np.array([ 62 | np.corrcoef(xf[i, :], yf[i, :])[0, 1] 63 | for i in range(10) 64 | ]) 65 | 66 | self.assertTrue(np.array_equal( 67 | pattern_correlation(x, y), r 68 | )) 69 | self.assertEqual(pattern_correlation(x, y).shape, (10,)) 70 | 71 | def test_2d(self): 72 | with open('tests/data/testdata-2d.pkl.gz', 'rb') as f: 73 | d = pickle.load(f) 74 | self.assertTrue(np.array_equal( 75 | profile_correlation(d['x'], d['y']), 76 | d['r_prof'] 77 | )) 78 | self.assertTrue(np.array_equal( 79 | pattern_correlation(d['x'], d['y']), 80 | d['r_patt'] 81 | )) 82 | self.assertTrue(np.array_equal( 83 | pairwise_identification(d['x'], d['y']), 84 | d['ident_acc'] 85 | )) 86 | 87 | def test_2d_nan(self): 88 | with open('tests/data/testdata-2d-nan.pkl.gz', 'rb') as f: 89 | d = pickle.load(f) 90 | # self.assertTrue(np.array_equal( 91 | # profile_correlation(d['x'], d['y']), 92 | # d['r_prof'] 93 | # )) 94 | self.assertTrue(np.array_equal( 95 | pattern_correlation(d['x'], d['y'], remove_nan=True), 96 | d['r_patt'], 97 | )) 98 | self.assertTrue(np.array_equal( 99 | pairwise_identification(d['x'], d['y'], remove_nan=True), 100 | d['ident_acc'], 101 | )) 102 | 103 | if __name__ == '__main__': 104 | unittest.main() -------------------------------------------------------------------------------- /tests/feature/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/feature/__init__.py -------------------------------------------------------------------------------- /tests/ml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/ml/__init__.py -------------------------------------------------------------------------------- /tests/ml/test_crossvalidation.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | '''Tests for ml''' 3 | 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | 9 | from bdpy.ml.crossvalidation import cvindex_groupwise, make_cvindex, make_cvindex_generator 10 | 11 | 12 | class TestCVIndexGroupwise(unittest.TestCase): 13 | 14 | def test_cvindex_groupwise(self): 15 | 16 | # Test data 17 | x = np.array([ 18 | 1, 1, 1, 19 | 2, 2, 2, 20 | 3, 3, 3, 21 | 4, 4, 4, 22 | 5, 5, 5, 23 | 6, 6, 6 24 | ]) 25 | 26 | # Expected output 27 | train_index = [ 28 | np.array([3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]), 29 | np.array([0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]), 30 | np.array([0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17]), 31 | np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17]), 32 | np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17]), 33 | np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]) 34 | ] 35 | 36 | test_index = [ 37 | np.array([ 0, 1, 2]), 38 | np.array([ 3, 4, 5]), 39 | np.array([ 6, 7, 8]), 40 | np.array([ 9, 10, 11]), 41 | np.array([12, 13, 14]), 42 | np.array([15, 16, 17]) 43 | ] 44 | 45 | cvindex = cvindex_groupwise(x) 46 | 47 | for i, (tr, te) in enumerate(cvindex): 48 | self.assertTrue(np.array_equal(train_index[i], tr)) 49 | self.assertTrue(np.array_equal(test_index[i], te)) 50 | 51 | def test_cvindex_groupwise_exclusive(self): 52 | 53 | # Test data 54 | x = np.array([ 55 | 1, 1, 1, 56 | 2, 2, 2, 57 | 3, 3, 3, 58 | 4, 4, 4, 59 | 5, 5, 5, 60 | 6, 6, 6 61 | ]) 62 | 63 | # Exclusive labels 64 | a = np.array([ 65 | 1, 2, 3, 66 | 4, 5, 6, 67 | 1, 2, 3, 68 | 4, 5, 6, 69 | 1, 2, 3, 70 | 4, 5, 6, 71 | ]) 72 | 73 | # Expected output 74 | train_index = [ 75 | np.array([3, 4, 5, 9, 10, 11, 15, 16, 17]), 76 | np.array([0, 1, 2, 6, 7, 8, 12, 13, 14]), 77 | np.array([3, 4, 5, 9, 10, 11, 15, 16, 17]), 78 | np.array([0, 1, 2, 6, 7, 8, 12, 13, 14]), 79 | np.array([3, 4, 5, 9, 10, 11, 15, 16, 17]), 80 | np.array([0, 1, 2, 6, 7, 8, 12, 13, 14]) 81 | ] 82 | 83 | test_index = [ 84 | np.array([ 0, 1, 2]), 85 | np.array([ 3, 4, 5]), 86 | np.array([ 6, 7, 8]), 87 | np.array([ 9, 10, 11]), 88 | np.array([12, 13, 14]), 89 | np.array([15, 16, 17]) 90 | ] 91 | 92 | cvindex = cvindex_groupwise(x, exclusive=a) 93 | 94 | for i, (tr, te) in enumerate(cvindex): 95 | self.assertTrue(np.array_equal(train_index[i], tr)) 96 | self.assertTrue(np.array_equal(test_index[i], te)) 97 | 98 | 99 | class TestMakeCVIndex(unittest.TestCase): 100 | def test_make_cvindex(self): 101 | '''Test for make_cvindex''' 102 | test_input = np.array([1, 1, 2, 2, 3, 3]) 103 | 104 | exp_output_a = np.array([[False, True, True], 105 | [False, True, True], 106 | [True, False, True], 107 | [True, False, True], 108 | [True, True, False], 109 | [True, True, False]]) 110 | exp_output_b = np.array([[True, False, False], 111 | [True, False, False], 112 | [False, True, False], 113 | [False, True, False], 114 | [False, False, True], 115 | [False, False, True]]) 116 | 117 | test_output_a, test_output_b = make_cvindex(test_input) 118 | 119 | self.assertTrue((test_output_a == exp_output_a).all()) 120 | self.assertTrue((test_output_b == exp_output_b).all()) 121 | 122 | 123 | if __name__ == '__main__': 124 | unittest.main() 125 | -------------------------------------------------------------------------------- /tests/ml/test_ensemble.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from bdpy.ml import ensemble 6 | 7 | 8 | class TestEnsemble(unittest.TestCase): 9 | def test_ensemble_get_majority(self): 10 | '''Tests of bdpy.ml.emsenble.get_majority''' 11 | data = np.array([[1, 3, 2, 1, 2], 12 | [2, 1, 0, 0, 2], 13 | [2, 1, 1, 0, 2], 14 | [1, 3, 3, 1, 1], 15 | [0, 2, 3, 3, 0], 16 | [3, 2, 2, 2, 1], 17 | [3, 1, 3, 2, 0], 18 | [3, 2, 0, 3, 1]]) 19 | # Get the major elements in each colum (axis=0) or row (axis=1). 20 | # The element with the smallest value will be returned when several 21 | # elements were the majority. 22 | ans_by_column = np.array([3, 1, 3, 0, 1]) 23 | ans_by_row = np.array([1, 0, 1, 1, 0, 2, 3, 3]) 24 | np.testing.assert_array_almost_equal( 25 | ensemble.get_majority(data, axis=0), ans_by_column) 26 | np.testing.assert_array_almost_equal( 27 | ensemble.get_majority(data, axis=1), ans_by_row) 28 | 29 | 30 | if __name__ == '__main__': 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/ml/test_regress.py: -------------------------------------------------------------------------------- 1 | '''Tests for bdpy.ml''' 2 | 3 | 4 | import unittest 5 | 6 | import numpy as np 7 | 8 | from bdpy.ml import regress 9 | 10 | 11 | class TestRegress(unittest.TestCase): 12 | '''Tests for 'ml.regress' module''' 13 | 14 | def test_add_bias_default(self): 15 | '''Test of bdpy.ml.regress.add_bias (default: axis=0)''' 16 | x = np.array([[100, 110, 120], 17 | [200, 210, 220]]) 18 | 19 | exp_y = np.array([[100, 110, 120], 20 | [200, 210, 220], 21 | [1, 1, 1]]) 22 | 23 | test_y = regress.add_bias(x) 24 | 25 | np.testing.assert_array_equal(test_y, exp_y) 26 | 27 | def test_add_bias_axisone(self): 28 | '''Test of bdpy.ml.regress.add_bias (axis=1)''' 29 | x = np.array([[100, 110, 120], 30 | [200, 210, 220]]) 31 | 32 | exp_y = np.array([[100, 110, 120, 1], 33 | [200, 210, 220, 1]]) 34 | 35 | test_y = regress.add_bias(x, axis=1) 36 | 37 | np.testing.assert_array_equal(test_y, exp_y) 38 | 39 | def test_add_bias_axiszero(self): 40 | '''Test of bdpy.ml.regress.add_bias (axis=0)''' 41 | x = np.array([[100, 110, 120], 42 | [200, 210, 220]]) 43 | 44 | exp_y = np.array([[100, 110, 120], 45 | [200, 210, 220], 46 | [1, 1, 1]]) 47 | 48 | test_y = regress.add_bias(x, axis=0) 49 | 50 | np.testing.assert_array_equal(test_y, exp_y) 51 | 52 | def test_add_bias_invalidaxis(self): 53 | '''Exception test of bdpy.ml.regress.add_bias 54 | (invalid input in 'axis')''' 55 | x = np.array([[100, 110, 120], 56 | [200, 210, 220]]) 57 | 58 | self.assertRaises(ValueError, lambda: regress.add_bias(x, axis=-1)) 59 | 60 | 61 | if __name__ == '__main__': 62 | unittest.main() -------------------------------------------------------------------------------- /tests/preproc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/preproc/__init__.py -------------------------------------------------------------------------------- /tests/preproc/test_interface.py: -------------------------------------------------------------------------------- 1 | '''Tests for bdpy.preprocessor''' 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | from scipy.signal import detrend 7 | 8 | from bdpy.preproc import interface 9 | 10 | 11 | class TestPreprocessorInterface(unittest.TestCase): 12 | '''Tests of 'preprocessor' module''' 13 | 14 | @classmethod 15 | def test_average_sample(cls): 16 | '''Test for average_sample''' 17 | 18 | x = np.random.rand(10, 100) 19 | group = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2]) 20 | 21 | exp_output_x = np.vstack((np.average(x[0:5, :], axis=0), 22 | np.average(x[5:10, :], axis=0))) 23 | exp_output_ind = np.array([0, 5]) 24 | 25 | test_output_x, test_output_ind = interface.average_sample( 26 | x, group, verbose=True) 27 | 28 | np.testing.assert_array_equal(test_output_x, exp_output_x) 29 | np.testing.assert_array_equal(test_output_ind, exp_output_ind) 30 | 31 | @classmethod 32 | def test_detrend_sample_default(cls): 33 | '''Test for detrend_sample (default)''' 34 | 35 | x = np.random.rand(20, 10) 36 | group = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 37 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) 38 | 39 | exp_output = np.vstack((detrend(x[0:10, :], axis=0, type='linear') 40 | + np.mean(x[0:10, :], axis=0), 41 | detrend(x[10:20, :], axis=0, type='linear') 42 | + np.mean(x[10:20, :], axis=0))) 43 | 44 | test_output = interface.detrend_sample(x, group, verbose=True) 45 | 46 | np.testing.assert_array_equal(test_output, exp_output) 47 | 48 | @classmethod 49 | def test_detrend_sample_nokeepmean(cls): 50 | '''Test for detrend_sample (keep_mean=False)''' 51 | 52 | x = np.random.rand(20, 10) 53 | group = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 54 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) 55 | 56 | exp_output = np.vstack((detrend(x[0:10, :], axis=0, type='linear'), 57 | detrend(x[10:20, :], axis=0, type='linear'))) 58 | 59 | test_output = interface.detrend_sample( 60 | x, group, keep_mean=False, verbose=True) 61 | 62 | np.testing.assert_array_equal(test_output, exp_output) 63 | 64 | @classmethod 65 | def test_normalize_sample(cls): 66 | '''Test for normalize_sample (default)''' 67 | 68 | x = np.random.rand(20, 10) 69 | group = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 70 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) 71 | 72 | mean_a = np.mean(x[0:10, :], axis=0) 73 | mean_b = np.mean(x[10:20, :], axis=0) 74 | 75 | exp_output = np.vstack((100 * (x[0:10, :] - mean_a) / mean_a, 76 | 100 * (x[10:20, :] - mean_b) / mean_b)) 77 | 78 | test_output = interface.normalize_sample(x, group, verbose=True) 79 | 80 | np.testing.assert_array_equal(test_output, exp_output) 81 | 82 | @classmethod 83 | def test_shift_sample_singlegroup(cls): 84 | '''Test for shift_sample (single group, shift_size=1)''' 85 | 86 | x = np.array([[1, 2, 3], 87 | [11, 12, 13], 88 | [21, 22, 23], 89 | [31, 32, 33], 90 | [41, 42, 43]]) 91 | grp = np.array([1, 1, 1, 1, 1]) 92 | 93 | exp_output_data = np.array([[11, 12, 13], 94 | [21, 22, 23], 95 | [31, 32, 33], 96 | [41, 42, 43]]) 97 | exp_output_ind = [0, 1, 2, 3] 98 | 99 | # Default shift_size = 1 100 | test_output_data, test_output_ind = interface.shift_sample( 101 | x, grp, verbose=True) 102 | 103 | np.testing.assert_array_equal(test_output_data, exp_output_data) 104 | np.testing.assert_array_equal(test_output_ind, exp_output_ind) 105 | 106 | @classmethod 107 | def test_shift_sample_twogroup(cls): 108 | '''Test for shift_sample (two groups, shift_size=1)''' 109 | 110 | x = np.array([[1, 2, 3], 111 | [11, 12, 13], 112 | [21, 22, 23], 113 | [31, 32, 33], 114 | [41, 42, 43], 115 | [51, 52, 53]]) 116 | grp = np.array([1, 1, 1, 2, 2, 2]) 117 | 118 | exp_output_data = np.array([[11, 12, 13], 119 | [21, 22, 23], 120 | [41, 42, 43], 121 | [51, 52, 53]]) 122 | exp_output_ind = [0, 1, 3, 4] 123 | 124 | # Default shift_size=1 125 | test_output_data, test_output_ind = interface.shift_sample( 126 | x, grp, verbose=True) 127 | 128 | np.testing.assert_array_equal(test_output_data, exp_output_data) 129 | np.testing.assert_array_equal(test_output_ind, exp_output_ind) 130 | 131 | 132 | if __name__ == '__main__': 133 | unittest.main() -------------------------------------------------------------------------------- /tests/preproc/test_select_top.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from bdpy.preproc.select_top import select_top 6 | 7 | 8 | class TestPreprocessorSelectTop(unittest.TestCase): 9 | 10 | @classmethod 11 | def test_select_top_default(cls): 12 | '''Test for select_top (default, axis=0)''' 13 | 14 | test_data = np.array([[1, 2, 3, 4, 5], 15 | [11, 12, 13, 14, 15], 16 | [21, 22, 23, 24, 25], 17 | [31, 32, 33, 34, 35], 18 | [41, 42, 43, 44, 45]]) 19 | test_value = np.array([15, 3, 6, 20, 0]) 20 | test_num = 3 21 | 22 | exp_output_data = np.array([[1, 2, 3, 4, 5], 23 | [21, 22, 23, 24, 25], 24 | [31, 32, 33, 34, 35]]) 25 | exp_output_index = np.array([0, 2, 3]) 26 | 27 | test_output_data, test_output_index = select_top( 28 | test_data, test_value, test_num) 29 | 30 | np.testing.assert_array_equal(test_output_data, exp_output_data) 31 | np.testing.assert_array_equal(test_output_index, exp_output_index) 32 | 33 | @classmethod 34 | def test_select_top_axisone(cls): 35 | '''Test for select_top (axis=1)''' 36 | 37 | test_data = np.array([[1, 2, 3, 4, 5], 38 | [11, 12, 13, 14, 15], 39 | [21, 22, 23, 24, 25], 40 | [31, 32, 33, 34, 35], 41 | [41, 42, 43, 44, 45]]) 42 | test_value = np.array([15, 3, 6, 20, 0]) 43 | test_num = 3 44 | 45 | exp_output_data = np.array([[1, 3, 4], 46 | [11, 13, 14], 47 | [21, 23, 24], 48 | [31, 33, 34], 49 | [41, 43, 44]]) 50 | exp_output_index = np.array([0, 2, 3]) 51 | 52 | test_output_data, test_output_index = select_top( 53 | test_data, test_value, test_num, axis=1) 54 | 55 | np.testing.assert_array_equal(test_output_data, exp_output_data) 56 | np.testing.assert_array_equal(test_output_index, exp_output_index) -------------------------------------------------------------------------------- /tests/recon/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/recon/__init__.py -------------------------------------------------------------------------------- /tests/recon/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/recon/torch/__init__.py -------------------------------------------------------------------------------- /tests/recon/torch/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/recon/torch/modules/__init__.py -------------------------------------------------------------------------------- /tests/recon/torch/modules/test_encoder.py: -------------------------------------------------------------------------------- 1 | """Tests for bdpy.recon.torch.modules.encoder.""" 2 | 3 | import unittest 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from bdpy.dl.torch.domain.image_domain import Zero2OneImageDomain 9 | from bdpy.recon.torch.modules import encoder as encoder_module 10 | 11 | 12 | class MLP(nn.Module): 13 | """A simple MLP.""" 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.fc1 = nn.Linear(64 * 64 * 3, 256) 18 | self.fc2 = nn.Linear(256, 128) 19 | 20 | def forward(self, x): 21 | x = x.view(x.size(0), -1) 22 | x = self.fc1(x) 23 | x = torch.relu(x) 24 | x = self.fc2(x) 25 | return x 26 | 27 | 28 | class TestBaseEncoder(unittest.TestCase): 29 | """Tests for bdpy.recon.torch.modules.encoder.BaseEncoder.""" 30 | 31 | def test_instantiation(self): 32 | """Test instantiation.""" 33 | self.assertRaises(TypeError, encoder_module.BaseEncoder) 34 | 35 | def test_call(self): 36 | """Test __call__.""" 37 | 38 | class ReturnAsIsEncoder(encoder_module.BaseEncoder): 39 | def encode(self, images): 40 | return {"image": images} 41 | 42 | encoder = ReturnAsIsEncoder() 43 | images = torch.randn(1, 3, 64, 64) 44 | features = encoder(images) 45 | self.assertDictEqual(features, {"image": images}) 46 | 47 | 48 | class TestNNModuleEncoder(unittest.TestCase): 49 | """Tests for bdpy.recon.torch.modules.encoder.NNModuleEncoder.""" 50 | 51 | def test_instantiation(self): 52 | """Test instantiation.""" 53 | self.assertRaises(TypeError, encoder_module.NNModuleEncoder) 54 | 55 | def test_call(self): 56 | """Test __call__.""" 57 | 58 | class ReturnAsIsEncoder(encoder_module.NNModuleEncoder): 59 | def __init__(self) -> None: 60 | super().__init__() 61 | def encode(self, images): 62 | return {"image": images} 63 | 64 | encoder = ReturnAsIsEncoder() 65 | 66 | images = torch.randn(1, 3, 64, 64) 67 | images.requires_grad = True 68 | features = encoder(images) 69 | self.assertIsInstance(features, dict) 70 | self.assertEqual(len(features), 1) 71 | self.assertEqual(features["image"].shape, (1, 3, 64, 64)) 72 | features["image"].sum().backward() 73 | self.assertIsNotNone(images.grad) 74 | 75 | 76 | class TestSimpleEncoder(unittest.TestCase): 77 | """Tests for bdpy.recon.torch.modules.encoder.SimpleEncoder.""" 78 | 79 | def test_call(self): 80 | """Test __call__.""" 81 | encoder = encoder_module.SimpleEncoder( 82 | MLP(), ["fc1", "fc2"], domain=Zero2OneImageDomain() 83 | ) 84 | images = torch.randn(1, 3, 64, 64).clamp(0, 1) 85 | images.requires_grad = True 86 | features = encoder(images) 87 | self.assertIsInstance(features, dict) 88 | self.assertEqual(len(features), 2) 89 | self.assertEqual(features["fc1"].shape, (1, 256)) 90 | self.assertEqual(features["fc2"].shape, (1, 128)) 91 | features["fc2"].sum().backward() 92 | self.assertIsNotNone(images.grad) 93 | 94 | 95 | class TestBuildEncoder(unittest.TestCase): 96 | """Tests for bdpy.recon.torch.modules.encoder.build_encoder.""" 97 | 98 | def test_build_encoder(self): 99 | """Test build_encoder.""" 100 | mlp = MLP() 101 | encoder_from_builder = encoder_module.build_encoder( 102 | feature_network=mlp, 103 | layer_names=["fc1", "fc2"], 104 | domain=Zero2OneImageDomain(), 105 | ) 106 | encoder = encoder_module.SimpleEncoder( 107 | mlp, ["fc1", "fc2"], domain=Zero2OneImageDomain() 108 | ) 109 | 110 | images = torch.randn(1, 3, 64, 64).clamp(0, 1) 111 | features_from_builder = encoder_from_builder(images) 112 | features = encoder(images) 113 | self.assertEqual(type(encoder_from_builder), type(encoder)) 114 | self.assertEqual(features_from_builder.keys(), features.keys()) 115 | for key in features_from_builder.keys(): 116 | self.assertTrue(torch.allclose(features_from_builder[key], features[key])) 117 | 118 | 119 | if __name__ == "__main__": 120 | unittest.main() 121 | -------------------------------------------------------------------------------- /tests/recon/torch/modules/test_latent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from typing import Iterator 4 | import torch.nn as nn 5 | from functools import partial 6 | from bdpy.recon.torch.modules import latent as latent_module 7 | 8 | 9 | class DummyLatent(latent_module.BaseLatent): 10 | def __init__(self): 11 | self.latent = nn.Parameter(torch.tensor([0.0, 1.0, 2.0])) 12 | 13 | def reset_states(self): 14 | with torch.no_grad(): 15 | self.latent.fill_(0.0) 16 | 17 | def parameters(self, recurse): 18 | return iter(self.latent) 19 | 20 | def generate(self): 21 | return self.latent 22 | 23 | 24 | class TestBaseLatent(unittest.TestCase): 25 | """Tests for bdpy.recon.torch.modules.latent.BaseLatent.""" 26 | def setUp(self): 27 | self.latent_value_expected = nn.Parameter(torch.tensor([0.0, 1.0, 2.0])) 28 | self.latent_reset_value_expected = nn.Parameter(torch.tensor([0.0, 0.0, 0.0])) 29 | 30 | def test_instantiation(self): 31 | """Test instantiation.""" 32 | self.assertRaises(TypeError, latent_module.BaseLatent) 33 | 34 | def test_call(self): 35 | """Test __call__.""" 36 | latent = DummyLatent() 37 | self.assertTrue(torch.equal(latent(), self.latent_value_expected)) 38 | 39 | def test_parameters(self): 40 | """test parameters""" 41 | latent = DummyLatent() 42 | params = latent.parameters(recurse=True) 43 | self.assertIsInstance(params, Iterator) 44 | 45 | def test_reset_states(self): 46 | """test reset_states""" 47 | latent = DummyLatent() 48 | latent.reset_states() 49 | self.assertTrue(torch.equal(latent(), self.latent_reset_value_expected)) 50 | 51 | 52 | class DummyNNModuleLatent(latent_module.NNModuleLatent): 53 | def __init__(self): 54 | super().__init__() 55 | self.latent = nn.Parameter(torch.tensor([0.0, 1.0, 2.0])) 56 | 57 | def reset_states(self): 58 | with torch.no_grad(): 59 | self.latent.fill_(0.0) 60 | 61 | def generate(self): 62 | return self.latent 63 | 64 | 65 | class TestNNModuleLatent(unittest.TestCase): 66 | """Tests for bdpy.recon.torch.modules.latent.NNModuleLatent.""" 67 | def setUp(self): 68 | self.latent_value_expected = nn.Parameter(torch.tensor([0.0, 1.0, 2.0])) 69 | self.latent_reset_value_expected = nn.Parameter(torch.tensor([0.0, 0.0, 0.0])) 70 | 71 | def test_instantiation(self): 72 | """Test instantiation.""" 73 | self.assertRaises(TypeError, latent_module.NNModuleLatent) 74 | 75 | def test_call(self): 76 | """Test __call__.""" 77 | latent = DummyNNModuleLatent() 78 | self.assertTrue(torch.equal(latent(), self.latent_value_expected)) 79 | 80 | def test_parameters(self): 81 | """test parameters""" 82 | latent = DummyNNModuleLatent() 83 | params = latent.parameters(recurse=True) 84 | self.assertIsInstance(params, Iterator) 85 | 86 | def test_reset_states(self): 87 | """test reset_states""" 88 | latent = DummyNNModuleLatent() 89 | latent.reset_states() 90 | self.assertTrue(torch.equal(latent(), self.latent_reset_value_expected)) 91 | 92 | 93 | class TestArbitraryLatent(unittest.TestCase): 94 | """Tests for bdpy.recon.torch.modules.latent.ArbitraryLatent.""" 95 | def setUp(self): 96 | self.latent = latent_module.ArbitraryLatent((1, 3, 64, 64), partial(nn.init.normal_, mean=0, std=1)) 97 | self.latent_shape_expected = (1, 3, 64, 64) 98 | self.latent_value_expected = nn.Parameter(torch.tensor([0.0, 1.0, 2.0])) 99 | self.latent_reset_value_expected = nn.Parameter(torch.tensor([0.0, 0.0, 0.0])) 100 | 101 | def test_instantiation(self): 102 | """Test instantiation.""" 103 | self.assertRaises(TypeError, latent_module.ArbitraryLatent) 104 | 105 | def test_call(self): 106 | """Test __call__.""" 107 | self.assertEqual(self.latent().size(), self.latent_shape_expected) 108 | 109 | def test_parameters(self): 110 | """test parameters""" 111 | params = self.latent.parameters(recurse=True) 112 | self.assertIsInstance(params, Iterator) 113 | 114 | def test_reset_states(self): 115 | """test reset_states""" 116 | self.latent.reset_states() 117 | mean = self.latent().mean().item() 118 | std = self.latent().std().item() 119 | self.assertAlmostEqual(mean, 0, places=1) 120 | self.assertAlmostEqual(std, 1, places=1) 121 | 122 | 123 | if __name__ == '__main__': 124 | unittest.main() 125 | -------------------------------------------------------------------------------- /tests/recon/torch/modules/test_optimizer.py: -------------------------------------------------------------------------------- 1 | """Tests for bdpy.recon.torch.modules.optimizer""" 2 | 3 | from __future__ import annotations 4 | 5 | import unittest 6 | 7 | from functools import partial 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from bdpy.recon.torch.modules import build_generator, ArbitraryLatent 12 | from bdpy.recon.torch.modules import build_optimizer_factory, build_scheduler_factory 13 | 14 | 15 | class MLP(nn.Module): 16 | def __init__(self, in_dim, out_dim): 17 | super().__init__() 18 | self.fc = nn.Linear(in_dim, out_dim) 19 | 20 | def forward(self, x): 21 | return self.fc(x) 22 | 23 | 24 | class TestBuildOptimizerFactory(unittest.TestCase): 25 | """Tests for bdpy.recon.torch.modules.optimizer.build_optimizer_factory""" 26 | 27 | def test_build_optimizer_factory(self): 28 | generator = build_generator(MLP(64, 10)) 29 | latent = ArbitraryLatent( 30 | (1, 64), init_fn=partial(nn.init.normal_, mean=0, std=1) 31 | ) 32 | optimizer_factory = build_optimizer_factory(optim.SGD, lr=0.1) 33 | optimizer = optimizer_factory(generator, latent) 34 | self.assertIsInstance( 35 | optimizer, 36 | optim.SGD, 37 | msg="optimizer_factory should return an instance of optim.Optimizer", 38 | ) 39 | 40 | latent.reset_states() 41 | generator.reset_states() 42 | latent_prev = latent().detach().clone().numpy() 43 | optimizer.zero_grad() 44 | output = generator(latent()) 45 | loss = output.sum() 46 | loss.backward() 47 | latent_next_expected = ( 48 | latent_prev - 0.1 * latent().grad.detach().clone().numpy() 49 | ) 50 | optimizer.step() 51 | latent_next = latent().detach().clone().numpy() 52 | np.testing.assert_allclose( 53 | latent_next, 54 | latent_next_expected, 55 | rtol=1e-6, 56 | err_msg="Optimizer does not update the latent variable correctly.", 57 | ) 58 | 59 | # check if all the frozen generator's gradients are None 60 | generator_grad = [p.grad for p in generator.parameters()] 61 | self.assertTrue( 62 | all([g is None for g in generator_grad]), 63 | msg="Frozen generator's gradients should be None after the optimizer step.", 64 | ) 65 | 66 | 67 | class TestBuildSchedulerFactory(unittest.TestCase): 68 | """Tests for bdpy.recon.torch.modules.optimizer.build_scheduler_factory""" 69 | 70 | def test_build_scheduler_factory(self): 71 | generator = build_generator(MLP(64, 10)) 72 | latent = ArbitraryLatent( 73 | (1, 64), init_fn=partial(nn.init.normal_, mean=0, std=1) 74 | ) 75 | optimizer_factory = build_optimizer_factory(optim.SGD, lr=0.1) 76 | scheduler_factory = build_scheduler_factory( 77 | optim.lr_scheduler.StepLR, step_size=1, gamma=0.1 78 | ) 79 | optimizer = optimizer_factory(generator, latent) 80 | scheduler = scheduler_factory(optimizer) 81 | self.assertIsInstance( 82 | scheduler, 83 | optim.lr_scheduler.StepLR, 84 | msg="Scheduler factory should return an instance of optim.lr_scheduler.LRScheduler", 85 | ) 86 | 87 | latent.reset_states() 88 | generator.reset_states() 89 | optimizer.zero_grad() 90 | output = generator(latent()) 91 | loss = output.sum() 92 | loss.backward() 93 | optimizer.step() 94 | scheduler.step() 95 | self.assertEqual( 96 | optimizer.param_groups[0]["lr"], 97 | 0.1 * 0.1, 98 | "Scheduler does not update the learning rate correctly.", 99 | ) 100 | 101 | # check if reference to the optimizer is kept during re-initialization 102 | for _ in range(10): 103 | optimizer = optimizer_factory(generator, latent) 104 | scheduler = scheduler_factory(optimizer) 105 | else: 106 | self.assertTrue( 107 | scheduler.optimizer is optimizer, 108 | "Scheduler should keep the reference to the optimizer during re-initialization.", 109 | ) 110 | 111 | 112 | if __name__ == "__main__": 113 | unittest.main() 114 | -------------------------------------------------------------------------------- /tests/recon/torch/task/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/recon/torch/task/__init__.py -------------------------------------------------------------------------------- /tests/task/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/task/__init__.py -------------------------------------------------------------------------------- /tests/task/test_core.py: -------------------------------------------------------------------------------- 1 | """Tests for bdpy.task.core.""" 2 | 3 | from __future__ import annotations 4 | 5 | import unittest 6 | from bdpy.task import core as core_module 7 | 8 | 9 | class MockCallback(core_module.BaseCallback): 10 | """Mock callback for testing.""" 11 | def __init__(self): 12 | self._storage = [] 13 | 14 | def on_some_event(self, input_): 15 | self._storage.append(input_) 16 | 17 | 18 | class MockTask(core_module.BaseTask[MockCallback]): 19 | """Mock task for testing BaseTask.""" 20 | def run(self, *inputs, **parameters): 21 | self._callback_handler.fire("on_some_event", input_=1) 22 | return inputs, parameters 23 | 24 | 25 | class TestBaseTask(unittest.TestCase): 26 | """Tests forbdpy.task.core.BaseTask """ 27 | def setUp(self): 28 | self.input1 = 1.0 29 | self.input2 = 2.0 30 | self.task_name = "reconstruction" 31 | 32 | def test_initialization_without_callbacks(self): 33 | """Test initialization without callbacks.""" 34 | task = MockTask() 35 | self.assertIsInstance(task._callback_handler, core_module.CallbackHandler) 36 | self.assertEqual(len(task._callback_handler._callbacks), 0) 37 | 38 | def test_initialization_with_callbacks(self): 39 | """Test initialization with callbacks.""" 40 | mock_callback = MockCallback() 41 | task = MockTask(callbacks=mock_callback) 42 | self.assertEqual(len(task._callback_handler._callbacks), 1) 43 | self.assertIn(mock_callback, task._callback_handler._callbacks) 44 | 45 | def test_register_callback(self): 46 | """Test register_callback method.""" 47 | task = MockTask() 48 | mock_callback = MockCallback() 49 | task.register_callback(mock_callback) 50 | self.assertIn(mock_callback, task._callback_handler._callbacks) 51 | 52 | def test_call(self): 53 | """Test __call__""" 54 | mock_callback = MockCallback() 55 | task = MockTask(callbacks=mock_callback) 56 | task_inputs, task_parameters = task(self.input1, self.input2, name=self.task_name) 57 | self.assertEqual(task_inputs, (self.input1, self.input2)) 58 | self.assertEqual(task_parameters["name"], self.task_name) 59 | self.assertEqual(mock_callback._storage, [1]) 60 | 61 | 62 | if __name__ == "__main__": 63 | unittest.main() -------------------------------------------------------------------------------- /tests/test_mri.py: -------------------------------------------------------------------------------- 1 | """Tests for bdpy.util.""" 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | import bdpy.mri as bmr 8 | 9 | 10 | class TestMri(unittest.TestCase): 11 | """Tests for 'mri' module.""" 12 | 13 | def test_get_roiflag_pass0001(self) -> None: 14 | """Test for get_roiflag (pass case 0001).""" 15 | roi_xyz = [np.array([[1, 2, 3], 16 | [1, 2, 3], 17 | [1, 2, 3]])] 18 | epi_xyz = np.array([[1, 2, 3, 4, 5, 6], 19 | [1, 2, 3, 4, 5, 6], 20 | [1, 2, 3, 4, 5, 6]]) 21 | 22 | exp_output = np.array([1, 1, 1, 0, 0, 0]) 23 | 24 | test_output = bmr.get_roiflag(roi_xyz, epi_xyz) # type: ignore 25 | 26 | self.assertTrue((test_output == exp_output).all()) 27 | 28 | def test_get_roiflag_pass0002(self) -> None: 29 | """Test for get_roiflag (pass case 0002).""" 30 | roi_xyz = [np.array([[1, 2, 3], 31 | [1, 2, 3], 32 | [1, 2, 3]]), 33 | np.array([[5, 6], 34 | [5, 6], 35 | [5, 6]])] 36 | epi_xyz = np.array([[1, 2, 3, 4, 5, 6], 37 | [1, 2, 3, 4, 5, 6], 38 | [1, 2, 3, 4, 5, 6]]) 39 | 40 | exp_output = np.array([[1, 1, 1, 0, 0, 0], 41 | [0, 0, 0, 0, 1, 1]]) 42 | 43 | test_output = bmr.get_roiflag(roi_xyz, epi_xyz) # type: ignore 44 | 45 | self.assertTrue((test_output == exp_output).all()) 46 | 47 | 48 | if __name__ == '__main__': 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /tests/test_pipeline.py: -------------------------------------------------------------------------------- 1 | """Tests for bdpy.pipeline.""" 2 | 3 | 4 | from unittest import TestCase, TestLoader, TextTestRunner 5 | import sys 6 | 7 | import numpy as np 8 | import yaml 9 | import hydra 10 | 11 | from bdpy.pipeline.config import init_hydra_cfg 12 | 13 | 14 | class TestPipeline(TestCase): 15 | """Tests for bdpy.pipeline.""" 16 | 17 | def setUp(self): 18 | cfg = { 19 | "key_int": 0, 20 | "key_str": "value", 21 | } 22 | with open("/tmp/testconf.yaml", "w") as f: 23 | yaml.dump(cfg, f) 24 | 25 | def _rest_argv(self): 26 | if len(sys.argv) > 1: 27 | del sys.argv[1:] 28 | hydra.core.global_hydra.GlobalHydra.instance().clear() # hydra-core 1.0.6 29 | 30 | def test_config_init_hydra_cfg_default(self): 31 | """Tests for bdpy.pipeline.config.init_hydra_cfg.""" 32 | self._rest_argv() 33 | sys.argv.append("/tmp/testconf.yaml") 34 | cfg = init_hydra_cfg() 35 | self.assertEqual(cfg.key_int, 0) 36 | self.assertEqual(cfg.key_str, "value") 37 | 38 | def test_config_init_hydra_cfg_override(self): 39 | """Tests for bdpy.pipeline.config.init_hydra_cfg.""" 40 | self._rest_argv() 41 | sys.argv.append("/tmp/testconf.yaml") 42 | sys.argv.append("-o") 43 | sys.argv.append("key_int=1") 44 | cfg = init_hydra_cfg() 45 | self.assertEqual(cfg.key_int, 1) 46 | self.assertEqual(cfg.key_str, "value") 47 | 48 | self._rest_argv() 49 | sys.argv.append("/tmp/testconf.yaml") 50 | sys.argv.append("-o") 51 | sys.argv.append("key_str=hoge") 52 | cfg = init_hydra_cfg() 53 | self.assertEqual(cfg.key_int, 0) 54 | self.assertEqual(cfg.key_str, "hoge") 55 | 56 | self._rest_argv() 57 | sys.argv.append("/tmp/testconf.yaml") 58 | sys.argv.append("-o") 59 | sys.argv.append("key_str='hoge fuga'") 60 | cfg = init_hydra_cfg() 61 | self.assertEqual(cfg.key_int, 0) 62 | self.assertEqual(cfg.key_str, "hoge fuga") 63 | 64 | self._rest_argv() 65 | sys.argv.append("/tmp/testconf.yaml") 66 | sys.argv.append("-o") 67 | sys.argv.append("key_int=1024") 68 | sys.argv.append("key_str=foo") 69 | cfg = init_hydra_cfg() 70 | self.assertEqual(cfg.key_int, 1024) 71 | self.assertEqual(cfg.key_str, "foo") 72 | 73 | def test_config_init_hydra_cfg_run(self): 74 | """Tests for bdpy.pipeline.config.init_hydra_cfg.""" 75 | self._rest_argv() 76 | sys.argv.append("/tmp/testconf.yaml") 77 | cfg = init_hydra_cfg() 78 | self.assertEqual(cfg._run_.name, "test_pipeline") 79 | 80 | self._rest_argv() 81 | sys.argv.append("/tmp/testconf.yaml") 82 | sys.argv.append("-a") 83 | sys.argv.append("overridden_analysis_name") 84 | cfg = init_hydra_cfg() 85 | self.assertEqual(cfg._run_.name, "test_pipeline") 86 | 87 | def test_config_init_hydra_cfg_analysis(self): 88 | """Tests for bdpy.pipeline.config.init_hydra_cfg.""" 89 | self._rest_argv() 90 | sys.argv.append("/tmp/testconf.yaml") 91 | cfg = init_hydra_cfg() 92 | self.assertEqual(cfg._analysis_name_, "test_pipeline") 93 | 94 | self._rest_argv() 95 | sys.argv.append("/tmp/testconf.yaml") 96 | sys.argv.append("-a") 97 | sys.argv.append("overridden_analysis_name") 98 | cfg = init_hydra_cfg() 99 | self.assertEqual(cfg._analysis_name_, "overridden_analysis_name") 100 | 101 | 102 | if __name__ == "__main__": 103 | suite = TestLoader().loadTestsFromTestCase(TestPipeline) 104 | TextTestRunner(verbosity=2).run(suite) 105 | -------------------------------------------------------------------------------- /tests/test_stats.py: -------------------------------------------------------------------------------- 1 | '''Tests for bdpy.stats''' 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | import bdpy.stats as bdst 8 | 9 | 10 | class TestStats(unittest.TestCase): 11 | '''Tests for bdpy.stats''' 12 | 13 | def test_corrcoef_matrix_matrix_default(self): 14 | '''Test for corrcoef (matrix and matrix, default, var=row)''' 15 | 16 | x = np.random.rand(100, 10) 17 | y = np.random.rand(100, 10) 18 | 19 | exp_output = np.diag(np.corrcoef(x, y)[:x.shape[0], x.shape[0]:]) 20 | 21 | test_output = bdst.corrcoef(x, y) 22 | 23 | np.testing.assert_array_equal(test_output, exp_output) 24 | 25 | def test_corrcoef_matrix_matrix_varcol(self): 26 | '''Test for corrcoef (matrix and matrix, var=col)''' 27 | 28 | x = np.random.rand(100, 10) 29 | y = np.random.rand(100, 10) 30 | 31 | exp_output = np.diag(np.corrcoef(x, y, rowvar=0)[:x.shape[1], 32 | x.shape[1]:]) 33 | 34 | test_output = bdst.corrcoef(x, y, var='col') 35 | 36 | np.testing.assert_array_equal(test_output, exp_output) 37 | 38 | def test_corrcoef_vector_vector(self): 39 | '''Test for corrcoef (vector and vector)''' 40 | 41 | x = np.random.rand(100) 42 | y = np.random.rand(100) 43 | 44 | exp_output = np.corrcoef(x, y)[0, 1] 45 | 46 | test_output = bdst.corrcoef(x, y) 47 | 48 | np.testing.assert_array_equal(test_output, exp_output) 49 | 50 | def test_corrcoef_hvector_hvector(self): 51 | '''Test for corrcoef (horizontal vector and horizontal vector)''' 52 | 53 | x = np.random.rand(1, 100) 54 | y = np.random.rand(1, 100) 55 | 56 | exp_output = np.corrcoef(x, y)[0, 1] 57 | 58 | test_output = bdst.corrcoef(x, y) 59 | 60 | np.testing.assert_array_equal(test_output, exp_output) 61 | 62 | def test_corrcoef_vvector_vvector(self): 63 | '''Test for corrcoef (vertical vector and vertical vector)''' 64 | 65 | x = np.random.rand(100, 1) 66 | y = np.random.rand(100, 1) 67 | 68 | exp_output = np.corrcoef(x.T, y.T)[0, 1] 69 | 70 | test_output = bdst.corrcoef(x, y) 71 | 72 | np.testing.assert_array_equal(test_output, exp_output) 73 | 74 | def test_corrcoef_matrix_vector_varrow(self): 75 | '''Test for corrcoef (matrix and vector, var=row)''' 76 | 77 | x = np.random.rand(100, 10) 78 | y = np.random.rand(10) 79 | 80 | exp_output = np.corrcoef(y, x)[0, 1:] 81 | 82 | test_output = bdst.corrcoef(x, y) 83 | 84 | np.testing.assert_array_almost_equal(test_output, exp_output) 85 | 86 | def test_corrcoef_matrix_vector_varcol(self): 87 | '''Test for corrcoef (matrix and vector, var=col)''' 88 | 89 | x = np.random.rand(100, 10) 90 | y = np.random.rand(100) 91 | 92 | exp_output = np.corrcoef(y, x, rowvar=0)[0, 1:] 93 | 94 | test_output = bdst.corrcoef(x, y, var='col') 95 | 96 | np.testing.assert_array_almost_equal(test_output, exp_output) 97 | 98 | def test_corrcoef_vector_matrix_varrow(self): 99 | '''Test for corrcoef (vector and matrix, var=row)''' 100 | 101 | x = np.random.rand(10) 102 | y = np.random.rand(100, 10) 103 | 104 | exp_output = np.corrcoef(x, y)[0, 1:] 105 | 106 | test_output = bdst.corrcoef(x, y) 107 | 108 | np.testing.assert_array_almost_equal(test_output, exp_output) 109 | 110 | def test_corrcoef_vector_matrix_varcol(self): 111 | '''Test for corrcoef (vector and matrix, var=col)''' 112 | 113 | x = np.random.rand(100) 114 | y = np.random.rand(100, 10) 115 | 116 | exp_output = np.corrcoef(x, y, rowvar=0)[0, 1:] 117 | 118 | test_output = bdst.corrcoef(x, y, var='col') 119 | 120 | np.testing.assert_array_almost_equal(test_output, exp_output) 121 | 122 | def test_corrmat_default(self): 123 | '''Test for corrmat (default, var=row)''' 124 | 125 | x = np.random.rand(100, 10) 126 | y = np.random.rand(100, 10) 127 | 128 | exp_output = np.corrcoef(x, y)[:x.shape[0], x.shape[0]:] 129 | 130 | test_output = bdst.corrmat(x, y) 131 | 132 | np.testing.assert_array_almost_equal(test_output, exp_output) 133 | 134 | def test_corrmat_varcol(self): 135 | '''Test for corrmat (var=col)''' 136 | 137 | x = np.random.rand(100, 10) 138 | y = np.random.rand(100, 10) 139 | 140 | exp_output = np.corrcoef(x, y, rowvar=0)[:x.shape[1], x.shape[1]:] 141 | 142 | test_output = bdst.corrmat(x, y, var='col') 143 | 144 | np.testing.assert_array_almost_equal(test_output, exp_output) 145 | 146 | 147 | if __name__ == '__main__': 148 | unittest.main() -------------------------------------------------------------------------------- /tests/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KamitaniLab/bdpy/d86199fc0683fe257dc0a3d474a1abdc81b0249f/tests/util/__init__.py -------------------------------------------------------------------------------- /tests/util/test_math.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from bdpy.util.math import average_elemwise 6 | 7 | 8 | class TestMath(unittest.TestCase): 9 | 10 | def test_average_elemwise(self): 11 | a = np.array([1, 2, 3]) 12 | b = np.array([9, 8, 7]) 13 | ans_valid = np.array([5, 5, 5]) 14 | ans_test = average_elemwise([a, b]) 15 | np.testing.assert_array_equal(ans_test, ans_valid) 16 | 17 | a = np.array([[1, 2, 3]]) 18 | b = np.array([[9, 8, 7]]) 19 | ans_valid = np.array([5, 5, 5]) 20 | ans_test = average_elemwise([a, b]) 21 | np.testing.assert_array_equal(ans_test, ans_valid) 22 | 23 | a = np.array([[1, 2, 3]]) 24 | b = np.array([9, 8, 7]) 25 | ans_valid = np.array([5, 5, 5]) 26 | ans_test = average_elemwise([a, b]) 27 | np.testing.assert_array_equal(ans_test, ans_valid) 28 | 29 | a = np.array([1, 2, 3]) 30 | b = np.array([[9, 8, 7]]) 31 | ans_valid = np.array([5, 5, 5]) 32 | ans_test = average_elemwise([a, b]) 33 | np.testing.assert_array_equal(ans_test, ans_valid) 34 | 35 | def test_average_elemwise_keepdims(self): 36 | a = np.array([1, 2, 3]) 37 | b = np.array([9, 8, 7]) 38 | ans_valid = np.array([5, 5, 5]) 39 | ans_test = average_elemwise([a, b], keepdims=True) 40 | np.testing.assert_array_equal(ans_test, ans_valid) 41 | 42 | a = np.array([[1, 2, 3]]) 43 | b = np.array([[9, 8, 7]]) 44 | ans_valid = np.array([[5, 5, 5]]) 45 | ans_test = average_elemwise([a, b], keepdims=True) 46 | np.testing.assert_array_equal(ans_test, ans_valid) 47 | 48 | a = np.array([[1, 2, 3]]) 49 | b = np.array([9, 8, 7]) 50 | ans_valid = np.array([[5, 5, 5]]) 51 | ans_test = average_elemwise([a, b], keepdims=True) 52 | np.testing.assert_array_equal(ans_test, ans_valid) 53 | 54 | a = np.array([1, 2, 3]) 55 | b = np.array([[9, 8, 7]]) 56 | ans_valid = np.array([[5, 5, 5]]) 57 | ans_test = average_elemwise([a, b], keepdims=True) 58 | np.testing.assert_array_equal(ans_test, ans_valid) 59 | 60 | 61 | if __name__ == '__main__': 62 | unittest.main() -------------------------------------------------------------------------------- /tests/util/test_utils.py: -------------------------------------------------------------------------------- 1 | '''Tests for bdpy.util''' 2 | 3 | 4 | import unittest 5 | 6 | import numpy as np 7 | 8 | from bdpy.util import utils 9 | 10 | 11 | class TestCreateGroupVector(unittest.TestCase): 12 | '''Tests for 'utils.create_groupvector''' 13 | 14 | def test_create_groupvector_pass0001(self): 15 | '''Test for create_groupvector (list and scalar inputs).''' 16 | 17 | x = [1, 2, 3] 18 | y = 2 19 | 20 | exp_output = [1, 1, 2, 2, 3, 3] 21 | 22 | test_output = utils.create_groupvector(x, y) 23 | 24 | self.assertTrue((test_output == exp_output).all()) 25 | 26 | def test_create_groupvector_pass0002(self): 27 | '''Test for create_groupvector (list and list inputs).''' 28 | 29 | x = [1, 2, 3] 30 | y = [2, 4, 2] 31 | 32 | exp_output = [1, 1, 2, 2, 2, 2, 3, 3] 33 | 34 | test_output = utils.create_groupvector(x, y) 35 | 36 | self.assertTrue((test_output == exp_output).all()) 37 | 38 | def test_create_groupvector_pass0003(self): 39 | '''Test for create_groupvector (Numpy array and scalar inputs).''' 40 | 41 | x = np.array([1, 2, 3]) 42 | y = 2 43 | 44 | exp_output = np.array([1, 1, 2, 2, 3, 3]) 45 | 46 | test_output = utils.create_groupvector(x, y) 47 | 48 | np.testing.assert_array_equal(test_output, exp_output) 49 | 50 | def test_create_groupvector_pass0005(self): 51 | '''Test for create_groupvector (Numpy arrays inputs).''' 52 | 53 | x = np.array([1, 2, 3]) 54 | y = np.array([2, 4, 2]) 55 | 56 | exp_output = np.array([1, 1, 2, 2, 2, 2, 3, 3]) 57 | 58 | test_output = utils.create_groupvector(x, y) 59 | 60 | np.testing.assert_array_equal(test_output, exp_output) 61 | 62 | def test_create_groupvector_error(self): 63 | '''Test for create_groupvector (ValueError).''' 64 | 65 | x = [1, 2, 3] 66 | y = [0] 67 | 68 | self.assertRaises(ValueError, utils.create_groupvector, x, y) 69 | 70 | 71 | class TestDivideChunks(unittest.TestCase): 72 | 73 | def test_divide_chunks(self): 74 | '''Test for divide_chunks.''' 75 | 76 | a = [1, 2, 3, 4, 5, 6, 7] 77 | 78 | # Test 1 79 | expected = [[1, 2, 3, 4], 80 | [5, 6, 7]] 81 | actual = utils.divide_chunks(a, chunk_size=4) 82 | self.assertEqual(actual, expected) 83 | 84 | # Test 2 85 | expected = [[1, 2, 3], 86 | [4, 5, 6], 87 | [7]] 88 | actual = utils.divide_chunks(a, chunk_size=3) 89 | self.assertEqual(actual, expected) 90 | 91 | 92 | if __name__ == '__main__': 93 | unittest.main() 94 | --------------------------------------------------------------------------------