├── .gitattributes ├── .gitignore ├── LICENSE ├── data ├── pid2name.json ├── pubmed_unsupervised.json ├── train_wiki.json ├── val_nyt.json ├── val_pubmed.json ├── val_semeval.json └── val_wiki.json ├── download_pretrain.sh ├── fewshot_re_kit ├── __init__.py ├── data_loader.py ├── framework.py ├── network │ ├── __init__.py │ ├── embedding.py │ └── encoder.py ├── old_data_loader.py └── sentence_encoder.py ├── models ├── __init__.py ├── d.py ├── gnn.py ├── gnn_iclr.py ├── metanet.py ├── mtb.py ├── pair.py ├── proto.py ├── proto_norm.py ├── siamese.py └── snail.py ├── paper ├── fewrel1.0.pdf ├── fewrel1_appendix.pdf └── fewrel2.0.pdf ├── readme.md └── train_demo.py /.gitattributes: -------------------------------------------------------------------------------- 1 | pretrain filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # folders 2 | checkpoint 3 | test_result 4 | 5 | # files 6 | *.pyc 7 | *.swp 8 | *.tar 9 | *.sh 10 | sbatch* 11 | *.ipynb 12 | 13 | # data 14 | pretrain/glove 15 | pretrain/bert-base-uncased 16 | 17 | # virtualenv 18 | .virtual 19 | 20 | # result 21 | result* 22 | 23 | # test data 24 | data/test_wiki.json 25 | data/test_pubmed.json 26 | data/train_wiki+pubmed.json 27 | data/pubmed.json 28 | 29 | # tmp file 30 | pretrain.tar 31 | 32 | # mac cache 33 | .DS_Store 34 | 35 | # editor cache 36 | .idea 37 | .vscode 38 | .venv 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 THUNLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /data/pid2name.json: -------------------------------------------------------------------------------- 1 | {"P2384": ["statement describes", "formalization of the statement contains a bound variable in this class"], "P2388": ["office held by head of the organization", "position of the head of this item"], "P2389": ["organization directed from the office or person", "No description defined"], "P2634": ["sitter", "person who posed during the creation of a work, whether or not that person is eventually depicted as oneself"], "P607": ["conflict", "battles, wars or other military engagements in which the person or item participated"], "P608": ["exhibition history", "exhibitions where the item is or was displayed"], "P609": ["terminus location", "location of the terminus of a linear feature"], "P237": ["coat of arms", "subject's coat of arms"], "P2925": ["domain of saint or deity", "domain(s) which this saint or deity controls or protects"], "P927": ["anatomical location", "where in the body does this anatomical feature lie"], "P2319": ["elector", "people or other entities which are qualified to participate in the subject election"], "P2318": ["debut participant", "participant for whom this is their debut appearance in a series of events"], "P560": ["direction", "qualifier to indicate the direction of the parent statement's value relative to the subject item"], "P925": ["presynaptic connection", "neuron connects on its presynaptic end to"], "P3967": ["final event", "final event of a competition"], "P3966": ["programming paradigm", "programming paradigm in which a programming language is classified"], "P3320": ["board member", "member(s) of the board for the organization"], "P3969": ["signed form", "manually coded form of this language"], "P1462": ["standards body", "organisation that published or maintains the standard governing an item"], "P4185": ["iconographic symbol", "identifying element typically depicted as accompanying or worn by this religious figure, hero, fictional or historical character"], "P127": ["owned by", "owner of the subject"], "P126": ["maintained by", "person or organization in charge of keeping the subject (for instance an infrastructure) in functioning order"], "P88": ["commissioned by", "person or organization that commissioned this work"], "P123": ["publisher", "organization or person responsible for publishing books, periodicals, games or software"], "P122": ["basic form of government", "subject's government"], "P121": ["item operated", "equipment, installation or service operated by the subject"], "P81": ["connecting line", "railway or public transport line(s) subject is directly connected to; use as a qualifier to P197"], "P4345": ["excavation director", "person leading the archaeological or anthropological excavation at this site"], "P4614": ["drainage basin", "area where precipitation collects and drains off into a common outlet, such as into a river, bay, or other body of water"], "P86": ["composer", "person(s) who wrote the music [for lyricist, use \"lyrics by\" (P676)]"], "P129": ["physically interacts with", "physical entity that the subject interacts with"], "P128": ["regulates (molecular biology)", "process regulated by a protein or RNA in molecular biology"], "P2541": ["operating area", "area this organisation operates in, serves or has responsibility for"], "P2546": ["sidekick of", "close companion of a fictional character"], "P2545": ["bowling style", "type of bowling employed by a cricketer"], "P161": ["cast member", "actor in the subject production [use \"character role\" (P453) and/or \"name of the character role\" (P4633) as qualifiers] [use \"voice actor\" (P725) for voice-only role]"], "P1408": ["licensed to broadcast to", "place that a radio/TV station is licensed/required to broadcast to"], "P1018": ["language regulatory body", "regulatory body of a language"], "P518": ["applies to part", "part, aspect, or form of the item to which the claim applies"], "P516": ["powered by", "equipment or engine used to power the subject"], "P1403": ["original combination", "for animals: the combination (binomen or trinomen) where the species-group name used in this taxon name was first published"], "P514": ["interleaves with", "stratigraphic relation in which two units overlap each other marginally"], "P512": ["academic degree", "academic degree that the person holds"], "P1013": ["criterion used", "property by which a distinction or classification is made"], "P511": ["honorific prefix", "word or expression used before a name, in addressing or referring to a person"], "P2358": ["Roman praenomen", "standard part of the name of a Roman, link to items for specific Roman praenomen only."], "P2445": ["metasubclass of", "relation between two metaclasses: instances of this metaclass are likely to be subclasses of classes that are instances of the target metaclass"], "P3190": ["innervates", "anatomical structures innervated by this nerve"], "P1889": ["different from", "item that is different from another item, with which it is often confused"], "P1884": ["hair color", "person's hair color. Use P585 as qualifier if there's more than one value."], "P1885": ["cathedral", "principal church of a religious district"], "P1880": ["measured by", "scale by which a phenomenon is measured"], "P1881": ["list of characters", "Wikimedia page with the list of characters for this work"], "P1068": ["instruction set", "instruction set on which the processor architecture is based"], "P16": ["highway system", "system (or specific country specific road type) of which the highway is a part"], "P1706": ["together with", "qualifier to specify the item that this property is shared with"], "P19": ["place of birth", "most specific known (e.g. city instead of country, or hospital instead of city) birth location of a person, animal or fictional character"], "P4661": ["reservoir created", "reservoir created upstream of a dam by this dam"], "P1066": ["student of", "person who has taught this person"], "P4149": ["conjugate base", "species formed by losing a proton (H\u207a)"], "P1536": ["immediate cause of", "immediate cause of this effect"], "P1537": ["contributing factor of", "thing that is significantly influenced by this cause, but does not directly result from it. See 'Help:Modeling causes' for examples and discussion."], "P1535": ["used by", "item or concept that makes use of the subject (use sub-properties when appropriate)"], "P1532": ["country for sport", "country a person or a team represents when playing a sport"], "P1533": ["family name identical to this given name", "last name that is the same as a given first name. Use on items for given names"], "P1531": ["parent of this hybrid, breed, or cultivar", "No description defined"], "P505": ["general manager", "general manager of a sports team. If they are also an on-field manager use P286 instead"], "P3592": ["Saros cycle of eclipse", "No description defined"], "P468": ["dan/kyu rank", "rank system used in several board games (e.g. go, shogi, renju), martial arts (e.g. judo, kendo, wushu) and some other games"], "P469": ["lakes on river", "lakes, reservoirs, waterfalls the river flows through"], "P1080": ["from fictional universe", "subject's fictional entity is in the object narrative. See also P1441 (present in work) and P1445 (fictional universe described in)"], "P1429": ["has pet", "pet that a person owns"], "P462": ["color", "color of subject"], "P588": ["coolant", "substance used by the subject to dissipate excess thermal energy"], "P460": ["said to be the same as", "this item is said to be the same as that item, but the statement is disputed"], "P461": ["opposite of", "item that is the opposite of this item"], "P466": ["occupant", "a person or organization occupying property"], "P467": ["legislated by", "indicates that an act or bill was passed by a legislature. The value can be a particular session of the legislature"], "P912": ["has facility", "the subject item has this type of facility, e.g. toilet, car park"], "P770": ["cause of destruction", "item which caused the destruction of the subject item"], "P3349": ["designed to carry", "what the vehicle or class of vehicles is or was designed to carry"], "P1830": ["owner of", "entities owned by the subject"], "P868": ["foods traditionally associated", "foods usually during the ceremony or associated with a certain settlement"], "P4000": ["has fruit type", "morphology of the fruit of this taxon, as defined in botany"], "P647": ["drafted by", "which team the player was drafted by"], "P642": ["of", "qualifier stating that a statement applies within the scope of a particular item"], "P641": ["sport", "sport in which the subject participates or belongs to"], "P885": ["origin of the watercourse", "main source of a river or stream"], "P880": ["CPU", "central processing unit found within the subject item"], "P768": ["electoral district", "electoral district this person is representing, or of the office that is being contested. Use as qualifier for \"position held\" (P39) or \"office contested\" (P541)"], "P765": ["surface played on", "the surface on which the tournament is played"], "P767": ["contributor(s) to the creative work or subject", "person or organization that contributed to a subject: co-creator of a creative work or subject"], "P3448": ["stepparent", "subject has the object as their stepparent"], "P3447": ["mirrors data from", "the website automatically crawls or mirrors another website as its data source"], "P3113": ["does not have part", "expected part that the item does not have (for qualities, use P6477)"], "P3842": ["located in present-day administrative territorial entity", "the item was located in the territory of this present-day administrative unit; however the two did not at any point coexist in time"], "P169": ["chief executive officer", "highest-ranking corporate officer appointed as the CEO within an organization"], "P241": ["military branch", "branch to which this military unit, award, office, or person belongs, e.g. Royal Navy"], "P163": ["flag", "subject's flag"], "P1906": ["office held by head of state", "political office that is fulfilled by the head of state of this item"], "P248": ["stated in", "to be used in the references field to refer to the information document or database in which a claim is made; for qualifiers use P805"], "P167": ["structure replaced by", "the item which replaced this building or structure, at the same geographic location"], "P166": ["award received", "award or recognition received by a person, organisation or creative work"], "P2882": ["relegated", "competitor or team relegated at the end of competition"], "P2881": ["promoted", "competitor or team promoted at the end of competition"], "P1210": ["supercharger", "supercharger or turbocharger used by an engine"], "P2341": ["indigenous to", "area or ethnic group that a language, folk dance, cooking style, food or other cultural expression is found (or was originally found)"], "P2348": ["time period", "time period (historic period or era, sports season, theatre season, legislative period etc.) in which the subject occurred"], "P2679": ["author of foreword", "person who wrote the preface, foreword, or introduction of the book but who isn't an author of the rest of the book"], "P2155": ["solid solution series with", "the mineral forms a continous (true) or discontinous \"solid solution series\" with another mineral"], "P2157": ["lithography", "structure size (as advertised or listed) of a microchip"], "P2152": ["antiparticle", "particle with the same rest mass and opposite charges"], "P2159": ["solves", "problem that this algorithm or method solves"], "P2012": ["cuisine", "type of food served by a restaurant or restaurant chain"], "P2408": ["set in period", "historical, contemporary or future period the work is set in"], "P3938": ["named by", "person or organisation that coined the name; use as qualifier for P2561 and its subproperties"], "P3833": ["diaspora", "diaspora that a cultural group belongs to"], "P2822": ["by-product of", "chemical or industrial process which produces the item as a by-product"], "P3931": ["copyright holder", "person or organisation who holds the copyright of a work according to the Berne Convention"], "P4628": ["ICTV virus genome composition", "classification of viruses of the International Committee on Taxonomy of Viruses by the molecular composition of the virus genome (DNA, RNA, double or single stranded and translational polarity)"], "P4545": ["sexually homologous with", "body part that originates from the same tissue or cell during fetal development in the opposing sex"], "P1990": ["species kept", "taxa, preferably species, present at a zoo, botanical garden, collection, or other institution. NOT specific animals, not for any geographic location"], "P1620": ["plaintiff", "party who initiates a lawsuit"], "P2989": ["has grammatical case", "case used in this language"], "P58": ["screenwriter", "person(s) who wrote the script for subject item"], "P59": ["constellation", "the area of the celestial sphere of which the subject is a part (from a scientific standpoint, not an astrological one)"], "P1625": ["has melody", "this work has the melody of the following work"], "P54": ["member of sports team", "sports teams or clubs that the subject currently represents or formerly represented"], "P1999": ["UNESCO language status", "degree of endangerment of a language conferred by the UNESCO Atlas of World Languages in Danger"], "P57": ["director", "director(s) of film, TV-series, stageplay, video game or similar"], "P50": ["author", "main creator(s) of a written work (use on works, not humans); use P2093 when Wikidata item is unknown or does not exist"], "P53": ["family", "family, including dynasty and nobility houses. Not family name (use P734 for family name)."], "P545": ["torch lit by", "person that lit the torch at an event, like the Olympic Games"], "P547": ["commemorates", "what the place, monument, memorial, or holiday, commemorates"], "P546": ["docking port", "intended docking port for a spacecraft"], "P541": ["office contested", "title of office which election will determine the next holder of"], "P543": ["oath made by", "person(s) that made the oath at an event, like the Olympic Games"], "P542": ["officially opened by", "person that officially opened the event or place"], "P427": ["taxonomic type", "the type genus of this family (or subfamily, etc), or the type species of this genus (or subgenus, etc)"], "P1479": ["has contributing factor", "thing that significantly influenced, but did not directly cause, this outcome or effect. Used in conjunction with 'has cause' and 'has immediate cause'. See '[[Help:Modeling causes]]'."], "P425": ["field of this occupation", "activity corresponding to this occupation (use only for occupations - for people use Property:P101, for profession use P425, for companies use P452)"], "P930": ["type of electrification", "electrification system scheme and/or voltage"], "P931": ["place served by transport hub", "territorial entity or entities served by this transport hub (airport, train station, etc.)"], "P1574": ["exemplar of", "property for manuscripts, autographs, incunabula, distinct printed copies"], "P421": ["located in time zone", "time zone for this item"], "P3716": ["social classification", "social class as recognized in traditional or state law"], "P1660": ["has index case", "initial patient in the population of an epidemiological investigation"], "P2715": ["elected in", "qualifier for statements in position held to link the election in which a person gained a position from, or reelection in which the position was confirmed"], "P2652": ["partnership with", "commercial partnership between this organization and another organization or institution"], "P2650": ["interested in", "item of special or vested interest to this person or organisation"], "P3494": ["points classification", "No description defined"], "P3719": ["regulated by", "organization that acts as regulator of an activity, financial market, or stock exchange"], "P3491": ["muscle insertion", "the anatomic entity to which the end of a muscle is anchored"], "P3490": ["muscle origin", "the anatomic entity to which the beginning of a muscle is anchored"], "P1877": ["after a work by", "artist whose work strongly inspired/ was copied in this item"], "P1876": ["vessel", "vessel involved in this mission, voyage or event"], "P1875": ["represented by", "express a relationship between a subject and their agent"], "P3301": ["broadcast by", "the channel, network, website or service that broadcast this item over radio, TV or the Internet"], "P3300": ["musical conductor", "the person who directs a musical group, orchestra or chorus"], "P1050": ["medical condition", "any state relevant to the health of an organism, including diseases and positive conditions"], "P1269": ["facet of", "topic of which this item is an aspect, item that offers a broader perspective on the same topic"], "P1057": ["chromosome", "chromosome on which an entity is localized"], "P1056": ["product or material produced", "material or product produced by a government agency, business, industry, facility, or process"], "P941": ["inspired by", "work, human, place or event which inspired this creative work or fictional entity"], "P943": ["programmer", "the programmer that wrote the piece of software"], "P942": ["theme music", "the theme music/song used by the item"], "P945": ["allegiance", "the country (or other power) that the person, or organization, served"], "P944": ["Code of nomenclature", "the Code that governs the scientific name of this taxon"], "P1196": ["manner of death", "general circumstances of a person's death; e.g. natural causes, accident, suicide, homicide, pending investigation or special 'unknown value'. Use 'cause of death' (P509) for more immediate or underlying causes, e.g. heart attack, car accident"], "P720": ["asteroid spectral type", "spectral classifications of asteroids based on spectral shape, color, and albedo"], "P1372": ["binding of software library", "software library in another programming language provided by the subject software binding"], "P689": ["afflicts", "type of organism which a condition or disease afflicts"], "P725": ["voice actor", "performer of a spoken role in a creative work such as animation, video game, radio drama, or dubbing over [use \"character role\" (P453) as qualifier] [use \"cast member\" (P161) for live acting]"], "P2849": ["produced by", "links a biologic/biochemical entity with the entity it has been produced by"], "P1376": ["capital of", "country, state, department, canton or other administrative division of which the municipality is the governmental seat"], "P726": ["candidate", "person or party that is an option for an office in an election"], "P682": ["biological process", "is involved in the biological process"], "P3085": ["qualifies for event", "this event qualifies for that event"], "P680": ["molecular function", "represents gene ontology function annotations"], "P681": ["cell component", "component of the cell in which this item is present"], "P3080": ["game artist", "game artist(s) that produced art assets for a role-playing games, collectible card games, video game, etc."], "P3081": ["damaged", "physical items damaged by this event"], "P3082": ["destroyed", "physical items destroyed by this event"], "P2872": ["tourist office", "official tourist office of this destination"], "P4733": ["produced sound", "item for the sound generated by the subject, for instance the cry of an animal"], "P3403": ["coextensive with", "this item has the same boundary as the target item; area associated with (this) entity is identical with the area associated with that entity"], "P2975": ["host", "an organism harboring another organism or organisms on or in itself"], "P2974": ["habitat", "the natural environment in which an organism lives, or the physical environment that surrounds a species population"], "P834": ["train depot", "depot which serves this railway line, not public stations"], "P3019": ["railway signalling system", "type of signalling used on this railway line"], "P3018": ["located in protected area", "the protected area a place or geographical feature belongs to"], "P831": ["parent club", "parent club of this team"], "P833": ["interchange station", "station to which passengers can transfer to from this station, normally without extra expense"], "P1283": ["filmography", "item with a list of films a person has contributed to. Don't use to add film items. Instead, add actors as cast on items for films the participated in."], "P3015": ["backup or reserve team or crew", "team or crew that is kept ready to act as reserve"], "P1995": ["health specialty", "main specialty that diagnoses, prevent human illness, injury and other physical and mental impairments"], "P807": ["separated from", "subject was founded or started by separating from identified object"], "P2378": ["issued by", "organisation that issues or allocates an identifier"], "P2375": ["has superpartner", "partner particle, in supersymmetry; inverse of \"superpartner of\""], "P800": ["notable work", "notable scientific, artistic or literary work, or other work of significance among subject's works"], "P2376": ["superpartner of", "partner particle, in supersymmetry; inverse of \"has superpartner\""], "P206": ["located in or next to body of water", "sea, lake or river"], "P1304": ["central bank", "country's central bank"], "P205": ["basin country", "country that have drainage to/from or border the body of water"], "P750": ["distributor", "distributor of a creative work; distributor for a record label"], "P751": ["introduced feature", "feature introduced by this version of a product item"], "P1303": ["instrument", "musical instrument that a person plays"], "P1302": ["primary destinations", "major cities that a road serves"], "P371": ["presenter", "someone who takes a main role in presenting a radio or television program or a performing arts show"], "P611": ["religious order", "order of monks or nuns to which an individual or religious house belongs"], "P610": ["highest point", "point with highest elevation in a region, on a path, of a race"], "P208": ["executive body", "branch of government for the daily administration of the state"], "P209": ["highest judicial authority", "supreme judicial body within a country, administrative division, or other organization"], "P3275": ["storyboard artist", "person credited as the storyboard artist of this work"], "P1629": ["subject item of this property", "relationship represented by the property"], "P3279": ["statistical leader", "leader of a sports tournament in one of statistical qualities (points, assists, rebounds etc.). Don't use for overall winner. Use a qualifier to link to the item about the quality."], "P3975": ["secretary general", "leader of a political or international organization, sometimes below the chairperson (P488)"], "P1652": ["referee", "referee or umpire of a match"], "P1657": ["MPAA film rating", "US film classification administered by the Motion Picture Association of America"], "P1656": ["unveiled by", "person who unveils a statue, sculpture, memorial or plaque, etc."], "P1654": ["wing configuration", "configuration of wing(s) used by an aircraft"], "P2289": ["venous drainage", "vein draining the anatomical structure"], "P85": ["anthem", "subject's official anthem"], "P2283": ["uses", "item or concept used by the subject or in the operation (see also instrument [P1303] and armament [P520])"], "P84": ["architect", "person or architectural firm that designed this building"], "P2286": ["arterial supply", "arterial supply of an anatomical structure"], "P98": ["editor", "editor of a compiled work such as a book or a periodical (newspaper or an academic journal)"], "P1547": ["depends on software", "subject software depends on object software"], "P4379": ["youth wing", "group of younger people associated with this organization"], "P111": ["measured physical quantity", "value of a physical property expressed as number multiplied by a unit"], "P1950": ["second family name in Spanish name", "second (generally maternal) family name in Spanish names (do not use for other double barrelled names)"], "P1951": ["investor", "individual or organization which invests money in the item for the purpose of obtaining financial return on their investment"], "P114": ["airline alliance", "alliance the airline belongs to"], "P1546": ["motto", "description of the motto of the subject"], "P91": ["sexual orientation", "the sexual orientation of the person \u2014 use IF AND ONLY IF they have stated it themselves, unambiguously, or it has been widely agreed upon by historians after their death"], "P92": ["main regulatory text", "text setting the main rules by which the subject is regulated"], "P119": ["place of burial", "location of grave, resting place, place of ash-scattering, etc, (e.g. town/city or cemetery) for a person or animal. There may be several places: e.g. re-burials, cenotaphs, parts of body buried separately."], "P97": ["noble title", "titles held by the person"], "P2579": ["studied by", "subject is studied by this science or domain"], "P2578": ["studies", "subject item is the academic field studying the object item of this property"], "P2577": ["admissible rule in", "this logic inference rule is admissible in that logical system"], "P2575": ["measures", "physical quantity that this device measures"], "P509": ["cause of death", "underlying or immediate cause of death. Underlying cause (e.g. car accident, stomach cancer) preferred. Use 'manner of death' (P1196) for broadest category, e.g. natural causes, accident, homicide, suicide"], "P277": ["programming language", "the programming language(s) in which the software is developed"], "P1478": ["has immediate cause", "nearest, proximate thing that directly resulted in the subject as outcome or effect. Used in conjunction with 'has cause' (i.e. underlying cause) and 'has contributing factor'. See 'Help:Modeling causes'."], "P276": ["location", "location of the item, physical object or event is within. In case of an administrative entity use P131. In case of a distinct terrain feature use P706."], "P501": ["enclave within", "territory is entirely surrounded by the other (enclaved)"], "P1576": ["lifestyle", "typical way of life of an individual, group, or culture"], "P2695": ["type locality (geology)", "the locality where a particular rock type, stratigraphic unit or mineral species is defined from (can coincide but not the same as p189)"], "P1002": ["engine configuration", "configuration of an engine's cylinders"], "P504": ["home port", "home port of the vessel (if different from \"ship registry\"): For civilian ships, the primary port from which the ship operates. Port of registry \u2192P532 should be listed in \"Ship registry\". For warships, this will be the ship's assigned naval base"], "P1000": ["record held", "notable record achieved by a person or entity, include qualifiers for dates held"], "P1001": ["applies to jurisdiction", "the item (an institution, law, public office ...) or statement belongs to or has power over or applies to the value (a territorial jurisdiction: a country, state, municipality, ...)"], "P3610": ["fare zone", "fare zone that the station is in"], "P2597": ["Gram staining", "Gram stain type of a bacterial strain"], "P2596": ["culture", "human culture or people (or several cultures) associated with this item"], "P3902": ["had as last meal", "components of the last meal had by a person before death"], "P2453": ["nominee", "qualifier used with \u00abnominated for\u00bb to specify which person or organization was nominated"], "P636": ["route of administration", "path by which a drug, fluid, poison, or other substance is taken into the body"], "P25": ["mother", "female parent of the subject. For stepmother, use \"stepparent\" (P3448)"], "P27": ["country of citizenship", "the object is a country that recognizes the subject as its citizen"], "P26": ["spouse", "the subject has the object as their spouse (husband, wife, partner, etc.). Use \"partner\" (P451) for non-married companions"], "P21": ["sex or gender", "sexual identity of subject: male (Q6581097), female (Q6581072), intersex (Q1097630), transgender female (Q1052281), transgender male (Q2449503). Animals: male animal (Q44148), female animal (Q43445). Groups of same gender use \"subclass of\" (P279)"], "P20": ["place of death", "most specific known (e.g. city instead of country, or hospital instead of city) death location of a person, animal or fictional character"], "P22": ["father", "male parent of the subject. For stepfather, use \"stepparent\" (P3448)"], "P2058": ["depositor", "depositor/depositaries for the treaty"], "P4675": ["appears in the form of", "this fictional or mythical entity takes the form of that entity"], "P3712": ["objective of project or action", "desired result or outcome"], "P376": ["located on astronomical location", "astronomical body on which features or places are situated"], "P805": ["statement is subject of", "(qualifier only) item which describes the relation identified in this statement"], "P375": ["space launch vehicle", "type of rocket or other vehicle for launching subject payload into outer space"], "P479": ["input method", "input method or device used to interact with a software product"], "P802": ["student", "notable student(s) of an individual"], "P803": ["professorship", "professorship position held by this academic person"], "P176": ["manufacturer", "manufacturer or producer of this product"], "P2937": ["parliamentary term", "term of a parliament or any deliberative assembly"], "P2936": ["language used", "language widely used (spoken or written) in this place or at this event"], "P2935": ["connector", "connectors which the device has/supports"], "P1136": ["solved by", "person that solved a scientific question"], "P1137": ["fossil found in this unit", "fossils that are found in this stratigraphic unit"], "P3354": ["positive therapeutic predictor", "the presence of the genetic variant helps to predict response to a treatment"], "P3355": ["negative therapeutic predictor", "the presence of the genetic variant helps to predict no response or resistance to a treatment"], "P3092": ["film crew member", "member of the crew creating an audiovisual work, used for miscellaneous roles qualified with the job title when no specific property exists. Don't use if such a property is available: notably for cast member (P161), director (P57), etc."], "P3729": ["next lower rank", "lower rank or level in a ranked hierarchy like sport league, military ranks. If there are several possible, list each one and qualify with \"criterion used\" (P1013), avoid using ranks and date qualifiers. For sports leagues/taxa, use specific properties instead."], "P4791": ["commanded by", "commander of a military unit/army/security service, operation, etc."], "P4151": ["game mechanics", "constructs of rules or methods designed for interaction with the game state"], "P4792": ["dam", "construction impounding this watercourse or creating this reservoir"], "P1455": ["list of works", "link to the article with the works of a person"], "P971": ["category combines topics", "this category combines (intersects) these two or more topics"], "P972": ["catalog", "catalog for the item, or, as a qualifier of P528 \u2013 catalog for which the 'catalog code' is valid"], "P1434": ["takes place in fictional universe", "the subject is a work describing a fictional universe, i.e. whose plot occurs in this universe."], "P1433": ["published in", "larger work that a given work was published in, like a book, journal or music album"], "P1432": ["B-side (DEPRECATED)", "song/track which is the B-side of this single"], "P1431": ["executive producer", "executive producer of a movie or TV show"], "P517": ["interaction", "subset of the four fundamental forces (strong (Q11415), electromagnetic (Q849919), weak (Q11418), and gravitation (Q11412) with which a particle interacts"], "P870": ["instrumentation", "combination of musical instruments employed in a composition"], "P873": ["phase point", "phase point to describe critical point and triple point (see talk page for an example)"], "P872": ["printed by", "organization or person who printed the creative work (if different from \"publisher\")"], "P3091": ["mount", "creature ridden by the subject, for instance a horse"], "P655": ["translator", "person who adapts a creative work to another language"], "P654": ["direction relative to location", "qualifier for geographical locations to express relative direction"], "P658": ["tracklist", "songs contained in this item"], "P1227": ["astronomical filter", "passband used to isolate particular sections of the electromagnetic spectrum"], "P710": ["participant", "person, group of people or organization (object) that actively takes/took part in an event or process (subject). Preferably qualify with \"object has role\" (P3831). Use P1923 for participants that are teams."], "P1340": ["eye color", "color of the irises of a person's eyes"], "P1343": ["described by source", "dictionary, encyclopedia, etc. where this item is described"], "P3680": ["statement supported by", "entity that supports a given statement"], "P1347": ["military casualty classification", "allowed values: killed in action (Q210392), missing in action (Q2344557), died of wounds (Q16861372), prisoner of war (Q179637), killed in flight accident (Q16861407), others used in military casualty classification"], "P1346": ["winner", "winner of an event - do not use for awards (use P166 instead), nor for wars or battles"], "P488": ["chairperson", "presiding member of an organization, group or body"], "P489": ["currency symbol description", "item with description of currency symbol"], "P734": ["family name", "surname of a person"], "P783": ["hymenium type", "type of spore-bearing surface or that mushroom"], "P210": ["party chief representative", "chief representative of a party in an institution or an administrative unit (use qualifier to identify the party)"], "P483": ["recorded at", "studio or location where a musical composition was recorded"], "P485": ["archives at", "the institution holding the subject's archives"], "P3438": ["vehicle normally used", "vehicle the subject normally uses"], "P744": ["asteroid family", "population of asteroids that share similar proper orbital elements"], "P3433": ["biological variant of", "a variant of a physical biological entity (e.g., gene sequence, protein sequence, epigenetic mark)"], "P3437": ["people or cargo transported", "the type of passengers or cargo a vehicle actually carries/carried"], "P156": ["followed by", "immediately following item in a series of which the subject is a part [if the subject has been replaced, e.g. political offices, use \"replaced by\" (P1366)]"], "P157": ["killed by", "person who killed the subject"], "P155": ["follows", "immediately prior item in a series of which the subject is a part [if the subject has replaced the preceding item, e.g. political offices, use \"replaces\" (P1365)]"], "P150": ["contains administrative territorial entity", "(list of) direct subdivisions of an administrative territorial entity"], "P1878": ["Vox-ATypI classification", "classification for typefaces"], "P629": ["edition or translation of", "is an edition or translation of this entity"], "P4330": ["contains", "item or substance located within this item but not part of it"], "P397": ["parent astronomical body", "major astronomical body the item belongs to"], "P1910": ["decreased expression in", "indicates that a decreased expression of the subject gene is found in the object disease"], "P664": ["organizer", "person or institution organizing an event"], "P1913": ["gene duplication association with", "This property should link a gene and a disease due to a duplication"], "P398": ["child astronomical body", "minor body that belongs to the item"], "P399": ["companion of", "two or more astronomic bodies of the same type relating to each other"], "P2238": ["official symbol", "official symbol of administrative entities"], "P1318": ["proved by", "person who proved something"], "P2894": ["day of week", "day of the week on which this item occurs, applies to or is valid"], "P195": ["collection", "art, museum, archival, or bibliographic collection the subject is part of"], "P2237": ["None", "None"], "P2414": ["substrate of", "substrate that an enzyme acts upon to create a product"], "P1686": ["for work", "qualifier of award received (P166) to specify the work that an award was given to the creator for"], "P3094": ["develops from", "this class of items develops from another class of items"], "P1192": ["connecting service", "service stopping at a station"], "P3828": ["wears", "clothing or accessory worn on subject's body"], "P61": ["discoverer or inventor", "the entity who discovered, first described, invented, or developed this discovery or invention"], "P2095": ["co-driver", "rally team member who performs as a co-driver or co-pilot"], "P65": ["site of astronomical discovery", "the place where an astronomical object was discovered (observatory, satellite)"], "P66": ["ancestral home", "place of origin for ancestors of subject"], "P69": ["educated at", "educational institution attended by subject"], "P2098": ["substitute/deputy/replacement of office/officeholder", "function that serves as deputy/replacement of this function/office (scope/conditions vary depending on office)"], "P2522": ["victory", "competition or event won by the subject"], "P500": ["exclave of", "territory is legally or politically attached to a main territory with which it is not physically contiguous because of surrounding alien territory. It may also be an enclave."], "P1962": ["None", "None"], "P530": ["diplomatic relation", "diplomatic relations of the country"], "P532": ["port of registry", "ship's port of registry. This is generally painted on the ship's stern (for the \"home port\", see Property:P504)"], "P533": ["target", "target of a terrorist attack or military operation"], "P928": ["activating neurotransmitter", "which neurotransmitter activates the neuron"], "P538": ["fracturing", "type of fracture a crystal or mineral forms"], "P926": ["postsynaptic connection", "neuron connects on its postsynaptic end to"], "P4599": ["monomer of", "polymer composed of this monomer subunits"], "P924": ["possible treatment", "health treatment used to resolve or ameliorate a medical condition"], "P923": ["medical examinations", "examinations that might be used to diagnose the medical condition"], "P688": ["encodes", "the product of a gene (protein or RNA)"], "P437": ["distribution format", "method (or type) of distribution for the subject"], "P3764": ["pole position", "person, who starts race at first row (leader in the starting grid)"], "P3484": ["None", "None"], "P2176": ["drug used for treatment", "drug, procedure, or therapy that can be used to treat a medical condition"], "P4446": ["reward program", "reward program associated with the item"], "P1344": ["participant of", "event a person or an organization was/is a participant in, inverse of P710 or P1923"], "P1040": ["film editor", "Person who works with the raw footage, selecting shots and combining them into sequences to create a finished motion picture"], "P1046": ["discovery method", "way an exoplanet was discovered"], "P1049": ["worshipped by", "religion or group/civilization that worships a given deity"], "P4292": ["possessed by spirit", "item which is spiritually possessing this item"], "P4688": ["geomorphological unit", "topographic or bathymetric feature to which this geographical item belongs"], "P1389": ["product certification", "certification for a product, qualify with P1001 (\"applies to jurisdiction\") if needed"], "P289": ["vessel class", "series of vessels built to the same design of which this vessel is a member"], "P286": ["head coach", "on-field manager or head coach of a sports club (not to be confused with a general manager P505, which is not a coaching position) or person"], "P287": ["designed by", "person(s) that designed the item"], "P1387": ["political alignment", "political position within the political spectrum"], "P17": ["country", "sovereign state of this item; don't use on humans"], "P282": ["writing system", "alphabet, character set or other system of writing used by a language, supported by a typeface"], "P1383": ["contains settlement", "settlement which an administrative division contains"], "P1382": ["partially coincident with", "object that is partially part of, but not fully part of (P361), the subject"], "P1142": ["political ideology", "political ideology of this organization or person"], "P1268": ["represents", "organization or individual that an entity represents"], "P991": ["successful candidate", "person(s) elected after the election"], "P449": ["original network", "network(s) the radio or television show was originally aired on, not including later re-runs or additional syndication"], "P2962": ["title of chess person", "title awarded by a chess federation to a person"], "P1264": ["valid in period", "time period when a statement is valid"], "P3005": ["valid in place", "place where a statement is valid"], "P1064": ["track gauge", "spacing of the rails on a railway track"], "P3161": ["has grammatical mood", "language has this grammatical mood/mode for signaling modality"], "P840": ["narrative location", "the narrative of the work is set in this location"], "P841": ["feast day", "saint's principal feast day"], "P1817": ["addressee", "person or organization to whom a letter or note is addressed"], "P1811": ["list of episodes", "link to the article with the list of episodes for this series"], "P1542": ["has effect", "effect of this cause"], "P523": ["temporal range start", "the start of a process or appearance of a life form relative to the geologic time scale"], "P4743": ["animal breed", "subject item belongs to a specific group of domestic animals, generally given by association"], "P1299": ["depicted by", "object depicting this subject"], "P2366": ["Roman agnomen", "optional part of the name of a Roman, link to items about specific Roman agnomen only."], "P2365": ["Roman cognomen", "standard part of the name of a Roman, link to items about specific Roman cognomen only."], "P3189": ["innervated by", "nerves which innervate this anatomical structure"], "P2360": ["intended public", "this work, product, object or event is intended for, or has been designed to that person or group of people, animals, plants, etc"], "P4220": ["order of battle", "arrangement of units and hierarchical organization of the armed forces involved in the specified military action"], "P747": ["has edition", "link to an edition of this item"], "P669": ["located on street", "street, road, or square, where the item is located. To add the number, use Property:P670 \"street number\" as qualifier. Use property P6375 if there is no item for the street"], "P199": ["business division", "organizational divisions of this organization (which are not independent legal entities)"], "P1312": ["has facet polytope", "facet of a polytope, in the next-lower dimension"], "P1313": ["office held by head of government", "political office that is fulfilled by the head of the government of this item"], "P1310": ["statement disputed by", "entity that disputes a given statement"], "P740": ["location of formation", "location where a group or organization was formed"], "P193": ["main building contractor", "the main organization responsible for construction of this structure or building"], "P190": ["twinned administrative body", "twin towns, sister cities, twinned municipalities and other localities that have a partnership or cooperative agreement, either legally or informally acknowledged by their governments"], "P196": ["minor planet group", "is in grouping of minor planets according to similar orbital characteristics"], "P197": ["adjacent station", "the stations next to this station, sharing the same line(s)"], "P749": ["parent organization", "parent organization of an organization, opposite of subsidiaries (P355)"], "P748": ["appointed by", "who appointed the person to the office, can be used as a qualifier"], "P3261": ["anatomical branch of", "main stem of this blood vessel, lymphatic vessel or nerve"], "P3262": ["has anatomical branch", "branches of this blood vessel, lymphatic vessel or nerve"], "P344": ["director of photography", "person responsible for the framing, lighting, and filtration of the subject work"], "P2554": ["production designer", "production designer(s) of this motion picture, play, video game or similar"], "P694": ["replaced synonym (for nom. nov.)", "previously published name on which a replacement name (avowed substitute, nomen novum) is based."], "P690": ["space group", "symmetry classification for 2 and 3 dimensional patterns or crystals"], "P3461": ["designated as terrorist by", "country or organization that has officially designated a given group as a terrorist organization (e.g. for India, listed on http://mha.nic.in/BO )"], "P3501": ["Catholic rite", "Catholic rite associated with this item"], "P2992": ["software quality assurance", "quality assurance process in place for a particular software"], "P0": ["None", "None"], "P3989": ["members have occupation", "all members of this group share the occupation"], "P2515": ["costume designer", "person who designed the costumes for a film, television programme, etc"], "P2512": ["series spin-off", "series' spin-offs"], "P1639": ["pendant of", "other work in a pair of opposing artworks, such as wedding portraits, commissioned together, but not always"], "P2978": ["wheel arrangement", "wheel/axle arrangement for locomotives, railcars and other rolling stock"], "P1464": ["category for people born here", "category item that groups people born in this place"], "P2293": ["genetic association", "general link between a disease and the causal genetic entity, if the detailed mechanism is unknown/unavailable"], "P2291": ["charted in", "chart where the element reached a position"], "P264": ["record label", "brand and trademark associated with the marketing of subject music recordings and music videos"], "P1923": ["participating team", "Like 'Participant' (P710) but for teams. For an event like a cycle race or a football match you can use this property to list the teams and P710 to list the individuals (with 'member of sports team' (P54) as a qualifier for the individuals)"], "P108": ["employer", "person or organization for which the subject works or worked"], "P263": ["official residence", "the residence at which heads of government and other senior figures officially reside"], "P105": ["taxon rank", "level in a taxonomic hierarchy"], "P106": ["occupation", "occupation of a person; see also \"field of work\" (Property:P101), \"position held\" (Property:P39)"], "P101": ["field of work", "specialization of a person or organization; see P106 for the occupation"], "P103": ["native language", "language or languages a person has learned from early childhood"], "P102": ["member of political party", "the political party of which this politician is or has been a member"], "P2568": ["repealed by", "document is repealed/inactived by specified other document"], "P2560": ["GPU", "graphics processing unit within a system"], "P2563": ["superhuman feature or ability", "superhuman, supernatural, or paranormal abilities that the fictional subject exhibits"], "P2564": ["K\u00f6ppen climate classification", "indicates the characteristic climate of a place"], "P2567": ["amended by", "document is amended by specified other document"], "P1039": ["type of kinship", "qualifier of \"relative (P1038)\" to indicate unusual family relationships (ancestor, son-in-law, adoptions, etc). Avoid using for relationships that can be derived from the family tree or have an explicit property (spouse, child, etc)"], "P2743": ["this zoological name is coordinate with", "links coordinate zoological names"], "P2746": ["production statistics", "amount of a certain good produced in/by the item"], "P1568": ["domain", "set of \"input\" or argument values for which a mathematical function is defined"], "P1032": ["Digital Rights Management system", "technologies to control the use of digital content and devices after sale"], "P1037": ["director/manager", "person who manages any kind of group"], "P828": ["has cause", "underlying cause, thing that ultimately resulted in this effect"], "P1034": ["main food source", "species, genus or family that an organism depends on for nutrition"], "P3602": ["candidacy in election", "election where the subject is a candidate"], "P1445": ["fictional universe described in", "to link a fictional universe with a work that describes it: \"described in the work:\" "], "P3912": ["newspaper format", "physical size of a newspaper (berliner, broadsheet, tabloid, etc.)"], "P78": ["top-level Internet domain", "Internet domain name system top-level code"], "P2738": ["disjoint union of", "every instance of this class is an instance of exactly one class in that list of classes"], "P463": ["member of", "organization or club to which the subject belongs. Do not use for membership in ethnic or social groups, nor for holding a position such as a member of parliament (use P39 for that)."], "P4099": ["metrically compatible typeface", "typeface metrically compatible with this typeface: glyph position, width and height match"], "P3919": ["contributed to creative work", "person is cited as contributing to some creative or published work or series (qualify with \"subject has role\", P2868)"], "P3985": ["supports programming language", "programming language which is supported by this programming tool"], "P414": ["stock exchange", "exchange on which this company is traded"], "P837": ["day in year for periodic occurrence", "when a specific holiday or periodic event occurs. Can be used as property or qualifier"], "P38": ["currency", "currency used by item"], "P39": ["position held", "subject currently or formerly holds the object position or public office"], "P1640": ["curator", "content specialist responsible for this collection or exhibition"], "P1560": ["given name version for other gender", "equivalent name (with respect to the meaning of the name) in the same language: female version of a male first name, male version of a female first name. Add primarily the closest matching one"], "P30": ["continent", "continent of which the subject is a part"], "P31": ["instance of", "that class of which this subject is a particular example and member (subject typically an individual member with a proper name label); different from P279; using this property as a qualifier is deprecated\u2014use P2868 or P3831 instead"], "P4647": ["location of first performance", "location where a work was first debuted, performed or broadcasted"], "P37": ["official language", "language designated as official by this item"], "P35": ["head of state", "official with the highest formal authority in a country/state"], "P400": ["platform", "platform for which a work was developed or released, or the specific platform version of a software product"], "P1454": ["legal form", "legal form of an organization"], "P1456": ["list of monuments", "link to the list of heritage monuments in the place/area"], "P404": ["game mode", "a video game's available playing mode(s)"], "P406": ["soundtrack release", "music release that incorporates music directly recorded from the soundtrack of an audiovisual work"], "P407": ["language of work or name", "language associated with this creative work (such as books, shows, songs, or websites) or a name (for persons use P103 and P1412)"], "P817": ["decay mode", "type of decay that a radioactive isotope undergoes (should be used as a qualifier for \"decays to\")"], "P816": ["decays to", "what isotope does this radioactive isotope decay to"], "P814": ["IUCN protected areas category", "protected areas category by the World Commission on Protected Areas. Used with dedicated items for each category."], "P812": ["academic major", "subject someone studied at college/university"], "P366": ["use", "main use of the subject (includes current and former usage)"], "P568": ["overlies", "stratigraphic unit that this unit lies over (i.e. the underlying unit)"], "P567": ["underlies", "stratigraphic unit that this unit lies under (i.e. the overlying unit)"], "P566": ["basionym", "the legitimate, previously published name on which a new combination or name at new rank is based"], "P2922": ["month of the year", "month of the year during which this item occurs, applies to or is valid in"], "P2739": ["typeface/font used", "style of type used in a work"], "P562": ["central bank/issuer", "central bank or other issuing authority for the currency"], "P2632": ["place of detention", "place where this person is or was detained"], "P2633": ["geography of topic", "item that deals with the geography of the subject. Sample: \"Rio de Janeiro\" uses this property with value \"geography of Rio de Janeiro\" (Q10288853). For the location of a subject,use \"location\" (P276)"], "P624": ["guidance system", "guidance system of a missile"], "P3730": ["next higher rank", "higher rank or level in a ranked hierarchy like sport league, military ranks. If there are several possible, list each one and qualify with \"criterion used\" (P1013), avoid using ranks and date qualifiers. For sports leagues/taxa, use specific properties instead."], "P1420": ["taxon synonym", "(incorrect) name(s) listed as synonym(s) of a taxon name"], "P87": ["librettist", "author of the libretto (words) of an opera, operetta, oratorio or cantata, or of the book of a musical"], "P1589": ["lowest point", "point with lowest elevation in the country, region, city or area"], "P1425": ["ecoregion (WWF)", "ecoregion of the item (choose from WWF's list)"], "P1427": ["start point", "starting place of this journey, flight, voyage, trek, migration etc."], "P1851": ["input set", "a superset of the domain of a function or relation that may include some inputs for which the function is not defined; to specify the set of only those inputs for which the function is defined use domain (P1568)"], "P1582": ["natural product of taxon", "links a natural product with its source (animal, plant, fungal, algal, etc.)"], "P4147": ["conjugate acid", "species formed by accepting a proton (H\u207a)"], "P1855": ["Wikidata property example", "example where this Wikidata property is used; target item is one that would use this property, with qualifier the property being described given the associated value"], "P1419": ["shape", "shape of an object"], "P1924": ["vaccine for", "disease that a vaccine is for"], "P1290": ["godparent", "person who is the godparent of a given person"], "P1211": ["fuel system", "fuel system that an engine uses"], "P708": ["diocese", "administrative division of the church to which the element belongs; use P5607 for other types of ecclesiastical territorial entities"], "P2813": ["mouthpiece", "media that speaks for an organization or a movement, and that is usually edited by its members"], "P703": ["found in taxon", "the taxon in which the item can be found"], "P702": ["encoded by", "the gene that encodes some gene product"], "P707": ["satellite bus", "general model on which multiple-production satellite spacecraft is based"], "P706": ["located on terrain feature", "located on the specified landform. Should not be used when the value is only political/administrative (P131) or a mountain range (P4552)."], "P3137": ["parent peak", "parent is the peak whose territory this peak resides in, based on the contour of the lowest col"], "P2959": ["permanent duplicated item", "this item duplicates another item and the two can't be merged, as one Wikimedia project includes two pages, e. g. in different scripts or languages (applies to some wiki, e.g.: cdowiki, gomwiki). Use \"duplicate item\" for other wikis."], "P495": ["country of origin", "country of origin of this item (creative work, food, phrase, product, etc.)"], "P790": ["approved by", "item is approved by other item(s) [qualifier: statement is approved by other item(s)]"], "P4584": ["first appearance", "work in which a fictional/mythical character or entity first appeared"], "P3781": ["has active ingredient", "has part biologically active component. Inverse of \"active ingredient in\""], "P3780": ["active ingredient in", "is part of and forms biologically active component. Inverse of \"has active ingredient\""], "P795": ["located on linear feature", "linear feature along which distance is specified from a specified datum point"], "P797": ["authority", "entity having executive power on given entity"], "P159": ["headquarters location", "specific location where an organization's headquarters is or has been situated. Inverse property of \"occupant\" (P466)."], "P4586": ["type foundry", "type foundry releasing or distributing a font or typeface"], "P3033": ["package management system", "package management system used to publish the software"], "P3032": ["adjacent building", "building adjacent to the item"], "P411": ["canonization status", "stage in the process of attaining sainthood per the subject's religious organization"], "P141": ["IUCN conservation status", "conservation status assigned by the International Union for Conservation of Nature"], "P140": ["religion", "religion of a person, organization or religious building, or associated with this subject"], "P634": ["captain", "captain of this sports team"], "P144": ["based on", "the work(s) used as the basis for subject item"], "P631": ["structural engineer", "person, group or organisation responsible for the structural engineering of a building or structure"], "P3342": ["significant person", "person linked to the item in any possible way"], "P149": ["architectural style", "architectural style of a structure"], "P915": ["filming location", "actual place where this scene/film was shot. For the setting, use \"narrative location\" (P840)"], "P832": ["public holiday", "official public holiday that occurs in this place in its honor, usually a non-working day"], "P2321": ["general classification of race participants", "classification of race participants"], "P1435": ["heritage designation", "heritage designation of a cultural or natural site"], "P684": ["ortholog", "orthologous gene in another species (use with 'species' qualifier)"], "P2329": ["antagonist muscle", "No description defined"], "P618": ["source of energy", "describes the source of energy an animated object (machine or animal) uses"], "P2079": ["fabrication method", "method, process or technique used to grow, cook, weave, build, assemble, manufacture the item"], "P200": ["inflows", "major inflow sources \u2014 rivers, aquifers, glacial runoff, etc. Some terms may not be place names, e.g. none"], "P2175": ["medical condition treated", "disease that this pharmaceutical drug, procedure, or therapy is used to treat"], "P1677": ["index case of", "primary case, patient zero: initial patient in the population of an epidemiological investigation"], "P201": ["lake outflow", "rivers and other outflows waterway names. If evaporation or seepage are notable outflows, they may be included. Some terms may not be place names, e.g. evaporation"], "P1672": ["this taxon is source of", "links a taxon to natural products it produces. Note that it does not say \"this taxon is the source of\" or \"this taxon is a source of\" as this may vary. Some products may be yielded by more than one taxon."], "P1038": ["relative", "family member (qualify with \"type of kinship\", P1039; for direct family member please use specific property)"], "P3815": ["volcano observatory", "institution that monitors this volcanic landform or phenomenon"], "P112": ["founded by", "founder or co-founder of this organization, religion or place"], "P937": ["work location", "location where persons were active"], "P134": ["has dialect (DEPRECATED)", "a former property, to be replaced by P4913, that describes a lot of things that are \"dialects\" related"], "P135": ["movement", "literary, artistic, scientific or philosophical movement associated with this person or work"], "P136": ["genre", "creative work's genre or an artist's field of work (P101). Use main subject (P921) to relate creative works to their topic"], "P137": ["operator", "person, profession, or organization that operates the equipment, facility, or service; use country for diplomatic missions"], "P131": ["located in the administrative territorial entity", "the item is located on the territory of the following administrative entity. Use P276 (location) for specifying the location of non-administrative places and for items about events"], "P138": ["named after", "entity or event that inspired the subject's name, or namesake (in at least one language)"], "P793": ["significant event", "significant or notable events associated with the subject"], "P4353": ["nominated by", "who nominated a person for an office; can be used as a qualifier"], "P4600": ["polymer of", "monomer of which this polymer compose"], "P2551": ["used metre", "rhythmic structure of the poetic text or musical piece"], "P2499": ["league level above", "the league above this sports league"], "P4428": ["implementation of", "implementation of a standard, program, specification or programming language"], "P1716": ["brand", "commercial brand associated with the item"], "P612": ["mother house", "principal house or community for a religious institute"], "P2670": ["has parts of the class", "the subject instance has parts of the object class (the subject is usually not a class)"], "P522": ["type of orbit", "orbit a satellite has around its central body"], "P521": ["scheduled service destination", "airport or station connected by regular direct service to the subject; for the destination of a trip see P1444"], "P520": ["armament", "equippable weapon item for the subject"], "P527": ["has part", "part of this subject; inverse property of \"part of\" (P361). See also \"has parts of the class\" (P2670)."], "P2675": ["reply to", "the intellectual work to which the subsequent work is a direct reply"], "P524": ["temporal range end", "the end of a process or extinction of a life form relative to the geologic time scale"], "P1411": ["nominated for", "award nomination received by a person, organisation or creative work (inspired from \"award received\" (Property:P166))"], "P913": ["notation", "mathematical notation or another symbol"], "P910": ["topic's main category", "main Wikimedia category"], "P1412": ["languages spoken, written or signed", "language(s) that a person speaks, writes or signs, including the native language(s)"], "P1414": ["GUI toolkit or framework", "framework or toolkit a program uses to display the graphical user interface"], "P1416": ["affiliation", "organization that a person or organization is affiliated with (not necessary member of or employed by)"], "P1789": ["chief operating officer", "the chief operating officer of an organization"], "P4100": ["parliamentary group", "parliamentary group which a member of a parliament belongs to"], "P3679": ["stock market index", "method of measuring the value of a section of the stock market"], "P1891": ["signatory", "person, country, or organization that has signed an official document (use\u00a0P50 for author)"], "P4426": ["Y-DNA Haplogroup", "Y-DNA haplogroup of a person or organism"], "P1775": ["follower of", "for unknown artists who work in the manner of the named artist"], "P1079": ["launch contractor", "organization contracted to launch the rocket"], "P921": ["main subject", "primary topic of a work (see also P180: depicts)"], "P769": ["significant drug interaction", "clinically significant interaction between two pharmacologically active substances (i.e., drugs and/or active metabolites) where concomitant intake can lead to altered effectiveness or adverse drug events."], "P1075": ["rector", "rector of a university"], "P1074": ["fictional analog of", "used to link an entity or class of entities appearing in a creative work with the analogous entity or class of entities in the real world"], "P1073": ["writable file format", "file format a program can create and/or write to"], "P2789": ["connects with", "item with which the item is physically connected"], "P1071": ["location of final assembly", "place where the item was made; location of final assembly"], "P881": ["type of variable star", "type of variable star"], "P736": ["cover art by", "name of person or team creating cover artwork for album, single, book, etc."], "P355": ["subsidiary", "subsidiary of a company or organization, opposite of parent organization (P749)"], "P1398": ["structure replaces", "the item this building or structure replaced, at the same geographic location"], "P1399": ["convicted of", "crime a person was convicted of"], "P358": ["discography", "link to discography in artist or band page"], "P1393": ["proxy", "person authorized to act for another"], "P1158": ["location of landing", "location where the craft landed"], "P457": ["foundational text", "text through which an institution or object has been created or established"], "P2851": ["payment types accepted", "types of payment accepted by a venue"], "P2852": ["emergency phone number", "telephone number to contact the emergency services"], "P2853": ["electrical plug type", "standard plug type for mains electricity in a country"], "P452": ["industry", "industry of company or organization"], "P451": ["partner", "someone in a relationship without being married. Use \"spouse\" for married couples."], "P450": ["astronaut mission", "space mission that the subject is or has been a member of (do not include future missions)"], "P408": ["software engine", "software engine employed by the subject item"], "P3179": ["territory overlaps", "part or all of the area associated with (this) entity overlaps part or all of the area associated with that entity"], "P3075": ["official religion", "official religion in this administrative entity"], "P2416": ["sports discipline competed in", "discipline an athlete competed in within a sport"], "P3173": ["offers view on", "things, places another place offers views on"], "P3174": ["art director", "person credited as the art director/artistic director of this work"], "P1552": ["has quality", "the entity has an inherent or distinguishing non-material characteristic"], "P403": ["mouth of the watercourse", "the body of water to which the watercourse drains"], "P1557": ["manifestation of", "inherent and characteristic embodiment of a given concept"], "P859": ["sponsor", "organization or individual that sponsors this item"], "P3373": ["sibling", "the subject has the object as their sibling (brother, sister, etc.). Use \"relative\" (P1038) for siblings-in-law (brother-in-law, sister-in-law, etc.) and step-siblings (step-brothers, step-sisters, etc.)"], "P1308": ["officeholder", "persons who hold and/or held an office or noble title"], "P1809": ["choreographer", "person(s) who did the choreography"], "P185": ["doctoral student", "doctoral student(s) of a professor"], "P184": ["doctoral advisor", "person who supervised the doctorate or PhD thesis of the subject"], "P186": ["material used", "material the subject is made of or derived from"], "P180": ["depicts", "depicted entity (see also P921: main subject)"], "P183": ["endemic to", "sole location or habitat type where the taxon lives"], "P1078": ["valvetrain configuration", "configuration of the valvetrain utilized by this engine"], "P189": ["location of discovery", "where the item was located when discovered"], "P2354": ["has list", "Wikimedia list related to this subject"], "P1322": ["dual to", "dual of a polytope, graph or curve"], "P1321": ["place of origin (Switzerland)", "lieu d'origine of a Swiss national. Not be confused with place of birth or place of residence"], "P1327": ["partner in business or sport", "professional collaborator"], "P739": ["ammunition", "cartridge or other ammunition used by the subject firearm"], "P676": ["lyrics by", "author of song lyrics; also use P86 for music composer"], "P674": ["characters", "characters which appear in this item (like plays, operas, operettas, books, comics, films, TV series, video games)"], "P2828": ["corporate officer", "person who holds a specific position"], "P194": ["legislative body", "legislative body governing this entity; political institution with elected representatives, such as a parliament/legislature or council"], "P2825": ["via", "intermediate point on a journey - stopover location, waystation or routing point"], "P2821": ["by-product", "product of a chemical or industrial process, of secondary economic value"], "P2820": ["cardinality of this set", "measure of number of elements of a set"], "P1072": ["readable file format", "file format a program can open and read"], "P3450": ["sports season of league or competition", "property that shows the competition of which the item is a season. Use P5138 for \"season of club or team\"."], "P3512": ["means of locomotion", "method that the subject uses to move from one place to another"], "P3103": ["has tense", "grammatical category expressing time reference of the language. To include a sample, use qualifier \"quote\" (P1683), sample: \"He writes\". If an activity before is needed \"He sees\". If an activity afterwards is needed: \"He reads\"."], "P2505": ["carries", "item (e.g. road, railway, canal) carried by a bridge, a tunnel or a mountain pass"], "P2502": ["classification of race", "race for which this classification applies"], "P2500": ["league level below", "the league below this sports league"], "P4312": ["camera setup", "filmmaking method that the cameras were placed by. Use single-camera (Q2918907) or multiple-camera (Q738160)"], "P2868": ["subject has role", "(qualifier) role/generic identity of the item that the statement is on (\"subject\") in the context of the statement. For the role of the value of the statement (\"object\"), use P3831 (\"object has role\"). For acting roles, use P453 (\"character role\")."], "P1336": ["territory claimed by", "administrative divisions that claim control of a given area"], "P1571": ["codomain", "codomain of a function"], "P178": ["developer", "organisation or person that developed the item"], "P179": ["part of the series", "series which contains the subject"], "P275": ["license", "license under which this copyrighted work is released"], "P1773": ["attributed to", "uncertain but considered creator of an artwork"], "P272": ["production company", "company that produced this film, audio or performing arts work"], "P2184": ["history of topic", "historical development of an item's topic"], "P170": ["creator", "maker of this creative work or other object (where no more specific property exists). Paintings with unknown painters, use \"anonymous\" (Q4233718) as value."], "P171": ["parent taxon", "closest parent taxon of the taxon in question"], "P172": ["ethnic group", "subject's ethnicity (consensus is that a VERY high standard of proof is needed for this field to be used. In general this means 1) the subject claims it him/herself, or 2) it is widely agreed on by scholars, or 3) is fictional and portrayed as such)."], "P175": ["performer", "actor, musician, band or other performer associated with this role or musical work"], "P279": ["subclass of", "all instances of these items are instances of those items; this item is a class (subset) of that item. Not to be confused with P31 (instance of)"], "P177": ["crosses", "obstacle (body of water, road, ...) which this bridge crosses over or this tunnel goes under"], "P291": ["place of publication", "geographical place of publication of the edition (use 1st edition when referring to works)"], "P2210": ["relative to", "qualifier: what a statement value is relative to"], "P3205": ["patient of", "was treated or studied as a patient by this person"], "P1202": ["carries scientific instrument", "scientific instruments carried by a vessel, satellite, or device that are not required for propelling or navigating"], "P2094": ["competition class", "official classification by a regulating body under which the subject (events, teams, participants, or equipment) qualifies for inclusion"], "P974": ["tributary", "stream or river that flows into this main stem (or parent) river"], "P2438": ["narrator", "narrator, character or person that tells the story"], "P2439": ["None", "None"], "P737": ["influenced by", "this person, idea, etc. is informed by that other person, idea, etc., e.g. \"Heidegger was influenced by Aristotle\"."], "P36": ["capital", "primary city of a country, province, state or other type of administrative territorial entity"], "P4552": ["mountain range", "range or subrange to which the geographical item belongs"], "P1637": ["undercarriage", "type of aircraft landing gear the item is equipped with"], "P47": ["shares border with", "countries or administrative subdivisions, of equal level, that this item borders, either by land or water"], "P1731": ["Fach", "describes the special ablilites of an operatic singers voice"], "P4387": ["update method", "method used by an app/OS to receive updates or self-update"], "P40": ["child", "subject has object as biological, foster, and/or adoptive child"], "P413": ["position played on team / speciality", "position or specialism of a player on a team, e.g. Small Forward"], "P412": ["voice type", "person's voice type. expected values: soprano, mezzo-soprano, contralto, countertenor, tenor, baritone, bass (and derivatives)"], "P1444": ["destination point", "intended destination for this route (journey, flight, sailing, exploration, migration, etc.)"], "P410": ["military rank", "military rank achieved by a person (should usually have a \"start time\" qualifier), or military rank associated with a position"], "P417": ["patron saint", "patron saint adopted by the subject"], "P415": ["radio format", "describes the overall content broadcast on a radio station"], "P1441": ["present in work", "work in which this fictional entity (Q14897293) or historical person is present (use P2860 for works citing other works, P361/P1433 for works being part of / published in other works, P1343 for entities described in non-fictional accounts)"], "P822": ["mascot", "mascot of an organization, e.g. a sports team or university"], "P823": ["speaker", "person who is speaker for this event, ceremony, keynote, or presentation"], "P418": ["seal description", "description of a subject's seal"], "P826": ["tonality", "key of a musical composition"], "P825": ["dedicated to", "person or organization to whom the subject was dedicated"], "P559": ["terminus", "the feature (intersecting road, train station, etc.) at the end of a linear feature"], "P553": ["website account on", "a website that the person or organization has an account on (use with P554) Note: only used with reliable source or if the person or organization disclosed it."], "P550": ["chivalric order", "the chivalric order which a person belongs to"], "P551": ["residence", "the place where the person is or has been, resident"], "P556": ["crystal system", "type of crystal for minerals\u00a0and/or for crystal compounds"], "P2647": ["source of material", "place the material used was mined, quarried, found, or produced"], "P1594": ["judge", "Judge, magistrate or equivalent, presiding at a trial"], "P1595": ["charge", "offence with which someone is charged, at a trial"], "P1596": ["penalty", "penalty passed at a trial"], "P1591": ["defendant", "person or organization accused, at a trial"], "P1592": ["prosecutor", "person representing the prosecuting authority, at a trial"], "P1593": ["defender", "person representing the defendant, at a trial"], "P3701": ["incarnation of", "incarnation of another religious or supernatural being"], "P1598": ["consecrator", "bishop who presided as consecrator or co-consecrator of this bishop"], "P361": ["part of", "object of which the subject is a part (it's not useful to link objects which are themselves parts of other objects already listed as parts of the subject). Inverse property of \"has part\" (P527, see also \"has parts of the class\" (P2670))."], "P1027": ["conferred by", "person or organization who awards a prize to or bestows an honor upon a recipient"], "P598": ["commander of", "for persons who are notable as commanding officers, the units they commanded"], "P360": ["is a list of", "common element between all listed items"], "P1840": ["investigated by", "person or organization involved in investigation of the item"], "P4132": ["linguistic typology", "classification of languages according to their linguistic trait (as opposed to historical families like romance languages)"], "P1028": ["donated by", "person or organization who donated the object"], "P1029": ["crew member", "person that participated operating or serving aboard this vehicle"], "P3093": ["recovered by", "person, organisation or vehicle that recovered the item. Use the most specific value known."], "P1366": ["replaced by", "other person or item which continues the item by replacing it in its role. Use P156 (followed by) if the item is not replaced (e.g. books in a series), nor identical, but adds to the series without dropping the role of this item in that series"], "P1365": ["replaces", "person or item replaced. Use P1398 (structure replaces) for structures. Use P155 (follows) if the previous item was not replaced or if predecessor and successor are identical"], "P735": ["given name", "first name or another given name of this person; values used with the property shouldn't link disambiguations nor family names"], "P1363": ["points/goal scored by", "person who scored a point or goal in a game"], "P3095": ["practiced by", "type of agents that study this subject or work in this profession"], "P113": ["airline hub", "airport that serves as a hub for an airline"], "P110": ["illustrator", "person drawing the pictures in a book"], "P364": ["original language of film or TV show", "language in which a film or a performance work was originally created. Deprecated for written works; use P407 (\"language of work or name\") instead."], "P1165": ["home world", "home planet or natural satellite for a fictional character or species"], "P2860": ["cites work", "citation from one creative work to another"], "P2673": ["next crossing upstream", "next crossing of this river, canal, etc. upstream of this subject"], "P2869": ["record or record progression", "links to item on the record or record progression"], "P306": ["operating system", "operating system (OS) on which a software works or the OS installed on hardware"], "P1201": ["space tug", "spacecraft vehicle designed to move the payload from a reference orbit to the target orbit, or direct it to an interplanetary trajectory"], "P301": ["category's main topic", "primary topic of the subject Wikimedia category"], "P162": ["producer", "person(s) who produced the film, musical work, theatrical production, etc. (for film, this does not include executive producers, associate producers, etc.) [for production company, use P272, video games - use P178]"], "P6": ["head of government", "head of the executive power of this town, city, municipality, state, country, or other governmental body"], "P3148": ["repeals", "this document or act repeals that other document or act"], "P115": ["home venue", "home stadium or venue of a sports team or applicable performing arts organization"], "P2674": ["next crossing downstream", "next crossing of this river, canal, etc. downstream of this subject"], "P4044": ["therapeutic area", "disease area in which a medical intervention is applied"], "P780": ["symptoms", "possible symptoms of a medical condition"], "P4043": ["emulates", "emulates the identified platform, CPU, or system"], "P789": ["edibility", "whether a mushroom can be eaten or not"], "P788": ["mushroom ecological type", "property classifying the ecological type of a mushroom"], "P3022": ["flag bearer", "person who carries the national flag of their country at an opening or closing ceremony"], "P118": ["league", "league in which team or player plays or has played in"]} -------------------------------------------------------------------------------- /download_pretrain.sh: -------------------------------------------------------------------------------- 1 | mkdir pretrain 2 | wget -P pretrain https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/glove/glove.6B.50d_mat.npy 3 | wget -P pretrain https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/glove/glove.6B.50d_word2id.json 4 | wget -P pretrain https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/config.json 5 | wget -P pretrain https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/pytorch_model.bin 6 | wget -P pretrain https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/vocab.txt 7 | -------------------------------------------------------------------------------- /fewshot_re_kit/__init__.py: -------------------------------------------------------------------------------- 1 | from fewshot_re_kit import data_loader 2 | from fewshot_re_kit import framework 3 | from fewshot_re_kit import sentence_encoder 4 | 5 | -------------------------------------------------------------------------------- /fewshot_re_kit/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import os 4 | import numpy as np 5 | import random 6 | import json 7 | 8 | class FewRelDataset(data.Dataset): 9 | """ 10 | FewRel Dataset 11 | """ 12 | def __init__(self, name, encoder, N, K, Q, na_rate, root): 13 | self.root = root 14 | path = os.path.join(root, name + ".json") 15 | if not os.path.exists(path): 16 | print("[ERROR] Data file does not exist!") 17 | assert(0) 18 | self.json_data = json.load(open(path)) 19 | self.classes = list(self.json_data.keys()) 20 | self.N = N 21 | self.K = K 22 | self.Q = Q 23 | self.na_rate = na_rate 24 | self.encoder = encoder 25 | 26 | def __getraw__(self, item): 27 | word, pos1, pos2, mask = self.encoder.tokenize(item['tokens'], 28 | item['h'][2][0], 29 | item['t'][2][0]) 30 | return word, pos1, pos2, mask 31 | 32 | def __additem__(self, d, word, pos1, pos2, mask): 33 | d['word'].append(word) 34 | d['pos1'].append(pos1) 35 | d['pos2'].append(pos2) 36 | d['mask'].append(mask) 37 | 38 | def __getitem__(self, index): 39 | target_classes = random.sample(self.classes, self.N) 40 | support_set = {'word': [], 'pos1': [], 'pos2': [], 'mask': [] } 41 | query_set = {'word': [], 'pos1': [], 'pos2': [], 'mask': [] } 42 | query_label = [] 43 | Q_na = int(self.na_rate * self.Q) 44 | na_classes = list(filter(lambda x: x not in target_classes, 45 | self.classes)) 46 | 47 | for i, class_name in enumerate(target_classes): 48 | indices = np.random.choice( 49 | list(range(len(self.json_data[class_name]))), 50 | self.K + self.Q, False) 51 | count = 0 52 | for j in indices: 53 | word, pos1, pos2, mask = self.__getraw__( 54 | self.json_data[class_name][j]) 55 | word = torch.tensor(word).long() 56 | pos1 = torch.tensor(pos1).long() 57 | pos2 = torch.tensor(pos2).long() 58 | mask = torch.tensor(mask).long() 59 | if count < self.K: 60 | self.__additem__(support_set, word, pos1, pos2, mask) 61 | else: 62 | self.__additem__(query_set, word, pos1, pos2, mask) 63 | count += 1 64 | 65 | query_label += [i] * self.Q 66 | 67 | # NA 68 | for j in range(Q_na): 69 | cur_class = np.random.choice(na_classes, 1, False)[0] 70 | index = np.random.choice( 71 | list(range(len(self.json_data[cur_class]))), 72 | 1, False)[0] 73 | word, pos1, pos2, mask = self.__getraw__( 74 | self.json_data[cur_class][index]) 75 | word = torch.tensor(word).long() 76 | pos1 = torch.tensor(pos1).long() 77 | pos2 = torch.tensor(pos2).long() 78 | mask = torch.tensor(mask).long() 79 | self.__additem__(query_set, word, pos1, pos2, mask) 80 | query_label += [self.N] * Q_na 81 | 82 | return support_set, query_set, query_label 83 | 84 | def __len__(self): 85 | return 1000000000 86 | 87 | def collate_fn(data): 88 | batch_support = {'word': [], 'pos1': [], 'pos2': [], 'mask': []} 89 | batch_query = {'word': [], 'pos1': [], 'pos2': [], 'mask': []} 90 | batch_label = [] 91 | support_sets, query_sets, query_labels = zip(*data) 92 | for i in range(len(support_sets)): 93 | for k in support_sets[i]: 94 | batch_support[k] += support_sets[i][k] 95 | for k in query_sets[i]: 96 | batch_query[k] += query_sets[i][k] 97 | batch_label += query_labels[i] 98 | for k in batch_support: 99 | batch_support[k] = torch.stack(batch_support[k], 0) 100 | for k in batch_query: 101 | batch_query[k] = torch.stack(batch_query[k], 0) 102 | batch_label = torch.tensor(batch_label) 103 | return batch_support, batch_query, batch_label 104 | 105 | def get_loader(name, encoder, N, K, Q, batch_size, 106 | num_workers=8, collate_fn=collate_fn, na_rate=0, root='./data'): 107 | dataset = FewRelDataset(name, encoder, N, K, Q, na_rate, root) 108 | data_loader = data.DataLoader(dataset=dataset, 109 | batch_size=batch_size, 110 | shuffle=False, 111 | pin_memory=True, 112 | num_workers=num_workers, 113 | collate_fn=collate_fn) 114 | return iter(data_loader) 115 | 116 | class FewRelDatasetPair(data.Dataset): 117 | """ 118 | FewRel Pair Dataset 119 | """ 120 | def __init__(self, name, encoder, N, K, Q, na_rate, root, encoder_name): 121 | self.root = root 122 | path = os.path.join(root, name + ".json") 123 | if not os.path.exists(path): 124 | print("[ERROR] Data file does not exist!") 125 | assert(0) 126 | self.json_data = json.load(open(path)) 127 | self.classes = list(self.json_data.keys()) 128 | self.N = N 129 | self.K = K 130 | self.Q = Q 131 | self.na_rate = na_rate 132 | self.encoder = encoder 133 | self.encoder_name = encoder_name 134 | self.max_length = encoder.max_length 135 | 136 | def __getraw__(self, item): 137 | word = self.encoder.tokenize(item['tokens'], 138 | item['h'][2][0], 139 | item['t'][2][0]) 140 | return word 141 | 142 | def __additem__(self, d, word, pos1, pos2, mask): 143 | d['word'].append(word) 144 | d['pos1'].append(pos1) 145 | d['pos2'].append(pos2) 146 | d['mask'].append(mask) 147 | 148 | def __getitem__(self, index): 149 | target_classes = random.sample(self.classes, self.N) 150 | support = [] 151 | query = [] 152 | fusion_set = {'word': [], 'mask': [], 'seg': []} 153 | query_label = [] 154 | Q_na = int(self.na_rate * self.Q) 155 | na_classes = list(filter(lambda x: x not in target_classes, 156 | self.classes)) 157 | 158 | for i, class_name in enumerate(target_classes): 159 | indices = np.random.choice( 160 | list(range(len(self.json_data[class_name]))), 161 | self.K + self.Q, False) 162 | count = 0 163 | for j in indices: 164 | word = self.__getraw__( 165 | self.json_data[class_name][j]) 166 | if count < self.K: 167 | support.append(word) 168 | else: 169 | query.append(word) 170 | count += 1 171 | 172 | query_label += [i] * self.Q 173 | 174 | # NA 175 | for j in range(Q_na): 176 | cur_class = np.random.choice(na_classes, 1, False)[0] 177 | index = np.random.choice( 178 | list(range(len(self.json_data[cur_class]))), 179 | 1, False)[0] 180 | word = self.__getraw__( 181 | self.json_data[cur_class][index]) 182 | query.append(word) 183 | query_label += [self.N] * Q_na 184 | 185 | for word_query in query: 186 | for word_support in support: 187 | if self.encoder_name == 'bert': 188 | SEP = self.encoder.tokenizer.convert_tokens_to_ids(['[SEP]']) 189 | CLS = self.encoder.tokenizer.convert_tokens_to_ids(['[CLS]']) 190 | word_tensor = torch.zeros((self.max_length)).long() 191 | else: 192 | SEP = self.encoder.tokenizer.convert_tokens_to_ids(['']) 193 | CLS = self.encoder.tokenizer.convert_tokens_to_ids(['']) 194 | word_tensor = torch.ones((self.max_length)).long() 195 | new_word = CLS + word_support + SEP + word_query + SEP 196 | for i in range(min(self.max_length, len(new_word))): 197 | word_tensor[i] = new_word[i] 198 | mask_tensor = torch.zeros((self.max_length)).long() 199 | mask_tensor[:min(self.max_length, len(new_word))] = 1 200 | seg_tensor = torch.ones((self.max_length)).long() 201 | seg_tensor[:min(self.max_length, len(word_support) + 1)] = 0 202 | fusion_set['word'].append(word_tensor) 203 | fusion_set['mask'].append(mask_tensor) 204 | fusion_set['seg'].append(seg_tensor) 205 | 206 | return fusion_set, query_label 207 | 208 | def __len__(self): 209 | return 1000000000 210 | 211 | def collate_fn_pair(data): 212 | batch_set = {'word': [], 'seg': [], 'mask': []} 213 | batch_label = [] 214 | fusion_sets, query_labels = zip(*data) 215 | for i in range(len(fusion_sets)): 216 | for k in fusion_sets[i]: 217 | batch_set[k] += fusion_sets[i][k] 218 | batch_label += query_labels[i] 219 | for k in batch_set: 220 | batch_set[k] = torch.stack(batch_set[k], 0) 221 | batch_label = torch.tensor(batch_label) 222 | return batch_set, batch_label 223 | 224 | def get_loader_pair(name, encoder, N, K, Q, batch_size, 225 | num_workers=8, collate_fn=collate_fn_pair, na_rate=0, root='./data', encoder_name='bert'): 226 | dataset = FewRelDatasetPair(name, encoder, N, K, Q, na_rate, root, encoder_name) 227 | data_loader = data.DataLoader(dataset=dataset, 228 | batch_size=batch_size, 229 | shuffle=False, 230 | pin_memory=True, 231 | num_workers=num_workers, 232 | collate_fn=collate_fn) 233 | return iter(data_loader) 234 | 235 | class FewRelUnsupervisedDataset(data.Dataset): 236 | """ 237 | FewRel Unsupervised Dataset 238 | """ 239 | def __init__(self, name, encoder, N, K, Q, na_rate, root): 240 | self.root = root 241 | path = os.path.join(root, name + ".json") 242 | if not os.path.exists(path): 243 | print("[ERROR] Data file does not exist!") 244 | assert(0) 245 | self.json_data = json.load(open(path)) 246 | self.N = N 247 | self.K = K 248 | self.Q = Q 249 | self.na_rate = na_rate 250 | self.encoder = encoder 251 | 252 | def __getraw__(self, item): 253 | word, pos1, pos2, mask = self.encoder.tokenize(item['tokens'], 254 | item['h'][2][0], 255 | item['t'][2][0]) 256 | return word, pos1, pos2, mask 257 | 258 | def __additem__(self, d, word, pos1, pos2, mask): 259 | d['word'].append(word) 260 | d['pos1'].append(pos1) 261 | d['pos2'].append(pos2) 262 | d['mask'].append(mask) 263 | 264 | def __getitem__(self, index): 265 | total = self.N * self.K 266 | support_set = {'word': [], 'pos1': [], 'pos2': [], 'mask': [] } 267 | 268 | indices = np.random.choice(list(range(len(self.json_data))), total, False) 269 | for j in indices: 270 | word, pos1, pos2, mask = self.__getraw__( 271 | self.json_data[j]) 272 | word = torch.tensor(word).long() 273 | pos1 = torch.tensor(pos1).long() 274 | pos2 = torch.tensor(pos2).long() 275 | mask = torch.tensor(mask).long() 276 | self.__additem__(support_set, word, pos1, pos2, mask) 277 | 278 | return support_set 279 | 280 | def __len__(self): 281 | return 1000000000 282 | 283 | def collate_fn_unsupervised(data): 284 | batch_support = {'word': [], 'pos1': [], 'pos2': [], 'mask': []} 285 | support_sets = data 286 | for i in range(len(support_sets)): 287 | for k in support_sets[i]: 288 | batch_support[k] += support_sets[i][k] 289 | for k in batch_support: 290 | batch_support[k] = torch.stack(batch_support[k], 0) 291 | return batch_support 292 | 293 | def get_loader_unsupervised(name, encoder, N, K, Q, batch_size, 294 | num_workers=8, collate_fn=collate_fn_unsupervised, na_rate=0, root='./data'): 295 | dataset = FewRelUnsupervisedDataset(name, encoder, N, K, Q, na_rate, root) 296 | data_loader = data.DataLoader(dataset=dataset, 297 | batch_size=batch_size, 298 | shuffle=False, 299 | pin_memory=True, 300 | num_workers=num_workers, 301 | collate_fn=collate_fn) 302 | return iter(data_loader) 303 | 304 | 305 | -------------------------------------------------------------------------------- /fewshot_re_kit/framework.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sklearn.metrics 3 | import numpy as np 4 | import sys 5 | import time 6 | from . import sentence_encoder 7 | from . import data_loader 8 | import torch 9 | from torch import autograd, optim, nn 10 | from torch.autograd import Variable 11 | from torch.nn import functional as F 12 | # from pytorch_pretrained_bert import BertAdam 13 | from transformers import AdamW, get_linear_schedule_with_warmup 14 | 15 | def warmup_linear(global_step, warmup_step): 16 | if global_step < warmup_step: 17 | return global_step / warmup_step 18 | else: 19 | return 1.0 20 | 21 | class FewShotREModel(nn.Module): 22 | def __init__(self, my_sentence_encoder): 23 | ''' 24 | sentence_encoder: Sentence encoder 25 | 26 | You need to set self.cost as your own loss function. 27 | ''' 28 | nn.Module.__init__(self) 29 | self.sentence_encoder = nn.DataParallel(my_sentence_encoder) 30 | self.cost = nn.CrossEntropyLoss() 31 | 32 | def forward(self, support, query, N, K, Q): 33 | ''' 34 | support: Inputs of the support set. 35 | query: Inputs of the query set. 36 | N: Num of classes 37 | K: Num of instances for each class in the support set 38 | Q: Num of instances for each class in the query set 39 | return: logits, pred 40 | ''' 41 | raise NotImplementedError 42 | 43 | def loss(self, logits, label): 44 | ''' 45 | logits: Logits with the size (..., class_num) 46 | label: Label with whatever size. 47 | return: [Loss] (A single value) 48 | ''' 49 | N = logits.size(-1) 50 | return self.cost(logits.view(-1, N), label.view(-1)) 51 | 52 | def accuracy(self, pred, label): 53 | ''' 54 | pred: Prediction results with whatever size 55 | label: Label with whatever size 56 | return: [Accuracy] (A single value) 57 | ''' 58 | return torch.mean((pred.view(-1) == label.view(-1)).type(torch.FloatTensor)) 59 | 60 | class FewShotREFramework: 61 | 62 | def __init__(self, train_data_loader, val_data_loader, test_data_loader, adv_data_loader=None, adv=False, d=None): 63 | ''' 64 | train_data_loader: DataLoader for training. 65 | val_data_loader: DataLoader for validating. 66 | test_data_loader: DataLoader for testing. 67 | ''' 68 | self.train_data_loader = train_data_loader 69 | self.val_data_loader = val_data_loader 70 | self.test_data_loader = test_data_loader 71 | self.adv_data_loader = adv_data_loader 72 | self.adv = adv 73 | if adv: 74 | self.adv_cost = nn.CrossEntropyLoss() 75 | self.d = d 76 | self.d.cuda() 77 | 78 | def __load_model__(self, ckpt): 79 | ''' 80 | ckpt: Path of the checkpoint 81 | return: Checkpoint dict 82 | ''' 83 | if os.path.isfile(ckpt): 84 | checkpoint = torch.load(ckpt) 85 | print("Successfully loaded checkpoint '%s'" % ckpt) 86 | return checkpoint 87 | else: 88 | raise Exception("No checkpoint found at '%s'" % ckpt) 89 | 90 | def item(self, x): 91 | ''' 92 | PyTorch before and after 0.4 93 | ''' 94 | torch_version = torch.__version__.split('.') 95 | if int(torch_version[0]) == 0 and int(torch_version[1]) < 4: 96 | return x[0] 97 | else: 98 | return x.item() 99 | 100 | def train(self, 101 | model, 102 | model_name, 103 | B, N_for_train, N_for_eval, K, Q, 104 | na_rate=0, 105 | learning_rate=1e-1, 106 | lr_step_size=20000, 107 | weight_decay=1e-5, 108 | train_iter=30000, 109 | val_iter=1000, 110 | val_step=2000, 111 | test_iter=3000, 112 | load_ckpt=None, 113 | save_ckpt=None, 114 | pytorch_optim=optim.SGD, 115 | bert_optim=False, 116 | warmup=True, 117 | warmup_step=300, 118 | grad_iter=1, 119 | fp16=False, 120 | pair=False, 121 | adv_dis_lr=1e-1, 122 | adv_enc_lr=1e-1, 123 | use_sgd_for_bert=False): 124 | ''' 125 | model: a FewShotREModel instance 126 | model_name: Name of the model 127 | B: Batch size 128 | N: Num of classes for each batch 129 | K: Num of instances for each class in the support set 130 | Q: Num of instances for each class in the query set 131 | ckpt_dir: Directory of checkpoints 132 | learning_rate: Initial learning rate 133 | lr_step_size: Decay learning rate every lr_step_size steps 134 | weight_decay: Rate of decaying weight 135 | train_iter: Num of iterations of training 136 | val_iter: Num of iterations of validating 137 | val_step: Validate every val_step steps 138 | test_iter: Num of iterations of testing 139 | ''' 140 | print("Start training...") 141 | 142 | # Init 143 | if bert_optim: 144 | print('Use bert optim!') 145 | parameters_to_optimize = list(model.named_parameters()) 146 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 147 | parameters_to_optimize = [ 148 | {'params': [p for n, p in parameters_to_optimize 149 | if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 150 | {'params': [p for n, p in parameters_to_optimize 151 | if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 152 | ] 153 | if use_sgd_for_bert: 154 | optimizer = torch.optim.SGD(parameters_to_optimize, lr=learning_rate) 155 | else: 156 | optimizer = AdamW(parameters_to_optimize, lr=learning_rate, correct_bias=False) 157 | if self.adv: 158 | optimizer_encoder = AdamW(parameters_to_optimize, lr=1e-5, correct_bias=False) 159 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=train_iter) 160 | else: 161 | optimizer = pytorch_optim(model.parameters(), 162 | learning_rate, weight_decay=weight_decay) 163 | if self.adv: 164 | optimizer_encoder = pytorch_optim(model.parameters(), lr=adv_enc_lr) 165 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_step_size) 166 | 167 | if self.adv: 168 | optimizer_dis = pytorch_optim(self.d.parameters(), lr=adv_dis_lr) 169 | 170 | if load_ckpt: 171 | state_dict = self.__load_model__(load_ckpt)['state_dict'] 172 | own_state = model.state_dict() 173 | for name, param in state_dict.items(): 174 | if name not in own_state: 175 | print('ignore {}'.format(name)) 176 | continue 177 | print('load {} from {}'.format(name, load_ckpt)) 178 | own_state[name].copy_(param) 179 | start_iter = 0 180 | else: 181 | start_iter = 0 182 | 183 | if fp16: 184 | from apex import amp 185 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 186 | 187 | model.train() 188 | if self.adv: 189 | self.d.train() 190 | 191 | # Training 192 | best_acc = 0 193 | iter_loss = 0.0 194 | iter_loss_dis = 0.0 195 | iter_right = 0.0 196 | iter_right_dis = 0.0 197 | iter_sample = 0.0 198 | for it in range(start_iter, start_iter + train_iter): 199 | if pair: 200 | batch, label = next(self.train_data_loader) 201 | if torch.cuda.is_available(): 202 | for k in batch: 203 | batch[k] = batch[k].cuda() 204 | label = label.cuda() 205 | logits, pred = model(batch, N_for_train, K, 206 | Q * N_for_train + na_rate * Q) 207 | else: 208 | support, query, label = next(self.train_data_loader) 209 | if torch.cuda.is_available(): 210 | for k in support: 211 | support[k] = support[k].cuda() 212 | for k in query: 213 | query[k] = query[k].cuda() 214 | label = label.cuda() 215 | 216 | logits, pred = model(support, query, 217 | N_for_train, K, Q * N_for_train + na_rate * Q) 218 | loss = model.loss(logits, label) / float(grad_iter) 219 | right = model.accuracy(pred, label) 220 | if fp16: 221 | with amp.scale_loss(loss, optimizer) as scaled_loss: 222 | scaled_loss.backward() 223 | # torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 10) 224 | else: 225 | loss.backward() 226 | torch.nn.utils.clip_grad_norm_(model.parameters(), 10) 227 | 228 | if it % grad_iter == 0: 229 | optimizer.step() 230 | scheduler.step() 231 | optimizer.zero_grad() 232 | 233 | # Adv part 234 | if self.adv: 235 | support_adv = next(self.adv_data_loader) 236 | if torch.cuda.is_available(): 237 | for k in support_adv: 238 | support_adv[k] = support_adv[k].cuda() 239 | 240 | features_ori = model.sentence_encoder(support) 241 | features_adv = model.sentence_encoder(support_adv) 242 | features = torch.cat([features_ori, features_adv], 0) 243 | total = features.size(0) 244 | dis_labels = torch.cat([torch.zeros((total // 2)).long().cuda(), 245 | torch.ones((total // 2)).long().cuda()], 0) 246 | dis_logits = self.d(features) 247 | loss_dis = self.adv_cost(dis_logits, dis_labels) 248 | _, pred = dis_logits.max(-1) 249 | right_dis = float((pred == dis_labels).long().sum()) / float(total) 250 | 251 | loss_dis.backward(retain_graph=True) 252 | optimizer_dis.step() 253 | optimizer_dis.zero_grad() 254 | optimizer_encoder.zero_grad() 255 | 256 | loss_encoder = self.adv_cost(dis_logits, 1 - dis_labels) 257 | 258 | loss_encoder.backward(retain_graph=True) 259 | optimizer_encoder.step() 260 | optimizer_dis.zero_grad() 261 | optimizer_encoder.zero_grad() 262 | 263 | iter_loss_dis += self.item(loss_dis.data) 264 | iter_right_dis += right_dis 265 | 266 | iter_loss += self.item(loss.data) 267 | iter_right += self.item(right.data) 268 | iter_sample += 1 269 | if self.adv: 270 | sys.stdout.write('step: {0:4} | loss: {1:2.6f}, accuracy: {2:3.2f}%, dis_loss: {3:2.6f}, dis_acc: {4:2.6f}' 271 | .format(it + 1, iter_loss / iter_sample, 272 | 100 * iter_right / iter_sample, 273 | iter_loss_dis / iter_sample, 274 | 100 * iter_right_dis / iter_sample) + '\r') 275 | else: 276 | sys.stdout.write('step: {0:4} | loss: {1:2.6f}, accuracy: {2:3.2f}%'.format(it + 1, iter_loss / iter_sample, 100 * iter_right / iter_sample) + '\r') 277 | sys.stdout.flush() 278 | 279 | if (it + 1) % val_step == 0: 280 | acc = self.eval(model, B, N_for_eval, K, Q, val_iter, 281 | na_rate=na_rate, pair=pair) 282 | model.train() 283 | if acc > best_acc: 284 | print('Best checkpoint') 285 | torch.save({'state_dict': model.state_dict()}, save_ckpt) 286 | best_acc = acc 287 | iter_loss = 0. 288 | iter_loss_dis = 0. 289 | iter_right = 0. 290 | iter_right_dis = 0. 291 | iter_sample = 0. 292 | 293 | print("\n####################\n") 294 | print("Finish training " + model_name) 295 | 296 | def eval(self, 297 | model, 298 | B, N, K, Q, 299 | eval_iter, 300 | na_rate=0, 301 | pair=False, 302 | ckpt=None): 303 | ''' 304 | model: a FewShotREModel instance 305 | B: Batch size 306 | N: Num of classes for each batch 307 | K: Num of instances for each class in the support set 308 | Q: Num of instances for each class in the query set 309 | eval_iter: Num of iterations 310 | ckpt: Checkpoint path. Set as None if using current model parameters. 311 | return: Accuracy 312 | ''' 313 | print("") 314 | 315 | model.eval() 316 | if ckpt is None: 317 | print("Use val dataset") 318 | eval_dataset = self.val_data_loader 319 | else: 320 | print("Use test dataset") 321 | if ckpt != 'none': 322 | state_dict = self.__load_model__(ckpt)['state_dict'] 323 | own_state = model.state_dict() 324 | for name, param in state_dict.items(): 325 | if name not in own_state: 326 | continue 327 | own_state[name].copy_(param) 328 | eval_dataset = self.test_data_loader 329 | 330 | iter_right = 0.0 331 | iter_sample = 0.0 332 | with torch.no_grad(): 333 | for it in range(eval_iter): 334 | if pair: 335 | batch, label = next(eval_dataset) 336 | if torch.cuda.is_available(): 337 | for k in batch: 338 | batch[k] = batch[k].cuda() 339 | label = label.cuda() 340 | logits, pred = model(batch, N, K, Q * N + Q * na_rate) 341 | else: 342 | support, query, label = next(eval_dataset) 343 | if torch.cuda.is_available(): 344 | for k in support: 345 | support[k] = support[k].cuda() 346 | for k in query: 347 | query[k] = query[k].cuda() 348 | label = label.cuda() 349 | logits, pred = model(support, query, N, K, Q * N + Q * na_rate) 350 | 351 | right = model.accuracy(pred, label) 352 | iter_right += self.item(right.data) 353 | iter_sample += 1 354 | 355 | sys.stdout.write('[EVAL] step: {0:4} | accuracy: {1:3.2f}%'.format(it + 1, 100 * iter_right / iter_sample) + '\r') 356 | sys.stdout.flush() 357 | print("") 358 | return iter_right / iter_sample 359 | -------------------------------------------------------------------------------- /fewshot_re_kit/network/__init__.py: -------------------------------------------------------------------------------- 1 | from . import embedding 2 | from . import encoder 3 | -------------------------------------------------------------------------------- /fewshot_re_kit/network/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | 7 | class Embedding(nn.Module): 8 | 9 | def __init__(self, word_vec_mat, max_length, word_embedding_dim=50, pos_embedding_dim=5): 10 | nn.Module.__init__(self) 11 | 12 | self.max_length = max_length 13 | self.word_embedding_dim = word_embedding_dim 14 | self.pos_embedding_dim = pos_embedding_dim 15 | 16 | # Word embedding 17 | # unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim) 18 | # blk = torch.zeros(1, word_embedding_dim) 19 | word_vec_mat = torch.from_numpy(word_vec_mat) 20 | self.word_embedding = nn.Embedding(word_vec_mat.shape[0], self.word_embedding_dim, padding_idx=word_vec_mat.shape[0] - 1) 21 | self.word_embedding.weight.data.copy_(word_vec_mat) 22 | 23 | # Position Embedding 24 | self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0) 25 | self.pos2_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0) 26 | 27 | def forward(self, inputs): 28 | word = inputs['word'] 29 | pos1 = inputs['pos1'] 30 | pos2 = inputs['pos2'] 31 | 32 | x = torch.cat([self.word_embedding(word), 33 | self.pos1_embedding(pos1), 34 | self.pos2_embedding(pos2)], 2) 35 | return x 36 | 37 | 38 | -------------------------------------------------------------------------------- /fewshot_re_kit/network/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from torch import optim 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, max_length, word_embedding_dim=50, pos_embedding_dim=5, hidden_size=230): 10 | nn.Module.__init__(self) 11 | 12 | self.max_length = max_length 13 | self.hidden_size = hidden_size 14 | self.embedding_dim = word_embedding_dim + pos_embedding_dim * 2 15 | self.conv = nn.Conv1d(self.embedding_dim, self.hidden_size, 3, padding=1) 16 | self.pool = nn.MaxPool1d(max_length) 17 | 18 | # For PCNN 19 | self.mask_embedding = nn.Embedding(4, 3) 20 | self.mask_embedding.weight.data.copy_(torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0]])) 21 | self.mask_embedding.weight.requires_grad = False 22 | self._minus = -100 23 | 24 | def forward(self, inputs): 25 | return self.cnn(inputs) 26 | 27 | def cnn(self, inputs): 28 | x = self.conv(inputs.transpose(1, 2)) 29 | x = F.relu(x) 30 | x = self.pool(x) 31 | return x.squeeze(2) # n x hidden_size 32 | 33 | def pcnn(self, inputs, mask): 34 | x = self.conv(inputs.transpose(1, 2)) # n x hidden x length 35 | mask = 1 - self.mask_embedding(mask).transpose(1, 2) # n x 3 x length 36 | pool1 = self.pool(F.relu(x + self._minus * mask[:, 0:1, :])) 37 | pool2 = self.pool(F.relu(x + self._minus * mask[:, 1:2, :])) 38 | pool3 = self.pool(F.relu(x + self._minus * mask[:, 2:3, :])) 39 | x = torch.cat([pool1, pool2, pool3], 1) 40 | x = x.squeeze(2) # n x (hidden_size * 3) 41 | 42 | -------------------------------------------------------------------------------- /fewshot_re_kit/old_data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import multiprocessing 4 | import numpy as np 5 | import random 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | class FileDataLoader: 10 | def next_batch(self, B, N, K, Q): 11 | ''' 12 | B: batch size. 13 | N: the number of relations for each batch 14 | K: the number of support instances for each relation 15 | Q: the number of query instances for each relation 16 | return: support_set, query_set, query_label 17 | ''' 18 | raise NotImplementedError 19 | 20 | class JSONFileDataLoader(FileDataLoader): 21 | def _load_preprocessed_file(self): 22 | name_prefix = '.'.join(self.file_name.split('/')[-1].split('.')[:-1]) 23 | word_vec_name_prefix = '.'.join(self.word_vec_file_name.split('/')[-1].split('.')[:-1]) 24 | processed_data_dir = '_processed_data' 25 | if not os.path.isdir(processed_data_dir): 26 | return False 27 | word_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_word.npy') 28 | pos1_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos1.npy') 29 | pos2_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos2.npy') 30 | mask_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_mask.npy') 31 | length_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_length.npy') 32 | rel2scope_file_name = os.path.join(processed_data_dir, name_prefix + '_rel2scope.json') 33 | word_vec_mat_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy') 34 | word2id_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json') 35 | if not os.path.exists(word_npy_file_name) or \ 36 | not os.path.exists(pos1_npy_file_name) or \ 37 | not os.path.exists(pos2_npy_file_name) or \ 38 | not os.path.exists(mask_npy_file_name) or \ 39 | not os.path.exists(length_npy_file_name) or \ 40 | not os.path.exists(rel2scope_file_name) or \ 41 | not os.path.exists(word_vec_mat_file_name) or \ 42 | not os.path.exists(word2id_file_name): 43 | return False 44 | print("Pre-processed files exist. Loading them...") 45 | self.data_word = np.load(word_npy_file_name) 46 | self.data_pos1 = np.load(pos1_npy_file_name) 47 | self.data_pos2 = np.load(pos2_npy_file_name) 48 | self.data_mask = np.load(mask_npy_file_name) 49 | self.data_length = np.load(length_npy_file_name) 50 | self.rel2scope = json.load(open(rel2scope_file_name)) 51 | self.word_vec_mat = np.load(word_vec_mat_file_name) 52 | self.word2id = json.load(open(word2id_file_name)) 53 | if self.data_word.shape[1] != self.max_length: 54 | print("Pre-processed files don't match current settings. Reprocessing...") 55 | return False 56 | print("Finish loading") 57 | return True 58 | 59 | def __init__(self, file_name, word_vec_file_name, max_length=40, case_sensitive=False, reprocess=False, cuda=True): 60 | ''' 61 | file_name: Json file storing the data in the following format 62 | { 63 | "P155": # relation id 64 | [ 65 | { 66 | "h": ["song for a future generation", "Q7561099", [[16, 17, ...]]], # head entity [word, id, location] 67 | "t": ["whammy kiss", "Q7990594", [[11, 12]]], # tail entity [word, id, location] 68 | "token": ["Hot", "Dance", "Club", ...], # sentence 69 | }, 70 | ... 71 | ], 72 | "P177": 73 | [ 74 | ... 75 | ] 76 | ... 77 | } 78 | word_vec_file_name: Json file storing word vectors in the following format 79 | [ 80 | {'word': 'the', 'vec': [0.418, 0.24968, ...]}, 81 | {'word': ',', 'vec': [0.013441, 0.23682, ...]}, 82 | ... 83 | ] 84 | max_length: The length that all the sentences need to be extend to. 85 | case_sensitive: Whether the data processing is case-sensitive, default as False. 86 | reprocess: Do the pre-processing whether there exist pre-processed files, default as False. 87 | cuda: Use cuda or not, default as True. 88 | ''' 89 | self.file_name = file_name 90 | self.word_vec_file_name = word_vec_file_name 91 | self.case_sensitive = case_sensitive 92 | self.max_length = max_length 93 | self.cuda = cuda 94 | 95 | if reprocess or not self._load_preprocessed_file(): # Try to load pre-processed files: 96 | # Check files 97 | if file_name is None or not os.path.isfile(file_name): 98 | raise Exception("[ERROR] Data file doesn't exist") 99 | if word_vec_file_name is None or not os.path.isfile(word_vec_file_name): 100 | raise Exception("[ERROR] Word vector file doesn't exist") 101 | 102 | # Load files 103 | print("Loading data file...") 104 | self.ori_data = json.load(open(self.file_name, "r")) 105 | print("Finish loading") 106 | print("Loading word vector file...") 107 | self.ori_word_vec = json.load(open(self.word_vec_file_name, "r")) 108 | print("Finish loading") 109 | 110 | # Eliminate case sensitive 111 | if not case_sensitive: 112 | print("Elimiating case sensitive problem...") 113 | for relation in self.ori_data: 114 | for ins in self.ori_data[relation]: 115 | for i in range(len(ins['tokens'])): 116 | ins['tokens'][i] = ins['tokens'][i].lower() 117 | print("Finish eliminating") 118 | 119 | 120 | # Pre-process word vec 121 | self.word2id = {} 122 | self.word_vec_tot = len(self.ori_word_vec) 123 | UNK = self.word_vec_tot 124 | BLANK = self.word_vec_tot + 1 125 | self.word_vec_dim = len(self.ori_word_vec[0]['vec']) 126 | print("Got {} words of {} dims".format(self.word_vec_tot, self.word_vec_dim)) 127 | print("Building word vector matrix and mapping...") 128 | self.word_vec_mat = np.zeros((self.word_vec_tot, self.word_vec_dim), dtype=np.float32) 129 | for cur_id, word in enumerate(self.ori_word_vec): 130 | w = word['word'] 131 | if not case_sensitive: 132 | w = w.lower() 133 | self.word2id[w] = cur_id 134 | self.word_vec_mat[cur_id, :] = word['vec'] 135 | self.word_vec_mat[cur_id] = self.word_vec_mat[cur_id] / np.sqrt(np.sum(self.word_vec_mat[cur_id] ** 2)) 136 | self.word2id['UNK'] = UNK 137 | self.word2id['BLANK'] = BLANK 138 | print("Finish building") 139 | 140 | # Pre-process data 141 | print("Pre-processing data...") 142 | self.instance_tot = 0 143 | for relation in self.ori_data: 144 | self.instance_tot += len(self.ori_data[relation]) 145 | self.data_word = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) 146 | self.data_pos1 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) 147 | self.data_pos2 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) 148 | self.data_mask = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) 149 | self.data_length = np.zeros((self.instance_tot), dtype=np.int32) 150 | self.rel2scope = {} # left close right open 151 | i = 0 152 | for relation in self.ori_data: 153 | self.rel2scope[relation] = [i, i] 154 | for ins in self.ori_data[relation]: 155 | head = ins['h'][0] 156 | tail = ins['t'][0] 157 | pos1 = ins['h'][2][0][0] 158 | pos2 = ins['t'][2][0][0] 159 | words = ins['tokens'] 160 | cur_ref_data_word = self.data_word[i] 161 | for j, word in enumerate(words): 162 | if j < max_length: 163 | if word in self.word2id: 164 | cur_ref_data_word[j] = self.word2id[word] 165 | else: 166 | cur_ref_data_word[j] = UNK 167 | for j in range(j + 1, max_length): 168 | cur_ref_data_word[j] = BLANK 169 | self.data_length[i] = len(words) 170 | if len(words) > max_length: 171 | self.data_length[i] = max_length 172 | if pos1 >= max_length: 173 | pos1 = max_length - 1 174 | if pos2 >= max_length: 175 | pos2 = max_length - 1 176 | pos_min = min(pos1, pos2) 177 | pos_max = max(pos1, pos2) 178 | for j in range(max_length): 179 | self.data_pos1[i][j] = j - pos1 + max_length 180 | self.data_pos2[i][j] = j - pos2 + max_length 181 | if j >= self.data_length[i]: 182 | self.data_mask[i][j] = 0 183 | elif j <= pos_min: 184 | self.data_mask[i][j] = 1 185 | elif j <= pos_max: 186 | self.data_mask[i][j] = 2 187 | else: 188 | self.data_mask[i][j] = 3 189 | i += 1 190 | self.rel2scope[relation][1] = i 191 | 192 | print("Finish pre-processing") 193 | 194 | print("Storing processed files...") 195 | name_prefix = '.'.join(file_name.split('/')[-1].split('.')[:-1]) 196 | word_vec_name_prefix = '.'.join(word_vec_file_name.split('/')[-1].split('.')[:-1]) 197 | processed_data_dir = '_processed_data' 198 | if not os.path.isdir(processed_data_dir): 199 | os.mkdir(processed_data_dir) 200 | np.save(os.path.join(processed_data_dir, name_prefix + '_word.npy'), self.data_word) 201 | np.save(os.path.join(processed_data_dir, name_prefix + '_pos1.npy'), self.data_pos1) 202 | np.save(os.path.join(processed_data_dir, name_prefix + '_pos2.npy'), self.data_pos2) 203 | np.save(os.path.join(processed_data_dir, name_prefix + '_mask.npy'), self.data_mask) 204 | np.save(os.path.join(processed_data_dir, name_prefix + '_length.npy'), self.data_length) 205 | json.dump(self.rel2scope, open(os.path.join(processed_data_dir, name_prefix + '_rel2scope.json'), 'w')) 206 | np.save(os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy'), self.word_vec_mat) 207 | json.dump(self.word2id, open(os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json'), 'w')) 208 | print("Finish storing") 209 | 210 | def next_one(self, N, K, Q): 211 | target_classes = random.sample(self.rel2scope.keys(), N) 212 | support_set = {'word': [], 'pos1': [], 'pos2': [], 'mask': []} 213 | query_set = {'word': [], 'pos1': [], 'pos2': [], 'mask': []} 214 | query_label = [] 215 | 216 | for i, class_name in enumerate(target_classes): 217 | scope = self.rel2scope[class_name] 218 | indices = np.random.choice(list(range(scope[0], scope[1])), K + Q, False) 219 | word = self.data_word[indices] 220 | pos1 = self.data_pos1[indices] 221 | pos2 = self.data_pos2[indices] 222 | mask = self.data_mask[indices] 223 | support_word, query_word, _ = np.split(word, [K, K + Q]) 224 | support_pos1, query_pos1, _ = np.split(pos1, [K, K + Q]) 225 | support_pos2, query_pos2, _ = np.split(pos2, [K, K + Q]) 226 | support_mask, query_mask, _ = np.split(mask, [K, K + Q]) 227 | support_set['word'].append(support_word) 228 | support_set['pos1'].append(support_pos1) 229 | support_set['pos2'].append(support_pos2) 230 | support_set['mask'].append(support_mask) 231 | query_set['word'].append(query_word) 232 | query_set['pos1'].append(query_pos1) 233 | query_set['pos2'].append(query_pos2) 234 | query_set['mask'].append(query_mask) 235 | query_label += [i] * Q 236 | 237 | support_set['word'] = np.stack(support_set['word'], 0) 238 | support_set['pos1'] = np.stack(support_set['pos1'], 0) 239 | support_set['pos2'] = np.stack(support_set['pos2'], 0) 240 | support_set['mask'] = np.stack(support_set['mask'], 0) 241 | query_set['word'] = np.concatenate(query_set['word'], 0) 242 | query_set['pos1'] = np.concatenate(query_set['pos1'], 0) 243 | query_set['pos2'] = np.concatenate(query_set['pos2'], 0) 244 | query_set['mask'] = np.concatenate(query_set['mask'], 0) 245 | query_label = np.array(query_label) 246 | 247 | perm = np.random.permutation(N * Q) 248 | query_set['word'] = query_set['word'][perm] 249 | query_set['pos1'] = query_set['pos1'][perm] 250 | query_set['pos2'] = query_set['pos2'][perm] 251 | query_set['mask'] = query_set['mask'][perm] 252 | query_label = query_label[perm] 253 | 254 | return support_set, query_set, query_label 255 | 256 | def next_batch(self, B, N, K, Q): 257 | support = {'word': [], 'pos1': [], 'pos2': [], 'mask': []} 258 | query = {'word': [], 'pos1': [], 'pos2': [], 'mask': []} 259 | label = [] 260 | for one_sample in range(B): 261 | current_support, current_query, current_label = self.next_one(N, K, Q) 262 | support['word'].append(current_support['word']) 263 | support['pos1'].append(current_support['pos1']) 264 | support['pos2'].append(current_support['pos2']) 265 | support['mask'].append(current_support['mask']) 266 | query['word'].append(current_query['word']) 267 | query['pos1'].append(current_query['pos1']) 268 | query['pos2'].append(current_query['pos2']) 269 | query['mask'].append(current_query['mask']) 270 | label.append(current_label) 271 | support['word'] = Variable(torch.from_numpy(np.stack(support['word'], 0)).long().view(-1, self.max_length)) 272 | support['pos1'] = Variable(torch.from_numpy(np.stack(support['pos1'], 0)).long().view(-1, self.max_length)) 273 | support['pos2'] = Variable(torch.from_numpy(np.stack(support['pos2'], 0)).long().view(-1, self.max_length)) 274 | support['mask'] = Variable(torch.from_numpy(np.stack(support['mask'], 0)).long().view(-1, self.max_length)) 275 | query['word'] = Variable(torch.from_numpy(np.stack(query['word'], 0)).long().view(-1, self.max_length)) 276 | query['pos1'] = Variable(torch.from_numpy(np.stack(query['pos1'], 0)).long().view(-1, self.max_length)) 277 | query['pos2'] = Variable(torch.from_numpy(np.stack(query['pos2'], 0)).long().view(-1, self.max_length)) 278 | query['mask'] = Variable(torch.from_numpy(np.stack(query['mask'], 0)).long().view(-1, self.max_length)) 279 | label = Variable(torch.from_numpy(np.stack(label, 0).astype(np.int64)).long()) 280 | 281 | # To cuda 282 | if self.cuda: 283 | for key in support: 284 | support[key] = support[key].cuda() 285 | for key in query: 286 | query[key] = query[key].cuda() 287 | label = label.cuda() 288 | 289 | return support, query, label 290 | -------------------------------------------------------------------------------- /fewshot_re_kit/sentence_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | import os 7 | from torch import optim 8 | from . import network 9 | from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification, RobertaModel, RobertaTokenizer, RobertaForSequenceClassification 10 | 11 | class CNNSentenceEncoder(nn.Module): 12 | 13 | def __init__(self, word_vec_mat, word2id, max_length, word_embedding_dim=50, 14 | pos_embedding_dim=5, hidden_size=230): 15 | nn.Module.__init__(self) 16 | self.hidden_size = hidden_size 17 | self.max_length = max_length 18 | self.embedding = network.embedding.Embedding(word_vec_mat, max_length, 19 | word_embedding_dim, pos_embedding_dim) 20 | self.encoder = network.encoder.Encoder(max_length, word_embedding_dim, 21 | pos_embedding_dim, hidden_size) 22 | self.word2id = word2id 23 | 24 | def forward(self, inputs): 25 | x = self.embedding(inputs) 26 | x = self.encoder(x) 27 | return x 28 | 29 | def tokenize(self, raw_tokens, pos_head, pos_tail): 30 | # token -> index 31 | indexed_tokens = [] 32 | for token in raw_tokens: 33 | token = token.lower() 34 | if token in self.word2id: 35 | indexed_tokens.append(self.word2id[token]) 36 | else: 37 | indexed_tokens.append(self.word2id['[UNK]']) 38 | 39 | # padding 40 | while len(indexed_tokens) < self.max_length: 41 | indexed_tokens.append(self.word2id['[PAD]']) 42 | indexed_tokens = indexed_tokens[:self.max_length] 43 | 44 | # pos 45 | pos1 = np.zeros((self.max_length), dtype=np.int32) 46 | pos2 = np.zeros((self.max_length), dtype=np.int32) 47 | pos1_in_index = min(self.max_length, pos_head[0]) 48 | pos2_in_index = min(self.max_length, pos_tail[0]) 49 | for i in range(self.max_length): 50 | pos1[i] = i - pos1_in_index + self.max_length 51 | pos2[i] = i - pos2_in_index + self.max_length 52 | 53 | # mask 54 | mask = np.zeros((self.max_length), dtype=np.int32) 55 | mask[:len(indexed_tokens)] = 1 56 | 57 | return indexed_tokens, pos1, pos2, mask 58 | 59 | 60 | class BERTSentenceEncoder(nn.Module): 61 | 62 | def __init__(self, pretrain_path, max_length, cat_entity_rep=False, mask_entity=False): 63 | nn.Module.__init__(self) 64 | self.bert = BertModel.from_pretrained(pretrain_path) 65 | self.max_length = max_length 66 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 67 | self.cat_entity_rep = cat_entity_rep 68 | self.mask_entity = mask_entity 69 | 70 | def forward(self, inputs): 71 | if not self.cat_entity_rep: 72 | _, x = self.bert(inputs['word'], attention_mask=inputs['mask']) 73 | return x 74 | else: 75 | outputs = self.bert(inputs['word'], attention_mask=inputs['mask']) 76 | tensor_range = torch.arange(inputs['word'].size()[0]) 77 | h_state = outputs[0][tensor_range, inputs["pos1"]] 78 | t_state = outputs[0][tensor_range, inputs["pos2"]] 79 | state = torch.cat((h_state, t_state), -1) 80 | return state 81 | 82 | def tokenize(self, raw_tokens, pos_head, pos_tail): 83 | # token -> index 84 | tokens = ['[CLS]'] 85 | cur_pos = 0 86 | pos1_in_index = 1 87 | pos2_in_index = 1 88 | for token in raw_tokens: 89 | token = token.lower() 90 | if cur_pos == pos_head[0]: 91 | tokens.append('[unused0]') 92 | pos1_in_index = len(tokens) 93 | if cur_pos == pos_tail[0]: 94 | tokens.append('[unused1]') 95 | pos2_in_index = len(tokens) 96 | if self.mask_entity and ((pos_head[0] <= cur_pos and cur_pos <= pos_head[-1]) or (pos_tail[0] <= cur_pos and cur_pos <= pos_tail[-1])): 97 | tokens += ['[unused4]'] 98 | else: 99 | tokens += self.tokenizer.tokenize(token) 100 | if cur_pos == pos_head[-1]: 101 | tokens.append('[unused2]') 102 | if cur_pos == pos_tail[-1]: 103 | tokens.append('[unused3]') 104 | cur_pos += 1 105 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens) 106 | 107 | # padding 108 | while len(indexed_tokens) < self.max_length: 109 | indexed_tokens.append(0) 110 | indexed_tokens = indexed_tokens[:self.max_length] 111 | 112 | # pos 113 | pos1 = np.zeros((self.max_length), dtype=np.int32) 114 | pos2 = np.zeros((self.max_length), dtype=np.int32) 115 | for i in range(self.max_length): 116 | pos1[i] = i - pos1_in_index + self.max_length 117 | pos2[i] = i - pos2_in_index + self.max_length 118 | 119 | # mask 120 | mask = np.zeros((self.max_length), dtype=np.int32) 121 | mask[:len(tokens)] = 1 122 | 123 | pos1_in_index = min(self.max_length, pos1_in_index) 124 | pos2_in_index = min(self.max_length, pos2_in_index) 125 | 126 | return indexed_tokens, pos1_in_index - 1, pos2_in_index - 1, mask 127 | 128 | class BERTPAIRSentenceEncoder(nn.Module): 129 | 130 | def __init__(self, pretrain_path, max_length): 131 | nn.Module.__init__(self) 132 | self.bert = BertForSequenceClassification.from_pretrained( 133 | pretrain_path, 134 | num_labels=2) 135 | self.max_length = max_length 136 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 137 | 138 | def forward(self, inputs): 139 | x = self.bert(inputs['word'], token_type_ids=inputs['seg'], attention_mask=inputs['mask'])[0] 140 | return x 141 | 142 | def tokenize(self, raw_tokens, pos_head, pos_tail): 143 | # token -> index 144 | # tokens = ['[CLS]'] 145 | tokens = [] 146 | cur_pos = 0 147 | pos1_in_index = 0 148 | pos2_in_index = 0 149 | for token in raw_tokens: 150 | token = token.lower() 151 | if cur_pos == pos_head[0]: 152 | tokens.append('[unused0]') 153 | pos1_in_index = len(tokens) 154 | if cur_pos == pos_tail[0]: 155 | tokens.append('[unused1]') 156 | pos2_in_index = len(tokens) 157 | tokens += self.tokenizer.tokenize(token) 158 | if cur_pos == pos_head[-1]: 159 | tokens.append('[unused2]') 160 | if cur_pos == pos_tail[-1]: 161 | tokens.append('[unused3]') 162 | cur_pos += 1 163 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens) 164 | 165 | return indexed_tokens 166 | 167 | class RobertaSentenceEncoder(nn.Module): 168 | 169 | def __init__(self, pretrain_path, max_length, cat_entity_rep=False): 170 | nn.Module.__init__(self) 171 | self.roberta = RobertaModel.from_pretrained(pretrain_path) 172 | self.max_length = max_length 173 | self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 174 | self.cat_entity_rep = cat_entity_rep 175 | 176 | def forward(self, inputs): 177 | if not self.cat_entity_rep: 178 | _, x = self.roberta(inputs['word'], attention_mask=inputs['mask']) 179 | return x 180 | else: 181 | outputs = self.roberta(inputs['word'], attention_mask=inputs['mask']) 182 | tensor_range = torch.arange(inputs['word'].size()[0]) 183 | h_state = outputs[0][tensor_range, inputs["pos1"]] 184 | t_state = outputs[0][tensor_range, inputs["pos2"]] 185 | state = torch.cat((h_state, t_state), -1) 186 | return state 187 | 188 | 189 | def tokenize(self, raw_tokens, pos_head, pos_tail): 190 | def getIns(bped, bpeTokens, tokens, L): 191 | resL = 0 192 | tkL = " ".join(tokens[:L]) 193 | bped_tkL = " ".join(self.tokenizer.tokenize(tkL)) 194 | if bped.find(bped_tkL) == 0: 195 | resL = len(bped_tkL.split()) 196 | else: 197 | tkL += " " 198 | bped_tkL = " ".join(self.tokenizer.tokenize(tkL)) 199 | if bped.find(bped_tkL) == 0: 200 | resL = len(bped_tkL.split()) 201 | else: 202 | raise Exception("Cannot locate the position") 203 | return resL 204 | 205 | s = " ".join(raw_tokens) 206 | sst = self.tokenizer.tokenize(s) 207 | headL = pos_head[0] 208 | headR = pos_head[-1] + 1 209 | hiL = getIns(" ".join(sst), sst, raw_tokens, headL) 210 | hiR = getIns(" ".join(sst), sst, raw_tokens, headR) 211 | 212 | tailL = pos_tail[0] 213 | tailR = pos_tail[-1] + 1 214 | tiL = getIns(" ".join(sst), sst, raw_tokens, tailL) 215 | tiR = getIns(" ".join(sst), sst, raw_tokens, tailR) 216 | 217 | E1b = 'madeupword0000' 218 | E1e = 'madeupword0001' 219 | E2b = 'madeupword0002' 220 | E2e = 'madeupword0003' 221 | ins = [(hiL, E1b), (hiR, E1e), (tiL, E2b), (tiR, E2e)] 222 | ins = sorted(ins) 223 | pE1 = 0 224 | pE2 = 0 225 | pE1_ = 0 226 | pE2_ = 0 227 | for i in range(0, 4): 228 | sst.insert(ins[i][0] + i, ins[i][1]) 229 | if ins[i][1] == E1b: 230 | pE1 = ins[i][0] + i 231 | elif ins[i][1] == E2b: 232 | pE2 = ins[i][0] + i 233 | elif ins[i][1] == E1e: 234 | pE1_ = ins[i][0] + i 235 | else: 236 | pE2_ = ins[i][0] + i 237 | pos1_in_index = pE1 + 1 238 | pos2_in_index = pE2 + 1 239 | sst = [''] + sst 240 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(sst) 241 | 242 | # padding 243 | while len(indexed_tokens) < self.max_length: 244 | indexed_tokens.append(1) 245 | indexed_tokens = indexed_tokens[:self.max_length] 246 | 247 | # pos 248 | pos1 = np.zeros((self.max_length), dtype=np.int32) 249 | pos2 = np.zeros((self.max_length), dtype=np.int32) 250 | for i in range(self.max_length): 251 | pos1[i] = i - pos1_in_index + self.max_length 252 | pos2[i] = i - pos2_in_index + self.max_length 253 | 254 | # mask 255 | mask = np.zeros((self.max_length), dtype=np.int32) 256 | mask[:len(sst)] = 1 257 | 258 | pos1_in_index = min(self.max_length, pos1_in_index) 259 | pos2_in_index = min(self.max_length, pos2_in_index) 260 | 261 | return indexed_tokens, pos1_in_index, pos2_in_index, mask 262 | 263 | 264 | class RobertaPAIRSentenceEncoder(nn.Module): 265 | 266 | def __init__(self, pretrain_path, max_length): 267 | nn.Module.__init__(self) 268 | self.roberta = RobertaForSequenceClassification.from_pretrained( 269 | pretrain_path, 270 | num_labels=2) 271 | self.max_length = max_length 272 | self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 273 | 274 | def forward(self, inputs): 275 | x = self.roberta(inputs['word'], attention_mask=inputs['mask'])[0] 276 | return x 277 | 278 | def tokenize(self, raw_tokens, pos_head, pos_tail): 279 | def getIns(bped, bpeTokens, tokens, L): 280 | resL = 0 281 | tkL = " ".join(tokens[:L]) 282 | bped_tkL = " ".join(self.tokenizer.tokenize(tkL)) 283 | if bped.find(bped_tkL) == 0: 284 | resL = len(bped_tkL.split()) 285 | else: 286 | tkL += " " 287 | bped_tkL = " ".join(self.tokenizer.tokenize(tkL)) 288 | if bped.find(bped_tkL) == 0: 289 | resL = len(bped_tkL.split()) 290 | else: 291 | raise Exception("Cannot locate the position") 292 | return resL 293 | 294 | s = " ".join(raw_tokens) 295 | sst = self.tokenizer.tokenize(s) 296 | headL = pos_head[0] 297 | headR = pos_head[-1] + 1 298 | hiL = getIns(" ".join(sst), sst, raw_tokens, headL) 299 | hiR = getIns(" ".join(sst), sst, raw_tokens, headR) 300 | 301 | tailL = pos_tail[0] 302 | tailR = pos_tail[-1] + 1 303 | tiL = getIns(" ".join(sst), sst, raw_tokens, tailL) 304 | tiR = getIns(" ".join(sst), sst, raw_tokens, tailR) 305 | 306 | E1b = 'madeupword0000' 307 | E1e = 'madeupword0001' 308 | E2b = 'madeupword0002' 309 | E2e = 'madeupword0003' 310 | ins = [(hiL, E1b), (hiR, E1e), (tiL, E2b), (tiR, E2e)] 311 | ins = sorted(ins) 312 | for i in range(0, 4): 313 | sst.insert(ins[i][0] + i, ins[i][1]) 314 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(sst) 315 | return indexed_tokens 316 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models import proto 2 | from models import snail 3 | from models import gnn 4 | from models import metanet 5 | from models import siamese 6 | from models import proto_norm 7 | from models import mtb -------------------------------------------------------------------------------- /models/d.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | 9 | class Discriminator(nn.Module): 10 | 11 | def __init__(self, hidden_size=230, num_labels=2): 12 | nn.Module.__init__(self) 13 | self.hidden_size = hidden_size 14 | self.num_labels = num_labels 15 | self.fc1 = nn.Linear(hidden_size, hidden_size) 16 | self.relu1 = nn.ReLU() 17 | self.drop = nn.Dropout() 18 | self.fc2 = nn.Linear(hidden_size, 2) 19 | 20 | def forward(self, x): 21 | x = self.fc1(x) 22 | x = self.relu1(x) 23 | x = self.drop(x) 24 | logits = self.fc2(x) 25 | return logits 26 | -------------------------------------------------------------------------------- /models/gnn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | from . import gnn_iclr 9 | 10 | class GNN(fewshot_re_kit.framework.FewShotREModel): 11 | 12 | def __init__(self, sentence_encoder, N, hidden_size=230): 13 | ''' 14 | N: Num of classes 15 | ''' 16 | fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder) 17 | self.hidden_size = hidden_size 18 | self.node_dim = hidden_size + N 19 | self.gnn_obj = gnn_iclr.GNN_nl(N, self.node_dim, nf=96, J=1) 20 | 21 | def forward(self, support, query, N, K, NQ): 22 | ''' 23 | support: Inputs of the support set. 24 | query: Inputs of the query set. 25 | N: Num of classes 26 | K: Num of instances for each class in the support set 27 | Q: Num of instances for each class in the query set 28 | ''' 29 | support = self.sentence_encoder(support) 30 | query = self.sentence_encoder(query) 31 | support = support.view(-1, N, K, self.hidden_size) 32 | query = query.view(-1, NQ, self.hidden_size) 33 | 34 | B = support.size(0) 35 | D = self.hidden_size 36 | 37 | support = support.unsqueeze(1).expand(-1, NQ, -1, -1, -1).contiguous().view(-1, N * K, D) # (B * NQ, N * K, D) 38 | query = query.view(-1, 1, D) # (B * NQ, 1, D) 39 | labels = Variable(torch.zeros((B * NQ, 1 + N * K, N), dtype=torch.float)).cuda() 40 | for b in range(B * NQ): 41 | for i in range(N): 42 | for k in range(K): 43 | labels[b][1 + i * K + k][i] = 1 44 | nodes = torch.cat([torch.cat([query, support], 1), labels], -1) # (B * NQ, 1 + N * K, D + N) 45 | 46 | logits = self.gnn_obj(nodes) # (B * NQ, N) 47 | _, pred = torch.max(logits, 1) 48 | return logits, pred 49 | -------------------------------------------------------------------------------- /models/gnn_iclr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | 4 | # Pytorch requirements 5 | 6 | ''' 7 | GNN models implemented by vgsatorras from https://github.com/vgsatorras/few-shot-gnn 8 | ''' 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | import torch.nn.functional as F 14 | 15 | if torch.cuda.is_available(): 16 | dtype = torch.cuda.FloatTensor 17 | dtype_l = torch.cuda.LongTensor 18 | else: 19 | dtype = torch.FloatTensor 20 | dtype_l = torch.cuda.LongTensor 21 | 22 | 23 | def gmul(input): 24 | W, x = input 25 | # x is a tensor of size (bs, N, num_features) 26 | # W is a tensor of size (bs, N, N, J) 27 | x_size = x.size() 28 | W_size = W.size() 29 | N = W_size[-2] 30 | W = W.split(1, 3) 31 | W = torch.cat(W, 1).squeeze(3) # W is now a tensor of size (bs, J*N, N) 32 | output = torch.bmm(W, x) # output has size (bs, J*N, num_features) 33 | output = output.split(N, 1) 34 | output = torch.cat(output, 2) # output has size (bs, N, J*num_features) 35 | return output 36 | 37 | 38 | class Gconv(nn.Module): 39 | def __init__(self, nf_input, nf_output, J, bn_bool=True): 40 | super(Gconv, self).__init__() 41 | self.J = J 42 | self.num_inputs = J*nf_input 43 | self.num_outputs = nf_output 44 | self.fc = nn.Linear(self.num_inputs, self.num_outputs) 45 | 46 | self.bn_bool = bn_bool 47 | if self.bn_bool: 48 | self.bn = nn.BatchNorm1d(self.num_outputs) 49 | 50 | def forward(self, input): 51 | W = input[0] 52 | x = gmul(input) # out has size (bs, N, num_inputs) 53 | #if self.J == 1: 54 | # x = torch.abs(x) 55 | x_size = x.size() 56 | x = x.contiguous() 57 | x = x.view(-1, self.num_inputs) 58 | x = self.fc(x) # has size (bs*N, num_outputs) 59 | 60 | if self.bn_bool: 61 | x = self.bn(x) 62 | 63 | x = x.view(x_size[0], x_size[1], self.num_outputs) 64 | return W, x 65 | 66 | 67 | class Wcompute(nn.Module): 68 | def __init__(self, input_features, nf, operator='J2', activation='softmax', ratio=[2,2,1,1], num_operators=1, drop=False): 69 | super(Wcompute, self).__init__() 70 | self.num_features = nf 71 | self.operator = operator 72 | self.conv2d_1 = nn.Conv2d(input_features, int(nf * ratio[0]), 1, stride=1) 73 | self.bn_1 = nn.BatchNorm2d(int(nf * ratio[0])) 74 | self.drop = drop 75 | if self.drop: 76 | self.dropout = nn.Dropout(0.3) 77 | self.conv2d_2 = nn.Conv2d(int(nf * ratio[0]), int(nf * ratio[1]), 1, stride=1) 78 | self.bn_2 = nn.BatchNorm2d(int(nf * ratio[1])) 79 | self.conv2d_3 = nn.Conv2d(int(nf * ratio[1]), nf*ratio[2], 1, stride=1) 80 | self.bn_3 = nn.BatchNorm2d(nf*ratio[2]) 81 | self.conv2d_4 = nn.Conv2d(nf*ratio[2], nf*ratio[3], 1, stride=1) 82 | self.bn_4 = nn.BatchNorm2d(nf*ratio[3]) 83 | self.conv2d_last = nn.Conv2d(nf, num_operators, 1, stride=1) 84 | self.activation = activation 85 | 86 | def forward(self, x, W_id): 87 | W1 = x.unsqueeze(2) 88 | W2 = torch.transpose(W1, 1, 2) #size: bs x N x N x num_features 89 | W_new = torch.abs(W1 - W2) #size: bs x N x N x num_features 90 | W_new = torch.transpose(W_new, 1, 3) #size: bs x num_features x N x N 91 | 92 | W_new = self.conv2d_1(W_new) 93 | W_new = self.bn_1(W_new) 94 | W_new = F.leaky_relu(W_new) 95 | if self.drop: 96 | W_new = self.dropout(W_new) 97 | 98 | W_new = self.conv2d_2(W_new) 99 | W_new = self.bn_2(W_new) 100 | W_new = F.leaky_relu(W_new) 101 | 102 | W_new = self.conv2d_3(W_new) 103 | W_new = self.bn_3(W_new) 104 | W_new = F.leaky_relu(W_new) 105 | 106 | W_new = self.conv2d_4(W_new) 107 | W_new = self.bn_4(W_new) 108 | W_new = F.leaky_relu(W_new) 109 | 110 | W_new = self.conv2d_last(W_new) 111 | W_new = torch.transpose(W_new, 1, 3) #size: bs x N x N x 1 112 | 113 | if self.activation == 'softmax': 114 | W_new = W_new - W_id.expand_as(W_new) * 1e8 115 | W_new = torch.transpose(W_new, 2, 3) 116 | # Applying Softmax 117 | W_new = W_new.contiguous() 118 | W_new_size = W_new.size() 119 | W_new = W_new.view(-1, W_new.size(3)) 120 | W_new = F.softmax(W_new) 121 | W_new = W_new.view(W_new_size) 122 | # Softmax applied 123 | W_new = torch.transpose(W_new, 2, 3) 124 | 125 | elif self.activation == 'sigmoid': 126 | W_new = F.sigmoid(W_new) 127 | W_new *= (1 - W_id) 128 | elif self.activation == 'none': 129 | W_new *= (1 - W_id) 130 | else: 131 | raise (NotImplementedError) 132 | 133 | if self.operator == 'laplace': 134 | W_new = W_id - W_new 135 | elif self.operator == 'J2': 136 | W_new = torch.cat([W_id, W_new], 3) 137 | else: 138 | raise(NotImplementedError) 139 | 140 | return W_new 141 | 142 | 143 | class GNN_nl_omniglot(nn.Module): 144 | def __init__(self, args, input_features, nf, J): 145 | super(GNN_nl_omniglot, self).__init__() 146 | self.args = args 147 | self.input_features = input_features 148 | self.nf = nf 149 | self.J = J 150 | 151 | self.num_layers = 2 152 | for i in range(self.num_layers): 153 | module_w = Wcompute(self.input_features + int(nf / 2) * i, 154 | self.input_features + int(nf / 2) * i, 155 | operator='J2', activation='softmax', ratio=[2, 1.5, 1, 1], drop=False) 156 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 157 | self.add_module('layer_w{}'.format(i), module_w) 158 | self.add_module('layer_l{}'.format(i), module_l) 159 | 160 | self.w_comp_last = Wcompute(self.input_features + int(self.nf / 2) * self.num_layers, 161 | self.input_features + int(self.nf / 2) * (self.num_layers - 1), 162 | operator='J2', activation='softmax', ratio=[2, 1.5, 1, 1], drop=True) 163 | self.layer_last = Gconv(self.input_features + int(self.nf / 2) * self.num_layers, args.train_N_way, 2, bn_bool=True) 164 | 165 | def forward(self, x): 166 | W_init = Variable(torch.eye(x.size(1)).unsqueeze(0).repeat(x.size(0), 1, 1).unsqueeze(3)) 167 | if self.args.cuda: 168 | W_init = W_init.cuda() 169 | 170 | for i in range(self.num_layers): 171 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 172 | 173 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 174 | x = torch.cat([x, x_new], 2) 175 | 176 | Wl=self.w_comp_last(x, W_init) 177 | out = self.layer_last([Wl, x])[1] 178 | 179 | return out[:, 0, :] 180 | 181 | 182 | class GNN_nl(nn.Module): 183 | def __init__(self, N, input_features, nf, J): 184 | super(GNN_nl, self).__init__() 185 | # self.args = args 186 | self.input_features = input_features 187 | self.nf = nf 188 | self.J = J 189 | 190 | self.num_layers = 2 191 | 192 | for i in range(self.num_layers): 193 | if i == 0: 194 | module_w = Wcompute(self.input_features, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 195 | module_l = Gconv(self.input_features, int(nf / 2), 2) 196 | else: 197 | module_w = Wcompute(self.input_features + int(nf / 2) * i, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 198 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 199 | self.add_module('layer_w{}'.format(i), module_w) 200 | self.add_module('layer_l{}'.format(i), module_l) 201 | 202 | self.w_comp_last = Wcompute(self.input_features + int(self.nf / 2) * self.num_layers, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 203 | self.layer_last = Gconv(self.input_features + int(self.nf / 2) * self.num_layers, N, 2, bn_bool=False) 204 | 205 | def forward(self, x): 206 | W_init = Variable(torch.eye(x.size(1)).unsqueeze(0).repeat(x.size(0), 1, 1).unsqueeze(3)) 207 | W_init = W_init.cuda() 208 | 209 | for i in range(self.num_layers): 210 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 211 | 212 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 213 | x = torch.cat([x, x_new], 2) 214 | 215 | Wl=self.w_comp_last(x, W_init) 216 | out = self.layer_last([Wl, x])[1] 217 | 218 | return out[:, 0, :] 219 | 220 | class GNN_active(nn.Module): 221 | def __init__(self, args, input_features, nf, J): 222 | super(GNN_active, self).__init__() 223 | self.args = args 224 | self.input_features = input_features 225 | self.nf = nf 226 | self.J = J 227 | 228 | self.num_layers = 2 229 | for i in range(self.num_layers // 2): 230 | if i == 0: 231 | module_w = Wcompute(self.input_features, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 232 | module_l = Gconv(self.input_features, int(nf / 2), 2) 233 | else: 234 | module_w = Wcompute(self.input_features + int(nf / 2) * i, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 235 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 236 | 237 | self.add_module('layer_w{}'.format(i), module_w) 238 | self.add_module('layer_l{}'.format(i), module_l) 239 | 240 | self.conv_active_1 = nn.Conv1d(self.input_features + int(nf / 2) * 1, self.input_features + int(nf / 2) * 1, 1) 241 | self.bn_active = nn.BatchNorm1d(self.input_features + int(nf / 2) * 1) 242 | self.conv_active_2 = nn.Conv1d(self.input_features + int(nf / 2) * 1, 1, 1) 243 | 244 | for i in range(int(self.num_layers/2), self.num_layers): 245 | if i == 0: 246 | module_w = Wcompute(self.input_features, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 247 | module_l = Gconv(self.input_features, int(nf / 2), 2) 248 | else: 249 | module_w = Wcompute(self.input_features + int(nf / 2) * i, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 250 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 251 | self.add_module('layer_w{}'.format(i), module_w) 252 | self.add_module('layer_l{}'.format(i), module_l) 253 | 254 | self.w_comp_last = Wcompute(self.input_features + int(self.nf / 2) * self.num_layers, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 255 | self.layer_last = Gconv(self.input_features + int(self.nf / 2) * self.num_layers, args.train_N_way, 2, bn_bool=False) 256 | 257 | def active(self, x, oracles_yi, hidden_labels): 258 | x_active = torch.transpose(x, 1, 2) 259 | x_active = self.conv_active_1(x_active) 260 | x_active = F.leaky_relu(self.bn_active(x_active)) 261 | x_active = self.conv_active_2(x_active) 262 | x_active = torch.transpose(x_active, 1, 2) 263 | 264 | x_active = x_active.squeeze(-1) 265 | x_active = x_active - (1-hidden_labels)*1e8 266 | x_active = F.softmax(x_active) 267 | x_active = x_active*hidden_labels 268 | 269 | if self.args.active_random == 1: 270 | #print('random active') 271 | x_active.data.fill_(1./x_active.size(1)) 272 | decision = torch.multinomial(x_active) 273 | x_active = x_active.detach() 274 | else: 275 | if self.training: 276 | decision = torch.multinomial(x_active) 277 | else: 278 | _, decision = torch.max(x_active, 1) 279 | decision = decision.unsqueeze(-1) 280 | 281 | decision = decision.detach() 282 | 283 | mapping = torch.FloatTensor(decision.size(0),x_active.size(1)).zero_() 284 | mapping = Variable(mapping) 285 | if self.args.cuda: 286 | mapping = mapping.cuda() 287 | mapping.scatter_(1, decision, 1) 288 | 289 | mapping_bp = (x_active*mapping).unsqueeze(-1) 290 | mapping_bp = mapping_bp.expand_as(oracles_yi) 291 | 292 | label2add = mapping_bp*oracles_yi #bsxNodesxN_way 293 | padd = torch.zeros(x.size(0), x.size(1), x.size(2) - label2add.size(2)) 294 | padd = Variable(padd).detach() 295 | if self.args.cuda: 296 | padd = padd.cuda() 297 | label2add = torch.cat([label2add, padd], 2) 298 | 299 | x = x+label2add 300 | return x 301 | 302 | 303 | def forward(self, x, oracles_yi, hidden_labels): 304 | W_init = Variable(torch.eye(x.size(1)).unsqueeze(0).repeat(x.size(0), 1, 1).unsqueeze(3)) 305 | if self.args.cuda: 306 | W_init = W_init.cuda() 307 | 308 | for i in range(self.num_layers // 2): 309 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 310 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 311 | x = torch.cat([x, x_new], 2) 312 | 313 | x = self.active(x, oracles_yi, hidden_labels) 314 | 315 | for i in range(int(self.num_layers/2), self.num_layers): 316 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 317 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 318 | x = torch.cat([x, x_new], 2) 319 | 320 | 321 | Wl=self.w_comp_last(x, W_init) 322 | out = self.layer_last([Wl, x])[1] 323 | 324 | return out[:, 0, :] 325 | 326 | if __name__ == '__main__': 327 | # test modules 328 | bs = 4 329 | nf = 10 330 | num_layers = 5 331 | N = 8 332 | x = torch.ones((bs, N, nf)) 333 | W1 = torch.eye(N).unsqueeze(0).unsqueeze(-1).expand(bs, N, N, 1) 334 | W2 = torch.ones(N).unsqueeze(0).unsqueeze(-1).expand(bs, N, N, 1) 335 | J = 2 336 | W = torch.cat((W1, W2), 3) 337 | input = [Variable(W), Variable(x)] 338 | ######################### test gmul ############################## 339 | # feature_maps = [num_features, num_features, num_features] 340 | # out = gmul(input) 341 | # print(out[0, :, num_features:]) 342 | ######################### test gconv ############################## 343 | # feature_maps = [num_features, num_features, num_features] 344 | # gconv = Gconv(feature_maps, J) 345 | # _, out = gconv(input) 346 | # print(out.size()) 347 | ######################### test gnn ############################## 348 | # x = torch.ones((bs, N, 1)) 349 | # input = [Variable(W), Variable(x)] 350 | # gnn = GNN(num_features, num_layers, J) 351 | # out = gnn(input) 352 | # print(out.size()) 353 | 354 | 355 | -------------------------------------------------------------------------------- /models/metanet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | from fewshot_re_kit.network.embedding import Embedding 5 | from fewshot_re_kit.network.encoder import Encoder 6 | import torch 7 | from torch import autograd, optim, nn 8 | from torch.autograd import Variable 9 | from torch.nn import functional as F 10 | import numpy as np 11 | 12 | def log_and_sign(inputs, k=7): 13 | eps = 1e-7 14 | log = torch.log(torch.abs(inputs) + eps) / k 15 | log[log < -1.0] = -1.0 16 | sign = log * np.exp(k) 17 | sign[sign < -1.0] = -1.0 18 | sign[sign > 1.0] = 1.0 19 | return torch.cat([log, sign], 1) 20 | 21 | class LearnerForAttention(nn.Module): 22 | 23 | def __init__(self): 24 | nn.Module.__init__(self) 25 | self.conv_lstm = nn.LSTM(2, 20, batch_first=True) 26 | self.conv_fc = nn.Linear(20, 1) 27 | self.fc_lstm = nn.LSTM(2, 20, batch_first=True) 28 | self.fc_fc = nn.Linear(20, 1) 29 | 30 | def forward(self, inputs, is_conv): 31 | size = inputs.size() 32 | x = inputs.view((-1, 1)) 33 | x = log_and_sign(x) # (-1, 2) 34 | 35 | #### NO BACKPROP 36 | x = Variable(x, requires_grad=False).unsqueeze(0) # (1, param_size, 2) 37 | #### 38 | 39 | if is_conv: 40 | x, _ = self.conv_lstm(x) # (1, param_size, 1) 41 | x = x.squeeze() 42 | x = self.conv_fc(x) 43 | else: 44 | x, _ = self.fc_lstm(x) # (1, param_size, 1) 45 | x = x.squeeze() 46 | x = self.fc_fc(x) 47 | return x.view(size) 48 | 49 | class LearnerForBasic(nn.Module): 50 | 51 | def __init__(self): 52 | nn.Module.__init__(self) 53 | self.conv_fc1 = nn.Linear(2, 20) 54 | self.conv_fc2 = nn.Linear(20, 20) 55 | self.conv_fc3 = nn.Linear(20, 1) 56 | self.fc_fc1 = nn.Linear(2, 20) 57 | self.fc_fc2 = nn.Linear(20, 20) 58 | self.fc_fc3 = nn.Linear(20, 1) 59 | 60 | 61 | def forward(self, inputs, is_conv): 62 | size = inputs.size() 63 | x = inputs.view((-1, 1)) 64 | x = log_and_sign(x) # (-1, 2) 65 | 66 | #### NO BACKPROP 67 | x = Variable(x, requires_grad=False) 68 | #### 69 | 70 | if is_conv: 71 | x = F.relu(self.conv_fc1(x)) 72 | x = F.relu(self.conv_fc2(x)) 73 | x = self.conv_fc3(x) 74 | else: 75 | x = F.relu(self.fc_fc1(x)) 76 | x = F.relu(self.fc_fc2(x)) 77 | x = self.fc_fc3(x) 78 | return x.view(size) 79 | 80 | class MetaNet(fewshot_re_kit.framework.FewShotREModel): 81 | 82 | def __init__(self, N, K, embedding, max_length, hidden_size=230): 83 | ''' 84 | N: num of classes 85 | K: num of instances for each class 86 | word_vec_mat, max_length, hidden_size: same as sentence_encoder 87 | ''' 88 | fewshot_re_kit.framework.FewShotREModel.__init__(self, None) 89 | self.max_length = max_length 90 | self.hidden_size = hidden_size 91 | self.N = N 92 | self.K = K 93 | 94 | # self.embedding = Embedding(word_vec_mat, max_length, word_embedding_dim=50, pos_embedding_dim=5) 95 | self.embedding = embedding 96 | 97 | self.basic_encoder = Encoder(max_length, word_embedding_dim=50, pos_embedding_dim=5, hidden_size=hidden_size) 98 | self.attention_encoder = Encoder(max_length, word_embedding_dim=50, pos_embedding_dim=5, hidden_size=hidden_size) 99 | 100 | self.basic_fast_conv_W = None 101 | self.attention_fast_conv_W = None 102 | 103 | self.basic_fc = nn.Linear(hidden_size, N, bias=False) 104 | self.attention_fc = nn.Linear(hidden_size, N, bias=False) 105 | 106 | self.basic_fast_fc_W = None 107 | self.attention_fast_fc_W = None 108 | 109 | self.learner_basic = LearnerForBasic() 110 | self.learner_attention = LearnerForAttention() 111 | 112 | def basic_emb(self, inputs, size, use_fast=False): 113 | x = self.embedding(inputs) 114 | output = self.basic_encoder(x) 115 | if use_fast: 116 | output += F.relu(F.conv1d(x.transpose(-1, -2), self.basic_fast_conv_W, padding=1)).max(-1)[0] 117 | return output.view(size) 118 | 119 | def attention_emb(self, inputs, size, use_fast=False): 120 | x = self.embedding(inputs) 121 | output = self.attention_encoder(x) 122 | if use_fast: 123 | output += F.relu(F.conv1d(x.transpose(-1, -2), self.attention_fast_conv_W, padding=1)).max(-1)[0] 124 | return output.view(size) 125 | 126 | def attention_score(self, s_att, q_att): 127 | ''' 128 | s_att: (B, N, K, D) 129 | q_att: (B, NQ, D) 130 | ''' 131 | s_att = s_att.view(s_att.size(0), s_att.size(1) * s_att.size(2), s_att.size(3)) # (B, N * K, D) 132 | s_att = s_att.unsqueeze(1) # (B, 1, N * K, D) 133 | q_att = q_att.unsqueeze(2) # (B, NQ, 1, D) 134 | cos = F.cosine_similarity(s_att, q_att, dim=-1) # (B, NQ, N * K) 135 | score = F.softmax(cos, -1) # (B, NQ, N * K) 136 | return score 137 | 138 | def forward(self, support, query, N, K, Q): 139 | ''' 140 | support: Inputs of the support set. 141 | query: Inputs of the query set. 142 | N: Num of classes 143 | K: Num of instances for each class in the support set 144 | Q: Num of instances for each class in the query set 145 | ''' 146 | 147 | # learn fast parameters for attention encoder 148 | s = self.attention_emb(support, (-1, N, K, self.hidden_size)) 149 | logits = self.attention_fc(s) # (B, N, K, N) 150 | 151 | B = s.size(0) 152 | NQ = N * Q 153 | assert(B == 1) 154 | 155 | self.zero_grad() 156 | tmp_label = Variable(torch.tensor([[x] * K for x in range(N)] * B, dtype=torch.long).cuda()) 157 | loss = self.cost(logits.view(-1, N), tmp_label.view(-1)) 158 | loss.backward(retain_graph=True) 159 | 160 | grad_conv = self.attention_encoder.conv.weight.grad 161 | grad_fc = self.attention_fc.weight.grad 162 | 163 | self.attention_fast_conv_W = self.learner_attention(grad_conv, is_conv=True) 164 | self.attention_fast_fc_W = self.learner_attention(grad_fc, is_conv=False) 165 | 166 | # learn fast parameters for basic encoder (each class) 167 | s = self.basic_emb(support, (-1, N, K, self.hidden_size)) 168 | logits = self.basic_fc(s) # (B, N, K, N) 169 | 170 | basic_fast_conv_params = [] 171 | basic_fast_fc_params = [] 172 | for i in range(N): 173 | for j in range(K): 174 | self.zero_grad() 175 | tmp_label = Variable(torch.tensor([i], dtype=torch.long).cuda()) 176 | loss = self.cost(logits[:, i, j].view(-1, N), tmp_label.view(-1)) 177 | loss.backward(retain_graph=True) 178 | 179 | grad_conv = self.basic_encoder.conv.weight.grad 180 | grad_fc = self.basic_fc.weight.grad 181 | 182 | basic_fast_conv_params.append(self.learner_basic(grad_conv, is_conv=True)) 183 | basic_fast_fc_params.append(self.learner_basic(grad_fc, is_conv=False)) 184 | basic_fast_conv_params = torch.stack(basic_fast_conv_params, 0) # (N * K, conv_weight_size) 185 | basic_fast_fc_params = torch.stack(basic_fast_fc_params, 0) # (N * K, fc_weight_size) 186 | 187 | # final 188 | self.zero_grad() 189 | s_att = self.attention_emb(support, (-1, N, K, self.hidden_size), use_fast=True) 190 | q_att = self.attention_emb(query, (-1, NQ, self.hidden_size), use_fast=True) 191 | score = self.attention_score(s_att, q_att).squeeze(0) # assume B = 1, (NQ, N * K) 192 | size_conv_param = basic_fast_conv_params.size()[1:] 193 | size_fc_param = basic_fast_fc_params.size()[1:] 194 | final_fast_conv_param = torch.matmul(score, basic_fast_conv_params.view(N * K, -1)) # (NQ, conv_weight_size) 195 | final_fast_fc_param = torch.matmul(score, basic_fast_fc_params.view(N * K, -1)) # (NQ, fc_weight_size) 196 | stack_logits = [] 197 | for i in range(NQ): 198 | self.basic_fast_conv_W = final_fast_conv_param[i].view(size_conv_param) 199 | self.basic_fast_fc_W = final_fast_fc_param[i].view(size_fc_param) 200 | q = self.basic_emb({'word': query['word'][i:i+1], 'pos1': query['pos1'][i:i+1], 'pos2': query['pos2'][i:i+1], 'mask': query['mask'][i:i+1]}, (self.hidden_size), use_fast=True) 201 | logits = self.basic_fc(q) + F.linear(q, self.basic_fast_fc_W) 202 | stack_logits.append(logits) 203 | logits = torch.stack(stack_logits, 0) 204 | 205 | _, pred = torch.max(logits.view(-1, N), 1) 206 | return logits, pred 207 | -------------------------------------------------------------------------------- /models/mtb.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | 9 | class Mtb(fewshot_re_kit.framework.FewShotREModel): 10 | """ 11 | Use the same few-shot model as the paper "Matching the Blanks: Distributional Similarity for Relation Learning". 12 | """ 13 | 14 | def __init__(self, sentence_encoder, use_dropout=True, combiner="max"): 15 | fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder) 16 | # self.fc = nn.Linear(hidden_size, hidden_size) 17 | self.drop = nn.Dropout() 18 | self.use_dropout = use_dropout 19 | self.layer_norm = torch.nn.LayerNorm(sentence_encoder.bert.config.hidden_size * (2 if sentence_encoder.cat_entity_rep else 1)) 20 | self.combiner = combiner 21 | 22 | def __dist__(self, x, y, dim): 23 | return (x * y).sum(dim) 24 | 25 | def __batch_dist__(self, S, Q): 26 | return self.__dist__(S.unsqueeze(1), Q.unsqueeze(2), 3) 27 | 28 | def forward(self, support, query, N, K, total_Q): 29 | ''' 30 | support: Inputs of the support set. 31 | query: Inputs of the query set. 32 | N: Num of classes 33 | K: Num of instances for each class in the support set 34 | Q: Num of instances in the query set 35 | ''' 36 | support = self.sentence_encoder(support) # (B * N * K, D), where D is the hidden size 37 | query = self.sentence_encoder(query) # (B * total_Q, D) 38 | hidden_size = support.size(-1) 39 | if self.use_dropout: 40 | support = self.drop(support) 41 | query = self.drop(query) 42 | support = self.layer_norm(support) 43 | query = self.layer_norm(query) 44 | support = support.view(-1, N, K, hidden_size).unsqueeze(1) # (B, 1, N, K, D) 45 | query = query.view(-1, total_Q, hidden_size).unsqueeze(2).unsqueeze(2) # (B, total_Q, 1, 1, D) 46 | 47 | logits = (support * query).sum(-1) # (B, total_Q, N, K) 48 | 49 | # aggregate result 50 | if self.combiner == "max": 51 | combined_logits, _ = logits.max(-1) # (B, total, N) 52 | elif self.combiner == "avg": 53 | combined_logits = logits.mean(-1) # (B, total, N) 54 | else: 55 | raise NotImplementedError 56 | _, pred = torch.max(combined_logits.view(-1, N), -1) 57 | 58 | return combined_logits, pred 59 | 60 | 61 | -------------------------------------------------------------------------------- /models/pair.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | 9 | class Pair(fewshot_re_kit.framework.FewShotREModel): 10 | 11 | def __init__(self, sentence_encoder, hidden_size=230): 12 | fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder) 13 | self.hidden_size = hidden_size 14 | # self.fc = nn.Linear(hidden_size, hidden_size) 15 | self.drop = nn.Dropout() 16 | 17 | def forward(self, batch, N, K, total_Q): 18 | ''' 19 | support: Inputs of the support set. 20 | query: Inputs of the query set. 21 | N: Num of classes 22 | K: Num of instances for each class in the support set 23 | Q: Num of instances in the query set 24 | ''' 25 | logits = self.sentence_encoder(batch) 26 | logits = logits.view(-1, total_Q, N, K, 2) 27 | logits = logits.mean(3) # (-1, total_Q, N, 2) 28 | logits_na, _ = logits[:, :, :, 0].min(2, keepdim=True) # (-1, totalQ, 1) 29 | logits = logits[:, :, :, 1] # (-1, total_Q, N) 30 | logits = torch.cat([logits, logits_na], 2) # (B, total_Q, N + 1) 31 | _, pred = torch.max(logits.view(-1, N + 1), 1) 32 | return logits, pred 33 | -------------------------------------------------------------------------------- /models/proto.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | 9 | class Proto(fewshot_re_kit.framework.FewShotREModel): 10 | 11 | def __init__(self, sentence_encoder, dot=False): 12 | fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder) 13 | # self.fc = nn.Linear(hidden_size, hidden_size) 14 | self.drop = nn.Dropout() 15 | self.dot = dot 16 | 17 | def __dist__(self, x, y, dim): 18 | if self.dot: 19 | return (x * y).sum(dim) 20 | else: 21 | return -(torch.pow(x - y, 2)).sum(dim) 22 | 23 | def __batch_dist__(self, S, Q): 24 | return self.__dist__(S.unsqueeze(1), Q.unsqueeze(2), 3) 25 | 26 | def forward(self, support, query, N, K, total_Q): 27 | ''' 28 | support: Inputs of the support set. 29 | query: Inputs of the query set. 30 | N: Num of classes 31 | K: Num of instances for each class in the support set 32 | Q: Num of instances in the query set 33 | ''' 34 | support_emb = self.sentence_encoder(support) # (B * N * K, D), where D is the hidden size 35 | query_emb = self.sentence_encoder(query) # (B * total_Q, D) 36 | hidden_size = support_emb.size(-1) 37 | support = self.drop(support_emb) 38 | query = self.drop(query_emb) 39 | support = support.view(-1, N, K, hidden_size) # (B, N, K, D) 40 | query = query.view(-1, total_Q, hidden_size) # (B, total_Q, D) 41 | 42 | # Prototypical Networks 43 | # Ignore NA policy 44 | support = torch.mean(support, 2) # Calculate prototype for each class 45 | logits = self.__batch_dist__(support, query) # (B, total_Q, N) 46 | minn, _ = logits.min(-1) 47 | logits = torch.cat([logits, minn.unsqueeze(2) - 1], 2) # (B, total_Q, N + 1) 48 | _, pred = torch.max(logits.view(-1, N + 1), 1) 49 | return logits, pred 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /models/proto_norm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | 9 | def l2norm(X): 10 | norm = torch.pow(X, 2).sum(dim=-1, keepdim=True).sqrt() 11 | X = torch.div(X, norm) 12 | return X 13 | 14 | class ProtoNorm(fewshot_re_kit.framework.FewShotREModel): 15 | 16 | def __init__(self, sentence_encoder, hidden_size=230): 17 | fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder) 18 | self.hidden_size = hidden_size 19 | # self.fc = nn.Linear(hidden_size, hidden_size) 20 | self.drop = nn.Dropout() 21 | 22 | def __dist__(self, x, y, dim): 23 | return (torch.pow(x - y, 2)).sum(dim) 24 | 25 | def __batch_dist__(self, S, Q): 26 | return self.__dist__(S.unsqueeze(1), Q.unsqueeze(2), 3) 27 | 28 | def forward(self, support, query, N, K, total_Q): 29 | ''' 30 | support: Inputs of the support set. 31 | query: Inputs of the query set. 32 | N: Num of classes 33 | K: Num of instances for each class in the support set 34 | Q: Num of instances in the query set 35 | ''' 36 | support = self.sentence_encoder(support) # (B * N * K, D), where D is the hidden size 37 | query = self.sentence_encoder(query) # (B * total_Q, D) 38 | support = l2norm(support) 39 | query = l2norm(query) 40 | support = self.drop(support) 41 | query = self.drop(query) 42 | support = support.view(-1, N, K, self.hidden_size) # (B, N, K, D) 43 | query = query.view(-1, total_Q, self.hidden_size) # (B, total_Q, D) 44 | 45 | # Prototypical Networks 46 | # Ignore NA policy 47 | support = torch.mean(support, 2) # Calculate prototype for each class 48 | logits = -self.__batch_dist__(support, query) # (B, total_Q, N) 49 | minn, _ = logits.min(-1) 50 | logits = torch.cat([logits, minn.unsqueeze(2) - 1], 2) # (B, total_Q, N + 1) 51 | _, pred = torch.max(logits.view(-1, N + 1), 1) 52 | return logits, pred 53 | 54 | 55 | -------------------------------------------------------------------------------- /models/siamese.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | 9 | class Siamese(fewshot_re_kit.framework.FewShotREModel): 10 | 11 | def __init__(self, sentence_encoder, hidden_size=230, dropout=0): 12 | fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder) 13 | self.hidden_size = hidden_size 14 | self.normalize = nn.LayerNorm(normalized_shape=hidden_size) 15 | self.drop = nn.Dropout(dropout) 16 | 17 | def forward(self, support, query, N, K, total_Q): 18 | ''' 19 | support: Inputs of the support set. 20 | query: Inputs of the query set. 21 | N: Num of classes 22 | K: Num of instances for each class in the support set 23 | Q: Num of instances in the query set 24 | ''' 25 | 26 | support = self.sentence_encoder(support) # (B * N * K, D), where D is the hidden size 27 | query = self.sentence_encoder(query) # (B * total_Q, D) 28 | 29 | # Layer Norm 30 | support = self.normalize(support) 31 | query = self.normalize(query) 32 | 33 | # Dropout ? 34 | support = self.drop(support) 35 | query = self.drop(query) 36 | 37 | support = support.view(-1, N * K, self.hidden_size) # (B, N * K, D) 38 | query = query.view(-1, total_Q, self.hidden_size) # (B, total_Q, D) 39 | B = support.size(0) # Batch size 40 | support = support.unsqueeze(1) # (B, 1, N * K, D) 41 | query = query.unsqueeze(2) # (B, total_Q, 1, D) 42 | 43 | # Dot production 44 | z = (support * query).sum(-1) # (B, total_Q, N * K) 45 | z = z.view(-1, total_Q, N, K) # (B, total_Q, N, K) 46 | 47 | # Max combination 48 | logits = z.max(-1)[0] # (B, total_Q, N) 49 | 50 | # NA 51 | minn, _ = logits.min(-1) 52 | logits = torch.cat([logits, minn.unsqueeze(2) - 1], 2) # (B, total_Q, N + 1) 53 | 54 | _, pred = torch.max(logits.view(-1, N+1), 1) 55 | return logits, pred 56 | -------------------------------------------------------------------------------- /models/snail.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | import numpy as np 9 | 10 | class CausalConv1d(nn.Module): 11 | 12 | def __init__(self, in_channels, out_channels, kernel_size=2, dilation=2): 13 | super(CausalConv1d, self).__init__() 14 | self.padding = dilation 15 | self.causal_conv = nn.Conv1d( 16 | in_channels, 17 | out_channels, 18 | kernel_size, 19 | padding = self.padding, 20 | dilation = dilation) 21 | 22 | def forward(self, minibatch): 23 | return self.causal_conv(minibatch)[:, :, :-self.padding] 24 | 25 | 26 | class DenseBlock(nn.Module): 27 | 28 | def __init__(self, in_channels, filters, dilation=2): 29 | super(DenseBlock, self).__init__() 30 | self.causal_conv1 = CausalConv1d( 31 | in_channels, 32 | filters, 33 | dilation=dilation) 34 | self.causal_conv2 = CausalConv1d( 35 | in_channels, 36 | filters, 37 | dilation=dilation) 38 | 39 | def forward(self, minibatch): 40 | tanh = F.tanh(self.causal_conv1(minibatch)) 41 | sig = F.sigmoid(self.causal_conv2(minibatch)) 42 | out = torch.cat([minibatch, tanh*sig], dim=1) 43 | return out 44 | 45 | class TCBlock(nn.Module): 46 | 47 | def __init__(self, in_channels, filters, seq_len): 48 | super(TCBlock, self).__init__() 49 | layer_count = np.ceil(np.log2(seq_len)).astype(np.int32) 50 | blocks = [] 51 | channel_count = in_channels 52 | for layer in range(layer_count): 53 | block = DenseBlock(channel_count, filters, dilation=2**layer) 54 | blocks.append(block) 55 | channel_count += filters 56 | self.tcblock = nn.Sequential(*blocks) 57 | self._dim = channel_count 58 | 59 | def forward(self, minibatch): 60 | return self.tcblock(minibatch) 61 | 62 | @property 63 | def dim(self): 64 | return self._dim 65 | 66 | class AttentionBlock(nn.Module): 67 | def __init__(self, dims, k_size, v_size, seq_len): 68 | 69 | super(AttentionBlock, self).__init__() 70 | self.key_layer = nn.Linear(dims, k_size) 71 | self.query_layer = nn.Linear(dims, k_size) 72 | self.value_layer = nn.Linear(dims, v_size) 73 | self.sqrt_k = np.sqrt(k_size) 74 | mask = np.tril(np.ones((seq_len, seq_len))).astype(np.float32) 75 | self.mask = nn.Parameter(torch.from_numpy(mask), requires_grad=False) 76 | self.minus = - 100. 77 | self._dim = dims + v_size 78 | 79 | def forward(self, minibatch, current_seq_len): 80 | keys = self.key_layer(minibatch) 81 | #queries = self.query_layer(minibatch) 82 | queries = keys 83 | values = self.value_layer(minibatch) 84 | current_mask = self.mask[:current_seq_len, :current_seq_len] 85 | logits = current_mask * torch.div(torch.bmm(queries, keys.transpose(2,1)), self.sqrt_k) + self.minus * (1. - current_mask) 86 | probs = F.softmax(logits, 2) 87 | read = torch.bmm(probs, values) 88 | return torch.cat([minibatch, read], dim=2) 89 | 90 | @property 91 | def dim(self): 92 | return self._dim 93 | 94 | class SNAIL(fewshot_re_kit.framework.FewShotREModel): 95 | 96 | def __init__(self, sentence_encoder, N, K, hidden_size=230): 97 | ''' 98 | N: num of classes 99 | K: num of instances for each class in the support set 100 | ''' 101 | fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder) 102 | self.hidden_size = hidden_size 103 | self.drop = nn.Dropout() 104 | self.seq_len = N * K + 1 105 | self.att0 = AttentionBlock(hidden_size + N, 64, 32, self.seq_len) 106 | self.tc1 = TCBlock(self.att0.dim, 128, self.seq_len) 107 | self.att1 = AttentionBlock(self.tc1.dim, 256, 128, self.seq_len) 108 | self.tc2 = TCBlock(self.att1.dim, 128, self.seq_len) 109 | self.att2 = AttentionBlock(self.tc2.dim, 512, 256, self.seq_len) 110 | self.disc = nn.Linear(self.att2.dim, N, bias=False) 111 | self.bn1 = nn.BatchNorm1d(self.tc1.dim) 112 | self.bn2 = nn.BatchNorm1d(self.tc2.dim) 113 | 114 | def forward(self, support, query, N, K, NQ): 115 | support = self.sentence_encoder(support) # (B * N * K, D), where D is the hidden size 116 | query = self.sentence_encoder(query) # (B * N * Q, D) 117 | # support = self.drop(support) 118 | # query = self.drop(query) 119 | support = support.view(-1, N, K, self.hidden_size) # (B, N, K, D) 120 | query = query.view(-1, NQ, self.hidden_size) # (B, N * Q, D) 121 | B = support.size(0) # Batch size 122 | 123 | support = support.unsqueeze(1).expand(-1, NQ, -1, -1, -1).contiguous().view(-1, N * K, self.hidden_size) # (B * NQ, N * K, D) 124 | query = query.view(-1, 1, self.hidden_size) # (B * NQ, 1, D) 125 | minibatch = torch.cat([support, query], 1) 126 | labels = torch.zeros((B * NQ, N * K + 1, N)).float().cuda() 127 | minibatch = torch.cat((minibatch, labels), 2) 128 | for i in range(N): 129 | for j in range(K): 130 | minibatch[:, i * K + j, i] = 1 131 | 132 | x = self.att0(minibatch, self.seq_len).transpose(1, 2) 133 | #x = self.bn1(x).transpose(1, 2) 134 | x = self.bn1(self.tc1(x)).transpose(1, 2) 135 | #x = self.tc1(x).transpose(1, 2) 136 | x = self.att1(x, self.seq_len).transpose(1, 2) 137 | x = self.bn2(self.tc2(x)).transpose(1, 2) 138 | #x = self.tc2(x).transpose(1, 2) 139 | x = self.att2(x, self.seq_len) 140 | x = x[:, -1, :] 141 | logits = self.disc(x) 142 | _, pred = torch.max(logits, -1) 143 | return logits, pred 144 | 145 | -------------------------------------------------------------------------------- /paper/fewrel1.0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/FewRel/278a2315d2138810a379cd8d5718914dc56e2582/paper/fewrel1.0.pdf -------------------------------------------------------------------------------- /paper/fewrel1_appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/FewRel/278a2315d2138810a379cd8d5718914dc56e2582/paper/fewrel1_appendix.pdf -------------------------------------------------------------------------------- /paper/fewrel2.0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/FewRel/278a2315d2138810a379cd8d5718914dc56e2582/paper/fewrel2.0.pdf -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # FewRel Dataset, Toolkits and Baseline Models 2 | 3 | Our benchmark website: [https://thunlp.github.io/fewrel.html](https://thunlp.github.io/fewrel.html) 4 | 5 | FewRel is a large-scale few-shot relation extraction dataset, which contains more than one hundred relations and tens of thousands of annotated instances cross different domains. Our dataset is presented in our EMNLP 2018 paper [FewRel: A Large-Scale Few-Shot Relation Classification Dataset with State-of-the-Art Evaluation](https://www.aclweb.org/anthology/D18-1514.pdf) and a following-up version is presented in our EMNLP 2019 paper [FewRel 2.0: Towards More Challenging Few-Shot Relation Classification](https://www.aclweb.org/anthology/D19-1649.pdf). 6 | 7 | Based on our dataset and designed few-shot settings, we have two different benchmarks: 8 | 9 | * FewRel 1.0: This is the first one to incorporate few-shot learning with relation extraction, where your model need to handle both the few-shot challenge and extracting entity relations from plain text. This benchmark provides a training dataset with 64 relations and a validation set with 16 relations. Once you submit your code to our [benchmark website](https://thunlp.github.io/1/fewrel1.html), it will be evaluated on a hidden test set with 20 relations. Each relation has 100 human-annotated instances. 10 | 11 | * FewRel 2.0: We found out that there are two long-neglected aspects in previous few-shot research: (1) How well models can transfer across different domains. (2) Can few-shot models detect instances belonging to none of the given few-shot classes. To dig deeper in these two aspects, we propose the 2.0 version of our dataset, with newly-added **domain adaptation (DA)** and **none-of-the-above (NOTA) detection** challenges. Find our more in our [paper](https://www.aclweb.org/anthology/D19-1649.pdf) and evaluation websites [FewRel 2.0 domain adaptation](https://thunlp.github.io/2/fewrel2_da.html) / [FewRel 2.0 none-of-the-above detection](https://thunlp.github.io/2/fewrel2_nota.html) 12 | 13 | ## Citing 14 | If you used our data, toolkits or baseline models, please kindly cite our paper: 15 | ``` 16 | @inproceedings{han-etal-2018-fewrel, 17 | title = "{F}ew{R}el: A Large-Scale Supervised Few-Shot Relation Classification Dataset with State-of-the-Art Evaluation", 18 | author = "Han, Xu and Zhu, Hao and Yu, Pengfei and Wang, Ziyun and Yao, Yuan and Liu, Zhiyuan and Sun, Maosong", 19 | booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing", 20 | month = oct # "-" # nov, 21 | year = "2018", 22 | address = "Brussels, Belgium", 23 | publisher = "Association for Computational Linguistics", 24 | url = "https://www.aclweb.org/anthology/D18-1514", 25 | doi = "10.18653/v1/D18-1514", 26 | pages = "4803--4809" 27 | } 28 | 29 | @inproceedings{gao-etal-2019-fewrel, 30 | title = "{F}ew{R}el 2.0: Towards More Challenging Few-Shot Relation Classification", 31 | author = "Gao, Tianyu and Han, Xu and Zhu, Hao and Liu, Zhiyuan and Li, Peng and Sun, Maosong and Zhou, Jie", 32 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)", 33 | month = nov, 34 | year = "2019", 35 | address = "Hong Kong, China", 36 | publisher = "Association for Computational Linguistics", 37 | url = "https://www.aclweb.org/anthology/D19-1649", 38 | doi = "10.18653/v1/D19-1649", 39 | pages = "6251--6256" 40 | } 41 | ``` 42 | 43 | If you have questions about any part of the paper, submission, leaderboard, codes, data, please e-mail `gaotianyu1350@126.com`. 44 | 45 | ## Contributions 46 | 47 | For FewRel 1.0, Hao Zhu first proposed this problem and proposed the way to build the dataset and the baseline system; Ziyuan Wang built and maintained the crowdsourcing website; Yuan Yao helped download the original data and conducted preprocess; 48 | Xu Han, Hao Zhu, Pengfei Yu and Ziyun Wang implemented baselines and wrote the paper together; Zhiyuan Liu provided thoughtful advice and funds through the whole project. The order of the first four authors are determined by dice rolling. 49 | 50 | ## Dataset and Pretrain files 51 | 52 | The dataset has already be contained in the github repo. However, due to the large size, glove files (pre-trained word embeddings) and BERT pretrain checkpoint are not included. Please use the script `download_pretrain.sh` to download these pretrain files. 53 | 54 | We also provide [pid2name.json](https://github.com/thunlp/FewRel/blob/master/data/pid2name.json) to show the Wikidata PID, name and description for each relation. 55 | 56 | **Note: We did not release the test dataset for both FewRel 1.0 and 2.0 for fair comparison. We recommend you to evaluate your models on the validation set first, and then submit it to our evaluation websites (which you can find above).** 57 | 58 | ## Training a Model 59 | 60 | To run our baseline models, use command 61 | 62 | ```bash 63 | python train_demo.py 64 | ``` 65 | 66 | This will start the training and evaluating process of Prototypical Networks in a 5-way 5-shot setting. You can also use different args to start different process. Some of them are here: 67 | 68 | * `train / val / test`: Specify the training / validation / test set. For example, if you use `train_wiki` for `train`, the program will load `data/train_wiki.json` for training. You should always use `train_wiki` for training and `val_wiki` (FewRel 1.0 and FewRel 2.0 NOTA challenge) or `val_pubmed` (FewRel 2.0 DA challenge) for validation. 69 | * `trainN`: N in N-way K-shot. `trainN` is the specific N in training process. 70 | * `N`: N in N-way K-shot. 71 | * `K`: K in N-way K-shot. 72 | * `Q`: Sample Q query instances for each relation. 73 | * `model`: Which model to use. The default one is `proto`, standing for Prototypical Networks. Note that if you use the **PAIR** model from our paper [FewRel 2.0](https://www.aclweb.org/anthology/D19-1649.pdf), you should also use `--encoder bert --pair`. 74 | * `encoder`: Which encoder to use. You can choose `cnn` or `bert`. 75 | * `na_rate`: NA rate for FewRel 2.0 none-of-the-above (NOTA) detection. Note that here `na_rate` specifies the rate between Q for NOTA and Q for positive. For example, `na_rate=0` means the normal setting, `na_rate=1,2,5` corresponds to NA rate = 15%, 30% and 50% in 5-way settings. 76 | 77 | There are also many args for training (like `batch_size` and `lr`) and you can find more details in our codes. 78 | 79 | ## Inference 80 | 81 | You can evaluate an existing checkpoint by 82 | 83 | ```bash 84 | python train_demo.py --only_test --load_ckpt {CHECKPOINT_PATH} {OTHER_ARGS} 85 | ``` 86 | 87 | Here we provide a BERT-PAIR [checkpoint](https://thunlp.oss-cn-qingdao.aliyuncs.com/fewrel/pair-bert-train_wiki-val_wiki-5-1.pth.tar) (trained on FewRel 1.0 dataset, 5 way 1 shot). 88 | 89 | ## Reproduction 90 | 91 | **BERT-PAIR for FewRel 1.0** 92 | 93 | ```bash 94 | python train_demo.py \ 95 | --trainN 5 --N 5 --K 1 --Q 1 \ 96 | --model pair --encoder bert --pair --hidden_size 768 --val_step 1000 \ 97 | --batch_size 4 --fp16 \ 98 | ``` 99 | 100 | Note that `--fp16` requires Nvidia's [apex](https://github.com/NVIDIA/apex). 101 | 102 | | | 5 way 1 shot | 5 way 5 shot | 10 way 1 shot | 10 way 5 shot | 103 | | --------------- | ----------- | ------------- | ------------ | ------------- | 104 | | Val | 85.66 | 89.48 | 76.84 | 81.76 | 105 | | Test | 88.32 | 93.22 | 80.63 | 87.02 | 106 | 107 | **BERT-PAIR for Domain Adaptation (FewRel 2.0)** 108 | 109 | ```bash 110 | python train_demo.py \ 111 | --trainN 5 --N 5 --K 1 --Q 1 \ 112 | --model pair --encoder bert --pair --hidden_size 768 --val_step 1000 \ 113 | --batch_size 4 --fp16 --val val_pubmed --test val_pubmed \ 114 | ``` 115 | 116 | | | 5 way 1 shot | 5 way 5 shot | 10 way 1 shot | 10 way 5 shot | 117 | | --------------- | ----------- | ------------- | ------------ | ------------- | 118 | | Val | 70.70 | 80.59 | 59.52 | 70.30 | 119 | | Test | 67.41 | 78.57 | 54.89 | 66.85 | 120 | 121 | **BERT-PAIR for None-of-the-Above (FewRel 2.0)** 122 | 123 | ```bash 124 | python train_demo.py \ 125 | --trainN 5 --N 5 --K 1 --Q 1 \ 126 | --model pair --encoder bert --pair --hidden_size 768 --val_step 1000 \ 127 | --batch_size 4 --fp16 --na_rate 5 \ 128 | ``` 129 | 130 | | | 5 way 1 shot (0% NOTA) | 5 way 1 shot (50% NOTA) | 5 way 5 shot (0% NOTA) | 5 way 5 shot (50% NOTA) | 131 | | --------------- | ----------- | ------------- | ------------ | ------------- | 132 | | Val | 74.56 | 73.09 | 75.01 | 75.38 | 133 | | Test | 76.73 | 80.31 | 83.32 | 84.64 | 134 | 135 | **Proto-CNN + Adversarial Training for Domain Adaptation (FewRel 2.0)** 136 | 137 | ```bash 138 | python train_demo.py \ 139 | --val val_pubmed --adv pubmed_unsupervised --trainN 10 --N {} --K {} \ 140 | --model proto --encoder cnn --val_step 1000 \ 141 | ``` 142 | 143 | | | 5 way 1 shot | 5 way 5 shot | 10 way 1 shot | 10 way 5 shot | 144 | | --------------- | ----------- | ------------- | ------------ | ------------- | 145 | | Val | 48.73 | 64.38 | 34.82 | 50.39 | 146 | | Test | 42.21 | 58.71 | 28.91 | 44.35 | 147 | -------------------------------------------------------------------------------- /train_demo.py: -------------------------------------------------------------------------------- 1 | from fewshot_re_kit.data_loader import get_loader, get_loader_pair, get_loader_unsupervised 2 | from fewshot_re_kit.framework import FewShotREFramework 3 | from fewshot_re_kit.sentence_encoder import CNNSentenceEncoder, BERTSentenceEncoder, BERTPAIRSentenceEncoder, RobertaSentenceEncoder, RobertaPAIRSentenceEncoder 4 | import models 5 | from models.proto import Proto 6 | from models.gnn import GNN 7 | from models.snail import SNAIL 8 | from models.metanet import MetaNet 9 | from models.siamese import Siamese 10 | from models.pair import Pair 11 | from models.d import Discriminator 12 | from models.mtb import Mtb 13 | import sys 14 | import torch 15 | from torch import optim, nn 16 | import numpy as np 17 | import json 18 | import argparse 19 | import os 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--train', default='train_wiki', 24 | help='train file') 25 | parser.add_argument('--val', default='val_wiki', 26 | help='val file') 27 | parser.add_argument('--test', default='test_wiki', 28 | help='test file') 29 | parser.add_argument('--adv', default=None, 30 | help='adv file') 31 | parser.add_argument('--trainN', default=10, type=int, 32 | help='N in train') 33 | parser.add_argument('--N', default=5, type=int, 34 | help='N way') 35 | parser.add_argument('--K', default=5, type=int, 36 | help='K shot') 37 | parser.add_argument('--Q', default=5, type=int, 38 | help='Num of query per class') 39 | parser.add_argument('--batch_size', default=4, type=int, 40 | help='batch size') 41 | parser.add_argument('--train_iter', default=30000, type=int, 42 | help='num of iters in training') 43 | parser.add_argument('--val_iter', default=1000, type=int, 44 | help='num of iters in validation') 45 | parser.add_argument('--test_iter', default=10000, type=int, 46 | help='num of iters in testing') 47 | parser.add_argument('--val_step', default=2000, type=int, 48 | help='val after training how many iters') 49 | parser.add_argument('--model', default='proto', 50 | help='model name') 51 | parser.add_argument('--encoder', default='cnn', 52 | help='encoder: cnn or bert or roberta') 53 | parser.add_argument('--max_length', default=128, type=int, 54 | help='max length') 55 | parser.add_argument('--lr', default=-1, type=float, 56 | help='learning rate') 57 | parser.add_argument('--weight_decay', default=1e-5, type=float, 58 | help='weight decay') 59 | parser.add_argument('--dropout', default=0.0, type=float, 60 | help='dropout rate') 61 | parser.add_argument('--na_rate', default=0, type=int, 62 | help='NA rate (NA = Q * na_rate)') 63 | parser.add_argument('--grad_iter', default=1, type=int, 64 | help='accumulate gradient every x iterations') 65 | parser.add_argument('--optim', default='sgd', 66 | help='sgd / adam / adamw') 67 | parser.add_argument('--hidden_size', default=230, type=int, 68 | help='hidden size') 69 | parser.add_argument('--load_ckpt', default=None, 70 | help='load ckpt') 71 | parser.add_argument('--save_ckpt', default=None, 72 | help='save ckpt') 73 | parser.add_argument('--fp16', action='store_true', 74 | help='use nvidia apex fp16') 75 | parser.add_argument('--only_test', action='store_true', 76 | help='only test') 77 | parser.add_argument('--ckpt_name', type=str, default='', 78 | help='checkpoint name.') 79 | 80 | 81 | # only for bert / roberta 82 | parser.add_argument('--pair', action='store_true', 83 | help='use pair model') 84 | parser.add_argument('--pretrain_ckpt', default=None, 85 | help='bert / roberta pre-trained checkpoint') 86 | parser.add_argument('--cat_entity_rep', action='store_true', 87 | help='concatenate entity representation as sentence rep') 88 | 89 | # only for prototypical networks 90 | parser.add_argument('--dot', action='store_true', 91 | help='use dot instead of L2 distance for proto') 92 | 93 | # only for mtb 94 | parser.add_argument('--no_dropout', action='store_true', 95 | help='do not use dropout after BERT (still has dropout in BERT).') 96 | 97 | # experiment 98 | parser.add_argument('--mask_entity', action='store_true', 99 | help='mask entity names') 100 | parser.add_argument('--use_sgd_for_bert', action='store_true', 101 | help='use SGD instead of AdamW for BERT.') 102 | 103 | opt = parser.parse_args() 104 | trainN = opt.trainN 105 | N = opt.N 106 | K = opt.K 107 | Q = opt.Q 108 | batch_size = opt.batch_size 109 | model_name = opt.model 110 | encoder_name = opt.encoder 111 | max_length = opt.max_length 112 | 113 | print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K)) 114 | print("model: {}".format(model_name)) 115 | print("encoder: {}".format(encoder_name)) 116 | print("max_length: {}".format(max_length)) 117 | 118 | if encoder_name == 'cnn': 119 | try: 120 | glove_mat = np.load('./pretrain/glove/glove_mat.npy') 121 | glove_word2id = json.load(open('./pretrain/glove/glove_word2id.json')) 122 | except: 123 | raise Exception("Cannot find glove files. Run glove/download_glove.sh to download glove files.") 124 | sentence_encoder = CNNSentenceEncoder( 125 | glove_mat, 126 | glove_word2id, 127 | max_length) 128 | elif encoder_name == 'bert': 129 | pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased' 130 | if opt.pair: 131 | sentence_encoder = BERTPAIRSentenceEncoder( 132 | pretrain_ckpt, 133 | max_length) 134 | else: 135 | sentence_encoder = BERTSentenceEncoder( 136 | pretrain_ckpt, 137 | max_length, 138 | cat_entity_rep=opt.cat_entity_rep, 139 | mask_entity=opt.mask_entity) 140 | elif encoder_name == 'roberta': 141 | pretrain_ckpt = opt.pretrain_ckpt or 'roberta-base' 142 | if opt.pair: 143 | sentence_encoder = RobertaPAIRSentenceEncoder( 144 | pretrain_ckpt, 145 | max_length) 146 | else: 147 | sentence_encoder = RobertaSentenceEncoder( 148 | pretrain_ckpt, 149 | max_length, 150 | cat_entity_rep=opt.cat_entity_rep) 151 | else: 152 | raise NotImplementedError 153 | 154 | if opt.pair: 155 | train_data_loader = get_loader_pair(opt.train, sentence_encoder, 156 | N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name) 157 | val_data_loader = get_loader_pair(opt.val, sentence_encoder, 158 | N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name) 159 | test_data_loader = get_loader_pair(opt.test, sentence_encoder, 160 | N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name) 161 | else: 162 | train_data_loader = get_loader(opt.train, sentence_encoder, 163 | N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) 164 | val_data_loader = get_loader(opt.val, sentence_encoder, 165 | N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) 166 | test_data_loader = get_loader(opt.test, sentence_encoder, 167 | N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) 168 | if opt.adv: 169 | adv_data_loader = get_loader_unsupervised(opt.adv, sentence_encoder, 170 | N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) 171 | 172 | if opt.optim == 'sgd': 173 | pytorch_optim = optim.SGD 174 | elif opt.optim == 'adam': 175 | pytorch_optim = optim.Adam 176 | elif opt.optim == 'adamw': 177 | from transformers import AdamW 178 | pytorch_optim = AdamW 179 | else: 180 | raise NotImplementedError 181 | if opt.adv: 182 | d = Discriminator(opt.hidden_size) 183 | framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader, adv_data_loader, adv=opt.adv, d=d) 184 | else: 185 | framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) 186 | 187 | prefix = '-'.join([model_name, encoder_name, opt.train, opt.val, str(N), str(K)]) 188 | if opt.adv is not None: 189 | prefix += '-adv_' + opt.adv 190 | if opt.na_rate != 0: 191 | prefix += '-na{}'.format(opt.na_rate) 192 | if opt.dot: 193 | prefix += '-dot' 194 | if opt.cat_entity_rep: 195 | prefix += '-catentity' 196 | if len(opt.ckpt_name) > 0: 197 | prefix += '-' + opt.ckpt_name 198 | 199 | if model_name == 'proto': 200 | model = Proto(sentence_encoder, dot=opt.dot) 201 | elif model_name == 'gnn': 202 | model = GNN(sentence_encoder, N, hidden_size=opt.hidden_size) 203 | elif model_name == 'snail': 204 | model = SNAIL(sentence_encoder, N, K, hidden_size=opt.hidden_size) 205 | elif model_name == 'metanet': 206 | model = MetaNet(N, K, sentence_encoder.embedding, max_length) 207 | elif model_name == 'siamese': 208 | model = Siamese(sentence_encoder, hidden_size=opt.hidden_size, dropout=opt.dropout) 209 | elif model_name == 'pair': 210 | model = Pair(sentence_encoder, hidden_size=opt.hidden_size) 211 | elif model_name == 'mtb': 212 | model = Mtb(sentence_encoder, use_dropout=not opt.no_dropout) 213 | else: 214 | raise NotImplementedError 215 | 216 | if not os.path.exists('checkpoint'): 217 | os.mkdir('checkpoint') 218 | ckpt = 'checkpoint/{}.pth.tar'.format(prefix) 219 | if opt.save_ckpt: 220 | ckpt = opt.save_ckpt 221 | 222 | if torch.cuda.is_available(): 223 | model.cuda() 224 | 225 | if not opt.only_test: 226 | if encoder_name in ['bert', 'roberta']: 227 | bert_optim = True 228 | else: 229 | bert_optim = False 230 | 231 | if opt.lr == -1: 232 | if bert_optim: 233 | opt.lr = 2e-5 234 | else: 235 | opt.lr = 1e-1 236 | 237 | opt.train_iter = opt.train_iter * opt.grad_iter 238 | framework.train(model, prefix, batch_size, trainN, N, K, Q, 239 | pytorch_optim=pytorch_optim, load_ckpt=opt.load_ckpt, save_ckpt=ckpt, 240 | na_rate=opt.na_rate, val_step=opt.val_step, fp16=opt.fp16, pair=opt.pair, 241 | train_iter=opt.train_iter, val_iter=opt.val_iter, bert_optim=bert_optim, 242 | learning_rate=opt.lr, use_sgd_for_bert=opt.use_sgd_for_bert, grad_iter=opt.grad_iter) 243 | else: 244 | ckpt = opt.load_ckpt 245 | if ckpt is None: 246 | print("Warning: --load_ckpt is not specified. Will load Hugginface pre-trained checkpoint.") 247 | ckpt = 'none' 248 | 249 | acc = framework.eval(model, batch_size, N, K, Q, opt.test_iter, na_rate=opt.na_rate, ckpt=ckpt, pair=opt.pair) 250 | print("RESULT: %.2f" % (acc * 100)) 251 | 252 | if __name__ == "__main__": 253 | main() 254 | --------------------------------------------------------------------------------