├── .gitignore ├── run_waypoint.bash ├── gen_training_data ├── config.yaml ├── get_images_inputs.py ├── get_nav_dict.py ├── test_twm0.2_obstacle_first.py └── utils.py ├── README.md ├── ddppo_resnet ├── resnet_policy.py ├── running_mean_and_var.py └── resnet.py ├── image_encoders.py ├── dataloader.py ├── TRM_net.py ├── utils.py ├── transformer ├── waypoint_bert.py └── pytorch_transformer │ ├── file_utils.py │ ├── modeling_utils.py │ └── modeling_bert.py ├── eval.py └── waypoint_predictor.py /.gitignore: -------------------------------------------------------------------------------- 1 | training_data 2 | gen_training_data/nav_dicts 3 | gen_training_data/raw_graphs 4 | __pycache__ 5 | checkpoints -------------------------------------------------------------------------------- /run_waypoint.bash: -------------------------------------------------------------------------------- 1 | 2 | flag="--EXP_ID wp-train 3 | 4 | --TRAINEVAL train 5 | --VIS 0 6 | 7 | --ANGLES 120 8 | --NUM_IMGS 12 9 | 10 | --EPOCH 300 11 | --BATCH_SIZE 8 12 | --LEARNING_RATE 1e-6 13 | 14 | --WEIGHT 0 15 | 16 | --TRM_LAYER 2 17 | --TRM_NEIGHBOR 1 18 | --HEATMAP_OFFSET 5 19 | --HIDDEN_DIM 768" 20 | 21 | python waypoint_predictor.py $flag 22 | -------------------------------------------------------------------------------- /gen_training_data/config.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 500 3 | SIMULATOR: 4 | ACTION_SPACE_CONFIG: v0 5 | AGENT_0: 6 | SENSORS: [RGB_SENSOR, DEPTH_SENSOR] 7 | FORWARD_STEP_SIZE: 0.25 8 | TURN_ANGLE: 15 9 | HABITAT_SIM_V0: 10 | GPU_DEVICE_ID: 0 11 | ALLOW_SLIDING: True 12 | RGB_SENSOR: 13 | WIDTH: 224 14 | HEIGHT: 224 15 | HFOV: 90 16 | TYPE: HabitatSimRGBSensor 17 | DEPTH_SENSOR: 18 | WIDTH: 256 # pretrained DDPPO resnet needs 256x256 19 | HEIGHT: 256 20 | TASK: 21 | TYPE: VLN-v0 22 | SUCCESS_DISTANCE: 3.0 23 | SENSORS: [ 24 | INSTRUCTION_SENSOR, 25 | SHORTEST_PATH_SENSOR, 26 | VLN_ORACLE_PROGRESS_SENSOR 27 | ] 28 | INSTRUCTION_SENSOR_UUID: instruction 29 | POSSIBLE_ACTIONS: [STOP, MOVE_FORWARD, TURN_LEFT, TURN_RIGHT] 30 | MEASUREMENTS: [ 31 | DISTANCE_TO_GOAL, 32 | SUCCESS, 33 | SPL, 34 | NDTW, 35 | PATH_LENGTH, 36 | ORACLE_SUCCESS, 37 | STEPS_TAKEN 38 | ] 39 | SUCCESS: 40 | SUCCESS_DISTANCE: 3.0 41 | SPL: 42 | SUCCESS_DISTANCE: 3.0 43 | NDTW: 44 | SUCCESS_DISTANCE: 3.0 45 | GT_PATH: data/datasets/R2R_VLNCE_v1-2_preprocessed/{split}/{split}_gt.json.gz 46 | SDTW: 47 | SUCCESS_DISTANCE: 3.0 48 | GT_PATH: data/datasets/R2R_VLNCE_v1-2_preprocessed/{split}/{split}_gt.json.gz 49 | ORACLE_SUCCESS: 50 | SUCCESS_DISTANCE: 3.0 51 | DATASET: 52 | TYPE: VLN-CE-v1 53 | SPLIT: train 54 | DATA_PATH: data/datasets/R2R_VLNCE_v1-2_preprocessed/{split}/{split}.json.gz 55 | SCENES_DIR: data/scene_datasets/ 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Waypoint Predictor Training for Discrete-Continuous-VLN 2 | 3 | ## Prerequisites 4 | 5 | 1. Please follow [Discrete-Continuous-VLN](https://github.com/YicongHong/Discrete-Continuous-VLN) to set up your environments, prepare scene dataset of MP3D, download the adapted mp3d connectivity graphs, and the pretrained ddppo ResNet encoder. Data and model path should be similar to Discrete-Continuous VLN. Download the adapted mp3d graphs from [here](https://drive.google.com/drive/folders/1wpuGAO-rRalPKt8m1-QIvlb_Pv1rYJ4x?usp=sharing). 6 | 7 | 2. Change the data path `/home/vlnce/vln-ce/data/` in the codes to your above data path. 8 | 3. Change the `RAW_GRAPH_PATH` in the codes to your unzipped adapted mp3d connectivity graphs. 9 | 10 | ## Preparing Training Data 11 | 12 | 1. Run `gen_training_data/get_images_inputs.py` to get the RGBD inputs of the waypoint predictor, which will be saved at `training_data/rgbd_fov90`. 13 | 2. Run `gen_training_data/get_nav_dict.py` to get the computed navigability dict of each node, which will be saved at `gen_training_data/nav_dicts`. 14 | 3. Run `gen_training_data/test_twm0.2_obstacle_first.py` to get the direct training data for training waypoint predictor, which will be saved at `training_data`. 15 | 16 | ## Running 17 | 18 | ### Training and Evaluation 19 | 20 | Please run `bash run_waypoint.bash` to train the waypoint predictor. If you only want to evaluate trained model, change `--TRAINEVAL` to `eval`. Modify the `checkpoint_load_path` in `waypoint_predictor.py` to evaluate different models. 21 | 22 | 23 | ## Citation 24 | Please cite our paper: 25 | ``` 26 | @InProceedings{Hong_2022_CVPR, 27 | author = {Hong, Yicong and Wang, Zun and Wu, Qi and Gould, Stephen}, 28 | title = {Bridging the Gap Between Learning in Discrete and Continuous Environments for Vision-and-Language Navigation}, 29 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 30 | month = {June}, 31 | year = {2022} 32 | } 33 | ``` 34 | -------------------------------------------------------------------------------- /ddppo_resnet/resnet_policy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from typing import Dict, Tuple 9 | 10 | import numpy as np 11 | import torch 12 | from torch import nn as nn 13 | from torch.nn import functional as F 14 | 15 | from habitat_baselines.rl.ddppo.policy import resnet 16 | 17 | 18 | class PNResnetDepthEncoder(nn.Module): 19 | def __init__( 20 | self, 21 | baseplanes: int = 32, 22 | ngroups: int = 16, 23 | spatial_size: int = 128, 24 | make_backbone=getattr(resnet, 'resnet50'), 25 | ): 26 | super().__init__() 27 | 28 | self._n_input_depth = 1 # observation_space.spaces["depth"].shape[2] 29 | spatial_size = 256 // 2 # observation_space.spaces["depth"].shape[0] 30 | 31 | self.running_mean_and_var = nn.Sequential() 32 | 33 | input_channels = self._n_input_depth 34 | self.backbone = make_backbone(input_channels, baseplanes, ngroups) 35 | 36 | final_spatial = int( 37 | spatial_size * self.backbone.final_spatial_compress 38 | ) 39 | after_compression_flat_size = 2048 40 | num_compression_channels = int( 41 | round(after_compression_flat_size / (final_spatial ** 2)) 42 | ) 43 | self.compression = nn.Sequential( 44 | nn.Conv2d( 45 | self.backbone.final_channels, 46 | num_compression_channels, 47 | kernel_size=3, 48 | padding=1, 49 | bias=False, 50 | ), 51 | nn.GroupNorm(1, num_compression_channels), 52 | nn.ReLU(True), 53 | ) 54 | 55 | def layer_init(self): 56 | for layer in self.modules(): 57 | if isinstance(layer, (nn.Conv2d, nn.Linear)): 58 | nn.init.kaiming_normal_( 59 | layer.weight, nn.init.calculate_gain("relu") 60 | ) 61 | if layer.bias is not None: 62 | nn.init.constant_(layer.bias, val=0) 63 | 64 | def forward(self, depth_observations): 65 | cnn_input = [] 66 | 67 | if self._n_input_depth > 0: 68 | # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] 69 | depth_observations = depth_observations.permute(0, 3, 1, 2) 70 | 71 | cnn_input.append(depth_observations) 72 | 73 | x = torch.cat(cnn_input, dim=1) 74 | x = F.avg_pool2d(x, 2) 75 | 76 | x = self.running_mean_and_var(x) 77 | x = self.backbone(x) 78 | x = self.compression(x) 79 | return x 80 | -------------------------------------------------------------------------------- /image_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import numpy as np 5 | 6 | from ddppo_resnet.resnet_policy import PNResnetDepthEncoder 7 | 8 | class RGBEncoder(nn.Module): 9 | def __init__(self, resnet_pretrain=True, trainable=False): 10 | super(RGBEncoder, self).__init__() 11 | if resnet_pretrain: 12 | print('\nLoading Torchvision pre-trained Resnet50 for RGB ...') 13 | rgb_resnet = torchvision.models.resnet50(pretrained=resnet_pretrain) 14 | rgb_modules = list(rgb_resnet.children())[:-2] 15 | rgb_net = torch.nn.Sequential(*rgb_modules) 16 | self.rgb_net = rgb_net 17 | for param in self.rgb_net.parameters(): 18 | param.requires_grad_(trainable) 19 | 20 | # self.scale = 0.5 21 | 22 | def forward(self, rgb_imgs): 23 | rgb_shape = rgb_imgs.size() 24 | rgb_imgs = rgb_imgs.reshape(rgb_shape[0]*rgb_shape[1], 25 | rgb_shape[2], rgb_shape[3], rgb_shape[4]) 26 | rgb_feats = self.rgb_net(rgb_imgs) # * self.scale 27 | 28 | # print('rgb_imgs', rgb_imgs.shape) 29 | # print('rgb_feats', rgb_feats.shape) 30 | 31 | return rgb_feats.squeeze() 32 | 33 | 34 | class DepthEncoder(nn.Module): 35 | def __init__(self, resnet_pretrain=True, trainable=False): 36 | super(DepthEncoder, self).__init__() 37 | 38 | self.depth_net = PNResnetDepthEncoder() 39 | if resnet_pretrain: 40 | print('Loading PointNav pre-trained Resnet50 for Depth ...') 41 | ddppo_pn_depth_encoder_weights = torch.load('/home/vlnce/vln-ce/data/ddppo-models/gibson-2plus-resnet50.pth') 42 | weights_dict = {} 43 | for k, v in ddppo_pn_depth_encoder_weights["state_dict"].items(): 44 | split_layer_name = k.split(".")[2:] 45 | if split_layer_name[0] != "visual_encoder": 46 | continue 47 | layer_name = ".".join(split_layer_name[1:]) 48 | weights_dict[layer_name] = v 49 | del ddppo_pn_depth_encoder_weights 50 | self.depth_net.load_state_dict(weights_dict, strict=True) 51 | for param in self.depth_net.parameters(): 52 | param.requires_grad_(trainable) 53 | 54 | def forward(self, depth_imgs): 55 | depth_shape = depth_imgs.size() 56 | depth_imgs = depth_imgs.reshape(depth_shape[0]*depth_shape[1], 57 | depth_shape[2], depth_shape[3], depth_shape[4]) 58 | depth_feats = self.depth_net(depth_imgs) 59 | 60 | # print('depth_imgs', depth_imgs.shape) 61 | # print('depth_feats', depth_feats.shape) 62 | # 63 | # import pdb; pdb.set_trace() 64 | 65 | return depth_feats 66 | -------------------------------------------------------------------------------- /gen_training_data/get_images_inputs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import utils 4 | import habitat 5 | import os 6 | import pickle 7 | from habitat.sims import make_sim 8 | 9 | 10 | config_path = './gen_training_data/config.yaml' 11 | scene_path = '/home/vlnce/vln-ce/data/scene_datasets/mp3d/{scan}/{scan}.glb' 12 | image_path = './training_data/rgbd_fov90/' 13 | save_path = os.path.join(image_path,'{split}/{scan}/{scan}_{node}_mp3d_imgs.pkl') 14 | RAW_GRAPH_PATH= '/home/vlnce/habitat_connectivity_graph/%s.json' 15 | NUMBER = 12 16 | 17 | SPLIT = 'train' 18 | 19 | with open(RAW_GRAPH_PATH%SPLIT, 'r') as f: 20 | raw_graph_data = json.load(f) 21 | 22 | nav_dict = {} 23 | total_invalids = 0 24 | total = 0 25 | 26 | for scene, data in raw_graph_data.items(): 27 | ''' connectivity dictionary ''' 28 | connect_dict = {} 29 | for edge_id, edge_info in data['edges'].items(): 30 | node_a = edge_info['nodes'][0] 31 | node_b = edge_info['nodes'][1] 32 | 33 | if node_a not in connect_dict: 34 | connect_dict[node_a] = [node_b] 35 | else: 36 | connect_dict[node_a].append(node_b) 37 | if node_b not in connect_dict: 38 | connect_dict[node_b] = [node_a] 39 | else: 40 | connect_dict[node_b].append(node_a) 41 | 42 | '''make sim for obstacle checking''' 43 | config = habitat.get_config(config_path) 44 | config.defrost() 45 | # config.TASK.POSSIBLE_ACTIONS = ['STOP', 'MOVE_FORWARD', 'TURN_LEFT', 'TURN_RIGHT', 'FORWARD_BY_DIS'] 46 | config.TASK.SENSORS = [] 47 | config.SIMULATOR.FORWARD_STEP_SIZE = 0.25 48 | config.SIMULATOR.HABITAT_SIM_V0.ALLOW_SLIDING = False 49 | config.SIMULATOR.SCENE = scene_path.format(scan=scene) 50 | sim = make_sim(id_sim=config.SIMULATOR.TYPE, config=config.SIMULATOR) 51 | 52 | '''save images''' 53 | if not os.path.exists(image_path+'{split}/{scan}'.format(split=SPLIT,scan=scene)): 54 | os.makedirs(image_path+'{split}/{scan}'.format(split=SPLIT,scan=scene)) 55 | navigability_dict = {} 56 | 57 | i = 0 58 | for node_a, neighbors in connect_dict.items(): 59 | navigability_dict[node_a] = utils.init_single_node_dict(number=NUMBER) 60 | rgbs = [] 61 | depths = [] 62 | node_a_pos = np.array(data['nodes'][node_a])[[0, 2]] 63 | 64 | habitat_pos = np.array(data['nodes'][node_a]) 65 | for info in navigability_dict[node_a].values(): 66 | position, heading = habitat_pos, info['heading'] 67 | theta = -(heading - np.pi) / 2 68 | rotation = np.quaternion(np.cos(theta), 0, np.sin(theta), 0) 69 | obs = sim.get_observations_at(position, rotation) 70 | rgbs.append(obs['rgb']) 71 | depths.append(obs['depth']) 72 | with open(save_path.format(split=SPLIT, scan=scene, node=node_a), 'wb') as f: 73 | pickle.dump({'rgb': np.array(rgbs), 74 | 'depth': np.array(depths, dtype=np.float16)}, f) 75 | utils.print_progress(i+1,total) 76 | i+=1 77 | 78 | sim.close() 79 | -------------------------------------------------------------------------------- /ddppo_resnet/running_mean_and_var.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor 9 | from torch import distributed as distrib 10 | from torch import nn as nn 11 | 12 | 13 | class RunningMeanAndVar(nn.Module): 14 | def __init__(self, n_channels: int) -> None: 15 | super().__init__() 16 | self.register_buffer("_mean", torch.zeros(1, n_channels, 1, 1)) 17 | self.register_buffer("_var", torch.zeros(1, n_channels, 1, 1)) 18 | self.register_buffer("_count", torch.zeros(())) 19 | self._mean: torch.Tensor = self._mean 20 | self._var: torch.Tensor = self._var 21 | self._count: torch.Tensor = self._count 22 | 23 | def forward(self, x: Tensor) -> Tensor: 24 | if self.training: 25 | n = x.size(0) 26 | # We will need to do reductions (mean) over the channel dimension, 27 | # so moving channels to the first dimension and then flattening 28 | # will make those faster. Further, it makes things more numerically stable 29 | # for fp16 since it is done in a single reduction call instead of 30 | # multiple 31 | x_channels_first = ( 32 | x.transpose(1, 0).contiguous().view(x.size(1), -1) 33 | ) 34 | new_mean = x_channels_first.mean(-1, keepdim=True) 35 | new_count = torch.full_like(self._count, n) 36 | 37 | if distrib.is_initialized(): 38 | distrib.all_reduce(new_mean) 39 | distrib.all_reduce(new_count) 40 | new_mean /= distrib.get_world_size() 41 | 42 | new_var = ( 43 | (x_channels_first - new_mean).pow(2).mean(dim=-1, keepdim=True) 44 | ) 45 | 46 | if distrib.is_initialized(): 47 | distrib.all_reduce(new_var) 48 | new_var /= distrib.get_world_size() 49 | 50 | new_mean = new_mean.view(1, -1, 1, 1) 51 | new_var = new_var.view(1, -1, 1, 1) 52 | 53 | m_a = self._var * (self._count) 54 | m_b = new_var * (new_count) 55 | M2 = ( 56 | m_a 57 | + m_b 58 | + (new_mean - self._mean).pow(2) 59 | * self._count 60 | * new_count 61 | / (self._count + new_count) 62 | ) 63 | 64 | self._var = M2 / (self._count + new_count) 65 | self._mean = (self._count * self._mean + new_count * new_mean) / ( 66 | self._count + new_count 67 | ) 68 | 69 | self._count += new_count 70 | 71 | inv_stdev = torch.rsqrt( 72 | torch.max(self._var, torch.full_like(self._var, 1e-2)) 73 | ) 74 | # This is the same as 75 | # (x - self._mean) * inv_stdev but is faster since it can 76 | # make use of addcmul and is more numerically stable in fp16 77 | return torch.addcmul(-self._mean * inv_stdev, x, inv_stdev) 78 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import numpy as np 4 | from PIL import Image 5 | import pickle as pkl 6 | 7 | import torch 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | # dataloader and transforms 12 | class RGBDepthPano(Dataset): 13 | def __init__(self, args, img_dir, navigability_dict): 14 | # self.IMG_WIDTH = 256 15 | # self.IMG_HEIGHT = 256 16 | self.RGB_INPUT_DIM = 224 17 | self.DEPTH_INPUT_DIM = 256 18 | self.NUM_IMGS = args.NUM_IMGS 19 | self.navigability_dict = navigability_dict 20 | 21 | self.rgb_transform = torch.nn.Sequential( 22 | # [transforms.Resize((256,341)), 23 | # transforms.CenterCrop(self.RGB_INPUT_DIM), 24 | # transforms.ToTensor(),] 25 | transforms.ConvertImageDtype(torch.float), 26 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 27 | ) 28 | # self.depth_transform = transforms.Compose( 29 | # # [transforms.Resize((self.DEPTH_INPUT_DIM, self.DEPTH_INPUT_DIM)), 30 | # [transforms.ToTensor(), 31 | # ]) 32 | 33 | self.img_dirs = glob.glob(img_dir) 34 | 35 | for img_dir in glob.glob(img_dir): 36 | scan_id = img_dir.split('/')[-1][:11] 37 | waypoint_id = img_dir.split('/')[-1][12:-14] 38 | if waypoint_id not in self.navigability_dict[scan_id]: 39 | self.img_dirs.remove(img_dir) 40 | 41 | def __len__(self): # default name when writing class 42 | return len(self.img_dirs) 43 | 44 | def __getitem__(self, idx): # default name when writing class 45 | 46 | img_dir = self.img_dirs[idx] 47 | sample_id = str(idx) 48 | scan_id = img_dir.split('/')[-1][:11] 49 | waypoint_id = img_dir.split('/')[-1][12:-14] 50 | 51 | ''' rgb and depth images ''' 52 | rgb_depth_img = pkl.load(open(img_dir, "rb")) 53 | rgb_img = torch.from_numpy(rgb_depth_img['rgb']).permute(0, 3, 1, 2) 54 | depth_img = torch.from_numpy(rgb_depth_img['depth']).permute(0, 3, 1, 2) 55 | 56 | # 3 should be the last channel 57 | trans_rgb_imgs = torch.zeros(self.NUM_IMGS, 3, self.RGB_INPUT_DIM, self.RGB_INPUT_DIM) 58 | trans_depth_imgs = torch.zeros(self.NUM_IMGS, self.DEPTH_INPUT_DIM, self.DEPTH_INPUT_DIM) 59 | 60 | no_trans_rgb = torch.zeros(self.NUM_IMGS, 3, self.RGB_INPUT_DIM, self.RGB_INPUT_DIM, dtype=torch.uint8) 61 | no_trans_depth = torch.zeros(self.NUM_IMGS, self.DEPTH_INPUT_DIM, self.DEPTH_INPUT_DIM) 62 | 63 | for ix in range(self.NUM_IMGS): 64 | trans_rgb_imgs[ix] = self.rgb_transform(rgb_img[ix]) 65 | # no_trans_rgb[ix] = rgb_img[ix] 66 | trans_depth_imgs[ix] = depth_img[ix][0] 67 | # no_trans_depth[ix] = depth_img[ix][0] 68 | 69 | sample = {'sample_id': sample_id, 70 | 'scan_id': scan_id, 71 | 'waypoint_id': waypoint_id, 72 | 'rgb': trans_rgb_imgs, 73 | 'depth': trans_depth_imgs.unsqueeze(-1), 74 | # 'no_trans_rgb': no_trans_rgb, 75 | # 'no_trans_depth': no_trans_depth, 76 | } 77 | 78 | # print('------------------------') 79 | # print(trans_rgb_imgs[0][0]) 80 | # print(rgb_img[0].shape, rgb_img[0]) 81 | # anivlrb 82 | 83 | return sample 84 | -------------------------------------------------------------------------------- /gen_training_data/get_nav_dict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import utils 4 | import habitat 5 | from habitat.sims import make_sim 6 | from utils import Simulator 7 | 8 | config_path = 'gen_training_data/config.yaml' 9 | scene_path = '/home/vlnce/vln-ce/data/scene_datasets/mp3d/{scan}/{scan}.glb' 10 | RAW_GRAPH_PATH= '/home/vlnce/habitat_connectivity_graph/%s.json' 11 | NUMBER = 120 12 | 13 | SPLIT = 'val_unseen' 14 | 15 | with open(RAW_GRAPH_PATH%SPLIT, 'r') as f: 16 | raw_graph_data = json.load(f) 17 | 18 | nav_dict = {} 19 | total_invalids = 0 20 | total = 0 21 | 22 | for scene, data in raw_graph_data.items(): 23 | ''' connectivity dictionary ''' 24 | connect_dict = {} 25 | for edge_id, edge_info in data['edges'].items(): 26 | node_a = edge_info['nodes'][0] 27 | node_b = edge_info['nodes'][1] 28 | 29 | if node_a not in connect_dict: 30 | connect_dict[node_a] = [node_b] 31 | else: 32 | connect_dict[node_a].append(node_b) 33 | if node_b not in connect_dict: 34 | connect_dict[node_b] = [node_a] 35 | else: 36 | connect_dict[node_b].append(node_a) 37 | 38 | 39 | '''make sim for obstacle checking''' 40 | config = habitat.get_config(config_path) 41 | config.defrost() 42 | # config.TASK.POSSIBLE_ACTIONS = ['STOP', 'MOVE_FORWARD', 'TURN_LEFT', 'TURN_RIGHT', 'FORWARD_BY_DIS'] 43 | # config.SIMULATOR.AGENT_0.SENSORS = [] 44 | config.SIMULATOR.FORWARD_STEP_SIZE = 0.25 45 | config.SIMULATOR.HABITAT_SIM_V0.ALLOW_SLIDING = False 46 | config.SIMULATOR.TYPE = 'Sim-v1' 47 | config.SIMULATOR.SCENE = scene_path.format(scan=scene) 48 | sim = make_sim(id_sim=config.SIMULATOR.TYPE, config=config.SIMULATOR) 49 | 50 | ''' process each node to standard data format ''' 51 | navigability_dict = {} 52 | total = len(connect_dict) 53 | for i, pair in enumerate(connect_dict.items()): 54 | node_a, neighbors = pair 55 | navigability_dict[node_a] = utils.init_single_node_dict(number=NUMBER) 56 | node_a_pos = np.array(data['nodes'][node_a])[[0,2]] 57 | 58 | habitat_pos = np.array(data['nodes'][node_a]) 59 | for id, info in navigability_dict[node_a].items(): 60 | obstacle_distance, obstacle_index = utils.get_obstacle_info(habitat_pos,info['heading'],sim) 61 | info['obstacle_distance'] = obstacle_distance 62 | info['obstacle_index'] = obstacle_index 63 | 64 | for node_b in neighbors: 65 | node_b_pos = np.array(data['nodes'][node_b])[[0,2]] 66 | 67 | edge_vec = (node_b_pos - node_a_pos) 68 | angle, angleIndex, distance, distanceIndex = utils.edge_vec_to_indexes(edge_vec,number=NUMBER) 69 | 70 | navigability_dict[node_a][str(angleIndex)]['has_waypoint'] = True 71 | navigability_dict[node_a][str(angleIndex)]['waypoint'].append( 72 | { 73 | 'node_id': node_b, 74 | 'position': node_b_pos.tolist(), 75 | 'angle': angle, 76 | 'angleIndex': angleIndex, 77 | 'distance': distance, 78 | 'distanceIndex': distanceIndex, 79 | }) 80 | utils.print_progress(i+1,total) 81 | 82 | nav_dict[scene] = navigability_dict 83 | sim.close() 84 | 85 | output_path = './gen_training_data/nav_dicts/navigability_dict_%s.json'%SPLIT 86 | with open(output_path, 'w') as fo: 87 | json.dump(nav_dict, fo, ensure_ascii=False, indent=4) 88 | -------------------------------------------------------------------------------- /TRM_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import utils 5 | 6 | from transformer.waypoint_bert import WaypointBert 7 | from pytorch_transformers import BertConfig 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | def TRM_predict(mode, args, predictor, rgb_feats, depth_feats): 12 | ''' predicting the waypoint probabilities ''' 13 | vis_logits = predictor(rgb_feats, depth_feats) 14 | # entry-wise probabilities 15 | vis_probs = torch.sigmoid(vis_logits) 16 | 17 | if mode == 'train': 18 | return vis_logits 19 | elif mode == 'eval': 20 | return vis_probs, vis_logits 21 | 22 | 23 | class BinaryDistPredictor_TRM(nn.Module): 24 | def __init__(self, args=None, hidden_dim=768, n_classes=12): 25 | super(BinaryDistPredictor_TRM, self).__init__() 26 | self.args = args 27 | self.batchsize = args.BATCH_SIZE 28 | self.num_angles = args.ANGLES 29 | self.num_imgs = args.NUM_IMGS 30 | self.n_classes = n_classes 31 | 32 | # self.visual_1by1conv_rgb = nn.Conv2d( 33 | # in_channels=2048, out_channels=512, kernel_size=1) 34 | self.visual_fc_rgb = nn.Sequential( 35 | nn.Flatten(), 36 | nn.Linear(np.prod([2048,7,7]), hidden_dim), 37 | nn.ReLU(True), 38 | ) 39 | # self.visual_1by1conv_depth = nn.Conv2d( 40 | # in_channels=128, out_channels=512, kernel_size=1) 41 | self.visual_fc_depth = nn.Sequential( 42 | nn.Flatten(), 43 | nn.Linear(np.prod([128,4,4]), hidden_dim), 44 | nn.ReLU(True), 45 | ) 46 | self.visual_merge = nn.Sequential( 47 | nn.Linear(hidden_dim*2, hidden_dim), 48 | nn.ReLU(True), 49 | ) 50 | 51 | config = BertConfig() 52 | config.model_type = 'visual' 53 | config.finetuning_task = 'waypoint_predictor' 54 | config.hidden_dropout_prob = 0.3 55 | config.hidden_size = 768 56 | config.num_attention_heads = 12 57 | config.num_hidden_layers = args.TRM_LAYER 58 | self.waypoint_TRM = WaypointBert(config=config) 59 | 60 | layer_norm_eps = config.layer_norm_eps 61 | # self.mergefeats_LayerNorm = BertLayerNorm( 62 | # hidden_dim, 63 | # eps=layer_norm_eps 64 | # ) 65 | 66 | self.mask = utils.get_attention_mask( 67 | num_imgs=self.num_imgs, 68 | neighbor=args.TRM_NEIGHBOR).to(device) 69 | 70 | self.vis_classifier = nn.Sequential( 71 | nn.Linear(hidden_dim, hidden_dim), 72 | nn.ReLU(), 73 | nn.Linear(hidden_dim, 74 | int(n_classes*(self.num_angles/self.num_imgs))), 75 | ) 76 | 77 | def forward(self, rgb_feats, depth_feats): 78 | bsi = rgb_feats.size(0) // self.num_imgs 79 | 80 | # rgb_x = self.visual_1by1conv_rgb(rgb_feats) 81 | rgb_x = self.visual_fc_rgb(rgb_feats).reshape( 82 | bsi, self.num_imgs, -1) 83 | 84 | # depth_x = self.visual_1by1conv_depth(depth_feats) 85 | depth_x = self.visual_fc_depth(depth_feats).reshape( 86 | bsi, self.num_imgs, -1) 87 | 88 | vis_x = self.visual_merge( 89 | torch.cat((rgb_x, depth_x), dim=-1) 90 | ) 91 | # vis_x = self.mergefeats_LayerNorm(vis_x) 92 | 93 | attention_mask = self.mask.repeat(bsi,1,1,1) 94 | vis_rel_x = self.waypoint_TRM( 95 | vis_x, attention_mask=attention_mask 96 | ) 97 | 98 | vis_logits = self.vis_classifier(vis_rel_x) 99 | vis_logits = vis_logits.reshape( 100 | bsi, self.num_angles, self.n_classes) 101 | 102 | # heatmap offset (each image is pointing at the middle) 103 | vis_logits = torch.cat( 104 | (vis_logits[:,self.args.HEATMAP_OFFSET:,:], vis_logits[:,:self.args.HEATMAP_OFFSET,:]), 105 | dim=1) 106 | 107 | return vis_logits 108 | 109 | 110 | class BertLayerNorm(nn.Module): 111 | def __init__(self, hidden_size, eps=1e-12): 112 | """Construct a layernorm module in the TF style (epsilon inside the square root). 113 | """ 114 | super(BertLayerNorm, self).__init__() 115 | self.weight = nn.Parameter(torch.ones(hidden_size)) 116 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 117 | self.variance_epsilon = eps 118 | 119 | def forward(self, x): 120 | u = x.mean(-1, keepdim=True) 121 | s = (x - u).pow(2).mean(-1, keepdim=True) 122 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 123 | return self.weight * x + self.bias 124 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import sys 5 | import glob 6 | import json 7 | 8 | def neighborhoods(mu, x_range, y_range, sigma, circular_x=True, gaussian=False): 9 | """ Generate masks centered at mu of the given x and y range with the 10 | origin in the centre of the output 11 | Inputs: 12 | mu: tensor (N, 2) 13 | Outputs: 14 | tensor (N, y_range, s_range) 15 | """ 16 | x_mu = mu[:,0].unsqueeze(1).unsqueeze(1) 17 | y_mu = mu[:,1].unsqueeze(1).unsqueeze(1) 18 | 19 | # Generate bivariate Gaussians centered at position mu 20 | x = torch.arange(start=0,end=x_range, device=mu.device, dtype=mu.dtype).unsqueeze(0).unsqueeze(0) 21 | y = torch.arange(start=0,end=y_range, device=mu.device, dtype=mu.dtype).unsqueeze(1).unsqueeze(0) 22 | 23 | y_diff = y - y_mu 24 | x_diff = x - x_mu 25 | if circular_x: 26 | x_diff = torch.min(torch.abs(x_diff), torch.abs(x_diff + x_range)) 27 | if gaussian: 28 | output = torch.exp(-0.5 * ((x_diff/sigma[0])**2 + (y_diff/sigma[1])**2 )) 29 | else: 30 | output = torch.logical_and( 31 | torch.abs(x_diff) <= sigma[0], torch.abs(y_diff) <= sigma[1] 32 | ).type(mu.dtype) 33 | 34 | return output 35 | 36 | 37 | def nms(pred, max_predictions=10, sigma=(1.0,1.0), gaussian=False): 38 | ''' Input (batch_size, 1, height, width) ''' 39 | 40 | shape = pred.shape 41 | 42 | output = torch.zeros_like(pred) 43 | flat_pred = pred.reshape((shape[0],-1)) # (BATCH_SIZE, 24*48) 44 | supp_pred = pred.clone() 45 | flat_output = output.reshape((shape[0],-1)) # (BATCH_SIZE, 24*48) 46 | 47 | for i in range(max_predictions): 48 | # Find and save max over the entire map 49 | flat_supp_pred = supp_pred.reshape((shape[0],-1)) 50 | val, ix = torch.max(flat_supp_pred, dim=1) 51 | indices = torch.arange(0,shape[0]) 52 | flat_output[indices,ix] = flat_pred[indices,ix] 53 | 54 | # Suppression 55 | y = ix / shape[-1] 56 | x = ix % shape[-1] 57 | mu = torch.stack([x,y], dim=1).float() 58 | 59 | g = neighborhoods(mu, shape[-1], shape[-2], sigma, gaussian=gaussian) 60 | 61 | supp_pred *= (1-g.unsqueeze(1)) 62 | 63 | output[output < 0] = 0 64 | return output 65 | 66 | 67 | def get_gt_nav_map(num_angles, nav_dict, scan_ids, waypoint_ids): 68 | ''' A manully written random target map and its corresponding value map 69 | Target map: 70 | 1 - keypoint on ground-truth map 71 | 2 - ignore indexes 72 | Weightings map: 73 | 0 - ignore 74 | 1 - waypoint, too far from waypoint, or obstacle 75 | (0,1) - other open space 76 | When building this map, starts with zeros, and then add GT rows, 77 | finally add walls 78 | ''' 79 | # (Pdb) target.shape 80 | # torch.Size([2, 24, 8]) 81 | # (Pdb) weight.shape 82 | # torch.Size([2, 24, 8]) 83 | bs = len(scan_ids) 84 | target = torch.zeros(bs, num_angles, 12) 85 | obstacle = torch.zeros(bs, num_angles, 12) 86 | weight = torch.zeros(bs, num_angles, 12) 87 | source_pos = [] 88 | target_pos = [] 89 | 90 | for i in range(bs): 91 | target[i] = torch.tensor(nav_dict[scan_ids[i]][waypoint_ids[i]]['target']) 92 | obstacle[i] = torch.tensor(nav_dict[scan_ids[i]][waypoint_ids[i]]['obstacle']) 93 | weight[i] = torch.tensor(nav_dict[scan_ids[i]][waypoint_ids[i]]['weight']) 94 | source_pos.append(nav_dict[scan_ids[i]][waypoint_ids[i]]['source_pos']) 95 | target_pos.append(nav_dict[scan_ids[i]][waypoint_ids[i]]['target_pos']) 96 | 97 | return target, obstacle, weight, source_pos, target_pos 98 | 99 | 100 | def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_length=10): 101 | """ 102 | Call in a loop to create terminal progress bar 103 | @params: 104 | iteration - Required : current iteration (Int) 105 | total - Required : total iterations (Int) 106 | prefix - Optional : prefix string (Str) 107 | suffix - Optional : suffix string (Str) 108 | decimals - Optional : positive number of decimals in percent complete (Int) 109 | bar_length - Optional : character length of bar (Int) 110 | """ 111 | str_format = "{0:." + str(decimals) + "f}" 112 | percents = str_format.format(100 * (iteration / float(total))) 113 | filled_length = int(round(bar_length * iteration / float(total))) 114 | bar = '█' * filled_length + '-' * (bar_length - filled_length) 115 | 116 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 117 | 118 | if iteration == total: 119 | sys.stdout.write('\n') 120 | sys.stdout.flush() 121 | 122 | 123 | def save_checkpoint(epoch, net, net_optimizer, path): 124 | ''' Snapshot models ''' 125 | states = {} 126 | def create_state(name, model, optimizer): 127 | states[name] = { 128 | 'epoch': epoch, 129 | 'state_dict': model.state_dict(), 130 | 'optimizer': optimizer.state_dict(), 131 | } 132 | all_tuple = [("predictor", net, net_optimizer)] 133 | for param in all_tuple: 134 | create_state(*param) 135 | torch.save(states, path) 136 | 137 | 138 | def load_checkpoint(net, net_optimizer, path): 139 | ''' Loads parameters (but not training state) ''' 140 | states = torch.load(path) 141 | def recover_state(name, model, optimizer): 142 | state = model.state_dict() 143 | model_keys = set(state.keys()) 144 | load_keys = set(states[name]['state_dict'].keys()) 145 | if model_keys != load_keys: 146 | print("NOTICE: DIFFERENT KEYS FOUND") 147 | state.update(states[name]['state_dict']) 148 | model.load_state_dict(state) 149 | optimizer.load_state_dict(states[name]['optimizer']) 150 | all_tuple = [("predictor", net, net_optimizer)] 151 | for param in all_tuple: 152 | recover_state(*param) 153 | return states['predictor']['epoch'], all_tuple[0][1], all_tuple[0][2] 154 | 155 | 156 | def get_attention_mask(num_imgs=24, neighbor=2): 157 | assert neighbor <= 5 158 | 159 | mask = np.zeros((num_imgs,num_imgs)) 160 | t = np.zeros(num_imgs) 161 | t[:neighbor+1] = np.ones(neighbor+1) 162 | if neighbor != 0: 163 | t[-neighbor:] = np.ones(neighbor) 164 | for ri in range(num_imgs): 165 | mask[ri] = t 166 | t = np.roll(t, 1) 167 | 168 | return torch.from_numpy(mask).reshape(1,1,num_imgs,num_imgs).long() 169 | 170 | 171 | def load_gt_navigability(path): 172 | ''' waypoint ground-truths ''' 173 | all_scans_nav_map = {} 174 | gt_dir = glob.glob('%s*'%(path)) 175 | for gt_dir_i in gt_dir: 176 | with open(gt_dir_i, 'r') as f: 177 | nav_map = json.load(f) 178 | for scan_id, values in nav_map.items(): 179 | all_scans_nav_map[scan_id] = values 180 | return all_scans_nav_map 181 | -------------------------------------------------------------------------------- /gen_training_data/test_twm0.2_obstacle_first.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import numpy as np 4 | import copy 5 | import torch 6 | import os 7 | import utils 8 | from scipy.spatial.distance import cdist 9 | from scipy.ndimage import gaussian_filter 10 | 11 | ANGLES = 120 12 | DISTANCES = 12 13 | RAW_GRAPH_PATH = '/home/vlnce/habitat_connectivity_graph/%s.json' 14 | 15 | RADIUS = 3.25 # corresponding to max forward distance of 2 meters 16 | 17 | print('Running TRM-0.2 !!!!!!!!!!') 18 | 19 | print('\nProcessing navigability and connectivity to GT maps') 20 | print('Using %s ANGLES and %s DISTANCES'%(ANGLES, DISTANCES)) 21 | print('Maximum radius for each waypoint: %s'%(RADIUS)) 22 | print('\nConstraining each angle sector has at most one GT waypoint') 23 | print('For all sectors with edges greater than %s, create a virtual waypoint at %s'%(RADIUS, 2.50)) 24 | print('\nThis script will return the target map, the obstacle map and the weigh map for each environment') 25 | 26 | np.random.seed(7) 27 | 28 | splits = ['train', 'val_unseen'] 29 | for split in splits: 30 | print('\nProcessing %s split data'%(split)) 31 | 32 | with open(RAW_GRAPH_PATH%split, 'r') as f: 33 | data = json.load(f) 34 | if os.path.exists('./gen_training_data/nav_dicts/navigability_dict_%s.json'%split): 35 | with open('./gen_training_data/nav_dicts/navigability_dict_%s.json'%split) as f: 36 | nav_dict = json.load(f) 37 | raw_nav_dict = {} 38 | nodes = {} 39 | edges = {} 40 | obstacles = {} 41 | for k, v in data.items(): 42 | nodes[k] = data[k]['nodes'] 43 | edges[k] = data[k]['edges'] 44 | obstacles[k] = nav_dict[k] 45 | raw_nav_dict['nodes'], raw_nav_dict['edges'], raw_nav_dict['obstacles'] = nodes, edges, obstacles 46 | data_scans = { 47 | 'nodes': raw_nav_dict['nodes'], 48 | 'edges': raw_nav_dict['edges'], 49 | } 50 | obstacle_dict_scans = raw_nav_dict['obstacles'] 51 | scans = list(raw_nav_dict['nodes'].keys()) 52 | 53 | 54 | overall_nav_dict = {} 55 | del_nodes = 0 56 | count_nodes = 0 57 | target_count = 0 # not count target because it is Gaussian 58 | openspace_count = 0; obstacle_count = 0 59 | rawedges_count = 0; postedges_count = 0 60 | 61 | for scan in scans: 62 | ''' connectivity dictionary ''' 63 | obstacle_dict = obstacle_dict_scans[scan] 64 | 65 | connect_dict = {} 66 | for edge_id, edge_info in data_scans['edges'][scan].items(): 67 | node_a = edge_info['nodes'][0] 68 | node_b = edge_info['nodes'][1] 69 | 70 | if node_a not in connect_dict: 71 | connect_dict[node_a] = [node_b] 72 | else: 73 | connect_dict[node_a].append(node_b) 74 | if node_b not in connect_dict: 75 | connect_dict[node_b] = [node_a] 76 | else: 77 | connect_dict[node_b].append(node_a) 78 | 79 | ''' process each node to standard data format ''' 80 | navigability_dict = {} 81 | groundtruth_dict = {} 82 | count_nodes_i = 0 83 | del_nodes_i = 0 84 | for node_a, neighbors in connect_dict.items(): 85 | count_nodes += 1; count_nodes_i += 1 86 | navigability_dict[node_a] = utils.init_node_nav_dict(ANGLES) 87 | groundtruth_dict[node_a] = utils.init_node_gt_dict(ANGLES) 88 | 89 | node_a_pos = np.array(data_scans['nodes'][scan][node_a])[[0,2]] 90 | groundtruth_dict[node_a]['source_pos'] = node_a_pos.tolist() 91 | 92 | for node_b in neighbors: 93 | node_b_pos = np.array(data_scans['nodes'][scan][node_b])[[0,2]] 94 | 95 | edge_vec = (node_b_pos - node_a_pos) 96 | angle, angleIndex, \ 97 | distance, \ 98 | distanceIndex = utils.edge_vec_to_indexes( 99 | edge_vec, ANGLES) 100 | 101 | # remove too far or too close viewpoints 102 | if distanceIndex == -1: 103 | continue 104 | # keep the further keypoint in the same sector 105 | if navigability_dict[node_a][str(angleIndex)]['has_waypoint']: 106 | if distance < navigability_dict[node_a][str(angleIndex)]['waypoint']['distance']: 107 | continue 108 | 109 | # if distance <= RADIUS: 110 | navigability_dict[node_a][str(angleIndex)]['has_waypoint'] = True 111 | navigability_dict[node_a][str(angleIndex)]['waypoint'] = { 112 | 'node_id': node_b, 113 | 'position': node_b_pos, 114 | 'angle': angle, 115 | 'angleIndex': angleIndex, 116 | 'distance': distance, 117 | 'distanceIndex': distanceIndex, 118 | } 119 | ''' set target map ''' 120 | groundtruth_dict[node_a]['target'][angleIndex, distanceIndex] = 1 121 | groundtruth_dict[node_a]['target_pos'].append(node_b_pos.tolist()) 122 | 123 | # record the number of raw targets 124 | raw_target_count = groundtruth_dict[node_a]['target'].sum() 125 | 126 | if raw_target_count == 0: 127 | del(groundtruth_dict[node_a]) 128 | del_nodes += 1; del_nodes_i += 1 129 | continue 130 | 131 | ''' a Gaussian target map ''' 132 | gau_peak = 10 133 | gau_sig_angle = 1.0 134 | gau_sig_dist = 2.0 135 | groundtruth_dict[node_a]['target'] = groundtruth_dict[node_a]['target'].astype(np.float32) 136 | 137 | gau_temp_in = np.concatenate( 138 | ( 139 | np.zeros((ANGLES,10)), 140 | groundtruth_dict[node_a]['target'], 141 | np.zeros((ANGLES,10)), 142 | ), axis=1) 143 | 144 | gau_target = gaussian_filter( 145 | gau_temp_in, 146 | sigma=(1,2), 147 | mode='wrap', 148 | ) 149 | gau_target = gau_target[:, 10:10+DISTANCES] 150 | 151 | gau_target_maxnorm = gau_target / gau_target.max() 152 | groundtruth_dict[node_a]['target'] = gau_peak * gau_target_maxnorm 153 | 154 | for k in range(ANGLES): 155 | k_dist = obstacle_dict[node_a][str(k)]['obstacle_distance'] 156 | if k_dist is None: 157 | k_dist = 100 158 | navigability_dict[node_a][str(k)]['obstacle_distance'] = k_dist 159 | 160 | k_dindex = utils.get_obstacle_distanceIndex12(k_dist) 161 | navigability_dict[node_a][str(k)]['obstacle_index'] = k_dindex 162 | 163 | ''' deal with obstacle ''' 164 | if k_dindex != -1: 165 | groundtruth_dict[node_a]['obstacle'][k][:k_dindex] = np.zeros(k_dindex) 166 | else: 167 | groundtruth_dict[node_a]['obstacle'][k] = np.zeros(12) 168 | 169 | 170 | ''' ********** very important ********** ''' 171 | ''' adjust target map ''' 172 | ''' obstacle comes first ''' 173 | 174 | rawt = copy.deepcopy(groundtruth_dict[node_a]['target']) 175 | 176 | groundtruth_dict[node_a]['target'] = \ 177 | groundtruth_dict[node_a]['target'] * (groundtruth_dict[node_a]['obstacle'] == 0) 178 | 179 | # a confidence thresholding 180 | if groundtruth_dict[node_a]['target'].max() < 0.75*gau_peak: 181 | del(groundtruth_dict[node_a]) 182 | del_nodes += 1; del_nodes_i += 1 183 | continue 184 | 185 | postt = copy.deepcopy(groundtruth_dict[node_a]['target']) 186 | rawedges_count += (rawt==gau_peak).sum() 187 | postedges_count += (postt==gau_peak).sum() 188 | 189 | ''' ********** very important ********** ''' 190 | 191 | openspace_count += (groundtruth_dict[node_a]['obstacle'] == 0).sum() 192 | obstacle_count += (groundtruth_dict[node_a]['obstacle'] == 1).sum() 193 | 194 | groundtruth_dict[node_a]['target'] = groundtruth_dict[node_a]['target'].tolist() 195 | groundtruth_dict[node_a]['weight'] = groundtruth_dict[node_a]['weight'].tolist() 196 | groundtruth_dict[node_a]['obstacle'] = groundtruth_dict[node_a]['obstacle'].tolist() 197 | 198 | overall_nav_dict[scan] = groundtruth_dict 199 | 200 | print('Obstacle comes before target !!!') 201 | print('Number of deleted nodes: %s/%s'%(del_nodes, count_nodes)) 202 | print('Ratio of obstacle behind target: %s/%s'%(postedges_count,rawedges_count)) 203 | 204 | print('Ratio of openspace %.5f'%(openspace_count/(openspace_count+obstacle_count))) 205 | print('Ratio of obstacle %.5f'%(obstacle_count/(openspace_count+obstacle_count))) 206 | 207 | with open('./training_data/%s_%s_mp3d_waypoint_twm0.2_obstacle_first_withpos.json'%(ANGLES, split), 'w') as f: 208 | json.dump(overall_nav_dict, f) 209 | print('Done') 210 | 211 | # import pdb; pdb.set_trace() 212 | -------------------------------------------------------------------------------- /transformer/waypoint_bert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Microsoft Corporation. Licensed under the MIT license. 2 | # Modified in Recurrent VLN-BERT, 2020, Yicong.Hong@anu.edu.au 3 | 4 | from __future__ import absolute_import, division, print_function, unicode_literals 5 | import logging 6 | import math 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torch.nn import CrossEntropyLoss, MSELoss 11 | 12 | from .pytorch_transformer.modeling_bert import (BertEmbeddings, 13 | BertSelfAttention, BertAttention, BertEncoder, BertLayer, 14 | BertSelfOutput, BertIntermediate, BertOutput, 15 | BertPooler, BertLayerNorm, BertPreTrainedModel, 16 | BertPredictionHeadTransform) 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | class VisPosEmbeddings(nn.Module): 21 | def __init__(self, config): 22 | super(VisPosEmbeddings, self).__init__() 23 | self.position_embeddings = nn.Embedding(24, config.hidden_size) 24 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 25 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 26 | 27 | def forward(self, input_vis_feats, position_ids=None): 28 | seq_length = input_vis_feats.size(1) 29 | if position_ids is None: 30 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_vis_feats.device) 31 | position_ids = position_ids.unsqueeze(0).repeat(input_vis_feats.size(0), 1) 32 | 33 | vis_embeddings = input_vis_feats 34 | position_embeddings = self.position_embeddings(position_ids) 35 | 36 | embeddings = vis_embeddings + position_embeddings 37 | embeddings = self.LayerNorm(embeddings) 38 | # embeddings = self.dropout(embeddings) 39 | return embeddings 40 | 41 | class CaptionBertSelfAttention(BertSelfAttention): 42 | """ 43 | Modified from BertSelfAttention to add support for output_hidden_states. 44 | """ 45 | def __init__(self, config): 46 | super(CaptionBertSelfAttention, self).__init__(config) 47 | self.config = config 48 | 49 | def forward(self, hidden_states, attention_mask, head_mask=None, 50 | history_state=None): 51 | if history_state is not None: 52 | x_states = torch.cat([history_state, hidden_states], dim=1) 53 | mixed_query_layer = self.query(hidden_states) 54 | mixed_key_layer = self.key(x_states) 55 | mixed_value_layer = self.value(x_states) 56 | else: 57 | mixed_query_layer = self.query(hidden_states) 58 | mixed_key_layer = self.key(hidden_states) 59 | mixed_value_layer = self.value(hidden_states) 60 | 61 | ''' language feature only provide Keys and Values ''' 62 | query_layer = self.transpose_for_scores(mixed_query_layer) 63 | key_layer = self.transpose_for_scores(mixed_key_layer) 64 | value_layer = self.transpose_for_scores(mixed_value_layer) 65 | 66 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 67 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 68 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 69 | attention_scores = attention_scores + attention_mask 70 | 71 | # Normalize the attention scores to probabilities. 72 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 73 | 74 | # This is actually dropping out entire tokens to attend to, which might 75 | # seem a bit unusual, but is taken from the original Transformer paper. 76 | attention_probs = self.dropout(attention_probs) 77 | 78 | # Mask heads if we want to 79 | if head_mask is not None: 80 | attention_probs = attention_probs * head_mask 81 | 82 | context_layer = torch.matmul(attention_probs, value_layer) 83 | 84 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 85 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 86 | context_layer = context_layer.view(*new_context_layer_shape) 87 | 88 | outputs = (context_layer, attention_scores) 89 | 90 | return outputs 91 | 92 | 93 | class CaptionBertAttention(BertAttention): 94 | """ 95 | Modified from BertAttention to add support for output_hidden_states. 96 | """ 97 | def __init__(self, config): 98 | super(CaptionBertAttention, self).__init__(config) 99 | self.self = CaptionBertSelfAttention(config) 100 | self.output = BertSelfOutput(config) 101 | self.config = config 102 | 103 | def forward(self, input_tensor, attention_mask, head_mask=None, 104 | history_state=None): 105 | ''' transformer processing ''' 106 | self_outputs = self.self(input_tensor, attention_mask, head_mask, history_state) 107 | 108 | ''' feed-forward network with residule ''' 109 | attention_output = self.output(self_outputs[0], input_tensor) 110 | 111 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 112 | 113 | return outputs 114 | 115 | 116 | class CaptionBertLayer(BertLayer): 117 | """ 118 | Modified from BertLayer to add support for output_hidden_states. 119 | """ 120 | def __init__(self, config): 121 | super(CaptionBertLayer, self).__init__(config) 122 | self.attention = CaptionBertAttention(config) 123 | self.intermediate = BertIntermediate(config) 124 | self.output = BertOutput(config) 125 | 126 | def forward(self, hidden_states, attention_mask, head_mask=None, 127 | history_state=None): 128 | 129 | attention_outputs = self.attention(hidden_states, attention_mask, 130 | head_mask, history_state) 131 | 132 | ''' feed-forward network with residule ''' 133 | attention_output = attention_outputs[0] 134 | intermediate_output = self.intermediate(attention_output) 135 | layer_output = self.output(intermediate_output, attention_output) 136 | outputs = (layer_output,) + attention_outputs[1:] 137 | 138 | return outputs 139 | 140 | 141 | class CaptionBertEncoder(BertEncoder): 142 | """ 143 | Modified from BertEncoder to add support for output_hidden_states. 144 | """ 145 | def __init__(self, config): 146 | super(CaptionBertEncoder, self).__init__(config) 147 | self.output_attentions = config.output_attentions 148 | self.output_hidden_states = config.output_hidden_states 149 | # 12 Bert layers 150 | self.layer = nn.ModuleList([CaptionBertLayer(config) for _ in range(config.num_hidden_layers)]) 151 | self.config = config 152 | 153 | def forward(self, hidden_states, attention_mask, head_mask=None, 154 | encoder_history_states=None): 155 | 156 | for i, layer_module in enumerate(self.layer): 157 | history_state = None if encoder_history_states is None else encoder_history_states[i] # default None 158 | 159 | layer_outputs = layer_module( 160 | hidden_states, attention_mask, head_mask[i], 161 | history_state) 162 | hidden_states = layer_outputs[0] 163 | 164 | if i == self.config.num_hidden_layers - 1: 165 | slang_attention_score = layer_outputs[1] 166 | 167 | outputs = (hidden_states, slang_attention_score) 168 | 169 | return outputs 170 | 171 | 172 | class BertImgModel(nn.Module): 173 | """ Expand from BertModel to handle image region features as input 174 | """ 175 | def __init__(self, config): 176 | super(BertImgModel, self).__init__() 177 | self.config = config 178 | # self.vis_pos_embeds = VisPosEmbeddings(config) 179 | self.encoder = CaptionBertEncoder(config) 180 | 181 | def forward(self, input_x, attention_mask=None): 182 | 183 | extended_attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 184 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 185 | 186 | head_mask = [None] * self.config.num_hidden_layers 187 | 188 | ''' positional encodings ''' 189 | # input_x = self.vis_pos_embeds(input_x) 190 | 191 | ''' pass to the Transformer layers ''' 192 | encoder_outputs = self.encoder(input_x, 193 | extended_attention_mask, head_mask=head_mask) 194 | 195 | outputs = (encoder_outputs[0],) + encoder_outputs[1:] 196 | 197 | return outputs 198 | 199 | 200 | class WaypointBert(nn.Module): 201 | """ 202 | Modified from BertForMultipleChoice to support oscar training. 203 | """ 204 | def __init__(self, config=None): 205 | super(WaypointBert, self).__init__() 206 | self.config = config 207 | self.bert = BertImgModel(config) 208 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 209 | 210 | def forward(self, input_x, attention_mask=None): 211 | 212 | outputs = self.bert(input_x, attention_mask=attention_mask) 213 | 214 | sequence_output = outputs[0] 215 | sequence_output = self.dropout(sequence_output) 216 | 217 | return sequence_output 218 | -------------------------------------------------------------------------------- /ddppo_resnet/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional, Type, Union, cast 8 | 9 | from torch import Tensor 10 | from torch import nn as nn 11 | from torch.nn.modules.container import Sequential 12 | from torch.nn.modules.conv import Conv2d 13 | 14 | 15 | def conv3x3( 16 | in_planes: int, out_planes: int, stride: int = 1, groups: int = 1 17 | ) -> Conv2d: 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d( 20 | in_planes, 21 | out_planes, 22 | kernel_size=3, 23 | stride=stride, 24 | padding=1, 25 | bias=False, 26 | groups=groups, 27 | ) 28 | 29 | 30 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> Conv2d: 31 | """1x1 convolution""" 32 | return nn.Conv2d( 33 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False 34 | ) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | resneXt = False 40 | 41 | def __init__( 42 | self, 43 | inplanes, 44 | planes, 45 | ngroups, 46 | stride=1, 47 | downsample=None, 48 | cardinality=1, 49 | ): 50 | super(BasicBlock, self).__init__() 51 | self.convs = nn.Sequential( 52 | conv3x3(inplanes, planes, stride, groups=cardinality), 53 | nn.GroupNorm(ngroups, planes), 54 | nn.ReLU(True), 55 | conv3x3(planes, planes, groups=cardinality), 56 | nn.GroupNorm(ngroups, planes), 57 | ) 58 | self.downsample = downsample 59 | self.relu = nn.ReLU(True) 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.convs(x) 65 | 66 | if self.downsample is not None: 67 | residual = self.downsample(x) 68 | 69 | return self.relu(out + residual) 70 | 71 | 72 | def _build_bottleneck_branch( 73 | inplanes: int, 74 | planes: int, 75 | ngroups: int, 76 | stride: int, 77 | expansion: int, 78 | groups: int = 1, 79 | ) -> Sequential: 80 | return nn.Sequential( 81 | conv1x1(inplanes, planes), 82 | nn.GroupNorm(ngroups, planes), 83 | nn.ReLU(True), 84 | conv3x3(planes, planes, stride, groups=groups), 85 | nn.GroupNorm(ngroups, planes), 86 | nn.ReLU(True), 87 | conv1x1(planes, planes * expansion), 88 | nn.GroupNorm(ngroups, planes * expansion), 89 | ) 90 | 91 | 92 | class SE(nn.Module): 93 | def __init__(self, planes, r=16): 94 | super().__init__() 95 | self.squeeze = nn.AdaptiveAvgPool2d(1) 96 | self.excite = nn.Sequential( 97 | nn.Linear(planes, int(planes / r)), 98 | nn.ReLU(True), 99 | nn.Linear(int(planes / r), planes), 100 | nn.Sigmoid(), 101 | ) 102 | 103 | def forward(self, x): 104 | b, c, _, _ = x.size() 105 | x = self.squeeze(x) 106 | x = x.view(b, c) 107 | x = self.excite(x) 108 | 109 | return x.view(b, c, 1, 1) 110 | 111 | 112 | def _build_se_branch(planes, r=16): 113 | return SE(planes, r) 114 | 115 | 116 | class Bottleneck(nn.Module): 117 | expansion = 4 118 | resneXt = False 119 | 120 | def __init__( 121 | self, 122 | inplanes: int, 123 | planes: int, 124 | ngroups: int, 125 | stride: int = 1, 126 | downsample: Optional[Sequential] = None, 127 | cardinality: int = 1, 128 | ) -> None: 129 | super().__init__() 130 | self.convs = _build_bottleneck_branch( 131 | inplanes, 132 | planes, 133 | ngroups, 134 | stride, 135 | self.expansion, 136 | groups=cardinality, 137 | ) 138 | self.relu = nn.ReLU(inplace=True) 139 | self.downsample = downsample 140 | 141 | def _impl(self, x: Tensor) -> Tensor: 142 | identity = x 143 | 144 | out = self.convs(x) 145 | 146 | if self.downsample is not None: 147 | identity = self.downsample(x) 148 | 149 | return self.relu(out + identity) 150 | 151 | def forward(self, x: Tensor) -> Tensor: 152 | return self._impl(x) 153 | 154 | 155 | class SEBottleneck(Bottleneck): 156 | def __init__( 157 | self, 158 | inplanes, 159 | planes, 160 | ngroups, 161 | stride=1, 162 | downsample=None, 163 | cardinality=1, 164 | ): 165 | super().__init__( 166 | inplanes, planes, ngroups, stride, downsample, cardinality 167 | ) 168 | 169 | self.se = _build_se_branch(planes * self.expansion) 170 | 171 | def _impl(self, x): 172 | identity = x 173 | 174 | out = self.convs(x) 175 | out = self.se(out) * out 176 | 177 | if self.downsample is not None: 178 | identity = self.downsample(x) 179 | 180 | return self.relu(out + identity) 181 | 182 | 183 | class SEResNeXtBottleneck(SEBottleneck): 184 | expansion = 2 185 | resneXt = True 186 | 187 | 188 | class ResNeXtBottleneck(Bottleneck): 189 | expansion = 2 190 | resneXt = True 191 | 192 | 193 | Block = Union[Type[Bottleneck], Type[BasicBlock]] 194 | 195 | 196 | class ResNet(nn.Module): 197 | def __init__( 198 | self, 199 | in_channels: int, 200 | base_planes: int, 201 | ngroups: int, 202 | block: Block, 203 | layers: List[int], 204 | cardinality: int = 1, 205 | ) -> None: 206 | super(ResNet, self).__init__() 207 | self.conv1 = nn.Sequential( 208 | nn.Conv2d( 209 | in_channels, 210 | base_planes, 211 | kernel_size=7, 212 | stride=2, 213 | padding=3, 214 | bias=False, 215 | ), 216 | nn.GroupNorm(ngroups, base_planes), 217 | nn.ReLU(True), 218 | ) 219 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 220 | self.cardinality = cardinality 221 | 222 | self.inplanes = base_planes 223 | if block.resneXt: 224 | base_planes *= 2 225 | 226 | self.layer1 = self._make_layer(block, ngroups, base_planes, layers[0]) 227 | self.layer2 = self._make_layer( 228 | block, ngroups, base_planes * 2, layers[1], stride=2 229 | ) 230 | self.layer3 = self._make_layer( 231 | block, ngroups, base_planes * 2 * 2, layers[2], stride=2 232 | ) 233 | self.layer4 = self._make_layer( 234 | block, ngroups, base_planes * 2 * 2 * 2, layers[3], stride=2 235 | ) 236 | 237 | self.final_channels = self.inplanes 238 | self.final_spatial_compress = 1.0 / (2 ** 5) 239 | 240 | def _make_layer( 241 | self, 242 | block: Block, 243 | ngroups: int, 244 | planes: int, 245 | blocks: int, 246 | stride: int = 1, 247 | ) -> Sequential: 248 | downsample = None 249 | if stride != 1 or self.inplanes != planes * block.expansion: 250 | downsample = nn.Sequential( 251 | conv1x1(self.inplanes, planes * block.expansion, stride), 252 | nn.GroupNorm(ngroups, planes * block.expansion), 253 | ) 254 | 255 | layers = [] 256 | layers.append( 257 | block( 258 | self.inplanes, 259 | planes, 260 | ngroups, 261 | stride, 262 | downsample, 263 | cardinality=self.cardinality, 264 | ) 265 | ) 266 | self.inplanes = planes * block.expansion 267 | for _i in range(1, blocks): 268 | layers.append(block(self.inplanes, planes, ngroups)) 269 | 270 | return nn.Sequential(*layers) 271 | 272 | def forward(self, x) -> Tensor: 273 | x = self.conv1(x) 274 | x = self.maxpool(x) 275 | x = cast(Tensor, x) 276 | x = self.layer1(x) 277 | x = self.layer2(x) 278 | x = self.layer3(x) 279 | x = self.layer4(x) 280 | 281 | return x 282 | 283 | 284 | def resnet18(in_channels, base_planes, ngroups): 285 | model = ResNet(in_channels, base_planes, ngroups, BasicBlock, [2, 2, 2, 2]) 286 | 287 | return model 288 | 289 | 290 | def resnet50(in_channels: int, base_planes: int, ngroups: int) -> ResNet: 291 | model = ResNet(in_channels, base_planes, ngroups, Bottleneck, [3, 4, 6, 3]) 292 | 293 | return model 294 | 295 | 296 | def resneXt50(in_channels, base_planes, ngroups): 297 | model = ResNet( 298 | in_channels, 299 | base_planes, 300 | ngroups, 301 | ResNeXtBottleneck, 302 | [3, 4, 6, 3], 303 | cardinality=int(base_planes / 2), 304 | ) 305 | 306 | return model 307 | 308 | 309 | def se_resnet50(in_channels, base_planes, ngroups): 310 | model = ResNet( 311 | in_channels, base_planes, ngroups, SEBottleneck, [3, 4, 6, 3] 312 | ) 313 | 314 | return model 315 | 316 | 317 | def se_resneXt50(in_channels, base_planes, ngroups): 318 | model = ResNet( 319 | in_channels, 320 | base_planes, 321 | ngroups, 322 | SEResNeXtBottleneck, 323 | [3, 4, 6, 3], 324 | cardinality=int(base_planes / 2), 325 | ) 326 | 327 | return model 328 | 329 | 330 | def se_resneXt101(in_channels, base_planes, ngroups): 331 | model = ResNet( 332 | in_channels, 333 | base_planes, 334 | ngroups, 335 | SEResNeXtBottleneck, 336 | [3, 4, 23, 3], 337 | cardinality=int(base_planes / 2), 338 | ) 339 | 340 | return model 341 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | import utils 5 | import copy 6 | from scipy.spatial.distance import cdist 7 | import matplotlib.pyplot as plt 8 | from skimage.transform import resize 9 | import os 10 | 11 | def waypoint_eval(args, predictions): 12 | ''' Evaluation of the predicted waypoint map, 13 | notice that the number of candidates is cap at args.MAX_NUM_CANDIDATES, 14 | but the number of GT waypoints could be any value in range [1,args.ANGLES]. 15 | 16 | The preprocessed data is constraining each angle sector has at most 17 | one GT waypoint. 18 | ''' 19 | 20 | sample_id = predictions['sample_id'] 21 | source_pos = predictions['source_pos'] 22 | target_pos = predictions['target_pos'] 23 | probs = predictions['probs'] 24 | logits = predictions['logits'] 25 | target = predictions['target'] 26 | obstacle = predictions['obstacle'] 27 | sample_loss = predictions['sample_loss'] 28 | 29 | results = { 30 | 'candidates': {}, 31 | 'p_waypoint_openspace': 0.0, 32 | 'p_waypoint_obstacle': 0.0, 33 | 'avg_wayscore': 0.0, 34 | 'avg_pred_distance': 0.0, 35 | 'avg_chamfer_distance': 0.0, 36 | 'avg_hausdorff_distance': 0.0, 37 | 'avg_num_delta': 0.0, 38 | } 39 | 40 | num_candidate = [] # cap at args.MAX_NUM_CANDIDATES 41 | num_waypoint_openspace = [] # % waypoint in open space 42 | num_waypoint_obstacle = [] # % waypoint in obstacle 43 | waypoint_score = [] # scores on target map collected by predictions 44 | pred_distance = [] # distance from targets to predictions 45 | chamfer_distance_all = [] 46 | hausdorff_distance_all = [] 47 | num_delta_all = [] 48 | 49 | ''' output prediction ''' 50 | for i, batch_x in enumerate(logits): 51 | batch_sample_id = sample_id[i] 52 | batch_source_pos = source_pos[i] 53 | batch_target_pos = target_pos[i] 54 | batch_target = target[i] 55 | batch_obstacle = obstacle[i] 56 | batch_sample_loss = sample_loss[i] 57 | 58 | batch_x = torch.tensor(batch_x) 59 | batch_x_norm = torch.softmax( 60 | batch_x.reshape( 61 | batch_x.size(0), args.ANGLES*args.NUM_CLASSES 62 | ), dim=1 63 | ) 64 | batch_x_norm = batch_x_norm.reshape(batch_x.size(0), args.ANGLES, args.NUM_CLASSES) 65 | # batch_x_norm = torch.sigmoid(batch_x) 66 | 67 | # batch_output_map = utils.nms( 68 | # batch_x_norm.unsqueeze(1), max_predictions=args.MAX_NUM_CANDIDATES, 69 | # sigma=(7.0,5.0)) 70 | # batch_output_map = batch_output_map.squeeze() 71 | 72 | batch_x_norm_wrap = torch.cat( 73 | (batch_x_norm[:,-1:,:], batch_x_norm, batch_x_norm[:,:1,:]), 74 | dim=1) 75 | batch_output_map = utils.nms( 76 | batch_x_norm_wrap.unsqueeze(1), max_predictions=5, 77 | sigma=(7.0,5.0)) 78 | batch_output_map = batch_output_map.squeeze()[:,1:-1,:] 79 | 80 | if args.VIS: 81 | # # nms without different sigma 82 | batch_output_map_sig4 = utils.nms( 83 | batch_x_norm_wrap.unsqueeze(1), max_predictions=args.MAX_NUM_CANDIDATES, 84 | sigma=(4.0,4.0)) 85 | batch_output_map_sig4 = batch_output_map_sig4.squeeze()[:,1:-1,:] 86 | batch_output_map_sig5 = utils.nms( 87 | batch_x_norm_wrap.unsqueeze(1), max_predictions=args.MAX_NUM_CANDIDATES, 88 | sigma=(5.0,5.0)) 89 | batch_output_map_sig5 = batch_output_map_sig5.squeeze()[:,1:-1,:] 90 | batch_output_map_sig7_5 = utils.nms( 91 | batch_x_norm_wrap.unsqueeze(1), max_predictions=args.MAX_NUM_CANDIDATES, 92 | sigma=(7.0,5.0)) 93 | batch_output_map_sig7_5 = batch_output_map_sig7_5.squeeze()[:,1:-1,:] 94 | 95 | for j, id in enumerate(batch_sample_id): 96 | # pick one distance from each non-zeros column 97 | candidates = {} 98 | c_openspace = 0 99 | c_obstacle = 0 100 | candidates_pos = [] 101 | 102 | ''' gather predicted candidates and check if candidates are in openspace ''' 103 | for jdx, angle_view in enumerate(batch_output_map[j]): 104 | if angle_view.sum() != 0: 105 | candidates[jdx] = angle_view.argmax().item() 106 | candidates_pos.append( 107 | [jdx * 2 * math.pi / args.ANGLES, 108 | (candidates[jdx]+1) * 0.25]) 109 | # opensapce / obstacle 110 | if batch_obstacle[j][jdx][candidates[jdx]] == 0: 111 | c_openspace += 1 112 | else: 113 | c_obstacle += 1 114 | 115 | # the inferene ouput 116 | results['candidates'][id] = { 117 | # 'loss': batch_sample_loss[j], 118 | 'angle_dist': candidates, 119 | } 120 | num_candidate.append(len(candidates)) 121 | num_waypoint_openspace.append(c_openspace) 122 | num_waypoint_obstacle.append(c_obstacle) 123 | 124 | ''' score collected over the target heatmap by predictions ''' 125 | # score = (torch.tensor(batch_target[j])[batch_output_map[j] != 0]).sum() 126 | # waypoint_score.append(score.item()) 127 | score_map = torch.tensor(batch_target[j]) 128 | # using binary selection here doesn't conflict with 129 | # the candidates due to the large sigmas for NMS 130 | score = (score_map[batch_output_map[j] != 0] 131 | ).sum() / (len(candidates)) 132 | waypoint_score.append(score.item()) 133 | 134 | ''' measure target to prediction distance ''' 135 | bsp = np.array(batch_source_pos[j]) 136 | btp = np.array(batch_target_pos[j]) 137 | cp = np.array(candidates_pos) 138 | cp_x = np.sin(cp[:,0]) * cp[:,1] + bsp[0] 139 | cp_y = np.cos(cp[:,0]) * cp[:,1] + bsp[1] 140 | cp = np.concatenate( 141 | (np.expand_dims(cp_x, axis=1), 142 | np.expand_dims(cp_y, axis=1)), axis=1) 143 | # take the minimal distance from each target 144 | # to all predictions 145 | tp_dists = cdist(btp, cp) 146 | tp_dist_min = tp_dists.min(1).mean() 147 | pred_distance.append(tp_dist_min) 148 | 149 | # Chamfer distance 150 | predict_to_gt_0 = tp_dists.min(0).mean() 151 | gt_to_predict_0 = tp_dists.min(1).mean() 152 | chamfer_distance = 0.5 * ( 153 | predict_to_gt_0 + gt_to_predict_0) 154 | chamfer_distance_all.append(chamfer_distance) 155 | 156 | # Hausdorff distance 157 | predict_to_gt_1 = tp_dists.min(0).max() 158 | gt_to_predict_1 = tp_dists.min(1).max() 159 | hausdorff_distance = max( 160 | predict_to_gt_1, gt_to_predict_1) 161 | hausdorff_distance_all.append(hausdorff_distance) 162 | 163 | # prediction-target delta 164 | num_target = len(batch_target_pos[j]) 165 | num_predict = len(candidates_pos) 166 | num_delta = num_predict - num_target 167 | num_delta_all.append(num_delta) 168 | 169 | if args.VIS: 170 | import pdb; pdb.set_trace() 171 | save_img_dir = './visualize/%s-best_avg_wayscore'%(args.EXP_ID.split('-')[1]) 172 | if not os.path.exists(save_img_dir): 173 | os.makedirs(save_img_dir) 174 | 175 | im1 = (np.array(batch_target[j])/np.array(batch_target[j]).max()*255).astype('uint8') 176 | batch_x_pos = copy.deepcopy(batch_x[j].numpy()) 177 | batch_x_pos[batch_x_pos<0]=0.0 178 | im2 = (batch_x_pos/batch_x_pos.max()*255).astype('uint8') 179 | im6 = (batch_output_map_sig4[j].numpy()/batch_output_map_sig4[j].numpy().max()*255).astype('uint8') 180 | im7 = (batch_output_map_sig5[j].numpy()/batch_output_map_sig5[j].numpy().max()*255).astype('uint8') 181 | im8 = (batch_output_map_sig7_5[j].numpy()/batch_output_map_sig7_5[j].numpy().max()*255).astype('uint8') 182 | fig = plt.figure(figsize=(10,14)) 183 | fig.add_subplot(1, 5, 1); plt.imshow(im6); plt.axis('off') 184 | fig.add_subplot(1, 5, 2); plt.imshow(im7); plt.axis('off') 185 | fig.add_subplot(1, 5, 3); plt.imshow(im8); plt.axis('off') 186 | fig.add_subplot(1, 5, 4); plt.imshow(im2); plt.axis('off') 187 | fig.add_subplot(1, 5, 5); plt.imshow(im1); plt.axis('off') 188 | plt.savefig(save_img_dir+'/predict-target-%s-%s.jpeg'%(i,j), 189 | bbox_inches='tight') 190 | plt.close() 191 | 192 | p_waypoint_openspace = sum(num_waypoint_openspace) / sum(num_candidate) 193 | p_waypoint_obstacle = sum(num_waypoint_obstacle) / sum(num_candidate) 194 | avg_wayscore = np.mean(waypoint_score).item() 195 | avg_pred_distance = np.mean(pred_distance).item() 196 | avg_chamfer_distance = np.mean(chamfer_distance_all).item() 197 | avg_hausdorff_distance = np.mean(hausdorff_distance_all).item() 198 | avg_num_delta = np.mean(num_delta_all).item() 199 | 200 | results['p_waypoint_openspace'] = p_waypoint_openspace 201 | results['p_waypoint_obstacle'] = p_waypoint_obstacle 202 | results['avg_wayscore'] = avg_wayscore 203 | results['avg_pred_distance'] = avg_pred_distance 204 | results['avg_chamfer_distance'] = avg_chamfer_distance 205 | results['avg_hausdorff_distance'] = avg_hausdorff_distance 206 | results['avg_num_delta'] = avg_num_delta 207 | 208 | return results 209 | -------------------------------------------------------------------------------- /transformer/pytorch_transformer/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | from io import open 18 | 19 | import boto3 20 | import requests 21 | from botocore.exceptions import ClientError 22 | from tqdm import tqdm 23 | 24 | try: 25 | from torch.hub import _get_torch_home 26 | torch_cache_home = _get_torch_home() 27 | except ImportError: 28 | torch_cache_home = os.path.expanduser( 29 | os.getenv('TORCH_HOME', os.path.join( 30 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 31 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers') 32 | 33 | try: 34 | from urllib.parse import urlparse 35 | except ImportError: 36 | from urlparse import urlparse 37 | 38 | try: 39 | from pathlib import Path 40 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 41 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) 42 | except (AttributeError, ImportError): 43 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 44 | default_cache_path) 45 | 46 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 47 | 48 | 49 | def url_to_filename(url, etag=None): 50 | """ 51 | Convert `url` into a hashed filename in a repeatable way. 52 | If `etag` is specified, append its hash to the url's, delimited 53 | by a period. 54 | """ 55 | url_bytes = url.encode('utf-8') 56 | url_hash = sha256(url_bytes) 57 | filename = url_hash.hexdigest() 58 | 59 | if etag: 60 | etag_bytes = etag.encode('utf-8') 61 | etag_hash = sha256(etag_bytes) 62 | filename += '.' + etag_hash.hexdigest() 63 | 64 | return filename 65 | 66 | 67 | def filename_to_url(filename, cache_dir=None): 68 | """ 69 | Return the url and etag (which may be ``None``) stored for `filename`. 70 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 71 | """ 72 | if cache_dir is None: 73 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 74 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 75 | cache_dir = str(cache_dir) 76 | 77 | cache_path = os.path.join(cache_dir, filename) 78 | if not os.path.exists(cache_path): 79 | raise EnvironmentError("file {} not found".format(cache_path)) 80 | 81 | meta_path = cache_path + '.json' 82 | if not os.path.exists(meta_path): 83 | raise EnvironmentError("file {} not found".format(meta_path)) 84 | 85 | with open(meta_path, encoding="utf-8") as meta_file: 86 | metadata = json.load(meta_file) 87 | url = metadata['url'] 88 | etag = metadata['etag'] 89 | 90 | return url, etag 91 | 92 | 93 | def cached_path(url_or_filename, cache_dir=None): 94 | """ 95 | Given something that might be a URL (or might be a local path), 96 | determine which. If it's a URL, download the file and cache it, and 97 | return the path to the cached file. If it's already a local path, 98 | make sure the file exists and then return the path. 99 | """ 100 | if cache_dir is None: 101 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 102 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 103 | url_or_filename = str(url_or_filename) 104 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 105 | cache_dir = str(cache_dir) 106 | 107 | parsed = urlparse(url_or_filename) 108 | 109 | if parsed.scheme in ('http', 'https', 's3'): 110 | # URL, so get it from the cache (downloading if necessary) 111 | return get_from_cache(url_or_filename, cache_dir) 112 | elif os.path.exists(url_or_filename): 113 | # File, and it exists. 114 | return url_or_filename 115 | elif parsed.scheme == '': 116 | # File, but it doesn't exist. 117 | raise EnvironmentError("file {} not found".format(url_or_filename)) 118 | else: 119 | # Something unknown 120 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 121 | 122 | 123 | def split_s3_path(url): 124 | """Split a full s3 path into the bucket name and path.""" 125 | parsed = urlparse(url) 126 | if not parsed.netloc or not parsed.path: 127 | raise ValueError("bad s3 path {}".format(url)) 128 | bucket_name = parsed.netloc 129 | s3_path = parsed.path 130 | # Remove '/' at beginning of path. 131 | if s3_path.startswith("/"): 132 | s3_path = s3_path[1:] 133 | return bucket_name, s3_path 134 | 135 | 136 | def s3_request(func): 137 | """ 138 | Wrapper function for s3 requests in order to create more helpful error 139 | messages. 140 | """ 141 | 142 | @wraps(func) 143 | def wrapper(url, *args, **kwargs): 144 | try: 145 | return func(url, *args, **kwargs) 146 | except ClientError as exc: 147 | if int(exc.response["Error"]["Code"]) == 404: 148 | raise EnvironmentError("file {} not found".format(url)) 149 | else: 150 | raise 151 | 152 | return wrapper 153 | 154 | 155 | @s3_request 156 | def s3_etag(url): 157 | """Check ETag on S3 object.""" 158 | s3_resource = boto3.resource("s3") 159 | bucket_name, s3_path = split_s3_path(url) 160 | s3_object = s3_resource.Object(bucket_name, s3_path) 161 | return s3_object.e_tag 162 | 163 | 164 | @s3_request 165 | def s3_get(url, temp_file): 166 | """Pull a file directly from S3.""" 167 | s3_resource = boto3.resource("s3") 168 | bucket_name, s3_path = split_s3_path(url) 169 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 170 | 171 | 172 | def http_get(url, temp_file): 173 | req = requests.get(url, stream=True) 174 | content_length = req.headers.get('Content-Length') 175 | total = int(content_length) if content_length is not None else None 176 | progress = tqdm(unit="B", total=total) 177 | for chunk in req.iter_content(chunk_size=1024): 178 | if chunk: # filter out keep-alive new chunks 179 | progress.update(len(chunk)) 180 | temp_file.write(chunk) 181 | progress.close() 182 | 183 | 184 | def get_from_cache(url, cache_dir=None): 185 | """ 186 | Given a URL, look for the corresponding dataset in the local cache. 187 | If it's not there, download it. Then return the path to the cached file. 188 | """ 189 | if cache_dir is None: 190 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 191 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 192 | cache_dir = str(cache_dir) 193 | if sys.version_info[0] == 2 and not isinstance(cache_dir, str): 194 | cache_dir = str(cache_dir) 195 | 196 | if not os.path.exists(cache_dir): 197 | os.makedirs(cache_dir) 198 | 199 | # Get eTag to add to filename, if it exists. 200 | if url.startswith("s3://"): 201 | etag = s3_etag(url) 202 | else: 203 | try: 204 | response = requests.head(url, allow_redirects=True) 205 | if response.status_code != 200: 206 | etag = None 207 | else: 208 | etag = response.headers.get("ETag") 209 | except EnvironmentError: 210 | etag = None 211 | 212 | if sys.version_info[0] == 2 and etag is not None: 213 | etag = etag.decode('utf-8') 214 | filename = url_to_filename(url, etag) 215 | 216 | # get cache path to put the file 217 | cache_path = os.path.join(cache_dir, filename) 218 | 219 | # If we don't have a connection (etag is None) and can't identify the file 220 | # try to get the last downloaded one 221 | if not os.path.exists(cache_path) and etag is None: 222 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 223 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 224 | if matching_files: 225 | cache_path = os.path.join(cache_dir, matching_files[-1]) 226 | 227 | if not os.path.exists(cache_path): 228 | # Download to temporary file, then copy to cache dir once finished. 229 | # Otherwise you get corrupt cache entries if the download gets interrupted. 230 | with tempfile.NamedTemporaryFile() as temp_file: 231 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 232 | 233 | # GET file object 234 | if url.startswith("s3://"): 235 | s3_get(url, temp_file) 236 | else: 237 | http_get(url, temp_file) 238 | 239 | # we are copying the file before closing it, so flush to avoid truncation 240 | temp_file.flush() 241 | # shutil.copyfileobj() starts at the current position, so go to the start 242 | temp_file.seek(0) 243 | 244 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 245 | with open(cache_path, 'wb') as cache_file: 246 | shutil.copyfileobj(temp_file, cache_file) 247 | 248 | logger.info("creating metadata file for %s", cache_path) 249 | meta = {'url': url, 'etag': etag} 250 | meta_path = cache_path + '.json' 251 | with open(meta_path, 'w') as meta_file: 252 | output_string = json.dumps(meta) 253 | if sys.version_info[0] == 2 and isinstance(output_string, str): 254 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 255 | meta_file.write(output_string) 256 | 257 | logger.info("removing temp file %s", temp_file.name) 258 | 259 | return cache_path 260 | -------------------------------------------------------------------------------- /gen_training_data/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import sys 4 | import torch 5 | 6 | def init_single_node_dict(number=24): 7 | init_dict = {} 8 | for k in range(number): 9 | init_dict[str(k)] = { 10 | 'heading': k * 2 * math.pi / number, 11 | 'has_waypoint': False, 12 | 'waypoint': [], # could be multiple waypoints in a direction 13 | 'obstacle_distance': None, # maximum 2 meters 14 | 'obstacle_index': None, 15 | } 16 | return init_dict 17 | 18 | def horizontal_distance(start, end): 19 | return np.linalg.norm(np.array(start)[[0,2]] - np.array(end)[[0,2]]) 20 | 21 | def get_viewIndex15(heading,number=24): 22 | viewIndex = heading // (2*math.pi/number) 23 | b = heading % (2*math.pi/number) 24 | if (viewIndex == number-1) and (b >= (math.pi/number)): 25 | viewIndex = 0 26 | elif b >= (math.pi/number): 27 | viewIndex += 1 28 | return int(viewIndex) 29 | 30 | def get_distanceIndex12(dist): 31 | distanceIndex = int(dist // 0.25) 32 | # >12 means greater than 3.25m, <1 means shorter than 0.25m 33 | if distanceIndex > 12 or distanceIndex < 1: 34 | distanceIndex = int(-1) 35 | return distanceIndex - 1 36 | 37 | def get_obstacle_distanceIndex12(dist): 38 | # the obstacle distance is measured as the maximum distance 39 | # agent can travel before collision 40 | distanceIndex = int((dist) // 0.25) 41 | if distanceIndex > 11: 42 | distanceIndex = int(-1) 43 | return distanceIndex 44 | 45 | def get_obstacle_info(position, heading, sim): 46 | theta = -(heading - np.pi)/2 47 | rotation = np.quaternion(np.cos(theta),0,np.sin(theta),0) 48 | sim.set_agent_state(position,rotation) 49 | for i in range(12): 50 | sim.step_without_obs(1) 51 | if sim.previous_step_collided: 52 | break 53 | if not sim.previous_step_collided: 54 | return None, None 55 | collided_at = sim.get_agent_state().position 56 | distance = horizontal_distance(position,collided_at) 57 | index = get_obstacle_distanceIndex12(distance) 58 | 59 | return distance, index 60 | 61 | def edge_vec_to_indexes(edge_vec,number=24): 62 | ''' angle index 63 | {0, 1, ..., 23} for 24 angles, 15 degrees separation 64 | ''' 65 | 66 | 67 | angle = -np.arctan2(1.0, 0.0) + np.arctan2(edge_vec[1], edge_vec[0]) 68 | if angle < 0.0: 69 | angle += 2 * math.pi 70 | 71 | angleIndex = get_viewIndex15(angle,number=number) 72 | 73 | ''' distance index 74 | {0, 1, ..., 7} for 8 distances, 0.25 meters separation 75 | {-1} denotes the target waypoint is not in 2 meters range 76 | ''' 77 | distance = np.linalg.norm(edge_vec) 78 | distanceIndex = get_distanceIndex12(distance) 79 | 80 | return angle, angleIndex, distance, distanceIndex 81 | 82 | def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_length=100): 83 | """ 84 | Call in a loop to create terminal progress bar 85 | @params: 86 | iteration - Required : current iteration (Int) 87 | total - Required : total iterations (Int) 88 | prefix - Optional : prefix string (Str) 89 | suffix - Optional : suffix string (Str) 90 | decimals - Optional : positive number of decimals in percent complete (Int) 91 | bar_length - Optional : character length of bar (Int) 92 | """ 93 | str_format = "{0:." + str(decimals) + "f}" 94 | percents = str_format.format(100 * (iteration / float(total))) 95 | filled_length = int(round(bar_length * iteration / float(total))) 96 | bar = '_' * filled_length + '-' * (bar_length - filled_length) 97 | 98 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 99 | 100 | if iteration == total: 101 | sys.stdout.write('\n') 102 | sys.stdout.flush() 103 | 104 | def init_node_nav_dict(angles): 105 | init_dict = {} 106 | for k in range(angles): 107 | init_dict[str(k)] = { 108 | 'heading': k * math.pi / (angles/2), 109 | 'has_waypoint': False, 110 | 'waypoint': None, # could be multiple waypoints in a direction, but we only consider one 111 | 'obstacle_distance': None, # maximum 2 meters 112 | 'obstacle_index': None, 113 | } 114 | return init_dict 115 | 116 | 117 | def init_node_gt_dict(angles): 118 | init_dict = { 119 | # 'target': np.zeros((24, 8), dtype=np.int8), 120 | # 'weight': np.ones((24, 8)), 121 | 'target': np.zeros((angles, 12), dtype=np.int8), 122 | 'obstacle': np.ones((angles, 12), dtype=np.int8), 123 | 'weight': np.ones((angles, 12)), 124 | 'source_pos': None, 125 | 'target_pos': [], 126 | } 127 | return init_dict 128 | 129 | 130 | def init_node_gt_dict_twm03(): 131 | init_dict = { 132 | # 'target': np.zeros((24, 8), dtype=np.int8), 133 | # 'weight': np.ones((24, 8)), 134 | 'target': np.zeros((24, 12), dtype=np.int8), 135 | 'obstacle': np.ones((24, 12), dtype=np.int8), 136 | 'weight': np.zeros((24, 12)), 137 | } 138 | return init_dict 139 | 140 | 141 | def k_largest_index_argsort(a, k): 142 | idx = np.argsort(a.ravel())[:-k-1:-1] 143 | return np.column_stack(np.unravel_index(idx, a.shape)) 144 | 145 | 146 | def neighborhoods(mu, x_range, y_range, sigma, circular_x=True, gaussian=False): 147 | """ Generate masks centered at mu of the given x and y range with the 148 | origin in the centre of the output 149 | Inputs: 150 | mu: tensor (N, 2) 151 | Outputs: 152 | tensor (N, y_range, s_range) 153 | """ 154 | x_mu = mu[:,0].unsqueeze(1).unsqueeze(1) 155 | y_mu = mu[:,1].unsqueeze(1).unsqueeze(1) 156 | 157 | # Generate bivariate Gaussians centered at position mu 158 | x = torch.arange(start=0,end=x_range, device=mu.device, dtype=mu.dtype).unsqueeze(0).unsqueeze(0) 159 | y = torch.arange(start=0,end=y_range, device=mu.device, dtype=mu.dtype).unsqueeze(1).unsqueeze(0) 160 | 161 | y_diff = y - y_mu 162 | x_diff = x - x_mu 163 | if circular_x: 164 | x_diff = torch.min(torch.abs(x_diff), torch.abs(x_diff + x_range)) 165 | if gaussian: 166 | output = torch.exp(-0.5 * ((x_diff/sigma)**2 + (y_diff/sigma)**2 )) 167 | else: 168 | output = torch.logical_and(torch.abs(x_diff) <= sigma, torch.abs(y_diff) <= sigma).type(mu.dtype) 169 | 170 | return output 171 | 172 | 173 | def nms(pred, max_predictions=10, sigma=1.0, gaussian=False): 174 | ''' Input (batch_size, 1, height, width) ''' 175 | 176 | shape = pred.shape 177 | 178 | output = torch.zeros_like(pred) 179 | flat_pred = pred.reshape((shape[0],-1)) # (BATCH_SIZE, 24*48) 180 | supp_pred = pred.clone() 181 | flat_output = output.reshape((shape[0],-1)) # (BATCH_SIZE, 24*48) 182 | 183 | for i in range(max_predictions): 184 | # Find and save max over the entire map 185 | flat_supp_pred = supp_pred.reshape((shape[0],-1)) 186 | val, ix = torch.max(flat_supp_pred, dim=1) 187 | indices = torch.arange(0,shape[0]) 188 | flat_output[indices,ix] = flat_pred[indices,ix] 189 | 190 | # Suppression 191 | y = ix / shape[-1] 192 | x = ix % shape[-1] 193 | mu = torch.stack([x,y], dim=1).float() 194 | 195 | g = neighborhoods(mu, shape[-1], shape[-2], sigma, gaussian=gaussian) 196 | 197 | supp_pred *= (1-g.unsqueeze(1)) 198 | 199 | output[output < 0] = 0 200 | return output 201 | 202 | #!/usr/bin/env python3 203 | 204 | # Copyright (c) Facebook, Inc. and its affiliates. 205 | # This source code is licensed under the MIT license found in the 206 | # LICENSE file in the root directory of this source tree. 207 | 208 | from typing import ( 209 | TYPE_CHECKING, 210 | Any, 211 | Dict, 212 | List, 213 | Optional, 214 | Sequence, 215 | Set, 216 | Union, 217 | cast, 218 | ) 219 | 220 | import numpy as np 221 | from gym import spaces 222 | from gym.spaces.box import Box 223 | from numpy import ndarray 224 | 225 | if TYPE_CHECKING: 226 | from torch import Tensor 227 | 228 | from habitat_sim.simulator import MutableMapping, MutableMapping_T 229 | from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim 230 | from habitat.core.registry import registry 231 | from habitat.core.simulator import ( 232 | Config, 233 | VisualObservation, 234 | ) 235 | from habitat.core.spaces import Space 236 | 237 | @registry.register_simulator(name="Sim-v1") 238 | class Simulator(HabitatSim): 239 | r"""Simulator wrapper over habitat-sim 240 | 241 | habitat-sim repo: https://github.com/facebookresearch/habitat-sim 242 | 243 | Args: 244 | config: configuration for initializing the simulator. 245 | """ 246 | 247 | def __init__(self, config: Config) -> None: 248 | super().__init__(config) 249 | 250 | def step_without_obs(self, 251 | action: Union[str, int, MutableMapping_T[int, Union[str, int]]], 252 | dt: float = 1.0 / 60.0,): 253 | self._num_total_frames += 1 254 | if isinstance(action, MutableMapping): 255 | return_single = False 256 | else: 257 | action = cast(Dict[int, Union[str, int]], {self._default_agent_id: action}) 258 | return_single = True 259 | collided_dict: Dict[int, bool] = {} 260 | for agent_id, agent_act in action.items(): 261 | agent = self.get_agent(agent_id) 262 | collided_dict[agent_id] = agent.act(agent_act) 263 | self.__last_state[agent_id] = agent.get_state() 264 | 265 | # # step physics by dt 266 | # step_start_Time = time.time() 267 | # super().step_world(dt) 268 | # self._previous_step_time = time.time() - step_start_Time 269 | 270 | multi_observations = {} 271 | for agent_id in action.keys(): 272 | agent_observation = {} 273 | agent_observation["collided"] = collided_dict[agent_id] 274 | multi_observations[agent_id] = agent_observation 275 | 276 | 277 | if return_single: 278 | sim_obs = multi_observations[self._default_agent_id] 279 | else: 280 | sim_obs = multi_observations 281 | 282 | self._prev_sim_obs = sim_obs -------------------------------------------------------------------------------- /waypoint_predictor.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import argparse 4 | from dataloader import RGBDepthPano 5 | 6 | from image_encoders import RGBEncoder, DepthEncoder 7 | from TRM_net import BinaryDistPredictor_TRM, TRM_predict 8 | 9 | from eval import waypoint_eval 10 | 11 | import os 12 | import glob 13 | import utils 14 | import random 15 | from utils import nms 16 | from utils import print_progress 17 | from tensorboardX import SummaryWriter 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | def setup(args): 22 | torch.manual_seed(0) 23 | random.seed(0) 24 | exp_log_path = './checkpoints/%s/'%(args.EXP_ID) 25 | os.makedirs(exp_log_path, exist_ok=True) 26 | exp_log_path = './checkpoints/%s/snap/'%(args.EXP_ID) 27 | os.makedirs(exp_log_path, exist_ok=True) 28 | 29 | class Param(): 30 | def __init__(self): 31 | self.parser = argparse.ArgumentParser(description='Train waypoint predictor') 32 | 33 | self.parser.add_argument('--EXP_ID', type=str, default='test_0') 34 | self.parser.add_argument('--TRAINEVAL', type=str, default='train', help='trian or eval mode') 35 | self.parser.add_argument('--VIS', type=int, default=0, help='visualize predicted hearmaps') 36 | # self.parser.add_argument('--LOAD_EPOCH', type=int, default=None, help='specific an epoch to load for eval') 37 | 38 | self.parser.add_argument('--ANGLES', type=int, default=24) 39 | self.parser.add_argument('--NUM_IMGS', type=int, default=24) 40 | self.parser.add_argument('--NUM_CLASSES', type=int, default=12) 41 | self.parser.add_argument('--MAX_NUM_CANDIDATES', type=int, default=5) 42 | 43 | self.parser.add_argument('--PREDICTOR_NET', type=str, default='TRM', help='TRM only') 44 | 45 | self.parser.add_argument('--EPOCH', type=int, default=10) 46 | self.parser.add_argument('--BATCH_SIZE', type=int, default=2) 47 | self.parser.add_argument('--LEARNING_RATE', type=float, default=1e-4) 48 | self.parser.add_argument('--WEIGHT', type=int, default=0, help='weight the target map') 49 | 50 | self.parser.add_argument('--TRM_LAYER', default=2, type=int, help='number of TRM hidden layers') 51 | self.parser.add_argument('--TRM_NEIGHBOR', default=2, type=int, help='number of attention mask neighbor') 52 | self.parser.add_argument('--HEATMAP_OFFSET', default=2, type=int, help='an offset determined by image FoV and number of images') 53 | self.parser.add_argument('--HIDDEN_DIM', default=768, type=int) 54 | 55 | self.args = self.parser.parse_args() 56 | 57 | def predict_waypoints(args): 58 | 59 | print('\nArguments', args) 60 | log_dir = './checkpoints/%s/tensorboard/'%(args.EXP_ID) 61 | writer = SummaryWriter(log_dir=log_dir) 62 | 63 | ''' networks ''' 64 | rgb_encoder = RGBEncoder(resnet_pretrain=True, trainable=False).to(device) 65 | depth_encoder = DepthEncoder(resnet_pretrain=True, trainable=False).to(device) 66 | if args.PREDICTOR_NET == 'TRM': 67 | print('\nUsing TRM predictor') 68 | print('HIDDEN_DIM default to 768') 69 | args.HIDDEN_DIM = 768 70 | predictor = BinaryDistPredictor_TRM(args=args, 71 | hidden_dim=args.HIDDEN_DIM, n_classes=args.NUM_CLASSES).to(device) 72 | 73 | ''' load navigability (gt waypoints, obstacles and weights) ''' 74 | navigability_dict = utils.load_gt_navigability( 75 | './training_data/%s_*_mp3d_waypoint_twm0.2_obstacle_first_withpos.json'%(args.ANGLES)) 76 | 77 | ''' dataloader for rgb and depth images ''' 78 | train_img_dir = './gen_training_data/rgbd_fov90/train/*/*.pkl' 79 | traindataloader = RGBDepthPano(args, train_img_dir, navigability_dict) 80 | eval_img_dir = './gen_training_data/rgbd_fov90/val_unseen/*/*.pkl' 81 | evaldataloader = RGBDepthPano(args, eval_img_dir, navigability_dict) 82 | if args.TRAINEVAL == 'train': 83 | trainloader = torch.utils.data.DataLoader(traindataloader, 84 | batch_size=args.BATCH_SIZE, shuffle=True, num_workers=4) 85 | evalloader = torch.utils.data.DataLoader(evaldataloader, 86 | batch_size=args.BATCH_SIZE, shuffle=False, num_workers=4) 87 | 88 | ''' optimization ''' 89 | criterion_bcel = torch.nn.BCEWithLogitsLoss(reduction='none') 90 | criterion_mse = torch.nn.MSELoss(reduction='none') 91 | 92 | params = list(predictor.parameters()) 93 | optimizer = torch.optim.AdamW(params, lr=args.LEARNING_RATE) 94 | 95 | ''' training loop ''' 96 | if args.TRAINEVAL == 'train': 97 | print('\nTraining starts') 98 | best_val_1 = {"avg_wayscore": 0.0, "log_string": '', "update":False} 99 | best_val_2 = {"avg_pred_distance": 10.0, "log_string": '', "update":False} 100 | 101 | for epoch in range(args.EPOCH): # loop over the dataset multiple times 102 | sum_loss = 0.0 103 | 104 | rgb_encoder.eval() 105 | depth_encoder.eval() 106 | predictor.train() 107 | 108 | for i, data in enumerate(trainloader): 109 | scan_ids = data['scan_id'] 110 | waypoint_ids = data['waypoint_id'] 111 | rgb_imgs = data['rgb'].to(device) 112 | depth_imgs = data['depth'].to(device) 113 | 114 | ''' checking image orientation ''' 115 | # from PIL import Image 116 | # from matplotlib import pyplot 117 | # import numpy as np 118 | # # import pdb; pdb.set_trace() 119 | # out_img = np.swapaxes( 120 | # np.swapaxes( 121 | # data['no_trans_rgb'][0].cpu().numpy(), 1,2), 122 | # 2, 3) 123 | # for kk, out_img_i in enumerate(out_img): 124 | # im = Image.fromarray(out_img_i) 125 | # im.save("./play/%s.png"%(kk)) 126 | # pyplot.imsave("./play/mpl_%s.png"%(kk), out_img_i) 127 | # out_depth = data['no_trans_depth'][0].cpu().numpy() * 255 128 | # out_depth = out_depth.astype(np.uint8) 129 | # for kk, out_depth_i in enumerate(out_depth): 130 | # im = Image.fromarray(out_depth_i) 131 | # im.save("./play/depth_%s.png"%(kk)) 132 | 133 | ''' processing observations ''' 134 | rgb_feats = rgb_encoder(rgb_imgs) # (BATCH_SIZE*ANGLES, 2048) 135 | depth_feats = depth_encoder(depth_imgs) # (BATCH_SIZE*ANGLES, 128, 4, 4) 136 | 137 | ''' learning objectives ''' 138 | target, obstacle, weight, _, _ = utils.get_gt_nav_map( 139 | args.ANGLES, navigability_dict, scan_ids, waypoint_ids) 140 | target = target.to(device) 141 | obstacle = obstacle.to(device) 142 | weight = weight.to(device) 143 | 144 | if args.PREDICTOR_NET == 'TRM': 145 | vis_logits = TRM_predict('train', args, 146 | predictor, rgb_feats, depth_feats) 147 | 148 | loss_vis = criterion_mse(vis_logits, target) 149 | if args.WEIGHT: 150 | loss_vis = loss_vis * weight 151 | total_loss = loss_vis.sum() / vis_logits.size(0) / args.ANGLES 152 | 153 | total_loss.backward() 154 | optimizer.step() 155 | sum_loss += total_loss.item() 156 | 157 | print_progress(i+1, len(trainloader), prefix='Epoch: %d/%d'%((epoch+1),args.EPOCH)) 158 | writer.add_scalar("Train/Loss", sum_loss/(i+1), epoch) 159 | print('Train Loss: %.5f' % (sum_loss/(i+1))) # (epoch+1),args.EPOCH 160 | 161 | ''' evaluation - inference ''' 162 | # print('Evaluation ...') 163 | sum_loss = 0.0 164 | predictions = {'sample_id': [], 165 | 'source_pos': [], 'target_pos': [], 166 | 'probs': [], 'logits': [], 167 | 'target': [], 'obstacle': [], 'sample_loss': []} 168 | 169 | rgb_encoder.eval() 170 | depth_encoder.eval() 171 | predictor.eval() 172 | 173 | for i, data in enumerate(evalloader): 174 | scan_ids = data['scan_id'] 175 | waypoint_ids = data['waypoint_id'] 176 | sample_id = data['sample_id'] 177 | rgb_imgs = data['rgb'].to(device) 178 | depth_imgs = data['depth'].to(device) 179 | 180 | target, obstacle, weight, \ 181 | source_pos, target_pos = utils.get_gt_nav_map( 182 | args.ANGLES, navigability_dict, scan_ids, waypoint_ids) 183 | target = target.to(device) 184 | obstacle = obstacle.to(device) 185 | weight = weight.to(device) 186 | 187 | ''' processing observations ''' 188 | rgb_feats = rgb_encoder(rgb_imgs) # (BATCH_SIZE*ANGLES, 2048) 189 | depth_feats = depth_encoder(depth_imgs) # (BATCH_SIZE*ANGLES, 128, 4, 4) 190 | 191 | if args.PREDICTOR_NET == 'TRM': 192 | vis_probs, vis_logits = TRM_predict('eval', args, 193 | predictor, rgb_feats, depth_feats) 194 | overall_probs = vis_probs 195 | overall_logits = vis_logits 196 | loss_vis = criterion_mse(vis_logits, target) 197 | if args.WEIGHT: 198 | loss_vis = loss_vis * weight 199 | sample_loss = loss_vis.sum(-1).sum(-1) / args.ANGLES 200 | total_loss = loss_vis.sum() / vis_logits.size(0) / args.ANGLES 201 | 202 | sum_loss += total_loss.item() 203 | predictions['sample_id'].append(sample_id) 204 | predictions['source_pos'].append(source_pos) 205 | predictions['target_pos'].append(target_pos) 206 | predictions['probs'].append(overall_probs.tolist()) 207 | predictions['logits'].append((overall_logits.tolist())) 208 | predictions['target'].append(target.tolist()) 209 | predictions['obstacle'].append(obstacle.tolist()) 210 | predictions['sample_loss'].append(target.tolist()) 211 | 212 | print('Eval Loss: %.5f' % (sum_loss/(i+1))) 213 | results = waypoint_eval(args, predictions) 214 | writer.add_scalar("Evaluation/Loss", sum_loss/(i+1), epoch) 215 | writer.add_scalar("Evaluation/p_waypoint_openspace", results['p_waypoint_openspace'], epoch) 216 | writer.add_scalar("Evaluation/p_waypoint_obstacle", results['p_waypoint_obstacle'], epoch) 217 | writer.add_scalar("Evaluation/avg_wayscore", results['avg_wayscore'], epoch) 218 | writer.add_scalar("Evaluation/avg_pred_distance", results['avg_pred_distance'], epoch) 219 | log_string = 'Epoch %s '%(epoch) 220 | for key, value in results.items(): 221 | if key != 'candidates': 222 | log_string += '{} {:.5f} | '.format(str(key), value) 223 | print(log_string) 224 | 225 | # save checkpoint 226 | if results['avg_wayscore'] > best_val_1['avg_wayscore']: 227 | checkpoint_save_path = './checkpoints/%s/snap/check_val_best_avg_wayscore'%(args.EXP_ID) #, epoch+1 228 | utils.save_checkpoint(epoch+1, predictor, optimizer, checkpoint_save_path) 229 | print('New best avg_wayscore result found, checkpoint saved to %s'%(checkpoint_save_path)) 230 | best_val_1['avg_wayscore'] = results['avg_wayscore'] 231 | best_val_1['log_string'] = log_string 232 | checkpoint_reg_save_path = './checkpoints/%s/snap/check_latest'%(args.EXP_ID) #, epoch+1 233 | utils.save_checkpoint(epoch+1, predictor, optimizer, checkpoint_reg_save_path) 234 | print('Best avg_wayscore result til now: ', best_val_1['log_string']) 235 | 236 | if results['avg_pred_distance'] < best_val_2['avg_pred_distance']: 237 | checkpoint_save_path = './checkpoints/%s/snap/check_val_best_avg_pred_distance'%(args.EXP_ID) #, epoch+1 238 | utils.save_checkpoint(epoch+1, predictor, optimizer, checkpoint_save_path) 239 | print('New best avg_pred_distance result found, checkpoint saved to %s'%(checkpoint_save_path)) 240 | best_val_2['avg_pred_distance'] = results['avg_pred_distance'] 241 | best_val_2['log_string'] = log_string 242 | checkpoint_reg_save_path = './checkpoints/%s/snap/check_latest'%(args.EXP_ID) #, epoch+1 243 | utils.save_checkpoint(epoch+1, predictor, optimizer, checkpoint_reg_save_path) 244 | print('Best avg_pred_distance result til now: ', best_val_2['log_string']) 245 | 246 | elif args.TRAINEVAL == 'eval': 247 | ''' evaluation - inference (with a bit mixture-of-experts) ''' 248 | print('\nEvaluation mode, please doublecheck EXP_ID and LOAD_EPOCH') 249 | checkpoint_load_path = './checkpoints/%s/snap/check_val_best_avg_wayscore'%(args.EXP_ID) #args.LOAD_EPOCH 250 | epoch, predictor, optimizer = utils.load_checkpoint( 251 | predictor, optimizer, checkpoint_load_path) 252 | 253 | sum_loss = 0.0 254 | predictions = {'sample_id': [], 255 | 'source_pos': [], 'target_pos': [], 256 | 'probs': [], 'logits': [], 257 | 'target': [], 'obstacle': [], 'sample_loss': []} 258 | 259 | rgb_encoder.eval() 260 | depth_encoder.eval() 261 | predictor.eval() 262 | 263 | for i, data in enumerate(evalloader): 264 | if args.VIS and i == 5: 265 | break 266 | 267 | scan_ids = data['scan_id'] 268 | waypoint_ids = data['waypoint_id'] 269 | sample_id = data['sample_id'] 270 | rgb_imgs = data['rgb'].to(device) 271 | depth_imgs = data['depth'].to(device) 272 | 273 | target, obstacle, weight, \ 274 | source_pos, target_pos = utils.get_gt_nav_map( 275 | args.ANGLES, navigability_dict, scan_ids, waypoint_ids) 276 | target = target.to(device) 277 | obstacle = obstacle.to(device) 278 | weight = weight.to(device) 279 | 280 | ''' processing observations ''' 281 | rgb_feats = rgb_encoder(rgb_imgs) # (BATCH_SIZE*ANGLES, 2048) 282 | depth_feats = depth_encoder(depth_imgs) # (BATCH_SIZE*ANGLES, 128, 4, 4) 283 | 284 | ''' predicting the waypoint probabilities ''' 285 | if args.PREDICTOR_NET == 'TRM': 286 | vis_probs, vis_logits = TRM_predict('eval', args, 287 | predictor, rgb_feats, depth_feats) 288 | overall_probs = vis_probs 289 | overall_logits = vis_logits 290 | loss_vis = criterion_mse(vis_logits, target) 291 | 292 | if args.WEIGHT: 293 | loss_vis = loss_vis * weight 294 | sample_loss = loss_vis.sum(-1).sum(-1) / args.ANGLES 295 | total_loss = loss_vis.sum() / vis_logits.size(0) / args.ANGLES 296 | 297 | sum_loss += total_loss.item() 298 | predictions['sample_id'].append(sample_id) 299 | predictions['source_pos'].append(source_pos) 300 | predictions['target_pos'].append(target_pos) 301 | predictions['probs'].append(overall_probs.tolist()) 302 | predictions['logits'].append(overall_logits.tolist()) 303 | predictions['target'].append(target.tolist()) 304 | predictions['obstacle'].append(obstacle.tolist()) 305 | predictions['sample_loss'].append(target.tolist()) 306 | 307 | print('Eval Loss: %.5f' % (sum_loss/(i+1))) 308 | results = waypoint_eval(args, predictions) 309 | log_string = 'Epoch %s '%(epoch) 310 | for key, value in results.items(): 311 | if key != 'candidates': 312 | log_string += '{} {:.5f} | '.format(str(key), value) 313 | print(log_string) 314 | print('Evaluation Done') 315 | 316 | else: 317 | RunningModeError 318 | 319 | if __name__ == "__main__": 320 | param = Param() 321 | args = param.args 322 | setup(args) 323 | 324 | if args.VIS: 325 | assert args.TRAINEVAL == 'eval' 326 | 327 | predict_waypoints(args) 328 | -------------------------------------------------------------------------------- /transformer/pytorch_transformer/modeling_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import copy 22 | import json 23 | import logging 24 | import os 25 | from io import open 26 | 27 | import six 28 | import torch 29 | from torch import nn 30 | from torch.nn import CrossEntropyLoss 31 | from torch.nn import functional as F 32 | 33 | from .file_utils import cached_path 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | CONFIG_NAME = "config.json" 38 | WEIGHTS_NAME = "pytorch_model.bin" 39 | TF_WEIGHTS_NAME = 'model.ckpt' 40 | 41 | 42 | try: 43 | from torch.nn import Identity 44 | except ImportError: 45 | # Older PyTorch compatibility 46 | class Identity(nn.Module): 47 | r"""A placeholder identity operator that is argument-insensitive. 48 | """ 49 | def __init__(self, *args, **kwargs): 50 | super(Identity, self).__init__() 51 | 52 | def forward(self, input): 53 | return input 54 | 55 | 56 | if not six.PY2: 57 | def add_start_docstrings(*docstr): 58 | def docstring_decorator(fn): 59 | fn.__doc__ = ''.join(docstr) + fn.__doc__ 60 | return fn 61 | return docstring_decorator 62 | else: 63 | # Not possible to update class docstrings on python2 64 | def add_start_docstrings(*docstr): 65 | def docstring_decorator(fn): 66 | return fn 67 | return docstring_decorator 68 | 69 | 70 | class PretrainedConfig(object): 71 | """ Base class for all configuration classes. 72 | Handle a few common parameters and methods for loading/downloading/saving configurations. 73 | """ 74 | pretrained_config_archive_map = {} 75 | 76 | def __init__(self, **kwargs): 77 | self.finetuning_task = kwargs.pop('finetuning_task', None) 78 | self.num_labels = kwargs.pop('num_labels', 2) 79 | self.output_attentions = kwargs.pop('output_attentions', False) 80 | self.output_hidden_states = kwargs.pop('output_hidden_states', False) 81 | self.torchscript = kwargs.pop('torchscript', False) 82 | 83 | def save_pretrained(self, save_directory): 84 | """ Save a configuration object to a directory, so that it 85 | can be re-loaded using the `from_pretrained(save_directory)` class method. 86 | """ 87 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 88 | 89 | # If we save using the predefined names, we can load using `from_pretrained` 90 | output_config_file = os.path.join(save_directory, CONFIG_NAME) 91 | 92 | self.to_json_file(output_config_file) 93 | 94 | @classmethod 95 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 96 | r""" Instantiate a PretrainedConfig from a pre-trained model configuration. 97 | 98 | Params: 99 | **pretrained_model_name_or_path**: either: 100 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache 101 | or download and cache if not already stored in cache (e.g. 'bert-base-uncased'). 102 | - a path to a `directory` containing a configuration file saved 103 | using the `save_pretrained(save_directory)` method. 104 | - a path or url to a saved configuration `file`. 105 | **cache_dir**: (`optional`) string: 106 | Path to a directory in which a downloaded pre-trained model 107 | configuration should be cached if the standard cache should not be used. 108 | **return_unused_kwargs**: (`optional`) bool: 109 | - If False, then this function returns just the final configuration object. 110 | - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` 111 | is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: 112 | ie the part of kwargs which has not been used to update `config` and is otherwise ignored. 113 | **kwargs**: (`optional`) dict: 114 | Dictionary of key/value pairs with which to update the configuration object after loading. 115 | - The values in kwargs of any keys which are configuration attributes will be used 116 | to override the loaded values. 117 | - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled 118 | by the `return_unused_kwargs` keyword parameter. 119 | 120 | Examples:: 121 | 122 | >>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 123 | >>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 124 | >>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') 125 | >>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 126 | >>> assert config.output_attention == True 127 | >>> config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, 128 | >>> foo=False, return_unused_kwargs=True) 129 | >>> assert config.output_attention == True 130 | >>> assert unused_kwargs == {'foo': False} 131 | 132 | """ 133 | cache_dir = kwargs.pop('cache_dir', None) 134 | return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) 135 | 136 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 137 | config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] 138 | elif os.path.isdir(pretrained_model_name_or_path): 139 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 140 | else: 141 | config_file = pretrained_model_name_or_path 142 | # redirect to the cache, if necessary 143 | try: 144 | resolved_config_file = cached_path(config_file, cache_dir=cache_dir) 145 | except EnvironmentError: 146 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 147 | logger.error( 148 | "Couldn't reach server at '{}' to download pretrained model configuration file.".format( 149 | config_file)) 150 | else: 151 | logger.error( 152 | "Model name '{}' was not found in model name list ({}). " 153 | "We assumed '{}' was a path or url but couldn't find any file " 154 | "associated to this path or url.".format( 155 | pretrained_model_name_or_path, 156 | ', '.join(cls.pretrained_config_archive_map.keys()), 157 | config_file)) 158 | return None 159 | if resolved_config_file == config_file: 160 | logger.info("loading configuration file {}".format(config_file)) 161 | else: 162 | logger.info("loading configuration file {} from cache at {}".format( 163 | config_file, resolved_config_file)) 164 | 165 | # Load config 166 | config = cls.from_json_file(resolved_config_file) 167 | 168 | # Update config with kwargs if needed 169 | to_remove = [] 170 | for key, value in kwargs.items(): 171 | if hasattr(config, key): 172 | setattr(config, key, value) 173 | to_remove.append(key) 174 | for key in to_remove: 175 | kwargs.pop(key, None) 176 | 177 | logger.info("Model config %s", config) 178 | if return_unused_kwargs: 179 | return config, kwargs 180 | else: 181 | return config 182 | 183 | @classmethod 184 | def from_dict(cls, json_object): 185 | """Constructs a `Config` from a Python dictionary of parameters.""" 186 | config = cls(vocab_size_or_config_json_file=-1) 187 | for key, value in json_object.items(): 188 | config.__dict__[key] = value 189 | return config 190 | 191 | @classmethod 192 | def from_json_file(cls, json_file): 193 | """Constructs a `BertConfig` from a json file of parameters.""" 194 | with open(json_file, "r", encoding='utf-8') as reader: 195 | text = reader.read() 196 | return cls.from_dict(json.loads(text)) 197 | 198 | def __eq__(self, other): 199 | return self.__dict__ == other.__dict__ 200 | 201 | def __repr__(self): 202 | return str(self.to_json_string()) 203 | 204 | def to_dict(self): 205 | """Serializes this instance to a Python dictionary.""" 206 | output = copy.deepcopy(self.__dict__) 207 | return output 208 | 209 | def to_json_string(self): 210 | """Serializes this instance to a JSON string.""" 211 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 212 | 213 | def to_json_file(self, json_file_path): 214 | """ Save this instance to a json file.""" 215 | with open(json_file_path, "w", encoding='utf-8') as writer: 216 | writer.write(self.to_json_string()) 217 | 218 | 219 | class PreTrainedModel(nn.Module): 220 | """ Base class for all models. Handle loading/storing model config and 221 | a simple interface for dowloading and loading pretrained models. 222 | """ 223 | config_class = PretrainedConfig 224 | pretrained_model_archive_map = {} 225 | load_tf_weights = lambda model, config, path: None 226 | base_model_prefix = "" 227 | input_embeddings = None 228 | 229 | def __init__(self, config, *inputs, **kwargs): 230 | super(PreTrainedModel, self).__init__() 231 | if not isinstance(config, PretrainedConfig): 232 | raise ValueError( 233 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " 234 | "To create a model from a pretrained model use " 235 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 236 | self.__class__.__name__, self.__class__.__name__ 237 | )) 238 | # Save config in model 239 | self.config = config 240 | 241 | def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): 242 | """ Build a resized Embedding Module from a provided token Embedding Module. 243 | Increasing the size will add newly initialized vectors at the end 244 | Reducing the size will remove vectors from the end 245 | 246 | Args: 247 | new_num_tokens: (`optional`) int 248 | New number of tokens in the embedding matrix. 249 | Increasing the size will add newly initialized vectors at the end 250 | Reducing the size will remove vectors from the end 251 | If not provided or None: return the provided token Embedding Module. 252 | Return: ``torch.nn.Embeddings`` 253 | Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None 254 | """ 255 | if new_num_tokens is None: 256 | return old_embeddings 257 | 258 | old_num_tokens, old_embedding_dim = old_embeddings.weight.size() 259 | if old_num_tokens == new_num_tokens: 260 | return old_embeddings 261 | 262 | # Build new embeddings 263 | new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) 264 | new_embeddings.to(old_embeddings.weight.device) 265 | 266 | # initialize all new embeddings (in particular added tokens) 267 | self.init_weights(new_embeddings) 268 | 269 | # Copy word embeddings from the previous weights 270 | num_tokens_to_copy = min(old_num_tokens, new_num_tokens) 271 | new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] 272 | 273 | return new_embeddings 274 | 275 | def _tie_or_clone_weights(self, first_module, second_module): 276 | """ Tie or clone module weights depending of weither we are using TorchScript or not 277 | """ 278 | if self.config.torchscript: 279 | first_module.weight = nn.Parameter(second_module.weight.clone()) 280 | else: 281 | first_module.weight = second_module.weight 282 | 283 | def resize_token_embeddings(self, new_num_tokens=None): 284 | """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. 285 | Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. 286 | 287 | Args: 288 | new_num_tokens: (`optional`) int 289 | New number of tokens in the embedding matrix. 290 | Increasing the size will add newly initialized vectors at the end 291 | Reducing the size will remove vectors from the end 292 | If not provided or None: does nothing and just returns a pointer to the input tokens Embedding Module of the model. 293 | 294 | Return: ``torch.nn.Embeddings`` 295 | Pointer to the input tokens Embedding Module of the model 296 | """ 297 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed 298 | model_embeds = base_model._resize_token_embeddings(new_num_tokens) 299 | if new_num_tokens is None: 300 | return model_embeds 301 | 302 | # Update base model and current model config 303 | self.config.vocab_size = new_num_tokens 304 | base_model.vocab_size = new_num_tokens 305 | 306 | # Tie weights again if needed 307 | if hasattr(self, 'tie_weights'): 308 | self.tie_weights() 309 | 310 | return model_embeds 311 | 312 | def prune_heads(self, heads_to_prune): 313 | """ Prunes heads of the base model. 314 | Args: 315 | heads_to_prune: dict of {layer_num (int): list of heads to prune in this layer (list of int)} 316 | """ 317 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed 318 | base_model._prune_heads(heads_to_prune) 319 | 320 | def save_pretrained(self, save_directory): 321 | """ Save a model with its configuration file to a directory, so that it 322 | can be re-loaded using the `from_pretrained(save_directory)` class method. 323 | """ 324 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 325 | 326 | # Only save the model it-self if we are using distributed training 327 | model_to_save = self.module if hasattr(self, 'module') else self 328 | 329 | # Save configuration file 330 | model_to_save.config.save_pretrained(save_directory) 331 | 332 | # If we save using the predefined names, we can load using `from_pretrained` 333 | output_model_file = os.path.join(save_directory, WEIGHTS_NAME) 334 | 335 | torch.save(model_to_save.state_dict(), output_model_file) 336 | 337 | @classmethod 338 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): 339 | r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. 340 | 341 | The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated) 342 | To train the model, you should first set it back in training mode with `model.train()` 343 | 344 | Params: 345 | **pretrained_model_name_or_path**: either: 346 | - a string with the `shortcut name` of a pre-trained model to load from cache 347 | or download and cache if not already stored in cache (e.g. 'bert-base-uncased'). 348 | - a path to a `directory` containing a configuration file saved 349 | using the `save_pretrained(save_directory)` method. 350 | - a path or url to a tensorflow index checkpoint `file` (e.g. `./tf_model/model.ckpt.index`). 351 | In this case, ``from_tf`` should be set to True and a configuration object should be 352 | provided as `config` argument. This loading option is slower than converting the TensorFlow 353 | checkpoint in a PyTorch model using the provided conversion scripts and loading 354 | the PyTorch model afterwards. 355 | **model_args**: (`optional`) Sequence: 356 | All remaning positional arguments will be passed to the underlying model's __init__ function 357 | **config**: an optional configuration for the model to use instead of an automatically loaded configuation. 358 | Configuration can be automatically loaded when: 359 | - the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or 360 | - the model was saved using the `save_pretrained(save_directory)` (loaded by suppling the save directory). 361 | **state_dict**: an optional state dictionnary for the model to use instead of a state dictionary loaded 362 | from saved weights file. 363 | This option can be used if you want to create a model from a pretrained configuraton but load your own weights. 364 | In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not 365 | a simpler option. 366 | **cache_dir**: (`optional`) string: 367 | Path to a directory in which a downloaded pre-trained model 368 | configuration should be cached if the standard cache should not be used. 369 | **output_loading_info**: (`optional`) boolean: 370 | Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. 371 | **kwargs**: (`optional`) dict: 372 | Dictionary of key, values to update the configuration object after loading. 373 | Can be used to override selected configuration parameters. E.g. ``output_attention=True``. 374 | 375 | - If a configuration is provided with `config`, **kwargs will be directly passed 376 | to the underlying model's __init__ method. 377 | - If a configuration is not provided, **kwargs will be first passed to the pretrained 378 | model configuration class loading function (`PretrainedConfig.from_pretrained`). 379 | Each key of **kwargs that corresponds to a configuration attribute 380 | will be used to override said attribute with the supplied **kwargs value. 381 | Remaining keys that do not correspond to any configuration attribute will 382 | be passed to the underlying model's __init__ function. 383 | 384 | Examples:: 385 | 386 | >>> model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. 387 | >>> model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` 388 | >>> model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading 389 | >>> assert model.config.output_attention == True 390 | >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) 391 | >>> config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') 392 | >>> model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) 393 | 394 | """ 395 | config = kwargs.pop('config', None) 396 | state_dict = kwargs.pop('state_dict', None) 397 | cache_dir = kwargs.pop('cache_dir', None) 398 | from_tf = kwargs.pop('from_tf', False) 399 | output_loading_info = kwargs.pop('output_loading_info', False) 400 | 401 | # Load config 402 | if config is None: 403 | config, model_kwargs = cls.config_class.from_pretrained( 404 | pretrained_model_name_or_path, *model_args, 405 | cache_dir=cache_dir, return_unused_kwargs=True, 406 | **kwargs 407 | ) 408 | else: 409 | model_kwargs = kwargs 410 | 411 | # Load model 412 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map: 413 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] 414 | elif os.path.isdir(pretrained_model_name_or_path): 415 | if from_tf: 416 | # Directly load from a TensorFlow checkpoint 417 | archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") 418 | else: 419 | archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) 420 | else: 421 | if from_tf: 422 | # Directly load from a TensorFlow checkpoint 423 | archive_file = pretrained_model_name_or_path + ".index" 424 | else: 425 | archive_file = pretrained_model_name_or_path 426 | # redirect to the cache, if necessary 427 | try: 428 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 429 | except EnvironmentError: 430 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map: 431 | logger.error( 432 | "Couldn't reach server at '{}' to download pretrained weights.".format( 433 | archive_file)) 434 | else: 435 | logger.error( 436 | "Model name '{}' was not found in model name list ({}). " 437 | "We assumed '{}' was a path or url but couldn't find any file " 438 | "associated to this path or url.".format( 439 | pretrained_model_name_or_path, 440 | ', '.join(cls.pretrained_model_archive_map.keys()), 441 | archive_file)) 442 | return None 443 | if resolved_archive_file == archive_file: 444 | logger.info("loading weights file {}".format(archive_file)) 445 | else: 446 | logger.info("loading weights file {} from cache at {}".format( 447 | archive_file, resolved_archive_file)) 448 | 449 | # Instantiate model. 450 | model = cls(config, *model_args, **model_kwargs) 451 | 452 | if state_dict is None and not from_tf: 453 | state_dict = torch.load(resolved_archive_file, map_location='cpu') 454 | if from_tf: 455 | # Directly load from a TensorFlow checkpoint 456 | return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' 457 | 458 | # Convert old format to new format if needed from a PyTorch state_dict 459 | old_keys = [] 460 | new_keys = [] 461 | for key in state_dict.keys(): 462 | new_key = None 463 | if 'gamma' in key: 464 | new_key = key.replace('gamma', 'weight') 465 | if 'beta' in key: 466 | new_key = key.replace('beta', 'bias') 467 | if new_key: 468 | old_keys.append(key) 469 | new_keys.append(new_key) 470 | for old_key, new_key in zip(old_keys, new_keys): 471 | state_dict[new_key] = state_dict.pop(old_key) 472 | 473 | # Load from a PyTorch state_dict 474 | missing_keys = [] 475 | unexpected_keys = [] 476 | error_msgs = [] 477 | # copy state_dict so _load_from_state_dict can modify it 478 | metadata = getattr(state_dict, '_metadata', None) 479 | state_dict = state_dict.copy() 480 | if metadata is not None: 481 | state_dict._metadata = metadata 482 | 483 | def load(module, prefix=''): 484 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 485 | module._load_from_state_dict( 486 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 487 | for name, child in module._modules.items(): 488 | if child is not None: 489 | load(child, prefix + name + '.') 490 | 491 | # Make sure we are able to load base models as well as derived models (with heads) 492 | start_prefix = '' 493 | model_to_load = model 494 | if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): 495 | start_prefix = cls.base_model_prefix + '.' 496 | if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): 497 | model_to_load = getattr(model, cls.base_model_prefix) 498 | 499 | load(model_to_load, prefix=start_prefix) 500 | if len(missing_keys) > 0: 501 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 502 | model.__class__.__name__, missing_keys)) 503 | 504 | print(" Weights of {} not initialized from pretrained model: {}".format( 505 | model.__class__.__name__, missing_keys)) 506 | 507 | if len(unexpected_keys) > 0: 508 | logger.info("Weights from pretrained model not used in {}: {}".format( 509 | model.__class__.__name__, unexpected_keys)) 510 | 511 | print(" Weights from pretrained model not used in {}: {}".format( 512 | model.__class__.__name__, unexpected_keys)) 513 | 514 | if len(error_msgs) > 0: 515 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 516 | model.__class__.__name__, "\n\t".join(error_msgs))) 517 | 518 | if hasattr(model, 'tie_weights'): 519 | model.tie_weights() # make sure word embedding weights are still tied 520 | 521 | # Set model in evaluation mode to desactivate DropOut modules by default 522 | model.eval() 523 | 524 | if output_loading_info: 525 | loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} 526 | return model, loading_info 527 | 528 | return model 529 | 530 | 531 | class Conv1D(nn.Module): 532 | def __init__(self, nf, nx): 533 | """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) 534 | Basically works like a Linear layer but the weights are transposed 535 | """ 536 | super(Conv1D, self).__init__() 537 | self.nf = nf 538 | w = torch.empty(nx, nf) 539 | nn.init.normal_(w, std=0.02) 540 | self.weight = nn.Parameter(w) 541 | self.bias = nn.Parameter(torch.zeros(nf)) 542 | 543 | def forward(self, x): 544 | size_out = x.size()[:-1] + (self.nf,) 545 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 546 | x = x.view(*size_out) 547 | return x 548 | 549 | 550 | class PoolerStartLogits(nn.Module): 551 | """ Compute SQuAD start_logits from sequence hidden states. """ 552 | def __init__(self, config): 553 | super(PoolerStartLogits, self).__init__() 554 | self.dense = nn.Linear(config.hidden_size, 1) 555 | 556 | def forward(self, hidden_states, p_mask=None): 557 | """ Args: 558 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)` 559 | invalid position mask such as query and special symbols (PAD, SEP, CLS) 560 | 1.0 means token should be masked. 561 | """ 562 | x = self.dense(hidden_states).squeeze(-1) 563 | 564 | if p_mask is not None: 565 | x = x * (1 - p_mask) - 1e30 * p_mask 566 | 567 | return x 568 | 569 | 570 | class PoolerEndLogits(nn.Module): 571 | """ Compute SQuAD end_logits from sequence hidden states and start token hidden state. 572 | """ 573 | def __init__(self, config): 574 | super(PoolerEndLogits, self).__init__() 575 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) 576 | self.activation = nn.Tanh() 577 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 578 | self.dense_1 = nn.Linear(config.hidden_size, 1) 579 | 580 | def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None): 581 | """ Args: 582 | One of ``start_states``, ``start_positions`` should be not None. 583 | If both are set, ``start_positions`` overrides ``start_states``. 584 | 585 | **start_states**: ``torch.LongTensor`` of shape identical to hidden_states 586 | hidden states of the first tokens for the labeled span. 587 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 588 | position of the first token for the labeled span: 589 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` 590 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 591 | 1.0 means token should be masked. 592 | """ 593 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" 594 | if start_positions is not None: 595 | slen, hsz = hidden_states.shape[-2:] 596 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 597 | start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) 598 | start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) 599 | 600 | x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) 601 | x = self.activation(x) 602 | x = self.LayerNorm(x) 603 | x = self.dense_1(x).squeeze(-1) 604 | 605 | if p_mask is not None: 606 | x = x * (1 - p_mask) - 1e30 * p_mask 607 | 608 | return x 609 | 610 | 611 | class PoolerAnswerClass(nn.Module): 612 | """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """ 613 | def __init__(self, config): 614 | super(PoolerAnswerClass, self).__init__() 615 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) 616 | self.activation = nn.Tanh() 617 | self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) 618 | 619 | def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None): 620 | """ 621 | Args: 622 | One of ``start_states``, ``start_positions`` should be not None. 623 | If both are set, ``start_positions`` overrides ``start_states``. 624 | 625 | **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``. 626 | hidden states of the first tokens for the labeled span. 627 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 628 | position of the first token for the labeled span. 629 | **cls_index**: torch.LongTensor of shape ``(batch_size,)`` 630 | position of the CLS token. If None, take the last token. 631 | 632 | note(Original repo): 633 | no dependency on end_feature so that we can obtain one single `cls_logits` 634 | for each sample 635 | """ 636 | hsz = hidden_states.shape[-1] 637 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" 638 | if start_positions is not None: 639 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 640 | start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) 641 | 642 | if cls_index is not None: 643 | cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 644 | cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) 645 | else: 646 | cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) 647 | 648 | x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) 649 | x = self.activation(x) 650 | x = self.dense_1(x).squeeze(-1) 651 | 652 | return x 653 | 654 | 655 | class SQuADHead(nn.Module): 656 | r""" A SQuAD head inspired by XLNet. 657 | 658 | Parameters: 659 | config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model. 660 | 661 | Inputs: 662 | **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)`` 663 | hidden states of sequence tokens 664 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 665 | position of the first token for the labeled span. 666 | **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 667 | position of the last token for the labeled span. 668 | **cls_index**: torch.LongTensor of shape ``(batch_size,)`` 669 | position of the CLS token. If None, take the last token. 670 | **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)`` 671 | Whether the question has a possible answer in the paragraph or not. 672 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` 673 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 674 | 1.0 means token should be masked. 675 | 676 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 677 | **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: 678 | Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. 679 | **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 680 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)`` 681 | Log probabilities for the top config.start_n_top start token possibilities (beam-search). 682 | **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 683 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)`` 684 | Indices for the top config.start_n_top start token possibilities (beam-search). 685 | **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 686 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` 687 | Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). 688 | **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 689 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` 690 | Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). 691 | **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 692 | ``torch.FloatTensor`` of shape ``(batch_size,)`` 693 | Log probabilities for the ``is_impossible`` label of the answers. 694 | """ 695 | def __init__(self, config): 696 | super(SQuADHead, self).__init__() 697 | self.start_n_top = config.start_n_top 698 | self.end_n_top = config.end_n_top 699 | 700 | self.start_logits = PoolerStartLogits(config) 701 | self.end_logits = PoolerEndLogits(config) 702 | self.answer_class = PoolerAnswerClass(config) 703 | 704 | def forward(self, hidden_states, start_positions=None, end_positions=None, 705 | cls_index=None, is_impossible=None, p_mask=None): 706 | outputs = () 707 | 708 | start_logits = self.start_logits(hidden_states, p_mask=p_mask) 709 | 710 | if start_positions is not None and end_positions is not None: 711 | # If we are on multi-GPU, let's remove the dimension added by batch splitting 712 | for x in (start_positions, end_positions, cls_index, is_impossible): 713 | if x is not None and x.dim() > 1: 714 | x.squeeze_(-1) 715 | 716 | # during training, compute the end logits based on the ground truth of the start position 717 | end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) 718 | 719 | loss_fct = CrossEntropyLoss() 720 | start_loss = loss_fct(start_logits, start_positions) 721 | end_loss = loss_fct(end_logits, end_positions) 722 | total_loss = (start_loss + end_loss) / 2 723 | 724 | if cls_index is not None and is_impossible is not None: 725 | # Predict answerability from the representation of CLS and START 726 | cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) 727 | loss_fct_cls = nn.BCEWithLogitsLoss() 728 | cls_loss = loss_fct_cls(cls_logits, is_impossible) 729 | 730 | # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss 731 | total_loss += cls_loss * 0.5 732 | 733 | outputs = (total_loss,) + outputs 734 | 735 | else: 736 | # during inference, compute the end logits based on beam search 737 | bsz, slen, hsz = hidden_states.size() 738 | start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) 739 | 740 | start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) 741 | start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) 742 | start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) 743 | start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) 744 | 745 | hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) 746 | p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None 747 | end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) 748 | end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) 749 | 750 | end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top) 751 | end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) 752 | end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) 753 | 754 | start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) 755 | cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) 756 | 757 | outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs 758 | 759 | # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits 760 | # or (if labels are provided) (total_loss,) 761 | return outputs 762 | 763 | 764 | class SequenceSummary(nn.Module): 765 | r""" Compute a single vector summary of a sequence hidden states according to various possibilities: 766 | Args of the config class: 767 | summary_type: 768 | - 'last' => [default] take the last token hidden state (like XLNet) 769 | - 'first' => take the first token hidden state (like Bert) 770 | - 'mean' => take the mean of all tokens hidden states 771 | - 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2) 772 | - 'attn' => Not implemented now, use multi-head attention 773 | summary_use_proj: Add a projection after the vector extraction 774 | summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. 775 | summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default 776 | summary_first_dropout: Add a dropout before the projection and activation 777 | summary_last_dropout: Add a dropout after the projection and activation 778 | """ 779 | def __init__(self, config): 780 | super(SequenceSummary, self).__init__() 781 | 782 | self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' 783 | if config.summary_type == 'attn': 784 | # We should use a standard multi-head attention module with absolute positional embedding for that. 785 | # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 786 | # We can probably just use the multi-head attention module of PyTorch >=1.1.0 787 | raise NotImplementedError 788 | 789 | self.summary = Identity() 790 | if hasattr(config, 'summary_use_proj') and config.summary_use_proj: 791 | if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: 792 | num_classes = config.num_labels 793 | else: 794 | num_classes = config.hidden_size 795 | self.summary = nn.Linear(config.hidden_size, num_classes) 796 | 797 | self.activation = Identity() 798 | if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': 799 | self.activation = nn.Tanh() 800 | 801 | self.first_dropout = Identity() 802 | if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: 803 | self.first_dropout = nn.Dropout(config.summary_first_dropout) 804 | 805 | self.last_dropout = Identity() 806 | if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: 807 | self.last_dropout = nn.Dropout(config.summary_last_dropout) 808 | 809 | def forward(self, hidden_states, token_ids=None): 810 | """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. 811 | token_ids: [optional] index of the classification token if summary_type == 'token_ids', 812 | shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. 813 | if summary_type == 'token_ids' and token_ids is None: 814 | we take the last token of the sequence as classification token 815 | """ 816 | if self.summary_type == 'last': 817 | output = hidden_states[:, -1] 818 | elif self.summary_type == 'first': 819 | output = hidden_states[:, 0] 820 | elif self.summary_type == 'mean': 821 | output = hidden_states.mean(dim=1) 822 | elif self.summary_type == 'token_ids': 823 | if token_ids is None: 824 | token_ids = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long) 825 | else: 826 | token_ids = token_ids.unsqueeze(-1).unsqueeze(-1) 827 | token_ids = token_ids.expand((-1,) * (token_ids.dim()-1) + (hidden_states.size(-1),)) 828 | # shape of token_ids: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states 829 | output = hidden_states.gather(-2, token_ids).squeeze(-2) # shape (bsz, XX, hidden_size) 830 | elif self.summary_type == 'attn': 831 | raise NotImplementedError 832 | 833 | output = self.first_dropout(output) 834 | output = self.summary(output) 835 | output = self.activation(output) 836 | output = self.last_dropout(output) 837 | 838 | return output 839 | 840 | 841 | def prune_linear_layer(layer, index, dim=0): 842 | """ Prune a linear layer (a model parameters) to keep only entries in index. 843 | Return the pruned layer as a new layer with requires_grad=True. 844 | Used to remove heads. 845 | """ 846 | index = index.to(layer.weight.device) 847 | W = layer.weight.index_select(dim, index).clone().detach() 848 | if layer.bias is not None: 849 | if dim == 1: 850 | b = layer.bias.clone().detach() 851 | else: 852 | b = layer.bias[index].clone().detach() 853 | new_size = list(layer.weight.size()) 854 | new_size[dim] = len(index) 855 | new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) 856 | new_layer.weight.requires_grad = False 857 | new_layer.weight.copy_(W.contiguous()) 858 | new_layer.weight.requires_grad = True 859 | if layer.bias is not None: 860 | new_layer.bias.requires_grad = False 861 | new_layer.bias.copy_(b.contiguous()) 862 | new_layer.bias.requires_grad = True 863 | return new_layer 864 | 865 | 866 | def prune_conv1d_layer(layer, index, dim=1): 867 | """ Prune a Conv1D layer (a model parameters) to keep only entries in index. 868 | A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. 869 | Return the pruned layer as a new layer with requires_grad=True. 870 | Used to remove heads. 871 | """ 872 | index = index.to(layer.weight.device) 873 | W = layer.weight.index_select(dim, index).clone().detach() 874 | if dim == 0: 875 | b = layer.bias.clone().detach() 876 | else: 877 | b = layer.bias[index].clone().detach() 878 | new_size = list(layer.weight.size()) 879 | new_size[dim] = len(index) 880 | new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) 881 | new_layer.weight.requires_grad = False 882 | new_layer.weight.copy_(W.contiguous()) 883 | new_layer.weight.requires_grad = True 884 | new_layer.bias.requires_grad = False 885 | new_layer.bias.copy_(b.contiguous()) 886 | new_layer.bias.requires_grad = True 887 | return new_layer 888 | 889 | 890 | def prune_layer(layer, index, dim=None): 891 | """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index. 892 | Return the pruned layer as a new layer with requires_grad=True. 893 | Used to remove heads. 894 | """ 895 | if isinstance(layer, nn.Linear): 896 | return prune_linear_layer(layer, index, dim=0 if dim is None else dim) 897 | elif isinstance(layer, Conv1D): 898 | return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) 899 | else: 900 | raise ValueError("Can't prune layer of class {}".format(layer.__class__)) 901 | -------------------------------------------------------------------------------- /transformer/pytorch_transformer/modeling_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model. """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import math 23 | import os 24 | import sys 25 | from io import open 26 | 27 | import torch 28 | from torch import nn 29 | from torch.nn import CrossEntropyLoss, MSELoss 30 | 31 | from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, 32 | prune_linear_layer, add_start_docstrings) 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { 37 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", 38 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", 39 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", 40 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", 41 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", 42 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", 43 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", 44 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin", 45 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", 46 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", 47 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", 48 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", 49 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", 50 | } 51 | 52 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 53 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 54 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 55 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 56 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 57 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 58 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 59 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 60 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 61 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 62 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 63 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 64 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 65 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 66 | } 67 | 68 | 69 | def load_tf_weights_in_bert(model, config, tf_checkpoint_path): 70 | """ Load tf checkpoints in a pytorch model. 71 | """ 72 | try: 73 | import re 74 | import numpy as np 75 | import tensorflow as tf 76 | except ImportError: 77 | logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " 78 | "https://www.tensorflow.org/install/ for installation instructions.") 79 | raise 80 | tf_path = os.path.abspath(tf_checkpoint_path) 81 | logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) 82 | # Load weights from TF model 83 | init_vars = tf.train.list_variables(tf_path) 84 | names = [] 85 | arrays = [] 86 | for name, shape in init_vars: 87 | logger.info("Loading TF weight {} with shape {}".format(name, shape)) 88 | array = tf.train.load_variable(tf_path, name) 89 | names.append(name) 90 | arrays.append(array) 91 | 92 | for name, array in zip(names, arrays): 93 | name = name.split('/') 94 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 95 | # which are not required for using pretrained model 96 | if any(n in ["adam_v", "adam_m", "global_step"] for n in name): 97 | logger.info("Skipping {}".format("/".join(name))) 98 | continue 99 | pointer = model 100 | for m_name in name: 101 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 102 | l = re.split(r'_(\d+)', m_name) 103 | else: 104 | l = [m_name] 105 | if l[0] == 'kernel' or l[0] == 'gamma': 106 | pointer = getattr(pointer, 'weight') 107 | elif l[0] == 'output_bias' or l[0] == 'beta': 108 | pointer = getattr(pointer, 'bias') 109 | elif l[0] == 'output_weights': 110 | pointer = getattr(pointer, 'weight') 111 | elif l[0] == 'squad': 112 | pointer = getattr(pointer, 'classifier') 113 | else: 114 | try: 115 | pointer = getattr(pointer, l[0]) 116 | except AttributeError: 117 | logger.info("Skipping {}".format("/".join(name))) 118 | continue 119 | if len(l) >= 2: 120 | num = int(l[1]) 121 | pointer = pointer[num] 122 | if m_name[-11:] == '_embeddings': 123 | pointer = getattr(pointer, 'weight') 124 | elif m_name == 'kernel': 125 | array = np.transpose(array) 126 | try: 127 | assert pointer.shape == array.shape 128 | except AssertionError as e: 129 | e.args += (pointer.shape, array.shape) 130 | raise 131 | logger.info("Initialize PyTorch weight {}".format(name)) 132 | pointer.data = torch.from_numpy(array) 133 | return model 134 | 135 | 136 | def gelu(x): 137 | """Implementation of the gelu activation function. 138 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 139 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 140 | Also see https://arxiv.org/abs/1606.08415 141 | """ 142 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 143 | 144 | 145 | def swish(x): 146 | return x * torch.sigmoid(x) 147 | 148 | 149 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 150 | 151 | 152 | class BertConfig(PretrainedConfig): 153 | r""" 154 | :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a 155 | `BertModel`. 156 | 157 | 158 | Arguments: 159 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 160 | hidden_size: Size of the encoder layers and the pooler layer. 161 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 162 | num_attention_heads: Number of attention heads for each attention layer in 163 | the Transformer encoder. 164 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 165 | layer in the Transformer encoder. 166 | hidden_act: The non-linear activation function (function or string) in the 167 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 168 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 169 | layers in the embeddings, encoder, and pooler. 170 | attention_probs_dropout_prob: The dropout ratio for the attention 171 | probabilities. 172 | max_position_embeddings: The maximum sequence length that this model might 173 | ever be used with. Typically set this to something large just in case 174 | (e.g., 512 or 1024 or 2048). 175 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 176 | `BertModel`. 177 | initializer_range: The sttdev of the truncated_normal_initializer for 178 | initializing all weight matrices. 179 | layer_norm_eps: The epsilon used by LayerNorm. 180 | """ 181 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 182 | 183 | def __init__(self, 184 | vocab_size_or_config_json_file=30522, 185 | hidden_size=768, 186 | num_hidden_layers=12, 187 | num_attention_heads=12, 188 | intermediate_size=3072, 189 | hidden_act="gelu", 190 | hidden_dropout_prob=0.1, 191 | attention_probs_dropout_prob=0.1, 192 | max_position_embeddings=512, 193 | type_vocab_size=2, 194 | initializer_range=0.02, 195 | layer_norm_eps=1e-12, 196 | **kwargs): 197 | super(BertConfig, self).__init__(**kwargs) 198 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 199 | and isinstance(vocab_size_or_config_json_file, unicode)): 200 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 201 | json_config = json.loads(reader.read()) 202 | for key, value in json_config.items(): 203 | self.__dict__[key] = value 204 | elif isinstance(vocab_size_or_config_json_file, int): 205 | self.vocab_size = vocab_size_or_config_json_file 206 | self.hidden_size = hidden_size 207 | self.num_hidden_layers = num_hidden_layers 208 | self.num_attention_heads = num_attention_heads 209 | self.hidden_act = hidden_act 210 | self.intermediate_size = intermediate_size 211 | self.hidden_dropout_prob = hidden_dropout_prob 212 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 213 | self.max_position_embeddings = max_position_embeddings 214 | self.type_vocab_size = type_vocab_size 215 | self.initializer_range = initializer_range 216 | self.layer_norm_eps = layer_norm_eps 217 | else: 218 | raise ValueError("First argument must be either a vocabulary size (int)" 219 | "or the path to a pretrained model config file (str)") 220 | 221 | 222 | 223 | try: 224 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 225 | except ImportError: 226 | logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") 227 | class BertLayerNorm(nn.Module): 228 | def __init__(self, hidden_size, eps=1e-12): 229 | """Construct a layernorm module in the TF style (epsilon inside the square root). 230 | """ 231 | super(BertLayerNorm, self).__init__() 232 | self.weight = nn.Parameter(torch.ones(hidden_size)) 233 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 234 | self.variance_epsilon = eps 235 | 236 | def forward(self, x): 237 | u = x.mean(-1, keepdim=True) 238 | s = (x - u).pow(2).mean(-1, keepdim=True) 239 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 240 | return self.weight * x + self.bias 241 | 242 | class BertEmbeddings(nn.Module): 243 | """Construct the embeddings from word, position and token_type embeddings. 244 | """ 245 | def __init__(self, config): 246 | super(BertEmbeddings, self).__init__() 247 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 248 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 249 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 250 | 251 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 252 | # any TensorFlow checkpoint file 253 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 254 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 255 | 256 | def forward(self, input_ids, token_type_ids=None, position_ids=None): 257 | seq_length = input_ids.size(1) 258 | if position_ids is None: 259 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 260 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 261 | if token_type_ids is None: 262 | token_type_ids = torch.zeros_like(input_ids) 263 | 264 | words_embeddings = self.word_embeddings(input_ids) 265 | position_embeddings = self.position_embeddings(position_ids) 266 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 267 | 268 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 269 | embeddings = self.LayerNorm(embeddings) 270 | embeddings = self.dropout(embeddings) 271 | return embeddings 272 | 273 | 274 | class BertSelfAttention(nn.Module): 275 | def __init__(self, config): 276 | super(BertSelfAttention, self).__init__() 277 | if config.hidden_size % config.num_attention_heads != 0: 278 | raise ValueError( 279 | "The hidden size (%d) is not a multiple of the number of attention " 280 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 281 | self.output_attentions = config.output_attentions 282 | 283 | self.num_attention_heads = config.num_attention_heads 284 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 285 | self.all_head_size = self.num_attention_heads * self.attention_head_size 286 | 287 | self.query = nn.Linear(config.hidden_size, self.all_head_size) # [768 * 768] 288 | self.key = nn.Linear(config.hidden_size, self.all_head_size) # [768 * 768] 289 | self.value = nn.Linear(config.hidden_size, self.all_head_size) # [768 * 768] 290 | 291 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 292 | 293 | def transpose_for_scores(self, x): 294 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 295 | x = x.view(*new_x_shape) 296 | return x.permute(0, 2, 1, 3) 297 | 298 | def forward(self, hidden_states, attention_mask, head_mask=None): 299 | mixed_query_layer = self.query(hidden_states) 300 | mixed_key_layer = self.key(hidden_states) 301 | mixed_value_layer = self.value(hidden_states) 302 | 303 | query_layer = self.transpose_for_scores(mixed_query_layer) 304 | key_layer = self.transpose_for_scores(mixed_key_layer) 305 | value_layer = self.transpose_for_scores(mixed_value_layer) 306 | 307 | # Take the dot product between "query" and "key" to get the raw attention scores. 308 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 309 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 310 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 311 | attention_scores = attention_scores + attention_mask 312 | 313 | # Normalize the attention scores to probabilities. 314 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 315 | 316 | # This is actually dropping out entire tokens to attend to, which might 317 | # seem a bit unusual, but is taken from the original Transformer paper. 318 | attention_probs = self.dropout(attention_probs) 319 | 320 | # Mask heads if we want to 321 | if head_mask is not None: 322 | attention_probs = attention_probs * head_mask 323 | 324 | context_layer = torch.matmul(attention_probs, value_layer) 325 | 326 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 327 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 328 | context_layer = context_layer.view(*new_context_layer_shape) 329 | 330 | outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) 331 | return outputs 332 | 333 | 334 | class BertSelfOutput(nn.Module): 335 | def __init__(self, config): 336 | super(BertSelfOutput, self).__init__() 337 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 338 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 339 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 340 | 341 | def forward(self, hidden_states, input_tensor): 342 | hidden_states = self.dense(hidden_states) 343 | hidden_states = self.dropout(hidden_states) 344 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 345 | return hidden_states 346 | 347 | 348 | class BertAttention(nn.Module): 349 | def __init__(self, config): 350 | super(BertAttention, self).__init__() 351 | self.self = BertSelfAttention(config) 352 | self.output = BertSelfOutput(config) 353 | 354 | def prune_heads(self, heads): 355 | if len(heads) == 0: 356 | return 357 | mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) 358 | for head in heads: 359 | mask[head] = 0 360 | mask = mask.view(-1).contiguous().eq(1) 361 | index = torch.arange(len(mask))[mask].long() 362 | # Prune linear layers 363 | self.self.query = prune_linear_layer(self.self.query, index) 364 | self.self.key = prune_linear_layer(self.self.key, index) 365 | self.self.value = prune_linear_layer(self.self.value, index) 366 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 367 | # Update hyper params 368 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 369 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 370 | 371 | def forward(self, input_tensor, attention_mask, head_mask=None): 372 | self_outputs = self.self(input_tensor, attention_mask, head_mask) 373 | attention_output = self.output(self_outputs[0], input_tensor) 374 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 375 | return outputs 376 | 377 | 378 | class BertIntermediate(nn.Module): 379 | def __init__(self, config): 380 | super(BertIntermediate, self).__init__() 381 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 382 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 383 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 384 | else: 385 | self.intermediate_act_fn = config.hidden_act 386 | 387 | def forward(self, hidden_states): 388 | hidden_states = self.dense(hidden_states) 389 | hidden_states = self.intermediate_act_fn(hidden_states) 390 | return hidden_states 391 | 392 | 393 | class BertOutput(nn.Module): 394 | def __init__(self, config): 395 | super(BertOutput, self).__init__() 396 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 397 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 398 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 399 | 400 | def forward(self, hidden_states, input_tensor): 401 | hidden_states = self.dense(hidden_states) 402 | hidden_states = self.dropout(hidden_states) 403 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 404 | return hidden_states 405 | 406 | 407 | class BertLayer(nn.Module): 408 | def __init__(self, config): 409 | super(BertLayer, self).__init__() 410 | self.attention = BertAttention(config) 411 | self.intermediate = BertIntermediate(config) 412 | self.output = BertOutput(config) 413 | 414 | def forward(self, hidden_states, attention_mask, head_mask=None): 415 | attention_outputs = self.attention(hidden_states, attention_mask, head_mask) 416 | attention_output = attention_outputs[0] 417 | intermediate_output = self.intermediate(attention_output) 418 | layer_output = self.output(intermediate_output, attention_output) 419 | outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them 420 | return outputs 421 | 422 | 423 | class BertEncoder(nn.Module): 424 | def __init__(self, config): 425 | super(BertEncoder, self).__init__() 426 | self.output_attentions = config.output_attentions 427 | self.output_hidden_states = config.output_hidden_states 428 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 429 | 430 | def forward(self, hidden_states, attention_mask, head_mask=None): 431 | all_hidden_states = () 432 | all_attentions = () 433 | for i, layer_module in enumerate(self.layer): 434 | if self.output_hidden_states: 435 | all_hidden_states = all_hidden_states + (hidden_states,) 436 | 437 | layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) 438 | hidden_states = layer_outputs[0] 439 | 440 | if self.output_attentions: 441 | all_attentions = all_attentions + (layer_outputs[1],) 442 | 443 | # Add last layer 444 | if self.output_hidden_states: 445 | all_hidden_states = all_hidden_states + (hidden_states,) 446 | 447 | outputs = (hidden_states,) 448 | if self.output_hidden_states: 449 | outputs = outputs + (all_hidden_states,) 450 | if self.output_attentions: 451 | outputs = outputs + (all_attentions,) 452 | return outputs # outputs, (hidden states), (attentions) 453 | 454 | 455 | class BertPooler(nn.Module): 456 | def __init__(self, config): 457 | super(BertPooler, self).__init__() 458 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 459 | self.activation = nn.Tanh() 460 | 461 | def forward(self, hidden_states): 462 | # We "pool" the model by simply taking the hidden state corresponding 463 | # to the first token. 464 | first_token_tensor = hidden_states[:, 0] 465 | pooled_output = self.dense(first_token_tensor) 466 | pooled_output = self.activation(pooled_output) 467 | return pooled_output 468 | 469 | 470 | class BertPredictionHeadTransform(nn.Module): 471 | def __init__(self, config): 472 | super(BertPredictionHeadTransform, self).__init__() 473 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 474 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 475 | self.transform_act_fn = ACT2FN[config.hidden_act] 476 | else: 477 | self.transform_act_fn = config.hidden_act 478 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 479 | 480 | def forward(self, hidden_states): 481 | hidden_states = self.dense(hidden_states) 482 | hidden_states = self.transform_act_fn(hidden_states) 483 | hidden_states = self.LayerNorm(hidden_states) 484 | return hidden_states 485 | 486 | 487 | class BertLMPredictionHead(nn.Module): 488 | def __init__(self, config): 489 | super(BertLMPredictionHead, self).__init__() 490 | self.transform = BertPredictionHeadTransform(config) 491 | 492 | # The output weights are the same as the input embeddings, but there is 493 | # an output-only bias for each token. 494 | self.decoder = nn.Linear(config.hidden_size, 495 | config.vocab_size, 496 | bias=False) 497 | 498 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 499 | 500 | def forward(self, hidden_states): 501 | hidden_states = self.transform(hidden_states) 502 | hidden_states = self.decoder(hidden_states) + self.bias 503 | return hidden_states 504 | 505 | 506 | class BertOnlyMLMHead(nn.Module): 507 | def __init__(self, config): 508 | super(BertOnlyMLMHead, self).__init__() 509 | self.predictions = BertLMPredictionHead(config) 510 | 511 | def forward(self, sequence_output): 512 | prediction_scores = self.predictions(sequence_output) 513 | return prediction_scores 514 | 515 | 516 | class BertOnlyNSPHead(nn.Module): 517 | def __init__(self, config): 518 | super(BertOnlyNSPHead, self).__init__() 519 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 520 | 521 | def forward(self, pooled_output): 522 | seq_relationship_score = self.seq_relationship(pooled_output) 523 | return seq_relationship_score 524 | 525 | 526 | class BertPreTrainingHeads(nn.Module): 527 | def __init__(self, config): 528 | super(BertPreTrainingHeads, self).__init__() 529 | self.predictions = BertLMPredictionHead(config) 530 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 531 | 532 | def forward(self, sequence_output, pooled_output): 533 | prediction_scores = self.predictions(sequence_output) 534 | seq_relationship_score = self.seq_relationship(pooled_output) 535 | return prediction_scores, seq_relationship_score 536 | 537 | 538 | class BertPreTrainedModel(PreTrainedModel): 539 | """ An abstract class to handle weights initialization and 540 | a simple interface for dowloading and loading pretrained models. 541 | """ 542 | config_class = BertConfig 543 | pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP 544 | load_tf_weights = load_tf_weights_in_bert 545 | base_model_prefix = "bert" 546 | 547 | def __init__(self, *inputs, **kwargs): 548 | super(BertPreTrainedModel, self).__init__(*inputs, **kwargs) 549 | 550 | def init_weights(self, module): 551 | """ Initialize the weights. 552 | """ 553 | if isinstance(module, (nn.Linear, nn.Embedding)): 554 | # Slightly different from the TF version which uses truncated_normal for initialization 555 | # cf https://github.com/pytorch/pytorch/pull/5617 556 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 557 | elif isinstance(module, BertLayerNorm): 558 | module.bias.data.zero_() 559 | module.weight.data.fill_(1.0) 560 | if isinstance(module, nn.Linear) and module.bias is not None: 561 | module.bias.data.zero_() 562 | 563 | 564 | BERT_START_DOCSTRING = r""" The BERT model was proposed in 565 | `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ 566 | by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer 567 | pre-trained using a combination of masked language modeling objective and next sentence prediction 568 | on a large corpus comprising the Toronto Book Corpus and Wikipedia. 569 | 570 | This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and 571 | refer to the PyTorch documentation for all matter related to general usage and behavior. 572 | 573 | .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`: 574 | https://arxiv.org/abs/1810.04805 575 | 576 | .. _`torch.nn.Module`: 577 | https://pytorch.org/docs/stable/nn.html#module 578 | 579 | Parameters: 580 | config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model. 581 | """ 582 | 583 | BERT_INPUTS_DOCSTRING = r""" 584 | Inputs: 585 | **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 586 | Indices of input sequence tokens in the vocabulary. 587 | To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows: 588 | 589 | (a) For sequence pairs: 590 | 591 | ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]`` 592 | 593 | ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1`` 594 | 595 | (b) For single sequences: 596 | 597 | ``tokens: [CLS] the dog is hairy . [SEP]`` 598 | 599 | ``token_type_ids: 0 0 0 0 0 0 0`` 600 | 601 | Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`. 602 | See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and 603 | :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. 604 | **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 605 | Indices of positions of each input sequence tokens in the position embeddings. 606 | Selected in the range ``[0, config.max_position_embeddings - 1[``. 607 | **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 608 | Segment token indices to indicate first and second portions of the inputs. 609 | Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` 610 | corresponds to a `sentence B` token 611 | (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details). 612 | **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: 613 | Mask to avoid performing attention on padding token indices. 614 | Mask values selected in ``[0, 1]``: 615 | ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. 616 | **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: 617 | Mask to nullify selected heads of the self-attention modules. 618 | Mask values selected in ``[0, 1]``: 619 | ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. 620 | """ 621 | 622 | @add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.", 623 | BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) 624 | class BertModel(BertPreTrainedModel): 625 | r""" 626 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 627 | **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` 628 | Sequence of hidden-states at the output of the last layer of the model. 629 | **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` 630 | Last layer hidden-state of the first token of the sequence (classification token) 631 | further processed by a Linear layer and a Tanh activation function. The Linear 632 | layer weights are trained from the next sentence prediction (classification) 633 | objective during Bert pretraining. This output is usually *not* a good summary 634 | of the semantic content of the input, you're often better with averaging or pooling 635 | the sequence of hidden-states for the whole input sequence. 636 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 637 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 638 | of shape ``(batch_size, sequence_length, hidden_size)``: 639 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 640 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 641 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 642 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 643 | 644 | Examples:: 645 | 646 | >>> config = BertConfig.from_pretrained('bert-base-uncased') 647 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 648 | >>> model = BertModel(config) 649 | >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 650 | >>> outputs = model(input_ids) 651 | >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple 652 | 653 | """ 654 | def __init__(self, config): 655 | super(BertModel, self).__init__(config) 656 | 657 | self.embeddings = BertEmbeddings(config) 658 | self.encoder = BertEncoder(config) 659 | self.pooler = BertPooler(config) 660 | 661 | self.apply(self.init_weights) 662 | 663 | def _resize_token_embeddings(self, new_num_tokens): 664 | old_embeddings = self.embeddings.word_embeddings 665 | new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) 666 | self.embeddings.word_embeddings = new_embeddings 667 | return self.embeddings.word_embeddings 668 | 669 | def _prune_heads(self, heads_to_prune): 670 | """ Prunes heads of the model. 671 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 672 | See base class PreTrainedModel 673 | """ 674 | for layer, heads in heads_to_prune.items(): 675 | self.encoder.layer[layer].attention.prune_heads(heads) 676 | 677 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None): 678 | if attention_mask is None: 679 | attention_mask = torch.ones_like(input_ids) 680 | if token_type_ids is None: 681 | token_type_ids = torch.zeros_like(input_ids) 682 | 683 | # We create a 3D attention mask from a 2D tensor mask. 684 | # Sizes are [batch_size, 1, 1, to_seq_length] 685 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 686 | # this attention mask is more simple than the triangular masking of causal attention 687 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 688 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 689 | 690 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 691 | # masked positions, this operation will create a tensor which is 0.0 for 692 | # positions we want to attend and -10000.0 for masked positions. 693 | # Since we are adding it to the raw scores before the softmax, this is 694 | # effectively the same as removing these entirely. 695 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 696 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 697 | 698 | # Prepare head mask if needed 699 | # 1.0 in head_mask indicate we keep the head 700 | # attention_probs has shape bsz x n_heads x N x N 701 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 702 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 703 | if head_mask is not None: 704 | if head_mask.dim() == 1: 705 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 706 | head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) 707 | elif head_mask.dim() == 2: 708 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer 709 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility 710 | else: 711 | head_mask = [None] * self.config.num_hidden_layers 712 | 713 | embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) 714 | encoder_outputs = self.encoder(embedding_output, 715 | extended_attention_mask, 716 | head_mask=head_mask) 717 | sequence_output = encoder_outputs[0] 718 | pooled_output = self.pooler(sequence_output) 719 | 720 | outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here 721 | return outputs # sequence_output, pooled_output, (hidden_states), (attentions) 722 | 723 | 724 | @add_start_docstrings("""Bert Model with two heads on top as done during the pre-training: 725 | a `masked language modeling` head and a `next sentence prediction (classification)` head. """, 726 | BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) 727 | class BertForPreTraining(BertPreTrainedModel): 728 | r""" 729 | **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 730 | Labels for computing the masked language modeling loss. 731 | Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) 732 | Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels 733 | in ``[0, ..., config.vocab_size]`` 734 | **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 735 | Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) 736 | Indices should be in ``[0, 1]``. 737 | ``0`` indicates sequence B is a continuation of sequence A, 738 | ``1`` indicates sequence B is a random sequence. 739 | 740 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 741 | **loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: 742 | Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. 743 | **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` 744 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 745 | **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)`` 746 | Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). 747 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 748 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 749 | of shape ``(batch_size, sequence_length, hidden_size)``: 750 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 751 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 752 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 753 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 754 | 755 | Examples:: 756 | 757 | >>> config = BertConfig.from_pretrained('bert-base-uncased') 758 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 759 | >>> 760 | >>> model = BertForPreTraining(config) 761 | >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 762 | >>> outputs = model(input_ids) 763 | >>> prediction_scores, seq_relationship_scores = outputs[:2] 764 | 765 | """ 766 | def __init__(self, config): 767 | super(BertForPreTraining, self).__init__(config) 768 | 769 | self.bert = BertModel(config) 770 | self.cls = BertPreTrainingHeads(config) 771 | 772 | self.apply(self.init_weights) 773 | self.tie_weights() 774 | 775 | def tie_weights(self): 776 | """ Make sure we are sharing the input and output embeddings. 777 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 778 | """ 779 | self._tie_or_clone_weights(self.cls.predictions.decoder, 780 | self.bert.embeddings.word_embeddings) 781 | 782 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, 783 | next_sentence_label=None, position_ids=None, head_mask=None): 784 | outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, 785 | attention_mask=attention_mask, head_mask=head_mask) 786 | 787 | sequence_output, pooled_output = outputs[:2] 788 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 789 | 790 | outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here 791 | 792 | if masked_lm_labels is not None and next_sentence_label is not None: 793 | loss_fct = CrossEntropyLoss(ignore_index=-1) 794 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 795 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 796 | total_loss = masked_lm_loss + next_sentence_loss 797 | outputs = (total_loss,) + outputs 798 | 799 | return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) 800 | 801 | 802 | @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, 803 | BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) 804 | class BertForMaskedLM(BertPreTrainedModel): 805 | r""" 806 | **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 807 | Labels for computing the masked language modeling loss. 808 | Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) 809 | Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels 810 | in ``[0, ..., config.vocab_size]`` 811 | 812 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 813 | **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 814 | Masked language modeling loss. 815 | **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` 816 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 817 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 818 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 819 | of shape ``(batch_size, sequence_length, hidden_size)``: 820 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 821 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 822 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 823 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 824 | 825 | Examples:: 826 | 827 | >>> config = BertConfig.from_pretrained('bert-base-uncased') 828 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 829 | >>> 830 | >>> model = BertForMaskedLM(config) 831 | >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 832 | >>> outputs = model(input_ids, masked_lm_labels=input_ids) 833 | >>> loss, prediction_scores = outputs[:2] 834 | 835 | """ 836 | def __init__(self, config): 837 | super(BertForMaskedLM, self).__init__(config) 838 | 839 | self.bert = BertModel(config) 840 | self.cls = BertOnlyMLMHead(config) 841 | 842 | self.apply(self.init_weights) 843 | self.tie_weights() 844 | 845 | def tie_weights(self): 846 | """ Make sure we are sharing the input and output embeddings. 847 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 848 | """ 849 | self._tie_or_clone_weights(self.cls.predictions.decoder, 850 | self.bert.embeddings.word_embeddings) 851 | 852 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, 853 | position_ids=None, head_mask=None): 854 | outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, 855 | attention_mask=attention_mask, head_mask=head_mask) 856 | 857 | sequence_output = outputs[0] 858 | prediction_scores = self.cls(sequence_output) 859 | 860 | outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention is they are here 861 | if masked_lm_labels is not None: 862 | loss_fct = CrossEntropyLoss(ignore_index=-1) 863 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 864 | outputs = (masked_lm_loss,) + outputs 865 | 866 | return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) 867 | 868 | 869 | @add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """, 870 | BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) 871 | class BertForNextSentencePrediction(BertPreTrainedModel): 872 | r""" 873 | **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 874 | Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) 875 | Indices should be in ``[0, 1]``. 876 | ``0`` indicates sequence B is a continuation of sequence A, 877 | ``1`` indicates sequence B is a random sequence. 878 | 879 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 880 | **loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 881 | Next sequence prediction (classification) loss. 882 | **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)`` 883 | Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). 884 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 885 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 886 | of shape ``(batch_size, sequence_length, hidden_size)``: 887 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 888 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 889 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 890 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 891 | 892 | Examples:: 893 | 894 | >>> config = BertConfig.from_pretrained('bert-base-uncased') 895 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 896 | >>> 897 | >>> model = BertForNextSentencePrediction(config) 898 | >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 899 | >>> outputs = model(input_ids) 900 | >>> seq_relationship_scores = outputs[0] 901 | 902 | """ 903 | def __init__(self, config): 904 | super(BertForNextSentencePrediction, self).__init__(config) 905 | 906 | self.bert = BertModel(config) 907 | self.cls = BertOnlyNSPHead(config) 908 | 909 | self.apply(self.init_weights) 910 | 911 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, 912 | position_ids=None, head_mask=None): 913 | outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, 914 | attention_mask=attention_mask, head_mask=head_mask) 915 | pooled_output = outputs[1] 916 | 917 | seq_relationship_score = self.cls(pooled_output) 918 | 919 | outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here 920 | if next_sentence_label is not None: 921 | loss_fct = CrossEntropyLoss(ignore_index=-1) 922 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 923 | outputs = (next_sentence_loss,) + outputs 924 | 925 | return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) 926 | 927 | 928 | @add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of 929 | the pooled output) e.g. for GLUE tasks. """, 930 | BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) 931 | class BertForSequenceClassification(BertPreTrainedModel): 932 | r""" 933 | **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 934 | Labels for computing the sequence classification/regression loss. 935 | Indices should be in ``[0, ..., config.num_labels]``. 936 | If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), 937 | If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). 938 | 939 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 940 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 941 | Classification (or regression if config.num_labels==1) loss. 942 | **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` 943 | Classification (or regression if config.num_labels==1) scores (before SoftMax). 944 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 945 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 946 | of shape ``(batch_size, sequence_length, hidden_size)``: 947 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 948 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 949 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 950 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 951 | 952 | Examples:: 953 | 954 | >>> config = BertConfig.from_pretrained('bert-base-uncased') 955 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 956 | >>> 957 | >>> model = BertForSequenceClassification(config) 958 | >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 959 | >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 960 | >>> outputs = model(input_ids, labels=labels) 961 | >>> loss, logits = outputs[:2] 962 | 963 | """ 964 | def __init__(self, config): 965 | super(BertForSequenceClassification, self).__init__(config) 966 | self.num_labels = config.num_labels 967 | 968 | self.bert = BertModel(config) 969 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 970 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) 971 | 972 | self.apply(self.init_weights) 973 | 974 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, 975 | position_ids=None, head_mask=None): 976 | outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, 977 | attention_mask=attention_mask, head_mask=head_mask) 978 | pooled_output = outputs[1] 979 | 980 | pooled_output = self.dropout(pooled_output) 981 | logits = self.classifier(pooled_output) 982 | 983 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here 984 | 985 | if labels is not None: 986 | if self.num_labels == 1: 987 | # We are doing regression 988 | loss_fct = MSELoss() 989 | loss = loss_fct(logits.view(-1), labels.view(-1)) 990 | else: 991 | loss_fct = CrossEntropyLoss() 992 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 993 | outputs = (loss,) + outputs 994 | 995 | return outputs # (loss), logits, (hidden_states), (attentions) 996 | 997 | 998 | @add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of 999 | the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, 1000 | BERT_START_DOCSTRING) 1001 | class BertForMultipleChoice(BertPreTrainedModel): 1002 | r""" 1003 | Inputs: 1004 | **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``: 1005 | Indices of input sequence tokens in the vocabulary. 1006 | The second dimension of the input (`num_choices`) indicates the number of choices to score. 1007 | To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows: 1008 | 1009 | (a) For sequence pairs: 1010 | 1011 | ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]`` 1012 | 1013 | ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1`` 1014 | 1015 | (b) For single sequences: 1016 | 1017 | ``tokens: [CLS] the dog is hairy . [SEP]`` 1018 | 1019 | ``token_type_ids: 0 0 0 0 0 0 0`` 1020 | 1021 | Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`. 1022 | See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and 1023 | :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. 1024 | **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``: 1025 | Segment token indices to indicate first and second portions of the inputs. 1026 | The second dimension of the input (`num_choices`) indicates the number of choices to score. 1027 | Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` 1028 | corresponds to a `sentence B` token 1029 | (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details). 1030 | **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``: 1031 | Mask to avoid performing attention on padding token indices. 1032 | The second dimension of the input (`num_choices`) indicates the number of choices to score. 1033 | Mask values selected in ``[0, 1]``: 1034 | ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. 1035 | **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: 1036 | Mask to nullify selected heads of the self-attention modules. 1037 | Mask values selected in ``[0, 1]``: 1038 | ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. 1039 | **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 1040 | Labels for computing the multiple choice classification loss. 1041 | Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension 1042 | of the input tensors. (see `input_ids` above) 1043 | 1044 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 1045 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 1046 | Classification loss. 1047 | **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension 1048 | of the input tensors. (see `input_ids` above). 1049 | Classification scores (before SoftMax). 1050 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 1051 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 1052 | of shape ``(batch_size, sequence_length, hidden_size)``: 1053 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 1054 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 1055 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 1056 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 1057 | 1058 | Examples:: 1059 | 1060 | >>> config = BertConfig.from_pretrained('bert-base-uncased') 1061 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 1062 | >>> 1063 | >>> model = BertForMultipleChoice(config) 1064 | >>> choices = ["Hello, my dog is cute", "Hello, my cat is amazing"] 1065 | >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices 1066 | >>> labels = torch.tensor(1).unsqueeze(0) # Batch size 1 1067 | >>> outputs = model(input_ids, labels=labels) 1068 | >>> loss, classification_scores = outputs[:2] 1069 | 1070 | """ 1071 | def __init__(self, config): 1072 | super(BertForMultipleChoice, self).__init__(config) 1073 | 1074 | self.bert = BertModel(config) 1075 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1076 | self.classifier = nn.Linear(config.hidden_size, 1) 1077 | 1078 | self.apply(self.init_weights) 1079 | 1080 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, 1081 | position_ids=None, head_mask=None): 1082 | num_choices = input_ids.shape[1] 1083 | 1084 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 1085 | flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 1086 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 1087 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 1088 | outputs = self.bert(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids, 1089 | attention_mask=flat_attention_mask, head_mask=head_mask) 1090 | pooled_output = outputs[1] 1091 | 1092 | pooled_output = self.dropout(pooled_output) 1093 | logits = self.classifier(pooled_output) 1094 | reshaped_logits = logits.view(-1, num_choices) 1095 | 1096 | outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here 1097 | 1098 | if labels is not None: 1099 | loss_fct = CrossEntropyLoss() 1100 | loss = loss_fct(reshaped_logits, labels) 1101 | outputs = (loss,) + outputs 1102 | 1103 | return outputs # (loss), reshaped_logits, (hidden_states), (attentions) 1104 | 1105 | 1106 | @add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of 1107 | the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, 1108 | BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) 1109 | class BertForTokenClassification(BertPreTrainedModel): 1110 | r""" 1111 | **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 1112 | Labels for computing the token classification loss. 1113 | Indices should be in ``[0, ..., config.num_labels]``. 1114 | 1115 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 1116 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 1117 | Classification loss. 1118 | **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)`` 1119 | Classification scores (before SoftMax). 1120 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 1121 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 1122 | of shape ``(batch_size, sequence_length, hidden_size)``: 1123 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 1124 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 1125 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 1126 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 1127 | 1128 | Examples:: 1129 | 1130 | >>> config = BertConfig.from_pretrained('bert-base-uncased') 1131 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 1132 | >>> 1133 | >>> model = BertForTokenClassification(config) 1134 | >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 1135 | >>> labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1 1136 | >>> outputs = model(input_ids, labels=labels) 1137 | >>> loss, scores = outputs[:2] 1138 | 1139 | """ 1140 | def __init__(self, config): 1141 | super(BertForTokenClassification, self).__init__(config) 1142 | self.num_labels = config.num_labels 1143 | 1144 | self.bert = BertModel(config) 1145 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1146 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1147 | 1148 | self.apply(self.init_weights) 1149 | 1150 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, 1151 | position_ids=None, head_mask=None): 1152 | outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, 1153 | attention_mask=attention_mask, head_mask=head_mask) 1154 | sequence_output = outputs[0] 1155 | 1156 | sequence_output = self.dropout(sequence_output) 1157 | logits = self.classifier(sequence_output) 1158 | 1159 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here 1160 | if labels is not None: 1161 | loss_fct = CrossEntropyLoss() 1162 | # Only keep active parts of the loss 1163 | if attention_mask is not None: 1164 | active_loss = attention_mask.view(-1) == 1 1165 | active_logits = logits.view(-1, self.num_labels)[active_loss] 1166 | active_labels = labels.view(-1)[active_loss] 1167 | loss = loss_fct(active_logits, active_labels) 1168 | else: 1169 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1170 | outputs = (loss,) + outputs 1171 | 1172 | return outputs # (loss), scores, (hidden_states), (attentions) 1173 | 1174 | 1175 | @add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of 1176 | the hidden-states output to compute `span start logits` and `span end logits`). """, 1177 | BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) 1178 | class BertForQuestionAnswering(BertPreTrainedModel): 1179 | r""" 1180 | **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 1181 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1182 | Positions are clamped to the length of the sequence (`sequence_length`). 1183 | Position outside of the sequence are not taken into account for computing the loss. 1184 | **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: 1185 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1186 | Positions are clamped to the length of the sequence (`sequence_length`). 1187 | Position outside of the sequence are not taken into account for computing the loss. 1188 | 1189 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 1190 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 1191 | Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. 1192 | **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` 1193 | Span-start scores (before SoftMax). 1194 | **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` 1195 | Span-end scores (before SoftMax). 1196 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 1197 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 1198 | of shape ``(batch_size, sequence_length, hidden_size)``: 1199 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 1200 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 1201 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 1202 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 1203 | 1204 | Examples:: 1205 | 1206 | >>> config = BertConfig.from_pretrained('bert-base-uncased') 1207 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 1208 | >>> 1209 | >>> model = BertForQuestionAnswering(config) 1210 | >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 1211 | >>> start_positions = torch.tensor([1]) 1212 | >>> end_positions = torch.tensor([3]) 1213 | >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) 1214 | >>> loss, start_scores, end_scores = outputs[:2] 1215 | 1216 | """ 1217 | def __init__(self, config): 1218 | super(BertForQuestionAnswering, self).__init__(config) 1219 | self.num_labels = config.num_labels 1220 | 1221 | self.bert = BertModel(config) 1222 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 1223 | 1224 | self.apply(self.init_weights) 1225 | 1226 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, 1227 | end_positions=None, position_ids=None, head_mask=None): 1228 | outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, 1229 | attention_mask=attention_mask, head_mask=head_mask) 1230 | sequence_output = outputs[0] 1231 | 1232 | logits = self.qa_outputs(sequence_output) 1233 | start_logits, end_logits = logits.split(1, dim=-1) 1234 | start_logits = start_logits.squeeze(-1) 1235 | end_logits = end_logits.squeeze(-1) 1236 | 1237 | outputs = (start_logits, end_logits,) + outputs[2:] 1238 | if start_positions is not None and end_positions is not None: 1239 | # If we are on multi-GPU, split add a dimension 1240 | if len(start_positions.size()) > 1: 1241 | start_positions = start_positions.squeeze(-1) 1242 | if len(end_positions.size()) > 1: 1243 | end_positions = end_positions.squeeze(-1) 1244 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1245 | ignored_index = start_logits.size(1) 1246 | start_positions.clamp_(0, ignored_index) 1247 | end_positions.clamp_(0, ignored_index) 1248 | 1249 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1250 | start_loss = loss_fct(start_logits, start_positions) 1251 | end_loss = loss_fct(end_logits, end_positions) 1252 | total_loss = (start_loss + end_loss) / 2 1253 | outputs = (total_loss,) + outputs 1254 | 1255 | return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) 1256 | --------------------------------------------------------------------------------