├── ade ├── ade.option ├── ADEsentence.h ├── Relation.h ├── BestPerformance.h ├── Tool.h ├── Makefile ├── N3Lhelper.h ├── .project ├── .settings │ └── language.settings.xml ├── Example.h ├── Entity.h ├── .cproject ├── ade.cpp ├── Options.h ├── CNN.h ├── Attention_ZhouACL2016.h ├── utils.h ├── Argument_helper.h └── NNade3.h ├── bb3 ├── bb3.option ├── Document.h ├── NerExample.h ├── Tool.h ├── N3Lhelper.h ├── Makefile ├── .project ├── .settings │ └── language.settings.xml ├── Relation.h ├── Example.h ├── Entity.h ├── bb3.cpp ├── .cproject ├── Options.h ├── CNN.h ├── Attention_ZhouACL2016.h ├── utils.h └── Argument_helper.h ├── .gitattributes └── README.md /ade/ade.option: -------------------------------------------------------------------------------- 1 | wordCutOff = 0 2 | maxIter = 300 3 | adaAlpha = 0.03 4 | dropProb = 0 5 | wordEmbSize = 200 6 | otherEmbSize = 25 7 | rnnHiddenSize = 100 8 | hiddenSize = 100 9 | evalPerIter = 1 10 | wordEmbFineTune = true 11 | abbrPath = 12 | puncPath = 13 | verboseIter = 1 14 | embFile = 15 | wordcontext = 0 16 | regParameter = 0.00000001 17 | batchSize = 1 18 | poolType = 0 19 | 20 | 21 | -------------------------------------------------------------------------------- /bb3/bb3.option: -------------------------------------------------------------------------------- 1 | wordCutOff = 0 2 | maxIter = 300 3 | adaAlpha = 0.03 4 | dropProb = 0 5 | wordEmbSize = 200 6 | otherEmbSize = 25 7 | rnnHiddenSize = 100 8 | hiddenSize = 100 9 | evalPerIter = 1 10 | wordEmbFineTune = true 11 | abbrPath = 12 | puncPath = 13 | verboseIter = 1 14 | embFile = 15 | wordcontext = 0 16 | regParameter = 0.00000001 17 | batchSize = 1 18 | poolType = 0 19 | 20 | 21 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | 7 | # Standard to msysgit 8 | *.doc diff=astextplain 9 | *.DOC diff=astextplain 10 | *.docx diff=astextplain 11 | *.DOCX diff=astextplain 12 | *.dot diff=astextplain 13 | *.DOT diff=astextplain 14 | *.pdf diff=astextplain 15 | *.PDF diff=astextplain 16 | *.rtf diff=astextplain 17 | *.RTF diff=astextplain 18 | -------------------------------------------------------------------------------- /ade/ADEsentence.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef ADESENTENCE_H_ 4 | #define ADESENTENCE_H_ 5 | 6 | #include 7 | #include "Entity.h" 8 | #include "Relation.h" 9 | 10 | using namespace std; 11 | 12 | 13 | class ADEsentence { 14 | public: 15 | ADEsentence() { 16 | offset = -1; 17 | } 18 | 19 | void clear() { 20 | entities.clear(); 21 | relations.clear(); 22 | text.clear(); 23 | offset = -1; 24 | } 25 | 26 | vector entities; 27 | vector relations; 28 | string text; 29 | int offset; 30 | }; 31 | 32 | #endif /* ADESENTENCE_H_ */ 33 | -------------------------------------------------------------------------------- /bb3/Document.h: -------------------------------------------------------------------------------- 1 | /* 2 | * BiocDocument.h 3 | * 4 | * Created on: Dec 19, 2015 5 | * Author: fox 6 | */ 7 | 8 | #ifndef DOCUMENT_H_ 9 | #define DOCUMENT_H_ 10 | 11 | #include 12 | #include "Entity.h" 13 | #include "Relation.h" 14 | #include "Sent.h" 15 | 16 | using namespace std; 17 | 18 | 19 | class Document { 20 | public: 21 | Document() { 22 | 23 | } 24 | /* virtual ~BiocDocument() { 25 | 26 | }*/ 27 | 28 | string id; 29 | vector entities; 30 | vector relations; 31 | vector sents; 32 | int maxParagraphId; 33 | }; 34 | 35 | #endif /* DOCUMENT_H_ */ 36 | -------------------------------------------------------------------------------- /bb3/NerExample.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Example.h 3 | * 4 | * Created on: Mar 17, 2015 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_NerEXAMPLE_H_ 9 | #define SRC_NerEXAMPLE_H_ 10 | #include 11 | 12 | using namespace std; 13 | 14 | class NerExample { 15 | 16 | public: 17 | vector m_labels; 18 | int goldLabel; 19 | 20 | vector _words; 21 | vector _postags; 22 | 23 | int _prior_ner; 24 | 25 | int _current_idx; 26 | 27 | 28 | public: 29 | NerExample() 30 | { 31 | 32 | } 33 | /* virtual ~Example() 34 | { 35 | 36 | }*/ 37 | 38 | 39 | 40 | }; 41 | 42 | #endif /* SRC_EXAMPLE_H_ */ 43 | -------------------------------------------------------------------------------- /ade/Relation.h: -------------------------------------------------------------------------------- 1 | /* 2 | * RelationEntity.h 3 | * 4 | * Created on: Mar 9, 2016 5 | * Author: fox 6 | */ 7 | 8 | #ifndef RELATION_H_ 9 | #define RELATION_H_ 10 | 11 | #include 12 | #include 13 | 14 | using namespace std; 15 | 16 | // non-directional live-in relation in mention level 17 | class Relation { 18 | 19 | public: 20 | 21 | Entity entity1; 22 | Entity entity2; 23 | 24 | bool equals(const Relation& another) const { 25 | if(entity1.equals(another.entity1) && entity2.equals(another.entity2)) 26 | return true; 27 | else 28 | return false; 29 | } 30 | }; 31 | 32 | 33 | 34 | #endif /* RELATION_H_ */ 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is a C++ implement of extracting biomedical entities and relations from raw text. 2 | Please see the publication "A neural joint model for entity and relation extraction from biomedical text" for details. 3 | If you have any question, pleae feel free to contact me. 4 | 5 | Citation in BibTex: 6 | 7 | @Article{Li2017, 8 | author="Li, Fei and Zhang, Meishan and Fu, Guohong and Ji, Donghong", 9 | title="A neural joint model for entity and relation extraction from biomedical text", 10 | journal="BMC Bioinformatics", 11 | year="2017", 12 | volume="18", 13 | number="1", 14 | pages="198", 15 | doi="10.1186/s12859-017-1609-9", 16 | url="http://dx.doi.org/10.1186/s12859-017-1609-9" 17 | } 18 | -------------------------------------------------------------------------------- /ade/BestPerformance.h: -------------------------------------------------------------------------------- 1 | /* 2 | * BestPerformance.h 3 | * 4 | * Created on: Sep 16, 2016 5 | * Author: fox 6 | */ 7 | 8 | #ifndef BESTPERFORMANCE_H_ 9 | #define BESTPERFORMANCE_H_ 10 | 11 | class BestPerformance { 12 | public: 13 | BestPerformance() { 14 | 15 | } 16 | double dev_pEntity = 0; 17 | double dev_rEntity = 0; 18 | double dev_f1Entity = 0; 19 | double dev_pRelation = 0; 20 | double dev_rRelation = 0; 21 | double dev_f1Relation = 0; 22 | 23 | double test_pEntity = 0; 24 | double test_rEntity = 0; 25 | double test_f1Entity = 0; 26 | double test_pRelation = 0; 27 | double test_rRelation = 0; 28 | double test_f1Relation = 0; 29 | }; 30 | 31 | #endif /* BESTPERFORMANCE_H_ */ 32 | -------------------------------------------------------------------------------- /ade/Tool.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tool.h 3 | * 4 | * Created on: Dec 27, 2015 5 | * Author: fox 6 | */ 7 | 8 | #ifndef TOOL_H_ 9 | #define TOOL_H_ 10 | 11 | #include "SentSplitter.h" 12 | #include "Tokenizer.h" 13 | #include "Options.h" 14 | #include "Word2Vec.h" 15 | #include "FoxUtil.h" 16 | 17 | 18 | class Tool { 19 | public: 20 | Options option; 21 | fox::SentSplitter sentSplitter; 22 | fox::Tokenizer tokenizer; 23 | fox::Word2Vec* w2v; 24 | 25 | 26 | Tool(Options option) : option(option), sentSplitter(NULL, &option.abbrPath), 27 | tokenizer(&option.puncPath) { 28 | 29 | w2v = new fox::Word2Vec(); 30 | 31 | } 32 | virtual ~Tool() { 33 | delete w2v; 34 | } 35 | 36 | }; 37 | 38 | 39 | 40 | #endif /* TOOL_H_ */ 41 | -------------------------------------------------------------------------------- /bb3/Tool.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tool.h 3 | * 4 | * Created on: Dec 27, 2015 5 | * Author: fox 6 | */ 7 | 8 | #ifndef TOOL_H_ 9 | #define TOOL_H_ 10 | 11 | #include "SentSplitter.h" 12 | #include "Tokenizer.h" 13 | #include "Options.h" 14 | #include "Word2Vec.h" 15 | #include "FoxUtil.h" 16 | 17 | 18 | class Tool { 19 | public: 20 | Options option; 21 | fox::SentSplitter sentSplitter; 22 | fox::Tokenizer tokenizer; 23 | fox::Word2Vec* w2v; 24 | 25 | 26 | Tool(Options option) : option(option), sentSplitter(NULL, &option.abbrPath), 27 | tokenizer(&option.puncPath) { 28 | 29 | w2v = new fox::Word2Vec(); 30 | 31 | } 32 | virtual ~Tool() { 33 | delete w2v; 34 | } 35 | 36 | }; 37 | 38 | 39 | 40 | #endif /* TOOL_H_ */ 41 | -------------------------------------------------------------------------------- /ade/Makefile: -------------------------------------------------------------------------------- 1 | cc=g++ 2 | 3 | #cflags = -O0 -g3 -w \ 4 | -I/home/fox/project/FoxUtil \ 5 | -msse3 -I /home/fox/Downloads/mshadow-master/mshadow -DMSHADOW_USE_CUDA=0 -DMSHADOW_USE_CBLAS=1 -DMSHADOW_USE_MKL=0 \ 6 | -I/home/fox/Downloads/LibN3L-master -DUSE_CUDA=0 \ 7 | 8 | cflags = -O3 -w \ 9 | -I/home/fox/project/FoxUtil \ 10 | -msse3 -I /home/fox/Downloads/mshadow-master/mshadow -DMSHADOW_USE_CUDA=0 -DMSHADOW_USE_CBLAS=1 -DMSHADOW_USE_MKL=0 \ 11 | -I/home/fox/Downloads/LibN3L-master -DUSE_CUDA=0 \ 12 | -static-libgcc -static-libstdc++\ 13 | 14 | libs = -lm -lopenblas -Wl,-rpath,./ \ 15 | 16 | all: ade 17 | 18 | ade: ade.cpp Entity.h Relation.h Tool.h Options.h utils.h 19 | $(cc) -o ade ade.cpp $(cflags) $(libs) 20 | 21 | 22 | 23 | 24 | 25 | 26 | clean: 27 | rm -rf *.o 28 | rm -rf ade 29 | 30 | -------------------------------------------------------------------------------- /ade/N3Lhelper.h: -------------------------------------------------------------------------------- 1 | /* 2 | * N3Lhelper.h 3 | * 4 | * Created on: Dec 27, 2015 5 | * Author: fox 6 | */ 7 | 8 | #ifndef N3LHELPER_H_ 9 | #define N3LHELPER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | using namespace std; 14 | 15 | 16 | void alphabet2vectormap(const Alphabet& alphabet, vector& vector, map& IDs) { 17 | 18 | for (int j = 0; j < alphabet.size(); ++j) { 19 | string str = alphabet.from_id(j); 20 | vector.push_back(str); 21 | IDs.insert(map::value_type(str, j)); 22 | } 23 | 24 | } 25 | 26 | template 27 | void array2NRMat(T * array, int sizeX, int sizeY, NRMat& mat) { 28 | for(int i=0;i& vector, map& IDs) { 17 | 18 | for (int j = 0; j < alphabet.size(); ++j) { 19 | string str = alphabet.from_id(j); 20 | vector.push_back(str); 21 | IDs.insert(map::value_type(str, j)); 22 | } 23 | 24 | } 25 | 26 | template 27 | void array2NRMat(T * array, int sizeX, int sizeY, NRMat& mat) { 28 | for(int i=0;i 2 | 3 | ade 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.cdt.managedbuilder.core.genmakebuilder 10 | clean,full,incremental, 11 | 12 | 13 | 14 | 15 | org.eclipse.cdt.managedbuilder.core.ScannerConfigBuilder 16 | full,incremental, 17 | 18 | 19 | 20 | 21 | 22 | org.eclipse.cdt.core.cnature 23 | org.eclipse.cdt.core.ccnature 24 | org.eclipse.cdt.managedbuilder.core.managedBuildNature 25 | org.eclipse.cdt.managedbuilder.core.ScannerConfigNature 26 | 27 | 28 | -------------------------------------------------------------------------------- /bb3/.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | bb3 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.cdt.managedbuilder.core.genmakebuilder 10 | clean,full,incremental, 11 | 12 | 13 | 14 | 15 | org.eclipse.cdt.managedbuilder.core.ScannerConfigBuilder 16 | full,incremental, 17 | 18 | 19 | 20 | 21 | 22 | org.eclipse.cdt.core.cnature 23 | org.eclipse.cdt.core.ccnature 24 | org.eclipse.cdt.managedbuilder.core.managedBuildNature 25 | org.eclipse.cdt.managedbuilder.core.ScannerConfigNature 26 | 27 | 28 | -------------------------------------------------------------------------------- /ade/.settings/language.settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /bb3/.settings/language.settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /bb3/Relation.h: -------------------------------------------------------------------------------- 1 | /* 2 | * RelationEntity.h 3 | * 4 | * Created on: Mar 9, 2016 5 | * Author: fox 6 | */ 7 | 8 | #ifndef RELATION_H_ 9 | #define RELATION_H_ 10 | 11 | #include 12 | #include 13 | 14 | using namespace std; 15 | 16 | // non-directional live-in relation in mention level 17 | class Relation { 18 | 19 | public: 20 | string id; 21 | string idBacteria; 22 | string idLocation; 23 | 24 | Entity bacteria; 25 | Entity location; 26 | 27 | bool equals(const Relation& another) { 28 | if(bacteria.equals(another.bacteria) && location.equals(another.location)) 29 | return true; 30 | else 31 | return false; 32 | } 33 | 34 | void setId(int _id) { 35 | stringstream ss; 36 | ss<<"R"<<_id; 37 | id = ss.str(); 38 | } 39 | 40 | void setBacId(const string& bacId) { 41 | idBacteria = bacId; 42 | bacteria.id = bacId; 43 | } 44 | 45 | void setLocId(const string& locId) { 46 | idLocation = locId; 47 | location.id = locId; 48 | } 49 | }; 50 | 51 | 52 | 53 | #endif /* RELATION_H_ */ 54 | -------------------------------------------------------------------------------- /bb3/Example.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Example.h 3 | * 4 | * Created on: Mar 17, 2015 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_EXAMPLE_H_ 9 | #define SRC_EXAMPLE_H_ 10 | #include 11 | #include "N3L.h" 12 | 13 | using namespace std; 14 | 15 | class Example { 16 | 17 | public: 18 | //used by ner 19 | vector _nerLabels; 20 | int nerGoldLabel; 21 | 22 | vector _words; 23 | vector _postags; 24 | vector< vector > _seq_chars; 25 | 26 | int _prior_ner; 27 | 28 | int _current_idx; 29 | 30 | // used by relation 31 | bool _isRelation; 32 | 33 | vector _relLabels; 34 | int relGoldLabel; 35 | 36 | vector _deps; 37 | vector _ners; 38 | 39 | vector _idxOnSDP_E12A; 40 | vector _idxOnSDP_E22A; 41 | 42 | hash_set _idx_e1; 43 | hash_set _idx_e2; 44 | 45 | public: 46 | Example(bool isrel) 47 | { 48 | nerGoldLabel = -1; 49 | _prior_ner = -1; 50 | _current_idx = -1; 51 | 52 | _isRelation = isrel; 53 | 54 | relGoldLabel = -1; 55 | 56 | } 57 | /* virtual ~Example() 58 | { 59 | 60 | }*/ 61 | 62 | 63 | 64 | }; 65 | 66 | #endif /* SRC_EXAMPLE_H_ */ 67 | -------------------------------------------------------------------------------- /ade/Example.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Example.h 3 | * 4 | * Created on: Mar 17, 2015 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_EXAMPLE_H_ 9 | #define SRC_EXAMPLE_H_ 10 | #include 11 | #include "N3L.h" 12 | 13 | using namespace std; 14 | 15 | class Example { 16 | 17 | public: 18 | //used by ner 19 | vector _nerLabels; 20 | int nerGoldLabel; 21 | 22 | vector _words; 23 | vector _postags; 24 | vector< vector > _seq_chars; 25 | 26 | int _prior_ner; 27 | 28 | int _current_idx; 29 | 30 | // used by relation 31 | bool _isRelation; 32 | 33 | vector _relLabels; 34 | int relGoldLabel; 35 | 36 | vector _deps; 37 | vector _ners; 38 | 39 | vector _idxOnSDP_E12A; 40 | vector _idxOnSDP_E22A; 41 | 42 | hash_set _idx_e1; 43 | hash_set _idx_e2; 44 | 45 | vector _between_words; 46 | 47 | public: 48 | Example(bool isrel) 49 | { 50 | nerGoldLabel = -1; 51 | _prior_ner = -1; 52 | _current_idx = -1; 53 | 54 | _isRelation = isrel; 55 | 56 | relGoldLabel = -1; 57 | 58 | } 59 | /* virtual ~Example() 60 | { 61 | 62 | }*/ 63 | 64 | 65 | 66 | }; 67 | 68 | #endif /* SRC_EXAMPLE_H_ */ 69 | -------------------------------------------------------------------------------- /bb3/Entity.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Entity.h 3 | * 4 | * Created on: Mar 9, 2016 5 | * Author: fox 6 | */ 7 | 8 | #ifndef ENTITY_H_ 9 | #define ENTITY_H_ 10 | 11 | #include 12 | #include 13 | 14 | using namespace std; 15 | 16 | class Entity { 17 | public: 18 | string id; // id is unique in a document 19 | string type; 20 | int begin; // the first character offset at "doc level" 21 | string text; 22 | int end; // the last character offset+1 at "doc level" 23 | 24 | int tkStart; // the token index that this segment starts 25 | int tkEnd; // the token index that this segment ends. 26 | 27 | // for non-continuous mention 28 | int begin2; 29 | int end2; 30 | 31 | int sentIdx; // the sentence index which this entity belongs to 32 | 33 | Entity() { 34 | id = "-1"; 35 | type = ""; 36 | begin = -1; 37 | text = ""; 38 | end = -1; 39 | tkStart = -1; 40 | tkEnd = -1; 41 | 42 | begin2 = -1; 43 | end2 = -1; 44 | 45 | sentIdx = -1; 46 | 47 | } 48 | 49 | bool equals(const Entity& another) const { 50 | if(type == another.type && begin == another.begin && end == another.end) 51 | return true; 52 | else 53 | return false; 54 | } 55 | 56 | void setId(int _id) { 57 | stringstream ss; 58 | ss<<"T"<<_id; 59 | id = ss.str(); 60 | } 61 | }; 62 | 63 | 64 | 65 | #endif /* ENTITY_H_ */ 66 | -------------------------------------------------------------------------------- /ade/Entity.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Entity.h 3 | * 4 | * Created on: Mar 9, 2016 5 | * Author: fox 6 | */ 7 | 8 | #ifndef ENTITY_H_ 9 | #define ENTITY_H_ 10 | 11 | #include 12 | #include 13 | 14 | using namespace std; 15 | 16 | class Entity { 17 | public: 18 | string id; // id is unique in a document 19 | string type; 20 | int begin; // the first character offset at "doc level" 21 | string text; 22 | int end; // the last character offset+1 at "doc level" 23 | 24 | int tkStart; // the token index that this segment starts 25 | int tkEnd; // the token index that this segment ends. 26 | 27 | // for non-continuous mention 28 | int begin2; 29 | int end2; 30 | 31 | int sentIdx; // the sentence index which this entity belongs to 32 | 33 | Entity() { 34 | id = "-1"; 35 | type = ""; 36 | begin = -1; 37 | text = ""; 38 | end = -1; 39 | tkStart = -1; 40 | tkEnd = -1; 41 | 42 | begin2 = -1; 43 | end2 = -1; 44 | 45 | sentIdx = -1; 46 | 47 | } 48 | 49 | bool equals(const Entity& another) const { 50 | if(type == another.type && begin == another.begin && end == another.end) 51 | return true; 52 | else 53 | return false; 54 | } 55 | 56 | bool equalsBoundary(const Entity& another) const { 57 | if(begin == another.begin && end == another.end) 58 | return true; 59 | else 60 | return false; 61 | } 62 | 63 | bool equalsType(const Entity& another) const { 64 | if(type == another.type) 65 | return true; 66 | else 67 | return false; 68 | } 69 | 70 | 71 | }; 72 | 73 | 74 | 75 | #endif /* ENTITY_H_ */ 76 | -------------------------------------------------------------------------------- /bb3/bb3.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * cdr.cpp 3 | * 4 | * Created on: Dec 19, 2015 5 | * Author: fox 6 | */ 7 | 8 | #include 9 | #include "utils.h" 10 | #include "FoxUtil.h" 11 | #include 12 | #include "Token.h" 13 | #include "SentSplitter.h" 14 | #include "Tokenizer.h" 15 | #include "N3L.h" 16 | #include "Argument_helper.h" 17 | #include "Options.h" 18 | #include "Tool.h" 19 | 20 | 21 | #include "NNbb3.h" 22 | 23 | 24 | 25 | using namespace std; 26 | 27 | 28 | int main(int argc, char **argv) 29 | { 30 | #if USE_CUDA==1 31 | InitTensorEngine(); 32 | #else 33 | InitTensorEngine(); 34 | #endif 35 | 36 | 37 | string optionFile; 38 | string trainFile; 39 | string devFile; 40 | string testFile; 41 | string outputFile; 42 | string trainNlpFile; 43 | string devNlpFile; 44 | string testNlpFile; 45 | 46 | 47 | 48 | dsr::Argument_helper ah; 49 | ah.new_named_string("train", "", "", "", trainFile); 50 | ah.new_named_string("dev", "", "", "", devFile); 51 | ah.new_named_string("test", "", "", "", testFile); 52 | ah.new_named_string("option", "", "", "", optionFile); 53 | ah.new_named_string("output", "", "", "", outputFile); 54 | ah.new_named_string("trainnlp", "", "", "", trainNlpFile); 55 | ah.new_named_string("devnlp", "", "", "", devNlpFile); 56 | ah.new_named_string("testnlp", "", "", "", testNlpFile); 57 | 58 | 59 | ah.process(argc, argv); 60 | cout<<"train file: " <(); 91 | #endif 92 | 93 | return 0; 94 | 95 | } 96 | 97 | -------------------------------------------------------------------------------- /ade/.cproject: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /bb3/.cproject: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /ade/ade.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * cdr.cpp 3 | * 4 | * Created on: Dec 19, 2015 5 | * Author: fox 6 | */ 7 | 8 | #include 9 | #include "utils.h" 10 | #include "FoxUtil.h" 11 | #include 12 | #include "Token.h" 13 | #include "SentSplitter.h" 14 | #include "Tokenizer.h" 15 | #include "N3L.h" 16 | #include "Argument_helper.h" 17 | #include "Options.h" 18 | #include "Tool.h" 19 | #include "BestPerformance.h" 20 | 21 | 22 | #include "NNade3.h" 23 | 24 | 25 | using namespace std; 26 | 27 | 28 | int main(int argc, char **argv) 29 | { 30 | #if USE_CUDA==1 31 | InitTensorEngine(); 32 | #else 33 | InitTensorEngine(); 34 | #endif 35 | 36 | 37 | string optionFile; 38 | string annotatedPath; 39 | string processedPath; 40 | string fold; 41 | 42 | 43 | 44 | dsr::Argument_helper ah; 45 | ah.new_named_string("annotated", "", "", "", annotatedPath); 46 | ah.new_named_string("option", "", "", "", optionFile); 47 | ah.new_named_string("processed", "", "", "", processedPath); 48 | ah.new_named_string("fold", "", "", "", fold); 49 | 50 | 51 | 52 | ah.process(argc, argv); 53 | cout<<"annotated path: " < > processedGroups; 66 | vector< vector > annotatedGroups; 67 | loadAnnotatedFile(annotatedPath, annotatedGroups); 68 | loadProcessedFile(processedPath, processedGroups); 69 | 70 | if(!options.embFile.empty()) { 71 | cout<< "load pre-trained emb"<loadFromBinFile(options.embFile, false, true); 73 | } 74 | 75 | vector bestAll; 76 | int currentFold = atoi(fold.c_str()); 77 | for(int crossValid=0;crossValid=0 && crossValid!=currentFold) { 82 | continue; 83 | }*/ 84 | cout<<"###### group ###### "< processedTest; 91 | vector annotatedTest; 92 | vector processedDev; 93 | vector annotatedDev; 94 | vector processedTrain; 95 | vector annotatedTrain; 96 | 97 | for(int groupIdx=0;groupIdx(); 177 | #endif 178 | 179 | return 0; 180 | 181 | } 182 | 183 | -------------------------------------------------------------------------------- /bb3/Options.h: -------------------------------------------------------------------------------- 1 | #ifndef _PARSER_OPTIONS_ 2 | #define _PARSER_OPTIONS_ 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "N3L.h" 11 | 12 | using namespace std; 13 | 14 | class Options { 15 | public: 16 | 17 | int wordCutOff; 18 | 19 | dtype initRange; 20 | int maxIter; 21 | int batchSize; 22 | dtype adaEps; 23 | dtype adaAlpha; 24 | dtype regParameter; 25 | dtype dropProb; 26 | 27 | int evalPerIter; 28 | 29 | bool wordEmbFineTune; 30 | 31 | string abbrPath; 32 | string puncPath; 33 | 34 | int wordcontext; 35 | int wordEmbSize; 36 | int otherEmbSize; 37 | int hiddenSize; 38 | int rnnHiddenSize; 39 | 40 | int sent_window; 41 | 42 | int verboseIter; 43 | 44 | string output; 45 | string embFile; 46 | 47 | int beamSize1; 48 | int beamSize2; 49 | 50 | int poolType; 51 | 52 | Options() { 53 | wordCutOff = 0; 54 | initRange = 0.01; 55 | maxIter = 1000; 56 | batchSize = 1; 57 | adaEps = 1e-6; 58 | adaAlpha = 0.01; 59 | regParameter = 1e-8; 60 | dropProb = 0.5; 61 | 62 | wordcontext = 0; 63 | wordEmbSize = 50; 64 | otherEmbSize = 50; 65 | hiddenSize = 150; 66 | rnnHiddenSize = 100; 67 | 68 | evalPerIter = 1; 69 | wordEmbFineTune = true; 70 | 71 | abbrPath = ""; 72 | puncPath = ""; 73 | 74 | sent_window = 1; 75 | verboseIter = 0; 76 | 77 | output = ""; 78 | embFile = ""; 79 | 80 | beamSize1 = 1; 81 | beamSize2 = 1; 82 | 83 | poolType = 0; 84 | } 85 | 86 | Options(const Options& options) { 87 | wordCutOff = options.wordCutOff; 88 | initRange = options.initRange; 89 | maxIter = options.maxIter; 90 | batchSize = options.batchSize; 91 | adaEps = options.adaEps; 92 | adaAlpha = options.adaAlpha; 93 | regParameter = options.regParameter; 94 | dropProb = options.dropProb; 95 | 96 | wordcontext = options.wordcontext; 97 | wordEmbSize = options.wordEmbSize; 98 | otherEmbSize = options.otherEmbSize; 99 | hiddenSize = options.hiddenSize; 100 | rnnHiddenSize = options.rnnHiddenSize; 101 | 102 | evalPerIter = options.evalPerIter; 103 | wordEmbFineTune = options.wordEmbFineTune; 104 | 105 | abbrPath = options.abbrPath; 106 | puncPath = options.puncPath; 107 | 108 | sent_window = options.sent_window; 109 | 110 | verboseIter = options.verboseIter; 111 | 112 | output = options.output; 113 | embFile = options.embFile; 114 | 115 | beamSize1 = options.beamSize1; 116 | beamSize2 = options.beamSize2; 117 | 118 | poolType = options.poolType; 119 | } 120 | 121 | /* virtual ~Options() { 122 | 123 | }*/ 124 | 125 | void setOptions(const vector &vecOption) { 126 | int i = 0; 127 | for (; i < vecOption.size(); ++i) { 128 | pair pr; 129 | string2pair(vecOption[i], pr, '='); 130 | if (pr.first == "wordCutOff") 131 | wordCutOff = atoi(pr.second.c_str()); 132 | else if (pr.first == "initRange") 133 | initRange = atof(pr.second.c_str()); 134 | else if (pr.first == "maxIter") 135 | maxIter = atoi(pr.second.c_str()); 136 | else if (pr.first == "batchSize") 137 | batchSize = atoi(pr.second.c_str()); 138 | else if (pr.first == "adaEps") 139 | adaEps = atof(pr.second.c_str()); 140 | else if (pr.first == "adaAlpha") 141 | adaAlpha = atof(pr.second.c_str()); 142 | else if (pr.first == "regParameter") 143 | regParameter = atof(pr.second.c_str()); 144 | else if (pr.first == "dropProb") 145 | dropProb = atof(pr.second.c_str()); 146 | 147 | else if (pr.first == "hiddenSize") 148 | hiddenSize = atoi(pr.second.c_str()); 149 | else if (pr.first == "rnnHiddenSize") 150 | rnnHiddenSize = atoi(pr.second.c_str()); 151 | 152 | else if (pr.first == "wordcontext") 153 | wordcontext = atoi(pr.second.c_str()); 154 | else if (pr.first == "wordEmbSize") 155 | wordEmbSize = atoi(pr.second.c_str()); 156 | else if(pr.first == "otherEmbSize") 157 | otherEmbSize = atoi(pr.second.c_str()); 158 | 159 | 160 | else if(pr.first == "evalPerIter") 161 | evalPerIter = atoi(pr.second.c_str()); 162 | 163 | else if (pr.first == "wordEmbFineTune") 164 | wordEmbFineTune = (pr.second == "true") ? true : false; 165 | 166 | else if(pr.first == "abbrPath") 167 | abbrPath = pr.second; 168 | 169 | else if(pr.first == "puncPath") 170 | puncPath = pr.second; 171 | else if (pr.first == "sent_window") 172 | sent_window = atoi(pr.second.c_str()); 173 | else if(pr.first == "verboseIter") 174 | verboseIter = atoi(pr.second.c_str()); 175 | 176 | else if(pr.first == "output") 177 | output = pr.second; 178 | else if(pr.first == "embFile") 179 | embFile = pr.second; 180 | 181 | else if(pr.first == "beamSize1") 182 | beamSize1 = atoi(pr.second.c_str()); 183 | else if(pr.first == "beamSize2") 184 | beamSize2 = atoi(pr.second.c_str()); 185 | 186 | else if(pr.first == "poolType") 187 | poolType = atoi(pr.second.c_str()); 188 | } 189 | } 190 | 191 | void showOptions() { 192 | std::cout << "wordCutOff = " << wordCutOff << std::endl; 193 | std::cout << "initRange = " << initRange << std::endl; 194 | std::cout << "maxIter = " << maxIter << std::endl; 195 | std::cout << "batchSize = " << batchSize << std::endl; 196 | std::cout << "adaEps = " << adaEps << std::endl; 197 | std::cout << "adaAlpha = " << adaAlpha << std::endl; 198 | std::cout << "regParameter = " << regParameter << std::endl; 199 | std::cout << "dropProb = " << dropProb << std::endl; 200 | 201 | std::cout << "hiddenSize = " << hiddenSize << std::endl; 202 | std::cout << "rnnHiddenSize = " << rnnHiddenSize << std::endl; 203 | 204 | std::cout<<"wordcontext = " << wordcontext << endl; 205 | std::cout << "wordEmbSize = " << wordEmbSize << std::endl; 206 | std::cout<<"otherEmbSize = "< &vecOption) { 127 | int i = 0; 128 | for (; i < vecOption.size(); ++i) { 129 | pair pr; 130 | string2pair(vecOption[i], pr, '='); 131 | if (pr.first == "wordCutOff") 132 | wordCutOff = atoi(pr.second.c_str()); 133 | else if (pr.first == "initRange") 134 | initRange = atof(pr.second.c_str()); 135 | else if (pr.first == "maxIter") 136 | maxIter = atoi(pr.second.c_str()); 137 | else if (pr.first == "batchSize") 138 | batchSize = atoi(pr.second.c_str()); 139 | else if (pr.first == "adaEps") 140 | adaEps = atof(pr.second.c_str()); 141 | else if (pr.first == "adaAlpha") 142 | adaAlpha = atof(pr.second.c_str()); 143 | else if (pr.first == "regParameter") 144 | regParameter = atof(pr.second.c_str()); 145 | else if (pr.first == "dropProb") 146 | dropProb = atof(pr.second.c_str()); 147 | 148 | else if (pr.first == "hiddenSize") 149 | hiddenSize = atoi(pr.second.c_str()); 150 | else if (pr.first == "rnnHiddenSize") 151 | rnnHiddenSize = atoi(pr.second.c_str()); 152 | 153 | else if (pr.first == "wordcontext") 154 | wordcontext = atoi(pr.second.c_str()); 155 | else if (pr.first == "wordEmbSize") 156 | wordEmbSize = atoi(pr.second.c_str()); 157 | else if(pr.first == "otherEmbSize") 158 | otherEmbSize = atoi(pr.second.c_str()); 159 | 160 | 161 | else if(pr.first == "evalPerIter") 162 | evalPerIter = atoi(pr.second.c_str()); 163 | 164 | else if (pr.first == "wordEmbFineTune") 165 | wordEmbFineTune = (pr.second == "true") ? true : false; 166 | 167 | else if(pr.first == "abbrPath") 168 | abbrPath = pr.second; 169 | 170 | else if(pr.first == "puncPath") 171 | puncPath = pr.second; 172 | else if (pr.first == "sent_window") 173 | sent_window = atoi(pr.second.c_str()); 174 | else if(pr.first == "verboseIter") 175 | verboseIter = atoi(pr.second.c_str()); 176 | 177 | else if(pr.first == "output") 178 | output = pr.second; 179 | else if(pr.first == "embFile") 180 | embFile = pr.second; 181 | 182 | else if(pr.first == "beamSize1") 183 | beamSize1 = atoi(pr.second.c_str()); 184 | else if(pr.first == "beamSize2") 185 | beamSize2 = atoi(pr.second.c_str()); 186 | 187 | else if(pr.first == "poolType") 188 | poolType = atoi(pr.second.c_str()); 189 | } 190 | } 191 | 192 | void showOptions() { 193 | std::cout << "wordCutOff = " << wordCutOff << std::endl; 194 | std::cout << "initRange = " << initRange << std::endl; 195 | std::cout << "maxIter = " << maxIter << std::endl; 196 | std::cout << "batchSize = " << batchSize << std::endl; 197 | std::cout << "adaEps = " << adaEps << std::endl; 198 | std::cout << "adaAlpha = " << adaAlpha << std::endl; 199 | std::cout << "regParameter = " << regParameter << std::endl; 200 | std::cout << "dropProb = " << dropProb << std::endl; 201 | 202 | std::cout << "hiddenSize = " << hiddenSize << std::endl; 203 | std::cout << "rnnHiddenSize = " << rnnHiddenSize << std::endl; 204 | 205 | std::cout<<"wordcontext = " << wordcontext << endl; 206 | std::cout << "wordEmbSize = " << wordEmbSize << std::endl; 207 | std::cout<<"otherEmbSize = "< 18 | class CNN { 19 | public: 20 | UniLayer _kernel; 21 | int _outputsize; 22 | int _inputsize; 23 | int _windowsize; 24 | int _kernelinputsize; 25 | int _poolType; 26 | 27 | Tensor _null, _nullLoss, _eg2null; 28 | 29 | AttentionZhouACL2016 _att; 30 | 31 | public: 32 | CNN() { 33 | _outputsize = 0; 34 | _inputsize= 0; 35 | _windowsize= 0; 36 | _kernelinputsize= 0; 37 | _poolType = 0; 38 | } 39 | 40 | // outputsize is the dimension of this CNN 41 | // inputsize is the dimension of one element 'x' of the sequence 42 | // if windowsize is 1, the input of kernel will be x[i-1],x[i],x[i+1] 43 | // poolType: 0-max, 5-att 44 | inline void initial(int outputsize, int inputsize, int windowsize, int poolType=0, int funcType = 0, int seed = 0) { 45 | int kernelInputsize = inputsize + 2*windowsize*inputsize; 46 | 47 | _kernel.initial(outputsize, kernelInputsize, true, seed, funcType); 48 | 49 | _null = NewTensor(Shape2(1, inputsize), d_zero); 50 | _nullLoss = NewTensor(Shape2(1, inputsize), d_zero); 51 | _eg2null = NewTensor(Shape2(1, inputsize), d_zero); 52 | 53 | _outputsize = outputsize; 54 | _inputsize = inputsize; 55 | _windowsize = windowsize; 56 | _kernelinputsize = kernelInputsize; 57 | _poolType = poolType; 58 | 59 | if(_poolType == 5) { 60 | _att.initial(_outputsize, seed+13); 61 | } 62 | } 63 | 64 | inline void release() { 65 | _kernel.release(); 66 | 67 | FreeSpace(&_null); 68 | FreeSpace(&_nullLoss); 69 | FreeSpace(&_eg2null); 70 | 71 | if(_poolType == 5) { 72 | _att.release(); 73 | } 74 | } 75 | 76 | inline void ComputeForwardScore(Tensor x, Tensor y, Tensor kernelInputs, Tensor kernelOutputs, Tensor poolIndex) { 77 | y = 0.0; 78 | int seq_size = x.size(0); 79 | if (seq_size == 0) 80 | return; 81 | 82 | // convolution 83 | 84 | 85 | for (int idx = 0; idx < seq_size; idx++) { 86 | 87 | int windowBegin = idx-_windowsize; 88 | int windowEnd = idx+_windowsize; 89 | 90 | int kernelIdx = 0; 91 | for(int i=windowBegin;i<=windowEnd;i++) { 92 | if(i<0 || i>=seq_size) { 93 | for(int j=0;j<_inputsize;j++) { 94 | kernelInputs[idx][0][kernelIdx] = _null[0][j]; 95 | kernelIdx++; 96 | } 97 | } else { 98 | for(int j=0;j<_inputsize;j++) { 99 | kernelInputs[idx][0][kernelIdx] = x[i][0][j]; 100 | kernelIdx++; 101 | } 102 | } 103 | } 104 | 105 | _kernel.ComputeForwardScore(kernelInputs[idx], kernelOutputs[idx]); 106 | 107 | } 108 | 109 | maxpool_forward(kernelOutputs, y, poolIndex); 110 | 111 | } 112 | 113 | inline void ComputeBackwardLoss(Tensor x, Tensor y, 114 | Tensor ly, Tensor lx, 115 | Tensor kernelInputs, Tensor kernelOutputs, 116 | Tensor poolIndex, bool bclear = false) { 117 | int seq_size = x.size(0); 118 | if (seq_size == 0) 119 | return; 120 | 121 | if (bclear) 122 | lx = 0.0; 123 | 124 | Tensor l_kernelOutputs= NewTensor(Shape3(seq_size, 1, _outputsize), d_zero); 125 | pool_backward(ly, poolIndex, l_kernelOutputs); 126 | 127 | Tensor l_kernelInputs = NewTensor(Shape3(seq_size, 1, _kernelinputsize), d_zero); 128 | for (int idx = 0; idx < seq_size; idx++) { 129 | 130 | _kernel.ComputeBackwardLoss(kernelInputs[idx], kernelOutputs[idx], l_kernelOutputs[idx], l_kernelInputs[idx]); 131 | 132 | int windowBegin = idx-_windowsize; 133 | int windowEnd = idx+_windowsize; 134 | 135 | int kernelIdx = 0; 136 | for(int i=windowBegin;i<=windowEnd;i++) { 137 | if(i<0 || i>=seq_size) { 138 | for(int j=0;j<_inputsize;j++) { 139 | _nullLoss[0][j] += l_kernelInputs[idx][0][kernelIdx]; 140 | kernelIdx++; 141 | } 142 | } else { 143 | for(int j=0;j<_inputsize;j++) { 144 | lx[i][0][j] += l_kernelInputs[idx][0][kernelIdx]; 145 | kernelIdx++; 146 | } 147 | } 148 | } 149 | 150 | 151 | 152 | } 153 | 154 | FreeSpace(&l_kernelOutputs); 155 | FreeSpace(&l_kernelInputs); 156 | 157 | } 158 | 159 | // if attention, use this one 160 | inline void ComputeForwardScore(Tensor x, Tensor y, 161 | Tensor kernelInputs, Tensor kernelOutputs, 162 | Tensor poolIndex, 163 | Tensor M, 164 | Tensor omegaM, Tensor exp_omegaM, 165 | Tensor alpha, Tensor r 166 | ) { 167 | y = 0.0; 168 | int seq_size = x.size(0); 169 | if (seq_size == 0) 170 | return; 171 | 172 | // convolution 173 | 174 | 175 | for (int idx = 0; idx < seq_size; idx++) { 176 | 177 | int windowBegin = idx-_windowsize; 178 | int windowEnd = idx+_windowsize; 179 | 180 | int kernelIdx = 0; 181 | for(int i=windowBegin;i<=windowEnd;i++) { 182 | if(i<0 || i>=seq_size) { 183 | for(int j=0;j<_inputsize;j++) { 184 | kernelInputs[idx][0][kernelIdx] = _null[0][j]; 185 | kernelIdx++; 186 | } 187 | } else { 188 | for(int j=0;j<_inputsize;j++) { 189 | kernelInputs[idx][0][kernelIdx] = x[i][0][j]; 190 | kernelIdx++; 191 | } 192 | } 193 | } 194 | 195 | _kernel.ComputeForwardScore(kernelInputs[idx], kernelOutputs[idx]); 196 | 197 | } 198 | 199 | _att.ComputeForwardScore(kernelOutputs, M, omegaM, exp_omegaM, alpha, r, y); 200 | 201 | } 202 | 203 | inline void ComputeBackwardLoss(Tensor x, Tensor y, 204 | Tensor ly, Tensor lx, 205 | Tensor kernelInputs, Tensor kernelOutputs, 206 | Tensor poolIndex, 207 | Tensor M, Tensor omegaM, Tensor exp_omegaM, 208 | Tensor alpha, Tensor r, 209 | bool bclear = false) { 210 | int seq_size = x.size(0); 211 | if (seq_size == 0) 212 | return; 213 | 214 | if (bclear) 215 | lx = 0.0; 216 | 217 | Tensor l_kernelOutputs= NewTensor(Shape3(seq_size, 1, _outputsize), d_zero); 218 | _att.ComputeBackwardLoss(kernelOutputs, M, omegaM, exp_omegaM, alpha, r, 219 | y, ly, l_kernelOutputs); 220 | 221 | Tensor l_kernelInputs = NewTensor(Shape3(seq_size, 1, _kernelinputsize), d_zero); 222 | for (int idx = 0; idx < seq_size; idx++) { 223 | 224 | _kernel.ComputeBackwardLoss(kernelInputs[idx], kernelOutputs[idx], l_kernelOutputs[idx], l_kernelInputs[idx]); 225 | 226 | int windowBegin = idx-_windowsize; 227 | int windowEnd = idx+_windowsize; 228 | 229 | int kernelIdx = 0; 230 | for(int i=windowBegin;i<=windowEnd;i++) { 231 | if(i<0 || i>=seq_size) { 232 | for(int j=0;j<_inputsize;j++) { 233 | _nullLoss[0][j] += l_kernelInputs[idx][0][kernelIdx]; 234 | kernelIdx++; 235 | } 236 | } else { 237 | for(int j=0;j<_inputsize;j++) { 238 | lx[i][0][j] += l_kernelInputs[idx][0][kernelIdx]; 239 | kernelIdx++; 240 | } 241 | } 242 | } 243 | 244 | 245 | 246 | } 247 | 248 | FreeSpace(&l_kernelOutputs); 249 | FreeSpace(&l_kernelInputs); 250 | 251 | } 252 | 253 | inline void updateAdaGrad(dtype regularizationWeight, dtype adaAlpha, dtype adaEps) { 254 | _nullLoss = _nullLoss + _null * regularizationWeight; 255 | _eg2null = _eg2null + _nullLoss * _nullLoss; 256 | _null = _null - _nullLoss * adaAlpha / F(_eg2null + adaEps); 257 | 258 | _kernel.updateAdaGrad(regularizationWeight, adaAlpha, adaEps); 259 | 260 | if(_poolType == 5) { 261 | _att.updateAdaGrad(regularizationWeight, adaAlpha, adaEps); 262 | } 263 | 264 | clearGrad(); 265 | } 266 | 267 | inline void clearGrad() { 268 | _nullLoss = 0; 269 | } 270 | 271 | }; 272 | 273 | #endif /* CNN_H_ */ 274 | -------------------------------------------------------------------------------- /bb3/CNN.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef CNN_H_ 3 | #define CNN_H_ 4 | 5 | // A standard convolutional neural network 6 | 7 | #include "tensor.h" 8 | #include "MyLib.h" 9 | #include "Utiltensor.h" 10 | #include "UniLayer.h" 11 | #include "Attention_ZhouACL2016.h" 12 | 13 | using namespace mshadow; 14 | using namespace mshadow::expr; 15 | using namespace mshadow::utils; 16 | 17 | template 18 | class CNN { 19 | public: 20 | UniLayer _kernel; 21 | int _outputsize; 22 | int _inputsize; 23 | int _windowsize; 24 | int _kernelinputsize; 25 | int _poolType; 26 | 27 | Tensor _null, _nullLoss, _eg2null; 28 | 29 | AttentionZhouACL2016 _att; 30 | 31 | public: 32 | CNN() { 33 | _outputsize = 0; 34 | _inputsize= 0; 35 | _windowsize= 0; 36 | _kernelinputsize= 0; 37 | _poolType = 0; 38 | } 39 | 40 | // outputsize is the dimension of this CNN 41 | // inputsize is the dimension of one element 'x' of the sequence 42 | // if windowsize is 1, the input of kernel will be x[i-1],x[i],x[i+1] 43 | // poolType: 0-max, 5-att 44 | inline void initial(int outputsize, int inputsize, int windowsize, int poolType=0, int funcType = 0, int seed = 0) { 45 | int kernelInputsize = inputsize + 2*windowsize*inputsize; 46 | 47 | _kernel.initial(outputsize, kernelInputsize, true, seed, funcType); 48 | 49 | _null = NewTensor(Shape2(1, inputsize), d_zero); 50 | _nullLoss = NewTensor(Shape2(1, inputsize), d_zero); 51 | _eg2null = NewTensor(Shape2(1, inputsize), d_zero); 52 | 53 | _outputsize = outputsize; 54 | _inputsize = inputsize; 55 | _windowsize = windowsize; 56 | _kernelinputsize = kernelInputsize; 57 | _poolType = poolType; 58 | 59 | if(_poolType == 5) { 60 | _att.initial(_outputsize, seed+13); 61 | } 62 | } 63 | 64 | inline void release() { 65 | _kernel.release(); 66 | 67 | FreeSpace(&_null); 68 | FreeSpace(&_nullLoss); 69 | FreeSpace(&_eg2null); 70 | 71 | if(_poolType == 5) { 72 | _att.release(); 73 | } 74 | } 75 | 76 | inline void ComputeForwardScore(Tensor x, Tensor y, Tensor kernelInputs, Tensor kernelOutputs, Tensor poolIndex) { 77 | y = 0.0; 78 | int seq_size = x.size(0); 79 | if (seq_size == 0) 80 | return; 81 | 82 | // convolution 83 | 84 | 85 | for (int idx = 0; idx < seq_size; idx++) { 86 | 87 | int windowBegin = idx-_windowsize; 88 | int windowEnd = idx+_windowsize; 89 | 90 | int kernelIdx = 0; 91 | for(int i=windowBegin;i<=windowEnd;i++) { 92 | if(i<0 || i>=seq_size) { 93 | for(int j=0;j<_inputsize;j++) { 94 | kernelInputs[idx][0][kernelIdx] = _null[0][j]; 95 | kernelIdx++; 96 | } 97 | } else { 98 | for(int j=0;j<_inputsize;j++) { 99 | kernelInputs[idx][0][kernelIdx] = x[i][0][j]; 100 | kernelIdx++; 101 | } 102 | } 103 | } 104 | 105 | _kernel.ComputeForwardScore(kernelInputs[idx], kernelOutputs[idx]); 106 | 107 | } 108 | 109 | maxpool_forward(kernelOutputs, y, poolIndex); 110 | 111 | } 112 | 113 | inline void ComputeBackwardLoss(Tensor x, Tensor y, 114 | Tensor ly, Tensor lx, 115 | Tensor kernelInputs, Tensor kernelOutputs, 116 | Tensor poolIndex, bool bclear = false) { 117 | int seq_size = x.size(0); 118 | if (seq_size == 0) 119 | return; 120 | 121 | if (bclear) 122 | lx = 0.0; 123 | 124 | Tensor l_kernelOutputs= NewTensor(Shape3(seq_size, 1, _outputsize), d_zero); 125 | pool_backward(ly, poolIndex, l_kernelOutputs); 126 | 127 | Tensor l_kernelInputs = NewTensor(Shape3(seq_size, 1, _kernelinputsize), d_zero); 128 | for (int idx = 0; idx < seq_size; idx++) { 129 | 130 | _kernel.ComputeBackwardLoss(kernelInputs[idx], kernelOutputs[idx], l_kernelOutputs[idx], l_kernelInputs[idx]); 131 | 132 | int windowBegin = idx-_windowsize; 133 | int windowEnd = idx+_windowsize; 134 | 135 | int kernelIdx = 0; 136 | for(int i=windowBegin;i<=windowEnd;i++) { 137 | if(i<0 || i>=seq_size) { 138 | for(int j=0;j<_inputsize;j++) { 139 | _nullLoss[0][j] += l_kernelInputs[idx][0][kernelIdx]; 140 | kernelIdx++; 141 | } 142 | } else { 143 | for(int j=0;j<_inputsize;j++) { 144 | lx[i][0][j] += l_kernelInputs[idx][0][kernelIdx]; 145 | kernelIdx++; 146 | } 147 | } 148 | } 149 | 150 | 151 | 152 | } 153 | 154 | FreeSpace(&l_kernelOutputs); 155 | FreeSpace(&l_kernelInputs); 156 | 157 | } 158 | 159 | // if attention, use this one 160 | inline void ComputeForwardScore(Tensor x, Tensor y, 161 | Tensor kernelInputs, Tensor kernelOutputs, 162 | Tensor poolIndex, 163 | Tensor M, 164 | Tensor omegaM, Tensor exp_omegaM, 165 | Tensor alpha, Tensor r 166 | ) { 167 | y = 0.0; 168 | int seq_size = x.size(0); 169 | if (seq_size == 0) 170 | return; 171 | 172 | // convolution 173 | 174 | 175 | for (int idx = 0; idx < seq_size; idx++) { 176 | 177 | int windowBegin = idx-_windowsize; 178 | int windowEnd = idx+_windowsize; 179 | 180 | int kernelIdx = 0; 181 | for(int i=windowBegin;i<=windowEnd;i++) { 182 | if(i<0 || i>=seq_size) { 183 | for(int j=0;j<_inputsize;j++) { 184 | kernelInputs[idx][0][kernelIdx] = _null[0][j]; 185 | kernelIdx++; 186 | } 187 | } else { 188 | for(int j=0;j<_inputsize;j++) { 189 | kernelInputs[idx][0][kernelIdx] = x[i][0][j]; 190 | kernelIdx++; 191 | } 192 | } 193 | } 194 | 195 | _kernel.ComputeForwardScore(kernelInputs[idx], kernelOutputs[idx]); 196 | 197 | } 198 | 199 | _att.ComputeForwardScore(kernelOutputs, M, omegaM, exp_omegaM, alpha, r, y); 200 | 201 | } 202 | 203 | inline void ComputeBackwardLoss(Tensor x, Tensor y, 204 | Tensor ly, Tensor lx, 205 | Tensor kernelInputs, Tensor kernelOutputs, 206 | Tensor poolIndex, 207 | Tensor M, Tensor omegaM, Tensor exp_omegaM, 208 | Tensor alpha, Tensor r, 209 | bool bclear = false) { 210 | int seq_size = x.size(0); 211 | if (seq_size == 0) 212 | return; 213 | 214 | if (bclear) 215 | lx = 0.0; 216 | 217 | Tensor l_kernelOutputs= NewTensor(Shape3(seq_size, 1, _outputsize), d_zero); 218 | _att.ComputeBackwardLoss(kernelOutputs, M, omegaM, exp_omegaM, alpha, r, 219 | y, ly, l_kernelOutputs); 220 | 221 | Tensor l_kernelInputs = NewTensor(Shape3(seq_size, 1, _kernelinputsize), d_zero); 222 | for (int idx = 0; idx < seq_size; idx++) { 223 | 224 | _kernel.ComputeBackwardLoss(kernelInputs[idx], kernelOutputs[idx], l_kernelOutputs[idx], l_kernelInputs[idx]); 225 | 226 | int windowBegin = idx-_windowsize; 227 | int windowEnd = idx+_windowsize; 228 | 229 | int kernelIdx = 0; 230 | for(int i=windowBegin;i<=windowEnd;i++) { 231 | if(i<0 || i>=seq_size) { 232 | for(int j=0;j<_inputsize;j++) { 233 | _nullLoss[0][j] += l_kernelInputs[idx][0][kernelIdx]; 234 | kernelIdx++; 235 | } 236 | } else { 237 | for(int j=0;j<_inputsize;j++) { 238 | lx[i][0][j] += l_kernelInputs[idx][0][kernelIdx]; 239 | kernelIdx++; 240 | } 241 | } 242 | } 243 | 244 | 245 | 246 | } 247 | 248 | FreeSpace(&l_kernelOutputs); 249 | FreeSpace(&l_kernelInputs); 250 | 251 | } 252 | 253 | inline void updateAdaGrad(dtype regularizationWeight, dtype adaAlpha, dtype adaEps) { 254 | _nullLoss = _nullLoss + _null * regularizationWeight; 255 | _eg2null = _eg2null + _nullLoss * _nullLoss; 256 | _null = _null - _nullLoss * adaAlpha / F(_eg2null + adaEps); 257 | 258 | _kernel.updateAdaGrad(regularizationWeight, adaAlpha, adaEps); 259 | 260 | if(_poolType == 5) { 261 | _att.updateAdaGrad(regularizationWeight, adaAlpha, adaEps); 262 | } 263 | 264 | clearGrad(); 265 | } 266 | 267 | inline void clearGrad() { 268 | _nullLoss = 0; 269 | } 270 | 271 | }; 272 | 273 | #endif /* CNN_H_ */ 274 | -------------------------------------------------------------------------------- /ade/Attention_ZhouACL2016.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef SRC_AttentionZhouACL2016_H_ 3 | #define SRC_AttentionZhouACL2016_H_ 4 | #include "tensor.h" 5 | 6 | #include "BiLayer.h" 7 | #include "MyLib.h" 8 | #include "Utiltensor.h" 9 | #include "Pooling.h" 10 | #include "UniLayer.h" 11 | 12 | using namespace mshadow; 13 | using namespace mshadow::expr; 14 | using namespace mshadow::utils; 15 | 16 | // This is re-implementment of the work of Zhou et al. (2016) in ACL. 17 | // Attention-Based Bidirectional LSTM Networks for Relation Classification 18 | 19 | template 20 | class AttentionZhouACL2016 { 21 | 22 | public: 23 | Tensor _omega; 24 | Tensor _gradOmega; 25 | Tensor _eg2Omega; 26 | 27 | int _dw = 0; 28 | 29 | 30 | AttentionZhouACL2016() { 31 | } 32 | 33 | inline void initial(int inputAndOutputSize, int seed = 0) { 34 | _dw = inputAndOutputSize; 35 | 36 | 37 | dtype bound = sqrt(6.0 / (_dw + 1)); 38 | 39 | _omega = NewTensor(Shape2(1, _dw), d_zero); 40 | _gradOmega = NewTensor(Shape2(1, _dw), d_zero); 41 | _eg2Omega = NewTensor(Shape2(1, _dw), d_zero); 42 | 43 | 44 | random(_omega, -1.0 * bound, 1.0 * bound, seed); 45 | } 46 | 47 | 48 | inline void release() { 49 | FreeSpace(&_omega); 50 | FreeSpace(&_gradOmega); 51 | FreeSpace(&_eg2Omega); 52 | } 53 | 54 | 55 | public: 56 | 57 | // assume that all the vector is row vector, so the formulas should be adjusted. 58 | inline void ComputeForwardScore(std::vector >& H, std::vector >& M, 59 | Tensor omegaM, Tensor exp_omegaM, 60 | Tensor alpha, Tensor r, 61 | Tensor hStar) { 62 | 63 | int seq_size = H.size(); 64 | if(seq_size == 0) return; 65 | 66 | 67 | 68 | for(int i=0;i(H[i]); // M = tanh(H) 70 | for(int j=0;j<_dw;j++) 71 | omegaM[0][i] += _omega[0][j]*M[i][0][j]; // dot(_omega, M[i].T()); // t(w)*M 72 | } 73 | 74 | // alpha = softmax(t(w)*M) 75 | /* int optLabel = -1; 76 | for (int i = 0; i < seq_size; ++i) { 77 | if (optLabel < 0 || omegaM[0][i] > omegaM[0][optLabel]) 78 | optLabel = i; 79 | }*/ 80 | dtype sum = 0; 81 | //dtype maxScore = omegaM[0][optLabel]; 82 | for (int i = 0; i < seq_size; ++i) { 83 | exp_omegaM[0][i] = exp(omegaM[0][i] /*- maxScore*/); 84 | sum += exp_omegaM[0][i]; 85 | } 86 | for (int i = 0; i < seq_size; ++i) { 87 | alpha[0][i] = exp_omegaM[0][i]/sum; 88 | } 89 | 90 | // r = H*t(alpha) 91 | for(int i = 0; i < _dw; ++i) { 92 | for(int j=0; j(r); 99 | 100 | /* for(int i = 0; i < _dw; ++i) { 101 | for(int j=0;j >& H, std::vector >& M, 110 | Tensor omegaM, Tensor exp_omegaM, 111 | Tensor alpha, Tensor r, 112 | Tensor hStar, Tensor lhStar, 113 | std::vector >& lH, bool bclear = false) { 114 | 115 | int seq_size = H.size(); 116 | if(seq_size == 0) return; 117 | 118 | if(bclear){ 119 | for (int idx = 0; idx < seq_size; idx++) { 120 | lH[idx] = 0.0; 121 | } 122 | } 123 | 124 | 125 | 126 | Tensor lr = NewTensor(Shape2(r.size(0), r.size(1)), d_zero); 127 | lr += lhStar * F(hStar); 128 | 129 | Tensor lalpha = NewTensor(Shape2(alpha.size(0), alpha.size(1)), d_zero); 130 | for(int i = 0; i < _dw; ++i) { 131 | for(int j=0; j lOmegaM = NewTensor(Shape2(omegaM.size(0), omegaM.size(1)), d_zero); 138 | for (int i = 0; i < seq_size; ++i) { 139 | for(int j=0;j > lM(seq_size); 148 | for (int idx = 0; idx < seq_size; idx++) { 149 | lM[idx] = NewTensor(Shape2(M[idx].size(0), M[idx].size(1)), d_zero); 150 | } 151 | for(int i=0;i(M[i]); 155 | } 156 | 157 | 158 | FreeSpace(&lr); 159 | FreeSpace(&lalpha); 160 | FreeSpace(&lOmegaM); 161 | for (int idx = 0; idx < seq_size; idx++) { 162 | FreeSpace(&(lM[idx])); 163 | } 164 | 165 | /* for(int i = 0; i < _dw; ++i) { 166 | for(int j=0;j H, Tensor M, 176 | Tensor omegaM, Tensor exp_omegaM, 177 | Tensor alpha, Tensor r, 178 | Tensor hStar) { 179 | 180 | int seq_size = H.size(0); 181 | if(seq_size == 0) return; 182 | 183 | 184 | 185 | for(int i=0;i(H[i]); // M = tanh(H) 187 | for(int j=0;j<_dw;j++) 188 | omegaM[0][i] += _omega[0][j]*M[i][0][j]; // dot(_omega, M[i].T()); // t(w)*M 189 | } 190 | 191 | // alpha = softmax(t(w)*M) 192 | dtype sum = 0; 193 | for (int i = 0; i < seq_size; ++i) { 194 | exp_omegaM[0][i] = exp(omegaM[0][i] ); 195 | sum += exp_omegaM[0][i]; 196 | } 197 | for (int i = 0; i < seq_size; ++i) { 198 | alpha[0][i] = exp_omegaM[0][i]/sum; 199 | } 200 | 201 | // r = H*t(alpha) 202 | for(int i = 0; i < _dw; ++i) { 203 | for(int j=0; j(r); 210 | 211 | 212 | } 213 | 214 | 215 | 216 | inline void ComputeBackwardLoss(Tensor H, Tensor M, 217 | Tensor omegaM, Tensor exp_omegaM, 218 | Tensor alpha, Tensor r, 219 | Tensor hStar, Tensor lhStar, 220 | Tensor lH, bool bclear = false) { 221 | 222 | int seq_size = H.size(0); 223 | if(seq_size == 0) return; 224 | 225 | if(bclear){ 226 | for (int idx = 0; idx < seq_size; idx++) { 227 | lH[idx] = 0.0; 228 | } 229 | } 230 | 231 | 232 | 233 | Tensor lr = NewTensor(Shape2(r.size(0), r.size(1)), d_zero); 234 | lr += lhStar * F(hStar); 235 | 236 | Tensor lalpha = NewTensor(Shape2(alpha.size(0), alpha.size(1)), d_zero); 237 | for(int i = 0; i < _dw; ++i) { 238 | for(int j=0; j lOmegaM = NewTensor(Shape2(omegaM.size(0), omegaM.size(1)), d_zero); 245 | for (int i = 0; i < seq_size; ++i) { 246 | for(int j=0;j > lM(seq_size); 255 | for (int idx = 0; idx < seq_size; idx++) { 256 | lM[idx] = NewTensor(Shape2(M[idx].size(0), M[idx].size(1)), d_zero); 257 | } 258 | for(int i=0;i(M[i]); 262 | } 263 | 264 | 265 | FreeSpace(&lr); 266 | FreeSpace(&lalpha); 267 | FreeSpace(&lOmegaM); 268 | for (int idx = 0; idx < seq_size; idx++) { 269 | FreeSpace(&(lM[idx])); 270 | } 271 | 272 | 273 | } 274 | 275 | 276 | 277 | inline void updateAdaGrad(dtype regularizationWeight, dtype adaAlpha, dtype adaEps) { 278 | _gradOmega = _gradOmega + _omega * regularizationWeight; 279 | _eg2Omega= _eg2Omega + _gradOmega * _gradOmega; 280 | _omega = _omega - _gradOmega * adaAlpha / F(_eg2Omega + adaEps); 281 | 282 | clearGrad(); 283 | } 284 | 285 | inline void clearGrad() { 286 | _gradOmega = 0; 287 | } 288 | 289 | 290 | }; 291 | 292 | #endif 293 | -------------------------------------------------------------------------------- /bb3/Attention_ZhouACL2016.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef SRC_AttentionZhouACL2016_H_ 3 | #define SRC_AttentionZhouACL2016_H_ 4 | #include "tensor.h" 5 | 6 | #include "BiLayer.h" 7 | #include "MyLib.h" 8 | #include "Utiltensor.h" 9 | #include "Pooling.h" 10 | #include "UniLayer.h" 11 | 12 | using namespace mshadow; 13 | using namespace mshadow::expr; 14 | using namespace mshadow::utils; 15 | 16 | // This is re-implementment of the work of Zhou et al. (2016) in ACL. 17 | // Attention-Based Bidirectional LSTM Networks for Relation Classification 18 | 19 | template 20 | class AttentionZhouACL2016 { 21 | 22 | public: 23 | Tensor _omega; 24 | Tensor _gradOmega; 25 | Tensor _eg2Omega; 26 | 27 | int _dw = 0; 28 | 29 | 30 | AttentionZhouACL2016() { 31 | } 32 | 33 | inline void initial(int inputAndOutputSize, int seed = 0) { 34 | _dw = inputAndOutputSize; 35 | 36 | 37 | dtype bound = sqrt(6.0 / (_dw + 1)); 38 | 39 | _omega = NewTensor(Shape2(1, _dw), d_zero); 40 | _gradOmega = NewTensor(Shape2(1, _dw), d_zero); 41 | _eg2Omega = NewTensor(Shape2(1, _dw), d_zero); 42 | 43 | 44 | random(_omega, -1.0 * bound, 1.0 * bound, seed); 45 | } 46 | 47 | 48 | inline void release() { 49 | FreeSpace(&_omega); 50 | FreeSpace(&_gradOmega); 51 | FreeSpace(&_eg2Omega); 52 | } 53 | 54 | 55 | public: 56 | 57 | // assume that all the vector is row vector, so the formulas should be adjusted. 58 | inline void ComputeForwardScore(std::vector >& H, std::vector >& M, 59 | Tensor omegaM, Tensor exp_omegaM, 60 | Tensor alpha, Tensor r, 61 | Tensor hStar) { 62 | 63 | int seq_size = H.size(); 64 | if(seq_size == 0) return; 65 | 66 | 67 | 68 | for(int i=0;i(H[i]); // M = tanh(H) 70 | for(int j=0;j<_dw;j++) 71 | omegaM[0][i] += _omega[0][j]*M[i][0][j]; // dot(_omega, M[i].T()); // t(w)*M 72 | } 73 | 74 | // alpha = softmax(t(w)*M) 75 | /* int optLabel = -1; 76 | for (int i = 0; i < seq_size; ++i) { 77 | if (optLabel < 0 || omegaM[0][i] > omegaM[0][optLabel]) 78 | optLabel = i; 79 | }*/ 80 | dtype sum = 0; 81 | //dtype maxScore = omegaM[0][optLabel]; 82 | for (int i = 0; i < seq_size; ++i) { 83 | exp_omegaM[0][i] = exp(omegaM[0][i] /*- maxScore*/); 84 | sum += exp_omegaM[0][i]; 85 | } 86 | for (int i = 0; i < seq_size; ++i) { 87 | alpha[0][i] = exp_omegaM[0][i]/sum; 88 | } 89 | 90 | // r = H*t(alpha) 91 | for(int i = 0; i < _dw; ++i) { 92 | for(int j=0; j(r); 99 | 100 | /* for(int i = 0; i < _dw; ++i) { 101 | for(int j=0;j >& H, std::vector >& M, 110 | Tensor omegaM, Tensor exp_omegaM, 111 | Tensor alpha, Tensor r, 112 | Tensor hStar, Tensor lhStar, 113 | std::vector >& lH, bool bclear = false) { 114 | 115 | int seq_size = H.size(); 116 | if(seq_size == 0) return; 117 | 118 | if(bclear){ 119 | for (int idx = 0; idx < seq_size; idx++) { 120 | lH[idx] = 0.0; 121 | } 122 | } 123 | 124 | 125 | 126 | Tensor lr = NewTensor(Shape2(r.size(0), r.size(1)), d_zero); 127 | lr += lhStar * F(hStar); 128 | 129 | Tensor lalpha = NewTensor(Shape2(alpha.size(0), alpha.size(1)), d_zero); 130 | for(int i = 0; i < _dw; ++i) { 131 | for(int j=0; j lOmegaM = NewTensor(Shape2(omegaM.size(0), omegaM.size(1)), d_zero); 138 | for (int i = 0; i < seq_size; ++i) { 139 | for(int j=0;j > lM(seq_size); 148 | for (int idx = 0; idx < seq_size; idx++) { 149 | lM[idx] = NewTensor(Shape2(M[idx].size(0), M[idx].size(1)), d_zero); 150 | } 151 | for(int i=0;i(M[i]); 155 | } 156 | 157 | 158 | FreeSpace(&lr); 159 | FreeSpace(&lalpha); 160 | FreeSpace(&lOmegaM); 161 | for (int idx = 0; idx < seq_size; idx++) { 162 | FreeSpace(&(lM[idx])); 163 | } 164 | 165 | /* for(int i = 0; i < _dw; ++i) { 166 | for(int j=0;j H, Tensor M, 176 | Tensor omegaM, Tensor exp_omegaM, 177 | Tensor alpha, Tensor r, 178 | Tensor hStar) { 179 | 180 | int seq_size = H.size(0); 181 | if(seq_size == 0) return; 182 | 183 | 184 | 185 | for(int i=0;i(H[i]); // M = tanh(H) 187 | for(int j=0;j<_dw;j++) 188 | omegaM[0][i] += _omega[0][j]*M[i][0][j]; // dot(_omega, M[i].T()); // t(w)*M 189 | } 190 | 191 | // alpha = softmax(t(w)*M) 192 | dtype sum = 0; 193 | for (int i = 0; i < seq_size; ++i) { 194 | exp_omegaM[0][i] = exp(omegaM[0][i] ); 195 | sum += exp_omegaM[0][i]; 196 | } 197 | for (int i = 0; i < seq_size; ++i) { 198 | alpha[0][i] = exp_omegaM[0][i]/sum; 199 | } 200 | 201 | // r = H*t(alpha) 202 | for(int i = 0; i < _dw; ++i) { 203 | for(int j=0; j(r); 210 | 211 | 212 | } 213 | 214 | 215 | 216 | inline void ComputeBackwardLoss(Tensor H, Tensor M, 217 | Tensor omegaM, Tensor exp_omegaM, 218 | Tensor alpha, Tensor r, 219 | Tensor hStar, Tensor lhStar, 220 | Tensor lH, bool bclear = false) { 221 | 222 | int seq_size = H.size(0); 223 | if(seq_size == 0) return; 224 | 225 | if(bclear){ 226 | for (int idx = 0; idx < seq_size; idx++) { 227 | lH[idx] = 0.0; 228 | } 229 | } 230 | 231 | 232 | 233 | Tensor lr = NewTensor(Shape2(r.size(0), r.size(1)), d_zero); 234 | lr += lhStar * F(hStar); 235 | 236 | Tensor lalpha = NewTensor(Shape2(alpha.size(0), alpha.size(1)), d_zero); 237 | for(int i = 0; i < _dw; ++i) { 238 | for(int j=0; j lOmegaM = NewTensor(Shape2(omegaM.size(0), omegaM.size(1)), d_zero); 245 | for (int i = 0; i < seq_size; ++i) { 246 | for(int j=0;j > lM(seq_size); 255 | for (int idx = 0; idx < seq_size; idx++) { 256 | lM[idx] = NewTensor(Shape2(M[idx].size(0), M[idx].size(1)), d_zero); 257 | } 258 | for(int i=0;i(M[i]); 262 | } 263 | 264 | 265 | FreeSpace(&lr); 266 | FreeSpace(&lalpha); 267 | FreeSpace(&lOmegaM); 268 | for (int idx = 0; idx < seq_size; idx++) { 269 | FreeSpace(&(lM[idx])); 270 | } 271 | 272 | 273 | } 274 | 275 | 276 | 277 | inline void updateAdaGrad(dtype regularizationWeight, dtype adaAlpha, dtype adaEps) { 278 | _gradOmega = _gradOmega + _omega * regularizationWeight; 279 | _eg2Omega= _eg2Omega + _gradOmega * _gradOmega; 280 | _omega = _omega - _gradOmega * adaAlpha / F(_eg2Omega + adaEps); 281 | 282 | clearGrad(); 283 | } 284 | 285 | inline void clearGrad() { 286 | _gradOmega = 0; 287 | } 288 | 289 | 290 | }; 291 | 292 | #endif 293 | -------------------------------------------------------------------------------- /ade/utils.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef UTILS_H_ 3 | #define UTILS_H_ 4 | 5 | 6 | #include 7 | #include 8 | #include "Word2Vec.h" 9 | #include "Utf.h" 10 | #include "Entity.h" 11 | #include "Relation.h" 12 | #include "Token.h" 13 | #include "FoxUtil.h" 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include "ADEsentence.h" 20 | #include "Sent.h" 21 | 22 | using namespace std; 23 | 24 | #define USE_IMP 0 25 | 26 | // schema BILOU, three entity types (Bacteria,Habitat,Geographical) 27 | #define TYPE_Disease "Disease" 28 | #define TYPE_Chemical "Chemical" 29 | #define MAX_ENTITY 9 30 | #define B_Disease "B_Disease" 31 | #define I_Disease "I_Disease" 32 | #define L_Disease "L_Disease" 33 | #define U_Disease "U_Disease" 34 | #define B_Chemical "B_Chemical" 35 | #define I_Chemical "I_Chemical" 36 | #define L_Chemical "L_Chemical" 37 | #define U_Chemical "U_Chemical" 38 | #define OTHER "O" 39 | 40 | 41 | void appendEntity(const fox::Token& token, Entity& entity) { 42 | int whitespacetoAdd = token.begin-entity.end; 43 | for(int j=0;j > & groups) { 162 | 163 | struct dirent** namelist = NULL; 164 | int total = scandir(dirPath.c_str(), &namelist, 0, alphasort); 165 | 166 | for(int i=0;id_type == 8) { 169 | //file 170 | if(namelist[i]->d_name[0]=='.') 171 | continue; 172 | 173 | string filePath = dirPath; 174 | filePath += "/"; 175 | filePath += namelist[i]->d_name; 176 | string fileName = namelist[i]->d_name; 177 | 178 | ifstream ifs; 179 | ifs.open(filePath.c_str()); 180 | fox::Sent sent; 181 | string line; 182 | vector group; 183 | 184 | while(getline(ifs, line)) { 185 | if(line.empty()){ 186 | // new line 187 | if(sent.tokens.size()!=0) { 188 | group.push_back(sent); 189 | sent.tokens.clear(); 190 | } 191 | } else { 192 | vector splitted; 193 | fox::split_bychar(line, splitted, '\t'); 194 | fox::Token token; 195 | token.word = splitted[0]; 196 | token.begin = atoi(splitted[1].c_str()); 197 | token.end = atoi(splitted[2].c_str()); 198 | token.pos = splitted[3]; 199 | token.lemma = splitted[4]; 200 | token.depGov = atoi(splitted[5].c_str()); 201 | token.depType = splitted[6]; 202 | sent.tokens.push_back(token); 203 | } 204 | 205 | 206 | } 207 | 208 | ifs.close(); 209 | 210 | groups.push_back(group); 211 | } 212 | } 213 | 214 | } 215 | 216 | 217 | void loadAnnotatedFile(const string& dirPath, vector< vector > & groups) 218 | { 219 | struct dirent** namelist = NULL; 220 | int total = scandir(dirPath.c_str(), &namelist, 0, alphasort); 221 | 222 | 223 | for(int i=0;id_type == 8) { 226 | //file 227 | if(namelist[i]->d_name[0]=='.') 228 | continue; 229 | 230 | string filePath = dirPath; 231 | filePath += "/"; 232 | filePath += namelist[i]->d_name; 233 | string fileName = namelist[i]->d_name; 234 | 235 | vector group; 236 | ifstream ifs; 237 | ifs.open(filePath.c_str()); 238 | string line; 239 | ADEsentence sentence; 240 | 241 | while(getline(ifs, line)) { 242 | if(line.empty()) { 243 | group.push_back(sentence); 244 | sentence.clear(); 245 | } else { 246 | vector splitted; 247 | fox::split_bychar(line, splitted, '\t'); 248 | if(splitted[0]=="offset") { 249 | sentence.offset = atoi(splitted[1].c_str()); 250 | } else if(splitted[0]=="EN") { 251 | Entity entity; 252 | entity.text = splitted[1]; 253 | entity.type = splitted[2]; 254 | entity.begin = atoi(splitted[3].c_str()); 255 | entity.end = atoi(splitted[4].c_str()); 256 | sentence.entities.push_back(entity); 257 | } else if(splitted[0]=="ADE") { 258 | Entity entity1; 259 | entity1.text = splitted[1]; 260 | entity1.type = TYPE_Disease; 261 | entity1.begin = atoi(splitted[2].c_str()); 262 | entity1.end = atoi(splitted[3].c_str()); 263 | Entity entity2; 264 | entity2.text = splitted[4]; 265 | entity2.type = TYPE_Chemical; 266 | entity2.begin = atoi(splitted[5].c_str()); 267 | entity2.end = atoi(splitted[6].c_str()); 268 | Relation relation; 269 | relation.entity1 = entity1; 270 | relation.entity2 = entity2; 271 | sentence.relations.push_back(relation); 272 | } else { 273 | sentence.text = splitted[0]; 274 | } 275 | } 276 | 277 | } 278 | 279 | ifs.close(); 280 | 281 | groups.push_back(group); 282 | 283 | } 284 | } 285 | 286 | 287 | 288 | 289 | } 290 | 291 | 292 | bool isTokenBeforeEntity(const fox::Token& tok, const Entity& entity) { 293 | if(tok.beginentity.end) 302 | return true; 303 | else 304 | return false; 305 | 306 | } else { 307 | if(tok.end>entity.end2) 308 | return true; 309 | else 310 | return false; 311 | 312 | } 313 | 314 | } 315 | 316 | 317 | string isTokenInEntity(const fox::Token& tok, const Entity& entity) { 318 | 319 | if(tok.begin==entity.begin && tok.end==entity.end) 320 | return "U"; 321 | else if(tok.begin==entity.begin) 322 | return "B"; 323 | else if(tok.end==entity.end) 324 | return "L"; 325 | else if(tok.begin>entity.begin && tok.end=entity.begin && tok.end<=entity.end) 336 | return true; 337 | else 338 | return false; 339 | } else { 340 | 341 | if((tok.begin>=entity.begin && tok.end<=entity.end) || 342 | (tok.begin>=entity.begin2 && tok.end<=entity.end2)) 343 | return true; 344 | else 345 | return false; 346 | 347 | } 348 | } 349 | 350 | bool isTokenBetweenTwoEntities(const fox::Token& tok, const Entity& former, const Entity& latter) { 351 | 352 | if(former.end2 == -1) { 353 | if(tok.begin>=former.end && tok.end<=latter.begin) 354 | return true; 355 | else 356 | return false; 357 | } else { 358 | if(tok.begin>=former.end2 && tok.end<=latter.begin) 359 | return true; 360 | else 361 | return false; 362 | } 363 | } 364 | 365 | 366 | void deleteEntity(vector& entities, const Entity& target) 367 | { 368 | vector::iterator iter = entities.begin(); 369 | for(;iter!=entities.end();iter++) { 370 | if((*iter).equals(target)) { 371 | break; 372 | } 373 | } 374 | if(iter!=entities.end()) { 375 | entities.erase(iter); 376 | } 377 | } 378 | 379 | int containsEntity(vector& source, const Entity& target) { 380 | 381 | for(int i=0;i& source, const Relation& target) { 391 | 392 | for(int i=0;i& source, const Entity& target) { 402 | 403 | for(int i=0;i 7 | #include 8 | #include "Word2Vec.h" 9 | #include "Utf.h" 10 | #include "Entity.h" 11 | #include "Token.h" 12 | #include "FoxUtil.h" 13 | #include 14 | #include 15 | #include 16 | #include "Document.h" 17 | #include "NerExample.h" 18 | #include 19 | #include 20 | 21 | using namespace std; 22 | 23 | #define USE_IMP 0 24 | 25 | // schema BILOU, three entity types (Bacteria,Habitat,Geographical) 26 | #define TYPE_Bac "Bacteria" 27 | #define TYPE_Hab "Habitat" 28 | #define TYPE_Geo "Geographical" 29 | #define MAX_ENTITY 13 30 | #define B_Bacteria "B_Bacteria" 31 | #define I_Bacteria "I_Bacteria" 32 | #define L_Bacteria "L_Bacteria" 33 | #define U_Bacteria "U_Bacteria" 34 | #define B_Habitat "B_Habitat" 35 | #define I_Habitat "I_Habitat" 36 | #define L_Habitat "L_Habitat" 37 | #define U_Habitat "U_Habitat" 38 | #define B_Geographical "B_Geographical" 39 | #define I_Geographical "I_Geographical" 40 | #define L_Geographical "L_Geographical" 41 | #define U_Geographical "U_Geographical" 42 | #define OTHER "O" 43 | 44 | 45 | void appendEntity(const fox::Token& token, Entity& entity) { 46 | int whitespacetoAdd = token.begin-entity.end; 47 | for(int j=0;j& entities, vector& relations, const string& dir) { 190 | ofstream m_outf; 191 | string path = dir+"/BB-event+ner-"+id+".a2"; 192 | m_outf.open(path.c_str()); 193 | 194 | for(int i=0;i& entities, const string& dir) { 210 | ofstream m_outf; 211 | string path = dir+"/BB-event+ner-"+id+".a2"; 212 | m_outf.open(path.c_str()); 213 | 214 | int count=1; 215 | for(int i=0;i& entities) { 226 | 227 | for(int i=0;i& docs) { 238 | 239 | struct dirent** namelist = NULL; 240 | int total = scandir(dirPath.c_str(), &namelist, 0, alphasort); 241 | int count = 0; 242 | 243 | for(int i=0;id_type == 8) { 246 | //file 247 | if(namelist[i]->d_name[0]=='.') 248 | continue; 249 | 250 | string filePath = dirPath; 251 | filePath += "/"; 252 | filePath += namelist[i]->d_name; 253 | string fileName = namelist[i]->d_name; 254 | 255 | ifstream ifs; 256 | ifs.open(filePath.c_str()); 257 | fox::Sent sent; 258 | string line; 259 | 260 | 261 | while(getline(ifs, line)) { 262 | if(line.empty()){ 263 | // new line 264 | if(sent.tokens.size()!=0) { 265 | docs[count].sents.push_back(sent); 266 | docs[count].sents[docs[count].sents.size()-1].begin = sent.tokens[0].begin; 267 | docs[count].sents[docs[count].sents.size()-1].end = sent.tokens[sent.tokens.size()-1].end; 268 | sent.tokens.clear(); 269 | } 270 | } else { 271 | vector splitted; 272 | fox::split_bychar(line, splitted, '\t'); 273 | fox::Token token; 274 | token.word = splitted[0]; 275 | token.begin = atoi(splitted[1].c_str()); 276 | token.end = atoi(splitted[2].c_str()); 277 | token.pos = splitted[3]; 278 | token.lemma = splitted[4]; 279 | token.depGov = atoi(splitted[5].c_str()); 280 | token.depType = splitted[6]; 281 | sent.tokens.push_back(token); 282 | } 283 | 284 | 285 | 286 | } 287 | 288 | ifs.close(); 289 | count++; 290 | } 291 | } 292 | 293 | } 294 | 295 | bool isEntityOverlapped(const Entity& former, const Entity& latter) { 296 | if(former.end2==-1) { 297 | if(former.end<=latter.begin) 298 | return false; 299 | else 300 | return true; 301 | } else { 302 | if(former.end2<=latter.begin) 303 | return false; 304 | else 305 | return true; 306 | } 307 | } 308 | 309 | 310 | int parseBB3(const string& dirPath, vector& documents) 311 | { 312 | struct dirent** namelist = NULL; 313 | int total = scandir(dirPath.c_str(), &namelist, 0, alphasort); 314 | 315 | 316 | for(int i=0;id_type == 8) { 319 | //file 320 | if(namelist[i]->d_name[0]=='.') 321 | continue; 322 | 323 | string filePath = dirPath; 324 | filePath += "/"; 325 | filePath += namelist[i]->d_name; 326 | string fileName = namelist[i]->d_name; 327 | 328 | if(string::npos != filePath.find(".a1")) { // doc 329 | Document doc; 330 | doc.id = fileName.substr(fileName.find_last_of("-")+1, fileName.find(".")-fileName.find_last_of("-")-1); 331 | doc.maxParagraphId = -1; 332 | ifstream ifs; 333 | ifs.open(filePath.c_str()); 334 | string line; 335 | while(getline(ifs, line)) { 336 | 337 | if(!line.empty() && line[0]=='T') { 338 | if(doc.maxParagraphId < atoi(line.substr(1,1).c_str())) 339 | doc.maxParagraphId = atoi(line.substr(1,1).c_str()); 340 | } 341 | } 342 | 343 | ifs.close(); 344 | 345 | documents.push_back(doc); 346 | } else if(string::npos != filePath.find(".a2")) { // entity && relation 347 | Document& doc = documents[documents.size()-1]; 348 | 349 | ifstream ifs; 350 | ifs.open(filePath.c_str()); 351 | 352 | 353 | string line; 354 | while(getline(ifs, line)) { 355 | 356 | if(!line.empty()) { 357 | 358 | if(line[0] == 'T') { // entity 359 | vector splitted; 360 | fox::split_bychar(line, splitted, '\t'); 361 | Entity entity; 362 | entity.id = splitted[0]; 363 | entity.text = splitted[2]; 364 | 365 | vector temp1; 366 | fox::split(splitted[1], temp1, " |;"); 367 | 368 | if(temp1.size() == 3) { 369 | entity.type = temp1[0]; 370 | entity.begin = atoi(temp1[1].c_str()); 371 | entity.end = atoi(temp1[2].c_str()); 372 | 373 | if(doc.entities.size()>0) { 374 | Entity& former = doc.entities[doc.entities.size()-1]; 375 | if(isEntityOverlapped(former, entity)) { 376 | // if two entities is overlapped, we keep the one with narrow range 377 | int formerRange = former.end-former.begin; 378 | int entityRange = entity.end-entity.begin; 379 | if(entityRange splitted; 404 | fox::split(line, splitted, "\t| |:"); 405 | Relation relation; 406 | 407 | relation.idBacteria = splitted[3]; 408 | relation.idLocation = splitted[5]; 409 | 410 | 411 | int idxBac = findEntityById(relation.idBacteria, doc.entities); 412 | int idxLoc = findEntityById(relation.idLocation, doc.entities); 413 | if(idxBac ==-1 || idxLoc == -1) 414 | continue; // we get rid of some entities 415 | 416 | relation.bacteria = doc.entities[idxBac]; 417 | relation.location = doc.entities[idxLoc]; 418 | 419 | doc.relations.push_back(relation); 420 | } 421 | 422 | } 423 | } 424 | ifs.close(); 425 | 426 | } 427 | 428 | } 429 | } 430 | 431 | 432 | return 0; 433 | 434 | } 435 | 436 | 437 | bool isTokenBeforeEntity(const fox::Token& tok, const Entity& entity) { 438 | if(tok.beginentity.end) 447 | return true; 448 | else 449 | return false; 450 | 451 | } else { 452 | if(tok.end>entity.end2) 453 | return true; 454 | else 455 | return false; 456 | 457 | } 458 | 459 | } 460 | 461 | 462 | string isTokenInEntity(const fox::Token& tok, const Entity& entity) { 463 | 464 | if(tok.begin==entity.begin && tok.end==entity.end) 465 | return "U"; 466 | else if(tok.begin==entity.begin) 467 | return "B"; 468 | else if(tok.end==entity.end) 469 | return "L"; 470 | else if(tok.begin>entity.begin && tok.end=entity.begin && tok.end<=entity.end) 481 | return true; 482 | else 483 | return false; 484 | } else { 485 | 486 | if((tok.begin>=entity.begin && tok.end<=entity.end) || 487 | (tok.begin>=entity.begin2 && tok.end<=entity.end2)) 488 | return true; 489 | else 490 | return false; 491 | 492 | } 493 | } 494 | 495 | bool isTokenBetweenTwoEntities(const fox::Token& tok, const Entity& former, const Entity& latter) { 496 | 497 | if(former.end2 == -1) { 498 | if(tok.begin>=former.end && tok.end<=latter.begin) 499 | return true; 500 | else 501 | return false; 502 | } else { 503 | if(tok.begin>=former.end2 && tok.end<=latter.begin) 504 | return true; 505 | else 506 | return false; 507 | } 508 | } 509 | 510 | // delete the entities not in the window sentence 511 | void deleteEntityOutOfWindow(vector& entities, int begin, int end) { 512 | 513 | 514 | vector::iterator iter = entities.begin(); 515 | for(;iter!=entities.end();) { 516 | if((*iter).begin < begin || (*iter).end > end) { 517 | iter = entities.erase(iter); 518 | } else { 519 | iter++; 520 | } 521 | } 522 | 523 | } 524 | 525 | void findEntityInWindow(int begin, int end, vector& src, vector& dst) { 526 | for(int i=0;i= begin && src[i].end <= end) 529 | dst.push_back(src[i]); 530 | } else { 531 | if(src[i].begin >= begin && src[i].end2 <= end) 532 | dst.push_back(src[i]); 533 | } 534 | 535 | } 536 | 537 | return; 538 | } 539 | 540 | // sentence spans from begin(include) to end(exclude), sorted because doc.entities are sorted 541 | void findEntityInSent(int begin, int end, const Document& doc, vector& results) { 542 | 543 | for(int i=0;i= begin && doc.entities[i].end <= end) 546 | results.push_back(doc.entities[i]); 547 | } else { 548 | if(doc.entities[i].begin >= begin && doc.entities[i].end2 <= end) 549 | results.push_back(doc.entities[i]); 550 | } 551 | 552 | } 553 | 554 | return; 555 | } 556 | 557 | void findEntityInSent(int begin, int end, const vector& source, vector& results) { 558 | 559 | for(int i=0;i= begin && source[i].end <= end) 562 | results.push_back(source[i]); 563 | } else { 564 | if(source[i].begin >= begin && source[i].end2 <= end) 565 | results.push_back(source[i]); 566 | } 567 | 568 | } 569 | 570 | return; 571 | } 572 | 573 | void deleteEntity(vector& entities, const Entity& target) 574 | { 575 | vector::iterator iter = entities.begin(); 576 | for(;iter!=entities.end();iter++) { 577 | if((*iter).equals(target)) { 578 | break; 579 | } 580 | } 581 | if(iter!=entities.end()) { 582 | entities.erase(iter); 583 | } 584 | } 585 | 586 | int containsEntity(vector& source, const Entity& target) { 587 | 588 | for(int i=0;i& source, const Relation& target) { 598 | 599 | for(int i=0;i& source, const Entity& target) { 609 | 610 | for(int i=0;i 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | //#include 27 | #include 28 | #include 29 | 30 | 31 | /*! 32 | 33 | \mainpage A Simple C++ Argument Parser 34 | 35 | This is a class to aid handling of command line arguments in a C++ 36 | program. It follows (and enforces) the unsual conventions. 37 | 38 | New arguments are added by calling functions of the form of new_[optional_/named_]type. 39 | 40 | - type is the type of the value for the argument (int, double, string, vector of strings...) 41 | 42 | - optional means that the use doesn't have to supply it. 43 | 44 | - named means that it is identified by following an "-c" or 45 | "--long-name" identifier. All named arguments are optional. 46 | 47 | Unnamed arguments are expected to appear in order of addition on 48 | the command line. Named arguments can be passed in any order (and 49 | mixed with the unnamed arguments). 50 | 51 | The special argument "--" means that all remaining arguments are 52 | treated as unnamed (so you can pass file names that begin with -). 53 | 54 | When calling a new_foo function to create an argument the 55 | following can/must be passed in 56 | 57 | - For all arguments 58 | + The place to put the value. 59 | + A description of the value of the argument (i.e. it is a filename) 60 | + A description of the argument as a whole (i.e. it is the input file). 61 | 62 | - For named arguments 63 | + A character for the short name. 64 | + A string for the long name. 65 | 66 | When the program is called if "--help" is passed as an argument 67 | the useage information is printed and the program exits. 68 | 69 | There is always an implicit "-v" flag for verbose which sets the 70 | dsr::verbose variable and a "-V" which sets the VERBOSE variable. 71 | 72 | Any extra arguments or arguments with unexpected types are treated 73 | as errors and cause the program to abort. Extra arguments can be 74 | allowed for by adding a std::vector to store them 75 | using the "set_string_vector" function. All extra (unnamed) 76 | arguments are placed there. 77 | 78 | This software is not subject to copyright protection and is in the 79 | public domain. Neither Stanford nor the author assume any 80 | responsibility whatsoever for its use by other parties, and makes no 81 | guarantees, expressed or implied, about its quality, reliability, or 82 | any other characteristic. 83 | 84 | An example using the class is: 85 | \include argument_helper_example.cc 86 | 87 | */ 88 | 89 | namespace dsr{ 90 | extern bool verbose, VERBOSE; 91 | 92 | //! A helper class for parsing command line arguments. 93 | /*! 94 | This is the only class you need to look at in order to use it. 95 | */ 96 | class Argument_helper{ 97 | private: 98 | class Argument_target; 99 | 100 | 101 | class FlagTarget; 102 | class DoubleTarget; 103 | class IntTarget; 104 | class UIntTarget; 105 | class StringTarget; 106 | class CharTarget; 107 | class StringVectorTarget; 108 | 109 | public: 110 | Argument_helper(); 111 | //! Toggle a boolean 112 | void new_flag(const char *key, const char *long_name, const char *description,bool &dest); 113 | 114 | //! add a string argument 115 | void new_string( const char *arg_description, const char *description, std::string &dest); 116 | //! add a string which must have a key. 117 | void new_named_string(const char *key, const char *long_name, 118 | const char *arg_description, 119 | const char *description, std::string &dest); 120 | //! Add an optional string-- any extra arguments are put in these. 121 | void new_optional_string( const char *arg_description, const char *description, std::string &dest); 122 | 123 | //! add an int 124 | void new_int( const char *arg_description, const char *description, int &dest); 125 | //! Add an int. 126 | void new_named_int(const char *key, const char *long_name,const char *value_name, 127 | const char *description, 128 | int &dest); 129 | //! Add an optional named int. 130 | void new_optional_int(const char *value_name, 131 | const char *description, 132 | int &dest); 133 | 134 | //! Add a named double. 135 | void new_double(const char *value_name, 136 | const char *description, 137 | double &dest); 138 | 139 | //! Add a named double. 140 | void new_named_double(const char *key, const char *long_name,const char *value_name, 141 | const char *description, 142 | double &dest); 143 | //! Add a named double. 144 | void new_optional_double(const char *value_name, 145 | const char *description, 146 | double &dest); 147 | 148 | //! Add an char. 149 | void new_char(const char *value_name, 150 | const char *description, 151 | char &dest); 152 | //! Add an optional char. 153 | void new_named_char(const char *key, const char *long_name,const char *value_name, 154 | const char *description, 155 | char &dest); 156 | //! Add an named char. 157 | void new_optional_char(const char *value_name, 158 | const char *description, 159 | char &dest); 160 | 161 | //! Add an unsigned int. 162 | void new_unsigned_int(const char *value_name, const char *description, 163 | unsigned int &dest); 164 | //! Add an named unsigned int. 165 | void new_optional_unsigned_int(const char *value_name, const char *description, 166 | unsigned int &dest); 167 | //! Add an optional named unsigned int. 168 | void new_named_unsigned_int(const char *key, const char *long_name, 169 | const char *value_name, const char *description, 170 | unsigned int &dest); 171 | 172 | 173 | //! add a target which takes a list of strings 174 | /*! 175 | Only named makes sense as the string vector default handles unnamed and optional. 176 | */ 177 | void new_named_string_vector(const char *key, const char *long_name, 178 | const char *value_name, const char *description, 179 | std::vector &dest); 180 | 181 | //! add a vector of strings. 182 | /*! Any arguments which are not claimed by earlier unnamed 183 | arguments or which are named are put here. This means you cannot 184 | have a string vector followed by a string. 185 | */ 186 | void set_string_vector(const char *arg_description, const char *description, std::vector &dest); 187 | 188 | //! Set who wrote the program. 189 | void set_author(const char *author); 190 | 191 | //! Set what the program does. 192 | void set_description(const char *descr); 193 | 194 | //! Set what the version is. 195 | void set_version(float v); 196 | void set_version(const char *str); 197 | 198 | //! Set the name of the program. 199 | void set_name(const char *name); 200 | 201 | //! Set when the program was built. 202 | void set_build_date(const char *date); 203 | 204 | //! Process the list of arguments and parse them. 205 | /*! 206 | This returns true if all the required arguments are there. 207 | */ 208 | void process(int argc, const char **argv); 209 | void process(int argc, char **argv){ 210 | process(argc, const_cast(argv)); 211 | } 212 | //! Write how to call the program. 213 | void write_usage(std::ostream &out) const; 214 | //! Write the values of all the possible arguments. 215 | void write_values(std::ostream &out) const; 216 | 217 | ~Argument_helper(); 218 | protected: 219 | typedef std::map SMap; 220 | typedef std::map LMap; 221 | typedef std::vector UVect; 222 | // A map from short names to arguments. 223 | SMap short_names_; 224 | // A map from long names to arguments. 225 | LMap long_names_; 226 | std::string author_; 227 | std::string name_; 228 | std::string description_; 229 | std::string date_; 230 | float version_; 231 | bool seen_end_named_; 232 | // List of unnamed arguments 233 | std::vector unnamed_arguments_; 234 | std::vector optional_unnamed_arguments_; 235 | std::vector all_arguments_; 236 | std::string extra_arguments_descr_; 237 | std::string extra_arguments_arg_descr_; 238 | std::vector *extra_arguments_; 239 | std::vector::iterator current_unnamed_; 240 | std::vector::iterator current_optional_unnamed_; 241 | void new_argument_target(Argument_target*); 242 | void handle_error() const; 243 | private: 244 | Argument_helper(const Argument_helper &){}; 245 | const Argument_helper& operator=(const Argument_helper &){return *this;} 246 | }; 247 | 248 | 249 | 250 | 251 | bool verbose=false, VERBOSE=false; 252 | 253 | 254 | //////////////////////////////////////////////// Argument Targets 255 | 256 | // This is a base class for representing one argument value. 257 | /* 258 | This is inherited by many classes and which represent the different types. 259 | */ 260 | class Argument_helper::Argument_target { 261 | public: 262 | std::string key; 263 | std::string long_name; 264 | std::string description; 265 | std::string arg_description; 266 | 267 | Argument_target(const std::string& k, const std::string& lname, 268 | const std::string& descr, 269 | const std::string& arg_descr) { 270 | key=k; 271 | long_name=lname; 272 | description=descr; 273 | arg_description=arg_descr; 274 | } 275 | Argument_target(const std::string& descr, 276 | const std::string& arg_descr) { 277 | key=""; 278 | long_name=""; 279 | description=descr; 280 | arg_description=arg_descr; 281 | } 282 | virtual bool process(int &, const char **&)=0; 283 | virtual void write_name(std::ostream &out) const; 284 | virtual void write_value(std::ostream &out) const=0; 285 | virtual void write_usage(std::ostream &out) const; 286 | virtual ~Argument_target(){} 287 | }; 288 | 289 | void Argument_helper::Argument_target::write_name(std::ostream &out) const { 290 | if (key != "") out << '-' << key; 291 | else if (!long_name.empty()) out << "--" << long_name; 292 | else out << arg_description; 293 | } 294 | 295 | 296 | void Argument_helper::Argument_target::write_usage(std::ostream &out) const { 297 | if (key != "") { 298 | out << '-' << key; 299 | out << "/--" << long_name; 300 | } 301 | out << ' ' << arg_description; 302 | out << "\t" << description; 303 | out << " Value: "; 304 | write_value(out); 305 | out << std::endl; 306 | } 307 | 308 | class Argument_helper::FlagTarget: public Argument_helper::Argument_target{ 309 | public: 310 | bool &val; 311 | FlagTarget(const char *k, const char *lname, 312 | const char *descr, 313 | bool &b): Argument_target(std::string(k), std::string(lname), std::string(descr), 314 | std::string()), val(b){} 315 | virtual bool process(int &, const char **&){ 316 | val= !val; 317 | return true; 318 | } 319 | virtual void write_value(std::ostream &out) const { 320 | out << val; 321 | } 322 | 323 | virtual void write_usage(std::ostream &out) const { 324 | if (key != "") { 325 | out << '-' << key; 326 | out << "/--" << long_name; 327 | } 328 | out << "\t" << description; 329 | out << " Value: "; 330 | write_value(out); 331 | out << std::endl; 332 | } 333 | virtual ~FlagTarget(){} 334 | }; 335 | 336 | class Argument_helper::DoubleTarget: public Argument_target{ 337 | public: 338 | double &val; 339 | DoubleTarget(const char *k, const char *lname, 340 | const char *arg_descr, 341 | const char *descr, double &b): Argument_target(std::string(k), std::string(lname), 342 | std::string(descr), 343 | std::string(arg_descr)), val(b){} 344 | DoubleTarget(const char *arg_descr, 345 | const char *descr, double &b): Argument_target(std::string(descr), 346 | std::string(arg_descr)), val(b){} 347 | virtual bool process(int &argc, const char **&argv){ 348 | if (argc==0){ 349 | std::cerr << "Missing value for argument." << std::endl; 350 | return false; 351 | } 352 | if (sscanf(argv[0], "%le", &val) ==1){ 353 | --argc; 354 | ++argv; 355 | return true; 356 | } else { 357 | std::cerr << "Double not found at " << argv[0] << std::endl; 358 | return false; 359 | } 360 | } 361 | virtual void write_value(std::ostream &out) const { 362 | out << val; 363 | } 364 | virtual ~DoubleTarget(){} 365 | }; 366 | 367 | class Argument_helper::IntTarget: public Argument_target{ 368 | public: 369 | int &val; 370 | IntTarget(const char *arg_descr, 371 | const char *descr, int &b): Argument_target(std::string(descr), std::string(arg_descr)), 372 | val(b){} 373 | IntTarget(const char *k, const char *lname, 374 | const char *arg_descr, 375 | const char *descr, int &b): Argument_target(std::string(k), std::string(lname), 376 | std::string(descr), 377 | std::string(arg_descr)), 378 | val(b){} 379 | virtual bool process(int &argc, const char **&argv){ 380 | if (argc==0){ 381 | std::cerr << "Missing value for argument." << std::endl; 382 | return false; 383 | } 384 | if (sscanf(argv[0], "%d", &val) ==1){ 385 | --argc; 386 | ++argv; 387 | return true; 388 | } else { 389 | std::cerr << "Integer not found at " << argv[0] << std::endl; 390 | return false; 391 | } 392 | } 393 | virtual void write_value(std::ostream &out) const { 394 | out << val; 395 | } 396 | virtual ~IntTarget(){} 397 | }; 398 | 399 | class Argument_helper::UIntTarget: public Argument_target{ 400 | public: 401 | unsigned int &val; 402 | UIntTarget(const char *arg_descr, 403 | const char *descr, unsigned int &b): Argument_target(std::string(descr), std::string(arg_descr)), 404 | val(b){} 405 | UIntTarget(const char *k, const char *lname, 406 | const char *arg_descr, 407 | const char *descr, unsigned int &b): Argument_target(std::string(k), std::string(lname), 408 | std::string(descr), 409 | std::string(arg_descr)), 410 | val(b){} 411 | virtual bool process(int &argc, const char **&argv){ 412 | if (argc==0){ 413 | std::cerr << "Missing value for argument." << std::endl; 414 | return false; 415 | } 416 | if (sscanf(argv[0], "%ud", &val) ==1){ 417 | --argc; 418 | ++argv; 419 | return true; 420 | } else { 421 | std::cerr << "Unsigned integer not found at " << argv[0] << std::endl; 422 | return false; 423 | } 424 | } 425 | virtual void write_value(std::ostream &out) const { 426 | out << val; 427 | } 428 | virtual ~UIntTarget(){} 429 | }; 430 | 431 | 432 | class Argument_helper::CharTarget: public Argument_target{ 433 | public: 434 | char &val; 435 | CharTarget(const char *k, const char *lname, 436 | const char *arg_descr, 437 | const char *descr, char &b): Argument_target(std::string(k), std::string(lname), 438 | std::string(descr), 439 | std::string(arg_descr)), val(b){} 440 | CharTarget(const char *arg_descr, 441 | const char *descr, char &b): Argument_target(std::string(descr), 442 | std::string(arg_descr)), val(b){} 443 | virtual bool process(int &argc, const char **&argv){ 444 | if (argc==0){ 445 | std::cerr << "Missing value for argument." << std::endl; 446 | return false; 447 | } 448 | if (sscanf(argv[0], "%c", &val) ==1){ 449 | --argc; 450 | ++argv; 451 | return true; 452 | } else { 453 | std::cerr << "Character not found at " << argv[0] << std::endl; 454 | return false; 455 | } 456 | } 457 | virtual void write_value(std::ostream &out) const { 458 | out << val; 459 | } 460 | virtual ~CharTarget(){} 461 | }; 462 | 463 | 464 | class Argument_helper::StringTarget: public Argument_target{ 465 | public: 466 | std::string &val; 467 | StringTarget(const char *arg_descr, 468 | const char *descr, std::string &b): Argument_target(std::string(descr), std::string(arg_descr)), 469 | val(b){} 470 | 471 | StringTarget(const char *k, const char *lname, const char *arg_descr, 472 | const char *descr, std::string &b): Argument_target(std::string(k), std::string(lname), 473 | std::string(descr), std::string(arg_descr)), 474 | val(b){} 475 | 476 | virtual bool process(int &argc, const char **&argv){ 477 | if (argc==0){ 478 | std::cerr << "Missing string argument." << std::endl; 479 | return false; 480 | } 481 | val= argv[0]; 482 | --argc; 483 | ++argv; 484 | return true; 485 | } 486 | virtual void write_value(std::ostream &out) const { 487 | out << val; 488 | } 489 | virtual ~StringTarget(){} 490 | }; 491 | 492 | 493 | class Argument_helper::StringVectorTarget: public Argument_target{ 494 | public: 495 | std::vector &val; 496 | 497 | StringVectorTarget(const char *k, const char *lname, const char *arg_descr, 498 | const char *descr, std::vector &b): Argument_target(std::string(k), std::string(lname), 499 | std::string(descr), std::string(arg_descr)), 500 | val(b){} 501 | 502 | virtual bool process(int &argc, const char **&argv){ 503 | while (argc >0 && argv[0][0] != '-'){ 504 | val.push_back(argv[0]); 505 | --argc; 506 | ++argv; 507 | } 508 | return true; 509 | } 510 | virtual void write_value(std::ostream &out) const { 511 | for (unsigned int i=0; i< val.size(); ++i){ 512 | out << val[i] << " "; 513 | } 514 | } 515 | virtual ~StringVectorTarget(){} 516 | }; 517 | 518 | 519 | //////////////////////////////////////////////// Argument_helper functions 520 | 521 | 522 | Argument_helper::Argument_helper(){ 523 | author_="a programmer"; 524 | description_= "This program produces output."; 525 | date_= "some day a long, long time ago."; 526 | version_=-1; 527 | extra_arguments_=NULL; 528 | seen_end_named_=false; 529 | new_flag("v", "verbose", "Whether to print extra information", verbose); 530 | new_flag("V", "VERBOSE", "Whether to print lots of extra information", VERBOSE); 531 | } 532 | 533 | 534 | 535 | void Argument_helper::set_string_vector(const char *arg_description, 536 | const char *description, 537 | std::vector &dest){ 538 | assert(extra_arguments_==NULL); 539 | extra_arguments_descr_= description; 540 | extra_arguments_arg_descr_= arg_description; 541 | extra_arguments_= &dest; 542 | } 543 | 544 | void Argument_helper::set_author(const char *author){ 545 | author_=author; 546 | } 547 | 548 | void Argument_helper::set_description(const char *descr){ 549 | description_= descr; 550 | } 551 | 552 | void Argument_helper::set_name(const char *descr){ 553 | name_= descr; 554 | } 555 | 556 | void Argument_helper::set_version(float v){ 557 | version_=v; 558 | } 559 | 560 | void Argument_helper::set_version(const char *s){ 561 | version_=atof(s); 562 | } 563 | 564 | void Argument_helper::set_build_date(const char *date){ 565 | date_=date; 566 | } 567 | 568 | void Argument_helper::new_argument_target(Argument_target *t) { 569 | assert(t!= NULL); 570 | if (t->key != ""){ 571 | if (short_names_.find(t->key) != short_names_.end()){ 572 | std::cerr << "Two arguments are defined with the same string key, namely" << std::endl; 573 | short_names_[t->key]->write_usage(std::cerr); 574 | std::cerr << "\n and \n"; 575 | t->write_usage(std::cerr); 576 | std::cerr << std::endl; 577 | } 578 | short_names_[t->key]= t; 579 | } 580 | if (!t->long_name.empty()){ 581 | if (long_names_.find(t->long_name) != long_names_.end()){ 582 | std::cerr << "Two arguments are defined with the same long key, namely" << std::endl; 583 | long_names_[t->long_name]->write_usage(std::cerr); 584 | std::cerr << "\n and \n"; 585 | t->write_usage(std::cerr); 586 | std::cerr << std::endl; 587 | } 588 | long_names_[t->long_name]= t; 589 | } 590 | all_arguments_.push_back(t); 591 | } 592 | 593 | void Argument_helper::new_flag(const char *key, const char *long_name, const char *description,bool &dest){ 594 | Argument_target *t= new FlagTarget(key, long_name, description, dest); 595 | new_argument_target(t); 596 | }; 597 | 598 | 599 | 600 | void Argument_helper::new_string(const char *arg_description, const char *description, 601 | std::string &dest){ 602 | Argument_target *t= new StringTarget(arg_description, description, dest); 603 | unnamed_arguments_.push_back(t); 604 | all_arguments_.push_back(t); 605 | }; 606 | void Argument_helper::new_optional_string(const char *arg_description, const char *description, 607 | std::string &dest){ 608 | Argument_target *t= new StringTarget(arg_description, description, dest); 609 | optional_unnamed_arguments_.push_back(t); 610 | }; 611 | void Argument_helper::new_named_string(const char *key, const char *long_name, 612 | const char *arg_description, const char *description, 613 | std::string &dest){ 614 | Argument_target *t= new StringTarget(key, long_name, arg_description, description, dest); 615 | new_argument_target(t); 616 | }; 617 | 618 | 619 | void Argument_helper::new_named_string_vector(const char *key, const char *long_name, 620 | const char *arg_description, const char *description, 621 | std::vector &dest){ 622 | Argument_target *t= new StringVectorTarget(key, long_name, arg_description, description, dest); 623 | new_argument_target(t); 624 | }; 625 | 626 | 627 | 628 | void Argument_helper::new_int(const char *arg_description, const char *description, 629 | int &dest){ 630 | Argument_target *t= new IntTarget(arg_description, description, dest); 631 | unnamed_arguments_.push_back(t); 632 | all_arguments_.push_back(t); 633 | }; 634 | void Argument_helper::new_optional_int(const char *arg_description, const char *description, 635 | int &dest){ 636 | Argument_target *t= new IntTarget(arg_description, description, dest); 637 | optional_unnamed_arguments_.push_back(t); 638 | }; 639 | void Argument_helper::new_named_int(const char *key, const char *long_name, 640 | const char *arg_description, const char *description, 641 | int &dest){ 642 | Argument_target *t= new IntTarget(key, long_name, arg_description, description, dest); 643 | new_argument_target(t); 644 | }; 645 | 646 | void Argument_helper::new_unsigned_int(const char *arg_description, const char *description, 647 | unsigned int &dest){ 648 | Argument_target *t= new UIntTarget(arg_description, description, dest); 649 | unnamed_arguments_.push_back(t); 650 | all_arguments_.push_back(t); 651 | }; 652 | void Argument_helper::new_optional_unsigned_int(const char *arg_description, const char *description, 653 | unsigned int &dest){ 654 | Argument_target *t= new UIntTarget(arg_description, description, dest); 655 | optional_unnamed_arguments_.push_back(t); 656 | }; 657 | void Argument_helper::new_named_unsigned_int(const char *key, const char *long_name, 658 | const char *arg_description, const char *description, 659 | unsigned int &dest){ 660 | Argument_target *t= new UIntTarget(key, long_name, arg_description, description, dest); 661 | new_argument_target(t); 662 | }; 663 | 664 | 665 | void Argument_helper::new_double(const char *arg_description, const char *description, 666 | double &dest){ 667 | Argument_target *t= new DoubleTarget(arg_description, description, dest); 668 | unnamed_arguments_.push_back(t); 669 | all_arguments_.push_back(t); 670 | }; 671 | void Argument_helper::new_optional_double(const char *arg_description, const char *description, 672 | double &dest){ 673 | Argument_target *t= new DoubleTarget(arg_description, description, dest); 674 | optional_unnamed_arguments_.push_back(t); 675 | }; 676 | void Argument_helper::new_named_double(const char *key, const char *long_name, 677 | const char *arg_description, const char *description, 678 | double &dest){ 679 | Argument_target *t= new DoubleTarget(key, long_name, arg_description, description, dest); 680 | new_argument_target(t); 681 | }; 682 | 683 | void Argument_helper::new_char(const char *arg_description, const char *description, 684 | char &dest){ 685 | Argument_target *t= new CharTarget(arg_description, description, dest); 686 | unnamed_arguments_.push_back(t); 687 | all_arguments_.push_back(t); 688 | }; 689 | void Argument_helper::new_optional_char(const char *arg_description, const char *description, 690 | char &dest){ 691 | Argument_target *t= new CharTarget(arg_description, description, dest); 692 | optional_unnamed_arguments_.push_back(t); 693 | }; 694 | void Argument_helper::new_named_char(const char *key, const char *long_name, 695 | const char *arg_description, const char *description, 696 | char &dest){ 697 | Argument_target *t= new CharTarget(key, long_name, arg_description, description, dest); 698 | new_argument_target(t); 699 | }; 700 | 701 | 702 | 703 | void Argument_helper::write_usage(std::ostream &out) const { 704 | out << name_ << " version " << version_ << ", by " << author_ << std::endl; 705 | out << description_ << std::endl; 706 | out << "Compiled on " << date_ << std::endl << std::endl; 707 | out << "Usage: " << name_ << " "; 708 | for (UVect::const_iterator it= unnamed_arguments_.begin(); it != unnamed_arguments_.end(); ++it){ 709 | (*it)->write_name(out); 710 | out << " "; 711 | } 712 | for (UVect::const_iterator it= optional_unnamed_arguments_.begin(); 713 | it != optional_unnamed_arguments_.end(); ++it){ 714 | out << "["; 715 | (*it)->write_name(out); 716 | out << "] "; 717 | } 718 | if (extra_arguments_ != NULL) { 719 | out << "[" << extra_arguments_arg_descr_ << "]"; 720 | } 721 | 722 | out << std::endl << std::endl; 723 | out << "All arguments:\n"; 724 | for (UVect::const_iterator it= unnamed_arguments_.begin(); it != unnamed_arguments_.end(); ++it){ 725 | (*it)->write_usage(out); 726 | } 727 | for (UVect::const_iterator it= optional_unnamed_arguments_.begin(); 728 | it != optional_unnamed_arguments_.end(); ++it){ 729 | (*it)->write_usage(out); 730 | } 731 | 732 | if (!extra_arguments_arg_descr_.empty()) { 733 | out << extra_arguments_arg_descr_ << ": " << extra_arguments_descr_ << std::endl; 734 | } 735 | for (SMap::const_iterator it= short_names_.begin(); it != short_names_.end(); ++it){ 736 | (it->second)->write_usage(out); 737 | } 738 | } 739 | 740 | 741 | 742 | void Argument_helper::write_values(std::ostream &out) const { 743 | for (UVect::const_iterator it= unnamed_arguments_.begin(); it != unnamed_arguments_.end(); ++it){ 744 | out << (*it)->description; 745 | out << ": "; 746 | (*it)->write_value(out); 747 | out << std::endl; 748 | } 749 | for (UVect::const_iterator it= optional_unnamed_arguments_.begin(); 750 | it != optional_unnamed_arguments_.end(); ++it){ 751 | out << (*it)->description; 752 | out << ": "; 753 | (*it)->write_value(out); 754 | out << std::endl; 755 | } 756 | if (extra_arguments_!=NULL){ 757 | for (std::vector::const_iterator it= extra_arguments_->begin(); 758 | it != extra_arguments_->end(); ++it){ 759 | out << *it << " "; 760 | } 761 | } 762 | 763 | for (SMap::const_iterator it= short_names_.begin(); it != short_names_.end(); ++it){ 764 | out << it->second->description; 765 | out << ": "; 766 | it->second->write_value(out); 767 | out << std::endl; 768 | } 769 | } 770 | 771 | Argument_helper::~Argument_helper(){ 772 | for (std::vector::iterator it= all_arguments_.begin(); 773 | it != all_arguments_.end(); ++it){ 774 | delete *it; 775 | } 776 | } 777 | 778 | 779 | void Argument_helper::process(int argc, const char **argv){ 780 | name_= argv[0]; 781 | ++argv; 782 | --argc; 783 | 784 | current_unnamed_= unnamed_arguments_.begin(); 785 | current_optional_unnamed_= optional_unnamed_arguments_.begin(); 786 | 787 | for ( int i=0; i< argc; ++i){ 788 | if (strcmp(argv[i], "--help") == 0){ 789 | write_usage(std::cout); 790 | exit(0); 791 | } 792 | } 793 | 794 | while (argc != 0){ 795 | 796 | const char* cur_arg= argv[0]; 797 | if (cur_arg[0]=='-' && !seen_end_named_){ 798 | --argc; ++argv; 799 | if (cur_arg[1]=='-'){ 800 | if (cur_arg[2] == '\0') { 801 | //std::cout << "Ending flags " << std::endl; 802 | seen_end_named_=true; 803 | } else { 804 | // long argument 805 | LMap::iterator f= long_names_.find(cur_arg+2); 806 | if ( f != long_names_.end()){ 807 | if (!f->second->process(argc, argv)) { 808 | handle_error(); 809 | } 810 | } else { 811 | std::cerr<< "Invalid long argument "<< cur_arg << ".\n"; 812 | handle_error(); 813 | } 814 | } 815 | } else { 816 | if (cur_arg[1]=='\0') { 817 | std::cerr << "Invalid argument " << cur_arg << ".\n"; 818 | handle_error(); 819 | } 820 | SMap::iterator f= short_names_.find(cur_arg+1); 821 | if ( f != short_names_.end()){ 822 | if (!f->second->process(argc, argv)) { 823 | handle_error(); 824 | } 825 | } else { 826 | std::cerr<< "Invalid short argument "<< cur_arg << ".\n"; 827 | handle_error(); 828 | } 829 | } 830 | } else { 831 | if (current_unnamed_ != unnamed_arguments_.end()){ 832 | Argument_target *t= *current_unnamed_; 833 | t->process(argc, argv); 834 | ++current_unnamed_; 835 | } else if (current_optional_unnamed_ != optional_unnamed_arguments_.end()){ 836 | Argument_target *t= *current_optional_unnamed_; 837 | t->process(argc, argv); 838 | ++current_optional_unnamed_; 839 | } else if (extra_arguments_!= NULL){ 840 | extra_arguments_->push_back(cur_arg); 841 | --argc; 842 | ++argv; 843 | } else { 844 | std::cerr << "Invalid extra argument " << argv[0] << std::endl; 845 | handle_error(); 846 | } 847 | } 848 | } 849 | 850 | if (current_unnamed_ != unnamed_arguments_.end()){ 851 | std::cerr << "Missing required arguments:" << std::endl; 852 | for (; current_unnamed_ != unnamed_arguments_.end(); ++current_unnamed_){ 853 | (*current_unnamed_)->write_name(std::cerr); 854 | std::cerr << std::endl; 855 | } 856 | std::cerr << std::endl; 857 | handle_error(); 858 | } 859 | 860 | if (VERBOSE) verbose=true; 861 | } 862 | 863 | void Argument_helper::handle_error() const { 864 | write_usage(std::cerr); 865 | exit(1); 866 | } 867 | 868 | } 869 | 870 | 871 | 872 | 873 | 874 | #endif 875 | -------------------------------------------------------------------------------- /bb3/Argument_helper.h: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Argument Helper 4 | * 5 | * Daniel Russel drussel@alumni.princeton.edu 6 | * Stanford University 7 | * 8 | * 9 | * This software is not subject to copyright protection and is in the 10 | * public domain. Neither Stanford nor the author assume any 11 | * responsibility whatsoever for its use by other parties, and makes no 12 | * guarantees, expressed or implied, about its quality, reliability, or 13 | * any other characteristic. 14 | * 15 | */ 16 | 17 | #ifndef _DSR_ARGS_H_ 18 | #define _DSR_ARGS_H_ 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | //#include 27 | #include 28 | #include 29 | 30 | 31 | /*! 32 | 33 | \mainpage A Simple C++ Argument Parser 34 | 35 | This is a class to aid handling of command line arguments in a C++ 36 | program. It follows (and enforces) the unsual conventions. 37 | 38 | New arguments are added by calling functions of the form of new_[optional_/named_]type. 39 | 40 | - type is the type of the value for the argument (int, double, string, vector of strings...) 41 | 42 | - optional means that the use doesn't have to supply it. 43 | 44 | - named means that it is identified by following an "-c" or 45 | "--long-name" identifier. All named arguments are optional. 46 | 47 | Unnamed arguments are expected to appear in order of addition on 48 | the command line. Named arguments can be passed in any order (and 49 | mixed with the unnamed arguments). 50 | 51 | The special argument "--" means that all remaining arguments are 52 | treated as unnamed (so you can pass file names that begin with -). 53 | 54 | When calling a new_foo function to create an argument the 55 | following can/must be passed in 56 | 57 | - For all arguments 58 | + The place to put the value. 59 | + A description of the value of the argument (i.e. it is a filename) 60 | + A description of the argument as a whole (i.e. it is the input file). 61 | 62 | - For named arguments 63 | + A character for the short name. 64 | + A string for the long name. 65 | 66 | When the program is called if "--help" is passed as an argument 67 | the useage information is printed and the program exits. 68 | 69 | There is always an implicit "-v" flag for verbose which sets the 70 | dsr::verbose variable and a "-V" which sets the VERBOSE variable. 71 | 72 | Any extra arguments or arguments with unexpected types are treated 73 | as errors and cause the program to abort. Extra arguments can be 74 | allowed for by adding a std::vector to store them 75 | using the "set_string_vector" function. All extra (unnamed) 76 | arguments are placed there. 77 | 78 | This software is not subject to copyright protection and is in the 79 | public domain. Neither Stanford nor the author assume any 80 | responsibility whatsoever for its use by other parties, and makes no 81 | guarantees, expressed or implied, about its quality, reliability, or 82 | any other characteristic. 83 | 84 | An example using the class is: 85 | \include argument_helper_example.cc 86 | 87 | */ 88 | 89 | namespace dsr{ 90 | extern bool verbose, VERBOSE; 91 | 92 | //! A helper class for parsing command line arguments. 93 | /*! 94 | This is the only class you need to look at in order to use it. 95 | */ 96 | class Argument_helper{ 97 | private: 98 | class Argument_target; 99 | 100 | 101 | class FlagTarget; 102 | class DoubleTarget; 103 | class IntTarget; 104 | class UIntTarget; 105 | class StringTarget; 106 | class CharTarget; 107 | class StringVectorTarget; 108 | 109 | public: 110 | Argument_helper(); 111 | //! Toggle a boolean 112 | void new_flag(const char *key, const char *long_name, const char *description,bool &dest); 113 | 114 | //! add a string argument 115 | void new_string( const char *arg_description, const char *description, std::string &dest); 116 | //! add a string which must have a key. 117 | void new_named_string(const char *key, const char *long_name, 118 | const char *arg_description, 119 | const char *description, std::string &dest); 120 | //! Add an optional string-- any extra arguments are put in these. 121 | void new_optional_string( const char *arg_description, const char *description, std::string &dest); 122 | 123 | //! add an int 124 | void new_int( const char *arg_description, const char *description, int &dest); 125 | //! Add an int. 126 | void new_named_int(const char *key, const char *long_name,const char *value_name, 127 | const char *description, 128 | int &dest); 129 | //! Add an optional named int. 130 | void new_optional_int(const char *value_name, 131 | const char *description, 132 | int &dest); 133 | 134 | //! Add a named double. 135 | void new_double(const char *value_name, 136 | const char *description, 137 | double &dest); 138 | 139 | //! Add a named double. 140 | void new_named_double(const char *key, const char *long_name,const char *value_name, 141 | const char *description, 142 | double &dest); 143 | //! Add a named double. 144 | void new_optional_double(const char *value_name, 145 | const char *description, 146 | double &dest); 147 | 148 | //! Add an char. 149 | void new_char(const char *value_name, 150 | const char *description, 151 | char &dest); 152 | //! Add an optional char. 153 | void new_named_char(const char *key, const char *long_name,const char *value_name, 154 | const char *description, 155 | char &dest); 156 | //! Add an named char. 157 | void new_optional_char(const char *value_name, 158 | const char *description, 159 | char &dest); 160 | 161 | //! Add an unsigned int. 162 | void new_unsigned_int(const char *value_name, const char *description, 163 | unsigned int &dest); 164 | //! Add an named unsigned int. 165 | void new_optional_unsigned_int(const char *value_name, const char *description, 166 | unsigned int &dest); 167 | //! Add an optional named unsigned int. 168 | void new_named_unsigned_int(const char *key, const char *long_name, 169 | const char *value_name, const char *description, 170 | unsigned int &dest); 171 | 172 | 173 | //! add a target which takes a list of strings 174 | /*! 175 | Only named makes sense as the string vector default handles unnamed and optional. 176 | */ 177 | void new_named_string_vector(const char *key, const char *long_name, 178 | const char *value_name, const char *description, 179 | std::vector &dest); 180 | 181 | //! add a vector of strings. 182 | /*! Any arguments which are not claimed by earlier unnamed 183 | arguments or which are named are put here. This means you cannot 184 | have a string vector followed by a string. 185 | */ 186 | void set_string_vector(const char *arg_description, const char *description, std::vector &dest); 187 | 188 | //! Set who wrote the program. 189 | void set_author(const char *author); 190 | 191 | //! Set what the program does. 192 | void set_description(const char *descr); 193 | 194 | //! Set what the version is. 195 | void set_version(float v); 196 | void set_version(const char *str); 197 | 198 | //! Set the name of the program. 199 | void set_name(const char *name); 200 | 201 | //! Set when the program was built. 202 | void set_build_date(const char *date); 203 | 204 | //! Process the list of arguments and parse them. 205 | /*! 206 | This returns true if all the required arguments are there. 207 | */ 208 | void process(int argc, const char **argv); 209 | void process(int argc, char **argv){ 210 | process(argc, const_cast(argv)); 211 | } 212 | //! Write how to call the program. 213 | void write_usage(std::ostream &out) const; 214 | //! Write the values of all the possible arguments. 215 | void write_values(std::ostream &out) const; 216 | 217 | ~Argument_helper(); 218 | protected: 219 | typedef std::map SMap; 220 | typedef std::map LMap; 221 | typedef std::vector UVect; 222 | // A map from short names to arguments. 223 | SMap short_names_; 224 | // A map from long names to arguments. 225 | LMap long_names_; 226 | std::string author_; 227 | std::string name_; 228 | std::string description_; 229 | std::string date_; 230 | float version_; 231 | bool seen_end_named_; 232 | // List of unnamed arguments 233 | std::vector unnamed_arguments_; 234 | std::vector optional_unnamed_arguments_; 235 | std::vector all_arguments_; 236 | std::string extra_arguments_descr_; 237 | std::string extra_arguments_arg_descr_; 238 | std::vector *extra_arguments_; 239 | std::vector::iterator current_unnamed_; 240 | std::vector::iterator current_optional_unnamed_; 241 | void new_argument_target(Argument_target*); 242 | void handle_error() const; 243 | private: 244 | Argument_helper(const Argument_helper &){}; 245 | const Argument_helper& operator=(const Argument_helper &){return *this;} 246 | }; 247 | 248 | 249 | 250 | 251 | bool verbose=false, VERBOSE=false; 252 | 253 | 254 | //////////////////////////////////////////////// Argument Targets 255 | 256 | // This is a base class for representing one argument value. 257 | /* 258 | This is inherited by many classes and which represent the different types. 259 | */ 260 | class Argument_helper::Argument_target { 261 | public: 262 | std::string key; 263 | std::string long_name; 264 | std::string description; 265 | std::string arg_description; 266 | 267 | Argument_target(const std::string& k, const std::string& lname, 268 | const std::string& descr, 269 | const std::string& arg_descr) { 270 | key=k; 271 | long_name=lname; 272 | description=descr; 273 | arg_description=arg_descr; 274 | } 275 | Argument_target(const std::string& descr, 276 | const std::string& arg_descr) { 277 | key=""; 278 | long_name=""; 279 | description=descr; 280 | arg_description=arg_descr; 281 | } 282 | virtual bool process(int &, const char **&)=0; 283 | virtual void write_name(std::ostream &out) const; 284 | virtual void write_value(std::ostream &out) const=0; 285 | virtual void write_usage(std::ostream &out) const; 286 | virtual ~Argument_target(){} 287 | }; 288 | 289 | void Argument_helper::Argument_target::write_name(std::ostream &out) const { 290 | if (key != "") out << '-' << key; 291 | else if (!long_name.empty()) out << "--" << long_name; 292 | else out << arg_description; 293 | } 294 | 295 | 296 | void Argument_helper::Argument_target::write_usage(std::ostream &out) const { 297 | if (key != "") { 298 | out << '-' << key; 299 | out << "/--" << long_name; 300 | } 301 | out << ' ' << arg_description; 302 | out << "\t" << description; 303 | out << " Value: "; 304 | write_value(out); 305 | out << std::endl; 306 | } 307 | 308 | class Argument_helper::FlagTarget: public Argument_helper::Argument_target{ 309 | public: 310 | bool &val; 311 | FlagTarget(const char *k, const char *lname, 312 | const char *descr, 313 | bool &b): Argument_target(std::string(k), std::string(lname), std::string(descr), 314 | std::string()), val(b){} 315 | virtual bool process(int &, const char **&){ 316 | val= !val; 317 | return true; 318 | } 319 | virtual void write_value(std::ostream &out) const { 320 | out << val; 321 | } 322 | 323 | virtual void write_usage(std::ostream &out) const { 324 | if (key != "") { 325 | out << '-' << key; 326 | out << "/--" << long_name; 327 | } 328 | out << "\t" << description; 329 | out << " Value: "; 330 | write_value(out); 331 | out << std::endl; 332 | } 333 | virtual ~FlagTarget(){} 334 | }; 335 | 336 | class Argument_helper::DoubleTarget: public Argument_target{ 337 | public: 338 | double &val; 339 | DoubleTarget(const char *k, const char *lname, 340 | const char *arg_descr, 341 | const char *descr, double &b): Argument_target(std::string(k), std::string(lname), 342 | std::string(descr), 343 | std::string(arg_descr)), val(b){} 344 | DoubleTarget(const char *arg_descr, 345 | const char *descr, double &b): Argument_target(std::string(descr), 346 | std::string(arg_descr)), val(b){} 347 | virtual bool process(int &argc, const char **&argv){ 348 | if (argc==0){ 349 | std::cerr << "Missing value for argument." << std::endl; 350 | return false; 351 | } 352 | if (sscanf(argv[0], "%le", &val) ==1){ 353 | --argc; 354 | ++argv; 355 | return true; 356 | } else { 357 | std::cerr << "Double not found at " << argv[0] << std::endl; 358 | return false; 359 | } 360 | } 361 | virtual void write_value(std::ostream &out) const { 362 | out << val; 363 | } 364 | virtual ~DoubleTarget(){} 365 | }; 366 | 367 | class Argument_helper::IntTarget: public Argument_target{ 368 | public: 369 | int &val; 370 | IntTarget(const char *arg_descr, 371 | const char *descr, int &b): Argument_target(std::string(descr), std::string(arg_descr)), 372 | val(b){} 373 | IntTarget(const char *k, const char *lname, 374 | const char *arg_descr, 375 | const char *descr, int &b): Argument_target(std::string(k), std::string(lname), 376 | std::string(descr), 377 | std::string(arg_descr)), 378 | val(b){} 379 | virtual bool process(int &argc, const char **&argv){ 380 | if (argc==0){ 381 | std::cerr << "Missing value for argument." << std::endl; 382 | return false; 383 | } 384 | if (sscanf(argv[0], "%d", &val) ==1){ 385 | --argc; 386 | ++argv; 387 | return true; 388 | } else { 389 | std::cerr << "Integer not found at " << argv[0] << std::endl; 390 | return false; 391 | } 392 | } 393 | virtual void write_value(std::ostream &out) const { 394 | out << val; 395 | } 396 | virtual ~IntTarget(){} 397 | }; 398 | 399 | class Argument_helper::UIntTarget: public Argument_target{ 400 | public: 401 | unsigned int &val; 402 | UIntTarget(const char *arg_descr, 403 | const char *descr, unsigned int &b): Argument_target(std::string(descr), std::string(arg_descr)), 404 | val(b){} 405 | UIntTarget(const char *k, const char *lname, 406 | const char *arg_descr, 407 | const char *descr, unsigned int &b): Argument_target(std::string(k), std::string(lname), 408 | std::string(descr), 409 | std::string(arg_descr)), 410 | val(b){} 411 | virtual bool process(int &argc, const char **&argv){ 412 | if (argc==0){ 413 | std::cerr << "Missing value for argument." << std::endl; 414 | return false; 415 | } 416 | if (sscanf(argv[0], "%ud", &val) ==1){ 417 | --argc; 418 | ++argv; 419 | return true; 420 | } else { 421 | std::cerr << "Unsigned integer not found at " << argv[0] << std::endl; 422 | return false; 423 | } 424 | } 425 | virtual void write_value(std::ostream &out) const { 426 | out << val; 427 | } 428 | virtual ~UIntTarget(){} 429 | }; 430 | 431 | 432 | class Argument_helper::CharTarget: public Argument_target{ 433 | public: 434 | char &val; 435 | CharTarget(const char *k, const char *lname, 436 | const char *arg_descr, 437 | const char *descr, char &b): Argument_target(std::string(k), std::string(lname), 438 | std::string(descr), 439 | std::string(arg_descr)), val(b){} 440 | CharTarget(const char *arg_descr, 441 | const char *descr, char &b): Argument_target(std::string(descr), 442 | std::string(arg_descr)), val(b){} 443 | virtual bool process(int &argc, const char **&argv){ 444 | if (argc==0){ 445 | std::cerr << "Missing value for argument." << std::endl; 446 | return false; 447 | } 448 | if (sscanf(argv[0], "%c", &val) ==1){ 449 | --argc; 450 | ++argv; 451 | return true; 452 | } else { 453 | std::cerr << "Character not found at " << argv[0] << std::endl; 454 | return false; 455 | } 456 | } 457 | virtual void write_value(std::ostream &out) const { 458 | out << val; 459 | } 460 | virtual ~CharTarget(){} 461 | }; 462 | 463 | 464 | class Argument_helper::StringTarget: public Argument_target{ 465 | public: 466 | std::string &val; 467 | StringTarget(const char *arg_descr, 468 | const char *descr, std::string &b): Argument_target(std::string(descr), std::string(arg_descr)), 469 | val(b){} 470 | 471 | StringTarget(const char *k, const char *lname, const char *arg_descr, 472 | const char *descr, std::string &b): Argument_target(std::string(k), std::string(lname), 473 | std::string(descr), std::string(arg_descr)), 474 | val(b){} 475 | 476 | virtual bool process(int &argc, const char **&argv){ 477 | if (argc==0){ 478 | std::cerr << "Missing string argument." << std::endl; 479 | return false; 480 | } 481 | val= argv[0]; 482 | --argc; 483 | ++argv; 484 | return true; 485 | } 486 | virtual void write_value(std::ostream &out) const { 487 | out << val; 488 | } 489 | virtual ~StringTarget(){} 490 | }; 491 | 492 | 493 | class Argument_helper::StringVectorTarget: public Argument_target{ 494 | public: 495 | std::vector &val; 496 | 497 | StringVectorTarget(const char *k, const char *lname, const char *arg_descr, 498 | const char *descr, std::vector &b): Argument_target(std::string(k), std::string(lname), 499 | std::string(descr), std::string(arg_descr)), 500 | val(b){} 501 | 502 | virtual bool process(int &argc, const char **&argv){ 503 | while (argc >0 && argv[0][0] != '-'){ 504 | val.push_back(argv[0]); 505 | --argc; 506 | ++argv; 507 | } 508 | return true; 509 | } 510 | virtual void write_value(std::ostream &out) const { 511 | for (unsigned int i=0; i< val.size(); ++i){ 512 | out << val[i] << " "; 513 | } 514 | } 515 | virtual ~StringVectorTarget(){} 516 | }; 517 | 518 | 519 | //////////////////////////////////////////////// Argument_helper functions 520 | 521 | 522 | Argument_helper::Argument_helper(){ 523 | author_="a programmer"; 524 | description_= "This program produces output."; 525 | date_= "some day a long, long time ago."; 526 | version_=-1; 527 | extra_arguments_=NULL; 528 | seen_end_named_=false; 529 | new_flag("v", "verbose", "Whether to print extra information", verbose); 530 | new_flag("V", "VERBOSE", "Whether to print lots of extra information", VERBOSE); 531 | } 532 | 533 | 534 | 535 | void Argument_helper::set_string_vector(const char *arg_description, 536 | const char *description, 537 | std::vector &dest){ 538 | assert(extra_arguments_==NULL); 539 | extra_arguments_descr_= description; 540 | extra_arguments_arg_descr_= arg_description; 541 | extra_arguments_= &dest; 542 | } 543 | 544 | void Argument_helper::set_author(const char *author){ 545 | author_=author; 546 | } 547 | 548 | void Argument_helper::set_description(const char *descr){ 549 | description_= descr; 550 | } 551 | 552 | void Argument_helper::set_name(const char *descr){ 553 | name_= descr; 554 | } 555 | 556 | void Argument_helper::set_version(float v){ 557 | version_=v; 558 | } 559 | 560 | void Argument_helper::set_version(const char *s){ 561 | version_=atof(s); 562 | } 563 | 564 | void Argument_helper::set_build_date(const char *date){ 565 | date_=date; 566 | } 567 | 568 | void Argument_helper::new_argument_target(Argument_target *t) { 569 | assert(t!= NULL); 570 | if (t->key != ""){ 571 | if (short_names_.find(t->key) != short_names_.end()){ 572 | std::cerr << "Two arguments are defined with the same string key, namely" << std::endl; 573 | short_names_[t->key]->write_usage(std::cerr); 574 | std::cerr << "\n and \n"; 575 | t->write_usage(std::cerr); 576 | std::cerr << std::endl; 577 | } 578 | short_names_[t->key]= t; 579 | } 580 | if (!t->long_name.empty()){ 581 | if (long_names_.find(t->long_name) != long_names_.end()){ 582 | std::cerr << "Two arguments are defined with the same long key, namely" << std::endl; 583 | long_names_[t->long_name]->write_usage(std::cerr); 584 | std::cerr << "\n and \n"; 585 | t->write_usage(std::cerr); 586 | std::cerr << std::endl; 587 | } 588 | long_names_[t->long_name]= t; 589 | } 590 | all_arguments_.push_back(t); 591 | } 592 | 593 | void Argument_helper::new_flag(const char *key, const char *long_name, const char *description,bool &dest){ 594 | Argument_target *t= new FlagTarget(key, long_name, description, dest); 595 | new_argument_target(t); 596 | }; 597 | 598 | 599 | 600 | void Argument_helper::new_string(const char *arg_description, const char *description, 601 | std::string &dest){ 602 | Argument_target *t= new StringTarget(arg_description, description, dest); 603 | unnamed_arguments_.push_back(t); 604 | all_arguments_.push_back(t); 605 | }; 606 | void Argument_helper::new_optional_string(const char *arg_description, const char *description, 607 | std::string &dest){ 608 | Argument_target *t= new StringTarget(arg_description, description, dest); 609 | optional_unnamed_arguments_.push_back(t); 610 | }; 611 | void Argument_helper::new_named_string(const char *key, const char *long_name, 612 | const char *arg_description, const char *description, 613 | std::string &dest){ 614 | Argument_target *t= new StringTarget(key, long_name, arg_description, description, dest); 615 | new_argument_target(t); 616 | }; 617 | 618 | 619 | void Argument_helper::new_named_string_vector(const char *key, const char *long_name, 620 | const char *arg_description, const char *description, 621 | std::vector &dest){ 622 | Argument_target *t= new StringVectorTarget(key, long_name, arg_description, description, dest); 623 | new_argument_target(t); 624 | }; 625 | 626 | 627 | 628 | void Argument_helper::new_int(const char *arg_description, const char *description, 629 | int &dest){ 630 | Argument_target *t= new IntTarget(arg_description, description, dest); 631 | unnamed_arguments_.push_back(t); 632 | all_arguments_.push_back(t); 633 | }; 634 | void Argument_helper::new_optional_int(const char *arg_description, const char *description, 635 | int &dest){ 636 | Argument_target *t= new IntTarget(arg_description, description, dest); 637 | optional_unnamed_arguments_.push_back(t); 638 | }; 639 | void Argument_helper::new_named_int(const char *key, const char *long_name, 640 | const char *arg_description, const char *description, 641 | int &dest){ 642 | Argument_target *t= new IntTarget(key, long_name, arg_description, description, dest); 643 | new_argument_target(t); 644 | }; 645 | 646 | void Argument_helper::new_unsigned_int(const char *arg_description, const char *description, 647 | unsigned int &dest){ 648 | Argument_target *t= new UIntTarget(arg_description, description, dest); 649 | unnamed_arguments_.push_back(t); 650 | all_arguments_.push_back(t); 651 | }; 652 | void Argument_helper::new_optional_unsigned_int(const char *arg_description, const char *description, 653 | unsigned int &dest){ 654 | Argument_target *t= new UIntTarget(arg_description, description, dest); 655 | optional_unnamed_arguments_.push_back(t); 656 | }; 657 | void Argument_helper::new_named_unsigned_int(const char *key, const char *long_name, 658 | const char *arg_description, const char *description, 659 | unsigned int &dest){ 660 | Argument_target *t= new UIntTarget(key, long_name, arg_description, description, dest); 661 | new_argument_target(t); 662 | }; 663 | 664 | 665 | void Argument_helper::new_double(const char *arg_description, const char *description, 666 | double &dest){ 667 | Argument_target *t= new DoubleTarget(arg_description, description, dest); 668 | unnamed_arguments_.push_back(t); 669 | all_arguments_.push_back(t); 670 | }; 671 | void Argument_helper::new_optional_double(const char *arg_description, const char *description, 672 | double &dest){ 673 | Argument_target *t= new DoubleTarget(arg_description, description, dest); 674 | optional_unnamed_arguments_.push_back(t); 675 | }; 676 | void Argument_helper::new_named_double(const char *key, const char *long_name, 677 | const char *arg_description, const char *description, 678 | double &dest){ 679 | Argument_target *t= new DoubleTarget(key, long_name, arg_description, description, dest); 680 | new_argument_target(t); 681 | }; 682 | 683 | void Argument_helper::new_char(const char *arg_description, const char *description, 684 | char &dest){ 685 | Argument_target *t= new CharTarget(arg_description, description, dest); 686 | unnamed_arguments_.push_back(t); 687 | all_arguments_.push_back(t); 688 | }; 689 | void Argument_helper::new_optional_char(const char *arg_description, const char *description, 690 | char &dest){ 691 | Argument_target *t= new CharTarget(arg_description, description, dest); 692 | optional_unnamed_arguments_.push_back(t); 693 | }; 694 | void Argument_helper::new_named_char(const char *key, const char *long_name, 695 | const char *arg_description, const char *description, 696 | char &dest){ 697 | Argument_target *t= new CharTarget(key, long_name, arg_description, description, dest); 698 | new_argument_target(t); 699 | }; 700 | 701 | 702 | 703 | void Argument_helper::write_usage(std::ostream &out) const { 704 | out << name_ << " version " << version_ << ", by " << author_ << std::endl; 705 | out << description_ << std::endl; 706 | out << "Compiled on " << date_ << std::endl << std::endl; 707 | out << "Usage: " << name_ << " "; 708 | for (UVect::const_iterator it= unnamed_arguments_.begin(); it != unnamed_arguments_.end(); ++it){ 709 | (*it)->write_name(out); 710 | out << " "; 711 | } 712 | for (UVect::const_iterator it= optional_unnamed_arguments_.begin(); 713 | it != optional_unnamed_arguments_.end(); ++it){ 714 | out << "["; 715 | (*it)->write_name(out); 716 | out << "] "; 717 | } 718 | if (extra_arguments_ != NULL) { 719 | out << "[" << extra_arguments_arg_descr_ << "]"; 720 | } 721 | 722 | out << std::endl << std::endl; 723 | out << "All arguments:\n"; 724 | for (UVect::const_iterator it= unnamed_arguments_.begin(); it != unnamed_arguments_.end(); ++it){ 725 | (*it)->write_usage(out); 726 | } 727 | for (UVect::const_iterator it= optional_unnamed_arguments_.begin(); 728 | it != optional_unnamed_arguments_.end(); ++it){ 729 | (*it)->write_usage(out); 730 | } 731 | 732 | if (!extra_arguments_arg_descr_.empty()) { 733 | out << extra_arguments_arg_descr_ << ": " << extra_arguments_descr_ << std::endl; 734 | } 735 | for (SMap::const_iterator it= short_names_.begin(); it != short_names_.end(); ++it){ 736 | (it->second)->write_usage(out); 737 | } 738 | } 739 | 740 | 741 | 742 | void Argument_helper::write_values(std::ostream &out) const { 743 | for (UVect::const_iterator it= unnamed_arguments_.begin(); it != unnamed_arguments_.end(); ++it){ 744 | out << (*it)->description; 745 | out << ": "; 746 | (*it)->write_value(out); 747 | out << std::endl; 748 | } 749 | for (UVect::const_iterator it= optional_unnamed_arguments_.begin(); 750 | it != optional_unnamed_arguments_.end(); ++it){ 751 | out << (*it)->description; 752 | out << ": "; 753 | (*it)->write_value(out); 754 | out << std::endl; 755 | } 756 | if (extra_arguments_!=NULL){ 757 | for (std::vector::const_iterator it= extra_arguments_->begin(); 758 | it != extra_arguments_->end(); ++it){ 759 | out << *it << " "; 760 | } 761 | } 762 | 763 | for (SMap::const_iterator it= short_names_.begin(); it != short_names_.end(); ++it){ 764 | out << it->second->description; 765 | out << ": "; 766 | it->second->write_value(out); 767 | out << std::endl; 768 | } 769 | } 770 | 771 | Argument_helper::~Argument_helper(){ 772 | for (std::vector::iterator it= all_arguments_.begin(); 773 | it != all_arguments_.end(); ++it){ 774 | delete *it; 775 | } 776 | } 777 | 778 | 779 | void Argument_helper::process(int argc, const char **argv){ 780 | name_= argv[0]; 781 | ++argv; 782 | --argc; 783 | 784 | current_unnamed_= unnamed_arguments_.begin(); 785 | current_optional_unnamed_= optional_unnamed_arguments_.begin(); 786 | 787 | for ( int i=0; i< argc; ++i){ 788 | if (strcmp(argv[i], "--help") == 0){ 789 | write_usage(std::cout); 790 | exit(0); 791 | } 792 | } 793 | 794 | while (argc != 0){ 795 | 796 | const char* cur_arg= argv[0]; 797 | if (cur_arg[0]=='-' && !seen_end_named_){ 798 | --argc; ++argv; 799 | if (cur_arg[1]=='-'){ 800 | if (cur_arg[2] == '\0') { 801 | //std::cout << "Ending flags " << std::endl; 802 | seen_end_named_=true; 803 | } else { 804 | // long argument 805 | LMap::iterator f= long_names_.find(cur_arg+2); 806 | if ( f != long_names_.end()){ 807 | if (!f->second->process(argc, argv)) { 808 | handle_error(); 809 | } 810 | } else { 811 | std::cerr<< "Invalid long argument "<< cur_arg << ".\n"; 812 | handle_error(); 813 | } 814 | } 815 | } else { 816 | if (cur_arg[1]=='\0') { 817 | std::cerr << "Invalid argument " << cur_arg << ".\n"; 818 | handle_error(); 819 | } 820 | SMap::iterator f= short_names_.find(cur_arg+1); 821 | if ( f != short_names_.end()){ 822 | if (!f->second->process(argc, argv)) { 823 | handle_error(); 824 | } 825 | } else { 826 | std::cerr<< "Invalid short argument "<< cur_arg << ".\n"; 827 | handle_error(); 828 | } 829 | } 830 | } else { 831 | if (current_unnamed_ != unnamed_arguments_.end()){ 832 | Argument_target *t= *current_unnamed_; 833 | t->process(argc, argv); 834 | ++current_unnamed_; 835 | } else if (current_optional_unnamed_ != optional_unnamed_arguments_.end()){ 836 | Argument_target *t= *current_optional_unnamed_; 837 | t->process(argc, argv); 838 | ++current_optional_unnamed_; 839 | } else if (extra_arguments_!= NULL){ 840 | extra_arguments_->push_back(cur_arg); 841 | --argc; 842 | ++argv; 843 | } else { 844 | std::cerr << "Invalid extra argument " << argv[0] << std::endl; 845 | handle_error(); 846 | } 847 | } 848 | } 849 | 850 | if (current_unnamed_ != unnamed_arguments_.end()){ 851 | std::cerr << "Missing required arguments:" << std::endl; 852 | for (; current_unnamed_ != unnamed_arguments_.end(); ++current_unnamed_){ 853 | (*current_unnamed_)->write_name(std::cerr); 854 | std::cerr << std::endl; 855 | } 856 | std::cerr << std::endl; 857 | handle_error(); 858 | } 859 | 860 | if (VERBOSE) verbose=true; 861 | } 862 | 863 | void Argument_helper::handle_error() const { 864 | write_usage(std::cerr); 865 | exit(1); 866 | } 867 | 868 | } 869 | 870 | 871 | 872 | 873 | 874 | #endif 875 | -------------------------------------------------------------------------------- /ade/NNade3.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef NNADE3_H_ 3 | #define NNADE3_H_ 4 | 5 | #include 6 | #include "Options.h" 7 | #include "Tool.h" 8 | #include "FoxUtil.h" 9 | #include "N3Lhelper.h" 10 | #include "Utf.h" 11 | #include "Token.h" 12 | #include "Sent.h" 13 | #include 14 | #include "Dependency.h" 15 | #include "N3L.h" 16 | #include "EnglishPos.h" 17 | #include "Punctuation.h" 18 | #include "Word2Vec.h" 19 | #include "utils.h" 20 | #include "Example.h" 21 | #include "BestPerformance.h" 22 | 23 | #include "Classifier3.h" 24 | 25 | using namespace nr; 26 | using namespace std; 27 | 28 | // a implement of ACL 2016 end-to-end relation extraction 29 | // use relation f1 on the development set 30 | 31 | class NNade3 { 32 | public: 33 | Options m_options; 34 | 35 | Alphabet m_wordAlphabet; 36 | Alphabet m_posAlphabet; 37 | Alphabet m_nerAlphabet; 38 | Alphabet m_depAlphabet; 39 | Alphabet m_charAlphabet; 40 | 41 | string unknownkey; 42 | string nullkey; 43 | 44 | #if USE_CUDA==1 45 | Classifier m_classifier; 46 | #else 47 | //Classifier3_base m_classifier; 48 | //Classifier3_char m_classifier; 49 | //Classifier3_pos m_classifier; 50 | //Classifier3_ner m_classifier; 51 | //Classifier3_label m_classifier; 52 | //Classifier3_dep m_classifier; 53 | //Classifier3_entity m_classifier; 54 | //Classifier3_nosdp m_classifier; 55 | //Classifier3_nojoint m_classifier; 56 | Classifier3 m_classifier; 57 | #endif 58 | 59 | 60 | NNade3(const Options &options):m_options(options) { 61 | unknownkey = "-#unknown#-"; 62 | nullkey = "-#null#-"; 63 | } 64 | 65 | 66 | BestPerformance trainAndTest(vector & processedTest, vector & annotatedTest, 67 | vector & processedDev, vector & annotatedDev, 68 | vector & processedTrain, vector & annotatedTrain, Tool & tool) { 69 | 70 | BestPerformance ret; 71 | 72 | m_wordAlphabet.clear(); 73 | m_wordAlphabet.from_string(unknownkey); 74 | m_wordAlphabet.from_string(nullkey); 75 | 76 | m_posAlphabet.clear(); 77 | m_posAlphabet.from_string(unknownkey); 78 | m_posAlphabet.from_string(nullkey); 79 | 80 | m_depAlphabet.clear(); 81 | m_depAlphabet.from_string(unknownkey); 82 | m_depAlphabet.from_string(nullkey); 83 | 84 | m_charAlphabet.clear(); 85 | m_charAlphabet.from_string(unknownkey); 86 | m_charAlphabet.from_string(nullkey); 87 | 88 | // ner alphabet should be initialized directly, not from the dataset 89 | m_nerAlphabet.clear(); 90 | m_nerAlphabet.from_string(unknownkey); 91 | m_nerAlphabet.from_string(nullkey); 92 | hash_map ner_stat; 93 | ner_stat[B_Disease]++; 94 | ner_stat[I_Disease]++; 95 | ner_stat[L_Disease]++; 96 | ner_stat[U_Disease]++; 97 | ner_stat[B_Chemical]++; 98 | ner_stat[I_Chemical]++; 99 | ner_stat[L_Chemical]++; 100 | ner_stat[U_Chemical]++; 101 | ner_stat[OTHER]++; 102 | stat2Alphabet(ner_stat, m_nerAlphabet, "ner"); 103 | 104 | createAlphabet(processedTrain, tool); 105 | 106 | if (!m_options.wordEmbFineTune) { 107 | createAlphabet(processedDev, tool); 108 | createAlphabet(processedTest, tool); 109 | } 110 | 111 | 112 | NRMat wordEmb; 113 | if(m_options.wordEmbFineTune) { 114 | if(m_options.embFile.empty()) { 115 | cout<<"random emb"< known; 123 | map IDs; 124 | alphabet2vectormap(m_wordAlphabet, known, IDs); 125 | 126 | tool.w2v->getEmbedding((double*)emb, m_options.wordEmbSize, known, unknownkey, IDs); 127 | 128 | wordEmb.resize(m_wordAlphabet.size(), m_options.wordEmbSize); 129 | array2NRMat((double*) emb, m_wordAlphabet.size(), m_options.wordEmbSize, wordEmb); 130 | 131 | delete[] emb; 132 | } 133 | } else { 134 | if(m_options.embFile.empty()) { 135 | assert(0); 136 | } else { 137 | 138 | double* emb = new double[m_wordAlphabet.size()*m_options.wordEmbSize]; 139 | fox::initArray2((double *)emb, (int)m_wordAlphabet.size(), m_options.wordEmbSize, 0.0); 140 | vector known; 141 | map IDs; 142 | alphabet2vectormap(m_wordAlphabet, known, IDs); 143 | 144 | tool.w2v->getEmbedding((double*)emb, m_options.wordEmbSize, known, unknownkey, IDs); 145 | 146 | wordEmb.resize(m_wordAlphabet.size(), m_options.wordEmbSize); 147 | array2NRMat((double*) emb, m_wordAlphabet.size(), m_options.wordEmbSize, wordEmb); 148 | 149 | delete[] emb; 150 | } 151 | } 152 | 153 | NRMat posEmb; 154 | randomInitNrmat(posEmb, m_posAlphabet, m_options.otherEmbSize, 1010); 155 | NRMat nerEmb; 156 | randomInitNrmat(nerEmb, m_nerAlphabet, m_options.otherEmbSize, 1020); 157 | NRMat depEmb; 158 | randomInitNrmat(depEmb, m_depAlphabet, m_options.otherEmbSize, 1030); 159 | NRMat charEmb; 160 | randomInitNrmat(charEmb, m_charAlphabet, m_options.otherEmbSize, 1040); 161 | 162 | vector trainExamples; 163 | initialTrainingExamples(tool, annotatedTrain, processedTrain, trainExamples); 164 | cout<<"Total train example number: "< subExamples; 185 | 186 | dtype best = 0; 187 | 188 | // begin to train 189 | for (int iter = 0; iter < m_options.maxIter; ++iter) { 190 | 191 | cout << "##### Iteration " << iter << std::endl; 192 | 193 | // this coding block should be identical to initialTrainingExamples 194 | // we don't regenerate training examples, instead fetch them directly 195 | int exampleIdx = 0; 196 | 197 | for(int sentIdx=0;sentIdx best) { 269 | cout << "Exceeds best performance of " << best << endl; 270 | best = currentDev.dev_f1Relation; 271 | 272 | BestPerformance currentTest = test(tool, annotatedTest, processedTest); 273 | 274 | 275 | ret.dev_pEntity = currentDev.dev_pEntity; 276 | ret.dev_rEntity = currentDev.dev_rEntity; 277 | ret.dev_f1Entity = currentDev.dev_f1Entity; 278 | ret.dev_pRelation = currentDev.dev_pRelation; 279 | ret.dev_rRelation = currentDev.dev_rRelation; 280 | ret.dev_f1Relation = currentDev.dev_f1Relation; 281 | ret.test_pEntity = currentTest.test_pEntity; 282 | ret.test_rEntity = currentTest.test_rEntity; 283 | ret.test_f1Entity = currentTest.test_f1Entity; 284 | ret.test_pRelation = currentTest.test_pRelation; 285 | ret.test_rRelation = currentTest.test_rRelation; 286 | ret.test_f1Relation = currentTest.test_f1Relation; 287 | 288 | }*/ 289 | 290 | } 291 | 292 | } // for iter 293 | 294 | m_classifier.release(); 295 | 296 | return ret; 297 | } 298 | 299 | 300 | void initialTrainingExamples(Tool& tool, vector & annotated, vector & processed, 301 | vector & examples) { 302 | 303 | 304 | for(int sentIdx=0;sentIdx labelSequence; 309 | 310 | for(int tokenIdx=0;tokenIdx & annotated, 373 | vector & processed, int iter) { 374 | 375 | BestPerformance ret; 376 | 377 | int ctGoldEntity = 0; 378 | int ctPredictEntity = 0; 379 | int ctCorrectEntity = 0; 380 | int ctGoldRelation = 0, ctPredictRelation = 0, ctCorrectRelation = 0; 381 | 382 | for(int sentIdx=0;sentIdx labelSequence; 387 | 388 | vector anwserEntities; 389 | vector anwserRelations; 390 | 391 | vector fnRelations; 392 | 393 | for(int tokenIdx=0;tokenIdx probs; 401 | int predicted = m_classifier.predictNer(eg, probs); 402 | 403 | string labelName = NERlabelID2labelName(predicted); 404 | labelSequence.push_back(labelName); 405 | 406 | // decode entity label 407 | if(labelName == B_Disease || labelName == U_Disease || 408 | labelName == B_Chemical || labelName == U_Chemical) { 409 | Entity entity; 410 | newEntity(token, labelName, entity); 411 | anwserEntities.push_back(entity); 412 | } else if(labelName == I_Disease || labelName == L_Disease || 413 | labelName == I_Chemical || labelName == L_Chemical) { 414 | if(checkWrongState(labelSequence)) { 415 | Entity& entity = anwserEntities[anwserEntities.size()-1]; 416 | appendEntity(token, entity); 417 | } 418 | } 419 | 420 | lastNerLabel = labelName; 421 | } // token 422 | 423 | for(int bIdx=0;bIdx probs; 443 | int predicted = m_classifier.predictRel(eg, probs); 444 | 445 | string labelName = RellabelID2labelName(predicted); 446 | 447 | // decode relation label 448 | if(labelName == ADE) { 449 | Relation relation; 450 | newRelation(relation, Disease, Chemical); 451 | anwserRelations.push_back(relation); 452 | } 453 | 454 | } // aIdx 455 | 456 | } // bIdx 457 | 458 | // evaluate by ourselves 459 | ctGoldEntity += annotatedSent.entities.size(); 460 | ctPredictEntity += anwserEntities.size(); 461 | for(int i=0;i=anwserRelations.size()) { 519 | fnRelations.push_back(annotatedSent.relations[i]); 520 | } 521 | } 522 | 523 | for(int i=0;i & annotated, vector & processed) { 554 | 555 | BestPerformance ret; 556 | 557 | int ctGoldEntity = 0; 558 | int ctPredictEntity = 0; 559 | int ctCorrectEntity = 0; 560 | int ctGoldRelation = 0, ctPredictRelation = 0, ctCorrectRelation = 0; 561 | 562 | for(int sentIdx=0;sentIdx labelSequence; 567 | 568 | vector anwserEntities; 569 | vector anwserRelations; 570 | 571 | for(int tokenIdx=0;tokenIdx probs; 579 | int predicted = m_classifier.predictNer(eg, probs); 580 | 581 | string labelName = NERlabelID2labelName(predicted); 582 | labelSequence.push_back(labelName); 583 | 584 | // decode entity label 585 | if(labelName == B_Disease || labelName == U_Disease || 586 | labelName == B_Chemical || labelName == U_Chemical ) { 587 | Entity entity; 588 | newEntity(token, labelName, entity); 589 | anwserEntities.push_back(entity); 590 | } else if(labelName == I_Disease || labelName == L_Disease || 591 | labelName == I_Chemical || labelName == L_Chemical ) { 592 | if(checkWrongState(labelSequence)) { 593 | Entity& entity = anwserEntities[anwserEntities.size()-1]; 594 | appendEntity(token, entity); 595 | } 596 | } 597 | 598 | lastNerLabel = labelName; 599 | } // token 600 | 601 | 602 | for(int bIdx=0;bIdx probs; 622 | int predicted = m_classifier.predictRel(eg, probs); 623 | 624 | string labelName = RellabelID2labelName(predicted); 625 | 626 | // decode relation label 627 | if(labelName == ADE) { 628 | Relation relation; 629 | newRelation(relation, Disease, Chemical); 630 | anwserRelations.push_back(relation); 631 | } 632 | 633 | } // aIdx 634 | 635 | } // bIdx 636 | 637 | // evaluate by ourselves 638 | ctGoldEntity += annotatedSent.entities.size(); 639 | ctPredictEntity += anwserEntities.size(); 640 | for(int i=0;i& labelSequence) { 683 | int positionNew = -1; // the latest type-consistent B 684 | int positionOther = -1; // other label except type-consistent I 685 | 686 | const string& currentLabel = labelSequence[labelSequence.size()-1]; 687 | 688 | for(int j=labelSequence.size()-2;j>=0;j--) { 689 | if(currentLabel==I_Disease || currentLabel==L_Disease) { 690 | if(positionNew==-1 && labelSequence[j]==B_Disease) { 691 | positionNew = j; 692 | } else if(positionOther==-1 && labelSequence[j]!=I_Disease) { 693 | positionOther = j; 694 | } 695 | } else if(currentLabel==I_Chemical || currentLabel==L_Chemical) { 696 | if(positionNew==-1 && labelSequence[j]==B_Chemical) { 697 | positionNew = j; 698 | } else if(positionOther==-1 && labelSequence[j]!=I_Chemical) { 699 | positionOther = j; 700 | } 701 | } 702 | 703 | if(positionOther!=-1 && positionNew!=-1) 704 | break; 705 | } 706 | 707 | if(positionNew == -1) 708 | return false; 709 | else if(positionOther chars; 730 | featureName2ID(m_charAlphabet, feature_character(sent.tokens[i]), chars); 731 | eg._seq_chars.push_back(chars); 732 | } 733 | 734 | eg._prior_ner = featureName2ID(m_nerAlphabet, lastNerLabel); 735 | eg._current_idx = tokenIdx; 736 | } 737 | 738 | void generateOneRelExample(Example& eg, const string& labelName, const fox::Sent& sent, 739 | const Entity& Disease, const Entity& Chemical, const vector& labelSequence) { 740 | if(!labelName.empty()) { 741 | int labelID = RellabelName2labelID(labelName); 742 | for(int i=0;i chars; 760 | featureName2ID(m_charAlphabet, feature_character(sent.tokens[i]), chars); 761 | eg._seq_chars.push_back(chars); 762 | eg._deps.push_back(featureName2ID(m_depAlphabet, feature_dep(token))); 763 | eg._ners.push_back(featureName2ID(m_nerAlphabet, labelSequence[i])); 764 | 765 | if(boolTokenInEntity(token, Disease)) { 766 | eg._idx_e1.insert(i); 767 | 768 | // if like "Listeria sp.", "." is not in dependency tree 769 | if(bacTkEnd == -1 /*&& sent.tokens[i].depGov!=-1*/) 770 | bacTkEnd = i; 771 | else if(bacTkEnd < i /*&& sent.tokens[i].depGov!=-1*/) 772 | bacTkEnd = i; 773 | } 774 | 775 | if(boolTokenInEntity(token, Chemical)) { 776 | eg._idx_e2.insert(i); 777 | 778 | if(locTkEnd == -1 /*&& sent.tokens[i].depGov!=-1*/) 779 | locTkEnd = i; 780 | else if(locTkEnd < i /*&& sent.tokens[i].depGov!=-1*/) 781 | locTkEnd = i; 782 | } 783 | 784 | if(isTokenBetweenTwoEntities(token, former, latter)) { 785 | eg._between_words.push_back(i); 786 | } 787 | } 788 | 789 | if(eg._between_words.size()==0) { 790 | eg._between_words.push_back(featureName2ID(m_wordAlphabet, nullkey)); 791 | } 792 | 793 | assert(bacTkEnd!=-1); 794 | assert(locTkEnd!=-1); 795 | 796 | 797 | // use SDP based on the last word of the entity 798 | vector sdpA; 799 | vector sdpB; 800 | int common = fox::Dependency::getCommonAncestor(sent.tokens, bacTkEnd, locTkEnd, 801 | sdpA, sdpB); 802 | 803 | 804 | //assert(common!=-2); // no common ancestor 805 | assert(common!=-1); // common ancestor is root 806 | 807 | if(common == -2) { 808 | eg._idxOnSDP_E12A.push_back(bacTkEnd); 809 | eg._idxOnSDP_E22A.push_back(locTkEnd); 810 | } else { 811 | for(int sdpANodeIdx=0;sdpANodeIdx & processed, Tool& tool) { 824 | 825 | hash_map word_stat; 826 | hash_map pos_stat; 827 | hash_map dep_stat; 828 | hash_map char_stat; 829 | 830 | for(int i=0;i characters = feature_character(sent.tokens[j]); 841 | for(int i=0;i feature_character(const fox::Token& token) { 877 | vector ret; 878 | string word = feature_word(token); 879 | for(int i=0;i& nrmat, Alphabet& alphabet, int embSize, int seed) { 885 | double* emb = new double[alphabet.size()*embSize]; 886 | fox::initArray2((double *)emb, (int)alphabet.size(), embSize, 0.0); 887 | 888 | vector known; 889 | map IDs; 890 | alphabet2vectormap(alphabet, known, IDs); 891 | 892 | fox::randomInitEmb((double*)emb, embSize, known, unknownkey, 893 | IDs, false, m_options.initRange, seed); 894 | 895 | nrmat.resize(alphabet.size(), embSize); 896 | array2NRMat((double*) emb, alphabet.size(), embSize, nrmat); 897 | 898 | delete[] emb; 899 | } 900 | 901 | template 902 | void averageUnkownEmb(Alphabet& alphabet, LookupTable& table, int embSize) { 903 | 904 | // unknown cannot be trained, use the average embedding 905 | int unknownID = alphabet.from_string(unknownkey); 906 | Tensor temp = NewTensor(Shape2(1, embSize), d_zero); 907 | int number = table._nVSize-1; 908 | table._E[unknownID] = 0.0; 909 | for(int i=0;i& stat, Alphabet& alphabet, const string& label) { 921 | 922 | cout << label<<" num: " << stat.size() << endl; 923 | alphabet.set_fixed_flag(false); 924 | hash_map::iterator feat_iter; 925 | for (feat_iter = stat.begin(); feat_iter != stat.end(); feat_iter++) { 926 | // if not fine tune, add all the words; if fine tune, add the words considering wordCutOff 927 | // in order to train unknown 928 | if (!m_options.wordEmbFineTune || feat_iter->second > m_options.wordCutOff) { 929 | alphabet.from_string(feat_iter->first); 930 | } 931 | } 932 | cout << "alphabet "<< label<<" num: " << alphabet.size() << endl; 933 | alphabet.set_fixed_flag(true); 934 | 935 | } 936 | 937 | 938 | void featureName2ID(Alphabet& alphabet, const string& featureName, vector& vfeatureID) { 939 | int id = alphabet.from_string(featureName); 940 | if(id >=0) 941 | vfeatureID.push_back(id); 942 | else 943 | vfeatureID.push_back(0); // assume unknownID is zero 944 | } 945 | 946 | void featureName2ID(Alphabet& alphabet, const vector& featureName, vector& vfeatureID) { 947 | for(int i=0;i=0) 950 | vfeatureID.push_back(id); 951 | else 952 | vfeatureID.push_back(0); // assume unknownID is zero 953 | } 954 | } 955 | 956 | int featureName2ID(Alphabet& alphabet, const string& featureName) { 957 | int id = alphabet.from_string(featureName); 958 | if(id >=0) 959 | return id; 960 | else 961 | return 0; // assume unknownID is zero 962 | } 963 | 964 | }; 965 | 966 | 967 | 968 | #endif /* NNBB3_H_ */ 969 | 970 | --------------------------------------------------------------------------------