├── .gitignore ├── .travis.yml ├── Makefile ├── README.md ├── analyze.go ├── analyze_test.go ├── ast.go ├── compile.go ├── demo ├── README.md └── optimize_constant.sc ├── env.go ├── env_test.go ├── example ├── array_test.sc ├── bubble_sort.sc ├── emoji.sc ├── factorial.sc ├── fib.sc ├── fizzbuzz.sc ├── gcd.sc ├── global_var.sc ├── if_test.sc ├── many_args.sc ├── optimize_constant.sc ├── pointer_test.sc ├── prime.sc ├── putchar.sc ├── quick_sort.sc ├── sum.sc └── sum_for.sc ├── ir.go ├── ir_test.go ├── lexer.go ├── lexer_test.go ├── main.go ├── main_test.go ├── optimize.go ├── optimize_test.go ├── parse.go ├── parse_test.go ├── parser.go.y ├── report ├── 1.pdf ├── 1.tex ├── 2.pdf ├── 2.tex ├── final.pdf └── final.tex ├── run ├── run-sample ├── run-test ├── sample ├── ng0.sc ├── ng1.sc ├── ok0.sc └── ok1.sc ├── src └── scc ├── test ├── advanced │ ├── ack.sc │ ├── bubble.sc │ ├── insert.sc │ ├── loop_sum.sc │ ├── matmul.sc │ ├── merge.sc │ ├── quick.sc │ ├── recur_sum.sc │ ├── ret_ptr.sc │ ├── share_ints.sc │ └── short.sc ├── basic │ ├── arith.sc │ ├── array.sc │ ├── cmp.sc │ ├── fib.sc │ ├── gcd.sc │ ├── global.sc │ ├── logic.sc │ ├── scope.sc │ ├── swap.sc │ └── while.sc └── err │ ├── name01.sc │ ├── name02.sc │ ├── name03.sc │ ├── name04.sc │ ├── name05.sc │ ├── name06.sc │ ├── name07.sc │ ├── name08.sc │ ├── name09.sc │ ├── name10.sc │ ├── name11.sc │ ├── shape01.sc │ ├── shape02.sc │ ├── shape03.sc │ ├── type01.sc │ ├── type02.sc │ ├── type03.sc │ ├── type04.sc │ ├── type05.sc │ ├── type06.sc │ ├── type07.sc │ ├── type08.sc │ ├── type09.sc │ ├── type10.sc │ ├── type11.sc │ ├── type12.sc │ ├── type13.sc │ ├── type14.sc │ ├── type15.sc │ ├── type16.sc │ ├── type17.sc │ └── type18.sc ├── type.go ├── type_test.go └── util.go /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/go 3 | 4 | ### Go ### 5 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 6 | *.o 7 | *.a 8 | *.so 9 | 10 | # Folders 11 | _obj 12 | _test 13 | 14 | # Architecture specific extensions/prefixes 15 | *.[568vq] 16 | [568vq].out 17 | 18 | *.cgo1.go 19 | *.cgo2.c 20 | _cgo_defun.c 21 | _cgo_gotypes.go 22 | _cgo_export.* 23 | 24 | _testmain.go 25 | 26 | *.exe 27 | *.test 28 | *.prof 29 | 30 | y.output 31 | small-c 32 | 33 | # generated by yacc 34 | parser.go 35 | 36 | *.dvi 37 | *.log 38 | *.aux 39 | 40 | *.s 41 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.6 5 | - tip 6 | 7 | before_install: 8 | - wget https://github.com/ymyzk/spim-for-kuis/archive/master.zip 9 | - unzip master.zip 10 | - cd ./spim-for-kuis-master/spim/ 11 | - sudo make DEST_DIR=/usr spim install 12 | - cd ../../ 13 | - make deps 14 | 15 | install: 16 | - go get golang.org/x/tools/cmd/cover 17 | - go get github.com/mattn/goveralls 18 | 19 | script: make ci-test 20 | 21 | env: 22 | global: 23 | secure: "WK4J/RFtZjHY9wZLSW3UqEzwfzpl/klMpTzHLuMjTvfCE5+Op7bAwsn4uPosXVPkTkQRDrP4ckl153EFjq57MkvxLwaut3I7qlQSpspibP5HEzpMVkwwhQkYSMw0qCDYajYoWLNTidGcIdm4FMTr/tXdJkMy2pBKiXj8lKiZRofVL6n4oHMnqevsIMWVY7wi1j3+iY8JrVBes3kfYB9zwMPeHHWOVd6SygBM1+Ug6qTFSSUPQn5YSr8W8dPzhHImmqlL3A2CxofAnhw01P0jA5EQgaWiTa9vExQ3uJK2QP8tKQQL62LxQoEN+DYRRa5Ix0kebgQXk5kYgHmF35/fzLXN++FZ8SvJPzocMKkf4Qcx/4gCiuXbNU3UdWP3Z8UHgfJ61m70qeTpqD9zHZlA73cBvcvcwE6exmi7KS4WCTYt4B/l/Sbs5h2CeWq/91QQQXw6UFhfHfL8awmVIUBvvH0L8RlR6skO8OrTe/fO/4hZrMTnWfshFicakZAtTt8/eBBMltoZ31GM9BaRuPkXFbRkUSYdbeIDHKhAhxKGbAK7pgDGkJrijh3MVGT+PiH5JIytpbjqf1OBE9zr24gAcoqWt02o9puDhHRiu1LwrVVLi2Y5RNJ9QzNlg6jib661VC1ssDLxxrVLdzz3qM4Qb1SO48itPrlkTxFveAMkMIE=" 24 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: exec 2 | 3 | exec: deps parser.go 4 | go build 5 | 6 | deps: 7 | go get -d -v 8 | 9 | parser.go: parser.go.y 10 | go tool yacc -o parser.go parser.go.y 11 | 12 | test: parser.go 13 | go test -v -cover ./... 14 | 15 | ci-test: parser.go 16 | go test -v -covermode=count -coverprofile=coverage.out 17 | $(HOME)/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci -repotoken $(COVERALLS_TOKEN) 18 | 19 | examples := $(wildcard example/*.sc) 20 | destfiles := $(patsubst example/%.sc,example/%.s,$(examples)) 21 | example: $(destfiles) 22 | 23 | example/%.s: example/%.sc exec 24 | ./small-c $< > $@ 25 | 26 | report: report/1.pdf report/2.pdf report/final.pdf 27 | 28 | report/%.pdf: %.dvi 29 | dvipdfmx -o $@ $< 30 | 31 | %.dvi: report/%.tex 32 | platex $< 33 | 34 | .PHONY: test ci-test 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Small C 2 | [![Build Status](https://travis-ci.org/uiureo/small-c.svg?branch=master)](https://travis-ci.org/uiureo/small-c) [![Coverage Status](https://coveralls.io/repos/github/uiureo/small-c/badge.svg?branch=master)](https://coveralls.io/github/uiureo/small-c?branch=master) 3 | 4 | Small C compiler in Go. "Small C" is a small subset of C. 5 | 6 | The target assembly language is MIPS. 7 | 8 | This compiler is for [京都大学工学部情報学科計算機科学コース / 計算機科学実験及び演習3](http://www.fos.kuis.kyoto-u.ac.jp/~umatani/le3b/). 9 | 10 | ## Run 11 | 12 | ``` sh 13 | make 14 | ./small-c example/quick_sort.sc 15 | ``` 16 | 17 | ## Test 18 | The test command requires [spim CLI](https://github.com/ymyzk/spim-for-kuis). 19 | 20 | ```sh 21 | make test 22 | ``` 23 | 24 | ## Example 25 | This example shows bubble sort algorithm in Small C. 26 | 27 | example/bubble_sort.sc: 28 | ``` c 29 | void bubble_sort(int *p, int size) { 30 | int i, j, tmp; 31 | 32 | for (i = 0; i < size; i = i + 1) { 33 | for (j = 1; j < size; j = j + 1) { 34 | int current; 35 | int prev; 36 | 37 | current = *(p + j); 38 | prev = *(p + j - 1); 39 | if (current < prev) { 40 | tmp = current; 41 | *(p + j) = prev; 42 | *(p + j - 1) = tmp; 43 | } 44 | } 45 | } 46 | } 47 | ``` 48 | 49 | example/fizzbuzz.sc: 50 | ```c 51 | int mod(int x, int y) { 52 | return x - (x / y) * y; 53 | } 54 | 55 | void puts(int *s) { 56 | while (*s != 0) { 57 | putchar(*s); 58 | s = s + 1; 59 | } 60 | } 61 | 62 | int main() { 63 | int i; 64 | int fizz[5]; 65 | int buzz[5]; 66 | 67 | fizz[0] = 'F'; fizz[1] = 'i'; fizz[2] = 'z'; fizz[3] = 'z'; fizz[4] = 0; 68 | buzz[0] = 'B'; buzz[1] = 'u'; buzz[2] = 'z'; buzz[3] = 'z'; buzz[4] = 0; 69 | 70 | for (i = 1; i <= 30; i = i + 1) { 71 | if (mod(i, 3) == 0) { 72 | puts(fizz); 73 | } 74 | 75 | if (mod(i, 5) == 0) { 76 | puts(buzz); 77 | } 78 | 79 | if (mod(i, 3) != 0 && mod(i, 5) != 0) { 80 | print(i); 81 | } 82 | 83 | putchar(' '); 84 | } 85 | } 86 | 87 | ``` 88 | 89 | There are other examples in `src/example/` 90 | 91 | ## License 92 | MIT 93 | -------------------------------------------------------------------------------- /analyze.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | // Analyze ast and register variables to env 9 | func Analyze(statements []Statement, env *Env) []error { 10 | var errs []error 11 | for _, statement := range statements { 12 | errs = append(errs, analyzeStatement(statement, env)...) 13 | } 14 | 15 | return errs 16 | } 17 | 18 | func analyzeStatement(statement Statement, env *Env) []error { 19 | var errs []error 20 | 21 | switch s := statement.(type) { 22 | case *FunctionDefinition: 23 | errs = analyzeFunctionDefinition(s, env) 24 | 25 | case *Declaration: 26 | errs = analyzeDeclaration(s, env) 27 | 28 | case *CompoundStatement: 29 | errs = analyzeCompoundStatement(s, env) 30 | 31 | case *IfStatement: 32 | errs = append(errs, analyzeExpression(s.Condition, env)...) 33 | errs = append(errs, analyzeStatement(s.TrueStatement, env)...) 34 | errs = append(errs, analyzeStatement(s.FalseStatement, env)...) 35 | 36 | case *WhileStatement: 37 | // ForStatement is converted to WhileStatement 38 | errs = analyzeExpression(s.Condition, env) 39 | errs = append(errs, analyzeStatement(s.Statement, env)...) 40 | 41 | case *ExpressionStatement: 42 | errs = analyzeExpression(s.Value, env) 43 | 44 | case *ReturnStatement: 45 | // Set current function symbol to check type 46 | s.FunctionSymbol = env.Get("#func") 47 | errs = analyzeExpression(s.Value, env) 48 | 49 | } 50 | 51 | return errs 52 | } 53 | 54 | func analyzeFunctionDefinition(s *FunctionDefinition, env *Env) []error { 55 | errs := []error{} 56 | 57 | identifier := findIdentifierExpression(s.Identifier) 58 | argTypes := []SymbolType{} 59 | 60 | for _, p := range s.Parameters { 61 | parameter, ok := p.(*ParameterDeclaration) 62 | if ok { 63 | argType := BasicType{Name: parameter.TypeName} 64 | argTypes = append(argTypes, composeType(parameter.Identifier, argType)) 65 | } 66 | } 67 | 68 | returnType := composeType(s.Identifier, BasicType{Name: s.TypeName}) 69 | symbolType := FunctionType{Return: returnType, Args: argTypes} 70 | 71 | kind := "" 72 | if s.Statement != nil { 73 | kind = "fun" 74 | } else { 75 | kind = "proto" 76 | } 77 | 78 | err := env.Register(identifier, &Symbol{ 79 | Kind: kind, 80 | Type: symbolType, 81 | }) 82 | 83 | if err != nil { 84 | errs = append(errs, SemanticError{ 85 | Pos: s.Pos(), 86 | Err: err, 87 | }) 88 | } 89 | 90 | if s.Statement != nil { 91 | paramEnv := env.CreateChild() 92 | // Set special symbol to analyze function type 93 | paramEnv.Add(&Symbol{ 94 | Name: "#func", 95 | Type: symbolType, 96 | }) 97 | 98 | for _, p := range s.Parameters { 99 | parameter, ok := p.(*ParameterDeclaration) 100 | 101 | if ok { 102 | identifier := findIdentifierExpression(parameter.Identifier) 103 | argType := composeType(parameter.Identifier, BasicType{Name: parameter.TypeName}) 104 | 105 | err := paramEnv.Register(identifier, &Symbol{ 106 | Kind: "parm", 107 | Type: argType, 108 | }) 109 | 110 | if err != nil { 111 | errs = append(errs, SemanticError{ 112 | Pos: parameter.Pos(), 113 | Err: fmt.Errorf("parameter `%s` is already defined", identifier.Name), 114 | }) 115 | } 116 | } 117 | } 118 | 119 | errs = append(errs, analyzeStatement(s.Statement, paramEnv)...) 120 | } 121 | 122 | return errs 123 | } 124 | 125 | func analyzeDeclaration(s *Declaration, env *Env) []error { 126 | errs := []error{} 127 | for _, declarator := range s.Declarators { 128 | symbolType := composeType(declarator.Identifier, BasicType{Name: s.VarType}) 129 | if declarator.Size > 0 { 130 | symbolType = ArrayType{Value: symbolType, Size: declarator.Size} 131 | } 132 | 133 | identifier := findIdentifierExpression(declarator.Identifier) 134 | err := env.Register(identifier, &Symbol{ 135 | Kind: "var", 136 | Type: symbolType, 137 | }) 138 | 139 | if err != nil { 140 | errs = append(errs, SemanticError{ 141 | Pos: declarator.Pos(), 142 | Err: err, 143 | }) 144 | } 145 | } 146 | 147 | return errs 148 | } 149 | 150 | func analyzeCompoundStatement(s *CompoundStatement, env *Env) []error { 151 | var errs []error 152 | newEnv := env.CreateChild() 153 | for _, declaration := range s.Declarations { 154 | errs = append(errs, analyzeStatement(declaration, newEnv)...) 155 | } 156 | 157 | for _, statement := range s.Statements { 158 | errs = append(errs, analyzeStatement(statement, newEnv)...) 159 | } 160 | 161 | return errs 162 | } 163 | 164 | func analyzeExpression(expression Expression, env *Env) []error { 165 | var errs []error 166 | 167 | switch e := expression.(type) { 168 | case *IdentifierExpression: 169 | symbol := env.Get(e.Name) 170 | 171 | if symbol == nil { 172 | errs = append(errs, SemanticError{ 173 | Pos: e.Pos(), 174 | Err: fmt.Errorf("reference error: `%v` is undefined", e.Name), 175 | }) 176 | } else { 177 | if !symbol.IsVariable() { 178 | errs = append(errs, SemanticError{ 179 | Pos: e.Pos(), 180 | Err: fmt.Errorf("`%v` is not variable", e.Name), 181 | }) 182 | } else { 183 | e.Symbol = symbol 184 | } 185 | } 186 | 187 | case *ExpressionList: 188 | for _, value := range e.Values { 189 | errs = append(errs, analyzeExpression(value, env)...) 190 | } 191 | 192 | case *BinaryExpression: 193 | errs = append(errs, analyzeExpression(e.Left, env)...) 194 | errs = append(errs, analyzeExpression(e.Right, env)...) 195 | 196 | if e.Operator == "=" { 197 | leftIsAssignable := true 198 | 199 | switch left := e.Left.(type) { 200 | case *IdentifierExpression: 201 | expressionErrs := analyzeExpression(left, env) 202 | errs = append(errs, expressionErrs...) 203 | 204 | if len(errs) == 0 { 205 | _, isArrayType := left.Symbol.Type.(ArrayType) 206 | if !left.Symbol.IsVariable() || isArrayType { 207 | leftIsAssignable = false 208 | } 209 | } 210 | 211 | case *UnaryExpression: 212 | if left.Operator != "*" { 213 | leftIsAssignable = false 214 | } 215 | 216 | default: 217 | leftIsAssignable = false 218 | } 219 | 220 | if !leftIsAssignable { 221 | errs = append(errs, SemanticError{ 222 | Pos: e.Left.Pos(), 223 | Err: errors.New("expression is not assignable"), 224 | }) 225 | } 226 | } 227 | 228 | case *UnaryExpression: 229 | if e.Operator == "&" { 230 | switch v := e.Value.(type) { 231 | case *IdentifierExpression: 232 | default: 233 | errs = append(errs, SemanticError{ 234 | Pos: v.Pos(), 235 | Err: errors.New("the operand of `&` must be on memory"), 236 | }) 237 | } 238 | } 239 | 240 | return append(errs, analyzeExpression(e.Value, env)...) 241 | 242 | case *ArrayReferenceExpression: 243 | errs = append(errs, analyzeExpression(e.Target, env)...) 244 | errs = append(errs, analyzeExpression(e.Index, env)...) 245 | 246 | case *FunctionCallExpression: 247 | identifier := findIdentifierExpression(e.Identifier) 248 | symbol := env.Get(identifier.Name) 249 | if symbol == nil { 250 | return []error{ 251 | SemanticError{ 252 | Pos: identifier.Pos(), 253 | Err: fmt.Errorf("unknown function `%v` call", identifier.Name), 254 | }, 255 | } 256 | } 257 | 258 | if !(symbol.Kind == "fun" || symbol.Kind == "proto") { 259 | return []error{ 260 | SemanticError{ 261 | Pos: identifier.Pos(), 262 | Err: fmt.Errorf("`%v` is not a function", identifier.Name), 263 | }, 264 | } 265 | } 266 | 267 | identifier.Symbol = symbol 268 | return analyzeExpression(e.Argument, env) 269 | } 270 | 271 | return errs 272 | } 273 | 274 | func findIdentifierExpression(expression Expression) *IdentifierExpression { 275 | switch e := expression.(type) { 276 | case *IdentifierExpression: 277 | return e 278 | case *UnaryExpression: 279 | return findIdentifierExpression(e.Value) 280 | } 281 | 282 | panic("IdentifierExpression not found") 283 | } 284 | 285 | func composeType(identifier Expression, symbolType SymbolType) SymbolType { 286 | switch e := identifier.(type) { 287 | case *UnaryExpression: 288 | if e.Operator == "*" { 289 | return PointerType{Value: composeType(e.Value, symbolType)} 290 | } 291 | case *IdentifierExpression: 292 | return symbolType 293 | } 294 | 295 | return nil 296 | } 297 | -------------------------------------------------------------------------------- /analyze_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestAnalyze(t *testing.T) { 9 | env := &Env{} 10 | statements, _ := Parse(` 11 | int sum(int a, int b) { 12 | return a + b; 13 | } 14 | `) 15 | 16 | errs := Analyze(statements, env) 17 | if len(errs) != 0 { 18 | t.Errorf("expect no error, but got: %v", errs) 19 | } 20 | } 21 | 22 | func TestAnalyzeDeclaration(t *testing.T) { 23 | { 24 | statements, err := Parse("int a, b, c;\n") 25 | 26 | if err != nil { 27 | t.Errorf("parse error: %v", err) 28 | return 29 | } 30 | 31 | declaration := statements[0].(*Declaration) 32 | 33 | env := &Env{} 34 | analyzeDeclaration(declaration, env) 35 | 36 | if len(env.Table) != 3 { 37 | t.Errorf("env.Table should have 3 vars, but %v", env.Table) 38 | } 39 | 40 | symbol := env.Table["a"] 41 | if !(symbol != nil && symbol.Name == "a" && symbol.Kind == "var") { 42 | t.Errorf("symbol should be a variable, got %v", symbol) 43 | } 44 | } 45 | 46 | { 47 | statements, _ := Parse("int a[10], b;\n") 48 | declaration := statements[0].(*Declaration) 49 | 50 | env := &Env{} 51 | analyzeDeclaration(declaration, env) 52 | 53 | symbol := env.Table["a"] 54 | 55 | isArrayType := symbol != nil && reflect.TypeOf(symbol.Type).Name() == "ArrayType" 56 | correctSize := symbol != nil && symbol.Type.(ArrayType).Size == 10 57 | if !(isArrayType && correctSize) { 58 | t.Errorf("expect `a` to be an array: %v", symbol) 59 | } 60 | } 61 | 62 | { 63 | statements, _ := Parse("int a, b, a;\n") 64 | declaration := statements[0].(*Declaration) 65 | 66 | errs := analyzeDeclaration(declaration, &Env{}) 67 | 68 | if len(errs) == 0 { 69 | t.Errorf("should return an error when variables are double defined: %v", errs) 70 | return 71 | } 72 | } 73 | } 74 | 75 | func TestAnalyzeFunctionDefinition(t *testing.T) { 76 | { 77 | statements, _ := Parse(` 78 | int foo(int a, int b) { 79 | return a + b; 80 | } 81 | `) 82 | 83 | fn := statements[0].(*FunctionDefinition) 84 | env := &Env{} 85 | analyzeFunctionDefinition(fn, env) 86 | 87 | symbol := env.Table["foo"] 88 | if symbol == nil { 89 | t.Errorf("env should have `foo` as symbol: %v", env) 90 | return 91 | } 92 | 93 | symbolType, ok := symbol.Type.(FunctionType) 94 | if !ok { 95 | t.Errorf("symbol type should be FunctionType: %v", symbol) 96 | return 97 | } 98 | 99 | returnIsInt := symbolType.Return.(BasicType).Name == "int" 100 | if !returnIsInt { 101 | t.Errorf("expect return type to be int, but got %v", symbolType) 102 | } 103 | 104 | argsHaveTwoInt := len(symbolType.Args) == 2 && symbolType.Args[0].String() == "int" 105 | 106 | if !argsHaveTwoInt { 107 | t.Errorf("expect args to be (int, int): %v", symbolType.Args) 108 | } 109 | } 110 | 111 | { 112 | statements, _ := Parse(` 113 | int foo(int a, int a) { 114 | int b; 115 | 116 | return a + b; 117 | } 118 | `) 119 | 120 | fn := statements[0].(*FunctionDefinition) 121 | errs := analyzeFunctionDefinition(fn, &Env{}) 122 | 123 | if len(errs) != 1 { 124 | t.Errorf("should return `parameter already defined` error: %v", errs) 125 | } 126 | } 127 | } 128 | 129 | func TestAnalyzeCompoundStatement(t *testing.T) { 130 | statements, _ := Parse(` 131 | int main() { 132 | int a; 133 | int a; 134 | } 135 | `) 136 | 137 | def := statements[0].(*FunctionDefinition) 138 | compoundStatement := def.Statement.(*CompoundStatement) 139 | errs := analyzeCompoundStatement(compoundStatement, &Env{}) 140 | 141 | if len(errs) != 1 { 142 | t.Errorf("should have 1 error: %v", errs) 143 | } 144 | } 145 | 146 | func TestAnalyzeExpression(t *testing.T) { 147 | { 148 | env := &Env{} 149 | env.Add(&Symbol{Name: "foo", Kind: "var"}) 150 | 151 | errs := analyzeExpression(&IdentifierExpression{Name: "foo"}, env) 152 | if len(errs) > 0 { 153 | t.Errorf("expect no error, got %v", errs) 154 | } 155 | 156 | errs = analyzeExpression(&IdentifierExpression{Name: "bar"}, env) 157 | if len(errs) != 1 { 158 | t.Errorf("expect reference error, got %v", errs) 159 | } 160 | } 161 | 162 | { 163 | env := &Env{} 164 | env.Add(&Symbol{Name: "foo", Kind: "fun"}) 165 | 166 | errs := analyzeExpression(&IdentifierExpression{Name: "foo"}, env) 167 | if len(errs) != 1 { 168 | t.Errorf("expect not variable error, got %v", errs) 169 | } 170 | } 171 | 172 | { 173 | e := &FunctionCallExpression{ 174 | Identifier: &IdentifierExpression{ 175 | Name: "foo", 176 | }, 177 | Argument: &ExpressionList{}, 178 | } 179 | 180 | env := &Env{} 181 | env.Add(&Symbol{Name: "foo", Kind: "fun"}) 182 | 183 | errs := analyzeExpression(e, env) 184 | 185 | if len(errs) != 0 { 186 | t.Errorf("expect no error, but got: %v", errs) 187 | } 188 | 189 | env.Table["foo"] = &Symbol{Name: "foo", Kind: "var"} 190 | errs = analyzeExpression(e, env) 191 | 192 | if len(errs) != 1 { 193 | t.Errorf("expect not function error, got %v", errs) 194 | } 195 | } 196 | 197 | { 198 | env := &Env{} 199 | errs := analyzeExpression(&UnaryExpression{ 200 | Operator: "&", 201 | Value: &NumberExpression{ 202 | Value: "10", 203 | }, 204 | }, env) 205 | 206 | if len(errs) != 1 { 207 | t.Errorf("expect memory reference error, but got %v", errs) 208 | } 209 | } 210 | 211 | { 212 | env := &Env{} 213 | env.Add(&Symbol{ 214 | Name: "f", 215 | Kind: "fun", 216 | }) 217 | 218 | errs := analyzeExpression(&BinaryExpression{ 219 | Operator: "=", 220 | Left: &NumberExpression{Value: "1"}, 221 | Right: &NumberExpression{Value: "2"}, 222 | }, env) 223 | 224 | if len(errs) != 1 { 225 | t.Errorf("expect assignment error, but got: %v", errs) 226 | } 227 | } 228 | } 229 | 230 | func TestAnalyzeArrayAssignment(t *testing.T) { 231 | statements, _ := Parse(` 232 | int main() { 233 | int data[10]; 234 | data = 10; 235 | } 236 | `) 237 | 238 | errs := Analyze(statements, &Env{}) 239 | if len(errs) != 1 { 240 | t.Errorf("should have 1 error: %v", errs) 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /ast.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "text/scanner" 5 | ) 6 | 7 | type Token struct { 8 | lit string 9 | pos scanner.Position 10 | } 11 | 12 | type Node interface { 13 | Pos() scanner.Position 14 | } 15 | 16 | type Expression interface { 17 | Node 18 | } 19 | 20 | type ExpressionList struct { 21 | Values []Expression 22 | } 23 | 24 | func (e *ExpressionList) Pos() scanner.Position { 25 | first := e.Values[0] 26 | return first.Pos() 27 | } 28 | 29 | type NumberExpression struct { 30 | pos scanner.Position 31 | Value string 32 | } 33 | 34 | func (e *NumberExpression) Pos() scanner.Position { return e.pos } 35 | 36 | type IdentifierExpression struct { 37 | pos scanner.Position 38 | Name string 39 | Symbol *Symbol 40 | } 41 | 42 | func (e *IdentifierExpression) Pos() scanner.Position { return e.pos } 43 | 44 | type UnaryExpression struct { 45 | pos scanner.Position 46 | Operator string 47 | Value Expression 48 | } 49 | 50 | func (e *UnaryExpression) Pos() scanner.Position { return e.pos } 51 | 52 | type BinaryExpression struct { 53 | Left Expression 54 | Operator string 55 | Right Expression 56 | } 57 | 58 | func (e *BinaryExpression) Pos() scanner.Position { 59 | return e.Left.Pos() 60 | } 61 | 62 | func (e *BinaryExpression) IsAssignment() bool { 63 | return e.Operator == "=" 64 | } 65 | 66 | func (e *BinaryExpression) IsArithmetic() bool { 67 | return e.Operator == "+" || e.Operator == "-" || e.Operator == "/" || e.Operator == "*" 68 | } 69 | 70 | func (e *BinaryExpression) IsLogical() bool { 71 | return e.Operator == "&&" || e.Operator == "||" 72 | } 73 | 74 | func (e *BinaryExpression) IsEqual() bool { 75 | switch e.Operator { 76 | case "==", "!=", ">=", ">", "<=", "<": 77 | return true 78 | } 79 | 80 | return false 81 | } 82 | 83 | type FunctionCallExpression struct { 84 | Identifier Expression 85 | Argument Expression 86 | } 87 | 88 | func (e *FunctionCallExpression) Pos() scanner.Position { 89 | identifier := e.Identifier.(*IdentifierExpression) 90 | return identifier.Pos() 91 | } 92 | 93 | type ArrayReferenceExpression struct { 94 | Target Expression 95 | Index Expression 96 | } 97 | 98 | func (e *ArrayReferenceExpression) Pos() scanner.Position { 99 | return e.Target.Pos() 100 | } 101 | 102 | type PointerExpression struct { 103 | pos scanner.Position 104 | Value Expression 105 | } 106 | 107 | func (e *PointerExpression) Pos() scanner.Position { return e.pos } 108 | 109 | type Declarator struct { 110 | Identifier Expression 111 | Size int 112 | } 113 | 114 | func (e *Declarator) Pos() scanner.Position { 115 | switch identifier := e.Identifier.(type) { 116 | case *IdentifierExpression: 117 | return identifier.Pos() 118 | 119 | case *UnaryExpression: 120 | return identifier.Pos() 121 | } 122 | 123 | panic("unexpected identifier") 124 | } 125 | 126 | type Declaration struct { 127 | pos scanner.Position 128 | VarType string 129 | Declarators []*Declarator 130 | } 131 | 132 | func (e *Declaration) Pos() scanner.Position { return e.pos } 133 | 134 | type FunctionDefinition struct { 135 | pos scanner.Position 136 | TypeName string 137 | Identifier Expression 138 | Parameters []Expression 139 | Statement Statement 140 | } 141 | 142 | func (e *FunctionDefinition) Pos() scanner.Position { return e.pos } 143 | 144 | type Statement interface { 145 | Node 146 | } 147 | 148 | type CompoundStatement struct { 149 | pos scanner.Position 150 | Declarations []Statement 151 | Statements []Statement 152 | } 153 | 154 | func (e *CompoundStatement) Pos() scanner.Position { return e.pos } 155 | 156 | type ExpressionStatement struct { 157 | Value Expression 158 | } 159 | 160 | func (e *ExpressionStatement) Pos() scanner.Position { 161 | return e.Value.Pos() 162 | } 163 | 164 | type IfStatement struct { 165 | pos scanner.Position 166 | Condition Expression 167 | TrueStatement Statement 168 | FalseStatement Statement 169 | } 170 | 171 | func (e *IfStatement) Pos() scanner.Position { return e.pos } 172 | func (e *IfStatement) Statements() []Statement { 173 | return []Statement{e.TrueStatement, e.FalseStatement} 174 | } 175 | 176 | type WhileStatement struct { 177 | pos scanner.Position 178 | Condition Expression 179 | Statement Statement 180 | } 181 | 182 | func (e *WhileStatement) Pos() scanner.Position { return e.pos } 183 | func (e *WhileStatement) Statements() []Statement { 184 | return []Statement{e.Statement} 185 | } 186 | 187 | type ForStatement struct { 188 | pos scanner.Position 189 | Init Expression 190 | Condition Expression 191 | Loop Expression 192 | Statement Statement 193 | } 194 | 195 | func (e *ForStatement) Pos() scanner.Position { return e.pos } 196 | 197 | type ReturnStatement struct { 198 | pos scanner.Position 199 | Value Expression 200 | FunctionSymbol *Symbol 201 | } 202 | 203 | func (e *ReturnStatement) Pos() scanner.Position { return e.pos } 204 | 205 | type ParameterDeclaration struct { 206 | pos scanner.Position 207 | TypeName string 208 | Identifier Expression 209 | } 210 | 211 | func (e *ParameterDeclaration) Pos() scanner.Position { return e.pos } 212 | -------------------------------------------------------------------------------- /compile.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | func CalculateOffset(ir *IRProgram) { 9 | globalOffset := 0 10 | // global vars 11 | for _, d := range ir.Declarations { 12 | size := d.Var.Type.ByteSize() 13 | globalOffset -= size 14 | d.Var.Offset = globalOffset 15 | } 16 | 17 | for _, f := range ir.Functions { 18 | calculateOffsetFunction(f) 19 | } 20 | } 21 | 22 | func calculateOffsetFunction(ir *IRFunctionDefinition) { 23 | offset := 0 24 | 25 | for i := len(ir.Parameters) - 1; i >= 0; i-- { 26 | p := ir.Parameters[i] 27 | size := p.Var.Type.ByteSize() 28 | 29 | // arg 4 => 4($fp), arg 5 => 8($fp) 30 | if i >= 4 { 31 | p.Var.Offset = (i - 3) * size 32 | } else { 33 | p.Var.Offset = offset - (size - 4) 34 | offset -= size 35 | } 36 | } 37 | 38 | minOffset := calculateOffsetStatement(ir.Body, offset) 39 | ir.VarSize = -minOffset 40 | } 41 | 42 | func calculateOffsetStatement(statement IRStatement, base int) int { 43 | offset := base 44 | minOffset := 0 45 | 46 | switch s := statement.(type) { 47 | case *IRCompoundStatement: 48 | for _, d := range s.Declarations { 49 | size := d.Var.Type.ByteSize() 50 | d.Var.Offset = offset - (size - 4) 51 | offset -= size 52 | } 53 | 54 | minOffset = offset 55 | for _, s := range s.Statements { 56 | statementOffset := calculateOffsetStatement(s, offset) 57 | 58 | if statementOffset < minOffset { 59 | minOffset = statementOffset 60 | } 61 | } 62 | } 63 | 64 | return minOffset 65 | } 66 | 67 | // Compile takes ir program as input and returns mips code 68 | func Compile(program *IRProgram) string { 69 | CalculateOffset(program) 70 | 71 | code := "" 72 | code += ".data\n" 73 | code += ".text\n.globl main\n" 74 | for _, f := range program.Functions { 75 | code += "\n" + strings.Join(compileFunction(f), "\n") + "\n" 76 | } 77 | 78 | return code 79 | } 80 | 81 | func compileFunction(function *IRFunctionDefinition) []string { 82 | size := function.VarSize + 4*2 // arguments + local vars + $ra + $fp 83 | 84 | var code []string 85 | code = append( 86 | code, 87 | fmt.Sprintf("%s:", function.Var.Name), 88 | fmt.Sprintf("addi $sp, $sp, %d", -size), 89 | "sw $ra, 4($sp)", 90 | "sw $fp, 0($sp)", 91 | fmt.Sprintf("addi $fp, $sp, %d", size-4), 92 | ) 93 | 94 | for i := len(function.Parameters) - 1; i >= 0; i-- { 95 | p := function.Parameters[i] 96 | // arg 4,5,6... is passed via 4($fp), 8($fp), ... 97 | if i < 4 { 98 | code = append(code, fmt.Sprintf("sw $a%d, %d($fp)", i, p.Var.Offset)) 99 | } 100 | } 101 | 102 | code = append(code, compileStatement(function.Body, function)...) 103 | 104 | code = append( 105 | code, 106 | function.Var.Name+"_exit:", 107 | "lw $fp, 0($sp)", 108 | "lw $ra, 4($sp)", 109 | fmt.Sprintf("addi $sp, $sp, %d", size), 110 | "jr $ra", 111 | ) 112 | 113 | return code 114 | } 115 | 116 | func compileStatement(statement IRStatement, function *IRFunctionDefinition) []string { 117 | var code []string 118 | 119 | switch s := statement.(type) { 120 | case *IRCompoundStatement: 121 | for _, statement := range s.Statements { 122 | code = append(code, compileStatement(statement, function)...) 123 | } 124 | 125 | case *IRAssignmentStatement: 126 | code = append(code, assignExpression("$t0", s.Expression)...) 127 | code = append(code, sw("$t0", s.Var)) 128 | 129 | case *IRCallStatement: 130 | for i := len(s.Vars) - 1; i >= 0; i-- { 131 | v := s.Vars[i] 132 | 133 | if i >= 4 { 134 | code = append(code, lw("$t0", v)) 135 | code = append(code, 136 | "addi $sp, $sp, -4", 137 | fmt.Sprintf("sw %s, 0($sp)", "$t0"), 138 | ) 139 | } else { 140 | code = append(code, lw(fmt.Sprintf("$a%d", i), v)) 141 | } 142 | } 143 | 144 | code = append(code, fmt.Sprintf("jal %s", s.Func.Name)) 145 | if len(s.Vars) > 4 { 146 | code = append(code, fmt.Sprintf("addi $sp, $sp, %d", 4*(len(s.Vars)-4))) 147 | } 148 | code = append(code, sw("$v0", s.Dest)) 149 | 150 | case *IRReturnStatement: 151 | if s.Var != nil { 152 | code = append(code, 153 | lw("$v0", s.Var), 154 | ) 155 | } 156 | 157 | code = append(code, 158 | fmt.Sprintf("j %s_exit", function.Var.Name), 159 | ) 160 | 161 | case *IRWriteStatement: 162 | return []string{ 163 | lw("$t0", s.Src), 164 | lw("$t1", s.Dest), 165 | "sw $t0, 0($t1)", 166 | } 167 | 168 | case *IRReadStatement: 169 | return []string{ 170 | lw("$t0", s.Src), 171 | "lw $t1, 0($t0)", 172 | sw("$t1", s.Dest), 173 | } 174 | 175 | case *IRLabelStatement: 176 | return append(code, s.Name+":") 177 | 178 | case *IRIfStatement: 179 | falseLabel := label("ir_if_false") 180 | endLabel := label("ir_if_end") 181 | 182 | code = append(code, 183 | lw("$t0", s.Var), 184 | fmt.Sprintf("beq $t0, $zero, %s", falseLabel), 185 | ) 186 | 187 | if len(s.TrueLabel) > 0 { 188 | code = append(code, 189 | fmt.Sprintf("j %s", s.TrueLabel), 190 | ) 191 | } else { 192 | code = append(code, 193 | fmt.Sprintf("j %s", endLabel), 194 | ) 195 | } 196 | 197 | code = append(code, 198 | falseLabel+":", 199 | ) 200 | 201 | if len(s.FalseLabel) > 0 { 202 | code = append(code, 203 | fmt.Sprintf("j %s", s.FalseLabel), 204 | ) 205 | } 206 | 207 | code = append(code, 208 | endLabel+":", 209 | ) 210 | 211 | case *IRGotoStatement: 212 | code = append(code, jmp(s.Label)) 213 | 214 | case *IRSystemCallStatement: 215 | switch s.Name { 216 | case "print": 217 | return []string{ 218 | "li $v0, 1", 219 | lw("$a0", s.Var), 220 | "syscall", 221 | } 222 | case "putchar": 223 | return []string{ 224 | "li $v0, 11", 225 | lw("$a0", s.Var), 226 | "syscall", 227 | } 228 | 229 | default: 230 | panic("invalid system call: " + s.Name) 231 | 232 | } 233 | } 234 | 235 | return code 236 | } 237 | 238 | func assignExpression(register string, expression IRExpression) []string { 239 | var code []string 240 | 241 | switch e := expression.(type) { 242 | case *IRNumberExpression: 243 | code = append(code, fmt.Sprintf("li %s, %d", register, e.Value)) 244 | 245 | case *IRBinaryExpression: 246 | leftRegister := "$t1" 247 | rightRegister := "$t2" 248 | 249 | code = append(code, assignExpression(leftRegister, e.Left)...) 250 | code = append(code, 251 | "addi $sp, $sp, -4", 252 | fmt.Sprintf("sw %s, 0($sp)", leftRegister), 253 | ) 254 | code = append(code, assignExpression(rightRegister, e.Right)...) 255 | code = append(code, 256 | fmt.Sprintf("lw %s, 0($sp)", leftRegister), 257 | "addi $sp, $sp, 4", 258 | ) 259 | 260 | operation := assignBinaryOperation(register, e.Operator, leftRegister, rightRegister) 261 | 262 | return append(code, operation...) 263 | 264 | case *IRVariableExpression: 265 | // *(a + 4) 266 | _, isArrayType := e.Var.Type.(ArrayType) 267 | if isArrayType { 268 | return []string{ 269 | fmt.Sprintf("addi %s, %s, %d", register, e.Var.AddressPointer(), e.Var.Offset), 270 | } 271 | } 272 | 273 | return append(code, lw(register, e.Var)) 274 | 275 | case *IRAddressExpression: 276 | return []string{ 277 | fmt.Sprintf("addi %s, %s, %d", register, e.Var.AddressPointer(), e.Var.Offset), 278 | } 279 | } 280 | 281 | return code 282 | } 283 | 284 | func assignBinaryOperation(register string, operator string, left string, right string) []string { 285 | inst := operatorToInst[operator] 286 | if len(inst) > 0 { 287 | return []string{ 288 | fmt.Sprintf("%s %s, %s, %s", inst, register, left, right), 289 | } 290 | } 291 | 292 | switch operator { 293 | case "==": 294 | falseLabel := label("beq_true") 295 | endLabel := label("beq_end") 296 | 297 | return []string{ 298 | fmt.Sprintf("beq $t1, $t2, %s", falseLabel), 299 | li(register, 0), 300 | fmt.Sprintf("j %s", endLabel), 301 | falseLabel + ":", 302 | li(register, 1), 303 | endLabel + ":", 304 | } 305 | 306 | case "!=": 307 | falseLabel := label("beq_true") 308 | endLabel := label("beq_end") 309 | 310 | return []string{ 311 | fmt.Sprintf("beq $t1, $t2, %s", falseLabel), 312 | li(register, 1), 313 | fmt.Sprintf("j %s", endLabel), 314 | falseLabel + ":", 315 | li(register, 0), 316 | endLabel + ":", 317 | } 318 | 319 | case ">": 320 | // a > b <=> (a <= b) < 1 321 | return append(assignBinaryOperation(register, "<=", left, right), 322 | fmt.Sprintf("slti %s, %s, 1", register, register), 323 | ) 324 | 325 | case "<=": 326 | // a <= b <=> a - 1 < b 327 | return []string{ 328 | fmt.Sprintf("addi %s, %s, -1", left, left), 329 | fmt.Sprintf("slt %s, %s, %s", register, left, right), 330 | } 331 | 332 | case ">=": 333 | // a >= b <=> b <= a 334 | return assignBinaryOperation(register, "<=", right, left) 335 | } 336 | 337 | panic("unimplemented operator: " + operator) 338 | } 339 | 340 | var operatorToInst = map[string]string{ 341 | "+": "add", 342 | "-": "sub", 343 | "*": "mul", 344 | "/": "div", 345 | "<": "slt", 346 | } 347 | 348 | func jmp(label string) string { 349 | return fmt.Sprintf("j %s", label) 350 | } 351 | 352 | func li(register string, value int) string { 353 | return fmt.Sprintf("li %s, %d", register, value) 354 | } 355 | 356 | func lw(register string, src *Symbol) string { 357 | return fmt.Sprintf("lw %s, %d(%s)", register, src.Offset, src.AddressPointer()) 358 | } 359 | 360 | func sw(register string, dest *Symbol) string { 361 | return fmt.Sprintf("sw %s, %d(%s)", register, dest.Offset, dest.AddressPointer()) 362 | } 363 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | # 最終報告 2 | ## データフロー解析 3 | * 到達可能定義解析 4 | 5 | データフロー解析処理、最適化処理は `optimize.go` で行っている。 6 | 7 | ```go 8 | // optimize.go 9 | type DataflowBlock struct { 10 | Name string // BEGIN, END 11 | Statements []IRStatement 12 | Next []*DataflowBlock 13 | Prev []*DataflowBlock 14 | } 15 | 16 | func Optimize(program *IRProgram) *IRProgram { 17 | for i, f := range program.Functions { 18 | statements := flatStatement(f) 19 | 20 | // 中間表現プログラム列をデータフローのブロックごとに分ける 21 | blocks := splitStatementsIntoBlocks(statements) 22 | 23 | // ブロックの配列からデータフローを構成 24 | // blockそれぞれについて, block.Nextを設定していく 25 | buildDataflowGraph(blocks) 26 | 27 | // データフローを見て不動点反復により到達可能定義解析する 28 | // 返り値はブロックごとに, 各シンボルの到達可能な定義文 を入れたmap 29 | // blockOut = (DataflowBlock -> (*Symbol -> []IRStatement)) 30 | blockOut := searchReachingDefinitions(blocks) 31 | 32 | // ... 33 | } 34 | 35 | return program 36 | } 37 | 38 | // 不動点反復なので、状態が収束するまで地道に解析して状態を更新していくという雰囲気 39 | func searchReachingDefinitions(blocks []*DataflowBlock) map[*DataflowBlock]BlockState { 40 | blockOut := make(map[*DataflowBlock]BlockState) 41 | 42 | changed := true 43 | for changed { 44 | changed = false 45 | 46 | for _, block := range blocks { 47 | inState := analyzeBlock(blockOut, block) 48 | if !inState.Equal(blockOut[block]) { 49 | changed = true 50 | } 51 | 52 | blockOut[block] = inState 53 | } 54 | } 55 | 56 | return blockOut 57 | } 58 | 59 | // ひとつのプログラム点を見て状態を更新する 60 | // 到達可能定義解析の実質的な処理 61 | func analyzeReachingDefinition(statement IRStatement, inState BlockState) BlockState { 62 | switch s := statement.(type) { 63 | case *IRAssignmentStatement: 64 | inState[s.Var] = []IRStatement{s} 65 | symbols := extractAddressVarsFromExpression(s.Expression) 66 | for _, symbol := range symbols { 67 | inState[symbol] = append(inState[symbol], s) 68 | } 69 | 70 | case *IRReadStatement: 71 | inState[s.Dest] = []IRStatement{s} 72 | 73 | // ポインタ参照書き込みがあったら, 諦めムードにしておく 74 | case *IRWriteStatement: 75 | for symbol := range inState { 76 | inState[symbol] = append(inState[symbol], s) 77 | } 78 | 79 | case *IRCallStatement: 80 | inState[s.Dest] = []IRStatement{s} 81 | 82 | } 83 | 84 | return inState 85 | } 86 | 87 | ``` 88 | 89 | ## 最適化 90 | * 定数畳み込み 91 | * 無駄な命令の除去 92 | 93 | を到達可能定義解析を用いて実装した。 94 | 95 | ```go 96 | func Optimize(program *IRProgram) *IRProgram { 97 | for i, f := range program.Functions { 98 | // ... 99 | blockOut := searchReachingDefinitions(blocks) 100 | 101 | // 実装の都合で文ごとの到達可能定義を計算しなおしている 102 | allStatementState := reachingDefinitionsOfStatements(blocks, blockOut, statements) 103 | 104 | // 定数畳み込み 105 | program.Functions[i] = transformByConstantFolding(program.Functions[i], allStatementState) 106 | // 無駄コード除去 107 | program.Functions[i] = transformByDeadCodeElimination(program.Functions[i], allStatementState) 108 | } 109 | 110 | return program 111 | } 112 | ``` 113 | 114 | ### 定数畳み込み 115 | 116 | ```go 117 | func transformByConstantFolding(f *IRFunctionDefinition, allStatementState map[IRStatement]BlockState) *IRFunctionDefinition { 118 | traversed := Traverse(f, func(statement IRStatement) IRStatement { 119 | foldConstantStatement(statement, allStatementState) 120 | return statement 121 | }) 122 | 123 | return traversed.(*IRFunctionDefinition) 124 | } 125 | 126 | // 代入文ならexpressionを見て、それが定数だったら埋め込む 127 | func foldConstantStatement(statement IRStatement, allStatementState map[IRStatement]BlockState) (bool, int) { 128 | switch s := statement.(type) { 129 | case *IRAssignmentStatement: 130 | isConstant, value := foldConstantExpression(s, s.Expression, allStatementState) 131 | if isConstant { 132 | s.Expression = &IRNumberExpression{Value: value} 133 | return true, value 134 | } 135 | } 136 | 137 | return false, 0 138 | } 139 | 140 | // 到達可能定義の情報を使って、再帰的に定数畳み込みしていく 141 | func foldConstantExpression(statement IRStatement, expression IRExpression, allStatementState map[IRStatement]BlockState) (bool, int) { 142 | switch e := expression.(type) { 143 | case *IRNumberExpression: 144 | return true, e.Value 145 | 146 | case *IRVariableExpression: 147 | symbol := e.Var 148 | definitionOfVar := allStatementState[statement][symbol] 149 | if len(definitionOfVar) == 1 && definitionOfVar[0] != statement { 150 | return foldConstantStatement(definitionOfVar[0], allStatementState) 151 | } 152 | 153 | return false, 0 154 | 155 | case *IRBinaryExpression: 156 | leftIsConstant, leftValue := foldConstantExpression(statement, e.Left, allStatementState) 157 | rightIsConstant, rightValue := foldConstantExpression(statement, e.Right, allStatementState) 158 | 159 | if leftIsConstant { 160 | e.Left = &IRNumberExpression{Value: leftValue} 161 | } 162 | 163 | if rightIsConstant { 164 | e.Right = &IRNumberExpression{Value: rightValue} 165 | } 166 | 167 | if leftIsConstant && rightIsConstant { 168 | switch e.Operator { 169 | case "+": 170 | return true, leftValue + rightValue 171 | 172 | case "-": 173 | return true, leftValue - rightValue 174 | 175 | case "*": 176 | return true, leftValue * rightValue 177 | 178 | case "/": 179 | return true, leftValue / rightValue 180 | 181 | case "<": 182 | value := 0 183 | if leftValue < rightValue { 184 | value = 1 185 | } 186 | return true, value 187 | 188 | case ">": 189 | value := 0 190 | if leftValue > rightValue { 191 | value = 1 192 | } 193 | return true, value 194 | 195 | case "<=": 196 | value := 0 197 | if leftValue <= rightValue { 198 | value = 1 199 | } 200 | return true, value 201 | 202 | case ">=": 203 | value := 0 204 | if leftValue >= rightValue { 205 | value = 1 206 | } 207 | return true, value 208 | 209 | case "==": 210 | value := 0 211 | if leftValue == rightValue { 212 | value = 1 213 | } 214 | return true, value 215 | 216 | case "!=": 217 | value := 0 218 | if leftValue != rightValue { 219 | value = 1 220 | } 221 | return true, value 222 | 223 | } 224 | 225 | panic("unexpected operator: " + e.Operator) 226 | } 227 | 228 | return false, 0 229 | } 230 | 231 | return false, 0 232 | } 233 | ``` 234 | 235 | ### 無駄な命令の除去 236 | 237 | ```go 238 | // 収束するまで繰り返す 239 | // * 文を使用しているか到達可能定義を用いて見ていく 240 | // * 消しても大丈夫そうな使われていない文を発見したら消す 241 | // 242 | // 最後に要らない宣言を削除 243 | func transformByDeadCodeElimination(f *IRFunctionDefinition, allStatementState map[IRStatement]BlockState) *IRFunctionDefinition { 244 | changed := true 245 | for changed { 246 | changed = false 247 | 248 | used := make(map[IRStatement]bool) 249 | markAsUsed := func(s IRStatement, symbol *Symbol) { 250 | for _, definition := range allStatementState[s][symbol] { 251 | used[definition] = true 252 | } 253 | } 254 | 255 | Traverse(f, func(statement IRStatement) IRStatement { 256 | switch s := statement.(type) { 257 | case *IRCompoundStatement: 258 | used[s] = true 259 | 260 | case *IRAssignmentStatement: 261 | if s.Var.IsGlobal() { 262 | used[s] = true 263 | } 264 | 265 | vars := extractVarsFromExpression(s.Expression) 266 | for _, v := range vars { 267 | markAsUsed(s, v) 268 | } 269 | 270 | case *IRReadStatement: 271 | if s.Dest.IsGlobal() { 272 | used[s] = true 273 | } 274 | 275 | markAsUsed(s, s.Src) 276 | 277 | case *IRWriteStatement: 278 | markAsUsed(s, s.Src) 279 | markAsUsed(s, s.Dest) 280 | 281 | case *IRCallStatement: 282 | if s.Dest.IsGlobal() { 283 | used[s] = true 284 | } 285 | 286 | for _, argVar := range s.Vars { 287 | markAsUsed(s, argVar) 288 | } 289 | 290 | case *IRSystemCallStatement: 291 | markAsUsed(s, s.Var) 292 | 293 | case *IRReturnStatement: 294 | markAsUsed(s, s.Var) 295 | 296 | case *IRIfStatement: 297 | markAsUsed(s, s.Var) 298 | } 299 | 300 | return statement 301 | }) 302 | 303 | transformed := Traverse(f, func(statement IRStatement) IRStatement { 304 | switch statement.(type) { 305 | case *IRAssignmentStatement, *IRReadStatement: 306 | if !used[statement] { 307 | changed = true 308 | return nil 309 | } 310 | } 311 | 312 | return statement 313 | }) 314 | 315 | f = transformed.(*IRFunctionDefinition) 316 | } 317 | 318 | return removeUnusedVariableDeclaration(f) 319 | } 320 | 321 | ``` 322 | 323 | ## 例 324 | 簡単な例を用いて最適化を試す。 325 | 326 | ``` c 327 | // demo/optimize_constant.sc 328 | int main() { 329 | int a, b; 330 | int c; 331 | c = 3; 332 | 333 | a = c; // 3 334 | b = a + c; // 3 + 3 335 | print(a + b == 9); // 3 + 6 == 9 336 | } 337 | ``` 338 | 339 | `-optimize=false` オプションをつけて、比較用に最適化しなかった結果を出力する。 340 | ``` c 341 | ❯ ./small-c -optimize=false demo/optimize_constant.sc > demo/optimize_constant.s 342 | ❯ ./small-c demo/optimize_constant.sc > demo/optimize_constant_optimized.s 343 | ``` 344 | 345 | ### 最適化前 346 | 347 | ``` sh 348 | ❯ spim -show_stats -f demo/optimize_constant.s 349 | Loaded: /usr/local/share/spim/exceptions.s 350 | 1 351 | --- Summary --- 352 | # of executed instructions 353 | - Total: 47 354 | - Memory: 21 355 | - Others: 26 356 | 357 | --- Details --- 358 | add 2 359 | addi 9 360 | addiu 2 361 | addu 1 362 | beq 1 363 | jal 1 364 | jr 1 365 | lw 12 366 | ori 5 367 | sll 2 368 | sw 9 369 | syscall 2 370 | 371 | ``` 372 | 373 | ### 最適化後 374 | 375 | ```sh 376 | ❯ spim -show_stats -f demo/optimize_constant_optimized.s 377 | Loaded: /usr/local/share/spim/exceptions.s 378 | 1 379 | --- Summary --- 380 | # of executed instructions 381 | - Total: 22 382 | - Memory: 7 383 | - Others: 15 384 | 385 | --- Details --- 386 | addi 3 387 | addiu 2 388 | addu 1 389 | jal 1 390 | jr 1 391 | lw 4 392 | ori 3 393 | sll 2 394 | sw 3 395 | syscall 2 396 | 397 | ``` 398 | 399 | Total: 47 -> 22 400 | 401 | このように定数畳み込みと無駄な命令を除去を組み合わせると大きな効果がある場合がある。 402 | -------------------------------------------------------------------------------- /demo/optimize_constant.sc: -------------------------------------------------------------------------------- 1 | int main() { 2 | int a, b; 3 | int c; 4 | c = 3; 5 | 6 | a = c; 7 | b = a + c; 8 | print(a + b == 9); 9 | } 10 | -------------------------------------------------------------------------------- /env.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "text/scanner" 6 | ) 7 | 8 | type Env struct { 9 | Table map[string]*Symbol 10 | Level int 11 | Children []*Env 12 | Parent *Env 13 | } 14 | 15 | func (env *Env) CreateChild() *Env { 16 | newEnv := &Env{Parent: env, Level: env.Level + 1} 17 | env.Children = append(env.Children, newEnv) 18 | return newEnv 19 | } 20 | 21 | func (env *Env) Add(symbol *Symbol) error { 22 | if env.Table == nil { 23 | env.Table = map[string]*Symbol{} 24 | } 25 | 26 | name := symbol.Name 27 | found := env.Table[name] 28 | if found != nil { 29 | if symbol.IsVariable() { 30 | if (found.Kind == "proto" || found.Kind == "fun") && symbol.IsGlobal() { 31 | return fmt.Errorf("function `%v` is already defined", name) 32 | } 33 | } 34 | 35 | if found.Kind != "proto" { 36 | return fmt.Errorf("`%s` is already defined", name) 37 | } 38 | 39 | if found.Kind == "proto" && (symbol.Kind == "fun" || symbol.Kind == "proto") { 40 | functionType, _ := found.Type.(FunctionType) 41 | if symbol.Type.String() != functionType.String() { 42 | return fmt.Errorf("prototype mismatch error: function `%v`: `%v` != `%v`", name, functionType, symbol.Type) 43 | } 44 | } 45 | } 46 | 47 | if symbol.Level == 0 { 48 | symbol.Level = env.Level 49 | } 50 | 51 | env.Table[name] = symbol 52 | return nil 53 | } 54 | 55 | func (env *Env) Register(identifier *IdentifierExpression, symbol *Symbol) error { 56 | symbol.Name = identifier.Name 57 | err := env.Add(symbol) 58 | 59 | if err == nil { 60 | identifier.Symbol = symbol 61 | } 62 | 63 | return err 64 | } 65 | 66 | func (env *Env) Get(name string) *Symbol { 67 | symbol := env.Table[name] 68 | 69 | if symbol != nil { 70 | return symbol 71 | } 72 | 73 | if env.Parent != nil { 74 | return env.Parent.Get(name) 75 | } 76 | 77 | return nil 78 | } 79 | 80 | type Symbol struct { 81 | Name string 82 | Level int 83 | Kind string 84 | Type SymbolType 85 | Offset int 86 | } 87 | 88 | func (symbol *Symbol) IsVariable() bool { 89 | return symbol.Kind == "var" || symbol.Kind == "parm" 90 | } 91 | 92 | func (symbol *Symbol) IsGlobal() bool { 93 | return symbol.Level == 0 94 | } 95 | 96 | func (symbol *Symbol) AddressPointer() string { 97 | if symbol.IsGlobal() { 98 | return "$gp" 99 | } else { 100 | return "$fp" 101 | } 102 | } 103 | 104 | type SemanticError struct { 105 | error 106 | Pos scanner.Position 107 | Err error 108 | } 109 | 110 | func (e SemanticError) Error() string { 111 | return e.Err.Error() 112 | } 113 | -------------------------------------------------------------------------------- /env_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "testing" 4 | 5 | func TestType(t *testing.T) { 6 | data := [][]string{ 7 | {PointerType{Value: BasicType{Name: "int"}}.String(), "int*"}, 8 | { 9 | ArrayType{Value: BasicType{Name: "int"}, Size: 4}.String(), 10 | "int[4]", 11 | }, 12 | { 13 | FunctionType{ 14 | Return: BasicType{Name: "int"}, 15 | Args: []SymbolType{BasicType{Name: "int"}, BasicType{Name: "int"}}, 16 | }.String(), 17 | "(int, int) -> int", 18 | }, 19 | } 20 | 21 | for _, pair := range data { 22 | if pair[0] != pair[1] { 23 | t.Errorf("expect `%v`, got `%v`", pair[1], pair[0]) 24 | } 25 | } 26 | } 27 | 28 | func TestCreateChild(t *testing.T) { 29 | env := &Env{} 30 | 31 | child := env.CreateChild() 32 | if !(len(env.Children) > 0 && env.Children[0] == child && child.Level == env.Level+1) { 33 | t.Errorf("the return value should be a child: parent: %v, child: %v", env, child) 34 | } 35 | } 36 | 37 | func TestAdd(t *testing.T) { 38 | env := &Env{} 39 | err := env.Add(&Symbol{ 40 | Name: "foo", 41 | Kind: "var", 42 | }) 43 | 44 | if err != nil { 45 | t.Errorf("expect err == nil, but %v", err) 46 | } 47 | 48 | env.Add(&Symbol{ 49 | Name: "bar", 50 | Kind: "var", 51 | }) 52 | 53 | err = env.Add(&Symbol{ 54 | Name: "bar", 55 | Kind: "var", 56 | }) 57 | 58 | if err == nil { 59 | t.Errorf("should return already defined error, but err == nil") 60 | return 61 | } 62 | 63 | { 64 | env := &Env{} 65 | funcType := FunctionType{Return: Int(), Args: []SymbolType{Int()}} 66 | 67 | env.Add(&Symbol{Name: "f", Kind: "proto", Type: funcType}) 68 | err = env.Add(&Symbol{Name: "f", Kind: "proto", Type: funcType}) 69 | 70 | if err != nil { 71 | t.Errorf("kind `proto` can be defined double, but got \"%v\"", err) 72 | } 73 | } 74 | 75 | { 76 | env := &Env{} 77 | env.Add(&Symbol{ 78 | Name: "f", 79 | Kind: "proto", 80 | Type: FunctionType{Return: Int(), Args: []SymbolType{Int()}}, 81 | }) 82 | 83 | err := env.Add(&Symbol{ 84 | Name: "f", 85 | Kind: "fun", 86 | Type: FunctionType{Return: Void(), Args: []SymbolType{Int()}}, 87 | }) 88 | 89 | if err == nil { 90 | t.Errorf("expect prototype mismatch error, but got nil") 91 | } 92 | } 93 | } 94 | 95 | func TestRegister(t *testing.T) { 96 | env := &Env{} 97 | identifier := &IdentifierExpression{Name: "foo"} 98 | 99 | err := env.Register(identifier, &Symbol{ 100 | Kind: "var", 101 | }) 102 | 103 | if !(err == nil && identifier.Symbol.Name == "foo") { 104 | t.Errorf("expect identifier.Symbol == `foo`, but: %v", identifier.Symbol) 105 | } 106 | } 107 | 108 | func TestGet(t *testing.T) { 109 | parent := &Env{} 110 | parent.Add(&Symbol{Name: "parent"}) 111 | 112 | child := parent.CreateChild() 113 | child.Add(&Symbol{Name: "child"}) 114 | 115 | symbol := parent.Get("parent") 116 | if !(symbol != nil && symbol.Name == "parent") { 117 | t.Errorf("should return `parent` symbol: %v", symbol) 118 | } 119 | 120 | symbol = child.Get("parent") 121 | if !(symbol != nil && symbol.Name == "parent") { 122 | t.Errorf("should return `parent` symbol from parent: %v", symbol) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /example/array_test.sc: -------------------------------------------------------------------------------- 1 | int main() { 2 | int data[4]; 3 | int i; 4 | 5 | for (i = 0; i < 10; i = i + 1) { 6 | data[i] = i; 7 | } 8 | 9 | for (i = 0; i < 10; i = i + 1) { 10 | print(data[i]); 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /example/bubble_sort.sc: -------------------------------------------------------------------------------- 1 | void bubble_sort(int *p, int size) { 2 | int i, j, tmp; 3 | 4 | for (i = 0; i < size; i = i + 1) { 5 | for (j = 1; j < size; j = j + 1) { 6 | int current; 7 | int prev; 8 | 9 | current = *(p + j); 10 | prev = *(p + j - 1); 11 | if (current < prev) { 12 | tmp = current; 13 | *(p + j) = prev; 14 | *(p + j - 1) = tmp; 15 | } 16 | } 17 | } 18 | } 19 | 20 | int main() { 21 | int data[8]; 22 | int size; 23 | int i; 24 | 25 | size = 8; 26 | 27 | data[0] = 4; 28 | data[1] = 2; 29 | data[2] = 1; 30 | data[3] = 3; 31 | data[4] = 6; 32 | data[5] = 8; 33 | data[6] = 7; 34 | data[7] = 5; 35 | 36 | bubble_sort(data, size); 37 | 38 | for (i = 0; i < size; i = i + 1) { 39 | print(data[i]); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /example/emoji.sc: -------------------------------------------------------------------------------- 1 | int main() { 2 | int 🐶, 🐱; 3 | 4 | 🐱 = 0; 5 | for (🐶 = 0; 🐶 < 10; 🐶 = 🐶 + 1) { 6 | 🐱 = 🐱 + 🐶; 7 | } 8 | 9 | print(🐱); 10 | } 11 | -------------------------------------------------------------------------------- /example/factorial.sc: -------------------------------------------------------------------------------- 1 | int fact(int n) { 2 | if (n == 1) { 3 | return 1; 4 | } else { 5 | return n * fact(n - 1); 6 | } 7 | } 8 | 9 | int main() { 10 | print(fact(4)); 11 | } 12 | -------------------------------------------------------------------------------- /example/fib.sc: -------------------------------------------------------------------------------- 1 | int fib(int n) { 2 | if (n <= 1) { 3 | return 1; 4 | } else { 5 | return fib(n - 1) + fib(n - 2); 6 | } 7 | } 8 | 9 | int main() { 10 | print(fib(10)); 11 | } 12 | -------------------------------------------------------------------------------- /example/fizzbuzz.sc: -------------------------------------------------------------------------------- 1 | int mod(int x, int y) { 2 | return x - (x / y) * y; 3 | } 4 | 5 | void puts(int *s) { 6 | while (*s != 0) { 7 | putchar(*s); 8 | s = s + 1; 9 | } 10 | } 11 | 12 | int main() { 13 | int i; 14 | int fizz[5]; 15 | int buzz[5]; 16 | 17 | fizz[0] = 'F'; fizz[1] = 'i'; fizz[2] = 'z'; fizz[3] = 'z'; fizz[4] = 0; 18 | buzz[0] = 'B'; buzz[1] = 'u'; buzz[2] = 'z'; buzz[3] = 'z'; buzz[4] = 0; 19 | 20 | for (i = 1; i <= 30; i = i + 1) { 21 | if (mod(i, 3) == 0) { 22 | puts(fizz); 23 | } 24 | 25 | if (mod(i, 5) == 0) { 26 | puts(buzz); 27 | } 28 | 29 | if (mod(i, 3) != 0 && mod(i, 5) != 0) { 30 | print(i); 31 | } 32 | 33 | putchar(' '); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /example/gcd.sc: -------------------------------------------------------------------------------- 1 | // greatest common divisor by euclidean algorithm 2 | int mod(int x, int y) { 3 | return x - (x / y) * y; 4 | } 5 | 6 | int gcd(int x, int y) { 7 | if (y == 0) { 8 | return x; 9 | } else { 10 | return gcd(y, mod(x, y)); 11 | } 12 | } 13 | 14 | int main() { 15 | print(gcd(1071, 1029)); // 21 16 | } 17 | -------------------------------------------------------------------------------- /example/global_var.sc: -------------------------------------------------------------------------------- 1 | int a, data[10]; 2 | 3 | int main() { 4 | int *p; 5 | a = 42; 6 | 7 | data[0] = 1; 8 | data[1] = 2; 9 | data[2] = 3; 10 | 11 | p = &a; 12 | print(*p == 42); 13 | print(data[0] + data[2] == 4); 14 | } 15 | -------------------------------------------------------------------------------- /example/if_test.sc: -------------------------------------------------------------------------------- 1 | int main() { 2 | int i; 3 | int a, b; 4 | 5 | i = 42; 6 | if (i > 42) { 7 | print(1); 8 | } 9 | 10 | if (i >= 43) { 11 | print(2); 12 | } 13 | 14 | if (i < 42) { 15 | print(3); 16 | } 17 | 18 | if (i <= 41) { 19 | print(4); 20 | } 21 | 22 | a = 1; 23 | b = 2; 24 | if (a >= b) { 25 | print(5); 26 | } 27 | 28 | a = 1; 29 | b = 1; 30 | if (a != b) { 31 | print(6); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /example/many_args.sc: -------------------------------------------------------------------------------- 1 | int sum6(int a, int b, int c, int d, int e, int f) { 2 | return a + b + c + d + e + f; 3 | } 4 | 5 | int main() { 6 | print(sum6(1, 1, 1, 1, 1, 1)); 7 | } 8 | -------------------------------------------------------------------------------- /example/optimize_constant.sc: -------------------------------------------------------------------------------- 1 | int main() { 2 | int a, b; 3 | int c; 4 | c = 3; 5 | 6 | a = c; 7 | b = a + c; 8 | print(a + b == 9); 9 | } 10 | -------------------------------------------------------------------------------- /example/pointer_test.sc: -------------------------------------------------------------------------------- 1 | int main() { 2 | int a; 3 | int b; 4 | int *p; 5 | int data[2]; 6 | a = 1; 7 | 8 | p = &a; 9 | *p = 0; 10 | b = *p; 11 | 12 | data[0] = 1; 13 | data[1] = 2; 14 | 15 | print(a == 0 && b == 0 && *(1 + data) == 2); 16 | } 17 | -------------------------------------------------------------------------------- /example/prime.sc: -------------------------------------------------------------------------------- 1 | int prime[10]; 2 | 3 | int mod(int x, int y) { 4 | return x - (x / y) * y; 5 | } 6 | 7 | int main() { 8 | int x, i, j, k, N; 9 | 10 | N = 10; 11 | x = 1; 12 | k = 1; 13 | prime[0] = 2; 14 | 15 | while (k < N) { 16 | int m; 17 | 18 | x = x + 2; 19 | j = 0; 20 | 21 | while (j < k && mod(x, prime[j]) != 0) { 22 | j = j + 1; 23 | } 24 | 25 | if (j == k) { 26 | prime[k] = x; 27 | k = k + 1; 28 | } 29 | } 30 | 31 | for (i = 0; i < N; i = i + 1) { 32 | print(prime[i]); 33 | putchar(' '); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /example/putchar.sc: -------------------------------------------------------------------------------- 1 | int main() { 2 | putchar('h'); 3 | putchar('e'); 4 | putchar('l'); 5 | putchar('l'); 6 | putchar('o'); 7 | putchar(' '); 8 | putchar('w'); 9 | putchar('o'); 10 | putchar('r'); 11 | putchar('l'); 12 | putchar('d'); 13 | } 14 | -------------------------------------------------------------------------------- /example/quick_sort.sc: -------------------------------------------------------------------------------- 1 | void swap(int *p, int *q); 2 | 3 | void quick_sort(int *p, int left, int right) { 4 | int i, j, pivot; 5 | 6 | if (left >= right) { 7 | return; 8 | } 9 | 10 | pivot = *(p + left + (right - left) / 2); 11 | 12 | i = left; 13 | j = right; 14 | 15 | while (i < j) { 16 | while (*(p + i) < pivot) { 17 | i = i + 1; 18 | } 19 | 20 | while (pivot < *(p + j)) { 21 | j = j - 1; 22 | } 23 | 24 | if (i < j) { 25 | swap(p + i, p + j); 26 | 27 | i = i + 1; 28 | j = j - 1; 29 | } 30 | } 31 | 32 | quick_sort(p, left, i - 1); 33 | quick_sort(p, j + 1, right); 34 | } 35 | 36 | void swap(int *p, int *q) { 37 | int tmp; 38 | 39 | tmp = *p; 40 | *p = *q; 41 | *q = tmp; 42 | } 43 | 44 | int main() { 45 | int i, size, data[8]; 46 | size = 8; 47 | 48 | for (i = 0; i < size; i = i + 1) { 49 | data[i] = size - i; 50 | } 51 | 52 | swap(data, data + 4); 53 | swap(data + 1, data + 5); 54 | swap(data + 3, data + 7); 55 | 56 | quick_sort(data, 0, size); 57 | 58 | for (i = 0; i < size; i = i + 1) { 59 | print(data[i]); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /example/sum.sc: -------------------------------------------------------------------------------- 1 | int sum(int a, int b) { 2 | return a + b; 3 | } 4 | 5 | int main() { 6 | print(sum(100, 20) == 120); 7 | } 8 | -------------------------------------------------------------------------------- /example/sum_for.sc: -------------------------------------------------------------------------------- 1 | int main() { 2 | int sum; 3 | int i; 4 | 5 | sum = 0; 6 | 7 | for (i = 0; i < 10; i = i + 1) { 8 | sum = sum + i; 9 | } 10 | 11 | print(sum); 12 | } 13 | -------------------------------------------------------------------------------- /ir.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | // intermediate representation 11 | type IRProgram struct { 12 | Declarations []*IRVariableDeclaration 13 | Functions []*IRFunctionDefinition 14 | } 15 | 16 | type traverseAction (func(statement IRStatement) IRStatement) 17 | 18 | func Traverse(statement IRStatement, action traverseAction) IRStatement { 19 | switch s := statement.(type) { 20 | case *IRFunctionDefinition: 21 | body := Traverse(s.Body, action) 22 | if body != nil { 23 | s.Body = body 24 | return s 25 | } 26 | 27 | return nil 28 | 29 | case *IRCompoundStatement: 30 | statements := []IRStatement{} 31 | for _, statement := range s.Statements { 32 | transformed := Traverse(statement, action) 33 | if transformed != nil { 34 | statements = append(statements, transformed) 35 | } 36 | } 37 | 38 | s.Statements = statements 39 | 40 | return action(s) 41 | 42 | default: 43 | return action(statement) 44 | } 45 | } 46 | 47 | func (s *IRProgram) String() string { 48 | var declStrs []string 49 | for _, decl := range s.Declarations { 50 | declStrs = append(declStrs, decl.String()) 51 | } 52 | 53 | var stmtStrs []string 54 | for _, statement := range s.Functions { 55 | stmtStrs = append(stmtStrs, statement.String()) 56 | } 57 | 58 | return strings.Join(declStrs, "\n") + "\n\n" + strings.Join(stmtStrs, "\n\n") 59 | } 60 | 61 | type IRStatement interface { 62 | String() string 63 | } 64 | type IRExpression interface { 65 | String() string 66 | } 67 | 68 | type IRVariableDeclaration struct { 69 | Var *Symbol 70 | } 71 | 72 | func (s *IRVariableDeclaration) String() string { 73 | return fmt.Sprintf("%v %v", s.Var.Type, s.Var.Name) 74 | } 75 | 76 | type IRFunctionDefinition struct { 77 | Var *Symbol 78 | Parameters []*IRVariableDeclaration 79 | Body IRStatement 80 | VarSize int 81 | } 82 | 83 | func (s *IRFunctionDefinition) String() string { 84 | var params []string 85 | for _, p := range s.Parameters { 86 | params = append(params, p.String()) 87 | } 88 | 89 | return fmt.Sprintf("%v(%v)\n%v", s.Var.Name, strings.Join(params, ", "), s.Body) 90 | } 91 | 92 | type IRAssignmentStatement struct { 93 | Var *Symbol 94 | Expression IRExpression 95 | } 96 | 97 | func (s *IRAssignmentStatement) String() string { 98 | return fmt.Sprintf("%v = %v", s.Var.Name, s.Expression) 99 | } 100 | 101 | type IRWriteStatement struct { 102 | Dest *Symbol 103 | Src *Symbol 104 | } 105 | 106 | func (s *IRWriteStatement) String() string { 107 | return fmt.Sprintf("*%v = %v", s.Dest.Name, s.Src.Name) 108 | } 109 | 110 | type IRReadStatement struct { 111 | Dest *Symbol 112 | Src *Symbol 113 | } 114 | 115 | func (s *IRReadStatement) String() string { 116 | return fmt.Sprintf("%v = *%v", s.Dest.Name, s.Src.Name) 117 | } 118 | 119 | type IRLabelStatement struct { 120 | Name string 121 | } 122 | 123 | func (s *IRLabelStatement) String() string { 124 | return fmt.Sprintf("%s:", s.Name) 125 | } 126 | 127 | type IRIfStatement struct { 128 | Var *Symbol 129 | TrueLabel string 130 | FalseLabel string 131 | } 132 | 133 | func (s *IRIfStatement) String() string { 134 | if len(s.FalseLabel) == 0 { 135 | return fmt.Sprintf("if (%s) %s", s.Var.Name, s.TrueLabel) 136 | } 137 | 138 | return fmt.Sprintf("if (%s) %s else %s", s.Var.Name, s.TrueLabel, s.FalseLabel) 139 | } 140 | 141 | type IRGotoStatement struct { 142 | Label string 143 | } 144 | 145 | func (s *IRGotoStatement) String() string { 146 | return fmt.Sprintf("goto %s", s.Label) 147 | } 148 | 149 | type IRCallStatement struct { 150 | Dest *Symbol 151 | Func *Symbol 152 | Vars []*Symbol 153 | } 154 | 155 | func (s *IRCallStatement) String() string { 156 | var args []string 157 | for _, symbol := range s.Vars { 158 | args = append(args, symbol.Name) 159 | } 160 | 161 | return fmt.Sprintf("%s = %s(%s)", s.Dest.Name, s.Func.Name, strings.Join(args, ", ")) 162 | } 163 | 164 | type IRReturnStatement struct { 165 | Var *Symbol 166 | } 167 | 168 | func (s *IRReturnStatement) String() string { 169 | return fmt.Sprintf("return %s", s.Var.Name) 170 | } 171 | 172 | type IRSystemCallStatement struct { 173 | Name string 174 | Var *Symbol 175 | } 176 | 177 | func (s *IRSystemCallStatement) String() string { 178 | return fmt.Sprintf("%v(%s)", s.Name, s.Var.Name) 179 | } 180 | 181 | type IRCompoundStatement struct { 182 | Declarations []*IRVariableDeclaration 183 | Statements []IRStatement 184 | } 185 | 186 | func (s *IRCompoundStatement) String() string { 187 | var declStrs []string 188 | for _, decl := range s.Declarations { 189 | declStrs = append(declStrs, decl.String()) 190 | } 191 | 192 | var stmtStrs []string 193 | for _, statement := range s.Statements { 194 | stmtStrs = append(stmtStrs, statement.String()) 195 | } 196 | 197 | str := "" 198 | if len(declStrs) > 0 { 199 | str += strings.Join(declStrs, "\n") + "\n" 200 | } 201 | str += strings.Join(stmtStrs, "\n") 202 | 203 | return str 204 | } 205 | 206 | // IRExpression 207 | 208 | type IRVariableExpression struct { 209 | Var *Symbol 210 | } 211 | 212 | func (e *IRVariableExpression) String() string { 213 | return e.Var.Name 214 | } 215 | 216 | type IRNumberExpression struct { 217 | Value int 218 | } 219 | 220 | func (e *IRNumberExpression) String() string { 221 | return strconv.Itoa(e.Value) 222 | } 223 | 224 | type IRBinaryExpression struct { 225 | Operator string 226 | Left IRExpression 227 | Right IRExpression 228 | } 229 | 230 | func (e *IRBinaryExpression) String() string { 231 | return fmt.Sprintf("(%s %v %v)", e.Operator, e.Left, e.Right) 232 | } 233 | 234 | type IRAddressExpression struct { 235 | Var *Symbol 236 | } 237 | 238 | func (e *IRAddressExpression) String() string { 239 | return fmt.Sprintf("&%v", e.Var.Name) 240 | } 241 | 242 | var counter = map[string]int{} 243 | 244 | func label(name string) string { 245 | labelName := fmt.Sprintf("%s_%d", name, counter[name]) 246 | counter[name]++ 247 | 248 | return labelName 249 | } 250 | 251 | func tmpvar() *Symbol { 252 | return &Symbol{ 253 | Name: label("#tmp"), 254 | Level: 2, // not global 255 | Type: Int(), 256 | } 257 | } 258 | 259 | // CompileIR convert Statements to intermediate representation 260 | func CompileIR(statements []Statement) *IRProgram { 261 | var decls []*IRVariableDeclaration 262 | var funcs []*IRFunctionDefinition 263 | 264 | var irStatements []IRStatement 265 | for _, statement := range statements { 266 | switch s := statement.(type) { 267 | case *Declaration: 268 | symbols := findSymbolsFromDeclaration(s) 269 | decls = append(decls, IRVariableDeclarations(symbols)...) 270 | default: 271 | irStatements = append(irStatements, compileIRStatement(s)) 272 | } 273 | } 274 | 275 | for _, statement := range irStatements { 276 | switch s := statement.(type) { 277 | case *IRFunctionDefinition: 278 | funcs = append(funcs, s) 279 | case *IRVariableDeclaration: 280 | decls = append(decls, s) 281 | } 282 | } 283 | 284 | return &IRProgram{ 285 | Declarations: decls, 286 | Functions: funcs, 287 | } 288 | } 289 | 290 | func compileIRStatement(statement Statement) IRStatement { 291 | if statement == nil { 292 | return nil 293 | } 294 | 295 | switch s := statement.(type) { 296 | case *FunctionDefinition: 297 | if s.Statement == nil { 298 | return nil 299 | } 300 | 301 | identifier := findIdentifierExpression(s.Identifier) 302 | 303 | var paramSymbols []*Symbol 304 | for _, p := range s.Parameters { 305 | parameter, ok := p.(*ParameterDeclaration) 306 | if ok { 307 | identifier := findIdentifierExpression(parameter.Identifier) 308 | paramSymbols = append(paramSymbols, identifier.Symbol) 309 | } 310 | } 311 | 312 | return &IRFunctionDefinition{ 313 | Var: identifier.Symbol, 314 | Parameters: IRVariableDeclarations(paramSymbols), 315 | Body: compileIRStatement(s.Statement), 316 | } 317 | 318 | case *CompoundStatement: 319 | var symbols []*Symbol 320 | for _, d := range s.Declarations { 321 | declaration, ok := d.(*Declaration) 322 | if ok { 323 | symbols = append(symbols, findSymbolsFromDeclaration(declaration)...) 324 | } 325 | } 326 | 327 | var statements []IRStatement 328 | for _, statement := range s.Statements { 329 | statements = append(statements, compileIRStatement(statement)) 330 | } 331 | 332 | return &IRCompoundStatement{ 333 | Declarations: IRVariableDeclarations(symbols), 334 | Statements: statements, 335 | } 336 | 337 | case *ExpressionStatement: 338 | if s.Value == nil { 339 | return nil 340 | } 341 | 342 | switch e := s.Value.(type) { 343 | case *ExpressionList: 344 | var statements []IRStatement 345 | for _, value := range e.Values { 346 | statements = append(statements, compileIRStatement(&ExpressionStatement{Value: value})) 347 | } 348 | 349 | return &IRCompoundStatement{ 350 | Statements: statements, 351 | } 352 | 353 | case *FunctionCallExpression: 354 | name := findIdentifierExpression(e.Identifier).Name 355 | 356 | if isSystemCall(name) { 357 | tmp := tmpvar() 358 | arg, decls, beforeArg := compileIRExpression(e.Argument) 359 | 360 | return &IRCompoundStatement{ 361 | Declarations: append(decls, &IRVariableDeclaration{Var: tmp}), 362 | Statements: append(beforeArg, 363 | &IRAssignmentStatement{ 364 | Var: tmp, 365 | Expression: arg, 366 | }, 367 | &IRSystemCallStatement{ 368 | Name: name, 369 | Var: tmp, 370 | }, 371 | ), 372 | } 373 | } 374 | } 375 | 376 | _, decls, beforeValue := compileIRExpression(s.Value) 377 | 378 | return &IRCompoundStatement{ 379 | Declarations: decls, 380 | Statements: beforeValue, 381 | } 382 | 383 | case *IfStatement: 384 | conditionVar := tmpvar() 385 | 386 | trueLabel := label("true") 387 | falseLabel := label("false") 388 | endLabel := label("end") 389 | 390 | condition, decls, beforeCondition := compileIRExpression(s.Condition) 391 | 392 | statements := []IRStatement{ 393 | &IRAssignmentStatement{ 394 | Var: conditionVar, 395 | Expression: condition, 396 | }, 397 | &IRIfStatement{ 398 | Var: conditionVar, 399 | TrueLabel: trueLabel, 400 | FalseLabel: falseLabel, 401 | }, 402 | &IRLabelStatement{Name: trueLabel}, 403 | compileIRStatement(s.TrueStatement), 404 | &IRGotoStatement{Label: endLabel}, 405 | &IRLabelStatement{Name: falseLabel}, 406 | } 407 | 408 | if s.FalseStatement != nil { 409 | statements = append(statements, compileIRStatement(s.FalseStatement)) 410 | } 411 | 412 | statements = append(statements, &IRLabelStatement{Name: endLabel}) 413 | 414 | return &IRCompoundStatement{ 415 | Declarations: append(IRVariableDeclarations([]*Symbol{conditionVar}), decls...), 416 | Statements: append(beforeCondition, statements...), 417 | } 418 | 419 | case *WhileStatement: 420 | conditionVar := tmpvar() 421 | 422 | beginLabel := label("while_begin") 423 | endLabel := label("while_end") 424 | 425 | condition, decls, beforeCondition := compileIRExpression(s.Condition) 426 | statements := append([]IRStatement{&IRLabelStatement{Name: beginLabel}}, beforeCondition...) 427 | 428 | statements = append(statements, 429 | &IRAssignmentStatement{ 430 | Var: conditionVar, 431 | Expression: condition, 432 | }, 433 | &IRIfStatement{ 434 | Var: conditionVar, 435 | FalseLabel: endLabel, 436 | }, 437 | compileIRStatement(s.Statement), 438 | &IRGotoStatement{Label: beginLabel}, 439 | &IRLabelStatement{Name: endLabel}, 440 | ) 441 | 442 | return &IRCompoundStatement{ 443 | Declarations: append(IRVariableDeclarations([]*Symbol{conditionVar}), decls...), 444 | Statements: statements, 445 | } 446 | 447 | case *ReturnStatement: 448 | // return exp; 449 | // 450 | // tmp = 451 | // return tmp 452 | 453 | if s.Value == nil { 454 | return &IRReturnStatement{} 455 | } 456 | 457 | tmp := tmpvar() 458 | 459 | value, decls, beforeValue := compileIRExpression(s.Value) 460 | return &IRCompoundStatement{ 461 | Declarations: append(IRVariableDeclarations([]*Symbol{tmp}), decls...), 462 | Statements: append(beforeValue, 463 | &IRAssignmentStatement{Var: tmp, Expression: value}, 464 | &IRReturnStatement{Var: tmp}, 465 | ), 466 | } 467 | 468 | default: 469 | panic("unexpected statement") 470 | } 471 | } 472 | 473 | func IRVariableDeclarations(symbols []*Symbol) []*IRVariableDeclaration { 474 | var declarations []*IRVariableDeclaration 475 | for _, symbol := range symbols { 476 | declarations = append(declarations, &IRVariableDeclaration{ 477 | Var: symbol, 478 | }) 479 | } 480 | 481 | return declarations 482 | } 483 | 484 | func findSymbolsFromDeclaration(declaration *Declaration) []*Symbol { 485 | var symbols []*Symbol 486 | for _, declarator := range declaration.Declarators { 487 | identifier := findIdentifierExpression(declarator.Identifier) 488 | symbols = append(symbols, identifier.Symbol) 489 | } 490 | 491 | return symbols 492 | } 493 | 494 | func compileIRExpression(expression Expression) (IRExpression, []*IRVariableDeclaration, []IRStatement) { 495 | switch e := expression.(type) { 496 | case *NumberExpression: 497 | value, _ := strconv.Atoi(e.Value) 498 | return &IRNumberExpression{ 499 | Value: value, 500 | }, nil, nil 501 | 502 | case *IdentifierExpression: 503 | return &IRVariableExpression{ 504 | Var: e.Symbol, 505 | }, nil, nil 506 | 507 | case *UnaryExpression: 508 | if e.Operator == "*" { 509 | result := tmpvar() 510 | tmp := tmpvar() 511 | irValue, decls, beforeValue := compileIRExpression(e.Value) 512 | 513 | statements := []IRStatement{ 514 | &IRAssignmentStatement{ 515 | Var: tmp, 516 | Expression: irValue, 517 | }, 518 | &IRReadStatement{Dest: result, Src: tmp}, 519 | } 520 | 521 | decls = append(IRVariableDeclarations([]*Symbol{result, tmp}), decls...) 522 | statements = append(beforeValue, statements...) 523 | 524 | return &IRVariableExpression{ 525 | Var: result, 526 | }, decls, statements 527 | } 528 | 529 | if e.Operator == "&" { 530 | value, decls, statements := compileIRExpression(e.Value) 531 | v, _ := value.(*IRVariableExpression) 532 | 533 | return &IRAddressExpression{ 534 | Var: v.Var, 535 | }, decls, statements 536 | } 537 | 538 | case *BinaryExpression: 539 | // return (a || b) && c 540 | // v; 541 | // if (a) { 542 | // v = 1 543 | // } else if (b) { 544 | // v = 1; 545 | // } else { 546 | // v = 0; 547 | // } 548 | // int v; 549 | // if (a) { 550 | // if (b) { 551 | // v = 1; 552 | // } else { 553 | // v = 0; 554 | // } 555 | // } else { 556 | // v = 0; 557 | // } 558 | 559 | if e.IsAssignment() { 560 | // a = (b = c); 561 | // *(p + 2) = 4 562 | switch left := e.Left.(type) { 563 | case *UnaryExpression: 564 | // address = left 565 | // tmpRight = exp 566 | // *address = tmpRight 567 | if left.Operator == "*" { 568 | address := tmpvar() 569 | tmpRight := tmpvar() 570 | 571 | right, rightDecls, beforeRight := compileIRExpression(e.Right) 572 | leftExpression, leftDecls, beforeLeft := compileIRExpression(left.Value) 573 | 574 | decls := append(rightDecls, leftDecls...) 575 | decls = append(IRVariableDeclarations([]*Symbol{address, tmpRight}), decls...) 576 | 577 | statements := []IRStatement{ 578 | &IRAssignmentStatement{ 579 | Var: address, 580 | Expression: leftExpression, 581 | }, 582 | &IRAssignmentStatement{ 583 | Var: tmpRight, 584 | Expression: right, 585 | }, 586 | &IRWriteStatement{Dest: address, Src: tmpRight}, 587 | } 588 | statements = append(append(beforeLeft, beforeRight...), statements...) 589 | 590 | return &IRVariableExpression{tmpRight}, decls, statements 591 | } 592 | 593 | default: 594 | tmp := tmpvar() 595 | decls := IRVariableDeclarations([]*Symbol{tmp}) 596 | 597 | symbol := findIdentifierExpression(e.Left).Symbol 598 | right, rightDecls, beforeRight := compileIRExpression(e.Right) 599 | 600 | decls = append(decls, rightDecls...) 601 | statements := append(beforeRight, 602 | &IRAssignmentStatement{ 603 | Var: symbol, 604 | Expression: right, 605 | }, 606 | ) 607 | 608 | return &IRVariableExpression{Var: symbol}, decls, statements 609 | } 610 | } 611 | 612 | switch e.Operator { 613 | case "&&": 614 | tmp := tmpvar() 615 | 616 | decls := IRVariableDeclarations([]*Symbol{tmp}) 617 | statements := []IRStatement{ 618 | compileIRStatement(&IfStatement{ 619 | Condition: e.Left, 620 | TrueStatement: &IfStatement{ 621 | Condition: e.Right, 622 | TrueStatement: assignStatementBySymbol(tmp, 1), 623 | FalseStatement: assignStatementBySymbol(tmp, 0), 624 | }, 625 | FalseStatement: assignStatementBySymbol(tmp, 0), 626 | }), 627 | } 628 | 629 | return &IRVariableExpression{Var: tmp}, decls, statements 630 | 631 | case "||": 632 | tmp := tmpvar() 633 | 634 | decls := IRVariableDeclarations([]*Symbol{tmp}) 635 | statements := []IRStatement{ 636 | compileIRStatement(&IfStatement{ 637 | Condition: e.Left, 638 | TrueStatement: assignStatementBySymbol(tmp, 1), 639 | FalseStatement: &IfStatement{ 640 | Condition: e.Right, 641 | TrueStatement: assignStatementBySymbol(tmp, 1), 642 | FalseStatement: assignStatementBySymbol(tmp, 0), 643 | }, 644 | }), 645 | } 646 | 647 | return &IRVariableExpression{Var: tmp}, decls, statements 648 | } 649 | 650 | left, leftDecls, beforeLeft := compileIRExpression(e.Left) 651 | right, rightDecls, beforeRight := compileIRExpression(e.Right) 652 | 653 | t, _ := typeOfExpression(e) 654 | switch t.(type) { 655 | case PointerType: 656 | leftType, _ := typeOfExpression(e.Left) 657 | 658 | if _, isInt := leftType.(BasicType); isInt { 659 | // 4 * r + l 660 | left = &IRBinaryExpression{ 661 | Operator: "*", 662 | Left: &IRNumberExpression{Value: 4}, // int -> 4 bytes 663 | Right: left, 664 | } 665 | } else { 666 | // l + 4 * r 667 | right = &IRBinaryExpression{ 668 | Operator: "*", 669 | Left: &IRNumberExpression{Value: 4}, // int -> 4 bytes 670 | Right: right, 671 | } 672 | } 673 | } 674 | 675 | return &IRBinaryExpression{ 676 | Operator: e.Operator, 677 | Left: left, 678 | Right: right, 679 | }, append(leftDecls, rightDecls...), append(beforeLeft, beforeRight...) 680 | 681 | case *FunctionCallExpression: 682 | funcIdentifier := findIdentifierExpression(e.Identifier) 683 | 684 | var args []Expression 685 | switch arg := e.Argument.(type) { 686 | case *ExpressionList: 687 | args = arg.Values 688 | default: 689 | if arg != nil { 690 | args = []Expression{arg} 691 | } 692 | } 693 | 694 | var argSymbols []*Symbol 695 | var statements []IRStatement 696 | var decls []*IRVariableDeclaration 697 | 698 | for _, arg := range args { 699 | tmp := tmpvar() 700 | argSymbols = append(argSymbols, tmp) 701 | 702 | expression, expressionDecls, beforeExpression := compileIRExpression(arg) 703 | 704 | decls = append(decls, expressionDecls...) 705 | 706 | statements = append(statements, beforeExpression...) 707 | statements = append(statements, &IRAssignmentStatement{ 708 | Var: tmp, 709 | Expression: expression, 710 | }) 711 | } 712 | 713 | result := tmpvar() 714 | 715 | // result = f(a0, a1, ...) 716 | statements = append(statements, &IRCallStatement{ 717 | Dest: result, 718 | Func: funcIdentifier.Symbol, 719 | Vars: argSymbols, 720 | }) 721 | 722 | decls = append(decls, IRVariableDeclarations(append(argSymbols, result))...) 723 | return &IRVariableExpression{ 724 | Var: result, 725 | }, decls, statements 726 | 727 | } 728 | 729 | panic(fmt.Sprintf("unexpected expression: `%v`", reflect.TypeOf(expression))) 730 | } 731 | 732 | func assignStatementBySymbol(symbol *Symbol, value int) *ExpressionStatement { 733 | return &ExpressionStatement{ 734 | Value: &BinaryExpression{ 735 | Operator: "=", 736 | Left: &IdentifierExpression{Symbol: symbol}, 737 | Right: &NumberExpression{Value: strconv.Itoa(value)}, 738 | }, 739 | } 740 | } 741 | 742 | func isSystemCall(name string) bool { 743 | switch name { 744 | case "print", "putchar": 745 | return true 746 | } 747 | 748 | return false 749 | } 750 | -------------------------------------------------------------------------------- /ir_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestCompileIR(t *testing.T) { 8 | { 9 | statements := ast(` 10 | int main() { 11 | int a; 12 | a = 1 + 2; 13 | } 14 | `) 15 | 16 | ir := CompileIR(statements) 17 | 18 | if len(ir.Functions) != 1 { 19 | t.Errorf("expect len(functions) == 1, got %v", len(ir.Functions)) 20 | } 21 | } 22 | 23 | { 24 | statements := ast(` 25 | int main() { 26 | int a; 27 | 28 | if (a > 0) { 29 | a = 1; 30 | } else { 31 | a = 2; 32 | } 33 | } 34 | `) 35 | 36 | CompileIR(statements) 37 | } 38 | 39 | { 40 | statements := ast(` 41 | int main() { 42 | int a; 43 | while (a > 0) { 44 | a = a - 1; 45 | } 46 | } 47 | `) 48 | 49 | CompileIR(statements) 50 | } 51 | 52 | { 53 | statements := ast(` 54 | int main() { 55 | int a; 56 | int *p; 57 | p = &a; 58 | } 59 | `) 60 | 61 | CompileIR(statements) 62 | } 63 | 64 | { 65 | statements := ast(` 66 | int main() { 67 | int a[10]; 68 | int *p; 69 | p = a; 70 | *(p + 1) = 1; 71 | } 72 | `) 73 | 74 | CompileIR(statements) 75 | } 76 | } 77 | 78 | func TestCompileIRStatement(t *testing.T) { 79 | // int a; 80 | // int *p; 81 | symbolP := &Symbol{Name: "p", Type: Pointer(Int())} 82 | symbolA := &Symbol{Name: "a", Type: Int()} 83 | 84 | { 85 | // *p = a; 86 | // 87 | // tmp = a 88 | // *p = tmp 89 | s := &ExpressionStatement{ 90 | Value: &BinaryExpression{ 91 | Operator: "=", 92 | Left: &UnaryExpression{ 93 | Operator: "*", 94 | Value: &IdentifierExpression{Symbol: symbolP}, 95 | }, 96 | Right: &IdentifierExpression{Symbol: symbolA}, 97 | }, 98 | } 99 | 100 | ir := compileIRStatement(s) 101 | compoundStatement, ok := ir.(*IRCompoundStatement) 102 | if !ok { 103 | t.Errorf("expect *IRCompoundStatement, but got %v", ir) 104 | return 105 | } 106 | 107 | _, ok = compoundStatement.Statements[2].(*IRWriteStatement) 108 | if !ok { 109 | t.Errorf("expect WriteStatement %v", ir) 110 | } 111 | } 112 | 113 | { 114 | // a = *p; 115 | s := &ExpressionStatement{ 116 | Value: &BinaryExpression{ 117 | Operator: "=", 118 | Left: &IdentifierExpression{Symbol: symbolA}, 119 | Right: &UnaryExpression{ 120 | Operator: "*", 121 | Value: &IdentifierExpression{Symbol: symbolP}, 122 | }, 123 | }, 124 | } 125 | 126 | ir := compileIRStatement(s) 127 | compoundStatement, ok := ir.(*IRCompoundStatement) 128 | if !ok { 129 | t.Errorf("expect *IRCompoundStatement, but got %v", ir) 130 | return 131 | } 132 | 133 | if len(compoundStatement.Statements) == 0 { 134 | t.Error(compoundStatement) 135 | } 136 | } 137 | } 138 | 139 | func TestCompileIRExpression(t *testing.T) { 140 | // 0 || 1 141 | e := &BinaryExpression{ 142 | Operator: "||", 143 | Left: &NumberExpression{Value: "0"}, 144 | Right: &NumberExpression{Value: "1"}, 145 | } 146 | 147 | _, decls, before := compileIRExpression(e) 148 | if len(before) == 0 || len(decls) == 0 { 149 | t.Errorf("expect decls and statements, got %v %v", before, decls) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /lexer.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "strings" 7 | "text/scanner" 8 | ) 9 | 10 | type Lexer struct { 11 | scanner scanner.Scanner 12 | result []Statement 13 | token Token 14 | pos scanner.Position 15 | errMessage string 16 | } 17 | 18 | func (l *Lexer) Init(code string) { 19 | l.scanner.Init(strings.NewReader(code)) 20 | } 21 | 22 | var keywords = map[string]int{ 23 | "int": TYPE, 24 | "void": TYPE, 25 | "if": IF, 26 | "else": ELSE, 27 | "return": RETURN, 28 | "while": WHILE, 29 | "for": FOR, 30 | } 31 | 32 | func (l *Lexer) Lex(lval *yySymType) int { 33 | tok := l.scanner.Scan() 34 | 35 | if tok == scanner.EOF { 36 | return -1 37 | } 38 | 39 | lit := l.scanner.TokenText() 40 | pos := l.scanner.Pos() 41 | 42 | lval.token = Token{lit: lit, pos: pos} 43 | l.token = lval.token 44 | 45 | if regexp.MustCompile(`^(0|[1-9][0-9]*)$`).MatchString(lit) { 46 | return NUMBER 47 | } 48 | 49 | if keywords[lit] != 0 { 50 | return keywords[lit] 51 | } 52 | 53 | two := fmt.Sprintf("%c%c", tok, l.scanner.Peek()) 54 | operators := map[string]int{ 55 | "==": EQL, 56 | "!=": NEQ, 57 | "<=": LEQ, 58 | ">=": GEQ, 59 | "&&": LOGICAL_AND, 60 | "||": LOGICAL_OR, 61 | } 62 | 63 | if operators[two] != 0 { 64 | l.scanner.Next() 65 | lval.token = Token{lit: two, pos: pos} 66 | l.token = lval.token 67 | return operators[two] 68 | } 69 | 70 | if regexp.MustCompile(`^'.*'$`).MatchString(lit) { 71 | return CHAR 72 | } 73 | 74 | switch lit { 75 | case "(", ")", "{", "}", "&", ";", ",", "[", "]", "+", "-", "*", "/", "<", ">", "=": 76 | return int(tok) 77 | 78 | default: 79 | return IDENT 80 | } 81 | } 82 | 83 | func (l *Lexer) Error(e string) { 84 | l.pos = l.token.pos 85 | l.errMessage = e 86 | } 87 | -------------------------------------------------------------------------------- /lexer_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestLex(t *testing.T) { 8 | testLex(t, `42 7 0`, []int{NUMBER, NUMBER, NUMBER}) 9 | testLex(t, `a == 100`, []int{IDENT, EQL, NUMBER}) 10 | } 11 | 12 | func testLex(t *testing.T, code string, tokens []int) { 13 | l := new(Lexer) 14 | l.Init(`a == 100`) 15 | 16 | tokenTypes := []int{IDENT, EQL, NUMBER} 17 | result := []int{} 18 | 19 | var sym yySymType 20 | for { 21 | tokenNumber := l.Lex(&sym) 22 | if tokenNumber == -1 { 23 | break 24 | } 25 | result = append(result, tokenNumber) 26 | } 27 | 28 | if len(result) != len(tokenTypes) { 29 | t.Errorf("expect %v tokens, got %v: %v", len(tokenTypes), len(result), result) 30 | } 31 | 32 | for i, resultType := range result { 33 | if resultType != tokenTypes[i] { 34 | t.Errorf("%v: expect %v, got %v", i, tokenTypes[i], resultType) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | 7 | "io/ioutil" 8 | "os" 9 | 10 | "github.com/k0kubun/pp" 11 | ) 12 | 13 | func main() { 14 | optimize := flag.Bool("optimize", true, "Enable optimization") 15 | flag.Parse() 16 | 17 | var src string 18 | 19 | if len(os.Args) > 1 { 20 | filename := os.Args[len(os.Args)-1] 21 | data, err := ioutil.ReadFile(filename) 22 | if err != nil { 23 | fmt.Println(err) 24 | } 25 | 26 | src = string(data) 27 | } else { 28 | data, _ := ioutil.ReadAll(os.Stdin) 29 | src = string(data) 30 | } 31 | 32 | code, errs := CompileSource(src, *optimize) 33 | if len(errs) > 0 { 34 | Exit(src, errs) 35 | } 36 | fmt.Println(code) 37 | } 38 | 39 | func CompileSource(src string, optimize bool) (string, []error) { 40 | debug := len(os.Getenv("DEBUG")) > 0 41 | 42 | statements, err := Parse(src) 43 | if err != nil { 44 | return "", []error{err} 45 | } 46 | 47 | for i, statement := range statements { 48 | statements[i] = Walk(statement) 49 | } 50 | 51 | if debug { 52 | pp.Println(statements) 53 | } 54 | 55 | prelude, _ := Parse(` 56 | void print(int i); 57 | void putchar(int ch); 58 | `) 59 | statements = append(prelude, statements...) 60 | 61 | env := &Env{} 62 | errs := Analyze(statements, env) 63 | if len(errs) > 0 { 64 | return "", errs 65 | } 66 | 67 | err = CheckType(statements) 68 | if err != nil { 69 | return "", []error{err} 70 | } 71 | 72 | irProgram := CompileIR(statements) 73 | 74 | if optimize { 75 | irProgram = Optimize(irProgram) 76 | } 77 | 78 | code := Compile(irProgram) 79 | 80 | if debug { 81 | fmt.Println(irProgram) 82 | } 83 | 84 | return code, nil 85 | } 86 | 87 | func Exit(src string, errs []error) { 88 | for _, err := range errs { 89 | switch e := err.(type) { 90 | case SemanticError: 91 | err = fmt.Errorf("%d:%d: %v", e.Pos.Line, e.Pos.Column, e.Err) 92 | 93 | default: 94 | } 95 | 96 | fmt.Fprintln(os.Stderr, err) 97 | } 98 | 99 | os.Exit(1) 100 | } 101 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io/ioutil" 5 | "os/exec" 6 | "path/filepath" 7 | "regexp" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | func TestSimulateExample(t *testing.T) { 13 | examples := [](struct { 14 | Filename string 15 | Output string 16 | }){ 17 | {"example/sum.sc", "1"}, 18 | {"example/sum_for.sc", "45"}, 19 | {"example/many_args.sc", "6"}, 20 | {"example/factorial.sc", "24"}, 21 | {"example/fib.sc", "89"}, 22 | {"example/global_var.sc", "11"}, 23 | {"example/if_test.sc", ""}, 24 | {"example/pointer_test.sc", "1"}, 25 | {"example/optimize_constant.sc", "1"}, 26 | {"example/bubble_sort.sc", "12345678"}, 27 | {"example/quick_sort.sc", "12345678"}, 28 | {"example/putchar.sc", "hello world"}, 29 | {"example/gcd.sc", "21"}, 30 | {"example/prime.sc", "2 3 5 7 11 13 17 19 23 29 "}, 31 | {"example/emoji.sc", "45"}, 32 | {"example/fizzbuzz.sc", "1 2 Fizz 4 Buzz Fizz 7 8 Fizz Buzz 11 Fizz 13 14 FizzBuzz 16 17 Fizz 19 Buzz Fizz 22 23 Fizz Buzz 26 Fizz 28 29 FizzBuzz "}, 33 | } 34 | 35 | for _, example := range examples { 36 | sourceFilename := example.Filename 37 | filename := regexp.MustCompile("\\.sc$").ReplaceAllString(sourceFilename, ".s") 38 | 39 | { 40 | err := compileAndSave(sourceFilename) 41 | 42 | if err != nil { 43 | t.Errorf("%v: %v", sourceFilename, err) 44 | continue 45 | } 46 | } 47 | 48 | output, err := runSpim(filename) 49 | if err != nil { 50 | t.Error(err) 51 | continue 52 | } 53 | 54 | expected := example.Output 55 | 56 | if output != expected { 57 | t.Errorf("`%v`: expect `%v`, got `%v`", filename, expected, output) 58 | } 59 | } 60 | } 61 | 62 | func TestSampleOk(t *testing.T) { 63 | sampleFiles, _ := filepath.Glob("sample/ok*.sc") 64 | for _, sampleFile := range sampleFiles { 65 | testOk(t, sampleFile) 66 | } 67 | } 68 | 69 | func TestSampleNg(t *testing.T) { 70 | sampleFiles, _ := filepath.Glob("sample/ng*.sc") 71 | for _, filename := range sampleFiles { 72 | err := compileAndSave(filename) 73 | if err == nil { 74 | t.Errorf("%v: expect error, got ok", filename) 75 | } 76 | } 77 | } 78 | 79 | func TestBasic(t *testing.T) { 80 | filenames, _ := filepath.Glob("test/basic/*.sc") 81 | for _, filename := range filenames { 82 | testOk(t, filename) 83 | } 84 | } 85 | 86 | func TestAdvanced(t *testing.T) { 87 | filenames, _ := filepath.Glob("test/advanced/*.sc") 88 | for _, filename := range filenames { 89 | testOk(t, filename) 90 | } 91 | } 92 | 93 | func TestErr(t *testing.T) { 94 | filenames, _ := filepath.Glob("test/err/*.sc") 95 | for _, filename := range filenames { 96 | err := compileAndSave(filename) 97 | if err == nil { 98 | t.Errorf("%v: expect error, got ok", filename) 99 | } 100 | } 101 | } 102 | 103 | func compileAndSave(filename string) error { 104 | src, err := ioutil.ReadFile(filename) 105 | if err != nil { 106 | return err 107 | } 108 | 109 | code, errs := CompileSource(string(src), true) 110 | for _, err := range errs { 111 | return err 112 | } 113 | 114 | dest := regexp.MustCompile("\\.sc$").ReplaceAllString(filename, ".s") 115 | err = ioutil.WriteFile(dest, []byte(code), 0777) 116 | if err != nil { 117 | return err 118 | } 119 | 120 | return nil 121 | } 122 | 123 | func runSpim(filename string) (string, error) { 124 | byteOut, err := exec.Command("spim", "-file", filename).Output() 125 | if err != nil { 126 | return "", err 127 | } 128 | 129 | lines := strings.Split(string(byteOut), "\n") 130 | output := lines[len(lines)-1] 131 | 132 | return output, nil 133 | } 134 | 135 | func testOk(t *testing.T, sourceFilename string) { 136 | filename := regexp.MustCompile("\\.sc$").ReplaceAllString(sourceFilename, ".s") 137 | 138 | { 139 | err := compileAndSave(sourceFilename) 140 | 141 | if err != nil { 142 | t.Errorf("%v: %v", sourceFilename, err) 143 | return 144 | } 145 | 146 | output, err := runSpim(filename) 147 | if err != nil { 148 | t.Error(err) 149 | return 150 | } 151 | 152 | expected := "1" 153 | if output != expected { 154 | t.Errorf("`%v`: expect `%v`, got `%v`", filename, expected, output) 155 | } 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /optimize.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | type DataflowBlock struct { 4 | Name string 5 | Statements []IRStatement 6 | Next []*DataflowBlock 7 | Prev []*DataflowBlock 8 | } 9 | 10 | func (block *DataflowBlock) AddEdge(another *DataflowBlock) { 11 | block.Next = append(block.Next, another) 12 | another.Prev = append(another.Prev, block) 13 | } 14 | 15 | type BlockState map[*Symbol][]IRStatement 16 | 17 | func (state BlockState) Equal(anotherState BlockState) bool { 18 | for symbol, statements := range state { 19 | if len(state[symbol]) != len(anotherState[symbol]) { 20 | return false 21 | } 22 | 23 | for i := range statements { 24 | if state[symbol][i] != anotherState[symbol][i] { 25 | return false 26 | } 27 | } 28 | } 29 | 30 | return true 31 | } 32 | 33 | func Optimize(program *IRProgram) *IRProgram { 34 | for i, f := range program.Functions { 35 | statements := flatStatement(f) 36 | 37 | blocks := splitStatementsIntoBlocks(statements) 38 | 39 | buildDataflowGraph(blocks) 40 | blockOut := searchReachingDefinitions(blocks) 41 | allStatementState := reachingDefinitionsOfStatements(blocks, blockOut, statements) 42 | 43 | program.Functions[i] = transformByConstantFolding(program.Functions[i], allStatementState) 44 | program.Functions[i] = transformByDeadCodeElimination(program.Functions[i], allStatementState) 45 | } 46 | 47 | return program 48 | } 49 | 50 | func transformByConstantFolding(f *IRFunctionDefinition, allStatementState map[IRStatement]BlockState) *IRFunctionDefinition { 51 | traversed := Traverse(f, func(statement IRStatement) IRStatement { 52 | foldConstantStatement(statement, allStatementState) 53 | return statement 54 | }) 55 | 56 | return traversed.(*IRFunctionDefinition) 57 | } 58 | 59 | func transformByDeadCodeElimination(f *IRFunctionDefinition, allStatementState map[IRStatement]BlockState) *IRFunctionDefinition { 60 | changed := true 61 | for changed { 62 | changed = false 63 | 64 | used := make(map[IRStatement]bool) 65 | markAsUsed := func(s IRStatement, symbol *Symbol) { 66 | for _, definition := range allStatementState[s][symbol] { 67 | used[definition] = true 68 | } 69 | } 70 | 71 | Traverse(f, func(statement IRStatement) IRStatement { 72 | switch s := statement.(type) { 73 | case *IRCompoundStatement: 74 | used[s] = true 75 | 76 | case *IRAssignmentStatement: 77 | if s.Var.IsGlobal() { 78 | used[s] = true 79 | } 80 | 81 | vars := extractVarsFromExpression(s.Expression) 82 | for _, v := range vars { 83 | markAsUsed(s, v) 84 | } 85 | 86 | case *IRReadStatement: 87 | if s.Dest.IsGlobal() { 88 | used[s] = true 89 | } 90 | 91 | markAsUsed(s, s.Src) 92 | 93 | case *IRWriteStatement: 94 | markAsUsed(s, s.Src) 95 | markAsUsed(s, s.Dest) 96 | 97 | case *IRCallStatement: 98 | if s.Dest.IsGlobal() { 99 | used[s] = true 100 | } 101 | 102 | for _, argVar := range s.Vars { 103 | markAsUsed(s, argVar) 104 | } 105 | 106 | case *IRSystemCallStatement: 107 | markAsUsed(s, s.Var) 108 | 109 | case *IRReturnStatement: 110 | markAsUsed(s, s.Var) 111 | 112 | case *IRIfStatement: 113 | markAsUsed(s, s.Var) 114 | } 115 | 116 | return statement 117 | }) 118 | 119 | transformed := Traverse(f, func(statement IRStatement) IRStatement { 120 | switch statement.(type) { 121 | case *IRAssignmentStatement, *IRReadStatement: 122 | if !used[statement] { 123 | changed = true 124 | return nil 125 | } 126 | } 127 | 128 | return statement 129 | }) 130 | 131 | f = transformed.(*IRFunctionDefinition) 132 | } 133 | 134 | return removeUnusedVariableDeclaration(f) 135 | } 136 | 137 | func removeUnusedVariableDeclaration(f *IRFunctionDefinition) *IRFunctionDefinition { 138 | used := make(map[*Symbol]bool) 139 | Traverse(f, func(statement IRStatement) IRStatement { 140 | switch s := statement.(type) { 141 | case *IRAssignmentStatement: 142 | used[s.Var] = true 143 | case *IRReadStatement: 144 | used[s.Dest] = true 145 | case *IRCallStatement: 146 | used[s.Dest] = true 147 | } 148 | 149 | return statement 150 | }) 151 | 152 | transformed := Traverse(f, func(statement IRStatement) IRStatement { 153 | switch s := statement.(type) { 154 | case *IRCompoundStatement: 155 | newDeclarations := []*IRVariableDeclaration{} 156 | for _, d := range s.Declarations { 157 | _, isArrayType := d.Var.Type.(ArrayType) 158 | if used[d.Var] || isArrayType { 159 | newDeclarations = append(newDeclarations, d) 160 | } 161 | } 162 | 163 | s.Declarations = newDeclarations 164 | return s 165 | } 166 | 167 | return statement 168 | }) 169 | 170 | return transformed.(*IRFunctionDefinition) 171 | } 172 | 173 | func extractAddressVarsFromExpression(expression IRExpression) []*Symbol { 174 | switch e := expression.(type) { 175 | case *IRBinaryExpression: 176 | var vars []*Symbol 177 | vars = append(vars, extractAddressVarsFromExpression(e.Left)...) 178 | vars = append(vars, extractAddressVarsFromExpression(e.Right)...) 179 | return vars 180 | 181 | case *IRAddressExpression: 182 | return []*Symbol{e.Var} 183 | 184 | } 185 | 186 | return nil 187 | } 188 | 189 | func extractVarsFromExpression(expression IRExpression) []*Symbol { 190 | switch e := expression.(type) { 191 | case *IRNumberExpression: 192 | return nil 193 | 194 | case *IRVariableExpression: 195 | return []*Symbol{e.Var} 196 | 197 | case *IRBinaryExpression: 198 | var vars []*Symbol 199 | vars = append(vars, extractVarsFromExpression(e.Left)...) 200 | vars = append(vars, extractVarsFromExpression(e.Right)...) 201 | return vars 202 | 203 | case *IRAddressExpression: 204 | return []*Symbol{e.Var} 205 | 206 | } 207 | 208 | return nil 209 | } 210 | 211 | func searchReachingDefinitions(blocks []*DataflowBlock) map[*DataflowBlock]BlockState { 212 | blockOut := make(map[*DataflowBlock]BlockState) 213 | 214 | changed := true 215 | for changed { 216 | changed = false 217 | 218 | for _, block := range blocks { 219 | inState := analyzeBlock(blockOut, block) 220 | if !inState.Equal(blockOut[block]) { 221 | changed = true 222 | } 223 | 224 | blockOut[block] = inState 225 | } 226 | } 227 | 228 | return blockOut 229 | } 230 | 231 | func reachingDefinitionsOfStatements(blocks []*DataflowBlock, blockOut map[*DataflowBlock]BlockState, statements []IRStatement) map[IRStatement]BlockState { 232 | allStatementState := make(map[IRStatement]BlockState) 233 | for _, block := range blocks { 234 | inState := blockIn(blockOut, block) 235 | 236 | for _, statement := range block.Statements { 237 | for key, value := range inState { 238 | if allStatementState[statement] == nil { 239 | allStatementState[statement] = make(BlockState) 240 | } 241 | allStatementState[statement][key] = value 242 | } 243 | inState = analyzeReachingDefinition(statement, inState) 244 | } 245 | } 246 | 247 | return allStatementState 248 | } 249 | 250 | func foldConstantStatement(statement IRStatement, allStatementState map[IRStatement]BlockState) (bool, int) { 251 | switch s := statement.(type) { 252 | case *IRAssignmentStatement: 253 | isConstant, value := foldConstantExpression(s, s.Expression, allStatementState) 254 | if isConstant { 255 | s.Expression = &IRNumberExpression{Value: value} 256 | return true, value 257 | } 258 | } 259 | 260 | return false, 0 261 | } 262 | 263 | func foldConstantExpression(statement IRStatement, expression IRExpression, allStatementState map[IRStatement]BlockState) (bool, int) { 264 | switch e := expression.(type) { 265 | case *IRNumberExpression: 266 | return true, e.Value 267 | 268 | case *IRVariableExpression: 269 | symbol := e.Var 270 | definitionOfVar := allStatementState[statement][symbol] 271 | if len(definitionOfVar) == 1 && definitionOfVar[0] != statement { 272 | return foldConstantStatement(definitionOfVar[0], allStatementState) 273 | } 274 | 275 | return false, 0 276 | 277 | case *IRBinaryExpression: 278 | leftIsConstant, leftValue := foldConstantExpression(statement, e.Left, allStatementState) 279 | rightIsConstant, rightValue := foldConstantExpression(statement, e.Right, allStatementState) 280 | 281 | if leftIsConstant { 282 | e.Left = &IRNumberExpression{Value: leftValue} 283 | } 284 | 285 | if rightIsConstant { 286 | e.Right = &IRNumberExpression{Value: rightValue} 287 | } 288 | 289 | if leftIsConstant && rightIsConstant { 290 | switch e.Operator { 291 | case "+": 292 | return true, leftValue + rightValue 293 | 294 | case "-": 295 | return true, leftValue - rightValue 296 | 297 | case "*": 298 | return true, leftValue * rightValue 299 | 300 | case "/": 301 | return true, leftValue / rightValue 302 | 303 | case "<": 304 | value := 0 305 | if leftValue < rightValue { 306 | value = 1 307 | } 308 | return true, value 309 | 310 | case ">": 311 | value := 0 312 | if leftValue > rightValue { 313 | value = 1 314 | } 315 | return true, value 316 | 317 | case "<=": 318 | value := 0 319 | if leftValue <= rightValue { 320 | value = 1 321 | } 322 | return true, value 323 | 324 | case ">=": 325 | value := 0 326 | if leftValue >= rightValue { 327 | value = 1 328 | } 329 | return true, value 330 | 331 | case "==": 332 | value := 0 333 | if leftValue == rightValue { 334 | value = 1 335 | } 336 | return true, value 337 | 338 | case "!=": 339 | value := 0 340 | if leftValue != rightValue { 341 | value = 1 342 | } 343 | return true, value 344 | 345 | } 346 | 347 | panic("unexpected operator: " + e.Operator) 348 | } 349 | 350 | return false, 0 351 | } 352 | 353 | return false, 0 354 | } 355 | 356 | func blockIn(blockOut map[*DataflowBlock]BlockState, block *DataflowBlock) BlockState { 357 | inState := BlockState{} 358 | for _, prevBlock := range block.Prev { 359 | for key, statements := range blockOut[prevBlock] { 360 | for _, statement := range statements { 361 | found := false 362 | for _, v := range inState[key] { 363 | if v == statement { 364 | found = true 365 | break 366 | } 367 | } 368 | 369 | if !found { 370 | inState[key] = append(inState[key], statement) 371 | } 372 | } 373 | } 374 | } 375 | 376 | return inState 377 | } 378 | 379 | func analyzeBlock(blockOut map[*DataflowBlock]BlockState, block *DataflowBlock) BlockState { 380 | inState := blockIn(blockOut, block) 381 | for _, statement := range block.Statements { 382 | inState = analyzeReachingDefinition(statement, inState) 383 | } 384 | 385 | return inState 386 | } 387 | 388 | func analyzeReachingDefinition(statement IRStatement, inState BlockState) BlockState { 389 | switch s := statement.(type) { 390 | case *IRAssignmentStatement: 391 | inState[s.Var] = []IRStatement{s} 392 | symbols := extractAddressVarsFromExpression(s.Expression) 393 | for _, symbol := range symbols { 394 | inState[symbol] = append(inState[symbol], s) 395 | } 396 | 397 | case *IRReadStatement: 398 | inState[s.Dest] = []IRStatement{s} 399 | 400 | case *IRWriteStatement: 401 | for symbol := range inState { 402 | inState[symbol] = append(inState[symbol], s) 403 | } 404 | 405 | case *IRCallStatement: 406 | inState[s.Dest] = []IRStatement{s} 407 | 408 | } 409 | 410 | return inState 411 | } 412 | 413 | func splitStatementsIntoBlocks(statements []IRStatement) []*DataflowBlock { 414 | var blocks []*DataflowBlock 415 | block := &DataflowBlock{} 416 | for _, statement := range statements { 417 | switch s := statement.(type) { 418 | case *IRFunctionDefinition, *IRLabelStatement: 419 | // in 420 | if len(block.Statements) > 0 { 421 | blocks = append(blocks, block) 422 | } 423 | 424 | block = &DataflowBlock{Statements: []IRStatement{s}} 425 | 426 | case *IRIfStatement, *IRGotoStatement, *IRReturnStatement: 427 | // out 428 | block.Statements = append(block.Statements, s) 429 | blocks = append(blocks, block) 430 | block = &DataflowBlock{} 431 | 432 | default: 433 | block.Statements = append(block.Statements, s) 434 | } 435 | } 436 | 437 | if len(block.Statements) > 0 { 438 | blocks = append(blocks, block) 439 | } 440 | 441 | return blocks 442 | } 443 | 444 | func buildDataflowGraph(blocks []*DataflowBlock) *DataflowBlock { 445 | beginBlock := &DataflowBlock{Name: "BEGIN"} 446 | beginBlock.Next = append(beginBlock.Next, blocks[0]) 447 | 448 | endBlock := &DataflowBlock{Name: "END"} 449 | lastBlock := blocks[len(blocks)-1] 450 | lastBlock.Next = append(lastBlock.Next, endBlock) 451 | 452 | for i, block := range blocks { 453 | lastStatement := block.Statements[len(block.Statements)-1] 454 | switch s := lastStatement.(type) { 455 | case *IRGotoStatement: 456 | // goto label -> label block 457 | nextBlock := findBlockByLabel(blocks, s.Label) 458 | block.AddEdge(nextBlock) 459 | 460 | case *IRIfStatement: 461 | // if block -> true_label block, false_label block 462 | if len(s.TrueLabel) > 0 { 463 | trueLabelBlock := findBlockByLabel(blocks, s.TrueLabel) 464 | block.AddEdge(trueLabelBlock) 465 | } 466 | 467 | if len(s.FalseLabel) > 0 { 468 | falseLabelBlock := findBlockByLabel(blocks, s.FalseLabel) 469 | block.AddEdge(falseLabelBlock) 470 | } 471 | 472 | if len(s.TrueLabel) == 0 || len(s.FalseLabel) == 0 { 473 | if i < len(blocks)-1 { 474 | nextBlock := blocks[i+1] 475 | block.AddEdge(nextBlock) 476 | } 477 | } 478 | 479 | case *IRReturnStatement: 480 | // return block -> end block 481 | block.AddEdge(endBlock) 482 | 483 | default: 484 | if i < len(blocks)-1 { 485 | nextBlock := blocks[i+1] 486 | block.AddEdge(nextBlock) 487 | } 488 | } 489 | } 490 | 491 | return beginBlock 492 | } 493 | 494 | func findBlockByLabel(blocks []*DataflowBlock, label string) *DataflowBlock { 495 | for _, block := range blocks { 496 | inStatement := block.Statements[0] 497 | labelStatement, ok := inStatement.(*IRLabelStatement) 498 | if ok && labelStatement.Name == label { 499 | return block 500 | } 501 | } 502 | 503 | return nil 504 | } 505 | 506 | func flatStatement(statement IRStatement) []IRStatement { 507 | switch s := statement.(type) { 508 | case *IRFunctionDefinition: 509 | var statements []IRStatement 510 | 511 | statements = append(statements, s) 512 | 513 | for _, p := range s.Parameters { 514 | statements = append(statements, flatStatement(p)...) 515 | } 516 | 517 | statements = append(statements, flatStatement(s.Body)...) 518 | 519 | return statements 520 | 521 | case *IRCompoundStatement: 522 | var statements []IRStatement 523 | 524 | for _, d := range s.Declarations { 525 | statements = append(statements, flatStatement(d)...) 526 | } 527 | 528 | for _, child := range s.Statements { 529 | statements = append(statements, flatStatement(child)...) 530 | } 531 | 532 | return statements 533 | } 534 | 535 | return []IRStatement{statement} 536 | } 537 | -------------------------------------------------------------------------------- /optimize_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestExtractVarsFromExpression(t *testing.T) { 8 | symbol := &Symbol{Name: "foo"} 9 | 10 | { 11 | // foo + 42 12 | vars := extractVarsFromExpression( 13 | &IRBinaryExpression{ 14 | Operator: "+", 15 | Left: &IRVariableExpression{Var: symbol}, 16 | Right: &IRNumberExpression{Value: 42}, 17 | }, 18 | ) 19 | 20 | expected := len(vars) == 1 && vars[0] == symbol 21 | if !expected { 22 | t.Errorf("expect vars of `foo + 42` to be `foo`, got %v", vars) 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /parse.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // Parse returns ast 8 | func Parse(src string) ([]Statement, error) { 9 | l := new(Lexer) 10 | l.Init(src) 11 | yyErrorVerbose = true 12 | 13 | fail := yyParse(l) 14 | if fail == 1 { 15 | err := fmt.Errorf("%d:%d: %s", l.pos.Line, l.pos.Column, l.errMessage) 16 | 17 | return nil, err 18 | } 19 | 20 | return l.result, nil 21 | } 22 | 23 | // Walk iterates over statement nodes and replace syntax sugar 24 | func Walk(statement Statement) Statement { 25 | switch s := statement.(type) { 26 | case *FunctionDefinition: 27 | for i, p := range s.Parameters { 28 | s.Parameters[i] = WalkExpression(p) 29 | } 30 | 31 | s.Statement = Walk(s.Statement) 32 | 33 | return s 34 | 35 | case *CompoundStatement: 36 | for i, st := range s.Statements { 37 | s.Statements[i] = Walk(st) 38 | } 39 | 40 | for i, d := range s.Declarations { 41 | s.Declarations[i] = Walk(d) 42 | } 43 | 44 | return s 45 | 46 | case *ForStatement: 47 | // for (init; cond; loop) s 48 | // => init; while (cond) { s; loop; } 49 | 50 | var statements []Statement 51 | if s.Init != nil { 52 | statements = append(statements, &ExpressionStatement{Value: WalkExpression(s.Init)}) 53 | } 54 | 55 | body := Walk(s.Statement) 56 | whileBody := []Statement{body} 57 | if s.Loop != nil { 58 | whileBody = append(whileBody, &ExpressionStatement{Value: WalkExpression(s.Loop)}) 59 | } 60 | 61 | var condition Expression 62 | if s.Condition != nil { 63 | condition = WalkExpression(s.Condition) 64 | } else { 65 | condition = &NumberExpression{Value: "1"} 66 | } 67 | 68 | statements = append(statements, 69 | &WhileStatement{ 70 | pos: s.Pos(), 71 | Condition: condition, 72 | Statement: &CompoundStatement{ 73 | Statements: whileBody, 74 | }, 75 | }, 76 | ) 77 | 78 | return &CompoundStatement{ 79 | Statements: statements, 80 | } 81 | 82 | case *WhileStatement: 83 | s.Condition = WalkExpression(s.Condition) 84 | s.Statement = Walk(s.Statement) 85 | 86 | case *IfStatement: 87 | s.Condition = WalkExpression(s.Condition) 88 | s.TrueStatement = Walk(s.TrueStatement) 89 | s.FalseStatement = Walk(s.FalseStatement) 90 | 91 | return s 92 | 93 | case *ReturnStatement: 94 | s.Value = WalkExpression(s.Value) 95 | return s 96 | 97 | case *ExpressionStatement: 98 | s.Value = WalkExpression(s.Value) 99 | return s 100 | } 101 | 102 | return statement 103 | } 104 | 105 | func WalkExpression(expression Expression) Expression { 106 | switch e := expression.(type) { 107 | case *ExpressionList: 108 | for i, value := range e.Values { 109 | e.Values[i] = WalkExpression(value) 110 | } 111 | 112 | return e 113 | 114 | case *FunctionCallExpression: 115 | e.Argument = WalkExpression(e.Argument) 116 | 117 | return e 118 | 119 | case *BinaryExpression: 120 | e.Left = WalkExpression(e.Left) 121 | e.Right = WalkExpression(e.Right) 122 | 123 | return e 124 | 125 | case *UnaryExpression: 126 | e.Value = WalkExpression(e.Value) 127 | 128 | if e.Operator == "-" { 129 | return &BinaryExpression{ 130 | Left: &NumberExpression{pos: e.Pos(), Value: "0"}, 131 | Operator: "-", 132 | Right: e.Value, 133 | } 134 | } else if e.Operator == "&" { 135 | // &(*e) -> e 136 | switch value := e.Value.(type) { 137 | case *UnaryExpression: 138 | if value.Operator == "*" { 139 | return value.Value 140 | } 141 | } 142 | } 143 | 144 | return e 145 | 146 | case *ArrayReferenceExpression: 147 | // a[100] => *(a + 100) 148 | e.Target = WalkExpression(e.Target) 149 | e.Index = WalkExpression(e.Index) 150 | 151 | return &UnaryExpression{ 152 | Operator: "*", 153 | Value: &BinaryExpression{ 154 | Left: e.Target, 155 | Operator: "+", 156 | Right: e.Index, 157 | }, 158 | } 159 | } 160 | 161 | return expression 162 | } 163 | -------------------------------------------------------------------------------- /parse_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestParse(t *testing.T) { 9 | statements, err := Parse(` 10 | int data[8]; 11 | 12 | int main() { 13 | int i; 14 | int j; 15 | int tmp; 16 | int size; 17 | 18 | size = 8; 19 | for (i = 0; i < size; i = i + 1) { 20 | for (j = 1; j < size; j = j + 1) { 21 | if (data[j] < data[j-1]) { 22 | tmp = data[j]; 23 | data[j] = data[j-1]; 24 | data[j-1] = tmp; 25 | } 26 | } 27 | } 28 | } 29 | `) 30 | 31 | if err != nil { 32 | t.Error(err) 33 | return 34 | } 35 | 36 | if len(statements) == 0 { 37 | t.Errorf("expect len(statements) > 0, actual: %v", len(statements)) 38 | } 39 | } 40 | 41 | func TestParseError(t *testing.T) { 42 | _, err := Parse(` 43 | wtf this is wtf 44 | `) 45 | 46 | if !(err != nil && strings.Contains(err.Error(), "syntax error")) { 47 | t.Errorf("expect syntax error, but success") 48 | } 49 | } 50 | 51 | func TestParseDeclaration(t *testing.T) { 52 | _, err := Parse(` 53 | int foo, bar; 54 | void baz; 55 | int a[100]; 56 | int *pointer; 57 | 58 | int sum(int a, int b); 59 | int *foo(); 60 | `) 61 | 62 | if err != nil { 63 | t.Error(err) 64 | } 65 | } 66 | 67 | func TestParseCompoundStatement(t *testing.T) { 68 | _, err := Parse(` 69 | int foo() { 70 | int a; 71 | 72 | { 73 | a = a + b; 74 | }; 75 | 76 | return b; 77 | } 78 | `) 79 | 80 | if err != nil { 81 | t.Error(err) 82 | } 83 | } 84 | 85 | func TestParseFunctionDefinition(t *testing.T) { 86 | _, err := Parse(` 87 | int sum(int a, int b) { 88 | return a + b; 89 | } 90 | `) 91 | 92 | if err != nil { 93 | t.Error(err) 94 | } 95 | } 96 | 97 | func TestParseIfStatement(t *testing.T) { 98 | _, err := Parse(` 99 | int foo(int a) { 100 | if (a == 0) a = 1; 101 | if (a != 0) a = 1; 102 | if (a > 0) a = 1; 103 | if (a >= 0) a = 1; 104 | if (a < 0) a = 1; 105 | if (a <= 0) a = 1; 106 | if (a && b) return 1; 107 | if (a || b) return 1; 108 | } 109 | `) 110 | 111 | if err != nil { 112 | t.Error(err) 113 | } 114 | } 115 | 116 | func TestParseWhileStatement(t *testing.T) { 117 | _, err := Parse(` 118 | int main() { 119 | int a; 120 | 121 | a = 100; 122 | while (a) { 123 | a = a - 1; 124 | } 125 | } 126 | `) 127 | 128 | if err != nil { 129 | t.Error(err) 130 | } 131 | } 132 | 133 | func TestParseForStatement(t *testing.T) { 134 | statements, err := Parse(` 135 | int main() { 136 | for (i = 0; i < 100; i = i + 1) { 137 | sum = sum + i; 138 | } 139 | 140 | for (;;) { 141 | return; 142 | } 143 | } 144 | `) 145 | 146 | if err != nil { 147 | t.Error(err) 148 | return 149 | } 150 | 151 | switch mainStatements(statements)[0].(type) { 152 | case *ForStatement: 153 | default: 154 | t.Error("expected ForStatement") 155 | } 156 | } 157 | 158 | func mainStatements(statements []Statement) []Statement { 159 | main := statements[0].(*FunctionDefinition) 160 | 161 | return main.Statement.(*CompoundStatement).Statements 162 | } 163 | 164 | func TestParseUnaryExpression(t *testing.T) { 165 | _, err := Parse(` 166 | int main() { 167 | a = &a; 168 | a = -a; 169 | a = *a; 170 | } 171 | `) 172 | 173 | if err != nil { 174 | t.Error(err) 175 | } 176 | } 177 | 178 | func TestWalkExpression(t *testing.T) { 179 | { 180 | e := WalkExpression(&UnaryExpression{ 181 | Operator: "-", 182 | Value: &NumberExpression{Value: "42"}, 183 | }) 184 | 185 | _, ok := e.(*BinaryExpression) 186 | if !ok { 187 | t.Errorf("expect *BinaryExpression") 188 | } 189 | } 190 | 191 | { 192 | e := WalkExpression(&UnaryExpression{ 193 | Operator: "&", 194 | Value: &UnaryExpression{ 195 | Operator: "*", 196 | Value: &IdentifierExpression{ 197 | Name: "foo", 198 | }, 199 | }, 200 | }) 201 | 202 | _, ok := e.(*IdentifierExpression) 203 | if !ok { 204 | t.Errorf("expect *IdentifierExpression") 205 | } 206 | } 207 | 208 | { 209 | // a[10] -> *(a + 10) 210 | e := WalkExpression(&ArrayReferenceExpression{ 211 | Target: &IdentifierExpression{Name: "a"}, 212 | Index: &NumberExpression{Value: "10"}, 213 | }) 214 | 215 | unaryExpression, ok := e.(*UnaryExpression) 216 | _, isBinary := (unaryExpression.Value).(*BinaryExpression) 217 | 218 | if !(ok && unaryExpression.Operator == "*" && isBinary) { 219 | t.Errorf("it should be *(a + 10), but: %v", e) 220 | } 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /parser.go.y: -------------------------------------------------------------------------------- 1 | %{ 2 | package main 3 | 4 | import ( 5 | "strconv" 6 | ) 7 | 8 | %} 9 | 10 | %union { 11 | token Token 12 | 13 | expression Expression 14 | expressions []Expression 15 | 16 | declarator *Declarator 17 | declarators []*Declarator 18 | 19 | statement Statement 20 | statements []Statement 21 | 22 | parameter_declaration *ParameterDeclaration 23 | } 24 | 25 | %type expression optional_expression identifier_expression identifier 26 | %type add_expression mult_expression assign_expression primary_expression logical_or_expression logical_and_expression equal_expression relation_expression unary_expression postfix_expression 27 | %type parameters optional_parameters 28 | %type statements declarations optional_statements optional_declarations program 29 | %type statement compound_statement external_declaration declaration function_definition function_prototype 30 | %type declarator 31 | %type declarators 32 | %type parameter_declaration 33 | %token NUMBER CHAR IDENT TYPE IF LOGICAL_OR LOGICAL_AND RETURN EQL NEQ GEQ LEQ ELSE WHILE FOR '-' '*' '&' '{' 34 | 35 | %% 36 | 37 | program 38 | : external_declaration 39 | { 40 | $$ = []Statement{$1} 41 | yylex.(*Lexer).result = $$ 42 | } 43 | | program external_declaration 44 | { 45 | $$ = append($1, $2) 46 | yylex.(*Lexer).result = $$ 47 | } 48 | 49 | external_declaration 50 | : declaration 51 | | function_prototype 52 | | function_definition 53 | 54 | declarations 55 | : declaration 56 | { 57 | $$ = []Statement{ $1 } 58 | } 59 | | declarations declaration 60 | { 61 | $$ = append($1, $2) 62 | } 63 | 64 | declaration 65 | : TYPE declarators ';' 66 | { 67 | $$ = &Declaration{ pos: $1.pos, VarType: $1.lit, Declarators: $2 } 68 | } 69 | 70 | declarators 71 | : declarator 72 | { 73 | $$ = []*Declarator{ $1 } 74 | } 75 | | declarators ',' declarator 76 | { 77 | $$ = append($1, $3) 78 | } 79 | 80 | declarator 81 | : identifier_expression 82 | { 83 | $$ = &Declarator{ Identifier: $1 } 84 | } 85 | | identifier_expression '[' NUMBER ']' 86 | { 87 | i, _ := strconv.Atoi($3.lit) 88 | $$ = &Declarator{ Identifier: $1, Size: i } 89 | } 90 | 91 | function_prototype 92 | : TYPE identifier_expression '(' optional_parameters ')' ';' 93 | { 94 | $$ = &FunctionDefinition{ pos: $1.pos, TypeName: $1.lit, Identifier: $2, Parameters: $4 } 95 | } 96 | 97 | function_definition 98 | : TYPE identifier_expression '(' optional_parameters ')' compound_statement 99 | { 100 | $$ = &FunctionDefinition{ pos: $1.pos, TypeName: $1.lit, Identifier: $2, Parameters: $4, Statement: $6 } 101 | } 102 | 103 | identifier_expression 104 | : identifier 105 | | '*' identifier 106 | { 107 | $$ = &UnaryExpression{ pos: $1.pos, Operator: "*", Value: $2 } 108 | } 109 | 110 | optional_parameters 111 | : { $$ = nil } 112 | | parameters 113 | 114 | parameters 115 | : parameter_declaration 116 | { 117 | $$ = []Expression{ $1 } 118 | } 119 | | parameters ',' parameter_declaration 120 | { 121 | $$ = append($1, $3) 122 | } 123 | 124 | parameter_declaration 125 | : TYPE identifier_expression 126 | { 127 | $$ = &ParameterDeclaration{ pos: $1.pos, TypeName: $1.lit, Identifier: $2 } 128 | } 129 | 130 | compound_statement 131 | : '{' optional_declarations optional_statements '}' 132 | { 133 | $$ = &CompoundStatement{ pos: $1.pos, Declarations: $2, Statements: $3 } 134 | } 135 | 136 | optional_declarations 137 | : { $$ = nil } 138 | | declarations 139 | 140 | optional_statements 141 | : { $$ = nil } 142 | | statements 143 | 144 | statements 145 | : statement 146 | { 147 | $$ = []Statement{ $1 } 148 | } 149 | | statements statement 150 | { 151 | $$ = append($1, $2) 152 | } 153 | 154 | statement 155 | : ';' 156 | { 157 | $$ = nil 158 | } 159 | | expression ';' 160 | { 161 | $$ = &ExpressionStatement{ Value: $1 } 162 | } 163 | | compound_statement 164 | | IF '(' expression ')' statement 165 | { 166 | $$ = &IfStatement{ pos: $1.pos, Condition: $3, TrueStatement: $5 } 167 | } 168 | | IF '(' expression ')' statement ELSE statement 169 | { 170 | $$ = &IfStatement{ pos: $1.pos, Condition: $3, TrueStatement: $5, FalseStatement: $7 } 171 | } 172 | | WHILE '(' expression ')' statement 173 | { 174 | $$ = &WhileStatement{ pos: $1.pos, Condition: $3, Statement: $5 } 175 | } 176 | | FOR '(' optional_expression ';' optional_expression ';' optional_expression ')' statement 177 | { 178 | $$ = &ForStatement{ pos: $1.pos, Init: $3, Condition: $5, Loop: $7, Statement: $9 } 179 | } 180 | | RETURN optional_expression ';' 181 | { 182 | $$ = &ReturnStatement{ pos: $1.pos, Value: $2 } 183 | } 184 | 185 | optional_expression: { $$ = nil } 186 | | expression 187 | 188 | expression 189 | : assign_expression 190 | { 191 | $$ = $1 192 | } 193 | | expression ',' assign_expression 194 | { 195 | switch e := $1.(type) { 196 | case *ExpressionList: 197 | $$ = &ExpressionList{ Values: append(e.Values, $3) } 198 | default: 199 | $$ = &ExpressionList{ Values: []Expression{$1, $3} } 200 | } 201 | } 202 | 203 | assign_expression 204 | : logical_or_expression 205 | | logical_or_expression '=' assign_expression 206 | { 207 | $$ = &BinaryExpression{ Left: $1, Operator: "=", Right: $3 } 208 | } 209 | 210 | logical_or_expression 211 | : logical_and_expression 212 | | logical_or_expression LOGICAL_OR logical_and_expression 213 | { 214 | $$ = &BinaryExpression{ Left: $1, Operator: "||", Right: $3} 215 | } 216 | 217 | logical_and_expression 218 | : equal_expression 219 | | logical_and_expression LOGICAL_AND equal_expression 220 | { 221 | $$ = &BinaryExpression{ Left: $1, Operator: "&&", Right: $3} 222 | } 223 | 224 | equal_expression 225 | : relation_expression 226 | | equal_expression EQL relation_expression 227 | { 228 | $$ = &BinaryExpression{ Left: $1, Operator: "==", Right: $3} 229 | } 230 | | equal_expression NEQ relation_expression 231 | { 232 | $$ = &BinaryExpression{ Left: $1, Operator: "!=", Right: $3} 233 | } 234 | 235 | relation_expression 236 | : add_expression 237 | | relation_expression '>' add_expression 238 | { 239 | $$ = &BinaryExpression{ Left: $1, Operator: ">", Right: $3} 240 | } 241 | | relation_expression '<' add_expression 242 | { 243 | $$ = &BinaryExpression{ Left: $1, Operator: "<", Right: $3} 244 | } 245 | | relation_expression GEQ add_expression 246 | { 247 | $$ = &BinaryExpression{ Left: $1, Operator: ">=", Right: $3} 248 | } 249 | | relation_expression LEQ add_expression 250 | { 251 | $$ = &BinaryExpression{ Left: $1, Operator: "<=", Right: $3} 252 | } 253 | 254 | add_expression 255 | : mult_expression 256 | | add_expression '+' mult_expression 257 | { 258 | $$ = &BinaryExpression{ Left: $1, Operator: "+", Right: $3 } 259 | } 260 | | add_expression '-' mult_expression 261 | { 262 | $$ = &BinaryExpression{ Left: $1, Operator: "-", Right: $3 } 263 | } 264 | 265 | mult_expression 266 | : unary_expression 267 | | mult_expression '*' unary_expression 268 | { 269 | $$ = &BinaryExpression{ Left: $1, Operator: "*", Right: $3 } 270 | } 271 | | mult_expression '/' unary_expression 272 | { 273 | $$ = &BinaryExpression{ Left: $1, Operator: "/", Right: $3 } 274 | } 275 | 276 | unary_expression 277 | : postfix_expression 278 | | '-' unary_expression 279 | { 280 | $$ = &UnaryExpression{ pos: $1.pos, Operator: "-", Value: $2 } 281 | } 282 | | '&' unary_expression 283 | { 284 | $$ = &UnaryExpression{ pos: $1.pos, Operator: "&", Value: $2 } 285 | } 286 | | '*' unary_expression 287 | { 288 | $$ = &UnaryExpression{ pos: $1.pos, Operator: "*", Value: $2 } 289 | } 290 | 291 | postfix_expression 292 | : primary_expression 293 | | postfix_expression '[' expression ']' 294 | { 295 | $$ = &ArrayReferenceExpression{ Target: $1, Index: $3 } 296 | } 297 | | identifier '(' optional_expression ')' 298 | { 299 | $$ = &FunctionCallExpression{ Identifier: $1, Argument: $3 } 300 | } 301 | 302 | primary_expression 303 | : NUMBER 304 | { 305 | $$ = &NumberExpression{ pos: $1.pos, Value: $1.lit } 306 | } 307 | | identifier 308 | | '(' expression ')' 309 | { 310 | $$ = $2 311 | } 312 | | CHAR 313 | { 314 | literal := $1.lit 315 | ch := literal[1:len(literal)-1][0] 316 | i := int(ch) 317 | 318 | $$ = &NumberExpression{ pos: $1.pos, Value: strconv.Itoa(i) } 319 | } 320 | 321 | identifier 322 | : IDENT 323 | { 324 | $$ = &IdentifierExpression{ pos: $1.pos, Name: $1.lit } 325 | } 326 | 327 | %% 328 | -------------------------------------------------------------------------------- /report/1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiur/small-c/e5a2a01bbbc1e93aa3d406f717a8030c281e7812/report/1.pdf -------------------------------------------------------------------------------- /report/1.tex: -------------------------------------------------------------------------------- 1 | \documentclass[a4j]{jarticle} 2 | \usepackage[dvipdfmx]{graphicx} 3 | 4 | \begin{document} 5 | 6 | \title{計算機科学実験3 中間レポート1} 7 | \author{杉本風斗} 8 | 9 | \maketitle 10 | 11 | \section{概要} 12 | Small Cの字句・構文解析および抽象構文木の変換処理をするプログラムを作成した. \\ 13 | 実装言語はGo (https://golang.org/) を使用した. 14 | 15 | \section{プログラムの概要} 16 | プログラムの実行方法とソースファイルの構造を説明する. \\ 17 | 18 | \subsection{使い方} 19 | 実装に使用したGoのバージョンは1.6である. ソースファイルの置いたディレクトリ内で以下のようにしてビルドして実行できる. \\ 20 | 21 | \begin{verbatim} 22 | $ make 23 | $ ./small-c program.sc 24 | \end{verbatim} 25 | 26 | 引数にソースファイルを与えて実行すると, 抽象構文木がpretty printされて出力される. \\ 27 | 簡単なSmall Cプログラムを与えて実行した例を示す. 28 | 29 | \begin{verbatim} 30 | $ cat test.sc 31 | int main() { 32 | return 1 + 2; 33 | } 34 | 35 | $ ./small-c test.sc 36 | []main.Statement{ 37 | &main.FunctionDefinition{ 38 | pos: 1, 39 | TypeName: "int", 40 | Identifier: &main.IdentifierExpression{ 41 | pos: 5, 42 | Name: "main", 43 | Symbol: (*main.Symbol)(nil), 44 | }, 45 | Parameters: []main.Expression{}, 46 | Statement: &main.CompoundStatement{ 47 | pos: 12, 48 | Declarations: []main.Statement{}, 49 | Statements: []main.Statement{ 50 | &main.ReturnStatement{ 51 | pos: 12, 52 | Value: &main.BinaryExpression{ 53 | Left: &main.NumberExpression{ 54 | pos: 23, 55 | Value: "1", 56 | }, 57 | Operator: "+", 58 | Right: &main.NumberExpression{ 59 | pos: 27, 60 | Value: "2", 61 | }, 62 | }, 63 | FunctionSymbol: (*main.Symbol)(nil), 64 | }, 65 | }, 66 | }, 67 | }, 68 | } 69 | \end{verbatim} 70 | 71 | 出力の形式はあとで説明する抽象構文木の構造体定義に基づいている. \\ 72 | 73 | 不正な文法のファイルを与える例を示す. ソースファイル中のエラーの位置とともにエラーメッセージが表示される. 74 | 75 | \begin{verbatim} 76 | $ cat error.sc 77 | int a 78 | int b; 79 | 80 | $ ./small-c error.sc 81 | 2:1: syntax error: unexpected TYPE, expecting ';' or ',' 82 | \end{verbatim} 83 | 84 | \subsection{ソースファイルの構造} 85 | プログラムを構成する主要なソースファイルを説明する. 86 | 87 | \begin{itemize} 88 | \item parser.go.y: yaccを用いた字句解析および構文解析 89 | \item ast.go: 抽象構文木の構造体の定義 90 | \item parse.go: 構文解析のラッパ関数および抽象構文木の変換処理 91 | \item main.go: コマンドラインから呼び出されるmain関数 92 | \end{itemize} 93 | 94 | parse\_test.go などのファイルは開発用のユニットテストである. 以下のコマンドでまとめて実行できる. 95 | 96 | \begin{verbatim} 97 | $ make test 98 | \end{verbatim} 99 | 100 | \section{構文解析} 101 | \subsection{抽象構文木の構造体定義} 102 | 抽象構文木のデータ構造を説明する. 構造体の定義はast.goに書かれている. 説明のため, 以下にはast.goの内容から構造体定義の部分だけ抜き出したものを示す. \\ 103 | 104 | 変換処理や意味解析の処理が複雑にならないように, 構造体の数が多くならないように工夫をした. \\ 105 | 106 | 大きく分けてExpression, Statement, 定義の三種類がある. 構文木要素のソースコード上の位置はposというfieldに格納されている. ただし, BinaryExpressionなど子要素を含む複合的な構造体は, 子要素からソースコード上の位置を求めることができるので, ソースコード上の位置を直接fieldに格納していない. 107 | 108 | \begin{verbatim} 109 | type Node interface { 110 | Pos() token.Pos 111 | } 112 | 113 | type Expression interface { 114 | Node 115 | } 116 | 117 | type ExpressionList struct { 118 | Values []Expression 119 | } 120 | 121 | type NumberExpression struct { 122 | pos token.Pos 123 | Value string 124 | } 125 | 126 | type IdentifierExpression struct { 127 | pos token.Pos 128 | Name string 129 | Symbol *Symbol 130 | } 131 | 132 | type UnaryExpression struct { 133 | pos token.Pos 134 | Operator string 135 | Value Expression 136 | } 137 | 138 | type BinaryExpression struct { 139 | Left Expression 140 | Operator string 141 | Right Expression 142 | } 143 | 144 | type FunctionCallExpression struct { 145 | Identifier Expression 146 | Argument Expression 147 | } 148 | 149 | type ArrayReferenceExpression struct { 150 | Target Expression 151 | Index Expression 152 | } 153 | 154 | type PointerExpression struct { 155 | pos token.Pos 156 | Value Expression 157 | } 158 | 159 | type Declarator struct { 160 | Identifier Expression 161 | Size int 162 | } 163 | 164 | type Declaration struct { 165 | pos token.Pos 166 | VarType string 167 | Declarators []*Declarator 168 | } 169 | 170 | type FunctionDefinition struct { 171 | pos token.Pos 172 | TypeName string 173 | Identifier Expression 174 | Parameters []Expression 175 | Statement Statement 176 | } 177 | 178 | type Statement interface { 179 | Node 180 | } 181 | 182 | type CompoundStatement struct { 183 | pos token.Pos 184 | Declarations []Statement 185 | Statements []Statement 186 | } 187 | 188 | type ExpressionStatement struct { 189 | Value Expression 190 | } 191 | 192 | type IfStatement struct { 193 | pos token.Pos 194 | Condition Expression 195 | TrueStatement Statement 196 | FalseStatement Statement 197 | } 198 | 199 | type WhileStatement struct { 200 | pos token.Pos 201 | Condition Expression 202 | Statement Statement 203 | } 204 | 205 | type ForStatement struct { 206 | pos token.Pos 207 | Init Expression 208 | Condition Expression 209 | Loop Expression 210 | Statement Statement 211 | } 212 | 213 | type ReturnStatement struct { 214 | pos token.Pos 215 | Value Expression 216 | FunctionSymbol *Symbol 217 | } 218 | 219 | type ParameterDeclaration struct { 220 | pos token.Pos 221 | TypeName string 222 | Identifier Expression 223 | } 224 | \end{verbatim} 225 | 226 | \subsection{字句解析} 227 | 字句解析には, goの標準ライブラリのgo/scannerを使った. 字句解析処理は, Lexer構造体のLex()関数に書いており, 構文解析部から逐次 Lex() を呼び出すという仕組みになっている. \\ 228 | 229 | \begin{verbatim} 230 | type Lexer struct { 231 | scanner.Scanner 232 | result []Statement 233 | token Token 234 | pos token.Pos 235 | errMessage string 236 | } 237 | 238 | func (l *Lexer) Lex(lval *yySymType) int { 239 | pos, tok, lit := l.Scan() 240 | token_number := int(tok) 241 | 242 | // 省略 243 | 244 | lval.token = Token{ tok: tok, lit: lit, pos: pos } 245 | l.token = lval.token 246 | 247 | return token_number 248 | } 249 | \end{verbatim} 250 | 251 | \subsection{構文解析} 252 | 253 | 構文解析器にはyaccのGo実装であるgoyacc (https://golang.org/cmd/yacc/)を使用した. 構文定義の文法は本家yaccと同じである. 本家yaccと違うのは, プログラムを記述する場所でCではなくGoで記述できるという点のみだと考えてよい. \\ 254 | 255 | parser.go.yの構文定義の一部を例として示す. 構文解析器から得られたトークン情報などから構造体を順番に組み立てる処理をしている. 256 | 257 | \begin{verbatim} 258 | statement 259 | : ';' 260 | { 261 | $$ = nil 262 | } 263 | | expression ';' 264 | { 265 | $$ = &ExpressionStatement{ Value: $1 } 266 | } 267 | | compound_statement 268 | | IF '(' expression ')' statement 269 | { 270 | $$ = &IfStatement{ pos: $1.pos, Condition: $3, TrueStatement: $5 } 271 | } 272 | | IF '(' expression ')' statement ELSE statement 273 | { 274 | $$ = &IfStatement{ pos: $1.pos, Condition: $3, TrueStatement: $5, FalseStatement: $7 } 275 | } 276 | | WHILE '(' expression ')' statement 277 | { 278 | $$ = &WhileStatement{ pos: $1.pos, Condition: $3, Statement: $5 } 279 | } 280 | | FOR '(' optional_expression ';' optional_expression ';' optional_expression ')' statement 281 | { 282 | $$ = &ForStatement{ pos: $1.pos, Init: $3, Condition: $5, Loop: $7, Statement: $9 } 283 | } 284 | | RETURN optional_expression ';' 285 | { 286 | $$ = &ReturnStatement{ pos: $1.pos, Value: $2 } 287 | } 288 | \end{verbatim} 289 | 290 | parse.goで, ほかの部分から呼び出すための構文解析関数をParse()を定義している. yaccから生成された関数を呼び出したり, エラー情報を適切にくっつけたりしている. 291 | 292 | \begin{verbatim} 293 | func Parse(src string) ([]Statement, error) { 294 | fset := token.NewFileSet() 295 | file := fset.AddFile("", fset.Base(), len(src)) 296 | 297 | l := new(Lexer) 298 | l.Init(file, []byte(src), nil, scanner.ScanComments) 299 | yyErrorVerbose = true 300 | 301 | fail := yyParse(l) 302 | if fail == 1 { 303 | lineNumber, columnNumber := posToLineInfo(src, int(l.pos)) 304 | err := fmt.Errorf("%d:%d: %s", lineNumber, columnNumber, l.errMessage) 305 | 306 | return nil, err 307 | } 308 | 309 | return l.result, nil 310 | } 311 | \end{verbatim} 312 | 313 | \section{抽象構文木の変換処理} 314 | 抽象構文木の変換処理は, parse.goのWalk()関数に書いている. \\ 315 | 316 | 構文解析部が返した抽象構文木を再帰的にたどっていき, 置き換えるべき表現を見つけたら変換した構造体を返すという処理をしている. \\ 317 | 318 | for文をwhile文に置き換える処理の例を示す. 319 | 320 | \begin{verbatim} 321 | func Walk(statement Statement) Statement { 322 | switch s := statement.(type) { 323 | case *ForStatement: 324 | // for (init; cond; loop) s 325 | // => init; while (cond) { s; loop; } 326 | body := Walk(s.Statement) 327 | return &CompoundStatement{ 328 | Statements: []Statement{ 329 | &ExpressionStatement{Value: s.Init}, 330 | &WhileStatement{ 331 | pos: s.Pos(), 332 | Condition: s.Condition, 333 | Statement: &CompoundStatement{ 334 | Statements: []Statement{ 335 | body, 336 | &ExpressionStatement{Value: s.Loop}, 337 | }, 338 | }, 339 | }, 340 | }, 341 | } 342 | 343 | \end{verbatim} 344 | 345 | \section{main関数} 346 | main関数では, 構文解析関数 Parse()と抽象構文木を変換する関数Walk()を呼び出して, 結果を出力する処理をしている. ソースファイルを見れば明らかだと思うので, ここでは説明を省略する. 347 | 348 | \section{感想} 349 | Goや型付きの言語での開発に不慣れなせいか, 抽象構文木の構造体をうまく定義するのに苦労した. 構文解析器を書いている途中で構造体の定義がよくないことがわかって何度も書き換えたりした. \\ 350 | 351 | はじめにオブジェクト指向っぽい空気感で書いていたところ, 何度も戸惑ったり痛い目にあったりした. Goは文法がオブジェクト指向言語っぽいが, 継承などの機能はないので実際はオブジェクト指向ではない. \\ 352 | 353 | もう少し気合を入れて開発することで, はやくGoに慣れていきたい. 354 | 355 | \end{document} 356 | -------------------------------------------------------------------------------- /report/2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiur/small-c/e5a2a01bbbc1e93aa3d406f717a8030c281e7812/report/2.pdf -------------------------------------------------------------------------------- /report/2.tex: -------------------------------------------------------------------------------- 1 | \documentclass[a4j]{jarticle} 2 | \usepackage[dvipdfmx]{graphicx} 3 | 4 | \begin{document} 5 | 6 | \title{計算機科学実験3 中間レポート2} 7 | \author{杉本風斗} 8 | 9 | \maketitle 10 | 11 | \section{概要} 12 | 課題9の意味解析と課題10の中間表現への変換を実装した. \\ 13 | まずは具体的な実行例を説明する. 14 | 15 | \subsection{使い方} 16 | 実行方法はレポート1で説明したのと同じである. ソースファイルの置いたディレクトリ内で以下のようにしてビルドして実行できる. \\ 17 | 18 | \begin{verbatim} 19 | $ make 20 | $ ./small-c program.sc 21 | \end{verbatim} 22 | 23 | 引数にソースファイルを与えて実行すると, 中間表現を表す文字列が出力される. \\ 24 | 簡単なSmall Cプログラムを与えて実行した例を示す. 25 | 26 | \begin{verbatim} 27 | $ cat example/0.sc 28 | int main() { 29 | return 1 + 2; 30 | } 31 | 32 | $ ./small-c example/0.sc 33 | 34 | 35 | main() 36 | int #tmp_0 37 | #tmp_0 = (+ 1 2) 38 | return #tmp_0 39 | 40 | \end{verbatim} 41 | 42 | 出力の形式については中間表現への変換の項であとで説明する. \\ 43 | 44 | 意味解析に失敗するソースコードを入力に与える例を示す. \\ 45 | \begin{verbatim} 46 | $ cat type_error.sc 47 | int main() { 48 | int a; 49 | int *p; 50 | p = a; 51 | } 52 | 53 | $ ./small-c type_error.sc 54 | 4:3: type error: int* = int 55 | 56 | \end{verbatim} 57 | 58 | \begin{verbatim} 59 | $ cat decl_error.sc 60 | int main() { 61 | int a, b; 62 | int a; 63 | } 64 | 65 | $ ./small-c decl_error.sc 66 | 3:7: `a` is already defined 67 | 68 | \end{verbatim} 69 | 70 | \section{課題9: 意味解析} 71 | 72 | \subsection{オブジェクト情報の収集} 73 | サンプルプログラムで動作させた例を示す. 環境変数DEBUGを設定すると, 解析した構文木の内容をpretty printして表示されるようになっている. \\ 74 | 75 | IdentifierExpression などの Symbol フィールドに型などのオブジェクト情報が埋め込まれていることが確認できる. \\ 76 | 77 | \begin{verbatim} 78 | $ cat example/sum.sc 79 | int sum(int a, int b) { 80 | return a + b; 81 | } 82 | 83 | int main() { 84 | print(sum(100, 20)); 85 | } 86 | 87 | $ DEBUG=1 ./small-c example/sum.sc 88 | ... 89 | []main.Statement{ 90 | &main.FunctionDefinition{ 91 | pos: 1, 92 | TypeName: "void", 93 | Identifier: &main.IdentifierExpression{ 94 | pos: 6, 95 | Name: "print", 96 | Symbol: &main.Symbol{ 97 | Name: "print", 98 | Level: 0, 99 | Kind: "proto", 100 | Type: main.FunctionType{ 101 | Return: main.BasicType{ 102 | Name: "void", 103 | }, 104 | Args: []main.SymbolType{ 105 | main.BasicType{ 106 | Name: "int", 107 | }, 108 | }, 109 | }, 110 | Offset: 0, 111 | }, 112 | }, 113 | Parameters: []main.Expression{ 114 | &main.ParameterDeclaration{ 115 | pos: 12, 116 | TypeName: "int", 117 | Identifier: &main.IdentifierExpression{ 118 | pos: 16, 119 | Name: "i", 120 | Symbol: (*main.Symbol)(nil), 121 | }, 122 | }, 123 | }, 124 | Statement: nil, 125 | }, 126 | &main.FunctionDefinition{ 127 | pos: 1, 128 | TypeName: "int", 129 | Identifier: &main.IdentifierExpression{ 130 | pos: 5, 131 | Name: "sum", 132 | Symbol: &main.Symbol{ 133 | Name: "sum", 134 | Level: 0, 135 | Kind: "fun", 136 | Type: main.FunctionType{ 137 | Return: main.BasicType{ 138 | Name: "int", 139 | }, 140 | Args: []main.SymbolType{ 141 | main.BasicType{ 142 | Name: "int", 143 | }, 144 | main.BasicType{ 145 | Name: "int", 146 | }, 147 | }, 148 | }, 149 | Offset: 0, 150 | }, 151 | }, 152 | Parameters: []main.Expression{ 153 | &main.ParameterDeclaration{ 154 | pos: 9, 155 | TypeName: "int", 156 | Identifier: &main.IdentifierExpression{ 157 | pos: 13, 158 | Name: "a", 159 | Symbol: &main.Symbol{ 160 | Name: "a", 161 | Level: 1, 162 | Kind: "parm", 163 | Type: main.BasicType{ 164 | Name: "int", 165 | }, 166 | Offset: 0, 167 | }, 168 | }, 169 | }, 170 | &main.ParameterDeclaration{ 171 | pos: 16, 172 | TypeName: "int", 173 | Identifier: &main.IdentifierExpression{ 174 | pos: 20, 175 | Name: "b", 176 | Symbol: &main.Symbol{ 177 | Name: "b", 178 | Level: 1, 179 | Kind: "parm", 180 | Type: main.BasicType{ 181 | Name: "int", 182 | }, 183 | Offset: 0, 184 | }, 185 | }, 186 | }, 187 | }, 188 | Statement: &main.CompoundStatement{ 189 | pos: 23, 190 | Declarations: []main.Statement{}, 191 | Statements: []main.Statement{ 192 | &main.ReturnStatement{ 193 | pos: 23, 194 | Value: &main.BinaryExpression{ 195 | Left: &main.IdentifierExpression{ 196 | pos: 34, 197 | Name: "a", 198 | Symbol: &main.Symbol{...}, 199 | }, 200 | Operator: "+", 201 | Right: &main.IdentifierExpression{ 202 | pos: 38, 203 | Name: "b", 204 | Symbol: &main.Symbol{...}, 205 | }, 206 | }, 207 | FunctionSymbol: &main.Symbol{ 208 | Name: "#func", 209 | Level: 1, 210 | Kind: "", 211 | Type: main.FunctionType{ 212 | Return: main.BasicType{ 213 | Name: "int", 214 | }, 215 | Args: []main.SymbolType{...}, 216 | }, 217 | Offset: 0, 218 | }, 219 | }, 220 | }, 221 | }, 222 | }, 223 | &main.FunctionDefinition{ 224 | pos: 44, 225 | TypeName: "int", 226 | Identifier: &main.IdentifierExpression{ 227 | pos: 48, 228 | Name: "main", 229 | Symbol: &main.Symbol{ 230 | Name: "main", 231 | Level: 0, 232 | Kind: "fun", 233 | Type: main.FunctionType{ 234 | Return: main.BasicType{ 235 | Name: "int", 236 | }, 237 | Args: []main.SymbolType{}, 238 | }, 239 | Offset: 0, 240 | }, 241 | }, 242 | Parameters: []main.Expression{}, 243 | Statement: &main.CompoundStatement{ 244 | pos: 55, 245 | Declarations: []main.Statement{}, 246 | Statements: []main.Statement{ 247 | &main.ExpressionStatement{ 248 | Value: &main.FunctionCallExpression{ 249 | Identifier: &main.IdentifierExpression{ 250 | pos: 59, 251 | Name: "print", 252 | Symbol: &main.Symbol{...}, 253 | }, 254 | Argument: &main.FunctionCallExpression{ 255 | Identifier: &main.IdentifierExpression{ 256 | pos: 65, 257 | Name: "sum", 258 | Symbol: &main.Symbol{...}, 259 | }, 260 | Argument: &main.ExpressionList{ 261 | Values: []main.Expression{ 262 | &main.NumberExpression{ 263 | pos: 69, 264 | Value: "100", 265 | }, 266 | &main.NumberExpression{ 267 | pos: 74, 268 | Value: "20", 269 | }, 270 | }, 271 | }, 272 | }, 273 | }, 274 | }, 275 | }, 276 | }, 277 | }, 278 | } 279 | 280 | ... 281 | \end{verbatim} 282 | 283 | \subsubsection{実装} 284 | オブジェクト情報の収集は, analyze.goのAnalyze()関数に実装されている. \\ 285 | 286 | Analyze()関数は入力された抽象構文木を再帰的に辿り, オブジェクト情報の収集や式の形の検査を行う. 抽象構文木の関数定義ノードを解析する関数の例を説明する. \\ 287 | 288 | 関数定義に含まれる関数名やパラメータを見て, オブジェクト情報を環境に登録したり, 型情報を取得する処理を行っている. \\ 289 | 290 | エラーを発見した場合, その場で処理を中断し例外を吐くことはしないで, エラーを配列に格納して解析処理を続けている. 複数のエラーがソースコードにあった場合に, 一度に見えるようにするためである. 出力側では特に工夫していないので, 大量の同じようなエラーが出力されることがあるが, とりあえずあまり気にしないことにしている. \\ 291 | 292 | また, 関数定義がプロトタイプ宣言でない場合には内容のstatementが含まれるので, statementを解析する関数を呼び出している. \\ 293 | 294 | \begin{verbatim} 295 | func analyzeFunctionDefinition(s *FunctionDefinition, env *Env) []error { 296 | errs := []error{} 297 | 298 | identifier := findIdentifierExpression(s.Identifier) 299 | argTypes := []SymbolType{} 300 | 301 | for _, p := range s.Parameters { 302 | parameter, ok := p.(*ParameterDeclaration) 303 | if ok { 304 | argType := BasicType{Name: parameter.TypeName} 305 | argTypes = append(argTypes, composeType(parameter.Identifier, argType)) 306 | } 307 | } 308 | 309 | returnType := BasicType{Name: s.TypeName} 310 | symbolType := FunctionType{Return: returnType, Args: argTypes} 311 | 312 | kind := "" 313 | if s.Statement != nil { 314 | kind = "fun" 315 | } else { 316 | kind = "proto" 317 | } 318 | 319 | err := env.Register(identifier, &Symbol{ 320 | Kind: kind, 321 | Type: symbolType, 322 | }) 323 | 324 | if err != nil { 325 | errs = append(errs, SemanticError{ 326 | Pos: s.Pos(), 327 | Err: err, 328 | }) 329 | } 330 | 331 | if s.Statement != nil { 332 | paramEnv := env.CreateChild() 333 | // Set special symbol to analyze function type 334 | paramEnv.Add(&Symbol{ 335 | Name: "#func", 336 | Type: symbolType, 337 | }) 338 | 339 | for _, p := range s.Parameters { 340 | parameter, ok := p.(*ParameterDeclaration) 341 | 342 | if ok { 343 | identifier := findIdentifierExpression(parameter.Identifier) 344 | argType := composeType(parameter.Identifier, BasicType{Name: parameter.TypeName}) 345 | 346 | err := paramEnv.Register(identifier, &Symbol{ 347 | Kind: "parm", 348 | Type: argType, 349 | }) 350 | 351 | if err != nil { 352 | errs = append(errs, SemanticError{ 353 | Pos: parameter.Pos(), 354 | Err: fmt.Errorf("parameter `%s` is already defined", identifier.Name), 355 | }) 356 | } 357 | } 358 | } 359 | 360 | errs = append(errs, analyzeStatement(s.Statement, paramEnv)...) 361 | } 362 | 363 | return errs 364 | } 365 | \end{verbatim} 366 | 367 | \subsection{重複定義の検査,式の形の検査} 368 | 重複定義の検査,式の形の検査は, オブジェクト情報の収集と同じ Analyze() 関数で行っている. オブジェクト情報の収集と同じであるので, 実装の説明は省略して, 実行例をしめす. \\ 369 | 370 | \begin{verbatim} 371 | $ cat error.sc 372 | int f() {} 373 | 374 | int main() { 375 | int a, b, a; 376 | int *p; 377 | f = 100; 378 | p = &(b + 10); 379 | } 380 | 381 | $ ./small-c error.sc 382 | 4:13: `a` is already defined 383 | 5:3: `f` is not variable 384 | 5:3: `f` is not variable 385 | 6:9: the operand of `&` must be on memory 386 | 387 | \end{verbatim} 388 | 389 | \subsection{型検査} 390 | 型検査の実行例は概要に示したとおりである. \\ 391 | 392 | 型検査はtype.goに実装されている. mainからはCheckType()関数を呼び出している. CheckType()は解析済みの抽象構文木を受け取って, 型エラーがあった場合エラーオブジェクトを返す. \\ 393 | 394 | Analyze()関数と同様に, 再帰的に構文木を辿って, 埋め込まれたオブジェクト情報から型を順番に調べている. \\ 395 | 396 | \begin{verbatim} 397 | func CheckType(statements []Statement) error { 398 | for _, s := range statements { 399 | err := CheckTypeOfStatement(s) 400 | if err != nil { 401 | return err 402 | } 403 | } 404 | 405 | return nil 406 | } 407 | \end{verbatim} 408 | 409 | \section{課題10: 中間表現への変換} 410 | 実行例は概要で示したとおりである. \\ 411 | 412 | 中間表現への変換処理は ir.go に実装されている. mainからは CompileIR() を呼び出している. CompileIR() は解析済みの抽象構文木を受け取って, 中間表現プログラムの構造体を返す. \\ 413 | 414 | \subsection{中間表現の構造体の定義} 415 | 416 | 中間表現の構造体定義の例を示す. 講義資料の説明に沿う形で定義している. 417 | 418 | \begin{verbatim} 419 | type IRProgram struct { 420 | Declarations []*IRVariableDeclaration 421 | Functions []*IRFunctionDefinition 422 | } 423 | 424 | type IRStatement interface { 425 | String() string 426 | } 427 | type IRExpression interface { 428 | String() string 429 | } 430 | 431 | type IRVariableDeclaration struct { 432 | Var *Symbol 433 | } 434 | 435 | type IRFunctionDefinition struct { 436 | Var *Symbol 437 | Parameters []*IRVariableDeclaration 438 | Body IRStatement 439 | VarSize int 440 | } 441 | 442 | type IRCompoundStatement struct { 443 | Declarations []*IRVariableDeclaration 444 | Statements []IRStatement 445 | } 446 | 447 | type IRAssignmentStatement struct { 448 | Var *Symbol 449 | Expression IRExpression 450 | } 451 | \end{verbatim} 452 | 453 | \subsection{中間表現の文字列出力} 454 | 中間表現の構造体には, 出力用に簡単な文字列に変換するString()関数が実装されている. \\ 455 | 456 | たとえば, プログラムを実行した際にはこのように表示される. コード中にコメントをつけて説明する. \\ 457 | 458 | 文字列出力では, 複合文の表示は省略されている. 複合文を表示してしまうととても読みにくくなってしまうためである. \\ 459 | 460 | \begin{verbatim} 461 | sum(int a, int b) // 関数の定義 462 | int #tmp_0 // 変数定義. #がついた名前はコンパイラが実装の都合で生成した一時変数を表している 463 | #tmp_0 = (+ a b) // 式はS式で表現している 464 | return #tmp_0 465 | 466 | main() 467 | int #tmp_2 468 | int #tmp_3 469 | int #tmp_4 470 | int #tmp_1 471 | #tmp_2 = 100 472 | #tmp_3 = 20 473 | #tmp_4 = sum(#tmp_2, #tmp_3) // 関数呼び出し 474 | #tmp_1 = #tmp_4 475 | print(#tmp_1) 476 | \end{verbatim} 477 | 478 | \subsection{変換処理の実装} 479 | 抽象構文木を再帰的に辿り, 文や式を変換する実装になっている. \\ 480 | 481 | 文を変換するとき, 一時変数が新たに必要になった場合には, その一時変数の定義を含んだ複合文に変換している. \\ 482 | 483 | 式を変換するときは, 式を複合文をそのまま変換することはできないので, 式を中間表現に変換した結果とともに必要な一時変数の宣言のリストを返している. 式で必要な一時変数は, 呼び出し元の文の変換処理で複合文にまとめて格納される. \\ 484 | 485 | If文の中間表現への変換を例に示す. \\ 486 | 487 | tmpvar()はuniqueな名前の一時変数を返し, label()はジャンプラベル用のuniqueな文字列を返す. \\ 488 | 489 | 条件式を条件判定用の一時変数に格納するようにし, 中間表現の制御構文を使って中間表現の列に変換している. 条件式用に一時変数が必要になるので, その一時変数の宣言を含めた複合文に変換して返すようにしている. \\ 490 | 491 | \begin{verbatim} 492 | func compileIRStatement(statement Statement) IRStatement { 493 | ... 494 | case *IfStatement: 495 | conditionVar := tmpvar() 496 | 497 | trueLabel := label("true") 498 | falseLabel := label("false") 499 | endLabel := label("end") 500 | 501 | condition, decls, beforeCondition := compileIRExpression(s.Condition) 502 | 503 | statements := []IRStatement{ 504 | &IRAssignmentStatement{ 505 | Var: conditionVar, 506 | Expression: condition, 507 | }, 508 | &IRIfStatement{ 509 | Var: conditionVar, 510 | TrueLabel: trueLabel, 511 | FalseLabel: falseLabel, 512 | }, 513 | &IRLabelStatement{ Name: trueLabel }, 514 | compileIRStatement(s.TrueStatement), 515 | &IRGotoStatement{ Label: endLabel }, 516 | &IRLabelStatement{ Name: falseLabel }, 517 | } 518 | 519 | if s.FalseStatement != nil { 520 | statements = append(statements, compileIRStatement(s.FalseStatement)) 521 | } 522 | 523 | statements = append(statements, &IRLabelStatement{ Name: endLabel }) 524 | 525 | return &IRCompoundStatement{ 526 | Declarations: append(IRVariableDeclarations([]*Symbol{conditionVar}), decls...), 527 | Statements: append(beforeCondition, statements...), 528 | } 529 | ... 530 | } 531 | \end{verbatim} 532 | 533 | \section{感想と工夫した点} 534 | デバッグの際に正しい中間表現が生成されていることを確認するのに苦労した. \\ 535 | 536 | 中間表現の変換では一時変数を大量に使うので, 変換結果の中間表現は複合文が大量にネストしたオブジェクトになってしまい, そのまま表示するとかなり読みにくい. \\ 537 | 538 | そこで, 中間表現を簡単な文字列に変換する処理をつくったところ, デバッグがかなりやりやすくなった. 実装がちょっとめんどくさい気がしたけど, 結果的には時間の節約になって嬉しかった. \\ 539 | 540 | \end{document} 541 | -------------------------------------------------------------------------------- /report/final.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiur/small-c/e5a2a01bbbc1e93aa3d406f717a8030c281e7812/report/final.pdf -------------------------------------------------------------------------------- /report/final.tex: -------------------------------------------------------------------------------- 1 | \documentclass[a4j]{jarticle} 2 | \usepackage[dvipdfmx]{graphicx} 3 | 4 | \begin{document} 5 | 6 | \title{計算機科学実験3 最終レポート} 7 | \author{杉本風斗} 8 | 9 | \maketitle 10 | 11 | \section{概要} 12 | 動作はSmall Cの動作仕様に準じている. \\ 13 | makeで実行プログラムをビルドし, 引数にソースファイルを渡すことでMIPSコードを生成する. 14 | 15 | \begin{verbatim} 16 | $ make 17 | $ ./small-c program.sc 18 | \end{verbatim} 19 | 20 | また, make test でユニットテストやspimエミュレータを用いた統合テストを実行できるように工夫した. テスト実行時にコードカバレッジも測定される. 21 | 22 | \begin{verbatim} 23 | $ make test 24 | \end{verbatim} 25 | 26 | \section{課題11: データフロー解析} 27 | 到達可能定義解析を実装した. \\ 28 | データフロー解析と最適化の処理は, optimize.goのOptimize()関数に実装している. \\ 29 | 30 | 以下にコードを示す. 簡単な説明をコード中のコメントにつけた. \\ 31 | 32 | \begin{verbatim} 33 | // optimize.go 34 | type DataflowBlock struct { 35 | Name string // BEGIN, END 36 | Statements []IRStatement 37 | Next []*DataflowBlock 38 | Prev []*DataflowBlock 39 | } 40 | 41 | func Optimize(program *IRProgram) *IRProgram { 42 | for i, f := range program.Functions { 43 | statements := flatStatement(f) 44 | 45 | // 中間表現プログラム列をデータフローのブロックごとに分ける 46 | blocks := splitStatementsIntoBlocks(statements) 47 | 48 | // ブロックの配列からデータフローを構成 49 | // blockそれぞれについて, block.Nextを設定していく 50 | buildDataflowGraph(blocks) 51 | 52 | // データフローを見て不動点反復により到達可能定義解析する 53 | // 返り値はブロックごとに, 各シンボルの到達可能な定義文 を入れたmap 54 | // blockOut = (DataflowBlock -> (*Symbol -> []IRStatement)) 55 | blockOut := searchReachingDefinitions(blocks) 56 | 57 | // ... 58 | } 59 | 60 | return program 61 | } 62 | 63 | // 不動点反復なので、状態が収束するまで地道に解析して状態を更新していくという雰囲気 64 | func searchReachingDefinitions(blocks []*DataflowBlock) map[*DataflowBlock]BlockState { 65 | blockOut := make(map[*DataflowBlock]BlockState) 66 | 67 | changed := true 68 | for changed { 69 | changed = false 70 | 71 | for _, block := range blocks { 72 | inState := analyzeBlock(blockOut, block) 73 | if !inState.Equal(blockOut[block]) { 74 | changed = true 75 | } 76 | 77 | blockOut[block] = inState 78 | } 79 | } 80 | 81 | return blockOut 82 | } 83 | 84 | // ひとつのプログラム点を見て状態を更新する 85 | // 到達可能定義解析の実質的な処理 86 | func analyzeReachingDefinition(statement IRStatement, inState BlockState) BlockState { 87 | switch s := statement.(type) { 88 | case *IRAssignmentStatement: 89 | inState[s.Var] = []IRStatement{s} 90 | symbols := extractAddressVarsFromExpression(s.Expression) 91 | for _, symbol := range symbols { 92 | inState[symbol] = append(inState[symbol], s) 93 | } 94 | 95 | case *IRReadStatement: 96 | inState[s.Dest] = []IRStatement{s} 97 | 98 | // ポインタ参照書き込みがあったら, 諦めムードにしておく 99 | case *IRWriteStatement: 100 | for symbol := range inState { 101 | inState[symbol] = append(inState[symbol], s) 102 | } 103 | 104 | case *IRCallStatement: 105 | inState[s.Dest] = []IRStatement{s} 106 | 107 | } 108 | 109 | return inState 110 | } 111 | \end{verbatim} 112 | 113 | \section{課題12: 最適化} 114 | 定数畳み込みと無駄な命令の除去を実装した. \\ 115 | 116 | 両方とも, 到達可能定義解析で得た情報を利用して実装している. 117 | 118 | \begin{verbatim} 119 | func Optimize(program *IRProgram) *IRProgram { 120 | for i, f := range program.Functions { 121 | // ... 122 | blockOut := searchReachingDefinitions(blocks) 123 | 124 | // 実装の都合で文ごとの到達可能定義を計算しなおしている 125 | allStatementState := reachingDefinitionsOfStatements(blocks, blockOut, statements) 126 | 127 | // 定数畳み込み 128 | program.Functions[i] = transformByConstantFolding(program.Functions[i], allStatementState) 129 | // 無駄コード除去 130 | program.Functions[i] = transformByDeadCodeElimination(program.Functions[i], allStatementState) 131 | } 132 | 133 | return program 134 | } 135 | \end{verbatim} 136 | 137 | \subsection{定数畳み込み} 138 | 中間表現の代入文に対して, 再帰的に定数畳み込みを行う. オペランドが両方とも定数の演算子を発見した場合, 直接計算結果を埋め込む. \\ 139 | 140 | 以下にコードを示す. 141 | 142 | \begin{verbatim} 143 | func transformByConstantFolding(f *IRFunctionDefinition, allStatementState map[IRStatement]BlockState) *IRFunctionDefinition { 144 | traversed := Traverse(f, func(statement IRStatement) IRStatement { 145 | foldConstantStatement(statement, allStatementState) 146 | return statement 147 | }) 148 | 149 | return traversed.(*IRFunctionDefinition) 150 | } 151 | 152 | // 代入文ならexpressionを見て、それが定数だったら埋め込む 153 | func foldConstantStatement(statement IRStatement, allStatementState map[IRStatement]BlockState) (bool, int) { 154 | switch s := statement.(type) { 155 | case *IRAssignmentStatement: 156 | isConstant, value := foldConstantExpression(s, s.Expression, allStatementState) 157 | if isConstant { 158 | s.Expression = &IRNumberExpression{Value: value} 159 | return true, value 160 | } 161 | } 162 | 163 | return false, 0 164 | } 165 | 166 | // 到達可能定義の情報を使って、再帰的に定数畳み込みしていく 167 | func foldConstantExpression(statement IRStatement, expression IRExpression, allStatementState map[IRStatement]BlockState) (bool, int) { 168 | switch e := expression.(type) { 169 | case *IRNumberExpression: 170 | return true, e.Value 171 | 172 | case *IRVariableExpression: 173 | symbol := e.Var 174 | definitionOfVar := allStatementState[statement][symbol] 175 | if len(definitionOfVar) == 1 && definitionOfVar[0] != statement { 176 | return foldConstantStatement(definitionOfVar[0], allStatementState) 177 | } 178 | 179 | return false, 0 180 | 181 | case *IRBinaryExpression: 182 | leftIsConstant, leftValue := foldConstantExpression(statement, e.Left, allStatementState) 183 | rightIsConstant, rightValue := foldConstantExpression(statement, e.Right, allStatementState) 184 | 185 | if leftIsConstant { 186 | e.Left = &IRNumberExpression{Value: leftValue} 187 | } 188 | 189 | if rightIsConstant { 190 | e.Right = &IRNumberExpression{Value: rightValue} 191 | } 192 | 193 | if leftIsConstant && rightIsConstant { 194 | switch e.Operator { 195 | case "+": 196 | return true, leftValue + rightValue 197 | 198 | case "-": 199 | return true, leftValue - rightValue 200 | 201 | case "*": 202 | return true, leftValue * rightValue 203 | 204 | case "/": 205 | return true, leftValue / rightValue 206 | 207 | case "<": 208 | value := 0 209 | if leftValue < rightValue { 210 | value = 1 211 | } 212 | return true, value 213 | 214 | case ">": 215 | value := 0 216 | if leftValue > rightValue { 217 | value = 1 218 | } 219 | return true, value 220 | 221 | case "<=": 222 | value := 0 223 | if leftValue <= rightValue { 224 | value = 1 225 | } 226 | return true, value 227 | 228 | case ">=": 229 | value := 0 230 | if leftValue >= rightValue { 231 | value = 1 232 | } 233 | return true, value 234 | 235 | case "==": 236 | value := 0 237 | if leftValue == rightValue { 238 | value = 1 239 | } 240 | return true, value 241 | 242 | case "!=": 243 | value := 0 244 | if leftValue != rightValue { 245 | value = 1 246 | } 247 | return true, value 248 | 249 | } 250 | 251 | panic("unexpected operator: " + e.Operator) 252 | } 253 | 254 | return false, 0 255 | } 256 | 257 | return false, 0 258 | } 259 | \end{verbatim} 260 | 261 | \subsection{無駄な命令の除去} 262 | 263 | 文を使用しているかどうかを到達可能定義を用いて記録し, 無駄な文を発見したら消す操作を収束するまで繰り返す. 無駄な命令を除去した結果, 不要になった変数宣言を最後に削除している. 264 | 265 | \begin{verbatim} 266 | func transformByDeadCodeElimination(f *IRFunctionDefinition, allStatementState map[IRStatement]BlockState) *IRFunctionDefinition { 267 | changed := true 268 | for changed { 269 | changed = false 270 | 271 | used := make(map[IRStatement]bool) 272 | markAsUsed := func(s IRStatement, symbol *Symbol) { 273 | for _, definition := range allStatementState[s][symbol] { 274 | used[definition] = true 275 | } 276 | } 277 | 278 | Traverse(f, func(statement IRStatement) IRStatement { 279 | switch s := statement.(type) { 280 | case *IRCompoundStatement: 281 | used[s] = true 282 | 283 | case *IRAssignmentStatement: 284 | if s.Var.IsGlobal() { 285 | used[s] = true 286 | } 287 | 288 | vars := extractVarsFromExpression(s.Expression) 289 | for _, v := range vars { 290 | markAsUsed(s, v) 291 | } 292 | 293 | case *IRReadStatement: 294 | if s.Dest.IsGlobal() { 295 | used[s] = true 296 | } 297 | 298 | markAsUsed(s, s.Src) 299 | 300 | case *IRWriteStatement: 301 | markAsUsed(s, s.Src) 302 | markAsUsed(s, s.Dest) 303 | 304 | case *IRCallStatement: 305 | if s.Dest.IsGlobal() { 306 | used[s] = true 307 | } 308 | 309 | for _, argVar := range s.Vars { 310 | markAsUsed(s, argVar) 311 | } 312 | 313 | case *IRSystemCallStatement: 314 | markAsUsed(s, s.Var) 315 | 316 | case *IRReturnStatement: 317 | markAsUsed(s, s.Var) 318 | 319 | case *IRIfStatement: 320 | markAsUsed(s, s.Var) 321 | } 322 | 323 | return statement 324 | }) 325 | 326 | transformed := Traverse(f, func(statement IRStatement) IRStatement { 327 | switch statement.(type) { 328 | case *IRAssignmentStatement, *IRReadStatement: 329 | if !used[statement] { 330 | changed = true 331 | return nil 332 | } 333 | } 334 | 335 | return statement 336 | }) 337 | 338 | f = transformed.(*IRFunctionDefinition) 339 | } 340 | 341 | return removeUnusedVariableDeclaration(f) 342 | } 343 | \end{verbatim} 344 | 345 | \subsection{最適化の効果} 346 | 最適化の効果を例を用いて説明する. \\ 347 | 348 | \begin{verbatim} 349 | // demo/optimize_constant.sc 350 | int main() { 351 | int a, b; 352 | int c; 353 | c = 3; 354 | 355 | a = c; // 3 356 | b = a + c; // 3 + 3 357 | print(a + b == 9); // 3 + 6 == 9 358 | } 359 | \end{verbatim} 360 | 361 | 比較用に最適化を無効化するオプションをつけてコードを生成する. 362 | 363 | \begin{verbatim} 364 | $ ./small-c -optimize=false demo/optimize_constant.sc > demo/optimize_constant.s 365 | $ ./small-c demo/optimize_constant.sc > demo/optimize_constant_optimized.s 366 | \end{verbatim} 367 | 368 | \begin{verbatim} 369 | # 最適化前 370 | $ spim -show_stats -f demo/optimize_constant.s 371 | Loaded: /usr/local/share/spim/exceptions.s 372 | 1 373 | --- Summary --- 374 | # of executed instructions 375 | - Total: 47 376 | - Memory: 21 377 | - Others: 26 378 | 379 | --- Details --- 380 | add 2 381 | addi 9 382 | addiu 2 383 | addu 1 384 | beq 1 385 | jal 1 386 | jr 1 387 | lw 12 388 | ori 5 389 | sll 2 390 | sw 9 391 | syscall 2 392 | 393 | # 最適化後 394 | $ spim -show_stats -f demo/optimize_constant_optimized.s 395 | Loaded: /usr/local/share/spim/exceptions.s 396 | 1 397 | --- Summary --- 398 | # of executed instructions 399 | - Total: 22 400 | - Memory: 7 401 | - Others: 15 402 | 403 | --- Details --- 404 | addi 3 405 | addiu 2 406 | addu 1 407 | jal 1 408 | jr 1 409 | lw 4 410 | ori 3 411 | sll 2 412 | sw 3 413 | syscall 2 414 | 415 | \end{verbatim} 416 | 417 | 合計命令数が 47 から 22 まで削減された. このように定数畳み込みと無駄な命令を除去を組み合わせると大きな効果がある場合がある. \\ 418 | 419 | ただし, 今回の簡単な最適化では, ポインタが多く用いられるようなプログラムではそれほど効果は得られない. 420 | 421 | \section{課題13: 相対番地の計算} 422 | 相対番地の計算は, compile.goの CalculateOffset() 関数に実装している. \\ 423 | 複合文を再帰的に探して, それに含まれる変数宣言に対して相対番地を計算している. \\ 424 | グローバル変数の場合は, グローバルポインタ $ \$gp $からの相対番地を計算する. 425 | 426 | \begin{verbatim} 427 | func CalculateOffset(ir *IRProgram) { 428 | globalOffset := 0 429 | // global vars 430 | for _, d := range ir.Declarations { 431 | size := d.Var.Type.ByteSize() 432 | globalOffset -= size 433 | d.Var.Offset = globalOffset 434 | } 435 | 436 | for _, f := range ir.Functions { 437 | calculateOffsetFunction(f) 438 | } 439 | } 440 | 441 | func calculateOffsetFunction(ir *IRFunctionDefinition) { 442 | offset := 0 443 | 444 | for i := len(ir.Parameters) - 1; i >= 0; i-- { 445 | p := ir.Parameters[i] 446 | size := p.Var.Type.ByteSize() 447 | 448 | // arg 4 => 4($fp), arg 5 => 8($fp) 449 | if i >= 4 { 450 | p.Var.Offset = (i - 3) * size 451 | } else { 452 | p.Var.Offset = offset - (size - 4) 453 | offset -= size 454 | } 455 | } 456 | 457 | minOffset := calculateOffsetStatement(ir.Body, offset) 458 | ir.VarSize = -minOffset 459 | } 460 | 461 | func calculateOffsetStatement(statement IRStatement, base int) int { 462 | offset := base 463 | minOffset := 0 464 | 465 | switch s := statement.(type) { 466 | case *IRCompoundStatement: 467 | for _, d := range s.Declarations { 468 | size := d.Var.Type.ByteSize() 469 | d.Var.Offset = offset - (size - 4) 470 | offset -= size 471 | } 472 | 473 | minOffset = offset 474 | for _, s := range s.Statements { 475 | statementOffset := calculateOffsetStatement(s, offset) 476 | 477 | if statementOffset < minOffset { 478 | minOffset = statementOffset 479 | } 480 | } 481 | } 482 | 483 | return minOffset 484 | } 485 | \end{verbatim} 486 | 487 | \section{課題15: コード生成} 488 | コード生成は, compile.goのCompile()関数に実装している. \\ 489 | Compile() は中間表現プログラムを受け取り, 生成したMIPSコードを文字列として返す. \\ 490 | 491 | 変数参照では, グローバル変数の場合は$gp, それ以外の場合は$fpをベースポインタとして計算した相対番地を使って相対参照するようにしている. 492 | 493 | \begin{verbatim} 494 | // Compile takes ir program as input and returns mips code 495 | func Compile(program *IRProgram) string { 496 | CalculateOffset(program) 497 | 498 | code := "" 499 | code += ".data\n" 500 | code += ".text\n.globl main\n" 501 | for _, f := range program.Functions { 502 | code += "\n" + strings.Join(compileFunction(f), "\n") + "\n" 503 | } 504 | 505 | return code 506 | } 507 | 508 | func compileFunction(function *IRFunctionDefinition) []string { 509 | size := function.VarSize + 4*2 // arguments + local vars + $ra + $fp 510 | 511 | var code []string 512 | code = append( 513 | code, 514 | fmt.Sprintf("%s:", function.Var.Name), 515 | fmt.Sprintf("addi $sp, $sp, %d", -size), 516 | "sw $ra, 4($sp)", 517 | "sw $fp, 0($sp)", 518 | fmt.Sprintf("addi $fp, $sp, %d", size-4), 519 | ) 520 | 521 | for i := len(function.Parameters) - 1; i >= 0; i-- { 522 | p := function.Parameters[i] 523 | // arg 4,5,6... is passed via 4($fp), 8($fp), ... 524 | if i < 4 { 525 | code = append(code, fmt.Sprintf("sw $a%d, %d($fp)", i, p.Var.Offset)) 526 | } 527 | } 528 | 529 | code = append(code, compileStatement(function.Body, function)...) 530 | 531 | code = append( 532 | code, 533 | function.Var.Name+"_exit:", 534 | "lw $fp, 0($sp)", 535 | "lw $ra, 4($sp)", 536 | fmt.Sprintf("addi $sp, $sp, %d", size), 537 | "jr $ra", 538 | ) 539 | 540 | return code 541 | } 542 | 543 | func compileStatement(statement IRStatement, function *IRFunctionDefinition) []string { 544 | var code []string 545 | 546 | switch s := statement.(type) { 547 | case *IRCompoundStatement: 548 | for _, statement := range s.Statements { 549 | code = append(code, compileStatement(statement, function)...) 550 | } 551 | 552 | case *IRAssignmentStatement: 553 | code = append(code, assignExpression("$t0", s.Expression)...) 554 | code = append(code, sw("$t0", s.Var)) 555 | 556 | case *IRCallStatement: 557 | for i := len(s.Vars) - 1; i >= 0; i-- { 558 | v := s.Vars[i] 559 | 560 | if i >= 4 { 561 | code = append(code, lw("$t0", v)) 562 | code = append(code, 563 | "addi $sp, $sp, -4", 564 | fmt.Sprintf("sw %s, 0($sp)", "$t0"), 565 | ) 566 | } else { 567 | code = append(code, lw(fmt.Sprintf("$a%d", i), v)) 568 | } 569 | } 570 | 571 | code = append(code, fmt.Sprintf("jal %s", s.Func.Name)) 572 | if len(s.Vars) > 4 { 573 | code = append(code, fmt.Sprintf("addi $sp, $sp, %d", 4*(len(s.Vars)-4))) 574 | } 575 | code = append(code, sw("$v0", s.Dest)) 576 | 577 | case *IRReturnStatement: 578 | if s.Var != nil { 579 | code = append(code, 580 | lw("$v0", s.Var), 581 | ) 582 | } 583 | 584 | code = append(code, 585 | fmt.Sprintf("j %s_exit", function.Var.Name), 586 | ) 587 | 588 | case *IRWriteStatement: 589 | return []string{ 590 | lw("$t0", s.Src), 591 | lw("$t1", s.Dest), 592 | "sw $t0, 0($t1)", 593 | } 594 | 595 | case *IRReadStatement: 596 | return []string{ 597 | lw("$t0", s.Src), 598 | "lw $t1, 0($t0)", 599 | sw("$t1", s.Dest), 600 | } 601 | 602 | case *IRLabelStatement: 603 | return append(code, s.Name+":") 604 | 605 | case *IRIfStatement: 606 | falseLabel := label("ir_if_false") 607 | endLabel := label("ir_if_end") 608 | 609 | code = append(code, 610 | lw("$t0", s.Var), 611 | fmt.Sprintf("beq $t0, $zero, %s", falseLabel), 612 | ) 613 | 614 | if len(s.TrueLabel) > 0 { 615 | code = append(code, 616 | fmt.Sprintf("j %s", s.TrueLabel), 617 | ) 618 | } else { 619 | code = append(code, 620 | fmt.Sprintf("j %s", endLabel), 621 | ) 622 | } 623 | 624 | code = append(code, 625 | falseLabel+":", 626 | ) 627 | 628 | if len(s.FalseLabel) > 0 { 629 | code = append(code, 630 | fmt.Sprintf("j %s", s.FalseLabel), 631 | ) 632 | } 633 | 634 | code = append(code, 635 | endLabel+":", 636 | ) 637 | 638 | case *IRGotoStatement: 639 | code = append(code, jmp(s.Label)) 640 | 641 | case *IRSystemCallStatement: 642 | switch s.Name { 643 | case "print": 644 | return []string{ 645 | "li $v0, 1", 646 | lw("$a0", s.Var), 647 | "syscall", 648 | } 649 | case "putchar": 650 | return []string{ 651 | "li $v0, 11", 652 | lw("$a0", s.Var), 653 | "syscall", 654 | } 655 | 656 | default: 657 | panic("invalid system call: " + s.Name) 658 | 659 | } 660 | } 661 | 662 | return code 663 | } 664 | 665 | func assignExpression(register string, expression IRExpression) []string { 666 | var code []string 667 | 668 | switch e := expression.(type) { 669 | case *IRNumberExpression: 670 | code = append(code, fmt.Sprintf("li %s, %d", register, e.Value)) 671 | 672 | case *IRBinaryExpression: 673 | leftRegister := "$t1" 674 | rightRegister := "$t2" 675 | 676 | code = append(code, assignExpression(leftRegister, e.Left)...) 677 | code = append(code, 678 | "addi $sp, $sp, -4", 679 | fmt.Sprintf("sw %s, 0($sp)", leftRegister), 680 | ) 681 | code = append(code, assignExpression(rightRegister, e.Right)...) 682 | code = append(code, 683 | fmt.Sprintf("lw %s, 0($sp)", leftRegister), 684 | "addi $sp, $sp, 4", 685 | ) 686 | 687 | operation := assignBinaryOperation(register, e.Operator, leftRegister, rightRegister) 688 | 689 | return append(code, operation...) 690 | 691 | case *IRVariableExpression: 692 | // *(a + 4) 693 | _, isArrayType := e.Var.Type.(ArrayType) 694 | if isArrayType { 695 | return []string{ 696 | fmt.Sprintf("addi %s, %s, %d", register, e.Var.AddressPointer(), e.Var.Offset), 697 | } 698 | } 699 | 700 | return append(code, lw(register, e.Var)) 701 | 702 | case *IRAddressExpression: 703 | return []string{ 704 | fmt.Sprintf("addi %s, %s, %d", register, e.Var.AddressPointer(), e.Var.Offset), 705 | } 706 | } 707 | 708 | return code 709 | } 710 | 711 | func assignBinaryOperation(register string, operator string, left string, right string) []string { 712 | inst := operatorToInst[operator] 713 | if len(inst) > 0 { 714 | return []string{ 715 | fmt.Sprintf("%s %s, %s, %s", inst, register, left, right), 716 | } 717 | } 718 | 719 | switch operator { 720 | case "==": 721 | falseLabel := label("beq_true") 722 | endLabel := label("beq_end") 723 | 724 | return []string{ 725 | fmt.Sprintf("beq $t1, $t2, %s", falseLabel), 726 | li(register, 0), 727 | fmt.Sprintf("j %s", endLabel), 728 | falseLabel + ":", 729 | li(register, 1), 730 | endLabel + ":", 731 | } 732 | 733 | case "!=": 734 | falseLabel := label("beq_true") 735 | endLabel := label("beq_end") 736 | 737 | return []string{ 738 | fmt.Sprintf("beq $t1, $t2, %s", falseLabel), 739 | li(register, 1), 740 | fmt.Sprintf("j %s", endLabel), 741 | falseLabel + ":", 742 | li(register, 0), 743 | endLabel + ":", 744 | } 745 | 746 | case ">": 747 | // a > b <=> (a <= b) < 1 748 | return append(assignBinaryOperation(register, "<=", left, right), 749 | fmt.Sprintf("slti %s, %s, 1", register, register), 750 | ) 751 | 752 | case "<=": 753 | // a <= b <=> a - 1 < b 754 | return []string{ 755 | fmt.Sprintf("addi %s, %s, -1", left, left), 756 | fmt.Sprintf("slt %s, %s, %s", register, left, right), 757 | } 758 | 759 | case ">=": 760 | // a >= b <=> b <= a 761 | return assignBinaryOperation(register, "<=", right, left) 762 | } 763 | 764 | panic("unimplemented operator: " + operator) 765 | } 766 | \end{verbatim} 767 | 768 | \section{感想} 769 | 最適化の実装は難しかったが, コンパイラの書籍(ドラゴンブック)を読んだりしながら苦労して実装する中で, コンパイラの奥深さを垣間見ることができた. 770 | 771 | 今回Go言語を使って実装していて思ったのは, Go言語は表現力に乏しく, 書いていて楽しくない, ということである. モダンなシンタックス, 型による支援, 強力なエコシステムは大きな魅力ではあるが, 他の関数型言語に比べると多くの行を書かなければならない. 次に言語を選ぶときは, 大企業のマーケティングに騙されず, 書いていて楽しい言語を選びたい. 772 | 773 | \end{document} 774 | -------------------------------------------------------------------------------- /run: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | SCC_OPTION='-s /usr/local/bin/spim -c ./small-c' 3 | src/scc -e ${SCC_OPTION} $* 4 | -------------------------------------------------------------------------------- /run-sample: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo "Running sample..." 4 | ./run sample/ok*.sc 5 | ./run sample/ng*.sc 2>&1 > /dev/null 6 | -------------------------------------------------------------------------------- /run-test: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo "Running basic..." 4 | ./run test/basic/*.sc 5 | 6 | echo "\nRunning advanced..." 7 | ./run test/advanced/*.sc 8 | 9 | echo "\nRunning err..." 10 | ./run test/err/*.sc 2>&1 > /dev/null 11 | -------------------------------------------------------------------------------- /sample/ng0.sc: -------------------------------------------------------------------------------- 1 | int f(int x) { 2 | return x + 1; 3 | } 4 | 5 | void main() { 6 | print(f()); 7 | } 8 | -------------------------------------------------------------------------------- /sample/ng1.sc: -------------------------------------------------------------------------------- 1 | void main() { 2 | int **a[2]; 3 | } -------------------------------------------------------------------------------- /sample/ok0.sc: -------------------------------------------------------------------------------- 1 | void main() { 2 | print(1 + 2 == 3); 3 | } 4 | -------------------------------------------------------------------------------- /sample/ok1.sc: -------------------------------------------------------------------------------- 1 | void main() { 2 | int x, y; 3 | x = 2; 4 | y = 3; 5 | print(x + y == 5 && x * y == 6); 6 | } 7 | -------------------------------------------------------------------------------- /src/scc: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env racket 2 | #lang racket 3 | 4 | ;; =========== mips-util.rkt =========================== 5 | (require parser-tools/lex 6 | (prefix-in : parser-tools/lex-sre) 7 | parser-tools/yacc) 8 | 9 | (define-tokens tokens-with-value 10 | (NUM CHAR STR ID DIR)) 11 | 12 | (define-empty-tokens tokens-without-value 13 | (COLON COMMA NEWLINE 14 | DOLLAR LPAR RPAR 15 | EOF)) 16 | 17 | (define-lex-abbrevs 18 | (digit (char-range "0" "9")) 19 | (digit-non-zero (char-range "1" "9")) 20 | (number (:or "0" 21 | (:: digit-non-zero 22 | (:* digit)))) 23 | (identifier-char (:or (char-range "a" "z") 24 | (char-range "A" "Z") 25 | "_")) 26 | (identifier (:: identifier-char 27 | (:* (:or identifier-char digit))))) 28 | 29 | (define mips-lexer 30 | (lexer 31 | ("$" (token-DOLLAR)) 32 | (":" (token-COLON)) 33 | ("(" (token-LPAR)) 34 | (")" (token-RPAR)) 35 | ("," (token-COMMA)) 36 | ((:: (:or "" "+" "-") number) (token-NUM (string->number lexeme))) 37 | (identifier (token-ID (string->symbol lexeme))) 38 | ((:: "." identifier) (token-DIR (string->symbol lexeme))) 39 | ((:: "'" (:or any-char (:: "\\" any-char)) "'") 40 | (token-CHAR 41 | (string-ref 42 | (read 43 | (open-input-string 44 | (string-append "\"" 45 | (substring lexeme 1 (- (string-length lexeme) 1)) 46 | "\""))) 47 | 0))) 48 | ((:: "\"" (:* (:or any-char (:: "\\" any-char))) "\"") 49 | (token-STR (read (open-input-string lexeme)))) 50 | ("\n" (token-NEWLINE)) 51 | ((:or " " "\t") (mips-lexer input-port)) 52 | ((:: "#" (:* (:~ "\n"))) (mips-lexer input-port)) 53 | ((eof) (token-EOF)))) 54 | 55 | (define mem-instrs 56 | (apply set '(lb lbu ld lh lhu ll lw lwc1 lwl lwr ulh ulhu 57 | ulw sb sc sd sh sw swc1 sdc1 swl swr ush usw ))) 58 | 59 | (define instrs 60 | (set-union 61 | mem-instrs 62 | (apply set '(abs add addi addiu addu and andi b bclf bclt beq 63 | beqz bge bgeu bgez bgezal bgt bgtu bgtz ble bleu 64 | blez blt bltu bltz bltzal bne bnez clo clz div 65 | divu j jal jalr jr li lui la move movf movn movt 66 | movz mfc0 mfc1 mfhi mflo mthi mtlo mtc0 mtc1 madd 67 | maddu msub msubu mul mulo mulou mult multu neg negu 68 | nop nor not or ori rem remu rol ror seq sge sgeu 69 | sgt sgtu sle sleu slt slti sltiu sltu sne sll sllv 70 | sra srav srl srlv sub subu syscall xor xori)))) 71 | 72 | (define directives 73 | (apply set '(.align .ascii .asciiz .byte .data .double .extern 74 | .float .globl .half .kdata .ktext .set .space .text .word 75 | .rdata .sdata))) 76 | 77 | (define mips-parser 78 | (parser 79 | (start program) 80 | (end EOF) 81 | ;(debug "mips-parser.tbl") 82 | (suppress) 83 | (error (lambda (tok-ok? tok-name tok-value) 84 | (error "parse error:" tok-name tok-value))) 85 | (tokens tokens-with-value tokens-without-value) 86 | (grammar 87 | (program (() '()) 88 | ((line program) (if $1 (cons $1 $2) $2))) 89 | (line ((instruction NEWLINE) $1) 90 | ((directive NEWLINE) $1) 91 | ((label) $1) 92 | ((NEWLINE) #f)) 93 | (label ((ID COLON) `(#:label ,$1)) 94 | ((ID COLON NEWLINE) `(#:label ,$1))) 95 | (instruction ((ID operands-opt) 96 | (if (set-member? instrs $1) 97 | (cons $1 $2) 98 | (error (format "illegal opcode: ~a" $1))))) 99 | (directive ((DIR operands-opt) 100 | (if (set-member? directives $1) 101 | (cons $1 $2) 102 | (error (format "illegal directive: ~a" $1))))) 103 | (operands-opt (() '()) 104 | ((operands) $1)) 105 | (operands ((operand) (list $1)) 106 | ((operand COMMA operands) (cons $1 $3))) 107 | (operand ((NUM) $1) 108 | ((ID) $1) 109 | ((DOLLAR ID) `($ ,$2)) 110 | ((NUM LPAR operand RPAR) `(,$1 ,$3)) 111 | ((STR) $1) 112 | ((CHAR) $1))))) 113 | 114 | (define (mips-parse-port port) 115 | (mips-parser (lambda () (mips-lexer port)))) 116 | 117 | (define (mips-parse-string str) 118 | (mips-parse-port (open-input-string str))) 119 | 120 | (define (mips-parse-file fname) 121 | (mips-parse-port (open-input-file fname))) 122 | 123 | (define (mips-count-file fname) 124 | (let* ((code (mips-parse-file fname)) 125 | (mems (filter (lambda (line) (set-member? mem-instrs (first line))) 126 | code))) 127 | (display (format "total: ~a~%" (length code))) 128 | (display (format "memory access: ~a~%" (length mems))))) 129 | 130 | ;; =========== end of mips-util.rkt ==================== 131 | 132 | (define spim-command "/home/lab4/umatani/local/bin/spim -file ~a") 133 | 134 | (define racket-compiler-module "compiler") 135 | (define racket-compiler-function 'compile) 136 | 137 | (define (usage) 138 | (display "Usage: scc (