├── .gitignore ├── Object Detection ├── Faster R-CNN.ipynb ├── README.md ├── data │ ├── annotations.xml │ └── images │ │ ├── 0462.png │ │ └── 0896.png ├── model.py └── utils.py ├── README.md ├── Transformer ├── .DS_Store ├── pytorch transformer │ └── language translation.ipynb └── the extra annotated transformer │ ├── .DS_Store │ ├── inference_test.ipynb │ └── transformer.py ├── attention ├── basic transformer.ipynb ├── neural translation with attention.ipynb └── self attention.ipynb ├── charRNN ├── charRNN_batching.jpg ├── generated_text_out.txt └── text_generation_with_charRNN.ipynb ├── pytorch basics ├── CIFAR-10 classification with linear layers.ipynb ├── CNNs.ipynb ├── README.md ├── data_representations.ipynb ├── imagenet_classes.txt ├── mechanics_of_learning.ipynb ├── neural_networks.ipynb ├── tensors.ipynb ├── using_pretrained_models.ipynb └── working_with_text_data.ipynb ├── sentiment_analysis ├── README.md ├── __pycache__ │ └── utils.cpython-38.pyc ├── sentiment_analysis_basic.ipynb ├── sentiment_analysis_v1.ipynb └── utils.py └── time series ├── sequence_data.jpg ├── sine_wave_prediction.ipynb ├── temperature_forecasting_lstm.ipynb └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # folders 2 | dlwpt-code/ 3 | data/ 4 | .ipynb_checkpoints/ 5 | **/.ipynb_checkpoints/ 6 | images/ 7 | **/*.pt 8 | *.pt 9 | -------------------------------------------------------------------------------- /Object Detection/README.md: -------------------------------------------------------------------------------- 1 | ## Faster-RCNN Implementation in Pytorch 2 | 3 | This is an implementation of Faster-RCNN in Pytorch. 4 | Please refer to the [associated blog](https://medium.com/p/11acfff216b0) for more details. -------------------------------------------------------------------------------- /Object Detection/data/annotations.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 1.1 4 | 5 | 6 | 2 7 | article-labelling 8 | 2 9 | annotation 10 | 0 11 | 12 | 2022-10-29 14:45:53.093755+00:00 13 | 2022-10-29 14:56:10.408986+00:00 14 | default 15 | 0 16 | 1 17 | 18 | 19 | 20 | 2 21 | 0 22 | 1 23 | http://localhost:8080/api/jobs/2 24 | 25 | 26 | 27 | wingedrasengan927 28 | ms.neerajkrishna@gmail.com 29 | 30 | 31 | wingedrasengan927 32 | ms.neerajkrishna@gmail.com 33 | 34 | 35 | 42 | 49 | 50 | 51 | 2022-10-29 14:58:01.689218+00:00 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /Object Detection/data/images/0462.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wingedrasengan927/pytorch-tutorials/20ddc1fda3750550c8426799187276619a7f7610/Object Detection/data/images/0462.png -------------------------------------------------------------------------------- /Object Detection/data/images/0896.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wingedrasengan927/pytorch-tutorials/20ddc1fda3750550c8426799187276619a7f7610/Object Detection/data/images/0896.png -------------------------------------------------------------------------------- /Object Detection/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision import ops 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torch.nn as nn 7 | 8 | from utils import * 9 | 10 | # -------------------- Models ----------------------- 11 | 12 | class FeatureExtractor(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | model = torchvision.models.resnet50(pretrained=True) 16 | req_layers = list(model.children())[:8] 17 | self.backbone = nn.Sequential(*req_layers) 18 | for param in self.backbone.named_parameters(): 19 | param[1].requires_grad = True 20 | 21 | def forward(self, img_data): 22 | return self.backbone(img_data) 23 | 24 | class ProposalModule(nn.Module): 25 | def __init__(self, in_features, hidden_dim=512, n_anchors=9, p_dropout=0.3): 26 | super().__init__() 27 | self.n_anchors = n_anchors 28 | self.conv1 = nn.Conv2d(in_features, hidden_dim, kernel_size=3, padding=1) 29 | self.dropout = nn.Dropout(p_dropout) 30 | self.conf_head = nn.Conv2d(hidden_dim, n_anchors, kernel_size=1) 31 | self.reg_head = nn.Conv2d(hidden_dim, n_anchors * 4, kernel_size=1) 32 | 33 | def forward(self, feature_map, pos_anc_ind=None, neg_anc_ind=None, pos_anc_coords=None): 34 | # determine mode 35 | if pos_anc_ind is None or neg_anc_ind is None or pos_anc_coords is None: 36 | mode = 'eval' 37 | else: 38 | mode = 'train' 39 | 40 | out = self.conv1(feature_map) 41 | out = F.relu(self.dropout(out)) 42 | 43 | reg_offsets_pred = self.reg_head(out) # (B, A*4, hmap, wmap) 44 | conf_scores_pred = self.conf_head(out) # (B, A, hmap, wmap) 45 | 46 | if mode == 'train': 47 | # get conf scores 48 | conf_scores_pos = conf_scores_pred.flatten()[pos_anc_ind] 49 | conf_scores_neg = conf_scores_pred.flatten()[neg_anc_ind] 50 | # get offsets for +ve anchors 51 | offsets_pos = reg_offsets_pred.contiguous().view(-1, 4)[pos_anc_ind] 52 | # generate proposals using offsets 53 | proposals = generate_proposals(pos_anc_coords, offsets_pos) 54 | 55 | return conf_scores_pos, conf_scores_neg, offsets_pos, proposals 56 | 57 | elif mode == 'eval': 58 | return conf_scores_pred, reg_offsets_pred 59 | 60 | class RegionProposalNetwork(nn.Module): 61 | def __init__(self, img_size, out_size, out_channels): 62 | super().__init__() 63 | 64 | self.img_height, self.img_width = img_size 65 | self.out_h, self.out_w = out_size 66 | 67 | # downsampling scale factor 68 | self.width_scale_factor = self.img_width // self.out_w 69 | self.height_scale_factor = self.img_height // self.out_h 70 | 71 | # scales and ratios for anchor boxes 72 | self.anc_scales = [2, 4, 6] 73 | self.anc_ratios = [0.5, 1, 1.5] 74 | self.n_anc_boxes = len(self.anc_scales) * len(self.anc_ratios) 75 | 76 | # IoU thresholds for +ve and -ve anchors 77 | self.pos_thresh = 0.7 78 | self.neg_thresh = 0.3 79 | 80 | # weights for loss 81 | self.w_conf = 1 82 | self.w_reg = 5 83 | 84 | self.feature_extractor = FeatureExtractor() 85 | self.proposal_module = ProposalModule(out_channels, n_anchors=self.n_anc_boxes) 86 | 87 | def forward(self, images, gt_bboxes, gt_classes): 88 | batch_size = images.size(dim=0) 89 | feature_map = self.feature_extractor(images) 90 | 91 | # generate anchors 92 | anc_pts_x, anc_pts_y = gen_anc_centers(out_size=(self.out_h, self.out_w)) 93 | anc_base = gen_anc_base(anc_pts_x, anc_pts_y, self.anc_scales, self.anc_ratios, (self.out_h, self.out_w)) 94 | anc_boxes_all = anc_base.repeat(batch_size, 1, 1, 1, 1) 95 | 96 | # get positive and negative anchors amongst other things 97 | gt_bboxes_proj = project_bboxes(gt_bboxes, self.width_scale_factor, self.height_scale_factor, mode='p2a') 98 | 99 | positive_anc_ind, negative_anc_ind, GT_conf_scores, \ 100 | GT_offsets, GT_class_pos, positive_anc_coords, \ 101 | negative_anc_coords, positive_anc_ind_sep = get_req_anchors(anc_boxes_all, gt_bboxes_proj, gt_classes) 102 | 103 | # pass through the proposal module 104 | conf_scores_pos, conf_scores_neg, offsets_pos, proposals = self.proposal_module(feature_map, positive_anc_ind, \ 105 | negative_anc_ind, positive_anc_coords) 106 | 107 | cls_loss = calc_cls_loss(conf_scores_pos, conf_scores_neg, batch_size) 108 | reg_loss = calc_bbox_reg_loss(GT_offsets, offsets_pos, batch_size) 109 | 110 | total_rpn_loss = self.w_conf * cls_loss + self.w_reg * reg_loss 111 | 112 | return total_rpn_loss, feature_map, proposals, positive_anc_ind_sep, GT_class_pos 113 | 114 | def inference(self, images, conf_thresh=0.5, nms_thresh=0.7): 115 | with torch.no_grad(): 116 | batch_size = images.size(dim=0) 117 | feature_map = self.feature_extractor(images) 118 | 119 | # generate anchors 120 | anc_pts_x, anc_pts_y = gen_anc_centers(out_size=(self.out_h, self.out_w)) 121 | anc_base = gen_anc_base(anc_pts_x, anc_pts_y, self.anc_scales, self.anc_ratios, (self.out_h, self.out_w)) 122 | anc_boxes_all = anc_base.repeat(batch_size, 1, 1, 1, 1) 123 | anc_boxes_flat = anc_boxes_all.reshape(batch_size, -1, 4) 124 | 125 | # get conf scores and offsets 126 | conf_scores_pred, offsets_pred = self.proposal_module(feature_map) 127 | conf_scores_pred = conf_scores_pred.reshape(batch_size, -1) 128 | offsets_pred = offsets_pred.reshape(batch_size, -1, 4) 129 | 130 | # filter out proposals based on conf threshold and nms threshold for each image 131 | proposals_final = [] 132 | conf_scores_final = [] 133 | for i in range(batch_size): 134 | conf_scores = torch.sigmoid(conf_scores_pred[i]) 135 | offsets = offsets_pred[i] 136 | anc_boxes = anc_boxes_flat[i] 137 | proposals = generate_proposals(anc_boxes, offsets) 138 | # filter based on confidence threshold 139 | conf_idx = torch.where(conf_scores >= conf_thresh)[0] 140 | conf_scores_pos = conf_scores[conf_idx] 141 | proposals_pos = proposals[conf_idx] 142 | # filter based on nms threshold 143 | nms_idx = ops.nms(proposals_pos, conf_scores_pos, nms_thresh) 144 | conf_scores_pos = conf_scores_pos[nms_idx] 145 | proposals_pos = proposals_pos[nms_idx] 146 | 147 | proposals_final.append(proposals_pos) 148 | conf_scores_final.append(conf_scores_pos) 149 | 150 | return proposals_final, conf_scores_final, feature_map 151 | 152 | class ClassificationModule(nn.Module): 153 | def __init__(self, out_channels, n_classes, roi_size, hidden_dim=512, p_dropout=0.3): 154 | super().__init__() 155 | self.roi_size = roi_size 156 | # hidden network 157 | self.avg_pool = nn.AvgPool2d(self.roi_size) 158 | self.fc = nn.Linear(out_channels, hidden_dim) 159 | self.dropout = nn.Dropout(p_dropout) 160 | 161 | # define classification head 162 | self.cls_head = nn.Linear(hidden_dim, n_classes) 163 | 164 | def forward(self, feature_map, proposals_list, gt_classes=None): 165 | 166 | if gt_classes is None: 167 | mode = 'eval' 168 | else: 169 | mode = 'train' 170 | 171 | # apply roi pooling on proposals followed by avg pooling 172 | roi_out = ops.roi_pool(feature_map, proposals_list, self.roi_size) 173 | roi_out = self.avg_pool(roi_out) 174 | 175 | # flatten the output 176 | roi_out = roi_out.squeeze(-1).squeeze(-1) 177 | 178 | # pass the output through the hidden network 179 | out = self.fc(roi_out) 180 | out = F.relu(self.dropout(out)) 181 | 182 | # get the classification scores 183 | cls_scores = self.cls_head(out) 184 | 185 | if mode == 'eval': 186 | return cls_scores 187 | 188 | # compute cross entropy loss 189 | cls_loss = F.cross_entropy(cls_scores, gt_classes.long()) 190 | 191 | return cls_loss 192 | 193 | class TwoStageDetector(nn.Module): 194 | def __init__(self, img_size, out_size, out_channels, n_classes, roi_size): 195 | super().__init__() 196 | self.rpn = RegionProposalNetwork(img_size, out_size, out_channels) 197 | self.classifier = ClassificationModule(out_channels, n_classes, roi_size) 198 | 199 | def forward(self, images, gt_bboxes, gt_classes): 200 | total_rpn_loss, feature_map, proposals, \ 201 | positive_anc_ind_sep, GT_class_pos = self.rpn(images, gt_bboxes, gt_classes) 202 | 203 | # get separate proposals for each sample 204 | pos_proposals_list = [] 205 | batch_size = images.size(dim=0) 206 | for idx in range(batch_size): 207 | proposal_idxs = torch.where(positive_anc_ind_sep == idx)[0] 208 | proposals_sep = proposals[proposal_idxs].detach().clone() 209 | pos_proposals_list.append(proposals_sep) 210 | 211 | cls_loss = self.classifier(feature_map, pos_proposals_list, GT_class_pos) 212 | total_loss = cls_loss + total_rpn_loss 213 | 214 | return total_loss 215 | 216 | def inference(self, images, conf_thresh=0.5, nms_thresh=0.7): 217 | batch_size = images.size(dim=0) 218 | proposals_final, conf_scores_final, feature_map = self.rpn.inference(images, conf_thresh, nms_thresh) 219 | cls_scores = self.classifier(feature_map, proposals_final) 220 | 221 | # convert scores into probability 222 | cls_probs = F.softmax(cls_scores, dim=-1) 223 | # get classes with highest probability 224 | classes_all = torch.argmax(cls_probs, dim=-1) 225 | 226 | classes_final = [] 227 | # slice classes to map to their corresponding image 228 | c = 0 229 | for i in range(batch_size): 230 | n_proposals = len(proposals_final[i]) # get the number of proposals for each image 231 | classes_final.append(classes_all[c: c+n_proposals]) 232 | c += n_proposals 233 | 234 | return proposals_final, conf_scores_final, classes_final 235 | 236 | # ------------------- Loss Utils ---------------------- 237 | 238 | def calc_cls_loss(conf_scores_pos, conf_scores_neg, batch_size): 239 | target_pos = torch.ones_like(conf_scores_pos) 240 | target_neg = torch.zeros_like(conf_scores_neg) 241 | 242 | target = torch.cat((target_pos, target_neg)) 243 | inputs = torch.cat((conf_scores_pos, conf_scores_neg)) 244 | 245 | loss = F.binary_cross_entropy_with_logits(inputs, target, reduction='sum') * 1. / batch_size 246 | 247 | return loss 248 | 249 | def calc_bbox_reg_loss(gt_offsets, reg_offsets_pos, batch_size): 250 | assert gt_offsets.size() == reg_offsets_pos.size() 251 | loss = F.smooth_l1_loss(reg_offsets_pos, gt_offsets, reduction='sum') * 1. / batch_size 252 | return loss -------------------------------------------------------------------------------- /Object Detection/utils.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | import numpy as np 3 | import os 4 | import matplotlib.pyplot as plt 5 | import matplotlib.patches as patches 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torchvision import ops 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | 13 | # -------------- Data Untils ------------------- 14 | 15 | def parse_annotation(annotation_path, image_dir, img_size): 16 | ''' 17 | Traverse the xml tree, get the annotations, and resize them to the scaled image size 18 | ''' 19 | img_h, img_w = img_size 20 | 21 | with open(annotation_path, "r") as f: 22 | tree = ET.parse(f) 23 | 24 | root = tree.getroot() 25 | 26 | img_paths = [] 27 | gt_boxes_all = [] 28 | gt_classes_all = [] 29 | # get image paths 30 | for object_ in root.findall('image'): 31 | img_path = os.path.join(image_dir, object_.get("name")) 32 | img_paths.append(img_path) 33 | 34 | # get raw image size 35 | orig_w = int(object_.get("width")) 36 | orig_h = int(object_.get("height")) 37 | 38 | # get bboxes and their labels 39 | groundtruth_boxes = [] 40 | groundtruth_classes = [] 41 | for box_ in object_.findall('box'): 42 | xmin = float(box_.get("xtl")) 43 | ymin = float(box_.get("ytl")) 44 | xmax = float(box_.get("xbr")) 45 | ymax = float(box_.get("ybr")) 46 | 47 | # rescale bboxes 48 | bbox = torch.Tensor([xmin, ymin, xmax, ymax]) 49 | bbox[[0, 2]] = bbox[[0, 2]] * img_w/orig_w 50 | bbox[[1, 3]] = bbox[[1, 3]] * img_h/orig_h 51 | 52 | groundtruth_boxes.append(bbox.tolist()) 53 | 54 | # get labels 55 | label = box_.get("label") 56 | groundtruth_classes.append(label) 57 | 58 | gt_boxes_all.append(torch.Tensor(groundtruth_boxes)) 59 | gt_classes_all.append(groundtruth_classes) 60 | 61 | return gt_boxes_all, gt_classes_all, img_paths 62 | 63 | # -------------- Prepocessing utils ---------------- 64 | 65 | def calc_gt_offsets(pos_anc_coords, gt_bbox_mapping): 66 | pos_anc_coords = ops.box_convert(pos_anc_coords, in_fmt='xyxy', out_fmt='cxcywh') 67 | gt_bbox_mapping = ops.box_convert(gt_bbox_mapping, in_fmt='xyxy', out_fmt='cxcywh') 68 | 69 | gt_cx, gt_cy, gt_w, gt_h = gt_bbox_mapping[:, 0], gt_bbox_mapping[:, 1], gt_bbox_mapping[:, 2], gt_bbox_mapping[:, 3] 70 | anc_cx, anc_cy, anc_w, anc_h = pos_anc_coords[:, 0], pos_anc_coords[:, 1], pos_anc_coords[:, 2], pos_anc_coords[:, 3] 71 | 72 | tx_ = (gt_cx - anc_cx)/anc_w 73 | ty_ = (gt_cy - anc_cy)/anc_h 74 | tw_ = torch.log(gt_w / anc_w) 75 | th_ = torch.log(gt_h / anc_h) 76 | 77 | return torch.stack([tx_, ty_, tw_, th_], dim=-1) 78 | 79 | def gen_anc_centers(out_size): 80 | out_h, out_w = out_size 81 | 82 | anc_pts_x = torch.arange(0, out_w) + 0.5 83 | anc_pts_y = torch.arange(0, out_h) + 0.5 84 | 85 | return anc_pts_x, anc_pts_y 86 | 87 | def project_bboxes(bboxes, width_scale_factor, height_scale_factor, mode='a2p'): 88 | assert mode in ['a2p', 'p2a'] 89 | 90 | batch_size = bboxes.size(dim=0) 91 | proj_bboxes = bboxes.clone().reshape(batch_size, -1, 4) 92 | invalid_bbox_mask = (proj_bboxes == -1) # indicating padded bboxes 93 | 94 | if mode == 'a2p': 95 | # activation map to pixel image 96 | proj_bboxes[:, :, [0, 2]] *= width_scale_factor 97 | proj_bboxes[:, :, [1, 3]] *= height_scale_factor 98 | else: 99 | # pixel image to activation map 100 | proj_bboxes[:, :, [0, 2]] /= width_scale_factor 101 | proj_bboxes[:, :, [1, 3]] /= height_scale_factor 102 | 103 | proj_bboxes.masked_fill_(invalid_bbox_mask, -1) # fill padded bboxes back with -1 104 | proj_bboxes.resize_as_(bboxes) 105 | 106 | return proj_bboxes 107 | 108 | def generate_proposals(anchors, offsets): 109 | 110 | # change format of the anchor boxes from 'xyxy' to 'cxcywh' 111 | anchors = ops.box_convert(anchors, in_fmt='xyxy', out_fmt='cxcywh') 112 | 113 | # apply offsets to anchors to create proposals 114 | proposals_ = torch.zeros_like(anchors) 115 | proposals_[:,0] = anchors[:,0] + offsets[:,0]*anchors[:,2] 116 | proposals_[:,1] = anchors[:,1] + offsets[:,1]*anchors[:,3] 117 | proposals_[:,2] = anchors[:,2] * torch.exp(offsets[:,2]) 118 | proposals_[:,3] = anchors[:,3] * torch.exp(offsets[:,3]) 119 | 120 | # change format of proposals back from 'cxcywh' to 'xyxy' 121 | proposals = ops.box_convert(proposals_, in_fmt='cxcywh', out_fmt='xyxy') 122 | 123 | return proposals 124 | 125 | def gen_anc_base(anc_pts_x, anc_pts_y, anc_scales, anc_ratios, out_size): 126 | n_anc_boxes = len(anc_scales) * len(anc_ratios) 127 | anc_base = torch.zeros(1, anc_pts_x.size(dim=0) \ 128 | , anc_pts_y.size(dim=0), n_anc_boxes, 4) # shape - [1, Hmap, Wmap, n_anchor_boxes, 4] 129 | 130 | for ix, xc in enumerate(anc_pts_x): 131 | for jx, yc in enumerate(anc_pts_y): 132 | anc_boxes = torch.zeros((n_anc_boxes, 4)) 133 | c = 0 134 | for i, scale in enumerate(anc_scales): 135 | for j, ratio in enumerate(anc_ratios): 136 | w = scale * ratio 137 | h = scale 138 | 139 | xmin = xc - w / 2 140 | ymin = yc - h / 2 141 | xmax = xc + w / 2 142 | ymax = yc + h / 2 143 | 144 | anc_boxes[c, :] = torch.Tensor([xmin, ymin, xmax, ymax]) 145 | c += 1 146 | 147 | anc_base[:, ix, jx, :] = ops.clip_boxes_to_image(anc_boxes, size=out_size) 148 | 149 | return anc_base 150 | 151 | def get_iou_mat(batch_size, anc_boxes_all, gt_bboxes_all): 152 | 153 | # flatten anchor boxes 154 | anc_boxes_flat = anc_boxes_all.reshape(batch_size, -1, 4) 155 | # get total anchor boxes for a single image 156 | tot_anc_boxes = anc_boxes_flat.size(dim=1) 157 | 158 | # create a placeholder to compute IoUs amongst the boxes 159 | ious_mat = torch.zeros((batch_size, tot_anc_boxes, gt_bboxes_all.size(dim=1))) 160 | 161 | # compute IoU of the anc boxes with the gt boxes for all the images 162 | for i in range(batch_size): 163 | gt_bboxes = gt_bboxes_all[i] 164 | anc_boxes = anc_boxes_flat[i] 165 | ious_mat[i, :] = ops.box_iou(anc_boxes, gt_bboxes) 166 | 167 | return ious_mat 168 | 169 | def get_req_anchors(anc_boxes_all, gt_bboxes_all, gt_classes_all, pos_thresh=0.7, neg_thresh=0.2): 170 | ''' 171 | Prepare necessary data required for training 172 | 173 | Input 174 | ------ 175 | anc_boxes_all - torch.Tensor of shape (B, w_amap, h_amap, n_anchor_boxes, 4) 176 | all anchor boxes for a batch of images 177 | gt_bboxes_all - torch.Tensor of shape (B, max_objects, 4) 178 | padded ground truth boxes for a batch of images 179 | gt_classes_all - torch.Tensor of shape (B, max_objects) 180 | padded ground truth classes for a batch of images 181 | 182 | Returns 183 | --------- 184 | positive_anc_ind - torch.Tensor of shape (n_pos,) 185 | flattened positive indices for all the images in the batch 186 | negative_anc_ind - torch.Tensor of shape (n_pos,) 187 | flattened positive indices for all the images in the batch 188 | GT_conf_scores - torch.Tensor of shape (n_pos,), IoU scores of +ve anchors 189 | GT_offsets - torch.Tensor of shape (n_pos, 4), 190 | offsets between +ve anchors and their corresponding ground truth boxes 191 | GT_class_pos - torch.Tensor of shape (n_pos,) 192 | mapped classes of +ve anchors 193 | positive_anc_coords - (n_pos, 4) coords of +ve anchors (for visualization) 194 | negative_anc_coords - (n_pos, 4) coords of -ve anchors (for visualization) 195 | positive_anc_ind_sep - list of indices to keep track of +ve anchors 196 | ''' 197 | # get the size and shape parameters 198 | B, w_amap, h_amap, A, _ = anc_boxes_all.shape 199 | N = gt_bboxes_all.shape[1] # max number of groundtruth bboxes in a batch 200 | 201 | # get total number of anchor boxes in a single image 202 | tot_anc_boxes = A * w_amap * h_amap 203 | 204 | # get the iou matrix which contains iou of every anchor box 205 | # against all the groundtruth bboxes in an image 206 | iou_mat = get_iou_mat(B, anc_boxes_all, gt_bboxes_all) 207 | 208 | # for every groundtruth bbox in an image, find the iou 209 | # with the anchor box which it overlaps the most 210 | max_iou_per_gt_box, _ = iou_mat.max(dim=1, keepdim=True) 211 | 212 | # get positive anchor boxes 213 | 214 | # condition 1: the anchor box with the max iou for every gt bbox 215 | positive_anc_mask = torch.logical_and(iou_mat == max_iou_per_gt_box, max_iou_per_gt_box > 0) 216 | # condition 2: anchor boxes with iou above a threshold with any of the gt bboxes 217 | positive_anc_mask = torch.logical_or(positive_anc_mask, iou_mat > pos_thresh) 218 | 219 | positive_anc_ind_sep = torch.where(positive_anc_mask)[0] # get separate indices in the batch 220 | # combine all the batches and get the idxs of the +ve anchor boxes 221 | positive_anc_mask = positive_anc_mask.flatten(start_dim=0, end_dim=1) 222 | positive_anc_ind = torch.where(positive_anc_mask)[0] 223 | 224 | # for every anchor box, get the iou and the idx of the 225 | # gt bbox it overlaps with the most 226 | max_iou_per_anc, max_iou_per_anc_ind = iou_mat.max(dim=-1) 227 | max_iou_per_anc = max_iou_per_anc.flatten(start_dim=0, end_dim=1) 228 | 229 | # get iou scores of the +ve anchor boxes 230 | GT_conf_scores = max_iou_per_anc[positive_anc_ind] 231 | 232 | # get gt classes of the +ve anchor boxes 233 | 234 | # expand gt classes to map against every anchor box 235 | gt_classes_expand = gt_classes_all.view(B, 1, N).expand(B, tot_anc_boxes, N) 236 | # for every anchor box, consider only the class of the gt bbox it overlaps with the most 237 | GT_class = torch.gather(gt_classes_expand, -1, max_iou_per_anc_ind.unsqueeze(-1)).squeeze(-1) 238 | # combine all the batches and get the mapped classes of the +ve anchor boxes 239 | GT_class = GT_class.flatten(start_dim=0, end_dim=1) 240 | GT_class_pos = GT_class[positive_anc_ind] 241 | 242 | # get gt bbox coordinates of the +ve anchor boxes 243 | 244 | # expand all the gt bboxes to map against every anchor box 245 | gt_bboxes_expand = gt_bboxes_all.view(B, 1, N, 4).expand(B, tot_anc_boxes, N, 4) 246 | # for every anchor box, consider only the coordinates of the gt bbox it overlaps with the most 247 | GT_bboxes = torch.gather(gt_bboxes_expand, -2, max_iou_per_anc_ind.reshape(B, tot_anc_boxes, 1, 1).repeat(1, 1, 1, 4)) 248 | # combine all the batches and get the mapped gt bbox coordinates of the +ve anchor boxes 249 | GT_bboxes = GT_bboxes.flatten(start_dim=0, end_dim=2) 250 | GT_bboxes_pos = GT_bboxes[positive_anc_ind] 251 | 252 | # get coordinates of +ve anc boxes 253 | anc_boxes_flat = anc_boxes_all.flatten(start_dim=0, end_dim=-2) # flatten all the anchor boxes 254 | positive_anc_coords = anc_boxes_flat[positive_anc_ind] 255 | 256 | # calculate gt offsets 257 | GT_offsets = calc_gt_offsets(positive_anc_coords, GT_bboxes_pos) 258 | 259 | # get -ve anchors 260 | 261 | # condition: select the anchor boxes with max iou less than the threshold 262 | negative_anc_mask = (max_iou_per_anc < neg_thresh) 263 | negative_anc_ind = torch.where(negative_anc_mask)[0] 264 | # sample -ve samples to match the +ve samples 265 | negative_anc_ind = negative_anc_ind[torch.randint(0, negative_anc_ind.shape[0], (positive_anc_ind.shape[0],))] 266 | negative_anc_coords = anc_boxes_flat[negative_anc_ind] 267 | 268 | return positive_anc_ind, negative_anc_ind, GT_conf_scores, GT_offsets, GT_class_pos, \ 269 | positive_anc_coords, negative_anc_coords, positive_anc_ind_sep 270 | 271 | # # -------------- Visualization utils ---------------- 272 | 273 | def display_img(img_data, fig, axes): 274 | for i, img in enumerate(img_data): 275 | if type(img) == torch.Tensor: 276 | img = img.permute(1, 2, 0).numpy() 277 | axes[i].imshow(img) 278 | 279 | return fig, axes 280 | 281 | def display_bbox(bboxes, fig, ax, classes=None, in_format='xyxy', color='y', line_width=3): 282 | if type(bboxes) == np.ndarray: 283 | bboxes = torch.from_numpy(bboxes) 284 | if classes: 285 | assert len(bboxes) == len(classes) 286 | # convert boxes to xywh format 287 | bboxes = ops.box_convert(bboxes, in_fmt=in_format, out_fmt='xywh') 288 | c = 0 289 | for box in bboxes: 290 | x, y, w, h = box.numpy() 291 | # display bounding box 292 | rect = patches.Rectangle((x, y), w, h, linewidth=line_width, edgecolor=color, facecolor='none') 293 | ax.add_patch(rect) 294 | # display category 295 | if classes: 296 | if classes[c] == 'pad': 297 | continue 298 | ax.text(x + 5, y + 20, classes[c], bbox=dict(facecolor='yellow', alpha=0.5)) 299 | c += 1 300 | 301 | return fig, ax 302 | 303 | def display_grid(x_points, y_points, fig, ax, special_point=None): 304 | # plot grid 305 | for x in x_points: 306 | for y in y_points: 307 | ax.scatter(x, y, color="w", marker='+') 308 | 309 | # plot a special point we want to emphasize on the grid 310 | if special_point: 311 | x, y = special_point 312 | ax.scatter(x, y, color="red", marker='+') 313 | 314 | return fig, ax -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-tutorials 2 | My notes, experiments, and implementations from learning pytorch from various resources. 3 | -------------------------------------------------------------------------------- /Transformer/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wingedrasengan927/pytorch-tutorials/20ddc1fda3750550c8426799187276619a7f7610/Transformer/.DS_Store -------------------------------------------------------------------------------- /Transformer/pytorch transformer/language translation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "49227668-33df-4a3f-9fdf-be65548fc20e", 6 | "metadata": {}, 7 | "source": [ 8 | "## Language Translation using Pytorch Transformer" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "4e4dfca4-87b2-4dca-b5d7-1d224129b400", 14 | "metadata": {}, 15 | "source": [ 16 | "### Data Preparation" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "b6faaf1c-0f7d-427f-ae5c-f643fc735770", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "from torchtext.data.utils import get_tokenizer\n", 27 | "from torchtext.vocab import build_vocab_from_iterator\n", 28 | "from torchtext.datasets import multi30k, Multi30k\n", 29 | "from typing import Iterable, List" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "b5b5a9e0-8c28-4bd5-9c46-f6c9668dab03", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "# modify urls to include new links\n", 40 | "multi30k.URL[\"train\"] = \"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz\"\n", 41 | "multi30k.URL[\"valid\"] = \"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz\"\n", 42 | "\n", 43 | "SRC_LANGUAGE = 'de'\n", 44 | "TGT_LANGUAGE = 'en'\n", 45 | "language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}\n", 46 | "\n", 47 | "# define placeholders\n", 48 | "token_transform = {}\n", 49 | "vocab_transform = {}\n", 50 | "\n", 51 | "# define spacy tokenizers\n", 52 | "token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')\n", 53 | "token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')\n", 54 | "\n", 55 | "# Define special symbols and their indices\n", 56 | "UNK_IDX, PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2, 3\n", 57 | "special_symbols = ['', '', '', '']\n", 58 | "\n", 59 | "# define helper function to yield list of tokens\n", 60 | "def yield_tokens(data_iter: Iterable, language: str) -> List[str]:\n", 61 | " lang_idx = language_index[language]\n", 62 | "\n", 63 | " for data_sample in data_iter:\n", 64 | " yield token_transform[language](data_sample[lang_idx])" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "id": "cd6faa94-3897-417e-b564-1be8ec7d18e2", 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stderr", 75 | "output_type": "stream", 76 | "text": [ 77 | "/Users/mmt9876/opt/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/datapipes/iter/combining.py:248: UserWarning: Some child DataPipes are not exhausted when __iter__ is called. We are resetting the buffer and each child DataPipe will read from the start again.\n", 78 | " warnings.warn(\"Some child DataPipes are not exhausted when __iter__ is called. We are resetting \"\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "# Create training data Iterator\n", 84 | "train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))\n", 85 | "\n", 86 | "# Create torchtext's Vocab object for each language\n", 87 | "for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:\n", 88 | " vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),\n", 89 | " min_freq=1,\n", 90 | " specials=special_symbols,\n", 91 | " special_first=True)\n", 92 | " \n", 93 | "# Set UNK_IDX as the default index. This index is returned when the token is not found.\n", 94 | "for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:\n", 95 | " vocab_transform[ln].set_default_index(UNK_IDX)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "id": "e7845aac-dff6-4c02-b08a-ecf8d0f46833", 101 | "metadata": {}, 102 | "source": [ 103 | "### Define the Seq2Seq Transformer Architecture" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "id": "3444c376-b79f-4517-a287-7bc996823584", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "from torch import Tensor\n", 114 | "import torch\n", 115 | "import torch.nn as nn\n", 116 | "from torch.nn import Transformer\n", 117 | "import math\n", 118 | "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 5, 124 | "id": "29f5a921-827d-4231-a69c-ac68fafcb682", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "class PositionalEncoding(nn.Module):\n", 129 | " def __init__(self,\n", 130 | " emb_dim: int,\n", 131 | " p_dropout: float,\n", 132 | " maxlen: int = 5000):\n", 133 | " super(PositionalEncoding, self).__init__()\n", 134 | " den = torch.exp(- torch.arange(0, emb_dim, 2)* math.log(10000) / emb_dim)\n", 135 | " pos = torch.arange(0, maxlen).reshape(maxlen, 1)\n", 136 | " pos_embedding = torch.zeros((maxlen, emb_dim))\n", 137 | " pos_embedding[:, 0::2] = torch.sin(pos * den)\n", 138 | " pos_embedding[:, 1::2] = torch.cos(pos * den)\n", 139 | " pos_embedding = pos_embedding.unsqueeze(-2)\n", 140 | "\n", 141 | " self.dropout = nn.Dropout(p_dropout)\n", 142 | " self.register_buffer('pos_embedding', pos_embedding)\n", 143 | "\n", 144 | " def forward(self, token_embedding: Tensor):\n", 145 | " return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 6, 151 | "id": "3e05c0ec-4fd2-4687-980a-01d63ef429e5", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "class TokenEmbedding(nn.Module):\n", 156 | " def __init__(self, vocab_size: int, emb_dim: int):\n", 157 | " super(TokenEmbedding, self).__init__()\n", 158 | " self.embedding = nn.Embedding(vocab_size, emb_dim)\n", 159 | " self.emb_dim = emb_dim\n", 160 | "\n", 161 | " def forward(self, tokens: Tensor):\n", 162 | " return self.embedding(tokens.long()) * math.sqrt(self.emb_dim)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 7, 168 | "id": "e7816d99-ac9f-4655-aff1-d49c89f9d613", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "class Generator(nn.Module):\n", 173 | " def __init__(self, emb_dim: int, vocab_size: int):\n", 174 | " super(Generator, self).__init__()\n", 175 | " self.proj = nn.Linear(emb_dim, vocab_size)\n", 176 | "\n", 177 | " def forward(self, x):\n", 178 | " return F.log_softmax(self.proj(x), dim=-1)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 8, 184 | "id": "436ec6c8-66b3-4153-9b75-6fbbef78f084", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "class Seq2SeqTransformer(nn.Module):\n", 189 | " def __init__(self, \n", 190 | " emb_dim: int,\n", 191 | " src_vocab_size: int,\n", 192 | " tgt_vocab_size: int,\n", 193 | " n_heads: int = 8,\n", 194 | " n_encoder_layers: int = 6,\n", 195 | " n_decoder_layers: int = 6,\n", 196 | " dim_feedforward: int = 2048, \n", 197 | " p_dropout: float = 0.1,\n", 198 | " batch_first: bool = True\n", 199 | " ):\n", 200 | " \n", 201 | " super(Seq2SeqTransformer, self).__init__()\n", 202 | "\n", 203 | " self.transformer = nn.Transformer(\n", 204 | " d_model = emb_dim,\n", 205 | " nhead=n_heads,\n", 206 | " num_encoder_layers=n_encoder_layers,\n", 207 | " num_decoder_layers=n_decoder_layers,\n", 208 | " dim_feedforward=dim_feedforward,\n", 209 | " dropout=p_dropout,\n", 210 | " batch_first = batch_first\n", 211 | " )\n", 212 | "\n", 213 | " self.generator = nn.Linear(emb_dim, tgt_vocab_size)\n", 214 | " self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_dim)\n", 215 | " self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_dim)\n", 216 | " self.positional_encoding = PositionalEncoding(emb_dim, p_dropout=p_dropout)\n", 217 | " \n", 218 | " def forward(self,\n", 219 | " src: Tensor,\n", 220 | " tgt: Tensor,\n", 221 | " src_mask: Tensor,\n", 222 | " tgt_mask: Tensor,\n", 223 | " src_padding_mask: Tensor,\n", 224 | " tgt_padding_mask: Tensor,\n", 225 | " memory_key_padding_mask: Tensor):\n", 226 | " \n", 227 | " src_emb = self.positional_encoding(self.src_tok_emb(src))\n", 228 | " tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))\n", 229 | " \n", 230 | " out = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,\n", 231 | " src_padding_mask, tgt_padding_mask, memory_key_padding_mask)\n", 232 | " \n", 233 | " return self.generator(out)\n", 234 | "\n", 235 | " def encode(self, src: Tensor, src_mask: Tensor):\n", 236 | " src_emb = self.positional_encoding(self.src_tok_emb(src))\n", 237 | " return self.transformer.encoder(src_emb, src_mask)\n", 238 | "\n", 239 | " def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):\n", 240 | " tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))\n", 241 | " return self.transformer.decoder(tgt_emb, memory, tgt_mask)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 9, 247 | "id": "76f75331-2a11-4787-b08b-1c786b6ce987", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "def create_mask(src: Tensor, tgt: Tensor):\n", 252 | " \n", 253 | " # assume batch first is True\n", 254 | " src_seq_len = src.size(1)\n", 255 | " tgt_seq_len = tgt.size(1)\n", 256 | " \n", 257 | " tgt_mask = Transformer.generate_square_subsequent_mask(tgt_seq_len) # create subsequent\n", 258 | " # mask for target\n", 259 | " src_mask = torch.zeros((src_seq_len, src_seq_len),\n", 260 | " device=DEVICE).type(torch.bool) # no mask for source\n", 261 | "\n", 262 | " src_padding_mask = (src == PAD_IDX)\n", 263 | " tgt_padding_mask = (tgt == PAD_IDX)\n", 264 | " \n", 265 | " return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 10, 271 | "id": "589ebbad-8000-4e9a-b26f-e30f88e2c320", 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "torch.manual_seed(0)\n", 276 | "\n", 277 | "SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])\n", 278 | "TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])\n", 279 | "EMB_DIM = 512\n", 280 | "N_HEADS = 8\n", 281 | "FFN_HID_DIM = 2048\n", 282 | "BATCH_SIZE = 128\n", 283 | "NUM_ENCODER_LAYERS = 6\n", 284 | "NUM_DECODER_LAYERS = 6\n", 285 | "\n", 286 | "transformer = Seq2SeqTransformer(EMB_DIM, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,\n", 287 | " N_HEADS, NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,\n", 288 | " FFN_HID_DIM)\n", 289 | "\n", 290 | "for p in transformer.parameters():\n", 291 | " if p.dim() > 1:\n", 292 | " nn.init.xavier_uniform_(p)\n", 293 | "\n", 294 | "transformer = transformer.to(DEVICE)\n", 295 | "\n", 296 | "loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)\n", 297 | "\n", 298 | "optimizer = torch.optim.Adam(transformer.parameters(), \n", 299 | " lr=0.0001, \n", 300 | " betas=(0.9, 0.98), \n", 301 | " eps=1e-9)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "id": "c0812726-798f-435c-b6c3-464aea51e65e", 307 | "metadata": {}, 308 | "source": [ 309 | "### Define Data Collation Fuctions" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 11, 315 | "id": "74c47abc-93f8-40d3-aead-ed17c7e632ff", 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "from torch.nn.utils.rnn import pad_sequence\n", 320 | "from torch.utils.data import DataLoader" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 12, 326 | "id": "82cb79ea-ad10-49c3-89b8-cf4e456d695e", 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "# helper function to club together sequential operations\n", 331 | "def sequential_transforms(*transforms):\n", 332 | " def func(txt_input):\n", 333 | " for transform in transforms:\n", 334 | " txt_input = transform(txt_input)\n", 335 | " return txt_input\n", 336 | " return func\n", 337 | "\n", 338 | "# function to add SOS/EOS tokens and create tensor for input sequence indices\n", 339 | "def tensor_transform(token_ids: List[int]):\n", 340 | " return torch.cat((torch.tensor([SOS_IDX]),\n", 341 | " torch.tensor(token_ids),\n", 342 | " torch.tensor([EOS_IDX])))\n", 343 | "\n", 344 | "# src and tgt language text transforms to convert raw strings into tensors indices\n", 345 | "text_transform = {}\n", 346 | "for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:\n", 347 | " text_transform[ln] = sequential_transforms(token_transform[ln], # Tokenization\n", 348 | " vocab_transform[ln], # Numericalization\n", 349 | " tensor_transform) # Add SOS/EOS and create tensor\n", 350 | "\n", 351 | "# function to collate data samples into batch tensors\n", 352 | "def collate_fn(batch):\n", 353 | " src_batch, tgt_batch = [], []\n", 354 | " for src_sample, tgt_sample in batch:\n", 355 | " src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip(\"\\n\")))\n", 356 | " tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip(\"\\n\")))\n", 357 | "\n", 358 | " src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)\n", 359 | " tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)\n", 360 | " \n", 361 | " return src_batch, tgt_batch\n" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "id": "8a4ab823-eb38-423b-a5ae-dc070d982b37", 367 | "metadata": {}, 368 | "source": [ 369 | "### Train the models" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 13, 375 | "id": "7b9fa59f-0e39-4899-b048-4aacc20a44b0", 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [ 379 | "def train_epoch(model, optimizer):\n", 380 | " total_loss = 0 # total loss for each epoch\n", 381 | " \n", 382 | " # the iterator needs to be initialized for each epoch\n", 383 | " train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))\n", 384 | " train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)\n", 385 | " \n", 386 | " n_samples = 0\n", 387 | "\n", 388 | " for src, tgt in train_dataloader:\n", 389 | " n_samples += src.size(0)\n", 390 | " \n", 391 | " src = src.to(DEVICE)\n", 392 | " tgt = tgt.to(DEVICE)\n", 393 | "\n", 394 | " tgt_input = tgt[:, :-1]\n", 395 | "\n", 396 | " src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)\n", 397 | "\n", 398 | " logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)\n", 399 | "\n", 400 | " tgt_out = tgt[:, 1:]\n", 401 | " \n", 402 | " loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))\n", 403 | " \n", 404 | " '''\n", 405 | " Note: we're passing token as tgt input (because of batching and padding)\n", 406 | " but it shouldn't matter because when computing the loss, the corresponding target \n", 407 | " to token would be token and we're ignoring token when \n", 408 | " computing the loss.\n", 409 | " '''\n", 410 | " \n", 411 | " optimizer.zero_grad()\n", 412 | " loss.backward()\n", 413 | " optimizer.step()\n", 414 | " \n", 415 | " total_loss += loss.item()\n", 416 | " \n", 417 | " return total_loss / n_samples" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": null, 423 | "id": "3e45a593-f7e7-4324-ae31-42d9715d3b36", 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "transformer.train()\n", 428 | "\n", 429 | "from timeit import default_timer as timer\n", 430 | "from tqdm import tqdm\n", 431 | "\n", 432 | "NUM_EPOCHS = 1\n", 433 | "\n", 434 | "for epoch in tqdm(range(1, NUM_EPOCHS+1)):\n", 435 | " start_time = timer()\n", 436 | " train_loss = train_epoch(transformer, optimizer)\n", 437 | " end_time = timer()\n", 438 | " print((f\"Epoch: {epoch}, Train loss: {train_loss:.3f}, \"f\"Epoch time = {(end_time - start_time):.3f}s\"))" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "id": "aef3d565-fff4-4e2f-9e67-97ee92f679ad", 444 | "metadata": {}, 445 | "source": [ 446 | "### Inference" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": null, 452 | "id": "5c236847-dd18-4835-827e-46972e8685f4", 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "transformer.eval()" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "id": "b7b582df-005c-4c18-99ce-6d8e7523ed9d", 463 | "metadata": {}, 464 | "outputs": [], 465 | "source": [ 466 | "# function to generate output sequence using greedy algorithm\n", 467 | "def greedy_decode(model, src, src_mask, max_len, start_symbol):\n", 468 | " src = src.to(DEVICE)\n", 469 | " src_mask = src_mask.to(DEVICE)\n", 470 | "\n", 471 | " memory = model.encode(src, src_mask)\n", 472 | " ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)\n", 473 | " for i in range(max_len-1):\n", 474 | " memory = memory.to(DEVICE)\n", 475 | " tgt_mask = (Transformer.generate_square_subsequent_mask(ys.size(0))\n", 476 | " .type(torch.bool)).to(DEVICE)\n", 477 | " out = model.decode(ys, memory, tgt_mask)\n", 478 | " prob = model.generator(out[:, -1])\n", 479 | " _, next_word = torch.max(prob, dim=1)\n", 480 | " next_word = next_word.item()\n", 481 | "\n", 482 | " ys = torch.cat([ys,\n", 483 | " torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)\n", 484 | " if next_word == EOS_IDX:\n", 485 | " break\n", 486 | " return ys" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": null, 492 | "id": "7378d0df-5abb-4647-bf92-65b23b017806", 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [ 496 | "# actual function to translate input sentence into target language\n", 497 | "def translate(model: torch.nn.Module, src_sentence: str):\n", 498 | " src = text_transform[SRC_LANGUAGE](src_sentence).unsqueeze(0)\n", 499 | " num_tokens = src.size(1)\n", 500 | " src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)\n", 501 | " tgt_tokens = greedy_decode(\n", 502 | " model, src, src_mask, max_len=num_tokens + 5, start_symbol=SOS_IDX).flatten()\n", 503 | " return \" \".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace(\"\", \"\").replace(\"\", \"\")" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": null, 509 | "id": "098f84dd-fb0f-460a-80fe-60b7ce97b676", 510 | "metadata": {}, 511 | "outputs": [], 512 | "source": [ 513 | "print(translate(transformer, \"Eine Gruppe von Menschen steht vor einem Iglu .\"))" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "id": "d381bd16-e364-4611-8c6f-cf42c9b21fe8", 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [] 523 | } 524 | ], 525 | "metadata": { 526 | "kernelspec": { 527 | "display_name": "Python 3 (ipykernel)", 528 | "language": "python", 529 | "name": "python3" 530 | }, 531 | "language_info": { 532 | "codemirror_mode": { 533 | "name": "ipython", 534 | "version": 3 535 | }, 536 | "file_extension": ".py", 537 | "mimetype": "text/x-python", 538 | "name": "python", 539 | "nbconvert_exporter": "python", 540 | "pygments_lexer": "ipython3", 541 | "version": "3.9.12" 542 | } 543 | }, 544 | "nbformat": 4, 545 | "nbformat_minor": 5 546 | } 547 | -------------------------------------------------------------------------------- /Transformer/the extra annotated transformer/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wingedrasengan927/pytorch-tutorials/20ddc1fda3750550c8426799187276619a7f7610/Transformer/the extra annotated transformer/.DS_Store -------------------------------------------------------------------------------- /Transformer/the extra annotated transformer/inference_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "\n", 11 | "from transformer import *" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "### Encoder Inference" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "#### Define Parameters" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 3, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "N=6 \n", 35 | "d_model=512 \n", 36 | "d_ff=2048 \n", 37 | "n_heads=8\n", 38 | "p_dropout=0.1\n", 39 | "inpt_vocab_size = 11" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "#### Build Models" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "attn = MultiHeadedAttention(n_heads, d_model)\n", 56 | "ff = PositionwiseFeedForward(d_model, d_ff)\n", 57 | "src_embedding = Embeddings(inpt_vocab_size, d_model)\n", 58 | "postion_embeds = PositionalEncoding(d_model, p_dropout)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 5, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "encoder = Encoder(EncoderBlock(d_model, attn, ff, p_dropout), N)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "#### Inference" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 6, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "src = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unsqueeze(0)\n", 84 | "src_mask = torch.ones(1, 1, 10)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 7, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "x = src_embedding(src)\n", 94 | "x = postion_embeds(x)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 8, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "memory = encoder(x, src_mask)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "### Decoder" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "#### Define Parameters" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 9, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "c = copy.deepcopy\n", 127 | "tgt_vocab_size = 20" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "#### Build Models" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 10, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "tgt_embedding = Embeddings(tgt_vocab_size, d_model)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 11, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "decoder = Decoder(DecoderBlock(d_model, c(attn), c(attn), c(ff), p_dropout), N)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "#### Inference" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 12, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "tgt = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]).unsqueeze(0)\n", 169 | "tgt_mask = subsequent_mask(tgt.size(1))" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 13, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "x = tgt_embedding(tgt)\n", 179 | "x = postion_embeds(x)\n", 180 | "out = decoder(x, memory, src_mask, tgt_mask)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "### Generator" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 14, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "generator = Generator(d_model, tgt_vocab_size)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 22, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "probs = generator(out)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 23, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "text/plain": [ 216 | "torch.Size([1, 15, 20])" 217 | ] 218 | }, 219 | "execution_count": 23, 220 | "metadata": {}, 221 | "output_type": "execute_result" 222 | } 223 | ], 224 | "source": [ 225 | "probs.size()" 226 | ] 227 | } 228 | ], 229 | "metadata": { 230 | "interpreter": { 231 | "hash": "3f14271c985fdb52f632c7afa9846b0e350184d57d39e3537b80b6670080e2b2" 232 | }, 233 | "kernelspec": { 234 | "display_name": "Python 3.9.12 ('pytorch')", 235 | "language": "python", 236 | "name": "python3" 237 | }, 238 | "language_info": { 239 | "codemirror_mode": { 240 | "name": "ipython", 241 | "version": 3 242 | }, 243 | "file_extension": ".py", 244 | "mimetype": "text/x-python", 245 | "name": "python", 246 | "nbconvert_exporter": "python", 247 | "pygments_lexer": "ipython3", 248 | "version": "3.9.12" 249 | }, 250 | "orig_nbformat": 4 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 2 254 | } 255 | -------------------------------------------------------------------------------- /Transformer/the extra annotated transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | import math 6 | 7 | class EncoderDecoder(nn.Module): 8 | ''' 9 | A standard encoder-decoder architecture. 10 | ''' 11 | 12 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): 13 | super(EncoderDecoder, self).__init__() 14 | self.encoder = encoder 15 | self.decoder = decoder 16 | self.src_embed = src_embed 17 | self.tgt_embed = tgt_embed 18 | self.generator = generator 19 | 20 | def forward(self, src, tgt, src_mask, tgt_mask): 21 | ''' 22 | Process masked source and target sequences 23 | ''' 24 | return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask) 25 | 26 | def encode(self, src, src_mask): 27 | return self.encoder(self.src_embed(src), src_mask) 28 | 29 | def decode(self, memory, src_mask, tgt, tgt_mask): 30 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) 31 | 32 | class Generator(nn.Module): 33 | ''' 34 | Define Standard linear + softmax generation step 35 | ''' 36 | 37 | def __init__(self, d_model, vocab_size): 38 | super(Generator, self).__init__() 39 | self.proj = nn.Linear(d_model, vocab_size) 40 | 41 | def forward(self, x): 42 | return F.log_softmax(self.proj(x), dim=-1) 43 | 44 | def make_clones(module, N): 45 | ''' 46 | Produce N identical layers 47 | ''' 48 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 49 | 50 | class Encoder(nn.Module): 51 | ''' 52 | The Encoder is a stack of N Encoder Blocks each with their own weights. 53 | ''' 54 | 55 | def __init__(self, encoder_block, N): 56 | super(Encoder, self).__init__() 57 | self.blocks = make_clones(encoder_block, N) 58 | self.norm = nn.LayerNorm(encoder_block.d_model) 59 | 60 | def forward(self, x, mask): 61 | ''' 62 | Pass the input (and mask) through each layer in turn 63 | ''' 64 | for block in self.blocks: 65 | x = block(x, mask) 66 | # normalize the final output 67 | return self.norm(x) 68 | 69 | class SublayerConnection(nn.Module): 70 | ''' 71 | A skip layer connection. It takes the layer and the input to the layer 72 | and performs the skip layer operation: x + layer(x). 73 | In addition, the input is normalized before feeding into the layer and 74 | dropout is applied to the output of the layer. 75 | 76 | Parameters 77 | ----------- 78 | inpt_size - int 79 | The dimension of the input tensor. 80 | p_dropout - float 81 | The dropout probability to be applied to the output of the layer 82 | ''' 83 | 84 | def __init__(self, inpt_size, p_dropout): 85 | super(SublayerConnection, self).__init__() 86 | self.norm = nn.LayerNorm(inpt_size) 87 | self.dropout = nn.Dropout(p_dropout) 88 | 89 | def forward(self, x, sublayer): 90 | ''' 91 | Apply residual connection to any sublayer of the same size 92 | ''' 93 | return x + self.dropout(sublayer(self.norm(x))) 94 | 95 | class EncoderBlock(nn.Module): 96 | ''' 97 | An Encoder Block consists of a multi-head self-attention layer followed 98 | by a feed-forward network with skip connections for each layer. 99 | 100 | Parameters 101 | ---------- 102 | d_model - int 103 | Size of the input dimension which is usually the embedding size 104 | self_attn - function which calls the module 105 | The multi-head self-attention layer 106 | feed_forward - nn.Module 107 | The feed-forward network 108 | p_dropout - float 109 | The dropout probability to be applied to the outputs of each layer 110 | ''' 111 | 112 | def __init__(self, d_model, self_attn, feed_forward, p_dropout): 113 | super(EncoderBlock, self).__init__() 114 | self.self_attn = self_attn 115 | self.feed_forward = feed_forward 116 | self.sublayer = make_clones(SublayerConnection(d_model, p_dropout), 2) 117 | self.d_model = d_model 118 | 119 | def forward(self, x, mask): 120 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 121 | return self.sublayer[1](x, self.feed_forward) 122 | 123 | class Decoder(nn.Module): 124 | ''' 125 | The Decoder is a stack of N Decoder Blocks each with their own weights. 126 | ''' 127 | 128 | def __init__(self, block, N): 129 | super(Decoder, self).__init__() 130 | self.blocks = make_clones(block, N) 131 | self.norm = nn.LayerNorm(block.d_model) 132 | 133 | def forward(self, x, memory, src_mask, tgt_mask): 134 | for block in self.blocks: 135 | x = block(x, memory, src_mask, tgt_mask) 136 | return self.norm(x) 137 | 138 | class DecoderBlock(nn.Module): 139 | ''' 140 | The Deocder Block is composed of a multi-head self attention layer followed by a 141 | multi-head encoder attention layer followed by a feed forward network with skip 142 | connections for each layer. 143 | ''' 144 | 145 | def __init__(self, d_model, self_attn, src_attn, feed_forward, p_dropout): 146 | super(DecoderBlock, self).__init__() 147 | self.d_model = d_model 148 | self.self_attn = self_attn 149 | self.src_attn = src_attn 150 | self.feed_forward = feed_forward 151 | self.sublayer = make_clones(SublayerConnection(d_model, p_dropout), 3) 152 | 153 | def forward(self, x, memory, src_mask, tgt_mask): 154 | ''' 155 | x - Floating point tensor of size (batch_size, tgt_sentence_length, embedding_dim) 156 | The target sequence tensor which is input to the decoder 157 | memory - Floating point tensor of size (batch_size, inpt_sentence_length, embedding_dim) 158 | The final output of the encoder. 159 | src_mask - Arbitary Tensor of shape (*, 1, inpt_sentence_length) 160 | mask to be applied on encoder-attention scores to hide certain words in the input sentence. 161 | tgt_mask - Arbitary Tensor of shape (*, tgt_sentence_length, tgt_sentence_length) 162 | Mask to be applied on self-attention scores in the decoder to prevent the current word from 163 | looking at subsequent future words. 164 | ''' 165 | m = memory 166 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 167 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 168 | return self.sublayer[2](x, self.feed_forward) 169 | 170 | def subsequent_mask(size): 171 | ''' 172 | A mask that prevents current word from looking at subsequent future words. 173 | Usually applied to the decoder input. 174 | ''' 175 | attn_shape = (1, size, size) 176 | mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8) 177 | return mask == 0 178 | 179 | def attention(query, key, value, mask=None, dropout=None): 180 | ''' 181 | Compute scaled dot product attention. 182 | 183 | Parameters 184 | ----------- 185 | query, key, value - Floating point tensors of shape (batch_size, n_heads, sentence_length, head_dim) 186 | The projected multi-head query, key, and value tensors. 187 | batch_size, no. of heads, and head dim should be the same for all of them. 188 | However, queries and keys can have different lengths, while keys and values 189 | should have the same length. 190 | mask - Tensor of arbitary size (*, query_length, key_length) or (*, key_length) 191 | The mask will be broadcasted and applied across the raw attention scores computed from queries and keys. 192 | 193 | Returns 194 | -------- 195 | attn weighed values - Floating point tensors of shape (batch_size, n_heads, query_length, head_dim) 196 | The attention weighed value tensors 197 | attn - Floating point tensors of shape (batch_size, n_heads, query_length, key_length) 198 | The attn probability scores computed for each head for a batch of queries and keys 199 | ''' 200 | d_k = query.size(-1) 201 | # (batch, heads, q_size, d_k) x (batch, heads, d_k, k_size) = (batch, heads, q_size, k_size) 202 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 203 | if mask is not None: 204 | scores = scores.masked_fill(mask == 0, -1e9) 205 | p_attn = F.softmax(scores, dim=-1) 206 | if dropout is not None: 207 | p_attn = dropout(p_attn) 208 | # (batch, heads, q_size, k_size) x (batch, heads, k_size, d_k) = (batch, heads, q_size, d_k) 209 | return torch.matmul(p_attn, value), p_attn 210 | 211 | class MultiHeadedAttention(nn.Module): 212 | ''' 213 | Compute Multi-head self-attention 214 | 215 | Parameters 216 | ----------- 217 | n_heads - int 218 | Number of heads the query, key, and value tensors should be split into 219 | d_model - int 220 | Size of the input dimension which is usually the embedding size 221 | p_dropout - float 222 | Dropout probability to be applied to the attention scores 223 | ''' 224 | 225 | def __init__(self, n_heads, d_model, p_dropout=0.1): 226 | super(MultiHeadedAttention, self).__init__() 227 | assert d_model % n_heads == 0 228 | # we assume d_v always equals d_k 229 | self.d_k = d_model // n_heads # dim of each head 230 | self.n_heads = n_heads 231 | # the first three layers project the input into queries, keys, and values. 232 | # the last layer performs a 1x1 convolution operation on the final concatenated attention vector. 233 | self.linear_layers = make_clones(nn.Linear(d_model, d_model), 4) 234 | self.attn = None # placeholder for attention scores 235 | self.dropout = nn.Dropout(p_dropout) 236 | 237 | def forward(self, query, key, value, mask=None): 238 | 239 | if mask is not None: 240 | # broadcast the same mask to all the heads 241 | mask = mask.unsqueeze(1) 242 | 243 | batch_size = query.size(0) 244 | 245 | # 1) project the input into queries, keys, and values 246 | # and split each of them into multiple heads: d_model => n_heads x d_k 247 | query, key, value = [ 248 | lin_layer(x).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) 249 | for lin_layer, x in zip(self.linear_layers, (query, key, value)) 250 | ] 251 | 252 | # 2) Apply attention on all the projected multi-head tensors in batch 253 | x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) 254 | 255 | # 3) Concat all the heads using a view 256 | x = ( 257 | x.transpose(1, 2) 258 | .contiguous() 259 | .view(batch_size, -1, self.n_heads * self.d_k) 260 | ) 261 | 262 | del query 263 | del key 264 | del value 265 | 266 | # apply a final linear layer (1 x 1 convolution) 267 | return self.linear_layers[-1](x) 268 | 269 | class PositionwiseFeedForward(nn.Module): 270 | ''' 271 | A simple feed-forward network applied on the attention tensor at the end 272 | 273 | Parameters 274 | ----------- 275 | d_model - int 276 | Size of the input dimension which is usually the embedding size 277 | dff - int 278 | size of hidden state in the network 279 | p_dropout - float 280 | dropout probability to be applied to the hidden layer 281 | ''' 282 | def __init__(self, d_model, d_ff, p_dropout=0.1): 283 | super(PositionwiseFeedForward, self).__init__() 284 | self.w_1 = nn.Linear(d_model, d_ff) 285 | self.w_2 = nn.Linear(d_ff, d_model) 286 | self.dropout = nn.Dropout(p_dropout) 287 | 288 | def forward(self, x): 289 | out = F.relu(self.w_1(x)) 290 | out = self.dropout(out) 291 | out = self.w_2(out) 292 | return out 293 | 294 | class Embeddings(nn.Module): 295 | def __init__(self, vocab_size, d_model): 296 | super(Embeddings, self).__init__() 297 | self.embedding = nn.Embedding(vocab_size, d_model) 298 | self.d_model = d_model 299 | 300 | def forward(self, x): 301 | return self.embedding(x) * math.sqrt(self.d_model) 302 | 303 | class PositionalEncoding(nn.Module): 304 | 305 | def __init__(self, d_model, p_dropout, max_len=5000): 306 | super(PositionalEncoding, self).__init__() 307 | self.dropout = nn.Dropout(p_dropout) 308 | 309 | # compute positional encodings once in log space 310 | pe = torch.zeros(max_len, d_model) 311 | position = torch.arange(0, max_len).unsqueeze(1) 312 | div_term = torch.exp( 313 | torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) 314 | ) 315 | pe[:, 0::2] = torch.sin(position * div_term) 316 | pe[:, 1::2] = torch.cos(position * div_term) 317 | pe = pe.unsqueeze(0) 318 | self.register_buffer("pe", pe) 319 | 320 | def forward(self, x): 321 | x = x + self.pe[:, :x.size(1)].requires_grad_(False) 322 | return self.dropout(x) 323 | -------------------------------------------------------------------------------- /attention/basic transformer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ea7dac6f", 6 | "metadata": {}, 7 | "source": [ 8 | "## Basic Transformer" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 2, 14 | "id": "297db42b", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "\n", 20 | "import torch\n", 21 | "import torch.nn as nn\n", 22 | "import torch.nn.functional as F" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "6a0f7fea", 28 | "metadata": {}, 29 | "source": [ 30 | "In this exercise, we give a network an input sequence of characters (e.g., **aabbccdd**), the network is trained to produce the same sequence in reverse order (**ddccbbaa**)." 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "6fae86b9", 36 | "metadata": {}, 37 | "source": [ 38 | "### Data Preparation" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "id": "ed6a70f2", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# create a dataset of 500 examples using a small vocabulary\n", 49 | "# we will try training on sequences of length 10 and testing on sequences of length 15\n", 50 | "# this setup tests whether the model has actually learned an algorithm to reverse its input\n", 51 | "vocab = {'a': 0, 'b': 1, 'c':2, 'd':3, 'e':4, '':5, '':6}\n", 52 | "idx_to_w = dict((v, k) for (k,v) in vocab.items())\n", 53 | "train_seq_len = 10\n", 54 | "num_train_examples = 5" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "id": "2e62d133", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# generate toy data\n", 65 | "train_inputs = torch.LongTensor(num_train_examples, train_seq_len).random_(0\n", 66 | ", len(vocab)-2) # random sequences\n", 67 | "inv_idx = torch.arange(train_seq_len-1, -1, -1).long()\n", 68 | "train_outputs = train_inputs[:, inv_idx] # outputs are just the reverse of the input\n", 69 | "sos_vec = torch.LongTensor(num_train_examples, 1)\n", 70 | "sos_vec[:] = vocab['']\n", 71 | "eos_vec = torch.LongTensor(num_train_examples, 1)\n", 72 | "eos_vec[:] = vocab['']\n", 73 | "train_encoder_input = torch.cat((train_inputs, eos_vec), 1)\n", 74 | "train_decoder_input = torch.cat((sos_vec, train_outputs), 1)\n", 75 | "train_targets = torch.cat((train_outputs, eos_vec), 1)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 5, 81 | "id": "af8839c1", 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "encoder input : e c a a e e b e a b \n", 89 | "decoder input: b a e b e e a a c e\n", 90 | "decoder target: b a e b e e a a c e \n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "print('encoder input :', ' '.join([idx_to_w[w] for w in train_encoder_input[0].numpy()]))\n", 96 | "print('decoder input:', ' '.join([idx_to_w[w] for w in train_decoder_input[0].numpy()]))\n", 97 | "print('decoder target:', ' '.join([idx_to_w[w] for w in train_targets[0].numpy()]))" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "id": "2a86afc4", 103 | "metadata": {}, 104 | "source": [ 105 | "### Build Model" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "id": "9bbef16a", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "torch.set_printoptions(precision=2, sci_mode=False)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 7, 121 | "id": "615a81ec", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "# class for our vanilla seq2seq\n", 126 | "class Seq2Seq(nn.Module):\n", 127 | " def __init__(self, char_dim, hidden_size, vocab_size):\n", 128 | " super().__init__()\n", 129 | " \n", 130 | " self.char_dim = char_dim\n", 131 | " self.hidden_size = hidden_size\n", 132 | " self.vocab_size = vocab_size\n", 133 | "\n", 134 | " # character embeddings\n", 135 | " self.char_embeds = nn.Embedding(vocab_size, char_dim)\n", 136 | " \n", 137 | " # position embeddings\n", 138 | " self.pos_embeds = nn.Embedding(15, char_dim) # add these to enc/dec\n", 139 | " \n", 140 | " # decoder attention\n", 141 | " self.query = nn.Linear(char_dim, hidden_size)\n", 142 | " self.key = nn.Linear(char_dim, hidden_size)\n", 143 | " self.value = nn.Linear(char_dim, hidden_size)\n", 144 | " \n", 145 | " # output layer (softmax will be applied after this)\n", 146 | " self.cls = nn.Linear(hidden_size, vocab_size)\n", 147 | " \n", 148 | " # a vectorized way of computing self attention for all queries efficiently\n", 149 | " def smart_unmasked_attn(self, qs, ks, vs):\n", 150 | " # here, queries are decoder states, keys and values are encoder representations\n", 151 | " scores = qs @ ks.t() # get all dot products at once, N x N\n", 152 | " scores = F.softmax(scores, dim=1)\n", 153 | " return scores @ vs # N x hidden_size \n", 154 | " \n", 155 | " # a vectorized way of computing **target-side self-attention**\n", 156 | " # we need to implement some masking to avoid cheating!\n", 157 | " def smart_masked_attn(self, qs, ks, vs):\n", 158 | " max_len = qs.size(0)\n", 159 | " mask = torch.tril(torch.ones(max_len, max_len))\n", 160 | " scores = qs @ ks.t() # get all UNMASKED dot products at once, max_len X max_len\n", 161 | " scores = scores.masked_fill(mask == 0, -1e9)\n", 162 | " scores = F.softmax(scores, dim=1)\n", 163 | " return scores @ vs\n", 164 | "\n", 165 | " \n", 166 | " def forward(self, inputs, decoder_inputs):\n", 167 | " \n", 168 | " batch_size, max_len = inputs.size()\n", 169 | "\n", 170 | " positions = torch.arange(0, inputs.size(1))\n", 171 | " pos_embeds = self.pos_embeds(positions)\n", 172 | " \n", 173 | " # we'll just consider this the output of our encoder\n", 174 | " # of course in a real transformer this would be computed\n", 175 | " # through multiple self attention blocks\n", 176 | " e_embeds = self.char_embeds(inputs).squeeze(0)\n", 177 | " e_embeds = e_embeds + pos_embeds\n", 178 | " e_keys = self.key(e_embeds)\n", 179 | " e_values = self.value(e_embeds)\n", 180 | " \n", 181 | " # we'll use the same weights to project decoder embeddings to q,k,v\n", 182 | " d_embeds = self.char_embeds(decoder_inputs).squeeze(0)\n", 183 | " d_embeds = d_embeds + pos_embeds\n", 184 | " d_queries = self.query(d_embeds)\n", 185 | " d_keys = self.key(d_embeds)\n", 186 | " d_values = self.value(d_embeds)\n", 187 | "\n", 188 | " # compute target side self attention\n", 189 | " fast_decoder_states = self.smart_masked_attn(d_queries, d_keys, d_values)\n", 190 | "\n", 191 | " # source attention, queries come from decoder, keys/values from encoder\n", 192 | " source_attn = self.smart_unmasked_attn(fast_decoder_states, e_keys, e_values)\n", 193 | " \n", 194 | " # combine decoder self attention w/ source attention\n", 195 | " source_attn = source_attn + fast_decoder_states\n", 196 | "\n", 197 | " # now do prediction over decoder states (reshape to 2d first)\n", 198 | " source_attn = source_attn.transpose(0, 1).contiguous().view(-1, self.hidden_size)\n", 199 | " decoder_preds = self.cls(source_attn)\n", 200 | " decoder_preds = F.log_softmax(decoder_preds, dim=1)\n", 201 | "\n", 202 | " return decoder_preds" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "id": "fc408ff6", 208 | "metadata": {}, 209 | "source": [ 210 | "### Train the model" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 8, 216 | "id": "ce8cb377", 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "def training_loop(net):\n", 221 | "\n", 222 | " # set some hyperparameters for training the network\n", 223 | " idx_to_w = dict((v,k) for (k,v) in vocab.items())\n", 224 | " loss_fn = nn.NLLLoss()\n", 225 | " optimizer = torch.optim.Adam(net.parameters(), lr=0.01)\n", 226 | " num_epochs = 10\n", 227 | " \n", 228 | " # okay, let's train the network!\n", 229 | " for ep in range(num_epochs):\n", 230 | " ep_loss = 0.\n", 231 | "\n", 232 | " for start in range(0, len(train_inputs)):\n", 233 | " e_in_batch = train_encoder_input[start].unsqueeze(0)\n", 234 | " d_in_batch = train_decoder_input[start].unsqueeze(0)\n", 235 | " d_targ_batch = train_targets[start].unsqueeze(0)\n", 236 | " \n", 237 | " preds = net(e_in_batch, d_in_batch)\n", 238 | " batch_loss = loss_fn(preds, d_targ_batch.view(-1))\n", 239 | " ep_loss += batch_loss\n", 240 | "\n", 241 | " # compute gradients\n", 242 | " optimizer.zero_grad() # reset the gradients from the last batch\n", 243 | " batch_loss.backward() # does backprop!!!\n", 244 | " optimizer.step() # updates parameters using gradients\n", 245 | "\n", 246 | " print('epoch %d, loss %f\\n' % (ep, ep_loss))" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 9, 252 | "id": "a2160890", 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "name": "stdout", 257 | "output_type": "stream", 258 | "text": [ 259 | "epoch 0, loss 9.770307\n", 260 | "\n", 261 | "epoch 1, loss 5.881337\n", 262 | "\n", 263 | "epoch 2, loss 4.329915\n", 264 | "\n", 265 | "epoch 3, loss 3.894290\n", 266 | "\n", 267 | "epoch 4, loss 2.730859\n", 268 | "\n", 269 | "epoch 5, loss 1.853093\n", 270 | "\n", 271 | "epoch 6, loss 1.146887\n", 272 | "\n", 273 | "epoch 7, loss 0.653309\n", 274 | "\n", 275 | "epoch 8, loss 0.368380\n", 276 | "\n", 277 | "epoch 9, loss 0.222095\n", 278 | "\n" 279 | ] 280 | } 281 | ], 282 | "source": [ 283 | "# build the network\n", 284 | "net = Seq2Seq(32, 64, len(vocab))\n", 285 | "training_loop(net)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "id": "35cf0ce7", 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [] 295 | } 296 | ], 297 | "metadata": { 298 | "kernelspec": { 299 | "display_name": "Python 3 (ipykernel)", 300 | "language": "python", 301 | "name": "python3" 302 | }, 303 | "language_info": { 304 | "codemirror_mode": { 305 | "name": "ipython", 306 | "version": 3 307 | }, 308 | "file_extension": ".py", 309 | "mimetype": "text/x-python", 310 | "name": "python", 311 | "nbconvert_exporter": "python", 312 | "pygments_lexer": "ipython3", 313 | "version": "3.7.11" 314 | } 315 | }, 316 | "nbformat": 4, 317 | "nbformat_minor": 5 318 | } 319 | -------------------------------------------------------------------------------- /attention/self attention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "6e0aa209", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "import torch.nn.functional as F" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "113b7f88", 18 | "metadata": {}, 19 | "source": [ 20 | "### Data Preparation" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "id": "83098fa6", 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "[[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14], [15, 16, 12, 17], [18, 19, 20, 21, 22, 23, 24]]\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "# spam detection!\n", 39 | "data = ['you won a billion dollars , great work !',\n", 40 | " 'click here for cs685 midterm answers',\n", 41 | " 'read important cs685 news',\n", 42 | " 'send me your bank account info asap']\n", 43 | "\n", 44 | "labels = torch.LongTensor([1, 1, 0, 1]) # store ground-truth labels\n", 45 | "\n", 46 | "# let's do some preprocessing\n", 47 | "vocab = {}\n", 48 | "inputs = []\n", 49 | "\n", 50 | "for sentence in data:\n", 51 | " idxs = []\n", 52 | " sentence = sentence.split()\n", 53 | " for word in sentence:\n", 54 | " if word not in vocab:\n", 55 | " vocab[word] = len(vocab)\n", 56 | " idxs.append(vocab[word])\n", 57 | " inputs.append(idxs)\n", 58 | " \n", 59 | "print(inputs)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "93ee681c", 65 | "metadata": {}, 66 | "source": [ 67 | "### Build the model" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "id": "29ed09a0", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "class SelfAttentionNN(nn.Module):\n", 78 | " \n", 79 | " def __init__(self, embedding_dim, vocab_size):\n", 80 | " \n", 81 | " super().__init__()\n", 82 | " self.embedding_dim = embedding_dim\n", 83 | " self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n", 84 | " \n", 85 | " self.Wq = nn.Linear(embedding_dim, embedding_dim) # project to query space\n", 86 | " self.Wk = nn.Linear(embedding_dim, embedding_dim) # project to keys\n", 87 | " self.Wv = nn.Linear(embedding_dim, embedding_dim) # project to values\n", 88 | " \n", 89 | " # the final classification layer\n", 90 | " self.cls = nn.Linear(embedding_dim, 2)\n", 91 | " \n", 92 | " # all three args are T x embedding_dim matrices!\n", 93 | " def dot_product_attn(self, q, k, v):\n", 94 | " scores = q @ k.t() # gets all dot products at once, T X T\n", 95 | " scores = F.softmax(scores, dim=1)\n", 96 | " return scores @ v # T x embedding_dim\n", 97 | " \n", 98 | " # you can implement the three below for fun!\n", 99 | " def bilinear_attn(self, q, k):\n", 100 | " pass\n", 101 | " \n", 102 | " def scaled_dot_product_attn(self, q, k):\n", 103 | " pass\n", 104 | " \n", 105 | " def mlp_attn(self, q, k):\n", 106 | " pass\n", 107 | " \n", 108 | " def forward(self, inpt_sentence):\n", 109 | " T = inpt_sentence.size(0) # number of tokens in input, assume T > 2\n", 110 | " word_embeds = self.embeddings(inpt_sentence) # T x embedding_dim\n", 111 | " \n", 112 | " queries = self.Wq(word_embeds) # T x embedding_dim\n", 113 | " keys = self.Wk(word_embeds) # T x embedding_dim\n", 114 | " values = self.Wv(word_embeds) # T x embedding_dim\n", 115 | "\n", 116 | " # efficient attention computation\n", 117 | " attn_reps = self.dot_product_attn(queries, keys, values)\n", 118 | "\n", 119 | " # compose attn_reps into a single vector\n", 120 | " attn_reps = torch.mean(attn_reps, dim=0)\n", 121 | "\n", 122 | " pred = self.cls(attn_reps) # return logits\n", 123 | " return pred.unsqueeze(0)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "id": "c961e478", 129 | "metadata": {}, 130 | "source": [ 131 | "#### Test Inference" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 4, 137 | "id": "aa1d4b2c", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "embedding_dim = 32\n", 142 | "vocab_size = len(vocab)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 5, 148 | "id": "cef5dd0f", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "model = SelfAttentionNN(embedding_dim, vocab_size)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 6, 158 | "id": "2207e18b", 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "tensor([[-0.2150, 0.3468]])\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "sample_input = torch.LongTensor([1, 2, 3, 4])\n", 171 | "\n", 172 | "with torch.no_grad():\n", 173 | " out = model(sample_input)\n", 174 | " print(out)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "id": "3bb039b9", 180 | "metadata": {}, 181 | "source": [ 182 | "### Train the model" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 7, 188 | "id": "f78f70fe", 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "num_epochs = 10\n", 193 | "loss_fn = nn.CrossEntropyLoss()\n", 194 | "optim = torch.optim.SGD(model.parameters(), lr = 0.1)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 8, 200 | "id": "ba52a9ee", 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "name": "stdout", 205 | "output_type": "stream", 206 | "text": [ 207 | "0 2.8318784534931183\n", 208 | "1 1.7424476146697998\n", 209 | "2 0.9125520139932632\n", 210 | "3 0.44691064208745956\n", 211 | "4 0.2501957528293133\n", 212 | "5 0.16152603924274445\n", 213 | "6 0.11486340966075659\n", 214 | "7 0.08715852349996567\n", 215 | "8 0.06921625044196844\n", 216 | "9 0.05681771645322442\n" 217 | ] 218 | } 219 | ], 220 | "source": [ 221 | "# training loop\n", 222 | "for epoch in range(num_epochs):\n", 223 | " ep_loss = 0. # loss per epoch\n", 224 | " \n", 225 | " for i in range(len(inputs)):\n", 226 | " # get input sentence and target label\n", 227 | " inpt_sentence = torch.LongTensor(inputs[i])\n", 228 | " target = labels[i].unsqueeze(0)\n", 229 | " \n", 230 | " pred = model(inpt_sentence)\n", 231 | " loss = loss_fn(pred, target)\n", 232 | " \n", 233 | " optim.zero_grad()\n", 234 | " loss.backward()\n", 235 | " optim.step()\n", 236 | " \n", 237 | " ep_loss += loss.item()\n", 238 | " \n", 239 | " print(epoch, ep_loss)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "id": "9914a28c", 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [] 249 | } 250 | ], 251 | "metadata": { 252 | "kernelspec": { 253 | "display_name": "Python 3 (ipykernel)", 254 | "language": "python", 255 | "name": "python3" 256 | }, 257 | "language_info": { 258 | "codemirror_mode": { 259 | "name": "ipython", 260 | "version": 3 261 | }, 262 | "file_extension": ".py", 263 | "mimetype": "text/x-python", 264 | "name": "python", 265 | "nbconvert_exporter": "python", 266 | "pygments_lexer": "ipython3", 267 | "version": "3.9.12" 268 | } 269 | }, 270 | "nbformat": 4, 271 | "nbformat_minor": 5 272 | } 273 | -------------------------------------------------------------------------------- /charRNN/charRNN_batching.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wingedrasengan927/pytorch-tutorials/20ddc1fda3750550c8426799187276619a7f7610/charRNN/charRNN_batching.jpg -------------------------------------------------------------------------------- /charRNN/generated_text_out.txt: -------------------------------------------------------------------------------- 1 | Once upon a time. 2 | 3 | "This were in success." 4 | 5 | "No, this is through the matter to them to me wonse for my feeling. I'm a 6 | pretension of my carriage." 7 | 8 | "Well, I'm not going, and I won't be in the made. 9 | 10 | "I'll say about your hand." But the conversation he went up to his wife. 11 | Here was they would not go away, and winging to all. She was not, as he 12 | could not have said that his wife had been trying to discover him, but 13 | he could not have to drunk in his brother's face, and seemed at once 14 | taken again, she herself went on still more and made her feeling and did not, that 15 | he had back her heart so much. Her feet. He was alroad in a steather in a 16 | shade words and a corture of his heart to herself. 17 | 18 | "Ah, I'm touching to you," she added, and he was sorry and the 19 | paper--shall waiting to share her arrively to his hat. And the 20 | carriage and her head and she went to by somewhere as they he had the 21 | same time he could not ask him; but the point would be desciret on a 22 | plincing, as a pleasure his senth of the subject of the world when 23 | he had so going to have a subject and they should get through all the drove 24 | whe his chief worries of till. 25 | 26 | "Ah!'" said the same thing alwads and working. She was that she 27 | called her husband. 28 | 29 | "Yes, but I have to drive to so married as to horror. They are all 30 | think of it. Here are new to this mind, while they have so much as 31 | it is not, but I can come in to the most intense of his better 32 | to straight in their position," said Levin at the sound of 33 | her heart. "Where worse are you to go?" had the letter said to her that, and 34 | he was at all that she was told him to be absolbed to him towards 35 | them, she did not said a long while. 36 | 37 | She does always hope of his heart, to see the peasants and at the convincon, 38 | as in the children with who had seen her this steps, and took off 39 | his hair, annerstonded her head on the carriage and his brothing significance. 40 | 41 | "When you weren't there? Yes, I'll go trisitals. If it is all alone. And to 42 | be angry. And it more had been as about them," he said. "Why do -------------------------------------------------------------------------------- /pytorch basics/README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch Basics 2 | My notes, experiments, implementations from the book ["Deep Learning with pytorch"](https://pytorch.org/assets/deep-learning/Deep-Learning-with-PyTorch.pdf) -------------------------------------------------------------------------------- /pytorch basics/imagenet_classes.txt: -------------------------------------------------------------------------------- 1 | tench, Tinca tinca 2 | goldfish, Carassius auratus 3 | great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias 4 | tiger shark, Galeocerdo cuvieri 5 | hammerhead, hammerhead shark 6 | electric ray, crampfish, numbfish, torpedo 7 | stingray 8 | cock 9 | hen 10 | ostrich, Struthio camelus 11 | brambling, Fringilla montifringilla 12 | goldfinch, Carduelis carduelis 13 | house finch, linnet, Carpodacus mexicanus 14 | junco, snowbird 15 | indigo bunting, indigo finch, indigo bird, Passerina cyanea 16 | robin, American robin, Turdus migratorius 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water ouzel, dipper 22 | kite 23 | bald eagle, American eagle, Haliaeetus leucocephalus 24 | vulture 25 | great grey owl, great gray owl, Strix nebulosa 26 | European fire salamander, Salamandra salamandra 27 | common newt, Triturus vulgaris 28 | eft 29 | spotted salamander, Ambystoma maculatum 30 | axolotl, mud puppy, Ambystoma mexicanum 31 | bullfrog, Rana catesbeiana 32 | tree frog, tree-frog 33 | tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui 34 | loggerhead, loggerhead turtle, Caretta caretta 35 | leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea 36 | mud turtle 37 | terrapin 38 | box turtle, box tortoise 39 | banded gecko 40 | common iguana, iguana, Iguana iguana 41 | American chameleon, anole, Anolis carolinensis 42 | whiptail, whiptail lizard 43 | agama 44 | frilled lizard, Chlamydosaurus kingi 45 | alligator lizard 46 | Gila monster, Heloderma suspectum 47 | green lizard, Lacerta viridis 48 | African chameleon, Chamaeleo chamaeleon 49 | Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis 50 | African crocodile, Nile crocodile, Crocodylus niloticus 51 | American alligator, Alligator mississipiensis 52 | triceratops 53 | thunder snake, worm snake, Carphophis amoenus 54 | ringneck snake, ring-necked snake, ring snake 55 | hognose snake, puff adder, sand viper 56 | green snake, grass snake 57 | king snake, kingsnake 58 | garter snake, grass snake 59 | water snake 60 | vine snake 61 | night snake, Hypsiglena torquata 62 | boa constrictor, Constrictor constrictor 63 | rock python, rock snake, Python sebae 64 | Indian cobra, Naja naja 65 | green mamba 66 | sea snake 67 | horned viper, cerastes, sand viper, horned asp, Cerastes cornutus 68 | diamondback, diamondback rattlesnake, Crotalus adamanteus 69 | sidewinder, horned rattlesnake, Crotalus cerastes 70 | trilobite 71 | harvestman, daddy longlegs, Phalangium opilio 72 | scorpion 73 | black and gold garden spider, Argiope aurantia 74 | barn spider, Araneus cavaticus 75 | garden spider, Aranea diademata 76 | black widow, Latrodectus mactans 77 | tarantula 78 | wolf spider, hunting spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse, partridge, Bonasa umbellus 84 | prairie chicken, prairie grouse, prairie fowl 85 | peacock 86 | quail 87 | partridge 88 | African grey, African gray, Psittacus erithacus 89 | macaw 90 | sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted merganser, Mergus serrator 100 | goose 101 | black swan, Cygnus atratus 102 | tusker 103 | echidna, spiny anteater, anteater 104 | platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus 105 | wallaby, brush kangaroo 106 | koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus 107 | wombat 108 | jellyfish 109 | sea anemone, anemone 110 | brain coral 111 | flatworm, platyhelminth 112 | nematode, nematode worm, roundworm 113 | conch 114 | snail 115 | slug 116 | sea slug, nudibranch 117 | chiton, coat-of-mail shell, sea cradle, polyplacophore 118 | chambered nautilus, pearly nautilus, nautilus 119 | Dungeness crab, Cancer magister 120 | rock crab, Cancer irroratus 121 | fiddler crab 122 | king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica 123 | American lobster, Northern lobster, Maine lobster, Homarus americanus 124 | spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish 125 | crayfish, crawfish, crawdad, crawdaddy 126 | hermit crab 127 | isopod 128 | white stork, Ciconia ciconia 129 | black stork, Ciconia nigra 130 | spoonbill 131 | flamingo 132 | little blue heron, Egretta caerulea 133 | American egret, great white heron, Egretta albus 134 | bittern 135 | crane 136 | limpkin, Aramus pictus 137 | European gallinule, Porphyrio porphyrio 138 | American coot, marsh hen, mud hen, water hen, Fulica americana 139 | bustard 140 | ruddy turnstone, Arenaria interpres 141 | red-backed sandpiper, dunlin, Erolia alpina 142 | redshank, Tringa totanus 143 | dowitcher 144 | oystercatcher, oyster catcher 145 | pelican 146 | king penguin, Aptenodytes patagonica 147 | albatross, mollymawk 148 | grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus 149 | killer whale, killer, orca, grampus, sea wolf, Orcinus orca 150 | dugong, Dugong dugon 151 | sea lion 152 | Chihuahua 153 | Japanese spaniel 154 | Maltese dog, Maltese terrier, Maltese 155 | Pekinese, Pekingese, Peke 156 | Shih-Tzu 157 | Blenheim spaniel 158 | papillon 159 | toy terrier 160 | Rhodesian ridgeback 161 | Afghan hound, Afghan 162 | basset, basset hound 163 | beagle 164 | bloodhound, sleuthhound 165 | bluetick 166 | black-and-tan coonhound 167 | Walker hound, Walker foxhound 168 | English foxhound 169 | redbone 170 | borzoi, Russian wolfhound 171 | Irish wolfhound 172 | Italian greyhound 173 | whippet 174 | Ibizan hound, Ibizan Podenco 175 | Norwegian elkhound, elkhound 176 | otterhound, otter hound 177 | Saluki, gazelle hound 178 | Scottish deerhound, deerhound 179 | Weimaraner 180 | Staffordshire bullterrier, Staffordshire bull terrier 181 | American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier 182 | Bedlington terrier 183 | Border terrier 184 | Kerry blue terrier 185 | Irish terrier 186 | Norfolk terrier 187 | Norwich terrier 188 | Yorkshire terrier 189 | wire-haired fox terrier 190 | Lakeland terrier 191 | Sealyham terrier, Sealyham 192 | Airedale, Airedale terrier 193 | cairn, cairn terrier 194 | Australian terrier 195 | Dandie Dinmont, Dandie Dinmont terrier 196 | Boston bull, Boston terrier 197 | miniature schnauzer 198 | giant schnauzer 199 | standard schnauzer 200 | Scotch terrier, Scottish terrier, Scottie 201 | Tibetan terrier, chrysanthemum dog 202 | silky terrier, Sydney silky 203 | soft-coated wheaten terrier 204 | West Highland white terrier 205 | Lhasa, Lhasa apso 206 | flat-coated retriever 207 | curly-coated retriever 208 | golden retriever 209 | Labrador retriever 210 | Chesapeake Bay retriever 211 | German short-haired pointer 212 | vizsla, Hungarian pointer 213 | English setter 214 | Irish setter, red setter 215 | Gordon setter 216 | Brittany spaniel 217 | clumber, clumber spaniel 218 | English springer, English springer spaniel 219 | Welsh springer spaniel 220 | cocker spaniel, English cocker spaniel, cocker 221 | Sussex spaniel 222 | Irish water spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old English sheepdog, bobtail 231 | Shetland sheepdog, Shetland sheep dog, Shetland 232 | collie 233 | Border collie 234 | Bouvier des Flandres, Bouviers des Flandres 235 | Rottweiler 236 | German shepherd, German shepherd dog, German police dog, alsatian 237 | Doberman, Doberman pinscher 238 | miniature pinscher 239 | Greater Swiss Mountain dog 240 | Bernese mountain dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull mastiff 245 | Tibetan mastiff 246 | French bulldog 247 | Great Dane 248 | Saint Bernard, St Bernard 249 | Eskimo dog, husky 250 | malamute, malemute, Alaskan malamute 251 | Siberian husky 252 | dalmatian, coach dog, carriage dog 253 | affenpinscher, monkey pinscher, monkey dog 254 | basenji 255 | pug, pug-dog 256 | Leonberg 257 | Newfoundland, Newfoundland dog 258 | Great Pyrenees 259 | Samoyed, Samoyede 260 | Pomeranian 261 | chow, chow chow 262 | keeshond 263 | Brabancon griffon 264 | Pembroke, Pembroke Welsh corgi 265 | Cardigan, Cardigan Welsh corgi 266 | toy poodle 267 | miniature poodle 268 | standard poodle 269 | Mexican hairless 270 | timber wolf, grey wolf, gray wolf, Canis lupus 271 | white wolf, Arctic wolf, Canis lupus tundrarum 272 | red wolf, maned wolf, Canis rufus, Canis niger 273 | coyote, prairie wolf, brush wolf, Canis latrans 274 | dingo, warrigal, warragal, Canis dingo 275 | dhole, Cuon alpinus 276 | African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus 277 | hyena, hyaena 278 | red fox, Vulpes vulpes 279 | kit fox, Vulpes macrotis 280 | Arctic fox, white fox, Alopex lagopus 281 | grey fox, gray fox, Urocyon cinereoargenteus 282 | tabby, tabby cat 283 | tiger cat 284 | Persian cat 285 | Siamese cat, Siamese 286 | Egyptian cat 287 | cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 288 | lynx, catamount 289 | leopard, Panthera pardus 290 | snow leopard, ounce, Panthera uncia 291 | jaguar, panther, Panthera onca, Felis onca 292 | lion, king of beasts, Panthera leo 293 | tiger, Panthera tigris 294 | cheetah, chetah, Acinonyx jubatus 295 | brown bear, bruin, Ursus arctos 296 | American black bear, black bear, Ursus americanus, Euarctos americanus 297 | ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus 298 | sloth bear, Melursus ursinus, Ursus ursinus 299 | mongoose 300 | meerkat, mierkat 301 | tiger beetle 302 | ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle 303 | ground beetle, carabid beetle 304 | long-horned beetle, longicorn, longicorn beetle 305 | leaf beetle, chrysomelid 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant, emmet, pismire 312 | grasshopper, hopper 313 | cricket 314 | walking stick, walkingstick, stick insect 315 | cockroach, roach 316 | mantis, mantid 317 | cicada, cicala 318 | leafhopper 319 | lacewing, lacewing fly 320 | dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk 321 | damselfly 322 | admiral 323 | ringlet, ringlet butterfly 324 | monarch, monarch butterfly, milkweed butterfly, Danaus plexippus 325 | cabbage butterfly 326 | sulphur butterfly, sulfur butterfly 327 | lycaenid, lycaenid butterfly 328 | starfish, sea star 329 | sea urchin 330 | sea cucumber, holothurian 331 | wood rabbit, cottontail, cottontail rabbit 332 | hare 333 | Angora, Angora rabbit 334 | hamster 335 | porcupine, hedgehog 336 | fox squirrel, eastern fox squirrel, Sciurus niger 337 | marmot 338 | beaver 339 | guinea pig, Cavia cobaya 340 | sorrel 341 | zebra 342 | hog, pig, grunter, squealer, Sus scrofa 343 | wild boar, boar, Sus scrofa 344 | warthog 345 | hippopotamus, hippo, river horse, Hippopotamus amphibius 346 | ox 347 | water buffalo, water ox, Asiatic buffalo, Bubalus bubalis 348 | bison 349 | ram, tup 350 | bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis 351 | ibex, Capra ibex 352 | hartebeest 353 | impala, Aepyceros melampus 354 | gazelle 355 | Arabian camel, dromedary, Camelus dromedarius 356 | llama 357 | weasel 358 | mink 359 | polecat, fitch, foulmart, foumart, Mustela putorius 360 | black-footed ferret, ferret, Mustela nigripes 361 | otter 362 | skunk, polecat, wood pussy 363 | badger 364 | armadillo 365 | three-toed sloth, ai, Bradypus tridactylus 366 | orangutan, orang, orangutang, Pongo pygmaeus 367 | gorilla, Gorilla gorilla 368 | chimpanzee, chimp, Pan troglodytes 369 | gibbon, Hylobates lar 370 | siamang, Hylobates syndactylus, Symphalangus syndactylus 371 | guenon, guenon monkey 372 | patas, hussar monkey, Erythrocebus patas 373 | baboon 374 | macaque 375 | langur 376 | colobus, colobus monkey 377 | proboscis monkey, Nasalis larvatus 378 | marmoset 379 | capuchin, ringtail, Cebus capucinus 380 | howler monkey, howler 381 | titi, titi monkey 382 | spider monkey, Ateles geoffroyi 383 | squirrel monkey, Saimiri sciureus 384 | Madagascar cat, ring-tailed lemur, Lemur catta 385 | indri, indris, Indri indri, Indri brevicaudatus 386 | Indian elephant, Elephas maximus 387 | African elephant, Loxodonta africana 388 | lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens 389 | giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca 390 | barracouta, snoek 391 | eel 392 | coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch 393 | rock beauty, Holocanthus tricolor 394 | anemone fish 395 | sturgeon 396 | gar, garfish, garpike, billfish, Lepisosteus osseus 397 | lionfish 398 | puffer, pufferfish, blowfish, globefish 399 | abacus 400 | abaya 401 | academic gown, academic robe, judge's robe 402 | accordion, piano accordion, squeeze box 403 | acoustic guitar 404 | aircraft carrier, carrier, flattop, attack aircraft carrier 405 | airliner 406 | airship, dirigible 407 | altar 408 | ambulance 409 | amphibian, amphibious vehicle 410 | analog clock 411 | apiary, bee house 412 | apron 413 | ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 414 | assault rifle, assault gun 415 | backpack, back pack, knapsack, packsack, rucksack, haversack 416 | bakery, bakeshop, bakehouse 417 | balance beam, beam 418 | balloon 419 | ballpoint, ballpoint pen, ballpen, Biro 420 | Band Aid 421 | banjo 422 | bannister, banister, balustrade, balusters, handrail 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel, cask 429 | barrow, garden cart, lawn cart, wheelbarrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing cap, swimming cap 435 | bath towel 436 | bathtub, bathing tub, bath, tub 437 | beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon 438 | beacon, lighthouse, beacon light, pharos 439 | beaker 440 | bearskin, busby, shako 441 | beer bottle 442 | beer glass 443 | bell cote, bell cot 444 | bib 445 | bicycle-built-for-two, tandem bicycle, tandem 446 | bikini, two-piece 447 | binder, ring-binder 448 | binoculars, field glasses, opera glasses 449 | birdhouse 450 | boathouse 451 | bobsled, bobsleigh, bob 452 | bolo tie, bolo, bola tie, bola 453 | bonnet, poke bonnet 454 | bookcase 455 | bookshop, bookstore, bookstall 456 | bottlecap 457 | bow 458 | bow tie, bow-tie, bowtie 459 | brass, memorial tablet, plaque 460 | brassiere, bra, bandeau 461 | breakwater, groin, groyne, mole, bulwark, seawall, jetty 462 | breastplate, aegis, egis 463 | broom 464 | bucket, pail 465 | buckle 466 | bulletproof vest 467 | bullet train, bullet 468 | butcher shop, meat market 469 | cab, hack, taxi, taxicab 470 | caldron, cauldron 471 | candle, taper, wax light 472 | cannon 473 | canoe 474 | can opener, tin opener 475 | cardigan 476 | car mirror 477 | carousel, carrousel, merry-go-round, roundabout, whirligig 478 | carpenter's kit, tool kit 479 | carton 480 | car wheel 481 | cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello, violoncello 488 | cellular telephone, cellular phone, cellphone, cell, mobile phone 489 | chain 490 | chainlink fence 491 | chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour 492 | chain saw, chainsaw 493 | chest 494 | chiffonier, commode 495 | chime, bell, gong 496 | china cabinet, china closet 497 | Christmas stocking 498 | church, church building 499 | cinema, movie theater, movie theatre, movie house, picture palace 500 | cleaver, meat cleaver, chopper 501 | cliff dwelling 502 | cloak 503 | clog, geta, patten, sabot 504 | cocktail shaker 505 | coffee mug 506 | coffeepot 507 | coil, spiral, volute, whorl, helix 508 | combination lock 509 | computer keyboard, keypad 510 | confectionery, confectionary, candy store 511 | container ship, containership, container vessel 512 | convertible 513 | corkscrew, bottle screw 514 | cornet, horn, trumpet, trump 515 | cowboy boot 516 | cowboy hat, ten-gallon hat 517 | cradle 518 | crane 519 | crash helmet 520 | crate 521 | crib, cot 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam, dike, dyke 527 | desk 528 | desktop computer 529 | dial telephone, dial phone 530 | diaper, nappy, napkin 531 | digital clock 532 | digital watch 533 | dining table, board 534 | dishrag, dishcloth 535 | dishwasher, dish washer, dishwashing machine 536 | disk brake, disc brake 537 | dock, dockage, docking facility 538 | dogsled, dog sled, dog sleigh 539 | dome 540 | doormat, welcome mat 541 | drilling platform, offshore rig 542 | drum, membranophone, tympan 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan, blower 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso maker 552 | face powder 553 | feather boa, boa 554 | file, file cabinet, filing cabinet 555 | fireboat 556 | fire engine, fire truck 557 | fire screen, fireguard 558 | flagpole, flagstaff 559 | flute, transverse flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster 566 | freight car 567 | French horn, horn 568 | frying pan, frypan, skillet 569 | fur coat 570 | garbage truck, dustcart 571 | gasmask, respirator, gas helmet 572 | gas pump, gasoline pump, petrol pump, island dispenser 573 | goblet 574 | go-kart 575 | golf ball 576 | golfcart, golf cart 577 | gondola 578 | gong, tam-tam 579 | gown 580 | grand piano, grand 581 | greenhouse, nursery, glasshouse 582 | grille, radiator grille 583 | grocery store, grocery, food market, market 584 | guillotine 585 | hair slide 586 | hair spray 587 | half track 588 | hammer 589 | hamper 590 | hand blower, blow dryer, blow drier, hair dryer, hair drier 591 | hand-held computer, hand-held microcomputer 592 | handkerchief, hankie, hanky, hankey 593 | hard disc, hard disk, fixed disk 594 | harmonica, mouth organ, harp, mouth harp 595 | harp 596 | harvester, reaper 597 | hatchet 598 | holster 599 | home theater, home theatre 600 | honeycomb 601 | hook, claw 602 | hoopskirt, crinoline 603 | horizontal bar, high bar 604 | horse cart, horse-cart 605 | hourglass 606 | iPod 607 | iron, smoothing iron 608 | jack-o'-lantern 609 | jean, blue jean, denim 610 | jeep, landrover 611 | jersey, T-shirt, tee shirt 612 | jigsaw puzzle 613 | jinrikisha, ricksha, rickshaw 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat, laboratory coat 619 | ladle 620 | lampshade, lamp shade 621 | laptop, laptop computer 622 | lawn mower, mower 623 | lens cap, lens cover 624 | letter opener, paper knife, paperknife 625 | library 626 | lifeboat 627 | lighter, light, igniter, ignitor 628 | limousine, limo 629 | liner, ocean liner 630 | lipstick, lip rouge 631 | Loafer 632 | lotion 633 | loudspeaker, speaker, speaker unit, loudspeaker system, speaker system 634 | loupe, jeweler's loupe 635 | lumbermill, sawmill 636 | magnetic compass 637 | mailbag, postbag 638 | mailbox, letter box 639 | maillot 640 | maillot, tank suit 641 | manhole cover 642 | maraca 643 | marimba, xylophone 644 | mask 645 | matchstick 646 | maypole 647 | maze, labyrinth 648 | measuring cup 649 | medicine chest, medicine cabinet 650 | megalith, megalithic structure 651 | microphone, mike 652 | microwave, microwave oven 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt, mini 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home, manufactured home 662 | Model T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito net 671 | motor scooter, scooter 672 | mountain bike, all-terrain bike, off-roader 673 | mountain tent 674 | mouse, computer mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | nail 679 | neck brace 680 | necklace 681 | nipple 682 | notebook, notebook computer 683 | obelisk 684 | oboe, hautboy, hautbois 685 | ocarina, sweet potato 686 | odometer, hodometer, mileometer, milometer 687 | oil filter 688 | organ, pipe organ 689 | oscilloscope, scope, cathode-ray oscilloscope, CRO 690 | overskirt 691 | oxcart 692 | oxygen mask 693 | packet 694 | paddle, boat paddle 695 | paddlewheel, paddle wheel 696 | padlock 697 | paintbrush 698 | pajama, pyjama, pj's, jammies 699 | palace 700 | panpipe, pandean pipe, syrinx 701 | paper towel 702 | parachute, chute 703 | parallel bars, bars 704 | park bench 705 | parking meter 706 | passenger car, coach, carriage 707 | patio, terrace 708 | pay-phone, pay-station 709 | pedestal, plinth, footstall 710 | pencil box, pencil case 711 | pencil sharpener 712 | perfume, essence 713 | Petri dish 714 | photocopier 715 | pick, plectrum, plectron 716 | pickelhaube 717 | picket fence, paling 718 | pickup, pickup truck 719 | pier 720 | piggy bank, penny bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate, pirate ship 726 | pitcher, ewer 727 | plane, carpenter's plane, woodworking plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | plow, plough 732 | plunger, plumber's helper 733 | Polaroid camera, Polaroid Land camera 734 | pole 735 | police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria 736 | poncho 737 | pool table, billiard table, snooker table 738 | pop bottle, soda bottle 739 | pot, flowerpot 740 | potter's wheel 741 | power drill 742 | prayer rug, prayer mat 743 | printer 744 | prison, prison house 745 | projectile, missile 746 | projector 747 | puck, hockey puck 748 | punching bag, punch bag, punching ball, punchball 749 | purse 750 | quill, quill pen 751 | quilt, comforter, comfort, puff 752 | racer, race car, racing car 753 | racket, racquet 754 | radiator 755 | radio, wireless 756 | radio telescope, radio reflector 757 | rain barrel 758 | recreational vehicle, RV, R.V. 759 | reel 760 | reflex camera 761 | refrigerator, icebox 762 | remote control, remote 763 | restaurant, eating house, eating place, eatery 764 | revolver, six-gun, six-shooter 765 | rifle 766 | rocking chair, rocker 767 | rotisserie 768 | rubber eraser, rubber, pencil eraser 769 | rugby ball 770 | rule, ruler 771 | running shoe 772 | safe 773 | safety pin 774 | saltshaker, salt shaker 775 | sandal 776 | sarong 777 | sax, saxophone 778 | scabbard 779 | scale, weighing machine 780 | school bus 781 | schooner 782 | scoreboard 783 | screen, CRT screen 784 | screw 785 | screwdriver 786 | seat belt, seatbelt 787 | sewing machine 788 | shield, buckler 789 | shoe shop, shoe-shop, shoe store 790 | shoji 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | ski mask 798 | sleeping bag 799 | slide rule, slipstick 800 | sliding door 801 | slot, one-armed bandit 802 | snorkel 803 | snowmobile 804 | snowplow, snowplough 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar dish, solar collector, solar furnace 809 | sombrero 810 | soup bowl 811 | space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | speedboat 816 | spider web, spider's web 817 | spindle 818 | sports car, sport car 819 | spotlight, spot 820 | stage 821 | steam locomotive 822 | steel arch bridge 823 | steel drum 824 | stethoscope 825 | stole 826 | stone wall 827 | stopwatch, stop watch 828 | stove 829 | strainer 830 | streetcar, tram, tramcar, trolley, trolley car 831 | stretcher 832 | studio couch, day bed 833 | stupa, tope 834 | submarine, pigboat, sub, U-boat 835 | suit, suit of clothes 836 | sundial 837 | sunglass 838 | sunglasses, dark glasses, shades 839 | sunscreen, sunblock, sun blocker 840 | suspension bridge 841 | swab, swob, mop 842 | sweatshirt 843 | swimming trunks, bathing trunks 844 | swing 845 | switch, electric switch, electrical switch 846 | syringe 847 | table lamp 848 | tank, army tank, armored combat vehicle, armoured combat vehicle 849 | tape player 850 | teapot 851 | teddy, teddy bear 852 | television, television system 853 | tennis ball 854 | thatch, thatched roof 855 | theater curtain, theatre curtain 856 | thimble 857 | thresher, thrasher, threshing machine 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop, tobacconist shop, tobacconist 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck, tow car, wrecker 866 | toyshop 867 | tractor 868 | trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi 869 | tray 870 | trench coat 871 | tricycle, trike, velocipede 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus, trolley coach, trackless trolley 876 | trombone 877 | tub, vat 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle, monocycle 882 | upright, upright piano 883 | vacuum, vacuum cleaner 884 | vase 885 | vault 886 | velvet 887 | vending machine 888 | vestment 889 | viaduct 890 | violin, fiddle 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet, billfold, notecase, pocketbook 895 | wardrobe, closet, press 896 | warplane, military plane 897 | washbasin, handbasin, washbowl, lavabo, wash-hand basin 898 | washer, automatic washer, washing machine 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | wing 910 | wok 911 | wooden spoon 912 | wool, woolen, woollen 913 | worm fence, snake fence, snake-rail fence, Virginia fence 914 | wreck 915 | yawl 916 | yurt 917 | web site, website, internet site, site 918 | comic book 919 | crossword puzzle, crossword 920 | street sign 921 | traffic light, traffic signal, stoplight 922 | book jacket, dust cover, dust jacket, dust wrapper 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot, hotpot 928 | trifle 929 | ice cream, icecream 930 | ice lolly, lolly, lollipop, popsicle 931 | French loaf 932 | bagel, beigel 933 | pretzel 934 | cheeseburger 935 | hotdog, hot dog, red hot 936 | mashed potato 937 | head cabbage 938 | broccoli 939 | cauliflower 940 | zucchini, courgette 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber, cuke 945 | artichoke, globe artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple, ananas 955 | banana 956 | jackfruit, jak, jack 957 | custard apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate sauce, chocolate syrup 962 | dough 963 | meat loaf, meatloaf 964 | pizza, pizza pie 965 | potpie 966 | burrito 967 | red wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff, drop, drop-off 974 | coral reef 975 | geyser 976 | lakeside, lakeshore 977 | promontory, headland, head, foreland 978 | sandbar, sand bar 979 | seashore, coast, seacoast, sea-coast 980 | valley, vale 981 | volcano 982 | ballplayer, baseball player 983 | groom, bridegroom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum 988 | corn 989 | acorn 990 | hip, rose hip, rosehip 991 | buckeye, horse chestnut, conker 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn, carrion fungus 996 | earthstar 997 | hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa 998 | bolete 999 | ear, spike, capitulum 1000 | toilet tissue, toilet paper, bathroom tissue -------------------------------------------------------------------------------- /pytorch basics/tensors.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "a = torch.ones(10)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 3, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "data": { 28 | "text/plain": [ 29 | "tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])" 30 | ] 31 | }, 32 | "execution_count": 3, 33 | "metadata": {}, 34 | "output_type": "execute_result" 35 | } 36 | ], 37 | "source": [ 38 | "a" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 4, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "data": { 48 | "text/plain": [ 49 | "tensor(1.)" 50 | ] 51 | }, 52 | "execution_count": 4, 53 | "metadata": {}, 54 | "output_type": "execute_result" 55 | } 56 | ], 57 | "source": [ 58 | "a[1]" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 5, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "a[9] = 2.0" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 6, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "a[8] = int(8)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 7, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "tensor([1., 1., 1., 1., 1., 1., 1., 1., 8., 2.])" 88 | ] 89 | }, 90 | "execution_count": 7, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "a" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "### Named Tensors" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 8, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# create dummy image\n", 113 | "image_dummy = torch.randn(3, 256, 256) # shape=[channels, rows, columns]\n", 114 | "\n", 115 | "# create a dummy image batch\n", 116 | "image_dummy_batch = torch.randn(10, 3, 256, 256) # shape=[batch, channels, rows, columns]" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 9, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "tensor([[-0.0674, -0.1689, -1.1018, ..., -0.7322, -0.2939, 0.3841],\n", 129 | " [-0.2840, -0.7412, 0.1809, ..., -0.1558, -0.2215, -0.2885],\n", 130 | " [-0.8739, -0.0669, 0.1476, ..., -0.2360, 0.2177, -0.4072],\n", 131 | " ...,\n", 132 | " [-0.0314, -0.6882, -1.7180, ..., -0.2414, -1.1165, 0.8286],\n", 133 | " [-0.0416, 0.4167, 0.2723, ..., 0.5075, 0.2172, -0.2410],\n", 134 | " [ 0.0192, 0.5361, 0.0802, ..., -0.4366, 0.3520, -0.0738]])\n", 135 | "torch.Size([256, 256])\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "# take mean of all the channels to obtain a grayscale image\n", 141 | "print(image_dummy.mean(-3))\n", 142 | "print(image_dummy.mean(-3).shape) # grayscale image; shape=[columns, rows] " 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 10, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "torch.Size([10, 256, 256])\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "# repeat the same for a batch of images\n", 160 | "print(image_dummy_batch.mean(-3).shape) # batch of grayscale images; shape=[batch, columns, rows]" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 11, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "# take weighted mean - multiply the channels with their respective weights and then take mean\n", 170 | "weights = torch.tensor([0.2126, 0.7152, 0.0722])" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 12, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "# transform weights by adding extra dimensions\n", 180 | "# unsqueeze adds extra dimension\n", 181 | "unsqueezed_weights = weights.unsqueeze(1).unsqueeze(1)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 13, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "# broadcast these weights onto the image\n", 191 | "image_dummy_weighted = image_dummy * unsqueezed_weights" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 14, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "# broadcast these weights onto batch of Images\n", 201 | "image_dummy_batch_weighted = image_dummy_batch * unsqueezed_weights" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 15, 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stdout", 211 | "output_type": "stream", 212 | "text": [ 213 | "torch.Size([256, 256]) torch.Size([10, 256, 256])\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "# get the weighted grayscale images\n", 219 | "image_dummy_grayscale_weighted = image_dummy_weighted.sum(-3)\n", 220 | "image_dummy_grayscale_weighted_batch = image_dummy_batch_weighted.sum(-3)\n", 221 | "print(image_dummy_grayscale_weighted.shape, image_dummy_grayscale_weighted_batch.shape)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": {}, 227 | "source": [ 228 | "##### Using Named Tensors" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 25, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "named weights torch.Size([3]) ('channels',)\n", 241 | "Image dummy named torch.Size([3, 256, 256]) ('channels', 'columns', 'rows')\n", 242 | "Image dummy batch named torch.Size([10, 3, 256, 256]) ('batch', 'channels', 'columns', 'rows')\n" 243 | ] 244 | } 245 | ], 246 | "source": [ 247 | "named_weights = weights.refine_names(\"channels\")\n", 248 | "image_dummy_named = image_dummy.refine_names(\"channels\", \"columns\", \"rows\")\n", 249 | "image_dummy_batch_named = image_dummy_batch.refine_names(\"batch\", \"channels\", \"columns\", \"rows\")\n", 250 | "print(\"named weights\", named_weights.shape, named_weights.names)\n", 251 | "print(\"Image dummy named\", image_dummy_named.shape, image_dummy_named.names)\n", 252 | "print(\"Image dummy batch named\", image_dummy_batch_named.shape, image_dummy_batch_named.names)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "##### Performing Operations" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 26, 265 | "metadata": {}, 266 | "outputs": [ 267 | { 268 | "data": { 269 | "text/plain": [ 270 | "(torch.Size([3, 1, 1]), ('channels', 'columns', 'rows'))" 271 | ] 272 | }, 273 | "execution_count": 26, 274 | "metadata": {}, 275 | "output_type": "execute_result" 276 | } 277 | ], 278 | "source": [ 279 | "# change weights dimensions to align with image dimensions\n", 280 | "weights_aligned = named_weights.align_as(image_dummy_named)\n", 281 | "weights_aligned.shape, weights_aligned.names" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 27, 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "data": { 291 | "text/plain": [ 292 | "(torch.Size([256, 256]), ('columns', 'rows'))" 293 | ] 294 | }, 295 | "execution_count": 27, 296 | "metadata": {}, 297 | "output_type": "execute_result" 298 | } 299 | ], 300 | "source": [ 301 | "# perform operations along dimensions wrt name\n", 302 | "grayscale_named = (image_dummy_named * weights_aligned).sum('channels')\n", 303 | "grayscale_named.shape, grayscale_named.names" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 30, 309 | "metadata": {}, 310 | "outputs": [ 311 | { 312 | "data": { 313 | "text/plain": [ 314 | "torch.Size([3, 256])" 315 | ] 316 | }, 317 | "execution_count": 30, 318 | "metadata": {}, 319 | "output_type": "execute_result" 320 | } 321 | ], 322 | "source": [ 323 | "# this means take mean of each column for all columns for all channels\n", 324 | "image_dummy_named.mean('columns').shape" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 31, 330 | "metadata": {}, 331 | "outputs": [ 332 | { 333 | "data": { 334 | "text/plain": [ 335 | "torch.Size([3, 256])" 336 | ] 337 | }, 338 | "execution_count": 31, 339 | "metadata": {}, 340 | "output_type": "execute_result" 341 | } 342 | ], 343 | "source": [ 344 | "# this means take mean of each row for all rows for all channels\n", 345 | "image_dummy_named.mean('rows').shape" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 32, 351 | "metadata": {}, 352 | "outputs": [ 353 | { 354 | "data": { 355 | "text/plain": [ 356 | "torch.Size([256, 256])" 357 | ] 358 | }, 359 | "execution_count": 32, 360 | "metadata": {}, 361 | "output_type": "execute_result" 362 | } 363 | ], 364 | "source": [ 365 | "# this means combine all three channels into one channel by taking the mean of respective values along the channels\n", 366 | "image_dummy_named.mean('channels').shape" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": {}, 372 | "source": [ 373 | "##### drop names" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 33, 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "data": { 383 | "text/plain": [ 384 | "(torch.Size([3, 256, 256]), (None, None, None))" 385 | ] 386 | }, 387 | "execution_count": 33, 388 | "metadata": {}, 389 | "output_type": "execute_result" 390 | } 391 | ], 392 | "source": [ 393 | "image_dummy_named = image_dummy_named.rename(None)\n", 394 | "image_dummy_named.shape, image_dummy_named.names" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": {}, 400 | "source": [ 401 | "### Storage" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 37, 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [ 410 | "points = torch.ones(size=[3, 4], dtype=torch.int64)" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 38, 416 | "metadata": {}, 417 | "outputs": [ 418 | { 419 | "data": { 420 | "text/plain": [ 421 | "tensor([[1, 1, 1, 1],\n", 422 | " [1, 1, 1, 1],\n", 423 | " [1, 1, 1, 1]])" 424 | ] 425 | }, 426 | "execution_count": 38, 427 | "metadata": {}, 428 | "output_type": "execute_result" 429 | } 430 | ], 431 | "source": [ 432 | "points" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 41, 438 | "metadata": {}, 439 | "outputs": [], 440 | "source": [ 441 | "storage_object = points.storage()" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": 42, 447 | "metadata": {}, 448 | "outputs": [ 449 | { 450 | "data": { 451 | "text/plain": [ 452 | " 1\n", 453 | " 1\n", 454 | " 1\n", 455 | " 1\n", 456 | " 1\n", 457 | " 1\n", 458 | " 1\n", 459 | " 1\n", 460 | " 1\n", 461 | " 1\n", 462 | " 1\n", 463 | " 1\n", 464 | "[torch.LongStorage of size 12]" 465 | ] 466 | }, 467 | "execution_count": 42, 468 | "metadata": {}, 469 | "output_type": "execute_result" 470 | } 471 | ], 472 | "source": [ 473 | "storage_object" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 45, 479 | "metadata": {}, 480 | "outputs": [ 481 | { 482 | "data": { 483 | "text/plain": [ 484 | "12" 485 | ] 486 | }, 487 | "execution_count": 45, 488 | "metadata": {}, 489 | "output_type": "execute_result" 490 | } 491 | ], 492 | "source": [ 493 | "storage_object.size()" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 46, 499 | "metadata": {}, 500 | "outputs": [], 501 | "source": [ 502 | "storage_object[2] = 2" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 47, 508 | "metadata": {}, 509 | "outputs": [ 510 | { 511 | "data": { 512 | "text/plain": [ 513 | "tensor([[1, 1, 2, 1],\n", 514 | " [1, 1, 1, 1],\n", 515 | " [1, 1, 1, 1]])" 516 | ] 517 | }, 518 | "execution_count": 47, 519 | "metadata": {}, 520 | "output_type": "execute_result" 521 | } 522 | ], 523 | "source": [ 524 | "points" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 50, 530 | "metadata": {}, 531 | "outputs": [ 532 | { 533 | "data": { 534 | "text/plain": [ 535 | "tensor([[0, 0, 0, 0],\n", 536 | " [0, 0, 0, 0],\n", 537 | " [0, 0, 0, 0]])" 538 | ] 539 | }, 540 | "execution_count": 50, 541 | "metadata": {}, 542 | "output_type": "execute_result" 543 | } 544 | ], 545 | "source": [ 546 | "points.zero_()" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": 51, 552 | "metadata": {}, 553 | "outputs": [ 554 | { 555 | "data": { 556 | "text/plain": [ 557 | " 0\n", 558 | " 0\n", 559 | " 0\n", 560 | " 0\n", 561 | " 0\n", 562 | " 0\n", 563 | " 0\n", 564 | " 0\n", 565 | " 0\n", 566 | " 0\n", 567 | " 0\n", 568 | " 0\n", 569 | "[torch.LongStorage of size 12]" 570 | ] 571 | }, 572 | "execution_count": 51, 573 | "metadata": {}, 574 | "output_type": "execute_result" 575 | } 576 | ], 577 | "source": [ 578 | "points.storage()" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 53, 584 | "metadata": {}, 585 | "outputs": [ 586 | { 587 | "data": { 588 | "text/plain": [ 589 | "torch.Size([3, 4])" 590 | ] 591 | }, 592 | "execution_count": 53, 593 | "metadata": {}, 594 | "output_type": "execute_result" 595 | } 596 | ], 597 | "source": [ 598 | "points.shape" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": 55, 604 | "metadata": {}, 605 | "outputs": [], 606 | "source": [ 607 | "points += 1" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": 56, 613 | "metadata": {}, 614 | "outputs": [ 615 | { 616 | "data": { 617 | "text/plain": [ 618 | "tensor([[1, 1, 1, 1],\n", 619 | " [1, 1, 1, 1],\n", 620 | " [1, 1, 1, 1]])" 621 | ] 622 | }, 623 | "execution_count": 56, 624 | "metadata": {}, 625 | "output_type": "execute_result" 626 | } 627 | ], 628 | "source": [ 629 | "points" 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": 57, 635 | "metadata": {}, 636 | "outputs": [ 637 | { 638 | "data": { 639 | "text/plain": [ 640 | "torch.Size([3, 4])" 641 | ] 642 | }, 643 | "execution_count": 57, 644 | "metadata": {}, 645 | "output_type": "execute_result" 646 | } 647 | ], 648 | "source": [ 649 | "points.shape" 650 | ] 651 | }, 652 | { 653 | "cell_type": "code", 654 | "execution_count": 58, 655 | "metadata": {}, 656 | "outputs": [ 657 | { 658 | "data": { 659 | "text/plain": [ 660 | "0" 661 | ] 662 | }, 663 | "execution_count": 58, 664 | "metadata": {}, 665 | "output_type": "execute_result" 666 | } 667 | ], 668 | "source": [ 669 | "points.storage_offset()" 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "execution_count": 59, 675 | "metadata": {}, 676 | "outputs": [ 677 | { 678 | "data": { 679 | "text/plain": [ 680 | "(4, 1)" 681 | ] 682 | }, 683 | "execution_count": 59, 684 | "metadata": {}, 685 | "output_type": "execute_result" 686 | } 687 | ], 688 | "source": [ 689 | "points.stride()" 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 60, 695 | "metadata": {}, 696 | "outputs": [], 697 | "source": [ 698 | "points[1, 1] = 0" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 61, 704 | "metadata": {}, 705 | "outputs": [ 706 | { 707 | "data": { 708 | "text/plain": [ 709 | "tensor([[1, 1, 1, 1],\n", 710 | " [1, 0, 1, 1],\n", 711 | " [1, 1, 1, 1]])" 712 | ] 713 | }, 714 | "execution_count": 61, 715 | "metadata": {}, 716 | "output_type": "execute_result" 717 | } 718 | ], 719 | "source": [ 720 | "points" 721 | ] 722 | }, 723 | { 724 | "cell_type": "code", 725 | "execution_count": 65, 726 | "metadata": {}, 727 | "outputs": [], 728 | "source": [ 729 | "i, j = 1, 2" 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "execution_count": 66, 735 | "metadata": {}, 736 | "outputs": [ 737 | { 738 | "data": { 739 | "text/plain": [ 740 | "tensor(True)" 741 | ] 742 | }, 743 | "execution_count": 66, 744 | "metadata": {}, 745 | "output_type": "execute_result" 746 | } 747 | ], 748 | "source": [ 749 | "points[i, j] == points.storage()[points.storage_offset() + i*points.stride()[0] + j*points.stride()[1]]" 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": 67, 755 | "metadata": {}, 756 | "outputs": [], 757 | "source": [ 758 | "sub_points = points[2]" 759 | ] 760 | }, 761 | { 762 | "cell_type": "code", 763 | "execution_count": 68, 764 | "metadata": {}, 765 | "outputs": [ 766 | { 767 | "data": { 768 | "text/plain": [ 769 | "tensor([1, 1, 1, 1])" 770 | ] 771 | }, 772 | "execution_count": 68, 773 | "metadata": {}, 774 | "output_type": "execute_result" 775 | } 776 | ], 777 | "source": [ 778 | "sub_points" 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": 69, 784 | "metadata": {}, 785 | "outputs": [ 786 | { 787 | "data": { 788 | "text/plain": [ 789 | "8" 790 | ] 791 | }, 792 | "execution_count": 69, 793 | "metadata": {}, 794 | "output_type": "execute_result" 795 | } 796 | ], 797 | "source": [ 798 | "sub_points.storage_offset()" 799 | ] 800 | }, 801 | { 802 | "cell_type": "code", 803 | "execution_count": 70, 804 | "metadata": {}, 805 | "outputs": [ 806 | { 807 | "data": { 808 | "text/plain": [ 809 | "(1,)" 810 | ] 811 | }, 812 | "execution_count": 70, 813 | "metadata": {}, 814 | "output_type": "execute_result" 815 | } 816 | ], 817 | "source": [ 818 | "sub_points.stride()" 819 | ] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": 73, 824 | "metadata": {}, 825 | "outputs": [], 826 | "source": [ 827 | "points = torch.Tensor([[1, 2], [3, 4], [5, 6]])" 828 | ] 829 | }, 830 | { 831 | "cell_type": "code", 832 | "execution_count": 76, 833 | "metadata": {}, 834 | "outputs": [ 835 | { 836 | "data": { 837 | "text/plain": [ 838 | "1824715206728" 839 | ] 840 | }, 841 | "execution_count": 76, 842 | "metadata": {}, 843 | "output_type": "execute_result" 844 | } 845 | ], 846 | "source": [ 847 | "id(points.storage)" 848 | ] 849 | }, 850 | { 851 | "cell_type": "code", 852 | "execution_count": 77, 853 | "metadata": {}, 854 | "outputs": [], 855 | "source": [ 856 | "points_t = points.t()" 857 | ] 858 | }, 859 | { 860 | "cell_type": "code", 861 | "execution_count": 79, 862 | "metadata": {}, 863 | "outputs": [ 864 | { 865 | "data": { 866 | "text/plain": [ 867 | "1824750927528" 868 | ] 869 | }, 870 | "execution_count": 79, 871 | "metadata": {}, 872 | "output_type": "execute_result" 873 | } 874 | ], 875 | "source": [ 876 | "id(points_t.storage)" 877 | ] 878 | }, 879 | { 880 | "cell_type": "code", 881 | "execution_count": 80, 882 | "metadata": {}, 883 | "outputs": [ 884 | { 885 | "data": { 886 | "text/plain": [ 887 | " 1.0\n", 888 | " 2.0\n", 889 | " 3.0\n", 890 | " 4.0\n", 891 | " 5.0\n", 892 | " 6.0\n", 893 | "[torch.FloatStorage of size 6]" 894 | ] 895 | }, 896 | "execution_count": 80, 897 | "metadata": {}, 898 | "output_type": "execute_result" 899 | } 900 | ], 901 | "source": [ 902 | "points_t.storage()" 903 | ] 904 | }, 905 | { 906 | "cell_type": "code", 907 | "execution_count": 81, 908 | "metadata": {}, 909 | "outputs": [ 910 | { 911 | "data": { 912 | "text/plain": [ 913 | " 1.0\n", 914 | " 2.0\n", 915 | " 3.0\n", 916 | " 4.0\n", 917 | " 5.0\n", 918 | " 6.0\n", 919 | "[torch.FloatStorage of size 6]" 920 | ] 921 | }, 922 | "execution_count": 81, 923 | "metadata": {}, 924 | "output_type": "execute_result" 925 | } 926 | ], 927 | "source": [ 928 | "points.storage()" 929 | ] 930 | }, 931 | { 932 | "cell_type": "code", 933 | "execution_count": 82, 934 | "metadata": {}, 935 | "outputs": [], 936 | "source": [ 937 | "points.storage()[3] = -10" 938 | ] 939 | }, 940 | { 941 | "cell_type": "code", 942 | "execution_count": 83, 943 | "metadata": {}, 944 | "outputs": [ 945 | { 946 | "data": { 947 | "text/plain": [ 948 | " 1.0\n", 949 | " 2.0\n", 950 | " 3.0\n", 951 | " -10.0\n", 952 | " 5.0\n", 953 | " 6.0\n", 954 | "[torch.FloatStorage of size 6]" 955 | ] 956 | }, 957 | "execution_count": 83, 958 | "metadata": {}, 959 | "output_type": "execute_result" 960 | } 961 | ], 962 | "source": [ 963 | "points.storage()" 964 | ] 965 | }, 966 | { 967 | "cell_type": "code", 968 | "execution_count": 84, 969 | "metadata": {}, 970 | "outputs": [ 971 | { 972 | "data": { 973 | "text/plain": [ 974 | " 1.0\n", 975 | " 2.0\n", 976 | " 3.0\n", 977 | " -10.0\n", 978 | " 5.0\n", 979 | " 6.0\n", 980 | "[torch.FloatStorage of size 6]" 981 | ] 982 | }, 983 | "execution_count": 84, 984 | "metadata": {}, 985 | "output_type": "execute_result" 986 | } 987 | ], 988 | "source": [ 989 | "points_t.storage()" 990 | ] 991 | }, 992 | { 993 | "cell_type": "code", 994 | "execution_count": 85, 995 | "metadata": {}, 996 | "outputs": [ 997 | { 998 | "data": { 999 | "text/plain": [ 1000 | "tensor([[ 1., 3., 5.],\n", 1001 | " [ 2., -10., 6.]])" 1002 | ] 1003 | }, 1004 | "execution_count": 85, 1005 | "metadata": {}, 1006 | "output_type": "execute_result" 1007 | } 1008 | ], 1009 | "source": [ 1010 | "points_t" 1011 | ] 1012 | }, 1013 | { 1014 | "cell_type": "code", 1015 | "execution_count": 86, 1016 | "metadata": {}, 1017 | "outputs": [ 1018 | { 1019 | "data": { 1020 | "text/plain": [ 1021 | "tensor([[ 1., 2.],\n", 1022 | " [ 3., -10.],\n", 1023 | " [ 5., 6.]])" 1024 | ] 1025 | }, 1026 | "execution_count": 86, 1027 | "metadata": {}, 1028 | "output_type": "execute_result" 1029 | } 1030 | ], 1031 | "source": [ 1032 | "points" 1033 | ] 1034 | }, 1035 | { 1036 | "cell_type": "code", 1037 | "execution_count": 87, 1038 | "metadata": {}, 1039 | "outputs": [ 1040 | { 1041 | "data": { 1042 | "text/plain": [ 1043 | " 1.0\n", 1044 | " 2.0\n", 1045 | " 3.0\n", 1046 | " -10.0\n", 1047 | " 5.0\n", 1048 | " 6.0\n", 1049 | "[torch.FloatStorage of size 6]" 1050 | ] 1051 | }, 1052 | "execution_count": 87, 1053 | "metadata": {}, 1054 | "output_type": "execute_result" 1055 | } 1056 | ], 1057 | "source": [ 1058 | "points.storage()" 1059 | ] 1060 | }, 1061 | { 1062 | "cell_type": "code", 1063 | "execution_count": 96, 1064 | "metadata": {}, 1065 | "outputs": [ 1066 | { 1067 | "data": { 1068 | "text/plain": [ 1069 | "tensor([[ 1., 3., 5.],\n", 1070 | " [ 2., -10., 6.]])" 1071 | ] 1072 | }, 1073 | "execution_count": 96, 1074 | "metadata": {}, 1075 | "output_type": "execute_result" 1076 | } 1077 | ], 1078 | "source": [ 1079 | "points_t" 1080 | ] 1081 | }, 1082 | { 1083 | "cell_type": "code", 1084 | "execution_count": 89, 1085 | "metadata": {}, 1086 | "outputs": [ 1087 | { 1088 | "data": { 1089 | "text/plain": [ 1090 | " 1.0\n", 1091 | " 2.0\n", 1092 | " 3.0\n", 1093 | " -10.0\n", 1094 | " 5.0\n", 1095 | " 6.0\n", 1096 | "[torch.FloatStorage of size 6]" 1097 | ] 1098 | }, 1099 | "execution_count": 89, 1100 | "metadata": {}, 1101 | "output_type": "execute_result" 1102 | } 1103 | ], 1104 | "source": [ 1105 | "points_t.storage()" 1106 | ] 1107 | }, 1108 | { 1109 | "cell_type": "code", 1110 | "execution_count": 91, 1111 | "metadata": {}, 1112 | "outputs": [ 1113 | { 1114 | "data": { 1115 | "text/plain": [ 1116 | "True" 1117 | ] 1118 | }, 1119 | "execution_count": 91, 1120 | "metadata": {}, 1121 | "output_type": "execute_result" 1122 | } 1123 | ], 1124 | "source": [ 1125 | "points.is_contiguous()" 1126 | ] 1127 | }, 1128 | { 1129 | "cell_type": "code", 1130 | "execution_count": 94, 1131 | "metadata": {}, 1132 | "outputs": [], 1133 | "source": [ 1134 | "points_t_cont = points_t.contiguous()" 1135 | ] 1136 | }, 1137 | { 1138 | "cell_type": "code", 1139 | "execution_count": 95, 1140 | "metadata": {}, 1141 | "outputs": [ 1142 | { 1143 | "data": { 1144 | "text/plain": [ 1145 | "tensor([[ 1., 3., 5.],\n", 1146 | " [ 2., -10., 6.]])" 1147 | ] 1148 | }, 1149 | "execution_count": 95, 1150 | "metadata": {}, 1151 | "output_type": "execute_result" 1152 | } 1153 | ], 1154 | "source": [ 1155 | "points_t_cont" 1156 | ] 1157 | }, 1158 | { 1159 | "cell_type": "code", 1160 | "execution_count": 102, 1161 | "metadata": {}, 1162 | "outputs": [ 1163 | { 1164 | "data": { 1165 | "text/plain": [ 1166 | " 1.0\n", 1167 | " 3.0\n", 1168 | " 5.0\n", 1169 | " 2.0\n", 1170 | " -10.0\n", 1171 | " 6.0\n", 1172 | "[torch.FloatStorage of size 6]" 1173 | ] 1174 | }, 1175 | "execution_count": 102, 1176 | "metadata": {}, 1177 | "output_type": "execute_result" 1178 | } 1179 | ], 1180 | "source": [ 1181 | "points_t_cont.storage()" 1182 | ] 1183 | }, 1184 | { 1185 | "cell_type": "code", 1186 | "execution_count": 103, 1187 | "metadata": {}, 1188 | "outputs": [ 1189 | { 1190 | "data": { 1191 | "text/plain": [ 1192 | " 1.0\n", 1193 | " 2.0\n", 1194 | " 3.0\n", 1195 | " -10.0\n", 1196 | " 5.0\n", 1197 | " 6.0\n", 1198 | "[torch.FloatStorage of size 6]" 1199 | ] 1200 | }, 1201 | "execution_count": 103, 1202 | "metadata": {}, 1203 | "output_type": "execute_result" 1204 | } 1205 | ], 1206 | "source": [ 1207 | "points_t.storage()" 1208 | ] 1209 | }, 1210 | { 1211 | "cell_type": "code", 1212 | "execution_count": 101, 1213 | "metadata": {}, 1214 | "outputs": [ 1215 | { 1216 | "data": { 1217 | "text/plain": [ 1218 | "True" 1219 | ] 1220 | }, 1221 | "execution_count": 101, 1222 | "metadata": {}, 1223 | "output_type": "execute_result" 1224 | } 1225 | ], 1226 | "source": [ 1227 | "id(points_t.storage) == id(points_t_cont.storage)" 1228 | ] 1229 | }, 1230 | { 1231 | "cell_type": "code", 1232 | "execution_count": 106, 1233 | "metadata": {}, 1234 | "outputs": [ 1235 | { 1236 | "data": { 1237 | "text/plain": [ 1238 | "" 1239 | ] 1240 | }, 1241 | "execution_count": 106, 1242 | "metadata": {}, 1243 | "output_type": "execute_result" 1244 | } 1245 | ], 1246 | "source": [ 1247 | "points_t.storage" 1248 | ] 1249 | }, 1250 | { 1251 | "cell_type": "code", 1252 | "execution_count": 107, 1253 | "metadata": {}, 1254 | "outputs": [], 1255 | "source": [ 1256 | "x = torch.Tensor([1, 2])" 1257 | ] 1258 | }, 1259 | { 1260 | "cell_type": "code", 1261 | "execution_count": 110, 1262 | "metadata": {}, 1263 | "outputs": [ 1264 | { 1265 | "data": { 1266 | "text/plain": [ 1267 | "True" 1268 | ] 1269 | }, 1270 | "execution_count": 110, 1271 | "metadata": {}, 1272 | "output_type": "execute_result" 1273 | } 1274 | ], 1275 | "source": [ 1276 | "id(x.storage) == id(points_t.storage)" 1277 | ] 1278 | }, 1279 | { 1280 | "cell_type": "code", 1281 | "execution_count": 118, 1282 | "metadata": {}, 1283 | "outputs": [ 1284 | { 1285 | "data": { 1286 | "text/plain": [ 1287 | "1824772856648" 1288 | ] 1289 | }, 1290 | "execution_count": 118, 1291 | "metadata": {}, 1292 | "output_type": "execute_result" 1293 | } 1294 | ], 1295 | "source": [ 1296 | "id(points_t.storage())" 1297 | ] 1298 | }, 1299 | { 1300 | "cell_type": "code", 1301 | "execution_count": 119, 1302 | "metadata": {}, 1303 | "outputs": [ 1304 | { 1305 | "data": { 1306 | "text/plain": [ 1307 | "1824772855880" 1308 | ] 1309 | }, 1310 | "execution_count": 119, 1311 | "metadata": {}, 1312 | "output_type": "execute_result" 1313 | } 1314 | ], 1315 | "source": [ 1316 | "id(x.storage())" 1317 | ] 1318 | }, 1319 | { 1320 | "cell_type": "code", 1321 | "execution_count": 120, 1322 | "metadata": {}, 1323 | "outputs": [ 1324 | { 1325 | "data": { 1326 | "text/plain": [ 1327 | "True" 1328 | ] 1329 | }, 1330 | "execution_count": 120, 1331 | "metadata": {}, 1332 | "output_type": "execute_result" 1333 | } 1334 | ], 1335 | "source": [ 1336 | "id(points_t.storage()) == id(x.storage())" 1337 | ] 1338 | }, 1339 | { 1340 | "cell_type": "code", 1341 | "execution_count": 117, 1342 | "metadata": {}, 1343 | "outputs": [ 1344 | { 1345 | "data": { 1346 | "text/plain": [ 1347 | " 1.0\n", 1348 | " 2.0\n", 1349 | "[torch.FloatStorage of size 2]" 1350 | ] 1351 | }, 1352 | "execution_count": 117, 1353 | "metadata": {}, 1354 | "output_type": "execute_result" 1355 | } 1356 | ], 1357 | "source": [ 1358 | "x.storage()" 1359 | ] 1360 | }, 1361 | { 1362 | "cell_type": "code", 1363 | "execution_count": 121, 1364 | "metadata": {}, 1365 | "outputs": [], 1366 | "source": [ 1367 | "a = torch.Tensor(list(range(9)))" 1368 | ] 1369 | }, 1370 | { 1371 | "cell_type": "code", 1372 | "execution_count": 122, 1373 | "metadata": {}, 1374 | "outputs": [ 1375 | { 1376 | "data": { 1377 | "text/plain": [ 1378 | "tensor([0., 1., 2., 3., 4., 5., 6., 7., 8.])" 1379 | ] 1380 | }, 1381 | "execution_count": 122, 1382 | "metadata": {}, 1383 | "output_type": "execute_result" 1384 | } 1385 | ], 1386 | "source": [ 1387 | "a" 1388 | ] 1389 | }, 1390 | { 1391 | "cell_type": "code", 1392 | "execution_count": 124, 1393 | "metadata": {}, 1394 | "outputs": [], 1395 | "source": [ 1396 | "b = a.view(3, 3)" 1397 | ] 1398 | }, 1399 | { 1400 | "cell_type": "code", 1401 | "execution_count": 125, 1402 | "metadata": {}, 1403 | "outputs": [ 1404 | { 1405 | "data": { 1406 | "text/plain": [ 1407 | " 0.0\n", 1408 | " 1.0\n", 1409 | " 2.0\n", 1410 | " 3.0\n", 1411 | " 4.0\n", 1412 | " 5.0\n", 1413 | " 6.0\n", 1414 | " 7.0\n", 1415 | " 8.0\n", 1416 | "[torch.FloatStorage of size 9]" 1417 | ] 1418 | }, 1419 | "execution_count": 125, 1420 | "metadata": {}, 1421 | "output_type": "execute_result" 1422 | } 1423 | ], 1424 | "source": [ 1425 | "a.storage()" 1426 | ] 1427 | }, 1428 | { 1429 | "cell_type": "code", 1430 | "execution_count": 126, 1431 | "metadata": {}, 1432 | "outputs": [ 1433 | { 1434 | "data": { 1435 | "text/plain": [ 1436 | " 0.0\n", 1437 | " 1.0\n", 1438 | " 2.0\n", 1439 | " 3.0\n", 1440 | " 4.0\n", 1441 | " 5.0\n", 1442 | " 6.0\n", 1443 | " 7.0\n", 1444 | " 8.0\n", 1445 | "[torch.FloatStorage of size 9]" 1446 | ] 1447 | }, 1448 | "execution_count": 126, 1449 | "metadata": {}, 1450 | "output_type": "execute_result" 1451 | } 1452 | ], 1453 | "source": [ 1454 | "b.storage()" 1455 | ] 1456 | }, 1457 | { 1458 | "cell_type": "code", 1459 | "execution_count": 127, 1460 | "metadata": {}, 1461 | "outputs": [ 1462 | { 1463 | "data": { 1464 | "text/plain": [ 1465 | "(1,)" 1466 | ] 1467 | }, 1468 | "execution_count": 127, 1469 | "metadata": {}, 1470 | "output_type": "execute_result" 1471 | } 1472 | ], 1473 | "source": [ 1474 | "a.stride()" 1475 | ] 1476 | }, 1477 | { 1478 | "cell_type": "code", 1479 | "execution_count": 128, 1480 | "metadata": {}, 1481 | "outputs": [ 1482 | { 1483 | "data": { 1484 | "text/plain": [ 1485 | "(3, 1)" 1486 | ] 1487 | }, 1488 | "execution_count": 128, 1489 | "metadata": {}, 1490 | "output_type": "execute_result" 1491 | } 1492 | ], 1493 | "source": [ 1494 | "b.stride()" 1495 | ] 1496 | }, 1497 | { 1498 | "cell_type": "code", 1499 | "execution_count": null, 1500 | "metadata": {}, 1501 | "outputs": [], 1502 | "source": [] 1503 | } 1504 | ], 1505 | "metadata": { 1506 | "kernelspec": { 1507 | "display_name": "Python 3 (ipykernel)", 1508 | "language": "python", 1509 | "name": "python3" 1510 | }, 1511 | "language_info": { 1512 | "codemirror_mode": { 1513 | "name": "ipython", 1514 | "version": 3 1515 | }, 1516 | "file_extension": ".py", 1517 | "mimetype": "text/x-python", 1518 | "name": "python", 1519 | "nbconvert_exporter": "python", 1520 | "pygments_lexer": "ipython3", 1521 | "version": "3.7.11" 1522 | } 1523 | }, 1524 | "nbformat": 4, 1525 | "nbformat_minor": 4 1526 | } 1527 | -------------------------------------------------------------------------------- /sentiment_analysis/README.md: -------------------------------------------------------------------------------- 1 | ## Sentiment Analysis 2 | Experiments, Notes with different sentiment analysis models taken inspiration from the book "Deep learning with python" by françois chollet -------------------------------------------------------------------------------- /sentiment_analysis/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wingedrasengan927/pytorch-tutorials/20ddc1fda3750550c8426799187276619a7f7610/sentiment_analysis/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /sentiment_analysis/sentiment_analysis_v1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Sentiment Analysis" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "##### Improvement" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "1. Refactor and modularize helper functions and import it from utils\n", 22 | "2. Use pre-trained word embeddings\n", 23 | "3. Build vocab smartly" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "name": "stderr", 33 | "output_type": "stream", 34 | "text": [ 35 | "[nltk_data] Downloading package stopwords to\n", 36 | "[nltk_data] C:\\Users\\Neeraj\\AppData\\Roaming\\nltk_data...\n", 37 | "[nltk_data] Package stopwords is already up-to-date!\n", 38 | "[nltk_data] Downloading package punkt to\n", 39 | "[nltk_data] C:\\Users\\Neeraj\\AppData\\Roaming\\nltk_data...\n", 40 | "[nltk_data] Package punkt is already up-to-date!\n", 41 | "[nltk_data] Downloading package wordnet to\n", 42 | "[nltk_data] C:\\Users\\Neeraj\\AppData\\Roaming\\nltk_data...\n", 43 | "[nltk_data] Package wordnet is already up-to-date!\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "import pandas as pd\n", 49 | "import numpy as np\n", 50 | "import torch\n", 51 | "import torch.nn as nn\n", 52 | "import torch.functional as F\n", 53 | "import torch.optim as optim\n", 54 | "\n", 55 | "import matplotlib.pyplot as plt\n", 56 | "import seaborn as sns\n", 57 | "\n", 58 | "from utils import preprocess_text, pad_sequence, train_test_split_tensors, train_model" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "### Data preparation" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "#### Load the dataset" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 2, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "train = pd.read_csv(\"../data/IMDB Sentiment Analysis/Train.csv\")\n", 82 | "test = pd.read_csv(\"../data/IMDB Sentiment Analysis/Test.csv\")\n", 83 | "val = pd.read_csv(\"../data/IMDB Sentiment Analysis/Valid.csv\")" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 3, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "df_all = pd.concat([train, val, test], axis=0)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 4, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/html": [ 103 | "
\n", 104 | "\n", 117 | "\n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | "
textlabel
0I grew up (b. 1965) watching and loving the Th...0
1When I put this movie in my DVD player, and sa...0
2Why do people who do not know what a particula...0
3Even though I have great interest in Biblical ...0
4Im a die hard Dads Army fan and nothing will e...1
\n", 153 | "
" 154 | ], 155 | "text/plain": [ 156 | " text label\n", 157 | "0 I grew up (b. 1965) watching and loving the Th... 0\n", 158 | "1 When I put this movie in my DVD player, and sa... 0\n", 159 | "2 Why do people who do not know what a particula... 0\n", 160 | "3 Even though I have great interest in Biblical ... 0\n", 161 | "4 Im a die hard Dads Army fan and nothing will e... 1" 162 | ] 163 | }, 164 | "execution_count": 4, 165 | "metadata": {}, 166 | "output_type": "execute_result" 167 | } 168 | ], 169 | "source": [ 170 | "df_all.head()" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 5, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "text/plain": [ 181 | "50000" 182 | ] 183 | }, 184 | "execution_count": 5, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "dataset_size = len(df_all)\n", 191 | "dataset_size" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "#### Preprocess Dataset" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 6, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "df_all[\"word_list\"] = df_all[\"text\"].apply(lambda text: preprocess_text(text))" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 7, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "data": { 217 | "text/html": [ 218 | "
\n", 219 | "\n", 232 | "\n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | "
textlabelword_list
0I grew up (b. 1965) watching and loving the Th...0[grew, b, watching, loving, thunderbird, mate,...
1When I put this movie in my DVD player, and sa...0[put, movie, dvd, player, sat, coke, chip, exp...
2Why do people who do not know what a particula...0[people, know, particular, time, past, like, f...
3Even though I have great interest in Biblical ...0[even, though, great, interest, biblical, movi...
4Im a die hard Dads Army fan and nothing will e...1[im, die, hard, dad, army, fan, nothing, ever,...
\n", 274 | "
" 275 | ], 276 | "text/plain": [ 277 | " text label \\\n", 278 | "0 I grew up (b. 1965) watching and loving the Th... 0 \n", 279 | "1 When I put this movie in my DVD player, and sa... 0 \n", 280 | "2 Why do people who do not know what a particula... 0 \n", 281 | "3 Even though I have great interest in Biblical ... 0 \n", 282 | "4 Im a die hard Dads Army fan and nothing will e... 1 \n", 283 | "\n", 284 | " word_list \n", 285 | "0 [grew, b, watching, loving, thunderbird, mate,... \n", 286 | "1 [put, movie, dvd, player, sat, coke, chip, exp... \n", 287 | "2 [people, know, particular, time, past, like, f... \n", 288 | "3 [even, though, great, interest, biblical, movi... \n", 289 | "4 [im, die, hard, dad, army, fan, nothing, ever,... " 290 | ] 291 | }, 292 | "execution_count": 7, 293 | "metadata": {}, 294 | "output_type": "execute_result" 295 | } 296 | ], 297 | "source": [ 298 | "df_all.head()" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 8, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "df_all[\"word_count\"] = df_all[\"word_list\"].apply(lambda word_list: len(word_list))" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 9, 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "data": { 317 | "text/plain": [ 318 | "" 319 | ] 320 | }, 321 | "execution_count": 9, 322 | "metadata": {}, 323 | "output_type": "execute_result" 324 | }, 325 | { 326 | "data": { 327 | "text/plain": [ 328 | "
" 329 | ] 330 | }, 331 | "metadata": {}, 332 | "output_type": "display_data" 333 | }, 334 | { 335 | "data": { 336 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdUUlEQVR4nO3de5CldX3n8fd3emZ6LsIKsWHJDClIlk0FrBXDSFBj4iWGiZvK4G4QLFcwcR2WxSzGXApiVWJSNbsx3gBdiAQJmKCEeAnEyMQR8UKFgA1BYUDCKCAjU0yTlHGme21mer77x/M7zaE909Pd0+f8+nS/X1WnznN+57l8Z+j58PTv+T2/JzITSVLvLatdgCQtVQawJFViAEtSJQawJFViAEtSJctrF9AtGzduzK1bt9YuQ5IAolPjoj0Dfvrpp2uXIEnTWrQBLEkLnQEsSZUYwJJUiQEsSZUYwJJUiQEsSZUYwJJUiQEsSZUYwJJUiQEsSZUYwJJUiQEsSZUYwJJUiQE8R6Ojo4yOjtYuQ1IfM4AlqRIDWJIqMYAlqRIDWJIq6VoAR8SqiLg7Ir4eEdsj4g9L+9ERsS0iHinvR7Vtc2lE7IiIhyPizLb20yLi/vLdFRHR8flKktRPunkGPA68OjNfBJwKbIyIM4BLgNsy8yTgtvKZiDgZOBc4BdgIXBkRA2VfVwGbgZPKa2MX65aknuhaAGdjb/m4orwS2ARcX9qvB84qy5uAGzNzPDMfBXYAp0fEccCRmXlnZibwsbZtJKlvdbUPOCIGIuI+YDewLTPvAo7NzF0A5f2Ysvo64Im2zXeWtnVleWp7p+NtjojhiBgeGRmZ1z+LJM23rgZwZk5k5qnAepqz2RdOs3qnft2cpr3T8a7OzA2ZuWFoaGjW9UpSL/VkFERmfg/4Ek3f7VOlW4HyvrusthM4vm2z9cCTpX19h3ZJ6mvdHAUxFBHPL8urgV8AvgncApxfVjsfuLks3wKcGxGDEXEizcW2u0s3xZ6IOKOMfjivbRtJ6lvLu7jv44Dry0iGZcBNmfnZiLgTuCki3gp8BzgbIDO3R8RNwIPAfuCizJwo+7oQuA5YDdxaXpLU16IZWLD4bNiwIYeHh7u2/9ZEPGvXru3aMSQtGh3vXfBOOEmqxACWpEoMYEmqxACWpEoMYEmqxACWpEoMYEmqxACWpEoMYEmqxACWpEoMYEmqxACWpEoMYEmqxACepcxkdHSUxTqLnKTeMYBnaWxsjHMu38rY2FjtUiT1OQN4DpavXFW7BEmLgAEsSZUYwJJUiQEsSZUYwJJUiQEsSZUYwJJUiQEsSZUYwHPg3XCS5oMBPAcT+8a54No7vBtO0mExgOdoYHB17RIk9TkDWJIqMYAlqRIDWJIqMYAlqRIDWJIqMYAlqRIDWJIqMYAlqRIDWJIqMYAlqRIDWJIqMYAlqZKuBXBEHB8Rt0fEQxGxPSIuLu3vjojvRsR95fW6tm0ujYgdEfFwRJzZ1n5aRNxfvrsiIqJbdUtSryzv4r73A7+VmfdGxBHAPRGxrXz3wcx8X/vKEXEycC5wCvCjwBci4j9m5gRwFbAZ+Efgc8BG4NYu1i5JXde1M+DM3JWZ95blPcBDwLppNtkE3JiZ45n5KLADOD0ijgOOzMw7s5kB/WPAWd2qW5J6pSd9wBFxAvBi4K7S9PaI+EZEXBsRR5W2dcATbZvtLG3ryvLU9k7H2RwRwxExPDIyMp9/BEmad10P4Ih4HvAp4B2Z+X2a7oSfAE4FdgHvb63aYfOcpv2HGzOvzswNmblhaGjocEuXpK7qagBHxAqa8L0hMz8NkJlPZeZEZh4A/gw4vay+Ezi+bfP1wJOlfX2Hdknqa90cBRHAR4GHMvMDbe3Hta32euCBsnwLcG5EDEbEicBJwN2ZuQvYExFnlH2eB9zcrbolqVe6OQri5cCbgfsj4r7S9nvAGyPiVJpuhMeACwAyc3tE3AQ8SDOC4qIyAgLgQuA6YDXN6AdHQEjqe10L4My8g879t5+bZpstwJYO7cPAC+evOkmqzzvhJKkSA1iSKjGAJakSA1iSKjGAJakSA1iSKjGAD0NmMjo6SjNHkCTNjgF8GMbGxjjn8q2MjY3VLkVSHzKAD9PylatqlyCpTxnAklSJASxJlRjAklSJASxJlRjAc5SZjn6QdFgM4Dma2DfOxTcMM3HAMcCS5sYAPgwDgw5BkzR3BrAkVWIAS1IlBrAkVWIAS1IlBrAkVWIAS1IlBrAkVWIAS1IlBrAkVWIAS1IlBrAkVWIAS1IlBrAkVWIAS1IlBrAkVWIAS1IlBrAkVWIAS1IlBrAkVWIAS1IlBrAkVWIAS1IlXQvgiDg+Im6PiIciYntEXFzaj46IbRHxSHk/qm2bSyNiR0Q8HBFntrWfFhH3l++uiIjoVt1zkZmMjo6SmbVLkdRHunkGvB/4rcz8KeAM4KKIOBm4BLgtM08CbiufKd+dC5wCbASujIiBsq+rgM3ASeW1sYt1z9rY2BjnXL6VsbGx2qVI6iNdC+DM3JWZ95blPcBDwDpgE3B9We164KyyvAm4MTPHM/NRYAdwekQcBxyZmXdmc4r5sbZtFozlK1fVLkFSn+lJH3BEnAC8GLgLODYzd0ET0sAxZbV1wBNtm+0sbevK8tT2TsfZHBHDETE8MjIyr38GSZpvXQ/giHge8CngHZn5/elW7dCW07T/cGPm1Zm5ITM3DA0Nzb5YSeqhrgZwRKygCd8bMvPTpfmp0q1Aed9d2ncCx7dtvh54srSv79AuSX2tm6MgAvgo8FBmfqDtq1uA88vy+cDNbe3nRsRgRJxIc7Ht7tJNsScizij7PK9tG0nqW8u7uO+XA28G7o+I+0rb7wF/DNwUEW8FvgOcDZCZ2yPiJuBBmhEUF2XmRNnuQuA6YDVwa3lJUl/rWgBn5h107r8FeM1BttkCbOnQPgy8cP6qk6T6vBNOkioxgCWpEgNYkioxgCWpEgNYkioxgCWpEgNYkioxgCWpEgNYkioxgCWpEgNYkioxgCWpEgNYkioxgCWpEgNYkioxgCWpEgNYkioxgCWpkhkFcES8fCZtkqSZm+kZ8Idm2CZJmqFpH8oZES8FXgYMRcQ72746EhjoZmGStNgd6qnIK4HnlfWOaGv/PvCr3SpKkpaCaQM4M78MfDkirsvMx3tUU1/JTEZHR2uXIakPHeoMuGUwIq4GTmjfJjNf3Y2i+snEvnEuuPYOBo84qnYpkvrMTAP4r4E/Ba4BJrpXTn8aGFxduwRJfWimAbw/M6/qaiWStMTMdBja30bE/4yI4yLi6Narq5VJ0iI30zPg88v777S1JfDj81uOJC0dMwrgzDyx24VI0lIzowCOiPM6tWfmx+a3nP7WGo62du3aypVI6gcz7YJ4SdvyKuA1wL2AASxJczTTLojfaP8cEf8O+IuuVCRJS8Rcp6McA06az0IkaamZaR/w39KMeoBmEp6fAm7qVlGStBTMtA/4fW3L+4HHM3NnF+qRpCVjRl0QZVKeb9LMiHYU8Ew3i5KkpWCmT8R4A3A3cDbwBuCuiHA6Skk6DDPtgngX8JLM3A0QEUPAF4BPdqswSVrsZjoKYlkrfIt/mcW2kqQOZhqiWyPi7yPiLRHxFuDvgM9Nt0FEXBsRuyPigba2d0fEdyPivvJ6Xdt3l0bEjoh4OCLObGs/LSLuL99dERExuz+iJC1M0wZwRPyHiHh5Zv4O8BHgPwEvAu4Erj7Evq8DNnZo/2BmnlpenyvHORk4FzilbHNlRLSeOXcVsJlm3PFJB9mnJPWdQ50BXwbsAcjMT2fmOzPzN2nOfi+bbsPM/ArwrzOsYxNwY2aOZ+ajwA7g9Ig4DjgyM+/MzKS59fmsGe5Tkha0QwXwCZn5jamNmTlM83iiuXh7RHyjdFG0nuOzDniibZ2dpW1dWZ7a3lFEbI6I4YgYHhkZmWN5ktQbhwrgVdN8N5fn8FwF/ARwKrALeH9p79Svm9O0d5SZV2fmhszcMDQ0NIfyJKl3DhXAX4uIt01tjIi3AvfM9mCZ+VRmTmTmAeDPgNPLVzuB49tWXQ88WdrXd2iXpL53qHHA7wA+ExFv4tnA3QCsBF4/24NFxHGZuat8fD3QGiFxC/DxiPgA8KM0F9vuzsyJiNgTEWcAdwHnAR+a7XElaSGaNoAz8yngZRHxKuCFpfnvMvOLh9pxRHwCeCXwgojYCfwB8MqIOJWmG+Ex4IJynO0RcRPwIM1cExdlZuvpyxfSjKhYDdxaXpLU92Y6H/DtwO2z2XFmvrFD80enWX8LsKVD+zDPhr8kLRrezSZJlRjAklSJATxPMpPR0VGa+0Uk6dAM4HkysW+cC669g7GxsdqlSOoTBvA8Ghicy70pkpYqA1iSKjGAJakSA1iSKjGAZ6E10kGS5oMBPAtjY2Ocd+U2Jg4cfKiZw9EkzZQBPEvLV043Q2cT0udcvtXhaJIOyQDugkOFtCSBASxJ1RjAklSJASxJlRjAklSJASxJlRjAklSJASxJlRjAklSJASxJlRjAklSJASxJlRjAklSJASxJlRjAklSJASxJlRjA8ygznYhd0owZwPNoYt84F98wzMSB9NFEkg7JAJ5nA4PN0zB8NJGkQzGAu8hHE0majgEsSZUYwJJUiQEsSZUYwJJUiQEsSZUYwF3QGgMsSdMxgLtgYt84F1x7BxMHvAlD0sF1LYAj4tqI2B0RD7S1HR0R2yLikfJ+VNt3l0bEjoh4OCLObGs/LSLuL99dERHRrZqnM9uz2oHB1V2sRtJi0M0z4OuAjVPaLgFuy8yTgNvKZyLiZOBc4JSyzZURMVC2uQrYDJxUXlP32RNjY2Ocd+U2z2olzZuuBXBmfgX41ynNm4Dry/L1wFlt7Tdm5nhmPgrsAE6PiOOAIzPzzmwmVfhY2zY9551tkuZTr/uAj83MXQDl/ZjSvg54om29naVtXVme2t5RRGyOiOGIGB4ZGZnXwudqdHTUC3KSOlooF+E69evmNO0dZebVmbkhMzcMDQ3NW3GS1A29DuCnSrcC5X13ad8JHN+23nrgydK+vkO7JPW9XgfwLcD5Zfl84Oa29nMjYjAiTqS52HZ36abYExFnlNEP57Vt03XO6Supm7o5DO0TwJ3AT0bEzoh4K/DHwGsj4hHgteUzmbkduAl4ENgKXJSZE2VXFwLX0FyY+xZwa7dqnso5fSV10/Ju7Tgz33iQr15zkPW3AFs6tA8DL5zH0mbFkQ+SumWhXISTpCXHAJakSgxgSarEAJakSgxgSarEAJakSgxgSaqka+OAFxMn05HUDZ4BS1IlBnAXOZeEpOkYwF3UejZca05gg1hSOwO4ywYGVzupj6SODOAecVIfSVMZwJJUiQEsSZUYwAfRGsHQjX16MU4SGMAHNTY2xnlXbmPiwPyFpRfjJLUzgKfRjQtnXoyT1GIAS1IlzgXRQXv/b2ayd+/eyhVJWowM4A5a/b+xfBAOjPO2j3yRGFjBijVra5cmaRExgA9i+cpVkxfglg2uYtnAisoVSVps7AOWpEoMYEmqxACWpEoMYEmqxACWpEoM4ApaE7RLWtoM4C7LTOd+kNSRAdxlE/vGufiG4Xmd1EfS4mAA98DAoBPwSPphBrAkVWIAS1IlBnCPTH3Chk/HkGQA98jEvnEuuPYO9k8cmByGds7lWyeXDWJp6TGAe2hgcPVkEI+NjbF85SofUyQtYQZwBQODq5/z2ccUSUuTASxJlVQJ4Ih4LCLuj4j7ImK4tB0dEdsi4pHyflTb+pdGxI6IeDgizqxRsyTNt5pnwK/KzFMzc0P5fAlwW2aeBNxWPhMRJwPnAqcAG4ErI2KgRsHzxduTJcHC6oLYBFxflq8HzmprvzEzxzPzUWAHcHrvy5s/3p4sCeoFcAKfj4h7ImJzaTs2M3cBlPdjSvs64Im2bXeWtr7m7cmSaj2U8+WZ+WREHANsi4hvTrNudGjreOpYwnwzwI/92I8dfpWS1EVVzoAz88nyvhv4DE2XwlMRcRxAed9dVt8JHN+2+XrgyYPs9+rM3JCZG4aGhrpVviTNi54HcESsjYgjWsvALwIPALcA55fVzgduLsu3AOdGxGBEnAicBNzd26olaf7V6II4FvhMRLSO//HM3BoRXwNuioi3At8BzgbIzO0RcRPwILAfuCgzJyrU3VWteSLWrl1buRJJvdLzAM7MbwMv6tD+L8BrDrLNFmBLl0vruakT9EhaWhbSMLQlpzUvhMPRpKXJAK5s6rwQkpYOA3gBaHVFHDhw4DnvTlEpLW4G8ALQ6op4+umnOefyrZPv3q4sLW4G8ALR6opoTU3pFJXS4mcAS1IlBvAC4Qxp0tJjAC8QzpAmLT0G8ALiDGnS0mIAS1IlBvAC1npkvaTFyQBeYKbOD9H67E0Z0uJjAC8wrZsy9k8cmDwD9qYMaXEygBeggcHVk0E8NjbmTRnSImUAL2DtE/XYFSEtPgbwAtZ+c8bY2JhdEdIiYwAvYFNvzrArQlpcDOApFtpTKjrdnGF3hLQ4GMBTjI2Ncd6V2xb0LcF2R0iLgwHcQT/8qt8PNUqaXo2nImuW2rtFOt2ksWbNGspTpiX1Ec+A+0D7zRl79+5l7969gF0RUr/zDLhPtG7OeNtHvkgMrGD1838EaLoiWmfFa9eurVmipFkygPvMssFVLBtYseBGa0iaPbsg+lSrW6I1WmPqk5UdoiYtfAZwHxsYXD0ZvK1Je3yistQ/DOA+12nSHoeoSf3BAF4E2iftAbsjpH5hAC8CrUl7WsE7sW+cN12xlYceeohzLt862UVhEEsLiwG8CLQm7Xlm/AfPXphbtoyLbxgmlg86XlhaoAzgRaI1aU97d0T7RD6t8cIOXZMWDgNYkirxRowpRkdHF/RMaHOVmezdu5cDBw5w4MABIoK1a9c6h4RUkQG8yLVfmGvdxrxsYBkDKwa57oKfZ+3atZNB3LqY5+Q+Um/YBbHItd8xt2xwFQODq5p+4mXLePOHP89/ec9nGBkZAZrJfd5w2a089thj7Nmzx5ETUpcZwEvA1HHC7e3LVq6anGEtM4kILrj2jufcUee4Yqk7DOAlbmLfOJuv+Spnf+Bzk2fCy1aumryzLjMZGRnhv773b3j88cd5w2W3MjIyMhnCPh5JmjsDWM1wtWXLJrsqWuOK908cYGRkhPOu3AYDKxkbGyMieMtHvjw5pK01B0Vr+cCBZ+cs9oxZmp4B3ObZKR6XZmBMHUPc6j+O5YPPeULzspWrePzxx/nV93+WkZERBlYMMjIywhsuu7Vpf9/NnP2Bz02eMe/evZs9e/ZMdnO0RmQY0lrqol9+8CNiI3A5MABck5l/PN36GzZsyOHh4VkdY3R0lE1bbmRg8HksGxiYbN//zA9YNrDCtra2Z/Z8j+VrjiD3P9NMFj/+/8gIcv8EK9Y0E8O3t7WPvAA4/6ovEMtWcOWbX8LbP34PH3rjT7NmzZrJERlr1qyZvHPPURpaBDr+wPbFMLSIGAD+L/BaYCfwtYi4JTMfnK9jtM5+l69ctUTPf2dn2ZQ77wYGV7P/mR885+679rZWcL/5w59n4plxlq85gijdHgODq39oiNyVb34JF17/D8SyFVx3wc9PBvKvX/NV/vxtP8fQ0JAhrL7XF2fAEfFS4N2ZeWb5fClAZv6fg20z2zPg0dFRztpyIwcmcjIIWiaeGbdtSlsv9g8QAyvIiX20W7Z8BTdc/Es+gkk9dxg/c/17BgysA55o+7wT+JmpK0XEZmBz+bg3Ih6exTFeADw95wq7w5oO4tj//ZyPC6KmKaxpZpZKTVszc+PUxn4J4E7/9/ihU/fMvBq4ek4HiBjOzA1z2bZbrGlmrGlmrGlmellTv4yC2Akc3/Z5PfBkpVokaV70SwB/DTgpIk6MiJXAucAtlWuSpMPSF10Qmbk/It4O/D3NMLRrM3P7PB9mTl0XXWZNM2NNM2NNM9OzmvpiFIQkLUb90gUhSYuOASxJlRjANLc5R8TDEbEjIi7p4XGPj4jbI+KhiNgeEReX9qMjYltEPFLej2rb5tJS58MRcWaX6hqIiH+KiM8uhHrKcZ4fEZ+MiG+Wv6+X1qwrIn6z/Dd7ICI+ERGratQTEddGxO6IeKCtbdZ1RMRpEXF/+e6KOIzbDA9S03vLf7tvRMRnIuL5tWtq++63IyIj4gW9rAlgcnKUpfqiuaj3LeDHgZXA14GTe3Ts44CfLstHAP8MnAz8CXBJab8EeE9ZPrnUNwicWOoe6EJd7wQ+Dny2fK5aTznW9cB/L8srgefXqovmxqBHgdXl803AW2rUA/wc8NPAA21ts64DuBt4Kc2Y+1uBX5rnmn4RWF6W37MQairtx9Nc3H8ceEEva8pMz4CB04EdmfntzHwGuBHY1IsDZ+auzLy3LO8BHqL5x72JJnAo72eV5U3AjZk5npmPAjtK/fMmItYD/xm4pq25Wj2lpiNp/gF9FCAzn8nM71WuazmwOiKWA2toxqX3vJ7M/Arwr1OaZ1VHRBwHHJmZd2aTMh9r22ZeasrMz2fm/vLxH2nG8letqfgg8Ls898auntQEdkFA59uc1/W6iIg4AXgxcBdwbGbugiakgWPKar2o9TKaH8gDbW0164Hmt5MR4M9L18g1EbG2Vl2Z+V3gfcB3gF3Av2Xm52vV08Fs61hXlntV36/TnD1WrSkifgX4bmZ+fcpXPavJAJ7hbc5dLSDiecCngHdk5venW7VD27zVGhG/DOzOzHtmukk362mznObXx6sy88XAKM2v1lXqKn2qm2h+Pf1RYG1E/Lda9czCweroWX0R8S5gP3BDzZoiYg3wLuD3O33dq5oM4Mq3OUfECprwvSEzP12anyq/7lDed/eo1pcDvxIRj9F0xbw6Iv6yYj0tO4GdmXlX+fxJmkCuVdcvAI9m5khm7gM+DbysYj1TzbaOnTzbJdC1+iLifOCXgTeVX+Fr1vQTNP8D/Xr5eV8P3BsR/76nNR1OB/JieNGcXX27/MdoXYQ7pUfHDpp+pMumtL+X515E+ZOyfArPvTjwbbp30euVPHsRbiHU81XgJ8vyu0tNVeqimYlvO03fb9D0s/5GxXpO4LkXvGZdB83t/mfw7MWl181zTRuBB4GhKetVq2nKd4/x7EW43tXUjX8s/fYCXkczAuFbwLt6eNyfpfkV5hvAfeX1OuBHgNuAR8r70W3bvKvU+TCHeQX2ELW9kmcDeCHUcyowXP6u/gY4qmZdwB8C3wQeAP6i/GPteT3AJ2j6offRnKG9dS51ABvKn+VbwIcpd8nOY007aPpVWz/nf1q7pinfP0YJ4F7VlJneiixJtdgHLEmVGMCSVIkBLEmVGMCSVIkBLEmVGMCSVIkBrCUtIt4SER+uePxTI+J1tY6vugxgLSkRMVC7hilOpbn5RkuQAay+ERG/GxH/qyx/MCK+WJZfExF/GRFvLJNlPxAR72nbbm9E/FFE3AW8NCJ+LSL+OSK+TDP/xXTHPLZMIP718npZaX9nOc4DEfGO0nbClEnIfzsi3l2WvxQR74mIu8uxXxHNE77/CDgnIu6LiHPm8+9LC58BrH7yFeAVZXkD8LwymdHP0tx2+x7g1TRnlS+JiLPKumtp5gD4GZpbSP+QJnhfSzP59nSuAL6cmS+imQBoe0ScBvwazZwQZwBvi4gXz6D+5Zl5OvAO4A+ymX/694G/ysxTM/OvZrAPLSIGsPrJPcBpEXEEMA7cSRPErwC+B3wpmxnKWtMd/lzZboJmxjloQrO13jPAoULv1cBVAJk5kZn/RhP4n8nM0czcSzMb2ium2UdLa7a7e2gmhtESZwCrb2Qz9eNjNGef/0AzQ9qraKYW/M40m/4gMyfad3WYpRzsOWD7ee6/qVVTvh8v7xM0s/BpiTOA1W++Avx2ef8q8D9oZtf6R+DnI+IF5ULbG4Evd9j+LuCVEfEjpfvi7EMc7zbgQph8WOmR5dhnRcSa8mSO15dangKOKfsepJn79lD20DwPUEuQAax+81Wah5nemZlPAT8AvprNo3cuBW6nmcv13sy8eerGZb1303RffAG49xDHuxh4VUTcT9N1cEo2z/G7juYBjXcB12TmP5Uz9D8qbZ+lma7yUG4HTvYi3NLkdJSSVIlnwJJUiRcCJCYfFjm1P/ivM3NLjXq0NNgFIUmV2AUhSZUYwJJUiQEsSZUYwJJUyf8HrRsv0HMhJAwAAAAASUVORK5CYII=\n", 337 | "text/plain": [ 338 | "
" 339 | ] 340 | }, 341 | "metadata": { 342 | "needs_background": "light" 343 | }, 344 | "output_type": "display_data" 345 | } 346 | ], 347 | "source": [ 348 | "plt.figure(figsize=(16, 8))\n", 349 | "sns.displot(df_all[\"word_count\"])" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 10, 355 | "metadata": {}, 356 | "outputs": [ 357 | { 358 | "name": "stdout", 359 | "output_type": "stream", 360 | "text": [ 361 | "mean: 121.2643 median: 90.0 std: 91.36838495423598 max: 1438 min: 3\n" 362 | ] 363 | } 364 | ], 365 | "source": [ 366 | "print(f\"mean: {df_all['word_count'].mean()} \\\n", 367 | " median: {df_all['word_count'].median()} \\\n", 368 | " std: {df_all['word_count'].std()} \\\n", 369 | " max: {df_all['word_count'].max()} \\\n", 370 | " min: {df_all['word_count'].min()}\")" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 11, 376 | "metadata": {}, 377 | "outputs": [ 378 | { 379 | "data": { 380 | "text/plain": [ 381 | "272" 382 | ] 383 | }, 384 | "execution_count": 11, 385 | "metadata": {}, 386 | "output_type": "execute_result" 387 | } 388 | ], 389 | "source": [ 390 | "max_sequence_length = int(df_all['word_count'].median() + 2 * df_all['word_count'].std())\n", 391 | "max_sequence_length" 392 | ] 393 | }, 394 | { 395 | "cell_type": "markdown", 396 | "metadata": {}, 397 | "source": [ 398 | "#### Build Vocab" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 12, 404 | "metadata": {}, 405 | "outputs": [ 406 | { 407 | "data": { 408 | "text/plain": [ 409 | "0 None\n", 410 | "1 None\n", 411 | "2 None\n", 412 | "3 None\n", 413 | "4 None\n", 414 | " ... \n", 415 | "4995 None\n", 416 | "4996 None\n", 417 | "4997 None\n", 418 | "4998 None\n", 419 | "4999 None\n", 420 | "Name: word_list, Length: 50000, dtype: object" 421 | ] 422 | }, 423 | "execution_count": 12, 424 | "metadata": {}, 425 | "output_type": "execute_result" 426 | } 427 | ], 428 | "source": [ 429 | "all_words_list = []\n", 430 | "df_all[\"word_list\"].apply(lambda word_list: all_words_list.extend(word_list))" 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": {}, 436 | "source": [ 437 | "##### Get frequency of the words" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 13, 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [ 446 | "from collections import Counter" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 14, 452 | "metadata": {}, 453 | "outputs": [], 454 | "source": [ 455 | "word_frequency = Counter(all_words_list)" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 15, 461 | "metadata": {}, 462 | "outputs": [ 463 | { 464 | "name": "stdout", 465 | "output_type": "stream", 466 | "text": [ 467 | "mean: 40.14523412257005 median: 1.0 std: 639.2531565816863 max: 114803 min: 1\n" 468 | ] 469 | } 470 | ], 471 | "source": [ 472 | "print(f\"mean: {np.mean(list(word_frequency.values()))} \\\n", 473 | " median: {np.median(list(word_frequency.values()))} \\\n", 474 | " std: {np.std(list(word_frequency.values()))} \\\n", 475 | " max: {max(word_frequency.values())} \\\n", 476 | " min: {min(word_frequency.values())}\")" 477 | ] 478 | }, 479 | { 480 | "cell_type": "markdown", 481 | "metadata": {}, 482 | "source": [ 483 | "We can see that there is a huge standard deviation" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": {}, 489 | "source": [ 490 | "For Convinience we'll take the min_frequency of the word to appear in the vocabulary as 20" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": 16, 496 | "metadata": {}, 497 | "outputs": [], 498 | "source": [ 499 | "min_frequency = 20" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": 17, 505 | "metadata": {}, 506 | "outputs": [], 507 | "source": [ 508 | "new_word_frequency = {word: freq for word, freq in word_frequency.items() if freq > min_frequency}" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": 18, 514 | "metadata": {}, 515 | "outputs": [ 516 | { 517 | "data": { 518 | "text/plain": [ 519 | "17018" 520 | ] 521 | }, 522 | "execution_count": 18, 523 | "metadata": {}, 524 | "output_type": "execute_result" 525 | } 526 | ], 527 | "source": [ 528 | "len(new_word_frequency)" 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "metadata": {}, 534 | "source": [ 535 | "#### Load GloVe Vectors" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": 19, 541 | "metadata": {}, 542 | "outputs": [], 543 | "source": [ 544 | "import os" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": 20, 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "glove_dir = \"../data/glove.6B/\"" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": 21, 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [ 562 | "embedding_mapping = {}" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": null, 568 | "metadata": {}, 569 | "outputs": [], 570 | "source": [ 571 | "with open(os.path.join(glove_dir, \"glove.6B.100d.txt\"), \"r\", encoding=\"utf-8\") as f:\n", 572 | " for line in f.readlines():\n", 573 | " values = line.split()\n", 574 | " word = values[0]\n", 575 | " coeffs = np.asarray(values[1:], dtype='float32')\n", 576 | " embedding_mapping[word] = coeffs" 577 | ] 578 | }, 579 | { 580 | "cell_type": "markdown", 581 | "metadata": {}, 582 | "source": [ 583 | "##### Create Embedding Matrix" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": null, 589 | "metadata": {}, 590 | "outputs": [], 591 | "source": [ 592 | "vocab_size = len(new_word_frequency)\n", 593 | "embedding_dim = 100" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": null, 599 | "metadata": {}, 600 | "outputs": [], 601 | "source": [ 602 | "embedding_matrix = torch.randn((vocab_size + 2, embedding_dim), dtype=torch.float32)" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": null, 608 | "metadata": {}, 609 | "outputs": [], 610 | "source": [ 611 | "embedding_matrix[0] = torch.zeros(embedding_dim) # making the padding idx zero" 612 | ] 613 | }, 614 | { 615 | "cell_type": "markdown", 616 | "metadata": {}, 617 | "source": [ 618 | "##### Create Word2Idx" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": null, 624 | "metadata": {}, 625 | "outputs": [], 626 | "source": [ 627 | "word2idx = {word: idx + 2 for idx, word in enumerate(new_word_frequency)}" 628 | ] 629 | }, 630 | { 631 | "cell_type": "markdown", 632 | "metadata": {}, 633 | "source": [ 634 | "##### Populate embedding matrix" 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": null, 640 | "metadata": {}, 641 | "outputs": [], 642 | "source": [ 643 | "rare_words = [] # words not in embedding\n", 644 | "for word, idx in word2idx.items():\n", 645 | " vector = embedding_mapping.get(word)\n", 646 | " if vector is None:\n", 647 | " rare_words.append(word)\n", 648 | " continue\n", 649 | " tensor = torch.from_numpy(vector)\n", 650 | " embedding_matrix[idx] = tensor" 651 | ] 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": null, 656 | "metadata": {}, 657 | "outputs": [], 658 | "source": [ 659 | "len(rare_words)" 660 | ] 661 | }, 662 | { 663 | "cell_type": "markdown", 664 | "metadata": {}, 665 | "source": [ 666 | "Only 435 words do not have a pretrained embedding, which is good" 667 | ] 668 | }, 669 | { 670 | "cell_type": "markdown", 671 | "metadata": {}, 672 | "source": [ 673 | "#### Preparing data for training" 674 | ] 675 | }, 676 | { 677 | "cell_type": "markdown", 678 | "metadata": {}, 679 | "source": [ 680 | "##### Converting words to indices" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": null, 686 | "metadata": {}, 687 | "outputs": [], 688 | "source": [ 689 | "unk_idx = 1" 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": null, 695 | "metadata": {}, 696 | "outputs": [], 697 | "source": [ 698 | "df_all[\"word_indices\"] = df_all[\"word_list\"].apply(lambda word_list: [word2idx.get(word, unk_idx) for word in word_list])" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": null, 704 | "metadata": {}, 705 | "outputs": [], 706 | "source": [ 707 | "df_all[\"padded_word_indices\"] = df_all[\"word_indices\"].apply(lambda word_index_list: pad_sequence(word_index_list, max_sequence_length))" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": null, 713 | "metadata": {}, 714 | "outputs": [], 715 | "source": [ 716 | "n_sequences = len(df_all)" 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "execution_count": null, 722 | "metadata": {}, 723 | "outputs": [], 724 | "source": [ 725 | "all_data = np.zeros((n_sequences, max_sequence_length), dtype=np.int64)\n", 726 | "for i, _list in enumerate(df_all[\"padded_word_indices\"]):\n", 727 | " all_data[i] = np.asarray(_list)" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": null, 733 | "metadata": {}, 734 | "outputs": [], 735 | "source": [ 736 | "len(df_all[\"word_list\"])" 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": null, 742 | "metadata": {}, 743 | "outputs": [], 744 | "source": [ 745 | "all_labels = df_all[\"label\"].values" 746 | ] 747 | }, 748 | { 749 | "cell_type": "markdown", 750 | "metadata": {}, 751 | "source": [ 752 | "##### Splitting the data and loading into a dataloader" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": null, 758 | "metadata": {}, 759 | "outputs": [], 760 | "source": [ 761 | "train_dataloader, test_dataloader = train_test_split_tensors(all_data, all_labels)" 762 | ] 763 | }, 764 | { 765 | "cell_type": "markdown", 766 | "metadata": {}, 767 | "source": [ 768 | "### Building the Model" 769 | ] 770 | }, 771 | { 772 | "cell_type": "code", 773 | "execution_count": null, 774 | "metadata": {}, 775 | "outputs": [], 776 | "source": [ 777 | "class LinearClassifier(nn.Module):\n", 778 | " def __init__(self, VOCAB_SIZE, EMBEDDING_DIM, SEQUENCE_LENGTH, weights):\n", 779 | " super().__init__()\n", 780 | " self.EMBEDDING_DIM = EMBEDDING_DIM\n", 781 | " self.SEQUENCE_LENGTH = SEQUENCE_LENGTH\n", 782 | " self.embedding = nn.Embedding(VOCAB_SIZE + 2, EMBEDDING_DIM, padding_idx=0, _weight = weights)\n", 783 | " self.linear1 = nn.Linear(EMBEDDING_DIM * SEQUENCE_LENGTH, 64)\n", 784 | " self.linear2 = nn.Linear(64, 2)\n", 785 | " \n", 786 | " def forward(self, inputs):\n", 787 | " embedding_vectors = self.embedding(inputs)\n", 788 | " flattened_embeddings = embedding_vectors.view(-1, self.EMBEDDING_DIM * self.SEQUENCE_LENGTH)\n", 789 | " out = torch.tanh(self.linear1(flattened_embeddings))\n", 790 | " out = self.linear2(out)\n", 791 | " \n", 792 | " return out" 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": null, 798 | "metadata": {}, 799 | "outputs": [], 800 | "source": [ 801 | "model = LinearClassifier(vocab_size, embedding_dim, max_sequence_length, embedding_matrix)" 802 | ] 803 | }, 804 | { 805 | "cell_type": "markdown", 806 | "metadata": {}, 807 | "source": [ 808 | "### Training the model" 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": null, 814 | "metadata": {}, 815 | "outputs": [], 816 | "source": [ 817 | "loss = nn.CrossEntropyLoss()\n", 818 | "optimizer = optim.RMSprop(model.parameters(), lr=1e-3)\n", 819 | "n_epochs = 50" 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": null, 825 | "metadata": {}, 826 | "outputs": [], 827 | "source": [ 828 | "device = torch.device('cuda') " 829 | ] 830 | }, 831 | { 832 | "cell_type": "code", 833 | "execution_count": null, 834 | "metadata": {}, 835 | "outputs": [], 836 | "source": [ 837 | "model.to(device=device)" 838 | ] 839 | }, 840 | { 841 | "cell_type": "code", 842 | "execution_count": null, 843 | "metadata": {}, 844 | "outputs": [], 845 | "source": [ 846 | "train_acc_list, val_acc_list, train_loss_list, val_loss_list = train_model(n_epochs, model, train_dataloader, test_dataloader, loss, optimizer, device)" 847 | ] 848 | }, 849 | { 850 | "cell_type": "markdown", 851 | "metadata": {}, 852 | "source": [ 853 | "##### Plot loss" 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": null, 859 | "metadata": {}, 860 | "outputs": [], 861 | "source": [ 862 | "fig, axes = plt.subplots(1, 1, figsize=(8, 6))\n", 863 | "\n", 864 | "axes.plot(list(range(n_epochs)), train_loss_list, label=\"train loss\")\n", 865 | "axes.plot(list(range(n_epochs)), val_loss_list, color='orange', label=\"val loss\")\n", 866 | "\n", 867 | "axes.set_xlabel(\"Epoch\")\n", 868 | "axes.set_ylabel(\"Loss\")\n", 869 | "\n", 870 | "plt.legend()" 871 | ] 872 | }, 873 | { 874 | "cell_type": "markdown", 875 | "metadata": {}, 876 | "source": [ 877 | "##### Plot accuracy" 878 | ] 879 | }, 880 | { 881 | "cell_type": "code", 882 | "execution_count": null, 883 | "metadata": {}, 884 | "outputs": [], 885 | "source": [ 886 | "fig, axes = plt.subplots(1, 1, figsize=(8, 6))\n", 887 | "\n", 888 | "axes.plot(list(range(n_epochs)), train_acc_list, label=\"train Accuracy\")\n", 889 | "axes.plot(list(range(n_epochs)), val_acc_list, color='orange', label=\"val Accuracy\")\n", 890 | "\n", 891 | "axes.set_xlabel(\"Epoch\")\n", 892 | "axes.set_ylabel(\"Accuracy\")\n", 893 | "\n", 894 | "plt.legend()" 895 | ] 896 | }, 897 | { 898 | "cell_type": "markdown", 899 | "metadata": {}, 900 | "source": [ 901 | "### Thoughts" 902 | ] 903 | }, 904 | { 905 | "cell_type": "markdown", 906 | "metadata": {}, 907 | "source": [ 908 | "1. Even with pre-trained embeddings, we didn't do much better than the previous time where we didn't use any. Reasons:\n", 909 | " - the dataset was sufficiently big in our case\n", 910 | " - This tells us to use pretrained embeddings when the dataset is small\n", 911 | "2. Also observe that the model is not able generalize i.e the validation accuracy doesn't go beyond a certain point while the training accuracy is high. Reasons for this:\n", 912 | " - We are using a linear model which treats each word independently, and hence it's not able to find the context in the sentence, and hence it tries to memorize the sentences.\n", 913 | " - What this means is that the model cannot recognize similar but not exact sentences. example: The sentences 'I'm going to bed' and 'I'm sleeping on bed' are totally different to the model even though they mean the same thing\n", 914 | " - There are lots of parameters in the model which make the model memorize the training data, hence the training accuracy is high." 915 | ] 916 | }, 917 | { 918 | "cell_type": "code", 919 | "execution_count": null, 920 | "metadata": {}, 921 | "outputs": [], 922 | "source": [] 923 | } 924 | ], 925 | "metadata": { 926 | "kernelspec": { 927 | "display_name": "Python 3", 928 | "language": "python", 929 | "name": "python3" 930 | }, 931 | "language_info": { 932 | "codemirror_mode": { 933 | "name": "ipython", 934 | "version": 3 935 | }, 936 | "file_extension": ".py", 937 | "mimetype": "text/x-python", 938 | "name": "python", 939 | "nbconvert_exporter": "python", 940 | "pygments_lexer": "ipython3", 941 | "version": "3.8.5" 942 | } 943 | }, 944 | "nbformat": 4, 945 | "nbformat_minor": 4 946 | } 947 | -------------------------------------------------------------------------------- /sentiment_analysis/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Contains various utilities from data preprocessing, and model training 3 | ''' 4 | 5 | import re 6 | import torch 7 | from torch.utils.data import TensorDataset 8 | from torch.utils.data import DataLoader 9 | import nltk 10 | from sklearn.model_selection import train_test_split 11 | from tqdm import tqdm 12 | 13 | nltk.download("stopwords") 14 | nltk.download("punkt") 15 | nltk.download("wordnet") 16 | 17 | from nltk.corpus import stopwords 18 | from nltk.tokenize import word_tokenize 19 | from nltk.stem import WordNetLemmatizer 20 | 21 | stop_words = stopwords.words("english") 22 | lm = WordNetLemmatizer() 23 | 24 | # Helper Functions 25 | 26 | def clean_text(text): 27 | ''' 28 | Remove Punctuation and Numbers 29 | ''' 30 | text = text.lower() 31 | text = text.replace("
", "") 32 | text = re.sub(r"(@[A-Za-z0-9]+)|([^0-9A-Za-z \t])|(\w+:\/\/\S+)|^rt|http.+?", "", text) 33 | # remove numbers 34 | text = re.sub(r"\d+", "", text) 35 | 36 | return text 37 | 38 | def remove_stop_words(word_list): 39 | ''' 40 | Remove stop words like ours, if etc 41 | ''' 42 | new_word_list = [word for word in word_list if word not in stop_words] 43 | 44 | return new_word_list 45 | 46 | def remove_br(word_list): 47 | ''' 48 | Remove br at the end of the word in a word list 49 | ''' 50 | new_word_list = [re.sub("br$", "", word) for word in word_list] 51 | return new_word_list 52 | 53 | def tokenize_text(text): 54 | ''' 55 | Tokenize Text 56 | ''' 57 | return word_tokenize(text) 58 | 59 | def lemmatize_text(word_list): 60 | ''' 61 | Lemmatize text 62 | ''' 63 | new_word_list = [lm.lemmatize(word) for word in word_list] 64 | return new_word_list 65 | 66 | def pad_sequence(sequence, max_sequence_length): 67 | ''' 68 | Pad sequences to max_sequence_length with padding_idx = 0 69 | ''' 70 | padding_idx = 0 71 | sequence_length = len(sequence) 72 | if sequence_length > max_sequence_length: 73 | sequence = sequence[:max_sequence_length] 74 | elif sequence_length < max_sequence_length: 75 | for i in range(max_sequence_length - sequence_length): 76 | sequence.append(padding_idx) 77 | return sequence 78 | 79 | def preprocess_text(text): 80 | ''' 81 | clean sentences, remove punctuation, perform tokenization, stop word removal, lemmatization. 82 | ''' 83 | text = clean_text(text) 84 | word_list = tokenize_text(text) 85 | word_list = remove_br(word_list) 86 | word_list = remove_stop_words(word_list) 87 | word_list = lemmatize_text(word_list) 88 | 89 | return word_list 90 | 91 | def train_test_split_tensors(X, y, train_size=0.8, batch_size=32): 92 | ''' 93 | X - training data 94 | numpy array 95 | shape: (n_sequences, max_sequence_length) 96 | y - target 97 | numpy array 98 | shape: (n_sequences) 99 | train_size - % of the training data 100 | float 101 | ''' 102 | X_train, X_test, Y_train, Y_test = train_test_split(X, y, train_size=0.8, shuffle=True) 103 | print("Data Split in the following way:") 104 | print(f"X train: {X_train.shape}\n X test: {X_test.shape} \n Y train: {Y_train.shape}\n Y test: {Y_test.shape}") 105 | print("Creating dataloaders...") 106 | 107 | X_train_tensor = torch.from_numpy(X_train) 108 | X_test_tensor = torch.from_numpy(X_test) 109 | Y_train_tensor = torch.from_numpy(Y_train) 110 | Y_test_tensor = torch.from_numpy(Y_test) 111 | 112 | train_ds = TensorDataset(X_train_tensor, Y_train_tensor) 113 | test_ds = TensorDataset(X_test_tensor, Y_test_tensor) 114 | 115 | train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) 116 | test_dataloader = DataLoader(test_ds, batch_size=batch_size, shuffle=True) 117 | 118 | print("Done") 119 | 120 | return train_dataloader, test_dataloader 121 | 122 | def get_accuracy(model, train_dataloader, val_dataloader, device): 123 | model.eval() 124 | result = dict() 125 | for mode, loader in [("train", train_dataloader), ("val", val_dataloader)]: 126 | corrects = 0 127 | total = 0 128 | with torch.no_grad(): 129 | for sequences, labels in loader: 130 | sequences = sequences.to(device=device) 131 | labels = labels.to(device=device) 132 | outputs = model(sequences) 133 | _, preds = torch.max(outputs, dim=1) 134 | total += labels.shape[0] 135 | corrects += int(sum(preds == labels)) 136 | 137 | result[mode] = round(corrects/total, 4) * 100 138 | 139 | print(f"Training Accuracy: {result['train']}") 140 | print(f"Validation Accuracy: {result['val']}") 141 | print("-----------") 142 | 143 | return result 144 | 145 | def train_model(n_epochs, model, train_dataloader, test_dataloader, loss, optimizer, device): 146 | train_loss_list = [] 147 | val_loss_list = [] 148 | train_acc_list = [] 149 | val_acc_list = [] 150 | for epoch in tqdm(range(n_epochs)): 151 | model.train() 152 | 153 | # train 154 | cummulative_loss = 0 155 | n_batches = 0 156 | for sequences, labels in train_dataloader: 157 | sequences = sequences.to(device=device) 158 | labels = labels.to(device=device) 159 | outputs = model(sequences) 160 | train_loss = loss(outputs, labels) 161 | 162 | optimizer.zero_grad() 163 | train_loss.backward() 164 | optimizer.step() 165 | 166 | cummulative_loss += train_loss 167 | n_batches += 1 168 | 169 | loss_per_epoch = cummulative_loss / n_batches 170 | train_loss_list.append(loss_per_epoch) 171 | 172 | # val 173 | cummulative_loss_val = 0 174 | n_batches_val = 0 175 | for sequences, labels in test_dataloader: 176 | sequences = sequences.to(device=device) 177 | labels = labels.to(device=device) 178 | with torch.no_grad(): 179 | outputs = model(sequences) 180 | val_loss = loss(outputs, labels) 181 | 182 | cummulative_loss_val += val_loss 183 | n_batches_val += 1 184 | 185 | loss_per_epoch_val = cummulative_loss_val / n_batches_val 186 | val_loss_list.append(loss_per_epoch_val) 187 | 188 | acc = get_accuracy(model, train_dataloader, test_dataloader, device) 189 | train_acc_list.append(acc["train"]) 190 | val_acc_list.append(acc["val"]) 191 | 192 | return train_acc_list, val_acc_list, train_loss_list, val_loss_list 193 | 194 | -------------------------------------------------------------------------------- /time series/sequence_data.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wingedrasengan927/pytorch-tutorials/20ddc1fda3750550c8426799187276619a7f7610/time series/sequence_data.jpg -------------------------------------------------------------------------------- /time series/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Contains various utilities from data preprocessing, and model training 3 | ''' 4 | 5 | import torch 6 | from tqdm import tqdm 7 | 8 | def train_model(n_epochs, model, train_dataloader, test_dataloader, loss, optimizer, device): 9 | train_loss_list = [] 10 | val_loss_list = [] 11 | for epoch in tqdm(range(n_epochs)): 12 | 13 | # train 14 | model.train() 15 | 16 | cummulative_loss = 0 17 | n_batches = 0 18 | for sequences, labels in train_dataloader: 19 | sequences = sequences.to(device=device) 20 | labels = labels.to(device=device) 21 | labels = labels.unsqueeze(1) 22 | 23 | outputs = model(sequences) 24 | train_loss = loss(outputs, labels) 25 | 26 | optimizer.zero_grad() 27 | train_loss.backward() 28 | optimizer.step() 29 | 30 | cummulative_loss += train_loss 31 | n_batches += 1 32 | 33 | loss_per_epoch = cummulative_loss / n_batches 34 | train_loss_list.append(loss_per_epoch) 35 | 36 | # val 37 | model.eval() 38 | 39 | cummulative_loss_val = 0 40 | n_batches_val = 0 41 | for sequences, labels in test_dataloader: 42 | sequences = sequences.to(device=device) 43 | labels = labels.to(device=device) 44 | labels = labels.unsqueeze(1) 45 | 46 | with torch.no_grad(): 47 | outputs = model(sequences) 48 | val_loss = loss(outputs, labels) 49 | 50 | cummulative_loss_val += val_loss 51 | n_batches_val += 1 52 | 53 | loss_per_epoch_val = cummulative_loss_val / n_batches_val 54 | val_loss_list.append(loss_per_epoch_val) 55 | 56 | return train_loss_list, val_loss_list 57 | 58 | --------------------------------------------------------------------------------