├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── addressnet ├── __init__.py ├── dataset.py ├── lookups.py ├── model.py ├── predict.py ├── pretrained │ ├── __init__.py │ ├── checkpoint │ ├── graph.pbtxt │ ├── model.ckpt.data-00000-of-00001 │ ├── model.ckpt.index │ └── model.ckpt.meta └── typo.py ├── example-result.png ├── generate_tf_records.py ├── predict.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jason Rigby 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft addressnet/pretrained 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AddressNet 2 | 3 | [Read more on Towards Data Science](https://towardsdatascience.com/addressnet-how-to-build-a-robust-street-address-parser-using-a-recurrent-neural-network-518d97b9aebd) 4 | 5 | ## Background 6 | 7 | This project is an attempt to create a recurrent neural network that 8 | segments an Australian street address into its components such that it 9 | can be more easily matched against a structured address database. The 10 | primary use-case for a model such as this is to transform legacy address 11 | data (e.g. unvalidated addresses, such as those collected on paper or by 12 | phone) into a reportable form at minimal cost. Once structured address 13 | data is produced, searching databases such as GNAF for geocoding 14 | information is much easier! 15 | 16 | ## Installation 17 | Get the latest code by installing directly from git using 18 | ``` 19 | pip install git+https://github.com/jasonrig/address-net.git 20 | ``` 21 | 22 | Or from PyPI: 23 | ``` 24 | pip install address-net 25 | pip install address-net[tf] # install TensorFlow (CPU version) 26 | pip install address-net[tf_gpu] # install TensorFlow (GPU version) 27 | ``` 28 | 29 | You will need an appropriate version of TensorFlow installed, ideally greater 30 | than version 1.12. This is not automatically installed since the CPU and GPU 31 | versions of TensorFlow exist in separate packages. 32 | 33 | ## Model output 34 | This model performs character-level classification, assigning each 35 | a one of the following 22 classes as defined by the 36 | [GNAF database](https://data.gov.au/dataset/geocoded-national-address-file-g-naf): 37 | 38 | 1. Separator/Blank 39 | 2. Building name 40 | 3. Level number prefix 41 | 4. Level number 42 | 5. Level number suffix 43 | 6. Level type 44 | 7. Flat number prefix 45 | 8. Flat number 46 | 9. Flat number suffix 47 | 10. Flat type 48 | 11. Number first prefix 49 | 12. Number first 50 | 13. Number first suffix 51 | 14. Number last prefix 52 | 15. Number last 53 | 16. Number last suffix 54 | 17. Street name 55 | 18. Street suffix 56 | 19. Street type 57 | 20. Locality name 58 | 21. State 59 | 22. Postcode 60 | 61 | An example result from this model for "168A Separation Street Northcote, 62 | VIC 3070" would be: 63 | 64 | ![address classification for 168A Separation Street Northcote, 65 | VIC 3070](./example-result.png) 66 | 67 | ## Architecture 68 | This model uses a character-level vocabulary consisting of digits, 69 | lower-case ASCII characters, punctuation and whitespace as defined in 70 | Python's `string` package. These characters are encoded using embedding 71 | vectors of eight units in length. 72 | 73 | The encoded text is fed through a bidirectional three-layer 128-Gated 74 | Recurrent Unit (GRU) Recurrent Neural Network (RNN). The outputs from 75 | the forward and backward pass are concatenated and fed through a dense 76 | layer with ELU activations to produce logits for each class. The final 77 | output probabilities are generated through a softmax transformation. 78 | 79 | Regularisation is achieved in three ways: 80 | 81 | 1. Data augmentation: the addresses constructed from the GNAF dataset 82 | are semi-randomly generated so that a huge variety of permutations are 83 | produced 84 | 2. Noise: a random typo generator that creates plausible errors 85 | consisting of insertions, transpositions, deletions and substitutions of 86 | nearby keys on the keyboard is used for each address 87 | 3. Dropout for the outputs and state is applied to the RNN layers 88 | 89 | ## Data sources 90 | The data used to produce this model was from the 91 | [GNAF database](https://data.gov.au/dataset/geocoded-national-address-file-g-naf) 92 | and is available under a permissive Creative Commons-like license. The 93 | GNAF data is available as a series of SQL files that can be imported to 94 | databases such as PostgreSQL, including a summary view named 95 | "address_view". Code included in `generate_tf_records.py` was used to 96 | consume a CSV dump of this file, producing a TFRecord file that is 97 | natively supported by TensorFlow. 98 | 99 | ## Pretrained model 100 | While you are free to train this model using the `model_fn` provided, 101 | a pretrained model is supplied with this package under 102 | `addressnet/pretrained` and is the default model loaded when using the 103 | prediction function. Thus, using this package should be as simple as: 104 | ```python 105 | from addressnet.predict import predict_one 106 | 107 | if __name__ == "__main__": 108 | # This is a fake address! 109 | print(predict_one("casa del gelato, 10A 24-26 high street road mount waverley vic 3183")) 110 | ``` 111 | 112 | Expected output: 113 | ```python 114 | { 115 | 'building_name': 'CASA DEL GELATO', 116 | 'flat_number': '10', 117 | 'flat_number_suffix': 'A', 118 | 'number_first': '24', 119 | 'number_last': '26', 120 | 'street_name': 'HIGH STREET', 121 | 'street_type': 'ROAD', 122 | 'locality_name': 'MOUNT WAVERLEY', 123 | 'state': 'VICTORIA', 124 | 'postcode': '3183' 125 | } 126 | ``` 127 | 128 | Because the model is not sensitive to small typographical errors, a 129 | simple string similarity algorithm is used to normalise fields such as 130 | `street_type` and `state`, since we know exhaustively what they should 131 | be. 132 | -------------------------------------------------------------------------------- /addressnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonrig/address-net/28e7c2de030bae56f81c66d7e640dcc2d04fdfb6/addressnet/__init__.py -------------------------------------------------------------------------------- /addressnet/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Callable, List 2 | from collections import OrderedDict 3 | 4 | import random 5 | import tensorflow as tf 6 | import numpy as np 7 | import string 8 | 9 | import addressnet.lookups as lookups 10 | from addressnet.typo import generate_typo 11 | 12 | # Schema used to decode data from the TFRecord file 13 | _features = OrderedDict([ 14 | ('building_name', tf.io.FixedLenFeature([], tf.string)), 15 | ('lot_number_prefix', tf.io.FixedLenFeature([], tf.string)), 16 | ('lot_number', tf.io.FixedLenFeature([], tf.string)), 17 | ('lot_number_suffix', tf.io.FixedLenFeature([], tf.string)), 18 | ('flat_number_prefix', tf.io.FixedLenFeature([], tf.string)), 19 | ('flat_number_suffix', tf.io.FixedLenFeature([], tf.string)), 20 | ('level_number_prefix', tf.io.FixedLenFeature([], tf.string)), 21 | ('level_number_suffix', tf.io.FixedLenFeature([], tf.string)), 22 | ('number_first_prefix', tf.io.FixedLenFeature([], tf.string)), 23 | ('number_first_suffix', tf.io.FixedLenFeature([], tf.string)), 24 | ('number_last_prefix', tf.io.FixedLenFeature([], tf.string)), 25 | ('number_last_suffix', tf.io.FixedLenFeature([], tf.string)), 26 | ('street_name', tf.io.FixedLenFeature([], tf.string)), 27 | ('locality_name', tf.io.FixedLenFeature([], tf.string)), 28 | ('postcode', tf.io.FixedLenFeature([], tf.string)), 29 | ('flat_number', tf.io.FixedLenFeature([], tf.int64)), 30 | ('level_number', tf.io.FixedLenFeature([], tf.int64)), 31 | ('number_first', tf.io.FixedLenFeature([], tf.int64)), 32 | ('number_last', tf.io.FixedLenFeature([], tf.int64)), 33 | ('flat_type', tf.io.FixedLenFeature([], tf.int64)), 34 | ('level_type', tf.io.FixedLenFeature([], tf.int64)), 35 | ('street_type_code', tf.io.FixedLenFeature([], tf.int64)), 36 | ('street_suffix_code', tf.io.FixedLenFeature([], tf.int64)), 37 | ('state_abbreviation', tf.io.FixedLenFeature([], tf.int64)), 38 | ('latitude', tf.io.FixedLenFeature([], tf.float32)), 39 | ('longitude', tf.io.FixedLenFeature([], tf.float32)) 40 | ]) 41 | 42 | # List of fields used as labels in the training data 43 | labels_list = [ 44 | 'building_name', # 1 45 | 'level_number_prefix', # 2 46 | 'level_number', # 3 47 | 'level_number_suffix', # 4 48 | 'level_type', # 5 49 | 'flat_number_prefix', # 6 50 | 'flat_number', # 7 51 | 'flat_number_suffix', # 8 52 | 'flat_type', # 9 53 | 'number_first_prefix', # 10 54 | 'number_first', # 11 55 | 'number_first_suffix', # 12 56 | 'number_last_prefix', # 13 57 | 'number_last', # 14 58 | 'number_last_suffix', # 15 59 | 'street_name', # 16 60 | 'street_suffix_code', # 17 61 | 'street_type_code', # 18 62 | 'locality_name', # 19 63 | 'state_abbreviation', # 20 64 | 'postcode' # 21 65 | ] 66 | # Number of labels in total (+1 for the blank category) 67 | n_labels = len(labels_list) + 1 68 | 69 | # Allowable characters for the encoded representation 70 | vocab = list(string.digits + string.ascii_lowercase + string.punctuation + string.whitespace) 71 | 72 | 73 | def vocab_lookup(characters: str) -> (int, np.ndarray): 74 | """ 75 | Converts a string into a list of vocab indices 76 | :param characters: the string to convert 77 | :param training: if True, artificial typos will be introduced 78 | :return: the string length and an array of vocab indices 79 | """ 80 | result = list() 81 | for c in characters.lower(): 82 | try: 83 | result.append(vocab.index(c) + 1) 84 | except ValueError: 85 | result.append(0) 86 | return len(characters), np.array(result, dtype=np.int64) 87 | 88 | 89 | def decode_data(record: List[Union[str, int, float]]) -> Union[str, int, float]: 90 | """ 91 | Decodes a record from the tfrecord file by converting all strings to UTF-8 encoding, and any numeric field with 92 | a value of -1 to None. 93 | :param record: the record to decode 94 | :return: an iterator for yielding the decoded fields 95 | """ 96 | for item in record: 97 | try: 98 | # Attempt to treat the item in the record as a string 99 | yield item.decode("UTF-8") 100 | except AttributeError: 101 | # Treat the item as a number and encode -1 as None (see generate_tf_records.py) 102 | yield item if item != -1 else None 103 | 104 | 105 | def labels(text: Union[str, int], field_name: Optional[str], mutate: bool = True) -> (str, np.ndarray): 106 | """ 107 | Generates a numpy matrix labelling each character by field type. Strings have artificial typos introduced if 108 | mutate == True 109 | :param text: the text to label 110 | :param field_name: the name of the field to which the text belongs, or None if the label is blank 111 | :param mutate: introduce artificial typos 112 | :return: the original text and the numpy matrix of labels 113 | """ 114 | 115 | # Ensure the input is a string, encoding None to an empty to string 116 | if text is None: 117 | text = '' 118 | else: 119 | # Introduce artificial typos if mutate == True 120 | text = generate_typo(str(text)) if mutate else str(text) 121 | labels_matrix = np.zeros((len(text), n_labels), dtype=np.bool) 122 | 123 | # If no field is supplied, then encode the label using the blank category 124 | if field_name is None: 125 | labels_matrix[:, 0] = True 126 | else: 127 | labels_matrix[:, labels_list.index(field_name) + 1] = True 128 | return text, labels_matrix 129 | 130 | 131 | def random_separator(min_length: int = 1, max_length: int = 3, possible_sep_chars: Optional[str] = r",./\ ") -> str: 132 | """ 133 | Generates a space-padded separator of random length using a random character from possible_sep_chars 134 | :param min_length: minimum length of the separator 135 | :param max_length: maximum length of the separator 136 | :param possible_sep_chars: string of possible characters to use for the separator 137 | :return: the separator string 138 | """ 139 | chars = [" "] * random.randint(min_length, max_length) 140 | if len(chars) > 0 and possible_sep_chars: 141 | sep_char = random.choice(possible_sep_chars) 142 | chars[random.randrange(len(chars))] = sep_char 143 | return ''.join(chars) 144 | 145 | 146 | def join_labels(lbls: [np.ndarray], sep: Union[str, Callable[..., str]] = " ") -> np.ndarray: 147 | """ 148 | Concatenates a series of label matrices with a separator 149 | :param lbls: a list of numpy matrices 150 | :param sep: the separator string or function that returns the sep string 151 | :return: the concatenated labels 152 | """ 153 | if len(lbls) < 2: 154 | return lbls 155 | 156 | joined_labels = None 157 | sep_str = None 158 | 159 | # if `sep` is not a function, set the separator (`sep_str`) to `sep`, otherwise leave as None 160 | if not callable(sep): 161 | sep_str = sep 162 | 163 | for l in lbls: 164 | if joined_labels is None: 165 | joined_labels = l 166 | else: 167 | # If `sep` is a function, call it on each iteration 168 | if callable(sep): 169 | sep_str = sep() 170 | 171 | # Skip zero-length labels 172 | if l.shape[0] == 0: 173 | continue 174 | elif sep_str is not None and len(sep_str) > 0 and joined_labels.shape[0] > 0: 175 | # Join using sep_str if it's present and non-zero in length 176 | joined_labels = np.concatenate([joined_labels, labels(sep_str, None, mutate=False)[1], l], axis=0) 177 | else: 178 | # Otherwise, directly concatenate the labels 179 | joined_labels = np.concatenate([joined_labels, l], axis=0) 180 | 181 | assert joined_labels is not None, "No labels were joined!" 182 | assert joined_labels.shape[1] == n_labels, "The number of labels generated was unexpected: got %i but wanted %i" % ( 183 | joined_labels.shape[1], n_labels) 184 | 185 | return joined_labels 186 | 187 | 188 | def join_str_and_labels(parts: [(str, np.ndarray)], sep: Union[str, Callable[..., str]] = " ") -> (str, np.ndarray): 189 | """ 190 | Joins the strings and labels using the given separator 191 | :param parts: a list of string/label tuples 192 | :param sep: a string or function that returns the string to be used as a separator 193 | :return: the joined string and labels 194 | """ 195 | # Keep only the parts with strings of length > 0 196 | parts = [p for p in parts if len(p[0]) > 0] 197 | 198 | # If there are no parts at all, return an empty string an array of shape (0, n_labels) 199 | if len(parts) == 0: 200 | return '', np.zeros((0, n_labels)) 201 | # If there's only one part, just give it back as-is 202 | elif len(parts) == 1: 203 | return parts[0] 204 | 205 | # Pre-generate the separators - this is important if `sep` is a function returning non-deterministic results 206 | n_sep = len(parts) - 1 207 | if callable(sep): 208 | seps = [sep() for _ in range(n_sep)] 209 | else: 210 | seps = [sep] * n_sep 211 | seps += [''] 212 | 213 | # Join the strings using the list of separators 214 | strings = ''.join(sum([(s[0][0], s[1]) for s in zip(parts, seps)], ())) 215 | 216 | # Join the labels using an iterator function 217 | sep_iter = iter(seps) 218 | lbls = join_labels([s[1] for s in parts], sep=lambda: next(sep_iter)) 219 | 220 | assert len(strings) == lbls.shape[0], "string length %i (%s), label length %i using sep %s" % ( 221 | len(strings), strings, lbls.shape[0], seps) 222 | return strings, lbls 223 | 224 | 225 | def choose(option1: Callable = lambda: None, option2: Callable = lambda: None): 226 | """ 227 | Randomly run either option 1 or option 2 228 | :param option1: a possible function to run 229 | :param option2: another possible function to run 230 | :return: the result of the function 231 | """ 232 | if random.getrandbits(1): 233 | return option1() 234 | else: 235 | return option2() 236 | 237 | 238 | def synthesise_address(*record) -> (int, np.ndarray, np.ndarray): 239 | """ 240 | Uses the record information to construct a formatted address with labels. The addresses generated involve 241 | semi-random permutations and corruptions to help avoid over-fitting. 242 | :param record: the decoded item from the TFRecord file 243 | :return: the address string length, encoded text and labels 244 | """ 245 | fields = dict(zip(_features.keys(), decode_data(record))) 246 | 247 | # Generate the individual address components: 248 | if fields['level_type'] > 0: 249 | level = generate_level_number(fields['level_type'], fields['level_number_prefix'], fields['level_number'], 250 | fields['level_number_suffix']) 251 | else: 252 | level = ('', np.zeros((0, n_labels))) 253 | 254 | if fields['flat_type'] > 0: 255 | flat_number = generate_flat_number( 256 | fields['flat_type'], fields['flat_number_prefix'], fields['flat_number'], fields['flat_number_suffix']) 257 | else: 258 | flat_number = ('', np.zeros((0, n_labels))) 259 | 260 | street_number = generate_street_number(fields['number_first_prefix'], fields['number_first'], 261 | fields['number_first_suffix'], fields['number_last_prefix'], 262 | fields['number_last'], fields['number_last_suffix']) 263 | street = generate_street_name(fields['street_name'], fields['street_suffix_code'], fields['street_type_code']) 264 | suburb = labels(fields['locality_name'], 'locality_name') 265 | state = generate_state(fields['state_abbreviation']) 266 | postcode = labels(fields['postcode'], 'postcode') 267 | building_name = labels(fields['building_name'], 'building_name') 268 | 269 | # Begin composing the formatted address, building up the `parts` variable... 270 | 271 | suburb_state_postcode = list() 272 | # Keep the suburb? 273 | choose(lambda: suburb_state_postcode.append(suburb)) 274 | # Keep state? 275 | choose(lambda: suburb_state_postcode.append(state)) 276 | # Keep postcode? 277 | choose(lambda: suburb_state_postcode.append(postcode)) 278 | 279 | random.shuffle(suburb_state_postcode) 280 | 281 | parts = [[building_name], [level]] 282 | 283 | # Keep the street number? (If street number is dropped, the flat number is also dropped) 284 | def keep_street_number(): 285 | # force flat number to be next to street number only if the flat number is only digits (i.e. does not have a 286 | # flat type) 287 | if flat_number[0].isdigit(): 288 | parts.append([flat_number, street_number, street]) 289 | else: 290 | parts.append([flat_number]) 291 | parts.append([street_number, street]) 292 | 293 | choose(keep_street_number, lambda: parts.append([street])) 294 | 295 | random.shuffle(parts) 296 | 297 | # Suburb, state, postcode is always at the end of an address 298 | parts.append(suburb_state_postcode) 299 | 300 | # Flatten the address components into an unnested list 301 | parts = sum(parts, []) 302 | 303 | # Join each address component/label with a random separator 304 | address, address_lbl = join_str_and_labels(parts, sep=lambda: random_separator(1, 3)) 305 | 306 | # Encode 307 | length, text_encoded = vocab_lookup(address) 308 | return length, text_encoded, address_lbl 309 | 310 | 311 | def generate_state(state_abbreviation: int) -> (str, np.ndarray): 312 | """ 313 | Generates the string and labels for the state, randomly abbreviated 314 | :param state_abbreviation: the state code 315 | :return: string and labels 316 | """ 317 | state = lookups.lookup_state(state_abbreviation, reverse_lookup=True) 318 | return labels(choose(lambda: lookups.expand_state(state), lambda: state), 'state_abbreviation') 319 | 320 | 321 | def generate_level_number(level_type: int, level_number_prefix: str, level_number: int, level_number_suffix: str) -> ( 322 | str, np.ndarray): 323 | """ 324 | Generates the level number for the address 325 | :param level_type: level type code 326 | :param level_number_prefix: number prefix 327 | :param level_number: level number 328 | :param level_number_suffix: level number suffix 329 | :return: string and labels 330 | """ 331 | 332 | level_type = labels(lookups.lookup_level_type(level_type, reverse_lookup=True), 'level_type') 333 | 334 | # Decide whether to transform the level number 335 | def do_transformation(): 336 | if not level_number_prefix and not level_number_suffix and level_type[0]: 337 | # If there is no prefix/suffix, decide whether to convert to ordinal numbers (1st, 2nd, etc.) 338 | def use_ordinal_numbers(lvl_num, lvl_type): 339 | # Use ordinal words (first, second, third) or numbers (1st, 2nd, 3rd)? 340 | lvl_num = choose(lambda: lookups.num2word(lvl_num, output='ordinal_words'), 341 | lambda: lookups.num2word(lvl_num, output='ordinal')) 342 | lvl_num = labels(lvl_num, 'level_number') 343 | return join_str_and_labels([lvl_num, lvl_type], 344 | sep=lambda: random_separator(1, 3, possible_sep_chars=None)) 345 | 346 | def use_cardinal_numbers(lvl_num, lvl_type): 347 | # Treat level 1 as GROUND? 348 | if lvl_num == 1: 349 | lvl_num = choose(lambda: "GROUND", lambda: 1) 350 | else: 351 | lvl_num = lookups.num2word(lvl_num, output='cardinal') 352 | lvl_num = labels(lvl_num, 'level_number') 353 | return join_str_and_labels([lvl_type, lvl_num], 354 | sep=lambda: random_separator(1, 3, possible_sep_chars=None)) 355 | 356 | return choose(lambda: use_ordinal_numbers(level_number, level_type), 357 | lambda: use_cardinal_numbers(level_number, level_type)) 358 | 359 | transformed_value = choose(do_transformation) 360 | if transformed_value: 361 | return transformed_value 362 | else: 363 | level_number_prefix = labels(level_number_prefix, 'level_number_prefix') 364 | level_number = labels(level_number, 'level_number') 365 | level_number_suffix = labels(level_number_suffix, 'level_number_suffix') 366 | return join_str_and_labels([level_type, level_number_prefix, level_number, level_number_suffix], 367 | sep=lambda: random_separator(1, 3, possible_sep_chars=None)) 368 | 369 | 370 | def generate_flat_number( 371 | flat_type: int, flat_number_prefix: str, flat_number: int, flat_number_suffix: str) -> (str, np.ndarray): 372 | """ 373 | Generates the flat number for the address 374 | :param flat_type: flat type code 375 | :param flat_number_prefix: number prefix 376 | :param flat_number: number 377 | :param flat_number_suffix: number suffix 378 | :return: string and labels 379 | """ 380 | flat_type = labels(lookups.lookup_flat_type(flat_type, reverse_lookup=True), 'flat_type') 381 | flat_number_prefix = labels(flat_number_prefix, 'flat_number_prefix') 382 | flat_number = labels(flat_number, 'flat_number') 383 | flat_number_suffix = labels(flat_number_suffix, 'flat_number_suffix') 384 | 385 | flat_number = join_str_and_labels([flat_number_prefix, flat_number, flat_number_suffix], 386 | sep=lambda: random_separator(0, 2, possible_sep_chars=None)) 387 | 388 | return choose( 389 | lambda: join_str_and_labels([flat_type, flat_number], sep=random_separator(0, 2, possible_sep_chars=None)), 390 | lambda: flat_number) 391 | 392 | 393 | def generate_street_number(number_first_prefix: str, number_first: int, number_first_suffix, 394 | number_last_prefix, number_last, number_last_suffix) -> (str, np.ndarray): 395 | """ 396 | Generates a street number using the prefix, suffix, first and last number components 397 | :param number_first_prefix: prefix to the first street number 398 | :param number_first: first street number 399 | :param number_first_suffix: suffix to the first street number 400 | :param number_last_prefix: prefix to the last street number 401 | :param number_last: last street number 402 | :param number_last_suffix: suffix to the last street number 403 | :return: the street number 404 | """ 405 | 406 | number_first_prefix = labels(number_first_prefix, 'number_first_prefix') 407 | number_first = labels(number_first, 'number_first') 408 | number_first_suffix = labels(number_first_suffix, 'number_first_suffix') 409 | 410 | number_last_prefix = labels(number_last_prefix, 'number_last_prefix') 411 | number_last = labels(number_last, 'number_last') 412 | number_last_suffix = labels(number_last_suffix, 'number_last_suffix') 413 | 414 | a = join_str_and_labels([number_first_prefix, number_first, number_first_suffix], 415 | lambda: random_separator(0, 2, possible_sep_chars=None)) 416 | b = join_str_and_labels([number_last_prefix, number_last, number_last_suffix], 417 | lambda: random_separator(0, 2, possible_sep_chars=None)) 418 | 419 | return join_str_and_labels([a, b], sep=random_separator(1, 3, possible_sep_chars=r"---- \/")) 420 | 421 | 422 | def generate_street_name(street_name: str, street_suffix_code: str, street_type_code: str) -> (str, np.ndarray): 423 | """ 424 | Generates a possible street name variation 425 | :param street_name: the street's name 426 | :param street_suffix_code: the street suffix code 427 | :param street_type_code: the street type code 428 | :return: string and labels 429 | """ 430 | street_name, street_name_lbl = labels(street_name, 'street_name') 431 | 432 | street_type = lookups.lookup_street_type(street_type_code, reverse_lookup=True) 433 | street_type = choose(lambda: lookups.abbreviate_street_type(street_type), lambda: street_type) 434 | street_type, street_type_lbl = labels(street_type, 'street_type_code') 435 | 436 | street_suffix = lookups.lookup_street_suffix(street_suffix_code, reverse_lookup=True) 437 | street_suffix = choose(lambda: lookups.expand_street_type_suffix(street_suffix), lambda: street_suffix) 438 | street_suffix, street_suffix_lbl = labels(street_suffix, 'street_suffix_code') 439 | 440 | return choose(lambda: join_str_and_labels([ 441 | (street_name, street_name_lbl), 442 | (street_suffix, street_suffix_lbl), 443 | (street_type, street_type_lbl) 444 | ]), lambda: join_str_and_labels([ 445 | (street_name, street_name_lbl), 446 | (street_type, street_type_lbl), 447 | (street_suffix, street_suffix_lbl) 448 | ])) 449 | 450 | 451 | def dataset(filenames: [str], batch_size: int = 10, shuffle_buffer: int = 1000, prefetch_buffer_size: int = 10000, 452 | num_parallel_calls: int = 8) -> Callable: 453 | """ 454 | Creates a Tensorflow dataset and iterator operations 455 | :param filenames: the tfrecord filenames 456 | :param batch_size: training batch size 457 | :param shuffle_buffer: shuffle buffer size 458 | :param prefetch_buffer_size: size of the prefetch buffer 459 | :param num_parallel_calls: number of parallel calls for the mapping functions 460 | :return: the input_fn 461 | """ 462 | 463 | def input_fn() -> tf.data.Dataset: 464 | ds = tf.data.TFRecordDataset(filenames, compression_type="GZIP") 465 | ds = ds.shuffle(buffer_size=shuffle_buffer) 466 | ds = ds.map(lambda record: tf.parse_single_example(record, features=_features), num_parallel_calls=8) 467 | ds = ds.map( 468 | lambda record: tf.py_func(synthesise_address, [record[k] for k in _features.keys()], 469 | [tf.int64, tf.int64, tf.bool], 470 | stateful=False), 471 | num_parallel_calls=num_parallel_calls 472 | ) 473 | 474 | ds = ds.padded_batch(batch_size, ([], [None], [None, n_labels])) 475 | 476 | ds = ds.map( 477 | lambda _lengths, _encoded_text, _labels: ({'lengths': _lengths, 'encoded_text': _encoded_text}, _labels), 478 | num_parallel_calls=num_parallel_calls 479 | ) 480 | ds = ds.prefetch(buffer_size=prefetch_buffer_size) 481 | return ds 482 | 483 | return input_fn 484 | 485 | 486 | def predict_input_fn(input_text: List[str]) -> Callable: 487 | """ 488 | An input function for one prediction example 489 | :param input_text: the input text 490 | :return: 491 | """ 492 | 493 | def input_fn() -> tf.data.Dataset: 494 | predict_ds = tf.data.Dataset.from_generator( 495 | lambda: (vocab_lookup(address) for address in input_text), 496 | (tf.int64, tf.int64), 497 | (tf.TensorShape([]), tf.TensorShape([None])) 498 | ) 499 | predict_ds = predict_ds.batch(1) 500 | predict_ds = predict_ds.map( 501 | lambda lengths, encoded_text: {'lengths': lengths, 'encoded_text': encoded_text} 502 | ) 503 | return predict_ds 504 | 505 | return input_fn 506 | -------------------------------------------------------------------------------- /addressnet/lookups.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Union 3 | 4 | # Categorical types as per the GNAF dataset, see: https://data.gov.au/dataset/geocoded-national-address-file-g-naf 5 | flat_types = ('ANTENNA', 'APARTMENT', 'AUTOMATED TELLER MACHINE', 'BARBECUE', 'BLOCK', 'BOATSHED', 'BUILDING', 6 | 'BUNGALOW', 'CAGE', 'CARPARK', 'CARSPACE', 'CLUB', 'COOLROOM', 'COTTAGE', 'DUPLEX', 'FACTORY', 'FLAT', 7 | 'GARAGE', 'HALL', 'HOUSE', 'KIOSK', 'LEASE', 'LOBBY', 'LOFT', 'LOT', 'MAISONETTE', 'MARINE BERTH', 8 | 'OFFICE', 'PENTHOUSE', 'REAR', 'RESERVE', 'ROOM', 'SECTION', 'SHED', 'SHOP', 'SHOWROOM', 'SIGN', 'SITE', 9 | 'STALL', 'STORE', 'STRATA UNIT', 'STUDIO', 'SUBSTATION', 'SUITE', 'TENANCY', 'TOWER', 'TOWNHOUSE', 10 | 'UNIT', 'VAULT', 'VILLA', 'WARD', 'WAREHOUSE', 'WORKSHOP') 11 | 12 | level_types = ('BASEMENT', 'FLOOR', 'GROUND', 'LEVEL', 'LOBBY', 'LOWER GROUND FLOOR', 'MEZZANINE', 'OBSERVATION DECK', 13 | 'PARKING', 'PENTHOUSE', 'PLATFORM', 'PODIUM', 'ROOFTOP', 'SUB-BASEMENT', 'UPPER GROUND FLOOR') 14 | 15 | street_types = ('ACCESS', 'ACRE', 'AIRWALK', 'ALLEY', 'ALLEYWAY', 'AMBLE', 'APPROACH', 'ARCADE', 'ARTERIAL', 'ARTERY', 16 | 'AVENUE', 'BANAN', 'BANK', 'BAY', 'BEACH', 'BEND', 'BOARDWALK', 'BOULEVARD', 'BOULEVARDE', 'BOWL', 17 | 'BRACE', 'BRAE', 'BRANCH', 'BREAK', 'BRETT', 'BRIDGE', 'BROADWALK', 'BROADWAY', 'BROW', 'BULL', 18 | 'BUSWAY', 'BYPASS', 'BYWAY', 'CAUSEWAY', 'CENTRE', 'CENTREWAY', 'CHASE', 'CIRCLE', 'CIRCLET', 19 | 'CIRCUIT', 'CIRCUS', 'CLOSE', 'CLUSTER', 'COLONNADE', 'COMMON', 'COMMONS', 'CONCORD', 'CONCOURSE', 20 | 'CONNECTION', 'COPSE', 'CORNER', 'CORSO', 'COURSE', 'COURT', 'COURTYARD', 'COVE', 'CRESCENT', 'CREST', 21 | 'CRIEF', 'CROOK', 'CROSS', 'CROSSING', 'CRUISEWAY', 'CUL-DE-SAC', 'CUT', 'CUTTING', 'DALE', 'DASH', 22 | 'DELL', 'DENE', 'DEVIATION', 'DIP', 'DISTRIBUTOR', 'DIVIDE', 'DOCK', 'DOMAIN', 'DOWN', 'DOWNS', 23 | 'DRIVE', 'DRIVEWAY', 'EASEMENT', 'EAST', 'EDGE', 'ELBOW', 'END', 'ENTRANCE', 'ESPLANADE', 'ESTATE', 24 | 'EXPRESSWAY', 'EXTENSION', 'FAIRWAY', 'FIREBREAK', 'FIRELINE', 'FIRETRACK', 'FIRETRAIL', 'FLAT', 25 | 'FLATS', 'FOLLOW', 'FOOTWAY', 'FORD', 'FORESHORE', 'FORK', 'FORMATION', 'FREEWAY', 'FRONT', 'FRONTAGE', 26 | 'GAP', 'GARDEN', 'GARDENS', 'GATE', 'GATEWAY', 'GLADE', 'GLEN', 'GRANGE', 'GREEN', 'GROVE', 'GULLY', 27 | 'HARBOUR', 'HAVEN', 'HEATH', 'HEIGHTS', 'HIGHROAD', 'HIGHWAY', 'HIKE', 'HILL', 'HILLS', 'HOLLOW', 28 | 'HUB', 'INLET', 'INTERCHANGE', 'ISLAND', 'JUNCTION', 'KEY', 'KEYS', 'KNOLL', 'LADDER', 'LANDING', 29 | 'LANE', 'LANEWAY', 'LEAD', 'LEADER', 'LINE', 'LINK', 'LOOKOUT', 'LOOP', 'LYNNE', 'MALL', 'MANOR', 30 | 'MART', 'MAZE', 'MEAD', 'MEANDER', 'MEW', 'MEWS', 'MILE', 'MOTORWAY', 'NOOK', 'NORTH', 'NULL', 31 | 'OUTLET', 'OUTLOOK', 'OVAL', 'PALMS', 'PARADE', 'PARADISE', 'PARK', 'PARKWAY', 'PART', 'PASS', 32 | 'PASSAGE', 'PATH', 'PATHWAY', 'PENINSULA', 'PIAZZA', 'PLACE', 'PLAZA', 'POCKET', 'POINT', 'PORT', 33 | 'PRECINCT', 'PROMENADE', 'PURSUIT', 'QUAD', 'QUADRANT', 'QUAY', 'QUAYS', 'RAMBLE', 'RAMP', 'RANGE', 34 | 'REACH', 'REEF', 'RESERVE', 'REST', 'RETREAT', 'RETURN', 'RIDE', 'RIDGE', 'RIGHT OF WAY', 'RING', 35 | 'RISE', 'RISING', 'RIVER', 'ROAD', 'ROADS', 'ROADWAY', 'ROTARY', 'ROUND', 'ROUTE', 'ROW', 'ROWE', 36 | 'RUE', 'RUN', 'SERVICEWAY', 'SHUNT', 'SKYLINE', 'SLOPE', 'SOUTH', 'SPUR', 'SQUARE', 'STEPS', 37 | 'STRAIGHT', 'STRAIT', 'STRAND', 'STREET', 'STRIP', 'SUBWAY', 'TARN', 'TERRACE', 'THOROUGHFARE', 38 | 'THROUGHWAY', 'TOLLWAY', 'TOP', 'TOR', 'TRACK', 'TRAIL', 'TRAMWAY', 'TRAVERSE', 'TRIANGLE', 'TRUNKWAY', 39 | 'TUNNEL', 'TURN', 'TWIST', 'UNDERPASS', 'VALE', 'VALLEY', 'VERGE', 'VIADUCT', 'VIEW', 'VIEWS', 'VILLA', 40 | 'VILLAGE', 'VILLAS', 'VISTA', 'VUE', 'WADE', 'WALK', 'WALKWAY', 'WATERS', 'WATERWAY', 'WAY', 'WEST', 41 | 'WHARF', 'WOOD', 'WOODS', 'WYND', 'YARD') 42 | 43 | street_suffix_types = OrderedDict([('CN', 'CENTRAL'), ('DE', 'DEVIATION'), ('E', 'EAST'), ('EX', 'EXTENSION'), 44 | ('IN', 'INNER'), ('LR', 'LOWER'), ('ML', 'MALL'), ('N', 'NORTH'), 45 | ('NE', 'NORTH EAST'), ('NW', 'NORTH WEST'), ('OF', 'OFF'), ('ON', 'ON'), 46 | ('OT', 'OUTER'), ('OP', 'OVERPASS'), ('S', 'SOUTH'), ('SE', 'SOUTH EAST'), 47 | ('SW', 'SOUTH WEST'), ('UP', 'UPPER'), ('W', 'WEST')]) 48 | 49 | states = OrderedDict([('ACT', 'AUSTRALIAN CAPITAL TERRITORY'), ('NSW', 'NEW SOUTH WALES'), 50 | ('NT', 'NORTHERN TERRITORY'), ('OT', 'OTHER TERRITORIES'), ('QLD', 'QUEENSLAND'), 51 | ('SA', 'SOUTH AUSTRALIA'), ('TAS', 'TASMANIA'), ('VIC', 'VICTORIA'), 52 | ('WA', 'WESTERN AUSTRALIA')]) 53 | 54 | # Abbreviaitons from METeOR identifier: 429387 55 | # see https://meteor.aihw.gov.au/content/index.phtml/itemId/429387/pageDefinitionItemId/tag.MeteorPrinterFriendlyPage 56 | street_type_abbreviation = {'ACCESS': 'ACCS', 'ALLEY': 'ALLY', 'ALLEYWAY': 'ALWY', 'AMBLE': 'AMBL', 'APPROACH': 'APP', 57 | 'ARCADE': 'ARC', 'ARTERIAL': 'ARTL', 'ARTERY': 'ARTY', 'AVENUE': 'AV', 'BANAN': 'BA', 58 | 'BEND': 'BEND', 'BOARDWALK': 'BWLK', 'BOULEVARD': 'BVD', 'BRACE': 'BR', 'BRAE': 'BRAE', 59 | 'BREAK': 'BRK', 'BROW': 'BROW', 'BYPASS': 'BYPA', 'BYWAY': 'BYWY', 'CAUSEWAY': 'CSWY', 60 | 'CENTRE': 'CTR', 'CHASE': 'CH', 'CIRCLE': 'CIR', 'CIRCUIT': 'CCT', 'CIRCUS': 'CRCS', 61 | 'CLOSE': 'CL', 'CONCOURSE': 'CON', 'COPSE': 'CPS', 'CORNER': 'CNR', 'COURT': 'CT', 62 | 'COURTYARD': 'CTYD', 'COVE': 'COVE', 'CRESCENT': 'CR', 'CREST': 'CRST', 'CROSS': 'CRSS', 63 | 'CUL-DE-SAC': 'CSAC', 'CUTTING': 'CUTT', 'DALE': 'DALE', 'DIP': 'DIP', 'DRIVE': 'DR', 64 | 'DRIVEWAY': 'DVWY', 'EDGE': 'EDGE', 'ELBOW': 'ELB', 'END': 'END', 'ENTRANCE': 'ENT', 65 | 'ESPLANADE': 'ESP', 'EXPRESSWAY': 'EXP', 'FAIRWAY': 'FAWY', 'FOLLOW': 'FOLW', 66 | 'FOOTWAY': 'FTWY', 'FORMATION': 'FORM', 'FREEWAY': 'FWY', 'FRONTAGE': 'FRTG', 67 | 'GAP': 'GAP', 'GARDENS': 'GDNS', 'GATE': 'GTE', 'GLADE': 'GLDE', 'GLEN': 'GLEN', 68 | 'GRANGE': 'GRA', 'GREEN': 'GRN', 'GROVE': 'GR', 'HEIGHTS': 'HTS', 'HIGHROAD': 'HIRD', 69 | 'HIGHWAY': 'HWY', 'HILL': 'HILL', 'INTERCHANGE': 'INTG', 'JUNCTION': 'JNC', 'KEY': 'KEY', 70 | 'LANE': 'LANE', 'LANEWAY': 'LNWY', 'LINE': 'LINE', 'LINK': 'LINK', 'LOOKOUT': 'LKT', 71 | 'LOOP': 'LOOP', 'MALL': 'MALL', 'MEANDER': 'MNDR', 'MEWS': 'MEWS', 'MOTORWAY': 'MTWY', 72 | 'NOOK': 'NOOK', 'OUTLOOK': 'OTLK', 'PARADE': 'PDE', 'PARKWAY': 'PWY', 'PASS': 'PASS', 73 | 'PASSAGE': 'PSGE', 'PATH': 'PATH', 'PATHWAY': 'PWAY', 'PIAZZA': 'PIAZ', 'PLAZA': 'PLZA', 74 | 'POCKET': 'PKT', 'POINT': 'PNT', 'PORT': 'PORT', 'PROMENADE': 'PROM', 'QUADRANT': 'QDRT', 75 | 'QUAYS': 'QYS', 'RAMBLE': 'RMBL', 'REST': 'REST', 'RETREAT': 'RTT', 'RIDGE': 'RDGE', 76 | 'RISE': 'RISE', 'ROAD': 'RD', 'ROTARY': 'RTY', 'ROUTE': 'RTE', 'ROW': 'ROW', 'RUE': 'RUE', 77 | 'SERVICEWAY': 'SVWY', 'SHUNT': 'SHUN', 'SPUR': 'SPUR', 'SQUARE': 'SQ', 'STREET': 'ST', 78 | 'SUBWAY': 'SBWY', 'TARN': 'TARN', 'TERRACE': 'TCE', 'THOROUGHFARE': 'THFR', 79 | 'TOLLWAY': 'TLWY', 'TOP': 'TOP', 'TOR': 'TOR', 'TRACK': 'TRK', 'TRAIL': 'TRL', 80 | 'TURN': 'TURN', 'UNDERPASS': 'UPAS', 'VALE': 'VALE', 'VIADUCT': 'VIAD', 'VIEW': 'VIEW', 81 | 'VISTA': 'VSTA', 'WALK': 'WALK', 'WALKWAY': 'WKWY', 'WHARF': 'WHRF', 'WYND': 'WYND'} 82 | 83 | ordinal_words = [ 84 | 'first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'eleventh', 85 | 'twelfth', 'thirteenth', 'fourteenth', 'fifteenth', 'sixteenth', 'seventeenth', 'eighteenth', 'nineteenth', 86 | 'twentieth', 'twenty-first', 'twenty-second', 'twenty-third', 'twenty-fourth', 'twenty-fifth', 'twenty-sixth', 87 | 'twenty-seventh', 'twenty-eighth', 'twenty-ninth', 'thirtieth', 'thirty-first', 'thirty-second', 'thirty-third', 88 | 'thirty-fourth', 'thirty-fifth', 'thirty-sixth', 'thirty-seventh', 'thirty-eighth', 'thirty-ninth', 'fortieth', 89 | 'forty-first', 'forty-second', 'forty-third', 'forty-fourth', 'forty-fifth', 'forty-sixth', 'forty-seventh', 90 | 'forty-eighth', 'forty-ninth', 'fiftieth', 'fifty-first', 'fifty-second', 'fifty-third', 'fifty-fourth', 91 | 'fifty-fifth', 'fifty-sixth', 'fifty-seventh', 'fifty-eighth', 'fifty-ninth', 'sixtieth', 'sixty-first', 92 | 'sixty-second', 'sixty-third', 'sixty-fourth', 'sixty-fifth', 'sixty-sixth', 'sixty-seventh', 'sixty-eighth', 93 | 'sixty-ninth', 'seventieth', 'seventy-first', 'seventy-second', 'seventy-third', 'seventy-fourth', 'seventy-fifth', 94 | 'seventy-sixth', 'seventy-seventh', 'seventy-eighth', 'seventy-ninth', 'eightieth', 'eighty-first', 'eighty-second', 95 | 'eighty-third', 'eighty-fourth', 'eighty-fifth', 'eighty-sixth', 'eighty-seventh', 'eighty-eighth', 'eighty-ninth', 96 | 'ninetieth', 'ninety-first', 'ninety-second', 'ninety-third', 'ninety-fourth', 'ninety-fifth', 'ninety-sixth', 97 | 'ninety-seventh', 'ninety-eighth', 'ninety-ninth', 'one hundredth' 98 | ] 99 | 100 | cardinal_words = [ 101 | 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 102 | 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen', 'twenty', 'twenty-one', 'twenty-two', 103 | 'twenty-three', 'twenty-four', 'twenty-five', 'twenty-six', 'twenty-seven', 'twenty-eight', 'twenty-nine', 'thirty', 104 | 'thirty-one', 'thirty-two', 'thirty-three', 'thirty-four', 'thirty-five', 'thirty-six', 'thirty-seven', 105 | 'thirty-eight', 'thirty-nine', 'forty', 'forty-one', 'forty-two', 'forty-three', 'forty-four', 'forty-five', 106 | 'forty-six', 'forty-seven', 'forty-eight', 'forty-nine', 'fifty', 'fifty-one', 'fifty-two', 'fifty-three', 107 | 'fifty-four', 'fifty-five', 'fifty-six', 'fifty-seven', 'fifty-eight', 'fifty-nine', 'sixty', 'sixty-one', 108 | 'sixty-two', 'sixty-three', 'sixty-four', 'sixty-five', 'sixty-six', 'sixty-seven', 'sixty-eight', 'sixty-nine', 109 | 'seventy', 'seventy-one', 'seventy-two', 'seventy-three', 'seventy-four', 'seventy-five', 'seventy-six', 110 | 'seventy-seven', 'seventy-eight', 'seventy-nine', 'eighty', 'eighty-one', 'eighty-two', 'eighty-three', 111 | 'eighty-four', 'eighty-five', 'eighty-six', 'eighty-seven', 'eighty-eight', 'eighty-nine', 'ninety', 'ninety-one', 112 | 'ninety-two', 'ninety-three', 'ninety-four', 'ninety-five', 'ninety-six', 'ninety-seven', 'ninety-eight', 113 | 'ninety-nine', 'one hundred' 114 | ] 115 | 116 | 117 | def _lookup(t: str, types: [str]) -> int: 118 | """ 119 | Looks up the value, t, from the array of types 120 | :param t: value to lookup 121 | :param types: list of types from which to lookup 122 | :return: an integer value > 0 if found, or 0 if not found 123 | """ 124 | try: 125 | return types.index(t.strip().upper()) + 1 126 | except ValueError: 127 | return 0 128 | 129 | 130 | def _reverse_lookup(idx: int, types: [str]) -> str: 131 | """ 132 | Converts an integer value back to the string representation 133 | :param idx: integer value 134 | :param types: list of types 135 | :return: the string value or None if not found (idx == 0) 136 | """ 137 | if idx == 0: 138 | return '' 139 | else: 140 | return types[idx - 1] 141 | 142 | 143 | def lookup_state(state: Union[str, int], reverse_lookup=False) -> Union[str, int]: 144 | """ 145 | Converts the representation for the geographic state 146 | :param state: string or int to lookup 147 | :param reverse_lookup: True if converting int to string, or False if string to int 148 | :return: the encoded value 149 | """ 150 | if reverse_lookup: 151 | return _reverse_lookup(state, list(states.keys())) 152 | return _lookup(state, list(states.keys())) 153 | 154 | 155 | def expand_state(state: str) -> str: 156 | """ 157 | Converts an abbreviated state name to the full name, e.g. "VIC" -> "VICTORIA" 158 | :param state: abbreviated state 159 | :return: full state 160 | """ 161 | return states[state.strip().upper()] 162 | 163 | 164 | def lookup_street_type(street_type: Union[str, int], reverse_lookup=False) -> Union[str, int]: 165 | """ 166 | Converts the representation for the street type 167 | :param street_type: string or int to lookup 168 | :param reverse_lookup: True if converting int to string, or False if string to int 169 | :return: the encoded value 170 | """ 171 | if reverse_lookup: 172 | return _reverse_lookup(street_type, street_types) 173 | return _lookup(street_type, street_types) 174 | 175 | 176 | def abbreviate_street_type(street_type: str) -> str: 177 | """ 178 | Converts an full street type to the abbreviated name, e.g. "STREET" -> "ST" 179 | :param street_type: full street type 180 | :return: abbreviated street type 181 | """ 182 | try: 183 | return street_type_abbreviation[street_type.strip().upper()] 184 | except KeyError: 185 | return street_type 186 | 187 | 188 | def lookup_street_suffix(street_suffix: Union[str, int], reverse_lookup=False) -> Union[str, int]: 189 | """ 190 | Converts the representation for the street type suffix 191 | :param street_suffix: string or int to lookup 192 | :param reverse_lookup: True if converting int to string, or False if string to int 193 | :return: the encoded value 194 | """ 195 | if reverse_lookup: 196 | return _reverse_lookup(street_suffix, list(street_suffix_types.keys())) 197 | return _lookup(street_suffix, list(street_suffix_types.keys())) 198 | 199 | 200 | def expand_street_type_suffix(street_suffix: str) -> str: 201 | """ 202 | Converts an abbreviated street suffix to the full name, e.g. "N" -> "NORTH" 203 | :param street_suffix: abbreviated street suffix 204 | :return: full street suffix 205 | """ 206 | try: 207 | return street_suffix_types[street_suffix.strip().upper()] 208 | except KeyError: 209 | return street_suffix 210 | 211 | 212 | def lookup_level_type(level_type: Union[str, int], reverse_lookup=False) -> Union[str, int]: 213 | """ 214 | Converts the representation for the level type 215 | :param level_type: string or int to lookup 216 | :param reverse_lookup: True if converting int to string, or False if string to int 217 | :return: the encoded value 218 | """ 219 | if reverse_lookup: 220 | return _reverse_lookup(level_type, level_types) 221 | return _lookup(level_type, level_types) 222 | 223 | 224 | def lookup_flat_type(flat_type: Union[str, int], reverse_lookup=False) -> Union[str, int]: 225 | """ 226 | Converts the representation for the flat type 227 | :param flat_type: string or int to lookup 228 | :param reverse_lookup: True if converting int to string, or False if string to int 229 | :return: the encoded value 230 | """ 231 | if reverse_lookup: 232 | return _reverse_lookup(flat_type, flat_types) 233 | return _lookup(flat_type, flat_types) 234 | 235 | 236 | # Adapted from http://code.activestate.com/recipes/576888-format-a-number-as-an-ordinal/ 237 | def num2word(value, output='ordinal_words'): 238 | """ 239 | Converts zero or a *postive* integer (or their string 240 | representations) to an ordinal/cardinal value. 241 | :param value: the number to convert 242 | :param output: one of 'ordinal_words', 'ordinal', 'cardinal' 243 | """ 244 | try: 245 | value = int(value) 246 | except ValueError: 247 | return value 248 | 249 | assert output in ( 250 | 'ordinal_words', 'ordinal', 'cardinal'), "`output` must be one of 'ordinal_words', 'ordinal' or 'cardinal'" 251 | 252 | if output == 'ordinal_words' and (0 < value < 100): 253 | val = ordinal_words[value - 1] 254 | elif output == 'ordinal_words': 255 | raise ValueError("'ordinal_words' only supported between 1 and 100") 256 | elif output == 'ordinal': 257 | if value % 100 // 10 != 1: 258 | if value % 10 == 1: 259 | val = u"%d%s" % (value, "st") 260 | elif value % 10 == 2: 261 | val = u"%d%s" % (value, "nd") 262 | elif value % 10 == 3: 263 | val = u"%d%s" % (value, "rd") 264 | else: 265 | val = u"%d%s" % (value, "th") 266 | else: 267 | val = u"%d%s" % (value, "th") 268 | else: 269 | val = cardinal_words[value - 1] 270 | 271 | return val.upper() 272 | -------------------------------------------------------------------------------- /addressnet/model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import tensorflow as tf 4 | 5 | from addressnet.dataset import vocab, n_labels 6 | 7 | 8 | def model_fn(features: Dict[str, tf.Tensor], labels: tf.Tensor, mode: str, params) -> tf.estimator.EstimatorSpec: 9 | """ 10 | The AddressNet model function suitable for tf.estimator.Estimator 11 | :param features: a dictionary containing tensors for the encoded_text and lengths 12 | :param labels: a label for each character designating its position in the address 13 | :param mode: indicates whether the model is being trained, evaluated or used in prediction mode 14 | :param params: model hyperparameters, including rnn_size and rnn_layers 15 | :return: the appropriate tf.estimator.EstimatorSpec for the model mode 16 | """ 17 | encoded_text, lengths = features['encoded_text'], features['lengths'] 18 | rnn_size = params.get("rnn_size", 128) 19 | rnn_layers = params.get("rnn_layers", 3) 20 | 21 | embeddings = tf.get_variable("embeddings", dtype=tf.float32, initializer=tf.random_normal(shape=(len(vocab), 8))) 22 | encoded_strings = tf.nn.embedding_lookup(embeddings, encoded_text) 23 | 24 | logits, loss = nnet(encoded_strings, lengths, rnn_layers, rnn_size, labels, mode == tf.estimator.ModeKeys.TRAIN) 25 | 26 | predicted_classes = tf.argmax(logits, axis=2) 27 | 28 | if mode == tf.estimator.ModeKeys.PREDICT: 29 | predictions = { 30 | 'class_ids': predicted_classes, 31 | 'probabilities': tf.nn.softmax(logits) 32 | } 33 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 34 | 35 | if mode == tf.estimator.ModeKeys.EVAL: 36 | metrics = {} 37 | return tf.estimator.EstimatorSpec( 38 | mode, loss=loss, eval_metric_ops=metrics) 39 | 40 | if mode == tf.estimator.ModeKeys.TRAIN: 41 | train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss, global_step=tf.train.get_global_step()) 42 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) 43 | 44 | 45 | def nnet(encoded_strings: tf.Tensor, lengths: tf.Tensor, rnn_layers: int, rnn_size: int, labels: tf.Tensor = None, 46 | training: bool = True) -> (tf.Tensor, Optional[tf.Tensor]): 47 | """ 48 | Generates the RNN component of the model 49 | :param encoded_strings: a tensor containing the encoded strings (embedding vectors) 50 | :param lengths: a tensor of string lengths 51 | :param rnn_layers: number of layers to use in the RNN 52 | :param rnn_size: number of units in each layer 53 | :param labels: labels for each character in the string (optional) 54 | :param training: if True, dropout will be enabled on the RNN 55 | :return: logits and loss (loss will be None if labels is not provided) 56 | """ 57 | 58 | def rnn_cell(): 59 | probs = 0.8 if training else 1.0 60 | return tf.contrib.rnn.DropoutWrapper(tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell(rnn_size), 61 | state_keep_prob=probs, output_keep_prob=probs) 62 | 63 | rnn_cell_fw = tf.nn.rnn_cell.MultiRNNCell([rnn_cell() for _ in range(rnn_layers)]) 64 | rnn_cell_bw = tf.nn.rnn_cell.MultiRNNCell([rnn_cell() for _ in range(rnn_layers)]) 65 | 66 | (rnn_output_fw, rnn_output_bw), states = tf.nn.bidirectional_dynamic_rnn(rnn_cell_fw, rnn_cell_bw, encoded_strings, 67 | lengths, dtype=tf.float32) 68 | rnn_output = tf.concat([rnn_output_fw, rnn_output_bw], axis=2) 69 | logits = tf.layers.dense(rnn_output, n_labels, activation=tf.nn.elu) 70 | 71 | loss = None 72 | if labels is not None: 73 | mask = tf.sequence_mask(lengths, dtype=tf.float32) 74 | loss = tf.losses.softmax_cross_entropy(labels, logits, weights=mask) 75 | return logits, loss 76 | -------------------------------------------------------------------------------- /addressnet/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Union 3 | import textdistance 4 | import tensorflow as tf 5 | 6 | from addressnet.dataset import predict_input_fn, labels_list 7 | from addressnet.lookups import street_types, street_type_abbreviation, states, street_suffix_types, flat_types, \ 8 | level_types 9 | from addressnet.model import model_fn 10 | from functools import lru_cache 11 | 12 | 13 | def _get_best_match(target: str, candidates: Union[List[str], Dict[str, str]], keep_idx: int = 0) -> str: 14 | """ 15 | Returns the most similar string to the target given a dictionary or list of candidates. If a dictionary is provided, 16 | the keys and values are compared to the target, but only the requested component of the matched tuple is returned. 17 | :param target: the target string to be matched 18 | :param candidates: a key-value dictionary or list of strings 19 | :param keep_idx: 0 to return the key, 1 to return the value of the best match (no effect if list is supplied) 20 | :return: the matched string 21 | """ 22 | max_sim = None 23 | best = None 24 | 25 | try: 26 | candidates_list = candidates.items() 27 | except AttributeError: 28 | candidates_list = [(i,) for i in candidates] 29 | keep_idx = 0 30 | 31 | for kv in candidates_list: 32 | if target in kv: 33 | return kv[keep_idx] 34 | 35 | for i in kv: 36 | similarity = _str_sim(i, target) 37 | if max_sim is None or similarity > max_sim: 38 | best = kv[keep_idx] 39 | max_sim = similarity 40 | return best 41 | 42 | 43 | def _str_sim(a, b, fn=textdistance.jaro_winkler): 44 | """ 45 | Wrapper function for the string similarity function 46 | :param a: a string to compare 47 | :param b: another string to compare 48 | :param fn: the string similarity function from the textdistance package 49 | :return: the similarity ratio 50 | """ 51 | return fn.normalized_similarity(a.lower(), b.lower()) 52 | 53 | 54 | def normalise_state(s: str) -> str: 55 | """ 56 | Converts the state parameter to a standard non-abbreviated form 57 | :param s: state string 58 | :return: state name in full 59 | """ 60 | if s in states: 61 | return states[s] 62 | return _get_best_match(s, states, keep_idx=1) 63 | 64 | 65 | def normalise_street_type(s: str) -> str: 66 | """ 67 | Converts the street type parameter to a standard non-abbreviated form 68 | :param s: street type string 69 | :return: street type in full 70 | """ 71 | if s in street_types: 72 | return s 73 | return _get_best_match(s, street_type_abbreviation, keep_idx=0) 74 | 75 | 76 | def normalise_street_suffix(s: str) -> str: 77 | """ 78 | Converts the street suffix parameter to a standard non-abbreviated form 79 | :param s: street suffix string 80 | :return: street suffix in full 81 | """ 82 | if s in street_suffix_types: 83 | return street_suffix_types[s] 84 | return _get_best_match(s, street_suffix_types, keep_idx=1) 85 | 86 | 87 | def normalise_flat_type(s: str) -> str: 88 | """ 89 | Converts the flat type parameter to a standard non-abbreviated form 90 | :param s: flat type string 91 | :return: flat type in full 92 | """ 93 | if s in flat_types: 94 | return s 95 | return _get_best_match(s, flat_types) 96 | 97 | 98 | def normalise_level_type(s: str) -> str: 99 | """ 100 | Converts the level type parameter to a standard non-abbreviated form 101 | :param s: level type string 102 | :return: level type in full 103 | """ 104 | if s in level_types: 105 | return s 106 | return _get_best_match(s, level_types) 107 | 108 | 109 | @lru_cache(maxsize=2) 110 | def _get_estimator(model_fn, model_dir): 111 | return tf.estimator.Estimator(model_fn=model_fn, 112 | model_dir=model_dir) 113 | 114 | 115 | def predict_one(address: str, model_dir: str = None) -> Dict[str, str]: 116 | """ 117 | Segments a given address into its components and attempts to normalise categorical components, 118 | e.g. state, street type 119 | :param address: the input address string 120 | :param model_dir: path to trained model 121 | :return: a dictionary with the address components separated 122 | """ 123 | return next(predict([address], model_dir)) 124 | 125 | 126 | def predict(address: List[str], model_dir: str = None) -> List[Dict[str, str]]: 127 | """ 128 | Segments a set of addresses into their components and attempts to normalise categorical components, 129 | e.g. state, street type 130 | :param address: the input list of address strings 131 | :param model_dir: path to trained model 132 | :return: a list of dictionaries with the address components separated 133 | """ 134 | if model_dir is None: 135 | model_dir = os.path.join(os.path.dirname(__file__), 'pretrained') 136 | assert os.path.isdir(model_dir), "invalid model_dir provided: %s" % model_dir 137 | address_net_estimator = _get_estimator(model_fn, model_dir) 138 | result = address_net_estimator.predict(predict_input_fn(address)) 139 | class_names = [l.replace("_code", "") for l in labels_list] 140 | class_names = [l.replace("_abbreviation", "") for l in class_names] 141 | for addr, res in zip(address, result): 142 | mappings = dict() 143 | for char, class_id in zip(addr.upper(), res['class_ids']): 144 | if class_id == 0: 145 | continue 146 | cls = class_names[class_id - 1] 147 | mappings[cls] = mappings.get(cls, "") + char 148 | 149 | if 'state' in mappings: 150 | mappings['state'] = normalise_state(mappings['state']) 151 | if 'street_type' in mappings: 152 | mappings['street_type'] = normalise_street_type(mappings['street_type']) 153 | if 'street_suffix' in mappings: 154 | mappings['street_suffix'] = normalise_street_suffix(mappings['street_suffix']) 155 | if 'flat_type' in mappings: 156 | mappings['flat_type'] = normalise_flat_type(mappings['flat_type']) 157 | if 'level_type' in mappings: 158 | mappings['level_type'] = normalise_level_type(mappings['level_type']) 159 | 160 | yield mappings 161 | -------------------------------------------------------------------------------- /addressnet/pretrained/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonrig/address-net/28e7c2de030bae56f81c66d7e640dcc2d04fdfb6/addressnet/pretrained/__init__.py -------------------------------------------------------------------------------- /addressnet/pretrained/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | -------------------------------------------------------------------------------- /addressnet/pretrained/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonrig/address-net/28e7c2de030bae56f81c66d7e640dcc2d04fdfb6/addressnet/pretrained/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /addressnet/pretrained/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonrig/address-net/28e7c2de030bae56f81c66d7e640dcc2d04fdfb6/addressnet/pretrained/model.ckpt.index -------------------------------------------------------------------------------- /addressnet/pretrained/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonrig/address-net/28e7c2de030bae56f81c66d7e640dcc2d04fdfb6/addressnet/pretrained/model.ckpt.meta -------------------------------------------------------------------------------- /addressnet/typo.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | # Contains nearby characters on the keyboard for substitution when generating typos 5 | character_replacement = dict() 6 | 7 | character_replacement['a'] = 'qwsz' 8 | character_replacement['b'] = 'nhgv ' 9 | character_replacement['c'] = 'vfdx ' 10 | character_replacement['d'] = 'fresxc' 11 | character_replacement['e'] = 'sdfr43ws' 12 | character_replacement['f'] = 'gtrdcv' 13 | character_replacement['g'] = 'hytfvb' 14 | character_replacement['h'] = 'juytgbn' 15 | character_replacement['i'] = 'ujklo98' 16 | character_replacement['j'] = 'mkiuyhn' 17 | character_replacement['k'] = 'jm,loij' 18 | character_replacement['l'] = 'k,.;pok' 19 | character_replacement['m'] = 'njk, ' 20 | character_replacement['n'] = 'bhjm ' 21 | character_replacement['o'] = 'plki90p' 22 | character_replacement['p'] = 'ol;[-0o' 23 | character_replacement['q'] = 'asw21' 24 | character_replacement['r'] = 'tfde45' 25 | character_replacement['s'] = 'dxzawe' 26 | character_replacement['t'] = 'ygfr56' 27 | character_replacement['u'] = 'ijhy78' 28 | character_replacement['v'] = 'cfgb ' 29 | character_replacement['w'] = 'saq23e' 30 | character_replacement['x'] = 'zsdc' 31 | character_replacement['y'] = 'uhgt67' 32 | character_replacement['z'] = 'xsa' 33 | character_replacement['1'] = '2q' 34 | character_replacement['2'] = '3wq1' 35 | character_replacement['3'] = '4ew2' 36 | character_replacement['4'] = '5re3' 37 | character_replacement['5'] = '6tr4' 38 | character_replacement['6'] = '7yt5' 39 | character_replacement['7'] = '8uy6' 40 | character_replacement['8'] = '9iu7' 41 | character_replacement['9'] = '0oi8' 42 | character_replacement['0'] = '-po9' 43 | 44 | 45 | def generate_typo(s: str, sub_rate: float = 0.01, del_rate: float = 0.005, dupe_rate: float = 0.005, 46 | transpose_rate: float = 0.01) -> str: 47 | """ 48 | Generates a new string containing some plausible typos 49 | :param s: the input string 50 | :param sub_rate: character substitution rate (0 < x < 1) 51 | :param del_rate: character deletion rate (0 < x < 1) 52 | :param dupe_rate: character duplication rate (0 < x < 1) 53 | :param transpose_rate: character transposition rate (0 < x < 1) 54 | :return: the string with typos 55 | """ 56 | if len(s) == 0: 57 | return s 58 | 59 | new_string = list() 60 | for i, char in enumerate(s.lower()): 61 | 62 | # Decide what to do 63 | do = np.random.uniform(size=(4,)) 64 | do_swap = do[0] < sub_rate 65 | do_delete = do[1] < del_rate 66 | do_duplicate = do[2] < dupe_rate 67 | do_transpose = do[3] < transpose_rate 68 | 69 | if do_swap and char in character_replacement: 70 | # Exchange the character for a randomly selected replacement of nearby keys 71 | new_string.append(random.choice(character_replacement[char])) 72 | elif do_delete: 73 | # Don't include this character in the replacement string 74 | continue 75 | elif do_duplicate: 76 | # Add this character twice to the new string 77 | new_string.extend([char] * 2) 78 | elif do_transpose and len(new_string) > 0: 79 | # Swap this and the previous character 80 | new_string.append(new_string[-1]) 81 | new_string[-2] = char 82 | else: 83 | # Keep the character 84 | new_string.append(char) 85 | 86 | # if an empty string is generated, give it another go 87 | if len(new_string) == 0: 88 | return generate_typo(s, sub_rate, del_rate, dupe_rate, transpose_rate) 89 | 90 | return ''.join(new_string) 91 | -------------------------------------------------------------------------------- /example-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonrig/address-net/28e7c2de030bae56f81c66d7e640dcc2d04fdfb6/example-result.png -------------------------------------------------------------------------------- /generate_tf_records.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import csv 3 | import tensorflow as tf 4 | import argparse 5 | 6 | from addressnet.lookups import lookup_flat_type, lookup_level_type, lookup_street_type, lookup_street_suffix, \ 7 | lookup_state 8 | 9 | 10 | def _str_feature(data: str) -> tf.train.Feature: 11 | """ 12 | Creates a string feature 13 | :param data: string data 14 | :return: a tf.train.Feature object holding the string data 15 | """ 16 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data.encode()])) 17 | 18 | 19 | def _int_feature(data: int, none_value: int=-1) -> tf.train.Feature: 20 | """ 21 | Creates an integer feature 22 | :param data: integer data 23 | :param none_value: int value to use if None 24 | :return: a tf.train.Feature object holding the integer data 25 | """ 26 | try: 27 | val = int(data) 28 | except ValueError: 29 | val = none_value 30 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[val])) 31 | 32 | 33 | def _float_feature(data: float) -> tf.train.Feature: 34 | """ 35 | Creates a float feature 36 | :param data: float data 37 | :return: a tf.train.Feature object holding the float data 38 | """ 39 | return tf.train.Feature(float_list=tf.train.FloatList(value=[float(data)])) 40 | 41 | 42 | def generate_tf_records(input_file_path: str, output_file_path: str, input_gzip: bool=True): 43 | """ 44 | Processes the input CSV file to produce a tfrecord file 45 | :param input_file_path: input CSV file 46 | :param output_file_path: output tfrecord file 47 | :param input_gzip: whether or not the input file is gzip compressed 48 | """ 49 | file_open = gzip.open if input_gzip else open 50 | file_open_mode = "rt" if input_gzip else "r" 51 | with file_open(input_file_path, file_open_mode, newline="") as f: 52 | csv_reader = csv.DictReader(f) 53 | 54 | string_fields = ('building_name', 'lot_number_prefix', 'lot_number', 'lot_number_suffix', 'flat_number_prefix', 55 | 'flat_number_suffix', 'level_number_prefix', 'level_number_suffix', 'number_first_prefix', 56 | 'number_first_suffix', 'number_last_prefix', 'number_last_suffix', 'street_name', 57 | 'locality_name', 'postcode') 58 | 59 | int_fields = ('flat_number', 'level_number', 'number_first', 'number_last') 60 | 61 | int_lookup_fields = ( 62 | ('flat_type', lookup_flat_type), ('level_type', lookup_level_type), ('street_type_code', lookup_street_type), 63 | ('street_suffix_code', lookup_street_suffix), ('state_abbreviation', lookup_state)) 64 | 65 | float_fields = ('latitude', 'longitude') 66 | 67 | tf_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP) 68 | with tf.python_io.TFRecordWriter(output_file_path, options=tf_options) as tf_writer: 69 | for row in csv_reader: 70 | record = dict() 71 | for field in string_fields: 72 | record[field] = _str_feature(row[field]) 73 | for field in int_fields: 74 | record[field] = _int_feature(row[field]) 75 | for field, lookup_fn in int_lookup_fields: 76 | record[field] = _int_feature(lookup_fn(row[field])) 77 | for field in float_fields: 78 | record[field] = _float_feature(row[field]) 79 | 80 | example = tf.train.Example(features=tf.train.Features(feature=record)) 81 | tf_writer.write(example.SerializeToString()) 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("gnaf_csv", help="CSV file exported from GNAF `address_view`") 87 | parser.add_argument("tf_record_output", help="Path to tfrecords output") 88 | parser.add_argument("--gzipped_input", action="store_true", default=False) 89 | args = parser.parse_args() 90 | 91 | print("Generating tfrecords files...") 92 | generate_tf_records(args.gnaf_csv, args.tf_record_output, args.gzipped_input) 93 | print("Done!") -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from addressnet.predict import predict_one 2 | 3 | if __name__ == "__main__": 4 | print(predict_one("casa del gelato, 10A 24-26 high street road mount waverley vic 3183")) 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='address-net', 5 | version='1.0', 6 | packages=['addressnet'], 7 | url='https://github.com/jasonrig/address-net', 8 | license='MIT', 9 | author='Jason Rigby', 10 | author_email='hello@jasonrig.by', 11 | description='Splits Australian addresses into their components', 12 | extras_require={ 13 | "tf": ["tensorflow>=1.12,<2.0"], 14 | "tf_gpu": ["tensorflow-gpu>=1.12,<2.0"], 15 | }, 16 | install_requires=[ 17 | 'numpy', 18 | 'textdistance' 19 | ], 20 | include_package_data=True 21 | ) 22 | --------------------------------------------------------------------------------