├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── cache └── README ├── datasets └── README ├── download-dependencies ├── fig ├── bin │ ├── chunk │ └── qcreate └── lib │ ├── execrunner.rb │ └── myutils.rb ├── release ├── core.files ├── dataset_debug.files ├── dataset_openweb.files ├── ling.files └── model.files ├── scripts ├── README ├── fake-google-search.py ├── get-webpage.py ├── google-search.py └── weblib │ ├── __init__.py │ ├── blacklist.py │ ├── tee.py │ └── web.py ├── src └── edu │ └── stanford │ └── nlp │ └── semparse │ └── open │ ├── Main.java │ ├── core │ ├── AllOptions.java │ ├── InteractiveDemo.java │ ├── OpenSemanticParser.java │ ├── ParallelizedTrainer.java │ └── eval │ │ ├── CandidateStatistics.java │ │ ├── EvaluationCase.java │ │ ├── EvaluationNormalFail.java │ │ ├── EvaluationSuccess.java │ │ ├── EvaluationSuperFail.java │ │ ├── Evaluator.java │ │ ├── EvaluatorStatistics.java │ │ └── IterativeTester.java │ ├── dataset │ ├── Criteria.java │ ├── CriteriaExactMatch.java │ ├── CriteriaGeneralWeb.java │ ├── Dataset.java │ ├── Example.java │ ├── ExampleCached.java │ ├── ExpectedAnswer.java │ ├── ExpectedAnswerCriteriaMatch.java │ ├── ExpectedAnswerInjectiveMatch.java │ ├── IRScore.java │ ├── entity │ │ ├── TargetEntity.java │ │ ├── TargetEntityNearMatch.java │ │ ├── TargetEntityPersonName.java │ │ ├── TargetEntityString.java │ │ └── TargetEntitySubstring.java │ └── library │ │ ├── DatasetLibrary.java │ │ ├── JSONDataset.java │ │ ├── JSONDatasetReader.java │ │ └── UnaryDatasets.java │ ├── ling │ ├── AveragedWordVector.java │ ├── BrownClusterTable.java │ ├── ClusterRepnUtils.java │ ├── CreateTypeEntityFeatures.java │ ├── FrequencyTable.java │ ├── LingData.java │ ├── LingTester.java │ ├── LingUtils.java │ ├── QueryTypeTable.java │ ├── WordNetClusterTable.java │ └── WordVectorTable.java │ ├── model │ ├── AdvancedWordVectorGradient.java │ ├── AdvancedWordVectorParams.java │ ├── AdvancedWordVectorParamsFullRank.java │ ├── AdvancedWordVectorParamsLowRank.java │ ├── FeatureCountPruner.java │ ├── FeatureDomainPruner.java │ ├── FeatureMatcher.java │ ├── FeatureVector.java │ ├── Learner.java │ ├── LearnerBaseline.java │ ├── LearnerMaxEnt.java │ ├── LearnerMaxEntWithBeamSearch.java │ ├── Params.java │ ├── candidate │ │ ├── Candidate.java │ │ ├── CandidateGenerator.java │ │ ├── CandidateGroup.java │ │ ├── PathEntry.java │ │ ├── PathEntryWithRange.java │ │ ├── PathUtils.java │ │ ├── TreePattern.java │ │ └── TreePatternAndRange.java │ ├── feature │ │ ├── FeatureExtractor.java │ │ ├── FeaturePostProcessor.java │ │ ├── FeaturePostProcessorConjoin.java │ │ ├── FeatureType.java │ │ ├── FeatureTypeCutRange.java │ │ ├── FeatureTypeHoleBased.java │ │ ├── FeatureTypeLinguisticsBased.java │ │ ├── FeatureTypeNaiveEntityBased.java │ │ ├── FeatureTypeNodeBased.java │ │ └── FeatureTypePathBased.java │ └── tree │ │ ├── HTMLFixer.java │ │ ├── KNode.java │ │ ├── KNodeUtils.java │ │ └── KnowledgeTreeBuilder.java │ └── util │ ├── BipartiteMatcher.java │ ├── EditDistance.java │ ├── Multiset.java │ ├── Parallelizer.java │ ├── SHA.java │ ├── SearchResult.java │ ├── StringDoubleArrayList.java │ ├── StringDoublePair.java │ ├── StringSampler.java │ ├── VectorAverager.java │ └── WebUtils.java └── web-entity-extractor /.gitignore: -------------------------------------------------------------------------------- 1 | /fig 2 | /lib 3 | /output 4 | /state 5 | /cache/*.gz 6 | /datasets/*/ 7 | /classes 8 | /models 9 | 10 | .classpath 11 | .project 12 | *.cache 13 | *.pyc 14 | *.bak 15 | *.swp 16 | *~ 17 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | default: 2 | mkdir -p classes 3 | javac -cp lib/\* -d classes `find src -name "*.java"` 4 | 5 | clean: 6 | rm -rf classes 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Web Entity Extractor 2 | 3 | This repository contains a toolkit for extracting entities from a given search query and web page. 4 | 5 | ## Requirements 6 | 7 | The requirements for running the code include: 8 | 9 | * Java 7 10 | * Ruby 1.8.7 or 1.9 11 | * Python 2.7 12 | 13 | Other required libraries and resources can be downloaded using the following commands: 14 | 15 | * `./download-dependencies core`: download required Java libraries 16 | * `./download-dependencies ling`: download linguistic resources 17 | * `./download-dependencies dataset_debug`: download a small dataset for testing the installation 18 | * `./download-dependencies dataset_openweb`: download the OpenWeb dataset, which contains diverse queries and web pages 19 | * `./download-dependencies model`: download a model trained on the training data of the OpenWeb dataset 20 | 21 | ## Compiling 22 | 23 | Run the following commands to download necessary libraries and compile: 24 | 25 | ./download-dependencies core 26 | ./download-dependencies ling 27 | make 28 | 29 | ## Testing 30 | 31 | To train and test on the debug dataset (30 examples) using the default features, run 32 | 33 | ./download-dependencies dataset_debug 34 | ./web-entity-extractor @mode=main @data=debug @feat=default 35 | 36 | For the OpenWeb dataset, make sure the system has enough RAM (~40GB recommended) and run 37 | 38 | ./download-dependencies dataset_openweb 39 | ./web-entity-extractor @memsize=high @mode=main @data=dev @feat=default -numThreads 0 -fold 3 40 | 41 | Alternatively, run the pre-trained model on the dataset using 42 | 43 | ./download-dependencies model 44 | ./download-dependencies dataset_openweb 45 | # Test on the training data 46 | ./web-entity-extractor @memsize=high @mode=load -loadModel models/openweb-devset @data=dev -numThreads 0 47 | # Test on the test data 48 | ./web-entity-extractor @memsize=high @mode=load -loadModel models/openweb-devset @data=test -numThreads 0 49 | 50 | The flag `-numThreads 0` uses all CPUs available, while `-fold 3` runs the system on 3 random splits of the dataset. 51 | Note that the system may take a long time on the first run to cache all linguistic data. 52 | 53 | ## Interactive Mode 54 | 55 | The interactive mode allows the user to apply the trained model on any query and web page. 56 | 57 | To use the interactive mode, first train and save a model by adding `-saveModel [MODELNAME]` to one of the commands above, and then run 58 | 59 | ./web-entity-extractor @mode=interactive -loadModel [MODELNAME] 60 | 61 | ## License 62 | 63 | The code is under the GNU General Public License (v2). See the `LICENSE` file for the full license. 64 | -------------------------------------------------------------------------------- /cache/README: -------------------------------------------------------------------------------- 1 | Linguistic Cache 2 | ================ 3 | 4 | A linguistic cache stores the CoreNLP tagging of phrases. 5 | They are gzip-ed JSON files of the format 6 | 7 | { 8 | "phrase1": { ... (data) ... }, 9 | "phrase2": { ... (data) ... }, 10 | ... 11 | } 12 | 13 | where (data) may differ between cache versions. 14 | 15 | Version 2 16 | --------- 17 | 18 | { 19 | "coreNLPTags": { 20 | "tokens": [...], 21 | "lemmaTokens": [...], 22 | "posTags": [...], 23 | "nerTags": [...], 24 | "nerValues": [...], 25 | }, 26 | "isOpenClassPOS": [...], 27 | "brownClusters": [...] 28 | } 29 | 30 | Version 3 & 4 31 | ------------- 32 | 33 | { 34 | "tokens": [...], 35 | "lemmaTokens": [...], 36 | "posTags": [...], 37 | "posTypes": [...], 38 | "nerTags": [...], 39 | "nerValues": [...] 40 | } 41 | 42 | To convert the BUGGY v3 to CORRECT v4, use 43 | convert3to4.py [V3CACHE.v3.json.gz] 44 | 45 | Unfortunately, posTags of v3 contains more information than v2, so v2 46 | cannot be converted to v3 directly. 47 | -------------------------------------------------------------------------------- /datasets/README: -------------------------------------------------------------------------------- 1 | Datasets directory 2 | ------------------ 3 | 4 | Call 5 | ./download-dependencies [datasetName] 6 | to download a dataset into this directory. 7 | 8 | Directory structure: 9 | 10 | datasets/ 11 | |- [datasetFamily] 12 | |- [datasetName].json <-- contains the queries and answers 13 | |- [webpageCacheDirectory].cache <-- stores cached web pages 14 | |- [hashcode] <-- contains a single web page 15 | -------------------------------------------------------------------------------- /download-dependencies: -------------------------------------------------------------------------------- 1 | #!/usr/bin/ruby 2 | 3 | $modules = [:core, 4 | :ling, 5 | :dataset_debug, 6 | :dataset_openweb, 7 | :model, 8 | ] 9 | 10 | def usage 11 | puts "Usage: ./download-dependencies <#{$modules.join('|')}>" 12 | end 13 | 14 | if ARGV.size == 0 15 | usage 16 | exit 1 17 | end 18 | 19 | def run(command) 20 | puts "RUNNING: #{command}" 21 | if not system(command) 22 | puts "FAILED: #{command}" 23 | exit 1 24 | end 25 | end 26 | 27 | BASE_URL = 'http://nlp.stanford.edu/software/sempre/web-entity-extractor-ACL2014/' 28 | def download(release, path, baseUrl=BASE_URL) 29 | url = baseUrl + '/release-' + release.to_s 30 | isDirectory = path.end_with?('/') 31 | path = path.sub(/\/*$/, '') 32 | if release != :code and !path.end_with?('.gz', '.tgz', '.jar') 33 | path += '.tar' if isDirectory 34 | path += '.bz2' 35 | end 36 | run("mkdir -p #{File.dirname(path)}") 37 | run("wget -c '#{url}/#{path}' -O #{path}") 38 | if release != :code and path.end_with?('.bz2') 39 | if isDirectory 40 | run("cd #{File.dirname(path)} && tar xjf #{File.basename(path)}") 41 | else 42 | run("bzip2 -fd #{path}") 43 | end 44 | end 45 | end 46 | 47 | def downloadFromFileList(release) 48 | files = [] 49 | File.foreach(File.join('release', "#{release.to_s}.files")) { |line| 50 | file = line.sub(/#.*$/, '').sub(/^\s*/, '').sub(/\s*$/, '') 51 | next if file.length == 0 52 | files << file 53 | } 54 | files.each { |path| 55 | download(release, path) 56 | } 57 | end 58 | 59 | def main(which) 60 | if not $modules.include?(which) 61 | usage 62 | exit 1 63 | end 64 | downloadFromFileList(which) 65 | end 66 | 67 | ARGV.each { |mod| 68 | mod = mod.to_sym 69 | main(mod) 70 | } 71 | -------------------------------------------------------------------------------- /fig/bin/chunk: -------------------------------------------------------------------------------- 1 | #!/usr/bin/ruby 2 | 3 | require File.dirname($0)+'/../lib/myutils' 4 | 5 | # Print out a chunk of a file. Specifically, print out the header lines. 6 | # Use 7 | $file, $headerNumLines, $chunkSize, $indices, $printNumChunks, $verbose = extractArgs(:spec => [ 8 | ['file', String, nil, true, 'Big file to read'], 9 | ['headerNumLines', Fixnum, 0, false, 'Number of header lines to include'], 10 | ['chunkSize', String, '100M', false, 'Size of a chunk (e.g., 1024, 1K, 1M, 1G)'], 11 | ['indices', [Fixnum], [], false, 'Which chunk(s) we want'], 12 | ['printNumChunks', TrueClass, false, false, 'Print out number of chunks'], 13 | ['verbose', Fixnum, 0, false, 'Verbosity level (to stderr)'], 14 | nil]) 15 | 16 | f = open($file, 'r') 17 | header = (0...$headerNumLines).map { f.gets } 18 | 19 | startPos = f.tell 20 | f.seek(0, IO::SEEK_END) 21 | endPos = f.tell 22 | 23 | def parseSize(s) 24 | return Integer(Float($1) * 1024**1) if s =~ /^(.+)[kK]$/ 25 | return Integer(Float($1) * 1024**2) if s =~ /^(.+)[mM]$/ 26 | return Integer(Float($1) * 1024**3) if s =~ /^(.+)[gG]$/ 27 | return Integer(s) 28 | end 29 | 30 | $chunkSize = parseSize($chunkSize) 31 | 32 | totalSize = endPos - startPos 33 | numChunks = (totalSize + $chunkSize - 1) / $chunkSize 34 | $stderr.puts "start = #{startPos}, end = #{endPos}, chunkSize = #{$chunkSize}, totalSize = #{totalSize}, #{numChunks} chunks" if $verbose >= 1 35 | puts numChunks if $printNumChunks 36 | 37 | # Return the first position of the i-th chunk which is a new line. 38 | # Seek to that point. 39 | findBeginOfLine = lambda { |i| 40 | if i == 0 41 | f.seek(startPos, IO::SEEK_SET) 42 | return startPos 43 | else 44 | pos = [endPos, startPos + i * $chunkSize - 1].min 45 | f.seek(pos, IO::SEEK_SET) 46 | while true 47 | c = f.read(1) 48 | #p [f.tell-1, c] 49 | break if c == nil || c == "\n" 50 | end 51 | return f.tell 52 | end 53 | } 54 | 55 | dump = lambda { |pos0,pos1| 56 | f.seek(pos0, IO::SEEK_SET) 57 | pos = pos0 58 | while true 59 | n = [16384, pos1-pos].min 60 | break if n == 0 61 | print f.read(n) 62 | pos += n 63 | end 64 | } 65 | 66 | printedHeader = false 67 | $indices.each { |i| 68 | pos0 = findBeginOfLine.call(i) 69 | pos1 = findBeginOfLine.call(i+1) 70 | next if pos0 == pos1 71 | $stderr.puts "Chunk #{i}/#{numChunks}: #{pos0} to #{pos1}" if $verbose >= 1 72 | if not printedHeader 73 | printedHeader = true 74 | header.each { |line| puts line } 75 | end 76 | dump.call(pos0, pos1) 77 | } 78 | -------------------------------------------------------------------------------- /fig/bin/qcreate: -------------------------------------------------------------------------------- 1 | #!/usr/bin/ruby 2 | 3 | help = < [ ... ] 10 | 11 | All occurrences of _OUTPATH_ in the arguments are replaced with the execution directory. 12 | For example: 13 | qcreate touch _OUTPATH_/foo 14 | will create the file state/execs/.exec/foo, where is unique to this job. 15 | 16 | Usage for fig programs: 17 | qcreate java ... -execDir _OUTPATH_ -overwriteExecDir 18 | More generally, _OUTPATH_ will be replaced with the execution directory. 19 | EOF 20 | 21 | statePath = 'state' 22 | qsub = false 23 | 24 | # Interpret the prefix of ARGV as options to be interpreted. 25 | while true 26 | if ARGV[0] == '-qsub' 27 | ARGV.shift 28 | qsub = true 29 | elsif ARGV[0] == '-statePath' 30 | ARGV.shift 31 | statePath = ARGV.shift 32 | if not statePath 33 | puts "Error: missing statePath" 34 | puts help 35 | exit 1 36 | end 37 | else 38 | break 39 | end 40 | end 41 | if ARGV.size == 0 42 | puts "Error: no command specified" 43 | puts help 44 | exit 1 45 | end 46 | 47 | system "mkdir -p #{statePath}/execs" or exit 1 48 | 49 | lastExecFile = statePath+"/lastExec" 50 | if not File.exists?(lastExecFile) 51 | puts "Creating #{lastExecFile}" 52 | system "touch #{lastExecFile}" 53 | end 54 | 55 | f = File.open(lastExecFile, 'r+') 56 | if f.flock(File::LOCK_EX) != 0 57 | puts "Error: unable to lock #{lastExecFile}" 58 | exit 1 59 | end 60 | id = f.read 61 | begin 62 | id = id == '' ? -1 : Integer(id) 63 | id += 1 64 | f.rewind 65 | f.puts id 66 | f.flush 67 | f.truncate(f.pos) 68 | rescue 69 | puts "Error: #{lastExecFile} has '#{id}' which is not an integer" 70 | id = nil 71 | end 72 | f.close 73 | exit 1 if not id 74 | 75 | execPath = statePath + "/execs/#{id}.exec" 76 | puts "Execution directory: #{execPath}" 77 | if not Dir.mkdir(execPath) 78 | puts "Already exists (this shouldn't happen): #{execPath}" 79 | exit 1 80 | end 81 | 82 | cmdFile = "#{execPath}/#{id}.sh" 83 | out = open(cmdFile, 'w') 84 | cmd = ARGV.map{|x| x =~ /[ "]/ ? x.inspect : x}.join(' ').gsub(/_OUTPATH_/, execPath) 85 | cmd += " > #{execPath}/stdout 2> #{execPath}/stderr" if qsub 86 | out.puts "cd #{Dir.pwd}" 87 | out.puts cmd 88 | out.close 89 | puts cmd 90 | 91 | system "git rev-parse HEAD > #{execPath}/git-hash" if File.exists?('.git') 92 | 93 | if qsub 94 | exec("qsub #{cmdFile} -o /dev/null -e /dev/null") 95 | else 96 | exec("bash #{cmdFile}") 97 | end 98 | -------------------------------------------------------------------------------- /release/core.files: -------------------------------------------------------------------------------- 1 | # This file contains all the basic data and software dependencies on 2 | # which the public code release relies to compile, build, and run in 3 | # general. 4 | # 5 | # These files are hosted on the Stanford NLP web endpoint and are 6 | # intended to be obtained via the command 7 | # 8 | # ./download-dependencies core 9 | # 10 | 11 | # Fig 12 | lib/fig.jar 13 | 14 | # Guava (Google Libraries) 15 | lib/guava-16.0.1.jar 16 | 17 | # JSON 18 | lib/jackson-core-2.3.2.jar 19 | lib/jackson-databind-2.3.2.jar 20 | lib/jackson-annotations-2.3.2.jar 21 | 22 | # HTML Parser 23 | lib/jsoup-1.7.3.jar 24 | 25 | # Stanford CoreNLP 26 | lib/stanford-corenlp-3.3.1.jar 27 | lib/stanford-corenlp-3.3.1-models.jar 28 | lib/joda-time.jar 29 | lib/jollyday.jar 30 | lib/stanford-corenlp-caseless-2013-11-12-models.jar 31 | 32 | -------------------------------------------------------------------------------- /release/dataset_debug.files: -------------------------------------------------------------------------------- 1 | # The debug dataset (contains only 30 examples) 2 | # 3 | # These files are hosted on the Stanford NLP web endpoint and are 4 | # intended to be obtained via the command 5 | # 6 | # ./download-dependencies dataset_debug 7 | # 8 | 9 | datasets/openweb/debug.json 10 | scripts/frozen.cache/openweb-debug/ 11 | -------------------------------------------------------------------------------- /release/dataset_openweb.files: -------------------------------------------------------------------------------- 1 | # The OpenWeb dataset 2 | # 3 | # These files are hosted on the Stanford NLP web endpoint and are 4 | # intended to be obtained via the command 5 | # 6 | # ./download-dependencies dataset_openweb 7 | # 8 | 9 | datasets/openweb/train.json 10 | datasets/openweb/test.json 11 | scripts/frozen.cache/openweb-train/ 12 | scripts/frozen.cache/openweb-test/ 13 | -------------------------------------------------------------------------------- /release/ling.files: -------------------------------------------------------------------------------- 1 | # This file contains all the linguistic data files. 2 | # 3 | # These files are hosted on the Stanford NLP web endpoint and are 4 | # intended to be obtained via the command 5 | # 6 | # ./download-dependencies ling 7 | # 8 | 9 | # wordreprs (Turian et al.) 10 | #lib/wordreprs/brown-rcv1.clean.tokenized-CoNLL03.txt-c1000-freq1.txt 11 | #lib/wordreprs/embeddings-scaled.EMBEDDING_SIZE=50.txt 12 | 13 | # Others 14 | lib/ling-data/queryType.tsv 15 | lib/ling-data/webFrequency.tsv 16 | lib/ling-data/wordnet/newer-30 17 | -------------------------------------------------------------------------------- /release/model.files: -------------------------------------------------------------------------------- 1 | # Trained models 2 | # 3 | # These files are hosted on the Stanford NLP web endpoint and are 4 | # intended to be obtained via the command 5 | # 6 | # ./download-dependencies model 7 | # 8 | 9 | models/openweb-devset/ 10 | -------------------------------------------------------------------------------- /scripts/README: -------------------------------------------------------------------------------- 1 | Python scripts for getting data from the Web or loading the cached web pages. 2 | 3 | Scripts 4 | ------- 5 | 6 | google-search.py 7 | Perform Google search using Google Custom Search 8 | 9 | fake-google-search.py 10 | Fake Google search by looking from the cached results instead 11 | 12 | get-webpage.py 13 | Either 14 | - Download the web page from the specified URL, or 15 | - Get the cached web page corresponding to the specified hash code 16 | from the specified cache directory. 17 | Then, dump the result to standard output 18 | -------------------------------------------------------------------------------- /scripts/fake-google-search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """Fake a Google search result. 5 | Load a Google search cache from fake-google-search.cache 6 | Useful for getting a Google search result for a particular data set in the past. 7 | 8 | Usage: 9 | python fake-google-search.py [QUERY] 10 | Return (in standard output): 11 | A JSON-encoded result of the form [{"link": ..., "title": ...}, ...] 12 | """ 13 | 14 | import sys, os, urllib 15 | 16 | def get_cache_filename(query): 17 | cache_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 18 | 'fake-google-search.cache') 19 | key = urllib.quote_plus(query) 20 | return os.path.join(cache_path, key + '.json') 21 | 22 | if __name__ == '__main__': 23 | query = ' '.join(sys.argv[1:]) 24 | cache_filename = get_cache_filename(query) 25 | try: 26 | with open(cache_filename) as fin: 27 | result = fin.read() 28 | except IOError: 29 | result = '' 30 | print result 31 | -------------------------------------------------------------------------------- /scripts/get-webpage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """Retrieve the web page. 5 | Either load from a cache directory or from the Internet. 6 | 7 | If -H (--hashcode) is specified, load the web page corresponding to 8 | the specified hashcode from the cache. 9 | """ 10 | 11 | import sys, os, argparse 12 | from weblib.web import WebpageCache 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-d', '--cache-directory', default='web.cache', 17 | help='use the specified cache directory') 18 | parser.add_argument('-H', '--hashcode', 19 | help='retrieve using hashcode instead of URL') 20 | (opts, args) = parser.parse_known_args() 21 | 22 | cache = WebpageCache(log=False, dirname=opts.cache_directory) 23 | if opts.hashcode: 24 | print cache.read(opts.hashcode, already_hashed=True) or 'ERROR' 25 | else: 26 | url = args[0] 27 | print cache.get_page(url) or 'ERROR' 28 | -------------------------------------------------------------------------------- /scripts/google-search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """Search Google. 5 | 6 | For only a few searches, the default free search should work OK. 7 | 8 | If a lot of searches are needed (100+), consider using Google Custom Search. 9 | Simply put Google API Key and CX Key in the variables below. 10 | 11 | Print only the first search result. 12 | """ 13 | 14 | import urllib, os, sys, json 15 | from weblib.web import WebpageCache 16 | 17 | # Google Custom Search 18 | GOOGLE_APIKEY = '' 19 | GOOGLE_CX = '' 20 | CACHE_DIRNAME = 'google.cache' 21 | 22 | if __name__ == '__main__': 23 | query = ' '.join(sys.argv[1:]) 24 | cache = WebpageCache(log=False, dirname=CACHE_DIRNAME) 25 | if GOOGLE_APIKEY and GOOGLE_CX: 26 | cache.set_google_custom_search_keys(GOOGLE_APIKEY, GOOGLE_CX) 27 | results = cache.get_urls_from_google_custom_search(query) 28 | else: 29 | results = cache.get_urls_from_google_search(query) 30 | print json.dumps(results) 31 | -------------------------------------------------------------------------------- /scripts/weblib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppasupat/web-entity-extractor-ACL2014/314318be928b584c459376e8f54d8996818044a7/scripts/weblib/__init__.py -------------------------------------------------------------------------------- /scripts/weblib/blacklist.py: -------------------------------------------------------------------------------- 1 | # Some domains that do not like us 2 | 3 | BLACKLIST = set([ 4 | 'www.thehugoawards.org', 5 | 'www.nytimes.com', 6 | 'allaboutexplorers.com', 7 | ]) 8 | -------------------------------------------------------------------------------- /scripts/weblib/tee.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | from codecs import open 6 | 7 | # http://stackoverflow.com/a/616686 8 | class TeeOut(object): 9 | def __init__(self, filename, mode='w'): 10 | import sys 11 | self.file = open(filename, mode, 'utf8') 12 | self.stdout = sys.stdout 13 | sys.stdout = self 14 | def __del__(self): 15 | import sys 16 | sys.stdout = self.stdout 17 | self.file.close() 18 | def write(self, data): 19 | self.file.write(data) 20 | self.file.flush() 21 | self.stdout.write(data) 22 | self.stdout.flush() 23 | 24 | class TeeErr(object): 25 | def __init__(self, filename, mode='w'): 26 | import sys 27 | self.file = open(filename, mode, 'utf8') 28 | self.stderr = sys.stderr 29 | sys.stderr = self 30 | def __del__(self): 31 | import sys 32 | sys.stderr = self.stderr 33 | self.file.close() 34 | def write(self, data): 35 | self.file.write(data) 36 | self.file.flush() 37 | self.stderr.write(data) 38 | self.stderr.flush() 39 | 40 | if __name__ == '__main__': 41 | pass 42 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/core/InteractiveDemo.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.core; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.IOException; 5 | import java.io.InputStreamReader; 6 | 7 | import edu.stanford.nlp.semparse.open.core.eval.CandidateStatistics; 8 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 9 | import fig.basic.LogInfo; 10 | 11 | public class InteractiveDemo { 12 | 13 | public final OpenSemanticParser parser; 14 | 15 | public InteractiveDemo(OpenSemanticParser parser) { 16 | this.parser = parser; 17 | } 18 | 19 | public void run() { 20 | LogInfo.log("Starting interactive mode ..."); 21 | try (BufferedReader in = new BufferedReader(new InputStreamReader(System.in))) { 22 | while (true) { 23 | System.out.println("============================================================"); 24 | System.out.print("Query: list of "); 25 | String phrase = in.readLine(); 26 | if (phrase == null) {System.out.println(); break;} 27 | if (phrase.isEmpty()) continue; 28 | System.out.print("Web Page URL (blank for Google Search): "); 29 | String url = in.readLine(); 30 | if (url == null) {System.out.println(); break;} 31 | CandidateStatistics pred = url.isEmpty() ? parser.predict(phrase) : parser.predict(phrase, url); 32 | LogInfo.begin_track("PRED (top scoring candidate):"); 33 | if (pred == null) { 34 | LogInfo.log("Rank 1 [Unique Rank 1]: NO CANDIDATE FOUND!"); 35 | } else { 36 | LogInfo.logs("Rank 1 [Unique Rank 1]: (Total Feature Score = %s)", pred.score); 37 | Candidate candidate = pred.candidate; 38 | LogInfo.logs("Extraction Predicate: %s", candidate.pattern); 39 | LogInfo.log(candidate.sampleEntities()); 40 | } 41 | LogInfo.end_track(); 42 | } 43 | } catch (IOException e) { 44 | LogInfo.fail(e); 45 | } 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/core/ParallelizedTrainer.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.core; 2 | 3 | import java.util.concurrent.Callable; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.Dataset; 6 | 7 | public class ParallelizedTrainer implements Callable { 8 | Dataset dataset; 9 | boolean beVeryQuiet; 10 | 11 | public ParallelizedTrainer(Dataset dataset, boolean beVeryQuiet) { 12 | this.dataset = dataset; 13 | this.beVeryQuiet = beVeryQuiet; 14 | } 15 | 16 | @Override 17 | public OpenSemanticParser call() throws Exception { 18 | OpenSemanticParser parser = new OpenSemanticParser(); 19 | parser.train(dataset, beVeryQuiet); 20 | return parser; 21 | } 22 | } -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/core/eval/CandidateStatistics.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.core.eval; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 6 | import fig.basic.Pair; 7 | 8 | public class CandidateStatistics { 9 | public final Candidate candidate; 10 | public final int rank, uniqueRank; // rank and uniqueRank are 1-indexed 11 | public final double score; 12 | 13 | public CandidateStatistics(Candidate candidate, int rank, int uniqueRank, double score) { 14 | this.candidate = candidate; 15 | this.rank = rank; 16 | this.uniqueRank = uniqueRank; 17 | this.score = score; 18 | } 19 | 20 | /** 21 | * Convert Pair to CandidateStatistics 22 | */ 23 | public static List getRankedCandidateStats(List> rankedCandidates) { 24 | List answer = new ArrayList<>(); 25 | Set> foundPredictedEntities = new HashSet<>(); 26 | for (int rank = 0; rank < rankedCandidates.size(); rank++) { 27 | Pair entry = rankedCandidates.get(rank); 28 | Candidate candidate = entry.getFirst(); 29 | foundPredictedEntities.add(candidate.predictedEntities); 30 | answer.add(new CandidateStatistics(candidate, rank + 1, foundPredictedEntities.size(), entry.getSecond())); 31 | } 32 | return answer; 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/core/eval/EvaluationCase.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.core.eval; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.core.OpenSemanticParser; 6 | import edu.stanford.nlp.semparse.open.dataset.Example; 7 | import edu.stanford.nlp.semparse.open.dataset.ExpectedAnswer; 8 | import edu.stanford.nlp.semparse.open.dataset.ExpectedAnswerInjectiveMatch; 9 | import edu.stanford.nlp.semparse.open.dataset.IRScore; 10 | import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity; 11 | import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntityNearMatch; 12 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 13 | import fig.basic.LogInfo; 14 | 15 | public abstract class EvaluationCase { 16 | public final Evaluator evaluator; 17 | public final Example ex; 18 | public final CandidateStatistics pred, firstTrue, best; 19 | 20 | // IR scores compared to expected entities 21 | public final IRScore predIRScore, firstTrueIRScore, bestIRScore; 22 | 23 | // IR scores compared to best true 24 | public final IRScore predIRScoreOnBest, firstTrueIRScoreOnBest; 25 | 26 | protected EvaluationCase(Evaluator evaluator, Example ex, CandidateStatistics pred, 27 | CandidateStatistics firstTrue, CandidateStatistics best) { 28 | this.evaluator = evaluator; 29 | this.ex = ex; 30 | this.pred = pred; 31 | this.firstTrue = firstTrue; 32 | this.best = best; 33 | // Compute IR scores 34 | predIRScore = (pred == null) ? null : ex.expectedAnswer.getIRScore(pred.candidate); 35 | firstTrueIRScore = (firstTrue == null) ? null : ex.expectedAnswer.getIRScore(firstTrue.candidate); 36 | bestIRScore = (best == null) ? null : ex.expectedAnswer.getIRScore(best.candidate); 37 | if (best == null) { 38 | predIRScoreOnBest = firstTrueIRScoreOnBest = null; 39 | } else { 40 | List bestTrueEntites = new ArrayList<>(); 41 | for (String entity : best.candidate.predictedEntities) { 42 | bestTrueEntites.add(new TargetEntityNearMatch(entity)); 43 | } 44 | ExpectedAnswer bestTrueAnswer = new ExpectedAnswerInjectiveMatch(bestTrueEntites); 45 | predIRScoreOnBest = (pred == null) ? null : bestTrueAnswer.getIRScore(pred.candidate); 46 | firstTrueIRScoreOnBest = (firstTrue == null) ? null : bestTrueAnswer.getIRScore(firstTrue.candidate); 47 | } 48 | } 49 | 50 | /** Log TRUE : the first likely correct candidate */ 51 | public void logTrue() { 52 | LogInfo.begin_track("TRUE (likely correct candidate):"); 53 | if (firstTrue == null) { 54 | LogInfo.logs("<%s SUPER FAIL> Correct candidate not found!", evaluator.testSuiteName); 55 | } else { 56 | LogInfo.logs("<%s %s> Rank %d [Unique Rank %d]: (Total Feature Score = %s)", evaluator.testSuiteName, 57 | firstTrue.rank == 1 ? "SUCCESS" : "FAIL", firstTrue.rank, 58 | firstTrue.uniqueRank, firstTrue.score); 59 | logCandidate(firstTrue); 60 | } 61 | LogInfo.end_track(); 62 | } 63 | 64 | /** Log PRED : the top scoring candidate */ 65 | public void logPred() { 66 | if (pred != null) { 67 | LogInfo.begin_track("PRED (top scoring candidate):"); 68 | LogInfo.logs("Rank 1 [Unique Rank 1]: (Total Feature Score = %s)", pred.score); 69 | logCandidate(pred); 70 | LogInfo.end_track(); 71 | } 72 | } 73 | 74 | private void logCandidate(CandidateStatistics candidateStat) { 75 | Candidate candidate = candidateStat.candidate; 76 | LogInfo.logs("%s %s", candidate.pattern, candidate.ex.expectedAnswer.getIRScore(candidate)); 77 | if (OpenSemanticParser.opts.logVerbosity >= 2) 78 | LogInfo.log(candidate.sampleEntities()); 79 | if (OpenSemanticParser.opts.logVerbosity >= 3) 80 | evaluator.learner.logFeatureWeights(candidate); 81 | } 82 | 83 | public abstract void logFeatureDiff(); 84 | } 85 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/core/eval/EvaluationNormalFail.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.core.eval; 2 | 3 | import edu.stanford.nlp.semparse.open.dataset.Example; 4 | import fig.basic.LogInfo; 5 | 6 | /** 7 | * When the correct candidate is found but not in the 1st rank 8 | */ 9 | public class EvaluationNormalFail extends EvaluationCase { 10 | 11 | public EvaluationNormalFail(Evaluator evaluator, Example ex, CandidateStatistics pred, 12 | CandidateStatistics firstTrue, CandidateStatistics best) { 13 | super(evaluator, ex, pred, firstTrue, best); 14 | } 15 | 16 | @Override 17 | public void logFeatureDiff() { 18 | LogInfo.begin_track("### %s ###", ex); 19 | LogInfo.logs("GOLD: %s", ex.expectedAnswer.sampleEntities()); 20 | LogInfo.logs("PRED: %s", pred.candidate.pattern); 21 | LogInfo.logs(" (vs target entities) %s", predIRScore); 22 | LogInfo.logs(" (vs best candidate) %s", predIRScoreOnBest); 23 | LogInfo.logs(" %s", pred.candidate.sampleEntities()); 24 | LogInfo.logs("TRUE: %s", firstTrue.candidate.pattern); 25 | LogInfo.logs(" (vs target entities) %s", firstTrueIRScore); 26 | LogInfo.logs(" (vs best candidate) %s", firstTrueIRScoreOnBest); 27 | LogInfo.logs(" (Rank %d [Unique Rank %d])", firstTrue.rank, firstTrue.uniqueRank); 28 | LogInfo.logs(" %s", firstTrue.candidate.sampleEntities()); 29 | LogInfo.logs("BEST: %s", best.candidate.pattern); 30 | LogInfo.logs(" (vs target entities) %s", bestIRScore); 31 | LogInfo.logs(" %s", best.candidate.sampleEntities()); 32 | evaluator.learner.logFeatureDiff(firstTrue.candidate, pred.candidate); 33 | LogInfo.end_track(); 34 | } 35 | 36 | } 37 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/core/eval/EvaluationSuccess.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.core.eval; 2 | 3 | import edu.stanford.nlp.semparse.open.dataset.Example; 4 | 5 | public class EvaluationSuccess extends EvaluationCase { 6 | 7 | public EvaluationSuccess(Evaluator evaluator, Example ex, CandidateStatistics pred, 8 | CandidateStatistics firstTrue, CandidateStatistics best) { 9 | super(evaluator, ex, pred, firstTrue, best); 10 | } 11 | 12 | @Override 13 | public void logFeatureDiff() { 14 | // Do nothing 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/core/eval/EvaluationSuperFail.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.core.eval; 2 | 3 | import edu.stanford.nlp.semparse.open.dataset.Example; 4 | import fig.basic.LogInfo; 5 | 6 | /** 7 | * When the correct candidate is not found 8 | */ 9 | public class EvaluationSuperFail extends EvaluationCase { 10 | 11 | public EvaluationSuperFail(Evaluator evaluator, Example ex, CandidateStatistics pred, 12 | CandidateStatistics firstTrue, CandidateStatistics best) { 13 | super(evaluator, ex, pred, firstTrue, best); 14 | } 15 | 16 | @Override 17 | public void logFeatureDiff() { 18 | LogInfo.begin_track("### %s ###", ex); 19 | LogInfo.logs("GOLD: %s", ex.expectedAnswer.sampleEntities()); 20 | if (pred != null) { 21 | LogInfo.logs("PRED: %s", pred.candidate.pattern); 22 | LogInfo.logs(" (vs target entities) %s", predIRScore); 23 | LogInfo.logs(" (vs best candidate) %s", predIRScoreOnBest); 24 | LogInfo.logs(" %s", pred.candidate.sampleEntities()); 25 | } else { 26 | LogInfo.log("PRED: NOT FOUND!"); 27 | } 28 | LogInfo.log("TRUE: NOT FOUND!"); 29 | if (best != null) { 30 | LogInfo.logs("BEST: %s", best.candidate.pattern); 31 | LogInfo.logs(" (vs target entities) %s", bestIRScore); 32 | LogInfo.logs(" %s", best.candidate.sampleEntities()); 33 | } else { 34 | LogInfo.log("BEST: NOT FOUND!"); 35 | } 36 | LogInfo.end_track(); 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/core/eval/Evaluator.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.core.eval; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.Example; 6 | import edu.stanford.nlp.semparse.open.model.Learner; 7 | import edu.stanford.nlp.semparse.open.util.Multiset; 8 | import fig.basic.Fmt; 9 | import fig.basic.ListUtils; 10 | import fig.basic.LogInfo; 11 | import fig.exec.Execution; 12 | 13 | /** 14 | * Deal with evaluation. 15 | * With zero-one loss, "correct" means matching all criteria 16 | * - Wiki : match all target entities 17 | * - Web : match first, second, and last 18 | * 19 | * pred = predicted candidate (first candidate in rankedCandidates) 20 | * true = first correct candidate in rankedCandidates 21 | * best = the best thing we can ever select from the web page (may not match all target entities) 22 | * 23 | * SUPER FAIL = (firstTrue == null) 24 | * NORMAL FAIL = (firstTrue.rank != 1) 25 | * SUCCESS = (firstTrue.rank == 1) 26 | */ 27 | public class Evaluator { 28 | String testSuiteName; 29 | Learner learner; 30 | int numExamples = 0, numSuccess = 0, numNormalFail = 0, numSuperFail = 0, numFound = 0; 31 | List successes, normalFails, superFails; 32 | double sumF1onExpectedEntities = 0, sumF1onBest = 0; 33 | Multiset firstTrueUniqueRanks; 34 | 35 | public Evaluator(String testSuiteName, Learner learner) { 36 | this.testSuiteName = testSuiteName; 37 | this.learner = learner; 38 | successes = new ArrayList<>(); 39 | normalFails = new ArrayList<>(); 40 | superFails = new ArrayList<>(); 41 | firstTrueUniqueRanks = new Multiset<>(); 42 | } 43 | 44 | public EvaluationCase add(Example ex, CandidateStatistics pred, CandidateStatistics firstTrue, CandidateStatistics best) { 45 | numExamples++; 46 | EvaluationCase evaluationCase; 47 | if (firstTrue == null) { 48 | evaluationCase = new EvaluationSuperFail(this, ex, pred, firstTrue, best); 49 | superFails.add(evaluationCase); 50 | numSuperFail++; 51 | firstTrueUniqueRanks.add(Integer.MAX_VALUE); 52 | } else if (firstTrue.rank != 1) { 53 | evaluationCase = new EvaluationNormalFail(this, ex, pred, firstTrue, best); 54 | normalFails.add(evaluationCase); 55 | numNormalFail++; 56 | numFound++; 57 | firstTrueUniqueRanks.add(firstTrue.uniqueRank); 58 | } else { 59 | evaluationCase = new EvaluationSuccess(this, ex, pred, firstTrue, best); 60 | successes.add(evaluationCase); 61 | numSuccess++; 62 | numFound++; 63 | firstTrueUniqueRanks.add(firstTrue.uniqueRank); 64 | } 65 | if (evaluationCase.predIRScore != null) 66 | sumF1onExpectedEntities += evaluationCase.predIRScore.f1; 67 | if (evaluationCase.predIRScoreOnBest != null) 68 | sumF1onBest += evaluationCase.predIRScoreOnBest.f1; 69 | return evaluationCase; 70 | } 71 | 72 | public double[] getAccuracyAtK(int maxK) { 73 | double[] accuracyAtK = new double[maxK + 1]; 74 | accuracyAtK[0] = 0.0; // Accuracy at 0 is always 0 (used for padding) 75 | int sumCorrect = 0; 76 | for (int k = 1; k <= maxK; k++) { 77 | sumCorrect += firstTrueUniqueRanks.count(k); 78 | accuracyAtK[k] = sumCorrect * 1.0 / numExamples; 79 | } 80 | return accuracyAtK; 81 | } 82 | 83 | public void printDetails() { 84 | LogInfo.begin_track("### %s: %d Normal FAILS (correct candidate in other rank) ###", testSuiteName, numNormalFail); 85 | for (EvaluationCase fail : normalFails) fail.logFeatureDiff(); 86 | LogInfo.end_track(); 87 | LogInfo.begin_track("### %s: %d Super FAILS (correct candidate not found) ###", testSuiteName, numSuperFail); 88 | for (EvaluationCase fail : superFails) fail.logFeatureDiff(); 89 | LogInfo.end_track(); 90 | } 91 | 92 | public Evaluator putOutput(String prefix) { 93 | Execution.putOutput(prefix + ".numExamples", numExamples); 94 | Execution.putOutput(prefix + ".accuracy", 1.0 * numSuccess / numExamples); 95 | Execution.putOutput(prefix + ".accuracyFound", 1.0 * numSuccess / numFound); 96 | Execution.putOutput(prefix + ".oracle", 1.0 * numFound / numExamples); 97 | Execution.putOutput(prefix + ".averageF1onTargetEntities", sumF1onExpectedEntities * 1.0 / numExamples); 98 | Execution.putOutput(prefix + ".averageF1onBestCandidate", sumF1onBest * 1.0 / numExamples); 99 | Execution.putOutput(prefix + ".accuracyAtK", Fmt.D(ListUtils.subArray(getAccuracyAtK(10), 1))); 100 | return this; 101 | } 102 | 103 | public Evaluator printScores() { 104 | LogInfo.begin_track("%s Evaluation", testSuiteName); 105 | LogInfo.logs("Number of examples: %d", numExamples); 106 | LogInfo.logs("Correct candidate in rank #1: %d", numSuccess); 107 | LogInfo.logs("Correct candidate in other rank: %d", numNormalFail); 108 | LogInfo.logs("Correct candidate not found: %d", numSuperFail); 109 | LogInfo.logs("Oracle: %.3f%% ( %d / %d )", numFound * 100.0 / numExamples, numFound, numExamples); 110 | LogInfo.logs("Accuracy (vs all): %.3f%% ( %d / %d )", numSuccess * 100.0 / numExamples, numSuccess, numExamples); 111 | LogInfo.logs("Accuracy (vs found): %.3f%% ( %d / %d )", numSuccess * 100.0 / numFound, numSuccess, numFound); 112 | LogInfo.logs("Average F1 (vs target entities): %.3f%%", sumF1onExpectedEntities * 100.0 / numExamples); 113 | LogInfo.logs("Average F1 (vs best candidate): %.3f%%", sumF1onBest * 100.0 / numExamples); 114 | LogInfo.logs("Accuracy @ k : %s", Fmt.D(ListUtils.mult(100.0, ListUtils.subArray(getAccuracyAtK(10), 1)))); 115 | LogInfo.end_track(); 116 | return this; 117 | } 118 | 119 | public Evaluator printScores(boolean beVeryQuiet) { 120 | return beVeryQuiet ? this : printScores(); 121 | } 122 | 123 | } 124 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/core/eval/EvaluatorStatistics.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.core.eval; 2 | 3 | import java.util.List; 4 | 5 | import fig.basic.Fmt; 6 | import fig.basic.ListUtils; 7 | import fig.basic.LogInfo; 8 | import fig.exec.Execution; 9 | 10 | public class EvaluatorStatistics { 11 | final int numExamples, numSuccess, numNormalFail, numSuperFail, numFound; 12 | // These values are in [0, 1] 13 | final double oracle, accuracyAll, accuracyFound; 14 | final double avgF1onExpectedEntities, avgF1onBest; 15 | final double[] accuracyAtK; 16 | 17 | public EvaluatorStatistics(Evaluator evaluator) { 18 | numExamples = evaluator.numExamples; 19 | numSuccess = evaluator.numSuccess; 20 | numNormalFail = evaluator.numNormalFail; 21 | numSuperFail = evaluator.numSuperFail; 22 | numFound = evaluator.numFound; 23 | oracle = numFound * 1.0 / numExamples; 24 | accuracyAll = numSuccess * 1.0 / numExamples; 25 | accuracyFound = numSuccess * 1.0 / numFound; 26 | avgF1onExpectedEntities = evaluator.sumF1onExpectedEntities * 1.0 / numExamples; 27 | avgF1onBest = evaluator.sumF1onBest * 1.0 / numExamples; 28 | accuracyAtK = evaluator.getAccuracyAtK(IterativeTester.MAX_K); 29 | } 30 | 31 | private static double getAverage(double[] stuff) { 32 | return ListUtils.sum(stuff) / stuff.length; 33 | } 34 | 35 | // Divide by n instead of (n-1) 36 | private static double getVariance(double[] stuff) { 37 | double sumSq = 0, sum = 0, n = stuff.length; 38 | for (double x: stuff) { sumSq += x * x; sum += x; } 39 | return (sumSq / n) - (sum / n) * (sum / n); 40 | } 41 | 42 | // Divide by n instead of (n-1) 43 | private static double getSD(double[] stuff) { 44 | return Math.sqrt(getVariance(stuff)); 45 | } 46 | 47 | public static void logAverage(List stats, String prefix) { 48 | int n = stats.size(), K = IterativeTester.MAX_K; 49 | double[] oracleL = new double[n], 50 | accuracyAllL = new double[n], 51 | accuracyFoundL = new double[n], 52 | avgF1onExpectedEntitiesL = new double[n], 53 | avgF1onBestL = new double[n]; 54 | double[][] accuracyAtKL = new double[K + 1][n]; 55 | 56 | // Compile each statistic into a list 57 | for (int i = 0; i < n; i++) { 58 | EvaluatorStatistics stat = stats.get(i); 59 | oracleL[i] = stat.oracle; 60 | accuracyAllL[i] = stat.accuracyAll; 61 | accuracyFoundL[i] = stat.accuracyFound; 62 | avgF1onExpectedEntitiesL[i] = stat.avgF1onExpectedEntities; 63 | avgF1onBestL[i] = stat.avgF1onBest; 64 | for (int j = 0; j <= K; j++) 65 | accuracyAtKL[j][i] = stat.accuracyAtK[j]; 66 | } 67 | double[] accuracyAtK = new double[K + 1], accuracyAtKSD = new double[K + 1]; 68 | for (int j = 0; j <= K; j++) { 69 | accuracyAtK[j] = getAverage(accuracyAtKL[j]); 70 | accuracyAtKSD[j] = getSD(accuracyAtKL[j]); 71 | } 72 | 73 | // Log the statistics 74 | LogInfo.begin_track("@@@@@ SUMMARY %s (%d folds) @@@@@", prefix, n); 75 | logAverageSingle(oracleL, prefix + ".oracle", "Oracle"); 76 | logAverageSingle(accuracyAllL, prefix + ".accuracy", "Accuracy (vs all)"); 77 | logAverageSingle(accuracyFoundL, prefix + ".accuracyFound", "Accuracy (vs found)"); 78 | logAverageSingle(avgF1onExpectedEntitiesL, prefix + ".averageF1onTargetEntities", "Average F1 (vs target entities)"); 79 | logAverageSingle(avgF1onBestL, prefix + ".averageF1onBestCandidate", "Average F1 (vs best candidate)"); 80 | // Ignore accuracy at 0 (which is always 0) 81 | Execution.putOutput(prefix + ".accuracyAtK", Fmt.D(ListUtils.subArray(accuracyAtK, 1))); 82 | Execution.putOutput(prefix + ".accuracyAtKSD", Fmt.D(ListUtils.subArray(accuracyAtKSD, 1))); 83 | LogInfo.logs("Accuracy @ k : %s", Fmt.D(ListUtils.mult(100.0, ListUtils.subArray(accuracyAtK, 1)))); 84 | LogInfo.logs(" +- %s", Fmt.D(ListUtils.mult(100.0, ListUtils.subArray(accuracyAtKSD, 1)))); 85 | LogInfo.end_track(); 86 | } 87 | 88 | private static void logAverageSingle(double[] stuff, String outputName, String logName) { 89 | Execution.putOutput(outputName, getAverage(stuff)); 90 | Execution.putOutput(outputName + "SD", getSD(stuff)); 91 | LogInfo.logs("%s: %.3f%% +- %.3f%% [%s]", logName, getAverage(stuff) * 100, getSD(stuff) * 100, 92 | Fmt.D(ListUtils.mult(100.0, stuff))); 93 | } 94 | } -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/core/eval/IterativeTester.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.core.eval; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.Main; 6 | import edu.stanford.nlp.semparse.open.core.OpenSemanticParser; 7 | import edu.stanford.nlp.semparse.open.dataset.Dataset; 8 | import fig.basic.LogInfo; 9 | 10 | public class IterativeTester { 11 | private final OpenSemanticParser openSemanticParser; 12 | private final Dataset dataset; 13 | public String message = ""; 14 | List trainStats, testStats; 15 | 16 | public boolean beVeryQuiet = false; 17 | public static final int MAX_K = 10; 18 | 19 | public IterativeTester(OpenSemanticParser openSemanticParser, Dataset dataset) { 20 | this.openSemanticParser = openSemanticParser; 21 | this.dataset = dataset; 22 | this.trainStats = new ArrayList<>(); 23 | this.testStats = new ArrayList<>(); 24 | } 25 | 26 | public void run() { 27 | int oldLogVerbosity = OpenSemanticParser.opts.logVerbosity; 28 | OpenSemanticParser.opts.logVerbosity = 0; 29 | trainStats.add(new EvaluatorStatistics( 30 | openSemanticParser.test(dataset.trainExamples, "[" + message + "] ITERATIVE TEST on TRAINING SET") 31 | .printScores(beVeryQuiet).putOutput("train"))); 32 | testStats.add(new EvaluatorStatistics( 33 | openSemanticParser.test(dataset.testExamples, "[" + message + "] ITERATIVE TEST on TEST SET") 34 | .printScores(beVeryQuiet).putOutput("test"))); 35 | OpenSemanticParser.opts.logVerbosity = oldLogVerbosity; 36 | } 37 | 38 | public void summarize() { 39 | LogInfo.begin_track("@@@ SUMMARY @@@"); 40 | LogInfo.logs("%7s | %7s %7s %7s | %7s %7s %7s", "iter", 41 | "tracc", "trora", "traf1", "tsacc", "tsora", "tsaf1"); 42 | for (int i = 0; i < trainStats.size(); i++) { 43 | EvaluatorStatistics t = trainStats.get(i), s = testStats.get(i); 44 | double tF1, sF1; 45 | if ("wiki".equals(Main.opts.dataset.split("[.]")[0])) { 46 | tF1 = t.avgF1onExpectedEntities; 47 | sF1 = s.avgF1onExpectedEntities; 48 | } else { 49 | tF1 = t.avgF1onBest; 50 | sF1 = s.avgF1onBest; 51 | } 52 | LogInfo.logs("%7s | %7.2f %7.2f %7.2f | %7.2f %7.2f %7.2f", i+1, 53 | t.accuracyAll, t.oracle, tF1, 54 | s.accuracyAll, s.oracle, sF1); 55 | } 56 | LogInfo.end_track(); 57 | } 58 | 59 | public EvaluatorStatistics getLastTrainStat() { 60 | return trainStats.get(trainStats.size() - 1); 61 | } 62 | 63 | public EvaluatorStatistics getLastTestStat() { 64 | return testStats.get(testStats.size() - 1); 65 | } 66 | 67 | 68 | 69 | } 70 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/Criteria.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset; 2 | 3 | import java.util.List; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity; 6 | 7 | public interface Criteria { 8 | public List getTargetEntities(); 9 | 10 | /** Number of criteria **/ 11 | public int numCriteria(); 12 | 13 | /** Number of matched criteria **/ 14 | public int countMatchedCriteria(List predictedEntities); 15 | 16 | /** Return a custom IR score */ 17 | public IRScore getIRScore(List predictedEntities); 18 | 19 | /** 20 | * Return the correctness score. Between correct candidates, the one with 21 | * higher correctness score is more correct. 22 | * 23 | * Normally, this is just getIRScore().f1 24 | */ 25 | public double getCorrectnessScore(List predictedEntities); 26 | } -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/CriteriaExactMatch.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity; 6 | import edu.stanford.nlp.semparse.open.util.BipartiteMatcher; 7 | 8 | /** 9 | * Only 1 criteria: whether the lists are exactly the same. 10 | */ 11 | public class CriteriaExactMatch implements Criteria { 12 | public final List targetEntities; 13 | 14 | public CriteriaExactMatch(TargetEntity... targetEntities) { 15 | this.targetEntities = Arrays.asList(targetEntities); 16 | } 17 | 18 | public CriteriaExactMatch(List targetEntities) { 19 | this.targetEntities = targetEntities; 20 | } 21 | 22 | @Override 23 | public List getTargetEntities() { 24 | return targetEntities; 25 | } 26 | 27 | @Override 28 | public int numCriteria() { 29 | return 1; 30 | } 31 | 32 | @Override 33 | public int countMatchedCriteria(List predictedEntities) { 34 | if (predictedEntities.size() == targetEntities.size()) 35 | if (new BipartiteMatcher(targetEntities, predictedEntities).findMaximumMatch() == targetEntities.size()) 36 | return 1; 37 | return 0; 38 | } 39 | 40 | Map, IRScore> irScoreCache = new HashMap<>(); 41 | 42 | @Override 43 | public IRScore getIRScore(List predictedEntities) { 44 | IRScore answer = irScoreCache.get(predictedEntities); 45 | if (answer == null) { 46 | answer = new IRScore(targetEntities, predictedEntities); 47 | irScoreCache.put(predictedEntities, answer); 48 | } 49 | return answer; 50 | } 51 | 52 | /** 53 | * More F1 = better candidate. 54 | */ 55 | @Override 56 | public double getCorrectnessScore(List predictedEntities) { 57 | return getIRScore(predictedEntities).f1; 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/CriteriaGeneralWeb.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity; 6 | import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntityNearMatch; 7 | import edu.stanford.nlp.semparse.open.dataset.library.JSONDataset.JSONDatasetDatum; 8 | 9 | /** 10 | * Must match first, second, and last entities 11 | */ 12 | public class CriteriaGeneralWeb implements Criteria { 13 | public final JSONDatasetDatum datum; 14 | public final TargetEntity first, second, last; 15 | 16 | public CriteriaGeneralWeb(JSONDatasetDatum datum) { 17 | this.datum = datum; 18 | this.first = new TargetEntityNearMatch(datum.criteria.first); 19 | this.second = new TargetEntityNearMatch(datum.criteria.second); 20 | this.last = new TargetEntityNearMatch(datum.criteria.last); 21 | } 22 | 23 | @Override 24 | public List getTargetEntities() { 25 | return Arrays.asList(first, second, last); 26 | } 27 | 28 | @Override 29 | public int countMatchedCriteria(List predictedEntities) { 30 | int n = predictedEntities.size(), answer = 0; 31 | String predictedFirst = n > 0 ? predictedEntities.get(0) : "", 32 | predictedSecond = n > 1 ? predictedEntities.get(1) : "", 33 | predictedLast = n > 0 ? predictedEntities.get(n - 1) : ""; 34 | if (first.match(predictedFirst)) answer++; 35 | if (second.match(predictedSecond)) answer++; 36 | if (last.match(predictedLast)) answer++; 37 | return answer; 38 | } 39 | 40 | @Override 41 | public int numCriteria() { 42 | return 3; 43 | } 44 | 45 | @Override 46 | public IRScore getIRScore(List predictedEntities) { 47 | return new IRScore(countMatchedCriteria(predictedEntities), numCriteria(), numCriteria()); 48 | } 49 | 50 | @Override 51 | public double getCorrectnessScore(List predictedEntities) { 52 | // TODO Make this better 53 | if (countMatchedCriteria(predictedEntities) != numCriteria()) return 0; 54 | return predictedEntities.size(); // Prefer larger set of entities 55 | } 56 | } -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/Dataset.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity; 6 | import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntityPersonName; 7 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 8 | import fig.basic.LogInfo; 9 | 10 | /** 11 | * A Dataset represents a data set, which has multiple Examples (data instances). 12 | * 13 | * The examples are divided into training and test. 14 | */ 15 | public class Dataset { 16 | public final List trainExamples = new ArrayList<>(); 17 | public final List testExamples = new ArrayList<>(); 18 | 19 | public Dataset() { 20 | // Do nothing 21 | } 22 | 23 | public Dataset(List train, List test) { 24 | trainExamples.addAll(train); 25 | testExamples.addAll(test); 26 | } 27 | 28 | public Dataset addTrainExample(Example ex) { 29 | trainExamples.add(ex); 30 | return this; 31 | } 32 | 33 | public Dataset addTestExample(Example ex) { 34 | testExamples.add(ex); 35 | return this; 36 | } 37 | 38 | public Dataset addFromDataset(Dataset that) { 39 | this.trainExamples.addAll(that.trainExamples); 40 | this.testExamples.addAll(that.testExamples); 41 | return this; 42 | } 43 | 44 | public Dataset addTrainFromDataset(Dataset that) { 45 | this.trainExamples.addAll(that.trainExamples); 46 | this.trainExamples.addAll(that.testExamples); 47 | return this; 48 | } 49 | 50 | public Dataset addTestFromDataset(Dataset that) { 51 | this.testExamples.addAll(that.trainExamples); 52 | this.testExamples.addAll(that.testExamples); 53 | return this; 54 | } 55 | 56 | /** 57 | * @return a new Dataset with the Examples shuffled up. 58 | * The train/test ratio remain the same. 59 | * The original Dataset is not modified. 60 | */ 61 | public Dataset getNewShuffledDataset() { 62 | List allExamples = new ArrayList<>(trainExamples); 63 | allExamples.addAll(testExamples); 64 | Collections.shuffle(allExamples, new Random(42)); 65 | List newTrain = allExamples.subList(0, trainExamples.size()); 66 | List newTest = allExamples.subList(trainExamples.size(), allExamples.size()); 67 | return new Dataset(newTrain, newTest); 68 | } 69 | 70 | /** 71 | * @return a new Dataset with the specified train/test ratio. 72 | */ 73 | public Dataset getNewSplitDataset(double trainRatio) { 74 | List allExamples = new ArrayList<>(trainExamples); 75 | allExamples.addAll(testExamples); 76 | Collections.shuffle(allExamples, new Random(42)); 77 | int trainEndIndex = (int) (allExamples.size() * trainRatio); 78 | List newTrain = allExamples.subList(0, trainEndIndex); 79 | List newTest = allExamples.subList(trainEndIndex, allExamples.size()); 80 | return new Dataset(newTrain, newTest); 81 | } 82 | 83 | // ============================================================ 84 | // Caching rewards 85 | // ============================================================ 86 | 87 | public void cacheRewards() { 88 | List uncached = new ArrayList<>(); 89 | for (Example ex : trainExamples) 90 | if (!ex.expectedAnswer.frozenReward) uncached.add(ex); 91 | for (Example ex : testExamples) 92 | if (!ex.expectedAnswer.frozenReward) uncached.add(ex); 93 | if (uncached.isEmpty()) return; 94 | LogInfo.begin_track("Cache rewards ..."); 95 | for (Example ex : uncached) { 96 | LogInfo.begin_track("Computing rewards for example %s ...", ex); 97 | for (Candidate candidate : ex.candidates) { 98 | ex.expectedAnswer.reward(candidate); 99 | } 100 | ex.expectedAnswer.frozenReward = true; 101 | LogInfo.end_track(); 102 | } 103 | LogInfo.end_track(); 104 | } 105 | 106 | // ============================================================ 107 | // Shorthands for creating datasets. 108 | // ============================================================ 109 | 110 | public Example E(String phrase, ExpectedAnswer expectedAnswer) { 111 | return E(phrase, expectedAnswer, true); 112 | } 113 | 114 | public Example E(String phrase, ExpectedAnswer expectedAnswer, boolean isTrain) { 115 | Example ex = new Example(phrase, expectedAnswer); 116 | if (isTrain) 117 | addTrainExample(ex); 118 | else 119 | addTestExample(ex); 120 | return ex; 121 | } 122 | 123 | public ExpectedAnswer L(String... items) { 124 | return L(false, items); 125 | } 126 | 127 | public ExpectedAnswer L(boolean exact, String... items) { 128 | return new ExpectedAnswerInjectiveMatch(items); 129 | } 130 | 131 | public ExpectedAnswer LN(String... items) { 132 | return LN(false, items); 133 | } 134 | 135 | public ExpectedAnswer LN(boolean exact, String... items) { 136 | TargetEntity[] targetEntities = new TargetEntity[items.length]; 137 | for (int i = 0; i < items.length; i++) targetEntities[i] = N(items[i]); 138 | return new ExpectedAnswerInjectiveMatch(items); 139 | } 140 | 141 | public TargetEntityPersonName N(String full) { 142 | String[] parts = full.split(" "); 143 | if (parts.length == 2) 144 | return new TargetEntityPersonName(parts[0], parts[1]); 145 | else if (parts.length == 3) 146 | return new TargetEntityPersonName(parts[0], parts[1], parts[2]); 147 | throw new RuntimeException("N(...) requires two or three words."); 148 | } 149 | 150 | } 151 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/Example.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset; 2 | 3 | import java.util.List; 4 | 5 | import edu.stanford.nlp.semparse.open.ling.AveragedWordVector; 6 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 7 | import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup; 8 | import edu.stanford.nlp.semparse.open.model.tree.KNode; 9 | 10 | /** 11 | * An Example consists of the following: 12 | * - A phrase (e.g., "us cities") 13 | * - expectedAnswer: the entities that we'd like to extract from the knowledge tree 14 | * - A knowledge tree (constructed based on the phrase) 15 | * - A list of candidate answers 16 | */ 17 | public class Example { 18 | public String displayId; // For debugging 19 | 20 | public final String phrase; 21 | public final ExpectedAnswer expectedAnswer; 22 | public KNode tree; // Deterministic function of the phrase 23 | public List candidateGroups; 24 | public List candidates; // Candidate predictions 25 | public AveragedWordVector averagedWordVector; 26 | 27 | public Example(String phrase) { 28 | this(phrase, null); 29 | } 30 | 31 | public Example(String phrase, ExpectedAnswer expectedAnswer) { 32 | this.phrase = phrase; 33 | this.expectedAnswer = expectedAnswer; 34 | } 35 | 36 | @Override public String toString() { 37 | return "[" + phrase + "]"; 38 | } 39 | 40 | public void initAveragedWordVector() { 41 | if (averagedWordVector == null) 42 | averagedWordVector = new AveragedWordVector(phrase); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/ExampleCached.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset; 2 | 3 | /** 4 | * A CachedExample is an Example that should build the knowledge tree from the 5 | * cached web page instead of from the Web. 6 | * 7 | * It is useful for testing datasets annotated on cached web pages. 8 | */ 9 | public class ExampleCached extends Example { 10 | public final String hashcode, cacheDirectory, url; 11 | 12 | public ExampleCached(String phrase, String url) { 13 | this(phrase, null, null, url, null); 14 | } 15 | 16 | public ExampleCached(String phrase, String cacheDirectory, String hashcode, String url) { 17 | this(phrase, cacheDirectory, hashcode, url, null); 18 | } 19 | 20 | public ExampleCached(String phrase, String cacheDirectory, String hashcode, String url, ExpectedAnswer expectedAnswer) { 21 | super(phrase, expectedAnswer); 22 | this.url = url; 23 | this.hashcode = hashcode; 24 | this.cacheDirectory = cacheDirectory; 25 | } 26 | 27 | @Override public String toString() { 28 | StringBuilder sb = new StringBuilder("[").append(phrase).append("]") 29 | .append("[").append(cacheDirectory).append("/").append(hashcode).append("]"); 30 | if (url != null) 31 | sb.append("[").append(url).append("]"); 32 | return sb.toString(); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/ExpectedAnswerCriteriaMatch.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset; 2 | 3 | import java.util.List; 4 | 5 | import edu.stanford.nlp.semparse.open.core.eval.CandidateStatistics; 6 | import fig.basic.Option; 7 | 8 | /** 9 | * Gives reward = 1 if the predicted entities match all criteria, and reward = 0 otherwise. 10 | */ 11 | public class ExpectedAnswerCriteriaMatch extends ExpectedAnswer { 12 | public static class Options { 13 | @Option(gloss = "Give partial reward for lists that don't exactly match the criteria") 14 | public boolean generous = false; 15 | } 16 | public static Options opts = new Options(); 17 | 18 | public final Criteria criteria; 19 | 20 | public ExpectedAnswerCriteriaMatch(Criteria criteria) { 21 | super(criteria.getTargetEntities()); 22 | this.criteria = criteria; 23 | } 24 | 25 | @Override 26 | public IRScore getIRScore(List predictedEntities) { 27 | return criteria.getIRScore(predictedEntities); 28 | } 29 | 30 | @Override 31 | public double reward(List predictedEntities) { 32 | if (!opts.generous) { 33 | return countCorrectEntities(predictedEntities) == criteria.numCriteria() ? 1 : 0; 34 | } else { 35 | // Generous reward 36 | double f1 = criteria.getIRScore(predictedEntities).f1; 37 | return f1 > ExpectedAnswerInjectiveMatch.opts.irThreshold ? f1 : 0; 38 | } 39 | } 40 | 41 | @Override 42 | public int computeCountCorrectEntities(List predictedEntities) { 43 | return criteria.countMatchedCriteria(predictedEntities); 44 | } 45 | 46 | @Override 47 | public boolean isLikelyCorrect(List predictedEntities) { 48 | return countCorrectEntities(predictedEntities) == criteria.numCriteria(); 49 | } 50 | 51 | @Override 52 | public CandidateStatistics findBestCandidate(List rankedCandidateStats) { 53 | double bestCorrectnessScore = 0; 54 | CandidateStatistics best = null; 55 | for (CandidateStatistics candidateStat : rankedCandidateStats) { 56 | double correctnessScore = criteria.getCorrectnessScore(candidateStat.candidate.predictedEntities); 57 | if (correctnessScore > bestCorrectnessScore) { 58 | best = candidateStat; 59 | bestCorrectnessScore = correctnessScore; 60 | } 61 | } 62 | return best; 63 | } 64 | } -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/ExpectedAnswerInjectiveMatch.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset; 2 | 3 | import java.util.List; 4 | 5 | import edu.stanford.nlp.semparse.open.core.eval.CandidateStatistics; 6 | import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity; 7 | import edu.stanford.nlp.semparse.open.util.BipartiteMatcher; 8 | import fig.basic.LogInfo; 9 | import fig.basic.Option; 10 | 11 | public class ExpectedAnswerInjectiveMatch extends ExpectedAnswer { 12 | public static class Options { 13 | @Option public double irThreshold = 0.8; 14 | @Option public String irCriterion = "recall"; 15 | } 16 | public static Options opts = new Options(); 17 | 18 | public ExpectedAnswerInjectiveMatch(TargetEntity... targetEntities) {super(targetEntities);} 19 | public ExpectedAnswerInjectiveMatch(List targetEntities) {super(targetEntities);} 20 | public ExpectedAnswerInjectiveMatch(String... targetStrings) {super(targetStrings);} 21 | 22 | @Override 23 | public IRScore getIRScore(List predictedEntities) { 24 | return new IRScore(countCorrectEntities(predictedEntities), predictedEntities.size(), targetEntities.size()); 25 | } 26 | 27 | @Override 28 | public double reward(List predictedEntities) { 29 | IRScore score = getIRScore(predictedEntities); 30 | double criterionScore = 0; 31 | switch (opts.irCriterion) { 32 | case "precision": case "p": 33 | criterionScore = score.precision; break; 34 | case "recall": case "r": 35 | criterionScore = score.recall; break; 36 | case "f1": 37 | criterionScore = score.f1; break; 38 | case "raw": 39 | return (score.numCorrect >= score.numGold - opts.irThreshold) ? 1 : 0; 40 | default: 41 | LogInfo.fails("IR Criterion %s not recognized", opts.irCriterion); 42 | } 43 | return (criterionScore < opts.irThreshold) ? 0 : criterionScore; 44 | } 45 | 46 | @Override 47 | public int computeCountCorrectEntities(List predictedEntities) { 48 | return new BipartiteMatcher(targetEntities, predictedEntities).findMaximumMatch(); 49 | } 50 | 51 | @Override 52 | public boolean isLikelyCorrect(List predictedEntities) { 53 | return reward(predictedEntities) > 0; 54 | } 55 | 56 | @Override 57 | public CandidateStatistics findBestCandidate(List rankedCandidateStats) { 58 | double bestReward = 0; 59 | CandidateStatistics best = null; 60 | for (CandidateStatistics candidateStat : rankedCandidateStats) { 61 | double reward = reward(candidateStat.candidate); 62 | if (reward > bestReward) { 63 | best = candidateStat; 64 | bestReward = reward; 65 | } 66 | } 67 | return best; 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/IRScore.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset; 2 | 3 | import java.util.List; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity; 6 | import edu.stanford.nlp.semparse.open.util.BipartiteMatcher; 7 | 8 | public class IRScore { 9 | public final int numCorrect, numPredicted, numGold; 10 | public final double precision, recall, f1; 11 | 12 | public IRScore(int numCorrect, int numPredicted, int numGold) { 13 | this.numCorrect = numCorrect; 14 | this.numPredicted = numPredicted; 15 | this.numGold = numGold; 16 | precision = (numPredicted == 0) ? 0 : numCorrect * 1.0 / numPredicted; 17 | recall = numCorrect * 1.0 / numGold; 18 | f1 = (numCorrect == 0) ? 0 : (2 * precision * recall) / (precision + recall); 19 | } 20 | 21 | public IRScore(List expected, List predicted) { 22 | this(new BipartiteMatcher(expected, predicted).findMaximumMatch(), 23 | predicted.size(), expected.size()); 24 | } 25 | 26 | @Override 27 | public String toString() { 28 | return String.format("[ Precision = %.2f (%d/%d) | Recall = %.2f (%d/%d) | F1 = %.2f ]", precision * 100, 29 | numCorrect, numPredicted, recall * 100, numCorrect, numGold, f1 * 100); 30 | } 31 | } -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/entity/TargetEntity.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset.entity; 2 | 3 | import java.util.Collection; 4 | 5 | /** 6 | * A TargetEntity represents an answer key -- an entity that appears in the target web page 7 | * (according to a Turker). 8 | * 9 | * A TargetEntity provides matching methods, which may implement fancy matching schemes 10 | * such as partial matching or person name matching. 11 | */ 12 | public interface TargetEntity { 13 | public boolean match(String predictedEntity); 14 | public boolean matchAny(Collection predictedEntities); 15 | } 16 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/entity/TargetEntityNearMatch.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset.entity; 2 | 3 | import java.util.Collection; 4 | 5 | import edu.stanford.nlp.semparse.open.ling.LingUtils; 6 | import edu.stanford.nlp.semparse.open.util.EditDistance; 7 | import fig.basic.Option; 8 | 9 | public class TargetEntityNearMatch implements TargetEntity { 10 | public static class Options { 11 | @Option public int nearMatchMaxEditDistance = 2; 12 | @Option(gloss = "level of target entity string normalization " 13 | + "(0 = none / 1 = whitespace / 2 = simple / 3 = aggressive)") 14 | public int targetNormalizeEntities = 2; 15 | } 16 | public static Options opts = new Options(); 17 | 18 | public final String expected, normalizedExpected; 19 | 20 | public TargetEntityNearMatch(String expected) { 21 | this.expected = expected; 22 | this.normalizedExpected = LingUtils.normalize(expected, opts.targetNormalizeEntities); 23 | } 24 | 25 | @Override public String toString() { 26 | StringBuilder sb = new StringBuilder(expected); 27 | if (!expected.equals(normalizedExpected)) 28 | sb.append(" || ").append(normalizedExpected); 29 | return sb.toString(); 30 | } 31 | 32 | @Override 33 | public boolean match(String predictedEntity) { 34 | // Easy cases 35 | if (expected.equals(predictedEntity)) 36 | return true; 37 | // Edit distance 38 | if (EditDistance.withinEditDistance(normalizedExpected, predictedEntity, opts.nearMatchMaxEditDistance)) 39 | return true; 40 | return false; 41 | } 42 | 43 | @Override 44 | public boolean matchAny(Collection predictedEntities) { 45 | for (String predictedEntity : predictedEntities) { 46 | if (match(predictedEntity)) return true; 47 | } 48 | return false; 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/entity/TargetEntityPersonName.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset.entity; 2 | 3 | import java.util.*; 4 | 5 | public class TargetEntityPersonName implements TargetEntity { 6 | 7 | public final String first; 8 | public final String mid; 9 | public final String last; 10 | final List patterns = new ArrayList<>(); 11 | 12 | public TargetEntityPersonName(String first, String last) { 13 | this.first = first; 14 | this.mid = null; 15 | this.last = last; 16 | generatePatterns(); 17 | } 18 | 19 | public TargetEntityPersonName(String first, String mid, String last) { 20 | this.first = first; 21 | if (mid.length() == 2 && mid.charAt(1) == '.') 22 | this.mid = mid.substring(0, 1); 23 | else 24 | this.mid = mid; 25 | this.last = last; 26 | generatePatterns(); 27 | } 28 | 29 | private void generatePatterns() { 30 | patterns.add(first + " " + last); 31 | patterns.add(last + ", " + first); 32 | patterns.add(first.charAt(0) + ". " + last); 33 | patterns.add(last + ", " + first.charAt(0) + "."); 34 | if (mid != null) { 35 | if (mid.length() > 1) { 36 | patterns.add(first + " " + mid + " " + last); 37 | patterns.add(last + ", " + first + " " + mid); 38 | } 39 | patterns.add(first + " " + mid.charAt(0) + ". " + last); 40 | patterns.add(last + ", " + first + " " + mid.charAt(0) + "."); 41 | } 42 | } 43 | 44 | @Override 45 | public String toString() { 46 | if (mid != null) 47 | return first + " " + mid + " " + last; 48 | return first + " " + last; 49 | } 50 | 51 | @Override 52 | public boolean match(String predictedEntity) { 53 | for (String pattern : patterns) 54 | if (pattern.equals(predictedEntity)) return true; 55 | return false; 56 | } 57 | 58 | @Override 59 | public boolean matchAny(Collection predictedEntities) { 60 | for (String pattern : patterns) 61 | if (predictedEntities.contains(pattern)) return true; 62 | return false; 63 | } 64 | 65 | } 66 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/entity/TargetEntityString.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset.entity; 2 | 3 | import java.util.Collection; 4 | 5 | public class TargetEntityString implements TargetEntity { 6 | 7 | public final String expected; 8 | 9 | public TargetEntityString(String expected) { 10 | this.expected = expected; 11 | } 12 | 13 | @Override 14 | public String toString() { 15 | return expected; 16 | } 17 | 18 | @Override 19 | public boolean match(String predicted) { 20 | return expected.equals(predicted); 21 | } 22 | 23 | @Override 24 | public boolean matchAny(Collection predictedEntities) { 25 | return predictedEntities.contains(expected); 26 | } 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/entity/TargetEntitySubstring.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset.entity; 2 | 3 | import java.util.Collection; 4 | 5 | public class TargetEntitySubstring implements TargetEntity { 6 | 7 | public final String expected; 8 | 9 | public TargetEntitySubstring(String expected) { 10 | this.expected = expected; 11 | } 12 | 13 | @Override public String toString() { 14 | return expected; 15 | } 16 | 17 | @Override 18 | public boolean match(String predictedEntity) { 19 | return predictedEntity.contains(expected); 20 | } 21 | 22 | @Override 23 | public boolean matchAny(Collection predictedEntities) { 24 | for (String predictedEntity : predictedEntities) { 25 | if (match(predictedEntity)) return true; 26 | } 27 | return false; 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/library/DatasetLibrary.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset.library; 2 | 3 | import java.io.IOException; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.Dataset; 6 | import fig.basic.LogInfo; 7 | 8 | public class DatasetLibrary { 9 | 10 | public static Dataset getDataset(String fullname) { 11 | if (fullname == null) 12 | return null; 13 | 14 | String[] parts = fullname.split("\\."); 15 | if (parts.length != 2) 16 | LogInfo.fails("Expected dataset format = family.name; got " + fullname); 17 | 18 | String family = parts[0], name = parts[1]; 19 | if (family == null) 20 | LogInfo.fails("No dataset family specified."); 21 | 22 | // Special case : Unary family (the very old dataset) 23 | if ("unary".equals(family)) 24 | return new UnaryDatasets().getDataset(name); 25 | 26 | // Load from the `datasets` directory 27 | try { 28 | return new JSONDatasetReader(family, name).getDataset(); 29 | } catch (IOException e) { 30 | LogInfo.fail(e); 31 | } 32 | return null; 33 | } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/library/JSONDataset.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset.library; 2 | 3 | import java.util.List; 4 | 5 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 6 | 7 | /** 8 | * The format of the JSON file is 9 | * 10 | *
11 |  *   {
12 |  *     "options": {
13 |  *       "cacheDirectory": "(location of cache directory -- default = web.cache)",
14 |  *       "useHashcode": (true if the page should be loaded from the frozen cache by hashcode
15 |  *                       and not from the Internet -- default = false),
16 |  *       "detailed": (true if detailed data is available -- default = false)
17 |  *     },
18 |  *     "data": [ ... ]
19 |  *   }
20 |  * 
21 | * 22 | * where each element in the data array is 23 | * 24 | *
25 |  *   {
26 |  *     "hashcode": "(OPTIONAL - hashcode for frozen cache)",
27 |  *     "query": "(MANDATORY - query string)",
28 |  *     "url": "(OPTIONAL - url)",
29 |  *     "entities": [ ...(MANDATORY - target entity strings)... ]
30 |  *     "criteria": { ...(OPTIONAL - mapping from "first", "second", and "last" to entity string)... }
31 |  *   }
32 |  * 
33 | * 34 | */ 35 | public class JSONDataset { 36 | 37 | @JsonIgnoreProperties(ignoreUnknown=true) 38 | public static class JSONDatasetOption { 39 | public String cacheDirectory = null; 40 | public boolean useHashcode = false; 41 | public boolean detailed = false; 42 | 43 | @Override 44 | public String toString() { 45 | return new StringBuilder() 46 | .append("useHashcode: ").append(useHashcode).append("\n") 47 | .append("cacheDirectory: ").append(cacheDirectory).append("\n") 48 | .append("detailed: ").append(detailed).append("\n") 49 | .toString(); 50 | } 51 | } 52 | 53 | @JsonIgnoreProperties(ignoreUnknown=true) 54 | public static class JSONDatasetDatum { 55 | public String hashcode; 56 | public String query; 57 | public String url; 58 | public List entities; 59 | public List rawanswers; 60 | public JSONDatasetCriteria criteria; 61 | 62 | @Override 63 | public String toString() { 64 | return new StringBuilder() 65 | .append("[").append(query) 66 | .append(hashcode == null ? "" : " " + hashcode).append("]") 67 | .toString(); 68 | } 69 | } 70 | 71 | public enum JSONDatasetRawAnswerType { Z, L, H }; 72 | 73 | public static class JSONDatasetRawAnswers { 74 | public JSONDatasetRawAnswerType type; 75 | public List answers; 76 | } 77 | 78 | public static class JSONDatasetCriteria { 79 | public String first, second, last; 80 | } 81 | 82 | public JSONDatasetOption options; 83 | public List data; 84 | } 85 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/dataset/library/UnaryDatasets.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.dataset.library; 2 | 3 | import edu.stanford.nlp.semparse.open.dataset.Dataset; 4 | 5 | class UnaryDatasets { 6 | 7 | public Dataset getDataset(final String name) { 8 | if (name == null) 9 | throw new RuntimeException("No dataset specified."); 10 | 11 | if (name.equals("all")) { 12 | return new Dataset() 13 | .addTestFromDataset(getDataset("geo")) 14 | .addFromDataset(getDataset("academia")) 15 | .addFromDataset(getDataset("website")) 16 | .addFromDataset(getDataset("stanford")) 17 | .addFromDataset(getDataset("route")) 18 | .addFromDataset(getDataset("celeb")) 19 | .addFromDataset(getDataset("sport")) 20 | .addFromDataset(getDataset("leader")) 21 | .addFromDataset(getDataset("fiction")); 22 | } 23 | 24 | return new Dataset() { 25 | { 26 | switch (name) { 27 | case "one": 28 | E("European countries", L("Greece", "Germany", "Spain", "France", "Estonia", "Romania")); 29 | break; 30 | 31 | case "geo": 32 | // Easy examples: every page has roughly, should be easy to generalize 33 | E("European countries", L("Greece", "Germany", "Spain", "France", "Estonia", "Romania")); 34 | E("Asian countries", L("Japan", "China", "India", "Singapore", "Kyrgyzstan", "Iran")); 35 | E("Canada provinces", L("Quebec", "British Columbia", "Ontario", "Saskatchewan")); 36 | E("cities in California", L("Los Angeles", "San Jose", "Ontario", "Sacramento", "San Francisco")); 37 | E("Hawaii islands", L("Hawaii", "Maui", "Kauai", "Molokai", "Oahu", "Lanai", "Niihua")); 38 | E("states of the USA", L("California", "Ohio", "Alaska", "Michigan", "Kansas", "New Jersey", "Arizona")); 39 | break; 40 | 41 | case "academia": 42 | E("stanford cs faculty", LN("Percy Liang", "Andrew Ng", "Alex Aiken", "Don Knuth", "Chris Manning")); 43 | E("cmu cs faculty", LN("Avrim Blum", "Umut Acar", "Priya Narasimhan", "Mahadev Satyanarayanan")); 44 | E("Michael I Jordan students", LN("Percy Liang", "Tommi Jaakkola", "John Duchi")); 45 | E("Lillian Lee students", LN("Regina Barzilay", "Chenhao Tan", "Bo Pang", "Rie Johnson")); 46 | E("MIT CSAIL professors", LN("Daniel Jackson", "Eric Grimson", "Hal Abelson", "Shafi Goldwasser")); 47 | break; 48 | 49 | case "website": 50 | E("online social networks", L("Facebook", "Twitter", "Myspace", "Google+")); 51 | E("search engines", L("Google", "Yahoo", "Bing")); 52 | E("Chinese web portals", L("Baidu", "Sina", "Sohu")); 53 | E("social bookmarking sites", L("Reddit", "StumbleUpon", "Digg", "Delicious")); 54 | break; 55 | 56 | case "stanford": 57 | // Web pages are a little harder to parse, 58 | // but may be closer to what general people want to know 59 | E("Stanford undergraduate residence halls", LN("Branner Hall", "Lagunita Court", "Wilbur Hall")); 60 | E("Stanford departments", L("Anesthesia", "Dermatology", "Linguistics", "Geophysics")); 61 | E("stores in Stanford Shopping Center", L("Brookstone", "Gap", "Microsoft", "Urban Outfitters")); 62 | E("Stanford Marguerite lines", L("Line X", "Line O", "SLAC", "Shopping Express", "Bohannon")); 63 | E("dining halls in Stanford", L("Ricker", "Wilbur", "Branner", "Lakeside")); 64 | E("libraries in Stanford", L("Green", "Meyer", "Hoover", "East Asia", "Music")); 65 | break; 66 | 67 | case "route": 68 | // The gas station question is tricky: what should the answer format be? 69 | E("Caltrain stops", L("San Francisco", "Palo Alto", "Mountain View", "Santa Clara", "Millbrae")); 70 | E("Boston red line stations", L("Harvard Square", "Kendall", "Broadway", "South Station", "Braintree")); 71 | E("Tokyo Metro subway lines", L("Ginza", "Chiyoda", "Hibiya", "Namboku", "Fukutoshin", "Marunouchi")); 72 | break; 73 | 74 | case "celeb": 75 | E("Justin Bieber's albums", L("Believe", "Under the Mistletoe", "Never Say Never", "My World 2.0")); 76 | E("Shyamalan's movies", L("The Sixth Sense", "Unbrekable", "After Earth", "Signs")); 77 | E("Rebecca Black singles", L("Friday", "My Moment", "Person of Interest", "Sing It", "In Your Words")); 78 | E("members of The Beetles", LN("John Lennon", "Paul McCartney", "George Harrison", "Ringo Starr")); 79 | E("casts of The Room", LN("Tommy Wiseau", "Greg Sestero", "Juliette Danielle", "Philip Haldiman")); 80 | break; 81 | 82 | case "sport": 83 | E("world cup champions", L(true, "Brazil", "Spain", "Argentina", "Uruguay", "Italy", "France", "England", "Germany")); 84 | E("England football clubs", L("Manchester United", "Liverpool", "Chelsea", "Arsenal", "Manchester City")); 85 | E("football teams in California", L("Raiders", "Chargers", "49ers")); 86 | E("countries in olympics 2012", L("China", "United States", "Australia", "Azerbaijan", "North Korea")); 87 | E("Wimbledon winners in men single", LN("Andy Murray", "Roger Federer", "Rafael Nadal", "Lleyton Hewitt")); 88 | break; 89 | 90 | case "leader": 91 | E("world billionaires", LN("Bill Gates", "Warren Buffett", "Larry Page", "Larry Ellison", "Steve Ballmer")); 92 | E("united states presidents", LN("George Washington", "Thomas Jefferson", "Abraham Lincoln", 93 | "Richard Nixon", "Barack Obama", "Andrew Jackson", "Bill Clinton")); 94 | E("united states vice presidents", LN("Joe Biden", "Al Gore", "Nelson Rockefeller", "Dick Cheney", "Aaron Burr")); 95 | E("leaders of ussr", LN("Vladimir Lenin", "Joseph Stalin", "Nikita Khrushchev", "Leonid Brezhnev", "Mikhail Gorbachev")); 96 | E("provosts of Stanford University", LN("Douglas M. Whitaker", "Gerald J. Lieberman", 97 | "Donald Kennedy", "John Etchemendy", "Richard Wall Lyman", "Condoleezza Rice")); 98 | break; 99 | 100 | case "fiction": 101 | E("Hogwarts Houses", L(true, "Gryffindor", "Hufflepuff", "Ravenclaw", "Slytherin")); 102 | E("main characters of Friends", LN(true, "Rachel Green", "Monica Geller", "Phoebe Buffay", 103 | "Joey Tribbiani", "Chandler Bing", "Ross Geller")); 104 | E("Twilight Saga books", L(true, "Twilight", "New Moon", "Eclipse", "Breaking Dawn")); 105 | E("disney movies", L("Brave", "Alice In Wonderland", "Wall E", "The Jungle Book", "Pinocchio")); 106 | break; 107 | 108 | default: 109 | throw new RuntimeException("Unsupported dataset: " + name); 110 | } 111 | } 112 | }; 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/ling/AveragedWordVector.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.ling; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.ling.LingData.POSType; 6 | import edu.stanford.nlp.semparse.open.model.FeatureVector; 7 | import edu.stanford.nlp.semparse.open.util.VectorAverager; 8 | 9 | /** 10 | * AveragedWordVector computes and stores averaged neural net word vectors. 11 | */ 12 | public class AveragedWordVector { 13 | 14 | // The average (mean) of the word vectors of all tokens 15 | public double[] averaged; 16 | // Divide each word vector by word frequency 17 | public double[] freqWeighted; 18 | // Use only open POS class words 19 | public double[] openPOSOnly; 20 | // Use only open POS class words and divide each word vector by word frequency 21 | public double[] freqWeightedOpenPOSOnly; 22 | // Term-wise minimum and maximum 23 | public double[] min, max, minmax; 24 | 25 | public AveragedWordVector(Collection phrases) { 26 | VectorAverager normalAverager = new VectorAverager(WordVectorTable.numDimensions), 27 | freqWeightedAverager = new VectorAverager(WordVectorTable.numDimensions), 28 | openPOSOnlyAverager = new VectorAverager(WordVectorTable.numDimensions), 29 | freqWeightedOpenPOSAverager = new VectorAverager(WordVectorTable.numDimensions); 30 | for (String phrase : phrases) { 31 | LingData lingData = LingData.get(phrase); 32 | for (int i = 0; i < lingData.length; i++) { 33 | String token = lingData.tokens.get(i); 34 | int freq = BrownClusterTable.getSmoothedFrequency(token); 35 | double[] vector = WordVectorTable.getVector(token); 36 | normalAverager.add(vector); 37 | freqWeightedAverager.add(vector, 1.0 / freq); 38 | if (lingData.posTypes.get(i) == POSType.OPEN) { 39 | openPOSOnlyAverager.add(vector); 40 | freqWeightedOpenPOSAverager.add(vector, 1.0 / freq); 41 | } 42 | } 43 | } 44 | averaged = normalAverager.getAverage(); 45 | freqWeighted = freqWeightedAverager.getAverage(); 46 | openPOSOnly = openPOSOnlyAverager.getAverage(); 47 | freqWeightedOpenPOSOnly = freqWeightedOpenPOSAverager.getAverage(); 48 | min = normalAverager.getMin(); 49 | max = normalAverager.getMax(); 50 | minmax = normalAverager.getMinmax(); 51 | } 52 | 53 | public AveragedWordVector(String phrase) { 54 | // Slightly inefficient, but will not be called often. 55 | this(Arrays.asList(phrase)); 56 | } 57 | 58 | public double[] get(boolean freqWeighted, boolean openPOSOnly) { 59 | if (freqWeighted) { 60 | return openPOSOnly ? this.freqWeightedOpenPOSOnly : this.freqWeighted; 61 | } else { 62 | return openPOSOnly ? this.openPOSOnly : this.averaged; 63 | } 64 | } 65 | 66 | /** 67 | * Add general features of the form name...[i] 68 | * with value = each element of the averaged word vector. 69 | */ 70 | @Deprecated 71 | public void addTermwiseFeatures(FeatureVector v, String domain, String name) { 72 | if (averaged != null) 73 | for (int i = 0; i < WordVectorTable.numDimensions; i++) 74 | v.add(domain, name + "[" + i + "]", averaged[i]); 75 | if (freqWeighted != null) 76 | for (int i = 0; i < WordVectorTable.numDimensions; i++) 77 | v.add(domain, name + "-freq-weighted[" + i + "]", freqWeighted[i]); 78 | if (openPOSOnly != null) 79 | for (int i = 0; i < WordVectorTable.numDimensions; i++) 80 | v.add(domain, name + "-open-pos[" + i + "]", openPOSOnly[i]); 81 | } 82 | 83 | /** 84 | * Add general features of the form name...[i] 85 | * with value = term-wise product between the averaged word vector and the given vector. 86 | */ 87 | @Deprecated 88 | public void addTermwiseFeatures(FeatureVector v, String domain, String name, double[] factor) { 89 | if (factor == null) return; 90 | if (averaged != null) 91 | for (int i = 0; i < WordVectorTable.numDimensions; i++) 92 | v.add(domain, name + "[" + i + "]", averaged[i] * factor[i]); 93 | if (freqWeighted != null) 94 | for (int i = 0; i < WordVectorTable.numDimensions; i++) 95 | v.add(domain, name + "-freq-weighted[" + i + "]", freqWeighted[i] * factor[i]); 96 | if (openPOSOnly != null) 97 | for (int i = 0; i < WordVectorTable.numDimensions; i++) 98 | v.add(domain, name + "-open-pos[" + i + "]", openPOSOnly[i] * factor[i]); 99 | } 100 | 101 | // Too slow and memory consuming 102 | @Deprecated 103 | public static void addCrossProductFeatures(FeatureVector v, String domain, String name1, String name2, 104 | double[] factor1, double[] factor2) { 105 | if (factor1 == null || factor2 == null) return; 106 | for (int i = 0; i < WordVectorTable.numDimensions; i++) { 107 | for (int j = 0; j < WordVectorTable.numDimensions; j++) { 108 | v.add(domain, name1 + "[" + i + "]*" + name2 + "[" + j + "]", factor1[i] * factor2[j]); 109 | } 110 | } 111 | } 112 | 113 | } 114 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/ling/BrownClusterTable.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.ling; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.IOException; 5 | import java.nio.charset.Charset; 6 | import java.nio.file.Files; 7 | import java.nio.file.Path; 8 | import java.nio.file.Paths; 9 | import java.util.*; 10 | 11 | 12 | import fig.basic.LogInfo; 13 | import fig.basic.Option; 14 | 15 | public class BrownClusterTable { 16 | public static class Options { 17 | @Option public String brownClusterFilename = null; 18 | } 19 | public static Options opts = new Options(); 20 | 21 | public static Map wordClusterMap; 22 | public static Map wordFrequencyMap; 23 | 24 | public static void initModels() { 25 | if (wordClusterMap != null || opts.brownClusterFilename == null || opts.brownClusterFilename.isEmpty()) return; 26 | Path dataPath = Paths.get(opts.brownClusterFilename); 27 | LogInfo.logs("Reading Brown clusters from %s", dataPath); 28 | try (BufferedReader in = Files.newBufferedReader(dataPath, Charset.forName("UTF-8"))) { 29 | wordClusterMap = new HashMap<>(); 30 | wordFrequencyMap = new HashMap<>(); 31 | String line = null; 32 | while ((line = in.readLine()) != null) { 33 | String[] tokens = line.split("\t"); 34 | wordClusterMap.put(tokens[1], tokens[0].intern()); 35 | wordFrequencyMap.put(tokens[1], Integer.parseInt(tokens[2])); 36 | } 37 | } catch (IOException e) { 38 | LogInfo.fails("Cannot load Brown cluster from %s", dataPath); 39 | } 40 | } 41 | 42 | public static String getCluster(String word) { 43 | initModels(); 44 | return wordClusterMap.get(word); 45 | } 46 | 47 | public static String getClusterPrefix(String word, int length) { 48 | initModels(); 49 | String answer = wordClusterMap.get(word); 50 | if (answer == null) return null; 51 | return answer.substring(0, Math.min(length, answer.length())); 52 | } 53 | 54 | public static final int[] DEFAULT_PREFIXES = {4, 6, 10, 20}; 55 | 56 | public static List getDefaultClusterPrefixes(String cluster) { 57 | List answer = new ArrayList<>(); 58 | if (cluster != null) 59 | for (int length : DEFAULT_PREFIXES) 60 | answer.add("[" + length + "]" + cluster.substring(0, Math.min(length, cluster.length()))); 61 | return answer; 62 | } 63 | 64 | public static List getDefaultClusterPrefixesFromWord(String word) { 65 | return getDefaultClusterPrefixes(getCluster(word)); 66 | } 67 | 68 | public static List getDefaultClusterPrefixes(String cluster1, String cluster2) { 69 | List answer = new ArrayList<>(); 70 | for (int length : DEFAULT_PREFIXES) { 71 | answer.add(cluster1.substring(0, Math.min(length, cluster1.length())) 72 | + "|" + cluster2.substring(0, Math.min(length, cluster2.length()))); 73 | } 74 | return answer; 75 | } 76 | 77 | public static int getSmoothedFrequency(String word) { 78 | initModels(); 79 | Integer frequency = wordFrequencyMap.get(word); 80 | if (frequency == null) return 1; 81 | return frequency + 1; 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/ling/ClusterRepnUtils.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.ling; 2 | 3 | import fig.basic.StrUtils; 4 | 5 | /* 6 | To run: 7 | java -Xmx30g edu.stanford.nlp.semparse.open.CreateTypeEntityFeatures 8 | */ 9 | public class ClusterRepnUtils { 10 | public static String getRepn(String s) { 11 | if (s.equals("")) return "EMPTY"; 12 | StringBuilder buf = new StringBuilder(); 13 | if (true) { 14 | s = s.replaceAll(",", " , "); 15 | s = s.replaceAll("!", " ! "); 16 | s = s.replaceAll(":", " : "); 17 | } 18 | for (String x : s.split(" ")) { 19 | if (x.equals("")) continue; 20 | 21 | if (buf.length() > 0) buf.append(' '); 22 | String c = BrownClusterTable.getCluster(x); 23 | if (c == null) c = LingUtils.computePhraseShape(x); // Unknown: replace with word form 24 | buf.append(c); 25 | } 26 | return buf.toString(); 27 | } 28 | 29 | public static void main(String[] args) { 30 | System.out.println(getRepn(StrUtils.join(args))); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/ling/CreateTypeEntityFeatures.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.ling; 2 | 3 | import static fig.basic.LogInfo.begin_track; 4 | import static fig.basic.LogInfo.end_track; 5 | import static fig.basic.LogInfo.logs; 6 | 7 | import java.io.BufferedReader; 8 | import java.io.IOException; 9 | import java.io.PrintWriter; 10 | import java.util.HashMap; 11 | import java.util.HashSet; 12 | import java.util.Map; 13 | import java.util.Set; 14 | 15 | import fig.basic.IOUtils; 16 | import fig.basic.MapUtils; 17 | import fig.basic.Option; 18 | import fig.basic.TVMap; 19 | import fig.exec.Execution; 20 | 21 | /* 22 | To run: 23 | java -Xmx30g edu.stanford.nlp.semparse.open.CreateTypeEntityFeatures 24 | */ 25 | public class CreateTypeEntityFeatures implements Runnable { 26 | Map clusterMap = new HashMap(); 27 | Map typeEntityCluster1Counts = new HashMap(); 28 | Map typeEntityCluster2Counts = new HashMap(); 29 | 30 | // Hack 31 | String pluralize(String s) { 32 | if (s.endsWith("y")) return s.substring(0, s.length()-1) + "ies"; 33 | if (s.endsWith("s")) return s + "es"; 34 | return s + "s"; 35 | } 36 | 37 | TVMap nameMap = new TVMap(); 38 | 39 | @Option public String namesPath = "/u/nlp/data/semparse/scr/freebase/names.ttl"; 40 | @Option public String typesPath = "/u/nlp/data/semparse/scr/freebase/types.ttl"; 41 | @Option public int maxLines = Integer.MAX_VALUE; 42 | 43 | void readNames() { 44 | begin_track("Reading names"); 45 | try { 46 | String line; 47 | BufferedReader in = IOUtils.openIn(namesPath); 48 | int numLines = 0; 49 | while ((line = in.readLine()) != null && numLines++ < maxLines) { 50 | String[] tokens = line.split("\t"); 51 | 52 | String id = tokens[0]; 53 | if (id.startsWith("fb:user")) continue; 54 | if (id.startsWith("fb:base")) continue; 55 | if (id.startsWith("fb:freebase")) continue; 56 | if (id.startsWith("fb:common")) continue; 57 | if (id.startsWith("fb:type")) continue; 58 | if (id.startsWith("fb:measurement_unit")) continue; 59 | 60 | // Remove "@en. 61 | String name = tokens[2].substring(1, tokens[2].length() - 5); 62 | nameMap.put(id, name); 63 | 64 | if (numLines % 1000000 == 0) logs("%d lines", numLines); 65 | } 66 | in.close(); 67 | } catch (IOException e) { 68 | throw new RuntimeException(e); 69 | } 70 | end_track(); 71 | } 72 | 73 | public void run() { 74 | readNames(); 75 | readTypes(); 76 | outputCounts(); 77 | } 78 | 79 | void readTypes() { 80 | begin_track("Reading entities"); 81 | try { 82 | String line; 83 | BufferedReader in = IOUtils.openIn(typesPath); 84 | Set hit = new HashSet(); 85 | int numLines = 0; 86 | // Write raw strings 87 | PrintWriter out = IOUtils.openOutHard("/u/nlp/data/open-semparse/scr/freebase/types-entities.tsv"); 88 | while ((line = in.readLine()) != null && numLines++ < maxLines) { 89 | String[] tokens = line.split("\t"); 90 | String entityId = tokens[0]; 91 | String origTypeId = tokens[2].substring(0, tokens[2].length()-1); 92 | 93 | String entity = nameMap.get(entityId, null); 94 | if (entity == null) continue; 95 | String entityCluster = ClusterRepnUtils.getRepn(entity); 96 | 97 | String origType = nameMap.get(origTypeId, null); 98 | if (origType == null) continue; 99 | 100 | out.println(origTypeId + "\t" + origType + "\t" + entity); 101 | 102 | String[] typeTokens = origType.split(" "); 103 | if (typeTokens.length > 3) continue; // Only keep short phrases 104 | 105 | MapUtils.incr(typeEntityCluster1Counts, origType + "\t" + entityCluster, 1); 106 | 107 | String headType = typeTokens[typeTokens.length-1]; // Just take head 108 | for (String type : headType.split("/")) { 109 | type = pluralize(type.toLowerCase()); // Process the type 110 | if (!hit.contains(type)) { 111 | logs("new type: %s", type); 112 | hit.add(type); 113 | } 114 | String typeCluster = ClusterRepnUtils.getRepn(type); 115 | MapUtils.incr(typeEntityCluster2Counts, typeCluster + "\t" + entityCluster, 1); 116 | } 117 | 118 | if (numLines % 1000000 == 0) logs("%d lines", numLines); 119 | } 120 | in.close(); 121 | out.close(); 122 | } catch (IOException e) { 123 | throw new RuntimeException(e); 124 | } 125 | end_track(); 126 | } 127 | 128 | void outputCounts() { 129 | // Write out counts 130 | begin_track("Writing out %d counts", typeEntityCluster1Counts.size()); 131 | PrintWriter out = IOUtils.openOutHard("/u/nlp/data/open-semparse/scr/freebase/types-entities-cluster1-counts.tsv"); 132 | for (Map.Entry e : typeEntityCluster1Counts.entrySet()) 133 | out.println(e.getKey() + "\t" + e.getValue()); 134 | out.close(); 135 | end_track(); 136 | 137 | begin_track("Writing out %d counts", typeEntityCluster2Counts.size()); 138 | out = IOUtils.openOutHard("/u/nlp/data/open-semparse/scr/freebase/types-entities-cluster2-counts.tsv"); 139 | for (Map.Entry e : typeEntityCluster2Counts.entrySet()) 140 | out.println(e.getKey() + "\t" + e.getValue()); 141 | out.close(); 142 | end_track(); 143 | } 144 | 145 | public static void main(String[] args) { 146 | Execution.run(args, new CreateTypeEntityFeatures()); 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/ling/FrequencyTable.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.ling; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.IOException; 5 | import java.nio.charset.Charset; 6 | import java.nio.file.Files; 7 | import java.nio.file.Path; 8 | import java.nio.file.Paths; 9 | import java.util.*; 10 | 11 | import fig.basic.LogInfo; 12 | import fig.basic.Option; 13 | 14 | public class FrequencyTable { 15 | public static class Options { 16 | @Option public String frequencyFilename = null; 17 | @Option public List frequencyAmounts = Arrays.asList(30, 300, 3000); 18 | } 19 | public static Options opts = new Options(); 20 | 21 | public static Map> topWordsLists; 22 | 23 | public static void initModels() { 24 | if (topWordsLists != null || opts.frequencyFilename == null || opts.frequencyFilename.isEmpty()) return; 25 | Path dataPath = Paths.get(opts.frequencyFilename); 26 | LogInfo.logs("Reading word frequency from %s", dataPath); 27 | List words = new ArrayList<>(); 28 | try (BufferedReader in = Files.newBufferedReader(dataPath, Charset.forName("UTF-8"))) { 29 | String line = null; 30 | while ((line = in.readLine()) != null) { 31 | String[] tokens = line.split("\t"); 32 | words.add(tokens[0]); 33 | } 34 | } catch (IOException e) { 35 | LogInfo.fails("Cannot load word frequency from %s", dataPath); 36 | } 37 | topWordsLists = new HashMap<>(); 38 | for (int amount : opts.frequencyAmounts) { 39 | topWordsLists.put(amount, new HashSet<>(words.subList(0, amount))); 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/ling/LingTester.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.ling; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.IOException; 5 | import java.io.InputStreamReader; 6 | 7 | public class LingTester { 8 | 9 | public static void main(String args[]) throws IOException { 10 | try (BufferedReader in = new BufferedReader(new InputStreamReader(System.in))) { 11 | while (true) { 12 | System.out.println("Please enter a sentence:"); 13 | String input = in.readLine(); 14 | LingData data = LingData.get(input); 15 | System.out.println(data.tokens); 16 | System.out.println(data.posTags); 17 | System.out.println(data.nerTags); 18 | System.out.println(data.posTypes); 19 | System.out.println(data.nerValues); 20 | System.out.println(data.lemmaTokens); 21 | } 22 | } 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/ling/QueryTypeTable.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.ling; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.IOException; 5 | import java.nio.charset.Charset; 6 | import java.nio.file.Files; 7 | import java.nio.file.Path; 8 | import java.nio.file.Paths; 9 | import java.util.*; 10 | 11 | import fig.basic.LogInfo; 12 | import fig.basic.Option; 13 | 14 | public class QueryTypeTable { 15 | public static class Options { 16 | @Option public String queryTypeFilename = null; 17 | } 18 | public static Options opts = new Options(); 19 | 20 | public static Map queryHeadwordMap, queryTypeMap; 21 | 22 | public static void initModels() { 23 | if (queryTypeMap != null || opts.queryTypeFilename == null || opts.queryTypeFilename.isEmpty()) return; 24 | Path dataPath = Paths.get(opts.queryTypeFilename); 25 | LogInfo.logs("Reading query types from %s", dataPath); 26 | try (BufferedReader in = Files.newBufferedReader(dataPath, Charset.forName("UTF-8"))) { 27 | queryHeadwordMap = new HashMap<>(); 28 | queryTypeMap = new HashMap<>(); 29 | String line = null; 30 | while ((line = in.readLine()) != null) { 31 | String[] tokens = line.split("\t"); 32 | queryHeadwordMap.put(tokens[0], tokens[1]); 33 | queryTypeMap.put(tokens[0], tokens[2]); 34 | } 35 | } catch (IOException e) { 36 | LogInfo.fails("Cannot load query types from %s", dataPath); 37 | } 38 | } 39 | 40 | public static String getQueryHeadword(String query) { 41 | initModels(); 42 | return queryHeadwordMap.get(query); 43 | } 44 | 45 | public static String getQueryType(String query) { 46 | initModels(); 47 | return queryTypeMap.get(query); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/ling/WordNetClusterTable.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.ling; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.IOException; 5 | import java.nio.charset.Charset; 6 | import java.nio.file.Files; 7 | import java.nio.file.Path; 8 | import java.nio.file.Paths; 9 | import java.util.*; 10 | 11 | import fig.basic.LogInfo; 12 | import fig.basic.Option; 13 | 14 | public class WordNetClusterTable { 15 | public static class Options { 16 | @Option public String wordnetClusterFilename = null; 17 | } 18 | public static Options opts = new Options(); 19 | 20 | public static Map wordClusterMap; 21 | 22 | public static void initModels() { 23 | if (wordClusterMap != null || opts.wordnetClusterFilename == null || opts.wordnetClusterFilename.isEmpty()) return; 24 | Path dataPath = Paths.get(opts.wordnetClusterFilename); 25 | LogInfo.logs("Reading WordNet clusters from %s", dataPath); 26 | try (BufferedReader in = Files.newBufferedReader(dataPath, Charset.forName("UTF-8"))) { 27 | wordClusterMap = new HashMap<>(); 28 | String line = null; 29 | while ((line = in.readLine()) != null) { 30 | String[] tokens = line.split("\t"); 31 | wordClusterMap.put(tokens[0], tokens[1]); 32 | } 33 | } catch (IOException e) { 34 | LogInfo.fails("Cannot load WordNet clusters from %s", dataPath); 35 | } 36 | } 37 | 38 | public static String getCluster(String word) { 39 | initModels(); 40 | return wordClusterMap.get(word); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/ling/WordVectorTable.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.ling; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.IOException; 5 | import java.nio.charset.Charset; 6 | import java.nio.file.Files; 7 | import java.nio.file.Path; 8 | import java.nio.file.Paths; 9 | import java.util.*; 10 | 11 | import fig.basic.LogInfo; 12 | import fig.basic.Option; 13 | 14 | public class WordVectorTable { 15 | public static class Options { 16 | @Option public String wordVectorFilename = null; 17 | 18 | @Option(gloss = "vector to use for UNKNOWN words (-1 = don't use any vector)") 19 | public int wordVectorUNKindex = 0; 20 | } 21 | public static Options opts = new Options(); 22 | 23 | public static Map wordToIndex; 24 | public static double[][] wordVectors; 25 | public static int numWords, numDimensions; 26 | 27 | public static void initModels() { 28 | if (wordVectors != null || opts.wordVectorFilename == null || opts.wordVectorFilename.isEmpty()) return; 29 | Path dataPath = Paths.get(opts.wordVectorFilename); 30 | LogInfo.logs("Reading word vectors from %s", dataPath); 31 | try (BufferedReader in = Files.newBufferedReader(dataPath, Charset.forName("UTF-8"))) { 32 | String[] headerTokens = in.readLine().split(" "); 33 | numWords = Integer.parseInt(headerTokens[0]); 34 | numDimensions = Integer.parseInt(headerTokens[1]); 35 | wordToIndex = new HashMap<>(); 36 | wordVectors = new double[numWords][numDimensions]; 37 | for (int i = 0; i < numWords; i++) { 38 | String[] tokens = in.readLine().split(" "); 39 | wordToIndex.put(tokens[0], i); 40 | for (int j = 0; j < numDimensions; j++) { 41 | wordVectors[i][j] = Double.parseDouble(tokens[j+1]); 42 | } 43 | } 44 | LogInfo.logs("Neural network vectors: %s words; %s dimensions per word", numWords, numDimensions); 45 | } catch (IOException e) { 46 | LogInfo.fails("Cannot load neural network vectors from %s", dataPath); 47 | } 48 | } 49 | 50 | public static double[] getVector(String word) { 51 | initModels(); 52 | Integer index = wordToIndex.get(word); 53 | if (index == null) { 54 | index = opts.wordVectorUNKindex; 55 | if (index < 0) return null; 56 | } 57 | return wordVectors[index]; 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/AdvancedWordVectorGradient.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model; 2 | 3 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 4 | 5 | public interface AdvancedWordVectorGradient { 6 | 7 | public void addToGradient(Candidate candidate, double factor); 8 | public void addL2Regularization(double beta); 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/AdvancedWordVectorParams.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model; 2 | 3 | import edu.stanford.nlp.semparse.open.ling.WordVectorTable; 4 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 5 | import fig.basic.Option; 6 | 7 | /** 8 | * Parameters for advanced word vector features. 9 | */ 10 | public abstract class AdvancedWordVectorParams { 11 | public static class Options { 12 | @Option(gloss = "Whether to use full rank") 13 | public boolean vecFullRank = true; 14 | 15 | @Option(gloss = "Use pooling (vecOpenPOSOnly and vecFreqWeighted will be ignored)") 16 | public boolean vecPooling = false; 17 | 18 | @Option(gloss = "Only use Open POS words") 19 | public boolean vecOpenPOSOnly = false; 20 | 21 | @Option(gloss = "Use frequency-weighted vectors") 22 | public boolean vecFreqWeighted = false; 23 | } 24 | public static Options opts = new Options(); 25 | 26 | public static AdvancedWordVectorParams create() { 27 | if (opts.vecFullRank) 28 | return new AdvancedWordVectorParamsFullRank(); 29 | else 30 | return new AdvancedWordVectorParamsLowRank(); 31 | } 32 | 33 | protected static int getDim() { 34 | if (opts.vecPooling) { 35 | return 2 * WordVectorTable.numDimensions; 36 | } 37 | return WordVectorTable.numDimensions; 38 | } 39 | 40 | protected static double[] getX(Candidate candidate) { 41 | candidate.ex.initAveragedWordVector(); 42 | if (opts.vecPooling) { 43 | return candidate.ex.averagedWordVector.minmax; 44 | } 45 | return candidate.ex.averagedWordVector.get(opts.vecFreqWeighted, opts.vecOpenPOSOnly); 46 | } 47 | 48 | protected static double[] getY(Candidate candidate) { 49 | candidate.group.initAveragedWordVector(); 50 | if (opts.vecPooling) { 51 | return candidate.group.averagedWordVector.minmax; 52 | } 53 | return candidate.group.averagedWordVector.get(opts.vecFreqWeighted, opts.vecOpenPOSOnly); 54 | } 55 | 56 | public abstract double getScore(Candidate candidate); 57 | 58 | public abstract AdvancedWordVectorGradient createGradient(); 59 | public abstract void update(AdvancedWordVectorGradient gradient); 60 | public abstract void applyL1Regularization(double cutoff); 61 | 62 | public abstract void log(); 63 | public abstract void logFeatureWeights(Candidate candidate); 64 | public abstract void logFeatureDiff(Candidate trueCandidate, Candidate predCandidate); 65 | } 66 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/AdvancedWordVectorParamsFullRank.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model; 2 | 3 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 4 | import fig.basic.Fmt; 5 | import fig.basic.LogInfo; 6 | 7 | public class AdvancedWordVectorParamsFullRank extends AdvancedWordVectorParams { 8 | 9 | double[][] weights; 10 | 11 | // Shorthands for word vector dimension 12 | protected final int dim; 13 | 14 | public AdvancedWordVectorParamsFullRank() { 15 | dim = getDim(); 16 | weights = new double[dim][dim]; 17 | if (Params.opts.initWeightsRandomly) { 18 | for (int i = 0; i < dim; i++) { 19 | for (int j = 0; j < dim; j++) { 20 | weights[i][j] = 2 * Params.opts.initRandom.nextDouble() - 1; 21 | } 22 | } 23 | } 24 | initGradientStats(); 25 | } 26 | 27 | // ============================================================ 28 | // Get score 29 | // ============================================================ 30 | 31 | @Override 32 | public double getScore(Candidate candidate) { 33 | return getScore(getX(candidate), getY(candidate)); 34 | } 35 | 36 | public double getScore(double[] x, double[] y) { 37 | if (x == null || y == null) 38 | return 0; 39 | double answer = 0; 40 | for (int i = 0; i < dim; i++) { 41 | for (int j = 0; j < dim; j++) { 42 | answer += weights[i][j] * x[i] * y[j]; 43 | } 44 | } 45 | return answer; 46 | } 47 | 48 | // ============================================================ 49 | // Compute gradient 50 | // ============================================================ 51 | 52 | @Override 53 | public AdvancedWordVectorGradient createGradient() { 54 | return new AdvancedWordVectorGradientFullRank(); 55 | } 56 | 57 | class AdvancedWordVectorGradientFullRank implements AdvancedWordVectorGradient { 58 | protected final double grad[][]; 59 | 60 | public AdvancedWordVectorGradientFullRank() { 61 | grad = new double[dim][dim]; 62 | } 63 | 64 | @Override 65 | public void addToGradient(Candidate candidate, double factor) { 66 | addToGradient(getX(candidate), getY(candidate), factor); 67 | } 68 | 69 | /** 70 | * Compute the gradient for the word vector pair (x,y) and add it to the 71 | * accumulative gradient. 72 | */ 73 | private void addToGradient(double[] x, double[] y, double factor) { 74 | if (x == null || y == null) 75 | return; 76 | for (int i = 0; i < dim; i++) { 77 | for (int j = 0; j < dim; j++) { 78 | grad[i][j] += x[i] * y[j] * factor; 79 | } 80 | } 81 | } 82 | 83 | @Override 84 | public void addL2Regularization(double beta) { 85 | for (int i = 0; i < dim; i++) { 86 | for (int j = 0; j < dim; j++) { 87 | grad[i][j] -= beta * weights[i][j]; 88 | } 89 | } 90 | } 91 | } 92 | 93 | // ============================================================ 94 | // Weight update 95 | // ============================================================ 96 | 97 | // For AdaGrad 98 | double[][] sumSquaredGradients; 99 | 100 | // For dual averaging 101 | double[][] sumGradients; 102 | 103 | protected void initGradientStats() { 104 | if (Params.opts.adaptiveStepSize) 105 | sumSquaredGradients = new double[dim][dim]; 106 | if (Params.opts.dualAveraging) 107 | sumGradients = new double[dim][dim]; 108 | } 109 | 110 | // Number of stochastic updates we've made so far (for determining step size). 111 | int numUpdates; 112 | 113 | @Override 114 | public void update(AdvancedWordVectorGradient gradient) { 115 | AdvancedWordVectorGradientFullRank grad = (AdvancedWordVectorGradientFullRank) gradient; 116 | numUpdates++; 117 | for (int i = 0; i < dim; i++) { 118 | for (int j = 0; j < dim; j++) { 119 | double g = grad.grad[i][j]; 120 | if (Math.abs(g) < 1e-6) continue; 121 | double stepSize; 122 | if (Params.opts.adaptiveStepSize) { 123 | sumSquaredGradients[i][j] += g * g; 124 | stepSize = Params.opts.initStepSize / Math.sqrt(sumSquaredGradients[i][j]); 125 | } else { 126 | stepSize = Params.opts.initStepSize / Math.pow(numUpdates, Params.opts.stepSizeReduction); 127 | } 128 | if (Params.opts.dualAveraging) { 129 | sumGradients[i][j] += g; 130 | weights[i][j] = stepSize * sumGradients[i][j]; 131 | } else { 132 | weights[i][j] += stepSize * g; 133 | } 134 | } 135 | } 136 | } 137 | 138 | @Override 139 | public void applyL1Regularization(double cutoff) { 140 | if (cutoff <= 0) 141 | return; 142 | for (int i = 0; i < dim; i++) { 143 | for (int j = 0; j < dim; j++) { 144 | weights[i][j] = Params.L1Cut(weights[i][j], cutoff); 145 | } 146 | } 147 | } 148 | 149 | // ============================================================ 150 | // Logging 151 | // ============================================================ 152 | 153 | @Override 154 | public void log() { 155 | LogInfo.begin_track("Advanced Word Vector Params"); 156 | for (int i = 0; i < dim; i++) { 157 | StringBuilder sb = new StringBuilder(); 158 | for (int j = 0; j < dim; j++) { 159 | sb.append(String.format("%10s ", Fmt.D(weights[i][j]))); 160 | } 161 | LogInfo.log(sb.toString()); 162 | } 163 | LogInfo.end_track(); 164 | } 165 | 166 | @Override 167 | public void logFeatureWeights(Candidate candidate) { 168 | LogInfo.begin_track("Advanced Word Vector feature weights"); 169 | double[] x = getX(candidate), y = getY(candidate); 170 | if (x == null) { 171 | LogInfo.log("NONE: x (query word vector) is null"); 172 | } else if (y == null) { 173 | LogInfo.log("NONE: y (entities word vector) is null"); 174 | } else { 175 | LogInfo.logs("Advanced Word Vector: %s", Fmt.D(getScore(x, y))); 176 | } 177 | LogInfo.end_track(); 178 | } 179 | 180 | @Override 181 | public void logFeatureDiff(Candidate trueCandidate, Candidate predCandidate) { 182 | LogInfo.begin_track("Advanced Word Vector feature weights"); 183 | // The candidates should be from the same example --> assume x are the same 184 | double[] x = getX(trueCandidate), yTrue = getY(trueCandidate), yPred = getY(predCandidate); 185 | if (x == null) { 186 | LogInfo.log("NONE: x (query word vector) is null"); 187 | } else if (yTrue == null) { 188 | LogInfo.log("NONE: y (entities word vector) is null for trueCandidate"); 189 | } else if (yPred == null) { 190 | LogInfo.log("NONE: y (entities word vector) is null for predCandidate"); 191 | } else { 192 | double trueScore = getScore(x, yTrue), predScore = getScore(x, yPred); 193 | LogInfo.logs("Advanced Word Vector: %s [ %s - %s ]", Fmt.D(trueScore - predScore), 194 | Fmt.D(trueScore), Fmt.D(predScore)); 195 | } 196 | LogInfo.end_track(); 197 | } 198 | 199 | } 200 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/FeatureCountPruner.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.Example; 6 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 7 | import edu.stanford.nlp.semparse.open.util.Multiset; 8 | import fig.basic.LogInfo; 9 | 10 | public class FeatureCountPruner implements FeatureMatcher { 11 | 12 | public Multiset counts = new Multiset<>(); 13 | public boolean beVeryQuiet; 14 | 15 | public FeatureCountPruner(boolean beVeryQuiet) { 16 | this.beVeryQuiet = beVeryQuiet; 17 | } 18 | 19 | /** 20 | * Add features from the example to the count. 21 | * 22 | * The same feature within the same example counts as 1 feature. 23 | */ 24 | public void add(Example example) { 25 | if (!beVeryQuiet) LogInfo.begin_track("Collecting features from %s ...", example); 26 | Set uniqued = new HashSet<>(); 27 | for (Candidate candidate : example.candidates) { 28 | for (String name : candidate.getCombinedFeatures().keySet()) { 29 | uniqued.add(name); 30 | } 31 | } 32 | for (String name : uniqued) counts.add(name); 33 | if (!beVeryQuiet) LogInfo.end_track(); 34 | } 35 | 36 | /** 37 | * Prune the features with count < minimumCount 38 | */ 39 | public void applyThreshold(int minimumCount) { 40 | if (!beVeryQuiet) LogInfo.begin_track("Pruning features with count < %d ...", minimumCount); 41 | if (!beVeryQuiet) LogInfo.logs("Original #Features: %d", counts.elementSet().size()); 42 | counts = counts.getPrunedByCount(minimumCount); 43 | if (!beVeryQuiet) LogInfo.logs("Pruned #Features: %d", counts.elementSet().size()); 44 | if (!beVeryQuiet) LogInfo.end_track(); 45 | } 46 | 47 | @Override 48 | public boolean matches(String feature) { 49 | return counts.contains(feature); 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/FeatureDomainPruner.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model; 2 | 3 | public class FeatureDomainPruner implements FeatureMatcher { 4 | 5 | public enum FeatureDomainPrunerType { ONLY_ALLOW_DOMAIN, ONLY_DISALLOW_DOMAIN }; 6 | 7 | public final String domainPrefix; 8 | public final FeatureDomainPrunerType type; 9 | 10 | public FeatureDomainPruner(String domain, FeatureDomainPrunerType type) { 11 | this.domainPrefix = domain + " :: "; 12 | this.type = type; 13 | } 14 | 15 | @Override 16 | public boolean matches(String feature) { 17 | // Always allow "basic" 18 | if (feature.startsWith("basic :: ")) return true; 19 | boolean matched = feature.startsWith(domainPrefix); 20 | if (type == FeatureDomainPrunerType.ONLY_ALLOW_DOMAIN) { 21 | return matched; 22 | } else { 23 | return !matched; 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/FeatureMatcher.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model; 2 | 3 | /** 4 | * Used to select a subset of features (to update). 5 | */ 6 | public interface FeatureMatcher { 7 | public boolean matches(String feature); 8 | } 9 | 10 | /** Matches all features **/ 11 | class AllFeatureMatcher implements FeatureMatcher { 12 | private AllFeatureMatcher() { } 13 | 14 | @Override 15 | public boolean matches(String feature) { return true; } 16 | 17 | public static final AllFeatureMatcher matcher = new AllFeatureMatcher(); 18 | } 19 | 20 | /** Matches only the specified feature **/ 21 | class ExactFeatureMatcher implements FeatureMatcher { 22 | private final String match; 23 | public ExactFeatureMatcher(String match) { this.match = match; } 24 | 25 | @Override 26 | public boolean matches(String feature) { return feature.equals(match); } 27 | } 28 | 29 | /** Matches only if all feature matchers in the list match **/ 30 | class ConjunctiveFeatureMatcher implements FeatureMatcher { 31 | private final FeatureMatcher[] matchers; 32 | public ConjunctiveFeatureMatcher(FeatureMatcher... matchers) { this.matchers = matchers; } 33 | 34 | @Override 35 | public boolean matches(String feature) { 36 | for (FeatureMatcher matcher : matchers) 37 | if (!matcher.matches(feature)) return false; 38 | return true; 39 | } 40 | } -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/Learner.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model; 2 | 3 | import java.util.List; 4 | 5 | import edu.stanford.nlp.semparse.open.core.eval.IterativeTester; 6 | import edu.stanford.nlp.semparse.open.dataset.Dataset; 7 | import edu.stanford.nlp.semparse.open.dataset.Example; 8 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 9 | import fig.basic.Pair; 10 | 11 | public interface Learner { 12 | 13 | // ============================================================ 14 | // Log 15 | // ============================================================ 16 | 17 | public void logParam(); 18 | public void logFeatureWeights(Candidate candidate); 19 | public void logFeatureDiff(Candidate trueCandidate, Candidate predCandidate); 20 | public void shutUp(); 21 | 22 | // ============================================================ 23 | // Predict 24 | // ============================================================ 25 | 26 | public List> getRankedCandidates(Example example); 27 | 28 | // ============================================================ 29 | // Learn 30 | // ============================================================ 31 | 32 | public void learn(Dataset dataset, FeatureMatcher additionalFeatureMatcher); 33 | public void setIterativeTester(IterativeTester tester); 34 | 35 | // ============================================================ 36 | // Persistence 37 | // ============================================================ 38 | 39 | public void saveModel(String path); 40 | public void loadModel(String path); 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/LearnerBaseline.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.core.eval.IterativeTester; 6 | import edu.stanford.nlp.semparse.open.dataset.Dataset; 7 | import edu.stanford.nlp.semparse.open.dataset.Example; 8 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 9 | import edu.stanford.nlp.semparse.open.model.candidate.PathEntry; 10 | import fig.basic.LogInfo; 11 | import fig.basic.MapUtils; 12 | import fig.basic.Option; 13 | import fig.basic.Pair; 14 | import fig.basic.ValueComparator; 15 | 16 | /** 17 | * Baseline classifier. 18 | */ 19 | public class LearnerBaseline implements Learner { 20 | public static class Options { 21 | @Option public int baselineSuffixLength = 5; 22 | @Option public int baselineMaxNumPatterns = 10000; 23 | @Option public boolean baselineUseMaxSize = false; // false = use most frequent 24 | @Option public IndexType baselineIndexType = IndexType.STAR; 25 | @Option public boolean baselineBagOfTags = true; 26 | } 27 | public static Options opts = new Options(); 28 | 29 | public enum IndexType {NONE, STAR, FULL}; 30 | 31 | protected IterativeTester iterativeTester; 32 | public boolean beVeryQuiet = false; 33 | 34 | /* 35 | * IDEA: 36 | * - Look at the training data and record the most frequent tree pattern (suffix) 37 | * - For a test example, find a suffix that matches -- maybe choose the longest one 38 | */ 39 | 40 | // Map from suffix to count 41 | Map, Integer> goodPathCounts; 42 | 43 | // ============================================================ 44 | // Log 45 | // ============================================================ 46 | 47 | @Override 48 | public void logParam() { 49 | LogInfo.begin_track("Params"); 50 | if (goodPathCounts == null) { 51 | LogInfo.log("No parameters."); 52 | } else { 53 | List, Integer>> entries = new ArrayList<>(goodPathCounts.entrySet()); 54 | Collections.sort(entries, new ValueComparator, Integer>(true)); 55 | for (Map.Entry, Integer> entry : entries) { 56 | LogInfo.logs("%8d : %s", entry.getValue(), entry.getKey()); 57 | } 58 | } 59 | LogInfo.end_track(); 60 | } 61 | 62 | @Override 63 | public void logFeatureWeights(Candidate candidate) { 64 | LogInfo.log("Using BASELINE Learner - no features"); 65 | } 66 | 67 | @Override 68 | public void logFeatureDiff(Candidate trueCandidate, Candidate predCandidate) { 69 | LogInfo.log("Using BASELINE Learner - no features"); 70 | } 71 | 72 | 73 | @Override 74 | public void shutUp() { 75 | beVeryQuiet = true; 76 | } 77 | 78 | // ============================================================ 79 | // Predict 80 | // ============================================================ 81 | 82 | @Override 83 | public List> getRankedCandidates(Example example) { 84 | List> answer = new ArrayList<>(); 85 | for (Candidate candidate : example.candidates) { 86 | double score = getScore(candidate); 87 | answer.add(new Pair(candidate, score)); 88 | } 89 | Collections.sort(answer, new Pair.ReverseSecondComparator()); 90 | return answer; 91 | } 92 | 93 | 94 | protected double getScore(Candidate candidate) { 95 | List suffix = getPathSuffix(candidate); 96 | Integer frequency = goodPathCounts.get(suffix); 97 | if (frequency == null) return 0; 98 | return opts.baselineUseMaxSize ? candidate.predictedEntities.size() : frequency; 99 | } 100 | 101 | // ============================================================ 102 | // Learn 103 | // ============================================================ 104 | 105 | @Override 106 | public void setIterativeTester(IterativeTester tester) { 107 | this.iterativeTester = tester; 108 | } 109 | 110 | @Override 111 | public void learn(Dataset dataset, FeatureMatcher additionalFeatureMatcher) { 112 | Map, Integer> pathCounts = new HashMap<>(); 113 | dataset.cacheRewards(); 114 | // Learn good tree patterns (path suffix) 115 | if (!beVeryQuiet) LogInfo.begin_track("Learning tree patterns ..."); 116 | for (Example ex : dataset.trainExamples) { 117 | for (Candidate candidate : ex.candidates) { 118 | if (candidate.getReward() > 0) { 119 | // Good candidate -- remember the tree pattern 120 | MapUtils.incr(pathCounts, getPathSuffix(candidate)); 121 | } 122 | } 123 | } 124 | // Sort by count 125 | List, Integer>> entries = new ArrayList<>(pathCounts.entrySet()); 126 | Collections.sort(entries, new ValueComparator, Integer>(true)); 127 | // Retain the top n paths 128 | int n = Math.min(opts.baselineMaxNumPatterns, entries.size()); 129 | goodPathCounts = new HashMap<>(); 130 | for (Map.Entry, Integer> entry : entries.subList(0, n)) { 131 | goodPathCounts.put(entry.getKey(), entry.getValue()); 132 | } 133 | if (!beVeryQuiet) LogInfo.logs("Found %d path patterns.", goodPathCounts.size()); 134 | if (!beVeryQuiet) LogInfo.end_track(); 135 | iterativeTester.run(); 136 | } 137 | 138 | private List getPathSuffix(Candidate candidate) { 139 | return getPathSuffix(candidate.pattern.getPath()); 140 | } 141 | 142 | private List getPathSuffix(List path) { 143 | List suffix = new ArrayList<>(); 144 | int startIndex = Math.max(0, path.size() - opts.baselineSuffixLength); 145 | for (PathEntry entry : path.subList(startIndex, path.size())) { 146 | String strEntry = ""; 147 | switch (opts.baselineIndexType) { 148 | case NONE: strEntry = entry.tag; break; 149 | case STAR: strEntry = entry.tag + (entry.isIndexed() ? "[*]" : ""); break; 150 | case FULL: strEntry = entry.toString(); break; 151 | } 152 | suffix.add(strEntry.intern()); 153 | } 154 | if (opts.baselineBagOfTags) 155 | Collections.sort(suffix); 156 | return suffix; 157 | } 158 | 159 | // ============================================================ 160 | // Persistence 161 | // ============================================================ 162 | 163 | @Override 164 | public void saveModel(String path) { 165 | LogInfo.fail("Not implemented"); 166 | } 167 | 168 | @Override 169 | public void loadModel(String path) { 170 | LogInfo.fail("Not implemented"); 171 | } 172 | 173 | } 174 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/LearnerMaxEntWithBeamSearch.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.Example; 6 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 7 | import fig.basic.LogInfo; 8 | import fig.basic.Option; 9 | import fig.basic.Pair; 10 | 11 | public class LearnerMaxEntWithBeamSearch extends LearnerMaxEnt { 12 | public static class Options { 13 | @Option public int beamSize = 500; 14 | @Option public int beamTrainStartIter = 1; 15 | @Option public String beamCandidateType = "cutrange"; 16 | } 17 | public static Options opts = new Options(); 18 | 19 | @Override 20 | protected List getCandidates(Example example) { 21 | if (trainIter <= opts.beamTrainStartIter) { 22 | return super.getCandidates(example); 23 | } else { 24 | return getBeamSearchedCandidates(example); 25 | } 26 | } 27 | 28 | protected List getBeamSearchedCandidates(Example example) { 29 | List> rankedCandidates = super.getRankedCandidates(example); 30 | rankedCandidates = rankedCandidates.subList(0, Math.min(opts.beamSize, rankedCandidates.size())); 31 | List derivedCandidates = new ArrayList<>(); 32 | for (Pair entry : rankedCandidates) { 33 | derivedCandidates.addAll(getDerivedCandidates(entry.getFirst())); 34 | } 35 | return derivedCandidates; 36 | } 37 | 38 | protected List getDerivedCandidates(Candidate original) { 39 | switch (opts.beamCandidateType) { 40 | case "cutrange": 41 | LogInfo.fails("... not implemented yet ..."); 42 | //return TreePatternAndRange.generateCutRangeCandidates(original); 43 | return null; 44 | case "endcut": 45 | LogInfo.fails("... not implemented yet ..."); 46 | return null; 47 | default: 48 | LogInfo.fails("Unrecognized beam candidate type: %s", opts.beamCandidateType); 49 | return null; 50 | } 51 | } 52 | 53 | @Override 54 | public List> getRankedCandidates(Example example) { 55 | List> answer = new ArrayList<>(); 56 | for (Candidate candidate : getBeamSearchedCandidates(example)) { 57 | double score = candidate.features.dotProduct(params); 58 | answer.add(new Pair(candidate, score)); 59 | } 60 | Collections.sort(answer, new Pair.ReverseSecondComparator()); 61 | return answer; 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/Params.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model; 2 | 3 | import fig.basic.*; 4 | 5 | import java.io.BufferedReader; 6 | import java.io.IOException; 7 | import java.io.PrintWriter; 8 | import java.util.*; 9 | 10 | /** 11 | * Params contains the parameters of the model. Currently consists of a map from 12 | * features to weights. 13 | * 14 | * @author Percy Liang 15 | */ 16 | public class Params { 17 | public static class Options { 18 | @Option(gloss = "By default, all features have this weight") 19 | public double defaultWeight = 0; 20 | @Option(gloss = "Randomly initialize the weights") 21 | public boolean initWeightsRandomly = false; 22 | @Option(gloss = "Randomly initialize the weights") 23 | public Random initRandom = new Random(1); 24 | 25 | @Option(gloss = "Initial step size") 26 | public double initStepSize = 1; 27 | @Option(gloss = "How fast to reduce the step size") 28 | public double stepSizeReduction = 0; 29 | @Option(gloss = "Use the AdaGrad algorithm (different step size for each coordinate)") 30 | public boolean adaptiveStepSize = true; 31 | @Option(gloss = "Use dual averaging") 32 | public boolean dualAveraging = false; 33 | } 34 | public static Options opts = new Options(); 35 | 36 | // Discriminative weights 37 | HashMap weights = new HashMap(); 38 | 39 | public double getWeight(String f) { 40 | if (opts.initWeightsRandomly) 41 | return MapUtils.getDouble(weights, f, 2 * opts.initRandom.nextDouble() - 1); 42 | else 43 | return MapUtils.getDouble(weights, f, opts.defaultWeight); 44 | } 45 | 46 | // ============================================================ 47 | // Weight update 48 | // ============================================================ 49 | 50 | // For AdaGrad 51 | Map sumSquaredGradients = new HashMap(); 52 | 53 | // For dual averaging 54 | Map sumGradients = new HashMap(); 55 | 56 | // Number of stochastic updates we've made so far (for determining step size). 57 | int numUpdates; 58 | 59 | /** 60 | * Update weights by adding |gradient| (modified appropriately with step size). 61 | */ 62 | public void update(Map gradient) { 63 | numUpdates++; 64 | 65 | for (Map.Entry entry : gradient.entrySet()) { 66 | String f = entry.getKey(); 67 | double g = entry.getValue(); 68 | if (Math.abs(g) < 1e-6) continue; 69 | double stepSize; 70 | if (opts.adaptiveStepSize) { 71 | MapUtils.incr(sumSquaredGradients, f, g * g); 72 | stepSize = opts.initStepSize / Math.sqrt(sumSquaredGradients.get(f)); 73 | } else { 74 | stepSize = opts.initStepSize / Math.pow(numUpdates, opts.stepSizeReduction); 75 | } 76 | if (Double.isNaN(stepSize) || Double.isNaN(g)) { 77 | LogInfo.fails("WTF? %s %s %s", f, g, sumSquaredGradients.get(f)); 78 | } 79 | if (opts.dualAveraging) { 80 | if (!opts.adaptiveStepSize && opts.stepSizeReduction != 0) 81 | throw new RuntimeException("Dual averaging not supported when " + 82 | "step-size changes across iterations for " + 83 | "features for which the gradient is zero"); 84 | MapUtils.incr(sumGradients, f, g); 85 | MapUtils.set(weights, f, stepSize * sumGradients.get(f)); 86 | } else { 87 | MapUtils.incr(weights, f, stepSize * g); 88 | } 89 | } 90 | } 91 | 92 | public static double L1Cut(double x, double cutoff) { 93 | return (x > cutoff) ? (x - cutoff) : (x < -cutoff) ? (x + cutoff) : 0; 94 | } 95 | 96 | /** 97 | * Apply L1 regularization: 98 | * - If weight > cutoff, then weight := weight - cutoff 99 | * - If weight < -cutoff, then weight := weight + cutoff 100 | * - Otherwise, weight := 0 101 | * @param cutoff regularization parameter (>= 0) 102 | */ 103 | public void applyL1Regularization(double cutoff) { 104 | if (cutoff <= 0) return; 105 | for (Map.Entry entry : weights.entrySet()) { 106 | entry.setValue(L1Cut(entry.getValue(), cutoff)); 107 | } 108 | } 109 | 110 | /** 111 | * Prune features with small weights 112 | * @param threshold the maximum absolute value for weights to be pruned 113 | */ 114 | public void prune(double threshold) { 115 | if (threshold <= 0) return; 116 | Iterator> iter = weights.entrySet().iterator(); 117 | while (iter.hasNext()) { 118 | if (Math.abs(iter.next().getValue()) < threshold) iter.remove(); 119 | } 120 | } 121 | 122 | // ============================================================ 123 | // Persistence 124 | // ============================================================ 125 | 126 | /** 127 | * Read parameters from |path|. 128 | */ 129 | public void read(String path) { 130 | LogInfo.begin_track("Reading parameters from %s", path); 131 | try { 132 | BufferedReader in = IOUtils.openIn(path); 133 | String line; 134 | while ((line = in.readLine()) != null) { 135 | String[] pair = line.split("\t"); 136 | weights.put(pair[0], Double.parseDouble(pair[1])); 137 | } 138 | in.close(); 139 | } catch (IOException e) { 140 | throw new RuntimeException(e); 141 | } 142 | LogInfo.logs("Read %s weights", weights.size()); 143 | LogInfo.end_track(); 144 | } 145 | 146 | public void write(PrintWriter out) { write(null, out); } 147 | 148 | public void write(String prefix, PrintWriter out) { 149 | List> entries = new ArrayList<>(weights.entrySet()); 150 | Collections.sort(entries, new ValueComparator(true)); 151 | for (Map.Entry entry : entries) { 152 | double value = entry.getValue(); 153 | out.println((prefix == null ? "" : prefix + "\t") + entry.getKey() + "\t" + value); 154 | } 155 | } 156 | 157 | public void write(String path) { 158 | LogInfo.begin_track("Params.write(%s)", path); 159 | PrintWriter out = IOUtils.openOutHard(path); 160 | write(out); 161 | out.close(); 162 | LogInfo.end_track(); 163 | } 164 | 165 | // ============================================================ 166 | // Logging 167 | // ============================================================ 168 | 169 | public void log() { 170 | LogInfo.begin_track("Params"); 171 | List> entries = new ArrayList<>(weights.entrySet()); 172 | Collections.sort(entries, new ValueComparator(true)); 173 | for (Map.Entry entry : entries) { 174 | double value = entry.getValue(); 175 | LogInfo.logs("%s\t%s", entry.getKey(), value); 176 | } 177 | LogInfo.end_track(); 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/candidate/Candidate.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.candidate; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.Example; 6 | import edu.stanford.nlp.semparse.open.model.FeatureVector; 7 | 8 | /** 9 | * A Candidate is a possible set of predicted entities. 10 | */ 11 | public class Candidate { 12 | public final Example ex; 13 | public final CandidateGroup group; 14 | public final TreePattern pattern; 15 | public final List predictedEntities; 16 | public FeatureVector features; 17 | 18 | public Candidate(CandidateGroup group, TreePattern pattern) { 19 | this.pattern = pattern; 20 | this.group = group; 21 | group.candidates.add(this); 22 | // Perform shallow copy 23 | this.ex = group.ex; 24 | this.predictedEntities = group.predictedEntities; 25 | } 26 | 27 | public int numEntities() { 28 | return group.numEntities(); 29 | } 30 | 31 | public double getReward() { 32 | return group.ex.expectedAnswer.reward(this); 33 | } 34 | 35 | public Map getCombinedFeatures() { 36 | Map map = new HashMap<>(); 37 | features.increment(1, map); 38 | group.features.increment(1, map); 39 | return map; 40 | } 41 | 42 | // ============================================================ 43 | // Debug Print 44 | // ============================================================ 45 | 46 | public String sampleEntities() { 47 | return group.sampleEntities(); 48 | } 49 | 50 | public String allEntities() { 51 | return group.allEntities(); 52 | } 53 | } -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/candidate/CandidateGroup.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.candidate; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.Example; 6 | import edu.stanford.nlp.semparse.open.ling.AveragedWordVector; 7 | import edu.stanford.nlp.semparse.open.ling.LingUtils; 8 | import edu.stanford.nlp.semparse.open.model.FeatureVector; 9 | import edu.stanford.nlp.semparse.open.model.tree.KNode; 10 | import edu.stanford.nlp.semparse.open.util.StringSampler; 11 | import fig.basic.Option; 12 | 13 | /** 14 | * A CandidateGroup is a collection of candidates with the same selected KNodes 15 | * (and thus the same selected entity strings). 16 | */ 17 | public class CandidateGroup { 18 | public static class Options { 19 | @Option(gloss = "level of entity string normalization when creating candidate group " 20 | + "(0 = none / 1 = whitespace / 2 = simple / 3 = aggressive)") 21 | public int lateNormalizeEntities = 2; 22 | } 23 | public static Options opts = new Options(); 24 | 25 | public final Example ex; 26 | public final List selectedNodes; 27 | public final List predictedEntities; 28 | final List candidates; 29 | public FeatureVector features; 30 | public AveragedWordVector averagedWordVector; 31 | 32 | public CandidateGroup(Example ex, List selectedNodes) { 33 | this.ex = ex; 34 | this.selectedNodes = new ArrayList<>(selectedNodes); 35 | List entities = new ArrayList<>(); 36 | for (KNode node : selectedNodes) { 37 | entities.add(LingUtils.normalize(node.fullText, opts.lateNormalizeEntities)); 38 | } 39 | predictedEntities = new ArrayList<>(entities); 40 | candidates = new ArrayList<>(); 41 | } 42 | 43 | public void initAveragedWordVector() { 44 | if (averagedWordVector == null) 45 | averagedWordVector = new AveragedWordVector(predictedEntities); 46 | } 47 | 48 | public int numEntities() { 49 | return predictedEntities.size(); 50 | } 51 | 52 | public int numCandidate() { 53 | return candidates.size(); 54 | } 55 | 56 | public List getCandidates() { 57 | return Collections.unmodifiableList(candidates); 58 | } 59 | 60 | public Candidate addCandidate(TreePattern pattern) { 61 | return new Candidate(this, pattern); 62 | } 63 | 64 | public double getReward() { 65 | return ex.expectedAnswer.reward(this); 66 | } 67 | 68 | // ============================================================ 69 | // Debug Print 70 | // ============================================================ 71 | 72 | public String sampleEntities() { 73 | return StringSampler.sampleEntities(predictedEntities, StringSampler.DEFAULT_LIMIT); 74 | } 75 | 76 | public String allEntities() { 77 | return StringSampler.sampleEntities(predictedEntities); 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/candidate/PathEntry.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.candidate; 2 | 3 | import edu.stanford.nlp.semparse.open.model.tree.KNode; 4 | 5 | /** 6 | * A PathEntry represents an entry in an XPath. 7 | * It is immutable. 8 | * 9 | * An XPath is just a list of PathEntry. 10 | * Utilities involving XPath are in PathUtils class. 11 | * 12 | * Use an List for a fixed XPath (e.g. in TreePattern). 13 | * The main benefit is that the paths can be hashed and compared consistently. 14 | * 15 | * Use a normal List for an editable path. 16 | */ 17 | public class PathEntry implements Comparable { 18 | final public String tag; 19 | final public int index; // 0-based; -1 = no index 20 | 21 | public PathEntry(String tag, int index) { 22 | this.tag = tag; 23 | this.index = index; 24 | } 25 | 26 | public PathEntry(String tag) { 27 | this.tag = tag; 28 | this.index = -1; 29 | } 30 | 31 | public boolean isIndexed() { 32 | return this.index != -1; 33 | } 34 | 35 | public PathEntry getIndexedVersion(int newIndex) { 36 | return new PathEntry(tag, newIndex); 37 | } 38 | 39 | public PathEntry getNoIndexVersion() { 40 | return new PathEntry(tag); 41 | } 42 | 43 | /** Check if the PathEntry's tag matches the node's tag. */ 44 | public boolean matchTag(KNode node) { 45 | return tag.equals("*") || tag.equals(node.value); 46 | } 47 | 48 | @Override public String toString() { 49 | if (this.index == -1) 50 | return this.tag; 51 | return this.tag + "[" + (this.index + 1) + "]"; 52 | } 53 | 54 | @Override public boolean equals(Object obj) { 55 | if (obj == this) 56 | return true; 57 | if (obj == null || obj.getClass() != this.getClass()) 58 | return false; 59 | PathEntry that = (PathEntry) obj; 60 | return this.tag.equals(that.tag) && this.index == that.index; 61 | } 62 | 63 | @Override public int hashCode() { 64 | return tag.hashCode() + index; 65 | } 66 | 67 | @Override 68 | public int compareTo(PathEntry o) { 69 | return toString().compareTo(o.toString()); 70 | } 71 | } -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/candidate/PathEntryWithRange.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.candidate; 2 | 3 | public class PathEntryWithRange extends PathEntry { 4 | 5 | public final int indexStart, indexEnd; 6 | 7 | public PathEntryWithRange(String tag, int indexStart, int indexEnd) { 8 | super(tag); 9 | this.indexStart = indexStart; 10 | this.indexEnd = indexEnd; 11 | } 12 | 13 | public boolean isIndexed() { 14 | return true; 15 | } 16 | 17 | @Override public String toString() { 18 | StringBuilder sb = new StringBuilder().append(this.tag).append("["); 19 | if (indexStart != 0) sb.append(indexStart); 20 | sb.append(":"); 21 | if (indexEnd != 0) sb.append("-").append(indexEnd); 22 | sb.append("]"); 23 | return sb.toString(); 24 | } 25 | 26 | @Override public boolean equals(Object obj) { 27 | if (obj == this) 28 | return true; 29 | if (obj == null || obj.getClass() != this.getClass()) 30 | return false; 31 | PathEntryWithRange that = (PathEntryWithRange) obj; 32 | return this.tag.equals(that.tag) && this.indexStart == that.indexStart && this.indexEnd == that.indexEnd; 33 | } 34 | 35 | @Override public int hashCode() { 36 | return tag.hashCode() + indexStart << 8 + indexEnd; 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/candidate/PathUtils.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.candidate; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.model.tree.KNode; 6 | import edu.stanford.nlp.semparse.open.model.tree.KNodeUtils; 7 | import fig.basic.LogInfo; 8 | 9 | /** 10 | * Utilities that deal with XPath (list of PathEntry) 11 | */ 12 | public class PathUtils { 13 | 14 | public static String getXPathString(List path) { 15 | StringBuilder sb = new StringBuilder(); 16 | for (PathEntry entry : path) { 17 | sb.append("/").append(entry); 18 | } 19 | return sb.toString(); 20 | } 21 | 22 | public static String getXPathSuffixString(List path, int amount) { 23 | StringBuilder sb = new StringBuilder(); 24 | int startIndex = Math.max(0, path.size() - amount); 25 | for (PathEntry entry : path.subList(startIndex, path.size())) { 26 | sb.append("/").append(entry.toString()); 27 | } 28 | return sb.toString(); 29 | } 30 | 31 | public static String getXPathSuffixStringNoIndex(List path, int amount) { 32 | StringBuilder sb = new StringBuilder(); 33 | int startIndex = Math.max(0, path.size() - amount); 34 | for (PathEntry entry : path.subList(startIndex, path.size())) { 35 | sb.append("/").append(entry.tag); 36 | } 37 | return sb.toString(); 38 | } 39 | 40 | public static List getXPathSuffix(List path, int amount) { 41 | List suffix = new ArrayList<>(); 42 | int startIndex = Math.max(0, path.size() - amount); 43 | for (PathEntry entry : path.subList(startIndex, path.size())) { 44 | suffix.add(entry); 45 | } 46 | return suffix; 47 | } 48 | 49 | public static List getXPathSuffixNoIndex(List path, int amount) { 50 | List suffix = new ArrayList<>(); 51 | int startIndex = Math.max(0, path.size() - amount); 52 | for (PathEntry entry : path.subList(startIndex, path.size())) { 53 | suffix.add(entry.getNoIndexVersion()); 54 | } 55 | return suffix; 56 | } 57 | 58 | /** 59 | * Execute XPath on the currentNode and add the matched nodes to the answer collection. 60 | * Only add nodes with short text (i.e., fullText != null and fullText != "") 61 | */ 62 | public static void executePath(List path, KNode currentNode, Collection answer) { 63 | if (!path.get(0).matchTag(currentNode)) 64 | LogInfo.fails("XPath mismatch (node %s != xpath %s)", currentNode.value, path.get(0).tag); 65 | if (path.size() == 1) { 66 | if (currentNode.fullText != null && !currentNode.fullText.isEmpty()) 67 | answer.add(currentNode); 68 | return; 69 | } 70 | // Go to the next element of the path 71 | List descendants = path.subList(1, path.size()); 72 | PathEntry nextPathEntry = descendants.get(0); 73 | if (nextPathEntry instanceof PathEntryWithRange) { 74 | PathEntryWithRange nextPathEntryWithRange = (PathEntryWithRange) nextPathEntry; 75 | List children = currentNode.getChildrenOfTag(nextPathEntry.tag); 76 | int start = nextPathEntryWithRange.indexStart; 77 | int end = children.size() - nextPathEntryWithRange.indexEnd; 78 | if (start >= end) return; 79 | for (KNode child : children.subList(start, end)) { 80 | executePath(descendants, child, answer); 81 | } 82 | } else if (nextPathEntry.isIndexed()) { 83 | KNode child = currentNode.getChildrenOfTag(nextPathEntry.tag, nextPathEntry.index); 84 | if (child != null) 85 | executePath(descendants, child, answer); 86 | } else { 87 | for (KNode child : currentNode.getChildrenOfTag(nextPathEntry.tag)) 88 | executePath(descendants, child, answer); 89 | } 90 | } 91 | 92 | /** 93 | * Return the prefix of the given XPath where the last entry of the prefix refers to the given node. 94 | */ 95 | public static List pathPrefixAtNode(List path, KNode root, KNode node) { 96 | if (!KNodeUtils.isDescendantOf(node, root)) { 97 | LogInfo.fails("%s not a descendant of %s", node, root); 98 | } 99 | return path.subList(0, node.depth - root.depth); 100 | } 101 | 102 | /** 103 | * Return the suffix of the given XPath where the first entry of the suffix refers to the given node. 104 | */ 105 | public static List pathSuffixAtNode(List path, KNode root, KNode node) { 106 | if (!KNodeUtils.isDescendantOf(node, root)) { 107 | LogInfo.fails("%s not a descendant of %s", node, root); 108 | } 109 | return path.subList(node.depth - root.depth, path.size()); 110 | } 111 | 112 | } 113 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/candidate/TreePattern.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.candidate; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.model.tree.KNode; 6 | 7 | /** 8 | * A TreePattern is the entire specification for selecting entities from the knowledge tree. 9 | */ 10 | public class TreePattern { 11 | 12 | protected final KNode rootNode; 13 | protected final List recordPath; 14 | protected final List recordNodes; 15 | 16 | public TreePattern(KNode rootNode, Collection recordPath, Collection recordNodes) { 17 | this.rootNode = rootNode; 18 | this.recordPath = new ArrayList<>(recordPath); 19 | this.recordNodes = new ArrayList<>(recordNodes); 20 | } 21 | 22 | @Override public String toString() { 23 | return PathUtils.getXPathString(recordPath); 24 | } 25 | 26 | public KNode getRoot() { 27 | return rootNode; 28 | } 29 | 30 | public List getPath() { 31 | return recordPath; 32 | } 33 | 34 | public List getNodes() { 35 | return recordNodes; 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/candidate/TreePatternAndRange.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.candidate; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.model.feature.FeatureExtractor; 6 | import edu.stanford.nlp.semparse.open.model.tree.KNode; 7 | 8 | public class TreePatternAndRange extends TreePattern { 9 | 10 | /** 11 | * The selected nodes are originalRecordNodes[i] where rangeStart <= i < rangeEnd 12 | */ 13 | public final int amountCutStart, amountCutEnd, rangeStart, rangeEnd; 14 | 15 | public TreePatternAndRange(KNode rootNode, Collection recordPath, Collection originalRecordNodes, 16 | int amountCutStart, int amountCutEnd) { 17 | super(rootNode, recordPath, originalRecordNodes); 18 | if (amountCutStart < 0 || amountCutEnd < 0 || amountCutStart + amountCutEnd > originalRecordNodes.size()) 19 | throw new IndexOutOfBoundsException("Invalid range: " + 20 | "cutStart = " + amountCutStart + ", cutEnd = " + amountCutEnd + ", n = " + originalRecordNodes.size()); 21 | this.amountCutStart = amountCutStart; 22 | this.amountCutEnd = amountCutEnd; 23 | this.rangeStart = amountCutStart; 24 | this.rangeEnd = originalRecordNodes.size() - amountCutEnd; 25 | } 26 | 27 | public TreePatternAndRange(TreePattern treePattern, int amountCutStart, int amountCutEnd) { 28 | this(treePattern.rootNode, treePattern.recordPath, treePattern.recordNodes, amountCutStart, amountCutEnd); 29 | } 30 | 31 | @Override 32 | public String toString() { 33 | return new StringBuilder().append(PathUtils.getXPathString(recordPath)) 34 | .append(" [").append(rangeStart).append(":").append(rangeEnd).append("]").toString(); 35 | } 36 | 37 | @Override 38 | public List getNodes() { 39 | return recordNodes.subList(rangeStart, rangeEnd); 40 | } 41 | 42 | /* It works, but is painfully slow. */ 43 | @Deprecated 44 | public static List generateCutRangeCandidates(Candidate candidate) { 45 | List candidates = new ArrayList<>(); 46 | // Remove a few stuff from both sides 47 | int n = candidate.numEntities(); 48 | for (int i = 0; i < Math.min(5, n - CandidateGenerator.opts.minNumCandidateEntity); i++) { 49 | CandidateGroup group = new CandidateGroup(candidate.ex, candidate.group.selectedNodes.subList(1, n)); 50 | candidates.add(group.addCandidate(new TreePatternAndRange(candidate.pattern, i, 0))); 51 | } 52 | for (int i = 1; i < Math.min(10, n - CandidateGenerator.opts.minNumCandidateEntity); i++) { 53 | CandidateGroup group = new CandidateGroup(candidate.ex, candidate.group.selectedNodes.subList(0, n-i)); 54 | candidates.add(group.addCandidate(new TreePatternAndRange(candidate.pattern, 0, i))); 55 | } 56 | for (Candidate cutRangeCandidate : candidates) { 57 | FeatureExtractor.featureExtractor.extract(cutRangeCandidate); 58 | FeatureExtractor.featureExtractor.extract(cutRangeCandidate.group); 59 | } 60 | return candidates; 61 | } 62 | 63 | /** 64 | * Generate cut-range groups from the given group without actually adding any candidate. 65 | * Only suitable for experiments only. 66 | */ 67 | public static List generateDummyCutRangeCandidateGroups(CandidateGroup group) { 68 | List candidateGroups = new ArrayList<>(); 69 | // Remove a few stuff from both sides 70 | int n = group.numEntities(); 71 | for (int i = 0; i < Math.min(5, n - CandidateGenerator.opts.minNumCandidateEntity); i++) { 72 | candidateGroups.add(new CandidateGroup(group.ex, group.selectedNodes.subList(1, n))); 73 | } 74 | for (int i = 1; i < Math.min(10, n - CandidateGenerator.opts.minNumCandidateEntity); i++) { 75 | candidateGroups.add(new CandidateGroup(group.ex, group.selectedNodes.subList(0, n-i))); 76 | } 77 | return candidateGroups; 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/feature/FeatureExtractor.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.feature; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.model.FeatureVector; 6 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 7 | import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup; 8 | 9 | /** 10 | * A FeatureExtractor populate candidate's features. 11 | * It calls the extract method of different FeatureTypes. 12 | */ 13 | public class FeatureExtractor { 14 | 15 | protected final List featureTypes = Arrays.asList( 16 | new FeatureTypeNaiveEntityBased(), 17 | new FeatureTypeNodeBased(), 18 | new FeatureTypePathBased(), 19 | new FeatureTypeLinguisticsBased(), 20 | //new FeatureTypeQueryBased(), 21 | new FeatureTypeHoleBased() 22 | //new FeatureTypeCutRange() 23 | ); 24 | protected final List featurePostProcessors = Arrays.asList( 25 | (FeaturePostProcessor) new FeaturePostProcessorConjoin()); 26 | 27 | public void extract(Candidate candidate) { 28 | if (candidate.features != null) return; 29 | candidate.features = new FeatureVector(); 30 | for (FeatureType featureType : featureTypes) { 31 | featureType.extract(candidate); 32 | } 33 | for (FeaturePostProcessor featurePostProcessor : featurePostProcessors) { 34 | featurePostProcessor.process(candidate); 35 | } 36 | } 37 | 38 | public void extract(CandidateGroup group) { 39 | if (group.features != null) return; 40 | group.features = new FeatureVector(); 41 | group.features.add("basic", "bias"); 42 | for (FeatureType featureType : featureTypes) { 43 | featureType.extract(group); 44 | } 45 | for (FeaturePostProcessor featurePostProcessor : featurePostProcessors) { 46 | featurePostProcessor.process(group); 47 | } 48 | } 49 | 50 | public static final FeatureExtractor featureExtractor = new FeatureExtractor(); 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/feature/FeaturePostProcessor.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.feature; 2 | 3 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 4 | import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup; 5 | import fig.basic.LogInfo; 6 | 7 | public abstract class FeaturePostProcessor { 8 | 9 | public abstract void process(Candidate candidate); 10 | public abstract void process(CandidateGroup group); 11 | 12 | public static void checkFeaturePostProcessorOptionsSanity() { 13 | if (FeaturePostProcessorConjoin.opts.useConjoin) { 14 | LogInfo.begin_track("Feature post-processor: Conjoin"); 15 | FeaturePostProcessorConjoin.debugPrintOptions(); 16 | LogInfo.end_track(); 17 | } 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/feature/FeaturePostProcessorConjoin.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.feature; 2 | 3 | import java.util.regex.Pattern; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.Example; 6 | import edu.stanford.nlp.semparse.open.ling.LingUtils; 7 | import edu.stanford.nlp.semparse.open.ling.QueryTypeTable; 8 | import edu.stanford.nlp.semparse.open.ling.WordNetClusterTable; 9 | import edu.stanford.nlp.semparse.open.model.FeatureMatcher; 10 | import edu.stanford.nlp.semparse.open.model.FeatureVector; 11 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 12 | import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup; 13 | import fig.basic.Fmt; 14 | import fig.basic.LogInfo; 15 | import fig.basic.Option; 16 | 17 | public class FeaturePostProcessorConjoin extends FeaturePostProcessor { 18 | public static class Options { 19 | @Option(gloss = "conjoin features with an abstract representation of the query") 20 | public boolean useConjoin = false; 21 | 22 | @Option public String cjQueryTypeName = null; 23 | @Option public boolean cjConjoinWithWordNetClusters = false; 24 | 25 | @Option public String cjRegExConjoin = "^(ling|entity).*"; 26 | @Option public boolean cjKeepOriginalFeatures = false; 27 | @Option public double cjScaleConjoinFeatures = 1.0; 28 | } 29 | public static Options opts = new Options(); 30 | 31 | public static void debugPrintOptions() { 32 | if (opts.cjQueryTypeName != null && !opts.cjQueryTypeName.isEmpty()) 33 | LogInfo.logs("Conjoining query type: %s", opts.cjQueryTypeName); 34 | else 35 | LogInfo.log("Conjoining ALL query types"); 36 | if (opts.cjRegExConjoin != null && !opts.cjRegExConjoin.isEmpty()) 37 | LogInfo.logs("Conjoining features matching regex: %s", opts.cjRegExConjoin); 38 | else 39 | LogInfo.log("Conjoining ALL features"); 40 | if (opts.cjKeepOriginalFeatures) 41 | LogInfo.log("... also keep original features"); 42 | if (opts.cjScaleConjoinFeatures != 1.0) 43 | LogInfo.logs("... also scale conjoined features by %s", Fmt.D(opts.cjScaleConjoinFeatures)); 44 | } 45 | 46 | @Override 47 | public void process(Candidate candidate) { 48 | if (!opts.useConjoin) return; 49 | String prefix = getConjoiningPrefix(candidate.ex); 50 | candidate.features = getConjoinedFeatureVector(candidate.features, prefix); 51 | } 52 | 53 | @Override 54 | public void process(CandidateGroup group) { 55 | if (!opts.useConjoin) return; 56 | String prefix = getConjoiningPrefix(group.ex); 57 | group.features = getConjoinedFeatureVector(group.features, prefix); 58 | } 59 | 60 | // ============================================================ 61 | // Compute the abstract representation g(query) 62 | // ============================================================ 63 | 64 | private String getQueryType(Example ex) { 65 | return getQueryType(ex.phrase); 66 | } 67 | 68 | private String getQueryType(String phrase) { 69 | String queryType; 70 | if (opts.cjConjoinWithWordNetClusters) { 71 | queryType = WordNetClusterTable.getCluster(LingUtils.findHeadWord(phrase, true)); 72 | } else { 73 | queryType = QueryTypeTable.getQueryType(phrase); 74 | } 75 | return "" + queryType; 76 | } 77 | 78 | private String getConjoiningPrefix(Example ex) { 79 | if (opts.cjQueryTypeName != null && !opts.cjQueryTypeName.isEmpty()) 80 | return opts.cjQueryTypeName.equals(getQueryType(ex)) ? "I" : "O"; 81 | else 82 | return getQueryType(ex); 83 | } 84 | 85 | // ============================================================ 86 | // Converting feature f to (g(query), f) 87 | // ============================================================ 88 | 89 | class RegExFeatureMatcher implements FeatureMatcher { 90 | public final Pattern regex; 91 | public final boolean inverse; 92 | 93 | public RegExFeatureMatcher(String regex) { 94 | this(Pattern.compile(regex), false); 95 | } 96 | 97 | public RegExFeatureMatcher(Pattern regex) { 98 | this(regex, false); 99 | } 100 | 101 | public RegExFeatureMatcher(String regex, boolean inverse) { 102 | this(Pattern.compile(regex), inverse); 103 | } 104 | 105 | public RegExFeatureMatcher(Pattern regex, boolean inverse) { 106 | this.regex = regex; 107 | this.inverse = inverse; 108 | } 109 | 110 | @Override 111 | public boolean matches(String feature) { 112 | boolean match = regex.matcher(feature).matches(); 113 | return inverse ? !match : match; 114 | } 115 | } 116 | 117 | private FeatureVector getConjoinedFeatureVector(FeatureVector vOld, String queryType) { 118 | FeatureVector v = new FeatureVector(); 119 | if (opts.cjRegExConjoin != null) { 120 | FeatureMatcher matcher = new RegExFeatureMatcher(opts.cjRegExConjoin), 121 | invMatcher = new RegExFeatureMatcher(opts.cjRegExConjoin, true); 122 | if (opts.cjKeepOriginalFeatures) { 123 | v.addConjoin(vOld, "ALL"); 124 | } else { 125 | v.addConjoin(vOld, "ALL", invMatcher); 126 | } 127 | if (opts.cjScaleConjoinFeatures != 1.0) { 128 | v.addConjoin(vOld, queryType, matcher, opts.cjScaleConjoinFeatures); 129 | } else { 130 | v.addConjoin(vOld, queryType, matcher); 131 | } 132 | } else { 133 | if (opts.cjKeepOriginalFeatures) v.addConjoin(vOld, "ALL"); 134 | v.addConjoin(vOld, queryType); 135 | } 136 | return v; 137 | } 138 | 139 | } 140 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/feature/FeatureTypeCutRange.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.feature; 2 | 3 | import java.util.List; 4 | 5 | import edu.stanford.nlp.semparse.open.model.FeatureVector; 6 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 7 | import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup; 8 | import edu.stanford.nlp.semparse.open.model.candidate.PathEntry; 9 | import edu.stanford.nlp.semparse.open.model.candidate.TreePatternAndRange; 10 | 11 | /** 12 | * Features on the range-cut candidates 13 | */ 14 | public class FeatureTypeCutRange extends FeatureType { 15 | 16 | @Override 17 | public void extract(Candidate candidate) { 18 | extractCutRangeFeatures(candidate); 19 | } 20 | 21 | @Override 22 | public void extract(CandidateGroup group) { 23 | // Do nothing 24 | } 25 | 26 | protected void extractCutRangeFeatures(Candidate candidate) { 27 | if ((candidate.pattern instanceof TreePatternAndRange) && isAllowedDomain("cutrange")) { 28 | // Amount of range cut 29 | TreePatternAndRange pattern = (TreePatternAndRange) candidate.pattern; 30 | if (pattern.amountCutStart + pattern.amountCutEnd > 0) { 31 | addConjunctiveFeatures(candidate.features, "cutrange", "has-cut", pattern); 32 | if (pattern.amountCutStart == 1 && pattern.amountCutEnd == 0) 33 | addConjunctiveFeatures(candidate.features, "cutrange", "cut-first-only", pattern); 34 | if (pattern.amountCutStart == 0 && pattern.amountCutEnd == 1) 35 | addConjunctiveFeatures(candidate.features, "cutrange", "cut-last-only", pattern); 36 | if (pattern.amountCutStart > 0) 37 | addConjunctiveFeatures(candidate.features, "cutrange", "cut-front", pattern); 38 | if (pattern.amountCutEnd > 0) 39 | addConjunctiveFeatures(candidate.features, "cutrange", "cut-back", pattern); 40 | } 41 | // What are at the cut points? 42 | 43 | } 44 | } 45 | 46 | protected void addConjunctiveFeatures(FeatureVector v, String domain, String name, 47 | TreePatternAndRange pattern) { 48 | List path = pattern.getPath(); 49 | v.add(domain, name); 50 | v.add(domain, name + " | tag = " + path.get(path.size() - 1)); 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/feature/FeatureTypeLinguisticsBased.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.feature; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.ling.BrownClusterTable; 6 | import edu.stanford.nlp.semparse.open.ling.LingData; 7 | import edu.stanford.nlp.semparse.open.ling.LingUtils; 8 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 9 | import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup; 10 | import edu.stanford.nlp.semparse.open.util.Multiset; 11 | import fig.basic.Option; 12 | 13 | /** 14 | * Extract features by looking at linguistic properties 15 | */ 16 | public class FeatureTypeLinguisticsBased extends FeatureType { 17 | public static class Options { 18 | @Option public boolean lingLemmatizedTokens = false; 19 | @Option public boolean lingWordPOS = false; 20 | @Option public boolean lingAltWordPOS = true; 21 | @Option public boolean lingBinWordPOS = true; 22 | @Option public boolean lingEntityPOS = false; 23 | @Option public boolean lingCollapsedPOS = true; 24 | } 25 | public static Options opts = new Options(); 26 | 27 | @Override 28 | public void extract(Candidate candidate) { 29 | // Do nothing 30 | } 31 | 32 | @Override 33 | public void extract(CandidateGroup group) { 34 | extractLingFeatures(group); 35 | extractClusterFeatures(group); 36 | //extractFakeWordVectorFeatures(group); 37 | } 38 | 39 | protected void extractLingFeatures(CandidateGroup group) { 40 | if (isAllowedDomain("ling")) { 41 | // Counts! 42 | Multiset countWordPOS = new Multiset<>(), 43 | countEntityPOS = new Multiset<>(), 44 | countEntityCollapsedPOS = new Multiset<>(), 45 | countFirstToken = new Multiset<>(), 46 | countLastToken = new Multiset<>(); 47 | for (String entity : group.predictedEntities) { 48 | LingData lingData = LingData.get(entity); 49 | if (lingData.length > 0) { 50 | for (String pos : lingData.posTags) countWordPOS.add(pos); 51 | // POS 52 | countEntityPOS.add(LingUtils.join(lingData.posTags)); 53 | countEntityCollapsedPOS.add(LingUtils.collapse(lingData.posTags)); 54 | // Tokens 55 | if (opts.lingLemmatizedTokens) { 56 | countFirstToken.add(lingData.lemmaTokens.get(0)); 57 | countLastToken.add(lingData.lemmaTokens.get(lingData.length - 1)); 58 | } else { 59 | countFirstToken.add(lingData.tokens.get(0)); 60 | countLastToken.add(lingData.tokens.get(lingData.length - 1)); 61 | } 62 | } 63 | } 64 | if (opts.lingWordPOS) 65 | addVotingFeatures(group.features, "ling", "word-pos", countWordPOS); 66 | if (opts.lingAltWordPOS) { 67 | addEntropyFeatures(group.features, "ling", "word-pos", countWordPOS); 68 | for (String pos : countWordPOS.elementSet()) { 69 | if (opts.lingBinWordPOS) { 70 | addPercentFeatures(group.features, "ling", "word-pos = " + pos, 71 | countWordPOS.count(pos) * 1.0 / countWordPOS.size()); 72 | } else { 73 | group.features.add("ling", "word-pos = " + pos); 74 | } 75 | } 76 | } 77 | if (opts.lingEntityPOS) 78 | addVotingFeatures(group.features, "ling", "entity-pos", countEntityPOS); 79 | if (opts.lingCollapsedPOS) 80 | addVotingFeatures(group.features, "ling", "entity-collapsed-pos", countEntityCollapsedPOS); 81 | addEntropyFeatures(group.features, "ling", "first-token", countFirstToken); 82 | addEntropyFeatures(group.features, "ling", "last-token", countLastToken); 83 | } 84 | } 85 | 86 | protected void extractClusterFeatures(CandidateGroup group) { 87 | if (isAllowedDomain("cluster")) { 88 | // Query cluster prefixes 89 | Set queryClusters = new HashSet<>(), 90 | entityPrefixes = new HashSet<>(); 91 | for (String token : LingData.get(group.ex.phrase).getTokens(true, true)) { 92 | queryClusters.addAll(BrownClusterTable.getDefaultClusterPrefixesFromWord(token)); 93 | queryClusters.add(token); // Also add the raw token 94 | } 95 | // Entity cluster 96 | Multiset entityTokenClusters = new Multiset<>(); 97 | for (String entity : group.predictedEntities) { 98 | for (String token : LingData.get(entity).tokens) { 99 | String cluster = BrownClusterTable.getCluster(token); 100 | if (cluster != null) { 101 | entityTokenClusters.add(cluster); 102 | for (String prefix : BrownClusterTable.getDefaultClusterPrefixes(cluster)) 103 | entityPrefixes.add(prefix); 104 | } 105 | } 106 | } 107 | // Add features 108 | for (String queryCluster : queryClusters) { 109 | for (String prefix : entityPrefixes) { 110 | group.features.add("cluster", "query = " + queryCluster + " | entity ~ " + prefix); 111 | } 112 | // Entity Entropy 113 | double normalizedEntropy = getNormalizedEntropy(entityTokenClusters); 114 | group.features.add("cluster", "query = " + queryCluster + " | entity-normalized-entropy", normalizedEntropy); 115 | } 116 | } 117 | } 118 | 119 | /** Use to debug the advanced word vector. Basically, this is the slower version. **/ 120 | protected void extractFakeWordVectorFeatures(CandidateGroup group) { 121 | if (isAllowedDomain("fake-wordvec")) { 122 | group.ex.initAveragedWordVector(); 123 | group.initAveragedWordVector(); 124 | double[] x = group.ex.averagedWordVector.averaged; 125 | double[] y = group.averagedWordVector.averaged; 126 | if (x == null || y == null) return; 127 | for (int i = 0; i < x.length; i++) { 128 | for (int j = 0; j < y.length; j++) { 129 | group.features.add("fake-wordvec", "[" + i + "][" + j + "]", x[i] * y[j]); 130 | } 131 | } 132 | } 133 | } 134 | 135 | } 136 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/feature/FeatureTypeNaiveEntityBased.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.feature; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.ling.FrequencyTable; 6 | import edu.stanford.nlp.semparse.open.ling.LingUtils; 7 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 8 | import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup; 9 | import edu.stanford.nlp.semparse.open.util.Multiset; 10 | import fig.basic.MapUtils; 11 | import fig.basic.Option; 12 | 13 | /** 14 | * Extract features by looking at the selected entities (strings). 15 | */ 16 | public class FeatureTypeNaiveEntityBased extends FeatureType { 17 | public static class Options { 18 | @Option public boolean useCountEntities = false; 19 | @Option public boolean addPhraseShapeFeature = true; 20 | @Option public boolean addCollapsedPhraseShapeFeature = false; 21 | @Option public boolean useDiscreteCountNumWords = true; 22 | @Option public boolean useMeanSDCountNumWords = true; 23 | } 24 | public static Options opts = new Options(); 25 | 26 | @Override 27 | public void extract(Candidate candidate) { 28 | // Do nothing 29 | } 30 | 31 | @Override 32 | public void extract(CandidateGroup group) { 33 | extractEntityFeatures(group); 34 | extractDocumentFrequencyFeatures(group); 35 | } 36 | 37 | protected void extractEntityFeatures(CandidateGroup group) { 38 | if (isAllowedDomain("entity")) { 39 | Multiset countEntity = new Multiset<>(), 40 | countPhraseShape = new Multiset<>(), 41 | countCollapsedPhraseShape = new Multiset<>(), 42 | countWordShape = new Multiset<>(); 43 | Multiset countNumWord = new Multiset<>(); 44 | for (String entity : group.predictedEntities) { 45 | countEntity.add(entity); 46 | String wordForm = LingUtils.computePhraseShape(entity); 47 | countPhraseShape.add(wordForm); 48 | String[] wordForms = wordForm.split(" "); 49 | for (String word : wordForms) { 50 | countWordShape.add(word); 51 | } 52 | countNumWord.add(wordForms.length); 53 | countCollapsedPhraseShape.add(LingUtils.collapse(wordForms)); 54 | } 55 | if (opts.useCountEntities) 56 | addQuantizedFeatures(group.features, "entity", "num-entities", group.predictedEntities.size()); 57 | addEntropyFeatures(group.features, "entity", "entity", countEntity); 58 | addDuplicationFeatures(group.features, "entity", "entity", countEntity); 59 | if (opts.addPhraseShapeFeature) 60 | addVotingFeatures(group.features, "entity", "phrase-shape", countPhraseShape); 61 | if (opts.addCollapsedPhraseShapeFeature) 62 | addVotingFeatures(group.features, "entity", "collapsed-phrase-shape", countCollapsedPhraseShape); 63 | addVotingFeatures(group.features, "entity", "word-shape", countWordShape); 64 | if (opts.useDiscreteCountNumWords) 65 | addVotingFeatures(group.features, "entity", "num-word", countNumWord); 66 | if (opts.useMeanSDCountNumWords) 67 | addMeanDeviationFeatures(group.features, "entity", "num-word", countNumWord); 68 | } 69 | } 70 | 71 | protected static Set BADWORDS = new HashSet<>(Arrays.asList( 72 | "com", "new", "about", "my", "home", "search", 73 | "information", "view", "page", "site", "click", 74 | "http", "contact", "www", "ord", "free", "now", "subscribe", 75 | "see", "service", "services", "online", "re", "data", 76 | "email", "top", "find", "system", "support", "comments", 77 | "policy", "last", "privacy", "post", "date", "time", "print")); 78 | 79 | 80 | protected void extractDocumentFrequencyFeatures(CandidateGroup group) { 81 | if (isAllowedDomain("token-freq")) { 82 | // Oracle experiment: use a fixed set of words 83 | int numTokens = 0, numBadWords = 0, numDigits = 0; 84 | Map numFrequentWords = new HashMap<>(); 85 | for (String entity : group.predictedEntities) { 86 | for (String token : LingUtils.getAlphaOrNumericTokens(entity)) { 87 | numTokens++; 88 | for (Map.Entry> entry : FrequencyTable.topWordsLists.entrySet()) { 89 | if (entry.getValue().contains(token)) { 90 | MapUtils.incr(numFrequentWords, entry.getKey()); 91 | } 92 | } 93 | if (BADWORDS.contains(token)) numBadWords++; 94 | if (token.matches("\\d+")) numDigits++; 95 | } 96 | } 97 | addPercentFeatures(group.features, "token-freq", "bad-words-ratio", numBadWords * 1.0 / numTokens); 98 | for (Map.Entry entry : numFrequentWords.entrySet()) { 99 | addPercentFeatures(group.features, "token-freq", "frequent-" + entry.getKey(), entry.getValue() * 1.0 / numTokens); 100 | } 101 | addPercentFeatures(group.features, "token-freq", "digits-ratio", numDigits * 1.0 / numTokens); 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/feature/FeatureTypeNodeBased.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.feature; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.model.FeatureVector; 6 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 7 | import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup; 8 | import edu.stanford.nlp.semparse.open.model.tree.KNode; 9 | import edu.stanford.nlp.semparse.open.util.Multiset; 10 | import fig.basic.Option; 11 | 12 | /** 13 | * Extract features by looking at the selected nodes in the knowledge tree. 14 | */ 15 | public class FeatureTypeNodeBased extends FeatureType { 16 | public static class Options { 17 | @Option public boolean rangeUseCollapsedTimestamp = true; 18 | @Option public boolean soaUseIndexedFeatures = false; 19 | @Option public boolean soaUseNoIndexFeatures = true; 20 | @Option public boolean soaUseIdClassFeatures = true; 21 | @Option public boolean soaAverage = false; 22 | } 23 | public static Options opts = new Options(); 24 | 25 | @Override 26 | public void extract(Candidate candidate) { 27 | // Do nothing 28 | } 29 | 30 | @Override 31 | public void extract(CandidateGroup group) { 32 | extractSelfOrAncestorsFeatures(group); 33 | extractNodeRangeFeatures(group); 34 | } 35 | 36 | public void extractSelfOrAncestorsFeatures(CandidateGroup group) { 37 | if (isAllowedDomain("self-or-ancestors")) { 38 | FeatureVector v = new FeatureVector(); 39 | // Majority id / class / number of children of the nodes and parents 40 | Set currentKNodes = new HashSet<>(group.selectedNodes); 41 | for (int ancestorCount = 0; ancestorCount < FeatureType.opts.maxAncestorCount; ancestorCount++) { 42 | Multiset countTag = new Multiset<>(), 43 | countId = new Multiset<>(), 44 | countClass = new Multiset<>(), 45 | countNumChildren = new Multiset<>(); 46 | Multiset countChildIndex = new Multiset<>(); 47 | Set parents = new HashSet<>(); 48 | for (KNode node : currentKNodes) { 49 | // Properties of the current node 50 | countTag.add(node.value); 51 | String nodeId = node.getAttribute("id"); 52 | if (!nodeId.isEmpty()) 53 | countId.add(nodeId); 54 | String nodeClass = node.getAttribute("class"); 55 | if (!nodeClass.isEmpty()) 56 | countClass.add(nodeClass); 57 | // Properties relating to children 58 | List children = node.getChildren(); 59 | int numChildren = children.size(); 60 | countNumChildren.add((numChildren <= 3) ? "" + numChildren : "many"); 61 | // Traverse up to parent 62 | if (node.parent != null) { 63 | countChildIndex.add(node.getChildIndex()); 64 | parents.add(node.parent); 65 | } 66 | } 67 | if (parents.isEmpty()) break; 68 | // Count how many children the parents have 69 | int countChildrenOfParents = 0; 70 | for (KNode parent : parents) { 71 | countChildrenOfParents += parent.countChildren(); 72 | } 73 | double percentChildrenOfParents = currentKNodes.size() * 1.0 / countChildrenOfParents; 74 | String domain = "self-or-ancestors"; 75 | // With indexed prefix 76 | if (opts.soaUseIndexedFeatures) { 77 | String prefix = "(n-" + ancestorCount + ")-"; 78 | addVotingFeatures(v, domain, prefix + "tag", countTag); 79 | if (opts.soaUseIdClassFeatures) { 80 | addVotingFeatures(v, domain, prefix + "id", countId, true); 81 | addVotingFeatures(v, domain, prefix + "class", countClass, true); 82 | } 83 | addVotingFeatures(v, domain, prefix + "num-children", countNumChildren); 84 | addVotingFeatures(v, domain, prefix + "child-index", countChildIndex, false, false); 85 | addPercentFeatures(v, domain, prefix + "children-of-parent", percentChildrenOfParents); 86 | if (parents.size() == 1) v.add(domain, prefix + "same-parent"); 87 | } 88 | // Without indexed prefix 89 | if (opts.soaUseNoIndexFeatures) { 90 | addVotingFeatures(v, domain, "tag", countTag); 91 | if (opts.soaUseIdClassFeatures) { 92 | addVotingFeatures(v, domain, "id", countId, true); 93 | addVotingFeatures(v, domain, "class", countClass, true); 94 | } 95 | addVotingFeatures(v, domain, "num-children", countNumChildren); 96 | addVotingFeatures(v, domain, "child-index", countChildIndex, false, false); 97 | addPercentFeatures(v, domain, "children-of-parent", percentChildrenOfParents); 98 | if (parents.size() == 1) v.add(domain, "same-parent"); 99 | } 100 | // Traverse up the tree 101 | currentKNodes = parents; 102 | } 103 | // Add features 104 | if (opts.soaAverage) { 105 | for (Map.Entry entry : v.toMap().entrySet()) { 106 | group.features.addFromString(entry.getKey(), entry.getValue() / FeatureType.opts.maxAncestorCount); 107 | } 108 | } else { 109 | group.features.add(v); 110 | } 111 | } 112 | } 113 | 114 | public void extractNodeRangeFeatures(CandidateGroup group) { 115 | if (isAllowedDomain("node-range")) { 116 | List selectedKNodes = group.selectedNodes; 117 | // Find root 118 | KNode root = selectedKNodes.get(0); 119 | while (root.timestampIn != 1) { 120 | root = root.parent; 121 | } 122 | int rangeMinTimestamp, rangeMaxTimestamp, pageMaxTimestamp; 123 | if (opts.rangeUseCollapsedTimestamp) { 124 | // Use only timestampIn (collapsed) 125 | rangeMinTimestamp = selectedKNodes.get(0).timestampInCollapsed; 126 | rangeMaxTimestamp = selectedKNodes.get(selectedKNodes.size() - 1).timestampInCollapsed; 127 | pageMaxTimestamp = root.timestampOut / 2; 128 | } else { 129 | // Use the original timestamps 130 | rangeMinTimestamp = selectedKNodes.get(0).timestampIn; 131 | rangeMaxTimestamp = selectedKNodes.get(selectedKNodes.size() - 1).timestampOut; 132 | pageMaxTimestamp = root.timestampOut; 133 | } 134 | addPercentFeatures(group.features, "node-range", "length", 135 | (rangeMaxTimestamp - rangeMinTimestamp) * 1.0 / pageMaxTimestamp); 136 | addPercentFeatures(group.features, "node-range", "start", 137 | (rangeMinTimestamp) * 1.0 / pageMaxTimestamp); 138 | addPercentFeatures(group.features, "node-range", "end", 139 | (rangeMaxTimestamp) * 1.0 / pageMaxTimestamp); 140 | } 141 | } 142 | 143 | } 144 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/feature/FeatureTypePathBased.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.feature; 2 | 3 | import java.util.List; 4 | 5 | import edu.stanford.nlp.semparse.open.model.candidate.Candidate; 6 | import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup; 7 | import edu.stanford.nlp.semparse.open.model.candidate.PathEntry; 8 | import edu.stanford.nlp.semparse.open.model.candidate.PathUtils; 9 | import edu.stanford.nlp.semparse.open.model.candidate.TreePattern; 10 | import fig.basic.Option; 11 | 12 | /** 13 | * Extract features by looking at the XPath 14 | */ 15 | public class FeatureTypePathBased extends FeatureType { 16 | public static class Options { 17 | @Option public boolean pathFeatureUsePrefix = true; 18 | @Option public int pathMaxAncestorCount = 0; 19 | @Option public boolean pathUsePathSuffix = true; 20 | } 21 | public static Options opts = new Options(); 22 | 23 | @Override 24 | public void extract(Candidate candidate) { 25 | extractPathTailFeatures(candidate); 26 | } 27 | 28 | @Override 29 | public void extract(CandidateGroup group) { 30 | // Do nothing 31 | } 32 | 33 | protected void extractPathTailFeatures(Candidate candidate) { 34 | if (isAllowedDomain("path-tail")) { 35 | TreePattern pattern = candidate.pattern; 36 | List path = pattern.getPath(); 37 | int maxCount = opts.pathMaxAncestorCount > 0 ? opts.pathMaxAncestorCount : FeatureType.opts.maxAncestorCount; 38 | for (int ancestorCount = 1; ancestorCount <= maxCount; ancestorCount++) { 39 | if (ancestorCount <= path.size()) { 40 | if (opts.pathUsePathSuffix) 41 | candidate.features.add("path-tail", PathUtils.getXPathSuffixStringNoIndex(path, ancestorCount)); 42 | PathEntry entry = path.get(path.size() - ancestorCount); 43 | if (opts.pathFeatureUsePrefix) { 44 | String prefix = "(n-" + ancestorCount + ")-"; 45 | candidate.features.add("path-tail", prefix + "tag = " + entry.tag); 46 | candidate.features.add("path-tail", prefix + "indexed = " + (entry.index != -1)); 47 | candidate.features.add("path-tail", prefix + "tag-indexed = " + entry.tag + " " + (entry.index != -1)); 48 | } else { 49 | candidate.features.add("path-tail", "tag = " + entry.tag); 50 | candidate.features.add("path-tail", "indexed = " + (entry.index != -1)); 51 | candidate.features.add("path-tail", "tag-indexed = " + entry.tag + " " + (entry.index != -1)); 52 | } 53 | } 54 | } 55 | } 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/tree/HTMLFixer.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.tree; 2 | 3 | import java.util.*; 4 | 5 | import org.jsoup.nodes.Document; 6 | import org.jsoup.nodes.Element; 7 | import org.jsoup.nodes.Node; 8 | import org.jsoup.select.Elements; 9 | 10 | 11 | import fig.basic.LogInfo; 12 | 13 | /** 14 | * Fix problematic HTML structures that decrease our accuracy. 15 | * 16 | * All fixes are done in place, so the document will be mutated. 17 | */ 18 | public class HTMLFixer { 19 | 20 | private Document document; 21 | 22 | public HTMLFixer(Document doc) { 23 | this.document = doc; 24 | } 25 | 26 | // ============================================================ 27 | // Fix Table (colspan / rowspan) 28 | // ============================================================ 29 | 30 | public void fixAllTables() { 31 | LogInfo.begin_track("Fix table ..."); 32 | for (Element table : document.getElementsByTag("tbody")) { 33 | fixTable(table); 34 | } 35 | LogInfo.end_track(); 36 | } 37 | 38 | /** 39 | * Normalize colspan and rowspan in the table 40 | * @param tbody An Element with tag name "tbody" 41 | */ 42 | private void fixTable(Element tbody) { 43 | // Fix colspan 44 | int numColumns = 0; 45 | for (Element tr : tbody.children()) { 46 | for (Element cell : new ArrayList<>(tr.children())) { 47 | int colspan = parseIntHard(cell.attr("colspan")), rowspan = parseIntHard(cell.attr("rowspan")); 48 | if (colspan <= 1) continue; 49 | cell.attr("old-colspan", cell.attr("colspan")); 50 | cell.removeAttr("colspan"); 51 | String tagName = cell.tagName(); 52 | for (int i = 2; i <= colspan; i++) { 53 | if (rowspan <= 1) 54 | cell.after(String.format("<%s>", tagName, tagName)); 55 | else 56 | cell.after(String.format("<%s rowspan=%d>", tagName, rowspan, tagName)); 57 | } 58 | } 59 | numColumns = Math.max(numColumns, tr.children().size()); 60 | } 61 | // Fix rowspan (assuming each column has 1 cell without colspan) 62 | int[] counts = new int[numColumns]; // For each column, track how many rows we should create new elements for 63 | String[] tags = new String[numColumns]; // For each column, track what type of elements to create 64 | for (Element tr : tbody.children()) { 65 | Element currentCell = null; 66 | List cells = new ArrayList<>(tr.children()); 67 | for (int i = 0, k = 0; i < numColumns; i++) { 68 | if (counts[i] > 0) { 69 | // Create a new element caused by rowspan 70 | String newCell = String.format("<%s>", tags[i], tags[i]); 71 | if (currentCell == null) 72 | tr.prepend(newCell); 73 | else 74 | currentCell.after(newCell); 75 | counts[i]--; 76 | } else { 77 | if (k >= cells.size()) continue; // Unfilled row 78 | currentCell = cells.get(k++); 79 | int rowSpan = parseIntHard(currentCell.attr("rowspan")); 80 | if (rowSpan <= 1) continue; 81 | counts[i] = rowSpan - 1; 82 | tags[i] = currentCell.tagName(); 83 | currentCell.attr("old-rowspan", currentCell.attr("rowspan")); 84 | currentCell.removeAttr("rowspan"); 85 | } 86 | } 87 | } 88 | } 89 | 90 | private int parseIntHard(String s) { 91 | if (s.isEmpty()) return 0; 92 | try { 93 | return Integer.parseInt(s); 94 | } catch (NumberFormatException e) { 95 | return 0; 96 | } 97 | } 98 | 99 | // ============================================================ 100 | // Fix BR 101 | // ============================================================ 102 | 103 | public void fixAllBRs() { 104 | LogInfo.begin_track("Fix BR ..."); 105 | Elements brList; 106 | while (!(brList = document.getElementsByTag("br")).isEmpty()) { 107 | fixBR(brList.get(0).parent()); 108 | } 109 | LogInfo.end_track(); 110 | } 111 | 112 | /** 113 | * Fix BR tags by wrapping each part in P tag instead 114 | */ 115 | private void fixBR(Element parent) { 116 | List childNodes = parent.childNodesCopy(); 117 | while (parent.childNodeSize() > 0) { 118 | Node child = parent.childNode(0); 119 | child.remove(); 120 | } 121 | Element currentChild = document.createElement("p"); 122 | for (Node node : childNodes) { 123 | if (node instanceof Element && "br".equals(((Element) node).tagName())) { 124 | parent.appendChild(currentChild); 125 | currentChild = document.createElement("p"); 126 | } else { 127 | currentChild.appendChild(node); 128 | } 129 | } 130 | parent.appendChild(currentChild); 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/tree/KNode.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.tree; 2 | 3 | import java.util.*; 4 | 5 | import fig.basic.LogInfo; 6 | import fig.basic.StrUtils; 7 | 8 | /** 9 | * A node in the knowledge tree. 10 | * 11 | * Nodes can be one of the following: 12 | *
    13 | *
  • type = QUERY, value = search query (e.g., "us cities")
  • 14 | *
  • type = URL, value = url (e.g., http://en.wikipedia.org/wiki/List_of_United_States_cities_by_population)
  • 15 | *
  • type = TAG, value = tag of the DOM node (e.g., "h1")
  • 16 | *
  • type = ATTR, value = attribute (e.g., "class")
  • 17 | *
  • type = TEXT, value = a string (e.g., "San Francisco")
  • 18 | *
19 | */ 20 | public class KNode { 21 | public enum Type { QUERY, URL, TAG, ATTR, TEXT }; 22 | 23 | public final Type type; 24 | public final String value; 25 | private final List children; 26 | private final List attributes; 27 | 28 | // fullText == '' if the node is empty 29 | // fullText == null if the full text is longer than the specified length. 30 | public final String fullText; 31 | 32 | // parent of root node is null 33 | public final KNode parent; 34 | 35 | // depth of root node is 0 36 | public final int depth; 37 | 38 | // timestamps of depth first search (used for firing range features) 39 | public int timestampIn, timestampOut, timestampInCollapsed; 40 | 41 | public KNode(KNode parent, Type type, String value) { 42 | this(parent, type, value, ""); 43 | } 44 | 45 | public KNode(KNode parent, Type type, String value, String fullText) { 46 | this.type = type; 47 | this.value = value; 48 | this.children = new ArrayList<>(); 49 | this.attributes = new ArrayList<>(); 50 | this.fullText = fullText; 51 | 52 | this.parent = parent; 53 | if (this.parent == null) { 54 | this.depth = 0; 55 | } else { 56 | this.depth = this.parent.depth + 1; 57 | if (type == Type.ATTR) { 58 | this.parent.attributes.add(this); 59 | } else { 60 | this.parent.children.add(this); 61 | } 62 | } 63 | } 64 | 65 | /** 66 | * Create a child and return the child. 67 | */ 68 | public KNode createChild(Type type, String value) { 69 | return new KNode(this, type, value); 70 | } 71 | 72 | /** 73 | * Create a child and return the child. 74 | */ 75 | public KNode createChild(Type type, String value, String fullText) { 76 | return new KNode(this, type, value, fullText); 77 | } 78 | 79 | /** 80 | * Create a child using the information from the `original` node and return the child. 81 | */ 82 | public KNode createChild(KNode original) { 83 | return new KNode(this, original.type, original.value, original.fullText); 84 | } 85 | 86 | public KNode createAttribute(String attributeName, String attributeValue) { 87 | KNode attributeNode = createChild(Type.ATTR, attributeName, attributeValue); 88 | attributeNode.createChild(Type.TEXT, attributeValue, attributeValue); 89 | return attributeNode; 90 | } 91 | 92 | // Getters 93 | 94 | public List getChildren() { 95 | return Collections.unmodifiableList(children); 96 | } 97 | 98 | public List getChildrenOfTag(String tag) { 99 | List answer = new ArrayList<>(); 100 | for (KNode child : children) { 101 | if (child.type == Type.TAG && (tag.equals(child.value) || tag.equals("*"))) { 102 | answer.add(child); 103 | } 104 | } 105 | return answer; 106 | } 107 | 108 | // The index is 0-based 109 | public KNode getChildrenOfTag(String tag, int index) { 110 | int count = 0; 111 | for (KNode child : children) { 112 | if (child.type == Type.TAG && (tag.equals(child.value) || tag.equals("*"))) { 113 | if (count == index) return child; 114 | count++; 115 | } 116 | } 117 | return null; 118 | } 119 | 120 | public List getAttributes() { 121 | return Collections.unmodifiableList(attributes); 122 | } 123 | 124 | public String getAttribute(String attributeName) { 125 | for (KNode attributeNode : attributes) { 126 | if (attributeNode.value.equals(attributeName)) 127 | return attributeNode.fullText; 128 | } 129 | return ""; 130 | } 131 | 132 | public String[] getAttributeList(String attributeName) { 133 | for (KNode attributeNode : attributes) { 134 | if (attributeNode.value.equals(attributeName) && !attributeNode.fullText.isEmpty()) 135 | return attributeNode.fullText.split(" "); 136 | } 137 | return new String[0]; 138 | } 139 | 140 | // The index is 0-based 141 | public int getChildIndex() { 142 | return this.parent.children.indexOf(this); 143 | } 144 | 145 | // The index is 0-based 146 | public int getChildIndexOfSameTag() { 147 | int count = 0; 148 | for (KNode child : this.parent.children) { 149 | if (child.type == Type.TAG && child.value.equals(this.value)){ 150 | if (child == this) return count; 151 | count++; 152 | } 153 | } 154 | return -1; 155 | } 156 | 157 | public int countChildren() { 158 | return this.children.size(); 159 | } 160 | 161 | public int countChildren(String tag) { 162 | int count = 0; 163 | for (KNode child : children) { 164 | if (child.type == Type.TAG && child.value.equals(tag)) count++; 165 | } 166 | return count; 167 | } 168 | 169 | // Debug Print 170 | 171 | public void debugPrint(int indent) { 172 | LogInfo.logs(StrUtils.repeat(" ", indent) + "%s '%s'", type, value); 173 | for (KNode child : children) { 174 | child.debugPrint(indent + 2); 175 | } 176 | } 177 | 178 | // Timestamp 179 | 180 | public void generateTimestamp() { 181 | generateTimestamp(1); 182 | generateTimestampInCollapsed(1); 183 | } 184 | 185 | protected int generateTimestamp(int currentTimestamp) { 186 | timestampIn = currentTimestamp++; 187 | for (KNode node : children) { 188 | currentTimestamp = node.generateTimestamp(currentTimestamp); 189 | } 190 | timestampOut = currentTimestamp++; 191 | return currentTimestamp; 192 | } 193 | 194 | protected int generateTimestampInCollapsed(int currentTimestampInCollapsed) { 195 | timestampInCollapsed = currentTimestampInCollapsed++; 196 | for (KNode node : children) { 197 | currentTimestampInCollapsed = node.generateTimestampInCollapsed(currentTimestampInCollapsed); 198 | } 199 | return currentTimestampInCollapsed; 200 | } 201 | 202 | } 203 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/model/tree/KNodeUtils.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.model.tree; 2 | 3 | import java.util.*; 4 | 5 | public class KNodeUtils { 6 | 7 | /** 8 | * @return The lowest common ancestor of all specified nodes. 9 | */ 10 | public static KNode lowestCommonAncestor(Collection nodes) { 11 | int minDepth = Integer.MAX_VALUE; 12 | for (KNode node : nodes) 13 | minDepth = Math.min(minDepth, node.depth); 14 | Set sameDepthAncestors = new HashSet<>(); 15 | for (KNode node : nodes) { 16 | while (node.depth != minDepth) node = node.parent; 17 | sameDepthAncestors.add(node); 18 | } 19 | while (sameDepthAncestors.size() != 1) { 20 | if (minDepth == 0) return null; 21 | Set newSameDepthAncestors = new HashSet<>(); 22 | for (KNode node : sameDepthAncestors) newSameDepthAncestors.add(node.parent); 23 | sameDepthAncestors = newSameDepthAncestors; 24 | } 25 | return sameDepthAncestors.iterator().next(); 26 | } 27 | 28 | public static boolean isDescendantOf(KNode allegedDescendant, KNode node) { 29 | KNode currentNode = allegedDescendant; 30 | while (currentNode != null) { 31 | if (currentNode == node) return true; 32 | currentNode = currentNode.parent; 33 | } 34 | return false; 35 | } 36 | 37 | /** 38 | * Print the tree to standard error. Useful for debugging. 39 | */ 40 | public static void printTree(KNode node) { printTree(node, 0); } 41 | 42 | public static void printTree(KNode node, int indent) { 43 | if ("text".equals(node.value)) { 44 | String text = node.fullText; 45 | if (text == null) text = "..."; 46 | System.err.printf("%s%s\n", new String(new char[indent]).replace('\0', ' '), text); 47 | return; 48 | } 49 | System.err.printf("%s<%s>\n", new String(new char[indent]).replace('\0', ' '), node.value); 50 | for (KNode child : node.getChildren()) { 51 | printTree(child, indent + 2); 52 | } 53 | System.err.printf("%s\n", new String(new char[indent]).replace('\0', ' '), node.value); 54 | } 55 | 56 | /** 57 | * Copy a KNode and its subtree to a new parent. 58 | */ 59 | public static KNode copyTree(KNode node, KNode newParent) { 60 | KNode newNode = newParent.createChild(node); 61 | for (KNode x : node.getChildren()) copyTree(x, newNode); 62 | for (KNode x : node.getAttributes()) copyTree(x, newNode); 63 | return newNode; 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/util/BipartiteMatcher.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.util; 2 | 3 | import java.util.*; 4 | 5 | import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity; 6 | 7 | public class BipartiteMatcher { 8 | 9 | private final int SOURCE = 1; 10 | private final int SINK = -1; 11 | 12 | private final Map fromMap; 13 | private final Map toMap; 14 | private final Map> edges; 15 | 16 | public BipartiteMatcher() { 17 | this.fromMap = new HashMap<>(); 18 | this.toMap = new HashMap<>(); 19 | this.edges = new HashMap<>(); 20 | } 21 | 22 | public BipartiteMatcher(List targetEntities, List predictedEntities) { 23 | this(); 24 | for (int i = 0; i < targetEntities.size(); i++) { 25 | TargetEntity targetEntity = targetEntities.get(i); 26 | for (int j = 0; j < predictedEntities.size(); j++) { 27 | if (targetEntity.match(predictedEntities.get(j))) { 28 | this.addEdge(i, j); 29 | } 30 | } 31 | } 32 | } 33 | 34 | public void addEdge(Object fromObj, Object toObj) { 35 | Integer from = fromMap.get(fromObj), to = toMap.get(toObj); 36 | if (from == null) { 37 | from = 2 + fromMap.size(); 38 | fromMap.put(fromObj, from); 39 | if (!edges.containsKey(SOURCE)) edges.put(SOURCE, new ArrayList<>()); 40 | edges.get(SOURCE).add(from); 41 | } 42 | if (to == null) { 43 | to = - 2 - toMap.size(); 44 | toMap.put(toObj, to); 45 | if (!edges.containsKey(to)) edges.put(to, new ArrayList<>()); 46 | edges.get(to).add(SINK); 47 | } 48 | if (!edges.containsKey(from)) edges.put(from, new ArrayList<>()); 49 | edges.get(from).add(to); 50 | } 51 | 52 | private List foundPath; 53 | private Set foundNodes; 54 | 55 | public int findMaximumMatch() { 56 | int count = 0; 57 | this.foundPath = new ArrayList<>(); 58 | this.foundNodes = new HashSet<>(); 59 | while (findPath(SOURCE)) { 60 | count++; 61 | for (int i = 0; i < foundPath.size() - 1; i++) { 62 | int from = foundPath.get(i), to = foundPath.get(i+1); 63 | edges.get(from).remove(Integer.valueOf(to)); 64 | if (!edges.containsKey(to)) edges.put(to, new ArrayList<>()); 65 | edges.get(to).add(from); 66 | } 67 | foundPath.clear(); 68 | foundNodes.clear(); 69 | } 70 | return count; 71 | } 72 | 73 | private boolean findPath(int node) { 74 | // DFS 75 | foundNodes.add(node); 76 | foundPath.add(node); 77 | if (node == SINK) return true; 78 | for (int dest : edges.get(node)) { 79 | if (!foundNodes.contains(dest)) { 80 | if (findPath(dest)) return true; 81 | } 82 | } 83 | foundPath.remove(foundPath.size() - 1); 84 | return false; 85 | } 86 | 87 | public static void main(String[] args) { 88 | // Test Method 89 | BipartiteMatcher bm = new BipartiteMatcher(); 90 | bm.addEdge("A", 1); bm.addEdge("A", 2); bm.addEdge("A", 4); 91 | bm.addEdge("B", 1); bm.addEdge("C", 2); bm.addEdge("C", 1); 92 | bm.addEdge("D", 4); bm.addEdge("D", 5); bm.addEdge("E", 3); 93 | System.out.println(bm.findMaximumMatch()); 94 | } 95 | 96 | } 97 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/util/EditDistance.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.util; 2 | 3 | public class EditDistance { 4 | 5 | private static int min3(int x1, int x2, int x3) { 6 | return Math.min(Math.min(x1, x2), x3); 7 | } 8 | 9 | /** 10 | * Return true if |a| and |b| are within Levenshtein edit distance |limit|. 11 | */ 12 | public static boolean withinEditDistance(String a, String b, int limit) { 13 | if (a == null || b == null || Math.abs(a.length() - b.length()) > limit) return false; 14 | // memory array: 0 limit-1 limit limit+1 2*limit 15 | // actual column: (i-limit) .. (i-1) i (i+1) .. (i+limit) 16 | int[] memory = new int[2 * limit + 1]; 17 | int INFINITY = limit + 1; 18 | // First row 19 | for (int j = 0; j < limit; j++) memory[j] = INFINITY; 20 | for (int j = 0; j <= limit; j++) memory[limit + j] = j; 21 | // Consequent rows 22 | for (int i = 0; i < a.length(); i++) { 23 | int[] newMemory = new int[2 * limit + 1]; 24 | for (int j = 0; j <= 2 * limit; j++) { 25 | int actualJ = i + j - limit; 26 | if (actualJ < 0 || actualJ >= b.length()) { 27 | newMemory[j] = INFINITY; 28 | } else { 29 | newMemory[j] = min3( 30 | (j-1 < 0 ? INFINITY : newMemory[j-1] + 1), 31 | (j+1 > 2*limit ? INFINITY : memory[j+1] + 1), 32 | memory[j] + (a.charAt(i) == b.charAt(actualJ) ? 0 : 1)); 33 | } 34 | } 35 | memory = newMemory; 36 | } 37 | return memory[limit + (b.length() - a.length())] <= limit; 38 | } 39 | 40 | public static void main(String args[]) { 41 | System.out.println(withinEditDistance("this", "this", 0)); 42 | System.out.println(withinEditDistance("this", "this", 3)); 43 | System.out.println(withinEditDistance("this", "that", 2)); 44 | System.out.println(withinEditDistance("this", "these", 2)); 45 | System.out.println(withinEditDistance("why?", "who!", 1)); 46 | System.out.println(withinEditDistance("This is a book", "That is a Book", 3)); 47 | System.out.println(withinEditDistance("This is a book", "That is a Books", 4)); 48 | System.out.println(withinEditDistance("hello", "HELLO", 3)); 49 | System.out.println(withinEditDistance("hello", "\"hello\"", 2)); 50 | System.out.println(withinEditDistance("cafe", "Café", 2)); 51 | } 52 | 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/util/Multiset.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.util; 2 | 3 | import java.util.*; 4 | 5 | public class Multiset { 6 | protected final HashMap map = new HashMap<>(); 7 | protected int size = 0; 8 | 9 | public void add(T entry) { 10 | Integer count = map.get(entry); 11 | if (count == null) 12 | count = 0; 13 | map.put(entry, count + 1); 14 | size++; 15 | } 16 | 17 | public void add(T entry, int incr) { 18 | Integer count = map.get(entry); 19 | if (count == null) 20 | count = 0; 21 | map.put(entry, count + incr); 22 | size += incr; 23 | } 24 | 25 | public boolean contains(T entry) { 26 | return map.containsKey(entry); 27 | } 28 | 29 | public int count(T entry) { 30 | Integer count = map.get(entry); 31 | if (count == null) 32 | count = 0; 33 | return count; 34 | } 35 | 36 | public Set elementSet() { 37 | return map.keySet(); 38 | } 39 | 40 | public Set> entrySet() { 41 | return map.entrySet(); 42 | } 43 | 44 | public int size() { 45 | return size; 46 | } 47 | 48 | public boolean isEmpty() { 49 | return size == 0; 50 | } 51 | 52 | public Multiset getPrunedByCount(int minCount) { 53 | Multiset pruned = new Multiset<>(); 54 | for (Map.Entry entry : map.entrySet()) { 55 | if (entry.getValue() >= minCount) 56 | pruned.add(entry.getKey(), entry.getValue()); 57 | } 58 | return pruned; 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/util/Parallelizer.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.util; 2 | 3 | import java.util.List; 4 | import java.util.concurrent.Callable; 5 | import java.util.concurrent.ExecutorService; 6 | import java.util.concurrent.Executors; 7 | import java.util.concurrent.Future; 8 | import java.util.concurrent.TimeUnit; 9 | 10 | import fig.basic.LogInfo; 11 | import fig.basic.Option; 12 | 13 | public class Parallelizer { 14 | public static class Options { 15 | @Option(gloss = "Number of threads for execution") 16 | public int numThreads = 1; 17 | } 18 | public static Options opts = new Options(); 19 | 20 | public static int getNumThreads() { 21 | int numThreads = Runtime.getRuntime().availableProcessors(); 22 | if (opts.numThreads > 0 && numThreads > opts.numThreads) 23 | numThreads = opts.numThreads; 24 | return numThreads; 25 | } 26 | 27 | public static void run(List tasks) { 28 | LogInfo.begin_threads(); 29 | ExecutorService service = Executors.newFixedThreadPool(getNumThreads()); 30 | try { 31 | for (Runnable task : tasks) { 32 | service.submit(task); 33 | } 34 | service.shutdown(); 35 | service.awaitTermination(1, TimeUnit.DAYS); 36 | } catch (InterruptedException e) { 37 | LogInfo.fail(e); 38 | } 39 | LogInfo.end_threads(); 40 | } 41 | 42 | public static > List> runAndReturnStuff(List tasks) { 43 | LogInfo.begin_threads(); 44 | List> results = null; 45 | ExecutorService service = Executors.newFixedThreadPool(getNumThreads()); 46 | try { 47 | // Invoke all trainers 48 | results = service.invokeAll(tasks); 49 | service.shutdown(); 50 | service.awaitTermination(1, TimeUnit.DAYS); 51 | } catch (InterruptedException e) { 52 | LogInfo.fail(e); 53 | } 54 | LogInfo.end_threads(); 55 | return results; 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/util/SHA.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.util; 2 | 3 | import java.math.BigInteger; 4 | import java.security.MessageDigest; 5 | import java.security.NoSuchAlgorithmException; 6 | 7 | public class SHA { 8 | public static String toSHA1(String input) { 9 | try { 10 | MessageDigest crypt = MessageDigest.getInstance("SHA-1"); 11 | crypt.reset(); 12 | crypt.update(input.getBytes()); 13 | return new BigInteger(1, crypt.digest()).toString(16); 14 | } catch (NoSuchAlgorithmException e) { 15 | throw new RuntimeException(e); 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/util/SearchResult.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.util; 2 | 3 | // Corresponds to a webpage 4 | public class SearchResult { 5 | // query: how did we get to this web page (possibly null)? 6 | public SearchResult(String query, String url, String title) { 7 | this.query = query; 8 | this.url = url; 9 | this.title = title; 10 | } 11 | public final String query; 12 | public final String url; 13 | public final String title; 14 | 15 | @Override public String toString() { return url; } 16 | } -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/util/StringDoubleArrayList.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.util; 2 | 3 | import java.util.Iterator; 4 | 5 | /** 6 | * A slightly more memory-efficient List of Pair. 7 | * 8 | * Many parts of the code are from http://developer.classpath.org/doc/java/util/ArrayList-source.html 9 | */ 10 | public class StringDoubleArrayList implements Iterable { 11 | 12 | public static final int DEFAULT_CAPACITY = 10; 13 | 14 | private int size; 15 | 16 | // The two data arrays must have equal length. 17 | private String[] strings; 18 | private double[] doubles; 19 | 20 | public StringDoubleArrayList(int capacity) { 21 | if (capacity < 0) 22 | throw new IllegalArgumentException(); 23 | strings = new String[capacity]; 24 | doubles = new double[capacity]; 25 | } 26 | 27 | public StringDoubleArrayList() { 28 | this(DEFAULT_CAPACITY); 29 | } 30 | 31 | public int size() { 32 | return size; 33 | } 34 | 35 | public void ensureCapacity(int minCapacity) { 36 | if (minCapacity - strings.length > 0) { // subtract to prevent overflow 37 | { 38 | String[] newStrings = new String[Math.max(strings.length * 2, minCapacity)]; 39 | System.arraycopy(strings, 0, newStrings, 0, size); 40 | strings = newStrings; 41 | } 42 | { 43 | double[] newDoubles = new double[Math.max(doubles.length * 2, minCapacity)]; 44 | System.arraycopy(doubles, 0, newDoubles, 0, size); 45 | doubles = newDoubles; 46 | } 47 | } 48 | } 49 | 50 | public void add(String s, double d) { 51 | if (size == strings.length) 52 | ensureCapacity(size + 1); 53 | strings[size] = s; 54 | doubles[size] = d; 55 | size++; 56 | } 57 | 58 | private void checkBoundExclusive(int index) { 59 | if (index >= size) 60 | throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size); 61 | } 62 | 63 | public String getString(int index) { 64 | checkBoundExclusive(index); 65 | return strings[index]; 66 | } 67 | 68 | public double getDouble(int index) { 69 | checkBoundExclusive(index); 70 | return doubles[index]; 71 | } 72 | 73 | public class StringDoubleArrayListIterator implements Iterator, StringDoublePair { 74 | int index = -1; 75 | 76 | @Override 77 | public boolean hasNext() { 78 | return index < size - 1; 79 | } 80 | 81 | @Override 82 | public StringDoublePair next() { 83 | index++; 84 | return this; 85 | } 86 | 87 | @Override 88 | public void remove() { 89 | throw new RuntimeException("Cannot remove stuff from StringDoubleArrayList"); 90 | } 91 | 92 | @Override 93 | public String getFirst() { 94 | return strings[index]; 95 | } 96 | 97 | @Override 98 | public double getSecond() { 99 | return doubles[index]; 100 | } 101 | } 102 | 103 | @Override 104 | public Iterator iterator() { 105 | return new StringDoubleArrayListIterator(); 106 | } 107 | 108 | } 109 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/util/StringDoublePair.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.util; 2 | 3 | public interface StringDoublePair { 4 | public String getFirst(); 5 | public double getSecond(); 6 | } 7 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/util/StringSampler.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.util; 2 | 3 | import java.util.List; 4 | 5 | public class StringSampler { 6 | 7 | public static String sampleEntities(List entities) { 8 | return sampleEntities(entities, entities.size()); 9 | } 10 | 11 | private static final int TRAILING_ENTITIES = 3; 12 | private static final int MAX_TEXT_LENGTH = 60; 13 | public static final int DEFAULT_LIMIT = 30; 14 | 15 | public static String sampleEntities(List entities, int limit) { 16 | StringBuilder sb = new StringBuilder("{"); 17 | int n = entities.size(); 18 | if (n <= limit) { 19 | for (int i = 0; i < n; i++) { 20 | chopString(sb, entities.get(i).toString(), i > 0); 21 | } 22 | } else { 23 | for (int i = 0; i < limit - TRAILING_ENTITIES; i++) { 24 | chopString(sb, entities.get(i).toString(), i > 0); 25 | } 26 | sb.append(", ... (").append(n - limit).append(" more) ..."); 27 | for (int i = n - TRAILING_ENTITIES - 1; i < n; i++) { 28 | chopString(sb, entities.get(i).toString(), i > 0); 29 | } 30 | } 31 | return sb.append("} (").append(n).append(" total)").toString(); 32 | } 33 | 34 | private static void chopString(StringBuilder sb, String x, boolean addComma) { 35 | if (addComma) sb.append(", "); 36 | sb.append('"'); 37 | x = x.replace("\n", " "); 38 | if (x.length() > MAX_TEXT_LENGTH) 39 | sb.append(x.substring(0, MAX_TEXT_LENGTH)).append("...\""); 40 | else 41 | sb.append(x).append('"'); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/util/VectorAverager.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.util; 2 | 3 | import fig.basic.ListUtils; 4 | 5 | /** 6 | * VectorAverager computes the term-wise average of several vectors. 7 | * It also computes term-wise minimum and maximum. 8 | */ 9 | public class VectorAverager { 10 | final int dim; 11 | double[] aggregate, min, max; 12 | int count = 0; 13 | double sumFactors = 0; 14 | 15 | public VectorAverager(int numDimensions) { 16 | dim = numDimensions; 17 | aggregate = new double[dim]; 18 | } 19 | 20 | public void add(double[] vector) { 21 | if (vector == null) return; 22 | count++; 23 | sumFactors++; 24 | if (min == null) min = vector.clone(); 25 | if (max == null) max = vector.clone(); 26 | for (int i = 0; i < dim; i++) { 27 | aggregate[i] += vector[i]; 28 | } 29 | } 30 | 31 | public void add(double[] vector, double factor) { 32 | if (vector == null) return; 33 | count++; 34 | sumFactors += factor; 35 | if (min == null) min = ListUtils.mult(factor, vector); 36 | if (max == null) max = ListUtils.mult(factor, vector); 37 | for (int i = 0; i < dim; i++) { 38 | aggregate[i] += vector[i] * factor; 39 | min[i] = Math.min(min[i], vector[i] * factor); 40 | max[i] = Math.max(max[i], vector[i] * factor); 41 | } 42 | } 43 | 44 | public double[] getSum() { 45 | return aggregate.clone(); 46 | } 47 | /* 48 | public double[] getAverage() { 49 | if (count == 0) return null; 50 | double[] answer = new double[dim]; 51 | for (int i = 0; i < dim; i++) 52 | answer[i] = aggregate[i] / count; 53 | return answer; 54 | } 55 | */ 56 | public double[] getAverage() { 57 | if (sumFactors < 1e-6) return null; 58 | double[] answer = new double[dim]; 59 | for (int i = 0; i < dim; i++) 60 | answer[i] = aggregate[i] / sumFactors; 61 | return answer; 62 | } 63 | 64 | public double[] getMin() { 65 | if (count == 0) return null; 66 | return min.clone(); 67 | } 68 | 69 | public double[] getMax() { 70 | if (count == 0) return null; 71 | return max.clone(); 72 | } 73 | 74 | /** Term-wise largest magnitude **/ 75 | public double[] getAbsMax() { 76 | if (count == 0) return null; 77 | double[] absMax = new double[dim]; 78 | for (int i = 0; i < dim; i++) { 79 | absMax[i] = Math.max(Math.abs(max[i]), Math.abs(min[i])); 80 | } 81 | return absMax; 82 | } 83 | 84 | /** Min and Max concatenated **/ 85 | public double[] getMinmax() { 86 | if (count == 0) return null; 87 | double[] minmax = new double[dim * 2]; 88 | System.arraycopy(min, 0, minmax, 0, dim); 89 | System.arraycopy(max, 0, minmax, dim, dim); 90 | return minmax; 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/edu/stanford/nlp/semparse/open/util/WebUtils.java: -------------------------------------------------------------------------------- 1 | package edu.stanford.nlp.semparse.open.util; 2 | 3 | import java.io.IOException; 4 | import java.util.*; 5 | 6 | import org.jsoup.Jsoup; 7 | import org.jsoup.nodes.Document; 8 | 9 | import com.fasterxml.jackson.databind.JsonNode; 10 | import com.fasterxml.jackson.databind.ObjectMapper; 11 | 12 | import fig.basic.Utils; 13 | 14 | /** 15 | * Handy utilities for interacting with the web. 16 | */ 17 | public class WebUtils { 18 | private static ObjectMapper jsonMapper = new ObjectMapper(); 19 | 20 | /** 21 | * Return the contents of a webpage. 22 | */ 23 | private static Document executeGetWebpageScript(String flags) { 24 | try { 25 | String contents = Utils.systemGetStringOutput("./scripts/get-webpage.py " + flags); 26 | return Jsoup.parse(contents); 27 | } catch (IOException e) { 28 | throw new RuntimeException(e); 29 | } catch (InterruptedException e) { 30 | throw new RuntimeException(e); 31 | } 32 | } 33 | 34 | public static Document getWebpage(String url) { 35 | url = url.replaceAll("'", "'\"'\"'"); 36 | return executeGetWebpageScript(" '" + url + "' "); 37 | } 38 | 39 | public static Document getWebpageFromHashcode(String cacheDirectory, String hashcode) { 40 | String flags = " -H " + hashcode; 41 | if (cacheDirectory != null && !cacheDirectory.isEmpty()) 42 | flags += " -d " + cacheDirectory; 43 | return executeGetWebpageScript(flags); 44 | } 45 | 46 | /** 47 | * Return the search results for a given query. 48 | */ 49 | public static List googleSearch(String query) { 50 | // Query is just a single webpage 51 | if (query.startsWith("http://")) 52 | return Collections.singletonList(new SearchResult(query, query, null)); 53 | 54 | try { 55 | query = query.replaceAll("'", "'\"'\"'"); 56 | String contents = Utils.systemGetStringOutput("./scripts/google-search.py '" + query + "'"); 57 | JsonNode root = jsonMapper.readTree(contents.getBytes("UTF-8")); 58 | List pages = new ArrayList<>(); 59 | for (JsonNode item : root) { 60 | pages.add(new SearchResult(query, item.get(0).asText(), item.get(1).asText())); 61 | } 62 | return pages; 63 | } catch (IOException e) { 64 | throw new RuntimeException(e); 65 | } catch (InterruptedException e) { 66 | throw new RuntimeException(e); 67 | } 68 | } 69 | 70 | /** 71 | * Equivalent to doing Google Search but actually reading from file. 72 | */ 73 | public static List fakeGoogleSearch(String query) { 74 | try { 75 | query = query.replaceAll("'", "'\"'\"'"); 76 | String contents = Utils.systemGetStringOutput("./scripts/fake-google-search.py '" + query + "'"); 77 | JsonNode root = jsonMapper.readTree(contents.getBytes("UTF-8")); 78 | List pages = new ArrayList<>(); 79 | for (JsonNode item : root) { 80 | pages.add(new SearchResult(query, item.get("link").asText(), item.get("title").asText())); 81 | } 82 | return pages; 83 | } catch (IOException e) { 84 | throw new RuntimeException(e); 85 | } catch (InterruptedException e) { 86 | throw new RuntimeException(e); 87 | } 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /web-entity-extractor: -------------------------------------------------------------------------------- 1 | #!/usr/bin/ruby 2 | 3 | require './fig/lib/execrunner' 4 | 5 | system "mkdir -p state/execs" 6 | system "mkdir -p cache" 7 | 8 | run!( 9 | # For running in a queue 10 | letDefault(:q, 0), sel(:q, l(), l('fig/bin/q', '-shareWorkingPath', o('mem', '5g'), o('memGrace', 10), '-add', '---')), 11 | 12 | # Profiling 13 | letDefault(:prof, 0), sel(:prof, l(), '-Xrunhprof:cpu=samples,depth=100,file=_OUTPATH_/java.hprof.txt'), 14 | 15 | 'fig/bin/qcreate', o('statePath', 'state'), 16 | 'java', '-cp', 'classes:lib/*', 17 | 18 | # Set memory size 19 | letDefault(:memsize, 'tiny'), 20 | sel(:memsize, { 21 | 'tiny' => l('-Xms2G', '-Xmx4G'), 22 | 'low' => l('-Xms12G', '-Xmx20G'), 23 | 'medium' => l('-Xms24G', '-Xmx30G'), 24 | 'high' => l('-Xms36G', '-Xmx50G'), 25 | 'higher' => l('-Xms48G', '-Xmx70G'), 26 | 'impressive' => l('-Xms60G', '-Xmx90G'), 27 | }), 28 | 29 | # Determine class to load 30 | sel(:mode, { 31 | 'main' => l( 32 | 'edu.stanford.nlp.semparse.open.Main', 33 | nil), 34 | 'load' => l( 35 | 'edu.stanford.nlp.semparse.open.Main', 36 | let(:feat, 'none'), 37 | nil), 38 | 'interactive' => l( 39 | 'edu.stanford.nlp.semparse.open.Main', 40 | let(:feat, 'none'), 41 | let(:data, 'interactive'), 42 | nil), 43 | }), 44 | 45 | # For fig 46 | o('execDir', '_OUTPATH_'), o('overwriteExecDir'), 47 | o('addToView', 0), 48 | 49 | o('numThreads', 1), 50 | o('numTrainIters', 5), 51 | o('featureMinimumCount', 0), # Disable pruning as default 52 | o('pruneSmallFeaturesThreshold', 0), # Disable pruning as default 53 | o('beta', 0.01), 54 | 55 | # Features 56 | sel(:feat, { 57 | 'none' => l(o('useAllFeatures', false)), 58 | 'structural' => l(o('useAllFeatures', false), o('include', 'self-or-ancestors', 'node-range', 'hole')), 59 | 'denotation' => l(o('useAllFeatures', false), o('include', 'entity', 'ling')), 60 | 'default' => l( 61 | o('useAllFeatures', false), 62 | o('include', 'self-or-ancestors', 'node-range', 'hole', 'entity', 'ling'), 63 | o('featureMinimumCount', 2), 64 | nil), 65 | }), 66 | 67 | sel(0, 68 | l(), # Don't use word clusters 69 | l(o('brownClusterFilename', 'lib/wordreprs/brown-rcv1.clean.tokenized-CoNLL03.txt-c1000-freq1.txt')), # Turian et al. 70 | nil), 71 | 72 | sel(0, 73 | l(), # Don't use word vectors 74 | l(o('wordVectorFilename', 'lib/wordreprs/embeddings-scaled.EMBEDDING_SIZE=50.txt')), # Turian et al. 75 | l(o('wordVectorFilename', 'lib/cbow.vectors'), o('wordVectorUNKindex', -1)), # CBOW (word2vec) 76 | nil), 77 | 78 | sel(0, 79 | l(), # Don't use frequency table 80 | l(o('frequencyFilename', 'lib/ling-data/webFrequency.tsv')), # default 81 | nil), 82 | 83 | sel(0, 84 | l(), # Don't use wordnet clusters 85 | l(o('wordnetClusterFilename', 'lib/ling-data/wordnet/newer-30')), # default 86 | nil), 87 | 88 | sel(1, 89 | l(), # Don't use query types 90 | l(o('queryTypeFilename', 'lib/ling-data/queryType.tsv')), # default 91 | nil), 92 | 93 | sel(1, 94 | l(), 95 | l(o('lateNorm', 3), o('targetNorm', 3)), 96 | nil), 97 | 98 | # Dataset 99 | sel(:data, { 100 | 'debug' => l( 101 | o('dataset', 'openweb.debug'), 102 | nil), 103 | 104 | 'dev' => l( 105 | o('dataset', 'openweb.train'), 106 | nil), 107 | 108 | 'test' => l( 109 | o('dataset', 'openweb.test'), 110 | nil), 111 | 112 | 'real' => l( 113 | o('dataset', 'openweb.train@test'), 114 | nil), 115 | 116 | 'custom' => l(), 117 | 'interactive' => l(), 118 | }), 119 | 120 | nil) 121 | --------------------------------------------------------------------------------