├── .github
└── workflows
│ ├── python-publish-test.yml
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── assets
├── framework.png
└── logo.png
├── docs
├── .readthedocs.yaml
├── Makefile
├── make.bat
├── requirements.txt
└── source
│ ├── _static
│ └── favicon.ico
│ ├── assets
│ ├── dataflow.jpg
│ ├── framework.png
│ └── logo.png
│ ├── conf.py
│ ├── developer_guide
│ ├── customize_atomic_op.md
│ ├── customize_datatpl.md
│ ├── customize_evaltpl.md
│ ├── customize_modeltpl.md
│ └── customize_traintpl.md
│ ├── features
│ ├── atomic_files.md
│ ├── atomic_operations.md
│ ├── dataset_folder_protocol.md
│ ├── global_cfg_obj.md
│ ├── inheritable_config.md
│ └── standard_datamodule.md
│ ├── get_started
│ ├── install.md
│ └── quick_start.md
│ ├── index.rst
│ └── user_guide
│ ├── atom_op.md
│ ├── datasets.md
│ ├── models.md
│ ├── reference_table.md
│ ├── usage.rst
│ └── usage
│ ├── aht.md
│ ├── atomic_cmds.md
│ ├── run_edustudio.md
│ └── use_case_of_config.md
├── edustudio
├── __init__.py
├── assets
│ └── datasets.yaml
├── atom_op
│ ├── __init__.py
│ ├── mid2cache
│ │ ├── CD
│ │ │ ├── __init__.py
│ │ │ ├── data_split4cd.py
│ │ │ └── filter_records4cd.py
│ │ ├── KT
│ │ │ ├── __init__.py
│ │ │ ├── build_seq_inter_feats.py
│ │ │ ├── cpt_as_exer.py
│ │ │ ├── data_split4kt.py
│ │ │ ├── gen_cpt_seq.py
│ │ │ └── gen_unfold_cpt_seq.py
│ │ ├── __init__.py
│ │ ├── common
│ │ │ ├── __init__.py
│ │ │ ├── base_mid2cache.py
│ │ │ ├── build_cpt_relation.py
│ │ │ ├── build_missing_Q.py
│ │ │ ├── fill_missing_Q.py
│ │ │ ├── filtering_records_by_attr.py
│ │ │ ├── gen_q_mat.py
│ │ │ ├── label2int.py
│ │ │ ├── merge_divided_splits.py
│ │ │ └── remapid.py
│ │ └── single
│ │ │ ├── M2C_CDGK_OP.py
│ │ │ ├── M2C_CL4KT_OP.py
│ │ │ ├── M2C_CNCDQ_OP.py
│ │ │ ├── M2C_DIMKT_OP.py
│ │ │ ├── M2C_DKTDSC_OP.py
│ │ │ ├── M2C_DKTForget_OP.py
│ │ │ ├── M2C_EERNN_OP.py
│ │ │ ├── M2C_LPKT_OP.py
│ │ │ ├── M2C_MGCD_OP.py
│ │ │ ├── M2C_QDKT_OP.py
│ │ │ ├── M2C_RCD_OP.py
│ │ │ └── __init__.py
│ └── raw2mid
│ │ ├── __init__.py
│ │ ├── aaai_2023.py
│ │ ├── algebra2005.py
│ │ ├── assist_0910.py
│ │ ├── assist_1213.py
│ │ ├── assist_1516.py
│ │ ├── assist_2017.py
│ │ ├── bridge2006.py
│ │ ├── ednet_kt1.py
│ │ ├── frcsub.py
│ │ ├── junyi_area_topic_as_cpt.py
│ │ ├── junyi_exer_as_cpt.py
│ │ ├── math1.py
│ │ ├── math2.py
│ │ ├── nips12.py
│ │ ├── nips34.py
│ │ ├── raw2mid.py
│ │ ├── simulated5.py
│ │ ├── slp_english.py
│ │ └── slp_math.py
├── datatpl
│ ├── CD
│ │ ├── CDGKDataTPL.py
│ │ ├── CDInterDataTPL.py
│ │ ├── CDInterExtendsQDataTPL.py
│ │ ├── CNCDFDataTPL.py
│ │ ├── CNCDQDataTPL.py
│ │ ├── DCDDataTPL.py
│ │ ├── ECDDataTPL.py
│ │ ├── FAIRDataTPL.py
│ │ ├── HierCDFDataTPL.py
│ │ ├── IRRDataTPL.py
│ │ ├── MGCDDataTPL.py
│ │ ├── RCDDataTPL.py
│ │ └── __init__.py
│ ├── KT
│ │ ├── CL4KTDataTPL.py
│ │ ├── DIMKTDataTPL.py
│ │ ├── DKTDSCDataTPL.py
│ │ ├── DKTForgetDataTPL.py
│ │ ├── EERNNDataTPL.py
│ │ ├── EKTDataTPL.py
│ │ ├── GKTDataTPL.py
│ │ ├── KTInterCptAsExerDataTPL.py
│ │ ├── KTInterCptUnfoldDataTPL.py
│ │ ├── KTInterDataTPL.py
│ │ ├── KTInterExtendsQDataTPL.py
│ │ ├── LPKTDataTPL.py
│ │ ├── QDKTDataTPL.py
│ │ ├── RKTDataTPL.py
│ │ └── __init__.py
│ ├── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── base_datatpl.py
│ │ ├── edu_datatpl.py
│ │ ├── general_datatpl.py
│ │ └── proxy_datatpl.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── pad_seq_util.py
│ │ └── spliter_util.py
├── evaltpl
│ ├── __init__.py
│ ├── base_evaltpl.py
│ ├── fairness_evaltpl.py
│ ├── identifiability_evaltpl.py
│ ├── interpretability_evaltpl.py
│ └── prediction_evaltpl.py
├── model
│ ├── CD
│ │ ├── __init__.py
│ │ ├── cdgk.py
│ │ ├── cdmfkc.py
│ │ ├── cncd_f.py
│ │ ├── cncd_q.py
│ │ ├── dcd.py
│ │ ├── dina.py
│ │ ├── ecd.py
│ │ ├── faircd.py
│ │ ├── hier_cdf.py
│ │ ├── irr.py
│ │ ├── irt.py
│ │ ├── kancd.py
│ │ ├── kscd.py
│ │ ├── mf.py
│ │ ├── mgcd.py
│ │ ├── mirt.py
│ │ ├── ncdm.py
│ │ └── rcd.py
│ ├── KT
│ │ ├── GKT
│ │ │ ├── __init__.py
│ │ │ ├── building_blocks.py
│ │ │ ├── gkt.py
│ │ │ └── losses.py
│ │ ├── __init__.py
│ │ ├── akt.py
│ │ ├── atkt.py
│ │ ├── ckt.py
│ │ ├── cl4kt.py
│ │ ├── ct_ncm.py
│ │ ├── deep_irt.py
│ │ ├── dimkt.py
│ │ ├── dkt.py
│ │ ├── dkt_dsc.py
│ │ ├── dkt_forget.py
│ │ ├── dkt_plus.py
│ │ ├── dkvmn.py
│ │ ├── dtransformer.py
│ │ ├── eernn.py
│ │ ├── ekt.py
│ │ ├── hawkeskt.py
│ │ ├── iekt.py
│ │ ├── kqn.py
│ │ ├── lpkt.py
│ │ ├── lpkt_s.py
│ │ ├── qdkt.py
│ │ ├── qikt.py
│ │ ├── rkt.py
│ │ ├── saint.py
│ │ ├── saint_plus.py
│ │ ├── sakt.py
│ │ ├── simplekt.py
│ │ └── skvmn.py
│ ├── __init__.py
│ ├── basemodel.py
│ ├── gd_basemodel.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── common.py
│ │ └── components.py
├── quickstart
│ ├── __init__.py
│ ├── atom_cmds.py
│ ├── init_all.py
│ ├── parse_cfg.py
│ └── quickstart.py
├── settings.py
├── traintpl
│ ├── __init__.py
│ ├── adversarial_traintpl.py
│ ├── atkt_traintpl.py
│ ├── base_traintpl.py
│ ├── dcd_traintpl.py
│ ├── gd_traintpl.py
│ ├── general_traintpl.py
│ └── group_cd_traintpl.py
└── utils
│ ├── __init__.py
│ ├── callback
│ ├── __init__.py
│ ├── callBackList.py
│ ├── callbacks
│ │ ├── __init__.py
│ │ ├── baseLogger.py
│ │ ├── callback.py
│ │ ├── earlyStopping.py
│ │ ├── epochPredict.py
│ │ ├── history.py
│ │ ├── modelCheckPoint.py
│ │ └── tensorboardCallBack.py
│ └── modeState.py
│ └── common
│ ├── __init__.py
│ ├── commonUtil.py
│ ├── configUtil.py
│ └── loggerUtil.py
├── examples
├── 1.run_cd_demo.py
├── 2.run_kt_demo.py
├── 3.run_with_customized_tpl.py
├── 4.run_cmd_demo.py
├── 5.run_with_hyperopt.py
├── 6.run_with_ray.tune.py
├── 7.run_toy_demo.py
└── single_model
│ ├── run_akt_demo.py
│ ├── run_atkt_demo.py
│ ├── run_cdgk_demo.py
│ ├── run_cdmfkc_demo.py
│ ├── run_ckt_demo.py
│ ├── run_cl4kt_demo.py
│ ├── run_cncd_f_demo.py
│ ├── run_cncdq_demo.py
│ ├── run_ctncm_demo.py
│ ├── run_dcd_demo.py
│ ├── run_deepirt_demo.py
│ ├── run_dimkt_demo.py
│ ├── run_dina_demo.py
│ ├── run_dkt_demo.py
│ ├── run_dkt_dsc_demo.py
│ ├── run_dkt_plus_demo.py
│ ├── run_dktforget_demo.py
│ ├── run_dkvmn_demo.py
│ ├── run_dtransformer_demo.py
│ ├── run_ecd_demo.py
│ ├── run_eernn_demo.py
│ ├── run_ekt_demo.py
│ ├── run_faircd_irt_demo.py
│ ├── run_faircd_mirt_demo.py
│ ├── run_faircd_ncdm_demo.py
│ ├── run_gkt_demo.py
│ ├── run_hawkeskt_demo.py
│ ├── run_hiercdf_demo.py
│ ├── run_iekt_demo.py
│ ├── run_irr_demo.py
│ ├── run_irt_demo.py
│ ├── run_kancd_demo.py
│ ├── run_kqn_demo.py
│ ├── run_kscd_demo.py
│ ├── run_lpkt_demo.py
│ ├── run_lpkt_s_demo.py
│ ├── run_mf_demo.py
│ ├── run_mgcd_demo.py
│ ├── run_mirt_demo.py
│ ├── run_ncdm_demo.py
│ ├── run_qdkt_demo.py
│ ├── run_qikt_demo.py
│ ├── run_rcd_demo.py
│ ├── run_rkt_demo.py
│ ├── run_saint_demo.py
│ ├── run_saint_plus_demo.py
│ ├── run_sakt_demo.py
│ ├── run_simplekt_demo.py
│ └── run_skvmn_demo.py
├── requirements.txt
├── setup.py
└── tests
└── test_run.py
/.github/workflows/python-publish-test.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package TEST
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | permissions:
7 | contents: read
8 |
9 | jobs:
10 | deploy:
11 |
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v3
16 | - name: Set up Python
17 | uses: actions/setup-python@v3
18 | with:
19 | python-version: '3.9'
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install build
24 | pip install pytest
25 | pip install torch --index-url https://download.pytorch.org/whl/cpu
26 | pip install -e . --verbose
27 | pip install -r requirements.txt
28 | - name: Test
29 | run: |
30 | cd tests && pytest && cd ..
31 | - name: Build package
32 | run: python -m build
33 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | release:
5 | types: [published]
6 | workflow_dispatch:
7 |
8 | permissions:
9 | contents: read
10 |
11 | jobs:
12 | deploy:
13 |
14 | runs-on: ubuntu-latest
15 |
16 | steps:
17 | - uses: actions/checkout@v3
18 | - name: Set up Python
19 | uses: actions/setup-python@v3
20 | with:
21 | python-version: '3.9'
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip
25 | pip install build
26 | pip install pytest
27 | pip install torch --index-url https://download.pytorch.org/whl/cpu
28 | pip install -e . --verbose
29 | pip install -r requirements.txt
30 | - name: Test
31 | run: |
32 | cd tests && pytest && cd ..
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@v1.13.0
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | .ipynb_checkpoints/
3 | __pycache__/
4 | /data/
5 | /temp/
6 | /archive/
7 | /test/
8 | /conf/
9 | /.vscode/
10 | /examples/**/archive/
11 | /examples/**/temp/
12 | /examples/**/data/
13 | /build/
14 | /dist/
15 | *.egg-info/
16 | docs/build/
17 | /.pytest_cache/
18 | /tests/archive/
19 | /tests/conf/
20 | /tests/temp/
21 | /tests/data/
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 HFUT-LEC
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include edustudio/assets/ *.yaml
2 |
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/assets/framework.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/assets/logo.png
--------------------------------------------------------------------------------
/docs/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the version of Python and other tools you might need
9 | build:
10 | os: ubuntu-22.04
11 | tools:
12 | python: "3.11"
13 |
14 | # Build documentation in the docs/ directory with Sphinx
15 | sphinx:
16 | configuration: docs/source/conf.py
17 |
18 | # We recommend specifying your dependencies to enable reproducible builds:
19 | # https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
20 | python:
21 | install:
22 | - requirements: docs/requirements.txt
23 |
24 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==6.2.1
2 | sphinx_rtd_theme==1.2.2
3 | myst-parser==2.0.0
4 | esbonio==0.16.1
5 | sphinx_copybutton
6 | sphinxcontrib-napoleon
7 |
--------------------------------------------------------------------------------
/docs/source/_static/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/docs/source/_static/favicon.ico
--------------------------------------------------------------------------------
/docs/source/assets/dataflow.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/docs/source/assets/dataflow.jpg
--------------------------------------------------------------------------------
/docs/source/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/docs/source/assets/framework.png
--------------------------------------------------------------------------------
/docs/source/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/docs/source/assets/logo.png
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | project = 'EduStudio'
10 | copyright = '2023, HFUT-LEC'
11 | author = 'HFUT-LEC'
12 | release = 'v1.1.7'
13 |
14 | import sphinx_rtd_theme
15 | import os
16 | import sys
17 |
18 | sys.path.insert(0, os.path.abspath("../.."))
19 |
20 | # -- General configuration ---------------------------------------------------
21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
22 |
23 | extensions = [
24 | 'myst_parser',
25 | "sphinx.ext.autodoc",
26 | "sphinx.ext.napoleon",
27 | "sphinx.ext.viewcode",
28 | "sphinx_copybutton",
29 | ]
30 | source_suffix = ['.rst', '.md']
31 | templates_path = ['_templates']
32 | exclude_patterns = []
33 |
34 |
35 | # -- Options for HTML output -------------------------------------------------
36 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
37 |
38 | html_theme = 'sphinx_rtd_theme'
39 | # html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
40 | html_static_path = ['_static']
41 | html_favicon = '_static/favicon.ico'
42 |
--------------------------------------------------------------------------------
/docs/source/developer_guide/customize_evaltpl.md:
--------------------------------------------------------------------------------
1 | Customize Evaluation Template
2 | ======================
3 | Here, we present how to develop a new Evaluation Template, and apply it into EduStudio.
4 | EduStudio provides the EvalTPL Protocol in ``EduStudio.edustudio.evaltpl.baseevaltpl.BaseEvalTPL`` (``BaseEvalTPL``).
5 |
6 | EvalTPL Protocol
7 | ----------------------
8 |
9 | ### BaseEvalTPL
10 | The protocols in ``BaseEvalTPL`` are listed as follows.
11 |
12 | | name | description | type | note |
13 | | ----------------- | ------------------------- | ------------------ | ----------------------- |
14 | | default_cfg | the default configuration | class variable | |
15 | | eval | calculate metric results | function interface | implemented by subclass |
16 | | _check_params | check parameters | function interface | implemented by subclass |
17 | | set_callback_list | set callback list | function interface | implemented by subclass |
18 | | set_dataloaders | set dataloaders | function interface | implemented by subclass |
19 | | add_extra_data | add extra data | function interface | implemented by subclass |
20 |
21 |
22 |
23 | EvalTPLs
24 | ----------------------
25 |
26 | EduStudio provides ``PredictionEvalTPL`` and ``InterpretabilityEvalTPL``, which inherent ``BaseEvalTPL``.
27 |
28 | ### PredictionEvalTPL
29 | This EvalTPL is for the model evaluation using binary classification metrics.
30 | The protocols in ``PredictionEvalTPL`` are listed as follows.
31 |
32 |
33 | ### InterpretabilityEvalTPL
34 | This EvalTPL is for the model evaluation for interpretability. It uses states of students and Q matrix for ``eval``, which are domain-specific in student assessment.
35 |
36 | ## Develop a New EvalTPL in EduStudio
37 |
38 | If you want to develop a new EvalTPl in EduStudio, you can inherent ``BaseEvalTPL`` and revise ``eval`` method.
39 |
40 | ### Example
41 |
42 | ```python
43 | from .base_evaltpl import BaseEvalTPL
44 | from sklearn.metrics import accuracy_score, coverage_error
45 |
46 | class NewEvalFmt(BaseEvalFmt):
47 | default_cfg = {
48 | 'use_metrics': ['acc', 'coverage_error']
49 | }
50 |
51 | def __init__(self, cfg):
52 | super().__init__(cfg)
53 |
54 | def eval(self, y_pd, y_gt, **kwargs):
55 | if not isinstance(y_pd, np.ndarray): y_pd = tensor2npy(y_pd)
56 | if not isinstance(y_gt, np.ndarray): y_gt = tensor2npy(y_gt)
57 | metric_result = {}
58 | ignore_metrics = kwargs.get('ignore_metrics', {})
59 | for metric_name in self.evalfmt_cfg[self.__class__.__name__]['use_metrics']:
60 | if metric_name not in ignore_metrics:
61 | metric_result[metric_name] = self._get_metrics(metric_name)(y_gt, y_pd)
62 | return metric_result
63 |
64 | def _get_metrics(self, metric):
65 | if metric == "acc":
66 | return lambda y_gt, y_pd: accuracy_score(y_gt, np.where(y_pd >= 0.5, 1, 0))
67 | elif metric == 'coverage_error':
68 | return lambda y_gt, y_pd: coverage_error(y_gt, y_pd)
69 | else:
70 | raise NotImplementedError
71 | ```
72 |
--------------------------------------------------------------------------------
/docs/source/features/atomic_files.md:
--------------------------------------------------------------------------------
1 | # Middle Data Format Protocol
2 |
3 | In `EduStudio`, we adopt a flexible CSV (Comma-Separated Values) file format following [Recbole](https://recbole.io/atomic_files.html). The flexible CSV format is defined in `middata` stage of dataset (see dataset stage protocol for details).
4 |
5 | The Middle Data Format Protocol including two parts: `Columns name Format` and `Filename Format`.
6 |
7 | ## Columns Name Format
8 |
9 | | feat_type | Explanations | Examples |
10 | | ------------- | --------------------------- | --------------------------------- |
11 | | **token** | single discrete feature | exer_id, stu_id |
12 | | **token_seq** | discrete features sequence | knowledge concept seq of exercise |
13 | | **float** | single continuous feature | label, start_timestamp |
14 | | **float_seq** | continuous feature sequence | word2vec embedding of exercise |
15 |
16 |
17 |
18 | ## Filename format
19 |
20 | So far, there are five atomic files in edustudio.
21 |
22 | **Note**: Users could also load other types of data except the three atomic files below. `{dt}` is the dataset name.
23 |
24 | | filename format | description |
25 | | -------------------- | ---------------------------------------------------- |
26 | | {dt}.inter.csv | Student-Exercise Interaction data |
27 | | {dt}.train.inter.csv | Student-Exercise Interaction data for training set |
28 | | {dt}.train.inter.csv | Student-Exercise Interaction data for validation set |
29 | | {dt}.train.inter.csv | Student-Exercise Interaction data for test set |
30 | | {dt}.stu.csv | Features of students |
31 | | {dt}.exer.csv | Features of exercises |
32 |
33 |
34 |
35 | ## Example
36 |
37 | ### example_dt.inter.csv
38 |
39 | | stu_id:token | exer_id:token | label:float |
40 | | ------------ | ------------- | ----------- |
41 | | 0 | 1 | 0.0 |
42 | | 1 | 0 | 1.0 |
43 |
44 | ### example_dt.stu.csv
45 |
46 | | stu_id:token | gender:token | occupation:token |
47 | | ------------ | ------------ | ---------------- |
48 | | 0 | 1 | 11 |
49 | | 1 | 0 | 7 |
50 |
51 | ### example_dt.exer.csv
52 |
53 | | exer_id:token | cpt_seq:token_seq | w2v_emb:float_seq |
54 | | ------------- | ----------------- | ---------------------- |
55 | | 0 | [0, 1] | [0.121, 0.123, 0.761] |
56 | | 1 | [1, 2, 3] | [0.229, -0.113, 0.138] |
57 |
--------------------------------------------------------------------------------
/docs/source/features/atomic_operations.md:
--------------------------------------------------------------------------------
1 | # Atomic Data Operation Protocol
2 |
3 | In `Edustudio`, we view the dataset from three stages: `rawdata`, `middata`, `cachedata`.
4 |
5 | We treat the whole data processing as multiple atomic operations called atomic operation sequence.
6 | The first atomic operation, inheriting the protocol class `BaseRaw2Mid`, is the process from raw data to middle data.
7 | The following atomic operations, inheriting the protocol class `BaseMid2Cache`, construct the process from middle data to cache data.
8 |
9 |
10 | ## Partial Atomic Operation Table
11 |
12 | In the following, we give a table to display some existing atomic operations. For more detailed Atomic Operation Table, please see the `user_guide/Atomic Data Operation List`
13 |
14 | ### Raw2Mid
15 |
16 | For the conversion from rawdata to middata, we implement a specific atomic data operation prefixed with `R2M` for each dataset.
17 |
18 | | name | Corresponding datase |
19 | | --------------- | ------------------------------------------------------------ |
20 | | R2M_ASSIST_0910 | ASSISTment 2009-2010 |
21 | | R2M_FrcSub | Frcsub |
22 | | R2M_ASSIST_1213 | ASSISTment 2012-2013 |
23 | | R2M_Math1 | Math1 |
24 | | R2M_Math2 | Math2 |
25 | | R2M_AAAI_2023 | AAAI 2023 Global Knowledge Tracing Challenge |
26 | | R2M_Algebra_0506 | Algebra 2005-2006 |
27 | | R2M_ASSIST_1516 | ASSISTment 2015-2016 |
28 |
29 | ### Mid2Cache
30 |
31 | #### common
32 |
33 | | name | description |
34 | | ---------------------- | --------------------------------------------- |
35 | | M2C_Label2Int | convert label column into discrete values |
36 | | M2C_MergeDividedSplits | merge train/valid/test set into one dataframe |
37 | | M2C_ReMapId | ReMap Column ID |
38 | | M2C_GenQMat | Generate Q-matrix |
39 |
40 | #### CD
41 |
42 | | name | description |
43 | | ---------------------- | ------------------------------------------------------------ |
44 | | M2C_RandomDataSplit4CD | Split datasets Randomly for CD |
45 | | M2C_FilterRecords4CD | Filter students or exercises whose number of interaction records is less than a threshold |
46 |
47 | #### KT
48 |
49 | | name | description |
50 | | ---------------------- | ------------------------------------------- |
51 | | M2C_BuildSeqInterFeats | Build Sequential Features and Split dataset |
52 | | M2C_KCAsExer | Treat knowledge concept as exercise |
53 | | M2C_GenKCSeq | Generate knowledge concept seq |
54 | | M2C_GenUnFoldKCSeq | Unfold knowledge concepts |
55 |
56 |
--------------------------------------------------------------------------------
/docs/source/features/global_cfg_obj.md:
--------------------------------------------------------------------------------
1 | # Global Configuration Object
2 |
3 | In `EduStudio`, there are five categories of configuration, they are unified in one global configuration object.
4 |
5 | ## Five Config Objects
6 |
7 | The description of five config objects is illustrated in Table below.
8 |
9 | | name | description |
10 | | ------------ | ---------------------------------- |
11 | | datatpl_cfg | configuration of data template |
12 | | modeltpl_cfg | configuration of model template |
13 | | traintpl_cfg | configuration of training template |
14 | | evaltpl_cfg | configuration of evaluate template |
15 | | frame_cfg | configuration of framework itself |
16 |
17 |
18 |
19 | ## Four Configuration Portals
20 |
21 | There are four configuration portals:
22 |
23 | - default_cfg: inheritable python class varible
24 | - configuration file
25 | - parameter dictionary
26 | - command line
27 |
28 |
--------------------------------------------------------------------------------
/docs/source/features/inheritable_config.md:
--------------------------------------------------------------------------------
1 | # Inheritable Default Configuration
2 |
3 | The management of default configuration in Edustudio is implemented by class variable, i.e. a dictionary object called default_config.
4 |
5 | Templates usually introduce new features through inheritance, and these new features may require corresponding configurations, so the default configuration we provide is inheritable.
6 |
7 | ## Example
8 |
9 | The inheritance example of data template is illustrated as follows. We present an example in the data preparation procedure. There are three data template classes (DataTPLs) that inherit from each other: BaseDataTPL, GeneralDataTPL, and EduDataTPL. If users specify current DataTPL is EduDataTPL, the eventual default\_config of data preparation procedure is a merger of default\_cfg of three templates. When a configuration conflict is encountered, the default\_config of subclass template takes precedence over that of parent class templates. As a result, other configuration portals (i.e, configuration file, parameter dictionary, and command line) can only specify parameters that are confined within the default configuration. The advantage of the inheritable design is that it facilitates the reader to locate the numerous hyperparameters.
10 |
11 | ```python
12 | class BaseDataTPL(Dataset):
13 | default_cfg = {'seed': 2023}
14 |
15 |
16 | class GeneralDataTPL(BaseDataTPL):
17 | default_cfg = {
18 | 'seperator': ',',
19 | 'n_folds': 1,
20 | 'is_dataset_divided': False,
21 | 'is_save_cache': False,
22 | 'cache_id': 'cache_default',
23 | 'load_data_from': 'middata', # ['rawdata', 'middata', 'cachedata']
24 | 'inter_exclude_feat_names': (),
25 | 'raw2mid_op': None,
26 | 'mid2cache_op_seq': []
27 | }
28 |
29 |
30 | class EduDataTPL(GeneralDataTPL):
31 | default_cfg = {
32 | 'exer_exclude_feat_names': (),
33 | 'stu_exclude_feat_names': (),
34 | }
35 | ```
36 |
37 | If the currently specified data template is `EduDataTPL`, then the framework will get the final `default_cfg` through API `get_default_cfg`, which would be:
38 |
39 | ```python
40 | default_cfg = {
41 | 'exer_exclude_feat_names': (),
42 | 'stu_exclude_feat_names': (),
43 | 'seperator': ',',
44 | 'n_folds': 1,
45 | 'is_dataset_divided': False,
46 | 'is_save_cache': False,
47 | 'cache_id': 'cache_default',
48 | 'load_data_from': 'middata', # ['rawdata', 'middata', 'cachedata']
49 | 'inter_exclude_feat_names': (),
50 | 'raw2mid_op': None,
51 | 'mid2cache_op_seq': [],
52 | 'seed': 2023
53 | }
54 | ```
55 |
56 | The final `default_cfg` follows two rules:
57 |
58 | - The subclass would incorporate the `default_cfg` of all parent classes.
59 | - When a conflict happened for the same key, the subclass would dominate the priority.
60 |
--------------------------------------------------------------------------------
/docs/source/features/standard_datamodule.md:
--------------------------------------------------------------------------------
1 | # Standardized Data Module
2 |
3 | For data module, we provide a standardized design with three protocols (see following sections for details):
4 | - Data Status Protocol
5 | - Middle Data Format Protocol
6 | - Atomic Operation Protocol
7 |
8 | 
9 |
10 | The first step of Data Templates is to load the raw data from the hard disk. Then, a series of processing steps are performed to obtain model-friendly data objects. Finally, these data objects are passed on to other modules.
11 | We simplify the data preparation into three into three stages:
12 |
13 | - Data loading: Loading necessary data from the hard disk.
14 | - Data processing: Convert the raw data into model-friendly data objects by a range of data processing operations.
15 | - Data delivery: Deliver model-friendly data objects to the training, model, and evaluation templates.
16 |
--------------------------------------------------------------------------------
/docs/source/get_started/install.md:
--------------------------------------------------------------------------------
1 | # Install EduStudio
2 |
3 | EduStudio can be installed from ``pip`` and source files.
4 |
5 | ## System requirements
6 |
7 | EduStudio is compatible with the following operating systems:
8 |
9 | - Linux
10 |
11 | - Windows 10
12 |
13 | - macOS X
14 |
15 | Python 3.8 (or later), torch 1.10.0 (or later) are required to install our library.
16 |
17 | ## Install from pip
18 |
19 | To install EduStudio from pip, only the following command is needed:
20 |
21 | ```bash
22 | pip install -U edustudio
23 | ```
24 |
25 | ## Install from source
26 |
27 | Download the source files from GitHub.
28 |
29 | ```bash
30 | git clone https://github.com/HFUT-LEC/EduStudio.git && cd EduStudio
31 | ```
32 |
33 | Run the following command to install:
34 |
35 | ```bash
36 | pip install -e . --verbose
37 | ```
38 |
39 |
--------------------------------------------------------------------------------
/docs/source/get_started/quick_start.md:
--------------------------------------------------------------------------------
1 | # Quick Start
2 |
3 | Here is a quick-start example for using EduStudio.
4 |
5 | ## Create a python file to run
6 |
7 | create a python file (e.g., *run.py*) anywhere, the content is as follows:
8 |
9 | ```python
10 | from edustudio.quickstart import run_edustudio
11 |
12 | run_edustudio(
13 | dataset='FrcSub',
14 | cfg_file_name=None,
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | datatpl_cfg_dict={
19 | 'cls': 'CDInterExtendsQDataTPL'
20 | },
21 | modeltpl_cfg_dict={
22 | 'cls': 'KaNCD',
23 | },
24 | evaltpl_cfg_dict={
25 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'],
26 | }
27 | )
28 | ```
29 |
30 | Then run the following command:
31 |
32 | ```bash
33 | python run.py
34 | ```
35 |
36 | To find out which templates are used for a model, we can find in the [Reference Table](https://edustudio.readthedocs.io/en/latest/user_guide/reference_table.html)
37 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. EduStudio documentation master file.
2 | .. title:: EduStudio v1.1.7
3 | .. image:: assets/logo.png
4 |
5 | =========================================================
6 |
7 | `HomePage `_ | `Docs `_ | `GitHub `_ | `Paper `_
8 |
9 | Introduction
10 | -------------------------
11 | EduStudio is a Unified Library for Student Assessment Models including Cognitive Diagnosis(CD) and Knowledge Tracing(KT) based on Pytorch.
12 |
13 | EduStudio first decomposes the general algorithmic workflow into six steps: `configuration reading`, `data prepration`, `model implementation`, `training control`, `model evaluation`, and `Log Storage`. Subsequently, to enhance the `reusability` and `scalability` of each step, we extract the commonalities of each algorithm at each step into individual templates for templatization.
14 |
15 | - Configuration Reading (Step 1) aims to collect, categorize and deliver configurations from different configuration portals.
16 | - Data Preparation (Step 2) aims to convert raw data from the hard disk into model-friendly data objects.
17 | - Model Implementation (Step 3) refers to the process of implementing the structure of each model and facilitating the reuse of model components.
18 | - Training Control (Step 4) focuses primarily on the training methods of various models.
19 | - Model Evaluation (Step 5) primarily focuses on the implementation of various evaluation metrics.
20 | - Log Storage (Step 6) aims to implement storage specification when store generated data.
21 |
22 | The modularization establishes clear boundaries between various programs in the algorithm pipeline, facilitating the introduction of new content to individual modules and enhancing scalability.
23 |
24 | The overall structure is illustrated as follows:
25 |
26 | .. image:: assets/framework.png
27 | :width: 600
28 | :align: center
29 |
30 | .. raw:: html
31 |
32 |
33 |
34 | .. toctree::
35 | :maxdepth: 1
36 | :caption: Get Started
37 |
38 | get_started/install
39 | get_started/quick_start
40 |
41 |
42 | .. toctree::
43 | :maxdepth: 1
44 | :caption: Framework Features
45 |
46 | features/global_cfg_obj
47 | features/inheritable_config
48 | features/standard_datamodule
49 | features/dataset_folder_protocol
50 | features/atomic_files
51 | features/atomic_operations
52 |
53 | .. toctree::
54 | :maxdepth: 1
55 | :caption: User Guide
56 |
57 | user_guide/atom_op
58 | user_guide/datasets
59 | user_guide/models
60 | user_guide/reference_table
61 | user_guide/usage
62 |
63 | .. toctree::
64 | :maxdepth: 1
65 | :caption: Developer Guide
66 |
67 | developer_guide/customize_atomic_op
68 | developer_guide/customize_datatpl
69 | developer_guide/customize_modeltpl
70 | developer_guide/customize_traintpl
71 | developer_guide/customize_evaltpl
72 |
73 |
74 | Indices and tables
75 | ==================
76 |
77 | * :ref:`genindex`
78 | * :ref:`modindex`
79 | * :ref:`search`
80 |
--------------------------------------------------------------------------------
/docs/source/user_guide/atom_op.md:
--------------------------------------------------------------------------------
1 | # M2C Atomic Data Operation List
2 |
3 |
4 | | M2C Atomic operation | M2C Atomic Type | Description |
5 | | :------------------------: | --------------- | ------------------------------------------------------------ |
6 | | M2C_Label2Int | Data Cleaning | Binarization for answering response |
7 | | M2C_FilterRecords4CD | Data Cleaning | Filter some students or exercises according specific conditions |
8 | | M2C_FilteringRecordsByAttr | Data Cleaning | Filtering Students without attribute values, Commonly used by Fair Models |
9 | | M2C_ReMapId | Data Conversion | ReMap Column ID |
10 | | M2C_BuildMissingQ | Data Conversion | Build Missing Q-matrix |
11 | | M2C_BuildSeqInterFeats | Data Conversion | Build sample format for Question-based KT |
12 | | M2C_CKCAsExer | Data Conversion | Build sample format for KC-based KT |
13 | | M2C_MergeDividedSplits | Data Conversion | Merge train/valid/test set into one dataframe |
14 | | M2C_RandomDataSplit4CD | Data Partition | Data partitioning for Cognitive Diagnosis |
15 | | M2C_RandomDataSplit4KT | Data Partition | Data partitioning for Knowledge Tracing |
16 | | M2C_GenKCSeq | Data Generation | Generate Knowledge Component Sequence |
17 | | M2C_GenQMat | Data Generation | Generate Q-matrix (i.e, exercise-KC relation) |
18 | | M2C_BuildKCRelation | Data Generation | Build Knowledge Component Relation Graph |
19 | | M2C_GenUnFoldKCSeq | Data Generation | Generate Unfolded Knowledge Component Sequence |
20 | | M2C_FillMissingQ | Data Generation | Fill Missing Q-matrix |
21 |
22 |
--------------------------------------------------------------------------------
/docs/source/user_guide/usage.rst:
--------------------------------------------------------------------------------
1 | Usage
2 | =========================================================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | usage/run_edustudio
8 | usage/aht
9 | usage/atomic_cmds
10 | usage/use_case_of_config
11 |
--------------------------------------------------------------------------------
/docs/source/user_guide/usage/atomic_cmds.md:
--------------------------------------------------------------------------------
1 | # Atomic Commands
2 |
3 | ## Atomic Commands for R2M process
4 |
5 | If there is a demand that process the `FrcSub` dataset from `rawdata` to `middata`, we can run the following command.
6 | ```bash
7 | edustudio r2m R2M_FrcSub --dt FrcSub --rawpath data/FrcSub/rawdata --midpath data/FrcSub/middata
8 | ```
9 |
10 | The command would read raw data files from `data/FrcSub/rawdata` and then save the middata in `data/FrcSub/middata`
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/user_guide/usage/run_edustudio.md:
--------------------------------------------------------------------------------
1 | # Run EduStudio
2 |
3 | ## Create a python file to run
4 |
5 | create a python file (e.g., *run.py*) anywhere, the content is as follows:
6 |
7 | ```python
8 | from edustudio.quickstart import run_edustudio
9 |
10 | run_edustudio(
11 | dataset='FrcSub',
12 | cfg_file_name=None,
13 | traintpl_cfg_dict={
14 | 'cls': 'GeneralTrainTPL',
15 | },
16 | datatpl_cfg_dict={
17 | 'cls': 'CDInterExtendsQDataTPL'
18 | },
19 | modeltpl_cfg_dict={
20 | 'cls': 'KaNCD',
21 | },
22 | evaltpl_cfg_dict={
23 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'],
24 | }
25 | )
26 | ```
27 |
28 | Then run the following command:
29 |
30 | ```bash
31 | python run.py
32 | ```
33 |
34 | ## Run with command
35 |
36 | You can run the following command with parameters based on created file above.
37 |
38 | ```bash
39 | cd examples
40 | python run.py -dt ASSIST_0910 --modeltpl_cfg.cls NCDM --traintpl_cfg.batch_size 512
41 | ```
42 |
43 | ## Run with config file
44 |
45 | create a yaml file `conf/ASSIST_0910/NCDM.yaml`:
46 | ```yaml
47 | datatpl_cfg:
48 | cls: CDInterDataTPL
49 |
50 | traintpl_cfg:
51 | cls: GeneralTrainTPL
52 | batch_size: 512
53 |
54 | modeltpl_cfg:
55 | cls: NCDM
56 |
57 | evaltpl_cfg:
58 | clses: [PredictionEvalTPL, InterpretabilityEvalT]
59 | ```
60 |
61 | then, run command:
62 |
63 | ```bash
64 | cd examples
65 | python run.py -dt ASSIST_0910 -f NCDM.yaml
66 | ```
67 |
--------------------------------------------------------------------------------
/docs/source/user_guide/usage/use_case_of_config.md:
--------------------------------------------------------------------------------
1 | # Use cases about specifying configuration
2 |
3 | ## Q1: How to specify the atomic data operation config
4 |
5 | The default_cfg of `M2C_FilterRecords4CD` is as follows:
6 |
7 | ```python
8 | class M2C_FilterRecords4CD(BaseMid2Cache):
9 | default_cfg = {
10 | "stu_least_records": 10,
11 | "exer_least_records": 0,
12 | }
13 |
14 | ```
15 |
16 | The following example demonstrates how to specify config of M2C_FilterRecords4CD.
17 |
18 | ```python
19 | from edustudio.quickstart import run_edustudio
20 |
21 | run_edustudio(
22 | dataset='FrcSub',
23 | cfg_file_name=None,
24 | traintpl_cfg_dict={
25 | 'cls': 'GeneralTrainTPL',
26 | },
27 | datatpl_cfg_dict={
28 | 'cls': 'CDInterExtendsQDataTPL',
29 | 'M2C_FilterRecords4CD': {
30 | "stu_least_records": 20, # look here
31 | }
32 | },
33 | modeltpl_cfg_dict={
34 | 'cls': 'KaNCD',
35 | },
36 | evaltpl_cfg_dict={
37 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'],
38 | }
39 | )
40 | ```
41 |
42 | ## Q2: How to specify the config of evaluate template
43 | The default_cfg of `PredictionEvalTPL` is as follows:
44 | ```python
45 | class PredictionEvalTPL(BaseEvalTPL):
46 | default_cfg = {
47 | 'use_metrics': ['auc', 'acc', 'rmse']
48 | }
49 | ```
50 |
51 |
52 | If we want to use only auc metric, we can do:
53 |
54 | ```python
55 | from edustudio.quickstart import run_edustudio
56 |
57 | run_edustudio(
58 | dataset='FrcSub',
59 | cfg_file_name=None,
60 | traintpl_cfg_dict={
61 | 'cls': 'GeneralTrainTPL',
62 | },
63 | datatpl_cfg_dict={
64 | 'cls': 'CDInterExtendsQDataTPL',
65 | 'M2C_FilterRecords4CD': {
66 | "stu_least_records": 20,
67 | }
68 | },
69 | modeltpl_cfg_dict={
70 | 'cls': 'KaNCD',
71 | },
72 | evaltpl_cfg_dict={
73 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'],
74 | 'InterpretabilityEvalTPL': {
75 | 'use_metrics': {"auc"} # look here
76 | }
77 | }
78 | )
79 | ```
80 |
--------------------------------------------------------------------------------
/edustudio/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 | __version__ = 'v1.1.7'
6 |
--------------------------------------------------------------------------------
/edustudio/assets/datasets.yaml:
--------------------------------------------------------------------------------
1 | # 1. all datasets are stored in https://huggingface.co/datasets/lmcRS/edustudio-datasets
2 | # 2. some datasets may not list here, but can still download, as edustudio will look up from external yaml file: https://huggingface.co/datasets/lmcRS/edustudio-datasets/raw/main/datasets.yaml
3 |
4 | ASSIST_0910:
5 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/ASSIST_0910/ASSIST_0910-middata.zip
6 | FrcSub:
7 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/FrcSub/FrcSub-middata.zip
8 | Math1:
9 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/Math1/Math1-middata.zip
10 | Math2:
11 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/Math2/Math2-middata.zip
12 | AAAI_2023:
13 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/AAAI_2023/AAAI_2023-middata.zip
14 | PISA_2015_ECD:
15 | middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/PISA_2015_ECD/PISA_2015_ECD-middata.zip
16 |
--------------------------------------------------------------------------------
/edustudio/atom_op/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/edustudio/atom_op/__init__.py
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/CD/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_split4cd import M2C_RandomDataSplit4CD
2 | from .filter_records4cd import M2C_FilterRecords4CD
3 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/CD/filter_records4cd.py:
--------------------------------------------------------------------------------
1 | from ..common.base_mid2cache import BaseMid2Cache
2 | import pandas as pd
3 |
4 |
5 | class M2C_FilterRecords4CD(BaseMid2Cache):
6 | default_cfg = {
7 | "stu_least_records": 10,
8 | "exer_least_records": 0,
9 | }
10 |
11 | def process(self, **kwargs):
12 | df: pd.DataFrame = kwargs['df']
13 | assert df is not None
14 |
15 | # 去重,保留第一个记录
16 | df.drop_duplicates(
17 | subset=['stu_id:token', 'exer_id:token', "label:float"],
18 | keep='first', inplace=True, ignore_index=True
19 | )
20 |
21 | stu_least_records = self.m2c_cfg['stu_least_records']
22 | exer_least_records = self.m2c_cfg['exer_least_records']
23 |
24 | # 循环删除user和item
25 | last_count = 0
26 | while last_count != df.__len__():
27 | last_count = df.__len__()
28 | gp_by_uid = df[['stu_id:token','exer_id:token']].groupby('stu_id:token').agg('count').reset_index()
29 | selected_users = gp_by_uid[gp_by_uid['exer_id:token'] >= stu_least_records].reset_index()['stu_id:token'].to_numpy()
30 |
31 | gp_by_iid = df[['stu_id:token','exer_id:token']].groupby('exer_id:token').agg('count').reset_index()
32 | selected_items = gp_by_iid[gp_by_iid['stu_id:token'] >= exer_least_records].reset_index()['exer_id:token'].to_numpy()
33 |
34 | df = df[df['stu_id:token'].isin(selected_users) & df['exer_id:token'].isin(selected_items)]
35 |
36 |
37 | df = df.reset_index(drop=True)
38 | selected_users = df['stu_id:token'].unique()
39 | selected_items = df['exer_id:token'].unique()
40 |
41 | if kwargs.get('df_exer', None) is not None:
42 | kwargs['df_exer'] = kwargs['df_exer'][kwargs['df_exer']['exer_id:token'].isin(selected_items)]
43 | kwargs['df_exer'].reset_index(drop=True, inplace=True)
44 |
45 | if kwargs.get('df_stu', None) is not None:
46 | kwargs['df_stu'] = kwargs['df_stu'][kwargs['df_stu']['stu_id:token'].isin(selected_users)]
47 | kwargs['df_stu'].reset_index(drop=True, inplace=True)
48 |
49 | kwargs['df'] = df
50 | return kwargs
51 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/KT/__init__.py:
--------------------------------------------------------------------------------
1 | from .build_seq_inter_feats import M2C_BuildSeqInterFeats
2 | from .cpt_as_exer import M2C_KCAsExer
3 | from .gen_cpt_seq import M2C_GenKCSeq
4 | from .gen_unfold_cpt_seq import M2C_GenUnFoldKCSeq
5 | from .data_split4kt import M2C_RandomDataSplit4KT
6 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/KT/gen_cpt_seq.py:
--------------------------------------------------------------------------------
1 | from ..common.base_mid2cache import BaseMid2Cache
2 | import numpy as np
3 | from edustudio.datatpl.utils import PadSeqUtil
4 |
5 |
6 | class M2C_GenKCSeq(BaseMid2Cache):
7 | """Generate Knowledge Component Sequence
8 | """
9 | default_cfg = {
10 | 'cpt_seq_window_size': -1,
11 | }
12 |
13 | def process(self, **kwargs):
14 | df_exer = kwargs['df_exer']
15 | tmp_df_Q = df_exer.set_index('exer_id:token')
16 | exer_count = kwargs['dt_info']['exer_count']
17 |
18 | cpt_seq_unpadding = [
19 | (tmp_df_Q.loc[exer_id].tolist()[0] if exer_id in tmp_df_Q.index else []) for exer_id in range(exer_count)
20 | ]
21 | cpt_seq_padding, _, cpt_seq_mask = PadSeqUtil.pad_sequence(
22 | cpt_seq_unpadding, maxlen=self.m2c_cfg['cpt_seq_window_size'], return_mask=True
23 | )
24 |
25 | kwargs['cpt_seq_padding'] = cpt_seq_padding
26 | kwargs['cpt_seq_mask'] = cpt_seq_mask
27 | return kwargs
28 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/KT/gen_unfold_cpt_seq.py:
--------------------------------------------------------------------------------
1 | from ..common.base_mid2cache import BaseMid2Cache
2 | import numpy as np
3 | from edustudio.datatpl.utils import PadSeqUtil
4 | import pandas as pd
5 |
6 |
7 | class M2C_GenUnFoldKCSeq(BaseMid2Cache):
8 | default_cfg = {}
9 |
10 | def __init__(self, m2c_cfg, n_folds, is_dataset_divided) -> None:
11 | super().__init__(m2c_cfg)
12 | self.n_folds = n_folds
13 | self.is_dataset_divided = is_dataset_divided
14 |
15 | @classmethod
16 | def from_cfg(cls, cfg):
17 | m2c_cfg = cfg.datatpl_cfg.get(cls.__name__)
18 | n_folds = cfg.datatpl_cfg.n_folds
19 | is_dataset_divided = cfg.datatpl_cfg.is_dataset_divided
20 | return cls(m2c_cfg, n_folds, is_dataset_divided)
21 |
22 | def process(self, **kwargs):
23 | df = kwargs['df']
24 | df_exer = kwargs['df_exer']
25 | df_train, df_valid, df_test = kwargs['df_train'], kwargs['df_valid'], kwargs['df_test']
26 |
27 |
28 | if not self.is_dataset_divided:
29 | assert df_train is None and df_valid is None and df_test is None
30 | unique_cpt_seq = df_exer['cpt_seq:token_seq'].explode().unique()
31 | cpt_map = dict(zip(unique_cpt_seq, range(len(unique_cpt_seq))))
32 | df_Q_unfold = pd.DataFrame({
33 | 'exer_id:token': df_exer['exer_id:token'].repeat(df_exer['cpt_seq:token_seq'].apply(len)),
34 | 'cpt_unfold:token': df_exer['cpt_seq:token_seq'].explode().replace(cpt_map)
35 | })
36 | df = pd.merge(df, df_Q_unfold, on=['exer_id:token'], how='left').reset_index(drop=True)
37 | kwargs['df'] = df
38 | else: # dataset is divided
39 | assert df_train is not None and df_test is not None
40 | train_df, val_df, test_df = self._unfold_dataset(df_train, df_valid, df_test, df_exer)
41 | kwargs['df_train'] = train_df
42 | kwargs['df_valid'] = val_df
43 | kwargs['df_test'] = test_df
44 | return kwargs
45 |
46 | def _unfold_dataset(self, train_df, val_df, test_df, df_Q):
47 | unique_cpt_seq = df_Q['cpt_seq'].explode().unique()
48 | cpt_map = dict(zip(unique_cpt_seq, range(len(unique_cpt_seq))))
49 | df_Q_unfold = pd.DataFrame({
50 | 'exer_id:token': df_Q['exer_id'].repeat(df_Q['cpt_seq'].apply(len)),
51 | 'cpt_unfold:token': df_Q['cpt_seq'].explode().replace(cpt_map)
52 | })
53 | train_df_unfold = pd.merge(train_df, df_Q_unfold, on=['exer_id'], how='left').reset_index(drop=True)
54 | val_df_unfold = pd.merge(val_df, df_Q_unfold, on=['exer_id'], how='left').reset_index(drop=True)
55 | test_df_unfold = pd.merge(test_df, df_Q_unfold, on=['exer_id'], how='left').reset_index(drop=True)
56 |
57 | return train_df_unfold, val_df_unfold, test_df_unfold
58 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/__init__.py:
--------------------------------------------------------------------------------
1 | from .CD import *
2 | from .KT import *
3 | from .single import *
4 | from .common import *
5 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/common/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_mid2cache import BaseMid2Cache
2 | from .gen_q_mat import M2C_GenQMat
3 | from .label2int import M2C_Label2Int
4 | from .merge_divided_splits import M2C_MergeDividedSplits
5 | from .remapid import M2C_ReMapId
6 | from .build_cpt_relation import M2C_BuildKCRelation
7 | from .build_missing_Q import M2C_BuildMissingQ
8 | from .fill_missing_Q import M2C_FillMissingQ
9 | from .filtering_records_by_attr import M2C_FilteringRecordsByAttr
10 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/common/base_mid2cache.py:
--------------------------------------------------------------------------------
1 | from edustudio.utils.common import UnifyConfig
2 | import logging
3 |
4 |
5 | class BaseMid2Cache(object):
6 | default_cfg = {}
7 |
8 | def __init__(self, m2c_cfg) -> None:
9 | self.logger = logging.getLogger("edustudio")
10 | self.m2c_cfg = m2c_cfg
11 | self._check_params()
12 |
13 | def _check_params(self):
14 | pass
15 |
16 | @classmethod
17 | def from_cfg(cls, cfg: UnifyConfig):
18 | return cls(cfg.datatpl_cfg.get(cls.__name__))
19 |
20 | @classmethod
21 | def get_default_cfg(cls, **kwargs):
22 | cfg = UnifyConfig()
23 | for _cls in cls.__mro__:
24 | if not hasattr(_cls, 'default_cfg'):
25 | break
26 | cfg.update(_cls.default_cfg, update_unknown_key_only=True)
27 | return cfg
28 |
29 | def process(self, **kwargs):
30 | pass
31 |
32 | def set_dt_info(self, dt_info, **kwargs):
33 | pass
34 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/common/build_cpt_relation.py:
--------------------------------------------------------------------------------
1 | from .base_mid2cache import BaseMid2Cache
2 | import pandas as pd
3 | import numpy as np
4 | from itertools import chain
5 |
6 |
7 | class M2C_BuildKCRelation(BaseMid2Cache):
8 | default_cfg = {
9 | 'relation_type': 'rcd_transition',
10 | 'threshold': None
11 | }
12 |
13 | def process(self, **kwargs):
14 | df = kwargs['df']
15 | if df is None:
16 | df = pd.concat(
17 | [kwargs['df_train'], kwargs['df_valid'], kwargs['df_test']],
18 | axis=0, ignore_index=True
19 | )
20 |
21 | if 'order_id:float' in df:
22 | df.sort_values(by=['order_id:float'], axis=0, ignore_index=True, inplace=True)
23 | elif 'order_id:token' in df:
24 | df.sort_values(by=['order_id:token'], axis=0, ignore_index=True, inplace=True)
25 |
26 | kwargs['df'] = df
27 |
28 | if self.m2c_cfg['relation_type'] == 'rcd_transition':
29 | return self.gen_rcd_transition_relation(kwargs)
30 | else:
31 | raise NotImplementedError
32 |
33 | def gen_rcd_transition_relation(self, kwargs):
34 | df:pd.DataFrame = kwargs['df']
35 | df_exer = kwargs['df_exer'][['exer_id:token', 'cpt_seq:token_seq']]
36 | df = df[['stu_id:token', 'exer_id:token', 'label:float']].merge(df_exer[['exer_id:token', 'cpt_seq:token_seq']], how='left', on='exer_id:token')
37 | cpt_count = np.max(list(chain(*df_exer['cpt_seq:token_seq'].to_list()))) + 1
38 |
39 | n_mat = np.zeros((cpt_count, cpt_count), dtype=np.float32) # n_{i,j}
40 | for _, df_one_stu in df.groupby('stu_id:token'):
41 | for idx in range(df_one_stu.shape[0]):
42 | if idx == df_one_stu.shape[0] - 2:
43 | break
44 | curr_record = df_one_stu.iloc[idx]
45 | next_record = df_one_stu.iloc[idx + 1]
46 | if curr_record['label:float'] * next_record['label:float'] == 1:
47 | for cpt_pre in curr_record['cpt_seq:token_seq']:
48 | for cpt_next in next_record['cpt_seq:token_seq']:
49 | if cpt_pre != cpt_next:
50 | n_mat[cpt_pre, cpt_next] += 1
51 |
52 | a = np.sum(n_mat, axis=1)[:,None]
53 | nonzero_mask = (a != 0)
54 | np.seterr(divide='ignore', invalid='ignore')
55 | C_mat = np.where(nonzero_mask, n_mat / a, n_mat)
56 |
57 | max_val = C_mat.max()
58 | np.fill_diagonal(C_mat, max_val)
59 | min_val = C_mat.min()
60 | np.fill_diagonal(C_mat, 0)
61 | T_mat = (C_mat- min_val) / (max_val - min_val)
62 |
63 | threshold = self.m2c_cfg['threshold']
64 | if threshold is None:
65 | threshold = C_mat.sum() / (C_mat != 0).sum()
66 | threshold *= threshold
67 | threshold *= threshold
68 |
69 | cpt_dep_mat = (T_mat > threshold).astype(np.int64)
70 |
71 | kwargs['cpt_dep_mat'] = cpt_dep_mat
72 |
73 | return kwargs
74 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/common/build_missing_Q.py:
--------------------------------------------------------------------------------
1 | from .base_mid2cache import BaseMid2Cache
2 | import numpy as np
3 | import pandas as pd
4 | from itertools import chain
5 | import torch
6 | from edustudio.utils.common import set_same_seeds
7 |
8 |
9 | class M2C_BuildMissingQ(BaseMid2Cache):
10 | default_cfg = {
11 | 'seed': 20230518,
12 | 'Q_delete_ratio': 0.0,
13 | }
14 |
15 | def process(self, **kwargs):
16 | dt_info = kwargs['dt_info']
17 | self.item_count = dt_info['exer_count']
18 | self.cpt_count = dt_info['cpt_count']
19 | self.df_Q = kwargs['df_exer'][['exer_id:token', 'cpt_seq:token_seq']]
20 |
21 | self.missing_df_Q = self.get_missing_df_Q()
22 | self.missing_Q_mat = self.get_Q_mat_from_df_arr(self.missing_df_Q, self.item_count, self.cpt_count)
23 |
24 | kwargs['missing_df_Q'] = self.missing_df_Q
25 | kwargs['missing_Q_mat'] = self.missing_Q_mat
26 |
27 | return kwargs
28 |
29 | def get_missing_df_Q(self):
30 | set_same_seeds(seed=self.m2c_cfg['seed'])
31 | ratio = self.m2c_cfg['Q_delete_ratio']
32 | iid2cptlist = self.df_Q.set_index('exer_id:token')['cpt_seq:token_seq'].to_dict()
33 | iid_lis = np.array(list(chain(*[[i]*len(iid2cptlist[i]) for i in iid2cptlist])))
34 | cpt_lis = np.array(list(chain(*list(iid2cptlist.values()))))
35 | entry_arr = np.vstack([iid_lis, cpt_lis]).T
36 |
37 | np.random.shuffle(entry_arr)
38 |
39 | # reference: https://stackoverflow.com/questions/64834655/python-how-to-find-first-duplicated-items-in-an-numpy-array
40 | _, idx = np.unique(entry_arr[:, 1], return_index=True) # 先从每个知识点中选出1题出来
41 | bool_idx = np.zeros_like(entry_arr[:, 1], dtype=bool)
42 | bool_idx[idx] = True
43 | preserved_exers = np.unique(entry_arr[bool_idx, 0]) # 选择符合条件的习题作为保留
44 |
45 | delete_num = int(ratio * self.item_count)
46 | preserved_num = self.item_count - delete_num
47 |
48 | if len(preserved_exers) >= preserved_num:
49 | self.logger.warning(
50 | f"Cant Satisfy Delete Require: {len(preserved_exers)=},{preserved_num=}"
51 | )
52 | else:
53 | need_preserved_num = preserved_num - len(preserved_exers)
54 |
55 | left_iids = np.arange(self.item_count)
56 | left_iids = left_iids[~np.isin(left_iids, preserved_exers)]
57 | np.random.shuffle(left_iids)
58 | choose_iids = left_iids[0:need_preserved_num]
59 |
60 | preserved_exers = np.hstack([preserved_exers, choose_iids])
61 |
62 | return self.df_Q.copy()[self.df_Q['exer_id:token'].isin(preserved_exers)].reset_index(drop=True)
63 |
64 |
65 | def get_Q_mat_from_df_arr(self, df_Q_arr, item_count, cpt_count):
66 | Q_mat = torch.zeros((item_count, cpt_count), dtype=torch.int64)
67 | for _, item in df_Q_arr.iterrows(): Q_mat[item['exer_id:token'], item['cpt_seq:token_seq']] = 1
68 | return Q_mat
69 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/common/filtering_records_by_attr.py:
--------------------------------------------------------------------------------
1 | from .base_mid2cache import BaseMid2Cache
2 | import pandas as pd
3 | import numpy as np
4 | from itertools import chain
5 |
6 |
7 | class M2C_FilteringRecordsByAttr(BaseMid2Cache):
8 | """Commonly used by Fair Models, and Filtering Students without attribute values
9 | """
10 | default_cfg = {
11 | 'filter_stu_attrs': ['gender:token']
12 | }
13 |
14 | def process(self, **kwargs):
15 | df_stu = kwargs['df_stu']
16 | df = kwargs['df']
17 | df_stu = df_stu[df_stu[self.m2c_cfg['filter_stu_attrs']].notna().all(axis=1)].reset_index(drop=True)
18 | df = df[df['stu_id:token'].isin(df_stu['stu_id:token'])].reset_index(drop=True)
19 |
20 | kwargs['df'] = df
21 | kwargs['df_stu'] = df_stu
22 |
23 | return kwargs
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/common/gen_q_mat.py:
--------------------------------------------------------------------------------
1 | from .base_mid2cache import BaseMid2Cache
2 | import numpy as np
3 | from itertools import chain
4 | import torch
5 |
6 |
7 | class M2C_GenQMat(BaseMid2Cache):
8 | def process(self, **kwargs):
9 | df_exer = kwargs['df_exer']
10 | cpt_count = len(set(list(chain(*df_exer['cpt_seq:token_seq'].to_list()))))
11 | # df_exer['cpt_multihot:token_seq'] = df_exer['cpt_seq:token_seq'].apply(
12 | # lambda x: self.multi_hot(cpt_count, np.array(x)).tolist()
13 | # )
14 | kwargs['df_exer'] = df_exer
15 | tmp_df_exer = df_exer.set_index("exer_id:token")
16 |
17 | kwargs['Q_mat'] = torch.from_numpy(np.array(
18 | [self.multi_hot(cpt_count, tmp_df_exer.loc[exer_id]['cpt_seq:token_seq']).tolist()
19 | for exer_id in range(df_exer['exer_id:token'].max() + 1)]
20 | ))
21 | return kwargs
22 |
23 | @staticmethod
24 | def multi_hot(length, indices):
25 | multi_hot = np.zeros(length, dtype=np.int64)
26 | multi_hot[indices] = 1
27 | return multi_hot
28 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/common/label2int.py:
--------------------------------------------------------------------------------
1 | from .base_mid2cache import BaseMid2Cache
2 | import numpy as np
3 |
4 |
5 | class M2C_Label2Int(BaseMid2Cache):
6 | def process(self, **kwargs):
7 | self.op_on_df('df', kwargs)
8 | self.op_on_df('df_train', kwargs)
9 | self.op_on_df('df_valid', kwargs)
10 | self.op_on_df('df_test', kwargs)
11 | return kwargs
12 |
13 | @staticmethod
14 | def op_on_df(column, kwargs):
15 | if column in kwargs and kwargs[column] is not None:
16 | kwargs[column]['label:float'] = (kwargs[column]['label:float'] >= 0.5).astype(np.float32)
17 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/common/merge_divided_splits.py:
--------------------------------------------------------------------------------
1 | from .base_mid2cache import BaseMid2Cache
2 | import pandas as pd
3 | from itertools import chain
4 |
5 |
6 | class M2C_MergeDividedSplits(BaseMid2Cache):
7 | default_cfg = {}
8 |
9 | def process(self, **kwargs):
10 | df_train = kwargs['df_train']
11 | df_valid = kwargs['df_valid']
12 | df_test = kwargs['df_test']
13 |
14 | # 1. ensure the keys in df_train, df_valid, df_test is same
15 | assert df_train is not None and df_test is not None
16 | assert set(df_train.columns) == set(df_test.columns)
17 | # 2. 包容df_valid不存在的情况
18 |
19 | df = pd.concat((df_train, df_test), ignore_index=True)
20 |
21 | if df_valid is not None:
22 | assert set(df_train.columns) == set(df_valid.columns)
23 | df = pd.concat((df, df_valid), ignore_index=True)
24 |
25 | kwargs['df'] = df
26 | return kwargs
27 |
28 | def set_dt_info(self, dt_info, **kwargs):
29 | if 'stu_id:token' in kwargs['df'].columns:
30 | dt_info['stu_count'] = int(kwargs['df']['stu_id:token'].max() + 1)
31 | if 'exer_id:token' in kwargs['df'].columns:
32 | dt_info['exer_count'] = int(kwargs['df']['exer_id:token'].max() + 1)
33 | if kwargs.get('df_exer', None) is not None:
34 | if 'cpt_seq:token_seq' in kwargs['df_exer']:
35 | dt_info['cpt_count'] = len(set(list(chain(*kwargs['df_exer']['cpt_seq:token_seq'].to_list()))))
36 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/single/M2C_CNCDQ_OP.py:
--------------------------------------------------------------------------------
1 | from ..common import BaseMid2Cache
2 | import torch
3 |
4 | class M2C_CNCDQ_OP(BaseMid2Cache):
5 | default_cfg = {}
6 |
7 | def process(self, **kwargs):
8 | df_exer = kwargs['df_exer']
9 | exer_count = kwargs['dt_info']['exer_count']
10 | cpt_count = kwargs['dt_info']['cpt_count']
11 | kwargs['Q_mask_mat'], kwargs['knowledge_pairs'] = self._get_knowledge_pairs(
12 | df_Q_arr=df_exer, exer_count=exer_count, cpt_count=cpt_count
13 | )
14 | return kwargs
15 |
16 | def _get_knowledge_pairs(self, df_Q_arr, exer_count, cpt_count):
17 | Q_mask_mat = torch.zeros((exer_count, cpt_count), dtype=torch.int64)
18 | knowledge_pairs = []
19 | kn_tags = []
20 | kn_topks = []
21 | for _, item in df_Q_arr.iterrows():
22 | # kn_tags.append(item['cpt_seq'])
23 | # kn_topks.append(item['cpt_pre_seq'])
24 | kn_tags = item['cpt_seq']
25 | kn_topks = item['cpt_pre_seq']
26 | knowledge_pairs.append((kn_tags, kn_topks))
27 | for cpt_id in item['cpt_seq']:
28 | Q_mask_mat[item['exer_id'], cpt_id-1] = 1
29 | for cpt_id in item['cpt_pre_seq']:
30 | Q_mask_mat[item['exer_id'], cpt_id-1] = 1
31 | return Q_mask_mat, knowledge_pairs
32 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/single/M2C_DIMKT_OP.py:
--------------------------------------------------------------------------------
1 | from ..common import BaseMid2Cache
2 | import torch
3 | import numpy as np
4 |
5 |
6 | class M2C_DIMKT_OP(BaseMid2Cache):
7 | default_cfg = {}
8 |
9 | def process(self, **kwargs):
10 | self.df_Q = kwargs['df_exer']
11 | dt_info = kwargs['dt_info']
12 | self.num_q = dt_info['exer_count']
13 | self.num_c = dt_info['cpt_count']
14 |
15 | df_train_folds = kwargs['df_train_folds']
16 |
17 | kwargs['q_diff_list'] = []
18 | kwargs['c_diff_list'] = []
19 | for train_dict in df_train_folds:
20 | self.train_dict = train_dict
21 | self.compute_difficulty()
22 | kwargs['q_diff_list'].append(self.q_dif)
23 | kwargs['c_diff_list'].append(self.c_dif)
24 | return kwargs
25 |
26 | def compute_difficulty(self):
27 | q_dict = dict(zip(self.df_Q['exer_id:token'], self.df_Q['cpt_seq:token_seq']))
28 | qd={}
29 | qd_count={}
30 | cd={}
31 | cd_count={}
32 | exer_ids = self.train_dict['exer_seq:token_seq']
33 | label_ids = self.train_dict['label_seq:float_seq']
34 | mask_ids = self.train_dict['mask_seq:token_seq']
35 | for ii, ee in enumerate(exer_ids):
36 | for i, e in enumerate(ee):
37 | tmp_mask = mask_ids[ii, i]
38 | if tmp_mask != 0:
39 | tmp_exer = exer_ids[ii, i]
40 | tmp_label = label_ids[ii, i]
41 | cpt = (q_dict[tmp_exer])[0]
42 | cd[cpt] = cd.get(cpt, 0) + tmp_label
43 | cd_count[cpt] = cd_count.get(cpt, 0) + 1
44 | if tmp_exer in qd:
45 | qd[tmp_exer] = qd[tmp_exer] + tmp_label
46 | qd_count[tmp_exer] = qd_count[tmp_exer]+1
47 | else:
48 | qd[tmp_exer] = tmp_label
49 | qd_count[tmp_exer] = 1
50 | else:
51 | break
52 |
53 |
54 | self.q_dif = np.ones(self.num_q)
55 | self.c_dif = np.ones(self.num_c)
56 | for k,v in qd.items():
57 | self.q_dif[k] = int((qd[k]/qd_count[k])*100)+1
58 | for k,v in cd.items():
59 | self.c_dif[k] = int((cd[k]/cd_count[k])*100)+1
60 | self.q_dif = torch.tensor(self.q_dif).unsqueeze(1)
61 | self.c_dif = torch.tensor(self.c_dif).unsqueeze(1)
62 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/single/M2C_EERNN_OP.py:
--------------------------------------------------------------------------------
1 | from ..common import BaseMid2Cache
2 | from itertools import chain
3 | import torch
4 | import pandas as pd
5 | import numpy as np
6 | from edustudio.datatpl.utils import PadSeqUtil
7 | import copy
8 |
9 |
10 | class M2C_EERNN_OP(BaseMid2Cache):
11 | default_cfg = {
12 | 'word_emb_dim': 50,
13 | 'max_sentence_len': 100,
14 | }
15 |
16 | def process(self, **kwargs):
17 | dt_info = kwargs['dt_info']
18 | train_dict_list = kwargs['df_train_folds']
19 | df_exer = kwargs['df_exer']
20 | assert len(train_dict_list) > 0
21 | import gensim
22 |
23 | word_emb_dict_list,content_mat_list = [], []
24 | for train_dict in train_dict_list:
25 | exers = np.unique(train_dict['exer_seq:token_seq'])
26 | sentences = df_exer[df_exer['exer_id:token'].isin(exers)]['content:token_seq'].tolist()
27 | # desc_dict = gensim.corpora.Dictionary(sentences)
28 | # word_set = list(desc_dict.token2id.keys())
29 | # word2id = {w:i+1 for i,w in enumerate(word_set)}
30 | model = gensim.models.word2vec.Word2Vec(
31 | sentences, vector_size=self.m2c_cfg['word_emb_dim'],
32 | )
33 | wv_from_bin = model.wv
34 |
35 | word2id = copy.deepcopy(model.wv.key_to_index)
36 | word2id = {k: word2id[k] + 1 for k in word2id}
37 | word_emb_dict = {word2id[key]: wv_from_bin[key] for key in word2id}
38 | word_emb_dict[0] = np.zeros(shape=(self.m2c_cfg['word_emb_dim'], ))
39 |
40 | # 将训练集、验证集、测试集中未出现在word_emb中的单词,全部替换成ID为0,并进行padding
41 | df_exer['content:token_seq'] = df_exer['content:token_seq'].apply(lambda x: [word2id.get(xx, 0) for xx in x])
42 | pad_mat, _, _ = PadSeqUtil.pad_sequence(
43 | df_exer['content:token_seq'].tolist(), maxlen=self.m2c_cfg['max_sentence_len'], padding='post',
44 | is_truncate=True, truncating='post', value=0,
45 | )
46 | df_exer['content:token_seq'] = [pad_mat[i].tolist() for i in range(pad_mat.shape[0])]
47 |
48 | tmp_df_Q = df_exer.set_index('exer_id:token')
49 | content_mat = torch.from_numpy(np.vstack(
50 | [tmp_df_Q.loc[exer_id]['content:token_seq'] for exer_id in range(dt_info['exer_count'])]
51 | ))
52 |
53 | word_emb_dict_list.append(word_emb_dict)
54 | content_mat_list.append(content_mat)
55 | kwargs['word_emb_dict_list'] = word_emb_dict_list
56 | kwargs['content_mat_list'] = content_mat_list
57 | return kwargs
58 |
59 | def set_dt_info(self, dt_info, **kwargs):
60 | dt_info['word_emb_dim'] = self.m2c_cfg['word_emb_dim']
61 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/single/M2C_MGCD_OP.py:
--------------------------------------------------------------------------------
1 | from ..common import BaseMid2Cache
2 | import pandas as pd
3 | import numpy as np
4 |
5 |
6 | class M2C_MGCD_OP(BaseMid2Cache):
7 | default_cfg = {
8 | "group_id_field": "class_id:token",
9 | "min_inter": 5,
10 | }
11 |
12 | def __init__(self, m2c_cfg, is_dataset_divided) -> None:
13 | super().__init__(m2c_cfg)
14 | self.is_dataset_divided = is_dataset_divided
15 | assert self.is_dataset_divided is False
16 |
17 | @classmethod
18 | def from_cfg(cls, cfg):
19 | m2c_cfg = cfg.datatpl_cfg.get(cls.__name__)
20 | is_dataset_divided = cfg.datatpl_cfg.is_dataset_divided
21 | return cls(m2c_cfg, is_dataset_divided)
22 |
23 | def process(self, **kwargs):
24 | df_stu = kwargs['df_stu']
25 | df_inter = kwargs['df']
26 | df_inter_group, df_inter_stu, kwargs['df_stu']= self.get_df_group_and_df_inter(
27 | df_inter=df_inter, df_stu=df_stu
28 | )
29 | kwargs['df'] = df_inter_group
30 | kwargs['df_inter_stu'] = df_inter_stu
31 | kwargs['df_stu'] = kwargs['df_stu'].rename(columns={self.m2c_cfg['group_id_field']: 'group_id:token'})
32 | self.group_count = df_inter_group['group_id:token'].nunique()
33 | return kwargs
34 |
35 | def get_df_group_and_df_inter(self, df_stu:pd.DataFrame, df_inter:pd.DataFrame):
36 | df_inter = df_inter.merge(df_stu[['stu_id:token', self.m2c_cfg['group_id_field']]], on=['stu_id:token'], how='left') # 两个表合并,这样inter里也有class_id
37 |
38 | df_inter_group = pd.DataFrame()
39 | df_inter_stu = pd.DataFrame()
40 |
41 | for _, inter_g in df_inter.groupby(self.m2c_cfg['group_id_field']):
42 | exers_list = inter_g[['stu_id:token', 'exer_id:token']].groupby('stu_id:token').agg(set)['exer_id:token'].tolist()
43 | inter_exer_set = None # 选择所有学生都做的题目
44 | for exers in exers_list: inter_exer_set = exers if inter_exer_set is None else (inter_exer_set & exers)
45 | inter_exer_set = np.array(list(inter_exer_set))
46 | if inter_exer_set.shape[0] >= self.m2c_cfg['min_inter']:
47 | tmp_group_df = inter_g[inter_g['exer_id:token'].isin(inter_exer_set)]
48 | tmp_stu_df = inter_g[~inter_g['exer_id:token'].isin(inter_exer_set)]
49 |
50 | df_inter_group = pd.concat([df_inter_group, tmp_group_df], ignore_index=True, axis=0)
51 | df_inter_stu = pd.concat([df_inter_stu, tmp_stu_df], ignore_index=True, axis=0)
52 | else:
53 | df_stu = df_stu[df_stu['class_id:token'] != inter_g['class_id:token'].values[0]]
54 |
55 | df_inter_group = df_inter_group[['label:float', 'exer_id:token', self.m2c_cfg['group_id_field']]].groupby(
56 | [self.m2c_cfg['group_id_field'], 'exer_id:token']
57 | ).agg('mean').reset_index().rename(columns={self.m2c_cfg['group_id_field']: 'group_id:token'})
58 |
59 | return df_inter_group, df_inter_stu, df_stu[df_stu['class_id:token'].isin(df_inter_group['group_id:token'].unique())]
60 |
61 | def set_dt_info(self, dt_info, **kwargs):
62 | dt_info['group_count'] = self.group_count
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/single/M2C_QDKT_OP.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | from ..common import BaseMid2Cache
3 | import torch
4 | from torch.nn import functional as F
5 | import numpy as np
6 |
7 |
8 | class M2C_QDKT_OP(BaseMid2Cache):
9 | default_cfg = {}
10 |
11 | def process(self, **kwargs):
12 | self.df_Q = kwargs['df_exer']
13 | dt_info = kwargs['dt_info']
14 | self.num_q = dt_info['exer_count']
15 | self.num_c = dt_info['cpt_count']
16 | self.Q_mat = kwargs['Q_mat']
17 | laplacian_matrix = self.laplacian_matrix_by_vectorization()
18 | kwargs['laplacian_matrix'] = laplacian_matrix
19 | return kwargs
20 |
21 | def laplacian_matrix_by_vectorization(self):
22 | normQ = F.normalize(self.Q_mat.float(), p=2, dim=-1)
23 | A = torch.mm(normQ, normQ.T) > (1 - 1/len(normQ))
24 | A = A.int() #Adjacency matrix
25 | D = A.sum(-1, dtype=torch.int32)
26 | diag_idx = [range(len(A)), range(len(A))]
27 | A[diag_idx] = D - A[diag_idx]
28 | return A
29 |
30 | def generate_graph(self):
31 |
32 | graph = nx.Graph()
33 | len1 = len(self.Q_mat)
34 | graph.add_nodes_from([i for i in range(1, len1 + 1)])
35 | for index in range(len1 - 1):
36 | for bindex in range(index + 1, len1):
37 | if not (False in (self.Q_mat[index, :] == self.Q_mat[bindex, :]).tolist()):
38 | graph.add_edge(index + 1, bindex + 1)
39 | return graph
40 |
41 | # 求图的拉普拉斯矩阵 L = D - A
42 |
43 | def laplacian_matrix(self, graph):
44 | # 求邻接矩阵
45 | A = np.array(nx.adjacency_matrix(graph).todense())
46 | A = -A
47 | for i in range(len(A)):
48 | # 求顶点的度
49 | degree_i = graph.degree(i + 1) # 节点编号从1开始,若从0开始,将i+1改为i
50 | A[i][i] = A[i][i] + degree_i
51 | return A
52 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/single/M2C_RCD_OP.py:
--------------------------------------------------------------------------------
1 | from ..common import BaseMid2Cache
2 | import numpy as np
3 | import importlib
4 |
5 |
6 | class M2C_RCD_OP(BaseMid2Cache):
7 | default_cfg = {}
8 |
9 | def process(self, **kwargs):
10 | self.dgl = importlib.import_module("dgl")
11 | k_e, e_k = self.build_k_e(kwargs)
12 | u_e_list, e_u_list = self.build_u_e(kwargs)
13 | local_map = {
14 | 'directed_g': self.build_cpt_directd(kwargs),
15 | 'undirected_g': self.build_cpt_undirected(kwargs),
16 | 'k_from_e': k_e,
17 | 'e_from_k': e_k,
18 | 'u_from_e_list': u_e_list,
19 | 'e_from_u_list': e_u_list,
20 | }
21 | kwargs['local_map'] = local_map
22 | return kwargs
23 |
24 | def build_cpt_undirected(self, kwargs):
25 | cpt_count = kwargs['dt_info']['cpt_count']
26 | cpt_dep_mat = kwargs['cpt_dep_mat']
27 | cpt_dep_mat_undirect = ((cpt_dep_mat + cpt_dep_mat.T) == 2).astype(np.int64)
28 | # undirected (only prerequisite)
29 | g_undirected = self.dgl.graph(np.argwhere(cpt_dep_mat_undirect == 1).tolist(), num_nodes=cpt_count)
30 | return g_undirected
31 |
32 | def build_cpt_directd(self, kwargs):
33 | cpt_count = kwargs['dt_info']['cpt_count']
34 | cpt_dep_mat = kwargs['cpt_dep_mat']
35 | # directed (prerequisite + similarity)
36 | g_directed =self.dgl.graph(np.argwhere(cpt_dep_mat == 1).tolist(), num_nodes=cpt_count)
37 | return g_directed
38 |
39 | def build_k_e(self, kwargs):
40 | cpt_count = kwargs['dt_info']['cpt_count']
41 | exer_count = kwargs['dt_info']['exer_count']
42 | df_exer = kwargs['df_exer']
43 |
44 | edges = df_exer[['exer_id:token','cpt_seq:token_seq']].explode('cpt_seq:token_seq').to_numpy()
45 | edges[:, 1] += exer_count
46 |
47 | k_e = self.dgl.graph(edges.tolist(), num_nodes=cpt_count + exer_count)
48 | e_k = self.dgl.graph(edges[:,[1,0]].tolist(), num_nodes=cpt_count + exer_count)
49 | return k_e, e_k
50 |
51 | def build_u_e(self, kwargs):
52 | stu_count = kwargs['dt_info']['stu_count']
53 | exer_count = kwargs['dt_info']['exer_count']
54 | df_train_folds = kwargs['df_train_folds']
55 |
56 | u_from_e_list= []
57 | e_from_u_list = []
58 | for train_df in df_train_folds:
59 | stu_id = train_df['stu_id:token'].to_numpy() + exer_count
60 | exer_id = train_df['exer_id:token'].to_numpy()
61 | u_e = self.dgl.graph(np.vstack([exer_id, stu_id]).T.tolist(), num_nodes=stu_count + exer_count)
62 | e_u = self.dgl.graph(np.vstack([stu_id, exer_id]).T.tolist(), num_nodes=stu_count + exer_count)
63 | u_from_e_list.append(u_e)
64 | e_from_u_list.append(e_u)
65 | return u_from_e_list, e_from_u_list
66 |
--------------------------------------------------------------------------------
/edustudio/atom_op/mid2cache/single/__init__.py:
--------------------------------------------------------------------------------
1 | from .M2C_EERNN_OP import M2C_EERNN_OP
2 | from .M2C_LPKT_OP import M2C_LPKT_OP
3 | from .M2C_DKTForget_OP import M2C_DKTForget_OP
4 | from .M2C_DKTDSC_OP import M2C_DKTDSC_OP
5 | from .M2C_DIMKT_OP import M2C_DIMKT_OP
6 | from .M2C_QDKT_OP import M2C_QDKT_OP
7 | from .M2C_CL4KT_OP import M2C_CL4KT_OP
8 | from .M2C_MGCD_OP import M2C_MGCD_OP
9 | from .M2C_RCD_OP import M2C_RCD_OP
10 | from .M2C_CDGK_OP import M2C_CDGK_OP
11 |
--------------------------------------------------------------------------------
/edustudio/atom_op/raw2mid/__init__.py:
--------------------------------------------------------------------------------
1 | from .raw2mid import BaseRaw2Mid
2 | from .assist_0910 import R2M_ASSIST_0910
3 | from .frcsub import R2M_FrcSub
4 | from .assist_1213 import R2M_ASSIST_1213
5 | from .math1 import R2M_Math1
6 | from .math2 import R2M_Math2
7 | from .aaai_2023 import R2M_AAAI_2023
8 | from .algebra2005 import R2M_Algebra_0506
9 | from .assist_1516 import R2M_ASSIST_1516
10 | from .assist_2017 import R2M_ASSIST_17
11 | from .bridge2006 import R2M_Bridge2Algebra_0607
12 | from .ednet_kt1 import R2M_EdNet_KT1
13 | from .junyi_area_topic_as_cpt import R2M_Junyi_AreaTopicAsCpt
14 | from .junyi_exer_as_cpt import R2M_JunyiExerAsCpt
15 | from .nips12 import R2M_Eedi_20_T12
16 | from .nips34 import R2M_Eedi_20_T34
17 | from .simulated5 import R2M_Simulated5
18 | from .slp_english import R2M_SLP_English
19 | from .slp_math import R2M_SLP_Math
20 |
21 | # look up api dict
22 | _cli_api_dict_ = {}
23 | _cli_api_dict_['R2M_ASSIST_0910'] = R2M_ASSIST_0910.from_cli
24 | _cli_api_dict_['R2M_FrcSub'] = R2M_FrcSub.from_cli
25 | _cli_api_dict_['R2M_ASSIST_1213'] = R2M_ASSIST_1213.from_cli
26 | _cli_api_dict_['R2M_Math1'] = R2M_Math1.from_cli
27 | _cli_api_dict_['R2M_Math2'] = R2M_Math2.from_cli
28 | _cli_api_dict_['R2M_Xueersi_2023'] = R2M_AAAI_2023.from_cli
29 | _cli_api_dict_['R2M_Algebra_0506'] = R2M_Algebra_0506.from_cli
30 | _cli_api_dict_['R2M_ASSIST_1516'] = R2M_ASSIST_1516.from_cli
31 | _cli_api_dict_['R2M_ASSIST_17'] = R2M_ASSIST_17.from_cli
32 | _cli_api_dict_['R2M_Bridge2Algebra_0607'] = R2M_Bridge2Algebra_0607.from_cli
33 | _cli_api_dict_['R2M_EdNet_KT1'] = R2M_EdNet_KT1.from_cli
34 | _cli_api_dict_['R2M_Junyi_AreaTopicAsCpt'] = R2M_Junyi_AreaTopicAsCpt.from_cli
35 | _cli_api_dict_['R2M_JunyiExerAsCpt'] = R2M_JunyiExerAsCpt.from_cli
36 | _cli_api_dict_['R2M_Eedi_20_T12'] = R2M_Eedi_20_T12.from_cli
37 | _cli_api_dict_['R2M_Eedi_20_T34'] = R2M_Eedi_20_T34.from_cli
38 | _cli_api_dict_['R2M_Simulated5'] = R2M_Simulated5.from_cli
39 | _cli_api_dict_['R2M_SLP_Math'] = R2M_SLP_Math.from_cli
40 | _cli_api_dict_['R2M_SLP_English'] = R2M_SLP_English.from_cli
41 |
--------------------------------------------------------------------------------
/edustudio/atom_op/raw2mid/assist_1516.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from .raw2mid import BaseRaw2Mid
3 |
4 | r"""
5 | R2M_ASSIST_1516
6 | ########################
7 | """
8 |
9 |
10 | class R2M_ASSIST_1516(BaseRaw2Mid):
11 | """R2M_ASSIST_1516 is a class used to handle the ASSISTment 2015-2016 dataset."""
12 |
13 | def process(self):
14 | super().process()
15 | pd.set_option("mode.chained_assignment", None) # ignore warning
16 | # 读取原始数据,查看其属性
17 | raw_data = pd.read_csv(f"{self.rawpath}/2015_100_skill_builders_main_problems.csv", encoding='utf-8')
18 |
19 | # 获取交互信息
20 | # 对log_id进行排序,确定order序列
21 | inter = pd.DataFrame.copy(raw_data).sort_values(by='log_id', ascending=True)
22 | inter['order'] = range(len(inter))
23 | inter['label'] = inter['correct']
24 | df_inter = inter.rename(columns={'sequence_id': 'exer_id', 'user_id': 'stu_id'}).reindex(
25 | columns=['stu_id', 'exer_id', 'label', 'order', ]).rename(
26 | columns={'stu_id': 'stu_id:token', 'exer_id': 'exer_id:token', 'label': 'label:float',
27 | 'order': 'order_id:float'})
28 |
29 | # 获取学生信息
30 | stu = pd.DataFrame(set(raw_data['user_id']), columns=['stu_id', ])
31 | # stu['classId'] = None
32 | # stu['gender'] = None
33 | df_stu = stu.sort_values(by='stu_id', ascending=True)
34 |
35 | # 获取习题信息
36 | exer = pd.DataFrame(set(raw_data['sequence_id']), columns=['exer_id'])
37 | # exer['cpt_seq'] = None
38 | # exer['assignment_id'] = None
39 | df_exer = exer.sort_values(by='exer_id', ascending=True)
40 |
41 | # 此处将数据保存到`self.midpath`中
42 |
43 | df_inter.to_csv(f"{self.midpath}/{self.dt}.inter.csv", index=False, encoding='utf-8')
44 | # df_stu.to_csv(f"{self.midpath}/{self.dt}.stu.csv", index=False, encoding='utf-8')
45 | # df_exer.to_csv(f"{self.midpath}/{self.dt}.exer.csv", index=False, encoding='utf-8')
46 | pd.set_option("mode.chained_assignment", "warn") # ignore warning
47 | return
48 |
--------------------------------------------------------------------------------
/edustudio/atom_op/raw2mid/frcsub.py:
--------------------------------------------------------------------------------
1 | from .raw2mid import BaseRaw2Mid
2 | import pandas as pd
3 | pd.set_option("mode.chained_assignment", None) # ingore warning
4 |
5 | r"""
6 | R2M_FrcSub
7 | #####################################
8 | FrcSub dataset preprocess
9 | """
10 |
11 | class R2M_FrcSub(BaseRaw2Mid):
12 | """R2M_FrcSub is to preprocess FrcSub dataset"""
13 | def process(self):
14 | super().process()
15 | # # Preprocess
16 | #
17 | # 此处对数据集进行处理
18 |
19 | # 读取文本文件转换为 dataframe
20 | df_inter = pd.read_csv(f"{self.rawpath}/data.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
21 | '10', '11', '12', '13', '14', '15', '16', '17', '18', '19'])
22 | df_inter.insert(0, 'stu_id:token', range(len(df_inter)))
23 | df_exer = pd.read_csv(f"{self.rawpath}/q.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7'])
24 | df_exer.insert(0, 'exer_id:token', range(len(df_exer)))
25 | # print(df_inter)
26 | # print(df_exer)
27 |
28 | # 统计知识点个数
29 | cpt_count = df_exer.shape[1]
30 |
31 | # 处理 df_inter 拆成很多行
32 | df_inter = df_inter.melt(id_vars=['stu_id:token'], var_name='exer_id:token', value_name='label:float')
33 | df_inter['exer_id:token'] = df_inter['exer_id:token'].astype(int)
34 | df_inter .sort_values(by = ['stu_id:token','exer_id:token'], inplace=True)
35 | # print(df_inter)
36 |
37 | # 处理 df_exer 先拆成很多行,再合并
38 |
39 | # 拆成很多行
40 | df_exer = df_exer.melt(id_vars=['exer_id:token'], var_name='cpt_seq:token_seq', value_name='value')
41 | df_exer = df_exer[df_exer['value'] == 1]
42 | del df_exer['value']
43 |
44 | # 合并 cpt_seq:token_seq
45 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(int)
46 | df_exer.sort_values(by='cpt_seq:token_seq', inplace=True)
47 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(str)
48 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(str)
49 | df_exer = df_exer.groupby('exer_id:token')['cpt_seq:token_seq'].agg(','.join).reset_index()
50 |
51 | # 按 exer_id:token 进行排序
52 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(int)
53 | df_exer.sort_values(by='exer_id:token', inplace=True)
54 | # print(df_exer)
55 |
56 | # # Stat dataset
57 | #
58 | # 此处统计数据集,保存到cfg对象中
59 |
60 | # cfg.stu_count = len(df_inter['stu_id:token'].unique())
61 | # cfg.exer_count = len(df_exer)
62 | # cfg.cpt_count = cpt_count
63 | # cfg.interaction_count = len(df_inter)
64 | # cfg
65 |
66 | # # Save MidData
67 | #
68 | # 此处将数据保存到`cfg.MIDDATA_PATH`中
69 |
70 | df_inter.to_csv(f"{self.midpath}/{self.dt}.inter.csv", index=False, encoding='utf-8')
71 | df_exer.to_csv(f"{self.midpath}/{self.dt}.exer.csv", index=False, encoding='utf-8')
72 |
--------------------------------------------------------------------------------
/edustudio/atom_op/raw2mid/math1.py:
--------------------------------------------------------------------------------
1 | from .raw2mid import BaseRaw2Mid
2 | import pandas as pd
3 | r"""
4 | R2M_Math1
5 | #####################################
6 | Math1 dataset preprocess
7 | """
8 |
9 |
10 | class R2M_Math1(BaseRaw2Mid):
11 | """R2M_Math1 is to preprocess Math1 dataset"""
12 | def process(self):
13 | super().process()
14 | # 读取文本文件转换为 dataframe
15 | df_inter = pd.read_csv(f"{self.rawpath}/data.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
16 | '10', '11', '12', '13', '14', '15', '16', '17', '18', '19'])
17 | df_exer = pd.read_csv(f"{self.rawpath}/q.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9','10'])
18 |
19 | # 处理 df_inter 拆成很多行
20 | df_inter.insert(0, 'stu_id:token', range(len(df_inter)))
21 | df_inter = df_inter.melt(id_vars=['stu_id:token'], var_name='exer_id:token', value_name='label:float')
22 | df_inter['exer_id:token'] = df_inter['exer_id:token'].astype(int)
23 | df_inter .sort_values(by = ['stu_id:token','exer_id:token'],inplace=True)
24 |
25 | # 处理 df_exer 先拆成很多行,再合并
26 |
27 | # 拆成很多行
28 | df_exer.insert(0, 'exer_id:token', range(len(df_exer)))
29 | df_exer = df_exer.melt(id_vars=['exer_id:token'], var_name='cpt_seq:token_seq', value_name='value')
30 | df_exer = df_exer[df_exer['value'] == 1]
31 | del df_exer['value']
32 |
33 | # 合并 cpt_seq:token_seq
34 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(int)
35 | df_exer.sort_values(by='cpt_seq:token_seq', inplace=True)
36 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(str)
37 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(str)
38 | df_exer = df_exer.groupby('exer_id:token')['cpt_seq:token_seq'].agg(','.join).reset_index()
39 |
40 | # 按 exer_id:token 进行排序
41 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(int)
42 | df_exer.sort_values(by='exer_id:token', inplace=True)
43 |
44 | # 此处将数据保存到`self.midpath`中
45 |
46 | df_inter.to_csv(f"{self.midpath}/{self.dt}.inter.csv", index=False, encoding='utf-8')
47 | df_exer.to_csv(f"{self.midpath}/{self.dt}.exer.csv", index=False, encoding='utf-8')
48 |
--------------------------------------------------------------------------------
/edustudio/atom_op/raw2mid/math2.py:
--------------------------------------------------------------------------------
1 | from .raw2mid import BaseRaw2Mid
2 | import pandas as pd
3 |
4 | r"""
5 | R2M_Math2
6 | #####################################
7 | Math2 dataset preprocess
8 | """
9 |
10 |
11 | class R2M_Math2(BaseRaw2Mid):
12 | """R2M_Math2 is to preprocess Math2 dataset"""
13 | def process(self):
14 | super().process()
15 |
16 | # 读取文本文件转换为 dataframe
17 | df_inter = pd.read_csv(f"{self.rawpath}/data.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
18 | '10', '11', '12', '13', '14', '15', '16', '17', '18', '19'])
19 | df_inter.insert(0, 'stu_id:token', range(len(df_inter)))
20 | df_exer = pd.read_csv(f"{self.rawpath}/q.txt", sep='\t', names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9','10',
21 | '11', '12', '13', '14', '15'])
22 | df_exer.insert(0, 'exer_id:token', range(len(df_exer)))
23 |
24 | # 处理 df_inter 拆成很多行
25 | df_inter = df_inter.melt(id_vars=['stu_id:token'], var_name='exer_id:token', value_name='label:float')
26 | df_inter['exer_id:token'] = df_inter['exer_id:token'].astype(int)
27 | df_inter .sort_values(by = ['stu_id:token','exer_id:token'], inplace=True)
28 |
29 | # 处理 df_exer 先拆成很多行,再合并
30 | # 拆成很多行
31 | df_exer = df_exer.melt(id_vars=['exer_id:token'], var_name='cpt_seq:token_seq', value_name='value')
32 | df_exer = df_exer[df_exer['value'] == 1]
33 | del df_exer['value']
34 |
35 | # 合并 cpt_seq:token_seq
36 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(int)
37 | df_exer.sort_values(by='cpt_seq:token_seq', inplace=True)
38 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(str)
39 | df_exer['cpt_seq:token_seq'] = df_exer['cpt_seq:token_seq'].astype(str)
40 | df_exer = df_exer.groupby('exer_id:token')['cpt_seq:token_seq'].agg(','.join).reset_index()
41 |
42 | # 按 exer_id:token 进行排序
43 | df_exer['exer_id:token'] = df_exer['exer_id:token'].astype(int)
44 | df_exer.sort_values(by='exer_id:token', inplace=True)
45 |
46 | # 保存数据
47 | df_inter.to_csv(f"{self.midpath}/{self.dt}.inter.csv", index=False, encoding='utf-8')
48 | df_exer.to_csv(f"{self.midpath}/{self.dt}.exer.csv", index=False, encoding='utf-8')
--------------------------------------------------------------------------------
/edustudio/atom_op/raw2mid/raw2mid.py:
--------------------------------------------------------------------------------
1 | from edustudio.utils.common import UnifyConfig
2 | import logging
3 | import os
4 |
5 |
6 | class BaseRaw2Mid(object):
7 | def __init__(self, dt, rawpath, midpath) -> None:
8 | self.dt = dt
9 | self.rawpath = rawpath
10 | self.midpath = midpath
11 | self.logger = logging.getLogger("edustudio")
12 | if not os.path.exists(self.midpath):
13 | os.makedirs(self.midpath)
14 |
15 | @classmethod
16 | def from_cfg(cls, cfg: UnifyConfig):
17 | rawdata_folder_path = f"{cfg.frame_cfg.data_folder_path}/rawdata"
18 | middata_folder_path = f"{cfg.frame_cfg.data_folder_path}/middata"
19 | dt = cfg.dataset
20 | return cls(dt, rawdata_folder_path, middata_folder_path)
21 |
22 | def process(self, **kwargs):
23 | self.logger.info(f"{self.__class__.__name__} start !")
24 |
25 | @classmethod
26 | def from_cli(cls, dt, rawpath="./rawpath", midpath="./midpath"):
27 | obj = cls(dt, rawpath, midpath)
28 | obj.process()
29 |
30 | # __cli_api_dict__ = {
31 | # 'BaseRaw2Mid': BaseRaw2Mid.from_cli
32 | # }
33 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/CDGKDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 |
3 |
4 | class CDGKDataTPL(EduDataTPL):
5 | default_cfg = {
6 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat', 'M2C_CDGK_OP'],
7 | }
8 |
9 | # @classmethod
10 | # def load_data(cls, cfg): # 只在middata存在时调用
11 | # kwargs = super().load_data(cfg)
12 | # if cfg.datatpl_cfg['has_cpt2group_file'] is True:
13 | # new_kwargs = cls.load_cpt2group(cfg)
14 | # for df in new_kwargs.values():
15 | # if df is not None:
16 | # cls._preprocess_feat(df) # 类型转换
17 | # kwargs.update(new_kwargs)
18 | # else:
19 | # kwargs['df_cpt2group'] = None
20 | # return kwargs
21 |
22 |
23 | # @classmethod
24 | # def load_cpt2group(cls, cfg):
25 | # cpt2group_fph = f'{cfg.frame_cfg.data_folder_path}/middata/{cfg.dataset}.cpt2group.csv'
26 | # sep = cfg.datatpl_cfg['seperator']
27 | # df_cpt2group = cls._load_atomic_csv(cpt2group_fph, sep=sep)
28 | # return {"df_cpt2group": df_cpt2group}
29 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/CDInterDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import GeneralDataTPL
2 |
3 | class CDInterDataTPL(GeneralDataTPL):
4 | default_cfg = {
5 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD'],
6 | }
7 |
8 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/CDInterExtendsQDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 |
3 | class CDInterExtendsQDataTPL(EduDataTPL):
4 | default_cfg = {
5 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'],
6 | }
7 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/CNCDFDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 |
3 | class CNCDFDataTPL(EduDataTPL):
4 | default_cfg = {
5 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'],
6 | }
7 |
8 | def get_extra_data(self):
9 | return {
10 | 'content': self.df_exer['content:token_seq'].to_list(),
11 | 'Q_mat': self.final_kwargs['Q_mat']
12 | }
13 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/CNCDQDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 | import torch
3 | import pandas as pd
4 | import os
5 |
6 | class CNCDQDataTPL(EduDataTPL):
7 | default_cfg = {
8 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'],
9 | }
10 |
11 | @property
12 | def common_str2df(self):
13 | return {
14 | "df": self.df, "df_train": self.df_train, "df_valid": self.df_valid,
15 | "df_test": self.df_test, "dt_info": self.datatpl_cfg['dt_info'],
16 | "df_stu": self.df_stu, "df_exer": self.df_exer,
17 | "df_questionnaire": self.df_questionnaire
18 | }
19 |
20 | @classmethod
21 | def load_data(cls, cfg): # 只在middata存在时调用
22 | kwargs = super().load_data(cfg)
23 | new_kwargs = cls.load_data_from_questionnaire(cfg)
24 | for df in new_kwargs.values():
25 | if df is not None:
26 | cls._preprocess_feat(df) # 类型转换
27 | kwargs.update(new_kwargs)
28 | return kwargs
29 |
30 | @classmethod
31 | def load_data_from_questionnaire(cls, cfg):
32 | file_path = f'{cfg.frame_cfg.data_folder_path}/middata/{cfg.dataset}.questionnaire.csv'
33 | df_questionnaire = None
34 | if os.path.exists(file_path):
35 | sep = cfg.datatpl_cfg['seperator']
36 | df_questionnaire = pd.read_csv(file_path, sep=sep, encoding='utf-8', usecols=['cpt_head:token', 'cpt_tail:token'])
37 | return {"df_questionnaire": df_questionnaire}
38 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/DCDDataTPL.py:
--------------------------------------------------------------------------------
1 | import os
2 | from ..common.edu_datatpl import EduDataTPL
3 | import json
4 | from edustudio.datatpl.common.general_datatpl import DataTPLStatus
5 | import torch
6 |
7 |
8 | class DCDDataTPL(EduDataTPL):
9 | default_cfg = {
10 | 'n_folds': 5,
11 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat', 'M2C_BuildMissingQ', 'M2C_FillMissingQ'],
12 | 'cpt_relation_file_name': 'cpt_relation',
13 | }
14 |
15 | def __init__(self, cfg, df, df_train=None, df_valid=None, df_test=None, dict_cpt_relation=None, status=DataTPLStatus(), df_stu=None, df_exer=None):
16 | self.dict_cpt_relation = dict_cpt_relation
17 | super().__init__(cfg, df, df_train, df_valid, df_test, df_stu, df_exer, status)
18 |
19 | def _check_param(self):
20 | super()._check_params()
21 | assert 0 <= self.datatpl_cfg['Q_delete_ratio'] < 1
22 |
23 | @property
24 | def common_str2df(self):
25 | dic = super().common_str2df
26 | dic['dict_cpt_relation'] = self.dict_cpt_relation
27 | return dic
28 |
29 |
30 | def process_data(self):
31 | super().process_data()
32 | dt_info = self.final_kwargs['dt_info']
33 | user_count = dt_info['stu_count']
34 | item_count = dt_info['exer_count']
35 | self.interact_mat_list = []
36 | for interact_df in self.final_kwargs['df_train_folds']:
37 | interact_mat = torch.zeros((user_count, item_count), dtype=torch.int8)
38 | idx = interact_df[interact_df['label:float'] == 1][['stu_id:token','exer_id:token']].to_numpy()
39 | interact_mat[idx[:,0], idx[:,1]] = 1
40 | idx = interact_df[interact_df['label:float'] != 1][['stu_id:token','exer_id:token']].to_numpy()
41 | interact_mat[idx[:,0], idx[:,1]] = -1
42 | self.interact_mat_list.append(interact_mat)
43 |
44 | self.final_kwargs['interact_mat_list'] = self.interact_mat_list
45 |
46 | if self.final_kwargs['dict_cpt_relation'] is None:
47 | self.final_kwargs['dict_cpt_relation'] = {i: [i] for i in range(self.final_kwargs['dt_info']['cpt_count'])}
48 |
49 | @classmethod
50 | def load_data(cls, cfg):
51 | kwargs = super().load_data(cfg)
52 | fph = f"{cfg.frame_cfg.data_folder_path}/middata/{cfg.datatpl_cfg['cpt_relation_file_name']}.json"
53 | if os.path.exists(fph):
54 | with open(fph, 'r', encoding='utf-8') as f:
55 | kwargs['dict_cpt_relation'] = json.load(f)
56 | else:
57 | cfg.logger.warning("without cpt_relation.json")
58 | kwargs['dict_cpt_relation'] = None
59 | return kwargs
60 |
61 | def get_extra_data(self):
62 | extra_dict = super().get_extra_data()
63 | extra_dict['filling_Q_mat'] = self.filling_Q_mat
64 | extra_dict['interact_mat'] = self.interact_mat
65 | return extra_dict
66 |
67 | def set_info_for_fold(self, fold_id):
68 | super().set_info_for_fold(fold_id)
69 | self.filling_Q_mat = self.final_kwargs['filling_Q_mat_list'][fold_id]
70 | self.interact_mat = self.final_kwargs['interact_mat_list'][fold_id]
71 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/ECDDataTPL.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | from ..common import EduDataTPL
6 | import pandas as pd
7 |
8 |
9 | class ECDDataTPL(EduDataTPL):
10 | default_cfg = {
11 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD',
12 | 'M2C_GenQMat'],
13 | }
14 |
15 | def process_load_data_from_middata(self):
16 | super().process_load_data_from_middata()
17 | self.final_kwargs['qqq_group_list'] = self.read_QQQ_group(self.cfg)
18 | qqq_list = self.df_stu['qqq_seq:token_seq'].to_list()
19 | self.QQQ_list = torch.tensor(qqq_list)
20 | self.final_kwargs['qqq_list'] = self.QQQ_list
21 | qqq_count = torch.max(self.QQQ_list) + 1
22 | self.cfg['datatpl_cfg']['dt_info']['qqq_count'] = qqq_count
23 |
24 | def get_extra_data(self):
25 | return {
26 | 'qqq_group_list': self.final_kwargs['qqq_group_list'],
27 | 'Q_mat': self.final_kwargs['Q_mat'],
28 | 'qqq_list': self.final_kwargs['qqq_list']
29 | }
30 |
31 | def read_QQQ_group(self, cfg):
32 | group_path = f'{cfg.frame_cfg.data_folder_path}/middata/{cfg.dataset}_QQQ-group.csv'
33 | assert os.path.exists(group_path)
34 | df_QQQ_group = pd.read_csv(group_path, encoding='utf-8', usecols=['qqq_id:token', 'group_id:token'])
35 | gp = df_QQQ_group.groupby("group_id:token")['qqq_id:token']
36 | gps = gp.groups
37 | gps_list = []
38 | for k, v in gps.items():
39 | gps_list.append(list(v))
40 | return gps_list
41 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/FAIRDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 |
3 | class FAIRDataTPL(EduDataTPL):
4 | default_cfg = {
5 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilteringRecordsByAttr', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'],
6 | }
7 |
8 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/HierCDFDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 | import torch
3 | from typing import Dict
4 | import pandas as pd
5 | import os
6 |
7 | class HierCDFDataTPL(EduDataTPL):
8 | default_cfg = {
9 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'],
10 | 'M2C_ReMapId': {
11 | 'share_id_columns': [{'cpt_seq:token_seq', 'cpt_head:token', 'cpt_tail:token'}],
12 | }
13 | }
14 |
15 | def __init__(self, cfg, df_cpt_relation=None, **kwargs):
16 | self.df_cpt_relation = df_cpt_relation
17 | super().__init__(cfg, **kwargs)
18 |
19 | @property
20 | def common_str2df(self):
21 | return {
22 | "df": self.df, "df_train": self.df_train, "df_valid": self.df_valid,
23 | "df_test": self.df_test, "dt_info": self.datatpl_cfg['dt_info'],
24 | "df_stu": self.df_stu, "df_exer": self.df_exer,
25 | "df_cpt_relation": self.df_cpt_relation
26 | }
27 |
28 | @classmethod
29 | def load_data(cls, cfg): # 只在middata存在时调用
30 | kwargs = super().load_data(cfg)
31 | new_kwargs = cls.load_data_from_cpt_relation(cfg)
32 | for df in new_kwargs.values():
33 | if df is not None:
34 | cls._preprocess_feat(df) # 类型转换
35 | kwargs.update(new_kwargs)
36 | return kwargs
37 |
38 | def get_extra_data(self):
39 | extra_dict = super().get_extra_data()
40 | extra_dict['cpt_relation_edges'] = self._unwrap_feat(self.final_kwargs['df_cpt_relation'])
41 | return extra_dict
42 |
43 | @classmethod
44 | def load_data_from_cpt_relation(cls, cfg):
45 | file_path = f'{cfg.frame_cfg.data_folder_path}/middata/{cfg.dataset}.cpt_relation.prerequisite.csv'
46 | df_cpt_relation = None
47 | if os.path.exists(file_path):
48 | sep = cfg.datatpl_cfg['seperator']
49 | df_cpt_relation = pd.read_csv(file_path, sep=sep, encoding='utf-8', usecols=['cpt_head:token', 'cpt_tail:token'])
50 | return {"df_cpt_relation": df_cpt_relation}
51 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/MGCDDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 |
3 |
4 | class MGCDDataTPL(EduDataTPL):
5 | default_cfg = {
6 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_MGCD_OP', 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'],
7 | }
8 |
9 | def get_extra_data(self):
10 | dic = super().get_extra_data()
11 | dic.update({
12 | 'inter_student': self.final_kwargs['df_inter_stu'],
13 | 'df_G': self.final_kwargs['df_stu']
14 | })
15 | return dic
16 |
17 | def df2dict(self):
18 | super().df2dict()
19 | self._unwrap_feat(self.final_kwargs['df_inter_stu'])
20 | self._unwrap_feat(self.final_kwargs['df_stu'])
21 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/RCDDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 | import json
3 | import numpy as np
4 |
5 |
6 | class RCDDataTPL(EduDataTPL):
7 | default_cfg = {
8 | 'mid2cache_op_seq': [
9 | 'M2C_Label2Int', 'M2C_FilterRecords4CD', 'M2C_ReMapId',
10 | 'M2C_RandomDataSplit4CD', 'M2C_BuildKCRelation',
11 | 'M2C_GenQMat', 'M2C_RCD_OP'
12 | ],
13 | 'M2C_BuildKCRelation': {
14 | 'relation_type': 'rcd_transition',
15 | 'threshold': None
16 | }
17 | }
18 |
19 | def get_extra_data(self):
20 | extra_dict = super().get_extra_data()
21 | extra_dict['local_map'] = self.local_map
22 | return extra_dict
23 |
24 | def set_info_for_fold(self, fold_id):
25 | super().set_info_for_fold(fold_id)
26 | self.local_map = self.final_kwargs['local_map']
27 | self.local_map['u_from_e'] = self.local_map['u_from_e_list'][fold_id]
28 | self.local_map['e_from_u'] = self.local_map['e_from_u_list'][fold_id]
29 |
--------------------------------------------------------------------------------
/edustudio/datatpl/CD/__init__.py:
--------------------------------------------------------------------------------
1 | from .CDInterDataTPL import CDInterDataTPL
2 | from .CDInterExtendsQDataTPL import CDInterExtendsQDataTPL
3 | from .IRRDataTPL import IRRDataTPL
4 | from .HierCDFDataTPL import HierCDFDataTPL
5 | from .CNCDFDataTPL import CNCDFDataTPL
6 | from .MGCDDataTPL import MGCDDataTPL
7 | from .CNCDQDataTPL import CNCDQDataTPL
8 | from .RCDDataTPL import RCDDataTPL
9 | from .CDGKDataTPL import CDGKDataTPL
10 | from .ECDDataTPL import ECDDataTPL
11 | from .DCDDataTPL import DCDDataTPL
12 | from .FAIRDataTPL import FAIRDataTPL
13 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/CL4KTDataTPL.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from ..common import EduDataTPL
4 |
5 | class CL4KTDataTPL(EduDataTPL):
6 | default_cfg = {
7 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_GenUnFoldKCSeq', 'M2C_CL4KT_OP', 'M2C_RandomDataSplit4KT'],
8 | 'M2C_CL4KT_OP': {
9 | 'sequence_truncation': 'recent',
10 | }
11 | }
12 |
13 | def __getitem__(self, index):
14 | dic = super().__getitem__(index)
15 |
16 | return dic
17 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/DIMKTDataTPL.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .KTInterExtendsQDataTPL import KTInterExtendsQDataTPL
4 | import torch
5 |
6 |
7 | class DIMKTDataTPL(KTInterExtendsQDataTPL):
8 | default_cfg = {
9 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_GenUnFoldKCSeq', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT', 'M2C_GenKCSeq', "M2C_DIMKT_OP"],
10 | 'M2C_BuildSeqInterFeats': {
11 | # 'window_size': 200,
12 | "extra_inter_feats": ['start_timestamp:float', 'cpt_unfold:token']
13 | }
14 | }
15 |
16 | def __getitem__(self, index):
17 | dic = super().__getitem__(index)
18 | dic['qd_seq'] = np.stack(
19 | [self.q_dif[exer_seq][0] for exer_seq in dic['exer_seq']], axis=0
20 | )
21 | dic['cd_seq'] = np.stack(
22 | [self.c_dif[cpt_seq][0] for cpt_seq in dic['cpt_unfold_seq']], axis=0
23 | )
24 | dic['cd_seq'] = np.squeeze(dic['cd_seq'])
25 | mask = dic['mask_seq']==0
26 | dic['qd_seq'][mask]=0
27 | dic['cd_seq'][mask] = 0
28 | return dic
29 |
30 | def set_info_for_fold(self, fold_id):
31 | super().set_info_for_fold(fold_id)
32 | self.q_dif = self.final_kwargs['q_diff_list'][fold_id]
33 | self.c_dif = self.final_kwargs['c_diff_list'][fold_id]
34 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/DKTDSCDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 | import torch
3 | import numpy as np
4 | import pandas as pd
5 |
6 |
7 | class DKTDSCDataTPL(EduDataTPL):
8 | default_cfg = {
9 | 'mid2cache_op_seq': ["M2C_KCAsExer", 'M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats','M2C_RandomDataSplit4KT', "M2C_DKTDSC_OP"],
10 | }
11 |
12 | def __getitem__(self, index):
13 | dic = super().__getitem__(index)
14 | # dic['cluster'] = self.cluster[(int(dic['stu_id']), int(dic['seg_seq']))]
15 | step = len(dic['seg_seq'])
16 | stu_id = dic['stu_id'].repeat(step)
17 | # cluster_df = pd.DataFrame([[list(each)] for each in torch.cat((stu_id.unsqueeze(1), dic['seg_seq'].unsqueeze(1)), dim=1).numpy()], columns=['stu_seg_id'])
18 | # result = pd.merge(cluster_df, self.cluster, on = ['stu_seg_id']).reset_index(drop=True)
19 | # cluster_id_tensor = torch.Tensor(result['cluster_id'].values)
20 | # dic['cluster'] = cluster_id_tensor
21 | dic['cluster'] = np.ones_like(dic['exer_seq'])
22 | for i in range(step):
23 | try:
24 | dic['cluster'][i] = self.cluster.get((int(stu_id[i]), int(dic['seg_seq'][i])))
25 | except:
26 | dic['cluster'][i] = 0
27 | dic['cluster'] = torch.from_numpy(dic['cluster'])
28 | return dic
29 |
30 |
31 | def set_info_for_fold(self, fold_id):
32 | super().set_info_for_fold(fold_id)
33 | self.cluster = self.final_kwargs['cluster_list'][fold_id]
34 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/DKTForgetDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 |
3 |
4 | class DKTForgetDataTPL(EduDataTPL):
5 | default_cfg = {
6 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats','M2C_RandomDataSplit4KT', "M2C_DKTForget_OP"],
7 | 'M2C_BuildSeqInterFeats': {
8 | "extra_inter_feats": ['start_timestamp:float']
9 | }
10 | }
11 |
12 | def set_info_for_fold(self, fold_id):
13 | dt_info = self.datatpl_cfg['dt_info']
14 | dt_info['n_pcount'] = dt_info['n_pcount_list'][fold_id]
15 | dt_info['n_rgap'] = dt_info['n_rgap_list'][fold_id]
16 | dt_info['n_sgap'] = dt_info['n_sgap_list'][fold_id]
17 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/EERNNDataTPL.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | from ..common import EduDataTPL
4 |
5 |
6 | class EERNNDataTPL(EduDataTPL):
7 | default_cfg = {
8 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats','M2C_RandomDataSplit4KT', 'M2C_EERNN_OP'],
9 | }
10 |
11 | def get_extra_data(self, **kwargs):
12 | super_dic = super().get_extra_data(**kwargs)
13 | super_dic['w2v_word_emb'] = self.w2v_word_emb
14 | super_dic['exer_content'] = self.content_mat
15 | return super_dic
16 |
17 | def set_info_for_fold(self, fold_id):
18 | dt_info = self.datatpl_cfg['dt_info']
19 | dt_info['word_count'] = len(self.word_emb_dict_list[fold_id])
20 |
21 | self.w2v_word_emb = np.vstack(
22 | [self.word_emb_dict_list[fold_id][k] for k in range(self.datatpl_cfg['dt_info']['word_count'])]
23 | )
24 |
25 | self.content_mat = self.content_mat_list[fold_id]
26 |
27 | def save_cache(self):
28 | super().save_cache()
29 | fph1 = f"{self.cache_folder_path}/word_emb_dict_list.pkl"
30 | fph2 = f"{self.cache_folder_path}/content_mat_list.pkl"
31 | self.save_pickle(fph1, self.word_emb_dict_list)
32 | self.save_pickle(fph2, self.content_mat_list)
33 |
34 | def load_cache(self):
35 | super().load_cache()
36 | fph1 = f"{self.cache_folder_path}/word_emb_dict_list.pkl"
37 | fph2 = f"{self.cache_folder_path}/content_mat_list.pkl"
38 | self.word_emb_dict_list = self.load_pickle(fph1)
39 | self.content_mat_list = self.load_pickle(fph2)
40 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/EKTDataTPL.py:
--------------------------------------------------------------------------------
1 |
2 | from .EERNNDataTPL import EERNNDataTPL
3 | import numpy as np
4 |
5 |
6 | class EKTDataTPL(EERNNDataTPL):
7 | default_cfg = {
8 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT', 'M2C_GenKCSeq', 'M2C_EERNN_OP'],
9 | }
10 |
11 | def __getitem__(self, index):
12 | dic = super().__getitem__(index)
13 | dic['cpt_seq'] = np.stack(
14 | [self.cpt_seq_padding[exer_seq] for exer_seq in dic['exer_seq']], axis=0
15 | )
16 | dic['cpt_seq_mask'] = np.stack(
17 | [self.cpt_seq_mask[exer_seq] for exer_seq in dic['exer_seq']], axis=0
18 | )
19 | return dic
20 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/KTInterCptAsExerDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 |
3 | class KTInterCptAsExerDataTPL(EduDataTPL):
4 | default_cfg = {
5 | 'mid2cache_op_seq': ["M2C_KCAsExer", 'M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT'],
6 | }
7 |
8 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/KTInterCptUnfoldDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 | import numpy as np
3 |
4 |
5 | class KTInterCptUnfoldDataTPL(EduDataTPL):
6 | default_cfg = {
7 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_GenUnFoldKCSeq', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT'],
8 | 'M2C_BuildSeqInterFeats': {
9 | "extra_inter_feats": ['start_timestamp:float', 'cpt_unfold:token']
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/KTInterDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import GeneralDataTPL
2 |
3 | class KTInterDataTPL(GeneralDataTPL):
4 | default_cfg = {
5 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT'],
6 | }
7 |
8 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/KTInterExtendsQDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 | import numpy as np
3 |
4 |
5 | class KTInterExtendsQDataTPL(EduDataTPL):
6 | default_cfg = {
7 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT', 'M2C_GenKCSeq'],
8 | }
9 |
10 | def __getitem__(self, index):
11 | dic = super().__getitem__(index)
12 | dic['cpt_seq'] = np.stack(
13 | [self.cpt_seq_padding[exer_seq] for exer_seq in dic['exer_seq']], axis=0
14 | )
15 | dic['cpt_seq_mask'] = np.stack(
16 | [self.cpt_seq_mask[exer_seq] for exer_seq in dic['exer_seq']], axis=0
17 | )
18 | return dic
19 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/LPKTDataTPL.py:
--------------------------------------------------------------------------------
1 | from ..common import EduDataTPL
2 |
3 |
4 | class LPKTDataTPL(EduDataTPL):
5 | default_cfg = {
6 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT', 'M2C_LPKT_OP', "M2C_GenQMat"],
7 | 'M2C_BuildSeqInterFeats': {
8 | "extra_inter_feats": ['start_timestamp:float', 'answer_time:float']
9 | }
10 | }
11 |
12 | def set_info_for_fold(self, fold_id):
13 | dt_info = self.datatpl_cfg['dt_info']
14 | dt_info['answer_time_count'] = dt_info['answer_time_count_list'][fold_id]
15 | dt_info['interval_time_count'] = dt_info['interval_time_count_list'][fold_id]
16 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/QDKTDataTPL.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import numpy as np
3 |
4 | from .KTInterExtendsQDataTPL import KTInterExtendsQDataTPL
5 | import torch
6 |
7 |
8 | class QDKTDataTPL(KTInterExtendsQDataTPL):
9 | default_cfg = {
10 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT', 'M2C_GenKCSeq','M2C_GenQMat','M2C_QDKT_OP'],
11 | }
12 |
13 | def get_extra_data(self, **kwargs):
14 | return {
15 | 'laplacian_matrix': self.final_kwargs['laplacian_matrix'],
16 | 'train_dict': self.train_dict
17 | }
18 |
19 | def set_info_for_fold(self, fold_id):
20 | super().set_info_for_fold(fold_id)
21 | self.train_dict = self.dict_train_folds[fold_id]
22 |
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/RKTDataTPL.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import numpy as np
3 |
4 | from scipy import sparse
5 | from ..common import EduDataTPL
6 |
7 |
8 | class RKTDataTPL(EduDataTPL):
9 | default_cfg = {
10 | 'mid2cache_op_seq': ['M2C_Label2Int', 'M2C_ReMapId','M2C_GenQMat', 'M2C_BuildSeqInterFeats', 'M2C_RandomDataSplit4KT'],
11 | 'M2C_BuildSeqInterFeats': {
12 | "extra_inter_feats": ['start_timestamp:float']
13 | }
14 | }
15 |
16 | def process_load_data_from_middata(self):
17 | super().process_load_data_from_middata()
18 | self.final_kwargs['pro_pro_dense'] = self.get_pro_pro_corr()
19 |
20 | def get_pro_pro_corr(self):
21 | # reference: https://github.com/shalini1194/RKT/issues/2
22 | pro_cpt_adj = []
23 | pro_num = self.cfg['datatpl_cfg']['dt_info']['exer_count']
24 | cpt_num = self.cfg['datatpl_cfg']['dt_info']['cpt_count']
25 | for index in range(len(self.df_exer)):
26 | tmp_df = self.df_exer.iloc[index]
27 | exer_id = tmp_df['exer_id:token']
28 | cpt_seq = tmp_df['cpt_seq:token_seq']
29 | for cpt in cpt_seq:
30 | pro_cpt_adj.append([exer_id, cpt, 1])
31 | pro_cpt_adj = np.array(pro_cpt_adj).astype(np.int32)
32 | pro_cpt_sparse = sparse.coo_matrix((pro_cpt_adj[:, 2].astype(np.float32),
33 | (pro_cpt_adj[:, 0], pro_cpt_adj[:, 1])), shape=(pro_num, cpt_num))
34 | pro_cpt_csc = pro_cpt_sparse.tocsc()
35 | pro_cpt_csr = pro_cpt_sparse.tocsr()
36 | pro_pro_adj = []
37 | for p in range(pro_num):
38 | tmp_skills = pro_cpt_csr.getrow(p).indices
39 | similar_pros = pro_cpt_csc[:, tmp_skills].indices
40 | zipped = zip([p] * similar_pros.shape[0], similar_pros)
41 | pro_pro_adj += list(zipped)
42 |
43 | pro_pro_adj = list(set(pro_pro_adj))
44 | pro_pro_adj = np.array(pro_pro_adj).astype(np.int32)
45 | data = np.ones(pro_pro_adj.shape[0]).astype(np.float32)
46 | pro_pro_sparse = sparse.coo_matrix((data, (pro_pro_adj[:, 0], pro_pro_adj[:, 1])), shape=(pro_num, pro_num))
47 | return 1-pro_pro_sparse.tocoo().toarray()
48 |
49 | def get_extra_data(self):
50 | return {
51 | "pro_pro_dense": self.final_kwargs['pro_pro_dense']
52 | }
53 |
54 | def set_info_for_fold(self, fold_id):
55 | super().set_info_for_fold(fold_id)
56 | self.train_dict = self.dict_train_folds[fold_id]
--------------------------------------------------------------------------------
/edustudio/datatpl/KT/__init__.py:
--------------------------------------------------------------------------------
1 | from .KTInterDataTPL import KTInterDataTPL
2 | from .KTInterExtendsQDataTPL import KTInterExtendsQDataTPL
3 | from .KTInterCptAsExerDataTPL import KTInterCptAsExerDataTPL
4 | from .EERNNDataTPL import EERNNDataTPL
5 | from .LPKTDataTPL import LPKTDataTPL
6 | from .DKTForgetDataTPL import DKTForgetDataTPL
7 | from .KTInterCptUnfoldDataTPL import KTInterCptUnfoldDataTPL
8 | from .DKTDSCDataTPL import DKTDSCDataTPL
9 | from .DIMKTDataTPL import DIMKTDataTPL
10 | from .QDKTDataTPL import QDKTDataTPL
11 | from .RKTDataTPL import RKTDataTPL
12 | from .CL4KTDataTPL import CL4KTDataTPL
13 | from .EKTDataTPL import EKTDataTPL
14 | from .GKTDataTPL import GKTDataTPL
15 |
--------------------------------------------------------------------------------
/edustudio/datatpl/__init__.py:
--------------------------------------------------------------------------------
1 | from .CD import *
2 | from .KT import *
3 | from .common import *
4 |
5 |
--------------------------------------------------------------------------------
/edustudio/datatpl/common/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_datatpl import BaseDataTPL
2 | from .edu_datatpl import EduDataTPL
3 | from .general_datatpl import GeneralDataTPL
4 | from .proxy_datatpl import BaseProxyDataTPL
5 |
--------------------------------------------------------------------------------
/edustudio/datatpl/common/proxy_datatpl.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from .base_datatpl import BaseDataTPL
3 | from edustudio.utils.common import UnifyConfig
4 |
5 |
6 | class BaseProxyDataTPL(object):
7 | """The basic protocol for implementing a proxy data template
8 | """
9 | default_cfg = {'backbone_datatpl_cls': 'BaseDataTPL'}
10 |
11 | @classmethod
12 | def from_cfg_proxy(cls, cfg):
13 | """an interface to instantiate a proxy model
14 |
15 | Args:
16 | cfg (UnifyConfig): the global config object
17 |
18 | Returns:
19 | BaseProxyDataTPL
20 | """
21 | backbone_datatpl_cls = cls.get_backbone_cls(cfg.datatpl_cfg.backbone_datatpl_cls)
22 | new_cls = cls.get_new_cls(p_cls=backbone_datatpl_cls)
23 | return new_cls.from_cfg(cfg)
24 |
25 | @classmethod
26 | def get_backbone_cls(cls, backbone_datatpl_cls):
27 | if isinstance(backbone_datatpl_cls, str):
28 | backbone_datatpl_cls = importlib.import_module('edustudio.datatpl').\
29 | __getattribute__(backbone_datatpl_cls)
30 | elif issubclass(backbone_datatpl_cls, BaseDataTPL):
31 | backbone_datatpl_cls = backbone_datatpl_cls
32 | else:
33 | raise ValueError(f"Unknown type of backbone_datatpl_cls: {backbone_datatpl_cls}")
34 | return backbone_datatpl_cls
35 |
36 | @classmethod
37 | def get_new_cls(cls, p_cls):
38 | """dynamic inheritance
39 |
40 | Args:
41 | p_cls (BaseModel): parent class
42 |
43 | Returns:
44 | BaseProxyModel: A inherited class
45 | """
46 | new_cls = type(cls.__name__ + "_proxy", (cls, p_cls), {})
47 | return new_cls
48 |
49 | @classmethod
50 | def get_default_cfg(cls, backbone_datatpl_cls, **kwargs):
51 | """Get the final default_cfg
52 |
53 | Args:
54 | backbone_datatpl_cls (BaseDataTPL): backbone data template class name
55 |
56 | Returns:
57 | UnifyConfig: the final default config object
58 | """
59 | bb_cls = None
60 | if backbone_datatpl_cls is not None:
61 | bb_cls = cls.get_backbone_cls(backbone_datatpl_cls)
62 | else:
63 | for _cls in cls.__mro__:
64 | if not hasattr(_cls, 'default_cfg'):
65 | break
66 | bb_cls = _cls.default_cfg.get('backbone_datatpl_cls', None)
67 | if bb_cls is not None: break
68 | assert bb_cls is not None
69 | bb_cls = cls.get_backbone_cls(bb_cls)
70 |
71 | cfg = UnifyConfig()
72 | cfg.backbone_datatpl_cls = bb_cls
73 | cfg.backbone_datatpl_cls_name = bb_cls.__name__
74 | new_cls = cls.get_new_cls(p_cls=bb_cls)
75 | for _cls in new_cls.__mro__:
76 | if not hasattr(_cls, 'default_cfg'):
77 | break
78 | if issubclass(_cls, BaseProxyDataTPL):
79 | cfg.update(_cls.default_cfg, update_unknown_key_only=True)
80 | else:
81 | cfg.update(_cls.get_default_cfg(**kwargs), update_unknown_key_only=True)
82 | break
83 | return cfg
84 |
--------------------------------------------------------------------------------
/edustudio/datatpl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .pad_seq_util import PadSeqUtil
2 | from .common import BigfileDownloader, DecompressionUtil
3 | from .spliter_util import SpliterUtil
4 |
5 |
--------------------------------------------------------------------------------
/edustudio/datatpl/utils/common.py:
--------------------------------------------------------------------------------
1 | from contextlib import closing
2 | from tqdm import tqdm
3 | import requests
4 | import zipfile
5 |
6 |
7 | class BigfileDownloader(object):
8 | @staticmethod
9 | def download(url, title, filepath, chunk_size=10240):
10 | with closing(requests.get(url, stream=True, allow_redirects=True)) as resp:
11 | if resp.status_code != 200:
12 | raise Exception("[ERROR]: {} - {} -{}".format(str(resp.status_code), title, url))
13 | chunk_size = chunk_size
14 | content_size = int(resp.headers['content-length'])
15 | with tqdm(total=content_size, desc=title, ncols=100) as pbar:
16 | with open(filepath, 'wb') as f:
17 | for data in resp.iter_content(chunk_size=chunk_size):
18 | f.write(data)
19 | pbar.update(len(data))
20 |
21 |
22 | class DecompressionUtil(object):
23 | @staticmethod
24 | def unzip_file(zip_src, dst_dir):
25 | r = zipfile.is_zipfile(zip_src)
26 | if r:
27 | fz = zipfile.ZipFile(zip_src, 'r')
28 | for file in tqdm(fz.namelist(), desc='unzip...', ncols=100):
29 | fz.extract(file, dst_dir)
30 | else:
31 | raise Exception(f'{zip_src} is not a zip file')
32 |
--------------------------------------------------------------------------------
/edustudio/evaltpl/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_evaltpl import BaseEvalTPL
2 | from .prediction_evaltpl import PredictionEvalTPL
3 | from .interpretability_evaltpl import InterpretabilityEvalTPL
4 | from .fairness_evaltpl import FairnessEvalTPL
5 | from .identifiability_evaltpl import IdentifiabilityEvalTPL
6 |
7 |
--------------------------------------------------------------------------------
/edustudio/evaltpl/base_evaltpl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | import logging
6 | from edustudio.utils.common import UnifyConfig
7 | from edustudio.utils.callback import CallbackList
8 |
9 |
10 | class BaseEvalTPL(object):
11 | """The baisc protocol for implementing a evaluate template
12 | """
13 | default_cfg = {}
14 |
15 | def __init__(self, cfg):
16 | self.cfg: UnifyConfig = cfg
17 | self.datatpl_cfg: UnifyConfig = cfg.datatpl_cfg
18 | self.evaltpl_cfg: UnifyConfig = cfg.evaltpl_cfg
19 | self.traintpl_cfg: UnifyConfig = cfg.traintpl_cfg
20 | self.frame_cfg: UnifyConfig = cfg.frame_cfg
21 | self.modeltpl_cfg: UnifyConfig = cfg.modeltpl_cfg
22 | self.logger: logging.Logger = logging.getLogger("edustudio")
23 | self.name = self.__class__.__name__
24 | self._check_params()
25 |
26 | @classmethod
27 | def get_default_cfg(cls):
28 | parent_class = cls.__base__
29 | cfg = UnifyConfig(cls.default_cfg)
30 | if hasattr(parent_class, 'get_default_cfg'):
31 | cfg.update(parent_class.get_default_cfg(), update_unknown_key_only=True)
32 | return cfg
33 |
34 | def eval(self, **kwargs):
35 | pass
36 |
37 | def _check_params(self):
38 | pass
39 |
40 | def set_callback_list(self, callbacklist: CallbackList):
41 | self.callback_list = callbacklist
42 |
43 | def set_dataloaders(self, train_loader, test_loader, valid_loader=None):
44 | self.train_loader = train_loader
45 | self.valid_loader = valid_loader
46 | self.test_loader = test_loader
47 |
48 | def add_extra_data(self, **kwargs):
49 | self.extra_data = kwargs
50 |
51 |
--------------------------------------------------------------------------------
/edustudio/evaltpl/prediction_evaltpl.py:
--------------------------------------------------------------------------------
1 | from .base_evaltpl import BaseEvalTPL
2 | import numpy as np
3 | from edustudio.utils.common import tensor2npy
4 | from sklearn.metrics import mean_squared_error, roc_auc_score, accuracy_score, f1_score, label_ranking_loss, coverage_error
5 |
6 |
7 | class PredictionEvalTPL(BaseEvalTPL):
8 | """Student Performance Prediction Evaluation
9 | """
10 | default_cfg = {
11 | 'use_metrics': ['auc', 'acc', 'rmse']
12 | }
13 |
14 | def __init__(self, cfg):
15 | super().__init__(cfg)
16 |
17 |
18 | def eval(self, y_pd, y_gt, **kwargs):
19 | if not isinstance(y_pd, np.ndarray): y_pd = tensor2npy(y_pd)
20 | if not isinstance(y_gt, np.ndarray): y_gt = tensor2npy(y_gt)
21 | metric_result = {}
22 | ignore_metrics = kwargs.get('ignore_metrics', {})
23 | for metric_name in self.evaltpl_cfg[self.__class__.__name__]['use_metrics']:
24 | if metric_name not in ignore_metrics:
25 | metric_result[metric_name] = self._get_metrics(metric_name)(y_gt, y_pd)
26 | return metric_result
27 |
28 | def _get_metrics(self, metric):
29 | if metric == "auc":
30 | return roc_auc_score
31 | elif metric == "mse":
32 | return mean_squared_error
33 | elif metric == 'rmse':
34 | return lambda y_gt, y_pd: mean_squared_error(y_gt, y_pd) ** 0.5
35 | elif metric == "acc":
36 | return lambda y_gt, y_pd: accuracy_score(y_gt, np.where(y_pd >= 0.5, 1, 0))
37 | elif metric == "f1_macro":
38 | return lambda y_gt, y_pd: f1_score(y_gt, y_pd, average='macro')
39 | elif metric == "f1_micro":
40 | return lambda y_gt, y_pd: f1_score(y_gt, y_pd, average='micro')
41 | elif metric == "ranking_loss":
42 | return lambda y_gt, y_pd: label_ranking_loss(y_gt, y_pd)
43 | elif metric == 'coverage_error':
44 | return lambda y_gt, y_pd: coverage_error(y_gt, y_pd)
45 | elif metric == 'samples_auc':
46 | return lambda y_gt, y_pd: roc_auc_score(y_gt, y_pd, average='samples')
47 | else:
48 | raise NotImplementedError
49 |
50 |
--------------------------------------------------------------------------------
/edustudio/model/CD/__init__.py:
--------------------------------------------------------------------------------
1 | from .irt import IRT
2 | from .mf import MF
3 | from .mirt import MIRT
4 | from .ncdm import NCDM
5 | from .dina import DINA, STEDINA
6 | from .kancd import KaNCD
7 | from .kscd import KSCD
8 | from .cncd_q import CNCD_Q
9 | from .irr import IRR
10 | from .hier_cdf import HierCDF
11 | from .rcd import RCD
12 | from .cncd_f import CNCD_F
13 | from .cdgk import CDGK_SINGLE, CDGK_MULTI
14 | from .cdmfkc import CDMFKC
15 | from .ecd import *
16 | from .mgcd import MGCD
17 | from .dcd import DCD
18 | from .faircd import FairCD_IRT, FairCD_MIRT, FairCD_NCDM
19 |
--------------------------------------------------------------------------------
/edustudio/model/CD/irr.py:
--------------------------------------------------------------------------------
1 | r"""
2 | IRR
3 | ##################################
4 | Reference:
5 | Tong et al. "Item Response Ranking for Cognitive Diagnosis." in IJCAI 2021.
6 | """
7 | from ..basemodel import BaseProxyModel
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class PairSCELoss(nn.Module):
14 | """IRR loss function"""
15 | def __init__(self):
16 | """IRR loss will use cross entropy"""
17 | super(PairSCELoss, self).__init__()
18 | self._loss = nn.CrossEntropyLoss()
19 |
20 | def forward(self, pred1, pred2, sign=1, *args):
21 | """Get the PairSCELoss
22 |
23 | Args:
24 | pred1 (torch.Tensor): positive prediction
25 | pred2 (_type_): negtive prediction
26 | sign (int, optional): 1: pred1 should be greater than pred2; -1: otherwise. Defaults to 1.
27 |
28 | Returns:
29 | torch.Tensor: PairSCELoss
30 | """
31 | pred = torch.stack([pred1, pred2], dim=1)
32 | return self._loss(pred, ((torch.ones(pred1.shape[0], device=pred.device) - sign) / 2).long())
33 |
34 |
35 |
36 | class IRR(BaseProxyModel):
37 | """
38 | backbone_modeltpl_cls: The backbone model of IRR
39 | """
40 | default_cfg = {
41 | "backbone_modeltpl_cls": "IRT",
42 | 'pair_weight': 0.5,
43 | }
44 |
45 | def __init__(self, cfg):
46 | """Pass parameters from other templates into the model
47 |
48 | Args:
49 | cfg (UnifyConfig): parameters from other templates
50 | """
51 | super().__init__(cfg)
52 |
53 | def build_model(self):
54 | """Initialize the various components of the model"""
55 | super().build_model()
56 | self.irr_pair_loss = PairSCELoss()
57 |
58 | def get_main_loss(self, **kwargs):
59 | """Get the loss of IRR
60 |
61 | Returns:
62 | dict: {'loss_main': loss}
63 | """
64 | pair_exer = kwargs['pair_exer']
65 | pair_pos_stu = kwargs['pair_pos_stu']
66 | pair_neg_stu = kwargs['pair_neg_stu']
67 |
68 | kwargs['exer_id'] = pair_exer
69 | kwargs['stu_id'] = pair_pos_stu
70 | pos_pd = self(**kwargs).flatten()
71 | kwargs['stu_id'] = pair_neg_stu
72 | neg_pd = self(**kwargs).flatten()
73 | pos_label = torch.ones(pos_pd.shape[0]).to(self.device)
74 | neg_label = torch.zeros(neg_pd.shape[0]).to(self.device)
75 | point_loss = F.binary_cross_entropy(input=pos_pd, target=pos_label) + F.binary_cross_entropy(input=neg_pd, target=neg_label)
76 |
77 | return {
78 | 'loss_main': self.modeltpl_cfg['pair_weight'] * self.irr_pair_loss(pos_pd, neg_pd) + (1-self.modeltpl_cfg['pair_weight']) * point_loss
79 | }
80 |
--------------------------------------------------------------------------------
/edustudio/model/CD/irt.py:
--------------------------------------------------------------------------------
1 | r"""
2 | IRT
3 | ##########################################
4 |
5 | Reference Code:
6 | https://github.com/bigdata-ustc/EduCDM/tree/main/EduCDM/IRT
7 |
8 | """
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from ..gd_basemodel import GDBaseModel
14 |
15 |
16 | class IRT(GDBaseModel):
17 | r"""
18 | IRT
19 | """
20 | default_cfg = {
21 | "a_range": -1.0, # disc range
22 | "diff_range": -1.0, # diff range
23 | "fix_a": False,
24 | "fix_c": True,
25 | }
26 |
27 | def __init__(self, cfg):
28 | super().__init__(cfg)
29 |
30 | def build_cfg(self):
31 | if self.modeltpl_cfg['a_range'] is not None and self.modeltpl_cfg['a_range'] < 0: self.modeltpl_cfg['a_range'] = None
32 | if self.modeltpl_cfg['diff_range'] is not None and self.modeltpl_cfg['diff_range'] < 0: self.modeltpl_cfg['diff_range'] = None
33 |
34 | self.n_user = self.datatpl_cfg['dt_info']['stu_count']
35 | self.n_item = self.datatpl_cfg['dt_info']['exer_count']
36 |
37 | # 确保c固定时,a一定不能固定
38 | if self.modeltpl_cfg['fix_c'] is False: assert self.modeltpl_cfg['fix_a'] is False
39 |
40 | def build_model(self):
41 | self.theta = nn.Embedding(self.n_user, 1) # student ability
42 | self.a = 0.0 if self.modeltpl_cfg['fix_a'] else nn.Embedding(self.n_item, 1) # exer discrimination
43 | self.b = nn.Embedding(self.n_item, 1) # exer difficulty
44 | self.c = 0.0 if self.modeltpl_cfg['fix_c'] else nn.Embedding(self.n_item, 1)
45 |
46 | def forward(self, stu_id, exer_id, **kwargs):
47 | theta = self.theta(stu_id)
48 | a = self.a(exer_id)
49 | b = self.b(exer_id)
50 | c = self.c if self.modeltpl_cfg['fix_c'] else self.c(exer_id).sigmoid()
51 |
52 | if self.modeltpl_cfg['diff_range'] is not None:
53 | b = self.modeltpl_cfg['diff_range'] * (torch.sigmoid(b) - 0.5)
54 | if self.modeltpl_cfg['a_range'] is not None:
55 | a = self.modeltpl_cfg['a_range'] * torch.sigmoid(a)
56 | else:
57 | a = F.softplus(a) # 让区分度大于0,保持单调性假设
58 | if torch.max(theta != theta) or torch.max(a != a) or torch.max(b != b): # pragma: no cover
59 | raise ValueError('ValueError:theta,a,b may contains nan! The diff_range or a_range is too large.')
60 | return self.irf(theta, a, b, c)
61 |
62 | @staticmethod
63 | def irf(theta, a, b, c, D=1.702):
64 | return c + (1 - c) / (1 + torch.exp(-D * a * (theta - b)))
65 |
66 | def get_main_loss(self, **kwargs):
67 | stu_id = kwargs['stu_id']
68 | exer_id = kwargs['exer_id']
69 | label = kwargs['label']
70 | pd = self(stu_id, exer_id).flatten()
71 | loss = F.binary_cross_entropy(input=pd, target=label)
72 | return {
73 | 'loss_main': loss
74 | }
75 |
76 | def get_loss_dict(self, **kwargs):
77 | return self.get_main_loss(**kwargs)
78 |
79 | @torch.no_grad()
80 | def predict(self, stu_id, exer_id, **kwargs):
81 | return {
82 | 'y_pd': self(stu_id, exer_id).flatten(),
83 | }
84 |
85 |
--------------------------------------------------------------------------------
/edustudio/model/CD/mf.py:
--------------------------------------------------------------------------------
1 | from edustudio.model import GDBaseModel
2 | import torch.nn as nn
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | class MF(GDBaseModel):
8 | default_cfg = {
9 | 'emb_dim': 32,
10 | 'reg_user': 0.0,
11 | 'reg_item': 0.0
12 | }
13 |
14 | def build_cfg(self):
15 | self.n_user = self.datatpl_cfg['dt_info']['stu_count']
16 | self.n_item = self.datatpl_cfg['dt_info']['exer_count']
17 | self.emb_size = self.modeltpl_cfg['emb_dim']
18 | self.reg_user = self.modeltpl_cfg['reg_user']
19 | self.reg_item = self.modeltpl_cfg['reg_item']
20 |
21 | def build_model(self):
22 | self.user_emb = nn.Embedding(
23 | num_embeddings=self.n_user,
24 | embedding_dim=self.emb_size
25 | )
26 | self.item_emb = nn.Embedding(
27 | num_embeddings=self.n_item,
28 | embedding_dim=self.emb_size
29 | )
30 |
31 | def forward(self, user_idx: torch.LongTensor, item_idx: torch.LongTensor):
32 | assert len(user_idx.shape) == 1 and len(item_idx.shape) == 1 and user_idx.shape[0] == item_idx.shape[0]
33 | return torch.einsum("ij,ij->i", self.user_emb(user_idx), self.item_emb(item_idx)).sigmoid()
34 |
35 | @torch.no_grad()
36 | def predict(self, stu_id, exer_id, **kwargs):
37 | return {
38 | 'y_pd': self(stu_id, exer_id).flatten(),
39 | }
40 |
41 | def get_main_loss(self, **kwargs):
42 | stu_id = kwargs['stu_id']
43 | exer_id = kwargs['exer_id']
44 | label = kwargs['label']
45 | pd = self(stu_id, exer_id).flatten()
46 | loss = F.binary_cross_entropy(input=pd, target=label)
47 | return {
48 | 'loss_main': loss
49 | }
50 |
51 | def get_loss_dict(self, **kwargs):
52 | return self.get_main_loss(**kwargs)
--------------------------------------------------------------------------------
/edustudio/model/CD/mirt.py:
--------------------------------------------------------------------------------
1 | r"""
2 | MIRT
3 | ##########################################
4 |
5 | Reference:
6 | Mark D Reckase et al. "Multidimensional item response theory models". Springer, 2009.
7 |
8 | Reference Code:
9 | https://github.com/bigdata-ustc/EduCDM/tree/main/EduCDM/MIRT
10 |
11 | """
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from ..gd_basemodel import GDBaseModel
17 |
18 |
19 | class MIRT(GDBaseModel):
20 | """
21 | 第一种: fix_a = True, fix_c = True
22 | 第二种: fix_a = False, fix_c = True
23 | 第三种: fix_a = False, fix_c = False
24 | """
25 | default_cfg = {
26 | "a_range": -1.0, # disc range
27 | "emb_dim": 32
28 | }
29 | def __init__(self, cfg):
30 | super().__init__(cfg)
31 |
32 | def build_cfg(self):
33 | if self.modeltpl_cfg['a_range'] is not None and self.modeltpl_cfg['a_range'] < 0: self.modeltpl_cfg['a_range'] = None
34 |
35 | self.n_user = self.datatpl_cfg['dt_info']['stu_count']
36 | self.n_item = self.datatpl_cfg['dt_info']['exer_count']
37 | self.emb_dim = self.modeltpl_cfg['emb_dim']
38 |
39 | def build_model(self):
40 | self.theta = nn.Embedding(self.n_user, self.emb_dim) # student ability
41 | self.a = nn.Embedding(self.n_item, self.emb_dim) # exer discrimination
42 | self.b = nn.Embedding(self.n_item, 1) # exer intercept term
43 |
44 | def forward(self, stu_id, exer_id, **kwargs):
45 | theta = self.theta(stu_id)
46 | a = self.a(exer_id)
47 | b = self.b(exer_id).flatten()
48 |
49 | if self.modeltpl_cfg['a_range'] is not None:
50 | a = self.modeltpl_cfg['a_range'] * torch.sigmoid(a)
51 | else:
52 | a = F.softplus(a) # 让区分度大于0,保持单调性假设
53 | if torch.max(theta != theta) or torch.max(a != a) or torch.max(b != b): # pragma: no cover
54 | raise ValueError('ValueError:theta,a,b may contains nan! The diff_range or a_range is too large.')
55 | return self.irf(theta, a, b)
56 |
57 | @staticmethod
58 | def irf(theta, a, b):
59 | return 1 / (1 + torch.exp(- torch.sum(torch.multiply(a, theta), axis=-1) + b)) # 为何sum前要取负号
60 |
61 | def get_main_loss(self, **kwargs):
62 | stu_id = kwargs['stu_id']
63 | exer_id = kwargs['exer_id']
64 | label = kwargs['label']
65 | pd = self(stu_id, exer_id).flatten()
66 | loss = F.binary_cross_entropy(input=pd, target=label)
67 | return {
68 | 'loss_main': loss
69 | }
70 |
71 | def get_loss_dict(self, **kwargs):
72 | return self.get_main_loss(**kwargs)
73 |
74 | @torch.no_grad()
75 | def predict(self, stu_id, exer_id, **kwargs):
76 | return {
77 | 'y_pd': self(stu_id, exer_id).flatten(),
78 | }
79 |
--------------------------------------------------------------------------------
/edustudio/model/KT/GKT/__init__.py:
--------------------------------------------------------------------------------
1 | from .gkt import GKT
--------------------------------------------------------------------------------
/edustudio/model/KT/__init__.py:
--------------------------------------------------------------------------------
1 | from .dkt import DKT
2 | from .dkvmn import DKVMN
3 | from .dkt_plus import DKT_plus
4 | from .iekt import IEKT
5 | from .dkt_forget import DKTForget
6 | from .akt import AKT
7 | from .ckt import CKT
8 | from .hawkeskt import HawkesKT
9 | from .kqn import KQN
10 | from .deep_irt import DeepIRT
11 | from .sakt import SAKT
12 | from .dkt_dsc import DKTDSC
13 | from .lpkt import LPKT
14 | from .simplekt import SimpleKT
15 | from .saint import SAINT
16 | from .skvmn import SKVMN
17 | from .saint_plus import SAINT_plus
18 | from .ct_ncm import CT_NCM
19 | from .lpkt_s import LPKT_S
20 | from .rkt import RKT
21 | from .qikt import QIKT
22 | from .dtransformer import DTransformer
23 | from .qdkt import QDKT
24 | from .cl4kt import CL4KT
25 | from .dimkt import DIMKT
26 | from .GKT import *
27 | from .atkt import ATKT
28 | from .eernn import EERNNM, EERNNA
29 | from .ekt import EKTM, EKTA
30 |
--------------------------------------------------------------------------------
/edustudio/model/KT/dkt.py:
--------------------------------------------------------------------------------
1 | r"""
2 | DKT
3 | ##########################################
4 |
5 | Reference:
6 | Chris Piech et al. "Deep knowledge tracing" in NIPS 2015.
7 |
8 | """
9 |
10 | from ..gd_basemodel import GDBaseModel
11 | import torch.nn as nn
12 | import torch
13 | import torch.nn.functional as F
14 |
15 |
16 | class DKT(GDBaseModel):
17 | default_cfg = {
18 | 'emb_size': 100,
19 | 'hidden_size': 100,
20 | 'num_layers': 1,
21 | 'dropout_rate': 0.2,
22 | 'rnn_or_lstm': 'lstm',
23 | }
24 |
25 | def __init__(self, cfg):
26 | super().__init__(cfg)
27 |
28 | def build_cfg(self):
29 | self.n_user = self.datatpl_cfg['dt_info']['stu_count']
30 | self.n_item = self.datatpl_cfg['dt_info']['exer_count']
31 | assert self.modeltpl_cfg['rnn_or_lstm'] in {'rnn', 'lstm'}
32 |
33 | def build_model(self):
34 | self.exer_emb = nn.Embedding(
35 | self.n_item * 2, self.modeltpl_cfg['emb_size']
36 | )
37 | if self.modeltpl_cfg['rnn_or_lstm'] == 'rnn':
38 | self.seq_model = nn.RNN(
39 | self.modeltpl_cfg['emb_size'], self.modeltpl_cfg['hidden_size'],
40 | self.modeltpl_cfg['num_layers'], batch_first=True
41 | )
42 | else:
43 | self.seq_model = nn.LSTM(
44 | self.modeltpl_cfg['emb_size'], self.modeltpl_cfg['hidden_size'],
45 | self.modeltpl_cfg['num_layers'], batch_first=True
46 | )
47 | self.dropout_layer = nn.Dropout(self.modeltpl_cfg['dropout_rate'])
48 | self.fc_layer = nn.Linear(self.modeltpl_cfg['hidden_size'], self.n_item)
49 |
50 | def forward(self, exer_seq, label_seq, **kwargs):
51 | input_x = self.exer_emb(exer_seq + label_seq.long() * self.n_item)
52 | output, _ = self.seq_model(input_x)
53 | output = self.dropout_layer(output)
54 | y_pd = self.fc_layer(output).sigmoid()
55 | return y_pd
56 |
57 | @torch.no_grad()
58 | def predict(self, **kwargs):
59 | y_pd = self(**kwargs)
60 | y_pd = y_pd[:, :-1].gather(
61 | index=kwargs['exer_seq'][:, 1:].unsqueeze(dim=-1), dim=2
62 | ).squeeze(dim=-1)
63 | y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1]
64 | y_gt = None
65 | if kwargs.get('label_seq', None) is not None:
66 | y_gt = kwargs['label_seq'][:, 1:]
67 | y_gt = y_gt[kwargs['mask_seq'][:, 1:] == 1]
68 | return {
69 | 'y_pd': y_pd,
70 | 'y_gt': y_gt
71 | }
72 |
73 | def get_main_loss(self, **kwargs):
74 | y_pd = self(**kwargs)
75 | y_pd = y_pd[:, :-1].gather(
76 | index=kwargs['exer_seq'][:, 1:].unsqueeze(dim=-1), dim=2
77 | ).squeeze(dim=-1)
78 | y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1]
79 | y_gt = kwargs['label_seq'][:, 1:]
80 | y_gt = y_gt[kwargs['mask_seq'][:, 1:] == 1]
81 | loss = F.binary_cross_entropy(
82 | input=y_pd, target=y_gt
83 | )
84 | return {
85 | 'loss_main': loss
86 | }
87 |
88 | def get_loss_dict(self, **kwargs):
89 | return self.get_main_loss(**kwargs)
90 |
--------------------------------------------------------------------------------
/edustudio/model/KT/dkt_plus.py:
--------------------------------------------------------------------------------
1 | from .dkt import DKT
2 | import torch.nn as nn
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | class DKT_plus(DKT):
8 | default_cfg = {
9 | 'lambda_r': 0.01,
10 | 'lambda_w1': 0.003,
11 | 'lambda_w2': 3.0,
12 | 'reg_all_KCs': True
13 | }
14 |
15 | def __init__(self, cfg):
16 | super().__init__(cfg)
17 |
18 | def get_main_loss(self, **kwargs):
19 | q = kwargs['exer_seq'][:, :-1].unsqueeze(dim=-1)
20 | q_shft = kwargs['exer_seq'][:, 1:].unsqueeze(dim=-1)
21 |
22 | y = self(**kwargs)
23 | pred = y[:, :-1].gather(index=q, dim=2).squeeze(dim=-1)
24 | pred_shft = y[:, :-1].gather(index=q_shft, dim=2).squeeze(dim=-1)
25 |
26 | y_curr = pred[kwargs['mask_seq'][:, :-1] == 1]
27 | y_next = pred_shft[kwargs['mask_seq'][:, 1:] == 1]
28 |
29 | gt_curr = kwargs['label_seq'][:, :-1][kwargs['mask_seq'][:, :-1] == 1]
30 | gt_next = kwargs['label_seq'][:, 1:][kwargs['mask_seq'][:, 1:] == 1]
31 |
32 | loss_main = F.binary_cross_entropy(input=y_next, target=gt_next)
33 | loss_r = self.modeltpl_cfg['lambda_r'] * F.binary_cross_entropy(input=y_curr, target=gt_curr)
34 |
35 | if self.modeltpl_cfg['reg_all_KCs']:
36 | diff = y[:, 1:] - y[:, :-1]
37 | else:
38 | diff = (pred_shft - pred)[kwargs['mask_seq'][:, 1:] == 1]
39 | loss_w1 = torch.norm(diff, 1) / len(diff)
40 | loss_w1 = self.modeltpl_cfg['lambda_w1'] * loss_w1 / self.n_item
41 | loss_w2 = torch.norm(diff, 2) / len(diff)
42 | loss_w2 = self.modeltpl_cfg['lambda_w2'] * loss_w2 / self.n_item
43 |
44 | return {
45 | 'loss_main': loss_main,
46 | 'loss_r': loss_r,
47 | 'loss_w1': loss_w1,
48 | 'loss_w2': loss_w2
49 | }
50 |
--------------------------------------------------------------------------------
/edustudio/model/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from .basemodel import BaseModel, BaseProxyModel
6 | from .gd_basemodel import GDBaseModel
7 | from .CD import *
8 | from .KT import *
--------------------------------------------------------------------------------
/edustudio/model/gd_basemodel.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from abc import abstractmethod
4 | from .utils.common import xavier_normal_initialization, xavier_uniform_initialization, kaiming_normal_initialization, kaiming_uniform_initialization
5 | from .basemodel import BaseModel
6 |
7 |
8 | class GDBaseModel(BaseModel):
9 | """
10 | The model that using gradient descent method can inherit the class
11 | """
12 | default_cfg = {
13 | 'param_init_type': 'xavier_normal', # initialization method of model paramters
14 | 'pretrained_file_path': "", # file path of pretrained model parameters
15 | }
16 |
17 | def __init__(self, cfg):
18 | super().__init__(cfg)
19 | self.device = self.traintpl_cfg['device']
20 | self.share_callback_dict = {
21 | "stop_training": False
22 | }
23 |
24 | @abstractmethod
25 | def build_cfg(self):
26 | """Construct model config
27 | """
28 | pass
29 |
30 | @abstractmethod
31 | def build_model(self):
32 | """Construct model component
33 | """
34 | pass
35 |
36 | def _init_params(self):
37 | """Initialize the model parameters
38 | """
39 | if self.modeltpl_cfg['param_init_type'] == 'default':
40 | pass
41 | elif self.modeltpl_cfg['param_init_type'] == 'xavier_normal':
42 | self.apply(xavier_normal_initialization)
43 | elif self.modeltpl_cfg['param_init_type'] == 'xavier_uniform':
44 | self.apply(xavier_uniform_initialization)
45 | elif self.modeltpl_cfg['param_init_type'] == 'kaiming_normal':
46 | self.apply(kaiming_normal_initialization)
47 | elif self.modeltpl_cfg['param_init_type'] == 'kaiming_uniform':
48 | self.apply(kaiming_uniform_initialization)
49 | elif self.modeltpl_cfg['param_init_type'] == 'init_from_pretrained':
50 | self._load_params_from_pretrained()
51 |
52 | def _load_params_from_pretrained(self):
53 | """Load pretrained model parameters
54 | """
55 | self.load_state_dict(torch.load(self.modeltpl_cfg['pretrained_file_path']))
56 |
57 | def predict(self, **kwargs):
58 | """predict process
59 | """
60 | pass
61 |
62 | def get_loss_dict(self, **kwargs):
63 | """Get a dict object. The key is the loss name, the value is the loss
64 | """
65 | pass
66 |
--------------------------------------------------------------------------------
/edustudio/model/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/edustudio/model/utils/__init__.py
--------------------------------------------------------------------------------
/edustudio/model/utils/common.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn.init import xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, constant_
3 |
4 |
5 | def xavier_normal_initialization(module):
6 | r""" using `xavier_normal_`_ in PyTorch to initialize the parameters in
7 | nn.Embedding and nn.Linear layers. For bias in nn.Linear layers,
8 | using constant 0 to initialize.
9 | .. _`xavier_normal_`:
10 | https://pytorch.org/docs/stable/nn.init.html?highlight=xavier_normal_#torch.nn.init.xavier_normal_
11 | Examples:
12 | >>> self.apply(xavier_normal_initialization)
13 | """
14 | if isinstance(module, nn.Embedding):
15 | xavier_normal_(module.weight.data)
16 | elif isinstance(module, nn.Linear):
17 | xavier_normal_(module.weight.data)
18 | if module.bias is not None:
19 | constant_(module.bias.data, 0)
20 |
21 |
22 | def xavier_uniform_initialization(module):
23 | r""" using `xavier_uniform_`_ in PyTorch to initialize the parameters in
24 | nn.Embedding and nn.Linear layers. For bias in nn.Linear layers,
25 | using constant 0 to initialize.
26 | .. _`xavier_uniform_`:
27 | https://pytorch.org/docs/stable/nn.init.html?highlight=xavier_uniform_#torch.nn.init.xavier_uniform_
28 | Examples:
29 | >>> self.apply(xavier_uniform_initialization)
30 | """
31 | if isinstance(module, nn.Embedding):
32 | xavier_uniform_(module.weight.data)
33 | elif isinstance(module, nn.Linear):
34 | xavier_uniform_(module.weight.data)
35 | if module.bias is not None:
36 | constant_(module.bias.data, 0)
37 |
38 |
39 | def kaiming_normal_initialization(module):
40 | r""" using `kaiming_normal_`_ in PyTorch to initialize the parameters in
41 | nn.Embedding and nn.Linear layers. For bias in nn.Linear layers,
42 | using constant 0 to initialize.
43 | .. _`kaiming_normal`:
44 | https://pytorch.org/docs/stable/nn.init.html?highlight=kaiming_normal_#torch.nn.init.kaiming_normal_
45 | Examples:
46 | >>> self.apply(kaiming_normal_initialization)
47 | """
48 | if isinstance(module, nn.Embedding):
49 | kaiming_normal_(module.weight.data)
50 | elif isinstance(module, nn.Linear):
51 | kaiming_normal_(module.weight.data)
52 | if module.bias is not None:
53 | constant_(module.bias.data, 0)
54 |
55 | def kaiming_uniform_initialization(module):
56 | r""" using `kaiming_uniform_`_ in PyTorch to initialize the parameters in
57 | nn.Embedding and nn.Linear layers. For bias in nn.Linear layers,
58 | using constant 0 to initialize.
59 | .. _`kaiming_uniform`:
60 | https://pytorch.org/docs/stable/nn.init.html?highlight=kaiming_uniform_#torch.nn.init.kaiming_uniform_
61 | Examples:
62 | >>> self.apply(kaiming_uniform_initialization)
63 | """
64 | if isinstance(module, nn.Embedding):
65 | kaiming_uniform_(module.weight.data)
66 | elif isinstance(module, nn.Linear):
67 | kaiming_uniform_(module.weight.data)
68 | if module.bias is not None:
69 | constant_(module.bias.data, 0)
70 |
--------------------------------------------------------------------------------
/edustudio/quickstart/__init__.py:
--------------------------------------------------------------------------------
1 | from .quickstart import run_edustudio
2 |
--------------------------------------------------------------------------------
/edustudio/quickstart/atom_cmds.py:
--------------------------------------------------------------------------------
1 | import fire
2 | from edustudio.atom_op.raw2mid import _cli_api_dict_ as raw2mid
3 | from edustudio.utils.common import Logger
4 |
5 | def entrypoint():
6 | Logger().get_std_logger()
7 | fire.Fire(
8 | {"r2m": raw2mid}
9 | )
10 |
--------------------------------------------------------------------------------
/edustudio/quickstart/init_all.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from edustudio.utils.common import UnifyConfig, PathUtil, Logger
6 | import os
7 |
8 |
9 | def init_all(cfg: UnifyConfig):
10 | """initialize process
11 |
12 | Args:
13 | cfg (UnifyConfig): the global config obejct
14 | """
15 | frame_cfg = cfg.frame_cfg
16 | dataset = cfg.dataset
17 | traintpl_cls_name = cfg.traintpl_cfg.cls if isinstance(cfg.traintpl_cfg.cls, str) else cfg.traintpl_cfg.cls.__name__
18 | model_cls_name = cfg.modeltpl_cfg.cls if isinstance(cfg.modeltpl_cfg.cls, str) else cfg.modeltpl_cfg.cls.__name__
19 |
20 | frame_cfg.data_folder_path = f"{frame_cfg.DATA_FOLDER_PATH}/{dataset}"
21 | # PathUtil.check_path_exist(frame_cfg.data_folder_path)
22 |
23 | frame_cfg.TEMP_FOLDER_PATH = os.path.realpath(frame_cfg.TEMP_FOLDER_PATH)
24 | frame_cfg.ARCHIVE_FOLDER_PATH = os.path.realpath(frame_cfg.ARCHIVE_FOLDER_PATH)
25 |
26 | frame_cfg.temp_folder_path = f"{frame_cfg.TEMP_FOLDER_PATH}/{dataset}/{traintpl_cls_name}/{model_cls_name}/{frame_cfg.ID}"
27 | frame_cfg.archive_folder_path = f"{frame_cfg.ARCHIVE_FOLDER_PATH}/{dataset}/{traintpl_cls_name}/{model_cls_name}"
28 | PathUtil.auto_create_folder_path(
29 | frame_cfg.temp_folder_path, frame_cfg.archive_folder_path
30 | )
31 | log_filepath = f"{frame_cfg.temp_folder_path}/{frame_cfg.ID}.log"
32 | if frame_cfg['LOG_WITHOUT_DATE']:
33 | cfg.logger = Logger(
34 | filepath=log_filepath, fmt='[%(levelname)s]: %(message)s', date_fmt=None,
35 | DISABLE_LOG_STDOUT=cfg['frame_cfg']['DISABLE_LOG_STDOUT']
36 | ).get_std_logger()
37 | else:
38 | cfg.logger = Logger(
39 | filepath=log_filepath, DISABLE_LOG_STDOUT=cfg['frame_cfg']['DISABLE_LOG_STDOUT']
40 | ).get_std_logger()
41 |
42 | if frame_cfg['DISABLE_TQDM_BAR'] is True:
43 | from tqdm import tqdm
44 | from functools import partialmethod
45 | tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
46 |
--------------------------------------------------------------------------------
/edustudio/settings.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from edustudio.utils.common import PathUtil as pathUtil
6 | from edustudio.utils.common import IDUtil as idUtil
7 | import os
8 | from edustudio import __version__
9 |
10 | ID = idUtil.get_random_id_bytime() # RUN ID
11 | EDUSTUDIO_VERSION = __version__
12 |
13 | WORK_DIR = os.getcwd()
14 |
15 | DATA_FOLDER_PATH = f"{WORK_DIR}/data"
16 | TEMP_FOLDER_PATH = f"{WORK_DIR}/temp"
17 | ARCHIVE_FOLDER_PATH = f"{WORK_DIR}/archive"
18 | CFG_FOLDER_PATH = f"{WORK_DIR}/conf"
19 |
20 | pathUtil.auto_create_folder_path(
21 | TEMP_FOLDER_PATH,
22 | ARCHIVE_FOLDER_PATH,
23 | DATA_FOLDER_PATH,
24 | CFG_FOLDER_PATH,
25 | )
26 |
27 | DISABLE_TQDM_BAR = False
28 | LOG_WITHOUT_DATE = False
29 | TQDM_NCOLS = 100
30 | DISABLE_LOG_STDOUT = False
31 |
32 | curr_file_folder = os.path.dirname(__file__)
33 | DT_INFO_FILE_PATH = os.path.realpath(curr_file_folder + "/assets/datasets.yaml")
34 |
35 | DT_INFO_DICT = {} # additional dataset info entrypoint, example: {'FrcSub': {middata_url: https://huggingface.co/datasets/lmcRS/edustudio-datasets/resolve/main/FrcSub/FrcSub-middata.zip} }
36 |
--------------------------------------------------------------------------------
/edustudio/traintpl/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from .base_traintpl import BaseTrainTPL
6 | from .gd_traintpl import GDTrainTPL
7 | from .general_traintpl import GeneralTrainTPL
8 | from .atkt_traintpl import AtktTrainTPL
9 | from .dcd_traintpl import DCDTrainTPL
10 | from .adversarial_traintpl import AdversarialTrainTPL
11 | from .group_cd_traintpl import GroupCDTrainTPL
12 |
--------------------------------------------------------------------------------
/edustudio/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/edustudio/utils/__init__.py
--------------------------------------------------------------------------------
/edustudio/utils/callback/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from .callbacks.callback import Callback
6 | from .callbacks.earlyStopping import EarlyStopping
7 | from .callbacks.baseLogger import BaseLogger
8 | from .callbacks.modelCheckPoint import ModelCheckPoint
9 | from .callbacks.history import History
10 | from .callbacks.epochPredict import EpochPredict
11 | from .callBackList import CallbackList
12 | from .modeState import ModeState
13 | from .callbacks.tensorboardCallBack import TensorboardCallback
14 |
--------------------------------------------------------------------------------
/edustudio/utils/callback/callbacks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HFUT-LEC/EduStudio/4c06ab7a3321607582acc5918788111ee7ad5549/edustudio/utils/callback/callbacks/__init__.py
--------------------------------------------------------------------------------
/edustudio/utils/callback/callbacks/baseLogger.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from .callback import Callback
6 | from collections import defaultdict
7 |
8 |
9 | class BaseLogger(Callback):
10 | def __init__(self, logger=None, **kwargs):
11 | super(BaseLogger, self).__init__()
12 | self.log = print
13 | if logger is not None and hasattr(logger, 'info'):
14 | self.log = logger.info
15 |
16 | self.join_str = kwargs.get("join_str", " | ")
17 | self.group_by_contains = kwargs.get('group_by_contains', ('loss'))
18 | self.group_by_count = kwargs.get('group_by_count', 5)
19 | assert self.group_by_count >= 1 and type(self.group_by_count) is int
20 |
21 | def on_train_begin(self, logs=None, **kwargs):
22 | super().on_train_begin()
23 | self.log("Start Training...")
24 |
25 | def on_epoch_end(self, epoch: int, logs: dict = None, **kwargs):
26 | info = f"[EPOCH={epoch:03d}]: "
27 | flag = [False] * len(logs)
28 | for group_rule in self.group_by_contains:
29 | v_list = []
30 | for i, (k, v) in enumerate(logs.items()):
31 | if group_rule in k and not flag[i]:
32 | v_list.append(f"{k}: {v:.4f}")
33 | flag[i] = True
34 | if len(v_list) > 0: self.log(f"{info}{self.join_str.join(v_list)}")
35 |
36 | logs = {k: logs[k] for i, k in enumerate(logs) if not flag[i]}
37 |
38 | v_list = []
39 | for i, (k, v) in enumerate(logs.items()):
40 | v_list.append(f"{k}: {v:.4f}")
41 | if (i+1) % self.group_by_count == 0:
42 | self.log(f"{info}{self.join_str.join(v_list)}")
43 | v_list = []
44 | if len(v_list) > 0: self.log(f"{info}{self.join_str.join(v_list)}")
45 |
46 | def on_train_end(self, logs=None, **kwargs):
47 | self.log("Training Completed!")
48 |
--------------------------------------------------------------------------------
/edustudio/utils/callback/callbacks/callback.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | import torch.nn as nn
6 | from ..modeState import ModeState
7 | import logging
8 |
9 |
10 | class Callback(object):
11 | def __init__(self):
12 | self.model = None
13 | self.mode_state = None
14 | self.logger = None
15 | self.callback_list = None
16 |
17 | def set_model(self, model: nn.Module):
18 | self.model = model
19 |
20 | def set_state(self, mode_state: ModeState):
21 | self.mode_state = mode_state
22 |
23 | def set_logger(self, logger: logging.Logger):
24 | self.logger = logger
25 |
26 | def set_callback_list(self, callback_list):
27 | self.callback_list = callback_list
28 |
29 | def on_train_begin(self, logs=None, **kwargs):
30 | self.logger.info(f"[CALLBACK]-{self.__class__.__name__} has been registered!")
31 |
32 | def on_train_end(self, logs=None, **kwargs):
33 | pass
34 |
35 | def on_epoch_begin(self, epoch, logs=None, **kwargs):
36 | pass
37 |
38 | def on_epoch_end(self, epoch, logs=None, **kwargs):
39 | pass
40 |
41 | def on_train_batch_begin(self, batch, logs=None, **kwargs):
42 | pass
43 |
44 | def on_train_batch_end(self, batch, logs=None, **kwargs):
45 | pass
46 |
--------------------------------------------------------------------------------
/edustudio/utils/callback/callbacks/earlyStopping.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from typing import List
6 | from .callback import Callback
7 | import numpy as np
8 |
9 |
10 | class Metric(object):
11 | def __init__(self, name, type_):
12 | assert type_ in ['max', 'min']
13 | self.name = name
14 | self.type_ = type_
15 | self.best_epoch = 1
16 | if self.type_ == 'max':
17 | self.best_value = -np.inf
18 | else:
19 | self.best_value = np.inf
20 |
21 | def better_than(self, value):
22 | return value < self.best_value if self.type_ == 'max' else value > self.best_value
23 |
24 | def update(self, epoch, value):
25 | self.best_epoch = epoch
26 | self.best_value = value
27 |
28 |
29 | class EarlyStopping(Callback):
30 | def __init__(self, metric_list:List[list], num_stop_rounds: int = 20, start_round=1):
31 | """_summary_
32 |
33 | Args:
34 | metric_list (List[list]): [['rmse', 'min'],['ndcg', 'max']]
35 | num_stop_rounds (int, optional): all metrics have no improvement in latest num_stop_rounds, suggest to stop training. Defaults to 20.
36 | start_round (int, optional): start detecting from epoch start_round, . Defaults to 1.
37 | """
38 | super().__init__()
39 | assert num_stop_rounds >= 1
40 | assert start_round >= 1
41 | self.start_round = start_round
42 | self.num_stop_round = num_stop_rounds
43 | self.stop_training = False
44 | self.metric_list = [Metric(name=metric_name, type_=metric_type) for metric_name, metric_type in metric_list]
45 |
46 | def on_train_begin(self, logs=None, **kwargs):
47 | super().on_train_begin()
48 | self.model.share_callback_dict['stop_training'] = False
49 |
50 | def on_epoch_end(self, epoch: int, logs: dict = None, **kwargs):
51 | flag = True
52 | for metric in self.metric_list:
53 | if not metric.better_than(logs[metric.name]):
54 | metric.update(epoch=epoch, value=logs[metric.name])
55 |
56 | if self.start_round <= epoch:
57 | if epoch - metric.best_epoch < self.num_stop_round:
58 | flag &= False
59 |
60 | if flag is True:
61 | self.logger.info("Suggest to stop training now")
62 | self.stop_training = True
63 | self.model.share_callback_dict['stop_training'] = True
64 |
--------------------------------------------------------------------------------
/edustudio/utils/callback/callbacks/epochPredict.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from .callback import Callback
6 | import os
7 | import pickle
8 |
9 |
10 | class EpochPredict(Callback):
11 | def __init__(self, save_folder_path, fmt="predict-{epoch}", **kwargs):
12 | super(EpochPredict, self).__init__()
13 | self.kwargs = kwargs
14 | self.save_folder_path = save_folder_path
15 | if not os.path.exists(save_folder_path):
16 | os.makedirs(self.save_folder_path)
17 | self.file_fmt = fmt + ".pkl"
18 | self.pd = None
19 | assert "{epoch}" in self.file_fmt
20 |
21 | def on_epoch_end(self, epoch, logs=None, **kwargs):
22 | self.pd = self.model.predict(**self.kwargs)
23 | filepath = os.path.join(self.save_folder_path, self.file_fmt.format(epoch=epoch))
24 | self.to_pickle(filepath, self.pd)
25 |
26 | @staticmethod
27 | def to_pickle(filepath, obj):
28 | with open(filepath, 'wb') as f:
29 | pickle.dump(obj, f, protocol=4)
30 |
--------------------------------------------------------------------------------
/edustudio/utils/callback/callbacks/history.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from .callback import Callback
6 | from collections import defaultdict
7 | from typing import List, Sequence, Optional
8 | import json
9 | import os
10 | import numpy as np
11 | from edustudio.utils.common import NumpyEncoder
12 |
13 |
14 | class History(Callback):
15 | def __init__(self, folder_path, exclude_metrics=(), plot_curve=False):
16 | super(History, self).__init__()
17 | self.log_as_time = {}
18 | self.exclude_metrics = set(exclude_metrics)
19 | self.folder_path = folder_path
20 | if not os.path.exists(self.folder_path):
21 | os.makedirs(self.folder_path)
22 |
23 | self.plot_curve = plot_curve
24 | if self.plot_curve:
25 | import matplotlib
26 | if not os.path.exists(self.folder_path+"/plots/"):
27 | os.makedirs(self.folder_path+"/plots/")
28 | self.names = None
29 |
30 | def on_epoch_end(self, epoch, logs=None, **kwargs):
31 | self.log_as_time[epoch] = {k:v for k,v in logs.items() if isinstance(v, (int, str, float)) and k not in self.exclude_metrics}
32 | if self.names is None:
33 | self.names = set(self.log_as_time[epoch].keys())
34 | else:
35 | assert self.names == set(self.log_as_time[epoch].keys())
36 |
37 |
38 | def on_train_end(self, logs=None, **kwargs):
39 | self.dump_json(self.log_as_time, os.path.join(self.folder_path, "history.json"))
40 |
41 | if self.plot_curve:
42 | self.logger.info("[CALLBACK]-History: Plot Curve...")
43 | self.plot()
44 | self.logger.info("[CALLBACK]-History: Plot Succeed!")
45 |
46 | def plot(self):
47 | import matplotlib.pyplot as plt
48 | epoch_num = len(self.log_as_time)
49 | x_arr = np.arange(1, epoch_num+1)
50 | for name in self.names:
51 | value_arr = [self.log_as_time[i][name] for i in range(1, epoch_num+1)]
52 | plt.figure()
53 | plt.title(name)
54 | plt.xlabel("epoch")
55 | plt.ylabel("value")
56 | plt.plot(x_arr, value_arr)
57 | plt.autoscale()
58 | plt.savefig(f"{self.folder_path}/plots/{name}.png", dpi=500, bbox_inches='tight', pad_inches=0.1)
59 |
60 | @staticmethod
61 | def dump_json(data, filepath):
62 | if not os.path.exists(os.path.dirname(filepath)):
63 | os.makedirs(os.path.dirname(filepath))
64 | with open(filepath, 'w', encoding='utf8') as f:
65 | json.dump(data, fp=f, indent=2, ensure_ascii=False, cls=NumpyEncoder)
66 |
--------------------------------------------------------------------------------
/edustudio/utils/callback/callbacks/modelCheckPoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from .callback import Callback
6 | from typing import Tuple, List
7 | import torch
8 | import os
9 | import shutil
10 | import glob
11 | from .baseLogger import BaseLogger
12 | import numpy as np
13 | from collections import namedtuple
14 |
15 |
16 | class Metric(object):
17 | def __init__(self, name, type_):
18 | assert type_ in ['max', 'min']
19 | self.name = name
20 | self.type_ = type_
21 | self.best_epoch = 1
22 | self.best_log = dict()
23 |
24 | if self.type_ == 'max':
25 | self.best_value = -np.inf
26 | else:
27 | self.best_value = np.inf
28 |
29 | def better_than(self, value):
30 | return value < self.best_value if self.type_ == 'max' else value > self.best_value
31 |
32 | def update(self, epoch, value, log):
33 | self.best_epoch = epoch
34 | self.best_value = value
35 | self.best_log = log
36 |
37 | class ModelCheckPoint(Callback):
38 | def __init__(self, metric_list: List[list], save_folder_path, save_best_only=True):
39 | super().__init__()
40 | self.save_folder_path = save_folder_path
41 | if not os.path.exists(self.save_folder_path):
42 | os.makedirs(self.save_folder_path)
43 | self.save_best_only = save_best_only
44 | self.metric_list = [Metric(name=metric_name, type_=metric_type) for metric_name, metric_type in metric_list]
45 |
46 |
47 | def on_epoch_end(self, epoch, logs=None, **kwargs):
48 | for metric in self.metric_list:
49 | if not metric.better_than(logs[metric.name]):
50 | metric.update(epoch=epoch, value=logs[metric.name], log=logs)
51 |
52 | self.logger.info(f"Update best epoch as {[epoch]} for {metric.name}!")
53 | filepath = os.path.join(self.save_folder_path, f"best-epoch-for-{metric.name}.pth")
54 | torch.save(obj=self.model.state_dict(), f=filepath)
55 |
56 | if not self.save_best_only:
57 | filepath = os.path.join(self.save_folder_path, f"{epoch:03d}.pth")
58 | torch.save(obj=self.model.state_dict(), f=filepath)
59 |
60 | def on_train_end(self, logs=None, **kwargs):
61 | self.logger.info("==="*10)
62 | for metric in self.metric_list:
63 | self.logger.info(f"[For {metric.name}], the Best Epoch is: {metric.best_epoch}, the value={metric.best_value:.4f}")
64 | for cb in self.callback_list.callbacks:
65 | if isinstance(cb, BaseLogger):
66 | cb.on_epoch_end(metric.best_epoch, logs=metric.best_log)
67 | # rename best epoch filename
68 | os.renames(
69 | os.path.join(self.save_folder_path, f"best-epoch-for-{metric.name}.pth"),
70 | os.path.join(self.save_folder_path, f"best-epoch-{metric.best_epoch:03d}-for-{metric.name}.pth")
71 | )
72 | self.logger.info("===" * 10)
73 |
--------------------------------------------------------------------------------
/edustudio/utils/callback/callbacks/tensorboardCallBack.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from .callback import Callback
6 |
7 | class TensorboardCallback(Callback):
8 | def __init__(self, log_dir, comment=''):
9 | super(TensorboardCallback, self).__init__()
10 | from torch.utils.tensorboard import SummaryWriter
11 | self.writer = SummaryWriter(log_dir=log_dir, comment=comment)
12 |
13 | def on_epoch_end(self, epoch, logs=None, **kwargs):
14 | for name, value in logs.items():
15 | self.writer.add_scalar(tag=name, scalar_value=value, global_step=epoch)
16 |
--------------------------------------------------------------------------------
/edustudio/utils/callback/modeState.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | from enum import Enum
6 |
7 |
8 | class ModeState(Enum):
9 | START = 1
10 | TRAINING = 2
11 | END = 3
12 |
--------------------------------------------------------------------------------
/edustudio/utils/common/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 |
6 | from .commonUtil import PathUtil, IDUtil, tensor2npy, tensor2cpu, IOUtil, set_same_seeds, get_gpu_usage, DecoratorTimer
7 | from .configUtil import UnifyConfig, NumpyEncoder
8 | from .loggerUtil import Logger
9 |
--------------------------------------------------------------------------------
/edustudio/utils/common/loggerUtil.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Xiangzhi Chen
3 | # @Github : kervias
4 |
5 | import logging
6 | import pytz
7 | import sys
8 | import os
9 |
10 |
11 | class Logger(object):
12 | def __init__(self, filepath: str = None,
13 | fmt: str = "%(asctime)s[%(levelname)s]: %(message)s",
14 | date_fmt: str = "%Y-%m-%d %H:%M:%S",
15 | timezone: str = "Asia/Shanghai",
16 | level=logging.DEBUG,
17 | DISABLE_LOG_STDOUT=False
18 | ):
19 | self.timezone = timezone
20 | if filepath:
21 | dir_name = os.path.dirname(filepath)
22 | if not os.path.exists(dir_name):
23 | os.makedirs(dir_name)
24 |
25 | self.logger = logging.getLogger("edustudio")
26 | self.logger.setLevel(level)
27 | formatter = logging.Formatter(
28 | fmt, datefmt=date_fmt
29 | )
30 | formatter.converter = self.converter
31 |
32 | if filepath:
33 | # write into file
34 | fh = logging.FileHandler(filepath, mode='a+', encoding='utf-8')
35 | fh.setLevel(level)
36 | fh.setFormatter(formatter)
37 | self.logger.addHandler(fh)
38 |
39 | # show on console
40 | ch = logging.StreamHandler(sys.stdout)
41 | ch.setLevel(level)
42 | ch.setFormatter(formatter)
43 | flag = False
44 | for handler in self.logger.handlers:
45 | if type(handler) is logging.StreamHandler:
46 | flag = True
47 | if flag is False and DISABLE_LOG_STDOUT is False:
48 | self.logger.addHandler(ch)
49 |
50 | def _flush(self):
51 | for handler in self.logger.handlers:
52 | handler.flush()
53 |
54 | def debug(self, message):
55 | self.logger.debug(message)
56 | self._flush()
57 |
58 | def info(self, message):
59 | self.logger.info(message)
60 | self._flush()
61 |
62 | def warning(self, message):
63 | self.logger.warning(message)
64 | self._flush()
65 |
66 | def error(self, message):
67 | self.logger.error(message)
68 | self._flush()
69 |
70 | def critical(self, message):
71 | self.logger.critical(message)
72 | self._flush()
73 |
74 | def converter(self, sec):
75 | tz = pytz.timezone(self.timezone)
76 | dt = pytz.datetime.datetime.fromtimestamp(sec, tz)
77 | return dt.timetuple()
78 |
79 | def get_std_logger(self):
80 | return self.logger
81 |
--------------------------------------------------------------------------------
/examples/1.run_cd_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='FrcSub',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | 'device': 'cpu'
15 | },
16 | datatpl_cfg_dict={
17 | 'cls': 'CDInterExtendsQDataTPL'
18 | },
19 | modeltpl_cfg_dict={
20 | 'cls': 'KaNCD',
21 | },
22 | evaltpl_cfg_dict={
23 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL', 'IdentifiabilityEvalTPL'],
24 | 'PredictionEvalTPL': {
25 | 'use_metrics': ['auc'],
26 | },
27 | 'InterpretabilityEvalTPL': {
28 | 'use_metrics': ['doa_all', 'doc_all'],
29 | 'test_only_metrics': ['doa_all', 'doc_all'],
30 | },
31 | 'IdentifiabilityEvalTPL': {
32 | 'use_metrics': ['IDS']
33 | }
34 | },
35 | frame_cfg_dict={
36 | 'ID': '123456',
37 | }
38 | )
39 |
--------------------------------------------------------------------------------
/examples/2.run_kt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | },
15 | datatpl_cfg_dict={
16 | 'cls': 'KTInterCptUnfoldDataTPL'
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'DKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/3.run_with_customized_tpl.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 | from edustudio.model import DKT
9 |
10 | run_edustudio(
11 | dataset='ASSIST_0910',
12 | cfg_file_name=None,
13 | traintpl_cfg_dict={
14 | 'cls': 'GeneralTrainTPL',
15 | },
16 | datatpl_cfg_dict={
17 | 'cls': 'KTInterCptUnfoldDataTPL'
18 | },
19 | modeltpl_cfg_dict={
20 | 'cls': DKT,
21 | },
22 | evaltpl_cfg_dict={
23 | 'clses': ['PredictionEvalTPL'],
24 | }
25 | )
26 |
--------------------------------------------------------------------------------
/examples/4.run_cmd_demo.py:
--------------------------------------------------------------------------------
1 | import fire
2 | import sys
3 | import os
4 |
5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
6 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
7 |
8 |
9 | from edustudio.atom_op.raw2mid import _cli_api_dict_ as raw2mid
10 | from edustudio.utils.common import Logger
11 |
12 | def entrypoint():
13 | Logger().get_std_logger()
14 | fire.Fire(
15 | {"r2m": raw2mid}
16 | )
17 |
18 | entrypoint()
19 |
--------------------------------------------------------------------------------
/examples/5.run_with_hyperopt.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 | from hyperopt import hp
9 | from hyperopt import fmin, tpe, space_eval
10 | from edustudio.utils.common import IDUtil as idUtil
11 | import uuid
12 |
13 | def deliver_cfg(args):
14 | g_args = {
15 | 'traintpl_cfg': {},
16 | 'datatpl_cfg': {},
17 | 'modeltpl_cfg': {},
18 | 'evaltpl_cfg': {},
19 | }
20 | for k,v in args.items():
21 | g, k = k.split(".")
22 | assert g in g_args
23 | g_args[g][k] = v
24 | g_args['frame_cfg'] = {
25 | 'ID': idUtil.get_random_id_bytime() + str(uuid.uuid4()).split("-")[-1]
26 | }
27 | return g_args
28 |
29 |
30 | # objective function
31 | def objective_function(args):
32 | g_args = deliver_cfg(args)
33 | cfg, res = run_edustudio(
34 | dataset='FrcSub',
35 | cfg_file_name=None,
36 | traintpl_cfg_dict=g_args['traintpl_cfg'],
37 | datatpl_cfg_dict=g_args['datatpl_cfg'],
38 | modeltpl_cfg_dict=g_args['modeltpl_cfg'],
39 | evaltpl_cfg_dict=g_args['evaltpl_cfg'],
40 | frame_cfg_dict=g_args['frame_cfg'],
41 | return_cfg_and_result=True,
42 | )
43 | return res['auc']
44 |
45 |
46 | space = {
47 | 'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['GeneralTrainTPL']),
48 | 'datatpl_cfg.cls': hp.choice('datapl_cfg.cls', ['CDInterExtendsQDataTPL']),
49 | 'modeltpl_cfg.cls': hp.choice('modeltpl_cfg.cls', ['KaNCD']),
50 | 'evaltpl_cfg.clses': hp.choice('evaltpl_cfg.clses', [['PredictionEvalTPL', 'InterpretabilityEvalTPL']]),
51 |
52 |
53 | 'traintpl_cfg.batch_size': hp.choice('traintpl_cfg.batch_size', [256,]),
54 | 'traintpl_cfg.epoch_num': hp.choice('traintpl_cfg.epoch_num', [2]),
55 | 'modeltpl_cfg.emb_dim': hp.choice('modeltpl_cfg.emb_dim', [20,40])
56 | }
57 |
58 | best = fmin(objective_function, space, algo=tpe.suggest, max_evals=10, verbose=False)
59 |
60 | print("=="*10)
61 | print(best)
62 | print(space_eval(space, best))
63 |
--------------------------------------------------------------------------------
/examples/6.run_with_ray.tune.py:
--------------------------------------------------------------------------------
1 | # run following after installed edustudio
2 | import sys
3 | import os
4 |
5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
6 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
7 |
8 | from edustudio.quickstart import run_edustudio
9 | from ray import tune
10 | import ray
11 | ray.init(num_cpus=4, num_gpus=1)
12 | from edustudio.utils.common import IDUtil as idUtil
13 | import uuid
14 |
15 | def deliver_cfg(args):
16 | g_args = {
17 | 'traintpl_cfg': {},
18 | 'datatpl_cfg': {},
19 | 'modeltpl_cfg': {},
20 | 'evaltpl_cfg': {},
21 | 'frame_cfg': {},
22 | }
23 | for k,v in args.items():
24 | g, k = k.split(".")
25 | assert g in g_args
26 | g_args[g][k] = v
27 | g_args['frame_cfg'] = {
28 | 'ID': idUtil.get_random_id_bytime() + str(uuid.uuid4()).split("-")[-1]
29 | }
30 | return g_args
31 |
32 |
33 | # objective function
34 | def objective_function(args):
35 | g_args = deliver_cfg(args)
36 | print(g_args)
37 | cfg, res = run_edustudio(
38 | dataset='FrcSub',
39 | cfg_file_name=None,
40 | traintpl_cfg_dict=g_args['traintpl_cfg'],
41 | datatpl_cfg_dict=g_args['datatpl_cfg'],
42 | modeltpl_cfg_dict=g_args['modeltpl_cfg'],
43 | evaltpl_cfg_dict=g_args['evaltpl_cfg'],
44 | frame_cfg_dict=g_args['frame_cfg'],
45 | return_cfg_and_result=True
46 | )
47 | return res
48 |
49 |
50 | search_space= {
51 | 'traintpl_cfg.cls': tune.grid_search(['GeneralTrainTPL']),
52 | 'datatpl_cfg.cls': tune.grid_search(['CDInterExtendsQDataTPL']),
53 | 'modeltpl_cfg.cls': tune.grid_search(['KaNCD']),
54 | 'evaltpl_cfg.clses': tune.grid_search([['PredictionEvalTPL', 'InterpretabilityEvalTPL']]),
55 |
56 |
57 | 'traintpl_cfg.batch_size': tune.grid_search([256,]),
58 | 'traintpl_cfg.epoch_num': tune.grid_search([2]),
59 | 'traintpl_cfg.device': tune.grid_search(["cuda:0"]),
60 | 'modeltpl_cfg.emb_dim': tune.grid_search([20,40]),
61 | 'frame_cfg.DISABLE_LOG_STDOUT': tune.grid_search([False]),
62 | }
63 |
64 | tuner = tune.Tuner(
65 | tune.with_resources(objective_function, {"gpu": 1}), param_space=search_space, tune_config=tune.TuneConfig(max_concurrent_trials=1),
66 | )
67 | results = tuner.fit()
68 |
69 | print("=="*10)
70 | print(results.get_best_result(metric="auc", mode="max").config)
71 |
--------------------------------------------------------------------------------
/examples/7.run_toy_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'BaseTrainTPL',
14 | },
15 | datatpl_cfg_dict={
16 | 'cls': 'BaseDataTPL'
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'BaseModel',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['BaseEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_akt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | },
15 | datatpl_cfg_dict={
16 | 'cls': 'KTInterCptUnfoldDataTPL'
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'AKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_atkt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'AtktTrainTPL',
14 | },
15 | datatpl_cfg_dict={
16 | 'cls': 'KTInterCptUnfoldDataTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'ATKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_cdgk_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='FrcSub',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | 'batch_size': 1024,
15 | 'eval_batch_size': 1024
16 | },
17 | datatpl_cfg_dict={
18 | 'cls': 'CDGKDataTPL',
19 | 'M2C_CDGK_OP': {
20 | 'subgraph_count': 1,
21 | }
22 | },
23 | modeltpl_cfg_dict={
24 | 'cls': 'CDGK_MULTI',
25 | },
26 | evaltpl_cfg_dict={
27 | 'clses': ['PredictionEvalTPL'],
28 | }
29 | )
30 |
--------------------------------------------------------------------------------
/examples/single_model/run_cdmfkc_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import torch
4 |
5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
6 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
7 |
8 | from edustudio.quickstart import run_edustudio
9 |
10 | run_edustudio(
11 | dataset='FrcSub',
12 | cfg_file_name=None,
13 | traintpl_cfg_dict={
14 | 'cls': 'GeneralTrainTPL',
15 | 'lr': 0.01,
16 | 'epoch_num': 1000
17 | },
18 | datatpl_cfg_dict={
19 | 'cls': 'CDInterExtendsQDataTPL',
20 | },
21 | modeltpl_cfg_dict={
22 | 'cls': 'CDMFKC',
23 | },
24 | evaltpl_cfg_dict={
25 | 'clses': ['PredictionEvalTPL'],
26 | }
27 | )
28 |
--------------------------------------------------------------------------------
/examples/single_model/run_ckt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterExtendsQDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | 'batch_size': 32,
18 | 'eval_batch_size': 32
19 | },
20 | modeltpl_cfg_dict={
21 | 'cls': 'CKT',
22 | },
23 | evaltpl_cfg_dict={
24 | 'clses': ['PredictionEvalTPL'],
25 | }
26 | )
27 |
--------------------------------------------------------------------------------
/examples/single_model/run_cl4kt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'CL4KTDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | 'eval_batch_size': 1024,
18 | },
19 | modeltpl_cfg_dict={
20 | 'cls': 'CL4KT',
21 | },
22 | evaltpl_cfg_dict={
23 | 'clses': ['PredictionEvalTPL'],
24 | }
25 | )
26 |
--------------------------------------------------------------------------------
/examples/single_model/run_cncd_f_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='AAAI_2023',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'CNCDFDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'CNCD_F',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_cncdq_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='AAAI_2023',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'CNCDQDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'CNCD_Q',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_ctncm_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterCptUnfoldDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'CT_NCM',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
26 |
--------------------------------------------------------------------------------
/examples/single_model/run_dcd_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
4 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
5 |
6 | from edustudio.quickstart import run_edustudio
7 |
8 | run_edustudio(
9 | dataset='FrcSub',
10 | datatpl_cfg_dict={
11 | 'cls': 'DCDDataTPL',
12 | 'M2C_BuildMissingQ': {
13 | 'Q_delete_ratio': 0.8
14 | },
15 | 'M2C_FillMissingQ': {
16 | 'Q_fill_type': "None"
17 | }
18 | },
19 | modeltpl_cfg_dict={
20 | 'cls': 'DCD',
21 | 'EncoderUserHidden': [768],
22 | 'EncoderItemHidden': [768],
23 | 'lambda_main': 1.0,
24 | 'lambda_q': 10.0,
25 | 'b_sample_type': 'gumbel_softmax',
26 | 'b_sample_kwargs': {'tau': 1.0, 'hard': True},
27 | 'align_margin_loss_kwargs': {'margin': 0.7, 'topk': 2, 'd1': 1, 'margin_lambda': 0.5, 'norm': 2, 'norm_lambda': 0.5, 'start_epoch': 1},
28 | 'beta_user': 0.0,
29 | 'beta_item': 0.0,
30 | 'g_beta_user': 1.0,
31 | 'g_beta_item': 1.0,
32 | 'alpha_user': 0.0,
33 | 'alpha_item': 0.0,
34 | 'gamma_user': 1.0,
35 | 'gamma_item': 1.0,
36 | 'sampling_type': 'mws',
37 | 'bernoulli_prior_p': 0.2,
38 | },
39 | traintpl_cfg_dict={
40 | 'cls': 'DCDTrainTPL',
41 | 'seed': 2023,
42 | 'epoch_num': 400,
43 | 'lr': 0.0005,
44 | 'num_workers': 0,
45 | 'batch_size': 2048,
46 | 'num_stop_rounds': 10,
47 | 'early_stop_metrics': [('auc','max'), ('doa_all', 'max')],
48 | 'best_epoch_metric': 'doa_all',
49 | },
50 | evaltpl_cfg_dict={
51 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'],
52 | 'InterpretabilityEvalTPL': {'test_only_metrics': []}
53 | }
54 | )
55 |
--------------------------------------------------------------------------------
/examples/single_model/run_deepirt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterExtendsQDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | 'eval_batch_size': 512,
18 | },
19 | modeltpl_cfg_dict={
20 | 'cls': 'DeepIRT',
21 | },
22 | evaltpl_cfg_dict={
23 | 'clses': ['PredictionEvalTPL'],
24 | }
25 | )
26 |
--------------------------------------------------------------------------------
/examples/single_model/run_dimkt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'DIMKTDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'DIMKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_dina_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import torch
4 |
5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
6 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
7 |
8 | from edustudio.quickstart import run_edustudio
9 |
10 | run_edustudio(
11 | dataset='FrcSub',
12 | cfg_file_name=None,
13 | traintpl_cfg_dict={
14 | 'cls': 'GeneralTrainTPL',
15 | },
16 | datatpl_cfg_dict={
17 | 'cls': 'CDInterExtendsQDataTPL',
18 | },
19 | modeltpl_cfg_dict={
20 | 'cls': 'DINA',
21 | },
22 | evaltpl_cfg_dict={
23 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'],
24 | }
25 | )
26 |
--------------------------------------------------------------------------------
/examples/single_model/run_dkt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterCptAsExerDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'DKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_dkt_dsc_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'DKTDSCDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'DKTDSC',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_dkt_plus_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterCptAsExerDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'DKT_plus',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_dktforget_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'DKTForgetDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'DKTForget',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_dkvmn_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterExtendsQDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | 'eval_batch_size': 256,
18 | },
19 | modeltpl_cfg_dict={
20 | 'cls': 'DKVMN',
21 | },
22 | evaltpl_cfg_dict={
23 | 'clses': ['PredictionEvalTPL'],
24 | }
25 | )
26 |
27 |
28 |
--------------------------------------------------------------------------------
/examples/single_model/run_dtransformer_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterCptUnfoldDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | 'batch_size': 16,
18 | 'eval_batch_size': 16
19 | },
20 | modeltpl_cfg_dict={
21 | 'cls': 'DTransformer',
22 | },
23 | evaltpl_cfg_dict={
24 | 'clses': ['PredictionEvalTPL'],
25 | }
26 | )
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/examples/single_model/run_ecd_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='PISA_2015_ECD',#ecd only supports PISA dataset
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | 'batch_size': 2048,
15 | 'eval_batch_size': 2048,
16 | },
17 | datatpl_cfg_dict={
18 | 'cls': 'ECDDataTPL',
19 | },
20 | modeltpl_cfg_dict={
21 | 'cls': 'ECD_IRT',#ECD_IRT,ECD_MIRT,ECD_NCD
22 | },
23 | evaltpl_cfg_dict={
24 | 'clses': ['PredictionEvalTPL'],
25 | }
26 | )
27 |
--------------------------------------------------------------------------------
/examples/single_model/run_eernn_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='AAAI_2023',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'EERNNDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'EERNNA',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
26 |
--------------------------------------------------------------------------------
/examples/single_model/run_ekt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='AAAI_2023',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'EKTDataTPL',
14 | 'M2C_BuildSeqInterFeats': {
15 | 'window_size': 50,
16 | }
17 | },
18 | traintpl_cfg_dict={
19 | 'cls': 'GeneralTrainTPL',
20 | 'batch_size': 8,
21 | 'eval_batch_size': 8,
22 | },
23 | modeltpl_cfg_dict={
24 | 'cls': 'EKTM',
25 | },
26 | evaltpl_cfg_dict={
27 | 'clses': ['PredictionEvalTPL'],
28 | }
29 | )
30 |
--------------------------------------------------------------------------------
/examples/single_model/run_faircd_irt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='SLP_English',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'AdversarialTrainTPL',
14 | 'batch_size': 1024
15 | },
16 | datatpl_cfg_dict={
17 | 'cls': 'FAIRDataTPL',
18 | },
19 | modeltpl_cfg_dict={
20 | 'cls': 'FairCD_IRT',
21 | },
22 | evaltpl_cfg_dict={
23 | 'clses': ['PredictionEvalTPL', 'FairnessEvalTPL'],
24 | }
25 | )
26 |
--------------------------------------------------------------------------------
/examples/single_model/run_faircd_mirt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='SLP_English',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'AdversarialTrainTPL',
14 | 'batch_size': 1024
15 | },
16 | datatpl_cfg_dict={
17 | 'cls': 'FAIRDataTPL',
18 | },
19 | modeltpl_cfg_dict={
20 | 'cls': 'FairCD_MIRT',
21 | },
22 | evaltpl_cfg_dict={
23 | 'clses': ['PredictionEvalTPL', 'FairnessEvalTPL'],
24 | }
25 | )
26 |
--------------------------------------------------------------------------------
/examples/single_model/run_faircd_ncdm_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='SLP_English',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'AdversarialTrainTPL',
14 | 'batch_size': 1024
15 | },
16 | datatpl_cfg_dict={
17 | 'cls': 'FAIRDataTPL',
18 | },
19 | modeltpl_cfg_dict={
20 | 'cls': 'FairCD_NCDM',
21 | },
22 | evaltpl_cfg_dict={
23 | 'clses': ['PredictionEvalTPL', 'FairnessEvalTPL'],
24 | }
25 | )
26 |
--------------------------------------------------------------------------------
/examples/single_model/run_gkt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'GKTDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'GKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_hawkeskt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterCptUnfoldDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'HawkesKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_hiercdf_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 | from edustudio.atom_op.mid2cache.CD import M2C_FilterRecords4CD
9 |
10 | import networkx as nx
11 | import pandas as pd
12 |
13 | class M2C_MyFilterRecords4CD(M2C_FilterRecords4CD):
14 |
15 | @staticmethod
16 | def build_DAG(edges):
17 | graph = nx.DiGraph()
18 | for edge in edges:
19 | graph.add_edge(*edge)
20 | if not nx.is_directed_acyclic_graph(graph):
21 | graph.remove_edge(*edge)
22 | return graph.edges
23 |
24 | def process(self, **kwargs):
25 | selected_items = kwargs['df_exer']['exer_id:token'].to_numpy()
26 | # kwargs['df_cpt_relation'] = kwargs['df_cpt_relation'].drop_duplicates(["cpt_head:token", "cpt_tail:token"])
27 | edges = kwargs['df_cpt_relation'][['cpt_head:token', 'cpt_tail:token']].to_numpy()
28 | edges = list(self.build_DAG(edges))
29 | kwargs['df_cpt_relation'] = pd.DataFrame({"cpt_head:token": [i[0] for i in edges], "cpt_tail:token": [i[1] for i in edges]})
30 | selected_items = kwargs['df_cpt_relation'].to_numpy().flatten()
31 | df = kwargs['df']
32 | df = df[df['exer_id:token'].isin(selected_items)].reset_index(drop=True)
33 | kwargs['df'] = df
34 |
35 | kwargs = super().process(**kwargs)
36 |
37 | selected_items = kwargs['df']['exer_id:token'].unique()
38 | kwargs['df_cpt_relation'] = kwargs['df_cpt_relation'][kwargs['df_cpt_relation']['cpt_head:token'].isin(selected_items)]
39 | kwargs['df_cpt_relation'] = kwargs['df_cpt_relation'][kwargs['df_cpt_relation']['cpt_tail:token'].isin(selected_items)]
40 | kwargs['df_cpt_relation'] = kwargs['df_cpt_relation'].reset_index(drop=True)
41 | return kwargs
42 |
43 |
44 | run_edustudio(
45 | dataset='JunyiExerAsCpt',
46 | cfg_file_name=None,
47 | traintpl_cfg_dict={
48 | 'cls': 'GeneralTrainTPL',
49 | 'batch_size': 2048,
50 | 'eval_batch_size': 2048,
51 | },
52 | datatpl_cfg_dict={
53 | 'cls': 'HierCDFDataTPL',
54 | # 'load_data_from': 'rawdata',
55 | # 'raw2mid_op': 'R2M_JunyiExerAsCpt',
56 | 'mid2cache_op_seq': ['M2C_Label2Int', M2C_MyFilterRecords4CD, 'M2C_ReMapId', 'M2C_RandomDataSplit4CD', 'M2C_GenQMat'],
57 | 'M2C_ReMapId': {
58 | 'share_id_columns': [{'cpt_seq:token_seq', 'cpt_head:token', 'cpt_tail:token', 'exer_id:token'}],
59 | },
60 | 'M2C_MyFilterRecords4CD': {
61 | "stu_least_records": 60,
62 | }
63 | },
64 | modeltpl_cfg_dict={
65 | 'cls': 'HierCDF',
66 | },
67 | evaltpl_cfg_dict={
68 | 'clses': ['PredictionEvalTPL'],
69 | }
70 | )
71 |
--------------------------------------------------------------------------------
/examples/single_model/run_iekt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterExtendsQDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'IEKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_irr_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import torch
4 |
5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
6 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
7 |
8 | from edustudio.quickstart import run_edustudio
9 |
10 | run_edustudio(
11 | dataset='FrcSub',
12 | cfg_file_name=None,
13 | traintpl_cfg_dict={
14 | 'cls': 'GeneralTrainTPL',
15 | 'lr': 0.0001,
16 | },
17 | datatpl_cfg_dict={
18 | 'cls': 'IRRDataTPL',
19 | },
20 | modeltpl_cfg_dict={
21 | 'cls': 'IRR',
22 | },
23 | evaltpl_cfg_dict={
24 | 'clses': ['PredictionEvalTPL'],
25 | }
26 | )
27 |
--------------------------------------------------------------------------------
/examples/single_model/run_irt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='FrcSub',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | },
15 | datatpl_cfg_dict={
16 | 'cls': 'CDInterDataTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'IRT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_kancd_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='FrcSub',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | },
15 | datatpl_cfg_dict={
16 | 'cls': 'CDInterExtendsQDataTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'KaNCD',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_kqn_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterCptAsExerDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'KQN',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_kscd_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='FrcSub',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | },
15 | datatpl_cfg_dict={
16 | 'cls': 'CDInterExtendsQDataTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'KSCD',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_lpkt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'LPKTDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'LPKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_lpkt_s_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'LPKTDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'LPKT_S',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_mf_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='FrcSub',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | },
15 | datatpl_cfg_dict={
16 | 'cls': 'CDInterDataTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'MF',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_mgcd_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GroupCDTrainTPL',
14 | 'early_stop_metrics': [('rmse','min')],
15 | 'best_epoch_metric': 'rmse',
16 | 'batch_size': 512
17 | },
18 | datatpl_cfg_dict={
19 | 'cls': 'MGCDDataTPL',
20 | # 'load_data_from': 'rawdata',
21 | # 'raw2mid_op': 'R2M_ASSIST_0910'
22 | },
23 | modeltpl_cfg_dict={
24 | 'cls': 'MGCD',
25 | },
26 | evaltpl_cfg_dict={
27 | 'clses': ['PredictionEvalTPL'],
28 | 'PredictionEvalTPL': {
29 | 'use_metrics': ['rmse']
30 | }
31 | }
32 | )
33 |
--------------------------------------------------------------------------------
/examples/single_model/run_mirt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='FrcSub',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | },
15 | datatpl_cfg_dict={
16 | 'cls': 'CDInterDataTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'MIRT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_ncdm_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='SLP-Math',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | },
15 | datatpl_cfg_dict={
16 | 'cls': 'CDInterExtendsQDataTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'NCDM',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_qdkt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'QDKTDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'QDKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_qikt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterExtendsQDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | 'batch_size': 16,
18 | 'eval_batch_size': 16
19 | },
20 | modeltpl_cfg_dict={
21 | 'cls': 'QIKT',
22 | },
23 | evaltpl_cfg_dict={
24 | 'clses': ['PredictionEvalTPL'],
25 | }
26 | )
27 |
--------------------------------------------------------------------------------
/examples/single_model/run_rcd_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='FrcSub',
11 | cfg_file_name=None,
12 | traintpl_cfg_dict={
13 | 'cls': 'GeneralTrainTPL',
14 | },
15 | datatpl_cfg_dict={
16 | 'cls': 'RCDDataTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'RCD',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_rkt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'RKTDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'RKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_saint_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterCptUnfoldDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'SAINT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
26 |
--------------------------------------------------------------------------------
/examples/single_model/run_saint_plus_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterCptUnfoldDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'SAINT_plus',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
26 |
--------------------------------------------------------------------------------
/examples/single_model/run_sakt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterCptAsExerDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'SAKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_simplekt_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterCptUnfoldDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'SimpleKT',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/examples/single_model/run_skvmn_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
5 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
6 |
7 | from edustudio.quickstart import run_edustudio
8 |
9 | run_edustudio(
10 | dataset='ASSIST_0910',
11 | cfg_file_name=None,
12 | datatpl_cfg_dict={
13 | 'cls': 'KTInterCptAsExerDataTPL',
14 | },
15 | traintpl_cfg_dict={
16 | 'cls': 'GeneralTrainTPL',
17 | },
18 | modeltpl_cfg_dict={
19 | 'cls': 'SKVMN',
20 | },
21 | evaltpl_cfg_dict={
22 | 'clses': ['PredictionEvalTPL'],
23 | }
24 | )
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.10.0
2 | numpy>=1.17.2
3 | scipy>=1.6.0
4 | pandas>=1.0.5
5 | tqdm>=4.48.2
6 | scikit_learn>=0.23.2
7 | pyyaml>=5.1.0
8 | tensorboard>=2.5.0
9 | requests>=2.27.1
10 | pytz>=2022.1
11 | matplotlib>=3.5.1
12 | deepdiff>=6.3.1
13 | networkx>=2.8
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 | import os
6 |
7 | from setuptools import setup, find_packages
8 |
9 | install_requires = [
10 | "torch>=1.10.0",
11 | "numpy>=1.17.2",
12 | "scipy>=1.6.0",
13 | "pandas>=1.0.5",
14 | "tqdm>=4.48.2",
15 | "scikit_learn>=0.23.2",
16 | "pyyaml>=5.1.0",
17 | "tensorboard>=2.5.0",
18 | "requests>=2.27.1",
19 | "pytz>=2022.1",
20 | "matplotlib>=3.5.1",
21 | "deepdiff>=6.3.1",
22 | "networkx>=2.8"
23 | ]
24 |
25 | setup_requires = []
26 |
27 | extras_require = {}
28 |
29 | classifiers = ["License :: OSI Approved :: MIT License"]
30 |
31 | long_description = (
32 | "EduStudio is a Unified and Templatized Framework "
33 | "for Student Assessment Models including "
34 | "Cognitive Diagnosis(CD) and Knowledge Tracing(KT) based on Pytorch."
35 | )
36 |
37 | # Readthedocs requires Sphinx extensions to be specified as part of
38 | # install_requires in order to build properly.
39 | on_rtd = os.environ.get("READTHEDOCS", None) == "True"
40 | if on_rtd:
41 | install_requires.extend(setup_requires)
42 |
43 | setup(
44 | name="edustudio",
45 | version="v1.1.7",
46 | description="a Unified and Templatized Framework for Student Assessment Models",
47 | long_description=long_description,
48 | python_requires='>=3.8',
49 | long_description_content_type="text/markdown",
50 | url="https://github.com/HFUT-LEC/EduStudio",
51 | author="HFUT-LEC",
52 | author_email="lmcRS.hfut@gmail.com",
53 | packages=[package for package in find_packages() if package.startswith("edustudio")],
54 | include_package_data=True,
55 | install_requires=install_requires,
56 | setup_requires=setup_requires,
57 | extras_require=extras_require,
58 | zip_safe=False,
59 | classifiers=classifiers,
60 | entry_points={
61 | "console_scripts": [
62 | "edustudio = edustudio.quickstart.atom_cmds:entrypoint",
63 | ],
64 | }
65 | )
66 |
--------------------------------------------------------------------------------
/tests/test_run.py:
--------------------------------------------------------------------------------
1 | from edustudio.quickstart import run_edustudio
2 |
3 | class TestRun:
4 | def test_cd(self):
5 | print("---------- test cd ------------")
6 | run_edustudio(
7 | dataset='FrcSub',
8 | cfg_file_name=None,
9 | traintpl_cfg_dict={
10 | 'cls': 'GeneralTrainTPL',
11 | 'epoch_num': 2,
12 | 'device': 'cpu'
13 | },
14 | datatpl_cfg_dict={
15 | 'cls': 'CDInterExtendsQDataTPL'
16 | },
17 | modeltpl_cfg_dict={
18 | 'cls': 'KaNCD',
19 | },
20 | evaltpl_cfg_dict={
21 | 'clses': ['PredictionEvalTPL', 'InterpretabilityEvalTPL'],
22 | }
23 | )
24 |
--------------------------------------------------------------------------------