├── LICENSE ├── README.md └── vm.py /LICENSE: -------------------------------------------------------------------------------- 1 | Public domain. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the code from the blog post "Making a simple VM interpreter in Python", 2 | which you can find at https://csl.name/post/vm/ 3 | 4 | Made by Christian Stigen Larsen, with some improvements from the people at r/Python. 5 | Put in the public domain. 6 | 7 | To run: 8 | 9 | Hit CTRL+D or type "exit" to quit. 10 | > 2 3 + 5 * println 11 | Constant-folded (2 + 3) to 5 12 | Constant-folded (5 * 5) to 25 13 | 25 14 | > ^D 15 | 16 | To test: 17 | 18 | $ python vm.py test 19 | Code before optimization: [2, 3, '+', 5, '*', 'println'] 20 | Constant-folded (2 + 3) to 5 21 | Constant-folded (5 * 5) to 25 22 | Code after optimization: [25, 'println'] 23 | Stack after running original program: 24 | 25 25 | Data stack (top first): 26 | Stack after running optimized program: 27 | 25 28 | Data stack (top first): 29 | Result: OK 30 | ** Program 1: Runs the code for `print((2+3)*4)` 31 | 20 32 | 33 | ** Program 2: Ask for numbers, computes sum and product. 34 | Enter a number: 12 35 | Enter another number: 13 36 | Their sum is: 25 37 | Their product is: 156 38 | 39 | ** Program 3: Shows branching and looping (use CTRL+D to exit). 40 | Enter a number: 1 41 | The number 1 is odd. 42 | Enter a number: 2 43 | The number 2 is even. 44 | Enter a number: 3 45 | The number 3 is odd. 46 | Enter a number: ^D 47 | -------------------------------------------------------------------------------- /vm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | """ 5 | A simple VM interpreter. 6 | 7 | Code from the post at http://csl.name/post/vm/ 8 | This version should work on both Python 2 and 3. 9 | """ 10 | 11 | from __future__ import print_function 12 | from collections import deque 13 | from io import StringIO 14 | import sys 15 | import tokenize 16 | 17 | 18 | def get_input(*args, **kw): 19 | """Read a string from standard input.""" 20 | if sys.version[0] == "2": 21 | return raw_input(*args, **kw) 22 | else: 23 | return input(*args, **kw) 24 | 25 | 26 | class Stack(deque): 27 | push = deque.append 28 | 29 | def top(self): 30 | return self[-1] 31 | 32 | 33 | class Machine: 34 | def __init__(self, code): 35 | self.data_stack = Stack() 36 | self.return_stack = Stack() 37 | self.instruction_pointer = 0 38 | self.code = code 39 | 40 | def pop(self): 41 | return self.data_stack.pop() 42 | 43 | def push(self, value): 44 | self.data_stack.push(value) 45 | 46 | def top(self): 47 | return self.data_stack.top() 48 | 49 | def run(self): 50 | while self.instruction_pointer < len(self.code): 51 | opcode = self.code[self.instruction_pointer] 52 | self.instruction_pointer += 1 53 | self.dispatch(opcode) 54 | 55 | def dispatch(self, op): 56 | dispatch_map = { 57 | "%": self.mod, 58 | "*": self.mul, 59 | "+": self.plus, 60 | "-": self.minus, 61 | "/": self.div, 62 | "==": self.eq, 63 | "cast_int": self.cast_int, 64 | "cast_str": self.cast_str, 65 | "drop": self.drop, 66 | "dup": self.dup, 67 | "exit": self.exit, 68 | "if": self.if_stmt, 69 | "jmp": self.jmp, 70 | "over": self.over, 71 | "print": self.print, 72 | "println": self.println, 73 | "read": self.read, 74 | "stack": self.dump_stack, 75 | "swap": self.swap, 76 | } 77 | 78 | if op in dispatch_map: 79 | dispatch_map[op]() 80 | elif isinstance(op, int): 81 | self.push(op) # push numbers on stack 82 | elif isinstance(op, str) and op[0]==op[-1]=='"': 83 | self.push(op[1:-1]) # push quoted strings on stack 84 | else: 85 | raise RuntimeError("Unknown opcode: '%s'" % op) 86 | 87 | # OPERATIONS FOLLOW: 88 | 89 | def plus(self): 90 | self.push(self.pop() + self.pop()) 91 | 92 | def exit(self): 93 | sys.exit(0) 94 | 95 | def minus(self): 96 | last = self.pop() 97 | self.push(self.pop() - last) 98 | 99 | def mul(self): 100 | self.push(self.pop() * self.pop()) 101 | 102 | def div(self): 103 | last = self.pop() 104 | self.push(self.pop() / last) 105 | 106 | def mod(self): 107 | last = self.pop() 108 | self.push(self.pop() % last) 109 | 110 | def dup(self): 111 | self.push(self.top()) 112 | 113 | def over(self): 114 | b = self.pop() 115 | a = self.pop() 116 | self.push(a) 117 | self.push(b) 118 | self.push(a) 119 | 120 | def drop(self): 121 | self.pop() 122 | 123 | def swap(self): 124 | b = self.pop() 125 | a = self.pop() 126 | self.push(b) 127 | self.push(a) 128 | 129 | def print(self): 130 | sys.stdout.write(str(self.pop())) 131 | sys.stdout.flush() 132 | 133 | def println(self): 134 | sys.stdout.write("%s\n" % self.pop()) 135 | sys.stdout.flush() 136 | 137 | def read(self): 138 | self.push(get_input()) 139 | 140 | def cast_int(self): 141 | self.push(int(self.pop())) 142 | 143 | def cast_str(self): 144 | self.push(str(self.pop())) 145 | 146 | def eq(self): 147 | self.push(self.pop() == self.pop()) 148 | 149 | def if_stmt(self): 150 | false_clause = self.pop() 151 | true_clause = self.pop() 152 | test = self.pop() 153 | self.push(true_clause if test else false_clause) 154 | 155 | def jmp(self): 156 | addr = self.pop() 157 | if isinstance(addr, int) and 0 <= addr < len(self.code): 158 | self.instruction_pointer = addr 159 | else: 160 | raise RuntimeError("JMP address must be a valid integer.") 161 | 162 | def dump_stack(self): 163 | print("Data stack (top first):") 164 | 165 | for v in reversed(self.data_stack): 166 | print(" - type %s, value '%s'" % (type(v), v)) 167 | 168 | 169 | def parse(text): 170 | # Note that the tokenizer module is intended for parsing Python source 171 | # code, so if you're going to expand on the parser, you may have to use 172 | # another tokenizer. 173 | 174 | if sys.version[0] == "2": 175 | stream = StringIO(unicode(text)) 176 | else: 177 | stream = StringIO(text) 178 | 179 | tokens = tokenize.generate_tokens(stream.readline) 180 | 181 | for toknum, tokval, _, _, _ in tokens: 182 | if toknum == tokenize.NUMBER: 183 | yield int(tokval) 184 | elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]: 185 | yield tokval 186 | elif toknum == tokenize.ENDMARKER: 187 | break 188 | else: 189 | raise RuntimeError("Unknown token %s: '%s'" % 190 | (tokenize.tok_name[toknum], tokval)) 191 | 192 | def constant_fold(code): 193 | """Constant-folds simple mathematical expressions like 2 3 + to 5.""" 194 | while True: 195 | # Find two consecutive numbers and an arithmetic operator 196 | for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])): 197 | if isinstance(a, int) and isinstance(b, int) \ 198 | and op in {"+", "-", "*", "/"}: 199 | m = Machine((a, b, op)) 200 | m.run() 201 | code[i:i+3] = [m.top()] 202 | print("Constant-folded %s%s%s to %s" % (a,op,b,m.top())) 203 | break 204 | else: 205 | break 206 | return code 207 | 208 | def repl(): 209 | print('Hit CTRL+D or type "exit" to quit.') 210 | 211 | while True: 212 | try: 213 | source = get_input("> ") 214 | code = list(parse(source)) 215 | code = constant_fold(code) 216 | Machine(code).run() 217 | except (RuntimeError, IndexError) as e: 218 | print("IndexError: %s" % e) 219 | except KeyboardInterrupt: 220 | print("\nKeyboardInterrupt") 221 | 222 | def test(code = [2, 3, "+", 5, "*", "println"]): 223 | print("Code before optimization: %s" % str(code)) 224 | optimized = constant_fold(code) 225 | print("Code after optimization: %s" % str(optimized)) 226 | 227 | print("Stack after running original program:") 228 | a = Machine(code) 229 | a.run() 230 | a.dump_stack() 231 | 232 | print("Stack after running optimized program:") 233 | b = Machine(optimized) 234 | b.run() 235 | b.dump_stack() 236 | 237 | result = a.data_stack == b.data_stack 238 | print("Result: %s" % ("OK" if result else "FAIL")) 239 | return result 240 | 241 | def examples(): 242 | print("** Program 1: Runs the code for `print((2+3)*4)`") 243 | Machine([2, 3, "+", 4, "*", "println"]).run() 244 | 245 | print("\n** Program 2: Ask for numbers, computes sum and product.") 246 | Machine([ 247 | '"Enter a number: "', "print", "read", "cast_int", 248 | '"Enter another number: "', "print", "read", "cast_int", 249 | "over", "over", 250 | '"Their sum is: "', "print", "+", "println", 251 | '"Their product is: "', "print", "*", "println" 252 | ]).run() 253 | 254 | print("\n** Program 3: Shows branching and looping (use CTRL+D to exit).") 255 | Machine([ 256 | '"Enter a number: "', "print", "read", "cast_int", 257 | '"The number "', "print", "dup", "print", '" is "', "print", 258 | 2, "%", 0, "==", '"even."', '"odd."', "if", "println", 259 | 0, "jmp" # loop forever! 260 | ]).run() 261 | 262 | 263 | if __name__ == "__main__": 264 | try: 265 | if len(sys.argv) > 1: 266 | cmd = sys.argv[1] 267 | if cmd == "repl": 268 | repl() 269 | elif cmd == "test": 270 | test() 271 | examples() 272 | else: 273 | print("Commands: repl, test") 274 | else: 275 | repl() 276 | except EOFError: 277 | print("") 278 | --------------------------------------------------------------------------------