├── doc └── measurements.ods ├── .gitignore ├── testdata ├── model.v3.19.xg.batch1.txt.gz ├── model.v3.19.xg.batch2.txt.gz ├── model.v3.19.xg.batch3.txt.gz ├── model.v3.19.xg.batch1.metadata.json ├── model.v3.19.xg.batch2.metadata.json └── model.v3.19.xg.batch3.metadata.json ├── Makefile ├── eval ├── mlmodel.go ├── predict │ └── predict.go ├── common_test.go ├── modutils │ └── modutils.go ├── examples.go ├── zero │ └── model.go ├── ym │ └── yesman.go ├── feats │ ├── letters.go │ ├── query_evaluation.go │ └── score.go ├── report.go ├── xg │ └── xgboost.go ├── rf │ └── model.go ├── nn │ └── model.go └── model.go ├── tools └── tools.go ├── conf-sample.json ├── cql ├── parentmap.go ├── common.go ├── query_test.go ├── grammar_test.go ├── query.go └── rgsimple.go ├── scripts ├── rfchart.py └── learnxgb.py ├── featurize.go ├── apiserver ├── common.go ├── handler.go ├── apiserver.go └── test_page.go ├── go.mod ├── dataimport └── camus.go ├── README.md ├── cnf └── conf.go ├── learn.go ├── repl.go ├── lognorm.go ├── cqlizer.go └── LICENSE /doc/measurements.ods: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czcorpus/cqlizer/main/doc/measurements.ods -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | cqlizer 2 | conf.json 3 | data/ 4 | testdata/testing.sqlite 5 | venv/ 6 | _models/ 7 | 8 | cql_features* 9 | -------------------------------------------------------------------------------- /testdata/model.v3.19.xg.batch1.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czcorpus/cqlizer/main/testdata/model.v3.19.xg.batch1.txt.gz -------------------------------------------------------------------------------- /testdata/model.v3.19.xg.batch2.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czcorpus/cqlizer/main/testdata/model.v3.19.xg.batch2.txt.gz -------------------------------------------------------------------------------- /testdata/model.v3.19.xg.batch3.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czcorpus/cqlizer/main/testdata/model.v3.19.xg.batch3.txt.gz -------------------------------------------------------------------------------- /testdata/model.v3.19.xg.batch1.metadata.json: -------------------------------------------------------------------------------- 1 | {"objective": "binary", "metric": ["auc", "binary_logloss"], "scale_pos_weight": 2.000586166471278, "max_depth": 6, "learning_rate": 0.05, "num_leaves": 81, "min_child_samples": 20, "subsample": 0.8, "colsample_bytree": 0.8, "random_state": 42, "verbose": -1} -------------------------------------------------------------------------------- /testdata/model.v3.19.xg.batch2.metadata.json: -------------------------------------------------------------------------------- 1 | {"objective": "binary", "metric": ["auc", "binary_logloss"], "scale_pos_weight": 2.000586166471278, "max_depth": 6, "learning_rate": 0.05, "num_leaves": 81, "min_child_samples": 20, "subsample": 0.8, "colsample_bytree": 0.8, "random_state": 42, "verbose": -1} -------------------------------------------------------------------------------- /testdata/model.v3.19.xg.batch3.metadata.json: -------------------------------------------------------------------------------- 1 | {"objective": "binary", "metric": ["auc", "binary_logloss"], "scale_pos_weight": 2.000586166471278, "max_depth": 6, "learning_rate": 0.05, "num_leaves": 81, "min_child_samples": 20, "subsample": 0.8, "colsample_bytree": 0.8, "random_state": 42, "verbose": -1} -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VERSION=`git describe --tags` 2 | BUILD=`date +%FT%T%z` 3 | HASH=`git rev-parse --short HEAD` 4 | 5 | LDFLAGS=-ldflags "-w -s -X main.version=${VERSION} -X main.buildDate=${BUILD} -X main.gitCommit=${HASH}" 6 | 7 | 8 | .PHONY: clean tools build generate 9 | 10 | all: generate build 11 | 12 | build: 13 | @echo "building the project without running unit tests" 14 | @go build ${LDFLAGS} -o cqlizer 15 | 16 | 17 | tools: 18 | @echo "installing local dependencies" 19 | @go install github.com/mna/pigeon 20 | 21 | generate: 22 | @echo "generating query parser code" 23 | @go generate ./cqlizer.go 24 | 25 | -------------------------------------------------------------------------------- /eval/mlmodel.go: -------------------------------------------------------------------------------- 1 | package eval 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/czcorpus/cqlizer/eval/nn" 7 | "github.com/czcorpus/cqlizer/eval/rf" 8 | "github.com/czcorpus/cqlizer/eval/xg" 9 | "github.com/czcorpus/cqlizer/eval/ym" 10 | ) 11 | 12 | var ErrNoSuchModel = errors.New("no such model") 13 | 14 | func GetMLModel(modelType, modelPath string) (MLModel, error) { 15 | 16 | var mlModel MLModel 17 | var err error 18 | 19 | switch modelType { 20 | case "rf": 21 | mlModel, err = rf.LoadFromFile(modelPath) 22 | case "nn": 23 | mlModel, err = nn.LoadFromFile(modelPath) 24 | case "xg": 25 | mlModel, err = xg.LoadFromFile(modelPath) 26 | case "ym": 27 | mlModel = &ym.Model{} 28 | default: 29 | err = ErrNoSuchModel 30 | } 31 | return mlModel, err 32 | } 33 | -------------------------------------------------------------------------------- /tools/tools.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Tomas Machalek 2 | // Copyright 2024 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | //go:build tools 18 | 19 | package tools 20 | 21 | import ( 22 | _ "github.com/mna/pigeon" 23 | ) 24 | -------------------------------------------------------------------------------- /eval/predict/predict.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package predict 18 | 19 | type Prediction struct { 20 | Votes []float64 21 | PredictedClass int 22 | } 23 | 24 | func (p Prediction) FastOrSlow() string { 25 | if p.PredictedClass == 0 { 26 | return "fast" 27 | } 28 | return "slow" 29 | } 30 | 31 | func (p Prediction) SlowQueryVote() float64 { 32 | return p.Votes[1] 33 | } 34 | -------------------------------------------------------------------------------- /eval/common_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package eval 18 | 19 | import ( 20 | "testing" 21 | 22 | "github.com/czcorpus/cqlizer/eval/modutils" 23 | "github.com/stretchr/testify/assert" 24 | ) 25 | 26 | func TestExtractModelNameBaseFromFeatFile(t *testing.T) { 27 | f := modutils.ExtractModelNameBaseFromFeatFile("cql_features.v7.144.nonzero.msgpack") 28 | assert.Equal(t, "cql_features.v7.144", f) 29 | } 30 | -------------------------------------------------------------------------------- /conf-sample.json: -------------------------------------------------------------------------------- 1 | { 2 | "listenAddress": "192.168.1.10", 3 | "listenPort": 8080, 4 | "logging": { 5 | "level": "info" 6 | }, 7 | "serverReadTimeoutSecs": 120, 8 | "serverWriteTimeoutSecs": 60, 9 | "syntheticRecordsTimeCorrection": 1.1, 10 | "rfEnsemble": [ 11 | { 12 | "modelType": "xg", 13 | "modelPath": "./testdata/model.v3.19.xg.batch1.txt.gz", 14 | "voteThreshold": 0.75 15 | }, 16 | { 17 | "modelType": "xg", 18 | "modelPath": "./testdata/model.v3.19.xg.batch2.txt.gz", 19 | "voteThreshold": 0.75 20 | }, 21 | { 22 | "modelType": "xg", 23 | "modelPath": "./testdata/model.v3.19.xg.batch3.txt.gz", 24 | "voteThreshold": 0.75 25 | } 26 | ], 27 | "corporaProps": { 28 | "my_corpus_4g": { 29 | "size": 4000000000, 30 | "lang": "en", 31 | "altCorpus": "my_smaller_corpus" 32 | }, 33 | "my_corpus_5g": { 34 | "size": 5000000000, 35 | "lang": "en", 36 | "altCorpus": "my_smaller_corpus" 37 | }, 38 | "my_corpus_6g": { 39 | "size": 6000000000, 40 | "lang": "en", 41 | "altCorpus": "my_smaller_corpus" 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /cql/parentmap.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Tomas Machalek 2 | // Copyright 2024 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package cql 18 | 19 | import "reflect" 20 | 21 | const ( 22 | maxSrchLimit = 10000 23 | ) 24 | 25 | type parentMap map[ASTNode]ASTNode 26 | 27 | func (pm parentMap) findParentByType(node, prototype ASTNode, maxDist int) ASTNode { 28 | pType := reflect.TypeOf(prototype) 29 | var i int 30 | 31 | curr := node 32 | for curr != nil { 33 | p2Type := reflect.TypeOf(curr) 34 | if pType == p2Type { 35 | return curr 36 | } 37 | if i == maxDist && maxDist > 0 { 38 | return nil 39 | } 40 | curr = pm[curr] 41 | i++ 42 | if i == maxSrchLimit { 43 | panic("possibly infinite parent search") 44 | } 45 | } 46 | return nil 47 | } 48 | -------------------------------------------------------------------------------- /scripts/rfchart.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | # Read arguments: script_name, "-o", output_path, "-t", title 7 | if len(sys.argv) < 5 or sys.argv[1] != "-o" or sys.argv[3] != "-t": 8 | print("Error: Expected arguments: -o -t ", file=sys.stderr) 9 | sys.exit(1) 10 | 11 | output_file = sys.argv[2] 12 | title = sys.argv[4] 13 | 14 | # Read CSV from stdin 15 | reader = csv.reader(sys.stdin, delimiter=";") 16 | headers = next(reader) 17 | rows = list(reader) 18 | 19 | if not rows: 20 | print("Error: CSV contains no data rows", file=sys.stderr) 21 | sys.exit(1) 22 | 23 | x_column = headers[0] 24 | y_columns = headers[1:] 25 | 26 | # Convert data 27 | x_data = [float(row[0]) for row in rows] 28 | y_datasets = [ 29 | [float(row[col_idx]) for row in rows] for col_idx in range(1, len(headers)) 30 | ] 31 | 32 | # Create plot 33 | plt.figure(figsize=(10, 6)) 34 | plt.ylim(0, 1) 35 | markers = ["o", "s", "^", "D", "v", "<", ">", "p", "*", "h"] 36 | 37 | for idx, (y_data, col_name) in enumerate(zip(y_datasets, y_columns)): 38 | marker = markers[idx % len(markers)] 39 | plt.plot(x_data, y_data, marker=marker, label=col_name, linewidth=2, markersize=8) 40 | 41 | plt.xlabel(x_column, fontsize=12) 42 | plt.ylabel("Values", fontsize=12) 43 | plt.title(title, fontsize=14) 44 | plt.legend(fontsize=10, loc="lower left") 45 | plt.grid(True, alpha=0.3) 46 | plt.tight_layout() 47 | plt.savefig(output_file, dpi=300, bbox_inches="tight") 48 | -------------------------------------------------------------------------------- /eval/modutils/modutils.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package modutils 18 | 19 | import ( 20 | "fmt" 21 | "regexp" 22 | ) 23 | 24 | var feat2modelRegexp = regexp.MustCompile(`(.+\.v\d+\.\d+).*`) 25 | 26 | func FormatRoughSize(value int64) string { 27 | if value < 100000 { 28 | return "~0" 29 | } 30 | 31 | if value >= 1000000000 { // 1 billion or more 32 | billions := float64(value) / 1000000000.0 33 | return fmt.Sprintf("%.1fG", billions) 34 | } 35 | 36 | if value >= 100000 { // 1 million or more 37 | millions := float64(value) / 1000000.0 38 | return fmt.Sprintf("%.1fM", millions) 39 | } 40 | 41 | // Between 100,000 and 1,000,000 42 | return fmt.Sprintf("%d", value) 43 | } 44 | 45 | func ExtractModelNameBaseFromFeatFile(filename string) string { 46 | return feat2modelRegexp.ReplaceAllString(filename, "$1") 47 | } 48 | -------------------------------------------------------------------------------- /eval/examples.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package eval 18 | 19 | var ObligatoryExamples = []QueryStatsRecord{ 20 | {Corpus: "syn_v13", CorpusSize: 6400899055, TimeProc: 500, Query: "aword,[]"}, 21 | {Corpus: "syn_v13", CorpusSize: 6400899055, TimeProc: 500, Query: "aword,[word=\".*\"]"}, 22 | {Corpus: "syn_v13", CorpusSize: 6400899055, TimeProc: 500, Query: "aword,[word=\".+\"]"}, 23 | {Corpus: "syn_v13", CorpusSize: 6400899055, TimeProc: 500, Query: "aword,[lemma=\".*\"]"}, 24 | {Corpus: "syn_v13", CorpusSize: 6400899055, TimeProc: 500, Query: "aword,[lemma=\".+\"]"}, 25 | {Corpus: "syn_v13", CorpusSize: 6400899055, TimeProc: 500, Query: "aword,[lc=\".*\"]"}, 26 | {Corpus: "syn_v13", CorpusSize: 6400899055, TimeProc: 500, Query: "aword,[lc=\".+\"]"}, 27 | {Corpus: "syn_v13", CorpusSize: 6400899055, TimeProc: 500, Query: "aword,[tag=\"N.*\"]"}, 28 | {Corpus: "syn_v13", CorpusSize: 6400899055, TimeProc: 500, Query: "aword,[tag=\"N.+\"]"}, 29 | {Corpus: "syn_v13", CorpusSize: 6400899055, TimeProc: 500, Query: "aword,[pos=\"N\"]"}, 30 | } 31 | -------------------------------------------------------------------------------- /cql/common.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2024 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package cql 18 | 19 | import ( 20 | "fmt" 21 | "reflect" 22 | ) 23 | 24 | type ASTString string 25 | 26 | func (s ASTString) Text() string { 27 | return string(s) 28 | } 29 | 30 | func (s ASTString) String() string { 31 | return string(s) 32 | } 33 | 34 | type ASTNode interface { 35 | Text() string 36 | } 37 | 38 | func fromIdxOfUntypedSlice[T any](arr any, idx int) T { 39 | if arr == nil { 40 | var t T 41 | return t 42 | } 43 | tmp, ok := arr.([]any) 44 | if !ok { 45 | panic("value must be a slice") 46 | } 47 | v := tmp[idx] 48 | if v == nil { 49 | var t T 50 | return t 51 | } 52 | vt, ok := v.(T) 53 | if !ok { 54 | panic(fmt.Sprintf("value with idx %d has invalid type %s", idx, reflect.TypeOf(v))) 55 | } 56 | return vt 57 | } 58 | 59 | func anyToSlice(v any) []any { 60 | if v == nil { 61 | return []any{} 62 | } 63 | vt, ok := v.([]any) 64 | if !ok { 65 | panic(fmt.Sprintf("expecting an []any slice, got %s", reflect.TypeOf(v))) 66 | } 67 | return vt 68 | } 69 | 70 | func typedOrPanic[T any](v any) T { 71 | if v == nil { 72 | var ans T 73 | return ans 74 | } 75 | vt, ok := v.(T) 76 | if !ok { 77 | panic(fmt.Sprintf("unexpected type %s of: %v", reflect.TypeOf(v), v)) 78 | } 79 | return vt 80 | } 81 | -------------------------------------------------------------------------------- /eval/zero/model.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package zero 18 | 19 | import ( 20 | "context" 21 | "fmt" 22 | 23 | "github.com/czcorpus/cqlizer/eval/feats" 24 | "github.com/czcorpus/cqlizer/eval/predict" 25 | ) 26 | 27 | type ZeroModel struct { 28 | SlowQueriesThresholdTime float64 29 | ClassThreshold float64 30 | } 31 | 32 | func (zm *ZeroModel) IsInferenceOnly() bool { 33 | return true 34 | } 35 | 36 | func (zm *ZeroModel) CreateModelFileName(featsFile string) string { 37 | return "zero-model" 38 | } 39 | 40 | func (zm *ZeroModel) Train(ctx context.Context, data []feats.QueryEvaluation, slowQueriesTime float64, comment string) error { 41 | return fmt.Errorf("cannot train zero model") 42 | } 43 | 44 | func (zm *ZeroModel) Predict(feats feats.QueryEvaluation) predict.Prediction { 45 | return predict.Prediction{} 46 | } 47 | 48 | func (zm *ZeroModel) SetClassThreshold(v float64) { 49 | zm.ClassThreshold = v 50 | } 51 | 52 | func (zm *ZeroModel) GetClassThreshold() float64 { 53 | return zm.ClassThreshold 54 | } 55 | 56 | func (zm *ZeroModel) GetSlowQueriesThresholdTime() float64 { 57 | return zm.SlowQueriesThresholdTime 58 | } 59 | 60 | func (zm *ZeroModel) SaveToFile(string) error { 61 | return fmt.Errorf("cannot save zero model") 62 | } 63 | 64 | func (zm *ZeroModel) GetInfo() string { 65 | return "ZeroModel" 66 | } 67 | -------------------------------------------------------------------------------- /featurize.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package main 18 | 19 | import ( 20 | "context" 21 | "fmt" 22 | "os" 23 | 24 | "github.com/czcorpus/cqlizer/cnf" 25 | "github.com/czcorpus/cqlizer/dataimport" 26 | "github.com/czcorpus/cqlizer/eval" 27 | "github.com/rs/zerolog/log" 28 | "github.com/vmihailenco/msgpack/v5" 29 | ) 30 | 31 | func runActionFeaturize( 32 | ctx context.Context, 33 | conf *cnf.Conf, 34 | srcPath, dstPath string, 35 | debug bool, 36 | ) { 37 | model := eval.NewPredictor(nil, conf) 38 | dataimport.ReadStatsFile(ctx, srcPath, model) 39 | model.Deduplicate() 40 | 41 | if debug { 42 | for i, v := range model.Evaluations { 43 | fmt.Printf("feats[%d] for %s\n", i, v.OrigQuery) 44 | fmt.Println(v.Show()) 45 | } 46 | 47 | } else { 48 | srz, err := msgpack.Marshal(model) 49 | if err != nil { 50 | log.Fatal().Err(err).Msg("failed to serialize cql queries features") 51 | return 52 | } 53 | fmt.Println("importing features from ", srcPath) 54 | 55 | file, err := os.Create(dstPath) 56 | if err != nil { 57 | log.Fatal().Err(err).Str("file", dstPath).Msg("failed to save features to a file") 58 | return 59 | } 60 | defer file.Close() 61 | if _, err := file.Write(srz); err != nil { 62 | log.Fatal().Err(err).Str("file", dstPath).Msg("failed to save features to a file") 63 | return 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /eval/ym/yesman.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package ym 18 | 19 | import ( 20 | "context" 21 | "fmt" 22 | 23 | "github.com/czcorpus/cqlizer/eval/feats" 24 | "github.com/czcorpus/cqlizer/eval/predict" 25 | ) 26 | 27 | // Model is a constant classifier model which evaluates any query as slow (ym = yes-man). It is for debugging 28 | // purposes (for debugging and developing cqlizer's clients). 29 | type Model struct { 30 | SlowQueriesThresholdTime float64 31 | ClassThreshold float64 32 | } 33 | 34 | func (ym *Model) IsInferenceOnly() bool { 35 | return true 36 | } 37 | 38 | func (ym *Model) CreateModelFileName(featsFile string) string { 39 | return "ym-model" 40 | } 41 | 42 | func (ym *Model) Train(ctx context.Context, data []feats.QueryEvaluation, slowQueriesTime float64, comment string) error { 43 | return nil 44 | } 45 | 46 | func (ym *Model) Predict(feats feats.QueryEvaluation) predict.Prediction { 47 | return predict.Prediction{ 48 | Votes: []float64{0, 1}, 49 | PredictedClass: 1, 50 | } 51 | } 52 | 53 | func (ym *Model) SetClassThreshold(v float64) { 54 | ym.ClassThreshold = v 55 | } 56 | 57 | func (ym *Model) GetClassThreshold() float64 { 58 | return ym.ClassThreshold 59 | } 60 | 61 | func (ym *Model) GetSlowQueriesThresholdTime() float64 { 62 | return ym.SlowQueriesThresholdTime 63 | } 64 | 65 | func (ym *Model) SaveToFile(string) error { 66 | return fmt.Errorf("cannot save ym model") 67 | } 68 | 69 | func (ym *Model) GetInfo() string { 70 | return "Constant classifier model (always 1)" 71 | } 72 | -------------------------------------------------------------------------------- /cql/query_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2024 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package cql 18 | 19 | import ( 20 | "testing" 21 | 22 | "github.com/stretchr/testify/assert" 23 | ) 24 | 25 | func TestQueryGetAttrs(t *testing.T) { 26 | q, err := ParseCQL("test", `[word="hi|hello"] [lemma="people" & tag="N.*"] within <text foo="b: ar" & zoo="b,az">`) 27 | assert.NoError(t, err) 28 | attrs := q.ExtractProps() 29 | assert.Equal( 30 | t, 31 | []QueryProp{ 32 | {Name: "word", Value: "hi|hello"}, 33 | {Name: "lemma", Value: "people"}, 34 | {Name: "tag", Value: "N.*"}, 35 | {Structure: "text", Name: "foo", Value: "b: ar"}, 36 | {Structure: "text", Name: "zoo", Value: "b,az"}, 37 | {Structure: "text"}, 38 | }, 39 | attrs, 40 | ) 41 | } 42 | 43 | func TestRegressionAtSign(t *testing.T) { 44 | q, err := ParseCQL("test", `[tag="X@.*"]`) 45 | assert.NoError(t, err) 46 | attrs := q.ExtractProps() 47 | assert.Equal( 48 | t, 49 | []QueryProp{ 50 | {Name: "tag", Value: "X@.*"}, 51 | }, 52 | attrs, 53 | ) 54 | } 55 | 56 | func TestQueryGetAttrsSimpleStruct(t *testing.T) { 57 | q, err := ParseCQL("test", `[word="x"] within <s>`) 58 | assert.NoError(t, err) 59 | attrs := q.ExtractProps() 60 | assert.Equal( 61 | t, 62 | []QueryProp{ 63 | {Name: "word", Value: "x"}, 64 | {Structure: "s", Name: "", Value: ""}, 65 | }, 66 | attrs, 67 | ) 68 | } 69 | 70 | func TestRegexpOnlyQuery(t *testing.T) { 71 | q, err := ParseCQL("test", `"attr.*"`) 72 | assert.NoError(t, err) 73 | attrs := q.ExtractProps() 74 | assert.Equal( 75 | t, 76 | []QueryProp{ 77 | {Value: "attr.*"}, 78 | }, 79 | attrs, 80 | ) 81 | } 82 | -------------------------------------------------------------------------------- /apiserver/common.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package apiserver 18 | 19 | import ( 20 | "context" 21 | 22 | "github.com/czcorpus/cqlizer/cnf" 23 | "github.com/czcorpus/cqlizer/eval" 24 | "github.com/czcorpus/cqlizer/eval/feats" 25 | "github.com/czcorpus/cqlizer/eval/predict" 26 | "github.com/gin-gonic/gin" 27 | ) 28 | 29 | // VersionInfo provides a detailed information about the actual build 30 | type VersionInfo struct { 31 | Version string `json:"version"` 32 | BuildDate string `json:"buildDate"` 33 | GitCommit string `json:"gitCommit"` 34 | } 35 | 36 | // --------------------- 37 | 38 | type service interface { 39 | Start(ctx context.Context) 40 | Stop(ctx context.Context) error 41 | } 42 | 43 | // ------ 44 | 45 | type evaluation struct { 46 | CorpusSize int `json:"corpusSize"` 47 | Votes []vote `json:"votes"` 48 | IsSlowQuery bool `json:"isSlowQuery"` 49 | AltCorpus string `json:"altCorpus,omitempty"` 50 | } 51 | 52 | type vote struct { 53 | Votes []float64 `json:"votes"` 54 | Result int `json:"result"` 55 | } 56 | 57 | // ------ 58 | 59 | type ensembleModel struct { 60 | model eval.MLModel 61 | srcPath string 62 | threshold float64 63 | } 64 | 65 | func (md ensembleModel) Predict(queryEval feats.QueryEvaluation) predict.Prediction { 66 | return md.model.Predict(queryEval) 67 | } 68 | 69 | // ----- 70 | 71 | func corsMiddleware(conf *cnf.Conf) gin.HandlerFunc { 72 | return func(ctx *gin.Context) { 73 | 74 | var allowedOrigin string 75 | currOrigin := ctx.Request.Header.Get("Origin") 76 | for _, origin := range conf.CorsAllowedOrigins { 77 | if currOrigin == origin || origin == "*" { 78 | allowedOrigin = origin 79 | break 80 | } 81 | } 82 | if allowedOrigin != "" { 83 | ctx.Writer.Header().Set("Access-Control-Allow-Origin", allowedOrigin) 84 | ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") 85 | ctx.Writer.Header().Set( 86 | "Access-Control-Allow-Headers", 87 | "Content-Type, Content-Length, Accept-Encoding, Authorization, Accept, Origin, Cache-Control, X-Requested-With", 88 | ) 89 | ctx.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE") 90 | } 91 | 92 | if ctx.Request.Method == "OPTIONS" { 93 | ctx.AbortWithStatus(204) 94 | return 95 | } 96 | ctx.Next() 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/czcorpus/cqlizer 2 | 3 | go 1.24.0 4 | 5 | require ( 6 | github.com/chzyer/readline v1.5.1 7 | github.com/czcorpus/cnc-gokit v0.20.0 8 | github.com/dmitryikh/leaves v0.0.0-20230708180554-25d19a787328 9 | github.com/fatih/color v1.7.0 10 | github.com/gin-gonic/gin v1.10.0 11 | github.com/malaschitz/randomForest v0.0.0-20251101172028-7c30b8b21d88 12 | github.com/mna/pigeon v1.2.1 13 | github.com/patrikeh/go-deep v0.0.0-20230427173908-a2775168ab3d 14 | github.com/rs/zerolog v1.34.0 15 | github.com/schollz/progressbar/v3 v3.18.0 16 | github.com/stretchr/testify v1.11.1 17 | github.com/vmihailenco/msgpack/v5 v5.4.1 18 | ) 19 | 20 | replace github.com/patrikeh/go-deep => /home/tomas/work/korpus/tools/go-deep 21 | 22 | require ( 23 | github.com/BurntSushi/toml v1.5.0 // indirect 24 | github.com/bytedance/gopkg v0.1.3 // indirect 25 | github.com/bytedance/sonic v1.14.2 // indirect 26 | github.com/bytedance/sonic/loader v0.4.0 // indirect 27 | github.com/cloudwego/base64x v0.1.6 // indirect 28 | github.com/davecgh/go-spew v1.1.1 // indirect 29 | github.com/gabriel-vasile/mimetype v1.4.11 // indirect 30 | github.com/gin-contrib/sse v1.1.0 // indirect 31 | github.com/go-playground/locales v0.14.1 // indirect 32 | github.com/go-playground/universal-translator v0.18.1 // indirect 33 | github.com/go-playground/validator/v10 v10.28.0 // indirect 34 | github.com/goccy/go-json v0.10.5 // indirect 35 | github.com/json-iterator/go v1.1.12 // indirect 36 | github.com/klauspost/cpuid/v2 v2.3.0 // indirect 37 | github.com/kr/pretty v0.3.1 // indirect 38 | github.com/leodido/go-urn v1.4.0 // indirect 39 | github.com/mattn/go-colorable v0.1.14 // indirect 40 | github.com/mattn/go-isatty v0.0.20 // indirect 41 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect 42 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 43 | github.com/modern-go/reflect2 v1.0.2 // indirect 44 | github.com/natefinch/lumberjack v2.0.0+incompatible // indirect 45 | github.com/pelletier/go-toml/v2 v2.2.4 // indirect 46 | github.com/pmezard/go-difflib v1.0.0 // indirect 47 | github.com/rivo/uniseg v0.4.7 // indirect 48 | github.com/rogpeppe/go-internal v1.13.1 // indirect 49 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 50 | github.com/ugorji/go/codec v1.3.1 // indirect 51 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect 52 | golang.org/x/arch v0.23.0 // indirect 53 | golang.org/x/crypto v0.45.0 // indirect 54 | golang.org/x/mod v0.30.0 // indirect 55 | golang.org/x/net v0.47.0 // indirect 56 | golang.org/x/sync v0.18.0 // indirect 57 | golang.org/x/sys v0.38.0 // indirect 58 | golang.org/x/term v0.37.0 // indirect 59 | golang.org/x/text v0.31.0 // indirect 60 | golang.org/x/tools v0.39.0 // indirect 61 | gonum.org/v1/gonum v0.16.0 // indirect 62 | google.golang.org/protobuf v1.36.10 // indirect 63 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 64 | gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect 65 | gopkg.in/yaml.v2 v2.4.0 // indirect 66 | gopkg.in/yaml.v3 v3.0.1 // indirect 67 | ) 68 | -------------------------------------------------------------------------------- /dataimport/camus.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package dataimport 18 | 19 | import ( 20 | "bufio" 21 | "context" 22 | "encoding/json" 23 | "fmt" 24 | "os" 25 | 26 | "github.com/czcorpus/cqlizer/eval" 27 | "github.com/rs/zerolog/log" 28 | ) 29 | 30 | type StatsFileProcessor interface { 31 | ProcessEntry(entry eval.QueryStatsRecord) error 32 | SetStats(numProcessed, numFailed int) 33 | } 34 | 35 | // ReadStatsFile reads a JSONL file where each line is a QueryStatsRecord 36 | // and calls the processor for each entry. 37 | func ReadStatsFile(ctx context.Context, filePath string, processor StatsFileProcessor) error { 38 | file, err := os.Open(filePath) 39 | if err != nil { 40 | return fmt.Errorf("failed to open file: %w", err) 41 | } 42 | defer file.Close() 43 | 44 | scanner := bufio.NewScanner(file) 45 | const maxCapacity = 1024 * 1024 // 1 MB 46 | buf := make([]byte, maxCapacity) 47 | scanner.Buffer(buf, maxCapacity) 48 | lineNum := 0 49 | numProc := 0 50 | numFailed := 0 51 | for scanner.Scan() { 52 | select { 53 | case <-ctx.Done(): 54 | log.Warn().Msg("interrupting CQL file processing") 55 | return nil 56 | default: 57 | } 58 | lineNum++ 59 | line := scanner.Bytes() 60 | 61 | // Skip empty lines 62 | if len(line) == 0 { 63 | continue 64 | } 65 | 66 | var record eval.QueryStatsRecord 67 | if err := json.Unmarshal(line, &record); err != nil { 68 | log.Error().Err(err).Int("line", lineNum).Msg("failed to parse JSON, skipping") 69 | continue 70 | } 71 | 72 | if err := processor.ProcessEntry(record); err != nil { 73 | log.Error(). 74 | Err(err). 75 | Any("entry", record). 76 | Int("line", lineNum). 77 | Msg("failed to process CQL entry, skipping") 78 | numFailed++ 79 | continue 80 | 81 | } else { 82 | numProc++ 83 | } 84 | } 85 | 86 | if err := scanner.Err(); err != nil { 87 | return fmt.Errorf("failed to read query log file: %w", err) 88 | } 89 | 90 | for _, item := range eval.ObligatoryExamples { 91 | 92 | if err := processor.ProcessEntry(item); err != nil { 93 | log.Error(). 94 | Err(err). 95 | Any("entry", item). 96 | Int("line", lineNum). 97 | Msg("failed to process CQL entry, skipping") 98 | numFailed++ 99 | continue 100 | 101 | } else { 102 | numProc++ 103 | } 104 | } 105 | 106 | processor.SetStats(numProc, numFailed) 107 | if err := scanner.Err(); err != nil { 108 | return fmt.Errorf("error reading file: %w", err) 109 | } 110 | fmt.Printf("Stats file processed. Num imported queries: %d, num failed: %d\n", numProc, numFailed) 111 | 112 | return nil 113 | } 114 | -------------------------------------------------------------------------------- /cql/grammar_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2024 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package cql 18 | 19 | import ( 20 | "fmt" 21 | "testing" 22 | 23 | "github.com/stretchr/testify/assert" 24 | ) 25 | 26 | func TestRegexQuery(t *testing.T) { 27 | q1 := "[word=\"moto[a-z]\"]" 28 | p, err := ParseCQL("#", q1) 29 | assert.NoError(t, err) 30 | fmt.Println("p: ", p) 31 | } 32 | 33 | func TestRgOrQuery(t *testing.T) { 34 | q1 := "\"ſb(é|ě)r(ka|ku|ki|ze)\"" 35 | _, err := ParseCQL("#", q1) 36 | assert.NoError(t, err) 37 | } 38 | 39 | func TestRgOrQuery2(t *testing.T) { 40 | q1 := "[lemma=\"de|-|\"]" 41 | _, err := ParseCQL("#", q1) 42 | assert.NoError(t, err) 43 | } 44 | 45 | func TestJustRgQuery(t *testing.T) { 46 | q1 := "\"more|less\"" 47 | _, err := ParseCQL("#", q1) 48 | assert.NoError(t, err) 49 | } 50 | 51 | func TestParallelQuery(t *testing.T) { 52 | q := "[word=\"Skifahren\"] within <text group=\"Syndicate|Subtitles\" /> within " + 53 | "intercorp_v15_cs:[word=\"lyžování\"]" 54 | _, err := ParseCQL("#", q) 55 | assert.NoError(t, err) 56 | } 57 | 58 | func TestRgUnicodeProp(t *testing.T) { 59 | q1 := `[mwe_lemma=".+_\p{Lu}+" & mwe_tag=".*1"]` 60 | _, err := ParseCQL("#", q1) 61 | assert.NoError(t, err) 62 | } 63 | 64 | func TestRgPosixCharCls(t *testing.T) { 65 | q1 := `[word="^[[:alpha:]]{17}$"]` 66 | _, err := ParseCQL("#", q1) 67 | assert.NoError(t, err) 68 | } 69 | 70 | func TestRgUncommonChar(t *testing.T) { 71 | q1 := `[word="Bułka"]` 72 | _, err := ParseCQL("#", q1) 73 | assert.NoError(t, err) 74 | } 75 | 76 | func TestAlignedQuery(t *testing.T) { 77 | q1 := `[word="test"] within <text group=\"Acquis|Bible|Core|Europarl|PressEurop|Subtitles\" /> within intercorp_v12_cs:[word="Je"]` 78 | _, err := ParseCQL("#", q1) 79 | assert.NoError(t, err) 80 | } 81 | 82 | func TestRegress001(t *testing.T) { 83 | q1 := `[feats="VerbForm=Fin" & upos="VERB"]` 84 | _, err := ParseCQL("#", q1) 85 | assert.NoError(t, err) 86 | } 87 | 88 | func TestRgress002(t *testing.T) { 89 | q1 := `[(lemma="(?i)demokraticko\-liberálním" | sublemma="(?i)demokraticko\-liberálním" | word="(?i)demokraticko\-liberálním")]` 90 | _, err := ParseCQL("#", q1) 91 | assert.NoError(t, err) 92 | } 93 | 94 | func TestRgress003(t *testing.T) { 95 | q1 := `[lemma=".+t(o/ö)n"]` 96 | _, err := ParseCQL("#", q1) 97 | assert.NoError(t, err) 98 | } 99 | 100 | func TestRgress004(t *testing.T) { 101 | q1 := `(meet [col_lemma="didaktický_test"][col_lemma="didaktický_test" & lemma="didaktický"] 0 15)` 102 | _, err := ParseCQL("#", q1) 103 | assert.NoError(t, err) 104 | } 105 | 106 | func TestRgress005(t *testing.T) { 107 | q1 := `[word="ni{n,5}n"]` 108 | _, err := ParseCQL("#", q1) 109 | assert.NoError(t, err) 110 | } 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CQLizer 2 | 3 | CQLizer is a data-driven CQL (Corpus Query Language) writing helper tool for linguistic corpus analysis. It uses machine learning models to predict query performance and help users write efficient CQL queries. 4 | 5 | ## Features 6 | 7 | - CQL query parsing and AST generation using PEG grammar 8 | - Machine learning-based query performance prediction (Random Forest, Neural Network, XGBoost) 9 | - Data import from KonText log files 10 | - Multiple interfaces: CLI, REPL, and API server 11 | 12 | ## Requirements 13 | 14 | - Go 1.24+ 15 | - [pigeon](https://github.com/mna/pigeon) parser generator (for development) 16 | 17 | ## Installation 18 | 19 | ```bash 20 | # Install dependencies 21 | make tools 22 | 23 | # Build the binary 24 | make build 25 | ``` 26 | 27 | ## Usage 28 | 29 | ### CLI Commands 30 | 31 | ```bash 32 | # Show version information 33 | cqlizer version 34 | 35 | # Start interactive REPL 36 | cqlizer repl <model_file.json> 37 | 38 | # Extract features from query logs 39 | cqlizer featurize config.json logfile.jsonl output.msgpack 40 | 41 | # Train a model 42 | cqlizer learn [options] config.json features_file.msgpack 43 | 44 | # Start API server 45 | cqlizer server config.json 46 | 47 | # Start MCP server (experimental) 48 | cqlizer mcp-server config.json 49 | ``` 50 | 51 | ### Learning Options 52 | 53 | ```bash 54 | # Random Forest 55 | cqlizer learn -model rf -num-trees 100 config.json features.msgpack 56 | 57 | # Neural Network 58 | cqlizer learn -model nn config.json features.msgpack 59 | ``` 60 | 61 | #### XGBoost Model 62 | 63 | For XGBoost, the `learn` action extracts features into a format compatible with LightGBM. After running the extraction, use the Python script to train the model. 64 | 65 | First, set up a Python virtual environment with the required dependencies: 66 | 67 | ```bash 68 | python3 -m venv venv 69 | source venv/bin/activate 70 | pip install lightgbm==3.3.5 msgpack numpy scikit-learn 71 | ``` 72 | 73 | Then run the training: 74 | 75 | ```bash 76 | # Step 1: Extract features 77 | cqlizer learn -model xg config.json features.msgpack 78 | 79 | # Step 2: Train the model using Python 80 | python scripts/learnxgb.py --input ./cql_features.v3.17.msgpack --output ./cql_model.v3.17.model.xg.txt 81 | ``` 82 | 83 | Use `cqlizer help <command>` for detailed information about specific commands. 84 | 85 | ## Configuration 86 | 87 | Most actions need a proper JSON configuration file. You can use the sample configuration file `conf-sample.json` as a base. 88 | 89 | Note: `conf-sample.json` is configured with testing XGBoost model files located in the `testdata/` directory. 90 | 91 | ### Server-Specific Model Training 92 | 93 | **Important**: The model is trained on data from a specific server which has its own load and performance characteristics. For proper deployment in production, it is necessary to train the model on data obtained from your own server to ensure accurate performance predictions. 94 | 95 | The most affected feature is corpus size, as e.g. on a less powerful machine than the one we used to train the sample model, there will likely be too many false negatives (and vice versa - a more powerful server will cause more false positives). 96 | 97 | ``` 98 | 99 | ## Development 100 | 101 | ```bash 102 | # Generate parser from PEG grammar 103 | make generate 104 | 105 | # Run tests 106 | go test ./... 107 | 108 | # Build everything 109 | make all 110 | ``` 111 | -------------------------------------------------------------------------------- /apiserver/handler.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package apiserver 18 | 19 | import ( 20 | "fmt" 21 | "math" 22 | "net/http" 23 | "strings" 24 | 25 | "github.com/czcorpus/cnc-gokit/unireq" 26 | "github.com/czcorpus/cnc-gokit/uniresp" 27 | "github.com/czcorpus/cqlizer/eval/feats" 28 | "github.com/gin-gonic/gin" 29 | ) 30 | 31 | func (api *apiServer) handleVersion(ctx *gin.Context) { 32 | uniresp.WriteJSONResponse(ctx.Writer, api.version) 33 | } 34 | 35 | func (api *apiServer) handleEvalSimple(ctx *gin.Context) { 36 | q := ctx.Query("q") 37 | defaultAttr := ctx.QueryArray("defaultAttr") 38 | cqlChunks := make([]string, len(defaultAttr)) 39 | for i, da := range defaultAttr { 40 | cqlChunks[i] = fmt.Sprintf("%s=\"%s\"", da, q) 41 | } 42 | q = strings.Join(cqlChunks, " | ") 43 | api.evaluateRawQuery(ctx, q) 44 | 45 | } 46 | 47 | func (api *apiServer) handleEvalCQL(ctx *gin.Context) { 48 | q := ctx.Query("q") 49 | api.evaluateRawQuery(ctx, q) 50 | } 51 | 52 | func (api *apiServer) evaluateRawQuery(ctx *gin.Context, q string) { 53 | corpname := ctx.Param("corpusId") 54 | //aligned := ctx.QueryArray("aligned") 55 | var corpusInfo feats.CorpusProps 56 | var ok bool 57 | if corpname != "" { 58 | corpusInfo, ok = api.conf.CorporaProps[corpname] 59 | 60 | if !ok { 61 | uniresp.RespondWithErrorJSON( 62 | ctx, fmt.Errorf("corpus not found"), http.StatusNotFound, 63 | ) 64 | return 65 | } 66 | 67 | if ctx.Query("corpusSize") != "" { 68 | uniresp.RespondWithErrorJSON( 69 | ctx, fmt.Errorf("cannot specify corpusSize for a concrete corpus"), http.StatusBadRequest, 70 | ) 71 | return 72 | } 73 | 74 | } else { 75 | corpusInfo.Size, ok = unireq.GetURLIntArgOrFail(ctx, "corpusSize", 1000000000) 76 | if !ok { 77 | return 78 | } 79 | corpusInfo.Lang = ctx.Query("lang") 80 | } 81 | charProb := feats.GetCharProbabilityProvider(corpusInfo.Lang) 82 | queryEval, err := feats.NewQueryEvaluation(q, float64(corpusInfo.Size), 0, 3, charProb) 83 | if err != nil { 84 | uniresp.RespondWithErrorJSON(ctx, err, http.StatusInternalServerError) 85 | return 86 | } 87 | predictions := make([]vote, 0, len(api.rfEnsemble)) 88 | for _, md := range api.rfEnsemble { 89 | pr := md.Predict(queryEval) 90 | predictions = append( 91 | predictions, 92 | vote{ 93 | Votes: pr.Votes, 94 | Result: pr.PredictedClass, 95 | }, 96 | ) 97 | } 98 | 99 | var votesFor int 100 | for _, pred := range predictions { 101 | votesFor += pred.Result 102 | } 103 | resp := evaluation{ 104 | CorpusSize: corpusInfo.Size, 105 | Votes: predictions, 106 | IsSlowQuery: votesFor > int(math.Floor(float64(len(api.rfEnsemble))/2)), 107 | AltCorpus: corpusInfo.AltCorpus, 108 | } 109 | 110 | uniresp.WriteJSONResponse(ctx.Writer, resp) 111 | } 112 | -------------------------------------------------------------------------------- /eval/feats/letters.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package feats 18 | 19 | // ------------------------ 20 | 21 | type charProbabilityProvider interface { 22 | CharProbability(r rune) float64 23 | } 24 | 25 | // ------------------------ 26 | 27 | type charsProbabilityMap map[rune]float64 28 | 29 | func (chmap charsProbabilityMap) CharProbability(r rune) float64 { 30 | v, ok := chmap[r] 31 | if ok { 32 | return v 33 | } 34 | return 1 / float64(len(chmap)) * 0.1 35 | } 36 | 37 | // -------- 38 | 39 | type fallbackCharProbProvider struct{} 40 | 41 | func (fb fallbackCharProbProvider) CharProbability(r rune) float64 { 42 | return 1.0 / 30.0 43 | } 44 | 45 | // -------- 46 | 47 | // source: https://nlp.fi.muni.cz/cs/FrekvenceSlovLemmat 48 | 49 | var czCharsProbs = charsProbabilityMap{ 50 | 'a': 6.698, 51 | 'i': 4.571, 52 | 's': 4.620, 53 | 'á': 2.129, 54 | 'í': 3.103, 55 | 'š': 0.817, 56 | 'b': 1.665, 57 | 'j': 1.983, 58 | 't': 5.554, 59 | 'c': 1.601, 60 | 'k': 3.752, 61 | 'ť': 0.038, 62 | 'č': 1.017, 63 | 'l': 4.097, 64 | 'u': 3.131, 65 | 'd': 3.613, 66 | 'm': 3.262, 67 | 'ú': 0.145, 68 | 'ď': 0.019, 69 | 'n': 6.676, 70 | 'ů': 0.569, 71 | 'e': 7.831, 72 | 'ň': 0.073, 73 | 'v': 4.378, 74 | 'é': 1.178, 75 | 'o': 8.283, 76 | 'w': 0.072, 77 | 'ě': 1.491, 78 | 'ó': 0.032, 79 | 'x': 0.092, 80 | 'f': 0.394, 81 | 'p': 3.454, 82 | 'y': 1.752, 83 | 'g': 0.343, 84 | 'q': 0.006, 85 | 'ý': 0.942, 86 | 'h': 1.296, 87 | 'r': 3.977, 88 | 'z': 2.123, 89 | 'ř': 1.186, 90 | 'ž': 1.022, 91 | } 92 | 93 | // source: https://en.wikipedia.org/wiki/Letter_frequency 94 | 95 | var enCharsProbs = charsProbabilityMap{ 96 | 97 | 'a': 8.2, 98 | 'b': 1.5, 99 | 'c': 2.8, 100 | 'd': 4.3, 101 | 'e': 12.7, 102 | 'f': 2.2, 103 | 'g': 2.0, 104 | 'h': 6.1, 105 | 'i': 7.0, 106 | 'j': 0.15, 107 | 'k': 0.77, 108 | 'l': 4.0, 109 | 'm': 2.4, 110 | 'n': 6.7, 111 | 'o': 7.5, 112 | 'p': 1.9, 113 | 'q': 0.095, 114 | 'r': 6.0, 115 | 's': 6.3, 116 | 't': 9.1, 117 | 'u': 2.8, 118 | 'v': 0.98, 119 | 'w': 2.4, 120 | 'x': 0.15, 121 | 'y': 2.0, 122 | 'z': 0.074, 123 | } 124 | 125 | // source: https://www.sttmedia.com/characterfrequency-german 126 | 127 | var deCharsProbs = charsProbabilityMap{ 128 | 'a': 5.58, 129 | 'ä': 0.54, 130 | 'b': 1.96, 131 | 'c': 3.16, 132 | 'd': 4.98, 133 | 'e': 16.93, 134 | 'f': 1.49, 135 | 'g': 3.02, 136 | 'h': 4.98, 137 | 'i': 8.02, 138 | 'j': 0.24, 139 | 'k': 1.32, 140 | 'l': 3.60, 141 | 'm': 2.55, 142 | 'n': 10.53, 143 | 'o': 2.24, 144 | 'ö': 0.30, 145 | 'p': 0.67, 146 | 'q': 0.02, 147 | 'r': 6.89, 148 | 'ß': 0.37, 149 | 's': 6.42, 150 | 't': 5.79, 151 | 'u': 3.83, 152 | 'ü': 0.65, 153 | 'v': 0.84, 154 | 'w': 1.78, 155 | 'x': 0.05, 156 | 'y': 0.05, 157 | 'z': 1.21, 158 | } 159 | 160 | // ----------------------- 161 | 162 | func GetCharProbabilityProvider(lang string) charProbabilityProvider { 163 | switch lang { 164 | case "cs": 165 | return czCharsProbs 166 | case "en": 167 | return enCharsProbs 168 | case "de": 169 | return deCharsProbs 170 | default: 171 | return fallbackCharProbProvider{} 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /cnf/conf.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2024 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package cnf 18 | 19 | import ( 20 | "encoding/json" 21 | "fmt" 22 | "os" 23 | "time" 24 | 25 | "github.com/czcorpus/cnc-gokit/logging" 26 | "github.com/czcorpus/cqlizer/eval/feats" 27 | "github.com/rs/zerolog/log" 28 | ) 29 | 30 | const ( 31 | dfltServerWriteTimeoutSecs = 30 32 | dfltLanguage = "en" 33 | dfltMaxNumConcurrentJobs = 4 34 | dfltVertMaxNumErrors = 100 35 | dfltTimeZone = "Europe/Prague" 36 | ) 37 | 38 | type RFEnsembleConf struct { 39 | ModelPath string `json:"modelPath"` 40 | VoteThreshold float64 `json:"voteThreshold"` 41 | ModelType string `json:"modelType"` 42 | Disabled bool `json:"disabled"` 43 | } 44 | 45 | type Conf struct { 46 | srcPath string 47 | Logging logging.LoggingConf `json:"logging"` 48 | ListenAddress string `json:"listenAddress"` 49 | PublicURL string `json:"publicUrl"` 50 | ListenPort int `json:"listenPort"` 51 | ServerReadTimeoutSecs int `json:"serverReadTimeoutSecs"` 52 | TestingPageURLPathPrefix string `json:"testingPageURLPathPrefix"` 53 | ServerWriteTimeoutSecs int `json:"serverWriteTimeoutSecs"` 54 | CorsAllowedOrigins []string `json:"corsAllowedOrigins"` 55 | TimeZone string `json:"timeZone"` 56 | RFEnsemble []RFEnsembleConf `json:"rfEnsemble"` 57 | CorporaProps map[string]feats.CorpusProps `json:"corporaProps"` 58 | 59 | // SyntheticTimeCorrection - for stats records generated via benchmarking, 60 | // it may be needed to increase the times as MQuery will probably perform a bit better 61 | // and if performed during low traffic hours, this difference can be even bigger. 62 | SyntheticTimeCorrection float64 `json:"syntheticTimeCorrection"` 63 | MQueryBenchmarkingURL string `json:"mqueryBenchmarkingUrl"` 64 | } 65 | 66 | func LoadConfig(path string) *Conf { 67 | if path == "" { 68 | log.Fatal().Msg("Cannot load config - path not specified") 69 | } 70 | rawData, err := os.ReadFile(path) 71 | if err != nil { 72 | log.Fatal().Err(err).Msg("Cannot load config") 73 | } 74 | var conf Conf 75 | conf.srcPath = path 76 | err = json.Unmarshal(rawData, &conf) 77 | if err != nil { 78 | log.Fatal().Err(err).Msg("Cannot load config") 79 | } 80 | return &conf 81 | } 82 | 83 | func ValidateAndDefaults(conf *Conf) { 84 | if conf.ServerWriteTimeoutSecs == 0 { 85 | conf.ServerWriteTimeoutSecs = dfltServerWriteTimeoutSecs 86 | log.Warn().Msgf( 87 | "serverWriteTimeoutSecs not specified, using default: %d", 88 | dfltServerWriteTimeoutSecs, 89 | ) 90 | } 91 | if conf.PublicURL == "" { 92 | conf.PublicURL = fmt.Sprintf("http://%s", conf.ListenAddress) 93 | log.Warn().Str("address", conf.PublicURL).Msg("publicUrl not set, using listenAddress") 94 | } 95 | 96 | if conf.TimeZone == "" { 97 | log.Warn(). 98 | Str("timeZone", dfltTimeZone). 99 | Msg("time zone not specified, using default") 100 | } 101 | if _, err := time.LoadLocation(conf.TimeZone); err != nil { 102 | log.Fatal().Err(err).Msg("invalid time zone") 103 | } 104 | 105 | if conf.SyntheticTimeCorrection == 0 { 106 | log.Warn().Msg("SyntheticRecordsTimeCorrection is not set - we must set it to 1") 107 | conf.SyntheticTimeCorrection = 1 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /apiserver/apiserver.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package apiserver 18 | 19 | import ( 20 | "context" 21 | "fmt" 22 | "net/http" 23 | "sync" 24 | "time" 25 | 26 | "github.com/czcorpus/cnc-gokit/logging" 27 | "github.com/czcorpus/cnc-gokit/uniresp" 28 | "github.com/czcorpus/cqlizer/cnf" 29 | "github.com/czcorpus/cqlizer/eval" 30 | "github.com/gin-gonic/gin" 31 | "github.com/rs/zerolog/log" 32 | ) 33 | 34 | // ----- 35 | 36 | type apiServer struct { 37 | conf *cnf.Conf 38 | server *http.Server 39 | rfEnsemble []ensembleModel 40 | version VersionInfo 41 | } 42 | 43 | func (api *apiServer) Start(ctx context.Context) { 44 | if !api.conf.Logging.Level.IsDebugMode() { 45 | gin.SetMode(gin.ReleaseMode) 46 | } 47 | 48 | engine := gin.New() 49 | engine.Use(gin.Recovery()) 50 | engine.Use(logging.GinMiddleware()) 51 | engine.Use(uniresp.AlwaysJSONContentType()) 52 | engine.Use(corsMiddleware(api.conf)) 53 | engine.NoMethod(uniresp.NoMethodHandler) 54 | engine.NoRoute(uniresp.NotFoundHandler) 55 | 56 | engine.GET("/test", api.handleTestPage) 57 | engine.GET("/cql/:corpusId", api.handleEvalCQL) 58 | engine.GET("/cql", api.handleEvalCQL) 59 | engine.GET("/simple/:corpusId", api.handleEvalSimple) 60 | engine.GET("/simple", api.handleEvalSimple) 61 | 62 | engine.GET("/version", api.handleVersion) 63 | 64 | log.Info().Msgf("starting to listen at %s:%d", api.conf.ListenAddress, api.conf.ListenPort) 65 | api.server = &http.Server{ 66 | Handler: engine, 67 | Addr: fmt.Sprintf("%s:%d", api.conf.ListenAddress, api.conf.ListenPort), 68 | WriteTimeout: time.Duration(api.conf.ServerWriteTimeoutSecs) * time.Second, 69 | ReadTimeout: time.Duration(api.conf.ServerReadTimeoutSecs) * time.Second, 70 | } 71 | go func() { 72 | if err := api.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { 73 | log.Fatal().Err(err).Msg("server error") 74 | } 75 | }() 76 | } 77 | 78 | func (api *apiServer) Stop(ctx context.Context) error { 79 | log.Warn().Msg("shutting down CQLizer HTTP API server") 80 | return api.server.Shutdown(ctx) 81 | } 82 | 83 | // ------------------------- 84 | 85 | func Run( 86 | ctx context.Context, 87 | conf *cnf.Conf, 88 | version VersionInfo, 89 | ) { 90 | 91 | server := &apiServer{ 92 | conf: conf, 93 | rfEnsemble: make([]ensembleModel, 0, len(conf.RFEnsemble)), 94 | version: version, 95 | } 96 | 97 | for _, rfc := range conf.RFEnsemble { 98 | if rfc.Disabled { 99 | continue 100 | } 101 | mlModel, err := eval.GetMLModel(rfc.ModelType, rfc.ModelPath) 102 | if err != nil { 103 | log.Fatal().Err(err).Msg("Error loading RF model") 104 | return 105 | } 106 | mlModel.SetClassThreshold(rfc.VoteThreshold) 107 | 108 | log.Info(). 109 | Float64("voteThreshold", rfc.VoteThreshold). 110 | Str("type", rfc.ModelType). 111 | Str("file", rfc.ModelPath). 112 | Msg("loaded model") 113 | server.rfEnsemble = append( 114 | server.rfEnsemble, 115 | ensembleModel{ 116 | model: mlModel, 117 | srcPath: rfc.ModelPath, 118 | threshold: rfc.VoteThreshold, 119 | }, 120 | ) 121 | } 122 | 123 | services := []service{server} 124 | for _, m := range services { 125 | m.Start(ctx) 126 | } 127 | <-ctx.Done() 128 | log.Warn().Msg("shutdown signal received") 129 | 130 | shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 131 | defer cancel() 132 | 133 | var wg sync.WaitGroup 134 | for _, s := range services { 135 | wg.Add(1) 136 | go func(srv service) { 137 | defer wg.Done() 138 | if err := srv.Stop(shutdownCtx); err != nil { 139 | log.Error().Err(err).Type("service", srv).Msg("Error shutting down service") 140 | } 141 | }(s) 142 | } 143 | 144 | done := make(chan struct{}) 145 | go func() { 146 | wg.Wait() 147 | close(done) 148 | }() 149 | 150 | select { 151 | case <-done: 152 | log.Info().Msg("Graceful shutdown completed") 153 | case <-shutdownCtx.Done(): 154 | log.Warn().Msg("Shutdown timed out") 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /cql/query.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2024 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package cql 18 | 19 | import ( 20 | "encoding/json" 21 | "strings" 22 | 23 | "github.com/czcorpus/cnc-gokit/collections" 24 | ) 25 | 26 | // QueryProp is a generalized query property: 27 | // a) positional attribute with a value 28 | // b) structural attribute with a value 29 | // c) structure 30 | type QueryProp struct { 31 | Structure string 32 | Name string 33 | Value string 34 | } 35 | 36 | func (qp QueryProp) IsStructure() bool { 37 | return qp.Structure != "" && qp.Name == "" && qp.Value == "" 38 | } 39 | 40 | func (qp QueryProp) IsStructAttr() bool { 41 | return qp.Structure != "" && qp.Name != "" && qp.Value != "" 42 | } 43 | 44 | func (qp QueryProp) IsPosattr() bool { 45 | // we do not test qp.Name here as the query can 46 | // be also just a regexp expecting a default attribute 47 | return qp.Structure == "" && qp.Value != "" 48 | } 49 | 50 | // Query represents root node of a CQL syntax tree. 51 | // 52 | // Sequence (_ BINAND _ GlobPart)? (_ WithinOrContaining)* EOF { 53 | type Query struct { 54 | origValue string 55 | Sequence *Sequence 56 | GlobPart *GlobPart 57 | WithinOrContaining []*WithinOrContaining 58 | } 59 | 60 | func (q *Query) MarshalJSON() ([]byte, error) { 61 | return json.Marshal(struct { 62 | Expansion Query 63 | RuleName string 64 | }{ 65 | RuleName: "Query", 66 | Expansion: *q, 67 | }) 68 | } 69 | 70 | func (q *Query) Len() int { 71 | return len(q.origValue) 72 | } 73 | 74 | func (q *Query) Text() string { 75 | return q.origValue 76 | } 77 | 78 | func (q *Query) ForEachElement(fn func(parent, v ASTNode)) { 79 | fn(nil, q) 80 | if q.Sequence != nil { 81 | q.Sequence.ForEachElement(q, fn) 82 | } 83 | if q.GlobPart != nil { 84 | q.GlobPart.ForEachElement(q, fn) 85 | } 86 | for _, item := range q.WithinOrContaining { 87 | item.ForEachElement(q, fn) 88 | } 89 | } 90 | 91 | func (q *Query) DFS(fn func(v ASTNode)) { 92 | if q.Sequence != nil { 93 | q.Sequence.DFS(fn) 94 | } 95 | if q.GlobPart != nil { 96 | q.GlobPart.DFS(fn) 97 | } 98 | for _, item := range q.WithinOrContaining { 99 | item.DFS(fn) 100 | } 101 | fn(q) 102 | } 103 | 104 | func (q *Query) ExtractProps() []QueryProp { 105 | ans := make([]QueryProp, 0, 10) 106 | parents := make(parentMap) 107 | structs := collections.NewSet[string]() 108 | q.ForEachElement(func(parent, v ASTNode) { 109 | parents[v] = parent 110 | switch typedV := v.(type) { 111 | case *AttVal: 112 | if typedV.Variant1 != nil { 113 | newItem := QueryProp{ 114 | Name: typedV.Variant1.AttName.String(), 115 | Value: strings.Trim(typedV.Variant1.RawString.SimpleString.Text(), "\""), 116 | } 117 | stSrch := parents.findParentByType(typedV, &Structure{}, 0) 118 | if stSrch != nil { 119 | t, ok := stSrch.(*Structure) 120 | if !ok { 121 | // this can happen only if findParentByType is broken 122 | panic("found structure is not a *Structure") 123 | } 124 | newItem.Structure = t.AttName.String() 125 | } 126 | ans = append(ans, newItem) 127 | 128 | } else if typedV.Variant2 != nil { 129 | newItem := QueryProp{ 130 | Name: typedV.Variant2.AttName.String(), 131 | Value: strings.Trim(typedV.Variant2.RegExp.Text(), "\""), 132 | } 133 | stSrch := parents.findParentByType(typedV, &Structure{}, 0) 134 | if stSrch != nil { 135 | t, ok := stSrch.(*Structure) 136 | if !ok { 137 | // this can happen only if findParentByType() is broken 138 | panic("found structure is not a *Structure") 139 | } 140 | newItem.Structure = t.AttName.String() 141 | } 142 | ans = append(ans, newItem) 143 | 144 | } 145 | case *Structure: 146 | structs.Add(typedV.AttName.String()) 147 | case *RegExp: 148 | srch := parents.findParentByType(typedV, &OnePosition{}, 1) 149 | if srch != nil { 150 | val := make([]string, len(typedV.RegExpRaw)) 151 | for i, v := range typedV.RegExpRaw { 152 | val[i] = v.Text() 153 | } 154 | ans = append(ans, QueryProp{Value: strings.Join(val, " ")}) 155 | } 156 | } 157 | }) 158 | for _, v := range structs.ToSlice() { 159 | ans = append(ans, QueryProp{Structure: v}) 160 | } 161 | return ans 162 | } 163 | -------------------------------------------------------------------------------- /scripts/learnxgb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 4 | # Copyright 2025 Department of Linguistics, 5 | # Faculty of Arts, Charles University 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Train LightGBM model for CQL query performance classification.""" 20 | 21 | import argparse 22 | import json 23 | import os 24 | 25 | import lightgbm as lgb 26 | import msgpack 27 | import numpy as np 28 | from sklearn.metrics import auc, classification_report, precision_recall_curve 29 | from sklearn.model_selection import train_test_split 30 | 31 | 32 | def load_msgpack_features(path: str) -> tuple[np.ndarray, np.ndarray]: 33 | """Load features from msgpack file. 34 | 35 | Adjust unpacking based on your actual msgpack structure. 36 | """ 37 | with open(path, "rb") as f: 38 | data = msgpack.unpack(f) 39 | X = np.array([item for item in data["features"]]) 40 | y = np.array([item for item in data["label"]]) 41 | return X, y 42 | 43 | 44 | def train_model(X: np.ndarray, y: np.ndarray, output_path: str): 45 | """Train LightGBM and save model.""" 46 | 47 | X_train, X_test, y_train, y_test = train_test_split( 48 | X, y, test_size=0.2, random_state=42, stratify=y 49 | ) 50 | 51 | # Calculate scale_pos_weight for class imbalance (your 1-5% slow queries) 52 | neg_count = np.sum(y_train == 0) 53 | pos_count = np.sum(y_train == 1) 54 | scale_pos_weight = neg_count / pos_count 55 | 56 | print(f"Scale pos weight: {scale_pos_weight:.2f}") 57 | 58 | params = { 59 | "objective": "binary", 60 | "metric": ["auc", "binary_logloss"], 61 | "scale_pos_weight": scale_pos_weight, 62 | "max_depth": 6, 63 | "learning_rate": 0.05, 64 | "num_leaves": 81, 65 | "min_child_samples": 20, 66 | "subsample": 0.8, 67 | "colsample_bytree": 0.8, 68 | "random_state": 42, 69 | "verbose": -1, 70 | } 71 | 72 | train_data = lgb.Dataset(X_train, label=y_train) 73 | valid_data = lgb.Dataset(X_test, label=y_test, reference=train_data) 74 | 75 | model = lgb.train( 76 | params, 77 | train_data, 78 | num_boost_round=200, 79 | valid_sets=[train_data, valid_data], 80 | valid_names=["train", "valid"], 81 | callbacks=[ 82 | lgb.early_stopping(stopping_rounds=20), 83 | lgb.log_evaluation(period=10), 84 | ], 85 | ) 86 | 87 | # Evaluate 88 | y_prob = model.predict(X_test, num_iteration=model.best_iteration) 89 | y_pred = (y_prob > 0.5).astype(int) 90 | 91 | print("\nClassification Report:") 92 | print(classification_report(y_test, y_pred, target_names=["normal", "slow"])) 93 | 94 | # PR-AUC (more meaningful than ROC-AUC for imbalanced data) 95 | precision, recall, _ = precision_recall_curve(y_test, y_prob) 96 | pr_auc = auc(recall, precision) 97 | print(f"PR-AUC: {pr_auc:.4f}") 98 | 99 | # Feature importance 100 | print("\nTop 10 Feature Importances (gain):") 101 | importance = model.feature_importance(importance_type="gain") 102 | indices = np.argsort(importance)[::-1][:10] 103 | for i, idx in enumerate(indices): 104 | print(f" {i + 1}. Feature {idx}: {importance[idx]:.4f}") 105 | 106 | # Save model in text format (compatible with leaves) 107 | model.save_model(output_path) 108 | print(f"\nModel saved to: {output_path}") 109 | print(f"Best iteration: {model.best_iteration}") 110 | 111 | with open(os.path.splitext(output_path)[0] + ".metadata.json", "w") as fw: 112 | json.dump(params, fw) 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser( 117 | description="Train LightGBM for CQL classification" 118 | ) 119 | parser.add_argument("--input", "-i", required=True, help="Path to msgpack features") 120 | parser.add_argument( 121 | "--output", 122 | "-o", 123 | default="model.txt", 124 | help="Output model path (.txt for leaves compatibility)", 125 | ) 126 | args = parser.parse_args() 127 | 128 | X, y = load_msgpack_features(args.input) 129 | print(f"Loaded {len(X)} samples, {X.shape[1]} features") 130 | print( 131 | f"Class distribution: {np.sum(y == 0)} normal, {np.sum(y == 1)} slow ({100 * np.mean(y):.2f}% positive)" 132 | ) 133 | train_model(X, y, args.output) 134 | -------------------------------------------------------------------------------- /eval/report.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package eval 18 | 19 | import ( 20 | "bytes" 21 | "fmt" 22 | "math" 23 | "os" 24 | "os/exec" 25 | "path/filepath" 26 | "slices" 27 | "strings" 28 | 29 | "github.com/czcorpus/cqlizer/eval/feats" 30 | ) 31 | 32 | type misclassification struct { 33 | Evaluation feats.QueryEvaluation `json:"evaluation"` 34 | MLOutput float64 `json:"mlOutput"` 35 | Threshold float64 `json:"threhold"` 36 | NumRepeat int `json:"numRepeat"` 37 | Type string `json:"type"` 38 | } 39 | 40 | func (m misclassification) AbsErrorSize() float64 { 41 | return math.Abs(m.MLOutput - m.Threshold) 42 | } 43 | 44 | // ------------------------ 45 | 46 | type Reporter struct { 47 | RFAccuracyScript string 48 | misclassQueries map[string]misclassification 49 | MisclassQueriesOutPath string 50 | } 51 | 52 | func (reporter *Reporter) AddMisclassifiedQuery(q feats.QueryEvaluation, mlOut, threshold, slowProcTime float64) { 53 | predictedSlow := mlOut >= threshold 54 | actuallySlow := q.ProcTime >= slowProcTime 55 | var tp string 56 | if actuallySlow && !predictedSlow { 57 | tp = "FN" 58 | 59 | } else if !actuallySlow && predictedSlow { 60 | tp = "FP" 61 | } 62 | if reporter.misclassQueries == nil { 63 | reporter.misclassQueries = make(map[string]misclassification) 64 | } 65 | curr, ok := reporter.misclassQueries[q.UniqKey()] 66 | if ok { 67 | curr.MLOutput += mlOut 68 | curr.NumRepeat += 1 69 | if tp != curr.Type { 70 | curr.Type = "*" 71 | } 72 | reporter.misclassQueries[q.UniqKey()] = curr 73 | 74 | } else { 75 | reporter.misclassQueries[q.UniqKey()] = misclassification{ 76 | Evaluation: q, 77 | MLOutput: mlOut, 78 | Threshold: threshold, 79 | NumRepeat: 1, 80 | Type: tp, 81 | } 82 | } 83 | } 84 | 85 | func (reporter *Reporter) sortedMisclassifiedQueries() []misclassification { 86 | ans := make([]misclassification, len(reporter.misclassQueries)) 87 | i := 0 88 | for _, v := range reporter.misclassQueries { 89 | v.MLOutput /= float64(v.NumRepeat) 90 | ans[i] = v 91 | i++ 92 | } 93 | slices.SortFunc( 94 | ans, 95 | func(v1, v2 misclassification) int { 96 | if v1.NumRepeat < v2.NumRepeat { 97 | return 1 98 | 99 | } else if v1.NumRepeat > v2.NumRepeat { 100 | return -1 101 | 102 | } else { 103 | if v1.AbsErrorSize() < v2.AbsErrorSize() { 104 | return 1 105 | } 106 | return -1 107 | } 108 | }, 109 | ) 110 | return ans 111 | } 112 | 113 | func (reporter *Reporter) ShowMisclassifiedQueries() { 114 | for i, v := range reporter.misclassQueries { 115 | fmt.Fprintf(os.Stderr, "%s\t%.2f\t%s\n", i, v.AbsErrorSize(), v.Evaluation.OrigQuery) 116 | } 117 | } 118 | 119 | func (reporter *Reporter) SaveMisclassifiedQueries() error { 120 | data := reporter.sortedMisclassifiedQueries() 121 | if reporter.MisclassQueriesOutPath == "" { 122 | return fmt.Errorf("misclassQueriesOutPath is not set") 123 | } 124 | 125 | f, err := os.Create(reporter.MisclassQueriesOutPath) 126 | if err != nil { 127 | return fmt.Errorf("failed to create file %s: %w", reporter.MisclassQueriesOutPath, err) 128 | } 129 | defer f.Close() 130 | 131 | for _, item := range data { 132 | _, err := fmt.Fprintf(f, "%.0f\t%.2f\t%0.2f\t%s(%d)\t%s\n", 133 | math.Exp(item.Evaluation.CorpusSize), item.Evaluation.ProcTime, item.MLOutput, item.Type, item.NumRepeat, item.Evaluation.OrigQuery) 134 | if err != nil { 135 | return fmt.Errorf("failed to write to file: %w", err) 136 | } 137 | } 138 | 139 | return nil 140 | } 141 | 142 | // PlotModelAccuracy creates a chart from CSV data using a Python plotting script. 143 | // The output file name is derived from the provided modelPath 144 | func (reporter *Reporter) PlotRFAccuracy(data, chartLabel, modelPath string) error { 145 | chartFilePath := fmt.Sprintf("%s.png", strings.TrimSuffix(modelPath, filepath.Ext(modelPath))) 146 | cmd := exec.Command("python3", "-c", reporter.RFAccuracyScript, "-o", chartFilePath, "-t", chartLabel) 147 | cmd.Stdin = bytes.NewBufferString(data) 148 | var stdout, stderr bytes.Buffer 149 | cmd.Stdout = &stdout 150 | cmd.Stderr = &stderr 151 | err := cmd.Run() 152 | if err != nil { 153 | return fmt.Errorf("failed to execute plotting script: %w\nStderr: %s", err, stderr.String()) 154 | } 155 | 156 | return nil 157 | } 158 | -------------------------------------------------------------------------------- /learn.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package main 18 | 19 | import ( 20 | "context" 21 | "fmt" 22 | "io" 23 | "math" 24 | "os" 25 | "os/signal" 26 | "strings" 27 | "syscall" 28 | "time" 29 | 30 | "github.com/czcorpus/cqlizer/cnf" 31 | "github.com/czcorpus/cqlizer/eval" 32 | "github.com/czcorpus/cqlizer/eval/nn" 33 | "github.com/czcorpus/cqlizer/eval/rf" 34 | "github.com/czcorpus/cqlizer/eval/xg" 35 | "github.com/rs/zerolog/log" 36 | "github.com/schollz/progressbar/v3" 37 | "github.com/vmihailenco/msgpack/v5" 38 | ) 39 | 40 | func runActionKlogImport( 41 | conf *cnf.Conf, 42 | srcPath string, 43 | modelType string, 44 | numTrees int, 45 | voteThreshold float64, 46 | misclassLogPath string, 47 | ) { 48 | ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 49 | defer stop() 50 | 51 | /* 52 | model := &eval.BasicModel{ 53 | SlowQueryPercentile: slowQueryPerc, 54 | } 55 | dataimport.ReadStatsFile(ctx, srcPath, model) 56 | */ 57 | 58 | f, err := os.Open(srcPath) 59 | if err != nil { 60 | log.Fatal().Err(err).Msg("failed to open features file") 61 | return 62 | } 63 | defer f.Close() 64 | data, err := io.ReadAll(f) 65 | if err != nil { 66 | log.Fatal().Err(err).Msg("failed to open features file") 67 | return 68 | } 69 | 70 | var mlModel eval.MLModel 71 | switch modelType { 72 | case "rf": 73 | mlModel = rf.NewModel(numTrees, voteThreshold) 74 | case "nn": 75 | mlModel = nn.NewModel() 76 | case "xg": 77 | mlModel = xg.NewModel() 78 | default: 79 | log.Fatal().Str("modelType", modelType).Msg("Unknown model") 80 | return 81 | } 82 | 83 | model := eval.NewPredictor(mlModel, conf) 84 | if err := msgpack.Unmarshal(data, &model); err != nil { 85 | log.Fatal().Err(err).Msg("failed to open features file") 86 | return 87 | } 88 | 89 | allEvals := model.BalanceSample() 90 | reporter := &eval.Reporter{ 91 | RFAccuracyScript: rfChartScript, 92 | MisclassQueriesOutPath: misclassLogPath, 93 | } 94 | 95 | if err := model.CreateAndTestModel(ctx, allEvals, srcPath, reporter); err != nil { 96 | fmt.Fprintf(os.Stderr, "RF training failed: %v\n", err) 97 | os.Exit(1) 98 | } 99 | } 100 | 101 | func runActionEvaluate( 102 | conf *cnf.Conf, 103 | modelPath string, 104 | modelType string, 105 | tstDataPath string, 106 | misclassLogPath string, 107 | ) { 108 | ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 109 | defer stop() 110 | mlModel, err := eval.GetMLModel(modelType, modelPath) 111 | if err != nil { 112 | log.Fatal().Err(err).Msg("Failed to load the ML model") 113 | return 114 | } 115 | 116 | f, err := os.Open(tstDataPath) 117 | if err != nil { 118 | log.Fatal().Err(err).Msg("failed to open features file") 119 | return 120 | } 121 | defer f.Close() 122 | data, err := io.ReadAll(f) 123 | if err != nil { 124 | log.Fatal().Err(err).Msg("failed to open features file") 125 | return 126 | } 127 | 128 | predictor := eval.NewPredictor(mlModel, conf) 129 | if err := msgpack.Unmarshal(data, &predictor); err != nil { 130 | log.Fatal().Err(err).Msg("failed to open features file") 131 | return 132 | } 133 | predictor.FindAndSetDataMidpoint() 134 | 135 | reporter := &eval.Reporter{ 136 | RFAccuracyScript: rfChartScript, 137 | MisclassQueriesOutPath: misclassLogPath, 138 | } 139 | 140 | log.Info(). 141 | Int("evalDataSize", len(predictor.Evaluations)). 142 | Msg("calculating precision and recall using full data") 143 | 144 | bar := progressbar.Default(int64(math.Ceil((1-0.5)/0.01)), "testing the model") 145 | var csv strings.Builder 146 | csv.WriteString("vote;precision;recall;f-beta\n") 147 | for v := 0.5; v < 1; v += 0.01 { 148 | select { 149 | case <-ctx.Done(): 150 | return 151 | default: 152 | } 153 | mlModel.SetClassThreshold(v) 154 | precall := predictor.PrecisionAndRecall(reporter) 155 | csv.WriteString(precall.CSV(v) + "\n") 156 | bar.Add(1) 157 | } 158 | unixt := time.Now().Unix() 159 | chartPath := fmt.Sprintf("./test-%d.png", unixt) 160 | if err := reporter.PlotRFAccuracy(csv.String(), mlModel.GetInfo(), chartPath); err != nil { 161 | log.Fatal().Err(err).Msgf("failed to generate accuracy chart") 162 | return 163 | 164 | } else { 165 | log.Info().Str("file", chartPath).Msg("saved evaluation chart") 166 | } 167 | reporter.SaveMisclassifiedQueries() 168 | } 169 | -------------------------------------------------------------------------------- /repl.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package main 18 | 19 | import ( 20 | "fmt" 21 | "io" 22 | "os" 23 | "path/filepath" 24 | "strconv" 25 | "strings" 26 | 27 | "github.com/chzyer/readline" 28 | "github.com/czcorpus/cqlizer/eval" 29 | "github.com/czcorpus/cqlizer/eval/feats" 30 | "github.com/czcorpus/cqlizer/eval/modutils" 31 | "github.com/fatih/color" 32 | "github.com/rs/zerolog/log" 33 | ) 34 | 35 | func ensureConfigDir() (string, error) { 36 | homeDir, err := os.UserHomeDir() 37 | if err != nil { 38 | return "", err 39 | } 40 | configDir := filepath.Join(homeDir, ".config", "cqlizer") 41 | if err := os.MkdirAll(configDir, 0755); err != nil { 42 | return "", err 43 | } 44 | return configDir, nil 45 | } 46 | 47 | func runActionREPL(modelType, modelPath string) { 48 | mlModel, err := eval.GetMLModel(modelType, modelPath) 49 | if err != nil { 50 | fmt.Printf("Error loading model: %v\n", err) 51 | os.Exit(1) 52 | } 53 | 54 | titleColor := color.New(color.FgHiMagenta).SprintFunc() 55 | greenColor := color.New(color.FgGreen).SprintFunc() 56 | redColor := color.New(color.FgRed).SprintFunc() 57 | 58 | // Default corpus size (can be overridden with 'set corpussize <value>') 59 | corpusSize := 6400000000.0 // 6.4G tokens default 60 | voteThreshold := 0.85 61 | lang := "cs" 62 | 63 | mlModel.SetClassThreshold(voteThreshold) 64 | 65 | fmt.Println("CQL Query Complexity Estimator") 66 | fmt.Println("Commands:") 67 | fmt.Println(" <CQL query> - Estimate query execution time") 68 | fmt.Println(" set corpussize <size> - Set corpus size (e.g., 'set corpussize 121826797')") 69 | fmt.Println(" set lang <lang> - Set corpus language (e.g., 'set lang cs')") 70 | fmt.Println(" set vote <value 0..1> - set model vote threshold") 71 | fmt.Println(" setup - view current settings") 72 | fmt.Println(" exit - Exit REPL") 73 | fmt.Printf("\nCurrent corpus size: %s tokens\n\n", modutils.FormatRoughSize(int64(corpusSize))) 74 | 75 | var historyFile string 76 | historyDir, err := ensureConfigDir() 77 | if err != nil { 78 | log.Error().Err(err).Msg("failed to determine user config directory - falling back to session-local history") 79 | 80 | } else { 81 | historyFile = filepath.Join(historyDir, "cql-history.txt") 82 | } 83 | 84 | rl, err := readline.NewEx(&readline.Config{ 85 | Prompt: color.New(color.FgHiGreen).Sprintf("/cql> "), 86 | HistoryFile: historyFile, 87 | }) 88 | if err != nil { 89 | fmt.Printf("Error initializing readline: %v\n", err) 90 | os.Exit(1) 91 | } 92 | defer rl.Close() 93 | 94 | for { 95 | line, err := rl.Readline() 96 | if err != nil { 97 | if err == readline.ErrInterrupt || err == io.EOF { 98 | fmt.Println("\nCQLizer out!") 99 | break 100 | } 101 | fmt.Printf("Error reading input: %v\n", err) 102 | continue 103 | } 104 | input := strings.TrimSpace(line) 105 | 106 | if input == "exit" { 107 | fmt.Println("Goodbye!") 108 | break 109 | } 110 | 111 | if strings.HasPrefix(input, "set ") { 112 | parsedInput := strings.Fields(input)[1:] 113 | switch parsedInput[0] { 114 | case "corpussize": 115 | if len(parsedInput) == 2 { 116 | corpusSize, err = strconv.ParseFloat(parsedInput[1], 64) 117 | if err != nil { 118 | fmt.Println("Error: Invalid corpus size") 119 | } 120 | 121 | } else { 122 | fmt.Println("Usage: set corpussize <size>") 123 | } 124 | case "vote": 125 | if len(parsedInput) == 2 { 126 | voteThreshold, err = strconv.ParseFloat(parsedInput[1], 64) 127 | if err != nil { 128 | fmt.Println("failed to parse number") 129 | } 130 | mlModel.SetClassThreshold(voteThreshold) 131 | 132 | } else { 133 | fmt.Println("Usage: set vote <value 0..1>") 134 | } 135 | case "lang": 136 | if len(parsedInput) == 2 { 137 | lang = parsedInput[1] 138 | 139 | } else { 140 | fmt.Println("Usage: set lang <lang>") 141 | } 142 | default: 143 | fmt.Println("Unknown 'set' command") 144 | } 145 | continue 146 | 147 | } else if input == "setup" { 148 | fmt.Printf("%s:\t%s\n", titleColor("Corpus size"), modutils.FormatRoughSize(int64(corpusSize))) 149 | fmt.Printf("%s:\t\t%s\n", titleColor("Model"), modelPath) 150 | fmt.Printf("%s:\t%.2f\n", titleColor("Vote threshold"), voteThreshold) 151 | continue 152 | } 153 | 154 | // Treat as CQL query 155 | charProbs := feats.GetCharProbabilityProvider(lang) 156 | queryEval, err := feats.NewQueryEvaluation(input, corpusSize, 0, 0, charProbs) 157 | if err != nil { 158 | fmt.Printf("Error parsing CQL: %v\n", err) 159 | continue 160 | } 161 | 162 | // Display results 163 | 164 | fmt.Printf("%s:\n", titleColor("Pos. features")) 165 | for i, pos := range queryEval.Positions { 166 | fmt.Printf(" %s: wildcards=%0.2f, range=%d, smallCard=%d, numConcreteChars=%.2f, posNumAlts: %d\n", 167 | titleColor(fmt.Sprintf("[%d]", i)), 168 | pos.Regexp.WildcardScore, pos.Regexp.HasRange, pos.HasSmallCardAttr, pos.Regexp.NumConcreteChars, pos.NumAlternatives) 169 | } 170 | fmt.Printf("%s: glob=%d, meet=%d, union=%d, within=%d, containing=%d\n", 171 | titleColor("Global features"), 172 | queryEval.NumGlobConditions, queryEval.ContainsMeet, 173 | queryEval.ContainsUnion, queryEval.ContainsWithin, queryEval.ContainsContaining) 174 | 175 | if mlModel != nil { 176 | rfPRediction := mlModel.Predict(queryEval) 177 | var predResult string 178 | if rfPRediction.PredictedClass == 1 { 179 | predResult = redColor(rfPRediction.FastOrSlow() + " query") 180 | 181 | } else { 182 | predResult = greenColor(rfPRediction.FastOrSlow() + " query") 183 | } 184 | fmt.Printf("model prediction: %s\n", predResult) 185 | fmt.Printf("vote 0: %.2f, vote 1: %.2f\n", rfPRediction.Votes[0], rfPRediction.Votes[1]) 186 | } 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /eval/xg/xgboost.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package xg 18 | 19 | import ( 20 | "bufio" 21 | "compress/gzip" 22 | "context" 23 | "encoding/json" 24 | "fmt" 25 | "io" 26 | "os" 27 | "path/filepath" 28 | "strings" 29 | 30 | "github.com/czcorpus/cnc-gokit/fs" 31 | "github.com/czcorpus/cqlizer/eval/feats" 32 | "github.com/czcorpus/cqlizer/eval/modutils" 33 | "github.com/czcorpus/cqlizer/eval/predict" 34 | "github.com/dmitryikh/leaves" 35 | "github.com/rs/zerolog/log" 36 | "github.com/vmihailenco/msgpack/v5" 37 | ) 38 | 39 | type metadata struct { 40 | Objective string `json:"objective"` 41 | Metric [2]string `json:"metric"` 42 | ScalePosWeight float64 `json:"scale_pos_weight"` 43 | MaxDepth int `json:"max_depth"` 44 | LearningRate float64 `json:"learning_rate"` 45 | NumLeaves int `json:"num_leaves"` 46 | MinChildSamples int `json:"min_child_samples"` 47 | Subsample float64 `json:"subsample"` 48 | ColsampleBytree float64 `json:"colsample_bytree"` 49 | RandomState int `json:"random_state"` 50 | Verbose int `json:"verbose"` 51 | } 52 | 53 | type Model struct { 54 | ClassThreshold float64 55 | SlowQueriesThresholdTime float64 56 | trainXData [][]float64 57 | trainYData []int 58 | xgboost *leaves.Ensemble 59 | metadata metadata 60 | } 61 | 62 | func (m *Model) IsInferenceOnly() bool { 63 | return true 64 | } 65 | 66 | func (m *Model) CreateModelFileName(featsFile string) string { 67 | return modutils.ExtractModelNameBaseFromFeatFile(featsFile) + ".feats.xg.msgpack" 68 | } 69 | 70 | func (m *Model) Train(ctx context.Context, data []feats.QueryEvaluation, slowQueriesTime float64, comment string) error { 71 | if len(data) == 0 { 72 | return fmt.Errorf("no training data provided") 73 | } 74 | if slowQueriesTime <= 0 { 75 | return fmt.Errorf("failed to train RF model - invalid value of SlowQueriesThresholdTime") 76 | } 77 | m.SlowQueriesThresholdTime = slowQueriesTime 78 | 79 | var xData [][]float64 80 | var yData []int 81 | numProblematic := 0 82 | for i, eval := range data { 83 | if i%100 == 0 && ctx != nil && ctx.Err() != nil { 84 | return ctx.Err() 85 | } 86 | features := feats.ExtractFeatures(eval) 87 | isPositive := 0 88 | if eval.ProcTime >= m.SlowQueriesThresholdTime { 89 | numProblematic++ 90 | isPositive = 1 91 | } 92 | xData = append(xData, features) 93 | yData = append(yData, isPositive) 94 | } 95 | m.trainXData = xData 96 | m.trainYData = yData 97 | return nil 98 | } 99 | 100 | func (m *Model) Predict(eval feats.QueryEvaluation) predict.Prediction { 101 | features := feats.ExtractFeatures(eval) 102 | pred := m.xgboost.PredictSingle(features, 0) 103 | var ans int 104 | if pred > m.ClassThreshold { 105 | ans = 1 106 | } 107 | return predict.Prediction{ 108 | Votes: []float64{1 - pred, pred}, 109 | PredictedClass: ans, 110 | } 111 | } 112 | 113 | func (m *Model) SetClassThreshold(v float64) { 114 | m.ClassThreshold = v 115 | } 116 | 117 | func (m *Model) GetClassThreshold() float64 { 118 | return m.ClassThreshold 119 | } 120 | 121 | func (m *Model) GetSlowQueriesThresholdTime() float64 { 122 | return m.SlowQueriesThresholdTime 123 | } 124 | 125 | func (m *Model) SaveToFile(filePath string) error { 126 | file, err := os.Create(filePath) 127 | if err != nil { 128 | return fmt.Errorf("failed to save RF model to a file: %w", err) 129 | } 130 | defer file.Close() 131 | out := make(map[string]any) 132 | out["features"] = m.trainXData 133 | out["label"] = m.trainYData 134 | 135 | outData, err := msgpack.Marshal(out) 136 | if err != nil { 137 | return fmt.Errorf("failed to create XGBoost training data: %w", err) 138 | } 139 | _, err = file.Write(outData) 140 | if err != nil { 141 | return fmt.Errorf("failed to create XGBoost training data: %w", err) 142 | } 143 | return nil 144 | } 145 | 146 | func (m *Model) GetInfo() string { 147 | return fmt.Sprintf( 148 | "XGBoost model, metric: %s / %s, NL: %d, SPV: %.2f, LR: %.2f", 149 | m.metadata.Metric[0], 150 | m.metadata.Metric[1], 151 | m.metadata.NumLeaves, 152 | m.metadata.ScalePosWeight, 153 | m.metadata.LearningRate, 154 | ) 155 | } 156 | 157 | func loadMetadata(modelPath string) (metadata, error) { 158 | var mt metadata 159 | var metadataFilePath string 160 | ext := filepath.Ext(modelPath) 161 | if ext == ".gz" || ext == ".gzip" { 162 | modelPath = modelPath[:len(modelPath)-len(ext)] 163 | ext = filepath.Ext(modelPath) 164 | } 165 | metadataFilePath = modelPath[:len(modelPath)-len(ext)] + ".metadata.json" 166 | isFile, err := fs.IsFile(metadataFilePath) 167 | if err != nil { 168 | return mt, fmt.Errorf("failed to load XG model metadata: %w", err) 169 | } 170 | if !isFile { 171 | log.Warn().Msg("Cannot load XG model metadata - no file found. For inference, this doesn't matter.") 172 | return mt, nil 173 | } 174 | data, err := os.ReadFile(metadataFilePath) 175 | if err != nil { 176 | return mt, fmt.Errorf("failed to load XG model metadata: %w", err) 177 | } 178 | if err := json.Unmarshal(data, &mt); err != nil { 179 | return mt, fmt.Errorf("failed to load XG model metadata: %w", err) 180 | } 181 | return mt, nil 182 | } 183 | 184 | func LoadFromFile(filePath string) (*Model, error) { 185 | file, err := os.Open(filePath) 186 | if err != nil { 187 | return nil, fmt.Errorf("failed to open file: %w", err) 188 | } 189 | defer file.Close() 190 | 191 | var reader io.Reader = file 192 | if strings.HasSuffix(filePath, ".gz") || strings.HasSuffix(filePath, ".gzip") { 193 | gzReader, err := gzip.NewReader(file) 194 | if err != nil { 195 | return nil, fmt.Errorf("failed to create gzip reader: %w", err) 196 | } 197 | defer gzReader.Close() 198 | reader = gzReader 199 | } 200 | 201 | model, err := leaves.LGEnsembleFromReader(bufio.NewReader(reader), true) 202 | if err != nil { 203 | return nil, fmt.Errorf("failed to load XG model: %w", err) 204 | } 205 | metadata, err := loadMetadata(filePath) 206 | if err != nil { 207 | return nil, fmt.Errorf("failed to load XG model: %w", err) 208 | } 209 | return &Model{xgboost: model, metadata: metadata}, nil 210 | } 211 | 212 | func NewModel() *Model { 213 | return &Model{ 214 | ClassThreshold: 0.5, 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /eval/rf/model.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package rf 18 | 19 | import ( 20 | "compress/gzip" 21 | "context" 22 | "encoding/json" 23 | "fmt" 24 | "io" 25 | "os" 26 | "strings" 27 | 28 | "github.com/czcorpus/cqlizer/eval/feats" 29 | "github.com/czcorpus/cqlizer/eval/modutils" 30 | "github.com/czcorpus/cqlizer/eval/predict" 31 | randomforest "github.com/malaschitz/randomForest" 32 | "github.com/rs/zerolog/log" 33 | ) 34 | 35 | type jsonizedRFModel struct { 36 | Forest json.RawMessage `json:"forest"` 37 | Comment string `json:"comment"` 38 | slowQueriesThresholdTime float64 `json:"slowQueriesThresholdTime"` 39 | } 40 | 41 | // Model wraps a Random Forest classifier for regression via quantile binning 42 | type Model struct { 43 | Forest *randomforest.Forest `json:"forest"` 44 | NumTrees int `json:"numTrees"` 45 | VotingThreshold float64 `json:"votingThreshold"` 46 | SlowQueriesThresholdTime float64 `json:"slowQueriesThresholdTime"` 47 | Comment string `json:"comment"` 48 | } 49 | 50 | // NewModel creates a new Random Forest model with time binning 51 | func NewModel(numTrees int, votingThreshold float64) *Model { 52 | return &Model{ 53 | Forest: &randomforest.Forest{}, 54 | NumTrees: numTrees, 55 | VotingThreshold: votingThreshold, 56 | } 57 | } 58 | 59 | func (m *Model) IsInferenceOnly() bool { 60 | return false 61 | } 62 | 63 | func (m *Model) CreateModelFileName(featsFile string) string { 64 | return modutils.ExtractModelNameBaseFromFeatFile(featsFile) + ".model.rf.json" 65 | } 66 | 67 | func (m *Model) GetClassThreshold() float64 { 68 | return m.VotingThreshold 69 | } 70 | 71 | func (m *Model) SetClassThreshold(v float64) { 72 | m.VotingThreshold = v 73 | } 74 | 75 | func (m *Model) GetSlowQueriesThresholdTime() float64 { 76 | return m.SlowQueriesThresholdTime 77 | } 78 | 79 | func (m *Model) GetInfo() string { 80 | return fmt.Sprintf("RF model, num. trees: %d, slow q. threshold time: %.2fs", m.NumTrees, m.SlowQueriesThresholdTime) 81 | } 82 | 83 | // Train trains the random forest on query evaluations and actual times 84 | // note: the `comment` argument will be stored with the model for easier model review 85 | func (m *Model) Train(ctx context.Context, data []feats.QueryEvaluation, slowQueriesThresholdTime float64, comment string) error { 86 | if len(data) == 0 { 87 | return fmt.Errorf("no training data provided") 88 | } 89 | if slowQueriesThresholdTime <= 0 { 90 | return fmt.Errorf("failed to train RF model - invalid value of SlowQueriesThresholdTime") 91 | } 92 | m.SlowQueriesThresholdTime = slowQueriesThresholdTime 93 | if m.NumTrees <= 0 { 94 | return fmt.Errorf("failed to train RF model - invalid value of NumTrees") 95 | } 96 | var xData [][]float64 97 | var yData []int 98 | numProblematic := 0 99 | for i, eval := range data { 100 | if i%100 == 0 && ctx != nil && ctx.Err() != nil { 101 | return ctx.Err() 102 | } 103 | features := feats.ExtractFeatures(eval) 104 | isPositive := 0 105 | if eval.ProcTime >= m.SlowQueriesThresholdTime { 106 | numProblematic++ 107 | isPositive = 1 108 | } 109 | xData = append(xData, features) 110 | yData = append(yData, isPositive) 111 | } 112 | log.Debug(). 113 | Int("numPositive", numProblematic). 114 | Int("dataSize", len(data)). 115 | Msg("prepared training vectors") 116 | 117 | m.Forest.Data = randomforest.ForestData{ 118 | X: xData, 119 | Class: yData, 120 | } 121 | m.Forest.Train(m.NumTrees) 122 | m.Comment = comment 123 | return nil 124 | } 125 | 126 | // Predict estimates query execution time using the trained forest 127 | func (m *Model) Predict(eval feats.QueryEvaluation) predict.Prediction { 128 | features := feats.ExtractFeatures(eval) 129 | votes := m.Forest.Vote(features) 130 | var ans int 131 | if votes[1] > m.VotingThreshold { 132 | ans = 1 133 | } 134 | return predict.Prediction{ 135 | Votes: votes, 136 | PredictedClass: ans, 137 | } 138 | } 139 | 140 | // SaveToFile saves the RF model to a file 141 | func (m *Model) SaveToFile(filePath string) error { 142 | file, err := os.Create(filePath) 143 | if err != nil { 144 | return fmt.Errorf("failed to save RF model to a file: %w", err) 145 | } 146 | defer file.Close() 147 | 148 | tmpModel := jsonizedRFModel{ 149 | Comment: m.Comment, 150 | slowQueriesThresholdTime: m.SlowQueriesThresholdTime, 151 | } 152 | 153 | bytes, err := json.Marshal(&m.Forest) 154 | if err != nil { 155 | return fmt.Errorf("failed to save RF model to a file: %w", err) 156 | } 157 | 158 | tmpModel.Forest = bytes 159 | 160 | bytes, err = json.Marshal(tmpModel) 161 | if err != nil { 162 | return fmt.Errorf("failed to save RF model to a file: %w", err) 163 | } 164 | _, err = file.Write(bytes) 165 | if err != nil { 166 | return fmt.Errorf("failed to save RF model to a file: %w", err) 167 | } 168 | return nil 169 | } 170 | 171 | // LoadFromFile loads model metadata from file 172 | // Note: This is a placeholder - the actual forest cannot be serialized/deserialized 173 | // with the current randomForest package 174 | func LoadFromFile(filePath string) (*Model, error) { 175 | file, err := os.Open(filePath) 176 | if err != nil { 177 | return nil, fmt.Errorf("failed to open file: %w", err) 178 | } 179 | defer file.Close() 180 | 181 | var reader io.Reader = file 182 | if strings.HasSuffix(filePath, ".gz") || strings.HasSuffix(filePath, ".gzip") { 183 | gzReader, err := gzip.NewReader(file) 184 | if err != nil { 185 | return nil, fmt.Errorf("failed to create gzip reader: %w", err) 186 | } 187 | defer gzReader.Close() 188 | reader = gzReader 189 | } 190 | 191 | var tmpModel jsonizedRFModel 192 | data, err := io.ReadAll(reader) 193 | if err != nil { 194 | return nil, fmt.Errorf("failed to load Random Forest model from file: %w", err) 195 | } 196 | if err := json.Unmarshal(data, &tmpModel); err != nil { 197 | return nil, fmt.Errorf("failed to load Random Forest model from file: %w", err) 198 | } 199 | 200 | model := &Model{ 201 | Comment: tmpModel.Comment, 202 | SlowQueriesThresholdTime: tmpModel.slowQueriesThresholdTime, 203 | } 204 | 205 | var forest randomforest.Forest 206 | if err := json.Unmarshal(tmpModel.Forest, &forest); err != nil { 207 | return nil, fmt.Errorf("failed to load Random Forest model from file: %w", err) 208 | } 209 | model.Forest = &forest 210 | model.NumTrees = forest.NTrees 211 | return model, nil 212 | } 213 | -------------------------------------------------------------------------------- /eval/nn/model.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package nn 18 | 19 | import ( 20 | "compress/gzip" 21 | "context" 22 | "encoding/json" 23 | "fmt" 24 | "io" 25 | "os" 26 | "strings" 27 | 28 | "github.com/czcorpus/cqlizer/eval/feats" 29 | "github.com/czcorpus/cqlizer/eval/modutils" 30 | "github.com/czcorpus/cqlizer/eval/predict" 31 | "github.com/patrikeh/go-deep" 32 | "github.com/patrikeh/go-deep/training" 33 | "github.com/rs/zerolog/log" 34 | ) 35 | 36 | type FeatureStats struct { 37 | Min float64 38 | Max float64 39 | } 40 | 41 | var ( 42 | //networkLayout = []int{20, 14, 7, 1} 43 | networkLayout = []int{50, 15, 1} 44 | //networkLayout = []int{30, 10, 1} 45 | numEpochs = 800 46 | //learningRate = 0.001 47 | learningRate = 0.0005 48 | ) 49 | 50 | type jsonizedModel struct { 51 | NeuralNet *deep.Dump `json:"neuralNet"` 52 | DataRanges []FeatureStats `json:"dataRanges"` 53 | SlowQueriesThresholdTime float64 `json:"slowQueriesThresholdTime"` 54 | ClassThreshold float64 `json:"classThreshold"` 55 | } 56 | 57 | // Model is a neural-network based model for evaluating CQL queries. 58 | // It is rather experimental and does not perform as well as other 59 | // models here so it is not recommended for production use. 60 | type Model struct { 61 | NeuralNet *deep.Neural 62 | DataRanges []FeatureStats 63 | SlowQueriesThresholdTime float64 64 | ClassThreshold float64 65 | } 66 | 67 | func (m *Model) IsInferenceOnly() bool { 68 | return false 69 | } 70 | 71 | func (m *Model) CreateModelFileName(featsFile string) string { 72 | return modutils.ExtractModelNameBaseFromFeatFile(featsFile) + ".model.nn.json" 73 | } 74 | 75 | func (m *Model) GetClassThreshold() float64 { 76 | return m.ClassThreshold 77 | } 78 | 79 | func (m *Model) SetClassThreshold(v float64) { 80 | m.ClassThreshold = v 81 | } 82 | 83 | func (m *Model) GetSlowQueriesThresholdTime() float64 { 84 | return m.SlowQueriesThresholdTime 85 | } 86 | 87 | func (m *Model) GetInfo() string { 88 | return fmt.Sprintf("NN model, layout: #%v, epochs: %d, slow q. threshold time: %.2fs", networkLayout, numEpochs, m.SlowQueriesThresholdTime) 89 | } 90 | 91 | // Train 92 | // TODO: comment is not stored 93 | func (m *Model) Train(ctx context.Context, data []feats.QueryEvaluation, slowQueriesTime float64, comment string) error { 94 | if len(data) == 0 { 95 | return fmt.Errorf("no training data provided") 96 | } 97 | if slowQueriesTime <= 0 { 98 | return fmt.Errorf("failed to train RF model - invalid value of SlowQueriesThresholdTime") 99 | } 100 | m.SlowQueriesThresholdTime = slowQueriesTime 101 | var featData = training.Examples{} 102 | //numTotal := len(dataModel.Evaluations) 103 | numProblematic := 0 104 | data2 := make([]feats.QueryEvaluation, len(data), len(data)*4) 105 | copy(data2, data) 106 | data2 = append(data2, data...) 107 | data2 = append(data2, data...) 108 | data2 = append(data2, data...) 109 | data = data2 110 | for _, eval := range data { 111 | features := feats.ExtractFeatures(eval) 112 | response := 0.0 113 | if eval.ProcTime >= m.SlowQueriesThresholdTime { 114 | numProblematic++ 115 | response = 1.0 116 | } 117 | featData = append( 118 | featData, 119 | training.Example{ 120 | Input: features, 121 | Response: []float64{response}, 122 | }, 123 | ) 124 | } 125 | log.Debug(). 126 | Int("numPositive", numProblematic). 127 | Int("dataSize", len(data)). 128 | Msg("prepared training vectors") 129 | 130 | m.DataRanges = m.getDataStats(featData) 131 | for _, item := range featData { 132 | m.normalizeNNFeats(item.Input) 133 | } 134 | 135 | // TODO !!!!!! we use the same training and heldout data !!! 136 | // trn, heldout := featData, featData 137 | trn, heldout := featData.Split(0.5) 138 | 139 | fmt.Printf("STATS: >>> %#v\n", m.DataRanges) 140 | 141 | //for _, item := range heldout { 142 | //m.normalizeNNFeats(item) 143 | //} 144 | 145 | m.NeuralNet = deep.NewNeural(&deep.Config{ 146 | Inputs: 50, 147 | Layout: networkLayout, 148 | Activation: deep.ActivationReLU, 149 | Mode: deep.ModeBinary, 150 | Weight: deep.NewUniform(1.0, 0.0), 151 | Bias: true, 152 | }) 153 | 154 | //optimizer := training.NewSGD(0.05, 0.4, 1e-5, true) 155 | optimizer := training.NewAdam(learningRate, 0.9, 0.999, 1e-8) 156 | // params: optimizer, verbosity (print stats at every 50th iteration) 157 | trainer := training.NewTrainer(optimizer, 50) 158 | trainer.TrainContext(ctx, m.NeuralNet, trn, heldout, numEpochs) 159 | return nil 160 | } 161 | 162 | func (m *Model) getDataStats(data training.Examples) []FeatureStats { 163 | stats := make([]FeatureStats, feats.NumFeatures) 164 | for _, item := range data { 165 | for i := 0; i < len(item.Input); i++ { 166 | if item.Input[i] > stats[i].Max { 167 | stats[i].Max = item.Input[i] 168 | } 169 | if item.Input[i] < stats[i].Min { 170 | stats[i].Min = item.Input[i] 171 | } 172 | } 173 | } 174 | return stats 175 | } 176 | 177 | func (m *Model) normalizeNNFeats(data []float64) { 178 | for i := 0; i < feats.NumFeatures; i++ { 179 | min := m.DataRanges[i].Min 180 | max := m.DataRanges[i].Max 181 | 182 | if max == min { 183 | data[i] = 0.0 // constant feature 184 | 185 | } else { 186 | data[i] = (data[i] - min) / (max - min) 187 | } 188 | } 189 | } 190 | 191 | func (m *Model) Predict(eval feats.QueryEvaluation) predict.Prediction { 192 | features := feats.ExtractFeatures(eval) 193 | m.normalizeNNFeats(features) 194 | out := m.NeuralNet.Predict(features) 195 | var predClass int 196 | if out[0] >= m.ClassThreshold { 197 | predClass = 1 198 | } 199 | return predict.Prediction{ 200 | Votes: []float64{1 - out[0], out[0]}, 201 | PredictedClass: predClass, 202 | } 203 | } 204 | 205 | func (m *Model) SaveToFile(filePath string) error { 206 | file, err := os.Create(filePath) 207 | if err != nil { 208 | return fmt.Errorf("failed to save RF model to a file: %w", err) 209 | } 210 | defer file.Close() 211 | dmp := m.NeuralNet.Dump() 212 | tmpModel := jsonizedModel{ 213 | NeuralNet: dmp, 214 | DataRanges: m.DataRanges, 215 | SlowQueriesThresholdTime: m.SlowQueriesThresholdTime, 216 | ClassThreshold: m.ClassThreshold, 217 | } 218 | bytes, err := json.Marshal(tmpModel) 219 | if err != nil { 220 | return fmt.Errorf("failed to save NN to file: %w", err) 221 | } 222 | _, err = file.Write(bytes) 223 | if err != nil { 224 | return fmt.Errorf("failed to save NN model to a file: %w", err) 225 | } 226 | return nil 227 | } 228 | 229 | func LoadFromFile(filePath string) (*Model, error) { 230 | file, err := os.Open(filePath) 231 | if err != nil { 232 | return nil, fmt.Errorf("failed to open file: %w", err) 233 | } 234 | defer file.Close() 235 | 236 | var reader io.Reader = file 237 | if strings.HasSuffix(filePath, ".gz") || strings.HasSuffix(filePath, ".gzip") { 238 | gzReader, err := gzip.NewReader(file) 239 | if err != nil { 240 | return nil, fmt.Errorf("failed to create gzip reader: %w", err) 241 | } 242 | defer gzReader.Close() 243 | reader = gzReader 244 | } 245 | 246 | var model jsonizedModel 247 | data, err := io.ReadAll(reader) 248 | if err != nil { 249 | return nil, fmt.Errorf("failed to load Neural Network model from file %s: %w", filePath, err) 250 | } 251 | if err := json.Unmarshal(data, &model); err != nil { 252 | return nil, fmt.Errorf("failed to load Neural Network model from file %s: %w", filePath, err) 253 | } 254 | nn := deep.FromDump(model.NeuralNet) 255 | return &Model{ 256 | NeuralNet: nn, 257 | DataRanges: model.DataRanges, 258 | SlowQueriesThresholdTime: model.SlowQueriesThresholdTime, 259 | ClassThreshold: model.ClassThreshold, 260 | }, nil 261 | } 262 | 263 | func NewModel() *Model { 264 | return &Model{ 265 | ClassThreshold: 0.5, 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /lognorm.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package main 18 | 19 | import ( 20 | "context" 21 | "encoding/json" 22 | "fmt" 23 | "io" 24 | "net/http" 25 | "net/url" 26 | "os" 27 | "slices" 28 | "strings" 29 | "time" 30 | 31 | "github.com/czcorpus/cnc-gokit/collections" 32 | "github.com/czcorpus/cqlizer/cnf" 33 | "github.com/czcorpus/cqlizer/dataimport" 34 | "github.com/czcorpus/cqlizer/eval" 35 | "github.com/czcorpus/cqlizer/eval/feats" 36 | "github.com/rs/zerolog/log" 37 | ) 38 | 39 | func getMQueryCorpora(mqueryURL string) ([]string, error) { 40 | urlObj, err := url.Parse(mqueryURL) 41 | if err != nil { 42 | return []string{}, fmt.Errorf("cannot measure request: %w", err) 43 | } 44 | urlObj = urlObj.JoinPath("/corplist") 45 | resp, err := http.Get(urlObj.String()) 46 | if err != nil { 47 | return []string{}, fmt.Errorf("failed to fetch installed corpora from MQuery: %w", err) 48 | } 49 | var respObj corporaResp 50 | rawResp, err := io.ReadAll(resp.Body) 51 | if err != nil { 52 | return []string{}, fmt.Errorf("failed to fetch installed corpora from MQuery: %w", err) 53 | } 54 | if err := json.Unmarshal(rawResp, &respObj); err != nil { 55 | return []string{}, fmt.Errorf("failed to fetch installed corpora from MQuery: %w", err) 56 | } 57 | ans := make([]string, len(respObj.Corpora)) 58 | for i, rc := range respObj.Corpora { 59 | ans[i] = rc.ID 60 | } 61 | return ans, nil 62 | } 63 | 64 | type missingRecFixer struct { 65 | ctx context.Context 66 | mqueryURL string 67 | uniqEntries map[string]eval.QueryStatsRecord 68 | onlyAllowedCorpora []string 69 | batchSize int 70 | batchOffset int 71 | corporaProps map[string]feats.CorpusProps 72 | } 73 | 74 | func (fixer *missingRecFixer) ProcessEntry(entry eval.QueryStatsRecord) error { 75 | if !slices.Contains(fixer.onlyAllowedCorpora, entry.Corpus) { 76 | return nil 77 | } 78 | if entry.TimeProc == 0 || entry.CorpusSize == 0 { 79 | _, ok := fixer.uniqEntries[entry.UniqKey()] 80 | if !ok { 81 | fixer.uniqEntries[entry.UniqKey()] = entry 82 | } 83 | } 84 | return nil 85 | } 86 | 87 | func (fixer *missingRecFixer) SetStats(numProcessed, numFailed int) { 88 | 89 | } 90 | 91 | func (fixer *missingRecFixer) RunBenchmark() { 92 | procEntries := make([]eval.QueryStatsRecord, len(fixer.uniqEntries)) 93 | i := 0 94 | for _, entry := range fixer.uniqEntries { 95 | procEntries[i] = entry 96 | i++ 97 | } 98 | slices.SortFunc(procEntries, func(v1, v2 eval.QueryStatsRecord) int { 99 | return strings.Compare(v1.UniqKey(), v2.UniqKey()) 100 | }) 101 | var isIncompleteProc bool 102 | if fixer.batchSize == 0 { 103 | fixer.batchSize = len(procEntries) 104 | 105 | } else { 106 | if len(procEntries) > fixer.batchOffset+fixer.batchSize { 107 | isIncompleteProc = true 108 | } 109 | fixer.batchSize = min(len(procEntries), fixer.batchSize) 110 | } 111 | for _, entry := range procEntries[fixer.batchOffset:fixer.batchSize] { 112 | select { 113 | case <-fixer.ctx.Done(): 114 | return 115 | default: 116 | } 117 | 118 | if entry.CorpusSize == 0 && entry.Corpus != "" { // legacy records with just corpnames 119 | entry.CorpusSize = int64(fixer.corporaProps[entry.Corpus].Size) 120 | } 121 | if entry.TimeProc > 0 { 122 | continue 123 | // we are also dealing with records with just missing "CorpusSize" property which is fixable 124 | // without benchmarking 125 | } 126 | 127 | if entry.CorpusSize == 0 { 128 | log.Warn().Str("corpname", entry.Corpus).Str("q", entry.Query).Msg("entry ignored due to missing corpus size") 129 | continue 130 | } 131 | 132 | t0, err := fixer.measureRequest(fixer.ctx, fixer.mqueryURL, entry.Corpus, entry.GetCQL()) 133 | if err != nil { 134 | log.Error().Err(err).Msg("failed to perform benchmark query, skipping") 135 | continue 136 | } 137 | entry.TimeProc = t0 138 | entry.IsSynthetic = true 139 | data, err := json.Marshal(entry) 140 | if err != nil { 141 | log.Error().Err(err).Msg("failed to perform benchmark query, skipping") 142 | continue 143 | } 144 | fmt.Println(string(data)) 145 | } 146 | if isIncompleteProc { 147 | fmt.Fprintf( 148 | os.Stderr, 149 | "Finished the current batch (%d ... %d), More data are available (next offset: %d)", 150 | fixer.batchOffset, 151 | fixer.batchSize-1, 152 | fixer.batchOffset+fixer.batchSize, 153 | ) 154 | } else { 155 | fmt.Fprintf( 156 | os.Stderr, 157 | "Finished the current batch (%d ... %d). No more data are available.", 158 | fixer.batchOffset, 159 | fixer.batchSize-1, 160 | ) 161 | } 162 | } 163 | 164 | type corporaRespCorpus struct { 165 | ID string `json:"id"` 166 | } 167 | 168 | type corporaResp struct { 169 | Corpora []corporaRespCorpus `json:"corpora"` 170 | } 171 | 172 | func (fixer *missingRecFixer) measureRequest( 173 | ctx context.Context, 174 | mqueryURL, corpname, q string, 175 | ) (float64, error) { 176 | urlObj, err := url.Parse(mqueryURL) 177 | if err != nil { 178 | return -1, fmt.Errorf("cannot measure request: %w", err) 179 | } 180 | urlObj = urlObj.JoinPath(fmt.Sprintf("/concordance/%s", corpname)) 181 | args := make(url.Values) 182 | args.Add("q", q) 183 | urlObj.RawQuery = args.Encode() 184 | req, err := http.NewRequestWithContext( 185 | ctx, 186 | "GET", 187 | urlObj.String(), 188 | nil, 189 | ) 190 | if err != nil { 191 | return -1, fmt.Errorf("failed to perform MQuery search (corpus: %s, q: %s): %w", corpname, q, err) 192 | } 193 | t0 := time.Now() 194 | resp, err := http.DefaultClient.Do(req) 195 | if err != nil { 196 | return -1, fmt.Errorf("failed to perform MQuery search (corpus: %s, q: %s): %w", corpname, q, err) 197 | } 198 | if resp.StatusCode != 200 { 199 | return -1, fmt.Errorf("failed to perform MQuery search (corpus: %s, q: %s) - status %s", corpname, q, resp.Status) 200 | } 201 | return float64(time.Since(t0).Seconds()), nil 202 | } 203 | 204 | func runActionBenchmarkMissing( 205 | ctx context.Context, 206 | conf *cnf.Conf, 207 | srcPath string, 208 | onlyAllowedCorpora []string, 209 | batchSize int, 210 | batchOffset int, 211 | ) { 212 | if len(onlyAllowedCorpora) == 0 { 213 | mqCorpora, err := getMQueryCorpora(conf.MQueryBenchmarkingURL) 214 | if err != nil { 215 | log.Fatal().Err(err).Msg("failed to run benchmark action") 216 | return 217 | } 218 | onlyAllowedCorpora = mqCorpora 219 | } 220 | fmt.Fprintln(os.Stderr, "Only queries for the following corpora will be tested:") 221 | for _, v := range onlyAllowedCorpora { 222 | fmt.Fprintf(os.Stderr, "\t%s\n", v) 223 | } 224 | fixer := &missingRecFixer{ 225 | ctx: ctx, 226 | mqueryURL: conf.MQueryBenchmarkingURL, 227 | uniqEntries: make(map[string]eval.QueryStatsRecord), 228 | onlyAllowedCorpora: onlyAllowedCorpora, 229 | batchSize: batchSize, 230 | batchOffset: batchOffset, 231 | corporaProps: conf.CorporaProps, 232 | } 233 | dataimport.ReadStatsFile(ctx, srcPath, fixer) 234 | fmt.Fprintf(os.Stderr, "queries loaded and deduplicated, num processable queries: %d\n", len(fixer.uniqEntries)) 235 | fixer.RunBenchmark() 236 | } 237 | 238 | // ------------------------------------------------- 239 | 240 | type zeroRemover struct { 241 | ctx context.Context 242 | numProcessed int 243 | numZero int 244 | foundCorpora map[string]int 245 | } 246 | 247 | func (remover *zeroRemover) ProcessEntry(entry eval.QueryStatsRecord) error { 248 | if entry.TimeProc > 0 { 249 | data, err := json.Marshal(entry) 250 | if err != nil { 251 | return fmt.Errorf("failed to marshal eval.QueryStatsRecord record: %w", err) 252 | } 253 | fmt.Println(string(data)) 254 | remover.numProcessed++ 255 | 256 | } else { 257 | remover.numZero++ 258 | remover.foundCorpora[entry.Corpus]++ 259 | } 260 | return nil 261 | } 262 | 263 | func (remover *zeroRemover) SetStats(numProcessed, numFailed int) { 264 | 265 | } 266 | 267 | type corpAndSize struct { 268 | c string 269 | s int 270 | } 271 | 272 | func runActionRemoveZero( 273 | ctx context.Context, 274 | conf *cnf.Conf, 275 | srcPath string, 276 | ) { 277 | rm := &zeroRemover{ 278 | ctx: ctx, 279 | foundCorpora: make(map[string]int), 280 | } 281 | 282 | var mqueryCorpora []string 283 | var err error 284 | if conf.MQueryBenchmarkingURL != "" { 285 | mqueryCorpora, err = getMQueryCorpora(conf.MQueryBenchmarkingURL) 286 | if err != nil { 287 | log.Error().Err(err).Msg("Failed to get MQuery corpora, skipping the feature") 288 | } 289 | } 290 | 291 | dataimport.ReadStatsFile(ctx, srcPath, rm) 292 | corpora := collections.MapToEntriesSorted( 293 | rm.foundCorpora, 294 | func(a, b collections.MapEntry[string, int]) int { 295 | return b.V - a.V 296 | }, 297 | ) 298 | corpora2 := make([]collections.MapEntry[string, int], 0, len(corpora)) 299 | for _, corp := range corpora { 300 | if slices.ContainsFunc(mqueryCorpora, func(v string) bool { 301 | return v == corp.K 302 | }) || len(mqueryCorpora) == 0 { 303 | corpora2 = append(corpora2, corp) 304 | } 305 | } 306 | fmt.Fprintln(os.Stderr, "\nCorpora with benchmarkable zero time requests:") 307 | if conf.MQueryBenchmarkingURL == "" { 308 | fmt.Fprintln(os.Stderr, "(without MQuery check - it is not known which corpora are installed for benchmarking)") 309 | } 310 | for _, entry := range corpora2 { 311 | fmt.Fprintf(os.Stderr, "\t%s: %d\n", entry.K, entry.V) 312 | } 313 | log.Info(). 314 | Int("numProcessed", rm.numProcessed). 315 | Int("numZero", rm.numZero). 316 | Msg("removed zero entries") 317 | } 318 | -------------------------------------------------------------------------------- /eval/feats/query_evaluation.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2024 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package feats 18 | 19 | import ( 20 | "math" 21 | "slices" 22 | "strings" 23 | 24 | "github.com/czcorpus/cqlizer/cql" 25 | ) 26 | 27 | const ( 28 | MaxPositions = 4 29 | ) 30 | 31 | type CorpusProps struct { 32 | Size int `json:"size"` 33 | Lang string `json:"lang"` 34 | AltCorpus string `json:"altCorpus"` 35 | } 36 | 37 | func logScaled(v float64) float64 { 38 | if v >= 1 { 39 | return math.Log(v) 40 | } 41 | return 0 42 | } 43 | 44 | // NewQueryEvaluation creates a QueryEvaluation from a CQL query string and corpus size 45 | func NewQueryEvaluation(cqlQuery string, corpusSize, namedSubcorpusSize, procTime float64, charProbs charProbabilityProvider) (QueryEvaluation, error) { 46 | query, err := cql.ParseCQL("", cqlQuery) 47 | if err != nil { 48 | return QueryEvaluation{}, err 49 | } 50 | 51 | eval := QueryEvaluation{ 52 | OrigQuery: cqlQuery, 53 | ProcTime: procTime, 54 | Positions: make([]Position, 0, MaxPositions), 55 | CorpusSize: logScaled(corpusSize), 56 | NamedSubcorpusSize: logScaled(namedSubcorpusSize), 57 | } 58 | 59 | // Extract features from the parsed query 60 | extractFeaturesFromQuery(query, &eval, charProbs) 61 | return eval, nil 62 | } 63 | 64 | // extractFeaturesFromQuery walks the AST and extracts relevant features 65 | func extractFeaturesFromQuery(query *cql.Query, eval *QueryEvaluation, charProbs charProbabilityProvider) { 66 | // First pass: collect all OnePosition nodes in order and extract their features 67 | if query.Sequence != nil { 68 | positionIndex := 0 69 | query.Sequence.ForEachElement(query.Sequence, func(parent, v cql.ASTNode) { 70 | switch typedNode := v.(type) { 71 | case *cql.Repetition: 72 | if positionIndex < MaxPositions { 73 | var pos Position 74 | if typedNode.IsAnyPosition() { 75 | pos.NumAlternatives = 1 76 | pos.Regexp.StartsWithWildCard = 1 77 | pos.Regexp.WildcardScore = 500 // TODO is this equivalent score to [attr=".*"] 78 | } 79 | pos.PosRepetition = typedNode.RepetitionScore() 80 | typedNode.ForEachElement(typedNode, func(parent, v2 cql.ASTNode) { 81 | switch typedNode2 := v2.(type) { 82 | case *cql.OnePosition: 83 | extractPositionFeatures(typedNode2, charProbs, &pos) 84 | pos.Index = positionIndex 85 | eval.Positions = append(eval.Positions, pos) 86 | positionIndex++ 87 | } 88 | }) 89 | } 90 | } 91 | }) 92 | } 93 | 94 | // Second pass: extract global features from entire query 95 | query.ForEachElement(func(parent, v cql.ASTNode) { 96 | switch typedNode := v.(type) { 97 | case *cql.GlobPart: 98 | eval.NumGlobConditions++ 99 | 100 | case *cql.WithinOrContaining: 101 | if typedNode.NumWithinParts() > 0 { 102 | eval.ContainsWithin = 1 103 | } 104 | if typedNode.NumContainingParts() > 0 { 105 | eval.ContainsContaining = 1 106 | } 107 | 108 | typedNode.ForEachElement(typedNode, func(parent, v2 cql.ASTNode) { 109 | switch typedNode2 := v2.(type) { 110 | case *cql.Repetition: 111 | eval.AdhocSubcorpus += typedNode2.SubcorpusDefScore() 112 | } 113 | }) 114 | 115 | case *cql.MeetOp: 116 | eval.ContainsMeet = 1 117 | 118 | case *cql.UnionOp: 119 | eval.ContainsUnion = 1 120 | 121 | case *cql.AlignedPart: 122 | eval.AlignedPart = 1 123 | } 124 | }) 125 | } 126 | 127 | func textToProbs(v string, probsMap charProbabilityProvider) float64 { 128 | var ansProb float64 = 0 129 | var size int 130 | for _, c := range v { 131 | ansProb += probsMap.CharProbability(c) 132 | size++ 133 | } 134 | return ansProb / float64(size) 135 | } 136 | 137 | // extractPositionFeatures analyzes a position to extract all features including regexp and attribute info 138 | func extractPositionFeatures(pos *cql.OnePosition, charProbs charProbabilityProvider, outPos *Position) { 139 | 140 | // Check if this is an empty position query [] 141 | numAlternatives := 0 142 | // Traverse the position to find regexp patterns and attribute info 143 | // Using DFS-like approach to maintain proper parent-child context 144 | pos.ForEachElement(pos, func(parent, v cql.ASTNode) { 145 | switch typedNode := v.(type) { 146 | case *cql.RegExp: 147 | analyzeRegExp(typedNode, &outPos.Regexp, charProbs) 148 | 149 | case *cql.RgSimple: 150 | // Use the built-in method to count wildcards 151 | outPos.Regexp.WildcardScore += typedNode.WildcardScore() 152 | 153 | case *cql.RawString: 154 | // Simple string - count characters 155 | text := typedNode.Text() 156 | if len(text) > 2 { 157 | text = strings.Trim(text, `"`) 158 | outPos.Regexp.NumConcreteChars = float64(len(text) - 2) // -2 for quotes 159 | outPos.Regexp.AvgCharProb = textToProbs(text, charProbs) 160 | } 161 | 162 | case *cql.RgAlt: 163 | outPos.Regexp.CharClasses = typedNode.Score() 164 | 165 | case *cql.AttVal: 166 | // Check if this is a small cardinality attribute 167 | if isSmallCardinalityAttr(typedNode) { 168 | outPos.HasSmallCardAttr = 500 169 | } 170 | if !typedNode.IsRecursive() { 171 | numAlternatives++ 172 | } 173 | if typedNode.IsNegation() { 174 | outPos.HasNegation = 1 175 | } 176 | } 177 | }) 178 | 179 | if numAlternatives > 0 { 180 | outPos.NumAlternatives = numAlternatives 181 | 182 | } else { 183 | outPos.NumAlternatives = 1 // AUTO-FIX 184 | // TODO - this should be solved within the AST, 185 | // it is caused by direct regexp queries: "foo" 186 | } 187 | outPos.Regexp.NumConcreteChars /= float64(outPos.NumAlternatives) 188 | } 189 | 190 | // analyzeRegExp examines a RegExp node to extract features 191 | func analyzeRegExp(re *cql.RegExp, regexp *Regexp, charProbs charProbabilityProvider) { 192 | if len(re.RegExpRaw) == 0 { 193 | return 194 | } 195 | 196 | // Check if starts with wildcard 197 | firstRaw := re.RegExpRaw[0] 198 | if startsWithWildcard(firstRaw) { 199 | regexp.StartsWithWildCard = 1 200 | } 201 | 202 | // Count concrete chars and check for ranges 203 | var concreteChars int 204 | var avgCharProb float64 205 | hasRange := false 206 | 207 | for _, raw := range re.RegExpRaw { 208 | raw.ForEachElement(raw, func(parent, v cql.ASTNode) { 209 | switch typedNode := v.(type) { 210 | case *cql.RgRange: 211 | hasRange = true 212 | 213 | case *cql.RgChar: 214 | if typedNode.IsConstant() { 215 | concreteChars++ 216 | avgCharProb += textToProbs(typedNode.Text(), charProbs) 217 | } 218 | } 219 | }) 220 | } 221 | 222 | if concreteChars > 0 { 223 | regexp.NumConcreteChars += float64(concreteChars) 224 | avgCharProb /= float64(concreteChars) 225 | regexp.AvgCharProb += avgCharProb 226 | } 227 | if hasRange { 228 | regexp.HasRange = 1 229 | } 230 | } 231 | 232 | // startsWithWildcard checks if a RegExpRaw starts with a wildcard operator 233 | func startsWithWildcard(raw *cql.RegExpRaw) bool { 234 | if len(raw.Values) == 0 { 235 | return false 236 | } 237 | return strings.HasPrefix(raw.Text(), ".+") || strings.HasPrefix(raw.Text(), ".*") 238 | } 239 | 240 | // isSmallCardinalityAttr checks if an attribute has small cardinality 241 | // These are attributes like tag, pos, etc. that have few possible values 242 | func isSmallCardinalityAttr(attVal *cql.AttVal) bool { 243 | var attrName string 244 | 245 | // Extract attribute name from either variant 246 | if attVal.Variant1 != nil && attVal.Variant1.AttName != "" { 247 | attrName = strings.ToLower(attVal.Variant1.AttName.String()) 248 | 249 | } else if attVal.Variant2 != nil && attVal.Variant2.AttName != "" { 250 | attrName = strings.ToLower(attVal.Variant2.AttName.String()) 251 | 252 | } else { 253 | return false 254 | } 255 | 256 | // List of known small cardinality attributes 257 | // tag, pos are typical linguistic attributes with limited value sets 258 | smallCardAttrs := []string{"tag", "pos", "postag", "xpos", "upos", "deprel"} 259 | return slices.Contains(smallCardAttrs, attrName) 260 | } 261 | 262 | // ExtractFeatures converts QueryEvaluation to feature vector (same as Huber) 263 | func ExtractFeatures(eval QueryEvaluation) []float64 { 264 | features := make([]float64, NumFeatures) 265 | idx := 0 266 | 267 | // Extract features for up to 4 positions 268 | for i := range MaxPositions { 269 | if i < len(eval.Positions) { 270 | pos := eval.Positions[i] 271 | // Position-specific features (normalized by concrete chars) 272 | features[idx] = float64(pos.Regexp.StartsWithWildCard) 273 | features[idx+1] = pos.Regexp.WildcardScore 274 | features[idx+2] = float64(pos.Regexp.HasRange) 275 | features[idx+3] = float64(pos.HasSmallCardAttr) 276 | features[idx+4] = float64(pos.Regexp.NumConcreteChars) 277 | features[idx+5] = pos.Regexp.AvgCharProb 278 | features[idx+6] = float64(pos.NumAlternatives) 279 | features[idx+7] = pos.PosRepetition 280 | features[idx+8] = pos.Regexp.CharClasses 281 | features[idx+9] = float64(pos.HasNegation) 282 | } 283 | // If position doesn't exist, features remain 0 284 | idx += 10 285 | } 286 | 287 | // Global features 288 | features[40] = float64(eval.NumGlobConditions) 289 | features[41] = float64(eval.ContainsMeet) 290 | features[42] = float64(eval.ContainsUnion) 291 | features[43] = float64(eval.ContainsWithin) 292 | features[44] = eval.AdhocSubcorpus 293 | features[45] = float64(eval.ContainsContaining) 294 | features[46] = logScaled(eval.CorpusSize) 295 | features[47] = logScaled(eval.NamedSubcorpusSize) 296 | features[48] = float64(eval.AlignedPart) 297 | features[49] = 1.0 // Bias term 298 | 299 | return features 300 | } 301 | -------------------------------------------------------------------------------- /cqlizer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2024 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | //go:generate pigeon -o ./cql/grammar.go ./cql/grammar.peg 18 | 19 | package main 20 | 21 | import ( 22 | "context" 23 | _ "embed" 24 | "flag" 25 | "fmt" 26 | "os" 27 | "os/signal" 28 | "path/filepath" 29 | "strings" 30 | "syscall" 31 | 32 | "github.com/czcorpus/cnc-gokit/logging" 33 | "github.com/czcorpus/cqlizer/apiserver" 34 | "github.com/czcorpus/cqlizer/cnf" 35 | ) 36 | 37 | const ( 38 | actionMCPServer = "mcp-server" 39 | actionREPL = "repl" 40 | actionVersion = "version" 41 | actionHelp = "help" 42 | actionLearn = "learn" 43 | actionFeaturize = "featurize" 44 | actionEvaluate = "evaluate" 45 | actionBenchmarkMissing = "benchmark-missing" 46 | actionRemoveZero = "remove-zero" 47 | actionAPIServer = "server" 48 | 49 | exitErrorGeneralFailure = iota 50 | exitErrorImportFailed 51 | exiterrrorREPLReading 52 | exitErrorFailedToOpenIdex 53 | exitErrorFailedToOpenQueryPersistence 54 | exitErrorFailedToOpenW2VModel 55 | ) 56 | 57 | var ( 58 | version string 59 | buildDate string 60 | gitCommit string 61 | ) 62 | 63 | //go:embed scripts/rfchart.py 64 | var rfChartScript string 65 | 66 | // --------------------------------------------- 67 | 68 | func topLevelUsage() { 69 | fmt.Fprintf(os.Stderr, "CQLIZER - a data-driven CQL writing helper tool\n") 70 | fmt.Fprintf(os.Stderr, "-----------------------------\n\n") 71 | fmt.Fprintf(os.Stderr, "Commands:\n") 72 | fmt.Fprintf(os.Stderr, "\t%s\t\t\tshow version info\n", actionVersion) 73 | fmt.Fprintf(os.Stderr, "\t%s\t\t\thelp on a specific action (cqlizer help ACTION)\n", actionHelp) 74 | fmt.Fprintf(os.Stderr, "\t%s\t\t\trun API server providing CQL evaluation functions\n", actionAPIServer) 75 | fmt.Fprintf(os.Stderr, "\t%s\t\ttransform query log into features\n", actionFeaturize) 76 | fmt.Fprintf(os.Stderr, "\t%s\t\tremove zero processing time items from a log\n", actionRemoveZero) 77 | fmt.Fprintf(os.Stderr, "\t%s\t\t\tlearn model based on provided features\n", actionLearn) 78 | fmt.Fprintf(os.Stderr, "\t%s\t\t\tevaluate model (precision, recall, f-beta) using provided data\n", actionLearn) 79 | fmt.Fprintf(os.Stderr, "\t%s\tbenchmark queries with zero processing time (using MQuery)\n", actionBenchmarkMissing) 80 | fmt.Fprintf(os.Stderr, "\t%s\t\t\tREPL for CQL evaluation\n", actionREPL) 81 | fmt.Fprintf(os.Stderr, "\t%s\t\tmcp-server MCP (experimental/unfinished) \n", actionMCPServer) 82 | fmt.Fprintf(os.Stderr, "\nUse `cqlizer help ACTION` for information about a specific action\n\n") 83 | } 84 | 85 | func setup(confPath string) *cnf.Conf { 86 | conf := cnf.LoadConfig(confPath) 87 | if conf.Logging.Level == "" { 88 | conf.Logging.Level = "info" 89 | } 90 | logging.SetupLogging(conf.Logging) 91 | cnf.ValidateAndDefaults(conf) 92 | return conf 93 | } 94 | 95 | func cleanVersionInfo(v string) string { 96 | return strings.TrimLeft(strings.Trim(v, "'"), "v") 97 | } 98 | 99 | func runActionMCPServer() { 100 | 101 | } 102 | 103 | func runActionVersion(ver apiserver.VersionInfo) { 104 | fmt.Fprintln(os.Stderr, "CQLizer version: ", ver) 105 | } 106 | 107 | func main() { 108 | version := apiserver.VersionInfo{ 109 | Version: cleanVersionInfo(version), 110 | BuildDate: cleanVersionInfo(buildDate), 111 | GitCommit: cleanVersionInfo(gitCommit), 112 | } 113 | 114 | cmdMCP := flag.NewFlagSet(actionMCPServer, flag.ExitOnError) 115 | cmdMCP.Usage = func() { 116 | fmt.Fprintf( 117 | os.Stderr, 118 | "Usage:\t%s %s [options] config.json\n\t", 119 | filepath.Base(os.Args[0]), actionMCPServer) 120 | fmt.Fprintf(os.Stderr, "\nOptions:\n") 121 | cmdMCP.PrintDefaults() 122 | fmt.Fprintf(os.Stderr, "\nSrun CQLizer as a MCP server\n") 123 | } 124 | 125 | cmdVersion := flag.NewFlagSet(actionVersion, flag.ExitOnError) 126 | cmdVersion.Usage = func() { 127 | cmdVersion.PrintDefaults() 128 | // TOOD 129 | } 130 | 131 | cmdHelp := flag.NewFlagSet(actionHelp, flag.ExitOnError) 132 | cmdHelp.Usage = func() { 133 | cmdVersion.PrintDefaults() 134 | } 135 | 136 | cmdREPL := flag.NewFlagSet(actionREPL, flag.ExitOnError) 137 | replModel := cmdREPL.String("model", "rf", "Specifies model which will be used (xg, rf, nn)") 138 | cmdREPL.Usage = func() { 139 | cmdREPL.PrintDefaults() 140 | } 141 | 142 | cmdKlogImport := flag.NewFlagSet(actionLearn, flag.ExitOnError) 143 | numTrees := cmdKlogImport.Int("num-trees", 100, "Number of trees for Random Forest (default: 100)") 144 | klogImportModel := cmdKlogImport.String("model", "rf", "Specifies model which will be used (xg, rf, nn)") 145 | voteThreshold := cmdKlogImport.Float64("vote-threshold", 0, "RF Vote threshold for marking CQL as problematic. This affects only evaluation. If none, then range from 0.7 to 0.99 is examined") 146 | klogImportMisclassOut := cmdKlogImport.String("misclassed-query-log", "", "Specify a path to store misclassified queries. If none, no logging is performed.") 147 | 148 | cmdKlogImport.Usage = func() { 149 | fmt.Fprintf(os.Stderr, "Usage: %s learn [options] config.json features_file.msgpack\n", os.Args[0]) 150 | fmt.Fprintf(os.Stderr, "\nOptions:\n") 151 | cmdKlogImport.PrintDefaults() 152 | } 153 | 154 | cmdEvaluate := flag.NewFlagSet(actionEvaluate, flag.ExitOnError) 155 | cmdEvaluateModel := cmdEvaluate.String("model", "rf", "Specifies model which will be used (xg, rf, nn)") 156 | cmdEvaluateMisclassOut := cmdEvaluate.String("misclassed-query-log", "", "Specify a path to store misclassified queries. If none, no logging is performed.") 157 | cmdEvaluate.Usage = func() { 158 | fmt.Fprintf(os.Stderr, "Usage: %s evaluate [options] config.json model_file testing_data \n", os.Args[0]) 159 | fmt.Fprintf(os.Stderr, "\nOptions:\n") 160 | cmdEvaluate.PrintDefaults() 161 | } 162 | 163 | cmdFeaturize := flag.NewFlagSet(actionFeaturize, flag.ExitOnError) 164 | featurizeDebug := cmdFeaturize.Bool( 165 | "debug", 166 | false, 167 | "if set then features will be written to stdout in human readable form and no feats file will be created", 168 | ) 169 | cmdFeaturize.Usage = func() { 170 | fmt.Fprintf(os.Stderr, "Usage: %s featurize [options] config.json logfile.txt\n", os.Args[0]) 171 | fmt.Fprintf(os.Stderr, "\nOptions:\n") 172 | cmdFeaturize.PrintDefaults() 173 | } 174 | 175 | cmdBenchmarkMissing := flag.NewFlagSet(actionBenchmarkMissing, flag.ExitOnError) 176 | benchmarkSpecCorpora := cmdBenchmarkMissing.String("corpora", "", "A forced list of comma-separated corpora to process, everything else ignored. If not set, all the corpora found in MQuery will be used.") 177 | benchmarkBatchSize := cmdBenchmarkMissing.Int("batch-size", 0, "Max. number of items to process at once") 178 | benchmarkBatchOffset := cmdBenchmarkMissing.Int("batch-offset", 0, "Where (in the sorted list of entries; zero indexed) to start with the current run.") 179 | cmdBenchmarkMissing.Usage = func() { 180 | fmt.Fprintf(os.Stderr, "Usage: %s benchmark-missing logfile.jsonl mquery_url\n", os.Args[0]) 181 | cmdBenchmarkMissing.PrintDefaults() 182 | } 183 | 184 | cmdRemoveZero := flag.NewFlagSet(actionRemoveZero, flag.ExitOnError) 185 | cmdRemoveZero.Usage = func() { 186 | fmt.Fprintf(os.Stderr, "Usage: %s remove-zero logfile.jsonl\n", os.Args[0]) 187 | cmdRemoveZero.PrintDefaults() 188 | } 189 | 190 | cmdAPIServer := flag.NewFlagSet(actionAPIServer, flag.ExitOnError) 191 | cmdAPIServer.Usage = func() { 192 | fmt.Fprintf(os.Stderr, "Usage: %s server [options] config.json\n", os.Args[0]) 193 | fmt.Fprintf(os.Stderr, "\nOptions:\n") 194 | cmdAPIServer.PrintDefaults() 195 | } 196 | 197 | action := actionHelp 198 | if len(os.Args) > 1 { 199 | action = os.Args[1] 200 | } 201 | 202 | switch action { 203 | case actionHelp: 204 | var subj string 205 | if len(os.Args) > 2 { 206 | cmdHelp.Parse(os.Args[2:]) 207 | subj = cmdHelp.Arg(0) 208 | } 209 | if subj == "" { 210 | topLevelUsage() 211 | return 212 | } 213 | switch subj { 214 | case actionLearn: 215 | cmdKlogImport.PrintDefaults() 216 | case actionMCPServer: 217 | cmdMCP.PrintDefaults() 218 | case actionREPL: 219 | cmdREPL.PrintDefaults() 220 | } 221 | case actionVersion: 222 | cmdVersion.Parse(os.Args[2:]) 223 | runActionVersion(version) 224 | case actionMCPServer: 225 | cmdMCP.Parse(os.Args[2:]) 226 | runActionMCPServer() 227 | case actionREPL: 228 | cmdREPL.Parse(os.Args[2:]) 229 | if cmdREPL.NArg() < 1 { 230 | fmt.Fprintf(os.Stderr, "Error: model file path required\n") 231 | fmt.Fprintf(os.Stderr, "Usage: %s repl <model_file.json>\n", os.Args[0]) 232 | os.Exit(1) 233 | } 234 | modelPath := cmdREPL.Arg(0) 235 | runActionREPL(*replModel, modelPath) 236 | case actionLearn: 237 | cmdKlogImport.Parse(os.Args[2:]) 238 | conf := setup(cmdKlogImport.Arg(0)) 239 | 240 | runActionKlogImport( 241 | conf, 242 | cmdKlogImport.Arg(1), 243 | *klogImportModel, 244 | *numTrees, 245 | *voteThreshold, 246 | *klogImportMisclassOut, 247 | ) 248 | case actionEvaluate: 249 | cmdEvaluate.Parse(os.Args[2:]) 250 | conf := setup(cmdEvaluate.Arg(0)) 251 | runActionEvaluate( 252 | conf, 253 | cmdEvaluate.Arg(1), 254 | *cmdEvaluateModel, 255 | cmdEvaluate.Arg(2), 256 | *cmdEvaluateMisclassOut, 257 | ) 258 | case actionFeaturize: 259 | cmdFeaturize.Parse(os.Args[2:]) 260 | conf := setup(cmdFeaturize.Arg(0)) 261 | ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 262 | defer stop() 263 | runActionFeaturize( 264 | ctx, 265 | conf, 266 | cmdFeaturize.Arg(1), 267 | cmdFeaturize.Arg(2), 268 | *featurizeDebug, 269 | ) 270 | case actionBenchmarkMissing: 271 | cmdBenchmarkMissing.Parse(os.Args[2:]) 272 | conf := setup(cmdBenchmarkMissing.Arg(0)) 273 | ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 274 | defer stop() 275 | corpora := []string{} 276 | if *benchmarkSpecCorpora != "" { 277 | corpora = strings.Split(*benchmarkSpecCorpora, ",") 278 | } 279 | runActionBenchmarkMissing( 280 | ctx, 281 | conf, 282 | cmdBenchmarkMissing.Arg(1), 283 | corpora, 284 | *benchmarkBatchSize, 285 | *benchmarkBatchOffset, 286 | ) 287 | case actionRemoveZero: 288 | cmdRemoveZero.Parse(os.Args[2:]) 289 | conf := setup(cmdRemoveZero.Arg(0)) 290 | ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 291 | defer stop() 292 | runActionRemoveZero( 293 | ctx, 294 | conf, 295 | cmdRemoveZero.Arg(1), 296 | ) 297 | 298 | case actionAPIServer: 299 | cmdAPIServer.Parse(os.Args[2:]) 300 | conf := setup(cmdAPIServer.Arg(0)) 301 | ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 302 | defer stop() 303 | apiserver.Run(ctx, conf, version) 304 | default: 305 | fmt.Fprintf(os.Stderr, "Unknown action, please use 'help' to get more information") 306 | } 307 | 308 | } 309 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /eval/feats/score.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package feats 18 | 19 | import ( 20 | "encoding/json" 21 | "fmt" 22 | "math" 23 | "os" 24 | "strings" 25 | ) 26 | 27 | const NumFeatures = 50 28 | 29 | type CostProvider interface { 30 | Cost(model ModelParams) float64 31 | } 32 | 33 | // ----------------------------------- 34 | 35 | type ModelParams struct { 36 | WildcardPrefix0 float64 37 | Wildcards0 float64 38 | RangeOp0 float64 39 | 40 | // SmallCardAttr0 if 1, it means that we search by an attribute which 41 | // has only a few possible values and thus the resulting set will be large. 42 | // This typically applies for attributes/searches like [tag="..."], [pos="..."] 43 | // and we also consider the special `[]` query (= any word) as part of that. 44 | SmallCardAttr0 float64 45 | ConcreteChars0 float64 46 | AvgCharProb0 float64 47 | NumPosAlts0 float64 48 | PosRepetition0 float64 49 | CharClasses0 float64 50 | HasNegation0 float64 51 | 52 | WildcardPrefix1 float64 53 | Wildcards1 float64 54 | RangeOp1 float64 55 | SmallCardAttr1 float64 56 | ConcreteChars1 float64 57 | AvgCharProb1 float64 58 | NumPosAlts1 float64 59 | PosRepetition1 float64 60 | CharClasses1 float64 61 | HasNegation1 float64 62 | 63 | WildcardPrefix2 float64 64 | Wildcards2 float64 65 | RangeOp2 float64 66 | SmallCardAttr2 float64 67 | ConcreteChars2 float64 68 | AvgCharProb2 float64 69 | NumPosAlts2 float64 70 | PosRepetition2 float64 71 | CharClasses2 float64 72 | HasNegation2 float64 73 | 74 | WildcardPrefix3 float64 75 | Wildcards3 float64 76 | RangeOp3 float64 77 | SmallCardAttr3 float64 78 | ConcreteChars3 float64 79 | AvgCharProb3 float64 80 | NumPosAlts3 float64 81 | PosRepetition3 float64 82 | CharClasses3 float64 83 | HasNegation3 float64 84 | 85 | GlobCond float64 86 | Meet float64 87 | Union float64 88 | Within float64 89 | AdhocSubcorpus float64 90 | Containing float64 91 | CorpusSize float64 // Impact of corpus size on query time 92 | NamedSubcorpusSize float64 93 | AlignedPart float64 94 | Bias float64 95 | } 96 | 97 | func (p ModelParams) ToSlice() []float64 { 98 | return []float64{ 99 | p.WildcardPrefix0, 100 | p.Wildcards0, 101 | p.RangeOp0, 102 | p.SmallCardAttr0, 103 | p.ConcreteChars0, 104 | p.AvgCharProb0, 105 | p.NumPosAlts0, 106 | p.PosRepetition0, 107 | p.CharClasses0, 108 | p.HasNegation0, 109 | p.WildcardPrefix1, 110 | p.Wildcards1, 111 | p.RangeOp1, 112 | p.SmallCardAttr1, 113 | p.ConcreteChars1, 114 | p.AvgCharProb1, 115 | p.NumPosAlts1, 116 | p.PosRepetition1, 117 | p.CharClasses1, 118 | p.HasNegation1, 119 | p.WildcardPrefix2, 120 | p.Wildcards2, 121 | p.RangeOp2, 122 | p.SmallCardAttr2, 123 | p.ConcreteChars2, 124 | p.AvgCharProb2, 125 | p.NumPosAlts2, 126 | p.PosRepetition2, 127 | p.CharClasses2, 128 | p.HasNegation2, 129 | p.WildcardPrefix3, 130 | p.Wildcards3, 131 | p.RangeOp3, 132 | p.SmallCardAttr3, 133 | p.ConcreteChars3, 134 | p.AvgCharProb3, 135 | p.NumPosAlts3, 136 | p.PosRepetition3, 137 | p.CharClasses3, 138 | p.HasNegation3, 139 | p.GlobCond, 140 | p.Meet, 141 | p.Union, 142 | p.Within, 143 | p.AdhocSubcorpus, 144 | p.Containing, 145 | p.CorpusSize, 146 | p.NamedSubcorpusSize, 147 | p.AlignedPart, 148 | p.Bias, 149 | } 150 | } 151 | 152 | func SliceToModelParams(slice []float64) ModelParams { 153 | if len(slice) != NumFeatures { 154 | panic(fmt.Sprintf("slice must have %d elements", NumFeatures)) 155 | } 156 | return ModelParams{ 157 | WildcardPrefix0: slice[0], 158 | Wildcards0: slice[1], 159 | RangeOp0: slice[2], 160 | SmallCardAttr0: slice[3], 161 | ConcreteChars0: slice[4], 162 | AvgCharProb0: slice[5], 163 | NumPosAlts0: slice[6], 164 | PosRepetition0: slice[7], 165 | CharClasses0: slice[8], 166 | HasNegation0: slice[9], 167 | WildcardPrefix1: slice[10], 168 | Wildcards1: slice[11], 169 | RangeOp1: slice[12], 170 | SmallCardAttr1: slice[13], 171 | ConcreteChars1: slice[14], 172 | AvgCharProb1: slice[15], 173 | NumPosAlts1: slice[16], 174 | PosRepetition1: slice[17], 175 | CharClasses1: slice[18], 176 | HasNegation1: slice[19], 177 | WildcardPrefix2: slice[20], 178 | Wildcards2: slice[21], 179 | RangeOp2: slice[22], 180 | SmallCardAttr2: slice[23], 181 | ConcreteChars2: slice[24], 182 | AvgCharProb2: slice[25], 183 | NumPosAlts2: slice[26], 184 | PosRepetition2: slice[27], 185 | CharClasses2: slice[28], 186 | HasNegation2: slice[29], 187 | WildcardPrefix3: slice[30], 188 | Wildcards3: slice[31], 189 | RangeOp3: slice[32], 190 | SmallCardAttr3: slice[33], 191 | ConcreteChars3: slice[34], 192 | AvgCharProb3: slice[35], 193 | NumPosAlts3: slice[36], 194 | PosRepetition3: slice[37], 195 | CharClasses3: slice[38], 196 | HasNegation3: slice[39], 197 | GlobCond: slice[40], 198 | Meet: slice[41], 199 | Union: slice[42], 200 | Within: slice[43], 201 | AdhocSubcorpus: slice[44], 202 | Containing: slice[45], 203 | CorpusSize: slice[46], 204 | NamedSubcorpusSize: slice[47], 205 | AlignedPart: slice[48], 206 | Bias: slice[49], 207 | } 208 | } 209 | 210 | // SaveToFile saves the model parameters to a JSON file 211 | func (p ModelParams) SaveToFile(filePath string) error { 212 | file, err := os.Create(filePath) 213 | if err != nil { 214 | return fmt.Errorf("failed to create file: %w", err) 215 | } 216 | defer file.Close() 217 | 218 | encoder := json.NewEncoder(file) 219 | encoder.SetIndent("", " ") // Pretty-print with 2-space indentation 220 | if err := encoder.Encode(p); err != nil { 221 | return fmt.Errorf("failed to encode model: %w", err) 222 | } 223 | 224 | return nil 225 | } 226 | 227 | // LoadModelFromFile loads model parameters from a JSON file 228 | func LoadModelFromFile(filePath string) (ModelParams, error) { 229 | file, err := os.Open(filePath) 230 | if err != nil { 231 | return ModelParams{}, fmt.Errorf("failed to open file: %w", err) 232 | } 233 | defer file.Close() 234 | 235 | var params ModelParams 236 | decoder := json.NewDecoder(file) 237 | if err := decoder.Decode(¶ms); err != nil { 238 | return ModelParams{}, fmt.Errorf("failed to decode model: %w", err) 239 | } 240 | return params, nil 241 | } 242 | 243 | // ----------------------------------- 244 | 245 | type Regexp struct { 246 | StartsWithWildCard int `msgpack:"startsWithWildCard"` 247 | NumConcreteChars float64 `msgpack:"numConcreteChars"` 248 | AvgCharProb float64 `msgpack:"avgCharProb"` 249 | WildcardScore float64 `msgpack:"wildcardScore"` 250 | HasRange int `msgpack:"hasRange"` 251 | CharClasses float64 `msgpack:"charClasses"` 252 | } 253 | 254 | // ----------------------------------- 255 | 256 | type Position struct { 257 | Index int `msgpack:"index"` 258 | Regexp Regexp `msgpack:"regexp"` 259 | HasSmallCardAttr int `msgpack:"hasSmallCardAttr"` // 1 if searching by attribute with small cardinality (tag, pos, etc.) or empty query [] 260 | NumAlternatives int `msgpack:"numAlternatives"` // at least 1, solves situations like [lemma="foo" | word="fooish"] 261 | PosRepetition float64 `msgpack:"posRepetition"` // stuff like [word="foo"]+ 262 | HasNegation int `msgpack:"hasNegation"` 263 | } 264 | 265 | // ----------------------------------- 266 | 267 | type QueryEvaluation struct { 268 | ProcTime float64 `msgpack:"procTime"` 269 | 270 | OrigQuery string `msgpack:"q"` 271 | Positions []Position `msgpack:"positions"` 272 | NumGlobConditions int `msgpack:"numGlobConditions"` 273 | ContainsMeet int `msgpack:"containsMeet"` 274 | ContainsUnion int `msgpack:"containsUnion"` 275 | ContainsWithin int `msgpack:"containsWithin"` 276 | AdhocSubcorpus float64 `msgpack:"adhocSubcorpus"` 277 | ContainsContaining int `msgpack:"containsContaining"` 278 | CorpusSize float64 `msgpack:"corpusSize"` // Size of the corpus being searched (e.g., number of tokens) 279 | NamedSubcorpusSize float64 `msgpack:"namedSubcorpusSize"` 280 | AlignedPart int `msgpack:"alignedPart"` 281 | } 282 | 283 | func (eval QueryEvaluation) UniqKey() string { 284 | return fmt.Sprintf("%s-%.5f", eval.OrigQuery, eval.CorpusSize) 285 | } 286 | 287 | func (eval QueryEvaluation) Show() string { 288 | var ans strings.Builder 289 | for i, pos := range eval.Positions { 290 | ans.WriteString(fmt.Sprintf("position %d:\n", i)) 291 | ans.WriteString(fmt.Sprintf(" HasSmallCardAttr: %d\n", pos.HasSmallCardAttr)) 292 | ans.WriteString(fmt.Sprintf(" NumAlternatives: %d\n", pos.NumAlternatives)) 293 | ans.WriteString(fmt.Sprintf(" PosRepetition: %.2f\n", pos.PosRepetition)) 294 | ans.WriteString(fmt.Sprintf(" HasNegation: %d\n", pos.HasNegation)) 295 | ans.WriteString(" regexp: \n") 296 | ans.WriteString(fmt.Sprintf(" StartsWithWildCard: %d\n", pos.Regexp.StartsWithWildCard)) 297 | ans.WriteString(fmt.Sprintf(" NumConcreteChars: %.2f\n", pos.Regexp.NumConcreteChars)) 298 | ans.WriteString(fmt.Sprintf(" AvgCharProb: %.2f\n", pos.Regexp.AvgCharProb)) 299 | ans.WriteString(fmt.Sprintf(" WildcardScore: %.2f\n", pos.Regexp.WildcardScore)) 300 | ans.WriteString(fmt.Sprintf(" HasRange: %d\n", pos.Regexp.HasRange)) 301 | ans.WriteString(fmt.Sprintf(" CharClasses: %.2f\n", pos.Regexp.CharClasses)) 302 | } 303 | ans.WriteString(fmt.Sprintf("NumGlobConditions: %d\n", eval.NumGlobConditions)) 304 | ans.WriteString(fmt.Sprintf("ContainsMeet: %d\n", eval.ContainsMeet)) 305 | ans.WriteString(fmt.Sprintf("ContainsUnion: %d\n", eval.ContainsUnion)) 306 | ans.WriteString(fmt.Sprintf("ContainsWithin: %d\n", eval.ContainsWithin)) 307 | ans.WriteString(fmt.Sprintf("AdhocSubcorpus: %.2f\n", eval.AdhocSubcorpus)) 308 | ans.WriteString(fmt.Sprintf("ContainsContaining: %d\n", eval.ContainsContaining)) 309 | ans.WriteString(fmt.Sprintf("CorpusSize: %0.2f\n", eval.CorpusSize)) 310 | ans.WriteString(fmt.Sprintf("NamedSubcorpusSize: %0.2f\n", eval.NamedSubcorpusSize)) 311 | ans.WriteString(fmt.Sprintf("AlignedPart: %d\n", eval.AlignedPart)) 312 | 313 | return ans.String() 314 | } 315 | 316 | func (eval QueryEvaluation) Cost(model ModelParams) float64 { 317 | var total float64 318 | 319 | // Compute position-specific costs 320 | for i := 0; i < len(eval.Positions) && i < MaxPositions; i++ { 321 | pos := eval.Positions[i] 322 | // Get position-specific parameters 323 | var wildcardPrefix, wildcards, rangeOp, smallCardAttr, concreteChars, 324 | avgCharProb, numPosAlts, posRepetition, charClasses, hasNegation float64 325 | switch i { 326 | case 0: 327 | wildcardPrefix = model.WildcardPrefix0 328 | wildcards = model.Wildcards0 329 | rangeOp = model.RangeOp0 330 | smallCardAttr = model.SmallCardAttr0 331 | concreteChars = model.ConcreteChars0 332 | avgCharProb = model.AvgCharProb0 333 | numPosAlts = model.NumPosAlts0 334 | posRepetition = model.PosRepetition0 335 | charClasses = model.CharClasses0 336 | hasNegation = model.HasNegation0 337 | case 1: 338 | wildcardPrefix = model.WildcardPrefix1 339 | wildcards = model.Wildcards1 340 | rangeOp = model.RangeOp1 341 | smallCardAttr = model.SmallCardAttr1 342 | concreteChars = model.ConcreteChars1 343 | avgCharProb = model.AvgCharProb1 344 | numPosAlts = model.NumPosAlts1 345 | posRepetition = model.PosRepetition1 346 | charClasses = model.CharClasses1 347 | hasNegation = model.HasNegation1 348 | case 2: 349 | wildcardPrefix = model.WildcardPrefix2 350 | wildcards = model.Wildcards2 351 | rangeOp = model.RangeOp2 352 | smallCardAttr = model.SmallCardAttr2 353 | concreteChars = model.ConcreteChars2 354 | avgCharProb = model.AvgCharProb2 355 | numPosAlts = model.NumPosAlts2 356 | posRepetition = model.PosRepetition2 357 | charClasses = model.CharClasses2 358 | hasNegation = model.HasNegation2 359 | case 3: 360 | wildcardPrefix = model.WildcardPrefix3 361 | wildcards = model.Wildcards3 362 | rangeOp = model.RangeOp3 363 | smallCardAttr = model.SmallCardAttr3 364 | concreteChars = model.ConcreteChars3 365 | avgCharProb = model.AvgCharProb3 366 | numPosAlts = model.NumPosAlts3 367 | posRepetition = model.PosRepetition3 368 | charClasses = model.CharClasses3 369 | hasNegation = model.HasNegation3 370 | } 371 | 372 | // Calculate position cost 373 | positionCost := (wildcardPrefix*float64(pos.Regexp.StartsWithWildCard) + 374 | wildcards*float64(pos.Regexp.WildcardScore) + 375 | rangeOp*float64(pos.Regexp.HasRange) + 376 | smallCardAttr*float64(pos.HasSmallCardAttr)) + 377 | concreteChars*float64(pos.Regexp.NumConcreteChars) + 378 | avgCharProb*float64(pos.Regexp.AvgCharProb) + 379 | numPosAlts*float64(pos.NumAlternatives) + 380 | posRepetition*pos.PosRepetition + 381 | charClasses*pos.Regexp.CharClasses + 382 | hasNegation*float64(pos.HasNegation) 383 | 384 | total += positionCost 385 | } 386 | 387 | // Add global costs 388 | total += model.GlobCond * float64(eval.NumGlobConditions) 389 | total += model.Meet * float64(eval.ContainsMeet) 390 | total += model.Union * float64(eval.ContainsUnion) 391 | total += model.Within * float64(eval.ContainsWithin) 392 | total += model.AdhocSubcorpus * float64(eval.AdhocSubcorpus) 393 | total += model.Containing * float64(eval.ContainsContaining) 394 | total += model.CorpusSize * math.Log(eval.CorpusSize) 395 | total += model.NamedSubcorpusSize * math.Log(eval.NamedSubcorpusSize) 396 | total += model.Bias 397 | 398 | return total 399 | } 400 | -------------------------------------------------------------------------------- /cql/rgsimple.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package cql 18 | 19 | import ( 20 | "encoding/json" 21 | "fmt" 22 | "strconv" 23 | "strings" 24 | ) 25 | 26 | type RgSimple struct { 27 | // RgRange / RgChar / RgAlt / RgPosixClass 28 | origValue string 29 | Values []ASTNode 30 | } 31 | 32 | func (r *RgSimple) Text() string { 33 | return r.origValue 34 | } 35 | 36 | func (r *RgSimple) WildcardScore() float64 { 37 | ans := 0.0 38 | r.ForEachElement(r, func(parent, item ASTNode) { 39 | switch tItem := item.(type) { 40 | case *RgChar: 41 | if tItem.Text() == "?" { 42 | ans += 1 43 | } 44 | } 45 | }) 46 | ans += float64(strings.Count(r.Text(), ".*")) * 20 47 | ans += float64(strings.Count(r.Text(), ".+")) * 20 48 | return ans 49 | } 50 | 51 | // ------------------------------------------------- 52 | 53 | type RgGrouped struct { 54 | Values []*RegExpRaw // the stuff here is A|B|C... 55 | } 56 | 57 | func (r *RgGrouped) Text() string { 58 | return "#RgGrouped" 59 | } 60 | 61 | func (r *RgGrouped) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 62 | fn(parent, r) 63 | for _, v := range r.Values { 64 | v.ForEachElement(r, fn) 65 | } 66 | } 67 | 68 | func (r *RgGrouped) DFS(fn func(v ASTNode)) { 69 | for _, v := range r.Values { 70 | v.DFS(fn) 71 | } 72 | fn(r) 73 | } 74 | 75 | // ---------------------------------------------------- 76 | 77 | type RgPosixClass struct { 78 | Value ASTString 79 | } 80 | 81 | func (r *RgPosixClass) Text() string { 82 | return "RgPosixClass" 83 | } 84 | 85 | func (r *RgPosixClass) MarshalJSON() ([]byte, error) { 86 | return json.Marshal(r.Value) 87 | } 88 | 89 | func (r *RgPosixClass) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 90 | fn(parent, r.Value) 91 | } 92 | 93 | func (r *RgPosixClass) DFS(fn func(v ASTNode)) { 94 | fn(r.Value) 95 | } 96 | 97 | // ---------------------------------------------------- 98 | 99 | type RgLook struct { 100 | Value ASTString 101 | } 102 | 103 | func (r *RgLook) Text() string { 104 | return "#RgLook" 105 | } 106 | 107 | func (r *RgLook) MarshalJSON() ([]byte, error) { 108 | return json.Marshal(struct { 109 | RuleName string 110 | Expansion RgLook 111 | }{ 112 | RuleName: "RgLook", 113 | Expansion: *r, 114 | }) 115 | } 116 | 117 | func (r *RgLook) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 118 | fn(parent, r.Value) 119 | } 120 | 121 | func (r *RgLook) DFS(fn func(v ASTNode)) { 122 | fn(r.Value) 123 | } 124 | 125 | // ---------------------------------------------------- 126 | 127 | type RgLookOperator struct { 128 | } 129 | 130 | // ----------------------------------------------------- 131 | 132 | type RgAlt struct { 133 | Values []*RgAltVal 134 | Not bool 135 | } 136 | 137 | func (r *RgAlt) NumItems() int { 138 | return len(r.Values) 139 | } 140 | 141 | func (r *RgAlt) Score() float64 { 142 | var ans float64 143 | for _, v := range r.Values { 144 | ans += v.SrchScore() 145 | } 146 | if r.Not { 147 | ans *= 5 // rough estimate 148 | } 149 | return ans 150 | } 151 | 152 | func (r *RgAlt) Text() string { 153 | var ans strings.Builder 154 | for i, v := range r.Values { 155 | if i > 0 { 156 | ans.WriteString(", ") 157 | } 158 | ans.WriteString(v.Text()) 159 | } 160 | return fmt.Sprintf("#RgAlt(%s)", ans.String()) 161 | } 162 | 163 | func (r *RgAlt) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 164 | fn(parent, r) 165 | for _, item := range r.Values { 166 | item.ForEachElement(r, fn) 167 | } 168 | } 169 | 170 | func (r *RgAlt) DFS(fn func(v ASTNode)) { 171 | for _, item := range r.Values { 172 | item.DFS(fn) 173 | } 174 | fn(r) 175 | } 176 | 177 | // -------------------------------------------------------- 178 | 179 | type rgCharVariant1 struct { 180 | Value ASTString 181 | IsUnicodeClass bool 182 | } 183 | 184 | type rgCharVariant2 struct { 185 | RgOp *RgOp 186 | } 187 | 188 | type rgCharVariant3 struct { 189 | RgRepeat *RgRepeat 190 | } 191 | 192 | type rgCharVariant4 struct { 193 | RgAny *RgAny 194 | } 195 | 196 | type rgCharVariant5 struct { 197 | RgQM *RgQM 198 | } 199 | 200 | type RgChar struct { 201 | variant1 *rgCharVariant1 202 | variant2 *rgCharVariant2 203 | variant3 *rgCharVariant3 204 | variant4 *rgCharVariant4 205 | variant5 *rgCharVariant5 206 | } 207 | 208 | func (rc *RgChar) IsUnicodeClass() bool { 209 | if rc.variant1 != nil { 210 | return rc.variant1.IsUnicodeClass 211 | } 212 | return false 213 | } 214 | 215 | func (rc *RgChar) Info() string { 216 | if rc.variant1 != nil { 217 | return fmt.Sprintf("#RgChar[%s]", rc.variant1.Value.String()) 218 | 219 | } else if rc.variant2 != nil { 220 | return fmt.Sprintf("#RgChar[%s]", rc.variant2.RgOp.Value.String()) 221 | 222 | } else if rc.variant3 != nil { 223 | return fmt.Sprintf("#RgChar[%s]", rc.variant3.RgRepeat.Value.String()) 224 | 225 | } else if rc.variant4 != nil { 226 | return fmt.Sprintf("#RgChar[%s]", rc.variant4.RgAny.Value.String()) 227 | 228 | } else if rc.variant5 != nil { 229 | return fmt.Sprintf("#RgChar[%s]", rc.variant5.RgQM.Value.String()) 230 | } 231 | return "#RgChar(_unknown_)" 232 | } 233 | 234 | func (rc *RgChar) Text() string { 235 | if rc.variant1 != nil { 236 | return rc.variant1.Value.String() 237 | 238 | } else if rc.variant2 != nil { 239 | return rc.variant2.RgOp.Value.String() 240 | 241 | } else if rc.variant3 != nil { 242 | return rc.variant3.RgRepeat.Value.String() 243 | 244 | } else if rc.variant4 != nil { 245 | return rc.variant4.RgAny.Value.String() 246 | 247 | } else if rc.variant5 != nil { 248 | return rc.variant5.RgQM.Value.String() 249 | } 250 | return "" 251 | } 252 | 253 | func (rc *RgChar) IsRgOperator(v string) bool { 254 | return rc.variant2 != nil && rc.variant2.RgOp.Value.String() == v 255 | } 256 | 257 | func (rc *RgChar) IsConstant() bool { 258 | return rc.variant1 != nil 259 | } 260 | 261 | func (r *RgChar) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 262 | fn(parent, r) 263 | if r.variant1 != nil { 264 | fn(r, r.variant1.Value) 265 | 266 | } else if r.variant2 != nil { 267 | r.variant2.RgOp.ForEachElement(r, fn) 268 | 269 | } else if r.variant3 != nil { 270 | r.variant3.RgRepeat.ForEachElement(r, fn) 271 | 272 | } else if r.variant4 != nil { 273 | r.variant4.RgAny.ForEachElement(r, fn) 274 | 275 | } else if r.variant5 != nil { 276 | r.variant5.RgQM.ForEachElement(r, fn) 277 | } 278 | } 279 | 280 | func (r *RgChar) DFS(fn func(v ASTNode)) { 281 | if r.variant1 != nil { 282 | fn(r.variant1.Value) 283 | 284 | } else if r.variant2 != nil { 285 | r.variant2.RgOp.DFS(fn) 286 | } else if r.variant3 != nil { 287 | r.variant3.RgRepeat.DFS(fn) 288 | 289 | } else if r.variant4 != nil { 290 | r.variant4.RgAny.DFS(fn) 291 | 292 | } else if r.variant5 != nil { 293 | r.variant5.RgQM.DFS(fn) 294 | } 295 | fn(r) 296 | } 297 | 298 | // ----------------------------------------------------------- 299 | 300 | type RgRepeat struct { 301 | effect float64 302 | Value ASTString 303 | } 304 | 305 | func (rr *RgRepeat) Text() string { 306 | return rr.Value.String() 307 | } 308 | 309 | func (rr *RgRepeat) MarshalJSON() ([]byte, error) { 310 | return json.Marshal( 311 | struct { 312 | RuleName string 313 | Expansion string 314 | Effect float64 315 | }{ 316 | RuleName: "RgRepeat", 317 | Expansion: string(rr.Value), 318 | Effect: rr.effect, 319 | }, 320 | ) 321 | } 322 | 323 | func (rr *RgRepeat) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 324 | fn(parent, rr.Value) 325 | } 326 | 327 | func (rr *RgRepeat) DFS(fn func(ASTNode)) { 328 | fn(rr.Value) 329 | } 330 | 331 | // ----------------------------------------------------------- 332 | 333 | type RgQM struct { 334 | effect float64 335 | Value ASTString 336 | } 337 | 338 | func (rr *RgQM) Text() string { 339 | return rr.Value.String() 340 | } 341 | 342 | func (rr *RgQM) MarshalJSON() ([]byte, error) { 343 | return json.Marshal( 344 | struct { 345 | RuleName string 346 | Expansion string 347 | Effect float64 348 | }{ 349 | RuleName: "RgQM", 350 | Expansion: string(rr.Value), 351 | Effect: rr.effect, 352 | }, 353 | ) 354 | } 355 | 356 | func (rr *RgQM) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 357 | fn(parent, rr.Value) 358 | } 359 | 360 | func (rr *RgQM) DFS(fn func(ASTNode)) { 361 | fn(rr.Value) 362 | } 363 | 364 | // ----------------------------------------------------------- 365 | 366 | type RgAny struct { 367 | effect float64 368 | Value ASTString 369 | } 370 | 371 | func (rr *RgAny) Text() string { 372 | return rr.Value.String() 373 | } 374 | 375 | func (rr *RgAny) MarshalJSON() ([]byte, error) { 376 | return json.Marshal( 377 | struct { 378 | RuleName string 379 | Expansion string 380 | Effect float64 381 | }{ 382 | RuleName: "RgAny", 383 | Expansion: string(rr.Value), 384 | Effect: rr.effect, 385 | }, 386 | ) 387 | } 388 | 389 | func (rr *RgAny) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 390 | fn(parent, rr.Value) 391 | } 392 | 393 | func (rr *RgAny) DFS(fn func(ASTNode)) { 394 | fn(rr.Value) 395 | } 396 | 397 | // ----------------------------------------------------------- 398 | 399 | type RgRange struct { 400 | RgRangeSpec *RgRangeSpec 401 | } 402 | 403 | func (r *RgRange) Text() string { 404 | if r.RgRangeSpec != nil { 405 | return r.RgRangeSpec.Text() 406 | } 407 | return "RgRange{?, ?}" 408 | } 409 | 410 | func (r *RgRange) MarshalJSON() ([]byte, error) { 411 | return json.Marshal(struct { 412 | RuleName string 413 | Expansion RgRange 414 | }{ 415 | RuleName: "RgRange", 416 | Expansion: *r, 417 | }) 418 | } 419 | 420 | // NumericRepr returns a numeric representation 421 | // of a repeat range operation ({a, b}). If something 422 | // is undefined, -1 is used. 423 | func (r *RgRange) NumericRepr() [2]int { 424 | if r.RgRangeSpec == nil { 425 | return [2]int{-1, -1} 426 | } 427 | v1, err := strconv.Atoi(r.RgRangeSpec.Number1.String()) 428 | if err != nil { 429 | panic("non-integer 1st value in RgRange") // should not happen - guaranteed by the parser 430 | } 431 | v2 := -1 432 | if r.RgRangeSpec.Number2 != "" { 433 | v2, err = strconv.Atoi(r.RgRangeSpec.Number2.String()) 434 | if err != nil { 435 | panic("non-integer 2nd value in RgRange") // should not happen - guaranteed by the parser 436 | } 437 | } 438 | return [2]int{v1, v2} 439 | } 440 | 441 | func (r *RgRange) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 442 | fn(parent, r) 443 | r.RgRangeSpec.ForEachElement(r, fn) 444 | } 445 | 446 | func (r *RgRange) DFS(fn func(v ASTNode)) { 447 | r.RgRangeSpec.DFS(fn) 448 | fn(r) 449 | } 450 | 451 | // ------------------------------------------------------------- 452 | 453 | type RgRangeSpec struct { 454 | origValue string 455 | Number1 ASTString 456 | Number2 ASTString 457 | } 458 | 459 | func (r *RgRangeSpec) Text() string { 460 | return r.origValue 461 | } 462 | 463 | func (r *RgRangeSpec) MarshalJSON() ([]byte, error) { 464 | return json.Marshal(struct { 465 | RuleName string 466 | Expansion RgRangeSpec 467 | }{ 468 | RuleName: "RgRangeSpec", 469 | Expansion: *r, 470 | }) 471 | } 472 | 473 | func (r *RgRangeSpec) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 474 | fn(parent, r) 475 | fn(parent, r.Number1) 476 | fn(parent, r.Number2) 477 | } 478 | 479 | func (r *RgRangeSpec) DFS(fn func(v ASTNode)) { 480 | fn(r.Number1) 481 | fn(r.Number2) 482 | fn(r) 483 | } 484 | 485 | // ------------------------------------------------------------- 486 | 487 | type AnyLetter struct { 488 | Value ASTString 489 | } 490 | 491 | func (a *AnyLetter) Text() string { 492 | return string(a.Value) 493 | } 494 | 495 | func (a *AnyLetter) MarshalJSON() ([]byte, error) { 496 | return json.Marshal(a.Value) 497 | } 498 | 499 | func (a *AnyLetter) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 500 | fn(parent, a.Value) 501 | } 502 | 503 | func (a *AnyLetter) DFS(fn func(v ASTNode)) { 504 | fn(a.Value) 505 | } 506 | 507 | // ------------------------------------------------------------- 508 | 509 | type RgOp struct { 510 | Value ASTString 511 | } 512 | 513 | func (r *RgOp) Text() string { 514 | return string(r.Value) 515 | } 516 | 517 | func (r *RgOp) MarshalJSON() ([]byte, error) { 518 | return json.Marshal(struct { 519 | RuleName string 520 | Expansion RgOp 521 | }{ 522 | RuleName: "RgOp", 523 | Expansion: *r, 524 | }) 525 | } 526 | 527 | func (r *RgOp) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 528 | fn(parent, r.Value) 529 | } 530 | 531 | func (r *RgOp) DFS(fn func(v ASTNode)) { 532 | fn(r.Value) 533 | } 534 | 535 | // ---------------------------------------------------------------- 536 | 537 | type rgAltValVariant1 struct { 538 | RgChar *RgChar 539 | } 540 | 541 | type rgAltValVariant2 struct { 542 | Value ASTString 543 | } 544 | 545 | type rgAltValVariant3 struct { 546 | From ASTString 547 | To ASTString 548 | } 549 | 550 | type RgAltVal struct { 551 | variant1 *rgAltValVariant1 552 | variant2 *rgAltValVariant2 553 | variant3 *rgAltValVariant3 554 | } 555 | 556 | func (rc *RgAltVal) SrchScore() float64 { 557 | if rc.variant1 != nil { 558 | textLen := float64(len(rc.variant1.RgChar.Text())) 559 | if rc.variant1.RgChar.IsUnicodeClass() { 560 | textLen *= 20 561 | } 562 | return textLen 563 | } 564 | if rc.variant2 != nil { 565 | return float64(len(rc.variant2.Value.Text())) 566 | } 567 | if rc.variant3 != nil { 568 | return float64(len(rc.variant3.From)) * 10 // TODO this is just a rough estimate 569 | } 570 | return 0 571 | } 572 | 573 | func (rc *RgAltVal) Text() string { 574 | return "#RgAltVal" 575 | } 576 | 577 | func (r *RgAltVal) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 578 | fn(parent, r) 579 | if r.variant1 != nil { 580 | r.variant1.RgChar.ForEachElement(r, fn) 581 | 582 | } else if r.variant2 != nil { 583 | fn(r, r.variant2.Value) 584 | 585 | } else if r.variant3 != nil { 586 | fn(r, r.variant3.From) 587 | fn(r, r.variant3.To) 588 | } 589 | } 590 | 591 | func (r *RgAltVal) DFS(fn func(v ASTNode)) { 592 | if r.variant1 != nil { 593 | r.variant1.RgChar.DFS(fn) 594 | 595 | } else if r.variant2 != nil { 596 | fn(r.variant2.Value) 597 | 598 | } else if r.variant3 != nil { 599 | fn(r.variant3.From) 600 | fn(r.variant3.To) 601 | } 602 | fn(r) 603 | } 604 | 605 | func (r *RgSimple) ForEachElement(parent ASTNode, fn func(parent, v ASTNode)) { 606 | fn(parent, r) 607 | for _, item := range r.Values { 608 | switch tItem := item.(type) { 609 | case *RgRange: 610 | tItem.ForEachElement(r, fn) 611 | case *RgChar: 612 | tItem.ForEachElement(r, fn) 613 | case *RgAlt: 614 | tItem.ForEachElement(r, fn) 615 | case *RgPosixClass: 616 | tItem.ForEachElement(r, fn) 617 | } 618 | } 619 | } 620 | 621 | func (r *RgSimple) DFS(fn func(v ASTNode)) { 622 | for _, item := range r.Values { 623 | switch tItem := item.(type) { 624 | case *RgRange: 625 | tItem.DFS(fn) 626 | case *RgChar: 627 | tItem.DFS(fn) 628 | case *RgAlt: 629 | tItem.DFS(fn) 630 | case *RgPosixClass: 631 | tItem.DFS(fn) 632 | } 633 | } 634 | fn(r) 635 | } 636 | -------------------------------------------------------------------------------- /eval/model.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package eval 18 | 19 | import ( 20 | "context" 21 | "errors" 22 | "fmt" 23 | "math" 24 | "math/rand/v2" 25 | "slices" 26 | "strings" 27 | "unicode/utf8" 28 | 29 | "github.com/czcorpus/cqlizer/cnf" 30 | "github.com/czcorpus/cqlizer/eval/feats" 31 | "github.com/czcorpus/cqlizer/eval/predict" 32 | "github.com/czcorpus/cqlizer/eval/zero" 33 | "github.com/rs/zerolog/log" 34 | "github.com/schollz/progressbar/v3" 35 | ) 36 | 37 | type PrecAndRecall struct { 38 | Precision float64 39 | Recall float64 40 | FBeta float64 41 | } 42 | 43 | func (pr PrecAndRecall) CSV(x float64) string { 44 | return fmt.Sprintf("%.2f;%.2f;%.2f;%.2f", x, pr.Precision, pr.Recall, pr.FBeta) 45 | } 46 | 47 | func findKneeDistance(items []feats.QueryEvaluation) (threshold float64, kneeIdx int) { 48 | 49 | n := len(items) 50 | if n < 2 { 51 | return items[n-1].ProcTime, 100.0 52 | } 53 | 54 | // Line from first to last point 55 | x1, y1 := 0.0, items[0].ProcTime 56 | x2, y2 := float64(n-1), items[n-1].ProcTime 57 | 58 | // Line equation coefficients: ax + by + c = 0 59 | a := y2 - y1 60 | b := x1 - x2 61 | c := x2*y1 - x1*y2 62 | 63 | normFactor := math.Sqrt(a*a + b*b) 64 | 65 | maxDist := 0.0 66 | kneeIdx = 0 67 | 68 | for i := 0; i < n; i++ { 69 | // Perpendicular distance from point to line 70 | dist := math.Abs(a*float64(i)+b*items[i].ProcTime+c) / normFactor 71 | if dist > maxDist { 72 | maxDist = dist 73 | kneeIdx = i 74 | } 75 | } 76 | threshold = items[kneeIdx].ProcTime 77 | return threshold, kneeIdx 78 | } 79 | 80 | // ------------------------------------ 81 | 82 | type QueryStatsRecord struct { 83 | Corpus string `json:"corpus"` 84 | CorpusSize int64 `json:"corpusSize"` 85 | SubcorpusSize int64 `json:"subcorpusSize"` 86 | TimeProc float64 `json:"timeProc"` 87 | Query string `json:"query"` 88 | 89 | // IsSynthetic specifies whether the record comes from 90 | // production KonText stats log or if it is generated 91 | // using a benchmarking module (= MQuery). 92 | IsSynthetic bool `json:"isSynthetic,omitempty"` 93 | } 94 | 95 | func (rec QueryStatsRecord) GetCQL() string { 96 | if strings.HasPrefix(rec.Query, "q") { 97 | return rec.Query[1:] 98 | } 99 | tmp := strings.SplitN(rec.Query, ",", 2) 100 | if len(tmp) > 1 { 101 | return tmp[1] 102 | } 103 | return rec.Query 104 | } 105 | 106 | func (rec QueryStatsRecord) UniqKey() string { 107 | return fmt.Sprintf("%d/%d/%s", rec.CorpusSize, rec.SubcorpusSize, rec.Query) 108 | } 109 | 110 | // ---------------------------- 111 | 112 | type LearningDataStats struct { 113 | NumProcessed int `msgpack:"numProcessed"` 114 | NumFailed int `msgpack:"numFailed"` 115 | DeduplicationRatio float64 `msgpack:"deduplicationRatio"` 116 | } 117 | 118 | func (stats LearningDataStats) AsComment() string { 119 | return fmt.Sprintf("source data - total items: %d, failed imports: %d, deduplicated ratio: %.2f", stats.NumProcessed, stats.NumFailed, stats.DeduplicationRatio) 120 | } 121 | 122 | // ---------------------------- 123 | 124 | // MLModel is a generalization of a Machine Learning model used to extract knowledge 125 | // about CQL queries. 126 | type MLModel interface { 127 | 128 | // Train trains the model based on input data. In case the model 129 | // supports only inference (e.g. our XGBoost), this should just prepare 130 | // data to a format required by actual program performing the learning. 131 | Train(ctx context.Context, data []feats.QueryEvaluation, slowQueriesTime float64, comment string) error 132 | 133 | Predict(feats.QueryEvaluation) predict.Prediction 134 | SetClassThreshold(v float64) 135 | GetClassThreshold() float64 136 | GetSlowQueriesThresholdTime() float64 137 | SaveToFile(string) error 138 | GetInfo() string 139 | 140 | // IsInferenceOnly specifies whether the model also supports 141 | IsInferenceOnly() bool 142 | 143 | // CreateModelFileName should generate proper model filename based 144 | // on the feature (i.e. input) file name. This should keep data and 145 | // model names organized and easy to search through. 146 | CreateModelFileName(featFile string) string 147 | } 148 | 149 | // ---------------------------- 150 | 151 | type misclassifiedQueryReporter interface { 152 | AddMisclassifiedQuery(q feats.QueryEvaluation, mlOut, threshold, slowProcTime float64) 153 | } 154 | 155 | // ---------------------------- 156 | 157 | type Predictor struct { 158 | mlModel MLModel 159 | 160 | Evaluations []feats.QueryEvaluation 161 | 162 | LearningDataStats LearningDataStats 163 | 164 | // slowQueryPercentile specifies which percentile of queries (by time) 165 | // is considered as "slow times". 166 | // This value is the one user enters. 167 | slowQueryPercentile float64 168 | 169 | // midpointIdx is derived from SlowQueryPercentile and represents a sorted data index 170 | // from which SlowQueryPercentile starts. 171 | midpointIdx int 172 | 173 | // binMidpoint is the threshold time where SlowQueryPercentile starts. The value 174 | // is derived from SlowQueryPercentile 175 | binMidpoint float64 176 | 177 | corpora map[string]feats.CorpusProps 178 | 179 | syntheticTimeCorrection float64 180 | } 181 | 182 | func NewPredictor( 183 | mlModel MLModel, 184 | conf *cnf.Conf, 185 | ) *Predictor { 186 | if mlModel == nil { 187 | mlModel = &zero.ZeroModel{} 188 | } 189 | return &Predictor{ 190 | corpora: conf.CorporaProps, 191 | mlModel: mlModel, 192 | syntheticTimeCorrection: conf.SyntheticTimeCorrection, 193 | binMidpoint: mlModel.GetSlowQueriesThresholdTime(), 194 | } 195 | } 196 | 197 | func (model *Predictor) FindAndSetDataMidpoint() { 198 | slices.SortFunc(model.Evaluations, func(v1, v2 feats.QueryEvaluation) int { 199 | if v1.ProcTime < v2.ProcTime { 200 | return -1 201 | 202 | } else if v1.ProcTime > v2.ProcTime { 203 | return 1 204 | } 205 | return 0 206 | }) 207 | for i := 0; i < len(model.Evaluations); i++ { 208 | if model.Evaluations[i].ProcTime > 450 { 209 | model.Evaluations[i].ProcTime = 450 210 | fmt.Println("HUGE WQUERY ------------ ", model.Evaluations[i].Positions) 211 | } 212 | } 213 | model.binMidpoint, model.midpointIdx = findKneeDistance(model.Evaluations) 214 | } 215 | 216 | func (model *Predictor) BalanceSample() []feats.QueryEvaluation { 217 | slices.SortFunc(model.Evaluations, func(v1, v2 feats.QueryEvaluation) int { 218 | if v1.ProcTime < v2.ProcTime { 219 | return -1 220 | 221 | } else if v1.ProcTime > v2.ProcTime { 222 | return 1 223 | } 224 | return 0 225 | }) 226 | for i := 0; i < len(model.Evaluations); i++ { 227 | if model.Evaluations[i].ProcTime > 450 { 228 | model.Evaluations[i].ProcTime = 450 229 | fmt.Println("HUGE WQUERY ------------ ", model.Evaluations[i].Positions) 230 | } 231 | } 232 | //model.MidpointIdx, model.BinMidpoint = model.computeThreshold() 233 | log.Info().Msg("creating a balanced sample for learning") 234 | model.binMidpoint, model.midpointIdx = findKneeDistance(model.Evaluations) 235 | model.slowQueryPercentile = float64(model.midpointIdx) / float64(len(model.Evaluations)) 236 | log.Info(). 237 | Float64("thresholdTime", model.binMidpoint). 238 | Int("thresholdIdx", model.midpointIdx). 239 | Float64("slowQueryPercentile", model.slowQueryPercentile). 240 | Int("totalQueries", len(model.Evaluations)). 241 | Int("positiveExamples", len(model.Evaluations)-model.midpointIdx). 242 | Float64("maxProcTime", model.Evaluations[len(model.Evaluations)-1].ProcTime). 243 | Float64("minProcTime", model.Evaluations[0].ProcTime). 244 | Msg("calculated threshold for slow queries") 245 | 246 | numPositive := len(model.Evaluations) - model.midpointIdx 247 | balEval := make([]feats.QueryEvaluation, numPositive*3) 248 | for i := 0; i < numPositive*2; i++ { 249 | balEval[i] = model.Evaluations[rand.IntN(model.midpointIdx)] 250 | } 251 | for i := range numPositive { 252 | balEval[i+numPositive*2] = model.Evaluations[model.midpointIdx+i] 253 | } 254 | oldEvals := model.Evaluations 255 | model.Evaluations = balEval 256 | return oldEvals 257 | } 258 | 259 | func (model *Predictor) ProcessEntry(entry QueryStatsRecord) error { 260 | if entry.CorpusSize == 0 { 261 | cProps, ok := model.corpora[entry.Corpus] 262 | if ok { 263 | entry.CorpusSize = int64(cProps.Size) 264 | log.Warn().Msg("fixed missing corpus size") 265 | 266 | } else { 267 | return fmt.Errorf("zero corpus size, unknown corpus %s - cannot fix", entry.Corpus) 268 | } 269 | } 270 | if entry.TimeProc <= 0 { 271 | return fmt.Errorf("invalid processing time %.2f", entry.TimeProc) 272 | } 273 | if entry.IsSynthetic { 274 | entry.TimeProc *= model.syntheticTimeCorrection 275 | } 276 | 277 | // Parse the CQL query and create evaluation with corpus size 278 | corpInfo := model.corpora[entry.Corpus] 279 | eval, err := feats.NewQueryEvaluation( 280 | entry.GetCQL(), 281 | float64(entry.CorpusSize), 282 | float64(entry.SubcorpusSize), 283 | entry.TimeProc, 284 | feats.GetCharProbabilityProvider(corpInfo.Lang), 285 | ) 286 | if err != nil { 287 | errMsg := err.Error() 288 | if utf8.RuneCountInString(errMsg) > 80 { 289 | errMsg = string([]rune(errMsg)[:80]) 290 | } 291 | log.Warn(). 292 | Err(errors.New(errMsg)). 293 | Str("query", entry.GetCQL()). 294 | Msg("Warning: Failed to parse query") 295 | return nil // Skip unparseable queries 296 | } 297 | 298 | model.Evaluations = append(model.Evaluations, eval) 299 | 300 | return nil 301 | } 302 | 303 | func (model *Predictor) SetStats(numProcessed, numFailed int) { 304 | model.LearningDataStats.NumProcessed = numProcessed 305 | model.LearningDataStats.NumFailed = numFailed 306 | } 307 | 308 | func (model *Predictor) PrecisionAndRecall(misclassQueries misclassifiedQueryReporter) PrecAndRecall { 309 | 310 | numTruePositives := 0 311 | numRelevant := 0 312 | numRetrieved := 0 313 | 314 | for i := 0; i < len(model.Evaluations); i++ { 315 | trulySlow := model.Evaluations[i].ProcTime >= model.binMidpoint 316 | prediction := model.mlModel.Predict(model.Evaluations[i]) 317 | if trulySlow != (prediction.PredictedClass == 1) && misclassQueries != nil { 318 | misclassQueries.AddMisclassifiedQuery( 319 | model.Evaluations[i], prediction.SlowQueryVote(), model.mlModel.GetClassThreshold(), model.mlModel.GetSlowQueriesThresholdTime()) 320 | } 321 | if trulySlow { 322 | numRelevant++ 323 | } 324 | if prediction.PredictedClass == 1 { 325 | numRetrieved++ 326 | if trulySlow { 327 | numTruePositives++ 328 | 329 | } else { 330 | /* 331 | fmt.Printf( 332 | "WE SAY %s IS SLOW (%0.2f) BUT IT IS NOT (time %.2f, corpsize: %0.2f)\n", 333 | model.Evaluations[i].OrigQuery, prediction, model.Evaluations[i].ProcTime, math.Exp(model.Evaluations[i].CorpusSize), 334 | ) 335 | */ 336 | } 337 | } 338 | } 339 | precision := float64(numTruePositives) / float64(numRetrieved) 340 | recall := float64(numTruePositives) / float64(numRelevant) 341 | beta := 1.0 342 | fbeta := 0.0 343 | if precision+recall > 0 { 344 | betaSquared := beta * beta 345 | fbeta = (1 + betaSquared) * (precision * recall) / (betaSquared*precision + recall) 346 | } 347 | return PrecAndRecall{Precision: precision, Recall: recall, FBeta: fbeta} 348 | 349 | } 350 | 351 | func (model *Predictor) showSampleEvaluations(rfModel MLModel, maxSamples int, votingThreshold float64) { 352 | 353 | if len(model.Evaluations) < maxSamples { 354 | maxSamples = len(model.Evaluations) 355 | } 356 | 357 | // Test predictions on training data (for diagnostic purposes) 358 | fmt.Printf("\nSample predictions on training data (voting threshold: %.2f):\n", votingThreshold) 359 | 360 | fmt.Println("negative examples test: ") 361 | for i := 0; i < maxSamples; i++ { 362 | randomIdx := rand.IntN(model.midpointIdx) 363 | predicted := float64(rfModel.Predict(model.Evaluations[randomIdx]).PredictedClass) / 100.0 364 | actual := model.Evaluations[randomIdx].ProcTime < model.binMidpoint 365 | fmt.Printf( 366 | " %d, match: %t, vote NO: %.2f (time: %.2f)\n", 367 | randomIdx, actual == (predicted < votingThreshold), 1-predicted, model.Evaluations[randomIdx].ProcTime, 368 | ) 369 | } 370 | 371 | fmt.Println("POSITIVE examples test: ") 372 | for i := 0; i < maxSamples; i++ { 373 | randomIdx := rand.IntN(len(model.Evaluations)-model.midpointIdx) + model.midpointIdx 374 | predicted := float64(rfModel.Predict(model.Evaluations[randomIdx]).PredictedClass) / 100.0 375 | actual := model.Evaluations[randomIdx].ProcTime >= model.binMidpoint 376 | fmt.Printf( 377 | " %d, match: %t, vote YES: %.2f (time: %.2f)\n", 378 | randomIdx, actual == (predicted >= votingThreshold), predicted, model.Evaluations[randomIdx].ProcTime, 379 | ) 380 | } 381 | } 382 | 383 | func (model *Predictor) Deduplicate() { 384 | uniq := make(map[string][]feats.QueryEvaluation) 385 | for _, v := range model.Evaluations { 386 | _, ok := uniq[v.UniqKey()] 387 | if !ok { 388 | uniq[v.UniqKey()] = make([]feats.QueryEvaluation, 0, 10) 389 | } 390 | uniq[v.UniqKey()] = append(uniq[v.UniqKey()], v) 391 | } 392 | for _, evals := range uniq { 393 | slices.SortFunc(evals, func(v1, v2 feats.QueryEvaluation) int { 394 | if v1.ProcTime < v2.ProcTime { 395 | return -1 396 | } 397 | return 1 398 | }) 399 | sum := 0.0 400 | sum2 := 0.0 401 | n := 0.0 402 | for _, v := range evals { 403 | sum += v.ProcTime 404 | sum2 += v.ProcTime * v.ProcTime 405 | n += 1 406 | } 407 | mean := sum / n 408 | //variance := (sum2 / n) - (mean * mean) 409 | //stdDev := math.Sqrt(variance) 410 | var median float64 411 | if len(evals) <= 2 { 412 | median = mean 413 | 414 | } else { 415 | middle := int(math.Ceil(float64(len(evals)) / 2.0)) 416 | median = evals[middle].ProcTime 417 | } 418 | evals[0].ProcTime = median 419 | } 420 | model.Evaluations = make([]feats.QueryEvaluation, len(uniq)) 421 | i := 0 422 | for _, u := range uniq { 423 | model.Evaluations[i] = u[0] 424 | i++ 425 | } 426 | model.LearningDataStats.DeduplicationRatio = float64(len(uniq)) / float64(model.LearningDataStats.NumProcessed) 427 | log.Info().Int("newSize", len(model.Evaluations)).Msg("deduplicated queries") 428 | } 429 | 430 | // CreateAndTestModel trains a ML model and saves it to a file 431 | // specified by the `outputPath`. It also takes a python script 432 | func (model *Predictor) CreateAndTestModel( 433 | ctx context.Context, 434 | testData []feats.QueryEvaluation, 435 | featsFile string, 436 | reporter *Reporter, 437 | ) error { 438 | if len(model.Evaluations) == 0 { 439 | return fmt.Errorf("no training data available") 440 | } 441 | 442 | log.Info(). 443 | Int("trainingDataSize", len(model.Evaluations)). 444 | Msg("Training Random Forest") 445 | 446 | outputPath := model.mlModel.CreateModelFileName(featsFile) 447 | 448 | if err := model.mlModel.Train(ctx, model.Evaluations, model.binMidpoint, model.LearningDataStats.AsComment()); err != nil { 449 | return fmt.Errorf("RF training failed: %w", err) 450 | } 451 | 452 | if err := model.mlModel.SaveToFile(outputPath); err != nil { 453 | return fmt.Errorf("error saving model: %w", err) 454 | 455 | } else { 456 | log.Debug().Str("path", outputPath).Msg("saved model file") 457 | } 458 | 459 | if model.mlModel.IsInferenceOnly() { 460 | return nil 461 | } 462 | // ----- testing 463 | slices.SortFunc( 464 | testData, 465 | func(v1, v2 feats.QueryEvaluation) int { 466 | if v1.ProcTime < v2.ProcTime { 467 | return -1 468 | } 469 | if v1.ProcTime > v2.ProcTime { 470 | return 1 471 | } 472 | return 0 473 | }) 474 | model.Evaluations = testData 475 | 476 | log.Info(). 477 | Int("evalDataSize", len(model.Evaluations)). 478 | Msg("calculating precision and recall using full data") 479 | 480 | bar := progressbar.Default(int64(math.Ceil((1-0.5)/0.01)), "testing the model") 481 | var csv strings.Builder 482 | csv.WriteString("vote;precision;recall;f-beta\n") 483 | for v := 0.5; v < 1; v += 0.01 { 484 | select { 485 | case <-ctx.Done(): 486 | return nil 487 | default: 488 | } 489 | model.mlModel.SetClassThreshold(v) 490 | precall := model.PrecisionAndRecall(reporter) 491 | csv.WriteString(precall.CSV(v) + "\n") 492 | bar.Add(1) 493 | } 494 | if err := reporter.PlotRFAccuracy(csv.String(), model.mlModel.GetInfo(), outputPath); err != nil { 495 | return fmt.Errorf("failed to generate accuracy chart: %w", err) 496 | } 497 | reporter.SaveMisclassifiedQueries() 498 | return nil 499 | } 500 | -------------------------------------------------------------------------------- /apiserver/test_page.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 Tomas Machalek <tomas.machalek@gmail.com> 2 | // Copyright 2025 Department of Linguistics, 3 | // Faculty of Arts, Charles University 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | 17 | package apiserver 18 | 19 | import ( 20 | "fmt" 21 | "net/http" 22 | "path/filepath" 23 | "slices" 24 | "strings" 25 | 26 | "github.com/czcorpus/cqlizer/eval/modutils" 27 | "github.com/gin-gonic/gin" 28 | ) 29 | 30 | type corpusSelProp struct { 31 | Name string 32 | Size int64 33 | } 34 | 35 | func (api *apiServer) handleTestPage(ctx *gin.Context) { 36 | // Build corpus options from configuration 37 | var corpusOptions strings.Builder 38 | corpora := make([]corpusSelProp, 0, len(api.conf.CorporaProps)) 39 | for c, v := range api.conf.CorporaProps { 40 | if v.Size > 100000000 { 41 | corpora = append(corpora, corpusSelProp{Name: c, Size: int64(v.Size)}) 42 | } 43 | } 44 | slices.SortFunc(corpora, func(v1, v2 corpusSelProp) int { 45 | return int(v2.Size - v1.Size) 46 | }) 47 | for _, corpus := range corpora { 48 | corpusOptions.WriteString( 49 | fmt.Sprintf( 50 | "<option value=\"%s\">%s (%s)</option>\n", 51 | corpus.Name, corpus.Name, modutils.FormatRoughSize(corpus.Size), 52 | ), 53 | ) 54 | } 55 | 56 | // Get URL prefix for proxy support 57 | urlPrefix := api.conf.TestingPageURLPathPrefix 58 | if urlPrefix != "" && !strings.HasPrefix(urlPrefix, "/") { 59 | urlPrefix = "/" + urlPrefix 60 | } 61 | urlPrefix = strings.TrimSuffix(urlPrefix, "/") 62 | 63 | slowQueryVoteThreshold := 0.0 64 | var modelFiles strings.Builder 65 | for i, mod := range api.rfEnsemble { 66 | slowQueryVoteThreshold += mod.threshold 67 | if i > 0 { 68 | modelFiles.WriteString(", ") 69 | } 70 | modelFiles.WriteString(filepath.Base(mod.srcPath)) 71 | } 72 | slowQueryVoteThreshold /= float64(len(api.rfEnsemble)) 73 | 74 | html := fmt.Sprintf(`<!DOCTYPE html> 75 | <html lang="en"> 76 | <head> 77 | <meta charset="UTF-8"> 78 | <meta name="viewport" content="width=device-width, initial-scale=1.0"> 79 | <title>CQL Query Complexity Predictor - Test Page 80 | 359 | 360 | 361 |
362 |

CQL Query Complexity Predictor

363 | 364 |
365 | This tool predicts whether a CQL query will be slow or fast based on its complexity.
366 | Models ensemble: %s 367 |
368 | 369 |
370 |
371 | 372 | 375 |
376 | 380 |
381 |
382 | 383 |
384 | 385 | 386 |
387 | 388 | 391 |
392 | 393 |
394 |

Results

395 |
396 |

397 |         
398 |
399 | 400 |
401 | Version:%s 402 | Build:%s 403 |
404 | 405 | 546 | 547 | `, 548 | modelFiles.String(), 549 | corpusOptions.String(), 550 | api.version.Version, 551 | api.version.BuildDate, 552 | urlPrefix, 553 | slowQueryVoteThreshold) 554 | 555 | ctx.Header("Content-Type", "text/html; charset=utf-8") 556 | ctx.String(http.StatusOK, html) 557 | } 558 | --------------------------------------------------------------------------------