├── common ├── __init__.py ├── log.py ├── center_loss.py ├── utils.py ├── clustering.py ├── sinkhorn_knopp.py ├── convert_pred_for_eval.py ├── predictions.py ├── b3.py ├── metrics.py └── layers.py ├── figures └── framework.jpg ├── .gitignore ├── configs ├── rel │ ├── fewrel │ │ ├── rsn.yaml │ │ ├── rocore.yaml │ │ ├── pretrain.yaml │ │ └── train.yaml │ └── tacred │ │ ├── rsn.yaml │ │ ├── rocore.yaml │ │ ├── pretrain.yaml │ │ └── train.yaml └── event │ └── ace │ ├── pretrain.yaml │ ├── pretrain_e2e.yaml │ ├── spherical.yaml │ ├── ssvqvae.yaml │ ├── train.yaml │ └── train_e2e.yaml ├── LICENSE ├── data_sample ├── tacred │ └── relation_description.csv └── fewrel │ └── relation_description.csv ├── README.md ├── baselines ├── RSN.py ├── RoCORE_layers.py ├── vae.py ├── vqvae_model.py ├── RoCORE_model.py ├── etypeclus_model.py ├── RSN_model.py └── latent_space_clustering.py └── src ├── multiview_layers.py └── main.py /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raspberryice/type-discovery-abs/HEAD/figures/framework.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | scripts/ 3 | checkpoints/ 4 | wandb/ 5 | *.ipynb 6 | __pycache__/ 7 | *.pkl 8 | *.pdf 9 | cache/* 10 | ETypeClus/* 11 | eval/**/*.json -------------------------------------------------------------------------------- /configs/rel/fewrel/rsn.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "rsn" 3 | task: "rel" 4 | # data related 5 | known_types: 64 6 | unknown_types: 16 7 | dataset_name: "fewrel" 8 | dataset_dir: "data/fewrel" 9 | test_ratio: 0.15 10 | # model 11 | feature: "all" 12 | hidden_dim: 512 13 | perturb_scale: 0.02 14 | vat_loss_weight: 1.0 15 | p_cond: 0.03 16 | # runtime 17 | num_train_epochs: 20 18 | num_pretrain_epochs: 0 19 | train_batch_size: 16 20 | accumulate_grad_batches: 2 21 | num_workers: 2 22 | eval_batch_size: 32 23 | -------------------------------------------------------------------------------- /configs/rel/tacred/rsn.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "rsn" 3 | task: "rel" 4 | # data related 5 | known_types: 31 6 | unknown_types: 10 7 | dataset_name: "tacred" 8 | dataset_dir: "data/tacred" 9 | test_ratio: 0.15 10 | # model 11 | feature: "all" 12 | hidden_dim: 512 13 | perturb_scale: 0.02 14 | vat_loss_weight: 1.0 15 | p_cond: 0.03 16 | # runtime 17 | num_train_epochs: 20 18 | num_pretrain_epochs: 0 19 | train_batch_size: 16 20 | accumulate_grad_batches: 2 21 | num_workers: 2 22 | eval_batch_size: 32 23 | -------------------------------------------------------------------------------- /configs/rel/fewrel/rocore.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "rocore" 3 | task: "rel" 4 | # data related 5 | known_types: 64 6 | unknown_types: 16 7 | dataset_name: "fewrel" 8 | dataset_dir: "data/fewrel" 9 | test_ratio: 0.15 10 | # model 11 | feature: "all" 12 | hidden_dim: 512 13 | kmeans_dim: 256 14 | # runtime 15 | num_train_epochs: 100 16 | num_pretrain_epochs: 10 17 | train_batch_size: 100 18 | num_workers: 4 19 | eval_batch_size: 100 20 | learning_rate: 1e-4 21 | center_loss: 0.005 22 | sigmoid: 2.0 23 | layer: 8 24 | 25 | -------------------------------------------------------------------------------- /configs/rel/tacred/rocore.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "rocore" 3 | task: "rel" 4 | # data related 5 | known_types: 31 6 | unknown_types: 10 7 | dataset_name: "tacred" 8 | dataset_dir: "data/tacred" 9 | test_ratio: 0.15 10 | # model 11 | feature: "all" 12 | hidden_dim: 512 13 | kmeans_dim: 256 14 | # runtime 15 | num_train_epochs: 100 16 | num_pretrain_epochs: 10 17 | train_batch_size: 100 18 | num_workers: 4 19 | eval_batch_size: 100 20 | learning_rate: 1e-4 21 | center_loss: 0.001 22 | sigmoid: 2.0 23 | layer: 8 24 | 25 | -------------------------------------------------------------------------------- /configs/event/ace/pretrain.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "tabs" 3 | task: "event" 4 | # data 5 | known_types: 10 6 | unknown_types: 23 7 | dataset_name: "ace" 8 | dataset_dir: "data/ace" 9 | test_ratio: 0.3 10 | # model 11 | model_type: "bert" 12 | model_name_or_path: "bert-base-uncased" 13 | cache_dir: "./cache" 14 | supervised_pretrain: true 15 | label_smoothing_alpha: 0.1 16 | token_pooling: "first" 17 | hidden_dim: 256 18 | classifier_layers: 2 19 | # runtime 20 | num_train_epochs: 3 21 | train_batch_size: 16 22 | accumulate_grad_batches: 2 23 | num_workers: 2 24 | eval_batch_size: 3 -------------------------------------------------------------------------------- /configs/event/ace/pretrain_e2e.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "tabs" 3 | task: "event" 4 | # data 5 | known_types: 10 6 | dataset_name: "ace" 7 | dataset_dir: "data/ace" 8 | test_ratio: 0.3 9 | e2e: true 10 | # model 11 | model_type: "bert" 12 | model_name_or_path: "bert-base-uncased" 13 | cache_dir: "./cache" 14 | supervised_pretrain: true 15 | label_smoothing_alpha: 0.1 16 | token_pooling: "first" 17 | hidden_dim: 256 18 | classifier_layers: 2 19 | # runtime 20 | num_train_epochs: 3 21 | train_batch_size: 16 22 | accumulate_grad_batches: 2 23 | num_workers: 2 24 | eval_batch_size: 3 -------------------------------------------------------------------------------- /configs/event/ace/spherical.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "etypeclus" 3 | task: "event" 4 | # data 5 | dataset_name: "ace" 6 | dataset_dir: "data/ace" 7 | test_ratio: 0.3 8 | known_types: 10 9 | unknown_types: 23 10 | # model 11 | model_type: "bert" 12 | model_name_or_path: "bert-base-uncased" 13 | cache_dir: "./cache" 14 | temperature: 0.1 15 | distribution: "softmax" 16 | gamma: 5 17 | # runtime 18 | num_train_epochs: 50 19 | num_pretrain_epochs: 10 20 | train_batch_size: 64 21 | accumulate_grad_batches: 1 22 | num_workers: 2 23 | eval_batch_size: 64 24 | learning_rate: 1e-4 25 | 26 | -------------------------------------------------------------------------------- /configs/event/ace/ssvqvae.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "vqvae" 3 | task: "event" 4 | # data related 5 | known_types: 10 6 | unknown_types: 23 7 | dataset_name: "ace" 8 | dataset_dir: "data/ace" 9 | test_ratio: 0.3 10 | # model related 11 | model_type: "bert" 12 | model_name_or_path: "bert-base-uncased" 13 | cache_dir: "./cache" 14 | beta: 1.0 15 | gamma: 0.0 16 | hybrid: true # hybrid vqvae + vae 17 | recon_loss: 0.0 18 | num_train_epochs: 15 19 | num_pretrain_epochs: 0 20 | train_batch_size: 32 21 | accumulate_grad_batches: 1 22 | num_workers: 2 23 | eval_batch_size: 32 24 | learning_rate: 3e-5 25 | 26 | -------------------------------------------------------------------------------- /configs/rel/fewrel/pretrain.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "tabs" 3 | task: "rel" 4 | # data related 5 | known_types: 64 6 | unknown_types: 16 7 | dataset_name: "fewrel" 8 | dataset_dir: "data/fewrel" 9 | test_ratio: 0.15 10 | # model related 11 | model_type: "bert" 12 | model_name_or_path: "bert-base-uncased" 13 | cache_dir: "./cache" 14 | supervised_pretrain: true 15 | label_smoothing_alpha: 0.1 16 | token_pooling: "first" 17 | hidden_dim: 256 18 | classifier_layers: 2 19 | # runtime 20 | num_train_epochs: 3 21 | train_batch_size: 16 22 | accumulate_grad_batches: 2 23 | num_workers: 2 24 | eval_batch_size: 32 25 | 26 | -------------------------------------------------------------------------------- /configs/rel/tacred/pretrain.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "tabs" 3 | task: "rel" 4 | # data related 5 | known_types: 31 6 | unknown_types: 10 7 | dataset_name: "tacred" 8 | dataset_dir: "data/tacred" 9 | test_ratio: 0.15 10 | # model related 11 | model_type: "bert" 12 | model_name_or_path: "bert-base-uncased" 13 | cache_dir: "./cache" 14 | supervised_pretrain: true 15 | label_smoothing_alpha: 0.1 16 | token_pooling: "first" 17 | hidden_dim: 256 18 | classifier_layers: 2 19 | # runtime 20 | num_train_epochs: 3 21 | train_batch_size: 16 22 | accumulate_grad_batches: 2 23 | num_workers: 2 24 | eval_batch_size: 16 25 | 26 | -------------------------------------------------------------------------------- /configs/event/ace/train.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "tabs" 3 | task: "event" 4 | # data related 5 | known_types: 10 6 | unknown_types: 23 7 | dataset_name: "ace" 8 | dataset_dir: "data/ace" 9 | test_ratio: 0.3 10 | # model related 11 | model_type: "bert" 12 | model_name_or_path: "bert-base-uncased" 13 | cache_dir: "./cache" 14 | label_smoothing_alpha: 0.1 15 | label_smoothing_ramp: 3 16 | token_pooling: "first" 17 | hidden_dim: 256 18 | kmeans_dim: 256 19 | classifier_layers: 2 20 | pairwise_loss: true 21 | clustering: "kmeans" 22 | kmeans_outlier_alpha: 0.0 23 | consistency_loss: 0.2 24 | contrastive_loss: 0.0 25 | recon_loss: 0.0 26 | # runtime 27 | num_train_epochs: 30 28 | num_pretrain_epochs: 0 29 | train_batch_size: 16 30 | accumulate_grad_batches: 2 31 | num_workers: 2 32 | eval_batch_size: 32 33 | check_pl: true 34 | 35 | -------------------------------------------------------------------------------- /configs/event/ace/train_e2e.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "tabs" 3 | task: "event" 4 | # data related 5 | known_types: 10 6 | dataset_name: "ace" 7 | dataset_dir: "data/ace" 8 | test_ratio: 0.3 9 | e2e: true 10 | # model related 11 | model_type: "bert" 12 | model_name_or_path: "bert-base-uncased" 13 | cache_dir: "./cache" 14 | label_smoothing_alpha: 0.1 15 | label_smoothing_ramp: 3 16 | token_pooling: "first" 17 | hidden_dim: 256 18 | kmeans_dim: 256 19 | classifier_layers: 2 20 | pairwise_loss: true 21 | clustering: "kmeans" 22 | kmeans_outlier_alpha: 0.0 23 | consistency_loss: 0.2 24 | contrastive_loss: 0.0 25 | recon_loss: 0.0 26 | # runtime 27 | num_train_epochs: 30 28 | num_pretrain_epochs: 0 29 | train_batch_size: 16 30 | accumulate_grad_batches: 2 31 | num_workers: 2 32 | eval_batch_size: 32 33 | check_pl: true 34 | 35 | -------------------------------------------------------------------------------- /configs/rel/tacred/train.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "tabs" 3 | task: "rel" 4 | # data related 5 | known_types: 31 6 | unknown_types: 10 7 | dataset_name: "tacred" 8 | dataset_dir: "data/tacred" 9 | test_ratio: 0.15 10 | # model related 11 | model_type: "bert" 12 | model_name_or_path: "bert-base-uncased" 13 | cache_dir: "./cache" 14 | label_smoothing_alpha: 0.1 15 | label_smoothing_ramp: 3 16 | token_pooling: "first" 17 | hidden_dim: 256 18 | kmeans_dim: 256 19 | classifier_layers: 2 20 | pairwise_loss: true 21 | clustering: "kmeans" 22 | kmeans_outlier_alpha: 0.0 23 | consistency_loss: 0.2 24 | contrastive_loss: 0.0 25 | recon_loss: 0.0 26 | # runtime 27 | num_train_epochs: 20 28 | num_pretrain_epochs: 0 29 | train_batch_size: 16 30 | accumulate_grad_batches: 2 31 | num_workers: 2 32 | eval_batch_size: 32 33 | check_pl: true 34 | -------------------------------------------------------------------------------- /configs/rel/fewrel/train.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "tabs" 3 | task: "rel" 4 | # data related 5 | known_types: 64 6 | unknown_types: 16 7 | dataset_name: "fewrel" 8 | dataset_dir: "data/fewrel" 9 | test_ratio: 0.15 10 | # model related 11 | model_type: "bert" 12 | model_name_or_path: "bert-base-uncased" 13 | cache_dir: "./cache" 14 | label_smoothing_alpha: 0.1 15 | label_smoothing_ramp: 3 16 | token_pooling: "first" 17 | hidden_dim: 256 18 | kmeans_dim: 256 19 | classifier_layers: 2 20 | pairwise_loss: true 21 | clustering: "kmeans" 22 | kmeans_outlier_alpha: 0.0 23 | consistency_loss: 0.2 24 | contrastive_loss: 0.0 25 | recon_loss: 0.0 26 | # runtime 27 | num_train_epochs: 20 28 | num_pretrain_epochs: 0 29 | train_batch_size: 32 30 | accumulate_grad_batches: 1 31 | num_workers: 2 32 | eval_batch_size: 32 33 | check_pl: true 34 | 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zoey Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /common/log.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | """ 14 | This file contains basic logging logic. 15 | """ 16 | import logging 17 | 18 | names = set() 19 | 20 | 21 | def __setup_custom_logger(name: str) -> logging.Logger: 22 | root_logger = logging.getLogger() 23 | root_logger.handlers.clear() 24 | 25 | formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(module)s - %(message)s') 26 | 27 | names.add(name) 28 | 29 | handler = logging.StreamHandler() 30 | handler.setFormatter(formatter) 31 | 32 | logger = logging.getLogger(name) 33 | logger.setLevel(logging.INFO) 34 | logger.addHandler(handler) 35 | return logger 36 | 37 | 38 | def get_logger(name: str) -> logging.Logger: 39 | if name in names: 40 | return logging.getLogger(name) 41 | else: 42 | return __setup_custom_logger(name) 43 | -------------------------------------------------------------------------------- /data_sample/tacred/relation_description.csv: -------------------------------------------------------------------------------- 1 | org:founded_by 268 organization person 2 | per:employee_of 2163 person organization 3 | org:alternate_names 1359 organization organization 4 | per:cities_of_residence 742 person location/city 5 | per:children 347 person person 6 | per:title 3862 person title 7 | per:siblings 250 person person 8 | per:religion 153 person religion 9 | per:age 833 person number 10 | org:website 223 organization website 11 | per:stateorprovinces_of_residence 484 person location/state_or_province 12 | org:member_of 171 organization organization 13 | org:top_members/employees 2770 organization person 14 | per:countries_of_residence 819 person location/country 15 | org:city_of_headquarters 573 organization location/city 16 | org:members 286 organization country 17 | org:country_of_headquarters 753 organization country 18 | per:spouse 483 person person 19 | org:stateorprovince_of_headquarters 350 organization location/state_or_province 20 | org:number_of_employees/members 121 organization number 21 | org:parents 444 person person 22 | org:subsidiaries 453 organization organization 23 | per:origin 667 person location/country 24 | org:political/religious_affiliation 125 organization religion 25 | per:other_family 319 person person 26 | per:stateorprovince_of_birth 72 person location/state_or_province 27 | org:dissolved 33 organization date 28 | per:date_of_death 394 person date 29 | org:shareholders 144 organization person 30 | per:alternate_names 153 person person 31 | per:parents 296 person person 32 | per:schools_attended 229 person organization/school 33 | per:cause_of_death 337 person cause 34 | per:city_of_death 227 person location/city 35 | per:stateorprovince_of_death 104 person location/state_or_province 36 | org:founded 166 organization date 37 | per:country_of_birth 53 person location/country 38 | per:date_of_birth 103 person date 39 | per:city_of_birth 103 person location/city 40 | per:charges 280 person crime 41 | per:country_of_death 61 person location/country 42 | -------------------------------------------------------------------------------- /common/center_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class CenterLoss(nn.Module): 5 | ''' 6 | L2 loss for pushing representations close to their centroid. 7 | ''' 8 | def __init__(self, dim_hidden: int , num_classes: int, lambda_c: float = 1.0, alpha: float=1.0, weight_by_prob: bool =False): 9 | super().__init__() 10 | self.dim_hidden = dim_hidden 11 | self.num_classes = num_classes 12 | self.lambda_c = lambda_c 13 | self.alpha= alpha 14 | self.weight_by_prob = weight_by_prob 15 | 16 | self.centers = self.register_buffer('centers', torch.zeros((num_classes, dim_hidden), dtype=torch.float), persistent=False) 17 | 18 | def _compute_prob(self, distance_centers: torch.FloatTensor, y: torch.LongTensor): 19 | ''' 20 | compute the probability according to student-t distribution 21 | Bug in original RoCORE code, added (-1) to power operation. 22 | ''' 23 | 24 | q = 1.0/(1.0+distance_centers/self.alpha) # (batch_size, num_class) 25 | q = q**((-1) * (self.alpha+1.0)/2.0) 26 | q = q / torch.sum(q, dim=1, keepdim=True) 27 | prob = q.gather(1, y.unsqueeze(1)).squeeze() # (batch_size) 28 | return prob 29 | 30 | 31 | def forward(self, y: torch.LongTensor, hidden: torch.FloatTensor) -> torch.FloatTensor: 32 | ''' 33 | :param y: (batch_size, ) 34 | :param hidden: (batch_size, dim_hidden) 35 | ''' 36 | batch_size = hidden.size(0) 37 | expanded_hidden = hidden.expand(self.num_classes, -1, -1).transpose(1, 0) # (num_class, batch_size, hid_dim) => (batch_size, num_class, hid_dim) 38 | expanded_centers = self.centers.expand(batch_size, -1, -1) # (batch_size, num_class, hid_dim) 39 | distance_centers = (expanded_hidden - expanded_centers).pow(2).sum(dim=-1) # (batch_size, num_class, hid_dim) => (batch_size, num_class) 40 | intra_distances = distance_centers.gather(1, y.unsqueeze(1)).squeeze() # (batch_size, num_class) => (batch_size, 1) => (batch_size) 41 | 42 | if self.weight_by_prob: 43 | prob = self._compute_prob(distance_centers, y) 44 | loss = 0.5 * self.lambda_c * torch.mean(intra_distances*prob) # (batch_size) => scalar 45 | 46 | else: 47 | loss = 0.5 * self.lambda_c * torch.mean(intra_distances) # (batch_size) => scalar 48 | return loss -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | from typing import List 5 | from copy import deepcopy 6 | 7 | 8 | import torch 9 | import numpy as np 10 | from scipy.optimize import linear_sum_assignment 11 | 12 | def clean_text(text: List[str]) -> List[str]: 13 | ret = [] 14 | for word in text: 15 | normalized_word = re.sub(u"([^\u0020-\u007f])", "", word) 16 | if normalized_word == '' or normalized_word == ' ' or normalized_word == ' ': 17 | normalized_word = '[UNK]' 18 | ret.append(normalized_word) 19 | return ret 20 | 21 | 22 | 23 | def onedim_gather(src: torch.Tensor, dim: int, index: torch.LongTensor) -> torch.Tensor: 24 | ''' 25 | src: (batch, M, L) 26 | index: (batch, M) or (batch, L) or (batch, 1) 27 | 28 | A version of the torch.gather function where the index is only along 1 dim. 29 | ''' 30 | for i in range(len(src.shape)): 31 | if i!=0 and i!=dim: 32 | # index is missing this dimension 33 | index = index.unsqueeze(dim=i) 34 | target_index_size = deepcopy(list(src.shape)) 35 | target_index_size[dim] = 1 36 | index = index.expand(target_index_size) 37 | output = torch.gather(src, dim, index) 38 | return output 39 | 40 | 41 | def get_label_mapping(predicted_logits, labels): 42 | """ 43 | Compute an assignment from predicted to labels. 44 | :param predicted_logits (N, n_classes) 45 | :param labels (N) 46 | """ 47 | M = predicted_logits.size(1) 48 | predicted_numpy = torch.max(predicted_logits, dim=1)[1].detach().cpu().numpy() 49 | labels_numpy = labels.detach().cpu().numpy() 50 | 51 | w = np.zeros((M, M)) # cost matrix 52 | for i in range(labels_numpy.size): 53 | w[ predicted_numpy[i], labels_numpy[i]] += 1 54 | 55 | mapping = linear_sum_assignment(w) 56 | map_matrix = torch.zeros((M, M), dtype=torch.float, device=labels.device) 57 | for i, j in np.transpose(np.asarray(mapping)): 58 | map_matrix[i,j] = 1 59 | return mapping, map_matrix 60 | 61 | 62 | 63 | # From UNO/utils/eval.py 64 | def cluster_acc(y_true:np.array, y_pred: np.array, reassign: bool=False): 65 | """ 66 | Calculate clustering accuracy with assigment 67 | 68 | # Arguments 69 | y: true labels, numpy.array with shape `(n_samples,)` 70 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 71 | 72 | # Return 73 | accuracy, in [0,1] 74 | """ 75 | y_true = y_true.astype(np.int64) # N*K 76 | y_pred = y_pred.astype(np.int64) # N*K 77 | assert y_pred.size == y_true.size # same number of clusters 78 | 79 | D = max(y_pred.max(), y_true.max()) + 1 80 | w = np.zeros((D, D), dtype=np.int64) # cost matrix 81 | for i in range(y_pred.size): 82 | w[y_pred[i], y_true[i]] += 1 83 | 84 | 85 | if reassign: 86 | mapping = compute_best_mapping(w) 87 | 88 | return sum([w[i, j] for i, j in mapping]) * 1.0 / y_pred.size 89 | else: 90 | acc= sum([w[i,i] for i in range(D)]) * 1.0/y_pred.size 91 | return acc 92 | 93 | def compute_best_mapping(w): 94 | return np.transpose(np.asarray(linear_sum_assignment(w.max() - w))) 95 | 96 | 97 | -------------------------------------------------------------------------------- /common/clustering.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Union, Tuple, Optional 2 | from math import floor 3 | from copy import deepcopy 4 | 5 | import torch 6 | import numpy as np 7 | from sklearn.cluster import KMeans, SpectralClustering, AgglomerativeClustering, DBSCAN 8 | from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity 9 | 10 | 11 | 12 | def spectral_clustering(X: np.array, n_clusters: int)->np.array: 13 | ''' 14 | wrapper for spectral clustering 15 | ''' 16 | clf = SpectralClustering(n_clusters=n_clusters, random_state=0, affinity='rbf') 17 | # sim = cosine_similarity(X, X) 18 | label_pred = clf.fit_predict(X) 19 | return label_pred 20 | 21 | def dbscan(X: np.ndarray, n_clusters: int, eps: Optional[float] = None)-> np.ndarray: 22 | ''' 23 | wrapper for dbscan clustering algorithm. 24 | eps: the max distance for the points to be considered a neighbor. 25 | 26 | ''' 27 | distances = euclidean_distances(X) # (n, n) 28 | eps2cluster = {} 29 | 30 | if eps == None: 31 | for eps in range(5, 20): 32 | clf = DBSCAN(eps=eps, min_samples=5, metric='euclidean', n_jobs=4) 33 | label_pred = clf.fit_predict(X) 34 | found_clusters = np.max(label_pred) 35 | eps2cluster[eps] = found_clusters 36 | 37 | min_diff = 1e6 38 | best_eps = None 39 | for eps in eps2cluster: 40 | if eps2cluster[eps] != -1: 41 | diff = abs(eps2cluster[eps] - n_clusters) 42 | if diff < min_diff or not best_eps: 43 | best_eps = eps 44 | min_diff = diff 45 | 46 | 47 | eps = best_eps 48 | 49 | clf = DBSCAN(eps=eps, min_samples=5, metric='euclidean', n_jobs=4) 50 | label_pred = clf.fit_predict(X) 51 | 52 | # searched_eps = set([eps, ]) 53 | # while found_clusters == 0 or found_clusters == -1: 54 | # if found_clusters == 0: 55 | # # eps is too large 56 | # eps -=1 57 | # else: 58 | # eps +=1 59 | # if eps in searched_eps: 60 | # break 61 | # clf = DBSCAN(eps=eps, min_samples=5, metric='euclidean', n_jobs=4) 62 | # label_pred = clf.fit_predict(X) 63 | # found_clusters = np.max(label_pred) 64 | # searched_eps.add(eps) 65 | 66 | 67 | 68 | idx2cluster = { i: cluster for i, cluster in enumerate(label_pred) if (cluster!=-1 or cluster>=n_clusters)} 69 | # assign -1 instances to their nearest neighbors 70 | neighbors = np.argsort(distances, axis=1) 71 | new_label_pred = deepcopy(label_pred) 72 | for i, cluster in enumerate(label_pred): 73 | if cluster == -1: 74 | for n in neighbors[i,1:]:# exclude self 75 | if n in idx2cluster: 76 | new_label_pred[i] = idx2cluster[n] 77 | break 78 | 79 | 80 | return new_label_pred, eps 81 | 82 | def agglomerative_clustering(X:np.array, n_clusters: int)-> np.array: 83 | clf = AgglomerativeClustering(n_clusters=n_clusters, affinity='cosine', linkage='average') 84 | label_pred = clf.fit_predict(X) 85 | return label_pred 86 | 87 | def agglomerative_ward(X: np.array, n_clusters:int)-> np.array: 88 | clf = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward') 89 | label_pred = clf.fit_predict(X) 90 | return label_pred 91 | 92 | -------------------------------------------------------------------------------- /common/sinkhorn_knopp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | 5 | class SinkhornKnopp(torch.nn.Module): 6 | def __init__(self, num_iters=3, epsilon=0.05, queue_len:int=1024, classes_n: int=10, delta=0.0): 7 | super().__init__() 8 | self.num_iters = num_iters 9 | self.epsilon = epsilon 10 | self.delta = delta 11 | 12 | self.classes_n = classes_n 13 | self.queue_len = queue_len 14 | self.register_buffer(name='logit_queue', tensor=torch.zeros((queue_len, classes_n)), persistent=False) 15 | self.cur_len = 0 16 | 17 | def add_to_queue(self, logits: torch.FloatTensor)-> None: 18 | ''' 19 | :param logits: (N, K) 20 | ''' 21 | batch_size = logits.size(0) 22 | classes_n = logits.size(1) 23 | assert (classes_n == self.classes_n) 24 | 25 | new_queue = torch.concat([logits, self.logit_queue], dim=0) 26 | self.logit_queue = new_queue[:self.queue_len, :] 27 | 28 | self.cur_len += batch_size 29 | 30 | self.cur_len = min(self.cur_len, self.queue_len) 31 | 32 | return 33 | 34 | def queue_full(self)-> bool: 35 | 36 | return self.cur_len == self.queue_len 37 | 38 | 39 | 40 | @torch.no_grad() 41 | def forward(self, logits: torch.FloatTensor): 42 | ''' 43 | :param logits: (N, K) 44 | ''' 45 | batch_size = logits.size(0) 46 | all_logits = self.logit_queue 47 | 48 | initial_Q = torch.softmax(all_logits/self.epsilon, dim=1) 49 | # Q = torch.exp(logits / self.epsilon).t() # (K, N) 50 | Q = initial_Q.clone().t() 51 | N = Q.shape[1] 52 | K = Q.shape[0] # how many prototypes 53 | 54 | # make the matrix sums to 1 55 | sum_Q = torch.sum(Q) 56 | assert (torch.any(torch.isinf(sum_Q)) == False), "sum_Q is too large" 57 | Q /= sum_Q 58 | 59 | for it in range(self.num_iters): 60 | # normalize each row: total weight per prototype must be 1/K 61 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 62 | sum_of_rows += self.delta # for numerical stability 63 | Q /= sum_of_rows 64 | Q /= K 65 | 66 | # normalize each column: total weight per sample must be 1/B 67 | sum_of_cols = torch.sum(Q, dim=0, keepdim=True) 68 | Q /= sum_of_cols 69 | Q /= N 70 | 71 | Q *= N # the colomns must sum to 1 so that Q is an assignment 72 | 73 | batch_assignments = Q.t()[:batch_size, :] 74 | return batch_assignments, sum_of_rows.squeeze(), sum_of_cols.squeeze() 75 | 76 | # https://github.com/yukimasano/self-label/blob/master/sinkhornknopp.py 77 | def optimize_L_sk(self,PS: np.array, L: np.array, lamb: int=25 ): 78 | ''' 79 | :param PS: (N, K) probability matrix, K is the number of clusters 80 | :param L: (N,) labels 81 | 82 | ''' 83 | K = PS.size(1) 84 | N = PS.size(0) 85 | 86 | r = np.ones((K, 1), dtype=self.dtype) / K 87 | c = np.ones((N, 1), dtype=self.dtype) / N 88 | inv_K = self.dtype(1./K) 89 | inv_N = self.dtype(1./N) 90 | 91 | PS = PS.T # now it is K x N 92 | PS **= lamb # K x N 93 | 94 | err = 1e6 95 | _counter = 0 96 | while err > 1e-1: 97 | r = inv_K / (PS @ c) # (KxN)@(N,1) = K x 1 98 | c_new = inv_N / (r.T @ PS).T # ((1,K)@(KxN)).t() = N x 1 99 | if _counter % 10 == 0: 100 | err = np.nansum(np.abs(c / c_new - 1)) 101 | c = c_new 102 | _counter += 1 103 | print("error: ", err, 'step ', _counter, flush=True) # " nonneg: ", sum(I), flush=True) 104 | # inplace calculations. 105 | PS *= np.squeeze(c) 106 | PS = PS.T 107 | PS *= np.squeeze(r) 108 | PS = PS.T 109 | 110 | # produce hard labels 111 | argmaxes = np.nanargmax(self.PS, 0) # size N 112 | newL = torch.LongTensor(argmaxes) 113 | 114 | return newL 115 | -------------------------------------------------------------------------------- /common/convert_pred_for_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import random 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--checkpoint_dir',type=str) 9 | parser.add_argument('--output_dir', type=str) 10 | parser.add_argument('--eval_k', type=int, default=10, help='number of instances to display') 11 | parser.add_argument('--intrusion_k', type=int, default=4, help='the number of instances to serve as background.') 12 | parser.add_argument('--max_instances', type=int, default=400) 13 | parser.add_argument('--n_splits', type=int, default=2) 14 | # parser.add_argument('--min_clus', type=int, default=10) 15 | args = parser.parse_args() 16 | 17 | with open(os.path.join(args.checkpoint_dir, 'test_unknown_clusters.json'),'r') as f: 18 | clusters = json.load(f) 19 | 20 | cluster_eval = [] 21 | cluster2showninstances = {} 22 | 23 | print(f'{len(clusters)} clusters predicted') 24 | for cluster_id, instances in clusters.items(): 25 | instances_to_show = instances[:args.eval_k] + instances[-args.eval_k:] 26 | 27 | ins_as_text = [] 28 | for ins in instances_to_show: 29 | ins_as_text.append(f"{ins['trigger']}: {ins['sentence']}") 30 | 31 | cluster_dict = { 32 | "id": cluster_id, 33 | "text": "\n\n".join(ins_as_text), 34 | "label" : [] 35 | } 36 | cluster2showninstances[cluster_id] = instances_to_show # List 37 | cluster_eval.append(cluster_dict) 38 | 39 | with open(os.path.join(args.output_dir,'cluster_eval.json'),'w') as f: 40 | json.dump(cluster_eval, f, indent=2) 41 | 42 | 43 | instance_eval = [] 44 | ins2label = {} 45 | ins_id = 0 46 | for cluster_id in cluster2showninstances: 47 | for ins in cluster2showninstances[cluster_id]: 48 | text = [] 49 | text.append(f"{ins['trigger']}: {ins['sentence']}") 50 | text.append("Background Instances:") 51 | pool = [] 52 | if random.random() < 0.5: 53 | # sample from same cluster 54 | while len(pool) < args.intrusion_k: 55 | background_ins = random.sample(clusters[cluster_id], k=1)[0] 56 | if background_ins != ins: 57 | pool.append(background_ins) 58 | text.append(f"{background_ins['trigger']}: {background_ins['sentence']}") 59 | label = True 60 | else: 61 | while len(pool) < args.intrusion_k: 62 | sampled_clus = random.sample(list(cluster2showninstances.keys()),k=1)[0] 63 | if sampled_clus != cluster_id: 64 | background_ins = random.sample(clusters[sampled_clus],k=1)[0] 65 | text.append(f"{background_ins['trigger']}: {background_ins['sentence']}") 66 | pool.append(background_ins) 67 | 68 | label=False 69 | 70 | ins2label[ins_id] = label 71 | instance_eval.append({ 72 | 'id': ins_id, 73 | "text": "\n\n".join(text), 74 | "label": [] 75 | }) 76 | ins_id +=1 77 | 78 | if ins_id >= args.max_instances: 79 | break 80 | 81 | 82 | with open(os.path.join(args.output_dir, 'instance_gold.json'),'w') as f: 83 | json.dump(ins2label, f, indent=2) 84 | 85 | if args.n_splits> 1: 86 | print(f'splitting into {args.n_splits} for evaluation') 87 | samples_in_split = args.max_instances // args.n_splits 88 | for split_i in range(args.n_splits): 89 | if split_i < args.n_splits -1:# not last split 90 | split_instances = instance_eval[split_i * samples_in_split: (split_i+1)* samples_in_split] 91 | else: 92 | split_instances = instance_eval[split_i * samples_in_split: ] 93 | 94 | with open(os.path.join(args.output_dir, f'instance_eval_split{split_i}.json'), 'w') as f: 95 | json.dump(split_instances, f, indent=2) 96 | 97 | 98 | else: 99 | with open(os.path.join(args.output_dir, 'instance_eval.json'), 'w') as f: 100 | json.dump(instance_eval, f, indent=2) 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open Relation and Event Type Discovery with Type Abstraction 2 | 3 | This is the official implementation of the paper published at EMNLP 2022. 4 | 5 | ## TL;DR 6 | 7 | ![Our multi-view co-training model framework](figures/framework.jpg) 8 | 9 | Type discovery is the task of identifying new relation/event types from unlabeled data. 10 | 11 | This task is more challenging than zero-shot learning since we don't even have access to type names. 12 | 13 | Two ideas are important in making type discovery work: (1) transfer learning from known types (2) abstraction of type names. 14 | 15 | Having access to some known types helps the model learn "what is a relation/event". This is an idea that has been pursued by previous work like RSN and RoCORE. 16 | 17 | Abstraction of type names works by trying to align the granularity of the type to some human concept. Even with a perfect metric space, only certain granularities make sense to humans. 18 | 19 | 20 | ## Requirements 21 | 22 | ```bash 23 | pytorch=1.10.2 24 | pytorch-lightning=1.5.10 25 | torchmetrics=0.7.2 26 | transformers=4.11.3 27 | sklearn 28 | numpy 29 | networkx 30 | python-louvain=0.16 31 | pyyaml 32 | wandb 33 | ``` 34 | 35 | ## Datasets 36 | 37 | - The TACRED dataset is available [here](https://nlp.stanford.edu/projects/tacred/). We used the original version of the TACRED dataset because we needed to resplit the dataset by relation types. 38 | The model assumes that the data is organized in a `data/tacred` directory as follows: 39 | 40 | ```bash 41 | | 42 | |----train.json 43 | |----dev.json 44 | |----test.json 45 | |----relation_description.csv 46 | ``` 47 | 48 | This `relation_description.csv` file was adapted from RoCORE and we provide it under `data_sample/tacred`. 49 | 50 | - The FewRel dataset is available [here](https://github.com/thunlp/FewRel). The directory `data/fewrel` should contain the following files: 51 | 52 | ```bash 53 | | 54 | |----train_wiki.json 55 | |----val_wiki.json 56 | |----relation_description.csv 57 | ``` 58 | 59 | You can find the `relation_description.csv` file under `data_sample/fewrel/`. 60 | These `relation_description.csv` files control which types are considered known and which are unknown. We keep them the same as RoCORE. However, in your experiments, you can play around with the relation type split. 61 | 62 | - The ACE dataset is provided by LDC at (https://catalog.ldc.upenn.edu/LDC2006T06). We follow the preprocessing in OneIE which preserves multi-token trigger words. 63 | The model assumes that the path has a prefix such as `pro_mttrig_id/json/train.oneie.json`. This might not be the case for your system, so please change the setting in `src/event_data_module.py` line 396 as neccesary. 64 | On the ACE dataset, we determine which types are known by sorting the types by frequency. 65 | 66 | ## Running the Model 67 | 68 | The entry point of the model is `src/main.py`. Most of the hyperparameters are determined in `configs/`. 69 | A sample script for running the model on TACRED is provided below. (Be sure to specify your checkpoint names, random seed and gpu id.) 70 | 71 | Please change line 222 of `src/main.py` to your own project name and username if you would like to use Weights and Bias for model monitoring. 72 | 73 | ```bash 74 | PRETRAIN_CKPT_NAME= 75 | CKPT_NAME= 76 | SEED= 77 | GPU= 78 | 79 | if [ ! -d checkpoints/${PRETRAIN_CKPT_NAME} ]; then 80 | python src/main.py \ 81 | --load_configs='configs/rel/tacred/pretrain.yaml' \ 82 | --ckpt_name=${PRETRAIN_CKPT_NAME} \ 83 | --psuedo_label="other" \ 84 | --feature="all" \ 85 | --seed=${SEED} \ 86 | --gpus="${GPU}," 87 | 88 | 89 | echo "finished pretraining!" 90 | fi 91 | 92 | 93 | rm -rf checkpoints/${CKPT_NAME} 94 | python src/main.py \ 95 | --load_configs='configs/rel/tacred/train.yaml' \ 96 | --ckpt_name=${CKPT_NAME} \ 97 | --load_pretrained=checkpoints/${PRETRAIN_CKPT_NAME}/last.ckpt \ 98 | --psuedo_label="other" \ 99 | --feature="all" \ 100 | --seed=${SEED} \ 101 | --gpus="${GPU}," 102 | 103 | echo "finished model training!" 104 | 105 | 106 | for ckpt_file in checkpoints/${CKPT_NAME}/epoch*.ckpt; do 107 | echo "running test for ${ckpt_file}" 108 | python src/main.py \ 109 | --load_configs='configs/rel/tacred/train.yaml' \ 110 | --ckpt_name=${CKPT_NAME} \ 111 | --load_ckpt=${ckpt_file} \ 112 | --psuedo_label="other" \ 113 | --feature="all" \ 114 | --eval_only \ 115 | --seed=${SEED} \ 116 | --gpus="${GPU}," 117 | done 118 | ``` 119 | 120 | Your checkpoint will be in a directory `checkpoints/`. The `test_unknown_clusters.json` file will directly show you the results of each cluster and the `test_unknown_metrics.json` file will print out the metrics. 121 | 122 | You can also run the baselines by changing the configuration path. 123 | 124 | ### Sanity checking 125 | 126 | Due to the size of the datasets, randomness might be an issue. 127 | On TACRED, getting an accuracy between 0.87-0.91 is considered normal. 128 | 129 | -------------------------------------------------------------------------------- /baselines/RSN.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementation following https://github.com/thunlp/RSN/blob/master/RSN/model/siamodel.py 3 | ''' 4 | from typing import List, Dict, Tuple 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from common.utils import onedim_gather 12 | 13 | # def _cnn_(cnn_input_shape,name=None): 14 | 15 | # convnet = Sequential() 16 | # convnet.add(Conv1D(230, 3, 17 | # input_shape = cnn_input_shape, 18 | # kernel_initializer = W_init, 19 | # bias_initializer = b_init_conv, 20 | # kernel_regularizer=l2(2e-4) 21 | # )) 22 | # convnet.add(MaxPooling1D(pool_size=cnn_input_shape[0]-4)) 23 | # convnet.add(Activation('relu')) 24 | 25 | # convnet.add(Flatten()) 26 | # convnet.add(Dense(cnn_input_shape[-1]*230, activation = 'sigmoid', 27 | # kernel_initializer = W_init, 28 | # bias_initializer = b_init_dense, 29 | # kernel_regularizer=l2(1e-3) 30 | # )) 31 | # return convnet 32 | 33 | class ConvNet(nn.Module): 34 | def __init__(self, input_dim: int, output_dim: int) -> None: 35 | super().__init__() 36 | self.conv = nn.Conv1d(in_channels=input_dim, out_channels=output_dim, 37 | kernel_size=3) 38 | self.linear = nn.Linear(output_dim, output_dim) 39 | # TODO: kernel regularizer? 40 | 41 | 42 | def forward(self, x: torch.FloatTensor)-> torch.FloatTensor: 43 | ''' 44 | :params x: (batch, seq_len, input_dim) 45 | :return (batch, output_dim) 46 | ''' 47 | seq_len = x.size(1) 48 | x = x.transpose(1,2) 49 | conv_output = self.conv(x) #(batch, output_dim, seq_len-2) 50 | 51 | pooled_output = torch.max_pool1d(conv_output, seq_len-2).squeeze(2) # (batch, output_dim) 52 | pooled_output = torch.relu(pooled_output) 53 | output = torch.sigmoid(self.linear(pooled_output)) 54 | 55 | return output 56 | 57 | 58 | 59 | 60 | class RSNLayer(nn.Module): 61 | def __init__(self, input_dim, rel_dim, use_cnn:bool=True, dropout_p:float=0.1, max_len=300, pos_emb_dim=5) -> None: 62 | super().__init__() 63 | 64 | self.use_cnn = use_cnn 65 | self.max_len = max_len 66 | 67 | self.dropout = nn.Dropout(p=dropout_p) 68 | self.pos_emb = nn.Embedding(num_embeddings=max_len*2, embedding_dim= pos_emb_dim) 69 | 70 | if use_cnn: 71 | self.conv_net = ConvNet(input_dim +2*pos_emb_dim, rel_dim) 72 | else: 73 | self.proj = nn.Linear(input_dim*2, rel_dim) 74 | self.p = nn.Linear(rel_dim, 1) 75 | 76 | 77 | def get_pos_embedding(self, spans, indexes): 78 | ''' 79 | :param spans: (B, 2) 80 | :param indexes: (B, seq_len) 81 | 82 | :return pos_embed: (B, seq_len, pos_emb_dim) 83 | ''' 84 | pos1 = indexes - spans[:, 0].unsqueeze(1) 85 | pos_embed = self.pos_emb(pos1+self.max_len) 86 | return pos_embed 87 | 88 | def embed(self, head_spans, tail_spans, seq_output): 89 | ''' 90 | Used for prediction. 91 | 92 | :return x: (B, rel_dim) 93 | ''' 94 | seq_len = seq_output.size(1) 95 | batch_size = seq_output.size(0) 96 | seq_output = self.dropout(seq_output) 97 | 98 | if self.use_cnn: 99 | # position embeddings 100 | indexes = torch.arange(0, seq_len, dtype=torch.long, device=head_spans.device).repeat((batch_size, 1)) # B, seq_len 101 | head_pos_embed = self.get_pos_embedding(head_spans, indexes) 102 | tail_pos_embed = self.get_pos_embedding(tail_spans, indexes) 103 | seq_output_with_pos = torch.concat([seq_output, head_pos_embed, tail_pos_embed], dim=2) 104 | 105 | x = self.conv_net(seq_output_with_pos) #B, rel_dim 106 | else: 107 | head_rep = onedim_gather(seq_output, dim=1, index=head_spans[:, 0].unsqueeze(1)) 108 | tail_rep = onedim_gather(seq_output, dim=1, index=tail_spans[:, 0].unsqueeze(1)) 109 | feat = torch.cat([head_rep, tail_rep], dim = 2).squeeze(1) 110 | x = self.proj(feat) # B, rel_dim 111 | 112 | return x 113 | 114 | def compute_distance(self, x1, x2): 115 | dis = torch.abs(x1 - x2) # (rel_dim) 116 | prob = torch.sigmoid(self.p(dis)) 117 | return prob 118 | 119 | 120 | def forward(self, head_spans: torch.LongTensor, tail_spans:torch.LongTensor, 121 | seq_output: torch.FloatTensor, perturb:torch.FloatTensor=None) ->torch.FloatTensor: 122 | ''' 123 | :param seq_output: (B, seq_len, input_dim) 124 | :return: (B, B) 125 | ''' 126 | seq_len = seq_output.size(1) 127 | batch_size = seq_output.size(0) 128 | seq_output = self.dropout(seq_output) 129 | if perturb!=None: 130 | seq_output += perturb 131 | 132 | if self.use_cnn: 133 | # position embeddings 134 | indexes = torch.arange(0, seq_len, device=head_spans.device).repeat((batch_size, 1)) # B, seq_len 135 | head_pos_embed = self.get_pos_embedding(head_spans, indexes) 136 | tail_pos_embed = self.get_pos_embedding(tail_spans, indexes) 137 | seq_output_with_pos = torch.concat([seq_output, head_pos_embed, tail_pos_embed], dim=2) 138 | 139 | x = self.conv_net(seq_output_with_pos) #B, rel_dim 140 | else: 141 | head_rep = onedim_gather(seq_output, dim=1, index=head_spans[:, 0].unsqueeze(1)) 142 | tail_rep = onedim_gather(seq_output, dim=1, index=tail_spans[:, 0].unsqueeze(1)) 143 | feat = torch.cat([head_rep, tail_rep], dim = 2).squeeze(1) 144 | x = self.proj(feat) # B, rel_dim 145 | 146 | dis = torch.abs(x.unsqueeze(1) - x.unsqueeze(0)) # (B, B, rel_dim) 147 | pair_prob = torch.sigmoid(self.p(dis).squeeze(-1)) 148 | 149 | return pair_prob 150 | -------------------------------------------------------------------------------- /common/predictions.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter 2 | from typing import Dict, Tuple, List,Optional 3 | import json 4 | import os 5 | 6 | import torch 7 | 8 | 9 | 10 | class PairwiseClusterPredictionsWrapper(torch.nn.Module): 11 | def __init__(self, prefix: str='', task: str='rel') -> None: 12 | super().__init__() 13 | self.prefix = prefix 14 | self.task = task 15 | 16 | self.predictions = {} 17 | 18 | def on_epoch_end(self, pred_cluster:Dict, metadata:Dict)->None: 19 | 20 | for uid in pred_cluster: 21 | cluster_idx = pred_cluster[uid] 22 | target_idx = metadata[uid]['label_idx'] 23 | sentence = ' '.join(metadata[uid]['tokens']) 24 | label = metadata[uid]['label'] # type: str 25 | self.predictions[uid] = { 26 | 'uid': uid, 27 | 'sentence': sentence, 28 | 'label': label, 29 | 'cluster_idx': cluster_idx, 30 | 'target_idx': target_idx 31 | } 32 | if self.task == 'rel': 33 | self.predictions[uid]['subj'] = metadata[uid]['subj'] 34 | self.predictions[uid]['obj'] = metadata[uid]['obj'] 35 | elif self.task == 'event': 36 | self.predictions[uid]['trigger'] = metadata[uid]['trigger'] 37 | 38 | 39 | # index by cluster 40 | self.clusters = defaultdict(list) 41 | for uid, d in self.predictions.items(): 42 | cluster_idx = d['cluster_idx'] 43 | self.clusters[cluster_idx].append(d) 44 | return 45 | 46 | 47 | def save(self, ckpt_dir: str): 48 | with open(os.path.join(ckpt_dir, f'{self.prefix}_predictions.json'),'w') as f: 49 | json.dump(self.predictions, f, indent=2) 50 | 51 | with open(os.path.join(ckpt_dir, f'{self.prefix}_clusters.json'),'w') as f: 52 | json.dump(self.clusters, f, indent=2) 53 | 54 | 55 | 56 | class ClusterPredictionsWrapper(torch.nn.Module): 57 | def __init__(self, reassign: bool=False, prefix: str='', known_classes: int=31, 58 | task: str='rel', save_names: bool =False) -> None: 59 | super().__init__() 60 | self.reassign = reassign 61 | self.prefix = prefix 62 | self.known_classes=known_classes 63 | self.task = task 64 | self.save_names = save_names 65 | 66 | self.pred_cluster_cache = [] 67 | self.target_cluster_cache = [] 68 | self.pred_prob_cache = [] 69 | 70 | self.predictions = {} # uid -> {tokens, label, cluster_idx} 71 | 72 | def update_batch(self, meta: List[Dict], logits: torch.FloatTensor, 73 | targets: torch.LongTensor, incremental:bool=False, names: Optional[List[str]]=None)->None: 74 | ''' 75 | :param logits: (batch, n_cluster) 76 | :param targets: (batch) 77 | ''' 78 | 79 | if not incremental: 80 | assert (targets.min() >= self.known_classes) 81 | # consider only unknown classes 82 | prob = torch.softmax(logits[:, self.known_classes:], dim=1) 83 | pred_prob, pred_cluster = torch.max(prob, dim=1) 84 | self.pred_cluster_cache.append(pred_cluster) 85 | self.target_cluster_cache.append(targets-self.known_classes) 86 | self.pred_prob_cache.append(pred_prob) 87 | else: 88 | prob = torch.softmax(logits, dim=1) 89 | pred_prob, pred_cluster = torch.max(prob, dim=1) 90 | self.target_cluster_cache.append(targets) 91 | self.pred_prob_cache.append(pred_prob) 92 | 93 | batch_size = logits.size(0) 94 | for i in range(batch_size): 95 | uid = meta[i]['uid'] 96 | sentence = ' '.join(meta[i]['tokens']) 97 | cluster_idx = pred_cluster[i].item() 98 | target_idx = targets[i].item() 99 | self.predictions[uid] = { 100 | 'uid': uid, 101 | 'sentence': sentence, 102 | 'label': meta[i]['label'], 103 | 'cluster_idx': cluster_idx, 104 | 'target_idx': target_idx, 105 | 'prob': pred_prob[i].item() 106 | } 107 | 108 | if names!= None: 109 | self.predictions[uid]['names'] = names[i].split() 110 | else: 111 | self.predictions[uid]['names'] = [] 112 | if self.task == 'rel': 113 | self.predictions[uid]['subj'] = meta[i]['subj'] 114 | self.predictions[uid]['obj'] = meta[i]['obj'] 115 | elif self.task == 'event': 116 | self.predictions[uid]['trigger'] = meta[i]['trigger'] 117 | 118 | 119 | return 120 | 121 | def on_epoch_end(self): 122 | # index by cluster 123 | self.clusters = defaultdict(list) 124 | cluster_scores = defaultdict(list) # cluster_id -> List[(uid, prob)] 125 | cluster_names = defaultdict(Counter) # cluster_id -> counter 126 | for uid, d in self.predictions.items(): 127 | cluster_idx = d['cluster_idx'] 128 | cluster_scores[cluster_idx].append((uid, d['prob'])) 129 | cluster_names[cluster_idx].update(d['names']) 130 | 131 | self.cluster_freq_names = {} 132 | # sort by score 133 | for cluster_idx in cluster_scores.keys(): 134 | uid_sorted = sorted(cluster_scores[cluster_idx], key=lambda x: x[1], reverse=True) 135 | for tup in uid_sorted: 136 | uid = tup[0] 137 | d = self.predictions[uid] 138 | self.clusters[cluster_idx].append(d) 139 | self.cluster_freq_names[cluster_idx] = cluster_names[cluster_idx].most_common(n=10) # List (str, int) 140 | 141 | def save(self, ckpt_dir: str): 142 | with open(os.path.join(ckpt_dir, f'{self.prefix}_predictions.json'),'w') as f: 143 | json.dump(self.predictions, f, indent=2) 144 | 145 | with open(os.path.join(ckpt_dir, f'{self.prefix}_clusters.json'),'w') as f: 146 | json.dump(self.clusters, f, indent=2) 147 | 148 | 149 | if self.save_names: 150 | with open(os.path.join(ckpt_dir, f'{self.prefix}_cluster_names.json'),'w') as f: 151 | json.dump(self.cluster_freq_names, f, indent=2) -------------------------------------------------------------------------------- /baselines/RoCORE_layers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from common.center_loss import CenterLoss 8 | 9 | ''' 10 | Code from RoCORE/model.py 11 | ''' 12 | 13 | def L2Reg(net): 14 | reg_loss = 0 15 | for name, params in net.named_parameters(): 16 | if name[-4:] != 'bias': 17 | reg_loss += torch.sum(torch.pow(params, 2)) 18 | return reg_loss 19 | 20 | 21 | def compute_kld(p_logit, q_logit): 22 | p = F.softmax(p_logit, dim = -1) # (B, B, n_class) 23 | q = F.softmax(q_logit, dim = -1) # (B, B, n_class) 24 | return torch.sum(p * (torch.log(p + 1e-16) - torch.log(q + 1e-16)), dim = -1) # (B, B) 25 | 26 | 27 | 28 | class ZeroShotModel(nn.Module): 29 | def __init__(self, args, known_types: int, unknown_types: int, model_config, pretrained_model, unfreeze_layers = []): 30 | super().__init__() 31 | # self.IL = args.IL 32 | self.known_types = known_types 33 | self.unknown_types = unknown_types 34 | self.hidden_dim = args.hidden_dim 35 | self.kmeans_dim = args.kmeans_dim 36 | self.initial_dim = model_config.hidden_size 37 | self.unfreeze_layers = unfreeze_layers 38 | self.pretrained_model = self.finetune(pretrained_model, self.unfreeze_layers) # fix bert weights 39 | self.layer = args.layer 40 | 41 | 42 | self.similarity_encoder = nn.Sequential( 43 | nn.Linear(2 * self.initial_dim, self.hidden_dim), 44 | nn.LeakyReLU(), 45 | nn.Linear(self.hidden_dim, self.hidden_dim), 46 | nn.LeakyReLU(), 47 | nn.Linear(self.hidden_dim, self.kmeans_dim) 48 | ) 49 | self.similarity_decoder = nn.Sequential( 50 | nn.Linear(self.kmeans_dim, self.hidden_dim), 51 | nn.LeakyReLU(), 52 | nn.Linear(self.hidden_dim, self.hidden_dim), 53 | nn.LeakyReLU(), 54 | nn.Linear(self.hidden_dim, 2 * self.initial_dim) 55 | ) 56 | self.ct_loss_u = CenterLoss(dim_hidden = self.kmeans_dim, num_classes = self.unknown_types, alpha=1.0, weight_by_prob=True) 57 | self.ct_loss_l = CenterLoss(dim_hidden = self.kmeans_dim, num_classes = self.known_types) 58 | # if self.IL: 59 | # self.labeled_head = nn.Linear(2 * self.initial_dim, self.known_types + self.unknown_types) 60 | self.labeled_head = nn.Linear(2 * self.initial_dim, self.known_types) 61 | self.unlabeled_head = nn.Linear(2 * self.initial_dim, self.unknown_types) 62 | self.bert_params = [] 63 | for name, param in self.pretrained_model.named_parameters(): 64 | if param.requires_grad is True: 65 | self.bert_params.append(param) 66 | 67 | @staticmethod 68 | def finetune(model, unfreeze_layers): 69 | params_name_mapping = ['embeddings', 'layer.0', 'layer.1', 'layer.2', 'layer.3', 'layer.4', 'layer.5', 'layer.6', 'layer.7', 'layer.8', 'layer.9', 'layer.10', 'layer.11', 'layer.12'] 70 | for name, param in model.named_parameters(): 71 | param.requires_grad = False 72 | for ele in unfreeze_layers: 73 | if params_name_mapping[ele] in name: 74 | param.requires_grad = True 75 | break 76 | return model 77 | 78 | def get_pretrained_feature(self, input_id, input_mask, head_span, tail_span): 79 | outputs = self.pretrained_model(input_id, token_type_ids=None, attention_mask=input_mask) # (13 * [batch_size, seq_len, bert_embedding_len]) 80 | all_encoder_layers = outputs[2] 81 | encoder_layers = all_encoder_layers[self.layer] # (batch_size, seq_len, bert_embedding) 82 | batch_size = encoder_layers.size(0) 83 | head_entity_rep = torch.stack([torch.max(encoder_layers[i, head_span[i][0]:head_span[i][1], :], dim = 0)[0] for i in range(batch_size)], dim = 0) 84 | tail_entity_rep = torch.stack([torch.max(encoder_layers[i, tail_span[i][0]:tail_span[i][1], :], dim = 0)[0] for i in range(batch_size)], dim = 0) # (batch_size, bert_embedding) 85 | pretrained_feat = torch.cat([head_entity_rep, tail_entity_rep], dim = 1) # (batch_size, 2 * bert_embedding) 86 | return pretrained_feat 87 | 88 | def forward(self, batch: Dict[str, torch.Tensor], mask: Optional[torch.BoolTensor]=None, msg: str='similarity', cut_gradient:bool=False): 89 | input_ids = batch['token_ids'] 90 | input_mask = batch['attn_mask'] 91 | head_span = batch['head_spans'] 92 | tail_span = batch['tail_spans'] 93 | 94 | if mask!= None: 95 | input_ids = input_ids[mask] 96 | input_mask = input_mask[mask] 97 | head_span = head_span[mask] 98 | tail_span = tail_span[mask] 99 | 100 | if msg == 'similarity':# used for centroid update 101 | with torch.no_grad(): 102 | pretrained_feat = self.get_pretrained_feature(input_ids, input_mask, head_span, tail_span) # (batch_size, 2 * bert_embedding) 103 | commonspace_rep = self.similarity_encoder(pretrained_feat) # (batch_size, keamns_dim) 104 | return commonspace_rep # (batch_size, keamns_dim) 105 | 106 | elif msg == 'reconstruct': 107 | with torch.no_grad(): 108 | pretrained_feat = self.get_pretrained_feature(input_ids, input_mask, head_span, tail_span) # (batch_size, 2 * bert_embedding) 109 | commonspace_rep = self.similarity_encoder(pretrained_feat) # (batch_size, kmeans_dim) 110 | rec_rep = self.similarity_decoder(commonspace_rep) # (batch_size, 2 * bert_embedding) 111 | rec_loss = (rec_rep - pretrained_feat).pow(2).mean(-1) 112 | return commonspace_rep, rec_loss 113 | 114 | 115 | elif msg == 'labeled': 116 | pretrained_feat = self.get_pretrained_feature(input_ids, input_mask, head_span, tail_span) # (batch_size, 2 * bert_embedding) 117 | if cut_gradient: 118 | pretrained_feat = pretrained_feat.detach() 119 | logits = self.labeled_head(pretrained_feat) 120 | return logits # (batch_size, num_class) 121 | 122 | elif msg == 'unlabeled': 123 | pretrained_feat = self.get_pretrained_feature(input_ids, input_mask, head_span, tail_span) # (batch_size, 2 * bert_embedding) 124 | if cut_gradient: 125 | pretrained_feat = pretrained_feat.detach() 126 | logits = self.unlabeled_head(pretrained_feat) 127 | return logits # (batch_size, new_class) 128 | 129 | else: 130 | raise NotImplementedError('not implemented!') 131 | -------------------------------------------------------------------------------- /common/b3.py: -------------------------------------------------------------------------------- 1 | ''' 2 | B-cube function from: 3 | https://github.com/m-wiesner/BCUBED/blob/master/B3score/b3.py 4 | ''' 5 | 6 | # ========================== B3 =============================================== 7 | # This function determines the extrinsic clustering quality b-cubed measure 8 | # using a set of known labels for a data set, and cluster assignmnets of 9 | # each data point stored as either vector or cell arrays where each 10 | # cell/row represets a data point in which an array containing class or cluster 11 | # assignments is stored. The multi-class variant should not yet be used and 12 | # does not scale well. 13 | # 14 | # Inputs 15 | # ------------------ 16 | # L: An NxM matrix containing class labels for each data point. Each 17 | # row represents the ith data point and each column contains a 0, or 1 18 | # in the jth column indicating membership of label j. 19 | # Alternatively, if hard class labels are available, L can be input 20 | # as an Nx1 vector where each entry is its class label 21 | # 22 | # K: Defined identically to L except for this variable stores cluster 23 | # assignments for each data point 24 | # 25 | # Outputs 26 | #-------------------------------- 27 | # f_measure: This F-measure using the b-cubed metric 28 | # precision: The b-cubed precision 29 | # recall: The b-cubed recall 30 | # 31 | #----------------------------------------------- 32 | # Author: Matthew Wiesner 33 | # Email : wiesner@jhu.edu 34 | # Institute: Johns Hopkins University Electrical and Computer Engineering 35 | # 36 | # Refences: DESCRIPTION OF THE UPENN CAMP SYSTEM AS USED FOR COREFERENCE, 37 | # Breck Baldwin, Tom Morton, Amit Bagga, Jason Baldridge, 38 | # Raman Chandraseker, Alexis Dimitriadis, Kieran Snyder, 39 | # Magdalena Wolska, Institute for Research in Cognitive Science 40 | # 41 | # A.A. Aroch-Villarruel Pattern Recognition 6th Mexican Conference, 42 | # MCPR 2014 Proceedings Paper, p.115 43 | # ---------------------------------------------- 44 | import sys 45 | import numpy as np 46 | 47 | # Calculate precision and recall for each class 48 | def compute_class_precision_recall(L,K): 49 | ''' 50 | Compute the partitions matrix P which stores the size 51 | of the intersection of elements belong to Label i and Cluster j 52 | in the (i,j)-th entry of P 53 | 54 | Input: 55 | L -- Numpy array of Labels or numpy 2d array with shape (1,N_L) or (N_L,1) 56 | K -- Numpy array of Clusters or numpy 2d array with shape (1,N_K) or (N_K,1) 57 | 58 | Output: 59 | P -- Numpy ndarray |\L| x |\K| Partitions Matrix, where |\L| is the 60 | size of the label set, and \K is the number of clusters 61 | ''' 62 | # Make everything nicely formatted. Ignore skipped labels or clusters 63 | _,L = np.unique(np.array(L),return_inverse=True) 64 | _,K = np.unique(np.array(K),return_inverse=True) 65 | 66 | # Check that there are the same number of labels and clusters 67 | if(len(L) != len(K)): 68 | sys.stderr.write("Labels and clusters are not of the same length.") 69 | sys.exit(1) 70 | 71 | # Extract some useful variables that will make things easier to read. 72 | # 1. Number of total elements to cluster 73 | # 2. Number of distinct labels 74 | # 3. Number of distinct clusters 75 | num_elements = len(L) 76 | num_labels = L.max() + 1 77 | num_clusters = K.max() + 1 78 | 79 | # Create binary num_elements x num_labels / num_clusters assignment matrices. 80 | X_L = np.tile(L, (num_labels,1) ).T 81 | X_K = np.tile(K, (num_clusters,1) ).T 82 | 83 | L_j = np.equal( np.tile(np.arange(num_labels),(num_elements,1)) , X_L ).astype(float) 84 | K_j = np.equal( np.tile(np.arange(num_clusters),(num_elements,1)) , X_K ).astype(float) 85 | 86 | # Create the partitions matrix which has an element for the 87 | # intersection of label i, and cluster j. The element of the matrix is the 88 | # Number of elements in that partition. 89 | P_ij = np.dot(L_j.T,K_j) 90 | 91 | # Summing over the appropriate axes gives the total number of elements 92 | # in each class label (S_i) or cluster T_i 93 | S_i = P_ij.sum(axis=1) 94 | T_i = P_ij.sum(axis=0) 95 | 96 | # Calculate Class recall and precision 97 | R_i = ( P_ij * P_ij ).sum(axis=1) / ( S_i * S_i ) 98 | P_i = ( P_ij.T * P_ij.T ).sum(axis=1) / ( T_i * T_i ) 99 | 100 | return [(P_i , R_i) , (S_i , T_i)] 101 | 102 | # Calculate b3 metrics 103 | def calc_b3(L , K , class_norm=False, beta=1.0): 104 | ''' 105 | Implements the BCUBED algorithm according to the DESCRIPTION OF THE UPENN 106 | CAMP SYSTEM AS USED FOR COREFERENCE, Breck Baldwin, Tom Morton, 107 | Amit Bagga, Jason Baldridge, Raman Chandraseker, Alexis Dimitriadis, 108 | Kieran Snyder, Magdalena Wolska, Institute for Research in Cognitive 109 | Science. 110 | 111 | Usage: 112 | from B3 import B3 113 | import numpy as np 114 | 115 | score = B3() 116 | L = np.array([1,3,3,3,3,4,2,2,2,3,3]) 117 | K = np.array([1,2,3,4,5,5,5,6,2,1,1]) 118 | 119 | # Standard BCUBED (Weight each element equally) 120 | [fmeasure, precision, recall] = score.calc_b3(L,K) 121 | 122 | # Equivalence class normalization (Weight each class equally) 123 | [fmeasure, precision, recall] = score.calc_b3(L,K,class_norm=True) 124 | 125 | # Different weighting schemes for fmeasure 126 | [fmeasure, precision, recall] = score.calc_b3(L,K,beta=2.0) 127 | [fmeasure, precision, recall] = score.calc_b3(L,K,beta=0.5) 128 | 129 | 130 | Computes the precision, recall, and fmeasure from the Class level 131 | precision, and recall arrays. Two types are possible. One weights all 132 | classes equally while the other weights each element equally. 133 | 134 | Input: 135 | L -- Numpy array of Labels or numpy 2d array with shape (1,N_L) or (N_L,1) 136 | K -- Numpy array of Clusters or numpy 2d array with shape (1,N_K) or (N_K,1) 137 | 138 | options: 139 | class_norm: Decides whether to weight the precision by class or by entity 140 | beta: Harmonic mean weighting 141 | ''' 142 | 143 | # Compute per equivalence class precision and recall 144 | precision_recall , class_sizes = compute_class_precision_recall(L,K) 145 | 146 | # Two methods of obtaining overall precision and recall 147 | if(class_norm == True): 148 | precision = precision_recall[0].sum() / class_sizes[1].size 149 | recall = precision_recall[1].sum() / class_sizes[0].size 150 | else: 151 | precision = ( precision_recall[0] * class_sizes[1] ).sum() / class_sizes[1].sum() 152 | recall = ( precision_recall[1] * class_sizes[0] ).sum() / class_sizes[0].sum() 153 | 154 | # f_measure with option beta to weight the precision and recall asymmetrically. 155 | f_measure = (1 + beta**2) * (precision * recall) /( (beta**2) * precision + recall ) 156 | 157 | return [f_measure,precision,recall] 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /src/multiview_layers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple, Optional 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.distributions.categorical import Categorical 7 | from torch.distributions.kl import kl_divergence 8 | 9 | from common.layers import ClassifierHead, MLP, CommonSpaceCache, ReconstructionNet 10 | from common.center_loss import CenterLoss 11 | from common.utils import onedim_gather 12 | 13 | 14 | 15 | import common.log as log 16 | logger = log.get_logger('root') 17 | 18 | class MultiviewModel(nn.Module): 19 | def __init__(self, args, model_config, pretrained_model, unfreeze_layers: List=[]) -> None: 20 | super().__init__() 21 | self.args = args 22 | 23 | self.layer = args.layer 24 | if args.freeze_pretrain: 25 | self.pretrained_model = self.finetune(pretrained_model, unfreeze_layers) 26 | else: 27 | self.pretrained_model = pretrained_model 28 | 29 | self.views = nn.ModuleList() 30 | if args.feature == 'all': feature_types = ['token','mask'] 31 | elif args.feature == 'mask': feature_types = ['mask', 'mask'] 32 | elif args.feature == 'token': feature_types = ['token', 'token'] 33 | else: 34 | raise NotImplementedError 35 | 36 | if self.args.rev_ratio >0: 37 | known_head_types = 2 * args.known_types 38 | else: 39 | known_head_types = args.known_types 40 | 41 | for view_idx, ft in enumerate(feature_types): 42 | if ft == 'mask': 43 | view_model = nn.ModuleDict( 44 | { 45 | 'common_space_proj': MLP(model_config.hidden_size, args.hidden_dim, args.kmeans_dim, 46 | norm=True, norm_type='batch', layers_n=2, dropout_p =0.1), 47 | 'known_type_center_loss': CenterLoss(args.kmeans_dim, args.known_types, weight_by_prob=False), 48 | 'unknown_type_center_loss': CenterLoss(args.kmeans_dim, args.unknown_types, weight_by_prob=False), 49 | 'known_type_classifier': ClassifierHead(args, args.kmeans_dim, 50 | known_head_types, layers_n=args.classifier_layers, n_heads=1, dropout_p=0.0, hidden_size=args.kmeans_dim), 51 | 'unknown_type_classifier': ClassifierHead(args, args.kmeans_dim, 52 | args.unknown_types, layers_n=args.classifier_layers, n_heads=1, dropout_p=0.0, hidden_size=args.kmeans_dim) 53 | } 54 | ) 55 | else: 56 | if self.args.task == 'rel': 57 | input_size = 2 * model_config.hidden_size # head, tail 58 | else: 59 | input_size = model_config.hidden_size # trigger 60 | 61 | view_model = nn.ModuleDict( 62 | { 63 | 'common_space_proj': MLP(input_size, args.hidden_dim, args.kmeans_dim, 64 | norm=True, norm_type='batch', layers_n=2, dropout_p=0.1), 65 | 'known_type_center_loss': CenterLoss(args.kmeans_dim, args.known_types, weight_by_prob=False), 66 | 'unknown_type_center_loss': CenterLoss(args.kmeans_dim, args.unknown_types, weight_by_prob=False), 67 | 'known_type_classifier': ClassifierHead(args, args.kmeans_dim, 68 | known_head_types, layers_n=args.classifier_layers, n_heads=1, dropout_p=0.0, hidden_size=args.kmeans_dim), 69 | 'unknown_type_classifier': ClassifierHead(args, args.kmeans_dim, 70 | args.unknown_types, layers_n=args.classifier_layers, n_heads=1, dropout_p=0.0, hidden_size=args.kmeans_dim) 71 | } 72 | ) 73 | self.views.append(view_model) 74 | 75 | # this commonspace means that known classes and unknown classes are projected into the same space 76 | self.commonspace_cache = nn.ModuleList([ 77 | CommonSpaceCache(feature_size=args.kmeans_dim, known_cache_size=512, unknown_cache_size=256, sim_thres=0.8), 78 | CommonSpaceCache(feature_size=args.kmeans_dim, known_cache_size=512, unknown_cache_size=256, sim_thres=0.8) 79 | ]) 80 | 81 | return 82 | 83 | 84 | # FIXME: this function is taken from ROCORE, will use layer.7 instead of layer.8 as described in the paper. 85 | @staticmethod 86 | def finetune(model, unfreeze_layers): 87 | params_name_mapping = ['embeddings', 'layer.0', 'layer.1', 'layer.2', 'layer.3', 'layer.4', 'layer.5', 'layer.6', 'layer.7', 'layer.8', 'layer.9', 'layer.10', 'layer.11', 'layer.12'] 88 | for name, param in model.named_parameters(): 89 | param.requires_grad = False 90 | for ele in unfreeze_layers: 91 | if params_name_mapping[ele] in name: 92 | param.requires_grad = True 93 | break 94 | return model 95 | 96 | 97 | def generate_default_inputs(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 98 | """Generate the default inputs required by almost every language model.""" 99 | inputs = {'input_ids': batch['token_ids'], 'attention_mask': batch['attn_mask']} 100 | return inputs 101 | 102 | def predict_name(self, batch:Dict, topk: int=10) -> torch.LongTensor: 103 | lm_inputs = self.generate_default_inputs(batch) 104 | outputs = self.pretrained_model(**lm_inputs) 105 | vocab_logits = outputs[0] #(batch_size, seq, vocab_size) 106 | mask_bpe_idx = batch['mask_bpe_idx'] 107 | mask_logits = onedim_gather(vocab_logits, dim=1, index=mask_bpe_idx.unsqueeze(1)).squeeze(1) # (batch_size, vocab_size) 108 | predicted_token_ids = mask_logits.argsort(dim=-1, descending=True)[:, :topk] 109 | return predicted_token_ids 110 | 111 | 112 | def _compute_features(self, batch:Dict, method: str='token', pooling: str='first')-> torch.FloatTensor: 113 | lm_inputs = self.generate_default_inputs(batch) 114 | outputs = self.pretrained_model(**lm_inputs) 115 | # seq_output = outputs[0] # last layer output, only works with AutoModel, does not work with AutoModelForMaskedLM 116 | 117 | all_encoder_layers = outputs[-1] 118 | seq_output = all_encoder_layers[-1] 119 | 120 | if method == 'token': 121 | if batch['task'] == 'rel': 122 | head_spans = batch['head_spans'] # (batch, 2) 123 | tail_spans = batch['tail_spans'] # (batch,2 ) 124 | batch_size = head_spans.size(0) 125 | 126 | if pooling == 'first': 127 | # taking the first token as the representation for the entity 128 | head_rep = onedim_gather(seq_output, dim=1, index=head_spans[:, 0].unsqueeze(1)) 129 | tail_rep = onedim_gather(seq_output, dim=1, index=tail_spans[:, 0].unsqueeze(1)) 130 | feat = torch.cat([head_rep, tail_rep], dim = 2).squeeze(1) 131 | elif pooling == 'max': 132 | # max pooling over the tokens in the entity 133 | head_rep_list = [] 134 | tail_rep_list = [] 135 | for i in range(batch_size): 136 | head_ent = seq_output[i, head_spans[i, 0]: head_spans[i, 1], :] #(ent_len, hidden_dim) 137 | head_ent_max, _ = torch.max(head_ent, dim=0) 138 | head_rep_list.append(head_ent_max) 139 | 140 | tail_ent = seq_output[i, tail_spans[i, 0]: tail_spans[i, 1], :] #(ent_len, hidden_dim) 141 | tail_ent_max, _ = torch.max(tail_ent, dim=0) 142 | tail_rep_list.append(tail_ent_max) 143 | head_rep = torch.stack(head_rep_list, dim=0) 144 | tail_rep = torch.stack(tail_rep_list, dim=0) 145 | feat= torch.cat([head_rep, tail_rep], dim=1) 146 | elif batch['task'] == 'event': 147 | trigger_spans = batch['trigger_spans'] 148 | batch_size = trigger_spans.size(0) 149 | if pooling == 'first': 150 | feat = onedim_gather(seq_output, dim=1, index=trigger_spans[:, 0].unsqueeze(1)).squeeze(1) 151 | elif pooling == 'max': 152 | rep_list = [] 153 | for i in range(batch_size): 154 | tgr = seq_output[i, trigger_spans[i, 0]: trigger_spans[i, 1], :] 155 | tgr_max, _ = torch.max(tgr, dim=0) 156 | rep_list.append(tgr_max) 157 | feat = torch.stack(rep_list, dim=0) 158 | 159 | elif method == 'mask': 160 | mask_bpe_idx = batch['mask_bpe_idx'] # (batch) 161 | seq_len = seq_output.size(1) 162 | assert (mask_bpe_idx.max() < seq_len), "mask token out of bounds" 163 | feat = onedim_gather(seq_output, dim=1, index=mask_bpe_idx.unsqueeze(1)).squeeze(1) 164 | 165 | else: 166 | raise NotImplementedError 167 | return feat 168 | 169 | 170 | def _compute_prediction_logits(self, batch: Dict, 171 | method:str ='token', pooling: str ='first', 172 | view_idx: int =0)-> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 173 | ''' 174 | :param seq_output: (batch, seq_len, hidden_dim) 175 | :param method: str, one of 'token', 'mask' 176 | :param pooling: str, one of first, max 177 | :param view_idx: int 178 | ''' 179 | feat = self._compute_features(batch, method, pooling) 180 | 181 | view_model = self.views[view_idx] 182 | common_space_feat = view_model['common_space_proj'](feat) 183 | known_head_logits = view_model['known_type_classifier'](common_space_feat) 184 | unknown_head_logits = view_model['unknown_type_classifier'](common_space_feat) 185 | 186 | predicted_logits = torch.cat([known_head_logits, unknown_head_logits], dim=1) 187 | 188 | 189 | return predicted_logits, common_space_feat, feat 190 | 191 | def _on_train_batch_start(self): 192 | # normalize all centroids 193 | for view_model in self.views: 194 | view_model['known_type_classifier'].update_centroid() 195 | view_model['unknown_type_classifier'].update_centroid() 196 | 197 | return 198 | 199 | 200 | def update_centers(self, centers: torch.FloatTensor, known:bool=True, view_idx: int=0): 201 | if known: 202 | self.views[view_idx]['known_type_center_loss'].centers = centers 203 | else: 204 | self.views[view_idx]['unknown_type_center_loss'].centers = centers 205 | return 206 | -------------------------------------------------------------------------------- /baselines/vae.py: -------------------------------------------------------------------------------- 1 | ''' 2 | File from https://github.com/ritheshkumar95/pytorch-vqvae/blob/master/modules.py 3 | ''' 4 | from typing import List, Tuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.distributions.normal import Normal 10 | from torch.distributions import kl_divergence 11 | from torch.autograd import Function 12 | 13 | 14 | from common.layers import MLP 15 | 16 | class VectorQuantization(Function): 17 | @staticmethod 18 | def forward(ctx, inputs, codebook): 19 | with torch.no_grad(): 20 | embedding_size = codebook.size(1) 21 | inputs_size = inputs.size() 22 | inputs_flatten = inputs.view(-1, embedding_size) 23 | 24 | codebook_sqr = torch.sum(codebook ** 2, dim=1) 25 | inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True) 26 | 27 | # Compute the distances to the codebook 28 | distances = torch.addmm(codebook_sqr + inputs_sqr, 29 | inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0) 30 | 31 | _, indices_flatten = torch.min(distances, dim=1) 32 | indices = indices_flatten.view(*inputs_size[:-1]) 33 | ctx.mark_non_differentiable(indices) 34 | 35 | return indices 36 | 37 | @staticmethod 38 | def backward(ctx, grad_output): 39 | raise RuntimeError('Trying to call `.grad()` on graph containing ' 40 | '`VectorQuantization`. The function `VectorQuantization` ' 41 | 'is not differentiable. Use `VectorQuantizationStraightThrough` ' 42 | 'if you want a straight-through estimator of the gradient.') 43 | 44 | class VectorQuantizationStraightThrough(Function): 45 | @staticmethod 46 | def forward(ctx, inputs, codebook): 47 | # indices = vq(inputs, codebook) 48 | y = torch.softmax(torch.matmul(inputs, codebook.transpose(0,1)), dim=1) 49 | # (batch, K) 50 | _, indices = torch.max(y, dim=1) 51 | 52 | indices_flatten = indices.view(-1) 53 | ctx.save_for_backward(indices_flatten, codebook) 54 | ctx.mark_non_differentiable(indices_flatten) 55 | 56 | codes_flatten = torch.index_select(codebook, dim=0, 57 | index=indices_flatten) 58 | codes = codes_flatten.view_as(inputs) 59 | 60 | return (codes, indices_flatten) 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output, grad_indices): 64 | grad_inputs, grad_codebook = None, None 65 | 66 | if ctx.needs_input_grad[0]: 67 | # Straight-through estimator 68 | grad_inputs = grad_output.clone() 69 | if ctx.needs_input_grad[1]: 70 | # Gradient wrt. the codebook 71 | indices, codebook = ctx.saved_tensors 72 | embedding_size = codebook.size(1) 73 | 74 | grad_output_flatten = (grad_output.contiguous() 75 | .view(-1, embedding_size)) 76 | grad_codebook = torch.zeros_like(codebook) 77 | grad_codebook.index_add_(0, indices, grad_output_flatten) 78 | 79 | return (grad_inputs, grad_codebook) 80 | 81 | vq = VectorQuantization.apply 82 | vq_st = VectorQuantizationStraightThrough.apply 83 | 84 | def to_scalar(arr): 85 | if type(arr) == list: 86 | return [x.item() for x in arr] 87 | else: 88 | return arr.item() 89 | 90 | 91 | def weights_init(m): 92 | classname = m.__class__.__name__ 93 | if classname.find('Conv') != -1: 94 | try: 95 | nn.init.xavier_uniform_(m.weight.data) 96 | m.bias.data.fill_(0) 97 | except AttributeError: 98 | print("Skipping initialization of ", classname) 99 | 100 | 101 | class VAE(nn.Module): 102 | def __init__(self, input_dim, dim, z_dim): 103 | super().__init__() 104 | self.encoder = nn.Sequential( 105 | nn.Conv2d(input_dim, dim, 4, 2, 1), 106 | nn.BatchNorm2d(dim), 107 | nn.ReLU(True), 108 | nn.Conv2d(dim, dim, 4, 2, 1), 109 | nn.BatchNorm2d(dim), 110 | nn.ReLU(True), 111 | nn.Conv2d(dim, dim, 5, 1, 0), 112 | nn.BatchNorm2d(dim), 113 | nn.ReLU(True), 114 | nn.Conv2d(dim, z_dim * 2, 3, 1, 0), 115 | nn.BatchNorm2d(z_dim * 2) 116 | ) 117 | 118 | self.decoder = nn.Sequential( 119 | nn.ConvTranspose2d(z_dim, dim, 3, 1, 0), 120 | nn.BatchNorm2d(dim), 121 | nn.ReLU(True), 122 | nn.ConvTranspose2d(dim, dim, 5, 1, 0), 123 | nn.BatchNorm2d(dim), 124 | nn.ReLU(True), 125 | nn.ConvTranspose2d(dim, dim, 4, 2, 1), 126 | nn.BatchNorm2d(dim), 127 | nn.ReLU(True), 128 | nn.ConvTranspose2d(dim, input_dim, 4, 2, 1), 129 | nn.Tanh() 130 | ) 131 | 132 | self.apply(weights_init) 133 | 134 | def forward(self, x): 135 | mu, logvar = self.encoder(x).chunk(2, dim=1) 136 | 137 | q_z_x = Normal(mu, logvar.mul(.5).exp()) 138 | p_z = Normal(torch.zeros_like(mu), torch.ones_like(logvar)) 139 | kl_div = kl_divergence(q_z_x, p_z).sum(1).mean() 140 | 141 | x_tilde = self.decoder(q_z_x.rsample()) 142 | return x_tilde, kl_div 143 | 144 | 145 | class VQEmbedding(nn.Module): 146 | def __init__(self, K: int, D: int): 147 | super().__init__() 148 | self.embedding = nn.Embedding(K, D) 149 | self.embedding.weight.data.uniform_(-1./K, 1./K) 150 | 151 | def forward(self, z_e_x: torch.FloatTensor): 152 | logits = torch.matmul(z_e_x, self.embedding.weight.transpose(0,1)) 153 | y = torch.softmax(logits, dim=1) 154 | # (batch, K) 155 | _, indices = torch.max(y, dim=1) 156 | return logits, indices 157 | 158 | def straight_through(self, z_e_x): 159 | z_q_x, indices = vq_st(z_e_x, self.embedding.weight.detach()) 160 | 161 | z_q_x_bar_flatten = torch.index_select(self.embedding.weight, 162 | dim=0, index=indices) 163 | z_q_x_bar = z_q_x_bar_flatten.view_as(z_e_x) 164 | 165 | return z_q_x, z_q_x_bar 166 | 167 | 168 | class ResBlock(nn.Module): 169 | def __init__(self, dim): 170 | super().__init__() 171 | self.block = nn.Sequential( 172 | nn.ReLU(True), 173 | nn.Conv2d(dim, dim, 3, 1, 1), 174 | nn.BatchNorm2d(dim), 175 | nn.ReLU(True), 176 | nn.Conv2d(dim, dim, 1), 177 | nn.BatchNorm2d(dim) 178 | ) 179 | 180 | def forward(self, x): 181 | return x + self.block(x) 182 | 183 | 184 | class VectorQuantizedVAE(nn.Module): 185 | def __init__(self, input_dim, dim, K=512): 186 | super().__init__() 187 | self.encoder = nn.Sequential( 188 | nn.Conv2d(input_dim, dim, 4, 2, 1), 189 | nn.BatchNorm2d(dim), 190 | nn.ReLU(True), 191 | nn.Conv2d(dim, dim, 4, 2, 1), 192 | ResBlock(dim), 193 | ResBlock(dim), 194 | ) 195 | 196 | self.codebook = VQEmbedding(K, dim) 197 | 198 | self.decoder = nn.Sequential( 199 | ResBlock(dim), 200 | ResBlock(dim), 201 | nn.ReLU(True), 202 | nn.ConvTranspose2d(dim, dim, 4, 2, 1), 203 | nn.BatchNorm2d(dim), 204 | nn.ReLU(True), 205 | nn.ConvTranspose2d(dim, input_dim, 4, 2, 1), 206 | nn.Tanh() 207 | ) 208 | 209 | self.apply(weights_init) 210 | 211 | def encode(self, x): 212 | z_e_x = self.encoder(x) 213 | latents = self.codebook(z_e_x) 214 | return latents 215 | 216 | def decode(self, latents): 217 | z_q_x = self.codebook.embedding(latents).permute(0, 3, 1, 2) # (B, D, H, W) 218 | x_tilde = self.decoder(z_q_x) 219 | return x_tilde 220 | 221 | def forward(self, x): 222 | z_e_x = self.encoder(x) 223 | z_q_x_st, z_q_x = self.codebook.straight_through(z_e_x) 224 | x_tilde = self.decoder(z_q_x_st) 225 | return x_tilde, z_e_x, z_q_x 226 | 227 | class EventVQVAE(nn.Module): 228 | ''' 229 | This class follows the descriptions in Semi-supervised New Event Type Induction and Event Detection paper. 230 | ''' 231 | def __init__(self, input_dim: int, dim: int, known_types: int, unknown_types: int, layers_n: int=1, 232 | use_vae: bool=False, vae_dim: int=1024) -> None: 233 | 234 | super().__init__() 235 | self.known_types = known_types 236 | self.unknown_types = unknown_types 237 | self.types_n = known_types + unknown_types 238 | 239 | # f_c 240 | self.encoder = MLP(input_dim, dim, dim, norm=False, layers_n=layers_n) 241 | self.codebook = VQEmbedding(K=self.types_n, D=dim) 242 | self.decoder = nn.Linear(in_features=dim, out_features=input_dim) 243 | 244 | self.use_vae = use_vae 245 | if use_vae: 246 | # f_e 247 | self.vae_encoder = MLP(input_dim, vae_dim, 2 * vae_dim, norm=False, layers_n=2) 248 | # f_r 249 | self.vae_decoder = nn.Linear(self.types_n + vae_dim, input_dim) 250 | 251 | self.apply(weights_init) 252 | 253 | def encode(self, x: torch.FloatTensor) ->Tuple[torch.FloatTensor, torch.LongTensor]: 254 | ''' 255 | :param x: (batch, input_dim) 256 | ''' 257 | z_e_x = self.encoder(x) # (batch, dim) 258 | logits, indexes = self.codebook(z_e_x) # (batch) 259 | return logits, indexes 260 | 261 | def decode(self, latents: torch.LongTensor) -> torch.FloatTensor: 262 | z_q_x = self.codebook.embedding(latents) 263 | x_tilde = self.decoder(z_q_x) 264 | return x_tilde 265 | 266 | def forward(self, x: torch.FloatTensor, known_mask: torch.LongTensor, labels=None): 267 | z_e_x = self.encoder(x) 268 | logits, _ = self.codebook(z_e_x) 269 | # z_q_x is backpropped to embedding 270 | z_q_x_st, z_q_x = self.codebook.straight_through(z_e_x) 271 | if self.use_vae: 272 | mu, logvar = self.vae_encoder(x).chunk(2, dim=1) 273 | q_z_x = Normal(mu, logvar.mul(.5).exp()) 274 | p_z = Normal(torch.zeros_like(mu), torch.ones_like(logvar)) 275 | kl_div = kl_divergence(q_z_x, p_z).sum(1).mean() 276 | sampled_z_x = q_z_x.rsample() 277 | 278 | y_tilde = torch.softmax(logits, dim=1) 279 | y = y_tilde.clone() 280 | y[known_mask] = F.one_hot(labels[known_mask], num_classes=self.types_n).float() # replace known types with gold labels 281 | x_tilde = self.vae_decoder(torch.concat([sampled_z_x, y], dim=1)) 282 | else: 283 | x_tilde = self.decoder(z_q_x_st) 284 | kl_div = 0.0 285 | # for a pure VQ-VQE, the prior is an uniform distribution and thus the kl term is constant 286 | 287 | return x_tilde, z_e_x, z_q_x, logits, kl_div 288 | -------------------------------------------------------------------------------- /baselines/vqvae_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Dict, Tuple, Optional, Union 4 | from collections import defaultdict 5 | import pickle as pkl 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import pytorch_lightning as pl 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | from transformers import AutoConfig, AutoModel, AdamW, get_linear_schedule_with_warmup 16 | 17 | from common.metrics import ClusteringMetricsWrapper 18 | from common.predictions import ClusterPredictionsWrapper 19 | from common.utils import cluster_acc, get_label_mapping, onedim_gather 20 | from baselines.vae import EventVQVAE 21 | 22 | 23 | 24 | import common.log as log 25 | logger = log.get_logger('root') 26 | 27 | 28 | class VQVAEModel(pl.LightningModule): 29 | def __init__(self, args, tokenizer, train_len:int = 1000) -> None: 30 | super().__init__() 31 | self.config = args 32 | 33 | self.tokenizer = tokenizer 34 | self.model_config = AutoConfig.from_pretrained(args.model_name_or_path, output_hidden_states = True) 35 | self.pretrained_model = AutoModel.from_pretrained(args.model_name_or_path, config = self.model_config) 36 | embeddings = self.pretrained_model.resize_token_embeddings(len(self.tokenizer)) # when adding new tokens, the tokenizer.vocab_size is not changed! 37 | 38 | self.train_len=train_len 39 | 40 | self.model = EventVQVAE(self.model_config.hidden_size, dim=500, 41 | known_types=args.known_types, unknown_types=args.unknown_types, use_vae=True if args.hybrid else False) 42 | 43 | self.train_unknown_metrics_wrapper = ClusteringMetricsWrapper(stage='train', 44 | known=False, 45 | prefix='train_unknown', known_classes=args.known_types) 46 | self.val_unknown_metrics_wrapper = ClusteringMetricsWrapper(stage='val', 47 | known=False, 48 | prefix='val_unknown', known_classes=args.known_types) 49 | self.test_unknown_metrics_wrapper = ClusteringMetricsWrapper(stage='test', 50 | known=False, 51 | prefix='test_unknown', known_classes=args.known_types) 52 | 53 | self.train_known_metrics_wrapper = ClusteringMetricsWrapper(stage='train', 54 | known=True, 55 | prefix='train_known', 56 | known_classes=args.known_types) 57 | self.val_known_metrics_wrapper = ClusteringMetricsWrapper(stage='val', 58 | known=True, 59 | prefix='val_known', known_classes=args.known_types) 60 | self.test_known_metrics_wrapper = ClusteringMetricsWrapper(stage='test', 61 | known=True, 62 | prefix='test_known', known_classes=args.known_types) 63 | 64 | if args.eval_only: 65 | self.predictions_wrapper = ClusterPredictionsWrapper(reassign=True, prefix='test_unknown', 66 | known_classes=args.known_types, task=args.task) 67 | 68 | 69 | def on_validation_epoch_start(self) -> None: 70 | # reset all metrics 71 | self.train_unknown_metrics_wrapper.reset() 72 | self.val_unknown_metrics_wrapper.reset() 73 | self.test_unknown_metrics_wrapper.reset() 74 | 75 | self.train_known_metrics_wrapper.reset() 76 | self.val_known_metrics_wrapper.reset() 77 | self.test_known_metrics_wrapper.reset() 78 | 79 | return 80 | 81 | def on_train_epoch_start(self) -> None: 82 | return 83 | 84 | 85 | 86 | def training_step(self, batch: List[Dict[str, torch.Tensor]], batch_idx: int): 87 | ''' 88 | batch = { 89 | 'meta':List[Dict], 90 | 'token_ids': torch.LongTensor (batch, seq_len), 91 | 'attn_mask': torch.BoolTensor (batch, seq_len) 92 | 'labels': torch.LongTensor([x['label'] for x in batch]) 93 | 'head_spans': , 94 | 'tail_spans': , 95 | 'mask_bpe_idx': , 96 | 'known_mask' torch.BoolTensor 97 | } 98 | ''' 99 | view_n = len(batch) 100 | batch_size = len(batch[0]['meta']) 101 | labels = batch[0]['labels'] 102 | 103 | known_mask = batch[0]['known_mask'] # (batch, ) 104 | view = batch[0] 105 | outputs = self.pretrained_model(input_ids=view['token_ids'],attention_mask=view['attn_mask']) 106 | seq_output = outputs[0] 107 | feat = onedim_gather(seq_output, dim=1, index=view['trigger_spans'][:, 0].unsqueeze(1)).squeeze(1) 108 | # (batch, hidden_dim) 109 | x_tilde, z_e_x, z_q_x, logits, kl_div = self.model(feat, known_mask, labels) 110 | 111 | 112 | # Reconstruction loss 113 | loss_recons = F.mse_loss(x_tilde, feat) + kl_div 114 | # Vector quantization objective 115 | loss_vq = F.mse_loss(z_q_x, z_e_x.detach()) 116 | # Commitment objective 117 | loss_commit = F.mse_loss(z_e_x, z_q_x.detach()) 118 | 119 | 120 | # supervised loss 121 | loss_supervised = F.cross_entropy(logits[known_mask, :], labels[known_mask]) 122 | # unsupervised margin loss 123 | y = torch.softmax(logits, dim=1) 124 | known_prob = y[~known_mask, :self.config.known_types] 125 | unknown_prob = y[~known_mask, self.config.known_types:] 126 | diff = torch.max(known_prob, dim=1)[0] - torch.max(unknown_prob, dim=1)[0] 127 | loss_unsupervised = torch.sum(torch.clamp(diff, min=0))/diff.size(0) 128 | 129 | loss = self.config.beta * (loss_vq + loss_commit) + loss_supervised + self.config.gamma* loss_unsupervised +\ 130 | self.config.recon_loss * loss_recons 131 | self.log('train/recon_loss', loss_recons) 132 | self.log('train/vq_loss', loss_vq+loss_commit) 133 | self.log('train/ss_loss', loss_supervised + loss_unsupervised) 134 | self.log('train/supervised_loss', loss_supervised) 135 | self.log('train/unsupervised_loss', loss_unsupervised) 136 | 137 | return loss 138 | 139 | 140 | def validation_step(self, batch: List[Dict[str, torch.Tensor]], batch_idx: int, dataloader_idx: int)-> Dict[str, torch.Tensor]: 141 | ''' 142 | :param dataloader_idx: 0 for unknown_train, 1 for unknown_test, 2 for known_test 143 | ''' 144 | view_n = len(batch) 145 | batch_size = len(batch[0]['meta']) 146 | view = batch[0] 147 | known_mask = batch[0]['known_mask'] # (batch, ) 148 | 149 | outputs = self.pretrained_model(input_ids=view['token_ids'],attention_mask=view['attn_mask']) 150 | seq_output = outputs[0] 151 | feat = onedim_gather(seq_output, dim=1, index=view['trigger_spans'][:, 0].unsqueeze(1)).squeeze(1) 152 | # (batch, hidden_dim) 153 | logits, indexes = self.model.encode(feat) 154 | 155 | # use the gt labels to compute metrics 156 | # for this model, do not evaluation on known types 157 | # setting incremental to True will not subtract the number of known types 158 | if dataloader_idx == 0: 159 | self.train_unknown_metrics_wrapper.update_batch(logits, batch[0]['labels'], incremental=False) 160 | elif dataloader_idx == 1: 161 | self.val_unknown_metrics_wrapper.update_batch(logits, batch[0]['labels'], incremental=False) 162 | else: 163 | self.val_known_metrics_wrapper.update_batch(logits, batch[0]['labels'], incremental=False) 164 | return {} 165 | 166 | def validation_epoch_end(self, outputs: List[List[Dict]]) -> None: 167 | val_unknown_metrics = self.val_unknown_metrics_wrapper.on_epoch_end() 168 | train_unknown_metrics = self.train_unknown_metrics_wrapper.on_epoch_end() 169 | val_known_metrics = self.val_known_metrics_wrapper.on_epoch_end() 170 | 171 | for k,v in val_unknown_metrics.items(): 172 | self.log(f'val/unknown_{k}',value=v, logger=True, on_step=False, on_epoch=True) 173 | 174 | for k,v in train_unknown_metrics.items(): 175 | self.log(f'train/unknown_{k}', value=v, logger=True, on_step=False, on_epoch=True) 176 | 177 | for k,v in val_known_metrics.items(): 178 | self.log(f'val/known_{k}',value=v, logger=True, on_step=False, on_epoch=True) 179 | return 180 | 181 | 182 | def test_step(self, batch: List[Dict], batch_idx: int) -> Dict: 183 | ''' 184 | :param dataloader_idx: 0 for unknown_test 185 | ''' 186 | view_n = len(batch) 187 | batch_size = len(batch[0]['meta']) 188 | view = batch[0] 189 | known_mask = batch[0]['known_mask'] # (batch, ) 190 | 191 | outputs = self.pretrained_model(input_ids=view['token_ids'],attention_mask=view['attn_mask']) 192 | seq_output = outputs[0] 193 | feat = onedim_gather(seq_output, dim=1, index=view['trigger_spans'][:, 0].unsqueeze(1)).squeeze(1) 194 | logits, indexes = self.model.encode(feat) 195 | # use the gt labels to compute metrics 196 | self.test_unknown_metrics_wrapper.update_batch(logits, batch[0]['labels'], incremental=False) 197 | 198 | self.predictions_wrapper.update_batch(batch[0]['meta'], logits, batch[0]['labels'], incremental=False) 199 | 200 | return {} 201 | 202 | 203 | def test_epoch_end(self, outputs: List[Dict]) -> None: 204 | test_unknown_metrics = self.test_unknown_metrics_wrapper.on_epoch_end() 205 | for k,v in test_unknown_metrics.items(): 206 | self.log(f'test/unknown_{k}',value=v, logger=True, on_step=False, on_epoch=True) 207 | 208 | self.test_unknown_metrics_wrapper.save(self.config.ckpt_dir) 209 | 210 | self.predictions_wrapper.on_epoch_end() 211 | self.predictions_wrapper.save(self.config.ckpt_dir) 212 | 213 | return 214 | 215 | 216 | def configure_optimizers(self): 217 | no_decay = ["bias", "LayerNorm.weight"] 218 | optimizer_grouped_parameters = [ 219 | { 220 | "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 221 | "weight_decay": self.config.weight_decay, 222 | }, 223 | {"params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 224 | ] 225 | 226 | 227 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate, eps=self.config.adam_epsilon) 228 | 229 | if self.config.max_steps > 0: 230 | t_total = self.config.max_steps 231 | self.config.num_train_epochs = self.config.max_steps // self.train_len // self.config.accumulate_grad_batches + 1 232 | else: 233 | t_total = self.train_len // self.config.accumulate_grad_batches * self.config.num_train_epochs 234 | 235 | logger.info('{} training steps in total.. '.format(t_total)) 236 | 237 | 238 | # scheduler is called only once per epoch by default 239 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.config.warmup_steps, num_training_steps=t_total) 240 | scheduler_dict = { 241 | 'scheduler': scheduler, 242 | 'interval': 'step', 243 | 'name': 'linear-schedule', 244 | } 245 | 246 | return [optimizer, ], [scheduler_dict,] 247 | 248 | -------------------------------------------------------------------------------- /common/metrics.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, Tuple, List 3 | import json 4 | import os 5 | 6 | import torch 7 | import numpy as np 8 | 9 | from sklearn.metrics.cluster import homogeneity_completeness_v_measure, adjusted_rand_score, normalized_mutual_info_score,fowlkes_mallows_score 10 | from sklearn.cluster import SpectralClustering 11 | from torchmetrics import Metric, MetricCollection 12 | import networkx as nx 13 | import community 14 | 15 | from common.utils import cluster_acc 16 | from common.b3 import calc_b3 17 | import common.log as log 18 | logger = log.get_logger('root') 19 | ''' 20 | Clustering metrics: 21 | B-cube 22 | V-measure 23 | Adjusted Rand Index 24 | ''' 25 | 26 | 27 | class PairwiseClusteringMetricsWrapper(torch.nn.Module): 28 | def __init__(self, stage: str='train', prefix: str='', 29 | cluster_n: int=10, clustering_method: str='louvain') -> None: 30 | ''' 31 | :param clustering method: str, one of louvain, hac 32 | ''' 33 | super().__init__() 34 | self.stage = stage 35 | self.prefix = prefix 36 | 37 | self.target_cluster_cache = [] 38 | self.metrics = None 39 | self.clustering_method = clustering_method 40 | self.cluster_n = cluster_n 41 | 42 | def get_pred_cluster(self, pred_metric:Dict, labels: Dict, 43 | clustering_method: str, merge_iso:bool=True, iso_thres:int=5 )->Dict: 44 | ''' 45 | :param pred_metric: This is actually a distance, the smaller the metric, the closer 46 | ''' 47 | N = len(labels) 48 | uid2idx = {k:idx for idx, k in enumerate(labels.keys())} 49 | idx2uid = {v:k for k,v in uid2idx.items()} 50 | if clustering_method == 'louvain': 51 | g = nx.Graph() 52 | g.add_nodes_from(labels.keys()) 53 | 54 | for k, v in pred_metric.items(): 55 | if round(v) == 0: 56 | g.add_edge(k[0], k[1]) 57 | 58 | partition = community.best_partition(g) 59 | pred_cluster = {k: partition[k] for k in labels} 60 | elif clustering_method == 'spectral': 61 | sim_matrix = np.zeros((N, N)) 62 | for idx in range(N): 63 | for other_idx in range(idx, N): 64 | key = (idx2uid[idx], idx2uid[other_idx]) 65 | rev_key = (idx2uid[other_idx], idx2uid[idx]) 66 | sim_matrix[idx, other_idx] = 1- pred_metric[key] # dist to similarity 67 | sim_matrix[other_idx, idx] = 1- pred_metric[rev_key] 68 | 69 | clustering = SpectralClustering(n_clusters=self.cluster_n, affinity='precomputed', assign_labels='discretize').fit(sim_matrix) 70 | pred_cluster = {idx2uid[idx]: int(clustering.labels_[idx]) for idx in range(N)} 71 | 72 | else: 73 | raise NotImplementedError(f'clustering method {clustering_method} not supported.') 74 | if merge_iso: 75 | cluster2uid = defaultdict(list) 76 | for k,v in pred_cluster.items(): 77 | cluster2uid[v].append(k) 78 | cluster_sizes = {k: len(v) for k,v in cluster2uid.items()} 79 | iso_clus = [k for k in cluster_sizes if cluster_sizes[k] <=iso_thres] 80 | 81 | # reassign the data points in the small clusters 82 | for clus_idx in iso_clus: 83 | for uid in cluster2uid[clus_idx]: 84 | all_dist = np.zeros((N,), dtype=np.float32) 85 | for other_uid in labels.keys(): 86 | all_dist[uid2idx[other_uid]] = pred_metric[(uid, other_uid)] 87 | search_idx_list = np.argsort(all_dist) 88 | for other_idx in search_idx_list: 89 | other_uid = idx2uid[other_idx] 90 | if pred_cluster[other_uid] not in iso_clus: 91 | pred_cluster[uid] = pred_cluster[other_uid] 92 | break 93 | 94 | 95 | return pred_cluster 96 | 97 | def on_epoch_end(self, pred_metric, labels)->Tuple[Dict[str, float], Dict[str, int]]: 98 | pred_cluster = self.get_pred_cluster(pred_metric, labels, self.clustering_method) 99 | # convert dict to numpy array 100 | pred_cluster_list = [] 101 | target_cluster_list = [] 102 | for k,v in pred_cluster.items(): 103 | pred_cluster_list.append(v) 104 | target_cluster_list.append(labels[k]) 105 | 106 | all_pred_cluster = np.array(pred_cluster_list) 107 | all_target_cluster = np.array(target_cluster_list) 108 | 109 | 110 | b3_metrics = calc_b3(np.expand_dims(all_pred_cluster,axis=0), np.expand_dims(all_target_cluster, axis=0)) 111 | ari = adjusted_rand_score(all_target_cluster, all_pred_cluster) 112 | v_hom, v_comp, v_f1 = homogeneity_completeness_v_measure(all_target_cluster, all_pred_cluster) 113 | nmi = normalized_mutual_info_score(all_target_cluster, all_pred_cluster) 114 | fm = fowlkes_mallows_score(all_target_cluster, all_pred_cluster) 115 | metrics = { 116 | 'b3_f1': b3_metrics[0], 117 | 'b3_prec': b3_metrics[1], 118 | 'b3_recall': b3_metrics[2], 119 | 'ARI': ari, 120 | 'homogeneity': v_hom, 121 | 'completeness': v_comp, 122 | 'v_measure': v_f1, 123 | 'NMI': nmi, 124 | 'fowlkes_mallows': fm 125 | } 126 | 127 | acc = cluster_acc(all_target_cluster, all_pred_cluster, reassign=True) 128 | 129 | metrics['acc'] = float(acc) 130 | self.metrics = metrics 131 | 132 | return metrics, pred_cluster 133 | 134 | 135 | 136 | def save(self, ckpt_dir: str): 137 | if self.metrics: 138 | metrics = self.metrics 139 | else: 140 | metrics = self.on_epoch_end() 141 | 142 | output_dict = {k: v for k,v in metrics.items()} 143 | with open(os.path.join(ckpt_dir, f'{self.prefix}_metrics.json'),'w') as f: 144 | json.dump(output_dict, f, indent=2) 145 | 146 | return 147 | 148 | class ClusteringMetricsWrapper(torch.nn.Module): 149 | def __init__(self, stage: str='train', known: bool=False, prefix: str='', known_classes:int=31) -> None: 150 | ''' 151 | :param reassign: whether to compute the cluster assignment. Set to true for unknown classes. 152 | ''' 153 | super().__init__() 154 | self.stage = stage 155 | self.prefix = prefix 156 | self.known_classes = known_classes 157 | self.known = known 158 | self.reassign = False if known else True 159 | 160 | self.pred_cluster_cache = [] 161 | self.target_cluster_cache = [] 162 | 163 | self.metrics = None 164 | 165 | def reset(self): 166 | self.pred_cluster_cache = [] 167 | self.target_cluster_cache = [] 168 | 169 | def update_batch(self, logits: torch.FloatTensor, targets: torch.LongTensor, incremental:bool=False )->None: 170 | ''' 171 | :param logits: (batch, n_cluster) 172 | :param targets: (batch) 173 | ''' 174 | if not incremental: 175 | if not self.known: 176 | # consider only unknown classes 177 | _, pred_cluster = torch.max(logits[:, self.known_classes:], dim=1) 178 | self.pred_cluster_cache.append(pred_cluster) 179 | self.target_cluster_cache.append(targets-self.known_classes) 180 | else: 181 | # consider only known classes 182 | _, pred_cluster = torch.max(logits[:, :self.known_classes], dim=1) 183 | self.pred_cluster_cache.append(pred_cluster) 184 | self.target_cluster_cache.append(targets) 185 | else: 186 | _, pred_cluster = torch.max(logits, dim=1) 187 | self.pred_cluster_cache.append(pred_cluster) 188 | self.target_cluster_cache.append(targets) 189 | return 190 | 191 | 192 | def on_epoch_end(self)-> Dict[str, float]: 193 | all_pred_cluster = torch.concat(self.pred_cluster_cache).cpu().numpy() 194 | all_target_cluster = torch.concat(self.target_cluster_cache).cpu().numpy() 195 | b3_metrics = calc_b3(np.expand_dims(all_pred_cluster,axis=0), np.expand_dims(all_target_cluster, axis=0)) 196 | ari = adjusted_rand_score(all_target_cluster, all_pred_cluster) 197 | v_hom, v_comp, v_f1 = homogeneity_completeness_v_measure(all_target_cluster, all_pred_cluster) 198 | nmi = normalized_mutual_info_score(all_target_cluster, all_pred_cluster) 199 | fm = fowlkes_mallows_score(all_target_cluster, all_pred_cluster) 200 | metrics = { 201 | 'b3_f1': b3_metrics[0], 202 | 'b3_prec': b3_metrics[1], 203 | 'b3_recall': b3_metrics[2], 204 | 'ARI': ari, 205 | 'homogeneity': v_hom, 206 | 'completeness': v_comp, 207 | 'v_measure': v_f1, 208 | 'NMI': nmi, 209 | 'fowlkes_mallows': fm 210 | } 211 | 212 | acc = cluster_acc(all_target_cluster, all_pred_cluster, reassign=self.reassign) 213 | 214 | metrics['acc'] = float(acc) 215 | self.metrics = metrics 216 | 217 | return metrics 218 | 219 | 220 | 221 | def save(self, ckpt_dir: str): 222 | if self.metrics: 223 | metrics = self.metrics 224 | else: 225 | metrics = self.on_epoch_end() 226 | 227 | output_dict = {k: v for k,v in metrics.items()} 228 | with open(os.path.join(ckpt_dir, f'{self.prefix}_metrics.json'),'w') as f: 229 | json.dump(output_dict, f, indent=2) 230 | 231 | return 232 | 233 | class PsuedoLabelMetricWrapper(torch.nn.Module): 234 | def __init__(self, prefix: str='', cache_size:int = 1024, known_classes: int = 31, unknown_classes: int =10) -> None: 235 | super().__init__() 236 | 237 | self.prefix = prefix 238 | self.cache_size = cache_size 239 | self.known_classes= known_classes 240 | self.unknown_classes = unknown_classes 241 | 242 | self.register_buffer('predicted_cache', torch.zeros((cache_size), dtype=torch.long), persistent=False) 243 | self.register_buffer('label_cache', torch.zeros((cache_size), dtype=torch.long), persistent=False) 244 | self.cur_len =0 245 | 246 | 247 | def compute_metric(self): 248 | # this operation will slow down the training, use only for diagnosis. 249 | label_array = self.label_cache[:self.cur_len].cpu().numpy() 250 | predicted_array = self.predicted_cache[:self.cur_len].cpu().numpy() 251 | acc = cluster_acc(label_array,predicted_array, reassign=True) 252 | ari = adjusted_rand_score(label_array, predicted_array) 253 | return acc, ari 254 | 255 | 256 | def update_batch(self, psuedo_labels: torch.FloatTensor, labels: torch.LongTensor): 257 | ''' 258 | :param pl: (batch, M+N) 259 | :param labels: (batch) over M+N 260 | ''' 261 | batch_size = psuedo_labels.size(0) 262 | 263 | _, predicted = torch.max(psuedo_labels[:, self.known_classes:], dim=1) 264 | new_cache = torch.concat([predicted, self.predicted_cache], dim=0) 265 | self.predicted_cache = new_cache[:self.cache_size] 266 | self.cur_len += batch_size 267 | self.cur_len = min(self.cur_len, self.cache_size) 268 | 269 | new_label_cache = torch.concat([labels - self.known_classes, self.label_cache], dim=0) 270 | self.label_cache = new_label_cache[:self.cache_size] 271 | 272 | return 273 | 274 | -------------------------------------------------------------------------------- /data_sample/fewrel/relation_description.csv: -------------------------------------------------------------------------------- 1 | P6:head_of_gov location person/politician head of government is head of the executive power of this town, city, municipality, state, country, or other governmental body 2 | P17:country location location/country country is sovereign state of this item; don’t use on humans 3 | P22:father person person/man father is male parent of the subject. For stepfather, use ”stepparent” 4 | P27:citizenship person location/country country of citizenship is the object is a country that recognizes the subject as its citizen 5 | P31:instance_of object object instance of is that class of which this subject is a particular example and member. (Subject typically an individual member with Proper Name label.) 6 | P39:position person title position held is subject currently or formerly holds the object position or public office 7 | P57:director work_of_art/movie person/director director is director(s) of this motion picture, TV-series, stageplay, video game or similar 8 | P58:screenwriter work_of_art/movie person/screenwriter screenwriter is author(s) of the screenplay or script for this work 9 | P84:architect facility/building person/architect architect is person or architectural firm that designed this building 10 | P86:composer work_of_art/music person/composer composer is person(s) who wrote the music; also use P676 for lyricist 11 | P101:field_of_work person title/specialization field of work is specialization of a person or organization; see P106 for the occupation 12 | P102:member_of_party person organization/party member of political party is the political party of which this politician is or has been a member 13 | P105:taxon_rank category rank taxon rank is level in a taxonomic hierarchy 14 | P106:occupation person title/occupation occupation is occupation of a person; see also ”field of work” (Property:P101), ”position held” (Property:P39) 15 | P118:league person/athlete league league is league in which team or player plays or has played in 16 | P123:publisher work_of_art/book organization/publisher publisher is organization or person responsible for publishing books, periodicals, games or software 17 | P127:owned_by object object owned by is owner of the subject 18 | P131:located_in location location located in the administrative territorial entity is the item is located on the territory of the following administrative entity. Use P276 (location) for specifying the location of non-administrative places and for items about events 19 | P135:movement person event/movement movement is literary, artistic, scientific or philosophical movement associated with this person or work 20 | P136:genre person/artist category/genre genre is creative work’s genre or an artist’s field of work (P101). Use main subject (P921) to relate creative works to their topic 21 | P137:operator facility person|organization operator is person or organization that operates the equipment, facility, or service; use country for diplomatic missions 22 | P140:religion person|organization religion religion is religion of a person, organization or religious building, or associated with this subject 23 | P150:contains location location contains administrative territorial entity is (list of) direct subdivisions of an administrative territorial entity 24 | P156:followed_by object object followed by is immediately following item in some series of which the subject is part. Use P1366 (replaced by) if the item is replaced, e.g. political offices, states 25 | P159:headquarters organization location headquarters location is specific location where an organization’s headquarters is or has been situated 26 | P175:performer work_of_art person/artist performer is performer involved in the performance or the recording of a work 27 | P176:manufacturer product organization/company manufacturer is manufacturer or producer of this product 28 | P178:developer product person|organization developer is organisation or person that developed this item 29 | P241:military_branch person organization/military military branch is branch to which this military unit, award, office, or person belongs, e.g. Royal Navy 30 | P264:record_label person/musician organization/record_label record label is brand and trademark associated with the marketing of subject music recordings and music videos 31 | P276:location event location location is location of the item, physical object or event is within. In case of an administrative entity use P131. In case of a distinct terrain feature use P706. 32 | P306:operating_system product/software product/operating_system operating system is operating system (OS) on which a software works or the OS installed on hardware 33 | P355:subsidiary organization organization subsidiary is subsidiary of a company or organization, opposite of parent company (P749) 34 | P400:platform product/software product/platform platform is platform for which a work has been developed or released / specific platform version of a software developed 35 | P403:mouth_of_watercourse location/body_of_water location/body_of_water mouth of the watercourse is the body of water to which the watercourse drains 36 | P407:language work language language of work or name is language associated with this work or name (for persons use P103 and P1412) 37 | P449:original_network work_of_art/show media/television_network original network is network(s) the radio or television show was originally aired on, including 38 | P460:same_as object object said to be the same as is this item is said to be the same as that item, but the statement is disputed 39 | P466:occupant facility person|organization occupant is person or organization occupying a building or facility 40 | P495:country_of_origin location/country work_of_art country of origin is country of origin of the creative work or subject item 41 | P527:has_part object object has part is part of this subject. Inverse property of ”part of” (P361). 42 | P551:residence person location residence is the place where the person is, or has been, resident 43 | P674:characters work_of_art person/character characters is characters which appear in this item (like plays, operas, operettas, books, comics, films, TV series, video games) 44 | P706:located_on_terrain location location located on terrain feature is located on the specified landform. Should not be used when the value is only political/administrative (provinces, states, countries, etc.). Use P131 for administrative entity. 45 | P710:participant participant event participant is person, group of people or organization (object) that actively takes/took part in the event (subject). Preferably qualify with ”object has role” (P3831). Use P1923 for team participants. 46 | P740:location_of_formation location organization location of formation is location where a group or organization was formed 47 | P750:distributor work_of_art organization/distributor distributor is distributor of a creative work 48 | P800:notable_work person work notable work is notable scientific, artistic or literary work, or other work of significance among subject’s works 49 | P931:place_served facility/transport_hub location place served by transport hub is city or region served by this transport hub (airport, train station, etc.) 50 | P937:work_location person location work location is location where persons were active 51 | P974:tributary location/body_of_water location/body_of_water tributary is stream or river that flows into this main stem (or parent) river 52 | P991:successful_candidate event/election person/candidate successful candidate is person(s) elected after the election 53 | P1001:jurisdiction legal_entity location applies to jurisdiction is the item (an institution, law, public office ...) belongs to or has power over or applies to the value (a territorial jurisdiction: a country, state, municipality, ...) 54 | P1303:instrument person instrument instrument is musical instrument that a person plays 55 | P1344:participant person event participant of is event a person or an organization was a participant in, inverse of P710 or P1923 56 | P1346:winner event person winner is winner of an event - do not use for wars or battles 57 | P1408:licensed media location licensed to broadcast to is place that a radio/TV station is licensed/required to broadcast to 58 | P1411:nominated_for candidate title/award nominated for is award nomination received by a person, organisation or creative work (inspired from ”award received” Property:P166)) 59 | P1435:heritage location/heritage_site title/designation heritage designation is heritage designation of a historical site 60 | P1877:after_work work_of_art person/author after a work by is artist whose work strongly inspired/ was copied in this item 61 | P1923:participating_teams event/competition person/team participating teams is Like ’Participant’ (P710) but for teams. For an event like a cycle race or a football match you can use this property to list the teams and P710 to list the individuals (with ’member of sports team’ (P54)’ as a qualifier for the individuals) 62 | P3373:sibling person person sibling is the subject has the object as their sibling (brother, sister, etc.). Use ”relative” (P1038) for siblings-in-law (brother-in-law, sister-in-law, etc.) and step-siblings (step-brothers, step-sisters, etc.) 63 | P3450:sports_season date event/sports_competition sports season of is property that shows the competition of which the item is a season 64 | P4552:mountain_range location/peak location/mountain mountain range is range or subrange to which the geographical item belongs 65 | P25:mother person person/woman mother is female parent of the subject. For stepmother, use ”stepparent” (P3448) 66 | P26:spouse person person spouse is the subject has the object as their spouse (husband, wife, partner, etc.). Use ”unmarried partner” (P451) for non-married companions 67 | P40:child person person child is subject has the object in their family as their offspring son or daughter (independently of their age) 68 | P59:constellation space_object/star space_object/constellation constellation is the area of the celestial sphere of which the subject is a part (from a scientific standpoint, not an astrological one) 69 | P155:follows successor predecessor follows is immediately prior item in some series of which the subject is part. Use P1365 (replaces) if the preceding item was replaced, e.g. political offices, states and there is no identity between precedent and following geographic unit 70 | P177:crosses facility/bridge location/body_of_water crosses is obstacle (body of water, road, ...) which this bridge crosses over or this tunnel goes under 71 | P206:located_near_water location location/body_of_water located in or next to body of water is sea, lake or river 72 | P361:part_of part whole part of is object of which the subject is a part. Inverse property of ”has part” (P527). See also ”has parts of the class” (P2670). 73 | P364:original_language work language original language of work is language in which a film or a performance work was originally created. Deprecated for written works; use P407 (”language of work or name”) instead. 74 | P410:military_rank person title/military_rank military rank is military rank achieved by a person (should usually have a ”start date” qualifier), or military rank associated with a position 75 | P412:voice_type person/singer category/voice_type voice type is person’s voice type. expected values: soprano, mezzo-soprano, contralto, countertenor, tenor, baritone, bass (and derivatives) 76 | P413:position_played person title/sports_position position played on team / speciality is position or specialism of a player on a team, e.g. Small Forward 77 | P463:member_org member organization member of is organization or club to which the subject belongs. Do not use for membership in ethnic or social groups, nor for holding a position such as a member of parliament (use P39 for that). 78 | P641:sport person/athlete sport sport is sport in which the entity participates or belongs to 79 | P921:subject work topic main subject is primary topic of a work (see also P180: depicts) 80 | P2094:competition_class person/athlete category/competition_class competition class is official classification by a regulating body under which the subject (events, teams, participants, or equipment) qualifies for inclusion -------------------------------------------------------------------------------- /baselines/RoCORE_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Dict, Tuple, Optional, Union 4 | from collections import defaultdict 5 | import pickle as pkl 6 | 7 | from tqdm import tqdm 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import pytorch_lightning as pl 12 | from sklearn.cluster import KMeans 13 | import numpy as np 14 | 15 | from transformers import AutoConfig, AutoModel, AdamW, get_linear_schedule_with_warmup 16 | 17 | 18 | 19 | from common.metrics import ClusteringMetricsWrapper, PsuedoLabelMetricWrapper 20 | from common.predictions import ClusterPredictionsWrapper 21 | from .RoCORE_layers import ZeroShotModel, L2Reg, compute_kld 22 | 23 | import common.log as log 24 | logger = log.get_logger('root') 25 | 26 | class RoCOREModel(pl.LightningModule): 27 | def __init__(self, args, tokenizer, train_len:int = 1000) -> None: 28 | super().__init__() 29 | self.config = args 30 | 31 | self.tokenizer = tokenizer 32 | self.model_config = AutoConfig.from_pretrained(args.model_name_or_path, output_hidden_states = True) 33 | pretrained_model = AutoModel.from_pretrained(args.model_name_or_path, config = self.model_config) 34 | embeddings = pretrained_model.resize_token_embeddings(len(self.tokenizer)) 35 | self.train_len=train_len # this is required to set up the optimizer 36 | 37 | self.net = ZeroShotModel(args, args.known_types, args.unknown_types, self.model_config, pretrained_model, unfreeze_layers = [args.layer]) 38 | self.train_unknown_metrics_wrapper = ClusteringMetricsWrapper(stage='train', 39 | known=False, 40 | prefix='train_unknown', known_classes=0) # no shift 41 | self.val_unknown_metrics_wrapper = ClusteringMetricsWrapper(stage='val', 42 | known=False, 43 | prefix='val_unknown', known_classes=0) 44 | self.test_unknown_metrics_wrapper = ClusteringMetricsWrapper(stage='test', 45 | known=False, 46 | prefix='test_unknown', known_classes=0) 47 | 48 | self.train_known_metrics_wrapper = ClusteringMetricsWrapper(stage='train', 49 | known=True, 50 | prefix='train_known', 51 | known_classes=args.known_types) 52 | self.val_known_metrics_wrapper = ClusteringMetricsWrapper(stage='val', 53 | known=True, 54 | prefix='val_known', known_classes=args.known_types) 55 | self.test_known_metrics_wrapper = ClusteringMetricsWrapper(stage='test', 56 | known=True, 57 | prefix='test_known', known_classes=args.known_types) 58 | 59 | if args.eval_only: 60 | self.predictions_wrapper = ClusterPredictionsWrapper(reassign=True, prefix='test_unknown', 61 | known_classes=0) 62 | 63 | def forward(self, inputs): 64 | pass 65 | 66 | def on_validation_epoch_start(self) -> None: 67 | # reset all metrics 68 | self.train_unknown_metrics_wrapper.reset() 69 | self.val_unknown_metrics_wrapper.reset() 70 | self.test_unknown_metrics_wrapper.reset() 71 | 72 | self.train_known_metrics_wrapper.reset() 73 | self.val_known_metrics_wrapper.reset() 74 | self.test_known_metrics_wrapper.reset() 75 | 76 | return 77 | 78 | def on_train_epoch_start(self): 79 | logger.info('updating cluster centers....') 80 | train_dl = self.trainer.train_dataloader.loaders 81 | known_centers = torch.zeros(self.config.known_types, self.config.kmeans_dim, device = self.device) 82 | num_samples = [0] * self.config.known_types 83 | with torch.no_grad(): 84 | uid2pl = {} # pseudo labels 85 | unknown_uid_list = [] 86 | unknown_vec_list = [] 87 | seen_uid = set() # oversampling for the unknown part, so we remove them here 88 | for batch in tqdm(iter(train_dl)): 89 | labels = batch[0]['labels'] 90 | known_mask = batch[0]['known_mask'] 91 | metadata = batch[0]['meta'] 92 | batch_size = len(metadata) 93 | 94 | # move batch to gpu 95 | for key in ['token_ids', 'attn_mask','head_spans','tail_spans']: 96 | batch[0][key] = batch[0][key].to(self.device) 97 | 98 | commonspace_rep = self.net.forward(batch[0], msg = 'similarity') # (batch_size, hidden_dim) 99 | for i in range(batch_size): 100 | if known_mask[i] == True: 101 | l = labels[i] 102 | known_centers[l] += commonspace_rep[i] 103 | num_samples[l] += 1 104 | else: 105 | uid = metadata[i]['uid'] 106 | if uid not in seen_uid: 107 | seen_uid.add(uid) 108 | unknown_uid_list.append(uid) 109 | unknown_vec_list.append(commonspace_rep[i].cpu().numpy()) 110 | 111 | # cluster unknown classes 112 | clf = KMeans(n_clusters=self.config.unknown_types,random_state=0,algorithm='full') 113 | rep = np.stack(unknown_vec_list, axis=0) 114 | label_pred = clf.fit_predict(rep)# from 0 to args.new_class - 1 115 | self.net.ct_loss_u.centers = torch.from_numpy(clf.cluster_centers_).to(self.device)# (num_class, kmeans_dim) 116 | for i in range(len(unknown_vec_list)): 117 | uid = unknown_uid_list[i] 118 | pseudo = label_pred[i] 119 | uid2pl[uid] = pseudo + self.config.known_types 120 | 121 | 122 | train_dl.dataset.update_pseudo_labels(uid2pl) 123 | logger.info('updating pseudo labels...') 124 | pl_acc = train_dl.dataset.check_pl_acc() 125 | self.log('train/pl_acc', pl_acc, on_epoch=True) 126 | 127 | 128 | # update center for known types 129 | for c in range(self.config.known_types): 130 | known_centers[c] /= num_samples[c] 131 | self.net.ct_loss_l.centers = known_centers 132 | return 133 | 134 | 135 | 136 | 137 | def _compute_unknown_margin_loss(self, batch: Dict[str, torch.Tensor], pseudo_labels: torch.LongTensor, known_mask: torch.BoolTensor) -> torch.FloatTensor: 138 | # convert 1d pseudo label into 2d pairwise pseudo label 139 | assert (torch.min(pseudo_labels) >= self.config.known_types) 140 | 141 | pair_label = (pseudo_labels.unsqueeze(0) == pseudo_labels.unsqueeze(1)).float() 142 | logits = self.net.forward(batch, mask=~known_mask, msg = 'unlabeled') # (batch_size, new_class) 143 | # this only predicts over new classes 144 | unknown_batch_size = pseudo_labels.size(0) 145 | expanded_logits = logits.expand(unknown_batch_size, -1, -1) 146 | expanded_logits2 = expanded_logits.transpose(0, 1) 147 | kl1 = compute_kld(expanded_logits.detach(), expanded_logits2) 148 | kl2 = compute_kld(expanded_logits2.detach(), expanded_logits) # (batch_size, batch_size) 149 | assert kl1.requires_grad 150 | unknown_class_loss = torch.mean(pair_label * (kl1 + kl2) + (1 - pair_label) * (torch.relu(self.config.sigmoid - kl1) + torch.relu(self.config.sigmoid - kl2))) 151 | return unknown_class_loss 152 | 153 | 154 | def _compute_reconstruction_loss(self, batch: Dict[str, torch.Tensor], known_mask: torch.BoolTensor, labels: torch.LongTensor) -> torch.FloatTensor: 155 | # recon loss for known classes 156 | commonspace_rep_known, rec_loss_known = self.net.forward(batch, mask=known_mask, msg = 'reconstruct') # (batch_size, kmeans_dim) 157 | # recon loss for unknown classes 158 | _ , rec_loss_unknown = self.net.forward(batch,mask=~known_mask, msg = 'reconstruct') # (batch_size, kmeans_dim) 159 | reconstruction_loss = (rec_loss_known.mean() + rec_loss_unknown.mean()) / 2 160 | # center loss for known classes 161 | center_loss = self.config.center_loss * self.net.ct_loss_l(labels[known_mask], commonspace_rep_known) 162 | l2_reg = 1e-5 * (L2Reg(self.net.similarity_encoder) + L2Reg(self.net.similarity_decoder)) 163 | loss = reconstruction_loss + center_loss + l2_reg 164 | return loss 165 | 166 | 167 | def _compute_ce_loss(self, batch: Dict[str, torch.Tensor], known_mask: torch.BoolTensor, labels: torch.LongTensor) -> torch.FloatTensor: 168 | ''' 169 | Cross entropy loss for known classes. 170 | ''' 171 | known_logits = self.net.forward(batch, mask=known_mask, msg = 'labeled') # single layer labeled head 172 | _, label_pred = torch.max(known_logits, dim = -1) 173 | known_label = labels[known_mask] 174 | acc = 1.0 * torch.sum(label_pred == known_label) / len(label_pred) 175 | ce_loss = F.cross_entropy(input = known_logits, target = known_label) 176 | return ce_loss, acc 177 | 178 | 179 | def training_step(self, batch: List[Dict[str, torch.Tensor]], batch_idx:int): 180 | ''' 181 | batch = { 182 | 'meta':List[Dict], 183 | 'token_ids': torch.LongTensor (batch, seq_len), 184 | 'attn_mask': torch.BoolTensor (batch, seq_len) 185 | 'labels': torch.LongTensor([x['label'] for x in batch]) 186 | 'head_spans': , 187 | 'tail_spans': , 188 | 'mask_bpe_idx': , 189 | 'known_mask' torch.BoolTensor 190 | } 191 | 192 | For RoCORE: data = (input_ids, input_mask, label, head_span, tail_span) 193 | ''' 194 | view_n = len(batch) 195 | batch_size = len(batch[0]['meta']) 196 | labels = batch[0]['labels'] 197 | known_mask = batch[0]['known_mask'] # (batch, ) 198 | psuedo_labels = batch[0]['pseudo_labels'] 199 | 200 | loss = self._compute_reconstruction_loss(batch[0], known_mask, labels) 201 | margin_loss = self._compute_unknown_margin_loss(batch[0], psuedo_labels[~known_mask], known_mask) 202 | ce_loss, acc = self._compute_ce_loss(batch[0], known_mask, labels) 203 | if self.current_epoch >= self.config.num_pretrain_epochs: 204 | loss += margin_loss 205 | loss += ce_loss 206 | 207 | self.log('train/unknown_margin_loss', margin_loss) 208 | self.log('train/known_ce_loss', ce_loss) 209 | 210 | 211 | self.log('train/known_acc', acc) 212 | self.log('train/loss', loss) 213 | return loss 214 | 215 | def validation_step(self, batch: List[Dict[str, torch.Tensor]], batch_idx:int, dataloader_idx:int): 216 | 217 | view_n = len(batch) 218 | batch_size = len(batch[0]['meta']) 219 | 220 | 221 | # use the gt labels to compute metrics 222 | if dataloader_idx == 0: 223 | logits = self.net.forward(batch[0],msg = 'unlabeled') 224 | self.train_unknown_metrics_wrapper.update_batch(logits, batch[0]['labels']- self.config.known_types , incremental=False) 225 | elif dataloader_idx == 1: 226 | logits = self.net.forward(batch[0], msg = 'unlabeled') 227 | self.val_unknown_metrics_wrapper.update_batch(logits, batch[0]['labels'] - self.config.known_types, incremental=False) 228 | else: 229 | logits = self.net.forward(batch[0], msg = 'labeled') 230 | self.val_known_metrics_wrapper.update_batch(logits, batch[0]['labels'], incremental=False) 231 | return 232 | 233 | def validation_epoch_end(self, outputs: List[List[Dict]]) -> None: 234 | val_unknown_metrics = self.val_unknown_metrics_wrapper.on_epoch_end() 235 | train_unknown_metrics = self.train_unknown_metrics_wrapper.on_epoch_end() 236 | val_known_metrics = self.val_known_metrics_wrapper.on_epoch_end() 237 | 238 | for k,v in val_unknown_metrics.items(): 239 | self.log(f'val/unknown_{k}',value=v, logger=True, on_step=False, on_epoch=True) 240 | 241 | for k,v in train_unknown_metrics.items(): 242 | self.log(f'train/unknown_{k}', value=v, logger=True, on_step=False, on_epoch=True) 243 | 244 | for k,v in val_known_metrics.items(): 245 | self.log(f'val/known_{k}',value=v, logger=True, on_step=False, on_epoch=True) 246 | return 247 | 248 | 249 | def test_step(self, batch: List[Dict], batch_idx: int) -> Dict: 250 | ''' 251 | :param dataloader_idx: 0 for unknown_test 252 | ''' 253 | view_n = len(batch) 254 | batch_size = len(batch[0]['meta']) 255 | 256 | target_logits = self.net.forward(batch[0], msg = 'unlabeled') 257 | # use the gt labels to compute metrics 258 | self.test_unknown_metrics_wrapper.update_batch(target_logits, batch[0]['labels'] - self.config.known_types , incremental=False) 259 | 260 | self.predictions_wrapper.update_batch(batch[0]['meta'], target_logits, batch[0]['labels'] - self.config.known_types , incremental=False) 261 | 262 | return {} 263 | 264 | def test_epoch_end(self, outputs: List[Dict]) -> None: 265 | test_unknown_metrics = self.test_unknown_metrics_wrapper.on_epoch_end() 266 | for k,v in test_unknown_metrics.items(): 267 | self.log(f'test/unknown_{k}',value=v, logger=True, on_step=False, on_epoch=True) 268 | 269 | self.test_unknown_metrics_wrapper.save(self.config.ckpt_dir) 270 | 271 | self.predictions_wrapper.on_epoch_end() 272 | self.predictions_wrapper.save(self.config.ckpt_dir) 273 | 274 | return 275 | 276 | 277 | def configure_optimizers(self): 278 | no_decay = ["bias", "LayerNorm.weight"] 279 | optimizer_grouped_parameters = [ 280 | { 281 | "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 282 | "weight_decay": self.config.weight_decay, 283 | }, 284 | {"params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 285 | ] 286 | 287 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate, eps=self.config.adam_epsilon) 288 | return [optimizer, ] 289 | 290 | 291 | 292 | 293 | 294 | 295 | -------------------------------------------------------------------------------- /common/layers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Dict 2 | from collections import OrderedDict 3 | from math import ceil 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | class Prototypes(nn.Module): 9 | def __init__(self, feat_dim, num_prototypes, norm:bool=False): 10 | super().__init__() 11 | 12 | if norm: 13 | self.norm = nn.LayerNorm(feat_dim) 14 | else: 15 | self.norm = lambda x: x 16 | self.prototypes = nn.Linear(feat_dim, num_prototypes, bias=False) 17 | 18 | @torch.no_grad() 19 | def initialize_prototypes(self, centers): 20 | self.prototypes.weight.copy_(centers) 21 | return 22 | 23 | 24 | @torch.no_grad() 25 | def normalize_prototypes(self): 26 | w = self.prototypes.weight.data.clone() 27 | w = F.normalize(w, dim=1, p=2) 28 | self.prototypes.weight.copy_(w) 29 | 30 | def freeze_prototypes(self): 31 | self.prototypes.requires_grad_(False) 32 | 33 | def unfreeze_prototypes(self): 34 | self.prototypes.requires_grad_(True) 35 | 36 | 37 | def forward(self, x): 38 | x = self.norm(x) 39 | return self.prototypes(x) 40 | 41 | 42 | class MLP(nn.Module): 43 | ''' 44 | Simple n layer MLP with ReLU activation and batch norm. 45 | The order is Linear, Norm, ReLU 46 | ''' 47 | def __init__(self, feat_dim: int, hidden_dim: int, latent_dim:int, 48 | norm:bool=False, norm_type:str='batch', layers_n: int = 1, dropout_p: float =0.1 ): 49 | ''' 50 | :param norm_type: one of layer, batch 51 | ''' 52 | super().__init__() 53 | self.feat_dim= feat_dim 54 | self._hidden_dim= hidden_dim 55 | self.latent_dim = latent_dim 56 | self.input2hidden = nn.Linear(feat_dim, hidden_dim) 57 | self.dropout = nn.Dropout(p=dropout_p) 58 | layers = [self.dropout, ] 59 | for i in range(layers_n): 60 | if i==0: 61 | layers.append(nn.Linear(feat_dim, hidden_dim)) 62 | out_dim = hidden_dim 63 | elif i==1: 64 | layers.append(nn.Linear(hidden_dim, latent_dim)) 65 | out_dim = latent_dim 66 | else: 67 | layers.append(nn.Linear(latent_dim, latent_dim)) 68 | out_dim = latent_dim 69 | if norm: 70 | if norm_type == 'batch': 71 | layers.append(nn.BatchNorm1d(out_dim)) 72 | else: 73 | layers.append(nn.LayerNorm(out_dim)) 74 | if i < layers_n -1: # last layer has no relu 75 | layers.append(nn.ReLU()) 76 | 77 | self.net = nn.Sequential(*layers) 78 | 79 | def forward(self, input): 80 | ''' 81 | :param input: torch.FloatTensor (batch, ..., feat_dim) 82 | 83 | :return output: torch.FloatTensor (batch, ..., hidden_dim) 84 | ''' 85 | output = self.net(input.reshape(-1, self.feat_dim)) 86 | 87 | original_shape = input.shape 88 | new_shape = tuple(list(input.shape[:-1]) + [self.latent_dim]) 89 | 90 | output = output.reshape(new_shape) 91 | return output 92 | 93 | class ReconstructionNet(nn.Module): 94 | ''' 95 | projection from hidden_size back to feature_size. 96 | ''' 97 | def __init__(self, feature_size:int, hidden_size: int, latent_size:int ) -> None: 98 | super().__init__() 99 | self.feature_size = feature_size 100 | self.hidden_size = hidden_size 101 | assert (feature_size > hidden_size) 102 | self.latent_size = latent_size 103 | self.net = nn.Sequential( 104 | nn.Linear(in_features=self.latent_size, out_features=self.hidden_size), 105 | nn.BatchNorm1d(self.hidden_size), 106 | nn.ReLU(), 107 | nn.Linear(self.hidden_size, self.feature_size) 108 | ) 109 | 110 | def forward(self, inputs: torch.FloatTensor): 111 | 112 | output = self.net(inputs.reshape(-1, self.hidden_size)) 113 | new_shape = tuple(list(inputs.shape[:-1]) + [self.feature_size]) 114 | 115 | output = output.reshape(new_shape) 116 | return output 117 | 118 | class CommonSpaceCache(nn.Module): 119 | ''' 120 | A cache for saving common space embeddings and using it to compute contrastive loss. 121 | ''' 122 | def __init__(self, feature_size:int, known_cache_size: int, unknown_cache_size: int, metric_type: str='cosine', sim_thres: float=0.8) -> None: 123 | super().__init__() 124 | self.feature_size = feature_size 125 | self.known_cache_size = known_cache_size 126 | self.unknown_cache_size = unknown_cache_size 127 | self.known_len=0 128 | self.unknown_len =0 129 | 130 | self.metric_type =metric_type 131 | self.metric = nn.CosineSimilarity(dim=2, eps=1e-8) 132 | self.sim_thres=sim_thres 133 | 134 | self.temp = 0.1 # temperature for softmax 135 | 136 | self.register_buffer("known_cache", torch.zeros((known_cache_size, feature_size), dtype=torch.float), persistent=False) 137 | self.register_buffer("unknown_cache", torch.zeros((unknown_cache_size, feature_size), dtype=torch.float), persistent=False) 138 | 139 | self.register_buffer("known_labels", torch.zeros((known_cache_size,), dtype=torch.long), persistent=False) 140 | self.register_buffer("unknown_labels", torch.zeros((unknown_cache_size, ), dtype=torch.long), persistent=False) 141 | 142 | 143 | def cache_full(self)-> bool: 144 | if (self.known_len == self.known_cache_size) and (self.unknown_len == self.unknown_cache_size): 145 | return True 146 | else: 147 | return False 148 | 149 | @torch.no_grad() 150 | def update_batch(self, embeddings: torch.FloatTensor, known_mask: torch.BoolTensor, labels: Optional[torch.LongTensor]=None) -> None: 151 | ''' 152 | Add embeddings to cache. 153 | :param embeddings: (batch, feature_size) 154 | ''' 155 | embeddings_detached = embeddings.detach() 156 | 157 | known_embeddings = embeddings_detached[known_mask,:] 158 | known_size = known_embeddings.size(0) 159 | new_known_cache = torch.concat([known_embeddings, self.known_cache], dim=0) 160 | self.known_cache = new_known_cache[:self.known_cache_size] 161 | self.known_len = min(self.known_len + known_size, self.known_cache_size) 162 | if labels!=None: 163 | known_labels = labels[known_mask] 164 | self.known_labels = torch.concat([known_labels, self.known_labels], dim=0)[:self.known_cache_size] 165 | unknown_labels = labels[~known_mask] 166 | self.unknown_labels = torch.concat([unknown_labels, self.unknown_labels], dim=0)[:self.unknown_cache_size] 167 | 168 | 169 | unknown_embeddings = embeddings_detached[~known_mask,: ] 170 | unknown_size = unknown_embeddings.size(0) 171 | new_unknown_cache = torch.concat([unknown_embeddings, self.unknown_cache], dim=0) 172 | self.unknown_cache = new_unknown_cache[:self.unknown_cache_size] 173 | self.unknown_len = min(self.unknown_len + unknown_size, self.unknown_cache_size) 174 | return 175 | 176 | @torch.no_grad() 177 | def get_positive_example(self, embedding: torch.FloatTensor, known: bool =False) -> Tuple[torch.FloatTensor, torch.BoolTensor]: 178 | ''' 179 | :param embeddings (N, feature_dim) 180 | 181 | :returns (N, feature_dim) 182 | ''' 183 | embedding_detached = embedding.detach() 184 | if known: 185 | cache = self.known_cache 186 | label_cache = self.known_labels 187 | else: 188 | cache = self.unknown_cache 189 | label_cache = self.unknown_labels 190 | 191 | if self.metric_type == 'cosine': 192 | similarity = self.metric(embedding_detached.unsqueeze(dim=1), cache.unsqueeze(dim=0)) # N, cache_size 193 | else: 194 | similarity = torch.einsum("ik,jk->ij", embedding_detached, cache) 195 | 196 | max_sim, max_idx = torch.max(similarity, dim=1) #(N, ) 197 | min_thres = self.sim_thres 198 | valid_pos_mask = (max_sim > min_thres) #(N, ) 199 | pos_embeddings = cache[max_idx, :] # (N, feature_dim) 200 | pos_labels = label_cache[max_idx] # (N, ) 201 | 202 | return pos_embeddings, valid_pos_mask, pos_labels 203 | 204 | @torch.no_grad() 205 | def get_negative_example_for_unknown(self, embedding: torch.FloatTensor, k: int=3) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 206 | ''' 207 | Take half of the negative examples from the unknown cache and half from the known cache. 208 | :param embeddings (N, feature_dim) 209 | ''' 210 | embedding_detached= embedding.detach() 211 | N = embedding_detached.size(0) 212 | if self.metric_type == 'cosine': 213 | unknown_similarity = self.metric(embedding_detached.unsqueeze(dim=1), self.unknown_cache.unsqueeze(dim=0)) # N, cache_size 214 | else: 215 | unknown_similarity = torch.einsum('ik,jk->ij', embedding_detached, self.unknown_cache) 216 | 217 | sorted_unk_idx = torch.argsort(unknown_similarity, dim=1) # N, cache_size 218 | unk_n = ceil(sorted_unk_idx.size(1) /2) 219 | candidate_neg_unk_idx = sorted_unk_idx[:, :unk_n] # N, cache_size/2 220 | # this is used for generating indexes 221 | neg_unk_list = [] 222 | for i in range(N): 223 | random_idx = torch.randperm(n=unk_n, dtype=torch.long, device=embedding.device)[:k] 224 | chosen_neg_unk_idx = candidate_neg_unk_idx[i, :][random_idx] 225 | chosen_neg_unk = self.unknown_cache[chosen_neg_unk_idx, :] # K, feature_size 226 | neg_unk_list.append(chosen_neg_unk) 227 | 228 | if self.metric_type == 'cosine': 229 | known_similarity = self.metric(embedding_detached.unsqueeze(dim=1), self.known_cache.unsqueeze(dim=0)) # (N, cache_size) 230 | else: 231 | known_similarity = torch.einsum("ik,jk->ij", embedding_detached, self.known_cache) 232 | 233 | sorted_known_idx = torch.argsort(known_similarity, dim=1, descending=True) # choose hard examples (N, cache_size) 234 | neg_known_list = [] 235 | chosen_neg_known_idx = sorted_known_idx[:, :k] 236 | for i in range(N): 237 | chosen_neg_known = self.known_cache[chosen_neg_known_idx[i], :] 238 | neg_known_list.append(chosen_neg_known) 239 | 240 | neg_unk = torch.stack(neg_unk_list, dim=0) 241 | neg_known = torch.stack(neg_known_list, dim=0) # (N, K, feature_size) 242 | 243 | return neg_unk, neg_known 244 | 245 | def get_contrastive_candidates(self, embeddings: torch.FloatTensor, neg_n: int=6, labels: Optional[torch.LongTensor]=None): 246 | N = embeddings.size(0) 247 | if labels!=None: assert (labels.size(0) == N) 248 | 249 | pos_embeddings, valid_pos_mask, pos_labels = self.get_positive_example(embeddings, known=False) # (N, hidden_dim) 250 | assert (pos_embeddings.shape == embeddings.shape ) 251 | # report positive sample accuracy 252 | pos_acc = self.compute_accuracy(labels[valid_pos_mask], pos_labels[valid_pos_mask]) 253 | 254 | neg_unk_embeddings, neg_known_embeddings = self.get_negative_example_for_unknown(embeddings, k=ceil(neg_n/2)) # (N, K, hidden_dim) 255 | candidates = torch.concat([pos_embeddings.unsqueeze(dim=1), neg_unk_embeddings, neg_known_embeddings], dim=1) # (N, 2K+1, hidden_dim) 256 | # scores = torch.einsum('ik,ijk->ij', embeddings, candidates) # (N, 2K+1 ) 257 | # targets = torch.zeros((N,), dtype=torch.long, device=scores.device) 258 | # loss = F.cross_entropy(scores/self.temp, targets) 259 | return candidates, valid_pos_mask, pos_acc 260 | 261 | def compute_accuracy(self, labels, other_labels): 262 | # consider moving average 263 | assert (labels.shape == other_labels.shape) 264 | acc = torch.sum(labels == other_labels)*1.0 / labels.size(0) 265 | return acc 266 | 267 | 268 | class ClassifierHead(nn.Module): 269 | def __init__(self, args, feature_size: int, 270 | n_classes: int, layers_n: int = 1, 271 | n_heads: int =1, dropout_p: float =0.2, hidden_size: Optional[int]=None) -> None: 272 | super().__init__() 273 | self.args = args 274 | 275 | self.feature_size = feature_size 276 | self.n_classes = n_classes 277 | self.n_heads = n_heads 278 | if hidden_size: 279 | self.hidden_size = hidden_size 280 | else: 281 | self.hidden_size = feature_size 282 | 283 | if layers_n == 1: 284 | self.classifier = nn.Sequential(OrderedDict( 285 | [('dropout',nn.Dropout(p=dropout_p)), 286 | ('centroids', Prototypes(feat_dim=self.hidden_size, num_prototypes=self.n_classes))] 287 | )) 288 | elif layers_n > 1: 289 | self.classifier = nn.Sequential(OrderedDict( 290 | [('mlp', MLP(feat_dim=self.feature_size, hidden_dim=self.hidden_size, latent_dim=self.hidden_size, norm=True,layers_n=layers_n-1)), 291 | ('dropout',nn.Dropout(p=dropout_p)), 292 | ('centroids', Prototypes(feat_dim=self.hidden_size, num_prototypes=self.n_classes))] 293 | )) 294 | 295 | def initialize_centroid(self, centers): 296 | for n, module in self.classifier.named_modules(): 297 | if n=='centroids': 298 | module.initialize_prototypes(centers) 299 | 300 | return 301 | 302 | def update_centroid(self): 303 | ''' 304 | The centroids are essentially just the vectors in the final Linear layer. Here we normalize them. they are trained along with the model. 305 | ''' 306 | for n, module in self.classifier.named_modules(): 307 | if n=='centroids': 308 | module.normalize_prototypes() 309 | 310 | return 311 | 312 | def freeze_centroid(self): 313 | ''' 314 | From Swav paper, freeze the prototypes to help with initial optimization. 315 | ''' 316 | for n, module in self.classifier.named_modules(): 317 | if n=='centroids': 318 | module.freeze_prototypes() 319 | 320 | return 321 | 322 | def unfreeze_centroid(self): 323 | for n, module in self.classifier.named_modules(): 324 | if n=='centroids': 325 | module.unfreeze_prototypes() 326 | 327 | return 328 | 329 | 330 | def forward(self, inputs: torch.FloatTensor): 331 | ''' 332 | :params inputs: (batch, feat_dim) 333 | 334 | :returns logits: (batch, n_classes) 335 | ''' 336 | outputs = self.classifier(inputs) 337 | return outputs 338 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from datetime import datetime 4 | import json 5 | 6 | import torch 7 | import yaml 8 | from pytorch_lightning import Trainer 9 | from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping, ModelCheckpoint 10 | from pytorch_lightning.loggers import WandbLogger 11 | from pytorch_lightning.utilities.seed import seed_everything 12 | from transformers import AutoTokenizer, PreTrainedTokenizer 13 | 14 | from data_module import OpenTypeDataModule 15 | from event_data_module import OpenTypeEventDataModule 16 | from event_e2e_data_module import E2EEventDataModule 17 | 18 | from model import TypeDiscoveryModel 19 | from baselines.RoCORE_model import RoCOREModel 20 | from baselines.etypeclus_model import ETypeClusModel 21 | from baselines.vqvae_model import VQVAEModel 22 | from baselines.RSN_model import RSNModel 23 | 24 | import common.log as log 25 | logger = log.get_logger('root') 26 | 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | 32 | parser.add_argument('--load_configs', type=str, help='yaml file to load configurations from') 33 | # Parameters for the model 34 | model_arg_group = parser.add_argument_group('model') 35 | model_arg_group.add_argument('--task', type=str, default='rel', choices=['rel', 'event']) 36 | model_arg_group.add_argument('--model', type=str, default='tabs', choices=['tabs', 'rocore','etypeclus', 'vqvae','rsn']) 37 | model_arg_group.add_argument('--collect_features', action='store_true') 38 | model_arg_group.add_argument('--predict_names', action='store_true') 39 | model_arg_group.add_argument('--supervised_training', action='store_true', help='train with true labels to check upper bound.') 40 | model_arg_group.add_argument('--known_types', type=int, default=64, 41 | help='the number of types during training. The remaining will be unseen and used for testing.') 42 | model_arg_group.add_argument('--unknown_types', type=int, default=16) 43 | 44 | model_arg_group.add_argument('--test_ratio', type=float, default=0.15, 45 | help='percentage of instances that are used for testing.') 46 | model_arg_group.add_argument('--incremental', action='store_true', default=False, help='whether or not the test set is mixture of known and unknown.') 47 | model_arg_group.add_argument('--dataset_name', 48 | type=str, 49 | default='fewrel', 50 | choices=['tacred', 'fewrel', 'ace']) 51 | 52 | 53 | model_arg_group.add_argument('--feature', type=str, default='all', choices=['token','mask', 'all']) 54 | model_arg_group.add_argument('--e2e', action='store_true') 55 | model_arg_group.add_argument('--token_pooling', type=str, default='first', choices=['max','first']) 56 | model_arg_group.add_argument('--regularization', type=str, default='temp', choices=['sk', 'temp']) 57 | model_arg_group.add_argument('--temp', type=float, default=0.2, help='value between 0 and 1') 58 | model_arg_group.add_argument('--sk_epsilon', type=float, default=0.05) 59 | model_arg_group.add_argument('--psuedo_label', type=str, default='combine', choices=['other','combine','self']) 60 | model_arg_group.add_argument('--rev_ratio', type=float, default=0.0) 61 | model_arg_group.add_argument( 62 | "--model_type", 63 | default='bert', 64 | type=str, 65 | choices=['bert', 'roberta','albert','gpt2'] 66 | ) 67 | model_arg_group.add_argument('--model_name_or_path', 68 | default='bert-base-uncased', 69 | type=str, 70 | help="the model name (e.g., 'roberta-large') or path to a pretrained model") 71 | model_arg_group.add_argument('--cache_dir', 72 | type=str, 73 | help='the cache location for transformers library') 74 | model_arg_group.add_argument('--hidden_dim', type=int, default=256) 75 | model_arg_group.add_argument('--classifier_layers',type=int, default=2) 76 | 77 | 78 | model_arg_group.add_argument('--check_pl', action='store_true', help='compute psuedo label accuracy for diagnosis') 79 | model_arg_group.add_argument('--supervised_pretrain', action='store_true', default=False) 80 | model_arg_group.add_argument('--label_smoothing_alpha', type=float, default=0.1) 81 | model_arg_group.add_argument('--label_smoothing_ramp', type=int, default=0) 82 | model_arg_group.add_argument('--consistency_loss', type=float, default=0.0) 83 | model_arg_group.add_argument('--pairwise_loss', action='store_true') 84 | model_arg_group.add_argument('--clustering', type=str, default='kmeans', choices=['online','kmeans','spectral','agglomerative','ward', 'dbscan']) 85 | model_arg_group.add_argument('--freeze_pretrain', default=False, action='store_true') 86 | 87 | 88 | # parameters for RoCORE 89 | model_arg_group.add_argument('--center_loss', type=float, default=0.005) 90 | model_arg_group.add_argument('--sigmoid', type=float, default=2.0) 91 | model_arg_group.add_argument("--layer", type = int, default = 8) 92 | model_arg_group.add_argument('--kmeans_dim', type=int, default=256) 93 | 94 | 95 | # parameters for etypeclus 96 | model_arg_group.add_argument('--temperature', type=float, default=0.1) 97 | model_arg_group.add_argument('--distribution', default='softmax', choices=['softmax', 'student']) 98 | model_arg_group.add_argument('--hidden_dims', default='[768, 500, 1000, 100]', type=str) 99 | model_arg_group.add_argument('--gamma', default=5, type=float, help='weight of clustering loss') 100 | 101 | # parameter for vqvae 102 | # note that the gamma parameter is overloaded in this model for the unsupervised loss weight 103 | model_arg_group.add_argument('--beta', type=float, default=1.0, help='weight for the commitment loss') 104 | model_arg_group.add_argument('--hybrid', action='store_true', help='whether to use a hybrid vae + vqvae') 105 | 106 | # parameters for RSN 107 | model_arg_group.add_argument('--perturb_scale', type=float, default=0.02) 108 | model_arg_group.add_argument('--vat_loss_weight', type=float, default=1.0) 109 | model_arg_group.add_argument('--use_cnn', action='store_true') 110 | model_arg_group.add_argument('--p_cond', type=float, default=0.03, help='weight of the conditional ce loss on unknown classes.') 111 | model_arg_group.add_argument('--clustering_method', type=str, default='spectral', choices=['louvain', 'spectral'], help='clustering method after pairwise metric.') 112 | # Parameters for IO 113 | io_arg_group = parser.add_argument_group('io') 114 | io_arg_group.add_argument( 115 | "--ckpt_name", 116 | default=None, 117 | type=str, 118 | help="The output directory where the model checkpoints and predictions will be written.", 119 | ) 120 | io_arg_group.add_argument( 121 | "--load_ckpt", 122 | default=None, 123 | type=str, 124 | help='whether to load an existing model. Required for eval.' 125 | ) 126 | io_arg_group.add_argument('--load_pretrained', type=str, help='load a pretrained model.') 127 | io_arg_group.add_argument('--dataset_dir', type=str, default='data/fewrel') 128 | 129 | 130 | # Parameters for runtime 131 | runtime_arg_group = parser.add_argument_group('runtime') 132 | runtime_arg_group.add_argument("--train_batch_size", 133 | default=100, type=int, 134 | help="Batch size per GPU/CPU for training.") 135 | runtime_arg_group.add_argument( 136 | "--eval_batch_size", 137 | default=100, type=int, 138 | help="Batch size per GPU/CPU for evaluation." 139 | ) 140 | runtime_arg_group.add_argument( 141 | "--eval_only", action="store_true", 142 | ) 143 | runtime_arg_group.add_argument('--num_workers', type=int, default=0) 144 | runtime_arg_group.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 145 | runtime_arg_group.add_argument( 146 | "--accumulate_grad_batches", 147 | type=int, 148 | default=1, 149 | help="Number of updates steps to accumulate before performing a backward/update pass.", 150 | ) 151 | runtime_arg_group.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 152 | runtime_arg_group.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 153 | runtime_arg_group.add_argument("--gradient_clip_val", default=1.0, type=float, help="Max gradient norm.") 154 | runtime_arg_group.add_argument( 155 | "--num_train_epochs", default=20, type=int, help="Total number of training epochs to perform." 156 | ) 157 | runtime_arg_group.add_argument( 158 | '--num_pretrain_epochs', default=0, type=int 159 | ) 160 | runtime_arg_group.add_argument( 161 | "--max_steps", 162 | default=-1, 163 | type=int, 164 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", 165 | ) 166 | runtime_arg_group.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 167 | 168 | runtime_arg_group.add_argument("--gpus", default='6,', help='-1 means train on all gpus') 169 | runtime_arg_group.add_argument("--seed", type=int, default=42, help="random seed for initialization") 170 | runtime_arg_group.add_argument( 171 | "--fp16", 172 | action="store_true", 173 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 174 | ) 175 | 176 | args = parser.parse_args() 177 | 178 | # Set seed 179 | seed_everything(args.seed) 180 | 181 | if not args.ckpt_name: 182 | d = datetime.now() 183 | time_str = d.strftime('%m-%dT%H%M') 184 | args.ckpt_name = '{}_{}_{}_{}'.format(args.model, args.dataset_name, args.unknown_types, time_str) 185 | 186 | 187 | args.ckpt_dir = os.path.join(f'./checkpoints/{args.ckpt_name}') 188 | 189 | os.makedirs(args.ckpt_dir, exist_ok=True) 190 | 191 | # read base arguments from yaml 192 | if args.load_configs: 193 | with open(args.load_configs, 'r') as f: 194 | yaml_configs = yaml.load(f, Loader=yaml.FullLoader) 195 | 196 | for k, v in yaml_configs.items(): 197 | if k=='learning_rate': 198 | v = eval(v) # convert string to float 199 | args.__dict__[k] = v 200 | 201 | # save the arguments to file 202 | arg_dict = vars(args) 203 | with open(os.path.join(args.ckpt_dir, 'params.json'),'w') as f: 204 | json.dump(arg_dict, f, indent=2) 205 | 206 | logger.info("Training/evaluation parameters %s", args) 207 | 208 | checkpoint_callback = ModelCheckpoint( 209 | dirpath=args.ckpt_dir, 210 | save_top_k=1, 211 | save_last=True, 212 | monitor='val/unknown_acc', # metric name 213 | mode='max', 214 | save_weights_only=True, 215 | filename='{epoch}', # this cannot contain slashes 216 | 217 | ) 218 | 219 | 220 | lr_logger = LearningRateMonitor() 221 | # TODO: change this to your own project name and username 222 | wb_logger = WandbLogger(project='open-type', name=args.ckpt_name, entity='shali0173') 223 | 224 | 225 | if args.max_steps < 0 : 226 | args.max_epochs = args.min_epochs = args.num_train_epochs 227 | 228 | 229 | 230 | tokenizer = AutoTokenizer.from_pretrained( 231 | args.model_name_or_path, 232 | cache_dir=args.cache_dir if args.cache_dir else None) # type: PreTrainedTokenizer 233 | tokenizer.add_tokens(['', '', '','','','']) 234 | # tokenizer.add_tokens(['', '', '','']) # quick fix for old checkpoints 235 | vocab_size = len(tokenizer) 236 | 237 | if args.task == 'rel': 238 | dm = OpenTypeDataModule(args, tokenizer, args.dataset_dir) 239 | elif args.task == 'event': 240 | if args.e2e: 241 | dm = E2EEventDataModule(args,tokenizer, args.dataset_dir) 242 | else: 243 | dm = OpenTypeEventDataModule(args, tokenizer, args.dataset_dir) 244 | 245 | dm.setup() 246 | train_dm = dm.train_dataloader() 247 | train_len = len(train_dm) 248 | if args.model == 'tabs': 249 | model = TypeDiscoveryModel(args, tokenizer, train_len) 250 | elif args.model == 'rocore': 251 | model = RoCOREModel(args, tokenizer, train_len) 252 | elif args.model == 'etypeclus': 253 | model = ETypeClusModel(args, tokenizer, train_len) 254 | elif args.model == 'vqvae': 255 | model = VQVAEModel(args, tokenizer, train_len) 256 | elif args.model == 'rsn': 257 | model = RSNModel(args, tokenizer, train_len) 258 | else: 259 | raise ValueError(f"model name {args.model} not recognized.") 260 | 261 | 262 | trainer = Trainer( 263 | logger=wb_logger, 264 | min_epochs=args.num_train_epochs, 265 | max_epochs=args.num_train_epochs, 266 | gpus=str(args.gpus), # use string or list to specify gpu id 267 | accumulate_grad_batches=args.accumulate_grad_batches, 268 | gradient_clip_val=args.gradient_clip_val, 269 | num_sanity_val_steps=10, 270 | val_check_interval=1.0, # use float to check every n epochs 271 | precision=16 if args.fp16 else 32, 272 | callbacks = [lr_logger, checkpoint_callback], 273 | strategy="ddp", 274 | ) 275 | 276 | if args.load_ckpt: 277 | checkpoint = torch.load(args.load_ckpt,map_location=model.device) 278 | # enlarge the embeddings based on the state dict 279 | for k, v in checkpoint['state_dict'].items(): 280 | if 'embeddings.word_embeddings.weight' in k: 281 | load_vocab_size = v.size(0) 282 | 283 | if load_vocab_size < vocab_size: 284 | # this is due to using less special tokens (old checkpoints) 285 | # resize the old embedding 286 | embed_dim = v.size(1) 287 | new_embeddings = torch.nn.Embedding(vocab_size, embed_dim) 288 | new_embeddings.to(model.device) 289 | new_embeddings.weight.data[:load_vocab_size, :] = v.data[:load_vocab_size, :] 290 | checkpoint['state_dict'][k] = new_embeddings.weight 291 | # embeddings = model.pretrained_model.resize_token_embeddings(vocab_size) 292 | 293 | 294 | model.load_state_dict(checkpoint['state_dict'], strict=True) 295 | model.on_load_checkpoint(checkpoint) 296 | 297 | elif args.load_pretrained: 298 | logger.info(f'loading pretrained model from {args.load_pretrained}') 299 | checkpoint = torch.load(args.load_pretrained, map_location=model.device) 300 | model.load_pretrained_model(checkpoint['state_dict']) 301 | 302 | if args.collect_features: 303 | model.collect_features(train_dm, known=False, raw=False) 304 | 305 | elif args.eval_only: 306 | trainer.test(model, datamodule=dm) #also loads training dataloader 307 | else: 308 | # model.initialize_centroids(train_dm) 309 | trainer.fit(model, datamodule=dm) 310 | 311 | if __name__ == "__main__": 312 | main() -------------------------------------------------------------------------------- /baselines/etypeclus_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Dict, Tuple, Optional, Union 4 | from collections import defaultdict 5 | import pickle as pkl 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import pytorch_lightning as pl 12 | from sklearn.cluster import KMeans 13 | import numpy as np 14 | from tqdm import tqdm 15 | 16 | from transformers import AutoConfig, AutoModel, AdamW, get_linear_schedule_with_warmup 17 | 18 | 19 | from .latent_space_clustering import AutoEncoder, cosine_dist 20 | from common.metrics import ClusteringMetricsWrapper 21 | from common.predictions import ClusterPredictionsWrapper 22 | from common.utils import cluster_acc, get_label_mapping, onedim_gather 23 | 24 | 25 | 26 | import common.log as log 27 | logger = log.get_logger('root') 28 | 29 | class ETypeClusModel(pl.LightningModule): 30 | ''' 31 | This is a wrapper of the TopicCluster class. 32 | ''' 33 | def __init__(self, args, tokenizer, train_len) -> None: 34 | super().__init__() 35 | self.config = args 36 | self.tokenizer = tokenizer 37 | self.model_config = AutoConfig.from_pretrained(args.model_name_or_path, output_hidden_states = True) 38 | self.pretrained_model = AutoModel.from_pretrained(args.model_name_or_path, config = self.model_config) 39 | embeddings = self.pretrained_model.resize_token_embeddings(len(self.tokenizer)) # when adding new tokens, the tokenizer.vocab_size is not changed! 40 | 41 | self.train_len=train_len # this is required to set up the optimizer,make optional 42 | self.temperature = args.temperature 43 | self.distribution = args.distribution 44 | 45 | input_dim = self.model_config.hidden_size 46 | hidden_dims = eval(args.hidden_dims) 47 | self.topic_emb = nn.Parameter(torch.Tensor(args.unknown_types, hidden_dims[-1])) 48 | 49 | 50 | self.q_dict = {} # uid -> target distribution 51 | 52 | self.model = AutoEncoder(input_dim, hidden_dims) 53 | 54 | 55 | torch.nn.init.xavier_normal_(self.topic_emb.data) 56 | 57 | self.freeze_model() 58 | 59 | 60 | self.train_unknown_metrics_wrapper = ClusteringMetricsWrapper(stage='train', 61 | known=False, 62 | prefix='train_unknown', known_classes=args.known_types) 63 | self.val_unknown_metrics_wrapper = ClusteringMetricsWrapper(stage='val', 64 | known=False, 65 | prefix='val_unknown', known_classes=args.known_types) 66 | self.test_unknown_metrics_wrapper = ClusteringMetricsWrapper(stage='test', 67 | known=False, 68 | prefix='test_unknown', known_classes=args.known_types) 69 | 70 | 71 | if args.eval_only: 72 | self.predictions_wrapper = ClusterPredictionsWrapper(reassign=True, prefix='test_unknown', 73 | known_classes=args.known_types, task=args.task) 74 | 75 | 76 | def on_validation_epoch_start(self) -> None: 77 | # reset all metrics 78 | self.train_unknown_metrics_wrapper.reset() 79 | self.val_unknown_metrics_wrapper.reset() 80 | self.test_unknown_metrics_wrapper.reset() 81 | return 82 | 83 | 84 | def freeze_model(self): 85 | self.pretrained_model.requires_grad_(False) 86 | return 87 | 88 | def on_train_epoch_start(self) -> None: 89 | train_dl = self.trainer.train_dataloader.loaders 90 | with torch.no_grad(): 91 | z_list = [] 92 | uid_list = [] 93 | 94 | for batch in tqdm(iter(train_dl)): 95 | known_mask = batch[0]['known_mask'] 96 | metadata = batch[0]['meta'] 97 | batch_size = len(metadata) 98 | 99 | view = batch[0] 100 | # move batch to gpu 101 | for key in ['token_ids', 'attn_mask','head_spans','tail_spans','mask_bpe_idx','trigger_spans']: 102 | if key in view: view[key] = view[key].to(self.device) 103 | # get features by pretrained_model 104 | outputs = self.pretrained_model(input_ids=view['token_ids'],attention_mask=view['attn_mask']) 105 | seq_output = outputs[0] 106 | feat = onedim_gather(seq_output, dim=1, index=view['trigger_spans'][:, 0].unsqueeze(1)).squeeze(1) 107 | x = F.normalize(feat[~known_mask], dim=1) 108 | x_bar, z = self.model(x) 109 | z = F.normalize(z, dim=1) 110 | z_list.append(z) 111 | uid_list.extend([it['uid'] for idx, it in enumerate(metadata) if known_mask[idx] == False]) 112 | 113 | if self.current_epoch ==0: 114 | # initialize the clusters by kmeans 115 | logger.info('initializing by kmeans...') 116 | kmeans = KMeans(n_clusters=self.config.unknown_types, n_init=5) 117 | rep = torch.concat(z_list, dim=0).cpu().numpy() 118 | y_pred = kmeans.fit_predict(rep) # y_pred is used to determine end of training 119 | self.topic_emb.data = torch.tensor(kmeans.cluster_centers_).to(self.device) 120 | 121 | else: 122 | logger.info('updating target distribution q') 123 | all_z = torch.concat(z_list, dim=0) # (N, hidden_dim) 124 | freq = torch.ones((all_z.size(0)), dtype=torch.long, device=self.device) 125 | p , q = self.target_distribution(all_z, freq, method='all', top_num=self.current_epoch+1) 126 | assert (q.size(0) == len(uid_list)) 127 | for idx, uid in enumerate(uid_list): 128 | self.q_dict[uid] = q[idx,:] 129 | 130 | return 131 | 132 | def collect_features(self, dataloader, known:bool=False, max_batch=50): 133 | ''' 134 | collect features for visualization. 135 | ''' 136 | feat_arrays = [] 137 | 138 | with torch.no_grad(): 139 | all_feats = [[],] 140 | all_labels = [[],] # List[str] 141 | 142 | for batch_idx, batch in enumerate(iter(dataloader)): 143 | view_n = len(batch) 144 | known_mask = batch[0]['known_mask'] # (batch, ) 145 | view = batch[0] 146 | outputs = self.pretrained_model(input_ids=view['token_ids'],attention_mask=view['attn_mask']) 147 | seq_output = outputs[0] 148 | feat = onedim_gather(seq_output, dim=1, index=view['trigger_spans'][:, 0].unsqueeze(1)).squeeze(1) 149 | if known: 150 | feat = feat[known_mask] 151 | else: 152 | feat = feat[~known_mask] 153 | labels = [x['label'] for x in batch[0]['meta'] if x['known'] == known] 154 | all_feats[0].append(feat.cpu().numpy()) 155 | all_labels[0].extend(labels) 156 | if batch_idx == max_batch: break 157 | 158 | for i in range(len(all_feats)): 159 | all_feats_array = np.concatenate(all_feats[i], axis=0)# (n_instances, hidden_dim) 160 | feat_arrays.append(all_feats_array) 161 | 162 | with open(os.path.join(self.config.ckpt_dir, 'view_features.pkl'), 'wb') as f: 163 | pkl.dump(feat_arrays,f ) 164 | with open(os.path.join(self.config.ckpt_dir, 'labels.pkl'),'wb') as f: 165 | pkl.dump(all_labels,f ) 166 | 167 | return 168 | def cluster_assign(self, z: torch.FloatTensor) -> torch.FloatTensor: 169 | ''' 170 | :param z: (batch, hidden_dim) 171 | :returns p: (batch, n_clusters) 172 | ''' 173 | if self.distribution == 'student': 174 | p = 1.0 / (1.0 + torch.sum( 175 | torch.pow(z.unsqueeze(1) - self.topic_emb, 2), 2) / self.alpha) 176 | p = p.pow((self.alpha + 1.0) / 2.0) 177 | p = (p.t() / torch.sum(p, 1)).t() 178 | else: 179 | self.topic_emb.data = F.normalize(self.topic_emb.data, dim=-1) 180 | z = F.normalize(z, dim=-1) 181 | sim = torch.matmul(z, self.topic_emb.t()) / self.temperature 182 | p = F.softmax(sim, dim=-1) 183 | return p 184 | 185 | def forward(self, x: torch.FloatTensor): 186 | x_bar, z = self.model(x) 187 | p = self.cluster_assign(z) 188 | return x_bar, z, p 189 | 190 | def target_distribution(self, z: torch.FloatTensor, 191 | freq: torch.LongTensor, method='all', top_num=0): 192 | ''' 193 | :param x: (batch, hidden_dim) 194 | :param freq: (batch) 195 | ''' 196 | p = self.cluster_assign(z).detach() 197 | if method == 'all': 198 | q = p**2 / (p * freq.unsqueeze(-1)).sum(dim=0) 199 | q = (q.t() / q.sum(dim=1)).t() 200 | elif method == 'top': 201 | assert top_num > 0 202 | q = p.clone() 203 | sim = torch.matmul(self.topic_emb, z.t()) 204 | _, selected_idx = sim.topk(k=top_num, dim=-1) 205 | for i, topic_idx in enumerate(selected_idx): 206 | q[topic_idx] = 0 207 | q[topic_idx, i] = 1 208 | return p, q 209 | 210 | 211 | def training_step(self, batch, batch_idx): 212 | ''' 213 | batch = { 214 | 'meta':List[Dict], 215 | 'token_ids': torch.LongTensor (batch, seq_len), 216 | 'attn_mask': torch.BoolTensor (batch, seq_len) 217 | 'labels': torch.LongTensor([x['label'] for x in batch]) 218 | 'head_spans': , 219 | 'tail_spans': , 220 | 'mask_bpe_idx': , 221 | 'known_mask' torch.BoolTensor 222 | } 223 | ''' 224 | view_n = len(batch) 225 | batch_size = len(batch[0]['meta']) 226 | labels = batch[0]['labels'] 227 | 228 | known_mask = batch[0]['known_mask'] # (batch, ) 229 | view = batch[0] 230 | outputs = self.pretrained_model(input_ids=view['token_ids'],attention_mask=view['attn_mask']) 231 | seq_output = outputs[0] 232 | feat = onedim_gather(seq_output, dim=1, index=view['trigger_spans'][:, 0].unsqueeze(1)).squeeze(1) 233 | x = F.normalize(feat[~known_mask]) 234 | 235 | x_bar, z , p = self.forward(x) 236 | reconstr_loss = cosine_dist(x_bar, x) 237 | self.log('train/recon_loss', reconstr_loss) 238 | 239 | if self.current_epoch < self.config.num_pretrain_epochs: 240 | loss = reconstr_loss 241 | return loss 242 | 243 | q_batch = torch.zeros((batch_size, self.config.unknown_types), dtype=torch.float, device=self.device) 244 | for i in range(batch_size): 245 | uid = batch[0]['meta'][i]['uid'] 246 | if known_mask[i] == False: 247 | q_batch[i, : ] = self.q_dict[uid] 248 | kl_loss = F.kl_div(p.log(), q_batch[~known_mask], reduction='none').sum() 249 | loss = self.config.gamma * kl_loss + reconstr_loss 250 | self.log('train/kl_loss',kl_loss) 251 | 252 | return loss 253 | 254 | def validation_step(self, batch: List[Dict[str, torch.Tensor]], batch_idx: int, dataloader_idx: int)-> Dict[str, torch.Tensor]: 255 | ''' 256 | :param dataloader_idx: 0 for unknown_train, 1 for unknown_test, 2 for known_test 257 | ''' 258 | view_n = len(batch) 259 | batch_size = len(batch[0]['meta']) 260 | view = batch[0] 261 | known_mask = batch[0]['known_mask'] # (batch, ) 262 | 263 | outputs = self.pretrained_model(input_ids=view['token_ids'],attention_mask=view['attn_mask']) 264 | seq_output = outputs[0] 265 | feat = onedim_gather(seq_output, dim=1, index=view['trigger_spans'][:, 0].unsqueeze(1)).squeeze(1) 266 | x = F.normalize(feat[~known_mask]) 267 | 268 | x_bar, _ , p = self.forward(x) 269 | 270 | # use the gt labels to compute metrics 271 | # for this model, do not evaluation on known types 272 | # setting incremental to True will not subtract the number of known types 273 | if dataloader_idx == 0: 274 | self.train_unknown_metrics_wrapper.update_batch(p, batch[0]['labels'], incremental=True) 275 | elif dataloader_idx == 1: 276 | self.val_unknown_metrics_wrapper.update_batch(p, batch[0]['labels'], incremental=True) 277 | 278 | return {} 279 | 280 | 281 | def validation_epoch_end(self, outputs: List[List[Dict]]) -> None: 282 | val_unknown_metrics = self.val_unknown_metrics_wrapper.on_epoch_end() 283 | train_unknown_metrics = self.train_unknown_metrics_wrapper.on_epoch_end() 284 | 285 | for k,v in val_unknown_metrics.items(): 286 | self.log(f'val/unknown_{k}',value=v, logger=True, on_step=False, on_epoch=True) 287 | 288 | for k,v in train_unknown_metrics.items(): 289 | self.log(f'train/unknown_{k}', value=v, logger=True, on_step=False, on_epoch=True) 290 | 291 | return 292 | 293 | 294 | def test_step(self, batch: List[Dict], batch_idx: int) -> Dict: 295 | ''' 296 | :param dataloader_idx: 0 for unknown_test 297 | ''' 298 | view_n = len(batch) 299 | batch_size = len(batch[0]['meta']) 300 | view = batch[0] 301 | known_mask = batch[0]['known_mask'] # (batch, ) 302 | 303 | outputs = self.pretrained_model(input_ids=view['token_ids'],attention_mask=view['attn_mask']) 304 | seq_output = outputs[0] 305 | feat = onedim_gather(seq_output, dim=1, index=view['trigger_spans'][:, 0].unsqueeze(1)).squeeze(1) 306 | x = F.normalize(feat[~known_mask]) 307 | 308 | x_bar, _ , p = self.forward(x) 309 | 310 | # use the gt labels to compute metrics 311 | self.test_unknown_metrics_wrapper.update_batch(p, batch[0]['labels'], incremental=True) 312 | 313 | self.predictions_wrapper.update_batch(batch[0]['meta'], p, batch[0]['labels'], incremental=True) 314 | 315 | return {} 316 | 317 | 318 | def test_epoch_end(self, outputs: List[Dict]) -> None: 319 | test_unknown_metrics = self.test_unknown_metrics_wrapper.on_epoch_end() 320 | for k,v in test_unknown_metrics.items(): 321 | self.log(f'test/unknown_{k}',value=v, logger=True, on_step=False, on_epoch=True) 322 | 323 | self.test_unknown_metrics_wrapper.save(self.config.ckpt_dir) 324 | 325 | self.predictions_wrapper.on_epoch_end() 326 | self.predictions_wrapper.save(self.config.ckpt_dir) 327 | 328 | return 329 | 330 | 331 | 332 | 333 | def configure_optimizers(self): 334 | no_decay = ["bias", "LayerNorm.weight"] 335 | optimizer_grouped_parameters = [ 336 | { 337 | "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 338 | "weight_decay": self.config.weight_decay, 339 | }, 340 | {"params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 341 | ] 342 | 343 | 344 | 345 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate, eps=self.config.adam_epsilon) 346 | 347 | if self.config.max_steps > 0: 348 | t_total = self.config.max_steps 349 | self.config.num_train_epochs = self.config.max_steps // self.train_len // self.config.accumulate_grad_batches + 1 350 | else: 351 | t_total = self.train_len // self.config.accumulate_grad_batches * self.config.num_train_epochs 352 | 353 | logger.info('{} training steps in total.. '.format(t_total)) 354 | 355 | 356 | # scheduler is called only once per epoch by default 357 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.config.warmup_steps, num_training_steps=t_total) 358 | scheduler_dict = { 359 | 'scheduler': scheduler, 360 | 'interval': 'step', 361 | 'name': 'linear-schedule', 362 | } 363 | 364 | return [optimizer, ], [scheduler_dict,] 365 | 366 | 367 | 368 | 369 | -------------------------------------------------------------------------------- /baselines/RSN_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Dict, Tuple, Optional, Union 4 | from collections import defaultdict 5 | import pickle as pkl 6 | import random 7 | from tqdm import tqdm 8 | 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import pytorch_lightning as pl 14 | 15 | from transformers import AutoConfig, AutoModel, AdamW, get_linear_schedule_with_warmup 16 | from torchmetrics import F1Score 17 | 18 | from common.metrics import PairwiseClusteringMetricsWrapper 19 | from common.predictions import PairwiseClusterPredictionsWrapper 20 | from .RSN import RSNLayer 21 | 22 | import common.log as log 23 | logger = log.get_logger('root') 24 | 25 | class RSNModel(pl.LightningModule): 26 | def __init__(self, args, tokenizer, train_len:int = 1000) -> None: 27 | super().__init__() 28 | self.config = args 29 | self.tokenizer = tokenizer 30 | self.model_config = AutoConfig.from_pretrained(args.model_name_or_path, output_hidden_states = True) 31 | self.pretrained_model = AutoModel.from_pretrained(args.model_name_or_path, config = self.model_config) 32 | embeddings = self.pretrained_model.resize_token_embeddings(len(self.tokenizer)) # when adding new tokens, the tokenizer.vocab_size is not changed! 33 | 34 | self.train_len=train_len 35 | 36 | self.rsn_head = RSNLayer(self.model_config.hidden_size, rel_dim=self.config.hidden_dim, use_cnn=self.config.use_cnn) 37 | 38 | self.f1_metric = F1Score(num_classes=1, threshold=0.5, multiclass=False) 39 | 40 | self.train_unknown_metrics_wrapper = PairwiseClusteringMetricsWrapper(stage='train', 41 | prefix='train_unknown', cluster_n=args.unknown_types, clustering_method=self.config.clustering_method) 42 | self.val_unknown_metrics_wrapper = PairwiseClusteringMetricsWrapper(stage='val', 43 | prefix='val_unknown', cluster_n=args.unknown_types, clustering_method=self.config.clustering_method) 44 | self.test_unknown_metrics_wrapper = PairwiseClusteringMetricsWrapper(stage='test', 45 | prefix='test_unknown', cluster_n=args.unknown_types, clustering_method=self.config.clustering_method) 46 | 47 | self.val_known_metrics_wrapper = PairwiseClusteringMetricsWrapper(stage='val', 48 | prefix='val_known', cluster_n=args.known_types, clustering_method=self.config.clustering_method) 49 | 50 | if args.eval_only: 51 | self.predictions_wrapper = PairwiseClusterPredictionsWrapper(prefix='test_unknown', 52 | task=args.task) 53 | 54 | # def get_v_adv_loss(self, ul_left_input, ul_right_input, p_mult, power_iterations=1): 55 | # bernoulli = tf.distributions.Bernoulli 56 | # prob, left_word_emb, right_word_emb = self(ul_left_input, ul_right_input)[0:3] 57 | # prob = tf.clip_by_value(prob, 1e-7, 1.-1e-7) 58 | # prob_dist = bernoulli(probs=prob) 59 | # #generate virtual adversarial perturbation 60 | # left_d = tf.random_uniform(shape=tf.shape(left_word_emb), dtype=tf.float32) 61 | # right_d = tf.random_uniform(shape=tf.shape(right_word_emb), dtype=tf.float32) 62 | # for _ in range(power_iterations): 63 | # left_d = (0.02) * tf.nn.l2_normalize(left_d, dim=1) 64 | # right_d = (0.02) * tf.nn.l2_normalize(right_d, dim=1) 65 | # p_prob = tf.clip_by_value(self(ul_left_input, ul_right_input, left_d, right_d)[0], 1e-7, 1.-1e-7) 66 | # kl = tf.distributions.kl_divergence(prob_dist, bernoulli(probs=p_prob), allow_nan_stats=False) 67 | # left_gradient,right_gradient = tf.gradients(kl, [left_d,right_d], 68 | # aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) 69 | # left_d = tf.stop_gradient(left_gradient) 70 | # right_d = tf.stop_gradient(right_gradient) 71 | # left_d = p_mult * tf.nn.l2_normalize(left_d, dim=1) 72 | # right_d = p_mult * tf.nn.l2_normalize(right_d, dim=1) 73 | # tf.stop_gradient(prob) 74 | # #virtual adversarial loss 75 | # p_prob = tf.clip_by_value(self(ul_left_input, ul_right_input, left_d, right_d)[0], 1e-7, 1.-1e-7) 76 | # v_adv_losses = tf.distributions.kl_divergence(prob_dist, bernoulli(probs=p_prob), allow_nan_stats=False) 77 | # return tf.reduce_mean(v_adv_losses) 78 | 79 | 80 | def compute_vat_loss(self, head_spans, tail_spans, seq_output, 81 | perturb_scale:float=0.02, power_iterations: int=1): 82 | seq_output = seq_output.detach() 83 | prob = self.rsn_head(head_spans, tail_spans, seq_output) # B, B 84 | prob = torch.clamp(prob, min=1e-7, max=1.0-1e-7) 85 | prob_dist = torch.distributions.Bernoulli(probs=prob) 86 | prob = prob.detach() 87 | 88 | # generate perturbation 89 | d = torch.rand_like(seq_output, dtype=torch.float, requires_grad=True, device=self.device) 90 | for _ in range(power_iterations): 91 | d = perturb_scale * F.normalize(d, p=2, dim=2) 92 | p_prob = self.rsn_head(head_spans, tail_spans, seq_output, perturb=d) 93 | p_prob = torch.clamp(p_prob, min=1e-7, max=1.0-1e-7) 94 | kl = torch.distributions.kl_divergence(prob_dist, torch.distributions.Bernoulli(probs=p_prob)) 95 | kl = torch.mean(kl) 96 | kl.backward(inputs=d) 97 | 98 | d_grad = torch.clone(d.grad) 99 | d.grad.zero_() 100 | d = perturb_scale * F.normalize(d_grad, p=2, dim=2) 101 | 102 | p_prob = self.rsn_head(head_spans, tail_spans, seq_output, perturb=d) 103 | p_prob = torch.clamp(p_prob, min=1e-7, max=1.0-1e-7) 104 | vat_loss = torch.mean(torch.distributions.kl_divergence(prob_dist, torch.distributions.Bernoulli(probs=p_prob))) 105 | 106 | return vat_loss 107 | 108 | 109 | def training_step(self, batch: List[Dict[str, torch.Tensor]], batch_idx: int): 110 | ''' 111 | batch = { 112 | 'meta':List[Dict], 113 | 'token_ids': torch.LongTensor (batch, seq_len), 114 | 'attn_mask': torch.BoolTensor (batch, seq_len) 115 | 'labels': torch.LongTensor([x['label'] for x in batch]) 116 | 'head_spans': , 117 | 'tail_spans': , 118 | 'mask_bpe_idx': , 119 | 'known_mask' torch.BoolTensor 120 | } 121 | ''' 122 | view_n = len(batch) 123 | batch_size = len(batch[0]['meta']) 124 | labels = batch[0]['labels'] 125 | 126 | known_mask = batch[0]['known_mask'] # (batch, ) 127 | view = batch[0] 128 | outputs = self.pretrained_model(input_ids=view['token_ids'],attention_mask=view['attn_mask']) 129 | seq_output = outputs[0] 130 | 131 | head_spans = view['head_spans'] # (batch, 2) 132 | tail_spans = view['tail_spans'] # (batch,2 ) 133 | 134 | pair_predicted_known = self.rsn_head(head_spans[known_mask], tail_spans[known_mask], seq_output[known_mask]) # (B, B) 135 | # 0 for same class and 1 for different class 136 | pair_labels_known = (labels[known_mask].unsqueeze(1) != labels[known_mask].unsqueeze(0)).float() # (B, B) 137 | # supervised bce loss 138 | bce_loss = F.binary_cross_entropy(pair_predicted_known, pair_labels_known) 139 | 140 | if known_mask.long().min() < 1: # has unknown elements 141 | pair_predicted_unknown = self.rsn_head(head_spans[~known_mask], tail_spans[~known_mask],seq_output[~known_mask, :]) # (B, B) 142 | # unsupervised conditional entropy loss 143 | pair_predicted_unknown = torch.clamp(pair_predicted_unknown, min=1e-7, max=1-1e-7) 144 | cond_loss = F.binary_cross_entropy(pair_predicted_unknown, pair_predicted_unknown) 145 | else: 146 | cond_loss = 0.0 147 | 148 | if self.config.vat_loss_weight >0: 149 | sup_vat_loss = self.compute_vat_loss(head_spans[known_mask], tail_spans[known_mask], seq_output[known_mask], 150 | perturb_scale=self.config.perturb_scale, power_iterations=1) 151 | unsup_vat_loss= self.compute_vat_loss(head_spans[~known_mask], tail_spans[~known_mask], seq_output[~known_mask], 152 | perturb_scale=self.config.perturb_scale, power_iterations=1) 153 | else: 154 | sup_vat_loss = 0.0 155 | unsup_vat_loss = 0.0 156 | supervised_loss = bce_loss + self.config.vat_loss_weight * sup_vat_loss 157 | unsupervised_loss = cond_loss + self.config.vat_loss_weight * unsup_vat_loss 158 | loss = supervised_loss + self.config.p_cond * unsupervised_loss 159 | 160 | self.log('train/bce_loss', bce_loss) 161 | self.log('train/cond_ent_loss', cond_loss) 162 | self.log('train/loss', loss) 163 | 164 | 165 | return loss 166 | 167 | 168 | 169 | def validation_step(self, batch: List[Dict[str, torch.Tensor]], batch_idx: int, dataloader_idx: int)-> Dict[str, torch.Tensor]: 170 | ''' 171 | :param dataloader_idx: 0 for unknown_train, 1 for unknown_test, 2 for known_test 172 | ''' 173 | view_n = len(batch) 174 | batch_size = len(batch[0]['meta']) 175 | view = batch[0] 176 | known_mask = batch[0]['known_mask'] # (batch, ) 177 | labels = batch[0]['labels'] 178 | 179 | outputs = self.pretrained_model(input_ids=view['token_ids'],attention_mask=view['attn_mask']) 180 | seq_output = outputs[0] 181 | 182 | head_spans = view['head_spans'] # (batch, 2) 183 | tail_spans = view['tail_spans'] # (batch,2 ) 184 | rel_embed = self.rsn_head.embed(head_spans, tail_spans, seq_output) # (B, rel_dim) 185 | return { 186 | 'meta': batch[0]['meta'], 187 | 'labels': batch[0]['labels'], 188 | 'embed': rel_embed 189 | } 190 | 191 | # pair_predicted = self.rsn_head(head_spans, tail_spans, seq_output) 192 | # pair_labels = (labels.unsqueeze(1) != labels.unsqueeze(0)) # (B, B) 193 | 194 | # acc = torch.sum((pair_predicted > 0.5) == pair_labels)/ (batch_size * batch_size) 195 | 196 | # if dataloader_idx == 0: 197 | # self.log('train/unknown_pair_acc', acc, on_epoch=True, add_dataloader_idx=False) 198 | # elif dataloader_idx ==1: 199 | # self.log('val/unknown_pair_acc', acc, on_epoch=True, add_dataloader_idx=False) 200 | # else: 201 | # self.log('val/known_pair_acc', acc, on_epoch=True, add_dataloader_idx=False) 202 | 203 | # return {} 204 | 205 | def validation_epoch_end(self, all_outputs: List[List[Dict]]) -> None: 206 | VAL_SAMPLE=1000 207 | for dataloader_idx, outputs in enumerate(all_outputs): 208 | # flatten list 209 | if dataloader_idx != 1: continue # only use val_unknown 210 | embeddings = {} # uid -> tensor 211 | metadata = {} 212 | for output in outputs: 213 | batch_size = len(output['meta']) 214 | for i in range(batch_size): 215 | uid = output['meta'][i]['uid'] 216 | embeddings[uid] = output['embed'][i] 217 | metadata[uid] = output['meta'][i] 218 | metadata[uid]['label_idx'] = output['labels'][i].item() 219 | # compute similarity 220 | distance = {} 221 | labels = {} 222 | seen = set() 223 | if len(embeddings) > VAL_SAMPLE: 224 | sampled_uids = random.sample(list(embeddings.keys()), k=VAL_SAMPLE) 225 | else: 226 | sampled_uids = embeddings.keys() 227 | 228 | 229 | for uid in sampled_uids: 230 | x1= embeddings[uid] 231 | labels[uid] = metadata[uid]['label_idx'] 232 | for uid_other in sampled_uids: 233 | x2 = embeddings[uid_other] 234 | if (uid, uid_other) not in seen: 235 | dis = self.rsn_head.compute_distance(x1, x2).item() 236 | distance[(uid, uid_other)] = dis 237 | distance[(uid_other, uid)] = dis 238 | seen.add((uid, uid_other)) 239 | seen.add((uid_other, uid)) 240 | 241 | sampled_labels = {k:labels[k] for k in sampled_uids} 242 | logger.info(f'running clustering on {len(labels)} data points...') 243 | val_unknown_metrics, _ = self.val_unknown_metrics_wrapper.on_epoch_end(distance, sampled_labels) 244 | 245 | for k,v in val_unknown_metrics.items(): 246 | self.log(f'val/unknown_{k}',value=v, logger=True, on_step=False, on_epoch=True) 247 | 248 | return 249 | 250 | 251 | def test_step(self, batch: List[Dict], batch_idx: int) -> Dict: 252 | ''' 253 | :param dataloader_idx: 0 for unknown_test 254 | ''' 255 | view_n = len(batch) 256 | batch_size = len(batch[0]['meta']) 257 | view = batch[0] 258 | outputs = self.pretrained_model(input_ids=view['token_ids'],attention_mask=view['attn_mask']) 259 | seq_output = outputs[0] 260 | head_spans = view['head_spans'] # (batch, 2) 261 | tail_spans = view['tail_spans'] # (batch,2 ) 262 | rel_embed = self.rsn_head.embed(head_spans, tail_spans, seq_output) # (B, rel_dim) 263 | return { 264 | 'meta': batch[0]['meta'], 265 | 'labels': batch[0]['labels'], 266 | 'embed': rel_embed 267 | } 268 | 269 | 270 | def test_epoch_end(self, outputs: List[Dict]) -> None: 271 | # flatten list 272 | embeddings = {} # uid -> tensor 273 | metadata = {} 274 | for output in outputs: 275 | batch_size = len(output['meta']) 276 | for i in range(batch_size): 277 | uid = output['meta'][i]['uid'] 278 | embeddings[uid] = output['embed'][i] 279 | metadata[uid] = output['meta'][i] 280 | metadata[uid]['label_idx'] = output['labels'][i].item() 281 | # compute similarity 282 | distance = {} 283 | labels = {} 284 | seen = set() 285 | for uid in tqdm(embeddings): 286 | x1= embeddings[uid] 287 | labels[uid] = metadata[uid]['label_idx'] 288 | for uid_other in embeddings: 289 | x2 = embeddings[uid_other] 290 | if (uid, uid_other) not in seen: 291 | dis = self.rsn_head.compute_distance(x1, x2).item() 292 | distance[(uid, uid_other)] = dis 293 | distance[(uid_other, uid)] = dis 294 | seen.add((uid, uid_other)) 295 | seen.add((uid_other, uid)) 296 | 297 | logger.info(f'running clustering on {len(labels)} data points...') 298 | 299 | test_unknown_metrics, pred_cluster = self.test_unknown_metrics_wrapper.on_epoch_end(distance, labels) 300 | for k,v in test_unknown_metrics.items(): 301 | self.log(f'test/unknown_{k}',value=v, logger=True, on_step=False, on_epoch=True) 302 | 303 | self.test_unknown_metrics_wrapper.save(self.config.ckpt_dir) 304 | 305 | self.predictions_wrapper.on_epoch_end(pred_cluster, metadata) 306 | self.predictions_wrapper.save(self.config.ckpt_dir) 307 | 308 | return 309 | 310 | 311 | def configure_optimizers(self): 312 | no_decay = ["bias", "LayerNorm.weight"] 313 | optimizer_grouped_parameters = [ 314 | { 315 | "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 316 | "weight_decay": self.config.weight_decay, 317 | }, 318 | {"params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 319 | ] 320 | 321 | 322 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate, eps=self.config.adam_epsilon) 323 | 324 | if self.config.max_steps > 0: 325 | t_total = self.config.max_steps 326 | self.config.num_train_epochs = self.config.max_steps // self.train_len // self.config.accumulate_grad_batches + 1 327 | else: 328 | t_total = self.train_len // self.config.accumulate_grad_batches * self.config.num_train_epochs 329 | 330 | logger.info('{} training steps in total.. '.format(t_total)) 331 | 332 | 333 | # scheduler is called only once per epoch by default 334 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.config.warmup_steps, num_training_steps=t_total) 335 | scheduler_dict = { 336 | 'scheduler': scheduler, 337 | 'interval': 'step', 338 | 'name': 'linear-schedule', 339 | } 340 | 341 | return [optimizer, ], [scheduler_dict,] 342 | 343 | 344 | 345 | -------------------------------------------------------------------------------- /baselines/latent_space_clustering.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle as pk 3 | import os 4 | from typing import List,Tuple 5 | 6 | import numpy as np 7 | from sklearn.cluster import KMeans 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn.parameter import Parameter 12 | from torch.optim import Adam 13 | from torch.utils.data import DataLoader, TensorDataset 14 | 15 | 16 | def cosine_dist(x_bar: torch.FloatTensor, x: torch.FloatTensor, weight=None): 17 | if weight is None: 18 | weight = torch.ones(x.size(0), device=x.device) 19 | cos_sim = (x_bar * x).sum(-1) 20 | cos_dist = 1 - cos_sim 21 | cos_dist = (cos_dist * weight).sum() / weight.sum() 22 | return cos_dist 23 | 24 | class AutoEncoder(nn.Module): 25 | ''' 26 | An autoencoder class with single input 27 | ''' 28 | 29 | def __init__(self, input_dim: int, hidden_dims: List[int]): 30 | ''' 31 | :param input_dim1: size of first input 32 | :param hidden_dims: list of dimension sizes 33 | ''' 34 | super().__init__() 35 | 36 | print("hidden_dims:", hidden_dims) 37 | self.encoder_layers = [] 38 | dims = [input_dim, ] + hidden_dims 39 | for i in range(len(dims) - 1): 40 | if i == 0: 41 | layer = nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.ReLU()) 42 | elif i != 0 and i < len(dims) - 2: 43 | layer = nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.ReLU()) 44 | else: 45 | layer = nn.Linear(dims[i], dims[i+1]) 46 | self.encoder_layers.append(layer) 47 | self.encoder = nn.Sequential(*self.encoder_layers) 48 | 49 | self.decoder_layers = [] 50 | hidden_dims.reverse() 51 | dims = hidden_dims + [input_dim,] 52 | for i in range(len(dims) - 1): 53 | if i < len(dims) - 2: 54 | layer = nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.ReLU()) 55 | else: 56 | layer = nn.Linear(dims[i], dims[i+1]) 57 | self.decoder_layers.append(layer) 58 | self.decoder = nn.Sequential(*self.decoder_layers) 59 | 60 | def forward(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 61 | z = self.encoder(x) 62 | x_bar = self.decoder(z) 63 | x_bar = F.normalize(x_bar, dim=-1) 64 | return x_bar, z 65 | 66 | 67 | 68 | class TwinAutoEncoder(nn.Module): 69 | ''' 70 | An autoencoder class with two inputs 71 | ''' 72 | 73 | def __init__(self, input_dim1: int, input_dim2: int, hidden_dims: List[int], agg: str, sep_decode: bool): 74 | ''' 75 | :param input_dim1: size of first input 76 | :param input_dim2: size of second input 77 | :param hidden_dims: list of dimension sizes 78 | :param agg: str, one of 'max','multi', 'sum', 'concat' 79 | :param sep_decode: bool, whether or not to decode the two inputs separately. 80 | ''' 81 | super().__init__() 82 | 83 | self.agg = agg 84 | self.sep_decode = sep_decode 85 | 86 | print("hidden_dims:", hidden_dims) 87 | self.encoder_layers = [] 88 | self.encoder2_layers = [] 89 | dims = [[input_dim1, input_dim2]] + hidden_dims 90 | for i in range(len(dims) - 1): 91 | if i == 0: 92 | layer = nn.Sequential(nn.Linear(dims[i][0], dims[i+1]), nn.ReLU()) 93 | layer2 = nn.Sequential(nn.Linear(dims[i][1], dims[i+1]), nn.ReLU()) 94 | elif i != 0 and i < len(dims) - 2: 95 | layer = nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.ReLU()) 96 | layer2 = nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.ReLU()) 97 | else: 98 | layer = nn.Linear(dims[i], dims[i+1]) 99 | layer2 = nn.Linear(dims[i], dims[i+1]) 100 | self.encoder_layers.append(layer) 101 | self.encoder2_layers.append(layer2) 102 | self.encoder = nn.Sequential(*self.encoder_layers) 103 | self.encoder2 = nn.Sequential(*self.encoder2_layers) 104 | 105 | self.decoder_layers = [] 106 | self.decoder2_layers = [] 107 | hidden_dims.reverse() 108 | dims = hidden_dims + [[input_dim1, input_dim2]] 109 | if self.agg == "concat" and not self.sep_decode: 110 | dims[0] = 2 * dims[0] 111 | for i in range(len(dims) - 1): 112 | if i < len(dims) - 2: 113 | layer = nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.ReLU()) 114 | layer2 = nn.Sequential(nn.Linear(dims[i], dims[i+1]), nn.ReLU()) 115 | else: 116 | layer = nn.Linear(dims[i], dims[i+1][0]) 117 | layer2 = nn.Linear(dims[i], dims[i+1][1]) 118 | self.decoder_layers.append(layer) 119 | self.decoder2_layers.append(layer2) 120 | self.decoder = nn.Sequential(*self.decoder_layers) 121 | self.decoder2 = nn.Sequential(*self.decoder2_layers) 122 | 123 | def forward(self, x1: torch.FloatTensor, x2: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 124 | z1 = self.encoder(x1) 125 | z2 = self.encoder2(x2) 126 | 127 | if self.agg == "max": 128 | z = torch.max(z1, z2) 129 | elif self.agg == "multi": 130 | z = z1 * z2 131 | elif self.agg == "sum": 132 | z = z1 + z2 133 | elif self.agg == "concat": 134 | z = torch.cat([z1, z2], dim=1) 135 | 136 | if self.sep_decode: 137 | x_bar1 = self.decoder(z1) 138 | x_bar1 = F.normalize(x_bar1, dim=-1) 139 | x_bar2 = self.decoder2(z2) 140 | x_bar2 = F.normalize(x_bar2, dim=-1) 141 | else: 142 | x_bar1 = self.decoder(z) 143 | x_bar1 = F.normalize(x_bar1, dim=-1) 144 | x_bar2 = self.decoder2(z) 145 | x_bar2 = F.normalize(x_bar2, dim=-1) 146 | 147 | return x_bar1, x_bar2, z 148 | 149 | 150 | class TopicCluster(nn.Module): 151 | 152 | def __init__(self, args): 153 | super(TopicCluster, self).__init__() 154 | self.alpha = 1.0 155 | self.dataset_path = args.dataset_path 156 | self.args = args 157 | self.device = args.device 158 | self.temperature = args.temperature 159 | self.distribution = args.distribution 160 | self.agg_method = args.agg_method 161 | self.sep_decode = (args.sep_decode == 1) 162 | 163 | input_dim1 = args.input_dim1 164 | input_dim2 = args.input_dim2 165 | hidden_dims = eval(args.hidden_dims) 166 | self.model = TwinAutoEncoder(input_dim1, input_dim2, hidden_dims, self.agg_method, self.sep_decode) 167 | if self.agg_method == "concat": 168 | self.topic_emb = Parameter(torch.Tensor(args.n_clusters, 2*hidden_dims[-1])) 169 | else: 170 | self.topic_emb = Parameter(torch.Tensor(args.n_clusters, hidden_dims[-1])) 171 | torch.nn.init.xavier_normal_(self.topic_emb.data) 172 | 173 | def pretrain(self, input_data, pretrain_epoch=200): 174 | pretrained_path = os.path.join(self.dataset_path, f"pretrained_{args.suffix}.pt") 175 | if os.path.exists(pretrained_path) and self.args.load_pretrain: 176 | # load pretrain weights 177 | print(f"loading pretrained model from {pretrained_path}") 178 | self.model.load_state_dict(torch.load(pretrained_path)) 179 | else: 180 | train_loader = DataLoader(input_data, batch_size=self.args.batch_size, shuffle=True) 181 | optimizer = Adam(self.model.parameters(), lr=self.args.lr) 182 | for epoch in range(pretrain_epoch): 183 | total_loss = 0 184 | for batch_idx, (x1, x2, _, weight) in enumerate(train_loader): 185 | x1 = x1.to(self.device) 186 | x2 = x2.to(self.device) 187 | weight = weight.to(self.device) 188 | optimizer.zero_grad() 189 | x_bar1, x_bar2, z = self.model(x1, x2) 190 | loss = cosine_dist(x_bar1, x1) + cosine_dist(x_bar2, x2) #, weight) 191 | total_loss += loss.item() 192 | loss.backward() 193 | optimizer.step() 194 | print(f"epoch {epoch}: loss = {total_loss / (batch_idx+1):.4f}") 195 | torch.save(self.model.state_dict(), pretrained_path) 196 | print(f"model saved to {pretrained_path}") 197 | 198 | def cluster_assign(self, z): 199 | if self.distribution == 'student': 200 | p = 1.0 / (1.0 + torch.sum( 201 | torch.pow(z.unsqueeze(1) - self.topic_emb, 2), 2) / self.alpha) 202 | p = p.pow((self.alpha + 1.0) / 2.0) 203 | p = (p.t() / torch.sum(p, 1)).t() 204 | else: 205 | self.topic_emb.data = F.normalize(self.topic_emb.data, dim=-1) 206 | z = F.normalize(z, dim=-1) 207 | sim = torch.matmul(z, self.topic_emb.t()) / self.temperature 208 | p = F.softmax(sim, dim=-1) 209 | return p 210 | 211 | def forward(self, x1, x2): 212 | x_bar1, x_bar2, z = self.model(x1, x2) 213 | p = self.cluster_assign(z) 214 | return x_bar1, x_bar2, z, p 215 | 216 | def target_distribution(self, x1, x2, freq, method='all', top_num=0): 217 | _, _, z = self.model(x1, x2) 218 | p = self.cluster_assign(z).detach() 219 | if method == 'all': 220 | q = p**2 / (p * freq.unsqueeze(-1)).sum(dim=0) 221 | q = (q.t() / q.sum(dim=1)).t() 222 | elif method == 'top': 223 | assert top_num > 0 224 | q = p.clone() 225 | sim = torch.matmul(self.topic_emb, z.t()) 226 | _, selected_idx = sim.topk(k=top_num, dim=-1) 227 | for i, topic_idx in enumerate(selected_idx): 228 | q[topic_idx] = 0 229 | q[topic_idx, i] = 1 230 | return p, q 231 | 232 | 233 | def train(args, emb_dict): 234 | # ipdb.set_trace() 235 | inv_vocab = {k: " ".join(v) for k, v in emb_dict["inv_vocab"].items()} 236 | vocab = {" ".join(k):v for k, v in emb_dict["vocab"].items()} 237 | print(f"Vocab size: {len(vocab)}") 238 | embs = F.normalize(torch.tensor(emb_dict["vs_emb"]), dim=-1) 239 | embs2 = F.normalize(torch.tensor(emb_dict["oh_emb"]), dim=-1) 240 | freq = np.array(emb_dict["tuple_freq"]) 241 | if not args.use_freq: 242 | freq = np.ones_like(freq) 243 | 244 | input_data = TensorDataset(embs, embs2, torch.arange(embs.size(0)), torch.tensor(freq)) 245 | topic_cluster = TopicCluster(args).to(args.device) 246 | topic_cluster.pretrain(input_data, args.pretrain_epoch) 247 | train_loader = DataLoader(input_data, batch_size=args.batch_size, shuffle=False) 248 | optimizer = Adam(topic_cluster.parameters(), lr=args.lr) 249 | 250 | # topic embedding initialization 251 | embs = embs.to(args.device) 252 | embs2 = embs2.to(args.device) 253 | x_bar1, x_bar2, z = topic_cluster.model(embs, embs2) 254 | z = F.normalize(z, dim=-1) 255 | 256 | print(f"Running K-Means for initialization") 257 | kmeans = KMeans(n_clusters=args.n_clusters, n_init=5) 258 | if args.use_freq: 259 | y_pred = kmeans.fit_predict(z.data.cpu().numpy(), sample_weight=freq) 260 | else: 261 | y_pred = kmeans.fit_predict(z.data.cpu().numpy()) 262 | print(f"Finish K-Means") 263 | 264 | freq = torch.tensor(freq).to(args.device) 265 | 266 | y_pred_last = y_pred 267 | topic_cluster.topic_emb.data = torch.tensor(kmeans.cluster_centers_).to(args.device) 268 | 269 | topic_cluster.train() 270 | i = 0 271 | for epoch in range(50): 272 | if epoch % 5 == 0: 273 | # save results 274 | _, _, z, p = topic_cluster(embs, embs2) 275 | z = F.normalize(z, dim=-1) 276 | topic_cluster.topic_emb.data = F.normalize(topic_cluster.topic_emb.data, dim=-1) 277 | if not os.path.exists(os.path.join(args.dataset_path, f"clusters_{args.suffix}")): 278 | os.makedirs(os.path.join(args.dataset_path, f"clusters_{args.suffix}")) 279 | embed_save_path = os.path.join(args.dataset_path, f"clusters_{args.suffix}/embed_{epoch}.pt") 280 | torch.save({ 281 | "inv_vocab": emb_dict['inv_vocab'], 282 | "embed": z.detach().cpu().numpy(), 283 | "topic_embed": topic_cluster.topic_emb.detach().cpu().numpy(), 284 | }, embed_save_path) 285 | f = open(os.path.join(args.dataset_path, f"clusters_{args.suffix}/{epoch}.txt"), 'w') 286 | pred_cluster = p.argmax(-1) 287 | result_strings = [] 288 | for j in range(args.n_clusters): 289 | if args.sort_method == 'discriminative': 290 | word_idx = torch.arange(embs.size(0))[pred_cluster == j] 291 | sorted_idx = torch.argsort(p[pred_cluster == j][:, j], descending=True) 292 | word_idx = word_idx[sorted_idx] 293 | else: 294 | sim = torch.matmul(topic_cluster.topic_emb[j], z.t()) 295 | _, word_idx = sim.topk(k=30, dim=-1) 296 | word_cluster = [] 297 | freq_sum = 0 298 | for idx in word_idx: 299 | freq_sum += freq[idx].item() 300 | if inv_vocab[idx.item()] not in word_cluster: 301 | word_cluster.append(inv_vocab[idx.item()]) 302 | if len(word_cluster) >= 10: 303 | break 304 | result_strings.append((freq_sum, f"Topic {j} ({freq_sum}): " + ', '.join(word_cluster)+'\n')) 305 | result_strings = sorted(result_strings, key=lambda x: x[0], reverse=True) 306 | for result_string in result_strings: 307 | f.write(result_string[1]) 308 | 309 | for x1, x2, idx, weight in train_loader: 310 | 311 | if i % args.update_interval == 0: 312 | p, q = topic_cluster.target_distribution(embs, embs2, freq.clone().fill_(1), method='all', top_num=epoch+1) 313 | 314 | y_pred = p.cpu().numpy().argmax(1) 315 | delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / y_pred.shape[0] 316 | y_pred_last = y_pred 317 | 318 | if i > 0 and delta_label < args.tol: 319 | print(f'delta_label {delta_label:.4f} < tol ({args.tol})') 320 | print('Reached tolerance threshold. Stopping training.') 321 | return None 322 | 323 | i += 1 324 | x1 = x1.to(args.device) 325 | x2 = x2.to(args.device) 326 | idx = idx.to(args.device) 327 | weight = weight.to(args.device) 328 | 329 | x_bar1, x_bar2, _, p = topic_cluster(x1, x2) 330 | reconstr_loss = cosine_dist(x_bar1, x1) + cosine_dist(x_bar2, x2) #, weight) 331 | kl_loss = F.kl_div(p.log(), q[idx], reduction='none').sum(-1) 332 | kl_loss = (kl_loss * weight).sum() / weight.sum() 333 | loss = args.gamma * kl_loss + reconstr_loss 334 | if i % args.update_interval == 0: 335 | print(f"KL loss: {kl_loss}; Reconstruction loss: {reconstr_loss}") 336 | 337 | optimizer.zero_grad() 338 | loss.backward() 339 | optimizer.step() 340 | return None 341 | 342 | 343 | 344 | if __name__ == "__main__": 345 | # CUDA_VISIBLE_DEVICES=0 python3 latent_space_clustering.py --dataset_path ./pandemic --input_emb_name po_tuple_features_all_svos.pk 346 | parser = argparse.ArgumentParser( 347 | description='train', 348 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 349 | 350 | parser.add_argument('--dataset_path', type=str) 351 | parser.add_argument('--input_emb_name', type=str) 352 | parser.add_argument('--lr', type=float, default=5e-4) 353 | parser.add_argument('--n_clusters', default=30, type=int) 354 | parser.add_argument('--input_dim1', default=1000, type=int) 355 | parser.add_argument('--input_dim2', default=1000, type=int) 356 | parser.add_argument('--agg_method', default="multi", choices=["sum", "multi", "concat", "attend"], type=str) 357 | parser.add_argument('--sep_decode', default=0, choices=[0, 1], type=int) 358 | parser.add_argument('--pretrain_epoch', default=100, type=int) 359 | parser.add_argument('--load_pretrain', default=False, action='store_true') 360 | parser.add_argument('--temperature', default=0.1, type=float) 361 | parser.add_argument('--sort_method', default='generative', choices=['generative', 'discriminative']) 362 | parser.add_argument('--distribution', default='softmax', choices=['softmax', 'student']) 363 | parser.add_argument('--batch_size', default=256, type=int) 364 | parser.add_argument('--use_freq', default=False, action='store_true') 365 | parser.add_argument('--hidden_dims', default='[1000, 2000, 1000, 100]', type=str) 366 | parser.add_argument('--suffix', type=str, default='') 367 | parser.add_argument('--gamma', default=5, type=float, help='weight of clustering loss') 368 | parser.add_argument('--update_interval', default=100, type=int) 369 | parser.add_argument('--tol', default=0.001, type=float) 370 | args = parser.parse_args() 371 | args.cuda = torch.cuda.is_available() 372 | print("use cuda: {}".format(args.cuda)) 373 | args.device = torch.device("cuda" if args.cuda else "cpu") 374 | print(args) 375 | with open(os.path.join(args.dataset_path, args.input_emb_name), "rb") as fin: 376 | emb_dict = pk.load(fin) 377 | 378 | candidate_idx = train(args, emb_dict) 379 | print(candidate_idx) --------------------------------------------------------------------------------