├── .github └── workflows │ ├── main.yml │ └── python-test.yml ├── .gitignore ├── CHANGE.txt ├── LICENSE ├── Makefile ├── README.md ├── XKT ├── DKT │ ├── DKT.py │ ├── __init__.py │ └── etl.py ├── DKVMN │ ├── DKVMN.py │ ├── __init__.py │ ├── etl.py │ └── net.py ├── GKT │ ├── MGKT.py │ ├── __init__.py │ ├── etl.py │ └── net.py ├── SKT │ ├── MSKT.py │ ├── __init__.py │ ├── etl.py │ ├── net.py │ └── utils.py ├── __init__.py ├── meta.py └── utils │ ├── __init__.py │ ├── etl.py │ ├── loss.py │ ├── nn │ ├── __init__.py │ └── rnn.py │ └── tests.py ├── docs └── DISCUSSION.md ├── examples ├── DKT │ ├── DKT.ipynb │ ├── DKT.py │ └── prepare_dataset.ipynb ├── DKVMN │ ├── DKVMN.ipynb │ ├── DKVMN.py │ └── prepare_dataset.ipynb ├── GKT │ ├── MGKT.py │ └── prepare_dataset.ipynb └── SKT │ ├── MSKT.py │ └── prepare_dataset.ipynb ├── pytest.ini ├── scripts ├── DKT │ ├── DKT.py │ ├── README.md │ ├── config.yml │ └── search_space.json ├── DKVMN │ ├── DKVMN.py │ ├── README.md │ ├── config.yml │ └── search_space.json ├── GKT │ ├── MGKT.py │ ├── README.md │ ├── config.yml │ └── search_space.json └── SKT │ ├── MSKT.py │ ├── README.md │ ├── config.yml │ └── search_space.json ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── dkt ├── __init__.py ├── conftest.py └── test_dkt.py ├── dkvmn ├── __init__.py ├── conftest.py └── test_dkvmn.py ├── gkt ├── __init__.py ├── conftest.py └── test_mgkt.py └── skt ├── __init__.py ├── conftest.py └── test_mskt.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: '3.x' 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install setuptools wheel twine 22 | - name: Build and publish 23 | env: 24 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 25 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 26 | run: | 27 | python setup.py sdist bdist_wheel 28 | twine upload dist/* 29 | -------------------------------------------------------------------------------- /.github/workflows/python-test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: [3.6, 3.7, 3.8, 3.9] 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | pip install -e .[test] 22 | pip install codecov 23 | - name: Test with pytest 24 | run: | 25 | pytest 26 | codecov 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # IDE 107 | .idea/ 108 | .vscode/ 109 | .DS_Store 110 | 111 | # User Definition 112 | data/ 113 | !data/demo -------------------------------------------------------------------------------- /CHANGE.txt: -------------------------------------------------------------------------------- 1 | v0.0.2: 2 | * refactor the architecture 3 | * add GKT 4 | * add SKT -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 tswsxk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VERSION=`ls dist/*.tar.gz | sed "s/dist\/XKT-\(.*\)\.tar\.gz/\1/g"` 2 | 3 | ifdef ENVPIP 4 | PIP = $(ENVPIP) 5 | else 6 | PIP = pip3 7 | endif 8 | 9 | ifdef ENVPYTHON 10 | PYTHON = $(ENVPYTHON) 11 | else 12 | PYTHON = python3 13 | endif 14 | 15 | ifdef ENVPYTEST 16 | PYTEST = $(ENVPYTEST) 17 | else 18 | PYTEST = pytest 19 | endif 20 | 21 | help: 22 | 23 | @echo "install install XKT" 24 | @echo "test run test" 25 | @echo "release publish to PyPI and release in github" 26 | @echo "release_test publish to TestPyPI" 27 | @echo "clean remove all build, test, coverage and Python artifacts" 28 | @echo "clean-build remove build artifacts" 29 | @echo "clean-pyc remove Python file artifacts" 30 | @echo "clean-test remove test and coverage artifacts" 31 | 32 | .PHONY: install, test, build, release, release_test, version, .test, .build, clean 33 | 34 | install: 35 | @echo "install longling" 36 | $(PIP) install -e . --user 37 | 38 | test: 39 | @echo "run test" 40 | $(PYTEST) 41 | 42 | build: test, clean 43 | $(PYTHON) setup.py bdist_wheel sdist 44 | 45 | .test: 46 | $(PYTEST) > /dev/null 47 | 48 | .build: clean 49 | $(PYTHON) setup.py bdist_wheel sdist > /dev/null 50 | 51 | version: .build 52 | @echo $(VERSION) 53 | 54 | release: test, build 55 | @echo "publish to pypi and release in github" 56 | @echo "version $(VERSION)" 57 | 58 | -@twine upload dist/* && git tag "v$(VERSION)" 59 | git push && git push --tags 60 | 61 | release_test: test, build 62 | @echo "publish to test pypi" 63 | @echo "version $(VERSION)" 64 | 65 | -@twine upload --repository test dist/* 66 | 67 | clean: clean-build clean-pyc clean-test 68 | 69 | clean-build: 70 | rm -rf build/* 71 | rm -rf dist/* 72 | rm -rf .eggs/* 73 | find . -name '*.egg-info' -exec rm -fr {} + 74 | find . -name '*.egg' -exec rm -f {} + 75 | 76 | clean-pyc: 77 | find . -name '*.pyc' -exec rm -f {} + 78 | find . -name '*.pyo' -exec rm -f {} + 79 | find . -name '*~' -exec rm -f {} + 80 | find . -name '__pycache__' -exec rm -rf {} + 81 | 82 | clean-test: 83 | rm -f .coverage -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # XKT 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/XKT.svg)](https://pypi.python.org/pypi/XKT) 4 | [![test](https://github.com/tswsxk/XKT/actions/workflows/python-test.yml/badge.svg?branch=master)](https://github.com/tswsxk/XKT/actions/workflows/python-test.yml) 5 | [![codecov](https://codecov.io/gh/tswsxk/XKT/branch/master/graph/badge.svg)](https://codecov.io/gh/tswsxk/XKT) 6 | [![Download](https://img.shields.io/pypi/dm/XKT.svg?style=flat)](https://pypi.python.org/pypi/XKT) 7 | [![License](https://img.shields.io/github/license/bigdata-ustc/XKT)](LICENSE) 8 | 9 | 10 | Multiple Knowledge Tracing models implemented by mxnet-gluon. 11 | 12 | The people who like pytorch can visit the sister projects: 13 | * [EduKTM](https://github.com/bigdata-ustc/EduKTM) 14 | * [TKT](https://github.com/bigdata-ustc/TKT) 15 | 16 | where the previous one is easy-to-understanding and 17 | the latter one shares the same architecture with XKT. 18 | 19 | For convenient dataset downloading and preprocessing of knowledge tracing task, 20 | visit [Edudata](https://github.com/bigdata-ustc/EduData) for handy api. 21 | 22 | 23 | ## Tutorial 24 | 25 | ### Installation 26 | 27 | 1. First get the repo in your computer by `git` or any way you like. 28 | 2. Suppose you create the project under your own `home` directory, then you can use use 29 | 1. `pip install -e .` to install the package, or 30 | 2. `export PYTHONPATH=$PYTHONPATH:~/XKT` 31 | 32 | ### Quick Start 33 | 34 | To know how to use XKT, readers are encouraged to see 35 | * [examples](examples) containing script usage and notebook demo and 36 | * [scripts](scripts) containing command-line interfaces which can be used to conduct hyper-parameters searching. 37 | 38 | ### Data Format 39 | In `XKT`, all sequence is store in `json` format, such as: 40 | ```json 41 | [[419, 1], [419, 1], [419, 1], [665, 0], [665, 0]] 42 | ``` 43 | Each item in the sequence represent one interaction. The first element of the item is the exercise id 44 | and the second one indicates whether the learner correctly answer the exercise, 0 for wrongly while 1 for correctly 45 | One line, one `json` record, which is corresponded to a learner's interaction sequence. 46 | 47 | A demo loading program is presented as follows: 48 | ```python 49 | import json 50 | from tqdm import tqdm 51 | 52 | def extract(data_src): 53 | responses = [] 54 | step = 200 55 | with open(data_src) as f: 56 | for line in tqdm(f, "reading data from %s" % data_src): 57 | data = json.loads(line) 58 | for i in range(0, len(data), step): 59 | if len(data[i: i + step]) < 2: 60 | continue 61 | responses.append(data[i: i + step]) 62 | 63 | return responses 64 | ``` 65 | The above program can be found in `XKT/utils/etl.py`. 66 | 67 | To deal with the issue that the dataset is store in `tl` format: 68 | 69 | ```text 70 | 5 71 | 419,419,419,665,665 72 | 1,1,1,0,0 73 | ``` 74 | 75 | Refer to [Edudata Documentation](https://github.com/bigdata-ustc/EduData#format-converter). 76 | 77 | 78 | ## Citation 79 | 80 | If this repository is helpful for you, please cite our work 81 | 82 | ```bibtex 83 | @inproceedings{tong2020structure, 84 | title={Structure-based Knowledge Tracing: An Influence Propagation View}, 85 | author={Tong, Shiwei and Liu, Qi and Huang, Wei and Huang, Zhenya and Chen, Enhong and Liu, Chuanren and Ma, Haiping and Wang, Shijin}, 86 | booktitle={2020 IEEE International Conference on Data Mining (ICDM)}, 87 | pages={541--550}, 88 | year={2020}, 89 | organization={IEEE} 90 | } 91 | ``` 92 | 93 | 94 | ## Appendix 95 | 96 | ### Model 97 | There are a lot of models that implements different knowledge tracing models in different frameworks, 98 | the following are the url of those implemented by python (the stared is the authors version): 99 | 100 | * DKT [[tensorflow]](https://github.com/mhagiwara/deep-knowledge-tracing) 101 | 102 | * DKT+ [[tensorflow*]](https://github.com/ckyeungac/deep-knowledge-tracing-plus) 103 | 104 | * DKVMN [[mxnet*]](https://github.com/jennyzhang0215/DKVMN) 105 | 106 | * KTM [[libfm]](https://github.com/jilljenn/ktm) 107 | 108 | * EKT[[pytorch*]](https://github.com/bigdata-ustc/ekt) 109 | 110 | More models can be found in [here](https://paperswithcode.com/task/knowledge-tracing) 111 | 112 | ### Dataset 113 | There are some datasets which are suitable for this task, 114 | you can refer to [BaseData ktbd doc](https://github.com/bigdata-ustc/EduData/blob/master/docs/ktbd.md) 115 | for these datasets 116 | -------------------------------------------------------------------------------- /XKT/DKT/DKT.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | from tqdm import tqdm 5 | import mxnet as mx 6 | from mxnet import gluon 7 | from XKT.meta import KTM 8 | from XKT.utils import SLMLoss 9 | from baize import get_params_filepath, get_epoch_params_filepath, path_append 10 | from baize.const import CFG_JSON 11 | from baize.mxnet import light_module as lm, Configuration, fit_wrapper, split_and_load 12 | from baize.metrics import classification_report 13 | from .etl import etl 14 | 15 | 16 | class DKTNet(gluon.HybridBlock): 17 | def __init__(self, ku_num, hidden_num, 18 | add_embedding_layer=False, embedding_dim=None, embedding_dropout=None, 19 | dropout=0.0, rnn_type=None, 20 | prefix=None, params=None, **kwargs): 21 | """ 22 | Deep Knowledge Tracing Model 23 | 24 | Parameters 25 | ---------- 26 | ku_num: int 27 | Number of knowledge units 28 | hidden_num : int 29 | Number of units in output symbol of rnn 30 | add_embedding_layer: bool 31 | Whether add embedding layer 32 | embedding_dim: int or None 33 | When embedding_dim is None, the embedding_dim will be equal to hidden_num 34 | embedding_dropout: float or None 35 | When not set, be equal to dropout 36 | dropout: float 37 | Fraction of the input units to drop. Must be a number between 0 and 1. 38 | rnn_type: str or None 39 | rnn, lstm or gru 40 | prefix : str 41 | Prefix for name of `Block`s 42 | params : Parameter or None 43 | Container for weight sharing between cells. 44 | Created if `None`. 45 | """ 46 | super(DKTNet, self).__init__(prefix, params) 47 | 48 | self.length = None 49 | self.ku_num = ku_num 50 | self.hidden_dim = hidden_num 51 | self.add_embedding_layer = add_embedding_layer 52 | 53 | with self.name_scope(): 54 | if add_embedding_layer is True: 55 | embedding_dim = self.hidden_dim if embedding_dim is None else embedding_dim 56 | embedding_dropout = dropout if embedding_dropout is None else embedding_dropout 57 | self.embedding = gluon.nn.HybridSequential() 58 | self.embedding.add( 59 | gluon.nn.Embedding(ku_num * 2, embedding_dim), 60 | gluon.nn.Dropout(embedding_dropout) 61 | ) 62 | cell = gluon.rnn.LSTMCell 63 | else: 64 | self.embedding = lambda x, F: F.one_hot(x, ku_num * 2) 65 | cell = gluon.rnn.RNNCell 66 | 67 | if rnn_type is not None: 68 | if rnn_type in {"elman", "rnn"}: 69 | cell = gluon.rnn.RNNCell 70 | elif rnn_type == "lstm": 71 | cell = gluon.rnn.LSTMCell 72 | elif rnn_type == "gru": 73 | cell = gluon.rnn.GRUCell 74 | else: 75 | raise TypeError("unknown rnn type: %s" % rnn_type) 76 | 77 | self.rnn = gluon.rnn.HybridSequentialRNNCell() 78 | self.rnn.add( 79 | cell(hidden_num), 80 | ) 81 | self.dropout = gluon.nn.Dropout(dropout) 82 | self.nn = gluon.nn.HybridSequential() 83 | self.nn.add( 84 | gluon.nn.Dense(ku_num, flatten=False) 85 | ) 86 | 87 | def hybrid_forward(self, F, responses, mask=None, begin_state=None, *args, **kwargs): 88 | length = self.length if self.length else len(responses[0]) 89 | 90 | if self.add_embedding_layer: 91 | input_data = self.embedding(responses) 92 | else: 93 | input_data = self.embedding(responses, F) 94 | 95 | outputs, states = self.rnn.unroll(length, input_data, begin_state=begin_state, merge_outputs=True, 96 | valid_length=mask) 97 | 98 | output = self.nn(self.dropout(outputs)) 99 | output = F.sigmoid(output) 100 | return output, states 101 | 102 | 103 | @fit_wrapper 104 | def fit(net, batch_data, loss_function, *args, **kwargs): 105 | data, data_mask, label, pick_index, label_mask = batch_data 106 | output, _ = net(data, data_mask) 107 | loss = loss_function(output, pick_index, label, label_mask) 108 | return sum(loss) 109 | 110 | 111 | def evaluation(net, test_data, ctx=mx.cpu(), *args, **kwargs): 112 | ground_truth = [] 113 | prediction = [] 114 | pred_labels = [] 115 | 116 | for batch_data in tqdm(test_data, "evaluating"): 117 | ctx_data = split_and_load( 118 | ctx, *batch_data, 119 | even_split=False 120 | ) 121 | for (data, data_mask, label, pick_index, label_mask) in ctx_data: 122 | output, _ = net(data, data_mask) 123 | output = mx.nd.slice(output, (None, None), (None, -1)) 124 | output = mx.nd.pick(output, pick_index) 125 | pred = output.asnumpy().tolist() 126 | label = label.asnumpy().tolist() 127 | for i, length in enumerate(label_mask.asnumpy().tolist()): 128 | length = int(length) 129 | ground_truth.extend(label[i][:length]) 130 | prediction.extend(pred[i][:length]) 131 | pred_labels.extend([0 if p < 0.5 else 1 for p in pred[i][:length]]) 132 | 133 | return classification_report(ground_truth, y_pred=pred_labels, y_score=prediction) 134 | 135 | 136 | def get_net(**kwargs): 137 | 138 | return DKTNet(**kwargs) 139 | 140 | 141 | class DKT(KTM): 142 | """ 143 | Examples 144 | -------- 145 | >>> import mxnet as mx 146 | >>> model = DKT(init_net=True, hyper_params={"ku_num": 3, "hidden_num": 5}) 147 | >>> model.net.initialize() 148 | >>> inputs = mx.nd.ones((2, 4)) 149 | >>> outputs, (states, *_) = model(inputs) 150 | >>> outputs.shape 151 | (2, 4, 3) 152 | >>> states.shape 153 | (2, 5) 154 | >>> outputs, (states, *_) = model(inputs, begin_state=[states]) 155 | >>> outputs.shape 156 | (2, 4, 3) 157 | >>> states.shape 158 | (2, 5) 159 | """ 160 | def __init__(self, init_net=True, cfg_path=None, *args, **kwargs): 161 | super(DKT, self).__init__(Configuration(params_path=cfg_path, *args, **kwargs)) 162 | if init_net: 163 | self.net = get_net(**self.cfg.hyper_params) 164 | 165 | def __call__(self, x, mask=None, begin_state=None): 166 | return super(DKT, self).__call__(x, mask, begin_state) 167 | 168 | def train(self, train_data, valid_data=None, re_init_net=False, enable_hyper_search=False, 169 | save=False, *args, **kwargs) -> ...: 170 | self.cfg.update(**kwargs) 171 | 172 | print(self.cfg) 173 | 174 | lm.train( 175 | net=self.net, 176 | cfg=self.cfg, 177 | get_net=get_net if re_init_net is True else None, 178 | fit_f=fit, 179 | eval_f=evaluation, 180 | trainer=None, 181 | loss_function=SLMLoss(**self.cfg.loss_params), 182 | train_data=train_data, 183 | test_data=valid_data, 184 | enable_hyper_search=enable_hyper_search, 185 | dump_result=save, 186 | params_save=save, 187 | primary_key="macro_auc" 188 | ) 189 | 190 | def eval(self, test_data, *args, **kwargs) -> ...: 191 | return evaluation(self.net, test_data, *args, **kwargs) 192 | 193 | @classmethod 194 | def from_pretrained(cls, model_dir, best_epoch=None, *args, **kwargs): 195 | cfg_path = path_append(model_dir, CFG_JSON) 196 | model = DKT(init_net=True, cfg_path=cfg_path, model_dir=model_dir) 197 | cfg = model.cfg 198 | model.load( 199 | get_epoch_params_filepath(cfg.model_name, best_epoch, cfg.model_dir) 200 | if best_epoch is not None else get_params_filepath(cfg.model_name, cfg.model_dir) 201 | ) 202 | return model 203 | 204 | @classmethod 205 | def benchmark_train(cls, train_path, valid_path=None, enable_hyper_search=False, 206 | save=False, *args, **kwargs): 207 | dkt = DKT(init_net=not enable_hyper_search, *args, **kwargs) 208 | train_data = etl(train_path, dkt.cfg) 209 | valid_data = etl(valid_path, dkt.cfg) if valid_path is not None else None 210 | dkt.train(train_data, valid_data, re_init_net=enable_hyper_search, enable_hyper_search=enable_hyper_search, 211 | save=save) 212 | 213 | @classmethod 214 | def benchmark_eval(cls, test_path, model_path, best_epoch, *args, **kwargs): 215 | dkt = DKT.from_pretrained(model_path, best_epoch) 216 | test_data = etl(test_path, dkt.cfg) 217 | return dkt.eval(test_data) 218 | -------------------------------------------------------------------------------- /XKT/DKT/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | from .DKT import DKT, Configuration 5 | from .etl import etl 6 | -------------------------------------------------------------------------------- /XKT/DKT/etl.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | import mxnet.ndarray as nd 5 | from tqdm import tqdm 6 | from XKT.utils import extract 7 | from baize.utils import FixedBucketSampler, PadSequence 8 | 9 | 10 | def transform(raw_data, batch_size, num_buckets=100): 11 | # 定义数据转换接口 12 | # raw_data --> batch_data 13 | 14 | responses = raw_data 15 | 16 | batch_idxes = FixedBucketSampler([len(rs) for rs in responses], batch_size, num_buckets=num_buckets) 17 | batch = [] 18 | 19 | def index(r): 20 | correct = 0 if r[1] <= 0 else 1 21 | return r[0] * 2 + correct 22 | 23 | for batch_idx in tqdm(batch_idxes, "batchify"): 24 | batch_rs = [] 25 | batch_pick_index = [] 26 | batch_labels = [] 27 | for idx in batch_idx: 28 | batch_rs.append([index(r) for r in responses[idx]]) 29 | if len(responses[idx]) <= 1: # pragma: no cover 30 | pick_index, labels = [], [] 31 | else: 32 | pick_index, labels = zip(*[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]]) 33 | batch_pick_index.append(list(pick_index)) 34 | batch_labels.append(list(labels)) 35 | 36 | max_len = max([len(rs) for rs in batch_rs]) 37 | padder = PadSequence(max_len, pad_val=0) 38 | batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs]) 39 | 40 | max_len = max([len(rs) for rs in batch_labels]) 41 | padder = PadSequence(max_len, pad_val=0) 42 | batch_labels, label_mask = zip(*[(padder(labels), len(labels)) for labels in batch_labels]) 43 | batch_pick_index = [padder(pick_index) for pick_index in batch_pick_index] 44 | # Load 45 | batch.append( 46 | [nd.array(batch_rs), nd.array(data_mask), nd.array(batch_labels), 47 | nd.array(batch_pick_index), 48 | nd.array(label_mask)] 49 | ) 50 | 51 | return batch 52 | 53 | 54 | def etl(data_src, cfg=None, batch_size=None, **kwargs): # pragma: no cover 55 | batch_size = batch_size if batch_size is not None else cfg.batch_size 56 | raw_data = extract(data_src) 57 | return transform(raw_data, batch_size, **kwargs) 58 | -------------------------------------------------------------------------------- /XKT/DKVMN/DKVMN.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | from tqdm import tqdm 5 | import mxnet as mx 6 | from XKT.meta import KTM 7 | from XKT.utils import LMLoss 8 | from baize import get_params_filepath, get_epoch_params_filepath, path_append 9 | from baize.const import CFG_JSON 10 | from baize.mxnet import light_module as lm, Configuration as CFG, fit_wrapper, split_and_load 11 | from baize.metrics import classification_report 12 | # from .etl import etl 13 | # from .net import get_net 14 | 15 | from XKT.DKVMN.etl import etl 16 | from XKT.DKVMN.net import get_net 17 | 18 | 19 | class Configuration(CFG): 20 | init_params = {"initializer": "uniform"} 21 | optimizer_params_update = {'wd': 0} 22 | 23 | 24 | @fit_wrapper 25 | def fit(net, batch_data, loss_function, *args, **kwargs): 26 | keys, values, data_mask, label, label_mask = batch_data 27 | output, _ = net(keys, values, mask=data_mask) 28 | loss = loss_function(output, label, label_mask) 29 | return sum(loss) 30 | 31 | 32 | def evaluation(net, test_data, ctx=mx.cpu(), verbose=True, *args, **kwargs): 33 | ground_truth = [] 34 | prediction = [] 35 | pred_labels = [] 36 | 37 | for batch_data in tqdm(test_data, "evaluating", disable=not verbose): 38 | ctx_data = split_and_load( 39 | ctx, *batch_data, 40 | even_split=False 41 | ) 42 | for (keys, values, data_mask, label, label_mask) in ctx_data: 43 | output, _ = net(keys, values, mask=data_mask) 44 | pred = output.asnumpy().tolist() 45 | label = label.asnumpy().tolist() 46 | for i, length in enumerate(label_mask.asnumpy().tolist()): 47 | length = int(length) 48 | ground_truth.extend(label[i][:length]) 49 | prediction.extend(pred[i][:length]) 50 | pred_labels.extend([0 if p < 0.5 else 1 for p in pred[i][:length]]) 51 | 52 | return classification_report(ground_truth, y_pred=pred_labels, y_score=prediction) 53 | 54 | 55 | class DKVMN(KTM): 56 | """ 57 | Examples 58 | -------- 59 | >>> import mxnet as mx 60 | >>> model = DKVMN( 61 | ... init_net=True, 62 | ... hyper_params={ 63 | ... "ku_num": 3, "hidden_num": 5, "key_embedding_dim": 2, "value_embedding_dim": 2, "key_memory_size": 3 64 | ... } 65 | ... ) 66 | >>> model.net.initialize() 67 | >>> item_id = mx.nd.ones((2, 4)) 68 | >>> response = mx.nd.ones((2, 4)) 69 | >>> outputs, _ = model(item_id, response) 70 | >>> outputs.shape 71 | (2, 4) 72 | """ 73 | 74 | def __init__(self, init_net=True, cfg_path=None, *args, **kwargs): 75 | super(DKVMN, self).__init__(Configuration(params_path=cfg_path, *args, **kwargs)) 76 | if init_net: 77 | self.net = get_net(**self.cfg.hyper_params) 78 | 79 | def train(self, train_data, valid_data=None, re_init_net=False, enable_hyper_search=False, 80 | save=False, *args, **kwargs) -> ...: 81 | self.cfg.update(**kwargs) 82 | 83 | if not enable_hyper_search: 84 | print(self.cfg) 85 | 86 | lm.train( 87 | net=self.net, 88 | cfg=self.cfg, 89 | get_net=get_net if re_init_net is True else None, 90 | fit_f=fit, 91 | eval_f=evaluation, 92 | trainer=None, 93 | loss_function=LMLoss(**self.cfg.loss_params), 94 | train_data=train_data, 95 | test_data=valid_data, 96 | enable_hyper_search=enable_hyper_search, 97 | dump_result=save, 98 | params_save=save, 99 | primary_key="macro_auc", 100 | ) 101 | 102 | def eval(self, test_data, *args, **kwargs) -> ...: 103 | return evaluation(self.net, test_data, *args, **kwargs) 104 | 105 | @classmethod 106 | def from_pretrained(cls, model_dir, best_epoch=None, *args, **kwargs): 107 | cfg_path = path_append(model_dir, CFG_JSON) 108 | model = DKVMN(init_net=True, cfg_path=cfg_path, model_dir=model_dir) 109 | cfg = model.cfg 110 | model.load( 111 | get_epoch_params_filepath(cfg.model_name, best_epoch, cfg.model_dir) 112 | if best_epoch is not None else get_params_filepath(cfg.model_name, cfg.model_dir) 113 | ) 114 | return model 115 | 116 | @classmethod 117 | def benchmark_train(cls, train_path, valid_path=None, enable_hyper_search=False, 118 | save=False, *args, **kwargs): 119 | dkt = DKVMN(init_net=not enable_hyper_search, *args, **kwargs) 120 | train_data = etl(train_path, dkt.cfg) 121 | valid_data = etl(valid_path, dkt.cfg) if valid_path is not None else None 122 | dkt.train(train_data, valid_data, re_init_net=enable_hyper_search, enable_hyper_search=enable_hyper_search, 123 | save=save) 124 | 125 | @classmethod 126 | def benchmark_eval(cls, test_path, model_path, best_epoch, *args, **kwargs): 127 | dkt = DKVMN.from_pretrained(model_path, best_epoch) 128 | test_data = etl(test_path, dkt.cfg) 129 | return dkt.eval(test_data) 130 | -------------------------------------------------------------------------------- /XKT/DKVMN/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | from .DKVMN import DKVMN 5 | from .etl import etl 6 | -------------------------------------------------------------------------------- /XKT/DKVMN/etl.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | import mxnet.ndarray as nd 5 | from tqdm import tqdm 6 | from XKT.utils import extract 7 | from baize.utils import FixedBucketSampler, PadSequence 8 | 9 | 10 | def transform(raw_data, batch_size, num_buckets=100): 11 | # 定义数据转换接口 12 | # raw_data --> batch_data 13 | 14 | responses = raw_data 15 | 16 | batch_idxes = FixedBucketSampler([len(rs) for rs in responses], batch_size, num_buckets=num_buckets) 17 | batch = [] 18 | 19 | def response_index(r): 20 | correct = 0 if r[1] <= 0 else 1 21 | return r[0] * 2 + correct 22 | 23 | def question_index(r): 24 | return r[0] 25 | 26 | for batch_idx in tqdm(batch_idxes, "batchify"): 27 | batch_qs = [] 28 | batch_rs = [] 29 | batch_labels = [] 30 | for idx in batch_idx: 31 | batch_qs.append([question_index(r) for r in responses[idx]]) 32 | batch_rs.append([response_index(r) for r in responses[idx]]) 33 | labels = [0 if r[1] <= 0 else 1 for r in responses[idx][:]] 34 | batch_labels.append(list(labels)) 35 | 36 | max_len = max([len(rs) for rs in batch_rs]) 37 | padder = PadSequence(max_len, pad_val=0) 38 | batch_qs, _ = zip(*[(padder(qs), len(qs)) for qs in batch_qs]) 39 | batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs]) 40 | 41 | max_len = max([len(rs) for rs in batch_labels]) 42 | padder = PadSequence(max_len, pad_val=0) 43 | batch_labels, label_mask = zip(*[(padder(labels), len(labels)) for labels in batch_labels]) 44 | batch.append( 45 | [ 46 | nd.array(batch_qs, dtype="float32"), 47 | nd.array(batch_rs, dtype="float32"), 48 | nd.array(data_mask), 49 | nd.array(batch_labels), 50 | nd.array(label_mask) 51 | ] 52 | ) 53 | 54 | return batch 55 | 56 | 57 | def etl(data_src, cfg=None, batch_size=None, **kwargs): # pragma: no cover 58 | batch_size = batch_size if batch_size is not None else cfg.batch_size 59 | raw_data = extract(data_src) 60 | return transform(raw_data, batch_size, **kwargs) 61 | -------------------------------------------------------------------------------- /XKT/DKVMN/net.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | from baize.mxnet.utils import format_sequence, mask_sequence_variable_length 5 | from mxnet import gluon 6 | from mxnet import ndarray 7 | 8 | 9 | def get_net(ku_num, key_embedding_dim, value_embedding_dim, hidden_num, 10 | key_memory_size, 11 | nettype="DKVMN", dropout=0.0, **kwargs): 12 | return DKVMN( 13 | ku_num=ku_num, 14 | key_embedding_dim=key_embedding_dim, 15 | value_embedding_dim=value_embedding_dim, 16 | hidden_num=hidden_num, 17 | key_memory_size=key_memory_size, 18 | nettype=nettype, 19 | dropout=dropout, 20 | **kwargs 21 | ) 22 | 23 | 24 | class KVMNCell(gluon.HybridBlock): 25 | def __init__(self, memory_state_dim, memory_size, input_size=0, prefix=None, params=None, *args, **kwargs): 26 | super(KVMNCell, self).__init__(prefix=prefix, params=params) 27 | 28 | self._input_size = input_size 29 | self.memory_size = memory_size 30 | self.memory_state_dim = memory_state_dim 31 | 32 | def addressing(self, F, control_input, memory): 33 | """ 34 | 35 | Parameters 36 | ---------- 37 | F 38 | control_input: Shape (batch_size, control_state_dim) 39 | memory: Shape (memory_size, memory_state_dim) 40 | 41 | Returns 42 | ------- 43 | correlation_weight: Shape (batch_size, memory_size) 44 | """ 45 | similarity_score = F.FullyConnected(data=control_input, 46 | num_hidden=self.memory_size, 47 | weight=memory, 48 | no_bias=True, 49 | name="similarity_score") 50 | correlation_weight = F.SoftmaxActivation(similarity_score) # Shape: (batch_size, memory_size) 51 | return correlation_weight 52 | 53 | def reset(self): 54 | pass 55 | 56 | def hybrid_forward(self, F, control_input, memory, *args, **kwargs): 57 | return self.addressing(F, control_input, memory) 58 | 59 | 60 | class KVMNReadCell(KVMNCell): 61 | def __init__(self, memory_state_dim, memory_size, input_size=0, prefix=None, params=None): 62 | super(KVMNReadCell, self).__init__(memory_state_dim, memory_size, input_size, prefix, params) 63 | 64 | def read(self, memory, control_input=None, read_weight=None): 65 | return self(memory, control_input, read_weight) 66 | 67 | def hybrid_forward(self, F, memory, control_input=None, read_weight=None): 68 | """ 69 | 70 | Parameters 71 | ---------- 72 | F 73 | control_input: Shape (batch_size, control_state_dim) 74 | memory: Shape (batch_size, memory_size, memory_state_dim) 75 | read_weight: Shape (batch_size, memory_size) 76 | 77 | Returns 78 | ------- 79 | read_content: Shape (batch_size, memory_state_dim) 80 | """ 81 | if read_weight is None: 82 | read_weight = self.addressing(F, control_input=control_input, memory=memory) 83 | read_weight = F.Reshape(read_weight, shape=(-1, 1, self.memory_size)) 84 | read_content = F.Reshape(data=F.batch_dot(read_weight, memory), 85 | # Shape (batch_size, 1, memory_state_dim) 86 | shape=(-1, self.memory_state_dim)) # Shape (batch_size, memory_state_dim) 87 | return read_content 88 | 89 | 90 | class KVMNWriteCell(KVMNCell): 91 | def __init__(self, memory_state_dim, memory_size, input_size=0, 92 | erase_signal_weight_initializer=None, erase_signal_bias_initializer=None, 93 | add_signal_weight_initializer=None, add_signal_bias_initializer=None, 94 | prefix=None, params=None): 95 | super(KVMNWriteCell, self).__init__(memory_state_dim, memory_size, input_size, prefix, params) 96 | with self.name_scope(): 97 | self.erase_signal_weight = self.params.get('erase_signal_weight', shape=(memory_state_dim, input_size), 98 | init=erase_signal_weight_initializer, 99 | allow_deferred_init=True) 100 | 101 | self.erase_signal_bias = self.params.get('erase_signal_bias', shape=(memory_state_dim,), 102 | init=erase_signal_bias_initializer, 103 | allow_deferred_init=True) 104 | 105 | self.add_signal_weight = self.params.get('add_signal_weight', shape=(memory_state_dim, input_size), 106 | init=add_signal_weight_initializer, 107 | allow_deferred_init=True) 108 | 109 | self.add_signal_bias = self.params.get('add_signal_bias', shape=(memory_state_dim,), 110 | init=add_signal_bias_initializer, 111 | allow_deferred_init=True) 112 | 113 | def read(self, F, memory, control_input=None, read_weight=None): 114 | if read_weight is None: 115 | read_weight = self.addressing(F, control_input=control_input, memory=memory) 116 | read_weight = F.Reshape(read_weight, shape=(-1, 1, self.memory_size)) 117 | read_content = F.Reshape(data=F.batch_dot(read_weight, memory, name=self.name + "read_content_batch_dot"), 118 | # Shape (batch_size, 1, memory_state_dim) 119 | shape=(-1, self.memory_state_dim)) # Shape (batch_size, memory_state_dim) 120 | return read_content 121 | 122 | def write(self, memory, control_input, write_weight): 123 | return self(memory, control_input, write_weight) 124 | 125 | def hybrid_forward(self, F, memory, control_input, write_weight, 126 | erase_signal_weight, erase_signal_bias, add_signal_weight, add_signal_bias, 127 | ): 128 | if write_weight is None: 129 | write_weight = self.addressing( 130 | F, control_input=control_input, memory=memory 131 | ) # Shape Shape (batch_size, memory_size) 132 | 133 | # erase_signal Shape (batch_size, memory_state_dim) 134 | erase_signal = F.FullyConnected(data=control_input, 135 | num_hidden=self.memory_state_dim, 136 | weight=erase_signal_weight, 137 | bias=erase_signal_bias) 138 | erase_signal = F.Activation(data=erase_signal, act_type='sigmoid', name=self.name + "_erase_signal") 139 | # add_signal Shape (batch_size, memory_state_dim) 140 | add_signal = F.FullyConnected(data=control_input, 141 | num_hidden=self.memory_state_dim, 142 | weight=add_signal_weight, 143 | bias=add_signal_bias) 144 | add_signal = F.Activation(data=add_signal, act_type='tanh', name=self.name + "_add_signal") 145 | # erase_mult Shape (batch_size, memory_size, memory_state_dim) 146 | erase_mult = 1 - F.batch_dot(F.Reshape(write_weight, shape=(-1, self.memory_size, 1)), 147 | F.Reshape(erase_signal, shape=(-1, 1, self.memory_state_dim)), 148 | name=self.name + "_erase_mult") 149 | 150 | aggre_add_signal = F.batch_dot(F.Reshape(write_weight, shape=(-1, self.memory_size, 1)), 151 | F.Reshape(add_signal, shape=(-1, 1, self.memory_state_dim)), 152 | name=self.name + "_aggre_add_signal") 153 | new_memory = memory * erase_mult + aggre_add_signal 154 | return new_memory 155 | 156 | 157 | class DKVMNCell(gluon.HybridBlock): 158 | def __init__(self, key_memory_size, key_memory_state_dim, value_memory_size, value_memory_state_dim, 159 | prefix=None, params=None): 160 | super(DKVMNCell, self).__init__(prefix, params) 161 | self._modified = False 162 | self.reset() 163 | 164 | with self.name_scope(): 165 | self.key_head = KVMNReadCell( 166 | memory_size=key_memory_size, 167 | memory_state_dim=key_memory_state_dim, 168 | prefix=self.prefix + "->key_head" 169 | ) 170 | self.value_head = KVMNWriteCell( 171 | memory_size=value_memory_size, 172 | memory_state_dim=value_memory_state_dim, 173 | prefix=self.prefix + "->value_head" 174 | ) 175 | 176 | self.key_memory_size = key_memory_size 177 | self.key_memory_state_dim = key_memory_state_dim 178 | self.value_memory_size = value_memory_size 179 | self.value_memory_state_dim = value_memory_state_dim 180 | 181 | def forward(self, *args): 182 | """Unrolls the recurrent cell for one time step. 183 | 184 | Parameters 185 | ---------- 186 | inputs : sym.Variable 187 | Input symbol, 2D, of shape (batch_size * num_units). 188 | states : list of sym.Variable 189 | RNN state from previous step or the output of begin_state(). 190 | 191 | Returns 192 | ------- 193 | output : Symbol 194 | Symbol corresponding to the output from the RNN when unrolling 195 | for a single time step. 196 | states : list of Symbol 197 | The new state of this RNN after this unrolling. 198 | The type of this symbol is same as the output of `begin_state()`. 199 | This can be used as an input state to the next time step 200 | of this RNN. 201 | 202 | See Also 203 | -------- 204 | begin_state: This function can provide the states for the first time step. 205 | unroll: This function unrolls an RNN for a given number of (>=1) time steps. 206 | """ 207 | # pylint: disable= arguments-differ 208 | self._counter += 1 209 | return super(DKVMNCell, self).forward(*args) 210 | 211 | def reset(self): 212 | """Reset before re-using the cell for another graph.""" 213 | self._init_counter = -1 214 | self._counter = -1 215 | for cell in self._children.values(): 216 | cell.reset() 217 | 218 | def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): 219 | """Initial state for this cell. 220 | 221 | Parameters 222 | ---------- 223 | func : callable, default symbol.zeros 224 | Function for creating initial state. 225 | 226 | For Symbol API, func can be `symbol.zeros`, `symbol.uniform`, 227 | `symbol.var etc`. Use `symbol.var` if you want to directly 228 | feed input as states. 229 | 230 | For NDArray API, func can be `ndarray.zeros`, `ndarray.ones`, etc. 231 | batch_size: int, default 0 232 | Only required for NDArray API. Size of the batch ('N' in layout) 233 | dimension of input. 234 | 235 | **kwargs : 236 | Additional keyword arguments passed to func. For example 237 | `mean`, `std`, `dtype`, etc. 238 | 239 | Returns 240 | ------- 241 | states : nested list of Symbol 242 | Starting states for the first RNN step. 243 | """ 244 | assert not self._modified, \ 245 | "After applying modifier cells (e.g. ZoneoutCell) the base " \ 246 | "cell cannot be called directly. Call the modifier cell instead." 247 | states = [] 248 | for info in self.state_info(batch_size): 249 | self._init_counter += 1 250 | if info is not None: 251 | info.update(kwargs) 252 | else: 253 | info = kwargs 254 | state = func(name='%sbegin_state_%d' % (self._prefix, self._init_counter), 255 | **info) 256 | states.append(state) 257 | return states 258 | 259 | def state_info(self, batch_size=0): 260 | return [ 261 | {'shape': (batch_size, self.key_memory_size, self.key_memory_state_dim), '__layout__': 'NC'}, 262 | {'shape': (batch_size, self.value_memory_size, self.key_memory_state_dim), '__layout__': 'NC'} 263 | ] 264 | 265 | def _alias(self): 266 | return 'dkvmn_cell' 267 | 268 | def attention(self, F, control_input, memory): 269 | correlation_weight = self.key_head.addressing(F, control_input=control_input, memory=memory) 270 | return correlation_weight # (batch_size, memory_size) 271 | 272 | def read(self, F, read_weight, memory): 273 | read_content = self.value_head.read(F, memory=memory, read_weight=read_weight) 274 | return read_content # (batch_size, memory_state_dim) 275 | 276 | def write(self, F, write_weight, control_input, memory): 277 | memory_value = self.value_head.write(control_input=control_input, 278 | memory=memory, 279 | write_weight=write_weight) 280 | return memory_value 281 | 282 | def hybrid_forward(self, F, keys, values, key_memory, value_memory): 283 | # Attention 284 | correlation_weight = self.attention(F, keys, key_memory) 285 | 286 | # Read Process 287 | read_content = self.read(F, correlation_weight, value_memory) 288 | 289 | # Write Process 290 | next_value_memory = self.write(F, correlation_weight, values, value_memory) 291 | 292 | return read_content, [key_memory, next_value_memory] 293 | 294 | def unroll(self, length, keys, values, key_memory, value_memory, layout='NTC', merge_outputs=None, 295 | valid_length=None): 296 | """Unrolls an RNN cell across time steps. 297 | 298 | Parameters 299 | ---------- 300 | length : int 301 | Number of steps to unroll. 302 | inputs : Symbol, list of Symbol, or None 303 | If `inputs` is a single Symbol (usually the output 304 | of Embedding symbol), it should have shape 305 | (batch_size, length, ...) if `layout` is 'NTC', 306 | or (length, batch_size, ...) if `layout` is 'TNC'. 307 | 308 | If `inputs` is a list of symbols (usually output of 309 | previous unroll), they should all have shape 310 | (batch_size, ...). 311 | begin_memory : nested list of Symbol, optional 312 | Input states created by `begin_state()` 313 | or output state of another cell. 314 | Created from `begin_state()` if `None`. 315 | layout : str, optional 316 | `layout` of input symbol. Only used if inputs 317 | is a single Symbol. 318 | merge_outputs : bool, optional 319 | If `False`, returns outputs as a list of Symbols. 320 | If `True`, concatenates output across time steps 321 | and returns a single symbol with shape 322 | (batch_size, length, ...) if layout is 'NTC', 323 | or (length, batch_size, ...) if layout is 'TNC'. 324 | If `None`, output whatever is faster. 325 | valid_length : Symbol, NDArray or None 326 | `valid_length` specifies the length of the sequences in the batch without padding. 327 | This option is especially useful for building sequence-to-sequence models where 328 | the input and output sequences would potentially be padded. 329 | If `valid_length` is None, all sequences are assumed to have the same length. 330 | If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,). 331 | The ith element will be the length of the ith sequence in the batch. 332 | The last valid state will be return and the padded outputs will be masked with 0. 333 | Note that `valid_length` must be smaller or equal to `length`. 334 | 335 | Returns 336 | ------- 337 | outputs : list of Symbol or Symbol 338 | Symbol (if `merge_outputs` is True) or list of Symbols 339 | (if `merge_outputs` is False) corresponding to the output from 340 | the RNN from this unrolling. 341 | 342 | states : list of Symbol 343 | The new state of this RNN after this unrolling. 344 | The type of this symbol is same as the output of `begin_state()`. 345 | """ 346 | # pylint: disable=too-many-locals 347 | self.reset() 348 | 349 | keys, axis, F, batch_size = format_sequence(length, keys, layout, False) 350 | values, axis, F, batch_size = format_sequence(length, values, layout, False) 351 | 352 | states = F.broadcast_to(F.expand_dims(value_memory, axis=0), 353 | shape=(batch_size, self.value_memory_size, self.value_memory_state_dim)) 354 | outputs = [] 355 | all_states = [] 356 | for i in range(length): 357 | output, [_, new_states] = self(keys[i], values[i], key_memory, states) 358 | states = new_states 359 | outputs.append(output) 360 | if valid_length is not None: 361 | all_states.append(states) 362 | if valid_length is not None: 363 | states = [F.SequenceLast(F.stack(*ele_list, axis=0), 364 | sequence_length=valid_length, 365 | use_sequence_length=True, 366 | axis=0) 367 | for ele_list in zip(*all_states)] 368 | outputs = mask_sequence_variable_length(F, outputs, length, valid_length, axis, True) 369 | 370 | # all_read_value_content = F.Concat(*outputs, num_args=length, dim=0) 371 | outputs, _, _, _ = format_sequence(length, outputs, layout, merge_outputs) 372 | 373 | return outputs, states 374 | 375 | 376 | class DKVMN(gluon.HybridBlock): 377 | def __init__(self, ku_num, key_embedding_dim, value_embedding_dim, hidden_num, 378 | key_memory_size, value_memory_size=None, key_memory_state_dim=None, value_memory_state_dim=None, 379 | nettype="DKVMN", dropout=0.0, 380 | key_memory_initializer=None, value_memory_initializer=None, 381 | **kwargs): 382 | super(DKVMN, self).__init__(kwargs.get("prefix"), kwargs.get("params")) 383 | 384 | ku_num = int(ku_num) 385 | key_embedding_dim = int(key_embedding_dim) 386 | value_embedding_dim = int(value_embedding_dim) 387 | hidden_num = int(hidden_num) 388 | key_memory_size = int(key_memory_size) 389 | value_memory_size = int(value_memory_size) if value_memory_size is not None else key_memory_size 390 | 391 | self.length = None 392 | self.nettype = nettype 393 | self._mask = None 394 | 395 | key_memory_state_dim = int(key_memory_state_dim) if key_memory_state_dim else key_embedding_dim 396 | value_memory_state_dim = int(value_memory_state_dim) if value_memory_state_dim else value_embedding_dim 397 | 398 | with self.name_scope(): 399 | self.key_memory = self.params.get( 400 | 'key_memory', shape=(key_memory_size, key_memory_state_dim), 401 | init=key_memory_initializer, 402 | ) 403 | 404 | self.value_memory = self.params.get( 405 | 'value_memory', shape=(value_memory_size, value_memory_state_dim), 406 | init=value_memory_initializer, 407 | ) 408 | 409 | embedding_dropout = kwargs.get("embedding_dropout", 0.2) 410 | self.key_embedding = gluon.nn.Embedding(ku_num, key_embedding_dim) 411 | self.value_embedding = gluon.nn.Embedding(2 * ku_num, value_embedding_dim) 412 | self.embedding_dropout = gluon.nn.Dropout(embedding_dropout) 413 | 414 | self.dkvmn = DKVMNCell(key_memory_size, key_memory_state_dim, value_memory_size, value_memory_state_dim) 415 | self.input_nn = gluon.nn.Dense(50, flatten=False) # 50 is set by the paper authors 416 | self.input_act = gluon.nn.Activation('tanh') 417 | self.read_content_nn = gluon.nn.Dense(hidden_num, flatten=False) 418 | self.read_content_act = gluon.nn.Activation('tanh') 419 | self.dropout = gluon.nn.Dropout(dropout) 420 | self.nn = gluon.nn.HybridSequential() 421 | self.nn.add( 422 | gluon.nn.Dense(ku_num, activation="tanh", flatten=False), 423 | self.dropout, 424 | gluon.nn.Dense(1, flatten=False), 425 | ) 426 | 427 | def __call__(self, *args, mask=None): 428 | self._mask = mask 429 | result = super(DKVMN, self).__call__(*args) 430 | self._mask = None 431 | return result 432 | 433 | def hybrid_forward(self, F, questions, responses, key_memory, value_memory, *args, **kwargs): 434 | length = self.length if self.length else len(responses[0]) 435 | 436 | q_data = self.embedding_dropout(self.key_embedding(questions)) 437 | r_data = self.embedding_dropout(self.value_embedding(responses)) 438 | 439 | read_contents, states = self.dkvmn.unroll( 440 | length, q_data, r_data, key_memory, value_memory, merge_outputs=True 441 | ) 442 | 443 | input_embed_content = self.input_act(self.input_nn(q_data)) 444 | read_content_embed = self.read_content_act( 445 | self.read_content_nn( 446 | F.Concat(read_contents, input_embed_content, num_args=2, dim=2) 447 | ) 448 | ) 449 | 450 | output = self.nn(read_content_embed) 451 | output = F.sigmoid(output) 452 | output = F.squeeze(output, axis=2) 453 | return output, states 454 | -------------------------------------------------------------------------------- /XKT/GKT/MGKT.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | from tqdm import tqdm 5 | import mxnet as mx 6 | from XKT.meta import KTM 7 | from XKT.utils import SLMLoss 8 | from baize import get_params_filepath, get_epoch_params_filepath, path_append 9 | from baize.const import CFG_JSON 10 | from baize.mxnet import light_module as lm, Configuration, fit_wrapper, split_and_load 11 | from baize.metrics import classification_report 12 | from .etl import etl 13 | from .net import get_net 14 | 15 | 16 | @fit_wrapper 17 | def fit(net, batch_data, loss_function, *args, **kwargs): 18 | item_id, data, data_mask, label, next_item_id, label_mask = batch_data 19 | output, _ = net(item_id, data, data_mask) 20 | loss = loss_function(output, next_item_id, label, label_mask) 21 | return sum(loss) 22 | 23 | 24 | def evaluation(net, test_data, ctx=mx.cpu(), *args, **kwargs): 25 | ground_truth = [] 26 | prediction = [] 27 | pred_labels = [] 28 | 29 | for batch_data in tqdm(test_data, "evaluating"): 30 | ctx_data = split_and_load( 31 | ctx, *batch_data, 32 | even_split=False 33 | ) 34 | for (item_id, data, data_mask, label, next_item_id, label_mask) in ctx_data: 35 | output, _ = net(item_id, data, data_mask) 36 | output = mx.nd.slice(output, (None, None), (None, -1)) 37 | output = mx.nd.pick(output, next_item_id) 38 | pred = output.asnumpy().tolist() 39 | label = label.asnumpy().tolist() 40 | for i, length in enumerate(label_mask.asnumpy().tolist()): 41 | length = int(length) 42 | ground_truth.extend(label[i][:length]) 43 | prediction.extend(pred[i][:length]) 44 | pred_labels.extend([0 if p < 0.5 else 1 for p in pred[i][:length]]) 45 | 46 | return classification_report(ground_truth, y_pred=pred_labels, y_score=prediction) 47 | 48 | 49 | class MGKT(KTM): 50 | def __init__(self, init_net=True, cfg_path=None, *args, **kwargs): 51 | super(MGKT, self).__init__(Configuration(params_path=cfg_path, *args, **kwargs)) 52 | if init_net: 53 | self.net = get_net(**self.cfg.hyper_params) 54 | 55 | def train(self, train_data, valid_data=None, re_init_net=False, enable_hyper_search=False, 56 | save=False, *args, **kwargs) -> ...: 57 | self.cfg.update(**kwargs) 58 | 59 | if not enable_hyper_search: 60 | print(self.cfg) 61 | 62 | lm.train( 63 | net=self.net, 64 | cfg=self.cfg, 65 | get_net=get_net if re_init_net is True else None, 66 | fit_f=fit, 67 | eval_f=evaluation, 68 | trainer=None, 69 | loss_function=SLMLoss(**self.cfg.loss_params), 70 | train_data=train_data, 71 | test_data=valid_data, 72 | enable_hyper_search=enable_hyper_search, 73 | dump_result=save, 74 | params_save=save, 75 | primary_key="macro_auc", 76 | ) 77 | 78 | def eval(self, test_data, *args, **kwargs) -> ...: 79 | return evaluation(self.net, test_data, *args, **kwargs) 80 | 81 | @classmethod 82 | def from_pretrained(cls, model_dir, best_epoch=None, *args, **kwargs): 83 | cfg_path = path_append(model_dir, CFG_JSON) 84 | model = MGKT(init_net=True, cfg_path=cfg_path, model_dir=model_dir) 85 | cfg = model.cfg 86 | model.load( 87 | get_epoch_params_filepath(cfg.model_name, best_epoch, cfg.model_dir) 88 | if best_epoch is not None else get_params_filepath(cfg.model_name, cfg.model_dir) 89 | ) 90 | return model 91 | 92 | @classmethod 93 | def benchmark_train(cls, train_path, valid_path=None, enable_hyper_search=False, 94 | save=False, *args, **kwargs): 95 | dkt = MGKT(init_net=not enable_hyper_search, *args, **kwargs) 96 | train_data = etl(train_path, dkt.cfg) 97 | valid_data = etl(valid_path, dkt.cfg) if valid_path is not None else None 98 | dkt.train(train_data, valid_data, re_init_net=enable_hyper_search, enable_hyper_search=enable_hyper_search, 99 | save=save) 100 | 101 | @classmethod 102 | def benchmark_eval(cls, test_path, model_path, best_epoch, *args, **kwargs): 103 | dkt = MGKT.from_pretrained(model_path, best_epoch) 104 | test_data = etl(test_path, dkt.cfg) 105 | return dkt.eval(test_data) 106 | -------------------------------------------------------------------------------- /XKT/GKT/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | from .MGKT import MGKT 5 | from .etl import etl 6 | -------------------------------------------------------------------------------- /XKT/GKT/etl.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | import mxnet.ndarray as nd 5 | from tqdm import tqdm 6 | from XKT.utils import extract 7 | from baize.utils import FixedBucketSampler, PadSequence 8 | 9 | 10 | def transform(raw_data, batch_size, num_buckets=100): 11 | # 定义数据转换接口 12 | # raw_data --> batch_data 13 | 14 | responses = raw_data 15 | 16 | batch_idxes = FixedBucketSampler([len(rs) for rs in responses], batch_size, num_buckets=num_buckets) 17 | batch = [] 18 | 19 | def one_hot(r): 20 | correct = 0 if r[1] <= 0 else 1 21 | return r[0] * 2 + correct 22 | 23 | for batch_idx in tqdm(batch_idxes, "batchify"): 24 | batch_qs = [] 25 | batch_rs = [] 26 | batch_pick_index = [] 27 | batch_labels = [] 28 | for idx in batch_idx: 29 | batch_qs.append([r[0] for r in responses[idx]]) 30 | batch_rs.append([one_hot(r) for r in responses[idx]]) 31 | if len(responses[idx]) <= 1: # pragma: no cover 32 | pick_index, labels = [], [] 33 | else: 34 | pick_index, labels = zip(*[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]]) 35 | batch_pick_index.append(list(pick_index)) 36 | batch_labels.append(list(labels)) 37 | 38 | max_len = max([len(rs) for rs in batch_rs]) 39 | padder = PadSequence(max_len, pad_val=0) 40 | batch_qs = [padder(qs) for qs in batch_qs] 41 | batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs]) 42 | 43 | max_len = max([len(rs) for rs in batch_labels]) 44 | padder = PadSequence(max_len, pad_val=0) 45 | batch_labels, label_mask = zip(*[(padder(labels), len(labels)) for labels in batch_labels]) 46 | batch_pick_index = [padder(pick_index) for pick_index in batch_pick_index] 47 | batch.append( 48 | [nd.array(batch_qs), nd.array(batch_rs), nd.array(data_mask), nd.array(batch_labels), 49 | nd.array(batch_pick_index), 50 | nd.array(label_mask)]) 51 | 52 | return batch 53 | 54 | 55 | def etl(data_src, cfg=None, batch_size=None, **kwargs): # pragma: no cover 56 | batch_size = batch_size if batch_size is not None else cfg.batch_size 57 | raw_data = extract(data_src) 58 | return transform(raw_data, batch_size, **kwargs) 59 | -------------------------------------------------------------------------------- /XKT/GKT/net.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | 5 | __all__ = ["get_net"] 6 | 7 | import json 8 | import networkx as nx 9 | from baize.mxnet.utils import format_sequence, mask_sequence_variable_length 10 | from mxnet import gluon 11 | import mxnet as mx 12 | from XKT.utils.nn import GRUCell, begin_states, get_states, expand_tensor 13 | 14 | 15 | def get_net(ku_num, graph, **kwargs): 16 | return GKT(ku_num, graph, **kwargs) 17 | 18 | 19 | class GKT(gluon.Block): 20 | def __init__(self, ku_num, graph, latent_dim=None, 21 | hidden_num=None, dropout=0.0, prefix=None, params=None): 22 | super(GKT, self).__init__(prefix=prefix, params=params) 23 | self.ku_num = int(ku_num) 24 | self.hidden_num = self.ku_num if hidden_num is None else int(hidden_num) 25 | self.latent_dim = self.ku_num if latent_dim is None else int(latent_dim) 26 | self.graph = nx.DiGraph() 27 | self.graph.add_nodes_from(range(ku_num)) 28 | try: 29 | with open(graph) as f: 30 | self.graph.add_weighted_edges_from(json.load(f)) 31 | except ValueError: 32 | with open(graph) as f: 33 | self.graph.add_weighted_edges_from([e + [1.0] for e in json.load(f)]) 34 | 35 | with self.name_scope(): 36 | self.rnn = GRUCell(self.hidden_num) 37 | self.response_embedding = gluon.nn.Embedding(2 * self.ku_num, self.latent_dim) 38 | self.concept_embedding = gluon.nn.Embedding(self.ku_num, self.latent_dim) 39 | self.f_self = gluon.nn.Dense(self.hidden_num, flatten=False) 40 | self.n_out = gluon.nn.Dense(self.hidden_num, flatten=False) 41 | self.n_in = gluon.nn.Dense(self.hidden_num, flatten=False) 42 | self.dropout = gluon.nn.Dropout(dropout) 43 | self.out = gluon.nn.Dense(1, flatten=False) 44 | 45 | def in_weight(self, x, ordinal=True, with_weight=True): 46 | if isinstance(x, mx.nd.NDArray): 47 | x = x.asnumpy().tolist() 48 | if isinstance(x, list): 49 | return [self.in_weight(_x) for _x in x] 50 | elif isinstance(x, (int, float)): 51 | if not ordinal: 52 | return list(self.graph.predecessors(int(x))) 53 | else: 54 | _ret = [0] * self.ku_num 55 | for i in self.graph.predecessors(int(x)): 56 | if with_weight: 57 | _ret[i] = self.graph[i][x]['weight'] 58 | else: 59 | _ret[i] = 1 60 | return _ret 61 | else: 62 | raise TypeError("cannot handle %s" % type(x)) 63 | 64 | def out_weight(self, x, ordinal=True, with_weight=True): 65 | if isinstance(x, mx.nd.NDArray): 66 | x = x.asnumpy().tolist() 67 | if isinstance(x, list): 68 | return [self.out_weight(_x) for _x in x] 69 | elif isinstance(x, (int, float)): 70 | if not ordinal: 71 | return list(self.graph.successors(int(x))) 72 | else: 73 | _ret = [0] * self.ku_num 74 | for i in self.graph.successors(int(x)): 75 | if with_weight: 76 | _ret[i] = self.graph[x][i]['weight'] 77 | else: 78 | _ret[i] = 1 79 | return _ret 80 | else: 81 | raise TypeError("cannot handle %s" % type(x)) 82 | 83 | def neighbors(self, x, ordinal=True, with_weight=False): 84 | if isinstance(x, mx.nd.NDArray): 85 | x = x.asnumpy().tolist() 86 | if isinstance(x, list): 87 | return [self.neighbors(_x) for _x in x] 88 | elif isinstance(x, (int, float)): 89 | if not ordinal: 90 | return list(self.graph.neighbors(int(x))) 91 | else: 92 | _ret = [0] * self.ku_num 93 | for i in self.graph.neighbors(int(x)): 94 | if with_weight: 95 | _ret[i] = self.graph[x][i]['weight'] 96 | else: 97 | _ret[i] = 1 98 | return _ret 99 | else: 100 | raise TypeError("cannot handle %s" % type(x)) 101 | 102 | def forward(self, questions, answers, valid_length=None, states=None, layout='NTC', compressed_out=True, *args, 103 | **kwargs): 104 | ctx = questions.context 105 | length = questions.shape[1] 106 | 107 | inputs, axis, F, batch_size = format_sequence(length, questions, layout, False) 108 | answers, _, _, _ = format_sequence(length, answers, layout, False) 109 | states = begin_states([(batch_size, self.ku_num, self.hidden_num)], self.prefix)[0] 110 | states = states.as_in_context(ctx) 111 | outputs = [] 112 | all_states = [] 113 | for i in range(length): 114 | # neighbors - aggregate 115 | _neighbors = self.neighbors(inputs[i]) 116 | neighbors_mask = expand_tensor(mx.nd.array(_neighbors, ctx=ctx), -1, self.hidden_num) 117 | _neighbors_mask = expand_tensor(mx.nd.array(_neighbors, ctx=ctx), -1, self.hidden_num + self.latent_dim) 118 | 119 | # get concept embedding 120 | concept_embeddings = self.concept_embedding.weight.data(ctx) 121 | concept_embeddings = expand_tensor(concept_embeddings, 0, batch_size) 122 | 123 | agg_states = mx.nd.concat(concept_embeddings, states, dim=-1) 124 | 125 | # aggregate 126 | _neighbors_states = _neighbors_mask * agg_states 127 | 128 | # self - aggregate 129 | _concept_embedding = get_states(inputs[i], states) 130 | _self_hidden_states = mx.nd.concat(_concept_embedding, self.response_embedding(answers[i]), dim=-1) 131 | 132 | _self_mask = mx.nd.one_hot(inputs[i], self.ku_num) 133 | _self_mask = expand_tensor(_self_mask, -1, self.hidden_num) 134 | 135 | self_hidden_states = expand_tensor(_self_hidden_states, 1, self.ku_num) 136 | 137 | # aggregate 138 | _hidden_states = mx.nd.concat(_neighbors_states, self_hidden_states, dim=-1) 139 | 140 | _in_state = self.n_in(_hidden_states) 141 | _out_state = self.n_out(_hidden_states) 142 | in_weight = expand_tensor(mx.nd.array(self.in_weight(inputs[i]), ctx=ctx), -1, self.hidden_num) 143 | out_weight = expand_tensor(mx.nd.array(self.out_weight(inputs[i]), ctx=ctx), -1, self.hidden_num) 144 | 145 | next_neighbors_states = in_weight * _in_state + out_weight * _out_state 146 | 147 | # self - update 148 | next_self_states = self.f_self(_self_hidden_states) 149 | next_self_states = expand_tensor(next_self_states, 1, self.ku_num) 150 | next_self_states = _self_mask * next_self_states 151 | 152 | next_states = neighbors_mask * next_neighbors_states + next_self_states 153 | 154 | next_states, _ = self.rnn(next_states, [states]) 155 | next_states = (_self_mask + neighbors_mask) * next_states + (1 - _self_mask - neighbors_mask) * states 156 | 157 | states = self.dropout(next_states) 158 | output = mx.nd.sigmoid(mx.nd.squeeze(self.out(states), axis=-1)) 159 | outputs.append(output) 160 | if valid_length is not None and not compressed_out: 161 | all_states.append([states]) 162 | 163 | if valid_length is not None: 164 | if compressed_out: 165 | states = None 166 | else: 167 | states = [mx.nd.SequenceLast(mx.nd.stack(*ele_list, axis=0), 168 | sequence_length=valid_length, 169 | use_sequence_length=True, 170 | axis=0) 171 | for ele_list in zip(*all_states)] 172 | outputs = mask_sequence_variable_length(mx.nd, outputs, length, valid_length, axis, True) 173 | outputs, _, _, _ = format_sequence(length, outputs, layout, merge=True) 174 | 175 | return outputs, states 176 | -------------------------------------------------------------------------------- /XKT/SKT/MSKT.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | from tqdm import tqdm 5 | import mxnet as mx 6 | from XKT.meta import KTM 7 | from XKT.utils import SLMLoss 8 | from baize import get_params_filepath, get_epoch_params_filepath, path_append 9 | from baize.const import CFG_JSON 10 | from baize.mxnet import light_module as lm, Configuration, fit_wrapper, split_and_load 11 | from baize.metrics import classification_report 12 | from .etl import etl 13 | from .net import get_net 14 | 15 | 16 | @fit_wrapper 17 | def fit(net, batch_data, loss_function, *args, **kwargs): 18 | item_id, data, data_mask, label, next_item_id, label_mask = batch_data 19 | output, _ = net(item_id, data, data_mask) 20 | loss = loss_function(output, next_item_id, label, label_mask) 21 | return sum(loss) 22 | 23 | 24 | def evaluation(net, test_data, ctx=mx.cpu(), *args, **kwargs): 25 | ground_truth = [] 26 | prediction = [] 27 | pred_labels = [] 28 | 29 | for batch_data in tqdm(test_data, "evaluating"): 30 | ctx_data = split_and_load( 31 | ctx, *batch_data, 32 | even_split=False 33 | ) 34 | for (item_id, data, data_mask, label, next_item_id, label_mask) in ctx_data: 35 | output, _ = net(item_id, data, data_mask) 36 | output = mx.nd.slice(output, (None, None), (None, -1)) 37 | output = mx.nd.pick(output, next_item_id) 38 | pred = output.asnumpy().tolist() 39 | label = label.asnumpy().tolist() 40 | for i, length in enumerate(label_mask.asnumpy().tolist()): 41 | length = int(length) 42 | ground_truth.extend(label[i][:length]) 43 | prediction.extend(pred[i][:length]) 44 | pred_labels.extend([0 if p < 0.5 else 1 for p in pred[i][:length]]) 45 | 46 | return classification_report(ground_truth, y_pred=pred_labels, y_score=prediction) 47 | 48 | 49 | class MSKT(KTM): 50 | def __init__(self, init_net=True, cfg_path=None, *args, **kwargs): 51 | super(MSKT, self).__init__(Configuration(params_path=cfg_path, *args, **kwargs)) 52 | if init_net: 53 | self.net = get_net(**self.cfg.hyper_params) 54 | 55 | def train(self, train_data, valid_data=None, re_init_net=False, enable_hyper_search=False, 56 | save=False, *args, **kwargs) -> ...: 57 | self.cfg.update(**kwargs) 58 | 59 | if not enable_hyper_search: 60 | print(self.cfg) 61 | 62 | lm.train( 63 | net=self.net, 64 | cfg=self.cfg, 65 | get_net=get_net if re_init_net is True else None, 66 | fit_f=fit, 67 | eval_f=evaluation, 68 | trainer=None, 69 | loss_function=SLMLoss(**self.cfg.loss_params), 70 | train_data=train_data, 71 | test_data=valid_data, 72 | enable_hyper_search=enable_hyper_search, 73 | dump_result=save, 74 | params_save=save, 75 | primary_key="macro_auc", 76 | ) 77 | 78 | def eval(self, test_data, *args, **kwargs) -> ...: 79 | return evaluation(self.net, test_data, *args, **kwargs) 80 | 81 | @classmethod 82 | def from_pretrained(cls, model_dir, best_epoch=None, *args, **kwargs): 83 | cfg_path = path_append(model_dir, CFG_JSON) 84 | model = MSKT(init_net=True, cfg_path=cfg_path, model_dir=model_dir) 85 | cfg = model.cfg 86 | model.load( 87 | get_epoch_params_filepath(cfg.model_name, best_epoch, cfg.model_dir) 88 | if best_epoch is not None else get_params_filepath(cfg.model_name, cfg.model_dir) 89 | ) 90 | return model 91 | 92 | @classmethod 93 | def benchmark_train(cls, train_path, valid_path=None, enable_hyper_search=False, 94 | save=False, *args, **kwargs): 95 | dkt = MSKT(init_net=not enable_hyper_search, *args, **kwargs) 96 | train_data = etl(train_path, dkt.cfg) 97 | valid_data = etl(valid_path, dkt.cfg) if valid_path is not None else None 98 | dkt.train(train_data, valid_data, re_init_net=enable_hyper_search, enable_hyper_search=enable_hyper_search, 99 | save=save) 100 | 101 | @classmethod 102 | def benchmark_eval(cls, test_path, model_path, best_epoch, *args, **kwargs): 103 | dkt = MSKT.from_pretrained(model_path, best_epoch) 104 | test_data = etl(test_path, dkt.cfg) 105 | return dkt.eval(test_data) 106 | -------------------------------------------------------------------------------- /XKT/SKT/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | from .MSKT import MSKT 5 | from .etl import etl 6 | -------------------------------------------------------------------------------- /XKT/SKT/etl.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | 4 | import mxnet.ndarray as nd 5 | from tqdm import tqdm 6 | from XKT.utils import extract 7 | from baize.utils import FixedBucketSampler, PadSequence 8 | 9 | 10 | def transform(raw_data, batch_size, num_buckets=100): 11 | responses = raw_data 12 | 13 | batch_idxes = FixedBucketSampler([len(rs) for rs in responses], batch_size, num_buckets=num_buckets) 14 | batch = [] 15 | 16 | def one_hot(r): 17 | correct = 0 if r[1] <= 0 else 1 18 | return r[0] * 2 + correct 19 | 20 | for batch_idx in tqdm(batch_idxes, "batchify"): 21 | batch_qs = [] 22 | batch_rs = [] 23 | batch_pick_index = [] 24 | batch_labels = [] 25 | for idx in batch_idx: 26 | batch_qs.append([r[0] for r in responses[idx]]) 27 | batch_rs.append([one_hot(r) for r in responses[idx]]) 28 | if len(responses[idx]) <= 1: 29 | pick_index, labels = [], [] 30 | else: 31 | pick_index, labels = zip(*[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]]) 32 | batch_pick_index.append(list(pick_index)) 33 | batch_labels.append(list(labels)) 34 | 35 | max_len = max([len(rs) for rs in batch_rs]) 36 | padder = PadSequence(max_len, pad_val=0) 37 | batch_qs = [padder(qs) for qs in batch_qs] 38 | batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs]) 39 | 40 | max_len = max([len(rs) for rs in batch_labels]) 41 | padder = PadSequence(max_len, pad_val=0) 42 | batch_labels, label_mask = zip(*[(padder(labels), len(labels)) for labels in batch_labels]) 43 | batch_pick_index = [padder(pick_index) for pick_index in batch_pick_index] 44 | batch.append( 45 | [nd.array(batch_qs), nd.array(batch_rs), nd.array(data_mask), nd.array(batch_labels), 46 | nd.array(batch_pick_index), 47 | nd.array(label_mask)]) 48 | 49 | return batch 50 | 51 | 52 | def etl(data_src, cfg=None, batch_size=None, **kwargs): # pragma: no cover 53 | batch_size = batch_size if batch_size is not None else cfg.batch_size 54 | raw_data = extract(data_src) 55 | return transform(raw_data, batch_size, **kwargs) 56 | -------------------------------------------------------------------------------- /XKT/SKT/net.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | from baize.mxnet.utils import format_sequence, mask_sequence_variable_length 4 | from mxnet import gluon 5 | import mxnet as mx 6 | from XKT.utils.nn import GRUCell, begin_states, get_states, expand_tensor 7 | from .utils import Graph 8 | 9 | 10 | def get_net(ku_num, graph_params=None, net_type="SKT", k=2, **kwargs): 11 | if net_type == "SKT": 12 | return SKT(ku_num, graph_params, **kwargs) 13 | elif net_type == "SKT_TE": 14 | return SKT_TE(ku_num, **kwargs) 15 | elif net_type == "SKTPart": 16 | return SKTPart(ku_num, graph_params, **kwargs) 17 | elif net_type == "SKTSync": 18 | return SKTSync(ku_num, graph_params, **kwargs) 19 | else: 20 | raise NotImplementedError 21 | 22 | 23 | class SKT(gluon.Block): 24 | def __init__(self, ku_num, graph_params=None, 25 | alpha=0.5, 26 | latent_dim=None, activation=None, 27 | hidden_num=90, concept_dim=None, 28 | # dropout=0.5, self_dropout=0.0, 29 | dropout=0.0, self_dropout=0.5, 30 | # dropout=0.0, self_dropout=0.0, 31 | sync_activation="relu", sync_dropout=0.0, 32 | prop_activation="relu", prop_dropout=0.0, 33 | agg_activation="relu", agg_dropout=0.0, 34 | prefix=None, params=None): 35 | super(SKT, self).__init__(prefix=prefix, params=params) 36 | self.ku_num = int(ku_num) 37 | self.hidden_num = self.ku_num if hidden_num is None else int(hidden_num) 38 | self.latent_dim = self.hidden_num if latent_dim is None else int(latent_dim) 39 | self.concept_dim = self.hidden_num if concept_dim is None else int(concept_dim) 40 | graph_params = graph_params if graph_params is not None else [] 41 | self.graph = Graph.from_file(ku_num, graph_params) 42 | self.alpha = alpha 43 | 44 | sync_activation = sync_activation if activation is None else activation 45 | prop_activation = prop_activation if activation is None else activation 46 | agg_activation = agg_activation if activation is None else activation 47 | 48 | with self.name_scope(): 49 | self.rnn = GRUCell(self.hidden_num) 50 | self.response_embedding = gluon.nn.Embedding(2 * self.ku_num, self.latent_dim) 51 | self.concept_embedding = gluon.nn.Embedding(self.ku_num, self.concept_dim) 52 | self.f_self = gluon.rnn.GRUCell(self.hidden_num) 53 | # self.f_self = gluon.nn.Sequential() 54 | # self.f_self.add( 55 | # gluon.nn.Dense(self.hidden_num), 56 | # gluon.nn.Activation("relu") 57 | # ) 58 | self.self_dropout = gluon.nn.Dropout(self_dropout) 59 | self.f_prop = gluon.nn.Sequential() 60 | self.f_prop.add( 61 | gluon.nn.Dense(self.hidden_num, flatten=False), 62 | gluon.nn.Activation(prop_activation), 63 | gluon.nn.Dropout(prop_dropout), 64 | ) 65 | self.f_sync = gluon.nn.Sequential() 66 | self.f_sync.add( 67 | gluon.nn.Dense(self.hidden_num, flatten=False), 68 | gluon.nn.Activation(sync_activation), 69 | gluon.nn.Dropout(sync_dropout), 70 | ) 71 | self.f_agg = gluon.nn.Sequential() 72 | self.f_agg.add( 73 | gluon.nn.Dense(self.hidden_num, flatten=False), 74 | # gluon.nn.InstanceNorm(), 75 | # gluon.nn.LayerNorm(), 76 | # gluon.nn.BatchNorm(), 77 | gluon.nn.Activation(agg_activation), 78 | gluon.nn.Dropout(agg_dropout), 79 | ) 80 | self.dropout = gluon.nn.Dropout(dropout) 81 | self.out = gluon.nn.Dense(1, flatten=False) 82 | 83 | def neighbors(self, x, ordinal=True): 84 | return self.graph.neighbors(x, ordinal) 85 | 86 | def successors(self, x, ordinal=True): 87 | return self.graph.successors(x, ordinal) 88 | 89 | def forward(self, questions, answers, valid_length=None, states=None, layout='NTC', compressed_out=True, 90 | *args, **kwargs): 91 | ctx = questions.context 92 | length = questions.shape[1] 93 | 94 | inputs, axis, F, batch_size = format_sequence(length, questions, layout, False) 95 | answers, _, _, _ = format_sequence(length, answers, layout, False) 96 | if states is None: 97 | states = begin_states([(batch_size, self.ku_num, self.hidden_num)], self.prefix)[0] 98 | states = states.as_in_context(ctx) 99 | outputs = [] 100 | all_states = [] 101 | for i in range(length): 102 | # self - influence 103 | _self_state = get_states(inputs[i], states) 104 | # fc 105 | # _next_self_state = self.f_self(mx.nd.concat(_self_state, self.response_embedding(answers[i]), dim=-1)) 106 | # gru 107 | _next_self_state, _ = self.f_self(self.response_embedding(answers[i]), [_self_state]) 108 | # _next_self_state = self.f_self(mx.nd.concat(_self_hidden_states, _self_state)) 109 | # _next_self_state, _ = self.f_self(_self_hidden_states, [_self_state]) 110 | _next_self_state = self.self_dropout(_next_self_state) 111 | 112 | # get self mask 113 | _self_mask = mx.nd.expand_dims(mx.nd.one_hot(inputs[i], self.ku_num), -1) 114 | _self_mask = mx.nd.broadcast_to(_self_mask, (0, 0, self.hidden_num)) 115 | 116 | # find neighbors 117 | _neighbors = self.neighbors(inputs[i]) 118 | _neighbors_mask = mx.nd.expand_dims(mx.nd.array(_neighbors, ctx=ctx), -1) 119 | _neighbors_mask = mx.nd.broadcast_to(_neighbors_mask, (0, 0, self.hidden_num)) 120 | 121 | # synchronization 122 | _broadcast_next_self_states = mx.nd.expand_dims(_next_self_state, 1) 123 | _broadcast_next_self_states = mx.nd.broadcast_to(_broadcast_next_self_states, (0, self.ku_num, 0)) 124 | # _sync_diff = mx.nd.concat(states, _broadcast_next_self_states, concept_embeddings, dim=-1) 125 | _sync_diff = mx.nd.concat(states, _broadcast_next_self_states, dim=-1) 126 | _sync_inf = _neighbors_mask * self.f_sync(_sync_diff) 127 | 128 | # reflection on current vertex 129 | _reflec_inf = mx.nd.sum(_sync_inf, axis=1) 130 | _reflec_inf = mx.nd.broadcast_to(mx.nd.expand_dims(_reflec_inf, 1), (0, self.ku_num, 0)) 131 | _sync_inf = _sync_inf + _self_mask * _reflec_inf 132 | 133 | # find successors 134 | _successors = self.successors(inputs[i]) 135 | _successors_mask = mx.nd.expand_dims(mx.nd.array(_successors, ctx=ctx), -1) 136 | _successors_mask = mx.nd.broadcast_to(_successors_mask, (0, 0, self.hidden_num)) 137 | 138 | # propagation 139 | # _prop_diff = mx.nd.concat(_next_self_state - _self_state, self.concept_embedding(inputs[i]), dim=-1) 140 | _prop_diff = _next_self_state - _self_state 141 | 142 | # 1 143 | _prop_inf = self.f_prop(_prop_diff) 144 | _prop_inf = _successors_mask * mx.nd.broadcast_to(mx.nd.expand_dims(_prop_inf, axis=1), (0, self.ku_num, 0)) 145 | # 2 146 | # _broadcast_diff = mx.nd.broadcast_to(mx.nd.expand_dims(_prop_diff, axis=1), (0, self.ku_num, 0)) 147 | # _pro_inf = _successors_mask * self.f_prop( 148 | # mx.nd.concat(_broadcast_diff, concept_embeddings, dim=-1) 149 | # ) 150 | # _pro_inf = _successors_mask * self.f_prop( 151 | # _broadcast_diff 152 | # ) 153 | # concept embedding 154 | concept_embeddings = self.concept_embedding.weight.data(ctx) 155 | concept_embeddings = expand_tensor(concept_embeddings, 0, batch_size) 156 | # concept_embeddings = (_self_mask + _successors_mask + _neighbors_mask) * concept_embeddings 157 | 158 | # aggregate 159 | _inf = self.f_agg(self.alpha * _sync_inf + (1 - self.alpha) * _prop_inf) 160 | # next_states, _ = self.rnn(_inf, [states]) 161 | next_states, _ = self.rnn(mx.nd.concat(_inf, concept_embeddings, dim=-1), [states]) 162 | # states = (1 - _self_mask) * next_states + _self_mask * _broadcast_next_self_states 163 | states = next_states 164 | output = mx.nd.sigmoid(mx.nd.squeeze(self.out(self.dropout(states)), axis=-1)) 165 | outputs.append(output) 166 | if valid_length is not None and not compressed_out: 167 | all_states.append([states]) 168 | 169 | if valid_length is not None: 170 | if compressed_out: 171 | states = None 172 | else: 173 | states = [mx.nd.SequenceLast(mx.nd.stack(*ele_list, axis=0), 174 | sequence_length=valid_length, 175 | use_sequence_length=True, 176 | axis=0) 177 | for ele_list in zip(*all_states)] 178 | outputs = mask_sequence_variable_length(mx.nd, outputs, length, valid_length, axis, True) 179 | outputs, _, _, _ = format_sequence(length, outputs, layout, merge=True) 180 | 181 | return outputs, states 182 | 183 | 184 | class SKTPart(SKT): 185 | def __init__(self, ku_num, graph_params=None, 186 | latent_dim=None, activation=None, 187 | hidden_num=90, concept_dim=None, 188 | dropout=0.0, self_dropout=0.0, 189 | prop_activation="relu", prop_dropout=0.0, 190 | prefix=None, params=None): 191 | super(SKT, self).__init__(prefix=prefix, params=params) 192 | self.ku_num = int(ku_num) 193 | self.hidden_num = self.ku_num if hidden_num is None else int(hidden_num) 194 | self.latent_dim = self.hidden_num if latent_dim is None else int(latent_dim) 195 | self.concept_dim = self.hidden_num if concept_dim is None else int(concept_dim) 196 | graph_params = graph_params if graph_params is not None else [] 197 | self.graph = Graph.from_file(ku_num, graph_params) 198 | 199 | prop_activation = prop_activation if activation is None else activation 200 | 201 | with self.name_scope(): 202 | self.rnn = GRUCell(self.hidden_num) 203 | self.response_embedding = gluon.nn.Embedding(2 * self.ku_num, self.latent_dim) 204 | self.concept_embedding = gluon.nn.Embedding(self.ku_num, self.concept_dim) 205 | self.f_self = gluon.rnn.GRUCell(self.hidden_num) 206 | # self.f_self = gluon.nn.Sequential() 207 | # self.f_self.add( 208 | # gluon.nn.Dense(self.hidden_num), 209 | # gluon.nn.Activation("relu") 210 | # ) 211 | self.self_dropout = gluon.nn.Dropout(self_dropout) 212 | self.f_prop = gluon.nn.Sequential() 213 | self.f_prop.add( 214 | gluon.nn.Dense(self.hidden_num, flatten=False), 215 | gluon.nn.Activation(prop_activation), 216 | gluon.nn.Dropout(prop_dropout), 217 | ) 218 | # self.f_sync = gluon.nn.Sequential() 219 | # self.f_sync.add( 220 | # gluon.nn.Dense(self.hidden_num, flatten=False), 221 | # gluon.nn.Activation(sync_activation), 222 | # gluon.nn.Dropout(sync_dropout), 223 | # ) 224 | # self.f_reflec = gluon.nn.Sequential() 225 | # self.f_reflec.add( 226 | # gluon.nn.Dense(self.hidden_num, flatten=False), 227 | # gluon.nn.Activation(sync_activation), 228 | # gluon.nn.Dropout(sync_dropout), 229 | # ) 230 | # self.f_agg = gluon.nn.Sequential() 231 | # self.f_agg.add( 232 | # gluon.nn.Dense(self.hidden_num, flatten=False), 233 | # # gluon.nn.InstanceNorm(), 234 | # # gluon.nn.LayerNorm(), 235 | # # gluon.nn.BatchNorm(), 236 | # gluon.nn.Activation(agg_activation), 237 | # gluon.nn.Dropout(agg_dropout), 238 | # ) 239 | self.dropout = gluon.nn.Dropout(dropout) 240 | self.out = gluon.nn.Dense(1, flatten=False) 241 | 242 | def forward(self, questions, answers, valid_length=None, states=None, layout='NTC', compressed_out=True, 243 | *args, **kwargs): 244 | ctx = questions.context 245 | length = questions.shape[1] 246 | 247 | inputs, axis, F, batch_size = format_sequence(length, questions, layout, False) 248 | answers, _, _, _ = format_sequence(length, answers, layout, False) 249 | if states is None: 250 | states = begin_states([(batch_size, self.ku_num, self.hidden_num)], self.prefix)[0] 251 | states = states.as_in_context(ctx) 252 | outputs = [] 253 | all_states = [] 254 | for i in range(length): 255 | # self - influence 256 | _self_state = get_states(inputs[i], states) 257 | # fc 258 | # _next_self_state = self.f_self(mx.nd.concat(_self_state, self.response_embedding(answers[i]), dim=-1)) 259 | # gru 260 | _next_self_state, _ = self.f_self(self.response_embedding(answers[i]), [_self_state]) 261 | # _next_self_state = self.f_self(mx.nd.concat(_self_hidden_states, _self_state)) 262 | # _next_self_state, _ = self.f_self(_self_hidden_states, [_self_state]) 263 | _next_self_state = self.self_dropout(_next_self_state) 264 | 265 | # get self mask 266 | _self_mask = mx.nd.expand_dims(mx.nd.one_hot(inputs[i], self.ku_num), -1) 267 | _self_mask = mx.nd.broadcast_to(_self_mask, (0, 0, self.hidden_num)) 268 | # self-concept embedding 269 | # _self_concept_embedding = self.concept_embedding(inputs[i]) 270 | # _broadcast_self_concept_embedding = mx.nd.expand_dims(_self_concept_embedding, dim=1) 271 | # _broadcast_self_concept_embedding = mx.nd.broadcast_to(_broadcast_self_concept_embedding, 272 | # (0, self.ku_num, 0)) 273 | # concept embedding 274 | concept_embeddings = self.concept_embedding.weight.data(ctx) 275 | concept_embeddings = expand_tensor(concept_embeddings, 0, batch_size) 276 | # concept_embeddings = (_self_mask + _successors_mask + _neighbors_mask) * concept_embeddings 277 | 278 | # find successors 279 | _successors = self.successors(inputs[i]) 280 | _successors_mask = mx.nd.expand_dims(mx.nd.array(_successors, ctx=ctx), -1) 281 | _successors_mask = mx.nd.broadcast_to(_successors_mask, (0, 0, self.hidden_num)) 282 | 283 | _broadcast_next_self_states = mx.nd.expand_dims(_next_self_state, 1) 284 | _broadcast_next_self_states = mx.nd.broadcast_to(_broadcast_next_self_states, (0, self.ku_num, 0)) 285 | 286 | # propagation 287 | # _prop_diff = mx.nd.concat(_next_self_state - _self_state, self.concept_embedding(inputs[i]), dim=-1) 288 | _prop_diff = _next_self_state - _self_state 289 | 290 | # 1 291 | _prop_inf = self.f_prop( 292 | mx.nd.concat(mx.nd.broadcast_to(mx.nd.expand_dims(_prop_diff, axis=1), (0, self.ku_num, 0)), 293 | concept_embeddings, dim=-1)) 294 | _prop_inf = _successors_mask * _prop_inf 295 | 296 | # aggregate 297 | # _inf = self.f_agg(_prop_inf) 298 | _inf = _prop_inf 299 | # next_states, _ = self.rnn(_inf, [states]) 300 | next_states, _ = self.rnn(_inf, [states]) 301 | updated = _successors_mask * next_states + _self_mask * _broadcast_next_self_states 302 | states = updated + (1 - _successors_mask - _self_mask) * states 303 | # states = next_states 304 | output = mx.nd.sigmoid(mx.nd.squeeze(self.out(self.dropout(states)), axis=-1)) 305 | outputs.append(output) 306 | if valid_length is not None and not compressed_out: 307 | all_states.append([states]) 308 | 309 | if valid_length is not None: 310 | if compressed_out: 311 | states = None 312 | else: 313 | states = [mx.nd.SequenceLast(mx.nd.stack(*ele_list, axis=0), 314 | sequence_length=valid_length, 315 | use_sequence_length=True, 316 | axis=0) 317 | for ele_list in zip(*all_states)] 318 | outputs = mask_sequence_variable_length(mx.nd, outputs, length, valid_length, axis, True) 319 | outputs, _, _, _ = format_sequence(length, outputs, layout, merge=True) 320 | 321 | return outputs, states 322 | 323 | 324 | class SKT_TE(gluon.Block): 325 | def __init__(self, ku_num, 326 | latent_dim=None, 327 | hidden_num=90, concept_dim=None, 328 | dropout=0.0, self_dropout=0.5, 329 | prefix=None, params=None): 330 | super(SKT_TE, self).__init__(prefix=prefix, params=params) 331 | self.ku_num = int(ku_num) 332 | self.hidden_num = self.ku_num if hidden_num is None else int(hidden_num) 333 | self.latent_dim = self.hidden_num if latent_dim is None else int(latent_dim) 334 | self.concept_dim = self.hidden_num if concept_dim is None else int(concept_dim) 335 | 336 | with self.name_scope(): 337 | self.response_embedding = gluon.nn.Embedding(2 * self.ku_num, self.latent_dim) 338 | self.f_self = gluon.rnn.GRUCell(self.hidden_num) 339 | self.self_dropout = gluon.nn.Dropout(self_dropout) 340 | self.dropout = gluon.nn.Dropout(dropout) 341 | self.out = gluon.nn.Dense(1, flatten=False) 342 | 343 | def forward(self, questions, answers, valid_length=None, states=None, layout='NTC', compressed_out=True, 344 | *args, **kwargs): 345 | ctx = questions.context 346 | length = questions.shape[1] 347 | 348 | inputs, axis, F, batch_size = format_sequence(length, questions, layout, False) 349 | answers, _, _, _ = format_sequence(length, answers, layout, False) 350 | if states is None: 351 | states = begin_states([(batch_size, self.ku_num, self.hidden_num)], self.prefix)[0] 352 | states = states.as_in_context(ctx) 353 | outputs = [] 354 | all_states = [] 355 | for i in range(length): 356 | # self - influence 357 | _self_state = get_states(inputs[i], states) 358 | # fc 359 | # _next_self_state = self.f_self(mx.nd.concat(_self_state, self.response_embedding(answers[i]), dim=-1)) 360 | # gru 361 | _next_self_state, _ = self.f_self(self.response_embedding(answers[i]), [_self_state]) 362 | # _next_self_state = self.f_self(mx.nd.concat(_self_hidden_states, _self_state)) 363 | # _next_self_state, _ = self.f_self(_self_hidden_states, [_self_state]) 364 | _next_self_state = self.self_dropout(_next_self_state) 365 | 366 | # get self mask 367 | _self_mask = mx.nd.expand_dims(mx.nd.one_hot(inputs[i], self.ku_num), -1) 368 | _self_mask = mx.nd.broadcast_to(_self_mask, (0, 0, self.hidden_num)) 369 | 370 | _broadcast_next_self_states = mx.nd.expand_dims(_next_self_state, 1) 371 | _broadcast_next_self_states = mx.nd.broadcast_to(_broadcast_next_self_states, (0, self.ku_num, 0)) 372 | 373 | states = (1 - _self_mask) * states + _self_mask * _broadcast_next_self_states 374 | output = mx.nd.sigmoid(mx.nd.squeeze(self.out(self.dropout(states)), axis=-1)) 375 | outputs.append(output) 376 | if valid_length is not None and not compressed_out: 377 | all_states.append([states]) 378 | 379 | if valid_length is not None: 380 | if compressed_out: 381 | states = None 382 | else: 383 | states = [mx.nd.SequenceLast(mx.nd.stack(*ele_list, axis=0), 384 | sequence_length=valid_length, 385 | use_sequence_length=True, 386 | axis=0) 387 | for ele_list in zip(*all_states)] 388 | outputs = mask_sequence_variable_length(mx.nd, outputs, length, valid_length, axis, True) 389 | outputs, _, _, _ = format_sequence(length, outputs, layout, merge=True) 390 | 391 | return outputs, states 392 | 393 | 394 | class SKTSync(SKT): 395 | def __init__(self, ku_num, graph_params=None, 396 | alpha=0.5, 397 | latent_dim=None, activation=None, 398 | hidden_num=90, concept_dim=None, 399 | dropout=0.0, self_dropout=0.0, 400 | sync_activation="relu", sync_dropout=0.0, 401 | prop_activation="relu", prop_dropout=0.0, 402 | agg_activation="relu", agg_dropout=0.0, 403 | prefix=None, params=None): 404 | super(SKT, self).__init__(prefix=prefix, params=params) 405 | self.ku_num = int(ku_num) 406 | self.hidden_num = self.ku_num if hidden_num is None else int(hidden_num) 407 | self.latent_dim = self.hidden_num if latent_dim is None else int(latent_dim) 408 | self.concept_dim = self.hidden_num if concept_dim is None else int(concept_dim) 409 | graph_params = graph_params if graph_params is not None else [] 410 | self.graph = Graph.from_file(ku_num, graph_params) 411 | self.alpha = alpha 412 | 413 | sync_activation = sync_activation if activation is None else activation 414 | 415 | with self.name_scope(): 416 | self.rnn = GRUCell(self.hidden_num) 417 | self.response_embedding = gluon.nn.Embedding(2 * self.ku_num, self.latent_dim) 418 | self.concept_embedding = gluon.nn.Embedding(self.ku_num, self.concept_dim) 419 | self.f_self = gluon.rnn.GRUCell(self.hidden_num) 420 | # self.f_self = gluon.nn.Sequential() 421 | # self.f_self.add( 422 | # gluon.nn.Dense(self.hidden_num), 423 | # gluon.nn.Activation("relu") 424 | # ) 425 | self.self_dropout = gluon.nn.Dropout(self_dropout) 426 | self.f_sync = gluon.nn.Sequential() 427 | self.f_sync.add( 428 | gluon.nn.Dense(self.hidden_num, flatten=False), 429 | gluon.nn.Activation(sync_activation), 430 | gluon.nn.Dropout(sync_dropout), 431 | ) 432 | self.f_reflec = gluon.nn.Sequential() 433 | self.f_reflec.add( 434 | gluon.nn.Dense(self.hidden_num, flatten=False), 435 | gluon.nn.Activation(sync_activation), 436 | gluon.nn.Dropout(sync_dropout), 437 | ) 438 | self.dropout = gluon.nn.Dropout(dropout) 439 | self.out = gluon.nn.Dense(1, flatten=False) 440 | 441 | def forward(self, questions, answers, valid_length=None, states=None, layout='NTC', compressed_out=True, 442 | *args, **kwargs): 443 | ctx = questions.context 444 | length = questions.shape[1] 445 | 446 | inputs, axis, F, batch_size = format_sequence(length, questions, layout, False) 447 | answers, _, _, _ = format_sequence(length, answers, layout, False) 448 | if states is None: 449 | states = begin_states([(batch_size, self.ku_num, self.hidden_num)], self.prefix)[0] 450 | states = states.as_in_context(ctx) 451 | outputs = [] 452 | all_states = [] 453 | for i in range(length): 454 | # self - influence 455 | _self_state = get_states(inputs[i], states) 456 | # fc 457 | # _next_self_state = self.f_self(mx.nd.concat(_self_state, self.response_embedding(answers[i]), dim=-1)) 458 | # gru 459 | _next_self_state, _ = self.f_self(self.response_embedding(answers[i]), [_self_state]) 460 | # _next_self_state = self.f_self(mx.nd.concat(_self_hidden_states, _self_state)) 461 | # _next_self_state, _ = self.f_self(_self_hidden_states, [_self_state]) 462 | _next_self_state = self.self_dropout(_next_self_state) 463 | 464 | # get self mask 465 | _self_mask = mx.nd.expand_dims(mx.nd.one_hot(inputs[i], self.ku_num), -1) 466 | _self_mask = mx.nd.broadcast_to(_self_mask, (0, 0, self.hidden_num)) 467 | # self-concept embedding 468 | _self_concept_embedding = self.concept_embedding(inputs[i]) 469 | # _broadcast_self_concept_embedding = mx.nd.expand_dims(_self_concept_embedding, dim=1) 470 | # _broadcast_self_concept_embedding = mx.nd.broadcast_to(_broadcast_self_concept_embedding, 471 | # (0, self.ku_num, 0)) 472 | # concept embedding 473 | concept_embeddings = self.concept_embedding.weight.data(ctx) 474 | concept_embeddings = expand_tensor(concept_embeddings, 0, batch_size) 475 | # concept_embeddings = (_self_mask + _successors_mask + _neighbors_mask) * concept_embeddings 476 | 477 | # find neighbors 478 | _neighbors = self.neighbors(inputs[i]) 479 | _neighbors_mask = mx.nd.expand_dims(mx.nd.array(_neighbors, ctx=ctx), -1) 480 | _neighbors_mask = mx.nd.broadcast_to(_neighbors_mask, (0, 0, self.hidden_num)) 481 | 482 | # synchronization 483 | _broadcast_next_self_states = mx.nd.expand_dims(_next_self_state, 1) 484 | _broadcast_next_self_states = mx.nd.broadcast_to(_broadcast_next_self_states, (0, self.ku_num, 0)) 485 | # _sync_diff = mx.nd.concat(states, _broadcast_next_self_states, concept_embeddings, dim=-1) 486 | _sync_diff = mx.nd.concat(states, _broadcast_next_self_states, dim=-1) 487 | _sync_inf = _neighbors_mask * self.f_sync( 488 | mx.nd.concat(_sync_diff, concept_embeddings, dim=-1) 489 | ) 490 | 491 | # reflection on current vertex 492 | _reflec_diff = mx.nd.concat(mx.nd.sum(_neighbors_mask * states, axis=1) + _next_self_state, 493 | _self_concept_embedding, dim=-1) 494 | # _reflec_diff = mx.nd.concat(mx.nd.sum(_neighbors_mask * states, axis=1), _next_self_state, 495 | # _self_concept_embedding, dim=-1) 496 | _reflec_inf = self.f_reflec(_reflec_diff) 497 | _reflec_inf = mx.nd.broadcast_to(mx.nd.expand_dims(_reflec_inf, 1), (0, self.ku_num, 0)) 498 | _sync_inf = _sync_inf + _self_mask * _reflec_inf 499 | 500 | # aggregate 501 | _inf = _sync_inf 502 | next_states, _ = self.rnn(_inf, [states]) 503 | states = (_neighbors_mask + _self_mask) * next_states + (1 - _neighbors_mask - _self_mask) * states 504 | # states = next_states 505 | output = mx.nd.sigmoid(mx.nd.squeeze(self.out(self.dropout(states)), axis=-1)) 506 | outputs.append(output) 507 | if valid_length is not None and not compressed_out: 508 | all_states.append([states]) 509 | 510 | if valid_length is not None: 511 | if compressed_out: 512 | states = None 513 | else: 514 | states = [mx.nd.SequenceLast(mx.nd.stack(*ele_list, axis=0), 515 | sequence_length=valid_length, 516 | use_sequence_length=True, 517 | axis=0) 518 | for ele_list in zip(*all_states)] 519 | outputs = mask_sequence_variable_length(mx.nd, outputs, length, valid_length, axis, True) 520 | outputs, _, _, _ = format_sequence(length, outputs, layout, merge=True) 521 | 522 | return outputs, states 523 | -------------------------------------------------------------------------------- /XKT/SKT/utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/23 @ tongshiwei 3 | 4 | import json 5 | import mxnet as mx 6 | from longling import as_list 7 | import networkx as nx 8 | 9 | 10 | class Graph(object): 11 | def __init__(self, ku_num, directed_graphs, undirected_graphs): 12 | self.ku_num = ku_num 13 | self.directed_graphs = as_list(directed_graphs) 14 | self.undirected_graphs = as_list(undirected_graphs) 15 | 16 | @staticmethod 17 | def _info(graph: nx.Graph): 18 | return {"edges": len(graph.edges)} 19 | 20 | @property 21 | def info(self): 22 | return { 23 | "directed": [self._info(graph) for graph in self.directed_graphs], 24 | "undirected": [self._info(graph) for graph in self.undirected_graphs] 25 | } 26 | 27 | def neighbors(self, x, ordinal=True, merge_to_one=True, with_weight=False, excluded=None): 28 | excluded = set() if excluded is None else excluded 29 | 30 | if isinstance(x, mx.nd.NDArray): 31 | x = x.asnumpy().tolist() 32 | if isinstance(x, list): 33 | return [self.neighbors(_x) for _x in x] 34 | elif isinstance(x, (int, float)): 35 | if not ordinal: 36 | if len(self.undirected_graphs) == 0: 37 | return None if not merge_to_one else [] 38 | elif len(self.undirected_graphs) == 1: 39 | return [v for v in self.undirected_graphs[0].neighbors(int(x)) if v not in excluded] 40 | else: 41 | if not merge_to_one: 42 | return [[v for v in graph.neighbors(int(x)) if v not in excluded] for graph in 43 | self.undirected_graphs] 44 | else: 45 | _ret = [] 46 | for graph in self.undirected_graphs: 47 | _ret.extend([v for v in graph.neighbors(int(x)) if v not in excluded]) 48 | return _ret 49 | else: # ordinal 50 | if not merge_to_one: 51 | if len(self.undirected_graphs) == 0: 52 | return None 53 | elif len(self.undirected_graphs) == 1: 54 | graph = self.undirected_graphs[0] 55 | _ret = [0] * self.ku_num 56 | for i in graph.neighbors(int(x)): 57 | if i in excluded: 58 | continue 59 | if with_weight: 60 | _ret[i] = graph[x][i].get('weight', 1) 61 | else: 62 | _ret[i] = 1 63 | return _ret 64 | else: 65 | _ret = [] 66 | for graph in self.undirected_graphs: 67 | __ret = [0] * self.ku_num 68 | for i in graph.neighbors(int(x)): 69 | if i in excluded: 70 | continue 71 | if with_weight: 72 | __ret[i] = graph[x][i].get('weight', 1) 73 | else: 74 | __ret[i] = 1 75 | _ret.append(__ret) 76 | else: 77 | if len(self.undirected_graphs) == 0: 78 | return [0] * self.ku_num 79 | else: 80 | _ret = [0] * self.ku_num 81 | for graph in self.undirected_graphs: 82 | for i in graph.neighbors(int(x)): 83 | if i in excluded: 84 | continue 85 | if with_weight: 86 | _ret[i] += graph[x][i].get('weight', 1) 87 | else: 88 | _ret[i] = 1 89 | return _ret 90 | else: 91 | raise TypeError("cannot handle %s" % type(x)) 92 | 93 | def successors(self, x, ordinal=True, merge_to_one=True, excluded=None): 94 | excluded = set() if excluded is None else excluded 95 | 96 | if isinstance(x, mx.nd.NDArray): 97 | x = x.asnumpy().tolist() 98 | if isinstance(x, list): 99 | return [self.neighbors(_x) for _x in x] 100 | elif isinstance(x, (int, float)): 101 | if not ordinal: 102 | if len(self.directed_graphs) == 0: 103 | return None if not merge_to_one else [] 104 | elif len(self.directed_graphs) == 1: 105 | return [v for v in self.directed_graphs[0].successors(int(x)) if v not in excluded] 106 | else: 107 | if not merge_to_one: 108 | return [[v for v in graph.successors(int(x)) if v not in excluded] for graph in 109 | self.directed_graphs] 110 | else: 111 | _ret = [] 112 | for graph in self.directed_graphs: 113 | _ret.extend([v for v in graph.successors(int(x)) if v not in excluded]) 114 | return _ret 115 | else: 116 | if not merge_to_one: 117 | if len(self.directed_graphs) == 0: 118 | return None 119 | elif len(self.directed_graphs) == 1: 120 | _ret = [0] * self.ku_num 121 | for i in self.directed_graphs[0].successors(int(x)): 122 | if i in excluded: 123 | continue 124 | _ret[i] = 1 125 | return _ret 126 | else: 127 | _ret = [] 128 | for graph in self.directed_graphs: 129 | __ret = [0] * self.ku_num 130 | for i in graph.successors(int(x)): 131 | if i in excluded: 132 | continue 133 | _ret[i] = 1 134 | _ret.append(__ret) 135 | else: 136 | if len(self.directed_graphs) == 0: 137 | return [0] * self.ku_num 138 | else: 139 | _ret = [0] * self.ku_num 140 | for graph in self.directed_graphs: 141 | for i in graph.successors(int(x)): 142 | if i in excluded: 143 | continue 144 | _ret[i] = 1 145 | return _ret 146 | else: 147 | raise TypeError("cannot handle %s" % type(x)) 148 | 149 | @classmethod 150 | def from_file(cls, graph_nodes_num, graph_params): 151 | directed_graphs = [] 152 | undirected_graphs = [] 153 | for graph_param in graph_params: 154 | graph, directed = load_graph(graph_nodes_num, *as_list(graph_param)) 155 | if directed: 156 | directed_graphs.append(graph) 157 | else: 158 | undirected_graphs.append(graph) 159 | return cls(graph_nodes_num, directed_graphs, undirected_graphs) 160 | 161 | 162 | def load_graph(graph_nodes_num, filename=None, directed: bool = True, threshold=0.0): 163 | directed = bool(directed) 164 | if directed: 165 | graph = nx.DiGraph() 166 | else: 167 | graph = nx.Graph() 168 | 169 | graph.add_nodes_from(range(graph_nodes_num)) 170 | if threshold < 0.0: 171 | for i in range(graph_nodes_num): 172 | for j in range(graph_nodes_num): 173 | graph.add_edge(i, j) 174 | else: 175 | assert filename is not None 176 | with open(filename) as f: 177 | for data in json.load(f): 178 | pre, suc = data[0], data[1] 179 | if len(data) >= 3 and float(data[2]) < threshold: 180 | continue 181 | elif len(data) >= 3: 182 | weight = float(data[2]) 183 | graph.add_edge(pre, suc, weight=weight) 184 | continue 185 | graph.add_edge(pre, suc) 186 | return graph, directed 187 | -------------------------------------------------------------------------------- /XKT/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # create by tongshiwei on 2019/7/2 3 | 4 | from .meta import KTM 5 | from .DKT import DKT 6 | from .DKVMN import DKVMN 7 | from .GKT import MGKT 8 | from .SKT import MSKT 9 | -------------------------------------------------------------------------------- /XKT/meta.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/22 @ tongshiwei 3 | from baize import path_append, get_params_filepath 4 | from baize.const import CFG_JSON 5 | from baize.mxnet import load_net, save_params, Configuration 6 | 7 | 8 | class KTM(object): 9 | def __init__(self, cfg: Configuration): 10 | self.cfg = cfg 11 | self.net = None 12 | 13 | def __call__(self, *args, **kwargs): 14 | assert self.net 15 | 16 | return self.net(*args, **kwargs) 17 | 18 | def train(self, *args, **kwargs) -> ...: 19 | raise NotImplementedError 20 | 21 | def eval(self, *args, **kwargs) -> ...: 22 | raise NotImplementedError 23 | 24 | def save(self, model_dir=None, *args, **kwargs) -> ...: 25 | model_dir = model_dir if model_dir is not None else self.cfg.model_dir 26 | select = kwargs.get("select", self.cfg.save_select) 27 | save_params(get_params_filepath(self.cfg.model_name, model_dir), self.net, select) 28 | self.cfg.dump(path_append(model_dir, CFG_JSON, to_str=True)) 29 | return model_dir 30 | 31 | def load(self, model_path, *args, **kwargs) -> ...: 32 | load_net(model_path, self.net) 33 | 34 | @classmethod 35 | def from_pretrained(cls, *args, **kwargs) -> ...: 36 | raise NotImplementedError 37 | 38 | @classmethod 39 | def benchmark_train(cls, *args, **kwargs) -> ...: 40 | raise NotImplementedError 41 | 42 | @classmethod 43 | def benchmark_eval(cls, *args, **kwargs) -> ...: 44 | raise NotImplementedError 45 | -------------------------------------------------------------------------------- /XKT/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # create by tongshiwei on 2019-7-13 3 | 4 | from .etl import * 5 | from .loss import SequenceLogisticMaskLoss as SLMLoss, LogisticMaskLoss as LMLoss 6 | from .loss import SequenceLogisticMaskLoss, LogisticMaskLoss 7 | -------------------------------------------------------------------------------- /XKT/utils/etl.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # create by tongshiwei on 2019-7-13 3 | 4 | import json 5 | 6 | from tqdm import tqdm 7 | 8 | __all__ = ["extract", "extract_iter"] 9 | 10 | 11 | def extract(data_src, max_step=200): 12 | responses = [] 13 | step = max_step 14 | with open(data_src) as f: 15 | for line in tqdm(f, "reading data from %s" % data_src): 16 | data = json.loads(line) 17 | if step is not None: 18 | for i in range(0, len(data), step): 19 | if len(data[i: i + step]) < 2: 20 | continue 21 | responses.append(data[i: i + step]) 22 | else: 23 | responses.append(data) 24 | 25 | return responses 26 | 27 | 28 | def extract_iter(data_src): 29 | step = 200 30 | with open(data_src) as f: 31 | for line in tqdm(f, "reading data from %s" % data_src): 32 | data = json.loads(line) 33 | for i in range(0, len(data), step): 34 | if len(data[i: i + step]) < 2: 35 | continue 36 | yield data[i: i + step] 37 | -------------------------------------------------------------------------------- /XKT/utils/loss.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # create by tongshiwei on 2019-7-30 3 | 4 | __all__ = ["SequenceLogisticMaskLoss", "LogisticMaskLoss"] 5 | 6 | from mxnet import gluon 7 | 8 | 9 | class SequenceLogisticMaskLoss(gluon.HybridBlock): 10 | """ 11 | Notes 12 | ----- 13 | The loss has been average, so when call the step method of trainer, batch_size should be 1 14 | """ 15 | 16 | def __init__(self, lr=0.0, lw1=0.0, lw2=0.0, **kwargs): 17 | super(SequenceLogisticMaskLoss, self).__init__(**kwargs) 18 | self.lr = lr 19 | self.lw1 = lw1 20 | self.lw2 = lw2 21 | with self.name_scope(): 22 | self.loss = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=True) 23 | 24 | def hybrid_forward(self, F, pred_rs, pick_index, label, label_mask, *args, **kwargs): 25 | if self.lw1 > 0.0 or self.lw2 > 0.0: 26 | pre_pred_rs = F.slice_axis(pred_rs, axis=1, begin=0, end=-1) 27 | post_pred_rs = F.slice_axis(pred_rs, axis=1, begin=1, end=None) 28 | diff = post_pred_rs - pre_pred_rs 29 | _weight_mask = F.squeeze( 30 | F.SequenceMask(F.expand_dims(F.ones_like(pre_pred_rs), -1), sequence_length=label_mask, 31 | use_sequence_length=True, axis=1) 32 | ) 33 | diff = _weight_mask * diff 34 | w1 = F.mean(F.norm(diff, 1, -1)) / diff.shape[-1] 35 | w2 = F.mean(F.norm(diff, 2, -1)) / diff.shape[-1] 36 | # w2 = F.mean(F.sqrt(diff ** 2)) 37 | w1 = w1 * self.lw1 if self.lw1 > 0.0 else 0.0 38 | w2 = w2 * self.lw2 if self.lw2 > 0.0 else 0.0 39 | else: 40 | w1 = 0.0 41 | w2 = 0.0 42 | 43 | if self.lr > 0.0: 44 | re_pred_rs = F.slice_axis(pred_rs, axis=1, begin=1, end=None) 45 | re_pred_rs = F.pick(re_pred_rs, pick_index) 46 | re_weight_mask = F.squeeze( 47 | F.SequenceMask(F.expand_dims(F.ones_like(re_pred_rs), -1), sequence_length=label_mask, 48 | use_sequence_length=True, axis=1) 49 | ) 50 | wr = self.loss(re_pred_rs, label, re_weight_mask) 51 | wr = F.mean(wr) * self.lr 52 | else: 53 | wr = 0.0 54 | 55 | pred_rs = F.slice_axis(pred_rs, axis=1, begin=0, end=-1) 56 | pred_rs = F.pick(pred_rs, pick_index) 57 | weight_mask = F.squeeze( 58 | F.SequenceMask(F.expand_dims(F.ones_like(pred_rs), -1), sequence_length=label_mask, 59 | use_sequence_length=True, axis=1) 60 | ) 61 | loss = self.loss(pred_rs, label, weight_mask) 62 | # loss = F.sum(loss, axis=-1) 63 | loss = F.mean(loss) + w1 + w2 + wr 64 | return loss 65 | 66 | 67 | class LogisticMaskLoss(gluon.HybridBlock): 68 | """ 69 | Notes 70 | ----- 71 | The loss has been average, so when call the step method of trainer, batch_size should be 1 72 | """ 73 | 74 | def __init__(self, **kwargs): 75 | super(LogisticMaskLoss, self).__init__(**kwargs) 76 | 77 | with self.name_scope(): 78 | self.loss = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=True) 79 | 80 | def hybrid_forward(self, F, pred_rs, label, label_mask, *args, **kwargs): 81 | weight_mask = F.squeeze( 82 | F.SequenceMask(F.expand_dims(F.ones_like(pred_rs), -1), sequence_length=label_mask, 83 | use_sequence_length=True, axis=1) 84 | ) 85 | loss = self.loss(pred_rs, label, weight_mask) 86 | loss = F.mean(loss) 87 | return loss 88 | -------------------------------------------------------------------------------- /XKT/utils/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/23 @ tongshiwei 3 | 4 | from .rnn import * 5 | -------------------------------------------------------------------------------- /XKT/utils/nn/rnn.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/23 @ tongshiwei 3 | import mxnet as mx 4 | from baize import as_list 5 | from mxnet import gluon 6 | 7 | __all__ = ["GRUCell", "begin_states", "get_states", "expand_tensor", "expand_states"] 8 | 9 | 10 | def begin_states(shapes, prefix, func=mx.nd.zeros): 11 | states = [] 12 | for i, shape in enumerate(as_list(shapes)): 13 | state = func(name='%sbegin_state_%d' % (prefix, i), shape=shape) 14 | states.append(state) 15 | return states 16 | 17 | 18 | def get_states(indexes, states): 19 | if isinstance(indexes, mx.nd.NDArray): 20 | indexes = indexes.asnumpy().tolist() 21 | if isinstance(indexes, list): 22 | return mx.nd.stack(*[get_states(index, state) for (index, state) in zip(indexes, states)]) 23 | elif isinstance(indexes, (int, float)): 24 | return states[int(indexes)] 25 | else: 26 | raise TypeError("cannot handle %s" % type(indexes)) 27 | 28 | 29 | def expand_states(indexes, states, expand_num): 30 | if isinstance(indexes, mx.nd.NDArray): 31 | indexes = indexes.asnumpy().tolist() 32 | if isinstance(indexes, list): 33 | return mx.nd.stack(*[expand_states(index, state, expand_num) for (index, state) in zip(indexes, states)]) 34 | elif isinstance(indexes, (int, float)): 35 | _expand_state = mx.nd.broadcast_to(mx.nd.expand_dims(states, 0), (expand_num, 0)) 36 | _mask = mx.nd.array([[0] * len(states) for _ in range(expand_num)], ctx=states.context) 37 | return _expand_state * _mask 38 | else: 39 | raise TypeError("cannot handle %s" % type(indexes)) 40 | 41 | 42 | def expand_tensor(tensor, expand_axis, expand_num, ctx=None, dtype=None) -> mx.nd.NDArray: 43 | if not isinstance(tensor, mx.nd.NDArray): 44 | tensor = mx.nd.array(tensor, ctx, dtype) 45 | assert len(tensor.shape) == 2 46 | 47 | _tensor = mx.nd.expand_dims(tensor, expand_axis) 48 | 49 | shape = [0] * 3 50 | shape[expand_axis] = expand_num 51 | 52 | _tensor = mx.nd.broadcast_to(_tensor, tuple(shape)) 53 | 54 | return _tensor 55 | 56 | 57 | class GRUCell(gluon.nn.Block): 58 | def __init__(self, hidden_num, prefix=None, params=None): 59 | super(GRUCell, self).__init__(prefix, params) 60 | with self.name_scope(): 61 | self.i2h = gluon.nn.Dense(3 * hidden_num, flatten=False) 62 | self.h2h = gluon.nn.Dense(3 * hidden_num, flatten=False) 63 | self.reset_act = gluon.nn.Activation("sigmoid") 64 | self.update_act = gluon.nn.Activation("sigmoid") 65 | self.act = gluon.nn.Activation("tanh") 66 | 67 | def forward(self, inputs, states): 68 | prev_state_h = states[0] 69 | 70 | i2h = self.i2h(inputs) 71 | h2h = self.h2h(prev_state_h) 72 | i2h_r, i2h_z, i2h = mx.nd.SliceChannel(i2h, 3, axis=-1) 73 | h2h_r, h2h_z, h2h = mx.nd.SliceChannel(h2h, 3, axis=-1) 74 | 75 | reset_gate = self.reset_act(i2h_r + h2h_r) 76 | update_gate = self.update_act(i2h_z + h2h_z) 77 | next_h_tmp = self.act(i2h + reset_gate * h2h) 78 | ones = mx.nd.ones_like(update_gate) 79 | next_h = (ones - update_gate) * next_h_tmp + update_gate * prev_state_h 80 | 81 | return next_h, [next_h] 82 | -------------------------------------------------------------------------------- /XKT/utils/tests.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/24 @ tongshiwei 3 | 4 | __all__ = ["pseudo_data_generation"] 5 | 6 | 7 | def pseudo_data_generation(ku_num): 8 | import random 9 | random.seed(10) 10 | 11 | raw_data = [ 12 | [ 13 | (random.randint(0, ku_num - 1), random.randint(-1, 1)) 14 | for _ in range(random.randint(2, 20)) 15 | ] for _ in range(100) 16 | ] 17 | 18 | return raw_data 19 | -------------------------------------------------------------------------------- /docs/DISCUSSION.md: -------------------------------------------------------------------------------- 1 | # Discussion 2 | 3 | ## The performance of DKT is not as good as the paper reported 4 | 5 | For one thing, there may exist some approximate bias in auc evaluation as posted in [DKT issue](https://github.com/chrispiech/DeepKnowledgeTracing/issues/6) 6 | 7 | To verify that, a python version which reproduce the algorithm calculating auc as the source code indicates is: 8 | 9 | ```python 10 | import random 11 | from sklearn.metrics import roc_auc_score 12 | 13 | x = sorted([random.random() for _ in range(126000)]) 14 | y = [random.randint(0, 1) for _ in range(126000)] 15 | print(roc_auc_score(y, x)) 16 | 17 | 18 | def auc_dkt(x, y): 19 | true_positives = 0 20 | false_positives = 0 21 | 22 | total_positives = sum([1 for e in y if e == 1]) 23 | total_negatives = sum([1 for e in y if e == 0]) 24 | 25 | last_fpr = None 26 | last_tpr = None 27 | 28 | _auc = 0 29 | 30 | for i, (_x, _y) in enumerate(zip(x, y)): 31 | if _y == 1: 32 | true_positives += 1 33 | else: 34 | false_positives += 1 35 | 36 | fpr = false_positives / total_negatives 37 | tpr = false_positives / total_positives 38 | 39 | if i % 500 == 0: 40 | if last_fpr is not None: 41 | trapezoid = (tpr + last_tpr) * (fpr - last_fpr) * 0.5 42 | _auc += trapezoid 43 | last_fpr = fpr 44 | last_tpr = tpr 45 | return _auc 46 | 47 | 48 | print(auc_dkt(x, y)) 49 | ``` 50 | 51 | and get the result: 52 | 53 | ```text 54 | 0.5005808208832101 55 | 0.4938467272251457 56 | ``` 57 | 58 | That means there is potential approximate bias in auc evaluation. However, it is still need to be clear that such bias is quite small when the scale of dataset gets larger. 59 | 60 | For the other thing, the frameworks chosen to build neural network do matter. We found that `mxnet` performs quite badly with the simplest RNN architecture while `pytorch`'s only fall a little behind the paper result. -------------------------------------------------------------------------------- /examples/DKT/DKT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true, 7 | "pycharm": { 8 | "name": "#%% md\n" 9 | } 10 | }, 11 | "source": [ 12 | "# Deep Knowledge Tracing\n", 13 | "\n", 14 | "This notebook will show you how to train and use the DKT.\n", 15 | "First, we will show how to get the data (here we use a0910 as the dataset).\n", 16 | "Then we will show how to train a DKT and perform the parameters persistence.\n", 17 | "At last, we will show how to load the parameters from the file and evaluate on the test dataset.\n", 18 | "\n", 19 | "The script version could be found in [DKT.py](DKT.py)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "source": [ 25 | "## Data Preparation\n", 26 | "\n", 27 | "Before we process the data, we need to first acquire the dataset which is shown in [prepare_dataset.ipynb](prepare_dataset.ipynb)" 28 | ], 29 | "metadata": { 30 | "collapsed": false 31 | } 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 1, 36 | "outputs": [ 37 | { 38 | "name": "stderr", 39 | "output_type": "stream", 40 | "text": [ 41 | "reading data from ../../data/a0910c/train.json: 3966it [00:00, 23302.50it/s]\n", 42 | "batchify: 100%|██████████| 130/130 [00:00<00:00, 629.18it/s]\n", 43 | "reading data from ../../data/a0910c/valid.json: 472it [00:00, 27564.52it/s]\n", 44 | "e:\\program\\baize\\baize\\extlib\\bucketing.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[55, 58, 59, 61, 65, 69, 74, 76, 77, 79, 80, 88, 90, 94, 95, 96, 99]\n", 45 | " warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n", 46 | "batchify: 100%|██████████| 84/84 [00:00<00:00, 1241.57it/s]\n", 47 | "reading data from ../../data/a0910c/test.json: 1088it [00:00, 13857.38it/s]\n", 48 | "e:\\program\\baize\\baize\\extlib\\bucketing.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[73, 88]\n", 49 | " warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n", 50 | "batchify: 100%|██████████| 101/101 [00:00<00:00, 931.51it/s]\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "from XKT.DKT import etl\n", 56 | "batch_size = 32\n", 57 | "train = etl(\"../../data/a0910c/train.json\", batch_size=batch_size)\n", 58 | "valid = etl(\"../../data/a0910c/valid.json\", batch_size=batch_size)\n", 59 | "test = etl(\"../../data/a0910c/test.json\", batch_size=batch_size)" 60 | ], 61 | "metadata": { 62 | "collapsed": false, 63 | "pycharm": { 64 | "name": "#%%\n" 65 | } 66 | } 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "source": [ 71 | "## Training and Persistence" 72 | ], 73 | "metadata": { 74 | "collapsed": false, 75 | "pycharm": { 76 | "name": "#%% md\n" 77 | } 78 | } 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 2, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "logger: \n", 89 | "model_name: model\n", 90 | "model_dir: model\n", 91 | "begin_epoch: 0\n", 92 | "end_epoch: 2\n", 93 | "batch_size: 32\n", 94 | "save_epoch: 1\n", 95 | "optimizer: Adam\n", 96 | "optimizer_params: {'learning_rate': 0.001, 'wd': 0.0001, 'clip_gradient': 1}\n", 97 | "lr_params: {}\n", 98 | "train_select: None\n", 99 | "save_select: None\n", 100 | "ctx: cpu(0)\n", 101 | "train_ctx: None\n", 102 | "eval_ctx: None\n", 103 | "toolbox_params: {}\n", 104 | "hyper_params: {'ku_num': 146, 'hidden_num': 100}\n", 105 | "init_params: {}\n", 106 | "loss_params: {}\n", 107 | "caption: \n", 108 | "validation_result_file: model\\result.json\n", 109 | "cfg_path: model\\configuration.json\n", 110 | "Epoch| Total-E Batch Total-B Loss-SequenceLogisticMaskLoss Progress \n", 111 | " 0| 1 130 130 0.639498 [00:04<00:00, 30.63it/s] \n", 112 | "Epoch [0]\tLoss - SequenceLogisticMaskLoss: 0.639498\n", 113 | " precision recall f1 support\n", 114 | "0.0 0.427845 0.205795 0.277913 7765\n", 115 | "1.0 0.689022 0.864755 0.766951 15801\n", 116 | "macro_avg 0.558433 0.535275 0.522432 23566\n", 117 | "accuracy: 0.647628\tmacro_auc: 0.562898\tmacro_aupoc: 0.713164\n", 118 | "Epoch| Total-E Batch Total-B Loss-SequenceLogisticMaskLoss Progress \n", 119 | " 1| 1 130 130 0.624943 [00:04<00:00, 31.97it/s] \n", 120 | "Epoch [1]\tLoss - SequenceLogisticMaskLoss: 0.624943\n", 121 | " precision recall f1 support\n", 122 | "0.0 0.472871 0.153767 0.232070 7765\n", 123 | "1.0 0.687705 0.915765 0.785517 15801\n", 124 | "macro_avg 0.580288 0.534766 0.508793 23566\n", 125 | "accuracy: 0.664686\tmacro_auc: 0.579386\tmacro_aupoc: 0.722303\n" 126 | ] 127 | }, 128 | { 129 | "name": "stderr", 130 | "output_type": "stream", 131 | "text": [ 132 | "evaluating: 100%|██████████| 84/84 [00:00<00:00, 130.42it/s]\n", 133 | "evaluating: 100%|██████████| 84/84 [00:00<00:00, 130.78it/s]\n", 134 | "model, INFO writing configuration parameters to G:\\program\\XKT\\examples\\DKT\\dkt\\configuration.json\n" 135 | ] 136 | }, 137 | { 138 | "data": { 139 | "text/plain": "'dkt'" 140 | }, 141 | "execution_count": 2, 142 | "metadata": {}, 143 | "output_type": "execute_result" 144 | } 145 | ], 146 | "source": [ 147 | "from XKT import DKT\n", 148 | "model = DKT(hyper_params=dict(ku_num=146, hidden_num=100))\n", 149 | "model.train(train, valid, end_epoch=2)\n", 150 | "model.save(\"dkt\")" 151 | ], 152 | "metadata": { 153 | "collapsed": false, 154 | "pycharm": { 155 | "name": "#%%\n" 156 | } 157 | } 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "source": [ 162 | "## Loading and Testing" 163 | ], 164 | "metadata": { 165 | "collapsed": false 166 | } 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 3, 171 | "outputs": [ 172 | { 173 | "name": "stderr", 174 | "output_type": "stream", 175 | "text": [ 176 | "evaluating: 100%|██████████| 101/101 [00:00<00:00, 113.98it/s]\n" 177 | ] 178 | }, 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | " precision recall f1 support\n", 184 | "0.0 0.484619 0.157390 0.237611 17517\n", 185 | "1.0 0.670330 0.911000 0.772351 32944\n", 186 | "macro_avg 0.577475 0.534195 0.504981 50461\n", 187 | "accuracy: 0.649393\tmacro_auc: 0.570926\tmacro_aupoc: 0.702939\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "model = DKT.from_pretrained(\"dkt\")\n", 193 | "print(model.eval(test))\n" 194 | ], 195 | "metadata": { 196 | "collapsed": false, 197 | "pycharm": { 198 | "name": "#%%\n" 199 | } 200 | } 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "source": [ 205 | "## Predict" 206 | ], 207 | "metadata": { 208 | "collapsed": false, 209 | "pycharm": { 210 | "name": "#%% md\n" 211 | } 212 | } 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "outputs": [], 218 | "source": [ 219 | "import mxnet as mx\n", 220 | "inputs = mx.nd.ones((2, 3)) # (2 students, 3 steps)\n", 221 | "outputs, _ = model(inputs)\n", 222 | "outputs.shape" 223 | ], 224 | "metadata": { 225 | "collapsed": false, 226 | "pycharm": { 227 | "name": "#%%\n" 228 | } 229 | } 230 | } 231 | ], 232 | "metadata": { 233 | "kernelspec": { 234 | "display_name": "Python 3", 235 | "language": "python", 236 | "name": "python3" 237 | }, 238 | "language_info": { 239 | "codemirror_mode": { 240 | "name": "ipython", 241 | "version": 2 242 | }, 243 | "file_extension": ".py", 244 | "mimetype": "text/x-python", 245 | "name": "python", 246 | "nbconvert_exporter": "python", 247 | "pygments_lexer": "ipython2", 248 | "version": "2.7.6" 249 | } 250 | }, 251 | "nbformat": 4, 252 | "nbformat_minor": 0 253 | } -------------------------------------------------------------------------------- /examples/DKT/DKT.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/5/26 @ tongshiwei 3 | import mxnet as mx 4 | from XKT.DKT import etl 5 | 6 | from XKT import DKT 7 | 8 | batch_size = 32 9 | train = etl("../../data/a0910c/train.json", batch_size=batch_size) 10 | valid = etl("../../data/a0910c/valid.json", batch_size=batch_size) 11 | test = etl("../../data/a0910c/test.json", batch_size=batch_size) 12 | 13 | model = DKT(hyper_params=dict(ku_num=146, hidden_num=100)) 14 | model.train(train, valid, end_epoch=2) 15 | model.save("dkt") 16 | 17 | model = DKT.from_pretrained("dkt") 18 | print(model.eval(test)) 19 | 20 | inputs = mx.nd.ones((2, 3)) 21 | outputs, _ = model(inputs) 22 | print(outputs) 23 | -------------------------------------------------------------------------------- /examples/DKT/prepare_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": { 7 | "collapsed": true, 8 | "pycharm": { 9 | "name": "#%%\n" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stderr", 15 | "output_type": "stream", 16 | "text": [ 17 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/readme.txt is saved as ..\\..\\data\\a0910c\\readme.txt\n", 18 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/test.json is saved as ..\\..\\data\\a0910c\\test.json\n", 19 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/train.json is saved as ..\\..\\data\\a0910c\\train.json\n", 20 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/valid.json is saved as ..\\..\\data\\a0910c\\valid.json\n" 21 | ] 22 | }, 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Downloading ..\\..\\data\\a0910c\\readme.txt 100.00%: 21.0B | 21.0B\n", 28 | "Downloading ..\\..\\data\\a0910c\\test.json 100.00%: 466KB | 466KB\n", 29 | "Downloading ..\\..\\data\\a0910c\\train.json 100.00%: 1.72MB | 1.72MB\n", 30 | "Downloading ..\\..\\data\\a0910c\\valid.json 100.00%: 217KB | 217KB\n" 31 | ] 32 | }, 33 | { 34 | "data": { 35 | "text/plain": "'..\\\\..\\\\data\\\\a0910c'" 36 | }, 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "output_type": "execute_result" 40 | } 41 | ], 42 | "source": [ 43 | "from EduData import get_data\n", 44 | "\n", 45 | "get_data(\"ktbd-a0910c\", \"../../data\")" 46 | ] 47 | } 48 | ], 49 | "metadata": { 50 | "kernelspec": { 51 | "display_name": "Python 3", 52 | "language": "python", 53 | "name": "python3" 54 | }, 55 | "language_info": { 56 | "codemirror_mode": { 57 | "name": "ipython", 58 | "version": 2 59 | }, 60 | "file_extension": ".py", 61 | "mimetype": "text/x-python", 62 | "name": "python", 63 | "nbconvert_exporter": "python", 64 | "pygments_lexer": "ipython2", 65 | "version": "2.7.6" 66 | } 67 | }, 68 | "nbformat": 4, 69 | "nbformat_minor": 0 70 | } -------------------------------------------------------------------------------- /examples/DKVMN/DKVMN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true, 7 | "pycharm": { 8 | "name": "#%% md\n" 9 | } 10 | }, 11 | "source": [ 12 | "# Deep Key-Value Memory Network\n", 13 | "\n", 14 | "This notebook will show you how to train and use the DKVMN.\n", 15 | "First, we will show how to get the data (here we use a0910 as the dataset).\n", 16 | "Then we will show how to train a DKVMN and perform the parameters persistence.\n", 17 | "At last, we will show how to load the parameters from the file and evaluate on the test dataset.\n", 18 | "\n", 19 | "The script version could be found in [DKVMN.py](DKVMN.py)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "source": [ 25 | "## Data Preparation\n", 26 | "\n", 27 | "Before we process the data, we need to first acquire the dataset which is shown in [prepare_dataset.ipynb](prepare_dataset.ipynb)" 28 | ], 29 | "metadata": { 30 | "collapsed": false 31 | } 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 1, 36 | "outputs": [ 37 | { 38 | "name": "stderr", 39 | "output_type": "stream", 40 | "text": [ 41 | "reading data from ../../data/a0910c/train.json: 3966it [00:00, 23302.50it/s]\n", 42 | "batchify: 100%|██████████| 130/130 [00:00<00:00, 629.18it/s]\n", 43 | "reading data from ../../data/a0910c/valid.json: 472it [00:00, 27564.52it/s]\n", 44 | "e:\\program\\baize\\baize\\extlib\\bucketing.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[55, 58, 59, 61, 65, 69, 74, 76, 77, 79, 80, 88, 90, 94, 95, 96, 99]\n", 45 | " warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n", 46 | "batchify: 100%|██████████| 84/84 [00:00<00:00, 1241.57it/s]\n", 47 | "reading data from ../../data/a0910c/test.json: 1088it [00:00, 13857.38it/s]\n", 48 | "e:\\program\\baize\\baize\\extlib\\bucketing.py:327: UserWarning: Some buckets are empty and will be removed. Unused bucket keys=[73, 88]\n", 49 | " warnings.warn('Some buckets are empty and will be removed. Unused bucket keys=%s' %\n", 50 | "batchify: 100%|██████████| 101/101 [00:00<00:00, 931.51it/s]\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "from XKT.DKVMN import etl\n", 56 | "batch_size = 32\n", 57 | "train = etl(\"../../data/a0910c/train.json\", batch_size=batch_size)\n", 58 | "valid = etl(\"../../data/a0910c/valid.json\", batch_size=batch_size)\n", 59 | "test = etl(\"../../data/a0910c/test.json\", batch_size=batch_size)" 60 | ], 61 | "metadata": { 62 | "collapsed": false, 63 | "pycharm": { 64 | "name": "#%%\n" 65 | } 66 | } 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "source": [ 71 | "## Training and Persistence" 72 | ], 73 | "metadata": { 74 | "collapsed": false, 75 | "pycharm": { 76 | "name": "#%% md\n" 77 | } 78 | } 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 2, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "logger: \n", 89 | "model_name: model\n", 90 | "model_dir: model\n", 91 | "begin_epoch: 0\n", 92 | "end_epoch: 2\n", 93 | "batch_size: 32\n", 94 | "save_epoch: 1\n", 95 | "optimizer: Adam\n", 96 | "optimizer_params: {'learning_rate': 0.001, 'wd': 0.0001, 'clip_gradient': 1}\n", 97 | "lr_params: {}\n", 98 | "train_select: None\n", 99 | "save_select: None\n", 100 | "ctx: cpu(0)\n", 101 | "train_ctx: None\n", 102 | "eval_ctx: None\n", 103 | "toolbox_params: {}\n", 104 | "hyper_params: {'ku_num': 146, 'hidden_num': 100}\n", 105 | "init_params: {}\n", 106 | "loss_params: {}\n", 107 | "caption: \n", 108 | "validation_result_file: model\\result.json\n", 109 | "cfg_path: model\\configuration.json\n", 110 | "Epoch| Total-E Batch Total-B Loss-SequenceLogisticMaskLoss Progress \n", 111 | " 0| 1 130 130 0.639498 [00:04<00:00, 30.63it/s] \n", 112 | "Epoch [0]\tLoss - SequenceLogisticMaskLoss: 0.639498\n", 113 | " precision recall f1 support\n", 114 | "0.0 0.427845 0.205795 0.277913 7765\n", 115 | "1.0 0.689022 0.864755 0.766951 15801\n", 116 | "macro_avg 0.558433 0.535275 0.522432 23566\n", 117 | "accuracy: 0.647628\tmacro_auc: 0.562898\tmacro_aupoc: 0.713164\n", 118 | "Epoch| Total-E Batch Total-B Loss-SequenceLogisticMaskLoss Progress \n", 119 | " 1| 1 130 130 0.624943 [00:04<00:00, 31.97it/s] \n", 120 | "Epoch [1]\tLoss - SequenceLogisticMaskLoss: 0.624943\n", 121 | " precision recall f1 support\n", 122 | "0.0 0.472871 0.153767 0.232070 7765\n", 123 | "1.0 0.687705 0.915765 0.785517 15801\n", 124 | "macro_avg 0.580288 0.534766 0.508793 23566\n", 125 | "accuracy: 0.664686\tmacro_auc: 0.579386\tmacro_aupoc: 0.722303\n" 126 | ] 127 | }, 128 | { 129 | "name": "stderr", 130 | "output_type": "stream", 131 | "text": [ 132 | "evaluating: 100%|██████████| 84/84 [00:00<00:00, 130.42it/s]\n", 133 | "evaluating: 100%|██████████| 84/84 [00:00<00:00, 130.78it/s]\n", 134 | "model, INFO writing configuration parameters to G:\\program\\XKT\\examples\\DKT\\dkt\\configuration.json\n" 135 | ] 136 | }, 137 | { 138 | "data": { 139 | "text/plain": "'dkt'" 140 | }, 141 | "execution_count": 2, 142 | "metadata": {}, 143 | "output_type": "execute_result" 144 | } 145 | ], 146 | "source": [ 147 | "from XKT import DKVMN\n", 148 | "model = DKVMN(\n", 149 | " hyper_params=dict(\n", 150 | " ku_num=146,\n", 151 | " key_embedding_dim=10,\n", 152 | " value_embedding_dim=10,\n", 153 | " key_memory_size=20,\n", 154 | " hidden_num=100\n", 155 | " )\n", 156 | ")\n", 157 | "model.train(train, valid, end_epoch=2)\n", 158 | "model.save(\"dkvmn\")" 159 | ], 160 | "metadata": { 161 | "collapsed": false, 162 | "pycharm": { 163 | "name": "#%%\n" 164 | } 165 | } 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "source": [ 170 | "## Loading and Testing" 171 | ], 172 | "metadata": { 173 | "collapsed": false 174 | } 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 3, 179 | "outputs": [ 180 | { 181 | "name": "stderr", 182 | "output_type": "stream", 183 | "text": [ 184 | "evaluating: 100%|██████████| 101/101 [00:00<00:00, 113.98it/s]\n" 185 | ] 186 | }, 187 | { 188 | "name": "stdout", 189 | "output_type": "stream", 190 | "text": [ 191 | " precision recall f1 support\n", 192 | "0.0 0.484619 0.157390 0.237611 17517\n", 193 | "1.0 0.670330 0.911000 0.772351 32944\n", 194 | "macro_avg 0.577475 0.534195 0.504981 50461\n", 195 | "accuracy: 0.649393\tmacro_auc: 0.570926\tmacro_aupoc: 0.702939\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "model = DKVMN.from_pretrained(\"dkvmn\")\n", 201 | "print(model.eval(test))" 202 | ], 203 | "metadata": { 204 | "collapsed": false, 205 | "pycharm": { 206 | "name": "#%%\n" 207 | } 208 | } 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "source": [ 213 | "## Predict" 214 | ], 215 | "metadata": { 216 | "collapsed": false 217 | } 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "outputs": [], 223 | "source": [ 224 | "import mxnet as mx\n", 225 | "inputs = mx.nd.ones((2, 3)) # (2 students, 3 steps)\n", 226 | "outputs, _ = model(inputs)\n", 227 | "outputs.shape" 228 | ], 229 | "metadata": { 230 | "collapsed": false, 231 | "pycharm": { 232 | "name": "#%%\n" 233 | } 234 | } 235 | } 236 | ], 237 | "metadata": { 238 | "kernelspec": { 239 | "display_name": "Python 3", 240 | "language": "python", 241 | "name": "python3" 242 | }, 243 | "language_info": { 244 | "codemirror_mode": { 245 | "name": "ipython", 246 | "version": 2 247 | }, 248 | "file_extension": ".py", 249 | "mimetype": "text/x-python", 250 | "name": "python", 251 | "nbconvert_exporter": "python", 252 | "pygments_lexer": "ipython2", 253 | "version": "2.7.6" 254 | } 255 | }, 256 | "nbformat": 4, 257 | "nbformat_minor": 0 258 | } -------------------------------------------------------------------------------- /examples/DKVMN/DKVMN.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/5/26 @ tongshiwei 3 | import mxnet as mx 4 | from XKT.DKVMN import etl 5 | 6 | from XKT import DKVMN 7 | 8 | batch_size = 32 9 | train = etl("../../data/a0910c/train.json", batch_size=batch_size) 10 | valid = etl("../../data/a0910c/valid.json", batch_size=batch_size) 11 | test = etl("../../data/a0910c/test.json", batch_size=batch_size) 12 | 13 | model = DKVMN( 14 | hyper_params=dict( 15 | ku_num=146, 16 | key_embedding_dim=10, 17 | value_embedding_dim=10, 18 | key_memory_size=20, 19 | hidden_num=100 20 | ) 21 | ) 22 | model.train(train, valid, end_epoch=2) 23 | model.save("dkvmn") 24 | 25 | model = DKVMN.from_pretrained("dkvmn") 26 | print(model.eval(test)) 27 | 28 | 29 | inputs = mx.nd.ones((2, 3)) 30 | outputs, _ = model(inputs) 31 | print(outputs) -------------------------------------------------------------------------------- /examples/DKVMN/prepare_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": { 7 | "collapsed": true, 8 | "pycharm": { 9 | "name": "#%%\n" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stderr", 15 | "output_type": "stream", 16 | "text": [ 17 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/readme.txt is saved as ..\\..\\data\\a0910c\\readme.txt\n", 18 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/test.json is saved as ..\\..\\data\\a0910c\\test.json\n", 19 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/train.json is saved as ..\\..\\data\\a0910c\\train.json\n", 20 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/a0910c/valid.json is saved as ..\\..\\data\\a0910c\\valid.json\n" 21 | ] 22 | }, 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Downloading ..\\..\\data\\a0910c\\readme.txt 100.00%: 21.0B | 21.0B\n", 28 | "Downloading ..\\..\\data\\a0910c\\test.json 100.00%: 466KB | 466KB\n", 29 | "Downloading ..\\..\\data\\a0910c\\train.json 100.00%: 1.72MB | 1.72MB\n", 30 | "Downloading ..\\..\\data\\a0910c\\valid.json 100.00%: 217KB | 217KB\n" 31 | ] 32 | }, 33 | { 34 | "data": { 35 | "text/plain": "'..\\\\..\\\\data\\\\a0910c'" 36 | }, 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "output_type": "execute_result" 40 | } 41 | ], 42 | "source": [ 43 | "from EduData import get_data\n", 44 | "\n", 45 | "get_data(\"ktbd-a0910c\", \"../../data\")" 46 | ] 47 | } 48 | ], 49 | "metadata": { 50 | "kernelspec": { 51 | "display_name": "Python 3", 52 | "language": "python", 53 | "name": "python3" 54 | }, 55 | "language_info": { 56 | "codemirror_mode": { 57 | "name": "ipython", 58 | "version": 2 59 | }, 60 | "file_extension": ".py", 61 | "mimetype": "text/x-python", 62 | "name": "python", 63 | "nbconvert_exporter": "python", 64 | "pygments_lexer": "ipython2", 65 | "version": "2.7.6" 66 | } 67 | }, 68 | "nbformat": 4, 69 | "nbformat_minor": 0 70 | } -------------------------------------------------------------------------------- /examples/GKT/MGKT.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/5/26 @ tongshiwei 3 | from XKT.GKT import etl 4 | 5 | from XKT import MGKT 6 | 7 | batch_size = 16 8 | train = etl("../../data/assistment_2009_2010/train.json", batch_size=batch_size) 9 | valid = etl("../../data/assistment_2009_2010/test.json", batch_size=batch_size) 10 | test = etl("../../data/assistment_2009_2010/test.json", batch_size=batch_size) 11 | 12 | model = MGKT( 13 | hyper_params=dict( 14 | ku_num=124, 15 | graph="../../data/assistment_2009_2010/transition_graph.json", 16 | hidden_num=5, 17 | ) 18 | ) 19 | model.train(train, valid, end_epoch=2) 20 | model.save("mgkt") 21 | 22 | model = MGKT.from_pretrained("mgkt") 23 | print(model.eval(test)) 24 | -------------------------------------------------------------------------------- /examples/GKT/prepare_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true, 8 | "pycharm": { 9 | "name": "#%%\n" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stderr", 15 | "output_type": "stream", 16 | "text": [ 17 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/correct_transition_graph.json is saved as ..\\..\\data\\assistment_2009_2010\\correct_transition_graph.json\n", 18 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/ctrans_sim.json is saved as ..\\..\\data\\assistment_2009_2010\\ctrans_sim.json\n", 19 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/dense_graph.json is saved as ..\\..\\data\\assistment_2009_2010\\dense_graph.json\n", 20 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/test.json is saved as ..\\..\\data\\assistment_2009_2010\\test.json\n", 21 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/train.json is saved as ..\\..\\data\\assistment_2009_2010\\train.json\n", 22 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/trans_sim.json is saved as ..\\..\\data\\assistment_2009_2010\\trans_sim.json\n", 23 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/transition_graph.json is saved as ..\\..\\data\\assistment_2009_2010\\transition_graph.json\n" 24 | ] 25 | }, 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "Downloading ..\\..\\data\\assistment_2009_2010\\correct_transition_graph.json 100.00%: 34.1KB | 34.1KB\n", 31 | "Downloading ..\\..\\data\\assistment_2009_2010\\ctrans_sim.json 100.00%: 200KB | 200KB\n", 32 | "Downloading ..\\..\\data\\assistment_2009_2010\\dense_graph.json 100.00%: 361KB | 361KB\n", 33 | "Downloading ..\\..\\data\\assistment_2009_2010\\test.json 100.00%: 1.02MB | 1.02MB\n", 34 | "Downloading ..\\..\\data\\assistment_2009_2010\\train.json 100.00%: 3.53MB | 3.53MB\n", 35 | "Downloading ..\\..\\data\\assistment_2009_2010\\trans_sim.json 100.00%: 350KB | 350KB\n", 36 | "Downloading ..\\..\\data\\assistment_2009_2010\\transition_graph.json 100.00%: 52.9KB | 52.9KB\n" 37 | ] 38 | }, 39 | { 40 | "data": { 41 | "text/plain": "'..\\\\..\\\\data\\\\assistment_2009_2010'" 42 | }, 43 | "execution_count": 1, 44 | "metadata": {}, 45 | "output_type": "execute_result" 46 | } 47 | ], 48 | "source": [ 49 | "from EduData import get_data\n", 50 | "\n", 51 | "get_data(\"ktbd-a0910\", \"../../data\")" 52 | ] 53 | } 54 | ], 55 | "metadata": { 56 | "kernelspec": { 57 | "display_name": "Python 3", 58 | "language": "python", 59 | "name": "python3" 60 | }, 61 | "language_info": { 62 | "codemirror_mode": { 63 | "name": "ipython", 64 | "version": 2 65 | }, 66 | "file_extension": ".py", 67 | "mimetype": "text/x-python", 68 | "name": "python", 69 | "nbconvert_exporter": "python", 70 | "pygments_lexer": "ipython2", 71 | "version": "2.7.6" 72 | } 73 | }, 74 | "nbformat": 4, 75 | "nbformat_minor": 0 76 | } -------------------------------------------------------------------------------- /examples/SKT/MSKT.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/5/26 @ tongshiwei 3 | from XKT.SKT import etl 4 | 5 | from XKT import MSKT 6 | 7 | batch_size = 16 8 | train = etl("../../data/assistment_2009_2010/train.json", batch_size=batch_size) 9 | valid = etl("../../data/assistment_2009_2010/test.json", batch_size=batch_size) 10 | test = etl("../../data/assistment_2009_2010/test.json", batch_size=batch_size) 11 | 12 | model = MSKT( 13 | hyper_params=dict( 14 | ku_num=124, 15 | graph_params=[ 16 | ['../../data/assistment_2009_2010/correct_transition_graph.json', True], 17 | ['../../data/assistment_2009_2010/ctrans_sim.json', False] 18 | ], 19 | hidden_num=5, 20 | ) 21 | ) 22 | model.train(train, valid, end_epoch=2) 23 | model.save("mskt") 24 | 25 | model = MSKT.from_pretrained("mskt") 26 | print(model.eval(test)) 27 | -------------------------------------------------------------------------------- /examples/SKT/prepare_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true, 8 | "pycharm": { 9 | "name": "#%%\n" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stderr", 15 | "output_type": "stream", 16 | "text": [ 17 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/correct_transition_graph.json is saved as ..\\..\\data\\assistment_2009_2010\\correct_transition_graph.json\n", 18 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/ctrans_sim.json is saved as ..\\..\\data\\assistment_2009_2010\\ctrans_sim.json\n", 19 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/dense_graph.json is saved as ..\\..\\data\\assistment_2009_2010\\dense_graph.json\n", 20 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/test.json is saved as ..\\..\\data\\assistment_2009_2010\\test.json\n", 21 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/train.json is saved as ..\\..\\data\\assistment_2009_2010\\train.json\n", 22 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/trans_sim.json is saved as ..\\..\\data\\assistment_2009_2010\\trans_sim.json\n", 23 | "downloader, INFO http://base.ustc.edu.cn/data/ktbd/assistment_2009_2010/transition_graph.json is saved as ..\\..\\data\\assistment_2009_2010\\transition_graph.json\n" 24 | ] 25 | }, 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "Downloading ..\\..\\data\\assistment_2009_2010\\correct_transition_graph.json 100.00%: 34.1KB | 34.1KB\n", 31 | "Downloading ..\\..\\data\\assistment_2009_2010\\ctrans_sim.json 100.00%: 200KB | 200KB\n", 32 | "Downloading ..\\..\\data\\assistment_2009_2010\\dense_graph.json 100.00%: 361KB | 361KB\n", 33 | "Downloading ..\\..\\data\\assistment_2009_2010\\test.json 100.00%: 1.02MB | 1.02MB\n", 34 | "Downloading ..\\..\\data\\assistment_2009_2010\\train.json 100.00%: 3.53MB | 3.53MB\n", 35 | "Downloading ..\\..\\data\\assistment_2009_2010\\trans_sim.json 100.00%: 350KB | 350KB\n", 36 | "Downloading ..\\..\\data\\assistment_2009_2010\\transition_graph.json 100.00%: 52.9KB | 52.9KB\n" 37 | ] 38 | }, 39 | { 40 | "data": { 41 | "text/plain": "'..\\\\..\\\\data\\\\assistment_2009_2010'" 42 | }, 43 | "execution_count": 1, 44 | "metadata": {}, 45 | "output_type": "execute_result" 46 | } 47 | ], 48 | "source": [ 49 | "from EduData import get_data\n", 50 | "\n", 51 | "get_data(\"ktbd-a0910\", \"../../data\")" 52 | ] 53 | } 54 | ], 55 | "metadata": { 56 | "kernelspec": { 57 | "display_name": "Python 3", 58 | "language": "python", 59 | "name": "python3" 60 | }, 61 | "language_info": { 62 | "codemirror_mode": { 63 | "name": "ipython", 64 | "version": 2 65 | }, 66 | "file_extension": ".py", 67 | "mimetype": "text/x-python", 68 | "name": "python", 69 | "nbconvert_exporter": "python", 70 | "pygments_lexer": "ipython2", 71 | "version": "2.7.6" 72 | } 73 | }, 74 | "nbformat": 4, 75 | "nbformat_minor": 0 76 | } -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | # For pytest usage, refer to https://hb4dsai.readthedocs.io/zh/latest/Architecture/Test.html 3 | norecursedirs = docs *build* trash dev scripts examples 4 | 5 | # Deal with marker warnings 6 | markers = 7 | flake8: flake8 8 | 9 | # Enable line length testing with maximum line length of 120 10 | flake8-max-line-length = 120 11 | 12 | # Ignore module level import not at top of file (E402) 13 | # Others can be found in https://flake8.pycqa.org/en/latest/user/error-codes.html 14 | flake8-ignore = E402 F401 F403 15 | 16 | # --doctest-modules is used for unitest 17 | addopts = --doctest-modules --cov --cov-report=term-missing --flake8 18 | -------------------------------------------------------------------------------- /scripts/DKT/DKT.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/5/26 @ tongshiwei 3 | from fire import Fire 4 | from baize import path_append 5 | from EduData import get_data 6 | from XKT import DKT 7 | 8 | DATASET = { 9 | "a0910c": ( 10 | "ktbd-a0910c", 11 | "a0910c/train.json", 12 | "a0910c/valid.json", 13 | "a0910c/test.json", 14 | {"ku_num": 146, "hidden_num": 100} 15 | ) 16 | } 17 | 18 | 19 | def get_dataset_and_config(dataset, train=None, valid=None, test=None, hyper_params=None): 20 | if dataset in DATASET: 21 | dataset, train, valid, test, hyper_params = DATASET[dataset] 22 | 23 | data_dir = "../../data" 24 | else: 25 | data_dir = dataset 26 | hyper_params = {} if hyper_params is None else hyper_params 27 | 28 | get_data(dataset, data_dir) 29 | 30 | train = path_append(data_dir, train) 31 | valid = path_append(data_dir, valid) 32 | test = path_append(data_dir, test) 33 | return train, valid, test, hyper_params 34 | 35 | 36 | def run(mode, model, dataset, epoch, train_path=None, valid_path=None, test_path=None, embedding_dim=None, *args, 37 | **kwargs): 38 | train, valid, test, hyper_params = get_dataset_and_config(dataset, train_path, valid_path, test_path) 39 | loss_params = {} 40 | if mode in {"hs", "train"}: 41 | if model in {"dkt+", "edkt+"}: 42 | loss_params = {"lr": 0.1, "lw1": 0.5, "lw2": 0.5} 43 | elif model in {"edkt", "edkt+"}: 44 | hyper_params.update(dict(add_embedding_layer=True, embedding_dim=embedding_dim)) 45 | 46 | DKT.benchmark_train( 47 | train, 48 | valid, 49 | enable_hyper_search=True if mode == "hs" else False, 50 | end_epoch=epoch, 51 | loss_params=loss_params, 52 | hyper_params=hyper_params, 53 | save=False if mode == "hs" else True, 54 | model_dir=model, 55 | model_name=model, 56 | *args, **kwargs 57 | ) 58 | elif mode == "test": 59 | print(DKT.benchmark_eval(test, model, epoch)) 60 | else: 61 | raise ValueError("unknown mode %s" % mode) 62 | 63 | 64 | if __name__ == '__main__': 65 | Fire(run) 66 | 67 | # run("train", "dkt", "a0910c", 10) 68 | -------------------------------------------------------------------------------- /scripts/DKT/README.md: -------------------------------------------------------------------------------- 1 | # Benchmark 2 | 3 | ## Hyper-parameter search 4 | 5 | ```sh 6 | nnictl create --config config.yml 7 | ``` 8 | 9 | ## Train 10 | 11 | * dkt 12 | ```sh 13 | python3 DKT.py train dkt a0910c 10 --hyper_params_update '{"dropout": 0.5}' 14 | ``` 15 | 16 | * edkt 17 | ```sh 18 | python3 DKT.py train edkt a0910c 10 --embdding_dim 50 --hyper_params_update '{"hidden_num": 100}' 19 | ``` 20 | 21 | ## Test 22 | ```sh 23 | python3 DKT.py test dkt a0910c 1 24 | ``` 25 | -------------------------------------------------------------------------------- /scripts/DKT/config.yml: -------------------------------------------------------------------------------- 1 | searchSpaceFile: search_space.json 2 | trialCommand: python DKT.py hs dkt a0910c 10 # NOTE: change "python3" to "python" if you are using Windows 3 | trialGpuNumber: 0 4 | trialConcurrency: 1 5 | tuner: 6 | name: GridSearch 7 | trainingService: 8 | platform: local 9 | experimentWorkingDirectory: ./nni-experiments -------------------------------------------------------------------------------- /scripts/DKT/search_space.json: -------------------------------------------------------------------------------- 1 | { 2 | "hyper_params_update:hidden_num": {"_type":"choice","_value":[50, 100, 200]}, 3 | "hyper_params_update:dropout": {"_type":"choice","_value":[0, 0.2, 0.5]}, 4 | "optimizer_params_update:learning_rate": {"_type":"choice","_value":[0.1, 0.01, 0.001]}, 5 | "optimizer_params_update:wd": {"_type":"choice","_value":[0, 0.0001]} 6 | } -------------------------------------------------------------------------------- /scripts/DKVMN/DKVMN.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/5/26 @ tongshiwei 3 | from fire import Fire 4 | from baize import path_append 5 | from EduData import get_data 6 | from XKT import DKVMN 7 | 8 | DATASET = { 9 | "a0910c": ( 10 | "ktbd-a0910c", 11 | "a0910c/train.json", 12 | "a0910c/valid.json", 13 | "a0910c/test.json", 14 | dict( 15 | ku_num=146, 16 | key_embedding_dim=10, 17 | value_embedding_dim=10, 18 | key_memory_size=20, 19 | hidden_num=100 20 | ) 21 | ) 22 | } 23 | 24 | 25 | def get_dataset_and_config(dataset, train=None, valid=None, test=None, hyper_params=None): 26 | if dataset in DATASET: 27 | dataset, train, valid, test, hyper_params = DATASET[dataset] 28 | 29 | data_dir = "../../data" 30 | else: 31 | data_dir = dataset 32 | hyper_params = {} if hyper_params is None else hyper_params 33 | 34 | get_data(dataset, data_dir) 35 | 36 | train = path_append(data_dir, train) 37 | valid = path_append(data_dir, valid) 38 | test = path_append(data_dir, test) 39 | return train, valid, test, hyper_params 40 | 41 | 42 | def run(mode, model, dataset, epoch, train_path=None, valid_path=None, test_path=None, *args, **kwargs): 43 | train, valid, test, hyper_params = get_dataset_and_config(dataset, train_path, valid_path, test_path) 44 | loss_params = {} 45 | if mode in {"hs", "train"}: 46 | DKVMN.benchmark_train( 47 | train, 48 | valid, 49 | enable_hyper_search=True if mode == "hs" else False, 50 | end_epoch=epoch, 51 | loss_params=loss_params, 52 | hyper_params=hyper_params, 53 | save=False if mode == "hs" else True, 54 | model_dir=model, 55 | model_name=model, 56 | *args, **kwargs 57 | ) 58 | elif mode == "test": 59 | print(DKVMN.benchmark_eval(test, model, epoch)) 60 | else: 61 | raise ValueError("unknown mode %s" % mode) 62 | 63 | 64 | if __name__ == '__main__': 65 | Fire(run) 66 | 67 | # run("train", "dkvmn", "a0910c", 2) 68 | -------------------------------------------------------------------------------- /scripts/DKVMN/README.md: -------------------------------------------------------------------------------- 1 | # Benchmark 2 | 3 | ## Hyper-parameter search 4 | 5 | ```sh 6 | nnictl create --config config.yml 7 | ``` 8 | 9 | ## Train 10 | 11 | ```sh 12 | python3 DKVMN.py train dkvmn a0910c 10 --hyper_params_update '{"dropout": 0.5}' 13 | ``` 14 | 15 | 16 | ## Test 17 | ```sh 18 | python3 DKVMN.py test dkvmn a0910c 1 19 | ``` 20 | -------------------------------------------------------------------------------- /scripts/DKVMN/config.yml: -------------------------------------------------------------------------------- 1 | searchSpaceFile: search_space.json 2 | trialCommand: python DKVMN.py hs dkvmn a0910c 10 # NOTE: change "python3" to "python" if you are using Windows 3 | trialGpuNumber: 0 4 | trialConcurrency: 1 5 | tuner: 6 | name: GridSearch 7 | trainingService: 8 | platform: local 9 | experimentWorkingDirectory: ./nni-experiments -------------------------------------------------------------------------------- /scripts/DKVMN/search_space.json: -------------------------------------------------------------------------------- 1 | { 2 | "hyper_params_update:hidden_num": {"_type":"choice","_value":[30, 50, 100]}, 3 | "hyper_params_update:key_embedding_dim": {"_type":"choice","_value":[20, 30, 50]}, 4 | "hyper_params_update:value_embedding_dim": {"_type":"choice","_value":[20, 30, 50]}, 5 | "hyper_params_update:key_memory_size": {"_type":"choice","_value":[10, 20, 30, 50]}, 6 | "hyper_params_update:dropout": {"_type":"choice","_value":[0.5]}, 7 | "optimizer_params_update:learning_rate": {"_type":"choice","_value":[0.01, 0.001]}, 8 | "optimizer_params_update:wd": {"_type":"choice","_value":[0, 0.0001]} 9 | } -------------------------------------------------------------------------------- /scripts/GKT/MGKT.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/25 @ tongshiwei 3 | 4 | from fire import Fire 5 | from baize import path_append 6 | from EduData import get_data 7 | from XKT import MGKT 8 | 9 | DATASET = { 10 | "a0910": ( 11 | "ktbd-a0910", 12 | "assistment_2009_2010/train.json", 13 | "assistment_2009_2010/valid.json", 14 | "assistment_2009_2010/test.json", 15 | "assistment_2009_2010/transition_graph.json", 16 | dict( 17 | ku_num=124, 18 | hidden_num=16, 19 | ) 20 | ) 21 | } 22 | 23 | 24 | def get_dataset_and_config(dataset, train=None, valid=None, test=None, hyper_params=None): 25 | if dataset in DATASET: 26 | dataset, train, valid, test, graph, hyper_params = DATASET[dataset] 27 | 28 | data_dir = "../../data" 29 | graph = path_append(data_dir, graph, to_str=True) 30 | else: 31 | data_dir = dataset 32 | assert hyper_params 33 | graph = None 34 | 35 | get_data(dataset, data_dir) 36 | 37 | if graph is not None: 38 | hyper_params["graph"] = graph 39 | 40 | train = path_append(data_dir, train) 41 | valid = path_append(data_dir, valid) 42 | test = path_append(data_dir, test) 43 | 44 | return train, valid, test, hyper_params 45 | 46 | 47 | def run(mode, model, dataset, epoch, train_path=None, valid_path=None, test_path=None, *args, **kwargs): 48 | train, valid, test, hyper_params = get_dataset_and_config(dataset, train_path, valid_path, test_path) 49 | loss_params = {} 50 | if mode in {"hs", "train"}: 51 | MGKT.benchmark_train( 52 | train, 53 | valid, 54 | enable_hyper_search=True if mode == "hs" else False, 55 | end_epoch=epoch, 56 | loss_params=loss_params, 57 | hyper_params=hyper_params, 58 | save=False if mode == "hs" else True, 59 | model_dir=model, 60 | model_name=model, 61 | *args, **kwargs 62 | ) 63 | elif mode == "test": 64 | print(MGKT.benchmark_eval(test, model, epoch)) 65 | else: 66 | raise ValueError("unknown mode %s" % mode) 67 | 68 | 69 | if __name__ == '__main__': 70 | Fire(run) 71 | 72 | # run("train", "dkvmn", "a0910c", 2) 73 | -------------------------------------------------------------------------------- /scripts/GKT/README.md: -------------------------------------------------------------------------------- 1 | # Benchmark 2 | 3 | ## Hyper-parameter search 4 | 5 | ```sh 6 | nnictl create --config config.yml 7 | ``` 8 | 9 | ## Train 10 | 11 | ```sh 12 | python3 MGKT.py train gkt a0910 1 13 | ``` 14 | 15 | 16 | ## Test 17 | ```sh 18 | python3 MGKT.py test gkt a0910 0 19 | ``` 20 | -------------------------------------------------------------------------------- /scripts/GKT/config.yml: -------------------------------------------------------------------------------- 1 | searchSpaceFile: search_space.json 2 | trialCommand: python MGKT.py hs gkt a0910 5 # NOTE: change "python3" to "python" if you are using Windows 3 | trialGpuNumber: 0 4 | trialConcurrency: 1 5 | tuner: 6 | name: GridSearch 7 | trainingService: 8 | platform: local 9 | experimentWorkingDirectory: ./nni-experiments -------------------------------------------------------------------------------- /scripts/GKT/search_space.json: -------------------------------------------------------------------------------- 1 | { 2 | "hyper_params_update:hidden_num": {"_type":"choice","_value":[16, 32, 64]} 3 | } -------------------------------------------------------------------------------- /scripts/SKT/MSKT.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/25 @ tongshiwei 3 | 4 | from fire import Fire 5 | from baize import path_append 6 | from EduData import get_data 7 | from XKT import MSKT 8 | 9 | DATASET = { 10 | "a0910": ( 11 | "ktbd-a0910", 12 | "assistment_2009_2010/train.json", 13 | "assistment_2009_2010/valid.json", 14 | "assistment_2009_2010/test.json", 15 | [ 16 | ['assistment_2009_2010/correct_transition_graph.json', True], 17 | ['assistment_2009_2010/ctrans_sim.json', False] 18 | ], 19 | dict( 20 | ku_num=124, 21 | hidden_num=16, 22 | ) 23 | ) 24 | } 25 | 26 | 27 | def get_dataset_and_config(dataset, train=None, valid=None, test=None, hyper_params=None): 28 | if dataset in DATASET: 29 | dataset, train, valid, test, graphs, hyper_params = DATASET[dataset] 30 | 31 | data_dir = "../../data" 32 | for graph in graphs: 33 | graph[0] = path_append(data_dir, graph[0], to_str=True) 34 | else: 35 | data_dir = dataset 36 | assert hyper_params 37 | graphs = None 38 | 39 | get_data(dataset, data_dir) 40 | 41 | if graphs is not None: 42 | hyper_params["graph_params"] = graphs 43 | 44 | train = path_append(data_dir, train) 45 | valid = path_append(data_dir, valid) 46 | test = path_append(data_dir, test) 47 | 48 | return train, valid, test, hyper_params 49 | 50 | 51 | def run(mode, model, dataset, epoch, train_path=None, valid_path=None, test_path=None, *args, **kwargs): 52 | train, valid, test, hyper_params = get_dataset_and_config(dataset, train_path, valid_path, test_path) 53 | loss_params = {} 54 | if mode in {"hs", "train"}: 55 | hyper_params.update({"net_type": model.upper()}) 56 | MSKT.benchmark_train( 57 | train, 58 | valid, 59 | enable_hyper_search=True if mode == "hs" else False, 60 | end_epoch=epoch, 61 | loss_params=loss_params, 62 | hyper_params=hyper_params, 63 | save=False if mode == "hs" else True, 64 | model_dir=model, 65 | model_name=model, 66 | *args, **kwargs 67 | ) 68 | elif mode == "test": 69 | print(MSKT.benchmark_eval(test, model, epoch)) 70 | else: 71 | raise ValueError("unknown mode %s" % mode) 72 | 73 | 74 | if __name__ == '__main__': 75 | Fire(run) 76 | 77 | # run("train", "dkvmn", "a0910c", 2) 78 | -------------------------------------------------------------------------------- /scripts/SKT/README.md: -------------------------------------------------------------------------------- 1 | # Benchmark 2 | 3 | ## Hyper-parameter search 4 | 5 | ```sh 6 | nnictl create --config config.yml 7 | ``` 8 | 9 | ## Train 10 | 11 | ```sh 12 | python3 MSKT.py train skt a0910 1 13 | ``` 14 | 15 | 16 | ## Test 17 | ```sh 18 | python3 MSKT.py test skt a0910 0 19 | ``` 20 | -------------------------------------------------------------------------------- /scripts/SKT/config.yml: -------------------------------------------------------------------------------- 1 | searchSpaceFile: search_space.json 2 | trialCommand: python MSKT.py hs skt a0910 5 # NOTE: change "python3" to "python" if you are using Windows 3 | trialGpuNumber: 0 4 | trialConcurrency: 1 5 | tuner: 6 | name: GridSearch 7 | trainingService: 8 | platform: local 9 | experimentWorkingDirectory: ./nni-experiments -------------------------------------------------------------------------------- /scripts/SKT/search_space.json: -------------------------------------------------------------------------------- 1 | { 2 | "hyper_params_update:hidden_num": {"_type":"choice","_value":[16, 32, 64]} 3 | } -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [coverage:run] 2 | source=XKT 3 | [coverage:report] 4 | exclude_lines = 5 | pragma: no cover 6 | pass 7 | raise NotImplementedError 8 | if __name__ == '__main__': 9 | def __str__ 10 | def __repr__ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # create by tongshiwei on 2019/6/25 3 | 4 | import logging 5 | from setuptools import setup, find_packages 6 | 7 | test_deps = [ 8 | 'pytest>=4', 9 | 'pytest-cov>=2.6.0', 10 | 'pytest-flake8', 11 | 'EduData>=0.0.4', 12 | ] 13 | 14 | try: 15 | import mxnet 16 | 17 | mxnet_requires = [] 18 | except ModuleNotFoundError: 19 | mxnet_requires = ["mxnet"] 20 | except Exception as e: 21 | mxnet_requires = [] 22 | logging.error(e) 23 | 24 | setup( 25 | name='XKT', 26 | version='0.0.2', 27 | packages=find_packages(), 28 | python_requires='>=3.6', 29 | long_description='Refer to full documentation https://github.com/bigdata-ustc/XKT/blob/master/README.md' 30 | ' for detailed information.', 31 | description='This project aims to ' 32 | 'provide multiple knowledge tracing models.', 33 | extras_require={ 34 | 'test': test_deps, 35 | }, 36 | install_requires=mxnet_requires + [ 37 | 'tqdm', 38 | 'PyBaize>=0.0.5', 39 | ], # And any other dependencies foo needs 40 | entry_points={ 41 | }, 42 | ) 43 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # create by tongshiwei on 2019-7-13 3 | -------------------------------------------------------------------------------- /tests/dkt/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/24 @ tongshiwei 3 | -------------------------------------------------------------------------------- /tests/dkt/conftest.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/23 @ tongshiwei 3 | import pytest 4 | 5 | import json 6 | from baize import as_out_io, path_append 7 | from XKT.utils.tests import pseudo_data_generation 8 | from XKT.DKT.etl import transform 9 | 10 | 11 | @pytest.fixture(scope="package") 12 | def conf(): 13 | item_num = 10 14 | hidden_num = 10 15 | batch_size = 32 16 | return item_num, hidden_num, batch_size 17 | 18 | 19 | @pytest.fixture(scope="package") 20 | def pseudo_data(conf): 21 | ques_num, *_ = conf 22 | return pseudo_data_generation(ques_num) 23 | 24 | 25 | @pytest.fixture(scope="package") 26 | def data(pseudo_data, conf): 27 | *_, batch_size = conf 28 | return transform(pseudo_data, batch_size) 29 | 30 | 31 | @pytest.fixture(scope="package") 32 | def train_file(pseudo_data, tmpdir_factory): 33 | data_dir = tmpdir_factory.mktemp("data") 34 | filepath = path_append(data_dir, "data.json", to_str=True) 35 | with as_out_io(filepath) as wf: 36 | for line in pseudo_data: 37 | print(json.dumps(line), file=wf) 38 | return filepath 39 | -------------------------------------------------------------------------------- /tests/dkt/test_dkt.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2019/12/6 @ tongshiwei 3 | import pytest 4 | from XKT import DKT 5 | 6 | 7 | @pytest.mark.parametrize("add_embedding_layer", [True, False]) 8 | def test_train(data, conf, tmpdir, add_embedding_layer): 9 | ku_num, hidden_num, batch_size = conf 10 | model = DKT( 11 | batch_size=batch_size, 12 | hyper_params={ 13 | "ku_num": ku_num, 14 | "hidden_num": hidden_num, 15 | "add_embedding_layer": add_embedding_layer, 16 | "embedding_dim": hidden_num 17 | } 18 | ) 19 | print(model.cfg) 20 | model.train(data, valid_data=data, end_epoch=1) 21 | filepath = tmpdir.mkdir("dkt") 22 | model.save(filepath) 23 | model = DKT.from_pretrained(filepath) 24 | model.eval(data) 25 | 26 | 27 | @pytest.mark.parametrize("rnn_type", ["rnn", "lstm", "gru"]) 28 | def test_benchmark(train_file, conf, tmpdir, rnn_type): 29 | ku_num, hidden_num, batch_size = conf 30 | model_dir = str(tmpdir.mkdir("dkt")) 31 | DKT.benchmark_train( 32 | train_path=train_file, 33 | valid_path=train_file, 34 | enable_hyper_search=False, 35 | save=True, 36 | model_dir=model_dir, 37 | end_epoch=1, 38 | batch_size=batch_size, 39 | hyper_params={ 40 | "rnn_type": rnn_type, 41 | "ku_num": ku_num, 42 | "hidden_num": hidden_num, 43 | }, 44 | 45 | ) 46 | DKT.benchmark_eval(train_file, model_dir, best_epoch=0) 47 | 48 | 49 | def test_exception(): 50 | with pytest.raises(TypeError): 51 | DKT(hyper_params={"ku_num": 10, "rnn_type": "error", "hidden_num": 10}) 52 | -------------------------------------------------------------------------------- /tests/dkvmn/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/24 @ tongshiwei 3 | -------------------------------------------------------------------------------- /tests/dkvmn/conftest.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/23 @ tongshiwei 3 | import pytest 4 | 5 | import json 6 | from baize import as_out_io, path_append 7 | from XKT.utils.tests import pseudo_data_generation 8 | from XKT.DKVMN.etl import transform 9 | 10 | 11 | @pytest.fixture(scope="package") 12 | def conf(): 13 | item_num = 10 14 | hidden_num = 10 15 | batch_size = 32 16 | return item_num, hidden_num, batch_size 17 | 18 | 19 | @pytest.fixture(scope="package") 20 | def pseudo_data(conf): 21 | ques_num, *_ = conf 22 | return pseudo_data_generation(ques_num) 23 | 24 | 25 | @pytest.fixture(scope="package") 26 | def data(pseudo_data, conf): 27 | *_, batch_size = conf 28 | return transform(pseudo_data, batch_size) 29 | 30 | 31 | @pytest.fixture(scope="package") 32 | def train_file(pseudo_data, tmpdir_factory): 33 | data_dir = tmpdir_factory.mktemp("data") 34 | filepath = path_append(data_dir, "data.json", to_str=True) 35 | with as_out_io(filepath) as wf: 36 | for line in pseudo_data: 37 | print(json.dumps(line), file=wf) 38 | return filepath 39 | -------------------------------------------------------------------------------- /tests/dkvmn/test_dkvmn.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2019/12/6 @ tongshiwei 3 | import pytest 4 | from XKT import DKVMN 5 | 6 | 7 | def test_train(data, conf, tmpdir): 8 | ku_num, hidden_num, batch_size = conf 9 | model = DKVMN( 10 | batch_size=batch_size, 11 | hyper_params=dict( 12 | ku_num=ku_num, 13 | key_embedding_dim=hidden_num, 14 | value_embedding_dim=hidden_num, 15 | hidden_num=hidden_num, 16 | key_memory_size=hidden_num) 17 | ) 18 | print(model.cfg) 19 | model.train(data, valid_data=data, end_epoch=1) 20 | filepath = tmpdir.mkdir("dkvmn") 21 | model.save(filepath) 22 | model = DKVMN.from_pretrained(filepath) 23 | model.eval(data) 24 | 25 | 26 | def test_benchmark(train_file, conf, tmpdir): 27 | ku_num, hidden_num, batch_size = conf 28 | model_dir = str(tmpdir.mkdir("dkvmn")) 29 | DKVMN.benchmark_train( 30 | train_path=train_file, 31 | valid_path=train_file, 32 | enable_hyper_search=False, 33 | save=True, 34 | model_dir=model_dir, 35 | end_epoch=1, 36 | batch_size=batch_size, 37 | hyper_params=dict( 38 | ku_num=ku_num, 39 | key_embedding_dim=hidden_num, 40 | value_embedding_dim=hidden_num, 41 | hidden_num=hidden_num, 42 | key_memory_size=hidden_num 43 | ) 44 | ) 45 | DKVMN.benchmark_eval(train_file, model_dir, best_epoch=0) 46 | -------------------------------------------------------------------------------- /tests/gkt/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/24 @ tongshiwei 3 | -------------------------------------------------------------------------------- /tests/gkt/conftest.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/23 @ tongshiwei 3 | import pytest 4 | 5 | import json 6 | from baize import as_out_io, path_append 7 | from XKT.utils.tests import pseudo_data_generation 8 | from XKT.GKT.etl import transform 9 | 10 | 11 | @pytest.fixture(scope="package") 12 | def conf(): 13 | item_num = 10 14 | hidden_num = 10 15 | batch_size = 32 16 | return item_num, hidden_num, batch_size 17 | 18 | 19 | @pytest.fixture(scope="package") 20 | def pseudo_data(conf): 21 | ques_num, *_ = conf 22 | return pseudo_data_generation(ques_num) 23 | 24 | 25 | @pytest.fixture(scope="package") 26 | def data(pseudo_data, conf): 27 | *_, batch_size = conf 28 | return transform(pseudo_data, batch_size) 29 | 30 | 31 | @pytest.fixture(scope="package") 32 | def train_file(pseudo_data, tmpdir_factory): 33 | data_dir = tmpdir_factory.mktemp("data") 34 | filepath = path_append(data_dir, "data.json", to_str=True) 35 | with as_out_io(filepath) as wf: 36 | for line in pseudo_data: 37 | print(json.dumps(line), file=wf) 38 | return filepath 39 | 40 | 41 | @pytest.fixture(scope="package") 42 | def graph_file(conf, tmpdir_factory): 43 | graph_dir = tmpdir_factory.mktemp("graph") 44 | filepath = path_append(graph_dir, "graph.json", to_str=True) 45 | with as_out_io(filepath) as wf: 46 | json.dump([[0, 1], [0, 2]], wf) 47 | return filepath 48 | -------------------------------------------------------------------------------- /tests/gkt/test_mgkt.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2019/12/6 @ tongshiwei 3 | 4 | from XKT import MGKT 5 | 6 | 7 | def test_benchmark(train_file, graph_file, conf, tmpdir): 8 | ku_num, hidden_num, batch_size = conf 9 | model_dir = str(tmpdir.mkdir("mgkt")) 10 | MGKT.benchmark_train( 11 | train_path=train_file, 12 | valid_path=train_file, 13 | enable_hyper_search=False, 14 | save=True, 15 | model_dir=model_dir, 16 | end_epoch=1, 17 | batch_size=batch_size, 18 | hyper_params={ 19 | "ku_num": ku_num, 20 | "hidden_num": hidden_num, 21 | "graph": graph_file 22 | }, 23 | ) 24 | MGKT.benchmark_eval(train_file, model_dir, best_epoch=0) 25 | -------------------------------------------------------------------------------- /tests/skt/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/25 @ tongshiwei 3 | -------------------------------------------------------------------------------- /tests/skt/conftest.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2021/8/23 @ tongshiwei 3 | import pytest 4 | 5 | import json 6 | from baize import as_out_io, path_append 7 | from XKT.utils.tests import pseudo_data_generation 8 | from XKT.GKT.etl import transform 9 | 10 | 11 | @pytest.fixture(scope="package") 12 | def conf(): 13 | item_num = 10 14 | hidden_num = 10 15 | batch_size = 32 16 | return item_num, hidden_num, batch_size 17 | 18 | 19 | @pytest.fixture(scope="package") 20 | def pseudo_data(conf): 21 | ques_num, *_ = conf 22 | return pseudo_data_generation(ques_num) 23 | 24 | 25 | @pytest.fixture(scope="package") 26 | def data(pseudo_data, conf): 27 | *_, batch_size = conf 28 | return transform(pseudo_data, batch_size) 29 | 30 | 31 | @pytest.fixture(scope="package") 32 | def train_file(pseudo_data, tmpdir_factory): 33 | data_dir = tmpdir_factory.mktemp("data") 34 | filepath = path_append(data_dir, "data.json", to_str=True) 35 | with as_out_io(filepath) as wf: 36 | for line in pseudo_data: 37 | print(json.dumps(line), file=wf) 38 | return filepath 39 | 40 | 41 | @pytest.fixture(scope="package") 42 | def graphs(conf, tmpdir_factory): 43 | graph_dir = tmpdir_factory.mktemp("graph") 44 | _graphs = [] 45 | filepath = path_append(graph_dir, "graph_1.json", to_str=True) 46 | with as_out_io(filepath) as wf: 47 | json.dump([[0, 1], [0, 2]], wf) 48 | _graphs.append([filepath, False]) 49 | _graphs.append([filepath, True]) 50 | filepath = path_append(graph_dir, "graph_2.json", to_str=True) 51 | with as_out_io(filepath) as wf: 52 | json.dump([[1, 2, 0.9], [1, 5, 0.1]], wf) 53 | _graphs.append([filepath, True, 0.5]) 54 | return _graphs 55 | -------------------------------------------------------------------------------- /tests/skt/test_mskt.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 2019/12/6 @ tongshiwei 3 | 4 | import pytest 5 | from XKT import MSKT 6 | 7 | 8 | @pytest.mark.parametrize("net_type", ["SKT", "SKT_TE", "SKTPart", "SKTSync"]) 9 | def test_benchmark(train_file, graphs, conf, tmpdir, net_type): 10 | ku_num, hidden_num, batch_size = conf 11 | model_dir = str(tmpdir.mkdir("mskt")) 12 | MSKT.benchmark_train( 13 | train_path=train_file, 14 | valid_path=train_file, 15 | enable_hyper_search=False, 16 | save=True, 17 | model_dir=model_dir, 18 | end_epoch=1, 19 | batch_size=batch_size, 20 | hyper_params={ 21 | "ku_num": ku_num, 22 | "hidden_num": hidden_num, 23 | "graph_params": graphs, 24 | "net_type": net_type 25 | }, 26 | ) 27 | MSKT.benchmark_eval(train_file, model_dir, best_epoch=0) 28 | --------------------------------------------------------------------------------