├── .gitignore ├── README.md ├── cloudpickle.py ├── crf.py ├── crf_test.py ├── example.py ├── features.py ├── sample.txt ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | .*.swp 3 | *~ 4 | # C extensions 5 | *.so 6 | 7 | # Packages 8 | *.egg 9 | *.egg-info 10 | dist 11 | build 12 | eggs 13 | parts 14 | bin 15 | var 16 | sdist 17 | develop-eggs 18 | .installed.cfg 19 | lib 20 | lib64 21 | 22 | # Installer logs 23 | pip-log.txt 24 | 25 | # Unit test / coverage reports 26 | .coverage 27 | .tox 28 | nosetests.xml 29 | 30 | # Translations 31 | *.mo 32 | 33 | # Mr Developer 34 | .mr.developer.cfg 35 | .project 36 | .pydevproject 37 | 38 | *.data 39 | *.pickle 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | python-crf 2 | ========== 3 | 4 | Python implementation of linear-chain conditional random fields. 5 | 6 | ##Application 7 | 8 | Use to do feature extraction from products. 9 | 10 | 1. Extract keywords from respective fields. 11 | 2. Train using keyword sets. 12 | -------------------------------------------------------------------------------- /cloudpickle.py: -------------------------------------------------------------------------------- 1 | """ 2 | This class is defined to override standard pickle functionality 3 | 4 | The goals of it follow: 5 | -Serialize lambdas and nested functions to compiled byte code 6 | -Deal with main module correctly 7 | -Deal with other non-serializable objects 8 | 9 | It does not include an unpickler, as standard python unpickling suffices. 10 | 11 | This module was extracted from the `cloud` package, developed by `PiCloud, Inc. 12 | `_. 13 | 14 | Copyright (c) 2012, Regents of the University of California. 15 | Copyright (c) 2009 `PiCloud, Inc. `_. 16 | All rights reserved. 17 | 18 | Redistribution and use in source and binary forms, with or without 19 | modification, are permitted provided that the following conditions 20 | are met: 21 | * Redistributions of source code must retain the above copyright 22 | notice, this list of conditions and the following disclaimer. 23 | * Redistributions in binary form must reproduce the above copyright 24 | notice, this list of conditions and the following disclaimer in the 25 | documentation and/or other materials provided with the distribution. 26 | * Neither the name of the University of California, Berkeley nor the 27 | names of its contributors may be used to endorse or promote 28 | products derived from this software without specific prior written 29 | permission. 30 | 31 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 32 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 33 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 34 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 35 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 36 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 37 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 38 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 39 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 40 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 41 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 42 | """ 43 | 44 | 45 | import operator 46 | import os 47 | import pickle 48 | import struct 49 | import sys 50 | import types 51 | from functools import partial 52 | import itertools 53 | from copy_reg import _extension_registry, _inverted_registry, _extension_cache 54 | import new 55 | import dis 56 | import traceback 57 | 58 | #relevant opcodes 59 | STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) 60 | DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) 61 | LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) 62 | GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] 63 | 64 | HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) 65 | EXTENDED_ARG = chr(dis.EXTENDED_ARG) 66 | 67 | import logging 68 | cloudLog = logging.getLogger("Cloud.Transport") 69 | 70 | try: 71 | import ctypes 72 | except (MemoryError, ImportError): 73 | logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True) 74 | ctypes = None 75 | PyObject_HEAD = None 76 | else: 77 | 78 | # for reading internal structures 79 | PyObject_HEAD = [ 80 | ('ob_refcnt', ctypes.c_size_t), 81 | ('ob_type', ctypes.c_void_p), 82 | ] 83 | 84 | 85 | try: 86 | from cStringIO import StringIO 87 | except ImportError: 88 | from StringIO import StringIO 89 | 90 | # These helper functions were copied from PiCloud's util module. 91 | def islambda(func): 92 | return getattr(func,'func_name') == '' 93 | 94 | def xrange_params(xrangeobj): 95 | """Returns a 3 element tuple describing the xrange start, step, and len 96 | respectively 97 | 98 | Note: Only guarentees that elements of xrange are the same. parameters may 99 | be different. 100 | e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same 101 | though w/ iteration 102 | """ 103 | 104 | xrange_len = len(xrangeobj) 105 | if not xrange_len: #empty 106 | return (0,1,0) 107 | start = xrangeobj[0] 108 | if xrange_len == 1: #one element 109 | return start, 1, 1 110 | return (start, xrangeobj[1] - xrangeobj[0], xrange_len) 111 | 112 | #debug variables intended for developer use: 113 | printSerialization = False 114 | printMemoization = False 115 | 116 | useForcedImports = True #Should I use forced imports for tracking? 117 | 118 | 119 | 120 | class CloudPickler(pickle.Pickler): 121 | 122 | dispatch = pickle.Pickler.dispatch.copy() 123 | savedForceImports = False 124 | savedDjangoEnv = False #hack tro transport django environment 125 | 126 | def __init__(self, file, protocol=None, min_size_to_save= 0): 127 | pickle.Pickler.__init__(self,file,protocol) 128 | self.modules = set() #set of modules needed to depickle 129 | self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env 130 | 131 | def dump(self, obj): 132 | # note: not thread safe 133 | # minimal side-effects, so not fixing 134 | recurse_limit = 3000 135 | base_recurse = sys.getrecursionlimit() 136 | if base_recurse < recurse_limit: 137 | sys.setrecursionlimit(recurse_limit) 138 | self.inject_addons() 139 | try: 140 | return pickle.Pickler.dump(self, obj) 141 | except RuntimeError, e: 142 | if 'recursion' in e.args[0]: 143 | msg = """Could not pickle object as excessively deep recursion required. 144 | Try _fast_serialization=2 or contact PiCloud support""" 145 | raise pickle.PicklingError(msg) 146 | finally: 147 | new_recurse = sys.getrecursionlimit() 148 | if new_recurse == recurse_limit: 149 | sys.setrecursionlimit(base_recurse) 150 | 151 | def save_buffer(self, obj): 152 | """Fallback to save_string""" 153 | pickle.Pickler.save_string(self,str(obj)) 154 | dispatch[buffer] = save_buffer 155 | 156 | #block broken objects 157 | def save_unsupported(self, obj, pack=None): 158 | raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) 159 | dispatch[types.GeneratorType] = save_unsupported 160 | 161 | #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it 162 | try: 163 | slice(0,1).__reduce__() 164 | except TypeError: #can't pickle - 165 | dispatch[slice] = save_unsupported 166 | 167 | #itertools objects do not pickle! 168 | for v in itertools.__dict__.values(): 169 | if type(v) is type: 170 | dispatch[v] = save_unsupported 171 | 172 | 173 | def save_dict(self, obj): 174 | """hack fix 175 | If the dict is a global, deal with it in a special way 176 | """ 177 | #print 'saving', obj 178 | if obj is __builtins__: 179 | self.save_reduce(_get_module_builtins, (), obj=obj) 180 | else: 181 | pickle.Pickler.save_dict(self, obj) 182 | dispatch[pickle.DictionaryType] = save_dict 183 | 184 | 185 | def save_module(self, obj, pack=struct.pack): 186 | """ 187 | Save a module as an import 188 | """ 189 | #print 'try save import', obj.__name__ 190 | self.modules.add(obj) 191 | self.save_reduce(subimport,(obj.__name__,), obj=obj) 192 | dispatch[types.ModuleType] = save_module #new type 193 | 194 | def save_codeobject(self, obj, pack=struct.pack): 195 | """ 196 | Save a code object 197 | """ 198 | #print 'try to save codeobj: ', obj 199 | args = ( 200 | obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, 201 | obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, 202 | obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars 203 | ) 204 | self.save_reduce(types.CodeType, args, obj=obj) 205 | dispatch[types.CodeType] = save_codeobject #new type 206 | 207 | def save_function(self, obj, name=None, pack=struct.pack): 208 | """ Registered with the dispatch to handle all function types. 209 | 210 | Determines what kind of function obj is (e.g. lambda, defined at 211 | interactive prompt, etc) and handles the pickling appropriately. 212 | """ 213 | write = self.write 214 | 215 | name = obj.__name__ 216 | modname = pickle.whichmodule(obj, name) 217 | #print 'which gives %s %s %s' % (modname, obj, name) 218 | try: 219 | themodule = sys.modules[modname] 220 | except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ 221 | modname = '__main__' 222 | 223 | if modname == '__main__': 224 | themodule = None 225 | 226 | if themodule: 227 | self.modules.add(themodule) 228 | 229 | if not self.savedDjangoEnv: 230 | #hack for django - if we detect the settings module, we transport it 231 | django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') 232 | if django_settings: 233 | django_mod = sys.modules.get(django_settings) 234 | if django_mod: 235 | cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) 236 | self.savedDjangoEnv = True 237 | self.modules.add(django_mod) 238 | write(pickle.MARK) 239 | self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) 240 | write(pickle.POP_MARK) 241 | 242 | 243 | # if func is lambda, def'ed at prompt, is in main, or is nested, then 244 | # we'll pickle the actual function object rather than simply saving a 245 | # reference (as is done in default pickler), via save_function_tuple. 246 | if islambda(obj) or obj.func_code.co_filename == '' or themodule == None: 247 | #Force server to import modules that have been imported in main 248 | modList = None 249 | if themodule == None and not self.savedForceImports: 250 | mainmod = sys.modules['__main__'] 251 | if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): 252 | modList = list(mainmod.___pyc_forcedImports__) 253 | self.savedForceImports = True 254 | self.save_function_tuple(obj, modList) 255 | return 256 | else: # func is nested 257 | klass = getattr(themodule, name, None) 258 | if klass is None or klass is not obj: 259 | self.save_function_tuple(obj, [themodule]) 260 | return 261 | 262 | if obj.__dict__: 263 | # essentially save_reduce, but workaround needed to avoid recursion 264 | self.save(_restore_attr) 265 | write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n') 266 | self.memoize(obj) 267 | self.save(obj.__dict__) 268 | write(pickle.TUPLE + pickle.REDUCE) 269 | else: 270 | write(pickle.GLOBAL + modname + '\n' + name + '\n') 271 | self.memoize(obj) 272 | dispatch[types.FunctionType] = save_function 273 | 274 | def save_function_tuple(self, func, forced_imports): 275 | """ Pickles an actual func object. 276 | 277 | A func comprises: code, globals, defaults, closure, and dict. We 278 | extract and save these, injecting reducing functions at certain points 279 | to recreate the func object. Keep in mind that some of these pieces 280 | can contain a ref to the func itself. Thus, a naive save on these 281 | pieces could trigger an infinite loop of save's. To get around that, 282 | we first create a skeleton func object using just the code (this is 283 | safe, since this won't contain a ref to the func), and memoize it as 284 | soon as it's created. The other stuff can then be filled in later. 285 | """ 286 | save = self.save 287 | write = self.write 288 | 289 | # save the modules (if any) 290 | if forced_imports: 291 | write(pickle.MARK) 292 | save(_modules_to_main) 293 | #print 'forced imports are', forced_imports 294 | 295 | forced_names = map(lambda m: m.__name__, forced_imports) 296 | save((forced_names,)) 297 | 298 | #save((forced_imports,)) 299 | write(pickle.REDUCE) 300 | write(pickle.POP_MARK) 301 | 302 | code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) 303 | 304 | save(_fill_function) # skeleton function updater 305 | write(pickle.MARK) # beginning of tuple that _fill_function expects 306 | 307 | # create a skeleton function object and memoize it 308 | save(_make_skel_func) 309 | save((code, len(closure), base_globals)) 310 | write(pickle.REDUCE) 311 | self.memoize(func) 312 | 313 | # save the rest of the func data needed by _fill_function 314 | save(f_globals) 315 | save(defaults) 316 | save(closure) 317 | save(dct) 318 | write(pickle.TUPLE) 319 | write(pickle.REDUCE) # applies _fill_function on the tuple 320 | 321 | @staticmethod 322 | def extract_code_globals(co): 323 | """ 324 | Find all globals names read or written to by codeblock co 325 | """ 326 | code = co.co_code 327 | names = co.co_names 328 | out_names = set() 329 | 330 | n = len(code) 331 | i = 0 332 | extended_arg = 0 333 | while i < n: 334 | op = code[i] 335 | 336 | i = i+1 337 | if op >= HAVE_ARGUMENT: 338 | oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg 339 | extended_arg = 0 340 | i = i+2 341 | if op == EXTENDED_ARG: 342 | extended_arg = oparg*65536L 343 | if op in GLOBAL_OPS: 344 | out_names.add(names[oparg]) 345 | #print 'extracted', out_names, ' from ', names 346 | return out_names 347 | 348 | def extract_func_data(self, func): 349 | """ 350 | Turn the function into a tuple of data necessary to recreate it: 351 | code, globals, defaults, closure, dict 352 | """ 353 | code = func.func_code 354 | 355 | # extract all global ref's 356 | func_global_refs = CloudPickler.extract_code_globals(code) 357 | if code.co_consts: # see if nested function have any global refs 358 | for const in code.co_consts: 359 | if type(const) is types.CodeType and const.co_names: 360 | func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const)) 361 | # process all variables referenced by global environment 362 | f_globals = {} 363 | for var in func_global_refs: 364 | #Some names, such as class functions are not global - we don't need them 365 | if func.func_globals.has_key(var): 366 | f_globals[var] = func.func_globals[var] 367 | 368 | # defaults requires no processing 369 | defaults = func.func_defaults 370 | 371 | def get_contents(cell): 372 | try: 373 | return cell.cell_contents 374 | except ValueError, e: #cell is empty error on not yet assigned 375 | raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') 376 | 377 | 378 | # process closure 379 | if func.func_closure: 380 | closure = map(get_contents, func.func_closure) 381 | else: 382 | closure = [] 383 | 384 | # save the dict 385 | dct = func.func_dict 386 | 387 | if printSerialization: 388 | outvars = ['code: ' + str(code) ] 389 | outvars.append('globals: ' + str(f_globals)) 390 | outvars.append('defaults: ' + str(defaults)) 391 | outvars.append('closure: ' + str(closure)) 392 | print 'function ', func, 'is extracted to: ', ', '.join(outvars) 393 | 394 | base_globals = self.globals_ref.get(id(func.func_globals), {}) 395 | self.globals_ref[id(func.func_globals)] = base_globals 396 | 397 | return (code, f_globals, defaults, closure, dct, base_globals) 398 | 399 | def save_global(self, obj, name=None, pack=struct.pack): 400 | write = self.write 401 | memo = self.memo 402 | 403 | if name is None: 404 | name = obj.__name__ 405 | 406 | modname = getattr(obj, "__module__", None) 407 | if modname is None: 408 | modname = pickle.whichmodule(obj, name) 409 | 410 | try: 411 | __import__(modname) 412 | themodule = sys.modules[modname] 413 | except (ImportError, KeyError, AttributeError): #should never occur 414 | raise pickle.PicklingError( 415 | "Can't pickle %r: Module %s cannot be found" % 416 | (obj, modname)) 417 | 418 | if modname == '__main__': 419 | themodule = None 420 | 421 | if themodule: 422 | self.modules.add(themodule) 423 | 424 | sendRef = True 425 | typ = type(obj) 426 | #print 'saving', obj, typ 427 | try: 428 | try: #Deal with case when getattribute fails with exceptions 429 | klass = getattr(themodule, name) 430 | except (AttributeError): 431 | if modname == '__builtin__': #new.* are misrepeported 432 | modname = 'new' 433 | __import__(modname) 434 | themodule = sys.modules[modname] 435 | try: 436 | klass = getattr(themodule, name) 437 | except AttributeError, a: 438 | #print themodule, name, obj, type(obj) 439 | raise pickle.PicklingError("Can't pickle builtin %s" % obj) 440 | else: 441 | raise 442 | 443 | except (ImportError, KeyError, AttributeError): 444 | if typ == types.TypeType or typ == types.ClassType: 445 | sendRef = False 446 | else: #we can't deal with this 447 | raise 448 | else: 449 | if klass is not obj and (typ == types.TypeType or typ == types.ClassType): 450 | sendRef = False 451 | if not sendRef: 452 | #note: Third party types might crash this - add better checks! 453 | d = dict(obj.__dict__) #copy dict proxy to a dict 454 | if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties 455 | d.pop('__dict__',None) 456 | d.pop('__weakref__',None) 457 | 458 | # hack as __new__ is stored differently in the __dict__ 459 | new_override = d.get('__new__', None) 460 | if new_override: 461 | d['__new__'] = obj.__new__ 462 | 463 | self.save_reduce(type(obj),(obj.__name__,obj.__bases__, 464 | d),obj=obj) 465 | #print 'internal reduce dask %s %s' % (obj, d) 466 | return 467 | 468 | if self.proto >= 2: 469 | code = _extension_registry.get((modname, name)) 470 | if code: 471 | assert code > 0 472 | if code <= 0xff: 473 | write(pickle.EXT1 + chr(code)) 474 | elif code <= 0xffff: 475 | write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) 476 | else: 477 | write(pickle.EXT4 + pack("= 2 and getattr(func, "__name__", "") == "__newobj__": 588 | #Added fix to allow transient 589 | cls = args[0] 590 | if not hasattr(cls, "__new__"): 591 | raise pickle.PicklingError( 592 | "args[0] from __newobj__ args has no __new__") 593 | if obj is not None and cls is not obj.__class__: 594 | raise pickle.PicklingError( 595 | "args[0] from __newobj__ args has the wrong class") 596 | args = args[1:] 597 | save(cls) 598 | 599 | #Don't pickle transient entries 600 | if hasattr(obj, '__transient__'): 601 | transient = obj.__transient__ 602 | state = state.copy() 603 | 604 | for k in list(state.keys()): 605 | if k in transient: 606 | del state[k] 607 | 608 | save(args) 609 | write(pickle.NEWOBJ) 610 | else: 611 | save(func) 612 | save(args) 613 | write(pickle.REDUCE) 614 | 615 | if obj is not None: 616 | self.memoize(obj) 617 | 618 | # More new special cases (that work with older protocols as 619 | # well): when __reduce__ returns a tuple with 4 or 5 items, 620 | # the 4th and 5th item should be iterators that provide list 621 | # items and dict items (as (key, value) tuples), or None. 622 | 623 | if listitems is not None: 624 | self._batch_appends(listitems) 625 | 626 | if dictitems is not None: 627 | self._batch_setitems(dictitems) 628 | 629 | if state is not None: 630 | #print 'obj %s has state %s' % (obj, state) 631 | save(state) 632 | write(pickle.BUILD) 633 | 634 | 635 | def save_xrange(self, obj): 636 | """Save an xrange object in python 2.5 637 | Python 2.6 supports this natively 638 | """ 639 | range_params = xrange_params(obj) 640 | self.save_reduce(_build_xrange,range_params) 641 | 642 | #python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it 643 | try: 644 | xrange(0).__reduce__() 645 | except TypeError: #can't pickle -- use PiCloud pickler 646 | dispatch[xrange] = save_xrange 647 | 648 | def save_partial(self, obj): 649 | """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" 650 | self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) 651 | 652 | if sys.version_info < (2,7): #2.7 supports partial pickling 653 | dispatch[partial] = save_partial 654 | 655 | 656 | def save_file(self, obj): 657 | """Save a file""" 658 | import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute 659 | from ..transport.adapter import SerializingAdapter 660 | 661 | if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): 662 | raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") 663 | if obj.name == '': 664 | return self.save_reduce(getattr, (sys,'stdout'), obj=obj) 665 | if obj.name == '': 666 | return self.save_reduce(getattr, (sys,'stderr'), obj=obj) 667 | if obj.name == '': 668 | raise pickle.PicklingError("Cannot pickle standard input") 669 | if hasattr(obj, 'isatty') and obj.isatty(): 670 | raise pickle.PicklingError("Cannot pickle files that map to tty objects") 671 | if 'r' not in obj.mode: 672 | raise pickle.PicklingError("Cannot pickle files that are not opened for reading") 673 | name = obj.name 674 | try: 675 | fsize = os.stat(name).st_size 676 | except OSError: 677 | raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name) 678 | 679 | if obj.closed: 680 | #create an empty closed string io 681 | retval = pystringIO.StringIO("") 682 | retval.close() 683 | elif not fsize: #empty file 684 | retval = pystringIO.StringIO("") 685 | try: 686 | tmpfile = file(name) 687 | tst = tmpfile.read(1) 688 | except IOError: 689 | raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) 690 | tmpfile.close() 691 | if tst != '': 692 | raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) 693 | elif fsize > SerializingAdapter.max_transmit_data: 694 | raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" % 695 | (name,SerializingAdapter.max_transmit_data)) 696 | else: 697 | try: 698 | tmpfile = file(name) 699 | contents = tmpfile.read(SerializingAdapter.max_transmit_data) 700 | tmpfile.close() 701 | except IOError: 702 | raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) 703 | retval = pystringIO.StringIO(contents) 704 | curloc = obj.tell() 705 | retval.seek(curloc) 706 | 707 | retval.name = name 708 | self.save(retval) #save stringIO 709 | self.memoize(obj) 710 | 711 | dispatch[file] = save_file 712 | """Special functions for Add-on libraries""" 713 | 714 | def inject_numpy(self): 715 | numpy = sys.modules.get('numpy') 716 | if not numpy or not hasattr(numpy, 'ufunc'): 717 | return 718 | self.dispatch[numpy.ufunc] = self.__class__.save_ufunc 719 | 720 | numpy_tst_mods = ['numpy', 'scipy.special'] 721 | def save_ufunc(self, obj): 722 | """Hack function for saving numpy ufunc objects""" 723 | name = obj.__name__ 724 | for tst_mod_name in self.numpy_tst_mods: 725 | tst_mod = sys.modules.get(tst_mod_name, None) 726 | if tst_mod: 727 | if name in tst_mod.__dict__: 728 | self.save_reduce(_getobject, (tst_mod_name, name)) 729 | return 730 | raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj)) 731 | 732 | def inject_timeseries(self): 733 | """Handle bugs with pickling scikits timeseries""" 734 | tseries = sys.modules.get('scikits.timeseries.tseries') 735 | if not tseries or not hasattr(tseries, 'Timeseries'): 736 | return 737 | self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries 738 | 739 | def save_timeseries(self, obj): 740 | import scikits.timeseries.tseries as ts 741 | 742 | func, reduce_args, state = obj.__reduce__() 743 | if func != ts._tsreconstruct: 744 | raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func)) 745 | state = (1, 746 | obj.shape, 747 | obj.dtype, 748 | obj.flags.fnc, 749 | obj._data.tostring(), 750 | ts.getmaskarray(obj).tostring(), 751 | obj._fill_value, 752 | obj._dates.shape, 753 | obj._dates.__array__().tostring(), 754 | obj._dates.dtype, #added -- preserve type 755 | obj.freq, 756 | obj._optinfo, 757 | ) 758 | return self.save_reduce(_genTimeSeries, (reduce_args, state)) 759 | 760 | def inject_email(self): 761 | """Block email LazyImporters from being saved""" 762 | email = sys.modules.get('email') 763 | if not email: 764 | return 765 | self.dispatch[email.LazyImporter] = self.__class__.save_unsupported 766 | 767 | def inject_addons(self): 768 | """Plug in system. Register additional pickling functions if modules already loaded""" 769 | self.inject_numpy() 770 | self.inject_timeseries() 771 | self.inject_email() 772 | 773 | """Python Imaging Library""" 774 | def save_image(self, obj): 775 | if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \ 776 | and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()): 777 | #if image not loaded yet -- lazy load 778 | self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj) 779 | else: 780 | #image is loaded - just transmit it over 781 | self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj) 782 | 783 | """ 784 | def memoize(self, obj): 785 | pickle.Pickler.memoize(self, obj) 786 | if printMemoization: 787 | print 'memoizing ' + str(obj) 788 | """ 789 | 790 | 791 | 792 | # Shorthands for legacy support 793 | 794 | def dump(obj, file, protocol=2): 795 | CloudPickler(file, protocol).dump(obj) 796 | 797 | def dumps(obj, protocol=2): 798 | file = StringIO() 799 | 800 | cp = CloudPickler(file,protocol) 801 | cp.dump(obj) 802 | 803 | #print 'cloud dumped', str(obj), str(cp.modules) 804 | 805 | return file.getvalue() 806 | 807 | 808 | #hack for __import__ not working as desired 809 | def subimport(name): 810 | __import__(name) 811 | return sys.modules[name] 812 | 813 | #hack to load django settings: 814 | def django_settings_load(name): 815 | modified_env = False 816 | 817 | if 'DJANGO_SETTINGS_MODULE' not in os.environ: 818 | os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps 819 | modified_env = True 820 | try: 821 | module = subimport(name) 822 | except Exception, i: 823 | print >> sys.stderr, 'Cloud not import django settings %s:' % (name) 824 | print_exec(sys.stderr) 825 | if modified_env: 826 | del os.environ['DJANGO_SETTINGS_MODULE'] 827 | else: 828 | #add project directory to sys,path: 829 | if hasattr(module,'__file__'): 830 | dirname = os.path.split(module.__file__)[0] + '/' 831 | sys.path.append(dirname) 832 | 833 | # restores function attributes 834 | def _restore_attr(obj, attr): 835 | for key, val in attr.items(): 836 | setattr(obj, key, val) 837 | return obj 838 | 839 | def _get_module_builtins(): 840 | return pickle.__builtins__ 841 | 842 | def print_exec(stream): 843 | ei = sys.exc_info() 844 | traceback.print_exception(ei[0], ei[1], ei[2], None, stream) 845 | 846 | def _modules_to_main(modList): 847 | """Force every module in modList to be placed into main""" 848 | if not modList: 849 | return 850 | 851 | main = sys.modules['__main__'] 852 | for modname in modList: 853 | if type(modname) is str: 854 | try: 855 | mod = __import__(modname) 856 | except Exception, i: #catch all... 857 | sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ 858 | A version mismatch is likely. Specific error was:\n' % modname) 859 | print_exec(sys.stderr) 860 | else: 861 | setattr(main,mod.__name__, mod) 862 | else: 863 | #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) 864 | #In old version actual module was sent 865 | setattr(main,modname.__name__, modname) 866 | 867 | #object generators: 868 | def _build_xrange(start, step, len): 869 | """Built xrange explicitly""" 870 | return xrange(start, start + step*len, step) 871 | 872 | def _genpartial(func, args, kwds): 873 | if not args: 874 | args = () 875 | if not kwds: 876 | kwds = {} 877 | return partial(func, *args, **kwds) 878 | 879 | 880 | def _fill_function(func, globals, defaults, closure, dict): 881 | """ Fills in the rest of function data into the skeleton function object 882 | that were created via _make_skel_func(). 883 | """ 884 | func.func_globals.update(globals) 885 | func.func_defaults = defaults 886 | func.func_dict = dict 887 | 888 | if len(closure) != len(func.func_closure): 889 | raise pickle.UnpicklingError("closure lengths don't match up") 890 | for i in range(len(closure)): 891 | _change_cell_value(func.func_closure[i], closure[i]) 892 | 893 | return func 894 | 895 | def _make_skel_func(code, num_closures, base_globals = None): 896 | """ Creates a skeleton function object that contains just the provided 897 | code and the correct number of cells in func_closure. All other 898 | func attributes (e.g. func_globals) are empty. 899 | """ 900 | #build closure (cells): 901 | if not ctypes: 902 | raise Exception('ctypes failed to import; cannot build function') 903 | 904 | cellnew = ctypes.pythonapi.PyCell_New 905 | cellnew.restype = ctypes.py_object 906 | cellnew.argtypes = (ctypes.py_object,) 907 | dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures))) 908 | 909 | if base_globals is None: 910 | base_globals = {} 911 | base_globals['__builtins__'] = __builtins__ 912 | 913 | return types.FunctionType(code, base_globals, 914 | None, None, dummy_closure) 915 | 916 | # this piece of opaque code is needed below to modify 'cell' contents 917 | cell_changer_code = new.code( 918 | 1, 1, 2, 0, 919 | ''.join([ 920 | chr(dis.opmap['LOAD_FAST']), '\x00\x00', 921 | chr(dis.opmap['DUP_TOP']), 922 | chr(dis.opmap['STORE_DEREF']), '\x00\x00', 923 | chr(dis.opmap['RETURN_VALUE']) 924 | ]), 925 | (), (), ('newval',), '', 'cell_changer', 1, '', ('c',), () 926 | ) 927 | 928 | def _change_cell_value(cell, newval): 929 | """ Changes the contents of 'cell' object to newval """ 930 | return new.function(cell_changer_code, {}, None, (), (cell,))(newval) 931 | 932 | """Constructors for 3rd party libraries 933 | Note: These can never be renamed due to client compatibility issues""" 934 | 935 | def _getobject(modname, attribute): 936 | mod = __import__(modname) 937 | return mod.__dict__[attribute] 938 | 939 | def _generateImage(size, mode, str_rep): 940 | """Generate image from string representation""" 941 | import Image 942 | i = Image.new(mode, size) 943 | i.fromstring(str_rep) 944 | return i 945 | 946 | def _lazyloadImage(fp): 947 | import Image 948 | fp.seek(0) #works in almost any case 949 | return Image.open(fp) 950 | 951 | """Timeseries""" 952 | def _genTimeSeries(reduce_args, state): 953 | import scikits.timeseries.tseries as ts 954 | from numpy import ndarray 955 | from numpy.ma import MaskedArray 956 | 957 | 958 | time_series = ts._tsreconstruct(*reduce_args) 959 | 960 | #from setstate modified 961 | (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state 962 | #print 'regenerating %s' % dtyp 963 | 964 | MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) 965 | _dates = time_series._dates 966 | #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ 967 | ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) 968 | _dates.freq = frq 969 | _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, 970 | toobj=None, toord=None, tostr=None)) 971 | # Update the _optinfo dictionary 972 | time_series._optinfo.update(infodict) 973 | return time_series 974 | 975 | -------------------------------------------------------------------------------- /crf.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO: 3 | x modify function so that it returns (func, gradient) 4 | x implement regularisation 5 | - implement viterbi 6 | """ 7 | import marshal 8 | import numpy as np 9 | from scipy import misc,optimize 10 | 11 | START = '|-' 12 | END = '-|' 13 | 14 | def log_dot_vm(loga,logM): 15 | return misc.logsumexp(loga.reshape(loga.shape+(1,))+logM,axis=0) 16 | def log_dot_mv(logM,logb): 17 | return misc.logsumexp(logM+logb.reshape((1,)+logb.shape),axis=1) 18 | 19 | class CRF: 20 | def __init__(self,feature_functions,labels,sigma=10,transition_feature=True): 21 | self.ft_fun = feature_functions 22 | 23 | self.labels = [START] + labels + [ END ] 24 | if transition_feature: 25 | self.ft_fun = self.ft_fun + Transitions.functions(self.labels[1:],self.labels[:-1]) 26 | self.theta = np.random.randn(len(self.ft_fun)) 27 | 28 | self.label_id = { l:i for i,l in enumerate(self.labels) } 29 | self.v = sigma ** 2 30 | self.v2 = self.v * 2 31 | 32 | def regulariser(self,w): 33 | return np.sum(w ** 2) /self.v2 34 | def regulariser_deriv(self,w): 35 | return np.sum(w) / self.v 36 | 37 | def all_features(self,x_vec): 38 | """ 39 | Axes: 40 | 0 - T or time or sequence index 41 | 1 - y' or previous label 42 | 2 - y or current label 43 | 3 - f(y',y,x_vec,i) for i s 44 | """ 45 | result = np.zeros((len(x_vec)+1,len(self.labels),len(self.labels),len(self.ft_fun))) 46 | for i in range(len(x_vec)+1): 47 | for j,yp in enumerate(self.labels): 48 | for k,y in enumerate(self.labels): 49 | for l,f in enumerate(self.ft_fun): 50 | result[i,j,k,l] = f(yp,y,x_vec,i) 51 | return result 52 | 53 | def forward(self,M,start=0): 54 | alphas = np.NINF*np.ones((M.shape[0],M.shape[1])) 55 | alpha = alphas[0] 56 | alpha[start] = 0 57 | for i in range(M.shape[0]-1): 58 | alpha = alphas[i+1] = log_dot_vm(alpha,M[i]) 59 | alpha = log_dot_vm(alpha,M[-1]) 60 | return (alphas,alpha) 61 | 62 | def backward(self,M,end=-1): 63 | #betas = np.NINF*np.ones((M.shape[0],M.shape[1])) 64 | betas = np.zeros((M.shape[0],M.shape[1])) 65 | beta = betas[-1] 66 | beta[end] = 0 67 | for i in reversed(range(M.shape[0]-1)): 68 | beta = betas[i] = log_dot_mv(M[i+1],beta) 69 | beta = log_dot_mv(M[0],beta) 70 | return (betas,beta) 71 | 72 | def create_vector_list(self,x_vecs,y_vecs): 73 | print len(x_vecs) 74 | observations = [ self.all_features(x_vec) for x_vec in x_vecs ] 75 | labels = len(y_vecs)*[None] 76 | 77 | for i in range(len(y_vecs)): 78 | assert(len(y_vecs[i]) == len(x_vecs[i])) 79 | y_vecs[i].insert(0,START) 80 | y_vecs[i].append(END) 81 | labels[i] = np.array([ self.label_id[y] for y in y_vecs[i] ],copy=False,dtype=np.int) 82 | 83 | return (observations,labels) 84 | 85 | def neg_likelihood_and_deriv(self,x_vec_list,y_vec_list,theta,debug=False): 86 | likelihood = 0 87 | derivative = np.zeros(len(self.theta)) 88 | for x_vec,y_vec in zip(x_vec_list,y_vec_list): 89 | """ 90 | all_features: len(x_vec) + 1 x Y x Y x K 91 | M: len(x_vec) + 1 x Y x Y 92 | alphas: len(x_vec) + 1 x Y 93 | betas: len(x_vec) + 1 x Y 94 | log_probs: len(x_vec) + 1 x Y x Y (Y is the size of the state space) 95 | `unnormalised` value here is alpha * M * beta, an unnormalised probability 96 | """ 97 | all_features = x_vec 98 | length = x_vec.shape[0] 99 | #y_vec = [START] + y_vec + [END] 100 | yp_vec_ids = y_vec[:-1] 101 | y_vec_ids = y_vec[1:] 102 | log_M = np.dot(all_features,theta) 103 | log_alphas,last = self.forward(log_M,self.label_id[START]) 104 | log_betas, zero = self.backward(log_M,self.label_id[END]) 105 | time,state = log_alphas.shape 106 | """ 107 | Reshaping allows me to do the entire computation of the unormalised 108 | probabilities in one step, which means its faster, because it's done 109 | in numpy 110 | """ 111 | log_alphas1 = log_alphas.reshape(time,state,1) 112 | log_betas1 = log_betas.reshape(time,1,state) 113 | log_Z = misc.logsumexp(last) 114 | log_probs = log_alphas1 + log_M + log_betas1 - log_Z 115 | log_probs = log_probs.reshape(log_probs.shape+(1,)) 116 | """ 117 | Find the expected value of f_k over all transitions 118 | and emperical values 119 | (numpy makes it so easy, only if you do it right) 120 | """ 121 | exp_features = np.sum( np.exp(log_probs) * all_features, axis= (0,1,2) ) 122 | emp_features = np.sum( all_features[range(length),yp_vec_ids,y_vec_ids], axis = 0 ) 123 | 124 | likelihood += np.sum(log_M[range(length),yp_vec_ids,y_vec_ids]) - log_Z 125 | derivative += emp_features - exp_features 126 | if debug: 127 | print "EmpFeatures:" 128 | print emp_features 129 | print "ExpFeatures:" 130 | print exp_features 131 | 132 | return ( 133 | - ( likelihood - self.regulariser(theta)), 134 | - ( derivative - self.regulariser_deriv(theta)) 135 | ) 136 | 137 | def predict(self,x_vec, debug=False): 138 | # small overhead, no copying is done 139 | """ 140 | all_features: len(x_vec+1) x Y' x Y x K 141 | log_potential: len(x_vec+1) x Y' x Y 142 | argmaxes: len(x_vec+1) x Y' 143 | """ 144 | all_features = self.all_features(x_vec) 145 | log_potential = np.dot(all_features,self.theta) 146 | return [ self.labels[i] for i in self.slow_predict(log_potential,len(x_vec),len(self.labels)) ] 147 | 148 | def slow_predict(self,log_potential,N,K,debug=False): 149 | """ 150 | Find the most likely assignment to labels given parameters using the 151 | Viterbi algorithm. 152 | """ 153 | g0 = log_potential[0,0] 154 | g = log_potential[1:] 155 | 156 | B = np.ones((N,K), dtype=np.int32) * -1 157 | # compute max-marginals and backtrace matrix 158 | V = g0 159 | for t in xrange(1,N): 160 | U = np.empty(K) 161 | for y in xrange(K): 162 | w = V + g[t-1,:,y] 163 | B[t,y] = b = w.argmax() 164 | U[y] = w[b] 165 | V = U 166 | # extract the best path by brack-tracking 167 | y = V.argmax() 168 | trace = [] 169 | for t in reversed(xrange(N)): 170 | trace.append(y) 171 | y = B[t, y] 172 | trace.reverse() 173 | return trace 174 | 175 | def log_predict(self,log_potential,N,K,debug=False): 176 | if debug: 177 | print 178 | print 179 | print "Log Potentials:" 180 | print log_potential 181 | print 182 | print 183 | prev_state = log_potential[0,self.label_id[START]] 184 | prev_state_v = prev_state.reshape((K,1)) 185 | argmaxes = np.zeros((N,K),dtype=np.int) 186 | if debug: 187 | print "T=0" 188 | print prev_state 189 | print 190 | for i in range(1,N): 191 | curr_state = prev_state_v + log_potential[i] 192 | argmaxes[i] = np.nanargmax(curr_state,axis=0) 193 | prev_state[:] = curr_state[argmaxes[i],range(K)] 194 | if debug: 195 | print 196 | print "T=%d"%i 197 | print curr_state 198 | print prev_state 199 | print argmaxes[i] 200 | print 201 | curr_state = prev_state + log_potential[-1,self.label_id[END]] 202 | prev_label = np.argmax(curr_state) 203 | if debug: print prev_label 204 | result = [] 205 | for i in reversed(range(N)): 206 | if debug:print result 207 | result.append(prev_label) 208 | prev_label = argmaxes[i,prev_label] 209 | result.reverse() 210 | return result 211 | 212 | def train(self,x_vecs,y_vecs,debug=False): 213 | vectorised_x_vecs,vectorised_y_vecs = self.create_vector_list(x_vecs,y_vecs) 214 | l = lambda theta: self.neg_likelihood_and_deriv(vectorised_x_vecs,vectorised_y_vecs,theta) 215 | val = optimize.fmin_l_bfgs_b(l,self.theta) 216 | if debug: print val 217 | self.theta,_,_ = val 218 | return self.theta 219 | 220 | 221 | class FeatureSet(object): 222 | @classmethod 223 | def functions(cls,lbls,*arguments): 224 | def gen(): 225 | for lbl in lbls: 226 | for arg in arguments: 227 | if isinstance(arg,tuple): 228 | yield cls(lbl,*arg) 229 | else: 230 | yield cls(lbl,arg) 231 | return list(gen()) 232 | def __repr__(self): 233 | return "%s(%s)"%(self.__class__.__name__,self.__dict__) 234 | 235 | class Transitions(FeatureSet): 236 | def __init__(self,curr_lbl,prev_lbl): 237 | self.prev_label = prev_lbl 238 | self.label = curr_lbl 239 | 240 | def __call__(self,yp,y,x_v,i): 241 | if yp==self.prev_label and y==self.label: 242 | return 1 243 | else: 244 | return 0 245 | 246 | -------------------------------------------------------------------------------- /crf_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import misc 3 | 4 | def logdotexp_vec_mat(loga, logM): 5 | return misc.logsumexp(loga.reshape(loga.shape + (1,))+logM,axis=0) 6 | 7 | def logdotexp_mat_vec(logM, logb): 8 | old = np.array([misc.logsumexp(x + logb) for x in logM], copy=False) 9 | new = misc.logsumexp(logM+logb.reshape((1,)+logb.shape),axis=1) 10 | print "N:",new 11 | print "O:",old 12 | return old 13 | 14 | def logalphas(Mlist): 15 | logalpha = Mlist[0][0] # alpha(1) 16 | logalphas = [logalpha] 17 | for logM in Mlist[1:]: 18 | logalpha = logdotexp_vec_mat(logalpha, logM) 19 | logalphas.append(logalpha) 20 | return logalphas 21 | 22 | def logbetas(Mlist): 23 | logbeta = Mlist[-1][:,2] 24 | logbetas = [logbeta] 25 | for logM in Mlist[-2::-1]: # reverse 26 | logbeta = logdotexp_mat_vec(logM, logbeta) 27 | logbetas.append(logbeta) 28 | return logbetas[::-1] 29 | M = [ 30 | np.log(np.array( 31 | [[0.2,0.2,0.6], 32 | [0.2,0.6,0.2], 33 | [0.7,0.0,0.3]])), 34 | np.log(np.array( 35 | [[0.0,0.0,0.0], 36 | [0.2,0.6,0.2], 37 | [0.7,0.1,0.2]])), 38 | np.log(np.array( 39 | [[0.1,0.2,0.7], 40 | [0.2,0.6,0.2], 41 | [0.7,0.1,0.2]])), 42 | ] 43 | alphas = np.exp(np.array(logalphas(M))) 44 | betas = np.exp(np.array(logbetas(M))) 45 | print "Alpha:" 46 | print alphas 47 | print "Beta:" 48 | print betas 49 | 50 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from crf import * 2 | from collections import defaultdict 3 | import re 4 | import sys 5 | 6 | word_data = [] 7 | label_data = [] 8 | all_labels = set() 9 | word_sets = defaultdict(set) 10 | obsrvs = set() 11 | for line in open(sys.argv[1]): 12 | words,labels = [],[] 13 | for token in line.strip().split(): 14 | word,label= token.split('/') 15 | all_labels.add(label) 16 | word_sets[label].add(word.lower()) 17 | obsrvs.add(word.lower()) 18 | words.append(word) 19 | labels.append(label) 20 | 21 | word_data.append(words) 22 | label_data.append(labels) 23 | if __name__ == "__main__": 24 | labels = list(all_labels) 25 | lbls = [START] + labels + [END] 26 | transition_functions = [ 27 | lambda yp,y,x_v,i,_yp=_yp,_y=_y: 1 if yp==_yp and y==_y else 0 28 | for _yp in lbls[:-1] for _y in lbls[1:] 29 | ] 30 | """ 31 | observation_functions = [ 32 | lambda yp,y,x_v,i,_yp=_yp,_y=_y: 1 if i < len(x_v) and y==_y and x_v[i].lower()==_o else 0 33 | for _y in lbls[1:] 34 | for _o in obsrvs 35 | ] 36 | """ 37 | def set_membership(tag): 38 | def fun(yp,y,x_v,i): 39 | if i < len(x_v) and x_v[i].lower() in word_sets[tag]: 40 | return 1 41 | else: 42 | return 0 43 | return fun 44 | observation_functions = [set_membership(t) for t in word_sets ] 45 | misc_functions = [ 46 | lambda yp,y,x_v,i: 1 if i < len(x_v) and re.match('^[^0-9a-zA-Z]+$',x_v[i]) else 0, 47 | lambda yp,y,x_v,i: 1 if i < len(x_v) and re.match('^[A-Z\.]+$',x_v[i]) else 0, 48 | lambda yp,y,x_v,i: 1 if i < len(x_v) and re.match('^[0-9\.]+$',x_v[i]) else 0 49 | ] 50 | tagval_functions = [ 51 | lambda yp,y,x_v,i,_y=_y,_x=_x: 1 if i < len(x_v) and y==_y and x_v[i].lower() ==_x else 0 52 | for _y in labels 53 | for _x in obsrvs] 54 | crf = CRF( labels = labels, 55 | feature_functions = transition_functions + tagval_functions + observation_functions + misc_functions ) 56 | vectorised_x_vecs,vectorised_y_vecs = crf.create_vector_list(word_data,label_data) 57 | l = lambda theta: crf.neg_likelihood_and_deriv(vectorised_x_vecs,vectorised_y_vecs,theta) 58 | #crf.theta = optimize.fmin_bfgs(l, crf.theta, maxiter=100) 59 | print "Minimizing..." 60 | def print_value(theta): 61 | print crf.neg_likelihood_and_deriv(vectorised_x_vecs,vectorised_y_vecs,theta) 62 | val = optimize.fmin_l_bfgs_b(l, crf.theta,callback=print_value) 63 | print val 64 | theta,_,_ = val 65 | """ 66 | theta = crf.theta 67 | for _ in range(10000): 68 | value, gradient = l(theta) 69 | print value 70 | theta = theta - 0.1*gradient 71 | """ 72 | crf.theta = theta 73 | print crf.neg_likelihood_and_deriv(vectorised_x_vecs,vectorised_y_vecs,theta) 74 | print 75 | print "Latest:" 76 | for x_vec in word_data[-5:]: 77 | print x_vec 78 | print crf.predict(x_vec) 79 | -------------------------------------------------------------------------------- /features.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import defaultdict 3 | from crf import FeatureSet 4 | alphas = re.compile('^[a-zA-Z]+$') 5 | 6 | def fit_dataset(filename): 7 | labels = set() 8 | obsrvs = set() 9 | word_sets = defaultdict(set) 10 | 11 | sents_words = [] 12 | sents_labels = [] 13 | 14 | for line in open(filename,'r'): 15 | sent_words = [] 16 | sent_labels = [] 17 | try: 18 | for token in line.strip().split(): 19 | word,label= token.rsplit('/',2) 20 | if alphas.match(word): 21 | orig_word = word 22 | word = word.lower() 23 | labels.add(label) 24 | obsrvs.add(word) 25 | word_sets[label].add(word) 26 | sent_words.append(orig_word) 27 | sent_labels.append(label) 28 | else: 29 | continue 30 | sents_words.append(sent_words) 31 | sents_labels.append(sent_labels) 32 | except Exception: 33 | print line 34 | return (labels,obsrvs,word_sets,sents_words,sents_labels) 35 | 36 | class Membership(FeatureSet): 37 | def __init__(self,label,word_set): 38 | self.label = label 39 | self.word_set = word_set 40 | def __call__(self,yp,y,x_v,i): 41 | if i < len(x_v) and y == self.label and (x_v[i].lower() in self.word_set): 42 | return 1 43 | else: 44 | return 0 45 | class FileMembership(Membership): 46 | @classmethod 47 | def functions(cls,lbls,*filenames): 48 | sets = [ 49 | set([ line.strip().lower() for line in open(filename,'r') ]) 50 | for filename in filenames 51 | ] 52 | return super(FileMembership, cls).functions(lbls,*sets) 53 | 54 | class MatchRegex(FeatureSet): 55 | def __init__(self,label,regex): 56 | self.label = label 57 | self.regex = re.compile(regex) 58 | def __call__(self,yp,y,x_v,i): 59 | if i < len(x_v) and y==self.label and self.regex.match(x_v[i]): 60 | return 1 61 | else: 62 | return 0 63 | 64 | if __name__ == "__main__": 65 | val = Membership.functions(['HI','HO'],set(['hi']),set(['ho'])) + MatchRegex.functions(['HI','HO'],'\w+','\d+') 66 | print val[0]('HO','HO',['hi'],0) 67 | print val[0]('HO','HI',['hi'],0) 68 | -------------------------------------------------------------------------------- /sample.txt: -------------------------------------------------------------------------------- 1 | Confidence/NN in/IN the/DT pound/NN is/VBZ widely/RB expected/VBN to/TO take/VB another/DT sharp/JJ dive/NN if/IN trade/NN figures/NNS for/IN September/NNP ,/, due/JJ for/IN release/NN tomorrow/NN ,/, fail/VB to/TO show/VB a/DT substantial/JJ improvement/NN from/IN July/NNP and/CC August/NNP 's/POS near-record/JJ deficits/NNS ./. 2 | Chancellor/NNP of/IN the/DT Exchequer/NNP Nigel/NNP Lawson/NNP 's/POS restated/VBN commitment/NN to/TO a/DT firm/NN monetary/JJ policy/NN has/VBZ helped/VBN to/TO prevent/VB a/DT freefall/NN in/IN sterling/NN over/IN the/DT past/JJ week/NN ./. 3 | But/CC analysts/NNS reckon/VBP underlying/VBG support/NN for/IN sterling/NN has/VBZ been/VBN eroded/VBN by/IN the/DT chancellor/NN 's/POS failure/NN to/TO announce/VB any/DT new/JJ policy/NN measures/NNS in/IN his/PRP$ Mansion/NNP House/NNP speech/NN last/JJ Thursday/NNP ./. 4 | This/DT has/VBZ increased/VBN the/DT risk/NN of/IN the/DT government/NN being/VBG forced/VBN to/TO increase/VB base/NN rates/NNS to/TO 16/CD %/NN from/IN their/PRP$ current/JJ 15/CD %/NN level/NN to/TO defend/VB the/DT pound/NN ,/, economists/NNS and/CC foreign/JJ exchange/NN market/NN analysts/NNS say/VBP ./. 5 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import crf 4 | from scipy import misc 5 | from collections import defaultdict 6 | from numpy import empty, zeros, ones, log, exp, sqrt, add, int32, abs 7 | def argmax(X): 8 | """ 9 | Find the most likely assignment to labels given parameters using the 10 | Viterbi algorithm. 11 | """ 12 | N,K,_ = X.shape 13 | g0 = X[0,0] 14 | g = X[1:] 15 | 16 | B = ones((N,K), dtype=int32) * -1 17 | # compute max-marginals and backtrace matrix 18 | V = g0 19 | for t in xrange(1,N): 20 | U = empty(K) 21 | for y in xrange(K): 22 | w = V + g[t-1,:,y] 23 | B[t,y] = b = w.argmax() 24 | U[y] = w[b] 25 | V = U 26 | # extract the best path by brack-tracking 27 | y = V.argmax() 28 | trace = [] 29 | for t in reversed(xrange(N)): 30 | trace.append(y) 31 | y = B[t, y] 32 | trace.reverse() 33 | return trace 34 | 35 | def forward(g0, g, N, K): 36 | """ 37 | Calculate matrix of forward unnormalized log-probabilities. 38 | 39 | a[i,y] log of the sum of scores of all sequences from 0 to i where 40 | the label at position i is y. 41 | """ 42 | a = np.zeros((N,K)) 43 | a[0,:] = g0 44 | for t in xrange(1,N): 45 | ayp = a[t-1,:] 46 | for y in xrange(K): 47 | a[t,y] = misc.logsumexp(ayp + g[t-1,:,y]) 48 | return a 49 | 50 | def backward(g, N, K): 51 | """ Calculate matrix of backward unnormalized log-probabilities. """ 52 | b = np.zeros((N,K)) 53 | for t in reversed(xrange(0,N-1)): 54 | by = b[t+1,:] 55 | for yp in xrange(K): 56 | b[t,yp] = misc.logsumexp(by + g[t,yp,:]) 57 | return b 58 | 59 | 60 | 61 | def expectation(N,K,log_M): 62 | """ 63 | Expectation of the sufficient statistics given ``x`` and current 64 | parameter settings. 65 | """ 66 | g0 = log_M[0,0] 67 | g = log_M[1:] 68 | a = forward(g0,g,N,K) 69 | b = backward(g,N,K) 70 | print "Forward:" 71 | print a 72 | print "Backward:" 73 | print b 74 | # log-normalizing constant 75 | logZ = misc.logsumexp(a[N-1,:]) 76 | 77 | E = defaultdict(float) 78 | 79 | # The first factor needs to be special case'd 80 | # E[ f( y_0 ) ] = p(y_0 | y_[1:N], x) * f(y_0) 81 | c = exp(g0 + b[0,:] - logZ).clip(0.0, 1.0) 82 | for y in xrange(K): 83 | p = c[y] 84 | if p < 1e-40: continue # skip really small updates. 85 | for k in f[0, None, y]: 86 | E[k] += p 87 | 88 | for t in xrange(1,N): 89 | # vectorized computation of the marginal for this transition factor 90 | c = exp((add.outer(a[t-1,:], b[t,:]) + g[t-1,:,:] - logZ)).clip(0.0, 1.0) 91 | 92 | for yp in xrange(K): 93 | for y in xrange(K): 94 | # we can also use the following to compute ``p`` but its quite 95 | # a bit slower than the computation of vectorized quantity ``c``. 96 | #p = exp(a[t-1,yp] + g[t-1,yp,y] + b[t,y] - logZ).clip(0.0, 1.0) 97 | p = c[yp, y] 98 | if p < 1e-40: continue # skip really small updates. 99 | # expectation of this factor is p*f(t, yp, y) 100 | for k in f[t, yp, y]: 101 | E[k] += p 102 | 103 | return E 104 | 105 | 106 | class TestCRF(unittest.TestCase): 107 | 108 | def setUp(self): 109 | self.matrix = 0.001 + np.random.poisson(lam=1.5, size=(3,3)).astype(np.float) 110 | self.vector = 0.001 + np.random.poisson(lam=1.5, size=(3,)).astype(np.float) 111 | self.M = 0.001 + np.random.poisson(lam=1.5, size=(10,3,3)).astype(np.float) 112 | labels = ['A','B','C'] 113 | obsrvs = ['a','b','c','d','e','f'] 114 | lbls = [crf.START] + labels + [crf.END] 115 | 116 | transition_functions = [ 117 | lambda yp,y,x_v,i,_yp=_yp,_y=_y: 1 if yp==_yp and y==_y else 0 118 | for _yp in lbls[:-1] 119 | for _y in lbls[1:]] 120 | observation_functions = [ 121 | lambda yp,y,x_v,i,_y=_y,_x=_x: 1 if i < len(x_v) and y==_y and x_v[i]==_x else 0 122 | for _y in labels 123 | for _x in obsrvs] 124 | self.crf = crf.CRF( labels = labels, 125 | feature_functions = transition_functions + observation_functions ) 126 | 127 | 128 | def test_log_dot_mv(self): 129 | self.assertTrue( 130 | (np.around(np.exp( 131 | crf.log_dot_mv( 132 | np.log(self.matrix), 133 | np.log(self.vector) 134 | ) 135 | ),10) == np.around(np.dot(self.matrix,self.vector),10)).all() 136 | ) 137 | 138 | def test_log_dot_vm(self): 139 | self.assertTrue( 140 | (np.around(np.exp( 141 | crf.log_dot_vm( 142 | np.log(self.vector), 143 | np.log(self.matrix) 144 | ) 145 | ),10) == np.around(np.dot(self.vector,self.matrix),10)).all() 146 | ) 147 | 148 | def test_forward(self): 149 | M = self.M/self.M.sum(axis=2).reshape(self.M.shape[:-1]+(1,)) 150 | res = np.around(np.exp(self.crf.forward(np.log(M))[0]).sum(axis=1),10) 151 | res_true = np.around(np.ones(M.shape[0]),10) 152 | self.assertTrue((res == res_true).all()) 153 | 154 | def test_predict(self): 155 | label_pred = self.crf.slow_predict(self.M,self.M.shape[0],self.M.shape[1]) 156 | label_act = argmax(self.M) 157 | self.assertTrue(label_pred == label_act) 158 | 159 | def test_integrated(self): 160 | x_vec = ["a","b","c","d","e","f"] 161 | y_vec = ["A","B","C","A","B","C"] 162 | self.crf.train([x_vec],[y_vec]) 163 | l = lambda theta: crf.neg_likelihood_and_deriv(vectorised_x_vecs,vectorised_y_vecs,theta) 164 | self.assertTrue(self.crf.predict(x_vec)==y_vec[1:-1]) 165 | 166 | if __name__ == '__main__': 167 | unittest.main() 168 | 169 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from crf import CRF 2 | from features import * 3 | import re, sys 4 | import pickle 5 | training_file = sys.argv[1] 6 | 7 | if __name__ == '__main__': 8 | labels,obsrvs,word_sets,word_data,label_data = fit_dataset(training_file) 9 | crf = CRF( 10 | labels=list(labels), 11 | feature_functions = Membership.functions(labels,*word_sets.values()) + 12 | MatchRegex.functions(labels, 13 | '^[^0-9a-zA-Z\-]+$', 14 | '^[^0-9\-]+$', 15 | '^[A-Z]+$', 16 | '^-?[1-9][0-9]*\.[0-9]+$', 17 | '^[1-9][0-9\.]+[a-z]+$', 18 | '^[0-9]+$', 19 | '^[A-Z][a-z]+$', 20 | '^([A-Z][a-z]*)+$', 21 | '^[^aeiouAEIOU]+$' 22 | ))# + [ 23 | # lambda yp,y,x_v,i,_y=_y,_x=_x: 24 | # 1 if i < len(x_v) and y==_y and x_v[i].lower() ==_x else 0 25 | # for _y in labels 26 | # for _x in obsrvs 27 | #]) 28 | crf.train(word_data[:-5],label_data[:-5]) 29 | pickle.dump(crf,open(sys.argv[2],'wb')) 30 | for i in range(-5,0): 31 | print word_data[i] 32 | print crf.predict(word_data[i]) 33 | print label_data[i] 34 | --------------------------------------------------------------------------------