├── .github └── workflows │ ├── python-publish-test.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── assets ├── framework.png └── logo.png ├── docs ├── .readthedocs.yaml ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _static │ └── favicon.ico │ ├── assets │ ├── dataflow.jpg │ ├── framework.png │ └── logo.png │ ├── conf.py │ ├── developer_guide │ ├── customize_atomic_op.md │ ├── customize_datatpl.md │ ├── customize_evaltpl.md │ ├── customize_modeltpl.md │ └── customize_traintpl.md │ ├── features │ ├── atomic_files.md │ ├── atomic_operations.md │ ├── dataset_folder_protocol.md │ ├── global_cfg_obj.md │ ├── inheritable_config.md │ └── standard_datamodule.md │ ├── get_started │ ├── install.md │ └── quick_start.md │ ├── index.rst │ └── user_guide │ ├── atom_op.md │ ├── datasets.md │ ├── models.md │ ├── reference_table.md │ ├── usage.rst │ └── usage │ ├── aht.md │ ├── atomic_cmds.md │ ├── run_edustudio.md │ └── use_case_of_config.md ├── edustudio ├── __init__.py ├── assets │ └── datasets.yaml ├── atom_op │ ├── __init__.py │ ├── mid2cache │ │ ├── CD │ │ │ ├── __init__.py │ │ │ ├── data_split4cd.py │ │ │ └── filter_records4cd.py │ │ ├── KT │ │ │ ├── __init__.py │ │ │ ├── build_seq_inter_feats.py │ │ │ ├── cpt_as_exer.py │ │ │ ├── data_split4kt.py │ │ │ ├── gen_cpt_seq.py │ │ │ └── gen_unfold_cpt_seq.py │ │ ├── __init__.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── base_mid2cache.py │ │ │ ├── build_cpt_relation.py │ │ │ ├── build_missing_Q.py │ │ │ ├── fill_missing_Q.py │ │ │ ├── filtering_records_by_attr.py │ │ │ ├── gen_q_mat.py │ │ │ ├── label2int.py │ │ │ ├── merge_divided_splits.py │ │ │ └── remapid.py │ │ └── single │ │ │ ├── M2C_CDGK_OP.py │ │ │ ├── M2C_CL4KT_OP.py │ │ │ ├── M2C_CNCDQ_OP.py │ │ │ ├── M2C_DIMKT_OP.py │ │ │ ├── M2C_DKTDSC_OP.py │ │ │ ├── M2C_DKTForget_OP.py │ │ │ ├── M2C_EERNN_OP.py │ │ │ ├── M2C_LPKT_OP.py │ │ │ ├── M2C_MGCD_OP.py │ │ │ ├── M2C_QDKT_OP.py │ │ │ ├── M2C_RCD_OP.py │ │ │ └── __init__.py │ └── raw2mid │ │ ├── __init__.py │ │ ├── aaai_2023.py │ │ ├── algebra2005.py │ │ ├── assist_0910.py │ │ ├── assist_1213.py │ │ ├── assist_1516.py │ │ ├── assist_2017.py │ │ ├── bridge2006.py │ │ ├── ednet_kt1.py │ │ ├── frcsub.py │ │ ├── junyi_area_topic_as_cpt.py │ │ ├── junyi_exer_as_cpt.py │ │ ├── math1.py │ │ ├── math2.py │ │ ├── nips12.py │ │ ├── nips34.py │ │ ├── raw2mid.py │ │ ├── simulated5.py │ │ ├── slp_english.py │ │ └── slp_math.py ├── datatpl │ ├── CD │ │ ├── CDGKDataTPL.py │ │ ├── CDInterDataTPL.py │ │ ├── CDInterExtendsQDataTPL.py │ │ ├── CNCDFDataTPL.py │ │ ├── CNCDQDataTPL.py │ │ ├── DCDDataTPL.py │ │ ├── ECDDataTPL.py │ │ ├── FAIRDataTPL.py │ │ ├── HierCDFDataTPL.py │ │ ├── IRRDataTPL.py │ │ ├── MGCDDataTPL.py │ │ ├── RCDDataTPL.py │ │ └── __init__.py │ ├── KT │ │ ├── CL4KTDataTPL.py │ │ ├── DIMKTDataTPL.py │ │ ├── DKTDSCDataTPL.py │ │ ├── DKTForgetDataTPL.py │ │ ├── EERNNDataTPL.py │ │ ├── EKTDataTPL.py │ │ ├── GKTDataTPL.py │ │ ├── KTInterCptAsExerDataTPL.py │ │ ├── KTInterCptUnfoldDataTPL.py │ │ ├── KTInterDataTPL.py │ │ ├── KTInterExtendsQDataTPL.py │ │ ├── LPKTDataTPL.py │ │ ├── QDKTDataTPL.py │ │ ├── RKTDataTPL.py │ │ └── __init__.py │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── base_datatpl.py │ │ ├── edu_datatpl.py │ │ ├── general_datatpl.py │ │ └── proxy_datatpl.py │ └── utils │ │ ├── __init__.py │ │ ├── common.py │ │ ├── pad_seq_util.py │ │ └── spliter_util.py ├── evaltpl │ ├── __init__.py │ ├── base_evaltpl.py │ ├── fairness_evaltpl.py │ ├── identifiability_evaltpl.py │ ├── interpretability_evaltpl.py │ └── prediction_evaltpl.py ├── model │ ├── CD │ │ ├── __init__.py │ │ ├── cdgk.py │ │ ├── cdmfkc.py │ │ ├── cncd_f.py │ │ ├── cncd_q.py │ │ ├── dcd.py │ │ ├── dina.py │ │ ├── ecd.py │ │ ├── faircd.py │ │ ├── hier_cdf.py │ │ ├── irr.py │ │ ├── irt.py │ │ ├── kancd.py │ │ ├── kscd.py │ │ ├── mf.py │ │ ├── mgcd.py │ │ ├── mirt.py │ │ ├── ncdm.py │ │ └── rcd.py │ ├── KT │ │ ├── GKT │ │ │ ├── __init__.py │ │ │ ├── building_blocks.py │ │ │ ├── gkt.py │ │ │ └── losses.py │ │ ├── __init__.py │ │ ├── akt.py │ │ ├── atkt.py │ │ ├── ckt.py │ │ ├── cl4kt.py │ │ ├── ct_ncm.py │ │ ├── deep_irt.py │ │ ├── dimkt.py │ │ ├── dkt.py │ │ ├── dkt_dsc.py │ │ ├── dkt_forget.py │ │ ├── dkt_plus.py │ │ ├── dkvmn.py │ │ ├── dtransformer.py │ │ ├── eernn.py │ │ ├── ekt.py │ │ ├── hawkeskt.py │ │ ├── iekt.py │ │ ├── kqn.py │ │ ├── lpkt.py │ │ ├── lpkt_s.py │ │ ├── qdkt.py │ │ ├── qikt.py │ │ ├── rkt.py │ │ ├── saint.py │ │ ├── saint_plus.py │ │ ├── sakt.py │ │ ├── simplekt.py │ │ └── skvmn.py │ ├── __init__.py │ ├── basemodel.py │ ├── gd_basemodel.py │ └── utils │ │ ├── __init__.py │ │ ├── common.py │ │ └── components.py ├── quickstart │ ├── __init__.py │ ├── atom_cmds.py │ ├── init_all.py │ ├── parse_cfg.py │ └── quickstart.py ├── settings.py ├── traintpl │ ├── __init__.py │ ├── adversarial_traintpl.py │ ├── atkt_traintpl.py │ ├── base_traintpl.py │ ├── dcd_traintpl.py │ ├── gd_traintpl.py │ ├── general_traintpl.py │ └── group_cd_traintpl.py └── utils │ ├── __init__.py │ ├── callback │ ├── __init__.py │ ├── callBackList.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── baseLogger.py │ │ ├── callback.py │ │ ├── earlyStopping.py │ │ ├── epochPredict.py │ │ ├── history.py │ │ ├── modelCheckPoint.py │ │ └── tensorboardCallBack.py │ └── modeState.py │ └── common │ ├── __init__.py │ ├── commonUtil.py │ ├── configUtil.py │ └── loggerUtil.py ├── examples ├── 1.run_cd_demo.py ├── 2.run_kt_demo.py ├── 3.run_with_customized_tpl.py ├── 4.run_cmd_demo.py ├── 5.run_with_hyperopt.py ├── 6.run_with_ray.tune.py ├── 7.run_toy_demo.py └── single_model │ ├── run_akt_demo.py │ ├── run_atkt_demo.py │ ├── run_cdgk_demo.py │ ├── run_cdmfkc_demo.py │ ├── run_ckt_demo.py │ ├── run_cl4kt_demo.py │ ├── run_cncd_f_demo.py │ ├── run_cncdq_demo.py │ ├── run_ctncm_demo.py │ ├── run_dcd_demo.py │ ├── run_deepirt_demo.py │ ├── run_dimkt_demo.py │ ├── run_dina_demo.py │ ├── run_dkt_demo.py │ ├── run_dkt_dsc_demo.py │ ├── run_dkt_plus_demo.py │ ├── run_dktforget_demo.py │ ├── run_dkvmn_demo.py │ ├── run_dtransformer_demo.py │ ├── run_ecd_demo.py │ ├── run_eernn_demo.py │ ├── run_ekt_demo.py │ ├── run_faircd_irt_demo.py │ ├── run_faircd_mirt_demo.py │ ├── run_faircd_ncdm_demo.py │ ├── run_gkt_demo.py │ ├── run_hawkeskt_demo.py │ ├── run_hiercdf_demo.py │ ├── run_iekt_demo.py │ ├── run_irr_demo.py │ ├── run_irt_demo.py │ ├── run_kancd_demo.py │ ├── run_kqn_demo.py │ ├── run_kscd_demo.py │ ├── run_lpkt_demo.py │ ├── run_lpkt_s_demo.py │ ├── run_mf_demo.py │ ├── run_mgcd_demo.py │ ├── run_mirt_demo.py │ ├── run_ncdm_demo.py │ ├── run_qdkt_demo.py │ ├── run_qikt_demo.py │ ├── run_rcd_demo.py │ ├── run_rkt_demo.py │ ├── run_saint_demo.py │ ├── run_saint_plus_demo.py │ ├── run_sakt_demo.py │ ├── run_simplekt_demo.py │ └── run_skvmn_demo.py ├── requirements.txt ├── setup.py └── tests └── test_run.py /.github/workflows/python-publish-test.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package TEST 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | permissions: 7 | contents: read 8 | 9 | jobs: 10 | deploy: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Set up Python 17 | uses: actions/setup-python@v3 18 | with: 19 | python-version: '3.9' 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install build 24 | pip install pytest 25 | pip install torch --index-url https://download.pytorch.org/whl/cpu 26 | pip install -e . --verbose 27 | pip install -r requirements.txt 28 | - name: Test 29 | run: | 30 | cd tests && pytest && cd .. 31 | - name: Build package 32 | run: python -m build 33 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | deploy: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python 19 | uses: actions/setup-python@v3 20 | with: 21 | python-version: '3.9' 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install build 26 | pip install pytest 27 | pip install torch --index-url https://download.pytorch.org/whl/cpu 28 | pip install -e . --verbose 29 | pip install -r requirements.txt 30 | - name: Test 31 | run: | 32 | cd tests && pytest && cd .. 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@v1.13.0 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .ipynb_checkpoints/ 3 | __pycache__/ 4 | /data/ 5 | /temp/ 6 | /archive/ 7 | /test/ 8 | /conf/ 9 | /.vscode/ 10 | /examples/**/archive/ 11 | /examples/**/temp/ 12 | /examples/**/data/ 13 | /build/ 14 | /dist/ 15 | *.egg-info/ 16 | docs/build/ 17 | /.pytest_cache/ 18 | /tests/archive/ 19 | /tests/conf/ 20 | /tests/temp/ 21 | /tests/data/ 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 HFUT-LEC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include edustudio/assets/ *.yaml 2 | -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/assets/framework.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/assets/logo.png -------------------------------------------------------------------------------- /docs/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/source/conf.py 17 | 18 | # We recommend specifying your dependencies to enable reproducible builds: 19 | # https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 20 | python: 21 | install: 22 | - requirements: docs/requirements.txt 23 | 24 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==6.2.1 2 | sphinx_rtd_theme==1.2.2 3 | myst-parser==2.0.0 4 | esbonio==0.16.1 5 | sphinx_copybutton 6 | sphinxcontrib-napoleon 7 | -------------------------------------------------------------------------------- /docs/source/_static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/docs/source/_static/favicon.ico -------------------------------------------------------------------------------- /docs/source/assets/dataflow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/docs/source/assets/dataflow.jpg -------------------------------------------------------------------------------- /docs/source/assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/docs/source/assets/framework.png -------------------------------------------------------------------------------- /docs/source/assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/docs/source/assets/logo.png -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = 'EduStudio' 10 | copyright = '2023, HFUT-LEC' 11 | author = 'HFUT-LEC' 12 | release = 'v1.1.7' 13 | 14 | import sphinx_rtd_theme 15 | import os 16 | import sys 17 | 18 | sys.path.insert(0, os.path.abspath("../..")) 19 | 20 | # -- General configuration --------------------------------------------------- 21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 22 | 23 | extensions = [ 24 | 'myst_parser', 25 | "sphinx.ext.autodoc", 26 | "sphinx.ext.napoleon", 27 | "sphinx.ext.viewcode", 28 | "sphinx_copybutton", 29 | ] 30 | source_suffix = ['.rst', '.md'] 31 | templates_path = ['_templates'] 32 | exclude_patterns = [] 33 | 34 | 35 | # -- Options for HTML output ------------------------------------------------- 36 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 37 | 38 | html_theme = 'sphinx_rtd_theme' 39 | # html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 40 | html_static_path = ['_static'] 41 | html_favicon = '_static/favicon.ico' 42 | -------------------------------------------------------------------------------- /docs/source/developer_guide/customize_evaltpl.md: -------------------------------------------------------------------------------- 1 | Customize Evaluation Template 2 | ====================== 3 | Here, we present how to develop a new Evaluation Template, and apply it into EduStudio. 4 | EduStudio provides the EvalTPL Protocol in ``EduStudio.edustudio.evaltpl.baseevaltpl.BaseEvalTPL`` (``BaseEvalTPL``). 5 | 6 | EvalTPL Protocol 7 | ---------------------- 8 | 9 | ### BaseEvalTPL 10 | The protocols in ``BaseEvalTPL`` are listed as follows. 11 | 12 | | name | description | type | note | 13 | | ----------------- | ------------------------- | ------------------ | ----------------------- | 14 | | default_cfg | the default configuration | class variable | | 15 | | eval | calculate metric results | function interface | implemented by subclass | 16 | | _check_params | check parameters | function interface | implemented by subclass | 17 | | set_callback_list | set callback list | function interface | implemented by subclass | 18 | | set_dataloaders | set dataloaders | function interface | implemented by subclass | 19 | | add_extra_data | add extra data | function interface | implemented by subclass | 20 | 21 | 22 | 23 | EvalTPLs 24 | ---------------------- 25 | 26 | EduStudio provides ``PredictionEvalTPL`` and ``InterpretabilityEvalTPL``, which inherent ``BaseEvalTPL``. 27 | 28 | ### PredictionEvalTPL 29 | This EvalTPL is for the model evaluation using binary classification metrics. 30 | The protocols in ``PredictionEvalTPL`` are listed as follows. 31 | 32 | 33 | ### InterpretabilityEvalTPL 34 | This EvalTPL is for the model evaluation for interpretability. It uses states of students and Q matrix for ``eval``, which are domain-specific in student assessment. 35 | 36 | ## Develop a New EvalTPL in EduStudio 37 | 38 | If you want to develop a new EvalTPl in EduStudio, you can inherent ``BaseEvalTPL`` and revise ``eval`` method. 39 | 40 | ### Example 41 | 42 | ```python 43 | from .base_evaltpl import BaseEvalTPL 44 | from sklearn.metrics import accuracy_score, coverage_error 45 | 46 | class NewEvalFmt(BaseEvalFmt): 47 | default_cfg = { 48 | 'use_metrics': ['acc', 'coverage_error'] 49 | } 50 | 51 | def __init__(self, cfg): 52 | super().__init__(cfg) 53 | 54 | def eval(self, y_pd, y_gt, **kwargs): 55 | if not isinstance(y_pd, np.ndarray): y_pd = tensor2npy(y_pd) 56 | if not isinstance(y_gt, np.ndarray): y_gt = tensor2npy(y_gt) 57 | metric_result = {} 58 | ignore_metrics = kwargs.get('ignore_metrics', {}) 59 | for metric_name in self.evalfmt_cfg[self.__class__.__name__]['use_metrics']: 60 | if metric_name not in ignore_metrics: 61 | metric_result[metric_name] = self._get_metrics(metric_name)(y_gt, y_pd) 62 | return metric_result 63 | 64 | def _get_metrics(self, metric): 65 | if metric == "acc": 66 | return lambda y_gt, y_pd: accuracy_score(y_gt, np.where(y_pd >= 0.5, 1, 0)) 67 | elif metric == 'coverage_error': 68 | return lambda y_gt, y_pd: coverage_error(y_gt, y_pd) 69 | else: 70 | raise NotImplementedError 71 | ``` 72 | -------------------------------------------------------------------------------- /docs/source/features/atomic_files.md: -------------------------------------------------------------------------------- 1 | # Middle Data Format Protocol 2 | 3 | In `EduStudio`, we adopt a flexible CSV (Comma-Separated Values) file format following [Recbole](https://recbole.io/atomic_files.html). The flexible CSV format is defined in `middata` stage of dataset (see dataset stage protocol for details). 4 | 5 | The Middle Data Format Protocol including two parts: `Columns name Format` and `Filename Format`. 6 | 7 | ## Columns Name Format 8 | 9 | | feat_type | Explanations | Examples | 10 | | ------------- | --------------------------- | --------------------------------- | 11 | | **token** | single discrete feature | exer_id, stu_id | 12 | | **token_seq** | discrete features sequence | knowledge concept seq of exercise | 13 | | **float** | single continuous feature | label, start_timestamp | 14 | | **float_seq** | continuous feature sequence | word2vec embedding of exercise | 15 | 16 | 17 | 18 | ## Filename format 19 | 20 | So far, there are five atomic files in edustudio. 21 | 22 | **Note**: Users could also load other types of data except the three atomic files below. `{dt}` is the dataset name. 23 | 24 | | filename format | description | 25 | | -------------------- | ---------------------------------------------------- | 26 | | {dt}.inter.csv | Student-Exercise Interaction data | 27 | | {dt}.train.inter.csv | Student-Exercise Interaction data for training set | 28 | | {dt}.train.inter.csv | Student-Exercise Interaction data for validation set | 29 | | {dt}.train.inter.csv | Student-Exercise Interaction data for test set | 30 | | {dt}.stu.csv | Features of students | 31 | | {dt}.exer.csv | Features of exercises | 32 | 33 | 34 | 35 | ## Example 36 | 37 | ### example_dt.inter.csv 38 | 39 | | stu_id:token | exer_id:token | label:float | 40 | | ------------ | ------------- | ----------- | 41 | | 0 | 1 | 0.0 | 42 | | 1 | 0 | 1.0 | 43 | 44 | ### example_dt.stu.csv 45 | 46 | | stu_id:token | gender:token | occupation:token | 47 | | ------------ | ------------ | ---------------- | 48 | | 0 | 1 | 11 | 49 | | 1 | 0 | 7 | 50 | 51 | ### example_dt.exer.csv 52 | 53 | | exer_id:token | cpt_seq:token_seq | w2v_emb:float_seq | 54 | | ------------- | ----------------- | ---------------------- | 55 | | 0 | [0, 1] | [0.121, 0.123, 0.761] | 56 | | 1 | [1, 2, 3] | [0.229, -0.113, 0.138] | 57 | -------------------------------------------------------------------------------- /docs/source/features/atomic_operations.md: -------------------------------------------------------------------------------- 1 | # Atomic Data Operation Protocol 2 | 3 | In `Edustudio`, we view the dataset from three stages: `rawdata`, `middata`, `cachedata`. 4 | 5 | We treat the whole data processing as multiple atomic operations called atomic operation sequence. 6 | The first atomic operation, inheriting the protocol class `BaseRaw2Mid`, is the process from raw data to middle data. 7 | The following atomic operations, inheriting the protocol class `BaseMid2Cache`, construct the process from middle data to cache data. 8 | 9 | 10 | ## Partial Atomic Operation Table 11 | 12 | In the following, we give a table to display some existing atomic operations. For more detailed Atomic Operation Table, please see the `user_guide/Atomic Data Operation List` 13 | 14 | ### Raw2Mid 15 | 16 | For the conversion from rawdata to middata, we implement a specific atomic data operation prefixed with `R2M` for each dataset. 17 | 18 | | name | Corresponding datase | 19 | | --------------- | ------------------------------------------------------------ | 20 | | R2M_ASSIST_0910 | ASSISTment 2009-2010 | 21 | | R2M_FrcSub | Frcsub | 22 | | R2M_ASSIST_1213 | ASSISTment 2012-2013 | 23 | | R2M_Math1 | Math1 | 24 | | R2M_Math2 | Math2 | 25 | | R2M_AAAI_2023 | AAAI 2023 Global Knowledge Tracing Challenge | 26 | | R2M_Algebra_0506 | Algebra 2005-2006 | 27 | | R2M_ASSIST_1516 | ASSISTment 2015-2016 | 28 | 29 | ### Mid2Cache 30 | 31 | #### common 32 | 33 | | name | description | 34 | | ---------------------- | --------------------------------------------- | 35 | | M2C_Label2Int | convert label column into discrete values | 36 | | M2C_MergeDividedSplits | merge train/valid/test set into one dataframe | 37 | | M2C_ReMapId | ReMap Column ID | 38 | | M2C_GenQMat | Generate Q-matrix | 39 | 40 | #### CD 41 | 42 | | name | description | 43 | | ---------------------- | ------------------------------------------------------------ | 44 | | M2C_RandomDataSplit4CD | Split datasets Randomly for CD | 45 | | M2C_FilterRecords4CD | Filter students or exercises whose number of interaction records is less than a threshold | 46 | 47 | #### KT 48 | 49 | | name | description | 50 | | ---------------------- | ------------------------------------------- | 51 | | M2C_BuildSeqInterFeats | Build Sequential Features and Split dataset | 52 | | M2C_KCAsExer | Treat knowledge concept as exercise | 53 | | M2C_GenKCSeq | Generate knowledge concept seq | 54 | | M2C_GenUnFoldKCSeq | Unfold knowledge concepts | 55 | 56 | -------------------------------------------------------------------------------- /docs/source/features/global_cfg_obj.md: -------------------------------------------------------------------------------- 1 | # Global Configuration Object 2 | 3 | In `EduStudio`, there are five categories of configuration, they are unified in one global configuration object. 4 | 5 | ## Five Config Objects 6 | 7 | The description of five config objects is illustrated in Table below. 8 | 9 | | name | description | 10 | | ------------ | ---------------------------------- | 11 | | datatpl_cfg | configuration of data template | 12 | | modeltpl_cfg | configuration of model template | 13 | | traintpl_cfg | configuration of training template | 14 | | evaltpl_cfg | configuration of evaluate template | 15 | | frame_cfg | configuration of framework itself | 16 | 17 | 18 | 19 | ## Four Configuration Portals 20 | 21 | There are four configuration portals: 22 | 23 | - default_cfg: inheritable python class varible 24 | - configuration file 25 | - parameter dictionary 26 | - command line 27 | 28 | -------------------------------------------------------------------------------- /docs/source/features/inheritable_config.md: -------------------------------------------------------------------------------- 1 | # Inheritable Default Configuration 2 | 3 | The management of default configuration in Edustudio is implemented by class variable, i.e. a dictionary object called default_config. 4 | 5 | Templates usually introduce new features through inheritance, and these new features may require corresponding configurations, so the default configuration we provide is inheritable. 6 | 7 | ## Example 8 | 9 | The inheritance example of data template is illustrated as follows. We present an example in the data preparation procedure. There are three data template classes (DataTPLs) that inherit from each other: BaseDataTPL, GeneralDataTPL, and EduDataTPL. If users specify current DataTPL is EduDataTPL, the eventual default\_config of data preparation procedure is a merger of default\_cfg of three templates. When a configuration conflict is encountered, the default\_config of subclass template takes precedence over that of parent class templates. As a result, other configuration portals (i.e, configuration file, parameter dictionary, and command line) can only specify parameters that are confined within the default configuration. The advantage of the inheritable design is that it facilitates the reader to locate the numerous hyperparameters. 10 | 11 | ```python 12 | class BaseDataTPL(Dataset): 13 | default_cfg = {'seed': 2023} 14 | 15 | 16 | class GeneralDataTPL(BaseDataTPL): 17 | default_cfg = { 18 | 'seperator': ',', 19 | 'n_folds': 1, 20 | 'is_dataset_divided': False, 21 | 'is_save_cache': False, 22 | 'cache_id': 'cache_default', 23 | 'load_data_from': 'middata', # ['rawdata', 'middata', 'cachedata'] 24 | 'inter_exclude_feat_names': (), 25 | 'raw2mid_op': None, 26 | 'mid2cache_op_seq': [] 27 | } 28 | 29 | 30 | class EduDataTPL(GeneralDataTPL): 31 | default_cfg = { 32 | 'exer_exclude_feat_names': (), 33 | 'stu_exclude_feat_names': (), 34 | } 35 | ``` 36 | 37 | If the currently specified data template is `EduDataTPL`, then the framework will get the final `default_cfg` through API `get_default_cfg`, which would be: 38 | 39 | ```python 40 | default_cfg = { 41 | 'exer_exclude_feat_names': (), 42 | 'stu_exclude_feat_names': (), 43 | 'seperator': ',', 44 | 'n_folds': 1, 45 | 'is_dataset_divided': False, 46 | 'is_save_cache': False, 47 | 'cache_id': 'cache_default', 48 | 'load_data_from': 'middata', # ['rawdata', 'middata', 'cachedata'] 49 | 'inter_exclude_feat_names': (), 50 | 'raw2mid_op': None, 51 | 'mid2cache_op_seq': [], 52 | 'seed': 2023 53 | } 54 | ``` 55 | 56 | The final `default_cfg` follows two rules: 57 | 58 | - The subclass would incorporate the `default_cfg` of all parent classes. 59 | - When a conflict happened for the same key, the subclass would dominate the priority. 60 | -------------------------------------------------------------------------------- /docs/source/features/standard_datamodule.md: -------------------------------------------------------------------------------- 1 | # Standardized Data Module 2 | 3 | For data module, we provide a standardized design with three protocols (see following sections for details): 4 | - Data Status Protocol 5 | - Middle Data Format Protocol 6 | - Atomic Operation Protocol 7 | 8 | ![](../assets/dataflow.jpg) 9 | 10 | The first step of Data Templates is to load the raw data from the hard disk. Then, a series of processing steps are performed to obtain model-friendly data objects. Finally, these data objects are passed on to other modules. 11 | We simplify the data preparation into three into three stages: 12 | 13 | - Data loading: Loading necessary data from the hard disk. 14 | - Data processing: Convert the raw data into model-friendly data objects by a range of data processing operations. 15 | - Data delivery: Deliver model-friendly data objects to the training, model, and evaluation templates. 16 | -------------------------------------------------------------------------------- /docs/source/get_started/install.md: -------------------------------------------------------------------------------- 1 | # Install EduStudio 2 | 3 | EduStudio can be installed from ``pip`` and source files. 4 | 5 | ## System requirements 6 | 7 | EduStudio is compatible with the following operating systems: 8 | 9 | - Linux 10 | 11 | - Windows 10 12 | 13 | - macOS X 14 | 15 | Python 3.8 (or later), torch 1.10.0 (or later) are required to install our library. 16 | 17 | ## Install from pip 18 | 19 | To install EduStudio from pip, only the following command is needed: 20 | 21 | ```bash 22 | pip install -U edustudio 23 | ``` 24 | 25 | ## Install from source 26 | 27 | Download the source files from GitHub. 28 | 29 | ```bash 30 | git clone https://github.com/HFUT-LEC/EduStudio.git && cd EduStudio 31 | ``` 32 | 33 | Run the following command to install: 34 | 35 | ```bash 36 | pip install -e . --verbose 37 | ``` 38 | 39 | -------------------------------------------------------------------------------- /docs/source/get_started/quick_start.md: -------------------------------------------------------------------------------- 1 | # Quick Start 2 | 3 | Here is a quick-start example for using EduStudio. 4 | 5 | ## Create a python file to run 6 | 7 | create a python file (e.g., *run.py*) anywhere, the content is as follows: 8 | 9 | ```python 10 | from edustudio.quickstart import run_edustudio 11 | 12 | run_edustudio( 13 | dataset='FrcSub', 14 | cfg_file_name=None, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | datatpl_cfg_dict={ 19 | 'cls': 'CDInterExtendsQDataTPL' 20 | }, 21 | modeltpl_cfg_dict={ 22 | 'cls': 'KaNCD', 23 | }, 24 | evaltpl_cfg_dict={ 25 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], 26 | } 27 | ) 28 | ``` 29 | 30 | Then run the following command: 31 | 32 | ```bash 33 | python run.py 34 | ``` 35 | 36 | To find out which templates are used for a model, we can find in the [Reference Table](https://edustudio.readthedocs.io/en/latest/user_guide/reference_table.html) 37 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. EduStudio documentation master file. 2 | .. title:: EduStudio v1.1.7 3 | .. image:: assets/logo.png 4 | 5 | ========================================================= 6 | 7 | `HomePage `_ | `Docs `_ | `GitHub `_ | `Paper `_ 8 | 9 | Introduction 10 | ------------------------- 11 | EduStudio is a Unified Library for Student Assessment Models including Cognitive Diagnosis(CD) and Knowledge Tracing(KT) based on Pytorch. 12 | 13 | EduStudio first decomposes the general algorithmic workflow into six steps: `configuration reading`, `data prepration`, `model implementation`, `training control`, `model evaluation`, and `Log Storage`. Subsequently, to enhance the `reusability` and `scalability` of each step, we extract the commonalities of each algorithm at each step into individual templates for templatization. 14 | 15 | - Configuration Reading (Step 1) aims to collect, categorize and deliver configurations from different configuration portals. 16 | - Data Preparation (Step 2) aims to convert raw data from the hard disk into model-friendly data objects. 17 | - Model Implementation (Step 3) refers to the process of implementing the structure of each model and facilitating the reuse of model components. 18 | - Training Control (Step 4) focuses primarily on the training methods of various models. 19 | - Model Evaluation (Step 5) primarily focuses on the implementation of various evaluation metrics. 20 | - Log Storage (Step 6) aims to implement storage specification when store generated data. 21 | 22 | The modularization establishes clear boundaries between various programs in the algorithm pipeline, facilitating the introduction of new content to individual modules and enhancing scalability. 23 | 24 | The overall structure is illustrated as follows: 25 | 26 | .. image:: assets/framework.png 27 | :width: 600 28 | :align: center 29 | 30 | .. raw:: html 31 | 32 |
33 | 34 | .. toctree:: 35 | :maxdepth: 1 36 | :caption: Get Started 37 | 38 | get_started/install 39 | get_started/quick_start 40 | 41 | 42 | .. toctree:: 43 | :maxdepth: 1 44 | :caption: Framework Features 45 | 46 | features/global_cfg_obj 47 | features/inheritable_config 48 | features/standard_datamodule 49 | features/dataset_folder_protocol 50 | features/atomic_files 51 | features/atomic_operations 52 | 53 | .. toctree:: 54 | :maxdepth: 1 55 | :caption: User Guide 56 | 57 | user_guide/atom_op 58 | user_guide/datasets 59 | user_guide/models 60 | user_guide/reference_table 61 | user_guide/usage 62 | 63 | .. toctree:: 64 | :maxdepth: 1 65 | :caption: Developer Guide 66 | 67 | developer_guide/customize_atomic_op 68 | developer_guide/customize_datatpl 69 | developer_guide/customize_modeltpl 70 | developer_guide/customize_traintpl 71 | developer_guide/customize_evaltpl 72 | 73 | 74 | Indices and tables 75 | ================== 76 | 77 | * :ref:`genindex` 78 | * :ref:`modindex` 79 | * :ref:`search` 80 | -------------------------------------------------------------------------------- /docs/source/user_guide/atom_op.md: -------------------------------------------------------------------------------- 1 | # M2C Atomic Data Operation List 2 | 3 | 4 | | M2C Atomic operation | M2C Atomic Type | Description | 5 | | :------------------------: | --------------- | ------------------------------------------------------------ | 6 | | M2C_Label2Int | Data Cleaning | Binarization for answering response | 7 | | M2C_FilterRecords4CD | Data Cleaning | Filter some students or exercises according specific conditions | 8 | | M2C_FilteringRecordsByAttr | Data Cleaning | Filtering Students without attribute values, Commonly used by Fair Models | 9 | | M2C_ReMapId | Data Conversion | ReMap Column ID | 10 | | M2C_BuildMissingQ | Data Conversion | Build Missing Q-matrix | 11 | | M2C_BuildSeqInterFeats | Data Conversion | Build sample format for Question-based KT | 12 | | M2C_CKCAsExer | Data Conversion | Build sample format for KC-based KT | 13 | | M2C_MergeDividedSplits | Data Conversion | Merge train/valid/test set into one dataframe | 14 | | M2C_RandomDataSplit4CD | Data Partition | Data partitioning for Cognitive Diagnosis | 15 | | M2C_RandomDataSplit4KT | Data Partition | Data partitioning for Knowledge Tracing | 16 | | M2C_GenKCSeq | Data Generation | Generate Knowledge Component Sequence | 17 | | M2C_GenQMat | Data Generation | Generate Q-matrix (i.e, exercise-KC relation) | 18 | | M2C_BuildKCRelation | Data Generation | Build Knowledge Component Relation Graph | 19 | | M2C_GenUnFoldKCSeq | Data Generation | Generate Unfolded Knowledge Component Sequence | 20 | | M2C_FillMissingQ | Data Generation | Fill Missing Q-matrix | 21 | 22 | -------------------------------------------------------------------------------- /docs/source/user_guide/usage.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ========================================================= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | usage/run_edustudio 8 | usage/aht 9 | usage/atomic_cmds 10 | usage/use_case_of_config 11 | -------------------------------------------------------------------------------- /docs/source/user_guide/usage/atomic_cmds.md: -------------------------------------------------------------------------------- 1 | # Atomic Commands 2 | 3 | ## Atomic Commands for R2M process 4 | 5 | If there is a demand that process the `FrcSub` dataset from `rawdata` to `middata`, we can run the following command. 6 | ```bash 7 | edustudio r2m R2M_FrcSub --dt FrcSub --rawpath data/FrcSub/rawdata --midpath data/FrcSub/middata 8 | ``` 9 | 10 | The command would read raw data files from `data/FrcSub/rawdata` and then save the middata in `data/FrcSub/middata` 11 | 12 | -------------------------------------------------------------------------------- /docs/source/user_guide/usage/run_edustudio.md: -------------------------------------------------------------------------------- 1 | # Run EduStudio 2 | 3 | ## Create a python file to run 4 | 5 | create a python file (e.g., *run.py*) anywhere, the content is as follows: 6 | 7 | ```python 8 | from edustudio.quickstart import run_edustudio 9 | 10 | run_edustudio( 11 | dataset='FrcSub', 12 | cfg_file_name=None, 13 | traintpl_cfg_dict={ 14 | 'cls': 'GeneralTrainTPL', 15 | }, 16 | datatpl_cfg_dict={ 17 | 'cls': 'CDInterExtendsQDataTPL' 18 | }, 19 | modeltpl_cfg_dict={ 20 | 'cls': 'KaNCD', 21 | }, 22 | evaltpl_cfg_dict={ 23 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], 24 | } 25 | ) 26 | ``` 27 | 28 | Then run the following command: 29 | 30 | ```bash 31 | python run.py 32 | ``` 33 | 34 | ## Run with command 35 | 36 | You can run the following command with parameters based on created file above. 37 | 38 | ```bash 39 | cd examples 40 | python run.py -dt ASSIST_0910 --modeltpl_cfg.cls NCDM --traintpl_cfg.batch_size 512 41 | ``` 42 | 43 | ## Run with config file 44 | 45 | create a yaml file `conf/ASSIST_0910/NCDM.yaml`: 46 | ```yaml 47 | datatpl_cfg: 48 | cls: CDInterDataTPL 49 | 50 | traintpl_cfg: 51 | cls: GeneralTrainTPL 52 | batch_size: 512 53 | 54 | modeltpl_cfg: 55 | cls: NCDM 56 | 57 | evaltpl_cfg: 58 | clses: [PredictionEvalTPL, InterpretabilityEvalT] 59 | ``` 60 | 61 | then, run command: 62 | 63 | ```bash 64 | cd examples 65 | python run.py -dt ASSIST_0910 -f NCDM.yaml 66 | ``` 67 | -------------------------------------------------------------------------------- /docs/source/user_guide/usage/use_case_of_config.md: -------------------------------------------------------------------------------- 1 | # Use cases about specifying configuration 2 | 3 | ## Q1: How to specify the atomic data operation config 4 | 5 | The default_cfg of `M2C_FilterRecords4CD` is as follows: 6 | 7 | ```python 8 | class M2C_FilterRecords4CD(BaseMid2Cache): 9 | default_cfg = { 10 | "stu_least_records": 10, 11 | "exer_least_records": 0, 12 | } 13 | 14 | ``` 15 | 16 | The following example demonstrates how to specify config of M2C_FilterRecords4CD. 17 | 18 | ```python 19 | from edustudio.quickstart import run_edustudio 20 | 21 | run_edustudio( 22 | dataset='FrcSub', 23 | cfg_file_name=None, 24 | traintpl_cfg_dict={ 25 | 'cls': 'GeneralTrainTPL', 26 | }, 27 | datatpl_cfg_dict={ 28 | 'cls': 'CDInterExtendsQDataTPL', 29 | 'M2C_FilterRecords4CD': { 30 | "stu_least_records": 20, # look here 31 | } 32 | }, 33 | modeltpl_cfg_dict={ 34 | 'cls': 'KaNCD', 35 | }, 36 | evaltpl_cfg_dict={ 37 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], 38 | } 39 | ) 40 | ``` 41 | 42 | ## Q2: How to specify the config of evaluate template 43 | The default_cfg of `PredictionEvalTPL` is as follows: 44 | ```python 45 | class PredictionEvalTPL(BaseEvalTPL): 46 | default_cfg = { 47 | 'use_metrics': ['auc', 'acc', 'rmse'] 48 | } 49 | ``` 50 | 51 | 52 | If we want to use only auc metric, we can do: 53 | 54 | ```python 55 | from edustudio.quickstart import run_edustudio 56 | 57 | run_edustudio( 58 | dataset='FrcSub', 59 | cfg_file_name=None, 60 | traintpl_cfg_dict={ 61 | 'cls': 'GeneralTrainTPL', 62 | }, 63 | datatpl_cfg_dict={ 64 | 'cls': 'CDInterExtendsQDataTPL', 65 | 'M2C_FilterRecords4CD': { 66 | "stu_least_records": 20, 67 | } 68 | }, 69 | modeltpl_cfg_dict={ 70 | 'cls': 'KaNCD', 71 | }, 72 | evaltpl_cfg_dict={ 73 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], 74 | 'InterpretabilityEvalTPL': { 75 | 'use_metrics': {"auc"} # look here 76 | } 77 | } 78 | ) 79 | ``` 80 | -------------------------------------------------------------------------------- /edustudio/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | __version__ = 'v1.1.7' 6 | -------------------------------------------------------------------------------- /edustudio/assets/datasets.yaml: -------------------------------------------------------------------------------- 1 | # 1. all datasets are stored in https://huggingface.co/datasets/lmcRS/edustudio-datasets 2 | # 2. some datasets may not list here, but can still download, as edustudio will look up from external yaml file: https://huggingface.co/datasets/lmcRS/edustudio-datasets/raw/main/datasets.yaml 3 | 4 | ASSIST_0910: 5 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/ASSIST_0910/ASSIST_0910-middata.zip 6 | FrcSub: 7 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/FrcSub/FrcSub-middata.zip 8 | Math1: 9 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/Math1/Math1-middata.zip 10 | Math2: 11 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/Math2/Math2-middata.zip 12 | AAAI_2023: 13 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/AAAI_2023/AAAI_2023-middata.zip 14 | PISA_2015_ECD: 15 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/PISA_2015_ECD/PISA_2015_ECD-middata.zip 16 | -------------------------------------------------------------------------------- /edustudio/atom_op/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/edustudio/atom_op/__init__.py -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/CD/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_split4cd import M2C_RandomDataSplit4CD 2 | from .filter_records4cd import M2C_FilterRecords4CD 3 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/CD/filter_records4cd.py: -------------------------------------------------------------------------------- 1 | from ..common.base_mid2cache import BaseMid2Cache 2 | import pandas as pd 3 | 4 | 5 | class M2C_FilterRecords4CD(BaseMid2Cache): 6 | default_cfg = { 7 | "stu_least_records": 10, 8 | "exer_least_records": 0, 9 | } 10 | 11 | def process(self, **kwargs): 12 | df: pd.DataFrame = kwargs['df'] 13 | assert df is not None 14 | 15 | # 去重,保留第一个记录 16 | df.drop_duplicates( 17 | subset=['stu_id:token', 'exer_id:token', "label:float"], 18 | keep='first', inplace=True, ignore_index=True 19 | ) 20 | 21 | stu_least_records = self.m2c_cfg['stu_least_records'] 22 | exer_least_records = self.m2c_cfg['exer_least_records'] 23 | 24 | # 循环删除user和item 25 | last_count = 0 26 | while last_count != df.__len__(): 27 | last_count = df.__len__() 28 | gp_by_uid = df[['stu_id:token','exer_id:token']].groupby('stu_id:token').agg('count').reset_index() 29 | selected_users = gp_by_uid[gp_by_uid['exer_id:token'] >= stu_least_records].reset_index()['stu_id:token'].to_numpy() 30 | 31 | gp_by_iid = df[['stu_id:token','exer_id:token']].groupby('exer_id:token').agg('count').reset_index() 32 | selected_items = gp_by_iid[gp_by_iid['stu_id:token'] >= exer_least_records].reset_index()['exer_id:token'].to_numpy() 33 | 34 | df = df[df['stu_id:token'].isin(selected_users) & df['exer_id:token'].isin(selected_items)] 35 | 36 | 37 | df = df.reset_index(drop=True) 38 | selected_users = df['stu_id:token'].unique() 39 | selected_items = df['exer_id:token'].unique() 40 | 41 | if kwargs.get('df_exer', None) is not None: 42 | kwargs['df_exer'] = kwargs['df_exer'][kwargs['df_exer']['exer_id:token'].isin(selected_items)] 43 | kwargs['df_exer'].reset_index(drop=True, inplace=True) 44 | 45 | if kwargs.get('df_stu', None) is not None: 46 | kwargs['df_stu'] = kwargs['df_stu'][kwargs['df_stu']['stu_id:token'].isin(selected_users)] 47 | kwargs['df_stu'].reset_index(drop=True, inplace=True) 48 | 49 | kwargs['df'] = df 50 | return kwargs 51 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/KT/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_seq_inter_feats import M2C_BuildSeqInterFeats 2 | from .cpt_as_exer import M2C_KCAsExer 3 | from .gen_cpt_seq import M2C_GenKCSeq 4 | from .gen_unfold_cpt_seq import M2C_GenUnFoldKCSeq 5 | from .data_split4kt import M2C_RandomDataSplit4KT 6 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/KT/gen_cpt_seq.py: -------------------------------------------------------------------------------- 1 | from ..common.base_mid2cache import BaseMid2Cache 2 | import numpy as np 3 | from edustudio.datatpl.utils import PadSeqUtil 4 | 5 | 6 | class M2C_GenKCSeq(BaseMid2Cache): 7 | """Generate Knowledge Component Sequence 8 | """ 9 | default_cfg = { 10 | 'cpt_seq_window_size': -1, 11 | } 12 | 13 | def process(self, **kwargs): 14 | df_exer = kwargs['df_exer'] 15 | tmp_df_Q = df_exer.set_index('exer_id:token') 16 | exer_count = kwargs['dt_info']['exer_count'] 17 | 18 | cpt_seq_unpadding = [ 19 | (tmp_df_Q.loc[exer_id].tolist()[0] if exer_id in tmp_df_Q.index else []) for exer_id in range(exer_count) 20 | ] 21 | cpt_seq_padding, _, cpt_seq_mask = PadSeqUtil.pad_sequence( 22 | cpt_seq_unpadding, maxlen=self.m2c_cfg['cpt_seq_window_size'], return_mask=True 23 | ) 24 | 25 | kwargs['cpt_seq_padding'] = cpt_seq_padding 26 | kwargs['cpt_seq_mask'] = cpt_seq_mask 27 | return kwargs 28 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/KT/gen_unfold_cpt_seq.py: -------------------------------------------------------------------------------- 1 | from ..common.base_mid2cache import BaseMid2Cache 2 | import numpy as np 3 | from edustudio.datatpl.utils import PadSeqUtil 4 | import pandas as pd 5 | 6 | 7 | class M2C_GenUnFoldKCSeq(BaseMid2Cache): 8 | default_cfg = {} 9 | 10 | def __init__(self, m2c_cfg, n_folds, is_dataset_divided) -> None: 11 | super().__init__(m2c_cfg) 12 | self.n_folds = n_folds 13 | self.is_dataset_divided = is_dataset_divided 14 | 15 | @classmethod 16 | def from_cfg(cls, cfg): 17 | m2c_cfg = cfg.datatpl_cfg.get(cls.__name__) 18 | n_folds = cfg.datatpl_cfg.n_folds 19 | is_dataset_divided = cfg.datatpl_cfg.is_dataset_divided 20 | return cls(m2c_cfg, n_folds, is_dataset_divided) 21 | 22 | def process(self, **kwargs): 23 | df = kwargs['df'] 24 | df_exer = kwargs['df_exer'] 25 | df_train, df_valid, df_test = kwargs['df_train'], kwargs['df_valid'], kwargs['df_test'] 26 | 27 | 28 | if not self.is_dataset_divided: 29 | assert df_train is None and df_valid is None and df_test is None 30 | unique_cpt_seq = df_exer['cpt_seq:token_seq'].explode().unique() 31 | cpt_map = dict(zip(unique_cpt_seq, range(len(unique_cpt_seq)))) 32 | df_Q_unfold = pd.DataFrame({ 33 | 'exer_id:token': df_exer['exer_id:token'].repeat(df_exer['cpt_seq:token_seq'].apply(len)), 34 | 'cpt_unfold:token': df_exer['cpt_seq:token_seq'].explode().replace(cpt_map) 35 | }) 36 | df = pd.merge(df, df_Q_unfold, on=['exer_id:token'], how='left').reset_index(drop=True) 37 | kwargs['df'] = df 38 | else: # dataset is divided 39 | assert df_train is not None and df_test is not None 40 | train_df, val_df, test_df = self._unfold_dataset(df_train, df_valid, df_test, df_exer) 41 | kwargs['df_train'] = train_df 42 | kwargs['df_valid'] = val_df 43 | kwargs['df_test'] = test_df 44 | return kwargs 45 | 46 | def _unfold_dataset(self, train_df, val_df, test_df, df_Q): 47 | unique_cpt_seq = df_Q['cpt_seq'].explode().unique() 48 | cpt_map = dict(zip(unique_cpt_seq, range(len(unique_cpt_seq)))) 49 | df_Q_unfold = pd.DataFrame({ 50 | 'exer_id:token': df_Q['exer_id'].repeat(df_Q['cpt_seq'].apply(len)), 51 | 'cpt_unfold:token': df_Q['cpt_seq'].explode().replace(cpt_map) 52 | }) 53 | train_df_unfold = pd.merge(train_df, df_Q_unfold, on=['exer_id'], how='left').reset_index(drop=True) 54 | val_df_unfold = pd.merge(val_df, df_Q_unfold, on=['exer_id'], how='left').reset_index(drop=True) 55 | test_df_unfold = pd.merge(test_df, df_Q_unfold, on=['exer_id'], how='left').reset_index(drop=True) 56 | 57 | return train_df_unfold, val_df_unfold, test_df_unfold 58 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/__init__.py: -------------------------------------------------------------------------------- 1 | from .CD import * 2 | from .KT import * 3 | from .single import * 4 | from .common import * 5 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_mid2cache import BaseMid2Cache 2 | from .gen_q_mat import M2C_GenQMat 3 | from .label2int import M2C_Label2Int 4 | from .merge_divided_splits import M2C_MergeDividedSplits 5 | from .remapid import M2C_ReMapId 6 | from .build_cpt_relation import M2C_BuildKCRelation 7 | from .build_missing_Q import M2C_BuildMissingQ 8 | from .fill_missing_Q import M2C_FillMissingQ 9 | from .filtering_records_by_attr import M2C_FilteringRecordsByAttr 10 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/common/base_mid2cache.py: -------------------------------------------------------------------------------- 1 | from edustudio.utils.common import UnifyConfig 2 | import logging 3 | 4 | 5 | class BaseMid2Cache(object): 6 | default_cfg = {} 7 | 8 | def __init__(self, m2c_cfg) -> None: 9 | self.logger = logging.getLogger("edustudio") 10 | self.m2c_cfg = m2c_cfg 11 | self._check_params() 12 | 13 | def _check_params(self): 14 | pass 15 | 16 | @classmethod 17 | def from_cfg(cls, cfg: UnifyConfig): 18 | return cls(cfg.datatpl_cfg.get(cls.__name__)) 19 | 20 | @classmethod 21 | def get_default_cfg(cls, **kwargs): 22 | cfg = UnifyConfig() 23 | for _cls in cls.__mro__: 24 | if not hasattr(_cls, 'default_cfg'): 25 | break 26 | cfg.update(_cls.default_cfg, update_unknown_key_only=True) 27 | return cfg 28 | 29 | def process(self, **kwargs): 30 | pass 31 | 32 | def set_dt_info(self, dt_info, **kwargs): 33 | pass 34 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/common/build_cpt_relation.py: -------------------------------------------------------------------------------- 1 | from .base_mid2cache import BaseMid2Cache 2 | import pandas as pd 3 | import numpy as np 4 | from itertools import chain 5 | 6 | 7 | class M2C_BuildKCRelation(BaseMid2Cache): 8 | default_cfg = { 9 | 'relation_type': 'rcd_transition', 10 | 'threshold': None 11 | } 12 | 13 | def process(self, **kwargs): 14 | df = kwargs['df'] 15 | if df is None: 16 | df = pd.concat( 17 | [kwargs['df_train'], kwargs['df_valid'], kwargs['df_test']], 18 | axis=0, ignore_index=True 19 | ) 20 | 21 | if 'order_id:float' in df: 22 | df.sort_values(by=['order_id:float'], axis=0, ignore_index=True, inplace=True) 23 | elif 'order_id:token' in df: 24 | df.sort_values(by=['order_id:token'], axis=0, ignore_index=True, inplace=True) 25 | 26 | kwargs['df'] = df 27 | 28 | if self.m2c_cfg['relation_type'] == 'rcd_transition': 29 | return self.gen_rcd_transition_relation(kwargs) 30 | else: 31 | raise NotImplementedError 32 | 33 | def gen_rcd_transition_relation(self, kwargs): 34 | df:pd.DataFrame = kwargs['df'] 35 | df_exer = kwargs['df_exer'][['exer_id:token', 'cpt_seq:token_seq']] 36 | df = df[['stu_id:token', 'exer_id:token', 'label:float']].merge(df_exer[['exer_id:token', 'cpt_seq:token_seq']], how='left', on='exer_id:token') 37 | cpt_count = np.max(list(chain(*df_exer['cpt_seq:token_seq'].to_list()))) + 1 38 | 39 | n_mat = np.zeros((cpt_count, cpt_count), dtype=np.float32) # n_{i,j} 40 | for _, df_one_stu in df.groupby('stu_id:token'): 41 | for idx in range(df_one_stu.shape[0]): 42 | if idx == df_one_stu.shape[0] - 2: 43 | break 44 | curr_record = df_one_stu.iloc[idx] 45 | next_record = df_one_stu.iloc[idx + 1] 46 | if curr_record['label:float'] * next_record['label:float'] == 1: 47 | for cpt_pre in curr_record['cpt_seq:token_seq']: 48 | for cpt_next in next_record['cpt_seq:token_seq']: 49 | if cpt_pre != cpt_next: 50 | n_mat[cpt_pre, cpt_next] += 1 51 | 52 | a = np.sum(n_mat, axis=1)[:,None] 53 | nonzero_mask = (a != 0) 54 | np.seterr(divide='ignore', invalid='ignore') 55 | C_mat = np.where(nonzero_mask, n_mat / a, n_mat) 56 | 57 | max_val = C_mat.max() 58 | np.fill_diagonal(C_mat, max_val) 59 | min_val = C_mat.min() 60 | np.fill_diagonal(C_mat, 0) 61 | T_mat = (C_mat- min_val) / (max_val - min_val) 62 | 63 | threshold = self.m2c_cfg['threshold'] 64 | if threshold is None: 65 | threshold = C_mat.sum() / (C_mat != 0).sum() 66 | threshold *= threshold 67 | threshold *= threshold 68 | 69 | cpt_dep_mat = (T_mat > threshold).astype(np.int64) 70 | 71 | kwargs['cpt_dep_mat'] = cpt_dep_mat 72 | 73 | return kwargs 74 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/common/build_missing_Q.py: -------------------------------------------------------------------------------- 1 | from .base_mid2cache import BaseMid2Cache 2 | import numpy as np 3 | import pandas as pd 4 | from itertools import chain 5 | import torch 6 | from edustudio.utils.common import set_same_seeds 7 | 8 | 9 | class M2C_BuildMissingQ(BaseMid2Cache): 10 | default_cfg = { 11 | 'seed': 20230518, 12 | 'Q_delete_ratio': 0.0, 13 | } 14 | 15 | def process(self, **kwargs): 16 | dt_info = kwargs['dt_info'] 17 | self.item_count = dt_info['exer_count'] 18 | self.cpt_count = dt_info['cpt_count'] 19 | self.df_Q = kwargs['df_exer'][['exer_id:token', 'cpt_seq:token_seq']] 20 | 21 | self.missing_df_Q = self.get_missing_df_Q() 22 | self.missing_Q_mat = self.get_Q_mat_from_df_arr(self.missing_df_Q, self.item_count, self.cpt_count) 23 | 24 | kwargs['missing_df_Q'] = self.missing_df_Q 25 | kwargs['missing_Q_mat'] = self.missing_Q_mat 26 | 27 | return kwargs 28 | 29 | def get_missing_df_Q(self): 30 | set_same_seeds(seed=self.m2c_cfg['seed']) 31 | ratio = self.m2c_cfg['Q_delete_ratio'] 32 | iid2cptlist = self.df_Q.set_index('exer_id:token')['cpt_seq:token_seq'].to_dict() 33 | iid_lis = np.array(list(chain(*[[i]*len(iid2cptlist[i]) for i in iid2cptlist]))) 34 | cpt_lis = np.array(list(chain(*list(iid2cptlist.values())))) 35 | entry_arr = np.vstack([iid_lis, cpt_lis]).T 36 | 37 | np.random.shuffle(entry_arr) 38 | 39 | # reference: https://stackoverflow.com/questions/64834655/python-how-to-find-first-duplicated-items-in-an-numpy-array 40 | _, idx = np.unique(entry_arr[:, 1], return_index=True) # 先从每个知识点中选出1题出来 41 | bool_idx = np.zeros_like(entry_arr[:, 1], dtype=bool) 42 | bool_idx[idx] = True 43 | preserved_exers = np.unique(entry_arr[bool_idx, 0]) # 选择符合条件的习题作为保留 44 | 45 | delete_num = int(ratio * self.item_count) 46 | preserved_num = self.item_count - delete_num 47 | 48 | if len(preserved_exers) >= preserved_num: 49 | self.logger.warning( 50 | f"Cant Satisfy Delete Require: {len(preserved_exers)=},{preserved_num=}" 51 | ) 52 | else: 53 | need_preserved_num = preserved_num - len(preserved_exers) 54 | 55 | left_iids = np.arange(self.item_count) 56 | left_iids = left_iids[~np.isin(left_iids, preserved_exers)] 57 | np.random.shuffle(left_iids) 58 | choose_iids = left_iids[0:need_preserved_num] 59 | 60 | preserved_exers = np.hstack([preserved_exers, choose_iids]) 61 | 62 | return self.df_Q.copy()[self.df_Q['exer_id:token'].isin(preserved_exers)].reset_index(drop=True) 63 | 64 | 65 | def get_Q_mat_from_df_arr(self, df_Q_arr, item_count, cpt_count): 66 | Q_mat = torch.zeros((item_count, cpt_count), dtype=torch.int64) 67 | for _, item in df_Q_arr.iterrows(): Q_mat[item['exer_id:token'], item['cpt_seq:token_seq']] = 1 68 | return Q_mat 69 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/common/filtering_records_by_attr.py: -------------------------------------------------------------------------------- 1 | from .base_mid2cache import BaseMid2Cache 2 | import pandas as pd 3 | import numpy as np 4 | from itertools import chain 5 | 6 | 7 | class M2C_FilteringRecordsByAttr(BaseMid2Cache): 8 | """Commonly used by Fair Models, and Filtering Students without attribute values 9 | """ 10 | default_cfg = { 11 | 'filter_stu_attrs': ['gender:token'] 12 | } 13 | 14 | def process(self, **kwargs): 15 | df_stu = kwargs['df_stu'] 16 | df = kwargs['df'] 17 | df_stu = df_stu[df_stu[self.m2c_cfg['filter_stu_attrs']].notna().all(axis=1)].reset_index(drop=True) 18 | df = df[df['stu_id:token'].isin(df_stu['stu_id:token'])].reset_index(drop=True) 19 | 20 | kwargs['df'] = df 21 | kwargs['df_stu'] = df_stu 22 | 23 | return kwargs 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/common/gen_q_mat.py: -------------------------------------------------------------------------------- 1 | from .base_mid2cache import BaseMid2Cache 2 | import numpy as np 3 | from itertools import chain 4 | import torch 5 | 6 | 7 | class M2C_GenQMat(BaseMid2Cache): 8 | def process(self, **kwargs): 9 | df_exer = kwargs['df_exer'] 10 | cpt_count = len(set(list(chain(*df_exer['cpt_seq:token_seq'].to_list())))) 11 | # df_exer['cpt_multihot:token_seq'] = df_exer['cpt_seq:token_seq'].apply( 12 | # lambda x: self.multi_hot(cpt_count, np.array(x)).tolist() 13 | # ) 14 | kwargs['df_exer'] = df_exer 15 | tmp_df_exer = df_exer.set_index("exer_id:token") 16 | 17 | kwargs['Q_mat'] = torch.from_numpy(np.array( 18 | [self.multi_hot(cpt_count, tmp_df_exer.loc[exer_id]['cpt_seq:token_seq']).tolist() 19 | for exer_id in range(df_exer['exer_id:token'].max() + 1)] 20 | )) 21 | return kwargs 22 | 23 | @staticmethod 24 | def multi_hot(length, indices): 25 | multi_hot = np.zeros(length, dtype=np.int64) 26 | multi_hot[indices] = 1 27 | return multi_hot 28 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/common/label2int.py: -------------------------------------------------------------------------------- 1 | from .base_mid2cache import BaseMid2Cache 2 | import numpy as np 3 | 4 | 5 | class M2C_Label2Int(BaseMid2Cache): 6 | def process(self, **kwargs): 7 | self.op_on_df('df', kwargs) 8 | self.op_on_df('df_train', kwargs) 9 | self.op_on_df('df_valid', kwargs) 10 | self.op_on_df('df_test', kwargs) 11 | return kwargs 12 | 13 | @staticmethod 14 | def op_on_df(column, kwargs): 15 | if column in kwargs and kwargs[column] is not None: 16 | kwargs[column]['label:float'] = (kwargs[column]['label:float'] >= 0.5).astype(np.float32) 17 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/common/merge_divided_splits.py: -------------------------------------------------------------------------------- 1 | from .base_mid2cache import BaseMid2Cache 2 | import pandas as pd 3 | from itertools import chain 4 | 5 | 6 | class M2C_MergeDividedSplits(BaseMid2Cache): 7 | default_cfg = {} 8 | 9 | def process(self, **kwargs): 10 | df_train = kwargs['df_train'] 11 | df_valid = kwargs['df_valid'] 12 | df_test = kwargs['df_test'] 13 | 14 | # 1. ensure the keys in df_train, df_valid, df_test is same 15 | assert df_train is not None and df_test is not None 16 | assert set(df_train.columns) == set(df_test.columns) 17 | # 2. 包容df_valid不存在的情况 18 | 19 | df = pd.concat((df_train, df_test), ignore_index=True) 20 | 21 | if df_valid is not None: 22 | assert set(df_train.columns) == set(df_valid.columns) 23 | df = pd.concat((df, df_valid), ignore_index=True) 24 | 25 | kwargs['df'] = df 26 | return kwargs 27 | 28 | def set_dt_info(self, dt_info, **kwargs): 29 | if 'stu_id:token' in kwargs['df'].columns: 30 | dt_info['stu_count'] = int(kwargs['df']['stu_id:token'].max() + 1) 31 | if 'exer_id:token' in kwargs['df'].columns: 32 | dt_info['exer_count'] = int(kwargs['df']['exer_id:token'].max() + 1) 33 | if kwargs.get('df_exer', None) is not None: 34 | if 'cpt_seq:token_seq' in kwargs['df_exer']: 35 | dt_info['cpt_count'] = len(set(list(chain(*kwargs['df_exer']['cpt_seq:token_seq'].to_list())))) 36 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/single/M2C_CNCDQ_OP.py: -------------------------------------------------------------------------------- 1 | from ..common import BaseMid2Cache 2 | import torch 3 | 4 | class M2C_CNCDQ_OP(BaseMid2Cache): 5 | default_cfg = {} 6 | 7 | def process(self, **kwargs): 8 | df_exer = kwargs['df_exer'] 9 | exer_count = kwargs['dt_info']['exer_count'] 10 | cpt_count = kwargs['dt_info']['cpt_count'] 11 | kwargs['Q_mask_mat'], kwargs['knowledge_pairs'] = self._get_knowledge_pairs( 12 | df_Q_arr=df_exer, exer_count=exer_count, cpt_count=cpt_count 13 | ) 14 | return kwargs 15 | 16 | def _get_knowledge_pairs(self, df_Q_arr, exer_count, cpt_count): 17 | Q_mask_mat = torch.zeros((exer_count, cpt_count), dtype=torch.int64) 18 | knowledge_pairs = [] 19 | kn_tags = [] 20 | kn_topks = [] 21 | for _, item in df_Q_arr.iterrows(): 22 | # kn_tags.append(item['cpt_seq']) 23 | # kn_topks.append(item['cpt_pre_seq']) 24 | kn_tags = item['cpt_seq'] 25 | kn_topks = item['cpt_pre_seq'] 26 | knowledge_pairs.append((kn_tags, kn_topks)) 27 | for cpt_id in item['cpt_seq']: 28 | Q_mask_mat[item['exer_id'], cpt_id-1] = 1 29 | for cpt_id in item['cpt_pre_seq']: 30 | Q_mask_mat[item['exer_id'], cpt_id-1] = 1 31 | return Q_mask_mat, knowledge_pairs 32 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/single/M2C_DIMKT_OP.py: -------------------------------------------------------------------------------- 1 | from ..common import BaseMid2Cache 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class M2C_DIMKT_OP(BaseMid2Cache): 7 | default_cfg = {} 8 | 9 | def process(self, **kwargs): 10 | self.df_Q = kwargs['df_exer'] 11 | dt_info = kwargs['dt_info'] 12 | self.num_q = dt_info['exer_count'] 13 | self.num_c = dt_info['cpt_count'] 14 | 15 | df_train_folds = kwargs['df_train_folds'] 16 | 17 | kwargs['q_diff_list'] = [] 18 | kwargs['c_diff_list'] = [] 19 | for train_dict in df_train_folds: 20 | self.train_dict = train_dict 21 | self.compute_difficulty() 22 | kwargs['q_diff_list'].append(self.q_dif) 23 | kwargs['c_diff_list'].append(self.c_dif) 24 | return kwargs 25 | 26 | def compute_difficulty(self): 27 | q_dict = dict(zip(self.df_Q['exer_id:token'], self.df_Q['cpt_seq:token_seq'])) 28 | qd={} 29 | qd_count={} 30 | cd={} 31 | cd_count={} 32 | exer_ids = self.train_dict['exer_seq:token_seq'] 33 | label_ids = self.train_dict['label_seq:float_seq'] 34 | mask_ids = self.train_dict['mask_seq:token_seq'] 35 | for ii, ee in enumerate(exer_ids): 36 | for i, e in enumerate(ee): 37 | tmp_mask = mask_ids[ii, i] 38 | if tmp_mask != 0: 39 | tmp_exer = exer_ids[ii, i] 40 | tmp_label = label_ids[ii, i] 41 | cpt = (q_dict[tmp_exer])[0] 42 | cd[cpt] = cd.get(cpt, 0) + tmp_label 43 | cd_count[cpt] = cd_count.get(cpt, 0) + 1 44 | if tmp_exer in qd: 45 | qd[tmp_exer] = qd[tmp_exer] + tmp_label 46 | qd_count[tmp_exer] = qd_count[tmp_exer]+1 47 | else: 48 | qd[tmp_exer] = tmp_label 49 | qd_count[tmp_exer] = 1 50 | else: 51 | break 52 | 53 | 54 | self.q_dif = np.ones(self.num_q) 55 | self.c_dif = np.ones(self.num_c) 56 | for k,v in qd.items(): 57 | self.q_dif[k] = int((qd[k]/qd_count[k])*100)+1 58 | for k,v in cd.items(): 59 | self.c_dif[k] = int((cd[k]/cd_count[k])*100)+1 60 | self.q_dif = torch.tensor(self.q_dif).unsqueeze(1) 61 | self.c_dif = torch.tensor(self.c_dif).unsqueeze(1) 62 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/single/M2C_EERNN_OP.py: -------------------------------------------------------------------------------- 1 | from ..common import BaseMid2Cache 2 | from itertools import chain 3 | import torch 4 | import pandas as pd 5 | import numpy as np 6 | from edustudio.datatpl.utils import PadSeqUtil 7 | import copy 8 | 9 | 10 | class M2C_EERNN_OP(BaseMid2Cache): 11 | default_cfg = { 12 | 'word_emb_dim': 50, 13 | 'max_sentence_len': 100, 14 | } 15 | 16 | def process(self, **kwargs): 17 | dt_info = kwargs['dt_info'] 18 | train_dict_list = kwargs['df_train_folds'] 19 | df_exer = kwargs['df_exer'] 20 | assert len(train_dict_list) > 0 21 | import gensim 22 | 23 | word_emb_dict_list,content_mat_list = [], [] 24 | for train_dict in train_dict_list: 25 | exers = np.unique(train_dict['exer_seq:token_seq']) 26 | sentences = df_exer[df_exer['exer_id:token'].isin(exers)]['content:token_seq'].tolist() 27 | # desc_dict = gensim.corpora.Dictionary(sentences) 28 | # word_set = list(desc_dict.token2id.keys()) 29 | # word2id = {w:i+1 for i,w in enumerate(word_set)} 30 | model = gensim.models.word2vec.Word2Vec( 31 | sentences, vector_size=self.m2c_cfg['word_emb_dim'], 32 | ) 33 | wv_from_bin = model.wv 34 | 35 | word2id = copy.deepcopy(model.wv.key_to_index) 36 | word2id = {k: word2id[k] + 1 for k in word2id} 37 | word_emb_dict = {word2id[key]: wv_from_bin[key] for key in word2id} 38 | word_emb_dict[0] = np.zeros(shape=(self.m2c_cfg['word_emb_dim'], )) 39 | 40 | # 将训练集、验证集、测试集中未出现在word_emb中的单词,全部替换成ID为0,并进行padding 41 | df_exer['content:token_seq'] = df_exer['content:token_seq'].apply(lambda x: [word2id.get(xx, 0) for xx in x]) 42 | pad_mat, _, _ = PadSeqUtil.pad_sequence( 43 | df_exer['content:token_seq'].tolist(), maxlen=self.m2c_cfg['max_sentence_len'], padding='post', 44 | is_truncate=True, truncating='post', value=0, 45 | ) 46 | df_exer['content:token_seq'] = [pad_mat[i].tolist() for i in range(pad_mat.shape[0])] 47 | 48 | tmp_df_Q = df_exer.set_index('exer_id:token') 49 | content_mat = torch.from_numpy(np.vstack( 50 | [tmp_df_Q.loc[exer_id]['content:token_seq'] for exer_id in range(dt_info['exer_count'])] 51 | )) 52 | 53 | word_emb_dict_list.append(word_emb_dict) 54 | content_mat_list.append(content_mat) 55 | kwargs['word_emb_dict_list'] = word_emb_dict_list 56 | kwargs['content_mat_list'] = content_mat_list 57 | return kwargs 58 | 59 | def set_dt_info(self, dt_info, **kwargs): 60 | dt_info['word_emb_dim'] = self.m2c_cfg['word_emb_dim'] 61 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/single/M2C_MGCD_OP.py: -------------------------------------------------------------------------------- 1 | from ..common import BaseMid2Cache 2 | import pandas as pd 3 | import numpy as np 4 | 5 | 6 | class M2C_MGCD_OP(BaseMid2Cache): 7 | default_cfg = { 8 | "group_id_field": "class_id:token", 9 | "min_inter": 5, 10 | } 11 | 12 | def __init__(self, m2c_cfg, is_dataset_divided) -> None: 13 | super().__init__(m2c_cfg) 14 | self.is_dataset_divided = is_dataset_divided 15 | assert self.is_dataset_divided is False 16 | 17 | @classmethod 18 | def from_cfg(cls, cfg): 19 | m2c_cfg = cfg.datatpl_cfg.get(cls.__name__) 20 | is_dataset_divided = cfg.datatpl_cfg.is_dataset_divided 21 | return cls(m2c_cfg, is_dataset_divided) 22 | 23 | def process(self, **kwargs): 24 | df_stu = kwargs['df_stu'] 25 | df_inter = kwargs['df'] 26 | df_inter_group, df_inter_stu, kwargs['df_stu']= self.get_df_group_and_df_inter( 27 | df_inter=df_inter, df_stu=df_stu 28 | ) 29 | kwargs['df'] = df_inter_group 30 | kwargs['df_inter_stu'] = df_inter_stu 31 | kwargs['df_stu'] = kwargs['df_stu'].rename(columns={self.m2c_cfg['group_id_field']: 'group_id:token'}) 32 | self.group_count = df_inter_group['group_id:token'].nunique() 33 | return kwargs 34 | 35 | def get_df_group_and_df_inter(self, df_stu:pd.DataFrame, df_inter:pd.DataFrame): 36 | df_inter = df_inter.merge(df_stu[['stu_id:token', self.m2c_cfg['group_id_field']]], on=['stu_id:token'], how='left') # 两个表合并,这样inter里也有class_id 37 | 38 | df_inter_group = pd.DataFrame() 39 | df_inter_stu = pd.DataFrame() 40 | 41 | for _, inter_g in df_inter.groupby(self.m2c_cfg['group_id_field']): 42 | exers_list = inter_g[['stu_id:token', 'exer_id:token']].groupby('stu_id:token').agg(set)['exer_id:token'].tolist() 43 | inter_exer_set = None # 选择所有学生都做的题目 44 | for exers in exers_list: inter_exer_set = exers if inter_exer_set is None else (inter_exer_set & exers) 45 | inter_exer_set = np.array(list(inter_exer_set)) 46 | if inter_exer_set.shape[0] >= self.m2c_cfg['min_inter']: 47 | tmp_group_df = inter_g[inter_g['exer_id:token'].isin(inter_exer_set)] 48 | tmp_stu_df = inter_g[~inter_g['exer_id:token'].isin(inter_exer_set)] 49 | 50 | df_inter_group = pd.concat([df_inter_group, tmp_group_df], ignore_index=True, axis=0) 51 | df_inter_stu = pd.concat([df_inter_stu, tmp_stu_df], ignore_index=True, axis=0) 52 | else: 53 | df_stu = df_stu[df_stu['class_id:token'] != inter_g['class_id:token'].values[0]] 54 | 55 | df_inter_group = df_inter_group[['label:float', 'exer_id:token', self.m2c_cfg['group_id_field']]].groupby( 56 | [self.m2c_cfg['group_id_field'], 'exer_id:token'] 57 | ).agg('mean').reset_index().rename(columns={self.m2c_cfg['group_id_field']: 'group_id:token'}) 58 | 59 | return df_inter_group, df_inter_stu, df_stu[df_stu['class_id:token'].isin(df_inter_group['group_id:token'].unique())] 60 | 61 | def set_dt_info(self, dt_info, **kwargs): 62 | dt_info['group_count'] = self.group_count -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/single/M2C_QDKT_OP.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from ..common import BaseMid2Cache 3 | import torch 4 | from torch.nn import functional as F 5 | import numpy as np 6 | 7 | 8 | class M2C_QDKT_OP(BaseMid2Cache): 9 | default_cfg = {} 10 | 11 | def process(self, **kwargs): 12 | self.df_Q = kwargs['df_exer'] 13 | dt_info = kwargs['dt_info'] 14 | self.num_q = dt_info['exer_count'] 15 | self.num_c = dt_info['cpt_count'] 16 | self.Q_mat = kwargs['Q_mat'] 17 | laplacian_matrix = self.laplacian_matrix_by_vectorization() 18 | kwargs['laplacian_matrix'] = laplacian_matrix 19 | return kwargs 20 | 21 | def laplacian_matrix_by_vectorization(self): 22 | normQ = F.normalize(self.Q_mat.float(), p=2, dim=-1) 23 | A = torch.mm(normQ, normQ.T) > (1 - 1/len(normQ)) 24 | A = A.int() #Adjacency matrix 25 | D = A.sum(-1, dtype=torch.int32) 26 | diag_idx = [range(len(A)), range(len(A))] 27 | A[diag_idx] = D - A[diag_idx] 28 | return A 29 | 30 | def generate_graph(self): 31 | 32 | graph = nx.Graph() 33 | len1 = len(self.Q_mat) 34 | graph.add_nodes_from([i for i in range(1, len1 + 1)]) 35 | for index in range(len1 - 1): 36 | for bindex in range(index + 1, len1): 37 | if not (False in (self.Q_mat[index, :] == self.Q_mat[bindex, :]).tolist()): 38 | graph.add_edge(index + 1, bindex + 1) 39 | return graph 40 | 41 | # 求图的拉普拉斯矩阵 L = D - A 42 | 43 | def laplacian_matrix(self, graph): 44 | # 求邻接矩阵 45 | A = np.array(nx.adjacency_matrix(graph).todense()) 46 | A = -A 47 | for i in range(len(A)): 48 | # 求顶点的度 49 | degree_i = graph.degree(i + 1) # 节点编号从1开始,若从0开始,将i+1改为i 50 | A[i][i] = A[i][i] + degree_i 51 | return A 52 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/single/M2C_RCD_OP.py: -------------------------------------------------------------------------------- 1 | from ..common import BaseMid2Cache 2 | import numpy as np 3 | import importlib 4 | 5 | 6 | class M2C_RCD_OP(BaseMid2Cache): 7 | default_cfg = {} 8 | 9 | def process(self, **kwargs): 10 | self.dgl = importlib.import_module("dgl") 11 | k_e, e_k = self.build_k_e(kwargs) 12 | u_e_list, e_u_list = self.build_u_e(kwargs) 13 | local_map = { 14 | 'directed_g': self.build_cpt_directd(kwargs), 15 | 'undirected_g': self.build_cpt_undirected(kwargs), 16 | 'k_from_e': k_e, 17 | 'e_from_k': e_k, 18 | 'u_from_e_list': u_e_list, 19 | 'e_from_u_list': e_u_list, 20 | } 21 | kwargs['local_map'] = local_map 22 | return kwargs 23 | 24 | def build_cpt_undirected(self, kwargs): 25 | cpt_count = kwargs['dt_info']['cpt_count'] 26 | cpt_dep_mat = kwargs['cpt_dep_mat'] 27 | cpt_dep_mat_undirect = ((cpt_dep_mat + cpt_dep_mat.T) == 2).astype(np.int64) 28 | # undirected (only prerequisite) 29 | g_undirected = self.dgl.graph(np.argwhere(cpt_dep_mat_undirect == 1).tolist(), num_nodes=cpt_count) 30 | return g_undirected 31 | 32 | def build_cpt_directd(self, kwargs): 33 | cpt_count = kwargs['dt_info']['cpt_count'] 34 | cpt_dep_mat = kwargs['cpt_dep_mat'] 35 | # directed (prerequisite + similarity) 36 | g_directed =self.dgl.graph(np.argwhere(cpt_dep_mat == 1).tolist(), num_nodes=cpt_count) 37 | return g_directed 38 | 39 | def build_k_e(self, kwargs): 40 | cpt_count = kwargs['dt_info']['cpt_count'] 41 | exer_count = kwargs['dt_info']['exer_count'] 42 | df_exer = kwargs['df_exer'] 43 | 44 | edges = df_exer[['exer_id:token','cpt_seq:token_seq']].explode('cpt_seq:token_seq').to_numpy() 45 | edges[:, 1] += exer_count 46 | 47 | k_e = self.dgl.graph(edges.tolist(), num_nodes=cpt_count + exer_count) 48 | e_k = self.dgl.graph(edges[:,[1,0]].tolist(), num_nodes=cpt_count + exer_count) 49 | return k_e, e_k 50 | 51 | def build_u_e(self, kwargs): 52 | stu_count = kwargs['dt_info']['stu_count'] 53 | exer_count = kwargs['dt_info']['exer_count'] 54 | df_train_folds = kwargs['df_train_folds'] 55 | 56 | u_from_e_list= [] 57 | e_from_u_list = [] 58 | for train_df in df_train_folds: 59 | stu_id = train_df['stu_id:token'].to_numpy() + exer_count 60 | exer_id = train_df['exer_id:token'].to_numpy() 61 | u_e = self.dgl.graph(np.vstack([exer_id, stu_id]).T.tolist(), num_nodes=stu_count + exer_count) 62 | e_u = self.dgl.graph(np.vstack([stu_id, exer_id]).T.tolist(), num_nodes=stu_count + exer_count) 63 | u_from_e_list.append(u_e) 64 | e_from_u_list.append(e_u) 65 | return u_from_e_list, e_from_u_list 66 | -------------------------------------------------------------------------------- /edustudio/atom_op/mid2cache/single/__init__.py: -------------------------------------------------------------------------------- 1 | from .M2C_EERNN_OP import M2C_EERNN_OP 2 | from .M2C_LPKT_OP import M2C_LPKT_OP 3 | from .M2C_DKTForget_OP import M2C_DKTForget_OP 4 | from .M2C_DKTDSC_OP import M2C_DKTDSC_OP 5 | from .M2C_DIMKT_OP import M2C_DIMKT_OP 6 | from .M2C_QDKT_OP import M2C_QDKT_OP 7 | from .M2C_CL4KT_OP import M2C_CL4KT_OP 8 | from .M2C_MGCD_OP import M2C_MGCD_OP 9 | from .M2C_RCD_OP import M2C_RCD_OP 10 | from .M2C_CDGK_OP import M2C_CDGK_OP 11 | -------------------------------------------------------------------------------- /edustudio/atom_op/raw2mid/__init__.py: -------------------------------------------------------------------------------- 1 | from .raw2mid import BaseRaw2Mid 2 | from .assist_0910 import R2M_ASSIST_0910 3 | from .frcsub import R2M_FrcSub 4 | from .assist_1213 import R2M_ASSIST_1213 5 | from .math1 import R2M_Math1 6 | from .math2 import R2M_Math2 7 | from .aaai_2023 import R2M_AAAI_2023 8 | from .algebra2005 import R2M_Algebra_0506 9 | from .assist_1516 import R2M_ASSIST_1516 10 | from .assist_2017 import R2M_ASSIST_17 11 | from .bridge2006 import R2M_Bridge2Algebra_0607 12 | from .ednet_kt1 import R2M_EdNet_KT1 13 | from .junyi_area_topic_as_cpt import R2M_Junyi_AreaTopicAsCpt 14 | from .junyi_exer_as_cpt import R2M_JunyiExerAsCpt 15 | from .nips12 import R2M_Eedi_20_T12 16 | from .nips34 import R2M_Eedi_20_T34 17 | from .simulated5 import R2M_Simulated5 18 | from .slp_english import R2M_SLP_English 19 | from .slp_math import R2M_SLP_Math 20 | 21 | # look up api dict 22 | _cli_api_dict_ = {} 23 | _cli_api_dict_['R2M_ASSIST_0910'] = R2M_ASSIST_0910.from_cli 24 | _cli_api_dict_['R2M_FrcSub'] = R2M_FrcSub.from_cli 25 | _cli_api_dict_['R2M_ASSIST_1213'] = R2M_ASSIST_1213.from_cli 26 | _cli_api_dict_['R2M_Math1'] = R2M_Math1.from_cli 27 | _cli_api_dict_['R2M_Math2'] = R2M_Math2.from_cli 28 | _cli_api_dict_['R2M_Xueersi_2023'] = R2M_AAAI_2023.from_cli 29 | _cli_api_dict_['R2M_Algebra_0506'] = R2M_Algebra_0506.from_cli 30 | _cli_api_dict_['R2M_ASSIST_1516'] = R2M_ASSIST_1516.from_cli 31 | _cli_api_dict_['R2M_ASSIST_17'] = R2M_ASSIST_17.from_cli 32 | _cli_api_dict_['R2M_Bridge2Algebra_0607'] = R2M_Bridge2Algebra_0607.from_cli 33 | _cli_api_dict_['R2M_EdNet_KT1'] = R2M_EdNet_KT1.from_cli 34 | _cli_api_dict_['R2M_Junyi_AreaTopicAsCpt'] = R2M_Junyi_AreaTopicAsCpt.from_cli 35 | _cli_api_dict_['R2M_JunyiExerAsCpt'] = R2M_JunyiExerAsCpt.from_cli 36 | _cli_api_dict_['R2M_Eedi_20_T12'] = R2M_Eedi_20_T12.from_cli 37 | _cli_api_dict_['R2M_Eedi_20_T34'] = R2M_Eedi_20_T34.from_cli 38 | _cli_api_dict_['R2M_Simulated5'] = R2M_Simulated5.from_cli 39 | _cli_api_dict_['R2M_SLP_Math'] = R2M_SLP_Math.from_cli 40 | _cli_api_dict_['R2M_SLP_English'] = R2M_SLP_English.from_cli 41 | -------------------------------------------------------------------------------- /edustudio/atom_op/raw2mid/assist_1516.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from .raw2mid import BaseRaw2Mid 3 | 4 | r""" 5 | R2M_ASSIST_1516 6 | ######################## 7 | """ 8 | 9 | 10 | class R2M_ASSIST_1516(BaseRaw2Mid): 11 | """R2M_ASSIST_1516 is a class used to handle the ASSISTment 2015-2016 dataset.""" 12 | 13 | def process(self): 14 | super().process() 15 | pd.set_option("mode.chained_assignment", None) # ignore warning 16 | # 读取原始数据,查看其属性 17 | raw_data = pd.read_csv(f"{self.rawpath}/2015_100_skill_builders_main_problems.csv", encoding='utf-8') 18 | 19 | # 获取交互信息 20 | # 对log_id进行排序,确定order序列 21 | inter = pd.DataFrame.copy(raw_data).sort_values(by='log_id', ascending=True) 22 | inter['order'] = range(len(inter)) 23 | inter['label'] = inter['correct'] 24 | df_inter = inter.rename(columns={'sequence_id': 'exer_id', 'user_id': 'stu_id'}).reindex( 25 | columns=['stu_id', 'exer_id', 'label', 'order', ]).rename( 26 | columns={'stu_id': 'stu_id:token', 'exer_id': 'exer_id:token', 'label': 'label:float', 27 | 'order': 'order_id:float'}) 28 | 29 | # 获取学生信息 30 | stu = pd.DataFrame(set(raw_data['user_id']), columns=['stu_id', ]) 31 | # stu['classId'] = None 32 | # stu['gender'] = None 33 | df_stu = stu.sort_values(by='stu_id', ascending=True) 34 | 35 | # 获取习题信息 36 | exer = pd.DataFrame(set(raw_data['sequence_id']), columns=['exer_id']) 37 | # exer['cpt_seq'] = None 38 | # exer['assignment_id'] = None 39 | df_exer = exer.sort_values(by='exer_id', ascending=True) 40 | 41 | # 此处将数据保存到`self.midpath`中 42 | 43 | df_inter.to_csv(f"{self.midpath}/{self.dt}.inter.csv", index=False, encoding='utf-8') 44 | # df_stu.to_csv(f"{self.midpath}/{self.dt}.stu.csv", index=False, encoding='utf-8') 45 | # df_exer.to_csv(f"{self.midpath}/{self.dt}.exer.csv", index=False, encoding='utf-8') 46 | pd.set_option("mode.chained_assignment", "warn") # ignore warning 47 | return 48 | -------------------------------------------------------------------------------- /edustudio/atom_op/raw2mid/frcsub.py: -------------------------------------------------------------------------------- 1 | from .raw2mid import BaseRaw2Mid 2 | import pandas as pd 3 | pd.set_option("mode.chained_assignment", None) # ingore warning 4 | 5 | r""" 6 | R2M_FrcSub 7 | ##################################### 8 | FrcSub dataset preprocess 9 | """ 10 | 11 | class R2M_FrcSub(BaseRaw2Mid): 12 | """R2M_FrcSub is to preprocess FrcSub dataset""" 13 | def process(self): 14 | super().process() 15 | # # Preprocess 16 | # 17 | # 此处对数据集进行处理 18 | 19 | # 读取文本文件转换为 dataframe 20 | df_inter = pd.read_csv(f"{self.rawpath}/data.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 21 | '10', '11', '12', '13', '14', '15', '16', '17', '18', '19']) 22 | df_inter.insert(0, 'stu_id:token', range(len(df_inter))) 23 | df_exer = pd.read_csv(f"{self.rawpath}/q.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7']) 24 | df_exer.insert(0, 'exer_id:token', range(len(df_exer))) 25 | # print(df_inter) 26 | # print(df_exer) 27 | 28 | # 统计知识点个数 29 | cpt_count = df_exer.shape[1] 30 | 31 | # 处理 df_inter 拆成很多行 32 | df_inter = df_inter.melt(id_vars=['stu_id:token'], var_name='exer_id:token', value_name='label:float') 33 | df_inter['exer_id:token'] = df_inter['exer_id:token'].astype(int) 34 | df_inter .sort_values(by = ['stu_id:token','exer_id:token'], inplace=True) 35 | # print(df_inter) 36 | 37 | # 处理 df_exer 先拆成很多行,再合并 38 | 39 | # 拆成很多行 40 | df_exer = df_exer.melt(id_vars=['exer_id:token'], var_name='cpt_seq:token_seq', value_name='value') 41 | df_exer = df_exer[df_exer['value'] == 1] 42 | del df_exer['value'] 43 | 44 | # 合并 cpt_seq:token_seq 45 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(int) 46 | df_exer.sort_values(by='cpt_seq:token_seq', inplace=True) 47 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(str) 48 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(str) 49 | df_exer = df_exer.groupby('exer_id:token')['cpt_seq:token_seq'].agg(','.join).reset_index() 50 | 51 | # 按 exer_id:token 进行排序 52 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(int) 53 | df_exer.sort_values(by='exer_id:token', inplace=True) 54 | # print(df_exer) 55 | 56 | # # Stat dataset 57 | # 58 | # 此处统计数据集,保存到cfg对象中 59 | 60 | # cfg.stu_count = len(df_inter['stu_id:token'].unique()) 61 | # cfg.exer_count = len(df_exer) 62 | # cfg.cpt_count = cpt_count 63 | # cfg.interaction_count = len(df_inter) 64 | # cfg 65 | 66 | # # Save MidData 67 | # 68 | # 此处将数据保存到`cfg.MIDDATA_PATH`中 69 | 70 | df_inter.to_csv(f"{self.midpath}/{self.dt}.inter.csv", index=False, encoding='utf-8') 71 | df_exer.to_csv(f"{self.midpath}/{self.dt}.exer.csv", index=False, encoding='utf-8') 72 | -------------------------------------------------------------------------------- /edustudio/atom_op/raw2mid/math1.py: -------------------------------------------------------------------------------- 1 | from .raw2mid import BaseRaw2Mid 2 | import pandas as pd 3 | r""" 4 | R2M_Math1 5 | ##################################### 6 | Math1 dataset preprocess 7 | """ 8 | 9 | 10 | class R2M_Math1(BaseRaw2Mid): 11 | """R2M_Math1 is to preprocess Math1 dataset""" 12 | def process(self): 13 | super().process() 14 | # 读取文本文件转换为 dataframe 15 | df_inter = pd.read_csv(f"{self.rawpath}/data.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 16 | '10', '11', '12', '13', '14', '15', '16', '17', '18', '19']) 17 | df_exer = pd.read_csv(f"{self.rawpath}/q.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9','10']) 18 | 19 | # 处理 df_inter 拆成很多行 20 | df_inter.insert(0, 'stu_id:token', range(len(df_inter))) 21 | df_inter = df_inter.melt(id_vars=['stu_id:token'], var_name='exer_id:token', value_name='label:float') 22 | df_inter['exer_id:token'] = df_inter['exer_id:token'].astype(int) 23 | df_inter .sort_values(by = ['stu_id:token','exer_id:token'],inplace=True) 24 | 25 | # 处理 df_exer 先拆成很多行,再合并 26 | 27 | # 拆成很多行 28 | df_exer.insert(0, 'exer_id:token', range(len(df_exer))) 29 | df_exer = df_exer.melt(id_vars=['exer_id:token'], var_name='cpt_seq:token_seq', value_name='value') 30 | df_exer = df_exer[df_exer['value'] == 1] 31 | del df_exer['value'] 32 | 33 | # 合并 cpt_seq:token_seq 34 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(int) 35 | df_exer.sort_values(by='cpt_seq:token_seq', inplace=True) 36 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(str) 37 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(str) 38 | df_exer = df_exer.groupby('exer_id:token')['cpt_seq:token_seq'].agg(','.join).reset_index() 39 | 40 | # 按 exer_id:token 进行排序 41 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(int) 42 | df_exer.sort_values(by='exer_id:token', inplace=True) 43 | 44 | # 此处将数据保存到`self.midpath`中 45 | 46 | df_inter.to_csv(f"{self.midpath}/{self.dt}.inter.csv", index=False, encoding='utf-8') 47 | df_exer.to_csv(f"{self.midpath}/{self.dt}.exer.csv", index=False, encoding='utf-8') 48 | -------------------------------------------------------------------------------- /edustudio/atom_op/raw2mid/math2.py: -------------------------------------------------------------------------------- 1 | from .raw2mid import BaseRaw2Mid 2 | import pandas as pd 3 | 4 | r""" 5 | R2M_Math2 6 | ##################################### 7 | Math2 dataset preprocess 8 | """ 9 | 10 | 11 | class R2M_Math2(BaseRaw2Mid): 12 | """R2M_Math2 is to preprocess Math2 dataset""" 13 | def process(self): 14 | super().process() 15 | 16 | # 读取文本文件转换为 dataframe 17 | df_inter = pd.read_csv(f"{self.rawpath}/data.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 18 | '10', '11', '12', '13', '14', '15', '16', '17', '18', '19']) 19 | df_inter.insert(0, 'stu_id:token', range(len(df_inter))) 20 | df_exer = pd.read_csv(f"{self.rawpath}/q.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9','10', 21 | '11', '12', '13', '14', '15']) 22 | df_exer.insert(0, 'exer_id:token', range(len(df_exer))) 23 | 24 | # 处理 df_inter 拆成很多行 25 | df_inter = df_inter.melt(id_vars=['stu_id:token'], var_name='exer_id:token', value_name='label:float') 26 | df_inter['exer_id:token'] = df_inter['exer_id:token'].astype(int) 27 | df_inter .sort_values(by = ['stu_id:token','exer_id:token'], inplace=True) 28 | 29 | # 处理 df_exer 先拆成很多行,再合并 30 | # 拆成很多行 31 | df_exer = df_exer.melt(id_vars=['exer_id:token'], var_name='cpt_seq:token_seq', value_name='value') 32 | df_exer = df_exer[df_exer['value'] == 1] 33 | del df_exer['value'] 34 | 35 | # 合并 cpt_seq:token_seq 36 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(int) 37 | df_exer.sort_values(by='cpt_seq:token_seq', inplace=True) 38 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(str) 39 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(str) 40 | df_exer = df_exer.groupby('exer_id:token')['cpt_seq:token_seq'].agg(','.join).reset_index() 41 | 42 | # 按 exer_id:token 进行排序 43 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(int) 44 | df_exer.sort_values(by='exer_id:token', inplace=True) 45 | 46 | # 保存数据 47 | df_inter.to_csv(f"{self.midpath}/{self.dt}.inter.csv", index=False, encoding='utf-8') 48 | df_exer.to_csv(f"{self.midpath}/{self.dt}.exer.csv", index=False, encoding='utf-8') -------------------------------------------------------------------------------- /edustudio/atom_op/raw2mid/raw2mid.py: -------------------------------------------------------------------------------- 1 | from edustudio.utils.common import UnifyConfig 2 | import logging 3 | import os 4 | 5 | 6 | class BaseRaw2Mid(object): 7 | def __init__(self, dt, rawpath, midpath) -> None: 8 | self.dt = dt 9 | self.rawpath = rawpath 10 | self.midpath = midpath 11 | self.logger = logging.getLogger("edustudio") 12 | if not os.path.exists(self.midpath): 13 | os.makedirs(self.midpath) 14 | 15 | @classmethod 16 | def from_cfg(cls, cfg: UnifyConfig): 17 | rawdata_folder_path = f"{cfg.frame_cfg.data_folder_path}/rawdata" 18 | middata_folder_path = f"{cfg.frame_cfg.data_folder_path}/middata" 19 | dt = cfg.dataset 20 | return cls(dt, rawdata_folder_path, middata_folder_path) 21 | 22 | def process(self, **kwargs): 23 | self.logger.info(f"{self.__class__.__name__} start !") 24 | 25 | @classmethod 26 | def from_cli(cls, dt, rawpath="./rawpath", midpath="./midpath"): 27 | obj = cls(dt, rawpath, midpath) 28 | obj.process() 29 | 30 | # __cli_api_dict__ = { 31 | # 'BaseRaw2Mid': BaseRaw2Mid.from_cli 32 | # } 33 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/CDGKDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | 3 | 4 | class CDGKDataTPL(EduDataTPL): 5 | default_cfg = { 6 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat', 'M2C_CDGK_OP'], 7 | } 8 | 9 | # @classmethod 10 | # def load_data(cls, cfg): # 只在middata存在时调用 11 | # kwargs = super().load_data(cfg) 12 | # if cfg.datatpl_cfg['has_cpt2group_file'] is True: 13 | # new_kwargs = cls.load_cpt2group(cfg) 14 | # for df in new_kwargs.values(): 15 | # if df is not None: 16 | # cls._preprocess_feat(df) # 类型转换 17 | # kwargs.update(new_kwargs) 18 | # else: 19 | # kwargs['df_cpt2group'] = None 20 | # return kwargs 21 | 22 | 23 | # @classmethod 24 | # def load_cpt2group(cls, cfg): 25 | # cpt2group_fph = f'{cfg.frame_cfg.data_folder_path}/middata/{cfg.dataset}.cpt2group.csv' 26 | # sep = cfg.datatpl_cfg['seperator'] 27 | # df_cpt2group = cls._load_atomic_csv(cpt2group_fph, sep=sep) 28 | # return {"df_cpt2group": df_cpt2group} 29 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/CDInterDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import GeneralDataTPL 2 | 3 | class CDInterDataTPL(GeneralDataTPL): 4 | default_cfg = { 5 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD'], 6 | } 7 | 8 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/CDInterExtendsQDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | 3 | class CDInterExtendsQDataTPL(EduDataTPL): 4 | default_cfg = { 5 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'], 6 | } 7 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/CNCDFDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | 3 | class CNCDFDataTPL(EduDataTPL): 4 | default_cfg = { 5 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'], 6 | } 7 | 8 | def get_extra_data(self): 9 | return { 10 | 'content': self.df_exer['content:token_seq'].to_list(), 11 | 'Q_mat': self.final_kwargs['Q_mat'] 12 | } 13 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/CNCDQDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | import torch 3 | import pandas as pd 4 | import os 5 | 6 | class CNCDQDataTPL(EduDataTPL): 7 | default_cfg = { 8 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'], 9 | } 10 | 11 | @property 12 | def common_str2df(self): 13 | return { 14 | "df": self.df, "df_train": self.df_train, "df_valid": self.df_valid, 15 | "df_test": self.df_test, "dt_info": self.datatpl_cfg['dt_info'], 16 | "df_stu": self.df_stu, "df_exer": self.df_exer, 17 | "df_questionnaire": self.df_questionnaire 18 | } 19 | 20 | @classmethod 21 | def load_data(cls, cfg): # 只在middata存在时调用 22 | kwargs = super().load_data(cfg) 23 | new_kwargs = cls.load_data_from_questionnaire(cfg) 24 | for df in new_kwargs.values(): 25 | if df is not None: 26 | cls._preprocess_feat(df) # 类型转换 27 | kwargs.update(new_kwargs) 28 | return kwargs 29 | 30 | @classmethod 31 | def load_data_from_questionnaire(cls, cfg): 32 | file_path = f'{cfg.frame_cfg.data_folder_path}/middata/{cfg.dataset}.questionnaire.csv' 33 | df_questionnaire = None 34 | if os.path.exists(file_path): 35 | sep = cfg.datatpl_cfg['seperator'] 36 | df_questionnaire = pd.read_csv(file_path, sep=sep, encoding='utf-8', usecols=['cpt_head:token', 'cpt_tail:token']) 37 | return {"df_questionnaire": df_questionnaire} 38 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/DCDDataTPL.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ..common.edu_datatpl import EduDataTPL 3 | import json 4 | from edustudio.datatpl.common.general_datatpl import DataTPLStatus 5 | import torch 6 | 7 | 8 | class DCDDataTPL(EduDataTPL): 9 | default_cfg = { 10 | 'n_folds': 5, 11 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat', 'M2C_BuildMissingQ', 'M2C_FillMissingQ'], 12 | 'cpt_relation_file_name': 'cpt_relation', 13 | } 14 | 15 | def __init__(self, cfg, df, df_train=None, df_valid=None, df_test=None, dict_cpt_relation=None, status=DataTPLStatus(), df_stu=None, df_exer=None): 16 | self.dict_cpt_relation = dict_cpt_relation 17 | super().__init__(cfg, df, df_train, df_valid, df_test, df_stu, df_exer, status) 18 | 19 | def _check_param(self): 20 | super()._check_params() 21 | assert 0 <= self.datatpl_cfg['Q_delete_ratio'] < 1 22 | 23 | @property 24 | def common_str2df(self): 25 | dic = super().common_str2df 26 | dic['dict_cpt_relation'] = self.dict_cpt_relation 27 | return dic 28 | 29 | 30 | def process_data(self): 31 | super().process_data() 32 | dt_info = self.final_kwargs['dt_info'] 33 | user_count = dt_info['stu_count'] 34 | item_count = dt_info['exer_count'] 35 | self.interact_mat_list = [] 36 | for interact_df in self.final_kwargs['df_train_folds']: 37 | interact_mat = torch.zeros((user_count, item_count), dtype=torch.int8) 38 | idx = interact_df[interact_df['label:float'] == 1][['stu_id:token','exer_id:token']].to_numpy() 39 | interact_mat[idx[:,0], idx[:,1]] = 1 40 | idx = interact_df[interact_df['label:float'] != 1][['stu_id:token','exer_id:token']].to_numpy() 41 | interact_mat[idx[:,0], idx[:,1]] = -1 42 | self.interact_mat_list.append(interact_mat) 43 | 44 | self.final_kwargs['interact_mat_list'] = self.interact_mat_list 45 | 46 | if self.final_kwargs['dict_cpt_relation'] is None: 47 | self.final_kwargs['dict_cpt_relation'] = {i: [i] for i in range(self.final_kwargs['dt_info']['cpt_count'])} 48 | 49 | @classmethod 50 | def load_data(cls, cfg): 51 | kwargs = super().load_data(cfg) 52 | fph = f"{cfg.frame_cfg.data_folder_path}/middata/{cfg.datatpl_cfg['cpt_relation_file_name']}.json" 53 | if os.path.exists(fph): 54 | with open(fph, 'r', encoding='utf-8') as f: 55 | kwargs['dict_cpt_relation'] = json.load(f) 56 | else: 57 | cfg.logger.warning("without cpt_relation.json") 58 | kwargs['dict_cpt_relation'] = None 59 | return kwargs 60 | 61 | def get_extra_data(self): 62 | extra_dict = super().get_extra_data() 63 | extra_dict['filling_Q_mat'] = self.filling_Q_mat 64 | extra_dict['interact_mat'] = self.interact_mat 65 | return extra_dict 66 | 67 | def set_info_for_fold(self, fold_id): 68 | super().set_info_for_fold(fold_id) 69 | self.filling_Q_mat = self.final_kwargs['filling_Q_mat_list'][fold_id] 70 | self.interact_mat = self.final_kwargs['interact_mat_list'][fold_id] 71 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/ECDDataTPL.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from ..common import EduDataTPL 6 | import pandas as pd 7 | 8 | 9 | class ECDDataTPL(EduDataTPL): 10 | default_cfg = { 11 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 12 | 'M2C_GenQMat'], 13 | } 14 | 15 | def process_load_data_from_middata(self): 16 | super().process_load_data_from_middata() 17 | self.final_kwargs['qqq_group_list'] = self.read_QQQ_group(self.cfg) 18 | qqq_list = self.df_stu['qqq_seq:token_seq'].to_list() 19 | self.QQQ_list = torch.tensor(qqq_list) 20 | self.final_kwargs['qqq_list'] = self.QQQ_list 21 | qqq_count = torch.max(self.QQQ_list) + 1 22 | self.cfg['datatpl_cfg']['dt_info']['qqq_count'] = qqq_count 23 | 24 | def get_extra_data(self): 25 | return { 26 | 'qqq_group_list': self.final_kwargs['qqq_group_list'], 27 | 'Q_mat': self.final_kwargs['Q_mat'], 28 | 'qqq_list': self.final_kwargs['qqq_list'] 29 | } 30 | 31 | def read_QQQ_group(self, cfg): 32 | group_path = f'{cfg.frame_cfg.data_folder_path}/middata/{cfg.dataset}_QQQ-group.csv' 33 | assert os.path.exists(group_path) 34 | df_QQQ_group = pd.read_csv(group_path, encoding='utf-8', usecols=['qqq_id:token', 'group_id:token']) 35 | gp = df_QQQ_group.groupby("group_id:token")['qqq_id:token'] 36 | gps = gp.groups 37 | gps_list = [] 38 | for k, v in gps.items(): 39 | gps_list.append(list(v)) 40 | return gps_list 41 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/FAIRDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | 3 | class FAIRDataTPL(EduDataTPL): 4 | default_cfg = { 5 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilteringRecordsByAttr', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'], 6 | } 7 | 8 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/HierCDFDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | import torch 3 | from typing import Dict 4 | import pandas as pd 5 | import os 6 | 7 | class HierCDFDataTPL(EduDataTPL): 8 | default_cfg = { 9 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'], 10 | 'M2C_ReMapId': { 11 | 'share_id_columns': [{'cpt_seq:token_seq', 'cpt_head:token', 'cpt_tail:token'}], 12 | } 13 | } 14 | 15 | def __init__(self, cfg, df_cpt_relation=None, **kwargs): 16 | self.df_cpt_relation = df_cpt_relation 17 | super().__init__(cfg, **kwargs) 18 | 19 | @property 20 | def common_str2df(self): 21 | return { 22 | "df": self.df, "df_train": self.df_train, "df_valid": self.df_valid, 23 | "df_test": self.df_test, "dt_info": self.datatpl_cfg['dt_info'], 24 | "df_stu": self.df_stu, "df_exer": self.df_exer, 25 | "df_cpt_relation": self.df_cpt_relation 26 | } 27 | 28 | @classmethod 29 | def load_data(cls, cfg): # 只在middata存在时调用 30 | kwargs = super().load_data(cfg) 31 | new_kwargs = cls.load_data_from_cpt_relation(cfg) 32 | for df in new_kwargs.values(): 33 | if df is not None: 34 | cls._preprocess_feat(df) # 类型转换 35 | kwargs.update(new_kwargs) 36 | return kwargs 37 | 38 | def get_extra_data(self): 39 | extra_dict = super().get_extra_data() 40 | extra_dict['cpt_relation_edges'] = self._unwrap_feat(self.final_kwargs['df_cpt_relation']) 41 | return extra_dict 42 | 43 | @classmethod 44 | def load_data_from_cpt_relation(cls, cfg): 45 | file_path = f'{cfg.frame_cfg.data_folder_path}/middata/{cfg.dataset}.cpt_relation.prerequisite.csv' 46 | df_cpt_relation = None 47 | if os.path.exists(file_path): 48 | sep = cfg.datatpl_cfg['seperator'] 49 | df_cpt_relation = pd.read_csv(file_path, sep=sep, encoding='utf-8', usecols=['cpt_head:token', 'cpt_tail:token']) 50 | return {"df_cpt_relation": df_cpt_relation} 51 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/MGCDDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | 3 | 4 | class MGCDDataTPL(EduDataTPL): 5 | default_cfg = { 6 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_MGCD_OP', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'], 7 | } 8 | 9 | def get_extra_data(self): 10 | dic = super().get_extra_data() 11 | dic.update({ 12 | 'inter_student': self.final_kwargs['df_inter_stu'], 13 | 'df_G': self.final_kwargs['df_stu'] 14 | }) 15 | return dic 16 | 17 | def df2dict(self): 18 | super().df2dict() 19 | self._unwrap_feat(self.final_kwargs['df_inter_stu']) 20 | self._unwrap_feat(self.final_kwargs['df_stu']) 21 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/RCDDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | import json 3 | import numpy as np 4 | 5 | 6 | class RCDDataTPL(EduDataTPL): 7 | default_cfg = { 8 | 'mid2cache_op_seq': [ 9 | 'M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 10 | 'M2C_RandomDataSplit4CD', 'M2C_BuildKCRelation', 11 | 'M2C_GenQMat', 'M2C_RCD_OP' 12 | ], 13 | 'M2C_BuildKCRelation': { 14 | 'relation_type': 'rcd_transition', 15 | 'threshold': None 16 | } 17 | } 18 | 19 | def get_extra_data(self): 20 | extra_dict = super().get_extra_data() 21 | extra_dict['local_map'] = self.local_map 22 | return extra_dict 23 | 24 | def set_info_for_fold(self, fold_id): 25 | super().set_info_for_fold(fold_id) 26 | self.local_map = self.final_kwargs['local_map'] 27 | self.local_map['u_from_e'] = self.local_map['u_from_e_list'][fold_id] 28 | self.local_map['e_from_u'] = self.local_map['e_from_u_list'][fold_id] 29 | -------------------------------------------------------------------------------- /edustudio/datatpl/CD/__init__.py: -------------------------------------------------------------------------------- 1 | from .CDInterDataTPL import CDInterDataTPL 2 | from .CDInterExtendsQDataTPL import CDInterExtendsQDataTPL 3 | from .IRRDataTPL import IRRDataTPL 4 | from .HierCDFDataTPL import HierCDFDataTPL 5 | from .CNCDFDataTPL import CNCDFDataTPL 6 | from .MGCDDataTPL import MGCDDataTPL 7 | from .CNCDQDataTPL import CNCDQDataTPL 8 | from .RCDDataTPL import RCDDataTPL 9 | from .CDGKDataTPL import CDGKDataTPL 10 | from .ECDDataTPL import ECDDataTPL 11 | from .DCDDataTPL import DCDDataTPL 12 | from .FAIRDataTPL import FAIRDataTPL 13 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/CL4KTDataTPL.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..common import EduDataTPL 4 | 5 | class CL4KTDataTPL(EduDataTPL): 6 | default_cfg = { 7 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_GenUnFoldKCSeq', 'M2C_CL4KT_OP', 'M2C_RandomDataSplit4KT'], 8 | 'M2C_CL4KT_OP': { 9 | 'sequence_truncation': 'recent', 10 | } 11 | } 12 | 13 | def __getitem__(self, index): 14 | dic = super().__getitem__(index) 15 | 16 | return dic 17 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/DIMKTDataTPL.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .KTInterExtendsQDataTPL import KTInterExtendsQDataTPL 4 | import torch 5 | 6 | 7 | class DIMKTDataTPL(KTInterExtendsQDataTPL): 8 | default_cfg = { 9 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_GenUnFoldKCSeq', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT', 'M2C_GenKCSeq', "M2C_DIMKT_OP"], 10 | 'M2C_BuildSeqInterFeats': { 11 | # 'window_size': 200, 12 | "extra_inter_feats": ['start_timestamp:float', 'cpt_unfold:token'] 13 | } 14 | } 15 | 16 | def __getitem__(self, index): 17 | dic = super().__getitem__(index) 18 | dic['qd_seq'] = np.stack( 19 | [self.q_dif[exer_seq][0] for exer_seq in dic['exer_seq']], axis=0 20 | ) 21 | dic['cd_seq'] = np.stack( 22 | [self.c_dif[cpt_seq][0] for cpt_seq in dic['cpt_unfold_seq']], axis=0 23 | ) 24 | dic['cd_seq'] = np.squeeze(dic['cd_seq']) 25 | mask = dic['mask_seq']==0 26 | dic['qd_seq'][mask]=0 27 | dic['cd_seq'][mask] = 0 28 | return dic 29 | 30 | def set_info_for_fold(self, fold_id): 31 | super().set_info_for_fold(fold_id) 32 | self.q_dif = self.final_kwargs['q_diff_list'][fold_id] 33 | self.c_dif = self.final_kwargs['c_diff_list'][fold_id] 34 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/DKTDSCDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | class DKTDSCDataTPL(EduDataTPL): 8 | default_cfg = { 9 | 'mid2cache_op_seq': ["M2C_KCAsExer", 'M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats','M2C_RandomDataSplit4KT', "M2C_DKTDSC_OP"], 10 | } 11 | 12 | def __getitem__(self, index): 13 | dic = super().__getitem__(index) 14 | # dic['cluster'] = self.cluster[(int(dic['stu_id']), int(dic['seg_seq']))] 15 | step = len(dic['seg_seq']) 16 | stu_id = dic['stu_id'].repeat(step) 17 | # cluster_df = pd.DataFrame([[list(each)] for each in torch.cat((stu_id.unsqueeze(1), dic['seg_seq'].unsqueeze(1)), dim=1).numpy()], columns=['stu_seg_id']) 18 | # result = pd.merge(cluster_df, self.cluster, on = ['stu_seg_id']).reset_index(drop=True) 19 | # cluster_id_tensor = torch.Tensor(result['cluster_id'].values) 20 | # dic['cluster'] = cluster_id_tensor 21 | dic['cluster'] = np.ones_like(dic['exer_seq']) 22 | for i in range(step): 23 | try: 24 | dic['cluster'][i] = self.cluster.get((int(stu_id[i]), int(dic['seg_seq'][i]))) 25 | except: 26 | dic['cluster'][i] = 0 27 | dic['cluster'] = torch.from_numpy(dic['cluster']) 28 | return dic 29 | 30 | 31 | def set_info_for_fold(self, fold_id): 32 | super().set_info_for_fold(fold_id) 33 | self.cluster = self.final_kwargs['cluster_list'][fold_id] 34 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/DKTForgetDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | 3 | 4 | class DKTForgetDataTPL(EduDataTPL): 5 | default_cfg = { 6 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats','M2C_RandomDataSplit4KT', "M2C_DKTForget_OP"], 7 | 'M2C_BuildSeqInterFeats': { 8 | "extra_inter_feats": ['start_timestamp:float'] 9 | } 10 | } 11 | 12 | def set_info_for_fold(self, fold_id): 13 | dt_info = self.datatpl_cfg['dt_info'] 14 | dt_info['n_pcount'] = dt_info['n_pcount_list'][fold_id] 15 | dt_info['n_rgap'] = dt_info['n_rgap_list'][fold_id] 16 | dt_info['n_sgap'] = dt_info['n_sgap_list'][fold_id] 17 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/EERNNDataTPL.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from ..common import EduDataTPL 4 | 5 | 6 | class EERNNDataTPL(EduDataTPL): 7 | default_cfg = { 8 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats','M2C_RandomDataSplit4KT', 'M2C_EERNN_OP'], 9 | } 10 | 11 | def get_extra_data(self, **kwargs): 12 | super_dic = super().get_extra_data(**kwargs) 13 | super_dic['w2v_word_emb'] = self.w2v_word_emb 14 | super_dic['exer_content'] = self.content_mat 15 | return super_dic 16 | 17 | def set_info_for_fold(self, fold_id): 18 | dt_info = self.datatpl_cfg['dt_info'] 19 | dt_info['word_count'] = len(self.word_emb_dict_list[fold_id]) 20 | 21 | self.w2v_word_emb = np.vstack( 22 | [self.word_emb_dict_list[fold_id][k] for k in range(self.datatpl_cfg['dt_info']['word_count'])] 23 | ) 24 | 25 | self.content_mat = self.content_mat_list[fold_id] 26 | 27 | def save_cache(self): 28 | super().save_cache() 29 | fph1 = f"{self.cache_folder_path}/word_emb_dict_list.pkl" 30 | fph2 = f"{self.cache_folder_path}/content_mat_list.pkl" 31 | self.save_pickle(fph1, self.word_emb_dict_list) 32 | self.save_pickle(fph2, self.content_mat_list) 33 | 34 | def load_cache(self): 35 | super().load_cache() 36 | fph1 = f"{self.cache_folder_path}/word_emb_dict_list.pkl" 37 | fph2 = f"{self.cache_folder_path}/content_mat_list.pkl" 38 | self.word_emb_dict_list = self.load_pickle(fph1) 39 | self.content_mat_list = self.load_pickle(fph2) 40 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/EKTDataTPL.py: -------------------------------------------------------------------------------- 1 | 2 | from .EERNNDataTPL import EERNNDataTPL 3 | import numpy as np 4 | 5 | 6 | class EKTDataTPL(EERNNDataTPL): 7 | default_cfg = { 8 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT', 'M2C_GenKCSeq', 'M2C_EERNN_OP'], 9 | } 10 | 11 | def __getitem__(self, index): 12 | dic = super().__getitem__(index) 13 | dic['cpt_seq'] = np.stack( 14 | [self.cpt_seq_padding[exer_seq] for exer_seq in dic['exer_seq']], axis=0 15 | ) 16 | dic['cpt_seq_mask'] = np.stack( 17 | [self.cpt_seq_mask[exer_seq] for exer_seq in dic['exer_seq']], axis=0 18 | ) 19 | return dic 20 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/KTInterCptAsExerDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | 3 | class KTInterCptAsExerDataTPL(EduDataTPL): 4 | default_cfg = { 5 | 'mid2cache_op_seq': ["M2C_KCAsExer", 'M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT'], 6 | } 7 | 8 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/KTInterCptUnfoldDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | import numpy as np 3 | 4 | 5 | class KTInterCptUnfoldDataTPL(EduDataTPL): 6 | default_cfg = { 7 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_GenUnFoldKCSeq', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT'], 8 | 'M2C_BuildSeqInterFeats': { 9 | "extra_inter_feats": ['start_timestamp:float', 'cpt_unfold:token'] 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/KTInterDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import GeneralDataTPL 2 | 3 | class KTInterDataTPL(GeneralDataTPL): 4 | default_cfg = { 5 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT'], 6 | } 7 | 8 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/KTInterExtendsQDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | import numpy as np 3 | 4 | 5 | class KTInterExtendsQDataTPL(EduDataTPL): 6 | default_cfg = { 7 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT', 'M2C_GenKCSeq'], 8 | } 9 | 10 | def __getitem__(self, index): 11 | dic = super().__getitem__(index) 12 | dic['cpt_seq'] = np.stack( 13 | [self.cpt_seq_padding[exer_seq] for exer_seq in dic['exer_seq']], axis=0 14 | ) 15 | dic['cpt_seq_mask'] = np.stack( 16 | [self.cpt_seq_mask[exer_seq] for exer_seq in dic['exer_seq']], axis=0 17 | ) 18 | return dic 19 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/LPKTDataTPL.py: -------------------------------------------------------------------------------- 1 | from ..common import EduDataTPL 2 | 3 | 4 | class LPKTDataTPL(EduDataTPL): 5 | default_cfg = { 6 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT', 'M2C_LPKT_OP', "M2C_GenQMat"], 7 | 'M2C_BuildSeqInterFeats': { 8 | "extra_inter_feats": ['start_timestamp:float', 'answer_time:float'] 9 | } 10 | } 11 | 12 | def set_info_for_fold(self, fold_id): 13 | dt_info = self.datatpl_cfg['dt_info'] 14 | dt_info['answer_time_count'] = dt_info['answer_time_count_list'][fold_id] 15 | dt_info['interval_time_count'] = dt_info['interval_time_count_list'][fold_id] 16 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/QDKTDataTPL.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | 4 | from .KTInterExtendsQDataTPL import KTInterExtendsQDataTPL 5 | import torch 6 | 7 | 8 | class QDKTDataTPL(KTInterExtendsQDataTPL): 9 | default_cfg = { 10 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT', 'M2C_GenKCSeq','M2C_GenQMat','M2C_QDKT_OP'], 11 | } 12 | 13 | def get_extra_data(self, **kwargs): 14 | return { 15 | 'laplacian_matrix': self.final_kwargs['laplacian_matrix'], 16 | 'train_dict': self.train_dict 17 | } 18 | 19 | def set_info_for_fold(self, fold_id): 20 | super().set_info_for_fold(fold_id) 21 | self.train_dict = self.dict_train_folds[fold_id] 22 | -------------------------------------------------------------------------------- /edustudio/datatpl/KT/RKTDataTPL.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | 4 | from scipy import sparse 5 | from ..common import EduDataTPL 6 | 7 | 8 | class RKTDataTPL(EduDataTPL): 9 | default_cfg = { 10 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId','M2C_GenQMat', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT'], 11 | 'M2C_BuildSeqInterFeats': { 12 | "extra_inter_feats": ['start_timestamp:float'] 13 | } 14 | } 15 | 16 | def process_load_data_from_middata(self): 17 | super().process_load_data_from_middata() 18 | self.final_kwargs['pro_pro_dense'] = self.get_pro_pro_corr() 19 | 20 | def get_pro_pro_corr(self): 21 | # reference: https://github.com/shalini1194/RKT/issues/2 22 | pro_cpt_adj = [] 23 | pro_num = self.cfg['datatpl_cfg']['dt_info']['exer_count'] 24 | cpt_num = self.cfg['datatpl_cfg']['dt_info']['cpt_count'] 25 | for index in range(len(self.df_exer)): 26 | tmp_df = self.df_exer.iloc[index] 27 | exer_id = tmp_df['exer_id:token'] 28 | cpt_seq = tmp_df['cpt_seq:token_seq'] 29 | for cpt in cpt_seq: 30 | pro_cpt_adj.append([exer_id, cpt, 1]) 31 | pro_cpt_adj = np.array(pro_cpt_adj).astype(np.int32) 32 | pro_cpt_sparse = sparse.coo_matrix((pro_cpt_adj[:, 2].astype(np.float32), 33 | (pro_cpt_adj[:, 0], pro_cpt_adj[:, 1])), shape=(pro_num, cpt_num)) 34 | pro_cpt_csc = pro_cpt_sparse.tocsc() 35 | pro_cpt_csr = pro_cpt_sparse.tocsr() 36 | pro_pro_adj = [] 37 | for p in range(pro_num): 38 | tmp_skills = pro_cpt_csr.getrow(p).indices 39 | similar_pros = pro_cpt_csc[:, tmp_skills].indices 40 | zipped = zip([p] * similar_pros.shape[0], similar_pros) 41 | pro_pro_adj += list(zipped) 42 | 43 | pro_pro_adj = list(set(pro_pro_adj)) 44 | pro_pro_adj = np.array(pro_pro_adj).astype(np.int32) 45 | data = np.ones(pro_pro_adj.shape[0]).astype(np.float32) 46 | pro_pro_sparse = sparse.coo_matrix((data, (pro_pro_adj[:, 0], pro_pro_adj[:, 1])), shape=(pro_num, pro_num)) 47 | return 1-pro_pro_sparse.tocoo().toarray() 48 | 49 | def get_extra_data(self): 50 | return { 51 | "pro_pro_dense": self.final_kwargs['pro_pro_dense'] 52 | } 53 | 54 | def set_info_for_fold(self, fold_id): 55 | super().set_info_for_fold(fold_id) 56 | self.train_dict = self.dict_train_folds[fold_id] -------------------------------------------------------------------------------- /edustudio/datatpl/KT/__init__.py: -------------------------------------------------------------------------------- 1 | from .KTInterDataTPL import KTInterDataTPL 2 | from .KTInterExtendsQDataTPL import KTInterExtendsQDataTPL 3 | from .KTInterCptAsExerDataTPL import KTInterCptAsExerDataTPL 4 | from .EERNNDataTPL import EERNNDataTPL 5 | from .LPKTDataTPL import LPKTDataTPL 6 | from .DKTForgetDataTPL import DKTForgetDataTPL 7 | from .KTInterCptUnfoldDataTPL import KTInterCptUnfoldDataTPL 8 | from .DKTDSCDataTPL import DKTDSCDataTPL 9 | from .DIMKTDataTPL import DIMKTDataTPL 10 | from .QDKTDataTPL import QDKTDataTPL 11 | from .RKTDataTPL import RKTDataTPL 12 | from .CL4KTDataTPL import CL4KTDataTPL 13 | from .EKTDataTPL import EKTDataTPL 14 | from .GKTDataTPL import GKTDataTPL 15 | -------------------------------------------------------------------------------- /edustudio/datatpl/__init__.py: -------------------------------------------------------------------------------- 1 | from .CD import * 2 | from .KT import * 3 | from .common import * 4 | 5 | -------------------------------------------------------------------------------- /edustudio/datatpl/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_datatpl import BaseDataTPL 2 | from .edu_datatpl import EduDataTPL 3 | from .general_datatpl import GeneralDataTPL 4 | from .proxy_datatpl import BaseProxyDataTPL 5 | -------------------------------------------------------------------------------- /edustudio/datatpl/common/proxy_datatpl.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from .base_datatpl import BaseDataTPL 3 | from edustudio.utils.common import UnifyConfig 4 | 5 | 6 | class BaseProxyDataTPL(object): 7 | """The basic protocol for implementing a proxy data template 8 | """ 9 | default_cfg = {'backbone_datatpl_cls': 'BaseDataTPL'} 10 | 11 | @classmethod 12 | def from_cfg_proxy(cls, cfg): 13 | """an interface to instantiate a proxy model 14 | 15 | Args: 16 | cfg (UnifyConfig): the global config object 17 | 18 | Returns: 19 | BaseProxyDataTPL 20 | """ 21 | backbone_datatpl_cls = cls.get_backbone_cls(cfg.datatpl_cfg.backbone_datatpl_cls) 22 | new_cls = cls.get_new_cls(p_cls=backbone_datatpl_cls) 23 | return new_cls.from_cfg(cfg) 24 | 25 | @classmethod 26 | def get_backbone_cls(cls, backbone_datatpl_cls): 27 | if isinstance(backbone_datatpl_cls, str): 28 | backbone_datatpl_cls = importlib.import_module('edustudio.datatpl').\ 29 | __getattribute__(backbone_datatpl_cls) 30 | elif issubclass(backbone_datatpl_cls, BaseDataTPL): 31 | backbone_datatpl_cls = backbone_datatpl_cls 32 | else: 33 | raise ValueError(f"Unknown type of backbone_datatpl_cls: {backbone_datatpl_cls}") 34 | return backbone_datatpl_cls 35 | 36 | @classmethod 37 | def get_new_cls(cls, p_cls): 38 | """dynamic inheritance 39 | 40 | Args: 41 | p_cls (BaseModel): parent class 42 | 43 | Returns: 44 | BaseProxyModel: A inherited class 45 | """ 46 | new_cls = type(cls.__name__ + "_proxy", (cls, p_cls), {}) 47 | return new_cls 48 | 49 | @classmethod 50 | def get_default_cfg(cls, backbone_datatpl_cls, **kwargs): 51 | """Get the final default_cfg 52 | 53 | Args: 54 | backbone_datatpl_cls (BaseDataTPL): backbone data template class name 55 | 56 | Returns: 57 | UnifyConfig: the final default config object 58 | """ 59 | bb_cls = None 60 | if backbone_datatpl_cls is not None: 61 | bb_cls = cls.get_backbone_cls(backbone_datatpl_cls) 62 | else: 63 | for _cls in cls.__mro__: 64 | if not hasattr(_cls, 'default_cfg'): 65 | break 66 | bb_cls = _cls.default_cfg.get('backbone_datatpl_cls', None) 67 | if bb_cls is not None: break 68 | assert bb_cls is not None 69 | bb_cls = cls.get_backbone_cls(bb_cls) 70 | 71 | cfg = UnifyConfig() 72 | cfg.backbone_datatpl_cls = bb_cls 73 | cfg.backbone_datatpl_cls_name = bb_cls.__name__ 74 | new_cls = cls.get_new_cls(p_cls=bb_cls) 75 | for _cls in new_cls.__mro__: 76 | if not hasattr(_cls, 'default_cfg'): 77 | break 78 | if issubclass(_cls, BaseProxyDataTPL): 79 | cfg.update(_cls.default_cfg, update_unknown_key_only=True) 80 | else: 81 | cfg.update(_cls.get_default_cfg(**kwargs), update_unknown_key_only=True) 82 | break 83 | return cfg 84 | -------------------------------------------------------------------------------- /edustudio/datatpl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .pad_seq_util import PadSeqUtil 2 | from .common import BigfileDownloader, DecompressionUtil 3 | from .spliter_util import SpliterUtil 4 | 5 | -------------------------------------------------------------------------------- /edustudio/datatpl/utils/common.py: -------------------------------------------------------------------------------- 1 | from contextlib import closing 2 | from tqdm import tqdm 3 | import requests 4 | import zipfile 5 | 6 | 7 | class BigfileDownloader(object): 8 | @staticmethod 9 | def download(url, title, filepath, chunk_size=10240): 10 | with closing(requests.get(url, stream=True, allow_redirects=True)) as resp: 11 | if resp.status_code != 200: 12 | raise Exception("[ERROR]: {} - {} -{}".format(str(resp.status_code), title, url)) 13 | chunk_size = chunk_size 14 | content_size = int(resp.headers['content-length']) 15 | with tqdm(total=content_size, desc=title, ncols=100) as pbar: 16 | with open(filepath, 'wb') as f: 17 | for data in resp.iter_content(chunk_size=chunk_size): 18 | f.write(data) 19 | pbar.update(len(data)) 20 | 21 | 22 | class DecompressionUtil(object): 23 | @staticmethod 24 | def unzip_file(zip_src, dst_dir): 25 | r = zipfile.is_zipfile(zip_src) 26 | if r: 27 | fz = zipfile.ZipFile(zip_src, 'r') 28 | for file in tqdm(fz.namelist(), desc='unzip...', ncols=100): 29 | fz.extract(file, dst_dir) 30 | else: 31 | raise Exception(f'{zip_src} is not a zip file') 32 | -------------------------------------------------------------------------------- /edustudio/evaltpl/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_evaltpl import BaseEvalTPL 2 | from .prediction_evaltpl import PredictionEvalTPL 3 | from .interpretability_evaltpl import InterpretabilityEvalTPL 4 | from .fairness_evaltpl import FairnessEvalTPL 5 | from .identifiability_evaltpl import IdentifiabilityEvalTPL 6 | 7 | -------------------------------------------------------------------------------- /edustudio/evaltpl/base_evaltpl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | import logging 6 | from edustudio.utils.common import UnifyConfig 7 | from edustudio.utils.callback import CallbackList 8 | 9 | 10 | class BaseEvalTPL(object): 11 | """The baisc protocol for implementing a evaluate template 12 | """ 13 | default_cfg = {} 14 | 15 | def __init__(self, cfg): 16 | self.cfg: UnifyConfig = cfg 17 | self.datatpl_cfg: UnifyConfig = cfg.datatpl_cfg 18 | self.evaltpl_cfg: UnifyConfig = cfg.evaltpl_cfg 19 | self.traintpl_cfg: UnifyConfig = cfg.traintpl_cfg 20 | self.frame_cfg: UnifyConfig = cfg.frame_cfg 21 | self.modeltpl_cfg: UnifyConfig = cfg.modeltpl_cfg 22 | self.logger: logging.Logger = logging.getLogger("edustudio") 23 | self.name = self.__class__.__name__ 24 | self._check_params() 25 | 26 | @classmethod 27 | def get_default_cfg(cls): 28 | parent_class = cls.__base__ 29 | cfg = UnifyConfig(cls.default_cfg) 30 | if hasattr(parent_class, 'get_default_cfg'): 31 | cfg.update(parent_class.get_default_cfg(), update_unknown_key_only=True) 32 | return cfg 33 | 34 | def eval(self, **kwargs): 35 | pass 36 | 37 | def _check_params(self): 38 | pass 39 | 40 | def set_callback_list(self, callbacklist: CallbackList): 41 | self.callback_list = callbacklist 42 | 43 | def set_dataloaders(self, train_loader, test_loader, valid_loader=None): 44 | self.train_loader = train_loader 45 | self.valid_loader = valid_loader 46 | self.test_loader = test_loader 47 | 48 | def add_extra_data(self, **kwargs): 49 | self.extra_data = kwargs 50 | 51 | -------------------------------------------------------------------------------- /edustudio/evaltpl/prediction_evaltpl.py: -------------------------------------------------------------------------------- 1 | from .base_evaltpl import BaseEvalTPL 2 | import numpy as np 3 | from edustudio.utils.common import tensor2npy 4 | from sklearn.metrics import mean_squared_error, roc_auc_score, accuracy_score, f1_score, label_ranking_loss, coverage_error 5 | 6 | 7 | class PredictionEvalTPL(BaseEvalTPL): 8 | """Student Performance Prediction Evaluation 9 | """ 10 | default_cfg = { 11 | 'use_metrics': ['auc', 'acc', 'rmse'] 12 | } 13 | 14 | def __init__(self, cfg): 15 | super().__init__(cfg) 16 | 17 | 18 | def eval(self, y_pd, y_gt, **kwargs): 19 | if not isinstance(y_pd, np.ndarray): y_pd = tensor2npy(y_pd) 20 | if not isinstance(y_gt, np.ndarray): y_gt = tensor2npy(y_gt) 21 | metric_result = {} 22 | ignore_metrics = kwargs.get('ignore_metrics', {}) 23 | for metric_name in self.evaltpl_cfg[self.__class__.__name__]['use_metrics']: 24 | if metric_name not in ignore_metrics: 25 | metric_result[metric_name] = self._get_metrics(metric_name)(y_gt, y_pd) 26 | return metric_result 27 | 28 | def _get_metrics(self, metric): 29 | if metric == "auc": 30 | return roc_auc_score 31 | elif metric == "mse": 32 | return mean_squared_error 33 | elif metric == 'rmse': 34 | return lambda y_gt, y_pd: mean_squared_error(y_gt, y_pd) ** 0.5 35 | elif metric == "acc": 36 | return lambda y_gt, y_pd: accuracy_score(y_gt, np.where(y_pd >= 0.5, 1, 0)) 37 | elif metric == "f1_macro": 38 | return lambda y_gt, y_pd: f1_score(y_gt, y_pd, average='macro') 39 | elif metric == "f1_micro": 40 | return lambda y_gt, y_pd: f1_score(y_gt, y_pd, average='micro') 41 | elif metric == "ranking_loss": 42 | return lambda y_gt, y_pd: label_ranking_loss(y_gt, y_pd) 43 | elif metric == 'coverage_error': 44 | return lambda y_gt, y_pd: coverage_error(y_gt, y_pd) 45 | elif metric == 'samples_auc': 46 | return lambda y_gt, y_pd: roc_auc_score(y_gt, y_pd, average='samples') 47 | else: 48 | raise NotImplementedError 49 | 50 | -------------------------------------------------------------------------------- /edustudio/model/CD/__init__.py: -------------------------------------------------------------------------------- 1 | from .irt import IRT 2 | from .mf import MF 3 | from .mirt import MIRT 4 | from .ncdm import NCDM 5 | from .dina import DINA, STEDINA 6 | from .kancd import KaNCD 7 | from .kscd import KSCD 8 | from .cncd_q import CNCD_Q 9 | from .irr import IRR 10 | from .hier_cdf import HierCDF 11 | from .rcd import RCD 12 | from .cncd_f import CNCD_F 13 | from .cdgk import CDGK_SINGLE, CDGK_MULTI 14 | from .cdmfkc import CDMFKC 15 | from .ecd import * 16 | from .mgcd import MGCD 17 | from .dcd import DCD 18 | from .faircd import FairCD_IRT, FairCD_MIRT, FairCD_NCDM 19 | -------------------------------------------------------------------------------- /edustudio/model/CD/irr.py: -------------------------------------------------------------------------------- 1 | r""" 2 | IRR 3 | ################################## 4 | Reference: 5 | Tong et al. "Item Response Ranking for Cognitive Diagnosis." in IJCAI 2021. 6 | """ 7 | from ..basemodel import BaseProxyModel 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class PairSCELoss(nn.Module): 14 | """IRR loss function""" 15 | def __init__(self): 16 | """IRR loss will use cross entropy""" 17 | super(PairSCELoss, self).__init__() 18 | self._loss = nn.CrossEntropyLoss() 19 | 20 | def forward(self, pred1, pred2, sign=1, *args): 21 | """Get the PairSCELoss 22 | 23 | Args: 24 | pred1 (torch.Tensor): positive prediction 25 | pred2 (_type_): negtive prediction 26 | sign (int, optional): 1: pred1 should be greater than pred2; -1: otherwise. Defaults to 1. 27 | 28 | Returns: 29 | torch.Tensor: PairSCELoss 30 | """ 31 | pred = torch.stack([pred1, pred2], dim=1) 32 | return self._loss(pred, ((torch.ones(pred1.shape[0], device=pred.device) - sign) / 2).long()) 33 | 34 | 35 | 36 | class IRR(BaseProxyModel): 37 | """ 38 | backbone_modeltpl_cls: The backbone model of IRR 39 | """ 40 | default_cfg = { 41 | "backbone_modeltpl_cls": "IRT", 42 | 'pair_weight': 0.5, 43 | } 44 | 45 | def __init__(self, cfg): 46 | """Pass parameters from other templates into the model 47 | 48 | Args: 49 | cfg (UnifyConfig): parameters from other templates 50 | """ 51 | super().__init__(cfg) 52 | 53 | def build_model(self): 54 | """Initialize the various components of the model""" 55 | super().build_model() 56 | self.irr_pair_loss = PairSCELoss() 57 | 58 | def get_main_loss(self, **kwargs): 59 | """Get the loss of IRR 60 | 61 | Returns: 62 | dict: {'loss_main': loss} 63 | """ 64 | pair_exer = kwargs['pair_exer'] 65 | pair_pos_stu = kwargs['pair_pos_stu'] 66 | pair_neg_stu = kwargs['pair_neg_stu'] 67 | 68 | kwargs['exer_id'] = pair_exer 69 | kwargs['stu_id'] = pair_pos_stu 70 | pos_pd = self(**kwargs).flatten() 71 | kwargs['stu_id'] = pair_neg_stu 72 | neg_pd = self(**kwargs).flatten() 73 | pos_label = torch.ones(pos_pd.shape[0]).to(self.device) 74 | neg_label = torch.zeros(neg_pd.shape[0]).to(self.device) 75 | point_loss = F.binary_cross_entropy(input=pos_pd, target=pos_label) + F.binary_cross_entropy(input=neg_pd, target=neg_label) 76 | 77 | return { 78 | 'loss_main': self.modeltpl_cfg['pair_weight'] * self.irr_pair_loss(pos_pd, neg_pd) + (1-self.modeltpl_cfg['pair_weight']) * point_loss 79 | } 80 | -------------------------------------------------------------------------------- /edustudio/model/CD/irt.py: -------------------------------------------------------------------------------- 1 | r""" 2 | IRT 3 | ########################################## 4 | 5 | Reference Code: 6 | https://github.com/bigdata-ustc/EduCDM/tree/main/EduCDM/IRT 7 | 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from ..gd_basemodel import GDBaseModel 14 | 15 | 16 | class IRT(GDBaseModel): 17 | r""" 18 | IRT 19 | """ 20 | default_cfg = { 21 | "a_range": -1.0, # disc range 22 | "diff_range": -1.0, # diff range 23 | "fix_a": False, 24 | "fix_c": True, 25 | } 26 | 27 | def __init__(self, cfg): 28 | super().__init__(cfg) 29 | 30 | def build_cfg(self): 31 | if self.modeltpl_cfg['a_range'] is not None and self.modeltpl_cfg['a_range'] < 0: self.modeltpl_cfg['a_range'] = None 32 | if self.modeltpl_cfg['diff_range'] is not None and self.modeltpl_cfg['diff_range'] < 0: self.modeltpl_cfg['diff_range'] = None 33 | 34 | self.n_user = self.datatpl_cfg['dt_info']['stu_count'] 35 | self.n_item = self.datatpl_cfg['dt_info']['exer_count'] 36 | 37 | # 确保c固定时,a一定不能固定 38 | if self.modeltpl_cfg['fix_c'] is False: assert self.modeltpl_cfg['fix_a'] is False 39 | 40 | def build_model(self): 41 | self.theta = nn.Embedding(self.n_user, 1) # student ability 42 | self.a = 0.0 if self.modeltpl_cfg['fix_a'] else nn.Embedding(self.n_item, 1) # exer discrimination 43 | self.b = nn.Embedding(self.n_item, 1) # exer difficulty 44 | self.c = 0.0 if self.modeltpl_cfg['fix_c'] else nn.Embedding(self.n_item, 1) 45 | 46 | def forward(self, stu_id, exer_id, **kwargs): 47 | theta = self.theta(stu_id) 48 | a = self.a(exer_id) 49 | b = self.b(exer_id) 50 | c = self.c if self.modeltpl_cfg['fix_c'] else self.c(exer_id).sigmoid() 51 | 52 | if self.modeltpl_cfg['diff_range'] is not None: 53 | b = self.modeltpl_cfg['diff_range'] * (torch.sigmoid(b) - 0.5) 54 | if self.modeltpl_cfg['a_range'] is not None: 55 | a = self.modeltpl_cfg['a_range'] * torch.sigmoid(a) 56 | else: 57 | a = F.softplus(a) # 让区分度大于0,保持单调性假设 58 | if torch.max(theta != theta) or torch.max(a != a) or torch.max(b != b): # pragma: no cover 59 | raise ValueError('ValueError:theta,a,b may contains nan! The diff_range or a_range is too large.') 60 | return self.irf(theta, a, b, c) 61 | 62 | @staticmethod 63 | def irf(theta, a, b, c, D=1.702): 64 | return c + (1 - c) / (1 + torch.exp(-D * a * (theta - b))) 65 | 66 | def get_main_loss(self, **kwargs): 67 | stu_id = kwargs['stu_id'] 68 | exer_id = kwargs['exer_id'] 69 | label = kwargs['label'] 70 | pd = self(stu_id, exer_id).flatten() 71 | loss = F.binary_cross_entropy(input=pd, target=label) 72 | return { 73 | 'loss_main': loss 74 | } 75 | 76 | def get_loss_dict(self, **kwargs): 77 | return self.get_main_loss(**kwargs) 78 | 79 | @torch.no_grad() 80 | def predict(self, stu_id, exer_id, **kwargs): 81 | return { 82 | 'y_pd': self(stu_id, exer_id).flatten(), 83 | } 84 | 85 | -------------------------------------------------------------------------------- /edustudio/model/CD/mf.py: -------------------------------------------------------------------------------- 1 | from edustudio.model import GDBaseModel 2 | import torch.nn as nn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class MF(GDBaseModel): 8 | default_cfg = { 9 | 'emb_dim': 32, 10 | 'reg_user': 0.0, 11 | 'reg_item': 0.0 12 | } 13 | 14 | def build_cfg(self): 15 | self.n_user = self.datatpl_cfg['dt_info']['stu_count'] 16 | self.n_item = self.datatpl_cfg['dt_info']['exer_count'] 17 | self.emb_size = self.modeltpl_cfg['emb_dim'] 18 | self.reg_user = self.modeltpl_cfg['reg_user'] 19 | self.reg_item = self.modeltpl_cfg['reg_item'] 20 | 21 | def build_model(self): 22 | self.user_emb = nn.Embedding( 23 | num_embeddings=self.n_user, 24 | embedding_dim=self.emb_size 25 | ) 26 | self.item_emb = nn.Embedding( 27 | num_embeddings=self.n_item, 28 | embedding_dim=self.emb_size 29 | ) 30 | 31 | def forward(self, user_idx: torch.LongTensor, item_idx: torch.LongTensor): 32 | assert len(user_idx.shape) == 1 and len(item_idx.shape) == 1 and user_idx.shape[0] == item_idx.shape[0] 33 | return torch.einsum("ij,ij->i", self.user_emb(user_idx), self.item_emb(item_idx)).sigmoid() 34 | 35 | @torch.no_grad() 36 | def predict(self, stu_id, exer_id, **kwargs): 37 | return { 38 | 'y_pd': self(stu_id, exer_id).flatten(), 39 | } 40 | 41 | def get_main_loss(self, **kwargs): 42 | stu_id = kwargs['stu_id'] 43 | exer_id = kwargs['exer_id'] 44 | label = kwargs['label'] 45 | pd = self(stu_id, exer_id).flatten() 46 | loss = F.binary_cross_entropy(input=pd, target=label) 47 | return { 48 | 'loss_main': loss 49 | } 50 | 51 | def get_loss_dict(self, **kwargs): 52 | return self.get_main_loss(**kwargs) -------------------------------------------------------------------------------- /edustudio/model/CD/mirt.py: -------------------------------------------------------------------------------- 1 | r""" 2 | MIRT 3 | ########################################## 4 | 5 | Reference: 6 | Mark D Reckase et al. "Multidimensional item response theory models". Springer, 2009. 7 | 8 | Reference Code: 9 | https://github.com/bigdata-ustc/EduCDM/tree/main/EduCDM/MIRT 10 | 11 | """ 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from ..gd_basemodel import GDBaseModel 17 | 18 | 19 | class MIRT(GDBaseModel): 20 | """ 21 | 第一种: fix_a = True, fix_c = True 22 | 第二种: fix_a = False, fix_c = True 23 | 第三种: fix_a = False, fix_c = False 24 | """ 25 | default_cfg = { 26 | "a_range": -1.0, # disc range 27 | "emb_dim": 32 28 | } 29 | def __init__(self, cfg): 30 | super().__init__(cfg) 31 | 32 | def build_cfg(self): 33 | if self.modeltpl_cfg['a_range'] is not None and self.modeltpl_cfg['a_range'] < 0: self.modeltpl_cfg['a_range'] = None 34 | 35 | self.n_user = self.datatpl_cfg['dt_info']['stu_count'] 36 | self.n_item = self.datatpl_cfg['dt_info']['exer_count'] 37 | self.emb_dim = self.modeltpl_cfg['emb_dim'] 38 | 39 | def build_model(self): 40 | self.theta = nn.Embedding(self.n_user, self.emb_dim) # student ability 41 | self.a = nn.Embedding(self.n_item, self.emb_dim) # exer discrimination 42 | self.b = nn.Embedding(self.n_item, 1) # exer intercept term 43 | 44 | def forward(self, stu_id, exer_id, **kwargs): 45 | theta = self.theta(stu_id) 46 | a = self.a(exer_id) 47 | b = self.b(exer_id).flatten() 48 | 49 | if self.modeltpl_cfg['a_range'] is not None: 50 | a = self.modeltpl_cfg['a_range'] * torch.sigmoid(a) 51 | else: 52 | a = F.softplus(a) # 让区分度大于0,保持单调性假设 53 | if torch.max(theta != theta) or torch.max(a != a) or torch.max(b != b): # pragma: no cover 54 | raise ValueError('ValueError:theta,a,b may contains nan! The diff_range or a_range is too large.') 55 | return self.irf(theta, a, b) 56 | 57 | @staticmethod 58 | def irf(theta, a, b): 59 | return 1 / (1 + torch.exp(- torch.sum(torch.multiply(a, theta), axis=-1) + b)) # 为何sum前要取负号 60 | 61 | def get_main_loss(self, **kwargs): 62 | stu_id = kwargs['stu_id'] 63 | exer_id = kwargs['exer_id'] 64 | label = kwargs['label'] 65 | pd = self(stu_id, exer_id).flatten() 66 | loss = F.binary_cross_entropy(input=pd, target=label) 67 | return { 68 | 'loss_main': loss 69 | } 70 | 71 | def get_loss_dict(self, **kwargs): 72 | return self.get_main_loss(**kwargs) 73 | 74 | @torch.no_grad() 75 | def predict(self, stu_id, exer_id, **kwargs): 76 | return { 77 | 'y_pd': self(stu_id, exer_id).flatten(), 78 | } 79 | -------------------------------------------------------------------------------- /edustudio/model/KT/GKT/__init__.py: -------------------------------------------------------------------------------- 1 | from .gkt import GKT -------------------------------------------------------------------------------- /edustudio/model/KT/__init__.py: -------------------------------------------------------------------------------- 1 | from .dkt import DKT 2 | from .dkvmn import DKVMN 3 | from .dkt_plus import DKT_plus 4 | from .iekt import IEKT 5 | from .dkt_forget import DKTForget 6 | from .akt import AKT 7 | from .ckt import CKT 8 | from .hawkeskt import HawkesKT 9 | from .kqn import KQN 10 | from .deep_irt import DeepIRT 11 | from .sakt import SAKT 12 | from .dkt_dsc import DKTDSC 13 | from .lpkt import LPKT 14 | from .simplekt import SimpleKT 15 | from .saint import SAINT 16 | from .skvmn import SKVMN 17 | from .saint_plus import SAINT_plus 18 | from .ct_ncm import CT_NCM 19 | from .lpkt_s import LPKT_S 20 | from .rkt import RKT 21 | from .qikt import QIKT 22 | from .dtransformer import DTransformer 23 | from .qdkt import QDKT 24 | from .cl4kt import CL4KT 25 | from .dimkt import DIMKT 26 | from .GKT import * 27 | from .atkt import ATKT 28 | from .eernn import EERNNM, EERNNA 29 | from .ekt import EKTM, EKTA 30 | -------------------------------------------------------------------------------- /edustudio/model/KT/dkt.py: -------------------------------------------------------------------------------- 1 | r""" 2 | DKT 3 | ########################################## 4 | 5 | Reference: 6 | Chris Piech et al. "Deep knowledge tracing" in NIPS 2015. 7 | 8 | """ 9 | 10 | from ..gd_basemodel import GDBaseModel 11 | import torch.nn as nn 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | 16 | class DKT(GDBaseModel): 17 | default_cfg = { 18 | 'emb_size': 100, 19 | 'hidden_size': 100, 20 | 'num_layers': 1, 21 | 'dropout_rate': 0.2, 22 | 'rnn_or_lstm': 'lstm', 23 | } 24 | 25 | def __init__(self, cfg): 26 | super().__init__(cfg) 27 | 28 | def build_cfg(self): 29 | self.n_user = self.datatpl_cfg['dt_info']['stu_count'] 30 | self.n_item = self.datatpl_cfg['dt_info']['exer_count'] 31 | assert self.modeltpl_cfg['rnn_or_lstm'] in {'rnn', 'lstm'} 32 | 33 | def build_model(self): 34 | self.exer_emb = nn.Embedding( 35 | self.n_item * 2, self.modeltpl_cfg['emb_size'] 36 | ) 37 | if self.modeltpl_cfg['rnn_or_lstm'] == 'rnn': 38 | self.seq_model = nn.RNN( 39 | self.modeltpl_cfg['emb_size'], self.modeltpl_cfg['hidden_size'], 40 | self.modeltpl_cfg['num_layers'], batch_first=True 41 | ) 42 | else: 43 | self.seq_model = nn.LSTM( 44 | self.modeltpl_cfg['emb_size'], self.modeltpl_cfg['hidden_size'], 45 | self.modeltpl_cfg['num_layers'], batch_first=True 46 | ) 47 | self.dropout_layer = nn.Dropout(self.modeltpl_cfg['dropout_rate']) 48 | self.fc_layer = nn.Linear(self.modeltpl_cfg['hidden_size'], self.n_item) 49 | 50 | def forward(self, exer_seq, label_seq, **kwargs): 51 | input_x = self.exer_emb(exer_seq + label_seq.long() * self.n_item) 52 | output, _ = self.seq_model(input_x) 53 | output = self.dropout_layer(output) 54 | y_pd = self.fc_layer(output).sigmoid() 55 | return y_pd 56 | 57 | @torch.no_grad() 58 | def predict(self, **kwargs): 59 | y_pd = self(**kwargs) 60 | y_pd = y_pd[:, :-1].gather( 61 | index=kwargs['exer_seq'][:, 1:].unsqueeze(dim=-1), dim=2 62 | ).squeeze(dim=-1) 63 | y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1] 64 | y_gt = None 65 | if kwargs.get('label_seq', None) is not None: 66 | y_gt = kwargs['label_seq'][:, 1:] 67 | y_gt = y_gt[kwargs['mask_seq'][:, 1:] == 1] 68 | return { 69 | 'y_pd': y_pd, 70 | 'y_gt': y_gt 71 | } 72 | 73 | def get_main_loss(self, **kwargs): 74 | y_pd = self(**kwargs) 75 | y_pd = y_pd[:, :-1].gather( 76 | index=kwargs['exer_seq'][:, 1:].unsqueeze(dim=-1), dim=2 77 | ).squeeze(dim=-1) 78 | y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1] 79 | y_gt = kwargs['label_seq'][:, 1:] 80 | y_gt = y_gt[kwargs['mask_seq'][:, 1:] == 1] 81 | loss = F.binary_cross_entropy( 82 | input=y_pd, target=y_gt 83 | ) 84 | return { 85 | 'loss_main': loss 86 | } 87 | 88 | def get_loss_dict(self, **kwargs): 89 | return self.get_main_loss(**kwargs) 90 | -------------------------------------------------------------------------------- /edustudio/model/KT/dkt_plus.py: -------------------------------------------------------------------------------- 1 | from .dkt import DKT 2 | import torch.nn as nn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class DKT_plus(DKT): 8 | default_cfg = { 9 | 'lambda_r': 0.01, 10 | 'lambda_w1': 0.003, 11 | 'lambda_w2': 3.0, 12 | 'reg_all_KCs': True 13 | } 14 | 15 | def __init__(self, cfg): 16 | super().__init__(cfg) 17 | 18 | def get_main_loss(self, **kwargs): 19 | q = kwargs['exer_seq'][:, :-1].unsqueeze(dim=-1) 20 | q_shft = kwargs['exer_seq'][:, 1:].unsqueeze(dim=-1) 21 | 22 | y = self(**kwargs) 23 | pred = y[:, :-1].gather(index=q, dim=2).squeeze(dim=-1) 24 | pred_shft = y[:, :-1].gather(index=q_shft, dim=2).squeeze(dim=-1) 25 | 26 | y_curr = pred[kwargs['mask_seq'][:, :-1] == 1] 27 | y_next = pred_shft[kwargs['mask_seq'][:, 1:] == 1] 28 | 29 | gt_curr = kwargs['label_seq'][:, :-1][kwargs['mask_seq'][:, :-1] == 1] 30 | gt_next = kwargs['label_seq'][:, 1:][kwargs['mask_seq'][:, 1:] == 1] 31 | 32 | loss_main = F.binary_cross_entropy(input=y_next, target=gt_next) 33 | loss_r = self.modeltpl_cfg['lambda_r'] * F.binary_cross_entropy(input=y_curr, target=gt_curr) 34 | 35 | if self.modeltpl_cfg['reg_all_KCs']: 36 | diff = y[:, 1:] - y[:, :-1] 37 | else: 38 | diff = (pred_shft - pred)[kwargs['mask_seq'][:, 1:] == 1] 39 | loss_w1 = torch.norm(diff, 1) / len(diff) 40 | loss_w1 = self.modeltpl_cfg['lambda_w1'] * loss_w1 / self.n_item 41 | loss_w2 = torch.norm(diff, 2) / len(diff) 42 | loss_w2 = self.modeltpl_cfg['lambda_w2'] * loss_w2 / self.n_item 43 | 44 | return { 45 | 'loss_main': loss_main, 46 | 'loss_r': loss_r, 47 | 'loss_w1': loss_w1, 48 | 'loss_w2': loss_w2 49 | } 50 | -------------------------------------------------------------------------------- /edustudio/model/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from .basemodel import BaseModel, BaseProxyModel 6 | from .gd_basemodel import GDBaseModel 7 | from .CD import * 8 | from .KT import * -------------------------------------------------------------------------------- /edustudio/model/gd_basemodel.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from abc import abstractmethod 4 | from .utils.common import xavier_normal_initialization, xavier_uniform_initialization, kaiming_normal_initialization, kaiming_uniform_initialization 5 | from .basemodel import BaseModel 6 | 7 | 8 | class GDBaseModel(BaseModel): 9 | """ 10 | The model that using gradient descent method can inherit the class 11 | """ 12 | default_cfg = { 13 | 'param_init_type': 'xavier_normal', # initialization method of model paramters 14 | 'pretrained_file_path': "", # file path of pretrained model parameters 15 | } 16 | 17 | def __init__(self, cfg): 18 | super().__init__(cfg) 19 | self.device = self.traintpl_cfg['device'] 20 | self.share_callback_dict = { 21 | "stop_training": False 22 | } 23 | 24 | @abstractmethod 25 | def build_cfg(self): 26 | """Construct model config 27 | """ 28 | pass 29 | 30 | @abstractmethod 31 | def build_model(self): 32 | """Construct model component 33 | """ 34 | pass 35 | 36 | def _init_params(self): 37 | """Initialize the model parameters 38 | """ 39 | if self.modeltpl_cfg['param_init_type'] == 'default': 40 | pass 41 | elif self.modeltpl_cfg['param_init_type'] == 'xavier_normal': 42 | self.apply(xavier_normal_initialization) 43 | elif self.modeltpl_cfg['param_init_type'] == 'xavier_uniform': 44 | self.apply(xavier_uniform_initialization) 45 | elif self.modeltpl_cfg['param_init_type'] == 'kaiming_normal': 46 | self.apply(kaiming_normal_initialization) 47 | elif self.modeltpl_cfg['param_init_type'] == 'kaiming_uniform': 48 | self.apply(kaiming_uniform_initialization) 49 | elif self.modeltpl_cfg['param_init_type'] == 'init_from_pretrained': 50 | self._load_params_from_pretrained() 51 | 52 | def _load_params_from_pretrained(self): 53 | """Load pretrained model parameters 54 | """ 55 | self.load_state_dict(torch.load(self.modeltpl_cfg['pretrained_file_path'])) 56 | 57 | def predict(self, **kwargs): 58 | """predict process 59 | """ 60 | pass 61 | 62 | def get_loss_dict(self, **kwargs): 63 | """Get a dict object. The key is the loss name, the value is the loss 64 | """ 65 | pass 66 | -------------------------------------------------------------------------------- /edustudio/model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/edustudio/model/utils/__init__.py -------------------------------------------------------------------------------- /edustudio/model/utils/common.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.init import xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, constant_ 3 | 4 | 5 | def xavier_normal_initialization(module): 6 | r""" using `xavier_normal_`_ in PyTorch to initialize the parameters in 7 | nn.Embedding and nn.Linear layers. For bias in nn.Linear layers, 8 | using constant 0 to initialize. 9 | .. _`xavier_normal_`: 10 | https://pytorch.org/docs/stable/nn.init.html?highlight=xavier_normal_#torch.nn.init.xavier_normal_ 11 | Examples: 12 | >>> self.apply(xavier_normal_initialization) 13 | """ 14 | if isinstance(module, nn.Embedding): 15 | xavier_normal_(module.weight.data) 16 | elif isinstance(module, nn.Linear): 17 | xavier_normal_(module.weight.data) 18 | if module.bias is not None: 19 | constant_(module.bias.data, 0) 20 | 21 | 22 | def xavier_uniform_initialization(module): 23 | r""" using `xavier_uniform_`_ in PyTorch to initialize the parameters in 24 | nn.Embedding and nn.Linear layers. For bias in nn.Linear layers, 25 | using constant 0 to initialize. 26 | .. _`xavier_uniform_`: 27 | https://pytorch.org/docs/stable/nn.init.html?highlight=xavier_uniform_#torch.nn.init.xavier_uniform_ 28 | Examples: 29 | >>> self.apply(xavier_uniform_initialization) 30 | """ 31 | if isinstance(module, nn.Embedding): 32 | xavier_uniform_(module.weight.data) 33 | elif isinstance(module, nn.Linear): 34 | xavier_uniform_(module.weight.data) 35 | if module.bias is not None: 36 | constant_(module.bias.data, 0) 37 | 38 | 39 | def kaiming_normal_initialization(module): 40 | r""" using `kaiming_normal_`_ in PyTorch to initialize the parameters in 41 | nn.Embedding and nn.Linear layers. For bias in nn.Linear layers, 42 | using constant 0 to initialize. 43 | .. _`kaiming_normal`: 44 | https://pytorch.org/docs/stable/nn.init.html?highlight=kaiming_normal_#torch.nn.init.kaiming_normal_ 45 | Examples: 46 | >>> self.apply(kaiming_normal_initialization) 47 | """ 48 | if isinstance(module, nn.Embedding): 49 | kaiming_normal_(module.weight.data) 50 | elif isinstance(module, nn.Linear): 51 | kaiming_normal_(module.weight.data) 52 | if module.bias is not None: 53 | constant_(module.bias.data, 0) 54 | 55 | def kaiming_uniform_initialization(module): 56 | r""" using `kaiming_uniform_`_ in PyTorch to initialize the parameters in 57 | nn.Embedding and nn.Linear layers. For bias in nn.Linear layers, 58 | using constant 0 to initialize. 59 | .. _`kaiming_uniform`: 60 | https://pytorch.org/docs/stable/nn.init.html?highlight=kaiming_uniform_#torch.nn.init.kaiming_uniform_ 61 | Examples: 62 | >>> self.apply(kaiming_uniform_initialization) 63 | """ 64 | if isinstance(module, nn.Embedding): 65 | kaiming_uniform_(module.weight.data) 66 | elif isinstance(module, nn.Linear): 67 | kaiming_uniform_(module.weight.data) 68 | if module.bias is not None: 69 | constant_(module.bias.data, 0) 70 | -------------------------------------------------------------------------------- /edustudio/quickstart/__init__.py: -------------------------------------------------------------------------------- 1 | from .quickstart import run_edustudio 2 | -------------------------------------------------------------------------------- /edustudio/quickstart/atom_cmds.py: -------------------------------------------------------------------------------- 1 | import fire 2 | from edustudio.atom_op.raw2mid import _cli_api_dict_ as raw2mid 3 | from edustudio.utils.common import Logger 4 | 5 | def entrypoint(): 6 | Logger().get_std_logger() 7 | fire.Fire( 8 | {"r2m": raw2mid} 9 | ) 10 | -------------------------------------------------------------------------------- /edustudio/quickstart/init_all.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from edustudio.utils.common import UnifyConfig, PathUtil, Logger 6 | import os 7 | 8 | 9 | def init_all(cfg: UnifyConfig): 10 | """initialize process 11 | 12 | Args: 13 | cfg (UnifyConfig): the global config obejct 14 | """ 15 | frame_cfg = cfg.frame_cfg 16 | dataset = cfg.dataset 17 | traintpl_cls_name = cfg.traintpl_cfg.cls if isinstance(cfg.traintpl_cfg.cls, str) else cfg.traintpl_cfg.cls.__name__ 18 | model_cls_name = cfg.modeltpl_cfg.cls if isinstance(cfg.modeltpl_cfg.cls, str) else cfg.modeltpl_cfg.cls.__name__ 19 | 20 | frame_cfg.data_folder_path = f"{frame_cfg.DATA_FOLDER_PATH}/{dataset}" 21 | # PathUtil.check_path_exist(frame_cfg.data_folder_path) 22 | 23 | frame_cfg.TEMP_FOLDER_PATH = os.path.realpath(frame_cfg.TEMP_FOLDER_PATH) 24 | frame_cfg.ARCHIVE_FOLDER_PATH = os.path.realpath(frame_cfg.ARCHIVE_FOLDER_PATH) 25 | 26 | frame_cfg.temp_folder_path = f"{frame_cfg.TEMP_FOLDER_PATH}/{dataset}/{traintpl_cls_name}/{model_cls_name}/{frame_cfg.ID}" 27 | frame_cfg.archive_folder_path = f"{frame_cfg.ARCHIVE_FOLDER_PATH}/{dataset}/{traintpl_cls_name}/{model_cls_name}" 28 | PathUtil.auto_create_folder_path( 29 | frame_cfg.temp_folder_path, frame_cfg.archive_folder_path 30 | ) 31 | log_filepath = f"{frame_cfg.temp_folder_path}/{frame_cfg.ID}.log" 32 | if frame_cfg['LOG_WITHOUT_DATE']: 33 | cfg.logger = Logger( 34 | filepath=log_filepath, fmt='[%(levelname)s]: %(message)s', date_fmt=None, 35 | DISABLE_LOG_STDOUT=cfg['frame_cfg']['DISABLE_LOG_STDOUT'] 36 | ).get_std_logger() 37 | else: 38 | cfg.logger = Logger( 39 | filepath=log_filepath, DISABLE_LOG_STDOUT=cfg['frame_cfg']['DISABLE_LOG_STDOUT'] 40 | ).get_std_logger() 41 | 42 | if frame_cfg['DISABLE_TQDM_BAR'] is True: 43 | from tqdm import tqdm 44 | from functools import partialmethod 45 | tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) 46 | -------------------------------------------------------------------------------- /edustudio/settings.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from edustudio.utils.common import PathUtil as pathUtil 6 | from edustudio.utils.common import IDUtil as idUtil 7 | import os 8 | from edustudio import __version__ 9 | 10 | ID = idUtil.get_random_id_bytime() # RUN ID 11 | EDUSTUDIO_VERSION = __version__ 12 | 13 | WORK_DIR = os.getcwd() 14 | 15 | DATA_FOLDER_PATH = f"{WORK_DIR}/data" 16 | TEMP_FOLDER_PATH = f"{WORK_DIR}/temp" 17 | ARCHIVE_FOLDER_PATH = f"{WORK_DIR}/archive" 18 | CFG_FOLDER_PATH = f"{WORK_DIR}/conf" 19 | 20 | pathUtil.auto_create_folder_path( 21 | TEMP_FOLDER_PATH, 22 | ARCHIVE_FOLDER_PATH, 23 | DATA_FOLDER_PATH, 24 | CFG_FOLDER_PATH, 25 | ) 26 | 27 | DISABLE_TQDM_BAR = False 28 | LOG_WITHOUT_DATE = False 29 | TQDM_NCOLS = 100 30 | DISABLE_LOG_STDOUT = False 31 | 32 | curr_file_folder = os.path.dirname(__file__) 33 | DT_INFO_FILE_PATH = os.path.realpath(curr_file_folder + "/assets/datasets.yaml") 34 | 35 | DT_INFO_DICT = {} # additional dataset info entrypoint, example: {'FrcSub': {middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/FrcSub/FrcSub-middata.zip} } 36 | -------------------------------------------------------------------------------- /edustudio/traintpl/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from .base_traintpl import BaseTrainTPL 6 | from .gd_traintpl import GDTrainTPL 7 | from .general_traintpl import GeneralTrainTPL 8 | from .atkt_traintpl import AtktTrainTPL 9 | from .dcd_traintpl import DCDTrainTPL 10 | from .adversarial_traintpl import AdversarialTrainTPL 11 | from .group_cd_traintpl import GroupCDTrainTPL 12 | -------------------------------------------------------------------------------- /edustudio/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/edustudio/utils/__init__.py -------------------------------------------------------------------------------- /edustudio/utils/callback/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from .callbacks.callback import Callback 6 | from .callbacks.earlyStopping import EarlyStopping 7 | from .callbacks.baseLogger import BaseLogger 8 | from .callbacks.modelCheckPoint import ModelCheckPoint 9 | from .callbacks.history import History 10 | from .callbacks.epochPredict import EpochPredict 11 | from .callBackList import CallbackList 12 | from .modeState import ModeState 13 | from .callbacks.tensorboardCallBack import TensorboardCallback 14 | -------------------------------------------------------------------------------- /edustudio/utils/callback/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/edustudio/utils/callback/callbacks/__init__.py -------------------------------------------------------------------------------- /edustudio/utils/callback/callbacks/baseLogger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from .callback import Callback 6 | from collections import defaultdict 7 | 8 | 9 | class BaseLogger(Callback): 10 | def __init__(self, logger=None, **kwargs): 11 | super(BaseLogger, self).__init__() 12 | self.log = print 13 | if logger is not None and hasattr(logger, 'info'): 14 | self.log = logger.info 15 | 16 | self.join_str = kwargs.get("join_str", " | ") 17 | self.group_by_contains = kwargs.get('group_by_contains', ('loss')) 18 | self.group_by_count = kwargs.get('group_by_count', 5) 19 | assert self.group_by_count >= 1 and type(self.group_by_count) is int 20 | 21 | def on_train_begin(self, logs=None, **kwargs): 22 | super().on_train_begin() 23 | self.log("Start Training...") 24 | 25 | def on_epoch_end(self, epoch: int, logs: dict = None, **kwargs): 26 | info = f"[EPOCH={epoch:03d}]: " 27 | flag = [False] * len(logs) 28 | for group_rule in self.group_by_contains: 29 | v_list = [] 30 | for i, (k, v) in enumerate(logs.items()): 31 | if group_rule in k and not flag[i]: 32 | v_list.append(f"{k}: {v:.4f}") 33 | flag[i] = True 34 | if len(v_list) > 0: self.log(f"{info}{self.join_str.join(v_list)}") 35 | 36 | logs = {k: logs[k] for i, k in enumerate(logs) if not flag[i]} 37 | 38 | v_list = [] 39 | for i, (k, v) in enumerate(logs.items()): 40 | v_list.append(f"{k}: {v:.4f}") 41 | if (i+1) % self.group_by_count == 0: 42 | self.log(f"{info}{self.join_str.join(v_list)}") 43 | v_list = [] 44 | if len(v_list) > 0: self.log(f"{info}{self.join_str.join(v_list)}") 45 | 46 | def on_train_end(self, logs=None, **kwargs): 47 | self.log("Training Completed!") 48 | -------------------------------------------------------------------------------- /edustudio/utils/callback/callbacks/callback.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | import torch.nn as nn 6 | from ..modeState import ModeState 7 | import logging 8 | 9 | 10 | class Callback(object): 11 | def __init__(self): 12 | self.model = None 13 | self.mode_state = None 14 | self.logger = None 15 | self.callback_list = None 16 | 17 | def set_model(self, model: nn.Module): 18 | self.model = model 19 | 20 | def set_state(self, mode_state: ModeState): 21 | self.mode_state = mode_state 22 | 23 | def set_logger(self, logger: logging.Logger): 24 | self.logger = logger 25 | 26 | def set_callback_list(self, callback_list): 27 | self.callback_list = callback_list 28 | 29 | def on_train_begin(self, logs=None, **kwargs): 30 | self.logger.info(f"[CALLBACK]-{self.__class__.__name__} has been registered!") 31 | 32 | def on_train_end(self, logs=None, **kwargs): 33 | pass 34 | 35 | def on_epoch_begin(self, epoch, logs=None, **kwargs): 36 | pass 37 | 38 | def on_epoch_end(self, epoch, logs=None, **kwargs): 39 | pass 40 | 41 | def on_train_batch_begin(self, batch, logs=None, **kwargs): 42 | pass 43 | 44 | def on_train_batch_end(self, batch, logs=None, **kwargs): 45 | pass 46 | -------------------------------------------------------------------------------- /edustudio/utils/callback/callbacks/earlyStopping.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from typing import List 6 | from .callback import Callback 7 | import numpy as np 8 | 9 | 10 | class Metric(object): 11 | def __init__(self, name, type_): 12 | assert type_ in ['max', 'min'] 13 | self.name = name 14 | self.type_ = type_ 15 | self.best_epoch = 1 16 | if self.type_ == 'max': 17 | self.best_value = -np.inf 18 | else: 19 | self.best_value = np.inf 20 | 21 | def better_than(self, value): 22 | return value < self.best_value if self.type_ == 'max' else value > self.best_value 23 | 24 | def update(self, epoch, value): 25 | self.best_epoch = epoch 26 | self.best_value = value 27 | 28 | 29 | class EarlyStopping(Callback): 30 | def __init__(self, metric_list:List[list], num_stop_rounds: int = 20, start_round=1): 31 | """_summary_ 32 | 33 | Args: 34 | metric_list (List[list]): [['rmse', 'min'],['ndcg', 'max']] 35 | num_stop_rounds (int, optional): all metrics have no improvement in latest num_stop_rounds, suggest to stop training. Defaults to 20. 36 | start_round (int, optional): start detecting from epoch start_round, . Defaults to 1. 37 | """ 38 | super().__init__() 39 | assert num_stop_rounds >= 1 40 | assert start_round >= 1 41 | self.start_round = start_round 42 | self.num_stop_round = num_stop_rounds 43 | self.stop_training = False 44 | self.metric_list = [Metric(name=metric_name, type_=metric_type) for metric_name, metric_type in metric_list] 45 | 46 | def on_train_begin(self, logs=None, **kwargs): 47 | super().on_train_begin() 48 | self.model.share_callback_dict['stop_training'] = False 49 | 50 | def on_epoch_end(self, epoch: int, logs: dict = None, **kwargs): 51 | flag = True 52 | for metric in self.metric_list: 53 | if not metric.better_than(logs[metric.name]): 54 | metric.update(epoch=epoch, value=logs[metric.name]) 55 | 56 | if self.start_round <= epoch: 57 | if epoch - metric.best_epoch < self.num_stop_round: 58 | flag &= False 59 | 60 | if flag is True: 61 | self.logger.info("Suggest to stop training now") 62 | self.stop_training = True 63 | self.model.share_callback_dict['stop_training'] = True 64 | -------------------------------------------------------------------------------- /edustudio/utils/callback/callbacks/epochPredict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from .callback import Callback 6 | import os 7 | import pickle 8 | 9 | 10 | class EpochPredict(Callback): 11 | def __init__(self, save_folder_path, fmt="predict-{epoch}", **kwargs): 12 | super(EpochPredict, self).__init__() 13 | self.kwargs = kwargs 14 | self.save_folder_path = save_folder_path 15 | if not os.path.exists(save_folder_path): 16 | os.makedirs(self.save_folder_path) 17 | self.file_fmt = fmt + ".pkl" 18 | self.pd = None 19 | assert "{epoch}" in self.file_fmt 20 | 21 | def on_epoch_end(self, epoch, logs=None, **kwargs): 22 | self.pd = self.model.predict(**self.kwargs) 23 | filepath = os.path.join(self.save_folder_path, self.file_fmt.format(epoch=epoch)) 24 | self.to_pickle(filepath, self.pd) 25 | 26 | @staticmethod 27 | def to_pickle(filepath, obj): 28 | with open(filepath, 'wb') as f: 29 | pickle.dump(obj, f, protocol=4) 30 | -------------------------------------------------------------------------------- /edustudio/utils/callback/callbacks/history.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from .callback import Callback 6 | from collections import defaultdict 7 | from typing import List, Sequence, Optional 8 | import json 9 | import os 10 | import numpy as np 11 | from edustudio.utils.common import NumpyEncoder 12 | 13 | 14 | class History(Callback): 15 | def __init__(self, folder_path, exclude_metrics=(), plot_curve=False): 16 | super(History, self).__init__() 17 | self.log_as_time = {} 18 | self.exclude_metrics = set(exclude_metrics) 19 | self.folder_path = folder_path 20 | if not os.path.exists(self.folder_path): 21 | os.makedirs(self.folder_path) 22 | 23 | self.plot_curve = plot_curve 24 | if self.plot_curve: 25 | import matplotlib 26 | if not os.path.exists(self.folder_path+"/plots/"): 27 | os.makedirs(self.folder_path+"/plots/") 28 | self.names = None 29 | 30 | def on_epoch_end(self, epoch, logs=None, **kwargs): 31 | self.log_as_time[epoch] = {k:v for k,v in logs.items() if isinstance(v, (int, str, float)) and k not in self.exclude_metrics} 32 | if self.names is None: 33 | self.names = set(self.log_as_time[epoch].keys()) 34 | else: 35 | assert self.names == set(self.log_as_time[epoch].keys()) 36 | 37 | 38 | def on_train_end(self, logs=None, **kwargs): 39 | self.dump_json(self.log_as_time, os.path.join(self.folder_path, "history.json")) 40 | 41 | if self.plot_curve: 42 | self.logger.info("[CALLBACK]-History: Plot Curve...") 43 | self.plot() 44 | self.logger.info("[CALLBACK]-History: Plot Succeed!") 45 | 46 | def plot(self): 47 | import matplotlib.pyplot as plt 48 | epoch_num = len(self.log_as_time) 49 | x_arr = np.arange(1, epoch_num+1) 50 | for name in self.names: 51 | value_arr = [self.log_as_time[i][name] for i in range(1, epoch_num+1)] 52 | plt.figure() 53 | plt.title(name) 54 | plt.xlabel("epoch") 55 | plt.ylabel("value") 56 | plt.plot(x_arr, value_arr) 57 | plt.autoscale() 58 | plt.savefig(f"{self.folder_path}/plots/{name}.png", dpi=500, bbox_inches='tight', pad_inches=0.1) 59 | 60 | @staticmethod 61 | def dump_json(data, filepath): 62 | if not os.path.exists(os.path.dirname(filepath)): 63 | os.makedirs(os.path.dirname(filepath)) 64 | with open(filepath, 'w', encoding='utf8') as f: 65 | json.dump(data, fp=f, indent=2, ensure_ascii=False, cls=NumpyEncoder) 66 | -------------------------------------------------------------------------------- /edustudio/utils/callback/callbacks/modelCheckPoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from .callback import Callback 6 | from typing import Tuple, List 7 | import torch 8 | import os 9 | import shutil 10 | import glob 11 | from .baseLogger import BaseLogger 12 | import numpy as np 13 | from collections import namedtuple 14 | 15 | 16 | class Metric(object): 17 | def __init__(self, name, type_): 18 | assert type_ in ['max', 'min'] 19 | self.name = name 20 | self.type_ = type_ 21 | self.best_epoch = 1 22 | self.best_log = dict() 23 | 24 | if self.type_ == 'max': 25 | self.best_value = -np.inf 26 | else: 27 | self.best_value = np.inf 28 | 29 | def better_than(self, value): 30 | return value < self.best_value if self.type_ == 'max' else value > self.best_value 31 | 32 | def update(self, epoch, value, log): 33 | self.best_epoch = epoch 34 | self.best_value = value 35 | self.best_log = log 36 | 37 | class ModelCheckPoint(Callback): 38 | def __init__(self, metric_list: List[list], save_folder_path, save_best_only=True): 39 | super().__init__() 40 | self.save_folder_path = save_folder_path 41 | if not os.path.exists(self.save_folder_path): 42 | os.makedirs(self.save_folder_path) 43 | self.save_best_only = save_best_only 44 | self.metric_list = [Metric(name=metric_name, type_=metric_type) for metric_name, metric_type in metric_list] 45 | 46 | 47 | def on_epoch_end(self, epoch, logs=None, **kwargs): 48 | for metric in self.metric_list: 49 | if not metric.better_than(logs[metric.name]): 50 | metric.update(epoch=epoch, value=logs[metric.name], log=logs) 51 | 52 | self.logger.info(f"Update best epoch as {[epoch]} for {metric.name}!") 53 | filepath = os.path.join(self.save_folder_path, f"best-epoch-for-{metric.name}.pth") 54 | torch.save(obj=self.model.state_dict(), f=filepath) 55 | 56 | if not self.save_best_only: 57 | filepath = os.path.join(self.save_folder_path, f"{epoch:03d}.pth") 58 | torch.save(obj=self.model.state_dict(), f=filepath) 59 | 60 | def on_train_end(self, logs=None, **kwargs): 61 | self.logger.info("==="*10) 62 | for metric in self.metric_list: 63 | self.logger.info(f"[For {metric.name}], the Best Epoch is: {metric.best_epoch}, the value={metric.best_value:.4f}") 64 | for cb in self.callback_list.callbacks: 65 | if isinstance(cb, BaseLogger): 66 | cb.on_epoch_end(metric.best_epoch, logs=metric.best_log) 67 | # rename best epoch filename 68 | os.renames( 69 | os.path.join(self.save_folder_path, f"best-epoch-for-{metric.name}.pth"), 70 | os.path.join(self.save_folder_path, f"best-epoch-{metric.best_epoch:03d}-for-{metric.name}.pth") 71 | ) 72 | self.logger.info("===" * 10) 73 | -------------------------------------------------------------------------------- /edustudio/utils/callback/callbacks/tensorboardCallBack.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from .callback import Callback 6 | 7 | class TensorboardCallback(Callback): 8 | def __init__(self, log_dir, comment=''): 9 | super(TensorboardCallback, self).__init__() 10 | from torch.utils.tensorboard import SummaryWriter 11 | self.writer = SummaryWriter(log_dir=log_dir, comment=comment) 12 | 13 | def on_epoch_end(self, epoch, logs=None, **kwargs): 14 | for name, value in logs.items(): 15 | self.writer.add_scalar(tag=name, scalar_value=value, global_step=epoch) 16 | -------------------------------------------------------------------------------- /edustudio/utils/callback/modeState.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | from enum import Enum 6 | 7 | 8 | class ModeState(Enum): 9 | START = 1 10 | TRAINING = 2 11 | END = 3 12 | -------------------------------------------------------------------------------- /edustudio/utils/common/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | 6 | from .commonUtil import PathUtil, IDUtil, tensor2npy, tensor2cpu, IOUtil, set_same_seeds, get_gpu_usage, DecoratorTimer 7 | from .configUtil import UnifyConfig, NumpyEncoder 8 | from .loggerUtil import Logger 9 | -------------------------------------------------------------------------------- /edustudio/utils/common/loggerUtil.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Xiangzhi Chen 3 | # @Github : kervias 4 | 5 | import logging 6 | import pytz 7 | import sys 8 | import os 9 | 10 | 11 | class Logger(object): 12 | def __init__(self, filepath: str = None, 13 | fmt: str = "%(asctime)s[%(levelname)s]: %(message)s", 14 | date_fmt: str = "%Y-%m-%d %H:%M:%S", 15 | timezone: str = "Asia/Shanghai", 16 | level=logging.DEBUG, 17 | DISABLE_LOG_STDOUT=False 18 | ): 19 | self.timezone = timezone 20 | if filepath: 21 | dir_name = os.path.dirname(filepath) 22 | if not os.path.exists(dir_name): 23 | os.makedirs(dir_name) 24 | 25 | self.logger = logging.getLogger("edustudio") 26 | self.logger.setLevel(level) 27 | formatter = logging.Formatter( 28 | fmt, datefmt=date_fmt 29 | ) 30 | formatter.converter = self.converter 31 | 32 | if filepath: 33 | # write into file 34 | fh = logging.FileHandler(filepath, mode='a+', encoding='utf-8') 35 | fh.setLevel(level) 36 | fh.setFormatter(formatter) 37 | self.logger.addHandler(fh) 38 | 39 | # show on console 40 | ch = logging.StreamHandler(sys.stdout) 41 | ch.setLevel(level) 42 | ch.setFormatter(formatter) 43 | flag = False 44 | for handler in self.logger.handlers: 45 | if type(handler) is logging.StreamHandler: 46 | flag = True 47 | if flag is False and DISABLE_LOG_STDOUT is False: 48 | self.logger.addHandler(ch) 49 | 50 | def _flush(self): 51 | for handler in self.logger.handlers: 52 | handler.flush() 53 | 54 | def debug(self, message): 55 | self.logger.debug(message) 56 | self._flush() 57 | 58 | def info(self, message): 59 | self.logger.info(message) 60 | self._flush() 61 | 62 | def warning(self, message): 63 | self.logger.warning(message) 64 | self._flush() 65 | 66 | def error(self, message): 67 | self.logger.error(message) 68 | self._flush() 69 | 70 | def critical(self, message): 71 | self.logger.critical(message) 72 | self._flush() 73 | 74 | def converter(self, sec): 75 | tz = pytz.timezone(self.timezone) 76 | dt = pytz.datetime.datetime.fromtimestamp(sec, tz) 77 | return dt.timetuple() 78 | 79 | def get_std_logger(self): 80 | return self.logger 81 | -------------------------------------------------------------------------------- /examples/1.run_cd_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='FrcSub', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | 'device': 'cpu' 15 | }, 16 | datatpl_cfg_dict={ 17 | 'cls': 'CDInterExtendsQDataTPL' 18 | }, 19 | modeltpl_cfg_dict={ 20 | 'cls': 'KaNCD', 21 | }, 22 | evaltpl_cfg_dict={ 23 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL', 'IdentifiabilityEvalTPL'], 24 | 'PredictionEvalTPL': { 25 | 'use_metrics': ['auc'], 26 | }, 27 | 'InterpretabilityEvalTPL': { 28 | 'use_metrics': ['doa_all', 'doc_all'], 29 | 'test_only_metrics': ['doa_all', 'doc_all'], 30 | }, 31 | 'IdentifiabilityEvalTPL': { 32 | 'use_metrics': ['IDS'] 33 | } 34 | }, 35 | frame_cfg_dict={ 36 | 'ID': '123456', 37 | } 38 | ) 39 | -------------------------------------------------------------------------------- /examples/2.run_kt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | }, 15 | datatpl_cfg_dict={ 16 | 'cls': 'KTInterCptUnfoldDataTPL' 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'DKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/3.run_with_customized_tpl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | from edustudio.model import DKT 9 | 10 | run_edustudio( 11 | dataset='ASSIST_0910', 12 | cfg_file_name=None, 13 | traintpl_cfg_dict={ 14 | 'cls': 'GeneralTrainTPL', 15 | }, 16 | datatpl_cfg_dict={ 17 | 'cls': 'KTInterCptUnfoldDataTPL' 18 | }, 19 | modeltpl_cfg_dict={ 20 | 'cls': DKT, 21 | }, 22 | evaltpl_cfg_dict={ 23 | 'clses': ['PredictionEvalTPL'], 24 | } 25 | ) 26 | -------------------------------------------------------------------------------- /examples/4.run_cmd_demo.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import sys 3 | import os 4 | 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") 6 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 7 | 8 | 9 | from edustudio.atom_op.raw2mid import _cli_api_dict_ as raw2mid 10 | from edustudio.utils.common import Logger 11 | 12 | def entrypoint(): 13 | Logger().get_std_logger() 14 | fire.Fire( 15 | {"r2m": raw2mid} 16 | ) 17 | 18 | entrypoint() 19 | -------------------------------------------------------------------------------- /examples/5.run_with_hyperopt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | from hyperopt import hp 9 | from hyperopt import fmin, tpe, space_eval 10 | from edustudio.utils.common import IDUtil as idUtil 11 | import uuid 12 | 13 | def deliver_cfg(args): 14 | g_args = { 15 | 'traintpl_cfg': {}, 16 | 'datatpl_cfg': {}, 17 | 'modeltpl_cfg': {}, 18 | 'evaltpl_cfg': {}, 19 | } 20 | for k,v in args.items(): 21 | g, k = k.split(".") 22 | assert g in g_args 23 | g_args[g][k] = v 24 | g_args['frame_cfg'] = { 25 | 'ID': idUtil.get_random_id_bytime() + str(uuid.uuid4()).split("-")[-1] 26 | } 27 | return g_args 28 | 29 | 30 | # objective function 31 | def objective_function(args): 32 | g_args = deliver_cfg(args) 33 | cfg, res = run_edustudio( 34 | dataset='FrcSub', 35 | cfg_file_name=None, 36 | traintpl_cfg_dict=g_args['traintpl_cfg'], 37 | datatpl_cfg_dict=g_args['datatpl_cfg'], 38 | modeltpl_cfg_dict=g_args['modeltpl_cfg'], 39 | evaltpl_cfg_dict=g_args['evaltpl_cfg'], 40 | frame_cfg_dict=g_args['frame_cfg'], 41 | return_cfg_and_result=True, 42 | ) 43 | return res['auc'] 44 | 45 | 46 | space = { 47 | 'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['GeneralTrainTPL']), 48 | 'datatpl_cfg.cls': hp.choice('datapl_cfg.cls', ['CDInterExtendsQDataTPL']), 49 | 'modeltpl_cfg.cls': hp.choice('modeltpl_cfg.cls', ['KaNCD']), 50 | 'evaltpl_cfg.clses': hp.choice('evaltpl_cfg.clses', [['PredictionEvalTPL', 'InterpretabilityEvalTPL']]), 51 | 52 | 53 | 'traintpl_cfg.batch_size': hp.choice('traintpl_cfg.batch_size', [256,]), 54 | 'traintpl_cfg.epoch_num': hp.choice('traintpl_cfg.epoch_num', [2]), 55 | 'modeltpl_cfg.emb_dim': hp.choice('modeltpl_cfg.emb_dim', [20,40]) 56 | } 57 | 58 | best = fmin(objective_function, space, algo=tpe.suggest, max_evals=10, verbose=False) 59 | 60 | print("=="*10) 61 | print(best) 62 | print(space_eval(space, best)) 63 | -------------------------------------------------------------------------------- /examples/6.run_with_ray.tune.py: -------------------------------------------------------------------------------- 1 | # run following after installed edustudio 2 | import sys 3 | import os 4 | 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") 6 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 7 | 8 | from edustudio.quickstart import run_edustudio 9 | from ray import tune 10 | import ray 11 | ray.init(num_cpus=4, num_gpus=1) 12 | from edustudio.utils.common import IDUtil as idUtil 13 | import uuid 14 | 15 | def deliver_cfg(args): 16 | g_args = { 17 | 'traintpl_cfg': {}, 18 | 'datatpl_cfg': {}, 19 | 'modeltpl_cfg': {}, 20 | 'evaltpl_cfg': {}, 21 | 'frame_cfg': {}, 22 | } 23 | for k,v in args.items(): 24 | g, k = k.split(".") 25 | assert g in g_args 26 | g_args[g][k] = v 27 | g_args['frame_cfg'] = { 28 | 'ID': idUtil.get_random_id_bytime() + str(uuid.uuid4()).split("-")[-1] 29 | } 30 | return g_args 31 | 32 | 33 | # objective function 34 | def objective_function(args): 35 | g_args = deliver_cfg(args) 36 | print(g_args) 37 | cfg, res = run_edustudio( 38 | dataset='FrcSub', 39 | cfg_file_name=None, 40 | traintpl_cfg_dict=g_args['traintpl_cfg'], 41 | datatpl_cfg_dict=g_args['datatpl_cfg'], 42 | modeltpl_cfg_dict=g_args['modeltpl_cfg'], 43 | evaltpl_cfg_dict=g_args['evaltpl_cfg'], 44 | frame_cfg_dict=g_args['frame_cfg'], 45 | return_cfg_and_result=True 46 | ) 47 | return res 48 | 49 | 50 | search_space= { 51 | 'traintpl_cfg.cls': tune.grid_search(['GeneralTrainTPL']), 52 | 'datatpl_cfg.cls': tune.grid_search(['CDInterExtendsQDataTPL']), 53 | 'modeltpl_cfg.cls': tune.grid_search(['KaNCD']), 54 | 'evaltpl_cfg.clses': tune.grid_search([['PredictionEvalTPL', 'InterpretabilityEvalTPL']]), 55 | 56 | 57 | 'traintpl_cfg.batch_size': tune.grid_search([256,]), 58 | 'traintpl_cfg.epoch_num': tune.grid_search([2]), 59 | 'traintpl_cfg.device': tune.grid_search(["cuda:0"]), 60 | 'modeltpl_cfg.emb_dim': tune.grid_search([20,40]), 61 | 'frame_cfg.DISABLE_LOG_STDOUT': tune.grid_search([False]), 62 | } 63 | 64 | tuner = tune.Tuner( 65 | tune.with_resources(objective_function, {"gpu": 1}), param_space=search_space, tune_config=tune.TuneConfig(max_concurrent_trials=1), 66 | ) 67 | results = tuner.fit() 68 | 69 | print("=="*10) 70 | print(results.get_best_result(metric="auc", mode="max").config) 71 | -------------------------------------------------------------------------------- /examples/7.run_toy_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'BaseTrainTPL', 14 | }, 15 | datatpl_cfg_dict={ 16 | 'cls': 'BaseDataTPL' 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'BaseModel', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['BaseEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_akt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | }, 15 | datatpl_cfg_dict={ 16 | 'cls': 'KTInterCptUnfoldDataTPL' 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'AKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_atkt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'AtktTrainTPL', 14 | }, 15 | datatpl_cfg_dict={ 16 | 'cls': 'KTInterCptUnfoldDataTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'ATKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_cdgk_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='FrcSub', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | 'batch_size': 1024, 15 | 'eval_batch_size': 1024 16 | }, 17 | datatpl_cfg_dict={ 18 | 'cls': 'CDGKDataTPL', 19 | 'M2C_CDGK_OP': { 20 | 'subgraph_count': 1, 21 | } 22 | }, 23 | modeltpl_cfg_dict={ 24 | 'cls': 'CDGK_MULTI', 25 | }, 26 | evaltpl_cfg_dict={ 27 | 'clses': ['PredictionEvalTPL'], 28 | } 29 | ) 30 | -------------------------------------------------------------------------------- /examples/single_model/run_cdmfkc_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 6 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 7 | 8 | from edustudio.quickstart import run_edustudio 9 | 10 | run_edustudio( 11 | dataset='FrcSub', 12 | cfg_file_name=None, 13 | traintpl_cfg_dict={ 14 | 'cls': 'GeneralTrainTPL', 15 | 'lr': 0.01, 16 | 'epoch_num': 1000 17 | }, 18 | datatpl_cfg_dict={ 19 | 'cls': 'CDInterExtendsQDataTPL', 20 | }, 21 | modeltpl_cfg_dict={ 22 | 'cls': 'CDMFKC', 23 | }, 24 | evaltpl_cfg_dict={ 25 | 'clses': ['PredictionEvalTPL'], 26 | } 27 | ) 28 | -------------------------------------------------------------------------------- /examples/single_model/run_ckt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterExtendsQDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | 'batch_size': 32, 18 | 'eval_batch_size': 32 19 | }, 20 | modeltpl_cfg_dict={ 21 | 'cls': 'CKT', 22 | }, 23 | evaltpl_cfg_dict={ 24 | 'clses': ['PredictionEvalTPL'], 25 | } 26 | ) 27 | -------------------------------------------------------------------------------- /examples/single_model/run_cl4kt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'CL4KTDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | 'eval_batch_size': 1024, 18 | }, 19 | modeltpl_cfg_dict={ 20 | 'cls': 'CL4KT', 21 | }, 22 | evaltpl_cfg_dict={ 23 | 'clses': ['PredictionEvalTPL'], 24 | } 25 | ) 26 | -------------------------------------------------------------------------------- /examples/single_model/run_cncd_f_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='AAAI_2023', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'CNCDFDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'CNCD_F', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_cncdq_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='AAAI_2023', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'CNCDQDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'CNCD_Q', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_ctncm_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterCptUnfoldDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'CT_NCM', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | 26 | -------------------------------------------------------------------------------- /examples/single_model/run_dcd_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 4 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 5 | 6 | from edustudio.quickstart import run_edustudio 7 | 8 | run_edustudio( 9 | dataset='FrcSub', 10 | datatpl_cfg_dict={ 11 | 'cls': 'DCDDataTPL', 12 | 'M2C_BuildMissingQ': { 13 | 'Q_delete_ratio': 0.8 14 | }, 15 | 'M2C_FillMissingQ': { 16 | 'Q_fill_type': "None" 17 | } 18 | }, 19 | modeltpl_cfg_dict={ 20 | 'cls': 'DCD', 21 | 'EncoderUserHidden': [768], 22 | 'EncoderItemHidden': [768], 23 | 'lambda_main': 1.0, 24 | 'lambda_q': 10.0, 25 | 'b_sample_type': 'gumbel_softmax', 26 | 'b_sample_kwargs': {'tau': 1.0, 'hard': True}, 27 | 'align_margin_loss_kwargs': {'margin': 0.7, 'topk': 2, 'd1': 1, 'margin_lambda': 0.5, 'norm': 2, 'norm_lambda': 0.5, 'start_epoch': 1}, 28 | 'beta_user': 0.0, 29 | 'beta_item': 0.0, 30 | 'g_beta_user': 1.0, 31 | 'g_beta_item': 1.0, 32 | 'alpha_user': 0.0, 33 | 'alpha_item': 0.0, 34 | 'gamma_user': 1.0, 35 | 'gamma_item': 1.0, 36 | 'sampling_type': 'mws', 37 | 'bernoulli_prior_p': 0.2, 38 | }, 39 | traintpl_cfg_dict={ 40 | 'cls': 'DCDTrainTPL', 41 | 'seed': 2023, 42 | 'epoch_num': 400, 43 | 'lr': 0.0005, 44 | 'num_workers': 0, 45 | 'batch_size': 2048, 46 | 'num_stop_rounds': 10, 47 | 'early_stop_metrics': [('auc','max'), ('doa_all', 'max')], 48 | 'best_epoch_metric': 'doa_all', 49 | }, 50 | evaltpl_cfg_dict={ 51 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], 52 | 'InterpretabilityEvalTPL': {'test_only_metrics': []} 53 | } 54 | ) 55 | -------------------------------------------------------------------------------- /examples/single_model/run_deepirt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterExtendsQDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | 'eval_batch_size': 512, 18 | }, 19 | modeltpl_cfg_dict={ 20 | 'cls': 'DeepIRT', 21 | }, 22 | evaltpl_cfg_dict={ 23 | 'clses': ['PredictionEvalTPL'], 24 | } 25 | ) 26 | -------------------------------------------------------------------------------- /examples/single_model/run_dimkt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'DIMKTDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'DIMKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_dina_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 6 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 7 | 8 | from edustudio.quickstart import run_edustudio 9 | 10 | run_edustudio( 11 | dataset='FrcSub', 12 | cfg_file_name=None, 13 | traintpl_cfg_dict={ 14 | 'cls': 'GeneralTrainTPL', 15 | }, 16 | datatpl_cfg_dict={ 17 | 'cls': 'CDInterExtendsQDataTPL', 18 | }, 19 | modeltpl_cfg_dict={ 20 | 'cls': 'DINA', 21 | }, 22 | evaltpl_cfg_dict={ 23 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], 24 | } 25 | ) 26 | -------------------------------------------------------------------------------- /examples/single_model/run_dkt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterCptAsExerDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'DKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_dkt_dsc_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'DKTDSCDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'DKTDSC', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_dkt_plus_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterCptAsExerDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'DKT_plus', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_dktforget_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'DKTForgetDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'DKTForget', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_dkvmn_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterExtendsQDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | 'eval_batch_size': 256, 18 | }, 19 | modeltpl_cfg_dict={ 20 | 'cls': 'DKVMN', 21 | }, 22 | evaltpl_cfg_dict={ 23 | 'clses': ['PredictionEvalTPL'], 24 | } 25 | ) 26 | 27 | 28 | -------------------------------------------------------------------------------- /examples/single_model/run_dtransformer_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterCptUnfoldDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | 'batch_size': 16, 18 | 'eval_batch_size': 16 19 | }, 20 | modeltpl_cfg_dict={ 21 | 'cls': 'DTransformer', 22 | }, 23 | evaltpl_cfg_dict={ 24 | 'clses': ['PredictionEvalTPL'], 25 | } 26 | ) 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /examples/single_model/run_ecd_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='PISA_2015_ECD',#ecd only supports PISA dataset 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | 'batch_size': 2048, 15 | 'eval_batch_size': 2048, 16 | }, 17 | datatpl_cfg_dict={ 18 | 'cls': 'ECDDataTPL', 19 | }, 20 | modeltpl_cfg_dict={ 21 | 'cls': 'ECD_IRT',#ECD_IRT,ECD_MIRT,ECD_NCD 22 | }, 23 | evaltpl_cfg_dict={ 24 | 'clses': ['PredictionEvalTPL'], 25 | } 26 | ) 27 | -------------------------------------------------------------------------------- /examples/single_model/run_eernn_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='AAAI_2023', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'EERNNDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'EERNNA', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | 26 | -------------------------------------------------------------------------------- /examples/single_model/run_ekt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='AAAI_2023', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'EKTDataTPL', 14 | 'M2C_BuildSeqInterFeats': { 15 | 'window_size': 50, 16 | } 17 | }, 18 | traintpl_cfg_dict={ 19 | 'cls': 'GeneralTrainTPL', 20 | 'batch_size': 8, 21 | 'eval_batch_size': 8, 22 | }, 23 | modeltpl_cfg_dict={ 24 | 'cls': 'EKTM', 25 | }, 26 | evaltpl_cfg_dict={ 27 | 'clses': ['PredictionEvalTPL'], 28 | } 29 | ) 30 | -------------------------------------------------------------------------------- /examples/single_model/run_faircd_irt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='SLP_English', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'AdversarialTrainTPL', 14 | 'batch_size': 1024 15 | }, 16 | datatpl_cfg_dict={ 17 | 'cls': 'FAIRDataTPL', 18 | }, 19 | modeltpl_cfg_dict={ 20 | 'cls': 'FairCD_IRT', 21 | }, 22 | evaltpl_cfg_dict={ 23 | 'clses': ['PredictionEvalTPL', 'FairnessEvalTPL'], 24 | } 25 | ) 26 | -------------------------------------------------------------------------------- /examples/single_model/run_faircd_mirt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='SLP_English', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'AdversarialTrainTPL', 14 | 'batch_size': 1024 15 | }, 16 | datatpl_cfg_dict={ 17 | 'cls': 'FAIRDataTPL', 18 | }, 19 | modeltpl_cfg_dict={ 20 | 'cls': 'FairCD_MIRT', 21 | }, 22 | evaltpl_cfg_dict={ 23 | 'clses': ['PredictionEvalTPL', 'FairnessEvalTPL'], 24 | } 25 | ) 26 | -------------------------------------------------------------------------------- /examples/single_model/run_faircd_ncdm_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='SLP_English', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'AdversarialTrainTPL', 14 | 'batch_size': 1024 15 | }, 16 | datatpl_cfg_dict={ 17 | 'cls': 'FAIRDataTPL', 18 | }, 19 | modeltpl_cfg_dict={ 20 | 'cls': 'FairCD_NCDM', 21 | }, 22 | evaltpl_cfg_dict={ 23 | 'clses': ['PredictionEvalTPL', 'FairnessEvalTPL'], 24 | } 25 | ) 26 | -------------------------------------------------------------------------------- /examples/single_model/run_gkt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'GKTDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'GKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_hawkeskt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterCptUnfoldDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'HawkesKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_hiercdf_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | from edustudio.atom_op.mid2cache.CD import M2C_FilterRecords4CD 9 | 10 | import networkx as nx 11 | import pandas as pd 12 | 13 | class M2C_MyFilterRecords4CD(M2C_FilterRecords4CD): 14 | 15 | @staticmethod 16 | def build_DAG(edges): 17 | graph = nx.DiGraph() 18 | for edge in edges: 19 | graph.add_edge(*edge) 20 | if not nx.is_directed_acyclic_graph(graph): 21 | graph.remove_edge(*edge) 22 | return graph.edges 23 | 24 | def process(self, **kwargs): 25 | selected_items = kwargs['df_exer']['exer_id:token'].to_numpy() 26 | # kwargs['df_cpt_relation'] = kwargs['df_cpt_relation'].drop_duplicates(["cpt_head:token", "cpt_tail:token"]) 27 | edges = kwargs['df_cpt_relation'][['cpt_head:token', 'cpt_tail:token']].to_numpy() 28 | edges = list(self.build_DAG(edges)) 29 | kwargs['df_cpt_relation'] = pd.DataFrame({"cpt_head:token": [i[0] for i in edges], "cpt_tail:token": [i[1] for i in edges]}) 30 | selected_items = kwargs['df_cpt_relation'].to_numpy().flatten() 31 | df = kwargs['df'] 32 | df = df[df['exer_id:token'].isin(selected_items)].reset_index(drop=True) 33 | kwargs['df'] = df 34 | 35 | kwargs = super().process(**kwargs) 36 | 37 | selected_items = kwargs['df']['exer_id:token'].unique() 38 | kwargs['df_cpt_relation'] = kwargs['df_cpt_relation'][kwargs['df_cpt_relation']['cpt_head:token'].isin(selected_items)] 39 | kwargs['df_cpt_relation'] = kwargs['df_cpt_relation'][kwargs['df_cpt_relation']['cpt_tail:token'].isin(selected_items)] 40 | kwargs['df_cpt_relation'] = kwargs['df_cpt_relation'].reset_index(drop=True) 41 | return kwargs 42 | 43 | 44 | run_edustudio( 45 | dataset='JunyiExerAsCpt', 46 | cfg_file_name=None, 47 | traintpl_cfg_dict={ 48 | 'cls': 'GeneralTrainTPL', 49 | 'batch_size': 2048, 50 | 'eval_batch_size': 2048, 51 | }, 52 | datatpl_cfg_dict={ 53 | 'cls': 'HierCDFDataTPL', 54 | # 'load_data_from': 'rawdata', 55 | # 'raw2mid_op': 'R2M_JunyiExerAsCpt', 56 | 'mid2cache_op_seq': ['M2C_Label2Int', M2C_MyFilterRecords4CD, 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'], 57 | 'M2C_ReMapId': { 58 | 'share_id_columns': [{'cpt_seq:token_seq', 'cpt_head:token', 'cpt_tail:token', 'exer_id:token'}], 59 | }, 60 | 'M2C_MyFilterRecords4CD': { 61 | "stu_least_records": 60, 62 | } 63 | }, 64 | modeltpl_cfg_dict={ 65 | 'cls': 'HierCDF', 66 | }, 67 | evaltpl_cfg_dict={ 68 | 'clses': ['PredictionEvalTPL'], 69 | } 70 | ) 71 | -------------------------------------------------------------------------------- /examples/single_model/run_iekt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterExtendsQDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'IEKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_irr_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 6 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 7 | 8 | from edustudio.quickstart import run_edustudio 9 | 10 | run_edustudio( 11 | dataset='FrcSub', 12 | cfg_file_name=None, 13 | traintpl_cfg_dict={ 14 | 'cls': 'GeneralTrainTPL', 15 | 'lr': 0.0001, 16 | }, 17 | datatpl_cfg_dict={ 18 | 'cls': 'IRRDataTPL', 19 | }, 20 | modeltpl_cfg_dict={ 21 | 'cls': 'IRR', 22 | }, 23 | evaltpl_cfg_dict={ 24 | 'clses': ['PredictionEvalTPL'], 25 | } 26 | ) 27 | -------------------------------------------------------------------------------- /examples/single_model/run_irt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='FrcSub', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | }, 15 | datatpl_cfg_dict={ 16 | 'cls': 'CDInterDataTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'IRT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_kancd_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='FrcSub', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | }, 15 | datatpl_cfg_dict={ 16 | 'cls': 'CDInterExtendsQDataTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'KaNCD', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_kqn_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterCptAsExerDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'KQN', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_kscd_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='FrcSub', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | }, 15 | datatpl_cfg_dict={ 16 | 'cls': 'CDInterExtendsQDataTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'KSCD', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_lpkt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'LPKTDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'LPKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_lpkt_s_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'LPKTDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'LPKT_S', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_mf_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='FrcSub', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | }, 15 | datatpl_cfg_dict={ 16 | 'cls': 'CDInterDataTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'MF', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_mgcd_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GroupCDTrainTPL', 14 | 'early_stop_metrics': [('rmse','min')], 15 | 'best_epoch_metric': 'rmse', 16 | 'batch_size': 512 17 | }, 18 | datatpl_cfg_dict={ 19 | 'cls': 'MGCDDataTPL', 20 | # 'load_data_from': 'rawdata', 21 | # 'raw2mid_op': 'R2M_ASSIST_0910' 22 | }, 23 | modeltpl_cfg_dict={ 24 | 'cls': 'MGCD', 25 | }, 26 | evaltpl_cfg_dict={ 27 | 'clses': ['PredictionEvalTPL'], 28 | 'PredictionEvalTPL': { 29 | 'use_metrics': ['rmse'] 30 | } 31 | } 32 | ) 33 | -------------------------------------------------------------------------------- /examples/single_model/run_mirt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='FrcSub', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | }, 15 | datatpl_cfg_dict={ 16 | 'cls': 'CDInterDataTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'MIRT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_ncdm_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='SLP-Math', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | }, 15 | datatpl_cfg_dict={ 16 | 'cls': 'CDInterExtendsQDataTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'NCDM', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_qdkt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'QDKTDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'QDKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_qikt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterExtendsQDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | 'batch_size': 16, 18 | 'eval_batch_size': 16 19 | }, 20 | modeltpl_cfg_dict={ 21 | 'cls': 'QIKT', 22 | }, 23 | evaltpl_cfg_dict={ 24 | 'clses': ['PredictionEvalTPL'], 25 | } 26 | ) 27 | -------------------------------------------------------------------------------- /examples/single_model/run_rcd_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='FrcSub', 11 | cfg_file_name=None, 12 | traintpl_cfg_dict={ 13 | 'cls': 'GeneralTrainTPL', 14 | }, 15 | datatpl_cfg_dict={ 16 | 'cls': 'RCDDataTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'RCD', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_rkt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'RKTDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'RKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_saint_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterCptUnfoldDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'SAINT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | 26 | -------------------------------------------------------------------------------- /examples/single_model/run_saint_plus_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterCptUnfoldDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'SAINT_plus', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | 26 | -------------------------------------------------------------------------------- /examples/single_model/run_sakt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterCptAsExerDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'SAKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_simplekt_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterCptUnfoldDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'SimpleKT', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /examples/single_model/run_skvmn_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | from edustudio.quickstart import run_edustudio 8 | 9 | run_edustudio( 10 | dataset='ASSIST_0910', 11 | cfg_file_name=None, 12 | datatpl_cfg_dict={ 13 | 'cls': 'KTInterCptAsExerDataTPL', 14 | }, 15 | traintpl_cfg_dict={ 16 | 'cls': 'GeneralTrainTPL', 17 | }, 18 | modeltpl_cfg_dict={ 19 | 'cls': 'SKVMN', 20 | }, 21 | evaltpl_cfg_dict={ 22 | 'clses': ['PredictionEvalTPL'], 23 | } 24 | ) 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.10.0 2 | numpy>=1.17.2 3 | scipy>=1.6.0 4 | pandas>=1.0.5 5 | tqdm>=4.48.2 6 | scikit_learn>=0.23.2 7 | pyyaml>=5.1.0 8 | tensorboard>=2.5.0 9 | requests>=2.27.1 10 | pytz>=2022.1 11 | matplotlib>=3.5.1 12 | deepdiff>=6.3.1 13 | networkx>=2.8 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | 7 | from setuptools import setup, find_packages 8 | 9 | install_requires = [ 10 | "torch>=1.10.0", 11 | "numpy>=1.17.2", 12 | "scipy>=1.6.0", 13 | "pandas>=1.0.5", 14 | "tqdm>=4.48.2", 15 | "scikit_learn>=0.23.2", 16 | "pyyaml>=5.1.0", 17 | "tensorboard>=2.5.0", 18 | "requests>=2.27.1", 19 | "pytz>=2022.1", 20 | "matplotlib>=3.5.1", 21 | "deepdiff>=6.3.1", 22 | "networkx>=2.8" 23 | ] 24 | 25 | setup_requires = [] 26 | 27 | extras_require = {} 28 | 29 | classifiers = ["License :: OSI Approved :: MIT License"] 30 | 31 | long_description = ( 32 | "EduStudio is a Unified and Templatized Framework " 33 | "for Student Assessment Models including " 34 | "Cognitive Diagnosis(CD) and Knowledge Tracing(KT) based on Pytorch." 35 | ) 36 | 37 | # Readthedocs requires Sphinx extensions to be specified as part of 38 | # install_requires in order to build properly. 39 | on_rtd = os.environ.get("READTHEDOCS", None) == "True" 40 | if on_rtd: 41 | install_requires.extend(setup_requires) 42 | 43 | setup( 44 | name="edustudio", 45 | version="v1.1.7", 46 | description="a Unified and Templatized Framework for Student Assessment Models", 47 | long_description=long_description, 48 | python_requires='>=3.8', 49 | long_description_content_type="text/markdown", 50 | url="https://github.com/HFUT-LEC/EduStudio", 51 | author="HFUT-LEC", 52 | author_email="lmcRS.hfut@gmail.com", 53 | packages=[package for package in find_packages() if package.startswith("edustudio")], 54 | include_package_data=True, 55 | install_requires=install_requires, 56 | setup_requires=setup_requires, 57 | extras_require=extras_require, 58 | zip_safe=False, 59 | classifiers=classifiers, 60 | entry_points={ 61 | "console_scripts": [ 62 | "edustudio = edustudio.quickstart.atom_cmds:entrypoint", 63 | ], 64 | } 65 | ) 66 | -------------------------------------------------------------------------------- /tests/test_run.py: -------------------------------------------------------------------------------- 1 | from edustudio.quickstart import run_edustudio 2 | 3 | class TestRun: 4 | def test_cd(self): 5 | print("---------- test cd ------------") 6 | run_edustudio( 7 | dataset='FrcSub', 8 | cfg_file_name=None, 9 | traintpl_cfg_dict={ 10 | 'cls': 'GeneralTrainTPL', 11 | 'epoch_num': 2, 12 | 'device': 'cpu' 13 | }, 14 | datatpl_cfg_dict={ 15 | 'cls': 'CDInterExtendsQDataTPL' 16 | }, 17 | modeltpl_cfg_dict={ 18 | 'cls': 'KaNCD', 19 | }, 20 | evaltpl_cfg_dict={ 21 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'], 22 | } 23 | ) 24 | --------------------------------------------------------------------------------