├── __init__.py ├── brent-phone.tar.gz ├── brent-text.tar.gz ├── .gitignore ├── option_parser.py ├── README.md ├── launch_test.py ├── util.py ├── launch_resume.py └── launch_train.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brent-phone.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kzhai/PyAdaGram/HEAD/brent-phone.tar.gz -------------------------------------------------------------------------------- /brent-text.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kzhai/PyAdaGram/HEAD/brent-text.tar.gz -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *~ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | *.o 11 | *.d 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | bin/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # Installer logs 31 | pip-log.txt 32 | pip-delete-this-directory.txt 33 | 34 | # Unit test / coverage reports 35 | htmlcov/ 36 | .tox/ 37 | .coverage 38 | .cache 39 | nosetests.xml 40 | coverage.xml 41 | 42 | # Translations 43 | *.mo 44 | 45 | # Mr Developer 46 | .mr.developer.cfg 47 | .project 48 | .pydevproject 49 | 50 | # Rope 51 | .ropeproject 52 | 53 | # Django stuff: 54 | *.log 55 | *.pot 56 | 57 | # Sphinx documentation 58 | docs/_build/ 59 | 60 | -------------------------------------------------------------------------------- /option_parser.py: -------------------------------------------------------------------------------- 1 | import optparse 2 | 3 | delimiter = '-' 4 | 5 | 6 | def floatable(str): 7 | try: 8 | float(str) 9 | return True 10 | except ValueError: 11 | return False 12 | 13 | 14 | def intable(str): 15 | try: 16 | int(str) 17 | return True 18 | except ValueError: 19 | return False 20 | 21 | 22 | def process_floats(option, opt_str, value, parser): 23 | assert value is None 24 | value = {} 25 | 26 | for arg in parser.rargs: 27 | # stop on --foo like options 28 | if arg[:2] == "--" and len(arg) > 2: 29 | break 30 | # stop on -a, but not on -3 or -3.0 31 | if arg[:1] == "-" and len(arg) > 1 and not floatable(arg): 32 | break 33 | 34 | tokens = arg.split("=") 35 | value[tokens[0]] = float(tokens[1]) 36 | 37 | del parser.rargs[:len(value)] 38 | setattr(parser.values, option.dest, value) 39 | 40 | return 41 | 42 | 43 | def process_ints(option, opt_str, value, parser): 44 | assert value is None 45 | value = {} 46 | 47 | for arg in parser.rargs: 48 | # stop on --foo like options 49 | if arg[:2] == "--" and len(arg) > 2: 50 | break 51 | # stop on -a, but not on -3 or -3.0 52 | if arg[:1] == "-" and len(arg) > 1 and not int(arg): 53 | break 54 | 55 | tokens = arg.split("=") 56 | value[tokens[0]] = int(tokens[1]) 57 | 58 | del parser.rargs[:len(value)] 59 | setattr(parser.values, option.dest, value) 60 | 61 | return 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PyAdaGram 2 | ========== 3 | 4 | PyAdaGram is an online Adaptor Grammar model package, developed by the Cloud Computing Research Team in [University of Maryland, College Park](http://www.umd.edu). 5 | You may find more details about this project on our papaer [Online Adaptor Grammars with Hybrid Inference](http://kzhai.github.io/paper/2014_tacl.pdf) appeared in TACL 2014. 6 | 7 | Please download the latest version from our [GitHub repository](https://github.com/kzhai/PyAdaGram). 8 | 9 | Please send any bugs of problems to Ke Zhai (kzhai@umd.edu). 10 | 11 | Install and Build 12 | ---------- 13 | 14 | This package depends on many external python libraries, such as numpy, scipy and nltk. 15 | 16 | Launch and Execute 17 | ---------- 18 | 19 | Assume the PyAdaGram package is downloaded under directory ```$PROJECT_SPACE/src/```, i.e., 20 | 21 | $PROJECT_SPACE/src/PyAdaGram 22 | 23 | To prepare the example dataset, 24 | 25 | tar zxvf brent-phone.tar.gz 26 | 27 | To launch PyAdaGram, first redirect to the directory of PyAdaGram source code, 28 | 29 | cd $PROJECT_SPACE/src/PyAdaGram 30 | 31 | and run the following command on example dataset, 32 | 33 | ```bash 34 | python -m launch_train \ 35 | --input_directory=./brent-phone/ \ 36 | --output_directory=./ \ 37 | --grammar_file=./brent-phone/grammar.unigram \ 38 | --number_of_documents=9790 \ 39 | --batch_size=10 40 | ``` 41 | 42 | The generic argument to run PyAdaGram is 43 | 44 | ```bash 45 | python -m launch_train \ 46 | --input_directory=$INPUT_DIRECTORY/$CORPUS_NAME \ 47 | --output_directory=$OUTPUT_DIRECTORY \ 48 | --grammar_file=$GRAMMAR_FILE \ 49 | --number_of_documents=$NUMBER_OF_DOCUMENTS \ 50 | --batch_size=$BATCH_SIZE 51 | ``` 52 | 53 | You should be able to find the output at directory ```$OUTPUT_DIRECTORY/$CORPUS_NAME```. 54 | 55 | Under any circumstances, you may also get help information and usage hints by running the following command 56 | 57 | ```bash 58 | python -m launch_train --help 59 | ``` 60 | 61 | To launch test script, run the following command 62 | 63 | ```bash 64 | python -m launch_test \ 65 | --input_directory=$DATA_DIRECTORY \ 66 | --model_directory=$MODEL_DIRECTORY \ 67 | --non_terminal_symbol=$NON_TERMINAL_SYMBOL 68 | ``` 69 | -------------------------------------------------------------------------------- /launch_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import pickle 3 | import string 4 | import numpy 5 | import getopt 6 | import sys 7 | import random 8 | import time 9 | import re 10 | import pprint 11 | import codecs 12 | import datetime, os 13 | import scipy.io 14 | import nltk 15 | import numpy 16 | import optparse 17 | 18 | 19 | # bash scrip to terminate all sub-processes 20 | # kill $(ps aux | grep 'python infag' | awk '{print $2}') 21 | 22 | def parse_args(): 23 | parser = optparse.OptionParser() 24 | parser.set_defaults( 25 | # parameter set 1 26 | input_directory=None, 27 | model_directory=None, 28 | non_terminal_symbol="Word", 29 | number_of_samples=10, 30 | number_of_processes=0, 31 | ) 32 | # parameter set 1 33 | parser.add_option("--input_directory", type="string", dest="input_directory", 34 | help="input directory [None]") 35 | parser.add_option("--model_directory", type="string", dest="model_directory", 36 | help="model directory [None]") 37 | parser.add_option("--non_terminal_symbol", type="string", dest="non_terminal_symbol", 38 | help="non-terminal symbol [Word]") 39 | parser.add_option("--number_of_samples", type="int", dest="number_of_samples", 40 | help="number of samples [10]") 41 | parser.add_option("--number_of_processes", type="int", dest="number_of_processes", 42 | help="number of processes [0]") 43 | 44 | (options, args) = parser.parse_args() 45 | return options 46 | 47 | 48 | def main(): 49 | options = parse_args() 50 | 51 | # parameter set 1 52 | assert (options.input_directory is not None) 53 | input_directory = options.input_directory 54 | assert (options.model_directory is not None) 55 | model_directory = options.model_directory 56 | assert (options.non_terminal_symbol is not None) 57 | non_terminal_symbol = options.non_terminal_symbol 58 | 59 | assert (options.number_of_samples > 0) 60 | number_of_samples = options.number_of_samples 61 | assert (options.number_of_processes >= 0) 62 | number_of_processes = options.number_of_processes 63 | 64 | print("========== ========== ========== ========== ==========") 65 | # parameter set 1 66 | print("input_directory=" + input_directory) 67 | print("model_directory=" + model_directory) 68 | print("non_terminal_symbol=" + non_terminal_symbol) 69 | 70 | print("number_of_samples=" + str(number_of_samples)) 71 | print("number_of_processes=" + str(number_of_processes)) 72 | print("========== ========== ========== ========== ==========") 73 | 74 | # Documents 75 | train_docs = [] 76 | input_stream = open(os.path.join(input_directory, 'train.dat'), 'r') 77 | for line in input_stream: 78 | train_docs.append(line.strip()) 79 | input_stream.close() 80 | print("successfully load %d training documents..." % (len(train_docs))) 81 | 82 | refer_docs = [] 83 | input_stream = open(os.path.join(input_directory, 'truth.dat'), 'r') 84 | for line in input_stream: 85 | refer_docs.append(line.strip()) 86 | input_stream.close() 87 | print("successfully load %d testing documents..." % (len(refer_docs))) 88 | 89 | for model_file in os.listdir(model_directory): 90 | if not model_file.startswith("model-"): 91 | continue 92 | 93 | model_file_path = os.path.join(model_directory, model_file) 94 | 95 | try: 96 | cpickle_file = open(model_file_path, 'rb') 97 | infinite_adaptor_grammar = pickle.load(cpickle_file) 98 | print("successfully load model from %s" % (model_file_path)) 99 | cpickle_file.close() 100 | except ValueError: 101 | print("warning: unsuccessfully load model from %s due to value error..." % (model_file_path)) 102 | continue 103 | except EOFError: 104 | print("warning: unsuccessfully load model from %s due to EOF error..." % (model_file_path)) 105 | continue 106 | 107 | non_terminal = nltk.grammar.Nonterminal(non_terminal_symbol) 108 | # assert non_terminal in infinite_adaptor_grammar._adapted_non_terminals, ( 109 | # non_terminal, infinite_adaptor_grammar._adapted_non_terminals) 110 | 111 | inference_parameter = (refer_docs, non_terminal, model_directory, model_file) 112 | infinite_adaptor_grammar.inference(train_docs, inference_parameter, number_of_samples, number_of_processes) 113 | 114 | ''' 115 | #from launch_train import shuffle_lists 116 | #shuffle_lists(train_docs, refer_docs) 117 | if number_of_processes==0: 118 | infinite_adaptor_grammar.inference(train_docs, refer_docs, non_terminal, model_directory, model_file, number_of_samples) 119 | else: 120 | from hybrid_process import inference_process 121 | inference_process(infinite_adaptor_grammar, train_docs, refer_docs, non_terminal, model_directory, number_of_samples, number_of_processes) 122 | ''' 123 | 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import nltk 3 | import numpy 4 | import scipy 5 | import sys 6 | import math 7 | import time 8 | 9 | """ 10 | """ 11 | 12 | 13 | def log_add(log_a, log_b): 14 | if log_a < log_b: 15 | return log_b + numpy.log(1 + numpy.exp(log_a - log_b)) 16 | else: 17 | return log_a + numpy.log(1 + numpy.exp(log_b - log_a)) 18 | 19 | 20 | class GraphNode(object): 21 | def __init__(self, non_terminal, next_nodes=set()): 22 | self._non_terminal = non_terminal 23 | self._next_nodes = next_nodes 24 | 25 | def add_node(self, next_node): 26 | assert (isinstance(next_node, self.__class__)) 27 | self._next_nodes.add(next_node) 28 | 29 | 30 | class AdaptedProduction(nltk.grammar.Production): 31 | def __init__(self, lhs, rhs, productions): 32 | super(AdaptedProduction, self).__init__(lhs, rhs) 33 | self._productions = productions 34 | productions_hash = 0 35 | for production in productions: 36 | productions_hash = productions_hash // 1000000000 37 | productions_hash += production.__hash__() 38 | self._hash = hash((self._lhs, self._rhs, productions_hash)) 39 | 40 | def get_production_list(self): 41 | return self._productions 42 | 43 | def match_grammaton(self, production_list): 44 | if len(self._productions) != len(production_list): 45 | return False 46 | for x in range(len(self._productions)): 47 | if self._productions[x] != production_list[x]: 48 | return False 49 | return True 50 | 51 | def __eq__(self, other): 52 | """ 53 | @return: true if this C{Production} is equal to C{other}. 54 | @rtype: C{boolean} 55 | """ 56 | if not isinstance(other, self.__class__): 57 | return False 58 | if self._lhs != other._lhs: 59 | return False 60 | if self._rhs != other._rhs: 61 | return False 62 | return self.match_grammaton(other._productions) 63 | 64 | def __str__(self): 65 | str = "%s -> %s (" % (self._lhs, " ".join(["%s" % elt for elt in self._rhs])) 66 | str += ", ".join(["%s" % production for production in self._productions]) 67 | str += ")" 68 | return str 69 | 70 | def retrieve_tokens_of_adapted_non_terminal(self, adapted_non_terminal): 71 | if self.lhs() == adapted_non_terminal: 72 | return ["".join(self.rhs())] 73 | else: 74 | token_list = [] 75 | for candidate_production in self._productions: 76 | if isinstance(candidate_production, self.__class__): 77 | token_list += candidate_production.retrieve_tokens_of_adapted_non_terminal(adapted_non_terminal) 78 | else: 79 | continue 80 | return token_list 81 | 82 | def __hash__(self): 83 | return super(AdaptedProduction, self).__hash__() 84 | 85 | 86 | class HyperNode(object): 87 | def __init__(self, 88 | node, 89 | span): 90 | self._node = node 91 | self._span = span 92 | 93 | self._derivation = [] 94 | self._log_probability = [] 95 | 96 | # TODO: this would incur duplication in derivation 97 | # self._derivation_log_probability = {} 98 | 99 | self._accumulated_log_probability = float('NaN') 100 | 101 | def add_new_derivation(self, production, log_probability, hyper_nodes=None): 102 | self._derivation.append((production, hyper_nodes)) 103 | self._log_probability.append(log_probability) 104 | 105 | ''' 106 | if (production, hyper_nodes) not in self._derivation_log_probability: 107 | self._derivation_log_probability[(production, hyper_nodes)] = log_probability 108 | else: 109 | self._derivation_log_probability[(production, hyper_nodes)] = log_add(log_probability, self._derivation_log_probability[(production, hyper_nodes)]) 110 | ''' 111 | 112 | if math.isnan(self._accumulated_log_probability): 113 | self._accumulated_log_probability = log_probability 114 | else: 115 | self._accumulated_log_probability = log_add(log_probability, self._accumulated_log_probability) 116 | 117 | return 118 | 119 | def random_sample_derivation(self): 120 | # print self._derivation_log_probability.keys() 121 | random_number = numpy.random.random() 122 | 123 | ''' 124 | for (production, hyper_nodes) in self._derivation_log_probability: 125 | current_probability = numpy.exp(self._derivation_log_probability[(production, hyper_nodes)] - self._accumulated_log_probability) 126 | if random_number>current_probability: 127 | random_number -= current_probability 128 | else: 129 | return production, hyper_nodes 130 | ''' 131 | 132 | # print "<<<<<<<<<>>>>>>>>>" 133 | # for x in xrange(len(self._derivation)): 134 | # print self._derivation[x], numpy.exp(self._log_probability[x] - self._accumulated_log_probability) 135 | # sys.exit() 136 | 137 | assert (len(self._derivation) == len(self._log_probability)) 138 | for x in range(len(self._derivation)): 139 | current_probability = numpy.exp(self._log_probability[x] - self._accumulated_log_probability) 140 | if random_number > current_probability: 141 | random_number -= current_probability 142 | else: 143 | # return self._derivation[x][0], self._derivation[x][1] 144 | return self._derivation[x][0], self._derivation[x][1], self._log_probability[ 145 | x] - self._accumulated_log_probability 146 | 147 | def __len__(self): 148 | # return len(self._derivation_log_probability) 149 | return len(self._log_probability) 150 | 151 | def __str__(self): 152 | output_string = "[%s (%d:%d) " % (self._node, self._span[0], self._span[1]) 153 | 154 | ''' 155 | for (production, hyper_nodes) in self._derivation_log_probability: 156 | 157 | output_string += "<%s" % (production) 158 | print production, len(hyper_nodes) 159 | if hyper_nodes!=None: 160 | print len(hyper_nodes) 161 | for hyper_node in hyper_nodes: 162 | output_string += " %s" % (hyper_node) 163 | output_string += "> " 164 | return output_string 165 | ''' 166 | 167 | for x in range(len(self._derivation)): 168 | production = self._derivation[x][0] 169 | hyper_nodes = self._derivation[x][1] 170 | 171 | output_string += "[%s" % (production) 172 | # print production, len(hyper_nodes) 173 | if hyper_nodes is not None: 174 | # print len(hyper_nodes) 175 | for hyper_node in hyper_nodes: 176 | output_string += " %s" % (hyper_node) 177 | output_string += "] " 178 | return output_string 179 | 180 | def __repr__(self): 181 | return self.__str__() 182 | 183 | def __hash__(self): 184 | return hash((self._node, self._span)) 185 | 186 | 187 | class PassiveEdge(object): 188 | def __init__(self, 189 | node, 190 | left, 191 | right 192 | ): 193 | self._node = node 194 | self._left = left 195 | self._right = right 196 | 197 | 198 | class ActiveEdge(object): 199 | def __init__(self, 200 | lhs, 201 | rhs, 202 | left, 203 | right, 204 | parsed=0 205 | ): 206 | self._lhs = lhs 207 | self._rhs = rhs 208 | self._left = left 209 | self._right = right 210 | self._parsed = parsed 211 | 212 | 213 | class HyperGraph(object): 214 | def __init__(self): 215 | self._top_down_approach = True 216 | return 217 | 218 | def parse(self, sentence, grammar): 219 | words = sentence.split() 220 | self.initialize(words, grammar, PassiveEdge(grammar._start_symbol, 0, len(words))) 221 | 222 | while len(self._finishing_agenda) > 0: 223 | while len(self._explore_agenda) > 0: 224 | break 225 | 226 | edge = self._finishing_agenda.pop() 227 | self.finish_edge(edge) 228 | 229 | return 230 | 231 | def initialize(self, words, grammar, goal=None): 232 | # TODO: create new chart and agenda 233 | self._explore_agenda = [] 234 | self._finishing_agenda = [] 235 | for x in len(words): 236 | self._finishing_agenda.append(PassiveEdge(words[x], x, x + 1)) 237 | 238 | if self._top_down_approach: 239 | for production in grammar.productions(lhs=grammar._start_symbol): 240 | self._finishing_agenda.append(ActiveEdge(production.lhs(), production.rhs(), 0, 0)) 241 | 242 | def finish_edge(self, edge): 243 | # TODO: add edge to chart 244 | self.do_fundamental_rule(edge) 245 | self.do_rule_introduction(edge) 246 | 247 | def do_fundamental_rule(self, edge): 248 | return 249 | 250 | def do_rule_introduction(self, edge): 251 | return 252 | -------------------------------------------------------------------------------- /launch_resume.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import pickle 3 | import string 4 | import numpy 5 | import getopt 6 | import sys 7 | import random 8 | import time 9 | import math 10 | import re 11 | import pprint 12 | import codecs 13 | import datetime 14 | import os 15 | import scipy.io 16 | import nltk 17 | import numpy 18 | import collections 19 | import optparse 20 | 21 | 22 | # bash scrip to terminate all sub-processes 23 | # kill $(ps aux | grep 'python infag' | awk '{print $2}') 24 | 25 | def parse_args(): 26 | parser = optparse.OptionParser() 27 | parser.set_defaults( 28 | # parameter set 1 29 | input_directory=None, 30 | output_directory=None, 31 | corpus_name=None, 32 | model_directory=None, 33 | 34 | # parameter set 2 35 | online_iterations=-1, 36 | snapshot_interval=-1, 37 | number_of_processes=0, 38 | ) 39 | # parameter set 1 40 | parser.add_option("--input_directory", type="string", dest="input_directory", 41 | help="input directory [None]") 42 | parser.add_option("--output_directory", type="string", dest="output_directory", 43 | help="output directory [None]") 44 | parser.add_option("--corpus_name", type="string", dest="corpus_name", 45 | help="the corpus name [None]") 46 | parser.add_option("--model_directory", type="string", dest="model_directory", 47 | help="the model directory [None]") 48 | 49 | # parameter set 2 50 | parser.add_option("--online_iterations", type="int", dest="online_iterations", 51 | help="resume iteration to run training [nonpos=previous settings]") 52 | parser.add_option("--number_of_processes", type="int", dest="number_of_processes", 53 | help="number of processes [0]") 54 | parser.add_option("--snapshot_interval", type="int", dest="snapshot_interval", 55 | help="snapshot interval [nonpos=previous settings]") 56 | 57 | (options, args) = parser.parse_args() 58 | return options 59 | 60 | 61 | model_setting_pattern = re.compile( 62 | r"\w+\-\d+\-D\d+\-P\d+\-S(?P\d+)\-B\d+\-O(?P\d+)\-t\d+\-k[\d\.]+\-G\w+\-T[\w\&\.]+\-ap[\w\&\.]+\-bp[\w\&\.]+") 63 | snapshot_pattern = re.compile(r"\-S(?P\d+)\-") 64 | online_pattern = re.compile(r"\-O(?P\d+)\-") 65 | 66 | 67 | def main(): 68 | options = parse_args() 69 | 70 | # parameter set 1 71 | assert (options.corpus_name is not None) 72 | assert (options.input_directory is not None) 73 | assert (options.output_directory is not None) 74 | assert (options.model_directory is not None) 75 | corpus_name = options.corpus_name 76 | input_directory = options.input_directory 77 | input_directory = os.path.join(input_directory, corpus_name) 78 | output_directory = options.output_directory 79 | if not os.path.exists(output_directory): 80 | os.mkdir(output_directory) 81 | output_directory = os.path.join(output_directory, corpus_name) 82 | if not os.path.exists(output_directory): 83 | os.mkdir(output_directory) 84 | model_directory = options.model_directory 85 | if not model_directory.endswith("/"): 86 | model_directory += "/" 87 | 88 | # look for model snapshot 89 | model_setting = os.path.basename(os.path.dirname(model_directory)) 90 | model_pattern_match_object = re.match(model_setting_pattern, model_setting) 91 | model_pattern_match_dictionary = model_pattern_match_object.groupdict() 92 | previous_online_iterations = int(model_pattern_match_dictionary["online"]) 93 | previous_snapshot_interval = int(model_pattern_match_dictionary["snapshot"]) 94 | model_file_path = os.path.join(model_directory, "model-%d" % (previous_online_iterations)) 95 | 96 | # load model snapshot 97 | try: 98 | cpickle_file = open(model_file_path, 'r') 99 | infinite_adaptor_grammar = pickle.load(cpickle_file) 100 | print("successfully load model from %s" % (model_directory)) 101 | cpickle_file.close() 102 | except ValueError: 103 | print("warning: unsuccessfully load model from %s due to value error..." % (model_file_path)) 104 | return 105 | except EOFError: 106 | print("warning: unsuccessfully load model from %s due to EOF error..." % (model_file_path)) 107 | return 108 | 109 | batch_size = infinite_adaptor_grammar._batch_size 110 | number_of_documents = infinite_adaptor_grammar._number_of_strings 111 | 112 | # parameter set 2 113 | online_iterations = number_of_documents // batch_size 114 | if options.online_iterations > 0: 115 | online_iterations = options.online_iterations 116 | assert (options.number_of_processes >= 0) 117 | number_of_processes = options.number_of_processes 118 | snapshot_interval = previous_snapshot_interval 119 | if options.snapshot_interval > 0: 120 | snapshot_interval = options.snapshot_interval 121 | 122 | # adjust model output path name 123 | model_setting = re.sub(snapshot_pattern, "-S%d-" % (snapshot_interval), model_setting) 124 | model_setting = re.sub(online_pattern, "-O%d-" % (online_iterations + previous_online_iterations), model_setting) 125 | 126 | output_directory = os.path.join(output_directory, model_setting) 127 | os.mkdir(os.path.abspath(output_directory)) 128 | 129 | # store all the options to a output stream 130 | options_output_file = open(os.path.join(output_directory, "option.txt"), 'w') 131 | # parameter set 1 132 | options_output_file.write("input_directory=" + input_directory + "\n") 133 | options_output_file.write("corpus_name=" + corpus_name + "\n") 134 | options_output_file.write("model_directory=" + model_directory + "\n") 135 | # parameter set 2 136 | options_output_file.write("snapshot_interval=" + str(snapshot_interval) + "\n") 137 | options_output_file.write("online_iterations=" + str(online_iterations) + "\n") 138 | options_output_file.write("number_of_processes=" + str(number_of_processes) + "\n") 139 | # parameter set 3 140 | options_output_file.write("number_of_documents=" + str(number_of_documents) + "\n") 141 | options_output_file.write("batch_size=" + str(batch_size) + "\n") 142 | options_output_file.close() 143 | 144 | print("========== ========== ========== ========== ==========") 145 | # parameter set 1 146 | print("input_directory=" + input_directory) 147 | print("corpus_name=" + corpus_name) 148 | print("model_directory=" + model_directory) 149 | # parameter set 2 150 | print("snapshot_interval=" + str(snapshot_interval)) 151 | print("online_iterations=" + str(online_iterations)) 152 | print("number_of_processes=" + str(number_of_processes)) 153 | # parameter set 3 154 | print("number_of_documents=" + str(number_of_documents)) 155 | print("batch_size=" + str(batch_size)) 156 | print("========== ========== ========== ========== ==========") 157 | 158 | # Documents 159 | train_docs = [] 160 | input_stream = open(os.path.join(input_directory, 'train.dat'), 'r') 161 | for line in input_stream: 162 | train_docs.append(line.strip()) 163 | input_stream.close() 164 | print("successfully load all training documents...") 165 | 166 | random.shuffle(train_docs) 167 | training_clock = time.time() 168 | snapshot_clock = time.time() 169 | for iteration in range(previous_online_iterations, previous_online_iterations + online_iterations): 170 | start_index = batch_size * iteration 171 | end_index = batch_size * (iteration + 1) 172 | if start_index // number_of_documents < end_index // number_of_documents: 173 | # train_doc_set = train_docs[(batch_size * iteration) % (number_of_documents) :] + train_docs[: (batch_size * (iteration+1)) % (number_of_documents)] 174 | train_doc_set = train_docs[(batch_size * iteration) % (number_of_documents):] 175 | random.shuffle(train_docs) 176 | train_doc_set += train_docs[: (batch_size * (iteration + 1)) % (number_of_documents)] 177 | else: 178 | train_doc_set = train_docs[(batch_size * iteration) % (number_of_documents): (batch_size * ( 179 | iteration + 1)) % number_of_documents] 180 | 181 | clock_iteration = time.time() 182 | # print "processing document:", train_doc_set 183 | clock_e_step, clock_m_step = infinite_adaptor_grammar.learning(train_doc_set, number_of_processes) 184 | 185 | if (iteration + 1) % snapshot_interval == 0: 186 | # cpickle_file = open(os.path.join(output_directory, "model-%d" % (iteration+1)), 'wb') 187 | # cPickle.dump(infinite_adaptor_grammar, cpickle_file) 188 | # cpickle_file.close() 189 | infinite_adaptor_grammar.export_adaptor_grammar( 190 | os.path.join(output_directory, "adagram-" + str((iteration + 1)))) 191 | # infinite_adaptor_grammar.export_aggregated_adaptor_grammar(os.path.join(output_directory, "ag-" + str((iteration+1)))) 192 | 193 | if (iteration + 1) % 1000 == 0: 194 | snapshot_clock = time.time() - snapshot_clock 195 | print('Processing 1000 mini-batches take %g seconds...' % (snapshot_clock)) 196 | snapshot_clock = time.time() 197 | 198 | clock_iteration = time.time() - clock_iteration 199 | print('E-step, M-step and iteration %d take %g, %g and %g seconds respectively...' % ( 200 | infinite_adaptor_grammar._counter, clock_e_step, clock_m_step, clock_iteration)) 201 | 202 | infinite_adaptor_grammar.export_adaptor_grammar(os.path.join(output_directory, "adagram-" + str((iteration + 1)))) 203 | # infinite_adaptor_grammar.export_aggregated_adaptor_grammar(os.path.join(output_directory, "ag-" + str((iteration+1)))) 204 | 205 | cpickle_file = open(os.path.join(output_directory, "model-%d" % (iteration + 1)), 'wb') 206 | pickle.dump(infinite_adaptor_grammar, cpickle_file) 207 | cpickle_file.close() 208 | 209 | training_clock = time.time() - training_clock 210 | print('Training finished in %g seconds...' % (training_clock)) 211 | 212 | 213 | if __name__ == '__main__': 214 | main() 215 | -------------------------------------------------------------------------------- /launch_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import pickle 3 | import string 4 | import numpy 5 | import getopt 6 | import sys 7 | import random 8 | import time 9 | import math 10 | import re 11 | import pprint 12 | import codecs 13 | import datetime 14 | import optparse 15 | import os 16 | import nltk 17 | import numpy 18 | import scipy 19 | import scipy.io 20 | import collections 21 | 22 | 23 | # bash scrip to terminate all sub-processes 24 | # kill $(ps aux | grep 'python infag' | awk '{print $2}') 25 | 26 | def shuffle_lists(list1, list2): 27 | assert (len(list1) == len(list2)) 28 | list1_shuf = [] 29 | list2_shuf = [] 30 | index_shuf = list(range(len(list1))) 31 | random.shuffle(index_shuf) 32 | for i in index_shuf: 33 | list1_shuf.append(list1[i]) 34 | list2_shuf.append(list2[i]) 35 | 36 | 37 | def parse_args(): 38 | parser = optparse.OptionParser() 39 | parser.set_defaults( 40 | # parameter set 1 41 | input_directory=None, 42 | output_directory=None, 43 | # corpus_name=None, 44 | grammar_file=None, 45 | 46 | # parameter set 2 47 | # number_of_topics=25, 48 | number_of_documents=-1, 49 | batch_size=-1, 50 | training_iterations=-1, 51 | number_of_processes=0, 52 | # multiprocesses=False, 53 | 54 | # parameter set 3 55 | grammaton_prune_interval=10, 56 | snapshot_interval=10, 57 | kappa=0.75, 58 | tau=64.0, 59 | 60 | # parameter set 4 61 | # desired_truncation_level={}, 62 | # alpha_theta={}, 63 | # alpha_pi={}, 64 | # beta_pi={}, 65 | 66 | # parameter set 5 67 | train_only=False, 68 | heldout_data=0.0, 69 | enable_word_model=False 70 | 71 | # fix_vocabulary=False 72 | ) 73 | # parameter set 1 74 | parser.add_option("--input_directory", type="string", dest="input_directory", 75 | help="input directory [None]") 76 | parser.add_option("--output_directory", type="string", dest="output_directory", 77 | help="output directory [None]") 78 | # parser.add_option("--corpus_name", type="string", dest="corpus_name", 79 | # help="the corpus name [None]") 80 | parser.add_option("--grammar_file", type="string", dest="grammar_file", 81 | help="the grammar file [None]") 82 | 83 | # parameter set 2 84 | # parser.add_option("--number_of_topics", type="int", dest="number_of_topics", 85 | # help="second level truncation [25]") 86 | parser.add_option("--number_of_documents", type="int", dest="number_of_documents", 87 | help="number of documents [-1]") 88 | parser.add_option("--batch_size", type="int", dest="batch_size", 89 | help="batch size [-1 in batch mode]") 90 | parser.add_option("--training_iterations", type="int", dest="training_iterations", 91 | help="max iteration to run training [number_of_documents/batch_size]") 92 | parser.add_option("--number_of_processes", type="int", dest="number_of_processes", 93 | help="number of processes [0]") 94 | # parser.add_option("--multiprocesses", action="store_true", dest="multiprocesses", 95 | # help="multiprocesses [false]") 96 | 97 | # parameter set 3 98 | parser.add_option("--kappa", type="float", dest="kappa", 99 | help="learning rate [0.5]") 100 | parser.add_option("--tau", type="float", dest="tau", 101 | help="slow down [1.0]") 102 | parser.add_option("--grammaton_prune_interval", type="int", dest="grammaton_prune_interval", 103 | help="vocabuary rank interval [10]") 104 | parser.add_option("--snapshot_interval", type="int", dest="snapshot_interval", 105 | help="snapshot interval [grammaton_prune_interval]") 106 | 107 | # parameter set 4 108 | ''' 109 | parser.add_option("--desired_truncation_level", dest="desired_truncation_level", 110 | action="callback", callback=process_ints, help="desired truncation level") 111 | parser.add_option("--alpha_theta", dest="alpha_theta", 112 | action="callback", callback=process_floats, 113 | help="hyper-parameter for Dirichlet distribution of PCFG productions [1.0/number_of_pcfg_productions]") 114 | parser.add_option("--alpha_pi", dest="alpha_pi", 115 | action="callback", callback=process_floats, 116 | help="hyper-parameter for Pitman-Yor process") 117 | parser.add_option("--beta_pi", dest="beta_pi", 118 | action="callback", callback=process_floats, 119 | help="hyper-parameter for Pitman-Yor process") 120 | ''' 121 | 122 | # parameter set 5 123 | parser.add_option("--train_only", action="store_true", dest="train_only", 124 | help="train mode only [false]") 125 | parser.add_option("--heldout_data", type="int", dest="heldout_data", 126 | help="portion of heldout data [0.0]") 127 | parser.add_option("--enable_word_model", action="store_true", dest="enable_word_model", 128 | help="enable word model [false]") 129 | 130 | ''' 131 | parser.add_option("--fix_vocabulary", action="store_true", dest="fix_vocabulary", 132 | help="run this program with fix vocabulary") 133 | ''' 134 | 135 | (options, args) = parser.parse_args() 136 | return options 137 | 138 | 139 | def main(): 140 | options = parse_args() 141 | 142 | # parameter set 1 143 | # assert(options.corpus_name!=None) 144 | assert (options.input_directory is not None) 145 | assert (options.output_directory is not None) 146 | 147 | input_directory = options.input_directory 148 | input_directory = input_directory.rstrip("/") 149 | corpus_name = os.path.basename(input_directory) 150 | 151 | output_directory = options.output_directory 152 | if not os.path.exists(output_directory): 153 | os.mkdir(output_directory) 154 | output_directory = os.path.join(output_directory, corpus_name) 155 | if not os.path.exists(output_directory): 156 | os.mkdir(output_directory) 157 | 158 | assert (options.grammar_file is not None) 159 | grammar_file = options.grammar_file 160 | assert (os.path.exists(grammar_file)) 161 | 162 | # Documents 163 | train_docs = [] 164 | input_stream = open(os.path.join(input_directory, 'train.dat'), 'r') 165 | for line in input_stream: 166 | train_docs.append(line.strip()) 167 | input_stream.close() 168 | print("successfully load all training documents...") 169 | 170 | # parameter set 2 171 | if options.number_of_documents > 0: 172 | number_of_documents = options.number_of_documents 173 | else: 174 | number_of_documents = len(train_docs) 175 | if options.batch_size > 0: 176 | batch_size = options.batch_size 177 | else: 178 | batch_size = number_of_documents 179 | # assert(number_of_documents % batch_size==0) 180 | training_iterations = number_of_documents // batch_size 181 | if options.training_iterations > 0: 182 | training_iterations = options.training_iterations 183 | # training_iterations=int(math.ceil(1.0*number_of_documents/batch_size)) 184 | # multiprocesses = options.multiprocesses 185 | assert (options.number_of_processes >= 0) 186 | number_of_processes = options.number_of_processes 187 | 188 | # parameter set 3 189 | assert (options.grammaton_prune_interval > 0) 190 | grammaton_prune_interval = options.grammaton_prune_interval 191 | snapshot_interval = grammaton_prune_interval 192 | if options.snapshot_interval > 0: 193 | snapshot_interval = options.snapshot_interval 194 | assert (options.tau >= 0) 195 | tau = options.tau 196 | # assert(options.kappa>=0.5 and options.kappa<=1) 197 | assert (options.kappa >= 0 and options.kappa <= 1) 198 | kappa = options.kappa 199 | if batch_size <= 0: 200 | print("warning: running in batch mode...") 201 | kappa = 0 202 | 203 | # read in adaptor grammars 204 | desired_truncation_level = {} 205 | alpha_pi = {} 206 | beta_pi = {} 207 | 208 | grammar_rules = [] 209 | adapted_non_terminals = set() 210 | # for line in codecs.open(grammar_file, 'r', encoding='utf-8'): 211 | for line in open(grammar_file, 'r'): 212 | line = line.strip() 213 | if line.startswith("%"): 214 | continue 215 | if line.startswith("@"): 216 | tokens = line.split() 217 | assert (len(tokens) == 5) 218 | adapted_non_terminal = nltk.Nonterminal(tokens[1]) 219 | adapted_non_terminals.add(adapted_non_terminal) 220 | desired_truncation_level[adapted_non_terminal] = int(tokens[2]) 221 | alpha_pi[adapted_non_terminal] = float(tokens[3]) 222 | beta_pi[adapted_non_terminal] = float(tokens[4]) 223 | continue 224 | grammar_rules.append(line) 225 | grammar_rules = "\n".join(grammar_rules) 226 | 227 | # Warning: if you are using nltk 2.x, please use parse_grammar() 228 | # from nltk.grammar import parse_grammar, standard_nonterm_parser 229 | # start, productions = parse_grammar(grammar_rules, standard_nonterm_parser, probabilistic=False) 230 | from nltk.grammar import read_grammar, standard_nonterm_parser 231 | start, productions = read_grammar(grammar_rules, standard_nonterm_parser, probabilistic=False) 232 | 233 | # create output directory 234 | now = datetime.datetime.now() 235 | suffix = now.strftime("%y%b%d-%H%M%S") + "" 236 | # desired_truncation_level_string = "".join(["%s%d" % (symbol, desired_truncation_level[symbol]) for symbol in desired_truncation_level]) 237 | # alpha_pi_string = "".join(["%s%d" % (symbol, alpha_pi[symbol]) for symbol in alpha_pi]) 238 | # beta_pi_string = "".join(["%s%d" % (symbol, beta_pi[symbol]) for symbol in beta_pi]) 239 | # output_directory += "-" + str(now.microsecond) + "/" 240 | suffix += "-D%d-P%d-S%d-B%d-O%d-t%d-k%g-G%s/" % (number_of_documents, 241 | # number_of_topics, 242 | grammaton_prune_interval, 243 | snapshot_interval, 244 | batch_size, 245 | training_iterations, 246 | tau, 247 | kappa, 248 | # alpha_theta, 249 | # alpha_pi_string, 250 | # beta_pi_string, 251 | # desired_truncation_level_string, 252 | os.path.basename(grammar_file) 253 | ) 254 | 255 | output_directory = os.path.join(output_directory, suffix) 256 | os.mkdir(os.path.abspath(output_directory)) 257 | 258 | # store all the options to a input_stream 259 | options_output_file = open(output_directory + "option.txt", 'w') 260 | # parameter set 1 261 | options_output_file.write("input_directory=" + input_directory + "\n") 262 | options_output_file.write("corpus_name=" + corpus_name + "\n") 263 | options_output_file.write("grammar_file=" + str(grammar_file) + "\n") 264 | # parameter set 2 265 | options_output_file.write("number_of_processes=" + str(number_of_processes) + "\n") 266 | # options_output_file.write("multiprocesses=" + str(multiprocesses) + "\n") 267 | options_output_file.write("number_of_documents=" + str(number_of_documents) + "\n") 268 | options_output_file.write("batch_size=" + str(batch_size) + "\n") 269 | options_output_file.write("training_iterations=" + str(training_iterations) + "\n") 270 | 271 | # parameter set 3 272 | options_output_file.write("grammaton_prune_interval=" + str(grammaton_prune_interval) + "\n") 273 | options_output_file.write("snapshot_interval=" + str(snapshot_interval) + "\n") 274 | options_output_file.write("tau=" + str(tau) + "\n") 275 | options_output_file.write("kappa=" + str(kappa) + "\n") 276 | 277 | # parameter set 4 278 | # options_output_file.write("alpha_theta=" + str(alpha_theta) + "\n") 279 | options_output_file.write("alpha_pi=%s\n" % alpha_pi) 280 | options_output_file.write("beta_pi=%s\n" % beta_pi) 281 | options_output_file.write("desired_truncation_level=%s\n" % desired_truncation_level) 282 | # parameter set 5 283 | # options_output_file.write("heldout_data=" + str(heldout_data) + "\n") 284 | options_output_file.close() 285 | 286 | print("========== ========== ========== ========== ==========") 287 | # parameter set 1 288 | print(("output_directory=" + output_directory)) 289 | print(("input_directory=" + input_directory)) 290 | print(("corpus_name=" + corpus_name)) 291 | print(("grammar_file=" + str(grammar_file))) 292 | 293 | # parameter set 2 294 | print(("number_of_documents=" + str(number_of_documents))) 295 | print(("batch_size=" + str(batch_size))) 296 | print(("training_iterations=" + str(training_iterations))) 297 | print(("number_of_processes=" + str(number_of_processes))) 298 | # print "multiprocesses=" + str(multiprocesses) 299 | 300 | # parameter set 3 301 | print(("grammaton_prune_interval=" + str(grammaton_prune_interval))) 302 | print(("snapshot_interval=" + str(snapshot_interval))) 303 | print(("tau=" + str(tau))) 304 | print(("kappa=" + str(kappa))) 305 | 306 | # parameter set 4 307 | # print "alpha_theta=" + str(alpha_theta) 308 | print(("alpha_pi=%s" % alpha_pi)) 309 | print(("beta_pi=%s" % beta_pi)) 310 | print(("desired_truncation_level=%s" % desired_truncation_level)) 311 | # parameter set 5 312 | # print "heldout_data=" + str(heldout_data) 313 | print("========== ========== ========== ========== ==========") 314 | 315 | import hybrid 316 | adagram_inferencer = hybrid.Hybrid(start, 317 | productions, 318 | adapted_non_terminals 319 | ) 320 | 321 | adagram_inferencer._initialize(number_of_documents, 322 | batch_size, 323 | tau, 324 | kappa, 325 | alpha_pi, 326 | beta_pi, 327 | None, 328 | desired_truncation_level, 329 | grammaton_prune_interval 330 | ) 331 | 332 | ''' 333 | clock_iteration = time.time() 334 | clock_e_step, clock_m_step = adagram_inferencer.seed(train_docs) 335 | clock_iteration = time.time()-clock_iteration 336 | print 'E-step, M-step and Seed take %g, %g and %g seconds respectively...' % (clock_e_step, clock_m_step, clock_iteration)p 337 | ''' 338 | 339 | # adagram_inferencer.export_adaptor_grammar(os.path.join(output_directory, "infag-0")) 340 | # adagram_inferencer.export_aggregated_adaptor_grammar(os.path.join(output_directory, "ag-0")) 341 | 342 | random.shuffle(train_docs) 343 | training_clock = time.time() 344 | snapshot_clock = time.time() 345 | for iteration in range(training_iterations): 346 | start_index = batch_size * iteration 347 | end_index = batch_size * (iteration + 1) 348 | if start_index // number_of_documents < end_index // number_of_documents: 349 | # train_doc_set = train_docs[(batch_size * iteration) % (number_of_documents) :] + train_docs[: (batch_size * (iteration+1)) % (number_of_documents)] 350 | train_doc_set = train_docs[(batch_size * iteration) % (number_of_documents):] 351 | random.shuffle(train_docs) 352 | train_doc_set += train_docs[: (batch_size * (iteration + 1)) % (number_of_documents)] 353 | else: 354 | train_doc_set = train_docs[(batch_size * iteration) % (number_of_documents): (batch_size * ( 355 | iteration + 1)) % number_of_documents] 356 | 357 | clock_iteration = time.time() 358 | # print "processing document:", train_doc_set 359 | clock_e_step, clock_m_step = adagram_inferencer.learning(train_doc_set, number_of_processes) 360 | 361 | if (iteration + 1) % snapshot_interval == 0: 362 | # cpickle_file = open(os.path.join(output_directory, "model-%d" % (adagram_inferencer._counter+1)), 'wb') 363 | # cPickle.dump(adagram_inferencer, cpickle_file) 364 | # cpickle_file.close() 365 | adagram_inferencer.export_adaptor_grammar(os.path.join(output_directory, "adagram-" + str((iteration + 1)))) 366 | # adagram_inferencer.export_aggregated_adaptor_grammar(os.path.join(output_directory, "ag-" + str((iteration+1)))) 367 | 368 | if (iteration + 1) % 1000 == 0: 369 | snapshot_clock = time.time() - snapshot_clock 370 | print(('Processing 1000 mini-batches take %g seconds...' % (snapshot_clock))) 371 | snapshot_clock = time.time() 372 | 373 | clock_iteration = time.time() - clock_iteration 374 | print(('E-step, M-step and iteration %d take %g, %g and %g seconds respectively...' % ( 375 | adagram_inferencer._counter, clock_e_step, clock_m_step, clock_iteration))) 376 | 377 | adagram_inferencer.export_adaptor_grammar( 378 | os.path.join(output_directory, "adagram-" + str(adagram_inferencer._counter + 1))) 379 | # adagram_inferencer.export_aggregated_adaptor_grammar(os.path.join(output_directory, "ag-" + str((iteration+1)))) 380 | 381 | cpickle_file = open(os.path.join(output_directory, "model-%d" % (iteration + 1)), 'wb') 382 | pickle.dump(adagram_inferencer, cpickle_file) 383 | cpickle_file.close() 384 | 385 | training_clock = time.time() - training_clock 386 | print(('Training finished in %g seconds...' % (training_clock))) 387 | 388 | 389 | if __name__ == '__main__': 390 | main() 391 | --------------------------------------------------------------------------------