├── tests ├── __init__.py └── test_derivative_calculator.py ├── src └── derivative_calculator │ ├── __init__.py │ ├── py.typed │ ├── interface.py │ ├── tokenizer.py │ ├── math_parser.py │ ├── interpreter.py │ ├── utils.py │ └── symb_diff_tool.py ├── setup.py ├── requirements.txt ├── tox.ini ├── pyproject.toml ├── .github └── workflows │ └── tests.yml ├── setup.cfg ├── README.md └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/derivative_calculator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/derivative_calculator/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==22.2.0 2 | cachetools==5.3.0 3 | chardet==5.1.0 4 | colorama==0.4.6 5 | coverage==7.1.0 6 | distlib==0.3.6 7 | exceptiongroup==1.1.0 8 | filelock==3.9.0 9 | flake8==6.0.0 10 | iniconfig==2.0.0 11 | mccabe==0.7.0 12 | mypy==0.991 13 | mypy-extensions==0.4.3 14 | packaging==23.0 15 | platformdirs==2.6.2 16 | pluggy==1.0.0 17 | pycodestyle==2.10.0 18 | pyflakes==3.0.1 19 | pyproject_api==1.5.0 20 | pytest==7.2.1 21 | pytest-cov==4.0.0 22 | tomli==2.0.1 23 | tox==4.4.2 24 | typing==3.7.4.3 25 | typing_extensions==4.4.0 26 | virtualenv==20.17.1 27 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | minversion = 4.3.5 3 | envlist = py310, flake8 4 | isolated_build = true 5 | 6 | [gh-actions] 7 | python = 8 | 3.10: py310, flake8 9 | 10 | [testenv] 11 | setenv = 12 | PYTHONPATH = {toxinidir} 13 | deps = 14 | -r{toxinidir}/requirements.txt 15 | commands = 16 | pytest --basetemp={envtmpdir} 17 | 18 | [testenv:flake8] 19 | basepython = python3.10 20 | deps = flake8 21 | commands = flake8 src tests 22 | 23 | [testenv:mypy] 24 | basepython = python3.10 25 | deps = 26 | -r{toxinidir}/requirements.txt 27 | commands = mypy src 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.pytest.ini_options] 6 | addopts = "--cov=derivative_calculator" 7 | testpaths = [ 8 | "tests", 9 | ] 10 | 11 | [tool.mypy] 12 | mypy_path = "src" 13 | check_untyped_defs = true 14 | disallow_any_generics = true 15 | ignore_missing_imports = true 16 | no_implicit_optional = true 17 | show_error_codes = true 18 | strict_equality = true 19 | warn_redundant_casts = true 20 | warn_return_any = true 21 | warn_unreachable = true 22 | warn_unused_configs = true 23 | no_implicit_reexport = true -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | runs-on: ${{ matrix.os }} 10 | strategy: 11 | matrix: 12 | os: [ubuntu-latest, windows-latest] 13 | python-version: ['3.10'] 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install tox tox-gh-actions 25 | - name: Test with tox 26 | run: tox -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = derivative-calculator 3 | description = derivative calculator 4 | author = Santiago Cuenca 5 | platforms = unix, linux, osx, cygwin, win32 6 | classifiers = 7 | Programming Language :: Python :: 3 8 | Programming Language :: Python :: 3 :: Only 9 | Programming Language :: Python :: 3.10 10 | 11 | [options] 12 | packages = 13 | derivative_calculator 14 | python_requires = >=3.10 15 | package_dir = 16 | =src 17 | zip_safe = no 18 | 19 | [options.extras_require] 20 | testing = 21 | pytest>=7.2.1 22 | pytest-cov>=4.0.0 23 | mypy>=0.991 24 | flake8>=6.0.0 25 | tox>=4.4.2 26 | 27 | [options.package_data] 28 | derivative_calculator = py.typed 29 | 30 | [flake8] 31 | max-line-length = 92 32 | -------------------------------------------------------------------------------- /src/derivative_calculator/interface.py: -------------------------------------------------------------------------------- 1 | import os 2 | from derivative_calculator.tokenizer import Token, Tokenizer, VAR 3 | from derivative_calculator.math_parser import Var, Parser 4 | from derivative_calculator.symb_diff_tool import deriv 5 | from derivative_calculator.interpreter import Interpreter 6 | 7 | 8 | def main(): 9 | 10 | print() 11 | print(' ------------------------- ') 12 | print(' - Derivative calculator - ') 13 | print(' ------------------------- ') 14 | print() 15 | print(' - Supported functions are: exp, log, sin, cos, tan, cosec, sec, cot.') 16 | print(' - Powers are represented by a double asterisk (**).') 17 | print(' - Valid variable inputs are single alphabet letters.') 18 | print() 19 | 20 | while True: 21 | try: 22 | expr = input('Enter mathematical function: ') 23 | if not isinstance(expr, str): 24 | raise ValueError 25 | break 26 | except ValueError: 27 | print("Invalid input: the mathematical expression must be in string format.") 28 | 29 | while True: 30 | try: 31 | var = input('Derivate respect to: ') 32 | if not (var.isalpha and len(var) == 1): 33 | raise ValueError 34 | break 35 | except ValueError: 36 | print("Invalid input: the variable must be a single alphabet letter.") 37 | 38 | expr_ast = Parser(Tokenizer(expr)).parse() # AST representing function 39 | token_var = Var(Token(VAR, var)) # Token object containing variable 40 | 41 | deriv_ast = deriv(expr_ast, token_var) 42 | inter = Interpreter() 43 | deriv_output = inter.visit(deriv_ast) 44 | 45 | print('Derivative: ', deriv_output) 46 | 47 | restart = input("Would you like to restart this program? (y/n): ") 48 | if restart == "y": 49 | os.system('cls||clear') 50 | main() 51 | if restart == "n": 52 | print("Good bye.") 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # derivative-calculator 2 | Symbolic differentiation tool that includes its own parser 3 | 4 | ![Tests](https://github.com/scfenton6/derivative-calculator/actions/workflows/tests.yml/badge.svg) 5 | 6 | ## Description 7 | 8 | This repository contains a symbolic differentiation tool written in python that includes: 9 | 10 | - It's own LL(1) parser to convert a string representing a mathematical function into an Abstract Syntax Tree. 11 | 12 | - A tool that performs symbolic differentiation on the AST applying the chain rule recursively to produce an AST containing the function's derivative. 13 | 14 | - An interpreter that takes an AST representing a mathematical function and returns the function in string format. 15 | 16 | Both the symbolic differentiation tool and the interpreter simplifiy binary arithmetic operations between numbers, multiplication and division by one, and addition, substraction and multiplication by zero. Additionally, the interpreter simplifies expressions containing an arbitrary number of prefix signs. 17 | 18 | ## Motivation 19 | 20 | The purpose to start this project was for me to learn about parsing. To this end, I followed the series [Let's Build A Simple Interpreter](https://ruslanspivak.com/lsbasi-part1). The series explains in detail how to build a parser and interpreter from scratch, accompanied with code snippets, and the code for my parser is mainly based on it. 21 | 22 | To give the parser a use, I decided to write code for a tool that would use the tree structure of the Abstract Syntax Tree produced by the parser to perform symbolic differentiation by doing a DFS traversal on it. The notation and general structure I used to write the symbolic differentiation tool is based on Subsection 2.3.2 or the book [Structure and Interpretation of Computer Programs, 2nd ed.](https://web.mit.edu/6.001/6.037/sicp.pdf). 23 | 24 | ## How to use 25 | 26 | To use the derivative calculator, simply run the interface.py file and enter a string containing a mathematical function, as well as a single character string containing the variable that you want to derivate respect to. After pressing enter, the program will output the derivative of the inputted function. 27 | 28 | ![Alt Text](https://media2.giphy.com/media/91R0PMpB1X0kP2VBwf/giphy.gif?cid=790b7611f817a0bcf269b594fb1a937519db11ef889d6062&rid=giphy.gif&ct=g) 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ -------------------------------------------------------------------------------- /src/derivative_calculator/tokenizer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Tokenizer implementation. 3 | Tokenizer.token_iter is the main program. 4 | ''' 5 | 6 | import typing 7 | 8 | 9 | INTEGER, VAR, PLUS, MINUS, MUL, DIV, POW, FUNC, LPAREN, RPAREN, EOF = ( 10 | 'INTEGER', 'VAR', 'PLUS', 'MINUS', 'MUL', 'DIV', 'POW', 'FUNC', '(', ')', 'EOF' 11 | ) 12 | valid_functions = ( 13 | 'exp', 'log', 'sin', 'cos', 'tan', 'cosec', 'sec', 'cot' 14 | ) 15 | 16 | 17 | class Token: 18 | def __init__(self, type: str, value: typing.Any) -> None: 19 | self.type = type 20 | self.value = value 21 | 22 | 23 | class Tokenizer: 24 | def __init__(self, text: str) -> None: 25 | self.text = text 26 | self.pos = 0 27 | self.current_char: typing.Optional[str] = self.text[self.pos] 28 | 29 | def error(self) -> None: 30 | raise Exception('Invalid character') 31 | 32 | def advance(self) -> None: 33 | """Advance pos pointer and set current_char variable.""" 34 | self.pos += 1 35 | if self.pos > len(self.text) - 1: 36 | self.current_char = None 37 | else: 38 | self.current_char = self.text[self.pos] 39 | 40 | def skip_whitespace(self) -> None: 41 | while self.current_char is not None and self.current_char.isspace(): 42 | self.advance() 43 | 44 | def handle_integer(self) -> int: 45 | """Return a (multidigit) integer consumed from the input.""" 46 | result = '' 47 | while self.current_char is not None and self.current_char.isdigit(): 48 | result += self.current_char 49 | self.advance() 50 | return int(result) 51 | 52 | def handle_alpha_seq(self) -> tuple[str, str]: # type: ignore[return] 53 | """Determine whether or not our sequence of alpha 54 | characters is a variable or a valid function.""" 55 | result = '' 56 | while self.current_char is not None and self.current_char.isalpha(): 57 | result += self.current_char 58 | self.advance() 59 | 60 | if len(result) == 1: 61 | return VAR, result 62 | 63 | if result in valid_functions: 64 | return FUNC, result 65 | 66 | self.error() 67 | 68 | def handle_asterisk(self) -> tuple[str, str]: # type: ignore[return] 69 | """Determine whether or not our asterisk is followed 70 | by another asterisk. In the former case it would 71 | represent the product operator, and in the latter it 72 | would represent the power operator""" 73 | result = '' 74 | while self.current_char is not None and self.current_char == '*': 75 | result += self.current_char 76 | self.advance() 77 | 78 | if len(result) == 1: 79 | return MUL, result 80 | 81 | if len(result) == 2: 82 | return POW, result 83 | 84 | self.error() 85 | 86 | def get_next_token(self) -> Token: 87 | """ 88 | Tokenizer: breaks a sentence apart into tokens one token at a time. 89 | """ 90 | while self.current_char is not None: 91 | 92 | if self.current_char.isspace(): 93 | self.skip_whitespace() 94 | continue 95 | 96 | if self.current_char.isdigit(): 97 | return Token(INTEGER, self.handle_integer()) 98 | 99 | if self.current_char.isalpha(): 100 | op, val = self.handle_alpha_seq() 101 | return Token(op, val) 102 | 103 | if self.current_char == '+': 104 | self.advance() 105 | return Token(PLUS, '+') 106 | 107 | if self.current_char == '-': 108 | self.advance() 109 | return Token(MINUS, '-') 110 | 111 | if self.current_char == '*': 112 | op, val = self.handle_asterisk() 113 | return Token(op, val) 114 | 115 | if self.current_char == '/': 116 | self.advance() 117 | return Token(DIV, '/') 118 | 119 | if self.current_char == '(': 120 | self.advance() 121 | return Token(LPAREN, '(') 122 | 123 | if self.current_char == ')': 124 | self.advance() 125 | return Token(RPAREN, ')') 126 | 127 | self.error() 128 | 129 | return Token(EOF, None) 130 | -------------------------------------------------------------------------------- /src/derivative_calculator/math_parser.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Parser for mathematical functions. 3 | Main program is Parser.parse. 4 | ''' 5 | 6 | import typing 7 | from derivative_calculator.tokenizer import ( 8 | Token, 9 | Tokenizer, 10 | INTEGER, 11 | VAR, 12 | PLUS, 13 | MINUS, 14 | MUL, 15 | DIV, 16 | POW, 17 | FUNC, 18 | LPAREN, 19 | RPAREN, 20 | ) 21 | 22 | Node = typing.Union['UnaryOp', 'BinOp', 'Num', 'Var'] 23 | 24 | 25 | class UnaryOp: 26 | '''AST node representing a unary operation''' 27 | def __init__(self, op: Token, expr: Node) -> None: 28 | self.token = self.op = op 29 | self.value = self.token.value 30 | self.expr = expr 31 | 32 | 33 | class BinOp: 34 | '''AST node representing a binary operation''' 35 | def __init__(self, left: Node, op: Token, right: Node) -> None: 36 | self.left = left 37 | self.token = self.op = op 38 | self.right = right 39 | 40 | 41 | class Num: 42 | '''AST node representing a number''' 43 | def __init__(self, token: Token) -> None: 44 | self.token = token 45 | self.value: int = token.value 46 | 47 | 48 | class Var: 49 | '''AST node representing a variable''' 50 | def __init__(self, token: Token) -> None: 51 | self.token = token 52 | self.value: str = token.value 53 | 54 | 55 | class Parser: 56 | ''' 57 | Parser for mathematical functions. 58 | Expression grammars are ordered by priority level: 59 | add_substr_expr: mul_div_expr ((PLUS | MINUS) mul_div_expr)* 60 | mul_div_expr: pow_expr ((MUL | DIV) pow_expr)* 61 | pow_expr: factor (POW factor)* 62 | factor : (PLUS | MINUS | FUNC) factor | INTEGER | LPAREN add_substr_expr RPAREN 63 | ''' 64 | def __init__(self, tokenizer: Tokenizer) -> None: 65 | self.tokenizer = tokenizer 66 | self.current_token = self.tokenizer.get_next_token() 67 | 68 | def error(self) -> None: 69 | raise Exception('Invalid syntax') 70 | 71 | def eat(self, token_type: str) -> None: 72 | if self.current_token.type == token_type: 73 | self.current_token = self.tokenizer.get_next_token() 74 | else: 75 | self.error() 76 | 77 | def factor(self) -> Node: # type: ignore[return] 78 | """ 79 | factor : (PLUS | MINUS | FUNC) factor | INTEGER | LPAREN add_substr_expr RPAREN 80 | """ 81 | token: Token = self.current_token 82 | 83 | if token.type == INTEGER: 84 | self.eat(INTEGER) 85 | return Num(token) 86 | 87 | elif token.type == VAR: 88 | self.eat(VAR) 89 | return Var(token) 90 | 91 | if token.type == PLUS: 92 | self.eat(PLUS) 93 | plus_node = UnaryOp(token, self.factor()) 94 | return plus_node 95 | 96 | elif token.type == MINUS: 97 | self.eat(MINUS) 98 | minus_node = UnaryOp(token, self.factor()) 99 | return minus_node 100 | 101 | elif token.type == FUNC: 102 | self.eat(FUNC) 103 | func_node = UnaryOp(token, self.factor()) 104 | return func_node 105 | 106 | elif token.type == LPAREN: 107 | self.eat(LPAREN) 108 | paren_node: Node = self.add_substr_expr() 109 | self.eat(RPAREN) 110 | return paren_node 111 | 112 | def pow_expr(self) -> Node: 113 | '''pow_expr: factor (POW factor)*''' 114 | fact_node: Node = self.factor() 115 | 116 | while self.current_token.type == POW: 117 | token = self.current_token 118 | self.eat(POW) 119 | fact_node = BinOp(left=fact_node, op=token, right=self.factor()) 120 | 121 | return fact_node 122 | 123 | def mul_div_expr(self) -> Node: 124 | ''' 125 | mul_div_expr: pow_expr ((MUL | DIV) pow_expr)* 126 | ''' 127 | pow_node: Node = self.pow_expr() 128 | 129 | while self.current_token.type in (MUL, DIV): 130 | token = self.current_token 131 | if token.type == MUL: 132 | self.eat(MUL) 133 | elif token.type == DIV: 134 | self.eat(DIV) 135 | 136 | pow_node = BinOp(left=pow_node, op=token, right=self.pow_expr()) 137 | 138 | return pow_node 139 | 140 | def add_substr_expr(self) -> Node: 141 | ''' 142 | add_substr_expr: mul_div_expr ((PLUS | MINUS) mul_div_expr)* 143 | ''' 144 | mul_node: Node = self.mul_div_expr() 145 | 146 | while self.current_token.type in (PLUS, MINUS): 147 | token = self.current_token 148 | if token.type == PLUS: 149 | self.eat(PLUS) 150 | elif token.type == MINUS: 151 | self.eat(MINUS) 152 | 153 | mul_node = BinOp(left=mul_node, op=token, right=self.mul_div_expr()) 154 | 155 | return mul_node 156 | 157 | def parse(self) -> Node: 158 | return self.add_substr_expr() 159 | -------------------------------------------------------------------------------- /src/derivative_calculator/interpreter.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Interpreter that visits an abstract syntax tree representing a 3 | mathematical function and returns the function in string format. 4 | Main program is Interpreter.interpret. 5 | ''' 6 | 7 | import typing 8 | from derivative_calculator.tokenizer import PLUS, MINUS, MUL, DIV, POW, FUNC 9 | import derivative_calculator.utils as utils 10 | from derivative_calculator.math_parser import UnaryOp, BinOp, Num, Var 11 | 12 | Node = typing.Union[UnaryOp, BinOp, Num, Var] 13 | 14 | 15 | class NodeVisitor: 16 | def visit(self, node: Node) -> str: 17 | method_name = 'visit_' + type(node).__name__ 18 | visitor = getattr(self, method_name) 19 | return visitor(node) 20 | 21 | 22 | class Interpreter(NodeVisitor): 23 | def __init__(self) -> None: 24 | self.prec = {PLUS: 1, MINUS: 1, MUL: 2, DIV: 2, POW: 3, FUNC: 4} 25 | 26 | def binOpHelper(self, node: BinOp, op: str, prec: int) -> str: 27 | ''' 28 | Adds parentheses to binary operation if needed 29 | ''' 30 | result = '' 31 | left, right = node.left, node.right 32 | 33 | if (utils.is_rational_number(left) or 34 | (isinstance(left, (UnaryOp, BinOp)) and self.prec[left.op.type] < prec)): 35 | result += r'(%s)%s' % (self.visit(left), op) 36 | else: 37 | result += r'%s%s' % (self.visit(left), op) 38 | if (utils.is_rational_number(right) or 39 | (isinstance(right, (UnaryOp, BinOp)) and self.prec[right.op.type] < prec)): 40 | result += r'(%s)' % self.visit(right) 41 | else: 42 | result += r'%s' % self.visit(right) 43 | return result 44 | 45 | def visit_BinOp(self, node: BinOp) -> str: # type: ignore[return] 46 | ''' 47 | Tool for visiting a binary operation. 48 | We try to simplify the expression if 49 | possible, and if not we call the pertaining 50 | helper function. 51 | ''' 52 | left, right = node.left, node.right 53 | 54 | # simplify left and right nodes in case they are prefix sign expressions 55 | if utils.is_prefix_sign(left): 56 | left = utils.simplifyPrefixSign(left) 57 | 58 | if utils.is_prefix_sign(right): 59 | right = utils.simplifyPrefixSign(right) 60 | 61 | if utils.is_sum(node): 62 | if isinstance(left, Num) and isinstance(right, Num): 63 | return str(left.value + right.value) 64 | if utils.is_zero(left): 65 | return str(self.visit(right)) 66 | if utils.is_zero(right): 67 | return str(self.visit(left)) 68 | return self.binOpHelper(node, '+', 1) 69 | 70 | if utils.is_substr(node): 71 | if isinstance(left, Num) and isinstance(right, Num): 72 | return str(left.value - right.value) 73 | if utils.is_zero(left): 74 | return r'-(%s)' % self.visit(right) 75 | if utils.is_zero(right): 76 | return str(self.visit(left)) 77 | return self.binOpHelper(node, '-', 1) 78 | 79 | elif utils.is_prod(node): 80 | if isinstance(left, Num) and isinstance(right, Num): 81 | return str(left.value * right.value) 82 | if utils.is_zero(left) or utils.is_zero(right): 83 | return '0' 84 | if utils.is_one(left): 85 | return str(self.visit(right)) 86 | if utils.is_one(right): 87 | return str(self.visit(left)) 88 | return self.binOpHelper(node, '*', 2) 89 | 90 | elif utils.is_div(node): 91 | if utils.is_zero(left): 92 | return '0' 93 | if utils.is_one(right): 94 | return str(self.visit(left)) 95 | return self.binOpHelper(node, '/', 4) 96 | 97 | elif utils.is_pow(node): 98 | if isinstance(left, Num) and isinstance(right, Num): 99 | return str(left.value ** right.value) 100 | if (utils.is_one(left) or utils.is_zero(right)): 101 | return '1' 102 | if utils.is_zero(left): 103 | return '0' 104 | if utils.is_one(right): 105 | return str(self.visit(left)) 106 | return self.binOpHelper(node, '**', 3) 107 | 108 | def visit_UnaryOp(self, node: UnaryOp) -> str: # type: ignore[return] 109 | if utils.is_prefix_sign(node): 110 | simpl_node: UnaryOp | Num = utils.simplifyPrefixSign(node) 111 | if isinstance(simpl_node, Num): 112 | return str(self.visit(simpl_node)) 113 | else: 114 | sign = simpl_node.op.value 115 | if isinstance(simpl_node, BinOp): 116 | return r'%s(%s)' % (sign, self.visit(simpl_node.expr)) 117 | else: 118 | return r'%s%s' % (sign, self.visit(simpl_node.expr)) 119 | if utils.is_func(node): 120 | return r'%s(%s)' % (node.op.value, self.visit(node.expr)) 121 | 122 | def visit_Num(self, node: Num) -> str: # type: ignore[return] 123 | return str(node.value) 124 | 125 | def visit_Var(self, node: Var) -> str: # type: ignore[return] 126 | return str(node.value) 127 | 128 | def interpret(self) -> str: 129 | tree = self.parser.parse() 130 | if tree is None: 131 | return '' 132 | return self.visit(tree) 133 | -------------------------------------------------------------------------------- /tests/test_derivative_calculator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import typing 3 | from derivative_calculator.math_parser import Var, Num, BinOp, UnaryOp, Parser 4 | from derivative_calculator.interpreter import Interpreter 5 | from derivative_calculator.symb_diff_tool import deriv 6 | from derivative_calculator.tokenizer import ( 7 | Token, 8 | Tokenizer, 9 | INTEGER, 10 | VAR, 11 | PLUS, 12 | MUL, 13 | DIV, 14 | POW, 15 | FUNC, 16 | ) 17 | 18 | Node = typing.Union[UnaryOp, BinOp, Num, Var] 19 | 20 | 21 | def get_parsed_expr(expr: str) -> Node: 22 | ''' 23 | Takes a string containing an expression 24 | and returns the AST tree corresponding to the 25 | parsed expression 26 | ''' 27 | return Parser(Tokenizer(expr)).parse() 28 | 29 | 30 | def interpret_ast(ast_expr: Node) -> str: 31 | ''' 32 | Takes an AST tree and returns the string 33 | containing its interpreted expression 34 | ''' 35 | return Interpreter().visit(ast_expr) 36 | 37 | 38 | def get_derivative(expr: str, var: str) -> str: 39 | ''' 40 | Takes an math function and a variable and returns a string 41 | containing the interpreted expression of its derivative 42 | ''' 43 | deriv_ast: Node = deriv(get_parsed_expr(expr), Var(Token(VAR, var))) 44 | return interpret_ast(deriv_ast) 45 | 46 | 47 | def compare_ast(node_1: Node, node_2: Node) -> bool: # type: ignore[return] 48 | ''' 49 | Function that compares two abstract syntax trees 50 | node for node, returning True if they are equal 51 | and False if not 52 | ''' 53 | if not (node_1 or node_2): 54 | return True 55 | 56 | if (node_1 and not node_2) or (not node_1 and node_2): 57 | return False 58 | 59 | if type(node_1) is not type(node_2): 60 | return False 61 | 62 | if ((isinstance(node_1, Num) and isinstance(node_2, Num)) or 63 | (isinstance(node_1, Var) and isinstance(node_2, Var))): 64 | if node_1.value == node_2.value: 65 | return True 66 | return False 67 | 68 | if isinstance(node_1, BinOp) and isinstance(node_2, BinOp): 69 | if node_1.op.type == node_2.op.type: 70 | return (compare_ast(node_1.left, node_2.left) and 71 | compare_ast(node_1.right, node_2.right)) 72 | return False 73 | 74 | if isinstance(node_1, UnaryOp) and isinstance(node_2, UnaryOp): 75 | if node_1.op.type == node_2.op.type: 76 | return compare_ast(node_1.expr, node_2.expr) 77 | return False 78 | 79 | 80 | node_1 = BinOp( # node representing the function 3*x**2+5 81 | BinOp( 82 | left=Num(Token(INTEGER, 3)), 83 | op=Token(MUL, '*'), 84 | right=BinOp( 85 | left=Var(Token(VAR, 'x')), 86 | op=Token(POW, '**'), 87 | right=Num(Token(INTEGER, 2)) 88 | ) 89 | ), 90 | op=Token(PLUS, '+'), 91 | right=Num(Token(INTEGER, 5)) 92 | ) 93 | 94 | node_2 = BinOp( # node representing the function x**(1/2)*y 95 | BinOp( 96 | left=Var(Token(VAR, 'x')), 97 | op=Token(POW, '**'), 98 | right=BinOp( 99 | left=Num(Token(INTEGER, 1)), 100 | op=Token(DIV, '/'), 101 | right=Num(Token(INTEGER, 2)) 102 | ) 103 | ), 104 | op=Token(MUL, '*'), 105 | right=Var(Token(VAR, 'y')) 106 | ) 107 | 108 | node_3 = BinOp( # node representing the function log(x**2)/(x+y) 109 | UnaryOp( 110 | op=Token(FUNC, 'log'), 111 | expr=BinOp( 112 | left=Var(Token(VAR, 'x')), 113 | op=Token(POW, '**'), 114 | right=Num(Token(INTEGER, 2)) 115 | ) 116 | ), 117 | op=Token(DIV, '/'), 118 | right=BinOp( 119 | left=Var(Token(VAR, 'x')), 120 | op=Token(PLUS, '+'), 121 | right=Var(Token(VAR, 'y')), 122 | ) 123 | ) 124 | 125 | 126 | def test_parser() -> None: 127 | ''' 128 | We get the return values resulting from calling our parser with 129 | different inputs and we compare them with our manually built ASTs 130 | containing the expected results 131 | ''' 132 | assert compare_ast(get_parsed_expr('3*x**2+5'), node_1) 133 | assert compare_ast(get_parsed_expr('x**(1/2)*y'), node_2) 134 | assert compare_ast(get_parsed_expr('log(x**2)/(x+y)'), node_3) 135 | 136 | 137 | def test_interpreter() -> None: 138 | ''' 139 | Here we pass our manually built ASTs to the interpreter and 140 | compare the string outputs with the expected results 141 | ''' 142 | assert interpret_ast(node_1) == '3*x**2+5' 143 | assert interpret_ast(node_2) == 'x**(1/2)*y' 144 | assert interpret_ast(node_3) == 'log(x**2)/(x+y)' 145 | 146 | 147 | @pytest.mark.parametrize("test_input,expected", [ 148 | (get_derivative('x', 'x'), '1'), 149 | (get_derivative('x', 'y'), '0'), 150 | (get_derivative('x**5', 'x'), '5*x**4'), 151 | (get_derivative('3*2*y', 'y'), '6'), 152 | (get_derivative('x*y', 'y'), 'x'), 153 | (get_derivative('sin(x)', 'x'), 'cos(x)'), 154 | (get_derivative('x**(1/2)', 'x'), '(1/2)*x**((1/2)-1)'), 155 | (get_derivative('exp(x)', 'x'), 'exp(x)'), 156 | (get_derivative('log(x**2)', 'x'), '1/(x**2)*2*x'), 157 | (get_derivative('(1+x)*3**x', 'x'), '(1+x)*3**x*log(3)+3**x') 158 | ]) 159 | def test_derivatives(test_input: str, expected: str) -> None: 160 | ''' 161 | We test our whole program at once, calling the derivative 162 | function with the parsed AST tree for each pair math function/ 163 | variable and comparing the interpreted result with the expected 164 | result 165 | ''' 166 | assert test_input == expected 167 | -------------------------------------------------------------------------------- /src/derivative_calculator/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Collection of utils used in the symbolic 3 | differentiation tool and in the interpreter. 4 | ''' 5 | 6 | import typing 7 | from derivative_calculator.math_parser import Num, Var, UnaryOp, BinOp 8 | from derivative_calculator.tokenizer import Token, INTEGER, PLUS, MINUS, MUL, DIV, POW, FUNC 9 | 10 | Node = typing.Union[UnaryOp, BinOp, Num, Var] 11 | 12 | 13 | def binOp_type(node: Node, OP: str) -> bool: 14 | return isinstance(node, BinOp) and node.op.type == OP 15 | 16 | 17 | def is_number(node: Node) -> bool: 18 | return isinstance(node, Num) 19 | 20 | 21 | def is_zero(node: Node) -> bool: 22 | return isinstance(node, Num) and node.value == 0 23 | 24 | 25 | def is_one(node: Node) -> bool: 26 | return isinstance(node, Num) and node.value == 1 27 | 28 | 29 | def is_var(node: Node) -> bool: 30 | return isinstance(node, Var) 31 | 32 | 33 | def same_var(x: Node, y: Node) -> bool: 34 | return (isinstance(x, Var) 35 | and isinstance(y, Var)) and x.value == y.value 36 | 37 | 38 | def is_prefix_sign(node: Node) -> bool: 39 | return isinstance(node, UnaryOp) and node.op.type in (PLUS, MINUS) 40 | 41 | 42 | def is_func(node: Node) -> bool: 43 | return isinstance(node, UnaryOp) and node.op.type == FUNC 44 | 45 | 46 | def is_sum(node: Node) -> bool: 47 | return isinstance(node, BinOp) and node.op.type == PLUS 48 | 49 | 50 | def is_substr(node: Node) -> bool: 51 | return isinstance(node, BinOp) and node.op.type == MINUS 52 | 53 | 54 | def is_prod(node: Node) -> bool: 55 | return isinstance(node, BinOp) and node.op.type == MUL 56 | 57 | 58 | def is_div(node: Node) -> bool: 59 | return isinstance(node, BinOp) and node.op.type == DIV 60 | 61 | 62 | def is_pow(node: Node) -> bool: 63 | return isinstance(node, BinOp) and node.op.type == POW 64 | 65 | 66 | def is_rational_number(node: BinOp) -> bool: 67 | return is_div(node) and (is_number(node.left) and is_number(node.right)) 68 | 69 | 70 | def simplifyPrefixSign(node: UnaryOp) -> typing.Union[Num, UnaryOp]: 71 | plus_token = Token(PLUS, '+') 72 | minus_token = Token(MINUS, '-') 73 | prefixes: tuple[str, str] = (plus_token.value, minus_token.value) 74 | minus_counter = 0 75 | while is_prefix_sign(node): 76 | if node.op.type == MINUS: 77 | minus_counter += 1 78 | node = node.expr 79 | # at this point we know that node cannot be prefix sign expression 80 | sign = prefixes[minus_counter % 2] 81 | if isinstance(node, Num): 82 | prefix_sign = -1 if sign == '-' else 1 83 | token_value: int = prefix_sign * node.value 84 | return Num(Token(INTEGER, token_value)) 85 | else: 86 | curr_token: Token = minus_token if sign == '-' else plus_token 87 | return UnaryOp(curr_token, node) 88 | 89 | 90 | def make_sum(x: Node, y: Node) -> Node: 91 | if isinstance(x, UnaryOp) and x.op.type in (PLUS, MINUS): 92 | x = simplifyPrefixSign(x) 93 | 94 | if isinstance(y, UnaryOp) and y.op.type in (PLUS, MINUS): 95 | y = simplifyPrefixSign(y) 96 | 97 | if isinstance(x, Num) and isinstance(y, Num): 98 | return Num(Token(INTEGER, x.value + y.value)) 99 | 100 | if isinstance(x, Num) and x.value == 0: 101 | return y 102 | 103 | if isinstance(y, Num) and y.value == 0: 104 | return x 105 | return BinOp(x, Token(PLUS, '+'), y) 106 | 107 | 108 | def make_substr(x: Node, y: Node) -> Node: 109 | if isinstance(x, UnaryOp) and x.op.type in (PLUS, MINUS): 110 | x = simplifyPrefixSign(x) 111 | 112 | if isinstance(y, UnaryOp) and y.op.type in (PLUS, MINUS): 113 | y = simplifyPrefixSign(y) 114 | 115 | if isinstance(x, Num) and isinstance(y, Num): 116 | return Num(Token(INTEGER, x.value - y.value)) 117 | 118 | if isinstance(x, Num) and x.value == 0: 119 | return UnaryOp(op=Token(MINUS, '-'), expr=y) 120 | 121 | if isinstance(y, Num) and y.value == 0: 122 | return x 123 | 124 | return BinOp(x, Token(MINUS, '-'), y) 125 | 126 | 127 | def make_prod(x: Node, y: Node) -> Node: 128 | if isinstance(x, UnaryOp) and x.op.type in (PLUS, MINUS): 129 | x = simplifyPrefixSign(x) 130 | 131 | if isinstance(y, UnaryOp) and y.op.type in (PLUS, MINUS): 132 | y = simplifyPrefixSign(y) 133 | 134 | if isinstance(x, Num) and isinstance(y, Num): 135 | return Num(Token(INTEGER, x.value * y.value)) 136 | 137 | if (isinstance(x, Num) and x.value == 0 or 138 | isinstance(y, Num) and y.value == 0): 139 | return Num(Token(INTEGER, 0)) 140 | 141 | if isinstance(x, Num) and x.value == 1: 142 | return y 143 | 144 | if isinstance(y, Num) and y.value == 0: 145 | return x 146 | 147 | return BinOp(x, Token(MUL, '*'), y) 148 | 149 | 150 | def make_div(x: Node, y: Node) -> Node: 151 | if isinstance(x, UnaryOp) and x.op.type in (PLUS, MINUS): 152 | x = simplifyPrefixSign(x) 153 | 154 | if isinstance(y, UnaryOp) and y.op.type in (PLUS, MINUS): 155 | y = simplifyPrefixSign(y) 156 | 157 | if isinstance(y, Num) and y.value == 0: 158 | raise Exception('Error: division by zero') 159 | 160 | if isinstance(x, Num) and x.value == 0: 161 | return Num(Token(INTEGER, 0)) 162 | 163 | if isinstance(y, Num) and y.value == 1: 164 | return x 165 | 166 | return BinOp(x, Token(DIV, '/'), y) 167 | 168 | 169 | def make_power(x: Node, y: Node) -> Node: 170 | if isinstance(x, UnaryOp) and x.op.type in (PLUS, MINUS): 171 | x = simplifyPrefixSign(x) 172 | 173 | if isinstance(y, UnaryOp) and y.op.type in (PLUS, MINUS): 174 | y = simplifyPrefixSign(y) 175 | 176 | if isinstance(x, Num) and isinstance(y, Num): 177 | return Num(Token(INTEGER, x.value ** y.value)) 178 | 179 | if (isinstance(x, Num) and x.value == 1 or 180 | isinstance(y, Num) and y.value == 0): 181 | return Num(Token(INTEGER, 1)) 182 | 183 | if isinstance(x, Num) and x.value == 0: 184 | return Num(Token(INTEGER, 0)) 185 | 186 | if isinstance(y, Num) and y.value == 1: 187 | return x 188 | 189 | return BinOp(x, Token(POW, '**'), y) 190 | 191 | 192 | def make_func(func: str, arg: Node) -> UnaryOp: 193 | return UnaryOp(Token(FUNC, func), arg) 194 | -------------------------------------------------------------------------------- /src/derivative_calculator/symb_diff_tool.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Symbolic differentiation tool that takes an abstract syntax 3 | tree containing a math function as input and applies the 4 | derivative rules and the chain rule on it via depth first search 5 | to return an abstract syntax tree containing the derivative of the 6 | function. 7 | 8 | Derivative rules: https://en.wikipedia.org/wiki/Differentiation_rules 9 | Chain rule: https://en.wikipedia.org/wiki/Chain_rule 10 | ''' 11 | 12 | import typing 13 | from derivative_calculator.math_parser import Num, Var, UnaryOp, BinOp 14 | from derivative_calculator.tokenizer import Token, INTEGER, MINUS 15 | import derivative_calculator.utils as utils 16 | 17 | Node = typing.Union[UnaryOp, BinOp, Num, Var] 18 | 19 | 20 | def deriv(node: Node, var: Var) -> Node: 21 | if utils.is_number(node): 22 | return Num(Token(INTEGER, 0)) 23 | 24 | if utils.is_var(node): 25 | if utils.same_var(node, var): 26 | return Num(Token(INTEGER, 1)) 27 | return Num(Token(INTEGER, 0)) 28 | 29 | if isinstance(node, UnaryOp): 30 | if utils.is_prefix_sign(node): 31 | simpl_node: UnaryOp | Num = utils.simplifyPrefixSign(node) 32 | if isinstance(simpl_node, Num): 33 | return Num(Token(INTEGER, 0)) 34 | else: 35 | return UnaryOp(simpl_node.token, deriv(simpl_node.expr, var)) 36 | 37 | if utils.is_func(node): 38 | if node.value == 'exp': 39 | return utils.make_prod( 40 | node, 41 | deriv(node.expr, var) 42 | ) 43 | 44 | if node.value == 'log': 45 | return utils.make_prod( 46 | utils.make_div( 47 | Num(Token(INTEGER, 1)), 48 | node.expr 49 | ), 50 | deriv(node.expr, var) 51 | ) 52 | 53 | if node.value == 'sin': 54 | return utils.make_prod( 55 | utils.make_func( 56 | func='cos', 57 | arg=node.expr 58 | ), 59 | deriv(node.expr, var) 60 | ) 61 | 62 | if node.value == 'cos': 63 | return utils.make_prod( 64 | UnaryOp( 65 | Token(MINUS, '-'), 66 | utils.make_func( 67 | func='sin', 68 | arg=node.expr 69 | ) 70 | ), 71 | deriv(node.expr, var) 72 | ) 73 | 74 | if node.value == 'tan': 75 | return utils.make_prod( 76 | utils.make_power( 77 | utils.make_func( 78 | func='sec', 79 | arg=node.expr 80 | ), 81 | Num(Token(INTEGER, 2)) 82 | ), 83 | deriv(node.expr, var) 84 | ) 85 | 86 | if node.value == 'cosec': 87 | return utils.make_prod( 88 | utils.make_prod( 89 | UnaryOp( 90 | op=Token(MINUS, '-'), 91 | expr=node 92 | ), 93 | utils.make_func( 94 | func='cot', 95 | arg=node.expr 96 | ) 97 | ), 98 | deriv(node.expr, var) 99 | ) 100 | 101 | if node.value == 'sec': 102 | return utils.make_prod( 103 | utils.make_prod( 104 | node, 105 | utils.make_func( 106 | func='tan', 107 | arg=node.expr 108 | ) 109 | ), 110 | deriv(node.expr, var) 111 | ) 112 | 113 | if node.value == 'cot': 114 | return utils.make_prod( 115 | UnaryOp( 116 | op=Token(MINUS, '-'), 117 | expr=utils.make_power( 118 | utils.make_func( 119 | func='cosec', 120 | arg=node.expr 121 | ), 122 | Num(Token(INTEGER, 2)) 123 | ) 124 | ), 125 | deriv(node.expr, var) 126 | ) 127 | 128 | if isinstance(node, BinOp): 129 | if utils.is_sum(node): 130 | return utils.make_sum( 131 | deriv(node.left, var), 132 | deriv(node.right, var) 133 | ) 134 | 135 | if utils.is_substr(node): 136 | return utils.make_substr( 137 | deriv(node.left, var), 138 | deriv(node.right, var) 139 | ) 140 | 141 | if utils.is_prod(node): 142 | return utils.make_sum( 143 | utils.make_prod( 144 | node.left, 145 | deriv(node.right, var) 146 | ), 147 | utils.make_prod( 148 | deriv(node.left, var), 149 | node.right 150 | ) 151 | ) 152 | 153 | if utils.is_div(node): 154 | return utils.make_div( 155 | utils.make_substr( 156 | utils.make_prod( 157 | node.right, 158 | deriv(node.left, var) 159 | ), 160 | utils.make_prod( 161 | node.left, 162 | deriv(node.right, var) 163 | ) 164 | ), 165 | utils.make_power( 166 | node.right, 167 | Num(Token(INTEGER, 2)) 168 | ) 169 | ) 170 | 171 | if utils.is_pow(node): 172 | base: Node = node.left 173 | exponent: Node = node.right 174 | 175 | return utils.make_sum( 176 | utils.make_prod( 177 | utils.make_prod( 178 | exponent, 179 | utils.make_power( 180 | base, 181 | utils.make_substr( 182 | exponent, 183 | Num(Token(INTEGER, 1)) 184 | ) 185 | ) 186 | ), 187 | deriv(base, var) 188 | ), 189 | utils.make_prod( 190 | utils.make_prod( 191 | node, 192 | utils.make_func( 193 | 'log', 194 | base 195 | ) 196 | ), 197 | deriv(exponent, var) 198 | ) 199 | ) 200 | 201 | raise Exception('Could not find any tokens matching input') 202 | --------------------------------------------------------------------------------