├── data
├── others
├── chart.jpg
├── example_1.jpg
├── example_2.jpg
└── TrainingStatistic.xlsx
├── dict
├── 40000_classes.npy
├── SROIE_classes.npy
├── table_classes.npy
├── 20000TC_classes.npy
├── 20000T_classes.npy
├── SROIEnc_classes.npy
├── 20000TC_dictionary.npy
├── 20000T_dictionary.npy
├── 40000_dictionary.npy
├── SROIE_dictionary.npy
├── SROIEnc_dictionary.npy
├── table_dictionary.npy
├── 20000T_index_to_word.npy
├── 20000T_word_to_index.npy
├── 40000_index_to_word.npy
├── 40000_word_to_index.npy
├── SROIE_index_to_word.npy
├── SROIE_word_to_index.npy
├── table_index_to_word.npy
├── table_word_to_index.npy
├── 20000TC_index_to_word.npy
├── 20000TC_word_to_index.npy
├── SROIEnc_index_to_word.npy
└── SROIEnc_word_to_index.npy
├── .gitignore
├── invoice_data
├── Faktura1.pdf_0.jpg
└── Faktura1.pdf_1.jpg
├── .idea
├── misc.xml
├── vcs.xml
├── modules.xml
├── CUTIE.iml
└── workspace.xml
├── .settings
└── org.eclipse.core.resources.prefs
├── .pydevproject
├── .project
├── requirements.txt
├── main_data_tokenizer.py
├── main_build_dict.py
├── helper.py
├── README.md
├── deprecated
├── model_cutie_dilate.py
├── model_cutie_att.py
├── model_cutie_sep.py
├── model_cutie_unet8.py
├── model_cutie_fpn8.py
├── model_cutie_res.py
├── model_cutie_res16.py
├── model_cutie_res_att.py
└── model_cutie_res_att_bert.py
├── download_data.py
├── model_cutie_res_bert.py
├── model_cutie_aspp.py
├── model_cutie2_fpn.py
├── model_cutie2_dilate.py
├── model_cutie2_aspp.py
├── model_cutie.py
├── main_evaluate_json.py
├── model_cutie2.py
├── bert_embedding.py
├── export_data.py
├── model_cutie_hr.py
├── tokenization.py
├── utils.py
├── main_train_json.py
└── model_framework.py
/data:
--------------------------------------------------------------------------------
1 | ../data/CUTIE/
--------------------------------------------------------------------------------
/others/chart.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/others/chart.jpg
--------------------------------------------------------------------------------
/dict/40000_classes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/40000_classes.npy
--------------------------------------------------------------------------------
/dict/SROIE_classes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIE_classes.npy
--------------------------------------------------------------------------------
/dict/table_classes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/table_classes.npy
--------------------------------------------------------------------------------
/others/example_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/others/example_1.jpg
--------------------------------------------------------------------------------
/others/example_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/others/example_2.jpg
--------------------------------------------------------------------------------
/dict/20000TC_classes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000TC_classes.npy
--------------------------------------------------------------------------------
/dict/20000T_classes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000T_classes.npy
--------------------------------------------------------------------------------
/dict/SROIEnc_classes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIEnc_classes.npy
--------------------------------------------------------------------------------
/dict/20000TC_dictionary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000TC_dictionary.npy
--------------------------------------------------------------------------------
/dict/20000T_dictionary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000T_dictionary.npy
--------------------------------------------------------------------------------
/dict/40000_dictionary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/40000_dictionary.npy
--------------------------------------------------------------------------------
/dict/SROIE_dictionary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIE_dictionary.npy
--------------------------------------------------------------------------------
/dict/SROIEnc_dictionary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIEnc_dictionary.npy
--------------------------------------------------------------------------------
/dict/table_dictionary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/table_dictionary.npy
--------------------------------------------------------------------------------
/dict/20000T_index_to_word.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000T_index_to_word.npy
--------------------------------------------------------------------------------
/dict/20000T_word_to_index.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000T_word_to_index.npy
--------------------------------------------------------------------------------
/dict/40000_index_to_word.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/40000_index_to_word.npy
--------------------------------------------------------------------------------
/dict/40000_word_to_index.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/40000_word_to_index.npy
--------------------------------------------------------------------------------
/dict/SROIE_index_to_word.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIE_index_to_word.npy
--------------------------------------------------------------------------------
/dict/SROIE_word_to_index.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIE_word_to_index.npy
--------------------------------------------------------------------------------
/dict/table_index_to_word.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/table_index_to_word.npy
--------------------------------------------------------------------------------
/dict/table_word_to_index.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/table_word_to_index.npy
--------------------------------------------------------------------------------
/others/TrainingStatistic.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/others/TrainingStatistic.xlsx
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | outputs/
2 | data/
3 | logs/
4 | graph/
5 | results/
6 | *.pyc
7 | .DS_Store
8 | */.DS_Store
9 | task*
--------------------------------------------------------------------------------
/dict/20000TC_index_to_word.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000TC_index_to_word.npy
--------------------------------------------------------------------------------
/dict/20000TC_word_to_index.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/20000TC_word_to_index.npy
--------------------------------------------------------------------------------
/dict/SROIEnc_index_to_word.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIEnc_index_to_word.npy
--------------------------------------------------------------------------------
/dict/SROIEnc_word_to_index.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/dict/SROIEnc_word_to_index.npy
--------------------------------------------------------------------------------
/invoice_data/Faktura1.pdf_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/invoice_data/Faktura1.pdf_0.jpg
--------------------------------------------------------------------------------
/invoice_data/Faktura1.pdf_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsymbol/CUTIE/HEAD/invoice_data/Faktura1.pdf_1.jpg
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.settings/org.eclipse.core.resources.prefs:
--------------------------------------------------------------------------------
1 | eclipse.preferences.version=1
2 | encoding/data_loader_json.py=utf-8
3 | encoding/main_data_tokenizer.py=utf-8
4 | encoding/main_evaluate_json.py=utf-8
5 | encoding/tokenization.py=utf-8
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.pydevproject:
--------------------------------------------------------------------------------
1 |
2 |
3 | Default
4 | python 2.7
5 |
6 |
--------------------------------------------------------------------------------
/.project:
--------------------------------------------------------------------------------
1 |
2 |
3 | CUTIE
4 |
5 |
6 |
7 |
8 |
9 | org.python.pydev.PyDevBuilder
10 |
11 |
12 |
13 |
14 |
15 | org.python.pydev.pythonNature
16 |
17 |
18 |
--------------------------------------------------------------------------------
/.idea/CUTIE.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.6.1
2 | astor==0.7.1
3 | atomicwrites==1.2.1
4 | attrs==18.2.0
5 | certifi==2019.3.9
6 | chardet==3.0.4
7 | cycler==0.10.0
8 | gast==0.2.0
9 | grpcio==1.17.1
10 | h5py==2.9.0
11 | idna==2.8
12 | Keras-Applications==1.0.6
13 | Keras-Preprocessing==1.0.5
14 | kiwisolver==1.0.1
15 | llvmlite==0.27.0
16 | Markdown==3.0.1
17 | matplotlib==3.0.2
18 | more-itertools==5.0.0
19 | numba==0.42.0
20 | numpy==1.15.4
21 | opencv-python==4.0.0.21
22 | pandas==0.23.4
23 | Pillow==5.4.0
24 | pluggy==0.8.1
25 | protobuf==3.6.1
26 | py==1.7.0
27 | pyparsing==2.3.0
28 | pytest==4.1.1
29 | python-dateutil==2.7.5
30 | pytz==2018.7
31 | requests==2.21.0
32 | scipy==1.2.0
33 | six==1.12.0
34 | tensorboard==1.12.1
35 | tensorflow==1.12.0
36 | termcolor==1.1.0
37 | urllib3==1.24.1
38 | Werkzeug==0.14.1
39 |
--------------------------------------------------------------------------------
/main_data_tokenizer.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-02
3 | # xh.zhao@outlook.com
4 | import tensorflow as tf
5 | import argparse
6 |
7 | import tokenization
8 | from data_loader_json import DataLoader
9 |
10 | parser = argparse.ArgumentParser(description='Data Tokenizer parameters')
11 | parser.add_argument('--dict_path', type=str, default='dict/vocab.txt')
12 | parser.add_argument('--doc_path', type=str, default='data/meals')
13 | parser.add_argument('--batch_size', type=int, default=32)
14 | params = parser.parse_args()
15 |
16 | #class DataTokenizer(DataLoader):
17 |
18 | if __name__ == '__main__':
19 | ## convert data into tokenized data with updated bbox
20 | #data_loader = DataLoader(params, update_dict=True, load_dictionary=False)
21 | tokenizer = tokenization.FullTokenizer(vocab_file=params.dict_path, do_lower_case=True)
22 |
23 |
--------------------------------------------------------------------------------
/main_build_dict.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-01
3 | # xiaohui.zhao@outlook.com
4 | import tensorflow as tf
5 | import argparse
6 |
7 | from data_loader_json import DataLoader
8 |
9 | parser = argparse.ArgumentParser(description='CUTIE parameters')
10 | parser.add_argument('--dict_path', type=str, default='dict/SROIE')
11 | parser.add_argument('--doc_path', type=str, default='data/SROIE')
12 | parser.add_argument('--test_path', type=str, default='') # leave empty if no test data provided
13 | parser.add_argument('--text_case', type=bool, default=True) # case sensitive
14 | parser.add_argument('--tokenize', type=bool, default=True) # tokenize input text
15 | parser.add_argument('--batch_size', type=int, default=32)
16 | parser.add_argument('--use_cutie2', type=bool, default=False)
17 | params = parser.parse_args()
18 |
19 | if __name__ == '__main__':
20 | ## run this program before training to create a basic dictionary for training
21 | data_loader = DataLoader(params, update_dict=True, load_dictionary=False)
22 |
--------------------------------------------------------------------------------
/helper.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import argparse
3 | import sys
4 | import os, re, json
5 | import time
6 |
7 | # if __name__ == "__main__":
8 | # src = '/Users/xiaohui.zhao/workspace/CUTIE/hotel_imgs/'
9 | # dst = '/Users/xiaohui.zhao/workspace/data/CUTIE/mixed_true_even/train_hotel'
10 | # json_files = {}
11 | # for dirpath,dirnames,filenames in os.walk(dst):
12 | # for filename in filenames:
13 | # file = re.split(r'\.', filename)[0]
14 | # file_path = os.path.join(dirpath,filename)
15 | # json_files.update({file: file_path})
16 | #
17 | # files = []
18 | # for dirpath,dirnames,filenames in os.walk(src):
19 | # for filename in filenames:
20 | # #file = os.path.join(dirpath,filename)
21 | # file = re.split(r'\.', filename)[0]
22 | # if file in json_files:
23 | # os.system('cp '+ os.path.join(dirpath,filename) + ' ' + dst)
24 |
25 | if __name__ == "__main__":
26 | src = '/Users/xiaohui.zhao/workspace/data/CUTIE/column_identity'
27 | dst = '/Users/xiaohui.zhao/workspace/data/CUTIE/column'
28 | json_files = {}
29 | for dirpath,dirnames,filenames in os.walk(src):
30 | for filename in filenames:
31 | file_path = os.path.join(dirpath,filename)
32 | if file_path[-3:] == 'png':
33 | continue
34 | with open(file_path, encoding='utf-8') as f:
35 | data = json.load(f)
36 | print(len(data['fields']))
37 | for i in range(len(data['fields'])):
38 | data['fields'][i]['field_name'] = 'Column{}'.format(i)
39 |
40 | target_path = os.path.join(dst,filename)
41 | with open(target_path, 'w') as f:
42 | json.dump(data, f)
43 |
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CUTIE
2 | TensorFlow implementation of the paper "CUTIE: Learning to Understand Documents with Convolutional Universal Text Information Extractor."
3 | Xiaohui Zhao [Paper Link](https://arxiv.org/abs/1903.12363v4)
4 |
5 | ----
6 | CUTIE 是用于“票据文档” 2D 关键信息提取/命名实体识别/槽位填充 算法。
7 | 使用CUTIE前,需先使用OCR算法对“票据文档” 中的文字执行检测和识别,而后将格式化的文本输入入CUTIE网络,具体流程可参照论文。
8 |
9 | CUTIE can be considered as one type of 2-Dimensional Key Information Extraction, 2-D NER (Named Entity Recognition) or a 2-Dimensional 2D Slot Filling algorithm.
10 | Before training / inference with CUTIE, prepare your structured texts in your scanned document images with any type of OCR algorithm. Refer to the CUTIE paper for details about the procedure.
11 |
12 | ### Results
13 |
14 | Result evaluated on 4,484 receipt documents, including taxi receipts, meals entertainment receipts, and hotel receipts, with 9 different key information classes. (AP / softAP)
15 | |Method | #Params | Taxi | Hotel |
16 | | ----------|:---------:| :-----: | :-----: |
17 | | CloudScan | - | 82.0 / - | 60.0 / - |
18 | | BERT | 110M | 88.1 / - | 71.7 / - |
19 | | CUTIE |**14M** |**94.0 / 97.3**|**74.6 / 87.0**|
20 |
21 | 
22 |
23 | 
24 |
25 |
26 | ### Installation & Usage
27 |
28 | ```
29 | pip install -r requirements.txt
30 | ```
31 |
32 | 1. Generate your own dictionary with main_build_dict.py / main_data_tokenizer.py
33 | 2. Train your model with main_train_json.py
34 |
35 | CUTIE achieves best performance with rows/cols well configured. For more insights, refer to statistics in the file (others/TrainingStatistic.xlsx).
36 |
37 | 
38 |
39 |
40 | ### Others
41 |
42 | For information about the input example, refer to [issue discussion](https://github.com/vsymbol/CUTIE/issues/7).
43 | - Apply any OCR tool that help you detecting and recognizing words in the scanned document image.
44 | - Label image OCR results with key information class as the .json file in the invoice_data folder. (thanks to @4kssoft)
45 |
--------------------------------------------------------------------------------
/deprecated/model_cutie_dilate.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2019-03
3 | # xiaohui.zhao@accenture.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIERes(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_dilate" #
10 |
11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
17 |
18 | self.num_vocabs = num_vocabs
19 | self.num_classes = num_classes
20 | self.trainable = trainable
21 |
22 | self.embedding_size = params.embedding_size
23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
26 |
27 | self.layer_inputs = []
28 | self.setup()
29 |
30 |
31 | def setup(self):
32 | # input
33 | (self.feed('data')
34 | .embed(self.num_vocabs, self.embedding_size, name='embedding'))
35 |
36 | # encoder
37 | (self.feed('embedding')
38 | .conv(3, 5, 128, 1, 1, name='encoder1_1')
39 | .conv(3, 5, 128, 1, 1, name='encoder1_2')
40 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_3')
41 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_4')
42 | .dilate_conv(3, 5, 512, 1, 1, 2, name='encoder1_5')
43 | .dilate_conv(3, 5, 512, 1, 1, 2, name='encoder1_6')
44 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_7')
45 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_8')
46 | .dilate_conv(3, 5, 128, 1, 1, 2, name='encoder1_9')
47 | .dilate_conv(3, 5, 128, 1, 1, 2, name='encoder1_10'))
48 |
49 | # classification
50 | (self.feed('encoder1_10')
51 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm
52 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/deprecated/model_cutie_att.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2019-03
3 | # xiaohui.zhao@accenture.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIERes(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_attention" #
10 |
11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
17 |
18 | self.num_vocabs = num_vocabs
19 | self.num_classes = num_classes
20 | self.trainable = trainable
21 |
22 | self.embedding_size = params.embedding_size
23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
26 |
27 | self.layer_inputs = []
28 | self.setup()
29 |
30 |
31 | def setup(self):
32 | # input
33 | (self.feed('data')
34 | .embed(self.num_vocabs, self.embedding_size, name='embedding')
35 | .conv(3, 5, 128, 1, 1, name='encoder0_1'))
36 |
37 | # encoder
38 | (self.feed('encoder0_1')
39 | .conv(3, 5, 128, 1, 1, name='encoder1_1')
40 | .conv(3, 5, 128, 1, 1, name='encoder1_2')
41 | .conv(3, 5, 128, 1, 1, name='encoder1_3')
42 | .conv(3, 5, 128, 1, 1, name='encoder1_4'))
43 |
44 | (self.feed('encoder0_1', 'encoder1_4')
45 | .attention(1, name='attention2')
46 | .conv(3, 5, 128, 1, 1, name='encoder1_5')
47 | .conv(3, 5, 128, 1, 1, name='encoder1_6')
48 | .conv(3, 5, 128, 1, 1, name='encoder1_7')
49 | .conv(3, 5, 128, 1, 1, name='encoder1_8'))
50 |
51 | (self.feed('encoder0_1', 'encoder1_8')
52 | .attention(1, name='attention5')
53 | .conv(3, 5, 128, 1, 1, name='encoder1_9')
54 | .conv(3, 5, 128, 1, 1, name='encoder1_10'))
55 |
56 | # classification
57 | (self.feed('encoder1_10')
58 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm
59 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/deprecated/model_cutie_sep.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xiaohui.zhao@accenture.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIESep(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_seperatable_residual"
10 |
11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes})
14 | self.num_vocabs = num_vocabs
15 | self.num_classes = num_classes
16 | self.trainable = trainable
17 |
18 | self.embedding_size = params.embedding_size
19 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0
20 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0
21 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
22 |
23 | self.layer_inputs = []
24 | self.setup()
25 |
26 |
27 | def setup(self):
28 | # input
29 | (self.feed('data')
30 | .embed(self.num_vocabs, self.embedding_size, name='embedding'))
31 |
32 | # encoder
33 | (self.feed('embedding')
34 | .sepconv(3, 5, 4, 2, name='encoder1_1')
35 | .sepconv(3, 5, 4, 2, name='encoder1_2')
36 | .conv(3, 5, 128, 1, 1, name='bottleneck1')
37 | .max_pool(2, 2, 2, 2, name='pool1')
38 | .sepconv(3, 5, 4, 2, name='encoder2_1')
39 | .sepconv(3, 5, 4, 2, name='encoder2_2')
40 | .conv(3, 5, 256, 1, 1, name='bottleneck2')
41 | .max_pool(2, 2, 2, 2, name='pool2')
42 | .sepconv(3, 5, 4, 2, name='encoder3_1')
43 | .sepconv(3, 5, 4, 2, name='encoder3_2')
44 | .conv(3, 5, 512, 1, 1, name='bottleneck3')
45 | .max_pool(2, 2, 2, 2, name='pool3')
46 | .sepconv(3, 5, 4, 2, name='encoder4_1')
47 | .sepconv(3, 5, 4, 2, name='encoder4_2'))
48 |
49 | # decoder
50 | (self.feed('encoder4_2')
51 | .up_conv(3, 5, 512, 1, 1, name='up1')
52 | .conv(3, 5, 256, 1,1, name='bottleneck4')
53 | .sepconv(3, 5, 4, 2, name='decoder1_1')
54 | .sepconv(3, 5, 4, 2, name='decoder1_2')
55 | .up_conv(3, 5, 256, 1, 1, name='up2')
56 | .conv(3, 5, 128, 1,1, name='bottleneck5')
57 | .sepconv(3, 5, 4, 2, name='decoder2_1')
58 | .sepconv(3, 5, 4, 2, name='decoder2_2')
59 | .up_conv(3, 5, 128, 1, 1, name='up3')
60 | .conv(3, 5, 64, 1,1, name='bottleneck6')
61 | .sepconv(3, 5, 4, 2, name='decoder3_1')
62 | .sepconv(3, 5, 4, 2, name='decoder3_2'))
63 |
64 | # classification
65 | (self.feed('decoder3_2')
66 | .conv(1, 1, self.num_classes, 1, 1, name='cls_logits')
67 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/download_data.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import argparse
3 | from requests.auth import HTTPBasicAuth
4 | import base64
5 | import os
6 | import time
7 |
8 |
9 | def init_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--user', type=str, help='login name', required=True)
12 | parser.add_argument('--password', type=str, help='login password', required=True)
13 | parser.add_argument('--date', type=str, help='search date', required=True)
14 | parser.add_argument('--invoice_type_id', type=str, help='invoice type id', required=True)
15 |
16 | return parser.parse_args()
17 |
18 |
19 | def get_tasks_of_page(date=None, invoice_type=None, next_url=None):
20 | api = next_url if next_url else \
21 | 'http://52.193.30.103/argus/api/task/myte/?search={0}&invoice_type={1}'.format(date, invoice_type)
22 |
23 | r = requests.get(api, auth=auth, headers=headers)
24 |
25 | if r.ok:
26 | results = r.json().get('results')
27 | next = r.json().get('next')
28 | count = r.json().get('count')
29 |
30 | for result in results:
31 | global index
32 | index += 1
33 |
34 | print('\r download: {0}/{1}'.format(index, count), end='')
35 | get_image(result['id'])
36 |
37 | else:
38 | print('Download task info failed.')
39 | print('Call api: {url}'.format(url=api))
40 | print('[Err Reason]:{err}'.format(err=r.reason))
41 | print('[Err text]:{err}'.format(err=r.text))
42 | next = None
43 |
44 | return next
45 |
46 |
47 | def get_image(task_id):
48 | api = 'http://52.193.30.103/argus/api/task_file/myte/{0}'.format(task_id)
49 | r = requests.get(api, auth=auth, headers=headers)
50 | if r.ok:
51 | result = r.json()[0]
52 | quality_file = result.get('quality_file')
53 | file_name = os.path.split(quality_file)[-1]
54 |
55 | with open(os.path.join(path, file_name), 'wb') as f:
56 | image_64_decode = base64.b64decode(result.get('picture').encode("utf8"))
57 | f.write(image_64_decode)
58 |
59 | else:
60 | print('Download image {id} failed.'.format(id=task_id))
61 | print('Call api: {url}'.format(url=api))
62 | print('[Err Reason]:{err}'.format(err=r.reason))
63 | print('[Err text]:{err}'.format(err=r.text))
64 |
65 |
66 | if __name__ == "__main__":
67 | args = init_args()
68 | auth = HTTPBasicAuth(args.user, args.password)
69 | headers = {
70 | 'Content-Type': "application/json",
71 | 'Cache-Control': "no-cache",
72 | }
73 | time_flag = int(time.time())
74 | for date in args.date.split(','):
75 | index = 0
76 | path = '{root}/invoice_type_{type}-{time}/{date}'.format(
77 | root=os.path.abspath('.'), date=date, time=time_flag, type=args.invoice_type_id
78 | )
79 | os.makedirs(path)
80 | print('\ndownload path: {0}'.format(path))
81 |
82 | next_url = get_tasks_of_page(date=date, invoice_type=args.invoice_type_id)
83 | while True:
84 | if next_url:
85 | next_url = get_tasks_of_page(next_url=next_url)
86 | else:
87 | break
88 |
--------------------------------------------------------------------------------
/deprecated/model_cutie_unet8.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xiaohui.zhao@accenture.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIEUNet(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_unet_8x" # 8x down sampling
10 |
11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes})
14 | self.num_vocabs = num_vocabs
15 | self.num_classes = num_classes
16 | self.trainable = trainable
17 |
18 | self.embedding_size = params.embedding_size
19 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
20 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
21 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
22 |
23 | self.layer_inputs = []
24 | self.setup()
25 |
26 |
27 | def setup(self):
28 | # input
29 | (self.feed('data')
30 | .embed(self.num_vocabs, self.embedding_size, name='embedding'))
31 |
32 | # encoder
33 | (self.feed('embedding')
34 | .conv(3, 5, 64, 1, 1, name='encoder1_1')
35 | .conv(3, 5, 128, 1, 1, name='encoder1_2')
36 | .max_pool(2, 2, 2, 2, name='pool1')
37 | .conv(3, 5, 128, 1, 1, name='encoder2_1')
38 | .conv(3, 5, 256, 1, 1, name='encoder2_2')
39 | .max_pool(2, 2, 2, 2, name='pool2')
40 | .conv(3, 5, 256, 1, 1, name='encoder3_1')
41 | .conv(3, 5, 512, 1, 1, name='encoder3_2')
42 | .conv(3, 5, 512, 1, 1, name='encoder3_3')
43 | .max_pool(2, 2, 2, 2, name='pool3')
44 | .conv(3, 5, 512, 1, 1, name='encoder4_1')
45 | .conv(3, 5, 512, 1, 1, name='encoder4_2')
46 | .conv(3, 5, 512, 1, 1, name='encoder4_3')
47 | .conv(3, 5, 512, 1, 1, name='encoder4_4'))
48 |
49 | # decoder
50 | (self.feed('encoder4_4')
51 | .up_conv(3, 5, 512, 1, 1, name='up1'))
52 | (self.feed('up1', 'encoder3_3')
53 | .concat(3, name='concat1')
54 | .conv(3, 5, 256, 1, 1, name='decoder1_1')
55 | .conv(3, 5, 256, 1, 1, name='decoder1_2')
56 | .conv(3, 5, 256, 1, 1, name='decoder1_3')
57 | .conv(3, 5, 256, 1, 1, name='decoder1_4')
58 | .up_conv(3, 5, 256, 1, 1, name='up2'))
59 | (self.feed('up2', 'encoder2_2')
60 | .concat(3, name='concat2')
61 | .conv(3, 5, 128, 1, 1, name='decoder2_1')
62 | .conv(3, 5, 128, 1, 1, name='decoder2_2')
63 | .conv(3, 5, 128, 1, 1, name='decoder2_3')
64 | .up_conv(3, 5, 128, 1, 1, name='up3'))
65 | (self.feed('up3', 'encoder1_2')
66 | .concat(3, name='concat3')
67 | .conv(3, 5, 64, 1, 1, name='decoder3_1')
68 | .conv(3, 5, 64, 1, 1, name='decoder3_2'))
69 |
70 | # classification
71 | (self.feed('decoder3_2')
72 | .conv(1, 1, self.num_classes, 1, 1, name='cls_logits')
73 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/deprecated/model_cutie_fpn8.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xiaohui.zhao@accenture.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIEFPN(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_fpn_8x" # 8x down sampling
10 |
11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes})
14 | self.num_vocabs = num_vocabs
15 | self.num_classes = num_classes
16 | self.trainable = trainable
17 |
18 | self.embedding_size = params.embedding_size
19 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
20 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
21 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
22 |
23 | self.layer_inputs = []
24 | self.setup()
25 |
26 |
27 | def setup(self):
28 | # input
29 | (self.feed('data')
30 | .embed(self.num_vocabs, self.embedding_size, name='embedding'))
31 |
32 | # encoder
33 | (self.feed('embedding')
34 | .conv(3, 5, 64, 1, 1, name='encoder1_1')
35 | .conv(3, 5, 128, 1, 1, name='encoder1_2')
36 | .max_pool(2, 2, 2, 2, name='pool1')
37 | .conv(3, 5, 128, 1, 1, name='encoder2_1')
38 | .conv(3, 5, 256, 1, 1, name='encoder2_2')
39 | .max_pool(2, 2, 2, 2, name='pool2')
40 | .conv(3, 5, 256, 1, 1, name='encoder3_1')
41 | .conv(3, 5, 512, 1, 1, name='encoder3_2')
42 | .conv(3, 5, 512, 1, 1, name='encoder3_3')
43 | .max_pool(2, 2, 2, 2, name='pool3')
44 | .conv(3, 5, 512, 1, 1, name='encoder4_1')
45 | .conv(3, 5, 512, 1, 1, name='encoder4_2')
46 | .conv(3, 5, 512, 1, 1, name='encoder4_3')
47 | .conv(3, 5, 512, 1, 1, name='encoder4_4'))
48 |
49 | # decoder
50 | (self.feed('encoder4_4')
51 | .up_conv(3, 5, 512, 1, 1, name='up1'))
52 | (self.feed('up1', 'encoder3_4')
53 | .concat(3, name='concat1') # 1x1 before the add operation for FPN
54 | .conv(3, 5, 256, 1, 1, name='decoder1_1')
55 | .conv(3, 5, 256, 1, 1, name='decoder1_2')
56 | .conv(3, 5, 256, 1, 1, name='decoder1_3')
57 | .conv(3, 5, 256, 1, 1, name='decoder1_4')
58 | .up_conv(3, 5, 256, 1, 1, name='up2'))
59 | (self.feed('up2', 'encoder2_3')
60 | .concat(3, name='concat2')
61 | .conv(3, 5, 128, 1, 1, name='decoder2_1')
62 | .conv(3, 5, 128, 1, 1, name='decoder2_2')
63 | .conv(3, 5, 128, 1, 1, name='decoder2_3')
64 | .up_conv(3, 5, 128, 1, 1, name='up3'))
65 | (self.feed('up3', 'encoder1_2')
66 | .concat(3, name='concat3')
67 | .conv(3, 5, 64, 1, 1, name='decoder3_1')
68 | .conv(3, 5, 64, 1, 1, name='decoder3_2'))
69 |
70 | # classification
71 | (self.feed('decoder3_2')
72 | .conv(1, 1, self.num_classes, 1, 1, name='cls_logits')
73 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/deprecated/model_cutie_res.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xiaohui.zhao@accenture.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIERes(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_residual_8x" # 8x down sampling
10 |
11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
17 |
18 | self.num_vocabs = num_vocabs
19 | self.num_classes = num_classes
20 | self.trainable = trainable
21 |
22 | self.embedding_size = params.embedding_size
23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
26 |
27 | self.layer_inputs = []
28 | self.setup()
29 |
30 |
31 | def setup(self):
32 | # input
33 | (self.feed('data')
34 | .embed(self.num_vocabs, self.embedding_size, name='embedding'))
35 |
36 | # encoder
37 | (self.feed('embedding')
38 | .conv(3, 5, 64, 1, 1, name='encoder1_1')
39 | .conv(3, 5, 128, 1, 1, name='encoder1_2')
40 | .max_pool(2, 2, 2, 2, name='pool1')
41 | .conv(3, 5, 128, 1, 1, name='encoder2_1')
42 | .conv(3, 5, 256, 1, 1, name='encoder2_2')
43 | .max_pool(2, 2, 2, 2, name='pool2')
44 | .conv(3, 5, 256, 1, 1, name='encoder3_1')
45 | .conv(3, 5, 512, 1, 1, name='encoder3_2')
46 | .max_pool(2, 2, 2, 2, name='pool3')
47 | .conv(3, 5, 512, 1, 1, name='encoder4_1')
48 | .conv(3, 5, 512, 1, 1, name='encoder4_2'))
49 |
50 | # decoder
51 | (self.feed('encoder4_2')
52 | .up_conv(3, 5, 512, 1, 1, name='up1'))
53 | (self.feed('up1', 'encoder3_2')
54 | .concat(3, name='concat1')
55 | .conv(3, 5, 256, 1, 1, name='decoder1_1')
56 | .conv(3, 5, 256, 1, 1, name='decoder1_2')
57 | .up_conv(3, 5, 256, 1, 1, name='up2'))
58 | (self.feed('up2', 'encoder2_2')
59 | .concat(3, name='concat2')
60 | .conv(3, 5, 128, 1, 1, name='decoder2_1')
61 | .conv(3, 5, 128, 1, 1, name='decoder2_2')
62 | .up_conv(3, 5, 128, 1, 1, name='up3'))
63 | (self.feed('up3', 'encoder1_2')
64 | .concat(3, name='concat3')
65 | .conv(3, 5, 64, 1, 1, name='decoder3_1')
66 | .conv(3, 5, 64, 1, 1, name='decoder3_2'))
67 |
68 | # classification
69 | (self.feed('decoder3_2')
70 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm
71 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/deprecated/model_cutie_res16.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xiaohui.zhao@accenture.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIERes(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_residual_16x" # 16x down sampling
10 |
11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes})
14 | self.num_vocabs = num_vocabs
15 | self.num_classes = num_classes
16 | self.trainable = trainable
17 |
18 | self.embedding_size = params.embedding_size
19 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
20 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
21 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
22 |
23 | self.layer_inputs = []
24 | self.setup()
25 |
26 |
27 | def setup(self):
28 | # input
29 | (self.feed('data')
30 | .embed(self.num_vocabs, self.embedding_size, name='embedding'))
31 |
32 | # encoder
33 | (self.feed('embedding')
34 | .conv(3, 5, 64, 1, 1, name='encoder1_1')
35 | .conv(3, 5, 128, 1, 1, name='encoder1_2')
36 | .max_pool(2, 2, 2, 2, name='pool1')
37 | .conv(3, 5, 128, 1, 1, name='encoder2_1')
38 | .conv(3, 5, 256, 1, 1, name='encoder2_2')
39 | .max_pool(2, 2, 2, 2, name='pool2')
40 | .conv(3, 5, 256, 1, 1, name='encoder3_1')
41 | .conv(3, 5, 512, 1, 1, name='encoder3_2')
42 | .max_pool(2, 2, 2, 2, name='pool3')
43 | .conv(3, 5, 512, 1, 1, name='encoder4_1')
44 | .conv(3, 5, 512, 1, 1, name='encoder4_2')
45 | .max_pool(2, 2, 2, 2, name='pool4')
46 | .conv(3, 5, 1024, 1, 1, name='encoder5_1')
47 | .conv(3, 5, 1024, 1, 1, name='encoder5_2'))
48 |
49 | # decoder
50 | (self.feed('encoder5_2')
51 | .up_conv(3, 5, 1024, 1, 1, name='up1'))
52 | (self.feed('up1', 'encoder4_2')
53 | .concat(3, name='concat1')
54 | .conv(3, 5, 512, 1, 1, name='decoder1_1')
55 | .conv(3, 5, 512, 1, 1, name='decoder1_2')
56 | .up_conv(3, 5, 512, 1, 1, name='up2'))
57 | (self.feed('up2', 'encoder3_2')
58 | .concat(3, name='concat2')
59 | .conv(3, 5, 256, 1, 1, name='decoder2_1')
60 | .conv(3, 5, 256, 1, 1, name='decoder2_2')
61 | .up_conv(3, 5, 256, 1, 1, name='up3'))
62 | (self.feed('up3', 'encoder2_2')
63 | .concat(3, name='concat3')
64 | .conv(3, 5, 128, 1, 1, name='decoder3_1')
65 | .conv(3, 5, 128, 1, 1, name='decoder3_2')
66 | .up_conv(3, 5, 128, 1, 1, name='up4'))
67 | (self.feed('up4', 'encoder1_2')
68 | .concat(3, name='concat4')
69 | .conv(3, 5, 64, 1, 1, name='decoder4_1')
70 | .conv(3, 5, 64, 1, 1, name='decoder4_2'))
71 |
72 | # classification
73 | (self.feed('decoder4_2')
74 | .conv(1, 1, self.num_classes, 1, 1, name='cls_logits')
75 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/deprecated/model_cutie_res_att.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xiaohui.zhao@accenture.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIERes(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_residual_attention_8x" # 8x down sampling
10 |
11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
17 |
18 | self.num_vocabs = num_vocabs
19 | self.num_classes = num_classes
20 | self.trainable = trainable
21 |
22 | self.embedding_size = params.embedding_size
23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
26 |
27 | self.layer_inputs = []
28 | self.setup()
29 |
30 |
31 | def setup(self):
32 | # input
33 | (self.feed('data')
34 | .embed(self.num_vocabs, self.embedding_size, name='embedding'))
35 |
36 | # encoder
37 | (self.feed('embedding')
38 | .conv(3, 5, 64, 1, 1, name='encoder1_1')
39 | .conv(3, 5, 128, 1, 1, name='encoder1_2')
40 | .max_pool(2, 2, 2, 2, name='pool1')
41 | .conv(3, 5, 128, 1, 1, name='encoder2_1')
42 | .conv(3, 5, 256, 1, 1, name='encoder2_2')
43 | .max_pool(2, 2, 2, 2, name='pool2')
44 | .conv(3, 5, 256, 1, 1, name='encoder3_1')
45 | .conv(3, 5, 512, 1, 1, name='encoder3_2')
46 | .max_pool(2, 2, 2, 2, name='pool3')
47 | .conv(3, 5, 512, 1, 1, name='encoder4_1')
48 | .conv(3, 5, 512, 1, 1, name='encoder4_2'))
49 |
50 | # decoder
51 | (self.feed('encoder4_2')
52 | .up_conv(3, 5, 512, 1, 1, name='up1'))
53 | (self.feed('up1', 'encoder3_2')
54 | .attention(1, name='attention1')
55 | .conv(3, 5, 256, 1, 1, name='decoder1_1')
56 | .conv(3, 5, 256, 1, 1, name='decoder1_2')
57 | .up_conv(3, 5, 256, 1, 1, name='up2'))
58 | (self.feed('up2', 'encoder2_2')
59 | .attention(1, name='attention2')
60 | .conv(3, 5, 128, 1, 1, name='decoder2_1')
61 | .conv(3, 5, 128, 1, 1, name='decoder2_2')
62 | .up_conv(3, 5, 128, 1, 1, name='up3'))
63 | (self.feed('up3', 'encoder1_2')
64 | .attention(1, name='attention3')
65 | .conv(3, 5, 64, 1, 1, name='decoder3_1')
66 | .conv(3, 5, 64, 1, 1, name='decoder3_2'))
67 |
68 | # classification
69 | (self.feed('decoder3_2')
70 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm
71 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/model_cutie_res_bert.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xiaohui.zhao@outlook.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIERes(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_residual_bert_8x" # 8x down sampling
10 |
11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
17 |
18 | self.num_vocabs = num_vocabs
19 | self.num_classes = num_classes
20 | self.trainable = trainable
21 |
22 | self.embedding_size = params.embedding_size
23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
26 |
27 | self.layer_inputs = []
28 | self.embedding_table = None
29 | self.setup()
30 |
31 |
32 | def setup(self):
33 | # input
34 | (self.feed('data')
35 | .bert_embed(self.num_vocabs, 768, name='embeddings', trainable=False))
36 |
37 | # encoder
38 | (self.feed('embeddings')
39 | .conv(3, 5, 64, 1, 1, name='encoder1_1')
40 | .conv(3, 5, 128, 1, 1, name='encoder1_2')
41 | .max_pool(2, 2, 2, 2, name='pool1')
42 | .conv(3, 5, 128, 1, 1, name='encoder2_1')
43 | .conv(3, 5, 256, 1, 1, name='encoder2_2')
44 | .max_pool(2, 2, 2, 2, name='pool2')
45 | .conv(3, 5, 256, 1, 1, name='encoder3_1')
46 | .conv(3, 5, 512, 1, 1, name='encoder3_2')
47 | .max_pool(2, 2, 2, 2, name='pool3')
48 | .conv(3, 5, 512, 1, 1, name='encoder4_1')
49 | .conv(3, 5, 512, 1, 1, name='encoder4_2'))
50 |
51 | # decoder
52 | (self.feed('encoder4_2')
53 | .up_conv(3, 5, 512, 1, 1, name='up1'))
54 | (self.feed('up1', 'encoder3_2')
55 | .concat(3, name='concat1')
56 | .conv(3, 5, 256, 1, 1, name='decoder1_1')
57 | .conv(3, 5, 256, 1, 1, name='decoder1_2')
58 | .up_conv(3, 5, 256, 1, 1, name='up2'))
59 | (self.feed('up2', 'encoder2_2')
60 | .concat(3, name='concat2')
61 | .conv(3, 5, 128, 1, 1, name='decoder2_1')
62 | .conv(3, 5, 128, 1, 1, name='decoder2_2')
63 | .up_conv(3, 5, 128, 1, 1, name='up3'))
64 | (self.feed('up3', 'encoder1_2')
65 | .concat(3, name='concat3')
66 | .conv(3, 5, 64, 1, 1, name='decoder3_1')
67 | .conv(3, 5, 64, 1, 1, name='decoder3_2'))
68 |
69 | # classification
70 | (self.feed('decoder3_2')
71 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm
72 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/deprecated/model_cutie_res_att_bert.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xiaohui.zhao@accenture.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIERes(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_residual_attention_bert_8x" # 8x down sampling
10 |
11 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
14 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
15 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
16 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
17 |
18 | self.num_vocabs = num_vocabs
19 | self.num_classes = num_classes
20 | self.trainable = trainable
21 |
22 | self.embedding_size = params.embedding_size
23 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
24 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
25 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
26 |
27 | self.layer_inputs = []
28 | self.embedding_table = None
29 | self.setup()
30 |
31 |
32 | def setup(self):
33 | # input
34 | (self.feed('data')
35 | .bert_embed(self.num_vocabs, 768, name='embeddings', trainable=False))
36 |
37 | # encoder
38 | (self.feed('embeddings')
39 | .conv(3, 5, 64, 1, 1, name='encoder1_1')
40 | .conv(3, 5, 128, 1, 1, name='encoder1_2')
41 | .max_pool(2, 2, 2, 2, name='pool1')
42 | .conv(3, 5, 128, 1, 1, name='encoder2_1')
43 | .conv(3, 5, 256, 1, 1, name='encoder2_2')
44 | .max_pool(2, 2, 2, 2, name='pool2')
45 | .conv(3, 5, 256, 1, 1, name='encoder3_1')
46 | .conv(3, 5, 512, 1, 1, name='encoder3_2')
47 | .max_pool(2, 2, 2, 2, name='pool3')
48 | .conv(3, 5, 512, 1, 1, name='encoder4_1')
49 | .conv(3, 5, 512, 1, 1, name='encoder4_2'))
50 |
51 | # decoder
52 | (self.feed('encoder4_2')
53 | .up_conv(3, 5, 512, 1, 1, name='up1'))
54 | (self.feed('up1', 'encoder3_2')
55 | .attention(1, name='attention1')
56 | .conv(3, 5, 256, 1, 1, name='decoder1_1')
57 | .conv(3, 5, 256, 1, 1, name='decoder1_2')
58 | .up_conv(3, 5, 256, 1, 1, name='up2'))
59 | (self.feed('up2', 'encoder2_2')
60 | .attention(1, name='attention2')
61 | .conv(3, 5, 128, 1, 1, name='decoder2_1')
62 | .conv(3, 5, 128, 1, 1, name='decoder2_2')
63 | .up_conv(3, 5, 128, 1, 1, name='up3'))
64 | (self.feed('up3', 'encoder1_2')
65 | .attention(1, name='attention3')
66 | .conv(3, 5, 64, 1, 1, name='decoder3_1')
67 | .conv(3, 5, 64, 1, 1, name='decoder3_2'))
68 |
69 | # classification
70 | (self.feed('decoder3_2')
71 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm
72 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/model_cutie_aspp.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2019-03
3 | # xiaohui.zhao@outlook.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIERes(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_atrousSPP" #
10 |
11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='data_grid')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.data_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data_image') # not used in CUTIEv1
14 | self.ps_1d_indices = tf.placeholder(tf.int32, shape=[None, None], name='ps_1d_indices') # not used in CUTIEv1
15 |
16 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
17 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
18 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1
19 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
20 | self.layers = dict({'data_grid': self.data_grid, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
21 |
22 | self.num_vocabs = num_vocabs
23 | self.num_classes = num_classes
24 | self.trainable = trainable
25 |
26 | self.embedding_size = params.embedding_size
27 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
28 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
29 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
30 |
31 | self.layer_inputs = []
32 | self.setup()
33 |
34 |
35 | def setup(self):
36 | # input
37 | (self.feed('data_grid')
38 | .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout))
39 |
40 | # encoder
41 | (self.feed('embedding')
42 | .conv(3, 5, 256, 1, 1, name='encoder1_1')
43 | .conv(3, 5, 256, 1, 1, name='encoder1_2')
44 | .conv(3, 5, 256, 1, 1, name='encoder1_3')
45 | .conv(3, 5, 256, 1, 1, name='encoder1_4')
46 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_5')
47 | .dilate_conv(3, 5, 256, 1, 1, 4, name='encoder1_6')
48 | .dilate_conv(3, 5, 256, 1, 1, 8, name='encoder1_7')
49 | .dilate_conv(3, 5, 256, 1, 1, 16, name='encoder1_8'))
50 |
51 | # Atrous Spatial Pyramid Pooling module
52 | #(self.feed('encoder1_8')
53 | # .conv(1, 1, 256, 1, 1, name='aspp_0'))
54 | (self.feed('encoder1_8')
55 | .dilate_conv(3, 5, 256, 1, 1, 4, name='aspp_1'))
56 | (self.feed('encoder1_8')
57 | .dilate_conv(3, 5, 256, 1, 1, 8, name='aspp_2'))
58 | (self.feed('encoder1_8')
59 | .dilate_conv(3, 5, 256, 1, 1, 16, name='aspp_3'))
60 | (self.feed('encoder1_8')
61 | .global_pool(name='aspp_4'))
62 | (self.feed('aspp_1', 'aspp_2', 'aspp_3', 'aspp_4')
63 | .concat(3, name='aspp_concat')
64 | .conv(1, 1, 256, 1, 1, name='aspp_1x1'))
65 |
66 | # combine low level features
67 | (self.feed('encoder1_1', 'aspp_1x1')
68 | .concat(3, name='concat1')
69 | .conv(3, 5, 64, 1, 1, name='decoder1_1'))
70 |
71 | # classification
72 | (self.feed('decoder1_1')
73 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm
74 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/model_cutie2_fpn.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2019-04
3 | # xiaohui.zhao@outlook.com
4 | import tensorflow as tf
5 | from model_cutie2 import CUTIE2 as CUTIE
6 |
7 | class CUTIE2(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE2_dilate" #
10 |
11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='data_grid')
12 | self.data_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data_image')
13 | self.ps_1d_indices = tf.placeholder(tf.int32, shape=[None, None], name='ps_1d_indices')
14 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
15 |
16 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
17 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
18 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1
19 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
20 | self.layers = dict({'data_grid': self.data_grid, 'data_image': self.data_image, 'ps_1d_indices': self.ps_1d_indices,
21 | 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
22 |
23 | self.num_vocabs = num_vocabs
24 | self.num_classes = num_classes
25 | self.trainable = trainable
26 |
27 | self.embedding_size = params.embedding_size
28 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
29 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
30 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
31 |
32 | self.layer_inputs = []
33 | self.setup()
34 |
35 |
36 | def setup(self):
37 | ## grid
38 | (self.feed('data_grid')
39 | .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout))
40 |
41 | ## image
42 | (self.feed('data_image')
43 | .conv(3, 3, 32, 1, 1, name='image_encoder1_1')
44 | .conv(3, 3, 64, 2, 2, name='image_encoder1_1')
45 | .conv(3, 3, 64, 1, 1, name='image_encoder1_2')
46 | .conv(3, 3, 128, 2, 2, name='image_encoder1_2')
47 | .conv(3, 3, 128, 1, 1, name='image_encoder1_2')
48 | .conv(3, 3, 256, 2, 2, name='image_encoder1_2'))
49 |
50 | # feature map positional mapping
51 | (self.feed('image_featuremap', 'ps_1d_indices', 'data_grid')
52 | .positional_sampling(32, name='positional_sampling'))
53 |
54 | ## concate image with grid
55 | (self.feed('positional_sampling', 'embedding')
56 | .concat(3, name='concat')
57 | .conv(1, 1, 256, 1, 1, name='feature_fuser'))
58 |
59 | # encoder
60 | (self.feed('feature_fuser')
61 | .conv(3, 5, 256, 1, 1, name='encoder1_1')
62 | .conv(3, 5, 256, 1, 1, name='encoder1_2')
63 | .conv(3, 5, 256, 1, 1, name='encoder1_3')
64 | .conv(3, 5, 256, 1, 1, name='encoder1_4')
65 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_5')
66 | .dilate_conv(3, 5, 256, 1, 1, 4, name='encoder1_6')
67 | .dilate_conv(3, 5, 256, 1, 1, 8, name='encoder1_7')
68 | .dilate_conv(3, 5, 256, 1, 1, 16, name='encoder1_8'))
69 |
70 | # Atrous Spatial Pyramid Pooling module
71 | #(self.feed('encoder1_8')
72 | # .conv(1, 1, 256, 1, 1, name='aspp_0'))
73 | (self.feed('encoder1_8')
74 | .dilate_conv(3, 5, 256, 1, 1, 4, name='aspp_1'))
75 | (self.feed('encoder1_8')
76 | .dilate_conv(3, 5, 256, 1, 1, 8, name='aspp_2'))
77 | (self.feed('encoder1_8')
78 | .dilate_conv(3, 5, 256, 1, 1, 16, name='aspp_3'))
79 | (self.feed('encoder1_8')
80 | .global_pool(name='aspp_4'))
81 | (self.feed('aspp_1', 'aspp_2', 'aspp_3', 'aspp_4')
82 | .concat(3, name='aspp_concat')
83 | .conv(1, 1, 256, 1, 1, name='aspp_1x1'))
84 |
85 | # combine low level features
86 | (self.feed('encoder1_1', 'aspp_1x1')
87 | .concat(3, name='concat1')
88 | .conv(3, 5, 64, 1, 1, name='decoder1_1'))
89 |
90 | # classification
91 | (self.feed('decoder1_1')
92 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm
93 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/model_cutie2_dilate.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2019-04
3 | # xiaohui.zhao@outlook.com
4 | import tensorflow as tf
5 | from model_cutie2 import CUTIE2 as CUTIE
6 |
7 | class CUTIE2(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE2_dilate" #
10 |
11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='data_grid')
12 | self.data_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data_image')
13 | self.ps_1d_indices = tf.placeholder(tf.int32, shape=[None, None], name='ps_1d_indices')
14 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
15 |
16 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
17 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
18 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1
19 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
20 | self.layers = dict({'data_grid': self.data_grid, 'data_image': self.data_image, 'ps_1d_indices': self.ps_1d_indices,
21 | 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
22 |
23 | self.num_vocabs = num_vocabs
24 | self.num_classes = num_classes
25 | self.trainable = trainable
26 |
27 | self.embedding_size = params.embedding_size
28 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
29 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
30 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
31 |
32 | self.layer_inputs = []
33 | self.setup()
34 |
35 |
36 | def setup(self):
37 | ## grid
38 | (self.feed('data_grid')
39 | .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout))
40 |
41 | ## image
42 | (self.feed('data_image')
43 | .conv(3, 3, 128, 1, 1, name='image_encoder1_1')
44 | .conv(3, 3, 128, 1, 1, name='image_encoder1_2')
45 | .dilate_conv(3, 3, 128, 1, 1, 2, name='image_encoder1_5')
46 | .dilate_conv(3, 3, 128, 1, 1, 4, name='image_encoder1_6')
47 | .dilate_conv(3, 3, 128, 1, 1, 8, name='image_encoder1_7')
48 | .dilate_conv(3, 3, 128, 1, 1, 16, name='image_encoder1_8'))
49 |
50 | # feature map positional mapping
51 | (self.feed('image_encoder1_8', 'ps_1d_indices', 'data_grid')
52 | .positional_sampling(128, name='positional_sampling'))
53 |
54 | ## concate image with grid
55 | (self.feed('positional_sampling', 'embedding')
56 | .concat(3, name='concat')
57 | .conv(1, 1, 256, 1, 1, name='feature_fuser'))
58 |
59 | # encoder
60 | (self.feed('feature_fuser')
61 | .conv(3, 5, 256, 1, 1, name='encoder1_1')
62 | .conv(3, 5, 256, 1, 1, name='encoder1_2')
63 | .conv(3, 5, 256, 1, 1, name='encoder1_3')
64 | .conv(3, 5, 256, 1, 1, name='encoder1_4')
65 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_5')
66 | .dilate_conv(3, 5, 256, 1, 1, 4, name='encoder1_6')
67 | .dilate_conv(3, 5, 256, 1, 1, 8, name='encoder1_7')
68 | .dilate_conv(3, 5, 256, 1, 1, 16, name='encoder1_8'))
69 |
70 | # Atrous Spatial Pyramid Pooling module
71 | #(self.feed('encoder1_8')
72 | # .conv(1, 1, 256, 1, 1, name='aspp_0'))
73 | (self.feed('encoder1_8')
74 | .dilate_conv(3, 5, 256, 1, 1, 4, name='aspp_1'))
75 | (self.feed('encoder1_8')
76 | .dilate_conv(3, 5, 256, 1, 1, 8, name='aspp_2'))
77 | (self.feed('encoder1_8')
78 | .dilate_conv(3, 5, 256, 1, 1, 16, name='aspp_3'))
79 | (self.feed('encoder1_8')
80 | .global_pool(name='aspp_4'))
81 | (self.feed('aspp_1', 'aspp_2', 'aspp_3', 'aspp_4')
82 | .concat(3, name='aspp_concat')
83 | .conv(1, 1, 256, 1, 1, name='aspp_1x1'))
84 |
85 | # combine low level features
86 | (self.feed('encoder1_1', 'aspp_1x1')
87 | .concat(3, name='concat1')
88 | .conv(3, 5, 64, 1, 1, name='decoder1_1'))
89 |
90 | # classification
91 | (self.feed('decoder1_1')
92 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm
93 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/model_cutie2_aspp.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2019-04
3 | # xiaohui.zhao@outlook.com
4 | import tensorflow as tf
5 | from model_cutie2 import CUTIE2 as CUTIE
6 |
7 | class CUTIE2(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE2_dilate" #
10 |
11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='data_grid')
12 | self.data_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data_image')
13 | self.ps_1d_indices = tf.placeholder(tf.int32, shape=[None, None], name='ps_1d_indices')
14 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
15 |
16 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
17 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
18 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1
19 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
20 | self.layers = dict({'data_grid': self.data_grid, 'data_image': self.data_image, 'ps_1d_indices': self.ps_1d_indices,
21 | 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
22 |
23 | self.num_vocabs = num_vocabs
24 | self.num_classes = num_classes
25 | self.trainable = trainable
26 |
27 | self.embedding_size = params.embedding_size
28 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
29 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
30 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
31 |
32 | self.layer_inputs = []
33 | self.setup()
34 |
35 |
36 | def setup(self):
37 | ## grid
38 | (self.feed('data_grid')
39 | .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout))
40 |
41 | ## image
42 | (self.feed('data_image')
43 | .conv(3, 3, 32, 1, 1, name='image_encoder1_1')
44 | .conv(3, 3, 32, 1, 1, name='image_encoder1_2')
45 | .dilate_conv(3, 3, 32, 1, 1, 2, name='image_encoder1_5')
46 | .dilate_conv(3, 3, 32, 1, 1, 2, name='image_encoder1_6')
47 | .dilate_conv(3, 3, 32, 1, 1, 2, name='image_encoder1_7')
48 | .dilate_conv(3, 3, 32, 1, 1, 2, name='image_encoder1_8'))
49 | (self.feed('image_encoder1_8')
50 | .dilate_conv(3, 3, 32, 1, 1, 4, name='image_aspp_1'))
51 | (self.feed('image_encoder1_8')
52 | .dilate_conv(3, 3, 32, 1, 1, 8, name='image_aspp_2'))
53 | (self.feed('image_encoder1_8')
54 | .dilate_conv(3, 3, 32, 1, 1, 16, name='image_aspp_3'))
55 | (self.feed('image_encoder1_8')
56 | .global_pool(name='image_aspp_4'))
57 | (self.feed('image_aspp_1', 'image_aspp_2', 'image_aspp_3', 'image_aspp_4')
58 | .concat(3, name='image_aspp_concat')
59 | .conv(1, 1, 32, 1, 1, name='image_aspp_1x1'))
60 | (self.feed('image_encoder1_1', 'image_aspp_1x1')
61 | .concat(3, name='image_concat1')
62 | .conv(3, 3, 32, 1, 1, name='image_featuremap'))
63 |
64 | # feature map positional mapping
65 | (self.feed('image_featuremap', 'ps_1d_indices', 'data_grid')
66 | .positional_sampling(32, name='positional_sampling'))
67 |
68 | ## concate image with grid
69 | (self.feed('positional_sampling', 'embedding')
70 | .concat(3, name='concat')
71 | .conv(1, 1, 256, 1, 1, name='feature_fuser'))
72 |
73 | # encoder
74 | (self.feed('feature_fuser')
75 | .conv(3, 5, 256, 1, 1, name='encoder1_1')
76 | .conv(3, 5, 256, 1, 1, name='encoder1_2')
77 | .conv(3, 5, 256, 1, 1, name='encoder1_3')
78 | .conv(3, 5, 256, 1, 1, name='encoder1_4')
79 | .dilate_conv(3, 5, 256, 1, 1, 2, name='encoder1_5')
80 | .dilate_conv(3, 5, 256, 1, 1, 4, name='encoder1_6')
81 | .dilate_conv(3, 5, 256, 1, 1, 8, name='encoder1_7')
82 | .dilate_conv(3, 5, 256, 1, 1, 16, name='encoder1_8'))
83 |
84 | # Atrous Spatial Pyramid Pooling module
85 | #(self.feed('encoder1_8')
86 | # .conv(1, 1, 256, 1, 1, name='aspp_0'))
87 | (self.feed('encoder1_8')
88 | .dilate_conv(3, 5, 256, 1, 1, 4, name='aspp_1'))
89 | (self.feed('encoder1_8')
90 | .dilate_conv(3, 5, 256, 1, 1, 8, name='aspp_2'))
91 | (self.feed('encoder1_8')
92 | .dilate_conv(3, 5, 256, 1, 1, 16, name='aspp_3'))
93 | (self.feed('encoder1_8')
94 | .global_pool(name='aspp_4'))
95 | (self.feed('aspp_1', 'aspp_2', 'aspp_3', 'aspp_4')
96 | .concat(3, name='aspp_concat')
97 | .conv(1, 1, 256, 1, 1, name='aspp_1x1'))
98 |
99 | # combine low level features
100 | (self.feed('encoder1_1', 'aspp_1x1')
101 | .concat(3, name='concat1')
102 | .conv(3, 5, 64, 1, 1, name='decoder1_1'))
103 |
104 | # classification
105 | (self.feed('decoder1_1')
106 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm
107 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/model_cutie.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xh.zhao@outlook.com
4 | import tensorflow as tf
5 | from model_framework import Model
6 |
7 |
8 | class CUTIE(Model):
9 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
10 | self.name = "CUTIE_benchmark"
11 |
12 | self.data = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid_table')
13 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
14 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
15 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
16 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
17 | self.layers = dict({'data': self.data, 'gt_classes': self.gt_classes, 'ghm_weights': self.ghm_weights})
18 |
19 | self.num_vocabs = num_vocabs
20 | self.num_classes = num_classes
21 | self.trainable = trainable
22 |
23 | self.embedding_size = params.embedding_size
24 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
25 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
26 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
27 |
28 | self.layer_inputs = []
29 | self.setup()
30 |
31 |
32 | def setup(self):
33 | # input
34 | (self.feed('data')
35 | .embed(self.num_vocabs, self.embedding_size, name='embedding'))
36 |
37 | # encoder
38 | (self.feed('embedding')
39 | .conv(3, 5, 64, 1, 1, name='encoder1_1')
40 | .conv(3, 5, 128, 1, 1, name='encoder1_2')
41 | .max_pool(2, 2, 2, 2, name='pool1')
42 | .conv(3, 5, 128, 1, 1, name='encoder2_1')
43 | .conv(3, 5, 256, 1, 1, name='encoder2_2')
44 | .max_pool(2, 2, 2, 2, name='pool2')
45 | .conv(3, 5, 256, 1, 1, name='encoder3_1')
46 | .conv(3, 5, 512, 1, 1, name='encoder3_2')
47 | .max_pool(2, 2, 2, 2, name='pool3')
48 | .conv(3, 5, 512, 1, 1, name='encoder4_1')
49 | .conv(3, 5, 512, 1, 1, name='encoder4_2'))
50 |
51 | # decoder
52 | (self.feed('encoder4_2')
53 | .up_conv(3, 5, 512, 1, 1, name='up1')
54 | .conv(3, 5, 256, 1, 1, name='decoder1_1')
55 | .conv(3, 5, 256, 1, 1, name='decoder1_2')
56 | .up_conv(3, 5, 256, 1, 1, name='up2')
57 | .conv(3, 5, 128, 1, 1, name='decoder2_1')
58 | .conv(3, 5, 128, 1, 1, name='decoder2_2')
59 | .up_conv(3, 5, 128, 1, 1, name='up3')
60 | .conv(3, 5, 64, 1, 1, name='decoder3_1')
61 | .conv(3, 5, 64, 1, 1, name='decoder3_2'))
62 |
63 | # classification
64 | (self.feed('decoder3_2')
65 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits')
66 | .softmax(name='softmax'))
67 |
68 | def disp_results(self, data_input, data_label, model_output, threshold):
69 | data_input_flat = data_input.reshape([-1]) # [b * h * w]
70 | labels = [] # [b * h * w, classes]
71 | for item in data_label.reshape([-1]):
72 | labels.append([i==item for i in range(self.num_classes)])
73 | logits = model_output.reshape([-1, self.num_classes]) # [b * h * w, classes]
74 |
75 | # ignore none word input
76 | labels_flat = []
77 | results_flat = []
78 | for idx, item in enumerate(data_input_flat):
79 | if item != 0:
80 | labels_flat.extend(labels[idx])
81 | results_flat.extend(logits[idx] > threshold)
82 |
83 | num_p = sum(labels_flat)
84 | num_n = sum([1-label for label in labels_flat])
85 | num_all = len(results_flat)
86 | num_correct = sum([True for i in range(num_all) if labels_flat[i] == results_flat[i]])
87 |
88 | labels_flat_p = [label!=0 for label in labels_flat]
89 | labels_flat_n = [label==0 for label in labels_flat]
90 | num_tp = sum([labels_flat_p[i] * results_flat[i] for i in range(num_all)])
91 | num_tn = sum([labels_flat_n[i] * (not results_flat[i]) for i in range(num_all)])
92 | num_fp = num_n - num_tp
93 | num_fn = num_p - num_tp
94 |
95 | # accuracy, precision, recall
96 | accuracy = num_correct / num_all
97 | precision = num_tp / (num_tp + num_fp)
98 | recall = num_tp / (num_tp + num_fn)
99 |
100 | return accuracy, precision, recall
101 |
102 |
103 | def inference(self):
104 | return self.get_output('softmax') #cls_logits
105 |
106 |
107 | def build_loss(self):
108 | labels = self.get_output('gt_classes')
109 | cls_logits = self.get_output('cls_logits')
110 | cls_logits = tf.cond(self.use_ghm, lambda: cls_logits*self.get_output('ghm_weights'),
111 | lambda: cls_logits, name="GradientHarmonizingMechanism")
112 |
113 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=cls_logits)
114 |
115 | with tf.variable_scope('HardNegativeMining'):
116 | labels = tf.reshape(labels, [-1])
117 | cross_entropy = tf.reshape(cross_entropy, [-1])
118 |
119 | fg_idx = tf.where(tf.not_equal(labels, 0))
120 | fgs = tf.gather(cross_entropy, fg_idx)
121 | bg_idx = tf.where(tf.equal(labels, 0))
122 | bgs = tf.gather(cross_entropy, bg_idx)
123 |
124 | num = self.hard_negative_ratio * tf.shape(fgs)[0]
125 | num_bg = tf.cond(tf.shape(bgs)[0]>time per step: %.2fs <<'%(timer_stop - timer_start))
102 |
103 | if not params.is_table:
104 | recall, acc_strict, acc_soft, res = cal_accuracy(data_loader, np.array(data['grid_table']),
105 | np.array(data['gt_classes']), model_output_val,
106 | np.array(data['label_mapids']), data['bbox_mapids'])
107 | else:
108 | recall, acc_strict, acc_soft, res = cal_accuracy_table(data_loader, np.array(data['grid_table']),
109 | np.array(data['gt_classes']), model_output_val,
110 | np.array(data['label_mapids']), data['bbox_mapids'])
111 | # recall, acc_strict, acc_soft, res = cal_save_results(data_loader, np.array(data['grid_table']),
112 | # np.array(data['gt_classes']), model_output_val,
113 | # np.array(data['label_mapids']), data['bbox_mapids'],
114 | # data['file_name'][0], params.save_prefix)
115 | recalls += [recall]
116 | accs_strict += [acc_strict]
117 | accs_soft += [acc_soft]
118 | if acc_strict != 1:
119 | print(res.decode()) # show res for current batch
120 |
121 | # visualize result
122 | shape = data['shape']
123 | file_name = data['file_name'][0] # use one single file_name
124 | bboxes = data['bboxes'][file_name]
125 | if not params.is_table:
126 | vis_bbox(data_loader, params.doc_path, np.array(data['grid_table'])[0],
127 | np.array(data['gt_classes'])[0], np.array(model_output_val)[0], file_name,
128 | np.array(bboxes), shape)
129 | else:
130 | vis_table(data_loader, params.doc_path, np.array(data['grid_table'])[0],
131 | np.array(data['gt_classes'])[0], np.array(model_output_val)[0], file_name,
132 | np.array(bboxes), shape)
133 |
134 | recall = sum(recalls) / len(recalls)
135 | acc_strict = sum(accs_strict) / len(accs_strict)
136 | acc_soft = sum(accs_soft) / len(accs_soft)
137 | print('EVALUATION ACC (Recall/Acc): %.3f / %.3f (%.3f) \n'%(recall, acc_strict, acc_soft))
138 |
--------------------------------------------------------------------------------
/model_cutie2.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2019-04
3 | # xiaohui.zhao@outlook.com
4 | import tensorflow as tf
5 | from model_framework import Model
6 |
7 | class CUTIE2(Model):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTTIE_benchmark" #
10 |
11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='grid')
12 | self.data_image = tf.placeholder(tf.uint8, shape=[None, None, None, 3], name='image')
13 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
14 |
15 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
16 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
17 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1
18 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
19 | self.layers = dict({'data_grid': self.data_grid, 'data_image': self.data_image,
20 | 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
21 |
22 | self.num_vocabs = num_vocabs
23 | self.num_classes = num_classes
24 | self.trainable = trainable
25 |
26 | self.embedding_size = params.embedding_size
27 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
28 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
29 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
30 |
31 | self.layer_inputs = []
32 | self.setup()
33 |
34 |
35 | def setup(self):
36 | # input
37 | # (self.feed('data_grid')
38 | # .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout))
39 | # (self.feed('data_image')
40 | # .conv(3, 3, 64, 1, 1, name='image_encoder1_1')
41 | # .conv(3, 3, 128, 1, 1, name='image_encoder1_2'))
42 | #
43 | # # encoder
44 | # (self.feed('embedding')
45 | # .conv(3, 5, 64, 1, 1, name='encoder1_1')
46 | # .conv(3, 5, 128, 1, 1, name='encoder1_2')
47 | # .max_pool(2, 2, 2, 2, name='pool1')
48 | # .conv(3, 5, 128, 1, 1, name='encoder2_1')
49 | # .conv(3, 5, 256, 1, 1, name='encoder2_2')
50 | # .max_pool(2, 2, 2, 2, name='pool2')
51 | # .conv(3, 5, 256, 1, 1, name='encoder3_1')
52 | # .conv(3, 5, 512, 1, 1, name='encoder3_2')
53 | # .max_pool(2, 2, 2, 2, name='pool3')
54 | # .conv(3, 5, 512, 1, 1, name='encoder4_1')
55 | # .conv(3, 5, 512, 1, 1, name='encoder4_2'))
56 | #
57 | # # decoder
58 | # (self.feed('encoder4_2')
59 | # .up_conv(3, 5, 512, 1, 1, name='up1')
60 | # .conv(3, 5, 256, 1, 1, name='decoder1_1')
61 | # .conv(3, 5, 256, 1, 1, name='decoder1_2')
62 | # .up_conv(3, 5, 256, 1, 1, name='up2')
63 | # .conv(3, 5, 128, 1, 1, name='decoder2_1')
64 | # .conv(3, 5, 128, 1, 1, name='decoder2_2')
65 | # .up_conv(3, 5, 128, 1, 1, name='up3')
66 | # .conv(3, 5, 64, 1, 1, name='decoder3_1')
67 | # .conv(3, 5, 64, 1, 1, name='decoder3_2'))
68 | #
69 | # # classification
70 | # (self.feed('decoder3_2')
71 | # .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits')
72 | # .softmax(name='softmax'))
73 | pass
74 |
75 | def disp_results(self, data_input, data_label, model_output, threshold):
76 | data_input_flat = data_input.reshape([-1]) # [b * h * w]
77 | labels = [] # [b * h * w, classes]
78 | for item in data_label.reshape([-1]):
79 | labels.append([i==item for i in range(self.num_classes)])
80 | logits = model_output.reshape([-1, self.num_classes]) # [b * h * w, classes]
81 |
82 | # ignore none word input
83 | labels_flat = []
84 | results_flat = []
85 | for idx, item in enumerate(data_input_flat):
86 | if item != 0:
87 | labels_flat.extend(labels[idx])
88 | results_flat.extend(logits[idx] > threshold)
89 |
90 | num_p = sum(labels_flat)
91 | num_n = sum([1-label for label in labels_flat])
92 | num_all = len(results_flat)
93 | num_correct = sum([True for i in range(num_all) if labels_flat[i] == results_flat[i]])
94 |
95 | labels_flat_p = [label!=0 for label in labels_flat]
96 | labels_flat_n = [label==0 for label in labels_flat]
97 | num_tp = sum([labels_flat_p[i] * results_flat[i] for i in range(num_all)])
98 | num_tn = sum([labels_flat_n[i] * (not results_flat[i]) for i in range(num_all)])
99 | num_fp = num_n - num_tp
100 | num_fn = num_p - num_tp
101 |
102 | # accuracy, precision, recall
103 | accuracy = num_correct / num_all
104 | precision = num_tp / (num_tp + num_fp)
105 | recall = num_tp / (num_tp + num_fn)
106 |
107 | return accuracy, precision, recall
108 |
109 |
110 | def inference(self):
111 | return self.get_output('softmax') #cls_logits
112 |
113 |
114 | def build_loss(self):
115 | labels = self.get_output('gt_classes')
116 | cls_logits = self.get_output('cls_logits')
117 | cls_logits = tf.cond(self.use_ghm, lambda: cls_logits*self.get_output('ghm_weights'),
118 | lambda: cls_logits, name="GradientHarmonizingMechanism")
119 |
120 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=cls_logits)
121 |
122 | with tf.variable_scope('HardNegativeMining'):
123 | labels = tf.reshape(labels, [-1])
124 | cross_entropy = tf.reshape(cross_entropy, [-1])
125 |
126 | fg_idx = tf.where(tf.not_equal(labels, 0))
127 | fgs = tf.gather(cross_entropy, fg_idx)
128 | bg_idx = tf.where(tf.equal(labels, 0))
129 | bgs = tf.gather(cross_entropy, bg_idx)
130 |
131 | num = self.hard_negative_ratio * tf.shape(fgs)[0]
132 | num_bg = tf.cond(tf.shape(bgs)[0] Download start')
167 | task_info = download_info(args.server, args.project, args.task)
168 |
169 | if args.label is True:
170 | print('==> Download label info start...')
171 | download_label_info(task_info)
172 | print('==> Download label info done.')
173 |
174 | if args.postprocessing is True:
175 | print('==> Download postprocessing info start...')
176 | download_postprocessing_result(task_info)
177 | print('==> Download postprocessing info done.')
178 |
179 | if args.source_file is True:
180 | print('==> Download source file start...')
181 | output_path = 'task_{id}_source_file_{time}'.format(id=args.task, time=int(time.time()))
182 | download_file(output_path, 'source_file_path', task_info)
183 | print('==> Download source file done.')
184 |
185 | if args.quality_file is True:
186 | print('==> Download quality file start...')
187 | output_path = 'task_{id}_quality_file_{time}'.format(id=args.task, time=int(time.time()))
188 | download_file(output_path, 'quality_file_path', task_info)
189 | print('==> Download quality file done.')
190 |
191 | if args.engine_raw_result is True:
192 | print('==> Download engine raw file start...')
193 | output_path = 'task_{id}_engine_raw_file_{time}'.format(id=args.task, time=int(time.time()))
194 | download_file(output_path, 'engine_raw_result_path', task_info)
195 | print('==> Download engine raw file done.')
196 |
197 | if args.engine_result is True:
198 | print('==> Download engine file start...')
199 | output_path = 'task_{id}_engine_file_{time}'.format(id=args.task, time=int(time.time()))
200 | download_file(output_path, 'engine_result_path', task_info)
201 | print('==> Download engine file done.')
202 |
--------------------------------------------------------------------------------
/model_cutie_hr.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xiaohui.zhao@outlook.com
4 | import tensorflow as tf
5 | from model_cutie import CUTIE
6 |
7 | class CUTIERes(CUTIE):
8 | def __init__(self, num_vocabs, num_classes, params, trainable=True):
9 | self.name = "CUTIE_highresolution_8x" # 8x down sampling
10 |
11 | self.data_grid = tf.placeholder(tf.int32, shape=[None, None, None, 1], name='data_grid')
12 | self.gt_classes = tf.placeholder(tf.int32, shape=[None, None, None], name='gt_classes')
13 | self.data_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data_image') # not used in CUTIEv1
14 | self.ps_1d_indices = tf.placeholder(tf.int32, shape=[None, None], name='ps_1d_indices') # not used in CUTIEv1
15 |
16 | self.use_ghm = tf.equal(1, params.use_ghm) if hasattr(params, 'use_ghm') else tf.equal(1, 0) #params.use_ghm
17 | self.activation = 'sigmoid' if (hasattr(params, 'use_ghm') and params.use_ghm) else 'relu'
18 | self.dropout = params.data_augmentation_dropout if hasattr(params, 'data_augmentation_dropout') else 1
19 | self.ghm_weights = tf.placeholder(tf.float32, shape=[None, None, None, num_classes], name='ghm_weights')
20 | self.layers = dict({'data_grid': self.data_grid, 'gt_classes': self.gt_classes, 'ghm_weights':self.ghm_weights})
21 |
22 | self.num_vocabs = num_vocabs
23 | self.num_classes = num_classes
24 | self.trainable = trainable
25 |
26 | self.embedding_size = params.embedding_size
27 | self.weight_decay = params.weight_decay if hasattr(params, 'weight_decay') else 0.0
28 | self.hard_negative_ratio = params.hard_negative_ratio if hasattr(params, 'hard_negative_ratio') else 0.0
29 | self.batch_size = params.batch_size if hasattr(params, 'batch_size') else 0
30 |
31 | self.layer_inputs = []
32 | self.setup()
33 |
34 |
35 | def setup(self):
36 | # input
37 | (self.feed('data_grid')
38 | .embed(self.num_vocabs, self.embedding_size, name='embedding', dropout=self.dropout))
39 |
40 | # stage 1 block 1
41 | (self.feed('embedding')
42 | .conv(3, 5, 64, 1, 1, name='encoder1_1_1')
43 | .conv(3, 5, 64, 1, 1, name='encoder1_1_2')
44 | .conv(3, 5, 128, 2, 2, name='down1_1_2'))
45 |
46 |
47 | ## introduce stage 2
48 | # stage 1 block 2
49 | (self.feed('encoder1_1_2')
50 | .conv(3, 5, 64, 1, 1, name='encoder1_2_1')
51 | .conv(3, 5, 64, 1, 1, name='encoder1_2_2')
52 | .conv(3, 5, 128, 2, 2, name='down1_2_2'))
53 | # stage 2 block 2
54 | (self.feed('down1_1_2')
55 | .conv(3, 5, 128, 1, 1, name='encoder2_2_1')
56 | .conv(3, 5, 128, 1, 1, name='encoder2_2_2')
57 | .up_conv(3, 5, 64, 1, 1, factor=2, name='up2_2_1'))
58 |
59 |
60 | # stage 1 block 3
61 | (self.feed('encoder1_1_2', 'encoder1_2_2', 'up2_2_1')
62 | .add(name='add1_3')
63 | .conv(3, 5, 64, 1, 1, name='encoder1_3_1')
64 | .conv(3, 5, 64, 1, 1, name='encoder1_3_2')
65 | .conv(3, 5, 128, 2, 2, name='down1_3_2'))
66 | (self.feed('encoder1_3_2')
67 | .conv(3, 5, 256, 4, 4, name='down1_3_3'))
68 | # stage 2 block 3
69 | (self.feed('down1_1_2', 'encoder2_2_2', 'down1_2_2')
70 | .add(name='add2_3')
71 | .conv(3, 5, 128, 1, 1, name='encoder2_3_1')
72 | .conv(3, 5, 128, 1, 1, name='encoder2_3_2')
73 | .up_conv(3, 5, 64, 1, 1, factor=2, name='up2_3_1'))
74 | (self.feed('encoder2_3_2')
75 | .conv(3, 5, 256, 2, 2, name='down2_3_3'))
76 |
77 |
78 | ## introduce stage 3
79 | # stage 1 block 4
80 | (self.feed('add1_3', 'encoder1_3_2', 'up2_3_1')
81 | .add(name='add1_4')
82 | .conv(3, 5, 64, 1, 1, name='encoder1_4_1')
83 | .conv(3, 5, 64, 1, 1, name='encoder1_4_2')
84 | .conv(3, 5, 128, 2, 2, name='down1_4_2'))
85 | (self.feed('encoder1_4_2')
86 | .conv(3, 5, 256, 4, 4, name='down1_4_3'))
87 | # stage 2 block 4
88 | (self.feed('add2_3', 'encoder2_3_2', 'down1_3_2')
89 | .add(name='add2_4')
90 | .conv(3, 5, 128, 1, 1, name='encoder2_4_1')
91 | .conv(3, 5, 128, 1, 1, name='encoder2_4_2')
92 | .up_conv(3, 5, 64, 1, 1, name='up2_4_1'))
93 | (self.feed('encoder2_4_2')
94 | .conv(3, 5, 256, 2, 2, name='down2_4_3'))
95 | # stage 3 block 4
96 | (self.feed('down1_3_3', 'down2_3_3')
97 | .add(name='add3_4')
98 | .conv(3, 5, 256, 1, 1, name='encoder3_4_1')
99 | .conv(3, 5, 256, 1, 1, name='encoder3_4_2')
100 | .up_conv(3, 5, 64, 1, 1, factor=4, name='up3_4_1'))
101 | (self.feed('encoder3_4_2')
102 | .up_conv(3, 5, 128, 1, 1, factor=2, name='up3_4_2'))
103 |
104 |
105 | # stage 1 block 5
106 | (self.feed('add1_4', 'encoder1_4_2', 'up2_4_1', 'up3_4_1')
107 | .add(name='add1_5')
108 | .conv(3, 5, 64, 1, 1, name='encoder1_5_1')
109 | .conv(3, 5, 64, 1, 1, name='encoder1_5_2')
110 | .conv(3, 5, 128, 2, 2, name='down1_5_2'))
111 | (self.feed('encoder1_5_2')
112 | .conv(3, 5, 256, 4, 4, name='down1_5_3'))
113 | (self.feed('encoder1_5_2')
114 | .conv(3, 5, 512, 8, 8, name='down1_5_4'))
115 | # stage 2 block 5
116 | (self.feed('add2_4', 'encoder2_4_2', 'down1_4_2', 'up3_4_2')
117 | .add(name='add2_5')
118 | .conv(3, 5, 128, 1, 1, name='encoder2_5_1')
119 | .conv(3, 5, 128, 1, 1, name='encoder2_5_2')
120 | .up_conv(3, 5, 64, 1, 1, factor=2, name='up2_5_1'))
121 | (self.feed('encoder2_5_2')
122 | .conv(3, 5, 256, 2, 2, name='down2_5_3'))
123 | (self.feed('encoder2_5_2')
124 | .conv(3, 5, 512, 4, 4, name='down2_5_4'))
125 | # stage 3 block 5
126 | (self.feed('add3_4', 'encoder3_4_2', 'down1_4_3', 'down2_4_3')
127 | .add(name='add3_5')
128 | .conv(3, 5, 256, 1, 1, name='encoder3_5_1')
129 | .conv(3, 5, 256, 1, 1, name='encoder3_5_2')
130 | .up_conv(3, 5, 64, 1, 1, factor=4, name='up3_5_1'))
131 | (self.feed('encoder3_5_2')
132 | .up_conv(3, 5, 128, 1, 1, factor=2, name='up3_5_2'))
133 | (self.feed('encoder3_5_2')
134 | .conv(3, 5, 512, 2, 2, name='down3_5_4'))
135 |
136 |
137 | ## introduce stage 4
138 | # stage 1 block 6
139 | (self.feed('add1_5', 'encoder1_5_2', 'up2_5_1', 'up3_5_1')
140 | .add(name='add1_6')
141 | .conv(3, 5, 64, 1, 1, name='encoder1_6_1')
142 | .conv(3, 5, 64, 1, 1, name='encoder1_6_2')
143 | .conv(3, 5, 128, 2, 2, name='down1_6_2'))
144 | (self.feed('encoder1_6_2')
145 | .conv(3, 5, 256, 4, 4, name='down1_6_3'))
146 | (self.feed('encoder1_6_2')
147 | .conv(3, 5, 512, 8, 8, name='down1_6_4'))
148 | # stage 2 block 6
149 | (self.feed('add2_5', 'encoder2_5_2', 'down1_5_2', 'up3_5_2')
150 | .add(name='add2_6')
151 | .conv(3, 5, 128, 1, 1, name='encoder2_6_1')
152 | .conv(3, 5, 128, 1, 1, name='encoder2_6_2')
153 | .up_conv(3, 5, 64, 1, 1, name='up2_6_1'))
154 | (self.feed('encoder2_6_2')
155 | .conv(3, 5, 256, 2, 2, name='down2_6_3'))
156 | (self.feed('encoder2_6_2')
157 | .conv(3, 5, 512, 4, 4, name='down2_6_4'))
158 | # stage 3 block 6
159 | (self.feed('add3_5', 'encoder3_5_2', 'down1_5_3', 'down2_5_3')
160 | .add(name='add3_6')
161 | .conv(3, 5, 256, 1, 1, name='encoder3_6_1')
162 | .conv(3, 5, 256, 1, 1, name='encoder3_6_2')
163 | .up_conv(3, 5, 64, 1, 1, factor=4, name='up3_6_1'))
164 | (self.feed('encoder3_6_2')
165 | .up_conv(3, 5, 128, 1, 1, factor=2, name='up3_6_2'))
166 | (self.feed('encoder3_6_2')
167 | .conv(3, 5, 512, 2, 2, name='down3_6_4'))
168 | # stage 4 block 6
169 | (self.feed('down1_5_4', 'down2_5_4', 'down3_5_4')
170 | .add(name='add4_6')
171 | .conv(3, 5, 512, 1, 1, name='encoder4_6_1')
172 | .conv(3, 5, 512, 1, 1, name='encoder4_6_2')
173 | .up_conv(3, 5, 64, 1, 1, factor=8, name='up4_6_1'))
174 | (self.feed('encoder4_6_2')
175 | .up_conv(3, 5, 128, 1, 1, factor=4, name='up4_6_2'))
176 | (self.feed('encoder4_6_2')
177 | .up_conv(3, 5, 256, 1, 1, factor=2, name='up4_6_3'))
178 |
179 |
180 | # stage 1 block 7
181 | (self.feed('add1_6', 'encoder1_6_2', 'up2_6_1', 'up3_6_1', 'up4_6_1')
182 | .add(name='add1_7')
183 | .conv(3, 5, 64, 1, 1, name='encoder1_7_1')
184 | .conv(3, 5, 64, 1, 1, name='encoder1_7_2')
185 | .conv(3, 5, 128, 2, 2, name='down1_7_2'))
186 | (self.feed('encoder1_7_2')
187 | .conv(3, 5, 256, 4, 4, name='down1_7_3'))
188 | (self.feed('encoder1_7_2')
189 | .conv(3, 5, 512, 8, 8, name='down1_7_4'))
190 | # stage 2 block 7
191 | (self.feed('add2_6', 'encoder2_6_2', 'down1_6_2', 'up3_6_2', 'up4_6_2')
192 | .add(name='add2_7')
193 | .conv(3, 5, 128, 1, 1, name='encoder2_7_1')
194 | .conv(3, 5, 128, 1, 1, name='encoder2_7_2')
195 | .up_conv(3, 5, 64, 1, 1, factor=2, name='up2_7_1'))
196 | (self.feed('encoder2_7_2')
197 | .conv(3, 5, 256, 2, 2, name='down2_7_3'))
198 | (self.feed('encoder2_7_2')
199 | .conv(3, 5, 512, 4, 4, name='down2_7_4'))
200 | # stage 3 block 7
201 | (self.feed('add3_6', 'encoder3_6_2', 'down1_6_3', 'down2_6_3', 'up4_6_3')
202 | .add(name='add3_7')
203 | .conv(3, 5, 256, 1, 1, name='encoder3_7_1')
204 | .conv(3, 5, 256, 1, 1, name='encoder3_7_2')
205 | .up_conv(3, 5, 64, 1, 1, factor=4, name='up3_7_1'))
206 | (self.feed('encoder3_7_2')
207 | .up_conv(3, 5, 128, 1, 1, factor=2, name='up3_7_2'))
208 | (self.feed('encoder3_7_2')
209 | .conv(3, 5, 512, 2, 2, name='down3_7_4'))
210 | # stage 4 block 7
211 | (self.feed('add4_6', 'encoder4_6_2', 'down1_6_4', 'down2_6_4', 'down3_6_4')
212 | .add(name='add4_7')
213 | .conv(3, 5, 512, 1, 1, name='encoder4_7_1')
214 | .conv(3, 5, 512, 1, 1, name='encoder4_7_2')
215 | .up_conv(3, 5, 64, 1, 1, factor=8, name='up4_7_1'))
216 | (self.feed('encoder4_7_2')
217 | .up_conv(3, 5, 128, 1, 1, factor=4, name='up4_7_2'))
218 | (self.feed('encoder4_7_2')
219 | .up_conv(3, 5, 256, 1, 1, factor=2, name='up4_7_3'))
220 |
221 |
222 | # stage 1 block 8
223 | (self.feed('add1_7', 'encoder1_7_2', 'up2_7_1', 'up3_7_1', 'up4_7_1')
224 | .add(name='add1_8')
225 | .conv(3, 5, 64, 1, 1, name='encoder1_8_1')
226 | .conv(3, 5, 64, 1, 1, name='encoder1_8_2'))
227 | # stage 2 block 8
228 | (self.feed('add2_7', 'encoder2_7_2', 'down1_7_2', 'up3_7_2', 'up4_7_2')
229 | .add(name='add2_8')
230 | .conv(3, 5, 128, 1, 1, name='encoder2_8_1')
231 | .conv(3, 5, 128, 1, 1, name='encoder2_8_2')
232 | .up_conv(3, 5, 64, 1, 1, factor=2, name='up2_8_1'))
233 | # stage 3 block 8
234 | (self.feed('add3_7', 'encoder3_7_2', 'down1_7_3', 'down2_7_3', 'up4_7_3')
235 | .add(name='add3_8')
236 | .conv(3, 5, 256, 1, 1, name='encoder3_8_1')
237 | .conv(3, 5, 256, 1, 1, name='encoder3_8_2')
238 | .up_conv(3, 5, 64, 1, 1, factor=4, name='up3_8_1'))
239 | # stage 4 block 8
240 | (self.feed('add4_7', 'encoder4_7_2', 'down1_7_4', 'down2_7_4', 'down3_7_4')
241 | .add(name='add4_8')
242 | .conv(3, 5, 512, 1, 1, name='encoder4_8_1')
243 | .conv(3, 5, 512, 1, 1, name='encoder4_8_2')
244 | .up_conv(3, 5, 64, 1, 1, factor=8, name='up4_8_1'))
245 |
246 |
247 | # stage 1 block 9
248 | (self.feed('add1_8', 'encoder1_8_2', 'up2_8_1', 'up3_8_1', 'up4_8_1')
249 | .add(name='add1_9')
250 | .conv(3, 5, 64, 1, 1, name='encoder1_9_1')
251 | .conv(3, 5, 64, 1, 1, name='encoder1_9_2'))
252 |
253 |
254 | # classification
255 | (self.feed('encoder1_9_2')
256 | .conv(1, 1, self.num_classes, 1, 1, activation=self.activation, name='cls_logits') # sigmoid for ghm
257 | .softmax(name='softmax'))
--------------------------------------------------------------------------------
/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in self.basic_tokenizer.tokenize(text):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | class BasicTokenizer(object):
186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
187 |
188 | def __init__(self, do_lower_case=True):
189 | """Constructs a BasicTokenizer.
190 |
191 | Args:
192 | do_lower_case: Whether to lower case the input.
193 | """
194 | self.do_lower_case = do_lower_case
195 |
196 | def tokenize(self, text):
197 | """Tokenizes a piece of text."""
198 | text = convert_to_unicode(text)
199 | text = self._clean_text(text)
200 |
201 | # This was added on November 1st, 2018 for the multilingual and Chinese
202 | # models. This is also applied to the English models now, but it doesn't
203 | # matter since the English models were not trained on any Chinese data
204 | # and generally don't have any Chinese data in them (there are Chinese
205 | # characters in the vocabulary because Wikipedia does have some Chinese
206 | # words in the English Wikipedia.).
207 | text = self._tokenize_chinese_chars(text)
208 |
209 | orig_tokens = whitespace_tokenize(text)
210 | split_tokens = []
211 | for token in orig_tokens:
212 | if self.do_lower_case:
213 | token = token.lower()
214 | token = self._run_strip_accents(token)
215 | split_tokens.extend(self._run_split_on_punc(token))
216 |
217 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
218 | return output_tokens
219 |
220 | def _run_strip_accents(self, text):
221 | """Strips accents from a piece of text."""
222 | text = unicodedata.normalize("NFD", text)
223 | output = []
224 | for char in text:
225 | cat = unicodedata.category(char)
226 | if cat == "Mn":
227 | continue
228 | output.append(char)
229 | return "".join(output)
230 |
231 | def _run_split_on_punc(self, text):
232 | """Splits punctuation on a piece of text."""
233 | chars = list(text)
234 | i = 0
235 | start_new_word = True
236 | output = []
237 | while i < len(chars):
238 | char = chars[i]
239 | if _is_punctuation(char):
240 | output.append([char])
241 | start_new_word = True
242 | else:
243 | if start_new_word:
244 | output.append([])
245 | start_new_word = False
246 | output[-1].append(char)
247 | i += 1
248 |
249 | return ["".join(x) for x in output]
250 |
251 | def _tokenize_chinese_chars(self, text):
252 | """Adds whitespace around any CJK character."""
253 | output = []
254 | for char in text:
255 | cp = ord(char)
256 | if self._is_chinese_char(cp):
257 | output.append(" ")
258 | output.append(char)
259 | output.append(" ")
260 | else:
261 | output.append(char)
262 | return "".join(output)
263 |
264 | def _is_chinese_char(self, cp):
265 | """Checks whether CP is the codepoint of a CJK character."""
266 | # This defines a "chinese character" as anything in the CJK Unicode block:
267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
268 | #
269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
270 | # despite its name. The modern Korean Hangul alphabet is a different block,
271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
272 | # space-separated words, so they are not treated specially and handled
273 | # like the all of the other languages.
274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
275 | (cp >= 0x3400 and cp <= 0x4DBF) or #
276 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
277 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
278 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
280 | (cp >= 0xF900 and cp <= 0xFAFF) or #
281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
282 | return True
283 |
284 | return False
285 |
286 | def _clean_text(self, text):
287 | """Performs invalid character removal and whitespace cleanup on text."""
288 | output = []
289 | for char in text:
290 | cp = ord(char)
291 | if cp == 0 or cp == 0xfffd or _is_control(char):
292 | continue
293 | if _is_whitespace(char):
294 | output.append(" ")
295 | else:
296 | output.append(char)
297 | return "".join(output)
298 |
299 |
300 | class WordpieceTokenizer(object):
301 | """Runs WordPiece tokenziation."""
302 |
303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
304 | self.vocab = vocab
305 | self.unk_token = unk_token
306 | self.max_input_chars_per_word = max_input_chars_per_word
307 |
308 | def tokenize(self, text):
309 | """Tokenizes a piece of text into its word pieces.
310 |
311 | This uses a greedy longest-match-first algorithm to perform tokenization
312 | using the given vocabulary.
313 |
314 | For example:
315 | input = "unaffable"
316 | output = ["un", "##aff", "##able"]
317 |
318 | Args:
319 | text: A single token or whitespace separated tokens. This should have
320 | already been passed through `BasicTokenizer.
321 |
322 | Returns:
323 | A list of wordpiece tokens.
324 | """
325 |
326 | text = convert_to_unicode(text)
327 |
328 | output_tokens = []
329 | for token in whitespace_tokenize(text):
330 | chars = list(token)
331 | if len(chars) > self.max_input_chars_per_word:
332 | output_tokens.append(self.unk_token)
333 | continue
334 |
335 | is_bad = False
336 | start = 0
337 | sub_tokens = []
338 | while start < len(chars):
339 | end = len(chars)
340 | cur_substr = None
341 | while start < end:
342 | substr = "".join(chars[start:end])
343 | if start > 0:
344 | substr = "##" + substr
345 | if substr in self.vocab:
346 | cur_substr = substr
347 | break
348 | end -= 1
349 | if cur_substr is None:
350 | is_bad = True
351 | break
352 | sub_tokens.append(cur_substr)
353 | start = end
354 |
355 | if is_bad:
356 | output_tokens.append(self.unk_token)
357 | else:
358 | output_tokens.extend(sub_tokens)
359 | return output_tokens
360 |
361 |
362 | def _is_whitespace(char):
363 | """Checks whether `chars` is a whitespace character."""
364 | # \t, \n, and \r are technically contorl characters but we treat them
365 | # as whitespace since they are generally considered as such.
366 | if char == " " or char == "\t" or char == "\n" or char == "\r":
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat == "Zs":
370 | return True
371 | return False
372 |
373 |
374 | def _is_control(char):
375 | """Checks whether `chars` is a control character."""
376 | # These are technically control characters but we count them as whitespace
377 | # characters.
378 | if char == "\t" or char == "\n" or char == "\r":
379 | return False
380 | cat = unicodedata.category(char)
381 | if cat.startswith("C"):
382 | return True
383 | return False
384 |
385 |
386 | def _is_punctuation(char):
387 | """Checks whether `chars` is a punctuation character."""
388 | cp = ord(char)
389 | # We treat all non-letter/number ASCII as punctuation.
390 | # Characters such as "^", "$", and "`" are not in the Unicode
391 | # Punctuation class but we treat them as punctuation anyways, for
392 | # consistency.
393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat.startswith("P"):
398 | return True
399 | return False
400 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 | 1583894111450
190 |
191 |
192 | 1583894111450
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-01
3 | # xiaohui.zhao@outlook.com
4 | import numpy as np
5 | import csv
6 | from os.path import join
7 | try:
8 | import cv2
9 | except ImportError:
10 | pass
11 |
12 | c_threshold = 0.5
13 |
14 | def cal_accuracy(data_loader, grid_table, gt_classes, model_output_val, label_mapids, bbox_mapids):
15 | #num_tp = 0
16 | #num_fn = 0
17 | res = ''
18 | num_correct = 0
19 | num_correct_strict = 0
20 | num_correct_soft = 0
21 | num_all = grid_table.shape[0] * (model_output_val.shape[-1]-1)
22 | for b in range(grid_table.shape[0]):
23 | data_input_flat = grid_table[b,:,:,0].reshape([-1])
24 | labels = gt_classes[b,:,:].reshape([-1])
25 | logits = model_output_val[b,:,:,:].reshape([-1, data_loader.num_classes])
26 | label_mapid = label_mapids[b]
27 | bbox_mapid = bbox_mapids[b]
28 | rows, cols = grid_table.shape[1:3]
29 | bbox_id = np.array([row*cols+col for row in range(rows) for col in range(cols)])
30 |
31 | # ignore inputs that are not word
32 | indexes = np.where(data_input_flat != 0)[0]
33 | data_selected = data_input_flat[indexes]
34 | labels_selected = labels[indexes]
35 | logits_array_selected = logits[indexes]
36 | bbox_id_selected = bbox_id[indexes]
37 |
38 | # calculate accuracy
39 | #test_classes = [1,2,3,4,5]
40 | #for c in test_classes:
41 | for c in range(1, data_loader.num_classes):
42 | labels_indexes = np.where(labels_selected == c)[0]
43 | logits_indexes = np.where(logits_array_selected[:,c] > c_threshold)[0]
44 |
45 | labels_words = list(data_loader.index_to_word[i] for i in data_selected[labels_indexes])
46 | logits_words = list(data_loader.index_to_word[i] for i in data_selected[logits_indexes])
47 |
48 | label_bbox_ids = label_mapid[c] # GT bbox_ids related to the type of class
49 | logit_bbox_ids = [bbox_mapid[bbox] for bbox in bbox_id_selected[logits_indexes] if bbox in bbox_mapid]
50 |
51 | #if np.array_equal(labels_indexes, logits_indexes):
52 | if set(label_bbox_ids) == set(logit_bbox_ids): # decide as correct when all ids match
53 | num_correct_strict += 1
54 | num_correct_soft += 1
55 | elif set(label_bbox_ids).issubset(set(logit_bbox_ids)): # correct when gt is subset of gt
56 | num_correct_soft += 1
57 | try: # calculate prevalence with decimal precision
58 | num_correct += np.shape(np.intersect1d(labels_indexes, logits_indexes))[0] / np.shape(labels_indexes)[0]
59 | except ZeroDivisionError:
60 | if np.shape(labels_indexes)[0] == 0:
61 | num_correct += 1
62 | else:
63 | num_correct += 0
64 |
65 | # show results without the class
66 | if b==0:
67 | res += '\n{}(GT/Inf):\t"'.format(data_loader.classes[c])
68 |
69 | # ground truth label
70 | res += ' '.join(data_loader.index_to_word[i] for i in data_selected[labels_indexes])
71 | res += '" | "'
72 | res += ' '.join(data_loader.index_to_word[i] for i in data_selected[logits_indexes])
73 | res += '"'
74 |
75 | # wrong inferences results
76 | if not np.array_equal(labels_indexes, logits_indexes):
77 | res += '\n \t FALSES =>>'
78 | logits_flat = logits_array_selected[:,c]
79 | fault_logits_indexes = np.setdiff1d(logits_indexes, labels_indexes)
80 | for i in range(len(data_selected)):
81 | if i not in fault_logits_indexes: # only show fault_logits_indexes
82 | continue
83 | w = data_loader.index_to_word[data_selected[i]]
84 | l = data_loader.classes[labels_selected[i]]
85 | res += ' "%s"/%s, '%(w, l)
86 | #res += ' "%s"/%.2f%s, '%(w, logits_flat[i], l)
87 |
88 | #print(res)
89 | prevalence = num_correct / num_all
90 | accuracy_strict = num_correct_strict / num_all
91 | accuracy_soft = num_correct_soft / num_all
92 | return prevalence, accuracy_strict, accuracy_soft, res.encode("utf-8")
93 |
94 |
95 | def cal_save_results(data_loader, grid_table, gt_classes, model_output_val, label_mapids, bbox_mapids, file_names, save_prefix):
96 | res = ''
97 | num_correct = 0
98 | num_correct_strict = 0
99 | num_correct_soft = 0
100 | num_all = grid_table.shape[0] * (model_output_val.shape[-1]-1)
101 | for b in range(grid_table.shape[0]):
102 | filename = file_names[0]
103 |
104 | data_input_flat = grid_table[b,:,:,0].reshape([-1])
105 | labels = gt_classes[b,:,:].reshape([-1])
106 | logits = model_output_val[b,:,:,:].reshape([-1, data_loader.num_classes])
107 | label_mapid = label_mapids[b]
108 | bbox_mapid = bbox_mapids[b]
109 | rows, cols = grid_table.shape[1:3]
110 | bbox_id = np.array([row*cols+col for row in range(rows) for col in range(cols)])
111 |
112 | # ignore inputs that are not word
113 | indexes = np.where(data_input_flat != 0)[0]
114 | data_selected = data_input_flat[indexes]
115 | labels_selected = labels[indexes]
116 | logits_array_selected = logits[indexes]
117 | bbox_id_selected = bbox_id[indexes]
118 |
119 | # calculate accuracy
120 | for c in range(1, data_loader.num_classes):
121 | labels_indexes = np.where(labels_selected == c)[0]
122 | logits_indexes = np.where(logits_array_selected[:,c] > c_threshold)[0]
123 |
124 | labels_words = list(data_loader.index_to_word[i] for i in data_selected[labels_indexes])
125 | logits_words = list(data_loader.index_to_word[i] for i in data_selected[logits_indexes])
126 |
127 | label_bbox_ids = label_mapid[c] # GT bbox_ids related to the type of class
128 | logit_bbox_ids = [bbox_mapid[bbox] for bbox in bbox_id_selected[logits_indexes] if bbox in bbox_mapid]
129 |
130 | #if np.array_equal(labels_indexes, logits_indexes):
131 | if set(label_bbox_ids) == set(logit_bbox_ids): # decide as correct when all ids match
132 | num_correct_strict += 1
133 | num_correct_soft += 1
134 | elif set(label_bbox_ids).issubset(set(logit_bbox_ids)): # correct when gt is subset of gt
135 | num_correct_soft += 1
136 | try: # calculate prevalence with decimal precision
137 | num_correct += np.shape(np.intersect1d(labels_indexes, logits_indexes))[0] / np.shape(labels_indexes)[0]
138 | except ZeroDivisionError:
139 | if np.shape(labels_indexes)[0] == 0:
140 | num_correct += 1
141 | else:
142 | num_correct += 0
143 |
144 | # show results without the class
145 |
146 | # ground truth label
147 | gt = str(' '.join(data_loader.index_to_word[i] for i in data_selected[labels_indexes]))
148 | predict = str(' '.join(data_loader.index_to_word[i] for i in data_selected[logits_indexes]))
149 |
150 |
151 | # write results to csv
152 | fieldnames = ['TaskID', 'GT', 'Predicted']
153 |
154 | csv_filename = 'data/results/' + save_prefix + '_' + data_loader.classes[c] + '.csv'
155 | writer = csv.DictWriter(open(csv_filename, 'a'), fieldnames=fieldnames)
156 | row = {'TaskID':filename, 'GT':gt, 'Predicted':predict}
157 | writer.writerow(row)
158 |
159 | csv_diff_filename = 'data/results/' + save_prefix + '_Diff_' + data_loader.classes[c] + '.csv'
160 | if gt != predict:
161 | writer = csv.DictWriter(open(csv_diff_filename, 'a'), fieldnames=fieldnames)
162 | row = {'TaskID':filename, 'GT':gt, 'Predicted':predict}
163 | writer.writerow(row)
164 |
165 | if b == 0:
166 | res += '\n{}(GT/Inf):\t"'.format(data_loader.classes[c])
167 | res += gt + '" | "' + predict + '"'
168 | # wrong inferences results
169 | if not np.array_equal(labels_indexes, logits_indexes):
170 | res += '\n \t FALSES =>>'
171 | logits_flat = logits_array_selected[:,c]
172 | fault_logits_indexes = np.setdiff1d(logits_indexes, labels_indexes)
173 | for i in range(len(data_selected)):
174 | if i not in fault_logits_indexes: # only show fault_logits_indexes
175 | continue
176 | w = data_loader.index_to_word[data_selected[i]]
177 | l = data_loader.classes[labels_selected[i]]
178 | res += ' "%s"/%s, '%(w, l)
179 | #res += ' "%s"/%.2f%s, '%(w, logits_flat[i], l)
180 |
181 | #print(res)
182 | prevalence = num_correct / num_all
183 | accuracy_strict = num_correct_strict / num_all
184 | accuracy_soft = num_correct_soft / num_all
185 | return prevalence, accuracy_strict, accuracy_soft, res.encode("utf-8")
186 |
187 |
188 | def vis_bbox(data_loader, file_prefix, grid_table, gt_classes, model_output_val, file_name, bboxes, shape):
189 | data_input_flat = grid_table.reshape([-1])
190 | labels = gt_classes.reshape([-1])
191 | logits = model_output_val.reshape([-1, data_loader.num_classes])
192 | bboxes = bboxes.reshape([-1])
193 |
194 | max_len = 768*2 # upper boundary of image display size
195 | img = cv2.imread(join(file_prefix, file_name))
196 | if img is not None:
197 | shape = list(img.shape)
198 |
199 | bbox_pad = 1
200 | gt_color = [[255, 250, 240], [152, 245, 255], [119,204,119], [100, 149, 237],
201 | [192, 255, 62], [119,119,204], [114,124,114], [240, 128, 128], [255, 105, 180]]
202 | inf_color = [[255, 222, 173], [0, 255, 255], [50,219,50], [72, 61, 139],
203 | [154, 205, 50], [50,50,219], [64,76,64], [255, 0, 0], [255, 20, 147]]
204 |
205 | font_size = 0.5
206 | font = cv2.FONT_HERSHEY_COMPLEX
207 | ft_color = [50, 50, 250]
208 |
209 | factor = max_len / max(shape)
210 | shape[0], shape[1] = [int(s*factor) for s in shape[:2]]
211 |
212 | img = cv2.resize(img, (shape[1], shape[0]))
213 | overlay_box = np.zeros(shape, dtype=img.dtype)
214 | overlay_line = np.zeros(shape, dtype=img.dtype)
215 | for i in range(len(data_input_flat)):
216 | if len(bboxes[i]) > 0:
217 | x,y,w,h = [int(p*factor) for p in bboxes[i]]
218 | else:
219 | row = i // data_loader.rows
220 | col = i % data_loader.cols
221 | x = shape[1] // data_loader.cols * col
222 | y = shape[0] // data_loader.rows * row
223 | w = shape[1] // data_loader.cols * 2
224 | h = shape[0] // data_loader.cols * 2
225 |
226 | if data_input_flat[i] and labels[i]:
227 | gt_id = labels[i]
228 | cv2.rectangle(overlay_box, (x,y), (x+w,y+h), gt_color[gt_id], -1)
229 |
230 | if max(logits[i]) > c_threshold:
231 | inf_id = np.argmax(logits[i])
232 | if inf_id:
233 | cv2.rectangle(overlay_line, (x+bbox_pad,y+bbox_pad), \
234 | (x+bbox_pad+w,y+bbox_pad+h), inf_color[inf_id], max_len//768*2)
235 |
236 | #text = data_loader.classes[gt_id] + '|' + data_loader.classes[inf_id]
237 | #cv2.putText(img, text, (x,y), font, font_size, ft_color)
238 |
239 | # legends
240 | w = shape[1] // data_loader.cols * 4
241 | h = shape[0] // data_loader.cols * 2
242 | for i in range(1, len(data_loader.classes)):
243 | row = i * 3
244 | col = 0
245 | x = shape[1] // data_loader.cols * col
246 | y = shape[0] // data_loader.rows * row
247 | cv2.rectangle(img, (x,y), (x+w,y+h), gt_color[i], -1)
248 | cv2.putText(img, data_loader.classes[i], (x+w,y+h), font, 0.8, ft_color)
249 |
250 | row = i * 3 + 1
251 | col = 0
252 | x = shape[1] // data_loader.cols * col
253 | y = shape[0] // data_loader.rows * row
254 | cv2.rectangle(img, (x+bbox_pad,y+bbox_pad), \
255 | (x+bbox_pad+w,y+bbox_pad+h), inf_color[i], max_len//384)
256 |
257 | alpha = 0.4
258 | cv2.addWeighted(overlay_box, alpha, img, 1-alpha, 0, img)
259 | cv2.addWeighted(overlay_line, 1-alpha, img, 1, 0, img)
260 | cv2.imwrite('results/' + file_name[:-4]+'.png', img)
261 | cv2.imshow("test", img)
262 | cv2.waitKey(0)
263 |
264 |
265 | def cal_accuracy_table(data_loader, grid_table, gt_classes, model_output_val, label_mapids, bbox_mapids):
266 | #num_tp = 0
267 | #num_fn = 0
268 | res = ''
269 | num_correct = 0
270 | num_correct_strict = 0
271 | num_correct_soft = 0
272 | num_all = grid_table.shape[0] * (model_output_val.shape[-1]-1)
273 | for b in range(grid_table.shape[0]):
274 | data_input_flat = grid_table[b,:,:,0]
275 | rows, cols = grid_table.shape[1:3]
276 | labels = gt_classes[b,:,:]
277 | logits = model_output_val[b,:,:,:].reshape([rows, cols, data_loader.num_classes])
278 | label_mapid = label_mapids[b]
279 | bbox_mapid = bbox_mapids[b]
280 | bbox_id = np.array([row*cols+col for row in range(rows) for col in range(cols)])
281 |
282 | # calculate accuracy
283 | #test_classes = [1,2,3,4,5]
284 | #for c in test_classes:
285 | for c in range(1, data_loader.num_classes):
286 | label_rows, label_cols = np.where(labels == c)
287 | logit_rows, logit_cols = np.where(logits[:,:,c] > c_threshold)
288 | if min(label_rows) == min(logit_rows) and max(label_cols) == max(logit_cols):
289 | num_correct_strict += 1
290 | num_correct_soft += 1
291 | num_correct += 1
292 | if min(label_rows) > min(logit_rows) and max(label_cols) < max(logit_cols):
293 | num_correct_soft += 1
294 | num_correct += 1
295 |
296 | prevalence = num_correct / num_all
297 | accuracy_strict = num_correct_strict / num_all
298 | accuracy_soft = num_correct_soft / num_all
299 | return prevalence, accuracy_strict, accuracy_soft, res.encode("utf-8")
300 |
301 |
302 | def vis_table(data_loader, file_prefix, grid_table, gt_classes, model_output_val, file_name, bboxes, shape):
303 | data_input_flat = grid_table.reshape([-1])
304 | labels = gt_classes.reshape([-1])
305 | logits = model_output_val.reshape([-1, data_loader.num_classes])
306 | bboxes = bboxes.reshape([-1])
307 |
308 | max_len = 768*2 # upper boundary of image display size
309 | img = cv2.imread(join(file_prefix, file_name))
310 | if img is not None:
311 | shape = list(img.shape)
312 |
313 | bbox_pad = 1
314 | gt_color = [[255, 250, 240], [152, 245, 255], [119,204,119], [100, 149, 237],
315 | [192, 255, 62], [119,119,204], [114,124,114], [240, 128, 128], [255, 105, 180]]
316 | inf_color = [[255, 222, 173], [0, 255, 255], [50,219,50], [72, 61, 139],
317 | [154, 205, 50], [50,50,219], [64,76,64], [255, 0, 0], [255, 20, 147]]
318 |
319 | font_size = 0.5
320 | font = cv2.FONT_HERSHEY_COMPLEX
321 | ft_color = [50, 50, 250]
322 |
323 | factor = max_len / max(shape)
324 | shape[0], shape[1] = [int(s*factor) for s in shape[:2]]
325 |
326 | img = cv2.resize(img, (shape[1], shape[0]))
327 | overlay_box = np.zeros(shape, dtype=img.dtype)
328 | overlay_line = np.zeros(shape, dtype=img.dtype)
329 | gt_x, gt_y, gt_r, gt_b = 99999, 99999, 0, 0
330 | inf_x, inf_y, inf_r, inf_b = 99999, 99999, 0, 0
331 | for i in range(len(data_input_flat)):
332 | if len(bboxes[i]) > 0:
333 | x,y,w,h = [int(p*factor) for p in bboxes[i]]
334 | else:
335 | row = i // data_loader.rows
336 | col = i % data_loader.cols
337 | x = shape[1] // data_loader.cols * col
338 | y = shape[0] // data_loader.rows * row
339 | w = shape[1] // data_loader.cols * 2
340 | h = shape[0] // data_loader.cols * 2
341 |
342 | if data_input_flat[i] and labels[i]:
343 | gt_id = labels[i]
344 | cv2.rectangle(overlay_box, (x,y), (x+w,y+h), gt_color[gt_id], -1)
345 | gt_x = min(x, gt_x)
346 | gt_y = min(y, gt_y)
347 | gt_r = max(x+w, gt_r)
348 | gt_b = max(y+h, gt_b)
349 |
350 | if max(logits[i]) > c_threshold:
351 | inf_id = np.argmax(logits[i])
352 | if inf_id:
353 | cv2.rectangle(overlay_line, (x+bbox_pad,y+bbox_pad), \
354 | (x+bbox_pad+w,y+bbox_pad+h), inf_color[inf_id], max_len//768*2)
355 | inf_x = min(x, inf_x)
356 | inf_y = min(y, inf_y)
357 | inf_r = max(x+w, inf_r)
358 | inf_b = max(y+h, inf_b)
359 |
360 | #text = data_loader.classes[gt_id] + '|' + data_loader.classes[inf_id]
361 | #cv2.putText(img, text, (x,y), font, font_size, ft_color)
362 |
363 | cv2.rectangle(overlay_box, (gt_x,gt_y), (gt_r,gt_b), [180,180,215], -1)
364 | cv2.rectangle(overlay_line, (inf_x+bbox_pad,inf_y+bbox_pad), (inf_r+bbox_pad,inf_b+bbox_pad), [0,115,255], max_len//768*2)
365 |
366 | # legends
367 | w = shape[1] // data_loader.cols * 4
368 | h = shape[0] // data_loader.cols * 2
369 | for i in range(1, len(data_loader.classes)):
370 | row = i * 3
371 | col = 0
372 | x = shape[1] // data_loader.cols * col
373 | y = shape[0] // data_loader.rows * row
374 | cv2.rectangle(img, (x,y), (x+w,y+h), gt_color[i], -1)
375 | cv2.putText(img, data_loader.classes[i], (x+w,y+h), font, 0.8, ft_color)
376 |
377 | row = i * 3 + 1
378 | col = 0
379 | x = shape[1] // data_loader.cols * col
380 | y = shape[0] // data_loader.rows * row
381 | cv2.rectangle(img, (x+bbox_pad,y+bbox_pad), \
382 | (x+bbox_pad+w,y+bbox_pad+h), inf_color[i], max_len//384)
383 |
384 | alpha = 0.4
385 | cv2.addWeighted(overlay_box, alpha, img, 1-alpha, 0, img)
386 | cv2.addWeighted(overlay_line, 1-alpha, img, 1, 0, img)
387 | cv2.imwrite('results/' + file_name[:-4]+'.png', img)
388 | cv2.imshow("test", img)
389 | cv2.waitKey(0)
--------------------------------------------------------------------------------
/main_train_json.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xh.zhao@outlook.com
4 | import tensorflow as tf
5 | import numpy as np
6 | import argparse, os
7 | import timeit
8 | from pprint import pprint
9 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
10 |
11 | from data_loader_json import DataLoader
12 | from utils import *
13 |
14 | from model_cutie_aspp import CUTIERes as CUTIEv1
15 | from model_cutie2_aspp import CUTIE2 as CUTIEv2
16 |
17 | parser = argparse.ArgumentParser(description='CUTIE parameters')
18 | # data
19 | parser.add_argument('--use_cutie2', type=bool, default=False) # True to read image from doc_path
20 | parser.add_argument('--doc_path', type=str, default='data/SROIE')
21 | parser.add_argument('--save_prefix', type=str, default='SROIE', help='prefix for ckpt') # TBD: save log/models with prefix
22 | parser.add_argument('--test_path', type=str, default='') # leave empty if no test data provided
23 |
24 | # ckpt
25 | parser.add_argument('--restore_ckpt', type=bool, default=False)
26 | parser.add_argument('--restore_bertembedding_only', type=bool, default=False) # effective when restore_ckpt is True
27 | parser.add_argument('--embedding_file', type=str, default='../graph/bert/multi_cased_L-12_H-768_A-12/bert_model.ckpt')
28 | parser.add_argument('--ckpt_path', type=str, default='../graph/CUTIE/graph/')
29 | parser.add_argument('--ckpt_file', type=str, default='meals/CUTIE_highresolution_8x_d20000c9(r80c80)_iter_40000.ckpt')
30 |
31 | # dict
32 | parser.add_argument('--load_dict', type=bool, default=True, help='True to work based on an existing dict')
33 | parser.add_argument('--load_dict_from_path', type=str, default='dict/SROIE') # 40000 or 20000TC or table
34 | parser.add_argument('--tokenize', type=bool, default=True) # tokenize input text
35 | parser.add_argument('--text_case', type=bool, default=True) # case sensitive
36 | parser.add_argument('--update_dict', type=bool, default=False)
37 | parser.add_argument('--dict_path', type=str, default='dict/---') # not used if load_dict is True
38 |
39 | # data manipulation
40 | parser.add_argument('--segment_grid', type=bool, default=False) # segment grid into two parts if grid is larger than cols_target
41 | parser.add_argument('--rows_segment', type=int, default=72)
42 | parser.add_argument('--cols_segment', type=int, default=72)
43 | parser.add_argument('--augment_strategy', type=int, default=1) # 1 for increasing grid shape size, 2 for gaussian around target shape
44 | parser.add_argument('--positional_mapping_strategy', type=int, default=1)
45 | parser.add_argument('--rows_target', type=int, default=64)
46 | parser.add_argument('--cols_target', type=int, default=64)
47 | parser.add_argument('--rows_ulimit', type=int, default=80) # used when data augmentation is true
48 | parser.add_argument('--cols_ulimit', type=int, default=80)
49 | parser.add_argument('--fill_bbox', type=bool, default=False) # fill bbox with dict_id / label_id
50 |
51 | parser.add_argument('--data_augmentation_extra', type=bool, default=True) # randomly expand rows/cols
52 | parser.add_argument('--data_augmentation_dropout', type=float, default=1)
53 | parser.add_argument('--data_augmentation_extra_rows', type=int, default=16)
54 | parser.add_argument('--data_augmentation_extra_cols', type=int, default=16)
55 |
56 | # training
57 | parser.add_argument('--batch_size', type=int, default=32)
58 | parser.add_argument('--iterations', type=int, default=40000)
59 | parser.add_argument('--lr_decay_step', type=int, default=13000)
60 | parser.add_argument('--learning_rate', type=float, default=0.0001)
61 | parser.add_argument('--lr_decay_factor', type=float, default=0.1)
62 |
63 | # loss optimization
64 | parser.add_argument('--hard_negative_ratio', type=int, help='the ratio between negative and positive losses', default=3)
65 | parser.add_argument('--use_ghm', type=int, default=0) # 1 to use GHM, 0 to not use
66 | parser.add_argument('--ghm_bins', type=int, default=30) # to be tuned
67 | parser.add_argument('--ghm_momentum', type=int, default=0) # 0 / 0.75
68 |
69 | # log
70 | parser.add_argument('--log_path', type=str, default='../graph/CUTIE/log/')
71 | parser.add_argument('--log_disp_step', type=int, default=200)
72 | parser.add_argument('--log_save_step', type=int, default=200)
73 | parser.add_argument('--validation_step', type=int, default=200)
74 | parser.add_argument('--test_step', type=int, default=400)
75 | parser.add_argument('--ckpt_save_step', type=int, default=1000)
76 |
77 | # model
78 | parser.add_argument('--embedding_size', type=int, default=128) # not used for bert embedding which has 768 as default
79 | parser.add_argument('--weight_decay', type=float, default=0.0005)
80 | parser.add_argument('--eps', type=float, default=1e-6)
81 |
82 | # inference
83 | #parser.add_argument('--c_threshold', type=float, default=0.5)
84 | params = parser.parse_args()
85 |
86 | edges = [float(x)/params.ghm_bins for x in range(params.ghm_bins+1)]
87 | edges[-1] += params.eps
88 | acc_sum = [0.0 for _ in range(params.ghm_bins)]
89 | def calc_ghm_weights(logits, labels):
90 | """
91 | calculate gradient harmonizing mechanism weights
92 | """
93 | bins = params.ghm_bins
94 | momentum = params.ghm_momentum
95 | shape = logits.shape
96 |
97 | logits_flat = logits.reshape([-1])
98 | labels_flat = labels.reshape([-1])
99 | arr = [0 for _ in range(len(labels_flat)*num_classes)]
100 | for i,l in enumerate(labels_flat):
101 | arr[i*num_classes + l] = 1
102 | labels_flat = np.array(arr)
103 |
104 | grad = abs(logits_flat - labels_flat) # equation for logits from the sigmoid activation
105 |
106 | weights = np.ones(logits_flat.shape)
107 | N = shape[0] * shape[1] * shape[2] * shape[3]
108 | M = 0
109 | for i in range(bins):
110 | idxes = np.multiply(grad>=edges[i], grad 0:
113 | acc_sum[i] = momentum * acc_sum[i] + (1-momentum) * num_in_bin
114 | weights[np.where(idxes)] = N / acc_sum[i]
115 | M += 1
116 | if M > 0:
117 | weights = weights / M
118 |
119 | return weights.reshape(shape)
120 |
121 | def save_ckpt(sess, path, save_prefix, data_loader, network, num_words, num_classes, iter):
122 | ckpt_path = os.path.join(path, save_prefix)
123 | if not os.path.exists(ckpt_path):
124 | os.makedirs(ckpt_path)
125 | filename = os.path.join(ckpt_path, network.name + '_d{:d}c{:d}(r{:d}c{:d})_iter_{:d}'.
126 | format(num_words, num_classes, data_loader.rows_ulimit, data_loader.cols_ulimit, iter) + '.ckpt')
127 | ckpt_saver.save(sess, filename)
128 | print('\nCheckpoint saved to: {:s}\n'.format(filename))
129 |
130 | if __name__ == '__main__':
131 | pprint(params)
132 | # data
133 |
134 | data_loader = DataLoader(params, update_dict=params.update_dict, load_dictionary=params.load_dict, data_split=0.75)
135 | num_words = max(20000, data_loader.num_words)
136 | num_classes = data_loader.num_classes
137 | for _ in range(2000):
138 | a = data_loader.next_batch()
139 | b = data_loader.fetch_validation_data()
140 | # c = data_loader.fetch_test_data()
141 |
142 | # model
143 | if params.use_cutie2:
144 | network = CUTIEv2(num_words, num_classes, params)
145 | else:
146 | network = CUTIEv1(num_words, num_classes, params)
147 | model_loss, regularization_loss, total_loss, model_logits, model_output = network.build_loss()
148 |
149 | # operators
150 | global_step = tf.Variable(0, trainable=False)
151 | lr = tf.Variable(params.learning_rate, trainable=False)
152 | optimizer = tf.train.AdamOptimizer(lr)
153 | tvars = tf.trainable_variables()
154 | grads = tf.gradients(total_loss, tvars)
155 | clipped_grads, norm = tf.clip_by_global_norm(grads, 10.0)
156 | train_op = optimizer.apply_gradients(list(zip(clipped_grads, tvars)), global_step=global_step)
157 | with tf.control_dependencies([train_op]):
158 | train_dummy = tf.constant(0)
159 |
160 | tf.contrib.training.add_gradients_summaries(zip(clipped_grads, tvars))
161 | summary_op = tf.summary.merge_all()
162 |
163 | # calculate the number of parameters
164 | total_parameters = 0
165 | for variable in tf.trainable_variables():
166 | shape = variable.get_shape()
167 | variable_parameters = 1
168 | for dim in shape:
169 | variable_parameters *= dim.value
170 | total_parameters += variable_parameters
171 | print(network.name, ': ', total_parameters/1000/1000, 'M parameters \n')
172 |
173 | # training
174 | loss_curve = []
175 | training_recall, validation_recall, test_recall = [], [], []
176 | training_acc_strict, validation_acc_strict, test_acc_strict = [], [], []
177 | training_acc_soft, validation_acc_soft, test_acc_soft = [], [], []
178 |
179 | ckpt_saver = tf.train.Saver(max_to_keep=200)
180 | summary_path = os.path.join(params.log_path, params.save_prefix, network.name)
181 | summary_writer = tf.summary.FileWriter(summary_path, tf.get_default_graph(), flush_secs=10)
182 |
183 | config = tf.ConfigProto(allow_soft_placement=True)
184 | config.gpu_options.allow_growth = True
185 | with tf.Session(config=config) as sess:
186 | sess.run(tf.global_variables_initializer())
187 |
188 | iter_start = 0
189 |
190 | # restore parameters
191 | if params.restore_ckpt:
192 | if params.restore_bertembedding_only:
193 | if 'bert' not in network.name:
194 | raise Exception('no bert embedding was designed in the built model, \
195 | switch restore_bertembedding_only off or built a related model')
196 | try:
197 | load_variable = {"bert/embeddings/word_embeddings": network.embedding_table}
198 | ckpt_saver = tf.train.Saver(load_variable, max_to_keep=50)
199 | ckpt_path = params.embedding_file
200 | ckpt = tf.train.get_checkpoint_state(ckpt_path)
201 | print('Restoring from {}...'.format(ckpt_path))
202 | ckpt_saver.restore(sess, ckpt_path)
203 | print('Restored from {}'.format(ckpt_path))
204 | except:
205 | raise Exception('Check your path {:s}'.format(ckpt_path))
206 | else:
207 | try:
208 | ckpt_path = os.path.join(params.ckpt_path, params.ckpt_file)
209 | ckpt = tf.train.get_checkpoint_state(ckpt_path)
210 | print('Restoring from {}...'.format(ckpt_path))
211 | ckpt_saver.restore(sess, ckpt_path)
212 | print('Restored from {}'.format(ckpt_path))
213 | stem = os.path.splitext(os.path.basename(ckpt_path))[0]
214 | #iter_start = int(stem.split('_')[-1]) - 1
215 | sess.run(global_step.assign(iter_start))
216 | except:
217 | raise Exception('Check your pretrained {:s}'.format(ckpt_path))
218 |
219 | # iterations
220 | print(" Let's roll! ")
221 | for iter in range(iter_start, params.iterations+1):
222 | timer_start = timeit.default_timer()
223 |
224 | # learning rate decay
225 | if iter!=0 and iter%params.lr_decay_step==0:
226 | sess.run(tf.assign(lr, lr.eval()*params.lr_decay_factor))
227 |
228 | data = data_loader.next_batch()
229 | feeds = [network.data_grid, network.gt_classes, network.data_image, network.ps_1d_indices, network.ghm_weights]
230 | fetches = [model_loss, regularization_loss, total_loss, summary_op, train_dummy, model_logits, model_output]
231 | h = sess.partial_run_setup(fetches, feeds)
232 |
233 | # one step inference
234 | feed_dict = {
235 | network.data_grid: data['grid_table'],
236 | network.gt_classes: data['gt_classes']
237 | }
238 | if params.use_cutie2:
239 | feed_dict = {
240 | network.data_grid: data['grid_table'],
241 | network.gt_classes: data['gt_classes'],
242 | network.data_image: data['data_image'],
243 | network.ps_1d_indices: data['ps_1d_indices']
244 | }
245 | fetches = [model_logits, model_output]
246 | (model_logit_val, model_output_val) = sess.partial_run(h, fetches, feed_dict)
247 |
248 | # one step training
249 | ghm_weights = np.ones(np.shape(model_logit_val))
250 | if params.use_ghm:
251 | ghm_weights = calc_ghm_weights(np.array(model_logit_val), np.array(data['gt_classes']))
252 | feed_dict = {
253 | network.ghm_weights: ghm_weights,
254 | }
255 | fetches = [model_loss, regularization_loss, total_loss, summary_op, train_dummy]
256 | (model_loss_val, regularization_loss_val, total_loss_val, summary_str, _) =\
257 | sess.partial_run(h, fetches=fetches, feed_dict=feed_dict)
258 |
259 | # calculate training accuracy and display results
260 | if iter%params.log_disp_step == 0:
261 | timer_stop = timeit.default_timer()
262 | print('\t >>time per step: %.2fs <<'%(timer_stop - timer_start))
263 |
264 | recall, acc_strict, acc_soft, res = cal_accuracy(data_loader, np.array(data['grid_table']),
265 | np.array(data['gt_classes']), model_output_val,
266 | np.array(data['label_mapids']), np.array(data['bbox_mapids']))
267 | loss_curve += [total_loss_val]
268 | training_recall += [recall]
269 | training_acc_strict += [acc_strict]
270 | training_acc_soft += [acc_soft]
271 |
272 | #print(res.decode())
273 | print('\nIter: %d/%d, total loss: %.4f, model loss: %.4f, regularization loss: %.4f'%\
274 | (iter, params.iterations, total_loss_val, model_loss_val, regularization_loss_val))
275 | print('LOSS CURVE: ' + ' >'.join(['{:d}:{:.3f}'.
276 | format(i*params.log_disp_step,w) for i,w in enumerate(loss_curve)]))
277 | print('TRAINING ACC CURVE: ' + ' >'.join(['{:d}:{:.3f}'.
278 | format(i*params.log_disp_step,w) for i,w in enumerate(training_acc_strict)]))
279 | print('TRAINING ACC (Recall/Acc): %.3f / %.3f (%.3f) | highest %.3f / %.3f (%.3f)'\
280 | %(recall, acc_strict, acc_soft, max(training_recall), max(training_acc_strict), max(training_acc_soft)))
281 |
282 | # calculate validation accuracy and display results
283 | if iter%params.validation_step == 0 and len(data_loader.validation_docs):
284 | recalls, accs_strict, accs_soft = [], [], []
285 | for _ in range(len(data_loader.validation_docs)):
286 | data = data_loader.fetch_validation_data()
287 | grid_tables = data['grid_table']
288 | gt_classes = data['gt_classes']
289 |
290 | feed_dict = {
291 | network.data_grid: grid_tables,
292 | }
293 | if params.use_cutie2:
294 | feed_dict = {
295 | network.data_grid: grid_tables,
296 | network.data_image: data['data_image'],
297 | network.ps_1d_indices: data['ps_1d_indices']
298 | }
299 | fetches = [model_output]
300 | [model_output_val] = sess.run(fetches=fetches, feed_dict=feed_dict)
301 | recall, acc_strict, acc_soft, res = cal_accuracy(data_loader, np.array(grid_tables),
302 | np.array(gt_classes), model_output_val,
303 | np.array(data['label_mapids']), np.array(data['bbox_mapids']))
304 | recalls += [recall]
305 | accs_strict += [acc_strict]
306 | accs_soft += [acc_soft]
307 |
308 | recall = sum(recalls) / len(recalls)
309 | acc_strict = sum(accs_strict) / len(accs_strict)
310 | acc_soft = sum(accs_soft) / len(accs_soft)
311 | validation_recall += [recall]
312 | validation_acc_strict += [acc_strict]
313 | validation_acc_soft += [acc_soft]
314 | #print(res.decode()) # show res from the last execution of the while loop
315 |
316 | print('VALIDATION ACC (STRICT) CURVE: ' + ' >'.join(['{:d}:{:.3f}'.
317 | format(i*params.validation_step,w) for i,w in enumerate(validation_acc_strict)]))
318 | print('VALIDATION ACC (SOFT) CURVE: ' + ' >'.join(['{:d}:{:.3f}'.
319 | format(i*params.validation_step,w) for i,w in enumerate(validation_acc_soft)]))
320 | print('TRAINING RECALL CURVE: ' + ' >'.join(['{:d}:{:.2f}'.
321 | format(i*params.log_disp_step,w) for i,w in enumerate(training_recall)]))
322 | print('VALIDATION RECALL CURVE: ' + ' >'.join(['{:d}:{:.2f}'.
323 | format(i*params.validation_step,w) for i,w in enumerate(validation_recall)]))
324 |
325 | idx = np.argmax(validation_acc_strict)
326 | print('VALIDATION Statistic %d(%d) (Recall/Acc): %.3f / %.3f (%.3f) | highest %.3f / %.3f (%.3f) \n'
327 | %(iter, idx*params.validation_step, recall, acc_strict, acc_soft,
328 | validation_recall[idx], validation_acc_strict[idx], validation_acc_soft[idx]))
329 |
330 | # save best performance checkpoint
331 | if iter>=params.ckpt_save_step and validation_acc_strict[-1] > max(validation_acc_strict[:-1]+[0]):
332 | # save as iter+1 to indicate best validation
333 | save_ckpt(sess, params.ckpt_path, params.save_prefix, data_loader, network, num_words, num_classes, iter+1)
334 | print('\nBest up-to-date performance validation checkpoint saved.\n')
335 |
336 | # calculate validation accuracy and display results
337 | if params.test_path!='' and iter%params.test_step == 0 and len(data_loader.test_docs):
338 |
339 | recalls, accs_strict, accs_soft = [], [], []
340 | while True:
341 | data = data_loader.fetch_test_data()
342 | if data == None:
343 | break
344 | grid_tables = data['grid_table']
345 | gt_classes = data['gt_classes']
346 |
347 |
348 | feed_dict = {
349 | network.data_grid: grid_tables,
350 | }
351 | if params.use_cutie2:
352 | feed_dict = {
353 | network.data_grid: grid_tables,
354 | network.data_image: data['data_image'],
355 | network.ps_1d_indices: data['ps_1d_indices']
356 | }
357 | fetches = [model_output]
358 | [model_output_val] = sess.run(fetches=fetches, feed_dict=feed_dict)
359 | recall, acc_strict, acc_soft, res = cal_accuracy(data_loader, np.array(grid_tables),
360 | np.array(gt_classes), model_output_val,
361 | np.array(data['label_mapids']), np.array(data['bbox_mapids']))
362 | recalls += [recall]
363 | accs_strict += [acc_strict]
364 | accs_soft += [acc_soft]
365 |
366 | recall = sum(recalls) / len(recalls)
367 | acc_strict = sum(accs_strict) / len(accs_strict)
368 | acc_soft = sum(accs_soft) / len(accs_soft)
369 | test_recall += [recall]
370 | test_acc_strict += [acc_strict]
371 | test_acc_soft += [acc_soft]
372 | idx = np.argmax(test_acc_strict)
373 | print('\n TEST ACC (Recall/Acc): %.3f / %.3f (%.3f) | highest %.3f / %.3f (%.3f) \n'
374 | %(recall, acc_strict, acc_soft, test_recall[idx], test_acc_strict[idx], test_acc_soft[idx]))
375 | print('TEST ACC (STRICT) CURVE: ' + ' >'.join(['{:d}:{:.3f}'.
376 | format(i*params.test_step,w) for i,w in enumerate(test_acc_strict)]))
377 | print('TEST ACC (SOFT) CURVE: ' + ' >'.join(['{:d}:{:.3f}'.
378 | format(i*params.test_step,w) for i,w in enumerate(test_acc_soft)]))
379 | print('TEST RECALL CURVE: ' + ' >'.join(['{:d}:{:.2f}'.
380 | format(i*params.test_step,w) for i,w in enumerate(test_recall)]))
381 |
382 | # save best performance checkpoint
383 | if iter>=params.ckpt_save_step and test_acc_strict[-1] > max(test_acc_strict[:-1]+[0]):
384 | # save as iter+1 to indicate best test
385 | save_ckpt(sess, params.ckpt_path, params.save_prefix, data_loader, network, num_words, num_classes, iter+2)
386 | print('\nBest up-to-date performance test checkpoint saved.\n')
387 |
388 | # save checkpoints
389 | if iter>=params.log_save_step and iter%params.ckpt_save_step == 0:
390 | save_ckpt(sess, params.ckpt_path, params.save_prefix, data_loader, network, num_words, num_classes, iter)
391 |
392 | # save logs
393 | if iter>=params.log_save_step and iter%params.log_save_step == 0:
394 | summary_writer.add_summary(summary_str, iter+1)
395 |
396 | pprint(params)
397 | pprint('Data rows/cols:{},{}'.format(data_loader.rows, data_loader.cols))
398 | summary_writer.close()
399 |
--------------------------------------------------------------------------------
/model_framework.py:
--------------------------------------------------------------------------------
1 | # written by Xiaohui Zhao
2 | # 2018-12
3 | # xiaohui.zhao@outlook.com
4 | import tensorflow as tf
5 | import math
6 |
7 | def layer(op):
8 | def layer_decorated(self, *args, **kwargs):
9 | name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
10 | if len(self.layer_inputs) == 0:
11 | raise RuntimeError('No input variables found for layers %s' % name)
12 | elif len(self.layer_inputs) == 1:
13 | layer_input = self.layer_inputs[0]
14 | else:
15 | layer_input = list(self.layer_inputs)
16 |
17 | layer_output = op(self, layer_input, *args, **kwargs)
18 |
19 | self.layers[name] = layer_output
20 | self.feed(layer_output)
21 |
22 | return self
23 | return layer_decorated
24 |
25 |
26 | class Model(object):
27 | def __init__(self, trainable=True):
28 | self.layers = dict()
29 | self.trainable = trainable
30 |
31 | self.layer_inputs = []
32 | self.setup()
33 |
34 |
35 | def build_loss(self):
36 | raise NotImplementedError('Must be subclassed.')
37 |
38 |
39 | def setup(self):
40 | raise NotImplementedError('Must be subclassed.')
41 |
42 |
43 | @layer
44 | def embed(self, layer_input, vocabulary_size, embedding_size, name, dropout=1, trainable=True):
45 | with tf.variable_scope(name) as scope:
46 | init_embedding = tf.random_uniform_initializer(-1.0, 1.0)
47 | embeddings = self.make_var('weights', [vocabulary_size, embedding_size], init_embedding, None, trainable)
48 | shape = tf.shape(layer_input)
49 |
50 | reshaped_input = tf.reshape(layer_input, [-1])
51 | e = tf.nn.embedding_lookup(embeddings, reshaped_input)
52 | e = tf.nn.dropout(e, dropout)
53 | reshaped_e = tf.reshape(e, [shape[0], shape[1], shape[2], embedding_size])
54 | return reshaped_e
55 |
56 |
57 | @layer
58 | def bert_embed(self, layer_input, vocab_size, embedding_size=768, use_one_hot_embeddings=False,
59 | initializer_range=0.02, name="embeddings", trainable=False):
60 | with tf.variable_scope("bert"):
61 | with tf.variable_scope("embeddings"):
62 | # Perform embedding lookup on the word ids.
63 | (embedding_output, embedding_table) = self.embedding_lookup(
64 | input_ids=layer_input, vocab_size=vocab_size, embedding_size=embedding_size,
65 | initializer_range=initializer_range,
66 | word_embedding_name="word_embeddings",
67 | use_one_hot_embeddings=use_one_hot_embeddings,
68 | trainable=trainable)
69 | self.embedding_table = embedding_table # the inherited class need a self.embedding_table variable
70 | return embedding_output
71 |
72 |
73 | @layer
74 | def positional_sampling(self, layer_input, feature_dimension, name='positional_sampling'):
75 | featuremap = layer_input[0]
76 | batch_indices = layer_input[1]
77 | grid = layer_input[2]
78 |
79 | shape_grid = tf.shape(grid)
80 |
81 | featuremap_flat = tf.reshape(featuremap, [shape_grid[0], -1, feature_dimension])
82 | batch_indices_flat = tf.reshape(batch_indices, [shape_grid[0], -1])
83 | batch_ps_flat = tf.batch_gather(featuremap_flat, batch_indices_flat)
84 |
85 | b, h, w, c = shape_grid[0], shape_grid[1], shape_grid[2], feature_dimension
86 | return tf.reshape(batch_ps_flat, [b,h,w,c])
87 |
88 |
89 | @layer
90 | def sepconv(self, layer_input, k_h, k_w, cardinality, compression, name, activation='relu', trainable=True):
91 | """ customized seperable convolution
92 | """
93 | convolve = lambda input, filter: tf.nn.conv2d(input, filter, [1,1,1,1], 'SAME')
94 | activate = lambda z: tf.nn.relu(z, 'relu')
95 | with tf.variable_scope(name) as scope:
96 | init_weights = tf.truncated_normal_initializer(0.0, 0.01)
97 | init_biases = tf.constant_initializer(0.0)
98 | regularizer = self.l2_regularizer(self.weight_decay)
99 | c_i = layer_input.get_shape().as_list()[-1]
100 |
101 | layer_output = []
102 | c = c_i / cardinality / compression
103 | for _ in range(cardinality):
104 | a = self.convolution(convolve, activate, layer_input, 1, 1, c_i, c,
105 | init_weights, init_biases, regularizer, trainable, '0_{}'.format(_))
106 | a = self.convolution(convolve, activate, a, k_h, k_w, c, c,
107 | init_weights, init_biases, regularizer, trainable, '1_{}'.format(_))
108 | a = self.convolution(convolve, activate, a, 1, 1, c, c_i,
109 | init_weights, init_biases, regularizer, trainable, '2_{}'.format(_))
110 | layer_output.append(a)
111 | layer_output = tf.add_n(layer_output)
112 | return tf.add(layer_output, layer_input)
113 |
114 |
115 | @layer
116 | def up_sepconv(self, layer_input, k_h, k_w, cardinality, compression, name, activation='relu', trainable=True):
117 | """ customized upscale seperable convolution
118 | """
119 | convolve = lambda input, filter: tf.nn.conv2d(input, filter, [1,1,1,1], 'SAME')
120 | activate = lambda z: tf.nn.relu(z, 'relu')
121 | with tf.variable_scope(name) as scope:
122 | shape = tf.shape(layer_input)
123 | h = shape[1]
124 | w = shape[2]
125 | layer_input = tf.image.resize_nearest_neighbor(layer_input, [2*h, 2*w])
126 | init_weights = tf.truncated_normal_initializer(0.0, 0.01)
127 | init_biases = tf.constant_initializer(0.0)
128 | regularizer = self.l2_regularizer(self.weight_decay)
129 | c_i = layer_input.get_shape().as_list()[-1]
130 |
131 | layer_output = []
132 | c = c_i / cardinality / compression
133 | for _ in range(cardinality):
134 | a = self.convolution(convolve, activate, layer_input, 1, 1, c_i, c,
135 | init_weights, init_biases, regularizer, trainable, '0_{}'.format(_))
136 | a = self.convolution(convolve, activate, a, k_h, k_w, c, c,
137 | init_weights, init_biases, regularizer, trainable, '1_{}'.format(_))
138 | a = self.convolution(convolve, activate, a, 1, 1, c, c_i,
139 | init_weights, init_biases, regularizer, trainable, '2_{}'.format(_))
140 | layer_output.append(a)
141 | layer_output = tf.add_n(layer_output)
142 | return tf.add(layer_output, layer_input)
143 |
144 |
145 | @layer
146 | def dense_block(self, layer_input, k_h, k_w, c_o, depth, name, activation='relu', trainable=True):
147 | convolve = lambda input, filter: tf.nn.conv2d(input, filter, [1,1,1,1], 'SAME')
148 | activate = lambda z: tf.nn.relu(z, 'relu')
149 | with tf.variable_scope(name) as scope:
150 | init_weights = tf.truncated_normal_initializer(0.0, 0.01)
151 | init_biases = tf.constant_initializer(0.0)
152 | regularizer = self.l2_regularizer(self.weight_decay)
153 |
154 | layer_tmp = layer_input
155 | for d in range(depth):
156 | c_i = layer_tmp.get_shape()[-1]
157 | a = self.convolution(convolve, activate, layer_tmp, 1, 1, c_i, c_i//2,
158 | init_weights, init_biases, regularizer, trainable)
159 |
160 | a = self.convolution(convolve, activate, a, k_h, k_w, c_i, c_o,
161 | init_weights, init_biases, regularizer, trainable)
162 |
163 | layer_tmp = tf.concat([a, layer_input], 3)
164 |
165 | return layer_tmp
166 |
167 |
168 | @layer
169 | def conv(self, layer_input, k_h, k_w, c_o, s_h, s_w, name, activation='relu', trainable=True):
170 | convolve = lambda input, filter: tf.nn.conv2d(input, filter, [1,s_h,s_w,1], 'SAME')
171 | #convolve = lambda input, filter: tf.nn.atrous_conv2d(input, filter, 2, 'SAME', 'DILATE')
172 |
173 | activate = lambda z: tf.nn.relu(z, 'relu') #if activation == 'relu':
174 | if activation == 'sigmoid':
175 | activate = lambda z: tf.nn.sigmoid(z, 'sigmoid')
176 |
177 | with tf.variable_scope(name) as scope:
178 | init_weights = tf.truncated_normal_initializer(0.0, 0.01)
179 | init_biases = tf.constant_initializer(0.0)
180 | regularizer = self.l2_regularizer(self.weight_decay)
181 | c_i = layer_input.get_shape()[-1]
182 |
183 | a = self.convolution(convolve, activate, layer_input, k_h, k_w, c_i, c_o,
184 | init_weights, init_biases, regularizer, trainable)
185 | return a
186 |
187 |
188 | @layer
189 | def dilate_conv(self, layer_input, k_h, k_w, c_o, s_h, s_w, rate, name, activation='relu', trainable=True):
190 | convolve = lambda input, filter: tf.nn.atrous_conv2d(input, filter, rate, 'SAME', 'DILATE')
191 |
192 | activate = lambda z: tf.nn.relu(z, 'relu') #if activation == 'relu':
193 | if activation == 'sigmoid':
194 | activate = lambda z: tf.nn.sigmoid(z, 'sigmoid')
195 |
196 | with tf.variable_scope(name) as scope:
197 | init_weights = tf.truncated_normal_initializer(0.0, 0.01)
198 | init_biases = tf.constant_initializer(0.0)
199 | regularizer = self.l2_regularizer(self.weight_decay)
200 | c_i = layer_input.get_shape()[-1]
201 |
202 | a = self.convolution(convolve, activate, layer_input, k_h, k_w, c_i, c_o,
203 | init_weights, init_biases, regularizer, trainable)
204 | return a
205 |
206 |
207 | @layer
208 | def dilate_module(self, layer_input, k_h, k_w, c_o, s_h, s_w, rate, name, activation='relu', trainable=True):
209 | convolve = lambda input, filter: tf.nn.atrous_conv2d(input, filter, rate, 'SAME', 'DILATE')
210 |
211 | activate = lambda z: tf.nn.relu(z, 'relu') #if activation == 'relu':
212 | if activation == 'sigmoid':
213 | activate = lambda z: tf.nn.sigmoid(z, 'sigmoid')
214 |
215 | with tf.variable_scope(name) as scope:
216 | init_weights = tf.truncated_normal_initializer(0.0, 0.01)
217 | init_biases = tf.constant_initializer(0.0)
218 | regularizer = self.l2_regularizer(self.weight_decay)
219 | c_i = layer_input.get_shape()[-1]
220 |
221 | a = self.convolution(convolve, activate, layer_input, k_h, k_w, c_i, c_o,
222 | init_weights, init_biases, regularizer, trainable)
223 | return a
224 |
225 |
226 | @layer
227 | def up_conv(self, layer_input, k_h, k_w, c_o, s_h, s_w, name, factor=2, activation='relu', trainable=True):
228 | convolve = lambda input, filter: tf.nn.conv2d(input, filter, [1,s_h,s_w,1], 'SAME')
229 | #convolve = lambda input, filter: tf.nn.atrous_conv2d(input, filter, 2, 'SAME', 'DILATE')
230 |
231 | activate = lambda z: tf.nn.relu(z, 'relu')
232 | with tf.variable_scope(name) as scope:
233 | shape = tf.shape(layer_input)
234 | h = shape[1]
235 | w = shape[2]
236 | layer_input = tf.image.resize_nearest_neighbor(layer_input, [factor*h, factor*w])
237 | init_weights = tf.truncated_normal_initializer(0.0, 0.01)
238 | init_biases = tf.constant_initializer(0.0)
239 | regularizer = self.l2_regularizer(self.weight_decay)
240 | c_i = layer_input.get_shape()[-1]
241 |
242 | a = self.convolution(convolve, activate, layer_input, k_h, k_w, c_i, c_o,
243 | init_weights, init_biases, regularizer, trainable)
244 | return a
245 |
246 |
247 | @layer
248 | def attention(self, layer_input, num_heads, name, att_dropout=0.0, hidden_dropout=0.1, trainable=True):
249 | """
250 | implement self attention with residual addition,
251 | layer_input[0] and layer_input[1] should have the same shape for residual addition
252 | """
253 | f = layer_input[0]
254 | x = layer_input[1]
255 |
256 | convolve = lambda input, filter: tf.nn.conv2d(input, filter, [1,1,1,1], 'SAME')
257 | with tf.variable_scope(name) as scope:
258 | init_weights = tf.truncated_normal_initializer(0.0, 0.02)
259 | regularizer = self.l2_regularizer(self.weight_decay)
260 | shape = tf.shape(f)
261 | c_i = f.get_shape()[-1]
262 | c_o = f.get_shape()[-1]
263 | c_a = c_o // num_heads # attention kernel depth, size per head
264 |
265 | query = self.make_var('weights_query', [1, 1, c_i, c_a], init_weights, regularizer, trainable)
266 | query_layer = convolve(f, query) # [B, H, W, c_a]
267 | query_layer = tf.reshape(query_layer, [shape[0], -1, c_a]) # [B, H*W, c_a]
268 |
269 | key = self.make_var('weights_key', [1, 1, c_i, c_a], init_weights, regularizer, trainable)
270 | key_layer = convolve(f, key) # [B, H, W, c_a]
271 | key_layer = tf.reshape(key_layer, [shape[0], -1, c_a]) # [B, H*W, c_a]
272 |
273 | value = self.make_var('weights_value', [1, 1, c_i, c_o], init_weights, regularizer, trainable)
274 | value_layer = convolve(f, value)
275 | value_layer = tf.reshape(value_layer, [shape[0], -1, c_o])# [B, H*W, c_o]
276 |
277 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) # [B, H*W, H*W]
278 | attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(c_a.value)))
279 |
280 | attention_probs = tf.nn.softmax(attention_scores)
281 | #attention_probs = dropout(attention_probs, att_dropout)
282 |
283 | context_layer = tf.matmul(attention_probs, value_layer) # [B, H*W, c_o]
284 | context_layer = tf.reshape(context_layer, shape) # [B, H, W, c_o]
285 |
286 | kernel = self.make_var('output', [1, 1, c_o, c_o], init_weights, regularizer, trainable)
287 | attention_output = convolve(context_layer, kernel)
288 | #attention_output = dropout(attention_output, hidden_dropout)
289 | attention_output = attention_output + x
290 |
291 | return tf.contrib.layers.instance_norm(attention_output, center=False, scale=False)
292 |
293 |
294 | @layer
295 | def concat(self, layer_input, axis, name):
296 | return tf.concat(layer_input, axis)
297 |
298 |
299 | @layer
300 | def add(self, layer_input, name):
301 | return tf.math.add_n(layer_input)
302 |
303 |
304 | @layer
305 | def max_pool(self, layer_input, k_h, k_w, s_h, s_w, name, padding='SAME'):
306 | return tf.nn.max_pool(layer_input, [1,k_h,k_w,1], [1,s_h,s_w,1], name=name, padding=padding)
307 |
308 |
309 | @layer
310 | def global_pool(self, layer_input, name):
311 | shape = tf.shape(layer_input)
312 | h = shape[1]
313 | w = shape[2]
314 | output = tf.reduce_mean(layer_input, [1,2], keepdims=True, name=name)
315 | return tf.image.resize_nearest_neighbor(output, [h, w])
316 |
317 |
318 | @layer
319 | def softmax(self, layer_input, name):
320 | return tf.nn.softmax(layer_input, name=name)
321 |
322 |
323 | def embedding_lookup(self, input_ids, vocab_size, embedding_size=768,
324 | initializer_range=0.02, word_embedding_name="word_embeddings",
325 | use_one_hot_embeddings=False, trainable=False):
326 | """Looks up words embeddings for id tensor.
327 |
328 | Args:
329 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
330 | ids.
331 | vocab_size: int. Size of the embedding vocabulary.
332 | embedding_size: int. Width of the word embeddings.
333 | initializer_range: float. Embedding initialization range.
334 | word_embedding_name: string. Name of the embedding table.
335 | use_one_hot_embeddings: bool. If True, use one-hot method for word
336 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better
337 | for TPUs.
338 |
339 | Returns:
340 | float Tensor of shape [batch_size, seq_length, embedding_size].
341 | """
342 | bert_vocab_size = 119547
343 | # This function assumes that the input is of shape [batch_size, seq_length,
344 | # num_inputs].
345 | #
346 | # If the input is a 2D tensor of shape [batch_size, seq_length], we
347 | # reshape to [batch_size, seq_length, 1].
348 | if input_ids.shape.ndims == 3: # originally 2
349 | input_ids = tf.expand_dims(input_ids, axis=[-1])
350 |
351 | bert_embedding_table = embedding_table = tf.get_variable(
352 | name=word_embedding_name,
353 | shape=[bert_vocab_size, embedding_size],
354 | initializer=tf.truncated_normal_initializer(stddev=initializer_range),
355 | trainable=trainable)
356 | if vocab_size > bert_vocab_size: # handle dict augmentation
357 | embedding_table_plus = tf.get_variable(
358 | name=word_embedding_name + '_plus',
359 | shape=[vocab_size-bert_vocab_size, embedding_size],
360 | initializer=tf.truncated_normal_initializer(stddev=initializer_range),
361 | trainable=True)
362 | embedding_table = tf.concat([embedding_table, embedding_table_plus], 0)
363 |
364 | if use_one_hot_embeddings:
365 | flat_input_ids = tf.reshape(input_ids, [-1])
366 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
367 | output = tf.matmul(one_hot_input_ids, embedding_table)
368 | else:
369 | output = tf.nn.embedding_lookup(embedding_table, input_ids)
370 |
371 | input_shape = self.get_shape_list(input_ids)
372 |
373 | output = tf.reshape(output,
374 | input_shape[0:-1] + [input_shape[-1] * embedding_size])
375 | return (output, bert_embedding_table)
376 |
377 | def get_shape_list(self, tensor, expected_rank=None, name=None):
378 | """Returns a list of the shape of tensor, preferring static dimensions.
379 |
380 | Args:
381 | tensor: A tf.Tensor object to find the shape of.
382 | expected_rank: (optional) int. The expected rank of `tensor`. If this is
383 | specified and the `tensor` has a different rank, and exception will be
384 | thrown.
385 | name: Optional name of the tensor for the error message.
386 |
387 | Returns:
388 | A list of dimensions of the shape of tensor. All static dimensions will
389 | be returned as python integers, and dynamic dimensions will be returned
390 | as tf.Tensor scalars.
391 | """
392 | if name is None:
393 | name = tensor.name
394 |
395 | if expected_rank is not None:
396 | assert_rank(tensor, expected_rank, name)
397 |
398 | shape = tensor.shape.as_list()
399 |
400 | non_static_indexes = []
401 | for (index, dim) in enumerate(shape):
402 | if dim is None:
403 | non_static_indexes.append(index)
404 |
405 | if not non_static_indexes:
406 | return shape
407 |
408 | dyn_shape = tf.shape(tensor)
409 | for index in non_static_indexes:
410 | shape[index] = dyn_shape[index]
411 | return shape
412 |
413 |
414 | def convolution(self, convolve, activate, input, k_h, k_w, c_i, c_o, init_weights, init_biases,
415 | regularizer, trainable, name=''):
416 | kernel = self.make_var('weights'+name, [k_h, k_w, c_i, c_o], init_weights, regularizer, trainable)
417 | biases = self.make_var('biases'+name, [c_o], init_biases, None, trainable)
418 | tf.summary.histogram('w', kernel)
419 | tf.summary.histogram('b', biases)
420 | # test with different orders: convolve/activate/normalize; normalize/convolve/activate; convolve/normalize/activate
421 | wx = convolve(input, kernel)
422 | a = activate(tf.nn.bias_add(wx, biases))
423 | a = tf.contrib.layers.instance_norm(a, center=False, scale=False)
424 | return a
425 |
426 |
427 | def l2_regularizer(self, weight_decay=0.0005, scope=None):
428 | def regularizer(tensor):
429 | with tf.name_scope(scope, default_name='l2_regularizer', values=[tensor]):
430 | factor = tf.convert_to_tensor(weight_decay, name='weight_decay')
431 | return tf.multiply(factor, tf.nn.l2_loss(tensor), name='decayed_value')
432 | return regularizer
433 |
434 |
435 | def make_var(self, name, shape, initializer=None, regularizer=None, trainable=True):
436 | return tf.get_variable(name, shape, initializer=initializer, regularizer=regularizer, trainable=trainable)
437 |
438 |
439 | def feed(self, *args):
440 | assert len(args) != 0
441 |
442 | self.layer_inputs = []
443 | for layer in args:
444 | if isinstance(layer, str):
445 | try:
446 | layer = self.layers[layer]
447 | print(layer)
448 | except KeyError:
449 | print(list(self.layers.keys()))
450 | raise KeyError('Unknown layer name fed: %s' % layer)
451 | self.layer_inputs.append(layer)
452 | return self
453 |
454 |
455 | def get_output(self, layer):
456 | try:
457 | layer = self.layers[layer]
458 | except KeyError:
459 | print(list(self.layers.keys()))
460 | raise KeyError('Unknown layer name fed: %s' % layer)
461 | return layer
462 |
463 |
464 | def get_unique_name(self, prefix):
465 | id = sum(t.startswith(prefix) for t,_ in list(self.layers.items())) + 1
466 | return '%s_%d' % (prefix, id)
--------------------------------------------------------------------------------