├── .github └── workflows │ ├── publish_pypi.yml │ ├── ruff.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml ├── setup.py ├── src └── fsrs_optimizer │ ├── __init__.py │ ├── __main__.py │ ├── fsrs_optimizer.py │ └── fsrs_simulator.py └── tests ├── __init__.py ├── model_test.py └── simulator_test.py /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Pypi 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-latest 11 | environment: release 12 | permissions: 13 | # IMPORTANT: this permission is mandatory for trusted publishing 14 | id-token: write 15 | steps: 16 | - uses: actions/checkout@v3 17 | - uses: actions/setup-python@v4 18 | with: 19 | python-version: '3.8' 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install build 24 | - name: Build package 25 | run: python -m build 26 | - name: pypi-publish 27 | uses: pypa/gh-action-pypi-publish@v1.8.8 28 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | on: [ push, pull_request ] 3 | jobs: 4 | ruff: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v4 8 | - uses: chartboost/ruff-action@v1 9 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test Python 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | test: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v3 14 | with: 15 | python-version: '3.12' 16 | 17 | - name: Install uv 18 | uses: astral-sh/setup-uv@v5 19 | with: 20 | enable-cache: true 21 | 22 | - name: Install the project 23 | run: uv sync --all-extras --dev 24 | 25 | - name: Run tests 26 | # For example, using `pytest` 27 | run: uv run pytest tests 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # IDE 132 | .idea 133 | .vscode 134 | 135 | # Mac 136 | .DS_Store 137 | 138 | # Anki files 139 | media 140 | meta 141 | *.anki2 142 | *.anki21 143 | *.apkg 144 | *.colpkg 145 | 146 | # Output files 147 | *.csv 148 | *.tsv 149 | *.png 150 | *.json 151 | .fsrs_optimizer 152 | *.ipynb 153 | *.zip 154 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Open Spaced Repetition 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FSRS Optimizer 2 | 3 | [![PyPi](https://img.shields.io/pypi/v/FSRS-Optimizer)](https://pypi.org/project/FSRS-Optimizer/) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 4 | 5 | The FSRS Optimizer is a Python library capable of utilizing personal spaced repetition review logs to refine the FSRS algorithm. Designed with the intent of delivering a standardized, universal optimizer to various FSRS implementations across numerous programming languages, this tool is set to establish a ubiquitous standard for spaced repetition review logs. By facilitating the uniformity of learning data among different spaced repetition softwares, it guarantees learners consistent review schedules across a multitude of platforms. 6 | 7 | Delve into the underlying principles of the FSRS Optimizer's training process at: https://github.com/open-spaced-repetition/fsrs4anki/wiki/The-mechanism-of-optimization 8 | 9 | Explore the mathematical formula of the FSRS model at: https://github.com/open-spaced-repetition/fsrs4anki/wiki/The-Algorithm 10 | 11 | # Review Logs Schema 12 | 13 | The `review_logs` table captures the review activities performed by users. Each log records the details of a single review instance. The schema for this table is as follows: 14 | 15 | | Column Name | Data Type | Description | Constraints | 16 | |-------------|-----------|-------------|-------------| 17 | | card_id | integer or string | The unique identifier of the flashcard being reviewed | Not null | 18 | | review_time | timestamp in *miliseconds* | The exact moment when the review took place | Not null | 19 | | review_rating | integer | The user's rating for the review. This rating is subjective and depends on how well the user believes they remembered the information on the card | Not null, Values: {1 (Again), 2 (Hard), 3 (Good), 4 (Easy)} | 20 | | review_state | integer | The state of the card at the time of review. This describes the learning phase of the card | Optional, Values: {0 (New), 1 (Learning), 2 (Review), 3 (Relearning)} | 21 | | review_duration | integer | The time spent on reviewing the card, typically in miliseconds | Optional, Non-negative | 22 | 23 | Extra Info: 24 | - `timezone`: The time zone of the user when they performed the review, which is used to identify the start of a new day. 25 | - `day_start`: The hour (0-23) at which the user starts a new day, which is used to separate reviews that are divided by sleep into different days. 26 | 27 | Notes: 28 | - All timestamp fields are expected to be in UTC. 29 | - The `card_id` should correspond to a valid card in the corresponding flashcards dataset. 30 | - `review_rating` should be a reflection of the user's memory of the card at the time of the review. 31 | - `review_state` helps to understand the learning progress of the card. 32 | - `review_duration` measures the cost of the review. 33 | - `timezone` should be a string from the IANA Time Zone Database (e.g., "America/New_York"). For more information, refer to this [list of IANA time zones](https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568). 34 | - `day_start` determines the start of the learner's day and is used to correctly assign reviews to days, especially when reviews are divided by sleep. 35 | 36 | Please ensure your data conforms to this schema for optimal compatibility with the optimization process. 37 | 38 | # Optimize FSRS with your review logs 39 | 40 | **Installation** 41 | 42 | Install the package with the command: 43 | 44 | ``` 45 | python -m pip install fsrs-optimizer 46 | ``` 47 | 48 | You should upgrade regularly to make sure you have the most recent version of FSRS-Optimizer: 49 | 50 | ``` 51 | python -m pip install fsrs-optimizer --upgrade 52 | ``` 53 | 54 | **Opimization** 55 | 56 | If you have a file named `revlog.csv` with the above schema, you can run: 57 | 58 | ``` 59 | python -m fsrs_optimizer "revlog.csv" 60 | ``` 61 | 62 | **Expected Functionality** 63 | 64 | ![image](https://github.com/open-spaced-repetition/fsrs-optimizer/assets/32575846/fad7154a-9667-4eea-b868-d94c94a50912) 65 | 66 | ![image](https://github.com/open-spaced-repetition/fsrs-optimizer/assets/32575846/f868aac4-2e9e-4101-b8ad-eccc1d9b1bd5) 67 | 68 | --- 69 | 70 | ## Alternative 71 | 72 | Are you getting tired of installing torch? Try [fsrs-rs-python](https://github.com/open-spaced-repetition/fsrs-rs-python)! 73 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "FSRS-Optimizer" 7 | version = "6.1.2" 8 | readme = "README.md" 9 | dependencies = [ 10 | "matplotlib>=3.7.0", 11 | "numpy>=1.22.4", 12 | "pandas>=1.5.3", 13 | "pytz>=2022.7.1", 14 | "scikit_learn>=1.4.0", 15 | "torch>=1.13.1", 16 | "tqdm>=4.64.1", 17 | "statsmodels>=0.13.5", 18 | "scipy<1.14.1" 19 | ] 20 | requires-python = ">=3.9,<3.13" 21 | 22 | [project.urls] 23 | Homepage = "https://github.com/open-spaced-repetition/fsrs-optimizer" 24 | [tool.ruff.lint] 25 | ignore = ["F405", "F403", "E712", "F541", "E722", "E741"] 26 | [project.optional-dependencies] 27 | test = [ 28 | "ruff", 29 | "mypy", 30 | "pytest", 31 | "pytest-cov", 32 | ] 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /src/fsrs_optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .fsrs_optimizer import * 2 | -------------------------------------------------------------------------------- /src/fsrs_optimizer/__main__.py: -------------------------------------------------------------------------------- 1 | import fsrs_optimizer 2 | import argparse 3 | import shutil 4 | import json 5 | import pytz 6 | import os 7 | import functools 8 | from pathlib import Path 9 | 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def prompt(msg: str, fallback): 14 | default = "" 15 | if fallback: 16 | default = f"(default: {fallback})" 17 | 18 | response = input(f"{msg} {default}: ") 19 | if response == "": 20 | if fallback is not None: 21 | return fallback 22 | else: # If there is no fallback 23 | raise Exception("You failed to enter a required parameter") 24 | return response 25 | 26 | 27 | def process(filepath, filter_out_flags: list[int]): 28 | suffix = filepath.split("/")[-1].replace(".", "_").replace("@", "_") 29 | proj_dir = Path(f"{suffix}") 30 | proj_dir.mkdir(parents=True, exist_ok=True) 31 | os.chdir(proj_dir) 32 | 33 | try: # Try and remember the last values inputted. 34 | with open(config_save, "r") as f: 35 | remembered_fallbacks = json.load(f) 36 | except FileNotFoundError: 37 | remembered_fallbacks = { # Defaults to this if not there 38 | "timezone": None, # Timezone starts with no default 39 | "next_day": 4, 40 | "revlog_start_date": "2006-10-05", 41 | "preview": "y", 42 | "filter_out_suspended_cards": "n", 43 | "enable_short_term": "y", 44 | } 45 | 46 | # Prompts the user with the key and then falls back on the last answer given. 47 | def remembered_fallback_prompt(key: str, pretty: str = None): 48 | if pretty is None: 49 | pretty = key 50 | remembered_fallbacks[key] = prompt( 51 | f"input {pretty}", remembered_fallbacks.get(key, None) 52 | ) 53 | 54 | print("The defaults will switch to whatever you entered last.\n") 55 | 56 | if not args.yes: 57 | print( 58 | "Timezone list: https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568" 59 | ) 60 | remembered_fallback_prompt("timezone", "used timezone") 61 | if remembered_fallbacks["timezone"] not in pytz.all_timezones: 62 | raise Exception("Not a valid timezone, Check the list for more information") 63 | 64 | remembered_fallback_prompt("next_day", "used next day start hour") 65 | remembered_fallback_prompt( 66 | "revlog_start_date", 67 | "the date at which before reviews will be ignored | YYYY-MM-DD", 68 | ) 69 | remembered_fallback_prompt( 70 | "filter_out_suspended_cards", "filter out suspended cards? (y/n)" 71 | ) 72 | remembered_fallback_prompt( 73 | "enable_short_term", "enable short-term component in FSRS model? (y/n)" 74 | ) 75 | 76 | graphs_input = prompt("Save graphs? (y/n)", remembered_fallbacks["preview"]) 77 | else: 78 | graphs_input = remembered_fallbacks["preview"] 79 | 80 | if graphs_input.lower() != "y": 81 | remembered_fallbacks["preview"] = "n" 82 | else: 83 | remembered_fallbacks["preview"] = "y" 84 | 85 | with open( 86 | config_save, "w+" 87 | ) as f: # Save the settings to load next time the program is run 88 | json.dump(remembered_fallbacks, f) 89 | 90 | save_graphs = graphs_input != "n" 91 | enable_short_term = remembered_fallbacks["enable_short_term"] == "y" 92 | 93 | optimizer = fsrs_optimizer.Optimizer(enable_short_term=enable_short_term) 94 | if filepath.endswith(".apkg") or filepath.endswith(".colpkg"): 95 | optimizer.anki_extract( 96 | f"{filepath}", 97 | remembered_fallbacks["filter_out_suspended_cards"] == "y", 98 | filter_out_flags, 99 | ) 100 | else: 101 | # copy the file to the current directory and rename it as revlog.csv 102 | shutil.copyfile(f"{filepath}", "revlog.csv") 103 | analysis = optimizer.create_time_series( 104 | remembered_fallbacks["timezone"], 105 | remembered_fallbacks["revlog_start_date"], 106 | remembered_fallbacks["next_day"], 107 | save_graphs, 108 | ) 109 | print(analysis) 110 | 111 | filename = os.path.splitext(os.path.basename(filepath))[0] 112 | 113 | optimizer.define_model() 114 | figures = optimizer.pretrain(verbose=save_graphs) 115 | for i, f in enumerate(figures): 116 | f.savefig(f"pretrain_{i}.png") 117 | plt.close(f) 118 | figures = optimizer.train(verbose=save_graphs, recency_weight=True) 119 | for i, f in enumerate(figures): 120 | f.savefig(f"train_{i}.png") 121 | plt.close(f) 122 | 123 | optimizer.predict_memory_states() 124 | try: 125 | figures = optimizer.find_optimal_retention(verbose=save_graphs) 126 | for i, f in enumerate(figures): 127 | f.savefig(f"find_optimal_retention_{i}.png") 128 | plt.close(f) 129 | except Exception as e: 130 | print(e) 131 | print("Failed to find optimal retention") 132 | optimizer.optimal_retention = 0.9 133 | 134 | print(optimizer.preview(optimizer.optimal_retention)) 135 | 136 | profile = f"""{{ 137 | // Generated, Optimized anki deck settings 138 | "deckName": "{filename}",// PLEASE CHANGE THIS TO THE DECKS PROPER NAME 139 | "w": {optimizer.w}, 140 | "requestRetention": {optimizer.optimal_retention}, 141 | "maximumInterval": 36500, 142 | }}, 143 | """ 144 | 145 | print("Paste this into your scheduling code") 146 | print(profile) 147 | 148 | if args.out: 149 | with open(args.out, "a+") as f: 150 | f.write(profile) 151 | 152 | loss_before, loss_after = optimizer.evaluate() 153 | print(f"Loss before training: {loss_before:.4f}") 154 | print(f"Loss after training: {loss_after:.4f}") 155 | metrics, figures = optimizer.calibration_graph(verbose=False) 156 | for partition in metrics: 157 | print(f"Last rating = {partition}") 158 | for metric in metrics[partition]: 159 | print(f"{metric}: {metrics[partition][metric]:.4f}") 160 | print() 161 | 162 | metrics["Log loss"] = loss_after 163 | if save_graphs: 164 | for i, f in enumerate(figures): 165 | f.savefig(f"calibration_{i}.png") 166 | plt.close(f) 167 | figures = optimizer.formula_analysis() 168 | if save_graphs: 169 | for i, f in enumerate(figures): 170 | f.savefig(f"formula_analysis_{i}.png") 171 | plt.close(f) 172 | figures = optimizer.compare_with_sm2() 173 | if save_graphs: 174 | for i, f in enumerate(figures): 175 | f.savefig(f"compare_with_sm2_{i}.png") 176 | plt.close(f) 177 | 178 | evaluation = { 179 | "filename": filename, 180 | "size": optimizer.dataset.shape[0], 181 | "parameters": optimizer.w, 182 | "metrics": metrics, 183 | } 184 | 185 | with open("evaluation.json", "w+") as f: 186 | json.dump(evaluation, f) 187 | 188 | 189 | def create_arg_parser(): 190 | parser = argparse.ArgumentParser() 191 | 192 | parser.add_argument("filenames", nargs="+") 193 | parser.add_argument( 194 | "-y", 195 | "--yes", 196 | action=argparse.BooleanOptionalAction, 197 | help="If set automatically defaults on all stdin settings.", 198 | ) 199 | parser.add_argument( 200 | "--flags", 201 | help="Remove any cards with the given flags from the training set.", 202 | default=[], 203 | nargs="+", 204 | ) 205 | parser.add_argument( 206 | "-o", "--out", help="File to APPEND the automatically generated profile to." 207 | ) 208 | 209 | return parser 210 | 211 | 212 | if __name__ == "__main__": 213 | config_save = os.path.expanduser(".fsrs_optimizer") 214 | 215 | parser = create_arg_parser() 216 | args = parser.parse_args() 217 | 218 | def lift(file_or_dir): 219 | return os.listdir(file_or_dir) if os.path.isdir(file_or_dir) else [file_or_dir] 220 | 221 | def flatten(fl): 222 | return sum(fl, []) 223 | 224 | def mapC(f): 225 | return lambda x: map(f, x) 226 | 227 | def filterC(f): 228 | return lambda x: filter(f, x) 229 | 230 | def pipe(functions, value): 231 | return functools.reduce(lambda out, f: f(out), functions, value) 232 | 233 | curdir = os.getcwd() 234 | 235 | files = pipe( 236 | [ 237 | mapC(lift), # map file to [ file ], dir to [ file1, file2, ... ] 238 | flatten, # flatten into [ file1, file2, ... ] 239 | mapC(os.path.abspath), # map to absolute path 240 | filterC(lambda f: not os.path.isdir(f)), # file filter 241 | filterC( 242 | lambda f: f.lower().endswith(".apkg") 243 | or f.lower().endswith(".colpkg") 244 | or f.lower().endswith(".csv") 245 | ), # extension filter 246 | ], 247 | args.filenames, 248 | ) 249 | 250 | for filename in files: 251 | try: 252 | print(f"Processing {filename}") 253 | process(filename, args.flags) 254 | except Exception as e: 255 | print(e) 256 | print(f"Failed to process {filename}") 257 | finally: 258 | plt.close("all") 259 | os.chdir(curdir) 260 | continue 261 | -------------------------------------------------------------------------------- /src/fsrs_optimizer/fsrs_optimizer.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import sqlite3 3 | import time 4 | import pandas as pd 5 | import numpy as np 6 | import os 7 | import math 8 | from typing import List, Optional, Tuple 9 | from datetime import timedelta, datetime 10 | from collections import defaultdict 11 | import statsmodels.api as sm # type: ignore 12 | from statsmodels.nonparametric.smoothers_lowess import lowess # type: ignore 13 | import matplotlib.pyplot as plt 14 | import matplotlib.ticker as ticker 15 | import torch 16 | from torch import nn 17 | from torch import Tensor 18 | from torch.utils.data import Dataset 19 | from torch.nn.utils.rnn import pad_sequence 20 | from sklearn.model_selection import TimeSeriesSplit # type: ignore 21 | from sklearn.metrics import ( # type: ignore 22 | log_loss, 23 | root_mean_squared_error, 24 | mean_absolute_error, 25 | mean_absolute_percentage_error, 26 | r2_score, 27 | roc_auc_score, 28 | ) 29 | from scipy.optimize import minimize # type: ignore 30 | from itertools import accumulate 31 | from tqdm.auto import tqdm # type: ignore 32 | import warnings 33 | 34 | try: 35 | from .fsrs_simulator import * 36 | except ImportError: 37 | from fsrs_simulator import * # type: ignore 38 | 39 | warnings.filterwarnings("ignore", category=UserWarning) 40 | 41 | New = 0 42 | Learning = 1 43 | Review = 2 44 | Relearning = 3 45 | 46 | DEFAULT_PARAMETER = [ 47 | 0.212, 48 | 1.2931, 49 | 2.3065, 50 | 8.2956, 51 | 6.4133, 52 | 0.8334, 53 | 3.0194, 54 | 0.001, 55 | 1.8722, 56 | 0.1666, 57 | 0.796, 58 | 1.4835, 59 | 0.0614, 60 | 0.2629, 61 | 1.6483, 62 | 0.6014, 63 | 1.8729, 64 | 0.5425, 65 | 0.0912, 66 | 0.0658, 67 | 0.1542, 68 | ] 69 | 70 | DEFAULT_PARAMS_STDDEV_TENSOR = torch.tensor( 71 | [ 72 | 6.43, 73 | 9.66, 74 | 17.58, 75 | 27.85, 76 | 0.57, 77 | 0.28, 78 | 0.6, 79 | 0.12, 80 | 0.39, 81 | 0.18, 82 | 0.33, 83 | 0.3, 84 | 0.09, 85 | 0.16, 86 | 0.57, 87 | 0.25, 88 | 1.03, 89 | 0.31, 90 | 0.32, 91 | 0.14, 92 | 0.27, 93 | ], 94 | dtype=torch.float, 95 | ) 96 | 97 | 98 | class FSRS(nn.Module): 99 | def __init__(self, w: List[float], float_delta_t: bool = False): 100 | super(FSRS, self).__init__() 101 | self.w = nn.Parameter(torch.tensor(w, dtype=torch.float32)) 102 | self.float_delta_t = float_delta_t 103 | 104 | def stability_after_success( 105 | self, state: Tensor, r: Tensor, rating: Tensor 106 | ) -> Tensor: 107 | hard_penalty = torch.where(rating == 2, self.w[15], 1) 108 | easy_bonus = torch.where(rating == 4, self.w[16], 1) 109 | new_s = state[:, 0] * ( 110 | 1 111 | + torch.exp(self.w[8]) 112 | * (11 - state[:, 1]) 113 | * torch.pow(state[:, 0], -self.w[9]) 114 | * (torch.exp((1 - r) * self.w[10]) - 1) 115 | * hard_penalty 116 | * easy_bonus 117 | ) 118 | return new_s 119 | 120 | def stability_after_failure(self, state: Tensor, r: Tensor) -> Tensor: 121 | old_s = state[:, 0] 122 | new_s = ( 123 | self.w[11] 124 | * torch.pow(state[:, 1], -self.w[12]) 125 | * (torch.pow(old_s + 1, self.w[13]) - 1) 126 | * torch.exp((1 - r) * self.w[14]) 127 | ) 128 | new_minimum_s = old_s / torch.exp(self.w[17] * self.w[18]) 129 | return torch.minimum(new_s, new_minimum_s) 130 | 131 | def stability_short_term(self, state: Tensor, rating: Tensor) -> Tensor: 132 | sinc = torch.exp(self.w[17] * (rating - 3 + self.w[18])) * torch.pow( 133 | state[:, 0], -self.w[19] 134 | ) 135 | new_s = state[:, 0] * torch.where(rating >= 3, sinc.clamp(min=1), sinc) 136 | return new_s 137 | 138 | def init_d(self, rating: Tensor) -> Tensor: 139 | new_d = self.w[4] - torch.exp(self.w[5] * (rating - 1)) + 1 140 | return new_d 141 | 142 | def linear_damping(self, delta_d: Tensor, old_d: Tensor) -> Tensor: 143 | return delta_d * (10 - old_d) / 9 144 | 145 | def next_d(self, state: Tensor, rating: Tensor) -> Tensor: 146 | delta_d = -self.w[6] * (rating - 3) 147 | new_d = state[:, 1] + self.linear_damping(delta_d, state[:, 1]) 148 | new_d = self.mean_reversion(self.init_d(4), new_d) 149 | return new_d 150 | 151 | def step(self, X: Tensor, state: Tensor) -> Tensor: 152 | """ 153 | :param X: shape[batch_size, 2], X[:,0] is elapsed time, X[:,1] is rating 154 | :param state: shape[batch_size, 2], state[:,0] is stability, state[:,1] is difficulty 155 | :return state: 156 | """ 157 | if torch.equal(state, torch.zeros_like(state)): 158 | keys = torch.tensor([1, 2, 3, 4]) 159 | keys = keys.view(1, -1).expand(X[:, 1].long().size(0), -1) 160 | index = (X[:, 1].long().unsqueeze(1) == keys).nonzero(as_tuple=True) 161 | # first learn, init memory states 162 | new_s = torch.ones_like(state[:, 0]) 163 | new_s[index[0]] = self.w[index[1]] 164 | new_d = self.init_d(X[:, 1]) 165 | new_d = new_d.clamp(1, 10) 166 | else: 167 | r = power_forgetting_curve(X[:, 0], state[:, 0], -self.w[20]) 168 | short_term = X[:, 0] < 1 169 | success = X[:, 1] > 1 170 | new_s = ( 171 | torch.where( 172 | short_term, 173 | self.stability_short_term(state, X[:, 1]), 174 | torch.where( 175 | success, 176 | self.stability_after_success(state, r, X[:, 1]), 177 | self.stability_after_failure(state, r), 178 | ), 179 | ) 180 | if not self.float_delta_t 181 | else torch.where( 182 | success, 183 | self.stability_after_success(state, r, X[:, 1]), 184 | self.stability_after_failure(state, r), 185 | ) 186 | ) 187 | new_d = self.next_d(state, X[:, 1]) 188 | new_d = new_d.clamp(1, 10) 189 | new_s = new_s.clamp(S_MIN, 36500) 190 | return torch.stack([new_s, new_d], dim=1) 191 | 192 | def forward( 193 | self, inputs: Tensor, state: Optional[Tensor] = None 194 | ) -> Tuple[Tensor, Tensor]: 195 | """ 196 | :param inputs: shape[seq_len, batch_size, 2] 197 | """ 198 | if state is None: 199 | state = torch.zeros((inputs.shape[1], 2)) 200 | outputs = [] 201 | for X in inputs: 202 | state = self.step(X, state) 203 | outputs.append(state) 204 | return torch.stack(outputs), state 205 | 206 | def mean_reversion(self, init: Tensor, current: Tensor) -> Tensor: 207 | return self.w[7] * init + (1 - self.w[7]) * current 208 | 209 | 210 | class ParameterClipper: 211 | def __init__(self, frequency: int = 1): 212 | self.frequency = frequency 213 | 214 | def __call__(self, module): 215 | if hasattr(module, "w"): 216 | w = module.w.data 217 | w[0] = w[0].clamp(S_MIN, 100) 218 | w[1] = w[1].clamp(S_MIN, 100) 219 | w[2] = w[2].clamp(S_MIN, 100) 220 | w[3] = w[3].clamp(S_MIN, 100) 221 | w[4] = w[4].clamp(1, 10) 222 | w[5] = w[5].clamp(0.001, 4) 223 | w[6] = w[6].clamp(0.001, 4) 224 | w[7] = w[7].clamp(0.001, 0.75) 225 | w[8] = w[8].clamp(0, 4.5) 226 | w[9] = w[9].clamp(0, 0.8) 227 | w[10] = w[10].clamp(0.001, 3.5) 228 | w[11] = w[11].clamp(0.001, 5) 229 | w[12] = w[12].clamp(0.001, 0.25) 230 | w[13] = w[13].clamp(0.001, 0.9) 231 | w[14] = w[14].clamp(0, 4) 232 | w[15] = w[15].clamp(0, 1) 233 | w[16] = w[16].clamp(1, 6) 234 | w[17] = w[17].clamp(0, 2) 235 | w[18] = w[18].clamp(0, 2) 236 | w[19] = w[19].clamp(0, 0.8) 237 | w[20] = w[20].clamp(0.1, 0.8) 238 | module.w.data = w 239 | 240 | 241 | def lineToTensor(line: str) -> Tensor: 242 | ivl = line[0].split(",") 243 | response = line[1].split(",") 244 | tensor = torch.zeros(len(response), 2) 245 | for li, response in enumerate(response): 246 | tensor[li][0] = float(ivl[li]) 247 | tensor[li][1] = int(response) 248 | return tensor 249 | 250 | 251 | class BatchDataset(Dataset): 252 | def __init__( 253 | self, 254 | dataframe: pd.DataFrame, 255 | batch_size: int = 0, 256 | sort_by_length: bool = True, 257 | max_seq_len: int = math.inf, 258 | device: str = "cpu", 259 | ): 260 | if dataframe.empty: 261 | raise ValueError("Training data is inadequate.") 262 | dataframe["seq_len"] = dataframe["tensor"].map(len) 263 | if dataframe["seq_len"].min() > max_seq_len: 264 | raise ValueError("Training data is inadequate.") 265 | dataframe = dataframe[dataframe["seq_len"] <= max_seq_len] 266 | if sort_by_length: 267 | dataframe = dataframe.sort_values(by=["seq_len"], kind="stable") 268 | del dataframe["seq_len"] 269 | self.x_train = pad_sequence( 270 | dataframe["tensor"].to_list(), batch_first=True, padding_value=0 271 | ) 272 | self.t_train = torch.tensor(dataframe["delta_t"].values, dtype=torch.float) 273 | self.y_train = torch.tensor(dataframe["y"].values, dtype=torch.float) 274 | self.seq_len = torch.tensor( 275 | dataframe["tensor"].map(len).values, dtype=torch.long 276 | ) 277 | if "weights" in dataframe.columns: 278 | self.weights = torch.tensor(dataframe["weights"].values, dtype=torch.float) 279 | else: 280 | self.weights = torch.ones(len(dataframe), dtype=torch.float) 281 | length = len(dataframe) 282 | batch_num, remainder = divmod(length, max(1, batch_size)) 283 | self.batch_num = batch_num + 1 if remainder > 0 else batch_num 284 | self.batches = [None] * self.batch_num 285 | if batch_size > 0: 286 | for i in range(self.batch_num): 287 | start_index = i * batch_size 288 | end_index = min((i + 1) * batch_size, length) 289 | sequences = self.x_train[start_index:end_index] 290 | seq_lens = self.seq_len[start_index:end_index] 291 | max_seq_len = max(seq_lens) 292 | sequences_truncated = sequences[:, :max_seq_len] 293 | self.batches[i] = ( 294 | sequences_truncated.transpose(0, 1).to(device), 295 | self.t_train[start_index:end_index].to(device), 296 | self.y_train[start_index:end_index].to(device), 297 | seq_lens.to(device), 298 | self.weights[start_index:end_index].to(device), 299 | ) 300 | 301 | def __getitem__(self, idx): 302 | return self.batches[idx] 303 | 304 | def __len__(self): 305 | return self.batch_num 306 | 307 | 308 | class BatchLoader: 309 | def __init__(self, dataset: BatchDataset, shuffle: bool = True, seed: int = 2023): 310 | self.dataset = dataset 311 | self.batch_nums = len(dataset.batches) 312 | self.shuffle = shuffle 313 | self.generator = torch.Generator() 314 | self.generator.manual_seed(seed) 315 | 316 | def __iter__(self): 317 | if self.shuffle: 318 | yield from ( 319 | self.dataset[idx] 320 | for idx in torch.randperm( 321 | self.batch_nums, generator=self.generator 322 | ).tolist() 323 | ) 324 | else: 325 | yield from (self.dataset[idx] for idx in range(self.batch_nums)) 326 | 327 | def __len__(self): 328 | return self.batch_nums 329 | 330 | 331 | class Trainer: 332 | def __init__( 333 | self, 334 | train_set: pd.DataFrame, 335 | test_set: Optional[pd.DataFrame], 336 | init_w: List[float], 337 | n_epoch: int = 5, 338 | lr: float = 4e-2, 339 | gamma: float = 1, 340 | batch_size: int = 512, 341 | max_seq_len: int = 64, 342 | float_delta_t: bool = False, 343 | enable_short_term: bool = True, 344 | ) -> None: 345 | if not enable_short_term: 346 | init_w[17] = 0 347 | init_w[18] = 0 348 | self.model = FSRS(init_w, float_delta_t) 349 | self.init_w_tensor = torch.tensor(init_w, dtype=torch.float) 350 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 351 | self.clipper = ParameterClipper() 352 | self.gamma = gamma 353 | self.batch_size = batch_size 354 | self.max_seq_len = max_seq_len 355 | self.build_dataset(train_set, test_set) 356 | self.n_epoch = n_epoch 357 | self.batch_nums = self.train_data_loader.batch_nums 358 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 359 | self.optimizer, T_max=self.batch_nums * n_epoch 360 | ) 361 | self.avg_train_losses = [] 362 | self.avg_eval_losses = [] 363 | self.loss_fn = nn.BCELoss(reduction="none") 364 | self.float_delta_t = float_delta_t 365 | self.enable_short_term = enable_short_term 366 | 367 | def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame]): 368 | self.train_set = BatchDataset( 369 | train_set, batch_size=self.batch_size, max_seq_len=self.max_seq_len 370 | ) 371 | self.train_data_loader = BatchLoader(self.train_set) 372 | 373 | self.test_set = ( 374 | [] 375 | if test_set is None 376 | else BatchDataset( 377 | test_set, batch_size=self.batch_size, max_seq_len=self.max_seq_len 378 | ) 379 | ) 380 | 381 | def train(self, verbose: bool = True): 382 | self.verbose = verbose 383 | best_loss = np.inf 384 | epoch_len = len(self.train_set.y_train) 385 | if verbose: 386 | pbar = tqdm(desc="train", colour="red", total=epoch_len * self.n_epoch) 387 | print_len = max(self.batch_nums * self.n_epoch // 10, 1) 388 | for k in range(self.n_epoch): 389 | weighted_loss, w = self.eval() 390 | if weighted_loss < best_loss: 391 | best_loss = weighted_loss 392 | best_w = w 393 | 394 | for i, batch in enumerate(self.train_data_loader): 395 | self.model.train() 396 | self.optimizer.zero_grad() 397 | sequences, delta_ts, labels, seq_lens, weights = batch 398 | real_batch_size = seq_lens.shape[0] 399 | outputs, _ = self.model(sequences) 400 | stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] 401 | retentions = power_forgetting_curve( 402 | delta_ts, stabilities, -self.model.w[20] 403 | ) 404 | loss = (self.loss_fn(retentions, labels) * weights).sum() 405 | penalty = ( 406 | torch.sum( 407 | torch.square(self.model.w - self.init_w_tensor) 408 | / torch.square(DEFAULT_PARAMS_STDDEV_TENSOR) 409 | ) 410 | * self.gamma 411 | * real_batch_size 412 | / epoch_len 413 | ) 414 | loss += penalty 415 | loss.backward() 416 | if self.float_delta_t: 417 | for param in self.model.parameters(): 418 | param.grad[:4] = torch.zeros(4) 419 | if not self.enable_short_term: 420 | for param in self.model.parameters(): 421 | param.grad[17:19] = torch.zeros(2) 422 | self.optimizer.step() 423 | self.scheduler.step() 424 | self.model.apply(self.clipper) 425 | if verbose: 426 | pbar.update(real_batch_size) 427 | if verbose and (k * self.batch_nums + i + 1) % print_len == 0: 428 | tqdm.write( 429 | f"iteration: {k * epoch_len + (i + 1) * self.batch_size}" 430 | ) 431 | for name, param in self.model.named_parameters(): 432 | tqdm.write( 433 | f"{name}: {list(map(lambda x: round(float(x), 4),param))}" 434 | ) 435 | if verbose: 436 | pbar.close() 437 | 438 | weighted_loss, w = self.eval() 439 | if weighted_loss < best_loss: 440 | best_loss = weighted_loss 441 | best_w = w 442 | 443 | return best_w 444 | 445 | def eval(self): 446 | self.model.eval() 447 | with torch.no_grad(): 448 | losses = [] 449 | for dataset in (self.train_set, self.test_set): 450 | if len(dataset) == 0: 451 | losses.append(0) 452 | continue 453 | sequences, delta_ts, labels, seq_lens, weights = ( 454 | dataset.x_train, 455 | dataset.t_train, 456 | dataset.y_train, 457 | dataset.seq_len, 458 | dataset.weights, 459 | ) 460 | real_batch_size = seq_lens.shape[0] 461 | outputs, _ = self.model(sequences.transpose(0, 1)) 462 | stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] 463 | retentions = power_forgetting_curve( 464 | delta_ts, stabilities, -self.model.w[20] 465 | ) 466 | loss = (self.loss_fn(retentions, labels) * weights).mean() 467 | penalty = torch.sum( 468 | torch.square(self.model.w - self.init_w_tensor) 469 | / torch.square(DEFAULT_PARAMS_STDDEV_TENSOR) 470 | ) 471 | loss += penalty * self.gamma / len(self.train_set.y_train) 472 | losses.append(loss) 473 | self.avg_train_losses.append(losses[0]) 474 | self.avg_eval_losses.append(losses[1]) 475 | 476 | w = list( 477 | map( 478 | lambda x: round(float(x), 4), 479 | dict(self.model.named_parameters())["w"].data, 480 | ) 481 | ) 482 | 483 | weighted_loss = ( 484 | losses[0] * len(self.train_set) + losses[1] * len(self.test_set) 485 | ) / (len(self.train_set) + len(self.test_set)) 486 | 487 | return weighted_loss, w 488 | 489 | def plot(self): 490 | fig = plt.figure() 491 | ax = fig.gca() 492 | ax.plot(self.avg_train_losses, label="train") 493 | ax.plot(self.avg_eval_losses, label="test") 494 | ax.set_xlabel("epoch") 495 | ax.set_ylabel("loss") 496 | ax.legend() 497 | return fig 498 | 499 | 500 | class Collection: 501 | def __init__(self, w: List[float], float_delta_t: bool = False) -> None: 502 | self.model = FSRS(w, float_delta_t) 503 | self.model.eval() 504 | 505 | def predict(self, t_history: str, r_history: str): 506 | with torch.no_grad(): 507 | line_tensor = lineToTensor( 508 | list(zip([t_history], [r_history]))[0] 509 | ).unsqueeze(1) 510 | output_t = self.model(line_tensor) 511 | return output_t[-1][0] 512 | 513 | def batch_predict(self, dataset): 514 | fast_dataset = BatchDataset(dataset, sort_by_length=False) 515 | with torch.no_grad(): 516 | outputs, _ = self.model(fast_dataset.x_train.transpose(0, 1)) 517 | stabilities, difficulties = outputs[ 518 | fast_dataset.seq_len - 1, torch.arange(len(fast_dataset)) 519 | ].transpose(0, 1) 520 | return stabilities.tolist(), difficulties.tolist() 521 | 522 | 523 | def remove_outliers(group: pd.DataFrame) -> pd.DataFrame: 524 | # threshold = np.mean(group['delta_t']) * 1.5 525 | # threshold = group['delta_t'].quantile(0.95) 526 | # Q1 = group['delta_t'].quantile(0.25) 527 | # Q3 = group['delta_t'].quantile(0.75) 528 | # IQR = Q3 - Q1 529 | # threshold = Q3 + 1.5 * IQR 530 | # group = group[group['delta_t'] <= threshold] 531 | grouped_group = ( 532 | group.groupby(by=["first_rating", "delta_t"], group_keys=False) 533 | .agg({"y": ["mean", "count"]}) 534 | .reset_index() 535 | ) 536 | sort_index = grouped_group.sort_values( 537 | by=[("y", "count"), "delta_t"], ascending=[True, False] 538 | ).index 539 | 540 | total = sum(grouped_group[("y", "count")]) 541 | has_been_removed = 0 542 | for i in sort_index: 543 | count = grouped_group.loc[i, ("y", "count")] 544 | delta_t = grouped_group.loc[i, "delta_t"].values[0] 545 | if has_been_removed + count >= max(total * 0.05, 20): 546 | if count < 6 or delta_t > (100 if group.name[0] != "4" else 365): 547 | group.drop(group[group["delta_t"] == delta_t].index, inplace=True) 548 | has_been_removed += count 549 | else: 550 | group.drop(group[group["delta_t"] == delta_t].index, inplace=True) 551 | has_been_removed += count 552 | return group 553 | 554 | 555 | def remove_non_continuous_rows(group): 556 | discontinuity = group["i"].diff().fillna(1).ne(1) 557 | if not discontinuity.any(): 558 | return group 559 | else: 560 | first_non_continuous_index = discontinuity.idxmax() 561 | return group.loc[: first_non_continuous_index - 1] 562 | 563 | 564 | def fit_stability(delta_t, retention, size): 565 | def loss(stability): 566 | y_pred = power_forgetting_curve(delta_t, stability) 567 | loss = sum( 568 | -(retention * np.log(y_pred) + (1 - retention) * np.log(1 - y_pred)) * size 569 | ) 570 | return loss 571 | 572 | res = minimize(loss, x0=1, bounds=[(S_MIN, 36500)]) 573 | return res.x[0] 574 | 575 | 576 | class Optimizer: 577 | float_delta_t: bool = False 578 | enable_short_term: bool = True 579 | 580 | def __init__( 581 | self, float_delta_t: bool = False, enable_short_term: bool = True 582 | ) -> None: 583 | tqdm.pandas() 584 | self.float_delta_t = float_delta_t 585 | self.enable_short_term = enable_short_term 586 | global S_MIN 587 | S_MIN = 1e-6 if float_delta_t else 0.001 588 | 589 | def anki_extract( 590 | self, 591 | filename: str, 592 | filter_out_suspended_cards: bool = False, 593 | filter_out_flags: List[int] = [], 594 | ): 595 | """Step 1""" 596 | # Extract the collection file or deck file to get the .anki21 database. 597 | with zipfile.ZipFile(f"{filename}", "r") as zip_ref: 598 | zip_ref.extractall("./") 599 | tqdm.write("Deck file extracted successfully!") 600 | 601 | if os.path.isfile("collection.anki21b"): 602 | os.remove("collection.anki21b") 603 | raise Exception( 604 | "Please export the file with `support older Anki versions` if you use the latest version of Anki." 605 | ) 606 | elif os.path.isfile("collection.anki21"): 607 | con = sqlite3.connect("collection.anki21") 608 | elif os.path.isfile("collection.anki2"): 609 | con = sqlite3.connect("collection.anki2") 610 | else: 611 | raise Exception("Collection not exist!") 612 | cur = con.cursor() 613 | 614 | def flags2str(flags: List[int]) -> str: 615 | return f"({','.join(map(str, flags))})" 616 | 617 | res = cur.execute( 618 | f""" 619 | SELECT * 620 | FROM revlog 621 | WHERE cid IN ( 622 | SELECT id 623 | FROM cards 624 | WHERE queue != 0 625 | AND id <= {time.time() * 1000} 626 | {"AND queue != -1" if filter_out_suspended_cards else ""} 627 | {"AND flags NOT IN %s" % flags2str(filter_out_flags) if len(filter_out_flags) > 0 else ""} 628 | ) 629 | AND ease BETWEEN 1 AND 4 630 | AND ( 631 | type != 3 632 | OR factor != 0 633 | ) 634 | AND id <= {time.time() * 1000} 635 | ORDER BY cid, id 636 | """ 637 | ) 638 | revlog = res.fetchall() 639 | if len(revlog) == 0: 640 | raise Exception("No review log found!") 641 | df = pd.DataFrame(revlog) 642 | df.columns = [ 643 | "review_time", 644 | "card_id", 645 | "usn", 646 | "review_rating", 647 | "ivl", 648 | "last_ivl", 649 | "factor", 650 | "review_duration", 651 | "review_state", 652 | ] 653 | df["i"] = df.groupby("card_id").cumcount() + 1 654 | df["is_learn_start"] = (df["review_state"] == 0) & ( 655 | (df["review_state"].shift() != 0) | (df["i"] == 1) 656 | ) 657 | df["sequence_group"] = df["is_learn_start"].cumsum() 658 | last_learn_start = ( 659 | df[df["is_learn_start"]].groupby("card_id")["sequence_group"].last() 660 | ) 661 | df["last_learn_start"] = ( 662 | df["card_id"].map(last_learn_start).fillna(0).astype(int) 663 | ) 664 | df["mask"] = df["last_learn_start"] <= df["sequence_group"] 665 | df = df[df["mask"] == True].copy() 666 | df["review_state"] = df["review_state"] + 1 667 | df.loc[df["is_learn_start"], "review_state"] = New 668 | df = df.groupby("card_id").filter( 669 | lambda group: group["review_state"].iloc[0] == New 670 | ) 671 | df.drop( 672 | columns=[ 673 | "i", 674 | "is_learn_start", 675 | "sequence_group", 676 | "last_learn_start", 677 | "mask", 678 | "usn", 679 | "ivl", 680 | "last_ivl", 681 | "factor", 682 | ], 683 | inplace=True, 684 | ) 685 | df.to_csv("revlog.csv", index=False) 686 | tqdm.write("revlog.csv saved.") 687 | 688 | def extract_simulation_config(self, df): 689 | df_tmp = df[ 690 | (df["review_duration"] > 0) & (df["review_duration"] < 1200000) 691 | ].copy() 692 | 693 | state_rating_costs = ( 694 | df_tmp[df_tmp["review_state"] != 4] 695 | .groupby(["review_state", "review_rating"])["review_duration"] 696 | .median() 697 | .unstack(fill_value=0) 698 | ) / 1000 699 | state_rating_counts = ( 700 | df_tmp[df_tmp["review_state"] != 4] 701 | .groupby(["review_state", "review_rating"])["review_duration"] 702 | .count() 703 | .unstack(fill_value=0) 704 | ) 705 | 706 | # Ensure all ratings (1-4) exist in columns 707 | for rating in range(1, 5): 708 | if rating not in state_rating_costs.columns: 709 | state_rating_costs[rating] = 0 710 | if rating not in state_rating_counts.columns: 711 | state_rating_counts[rating] = 0 712 | 713 | # Ensure all states exist in index 714 | for state in [Learning, Review, Relearning]: 715 | if state not in state_rating_costs.index: 716 | state_rating_costs.loc[state] = 0 717 | if state not in state_rating_counts.index: 718 | state_rating_counts.loc[state] = 0 719 | 720 | self.state_rating_costs = state_rating_costs.values.round(2).tolist() 721 | for i, (rating_costs, default_rating_cost, rating_counts) in enumerate( 722 | zip( 723 | state_rating_costs.values.tolist(), 724 | DEFAULT_STATE_RATING_COSTS, 725 | state_rating_counts.values.tolist(), 726 | ) 727 | ): 728 | for j, (cost, default_cost, count) in enumerate( 729 | zip(rating_costs, default_rating_cost, rating_counts) 730 | ): 731 | weight = count / (50 + count) 732 | self.state_rating_costs[i][j] = cost * weight + default_cost * ( 733 | 1 - weight 734 | ) 735 | 736 | df1 = ( 737 | df_tmp.groupby(by=["card_id", "real_days"]) 738 | .agg( 739 | { 740 | "review_state": "first", 741 | "review_rating": ["first", list], 742 | "review_duration": "sum", 743 | } 744 | ) 745 | .reset_index() 746 | ) 747 | del df1["real_days"] 748 | df1.columns = [ 749 | "card_id", 750 | "first_state", 751 | "first_rating", 752 | "same_day_ratings", 753 | "sum_review_duration", 754 | ] 755 | model = FirstOrderMarkovChain() 756 | learning_step_rating_sequences = df1[df1["first_state"] == Learning][ 757 | "same_day_ratings" 758 | ] 759 | result = model.fit(learning_step_rating_sequences) 760 | learning_transition_matrix, learning_transition_counts = ( 761 | result.transition_matrix[:3], 762 | result.transition_counts[:3], 763 | ) 764 | self.learning_step_transitions = learning_transition_matrix.round(2).tolist() 765 | for i, (rating_probs, default_rating_probs, transition_counts) in enumerate( 766 | zip( 767 | learning_transition_matrix.tolist(), 768 | DEFAULT_LEARNING_STEP_TRANSITIONS, 769 | learning_transition_counts.tolist(), 770 | ) 771 | ): 772 | weight = sum(transition_counts) / (50 + sum(transition_counts)) 773 | for j, (prob, default_prob) in enumerate( 774 | zip(rating_probs, default_rating_probs) 775 | ): 776 | self.learning_step_transitions[i][j] = prob * weight + default_prob * ( 777 | 1 - weight 778 | ) 779 | 780 | relearning_step_rating_sequences = df1[ 781 | (df1["first_state"] == Review) & (df1["first_rating"] == 1) 782 | ]["same_day_ratings"] 783 | result = model.fit(relearning_step_rating_sequences) 784 | relearning_transition_matrix, relearning_transition_counts = ( 785 | result.transition_matrix[:3], 786 | result.transition_counts[:3], 787 | ) 788 | self.relearning_step_transitions = relearning_transition_matrix.round( 789 | 2 790 | ).tolist() 791 | for i, (rating_probs, default_rating_probs, transition_counts) in enumerate( 792 | zip( 793 | relearning_transition_matrix.tolist(), 794 | DEFAULT_RELEARNING_STEP_TRANSITIONS, 795 | relearning_transition_counts.tolist(), 796 | ) 797 | ): 798 | weight = sum(transition_counts) / (50 + sum(transition_counts)) 799 | for j, (prob, default_prob) in enumerate( 800 | zip(rating_probs, default_rating_probs) 801 | ): 802 | self.relearning_step_transitions[i][j] = ( 803 | prob * weight + default_prob * (1 - weight) 804 | ) 805 | 806 | button_usage_dict = defaultdict( 807 | int, 808 | ( 809 | df1.groupby(by=["first_state", "first_rating"])["card_id"] 810 | .count() 811 | .to_dict() 812 | ), 813 | ) 814 | self.learn_buttons = ( 815 | np.array([button_usage_dict[(1, i)] for i in range(1, 5)]) + 1 816 | ) 817 | self.review_buttons = ( 818 | np.array([button_usage_dict[(2, i)] for i in range(1, 5)]) + 1 819 | ) 820 | self.first_rating_prob = self.learn_buttons / self.learn_buttons.sum() 821 | self.review_rating_prob = ( 822 | self.review_buttons[1:] / self.review_buttons[1:].sum() 823 | ) 824 | 825 | weight = sum(self.learn_buttons) / (50 + sum(self.learn_buttons)) 826 | self.first_rating_prob = ( 827 | self.first_rating_prob * weight + DEFAULT_FIRST_RATING_PROB * (1 - weight) 828 | ) 829 | 830 | weight = sum(self.review_buttons[1:]) / (50 + sum(self.review_buttons[1:])) 831 | self.review_rating_prob = ( 832 | self.review_rating_prob * weight + DEFAULT_REVIEW_RATING_PROB * (1 - weight) 833 | ) 834 | 835 | def create_time_series( 836 | self, 837 | timezone: str, 838 | revlog_start_date: str, 839 | next_day_starts_at: int, 840 | analysis: bool = True, 841 | ): 842 | """Step 2""" 843 | df = pd.read_csv("./revlog.csv") 844 | df.sort_values(by=["card_id", "review_time"], inplace=True, ignore_index=True) 845 | df["review_date"] = pd.to_datetime(df["review_time"] // 1000, unit="s") 846 | df["review_date"] = ( 847 | df["review_date"].dt.tz_localize("UTC").dt.tz_convert(timezone) 848 | ) 849 | df.drop(df[df["review_date"].dt.year < 2006].index, inplace=True) 850 | df["real_days"] = df["review_date"] - timedelta(hours=int(next_day_starts_at)) 851 | df["real_days"] = pd.DatetimeIndex( 852 | df["real_days"].dt.floor( 853 | "D", ambiguous="infer", nonexistent="shift_forward" 854 | ) 855 | ).to_julian_date() 856 | # df.drop_duplicates(["card_id", "real_days"], keep="first", inplace=True) 857 | if self.float_delta_t: 858 | df["delta_t"] = df["review_time"].diff().fillna(0) / 1000 / 86400 859 | else: 860 | df["delta_t"] = df.real_days.diff() 861 | df.fillna({"delta_t": 0}, inplace=True) 862 | df["i"] = df.groupby("card_id").cumcount() + 1 863 | df.loc[df["i"] == 1, "delta_t"] = -1 864 | if df.empty: 865 | raise ValueError("Training data is inadequate.") 866 | 867 | if ( 868 | "review_state" in df.columns 869 | and "review_duration" in df.columns 870 | and not (df["review_duration"] == 0).all() 871 | ): 872 | df["review_state"] = df["review_state"].map( 873 | lambda x: x if x != New else Learning 874 | ) 875 | self.extract_simulation_config(df) 876 | df.drop(columns=["review_duration", "review_state"], inplace=True) 877 | 878 | def cum_concat(x): 879 | return list(accumulate(x)) 880 | 881 | t_history_list = df.groupby("card_id", group_keys=False)["delta_t"].apply( 882 | lambda x: cum_concat( 883 | [[max(0, round(i, 6) if self.float_delta_t else int(i))] for i in x] 884 | ) 885 | ) 886 | df["t_history"] = [ 887 | ",".join(map(str, item[:-1])) 888 | for sublist in t_history_list 889 | for item in sublist 890 | ] 891 | r_history_list = df.groupby("card_id", group_keys=False)["review_rating"].apply( 892 | lambda x: cum_concat([[i] for i in x]) 893 | ) 894 | df["r_history"] = [ 895 | ",".join(map(str, item[:-1])) 896 | for sublist in r_history_list 897 | for item in sublist 898 | ] 899 | last_rating = [] 900 | for t_sublist, r_sublist in zip(t_history_list, r_history_list): 901 | for t_history, r_history in zip(t_sublist, r_sublist): 902 | flag = True 903 | for t, r in zip(reversed(t_history[:-1]), reversed(r_history[:-1])): 904 | if t > 0: 905 | last_rating.append(r) 906 | flag = False 907 | break 908 | if flag: 909 | last_rating.append(r_history[0]) 910 | df["last_rating"] = last_rating 911 | 912 | df = df.groupby("card_id").filter( 913 | lambda group: group["review_time"].min() 914 | > time.mktime(datetime.strptime(revlog_start_date, "%Y-%m-%d").timetuple()) 915 | * 1000 916 | ) 917 | df = df[ 918 | (df["review_rating"] != 0) 919 | & (df["r_history"].str.contains("0") == 0) 920 | & (df["delta_t"] != 0) 921 | ].copy() 922 | df["i"] = df.groupby("card_id").cumcount() + 1 923 | df["first_rating"] = df["r_history"].map(lambda x: x[0] if len(x) > 0 else "") 924 | df["y"] = df["review_rating"].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x]) 925 | 926 | if not self.float_delta_t: 927 | df[df["i"] == 2] = ( 928 | df[df["i"] == 2] 929 | .groupby(by=["first_rating"], as_index=False, group_keys=False) 930 | .apply(remove_outliers) 931 | ) 932 | df.dropna(inplace=True) 933 | 934 | df = df.groupby("card_id", as_index=False, group_keys=False).progress_apply( 935 | remove_non_continuous_rows 936 | ) 937 | 938 | df["review_time"] = df["review_time"].astype(int) 939 | df["review_rating"] = df["review_rating"].astype(int) 940 | df["delta_t"] = df["delta_t"].astype(float if self.float_delta_t else int) 941 | df["i"] = df["i"].astype(int) 942 | df["t_history"] = df["t_history"].astype(str) 943 | df["r_history"] = df["r_history"].astype(str) 944 | df["last_rating"] = df["last_rating"].astype(int) 945 | df["y"] = df["y"].astype(int) 946 | 947 | df.to_csv("revlog_history.tsv", sep="\t", index=False) 948 | tqdm.write("Trainset saved.") 949 | 950 | self.S0_dataset_group = ( 951 | df[df["i"] == 2] 952 | .groupby(by=["first_rating", "delta_t"], group_keys=False) 953 | .agg({"y": ["mean", "count"]}) 954 | .reset_index() 955 | ) 956 | self.S0_dataset_group.to_csv("stability_for_pretrain.tsv", sep="\t", index=None) 957 | del df["first_rating"] 958 | 959 | if not analysis: 960 | return 961 | 962 | df["r_history"] = df.apply( 963 | lambda row: wrap_short_term_ratings(row["r_history"], row["t_history"]), 964 | axis=1, 965 | ) 966 | 967 | df["retention"] = df.groupby(by=["r_history", "delta_t"], group_keys=False)[ 968 | "y" 969 | ].transform("mean") 970 | df["total_cnt"] = df.groupby(by=["r_history", "delta_t"], group_keys=False)[ 971 | "review_time" 972 | ].transform("count") 973 | tqdm.write("Retention calculated.") 974 | 975 | df.drop( 976 | columns=[ 977 | "review_time", 978 | "card_id", 979 | "review_date", 980 | "real_days", 981 | "review_rating", 982 | "t_history", 983 | "last_rating", 984 | "y", 985 | ], 986 | inplace=True, 987 | ) 988 | df.drop_duplicates(inplace=True) 989 | df["retention"] = df["retention"].map(lambda x: max(min(0.99, x), 0.01)) 990 | 991 | def cal_stability(group: pd.DataFrame) -> pd.DataFrame: 992 | group_cnt = sum(group.groupby("delta_t").first()["total_cnt"]) 993 | if group_cnt < 10: 994 | return pd.DataFrame() 995 | group["group_cnt"] = group_cnt 996 | if group["i"].values[0] > 1: 997 | group["stability"] = round( 998 | fit_stability( 999 | group["delta_t"], group["retention"], group["total_cnt"] 1000 | ), 1001 | 1, 1002 | ) 1003 | else: 1004 | group["stability"] = 0.0 1005 | group["avg_retention"] = round( 1006 | sum(group["retention"] * pow(group["total_cnt"], 2)) 1007 | / sum(pow(group["total_cnt"], 2)), 1008 | 3, 1009 | ) 1010 | group["avg_interval"] = round( 1011 | sum(group["delta_t"] * pow(group["total_cnt"], 2)) 1012 | / sum(pow(group["total_cnt"], 2)), 1013 | 1, 1014 | ) 1015 | del group["total_cnt"] 1016 | del group["retention"] 1017 | del group["delta_t"] 1018 | return group 1019 | 1020 | df = df.groupby(by=["r_history"], group_keys=False).progress_apply( 1021 | cal_stability 1022 | ) 1023 | if df.empty: 1024 | return "No enough data for stability calculation." 1025 | tqdm.write("Stability calculated.") 1026 | df.reset_index(drop=True, inplace=True) 1027 | df.drop_duplicates(inplace=True) 1028 | df.sort_values(by=["r_history"], inplace=True, ignore_index=True) 1029 | 1030 | if df.shape[0] > 0: 1031 | for idx in tqdm(df.index, desc="analysis"): 1032 | item = df.loc[idx] 1033 | index = df[ 1034 | (df["i"] == item["i"] + 1) 1035 | & (df["r_history"].str.startswith(item["r_history"])) 1036 | ].index 1037 | df.loc[index, "last_stability"] = item["stability"] 1038 | df["factor"] = round(df["stability"] / df["last_stability"], 2) 1039 | df = df[(df["i"] >= 2) & (df["group_cnt"] >= 100)].copy() 1040 | df["last_recall"] = df["r_history"].map(lambda x: x[-1]) 1041 | df = df[ 1042 | df.groupby(["r_history"], group_keys=False)["group_cnt"].transform( 1043 | "max" 1044 | ) 1045 | == df["group_cnt"] 1046 | ] 1047 | df.to_csv("./stability_for_analysis.tsv", sep="\t", index=None) 1048 | tqdm.write("Analysis saved!") 1049 | caption = "1:again, 2:hard, 3:good, 4:easy\n" 1050 | df["first_rating"] = df["r_history"].map(lambda x: x[1]) 1051 | analysis = ( 1052 | df[df["r_history"].str.contains(r"^\([1-4][^124]*$", regex=True)][ 1053 | [ 1054 | "first_rating", 1055 | "i", 1056 | "r_history", 1057 | "avg_interval", 1058 | "avg_retention", 1059 | "stability", 1060 | "factor", 1061 | "group_cnt", 1062 | ] 1063 | ] 1064 | .sort_values(by=["first_rating", "i"]) 1065 | .to_string(index=False) 1066 | ) 1067 | return caption + analysis 1068 | 1069 | def define_model(self): 1070 | """Step 3""" 1071 | self.init_w = DEFAULT_PARAMETER.copy() 1072 | """ 1073 | For details about the parameters, please see: 1074 | https://github.com/open-spaced-repetition/fsrs4anki/wiki/The-Algorithm 1075 | """ 1076 | 1077 | def pretrain(self, dataset=None, verbose=True): 1078 | if dataset is None: 1079 | self.dataset = pd.read_csv( 1080 | "./revlog_history.tsv", 1081 | sep="\t", 1082 | index_col=None, 1083 | dtype={"r_history": str, "t_history": str}, 1084 | ) 1085 | else: 1086 | self.dataset = dataset 1087 | self.dataset["r_history"] = self.dataset["r_history"].fillna("") 1088 | self.dataset["first_rating"] = self.dataset["r_history"].map( 1089 | lambda x: x[0] if len(x) > 0 else "" 1090 | ) 1091 | self.S0_dataset_group = ( 1092 | self.dataset[self.dataset["i"] == 2] 1093 | .groupby(by=["first_rating", "delta_t"], group_keys=False) 1094 | .agg({"y": ["mean", "count"]}) 1095 | .reset_index() 1096 | ) 1097 | self.dataset = self.dataset[ 1098 | (self.dataset["i"] > 1) & (self.dataset["delta_t"] > 0) 1099 | ] 1100 | if self.dataset.empty: 1101 | raise ValueError("Training data is inadequate.") 1102 | rating_stability = {} 1103 | rating_count = {} 1104 | average_recall = self.dataset["y"].mean() 1105 | plots = [] 1106 | r_s0_default = {str(i): DEFAULT_PARAMETER[i - 1] for i in range(1, 5)} 1107 | 1108 | for first_rating in ("1", "2", "3", "4"): 1109 | group = self.S0_dataset_group[ 1110 | self.S0_dataset_group["first_rating"] == first_rating 1111 | ] 1112 | if group.empty: 1113 | if verbose: 1114 | tqdm.write( 1115 | f"Not enough data for first rating {first_rating}. Expected at least 1, got 0." 1116 | ) 1117 | continue 1118 | delta_t = group["delta_t"] 1119 | recall = ( 1120 | (group["y"]["mean"] * group["y"]["count"] + average_recall * 1) 1121 | / (group["y"]["count"] + 1) 1122 | if not self.float_delta_t 1123 | else group["y"]["mean"] 1124 | ) 1125 | count = group["y"]["count"] 1126 | 1127 | init_s0 = r_s0_default[first_rating] 1128 | 1129 | def loss(stability): 1130 | y_pred = power_forgetting_curve(delta_t, stability) 1131 | logloss = sum( 1132 | -(recall * np.log(y_pred) + (1 - recall) * np.log(1 - y_pred)) 1133 | * count 1134 | ) 1135 | l1 = np.abs(stability - init_s0) / 16 if not self.float_delta_t else 0 1136 | return logloss + l1 1137 | 1138 | res = minimize( 1139 | loss, 1140 | x0=init_s0, 1141 | bounds=((S_MIN, 100),), 1142 | options={"maxiter": int(sum(count))}, 1143 | ) 1144 | params = res.x 1145 | stability = params[0] 1146 | rating_stability[int(first_rating)] = stability 1147 | rating_count[int(first_rating)] = sum(count) 1148 | predict_recall = power_forgetting_curve(delta_t, *params) 1149 | rmse = root_mean_squared_error(recall, predict_recall, sample_weight=count) 1150 | 1151 | if verbose: 1152 | fig = plt.figure() 1153 | ax = fig.gca() 1154 | ax.plot(delta_t, recall, label="Exact") 1155 | ax.plot( 1156 | np.linspace(0, 30), 1157 | power_forgetting_curve(np.linspace(0, 30), *params), 1158 | label=f"Weighted fit (RMSE: {rmse:.4f})", 1159 | ) 1160 | count_percent = np.array([x / sum(count) for x in count]) 1161 | ax.scatter(delta_t, recall, s=count_percent * 1000, alpha=0.5) 1162 | ax.legend(loc="upper right", fancybox=True, shadow=False) 1163 | ax.grid(True) 1164 | ax.set_ylim(0, 1) 1165 | ax.set_xlabel("Interval") 1166 | ax.set_ylabel("Recall") 1167 | ax.set_title( 1168 | f"Forgetting curve for first rating {first_rating} (n={sum(count)}, s={stability:.2f})" 1169 | ) 1170 | plots.append(fig) 1171 | tqdm.write(str(rating_stability)) 1172 | 1173 | for small_rating, big_rating in ( 1174 | (1, 2), 1175 | (2, 3), 1176 | (3, 4), 1177 | (1, 3), 1178 | (2, 4), 1179 | (1, 4), 1180 | ): 1181 | if small_rating in rating_stability and big_rating in rating_stability: 1182 | # if rating_count[small_rating] > 300 and rating_count[big_rating] > 300: 1183 | # continue 1184 | if rating_stability[small_rating] > rating_stability[big_rating]: 1185 | if rating_count[small_rating] > rating_count[big_rating]: 1186 | rating_stability[big_rating] = rating_stability[small_rating] 1187 | else: 1188 | rating_stability[small_rating] = rating_stability[big_rating] 1189 | 1190 | w1 = 0.41 1191 | w2 = 0.54 1192 | 1193 | if len(rating_stability) == 0: 1194 | raise Exception("Not enough data for pretraining!") 1195 | elif len(rating_stability) == 1: 1196 | rating = list(rating_stability.keys())[0] 1197 | factor = rating_stability[rating] / r_s0_default[str(rating)] 1198 | init_s0 = list(map(lambda x: x * factor, r_s0_default.values())) 1199 | elif len(rating_stability) == 2: 1200 | if 1 not in rating_stability and 2 not in rating_stability: 1201 | rating_stability[2] = np.power( 1202 | rating_stability[3], 1 / (1 - w2) 1203 | ) * np.power(rating_stability[4], 1 - 1 / (1 - w2)) 1204 | rating_stability[1] = np.power(rating_stability[2], 1 / w1) * np.power( 1205 | rating_stability[3], 1 - 1 / w1 1206 | ) 1207 | elif 1 not in rating_stability and 3 not in rating_stability: 1208 | rating_stability[3] = np.power(rating_stability[2], 1 - w2) * np.power( 1209 | rating_stability[4], w2 1210 | ) 1211 | rating_stability[1] = np.power(rating_stability[2], 1 / w1) * np.power( 1212 | rating_stability[3], 1 - 1 / w1 1213 | ) 1214 | elif 1 not in rating_stability and 4 not in rating_stability: 1215 | rating_stability[4] = np.power( 1216 | rating_stability[2], 1 - 1 / w2 1217 | ) * np.power(rating_stability[3], 1 / w2) 1218 | rating_stability[1] = np.power(rating_stability[2], 1 / w1) * np.power( 1219 | rating_stability[3], 1 - 1 / w1 1220 | ) 1221 | elif 2 not in rating_stability and 3 not in rating_stability: 1222 | rating_stability[2] = np.power( 1223 | rating_stability[1], w1 / (w1 + w2 - w1 * w2) 1224 | ) * np.power(rating_stability[4], 1 - w1 / (w1 + w2 - w1 * w2)) 1225 | rating_stability[3] = np.power( 1226 | rating_stability[1], 1 - w2 / (w1 + w2 - w1 * w2) 1227 | ) * np.power(rating_stability[4], w2 / (w1 + w2 - w1 * w2)) 1228 | elif 2 not in rating_stability and 4 not in rating_stability: 1229 | rating_stability[2] = np.power(rating_stability[1], w1) * np.power( 1230 | rating_stability[3], 1 - w1 1231 | ) 1232 | rating_stability[4] = np.power( 1233 | rating_stability[2], 1 - 1 / w2 1234 | ) * np.power(rating_stability[3], 1 / w2) 1235 | elif 3 not in rating_stability and 4 not in rating_stability: 1236 | rating_stability[3] = np.power( 1237 | rating_stability[1], 1 - 1 / (1 - w1) 1238 | ) * np.power(rating_stability[2], 1 / (1 - w1)) 1239 | rating_stability[4] = np.power( 1240 | rating_stability[2], 1 - 1 / w2 1241 | ) * np.power(rating_stability[3], 1 / w2) 1242 | init_s0 = [ 1243 | item[1] for item in sorted(rating_stability.items(), key=lambda x: x[0]) 1244 | ] 1245 | elif len(rating_stability) == 3: 1246 | if 1 not in rating_stability: 1247 | rating_stability[1] = np.power(rating_stability[2], 1 / w1) * np.power( 1248 | rating_stability[3], 1 - 1 / w1 1249 | ) 1250 | elif 2 not in rating_stability: 1251 | rating_stability[2] = np.power(rating_stability[1], w1) * np.power( 1252 | rating_stability[3], 1 - w1 1253 | ) 1254 | elif 3 not in rating_stability: 1255 | rating_stability[3] = np.power(rating_stability[2], 1 - w2) * np.power( 1256 | rating_stability[4], w2 1257 | ) 1258 | elif 4 not in rating_stability: 1259 | rating_stability[4] = np.power( 1260 | rating_stability[2], 1 - 1 / w2 1261 | ) * np.power(rating_stability[3], 1 / w2) 1262 | init_s0 = [ 1263 | item[1] for item in sorted(rating_stability.items(), key=lambda x: x[0]) 1264 | ] 1265 | elif len(rating_stability) == 4: 1266 | init_s0 = [ 1267 | item[1] for item in sorted(rating_stability.items(), key=lambda x: x[0]) 1268 | ] 1269 | 1270 | self.init_w[0:4] = list(map(lambda x: max(min(100, x), S_MIN), init_s0)) 1271 | if verbose: 1272 | tqdm.write(f"Pretrain finished!") 1273 | return plots 1274 | 1275 | def train( 1276 | self, 1277 | lr: float = 4e-2, 1278 | n_epoch: int = 5, 1279 | gamma: float = 1.0, 1280 | batch_size: int = 512, 1281 | verbose: bool = True, 1282 | split_by_time: bool = False, 1283 | recency_weight: bool = False, 1284 | ): 1285 | """Step 4""" 1286 | self.dataset["tensor"] = self.dataset.progress_apply( 1287 | lambda x: lineToTensor(list(zip([x["t_history"]], [x["r_history"]]))[0]), 1288 | axis=1, 1289 | ) 1290 | self.dataset["group"] = self.dataset["r_history"] + self.dataset["t_history"] 1291 | if verbose: 1292 | tqdm.write("Tensorized!") 1293 | 1294 | w = [] 1295 | plots = [] 1296 | self.dataset.sort_values(by=["review_time"], inplace=True) 1297 | if split_by_time: 1298 | tscv = TimeSeriesSplit(n_splits=5) 1299 | for i, (train_index, test_index) in enumerate(tscv.split(self.dataset)): 1300 | if verbose: 1301 | tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}") 1302 | train_set = self.dataset.iloc[train_index].copy() 1303 | test_set = self.dataset.iloc[test_index].copy() 1304 | trainer = Trainer( 1305 | train_set, 1306 | test_set, 1307 | self.init_w, 1308 | n_epoch=n_epoch, 1309 | lr=lr, 1310 | gamma=gamma, 1311 | batch_size=batch_size, 1312 | float_delta_t=self.float_delta_t, 1313 | enable_short_term=self.enable_short_term, 1314 | ) 1315 | w.append(trainer.train(verbose=verbose)) 1316 | self.w = w[-1] 1317 | self.evaluate() 1318 | metrics, figures = self.calibration_graph(self.dataset.iloc[test_index]) 1319 | for j, f in enumerate(figures): 1320 | f.savefig(f"graph_{j}_test_{i}.png") 1321 | plt.close(f) 1322 | if verbose: 1323 | print(metrics) 1324 | plots.append(trainer.plot()) 1325 | else: 1326 | if recency_weight: 1327 | x = np.linspace(0, 1, len(self.dataset)) 1328 | self.dataset["weights"] = 0.25 + 0.75 * np.power(x, 3) 1329 | trainer = Trainer( 1330 | self.dataset, 1331 | None, 1332 | self.init_w, 1333 | n_epoch=n_epoch, 1334 | lr=lr, 1335 | gamma=gamma, 1336 | batch_size=batch_size, 1337 | float_delta_t=self.float_delta_t, 1338 | enable_short_term=self.enable_short_term, 1339 | ) 1340 | w.append(trainer.train(verbose=verbose)) 1341 | if verbose: 1342 | plots.append(trainer.plot()) 1343 | 1344 | w = np.array(w) 1345 | avg_w = np.round(np.mean(w, axis=0), 4) 1346 | self.w = avg_w.tolist() 1347 | 1348 | if verbose: 1349 | tqdm.write("\nTraining finished!") 1350 | return plots 1351 | 1352 | def preview(self, requestRetention: float, verbose=False, n_steps=3): 1353 | my_collection = Collection(self.w, self.float_delta_t) 1354 | preview_text = "1:again, 2:hard, 3:good, 4:easy\n" 1355 | n_learning_steps = n_steps if not self.float_delta_t else 0 1356 | for first_rating in (1, 2, 3, 4): 1357 | preview_text += f"\nfirst rating: {first_rating}\n" 1358 | t_history = "0" 1359 | d_history = "0" 1360 | s_history = "0" 1361 | r_history = f"{first_rating}" # the first rating of the new card 1362 | if first_rating in (1, 2): 1363 | left = n_learning_steps 1364 | elif first_rating == 3: 1365 | left = n_learning_steps - 1 1366 | else: 1367 | left = 1 1368 | # print("stability, difficulty, lapses") 1369 | for i in range(10): 1370 | states = my_collection.predict(t_history, r_history) 1371 | stability = round(float(states[0]), 1) 1372 | difficulty = round(float(states[1]), 1) 1373 | if verbose: 1374 | print( 1375 | "{0:9.2f} {1:11.2f} {2:7.0f}".format( 1376 | *list(map(lambda x: round(float(x), 4), states)) 1377 | ) 1378 | ) 1379 | left -= 1 1380 | next_t = ( 1381 | next_interval( 1382 | states[0].detach().numpy(), requestRetention, self.float_delta_t 1383 | ) 1384 | if left <= 0 1385 | else 0 1386 | ) 1387 | t_history += f",{next_t}" 1388 | d_history += f",{difficulty}" 1389 | s_history += f",{stability}" 1390 | r_history += f",3" 1391 | r_history = wrap_short_term_ratings(r_history, t_history) 1392 | preview_text += f"rating history: {r_history}\n" 1393 | preview_text += ( 1394 | "interval history: " 1395 | + ",".join( 1396 | [ 1397 | ( 1398 | f"{ivl:.4f}d" 1399 | if ivl < 1 and ivl > 0 1400 | else ( 1401 | f"{ivl:.1f}d" 1402 | if ivl < 30 1403 | else ( 1404 | f"{ivl / 30:.1f}m" 1405 | if ivl < 365 1406 | else f"{ivl / 365:.1f}y" 1407 | ) 1408 | ) 1409 | ) 1410 | for ivl in map( 1411 | int if not self.float_delta_t else float, 1412 | t_history.split(","), 1413 | ) 1414 | ] 1415 | ) 1416 | + "\n" 1417 | ) 1418 | preview_text += ( 1419 | "factor history: " 1420 | + ",".join( 1421 | ["0.0"] 1422 | + [ 1423 | ( 1424 | f"{float(ivl) / float(pre_ivl):.2f}" 1425 | if pre_ivl != "0" 1426 | else "0.0" 1427 | ) 1428 | for ivl, pre_ivl in zip( 1429 | t_history.split(",")[1:], 1430 | t_history.split(",")[:-1], 1431 | ) 1432 | ] 1433 | ) 1434 | + "\n" 1435 | ) 1436 | preview_text += f"difficulty history: {d_history}\n" 1437 | preview_text += f"stability history: {s_history}\n" 1438 | return preview_text 1439 | 1440 | def preview_sequence(self, test_rating_sequence: str, requestRetention: float): 1441 | my_collection = Collection(self.w, self.float_delta_t) 1442 | 1443 | t_history = "0" 1444 | d_history = "0" 1445 | for i in range(len(test_rating_sequence.split(","))): 1446 | r_history = test_rating_sequence[: 2 * i + 1] 1447 | states = my_collection.predict(t_history, r_history) 1448 | next_t = next_interval( 1449 | states[0].detach().numpy(), requestRetention, self.float_delta_t 1450 | ) 1451 | t_history += f",{next_t}" 1452 | difficulty = round(float(states[1]), 1) 1453 | d_history += f",{difficulty}" 1454 | preview_text = f"rating history: {test_rating_sequence}\n" 1455 | preview_text += ( 1456 | "interval history: " 1457 | + ",".join( 1458 | [ 1459 | ( 1460 | f"{ivl:.4f}d" 1461 | if ivl < 1 and ivl > 0 1462 | else ( 1463 | f"{ivl:.1f}d" 1464 | if ivl < 30 1465 | else ( 1466 | f"{ivl / 30:.1f}m" if ivl < 365 else f"{ivl / 365:.1f}y" 1467 | ) 1468 | ) 1469 | ) 1470 | for ivl in map( 1471 | int if not self.float_delta_t else float, 1472 | t_history.split(","), 1473 | ) 1474 | ] 1475 | ) 1476 | + "\n" 1477 | ) 1478 | preview_text += ( 1479 | "factor history: " 1480 | + ",".join( 1481 | ["0.0"] 1482 | + [ 1483 | f"{float(ivl) / float(pre_ivl):.2f}" if pre_ivl != "0" else "0.0" 1484 | for ivl, pre_ivl in zip( 1485 | t_history.split(",")[1:], 1486 | t_history.split(",")[:-1], 1487 | ) 1488 | ] 1489 | ) 1490 | + "\n" 1491 | ) 1492 | preview_text += f"difficulty history: {d_history}" 1493 | return preview_text 1494 | 1495 | def predict_memory_states(self): 1496 | my_collection = Collection(self.w, self.float_delta_t) 1497 | 1498 | stabilities, difficulties = my_collection.batch_predict(self.dataset) 1499 | stabilities = map(lambda x: round(x, 2), stabilities) 1500 | difficulties = map(lambda x: round(x, 2), difficulties) 1501 | self.dataset["stability"] = list(stabilities) 1502 | self.dataset["difficulty"] = list(difficulties) 1503 | prediction = self.dataset.groupby(by=["t_history", "r_history"]).agg( 1504 | {"stability": "mean", "difficulty": "mean", "review_time": "count"} 1505 | ) 1506 | prediction.reset_index(inplace=True) 1507 | prediction.sort_values(by=["r_history"], inplace=True) 1508 | prediction.rename(columns={"review_time": "count"}, inplace=True) 1509 | prediction.to_csv("./prediction.tsv", sep="\t", index=None) 1510 | prediction["difficulty"] = prediction["difficulty"].map(lambda x: int(round(x))) 1511 | self.difficulty_distribution = ( 1512 | prediction.groupby(by=["difficulty"])["count"].sum() 1513 | / prediction["count"].sum() 1514 | ) 1515 | self.difficulty_distribution_padding = np.zeros(10) 1516 | for i in range(10): 1517 | if i + 1 in self.difficulty_distribution.index: 1518 | self.difficulty_distribution_padding[i] = ( 1519 | self.difficulty_distribution.loc[i + 1] 1520 | ) 1521 | return self.difficulty_distribution 1522 | 1523 | def find_optimal_retention( 1524 | self, 1525 | learn_span=365, 1526 | max_ivl=36500, 1527 | verbose=True, 1528 | ): 1529 | """should not be called before predict_memory_states""" 1530 | if verbose: 1531 | print("Learn buttons: ", self.learn_buttons) 1532 | print("Review buttons: ", self.review_buttons) 1533 | print("First rating prob: ", self.first_rating_prob) 1534 | print("Review rating prob: ", self.review_rating_prob) 1535 | print("Learning step transitions: ", self.learning_step_transitions) 1536 | print("Relearning step transitions: ", self.relearning_step_transitions) 1537 | print("State rating costs: ", self.state_rating_costs) 1538 | 1539 | simulate_config = { 1540 | "w": self.w, 1541 | "deck_size": learn_span * 10, 1542 | "learn_span": learn_span, 1543 | "max_cost_perday": math.inf, 1544 | "learn_limit_perday": 10, 1545 | "review_limit_perday": math.inf, 1546 | "max_ivl": max_ivl, 1547 | "first_rating_prob": self.first_rating_prob, 1548 | "review_rating_prob": self.review_rating_prob, 1549 | "learning_step_transitions": self.learning_step_transitions, 1550 | "relearning_step_transitions": self.relearning_step_transitions, 1551 | "state_rating_costs": self.state_rating_costs, 1552 | } 1553 | self.optimal_retention = optimal_retention(**simulate_config) 1554 | 1555 | tqdm.write( 1556 | f"\n-----suggested retention (experimental): {self.optimal_retention:.2f}-----" 1557 | ) 1558 | 1559 | if not verbose: 1560 | return () 1561 | 1562 | ( 1563 | _, 1564 | review_cnt_per_day, 1565 | learn_cnt_per_day, 1566 | memorized_cnt_per_day, 1567 | cost_per_day, 1568 | _, 1569 | ) = simulate(**simulate_config) 1570 | 1571 | def moving_average(data, window_size=365 // 20): 1572 | weights = np.ones(window_size) / window_size 1573 | return np.convolve(data, weights, mode="valid") 1574 | 1575 | fig1 = plt.figure() 1576 | ax = fig1.gca() 1577 | ax.plot( 1578 | moving_average(review_cnt_per_day), 1579 | label=f"R={self.optimal_retention*100:.0f}%", 1580 | ) 1581 | ax.set_title("Review Count per Day") 1582 | ax.legend() 1583 | ax.grid(True) 1584 | fig2 = plt.figure() 1585 | ax = fig2.gca() 1586 | ax.plot( 1587 | moving_average(learn_cnt_per_day), 1588 | label=f"R={self.optimal_retention*100:.0f}%", 1589 | ) 1590 | ax.set_title("Learn Count per Day") 1591 | ax.legend() 1592 | ax.grid(True) 1593 | fig3 = plt.figure() 1594 | ax = fig3.gca() 1595 | ax.plot( 1596 | np.cumsum(learn_cnt_per_day), label=f"R={self.optimal_retention*100:.0f}%" 1597 | ) 1598 | ax.set_title("Cumulative Learn Count") 1599 | ax.legend() 1600 | ax.grid(True) 1601 | fig4 = plt.figure() 1602 | ax = fig4.gca() 1603 | ax.plot(memorized_cnt_per_day, label=f"R={self.optimal_retention*100:.0f}%") 1604 | ax.set_title("Memorized Count per Day") 1605 | ax.legend() 1606 | ax.grid(True) 1607 | 1608 | fig5 = plt.figure() 1609 | ax = fig5.gca() 1610 | ax.plot(cost_per_day, label=f"R={self.optimal_retention*100:.0f}%") 1611 | ax.set_title("Cost per Day") 1612 | ax.legend() 1613 | ax.grid(True) 1614 | 1615 | fig6 = workload_graph(simulate_config) 1616 | 1617 | return (fig1, fig2, fig3, fig4, fig5, fig6) 1618 | 1619 | def evaluate(self, save_to_file=True): 1620 | my_collection = Collection(DEFAULT_PARAMETER, self.float_delta_t) 1621 | if "tensor" not in self.dataset.columns: 1622 | self.dataset["tensor"] = self.dataset.progress_apply( 1623 | lambda x: lineToTensor( 1624 | list(zip([x["t_history"]], [x["r_history"]]))[0] 1625 | ), 1626 | axis=1, 1627 | ) 1628 | stabilities, difficulties = my_collection.batch_predict(self.dataset) 1629 | self.dataset["stability"] = stabilities 1630 | self.dataset["difficulty"] = difficulties 1631 | self.dataset["p"] = power_forgetting_curve( 1632 | self.dataset["delta_t"], 1633 | self.dataset["stability"], 1634 | -my_collection.model.w[20].detach().numpy(), 1635 | ) 1636 | self.dataset["log_loss"] = self.dataset.apply( 1637 | lambda row: -np.log(row["p"]) if row["y"] == 1 else -np.log(1 - row["p"]), 1638 | axis=1, 1639 | ) 1640 | if "weights" not in self.dataset.columns: 1641 | self.dataset["weights"] = 1 1642 | self.dataset["log_loss"] = ( 1643 | self.dataset["log_loss"] 1644 | * self.dataset["weights"] 1645 | / self.dataset["weights"].mean() 1646 | ) 1647 | loss_before = self.dataset["log_loss"].mean() 1648 | 1649 | my_collection = Collection(self.w, self.float_delta_t) 1650 | stabilities, difficulties = my_collection.batch_predict(self.dataset) 1651 | self.dataset["stability"] = stabilities 1652 | self.dataset["difficulty"] = difficulties 1653 | self.dataset["p"] = power_forgetting_curve( 1654 | self.dataset["delta_t"], 1655 | self.dataset["stability"], 1656 | -my_collection.model.w[20].detach().numpy(), 1657 | ) 1658 | self.dataset["log_loss"] = self.dataset.apply( 1659 | lambda row: -np.log(row["p"]) if row["y"] == 1 else -np.log(1 - row["p"]), 1660 | axis=1, 1661 | ) 1662 | self.dataset["log_loss"] = ( 1663 | self.dataset["log_loss"] 1664 | * self.dataset["weights"] 1665 | / self.dataset["weights"].mean() 1666 | ) 1667 | loss_after = self.dataset["log_loss"].mean() 1668 | if save_to_file: 1669 | tmp = self.dataset.copy() 1670 | tmp["stability"] = tmp["stability"].map(lambda x: round(x, 2)) 1671 | tmp["difficulty"] = tmp["difficulty"].map(lambda x: round(x, 2)) 1672 | tmp["p"] = tmp["p"].map(lambda x: round(x, 2)) 1673 | tmp["log_loss"] = tmp["log_loss"].map(lambda x: round(x, 2)) 1674 | tmp.rename(columns={"p": "retrievability"}, inplace=True) 1675 | tmp[ 1676 | [ 1677 | "review_time", 1678 | "card_id", 1679 | "review_date", 1680 | "r_history", 1681 | "t_history", 1682 | "delta_t", 1683 | "review_rating", 1684 | "stability", 1685 | "difficulty", 1686 | "retrievability", 1687 | "log_loss", 1688 | ] 1689 | ].to_csv("./evaluation.tsv", sep="\t", index=False) 1690 | del tmp 1691 | return loss_before, loss_after 1692 | 1693 | def calibration_graph(self, dataset=None, verbose=True): 1694 | if dataset is None: 1695 | dataset = self.dataset 1696 | fig1 = plt.figure() 1697 | rmse = rmse_matrix(dataset) 1698 | if verbose: 1699 | tqdm.write(f"RMSE(bins): {rmse:.4f}") 1700 | metrics_all = {} 1701 | metrics = plot_brier( 1702 | dataset["p"], dataset["y"], bins=20, ax=fig1.add_subplot(111) 1703 | ) 1704 | metrics["RMSE(bins)"] = rmse 1705 | metrics["AUC"] = ( 1706 | roc_auc_score(y_true=dataset["y"], y_score=dataset["p"]) 1707 | if len(dataset["y"].unique()) == 2 1708 | else np.nan 1709 | ) 1710 | metrics["LogLoss"] = log_loss(y_true=dataset["y"], y_pred=dataset["p"]) 1711 | metrics_all["all"] = metrics 1712 | fig2 = plt.figure(figsize=(16, 12)) 1713 | for last_rating in (1, 2, 3, 4): 1714 | calibration_data = dataset[dataset["last_rating"] == last_rating] 1715 | if calibration_data.empty: 1716 | continue 1717 | rmse = rmse_matrix(calibration_data) 1718 | if verbose: 1719 | tqdm.write(f"\nLast rating: {last_rating}") 1720 | tqdm.write(f"RMSE(bins): {rmse:.4f}") 1721 | metrics = plot_brier( 1722 | calibration_data["p"], 1723 | calibration_data["y"], 1724 | bins=20, 1725 | ax=fig2.add_subplot(2, 2, int(last_rating)), 1726 | title=f"Last rating: {last_rating}", 1727 | ) 1728 | metrics["RMSE(bins)"] = rmse 1729 | metrics["AUC"] = ( 1730 | roc_auc_score( 1731 | y_true=calibration_data["y"], 1732 | y_score=calibration_data["p"], 1733 | ) 1734 | if len(calibration_data["y"].unique()) == 2 1735 | else np.nan 1736 | ) 1737 | metrics["LogLoss"] = log_loss( 1738 | y_true=calibration_data["y"], y_pred=calibration_data["p"] 1739 | ) 1740 | metrics_all[last_rating] = metrics 1741 | 1742 | fig3 = plt.figure() 1743 | self.calibration_helper( 1744 | dataset[["stability", "p", "y"]].copy(), 1745 | "stability", 1746 | lambda x: math.pow(1.2, math.floor(math.log(x, 1.2))), 1747 | True, 1748 | fig3.add_subplot(111), 1749 | ) 1750 | 1751 | fig4 = plt.figure(figsize=(16, 12)) 1752 | for last_rating in (1, 2, 3, 4): 1753 | calibration_data = dataset[dataset["last_rating"] == last_rating] 1754 | if calibration_data.empty: 1755 | continue 1756 | self.calibration_helper( 1757 | calibration_data[["stability", "p", "y"]].copy(), 1758 | "stability", 1759 | lambda x: math.pow(1.2, math.floor(math.log(x, 1.2))), 1760 | True, 1761 | fig4.add_subplot(2, 2, int(last_rating)), 1762 | ) 1763 | fig5 = plt.figure() 1764 | self.calibration_helper( 1765 | dataset[["difficulty", "p", "y"]].copy(), 1766 | "difficulty", 1767 | lambda x: round(x), 1768 | False, 1769 | fig5.add_subplot(111), 1770 | ) 1771 | return metrics_all, (fig1, fig2, fig3, fig4, fig5) 1772 | 1773 | def calibration_helper(self, calibration_data, key, bin_func, semilogx, ax1): 1774 | ax2 = ax1.twinx() 1775 | lns = [] 1776 | 1777 | def to_percent(temp, position): 1778 | return "%1.0f" % (100 * temp) + "%" 1779 | 1780 | calibration_data["bin"] = calibration_data[key].map(bin_func) 1781 | calibration_group = calibration_data.groupby("bin").count() 1782 | 1783 | lns1 = ax1.bar( 1784 | x=calibration_group.index, 1785 | height=calibration_group["y"], 1786 | width=calibration_group.index / 5.5 if key == "stability" else 0.8, 1787 | ec="k", 1788 | lw=0.2, 1789 | label="Number of predictions", 1790 | alpha=0.5, 1791 | ) 1792 | ax1.set_ylabel("Number of predictions") 1793 | ax1.set_xlabel(key.title()) 1794 | if semilogx: 1795 | ax1.semilogx() 1796 | lns.append(lns1) 1797 | 1798 | calibration_group = calibration_data.groupby(by="bin").agg("mean") 1799 | lns2 = ax2.plot(calibration_group["y"], label="Actual retention") 1800 | lns3 = ax2.plot(calibration_group["p"], label="Predicted retention") 1801 | ax2.set_ylabel("Retention") 1802 | ax2.set_ylim(0, 1) 1803 | lns.append(lns2[0]) 1804 | lns.append(lns3[0]) 1805 | 1806 | labs = [l.get_label() for l in lns] 1807 | ax2.legend(lns, labs, loc="lower right") 1808 | ax2.grid(linestyle="--") 1809 | ax2.yaxis.set_major_formatter(ticker.FuncFormatter(to_percent)) 1810 | return ax1 1811 | 1812 | def formula_analysis(self): 1813 | analysis_df = self.dataset[self.dataset["i"] > 2].copy() 1814 | analysis_df["tensor"] = analysis_df["tensor"].map(lambda x: x[:-1]) 1815 | my_collection = Collection(self.w, self.float_delta_t) 1816 | stabilities, difficulties = my_collection.batch_predict(analysis_df) 1817 | analysis_df["last_s"] = stabilities 1818 | analysis_df["last_d"] = difficulties 1819 | analysis_df["last_delta_t"] = analysis_df["t_history"].map( 1820 | lambda x: ( 1821 | int(x.split(",")[-1]) 1822 | if not self.float_delta_t 1823 | else float(x.split(",")[-1]) 1824 | ) 1825 | ) 1826 | analysis_df["last_r"] = power_forgetting_curve( 1827 | analysis_df["delta_t"], analysis_df["last_s"] 1828 | ) 1829 | analysis_df["last_s_bin"] = analysis_df["last_s"].map( 1830 | lambda x: math.pow(1.2, math.floor(math.log(x, 1.2))) 1831 | ) 1832 | analysis_df["last_d_bin"] = analysis_df["last_d"].map(lambda x: round(x)) 1833 | bins = 20 1834 | analysis_df["last_r_bin"] = analysis_df["last_r"].map( 1835 | lambda x: ( 1836 | np.log( 1837 | np.minimum(np.floor(np.exp(np.log(bins + 1) * x) - 1), bins - 1) + 1 1838 | ) 1839 | / np.log(bins) 1840 | ).round(3) 1841 | ) 1842 | figs = [] 1843 | for group_key in ("last_s_bin", "last_d_bin", "last_r_bin"): 1844 | for last_rating in (1, 3): 1845 | analysis_group = ( 1846 | analysis_df[analysis_df["last_rating"] == last_rating] 1847 | .groupby( 1848 | by=["last_s_bin", "last_d_bin", "last_r_bin", "delta_t"], 1849 | group_keys=True, 1850 | as_index=False, 1851 | ) 1852 | .agg( 1853 | { 1854 | "y": ["mean", "count"], 1855 | "p": "mean", 1856 | "stability": "mean", 1857 | "last_d": "mean", 1858 | } 1859 | ) 1860 | ) 1861 | analysis_group.columns = [ 1862 | "_".join(col_name).rstrip("_") 1863 | for col_name in analysis_group.columns 1864 | ] 1865 | 1866 | def cal_stability(tmp): 1867 | delta_t = tmp["delta_t"] 1868 | recall = tmp["y_mean"] 1869 | count = tmp["y_count"] 1870 | total_count = sum(count) 1871 | 1872 | tmp["true_s"] = fit_stability(delta_t, recall, count) 1873 | tmp["predicted_s"] = np.average( 1874 | tmp["stability_mean"], weights=count 1875 | ) 1876 | tmp["total_count"] = total_count 1877 | return tmp 1878 | 1879 | analysis_group = analysis_group.groupby( 1880 | by=[group_key], group_keys=False 1881 | ).apply(cal_stability) 1882 | analysis_group.dropna(inplace=True) 1883 | analysis_group.drop_duplicates(subset=[group_key], inplace=True) 1884 | analysis_group.sort_values(by=[group_key], inplace=True) 1885 | mape = mean_absolute_percentage_error( 1886 | analysis_group["true_s"], 1887 | analysis_group["predicted_s"], 1888 | sample_weight=analysis_group["total_count"], 1889 | ) 1890 | fig = plt.figure() 1891 | ax1 = fig.add_subplot(111) 1892 | ax1.set_title(f"MAPE={mape:.2f}, last rating={last_rating}") 1893 | ax1.scatter( 1894 | analysis_group[group_key], 1895 | analysis_group["true_s"], 1896 | s=np.sqrt(analysis_group["total_count"]), 1897 | label="True stability", 1898 | alpha=0.5, 1899 | ) 1900 | ax1.plot( 1901 | analysis_group[group_key], 1902 | analysis_group["predicted_s"], 1903 | label="Predicted stability", 1904 | color="orange", 1905 | ) 1906 | ax1.set_ylim(0, analysis_group["predicted_s"].max() * 1.1) 1907 | ax1.legend(loc="upper left") 1908 | ax1.set_xlabel(group_key) 1909 | if group_key == "last_s_bin": 1910 | ax1.set_ylim( 1911 | max(analysis_group["predicted_s"].min(), S_MIN), 1912 | analysis_group["predicted_s"].max() * 1.1, 1913 | ) 1914 | ax1.set_xscale("log") 1915 | ax1.set_yscale("log") 1916 | ax1.set_ylabel("Next Stability (days)") 1917 | ax1.grid() 1918 | ax1.xaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f")) 1919 | figs.append(fig) 1920 | return figs 1921 | 1922 | def bw_matrix(self): 1923 | B_W_Metric_raw = self.dataset[["difficulty", "stability", "p", "y"]].copy() 1924 | B_W_Metric_raw["s_bin"] = B_W_Metric_raw["stability"].map( 1925 | lambda x: round(math.pow(1.4, math.floor(math.log(x, 1.4))), 2) 1926 | ) 1927 | B_W_Metric_raw["d_bin"] = B_W_Metric_raw["difficulty"].map( 1928 | lambda x: int(round(x)) 1929 | ) 1930 | B_W_Metric = ( 1931 | B_W_Metric_raw.groupby(by=["s_bin", "d_bin"]).agg("mean").reset_index() 1932 | ) 1933 | B_W_Metric_count = ( 1934 | B_W_Metric_raw.groupby(by=["s_bin", "d_bin"]).agg("count").reset_index() 1935 | ) 1936 | B_W_Metric["B-W"] = B_W_Metric["p"] - B_W_Metric["y"] 1937 | n = len(self.dataset) 1938 | bins = len(B_W_Metric) 1939 | B_W_Metric_pivot = B_W_Metric[ 1940 | B_W_Metric_count["p"] > max(50, n / (3 * bins)) 1941 | ].pivot(index="s_bin", columns="d_bin", values="B-W") 1942 | return ( 1943 | B_W_Metric_pivot.apply(pd.to_numeric) 1944 | .style.background_gradient(cmap="seismic", axis=None, vmin=-0.2, vmax=0.2) 1945 | .format("{:.2%}", na_rep="") 1946 | ) 1947 | 1948 | def compare_with_sm2(self): 1949 | self.dataset["sm2_ivl"] = self.dataset["tensor"].map(sm2) 1950 | self.dataset["sm2_p"] = np.exp( 1951 | np.log(0.9) * self.dataset["delta_t"] / self.dataset["sm2_ivl"] 1952 | ) 1953 | self.dataset["log_loss"] = self.dataset.apply( 1954 | lambda row: ( 1955 | -np.log(row["sm2_p"]) if row["y"] == 1 else -np.log(1 - row["sm2_p"]) 1956 | ), 1957 | axis=1, 1958 | ) 1959 | tqdm.write(f"Loss of SM-2: {self.dataset['log_loss'].mean():.4f}") 1960 | dataset = self.dataset[["sm2_p", "p", "y"]].copy() 1961 | dataset.rename(columns={"sm2_p": "R (SM2)", "p": "R (FSRS)"}, inplace=True) 1962 | fig1 = plt.figure() 1963 | plot_brier( 1964 | dataset["R (SM2)"], 1965 | dataset["y"], 1966 | bins=20, 1967 | ax=fig1.add_subplot(111), 1968 | ) 1969 | universal_metrics, fig2 = cross_comparison(dataset, "SM2", "FSRS") 1970 | 1971 | tqdm.write(f"Universal Metric of FSRS: {universal_metrics[0]:.4f}") 1972 | tqdm.write(f"Universal Metric of SM2: {universal_metrics[1]:.4f}") 1973 | 1974 | return fig1, fig2 1975 | 1976 | 1977 | # code from https://github.com/papousek/duolingo-halflife-regression/blob/master/evaluation.py 1978 | def load_brier(predictions, real, bins=20): 1979 | # https://www.scirp.org/pdf/ojs_2021101415023495.pdf 1980 | # Note that my implementation isn't exactly the same as in the paper, but it still has good coverage, better than Clopper-Pearson 1981 | # I also made it possible to deal with k=0 and k=n, which was an issue with how this method is described in the paper 1982 | def likelihood_interval(k, n, alpha=0.05): 1983 | def log_likelihood(p: np.ndarray, k, n): 1984 | assert k <= n 1985 | p_hat = k / n 1986 | 1987 | def log_likelihood_f(k, n, p): 1988 | one_minus_p = np.ones_like(p) - p 1989 | if k == 0: 1990 | return n * np.log(one_minus_p) 1991 | elif k == n: 1992 | return k * np.log(p) 1993 | else: 1994 | return k * np.log(p) + (n - k) * np.log(one_minus_p) 1995 | 1996 | return log_likelihood_f(k, n, p) - log_likelihood_f(k, n, p_hat) 1997 | 1998 | def calc(x: np.ndarray, y: np.ndarray, target_p: float): 1999 | def loss(guess_y: float, target_p: float) -> float: 2000 | # Find segments where the horizontal line intersects the curve 2001 | # This creates a boolean array where True indicates a potential intersection 2002 | intersect_segments = ((y[:-1] <= guess_y) & (y[1:] >= guess_y)) | ( 2003 | (y[:-1] >= guess_y) & (y[1:] <= guess_y) 2004 | ) 2005 | 2006 | # Get indices of segments where intersections occur 2007 | intersection_indices = np.where(intersect_segments)[0] 2008 | 2009 | # If we don't have intersections, return a large error 2010 | if len(intersection_indices) < 2: 2011 | return 1e100 2012 | 2013 | # Find the first two intersection points (we only need two for a connected curve) 2014 | intersection_points = [] 2015 | 2016 | for idx in intersection_indices[ 2017 | :2 2018 | ]: # Take at most first two intersections 2019 | # Linear interpolation to find the x value at the intersection 2020 | x1, x2 = x[idx], x[idx + 1] 2021 | y1, y2 = y[idx], y[idx + 1] 2022 | 2023 | # If points are exactly the same, just take the x 2024 | if y1 == y2: 2025 | intersection_points.append(x1) 2026 | else: 2027 | # Linear interpolation 2028 | t = (guess_y - y1) / (y2 - y1) 2029 | intersection_x = x1 + t * (x2 - x1) 2030 | intersection_points.append(intersection_x) 2031 | 2032 | # Get the range bounds 2033 | x_low, x_high = min(intersection_points), max(intersection_points) 2034 | 2035 | # Find indices of x values that fall within our range 2036 | in_range = (x >= x_low) & (x <= x_high) 2037 | 2038 | # Calculate the sum of probabilities in the range 2039 | probability_sum = np.sum(y[in_range]) 2040 | 2041 | # Return the absolute difference from target probability 2042 | return abs(probability_sum - target_p) 2043 | 2044 | def bracket(xa, xb, maxiter, target_p): 2045 | u_lim = xa 2046 | l_lim = xb 2047 | 2048 | grow_limit = 100.0 2049 | gold = 1.6180339 2050 | verysmall_num = 1e-21 2051 | 2052 | fa = loss(xa, target_p) 2053 | fb = loss(xb, target_p) 2054 | funccalls = 2 2055 | 2056 | if fa < fb: # Switch so fa > fb 2057 | xa, xb = xb, xa 2058 | fa, fb = fb, fa 2059 | xc = max(min(xb + gold * (xb - xa), u_lim), l_lim) 2060 | fc = loss(xc, target_p) 2061 | funccalls += 1 2062 | 2063 | iter = 0 2064 | while fc < fb: 2065 | tmp1 = (xb - xa) * (fb - fc) 2066 | tmp2 = (xb - xc) * (fb - fa) 2067 | val = tmp2 - tmp1 2068 | if np.abs(val) < verysmall_num: 2069 | denom = 2.0 * verysmall_num 2070 | else: 2071 | denom = 2.0 * val 2072 | w = max( 2073 | min( 2074 | (xb - ((xb - xc) * tmp2 - (xb - xa) * tmp1) / denom), u_lim 2075 | ), 2076 | l_lim, 2077 | ) 2078 | wlim = max(min(xb + grow_limit * (xc - xb), u_lim), l_lim) 2079 | 2080 | if iter > maxiter: 2081 | print("Failed to converge") 2082 | break 2083 | 2084 | iter += 1 2085 | if (w - xc) * (xb - w) > 0.0: 2086 | fw = loss(w, target_p) 2087 | funccalls += 1 2088 | if fw < fc: 2089 | xa = max(min(xb, u_lim), l_lim) 2090 | xb = max(min(w, u_lim), l_lim) 2091 | fa = fb 2092 | fb = fw 2093 | break 2094 | elif fw > fb: 2095 | xc = max(min(w, u_lim), l_lim) 2096 | fc = fw 2097 | break 2098 | w = max(min(xc + gold * (xc - xb), u_lim), l_lim) 2099 | fw = loss(w, target_p) 2100 | funccalls += 1 2101 | elif (w - wlim) * (wlim - xc) >= 0.0: 2102 | w = wlim 2103 | fw = loss(w, target_p) 2104 | funccalls += 1 2105 | elif (w - wlim) * (xc - w) > 0.0: 2106 | fw = loss(w, target_p) 2107 | funccalls += 1 2108 | if fw < fc: 2109 | xb = max(min(xc, u_lim), l_lim) 2110 | xc = max(min(w, u_lim), l_lim) 2111 | w = max(min(xc + gold * (xc - xb), u_lim), l_lim) 2112 | fb = fc 2113 | fc = fw 2114 | fw = loss(w, target_p) 2115 | funccalls += 1 2116 | else: 2117 | w = max(min(xc + gold * (xc - xb), u_lim), l_lim) 2118 | fw = loss(w, target_p) 2119 | funccalls += 1 2120 | xa = max(min(xb, u_lim), l_lim) 2121 | xb = max(min(xc, u_lim), l_lim) 2122 | xc = max(min(w, u_lim), l_lim) 2123 | fa = fb 2124 | fb = fc 2125 | fc = fw 2126 | 2127 | return xa, xb, xc, fa, fb, fc, funccalls 2128 | 2129 | def brent_minimization(tol, maxiter): 2130 | mintol = 1.0e-11 2131 | cg = 0.3819660 2132 | 2133 | xa, xb, xc, fa, fb, fc, funccalls = bracket( 2134 | xa=min(y), xb=max(y), maxiter=maxiter, target_p=target_p 2135 | ) 2136 | 2137 | ################################# 2138 | # BEGIN 2139 | ################################# 2140 | x = w = v = xb 2141 | fw = fv = fx = fb 2142 | if xa < xc: 2143 | a = xa 2144 | b = xc 2145 | else: 2146 | a = xc 2147 | b = xa 2148 | deltax = 0.0 2149 | iter = 0 2150 | 2151 | while iter < maxiter: 2152 | tol1 = tol * np.abs(x) + mintol 2153 | tol2 = 2.0 * tol1 2154 | xmid = 0.5 * (a + b) 2155 | # check for convergence 2156 | if np.abs(x - xmid) < (tol2 - 0.5 * (b - a)): 2157 | break 2158 | 2159 | if np.abs(deltax) <= tol1: 2160 | if x >= xmid: 2161 | deltax = a - x # do a golden section step 2162 | else: 2163 | deltax = b - x 2164 | rat = cg * deltax 2165 | else: # do a parabolic step 2166 | tmp1 = (x - w) * (fx - fv) 2167 | tmp2 = (x - v) * (fx - fw) 2168 | p = (x - v) * tmp2 - (x - w) * tmp1 2169 | tmp2 = 2.0 * (tmp2 - tmp1) 2170 | if tmp2 > 0.0: 2171 | p = -p 2172 | tmp2 = np.abs(tmp2) 2173 | dx_temp = deltax 2174 | deltax = rat 2175 | # check parabolic fit 2176 | if ( 2177 | (p > tmp2 * (a - x)) 2178 | and (p < tmp2 * (b - x)) 2179 | and (np.abs(p) < np.abs(0.5 * tmp2 * dx_temp)) 2180 | ): 2181 | rat = p * 1.0 / tmp2 # if parabolic step is useful 2182 | u = x + rat 2183 | if (u - a) < tol2 or (b - u) < tol2: 2184 | if xmid - x >= 0: 2185 | rat = tol1 2186 | else: 2187 | rat = -tol1 2188 | else: 2189 | if x >= xmid: 2190 | deltax = a - x # if it's not do a golden section step 2191 | else: 2192 | deltax = b - x 2193 | rat = cg * deltax 2194 | 2195 | if np.abs(rat) < tol1: # update by at least tol1 2196 | if rat >= 0: 2197 | u = x + tol1 2198 | else: 2199 | u = x - tol1 2200 | else: 2201 | u = x + rat 2202 | fu = loss(u, target_p) # calculate new output value 2203 | funccalls += 1 2204 | 2205 | if fu > fx: # if it's bigger than current 2206 | if u < x: 2207 | a = u 2208 | else: 2209 | b = u 2210 | if (fu <= fw) or (w == x): 2211 | v = w 2212 | w = u 2213 | fv = fw 2214 | fw = fu 2215 | elif (fu <= fv) or (v == x) or (v == w): 2216 | v = u 2217 | fv = fu 2218 | else: 2219 | if u >= x: 2220 | a = x 2221 | else: 2222 | b = x 2223 | v = w 2224 | w = x 2225 | x = u 2226 | fv = fw 2227 | fw = fx 2228 | fx = fu 2229 | 2230 | iter += 1 2231 | # print(f'Iteration={iter}') 2232 | # print(f'x={x:.3f}') 2233 | ################################# 2234 | # END 2235 | ################################# 2236 | 2237 | xmin = x 2238 | fval = fx 2239 | 2240 | success = not (np.isnan(fval) or np.isnan(xmin)) and (0 <= xmin <= 1) 2241 | 2242 | if success: 2243 | # print(f'Loss function called {funccalls} times') 2244 | return xmin 2245 | else: 2246 | raise Exception( 2247 | "The algorithm terminated without finding a valid value." 2248 | ) 2249 | 2250 | best_guess_y = brent_minimization(1e-5, 50) 2251 | 2252 | intersect_segments = ( 2253 | (y[:-1] <= best_guess_y) & (y[1:] >= best_guess_y) 2254 | ) | ((y[:-1] >= best_guess_y) & (y[1:] <= best_guess_y)) 2255 | intersection_indices = np.where(intersect_segments)[0] 2256 | intersection_points = [] 2257 | 2258 | for idx in intersection_indices[:2]: 2259 | x1, x2 = x[idx], x[idx + 1] 2260 | y1, y2 = y[idx], y[idx + 1] 2261 | if y1 == y2: 2262 | intersection_points.append(x1) 2263 | else: 2264 | t = (best_guess_y - y1) / (y2 - y1) 2265 | intersection_x = x1 + t * (x2 - x1) 2266 | intersection_points.append(intersection_x) 2267 | 2268 | x_low, x_high = min(intersection_points), max(intersection_points) 2269 | in_range = (x >= x_low) & (x <= x_high) 2270 | probability_sum = np.sum(y[in_range]) 2271 | return x_low, x_high, probability_sum 2272 | 2273 | p_hat = k / n 2274 | # continuity correction 2275 | if k == 0 or k == n: 2276 | k, n = k + 0.5, n + 1 2277 | 2278 | probs = np.arange(1e-5, 1, 1e-5) 2279 | 2280 | likelihoods = np.exp(log_likelihood(probs, k, n)) 2281 | likelihoods = np.asarray(likelihoods) 2282 | y = likelihoods / np.sum(likelihoods) 2283 | 2284 | x_low_cred, x_high_cred, probsum = calc(probs, y, 1 - alpha) 2285 | assert 0 <= probsum <= 1 2286 | 2287 | if p_hat == 1.0: 2288 | x_high_cred = 1.0 2289 | elif p_hat == 0.0: 2290 | x_low_cred = 0.0 2291 | 2292 | assert not np.isnan(x_low_cred) 2293 | assert not np.isnan(x_high_cred) 2294 | assert ( 2295 | x_low_cred <= p_hat <= x_high_cred 2296 | ), f"{x_low_cred}, {p_hat}, {k / n}, {x_high_cred}" 2297 | return x_low_cred, x_high_cred 2298 | 2299 | counts = np.zeros(bins) 2300 | correct = np.zeros(bins) 2301 | prediction = np.zeros(bins) 2302 | 2303 | two_d_list = [[] for _ in range(bins)] 2304 | 2305 | def get_bin(x, bins=bins): 2306 | return np.floor(np.exp(np.log(bins + 1) * x)) - 1 2307 | 2308 | for p, r in zip(predictions, real): 2309 | bin = int(min(get_bin(p), bins - 1)) 2310 | counts[bin] += 1 2311 | correct[bin] += r 2312 | prediction[bin] += p 2313 | two_d_list[bin].append(r) # for confidence interval calculations 2314 | 2315 | np.seterr(invalid="ignore") 2316 | prediction_means = prediction / counts 2317 | real_means = correct / counts 2318 | size = len(predictions) 2319 | answer_mean = sum(correct) / size 2320 | 2321 | real_means_upper = [] 2322 | real_means_lower = [] 2323 | for n in range(len(two_d_list)): 2324 | if len(two_d_list[n]) > 0: 2325 | lower_bound, upper_bound = likelihood_interval( 2326 | sum(two_d_list[n]), len(two_d_list[n]) 2327 | ) 2328 | else: 2329 | lower_bound, upper_bound = float("NaN"), float("NaN") 2330 | real_means_upper.append(upper_bound) 2331 | real_means_lower.append(lower_bound) 2332 | 2333 | assert len(real_means_lower) == len(prediction_means) == len(real_means_upper) 2334 | # sanity check 2335 | for n in range(len(real_means)): 2336 | # check that the mean is within the bounds, unless they are NaNs 2337 | if not np.isnan(real_means_lower[n]): 2338 | assert ( 2339 | real_means_lower[n] <= real_means[n] <= real_means_upper[n] 2340 | ), f"{real_means_lower[n]:4f}, {real_means[n]:4f}, {real_means_upper[n]:4f}" 2341 | 2342 | return { 2343 | "reliability": sum(counts * (real_means - prediction_means) ** 2) / size, 2344 | "resolution": sum(counts * (real_means - answer_mean) ** 2) / size, 2345 | "uncertainty": answer_mean * (1 - answer_mean), 2346 | "detail": { 2347 | "bin_count": bins, 2348 | "bin_counts": counts, 2349 | "bin_prediction_means": prediction_means, 2350 | "bin_real_means_upper_bounds": real_means_upper, 2351 | "bin_real_means_lower_bounds": real_means_lower, 2352 | "bin_real_means": real_means, 2353 | }, 2354 | } 2355 | 2356 | 2357 | def plot_brier(predictions, real, bins=20, ax=None, title=None): 2358 | y, p = zip(*sorted(zip(real, predictions), key=lambda x: x[1])) 2359 | observation = lowess( 2360 | y, p, it=0, delta=0.01 * (max(p) - min(p)), is_sorted=True, return_sorted=False 2361 | ) 2362 | ici = np.mean(np.abs(observation - p)) 2363 | e_50 = np.median(np.abs(observation - p)) 2364 | e_90 = np.quantile(np.abs(observation - p), 0.9) 2365 | e_max = np.max(np.abs(observation - p)) 2366 | brier = load_brier(predictions, real, bins=bins) 2367 | bin_prediction_means = brier["detail"]["bin_prediction_means"] 2368 | 2369 | bin_real_means = brier["detail"]["bin_real_means"] 2370 | bin_real_means_upper_bounds = brier["detail"]["bin_real_means_upper_bounds"] 2371 | bin_real_means_lower_bounds = brier["detail"]["bin_real_means_lower_bounds"] 2372 | bin_real_means_errors_upper = bin_real_means_upper_bounds - bin_real_means 2373 | bin_real_means_errors_lower = bin_real_means - bin_real_means_lower_bounds 2374 | 2375 | bin_counts = brier["detail"]["bin_counts"] 2376 | mask = bin_counts > 0 2377 | r2 = r2_score( 2378 | bin_real_means[mask], 2379 | bin_prediction_means[mask], 2380 | sample_weight=bin_counts[mask], 2381 | ) 2382 | mae = mean_absolute_error( 2383 | bin_real_means[mask], 2384 | bin_prediction_means[mask], 2385 | sample_weight=bin_counts[mask], 2386 | ) 2387 | ax.set_xlim([0, 1]) 2388 | ax.set_ylim([0, 1]) 2389 | ax.grid(True) 2390 | try: 2391 | fit_wls = sm.WLS( 2392 | bin_real_means[mask], 2393 | sm.add_constant(bin_prediction_means[mask]), 2394 | weights=bin_counts[mask], 2395 | ).fit() 2396 | params = fit_wls.params 2397 | y_regression = [params[0] + params[1] * x for x in [0, 1]] 2398 | ax.plot( 2399 | [0, 1], 2400 | y_regression, 2401 | label=f"y = {params[0]:.3f} + {params[1]:.3f}x", 2402 | color="green", 2403 | ) 2404 | except: 2405 | pass 2406 | # ax.plot( 2407 | # bin_prediction_means[mask], 2408 | # bin_correct_means[mask], 2409 | # label="Actual Calibration", 2410 | # color="#1f77b4", 2411 | # marker="*", 2412 | # ) 2413 | assert not any(np.isnan(bin_real_means_errors_upper[mask])) 2414 | assert not any(np.isnan(bin_real_means_errors_lower[mask])) 2415 | ax.errorbar( 2416 | bin_prediction_means[mask], 2417 | bin_real_means[mask], 2418 | yerr=[bin_real_means_errors_lower[mask], bin_real_means_errors_upper[mask]], 2419 | label="Actual Calibration", 2420 | color="#1f77b4", 2421 | ecolor="black", 2422 | elinewidth=1.0, 2423 | capsize=3.5, 2424 | capthick=1.0, 2425 | marker="", 2426 | ) 2427 | # ax.plot(p, observation, label="Lowess Smoothing", color="red") 2428 | ax.plot((0, 1), (0, 1), label="Perfect Calibration", color="#ff7f0e") 2429 | bin_count = brier["detail"]["bin_count"] 2430 | counts = np.array(bin_counts) 2431 | bins = np.log((np.arange(bin_count)) + 1) / np.log(bin_count + 1) 2432 | widths = np.diff(bins) 2433 | widths = np.append(widths, 1 - bins[-1]) 2434 | ax.legend(loc="upper center") 2435 | ax.set_xlabel("Predicted R") 2436 | ax.set_ylabel("Actual R") 2437 | ax2 = ax.twinx() 2438 | ax2.set_ylabel("Number of reviews") 2439 | ax2.bar( 2440 | bins, 2441 | counts, 2442 | width=widths, 2443 | ec="k", 2444 | linewidth=0, 2445 | alpha=0.5, 2446 | label="Number of reviews", 2447 | align="edge", 2448 | ) 2449 | ax2.legend(loc="lower center") 2450 | if title: 2451 | ax.set_title(title) 2452 | metrics = { 2453 | "R-squared": r2, 2454 | "MAE": mae, 2455 | "ICI": ici, 2456 | "E50": e_50, 2457 | "E90": e_90, 2458 | "EMax": e_max, 2459 | } 2460 | return metrics 2461 | 2462 | 2463 | def sm2(history): 2464 | ivl = 0 2465 | ef = 2.5 2466 | reps = 0 2467 | for delta_t, rating in history: 2468 | delta_t = delta_t.item() 2469 | rating = rating.item() + 1 2470 | if rating > 2: 2471 | if reps == 0: 2472 | ivl = 1 2473 | reps = 1 2474 | elif reps == 1: 2475 | ivl = 6 2476 | reps = 2 2477 | else: 2478 | ivl = ivl * ef 2479 | reps += 1 2480 | else: 2481 | ivl = 1 2482 | reps = 0 2483 | ef = max(1.3, ef + (0.1 - (5 - rating) * (0.08 + (5 - rating) * 0.02))) 2484 | ivl = max(1, round(ivl + 0.01)) 2485 | return ivl 2486 | 2487 | 2488 | def cross_comparison(dataset, algoA, algoB): 2489 | if algoA != algoB: 2490 | cross_comparison_record = dataset[[f"R ({algoA})", f"R ({algoB})", "y"]].copy() 2491 | bin_algo = ( 2492 | algoA, 2493 | algoB, 2494 | ) 2495 | pair_algo = [(algoA, algoB), (algoB, algoA)] 2496 | else: 2497 | cross_comparison_record = dataset[[f"R ({algoA})", "y"]].copy() 2498 | bin_algo = (algoA,) 2499 | pair_algo = [(algoA, algoA)] 2500 | 2501 | def get_bin(x, bins=20): 2502 | return ( 2503 | np.log(np.minimum(np.floor(np.exp(np.log(bins + 1) * x) - 1), bins - 1) + 1) 2504 | / np.log(bins) 2505 | ).round(3) 2506 | 2507 | for algo in bin_algo: 2508 | cross_comparison_record[f"{algo}_B-W"] = ( 2509 | cross_comparison_record[f"R ({algo})"] - cross_comparison_record["y"] 2510 | ) 2511 | cross_comparison_record[f"{algo}_bin"] = cross_comparison_record[ 2512 | f"R ({algo})" 2513 | ].map(get_bin) 2514 | 2515 | fig = plt.figure(figsize=(6, 6)) 2516 | ax = fig.gca() 2517 | ax.axhline(y=0.0, color="black", linestyle="-") 2518 | 2519 | universal_metric_list = [] 2520 | 2521 | for algoA, algoB in pair_algo: 2522 | cross_comparison_group = cross_comparison_record.groupby(by=f"{algoA}_bin").agg( 2523 | {"y": ["mean"], f"{algoB}_B-W": ["mean"], f"R ({algoB})": ["mean", "count"]} 2524 | ) 2525 | universal_metric = root_mean_squared_error( 2526 | y_true=cross_comparison_group["y", "mean"], 2527 | y_pred=cross_comparison_group[f"R ({algoB})", "mean"], 2528 | sample_weight=cross_comparison_group[f"R ({algoB})", "count"], 2529 | ) 2530 | cross_comparison_group[f"R ({algoB})", "percent"] = ( 2531 | cross_comparison_group[f"R ({algoB})", "count"] 2532 | / cross_comparison_group[f"R ({algoB})", "count"].sum() 2533 | ) 2534 | ax.scatter( 2535 | cross_comparison_group.index, 2536 | cross_comparison_group[f"{algoB}_B-W", "mean"], 2537 | s=cross_comparison_group[f"R ({algoB})", "percent"] * 1024, 2538 | alpha=0.5, 2539 | ) 2540 | ax.plot( 2541 | cross_comparison_group[f"{algoB}_B-W", "mean"], 2542 | label=f"{algoB} by {algoA}, UM={universal_metric:.4f}", 2543 | ) 2544 | universal_metric_list.append(universal_metric) 2545 | 2546 | ax.legend(loc="lower center") 2547 | ax.grid(linestyle="--") 2548 | ax.set_title(f"{algoA} vs {algoB}") 2549 | ax.set_xlabel("Predicted R") 2550 | ax.set_ylabel("B-W Metric") 2551 | ax.set_xlim(0, 1) 2552 | ax.set_xticks(np.arange(0, 1.1, 0.1)) 2553 | return universal_metric_list, fig 2554 | 2555 | 2556 | def rmse_matrix(df): 2557 | tmp = df.copy() 2558 | 2559 | def count_lapse(r_history, t_history): 2560 | lapse = 0 2561 | for r, t in zip(r_history.split(","), t_history.split(",")): 2562 | if t != "0" and r == "1": 2563 | lapse += 1 2564 | return lapse 2565 | 2566 | tmp["lapse"] = tmp.apply( 2567 | lambda x: count_lapse(x["r_history"], x["t_history"]), axis=1 2568 | ) 2569 | tmp["delta_t"] = tmp["delta_t"].map( 2570 | lambda x: round(2.48 * np.power(3.62, np.floor(np.log(x) / np.log(3.62))), 2) 2571 | ) 2572 | tmp["i"] = tmp["i"].map( 2573 | lambda x: round(1.99 * np.power(1.89, np.floor(np.log(x) / np.log(1.89))), 0) 2574 | ) 2575 | tmp["lapse"] = tmp["lapse"].map( 2576 | lambda x: ( 2577 | round(1.65 * np.power(1.73, np.floor(np.log(x) / np.log(1.73))), 0) 2578 | if x != 0 2579 | else 0 2580 | ) 2581 | ) 2582 | if "weights" not in tmp.columns: 2583 | tmp["weights"] = 1 2584 | tmp = ( 2585 | tmp.groupby(["delta_t", "i", "lapse"]) 2586 | .agg({"y": "mean", "p": "mean", "weights": "sum"}) 2587 | .reset_index() 2588 | ) 2589 | return root_mean_squared_error(tmp["y"], tmp["p"], sample_weight=tmp["weights"]) 2590 | 2591 | 2592 | def wrap_short_term_ratings(r_history, t_history): 2593 | result = [] 2594 | in_zero_sequence = False 2595 | 2596 | for t, r in zip(t_history.split(","), r_history.split(",")): 2597 | if t in ("-1", "0"): 2598 | if not in_zero_sequence: 2599 | result.append("(") 2600 | in_zero_sequence = True 2601 | result.append(r) 2602 | result.append(",") 2603 | else: 2604 | if in_zero_sequence: 2605 | result[-1] = ")," 2606 | in_zero_sequence = False 2607 | result.append(r) 2608 | result.append(",") 2609 | 2610 | if in_zero_sequence: 2611 | result[-1] = ")" 2612 | else: 2613 | result.pop() 2614 | return "".join(result) 2615 | 2616 | 2617 | class FirstOrderMarkovChain: 2618 | def __init__(self, n_states=4): 2619 | """ 2620 | Initialize a first-order Markov chain model 2621 | 2622 | Parameters: 2623 | n_states: Number of states, default is 4 (corresponding to states 1,2,3,4) 2624 | """ 2625 | self.n_states = n_states 2626 | self.transition_matrix = None 2627 | self.initial_distribution = None 2628 | self.transition_counts = None 2629 | self.initial_counts = None 2630 | 2631 | def fit(self, sequences, smoothing=1.0): 2632 | """ 2633 | Fit the Markov chain model based on given sequences 2634 | 2635 | Parameters: 2636 | sequences: List of sequences, each sequence is a list containing 1,2,3,4 2637 | smoothing: Laplace smoothing parameter to avoid zero probability issues 2638 | """ 2639 | # Initialize transition count matrix and initial state counts 2640 | self.transition_counts = np.zeros((self.n_states, self.n_states)) 2641 | self.initial_counts = np.zeros(self.n_states) 2642 | 2643 | # Count transition frequencies and initial state frequencies 2644 | for sequence in sequences: 2645 | if len(sequence) == 0: 2646 | continue 2647 | 2648 | # Record initial state 2649 | self.initial_counts[sequence[0] - 1] += 1 2650 | 2651 | # Record transitions 2652 | for i in range(len(sequence) - 1): 2653 | current_state = sequence[i] - 1 # Convert to 0-indexed 2654 | next_state = sequence[i + 1] - 1 # Convert to 0-indexed 2655 | self.transition_counts[current_state, next_state] += 1 2656 | 2657 | # Apply Laplace smoothing and calculate probabilities 2658 | self.transition_counts += smoothing 2659 | self.initial_counts += smoothing 2660 | 2661 | # Calculate transition probability matrix 2662 | self.transition_matrix = np.zeros((self.n_states, self.n_states)) 2663 | for i in range(self.n_states): 2664 | row_sum = np.sum(self.transition_counts[i]) 2665 | if row_sum > 0: 2666 | self.transition_matrix[i] = self.transition_counts[i] / row_sum 2667 | else: 2668 | # If a state never appears, assume uniform distribution 2669 | self.transition_matrix[i] = np.ones(self.n_states) / self.n_states 2670 | 2671 | # Calculate initial state distribution 2672 | self.initial_distribution = self.initial_counts / np.sum(self.initial_counts) 2673 | 2674 | return self 2675 | 2676 | def generate_sequence(self, length): 2677 | """ 2678 | Generate a new sequence 2679 | 2680 | Parameters: 2681 | length: Length of the sequence to generate 2682 | 2683 | Returns: 2684 | Generated sequence (elements are 1,2,3,4) 2685 | """ 2686 | if self.transition_matrix is None or self.initial_distribution is None: 2687 | raise ValueError("Model not yet fitted, please call the fit method first") 2688 | 2689 | sequence = [] 2690 | 2691 | # Generate initial state 2692 | current_state = np.random.choice(self.n_states, p=self.initial_distribution) 2693 | sequence.append(current_state + 1) # Convert to 1-indexed 2694 | 2695 | # Generate subsequent states 2696 | for _ in range(length - 1): 2697 | current_state = np.random.choice( 2698 | self.n_states, p=self.transition_matrix[current_state] 2699 | ) 2700 | sequence.append(current_state + 1) # Convert to 1-indexed 2701 | 2702 | return sequence 2703 | 2704 | def log_likelihood(self, sequences): 2705 | """ 2706 | Calculate the log-likelihood of sequences 2707 | 2708 | Parameters: 2709 | sequences: List of sequences 2710 | 2711 | Returns: 2712 | Log-likelihood value 2713 | """ 2714 | if self.transition_matrix is None or self.initial_distribution is None: 2715 | raise ValueError("Model not yet fitted, please call the fit method first") 2716 | 2717 | log_likelihood = 0.0 2718 | 2719 | for sequence in sequences: 2720 | if len(sequence) == 0: 2721 | continue 2722 | 2723 | # Log probability of initial state 2724 | log_likelihood += np.log(self.initial_distribution[sequence[0] - 1]) 2725 | 2726 | # Log probability of transitions 2727 | for i in range(len(sequence) - 1): 2728 | current_state = sequence[i] - 1 2729 | next_state = sequence[i + 1] - 1 2730 | log_likelihood += np.log( 2731 | self.transition_matrix[current_state, next_state] 2732 | ) 2733 | 2734 | return log_likelihood 2735 | 2736 | def print_model(self): 2737 | """Print model parameters""" 2738 | print("Initial state distribution:") 2739 | for i in range(self.n_states): 2740 | print(f"State {i+1}: {self.initial_distribution[i]:.4f}") 2741 | 2742 | print("\nTransition probability matrix:") 2743 | print(" | " + " ".join([f" {i+1} " for i in range(self.n_states)])) 2744 | print("----+" + "------" * self.n_states) 2745 | for i in range(self.n_states): 2746 | print( 2747 | f" {i+1} | " 2748 | + " ".join( 2749 | [f"{self.transition_matrix[i,j]:.4f}" for j in range(self.n_states)] 2750 | ) 2751 | ) 2752 | 2753 | print("Initial counts:") 2754 | print(self.initial_counts.astype(int)) 2755 | print("Transition counts:") 2756 | print(self.transition_counts.astype(int)) 2757 | -------------------------------------------------------------------------------- /src/fsrs_optimizer/fsrs_simulator.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | from typing import Optional 5 | from concurrent.futures import ProcessPoolExecutor, as_completed 6 | 7 | 8 | DECAY = -0.1542 9 | FACTOR = 0.9 ** (1 / DECAY) - 1 10 | S_MIN = 0.001 11 | 12 | 13 | def power_forgetting_curve(t, s, decay=DECAY): 14 | factor = 0.9 ** (1 / decay) - 1 15 | return (1 + factor * t / s) ** decay 16 | 17 | 18 | def next_interval( 19 | s, r, float_ivl: bool = False, fuzz: bool = False, decay: float = DECAY 20 | ): 21 | factor = 0.9 ** (1 / decay) - 1 22 | ivl = s / factor * (r ** (1 / decay) - 1) 23 | if float_ivl: 24 | ivl = np.round(ivl, 6) 25 | else: 26 | ivl = np.maximum(1, np.round(ivl).astype(int)) 27 | if fuzz: 28 | fuzz_mask = ivl >= 3 29 | ivl[fuzz_mask] = fuzz_interval(ivl[fuzz_mask]) 30 | return ivl 31 | 32 | 33 | FUZZ_RANGES = [ 34 | { 35 | "start": 2.5, 36 | "end": 7.0, 37 | "factor": 0.15, 38 | }, 39 | { 40 | "start": 7.0, 41 | "end": 20.0, 42 | "factor": 0.1, 43 | }, 44 | { 45 | "start": 20.0, 46 | "end": math.inf, 47 | "factor": 0.05, 48 | }, 49 | ] 50 | 51 | 52 | def get_fuzz_range(interval): 53 | delta = np.ones_like(interval, dtype=float) 54 | for range in FUZZ_RANGES: 55 | delta += range["factor"] * np.maximum( 56 | np.minimum(interval, range["end"]) - range["start"], 0.0 57 | ) 58 | min_ivl = np.round(interval - delta).astype(int) 59 | max_ivl = np.round(interval + delta).astype(int) 60 | min_ivl = np.maximum(2, min_ivl) 61 | min_ivl = np.minimum(min_ivl, max_ivl) 62 | return min_ivl, max_ivl 63 | 64 | 65 | def fuzz_interval(interval): 66 | min_ivl, max_ivl = get_fuzz_range(interval) 67 | # max_ivl + 1 because randint upper bound is exclusive 68 | return np.random.randint(min_ivl, max_ivl + 1, size=min_ivl.shape) 69 | 70 | 71 | columns = [ 72 | "difficulty", 73 | "stability", 74 | "retrievability", 75 | "delta_t", 76 | "reps", 77 | "lapses", 78 | "last_date", 79 | "due", 80 | "ivl", 81 | "cost", 82 | "rand", 83 | "rating", 84 | "ease", 85 | ] 86 | col = {key: i for i, key in enumerate(columns)} 87 | 88 | DEFAULT_LEARN_COSTS = np.array([33.79, 24.3, 13.68, 6.5]) 89 | DEFAULT_REVIEW_COSTS = np.array([23.0, 11.68, 7.33, 5.6]) 90 | DEFAULT_FIRST_RATING_PROB = np.array([0.24, 0.094, 0.495, 0.171]) 91 | DEFAULT_REVIEW_RATING_PROB = np.array([0.224, 0.631, 0.145]) 92 | DEFAULT_LEARNING_STEP_COUNT = 2 93 | DEFAULT_RELEARNING_STEP_COUNT = 1 94 | DEFAULT_LEARNING_STEP_TRANSITIONS = np.array( 95 | [ 96 | [0.3687, 0.0628, 0.5108, 0.0577], 97 | [0.0441, 0.4553, 0.4457, 0.0549], 98 | [0.0518, 0.0470, 0.8462, 0.0550], 99 | ], 100 | ) 101 | DEFAULT_RELEARNING_STEP_TRANSITIONS = np.array( 102 | [ 103 | [0.2157, 0.0643, 0.6595, 0.0605], 104 | [0.0500, 0.4638, 0.4475, 0.0387], 105 | [0.1056, 0.1434, 0.7266, 0.0244], 106 | ], 107 | ) 108 | DEFAULT_STATE_RATING_COSTS = np.array( 109 | [ 110 | [12.75, 12.26, 8.0, 6.38], 111 | [13.05, 11.74, 7.42, 5.6], 112 | [10.56, 10.0, 7.37, 5.4], 113 | ] 114 | ) 115 | 116 | 117 | def simulate( 118 | w, 119 | request_retention=0.9, 120 | deck_size=10000, 121 | learn_span=365, 122 | max_cost_perday=1800, 123 | learn_limit_perday=math.inf, 124 | review_limit_perday=math.inf, 125 | max_ivl=36500, 126 | first_rating_prob=DEFAULT_FIRST_RATING_PROB, 127 | review_rating_prob=DEFAULT_REVIEW_RATING_PROB, 128 | learning_step_count=DEFAULT_LEARNING_STEP_COUNT, 129 | relearning_step_count=DEFAULT_RELEARNING_STEP_COUNT, 130 | learning_step_transitions=DEFAULT_LEARNING_STEP_TRANSITIONS, 131 | relearning_step_transitions=DEFAULT_RELEARNING_STEP_TRANSITIONS, 132 | state_rating_costs=DEFAULT_STATE_RATING_COSTS, 133 | seed=42, 134 | fuzz=False, 135 | scheduler_name="fsrs", 136 | ): 137 | np.random.seed(seed) 138 | card_table = np.zeros((len(columns), deck_size)) 139 | card_table[col["due"]] = learn_span 140 | card_table[col["difficulty"]] = 1e-10 141 | card_table[col["stability"]] = 1e-10 142 | card_table[col["rating"]] = np.random.choice( 143 | [1, 2, 3, 4], deck_size, p=first_rating_prob 144 | ) 145 | card_table[col["rating"]] = card_table[col["rating"]].astype(int) 146 | card_table[col["ease"]] = 0 147 | 148 | revlogs = {} 149 | review_cnt_per_day = np.zeros(learn_span) 150 | learn_cnt_per_day = np.zeros(learn_span) 151 | memorized_cnt_per_day = np.zeros(learn_span) 152 | cost_per_day = np.zeros(learn_span) 153 | 154 | # Anki scheduler constants 155 | GRADUATING_IVL = 1 156 | EASY_IVL = 4 157 | NEW_IVL = 0 158 | HARD_IVL = 1.2 159 | INTERVAL_MODIFIER = 1.0 160 | EASY_BONUS = 1.3 161 | MIN_IVL = 1 162 | 163 | def anki_sm2_scheduler(scheduled_interval, real_interval, ease, rating): 164 | # Handle new cards (ease == 0) 165 | is_new_card = ease == 0 166 | new_card_intervals = np.where(rating == 4, EASY_IVL, GRADUATING_IVL) 167 | new_card_eases = np.full_like(ease, 2.5) 168 | 169 | # Handle review cards 170 | delay = real_interval - scheduled_interval 171 | 172 | # Calculate intervals for each rating 173 | again_interval = np.minimum( 174 | np.maximum( 175 | np.round(scheduled_interval * NEW_IVL * INTERVAL_MODIFIER + 0.01), 176 | MIN_IVL, 177 | ), 178 | max_ivl, 179 | ) 180 | hard_interval = np.minimum( 181 | np.maximum( 182 | np.round(scheduled_interval * HARD_IVL * INTERVAL_MODIFIER + 0.01), 183 | np.maximum(scheduled_interval + 1, MIN_IVL), 184 | ), 185 | max_ivl, 186 | ) 187 | good_interval = np.minimum( 188 | np.maximum( 189 | np.round( 190 | (scheduled_interval + delay / 2) * ease * INTERVAL_MODIFIER + 0.01 191 | ), 192 | np.maximum(hard_interval + 1, MIN_IVL), 193 | ), 194 | max_ivl, 195 | ) 196 | easy_interval = np.minimum( 197 | np.maximum( 198 | np.round(real_interval * ease * INTERVAL_MODIFIER * EASY_BONUS + 0.01), 199 | np.maximum(good_interval + 1, MIN_IVL), 200 | ), 201 | max_ivl, 202 | ) 203 | 204 | # Select intervals based on rating 205 | review_intervals = np.choose( 206 | rating - 1, [again_interval, hard_interval, good_interval, easy_interval] 207 | ) 208 | 209 | # Calculate new eases 210 | review_eases = np.choose( 211 | rating - 1, 212 | [ 213 | ease - 0.2, 214 | ease - 0.15, 215 | ease, 216 | ease + 0.15, 217 | ], 218 | ) 219 | review_eases = np.maximum(review_eases, 1.3) 220 | 221 | # Combine new card and review card results 222 | intervals = np.where(is_new_card, new_card_intervals, review_intervals) 223 | eases = np.where(is_new_card, new_card_eases, review_eases) 224 | 225 | return intervals, eases 226 | 227 | def stability_after_success(s, r, d, rating): 228 | hard_penalty = np.where(rating == 2, w[15], 1) 229 | easy_bonus = np.where(rating == 4, w[16], 1) 230 | return np.maximum( 231 | S_MIN, 232 | s 233 | * ( 234 | 1 235 | + np.exp(w[8]) 236 | * (11 - d) 237 | * np.power(s, -w[9]) 238 | * (np.exp((1 - r) * w[10]) - 1) 239 | * hard_penalty 240 | * easy_bonus 241 | ), 242 | ) 243 | 244 | def stability_after_failure(s, r, d): 245 | return np.maximum( 246 | S_MIN, 247 | np.minimum( 248 | w[11] 249 | * np.power(d, -w[12]) 250 | * (np.power(s + 1, w[13]) - 1) 251 | * np.exp((1 - r) * w[14]), 252 | s / np.exp(w[17] * w[18]), 253 | ), 254 | ) 255 | 256 | MAX_RELEARN_STEPS = 5 257 | 258 | # learn_state: 1: Learning, 2: Review, 3: Relearning 259 | def memory_state_short_term( 260 | s: np.ndarray, d: np.ndarray, init_rating: Optional[np.ndarray] = None 261 | ): 262 | if init_rating is not None: 263 | s = np.choose(init_rating - 1, w) 264 | d = np.clip(init_d(init_rating), 1, 10) 265 | costs = state_rating_costs[0] 266 | max_consecutive = learning_step_count - np.choose( 267 | init_rating - 1, [0, 0, 1, 1] 268 | ) 269 | cost = np.choose(init_rating - 1, costs).sum() 270 | else: 271 | costs = state_rating_costs[2] 272 | max_consecutive = relearning_step_count 273 | cost = 0 274 | 275 | def step(s, next_weights): 276 | rating = np.random.choice([1, 2, 3, 4], p=next_weights) 277 | sinc = (math.e ** (w[17] * (rating - 3 + w[18]))) * (s ** -w[19]) 278 | new_s = s * (sinc.clip(min=1) if rating >= 3 else sinc) 279 | 280 | return (new_s, rating) 281 | 282 | def loop(s, d, max_consecutive, init_rating): 283 | i = 0 284 | consecutive = 0 285 | step_transitions = ( 286 | relearning_step_transitions 287 | if init_rating is None 288 | else learning_step_transitions 289 | ) 290 | rating = init_rating or 1 291 | cost = 0 292 | while ( 293 | i < MAX_RELEARN_STEPS and consecutive < max_consecutive and rating < 4 294 | ): 295 | (s, rating) = step(s, step_transitions[rating - 1]) 296 | d = next_d(d, rating) 297 | cost += costs[rating - 1] 298 | i += 1 299 | if rating > 2: 300 | consecutive += 1 301 | elif rating == 1: 302 | consecutive = 0 303 | 304 | return s, d, cost 305 | 306 | if len(s) != 0: 307 | new_s, new_d, cost = np.vectorize(loop, otypes=["float", "float", "float"])( 308 | s, d, max_consecutive, init_rating 309 | ) 310 | else: 311 | new_s, new_d, cost = [], [], [] 312 | 313 | return new_s, new_d, cost 314 | 315 | def init_d(rating): 316 | return w[4] - np.exp(w[5] * (rating - 1)) + 1 317 | 318 | def linear_damping(delta_d, old_d): 319 | return delta_d * (10 - old_d) / 9 320 | 321 | def next_d(d, rating): 322 | delta_d = -w[6] * (rating - 3) 323 | new_d = d + linear_damping(delta_d, d) 324 | new_d = mean_reversion(init_d(4), new_d) 325 | return np.clip(new_d, 1, 10) 326 | 327 | def mean_reversion(init, current): 328 | return w[7] * init + (1 - w[7]) * current 329 | 330 | for today in range(learn_span): 331 | new_s = np.copy(card_table[col["stability"]]) 332 | new_d = np.copy(card_table[col["difficulty"]]) 333 | 334 | has_learned = card_table[col["stability"]] > 1e-10 335 | card_table[col["delta_t"]][has_learned] = ( 336 | today - card_table[col["last_date"]][has_learned] 337 | ) 338 | card_table[col["retrievability"]][has_learned] = power_forgetting_curve( 339 | card_table[col["delta_t"]][has_learned], 340 | card_table[col["stability"]][has_learned], 341 | -w[20], 342 | ) 343 | card_table[col["cost"]] = 0 344 | need_review = card_table[col["due"]] <= today 345 | card_table[col["rand"]][need_review] = np.random.rand(np.sum(need_review)) 346 | forget = card_table[col["rand"]] > card_table[col["retrievability"]] 347 | card_table[col["rating"]][need_review & forget] = 1 348 | card_table[col["rating"]][need_review & ~forget] = np.random.choice( 349 | [2, 3, 4], np.sum(need_review & ~forget), p=review_rating_prob 350 | ) 351 | card_table[col["cost"]][need_review] = np.choose( 352 | card_table[col["rating"]][need_review].astype(int) - 1, 353 | state_rating_costs[1], 354 | ) 355 | true_review = need_review & (np.cumsum(need_review) <= review_limit_perday) 356 | card_table[col["last_date"]][true_review] = today 357 | 358 | card_table[col["lapses"]][true_review & forget] += 1 359 | card_table[col["reps"]][true_review & ~forget] += 1 360 | 361 | new_s[true_review & forget] = stability_after_failure( 362 | card_table[col["stability"]][true_review & forget], 363 | card_table[col["retrievability"]][true_review & forget], 364 | card_table[col["difficulty"]][true_review & forget], 365 | ) 366 | new_d[true_review & forget] = next_d( 367 | card_table[col["difficulty"]][true_review & forget], 368 | card_table[col["rating"]][true_review & forget], 369 | ) 370 | ( 371 | card_table[col["stability"]][true_review & forget], 372 | card_table[col["difficulty"]][true_review & forget], 373 | costs, 374 | ) = memory_state_short_term( 375 | new_s[true_review & forget], 376 | new_d[true_review & forget], 377 | ) 378 | new_s[true_review & ~forget] = stability_after_success( 379 | new_s[true_review & ~forget], 380 | card_table[col["retrievability"]][true_review & ~forget], 381 | new_d[true_review & ~forget], 382 | card_table[col["rating"]][true_review & ~forget], 383 | ) 384 | 385 | new_d[true_review & ~forget] = next_d( 386 | card_table[col["difficulty"]][true_review & ~forget], 387 | card_table[col["rating"]][true_review & ~forget], 388 | ) 389 | 390 | card_table[col["cost"]][true_review & forget] = [ 391 | a + b for a, b in zip(card_table[col["cost"]][true_review & forget], costs) 392 | ] 393 | 394 | need_learn = card_table[col["stability"]] == 1e-10 395 | card_table[col["cost"]][need_learn] = np.choose( 396 | card_table[col["rating"]][need_learn].astype(int) - 1, 397 | state_rating_costs[0], 398 | ) 399 | true_learn = need_learn & (np.cumsum(need_learn) <= learn_limit_perday) 400 | card_table[col["last_date"]][true_learn] = today 401 | new_s[true_learn] = np.choose( 402 | card_table[col["rating"]][true_learn].astype(int) - 1, w[:4] 403 | ) 404 | ( 405 | new_s[true_learn], 406 | new_d[true_learn], 407 | costs, 408 | ) = memory_state_short_term( 409 | new_s[true_learn], 410 | new_d[true_learn], 411 | init_rating=card_table[col["rating"]][true_learn].astype(int), 412 | ) 413 | 414 | card_table[col["cost"]][true_learn] = [ 415 | a + b for a, b in zip(card_table[col["cost"]][true_learn], costs) 416 | ] 417 | 418 | below_cost_limit = np.cumsum(card_table[col["cost"]]) <= max_cost_perday 419 | reviewed = (true_review | true_learn) & below_cost_limit 420 | 421 | card_table[col["stability"]][reviewed] = new_s[reviewed] 422 | card_table[col["difficulty"]][reviewed] = new_d[reviewed] 423 | 424 | if scheduler_name == "fsrs": 425 | card_table[col["ivl"]][reviewed] = np.clip( 426 | next_interval( 427 | card_table[col["stability"]][reviewed], 428 | request_retention, 429 | fuzz=fuzz, 430 | decay=-w[20], 431 | ), 432 | 1, 433 | max_ivl, 434 | ) 435 | card_table[col["due"]][reviewed] = today + card_table[col["ivl"]][reviewed] 436 | else: # anki scheduler 437 | scheduled_intervals = card_table[col["ivl"]][reviewed] 438 | eases = card_table[col["ease"]][reviewed] 439 | real_intervals = card_table[col["delta_t"]][reviewed] 440 | ratings = card_table[col["rating"]][reviewed].astype(int) 441 | 442 | delta_ts, new_eases = anki_sm2_scheduler( 443 | scheduled_intervals, real_intervals, eases, ratings 444 | ) 445 | card_table[col["ivl"]][reviewed] = delta_ts 446 | card_table[col["due"]][reviewed] = today + delta_ts 447 | card_table[col["ease"]][reviewed] = new_eases 448 | 449 | revlogs[today] = { 450 | "card_id": np.where(reviewed)[0], 451 | "rating": card_table[col["rating"]][reviewed], 452 | } 453 | 454 | has_learned = card_table[col["stability"]] > 1e-10 455 | card_table[col["delta_t"]][has_learned] = ( 456 | today - card_table[col["last_date"]][has_learned] 457 | ) 458 | card_table[col["retrievability"]][has_learned] = power_forgetting_curve( 459 | card_table[col["delta_t"]][has_learned], 460 | card_table[col["stability"]][has_learned], 461 | ) 462 | 463 | review_cnt_per_day[today] = np.sum(true_review & reviewed) 464 | learn_cnt_per_day[today] = np.sum(true_learn & reviewed) 465 | memorized_cnt_per_day[today] = card_table[col["retrievability"]].sum() 466 | cost_per_day[today] = card_table[col["cost"]][reviewed].sum() 467 | return ( 468 | card_table, 469 | review_cnt_per_day, 470 | learn_cnt_per_day, 471 | memorized_cnt_per_day, 472 | cost_per_day, 473 | revlogs, 474 | ) 475 | 476 | 477 | def optimal_retention(**kwargs): 478 | return brent(**kwargs) 479 | 480 | 481 | CMRR_TARGET_WORKLOAD_ONLY = True 482 | CMRR_TARGET_MEMORIZED_PER_WORKLOAD = False 483 | CMRR_TARGET_MEMORIZED_STABILITY_PER_WORKLOAD = "memorized_stability_per_workload" 484 | 485 | 486 | def run_simulation(args): 487 | target, kwargs = args 488 | 489 | (card_table, _, _, memorized_cnt_per_day, cost_per_day, _) = simulate(**kwargs) 490 | 491 | if target == CMRR_TARGET_WORKLOAD_ONLY: 492 | return np.mean(cost_per_day) 493 | if target == CMRR_TARGET_MEMORIZED_PER_WORKLOAD: 494 | return np.sum(cost_per_day) / memorized_cnt_per_day[-1] 495 | if target == CMRR_TARGET_MEMORIZED_STABILITY_PER_WORKLOAD: 496 | return np.sum(cost_per_day) / np.sum( 497 | np.max(card_table[col["stability"]], 0) * card_table[col["retrievability"]] 498 | ) 499 | 500 | 501 | def sample( 502 | r, 503 | w, 504 | deck_size=10000, 505 | learn_span=365, 506 | max_cost_perday=1800, 507 | learn_limit_perday=math.inf, 508 | review_limit_perday=math.inf, 509 | max_ivl=36500, 510 | first_rating_prob=DEFAULT_FIRST_RATING_PROB, 511 | review_rating_prob=DEFAULT_REVIEW_RATING_PROB, 512 | learning_step_transitions=DEFAULT_LEARNING_STEP_TRANSITIONS, 513 | relearning_step_transitions=DEFAULT_RELEARNING_STEP_TRANSITIONS, 514 | state_rating_costs=DEFAULT_STATE_RATING_COSTS, 515 | workload_only=CMRR_TARGET_MEMORIZED_PER_WORKLOAD, 516 | ): 517 | results = [] 518 | 519 | def best_sample_size(days_to_simulate): 520 | if days_to_simulate <= 30: 521 | return 45 522 | elif days_to_simulate >= 365: 523 | return 4 524 | else: 525 | a1, a2, a3 = 8.20e-07, 2.41e-03, 1.30e-02 526 | factor = a1 * np.power(days_to_simulate, 2) + a2 * days_to_simulate + a3 527 | default_sample_size = 4 528 | return int(default_sample_size / factor) 529 | 530 | SAMPLE_SIZE = best_sample_size(learn_span) 531 | 532 | with ProcessPoolExecutor() as executor: 533 | futures = [] 534 | for i in range(SAMPLE_SIZE): 535 | kwargs = { 536 | "w": w, 537 | "request_retention": r, 538 | "deck_size": deck_size, 539 | "learn_span": learn_span, 540 | "max_cost_perday": max_cost_perday, 541 | "learn_limit_perday": learn_limit_perday, 542 | "review_limit_perday": review_limit_perday, 543 | "max_ivl": max_ivl, 544 | "first_rating_prob": first_rating_prob, 545 | "review_rating_prob": review_rating_prob, 546 | "learning_step_transitions": learning_step_transitions, 547 | "relearning_step_transitions": relearning_step_transitions, 548 | "state_rating_costs": state_rating_costs, 549 | "seed": 42 + i, 550 | } 551 | futures.append(executor.submit(run_simulation, (workload_only, kwargs))) 552 | 553 | for future in as_completed(futures): 554 | results.append(future.result()) 555 | return np.mean(results) 556 | 557 | 558 | def brent(tol=0.01, maxiter=20, **kwargs): 559 | mintol = 1.0e-11 560 | cg = 0.3819660 561 | 562 | xb = 0.70 563 | fb = sample(xb, **kwargs) 564 | funccalls = 1 565 | 566 | ################################# 567 | # BEGIN 568 | ################################# 569 | x = w = v = xb 570 | fw = fv = fx = fb 571 | a = 0.70 572 | b = 0.95 573 | deltax = 0.0 574 | iter = 0 575 | 576 | while iter < maxiter: 577 | tol1 = tol * np.abs(x) + mintol 578 | tol2 = 2.0 * tol1 579 | xmid = 0.5 * (a + b) 580 | # check for convergence 581 | if np.abs(x - xmid) < (tol2 - 0.5 * (b - a)): 582 | break 583 | 584 | if np.abs(deltax) <= tol1: 585 | if x >= xmid: 586 | deltax = a - x # do a golden section step 587 | else: 588 | deltax = b - x 589 | rat = cg * deltax 590 | else: # do a parabolic step 591 | tmp1 = (x - w) * (fx - fv) 592 | tmp2 = (x - v) * (fx - fw) 593 | p = (x - v) * tmp2 - (x - w) * tmp1 594 | tmp2 = 2.0 * (tmp2 - tmp1) 595 | if tmp2 > 0.0: 596 | p = -p 597 | tmp2 = np.abs(tmp2) 598 | dx_temp = deltax 599 | deltax = rat 600 | # check parabolic fit 601 | if ( 602 | (p > tmp2 * (a - x)) 603 | and (p < tmp2 * (b - x)) 604 | and (np.abs(p) < np.abs(0.5 * tmp2 * dx_temp)) 605 | ): 606 | rat = p * 1.0 / tmp2 # if parabolic step is useful 607 | u = x + rat 608 | if (u - a) < tol2 or (b - u) < tol2: 609 | if xmid - x >= 0: 610 | rat = tol1 611 | else: 612 | rat = -tol1 613 | else: 614 | if x >= xmid: 615 | deltax = a - x # if it's not do a golden section step 616 | else: 617 | deltax = b - x 618 | rat = cg * deltax 619 | 620 | if np.abs(rat) < tol1: # update by at least tol1 621 | if rat >= 0: 622 | u = x + tol1 623 | else: 624 | u = x - tol1 625 | else: 626 | u = x + rat 627 | fu = sample(u, **kwargs) # calculate new output value 628 | funccalls += 1 629 | 630 | if fu > fx: # if it's bigger than current 631 | if u < x: 632 | a = u 633 | else: 634 | b = u 635 | if (fu <= fw) or (w == x): 636 | v = w 637 | w = u 638 | fv = fw 639 | fw = fu 640 | elif (fu <= fv) or (v == x) or (v == w): 641 | v = u 642 | fv = fu 643 | else: 644 | if u >= x: 645 | a = x 646 | else: 647 | b = x 648 | v = w 649 | w = x 650 | x = u 651 | fv = fw 652 | fw = fx 653 | fx = fu 654 | 655 | iter += 1 656 | ################################# 657 | # END 658 | ################################# 659 | 660 | xmin = x 661 | fval = fx 662 | 663 | success = ( 664 | iter < maxiter 665 | and not (np.isnan(fval) or np.isnan(xmin)) 666 | and (0.70 <= xmin <= 0.95) 667 | ) 668 | 669 | if success: 670 | return xmin 671 | else: 672 | raise Exception("The algorithm terminated without finding a valid value.") 673 | 674 | 675 | def workload_graph(default_params, sampling_size=30): 676 | R = np.linspace(0.7, 0.999, sampling_size).tolist() 677 | default_params["max_cost_perday"] = math.inf 678 | default_params["learn_limit_perday"] = int( 679 | default_params["deck_size"] / default_params["learn_span"] 680 | ) 681 | default_params["review_limit_perday"] = math.inf 682 | workload = [sample(r=r, workload_only=True, **default_params) for r in R] 683 | 684 | # this is for testing 685 | # workload = [min(x, 2.3 * min(workload)) for x in workload] 686 | min_w = min(workload) # minimum workload 687 | max_w = max(workload) # maximum workload 688 | min1_index = R.index(R[workload.index(min_w)]) 689 | 690 | min_w2 = 0 691 | min_w3 = 0 692 | target2 = 1.5 * min_w 693 | target3 = 2 * min_w 694 | 695 | for i in range(len(workload) - 1): 696 | if (workload[i] <= target2) and (workload[i + 1] >= target2): 697 | if abs(workload[i] - target2) < abs(workload[i + 1] - target2): 698 | min_w2 = workload[i] 699 | else: 700 | min_w2 = workload[i + 1] 701 | 702 | for i in range(len(workload) - 1): 703 | if (workload[i] <= target3) and (workload[i + 1] >= target3): 704 | if abs(workload[i] - target3) < abs(workload[i + 1] - target3): 705 | min_w3 = workload[i] 706 | else: 707 | min_w3 = workload[i + 1] 708 | 709 | if min_w2 == 0: 710 | min2_index = len(R) 711 | else: 712 | min2_index = R.index(R[workload.index(min_w2)]) 713 | 714 | min1_5_index = int(math.ceil((min2_index + 3 * min1_index) / 4)) 715 | if min_w3 == 0: 716 | min3_index = len(R) 717 | else: 718 | min3_index = R.index(R[workload.index(min_w3)]) 719 | 720 | fig = plt.figure(figsize=(16, 8)) 721 | ax = fig.gca() 722 | if min1_index > 0: 723 | ax.fill_between( 724 | x=R[: min1_index + 1], 725 | y1=0, 726 | y2=workload[: min1_index + 1], 727 | color="red", 728 | alpha=1, 729 | ) 730 | ax.fill_between( 731 | x=R[min1_index : min1_5_index + 1], 732 | y1=0, 733 | y2=workload[min1_index : min1_5_index + 1], 734 | color="gold", 735 | alpha=1, 736 | ) 737 | else: 738 | # handle the case when there is no red area to the left 739 | ax.fill_between( 740 | x=R[: min1_5_index + 1], 741 | y1=0, 742 | y2=workload[: min1_5_index + 1], 743 | color="gold", 744 | alpha=1, 745 | ) 746 | 747 | ax.fill_between( 748 | x=R[min1_5_index : min2_index + 1], 749 | y1=0, 750 | y2=workload[min1_5_index : min2_index + 1], 751 | color="limegreen", 752 | alpha=1, 753 | ) 754 | ax.fill_between( 755 | x=R[min2_index : min3_index + 1], 756 | y1=0, 757 | y2=workload[min2_index : min3_index + 1], 758 | color="gold", 759 | alpha=1, 760 | ) 761 | ax.fill_between( 762 | x=R[min3_index:], 763 | y1=0, 764 | y2=workload[min3_index:], 765 | color="red", 766 | alpha=1, 767 | ) 768 | ax.set_yticks([]) 769 | ax.set_xticks([0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99]) 770 | ax.xaxis.set_tick_params(labelsize=14) 771 | ax.set_xlim(0.7, 0.99) 772 | 773 | if max_w >= 3.5 * min_w: 774 | lim = 3.5 * min_w 775 | elif max_w >= 3 * min_w: 776 | lim = 3 * min_w 777 | elif max_w >= 2.5 * min_w: 778 | lim = 2.5 * min_w 779 | elif max_w >= 2 * min_w: 780 | lim = 2 * min_w 781 | else: 782 | lim = 1.1 * max_w 783 | 784 | ax.set_ylim(0, lim) 785 | ax.set_ylabel("Workload (minutes of study per day)", fontsize=20) 786 | ax.set_xlabel("Desired Retention", fontsize=20) 787 | ax.axhline(y=min_w, color="black", alpha=0.75, ls="--") 788 | ax.text( 789 | 0.701, 790 | min_w, 791 | "minimum workload", 792 | ha="left", 793 | va="bottom", 794 | color="black", 795 | fontsize=12, 796 | ) 797 | if lim >= 1.8 * min_w: 798 | ax.axhline(y=1.5 * min_w, color="black", alpha=0.75, ls="--") 799 | ax.text( 800 | 0.701, 801 | 1.5 * min_w, 802 | "minimum workload x1.5", 803 | ha="left", 804 | va="bottom", 805 | color="black", 806 | fontsize=12, 807 | ) 808 | if lim >= 2.3 * min_w: 809 | ax.axhline(y=2 * min_w, color="black", alpha=0.75, ls="--") 810 | ax.text( 811 | 0.701, 812 | 2 * min_w, 813 | "minimum workload x2", 814 | ha="left", 815 | va="bottom", 816 | color="black", 817 | fontsize=12, 818 | ) 819 | if lim >= 2.8 * min_w: 820 | ax.axhline(y=2.5 * min_w, color="black", alpha=0.75, ls="--") 821 | ax.text( 822 | 0.701, 823 | 2.5 * min_w, 824 | "minimum workload x2.5", 825 | ha="left", 826 | va="bottom", 827 | color="black", 828 | fontsize=12, 829 | ) 830 | if lim >= 3.3 * min_w: 831 | ax.axhline(y=3 * min_w, color="black", alpha=0.75, ls="--") 832 | ax.text( 833 | 0.701, 834 | 3 * min_w, 835 | "minimum workload x3", 836 | ha="left", 837 | va="bottom", 838 | color="black", 839 | fontsize=12, 840 | ) 841 | fig.tight_layout(h_pad=0, w_pad=0) 842 | return fig 843 | 844 | 845 | if __name__ == "__main__": 846 | default_params = { 847 | "w": [ 848 | 0.2172, 849 | 1.1771, 850 | 3.2602, 851 | 16.1507, 852 | 7.0114, 853 | 0.57, 854 | 2.0966, 855 | 0.0069, 856 | 1.5261, 857 | 0.112, 858 | 1.0178, 859 | 1.849, 860 | 0.1133, 861 | 0.3127, 862 | 2.2934, 863 | 0.2191, 864 | 3.0004, 865 | 0.7536, 866 | 0.3332, 867 | 0.1437, 868 | 0.2, 869 | ], 870 | "deck_size": 20000, 871 | "learn_span": 365, 872 | "max_cost_perday": 1800, 873 | "learn_limit_perday": math.inf, 874 | "review_limit_perday": math.inf, 875 | "max_ivl": 36500, 876 | } 877 | 878 | schedulers = ["fsrs", "anki"] 879 | for scheduler_name in schedulers: 880 | ( 881 | _, 882 | review_cnt_per_day, 883 | learn_cnt_per_day, 884 | memorized_cnt_per_day, 885 | _, 886 | revlogs, 887 | ) = simulate( 888 | w=default_params["w"], 889 | max_cost_perday=math.inf, 890 | learn_limit_perday=10, 891 | review_limit_perday=50, 892 | scheduler_name=scheduler_name, 893 | ) 894 | 895 | def moving_average(data, window_size=365 // 20): 896 | weights = np.ones(window_size) / window_size 897 | return np.convolve(data, weights, mode="valid") 898 | 899 | plt.figure(1) 900 | plt.plot( 901 | moving_average(review_cnt_per_day), 902 | label=scheduler_name, 903 | ) 904 | plt.title("Review Count per Day") 905 | plt.legend() 906 | plt.grid(True) 907 | 908 | plt.figure(2) 909 | plt.plot( 910 | moving_average(learn_cnt_per_day), 911 | label=scheduler_name, 912 | ) 913 | plt.title("Learn Count per Day") 914 | plt.legend() 915 | plt.grid(True) 916 | 917 | plt.figure(3) 918 | plt.plot( 919 | np.cumsum(learn_cnt_per_day), 920 | label=scheduler_name, 921 | ) 922 | plt.title("Cumulative Learn Count") 923 | plt.legend() 924 | plt.grid(True) 925 | 926 | plt.figure(4) 927 | plt.plot( 928 | memorized_cnt_per_day, 929 | label=scheduler_name, 930 | ) 931 | plt.title("Memorized Count per Day") 932 | plt.legend() 933 | plt.grid(True) 934 | 935 | plt.figure(5) 936 | plt.plot( 937 | [ 938 | sum(rating > 1 for rating in day["rating"]) / len(day["rating"]) 939 | for _, day in sorted(revlogs.items(), key=lambda a: a[0]) 940 | ], 941 | label=scheduler_name, 942 | ) 943 | plt.ylim(0, 1) 944 | plt.title("True retention per day") 945 | plt.legend() 946 | plt.grid(True) 947 | 948 | plt.show() 949 | workload_graph(default_params, sampling_size=30).savefig("workload.png") 950 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-spaced-repetition/fsrs-optimizer/e0f48aec8f5c083de52c69af7ca1e32576f1ac07/tests/__init__.py -------------------------------------------------------------------------------- /tests/model_test.py: -------------------------------------------------------------------------------- 1 | from src.fsrs_optimizer import * 2 | 3 | 4 | class Test_Model: 5 | def test_next_stability(self): 6 | model = FSRS(DEFAULT_PARAMETER) 7 | stability = torch.tensor([5.0] * 4) 8 | difficulty = torch.tensor([1.0, 2.0, 3.0, 4.0]) 9 | retention = torch.tensor([0.9, 0.8, 0.7, 0.6]) 10 | rating = torch.tensor([1, 2, 3, 4]) 11 | state = torch.stack([stability, difficulty]).unsqueeze(0) 12 | s_recall = model.stability_after_success(state, retention, rating) 13 | assert torch.allclose( 14 | s_recall, 15 | torch.tensor([25.602541, 28.226582, 58.656002, 127.226685]), 16 | atol=1e-4, 17 | ) 18 | s_forget = model.stability_after_failure(state, retention) 19 | assert torch.allclose( 20 | s_forget, 21 | torch.tensor([1.0525396, 1.1894329, 1.3680838, 1.584989]), 22 | atol=1e-4, 23 | ) 24 | s_short_term = model.stability_short_term(state, rating) 25 | assert torch.allclose( 26 | s_short_term, 27 | torch.tensor([1.596818, 2.7470093, 5.0, 8.12961]), 28 | atol=1e-4, 29 | ) 30 | 31 | def test_next_difficulty(self): 32 | model = FSRS(DEFAULT_PARAMETER) 33 | stability = torch.tensor([5.0] * 4) 34 | difficulty = torch.tensor([5.0] * 4) 35 | rating = torch.tensor([1, 2, 3, 4]) 36 | state = torch.stack([stability, difficulty]).unsqueeze(0) 37 | d_recall = model.next_d(state, rating) 38 | assert torch.allclose( 39 | d_recall, 40 | torch.tensor([8.341763, 6.6659956, 4.990228, 3.3144615]), 41 | atol=1e-4, 42 | ) 43 | 44 | def test_power_forgetting_curve(self): 45 | delta_t = torch.tensor([0, 1, 2, 3, 4, 5]) 46 | stability = torch.tensor([1, 2, 3, 4, 4, 2]) 47 | retention = power_forgetting_curve(delta_t, stability) 48 | assert torch.allclose( 49 | retention, 50 | torch.tensor([1.0, 0.9403443, 0.9253786, 0.9185229, 0.9, 0.8261359]), 51 | atol=1e-4, 52 | ) 53 | 54 | def test_forward(self): 55 | model = FSRS(DEFAULT_PARAMETER) 56 | delta_ts = torch.tensor( 57 | [ 58 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 59 | [1.0, 1.0, 1.0, 1.0, 2.0, 2.0], 60 | ] 61 | ) 62 | ratings = torch.tensor( 63 | [ 64 | [1.0, 2.0, 3.0, 4.0, 1.0, 2.0], 65 | [1.0, 2.0, 3.0, 4.0, 1.0, 2.0], 66 | ] 67 | ) 68 | inputs = torch.stack([delta_ts, ratings], dim=2) 69 | _, state = model.forward(inputs) 70 | stability = state[:, 0] 71 | difficulty = state[:, 1] 72 | assert torch.allclose( 73 | stability, 74 | torch.tensor( 75 | [ 76 | 0.10088589, 77 | 3.2494123, 78 | 7.3153, 79 | 18.014914, 80 | 0.112798266, 81 | 4.4694576, 82 | ] 83 | ), 84 | atol=1e-4, 85 | ) 86 | assert torch.allclose( 87 | difficulty, 88 | torch.tensor([8.806304, 6.7404594, 2.1112142, 1.0, 8.806304, 6.7404594]), 89 | atol=1e-4, 90 | ) 91 | 92 | def test_loss_and_grad(self): 93 | model = FSRS(DEFAULT_PARAMETER) 94 | clipper = ParameterClipper() 95 | init_w = torch.tensor(DEFAULT_PARAMETER) 96 | params_stddev = DEFAULT_PARAMS_STDDEV_TENSOR 97 | optimizer = torch.optim.Adam(model.parameters(), lr=4e-2) 98 | loss_fn = nn.BCELoss(reduction="none") 99 | t_histories = torch.tensor( 100 | [ 101 | [0.0, 0.0, 0.0, 0.0], 102 | [0.0, 0.0, 0.0, 0.0], 103 | [0.0, 0.0, 0.0, 1.0], 104 | [0.0, 1.0, 1.0, 3.0], 105 | [1.0, 3.0, 3.0, 5.0], 106 | [3.0, 6.0, 6.0, 12.0], 107 | ] 108 | ) 109 | r_histories = torch.tensor( 110 | [ 111 | [1.0, 2.0, 3.0, 4.0], 112 | [3.0, 4.0, 2.0, 4.0], 113 | [1.0, 4.0, 4.0, 3.0], 114 | [4.0, 3.0, 3.0, 3.0], 115 | [3.0, 1.0, 3.0, 3.0], 116 | [2.0, 3.0, 3.0, 4.0], 117 | ] 118 | ) 119 | delta_ts = torch.tensor([4.0, 11.0, 12.0, 23.0]) 120 | labels = torch.tensor([1, 1, 1, 0], dtype=torch.float32, requires_grad=False) 121 | inputs = torch.stack([t_histories, r_histories], dim=2) 122 | seq_lens = inputs.shape[0] 123 | real_batch_size = inputs.shape[1] 124 | outputs, _ = model.forward(inputs) 125 | stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] 126 | retentions = power_forgetting_curve(delta_ts, stabilities, -model.w[20]) 127 | loss = loss_fn(retentions, labels).sum() 128 | assert torch.allclose(loss, torch.tensor(4.047898), atol=1e-4) 129 | loss.backward() 130 | assert torch.allclose( 131 | model.w.grad, 132 | torch.tensor( 133 | [ 134 | -0.095688485, 135 | -0.0051607806, 136 | -0.00080300873, 137 | 0.007462064, 138 | 0.03677408, 139 | -0.084962785, 140 | 0.059571628, 141 | -2.1566951, 142 | 0.5738574, 143 | -2.8749206, 144 | 0.7123072, 145 | -0.028993709, 146 | 0.0099172965, 147 | -0.2189217, 148 | -0.0017800558, 149 | -0.089381434, 150 | 0.299141, 151 | 0.0708902, 152 | -0.01219162, 153 | -0.25424173, 154 | 0.27452517, 155 | ] 156 | ), 157 | atol=1e-4, 158 | ) 159 | optimizer.step() 160 | assert torch.allclose( 161 | model.w, 162 | torch.tensor( 163 | [ 164 | 0.252, 165 | 1.3331, 166 | 2.3464994, 167 | 8.2556, 168 | 6.3733, 169 | 0.87340003, 170 | 2.9794, 171 | 0.040999997, 172 | 1.8322, 173 | 0.20660001, 174 | 0.756, 175 | 1.5235, 176 | 0.021400042, 177 | 0.3029, 178 | 1.6882998, 179 | 0.64140004, 180 | 1.8329, 181 | 0.5025, 182 | 0.13119997, 183 | 0.1058, 184 | 0.1142, 185 | ] 186 | ), 187 | atol=1e-4, 188 | ) 189 | 190 | optimizer.zero_grad() 191 | penalty = ( 192 | torch.sum(torch.square(model.w - init_w) / torch.square(params_stddev)) 193 | * 512 194 | / 1000 195 | * 2 196 | ) 197 | assert torch.allclose(penalty, torch.tensor(0.6771115), atol=1e-4) 198 | penalty.backward() 199 | assert torch.allclose( 200 | model.w.grad, 201 | torch.tensor( 202 | [ 203 | 0.0019813816, 204 | 0.00087788026, 205 | 0.00026506148, 206 | -0.000105618295, 207 | -0.25213888, 208 | 1.0448985, 209 | -0.22755535, 210 | 5.688889, 211 | -0.5385926, 212 | 2.5283954, 213 | -0.75225013, 214 | 0.9102214, 215 | -10.113569, 216 | 3.1999993, 217 | 0.2521374, 218 | 1.3107208, 219 | -0.07721739, 220 | -0.85244584, 221 | 0.79999936, 222 | 4.1795917, 223 | -1.1237311, 224 | ] 225 | ), 226 | atol=1e-5, 227 | ) 228 | 229 | optimizer.zero_grad() 230 | t_histories = torch.tensor( 231 | [ 232 | [0.0, 0.0, 0.0, 0.0], 233 | [0.0, 0.0, 0.0, 0.0], 234 | [0.0, 0.0, 0.0, 1.0], 235 | [0.0, 1.0, 1.0, 3.0], 236 | [1.0, 3.0, 3.0, 5.0], 237 | [3.0, 6.0, 6.0, 12.0], 238 | ] 239 | ) 240 | r_histories = torch.tensor( 241 | [ 242 | [1.0, 2.0, 3.0, 4.0], 243 | [3.0, 4.0, 2.0, 4.0], 244 | [1.0, 4.0, 4.0, 3.0], 245 | [4.0, 3.0, 3.0, 3.0], 246 | [3.0, 1.0, 3.0, 3.0], 247 | [2.0, 3.0, 3.0, 4.0], 248 | ] 249 | ) 250 | delta_ts = torch.tensor([4.0, 11.0, 12.0, 23.0]) 251 | labels = torch.tensor([1, 1, 1, 0], dtype=torch.float32, requires_grad=False) 252 | inputs = torch.stack([t_histories, r_histories], dim=2) 253 | outputs, _ = model.forward(inputs) 254 | stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] 255 | retentions = power_forgetting_curve(delta_ts, stabilities, -model.w[20]) 256 | loss = loss_fn(retentions, labels).sum() 257 | assert torch.allclose(loss, torch.tensor(3.76888), atol=1e-4) 258 | loss.backward() 259 | assert torch.allclose( 260 | model.w.grad, 261 | torch.tensor( 262 | [ 263 | -0.040530164, 264 | -0.0041278866, 265 | -0.0006833144, 266 | 0.007239434, 267 | 0.009416521, 268 | -0.12156768, 269 | 0.039193563, 270 | -0.86553144, 271 | 0.57743585, 272 | -2.571437, 273 | 0.76415884, 274 | -0.024242667, 275 | 0.0, 276 | -0.16912507, 277 | -0.0017008218, 278 | -0.061857328, 279 | 0.28093633, 280 | 0.06636292, 281 | 0.0057900245, 282 | -0.19041246, 283 | 0.6214733, 284 | ] 285 | ), 286 | atol=1e-4, 287 | ) 288 | optimizer.step() 289 | model.apply(clipper) 290 | assert torch.allclose( 291 | model.w, 292 | torch.tensor( 293 | [ 294 | 0.2882918, 295 | 1.3726242, 296 | 2.3862023, 297 | 8.215636, 298 | 6.339949, 299 | 0.9131501, 300 | 2.940647, 301 | 0.07696302, 302 | 1.7921939, 303 | 0.2464219, 304 | 0.71595156, 305 | 1.5631561, 306 | 0.001, 307 | 0.34230903, 308 | 1.7282416, 309 | 0.68038, 310 | 1.7929853, 311 | 0.46259063, 312 | 0.1426339, 313 | 0.14509763, 314 | 0.1, 315 | ] 316 | ), 317 | atol=1e-4, 318 | ) 319 | -------------------------------------------------------------------------------- /tests/simulator_test.py: -------------------------------------------------------------------------------- 1 | from src.fsrs_optimizer import * 2 | 3 | FSRS_RS_MEMORIZED = 5361.807 4 | 5 | 6 | class Test_Simulator: 7 | def test_simulate(self): 8 | ( 9 | card_table, 10 | review_cnt_per_day, 11 | learn_cnt_per_day, 12 | memorized_cnt_per_day, 13 | cost_per_day, 14 | revlogs, 15 | ) = simulate(w=DEFAULT_PARAMETER, request_retention=0.9) 16 | deviation = abs(1 - (memorized_cnt_per_day[-1] / FSRS_RS_MEMORIZED)) 17 | assert ( 18 | deviation < 0.06 19 | ), f"{memorized_cnt_per_day[-1]:.2f} is not within 5% of the expected {FSRS_RS_MEMORIZED:.2f} ({deviation:.2%} deviation)" 20 | 21 | def test_optimal_retention(self): 22 | default_params = { 23 | "w": DEFAULT_PARAMETER, 24 | "deck_size": 10000, 25 | "learn_span": 1000, 26 | "max_cost_perday": math.inf, 27 | "learn_limit_perday": 10, 28 | "review_limit_perday": math.inf, 29 | "max_ivl": 36500, 30 | } 31 | r = optimal_retention(**default_params) 32 | assert r == 0.7 33 | --------------------------------------------------------------------------------