├── README.MD ├── create_dataset.py ├── preprocess_data.py ├── train.py └── convertXMLtoPCKL.py /README.MD: -------------------------------------------------------------------------------- 1 | # patent-classification 2 | Implementation of "Optimizing neural networks for patent classification" paper for wipo-alpha dataset 3 | 4 | ## How to install 5 | 6 | Install the following requirements: 7 | - python3 8 | - pyfasttext 9 | - keras 10 | 11 | Download Wipo-alpha dataset and put extracted folder in resources 12 | 13 | Download fasttext word embedding and put in resources 14 | 15 | Create the dataset by executing: 16 | 17 | ``` 18 | python create_dataset -train 19 | python create_dataset -test 20 | ``` 21 | 22 | Train model by executing: 23 | ``` 24 | python train.py -embedding 25 | ``` 26 | -------------------------------------------------------------------------------- /create_dataset.py: -------------------------------------------------------------------------------- 1 | from convertXMLtoPCKL import convert 2 | import argparse 3 | import os.path 4 | import pickle 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('-train', action="store_true") 8 | parser.add_argument('-test', action="store_true") 9 | 10 | args = vars(parser.parse_args()) 11 | 12 | if (args['train'] and args['test']) or (not args['train'] and not args['test']): 13 | print("Error in arguments") 14 | print("run \"python create_data.py\" with -train or -test argument") 15 | exit() 16 | 17 | if args['train']: 18 | LabelsIdDict = {} 19 | convert(LabelsIdDict) 20 | else: 21 | labels_ID_path = './resources/labels_ID.pkl' 22 | if os.path.isfile(labels_ID_path): 23 | with open('./resources/labels_ID.pkl', 'rb') as pckl: 24 | LabelsIdDict = pickle.load(pckl) 25 | convert(LabelsIdDict) 26 | else: 27 | print("run code with \"-train\" argument " + 28 | "first to create label alphabet") 29 | exit() 30 | -------------------------------------------------------------------------------- /preprocess_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from keras.preprocessing.text import Tokenizer 4 | from keras.preprocessing.sequence import pad_sequences 5 | import os.path 6 | from pyfasttext import FastText 7 | 8 | 9 | def readData(file_name, max_length): 10 | data_path = './resources/' + file_name 11 | if not os.path.isfile(data_path): 12 | print(file_name + " not found") 13 | exit() 14 | with open(data_path, 'rb') as pckl: 15 | text = pickle.load(pckl) 16 | for o, doc in enumerate(text): 17 | text[o] = " ".join(text[o].split()[:max_length]) 18 | return text 19 | 20 | 21 | def editData(text_file, label_file, max_length, tokenizer): 22 | text_path = './resources/' + text_file 23 | if not os.path.isfile(text_path): 24 | print(text_file + " not found") 25 | exit() 26 | with open(text_path, 'rb') as pckl: 27 | texts = pickle.load(pckl) 28 | for o, doc in enumerate(texts): 29 | texts[o] = " ".join(texts[o].split()[:max_length]) 30 | sequences = tokenizer.texts_to_sequences(texts) 31 | del texts 32 | data = pad_sequences(sequences, maxlen=max_length, 33 | padding='post', truncating='post') 34 | del sequences 35 | label_path = './resources/' + label_file 36 | if not os.path.isfile(label_path): 37 | print(label_file + " not found") 38 | exit() 39 | with open(label_path, 'rb') as pckl: 40 | labels = pickle.load(pckl) 41 | data = data.astype(np.uint16) 42 | return data, labels 43 | 44 | 45 | def preprocess(fasttext_name, embedding_dim, max_length, max_num_words): 46 | fastmodel = FastText('./resources/' + fasttext_name) 47 | texts_tokenize = readData('train_texts.pkl', max_length) 48 | print("Tokenizing data ..") 49 | tokenizer = Tokenizer(num_words=max_num_words) 50 | tokenizer.fit_on_texts(texts_tokenize) 51 | print("Tokenization and fitting done!") 52 | print("Loading data ...") 53 | x_train, y_train = editData('train_texts.pkl', 'train_labels.pkl', 54 | max_length, tokenizer) 55 | print("Training data loaded") 56 | x_val, y_val = editData('test_texts.pkl', 'test_labels.pkl', 57 | max_length, tokenizer) 58 | print("Test Data loaded") 59 | word_index = tokenizer.word_index 60 | print('Preparing embedding matrix ...') 61 | embedding_matrix = np.zeros((max_num_words, embedding_dim)) 62 | for word, i in word_index.items(): 63 | if i >= max_num_words: 64 | continue 65 | embedding_matrix[i] = fastmodel[word] 66 | print("Embedding done!") 67 | return x_train, y_train, x_val, y_val, embedding_matrix 68 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.utils import to_categorical 3 | from keras.layers import Dense, Input, Embedding, Conv2D, MaxPool2D 4 | from keras.layers import Reshape, Flatten, Dropout, Concatenate 5 | from keras.models import Model 6 | from keras.optimizers import Adam 7 | from preprocess_data import preprocess 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('-embedding', action='store', dest='embedding') 12 | fasttext_name = parser.parse_args().embedding 13 | 14 | embedding_dim = 300 15 | learning_rate = 0.001145 16 | bs = 128 17 | drop = 0.2584 18 | max_length = 1431 19 | max_num_words = 23140 20 | filters = [6] 21 | num_filters = 2426 22 | nclasses = 451 23 | 24 | x_train, y_train, x_val, y_val, embedding_matrix = preprocess(fasttext_name, 25 | embedding_dim, 26 | max_length, 27 | max_num_words) 28 | 29 | print("Starting Training ...") 30 | 31 | filter_sizes = [] 32 | for i in filters: 33 | filter_sizes.append(i) 34 | 35 | embedding_layer = Embedding(max_num_words, 36 | embedding_dim, 37 | weights=[embedding_matrix], 38 | input_length=max_length, 39 | trainable=True) 40 | 41 | sequence_input = Input(shape=(max_length,), dtype='uint16') 42 | embedded_sequences = embedding_layer(sequence_input) 43 | reshape = Reshape((max_length, embedding_dim, 1))(embedded_sequences) 44 | 45 | maxpool_blocks = [] 46 | for filter_size in filter_sizes: 47 | conv = Conv2D(num_filters, kernel_size=(filter_size, embedding_dim), 48 | padding='valid', activation='relu', 49 | kernel_initializer='he_uniform', 50 | bias_initializer='zeros')(reshape) 51 | maxpool = MaxPool2D(pool_size=(max_length - filter_size + 1, 1), 52 | strides=(1, 1), padding='valid')(conv) 53 | maxpool_blocks.append(maxpool) 54 | 55 | if len(maxpool_blocks) > 1: 56 | concatenated_tensor = Concatenate(axis=1)(maxpool_blocks) 57 | else: 58 | concatenated_tensor = maxpool_blocks[0] 59 | 60 | flatten = Flatten()(concatenated_tensor) 61 | dropout = Dropout(drop)(flatten) 62 | output = Dense(units=nclasses, activation='softmax')(dropout) 63 | 64 | model = Model(inputs=sequence_input, outputs=output) 65 | 66 | adam = Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, 67 | epsilon=1e-08, decay=0.0) 68 | model.compile(optimizer=adam, 69 | loss='categorical_crossentropy', 70 | metrics=['accuracy']) 71 | 72 | y_train = to_categorical(np.asarray(y_train), 73 | num_classes=nclasses).astype(np.float16) 74 | y_val = to_categorical(np.asarray(y_val), 75 | num_classes=nclasses).astype(np.float16) 76 | history = model.fit(x_train, y_train, 77 | batch_size=bs, shuffle=True, 78 | epochs=20, 79 | validation_data=(x_val, y_val)) 80 | -------------------------------------------------------------------------------- /convertXMLtoPCKL.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | import os 3 | import os.path 4 | import pickle 5 | import re 6 | 7 | 8 | def convert(labels_ID): 9 | status = '' 10 | if not labels_ID: 11 | status = 'train' 12 | train = True 13 | number_of_labels = 0 14 | else: 15 | status = 'test' 16 | train = False 17 | 18 | base_dir = './resources/wipo-alpha/' + status 19 | 20 | if not os.path.isdir(base_dir): 21 | print("Dataset not found") 22 | print("please download wipo-alpha dataset to resources folder") 23 | exit() 24 | 25 | all_files_text = './resources/' + status + 'Directory.txt' 26 | cutoff = 2500 27 | 28 | texts = [] 29 | labels = [] 30 | all_files = open(all_files_text, 'r') 31 | for entry in all_files: 32 | path = base_dir + entry[1:] 33 | path = path.replace('\n', '') 34 | try: 35 | tree = ET.parse(path) 36 | except: 37 | print("File " + path + " not found") 38 | continue 39 | root = tree.getroot() 40 | text = '' 41 | found_desc = 0 42 | text_claims = '' # Claims should be after description 43 | for child in root: 44 | if child.attrib.get('mc') is not None: 45 | label = child.attrib['mc'][:4] 46 | for child_child in child: 47 | if child_child.tag in ['ti']: 48 | text += child_child.text + ' ' 49 | if child_child.tag in ['ab', 'txt']: 50 | if child_child.tag in ['txt']: 51 | found_desc = 1 52 | try: 53 | text += child_child.text + ' ' 54 | except: 55 | pass 56 | if child_child.tag in ['cl']: 57 | try: 58 | text_claims += child_child.text + ' ' 59 | except: 60 | pass 61 | if text_claims is not '': 62 | if found_desc: 63 | found_desc = 0 64 | text += text_claims 65 | 66 | text = text.replace('\n', ' ') 67 | text = re.sub('<[^>]+>', '', text) # remove tags 68 | regex = re.compile('[^a-zA-Z]') # remove all but alphabetic 69 | text = regex.sub(' ', text) 70 | text = re.sub(r'\b\w{1,2}\b', '', text) # remove words of length < 3 71 | text = " ".join(text.split()[:cutoff]).lower() 72 | 73 | if train: 74 | if label not in labels_ID: 75 | labels_ID[label] = number_of_labels 76 | number_of_labels += 1 77 | if label not in labels_ID: 78 | continue 79 | id_of_label = labels_ID.get(label) 80 | texts.append(text) 81 | labels.append(id_of_label) 82 | save_data(status, texts, labels, labels_ID) 83 | 84 | 85 | def save_data(status, texts, labels, labels_ID): 86 | with open('./resources/' + status + '_texts.pkl', 'wb') as pckl: 87 | pickle.dump(texts, pckl) 88 | with open('./resources/' + status + '_labels.pkl', 'wb') as pckl: 89 | pickle.dump(labels, pckl) 90 | if status is 'train': 91 | with open('./resources/labels_ID.pkl', 'wb') as pckl: 92 | pickle.dump(labels_ID, pckl) 93 | print('Done! Number of classes of ' + status + ' documents is ' + 94 | str(len(labels_ID))) 95 | print("total number of samples is " + str(len(texts))) 96 | --------------------------------------------------------------------------------