├── .gitignore ├── LICENSE ├── NAACL_2021_Appendix.pdf ├── README.md ├── aida_ontology_cleaned.csv ├── docs ├── _config.yml ├── figures │ └── model.png └── index.md ├── event_role_ACE.json ├── event_role_KAIROS.json ├── outputs └── wikievents-pointer-pred │ ├── predictions.html │ └── predictions.jsonl ├── pronoun_list.txt ├── scripts ├── test_ace.sh ├── test_kairos.sh ├── test_rams.sh ├── train_ace.sh ├── train_kairos.sh └── train_rams.sh ├── src └── genie │ ├── ACE_data_module.py │ ├── KAIROS_data_module.py │ ├── __init__.py │ ├── constrained_gen.py │ ├── convert_gen_to_output.py │ ├── data.py │ ├── data_module.py │ ├── model.py │ ├── network.py │ ├── pipeline_scorer.py │ ├── scorer.py │ └── utils.py ├── train.py └── viz ├── visualize_output_KAIROS.py └── visualize_output_RAMS.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | logs/ 3 | data/ 4 | preprocessed_*/ 5 | __pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Zoey Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NAACL_2021_Appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raspberryice/gen-arg/253e0889b2377e0f7084cb406cf5d4142ee8a365/NAACL_2021_Appendix.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Argument Extraction by Generation 2 | 3 | Code for paper "Document-Level Argument Extraction by Conditional Generation". NAACL 21' 4 | 5 | 6 | ## Dependencies 7 | - pytorch=1.6 8 | - transformers=3.1.0 9 | - pytorch-lightning=1.0.6 10 | - spacy=2.3.2 11 | 12 | ## Model Checkpoints 13 | Checkpoints trained from this repo are shared for the WikiEvents dataset and the ACE dataset are available at: [s3://gen-arg-data/checkpoints/]. 14 | 15 | You can download all the contents from the S3 bucket using AWS cli: `aws s3 cp s3://gen-arg-data/checkpoints/ ./ --recursive` 16 | 17 | ### Model Predictions 18 | The model predictions on WikiEvents is provided in `outputs/wikievents-pointer-pred`. 19 | Running this file through the `scorer.py` function should give you the exact same numbers as Table 5. 20 | 21 | ## Datasets 22 | - RAMS (Download at [https://nlp.jhu.edu/rams/]) 23 | - ACE05 (Access from LDC[https://catalog.ldc.upenn.edu/LDC2006T06] and preprocessing following OneIE[http://blender.cs.illinois.edu/software/oneie/]) 24 | - WikiEvents (Available here [s3://gen-arg-data/wikievents/]) 25 | 26 | You can download the data through the AWS cli or AWS console. 27 | Alternatively, you can download individual files by 28 | - `wget https://gen-arg-data.s3.us-east-2.amazonaws.com/wikievents/data/.jsonl` for split={train, dev,test}. 29 | - `wget https://gen-arg-data.s3.us-east-2.amazonaws.com/wikievents/data/coref/.jsonlines` for split={train, dev, test}. 30 | 31 | Additional processed test files for RAMS can be downloaded by 32 | - `wget https://gen-arg-data.s3.us-east-2.amazonaws.com/RAMS/test_head_coref.jsonlines` 33 | - `wget https://gen-arg-data.s3.us-east-2.amazonaws.com/RAMS/test_head.jsonlines` 34 | -------------------------------------------------------------------------------- /aida_ontology_cleaned.csv: -------------------------------------------------------------------------------- 1 | event_type,template,arg1,arg2,arg3,arg4,arg5 2 | artifactexistence.artifactfailure.mechanicalfailure, mechanical artifact failed due to instrument at place,evt152arg01mechanicalartifact ,evt152arg02instrument ,evt152arg03place ,, 3 | artifactexistence.damagedestroy.unspecified, damaged or destroyed using instrument in place,evt001arg01damagerdestroyer,evt001arg02artifact,evt001arg03instrument,evt001arg04place, 4 | artifactexistence.damagedestroy.damage, damaged using instrument in place,evt002arg01damager,evt002arg02artifact,evt002arg03instrument,evt002arg04place, 5 | artifactexistence.damagedestroy.destroy, destroyed using instrument in place,evt003arg01destroyer,evt003arg02artifact,evt003arg03instrument,evt003arg04place, 6 | artifactexistence.shortage.shortage, experienced a shortage of supply at place,evt149arg01experiencer ,evt149arg02supply ,evt149arg03place ,, 7 | conflict.attack.unspecified, attacked using at place,evt004arg01attacker,evt004arg02target,evt004arg03instrument,evt004arg04place, 8 | conflict.attack.airstrikemissilestrike, attacked using at place,evt005arg01attacker,evt005arg02target,evt005arg03instrument,evt005arg04place, 9 | conflict.attack.biologicalchemicalpoisonattack, attacked using at place,evt006arg01attacker,evt006arg02target,evt006arg03instrument,evt006arg04place, 10 | conflict.attack.bombing, attacked using at place,evt007arg01attacker,evt007arg02target,evt007arg03instrument,evt007arg04place, 11 | conflict.attack.firearmattack, attacked using at place,evt008arg01attacker,evt008arg02target,evt008arg03instrument,evt008arg04place, 12 | conflict.attack.hanging, attacked using at place,evt009arg01attacker,evt009arg02target,evt009arg03instrument,evt009arg04place, 13 | conflict.attack.invade, attacked using at place,evt010arg01attacker,evt010arg02target,evt010arg03instrument,evt010arg04place, 14 | conflict.attack.selfdirectedbattle, attacked using at place,evt011arg01attacker,evt011arg02target,evt011arg03instrument,evt011arg04place, 15 | conflict.attack.setfire, attacked using at place,evt012arg01attacker,evt012arg02target,evt012arg03instrument,evt012arg04place, 16 | conflict.attack.stabbing, attacked using at place,evt013arg01attacker,evt013arg02target,evt013arg03instrument,evt013arg04place, 17 | conflict.attack.stealrobhijack," attacked using at place, in order to take ",evt014arg01attacker,evt014arg02target,evt014arg03instrument,evt014arg04place,evt014arg05artifact 18 | conflict.attack.strangling, attacked using at place,evt015arg01attacker,evt015arg02target,evt015arg03instrument,evt015arg04place, 19 | conflict.coup.coup, was deposed by at place,evt151arg01deposedentity ,evt151arg02deposingentity ,evt151arg03place ,, 20 | conflict.demonstrate.unspecified, was in a demonstration at place,evt016arg01demonstrator,evt016arg02place,,, 21 | conflict.demonstrate.marchprotestpoliticalgathering, was in a demonstration or protest at place,evt017arg01demonstrator,evt017arg02place,,, 22 | conflict.yield.unspecified, yielded to at place,evt018arg01yielder,evt018arg02recipient,evt018arg03place,, 23 | conflict.yield.retreat, retreated from place to place,evt019arg01retreater,evt019arg02origin,evt019arg03destination,, 24 | conflict.yield.surrender, surrendered to at place,evt020arg01surrenderer,evt020arg02recipient,evt020arg03place,, 25 | contact.collaborate.unspecified, communicated with at place,evt021arg01participant,evt021arg02participant,evt021arg03place,, 26 | contact.collaborate.correspondence, communicated remotely with at place,evt022arg01participant,evt022arg02participant,evt022arg03place,, 27 | contact.collaborate.meet, met face-to-face with at place,evt023arg01participant,evt023arg02participant,evt023arg03place,, 28 | contact.commandorder.unspecified, communicated with about topic at place,evt024arg01communicator,evt024arg02recipient,evt024arg03place,evt024arg04topic, 29 | contact.commandorder.broadcast, communicated to about topic at place (one-way communication),evt025arg01communicator,evt025arg02recipient,evt025arg03place,evt025arg04topic, 30 | contact.commandorder.correspondence, communicated remotely with about topic at place,evt026arg01communicator,evt026arg02recipient,evt026arg03place,evt026arg04topic, 31 | contact.commandorder.meet, met face-to-face with about topic at place,evt027arg01communicator,evt027arg02recipient,evt027arg03place,evt027arg04topic, 32 | contact.commitmentpromiseexpressintent.unspecified, communicated with about topic at place,evt028arg01communicator,evt028arg02recipient,evt028arg03place,evt028arg04topic, 33 | contact.commitmentpromiseexpressintent.broadcast, communicated to about topic at place (one-way communication),evt029arg01communicator,evt029arg02recipient,evt029arg03place,evt029arg04topic, 34 | contact.commitmentpromiseexpressintent.correspondence, communicated remotely with about topic at place,evt030arg01communicator,evt030arg02recipient,evt030arg03place,evt030arg04topic, 35 | contact.commitmentpromiseexpressintent.meet, met face-to-face with about topic at place,evt031arg01communicator,evt031arg02recipient,evt031arg03place,evt031arg04topic, 36 | contact.discussion.unspecified, communicated with at place,evt032arg01participant,evt032arg02participant,evt032arg03place,, 37 | contact.discussion.correspondence, communicated remotely with at place,evt033arg01participant,evt033arg02participant,evt033arg03place,, 38 | contact.discussion.meet, met face-to-face with at place,evt034arg01participant,evt034arg02participant,evt034arg03place,, 39 | contact.funeralvigil.unspecified, communicated with during a funeral or vigil for at place,evt035arg01participant,evt035arg02participant,evt035arg03deceased,evt035arg04place, 40 | contact.funeralvigil.meet, met face-to-face with during a funeral or vigil for at place,evt036arg01participant,evt036arg02participant,evt036arg03deceased,evt036arg04place, 41 | contact.mediastatement.unspecified, communicated with at place,evt037arg01communicator,evt037arg02recipient,evt037arg03place,, 42 | contact.mediastatement.broadcast, communicated to at place (one-way communication),evt038arg01communicator,evt038arg02recipient,evt038arg03place,, 43 | contact.negotiate.unspecified, communicated with about topic at place,evt039arg01participant,evt039arg02participant,evt039arg03place,evt039arg04topic, 44 | contact.negotiate.correspondence, communicated remotely with about topic at place,evt040arg01participant,evt040arg02participant,evt040arg03place,evt040arg04topic, 45 | contact.negotiate.meet, met face-to-face with about topic at place,evt041arg01participant,evt041arg02participant,evt041arg03place,evt041arg04topic, 46 | contact.prevarication.unspecified, communicated with about topic at place,evt042arg01communicator,evt042arg02recipient,evt042arg03place,evt042arg04topic, 47 | contact.prevarication.broadcast, communicated to about topic at place (one-way communication),evt043arg01communicator,evt043arg02recipient,evt043arg03place,evt043arg04topic, 48 | contact.prevarication.correspondence, communicated remotely with about topic at place,evt044arg01communicator,evt044arg02recipient,evt044arg03place,evt044arg04topic, 49 | contact.prevarication.meet, met face-to-face with about topic at place,evt045arg01communicator,evt045arg02recipient,evt045arg03place,evt045arg04topic, 50 | contact.publicstatementinperson.unspecified, communicated with at place,evt046arg01communicator,evt046arg02recipient,evt046arg03place,, 51 | contact.publicstatementinperson.broadcast, communicated to at place (one-way communication),evt047arg01communicator,evt047arg02recipient,evt047arg03place,, 52 | contact.requestadvise.unspecified, communicated with about topic at place,evt048arg01communicator,evt048arg02recipient,evt048arg03place,evt048arg04topic, 53 | contact.requestadvise.broadcast, communicated to about topic at place (one-way communication),evt049arg01communicator,evt049arg02recipient,evt049arg03place,evt049arg04topic, 54 | contact.requestadvise.correspondence, communicated remotely with about topic at place,evt050arg01communicator,evt050arg02recipient,evt050arg03place,evt050arg04topic, 55 | contact.requestadvise.meet, met face-to-face with about topic at place,evt051arg01communicator,evt051arg02recipient,evt051arg03place,evt051arg04topic, 56 | contact.threatencoerce.unspecified, communicated with about topic at place,evt052arg01communicator,evt052arg02recipient,evt052arg03place,evt052arg04topic, 57 | contact.threatencoerce.broadcast, communicated to about topic at place (one-way communication),evt053arg01communicator,evt053arg02recipient,evt053arg03place,evt053arg04topic, 58 | contact.threatencoerce.correspondence, communicated remotely with about topic at place,evt054arg01communicator,evt054arg02recipient,evt054arg03place,evt054arg04topic, 59 | contact.threatencoerce.meet, met face-to-face with about topic at place,evt055arg01communicator,evt055arg02recipient,evt055arg03place,evt055arg04topic, 60 | disaster.accidentcrash.accidentcrash, person in vehicle crashed into at place,evt057arg01driverpassenger,evt057arg02vehicle,evt057arg03crashobject,evt057arg04place, 61 | disaster.diseaseoutbreak.diseaseoutbreak, disease broke out among victims or population at place,evt148arg01disease ,evt148arg02victim ,evt148arg03place ,, 62 | disaster.fireexplosion.fireexplosion, caught fire or exploded from instrument at place,evt059arg01fireexplosionobject,evt059arg02instrument,evt059arg03place,, 63 | genericcrime.genericcrime.genericcrime, committed a crime against at place,evt154arg01perpetrator ,evt154arg02victim ,evt154arg03place ,, 64 | government.agreements.unspecified, and signed an agreement in place,evt060arg01participant,evt060arg02participant,evt060arg03place,, 65 | government.agreements.acceptagreementcontractceasefire, and signed an agreement in place,evt061arg01participant,evt061arg02participant,evt061arg03place,, 66 | government.agreements.rejectnullifyagreementcontractceasefire, rejected or nullified an agreement with in place,evt062arg01rejecternullifier,evt062arg02otherparticipant,evt062arg03place,, 67 | government.agreements.violateagreement, violated an agreement with in place,evt063arg01violator,evt063arg02otherparticipant,evt063arg03place,, 68 | government.convene.convene, convened at place,evt145arg01convener ,evt145arg02convenedthing ,evt145arg03place ,, 69 | government.formation.unspecified, was formed by in place,evt064arg01gpe,evt064arg02founder,evt064arg03place,, 70 | government.formation.mergegpe, merged with at place,evt065arg01participant,evt065arg02participant,evt065arg03place,, 71 | government.formation.startgpe, was started by in place,evt066arg01gpe,evt066arg02founder,evt066arg03place,, 72 | government.legislate.legislate, legislature enacted law in place,evt068arg01governmentbody,evt068arg02law,evt068arg03place,, 73 | government.spy.spy, spied on to the benefit of in place,evt070arg01spy,evt070arg02observedentity,evt070arg03beneficiary,evt070arg04place, 74 | government.vote.unspecified, voted for on ballot with results in place,evt071arg01voter,evt071arg02candidate,evt071arg03ballot,evt071arg04result,evt071arg05place 75 | government.vote.castvote, voted for on ballot with results in place,evt072arg01voter,evt072arg02candidate,evt072arg03ballot,evt072arg04result,evt072arg05place 76 | government.vote.violationspreventvote, prevented from voting for on ballot in place,evt073arg01preventer,evt073arg02voter,evt073arg03candidate,evt073arg04ballot,evt073arg05place 77 | inspection.sensoryobserve.unspecified, observed in place,evt074arg01observer,evt074arg02observedentity,evt074arg03place,, 78 | inspection.sensoryobserve.inspectpeopleorganization, inspected in place,evt075arg01inspector,evt075arg02inspectedentity,evt075arg03place,, 79 | inspection.sensoryobserve.monitorelection, monitored taking part in an election in place,evt076arg01monitor,evt076arg02monitoredentity,evt076arg03place,, 80 | inspection.sensoryobserve.physicalinvestigateinspect, inspected in place,evt077arg01inspector,evt077arg02inspectedentity,evt077arg03place,, 81 | inspection.targetaimat.targetaimat, physically targeted with instrument at place,evt153arg01targeter ,evt153arg02target ,evt153arg03instrument ,evt153arg04place , 82 | justice.arrestjaildetain.arrestjaildetain, arrested or jailed for crime at place,evt079arg01jailer,evt079arg02detainee,evt079arg03crime,evt079arg04place, 83 | justice.initiatejudicialprocess.unspecified, initiated judicial process pertaining to before court or judge for crime in place,evt080arg01prosecutor,evt080arg02defendant,evt080arg03judgecourt,evt080arg04crime,evt080arg05place 84 | justice.initiatejudicialprocess.chargeindict, charged or indicted before court or judge for crime in place,evt081arg01prosecutor,evt081arg02defendant,evt081arg03judgecourt,evt081arg04crime,evt081arg05place 85 | justice.initiatejudicialprocess.trialhearing, tried before court or judge for crime in place,evt082arg01prosecutor,evt082arg02defendant,evt082arg03judgecourt,evt082arg04crime,evt082arg05place 86 | justice.investigate.unspecified, investigated in place,evt083arg01investigator,evt083arg02defendant,evt083arg03place,, 87 | justice.investigate.investigatecrime, investigated for crime in place,evt084arg01investigator,evt084arg02defendant,evt084arg03crime,evt084arg04place, 88 | justice.judicialconsequences.unspecified," court or judge decided consequences of crime, committed by , in place",evt085arg01judgecourt,evt085arg02defendant,evt085arg03crime,evt085arg04place, 89 | justice.judicialconsequences.convict, court or judge convicted of crime in place,evt086arg01judgecourt,evt086arg02defendant,evt086arg03crime,evt086arg04place, 90 | justice.judicialconsequences.execute, executed for crime in place,evt087arg01executioner,evt087arg02defendant,evt087arg03crime,evt087arg04place, 91 | justice.judicialconsequences.extradite, extradited for crime from place to place,evt088arg01extraditer,evt088arg02defendant,evt088arg03crime,evt088arg04origin,evt088arg05destination 92 | life.die.unspecified," died at place from medical issue, killed by killer",evt089arg01victim,evt089arg02place,evt089arg03killer,evt089arg04medicalissue, 93 | life.die.deathcausedbyviolentevents, killed using instrument or medical issue at place,evt090arg01killer,evt090arg02victim,evt090arg03instrument,evt090arg04place,evt090arg05medicalissue 94 | life.die.nonviolentdeath," died at place from medical issue, killed by killer",evt091arg01victim,evt091arg02place,evt091arg03killer,evt091arg04medicalissue, 95 | life.injure.unspecified, was injured by with medical issue at place,evt092arg01victim,evt092arg02injurer,evt092arg03place,evt092arg04medicalissue, 96 | life.injure.illnessdegradationhungerthirst, has extreme hunger or thirst from medical issue imposed by injurer at place,evt093arg01victim,evt093arg02place,evt093arg03injurer,evt093arg04medicalissue, 97 | life.injure.illnessdegradationphysical, person has some physical degradation from medical issue imposed by injurer at place,evt094arg01victim,evt094arg02place,evt094arg03injurer,evt094arg04medicalissue, 98 | life.injure.illnessdegredationsickness," has sickness or illness at place, deliberately infected by ",evt150arg01victim ,evt150arg02injurer ,evt150arg03disease ,evt150arg04place , 99 | life.injure.injurycausedbyviolentevents, injured using instrument or medical issue at place,evt095arg01injurer,evt095arg02victim,evt095arg03instrument,evt095arg04place,evt095arg05medicalissue 100 | manufacture.artifact.unspecified, manufactured or created or produced using at place,evt096arg01manufacturer,evt096arg02artifact,evt096arg03instrument,evt096arg04place, 101 | manufacture.artifact.build, manufactured or created or produced using at place,evt097arg01manufacturer,evt097arg02artifact,evt097arg03instrument,evt097arg04place, 102 | manufacture.artifact.createintellectualproperty, manufactured or created or produced using at place,evt098arg01manufacturer,evt098arg02artifact,evt098arg03instrument,evt098arg04place, 103 | manufacture.artifact.createmanufacture, manufactured or created or produced using at place,evt099arg01manufacturer,evt099arg02artifact,evt099arg03instrument,evt099arg04place, 104 | medical.intervention.intervention, treater treated patient for medical issue with means at place,evt147arg01treater ,evt147arg02patient ,evt147arg03medicalissue ,evt147arg04instrument ,evt147arg05place 105 | movement.transportartifact.unspecified, transported in from place to place,evt100arg01transporter,evt100arg02artifact,evt100arg03vehicle,evt100arg04origin,evt100arg05destination 106 | movement.transportartifact.bringcarryunload, transported in from place to place,evt101arg01transporter,evt101arg02artifact,evt101arg03vehicle,evt101arg04origin,evt101arg05destination 107 | movement.transportartifact.disperseseparate, transported in from place to place,evt102arg01transporter,evt102arg02artifact,evt102arg03vehicle,evt102arg04origin,evt102arg05destination 108 | movement.transportartifact.fall, fell from place to place,evt103arg01artifact,evt103arg02origin,evt103arg03destination,, 109 | movement.transportartifact.grantentry, grants entry to place from place,evt104arg01transporter,evt104arg02artifact,evt104arg03origin,evt104arg04destination, 110 | movement.transportartifact.hide," concealed in place, transported in vehicle from place",evt105arg01transporter,evt105arg02artifact,evt105arg03hidingplace,evt105arg04vehicle,evt105arg05origin 111 | movement.transportartifact.lossofcontrol, lost control of moving at place,evt146arg01controller ,evt146arg02controlledthing ,evt146arg03place ,, 112 | movement.transportartifact.nonviolentthrowlaunch, transported in from place to place,evt106arg01transporter,evt106arg02artifact,evt106arg03vehicle,evt106arg04origin,evt106arg05destination 113 | movement.transportartifact.prevententry, prevents from transporting from place to place,evt107arg01preventer,evt107arg02transporter,evt107arg03artifact,evt107arg04origin,evt107arg05destination 114 | movement.transportartifact.preventexit, prevents from transporting from place to place,evt108arg01preventer,evt108arg02transporter,evt108arg03artifact,evt108arg04origin,evt108arg05destination 115 | movement.transportartifact.receiveimport, transported in from place to place,evt109arg01transporter,evt109arg02artifact,evt109arg03vehicle,evt109arg04origin,evt109arg05destination 116 | movement.transportartifact.sendsupplyexport, transported in from place to place,evt110arg01transporter,evt110arg02artifact,evt110arg03vehicle,evt110arg04origin,evt110arg05destination 117 | movement.transportartifact.smuggleextract, transported in from place to place,evt111arg01transporter,evt111arg02artifact,evt111arg03vehicle,evt111arg04origin,evt111arg05destination 118 | movement.transportperson.unspecified, transported in from place to place,evt112arg01transporter,evt112arg02passenger,evt112arg03vehicle,evt112arg04origin,evt112arg05destination 119 | movement.transportperson.bringcarryunload, transported in from place to place,evt113arg01transporter,evt113arg02passenger,evt113arg03vehicle,evt113arg04origin,evt113arg05destination 120 | movement.transportperson.disperseseparate, transported in from place to place,evt114arg01transporter,evt114arg02passenger,evt114arg03vehicle,evt114arg04origin,evt114arg05destination 121 | movement.transportperson.evacuationrescue, transported in from place to place,evt115arg01transporter,evt115arg02passenger,evt115arg03vehicle,evt115arg04origin,evt115arg05destination 122 | movement.transportperson.fall, fell from place to place,evt116arg01passenger,evt116arg02origin,evt116arg03destination,, 123 | movement.transportperson.grantentryasylum, grants entry to transporting from place to place,evt117arg01granter,evt117arg02transporter,evt117arg03passenger,evt117arg04origin,evt117arg05destination 124 | movement.transportperson.hide," concealed in place, transported in vehicle from place",evt118arg01transporter,evt118arg02passenger,evt118arg03hidingplace,evt118arg04vehicle,evt118arg05origin 125 | movement.transportperson.prevententry, prevents from transporting from place to place,evt119arg01preventer,evt119arg02transporter,evt119arg03passenger,evt119arg04origin,evt119arg05destination 126 | movement.transportperson.preventexit, prevents from transporting from place to place,evt120arg01preventer,evt120arg02transporter,evt120arg03passenger,evt120arg04origin,evt120arg05destination 127 | movement.transportperson.selfmotion, moved in from place to place,evt121arg01transporter,evt121arg02vehicle,evt121arg03origin,evt121arg04destination, 128 | movement.transportperson.smuggleextract, transported in from place to place,evt122arg01transporter,evt122arg02passenger,evt122arg03vehicle,evt122arg04origin,evt122arg05destination 129 | personnel.elect.unspecified, elected in place,evt123arg01voter,evt123arg02candidate,evt123arg03place,, 130 | personnel.elect.winelection, elected in place,evt124arg01voter,evt124arg02candidate,evt124arg03place,, 131 | personnel.endposition.unspecified, stopped working at in place,evt125arg01employee,evt125arg02placeofemployment,evt125arg03place,, 132 | personnel.endposition.firinglayoff, stopped working at in place,evt126arg01employee,evt126arg02placeofemployment,evt126arg03place,, 133 | personnel.endposition.quitretire, stopped working at in place,evt127arg01employee,evt127arg02placeofemployment,evt127arg03place,, 134 | personnel.startposition.unspecified, started working at in place,evt128arg01employee,evt128arg02placeofemployment,evt128arg03place,, 135 | personnel.startposition.hiring, started working at in place,evt129arg01employee,evt129arg02placeofemployment,evt129arg03place,, 136 | transaction.transaction.unspecified,A transaction occurred between and for the benefit of at place,evt130arg01participant,evt130arg02participant,evt130arg03beneficiary,evt130arg04place, 137 | transaction.transaction.embargosanction, prevented from giving to at place,evt131arg01preventer,evt131arg02giver,evt131arg03recipient,evt131arg04artifactmoney,evt131arg05place 138 | transaction.transaction.giftgrantprovideaid, gave something to for the benefit of at place,evt132arg01giver,evt132arg02recipient,evt132arg03beneficiary,evt132arg04place, 139 | transaction.transfermoney.unspecified, gave money to for the benefit of at place,evt133arg01giver,evt133arg02recipient,evt133arg03beneficiary,evt133arg04money,evt133arg05place 140 | transaction.transfermoney.borrowlend, gave money to for the benefit of at place,evt134arg01giver,evt134arg02recipient,evt134arg03beneficiary,evt134arg04money,evt134arg05place 141 | transaction.transfermoney.embargosanction, prevented from giving to at place,evt135arg01preventer,evt135arg02giver,evt135arg03recipient,evt135arg04money,evt135arg05place 142 | transaction.transfermoney.giftgrantprovideaid, gave money to for the benefit of at place,evt136arg01giver,evt136arg02recipient,evt136arg03beneficiary,evt136arg04money,evt136arg05place 143 | transaction.transfermoney.payforservice, gave money to for the benefit of at place,evt137arg01giver,evt137arg02recipient,evt137arg03beneficiary,evt137arg04money,evt137arg05place 144 | transaction.transfermoney.purchase, gave money to for the benefit of at place,evt138arg01giver,evt138arg02recipient,evt138arg03beneficiary,evt138arg04money,evt138arg05place 145 | transaction.transferownership.unspecified, gave to for the benefit of at place,evt139arg01giver,evt139arg02recipient,evt139arg03beneficiary,evt139arg04artifact,evt139arg05place 146 | transaction.transferownership.borrowlend, gave to for the benefit of at place,evt140arg01giver,evt140arg02recipient,evt140arg03beneficiary,evt140arg04artifact,evt140arg05place 147 | transaction.transferownership.embargosanction, prevented from giving to at place,evt141arg01preventer,evt141arg02giver,evt141arg03recipient,evt141arg04artifact,evt141arg05place 148 | transaction.transferownership.giftgrantprovideaid, gave to for the benefit of at place,evt142arg01giver,evt142arg02recipient,evt142arg03beneficiary,evt142arg04artifact,evt142arg05place 149 | transaction.transferownership.purchase, gave to for the benefit of at place,evt143arg01giver,evt143arg02recipient,evt143arg03beneficiary,evt143arg04artifact,evt143arg05place 150 | transaction.transaction.transfercontrol, transferred control of to for the benefit of in place,evt144arg01giver,evt144arg02recipient,evt144arg03beneficiary,evt144arg04territoryorfacility,evt144arg05place 151 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /docs/figures/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raspberryice/gen-arg/253e0889b2377e0f7084cb406cf5d4142ee8a365/docs/figures/model.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ## Intro 2 | 3 | > TLDR: 4 | > This paper treats document-level argument extraction as a conditional generation problem by filling in event templates. 5 | 6 | 7 | ![Argument Extraction Model Overview](figures/model.png) 8 | 9 | ## Dataset 10 | 11 | The WikiEvents datasets is under `data/wikievents`. 12 | 13 | The RAMS dataset is for download [here](https://nlp.jhu.edu/rams/). 14 | 15 | The ACE dataset is provided by LDC and unfortunately we cannot release it directly. See [this link](https://catalog.ldc.upenn.edu/LDC2006T06) for details. 16 | 17 | ## Code 18 | 19 | - v0.1 (April 12, 2021): Basic generation model for argument extraction. (This does not include the post-processing script.) 20 | 21 | 22 | 23 | This page is currently under construction. 24 | -------------------------------------------------------------------------------- /event_role_ACE.json: -------------------------------------------------------------------------------- 1 | { 2 | "Movement:Transport": { 3 | "event_id": "evt001", 4 | "template":" transported in vehicle from place to place", 5 | "roles":[ "Agent", 6 | "Artifact", 7 | "Vehicle", 8 | "Origin", 9 | "Destination"], 10 | "i-label": 1, 11 | "keywords": ["transport","move","travel", "head"] 12 | }, 13 | "Personnel:Elect": { 14 | "event_id": "evt002", 15 | "template": " elected in place", 16 | "roles":[ "Entity", "Person", "Place"], 17 | "i-label":2, 18 | "keywords":["elect", "vote", "election"] 19 | }, 20 | "Personnel:Start-Position": { 21 | "event_id": "evt003", 22 | "template": " started working at organization in place", 23 | "roles":["Person", "Entity", "Place"], 24 | "i-label": 3, 25 | "keywords":["hire","employ", "appoint"] 26 | }, 27 | "Personnel:Nominate": { 28 | "event_id": "evt004", 29 | "template": " nominated ", 30 | "roles":["Agent", "Person"], 31 | "i-label":4, 32 | "keywords":["nominate", "name"] 33 | }, 34 | "Personnel:End-Position": { 35 | "event_id": "evt005", 36 | "template": " stopped working at organization in place", 37 | "roles":[ 38 | "Person", 39 | "Entity", 40 | "Place" 41 | ], 42 | "i-label": 5, 43 | "keywords":["resign", "retire", "former", "fire","dismiss"] 44 | }, 45 | "Conflict:Attack": { 46 | "event_id": "evt006", 47 | "template":" attacked hurting victims using instrument at place", 48 | "roles":[ 49 | "Attacker", 50 | "Target", 51 | "Instrument", 52 | "Place", 53 | "Victim" 54 | ], 55 | "i-label": 6, 56 | "keywords":["attack","invade", "shoot", "explode", "war", "terrorism"] 57 | }, 58 | "Contact:Meet": { 59 | "event_id":"evt007", 60 | "template": " met with in place", 61 | "roles":[ 62 | "Entity", 63 | "Entity", 64 | "Place" 65 | ], 66 | "i-label":7, 67 | "keywords": ["meet", "meeting", "talk", "summit"] 68 | }, 69 | "Life:Marry": { 70 | "event_id": "evt008", 71 | "template": " married in place", 72 | "roles":[ 73 | "Person", 74 | "Person", 75 | "Place" 76 | ], 77 | "i-label": 8, 78 | "keywords": ["marry", "wed", "spouse"] 79 | }, 80 | "Transaction:Transfer-Money": { 81 | "event_id": "evt009", 82 | "template": " gave money to for the benefit of in place", 83 | "roles":[ 84 | "Giver", 85 | "Recipient", 86 | "Beneficiary", 87 | "Place" 88 | ], 89 | "i-label": 9, 90 | "keywords": ["pay", "transfer", "donate", "lend", "borrow", "compensate"] 91 | }, 92 | "Conflict:Demonstrate": { 93 | "event_id": "evt010", 94 | "template": " demonstrated at place", 95 | "roles":[ 96 | "Entity", 97 | "Place" 98 | ], 99 | "i-label": 10, 100 | "keywords":["demonstrate", "protest","rally", "riot"] 101 | }, 102 | "Business:End-Org": { 103 | "event_id": "evt011", 104 | "template": " organization shut down at place", 105 | "roles":[ 106 | "Org", 107 | "Place" 108 | ], 109 | "i-label": 11, 110 | "keywords": ["shut", "fold", "collapse", "demise", "failure"] 111 | }, 112 | "Justice:Sue": { 113 | "event_id": "evt012", 114 | "template": " sued before court or judge in place", 115 | "roles":[ 116 | "Plaintiff", 117 | "Defendant", 118 | "Adjudicator", 119 | "Place" 120 | ], 121 | "i-label": 12, 122 | "keywords": ["sue", "lawsuit"] 123 | }, 124 | "Life:Injure": { 125 | "event_id": "evt013", 126 | "template": " injured with instrument in place", 127 | "roles":["Agent", 128 | "Victim", 129 | "Instrument", 130 | "Place" 131 | ], 132 | "i-label": 13, 133 | "keywords": ["injure", "hurt", "wound"] 134 | }, 135 | "Life:Die": { 136 | "event_id": "evt014", 137 | "template": " killed with instrument in place", 138 | "roles":[ 139 | "Agent", 140 | "Victim", 141 | "Instrument", 142 | "Place" 143 | ], 144 | "i-label": 14, 145 | "keywords": ["die", "kill", "death", "assasinate", "suicide"] 146 | }, 147 | "Justice:Arrest-Jail": { 148 | "event_id": "evt015", 149 | "template": " arrested in place", 150 | "roles":[ 151 | "Agent", 152 | "Person", 153 | "Place" 154 | ], 155 | "i-label": 15, 156 | "keywords":["arrest", "jail", "imprison"] 157 | }, 158 | "Contact:Phone-Write": { 159 | "event_id": "evt016", 160 | "template": " communicated remotely with at place", 161 | "roles":[ 162 | "Entity", 163 | "Entity", 164 | "Place" 165 | ], 166 | "i-label": 16, 167 | "keywords": ["call", "mail", "letter", "telephone", "correspondence"] 168 | }, 169 | "Transaction:Transfer-Ownership": { 170 | "event_id": "evt017", 171 | "template":" gave to for the benefit of at place", 172 | "roles":[ 173 | "Seller", 174 | "Buyer", 175 | "Beneficiary", 176 | "Artifact", 177 | "Place"], 178 | "i-label": 17, 179 | "keywords": ["purchase", "buy", "acquire"] 180 | }, 181 | "Business:Start-Org": { 182 | "event_id": "evt018", 183 | "template": " started organization at place", 184 | "roles":[ 185 | "Agent", 186 | "Org", 187 | "Place" 188 | ], 189 | "i-label": 18, 190 | "keywords":["establish", "launch", "initiate"] 191 | }, 192 | "Justice:Execute": { 193 | "event_id": "evt019", 194 | "template": " executed at place", 195 | "roles":[ 196 | "Agent", 197 | "Person", 198 | "Place" 199 | ], 200 | "i-label": 19, 201 | "keywords":["execute", "execution"] 202 | }, 203 | "Justice:Trial-Hearing": { 204 | "event_id": "evt020", 205 | "template":" tried before court or judge in place", 206 | "roles": [ 207 | "Prosecutor", 208 | "Defendant", 209 | "Adjudicator", 210 | "Place" 211 | ], 212 | "i-label": 20, 213 | "keywords": ["trial", "tried", "hearing", "testify"] 214 | }, 215 | "Life:Be-Born": { 216 | "event_id": "evt021", 217 | "template": " was born in place", 218 | "roles":[ 219 | "Person", 220 | "Place" 221 | ], 222 | "i-label": 21, 223 | "keywords": ["born","birth"] 224 | }, 225 | "Justice:Charge-Indict": { 226 | "event_id": "evt022", 227 | "template": " charged or indicted before court or judge in place", 228 | "roles":[ 229 | "Prosecutor", 230 | "Defendant", 231 | "Adjudicator", 232 | "Place" 233 | ], 234 | "i-label": 22, 235 | "keywords": ["charge", "indict", "accuse","accusation"] 236 | }, 237 | "Justice:Convict": { 238 | "event_id": "evt023", 239 | "template": " court or judge convicted in place", 240 | "roles":[ 241 | "Adjudicator", 242 | "Defendant", 243 | "Place" 244 | ], 245 | "i-label": 23, 246 | "keywords":["convict"] 247 | }, 248 | "Justice:Sentence": { 249 | "event_id": "evt024", 250 | "template": " court or judge sentenced in place", 251 | "roles":[ 252 | "Adjudicator", 253 | "Defendant", 254 | "Place" 255 | ], 256 | "i-label": 24, 257 | "keywords":["sentence"] 258 | }, 259 | "Business:Declare-Bankruptcy": { 260 | "event_id": "evt025", 261 | "template": " declared bankruptcy at place", 262 | "roles":[ 263 | "Org", 264 | "Place" 265 | ], 266 | "i-label": 25, 267 | "keywords": ["bankrupt","bankruptcy","broke"] 268 | }, 269 | "Justice:Release-Parole": { 270 | "event_id": "evt026", 271 | "template": " released or paroled in place", 272 | "roles":[ 273 | "Entity", 274 | "Person", 275 | "Place" 276 | ], 277 | "i-label": 26, 278 | "keywords":["release", "parole", "free"] 279 | }, 280 | "Justice:Fine": { 281 | "event_id": "evt027", 282 | "template": " court or judge fined at place", 283 | "roles": [ 284 | "Adjudicator", 285 | "Entity", 286 | "Place" 287 | ], 288 | "i-label": 27, 289 | "keywords":["fine", "forfeit", "penalize", "fee", "penalty"] 290 | }, 291 | "Justice:Pardon": { 292 | "event_id": "evt028", 293 | "template": " court or judge pardoned at place", 294 | "roles":[ 295 | "Adjudicator", 296 | "Defendant", 297 | "Place" 298 | ], 299 | "i-label": 28, 300 | "keywords": ["pardon","forgiveness","forgive","mercy", "amnesty"] 301 | }, 302 | "Justice:Appeal": { 303 | "event_id": "evt029", 304 | "template": " appealed to court or judge at place", 305 | "roles":[ 306 | "Plaintiff", 307 | "Adjudicator", 308 | "Place" 309 | ], 310 | "i-label": 29, 311 | "keywords": ["appeal", "retrial"] 312 | }, 313 | "Justice:Extradite": { 314 | "event_id": "evt030", 315 | "template": " extradited from place to place", 316 | "roles":[ 317 | "Agent", 318 | "Person", 319 | "Origin", 320 | "Destination" 321 | ], 322 | "i-label": 30, 323 | "keywords": ["extradite", "deport", "expel", "extradiction"] 324 | }, 325 | "Life:Divorce": { 326 | "event_id": "evt031", 327 | "template":" divorced in place", 328 | "roles":[ 329 | "Person", 330 | "Person", 331 | "Place" 332 | ], 333 | "i-label": 31, 334 | "keywords": ["divorce", "split", "dissolution"] 335 | }, 336 | "Business:Merge-Org": { 337 | "event_id": "evt032", 338 | "template": " organization merged with organization", 339 | "roles":[ 340 | "Org", 341 | "Org" 342 | ], 343 | "i-label": 32, 344 | "keywords": ["merge", "acquire", "merger"] 345 | }, 346 | "Justice:Acquit": { 347 | "event_id": "evt033", 348 | "template": " court or judge acquitted ", 349 | "roles":[ 350 | "Adjudicator", 351 | "Defendant" 352 | ], 353 | "i-label": 33, 354 | "keywords":["acquit", "absolve"] 355 | } 356 | } 357 | -------------------------------------------------------------------------------- /pronoun_list.txt: -------------------------------------------------------------------------------- 1 | all 2 | another 3 | any 4 | anybody 5 | anyone 6 | anything 7 | as 8 | aught 9 | both 10 | each 11 | each other 12 | either 13 | enough 14 | everybody 15 | everyone 16 | everything 17 | few 18 | he 19 | her 20 | hers 21 | herself 22 | him 23 | himself 24 | his 25 | I 26 | idem 27 | it 28 | its 29 | itself 30 | many 31 | me 32 | mine 33 | most 34 | my 35 | myself 36 | naught 37 | neither 38 | no one 39 | nobody 40 | none 41 | nothing 42 | nought 43 | one 44 | one another 45 | other 46 | others 47 | ought 48 | our 49 | ours 50 | ourself 51 | ourselves 52 | several 53 | she 54 | some 55 | somebody 56 | someone 57 | something 58 | somewhat 59 | such 60 | suchlike 61 | that 62 | thee 63 | their 64 | theirs 65 | theirself 66 | theirselves 67 | them 68 | themself 69 | themselves 70 | there 71 | these 72 | they 73 | thine 74 | this 75 | those 76 | thou 77 | thy 78 | thyself 79 | us 80 | we 81 | what 82 | whatever 83 | whatnot 84 | whatsoever 85 | whence 86 | where 87 | whereby 88 | wherefrom 89 | wherein 90 | whereinto 91 | whereof 92 | whereon 93 | wherever 94 | wheresoever 95 | whereto 96 | whereunto 97 | wherewith 98 | wherewithal 99 | whether 100 | which 101 | whichever 102 | whichsoever 103 | who 104 | whoever 105 | whom 106 | whomever 107 | whomso 108 | whomsoever 109 | whose 110 | whosever 111 | whosesoever 112 | whoso 113 | whosoever 114 | ye 115 | yon 116 | yonder 117 | you 118 | your 119 | yours 120 | yourself 121 | yourselves -------------------------------------------------------------------------------- /scripts/test_ace.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | CKPT_PATH=constrained-gen-ACE 5 | MODEL=constrained-gen 6 | DATA_DIR=data/ace/pro_mttrig_id/json 7 | CKPT_NAME=epoch=2-v0.ckpt 8 | 9 | python train.py --model=$MODEL --ckpt_name=$CKPT_PATH-pred \ 10 | --load_ckpt=checkpoints/$CKPT_PATH/$CKPT_NAME \ 11 | --dataset=ACE \ 12 | --mark_trigger \ 13 | --eval_only \ 14 | --train_file=${DATA_DIR}/train.oneie.json \ 15 | --val_file=${DATA_DIR}/dev.oneie.json \ 16 | --test_file=${DATA_DIR}/test.oneie.json \ 17 | --train_batch_size=4 \ 18 | --eval_batch_size=4 \ 19 | --learning_rate=3e-5 \ 20 | --accumulate_grad_batches=4 21 | 22 | 23 | python src/genie/scorer.py --gen-file=checkpoints/$CKPT_PATH-pred/predictions.jsonl --dataset=ACE \ 24 | --test-file=data/ace/pro_mttrig_id/json/test.oneie.json 25 | 26 | python src/genie/scorer.py --gen-file=checkpoints/$CKPT_PATH-pred/predictions.jsonl --dataset=ACE \ 27 | --test-file=data/ace/pro_mttrig_id/json/test.oneie.json --head-only -------------------------------------------------------------------------------- /scripts/test_kairos.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | CKPT_NAME=gen-KAIROS 5 | MODEL=constrained-gen 6 | 7 | rm -rf checkpoints/${CKPT_NAME}-pred 8 | python train.py --model=$MODEL --ckpt_name=${CKPT_NAME}-pred \ 9 | --load_ckpt=checkpoints/${CKPT_NAME}/epoch=2-v0.ckpt \ 10 | --dataset=KAIROS \ 11 | --eval_only \ 12 | --mark_trigger \ 13 | --train_file=data/wikievents/train.jsonl \ 14 | --val_file=data/wikievents/dev.jsonl \ 15 | --test_file=data/wikievents/test.jsonl \ 16 | --coref_dir=data/wikievents/coref \ 17 | --train_batch_size=4 \ 18 | --eval_batch_size=4 \ 19 | --learning_rate=3e-5 \ 20 | --accumulate_grad_batches=4 \ 21 | --num_train_epochs=3 22 | 23 | python src/genie/scorer.py --gen-file=checkpoints/$CKPT_NAME-pred/predictions.jsonl \ 24 | --test-file=data/wikievents/test.jsonl \ 25 | --dataset=KAIROS \ 26 | --coref-file=data/wikievents/coref/test.jsonlines \ 27 | --head-only \ 28 | --coref 29 | -------------------------------------------------------------------------------- /scripts/test_rams.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | CKPT_NAME=gen-RAMS 5 | MODEL=gen 6 | 7 | python train.py --model=$MODEL --ckpt_name=$CKPT_NAME-pred \ 8 | --load_ckpt=checkpoints/$CKPT_NAME/epoch=2-v0.ckpt \ 9 | --dataset=RAMS \ 10 | --eval_only \ 11 | --train_file=data/RAMS_1.0/data/train.jsonlines \ 12 | --val_file=data/RAMS_1.0/data/dev.jsonlines \ 13 | --test_file=data/RAMS_1.0/data/test.jsonlines \ 14 | --train_batch_size=2 \ 15 | --eval_batch_size=4 \ 16 | --learning_rate=3e-5 \ 17 | --accumulate_grad_batches=4 \ 18 | --num_train_epochs=3 19 | 20 | 21 | #span eval 22 | python genie/convert_gen_to_output.py --gen-file=checkpoints/$CKPT_NAME-pred/predictions.jsonl \ 23 | --output-file=checkpoints/$CKPT_NAME-pred/span_output.jsonl 24 | 25 | python data/RAMS_1.0/scorer/scorer.py -g=data/RAMS_1.0/data/test.jsonlines -p=checkpoints/$CKPT_NAME-pred/span_output.jsonl \ 26 | --reuse_gold_format --do_all > checkpoints/$CKPT_NAME-pred/span_metrics.txt 27 | 28 | # head eval 29 | python genie/convert_gen_to_output.py --gen-file=checkpoints/$CKPT_NAME-pred/predictions.jsonl \ 30 | --output-file=checkpoints/$CKPT_NAME-pred/output.jsonl --head-only 31 | 32 | python data/RAMS_1.0/scorer/scorer.py -g=data/RAMS_1.0/data/test_head.jsonlines -p=checkpoints/$CKPT_NAME-pred/output.jsonl \ 33 | --reuse_gold_format --do_all > checkpoints/$CKPT_NAME-pred/head_metrics.txt 34 | 35 | # head + coref eval 36 | python genie/convert_gen_to_output.py --gen-file=checkpoints/$CKPT_NAME-pred/predictions.jsonl \ 37 | --test-file=data/RAMS_1.0/data/test_head_coref.jsonlines \ 38 | --output-file=checkpoints/$CKPT_NAME-pred/coref_output.jsonl --head-only --coref 39 | 40 | python data/RAMS_1.0/scorer/scorer.py -g=data/RAMS_1.0/data/test_head_coref.jsonlines -p=checkpoints/$CKPT_NAME-pred/coref_output.jsonl \ 41 | --reuse_gold_format --do_all > checkpoints/$CKPT_NAME-pred/coref_metrics.txt 42 | 43 | 44 | # visualize 45 | python visualize_output.py --result-file=checkpoints/$CKPT_NAME-pred/span_output.jsonl -------------------------------------------------------------------------------- /scripts/train_ace.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | DATA_DIR=data/ace/pro_mttrig_id/json 5 | MODEL=constrained-gen 6 | CKPT_NAME=constrained-gen-ACE 7 | 8 | 9 | rm -rf checkpoints/${CKPT_NAME} 10 | python train.py --model=${MODEL} --ckpt_name=${CKPT_NAME} \ 11 | --dataset=ACE \ 12 | --tmp_dir=preprocessed_ACE \ 13 | --train_file=${DATA_DIR}/train.oneie.json \ 14 | --val_file=${DATA_DIR}/dev.oneie.json \ 15 | --test_file=${DATA_DIR}/test.oneie.json \ 16 | --train_batch_size=4 \ 17 | --eval_batch_size=4 \ 18 | --learning_rate=3e-5 \ 19 | --accumulate_grad_batches=4 \ 20 | --num_train_epochs=6 \ 21 | --mark_trigger 22 | -------------------------------------------------------------------------------- /scripts/train_kairos.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | 5 | CKPT_NAME='gen-KAIROS' 6 | rm -rf checkpoints/${CKPT_NAME} 7 | 8 | # does not use informative mentions 9 | python train.py --model=constrained-gen --ckpt_name=${CKPT_NAME} \ 10 | --dataset=KAIROS \ 11 | --train_file=data/wikievents/train.jsonl \ 12 | --val_file=data/wikievents/dev.jsonl \ 13 | --test_file=data/wikievents/test.jsonl \ 14 | --train_batch_size=2 \ 15 | --eval_batch_size=4 \ 16 | --learning_rate=3e-5 \ 17 | --accumulate_grad_batches=8 \ 18 | --num_train_epochs=3 \ 19 | --mark_trigger \ 20 | --coref_dir=data/wikievents/coref 21 | -------------------------------------------------------------------------------- /scripts/train_rams.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | 5 | python train.py --model=gen --ckpt_name='gen-RAMS' \ 6 | --dataset=RAMS \ 7 | --train_file=data/RAMS_1.0/data/train.jsonlines \ 8 | --val_file=data/RAMS_1.0/data/dev.jsonlines \ 9 | --test_file=data/RAMS_1.0/data/test.jsonlines \ 10 | --train_batch_size=2 \ 11 | --eval_batch_size=4 \ 12 | --learning_rate=3e-5 \ 13 | --accumulate_grad_batches=4 \ 14 | --num_train_epochs=3 \ 15 | --mark_trigger 16 | -------------------------------------------------------------------------------- /src/genie/ACE_data_module.py: -------------------------------------------------------------------------------- 1 | from json import load 2 | import os 3 | import json 4 | import re 5 | from collections import defaultdict 6 | import argparse 7 | 8 | import transformers 9 | from transformers import BartTokenizer 10 | import torch 11 | from torch.utils.data import DataLoader 12 | import pytorch_lightning as pl 13 | 14 | from .data import IEDataset, my_collate 15 | from .utils import load_ontology 16 | 17 | MAX_LENGTH=170 18 | MAX_TGT_LENGTH=72 19 | 20 | class ACEDataModule(pl.LightningDataModule): 21 | def __init__(self, args): 22 | super().__init__() 23 | self.hparams = args 24 | self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 25 | self.tokenizer.add_tokens([' ',' ']) 26 | 27 | 28 | def create_gold_gen(self, ex, ontology_dict,mark_trigger=True, index=0): 29 | ''' 30 | If there are multiple events per example, use index parameter. 31 | 32 | Input: Template with special placeholders Passage 33 | Output: Template with arguments and when no argument is found. 34 | ''' 35 | 36 | evt_type = ex['event_mentions'][index]['event_type'] 37 | 38 | context_words = ex['tokens'] 39 | template = ontology_dict[evt_type]['template'] 40 | input_template = re.sub(r'', '', template) 41 | 42 | 43 | space_tokenized_input_template = input_template.split() 44 | tokenized_input_template = [] 45 | for w in space_tokenized_input_template: 46 | tokenized_input_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) 47 | 48 | role2arg = defaultdict(list) 49 | 50 | for argument in ex['event_mentions'][index]['arguments']: 51 | role2arg[argument['role']].append(argument) 52 | 53 | role2arg = dict(role2arg) 54 | arg_idx2text = defaultdict(list) 55 | for role in role2arg.keys(): 56 | if role not in ontology_dict[evt_type]: 57 | # annotation error 58 | continue 59 | for i, argument in enumerate(role2arg[role]): 60 | arg_text = argument['text'] 61 | if i < len(ontology_dict[evt_type][role]): 62 | # enough slots to fill in 63 | arg_idx = ontology_dict[evt_type][role][i] 64 | 65 | 66 | else: 67 | # multiple participants for the same role 68 | arg_idx = ontology_dict[evt_type][role][-1] 69 | 70 | arg_idx2text[arg_idx].append(arg_text) 71 | 72 | for arg_idx, text_list in arg_idx2text.items(): 73 | text = ' and '.join(text_list) 74 | template = re.sub('<{}>'.format(arg_idx), text, template) 75 | 76 | 77 | 78 | trigger = ex['event_mentions'][index]['trigger'] 79 | # trigger span does not include last index 80 | 81 | if mark_trigger: 82 | prefix = self.tokenizer.tokenize(' '.join(context_words[:trigger['start']]), add_prefix_space=True) 83 | tgt = self.tokenizer.tokenize(' '.join(context_words[trigger['start']: trigger['end']]), add_prefix_space=True) 84 | 85 | suffix = self.tokenizer.tokenize(' '.join(context_words[trigger['end']:]), add_prefix_space=True) 86 | context = prefix + [' ', ] + tgt + [' ', ] + suffix 87 | else: 88 | context = self.tokenizer.tokenize(' '.join(context_words), add_prefix_space=True) 89 | 90 | output_template = re.sub(r'','', template ) 91 | space_tokenized_template = output_template.split() 92 | tokenized_template = [] 93 | for w in space_tokenized_template: 94 | tokenized_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) 95 | 96 | return tokenized_input_template, tokenized_template, context 97 | 98 | 99 | 100 | 101 | 102 | def prepare_data(self): 103 | if self.hparams.tmp_dir: 104 | data_dir = self.hparams.tmp_dir 105 | else: 106 | data_dir = 'preprocessed_{}'.format(self.hparams.dataset) 107 | 108 | if not os.path.exists(data_dir): 109 | print('creating tmp dir ....') 110 | os.makedirs(data_dir) 111 | if self.hparams.dataset == 'combined': 112 | ontology_dict = load_ontology(dataset='KAIROS') 113 | else: 114 | ontology_dict = load_ontology(dataset=self.hparams.dataset) 115 | 116 | for split,f in [('train',self.hparams.train_file), ('val',self.hparams.val_file), ('test',self.hparams.test_file)]: 117 | if (split in ['train', 'val']) and not f: #possible for eval_only 118 | continue 119 | with open(f,'r') as reader, open(os.path.join(data_dir,'{}.jsonl'.format(split)), 'w') as writer: 120 | for lidx, line in enumerate(reader): 121 | ex = json.loads(line.strip()) 122 | 123 | for i in range(len(ex['event_mentions'])): 124 | evt_type = ex['event_mentions'][i]['event_type'] 125 | 126 | if evt_type not in ontology_dict: # should be a rare event type 127 | print(evt_type) 128 | continue 129 | 130 | input_template, output_template, context= self.create_gold_gen(ex, ontology_dict, self.hparams.mark_trigger, index=i) 131 | 132 | 133 | input_tokens = self.tokenizer.encode_plus(input_template, context, 134 | add_special_tokens=True, 135 | add_prefix_space=True, 136 | max_length=MAX_LENGTH, 137 | truncation='only_second', 138 | padding='max_length') 139 | tgt_tokens = self.tokenizer.encode_plus(output_template, 140 | add_special_tokens=True, 141 | add_prefix_space=True, 142 | max_length=MAX_TGT_LENGTH, 143 | truncation=True, 144 | padding='max_length') 145 | 146 | processed_ex = { 147 | 'doc_key': ex['sent_id'], #this is not unique 148 | 'input_token_ids':input_tokens['input_ids'], 149 | 'input_attn_mask': input_tokens['attention_mask'], 150 | 'tgt_token_ids': tgt_tokens['input_ids'], 151 | 'tgt_attn_mask': tgt_tokens['attention_mask'], 152 | } 153 | writer.write(json.dumps(processed_ex) + '\n') 154 | 155 | 156 | 157 | 158 | def train_dataloader(self): 159 | if self.hparams.tmp_dir: 160 | data_dir = self.hparams.tmp_dir 161 | else: 162 | data_dir = 'preprocessed_{}'.format(self.hparams.dataset) 163 | 164 | dataset = IEDataset(os.path.join(data_dir, 'train.jsonl')) 165 | 166 | dataloader = DataLoader(dataset, 167 | pin_memory=True, num_workers=2, 168 | collate_fn=my_collate, 169 | batch_size=self.hparams.train_batch_size, 170 | shuffle=True) 171 | return dataloader 172 | 173 | 174 | def val_dataloader(self): 175 | if self.hparams.tmp_dir: 176 | data_dir = self.hparams.tmp_dir 177 | else: 178 | data_dir = 'preprocessed_{}'.format(self.hparams.dataset) 179 | 180 | dataset = IEDataset(os.path.join(data_dir, 'val.jsonl')) 181 | 182 | 183 | dataloader = DataLoader(dataset, pin_memory=True, num_workers=2, 184 | collate_fn=my_collate, 185 | batch_size=self.hparams.eval_batch_size, shuffle=False) 186 | return dataloader 187 | 188 | def test_dataloader(self): 189 | if self.hparams.tmp_dir: 190 | data_dir = self.hparams.tmp_dir 191 | else: 192 | data_dir = 'preprocessed_{}'.format(self.hparams.dataset) 193 | 194 | dataset = IEDataset(os.path.join(data_dir, 'test.jsonl')) 195 | 196 | 197 | dataloader = DataLoader(dataset, pin_memory=True, num_workers=2, 198 | collate_fn=my_collate, 199 | batch_size=self.hparams.eval_batch_size, shuffle=False) 200 | 201 | return dataloader 202 | 203 | 204 | if __name__ == '__main__': 205 | parser = argparse.ArgumentParser() 206 | parser.add_argument('--train-file',type=str) 207 | parser.add_argument('--val-file', type=str) 208 | parser.add_argument('--test-file', type=str) 209 | parser.add_argument('--tmp_dir', default='tmp') 210 | parser.add_argument('--train_batch_size', type=int, default=2) 211 | parser.add_argument('--eval_batch_size', type=int, default=4) 212 | parser.add_argument('--mark-trigger', action='store_true', default=True) 213 | parser.add_argument('--dataset', type=str, default='combined') 214 | args = parser.parse_args() 215 | 216 | dm = ACEDataModule(args=args) 217 | dm.prepare_data() 218 | 219 | # training dataloader 220 | dataloader = dm.train_dataloader() 221 | 222 | for idx, batch in enumerate(dataloader): 223 | print(batch) 224 | break 225 | 226 | -------------------------------------------------------------------------------- /src/genie/KAIROS_data_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | import random 5 | from collections import defaultdict 6 | import argparse 7 | 8 | import transformers 9 | from transformers import BartTokenizer 10 | import torch 11 | from torch.utils.data import DataLoader 12 | import pytorch_lightning as pl 13 | 14 | from .data import IEDataset, my_collate 15 | from .utils import load_ontology, check_pronoun, clean_mention 16 | 17 | MAX_CONTEXT_LENGTH=400 # measured in words 18 | MAX_LENGTH=512 19 | MAX_TGT_LENGTH=70 20 | 21 | class KAIROSDataModule(pl.LightningDataModule): 22 | ''' 23 | Dataset processing for KAIROS. Involves chunking for long documents. 24 | ''' 25 | def __init__(self, args): 26 | super().__init__() 27 | self.hparams = args 28 | self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 29 | self.tokenizer.add_tokens([' ',' ']) 30 | 31 | 32 | def create_gold_gen(self, ex, ontology_dict,mark_trigger=True, index=0, ent2info=None, use_info=False): 33 | ''' 34 | If there are multiple events per example, use index parameter. 35 | 36 | Input: Template with special placeholders Passage 37 | Output: Template with arguments and when no argument is found. 38 | ''' 39 | if use_info and ent2info==None: 40 | raise ValueError('entity to informative mention mapping required.') 41 | 42 | evt_type = ex['event_mentions'][index]['event_type'] 43 | 44 | 45 | template = ontology_dict[evt_type]['template'] 46 | input_template = re.sub(r'', '', template) 47 | 48 | 49 | space_tokenized_input_template = input_template.split() 50 | tokenized_input_template = [] 51 | for w in space_tokenized_input_template: 52 | tokenized_input_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) 53 | 54 | role2arg = defaultdict(list) 55 | 56 | for argument in ex['event_mentions'][index]['arguments']: 57 | role2arg[argument['role']].append(argument) 58 | 59 | role2arg = dict(role2arg) 60 | 61 | # create output template 62 | arg_idx2text = defaultdict(list) 63 | for role in role2arg.keys(): 64 | if role not in ontology_dict[evt_type]: 65 | # annotation error 66 | continue 67 | for i, argument in enumerate(role2arg[role]): 68 | use_arg = True 69 | if use_info: 70 | ent_id = argument['entity_id'] 71 | if ent_id in ent2info: 72 | arg_text = clean_mention(ent2info[ent_id]) 73 | if check_pronoun(arg_text): 74 | # skipping this argument 75 | use_arg = False 76 | # if arg_text != argument['text']: 77 | # print('Original mention:{}, Informative mention:{}'.format(argument['text'], arg_text)) 78 | else: 79 | arg_text = argument['text'] 80 | else: 81 | arg_text = argument['text'] 82 | 83 | # assign the argument index 84 | if i < len(ontology_dict[evt_type][role]): 85 | # enough slots to fill in 86 | arg_idx = ontology_dict[evt_type][role][i] 87 | 88 | else: 89 | # multiple participants for the same role 90 | arg_idx = ontology_dict[evt_type][role][-1] 91 | 92 | if use_arg: 93 | arg_idx2text[arg_idx].append(arg_text) 94 | 95 | for arg_idx, text_list in arg_idx2text.items(): 96 | text = ' and '.join(text_list) 97 | template = re.sub('<{}>'.format(arg_idx), text, template) 98 | 99 | 100 | 101 | trigger = ex['event_mentions'][index]['trigger'] 102 | offset = 0 103 | # trigger span does not include last index 104 | context_words = ex['tokens'] 105 | center_sent = trigger['sent_idx'] 106 | if len(context_words) > MAX_CONTEXT_LENGTH: 107 | cur_len = len(ex['sentences'][center_sent][0]) 108 | context_words = [tup[0] for tup in ex['sentences'][center_sent][0]] 109 | if cur_len > MAX_CONTEXT_LENGTH: 110 | # one sentence is very long 111 | trigger_start = trigger['start'] 112 | start_idx = max(0, trigger_start- MAX_CONTEXT_LENGTH//2 ) 113 | end_idx = min(len(context_words), trigger_start + MAX_CONTEXT_LENGTH //2 ) 114 | context_words = context_words[start_idx: end_idx] 115 | offset = start_idx 116 | 117 | else: 118 | # take a sliding window 119 | left = center_sent -1 120 | right = center_sent +1 121 | 122 | total_sents = len(ex['sentences']) 123 | prev_len =0 124 | while cur_len > prev_len: 125 | prev_len = cur_len 126 | # try expanding the sliding window 127 | if left >= 0: 128 | left_sent_tokens = [tup[0] for tup in ex['sentences'][left][0]] 129 | if cur_len + len(left_sent_tokens) <= MAX_CONTEXT_LENGTH: 130 | context_words = left_sent_tokens + context_words 131 | left -=1 132 | cur_len += len(left_sent_tokens) 133 | 134 | if right < total_sents: 135 | right_sent_tokens = [tup[0] for tup in ex['sentences'][right][0]] 136 | if cur_len + len(right_sent_tokens) <= MAX_CONTEXT_LENGTH: 137 | context_words = context_words + right_sent_tokens 138 | right +=1 139 | cur_len += len(right_sent_tokens) 140 | # update trigger offset 141 | offset = sum([len(ex['sentences'][idx][0]) for idx in range(left+1)]) 142 | 143 | 144 | assert(len(context_words) <= MAX_CONTEXT_LENGTH) 145 | 146 | trigger['start'] = trigger['start'] - offset 147 | trigger['end'] = trigger['end'] - offset 148 | if mark_trigger: 149 | prefix = self.tokenizer.tokenize(' '.join(context_words[:trigger['start']]), add_prefix_space=True) 150 | tgt = self.tokenizer.tokenize(' '.join(context_words[trigger['start']: trigger['end']]), add_prefix_space=True) 151 | 152 | suffix = self.tokenizer.tokenize(' '.join(context_words[trigger['end']:]), add_prefix_space=True) 153 | context = prefix + [' ', ] + tgt + [' ', ] + suffix 154 | else: 155 | context = self.tokenizer.tokenize(' '.join(context_words), add_prefix_space=True) 156 | 157 | output_template = re.sub(r'','', template ) 158 | space_tokenized_template = output_template.split() 159 | tokenized_template = [] 160 | for w in space_tokenized_template: 161 | tokenized_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) 162 | 163 | return tokenized_input_template, tokenized_template, context 164 | 165 | 166 | 167 | 168 | 169 | def prepare_data(self): 170 | data_dir = 'preprocessed_{}'.format(self.hparams.dataset) 171 | if not os.path.exists(data_dir): 172 | os.makedirs(data_dir) 173 | ontology_dict = load_ontology(self.hparams.dataset) 174 | max_tokens = 0 175 | max_tgt =0 176 | 177 | 178 | 179 | for split,f in [('train',self.hparams.train_file), ('val',self.hparams.val_file), ('test',self.hparams.test_file)]: 180 | coref_split = 'dev' if split=='val' else split 181 | coref_reader = open(os.path.join(self.hparams.coref_dir, '{}.jsonlines'.format(coref_split))) 182 | with open(f,'r') as reader, open(os.path.join(data_dir,'{}.jsonl'.format(split)), 'w') as writer: 183 | for line, coref_line in zip(reader, coref_reader): 184 | ex = json.loads(line.strip()) 185 | corefs = json.loads(coref_line.strip()) 186 | assert(ex['doc_id'] == corefs['doc_key']) 187 | # mapping from entity id to information mention 188 | ent2info = {} 189 | for cidx, cluster in enumerate(corefs['clusters']): 190 | for eid in cluster: 191 | ent2info[eid] = corefs['informative_mentions'][cidx] 192 | 193 | 194 | for i in range(len(ex['event_mentions'])): 195 | if split=='train' and len(ex['event_mentions'][i]['arguments']) ==0: 196 | # skip mentions with no arguments 197 | continue 198 | evt_type = ex['event_mentions'][i]['event_type'] 199 | 200 | if evt_type not in ontology_dict: # should be a rare event type 201 | continue 202 | 203 | input_template, output_template, context= self.create_gold_gen(ex, ontology_dict, self.hparams.mark_trigger, 204 | index=i, ent2info=ent2info, use_info=self.hparams.use_info) 205 | 206 | 207 | max_tokens = max(len(context) + len(input_template) +2, max_tokens) 208 | # print(len(context) + len(input_template) +2 ) 209 | max_tgt = max(len(output_template) +1 , max_tgt) 210 | assert(len(output_template) < MAX_TGT_LENGTH) 211 | input_tokens = self.tokenizer.encode_plus(input_template, context, 212 | add_special_tokens=True, 213 | add_prefix_space=True, 214 | max_length=MAX_LENGTH, 215 | truncation='only_second', 216 | padding='max_length') 217 | tgt_tokens = self.tokenizer.encode_plus(output_template, 218 | add_special_tokens=True, 219 | add_prefix_space=True, 220 | max_length=MAX_TGT_LENGTH, 221 | truncation=True, 222 | padding='max_length') 223 | 224 | processed_ex = { 225 | 'event_idx': i, 226 | 'doc_key': ex['doc_id'], 227 | 'input_token_ids':input_tokens['input_ids'], 228 | 'input_attn_mask': input_tokens['attention_mask'], 229 | 'tgt_token_ids': tgt_tokens['input_ids'], 230 | 'tgt_attn_mask': tgt_tokens['attention_mask'], 231 | } 232 | writer.write(json.dumps(processed_ex) + '\n') 233 | 234 | 235 | print('longest context:{}'.format(max_tokens)) 236 | print('longest target {}'.format(max_tgt)) 237 | 238 | def train_dataloader(self): 239 | dataset = IEDataset('preprocessed_{}/train.jsonl'.format(self.hparams.dataset)) 240 | 241 | dataloader = DataLoader(dataset, 242 | pin_memory=True, num_workers=2, 243 | collate_fn=my_collate, 244 | batch_size=self.hparams.train_batch_size, 245 | shuffle=True) 246 | return dataloader 247 | 248 | 249 | def val_dataloader(self): 250 | dataset = IEDataset('preprocessed_{}/val.jsonl'.format(self.hparams.dataset)) 251 | 252 | dataloader = DataLoader(dataset, pin_memory=True, num_workers=2, 253 | collate_fn=my_collate, 254 | batch_size=self.hparams.eval_batch_size, shuffle=False) 255 | return dataloader 256 | 257 | def test_dataloader(self): 258 | dataset = IEDataset('preprocessed_{}/test.jsonl'.format(self.hparams.dataset)) 259 | 260 | dataloader = DataLoader(dataset, pin_memory=True, num_workers=2, 261 | collate_fn=my_collate, 262 | batch_size=self.hparams.eval_batch_size, shuffle=False) 263 | 264 | return dataloader 265 | 266 | 267 | if __name__ == '__main__': 268 | parser = argparse.ArgumentParser() 269 | parser.add_argument('--train-file',type=str,default='data/kairos/train.jsonl') 270 | parser.add_argument('--val-file', type=str, default='data/kairos/dev.jsonl') 271 | parser.add_argument('--test-file', type=str, default='data/kairos/test.jsonl') 272 | parser.add_argument('--coref-dir', type=str, default='data/kairos/coref') 273 | parser.add_argument('--use_info', action='store_true', default=True, help='use informative mentions instead of the nearest mention.') 274 | parser.add_argument('--train_batch_size', type=int, default=2) 275 | parser.add_argument('--eval_batch_size', type=int, default=4) 276 | parser.add_argument('--dataset', type=str, default='KAIROS') 277 | parser.add_argument('--mark-trigger', action='store_true', default=True) 278 | args = parser.parse_args() 279 | 280 | dm = KAIROSDataModule(args=args) 281 | dm.prepare_data() 282 | 283 | # training dataloader 284 | dataloader = dm.train_dataloader() 285 | 286 | for idx, batch in enumerate(dataloader): 287 | print(batch) 288 | break 289 | 290 | # val dataloader -------------------------------------------------------------------------------- /src/genie/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raspberryice/gen-arg/253e0889b2377e0f7084cb406cf5d4142ee8a365/src/genie/__init__.py -------------------------------------------------------------------------------- /src/genie/constrained_gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from transformers import ( 6 | BartModel, 7 | ) 8 | from transformers.modeling_utils import PreTrainedModel 9 | from transformers.modeling_outputs import Seq2SeqLMOutput 10 | from transformers.generation_utils import top_k_top_p_filtering 11 | from typing import Iterable, List, Optional 12 | from transformers.file_utils import ModelOutput 13 | 14 | 15 | class BartConstrainedGen(PreTrainedModel): 16 | def __init__(self, config, tokenizer): 17 | super(BartConstrainedGen, self).__init__(config) 18 | self.config = config 19 | self.tokenizer = tokenizer 20 | self.transformer = BartModel.from_pretrained('facebook/bart-large') 21 | self.register_buffer("final_logits_bias", torch.zeros((1, self.transformer.shared.num_embeddings))) 22 | 23 | 24 | def resize_token_embeddings(self): 25 | old_num_tokens = self.transformer.shared.num_embeddings 26 | new_embeddings = self.transformer.resize_token_embeddings(len(self.tokenizer)) 27 | self.transformer.shared = new_embeddings 28 | self._resize_final_logits_bias(len(self.tokenizer), old_num_tokens) 29 | self.vocab_size = len(self.tokenizer) 30 | 31 | return new_embeddings 32 | 33 | def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None: 34 | if new_num_tokens <= old_num_tokens: 35 | new_bias = self.final_logits_bias[:, :new_num_tokens] 36 | else: 37 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 38 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 39 | self.register_buffer("final_logits_bias", new_bias) 40 | 41 | 42 | def _init_weights(self, module): 43 | """ Initialize the weights """ 44 | if isinstance(module, (nn.Linear, nn.Embedding)): 45 | # Slightly different from the TF version which uses truncated_normal for initialization 46 | # cf https://github.com/pytorch/pytorch/pull/5617 47 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 48 | elif isinstance(module, torch.nn.LayerNorm): # if use apex, this should be FusedLayerNorm 49 | module.bias.data.zero_() 50 | module.weight.data.fill_(1.0) 51 | if isinstance(module, nn.Linear) and module.bias is not None: 52 | module.bias.data.zero_() 53 | 54 | 55 | def get_encoder(self): 56 | return self.transformer.encoder 57 | 58 | 59 | def get_output_embeddings(self): 60 | # this method is needed for generation 61 | vocab_size, emb_size = self.transformer.shared.weight.shape 62 | lin_layer = nn.Linear(vocab_size, emb_size, bias=False) 63 | lin_layer.weight.data = self.transformer.shared.weight.data 64 | return lin_layer 65 | 66 | 67 | def prepare_inputs_for_generation( 68 | self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, input_embeds, encoder_input_ids, **kwargs): 69 | return { 70 | "input_ids": encoder_input_ids, # encoder_outputs is defined. input_ids not needed 71 | "encoder_outputs": encoder_outputs, 72 | "past_key_values": past, 73 | "decoder_input_ids": decoder_input_ids, 74 | "attention_mask": attention_mask, 75 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 76 | "input_embeds": input_embeds, 77 | } 78 | 79 | def adjust_logits_during_generation(self, logits, cur_len, max_length): 80 | if cur_len == 1 and self.config.force_bos_token_to_be_generated: 81 | self._force_token_ids_generation(logits, self.config.bos_token_id) 82 | elif cur_len == max_length - 1 and self.config.eos_token_id is not None: 83 | self._force_token_ids_generation(logits, self.config.eos_token_id) 84 | return logits 85 | 86 | def _force_token_ids_generation(self, scores, token_id) -> None: 87 | """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" 88 | scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf") 89 | 90 | @staticmethod 91 | def _reorder_cache(past, beam_idx): 92 | reordered_past = [] 93 | for layer_past in past: 94 | # get the correct batch idx from decoder layer's batch dim for cross and self-attn 95 | layer_past_new = { 96 | attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() 97 | } 98 | reordered_past.append(layer_past_new) 99 | return reordered_past 100 | 101 | 102 | def convert_pointer_logits_to_lm_logits(self, pointer_logits, input_ids): 103 | ''' 104 | pointer_logits: (batch, seq_len, input_seq_len) 105 | input_ids: (batch, input_seq_len) 106 | lm_logits: (batch, seq_len, vocab_size) 107 | ''' 108 | batch_size = pointer_logits.size(0) 109 | seq_len = pointer_logits.size(1) 110 | input_seq_len = input_ids.size(1) 111 | lm_logits = torch.full((batch_size, seq_len, self.vocab_size), fill_value=-1000,dtype=pointer_logits.dtype).to(pointer_logits.device) 112 | 113 | 114 | # scatter may be technically incorrect for duplicate indexes, but not using it gets slow 115 | index = input_ids.unsqueeze(dim=1).expand_as(pointer_logits) 116 | lm_logits.scatter_(dim=2, index=index, src=pointer_logits) 117 | 118 | 119 | 120 | return lm_logits 121 | 122 | def remove_unseen(self, lm_logits, input_ids): 123 | # input_ids (batch, seq) 124 | seen_lm_logits = torch.full_like(lm_logits, fill_value=-1000).to(lm_logits.device) #(batch, seq, vocab) 125 | seen_vocab = set(input_ids.reshape(-1).tolist()) 126 | for i in range(self.transformer.vocab_size): 127 | if i in (seen_vocab): 128 | seen_lm_logits[:, :, i] = lm_logits[:, :, i] 129 | elif i == self.tokenizer.encode(' and', add_special_tokens=False)[0]: 130 | seen_lm_logits[:, :, i] = lm_logits[:, :, i] 131 | return seen_lm_logits 132 | 133 | 134 | def forward(self, input_ids, 135 | attention_mask=None, 136 | encoder_outputs=None, 137 | use_cache=False, 138 | past_key_values=None, 139 | decoder_input_ids=None, 140 | decoder_attention_mask=None, 141 | output_attentions=None, 142 | output_hidden_states=None, 143 | return_dict=None, 144 | input_embeds=None, 145 | task=-1): 146 | 147 | # generation 148 | if task==-1: 149 | outputs = self.transformer( 150 | input_ids=input_ids, 151 | attention_mask=attention_mask, 152 | decoder_input_ids=decoder_input_ids, 153 | decoder_attention_mask=decoder_attention_mask, 154 | use_cache=use_cache, 155 | encoder_outputs=encoder_outputs, 156 | past_key_values=past_key_values, 157 | output_attentions=output_attentions, 158 | output_hidden_states=output_hidden_states, 159 | input_embeds = input_embeds, 160 | return_dict=return_dict,) 161 | 162 | 163 | decoder_output = outputs[0] #(batch, seq_len, hidden_dim) 164 | if encoder_outputs==None: 165 | encoder_outputs = outputs[1] # (batch, input_seq_len, hidden_dim) 166 | # BaseModelOutput if return dict 167 | 168 | if input_embeds==None: 169 | # get encoder side embeddings 170 | input_embeds = self.transformer.encoder.embed_tokens(input_ids) * self.transformer.encoder.embed_scale #(batch, seq_len, input_seq_len) 171 | pointer_logits = torch.einsum('ijk,ilk->ijl', decoder_output, input_embeds) #(batch, seq_len, input_seq_len) 172 | lm_logits = self.convert_pointer_logits_to_lm_logits(pointer_logits, input_ids) 173 | 174 | 175 | masked_lm_loss = None 176 | 177 | if not return_dict: 178 | output = (lm_logits,) + outputs[1:] 179 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 180 | 181 | return Seq2SeqLMOutput( 182 | loss=masked_lm_loss, 183 | logits=lm_logits, 184 | past_key_values=outputs.past_key_values, 185 | decoder_hidden_states=outputs.decoder_hidden_states, 186 | decoder_attentions=outputs.decoder_attentions, 187 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 188 | encoder_hidden_states=outputs.encoder_hidden_states, 189 | encoder_attentions=outputs.encoder_attentions, 190 | ) 191 | 192 | #training 193 | elif task==0: 194 | 195 | assert(decoder_input_ids!=None) 196 | y_ids = decoder_input_ids[:, :-1] 197 | labels = decoder_input_ids[:, 1:].clone() 198 | labels[labels== self.tokenizer.pad_token_id] = -100 199 | # labels are just decoder_input_ids shifted to the right by 1 200 | 201 | outputs = self.transformer( 202 | input_ids, 203 | attention_mask=attention_mask, 204 | decoder_input_ids=y_ids, 205 | decoder_attention_mask=decoder_attention_mask[:, :-1], 206 | use_cache=False, 207 | past_key_values=past_key_values, 208 | output_attentions=output_attentions, 209 | output_hidden_states=output_hidden_states, 210 | return_dict=return_dict,) 211 | 212 | decoder_output = outputs[0] #(batch, seq_len, hidden_dim) 213 | encoder_output = outputs[1] # (batch, input_seq_len, hidden_dim) 214 | # lm_logits = F.linear(decoder_output, self.transformer.shared.weight, bias=self.final_logits_bias) 215 | # lm_logits = self.remove_unseen(lm_logits, input_ids) 216 | # get encoder side embeddings 217 | input_embeds = self.transformer.encoder.embed_tokens(input_ids) * self.transformer.encoder.embed_scale #(batch, seq_len, input_seq_len) 218 | 219 | pointer_logits = torch.einsum('ijk,ilk->ijl', decoder_output, input_embeds) #(batch, seq_len, input_seq_len) 220 | # decrease prob if neccesary 221 | 222 | lm_logits = self.convert_pointer_logits_to_lm_logits(pointer_logits, input_ids) 223 | 224 | outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here 225 | loss_fct = nn.CrossEntropyLoss() 226 | 227 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.vocab_size), labels.view(-1)) 228 | outputs = (masked_lm_loss,) + outputs 229 | 230 | return outputs 231 | 232 | 233 | 234 | 235 | # # this is a simplified generate class for the pointer generator taken from https://github.com/huggingface/transformers/blob/v3.1.0/src/transformers/generation_utils.py 236 | @torch.no_grad() 237 | def generate( 238 | self, 239 | input_ids: Optional[torch.LongTensor] = None, 240 | max_length: Optional[int] = None, 241 | min_length: Optional[int] = None, 242 | do_sample: Optional[bool] = None, 243 | early_stopping: Optional[bool] = None, 244 | num_beams: Optional[int] = None, 245 | temperature: Optional[float] = None, 246 | top_k: Optional[int] = None, 247 | top_p: Optional[float] = None, 248 | repetition_penalty: Optional[float] = None, 249 | bad_words_ids: Optional[Iterable[int]] = None, 250 | bos_token_id: Optional[int] = None, 251 | pad_token_id: Optional[int] = None, 252 | eos_token_id: Optional[int] = None, 253 | length_penalty: Optional[float] = None, 254 | no_repeat_ngram_size: Optional[int] = None, 255 | num_return_sequences: Optional[int] = None, 256 | attention_mask: Optional[torch.LongTensor] = None, 257 | decoder_start_token_id: Optional[int] = None, 258 | use_cache: Optional[bool] = None, 259 | **model_kwargs 260 | ) -> torch.LongTensor: 261 | r""" 262 | Generates sequences for models with a language modeling head. The method currently supports greedy decoding, 263 | beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. 264 | Adapted in part from `Facebook's XLM beam search code 265 | `__. 266 | Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the 267 | attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values 268 | indicated are the default values of those config. 269 | Most of these parameters are explained in more detail in `this blog post 270 | `__. 271 | Parameters: 272 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 273 | The sequence used as a prompt for the generation. If :obj:`None` the method initializes 274 | it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. 275 | max_length (:obj:`int`, `optional`, defaults to 20): 276 | The maximum length of the sequence to be generated. 277 | min_length (:obj:`int`, `optional`, defaults to 10): 278 | The minimum length of the sequence to be generated. 279 | do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): 280 | Whether or not to use sampling ; use greedy decoding otherwise. 281 | early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): 282 | Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. 283 | num_beams (:obj:`int`, `optional`, defaults to 1): 284 | Number of beams for beam search. 1 means no beam search. 285 | temperature (:obj:`float`, `optional`, defaults tp 1.0): 286 | The value used to module the next token probabilities. 287 | top_k (:obj:`int`, `optional`, defaults to 50): 288 | The number of highest probability vocabulary tokens to keep for top-k-filtering. 289 | top_p (:obj:`float`, `optional`, defaults to 1.0): 290 | If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or 291 | higher are kept for generation. 292 | repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): 293 | The parameter for repetition penalty. 1.0 means no penalty. See `this paper 294 | `__ for more details. 295 | pad_token_id (:obj:`int`, `optional`): 296 | The id of the `padding` token. 297 | bos_token_id (:obj:`int`, `optional`): 298 | The id of the `beginning-of-sequence` token. 299 | eos_token_id (:obj:`int`, `optional`): 300 | The id of the `end-of-sequence` token. 301 | length_penalty (:obj:`float`, `optional`, defaults to 1.0): 302 | Exponential penalty to the length. 1.0 means no penalty. 303 | Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in 304 | order to encourage the model to produce longer sequences. 305 | no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): 306 | If set to int > 0, all ngrams of that size can only occur once. 307 | bad_words_ids(:obj:`List[int]`, `optional`): 308 | List of token ids that are not allowed to be generated. In order to get the tokens of the words that 309 | should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. 310 | num_return_sequences(:obj:`int`, `optional`, defaults to 1): 311 | The number of independently computed returned sequences for each element in the batch. 312 | attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 313 | Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for 314 | tokens that are not masked, and 0 for masked tokens. 315 | If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token. 316 | `What are attention masks? <../glossary.html#attention-mask>`__ 317 | decoder_start_token_id (:obj:`int`, `optional`): 318 | If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. 319 | use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): 320 | Whether or not the model should use the past last key/values attentions (if applicable to the model) to 321 | speed up decoding. 322 | model_kwargs: 323 | Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. 324 | Return: 325 | :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: 326 | The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or 327 | shorter if all batches finished early due to the :obj:`eos_token_id`. 328 | Examples:: 329 | tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer 330 | model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. 331 | outputs = model.generate(max_length=40) # do greedy decoding 332 | print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) 333 | tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer 334 | model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. 335 | input_context = 'The dog' 336 | input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context 337 | outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' 338 | for i in range(3): # 3 output sequences were generated 339 | print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) 340 | tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer 341 | model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. 342 | input_context = 'The dog' 343 | input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context 344 | outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling 345 | for i in range(3): # 3 output sequences were generated 346 | print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) 347 | tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer 348 | model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. 349 | input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl 350 | input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context 351 | outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences 352 | print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) 353 | tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer 354 | model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. 355 | input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl 356 | bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] 357 | input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context 358 | outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated 359 | """ 360 | 361 | max_length = max_length if max_length is not None else self.config.max_length 362 | min_length = min_length if min_length is not None else self.config.min_length 363 | do_sample = do_sample if do_sample is not None else self.config.do_sample 364 | early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping 365 | use_cache = use_cache if use_cache is not None else self.config.use_cache 366 | num_beams = num_beams if num_beams is not None else self.config.num_beams 367 | temperature = temperature if temperature is not None else self.config.temperature 368 | top_k = top_k if top_k is not None else self.config.top_k 369 | top_p = top_p if top_p is not None else self.config.top_p 370 | repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty 371 | bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id 372 | pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id 373 | eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id 374 | length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty 375 | no_repeat_ngram_size = ( 376 | no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size 377 | ) 378 | bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids 379 | num_return_sequences = ( 380 | num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences 381 | ) 382 | decoder_start_token_id = ( 383 | decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id 384 | ) 385 | 386 | if input_ids is not None: 387 | batch_size = input_ids.shape[0] # overriden by the input batch_size 388 | else: 389 | batch_size = 1 390 | 391 | assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." 392 | assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." 393 | assert isinstance(do_sample, bool), "`do_sample` should be a boolean." 394 | assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." 395 | assert isinstance(use_cache, bool), "`use_cache` should be a boolean." 396 | assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." 397 | assert temperature > 0, "`temperature` should be strictly positive." 398 | assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." 399 | assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." 400 | assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." 401 | assert input_ids is not None or ( 402 | isinstance(bos_token_id, int) and bos_token_id >= 0 403 | ), "If input_ids is not defined, `bos_token_id` should be a positive integer." 404 | assert pad_token_id is None or ( 405 | isinstance(pad_token_id, int) and (pad_token_id >= 0) 406 | ), "`pad_token_id` should be a positive integer." 407 | assert (eos_token_id is None) or ( 408 | isinstance(eos_token_id, int) and (eos_token_id >= 0) 409 | ), "`eos_token_id` should be a positive integer." 410 | assert length_penalty > 0, "`length_penalty` should be strictly positive." 411 | assert ( 412 | isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 413 | ), "`no_repeat_ngram_size` should be a positive integer." 414 | assert ( 415 | isinstance(num_return_sequences, int) and num_return_sequences > 0 416 | ), "`num_return_sequences` should be a strictly positive integer." 417 | assert ( 418 | bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) 419 | ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" 420 | 421 | if input_ids is None: 422 | assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( 423 | "you should either supply a context to complete as `input_ids` input " 424 | "or a `bos_token_id` (integer >= 0) as a first token to start the generation." 425 | ) 426 | input_ids = torch.full( 427 | (batch_size, 1), 428 | bos_token_id, 429 | dtype=torch.long, 430 | device=next(self.parameters()).device, 431 | ) 432 | else: 433 | assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." 434 | 435 | # not allow to duplicate outputs when greedy decoding 436 | if do_sample is False: 437 | if num_beams == 1: 438 | # no_beam_search greedy generation conditions 439 | assert ( 440 | num_return_sequences == 1 441 | ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" 442 | 443 | else: 444 | # beam_search greedy generation conditions 445 | assert ( 446 | num_beams >= num_return_sequences 447 | ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" 448 | 449 | # create attention mask if necessary 450 | # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 451 | if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): 452 | attention_mask = input_ids.ne(pad_token_id).long() 453 | elif attention_mask is None: 454 | attention_mask = input_ids.new_ones(input_ids.shape) 455 | 456 | # set pad_token_id to eos_token_id if not set. Important that this is done after 457 | # attention_mask is created 458 | if pad_token_id is None and eos_token_id is not None: 459 | pad_token_id = eos_token_id 460 | 461 | # current position and vocab size 462 | if hasattr(self.config, "vocab_size"): 463 | vocab_size = self.config.vocab_size 464 | elif ( 465 | self.config.is_encoder_decoder 466 | and hasattr(self.config, "decoder") 467 | and hasattr(self.config.decoder, "vocab_size") 468 | ): 469 | vocab_size = self.config.decoder.vocab_size 470 | 471 | # set effective batch size and effective batch multiplier according to do_sample 472 | if do_sample: 473 | effective_batch_size = batch_size * num_return_sequences 474 | effective_batch_mult = num_return_sequences 475 | else: 476 | effective_batch_size = batch_size 477 | effective_batch_mult = 1 478 | 479 | if self.config.is_encoder_decoder: 480 | if decoder_start_token_id is None: 481 | # see if BOS token can be used for decoder_start_token_id 482 | if bos_token_id is not None: 483 | decoder_start_token_id = bos_token_id 484 | elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"): 485 | decoder_start_token_id = self.config.decoder.bos_token_id 486 | else: 487 | raise ValueError( 488 | "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" 489 | ) 490 | 491 | assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) 492 | assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) 493 | 494 | # get encoder and store encoder outputs 495 | encoder = self.get_encoder() 496 | encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True) 497 | input_embeds = encoder.embed_tokens(input_ids) * encoder.embed_scale 498 | 499 | # Expand input ids if num_beams > 1 or num_return_sequences > 1 500 | if num_return_sequences > 1 or num_beams > 1: 501 | input_ids_len = input_ids.shape[-1] 502 | input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) 503 | attention_mask = attention_mask.unsqueeze(1).expand( 504 | batch_size, effective_batch_mult * num_beams, input_ids_len 505 | ) 506 | 507 | input_ids = input_ids.contiguous().view( 508 | effective_batch_size * num_beams, input_ids_len 509 | ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) 510 | attention_mask = attention_mask.contiguous().view( 511 | effective_batch_size * num_beams, input_ids_len 512 | ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) 513 | 514 | encoder_input_ids = input_ids 515 | 516 | if self.config.is_encoder_decoder: 517 | # create empty decoder_input_ids 518 | input_ids = torch.full( 519 | (effective_batch_size * num_beams, 1), 520 | decoder_start_token_id, 521 | dtype=torch.long, 522 | device=next(self.parameters()).device, 523 | ) 524 | cur_len = 1 525 | 526 | assert ( 527 | batch_size == encoder_outputs.last_hidden_state.shape[0] 528 | ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} " 529 | 530 | # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) 531 | expanded_batch_idxs = ( 532 | torch.arange(batch_size) 533 | .view(-1, 1) 534 | .repeat(1, num_beams * effective_batch_mult) 535 | .view(-1) 536 | .to(input_ids.device) 537 | ) 538 | 539 | # expand encoder_outputs 540 | encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 541 | 0, expanded_batch_idxs 542 | ) 543 | 544 | # save encoder_outputs in `model_kwargs` 545 | model_kwargs["encoder_outputs"] = encoder_outputs 546 | model_kwargs["input_embeds"] = input_embeds 547 | model_kwargs["encoder_input_ids"] = encoder_input_ids 548 | 549 | else: 550 | cur_len = input_ids.shape[-1] 551 | 552 | assert ( 553 | cur_len < max_length 554 | ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" 555 | 556 | output = self._generate_no_beam_search( 557 | input_ids, 558 | cur_len=cur_len, 559 | max_length=max_length, 560 | min_length=min_length, 561 | do_sample=do_sample, 562 | temperature=temperature, 563 | top_k=top_k, 564 | top_p=top_p, 565 | repetition_penalty=repetition_penalty, 566 | no_repeat_ngram_size=no_repeat_ngram_size, 567 | bad_words_ids=bad_words_ids, 568 | pad_token_id=pad_token_id, 569 | eos_token_id=eos_token_id, 570 | batch_size=effective_batch_size, 571 | attention_mask=attention_mask, 572 | use_cache=use_cache, 573 | model_kwargs=model_kwargs, 574 | ) 575 | 576 | return output 577 | 578 | def _generate_no_beam_search( 579 | self, 580 | input_ids, 581 | cur_len, 582 | max_length, 583 | min_length, 584 | do_sample, 585 | temperature, 586 | top_k, 587 | top_p, 588 | repetition_penalty, 589 | no_repeat_ngram_size, 590 | bad_words_ids, 591 | pad_token_id, 592 | eos_token_id, 593 | batch_size, 594 | attention_mask, 595 | use_cache, 596 | model_kwargs, 597 | ): 598 | """Generate sequences for each example without beam search (num_beams == 1). 599 | All returned sequence are generated independantly. 600 | """ 601 | # length of generated sentences / unfinished sentences 602 | unfinished_sents = input_ids.new(batch_size).fill_(1) 603 | sent_lengths = input_ids.new(batch_size).fill_(max_length) 604 | 605 | past = None 606 | while cur_len < max_length: 607 | model_inputs = self.prepare_inputs_for_generation( 608 | input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs 609 | ) 610 | 611 | outputs = self(**model_inputs, return_dict=True) 612 | # calling forward here 613 | 614 | #outputs.logits (batch, seq_len, input_seq_len) 615 | next_token_logits = outputs.logits[:, -1, :] 616 | 617 | scores = self.postprocess_next_token_scores( 618 | scores=next_token_logits, 619 | input_ids=input_ids, 620 | no_repeat_ngram_size=no_repeat_ngram_size, 621 | bad_words_ids=bad_words_ids, 622 | cur_len=cur_len, 623 | min_length=min_length, 624 | max_length=max_length, 625 | eos_token_id=eos_token_id, 626 | repetition_penalty=repetition_penalty, 627 | batch_size=batch_size, 628 | num_beams=1, 629 | ) 630 | 631 | # if model has past, then set the past variable to speed up decoding 632 | if "past_key_values" in outputs: 633 | past = outputs.past_key_values 634 | elif "mems" in outputs: 635 | past = outputs.mems 636 | 637 | if do_sample: 638 | # Temperature (higher temperature => more likely to sample low probability tokens) 639 | if temperature != 1.0: 640 | scores = scores / temperature 641 | # Top-p/top-k filtering 642 | next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) 643 | # Sample 644 | probs = F.softmax(next_token_logscores, dim=-1) 645 | next_token = torch.multinomial(probs, num_samples=1).squeeze(1) 646 | else: 647 | # Greedy decoding 648 | next_token = torch.argmax(next_token_logits, dim=-1) 649 | 650 | # update generations and finished sentences 651 | if eos_token_id is not None: 652 | # pad finished sentences if eos_token_id exist 653 | tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) 654 | else: 655 | tokens_to_add = next_token 656 | 657 | # add token and increase length by one 658 | input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) 659 | cur_len = cur_len + 1 660 | 661 | if eos_token_id is not None: 662 | eos_in_sents = tokens_to_add == eos_token_id 663 | # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length 664 | is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() 665 | sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len) 666 | # unfinished_sents is set to zero if eos in sentence 667 | unfinished_sents.mul_((~eos_in_sents).long()) 668 | 669 | # stop when there is a in each sentence, or if we exceed the maximul length 670 | if unfinished_sents.max() == 0: 671 | break 672 | 673 | # extend attention_mask for new generated input if only decoder 674 | if self.config.is_encoder_decoder is False: 675 | attention_mask = torch.cat( 676 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 677 | ) 678 | 679 | return input_ids 680 | 681 | # Function from `generation_utils.py` of Transformers library 682 | def postprocess_next_token_scores( 683 | self, 684 | scores, 685 | input_ids, 686 | no_repeat_ngram_size, 687 | bad_words_ids, 688 | cur_len, 689 | min_length, 690 | max_length, 691 | eos_token_id, 692 | repetition_penalty, 693 | batch_size, 694 | num_beams, 695 | ): 696 | # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) 697 | if repetition_penalty != 1.0: 698 | self.enforce_repetition_penalty_( 699 | scores, 700 | batch_size, 701 | num_beams, 702 | input_ids, 703 | repetition_penalty, 704 | ) 705 | 706 | # set eos token prob to zero if min_length is not reached 707 | if eos_token_id is not None and cur_len < min_length: 708 | scores[:, eos_token_id] = -float("inf") 709 | 710 | return scores 711 | -------------------------------------------------------------------------------- /src/genie/convert_gen_to_output.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import re 5 | from copy import deepcopy 6 | from tqdm import tqdm 7 | 8 | from utils import find_head, WhitespaceTokenizer, find_arg_span 9 | import spacy 10 | 11 | 12 | 13 | def extract_args_from_template(ex, template, ontology_dict,): 14 | # extract argument text 15 | template_words = template.strip().split() 16 | predicted_words = ex['predicted'].strip().split() 17 | predicted_args = {} 18 | t_ptr= 0 19 | p_ptr= 0 20 | evt_type = get_event_type(ex)[0] 21 | 22 | while t_ptr < len(template_words) and p_ptr < len(predicted_words): 23 | if re.match(r'<(arg\d+)>', template_words[t_ptr]): 24 | m = re.match(r'<(arg\d+)>', template_words[t_ptr]) 25 | arg_num = m.group(1) 26 | arg_name = ontology_dict[evt_type.replace('n/a','unspecified')][arg_num] 27 | 28 | if predicted_words[p_ptr] == '': 29 | # missing argument 30 | p_ptr +=1 31 | t_ptr +=1 32 | else: 33 | arg_start = p_ptr 34 | while (p_ptr < len(predicted_words)) and (predicted_words[p_ptr] != template_words[t_ptr+1]): 35 | p_ptr+=1 36 | arg_text = predicted_words[arg_start:p_ptr] 37 | predicted_args[arg_name] = arg_text 38 | t_ptr+=1 39 | # aligned 40 | else: 41 | t_ptr+=1 42 | p_ptr+=1 43 | 44 | return predicted_args 45 | 46 | 47 | 48 | 49 | 50 | 51 | def get_event_type(ex): 52 | evt_type = [] 53 | for evt in ex['evt_triggers']: 54 | for t in evt[2]: 55 | evt_type.append( t[0]) 56 | return evt_type 57 | 58 | def check_coref(ex, arg_span, gold_spans): 59 | for clus in ex['corefs']: 60 | if arg_span in clus: 61 | matched_gold_spans = [span for span in gold_spans if span in clus] 62 | if len(matched_gold_spans) > 0: 63 | return matched_gold_spans[0] 64 | return arg_span 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--gen-file',type=str, default='checkpoints/gen-new-tokenization-pred/sample_predictions.jsonl') 70 | parser.add_argument('--test-file', type=str,default='data/RAMS_1.0/data/test_head_coref.jsonlines') 71 | parser.add_argument('--output-file',type=str, default='test_output.jsonl') 72 | parser.add_argument('--ontology-file',type=str, default='aida_ontology_cleaned.csv') 73 | parser.add_argument('--head-only',action='store_true',default=False) 74 | parser.add_argument('--coref', action='store_true', default=False) 75 | args = parser.parse_args() 76 | 77 | nlp = spacy.load('en_core_web_sm') 78 | nlp.tokenizer = WhitespaceTokenizer(nlp.vocab) 79 | # read ontology 80 | ontology_dict ={} 81 | with open('aida_ontology_cleaned.csv','r') as f: 82 | for lidx, line in enumerate(f): 83 | if lidx == 0:# header 84 | continue 85 | fields = line.strip().split(',') 86 | if len(fields) < 2: 87 | break 88 | evt_type = fields[0] 89 | arguments = fields[2:] 90 | 91 | ontology_dict[evt_type] = { 92 | 'template': fields[1] 93 | } 94 | 95 | for i, arg in enumerate(arguments): 96 | if arg !='': 97 | ontology_dict[evt_type]['arg{}'.format(i+1)] = arg 98 | ontology_dict[evt_type][arg] = 'arg{}'.format(i+1) 99 | 100 | 101 | examples = {} 102 | with open(args.test_file, 'r') as f: 103 | for line in f: 104 | ex = json.loads(line.strip()) 105 | ex['ref_evt_links'] = deepcopy(ex['gold_evt_links']) 106 | ex['gold_evt_links'] = [] 107 | examples[ex['doc_key']] =ex 108 | 109 | 110 | with open(args.gen_file,'r') as f: 111 | for line in f: 112 | pred = json.loads(line.strip()) 113 | examples[pred['doc_key']]['predicted'] = pred['predicted'] 114 | examples[pred['doc_key']]['gold'] = pred['gold'] 115 | 116 | 117 | writer = open(args.output_file, 'w') 118 | for ex in tqdm(examples.values()): 119 | if 'predicted' not in ex:# this is used for testing 120 | continue 121 | # get template 122 | evt_type = get_event_type(ex)[0] 123 | context_words = [w for sent in ex['sentences'] for w in sent ] 124 | template = ontology_dict[evt_type.replace('n/a','unspecified')]['template'] 125 | # extract argument text 126 | 127 | predicted_args = extract_args_from_template(ex,template, ontology_dict) 128 | # get trigger 129 | # extract argument span 130 | trigger_start = ex['evt_triggers'][0][0] 131 | trigger_end = ex['evt_triggers'][0][1] 132 | doc = None 133 | if args.head_only: 134 | doc = nlp(' '.join(context_words)) 135 | 136 | for argname in predicted_args: 137 | arg_span = find_arg_span(predicted_args[argname], context_words, 138 | trigger_start, trigger_end, head_only=args.head_only, doc=doc) 139 | if arg_span:# if None means hullucination 140 | 141 | if args.head_only and args.coref: 142 | # consider coreferential mentions as matching 143 | assert('corefs' in ex) 144 | gold_spans = [a[1] for a in ex['ref_evt_links'] if a[2]==argname] 145 | arg_span = check_coref(ex, list(arg_span), gold_spans) 146 | 147 | ex['gold_evt_links'].append([[trigger_start, trigger_end], list(arg_span), argname]) 148 | 149 | writer.write(json.dumps(ex)+'\n') 150 | 151 | writer.close() 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /src/genie/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | def my_collate(batch): 8 | ''' 9 | 'doc_key': ex['doc_key'], 10 | 'input_token_ids':input_tokens['input_ids'], 11 | 'input_attn_mask': input_tokens['attention_mask'], 12 | 'tgt_token_ids': tgt_tokens['input_ids'], 13 | 'tgt_attn_mask': tgt_tokens['attention_mask'], 14 | ''' 15 | doc_keys = [ex['doc_key'] for ex in batch] 16 | input_token_ids = torch.stack([torch.LongTensor(ex['input_token_ids']) for ex in batch]) 17 | input_attn_mask = torch.stack([torch.BoolTensor(ex['input_attn_mask']) for ex in batch]) 18 | tgt_token_ids = torch.stack([torch.LongTensor(ex['tgt_token_ids']) for ex in batch]) 19 | tgt_attn_mask = torch.stack([torch.BoolTensor(ex['tgt_attn_mask']) for ex in batch]) 20 | 21 | return { 22 | 'input_token_ids': input_token_ids, 23 | 'input_attn_mask': input_attn_mask, 24 | 'tgt_token_ids': tgt_token_ids, 25 | 'tgt_attn_mask': tgt_attn_mask, 26 | 'doc_key': doc_keys, 27 | } 28 | 29 | 30 | class IEDataset(Dataset): 31 | def __init__(self, input_file): 32 | super().__init__() 33 | self.examples = [] 34 | with open(input_file, 'r') as f: 35 | for line in f: 36 | ex = json.loads(line.strip()) 37 | self.examples.append(ex) 38 | 39 | def __len__(self): 40 | return len(self.examples) 41 | 42 | def __getitem__(self, idx): 43 | return self.examples[idx] 44 | 45 | 46 | -------------------------------------------------------------------------------- /src/genie/data_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | import random 5 | from collections import defaultdict 6 | import argparse 7 | 8 | import transformers 9 | from transformers import BartTokenizer 10 | import torch 11 | from torch.utils.data import DataLoader 12 | import pytorch_lightning as pl 13 | 14 | from .data import IEDataset, my_collate 15 | 16 | MAX_LENGTH=424 17 | MAX_TGT_LENGTH=72 18 | DOC_STRIDE=256 19 | 20 | class RAMSDataModule(pl.LightningDataModule): 21 | def __init__(self, args): 22 | super().__init__() 23 | self.hparams = args 24 | self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 25 | self.tokenizer.add_tokens([' ',' ']) 26 | 27 | def get_event_type(self,ex): 28 | evt_type = [] 29 | for evt in ex['evt_triggers']: 30 | for t in evt[2]: 31 | evt_type.append( t[0]) 32 | return evt_type 33 | 34 | def create_gold_gen(self, ex, ontology_dict,mark_trigger=True): 35 | '''assumes that each line only contains 1 event. 36 | Input: Template with special placeholders Passage 37 | Output: Template with arguments and when no argument is found. 38 | ''' 39 | 40 | evt_type = self.get_event_type(ex)[0] 41 | context_words = [w for sent in ex['sentences'] for w in sent ] 42 | template = ontology_dict[evt_type.replace('n/a','unspecified')]['template'] 43 | input_template = re.sub(r'', '', template) 44 | space_tokenized_input_template = input_template.split(' ') 45 | tokenized_input_template = [] 46 | for w in space_tokenized_input_template: 47 | tokenized_input_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) 48 | 49 | 50 | for triple in ex['gold_evt_links']: 51 | trigger_span, argument_span, arg_name = triple 52 | arg_num = ontology_dict[evt_type.replace('n/a','unspecified')][arg_name] 53 | arg_text = ' '.join(context_words[argument_span[0]:argument_span[1]+1]) 54 | 55 | template = re.sub('<{}>'.format(arg_num),arg_text , template) 56 | 57 | 58 | trigger = ex['evt_triggers'][0] 59 | if mark_trigger: 60 | trigger_span_start = trigger[0] 61 | trigger_span_end = trigger[1] +2 # one for inclusion, one for extra start marker 62 | prefix = self.tokenizer.tokenize(' '.join(context_words[:trigger[0]]), add_prefix_space=True) 63 | tgt = self.tokenizer.tokenize(' '.join(context_words[trigger[0]: trigger[1]+1]), add_prefix_space=True) 64 | 65 | suffix = self.tokenizer.tokenize(' '.join(context_words[trigger[1]+1:]), add_prefix_space=True) 66 | context = prefix + [' ', ] + tgt + [' ', ] + suffix 67 | else: 68 | context = self.tokenizer.tokenize(' '.join(context_words), add_prefix_space=True) 69 | 70 | output_template = re.sub(r'','', template ) 71 | space_tokenized_template = output_template.split(' ') 72 | tokenized_template = [] 73 | for w in space_tokenized_template: 74 | tokenized_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) 75 | 76 | return tokenized_input_template, tokenized_template, context 77 | 78 | 79 | 80 | def load_ontology(self): 81 | # read ontology 82 | ontology_dict ={} 83 | with open('aida_ontology_cleaned.csv','r') as f: 84 | for lidx, line in enumerate(f): 85 | if lidx == 0:# header 86 | continue 87 | fields = line.strip().split(',') 88 | if len(fields) < 2: 89 | break 90 | evt_type = fields[0] 91 | args = fields[2:] 92 | 93 | ontology_dict[evt_type] = { 94 | 'template': fields[1] 95 | } 96 | 97 | for i, arg in enumerate(args): 98 | if arg !='': 99 | ontology_dict[evt_type]['arg{}'.format(i+1)] = arg 100 | ontology_dict[evt_type][arg] = 'arg{}'.format(i+1) 101 | 102 | return ontology_dict 103 | 104 | def prepare_data(self): 105 | if not os.path.exists('preprocessed_data'): 106 | os.makedirs('preprocessed_data') 107 | 108 | ontology_dict = self.load_ontology() 109 | 110 | for split,f in [('train',self.hparams.train_file), ('val',self.hparams.val_file), ('test',self.hparams.test_file)]: 111 | with open(f,'r') as reader, open('preprocessed_data/{}.jsonl'.format(split), 'w') as writer: 112 | for lidx, line in enumerate(reader): 113 | ex = json.loads(line.strip()) 114 | input_template, output_template, context= self.create_gold_gen(ex, ontology_dict, self.hparams.mark_trigger) 115 | 116 | 117 | input_tokens = self.tokenizer.encode_plus(input_template, context, 118 | add_special_tokens=True, 119 | add_prefix_space=True, 120 | max_length=MAX_LENGTH, 121 | truncation='only_second', 122 | padding='max_length') 123 | tgt_tokens = self.tokenizer.encode_plus(output_template, 124 | add_special_tokens=True, 125 | add_prefix_space=True, 126 | max_length=MAX_TGT_LENGTH, 127 | truncation=True, 128 | padding='max_length') 129 | 130 | processed_ex = { 131 | # 'idx': lidx, 132 | 'doc_key': ex['doc_key'], 133 | 'input_token_ids':input_tokens['input_ids'], 134 | 'input_attn_mask': input_tokens['attention_mask'], 135 | 'tgt_token_ids': tgt_tokens['input_ids'], 136 | 'tgt_attn_mask': tgt_tokens['attention_mask'], 137 | } 138 | writer.write(json.dumps(processed_ex) + '\n') 139 | 140 | 141 | 142 | 143 | def train_dataloader(self): 144 | dataset = IEDataset('preprocessed_data/train.jsonl') 145 | 146 | dataloader = DataLoader(dataset, 147 | pin_memory=True, num_workers=2, 148 | collate_fn=my_collate, 149 | batch_size=self.hparams.train_batch_size, 150 | shuffle=True) 151 | return dataloader 152 | 153 | 154 | def val_dataloader(self): 155 | dataset = IEDataset('preprocessed_data/val.jsonl') 156 | 157 | dataloader = DataLoader(dataset, pin_memory=True, num_workers=2, 158 | collate_fn=my_collate, 159 | batch_size=self.hparams.eval_batch_size, shuffle=False) 160 | return dataloader 161 | 162 | def test_dataloader(self): 163 | dataset = IEDataset('preprocessed_data/test.jsonl') 164 | 165 | dataloader = DataLoader(dataset, pin_memory=True, num_workers=2, 166 | collate_fn=my_collate, 167 | batch_size=self.hparams.eval_batch_size, shuffle=False) 168 | 169 | return dataloader 170 | 171 | 172 | if __name__ == '__main__': 173 | parser = argparse.ArgumentParser() 174 | parser.add_argument('--train-file',type=str,default='data/RAMS_1.0/data/train.jsonlines') 175 | parser.add_argument('--val-file', type=str, default='data/RAMS_1.0/data/dev.jsonlines') 176 | parser.add_argument('--test-file', type=str, default='data/RAMS_1.0/data/test.jsonlines') 177 | parser.add_argument('--train_batch_size', type=int, default=2) 178 | parser.add_argument('--eval_batch_size', type=int, default=4) 179 | parser.add_argument('--mark-trigger', action='store_true', default=True) 180 | args = parser.parse_args() 181 | 182 | dm = RAMSDataModule(args=args) 183 | dm.prepare_data() 184 | 185 | # training dataloader 186 | dataloader = dm.train_dataloader() 187 | 188 | for idx, batch in enumerate(dataloader): 189 | print(batch) 190 | break 191 | 192 | # val dataloader -------------------------------------------------------------------------------- /src/genie/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import logging 5 | import json 6 | 7 | 8 | import pytorch_lightning as pl 9 | from transformers import BartTokenizer, BartConfig 10 | from transformers import AdamW, get_linear_schedule_with_warmup 11 | 12 | from .network import BartGen 13 | from .constrained_gen import BartConstrainedGen 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | class GenIEModel(pl.LightningModule): 18 | def __init__(self, args): 19 | super().__init__() 20 | self.hparams = args 21 | 22 | 23 | self.config=BartConfig.from_pretrained('facebook/bart-large') 24 | self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 25 | self.tokenizer.add_tokens([' ',' ']) 26 | 27 | 28 | if self.hparams.model=='gen': 29 | self.model = BartGen(self.config, self.tokenizer) 30 | self.model.resize_token_embeddings() 31 | elif self.hparams.model == 'constrained-gen': 32 | self.model = BartConstrainedGen(self.config, self.tokenizer) 33 | self.model.resize_token_embeddings() 34 | else: 35 | raise NotImplementedError 36 | 37 | 38 | 39 | def forward(self, inputs): 40 | 41 | return self.model(**inputs) 42 | 43 | 44 | def training_step(self, batch, batch_idx): 45 | ''' 46 | processed_ex = { 47 | 'doc_key': ex['doc_key'], 48 | 'input_tokens_ids':input_tokens['input_ids'], 49 | 'input_attn_mask': input_tokens['attention_mask'], 50 | 'tgt_token_ids': tgt_tokens['input_ids'], 51 | 'tgt_attn_mask': tgt_tokens['attention_mask'], 52 | } 53 | ''' 54 | inputs = { 55 | "input_ids": batch["input_token_ids"], 56 | "attention_mask": batch["input_attn_mask"], 57 | "decoder_input_ids": batch['tgt_token_ids'], 58 | "decoder_attention_mask": batch["tgt_attn_mask"], 59 | "task": 0 60 | } 61 | 62 | outputs = self.model(**inputs) 63 | loss = outputs[0] 64 | loss = torch.mean(loss) 65 | 66 | log = { 67 | 'train/loss': loss, 68 | } 69 | return { 70 | 'loss': loss, 71 | 'log': log 72 | } 73 | 74 | 75 | def validation_step(self,batch, batch_idx): 76 | inputs = { 77 | "input_ids": batch["input_token_ids"], 78 | "attention_mask": batch["input_attn_mask"], 79 | "decoder_input_ids": batch['tgt_token_ids'], 80 | "decoder_attention_mask": batch["tgt_attn_mask"], 81 | "task" :0, 82 | } 83 | outputs = self.model(**inputs) 84 | loss = outputs[0] 85 | loss = torch.mean(loss) 86 | 87 | 88 | 89 | return loss 90 | 91 | 92 | def validation_epoch_end(self, outputs): 93 | avg_loss = torch.mean(torch.stack(outputs)) 94 | log = { 95 | 'val/loss': avg_loss, 96 | } 97 | return { 98 | 'loss': avg_loss, 99 | 'log': log 100 | } 101 | 102 | 103 | 104 | 105 | def test_step(self, batch, batch_idx): 106 | if self.hparams.sample_gen: 107 | sample_output = self.model.generate(batch['input_token_ids'], do_sample=True, 108 | top_k=20, top_p=0.95, max_length=30, num_return_sequences=1,num_beams=1, 109 | ) 110 | else: 111 | sample_output = self.model.generate(batch['input_token_ids'], do_sample=False, 112 | max_length=30, num_return_sequences=1,num_beams=1, 113 | ) 114 | 115 | sample_output = sample_output.reshape(batch['input_token_ids'].size(0), 1, -1) 116 | doc_key = batch['doc_key'] # list 117 | tgt_token_ids = batch['tgt_token_ids'] 118 | 119 | return (doc_key, sample_output, tgt_token_ids) 120 | 121 | def test_epoch_end(self, outputs): 122 | # evaluate F1 123 | with open('checkpoints/{}/predictions.jsonl'.format(self.hparams.ckpt_name),'w') as writer: 124 | for tup in outputs: 125 | for idx in range(len(tup[0])): 126 | 127 | pred = { 128 | 'doc_key': tup[0][idx], 129 | 'predicted': self.tokenizer.decode(tup[1][idx].squeeze(0), skip_special_tokens=True), 130 | 'gold': self.tokenizer.decode(tup[2][idx].squeeze(0), skip_special_tokens=True) 131 | } 132 | writer.write(json.dumps(pred)+'\n') 133 | 134 | return {} 135 | 136 | 137 | def configure_optimizers(self): 138 | self.train_len = len(self.train_dataloader()) 139 | if self.hparams.max_steps > 0: 140 | t_total = self.hparams.max_steps 141 | self.hparams.num_train_epochs = self.hparams.max_steps // self.train_len // self.hparams.accumulate_grad_batches + 1 142 | else: 143 | t_total = self.train_len // self.hparams.accumulate_grad_batches * self.hparams.num_train_epochs 144 | 145 | logger.info('{} training steps in total.. '.format(t_total)) 146 | 147 | # Prepare optimizer and schedule (linear warmup and decay) 148 | no_decay = ["bias", "LayerNorm.weight"] 149 | optimizer_grouped_parameters = [ 150 | { 151 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 152 | "weight_decay": self.hparams.weight_decay, 153 | }, 154 | {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 155 | ] 156 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) 157 | # scheduler is called only once per epoch by default 158 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total) 159 | scheduler_dict = { 160 | 'scheduler': scheduler, 161 | 'interval': 'step', 162 | 'name': 'linear-schedule', 163 | } 164 | 165 | return [optimizer, ], [scheduler_dict,] -------------------------------------------------------------------------------- /src/genie/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from transformers import ( 5 | AdamW, 6 | get_linear_schedule_with_warmup, 7 | BartModel, 8 | 9 | ) 10 | from transformers.modeling_utils import PreTrainedModel 11 | from transformers.modeling_outputs import Seq2SeqLMOutput 12 | 13 | 14 | 15 | class BartGen(PreTrainedModel): 16 | def __init__(self, config, tokenizer): 17 | super(BartGen, self).__init__(config) 18 | self.config = config 19 | self.tokenizer = tokenizer 20 | self.transformer = BartModel.from_pretrained('facebook/bart-large') 21 | self.register_buffer("final_logits_bias", torch.zeros((1, self.transformer.shared.num_embeddings))) 22 | 23 | def resize_token_embeddings(self): 24 | old_num_tokens = self.transformer.shared.num_embeddings 25 | new_embeddings = self.transformer.resize_token_embeddings(len(self.tokenizer)) 26 | self.transformer.shared = new_embeddings 27 | self._resize_final_logits_bias(len(self.tokenizer), old_num_tokens) 28 | self.vocab_size = len(self.tokenizer) 29 | 30 | return new_embeddings 31 | 32 | def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None: 33 | if new_num_tokens <= old_num_tokens: 34 | new_bias = self.final_logits_bias[:, :new_num_tokens] 35 | else: 36 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 37 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 38 | self.register_buffer("final_logits_bias", new_bias) 39 | 40 | 41 | def _init_weights(self, module): 42 | """ Initialize the weights """ 43 | if isinstance(module, (nn.Linear, nn.Embedding)): 44 | # Slightly different from the TF version which uses truncated_normal for initialization 45 | # cf https://github.com/pytorch/pytorch/pull/5617 46 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 47 | elif isinstance(module, torch.nn.LayerNorm): # if use apex, this should be FusedLayerNorm 48 | module.bias.data.zero_() 49 | module.weight.data.fill_(1.0) 50 | if isinstance(module, nn.Linear) and module.bias is not None: 51 | module.bias.data.zero_() 52 | 53 | 54 | def get_encoder(self): 55 | return self.transformer.encoder 56 | 57 | 58 | def get_output_embeddings(self): 59 | # this method is needed for generation 60 | vocab_size, emb_size = self.transformer.shared.weight.shape 61 | lin_layer = nn.Linear(vocab_size, emb_size, bias=False) 62 | lin_layer.weight.data = self.transformer.shared.weight.data 63 | return lin_layer 64 | 65 | 66 | def prepare_inputs_for_generation( 67 | self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs 68 | ): 69 | return { 70 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 71 | "encoder_outputs": encoder_outputs, 72 | "past_key_values": past, 73 | "decoder_input_ids": decoder_input_ids, 74 | "attention_mask": attention_mask, 75 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 76 | } 77 | 78 | def adjust_logits_during_generation(self, logits, cur_len, max_length): 79 | if cur_len == 1 and self.config.force_bos_token_to_be_generated: 80 | self._force_token_ids_generation(logits, self.config.bos_token_id) 81 | elif cur_len == max_length - 1 and self.config.eos_token_id is not None: 82 | self._force_token_ids_generation(logits, self.config.eos_token_id) 83 | return logits 84 | 85 | def _force_token_ids_generation(self, scores, token_id) -> None: 86 | """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" 87 | scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf") 88 | 89 | @staticmethod 90 | def _reorder_cache(past, beam_idx): 91 | reordered_past = [] 92 | for layer_past in past: 93 | # get the correct batch idx from decoder layer's batch dim for cross and self-attn 94 | layer_past_new = { 95 | attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() 96 | } 97 | reordered_past.append(layer_past_new) 98 | return reordered_past 99 | 100 | 101 | 102 | def forward(self, input_ids, 103 | attention_mask=None, 104 | encoder_outputs=None, 105 | use_cache=False, 106 | past_key_values=None, 107 | decoder_input_ids=None, 108 | decoder_attention_mask=None, 109 | output_attentions=None, 110 | output_hidden_states=None, 111 | return_dict=None, 112 | task=-1): 113 | 114 | # generation 115 | if task==-1: 116 | outputs = self.transformer( 117 | input_ids=input_ids, 118 | attention_mask=attention_mask, 119 | decoder_input_ids=decoder_input_ids, 120 | decoder_attention_mask=decoder_attention_mask, 121 | use_cache=use_cache, 122 | encoder_outputs=encoder_outputs, 123 | past_key_values=past_key_values, 124 | output_attentions=output_attentions, 125 | output_hidden_states=output_hidden_states, 126 | return_dict=return_dict,) 127 | 128 | lm_logits = F.linear(outputs[0], self.transformer.shared.weight, bias=self.final_logits_bias) 129 | masked_lm_loss = None 130 | 131 | if not return_dict: 132 | output = (lm_logits,) + outputs[1:] 133 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 134 | 135 | return Seq2SeqLMOutput( 136 | loss=masked_lm_loss, 137 | logits=lm_logits, 138 | past_key_values=outputs.past_key_values, 139 | decoder_hidden_states=outputs.decoder_hidden_states, 140 | decoder_attentions=outputs.decoder_attentions, 141 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 142 | encoder_hidden_states=outputs.encoder_hidden_states, 143 | encoder_attentions=outputs.encoder_attentions, 144 | ) 145 | 146 | #training 147 | elif task==0: 148 | 149 | assert(decoder_input_ids!=None) 150 | y_ids = decoder_input_ids[:, :-1] 151 | labels = decoder_input_ids[:, 1:].clone() 152 | labels[labels== self.tokenizer.pad_token_id] = -100 153 | # labels are just decoder_input_ids shifted to the right by 1 154 | 155 | outputs = self.transformer( 156 | input_ids, 157 | attention_mask=attention_mask, 158 | decoder_input_ids=y_ids, 159 | decoder_attention_mask=decoder_attention_mask[:, :-1], 160 | use_cache=False, 161 | past_key_values=past_key_values, 162 | output_attentions=output_attentions, 163 | output_hidden_states=output_hidden_states, 164 | return_dict=return_dict,) 165 | 166 | sequence_output = outputs[0] 167 | 168 | lm_logits = F.linear(sequence_output, self.transformer.shared.weight, bias=self.final_logits_bias) 169 | outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here 170 | loss_fct = nn.CrossEntropyLoss() 171 | 172 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.vocab_size), labels.view(-1)) 173 | outputs = (masked_lm_loss,) + outputs 174 | 175 | return outputs 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /src/genie/pipeline_scorer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file scores an argument prediction file from predicted triggers. 3 | ''' 4 | import os 5 | import json 6 | import argparse 7 | 8 | import re 9 | from copy import deepcopy 10 | from collections import defaultdict 11 | from tqdm import tqdm 12 | import spacy 13 | 14 | 15 | from utils import load_ontology,find_arg_span, compute_f1, get_entity_span, find_head, WhitespaceTokenizer 16 | 17 | nlp = spacy.load('en_core_web_sm') 18 | nlp.tokenizer = WhitespaceTokenizer(nlp.vocab) 19 | 20 | ''' 21 | Scorer for argument extraction on ACE & KAIROS. 22 | For the RAMS dataset, the official scorer is used. 23 | 24 | Outputs: 25 | Head F1 26 | Coref F1 27 | ''' 28 | def clean_span(ex, span): 29 | tokens = ex['tokens'] 30 | if tokens[span[0]].lower() in {'the', 'an', 'a'}: 31 | if span[0]!=span[1]: 32 | return (span[0]+1, span[1]) 33 | return span 34 | 35 | def extract_args_from_template(ex, template, ontology_dict,): 36 | # extract argument text 37 | template_words = template.strip().split() 38 | predicted_words = ex['predicted'].strip().split() 39 | predicted_args = defaultdict(list) # each argname may have multiple participants 40 | t_ptr= 0 41 | p_ptr= 0 42 | evt_type = ex['event']['event_type'] 43 | while t_ptr < len(template_words) and p_ptr < len(predicted_words): 44 | if re.match(r'<(arg\d+)>', template_words[t_ptr]): 45 | m = re.match(r'<(arg\d+)>', template_words[t_ptr]) 46 | arg_num = m.group(1) 47 | try: 48 | arg_name = ontology_dict[evt_type][arg_num] 49 | except KeyError: 50 | print(evt_type) 51 | exit() 52 | 53 | if predicted_words[p_ptr] == '': 54 | # missing argument 55 | p_ptr +=1 56 | t_ptr +=1 57 | else: 58 | arg_start = p_ptr 59 | while (p_ptr < len(predicted_words)) and ((t_ptr== len(template_words)-1) or (predicted_words[p_ptr] != template_words[t_ptr+1])): 60 | p_ptr+=1 61 | arg_text = predicted_words[arg_start:p_ptr] 62 | predicted_args[arg_name].append(arg_text) 63 | t_ptr+=1 64 | # aligned 65 | else: 66 | t_ptr+=1 67 | p_ptr+=1 68 | 69 | return predicted_args 70 | 71 | def create_coref_mapping(args): 72 | ''' 73 | Coref mapping for the entire data split. 74 | ''' 75 | coref_mapping = defaultdict(dict) # span to canonical entity_id mapping for each doc 76 | if args.dataset == 'KAIROS' and args.coref_file: 77 | with open(args.coref_file, 'r') as f, open(args.test_file, 'r') as test_reader: 78 | for line, test_line in zip(f, test_reader): 79 | coref_ex = json.loads(line) 80 | ex = json.loads(test_line) 81 | doc_id = coref_ex['doc_key'] 82 | 83 | for cluster, name in zip(coref_ex['clusters'], coref_ex['informative_mentions']): 84 | canonical = cluster[0] 85 | for ent_id in cluster: 86 | ent_span = get_entity_span(ex, ent_id) 87 | ent_span = (ent_span[0], ent_span[1]-1) 88 | coref_mapping[doc_id][ent_span] = canonical 89 | # this does not include singleton clusters 90 | else: 91 | # for the ACE dataset 92 | with open(args.test_file) as f: 93 | for line in f: 94 | doc=json.loads(line.strip()) 95 | doc_id = doc['sent_id'] 96 | for entity in doc['entity_mentions']: 97 | mention_id = entity['id'] 98 | ent_id = '-'.join(mention_id.split('-')[:-1]) 99 | coref_mapping[doc_id][(entity['start'], entity['end']-1)] = ent_id # all indexes are inclusive 100 | return coref_mapping 101 | 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('--gen-file',type=str,default='checkpoints/gen-all-ACE-freq-pipeline/predictions.jsonl' ) 107 | parser.add_argument('--tgr-pred-file', type=str,default='data/ace/zs-freq-10/pred.oneie.json') 108 | parser.add_argument('--test-file', type=str,default='data/ace/zs-freq-10/test.oneie.json') 109 | parser.add_argument('--coref-file', type=str) 110 | parser.add_argument('--head-only', action='store_true') 111 | parser.add_argument('--coref', action='store_true', default=True) 112 | parser.add_argument('--dataset',type=str, default='ACE', choices=['ACE', 'KAIROS']) 113 | parser.add_argument('--seen-types', type=str) 114 | args = parser.parse_args() 115 | 116 | 117 | 118 | ontology_dict = load_ontology(dataset=args.dataset) 119 | seen_types = set() 120 | if args.seen_types: 121 | with open(args.seen_types) as f: 122 | for line in f: 123 | e = line.strip() 124 | assert(e in ontology_dict) 125 | seen_types.add( e) 126 | 127 | 128 | if args.dataset == 'KAIROS' and not args.coref_file: 129 | print('coreference file needed for the KAIROS dataset.') 130 | raise ValueError 131 | 132 | coref_mapping = create_coref_mapping(args) 133 | 134 | examples = {} 135 | doc2ex = defaultdict(list) # a document contains multiple events 136 | with open(args.gen_file,'r') as f: 137 | for lidx, line in enumerate(f): # this solution relies on keeping the exact same order 138 | pred = json.loads(line.strip()) 139 | examples[lidx] = { 140 | 'predicted': pred['predicted'], 141 | 'gold': pred['gold'], 142 | 'doc_id': pred['doc_key'] 143 | } 144 | doc2ex[pred['doc_key']].append(lidx) 145 | 146 | with open(args.tgr_pred_file, 'r') as f: 147 | for line in f: 148 | doc = json.loads(line.strip()) 149 | if 'sent_id' in doc.keys(): 150 | doc_id = doc['sent_id'] 151 | # print('evaluating on sentence level') 152 | else: 153 | doc_id = doc['doc_id'] 154 | # print('evaluating on document level') 155 | for idx, eid in enumerate(doc2ex[doc_id]): 156 | examples[eid]['tokens'] = doc['tokens'] 157 | try: 158 | examples[eid]['event'] = doc['event_mentions'][idx] 159 | examples[eid]['entity_mentions'] = doc['entity_mentions'] 160 | except IndexError: 161 | print(doc_id) 162 | exit() 163 | 164 | 165 | gold_evt_dict ={} # doc_key -> list of event_mentions 166 | with open(args.test_file, 'r') as f: 167 | for line in f: 168 | doc = json.loads(line.strip()) 169 | if 'sent_id' in doc.keys(): 170 | doc_id = doc['sent_id'] 171 | # print('evaluating on sentence level') 172 | else: 173 | doc_id = doc['doc_id'] 174 | # print('evaluating on document level') 175 | gold_evt_dict[doc_id] = doc['event_mentions'] 176 | 177 | 178 | gold_arg_num =0 179 | # directly compute the number of gold args 180 | for event_list in gold_evt_dict.values(): 181 | for e in event_list: 182 | if e['event_type'] not in seen_types: 183 | gold_arg_num += len(e['arguments']) 184 | 185 | 186 | pred_arg_num =0 187 | 188 | arg_idn_num =0 189 | arg_class_num =0 190 | 191 | arg_idn_coref_num =0 192 | arg_class_coref_num =0 193 | 194 | for ex in tqdm(examples.values()): 195 | # an example is a single predicted event 196 | context_words = ex['tokens'] 197 | doc_id = ex['doc_id'] 198 | doc = None 199 | if args.head_only: 200 | doc = nlp(' '.join(context_words)) 201 | 202 | # get template 203 | evt_type = ex['event']['event_type'] 204 | if evt_type in seen_types: 205 | continue 206 | 207 | if evt_type not in ontology_dict: 208 | continue 209 | template = ontology_dict[evt_type]['template'] 210 | # extract argument text 211 | predicted_args = extract_args_from_template(ex,template, ontology_dict) 212 | # get trigger 213 | # extract argument span 214 | trigger_start = ex['event']['trigger']['start'] 215 | trigger_end = ex['event']['trigger']['end'] 216 | 217 | predicted_set = set() 218 | for argname in predicted_args: 219 | for entity in predicted_args[argname]:# this argument span is inclusive 220 | arg_span = find_arg_span(entity, context_words, 221 | trigger_start, trigger_end, head_only=args.head_only, doc=doc) 222 | 223 | if arg_span:# if None means hullucination 224 | 225 | predicted_set.add((arg_span[0], arg_span[1], evt_type, argname)) 226 | 227 | else: 228 | new_entity = [] 229 | for w in entity: 230 | if w == 'and' and len(new_entity) >0: 231 | arg_span = find_arg_span(new_entity, context_words, trigger_start, trigger_end, 232 | head_only=args.head_only, doc=doc) 233 | if arg_span: predicted_set.add((arg_span[0], arg_span[1], evt_type, argname)) 234 | new_entity = [] 235 | else: 236 | new_entity.append(w) 237 | 238 | if len(new_entity) >0: # last entity 239 | arg_span = find_arg_span(new_entity, context_words, trigger_start, trigger_end, 240 | head_only=args.head_only, doc=doc) 241 | if arg_span: predicted_set.add((arg_span[0], arg_span[1], evt_type, argname)) 242 | 243 | 244 | gold_set = set() 245 | gold_canonical_set = set() # set of canonical mention ids, singleton mentions will not be here 246 | # check if this event is in the gold events 247 | for e in gold_evt_dict[doc_id]: 248 | if (e['event_type'] == evt_type) and (e['trigger'] == ex['event']['trigger']): 249 | # trigger extraction is correct 250 | 251 | for arg in e['arguments']: 252 | argname = arg['role'] 253 | entity_id = arg['entity_id'] 254 | span = get_entity_span(ex, entity_id) 255 | span = (span[0], span[1]-1) 256 | span = clean_span(ex, span) 257 | # clean up span by removing `a` `the` 258 | if args.head_only and span[0]!=span[1]: 259 | span = find_head(span[0], span[1], doc=doc) 260 | 261 | gold_set.add((span[0], span[1], evt_type, argname)) 262 | if span in coref_mapping[doc_id]: 263 | canonical_id = coref_mapping[doc_id][span] 264 | gold_canonical_set.add((canonical_id, evt_type, argname)) 265 | 266 | pred_arg_num += len(predicted_set) 267 | 268 | # check matches 269 | for pred_arg in predicted_set: 270 | arg_start, arg_end, event_type, role = pred_arg 271 | gold_idn = {item for item in gold_set 272 | if item[0] == arg_start and item[1] == arg_end 273 | and item[2] == event_type} 274 | if gold_idn: 275 | arg_idn_num += 1 276 | gold_class = {item for item in gold_idn if item[-1] == role} 277 | if gold_class: 278 | arg_class_num += 1 279 | elif args.coref:# check coref matches 280 | arg_start, arg_end, event_type, role = pred_arg 281 | span = (arg_start, arg_end) 282 | if span in coref_mapping[doc_id]: 283 | canonical_id = coref_mapping[doc_id][span] 284 | gold_idn_coref = {item for item in gold_canonical_set 285 | if item[0] == canonical_id and item[1] == event_type} 286 | if gold_idn_coref: 287 | arg_idn_coref_num +=1 288 | gold_class_coref = {item for item in gold_idn_coref 289 | if item[2] == role} 290 | if gold_class_coref: 291 | arg_class_coref_num +=1 292 | 293 | if args.head_only: 294 | print('Evaluation by matching head words only....') 295 | 296 | 297 | role_id_prec, role_id_rec, role_id_f = compute_f1( 298 | pred_arg_num, gold_arg_num, arg_idn_num) 299 | role_prec, role_rec, role_f = compute_f1( 300 | pred_arg_num, gold_arg_num, arg_class_num) 301 | 302 | 303 | print('Role identification: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( 304 | role_id_prec * 100.0, role_id_rec * 100.0, role_id_f * 100.0)) 305 | print('Role: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( 306 | role_prec * 100.0, role_rec * 100.0, role_f * 100.0)) 307 | 308 | if args.coref: 309 | role_id_prec, role_id_rec, role_id_f = compute_f1( 310 | pred_arg_num, gold_arg_num, arg_idn_num + arg_idn_coref_num) 311 | role_prec, role_rec, role_f = compute_f1( 312 | pred_arg_num, gold_arg_num, arg_class_num + arg_class_coref_num) 313 | 314 | 315 | print('Coref Role identification: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( 316 | role_id_prec * 100.0, role_id_rec * 100.0, role_id_f * 100.0)) 317 | print('Coref Role: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( 318 | role_prec * 100.0, role_rec * 100.0, role_f * 100.0)) 319 | -------------------------------------------------------------------------------- /src/genie/scorer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import re 5 | from copy import deepcopy 6 | from collections import defaultdict 7 | from tqdm import tqdm 8 | import spacy 9 | 10 | 11 | from utils import load_ontology,find_arg_span, compute_f1, get_entity_span, find_head, WhitespaceTokenizer 12 | 13 | nlp = spacy.load('en_core_web_sm') 14 | nlp.tokenizer = WhitespaceTokenizer(nlp.vocab) 15 | 16 | ''' 17 | Scorer for argument extraction on ACE & KAIROS. 18 | For the RAMS dataset, the official scorer is used. 19 | 20 | Outputs: 21 | Head F1 22 | Coref F1 23 | ''' 24 | def clean_span(ex, span): 25 | tokens = ex['tokens'] 26 | if tokens[span[0]].lower() in {'the', 'an', 'a'}: 27 | if span[0]!=span[1]: 28 | return (span[0]+1, span[1]) 29 | return span 30 | 31 | def extract_args_from_template(ex, template, ontology_dict,): 32 | # extract argument text 33 | template_words = template.strip().split() 34 | predicted_words = ex['predicted'].strip().split() 35 | predicted_args = defaultdict(list) # each argname may have multiple participants 36 | t_ptr= 0 37 | p_ptr= 0 38 | evt_type = ex['event']['event_type'] 39 | while t_ptr < len(template_words) and p_ptr < len(predicted_words): 40 | if re.match(r'<(arg\d+)>', template_words[t_ptr]): 41 | m = re.match(r'<(arg\d+)>', template_words[t_ptr]) 42 | arg_num = m.group(1) 43 | try: 44 | arg_name = ontology_dict[evt_type][arg_num] 45 | except KeyError: 46 | print(evt_type) 47 | exit() 48 | 49 | if predicted_words[p_ptr] == '': 50 | # missing argument 51 | p_ptr +=1 52 | t_ptr +=1 53 | else: 54 | arg_start = p_ptr 55 | while (p_ptr < len(predicted_words)) and ((t_ptr== len(template_words)-1) or (predicted_words[p_ptr] != template_words[t_ptr+1])): 56 | p_ptr+=1 57 | arg_text = predicted_words[arg_start:p_ptr] 58 | predicted_args[arg_name].append(arg_text) 59 | t_ptr+=1 60 | # aligned 61 | else: 62 | t_ptr+=1 63 | p_ptr+=1 64 | 65 | return predicted_args 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--gen-file',type=str,default='checkpoints/gen-all-ACE-freq-pred/predictions.jsonl' ) 76 | parser.add_argument('--test-file', type=str,default='data/ace/zs-freq-10/test.oneie.json') 77 | parser.add_argument('--coref-file', type=str) 78 | parser.add_argument('--head-only', action='store_true') 79 | parser.add_argument('--coref', action='store_true') 80 | parser.add_argument('--dataset',type=str, default='ACE', choices=['ACE', 'KAIROS','AIDA']) 81 | args = parser.parse_args() 82 | 83 | 84 | ontology_dict = load_ontology(dataset=args.dataset) 85 | 86 | if args.dataset == 'KAIROS' and args.coref and not args.coref_file: 87 | print('coreference file needed for the KAIROS dataset.') 88 | raise ValueError 89 | if args.dataset == 'AIDA' and args.coref: 90 | raise NotImplementedError 91 | 92 | examples = {} 93 | doc2ex = defaultdict(list) # a document contains multiple events 94 | with open(args.gen_file,'r') as f: 95 | for lidx, line in enumerate(f): # this solution relies on keeping the exact same order 96 | pred = json.loads(line.strip()) 97 | examples[lidx] = { 98 | 'predicted': pred['predicted'], 99 | 'gold': pred['gold'], 100 | 'doc_id': pred['doc_key'] 101 | } 102 | doc2ex[pred['doc_key']].append(lidx) 103 | 104 | with open(args.test_file, 'r') as f: 105 | for line in f: 106 | doc = json.loads(line.strip()) 107 | if 'sent_id' in doc.keys(): 108 | doc_id = doc['sent_id'] 109 | # print('evaluating on sentence level') 110 | else: 111 | doc_id = doc['doc_id'] 112 | # print('evaluating on document level') 113 | for idx, eid in enumerate(doc2ex[doc_id]): 114 | examples[eid]['tokens'] = doc['tokens'] 115 | examples[eid]['event'] = doc['event_mentions'][idx] 116 | examples[eid]['entity_mentions'] = doc['entity_mentions'] 117 | 118 | coref_mapping = defaultdict(dict) # span to canonical entity_id mapping for each doc 119 | if args.coref: 120 | if args.dataset == 'KAIROS' and args.coref_file: 121 | with open(args.coref_file, 'r') as f, open(args.test_file, 'r') as test_reader: 122 | for line, test_line in zip(f, test_reader): 123 | coref_ex = json.loads(line) 124 | ex = json.loads(test_line) 125 | doc_id = coref_ex['doc_key'] 126 | 127 | for cluster, name in zip(coref_ex['clusters'], coref_ex['informative_mentions']): 128 | canonical = cluster[0] 129 | for ent_id in cluster: 130 | ent_span = get_entity_span(ex, ent_id) 131 | ent_span = (ent_span[0], ent_span[1]-1) 132 | coref_mapping[doc_id][ent_span] = canonical 133 | # this does not include singleton clusters 134 | else: 135 | # for the ACE dataset 136 | with open(args.test_file) as f: 137 | for line in f: 138 | doc=json.loads(line.strip()) 139 | doc_id = doc['sent_id'] 140 | for entity in doc['entity_mentions']: 141 | mention_id = entity['id'] 142 | ent_id = '-'.join(mention_id.split('-')[:-1]) 143 | coref_mapping[doc_id][(entity['start'], entity['end']-1)] = ent_id # all indexes are inclusive 144 | 145 | 146 | 147 | pred_arg_num =0 148 | gold_arg_num =0 149 | arg_idn_num =0 150 | arg_class_num =0 151 | 152 | arg_idn_coref_num =0 153 | arg_class_coref_num =0 154 | 155 | for ex in tqdm(examples.values()): 156 | context_words = ex['tokens'] 157 | doc_id = ex['doc_id'] 158 | doc = None 159 | if args.head_only: 160 | doc = nlp(' '.join(context_words)) 161 | 162 | # get template 163 | evt_type = ex['event']['event_type'] 164 | 165 | if evt_type not in ontology_dict: 166 | continue 167 | template = ontology_dict[evt_type]['template'] 168 | # extract argument text 169 | predicted_args = extract_args_from_template(ex,template, ontology_dict) 170 | # get trigger 171 | # extract argument span 172 | trigger_start = ex['event']['trigger']['start'] 173 | trigger_end = ex['event']['trigger']['end'] 174 | 175 | predicted_set = set() 176 | for argname in predicted_args: 177 | for entity in predicted_args[argname]:# this argument span is inclusive, FIXME: this might be problematic 178 | arg_span = find_arg_span(entity, context_words, 179 | trigger_start, trigger_end, head_only=args.head_only, doc=doc) 180 | 181 | if arg_span:# if None means hullucination 182 | 183 | predicted_set.add((arg_span[0], arg_span[1], evt_type, argname)) 184 | 185 | else: 186 | new_entity = [] 187 | for w in entity: 188 | if w == 'and' and len(new_entity) >0: 189 | arg_span = find_arg_span(new_entity, context_words, trigger_start, trigger_end, 190 | head_only=args.head_only, doc=doc) 191 | if arg_span: predicted_set.add((arg_span[0], arg_span[1], evt_type, argname)) 192 | new_entity = [] 193 | else: 194 | new_entity.append(w) 195 | 196 | if len(new_entity) >0: # last entity 197 | arg_span = find_arg_span(new_entity, context_words, trigger_start, trigger_end, 198 | head_only=args.head_only, doc=doc) 199 | if arg_span: predicted_set.add((arg_span[0], arg_span[1], evt_type, argname)) 200 | 201 | 202 | # get gold spans 203 | gold_set = set() 204 | gold_canonical_set = set() # set of canonical mention ids, singleton mentions will not be here 205 | for arg in ex['event']['arguments']: 206 | argname = arg['role'] 207 | entity_id = arg['entity_id'] 208 | span = get_entity_span(ex, entity_id) 209 | span = (span[0], span[1]-1) 210 | span = clean_span(ex, span) 211 | # clean up span by removing `a` `the` 212 | if args.head_only and span[0]!=span[1]: 213 | span = find_head(span[0], span[1], doc=doc) 214 | 215 | gold_set.add((span[0], span[1], evt_type, argname)) 216 | if args.coref: 217 | if span in coref_mapping[doc_id]: 218 | canonical_id = coref_mapping[doc_id][span] 219 | gold_canonical_set.add((canonical_id, evt_type, argname)) 220 | 221 | 222 | pred_arg_num += len(predicted_set) 223 | gold_arg_num += len(gold_set) 224 | # check matches 225 | for pred_arg in predicted_set: 226 | arg_start, arg_end, event_type, role = pred_arg 227 | gold_idn = {item for item in gold_set 228 | if item[0] == arg_start and item[1] == arg_end 229 | and item[2] == event_type} 230 | if gold_idn: 231 | arg_idn_num += 1 232 | gold_class = {item for item in gold_idn if item[-1] == role} 233 | if gold_class: 234 | arg_class_num += 1 235 | elif args.coref:# check coref matches 236 | arg_start, arg_end, event_type, role = pred_arg 237 | span = (arg_start, arg_end) 238 | if span in coref_mapping[doc_id]: 239 | canonical_id = coref_mapping[doc_id][span] 240 | gold_idn_coref = {item for item in gold_canonical_set 241 | if item[0] == canonical_id and item[1] == event_type} 242 | if gold_idn_coref: 243 | arg_idn_coref_num +=1 244 | gold_class_coref = {item for item in gold_idn_coref 245 | if item[2] == role} 246 | if gold_class_coref: 247 | arg_class_coref_num +=1 248 | 249 | 250 | 251 | if args.head_only: 252 | print('Evaluation by matching head words only....') 253 | 254 | 255 | role_id_prec, role_id_rec, role_id_f = compute_f1( 256 | pred_arg_num, gold_arg_num, arg_idn_num) 257 | role_prec, role_rec, role_f = compute_f1( 258 | pred_arg_num, gold_arg_num, arg_class_num) 259 | 260 | 261 | print('Role identification: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( 262 | role_id_prec * 100.0, role_id_rec * 100.0, role_id_f * 100.0)) 263 | print('Role: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( 264 | role_prec * 100.0, role_rec * 100.0, role_f * 100.0)) 265 | 266 | if args.coref: 267 | role_id_prec, role_id_rec, role_id_f = compute_f1( 268 | pred_arg_num, gold_arg_num, arg_idn_num + arg_idn_coref_num) 269 | role_prec, role_rec, role_f = compute_f1( 270 | pred_arg_num, gold_arg_num, arg_class_num + arg_class_coref_num) 271 | 272 | 273 | print('Coref Role identification: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( 274 | role_id_prec * 100.0, role_id_rec * 100.0, role_id_f * 100.0)) 275 | print('Coref Role: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( 276 | role_prec * 100.0, role_rec * 100.0, role_f * 100.0)) 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | -------------------------------------------------------------------------------- /src/genie/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import spacy 3 | from spacy.tokens import Doc 4 | PRONOUN_FILE='pronoun_list.txt' 5 | pronoun_set = set() 6 | with open(PRONOUN_FILE, 'r') as f: 7 | for line in f: 8 | pronoun_set.add(line.strip()) 9 | 10 | 11 | def check_pronoun(text): 12 | if text.lower() in pronoun_set: 13 | return True 14 | else: 15 | return False 16 | 17 | 18 | def clean_mention(text): 19 | ''' 20 | Clean up a mention by removing 'a', 'an', 'the' prefixes. 21 | ''' 22 | prefixes = ['the ', 'The ', 'an ', 'An ', 'a ', 'A '] 23 | for prefix in prefixes: 24 | if text.startswith(prefix): 25 | return text[len(prefix):] 26 | return text 27 | 28 | 29 | def safe_div(num, denom): 30 | if denom > 0: 31 | return num / denom 32 | else: 33 | return 0 34 | 35 | def compute_f1(predicted, gold, matched): 36 | precision = safe_div(matched, predicted) 37 | recall = safe_div(matched, gold) 38 | f1 = safe_div(2 * precision * recall, precision + recall) 39 | return precision, recall, f1 40 | 41 | class WhitespaceTokenizer: 42 | def __init__(self, vocab): 43 | self.vocab = vocab 44 | 45 | def __call__(self, text): 46 | words = text.split(" ") 47 | return Doc(self.vocab, words=words) 48 | 49 | def find_head(arg_start, arg_end, doc): 50 | cur_i = arg_start 51 | while doc[cur_i].head.i >= arg_start and doc[cur_i].head.i <=arg_end: 52 | if doc[cur_i].head.i == cur_i: 53 | # self is the head 54 | break 55 | else: 56 | cur_i = doc[cur_i].head.i 57 | 58 | arg_head = cur_i 59 | 60 | return (arg_head, arg_head) 61 | 62 | 63 | def load_ontology(dataset, ontology_file=None): 64 | ''' 65 | Read ontology file for event to argument mapping. 66 | ''' 67 | ontology_dict ={} 68 | if not ontology_file: # use the default file path 69 | if not dataset: 70 | raise ValueError 71 | with open('event_role_{}.json'.format(dataset),'r') as f: 72 | ontology_dict = json.load(f) 73 | else: 74 | with open(ontology_file,'r') as f: 75 | ontology_dict = json.load(f) 76 | 77 | for evt_name, evt_dict in ontology_dict.items(): 78 | for i, argname in enumerate(evt_dict['roles']): 79 | evt_dict['arg{}'.format(i+1)] = argname 80 | # argname -> role is not a one-to-one mapping 81 | if argname in evt_dict: 82 | evt_dict[argname].append('arg{}'.format(i+1)) 83 | else: 84 | evt_dict[argname] = ['arg{}'.format(i+1)] 85 | 86 | return ontology_dict 87 | 88 | def find_arg_span(arg, context_words, trigger_start, trigger_end, head_only=False, doc=None): 89 | match = None 90 | arg_len = len(arg) 91 | min_dis = len(context_words) # minimum distance to trigger 92 | for i, w in enumerate(context_words): 93 | if context_words[i:i+arg_len] == arg: 94 | if i < trigger_start: 95 | dis = abs(trigger_start-i-arg_len) 96 | else: 97 | dis = abs(i-trigger_end) 98 | if dis< min_dis: 99 | match = (i, i+arg_len-1) 100 | min_dis = dis 101 | 102 | if match and head_only: 103 | assert(doc!=None) 104 | match = find_head(match[0], match[1], doc) 105 | return match 106 | 107 | def get_entity_span(ex, entity_id): 108 | for ent in ex['entity_mentions']: 109 | if ent['id'] == entity_id: 110 | return (ent['start'], ent['end']) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import timeit 6 | from datetime import datetime 7 | 8 | import torch 9 | import wandb 10 | from pytorch_lightning import Trainer 11 | from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping, ModelCheckpoint 12 | from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger 13 | from pytorch_lightning.utilities.seed import seed_everything 14 | 15 | 16 | 17 | from src.genie.data_module import RAMSDataModule 18 | from src.genie.ACE_data_module import ACEDataModule 19 | from src.genie.KAIROS_data_module import KAIROSDataModule 20 | from src.genie.model import GenIEModel 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser() 28 | 29 | # Required parameters 30 | parser.add_argument( 31 | "--model", 32 | type=str, 33 | required=True, 34 | choices=['gen','constrained-gen'] 35 | ) 36 | parser.add_argument( 37 | "--dataset", 38 | type=str, 39 | required=True, 40 | choices=['RAMS', 'ACE', 'KAIROS'] 41 | ) 42 | parser.add_argument('--tmp_dir', type=str) 43 | parser.add_argument( 44 | "--ckpt_name", 45 | default=None, 46 | type=str, 47 | help="The output directory where the model checkpoints and predictions will be written.", 48 | ) 49 | parser.add_argument( 50 | "--load_ckpt", 51 | default=None, 52 | type=str, 53 | ) 54 | parser.add_argument( 55 | "--train_file", 56 | default=None, 57 | type=str, 58 | help="The input training file. If a data dir is specified, will look for the file there" 59 | + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", 60 | ) 61 | parser.add_argument( 62 | "--val_file", 63 | default=None, 64 | type=str, 65 | help="The input evaluation file. If a data dir is specified, will look for the file there" 66 | + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", 67 | ) 68 | parser.add_argument( 69 | '--test_file', 70 | type=str, 71 | default=None, 72 | ) 73 | parser.add_argument('--input_dir', type=str, default=None) 74 | parser.add_argument('--coref_dir', type=str, default='data/kairos/coref_outputs') 75 | parser.add_argument('--use_info', action='store_true', default=False, help='use informative mentions instead of the nearest mention.') 76 | parser.add_argument('--mark_trigger', action='store_true') 77 | parser.add_argument('--sample-gen', action='store_true', help='Do sampling when generation.') 78 | parser.add_argument("--train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") 79 | parser.add_argument( 80 | "--eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation." 81 | ) 82 | parser.add_argument( 83 | "--eval_only", action="store_true", 84 | ) 85 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 86 | parser.add_argument( 87 | "--accumulate_grad_batches", 88 | type=int, 89 | default=1, 90 | help="Number of updates steps to accumulate before performing a backward/update pass.", 91 | ) 92 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 93 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 94 | parser.add_argument("--gradient_clip_val", default=1.0, type=float, help="Max gradient norm.") 95 | parser.add_argument( 96 | "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform." 97 | ) 98 | parser.add_argument( 99 | "--max_steps", 100 | default=-1, 101 | type=int, 102 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", 103 | ) 104 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 105 | 106 | parser.add_argument("--gpus", default=-1, help='-1 means train on all gpus') 107 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 108 | parser.add_argument( 109 | "--fp16", 110 | action="store_true", 111 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 112 | ) 113 | parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") 114 | args = parser.parse_args() 115 | 116 | # Setup logging 117 | logging.basicConfig( 118 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 119 | datefmt="%m/%d/%Y %H:%M:%S", 120 | level=logging.INFO, 121 | ) 122 | # Set seed 123 | seed_everything(args.seed) 124 | 125 | logger.info("Training/evaluation parameters %s", args) 126 | 127 | 128 | if not args.ckpt_name: 129 | d = datetime.now() 130 | time_str = d.strftime('%m-%dT%H%M') 131 | args.ckpt_name = '{}_{}lr{}_{}'.format(args.model, args.train_batch_size * args.accumulate_grad_batches, 132 | args.learning_rate, time_str) 133 | 134 | 135 | args.ckpt_dir = os.path.join(f'./checkpoints/{args.ckpt_name}') 136 | 137 | os.makedirs(args.ckpt_dir) 138 | 139 | checkpoint_callback = ModelCheckpoint( 140 | dirpath=args.ckpt_dir, 141 | save_top_k=2, 142 | monitor='val/loss', 143 | mode='min', 144 | save_weights_only=True, 145 | filename='{epoch}', # this cannot contain slashes 146 | 147 | ) 148 | 149 | 150 | 151 | 152 | lr_logger = LearningRateMonitor() 153 | tb_logger = TensorBoardLogger('logs/') 154 | 155 | model = GenIEModel(args) 156 | if args.dataset == 'RAMS': 157 | dm = RAMSDataModule(args) 158 | elif args.dataset == 'ACE': 159 | dm = ACEDataModule(args) 160 | elif args.dataset == 'KAIROS': 161 | dm = KAIROSDataModule(args) 162 | 163 | 164 | 165 | if args.max_steps < 0 : 166 | args.max_epochs = args.min_epochs = args.num_train_epochs 167 | 168 | 169 | 170 | trainer = Trainer( 171 | logger=tb_logger, 172 | min_epochs=args.num_train_epochs, 173 | max_epochs=args.num_train_epochs, 174 | gpus=args.gpus, 175 | checkpoint_callback=checkpoint_callback, 176 | accumulate_grad_batches=args.accumulate_grad_batches, 177 | gradient_clip_val=args.gradient_clip_val, 178 | num_sanity_val_steps=0, 179 | val_check_interval=0.5, # use float to check every n epochs 180 | precision=16 if args.fp16 else 32, 181 | callbacks = [lr_logger, ], 182 | 183 | ) 184 | 185 | if args.load_ckpt: 186 | model.load_state_dict(torch.load(args.load_ckpt,map_location=model.device)['state_dict']) 187 | 188 | if args.eval_only: 189 | dm.setup('test') 190 | trainer.test(model, datamodule=dm) #also loads training dataloader 191 | else: 192 | dm.setup('fit') 193 | trainer.fit(model, dm) 194 | 195 | 196 | 197 | 198 | if __name__ == "__main__": 199 | main() -------------------------------------------------------------------------------- /viz/visualize_output_KAIROS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from copy import deepcopy 5 | import spacy 6 | from spacy import displacy 7 | import re 8 | from collections import defaultdict 9 | 10 | def find_head(arg_start, arg_end, doc): 11 | cur_i = arg_start 12 | while doc[cur_i].head.i >= arg_start and doc[cur_i].head.i <=arg_end: 13 | if doc[cur_i].head.i == cur_i: 14 | # self is the head 15 | break 16 | else: 17 | cur_i = doc[cur_i].head.i 18 | 19 | arg_head = cur_i 20 | 21 | return (arg_head, arg_head) 22 | 23 | def extract_args_from_template(predicted, template, ontology_dict, evt_type): 24 | # extract argument text 25 | template_words = template.strip().split() 26 | predicted_words = predicted.strip().split() 27 | predicted_args = defaultdict(list) # argname -> List of text 28 | t_ptr= 0 29 | p_ptr= 0 30 | 31 | while t_ptr < len(template_words) and p_ptr < len(predicted_words): 32 | if re.match(r'<(arg\d+)>', template_words[t_ptr]): 33 | m = re.match(r'<(arg\d+)>', template_words[t_ptr]) 34 | arg_num = m.group(1) 35 | arg_name = ontology_dict[evt_type][arg_num] 36 | 37 | if predicted_words[p_ptr] == '': 38 | # missing argument 39 | p_ptr +=1 40 | t_ptr +=1 41 | else: 42 | arg_start = p_ptr 43 | while (p_ptr < len(predicted_words)) and (predicted_words[p_ptr] != template_words[t_ptr+1]): 44 | p_ptr+=1 45 | arg_text = predicted_words[arg_start:p_ptr] 46 | predicted_args[arg_name].append(arg_text) 47 | t_ptr+=1 48 | # aligned 49 | else: 50 | t_ptr+=1 51 | p_ptr+=1 52 | 53 | return dict(predicted_args) 54 | 55 | def find_arg_span(arg, context_words, trigger_start, trigger_end, head_only=False, doc=None): 56 | match = None 57 | arg_len = len(arg) 58 | min_dis = len(context_words) # minimum distance to trigger 59 | for i, w in enumerate(context_words): 60 | if context_words[i:i+arg_len] == arg: 61 | if i < trigger_start: 62 | dis = abs(trigger_start-i-arg_len) 63 | else: 64 | dis = abs(i-trigger_end) 65 | if dis< min_dis: 66 | match = (i, i+arg_len-1) 67 | min_dis = dis 68 | 69 | if match and head_only: 70 | assert(doc!=None) 71 | match = find_head(match[0], match[1], doc) 72 | return match 73 | 74 | def load_ontology(dataset): 75 | ''' 76 | Read ontology file for event to argument mapping. 77 | ''' 78 | ontology_dict ={} 79 | with open('event_role_{}.json'.format(dataset),'r') as f: 80 | ontology_dict = json.load(f) 81 | 82 | for evt_name, evt_dict in ontology_dict.items(): 83 | for i, argname in enumerate(evt_dict['roles']): 84 | evt_dict['arg{}'.format(i+1)] = argname 85 | # argname -> role is not a one-to-one mapping 86 | if argname in evt_dict: 87 | evt_dict[argname].append('arg{}'.format(i+1)) 88 | else: 89 | evt_dict[argname] = ['arg{}'.format(i+1)] 90 | 91 | return ontology_dict 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument('--result-file',type=str, default='checkpoints/gen-KAIROS-pointer-pred/predictions.jsonl') 96 | parser.add_argument('--test-file', type=str, default='data/kairos/test.jsonl') 97 | parser.add_argument('--gold', action='store_true') 98 | args = parser.parse_args() 99 | 100 | ontology_dict = load_ontology('KAIROS') 101 | 102 | render_dicts = [] 103 | 104 | reader= open(args.result_file, 'r') 105 | 106 | with open(args.test_file,'r') as f: 107 | for line in f: 108 | doc = json.loads(line) 109 | # use sent_id for ACE 110 | context_words = doc['tokens'] 111 | render_dict = { 112 | "text":' '.join(context_words), 113 | "ents": [], 114 | "title": '{}_gold'.format(doc['doc_id']) if args.gold else doc['doc_id'], 115 | 116 | } 117 | word2char = {} # word index to start, end char index (end is not inclusive) 118 | ptr =0 119 | for idx, w in enumerate(context_words): 120 | word2char[idx] = (ptr, ptr+ len(w)) 121 | ptr = word2char[idx][1] +1 122 | 123 | links = [] # (start_word, end_word, label) 124 | for eidx, e in enumerate(doc['event_mentions']): 125 | predicted = json.loads(reader.readline()) 126 | filled_template = predicted['predicted'] 127 | evt_type = e['event_type'] 128 | label = 'E{}-{}'.format(eidx, e['event_type']) 129 | trigger_start= e['trigger']['start'] 130 | trigger_end = e['trigger']['end'] -1 131 | trigger_tup = (trigger_start, trigger_end, label) 132 | links.append(trigger_tup) 133 | if args.gold: 134 | # use gold arguments 135 | for arg in e['arguments']: 136 | label = 'E{}-{}'.format(eidx, arg['role']) 137 | ent_id = arg['entity_id'] 138 | # get entity span 139 | matched_ent = [entity for entity in doc['entity_mentions'] if entity['id'] == ent_id][0] 140 | arg_start = matched_ent['start'] 141 | arg_end = matched_ent['end'] -1 142 | links.append((arg_start, arg_end, label)) 143 | else: # use predicted arguments 144 | template = ontology_dict[evt_type]['template'] 145 | # extract argument text 146 | predicted_args = extract_args_from_template(filled_template,template, ontology_dict, evt_type) 147 | # get trigger 148 | # extract argument span 149 | for argname in predicted_args: 150 | for argtext in predicted_args[argname]: 151 | arg_span = find_arg_span(argtext, context_words, 152 | trigger_start, trigger_end, head_only=False, doc=None) 153 | if arg_span:# if None means hullucination 154 | label = 'E{}-{}'.format(eidx, argname) 155 | links.append((arg_span[0], arg_span[1], label)) 156 | 157 | sorted_links = sorted(links, key=lambda x: x[0]) # sort by start idx 158 | 159 | for tup in sorted_links: 160 | arg_start, arg_end, arg_name = tup 161 | label = arg_name 162 | render_dict["ents"].append({ 163 | "start": word2char[arg_start][0], 164 | "end": word2char[arg_end][1], 165 | "label": label, 166 | }) 167 | render_dicts.append(render_dict) 168 | 169 | 170 | 171 | 172 | file_name = args.result_file.split('.')[0] 173 | if args.gold: 174 | file_name += '.gold' 175 | 176 | html = displacy.render(render_dicts, style="ent", manual=True, page=True) 177 | 178 | with open('{}.html'.format(file_name), 'w') as f: 179 | f.write(html) 180 | 181 | -------------------------------------------------------------------------------- /viz/visualize_output_RAMS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from copy import deepcopy 5 | import spacy 6 | from spacy import displacy 7 | import re 8 | 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--result-file',type=str, default='gen-trigger-pred-output.jsonl') 14 | parser.add_argument('--gold', action='store_true') 15 | args = parser.parse_args() 16 | 17 | render_dicts = [] 18 | 19 | 20 | 21 | with open(args.result_file, 'r') as f: 22 | for line in f: 23 | ex = json.loads(line.strip()) 24 | title = ex['doc_key'] 25 | context_words = [w for sent in ex['sentences'] for w in sent ] 26 | 27 | 28 | render_dict = { 29 | "text":' '.join(context_words), 30 | "ents": [], 31 | "title": '{}_gold'.format(ex['doc_key']) if args.gold else ex['doc_key'], 32 | } 33 | 34 | word2char = {} # word index to start, end char index (end is not inclusive) 35 | ptr =0 36 | for idx, w in enumerate(context_words): 37 | word2char[idx] = (ptr, ptr+ len(w)) 38 | ptr = word2char[idx][1] +1 39 | 40 | if args.gold: 41 | links = ex['ref_evt_links'] 42 | else: 43 | links = ex['gold_evt_links'] 44 | 45 | tmp = ex['evt_triggers'][0] 46 | trigger_start = tmp[0] 47 | trigger_end = tmp[1] 48 | trigger_type = tmp[2][0][0] 49 | 50 | links.append([(trigger_start, trigger_end), (trigger_start, trigger_end), trigger_type]) 51 | 52 | sorted_links = sorted(links, key=lambda x: x[1][0]) 53 | 54 | for tup in sorted_links: 55 | trigger_span, arg_span, arg_name = tup 56 | m = re.match(r'evt\d+arg\d+(\w+)', arg_name) 57 | if m: 58 | label = m.group(1) 59 | else: 60 | label = arg_name 61 | render_dict["ents"].append({ 62 | "start": word2char[arg_span[0]][0], 63 | "end": word2char[arg_span[1]][1], 64 | "label": label, 65 | }) 66 | 67 | 68 | 69 | render_dicts.append(render_dict) 70 | 71 | 72 | # ex = [{"text": "But Google is starting from behind.", 73 | # "ents": [{"start": 4, "end": 10, "label": "ORG"}], 74 | # "title": "doc1"}, 75 | # {"text": "But Google is starting from behind.", 76 | # "ents": [{"start": 4, "end": 10, "label": "ORG"}], 77 | # "title": "doc2"}, 78 | 79 | # ] 80 | 81 | file_name = args.result_file.split('.')[0] 82 | if args.gold: 83 | file_name += '.gold' 84 | 85 | html = displacy.render(render_dicts, style="ent", manual=True, page=True) 86 | 87 | with open('{}.html'.format(file_name), 'w') as f: 88 | f.write(html) 89 | 90 | --------------------------------------------------------------------------------