├── data ├── others ├── chart.jpg ├── example_1.jpg ├── example_2.jpg └── TrainingStatistic.xlsx ├── dict ├── 40000_classes.npy ├── SROIE_classes.npy ├── table_classes.npy ├── 20000TC_classes.npy ├── 20000T_classes.npy ├── SROIEnc_classes.npy ├── 20000TC_dictionary.npy ├── 20000T_dictionary.npy ├── 40000_dictionary.npy ├── SROIE_dictionary.npy ├── SROIEnc_dictionary.npy ├── table_dictionary.npy ├── 20000T_index_to_word.npy ├── 20000T_word_to_index.npy ├── 40000_index_to_word.npy ├── 40000_word_to_index.npy ├── SROIE_index_to_word.npy ├── SROIE_word_to_index.npy ├── table_index_to_word.npy ├── table_word_to_index.npy ├── 20000TC_index_to_word.npy ├── 20000TC_word_to_index.npy ├── SROIEnc_index_to_word.npy └── SROIEnc_word_to_index.npy ├── .gitignore ├── invoice_data ├── Faktura1.pdf_0.jpg └── Faktura1.pdf_1.jpg ├── .idea ├── misc.xml ├── vcs.xml ├── modules.xml ├── CUTIE.iml └── workspace.xml ├── .settings └── org.eclipse.core.resources.prefs ├── .pydevproject ├── .project ├── requirements.txt ├── main_data_tokenizer.py ├── main_build_dict.py ├── helper.py ├── README.md ├── deprecated ├── model_cutie_dilate.py ├── model_cutie_att.py ├── model_cutie_sep.py ├── model_cutie_unet8.py ├── model_cutie_fpn8.py ├── model_cutie_res.py ├── model_cutie_res16.py ├── model_cutie_res_att.py └── model_cutie_res_att_bert.py ├── download_data.py ├── model_cutie_res_bert.py ├── model_cutie_aspp.py ├── model_cutie2_fpn.py ├── model_cutie2_dilate.py ├── model_cutie2_aspp.py ├── model_cutie.py ├── main_evaluate_json.py ├── model_cutie2.py ├── bert_embedding.py ├── export_data.py ├── model_cutie_hr.py ├── tokenization.py ├── utils.py ├── main_train_json.py └── model_framework.py /data: -------------------------------------------------------------------------------- 1 | ../data/CUTIE/ -------------------------------------------------------------------------------- /others/chart.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/others/chart.jpg -------------------------------------------------------------------------------- /dict/40000_classes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/40000_classes.npy -------------------------------------------------------------------------------- /dict/SROIE_classes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIE_classes.npy -------------------------------------------------------------------------------- /dict/table_classes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/table_classes.npy -------------------------------------------------------------------------------- /others/example_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/others/example_1.jpg -------------------------------------------------------------------------------- /others/example_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/others/example_2.jpg -------------------------------------------------------------------------------- /dict/20000TC_classes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000TC_classes.npy -------------------------------------------------------------------------------- /dict/20000T_classes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000T_classes.npy -------------------------------------------------------------------------------- /dict/SROIEnc_classes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIEnc_classes.npy -------------------------------------------------------------------------------- /dict/20000TC_dictionary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000TC_dictionary.npy -------------------------------------------------------------------------------- /dict/20000T_dictionary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000T_dictionary.npy -------------------------------------------------------------------------------- /dict/40000_dictionary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/40000_dictionary.npy -------------------------------------------------------------------------------- /dict/SROIE_dictionary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIE_dictionary.npy -------------------------------------------------------------------------------- /dict/SROIEnc_dictionary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIEnc_dictionary.npy -------------------------------------------------------------------------------- /dict/table_dictionary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/table_dictionary.npy -------------------------------------------------------------------------------- /dict/20000T_index_to_word.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000T_index_to_word.npy -------------------------------------------------------------------------------- /dict/20000T_word_to_index.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000T_word_to_index.npy -------------------------------------------------------------------------------- /dict/40000_index_to_word.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/40000_index_to_word.npy -------------------------------------------------------------------------------- /dict/40000_word_to_index.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/40000_word_to_index.npy -------------------------------------------------------------------------------- /dict/SROIE_index_to_word.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIE_index_to_word.npy -------------------------------------------------------------------------------- /dict/SROIE_word_to_index.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIE_word_to_index.npy -------------------------------------------------------------------------------- /dict/table_index_to_word.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/table_index_to_word.npy -------------------------------------------------------------------------------- /dict/table_word_to_index.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/table_word_to_index.npy -------------------------------------------------------------------------------- /others/TrainingStatistic.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/others/TrainingStatistic.xlsx -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | data/ 3 | logs/ 4 | graph/ 5 | results/ 6 | *.pyc 7 | .DS_Store 8 | */.DS_Store 9 | task* -------------------------------------------------------------------------------- /dict/20000TC_index_to_word.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000TC_index_to_word.npy -------------------------------------------------------------------------------- /dict/20000TC_word_to_index.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000TC_word_to_index.npy -------------------------------------------------------------------------------- /dict/SROIEnc_index_to_word.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIEnc_index_to_word.npy -------------------------------------------------------------------------------- /dict/SROIEnc_word_to_index.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIEnc_word_to_index.npy -------------------------------------------------------------------------------- /invoice_data/Faktura1.pdf_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/invoice_data/Faktura1.pdf_0.jpg -------------------------------------------------------------------------------- /invoice_data/Faktura1.pdf_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/invoice_data/Faktura1.pdf_1.jpg -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.settings/org.eclipse.core.resources.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | encoding/data_loader_json.py=utf-8 3 | encoding/main_data_tokenizer.py=utf-8 4 | encoding/main_evaluate_json.py=utf-8 5 | encoding/tokenization.py=utf-8 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.pydevproject: -------------------------------------------------------------------------------- 1 | 2 | 3 | Default 4 | python 2.7 5 | 6 | -------------------------------------------------------------------------------- /.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | CUTIE 4 | 5 | 6 | 7 | 8 | 9 | org.python.pydev.PyDevBuilder 10 | 11 | 12 | 13 | 14 | 15 | org.python.pydev.pythonNature 16 | 17 | 18 | -------------------------------------------------------------------------------- /.idea/CUTIE.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.6.1 2 | astor==0.7.1 3 | atomicwrites==1.2.1 4 | attrs==18.2.0 5 | certifi==2019.3.9 6 | chardet==3.0.4 7 | cycler==0.10.0 8 | gast==0.2.0 9 | grpcio==1.17.1 10 | h5py==2.9.0 11 | idna==2.8 12 | Keras-Applications==1.0.6 13 | Keras-Preprocessing==1.0.5 14 | kiwisolver==1.0.1 15 | llvmlite==0.27.0 16 | Markdown==3.0.1 17 | matplotlib==3.0.2 18 | more-itertools==5.0.0 19 | numba==0.42.0 20 | numpy==1.15.4 21 | opencv-python==4.0.0.21 22 | pandas==0.23.4 23 | Pillow==5.4.0 24 | pluggy==0.8.1 25 | protobuf==3.6.1 26 | py==1.7.0 27 | pyparsing==2.3.0 28 | pytest==4.1.1 29 | python-dateutil==2.7.5 30 | pytz==2018.7 31 | requests==2.21.0 32 | scipy==1.2.0 33 | six==1.12.0 34 | tensorboard==1.12.1 35 | tensorflow==1.12.0 36 | termcolor==1.1.0 37 | urllib3==1.24.1 38 | Werkzeug==0.14.1 39 | -------------------------------------------------------------------------------- /main_data_tokenizer.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-02 3 | # xh.zhao@outlook.com 4 | import tensorflow as tf 5 | import argparse 6 | 7 | import tokenization 8 | from data_loader_json import DataLoader 9 | 10 | parser = argparse.ArgumentParser(description='Data Tokenizer parameters') 11 | parser.add_argument('--dict_path', type=str, default='dict/vocab.txt') 12 | parser.add_argument('--doc_path', type=str, default='data/meals') 13 | parser.add_argument('--batch_size', type=int, default=32) 14 | params = parser.parse_args() 15 | 16 | #class DataTokenizer(DataLoader): 17 | 18 | if __name__ == '__main__': 19 | ## convert data into tokenized data with updated bbox 20 | #data_loader = DataLoader(params, update_dict=True, load_dictionary=False) 21 | tokenizer = tokenization.FullTokenizer(vocab_file=params.dict_path, do_lower_case=True) 22 | 23 | -------------------------------------------------------------------------------- /main_build_dict.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-01 3 | # xiaohui.zhao@outlook.com 4 | import tensorflow as tf 5 | import argparse 6 | 7 | from data_loader_json import DataLoader 8 | 9 | parser = argparse.ArgumentParser(description='CUTIE parameters') 10 | parser.add_argument('--dict_path', type=str, default='dict/SROIE') 11 | parser.add_argument('--doc_path', type=str, default='data/SROIE') 12 | parser.add_argument('--test_path', type=str, default='') # leave empty if no test data provided 13 | parser.add_argument('--text_case', type=bool, default=True) # case sensitive 14 | parser.add_argument('--tokenize', type=bool, default=True) # tokenize input text 15 | parser.add_argument('--batch_size', type=int, default=32) 16 | parser.add_argument('--use_cutie2', type=bool, default=False) 17 | params = parser.parse_args() 18 | 19 | if __name__ == '__main__': 20 | ## run this program before training to create a basic dictionary for training 21 | data_loader = DataLoader(params, update_dict=True, load_dictionary=False) 22 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import argparse 3 | import sys 4 | import os, re, json 5 | import time 6 | 7 | # if __name__ == "__main__": 8 | # src = '/Users/xiaohui.zhao/workspace/CUTIE/hotel_imgs/' 9 | # dst = '/Users/xiaohui.zhao/workspace/data/CUTIE/mixed_true_even/train_hotel' 10 | # json_files = {} 11 | # for dirpath,dirnames,filenames in os.walk(dst): 12 | # for filename in filenames: 13 | # file = re.split(r'\.', filename)[0] 14 | # file_path = os.path.join(dirpath,filename) 15 | # json_files.update({file: file_path}) 16 | # 17 | # files = [] 18 | # for dirpath,dirnames,filenames in os.walk(src): 19 | # for filename in filenames: 20 | # #file = os.path.join(dirpath,filename) 21 | # file = re.split(r'\.', filename)[0] 22 | # if file in json_files: 23 | # os.system('cp '+ os.path.join(dirpath,filename) + ' ' + dst) 24 | 25 | if __name__ == "__main__": 26 | src = '/Users/xiaohui.zhao/workspace/data/CUTIE/column_identity' 27 | dst = '/Users/xiaohui.zhao/workspace/data/CUTIE/column' 28 | json_files = {} 29 | for dirpath,dirnames,filenames in os.walk(src): 30 | for filename in filenames: 31 | file_path = os.path.join(dirpath,filename) 32 | if file_path[-3:] == 'png': 33 | continue 34 | with open(file_path, encoding='utf-8') as f: 35 | data = json.load(f) 36 | print(len(data['fields'])) 37 | for i in range(len(data['fields'])): 38 | data['fields'][i]['field_name'] = 'Column{}'.format(i) 39 | 40 | target_path = os.path.join(dst,filename) 41 | with open(target_path, 'w') as f: 42 | json.dump(data, f) 43 | 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CUTIE 2 | TensorFlow implementation of the paper "CUTIE: Learning to Understand Documents with Convolutional Universal Text Information Extractor." 3 | Xiaohui Zhao [Paper Link](https://arxiv.org/abs/1903.12363v4) 4 | 5 | ---- 6 | CUTIE 是用于“票据文档” 2D 关键信息提取/命名实体识别/槽位填充 算法。 7 | 使用CUTIE前,需先使用OCR算法对“票据文档” 中的文字执行检测和识别,而后将格式化的文本输入入CUTIE网络,具体流程可参照论文。 8 | 9 | CUTIE can be considered as one type of 2-Dimensional Key Information Extraction, 2-D NER (Named Entity Recognition) or a 2-Dimensional 2D Slot Filling algorithm. 10 | Before training / inference with CUTIE, prepare your structured texts in your scanned document images with any type of OCR algorithm. Refer to the CUTIE paper for details about the procedure. 11 | 12 | ### Results 13 | 14 | Result evaluated on 4,484 receipt documents, including taxi receipts, meals entertainment receipts, and hotel receipts, with 9 different key information classes. (AP / softAP) 15 | |Method | #Params | Taxi | Hotel | 16 | | ----------|:---------:| :-----: | :-----: | 17 | | CloudScan | - | 82.0 / - | 60.0 / - | 18 | | BERT | 110M | 88.1 / - | 71.7 / - | 19 | | CUTIE |**14M** |**94.0 / 97.3**|**74.6 / 87.0**| 20 | 21 | ![Taxi](https://github.com/vsymbol/CUTIE/raw/master/others/example_1.jpg) 22 | 23 | ![Hotel](https://github.com/vsymbol/CUTIE/raw/master/others/example_2.jpg) 24 | 25 | 26 | ### Installation & Usage 27 | 28 | ``` 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | 1. Generate your own dictionary with main_build_dict.py / main_data_tokenizer.py 33 | 2. Train your model with main_train_json.py 34 | 35 | CUTIE achieves best performance with rows/cols well configured. For more insights, refer to statistics in the file (others/TrainingStatistic.xlsx). 36 | 37 | ![Chart](https://github.com/vsymbol/CUTIE/raw/master/others/chart.jpg) 38 | 39 | 40 | ### Others 41 | 42 | For information about the input example, refer to [issue discussion](https://github.com/vsymbol/CUTIE/issues/7). 43 | - Apply any OCR tool that help you detecting and recognizing words in the scanned document image. 44 | - Label image OCR results with key information class as the .json file in the invoice_data folder. (thanks to @4kssoft) 45 | -------------------------------------------------------------------------------- /deprecated/model_cutie_dilate.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2019-03 3 | # xiaohui.zhao@accenture.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIERes(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_dilate" # 10 | 11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 17 | 18 | self.num_vocabs = num_vocabs 19 | self.num_classes = num_classes 20 | self.trainable = trainable 21 | 22 | self.embedding_size = params.embedding_size 23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 26 | 27 | self.layer_inputs = [] 28 | self.setup() 29 | 30 | 31 | def setup(self): 32 | # input 33 | (self.feed('data') 34 | .embed(self.num_vocabs, self.embedding_size, name='embedding')) 35 | 36 | # encoder 37 | (self.feed('embedding') 38 | .conv(3, 5, 128, 1, 1, name='encoder1_1') 39 | .conv(3, 5, 128, 1, 1, name='encoder1_2') 40 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_3') 41 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_4') 42 | .dilate_conv(3, 5, 512, 1, 1, 2, name='encoder1_5') 43 | .dilate_conv(3, 5, 512, 1, 1, 2, name='encoder1_6') 44 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_7') 45 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_8') 46 | .dilate_conv(3, 5, 128, 1, 1, 2, name='encoder1_9') 47 | .dilate_conv(3, 5, 128, 1, 1, 2, name='encoder1_10')) 48 | 49 | # classification 50 | (self.feed('encoder1_10') 51 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm 52 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /deprecated/model_cutie_att.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2019-03 3 | # xiaohui.zhao@accenture.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIERes(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_attention" # 10 | 11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 17 | 18 | self.num_vocabs = num_vocabs 19 | self.num_classes = num_classes 20 | self.trainable = trainable 21 | 22 | self.embedding_size = params.embedding_size 23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 26 | 27 | self.layer_inputs = [] 28 | self.setup() 29 | 30 | 31 | def setup(self): 32 | # input 33 | (self.feed('data') 34 | .embed(self.num_vocabs, self.embedding_size, name='embedding') 35 | .conv(3, 5, 128, 1, 1, name='encoder0_1')) 36 | 37 | # encoder 38 | (self.feed('encoder0_1') 39 | .conv(3, 5, 128, 1, 1, name='encoder1_1') 40 | .conv(3, 5, 128, 1, 1, name='encoder1_2') 41 | .conv(3, 5, 128, 1, 1, name='encoder1_3') 42 | .conv(3, 5, 128, 1, 1, name='encoder1_4')) 43 | 44 | (self.feed('encoder0_1', 'encoder1_4') 45 | .attention(1, name='attention2') 46 | .conv(3, 5, 128, 1, 1, name='encoder1_5') 47 | .conv(3, 5, 128, 1, 1, name='encoder1_6') 48 | .conv(3, 5, 128, 1, 1, name='encoder1_7') 49 | .conv(3, 5, 128, 1, 1, name='encoder1_8')) 50 | 51 | (self.feed('encoder0_1', 'encoder1_8') 52 | .attention(1, name='attention5') 53 | .conv(3, 5, 128, 1, 1, name='encoder1_9') 54 | .conv(3, 5, 128, 1, 1, name='encoder1_10')) 55 | 56 | # classification 57 | (self.feed('encoder1_10') 58 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm 59 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /deprecated/model_cutie_sep.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-12 3 | # xiaohui.zhao@accenture.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIESep(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_seperatable_residual" 10 | 11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes}) 14 | self.num_vocabs = num_vocabs 15 | self.num_classes = num_classes 16 | self.trainable = trainable 17 | 18 | self.embedding_size = params.embedding_size 19 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0 20 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0 21 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 22 | 23 | self.layer_inputs = [] 24 | self.setup() 25 | 26 | 27 | def setup(self): 28 | # input 29 | (self.feed('data') 30 | .embed(self.num_vocabs, self.embedding_size, name='embedding')) 31 | 32 | # encoder 33 | (self.feed('embedding') 34 | .sepconv(3, 5, 4, 2, name='encoder1_1') 35 | .sepconv(3, 5, 4, 2, name='encoder1_2') 36 | .conv(3, 5, 128, 1, 1, name='bottleneck1') 37 | .max_pool(2, 2, 2, 2, name='pool1') 38 | .sepconv(3, 5, 4, 2, name='encoder2_1') 39 | .sepconv(3, 5, 4, 2, name='encoder2_2') 40 | .conv(3, 5, 256, 1, 1, name='bottleneck2') 41 | .max_pool(2, 2, 2, 2, name='pool2') 42 | .sepconv(3, 5, 4, 2, name='encoder3_1') 43 | .sepconv(3, 5, 4, 2, name='encoder3_2') 44 | .conv(3, 5, 512, 1, 1, name='bottleneck3') 45 | .max_pool(2, 2, 2, 2, name='pool3') 46 | .sepconv(3, 5, 4, 2, name='encoder4_1') 47 | .sepconv(3, 5, 4, 2, name='encoder4_2')) 48 | 49 | # decoder 50 | (self.feed('encoder4_2') 51 | .up_conv(3, 5, 512, 1, 1, name='up1') 52 | .conv(3, 5, 256, 1,1, name='bottleneck4') 53 | .sepconv(3, 5, 4, 2, name='decoder1_1') 54 | .sepconv(3, 5, 4, 2, name='decoder1_2') 55 | .up_conv(3, 5, 256, 1, 1, name='up2') 56 | .conv(3, 5, 128, 1,1, name='bottleneck5') 57 | .sepconv(3, 5, 4, 2, name='decoder2_1') 58 | .sepconv(3, 5, 4, 2, name='decoder2_2') 59 | .up_conv(3, 5, 128, 1, 1, name='up3') 60 | .conv(3, 5, 64, 1,1, name='bottleneck6') 61 | .sepconv(3, 5, 4, 2, name='decoder3_1') 62 | .sepconv(3, 5, 4, 2, name='decoder3_2')) 63 | 64 | # classification 65 | (self.feed('decoder3_2') 66 | .conv(1, 1, self.num_classes, 1, 1, name='cls_logits') 67 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /download_data.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import argparse 3 | from requests.auth import HTTPBasicAuth 4 | import base64 5 | import os 6 | import time 7 | 8 | 9 | def init_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--user', type=str, help='login name', required=True) 12 | parser.add_argument('--password', type=str, help='login password', required=True) 13 | parser.add_argument('--date', type=str, help='search date', required=True) 14 | parser.add_argument('--invoice_type_id', type=str, help='invoice type id', required=True) 15 | 16 | return parser.parse_args() 17 | 18 | 19 | def get_tasks_of_page(date=None, invoice_type=None, next_url=None): 20 | api = next_url if next_url else \ 21 | 'http://52.193.30.103/argus/api/task/myte/?search={0}&invoice_type={1}'.format(date, invoice_type) 22 | 23 | r = requests.get(api, auth=auth, headers=headers) 24 | 25 | if r.ok: 26 | results = r.json().get('results') 27 | next = r.json().get('next') 28 | count = r.json().get('count') 29 | 30 | for result in results: 31 | global index 32 | index += 1 33 | 34 | print('\r download: {0}/{1}'.format(index, count), end='') 35 | get_image(result['id']) 36 | 37 | else: 38 | print('Download task info failed.') 39 | print('Call api: {url}'.format(url=api)) 40 | print('[Err Reason]:{err}'.format(err=r.reason)) 41 | print('[Err text]:{err}'.format(err=r.text)) 42 | next = None 43 | 44 | return next 45 | 46 | 47 | def get_image(task_id): 48 | api = 'http://52.193.30.103/argus/api/task_file/myte/{0}'.format(task_id) 49 | r = requests.get(api, auth=auth, headers=headers) 50 | if r.ok: 51 | result = r.json()[0] 52 | quality_file = result.get('quality_file') 53 | file_name = os.path.split(quality_file)[-1] 54 | 55 | with open(os.path.join(path, file_name), 'wb') as f: 56 | image_64_decode = base64.b64decode(result.get('picture').encode("utf8")) 57 | f.write(image_64_decode) 58 | 59 | else: 60 | print('Download image {id} failed.'.format(id=task_id)) 61 | print('Call api: {url}'.format(url=api)) 62 | print('[Err Reason]:{err}'.format(err=r.reason)) 63 | print('[Err text]:{err}'.format(err=r.text)) 64 | 65 | 66 | if __name__ == "__main__": 67 | args = init_args() 68 | auth = HTTPBasicAuth(args.user, args.password) 69 | headers = { 70 | 'Content-Type': "application/json", 71 | 'Cache-Control': "no-cache", 72 | } 73 | time_flag = int(time.time()) 74 | for date in args.date.split(','): 75 | index = 0 76 | path = '{root}/invoice_type_{type}-{time}/{date}'.format( 77 | root=os.path.abspath('.'), date=date, time=time_flag, type=args.invoice_type_id 78 | ) 79 | os.makedirs(path) 80 | print('\ndownload path: {0}'.format(path)) 81 | 82 | next_url = get_tasks_of_page(date=date, invoice_type=args.invoice_type_id) 83 | while True: 84 | if next_url: 85 | next_url = get_tasks_of_page(next_url=next_url) 86 | else: 87 | break 88 | -------------------------------------------------------------------------------- /deprecated/model_cutie_unet8.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-12 3 | # xiaohui.zhao@accenture.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIEUNet(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_unet_8x" # 8x down sampling 10 | 11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes}) 14 | self.num_vocabs = num_vocabs 15 | self.num_classes = num_classes 16 | self.trainable = trainable 17 | 18 | self.embedding_size = params.embedding_size 19 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 20 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 21 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 22 | 23 | self.layer_inputs = [] 24 | self.setup() 25 | 26 | 27 | def setup(self): 28 | # input 29 | (self.feed('data') 30 | .embed(self.num_vocabs, self.embedding_size, name='embedding')) 31 | 32 | # encoder 33 | (self.feed('embedding') 34 | .conv(3, 5, 64, 1, 1, name='encoder1_1') 35 | .conv(3, 5, 128, 1, 1, name='encoder1_2') 36 | .max_pool(2, 2, 2, 2, name='pool1') 37 | .conv(3, 5, 128, 1, 1, name='encoder2_1') 38 | .conv(3, 5, 256, 1, 1, name='encoder2_2') 39 | .max_pool(2, 2, 2, 2, name='pool2') 40 | .conv(3, 5, 256, 1, 1, name='encoder3_1') 41 | .conv(3, 5, 512, 1, 1, name='encoder3_2') 42 | .conv(3, 5, 512, 1, 1, name='encoder3_3') 43 | .max_pool(2, 2, 2, 2, name='pool3') 44 | .conv(3, 5, 512, 1, 1, name='encoder4_1') 45 | .conv(3, 5, 512, 1, 1, name='encoder4_2') 46 | .conv(3, 5, 512, 1, 1, name='encoder4_3') 47 | .conv(3, 5, 512, 1, 1, name='encoder4_4')) 48 | 49 | # decoder 50 | (self.feed('encoder4_4') 51 | .up_conv(3, 5, 512, 1, 1, name='up1')) 52 | (self.feed('up1', 'encoder3_3') 53 | .concat(3, name='concat1') 54 | .conv(3, 5, 256, 1, 1, name='decoder1_1') 55 | .conv(3, 5, 256, 1, 1, name='decoder1_2') 56 | .conv(3, 5, 256, 1, 1, name='decoder1_3') 57 | .conv(3, 5, 256, 1, 1, name='decoder1_4') 58 | .up_conv(3, 5, 256, 1, 1, name='up2')) 59 | (self.feed('up2', 'encoder2_2') 60 | .concat(3, name='concat2') 61 | .conv(3, 5, 128, 1, 1, name='decoder2_1') 62 | .conv(3, 5, 128, 1, 1, name='decoder2_2') 63 | .conv(3, 5, 128, 1, 1, name='decoder2_3') 64 | .up_conv(3, 5, 128, 1, 1, name='up3')) 65 | (self.feed('up3', 'encoder1_2') 66 | .concat(3, name='concat3') 67 | .conv(3, 5, 64, 1, 1, name='decoder3_1') 68 | .conv(3, 5, 64, 1, 1, name='decoder3_2')) 69 | 70 | # classification 71 | (self.feed('decoder3_2') 72 | .conv(1, 1, self.num_classes, 1, 1, name='cls_logits') 73 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /deprecated/model_cutie_fpn8.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-12 3 | # xiaohui.zhao@accenture.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIEFPN(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_fpn_8x" # 8x down sampling 10 | 11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes}) 14 | self.num_vocabs = num_vocabs 15 | self.num_classes = num_classes 16 | self.trainable = trainable 17 | 18 | self.embedding_size = params.embedding_size 19 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 20 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 21 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 22 | 23 | self.layer_inputs = [] 24 | self.setup() 25 | 26 | 27 | def setup(self): 28 | # input 29 | (self.feed('data') 30 | .embed(self.num_vocabs, self.embedding_size, name='embedding')) 31 | 32 | # encoder 33 | (self.feed('embedding') 34 | .conv(3, 5, 64, 1, 1, name='encoder1_1') 35 | .conv(3, 5, 128, 1, 1, name='encoder1_2') 36 | .max_pool(2, 2, 2, 2, name='pool1') 37 | .conv(3, 5, 128, 1, 1, name='encoder2_1') 38 | .conv(3, 5, 256, 1, 1, name='encoder2_2') 39 | .max_pool(2, 2, 2, 2, name='pool2') 40 | .conv(3, 5, 256, 1, 1, name='encoder3_1') 41 | .conv(3, 5, 512, 1, 1, name='encoder3_2') 42 | .conv(3, 5, 512, 1, 1, name='encoder3_3') 43 | .max_pool(2, 2, 2, 2, name='pool3') 44 | .conv(3, 5, 512, 1, 1, name='encoder4_1') 45 | .conv(3, 5, 512, 1, 1, name='encoder4_2') 46 | .conv(3, 5, 512, 1, 1, name='encoder4_3') 47 | .conv(3, 5, 512, 1, 1, name='encoder4_4')) 48 | 49 | # decoder 50 | (self.feed('encoder4_4') 51 | .up_conv(3, 5, 512, 1, 1, name='up1')) 52 | (self.feed('up1', 'encoder3_4') 53 | .concat(3, name='concat1') # 1x1 before the add operation for FPN 54 | .conv(3, 5, 256, 1, 1, name='decoder1_1') 55 | .conv(3, 5, 256, 1, 1, name='decoder1_2') 56 | .conv(3, 5, 256, 1, 1, name='decoder1_3') 57 | .conv(3, 5, 256, 1, 1, name='decoder1_4') 58 | .up_conv(3, 5, 256, 1, 1, name='up2')) 59 | (self.feed('up2', 'encoder2_3') 60 | .concat(3, name='concat2') 61 | .conv(3, 5, 128, 1, 1, name='decoder2_1') 62 | .conv(3, 5, 128, 1, 1, name='decoder2_2') 63 | .conv(3, 5, 128, 1, 1, name='decoder2_3') 64 | .up_conv(3, 5, 128, 1, 1, name='up3')) 65 | (self.feed('up3', 'encoder1_2') 66 | .concat(3, name='concat3') 67 | .conv(3, 5, 64, 1, 1, name='decoder3_1') 68 | .conv(3, 5, 64, 1, 1, name='decoder3_2')) 69 | 70 | # classification 71 | (self.feed('decoder3_2') 72 | .conv(1, 1, self.num_classes, 1, 1, name='cls_logits') 73 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /deprecated/model_cutie_res.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-12 3 | # xiaohui.zhao@accenture.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIERes(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_residual_8x" # 8x down sampling 10 | 11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 17 | 18 | self.num_vocabs = num_vocabs 19 | self.num_classes = num_classes 20 | self.trainable = trainable 21 | 22 | self.embedding_size = params.embedding_size 23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 26 | 27 | self.layer_inputs = [] 28 | self.setup() 29 | 30 | 31 | def setup(self): 32 | # input 33 | (self.feed('data') 34 | .embed(self.num_vocabs, self.embedding_size, name='embedding')) 35 | 36 | # encoder 37 | (self.feed('embedding') 38 | .conv(3, 5, 64, 1, 1, name='encoder1_1') 39 | .conv(3, 5, 128, 1, 1, name='encoder1_2') 40 | .max_pool(2, 2, 2, 2, name='pool1') 41 | .conv(3, 5, 128, 1, 1, name='encoder2_1') 42 | .conv(3, 5, 256, 1, 1, name='encoder2_2') 43 | .max_pool(2, 2, 2, 2, name='pool2') 44 | .conv(3, 5, 256, 1, 1, name='encoder3_1') 45 | .conv(3, 5, 512, 1, 1, name='encoder3_2') 46 | .max_pool(2, 2, 2, 2, name='pool3') 47 | .conv(3, 5, 512, 1, 1, name='encoder4_1') 48 | .conv(3, 5, 512, 1, 1, name='encoder4_2')) 49 | 50 | # decoder 51 | (self.feed('encoder4_2') 52 | .up_conv(3, 5, 512, 1, 1, name='up1')) 53 | (self.feed('up1', 'encoder3_2') 54 | .concat(3, name='concat1') 55 | .conv(3, 5, 256, 1, 1, name='decoder1_1') 56 | .conv(3, 5, 256, 1, 1, name='decoder1_2') 57 | .up_conv(3, 5, 256, 1, 1, name='up2')) 58 | (self.feed('up2', 'encoder2_2') 59 | .concat(3, name='concat2') 60 | .conv(3, 5, 128, 1, 1, name='decoder2_1') 61 | .conv(3, 5, 128, 1, 1, name='decoder2_2') 62 | .up_conv(3, 5, 128, 1, 1, name='up3')) 63 | (self.feed('up3', 'encoder1_2') 64 | .concat(3, name='concat3') 65 | .conv(3, 5, 64, 1, 1, name='decoder3_1') 66 | .conv(3, 5, 64, 1, 1, name='decoder3_2')) 67 | 68 | # classification 69 | (self.feed('decoder3_2') 70 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm 71 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /deprecated/model_cutie_res16.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-12 3 | # xiaohui.zhao@accenture.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIERes(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_residual_16x" # 16x down sampling 10 | 11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes}) 14 | self.num_vocabs = num_vocabs 15 | self.num_classes = num_classes 16 | self.trainable = trainable 17 | 18 | self.embedding_size = params.embedding_size 19 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 20 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 21 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 22 | 23 | self.layer_inputs = [] 24 | self.setup() 25 | 26 | 27 | def setup(self): 28 | # input 29 | (self.feed('data') 30 | .embed(self.num_vocabs, self.embedding_size, name='embedding')) 31 | 32 | # encoder 33 | (self.feed('embedding') 34 | .conv(3, 5, 64, 1, 1, name='encoder1_1') 35 | .conv(3, 5, 128, 1, 1, name='encoder1_2') 36 | .max_pool(2, 2, 2, 2, name='pool1') 37 | .conv(3, 5, 128, 1, 1, name='encoder2_1') 38 | .conv(3, 5, 256, 1, 1, name='encoder2_2') 39 | .max_pool(2, 2, 2, 2, name='pool2') 40 | .conv(3, 5, 256, 1, 1, name='encoder3_1') 41 | .conv(3, 5, 512, 1, 1, name='encoder3_2') 42 | .max_pool(2, 2, 2, 2, name='pool3') 43 | .conv(3, 5, 512, 1, 1, name='encoder4_1') 44 | .conv(3, 5, 512, 1, 1, name='encoder4_2') 45 | .max_pool(2, 2, 2, 2, name='pool4') 46 | .conv(3, 5, 1024, 1, 1, name='encoder5_1') 47 | .conv(3, 5, 1024, 1, 1, name='encoder5_2')) 48 | 49 | # decoder 50 | (self.feed('encoder5_2') 51 | .up_conv(3, 5, 1024, 1, 1, name='up1')) 52 | (self.feed('up1', 'encoder4_2') 53 | .concat(3, name='concat1') 54 | .conv(3, 5, 512, 1, 1, name='decoder1_1') 55 | .conv(3, 5, 512, 1, 1, name='decoder1_2') 56 | .up_conv(3, 5, 512, 1, 1, name='up2')) 57 | (self.feed('up2', 'encoder3_2') 58 | .concat(3, name='concat2') 59 | .conv(3, 5, 256, 1, 1, name='decoder2_1') 60 | .conv(3, 5, 256, 1, 1, name='decoder2_2') 61 | .up_conv(3, 5, 256, 1, 1, name='up3')) 62 | (self.feed('up3', 'encoder2_2') 63 | .concat(3, name='concat3') 64 | .conv(3, 5, 128, 1, 1, name='decoder3_1') 65 | .conv(3, 5, 128, 1, 1, name='decoder3_2') 66 | .up_conv(3, 5, 128, 1, 1, name='up4')) 67 | (self.feed('up4', 'encoder1_2') 68 | .concat(3, name='concat4') 69 | .conv(3, 5, 64, 1, 1, name='decoder4_1') 70 | .conv(3, 5, 64, 1, 1, name='decoder4_2')) 71 | 72 | # classification 73 | (self.feed('decoder4_2') 74 | .conv(1, 1, self.num_classes, 1, 1, name='cls_logits') 75 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /deprecated/model_cutie_res_att.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-12 3 | # xiaohui.zhao@accenture.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIERes(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_residual_attention_8x" # 8x down sampling 10 | 11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 17 | 18 | self.num_vocabs = num_vocabs 19 | self.num_classes = num_classes 20 | self.trainable = trainable 21 | 22 | self.embedding_size = params.embedding_size 23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 26 | 27 | self.layer_inputs = [] 28 | self.setup() 29 | 30 | 31 | def setup(self): 32 | # input 33 | (self.feed('data') 34 | .embed(self.num_vocabs, self.embedding_size, name='embedding')) 35 | 36 | # encoder 37 | (self.feed('embedding') 38 | .conv(3, 5, 64, 1, 1, name='encoder1_1') 39 | .conv(3, 5, 128, 1, 1, name='encoder1_2') 40 | .max_pool(2, 2, 2, 2, name='pool1') 41 | .conv(3, 5, 128, 1, 1, name='encoder2_1') 42 | .conv(3, 5, 256, 1, 1, name='encoder2_2') 43 | .max_pool(2, 2, 2, 2, name='pool2') 44 | .conv(3, 5, 256, 1, 1, name='encoder3_1') 45 | .conv(3, 5, 512, 1, 1, name='encoder3_2') 46 | .max_pool(2, 2, 2, 2, name='pool3') 47 | .conv(3, 5, 512, 1, 1, name='encoder4_1') 48 | .conv(3, 5, 512, 1, 1, name='encoder4_2')) 49 | 50 | # decoder 51 | (self.feed('encoder4_2') 52 | .up_conv(3, 5, 512, 1, 1, name='up1')) 53 | (self.feed('up1', 'encoder3_2') 54 | .attention(1, name='attention1') 55 | .conv(3, 5, 256, 1, 1, name='decoder1_1') 56 | .conv(3, 5, 256, 1, 1, name='decoder1_2') 57 | .up_conv(3, 5, 256, 1, 1, name='up2')) 58 | (self.feed('up2', 'encoder2_2') 59 | .attention(1, name='attention2') 60 | .conv(3, 5, 128, 1, 1, name='decoder2_1') 61 | .conv(3, 5, 128, 1, 1, name='decoder2_2') 62 | .up_conv(3, 5, 128, 1, 1, name='up3')) 63 | (self.feed('up3', 'encoder1_2') 64 | .attention(1, name='attention3') 65 | .conv(3, 5, 64, 1, 1, name='decoder3_1') 66 | .conv(3, 5, 64, 1, 1, name='decoder3_2')) 67 | 68 | # classification 69 | (self.feed('decoder3_2') 70 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm 71 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /model_cutie_res_bert.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-12 3 | # xiaohui.zhao@outlook.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIERes(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_residual_bert_8x" # 8x down sampling 10 | 11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 17 | 18 | self.num_vocabs = num_vocabs 19 | self.num_classes = num_classes 20 | self.trainable = trainable 21 | 22 | self.embedding_size = params.embedding_size 23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 26 | 27 | self.layer_inputs = [] 28 | self.embedding_table = None 29 | self.setup() 30 | 31 | 32 | def setup(self): 33 | # input 34 | (self.feed('data') 35 | .bert_embed(self.num_vocabs, 768, name='embeddings', trainable=False)) 36 | 37 | # encoder 38 | (self.feed('embeddings') 39 | .conv(3, 5, 64, 1, 1, name='encoder1_1') 40 | .conv(3, 5, 128, 1, 1, name='encoder1_2') 41 | .max_pool(2, 2, 2, 2, name='pool1') 42 | .conv(3, 5, 128, 1, 1, name='encoder2_1') 43 | .conv(3, 5, 256, 1, 1, name='encoder2_2') 44 | .max_pool(2, 2, 2, 2, name='pool2') 45 | .conv(3, 5, 256, 1, 1, name='encoder3_1') 46 | .conv(3, 5, 512, 1, 1, name='encoder3_2') 47 | .max_pool(2, 2, 2, 2, name='pool3') 48 | .conv(3, 5, 512, 1, 1, name='encoder4_1') 49 | .conv(3, 5, 512, 1, 1, name='encoder4_2')) 50 | 51 | # decoder 52 | (self.feed('encoder4_2') 53 | .up_conv(3, 5, 512, 1, 1, name='up1')) 54 | (self.feed('up1', 'encoder3_2') 55 | .concat(3, name='concat1') 56 | .conv(3, 5, 256, 1, 1, name='decoder1_1') 57 | .conv(3, 5, 256, 1, 1, name='decoder1_2') 58 | .up_conv(3, 5, 256, 1, 1, name='up2')) 59 | (self.feed('up2', 'encoder2_2') 60 | .concat(3, name='concat2') 61 | .conv(3, 5, 128, 1, 1, name='decoder2_1') 62 | .conv(3, 5, 128, 1, 1, name='decoder2_2') 63 | .up_conv(3, 5, 128, 1, 1, name='up3')) 64 | (self.feed('up3', 'encoder1_2') 65 | .concat(3, name='concat3') 66 | .conv(3, 5, 64, 1, 1, name='decoder3_1') 67 | .conv(3, 5, 64, 1, 1, name='decoder3_2')) 68 | 69 | # classification 70 | (self.feed('decoder3_2') 71 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm 72 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /deprecated/model_cutie_res_att_bert.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-12 3 | # xiaohui.zhao@accenture.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIERes(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_residual_attention_bert_8x" # 8x down sampling 10 | 11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 17 | 18 | self.num_vocabs = num_vocabs 19 | self.num_classes = num_classes 20 | self.trainable = trainable 21 | 22 | self.embedding_size = params.embedding_size 23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 26 | 27 | self.layer_inputs = [] 28 | self.embedding_table = None 29 | self.setup() 30 | 31 | 32 | def setup(self): 33 | # input 34 | (self.feed('data') 35 | .bert_embed(self.num_vocabs, 768, name='embeddings', trainable=False)) 36 | 37 | # encoder 38 | (self.feed('embeddings') 39 | .conv(3, 5, 64, 1, 1, name='encoder1_1') 40 | .conv(3, 5, 128, 1, 1, name='encoder1_2') 41 | .max_pool(2, 2, 2, 2, name='pool1') 42 | .conv(3, 5, 128, 1, 1, name='encoder2_1') 43 | .conv(3, 5, 256, 1, 1, name='encoder2_2') 44 | .max_pool(2, 2, 2, 2, name='pool2') 45 | .conv(3, 5, 256, 1, 1, name='encoder3_1') 46 | .conv(3, 5, 512, 1, 1, name='encoder3_2') 47 | .max_pool(2, 2, 2, 2, name='pool3') 48 | .conv(3, 5, 512, 1, 1, name='encoder4_1') 49 | .conv(3, 5, 512, 1, 1, name='encoder4_2')) 50 | 51 | # decoder 52 | (self.feed('encoder4_2') 53 | .up_conv(3, 5, 512, 1, 1, name='up1')) 54 | (self.feed('up1', 'encoder3_2') 55 | .attention(1, name='attention1') 56 | .conv(3, 5, 256, 1, 1, name='decoder1_1') 57 | .conv(3, 5, 256, 1, 1, name='decoder1_2') 58 | .up_conv(3, 5, 256, 1, 1, name='up2')) 59 | (self.feed('up2', 'encoder2_2') 60 | .attention(1, name='attention2') 61 | .conv(3, 5, 128, 1, 1, name='decoder2_1') 62 | .conv(3, 5, 128, 1, 1, name='decoder2_2') 63 | .up_conv(3, 5, 128, 1, 1, name='up3')) 64 | (self.feed('up3', 'encoder1_2') 65 | .attention(1, name='attention3') 66 | .conv(3, 5, 64, 1, 1, name='decoder3_1') 67 | .conv(3, 5, 64, 1, 1, name='decoder3_2')) 68 | 69 | # classification 70 | (self.feed('decoder3_2') 71 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm 72 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /model_cutie_aspp.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2019-03 3 | # xiaohui.zhao@outlook.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIERes(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_atrousSPP" # 10 | 11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='data_grid') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.data_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data_image') # not used in CUTIEv1 14 | self.ps_1d_indices = tf.placeholder(tf.int32, shape=[None, None], name='ps_1d_indices') # not used in CUTIEv1 15 | 16 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 17 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 18 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1 19 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 20 | self.layers = dict({'data_grid': self.data_grid, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 21 | 22 | self.num_vocabs = num_vocabs 23 | self.num_classes = num_classes 24 | self.trainable = trainable 25 | 26 | self.embedding_size = params.embedding_size 27 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 28 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 29 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 30 | 31 | self.layer_inputs = [] 32 | self.setup() 33 | 34 | 35 | def setup(self): 36 | # input 37 | (self.feed('data_grid') 38 | .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout)) 39 | 40 | # encoder 41 | (self.feed('embedding') 42 | .conv(3, 5, 256, 1, 1, name='encoder1_1') 43 | .conv(3, 5, 256, 1, 1, name='encoder1_2') 44 | .conv(3, 5, 256, 1, 1, name='encoder1_3') 45 | .conv(3, 5, 256, 1, 1, name='encoder1_4') 46 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_5') 47 | .dilate_conv(3, 5, 256, 1, 1, 4, name='encoder1_6') 48 | .dilate_conv(3, 5, 256, 1, 1, 8, name='encoder1_7') 49 | .dilate_conv(3, 5, 256, 1, 1, 16, name='encoder1_8')) 50 | 51 | # Atrous Spatial Pyramid Pooling module 52 | #(self.feed('encoder1_8') 53 | # .conv(1, 1, 256, 1, 1, name='aspp_0')) 54 | (self.feed('encoder1_8') 55 | .dilate_conv(3, 5, 256, 1, 1, 4, name='aspp_1')) 56 | (self.feed('encoder1_8') 57 | .dilate_conv(3, 5, 256, 1, 1, 8, name='aspp_2')) 58 | (self.feed('encoder1_8') 59 | .dilate_conv(3, 5, 256, 1, 1, 16, name='aspp_3')) 60 | (self.feed('encoder1_8') 61 | .global_pool(name='aspp_4')) 62 | (self.feed('aspp_1', 'aspp_2', 'aspp_3', 'aspp_4') 63 | .concat(3, name='aspp_concat') 64 | .conv(1, 1, 256, 1, 1, name='aspp_1x1')) 65 | 66 | # combine low level features 67 | (self.feed('encoder1_1', 'aspp_1x1') 68 | .concat(3, name='concat1') 69 | .conv(3, 5, 64, 1, 1, name='decoder1_1')) 70 | 71 | # classification 72 | (self.feed('decoder1_1') 73 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm 74 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /model_cutie2_fpn.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2019-04 3 | # xiaohui.zhao@outlook.com 4 | import tensorflow as tf 5 | from model_cutie2 import CUTIE2 as CUTIE 6 | 7 | class CUTIE2(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE2_dilate" # 10 | 11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='data_grid') 12 | self.data_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data_image') 13 | self.ps_1d_indices = tf.placeholder(tf.int32, shape=[None, None], name='ps_1d_indices') 14 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 15 | 16 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 17 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 18 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1 19 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 20 | self.layers = dict({'data_grid': self.data_grid, 'data_image': self.data_image, 'ps_1d_indices': self.ps_1d_indices, 21 | 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 22 | 23 | self.num_vocabs = num_vocabs 24 | self.num_classes = num_classes 25 | self.trainable = trainable 26 | 27 | self.embedding_size = params.embedding_size 28 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 29 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 30 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 31 | 32 | self.layer_inputs = [] 33 | self.setup() 34 | 35 | 36 | def setup(self): 37 | ## grid 38 | (self.feed('data_grid') 39 | .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout)) 40 | 41 | ## image 42 | (self.feed('data_image') 43 | .conv(3, 3, 32, 1, 1, name='image_encoder1_1') 44 | .conv(3, 3, 64, 2, 2, name='image_encoder1_1') 45 | .conv(3, 3, 64, 1, 1, name='image_encoder1_2') 46 | .conv(3, 3, 128, 2, 2, name='image_encoder1_2') 47 | .conv(3, 3, 128, 1, 1, name='image_encoder1_2') 48 | .conv(3, 3, 256, 2, 2, name='image_encoder1_2')) 49 | 50 | # feature map positional mapping 51 | (self.feed('image_featuremap', 'ps_1d_indices', 'data_grid') 52 | .positional_sampling(32, name='positional_sampling')) 53 | 54 | ## concate image with grid 55 | (self.feed('positional_sampling', 'embedding') 56 | .concat(3, name='concat') 57 | .conv(1, 1, 256, 1, 1, name='feature_fuser')) 58 | 59 | # encoder 60 | (self.feed('feature_fuser') 61 | .conv(3, 5, 256, 1, 1, name='encoder1_1') 62 | .conv(3, 5, 256, 1, 1, name='encoder1_2') 63 | .conv(3, 5, 256, 1, 1, name='encoder1_3') 64 | .conv(3, 5, 256, 1, 1, name='encoder1_4') 65 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_5') 66 | .dilate_conv(3, 5, 256, 1, 1, 4, name='encoder1_6') 67 | .dilate_conv(3, 5, 256, 1, 1, 8, name='encoder1_7') 68 | .dilate_conv(3, 5, 256, 1, 1, 16, name='encoder1_8')) 69 | 70 | # Atrous Spatial Pyramid Pooling module 71 | #(self.feed('encoder1_8') 72 | # .conv(1, 1, 256, 1, 1, name='aspp_0')) 73 | (self.feed('encoder1_8') 74 | .dilate_conv(3, 5, 256, 1, 1, 4, name='aspp_1')) 75 | (self.feed('encoder1_8') 76 | .dilate_conv(3, 5, 256, 1, 1, 8, name='aspp_2')) 77 | (self.feed('encoder1_8') 78 | .dilate_conv(3, 5, 256, 1, 1, 16, name='aspp_3')) 79 | (self.feed('encoder1_8') 80 | .global_pool(name='aspp_4')) 81 | (self.feed('aspp_1', 'aspp_2', 'aspp_3', 'aspp_4') 82 | .concat(3, name='aspp_concat') 83 | .conv(1, 1, 256, 1, 1, name='aspp_1x1')) 84 | 85 | # combine low level features 86 | (self.feed('encoder1_1', 'aspp_1x1') 87 | .concat(3, name='concat1') 88 | .conv(3, 5, 64, 1, 1, name='decoder1_1')) 89 | 90 | # classification 91 | (self.feed('decoder1_1') 92 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm 93 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /model_cutie2_dilate.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2019-04 3 | # xiaohui.zhao@outlook.com 4 | import tensorflow as tf 5 | from model_cutie2 import CUTIE2 as CUTIE 6 | 7 | class CUTIE2(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE2_dilate" # 10 | 11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='data_grid') 12 | self.data_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data_image') 13 | self.ps_1d_indices = tf.placeholder(tf.int32, shape=[None, None], name='ps_1d_indices') 14 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 15 | 16 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 17 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 18 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1 19 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 20 | self.layers = dict({'data_grid': self.data_grid, 'data_image': self.data_image, 'ps_1d_indices': self.ps_1d_indices, 21 | 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 22 | 23 | self.num_vocabs = num_vocabs 24 | self.num_classes = num_classes 25 | self.trainable = trainable 26 | 27 | self.embedding_size = params.embedding_size 28 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 29 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 30 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 31 | 32 | self.layer_inputs = [] 33 | self.setup() 34 | 35 | 36 | def setup(self): 37 | ## grid 38 | (self.feed('data_grid') 39 | .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout)) 40 | 41 | ## image 42 | (self.feed('data_image') 43 | .conv(3, 3, 128, 1, 1, name='image_encoder1_1') 44 | .conv(3, 3, 128, 1, 1, name='image_encoder1_2') 45 | .dilate_conv(3, 3, 128, 1, 1, 2, name='image_encoder1_5') 46 | .dilate_conv(3, 3, 128, 1, 1, 4, name='image_encoder1_6') 47 | .dilate_conv(3, 3, 128, 1, 1, 8, name='image_encoder1_7') 48 | .dilate_conv(3, 3, 128, 1, 1, 16, name='image_encoder1_8')) 49 | 50 | # feature map positional mapping 51 | (self.feed('image_encoder1_8', 'ps_1d_indices', 'data_grid') 52 | .positional_sampling(128, name='positional_sampling')) 53 | 54 | ## concate image with grid 55 | (self.feed('positional_sampling', 'embedding') 56 | .concat(3, name='concat') 57 | .conv(1, 1, 256, 1, 1, name='feature_fuser')) 58 | 59 | # encoder 60 | (self.feed('feature_fuser') 61 | .conv(3, 5, 256, 1, 1, name='encoder1_1') 62 | .conv(3, 5, 256, 1, 1, name='encoder1_2') 63 | .conv(3, 5, 256, 1, 1, name='encoder1_3') 64 | .conv(3, 5, 256, 1, 1, name='encoder1_4') 65 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_5') 66 | .dilate_conv(3, 5, 256, 1, 1, 4, name='encoder1_6') 67 | .dilate_conv(3, 5, 256, 1, 1, 8, name='encoder1_7') 68 | .dilate_conv(3, 5, 256, 1, 1, 16, name='encoder1_8')) 69 | 70 | # Atrous Spatial Pyramid Pooling module 71 | #(self.feed('encoder1_8') 72 | # .conv(1, 1, 256, 1, 1, name='aspp_0')) 73 | (self.feed('encoder1_8') 74 | .dilate_conv(3, 5, 256, 1, 1, 4, name='aspp_1')) 75 | (self.feed('encoder1_8') 76 | .dilate_conv(3, 5, 256, 1, 1, 8, name='aspp_2')) 77 | (self.feed('encoder1_8') 78 | .dilate_conv(3, 5, 256, 1, 1, 16, name='aspp_3')) 79 | (self.feed('encoder1_8') 80 | .global_pool(name='aspp_4')) 81 | (self.feed('aspp_1', 'aspp_2', 'aspp_3', 'aspp_4') 82 | .concat(3, name='aspp_concat') 83 | .conv(1, 1, 256, 1, 1, name='aspp_1x1')) 84 | 85 | # combine low level features 86 | (self.feed('encoder1_1', 'aspp_1x1') 87 | .concat(3, name='concat1') 88 | .conv(3, 5, 64, 1, 1, name='decoder1_1')) 89 | 90 | # classification 91 | (self.feed('decoder1_1') 92 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm 93 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /model_cutie2_aspp.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2019-04 3 | # xiaohui.zhao@outlook.com 4 | import tensorflow as tf 5 | from model_cutie2 import CUTIE2 as CUTIE 6 | 7 | class CUTIE2(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE2_dilate" # 10 | 11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='data_grid') 12 | self.data_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data_image') 13 | self.ps_1d_indices = tf.placeholder(tf.int32, shape=[None, None], name='ps_1d_indices') 14 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 15 | 16 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 17 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 18 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1 19 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 20 | self.layers = dict({'data_grid': self.data_grid, 'data_image': self.data_image, 'ps_1d_indices': self.ps_1d_indices, 21 | 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 22 | 23 | self.num_vocabs = num_vocabs 24 | self.num_classes = num_classes 25 | self.trainable = trainable 26 | 27 | self.embedding_size = params.embedding_size 28 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 29 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 30 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 31 | 32 | self.layer_inputs = [] 33 | self.setup() 34 | 35 | 36 | def setup(self): 37 | ## grid 38 | (self.feed('data_grid') 39 | .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout)) 40 | 41 | ## image 42 | (self.feed('data_image') 43 | .conv(3, 3, 32, 1, 1, name='image_encoder1_1') 44 | .conv(3, 3, 32, 1, 1, name='image_encoder1_2') 45 | .dilate_conv(3, 3, 32, 1, 1, 2, name='image_encoder1_5') 46 | .dilate_conv(3, 3, 32, 1, 1, 2, name='image_encoder1_6') 47 | .dilate_conv(3, 3, 32, 1, 1, 2, name='image_encoder1_7') 48 | .dilate_conv(3, 3, 32, 1, 1, 2, name='image_encoder1_8')) 49 | (self.feed('image_encoder1_8') 50 | .dilate_conv(3, 3, 32, 1, 1, 4, name='image_aspp_1')) 51 | (self.feed('image_encoder1_8') 52 | .dilate_conv(3, 3, 32, 1, 1, 8, name='image_aspp_2')) 53 | (self.feed('image_encoder1_8') 54 | .dilate_conv(3, 3, 32, 1, 1, 16, name='image_aspp_3')) 55 | (self.feed('image_encoder1_8') 56 | .global_pool(name='image_aspp_4')) 57 | (self.feed('image_aspp_1', 'image_aspp_2', 'image_aspp_3', 'image_aspp_4') 58 | .concat(3, name='image_aspp_concat') 59 | .conv(1, 1, 32, 1, 1, name='image_aspp_1x1')) 60 | (self.feed('image_encoder1_1', 'image_aspp_1x1') 61 | .concat(3, name='image_concat1') 62 | .conv(3, 3, 32, 1, 1, name='image_featuremap')) 63 | 64 | # feature map positional mapping 65 | (self.feed('image_featuremap', 'ps_1d_indices', 'data_grid') 66 | .positional_sampling(32, name='positional_sampling')) 67 | 68 | ## concate image with grid 69 | (self.feed('positional_sampling', 'embedding') 70 | .concat(3, name='concat') 71 | .conv(1, 1, 256, 1, 1, name='feature_fuser')) 72 | 73 | # encoder 74 | (self.feed('feature_fuser') 75 | .conv(3, 5, 256, 1, 1, name='encoder1_1') 76 | .conv(3, 5, 256, 1, 1, name='encoder1_2') 77 | .conv(3, 5, 256, 1, 1, name='encoder1_3') 78 | .conv(3, 5, 256, 1, 1, name='encoder1_4') 79 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_5') 80 | .dilate_conv(3, 5, 256, 1, 1, 4, name='encoder1_6') 81 | .dilate_conv(3, 5, 256, 1, 1, 8, name='encoder1_7') 82 | .dilate_conv(3, 5, 256, 1, 1, 16, name='encoder1_8')) 83 | 84 | # Atrous Spatial Pyramid Pooling module 85 | #(self.feed('encoder1_8') 86 | # .conv(1, 1, 256, 1, 1, name='aspp_0')) 87 | (self.feed('encoder1_8') 88 | .dilate_conv(3, 5, 256, 1, 1, 4, name='aspp_1')) 89 | (self.feed('encoder1_8') 90 | .dilate_conv(3, 5, 256, 1, 1, 8, name='aspp_2')) 91 | (self.feed('encoder1_8') 92 | .dilate_conv(3, 5, 256, 1, 1, 16, name='aspp_3')) 93 | (self.feed('encoder1_8') 94 | .global_pool(name='aspp_4')) 95 | (self.feed('aspp_1', 'aspp_2', 'aspp_3', 'aspp_4') 96 | .concat(3, name='aspp_concat') 97 | .conv(1, 1, 256, 1, 1, name='aspp_1x1')) 98 | 99 | # combine low level features 100 | (self.feed('encoder1_1', 'aspp_1x1') 101 | .concat(3, name='concat1') 102 | .conv(3, 5, 64, 1, 1, name='decoder1_1')) 103 | 104 | # classification 105 | (self.feed('decoder1_1') 106 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm 107 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /model_cutie.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-12 3 | # xh.zhao@outlook.com 4 | import tensorflow as tf 5 | from model_framework import Model 6 | 7 | 8 | class CUTIE(Model): 9 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 10 | self.name = "CUTIE_benchmark" 11 | 12 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table') 13 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 14 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 15 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 16 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 17 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights': self.ghm_weights}) 18 | 19 | self.num_vocabs = num_vocabs 20 | self.num_classes = num_classes 21 | self.trainable = trainable 22 | 23 | self.embedding_size = params.embedding_size 24 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 25 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 26 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 27 | 28 | self.layer_inputs = [] 29 | self.setup() 30 | 31 | 32 | def setup(self): 33 | # input 34 | (self.feed('data') 35 | .embed(self.num_vocabs, self.embedding_size, name='embedding')) 36 | 37 | # encoder 38 | (self.feed('embedding') 39 | .conv(3, 5, 64, 1, 1, name='encoder1_1') 40 | .conv(3, 5, 128, 1, 1, name='encoder1_2') 41 | .max_pool(2, 2, 2, 2, name='pool1') 42 | .conv(3, 5, 128, 1, 1, name='encoder2_1') 43 | .conv(3, 5, 256, 1, 1, name='encoder2_2') 44 | .max_pool(2, 2, 2, 2, name='pool2') 45 | .conv(3, 5, 256, 1, 1, name='encoder3_1') 46 | .conv(3, 5, 512, 1, 1, name='encoder3_2') 47 | .max_pool(2, 2, 2, 2, name='pool3') 48 | .conv(3, 5, 512, 1, 1, name='encoder4_1') 49 | .conv(3, 5, 512, 1, 1, name='encoder4_2')) 50 | 51 | # decoder 52 | (self.feed('encoder4_2') 53 | .up_conv(3, 5, 512, 1, 1, name='up1') 54 | .conv(3, 5, 256, 1, 1, name='decoder1_1') 55 | .conv(3, 5, 256, 1, 1, name='decoder1_2') 56 | .up_conv(3, 5, 256, 1, 1, name='up2') 57 | .conv(3, 5, 128, 1, 1, name='decoder2_1') 58 | .conv(3, 5, 128, 1, 1, name='decoder2_2') 59 | .up_conv(3, 5, 128, 1, 1, name='up3') 60 | .conv(3, 5, 64, 1, 1, name='decoder3_1') 61 | .conv(3, 5, 64, 1, 1, name='decoder3_2')) 62 | 63 | # classification 64 | (self.feed('decoder3_2') 65 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') 66 | .softmax(name='softmax')) 67 | 68 | def disp_results(self, data_input, data_label, model_output, threshold): 69 | data_input_flat = data_input.reshape([-1]) # [b * h * w] 70 | labels = [] # [b * h * w, classes] 71 | for item in data_label.reshape([-1]): 72 | labels.append([i==item for i in range(self.num_classes)]) 73 | logits = model_output.reshape([-1, self.num_classes]) # [b * h * w, classes] 74 | 75 | # ignore none word input 76 | labels_flat = [] 77 | results_flat = [] 78 | for idx, item in enumerate(data_input_flat): 79 | if item != 0: 80 | labels_flat.extend(labels[idx]) 81 | results_flat.extend(logits[idx] > threshold) 82 | 83 | num_p = sum(labels_flat) 84 | num_n = sum([1-label for label in labels_flat]) 85 | num_all = len(results_flat) 86 | num_correct = sum([True for i in range(num_all) if labels_flat[i] == results_flat[i]]) 87 | 88 | labels_flat_p = [label!=0 for label in labels_flat] 89 | labels_flat_n = [label==0 for label in labels_flat] 90 | num_tp = sum([labels_flat_p[i] * results_flat[i] for i in range(num_all)]) 91 | num_tn = sum([labels_flat_n[i] * (not results_flat[i]) for i in range(num_all)]) 92 | num_fp = num_n - num_tp 93 | num_fn = num_p - num_tp 94 | 95 | # accuracy, precision, recall 96 | accuracy = num_correct / num_all 97 | precision = num_tp / (num_tp + num_fp) 98 | recall = num_tp / (num_tp + num_fn) 99 | 100 | return accuracy, precision, recall 101 | 102 | 103 | def inference(self): 104 | return self.get_output('softmax') #cls_logits 105 | 106 | 107 | def build_loss(self): 108 | labels = self.get_output('gt_classes') 109 | cls_logits = self.get_output('cls_logits') 110 | cls_logits = tf.cond(self.use_ghm, lambda: cls_logits*self.get_output('ghm_weights'), 111 | lambda: cls_logits, name="GradientHarmonizingMechanism") 112 | 113 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=cls_logits) 114 | 115 | with tf.variable_scope('HardNegativeMining'): 116 | labels = tf.reshape(labels, [-1]) 117 | cross_entropy = tf.reshape(cross_entropy, [-1]) 118 | 119 | fg_idx = tf.where(tf.not_equal(labels, 0)) 120 | fgs = tf.gather(cross_entropy, fg_idx) 121 | bg_idx = tf.where(tf.equal(labels, 0)) 122 | bgs = tf.gather(cross_entropy, bg_idx) 123 | 124 | num = self.hard_negative_ratio * tf.shape(fgs)[0] 125 | num_bg = tf.cond(tf.shape(bgs)[0]>time per step: %.2fs <<'%(timer_stop - timer_start)) 102 | 103 | if not params.is_table: 104 | recall, acc_strict, acc_soft, res = cal_accuracy(data_loader, np.array(data['grid_table']), 105 | np.array(data['gt_classes']), model_output_val, 106 | np.array(data['label_mapids']), data['bbox_mapids']) 107 | else: 108 | recall, acc_strict, acc_soft, res = cal_accuracy_table(data_loader, np.array(data['grid_table']), 109 | np.array(data['gt_classes']), model_output_val, 110 | np.array(data['label_mapids']), data['bbox_mapids']) 111 | # recall, acc_strict, acc_soft, res = cal_save_results(data_loader, np.array(data['grid_table']), 112 | # np.array(data['gt_classes']), model_output_val, 113 | # np.array(data['label_mapids']), data['bbox_mapids'], 114 | # data['file_name'][0], params.save_prefix) 115 | recalls += [recall] 116 | accs_strict += [acc_strict] 117 | accs_soft += [acc_soft] 118 | if acc_strict != 1: 119 | print(res.decode()) # show res for current batch 120 | 121 | # visualize result 122 | shape = data['shape'] 123 | file_name = data['file_name'][0] # use one single file_name 124 | bboxes = data['bboxes'][file_name] 125 | if not params.is_table: 126 | vis_bbox(data_loader, params.doc_path, np.array(data['grid_table'])[0], 127 | np.array(data['gt_classes'])[0], np.array(model_output_val)[0], file_name, 128 | np.array(bboxes), shape) 129 | else: 130 | vis_table(data_loader, params.doc_path, np.array(data['grid_table'])[0], 131 | np.array(data['gt_classes'])[0], np.array(model_output_val)[0], file_name, 132 | np.array(bboxes), shape) 133 | 134 | recall = sum(recalls) / len(recalls) 135 | acc_strict = sum(accs_strict) / len(accs_strict) 136 | acc_soft = sum(accs_soft) / len(accs_soft) 137 | print('EVALUATION ACC (Recall/Acc): %.3f / %.3f (%.3f) \n'%(recall, acc_strict, acc_soft)) 138 | -------------------------------------------------------------------------------- /model_cutie2.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2019-04 3 | # xiaohui.zhao@outlook.com 4 | import tensorflow as tf 5 | from model_framework import Model 6 | 7 | class CUTIE2(Model): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTTIE_benchmark" # 10 | 11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid') 12 | self.data_image = tf.placeholder(tf.uint8, shape=[None, None, None, 3], name='image') 13 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 14 | 15 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 16 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 17 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1 18 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 19 | self.layers = dict({'data_grid': self.data_grid, 'data_image': self.data_image, 20 | 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 21 | 22 | self.num_vocabs = num_vocabs 23 | self.num_classes = num_classes 24 | self.trainable = trainable 25 | 26 | self.embedding_size = params.embedding_size 27 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 28 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 29 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 30 | 31 | self.layer_inputs = [] 32 | self.setup() 33 | 34 | 35 | def setup(self): 36 | # input 37 | # (self.feed('data_grid') 38 | # .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout)) 39 | # (self.feed('data_image') 40 | # .conv(3, 3, 64, 1, 1, name='image_encoder1_1') 41 | # .conv(3, 3, 128, 1, 1, name='image_encoder1_2')) 42 | # 43 | # # encoder 44 | # (self.feed('embedding') 45 | # .conv(3, 5, 64, 1, 1, name='encoder1_1') 46 | # .conv(3, 5, 128, 1, 1, name='encoder1_2') 47 | # .max_pool(2, 2, 2, 2, name='pool1') 48 | # .conv(3, 5, 128, 1, 1, name='encoder2_1') 49 | # .conv(3, 5, 256, 1, 1, name='encoder2_2') 50 | # .max_pool(2, 2, 2, 2, name='pool2') 51 | # .conv(3, 5, 256, 1, 1, name='encoder3_1') 52 | # .conv(3, 5, 512, 1, 1, name='encoder3_2') 53 | # .max_pool(2, 2, 2, 2, name='pool3') 54 | # .conv(3, 5, 512, 1, 1, name='encoder4_1') 55 | # .conv(3, 5, 512, 1, 1, name='encoder4_2')) 56 | # 57 | # # decoder 58 | # (self.feed('encoder4_2') 59 | # .up_conv(3, 5, 512, 1, 1, name='up1') 60 | # .conv(3, 5, 256, 1, 1, name='decoder1_1') 61 | # .conv(3, 5, 256, 1, 1, name='decoder1_2') 62 | # .up_conv(3, 5, 256, 1, 1, name='up2') 63 | # .conv(3, 5, 128, 1, 1, name='decoder2_1') 64 | # .conv(3, 5, 128, 1, 1, name='decoder2_2') 65 | # .up_conv(3, 5, 128, 1, 1, name='up3') 66 | # .conv(3, 5, 64, 1, 1, name='decoder3_1') 67 | # .conv(3, 5, 64, 1, 1, name='decoder3_2')) 68 | # 69 | # # classification 70 | # (self.feed('decoder3_2') 71 | # .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') 72 | # .softmax(name='softmax')) 73 | pass 74 | 75 | def disp_results(self, data_input, data_label, model_output, threshold): 76 | data_input_flat = data_input.reshape([-1]) # [b * h * w] 77 | labels = [] # [b * h * w, classes] 78 | for item in data_label.reshape([-1]): 79 | labels.append([i==item for i in range(self.num_classes)]) 80 | logits = model_output.reshape([-1, self.num_classes]) # [b * h * w, classes] 81 | 82 | # ignore none word input 83 | labels_flat = [] 84 | results_flat = [] 85 | for idx, item in enumerate(data_input_flat): 86 | if item != 0: 87 | labels_flat.extend(labels[idx]) 88 | results_flat.extend(logits[idx] > threshold) 89 | 90 | num_p = sum(labels_flat) 91 | num_n = sum([1-label for label in labels_flat]) 92 | num_all = len(results_flat) 93 | num_correct = sum([True for i in range(num_all) if labels_flat[i] == results_flat[i]]) 94 | 95 | labels_flat_p = [label!=0 for label in labels_flat] 96 | labels_flat_n = [label==0 for label in labels_flat] 97 | num_tp = sum([labels_flat_p[i] * results_flat[i] for i in range(num_all)]) 98 | num_tn = sum([labels_flat_n[i] * (not results_flat[i]) for i in range(num_all)]) 99 | num_fp = num_n - num_tp 100 | num_fn = num_p - num_tp 101 | 102 | # accuracy, precision, recall 103 | accuracy = num_correct / num_all 104 | precision = num_tp / (num_tp + num_fp) 105 | recall = num_tp / (num_tp + num_fn) 106 | 107 | return accuracy, precision, recall 108 | 109 | 110 | def inference(self): 111 | return self.get_output('softmax') #cls_logits 112 | 113 | 114 | def build_loss(self): 115 | labels = self.get_output('gt_classes') 116 | cls_logits = self.get_output('cls_logits') 117 | cls_logits = tf.cond(self.use_ghm, lambda: cls_logits*self.get_output('ghm_weights'), 118 | lambda: cls_logits, name="GradientHarmonizingMechanism") 119 | 120 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=cls_logits) 121 | 122 | with tf.variable_scope('HardNegativeMining'): 123 | labels = tf.reshape(labels, [-1]) 124 | cross_entropy = tf.reshape(cross_entropy, [-1]) 125 | 126 | fg_idx = tf.where(tf.not_equal(labels, 0)) 127 | fgs = tf.gather(cross_entropy, fg_idx) 128 | bg_idx = tf.where(tf.equal(labels, 0)) 129 | bgs = tf.gather(cross_entropy, bg_idx) 130 | 131 | num = self.hard_negative_ratio * tf.shape(fgs)[0] 132 | num_bg = tf.cond(tf.shape(bgs)[0] Download start') 167 | task_info = download_info(args.server, args.project, args.task) 168 | 169 | if args.label is True: 170 | print('==> Download label info start...') 171 | download_label_info(task_info) 172 | print('==> Download label info done.') 173 | 174 | if args.postprocessing is True: 175 | print('==> Download postprocessing info start...') 176 | download_postprocessing_result(task_info) 177 | print('==> Download postprocessing info done.') 178 | 179 | if args.source_file is True: 180 | print('==> Download source file start...') 181 | output_path = 'task_{id}_source_file_{time}'.format(id=args.task, time=int(time.time())) 182 | download_file(output_path, 'source_file_path', task_info) 183 | print('==> Download source file done.') 184 | 185 | if args.quality_file is True: 186 | print('==> Download quality file start...') 187 | output_path = 'task_{id}_quality_file_{time}'.format(id=args.task, time=int(time.time())) 188 | download_file(output_path, 'quality_file_path', task_info) 189 | print('==> Download quality file done.') 190 | 191 | if args.engine_raw_result is True: 192 | print('==> Download engine raw file start...') 193 | output_path = 'task_{id}_engine_raw_file_{time}'.format(id=args.task, time=int(time.time())) 194 | download_file(output_path, 'engine_raw_result_path', task_info) 195 | print('==> Download engine raw file done.') 196 | 197 | if args.engine_result is True: 198 | print('==> Download engine file start...') 199 | output_path = 'task_{id}_engine_file_{time}'.format(id=args.task, time=int(time.time())) 200 | download_file(output_path, 'engine_result_path', task_info) 201 | print('==> Download engine file done.') 202 | -------------------------------------------------------------------------------- /model_cutie_hr.py: -------------------------------------------------------------------------------- 1 | # written by Xiaohui Zhao 2 | # 2018-12 3 | # xiaohui.zhao@outlook.com 4 | import tensorflow as tf 5 | from model_cutie import CUTIE 6 | 7 | class CUTIERes(CUTIE): 8 | def __init__(self, num_vocabs, num_classes, params, trainable=True): 9 | self.name = "CUTIE_highresolution_8x" # 8x down sampling 10 | 11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='data_grid') 12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes') 13 | self.data_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data_image') # not used in CUTIEv1 14 | self.ps_1d_indices = tf.placeholder(tf.int32, shape=[None, None], name='ps_1d_indices') # not used in CUTIEv1 15 | 16 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm 17 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu' 18 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1 19 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights') 20 | self.layers = dict({'data_grid': self.data_grid, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights}) 21 | 22 | self.num_vocabs = num_vocabs 23 | self.num_classes = num_classes 24 | self.trainable = trainable 25 | 26 | self.embedding_size = params.embedding_size 27 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0 28 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0 29 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0 30 | 31 | self.layer_inputs = [] 32 | self.setup() 33 | 34 | 35 | def setup(self): 36 | # input 37 | (self.feed('data_grid') 38 | .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout)) 39 | 40 | # stage 1 block 1 41 | (self.feed('embedding') 42 | .conv(3, 5, 64, 1, 1, name='encoder1_1_1') 43 | .conv(3, 5, 64, 1, 1, name='encoder1_1_2') 44 | .conv(3, 5, 128, 2, 2, name='down1_1_2')) 45 | 46 | 47 | ## introduce stage 2 48 | # stage 1 block 2 49 | (self.feed('encoder1_1_2') 50 | .conv(3, 5, 64, 1, 1, name='encoder1_2_1') 51 | .conv(3, 5, 64, 1, 1, name='encoder1_2_2') 52 | .conv(3, 5, 128, 2, 2, name='down1_2_2')) 53 | # stage 2 block 2 54 | (self.feed('down1_1_2') 55 | .conv(3, 5, 128, 1, 1, name='encoder2_2_1') 56 | .conv(3, 5, 128, 1, 1, name='encoder2_2_2') 57 | .up_conv(3, 5, 64, 1, 1, factor=2, name='up2_2_1')) 58 | 59 | 60 | # stage 1 block 3 61 | (self.feed('encoder1_1_2', 'encoder1_2_2', 'up2_2_1') 62 | .add(name='add1_3') 63 | .conv(3, 5, 64, 1, 1, name='encoder1_3_1') 64 | .conv(3, 5, 64, 1, 1, name='encoder1_3_2') 65 | .conv(3, 5, 128, 2, 2, name='down1_3_2')) 66 | (self.feed('encoder1_3_2') 67 | .conv(3, 5, 256, 4, 4, name='down1_3_3')) 68 | # stage 2 block 3 69 | (self.feed('down1_1_2', 'encoder2_2_2', 'down1_2_2') 70 | .add(name='add2_3') 71 | .conv(3, 5, 128, 1, 1, name='encoder2_3_1') 72 | .conv(3, 5, 128, 1, 1, name='encoder2_3_2') 73 | .up_conv(3, 5, 64, 1, 1, factor=2, name='up2_3_1')) 74 | (self.feed('encoder2_3_2') 75 | .conv(3, 5, 256, 2, 2, name='down2_3_3')) 76 | 77 | 78 | ## introduce stage 3 79 | # stage 1 block 4 80 | (self.feed('add1_3', 'encoder1_3_2', 'up2_3_1') 81 | .add(name='add1_4') 82 | .conv(3, 5, 64, 1, 1, name='encoder1_4_1') 83 | .conv(3, 5, 64, 1, 1, name='encoder1_4_2') 84 | .conv(3, 5, 128, 2, 2, name='down1_4_2')) 85 | (self.feed('encoder1_4_2') 86 | .conv(3, 5, 256, 4, 4, name='down1_4_3')) 87 | # stage 2 block 4 88 | (self.feed('add2_3', 'encoder2_3_2', 'down1_3_2') 89 | .add(name='add2_4') 90 | .conv(3, 5, 128, 1, 1, name='encoder2_4_1') 91 | .conv(3, 5, 128, 1, 1, name='encoder2_4_2') 92 | .up_conv(3, 5, 64, 1, 1, name='up2_4_1')) 93 | (self.feed('encoder2_4_2') 94 | .conv(3, 5, 256, 2, 2, name='down2_4_3')) 95 | # stage 3 block 4 96 | (self.feed('down1_3_3', 'down2_3_3') 97 | .add(name='add3_4') 98 | .conv(3, 5, 256, 1, 1, name='encoder3_4_1') 99 | .conv(3, 5, 256, 1, 1, name='encoder3_4_2') 100 | .up_conv(3, 5, 64, 1, 1, factor=4, name='up3_4_1')) 101 | (self.feed('encoder3_4_2') 102 | .up_conv(3, 5, 128, 1, 1, factor=2, name='up3_4_2')) 103 | 104 | 105 | # stage 1 block 5 106 | (self.feed('add1_4', 'encoder1_4_2', 'up2_4_1', 'up3_4_1') 107 | .add(name='add1_5') 108 | .conv(3, 5, 64, 1, 1, name='encoder1_5_1') 109 | .conv(3, 5, 64, 1, 1, name='encoder1_5_2') 110 | .conv(3, 5, 128, 2, 2, name='down1_5_2')) 111 | (self.feed('encoder1_5_2') 112 | .conv(3, 5, 256, 4, 4, name='down1_5_3')) 113 | (self.feed('encoder1_5_2') 114 | .conv(3, 5, 512, 8, 8, name='down1_5_4')) 115 | # stage 2 block 5 116 | (self.feed('add2_4', 'encoder2_4_2', 'down1_4_2', 'up3_4_2') 117 | .add(name='add2_5') 118 | .conv(3, 5, 128, 1, 1, name='encoder2_5_1') 119 | .conv(3, 5, 128, 1, 1, name='encoder2_5_2') 120 | .up_conv(3, 5, 64, 1, 1, factor=2, name='up2_5_1')) 121 | (self.feed('encoder2_5_2') 122 | .conv(3, 5, 256, 2, 2, name='down2_5_3')) 123 | (self.feed('encoder2_5_2') 124 | .conv(3, 5, 512, 4, 4, name='down2_5_4')) 125 | # stage 3 block 5 126 | (self.feed('add3_4', 'encoder3_4_2', 'down1_4_3', 'down2_4_3') 127 | .add(name='add3_5') 128 | .conv(3, 5, 256, 1, 1, name='encoder3_5_1') 129 | .conv(3, 5, 256, 1, 1, name='encoder3_5_2') 130 | .up_conv(3, 5, 64, 1, 1, factor=4, name='up3_5_1')) 131 | (self.feed('encoder3_5_2') 132 | .up_conv(3, 5, 128, 1, 1, factor=2, name='up3_5_2')) 133 | (self.feed('encoder3_5_2') 134 | .conv(3, 5, 512, 2, 2, name='down3_5_4')) 135 | 136 | 137 | ## introduce stage 4 138 | # stage 1 block 6 139 | (self.feed('add1_5', 'encoder1_5_2', 'up2_5_1', 'up3_5_1') 140 | .add(name='add1_6') 141 | .conv(3, 5, 64, 1, 1, name='encoder1_6_1') 142 | .conv(3, 5, 64, 1, 1, name='encoder1_6_2') 143 | .conv(3, 5, 128, 2, 2, name='down1_6_2')) 144 | (self.feed('encoder1_6_2') 145 | .conv(3, 5, 256, 4, 4, name='down1_6_3')) 146 | (self.feed('encoder1_6_2') 147 | .conv(3, 5, 512, 8, 8, name='down1_6_4')) 148 | # stage 2 block 6 149 | (self.feed('add2_5', 'encoder2_5_2', 'down1_5_2', 'up3_5_2') 150 | .add(name='add2_6') 151 | .conv(3, 5, 128, 1, 1, name='encoder2_6_1') 152 | .conv(3, 5, 128, 1, 1, name='encoder2_6_2') 153 | .up_conv(3, 5, 64, 1, 1, name='up2_6_1')) 154 | (self.feed('encoder2_6_2') 155 | .conv(3, 5, 256, 2, 2, name='down2_6_3')) 156 | (self.feed('encoder2_6_2') 157 | .conv(3, 5, 512, 4, 4, name='down2_6_4')) 158 | # stage 3 block 6 159 | (self.feed('add3_5', 'encoder3_5_2', 'down1_5_3', 'down2_5_3') 160 | .add(name='add3_6') 161 | .conv(3, 5, 256, 1, 1, name='encoder3_6_1') 162 | .conv(3, 5, 256, 1, 1, name='encoder3_6_2') 163 | .up_conv(3, 5, 64, 1, 1, factor=4, name='up3_6_1')) 164 | (self.feed('encoder3_6_2') 165 | .up_conv(3, 5, 128, 1, 1, factor=2, name='up3_6_2')) 166 | (self.feed('encoder3_6_2') 167 | .conv(3, 5, 512, 2, 2, name='down3_6_4')) 168 | # stage 4 block 6 169 | (self.feed('down1_5_4', 'down2_5_4', 'down3_5_4') 170 | .add(name='add4_6') 171 | .conv(3, 5, 512, 1, 1, name='encoder4_6_1') 172 | .conv(3, 5, 512, 1, 1, name='encoder4_6_2') 173 | .up_conv(3, 5, 64, 1, 1, factor=8, name='up4_6_1')) 174 | (self.feed('encoder4_6_2') 175 | .up_conv(3, 5, 128, 1, 1, factor=4, name='up4_6_2')) 176 | (self.feed('encoder4_6_2') 177 | .up_conv(3, 5, 256, 1, 1, factor=2, name='up4_6_3')) 178 | 179 | 180 | # stage 1 block 7 181 | (self.feed('add1_6', 'encoder1_6_2', 'up2_6_1', 'up3_6_1', 'up4_6_1') 182 | .add(name='add1_7') 183 | .conv(3, 5, 64, 1, 1, name='encoder1_7_1') 184 | .conv(3, 5, 64, 1, 1, name='encoder1_7_2') 185 | .conv(3, 5, 128, 2, 2, name='down1_7_2')) 186 | (self.feed('encoder1_7_2') 187 | .conv(3, 5, 256, 4, 4, name='down1_7_3')) 188 | (self.feed('encoder1_7_2') 189 | .conv(3, 5, 512, 8, 8, name='down1_7_4')) 190 | # stage 2 block 7 191 | (self.feed('add2_6', 'encoder2_6_2', 'down1_6_2', 'up3_6_2', 'up4_6_2') 192 | .add(name='add2_7') 193 | .conv(3, 5, 128, 1, 1, name='encoder2_7_1') 194 | .conv(3, 5, 128, 1, 1, name='encoder2_7_2') 195 | .up_conv(3, 5, 64, 1, 1, factor=2, name='up2_7_1')) 196 | (self.feed('encoder2_7_2') 197 | .conv(3, 5, 256, 2, 2, name='down2_7_3')) 198 | (self.feed('encoder2_7_2') 199 | .conv(3, 5, 512, 4, 4, name='down2_7_4')) 200 | # stage 3 block 7 201 | (self.feed('add3_6', 'encoder3_6_2', 'down1_6_3', 'down2_6_3', 'up4_6_3') 202 | .add(name='add3_7') 203 | .conv(3, 5, 256, 1, 1, name='encoder3_7_1') 204 | .conv(3, 5, 256, 1, 1, name='encoder3_7_2') 205 | .up_conv(3, 5, 64, 1, 1, factor=4, name='up3_7_1')) 206 | (self.feed('encoder3_7_2') 207 | .up_conv(3, 5, 128, 1, 1, factor=2, name='up3_7_2')) 208 | (self.feed('encoder3_7_2') 209 | .conv(3, 5, 512, 2, 2, name='down3_7_4')) 210 | # stage 4 block 7 211 | (self.feed('add4_6', 'encoder4_6_2', 'down1_6_4', 'down2_6_4', 'down3_6_4') 212 | .add(name='add4_7') 213 | .conv(3, 5, 512, 1, 1, name='encoder4_7_1') 214 | .conv(3, 5, 512, 1, 1, name='encoder4_7_2') 215 | .up_conv(3, 5, 64, 1, 1, factor=8, name='up4_7_1')) 216 | (self.feed('encoder4_7_2') 217 | .up_conv(3, 5, 128, 1, 1, factor=4, name='up4_7_2')) 218 | (self.feed('encoder4_7_2') 219 | .up_conv(3, 5, 256, 1, 1, factor=2, name='up4_7_3')) 220 | 221 | 222 | # stage 1 block 8 223 | (self.feed('add1_7', 'encoder1_7_2', 'up2_7_1', 'up3_7_1', 'up4_7_1') 224 | .add(name='add1_8') 225 | .conv(3, 5, 64, 1, 1, name='encoder1_8_1') 226 | .conv(3, 5, 64, 1, 1, name='encoder1_8_2')) 227 | # stage 2 block 8 228 | (self.feed('add2_7', 'encoder2_7_2', 'down1_7_2', 'up3_7_2', 'up4_7_2') 229 | .add(name='add2_8') 230 | .conv(3, 5, 128, 1, 1, name='encoder2_8_1') 231 | .conv(3, 5, 128, 1, 1, name='encoder2_8_2') 232 | .up_conv(3, 5, 64, 1, 1, factor=2, name='up2_8_1')) 233 | # stage 3 block 8 234 | (self.feed('add3_7', 'encoder3_7_2', 'down1_7_3', 'down2_7_3', 'up4_7_3') 235 | .add(name='add3_8') 236 | .conv(3, 5, 256, 1, 1, name='encoder3_8_1') 237 | .conv(3, 5, 256, 1, 1, name='encoder3_8_2') 238 | .up_conv(3, 5, 64, 1, 1, factor=4, name='up3_8_1')) 239 | # stage 4 block 8 240 | (self.feed('add4_7', 'encoder4_7_2', 'down1_7_4', 'down2_7_4', 'down3_7_4') 241 | .add(name='add4_8') 242 | .conv(3, 5, 512, 1, 1, name='encoder4_8_1') 243 | .conv(3, 5, 512, 1, 1, name='encoder4_8_2') 244 | .up_conv(3, 5, 64, 1, 1, factor=8, name='up4_8_1')) 245 | 246 | 247 | # stage 1 block 9 248 | (self.feed('add1_8', 'encoder1_8_2', 'up2_8_1', 'up3_8_1', 'up4_8_1') 249 | .add(name='add1_9') 250 | .conv(3, 5, 64, 1, 1, name='encoder1_9_1') 251 | .conv(3, 5, 64, 1, 1, name='encoder1_9_2')) 252 | 253 | 254 | # classification 255 | (self.feed('encoder1_9_2') 256 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm 257 | .softmax(name='softmax')) -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat.startswith("C"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 123 | 124 | 143 | 144 | 145 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 |