├── saved_models └── saved_model_path.txt ├── pictures ├── elman.png ├── jordan.png ├── dialogue_system.png └── output_file_format.png ├── models ├── __pycache__ │ └── rnn.cpython-36.pyc └── rnn.py ├── eval ├── conlleval.md └── conlleval.pl ├── atisdata.py ├── README.md ├── LICENSE ├── data └── atis_slot_names.txt └── main.py /saved_models/saved_model_path.txt: -------------------------------------------------------------------------------- 1 | saved model path 2 | -------------------------------------------------------------------------------- /pictures/elman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llhthinker/slot-filling/HEAD/pictures/elman.png -------------------------------------------------------------------------------- /pictures/jordan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llhthinker/slot-filling/HEAD/pictures/jordan.png -------------------------------------------------------------------------------- /pictures/dialogue_system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llhthinker/slot-filling/HEAD/pictures/dialogue_system.png -------------------------------------------------------------------------------- /pictures/output_file_format.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llhthinker/slot-filling/HEAD/pictures/output_file_format.png -------------------------------------------------------------------------------- /models/__pycache__/rnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llhthinker/slot-filling/HEAD/models/__pycache__/rnn.cpython-36.pyc -------------------------------------------------------------------------------- /eval/conlleval.md: -------------------------------------------------------------------------------- 1 | 例如数据格式如下: 2 | 3 | 词\t词性\t命名实体\t正确label\t模型预测的label 4 | 5 | ![img](../pictures/output_file_format.png) 6 | 7 | 保证最后一列是模型预测的label,倒数第二列是正确的label 8 | 9 | 文件名假设为output.txt 10 | 11 | 那么可以用以下命令评测: 12 | 13 | ```bash 14 | perl ./conlleval –d "\t" < output.txt 15 | ``` 16 | 17 | -d参数用来指定行之间的分隔符,默认是空格 -------------------------------------------------------------------------------- /atisdata.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | 3 | class ATISData(data.Dataset): 4 | def __init__(self, X, y): 5 | self.len = len(X) 6 | self.x_data = X 7 | self.y_data = y 8 | 9 | def __getitem__(self, index): 10 | return self.x_data[index], self.y_data[index] 11 | 12 | def __len__(self): 13 | return self.len 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | 3 | - `Pytorch 0.3.0` or newer 4 | - `Python 3.6` 5 | - `Perl` 6 | 7 | ## Usage 8 | 9 | - Training and Prediction 10 | 11 | ```bash 12 | python main.py [-h] [--train-data-path TRAIN_DATA_PATH] 13 | [--test-data-path TEST_DATA_PATH] 14 | [--slot-names-path SLOT_NAMES_PATH] 15 | [--saved-model-path SAVED_MODEL_PATH] 16 | [--result-path RESULT_PATH] [--mode {elman,jordan,hybrid,lstm}] 17 | [--bidirectional] [--cuda] 18 | ``` 19 | 20 | - [Evaluation](./eval/conlleval.md) 21 | 22 | ```bash 23 | perl eval/conlleval.pl -d "\t" < data/output.txt 24 | ``` 25 | 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/atis_slot_names.txt: -------------------------------------------------------------------------------- 1 | compartment 2 | stoploc.state_code 3 | depart_date.today_relative 4 | arrive_date.date_relative 5 | depart_date.date_relative 6 | return_date.month_name 7 | depart_date.day_name 8 | fromloc.airport_code 9 | cost_relative 10 | connect 11 | return_time.period_mod 12 | arrive_time.period_mod 13 | flight_number 14 | depart_time.time_relative 15 | arrive_time.period_of_day 16 | depart_time.period_of_day 17 | fare_amount 18 | city_name 19 | depart_date.day_number 20 | toloc.state_code 21 | arrive_date.month_name 22 | stoploc.airport_code 23 | arrive_date.today_relative 24 | airport_code 25 | arrive_time.start_time 26 | period_of_day 27 | arrive_time.time 28 | toloc.state_name 29 | booking_class 30 | arrive_time.end_time 31 | meal 32 | arrive_time.time_relative 33 | return_date.day_number 34 | day_name 35 | or 36 | economy 37 | fromloc.airport_name 38 | return_date.day_name 39 | class_type 40 | meal_code 41 | depart_time.time 42 | return_date.today_relative 43 | round_trip 44 | restriction_code 45 | fare_basis_code 46 | flight 47 | airline_name 48 | time_relative 49 | airline_code 50 | fromloc.state_name 51 | flight_stop 52 | day_number 53 | flight_mod 54 | depart_time.start_time 55 | today_relative 56 | arrive_date.day_number 57 | arrive_date.day_name 58 | depart_time.period_mod 59 | mod 60 | depart_date.month_name 61 | flight_days 62 | stoploc.airport_name 63 | flight_time 64 | fromloc.city_name 65 | transport_type 66 | return_time.period_of_day 67 | state_code 68 | toloc.country_name 69 | return_date.date_relative 70 | depart_date.year 71 | toloc.city_name 72 | time 73 | airport_name 74 | stoploc.city_name 75 | meal_description 76 | toloc.airport_code 77 | days_code 78 | toloc.airport_name 79 | state_name 80 | depart_time.end_time 81 | aircraft_code 82 | month_name 83 | fromloc.state_code 84 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import numpy as np 5 | from collections import Counter 6 | 7 | from builddataset import * 8 | from atisdata import ATISData 9 | from models.rnn import SlotFilling 10 | 11 | import torch 12 | from torch import nn, optim 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | 17 | def train(train_data_path, test_data_path, slot_names_path, mode, bidirectional, saved_model_path, cuda): 18 | train_data = load_data(train_data_path) 19 | label2idx, idx2label = build_label_vocab(slot_names_path) 20 | word2idx, idx2word = build_vocab(train_data) 21 | train_X, train_y = build_dataset(train_data, word2idx, label2idx) 22 | train_set = ATISData(train_X, train_y) 23 | train_loader = DataLoader(dataset=train_set, 24 | batch_size=1, 25 | shuffle=True) 26 | 27 | test_data = load_data(test_data_path) 28 | test_X, test_y = build_dataset(test_data, word2idx, label2idx) 29 | test_set = ATISData(test_X, test_y) 30 | test_loader = DataLoader(dataset=test_set, 31 | batch_size=1, 32 | shuffle=False) 33 | 34 | vocab_size = len(word2idx) 35 | label_size = len(label2idx) 36 | 37 | model = SlotFilling(vocab_size, label_size, mode=mode, bidirectional=bidirectional) 38 | if cuda: 39 | model = model.cuda() 40 | loss_fn = nn.CrossEntropyLoss() 41 | optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 42 | epoch_num = 10 43 | print_step = 1000 44 | for epoch in range(epoch_num): 45 | start_time = time.time() 46 | running_loss = 0.0 47 | count = 0 48 | for X, y in train_loader: 49 | optimizer.zero_grad() 50 | if torch.__version__ < "0.4.*": 51 | X, y = Variable(X), Variable(y) 52 | if cuda: 53 | X, y = X.cuda(), y.cuda() 54 | output = model(X) 55 | output = output.squeeze(0) 56 | y = y.squeeze(0) 57 | # print(output.size()) 58 | # print(y.size()) 59 | sentence_len = y.size(0) 60 | loss = loss_fn(output, y) 61 | loss.backward() 62 | optimizer.step() 63 | if torch.__version__ < "0.4.*": 64 | running_loss += loss.data[0] / sentence_len 65 | else: 66 | running_loss += loss.item() / sentence_len 67 | count += 1 68 | if count % print_step == 0: 69 | print("epoch: %d, loss: %.4f" % (epoch, running_loss / print_step)) 70 | running_loss = 0.0 71 | count = 0 72 | print("time: ", time.time() - start_time) 73 | do_eval(model, test_loader, cuda) 74 | torch.save(model.state_dict(), saved_model_path) 75 | 76 | 77 | def predict(train_data_path, test_data_path, slot_names_path, mode, bidirectional, saved_model_path, result_path, cuda): 78 | train_data = load_data(train_data_path) 79 | label2idx, idx2label = build_label_vocab(slot_names_path) 80 | word2idx, idx2word = build_vocab(train_data) 81 | 82 | test_data = load_data(test_data_path) 83 | test_X, test_y = build_dataset(test_data, word2idx, label2idx) 84 | test_set = ATISData(test_X, test_y) 85 | test_loader = DataLoader(dataset=test_set, 86 | batch_size=1, 87 | shuffle=False) 88 | 89 | vocab_size = len(word2idx) 90 | label_size = len(label2idx) 91 | 92 | model = SlotFilling(vocab_size, label_size, mode=mode, bidirectional=bidirectional) 93 | model.load_state_dict(torch.load(saved_model_path)) 94 | if cuda: 95 | model = model.cuda() 96 | predicted = do_eval(model, test_loader, cuda) 97 | predicted_labels = [idx2label[idx] for idx in predicted] 98 | gen_result_file(test_data, predicted_labels, result_path) 99 | 100 | def gen_result_file(test_data, predicted, result_path): 101 | f = open(result_path, 'w', encoding='utf-8') 102 | idx = 0 103 | for sentence, true_labels in test_data: 104 | for word, true_label in zip(sentence, true_labels): 105 | predicted_label = predicted[idx] 106 | idx += 1 107 | f.write(word + "\t" + true_label + "\t" + predicted_label + "\n") 108 | f.write("\n") 109 | f.close() 110 | 111 | def accuracy(predictions, labels): 112 | return (100.0 * np.sum(np.array(predictions) == np.array(labels)) / len(labels)) 113 | 114 | def do_eval(model, test_loader, cuda): 115 | model.is_training = False 116 | predicted = [] 117 | true_label = [] 118 | for X, y in test_loader: 119 | X = Variable(X) 120 | if cuda: 121 | X = X.cuda() 122 | output = model(X) 123 | output = output.squeeze(0) 124 | _, output = torch.max(output, 1) 125 | if cuda: 126 | output = output.cpu() 127 | predicted.extend(output.data.numpy().tolist()) 128 | y = y.squeeze(0) 129 | true_label.extend(y.numpy().tolist()) 130 | print("Acc: %.3f" % accuracy(predicted, true_label)) 131 | return predicted 132 | 133 | if __name__ == "__main__": 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument('--train-data-path', type=str, default="./data/atis.train.txt") 136 | parser.add_argument('--test-data-path', type=str, default="./data/atis.test.txt") 137 | parser.add_argument('--slot-names-path', type=str, default="./data/atis_slot_names.txt") 138 | parser.add_argument('--saved-model-path', type=str, default="./saved_models/epoch10elman.model") 139 | parser.add_argument('--result-path', type=str, default="./data/output.txt") 140 | parser.add_argument('--mode', type=str, default='elman', 141 | choices=['elman', 'jordan', 'hybrid', 'lstm']) 142 | parser.add_argument('--bidirectional', action='store_true', default=False) 143 | parser.add_argument('--cuda', action='store_true', default=False) 144 | 145 | args = parser.parse_args() 146 | 147 | if os.path.exists(args.saved_model_path): 148 | print("predicting...") 149 | predict(args.train_data_path, args.test_data_path, args.slot_names_path, 150 | args.mode, args.bidirectional, args.saved_model_path, args.result_path, args.cuda) 151 | else: 152 | print("training") 153 | train(args.train_data_path, args.test_data_path, args.slot_names_path, 154 | args.mode, args.bidirectional, args.saved_model_path, args.cuda) 155 | 156 | -------------------------------------------------------------------------------- /models/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class SlotFilling(nn.Module): 8 | def __init__(self, vocab_size, label_size, mode='elman', bidirectional=False, cuda=False, is_training=True): 9 | 10 | super(SlotFilling, self).__init__() 11 | self.is_training = is_training 12 | embedding_dim = 100 13 | hidden_size = 75 14 | self.embedding = nn.Embedding(vocab_size, embedding_dim) 15 | 16 | if mode == 'lstm': 17 | self.rnn = nn.LSTM(input_size=embedding_dim, 18 | hidden_size=hidden_size, 19 | bidirectional=bidirectional, 20 | batch_first=True) 21 | else: 22 | self.rnn = RNN(input_size=embedding_dim, 23 | hidden_size=hidden_size, 24 | mode=mode, 25 | cuda=cuda, 26 | bidirectional=bidirectional, 27 | batch_first=True) 28 | if bidirectional: 29 | self.fc = nn.Linear(2*hidden_size, label_size) 30 | else: 31 | self.fc = nn.Linear(hidden_size, label_size) 32 | 33 | def forward(self, X): 34 | embed = self.embedding(X) 35 | embed = F.dropout(embed, p=0.2, training=self.is_training) 36 | outputs, _ = self.rnn(embed) 37 | outputs = self.fc(outputs) 38 | return outputs 39 | 40 | 41 | class ElmanRNNCell(nn.Module): 42 | def __init__(self, input_size, hidden_size): 43 | super(ElmanRNNCell, self).__init__() 44 | self.hidden_size = hidden_size 45 | self.i2h_fc1 = nn.Linear(input_size, hidden_size) 46 | self.i2h_fc2 = nn.Linear(hidden_size, hidden_size) 47 | self.h2o_fc = nn.Linear(hidden_size, hidden_size) 48 | 49 | def forward(self, input, hidden): 50 | hidden = F.sigmoid(self.i2h_fc1(input) + self.i2h_fc2(hidden)) 51 | output = F.sigmoid(self.h2o_fc(hidden)) 52 | return output, hidden 53 | 54 | 55 | class JordanRNNCell(nn.Module): 56 | def __init__(self, input_size, hidden_size): 57 | super(JordanRNNCell, self).__init__() 58 | self.hidden_size = hidden_size 59 | self.i2h_fc1 = nn.Linear(input_size, hidden_size) 60 | self.i2h_fc2 = nn.Linear(hidden_size, hidden_size) 61 | self.h2o_fc = nn.Linear(hidden_size, hidden_size) 62 | self.y_0 = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(1, hidden_size)), requires_grad=True) 63 | 64 | def forward(self, input, hidden=None): 65 | if hidden is None: 66 | hidden = self.y_0 67 | hidden = F.sigmoid(self.i2h_fc1(input) + self.i2h_fc2(hidden)) 68 | output = F.sigmoid(self.h2o_fc(hidden)) 69 | return output, output 70 | 71 | 72 | class HybridRNNCell(nn.Module): 73 | def __init__(self, input_size, hidden_size): 74 | super(HybridRNNCell, self).__init__() 75 | self.hidden_size = hidden_size 76 | self.i2h_fc1 = nn.Linear(input_size, hidden_size) 77 | self.i2h_fc2 = nn.Linear(hidden_size, hidden_size) 78 | self.i2h_fc3 = nn.Linear(hidden_size, hidden_size) 79 | self.h2o_fc = nn.Linear(hidden_size, hidden_size) 80 | self.y_0 = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(1, hidden_size)), requires_grad=True) 81 | 82 | def forward(self, input, hidden, output=None): 83 | if output is None: 84 | output = self.y_0 85 | hidden = F.sigmoid(self.i2h_fc1(input)+self.i2h_fc2(hidden)+self.i2h_fc3(output)) 86 | output = F.sigmoid(self.h2o_fc(hidden)) 87 | return output, hidden 88 | 89 | 90 | class RNN(nn.Module): 91 | 92 | def __init__(self, input_size, hidden_size, mode='elman', cuda=False, bidirectional=False, batch_first=True): 93 | super(RNN, self).__init__() 94 | self.mode = mode 95 | self.cuda = cuda 96 | if mode == 'elman': 97 | RNNCell = ElmanRNNCell 98 | elif mode == 'jordan': 99 | RNNCell = JordanRNNCell 100 | elif mode == 'hybrid': 101 | RNNCell = HybridRNNCell 102 | else: 103 | raise RuntimeError(mode + " is not a simple rnn mode") 104 | self.forward_cell = RNNCell(input_size=input_size, 105 | hidden_size=hidden_size) 106 | self.hidden_size = hidden_size 107 | self.bidirectional = bidirectional 108 | self.batch_first = batch_first 109 | if bidirectional: 110 | self.reversed_cell = RNNCell(input_size=input_size, 111 | hidden_size=hidden_size) 112 | 113 | def _forward(self, inputs, hidden): 114 | outputs = [] 115 | seq_len = inputs.size(1) 116 | # batch_size*seq_len*n 117 | # -> seq_len*batch_size*n 118 | inputs = inputs.transpose(0, 1) 119 | # print("hidden size:", hidden.size()) 120 | output = None 121 | for i in range(seq_len): 122 | step_input = inputs[i] # batch_size*n 123 | if self.mode == 'hybrid': 124 | output, hidden = self.forward_cell(step_input, hidden, output) 125 | else: 126 | output, hidden = self.forward_cell(step_input, hidden) 127 | outputs.append(output) 128 | 129 | return outputs, hidden 130 | 131 | def _reversed_forward(self, inputs, hidden): 132 | outputs = [] 133 | seq_len = inputs.size(1) 134 | # batch_size*seq_len*n 135 | # -> seq_len_len*batch_size*n 136 | inputs = inputs.transpose(0, 1) 137 | output = None 138 | for i in range(seq_len): 139 | step_input = inputs[seq_len-i-1] # batch_size*n 140 | if self.mode == 'hybrid': 141 | output, hidden = self.reversed_cell(step_input, hidden, output) 142 | else: 143 | output, hidden = self.reversed_cell(step_input, hidden) 144 | outputs.append(output) 145 | 146 | outputs.reverse() 147 | return outputs, hidden 148 | 149 | def forward(self, inputs, hidden=None): 150 | if hidden is None and self.mode != "jordan": 151 | # if hidden is None: 152 | batch_size = inputs.size(0) 153 | # print(batch_size) 154 | hidden = torch.autograd.Variable(torch.zeros(batch_size, 155 | self.hidden_size)) 156 | if self.cuda: 157 | hidden = hidden.cuda() 158 | 159 | output_forward, hidden_forward = self._forward(inputs, hidden) 160 | output_forward = torch.stack(output_forward, dim=0) 161 | if not self.bidirectional: 162 | if self.batch_first: 163 | output_forward = output_forward.transpose(0,1) 164 | return output_forward, hidden_forward 165 | 166 | output_reversed, hidden_reversed = self._reversed_forward(inputs, hidden) 167 | hidden = torch.cat([hidden_forward, hidden_reversed], dim=hidden_forward.dim() - 1) 168 | output_reversed = torch.stack(output_reversed, dim=0) 169 | output = torch.cat([output_forward, output_reversed], 170 | dim=output_reversed.data.dim() - 1) 171 | if self.batch_first: 172 | output = output.transpose(0,1) 173 | return output, hidden 174 | 175 | -------------------------------------------------------------------------------- /eval/conlleval.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # conlleval: evaluate result of processing CoNLL-2000 shared task 3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file 4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html 5 | # options: l: generate LaTeX output for tables like in 6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 7 | # r: accept raw result tags (without B- and I- prefix; 8 | # assumes one word per chunk) 9 | # d: alternative delimiter tag (default is single space) 10 | # o: alternative outside tag (default is O) 11 | # note: the file should contain lines with items separated 12 | # by $delimiter characters (default space). The final 13 | # two items should contain the correct tag and the 14 | # guessed tag in that order. Sentences should be 15 | # separated from each other by empty lines or lines 16 | # with $boundary fields (default -X-). 17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/ 18 | # started: 1998-09-25 19 | # version: 2004-01-26 20 | # author: Erik Tjong Kim Sang 21 | 22 | use strict; 23 | 24 | my $false = 0; 25 | my $true = 42; 26 | 27 | my $boundary = "-X-"; # sentence boundary 28 | my $correct; # current corpus chunk tag (I,O,B) 29 | my $correctChunk = 0; # number of correctly identified chunks 30 | my $correctTags = 0; # number of correct chunk tags 31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.) 32 | my $delimiter = " "; # field delimiter 33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979) 34 | my $firstItem; # first feature (for sentence boundary checks) 35 | my $foundCorrect = 0; # number of chunks in corpus 36 | my $foundGuessed = 0; # number of identified chunks 37 | my $guessed; # current guessed chunk tag 38 | my $guessedType; # type of current guessed chunk tag 39 | my $i; # miscellaneous counter 40 | my $inCorrect = $false; # currently processed chunk is correct until now 41 | my $lastCorrect = "O"; # previous chunk tag in corpus 42 | my $latex = 0; # generate LaTeX formatted output 43 | my $lastCorrectType = ""; # type of previously identified chunk tag 44 | my $lastGuessed = "O"; # previously identified chunk tag 45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus 46 | my $lastType; # temporary storage for detecting duplicates 47 | my $line; # line 48 | my $nbrOfFeatures = -1; # number of features per line 49 | my $precision = 0.0; # precision score 50 | my $oTag = "O"; # outside tag, default O 51 | my $raw = 0; # raw input: add B to every token 52 | my $recall = 0.0; # recall score 53 | my $tokenCounter = 0; # token counter (ignores sentence breaks) 54 | 55 | my %correctChunk = (); # number of correctly identified chunks per type 56 | my %foundCorrect = (); # number of chunks in corpus per type 57 | my %foundGuessed = (); # number of identified chunks per type 58 | 59 | my @features; # features on line 60 | my @sortedTypes; # sorted list of chunk type names 61 | 62 | # sanity check 63 | while (@ARGV and $ARGV[0] =~ /^-/) { 64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); } 65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); } 66 | elsif ($ARGV[0] eq "-d") { 67 | shift(@ARGV); 68 | if (not defined $ARGV[0]) { 69 | die "conlleval: -d requires delimiter character"; 70 | } 71 | $delimiter = shift(@ARGV); 72 | } elsif ($ARGV[0] eq "-o") { 73 | shift(@ARGV); 74 | if (not defined $ARGV[0]) { 75 | die "conlleval: -o requires delimiter character"; 76 | } 77 | $oTag = shift(@ARGV); 78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; } 79 | } 80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; } 81 | # process input 82 | while () { 83 | chomp($line = $_); 84 | @features = split(/$delimiter/,$line); 85 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; } 86 | elsif ($nbrOfFeatures != $#features and @features != 0) { 87 | printf STDERR "unexpected number of features: %d (%d)\n", 88 | $#features+1,$nbrOfFeatures+1; 89 | exit(1); 90 | } 91 | if (@features == 0 or 92 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); } 93 | if (@features < 2) { 94 | die "conlleval: unexpected number of features in line $line\n"; 95 | } 96 | if ($raw) { 97 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; } 98 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; } 99 | if ($features[$#features] ne "O") { 100 | $features[$#features] = "B-$features[$#features]"; 101 | } 102 | if ($features[$#features-1] ne "O") { 103 | $features[$#features-1] = "B-$features[$#features-1]"; 104 | } 105 | } 106 | # 20040126 ET code which allows hyphens in the types 107 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 108 | $guessed = $1; 109 | $guessedType = $2; 110 | } else { 111 | $guessed = $features[$#features]; 112 | $guessedType = ""; 113 | } 114 | pop(@features); 115 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 116 | $correct = $1; 117 | $correctType = $2; 118 | } else { 119 | $correct = $features[$#features]; 120 | $correctType = ""; 121 | } 122 | pop(@features); 123 | # ($guessed,$guessedType) = split(/-/,pop(@features)); 124 | # ($correct,$correctType) = split(/-/,pop(@features)); 125 | $guessedType = $guessedType ? $guessedType : ""; 126 | $correctType = $correctType ? $correctType : ""; 127 | $firstItem = shift(@features); 128 | 129 | # 1999-06-26 sentence breaks should always be counted as out of chunk 130 | if ( $firstItem eq $boundary ) { $guessed = "O"; } 131 | 132 | if ($inCorrect) { 133 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 134 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 135 | $lastGuessedType eq $lastCorrectType) { 136 | $inCorrect=$false; 137 | $correctChunk++; 138 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 139 | $correctChunk{$lastCorrectType}+1 : 1; 140 | } elsif ( 141 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) != 142 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or 143 | $guessedType ne $correctType ) { 144 | $inCorrect=$false; 145 | } 146 | } 147 | 148 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 149 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 150 | $guessedType eq $correctType) { $inCorrect = $true; } 151 | 152 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) { 153 | $foundCorrect++; 154 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ? 155 | $foundCorrect{$correctType}+1 : 1; 156 | } 157 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) { 158 | $foundGuessed++; 159 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ? 160 | $foundGuessed{$guessedType}+1 : 1; 161 | } 162 | if ( $firstItem ne $boundary ) { 163 | if ( $correct eq $guessed and $guessedType eq $correctType ) { 164 | $correctTags++; 165 | } 166 | $tokenCounter++; 167 | } 168 | 169 | $lastGuessed = $guessed; 170 | $lastCorrect = $correct; 171 | $lastGuessedType = $guessedType; 172 | $lastCorrectType = $correctType; 173 | } 174 | if ($inCorrect) { 175 | $correctChunk++; 176 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 177 | $correctChunk{$lastCorrectType}+1 : 1; 178 | } 179 | 180 | if (not $latex) { 181 | # compute overall precision, recall and FB1 (default values are 0.0) 182 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 183 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 184 | $FB1 = 2*$precision*$recall/($precision+$recall) 185 | if ($precision+$recall > 0); 186 | 187 | # print overall performance 188 | printf "processed $tokenCounter tokens with $foundCorrect phrases; "; 189 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n"; 190 | if ($tokenCounter>0) { 191 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter; 192 | printf "precision: %6.2f%%; ",$precision; 193 | printf "recall: %6.2f%%; ",$recall; 194 | printf "FB1: %6.2f\n",$FB1; 195 | } 196 | } 197 | 198 | # sort chunk type names 199 | undef($lastType); 200 | @sortedTypes = (); 201 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) { 202 | if (not($lastType) or $lastType ne $i) { 203 | push(@sortedTypes,($i)); 204 | } 205 | $lastType = $i; 206 | } 207 | # print performance per chunk type 208 | if (not $latex) { 209 | for $i (@sortedTypes) { 210 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 211 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; } 212 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 213 | if (not($foundCorrect{$i})) { $recall = 0.0; } 214 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 215 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 216 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 217 | printf "%17s: ",$i; 218 | printf "precision: %6.2f%%; ",$precision; 219 | printf "recall: %6.2f%%; ",$recall; 220 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i}; 221 | } 222 | } else { 223 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline"; 224 | for $i (@sortedTypes) { 225 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 226 | if (not($foundGuessed{$i})) { $precision = 0.0; } 227 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 228 | if (not($foundCorrect{$i})) { $recall = 0.0; } 229 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 230 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 231 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 232 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\", 233 | $i,$precision,$recall,$FB1; 234 | } 235 | print "\\hline\n"; 236 | $precision = 0.0; 237 | $recall = 0; 238 | $FB1 = 0.0; 239 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 240 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 241 | $FB1 = 2*$precision*$recall/($precision+$recall) 242 | if ($precision+$recall > 0); 243 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n", 244 | $precision,$recall,$FB1; 245 | } 246 | 247 | exit 0; 248 | 249 | # endOfChunk: checks if a chunk ended between the previous and current word 250 | # arguments: previous and current chunk tags, previous and current types 251 | # note: this code is capable of handling other chunk representations 252 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 253 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 254 | 255 | sub endOfChunk { 256 | my $prevTag = shift(@_); 257 | my $tag = shift(@_); 258 | my $prevType = shift(@_); 259 | my $type = shift(@_); 260 | my $chunkEnd = $false; 261 | 262 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; } 263 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; } 264 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; } 265 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 266 | 267 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; } 268 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; } 269 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; } 270 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 271 | 272 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) { 273 | $chunkEnd = $true; 274 | } 275 | 276 | # corrected 1998-12-22: these chunks are assumed to have length 1 277 | if ( $prevTag eq "]" ) { $chunkEnd = $true; } 278 | if ( $prevTag eq "[" ) { $chunkEnd = $true; } 279 | 280 | return($chunkEnd); 281 | } 282 | 283 | # startOfChunk: checks if a chunk started between the previous and current word 284 | # arguments: previous and current chunk tags, previous and current types 285 | # note: this code is capable of handling other chunk representations 286 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 287 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 288 | 289 | sub startOfChunk { 290 | my $prevTag = shift(@_); 291 | my $tag = shift(@_); 292 | my $prevType = shift(@_); 293 | my $type = shift(@_); 294 | my $chunkStart = $false; 295 | 296 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; } 297 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; } 298 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; } 299 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 300 | 301 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; } 302 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; } 303 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; } 304 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 305 | 306 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) { 307 | $chunkStart = $true; 308 | } 309 | 310 | # corrected 1998-12-22: these chunks are assumed to have length 1 311 | if ( $tag eq "[" ) { $chunkStart = $true; } 312 | if ( $tag eq "]" ) { $chunkStart = $true; } 313 | 314 | return($chunkStart); 315 | } 316 | --------------------------------------------------------------------------------