├── test-folder ├── data │ ├── idl-pretrain-dataset │ │ ├── dataset_dict.json │ │ └── train │ │ │ ├── data-00000-of-00001.arrow │ │ │ ├── cache-03c63454d8d0c442.arrow │ │ │ ├── cache-04b6caffbd677b49.arrow │ │ │ ├── cache-1b7d50c892e7d266.arrow │ │ │ ├── cache-3bbe6122452c2ec1.arrow │ │ │ ├── cache-7c68162c665bbdff.arrow │ │ │ ├── cache-bb43aa8c7ec63c3d.arrow │ │ │ ├── cache-cbf1d5fc1918aa19.arrow │ │ │ ├── state.json │ │ │ └── dataset_info.json │ ├── ocr_files.txt │ ├── pdf_files.txt │ ├── sample_rvl_cdip_dataset │ │ ├── RVL-CDIP Invoice Class Dataset │ │ │ ├── 0001139724.tif │ │ │ ├── 0001139827.tif │ │ │ ├── 0001139841.tif │ │ │ ├── 0001139848.tif │ │ │ ├── 0001140924.tif │ │ │ ├── 0001140927.tif │ │ │ ├── 0001140933.tif │ │ │ ├── 0001140965.tif │ │ │ ├── 0001140968.tif │ │ │ └── 0001140975.tif │ │ └── IDL Pre-training dataset │ │ │ └── pdfs │ │ │ └── f │ │ │ └── f │ │ │ └── b │ │ │ └── b │ │ │ └── ffbb0000 │ │ │ └── ffbb0000.pdf │ ├── ocr_pdf.csv │ └── true_data_path.csv └── code │ ├── __pycache__ │ ├── modeling.cpython-310.pyc │ └── utils_modeling.cpython-310.pyc │ ├── utils_modeling.py │ ├── utils_data_prep.py │ ├── modeling.py │ └── pretraining-modeling.ipynb ├── README.md ├── all-requirements.txt ├── src ├── utils_modeling.py ├── utils_data_prep.py └── modeling.py └── LICENSE /test-folder/data/idl-pretrain-dataset/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["train"]} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DocFormerV2 [work in progress] 2 | This repo consists of my implementation of DocFormerV2 3 | -------------------------------------------------------------------------------- /test-folder/data/ocr_files.txt: -------------------------------------------------------------------------------- 1 | ../data/sample_rvl_cdip_dataset/IDL Pre-training dataset/ocrs/f/f/b/b/ffbb0000/ffbb0000.json 2 | -------------------------------------------------------------------------------- /test-folder/data/pdf_files.txt: -------------------------------------------------------------------------------- 1 | ../data/sample_rvl_cdip_dataset/IDL Pre-training dataset/pdfs/f/f/b/b/ffbb0000/ffbb0000.pdf 2 | -------------------------------------------------------------------------------- /test-folder/code/__pycache__/modeling.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/code/__pycache__/modeling.cpython-310.pyc -------------------------------------------------------------------------------- /test-folder/code/__pycache__/utils_modeling.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/code/__pycache__/utils_modeling.cpython-310.pyc -------------------------------------------------------------------------------- /test-folder/data/idl-pretrain-dataset/train/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/idl-pretrain-dataset/train/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /test-folder/data/idl-pretrain-dataset/train/cache-03c63454d8d0c442.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/idl-pretrain-dataset/train/cache-03c63454d8d0c442.arrow -------------------------------------------------------------------------------- /test-folder/data/idl-pretrain-dataset/train/cache-04b6caffbd677b49.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/idl-pretrain-dataset/train/cache-04b6caffbd677b49.arrow -------------------------------------------------------------------------------- /test-folder/data/idl-pretrain-dataset/train/cache-1b7d50c892e7d266.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/idl-pretrain-dataset/train/cache-1b7d50c892e7d266.arrow -------------------------------------------------------------------------------- /test-folder/data/idl-pretrain-dataset/train/cache-3bbe6122452c2ec1.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/idl-pretrain-dataset/train/cache-3bbe6122452c2ec1.arrow -------------------------------------------------------------------------------- /test-folder/data/idl-pretrain-dataset/train/cache-7c68162c665bbdff.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/idl-pretrain-dataset/train/cache-7c68162c665bbdff.arrow -------------------------------------------------------------------------------- /test-folder/data/idl-pretrain-dataset/train/cache-bb43aa8c7ec63c3d.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/idl-pretrain-dataset/train/cache-bb43aa8c7ec63c3d.arrow -------------------------------------------------------------------------------- /test-folder/data/idl-pretrain-dataset/train/cache-cbf1d5fc1918aa19.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/idl-pretrain-dataset/train/cache-cbf1d5fc1918aa19.arrow -------------------------------------------------------------------------------- /test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001139724.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001139724.tif -------------------------------------------------------------------------------- /test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001139827.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001139827.tif -------------------------------------------------------------------------------- /test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001139841.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001139841.tif -------------------------------------------------------------------------------- /test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001139848.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001139848.tif -------------------------------------------------------------------------------- /test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140924.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140924.tif -------------------------------------------------------------------------------- /test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140927.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140927.tif -------------------------------------------------------------------------------- /test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140933.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140933.tif -------------------------------------------------------------------------------- /test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140965.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140965.tif -------------------------------------------------------------------------------- /test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140968.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140968.tif -------------------------------------------------------------------------------- /test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140975.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset/0001140975.tif -------------------------------------------------------------------------------- /test-folder/data/ocr_pdf.csv: -------------------------------------------------------------------------------- 1 | pdf_path,ocr_path 2 | ../data/sample_rvl_cdip_dataset/IDL Pre-training dataset/pdfs/f/f/b/b/ffbb0000/ffbb0000.pdf,../data/sample_rvl_cdip_dataset/IDL Pre-training dataset/ocrs/f/f/b/b/ffbb0000/ffbb0000.json 3 | -------------------------------------------------------------------------------- /test-folder/data/sample_rvl_cdip_dataset/IDL Pre-training dataset/pdfs/f/f/b/b/ffbb0000/ffbb0000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uakarsh/docformerv2/HEAD/test-folder/data/sample_rvl_cdip_dataset/IDL Pre-training dataset/pdfs/f/f/b/b/ffbb0000/ffbb0000.pdf -------------------------------------------------------------------------------- /all-requirements.txt: -------------------------------------------------------------------------------- 1 | ipykernel 2 | requests 3 | tqdm 4 | joblib 5 | PyPDF2 6 | pandas 7 | pillow==9.2.0 8 | datasets 9 | transformers 10 | torchvision 11 | # pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu -------------------------------------------------------------------------------- /test-folder/data/idl-pretrain-dataset/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "e8ea4505c1298550", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": null 13 | } -------------------------------------------------------------------------------- /test-folder/data/true_data_path.csv: -------------------------------------------------------------------------------- 1 | img_list,ocr_list,pg_num 2 | ../data/sample_rvl_cdip_dataset/IDL Pre-training dataset/pdfs/f/f/b/b/ffbb0000/ffbb0000.pdf,../data/sample_rvl_cdip_dataset/IDL Pre-training dataset/ocrs/f/f/b/b/ffbb0000/ffbb0000.json,1 3 | ../data/sample_rvl_cdip_dataset/IDL Pre-training dataset/pdfs/f/f/b/b/ffbb0000/ffbb0000.pdf,../data/sample_rvl_cdip_dataset/IDL Pre-training dataset/ocrs/f/f/b/b/ffbb0000/ffbb0000.json,2 4 | -------------------------------------------------------------------------------- /test-folder/data/idl-pretrain-dataset/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "", 3 | "description": "", 4 | "features": { 5 | "img": { 6 | "_type": "Image" 7 | }, 8 | "bbox": { 9 | "feature": { 10 | "feature": { 11 | "dtype": "int64", 12 | "_type": "Value" 13 | }, 14 | "_type": "Sequence" 15 | }, 16 | "_type": "Sequence" 17 | }, 18 | "words": { 19 | "feature": { 20 | "dtype": "string", 21 | "_type": "Value" 22 | }, 23 | "_type": "Sequence" 24 | } 25 | }, 26 | "homepage": "", 27 | "license": "" 28 | } -------------------------------------------------------------------------------- /src/utils_modeling.py: -------------------------------------------------------------------------------- 1 | ## Pre-processing bounding boxes, ref: https://github.com/uakarsh/latr/blob/1e73c1a99f9a0db4d85177259226148a65556069/src/new_latr/dataset.py#L34 2 | 3 | def normalize_box(box, width, height, size=1000): 4 | """ 5 | Takes a bounding box and normalizes it to a thousand pixels. If you notice it is 6 | just like calculating percentage except takes 1000 instead of 100. 7 | 8 | Arguments: 9 | box: A list of bounding box coordinates 10 | width: The width of the image 11 | height: The height of the image 12 | size: The size to normalize to 13 | Returns: 14 | A list of normalized bounding box coordinates 15 | """ 16 | return [ 17 | int(size * (box[0] / width)), 18 | int(size * (box[1] / height)), 19 | int(size * (box[2] / width)), 20 | int(size * (box[3] / height)), 21 | ] 22 | 23 | def get_tokens_with_boxes(bounding_boxes, list_of_words, tokenizer, pad_token_box=[0, 0, 0, 0], max_seq_len=-1, eos_token_box=[0, 0, 1000, 1000]): 24 | 25 | ''' 26 | A function to get the tokens with the bounding boxes 27 | Arguments: 28 | bounding_boxes: A list of bounding boxes 29 | list_of_words: A list of words 30 | tokenizer: The tokenizer to use 31 | pad_token_box: The padding token box 32 | max_seq_len: The maximum sequence length, not padded if max_seq_len is -1 33 | eos_token_box: The end of sequence token box 34 | Returns: 35 | A list of input_ids, bbox_according_to_tokenizer, attention_mask 36 | ''' 37 | 38 | # 2. Performing the semantic pre-processing 39 | encoding = tokenizer(list_of_words, is_split_into_words=True, 40 | add_special_tokens=False) 41 | 42 | input_ids = encoding['input_ids'] 43 | attention_mask = encoding['attention_mask'] 44 | 45 | # Note that, there is no need for bboxes, since the model does not use bbox as feature, so no pre-processing of that 46 | bbox_according_to_tokenizer = [bounding_boxes[i] 47 | for i in encoding.word_ids()] 48 | 49 | # Truncation of token_boxes + token_labels 50 | special_tokens_count = 1 51 | if max_seq_len != -1 and len(input_ids) > max_seq_len - special_tokens_count: 52 | bbox_according_to_tokenizer = bbox_according_to_tokenizer[: ( 53 | max_seq_len - special_tokens_count)] 54 | input_ids = input_ids[: (max_seq_len - special_tokens_count)] 55 | attention_mask = attention_mask[: (max_seq_len - special_tokens_count)] 56 | 57 | ## Adding End of sentence token 58 | input_ids = input_ids + [tokenizer.eos_token_id] 59 | bbox_according_to_tokenizer = bbox_according_to_tokenizer + [eos_token_box] 60 | attention_mask = attention_mask + [1] 61 | 62 | # Padding 63 | if max_seq_len != -1 and len(input_ids) < max_seq_len: 64 | pad_length = max_seq_len - len(input_ids) 65 | 66 | input_ids = input_ids + [tokenizer.pad_token_id] * (pad_length) 67 | bbox_according_to_tokenizer = bbox_according_to_tokenizer + \ 68 | [pad_token_box] * (pad_length) 69 | attention_mask = attention_mask + [0] * (pad_length) 70 | 71 | return dict(input_ids = input_ids, bboxes = bbox_according_to_tokenizer, attention_mask = attention_mask) -------------------------------------------------------------------------------- /test-folder/code/utils_modeling.py: -------------------------------------------------------------------------------- 1 | ## Pre-processing bounding boxes, ref: https://github.com/uakarsh/latr/blob/1e73c1a99f9a0db4d85177259226148a65556069/src/new_latr/dataset.py#L34 2 | 3 | def normalize_box(box, width, height, size=1000): 4 | """ 5 | Takes a bounding box and normalizes it to a thousand pixels. If you notice it is 6 | just like calculating percentage except takes 1000 instead of 100. 7 | 8 | Arguments: 9 | box: A list of bounding box coordinates 10 | width: The width of the image 11 | height: The height of the image 12 | size: The size to normalize to 13 | Returns: 14 | A list of normalized bounding box coordinates 15 | """ 16 | return [ 17 | int(size * (box[0] / width)), 18 | int(size * (box[1] / height)), 19 | int(size * (box[2] / width)), 20 | int(size * (box[3] / height)), 21 | ] 22 | 23 | def get_tokens_with_boxes(bounding_boxes, list_of_words, tokenizer, pad_token_box=[0, 0, 0, 0], max_seq_len=-1, eos_token_box=[0, 0, 1000, 1000]): 24 | 25 | ''' 26 | A function to get the tokens with the bounding boxes 27 | Arguments: 28 | bounding_boxes: A list of bounding boxes 29 | list_of_words: A list of words 30 | tokenizer: The tokenizer to use 31 | pad_token_box: The padding token box 32 | max_seq_len: The maximum sequence length, not padded if max_seq_len is -1 33 | eos_token_box: The end of sequence token box 34 | Returns: 35 | A list of input_ids, bbox_according_to_tokenizer, attention_mask 36 | ''' 37 | 38 | # 2. Performing the semantic pre-processing 39 | encoding = tokenizer(list_of_words, is_split_into_words=True, 40 | add_special_tokens=False) 41 | 42 | input_ids = encoding['input_ids'] 43 | attention_mask = encoding['attention_mask'] 44 | 45 | # Note that, there is no need for bboxes, since the model does not use bbox as feature, so no pre-processing of that 46 | bbox_according_to_tokenizer = [bounding_boxes[i] 47 | for i in encoding.word_ids()] 48 | 49 | # Truncation of token_boxes + token_labels 50 | special_tokens_count = 1 51 | if max_seq_len != -1 and len(input_ids) > max_seq_len - special_tokens_count: 52 | bbox_according_to_tokenizer = bbox_according_to_tokenizer[: ( 53 | max_seq_len - special_tokens_count)] 54 | input_ids = input_ids[: (max_seq_len - special_tokens_count)] 55 | attention_mask = attention_mask[: (max_seq_len - special_tokens_count)] 56 | 57 | ## Adding End of sentence token 58 | input_ids = input_ids + [tokenizer.eos_token_id] 59 | bbox_according_to_tokenizer = bbox_according_to_tokenizer + [eos_token_box] 60 | attention_mask = attention_mask + [1] 61 | 62 | # Padding 63 | if max_seq_len != -1 and len(input_ids) < max_seq_len: 64 | pad_length = max_seq_len - len(input_ids) 65 | 66 | input_ids = input_ids + [tokenizer.pad_token_id] * (pad_length) 67 | bbox_according_to_tokenizer = bbox_according_to_tokenizer + \ 68 | [pad_token_box] * (pad_length) 69 | attention_mask = attention_mask + [0] * (pad_length) 70 | 71 | return dict(input_ids = input_ids, bboxes = bbox_according_to_tokenizer, attention_mask = attention_mask) -------------------------------------------------------------------------------- /src/utils_data_prep.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from pathlib import Path 4 | import io 5 | from PyPDF2 import PdfReader 6 | import json 7 | from typing import List 8 | from PIL import Image 9 | 10 | def download_and_store_pdf(file_path) -> int: 11 | ''' 12 | args: file_path: file path corresponding to the OCR stored in json 13 | returns: None, and stores the corresponding pdf in the same folder structure as that of OCR 14 | ''' 15 | 16 | url = "https://download.industrydocuments.ucsf.edu" 17 | sample_file = file_path 18 | actual_path = sample_file.replace('ocrs', 'pdfs').split("/")[:-1] 19 | file_name = sample_file.split("/")[-1].split(".")[0] + ".pdf" 20 | dest_path = "/".join(actual_path) if file_path[0] == '/' else os.path.join(*actual_path) 21 | os.makedirs(dest_path, exist_ok=True) 22 | dest_path = os.path.join(dest_path, file_name) 23 | 24 | if os.path.exists(dest_path): 25 | return 1 26 | 27 | idx = sample_file.split("/").index("ocrs") 28 | for i in actual_path[idx+1:]: 29 | url = url + "/" + i 30 | url = url + "/" + actual_path[-1] 31 | url = url + ".pdf" 32 | 33 | try: 34 | response = requests.get(url) 35 | filename = Path(dest_path) 36 | filename.write_bytes(response.content) 37 | return 1 38 | except: 39 | return 0 40 | 41 | 42 | # Image property 43 | resize_scale = (500, 500) 44 | 45 | def normalize_box(box: List[int], width: int, height: int, size: tuple = resize_scale): 46 | """ 47 | Takes a bounding box and normalizes it to a thousand pixels. If you notice it is 48 | just like calculating percentage except takes 1000 instead of 100. 49 | """ 50 | return [ 51 | int(size[0] * (box[0] / width)), 52 | int(size[1] * (box[1] / height)), 53 | int(size[0] * (box[2] / width)), 54 | int(size[1] * (box[3] / height)), 55 | ] 56 | 57 | 58 | # Function to get the images from the PDFs as well as the OCRs for the corresponding images 59 | def get_image_ocrs_from_path(pdf_file_path: str, ocr_file_path: str, resize_scale=resize_scale, 60 | save_folder_img: str = "../data/images", save_folder_ocr: str = "../data/ocrs"): 61 | 62 | # Making folder to save the images 63 | if not os.path.exists(save_folder_img): 64 | os.mkdir(save_folder_img) 65 | 66 | # Making folder to save the OCRs 67 | if not os.path.exists(save_folder_ocr): 68 | os.mkdir(save_folder_ocr) 69 | 70 | try: 71 | 72 | # Getting the image list, since the pdfs can contain many image 73 | reader = PdfReader(pdf_file_path) 74 | img_list = {} 75 | pg_count = 1 76 | for i in range(len(reader.pages)): 77 | page = reader.pages[i] 78 | for image_file_object in page.images: 79 | 80 | stream = io.BytesIO(image_file_object.data) 81 | img = Image.open(stream).convert("RGB").resize(resize_scale) 82 | path_name = os.path.join( 83 | save_folder_img, f"{pdf_file_path.split('/')[-1].split('.')[0]}_{pg_count}.png") 84 | pg_count += 1 85 | img.save(path_name) 86 | img_list[pg_count - 1] = path_name 87 | 88 | json_entry = json.load(open(ocr_file_path))[1] 89 | json_entry = [x for x in json_entry["Blocks"] if "Text" in x] 90 | 91 | pages = [x["Page"] for x in json_entry] 92 | ocrs = {pg: [] for pg in set(pages)} 93 | 94 | for entry in json_entry: 95 | bbox = entry["Geometry"]["BoundingBox"] 96 | x, y, w, h = bbox['Left'], bbox['Top'], bbox["Width"], bbox["Height"] 97 | bbox = [x, y, x + w, y + h] 98 | # bbox = normalize_box(bbox, width=1, height=1, size=resize_scale) 99 | ocrs[entry["Page"]].append({"word": entry["Text"], "bbox": bbox}) 100 | 101 | ocr_path = {} 102 | for pg in set(pages): 103 | path_name = os.path.join( 104 | save_folder_ocr, f"{pdf_file_path.split('/')[-1].split('.')[0]}_{pg}.json") 105 | with open(path_name, "w") as f: 106 | json.dump(ocrs[pg], f) 107 | ocr_path[pg] = path_name 108 | 109 | return img_list, ocr_path 110 | 111 | except: 112 | return {}, {} 113 | 114 | 115 | # Function to get the OCRs for the corresponding images 116 | def get_ocrs_from_path(ocr_file_path: str, save_folder_ocr: str = "./ocrs"): 117 | 118 | # Making folder to save the OCRs 119 | if not os.path.exists(save_folder_ocr): 120 | os.mkdir(save_folder_ocr) 121 | 122 | try: 123 | 124 | json_entry = json.load(open(ocr_file_path))[1] 125 | json_entry = [x for x in json_entry["Blocks"] if "Text" in x] 126 | 127 | pages = [x["Page"] for x in json_entry] 128 | ocrs = {pg: [] for pg in set(pages)} 129 | 130 | for entry in json_entry: 131 | bbox = entry["Geometry"]["BoundingBox"] 132 | x, y, w, h = bbox['Left'], bbox['Top'], bbox["Width"], bbox["Height"] 133 | bbox = [x, y, x + w, y + h] 134 | # bbox = normalize_box(bbox, width=1, height=1, size=resize_scale) 135 | ocrs[entry["Page"]].append({"word": entry["Text"], "bbox": bbox}) 136 | 137 | ocr_path = {} 138 | for pg in set(pages): 139 | path_name = os.path.join( 140 | save_folder_ocr, f"{ocr_file_path.split('/')[-1].split('.')[0]}_{pg}.json") 141 | with open(path_name, "w") as f: 142 | json.dump(ocrs[pg], f) 143 | ocr_path[pg] = path_name 144 | 145 | return ocr_path 146 | 147 | except: 148 | return {} 149 | 150 | # Function to get the images from the PDFs as well as the OCRs for the corresponding images without saving the image 151 | def get_image_ocrs_dict_from_path(pdf_file_path: str, ocr_file_path: str): 152 | 153 | #try: 154 | # Getting the image list, since the pdfs can contain many image 155 | reader = PdfReader(pdf_file_path) 156 | img_list = {} 157 | pg_count = 1 158 | for i in range(len(reader.pages)): 159 | page = reader.pages[i] 160 | for _ in page.images: 161 | pg_count += 1 162 | img_list[pg_count - 1] = pdf_file_path 163 | 164 | json_entry = json.load(open(ocr_file_path, 'r'))[1] 165 | json_entry = [x for x in json_entry["Blocks"] if "Text" in x] 166 | 167 | pages = [x["Page"] for x in json_entry] 168 | ocrs = {pg: [] for pg in set(pages)} 169 | 170 | for entry in json_entry: 171 | bbox = entry["Geometry"]["BoundingBox"] 172 | x, y, w, h = bbox['Left'], bbox['Top'], bbox["Width"], bbox["Height"] 173 | bbox = [x, y, x + w, y + h] 174 | ocrs[entry["Page"]].append({"word": entry["Text"], "bbox": bbox}) 175 | 176 | ocr_path = {} 177 | for pg in set(pages): 178 | ocr_path[pg-1] = ocr_file_path 179 | 180 | #print(img_list) 181 | return img_list, ocr_path 182 | 183 | # except: 184 | # return {}, {} 185 | 186 | -------------------------------------------------------------------------------- /test-folder/code/utils_data_prep.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from pathlib import Path 4 | import io 5 | from PyPDF2 import PdfReader 6 | import json 7 | from typing import List 8 | from PIL import Image 9 | 10 | def download_and_store_pdf(file_path) -> int: 11 | ''' 12 | args: file_path: file path corresponding to the OCR stored in json 13 | returns: None, and stores the corresponding pdf in the same folder structure as that of OCR 14 | ''' 15 | 16 | url = "https://download.industrydocuments.ucsf.edu" 17 | sample_file = file_path 18 | actual_path = sample_file.replace('ocrs', 'pdfs').split("/")[:-1] 19 | file_name = sample_file.split("/")[-1].split(".")[0] + ".pdf" 20 | dest_path = "/".join(actual_path) if file_path[0] == '/' else os.path.join(*actual_path) 21 | os.makedirs(dest_path, exist_ok=True) 22 | dest_path = os.path.join(dest_path, file_name) 23 | 24 | if os.path.exists(dest_path): 25 | return 1 26 | 27 | idx = sample_file.split("/").index("ocrs") 28 | for i in actual_path[idx+1:]: 29 | url = url + "/" + i 30 | url = url + "/" + actual_path[-1] 31 | url = url + ".pdf" 32 | 33 | try: 34 | response = requests.get(url) 35 | filename = Path(dest_path) 36 | filename.write_bytes(response.content) 37 | return 1 38 | except: 39 | return 0 40 | 41 | 42 | # Image property 43 | resize_scale = (500, 500) 44 | 45 | def normalize_box(box: List[int], width: int, height: int, size: tuple = resize_scale): 46 | """ 47 | Takes a bounding box and normalizes it to a thousand pixels. If you notice it is 48 | just like calculating percentage except takes 1000 instead of 100. 49 | """ 50 | return [ 51 | int(size[0] * (box[0] / width)), 52 | int(size[1] * (box[1] / height)), 53 | int(size[0] * (box[2] / width)), 54 | int(size[1] * (box[3] / height)), 55 | ] 56 | 57 | 58 | # Function to get the images from the PDFs as well as the OCRs for the corresponding images 59 | def get_image_ocrs_from_path(pdf_file_path: str, ocr_file_path: str, resize_scale=resize_scale, 60 | save_folder_img: str = "../data/images", save_folder_ocr: str = "../data/ocrs"): 61 | 62 | # Making folder to save the images 63 | if not os.path.exists(save_folder_img): 64 | os.mkdir(save_folder_img) 65 | 66 | # Making folder to save the OCRs 67 | if not os.path.exists(save_folder_ocr): 68 | os.mkdir(save_folder_ocr) 69 | 70 | try: 71 | 72 | # Getting the image list, since the pdfs can contain many image 73 | reader = PdfReader(pdf_file_path) 74 | img_list = {} 75 | pg_count = 1 76 | for i in range(len(reader.pages)): 77 | page = reader.pages[i] 78 | for image_file_object in page.images: 79 | 80 | stream = io.BytesIO(image_file_object.data) 81 | img = Image.open(stream).convert("RGB").resize(resize_scale) 82 | path_name = os.path.join( 83 | save_folder_img, f"{pdf_file_path.split('/')[-1].split('.')[0]}_{pg_count}.png") 84 | pg_count += 1 85 | img.save(path_name) 86 | img_list[pg_count - 1] = path_name 87 | 88 | json_entry = json.load(open(ocr_file_path))[1] 89 | json_entry = [x for x in json_entry["Blocks"] if "Text" in x] 90 | 91 | pages = [x["Page"] for x in json_entry] 92 | ocrs = {pg: [] for pg in set(pages)} 93 | 94 | for entry in json_entry: 95 | bbox = entry["Geometry"]["BoundingBox"] 96 | x, y, w, h = bbox['Left'], bbox['Top'], bbox["Width"], bbox["Height"] 97 | bbox = [x, y, x + w, y + h] 98 | # bbox = normalize_box(bbox, width=1, height=1, size=resize_scale) 99 | ocrs[entry["Page"]].append({"word": entry["Text"], "bbox": bbox}) 100 | 101 | ocr_path = {} 102 | for pg in set(pages): 103 | path_name = os.path.join( 104 | save_folder_ocr, f"{pdf_file_path.split('/')[-1].split('.')[0]}_{pg}.json") 105 | with open(path_name, "w") as f: 106 | json.dump(ocrs[pg], f) 107 | ocr_path[pg] = path_name 108 | 109 | return img_list, ocr_path 110 | 111 | except: 112 | return {}, {} 113 | 114 | 115 | # Function to get the OCRs for the corresponding images 116 | def get_ocrs_from_path(ocr_file_path: str, save_folder_ocr: str = "./ocrs"): 117 | 118 | # Making folder to save the OCRs 119 | if not os.path.exists(save_folder_ocr): 120 | os.mkdir(save_folder_ocr) 121 | 122 | try: 123 | 124 | json_entry = json.load(open(ocr_file_path))[1] 125 | json_entry = [x for x in json_entry["Blocks"] if "Text" in x] 126 | 127 | pages = [x["Page"] for x in json_entry] 128 | ocrs = {pg: [] for pg in set(pages)} 129 | 130 | for entry in json_entry: 131 | bbox = entry["Geometry"]["BoundingBox"] 132 | x, y, w, h = bbox['Left'], bbox['Top'], bbox["Width"], bbox["Height"] 133 | bbox = [x, y, x + w, y + h] 134 | # bbox = normalize_box(bbox, width=1, height=1, size=resize_scale) 135 | ocrs[entry["Page"]].append({"word": entry["Text"], "bbox": bbox}) 136 | 137 | ocr_path = {} 138 | for pg in set(pages): 139 | path_name = os.path.join( 140 | save_folder_ocr, f"{ocr_file_path.split('/')[-1].split('.')[0]}_{pg}.json") 141 | with open(path_name, "w") as f: 142 | json.dump(ocrs[pg], f) 143 | ocr_path[pg] = path_name 144 | 145 | return ocr_path 146 | 147 | except: 148 | return {} 149 | 150 | # Function to get the images from the PDFs as well as the OCRs for the corresponding images without saving the image 151 | def get_image_ocrs_dict_from_path(pdf_file_path: str, ocr_file_path: str): 152 | 153 | #try: 154 | # Getting the image list, since the pdfs can contain many image 155 | reader = PdfReader(pdf_file_path) 156 | img_list = {} 157 | pg_count = 1 158 | for i in range(len(reader.pages)): 159 | page = reader.pages[i] 160 | for image_file_object in page.images: 161 | pdf_name = pdf_file_path.split('/')[-1].split('.')[0] 162 | pg_count += 1 163 | img_list[pg_count - 1] = pdf_file_path 164 | 165 | json_entry = json.load(open(ocr_file_path, 'r'))[1] 166 | json_entry = [x for x in json_entry["Blocks"] if "Text" in x] 167 | 168 | pages = [x["Page"] for x in json_entry] 169 | ocrs = {pg: [] for pg in set(pages)} 170 | 171 | for entry in json_entry: 172 | bbox = entry["Geometry"]["BoundingBox"] 173 | x, y, w, h = bbox['Left'], bbox['Top'], bbox["Width"], bbox["Height"] 174 | bbox = [x, y, x + w, y + h] 175 | ocrs[entry["Page"]].append({"word": entry["Text"], "bbox": bbox}) 176 | 177 | ocr_path = {} 178 | for pg in set(pages): 179 | ocr_path[pg-1] = ocr_file_path 180 | 181 | #print(img_list) 182 | return img_list, ocr_path 183 | 184 | # except: 185 | # return {}, {} 186 | 187 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/modeling.py: -------------------------------------------------------------------------------- 1 | from transformers import T5ForConditionalGeneration 2 | from torch.nn import CrossEntropyLoss 3 | from transformers.modeling_outputs import ( 4 | Seq2SeqLMOutput, 5 | ) 6 | from transformers.utils import is_torch_fx_proxy 7 | import torch 8 | 9 | import torch.nn as nn 10 | import collections.abc 11 | 12 | ## Ref:https://github.com/huggingface/transformers/blob/ff841900e45763114d2417fb24ce29d950c6c956/src/transformers/models/vit/modeling_vit.py#L146 13 | class PatchEmbeddings(nn.Module): 14 | """ 15 | This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial 16 | `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a 17 | Transformer. 18 | """ 19 | 20 | def __init__(self, config): 21 | super().__init__() 22 | image_size, patch_size = config.image_size, config.patch_size 23 | num_channels, hidden_size = config.num_channels, config.d_model 24 | 25 | image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) 26 | patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) 27 | num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) 28 | self.image_size = image_size 29 | self.patch_size = patch_size 30 | self.num_channels = num_channels 31 | self.num_patches = num_patches 32 | self.max_image_tokens = config.max_image_tokens ## If we limit the max_image_tokens to a number, would it capture the global context? 33 | ## Should we keep the convolution kernel's size to be (16, 16) rather than just (2, 2), so sequence length can be reduced and we are 34 | ## able to capture global context? 35 | 36 | self.conv_projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) 37 | self.linear_projection = nn.Linear(hidden_size, hidden_size) 38 | 39 | self.positional_embedding = nn.Embedding(self.max_image_tokens, hidden_size) 40 | 41 | 42 | def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: 43 | _, num_channels, height, width = pixel_values.shape 44 | if num_channels != self.num_channels: 45 | raise ValueError( 46 | "Make sure that the channel dimension of the pixel values match with the one set in the configuration." 47 | f" Expected {self.num_channels} but got {num_channels}." 48 | ) 49 | if not interpolate_pos_encoding: 50 | if height != self.image_size[0] or width != self.image_size[1]: 51 | raise ValueError( 52 | f"Input image size ({height}*{width}) doesn't match model" 53 | f" ({self.image_size[0]}*{self.image_size[1]})." 54 | ) 55 | embeddings = self.conv_projection(pixel_values).flatten(2).transpose(1, 2) 56 | embeddings = self.linear_projection(embeddings)[:, :self.max_image_tokens, :] 57 | 58 | positions = torch.arange(0, self.max_image_tokens).unsqueeze(0).to(embeddings.device) 59 | position_embedding = self.positional_embedding(positions) 60 | 61 | return embeddings + position_embedding 62 | 63 | 64 | ## Ref: https://github.com/uakarsh/latr/blob/1e73c1a99f9a0db4d85177259226148a65556069/src/new_latr/modeling.py#L11 65 | 66 | class SpatialModule(nn.Module): 67 | def __init__(self, config): 68 | super().__init__() 69 | self.top_left_x = nn.Embedding( 70 | config.max_2d_position_embeddings, config.d_model // 2) 71 | self.bottom_right_x = nn.Embedding( 72 | config.max_2d_position_embeddings, config.d_model // 2) 73 | self.top_left_y = nn.Embedding( 74 | config.max_2d_position_embeddings, config.d_model // 2) 75 | self.bottom_right_y = nn.Embedding( 76 | config.max_2d_position_embeddings, config.d_model // 2) 77 | self.width_emb = nn.Embedding(config.max_2d_position_embeddings, config.d_model) 78 | self.height_emb = nn.Embedding( 79 | config.max_2d_position_embeddings, config.d_model) 80 | 81 | def forward(self, coordinates): 82 | 83 | top_left_x_feat = self.top_left_x(coordinates[:, :, 0]) 84 | top_left_y_feat = self.top_left_y(coordinates[:, :, 1]) 85 | bottom_right_x_feat = self.bottom_right_x(coordinates[:, :, 2]) 86 | bottom_right_y_feat = self.bottom_right_y(coordinates[:, :, 3]) 87 | width_feat = self.width_emb(coordinates[:, :, 2] - coordinates[:, :, 0]) 88 | height_feat = self.height_emb(coordinates[:, :, 3] - coordinates[:, :, 1]) 89 | 90 | layout_feature = torch.cat([top_left_x_feat, bottom_right_x_feat], axis = -1) + torch.cat([top_left_y_feat, bottom_right_y_feat], axis = -1) + \ 91 | width_feat + height_feat 92 | return layout_feature 93 | 94 | class DocFormerV2(T5ForConditionalGeneration): 95 | def __init__(self, config): 96 | super().__init__(config=config) 97 | self.spatial_feat_extractor = SpatialModule(config) 98 | self.img_feature_extractor = PatchEmbeddings(config) 99 | self.modality_embedding = nn.Embedding(2, config.d_model) 100 | 101 | def forward( 102 | self, 103 | input_ids=None, 104 | bbox=None, 105 | attention_mask=None, 106 | decoder_input_ids=None, 107 | decoder_attention_mask=None, 108 | encoder_outputs=None, 109 | past_key_values=None, 110 | pixel_values=None, 111 | labels=None, 112 | head_mask=None, 113 | inputs_embeds=None, 114 | decoder_inputs_embeds=None, 115 | decoder_head_mask=None, 116 | cross_attn_head_mask=None, 117 | use_cache=True, 118 | output_attentions=None, 119 | output_hidden_states=None, 120 | return_dict=None, 121 | **kwargs,) : 122 | 123 | use_cache = use_cache if use_cache is not None else self.config.use_cache 124 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 125 | 126 | if decoder_input_ids is None and labels is not None: 127 | decoder_input_ids = self._shift_right(labels) 128 | 129 | # Encode if needed (training, first prediction pass) 130 | if encoder_outputs is None: 131 | inputs_embeds, attention_mask = self.calculate_embedding( 132 | pixel_values, bbox, input_ids, attention_mask) 133 | encoder_outputs = self.encoder( 134 | attention_mask=attention_mask, 135 | inputs_embeds=inputs_embeds, 136 | head_mask=head_mask, 137 | output_attentions=output_attentions, 138 | output_hidden_states=output_hidden_states, 139 | return_dict=return_dict, 140 | ) 141 | hidden_states = encoder_outputs[0] 142 | 143 | if decoder_input_ids == None: 144 | decoder_input_ids = self._shift_right(input_ids) 145 | 146 | # Decode 147 | decoder_outputs = self.decoder( 148 | input_ids=decoder_input_ids, 149 | attention_mask=decoder_attention_mask, 150 | inputs_embeds=decoder_inputs_embeds, 151 | past_key_values=past_key_values, 152 | encoder_hidden_states=hidden_states, 153 | encoder_attention_mask=attention_mask, 154 | head_mask=decoder_head_mask, 155 | cross_attn_head_mask=cross_attn_head_mask, 156 | use_cache=use_cache, 157 | output_attentions=output_attentions, 158 | output_hidden_states=output_hidden_states, 159 | return_dict=return_dict, 160 | ) 161 | 162 | sequence_output = decoder_outputs[0] 163 | 164 | if self.config.tie_word_embeddings: 165 | # Rescale output before projecting on vocab 166 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 167 | sequence_output = sequence_output * (self.config.d_model**-0.5) 168 | 169 | lm_logits = self.lm_head(sequence_output) 170 | 171 | loss = None 172 | if labels is not None: 173 | loss_fct = CrossEntropyLoss(ignore_index=-100) 174 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) 175 | 176 | if not return_dict: 177 | output = (lm_logits,) + \ 178 | decoder_outputs[2:] + (encoder_outputs[0],) + encoder_outputs[2:] 179 | return ((loss,) + output) if loss is not None else output 180 | 181 | return Seq2SeqLMOutput( 182 | loss=loss, 183 | logits=lm_logits, 184 | past_key_values=decoder_outputs.past_key_values, 185 | decoder_hidden_states=decoder_outputs.hidden_states, 186 | decoder_attentions=decoder_outputs.attentions, 187 | cross_attentions=decoder_outputs.cross_attentions, 188 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 189 | encoder_hidden_states=encoder_outputs.hidden_states, 190 | encoder_attentions=encoder_outputs.attentions, 191 | ) 192 | 193 | def _shift_right(self, input_ids): 194 | decoder_start_token_id = self.config.decoder_start_token_id 195 | pad_token_id = self.config.pad_token_id 196 | 197 | if decoder_start_token_id is None: 198 | raise ValueError( 199 | "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. " 200 | "See T5 docs for more information." 201 | ) 202 | 203 | # shift inputs to the right 204 | if is_torch_fx_proxy(input_ids): 205 | # Item assignment is not supported natively for proxies. 206 | shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) 207 | shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) 208 | else: 209 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 210 | shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() 211 | shifted_input_ids[..., 0] = decoder_start_token_id 212 | 213 | if pad_token_id is None: 214 | raise ValueError("self.model.config.pad_token_id has to be defined.") 215 | # replace possible -100 values in labels by `pad_token_id` 216 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 217 | 218 | return shifted_input_ids 219 | 220 | def prepare_inputs_for_generation( 221 | self, 222 | input_ids, 223 | past_key_values=None, 224 | attention_mask=None, 225 | head_mask=None, 226 | decoder_head_mask=None, 227 | cross_attn_head_mask=None, 228 | use_cache=None, 229 | encoder_outputs=None, 230 | **kwargs, 231 | ): 232 | # cut decoder_input_ids if past is used 233 | if past_key_values is not None: 234 | input_ids = input_ids[:, -1:] 235 | 236 | return { 237 | "decoder_input_ids": input_ids, 238 | "past_key_values": past_key_values, 239 | "encoder_outputs": encoder_outputs, 240 | "attention_mask": attention_mask, 241 | "head_mask": head_mask, 242 | "decoder_head_mask": decoder_head_mask, 243 | "cross_attn_head_mask": cross_attn_head_mask, 244 | "use_cache": use_cache, 245 | "bbox": kwargs.get("bbox", None), 246 | "pixel_values": kwargs.get("pixel_values", None), 247 | } 248 | 249 | def calculate_embedding(self, img, bbox, input_ids, attention_mask): 250 | img_feat = self.img_feature_extractor(img) 251 | spatial_feat = self.spatial_feat_extractor(bbox) 252 | language_feat = self.shared(input_ids) 253 | 254 | layout_feat = spatial_feat + language_feat 255 | img_modality_token = self.modality_embedding(torch.zeros(1, img_feat.shape[1]).long().to(self.device)) 256 | lang_modality_token = self.modality_embedding(torch.ones(1, language_feat.shape[1]).long().to(self.device)) 257 | 258 | img_feat += img_modality_token 259 | layout_feat += lang_modality_token 260 | 261 | multi_modal_feat = torch.cat([img_feat, layout_feat], axis=1) 262 | input_attention_mask = torch.cat( 263 | [torch.ones(img_feat.shape[:2]).to(img_feat.device), attention_mask], axis=1) 264 | 265 | return multi_modal_feat, input_attention_mask -------------------------------------------------------------------------------- /test-folder/code/modeling.py: -------------------------------------------------------------------------------- 1 | from transformers import T5ForConditionalGeneration 2 | from torch.nn import CrossEntropyLoss 3 | from transformers.modeling_outputs import ( 4 | Seq2SeqLMOutput, 5 | ) 6 | from transformers.utils import is_torch_fx_proxy 7 | import torch 8 | 9 | import torch.nn as nn 10 | import collections.abc 11 | 12 | ## Ref:https://github.com/huggingface/transformers/blob/ff841900e45763114d2417fb24ce29d950c6c956/src/transformers/models/vit/modeling_vit.py#L146 13 | class PatchEmbeddings(nn.Module): 14 | """ 15 | This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial 16 | `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a 17 | Transformer. 18 | """ 19 | 20 | def __init__(self, config): 21 | super().__init__() 22 | image_size, patch_size = config.image_size, config.patch_size 23 | num_channels, hidden_size = config.num_channels, config.d_model 24 | 25 | image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) 26 | patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) 27 | num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) 28 | self.image_size = image_size 29 | self.patch_size = patch_size 30 | self.num_channels = num_channels 31 | self.num_patches = num_patches 32 | self.max_image_tokens = config.max_image_tokens ## If we limit the max_image_tokens to a number, would it capture the global context? 33 | ## Should we keep the convolution kernel's size to be (16, 16) rather than just (2, 2), so sequence length can be reduced and we are 34 | ## able to capture global context? 35 | 36 | self.conv_projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) 37 | self.linear_projection = nn.Linear(hidden_size, hidden_size) 38 | 39 | self.positional_embedding = nn.Embedding(self.max_image_tokens, hidden_size) 40 | 41 | 42 | def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: 43 | _, num_channels, height, width = pixel_values.shape 44 | if num_channels != self.num_channels: 45 | raise ValueError( 46 | "Make sure that the channel dimension of the pixel values match with the one set in the configuration." 47 | f" Expected {self.num_channels} but got {num_channels}." 48 | ) 49 | if not interpolate_pos_encoding: 50 | if height != self.image_size[0] or width != self.image_size[1]: 51 | raise ValueError( 52 | f"Input image size ({height}*{width}) doesn't match model" 53 | f" ({self.image_size[0]}*{self.image_size[1]})." 54 | ) 55 | embeddings = self.conv_projection(pixel_values).flatten(2).transpose(1, 2) 56 | embeddings = self.linear_projection(embeddings)[:, :self.max_image_tokens, :] 57 | 58 | positions = torch.arange(0, self.max_image_tokens).unsqueeze(0).to(embeddings.device) 59 | position_embedding = self.positional_embedding(positions) 60 | 61 | return embeddings + position_embedding 62 | 63 | 64 | ## Ref: https://github.com/uakarsh/latr/blob/1e73c1a99f9a0db4d85177259226148a65556069/src/new_latr/modeling.py#L11 65 | 66 | class SpatialModule(nn.Module): 67 | def __init__(self, config): 68 | super().__init__() 69 | self.top_left_x = nn.Embedding( 70 | config.max_2d_position_embeddings, config.d_model // 2) 71 | self.bottom_right_x = nn.Embedding( 72 | config.max_2d_position_embeddings, config.d_model // 2) 73 | self.top_left_y = nn.Embedding( 74 | config.max_2d_position_embeddings, config.d_model // 2) 75 | self.bottom_right_y = nn.Embedding( 76 | config.max_2d_position_embeddings, config.d_model // 2) 77 | self.width_emb = nn.Embedding(config.max_2d_position_embeddings, config.d_model) 78 | self.height_emb = nn.Embedding( 79 | config.max_2d_position_embeddings, config.d_model) 80 | 81 | def forward(self, coordinates): 82 | 83 | top_left_x_feat = self.top_left_x(coordinates[:, :, 0]) 84 | top_left_y_feat = self.top_left_y(coordinates[:, :, 1]) 85 | bottom_right_x_feat = self.bottom_right_x(coordinates[:, :, 2]) 86 | bottom_right_y_feat = self.bottom_right_y(coordinates[:, :, 3]) 87 | width_feat = self.width_emb(coordinates[:, :, 2] - coordinates[:, :, 0]) 88 | height_feat = self.height_emb(coordinates[:, :, 3] - coordinates[:, :, 1]) 89 | 90 | layout_feature = torch.cat([top_left_x_feat, bottom_right_x_feat], axis = -1) + torch.cat([top_left_y_feat, bottom_right_y_feat], axis = -1) + \ 91 | width_feat + height_feat 92 | return layout_feature 93 | 94 | class DocFormerV2(T5ForConditionalGeneration): 95 | def __init__(self, config): 96 | super().__init__(config=config) 97 | self.spatial_feat_extractor = SpatialModule(config) 98 | self.img_feature_extractor = PatchEmbeddings(config) 99 | self.modality_embedding = nn.Embedding(2, config.d_model) 100 | 101 | def forward( 102 | self, 103 | input_ids=None, 104 | bbox=None, 105 | attention_mask=None, 106 | decoder_input_ids=None, 107 | decoder_attention_mask=None, 108 | encoder_outputs=None, 109 | past_key_values=None, 110 | pixel_values=None, 111 | labels=None, 112 | head_mask=None, 113 | inputs_embeds=None, 114 | decoder_inputs_embeds=None, 115 | decoder_head_mask=None, 116 | cross_attn_head_mask=None, 117 | use_cache=True, 118 | output_attentions=None, 119 | output_hidden_states=None, 120 | return_dict=None, 121 | **kwargs,) : 122 | 123 | use_cache = use_cache if use_cache is not None else self.config.use_cache 124 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 125 | 126 | if decoder_input_ids is None and labels is not None: 127 | decoder_input_ids = self._shift_right(labels) 128 | 129 | # Encode if needed (training, first prediction pass) 130 | if encoder_outputs is None: 131 | inputs_embeds, attention_mask = self.calculate_embedding( 132 | pixel_values, bbox, input_ids, attention_mask) 133 | encoder_outputs = self.encoder( 134 | attention_mask=attention_mask, 135 | inputs_embeds=inputs_embeds, 136 | head_mask=head_mask, 137 | output_attentions=output_attentions, 138 | output_hidden_states=output_hidden_states, 139 | return_dict=return_dict, 140 | ) 141 | hidden_states = encoder_outputs[0] 142 | 143 | if decoder_input_ids == None: 144 | decoder_input_ids = self._shift_right(input_ids) 145 | 146 | # Decode 147 | decoder_outputs = self.decoder( 148 | input_ids=decoder_input_ids, 149 | attention_mask=decoder_attention_mask, 150 | inputs_embeds=decoder_inputs_embeds, 151 | past_key_values=past_key_values, 152 | encoder_hidden_states=hidden_states, 153 | encoder_attention_mask=attention_mask, 154 | head_mask=decoder_head_mask, 155 | cross_attn_head_mask=cross_attn_head_mask, 156 | use_cache=use_cache, 157 | output_attentions=output_attentions, 158 | output_hidden_states=output_hidden_states, 159 | return_dict=return_dict, 160 | ) 161 | 162 | sequence_output = decoder_outputs[0] 163 | 164 | if self.config.tie_word_embeddings: 165 | # Rescale output before projecting on vocab 166 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 167 | sequence_output = sequence_output * (self.config.d_model**-0.5) 168 | 169 | lm_logits = self.lm_head(sequence_output) 170 | 171 | loss = None 172 | if labels is not None: 173 | loss_fct = CrossEntropyLoss(ignore_index=-100) 174 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) 175 | 176 | if not return_dict: 177 | output = (lm_logits,) + \ 178 | decoder_outputs[2:] + (encoder_outputs[0],) + encoder_outputs[2:] 179 | return ((loss,) + output) if loss is not None else output 180 | 181 | return Seq2SeqLMOutput( 182 | loss=loss, 183 | logits=lm_logits, 184 | past_key_values=decoder_outputs.past_key_values, 185 | decoder_hidden_states=decoder_outputs.hidden_states, 186 | decoder_attentions=decoder_outputs.attentions, 187 | cross_attentions=decoder_outputs.cross_attentions, 188 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 189 | encoder_hidden_states=encoder_outputs.hidden_states, 190 | encoder_attentions=encoder_outputs.attentions, 191 | ) 192 | 193 | def _shift_right(self, input_ids): 194 | decoder_start_token_id = self.config.decoder_start_token_id 195 | pad_token_id = self.config.pad_token_id 196 | 197 | if decoder_start_token_id is None: 198 | raise ValueError( 199 | "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. " 200 | "See T5 docs for more information." 201 | ) 202 | 203 | # shift inputs to the right 204 | if is_torch_fx_proxy(input_ids): 205 | # Item assignment is not supported natively for proxies. 206 | shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) 207 | shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) 208 | else: 209 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 210 | shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() 211 | shifted_input_ids[..., 0] = decoder_start_token_id 212 | 213 | if pad_token_id is None: 214 | raise ValueError("self.model.config.pad_token_id has to be defined.") 215 | # replace possible -100 values in labels by `pad_token_id` 216 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 217 | 218 | return shifted_input_ids 219 | 220 | def prepare_inputs_for_generation( 221 | self, 222 | input_ids, 223 | past_key_values=None, 224 | attention_mask=None, 225 | head_mask=None, 226 | decoder_head_mask=None, 227 | cross_attn_head_mask=None, 228 | use_cache=None, 229 | encoder_outputs=None, 230 | **kwargs, 231 | ): 232 | # cut decoder_input_ids if past is used 233 | if past_key_values is not None: 234 | input_ids = input_ids[:, -1:] 235 | 236 | return { 237 | "decoder_input_ids": input_ids, 238 | "past_key_values": past_key_values, 239 | "encoder_outputs": encoder_outputs, 240 | "attention_mask": attention_mask, 241 | "head_mask": head_mask, 242 | "decoder_head_mask": decoder_head_mask, 243 | "cross_attn_head_mask": cross_attn_head_mask, 244 | "use_cache": use_cache, 245 | "bbox": kwargs.get("bbox", None), 246 | "pixel_values": kwargs.get("pixel_values", None), 247 | } 248 | 249 | def calculate_embedding(self, img, bbox, input_ids, attention_mask): 250 | img_feat = self.img_feature_extractor(img) 251 | spatial_feat = self.spatial_feat_extractor(bbox) 252 | language_feat = self.shared(input_ids) 253 | 254 | layout_feat = spatial_feat + language_feat 255 | img_modality_token = self.modality_embedding(torch.zeros(1, img_feat.shape[1]).long().to(self.device)) 256 | lang_modality_token = self.modality_embedding(torch.ones(1, language_feat.shape[1]).long().to(self.device)) 257 | 258 | img_feat += img_modality_token 259 | layout_feat += lang_modality_token 260 | 261 | multi_modal_feat = torch.cat([img_feat, layout_feat], axis=1) 262 | input_attention_mask = torch.cat( 263 | [torch.ones(img_feat.shape[:2]).to(img_feat.device), attention_mask], axis=1) 264 | 265 | return multi_modal_feat, input_attention_mask -------------------------------------------------------------------------------- /test-folder/code/pretraining-modeling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/Users/akarsh.upadhyay@zomato.com/anaconda3/envs/docformer-v2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "from datasets import load_from_disk\n", 19 | "import torch\n", 20 | "from utils_modeling import get_tokens_with_boxes, normalize_box\n", 21 | "# from transformers.utils import is_torch_fx_proxy\n", 22 | "from modeling import PatchEmbeddings, SpatialModule, DocFormerV2" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "raw_datasets = load_from_disk(\"../data/idl-pretrain-dataset\")" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# # create rectangle image\n", 41 | "# img = item['img']\n", 42 | "# bbox = item['bbox'][:-1] ## Removing the website\n", 43 | "# words = item['words'][:-1] ## Removing the website\n", 44 | "# draw_on_img = ImageDraw.Draw(img) \n", 45 | "\n", 46 | "# for it in bbox:\n", 47 | "# draw_on_img.rectangle(it, outline =\"violet\")\n", 48 | "# img" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 4, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# ## Pre-processing bounding boxes, ref: https://github.com/uakarsh/latr/blob/1e73c1a99f9a0db4d85177259226148a65556069/src/new_latr/dataset.py#L34\n", 58 | "\n", 59 | "# def normalize_box(box, width, height, size=1000):\n", 60 | "# \"\"\"\n", 61 | "# Takes a bounding box and normalizes it to a thousand pixels. If you notice it is\n", 62 | "# just like calculating percentage except takes 1000 instead of 100.\n", 63 | "\n", 64 | "# Arguments:\n", 65 | "# box: A list of bounding box coordinates\n", 66 | "# width: The width of the image\n", 67 | "# height: The height of the image\n", 68 | "# size: The size to normalize to\n", 69 | "# Returns:\n", 70 | "# A list of normalized bounding box coordinates\n", 71 | "# \"\"\"\n", 72 | "# return [\n", 73 | "# int(size * (box[0] / width)),\n", 74 | "# int(size * (box[1] / height)),\n", 75 | "# int(size * (box[2] / width)),\n", 76 | "# int(size * (box[3] / height)),\n", 77 | "# ]\n", 78 | "\n", 79 | "# def get_tokens_with_boxes(bounding_boxes, list_of_words, tokenizer, pad_token_box=[0, 0, 0, 0], max_seq_len=-1, eos_token_box=[0, 0, 1000, 1000]):\n", 80 | "\n", 81 | "# '''\n", 82 | "# A function to get the tokens with the bounding boxes\n", 83 | "# Arguments:\n", 84 | "# bounding_boxes: A list of bounding boxes\n", 85 | "# list_of_words: A list of words\n", 86 | "# tokenizer: The tokenizer to use\n", 87 | "# pad_token_box: The padding token box\n", 88 | "# max_seq_len: The maximum sequence length, not padded if max_seq_len is -1\n", 89 | "# eos_token_box: The end of sequence token box\n", 90 | "# Returns:\n", 91 | "# A list of input_ids, bbox_according_to_tokenizer, attention_mask\n", 92 | "# '''\n", 93 | "\n", 94 | "# # 2. Performing the semantic pre-processing\n", 95 | "# encoding = tokenizer(list_of_words, is_split_into_words=True,\n", 96 | "# add_special_tokens=False)\n", 97 | "\n", 98 | "# input_ids = encoding['input_ids']\n", 99 | "# attention_mask = encoding['attention_mask']\n", 100 | "\n", 101 | "# # Note that, there is no need for bboxes, since the model does not use bbox as feature, so no pre-processing of that\n", 102 | "# bbox_according_to_tokenizer = [bounding_boxes[i]\n", 103 | "# for i in encoding.word_ids()]\n", 104 | "\n", 105 | "# # Truncation of token_boxes + token_labels\n", 106 | "# special_tokens_count = 1\n", 107 | "# if max_seq_len != -1 and len(input_ids) > max_seq_len - special_tokens_count:\n", 108 | "# bbox_according_to_tokenizer = bbox_according_to_tokenizer[: (\n", 109 | "# max_seq_len - special_tokens_count)]\n", 110 | "# input_ids = input_ids[: (max_seq_len - special_tokens_count)]\n", 111 | "# attention_mask = attention_mask[: (max_seq_len - special_tokens_count)]\n", 112 | "\n", 113 | "# ## Adding End of sentence token\n", 114 | "# input_ids = input_ids + [tokenizer.eos_token_id]\n", 115 | "# bbox_according_to_tokenizer = bbox_according_to_tokenizer + [eos_token_box]\n", 116 | "# attention_mask = attention_mask + [1]\n", 117 | "\n", 118 | "# # Padding\n", 119 | "# if max_seq_len != -1 and len(input_ids) < max_seq_len:\n", 120 | "# pad_length = max_seq_len - len(input_ids)\n", 121 | "\n", 122 | "# input_ids = input_ids + [tokenizer.pad_token_id] * (pad_length)\n", 123 | "# bbox_according_to_tokenizer = bbox_according_to_tokenizer + \\\n", 124 | "# [pad_token_box] * (pad_length)\n", 125 | "# attention_mask = attention_mask + [0] * (pad_length)\n", 126 | "\n", 127 | "# return dict(input_ids = input_ids, bboxes = bbox_according_to_tokenizer, attention_mask = attention_mask)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 5, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "from transformers import AutoTokenizer, PretrainedConfig, AutoConfig\n", 137 | "from torchvision.transforms import Compose, Normalize, Resize, ToTensor\n", 138 | "\n", 139 | "config = {\n", 140 | " 'model_name' : 'google-t5/t5-small', ## can be 'google-t5/t5-small', 'google-t5/t5-base', 'google-t5/t5-large'\n", 141 | " 'image_model' : 'microsoft/resnet-50',\n", 142 | " 'image_mean' : [0.485, 0.456, 0.406], ## resnet-50 configuration\n", 143 | " 'image_std' : [0.229, 0.224, 0.225],\n", 144 | " 'image_size' : (224, 224),\n", 145 | " 'patch_size' : (2, 2),\n", 146 | " 'num_channels' : 3,\n", 147 | " 'max_image_tokens' : 128, ## Found from ablation\n", 148 | " 'max_2d_position_embeddings' : 1024, \n", 149 | " \n", 150 | "}\n", 151 | "\n", 152 | "t5_config = AutoConfig.from_pretrained(config['model_name'])\n", 153 | "t5_config.update(config)\n", 154 | "config = t5_config\n", 155 | "tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", 156 | "image_transform = Compose([\n", 157 | " Resize(config.image_size if type(config.image_size) == tuple else (config.image_size, config.image_size)),\n", 158 | " ToTensor(),\n", 159 | " Normalize(mean=config.image_mean, std=config.image_std),\n", 160 | " ])" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 6, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "orig_size = (1000, 1000) ## We kept the original scale to 1000, 1000 image size\n", 170 | "\n", 171 | "## Normalizing the bounding boxes between 0 to 1000\n", 172 | "raw_datasets['train'] = raw_datasets['train'].map(lambda x : {'bbox' : [normalize_box(a, orig_size[0], orig_size[1]) for a in x['bbox']]}, batched=False)\n", 173 | "\n", 174 | "raw_datasets['train'] = raw_datasets['train'].map(lambda x : get_tokens_with_boxes(x['bbox'], x['words'], tokenizer=tokenizer,\n", 175 | " max_seq_len=tokenizer.model_max_length), batched=False, \n", 176 | " remove_columns = ['bbox','words'])" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 7, 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "data": { 186 | "text/plain": [ 187 | "DatasetDict({\n", 188 | " train: Dataset({\n", 189 | " features: ['img', 'input_ids', 'bboxes', 'attention_mask'],\n", 190 | " num_rows: 2\n", 191 | " })\n", 192 | "})" 193 | ] 194 | }, 195 | "execution_count": 7, 196 | "metadata": {}, 197 | "output_type": "execute_result" 198 | } 199 | ], 200 | "source": [ 201 | "raw_datasets" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 8, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "def preprocess(example_batch):\n", 211 | " batch = {}\n", 212 | " batch[\"pixel_values\"] = [\n", 213 | " image_transform(img) for img in example_batch[\"img\"]\n", 214 | " ]\n", 215 | "\n", 216 | " batch['input_ids'] = [torch.tensor(ids).long() for ids in example_batch[\"input_ids\"]]\n", 217 | " batch['bbox'] = [torch.tensor(box).long() for box in example_batch[\"bboxes\"]]\n", 218 | " batch['attention_mask'] = [torch.tensor(mask).long() for mask in example_batch[\"attention_mask\"]]\n", 219 | "\n", 220 | " return batch" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 9, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "raw_datasets.set_transform(preprocess)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "## Modeling\n", 237 | "\n", 238 | "1. Image Embedding" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 10, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "# import torch.nn as nn\n", 248 | "# import collections.abc\n", 249 | "\n", 250 | "# ## Ref:https://github.com/huggingface/transformers/blob/ff841900e45763114d2417fb24ce29d950c6c956/src/transformers/models/vit/modeling_vit.py#L146\n", 251 | "# class PatchEmbeddings(nn.Module):\n", 252 | "# \"\"\"\n", 253 | "# This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial\n", 254 | "# `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a\n", 255 | "# Transformer.\n", 256 | "# \"\"\"\n", 257 | "\n", 258 | "# def __init__(self, config):\n", 259 | "# super().__init__()\n", 260 | "# image_size, patch_size = config.image_size, config.patch_size\n", 261 | "# num_channels, hidden_size = config.num_channels, config.d_model\n", 262 | "\n", 263 | "# image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)\n", 264 | "# patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)\n", 265 | "# num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])\n", 266 | "# self.image_size = image_size\n", 267 | "# self.patch_size = patch_size\n", 268 | "# self.num_channels = num_channels\n", 269 | "# self.num_patches = num_patches\n", 270 | "# self.max_image_tokens = config.max_image_tokens ## If we limit the max_image_tokens to a number, would it capture the global context?\n", 271 | "# ## Should we keep the convolution kernel's size to be (16, 16) rather than just (2, 2), so sequence length can be reduced and we are\n", 272 | "# ## able to capture global context?\n", 273 | "\n", 274 | "# self.conv_projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)\n", 275 | "# self.linear_projection = nn.Linear(hidden_size, hidden_size)\n", 276 | "\n", 277 | "# self.positional_embedding = nn.Embedding(self.max_image_tokens, hidden_size)\n", 278 | "\n", 279 | "\n", 280 | "# def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:\n", 281 | "# _, num_channels, height, width = pixel_values.shape\n", 282 | "# if num_channels != self.num_channels:\n", 283 | "# raise ValueError(\n", 284 | "# \"Make sure that the channel dimension of the pixel values match with the one set in the configuration.\"\n", 285 | "# f\" Expected {self.num_channels} but got {num_channels}.\"\n", 286 | "# )\n", 287 | "# if not interpolate_pos_encoding:\n", 288 | "# if height != self.image_size[0] or width != self.image_size[1]:\n", 289 | "# raise ValueError(\n", 290 | "# f\"Input image size ({height}*{width}) doesn't match model\"\n", 291 | "# f\" ({self.image_size[0]}*{self.image_size[1]}).\"\n", 292 | "# )\n", 293 | "# embeddings = self.conv_projection(pixel_values).flatten(2).transpose(1, 2)\n", 294 | "# embeddings = self.linear_projection(embeddings)[:, :self.max_image_tokens, :]\n", 295 | "\n", 296 | "# positions = torch.arange(0, self.max_image_tokens).unsqueeze(0).to(embeddings.device)\n", 297 | "# position_embedding = self.positional_embedding(positions)\n", 298 | "\n", 299 | "# return embeddings + position_embedding" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 11, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "img_feature_extractor = PatchEmbeddings(config)\n", 309 | "sample_img_emb = img_feature_extractor(raw_datasets['train'][0]['pixel_values'].unsqueeze(0))" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "metadata": {}, 315 | "source": [ 316 | "2. Text and Spatial Embedding" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 12, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "# ## Ref: https://github.com/uakarsh/latr/blob/1e73c1a99f9a0db4d85177259226148a65556069/src/new_latr/modeling.py#L11\n", 326 | "\n", 327 | "# class SpatialModule(nn.Module):\n", 328 | "# def __init__(self, config):\n", 329 | "# super().__init__()\n", 330 | "# self.top_left_x = nn.Embedding(\n", 331 | "# config.max_2d_position_embeddings, config.d_model // 2)\n", 332 | "# self.bottom_right_x = nn.Embedding(\n", 333 | "# config.max_2d_position_embeddings, config.d_model // 2)\n", 334 | "# self.top_left_y = nn.Embedding(\n", 335 | "# config.max_2d_position_embeddings, config.d_model // 2)\n", 336 | "# self.bottom_right_y = nn.Embedding(\n", 337 | "# config.max_2d_position_embeddings, config.d_model // 2)\n", 338 | "# self.width_emb = nn.Embedding(config.max_2d_position_embeddings, config.d_model)\n", 339 | "# self.height_emb = nn.Embedding(\n", 340 | "# config.max_2d_position_embeddings, config.d_model)\n", 341 | "\n", 342 | "# def forward(self, coordinates):\n", 343 | "\n", 344 | "# top_left_x_feat = self.top_left_x(coordinates[:, :, 0])\n", 345 | "# top_left_y_feat = self.top_left_y(coordinates[:, :, 1])\n", 346 | "# bottom_right_x_feat = self.bottom_right_x(coordinates[:, :, 2])\n", 347 | "# bottom_right_y_feat = self.bottom_right_y(coordinates[:, :, 3])\n", 348 | "# width_feat = self.width_emb(coordinates[:, :, 2] - coordinates[:, :, 0])\n", 349 | "# height_feat = self.height_emb(coordinates[:, :, 3] - coordinates[:, :, 1])\n", 350 | "\n", 351 | "# layout_feature = torch.cat([top_left_x_feat, bottom_right_x_feat], axis = -1) + torch.cat([top_left_y_feat, bottom_right_y_feat], axis = -1) + \\\n", 352 | "# width_feat + height_feat\n", 353 | "# return layout_feature" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 13, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "spatial_feature_extractor = SpatialModule(config)\n", 363 | "spatial_feat = spatial_feature_extractor(raw_datasets['train'][0]['bbox'].unsqueeze(0))" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 14, 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "data": { 373 | "text/plain": [ 374 | "dict_keys(['pixel_values', 'input_ids', 'bbox', 'attention_mask'])" 375 | ] 376 | }, 377 | "execution_count": 14, 378 | "metadata": {}, 379 | "output_type": "execute_result" 380 | } 381 | ], 382 | "source": [ 383 | "raw_datasets['train'][0].keys()" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 15, 389 | "metadata": {}, 390 | "outputs": [ 391 | { 392 | "data": { 393 | "text/plain": [ 394 | "torch.Size([512, 4])" 395 | ] 396 | }, 397 | "execution_count": 15, 398 | "metadata": {}, 399 | "output_type": "execute_result" 400 | } 401 | ], 402 | "source": [ 403 | "raw_datasets['train'][0]['bbox'].shape" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 16, 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "# from transformers import T5ForConditionalGeneration\n", 413 | "# from torch.nn import CrossEntropyLoss\n", 414 | "# from transformers.modeling_outputs import (\n", 415 | "# Seq2SeqLMOutput,\n", 416 | "# )\n", 417 | "\n", 418 | "# class DocFormerV2(T5ForConditionalGeneration):\n", 419 | "# def __init__(self, config):\n", 420 | "# super().__init__(config=config)\n", 421 | "# self.spatial_feat_extractor = SpatialModule(config)\n", 422 | "# self.img_feature_extractor = PatchEmbeddings(config)\n", 423 | "# self.modality_embedding = nn.Embedding(2, config.d_model)\n", 424 | "\n", 425 | "# def forward(\n", 426 | "# self,\n", 427 | "# input_ids=None,\n", 428 | "# bbox=None,\n", 429 | "# attention_mask=None,\n", 430 | "# decoder_input_ids=None,\n", 431 | "# decoder_attention_mask=None,\n", 432 | "# encoder_outputs=None,\n", 433 | "# past_key_values=None,\n", 434 | "# pixel_values=None,\n", 435 | "# labels=None,\n", 436 | "# head_mask=None,\n", 437 | "# inputs_embeds=None,\n", 438 | "# decoder_inputs_embeds=None,\n", 439 | "# decoder_head_mask=None,\n", 440 | "# cross_attn_head_mask=None,\n", 441 | "# use_cache=True,\n", 442 | "# output_attentions=None,\n", 443 | "# output_hidden_states=None,\n", 444 | "# return_dict=None,\n", 445 | "# **kwargs,) :\n", 446 | "\n", 447 | "# use_cache = use_cache if use_cache is not None else self.config.use_cache\n", 448 | "# return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n", 449 | "\n", 450 | "# if decoder_input_ids is None and labels is not None:\n", 451 | "# decoder_input_ids = self._shift_right(labels)\n", 452 | "\n", 453 | "# # Encode if needed (training, first prediction pass)\n", 454 | "# if encoder_outputs is None:\n", 455 | "# inputs_embeds, attention_mask = self.calculate_embedding(\n", 456 | "# pixel_values, bbox, input_ids, attention_mask)\n", 457 | "# encoder_outputs = self.encoder(\n", 458 | "# attention_mask=attention_mask,\n", 459 | "# inputs_embeds=inputs_embeds,\n", 460 | "# head_mask=head_mask,\n", 461 | "# output_attentions=output_attentions,\n", 462 | "# output_hidden_states=output_hidden_states,\n", 463 | "# return_dict=return_dict,\n", 464 | "# )\n", 465 | "# hidden_states = encoder_outputs[0]\n", 466 | "\n", 467 | "# if decoder_input_ids == None:\n", 468 | "# decoder_input_ids = self._shift_right(input_ids)\n", 469 | "\n", 470 | "# # Decode\n", 471 | "# decoder_outputs = self.decoder(\n", 472 | "# input_ids=decoder_input_ids,\n", 473 | "# attention_mask=decoder_attention_mask,\n", 474 | "# inputs_embeds=decoder_inputs_embeds,\n", 475 | "# past_key_values=past_key_values,\n", 476 | "# encoder_hidden_states=hidden_states,\n", 477 | "# encoder_attention_mask=attention_mask,\n", 478 | "# head_mask=decoder_head_mask,\n", 479 | "# cross_attn_head_mask=cross_attn_head_mask,\n", 480 | "# use_cache=use_cache,\n", 481 | "# output_attentions=output_attentions,\n", 482 | "# output_hidden_states=output_hidden_states,\n", 483 | "# return_dict=return_dict,\n", 484 | "# )\n", 485 | "\n", 486 | "# sequence_output = decoder_outputs[0]\n", 487 | "\n", 488 | "# if self.config.tie_word_embeddings:\n", 489 | "# # Rescale output before projecting on vocab\n", 490 | "# # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586\n", 491 | "# sequence_output = sequence_output * (self.config.d_model**-0.5)\n", 492 | "\n", 493 | "# lm_logits = self.lm_head(sequence_output)\n", 494 | "\n", 495 | "# loss = None\n", 496 | "# if labels is not None:\n", 497 | "# loss_fct = CrossEntropyLoss(ignore_index=-100)\n", 498 | "# loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))\n", 499 | "\n", 500 | "# if not return_dict:\n", 501 | "# output = (lm_logits,) + \\\n", 502 | "# decoder_outputs[2:] + (encoder_outputs[0],) + encoder_outputs[2:]\n", 503 | "# return ((loss,) + output) if loss is not None else output\n", 504 | "\n", 505 | "# return Seq2SeqLMOutput(\n", 506 | "# loss=loss,\n", 507 | "# logits=lm_logits,\n", 508 | "# past_key_values=decoder_outputs.past_key_values,\n", 509 | "# decoder_hidden_states=decoder_outputs.hidden_states,\n", 510 | "# decoder_attentions=decoder_outputs.attentions,\n", 511 | "# cross_attentions=decoder_outputs.cross_attentions,\n", 512 | "# encoder_last_hidden_state=encoder_outputs.last_hidden_state,\n", 513 | "# encoder_hidden_states=encoder_outputs.hidden_states,\n", 514 | "# encoder_attentions=encoder_outputs.attentions,\n", 515 | "# )\n", 516 | " \n", 517 | "# def _shift_right(self, input_ids):\n", 518 | "# decoder_start_token_id = self.config.decoder_start_token_id\n", 519 | "# pad_token_id = self.config.pad_token_id\n", 520 | "\n", 521 | "# if decoder_start_token_id is None:\n", 522 | "# raise ValueError(\n", 523 | "# \"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. \"\n", 524 | "# \"See T5 docs for more information.\"\n", 525 | "# )\n", 526 | "\n", 527 | "# # shift inputs to the right\n", 528 | "# if is_torch_fx_proxy(input_ids):\n", 529 | "# # Item assignment is not supported natively for proxies.\n", 530 | "# shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)\n", 531 | "# shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)\n", 532 | "# else:\n", 533 | "# shifted_input_ids = input_ids.new_zeros(input_ids.shape)\n", 534 | "# shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()\n", 535 | "# shifted_input_ids[..., 0] = decoder_start_token_id\n", 536 | "\n", 537 | "# if pad_token_id is None:\n", 538 | "# raise ValueError(\"self.model.config.pad_token_id has to be defined.\")\n", 539 | "# # replace possible -100 values in labels by `pad_token_id`\n", 540 | "# shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)\n", 541 | "\n", 542 | "# return shifted_input_ids\n", 543 | "\n", 544 | "# def prepare_inputs_for_generation(\n", 545 | "# self,\n", 546 | "# input_ids,\n", 547 | "# past_key_values=None,\n", 548 | "# attention_mask=None,\n", 549 | "# head_mask=None,\n", 550 | "# decoder_head_mask=None,\n", 551 | "# cross_attn_head_mask=None,\n", 552 | "# use_cache=None,\n", 553 | "# encoder_outputs=None,\n", 554 | "# **kwargs,\n", 555 | "# ):\n", 556 | "# # cut decoder_input_ids if past is used\n", 557 | "# if past_key_values is not None:\n", 558 | "# input_ids = input_ids[:, -1:]\n", 559 | "\n", 560 | "# return {\n", 561 | "# \"decoder_input_ids\": input_ids,\n", 562 | "# \"past_key_values\": past_key_values,\n", 563 | "# \"encoder_outputs\": encoder_outputs,\n", 564 | "# \"attention_mask\": attention_mask,\n", 565 | "# \"head_mask\": head_mask,\n", 566 | "# \"decoder_head_mask\": decoder_head_mask,\n", 567 | "# \"cross_attn_head_mask\": cross_attn_head_mask,\n", 568 | "# \"use_cache\": use_cache,\n", 569 | "# \"bbox\": kwargs.get(\"bbox\", None),\n", 570 | "# \"pixel_values\": kwargs.get(\"pixel_values\", None),\n", 571 | "# }\n", 572 | "\n", 573 | "# def calculate_embedding(self, img, bbox, input_ids, attention_mask):\n", 574 | "# img_feat = self.img_feature_extractor(img)\n", 575 | "# spatial_feat = self.spatial_feat_extractor(bbox)\n", 576 | "# language_feat = self.shared(input_ids)\n", 577 | "\n", 578 | "# layout_feat = spatial_feat + language_feat\n", 579 | "# img_modality_token = self.modality_embedding(torch.zeros(1, img_feat.shape[1]).long().to(self.device))\n", 580 | "# lang_modality_token = self.modality_embedding(torch.ones(1, language_feat.shape[1]).long().to(self.device))\n", 581 | "\n", 582 | "# img_feat += img_modality_token\n", 583 | "# layout_feat += lang_modality_token\n", 584 | "\n", 585 | "# multi_modal_feat = torch.cat([img_feat, layout_feat], axis=1)\n", 586 | "# input_attention_mask = torch.cat(\n", 587 | "# [torch.ones(img_feat.shape[:2]).to(img_feat.device), attention_mask], axis=1)\n", 588 | " \n", 589 | "# return multi_modal_feat, input_attention_mask" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": 17, 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [ 598 | "sample = raw_datasets['train'][0]\n", 599 | "for key, _ in list(sample.items()):\n", 600 | " sample[key] = sample[key].unsqueeze(0)" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 18, 606 | "metadata": {}, 607 | "outputs": [], 608 | "source": [ 609 | "docformer_v2 = DocFormerV2(config)" 610 | ] 611 | }, 612 | { 613 | "cell_type": "code", 614 | "execution_count": 19, 615 | "metadata": {}, 616 | "outputs": [ 617 | { 618 | "data": { 619 | "text/plain": [ 620 | "torch.Size([1, 3, 224, 224])" 621 | ] 622 | }, 623 | "execution_count": 19, 624 | "metadata": {}, 625 | "output_type": "execute_result" 626 | } 627 | ], 628 | "source": [ 629 | "sample['pixel_values'].shape" 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": 20, 635 | "metadata": {}, 636 | "outputs": [], 637 | "source": [ 638 | "output = docformer_v2(**sample)" 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": 23, 644 | "metadata": {}, 645 | "outputs": [ 646 | { 647 | "data": { 648 | "text/plain": [ 649 | "Seq2SeqLMOutput(loss=None, logits=tensor([[[ 6.2223, -0.0186, -0.4801, ..., -0.6275, 0.9068, -0.2347],\n", 650 | " [ 0.9536, -1.4644, 2.2863, ..., 0.5846, 0.4761, -0.5474],\n", 651 | " [ 0.0564, -0.3024, 1.2863, ..., -0.5176, 0.5124, -1.1346],\n", 652 | " ...,\n", 653 | " [ 0.4013, 0.9506, 1.5823, ..., 1.9127, -0.8710, -0.0522],\n", 654 | " [ 0.6935, -0.1235, 2.1049, ..., -0.7052, 1.2970, -0.6487],\n", 655 | " [ 0.5439, -1.3241, 1.1074, ..., -0.2089, 0.6928, -1.1619]]],\n", 656 | " grad_fn=), past_key_values=((tensor([[[[-0.0593, 0.2877, -0.2666, ..., 1.1694, 0.3735, 0.1573],\n", 657 | " [ 1.3052, -1.8031, -0.9118, ..., -1.8054, -0.3311, -0.1553],\n", 658 | " [ 1.9435, 0.6398, 0.7937, ..., -0.2507, 0.7002, 0.6433],\n", 659 | " ...,\n", 660 | " [ 0.6455, 0.8375, -1.2233, ..., -0.0106, -0.5837, 0.5068],\n", 661 | " [ 1.0143, -0.9986, -1.8679, ..., 0.4936, -0.5922, 0.8664],\n", 662 | " [ 0.9427, -0.4008, 2.6253, ..., -0.6578, -0.3355, -1.6953]],\n", 663 | "\n", 664 | " [[ 0.1454, -0.6221, 1.5433, ..., -0.8987, 1.1462, 0.2721],\n", 665 | " [-0.0974, -0.6311, 0.8437, ..., -0.0550, -0.8665, -0.4734],\n", 666 | " [-0.9041, -1.2471, 0.2143, ..., 0.0528, -0.6683, -1.8630],\n", 667 | " ...,\n", 668 | " [-0.2288, 0.6528, -1.0433, ..., -0.6467, 0.0361, 1.1604],\n", 669 | " [-0.4237, 1.3213, 1.1616, ..., 0.5205, -0.2358, 0.9347],\n", 670 | " [-1.0581, 1.1252, 0.5662, ..., -0.5435, 1.5726, -0.3251]],\n", 671 | "\n", 672 | " [[-0.7464, 2.2378, -0.7501, ..., 0.1104, -0.1755, -0.5310],\n", 673 | " [ 0.1698, 0.5626, 0.5747, ..., -0.9974, 0.0882, 0.1471],\n", 674 | " [ 0.6634, -2.9765, -0.3719, ..., 1.6095, 1.1911, -0.3457],\n", 675 | " ...,\n", 676 | " [ 0.2708, -0.5521, 0.8154, ..., 0.5676, 0.4663, 0.2805],\n", 677 | " [ 0.2735, -0.0652, 0.8132, ..., 0.8427, -0.5900, -0.7943],\n", 678 | " [ 0.3308, 1.8184, -0.9341, ..., 0.7379, 0.0276, -0.1161]],\n", 679 | "\n", 680 | " ...,\n", 681 | "\n", 682 | " [[-0.2820, 0.5099, -0.0183, ..., 0.3365, 0.4139, 1.0931],\n", 683 | " [-2.7147, -2.1513, -1.0231, ..., 1.5720, 1.9800, -0.2389],\n", 684 | " [-1.4485, 1.5938, -0.5504, ..., 0.0366, 0.4515, 1.4280],\n", 685 | " ...,\n", 686 | " [-1.5314, -0.0964, 0.4245, ..., 0.9561, -0.7141, 0.2874],\n", 687 | " [ 0.5410, -0.7713, 0.8211, ..., -0.2017, -0.4230, 1.1753],\n", 688 | " [ 0.3459, -1.0430, -1.5762, ..., -1.8713, -0.2878, -0.1242]],\n", 689 | "\n", 690 | " [[ 0.2111, 0.3830, -0.9356, ..., -1.0819, 0.7509, -1.2613],\n", 691 | " [-0.0277, 0.2949, -0.2651, ..., 1.0139, 0.0210, -0.9773],\n", 692 | " [-1.4889, 0.5036, 0.7720, ..., -0.2182, 1.1019, 0.3626],\n", 693 | " ...,\n", 694 | " [-0.3904, 0.6469, -0.9039, ..., 1.2906, 0.3594, -1.0282],\n", 695 | " [ 0.6755, 1.1791, -0.3607, ..., 0.5006, -1.2662, 0.5094],\n", 696 | " [ 1.4437, -0.3035, 1.7333, ..., 1.1045, 1.2157, -1.2609]],\n", 697 | "\n", 698 | " [[-1.4155, -0.0085, -0.2453, ..., 0.9253, -0.4391, -0.5035],\n", 699 | " [ 0.2787, -0.6304, -2.0616, ..., -0.9024, -0.9042, 1.4136],\n", 700 | " [ 0.4760, -0.2602, 0.4933, ..., -2.2250, -1.1745, 0.3813],\n", 701 | " ...,\n", 702 | " [-1.2263, 0.5764, 0.1533, ..., 2.5216, 0.9590, 0.5763],\n", 703 | " [ 0.9124, -0.3126, -0.2478, ..., 1.3816, 1.8564, -0.3897],\n", 704 | " [-2.2102, 0.7775, -1.4069, ..., 0.4870, 0.1930, -0.6602]]]],\n", 705 | " grad_fn=), tensor([[[[ 1.4614e+00, 2.6421e-01, -1.3680e+00, ..., 1.4169e+00,\n", 706 | " -5.7289e-01, -1.2572e+00],\n", 707 | " [-3.1274e+00, 4.4716e-01, 2.5781e-01, ..., -2.9186e-01,\n", 708 | " -2.1492e-02, 1.7767e-01],\n", 709 | " [-6.5700e-03, -2.3189e-01, 5.6596e-01, ..., 6.6354e-02,\n", 710 | " -1.7831e+00, 1.0403e-01],\n", 711 | " ...,\n", 712 | " [-2.1058e+00, -1.0511e+00, 1.2126e+00, ..., -8.8670e-01,\n", 713 | " 1.8839e+00, -3.1118e-01],\n", 714 | " [ 2.8340e-01, 8.8827e-01, 9.3949e-02, ..., -3.4579e-01,\n", 715 | " -4.1511e-01, -1.5941e+00],\n", 716 | " [-9.5013e-01, -7.7862e-01, 1.1353e+00, ..., 3.5013e-01,\n", 717 | " 3.4780e-01, -5.2310e-01]],\n", 718 | "\n", 719 | " [[-2.4981e-01, 1.0723e+00, 5.5596e-01, ..., -1.7591e+00,\n", 720 | " 2.0452e+00, 3.5227e-02],\n", 721 | " [ 2.9537e-01, -8.8496e-01, 2.5376e-01, ..., 7.3158e-02,\n", 722 | " -1.1627e+00, 5.9921e-01],\n", 723 | " [-2.2584e-01, -1.3669e+00, -4.6135e-01, ..., -2.4335e-01,\n", 724 | " 1.0478e-01, 1.2539e+00],\n", 725 | " ...,\n", 726 | " [ 3.5619e-01, -1.2489e+00, 4.7297e-02, ..., 9.2704e-03,\n", 727 | " 1.2211e-01, -5.3558e-01],\n", 728 | " [ 6.0729e-01, 5.2458e-04, 5.6767e-01, ..., -8.2656e-01,\n", 729 | " 3.6458e-02, -1.0444e+00],\n", 730 | " [-2.2625e+00, 1.0393e+00, -5.7954e-01, ..., -5.4839e-01,\n", 731 | " 8.8222e-01, -3.8751e-01]],\n", 732 | "\n", 733 | " [[-2.2939e+00, -1.6600e+00, 5.9477e-01, ..., -9.3079e-01,\n", 734 | " -8.6245e-01, -2.5887e-01],\n", 735 | " [-5.4984e-01, -6.1888e-01, -1.3028e+00, ..., 1.2452e+00,\n", 736 | " 1.4431e+00, -7.5765e-01],\n", 737 | " [ 3.5154e-01, 9.7401e-01, 2.3115e+00, ..., -4.8683e-01,\n", 738 | " 2.9393e-01, 9.5430e-01],\n", 739 | " ...,\n", 740 | " [ 9.0311e-01, 3.0920e-01, -1.1898e+00, ..., 1.2898e+00,\n", 741 | " -7.8601e-01, -1.6391e-01],\n", 742 | " [ 4.0611e-01, -7.3411e-01, -1.1854e-01, ..., 7.2030e-01,\n", 743 | " 2.6172e-01, -9.3276e-01],\n", 744 | " [-2.8767e-01, -2.2262e-01, 1.6130e+00, ..., 5.2591e-01,\n", 745 | " 1.4270e+00, 8.7006e-02]],\n", 746 | "\n", 747 | " ...,\n", 748 | "\n", 749 | " [[ 1.7032e-02, -1.9356e+00, -9.4754e-02, ..., 9.0679e-01,\n", 750 | " -1.4046e+00, 1.7679e+00],\n", 751 | " [ 4.7346e-01, -1.3400e+00, -6.4838e-01, ..., 3.9076e-01,\n", 752 | " -1.1330e+00, -4.9633e-01],\n", 753 | " [ 4.9744e-01, 5.8535e-01, 1.6768e+00, ..., -6.5851e-01,\n", 754 | " 2.2386e+00, 9.6393e-02],\n", 755 | " ...,\n", 756 | " [ 3.9765e-01, 1.0863e+00, 4.8739e-01, ..., -6.2229e-01,\n", 757 | " 9.8109e-01, 6.7772e-01],\n", 758 | " [ 7.2959e-01, 1.1064e+00, 2.7164e-01, ..., 7.7526e-01,\n", 759 | " -8.7584e-01, -1.3682e+00],\n", 760 | " [ 4.7901e-01, -1.0109e+00, -8.6849e-01, ..., -1.2524e+00,\n", 761 | " 7.1824e-01, 1.7993e+00]],\n", 762 | "\n", 763 | " [[ 5.4820e-01, -8.5637e-01, 3.3024e-01, ..., -7.6496e-01,\n", 764 | " -9.4125e-01, -2.9034e-01],\n", 765 | " [-7.9766e-01, 8.6263e-01, -1.1209e+00, ..., -4.5686e-01,\n", 766 | " 3.2732e-01, 9.6346e-01],\n", 767 | " [-5.8102e-01, -4.6601e-01, 8.4692e-01, ..., 1.0654e+00,\n", 768 | " -4.0217e-01, -2.8820e+00],\n", 769 | " ...,\n", 770 | " [-6.1251e-01, -1.6670e+00, -1.1108e+00, ..., 1.1542e+00,\n", 771 | " 2.4419e+00, 1.4342e+00],\n", 772 | " [-8.8380e-01, -1.1298e+00, -2.6068e-01, ..., -5.7450e-01,\n", 773 | " -5.4383e-01, 7.3172e-01],\n", 774 | " [-6.7370e-01, 1.4037e+00, 1.1303e+00, ..., 6.1029e-01,\n", 775 | " -2.7471e-01, 1.8447e+00]],\n", 776 | "\n", 777 | " [[-9.4917e-01, -3.0915e-01, 5.8592e-01, ..., -1.8398e-01,\n", 778 | " 1.9881e+00, -1.2244e+00],\n", 779 | " [-3.2779e-01, -2.0070e+00, -4.0755e-02, ..., 1.0777e+00,\n", 780 | " -4.2478e-01, 3.8624e-01],\n", 781 | " [-9.2764e-01, 5.7944e-01, 2.0888e-01, ..., 1.3288e-01,\n", 782 | " -6.9058e-01, 5.5958e-01],\n", 783 | " ...,\n", 784 | " [-1.9358e+00, -5.9562e-01, -1.7891e+00, ..., 1.8458e-01,\n", 785 | " -1.4226e+00, 1.5566e-01],\n", 786 | " [-7.7846e-01, 2.7922e-01, 9.6240e-01, ..., -4.0436e-01,\n", 787 | " 5.1693e-01, 2.0855e-01],\n", 788 | " [-1.4111e-01, 1.3559e+00, 1.7939e+00, ..., 3.7332e-01,\n", 789 | " 7.0182e-01, 1.3282e+00]]]], grad_fn=), tensor([[[[-7.6957e-02, 1.3467e+00, -2.2772e+00, ..., -1.3127e-01,\n", 790 | " -2.5787e+00, 2.0870e-01],\n", 791 | " [-6.7614e-01, 7.5614e-01, -3.7543e-01, ..., 6.3095e-01,\n", 792 | " -1.0680e+00, 1.1346e+00],\n", 793 | " [ 5.2638e-01, 3.2473e-01, -1.6300e+00, ..., 4.6405e-01,\n", 794 | " -1.5912e+00, 5.2696e-02],\n", 795 | " ...,\n", 796 | " [ 6.6602e-01, 1.5955e+00, 6.0958e-01, ..., -7.8513e-01,\n", 797 | " 3.0099e-01, 1.1625e+00],\n", 798 | " [-6.2095e-01, 1.7021e+00, 1.5416e+00, ..., -8.1195e-02,\n", 799 | " -2.8998e-01, 1.2772e+00],\n", 800 | " [ 1.8570e-01, 1.7788e+00, 1.1997e+00, ..., 2.7613e+00,\n", 801 | " 6.1724e-01, 4.3249e-01]],\n", 802 | "\n", 803 | " [[-9.3422e-02, 3.1692e-01, -2.1002e+00, ..., 9.6361e-01,\n", 804 | " 6.4749e-01, 1.3585e+00],\n", 805 | " [ 2.6087e-01, 9.5155e-02, -2.3111e+00, ..., 2.3682e+00,\n", 806 | " 1.3874e+00, -3.0096e-01],\n", 807 | " [-1.9618e-01, 9.1673e-01, -2.9299e+00, ..., 1.2490e+00,\n", 808 | " -1.1998e-02, -2.3149e-01],\n", 809 | " ...,\n", 810 | " [-5.1294e-01, 1.6511e+00, 8.1088e-02, ..., 1.3903e+00,\n", 811 | " -3.5206e-01, 4.3983e-01],\n", 812 | " [ 4.5217e-01, 1.7012e+00, 9.0029e-03, ..., 5.1455e-01,\n", 813 | " 1.0435e+00, 4.6214e-02],\n", 814 | " [-1.0089e+00, -1.2108e-01, -2.5462e+00, ..., -3.3143e-02,\n", 815 | " -8.5829e-01, -1.1167e-01]],\n", 816 | "\n", 817 | " [[-2.4930e-01, -1.6055e-01, 8.8977e-02, ..., 1.1330e+00,\n", 818 | " -3.6374e-01, 4.1711e-01],\n", 819 | " [ 7.6840e-01, 9.6556e-01, 1.4596e-01, ..., 4.2453e-01,\n", 820 | " 5.8496e-01, -1.0509e+00],\n", 821 | " [-2.0196e-01, -1.5585e-02, -1.1686e+00, ..., 4.9474e-01,\n", 822 | " -9.0701e-01, -7.4988e-01],\n", 823 | " ...,\n", 824 | " [-1.2256e+00, 1.9173e+00, 1.1561e+00, ..., 4.6376e-01,\n", 825 | " 1.6046e-01, -9.8343e-01],\n", 826 | " [-7.1568e-01, -3.1049e-01, -1.3215e+00, ..., 5.6664e-02,\n", 827 | " -1.6692e+00, 1.7667e-01],\n", 828 | " [ 3.3574e-01, -1.0546e+00, -1.1884e+00, ..., 1.2593e+00,\n", 829 | " -4.6818e-01, 7.3632e-01]],\n", 830 | "\n", 831 | " ...,\n", 832 | "\n", 833 | " [[-1.4360e+00, 7.0257e-01, -7.7453e-01, ..., 9.9753e-02,\n", 834 | " -5.0381e-01, 2.4821e-01],\n", 835 | " [-1.0479e+00, 8.8262e-01, 8.7385e-01, ..., -1.2081e+00,\n", 836 | " 2.1635e-01, 5.5342e-01],\n", 837 | " [-5.3290e-01, 8.2664e-01, -1.6155e-01, ..., -3.4853e-01,\n", 838 | " 1.9901e-01, 1.8356e-01],\n", 839 | " ...,\n", 840 | " [-7.2076e-01, 1.1281e+00, 8.6445e-02, ..., -1.1269e-02,\n", 841 | " -7.4731e-01, -1.2162e+00],\n", 842 | " [-1.9803e+00, 7.4582e-01, 4.4803e-01, ..., 2.5625e-01,\n", 843 | " -8.3720e-01, 6.3939e-01],\n", 844 | " [-1.3989e+00, 2.4391e-01, -1.0030e-01, ..., -9.0857e-01,\n", 845 | " 1.6697e+00, 1.8484e-01]],\n", 846 | "\n", 847 | " [[ 2.3007e-01, -2.4493e-01, 5.6382e-01, ..., -5.8551e-01,\n", 848 | " -9.0174e-01, -1.7688e-01],\n", 849 | " [ 1.6782e-01, -9.9314e-01, -1.3224e+00, ..., -7.5531e-01,\n", 850 | " -1.4299e+00, -4.8854e-01],\n", 851 | " [ 1.8034e-01, 3.8624e-02, -4.6893e-01, ..., -3.3213e-01,\n", 852 | " -1.9771e-01, -3.1655e-01],\n", 853 | " ...,\n", 854 | " [ 1.0759e+00, -4.7448e-01, 1.3226e+00, ..., -1.1421e+00,\n", 855 | " -7.6845e-02, 2.2165e+00],\n", 856 | " [ 6.6337e-01, 9.8278e-01, 2.4586e-01, ..., -4.1138e-01,\n", 857 | " -5.5127e-01, 2.6354e+00],\n", 858 | " [-5.1592e-01, 9.2007e-01, 1.1842e+00, ..., -4.6586e-01,\n", 859 | " -6.4577e-01, 1.5162e+00]],\n", 860 | "\n", 861 | " [[ 1.2828e+00, -7.5898e-01, -1.0110e+00, ..., 7.0950e-02,\n", 862 | " 2.1994e+00, -2.3048e+00],\n", 863 | " [-3.4147e-01, 9.1635e-01, 6.1921e-01, ..., -1.3381e+00,\n", 864 | " 1.3300e+00, -2.8113e+00],\n", 865 | " [-8.2520e-01, 4.4047e-01, 5.2577e-01, ..., -2.1698e+00,\n", 866 | " 8.8754e-01, -1.2260e+00],\n", 867 | " ...,\n", 868 | " [-1.9247e-01, 1.1727e+00, 1.0352e-03, ..., -6.0290e-02,\n", 869 | " 8.1698e-01, 4.0958e-01],\n", 870 | " [-3.3347e-01, 6.3272e-01, 7.8529e-01, ..., 1.4966e-01,\n", 871 | " 8.2079e-01, -2.0603e+00],\n", 872 | " [ 1.1174e+00, 3.0452e-01, -1.3509e-01, ..., 2.0221e+00,\n", 873 | " 5.8850e-01, -2.6924e+00]]]], grad_fn=), tensor([[[[ 0.8688, 0.5364, -0.7254, ..., -1.0840, 2.3967, -0.2252],\n", 874 | " [ 1.7472, 0.0919, 2.6840, ..., -1.6227, 0.7445, -0.9784],\n", 875 | " [ 1.5110, -1.0991, 1.1170, ..., -0.8159, 2.0815, -1.1410],\n", 876 | " ...,\n", 877 | " [ 0.4057, -0.7315, 1.9774, ..., -0.2823, 1.4043, -0.3605],\n", 878 | " [ 0.1900, -0.9438, 1.0402, ..., -0.5438, 1.1754, -0.5636],\n", 879 | " [ 2.1515, -0.5214, 1.8893, ..., -1.4933, 2.0962, 1.0430]],\n", 880 | "\n", 881 | " [[ 0.5881, -0.5765, 0.2702, ..., 2.1152, -2.2780, 0.5873],\n", 882 | " [ 1.5734, 0.8787, 1.3521, ..., 2.5689, -1.1332, 1.8520],\n", 883 | " [ 1.1005, 0.5830, 0.4053, ..., 1.7859, -0.2911, 0.1786],\n", 884 | " ...,\n", 885 | " [ 1.0846, -1.6650, -0.1877, ..., 0.0122, 1.0750, 0.0898],\n", 886 | " [ 1.0181, -1.5172, -0.0514, ..., -0.3602, 0.9145, 0.2771],\n", 887 | " [ 1.1614, 0.2066, 0.2224, ..., 1.6964, 0.7498, 1.8952]],\n", 888 | "\n", 889 | " [[ 2.7685, -0.9358, -0.8984, ..., 2.5253, -1.1156, 0.1228],\n", 890 | " [ 1.8086, -1.5474, -0.2396, ..., 0.6646, -2.3186, 0.0773],\n", 891 | " [ 2.3644, -1.2871, -1.3978, ..., -0.1867, 0.5909, -0.0646],\n", 892 | " ...,\n", 893 | " [ 1.9479, -0.2236, -0.8505, ..., 0.6196, 0.1497, -1.1519],\n", 894 | " [ 0.3514, 0.3656, -0.5232, ..., 0.5709, 0.7508, -1.9574],\n", 895 | " [-0.2761, -0.4440, -0.7157, ..., 0.9986, 0.1114, -1.6255]],\n", 896 | "\n", 897 | " ...,\n", 898 | "\n", 899 | " [[-0.0064, -0.7360, 2.4913, ..., 0.5837, 2.4793, 1.9118],\n", 900 | " [ 0.2170, -1.4111, 1.2690, ..., -0.4135, 2.1203, 1.5432],\n", 901 | " [ 0.5159, -1.1182, 1.5664, ..., 0.0886, 2.9707, 1.4184],\n", 902 | " ...,\n", 903 | " [ 0.5881, -0.7907, 0.4915, ..., 1.0885, 0.8471, 1.3585],\n", 904 | " [ 0.2222, 0.1144, -0.9330, ..., -0.5350, 1.4225, 1.4420],\n", 905 | " [-0.5992, -1.6092, 0.4180, ..., 0.5187, 1.3839, -0.0053]],\n", 906 | "\n", 907 | " [[ 2.7917, -0.9076, 1.5639, ..., -1.2175, 0.2903, 0.6959],\n", 908 | " [ 0.7850, 0.0552, 1.4944, ..., -0.2021, 0.6601, 0.7340],\n", 909 | " [ 1.5903, 0.2466, 1.1094, ..., -0.2971, -0.7853, -0.0699],\n", 910 | " ...,\n", 911 | " [ 0.2741, 0.0100, 2.3455, ..., 0.7821, -2.3987, -2.0818],\n", 912 | " [ 0.8170, 1.3557, -0.2326, ..., 0.2625, -1.3975, -1.9373],\n", 913 | " [ 0.5239, -0.0389, -0.3231, ..., 0.1644, -1.7381, -2.8768]],\n", 914 | "\n", 915 | " [[-0.8328, -0.5694, -1.2800, ..., -1.3780, -0.2132, 0.1676],\n", 916 | " [-1.4498, 0.7050, -0.9369, ..., -1.5966, 0.6326, 0.8342],\n", 917 | " [-1.8059, -1.4517, -1.8759, ..., -1.9688, 1.1654, 2.0732],\n", 918 | " ...,\n", 919 | " [-1.3286, 1.5237, -1.2488, ..., -1.3609, 0.1855, -0.8612],\n", 920 | " [-1.1255, 1.5864, -0.6492, ..., -1.7942, -0.5066, -0.2912],\n", 921 | " [-1.6661, -0.1065, -0.4862, ..., -0.4511, -0.3101, 0.4935]]]],\n", 922 | " grad_fn=)), (tensor([[[[ 1.4972, 0.0548, -0.8492, ..., 0.7725, 0.7540, 0.3482],\n", 923 | " [ 0.1677, -0.7811, -0.6212, ..., -1.1105, -0.7886, -1.0756],\n", 924 | " [ 1.6930, 0.1337, -0.2441, ..., 1.2957, 1.4646, -1.2376],\n", 925 | " ...,\n", 926 | " [ 0.7339, -1.4922, -1.1063, ..., -1.2953, -0.7443, -1.4983],\n", 927 | " [ 2.2015, 0.0601, -0.6939, ..., 0.9228, -0.3546, -1.1624],\n", 928 | " [ 1.7473, 0.3134, 0.7153, ..., 0.3966, 0.5337, -0.0977]],\n", 929 | "\n", 930 | " [[ 1.4532, 0.0120, -1.8781, ..., 0.1758, 0.8710, 0.0357],\n", 931 | " [ 0.4232, 0.0848, -1.4846, ..., 1.3118, 1.0841, 0.8072],\n", 932 | " [-0.3530, 0.5586, -0.4459, ..., 0.7448, 0.7825, 1.5626],\n", 933 | " ...,\n", 934 | " [-1.6030, 0.9190, -0.8973, ..., 0.0339, 0.3640, -0.4783],\n", 935 | " [-1.8310, -0.7030, 0.5529, ..., 0.5007, -1.1818, 1.0585],\n", 936 | " [-2.7480, 0.3721, 0.2619, ..., 0.3845, 0.6930, -1.1113]],\n", 937 | "\n", 938 | " [[-2.0657, -0.0258, -1.1721, ..., -0.5406, -2.4868, 0.2531],\n", 939 | " [-1.0266, 0.5884, -0.2476, ..., -2.2111, 0.6811, 0.3764],\n", 940 | " [ 0.2181, -0.2988, -0.3583, ..., -1.7328, -1.0936, -0.6777],\n", 941 | " ...,\n", 942 | " [-0.1842, 0.6049, -0.2303, ..., 0.2667, 1.0310, -0.5046],\n", 943 | " [-1.3931, -0.0641, -0.4162, ..., -1.2164, -0.7951, -1.2568],\n", 944 | " [ 0.5488, 0.1479, -0.6829, ..., 0.5100, -1.2281, -0.6044]],\n", 945 | "\n", 946 | " ...,\n", 947 | "\n", 948 | " [[ 1.5960, -0.6297, 2.0505, ..., -1.3070, 0.4580, -0.2301],\n", 949 | " [ 0.8438, 1.1838, 0.2349, ..., -0.2540, 0.4186, -2.7262],\n", 950 | " [ 1.3582, 1.9834, 0.2390, ..., -0.0939, -0.5920, -0.4031],\n", 951 | " ...,\n", 952 | " [-0.4578, 0.6615, 1.2834, ..., -0.9606, 0.8916, 0.8069],\n", 953 | " [ 1.6657, -0.9052, 0.7405, ..., -2.0862, -1.6826, -1.3867],\n", 954 | " [-1.7419, -1.7181, -0.4886, ..., -2.0708, -1.3262, 0.3652]],\n", 955 | "\n", 956 | " [[ 0.8021, 0.1249, 0.7070, ..., -0.6813, -0.0800, 2.3084],\n", 957 | " [ 0.4512, 0.5156, 1.2959, ..., 0.0935, -0.7216, 0.1942],\n", 958 | " [ 0.2148, -0.7512, 0.9053, ..., 1.6215, 1.3450, 1.0556],\n", 959 | " ...,\n", 960 | " [ 0.6467, -0.7630, 2.1952, ..., 0.0064, 0.9875, 0.6260],\n", 961 | " [-0.0549, -0.8935, 2.0113, ..., 0.0311, -0.1896, -0.6099],\n", 962 | " [ 1.5192, -0.5745, 0.3256, ..., -0.2528, -1.6037, 0.2922]],\n", 963 | "\n", 964 | " [[-1.0550, -0.3205, 1.8077, ..., -1.3200, -1.4548, 1.5339],\n", 965 | " [-0.2584, -0.7082, 1.3633, ..., -0.8250, -1.1663, 2.0331],\n", 966 | " [ 0.1330, 0.9049, 1.3391, ..., -1.2514, -1.7474, -1.1612],\n", 967 | " ...,\n", 968 | " [ 0.1905, 0.6870, -0.2597, ..., -0.2404, 0.5366, 0.6617],\n", 969 | " [ 0.5768, -0.8065, -0.0583, ..., -0.9052, -0.6492, 0.2074],\n", 970 | " [-1.1593, 0.8707, -0.5213, ..., -0.4319, -0.4374, 0.9552]]]],\n", 971 | " grad_fn=), tensor([[[[ 1.2982, -0.9972, -0.3269, ..., -0.2488, 0.2613, 0.4454],\n", 972 | " [ 0.4781, -1.2991, 0.9245, ..., 0.1087, -0.5226, -1.7714],\n", 973 | " [-0.3153, -1.2818, 0.5826, ..., -0.9816, -0.4208, -0.5874],\n", 974 | " ...,\n", 975 | " [ 1.2261, -0.2078, 1.2430, ..., -1.0384, -0.4323, -1.8366],\n", 976 | " [-0.4600, -0.8842, 0.8262, ..., -0.1638, -0.9615, -2.0290],\n", 977 | " [-0.3516, -0.4009, 0.7691, ..., -0.4747, -0.5075, 0.1586]],\n", 978 | "\n", 979 | " [[ 0.2218, 0.3748, 1.2652, ..., 0.7203, 0.5564, 1.1997],\n", 980 | " [ 0.6081, 1.9555, -0.0429, ..., -1.1584, -0.4135, 1.3643],\n", 981 | " [-0.1917, 0.7643, -0.6769, ..., 0.0663, 0.0180, -1.0422],\n", 982 | " ...,\n", 983 | " [ 0.2401, 1.1992, 0.2265, ..., -0.0689, -1.2312, 0.8386],\n", 984 | " [ 0.1670, 1.0325, 0.3731, ..., -0.0377, 0.3169, -0.0356],\n", 985 | " [ 0.4813, -0.0289, 0.2292, ..., -0.2011, -0.3846, 1.3926]],\n", 986 | "\n", 987 | " [[-0.6230, -0.7001, -0.0511, ..., -0.4740, 0.1516, 1.8745],\n", 988 | " [-0.8654, 0.8075, -0.6455, ..., 1.2727, 0.1258, 0.0364],\n", 989 | " [ 0.6277, -1.3749, -1.8215, ..., 1.5821, 0.9678, 0.6904],\n", 990 | " ...,\n", 991 | " [ 2.0298, 0.4327, -1.0653, ..., 1.3597, -0.5206, -0.5428],\n", 992 | " [-0.2377, 0.4516, 0.8573, ..., 1.7334, 0.1379, 0.1815],\n", 993 | " [ 0.2135, -0.4674, 0.8017, ..., 1.1495, -0.0649, -0.5759]],\n", 994 | "\n", 995 | " ...,\n", 996 | "\n", 997 | " [[-1.2359, -1.1578, -1.0298, ..., 0.4045, 0.0210, -0.1862],\n", 998 | " [-1.1205, 0.6622, -0.9978, ..., 0.7161, -0.4713, -0.0663],\n", 999 | " [-0.0878, 0.4926, -0.0563, ..., -0.4914, 0.7963, 1.2020],\n", 1000 | " ...,\n", 1001 | " [ 0.7048, 0.8549, 0.0156, ..., 1.5859, 0.3024, -0.2348],\n", 1002 | " [-1.2799, -0.4781, 0.2954, ..., 1.2974, -1.3769, 1.0408],\n", 1003 | " [-0.6115, 2.5117, -0.0040, ..., 1.1728, 0.2376, -0.7232]],\n", 1004 | "\n", 1005 | " [[-0.6604, 1.3118, -2.1030, ..., 0.8193, -0.2627, -0.8209],\n", 1006 | " [-0.9965, -0.5900, -1.5980, ..., 0.0397, -0.1314, -0.2116],\n", 1007 | " [-0.9856, 0.7732, -1.5089, ..., 1.0963, 0.2897, 0.1653],\n", 1008 | " ...,\n", 1009 | " [ 0.8940, -1.6817, -0.2987, ..., 0.2403, -0.5766, -0.4909],\n", 1010 | " [ 1.4078, -0.1563, -0.9025, ..., -0.8853, 0.2263, -0.2672],\n", 1011 | " [-1.6737, -0.8676, 0.2286, ..., -0.8342, 1.1494, 0.1946]],\n", 1012 | "\n", 1013 | " [[ 1.2522, 0.5929, -0.7568, ..., 0.3075, 3.2335, -0.8140],\n", 1014 | " [ 0.6025, 1.7070, -0.2293, ..., 0.5756, 1.8738, 1.1110],\n", 1015 | " [ 1.2706, 1.6969, -1.1226, ..., 0.5002, 0.7001, -0.0096],\n", 1016 | " ...,\n", 1017 | " [ 1.3437, -0.3139, -0.4875, ..., 0.6793, 0.4243, -0.9666],\n", 1018 | " [ 0.9501, 2.0200, -1.0568, ..., -2.2206, 2.3502, -1.2019],\n", 1019 | " [ 1.0499, 0.8757, -1.9000, ..., -2.6846, 0.6073, 0.6173]]]],\n", 1020 | " grad_fn=), tensor([[[[-0.8133, 2.3131, -0.6237, ..., 1.8601, 0.0450, 0.2721],\n", 1021 | " [ 0.6199, 0.7449, -0.4622, ..., 0.4508, -0.7734, 1.3844],\n", 1022 | " [ 1.1796, 0.5065, -1.1026, ..., 1.8926, -0.8349, 1.5607],\n", 1023 | " ...,\n", 1024 | " [ 0.1642, 1.0497, -0.0843, ..., -0.1039, -0.2775, 0.8779],\n", 1025 | " [ 1.0267, 1.3056, -0.1584, ..., 0.3172, -1.6172, 0.6704],\n", 1026 | " [-0.8565, 1.9017, -0.2574, ..., 0.4322, -1.2465, 0.3324]],\n", 1027 | "\n", 1028 | " [[ 0.5507, -0.6337, 0.4014, ..., -1.5466, -1.9956, 0.5846],\n", 1029 | " [ 0.7735, -2.6480, 1.6007, ..., -0.9854, -1.4292, -1.6044],\n", 1030 | " [ 0.7047, -1.0891, 1.3642, ..., -0.0228, -1.5030, -0.7269],\n", 1031 | " ...,\n", 1032 | " [-0.5708, -0.2313, -0.2430, ..., -1.5256, -1.0448, -0.6344],\n", 1033 | " [-0.8687, 0.2726, -0.0084, ..., -0.3110, -1.5450, -0.8091],\n", 1034 | " [ 0.3754, -1.8367, 0.0987, ..., -1.4098, -0.0887, 0.0829]],\n", 1035 | "\n", 1036 | " [[-0.3976, -0.2779, -0.4185, ..., -0.2048, 2.2570, 1.2822],\n", 1037 | " [ 0.1871, -0.9892, -0.8770, ..., 0.2097, 1.0988, 3.1225],\n", 1038 | " [-1.0281, -1.9130, 0.4286, ..., 1.0727, 1.3678, 1.3618],\n", 1039 | " ...,\n", 1040 | " [-0.1614, -1.0152, -0.0461, ..., 0.2102, -0.8302, 1.7495],\n", 1041 | " [ 0.0331, 1.0859, -0.4850, ..., -0.9841, -1.4613, 0.6022],\n", 1042 | " [-0.2613, 2.8318, -1.3642, ..., 0.1756, 0.4956, -0.4052]],\n", 1043 | "\n", 1044 | " ...,\n", 1045 | "\n", 1046 | " [[ 0.2311, 1.3527, 0.6797, ..., -0.7439, -0.7342, 3.2949],\n", 1047 | " [-1.0338, -0.0236, 1.3021, ..., -0.2475, -1.0406, 3.3090],\n", 1048 | " [-0.5765, 0.6739, 1.2303, ..., 0.3393, 0.0555, 1.3480],\n", 1049 | " ...,\n", 1050 | " [ 0.4424, -0.4704, -1.5654, ..., 1.0121, 0.0781, -0.3446],\n", 1051 | " [-0.6506, 0.6219, -1.8603, ..., 0.8730, -0.5954, 0.2033],\n", 1052 | " [ 0.6416, -1.7000, -0.5156, ..., 1.7490, 0.9951, 0.1511]],\n", 1053 | "\n", 1054 | " [[-0.9009, 1.2925, 0.5161, ..., 1.7396, 0.6485, -1.8368],\n", 1055 | " [-1.0451, -1.1681, -2.4656, ..., -0.3525, -0.0661, -0.6220],\n", 1056 | " [-1.3485, -0.7776, -0.4497, ..., -0.0853, -0.4253, 1.2824],\n", 1057 | " ...,\n", 1058 | " [-0.6166, 0.5124, 0.0270, ..., 0.9915, 1.1851, -0.2333],\n", 1059 | " [-0.6399, -0.2243, 1.3538, ..., 0.5791, 0.0939, 0.4213],\n", 1060 | " [ 1.1871, 0.5937, 0.5900, ..., 0.1595, 0.1583, 0.5446]],\n", 1061 | "\n", 1062 | " [[-0.1095, 2.4405, -0.3235, ..., 0.0165, -0.8736, -1.8667],\n", 1063 | " [-0.4798, 1.8698, 0.5644, ..., 0.8450, -0.4490, -0.0879],\n", 1064 | " [-0.1944, 2.4533, 0.7569, ..., 0.4661, -2.1294, -0.7186],\n", 1065 | " ...,\n", 1066 | " [-0.0828, 1.2430, 2.5390, ..., 0.3106, -0.1020, -1.2724],\n", 1067 | " [ 0.4107, 3.6248, 1.7379, ..., 1.0112, 0.2110, -1.6212],\n", 1068 | " [ 0.2735, 1.5502, 1.8762, ..., 0.3584, 0.6244, -0.2083]]]],\n", 1069 | " grad_fn=), tensor([[[[ 0.1085, -0.1901, 0.7732, ..., -1.5325, -0.8262, 0.7557],\n", 1070 | " [-0.6599, -0.6945, 0.8944, ..., -0.0404, -0.2255, -0.1132],\n", 1071 | " [-1.1583, -1.7888, 0.9448, ..., 0.4826, -0.2198, 0.1881],\n", 1072 | " ...,\n", 1073 | " [ 0.6431, -1.6765, -0.6248, ..., 0.4865, -1.5438, 0.7231],\n", 1074 | " [ 0.9629, -0.9633, -1.4301, ..., -0.3133, -1.4580, -0.0413],\n", 1075 | " [-0.6191, -0.6393, -1.2475, ..., -0.9843, -0.7119, 0.5108]],\n", 1076 | "\n", 1077 | " [[-0.8243, 0.5345, 0.4964, ..., 0.5066, 0.6636, -0.1276],\n", 1078 | " [-1.2929, -2.7970, 1.1191, ..., -0.2642, 0.6227, 0.6867],\n", 1079 | " [-0.2780, -2.6756, -0.7292, ..., 0.1363, 1.4578, 0.5750],\n", 1080 | " ...,\n", 1081 | " [ 0.9285, 0.3331, 0.7417, ..., 1.1403, 0.6361, 0.5883],\n", 1082 | " [ 1.7821, 0.6418, -0.1684, ..., 0.2316, 0.6381, 1.1744],\n", 1083 | " [-0.0433, 0.6234, -1.0403, ..., -0.0614, 2.8693, 0.1322]],\n", 1084 | "\n", 1085 | " [[-0.8548, -0.2886, -0.2150, ..., 0.6713, 0.4732, -0.4427],\n", 1086 | " [ 0.6549, 0.5047, -1.1564, ..., 1.1409, -2.1023, 0.8922],\n", 1087 | " [-0.5133, 0.4339, -1.8949, ..., 0.7632, -1.7318, 0.8244],\n", 1088 | " ...,\n", 1089 | " [-1.1167, 0.7593, -2.9044, ..., 1.6787, -1.2371, 0.2394],\n", 1090 | " [-0.4085, -0.1975, -1.6200, ..., 0.4523, -0.5311, 0.1977],\n", 1091 | " [-0.5352, 0.2527, 0.0671, ..., 0.4205, 0.2510, -0.3969]],\n", 1092 | "\n", 1093 | " ...,\n", 1094 | "\n", 1095 | " [[ 0.8030, -0.5910, -0.3730, ..., 1.1750, -1.9027, 1.9609],\n", 1096 | " [-0.3013, -0.2259, -2.1837, ..., -0.8806, -1.4036, -1.2320],\n", 1097 | " [ 0.1819, 0.2404, -2.1868, ..., 0.9590, -0.8307, -0.2578],\n", 1098 | " ...,\n", 1099 | " [-0.6214, 0.0472, 0.5830, ..., 0.0867, 1.6352, 0.5018],\n", 1100 | " [-0.4400, -0.2661, 0.6328, ..., 0.0781, 1.1012, 1.3644],\n", 1101 | " [-0.2058, -0.6648, 1.1162, ..., -0.6451, 1.5585, 0.4018]],\n", 1102 | "\n", 1103 | " [[-0.7288, 0.0258, 0.3674, ..., -0.6008, -0.2705, 0.2849],\n", 1104 | " [-0.3574, 0.1402, 0.8523, ..., -1.1311, -0.0499, -0.7463],\n", 1105 | " [-0.1369, 0.9946, -0.4872, ..., -0.7923, 0.8450, 1.4764],\n", 1106 | " ...,\n", 1107 | " [-1.9335, 0.8847, 1.6469, ..., -2.2084, -0.6117, -0.6171],\n", 1108 | " [-1.0596, 2.0017, 1.1226, ..., -1.2155, 0.6042, 0.0387],\n", 1109 | " [-1.1470, 0.5434, 0.9675, ..., 0.0203, 0.1381, -0.4362]],\n", 1110 | "\n", 1111 | " [[-0.6805, -1.1255, -0.1703, ..., -1.6133, 0.5021, -1.0495],\n", 1112 | " [-0.9012, -1.3933, 0.5381, ..., -1.2507, -1.4188, -1.3195],\n", 1113 | " [-0.7038, -0.6248, 0.5233, ..., -1.2481, -0.9742, -1.3469],\n", 1114 | " ...,\n", 1115 | " [ 1.0825, 0.3802, 1.0406, ..., -0.3246, 0.2143, 0.5941],\n", 1116 | " [-0.4559, -0.4408, 1.0708, ..., -1.3772, 0.4639, 0.3717],\n", 1117 | " [ 1.1691, -0.1498, -0.4476, ..., -1.0607, -0.7118, -1.3441]]]],\n", 1118 | " grad_fn=)), (tensor([[[[-1.2185e+00, -9.9815e-01, 1.9104e+00, ..., 1.2167e+00,\n", 1119 | " 1.0378e+00, -1.0581e-02],\n", 1120 | " [-1.4640e+00, -2.7529e-01, 4.4344e-01, ..., 1.3074e+00,\n", 1121 | " 1.0477e+00, -3.7739e-01],\n", 1122 | " [-6.9174e-01, -6.4124e-01, 9.1028e-02, ..., 1.8960e+00,\n", 1123 | " 9.6028e-01, -4.5290e-01],\n", 1124 | " ...,\n", 1125 | " [-1.3568e+00, 4.3083e-01, 5.9462e-01, ..., -2.5880e-01,\n", 1126 | " -6.8505e-01, 3.6925e-01],\n", 1127 | " [-1.1006e+00, -3.6085e-01, -7.2609e-01, ..., 1.8191e-01,\n", 1128 | " -1.1100e+00, -9.7719e-01],\n", 1129 | " [-1.0325e+00, -7.4450e-01, 5.7649e-01, ..., 8.8636e-01,\n", 1130 | " -1.8957e-02, 1.2330e+00]],\n", 1131 | "\n", 1132 | " [[-1.5560e+00, -5.7987e-01, -1.3682e+00, ..., -4.8995e-01,\n", 1133 | " 1.3291e+00, 1.0491e-01],\n", 1134 | " [-1.7694e+00, 4.7057e-01, 1.8213e-01, ..., 4.2604e-01,\n", 1135 | " 2.6254e-02, -6.2608e-01],\n", 1136 | " [-1.2043e+00, 1.3721e+00, -1.6274e-01, ..., -4.9929e-01,\n", 1137 | " 1.3951e+00, -3.7089e-01],\n", 1138 | " ...,\n", 1139 | " [-6.0726e-01, 1.6715e-01, 1.2856e+00, ..., -9.0832e-02,\n", 1140 | " -5.3774e-01, 2.3115e-01],\n", 1141 | " [ 2.4984e-01, -3.2600e-01, -5.0830e-02, ..., 4.4040e-02,\n", 1142 | " 4.6672e-01, -9.7391e-01],\n", 1143 | " [-3.7426e-01, -2.1384e-01, 3.2299e-01, ..., -1.9327e-01,\n", 1144 | " -4.7217e-01, 3.7367e-01]],\n", 1145 | "\n", 1146 | " [[-5.4768e-01, -8.6922e-02, -2.3597e+00, ..., -6.5527e-01,\n", 1147 | " -7.9475e-01, 9.7313e-01],\n", 1148 | " [-1.4755e+00, -7.4756e-01, -2.1650e+00, ..., 1.3291e+00,\n", 1149 | " -3.7727e-01, 1.0992e+00],\n", 1150 | " [-5.8730e-02, -1.5283e+00, -2.1648e+00, ..., -9.2384e-02,\n", 1151 | " 2.0341e+00, 1.1478e+00],\n", 1152 | " ...,\n", 1153 | " [-8.8310e-01, -3.1663e-01, -7.0109e-01, ..., -8.0721e-02,\n", 1154 | " 1.2573e-01, 8.8213e-01],\n", 1155 | " [-6.3314e-01, 7.6389e-01, -2.2311e+00, ..., -7.4142e-01,\n", 1156 | " -9.0397e-02, -3.5889e-01],\n", 1157 | " [ 8.9602e-01, 2.4813e-01, -8.5474e-01, ..., -1.3538e-01,\n", 1158 | " -8.7224e-01, 7.8633e-01]],\n", 1159 | "\n", 1160 | " ...,\n", 1161 | "\n", 1162 | " [[ 4.2700e-01, -7.0779e-01, -1.6204e+00, ..., -5.8012e-01,\n", 1163 | " 1.6771e+00, -4.2902e-01],\n", 1164 | " [ 1.9616e+00, -1.5372e-01, -6.0260e-01, ..., -5.9763e-01,\n", 1165 | " 8.9153e-03, 8.9344e-02],\n", 1166 | " [ 2.3815e-01, -1.7201e+00, -1.5573e+00, ..., -1.0446e+00,\n", 1167 | " 9.0044e-01, -1.0920e+00],\n", 1168 | " ...,\n", 1169 | " [ 1.3247e+00, -1.6274e+00, 1.1333e+00, ..., 3.6685e-01,\n", 1170 | " 1.3327e+00, -8.0685e-01],\n", 1171 | " [ 2.7654e-02, -2.3357e+00, 1.9328e-02, ..., -1.5136e+00,\n", 1172 | " -6.6051e-05, -2.3928e-01],\n", 1173 | " [-7.1701e-01, -1.7551e+00, -7.1057e-01, ..., -4.0799e-01,\n", 1174 | " -5.7211e-01, -1.0299e+00]],\n", 1175 | "\n", 1176 | " [[ 8.9901e-01, -7.3133e-01, -1.7865e-01, ..., 1.6582e+00,\n", 1177 | " 1.3053e+00, -1.3414e-01],\n", 1178 | " [ 4.7168e-01, 8.1973e-01, 1.4725e+00, ..., 6.0384e-01,\n", 1179 | " 8.8135e-01, -1.5466e+00],\n", 1180 | " [ 1.0666e+00, 4.3672e-01, 2.2894e+00, ..., 1.3774e+00,\n", 1181 | " 4.0296e-01, -1.6471e+00],\n", 1182 | " ...,\n", 1183 | " [ 1.4741e+00, 2.3786e-01, 1.4417e+00, ..., 8.5788e-01,\n", 1184 | " -5.0899e-01, 4.3582e-01],\n", 1185 | " [-1.2204e+00, 9.2238e-01, 2.1288e+00, ..., 7.2993e-01,\n", 1186 | " -4.5158e-01, -1.9127e-01],\n", 1187 | " [ 8.6269e-01, 2.2572e-01, 2.0845e-01, ..., 2.6308e+00,\n", 1188 | " 7.9442e-01, 2.1597e-01]],\n", 1189 | "\n", 1190 | " [[-8.0424e-02, 9.8001e-01, 8.7290e-02, ..., -1.3556e+00,\n", 1191 | " -5.6701e-01, -4.5394e-01],\n", 1192 | " [ 5.1915e-01, 4.7460e-01, -6.2491e-01, ..., -4.4040e-01,\n", 1193 | " -7.1202e-01, 5.9684e-01],\n", 1194 | " [ 1.2885e+00, 8.4012e-01, -6.3085e-01, ..., -8.8185e-01,\n", 1195 | " -1.8821e-01, 4.9784e-01],\n", 1196 | " ...,\n", 1197 | " [ 4.5184e-01, 1.9999e+00, -9.1064e-01, ..., -6.8581e-01,\n", 1198 | " 4.8856e-02, 5.8514e-01],\n", 1199 | " [-6.0370e-01, 1.0453e+00, -3.1292e-01, ..., -5.8249e-01,\n", 1200 | " -1.8893e+00, 8.0134e-01],\n", 1201 | " [-1.2738e-01, 3.8033e-01, -1.2482e+00, ..., -8.6938e-01,\n", 1202 | " -9.9258e-01, -7.9568e-01]]]], grad_fn=), tensor([[[[-0.4724, 1.0839, 0.8063, ..., 1.1750, -0.4685, -1.6946],\n", 1203 | " [-1.3082, -0.8983, -0.7880, ..., 0.2600, -0.9099, -1.2397],\n", 1204 | " [-0.8279, -1.3326, -1.9180, ..., -0.3496, 0.3952, -0.6045],\n", 1205 | " ...,\n", 1206 | " [-0.2693, -0.9078, 1.1241, ..., -0.0778, 1.3822, -2.3358],\n", 1207 | " [-0.4133, -1.4295, -0.5930, ..., 0.2488, -0.1754, -1.0187],\n", 1208 | " [ 0.2094, -0.7800, 1.1726, ..., 0.0368, 1.1228, -0.8926]],\n", 1209 | "\n", 1210 | " [[ 0.4937, -0.5046, 0.5110, ..., -1.0513, 0.6154, -0.5309],\n", 1211 | " [ 0.9847, 1.5285, 2.0509, ..., 0.2309, 0.8909, 0.8726],\n", 1212 | " [ 1.0879, 0.0904, 2.1371, ..., -0.5120, 1.8586, -0.6129],\n", 1213 | " ...,\n", 1214 | " [ 1.4883, 1.1961, 0.4841, ..., -2.3308, 0.5332, 0.0039],\n", 1215 | " [-0.3977, -0.3255, 2.2729, ..., 0.0708, 1.7200, 0.8302],\n", 1216 | " [ 0.8156, -0.9793, 1.0732, ..., -1.5901, 1.0048, 0.1883]],\n", 1217 | "\n", 1218 | " [[-0.3255, 1.4261, -1.5275, ..., 0.3556, 0.4650, 0.4654],\n", 1219 | " [ 1.1065, 0.9324, -1.4242, ..., -0.0504, 0.6465, 0.4618],\n", 1220 | " [-0.5205, 2.4605, -1.3291, ..., -0.2079, -1.3804, 0.2259],\n", 1221 | " ...,\n", 1222 | " [-1.2338, -0.0747, -1.3979, ..., -1.3914, 0.3787, 0.6734],\n", 1223 | " [-0.4880, 0.3470, -0.8381, ..., -2.3530, 0.3590, 0.7542],\n", 1224 | " [ 1.5746, 1.3896, -0.0514, ..., -0.6396, 0.9713, 0.2103]],\n", 1225 | "\n", 1226 | " ...,\n", 1227 | "\n", 1228 | " [[-0.9752, 1.5632, -1.1736, ..., -0.0720, -0.3175, -1.5312],\n", 1229 | " [-0.6887, 0.1593, -0.6501, ..., 0.3012, -0.0565, -1.1415],\n", 1230 | " [ 0.0201, 1.9441, -1.7071, ..., -0.9138, -0.8830, -1.9346],\n", 1231 | " ...,\n", 1232 | " [-0.7668, 0.1580, -0.8772, ..., 0.1377, 1.2481, -0.7234],\n", 1233 | " [-0.9813, 1.0504, -2.2115, ..., 0.0964, 1.3161, -0.6810],\n", 1234 | " [-0.4299, -0.0724, -1.3412, ..., 0.0778, 1.9269, 0.0768]],\n", 1235 | "\n", 1236 | " [[ 0.8825, 0.4658, -0.8383, ..., 2.3657, 1.2150, -0.1660],\n", 1237 | " [-0.6860, -1.2839, 0.0251, ..., -0.2746, 1.0535, -1.0681],\n", 1238 | " [ 1.2790, 0.1710, -0.2096, ..., 0.4369, 0.4076, -0.4182],\n", 1239 | " ...,\n", 1240 | " [-0.0316, 1.2719, -0.5905, ..., 1.5969, -0.3511, 0.9447],\n", 1241 | " [ 0.4255, 2.1268, -0.5760, ..., 0.4775, 0.0180, 0.4869],\n", 1242 | " [-0.2970, -0.5588, -0.2625, ..., 1.5835, 0.0854, 0.4130]],\n", 1243 | "\n", 1244 | " [[-0.4890, -1.2588, -0.7160, ..., -0.2877, -0.4075, 0.2001],\n", 1245 | " [-0.7699, -0.1912, -0.1076, ..., -1.0989, 0.0496, 0.2020],\n", 1246 | " [ 0.3435, -2.0022, -1.6633, ..., -0.6451, 0.5222, 2.5624],\n", 1247 | " ...,\n", 1248 | " [-0.6927, -0.1054, -1.2639, ..., -1.6255, -0.4256, -0.2608],\n", 1249 | " [-0.0618, 0.1150, -0.7448, ..., -0.9205, -0.4802, 2.0752],\n", 1250 | " [-2.6176, 0.5340, -0.6606, ..., -0.3093, -0.0139, 0.2577]]]],\n", 1251 | " grad_fn=), tensor([[[[ 3.4533e-01, 3.6772e-01, 1.4009e-01, ..., -1.1620e+00,\n", 1252 | " 9.9274e-02, 4.6244e-01],\n", 1253 | " [ 1.3212e+00, -1.4218e+00, 1.5194e+00, ..., -1.2160e+00,\n", 1254 | " 9.2627e-01, 9.5089e-01],\n", 1255 | " [-6.7195e-01, -5.3049e-01, 1.3918e+00, ..., 4.7604e-03,\n", 1256 | " 1.1245e+00, 2.1688e+00],\n", 1257 | " ...,\n", 1258 | " [ 1.4063e-01, -3.0661e-02, 6.0490e-01, ..., 7.9748e-01,\n", 1259 | " 2.1934e+00, -4.0737e-01],\n", 1260 | " [-6.1399e-01, 5.8239e-01, 1.8883e-01, ..., -2.1115e-01,\n", 1261 | " 2.9063e-01, -4.2740e-01],\n", 1262 | " [-4.9458e-01, -1.3269e+00, 5.2839e-01, ..., 4.8899e-01,\n", 1263 | " 7.8337e-01, 3.8324e-02]],\n", 1264 | "\n", 1265 | " [[ 5.8023e-01, 1.9755e-01, -2.4368e+00, ..., -1.2456e-01,\n", 1266 | " 1.9437e-01, -7.6228e-01],\n", 1267 | " [-2.9588e-01, 1.5515e+00, -1.5377e+00, ..., -7.5390e-01,\n", 1268 | " 1.8940e-01, -1.1484e+00],\n", 1269 | " [ 2.4831e-01, 7.0364e-01, -2.5617e+00, ..., -6.6360e-01,\n", 1270 | " 1.2217e+00, -9.0076e-01],\n", 1271 | " ...,\n", 1272 | " [ 3.9211e-01, -1.0672e+00, -1.3936e+00, ..., 2.5601e-01,\n", 1273 | " -2.2418e+00, -4.7488e-01],\n", 1274 | " [ 2.0644e-01, 3.2161e-01, -7.0632e-01, ..., 1.1398e+00,\n", 1275 | " -6.2195e-01, -3.9433e-01],\n", 1276 | " [ 7.0337e-01, -1.0694e+00, 1.1198e+00, ..., -8.0356e-01,\n", 1277 | " -6.6937e-01, 4.6251e-02]],\n", 1278 | "\n", 1279 | " [[-1.3317e+00, 5.3823e-01, 1.6805e+00, ..., 5.6313e-01,\n", 1280 | " 7.1131e-01, -6.5080e-01],\n", 1281 | " [-1.8674e+00, -3.7794e-01, -7.3050e-02, ..., -7.8069e-02,\n", 1282 | " 9.2026e-01, 8.6108e-02],\n", 1283 | " [-1.2166e+00, 1.8130e-01, 2.8462e-01, ..., -1.1308e+00,\n", 1284 | " -4.1265e-02, 5.8018e-01],\n", 1285 | " ...,\n", 1286 | " [-9.0016e-01, 1.3675e+00, 1.8459e-01, ..., -7.8941e-01,\n", 1287 | " 4.1652e-02, 2.3315e-01],\n", 1288 | " [-1.5644e+00, 9.0743e-02, -2.1541e-01, ..., 1.1666e-01,\n", 1289 | " -6.9557e-02, -3.8731e-01],\n", 1290 | " [-8.5084e-01, 6.1508e-01, 6.9345e-02, ..., -6.8922e-01,\n", 1291 | " -1.0270e+00, 3.4642e-01]],\n", 1292 | "\n", 1293 | " ...,\n", 1294 | "\n", 1295 | " [[ 7.0962e-02, -2.0825e-01, 1.6565e-02, ..., -3.1266e-01,\n", 1296 | " 1.5602e+00, -3.9295e-01],\n", 1297 | " [-4.4136e-01, 1.1288e+00, 1.7785e-01, ..., -7.4895e-01,\n", 1298 | " 1.8970e+00, -1.3966e-01],\n", 1299 | " [ 3.5275e-01, -7.7085e-02, 9.1507e-01, ..., -7.7748e-01,\n", 1300 | " 1.0957e+00, -2.7197e-01],\n", 1301 | " ...,\n", 1302 | " [-9.5285e-01, 9.4714e-01, -1.0834e+00, ..., -1.0783e+00,\n", 1303 | " 8.0203e-01, 2.7291e+00],\n", 1304 | " [ 5.3636e-04, 1.1620e+00, -6.7941e-01, ..., 4.0094e-01,\n", 1305 | " 4.4432e-01, 1.2459e+00],\n", 1306 | " [-1.9921e+00, 1.9394e+00, 1.1929e+00, ..., 1.3809e+00,\n", 1307 | " 1.5925e+00, -6.6769e-02]],\n", 1308 | "\n", 1309 | " [[ 1.2273e+00, 1.5531e+00, 6.8027e-01, ..., -1.2246e+00,\n", 1310 | " 1.3223e+00, 5.2772e-01],\n", 1311 | " [ 1.4708e+00, 1.3777e+00, -5.5454e-01, ..., -2.2135e+00,\n", 1312 | " 3.2608e-01, 5.2835e-01],\n", 1313 | " [ 5.8191e-01, 1.6842e+00, 6.7932e-01, ..., -1.4999e+00,\n", 1314 | " 7.3916e-01, 9.5131e-01],\n", 1315 | " ...,\n", 1316 | " [ 1.1816e+00, 7.8476e-02, 4.2217e-01, ..., -1.4079e+00,\n", 1317 | " 2.3544e+00, -1.6094e+00],\n", 1318 | " [ 1.5487e+00, -3.0230e-01, 3.4758e-01, ..., -9.6911e-01,\n", 1319 | " 2.5784e+00, -5.3619e-01],\n", 1320 | " [ 3.5795e-01, -1.6161e-02, 3.0310e-01, ..., -3.4414e-01,\n", 1321 | " 6.4354e-01, -4.1566e-01]],\n", 1322 | "\n", 1323 | " [[-6.1810e-01, 5.3711e-01, -1.5970e+00, ..., -2.9981e-01,\n", 1324 | " 1.0876e+00, 9.1181e-01],\n", 1325 | " [-7.5442e-01, -5.0500e-01, -1.4781e+00, ..., -1.8957e+00,\n", 1326 | " 1.0061e+00, 5.5207e-01],\n", 1327 | " [ 1.6183e+00, -5.9132e-01, -1.3302e+00, ..., -2.4762e+00,\n", 1328 | " 1.0373e+00, 1.2813e+00],\n", 1329 | " ...,\n", 1330 | " [-3.2647e-01, -8.3790e-01, 1.7109e-01, ..., -3.6162e-01,\n", 1331 | " 1.3515e+00, 4.9251e-01],\n", 1332 | " [-1.2880e+00, -6.1401e-01, -9.2909e-01, ..., -9.7167e-01,\n", 1333 | " -1.1015e-01, -7.4352e-01],\n", 1334 | " [ 1.9131e+00, 1.8491e-01, 9.8593e-01, ..., -7.2909e-01,\n", 1335 | " -1.4237e+00, -1.8890e+00]]]], grad_fn=), tensor([[[[ 7.2166e-01, 9.2356e-02, -2.4815e-01, ..., -6.4101e-01,\n", 1336 | " -4.4150e-01, 6.0808e-01],\n", 1337 | " [-6.1065e-02, -9.6135e-02, -1.2675e+00, ..., -2.2511e+00,\n", 1338 | " -3.3434e-01, 4.5125e-01],\n", 1339 | " [-4.9817e-01, 1.2921e-01, -5.3448e-01, ..., 4.9934e-01,\n", 1340 | " -1.4995e+00, 9.6123e-01],\n", 1341 | " ...,\n", 1342 | " [-3.3240e-01, 4.3268e-01, 9.6952e-01, ..., -4.7847e-01,\n", 1343 | " -5.6016e-01, 3.0073e-01],\n", 1344 | " [-9.3579e-01, 5.5315e-01, 1.2958e+00, ..., -1.9586e-02,\n", 1345 | " -4.4452e-01, 1.1003e+00],\n", 1346 | " [-1.4901e+00, -6.1557e-01, 8.5143e-01, ..., -4.3471e-01,\n", 1347 | " -5.0050e-01, 1.4610e+00]],\n", 1348 | "\n", 1349 | " [[ 2.0977e+00, 3.0291e-01, 1.1418e+00, ..., -1.2034e+00,\n", 1350 | " -4.6665e-01, 1.6795e-01],\n", 1351 | " [ 1.3980e+00, -5.4897e-02, 1.6844e-01, ..., -8.8923e-01,\n", 1352 | " 1.5860e-01, 1.9334e+00],\n", 1353 | " [ 2.6188e+00, 8.8978e-01, -6.3999e-01, ..., -9.9829e-01,\n", 1354 | " 7.1889e-01, 1.6423e+00],\n", 1355 | " ...,\n", 1356 | " [ 1.9562e+00, 3.5817e-01, 6.5640e-01, ..., 1.1465e+00,\n", 1357 | " -7.4524e-01, 4.7925e-01],\n", 1358 | " [ 1.6226e+00, -1.0050e+00, 3.0921e-01, ..., 1.4272e+00,\n", 1359 | " -1.9355e-01, 2.1481e+00],\n", 1360 | " [-1.4425e+00, 1.9237e+00, 5.1812e-01, ..., -1.6134e+00,\n", 1361 | " 4.2979e-01, 4.5610e-01]],\n", 1362 | "\n", 1363 | " [[ 2.8024e-01, 3.2364e+00, 4.3362e-01, ..., -7.6633e-01,\n", 1364 | " -1.4491e+00, 9.2567e-02],\n", 1365 | " [ 5.7283e-01, 2.1395e+00, -1.0499e-01, ..., 1.6081e+00,\n", 1366 | " -2.6995e-01, 3.8909e-01],\n", 1367 | " [ 7.6129e-01, 2.8020e+00, -1.1295e+00, ..., -1.0405e+00,\n", 1368 | " -2.4585e-01, -4.9060e-01],\n", 1369 | " ...,\n", 1370 | " [-1.0129e+00, 1.3191e+00, 8.2721e-01, ..., 2.8205e-01,\n", 1371 | " -1.8482e+00, -1.5369e+00],\n", 1372 | " [ 1.1092e-01, 1.0381e+00, 9.5222e-02, ..., 3.7420e-01,\n", 1373 | " -6.0539e-01, -1.3162e-01],\n", 1374 | " [ 1.3017e-01, -6.1051e-01, -3.5774e-02, ..., -4.2809e-01,\n", 1375 | " -4.4291e-01, -1.6395e-01]],\n", 1376 | "\n", 1377 | " ...,\n", 1378 | "\n", 1379 | " [[-5.1135e-03, -2.9618e-01, -1.1733e+00, ..., -1.4387e+00,\n", 1380 | " -7.3877e-01, 5.7792e-01],\n", 1381 | " [-4.5224e-01, -5.6453e-01, -6.1608e-01, ..., -3.6164e-01,\n", 1382 | " -9.0108e-01, -1.4809e-01],\n", 1383 | " [ 1.1111e-01, 1.0815e+00, -1.1741e+00, ..., -2.9468e-01,\n", 1384 | " -4.5984e-01, -4.7170e-01],\n", 1385 | " ...,\n", 1386 | " [-8.2575e-01, -1.1125e+00, 8.0271e-01, ..., 2.3293e+00,\n", 1387 | " -5.1809e-01, 9.4189e-01],\n", 1388 | " [-6.1359e-01, 4.0746e-01, 1.3906e+00, ..., 2.4095e-01,\n", 1389 | " 1.4703e+00, 4.5296e-01],\n", 1390 | " [ 1.6233e-01, -5.6581e-01, 1.1542e+00, ..., -1.0676e+00,\n", 1391 | " 6.9428e-01, -7.7902e-01]],\n", 1392 | "\n", 1393 | " [[-1.0683e+00, -1.3030e+00, 1.8886e-01, ..., -4.6987e-01,\n", 1394 | " -7.5394e-01, 4.5390e-01],\n", 1395 | " [-2.3283e+00, -1.0419e+00, -7.2467e-02, ..., -1.1481e+00,\n", 1396 | " 1.3052e+00, 9.1777e-01],\n", 1397 | " [-2.5894e+00, -2.1335e+00, 1.5768e-01, ..., -9.6493e-01,\n", 1398 | " -1.0538e-01, 1.6921e+00],\n", 1399 | " ...,\n", 1400 | " [ 8.3145e-01, -2.1815e+00, 7.0426e-01, ..., 6.9153e-01,\n", 1401 | " 8.3484e-01, 1.0936e+00],\n", 1402 | " [ 5.8033e-01, -1.3552e+00, -6.7996e-03, ..., -2.0882e-02,\n", 1403 | " -8.3338e-01, 1.3986e+00],\n", 1404 | " [ 1.6030e-01, 6.3602e-01, -2.4299e-02, ..., 8.3047e-01,\n", 1405 | " 2.4757e+00, 7.8426e-01]],\n", 1406 | "\n", 1407 | " [[ 1.9515e+00, 7.8886e-01, 1.3471e+00, ..., 2.6180e-01,\n", 1408 | " -7.4249e-01, 4.0844e-01],\n", 1409 | " [-3.5427e-01, -3.2629e-01, 5.5673e-01, ..., -8.1221e-01,\n", 1410 | " -8.3073e-01, -1.2931e+00],\n", 1411 | " [ 1.0102e+00, 3.1299e-01, 2.5735e-01, ..., -7.9141e-01,\n", 1412 | " -1.4085e+00, 1.6370e-01],\n", 1413 | " ...,\n", 1414 | " [ 1.8661e-01, -9.6310e-01, -2.0772e-03, ..., -2.2752e+00,\n", 1415 | " -2.3609e+00, -8.7359e-01],\n", 1416 | " [-6.1659e-02, 7.7845e-02, -5.8011e-01, ..., -2.2967e+00,\n", 1417 | " -1.4195e+00, -3.4466e-01],\n", 1418 | " [ 1.5605e+00, -1.1121e+00, -6.2791e-02, ..., -5.9877e-01,\n", 1419 | " -3.3908e-01, -7.9289e-01]]]], grad_fn=)), (tensor([[[[ 0.6375, -1.0607, -0.0506, ..., 0.7781, 0.6974, 0.3603],\n", 1420 | " [ 0.4155, -0.8113, 0.4818, ..., 0.4538, -0.5418, -0.6881],\n", 1421 | " [ 0.7690, -0.3836, 0.1099, ..., 0.1694, 0.2906, 0.4316],\n", 1422 | " ...,\n", 1423 | " [ 0.0523, -0.2854, -0.2641, ..., 0.9500, -0.1795, 0.6425],\n", 1424 | " [ 0.0485, -0.2620, -0.5422, ..., 2.7049, 0.0583, -0.5140],\n", 1425 | " [ 1.5136, -1.5457, -0.7203, ..., 0.2318, -1.2807, 0.5886]],\n", 1426 | "\n", 1427 | " [[-0.4264, -0.8909, -1.2484, ..., -1.2550, -0.3939, 0.3064],\n", 1428 | " [ 1.0063, -0.6523, -0.7345, ..., 0.7056, -1.3964, 0.9859],\n", 1429 | " [-1.0270, -0.9746, -0.8832, ..., 0.6635, -0.0391, 1.1462],\n", 1430 | " ...,\n", 1431 | " [ 0.6167, -1.4391, -0.0119, ..., 1.5079, -1.4389, 0.9198],\n", 1432 | " [ 0.1429, -0.7633, -0.6261, ..., 1.6468, -0.6130, 1.7652],\n", 1433 | " [ 0.4483, -0.1122, -0.0800, ..., 1.1904, -0.6574, 1.0487]],\n", 1434 | "\n", 1435 | " [[-0.8241, -0.9378, 0.6394, ..., -0.9060, -0.3517, -0.7206],\n", 1436 | " [-1.4427, -0.9361, -0.4451, ..., -1.2586, -1.9313, -0.4450],\n", 1437 | " [-1.3042, -0.1520, -0.7851, ..., -0.7048, -0.9187, -0.2479],\n", 1438 | " ...,\n", 1439 | " [-0.7560, -1.0044, -0.1770, ..., -0.8120, 0.5236, 1.3849],\n", 1440 | " [-0.2566, -0.8949, -0.0375, ..., -0.5776, -0.4621, -0.2088],\n", 1441 | " [ 0.1377, -0.7452, 0.2887, ..., -0.5330, 0.8383, 0.3879]],\n", 1442 | "\n", 1443 | " ...,\n", 1444 | "\n", 1445 | " [[-0.3708, 0.0447, -1.7704, ..., -0.1384, 1.2523, -0.9554],\n", 1446 | " [ 0.3907, -0.0706, -0.2819, ..., 0.8475, 2.6240, 0.0236],\n", 1447 | " [-0.3695, -1.1259, -0.8637, ..., -0.2341, 0.8855, -0.2281],\n", 1448 | " ...,\n", 1449 | " [ 0.2710, -0.2403, -0.8422, ..., 1.1316, -0.2313, -0.1778],\n", 1450 | " [ 0.0365, -1.1763, -0.3787, ..., -0.0527, 1.3224, 0.2266],\n", 1451 | " [-0.2912, -0.7162, -2.0500, ..., -0.8947, -0.7955, 0.3733]],\n", 1452 | "\n", 1453 | " [[ 0.0516, 0.2699, -0.7761, ..., 1.0632, 0.8367, -1.5930],\n", 1454 | " [ 0.5593, 0.2786, 1.7272, ..., -0.4225, 1.6544, -1.6445],\n", 1455 | " [-0.0693, 0.3745, 1.9922, ..., -0.9634, 0.9971, -1.1735],\n", 1456 | " ...,\n", 1457 | " [ 0.6697, 0.9678, 0.3182, ..., -0.4548, 0.3080, 0.4878],\n", 1458 | " [-0.4492, 1.6786, -0.0268, ..., 0.2867, 1.1661, -0.7598],\n", 1459 | " [ 0.7795, 1.6853, 1.8191, ..., -0.3550, 1.1653, 0.3651]],\n", 1460 | "\n", 1461 | " [[ 1.1410, 0.3139, 0.7882, ..., 0.1924, -0.8318, -0.0871],\n", 1462 | " [-1.5768, -0.5493, 0.5358, ..., -1.3087, -0.0706, 1.0150],\n", 1463 | " [ 0.4230, -0.9425, 0.2153, ..., 1.3946, -0.3100, 0.1395],\n", 1464 | " ...,\n", 1465 | " [-1.5951, -1.1437, 0.8828, ..., 0.0509, -0.7027, -0.4917],\n", 1466 | " [-0.4828, -0.5360, -0.3505, ..., -0.5672, -1.9420, -0.5707],\n", 1467 | " [-0.8982, 0.2975, 0.0532, ..., 0.8858, -0.9248, -0.5093]]]],\n", 1468 | " grad_fn=), tensor([[[[ 2.1604e-01, -3.2716e-01, -6.0016e-02, ..., -3.8815e-01,\n", 1469 | " 2.1307e+00, 9.5119e-01],\n", 1470 | " [-5.6246e-01, 1.4467e-01, -8.7067e-01, ..., 9.9591e-02,\n", 1471 | " 3.2224e+00, -2.2483e-01],\n", 1472 | " [-1.5745e+00, 1.2240e-01, -5.4912e-01, ..., 1.0691e+00,\n", 1473 | " 1.8557e+00, 9.8569e-01],\n", 1474 | " ...,\n", 1475 | " [-1.1203e+00, 2.3769e-01, -2.2947e-01, ..., 1.8408e-02,\n", 1476 | " 1.5588e+00, -3.9610e-01],\n", 1477 | " [-9.2727e-01, -8.5821e-01, -1.3439e-01, ..., -1.4810e-01,\n", 1478 | " 1.8561e+00, -1.3696e-01],\n", 1479 | " [ 6.1069e-01, -1.4279e+00, 9.6974e-01, ..., 7.7596e-01,\n", 1480 | " 1.9152e+00, 1.8959e-01]],\n", 1481 | "\n", 1482 | " [[-9.8456e-01, 1.7759e+00, -4.3084e-01, ..., -2.1178e+00,\n", 1483 | " 4.4537e-01, -5.9616e-01],\n", 1484 | " [ 4.0403e-02, 9.6833e-01, 1.0305e+00, ..., -1.1369e+00,\n", 1485 | " 8.3554e-01, 1.0103e+00],\n", 1486 | " [ 4.9437e-01, 2.2574e+00, -2.0949e-01, ..., -6.8047e-01,\n", 1487 | " -7.2000e-01, 2.0187e-01],\n", 1488 | " ...,\n", 1489 | " [-1.0229e-01, 1.1636e+00, 6.4511e-01, ..., 2.4976e-01,\n", 1490 | " 3.2884e-01, 9.1565e-01],\n", 1491 | " [-1.7407e-01, 2.5496e+00, -1.3417e+00, ..., -8.8954e-01,\n", 1492 | " -1.0207e-01, 1.5707e+00],\n", 1493 | " [ 8.2192e-01, 1.0633e+00, -1.2077e+00, ..., -2.9693e-01,\n", 1494 | " -4.7066e-02, 1.5686e+00]],\n", 1495 | "\n", 1496 | " [[-2.0310e+00, -2.1590e+00, 1.2248e-01, ..., -5.0193e-01,\n", 1497 | " -3.5183e-01, -6.8041e-01],\n", 1498 | " [-2.3571e+00, -8.9402e-01, 1.2303e+00, ..., -1.0642e+00,\n", 1499 | " -1.2017e+00, -4.7249e-01],\n", 1500 | " [-1.3213e+00, 5.9326e-01, 4.7095e-01, ..., -2.1528e+00,\n", 1501 | " -1.5252e+00, 2.8977e-01],\n", 1502 | " ...,\n", 1503 | " [-2.5647e+00, 7.1942e-01, -1.2089e-01, ..., -6.6379e-01,\n", 1504 | " -4.5275e-01, 7.1632e-01],\n", 1505 | " [-1.8439e+00, 1.5855e+00, 5.4205e-01, ..., -5.5856e-01,\n", 1506 | " -7.2752e-01, 1.0312e-01],\n", 1507 | " [-2.8934e+00, 4.3890e-01, 6.5109e-01, ..., -4.1851e-02,\n", 1508 | " -3.0042e-01, 6.1471e-02]],\n", 1509 | "\n", 1510 | " ...,\n", 1511 | "\n", 1512 | " [[ 3.7702e-01, -1.0278e+00, -1.2773e+00, ..., 1.7233e-02,\n", 1513 | " 8.5850e-01, -9.9762e-01],\n", 1514 | " [ 2.1208e-01, 1.2807e-01, 3.6567e-01, ..., 5.0373e-01,\n", 1515 | " 8.9837e-01, -9.6280e-01],\n", 1516 | " [ 7.7939e-01, -1.5258e+00, 3.0297e-01, ..., 8.0966e-01,\n", 1517 | " -3.4588e-01, -1.9134e-01],\n", 1518 | " ...,\n", 1519 | " [ 9.0818e-01, -1.2282e+00, -1.4150e+00, ..., 5.0955e-01,\n", 1520 | " -9.4857e-01, 3.0390e-01],\n", 1521 | " [ 9.9823e-01, -1.1528e-01, 1.1589e+00, ..., 3.6995e-01,\n", 1522 | " 1.6192e-01, 9.6493e-01],\n", 1523 | " [-3.5578e-01, -1.2186e+00, -4.7192e-01, ..., -1.2560e+00,\n", 1524 | " 6.4249e-01, -2.2464e-01]],\n", 1525 | "\n", 1526 | " [[-3.5366e-01, 5.6603e-01, -6.3872e-01, ..., 3.0472e-01,\n", 1527 | " 1.2589e-01, -7.9563e-01],\n", 1528 | " [ 4.1194e-01, 2.8440e+00, -7.9065e-01, ..., 5.0764e-01,\n", 1529 | " -7.4191e-01, 4.3073e-01],\n", 1530 | " [-2.0034e-01, 2.0893e+00, 4.4662e-01, ..., -4.5047e-01,\n", 1531 | " -5.6854e-03, 3.4939e-01],\n", 1532 | " ...,\n", 1533 | " [ 2.1475e+00, -1.8466e-01, 1.3608e+00, ..., -1.1318e-01,\n", 1534 | " -1.4896e+00, -8.0938e-01],\n", 1535 | " [-1.2008e+00, 7.2882e-01, 7.7512e-04, ..., -4.6787e-02,\n", 1536 | " -1.5426e+00, 8.7835e-02],\n", 1537 | " [ 1.2643e+00, 1.1919e+00, -1.0268e+00, ..., 1.2820e+00,\n", 1538 | " -6.1119e-01, -7.5837e-02]],\n", 1539 | "\n", 1540 | " [[ 5.5421e-01, 2.3666e+00, 1.7582e-01, ..., 1.1197e-01,\n", 1541 | " -2.0041e+00, 1.2128e-01],\n", 1542 | " [-7.2182e-01, 1.2761e+00, -3.5762e-03, ..., -2.8744e-01,\n", 1543 | " -5.3568e-01, 7.1353e-01],\n", 1544 | " [-7.6959e-01, 1.6598e+00, 1.9997e-01, ..., -1.2437e+00,\n", 1545 | " -9.6927e-01, 2.0338e+00],\n", 1546 | " ...,\n", 1547 | " [-9.8093e-01, 2.7564e+00, 6.6994e-02, ..., -9.9922e-01,\n", 1548 | " -4.7275e-01, -5.4256e-01],\n", 1549 | " [-1.4753e-01, 3.8425e-01, -6.1143e-01, ..., -2.2525e+00,\n", 1550 | " 5.7475e-01, 1.0959e+00],\n", 1551 | " [ 1.0442e-01, 1.1969e+00, -5.7404e-02, ..., -2.0286e+00,\n", 1552 | " 2.1710e-01, 1.3723e-02]]]], grad_fn=), tensor([[[[-0.8190, 0.5142, -1.4268, ..., 0.3575, -0.0051, 1.3459],\n", 1553 | " [-1.1209, 0.2459, -1.2252, ..., 0.5464, 0.7709, 2.7390],\n", 1554 | " [-0.4105, 0.2990, -1.9476, ..., -0.4617, 0.7841, 3.2897],\n", 1555 | " ...,\n", 1556 | " [ 0.5893, -1.2044, -2.7092, ..., -0.3797, -1.2444, 1.4551],\n", 1557 | " [ 1.6152, -1.2261, -0.7019, ..., 0.3614, -0.7336, 0.3843],\n", 1558 | " [ 0.8203, -0.5307, -2.3602, ..., 1.6997, -2.4419, -0.6140]],\n", 1559 | "\n", 1560 | " [[-1.7908, 0.4819, -0.8003, ..., 0.7554, 0.4614, -0.2918],\n", 1561 | " [-0.7962, 0.2361, -0.4922, ..., -0.0065, -0.2105, 1.4508],\n", 1562 | " [-1.2806, -0.5569, 0.1745, ..., 1.0426, 1.6348, -0.6355],\n", 1563 | " ...,\n", 1564 | " [-1.2447, 0.0556, -0.8745, ..., 0.0491, 0.2776, -0.8309],\n", 1565 | " [-0.8902, 0.4068, -1.6391, ..., -1.2868, 0.4114, 0.4718],\n", 1566 | " [ 0.3315, -1.2472, 0.8930, ..., 0.4795, -1.2464, 1.0640]],\n", 1567 | "\n", 1568 | " [[ 1.5506, -0.0655, -0.4002, ..., -0.3001, 0.5147, 0.1180],\n", 1569 | " [ 0.6105, 0.8292, -1.4999, ..., 0.2363, 0.5737, 1.1308],\n", 1570 | " [ 0.6017, -0.9027, -1.1870, ..., -0.0386, 0.4509, 2.1129],\n", 1571 | " ...,\n", 1572 | " [ 0.0670, 0.9083, -0.2928, ..., 0.8511, 0.0886, 1.7198],\n", 1573 | " [-0.7883, 0.3971, -1.0123, ..., 0.3074, -0.4264, 1.1632],\n", 1574 | " [ 1.6792, 1.2156, 0.8388, ..., -1.4789, 2.1182, 2.1061]],\n", 1575 | "\n", 1576 | " ...,\n", 1577 | "\n", 1578 | " [[-1.8800, -0.8865, 1.1234, ..., -1.4806, 0.8737, 0.7305],\n", 1579 | " [-0.2667, -1.1813, 0.7270, ..., -0.9530, 0.7119, -0.5268],\n", 1580 | " [-1.0485, -1.2319, -1.3041, ..., -2.5358, -0.8150, -0.9273],\n", 1581 | " ...,\n", 1582 | " [-0.1511, 1.3015, 0.9410, ..., -0.8009, 0.2178, -0.6611],\n", 1583 | " [ 1.5630, -1.0691, 0.8386, ..., -0.7984, 1.2859, -0.5159],\n", 1584 | " [-1.5367, 1.1079, 0.6445, ..., 2.2334, 0.6768, 0.9514]],\n", 1585 | "\n", 1586 | " [[ 1.3583, -0.1197, 0.5999, ..., -1.3974, -0.7783, 1.0389],\n", 1587 | " [-0.0692, -0.9591, -0.6754, ..., -0.2088, 0.1375, 0.4700],\n", 1588 | " [-0.1723, 0.7199, 0.1524, ..., 0.0225, -1.5446, 0.5900],\n", 1589 | " ...,\n", 1590 | " [-0.3791, 0.5778, 0.3094, ..., 0.8343, -0.5934, 0.4007],\n", 1591 | " [ 0.0423, -0.2357, 0.2266, ..., 0.9670, -1.0640, -0.9361],\n", 1592 | " [-1.7368, -0.3133, 0.8204, ..., -0.1409, -1.2703, -0.8337]],\n", 1593 | "\n", 1594 | " [[-1.3691, 0.0807, -0.8878, ..., 0.9156, -0.6222, -1.1881],\n", 1595 | " [ 0.1363, -1.8240, -0.3501, ..., -0.2703, 0.0168, -2.1925],\n", 1596 | " [-0.6527, -0.3778, -1.0638, ..., -0.0099, 1.2982, -1.5682],\n", 1597 | " ...,\n", 1598 | " [ 1.6455, -0.6602, -1.2317, ..., 1.0569, -0.2511, -1.9743],\n", 1599 | " [ 1.5495, -0.9669, -1.3757, ..., 1.5837, -0.3012, -2.1403],\n", 1600 | " [ 1.3665, 0.0142, -0.4172, ..., 1.5167, 0.0500, 1.7271]]]],\n", 1601 | " grad_fn=), tensor([[[[-1.2461, -0.1124, 0.1706, ..., 0.5002, -1.4231, -0.4108],\n", 1602 | " [-1.0301, -0.1877, 1.5421, ..., 1.2170, -1.0959, -1.1301],\n", 1603 | " [ 0.8284, 0.8393, 0.3205, ..., -0.5026, -3.0008, -1.2326],\n", 1604 | " ...,\n", 1605 | " [-0.7213, -0.2704, 1.3067, ..., -0.7311, -1.3085, -0.4617],\n", 1606 | " [-0.2169, 1.2115, 1.3550, ..., -0.8301, 0.5378, -1.0403],\n", 1607 | " [ 0.3591, -0.0347, 1.0585, ..., -0.2107, 0.2073, -0.4698]],\n", 1608 | "\n", 1609 | " [[ 0.7583, -0.9865, 1.1101, ..., 0.1311, 0.0418, -0.4637],\n", 1610 | " [ 1.6207, -1.3941, 1.0369, ..., 0.5490, 1.5714, -0.4359],\n", 1611 | " [ 0.7278, -0.1013, 1.5101, ..., 0.5439, 0.1141, 0.2020],\n", 1612 | " ...,\n", 1613 | " [ 0.2584, -0.3256, -1.4017, ..., -1.0581, 1.1026, 0.9981],\n", 1614 | " [-0.5833, 1.7367, 0.5675, ..., -0.3452, 0.1426, 0.4089],\n", 1615 | " [-0.0910, 0.6142, 0.4169, ..., 0.0888, -0.2498, -2.6906]],\n", 1616 | "\n", 1617 | " [[-1.5102, 1.0117, -0.6059, ..., 1.1354, 0.7957, 0.0131],\n", 1618 | " [-0.6197, 0.6972, -1.0852, ..., 1.1895, 1.1030, -0.2739],\n", 1619 | " [-0.9637, 0.5988, -1.1823, ..., 1.7660, -0.4396, -0.2999],\n", 1620 | " ...,\n", 1621 | " [ 0.5646, 0.2658, -0.8236, ..., 0.8394, -0.0283, -1.1298],\n", 1622 | " [ 1.2788, 0.0284, -0.2393, ..., 0.8382, 1.6024, -0.7725],\n", 1623 | " [-1.1333, 1.7624, -1.1471, ..., 0.9239, 1.8088, -1.4005]],\n", 1624 | "\n", 1625 | " ...,\n", 1626 | "\n", 1627 | " [[-0.7944, -1.0263, 0.7782, ..., -1.5292, -0.6570, 1.1802],\n", 1628 | " [-0.2381, -0.0353, -0.0869, ..., -1.1211, -1.4598, 1.3562],\n", 1629 | " [-0.6267, -2.5943, 1.0319, ..., 0.3867, -1.1524, 1.4396],\n", 1630 | " ...,\n", 1631 | " [ 0.5348, -0.7972, 1.3008, ..., 1.0386, 0.1470, 1.2778],\n", 1632 | " [-0.2981, 0.3088, -0.1870, ..., -0.2927, -0.1638, 0.2573],\n", 1633 | " [ 1.1004, 0.9219, 0.3092, ..., -0.0740, -0.8648, -0.7523]],\n", 1634 | "\n", 1635 | " [[-0.6090, -0.0152, -0.6405, ..., 0.2399, -0.4228, 0.9649],\n", 1636 | " [ 0.1717, -0.3525, 0.3055, ..., -1.4199, -0.3662, 1.1744],\n", 1637 | " [ 1.5141, -0.6191, -1.1831, ..., 0.0293, -1.3887, 0.5819],\n", 1638 | " ...,\n", 1639 | " [ 0.7900, 0.2947, -0.0464, ..., 0.8441, -1.8202, 1.8005],\n", 1640 | " [ 1.3701, 0.1209, 0.4358, ..., 0.0942, -2.0341, 2.3638],\n", 1641 | " [-0.3664, 1.2977, -0.2643, ..., -0.7570, -1.6616, 0.8313]],\n", 1642 | "\n", 1643 | " [[-1.2750, 1.3808, 0.3599, ..., 0.4753, 2.0567, 0.2973],\n", 1644 | " [-1.1786, 1.2513, -0.5062, ..., 0.8810, -0.0739, 2.7183],\n", 1645 | " [-1.3242, -0.5974, -1.0374, ..., -0.1789, 1.8800, 1.5143],\n", 1646 | " ...,\n", 1647 | " [-0.2064, 1.9762, -0.2209, ..., 0.8421, 0.2804, 1.4169],\n", 1648 | " [-0.5639, 0.3187, -1.0188, ..., 0.6459, 0.2337, 1.4745],\n", 1649 | " [ 0.6266, -0.0963, -0.5175, ..., -1.3442, 1.1582, 1.2530]]]],\n", 1650 | " grad_fn=)), (tensor([[[[ 1.0653, -1.4460, -1.6103, ..., -1.5770, 0.8314, 1.8001],\n", 1651 | " [ 1.1820, -0.4309, -1.5622, ..., -0.5059, 0.2069, 1.3095],\n", 1652 | " [ 0.9034, 0.6384, 0.2626, ..., -2.3903, 1.5276, 1.3996],\n", 1653 | " ...,\n", 1654 | " [ 0.9142, -0.2006, -1.7885, ..., -1.3486, 2.0745, 1.5819],\n", 1655 | " [ 2.1092, 0.4930, -1.1592, ..., -1.3361, 1.0135, 1.8803],\n", 1656 | " [ 1.7846, -0.7889, -0.6335, ..., -1.6189, 0.3028, 2.0726]],\n", 1657 | "\n", 1658 | " [[-0.9327, 0.5232, -0.3953, ..., 0.8283, -0.7681, 1.4122],\n", 1659 | " [-1.2663, 1.2217, -0.4597, ..., -0.8704, -1.3138, 2.2978],\n", 1660 | " [-0.3304, 0.4148, -0.1473, ..., -0.1019, -1.2333, 0.6905],\n", 1661 | " ...,\n", 1662 | " [-0.4582, 0.3173, -0.6386, ..., -0.7278, 0.3289, 1.9039],\n", 1663 | " [-0.3959, -0.7087, -0.2956, ..., -0.2347, -0.4133, 0.0551],\n", 1664 | " [-0.3392, -1.5320, -0.1130, ..., -0.4031, -1.8364, 1.8707]],\n", 1665 | "\n", 1666 | " [[-0.8630, -0.7352, -1.0010, ..., -1.3779, 1.3046, -1.6086],\n", 1667 | " [-1.8078, -0.0864, -1.4167, ..., -0.4013, 0.1239, -1.2503],\n", 1668 | " [-2.2500, -0.3860, -1.8416, ..., -1.6914, 0.5219, -0.1487],\n", 1669 | " ...,\n", 1670 | " [-0.6478, -0.2467, -1.8465, ..., -0.4889, -0.3073, -0.8373],\n", 1671 | " [-0.0635, 0.5079, -2.1032, ..., -0.1159, -0.2508, -0.1902],\n", 1672 | " [-0.4074, -1.1365, -0.3983, ..., -2.0198, 0.2234, 0.2942]],\n", 1673 | "\n", 1674 | " ...,\n", 1675 | "\n", 1676 | " [[ 0.0859, 0.5364, 0.2800, ..., -0.2320, -1.1693, -0.5954],\n", 1677 | " [ 0.4458, 0.9618, -0.6096, ..., -1.0331, 0.0398, 0.4012],\n", 1678 | " [ 1.3557, 0.8665, 0.5443, ..., -0.7382, -1.3548, -0.9165],\n", 1679 | " ...,\n", 1680 | " [ 1.4020, 0.0607, -0.0089, ..., -0.7473, 0.2765, 0.5297],\n", 1681 | " [-0.2126, -0.1436, -0.0338, ..., -0.4384, 0.3610, -0.6730],\n", 1682 | " [ 0.2582, -1.0557, -0.0875, ..., -0.9816, 0.2733, -0.7688]],\n", 1683 | "\n", 1684 | " [[-1.7211, -2.3619, 1.2535, ..., -2.6890, -1.3875, -0.1377],\n", 1685 | " [-1.6555, -3.9286, -0.0142, ..., -0.8403, 0.3068, -0.5605],\n", 1686 | " [-2.5341, -2.1949, 0.8729, ..., -0.5987, -0.7153, -1.7104],\n", 1687 | " ...,\n", 1688 | " [-2.3736, -1.9334, 0.9578, ..., 0.3520, -1.2120, -0.8045],\n", 1689 | " [-0.7035, -2.1713, -0.3638, ..., -0.7670, 0.2852, -0.2776],\n", 1690 | " [-1.0691, -2.9295, 0.5920, ..., -1.0829, -0.2905, -0.7756]],\n", 1691 | "\n", 1692 | " [[-0.8136, -0.8202, -0.0355, ..., 0.9911, -0.0250, 0.7227],\n", 1693 | " [-0.4566, 0.4636, -0.3604, ..., -0.1232, 1.0978, 0.9319],\n", 1694 | " [-0.0643, -1.2891, -0.1833, ..., 0.9718, 0.7277, 0.9564],\n", 1695 | " ...,\n", 1696 | " [-0.7778, -0.0555, -1.2350, ..., 1.5345, 0.1168, -0.2632],\n", 1697 | " [-0.9038, -0.3976, -2.0099, ..., 0.7643, -0.2821, 0.3548],\n", 1698 | " [-1.1800, -0.7831, 0.4600, ..., 0.1622, -0.7997, -0.8573]]]],\n", 1699 | " grad_fn=), tensor([[[[ 0.1050, -1.9242, 0.1407, ..., -0.1992, -1.2297, -0.7038],\n", 1700 | " [-0.8176, -2.2313, -0.0209, ..., -0.6400, 0.6857, -0.9931],\n", 1701 | " [ 1.2991, -0.5896, -1.5006, ..., -1.1200, 0.5632, -2.2305],\n", 1702 | " ...,\n", 1703 | " [ 0.3802, -1.9609, 0.3652, ..., -1.7267, 0.1589, -0.4516],\n", 1704 | " [-0.6999, -0.6149, -0.3793, ..., -0.8012, -0.6511, -1.0071],\n", 1705 | " [-0.4425, -0.2534, 0.1809, ..., -1.5360, -0.6365, -0.3229]],\n", 1706 | "\n", 1707 | " [[ 1.2169, 1.3213, -0.1206, ..., -0.6916, -0.9117, -1.8586],\n", 1708 | " [-0.6221, 1.1419, 1.0351, ..., -1.0688, 0.4392, -1.7376],\n", 1709 | " [ 0.3915, 0.6196, -0.5389, ..., -0.1850, -1.2881, 0.0610],\n", 1710 | " ...,\n", 1711 | " [ 0.1092, 0.1168, -0.5731, ..., 0.3604, -1.1335, -0.0610],\n", 1712 | " [ 0.0986, -0.6574, -1.0962, ..., -0.3127, -0.0960, -0.7501],\n", 1713 | " [-0.9653, -0.9011, -0.5328, ..., 0.3781, -0.3899, -0.8949]],\n", 1714 | "\n", 1715 | " [[-0.5427, -0.7087, -1.0783, ..., -0.8471, 1.3731, 0.3029],\n", 1716 | " [ 0.7566, 0.6164, -1.4856, ..., 0.2284, 0.3789, 1.3227],\n", 1717 | " [ 0.5542, 0.2869, -0.3346, ..., 0.2194, 0.5483, 0.8676],\n", 1718 | " ...,\n", 1719 | " [ 0.5289, -0.8943, -1.2453, ..., -0.7373, 1.2142, -1.2769],\n", 1720 | " [-0.5359, 0.4493, -0.3293, ..., -0.5608, 2.1688, -0.1618],\n", 1721 | " [-0.4045, 0.9089, -0.5781, ..., -0.3523, 1.1257, -0.8119]],\n", 1722 | "\n", 1723 | " ...,\n", 1724 | "\n", 1725 | " [[-0.3212, 1.4467, 2.3679, ..., 2.5355, -1.6730, -1.4317],\n", 1726 | " [-0.0684, 0.9853, 0.6572, ..., 0.8578, -0.7068, -1.6395],\n", 1727 | " [-0.1489, 1.4896, 0.9486, ..., 0.8751, -1.0524, -1.3877],\n", 1728 | " ...,\n", 1729 | " [ 0.4564, 1.7123, 0.7416, ..., 1.0830, -0.2973, -1.2881],\n", 1730 | " [ 0.1984, 2.3722, 0.1560, ..., 0.4196, -0.8805, -0.5682],\n", 1731 | " [ 0.0104, 0.7799, 1.4276, ..., 1.0256, -1.2522, -0.2617]],\n", 1732 | "\n", 1733 | " [[-1.2435, 0.2645, 0.5274, ..., 1.8309, -0.1966, -0.8077],\n", 1734 | " [-0.3734, 1.2999, -0.0471, ..., 1.3048, -0.1147, -1.6087],\n", 1735 | " [-0.3049, 0.9447, -0.3186, ..., 2.6523, -0.2159, -0.1191],\n", 1736 | " ...,\n", 1737 | " [ 0.0324, -0.5511, -0.1347, ..., 2.3630, -0.6616, -1.6717],\n", 1738 | " [ 0.5933, -0.2154, -0.2313, ..., 0.3746, -0.4208, -0.3312],\n", 1739 | " [-0.5182, -0.1256, -0.1940, ..., 1.5194, -0.3323, -0.6570]],\n", 1740 | "\n", 1741 | " [[-0.5339, 1.7173, -0.5979, ..., 0.3987, 1.1125, -0.1459],\n", 1742 | " [-0.1437, 1.6214, 0.6224, ..., 0.9200, 1.0099, 0.6049],\n", 1743 | " [-0.3922, 0.9844, 1.0659, ..., 1.0976, -0.8634, 0.3848],\n", 1744 | " ...,\n", 1745 | " [-0.1447, 0.5206, 1.3969, ..., 1.0450, 1.7714, -0.4603],\n", 1746 | " [-0.3109, 2.5772, 0.3873, ..., 0.0242, 1.3412, -1.8854],\n", 1747 | " [-0.5794, 1.4301, -0.3009, ..., -0.4938, -0.2519, -0.8620]]]],\n", 1748 | " grad_fn=), tensor([[[[ 3.5884e-01, 1.5489e+00, 4.6702e-01, ..., 2.3582e-01,\n", 1749 | " -1.0350e+00, 2.5012e+00],\n", 1750 | " [-2.2513e-01, 1.7231e+00, 1.0232e-01, ..., 3.4621e-01,\n", 1751 | " -1.6062e+00, 1.8949e+00],\n", 1752 | " [-5.0264e-01, 2.1998e+00, 1.5580e+00, ..., 9.2758e-01,\n", 1753 | " 6.1412e-01, 2.0489e+00],\n", 1754 | " ...,\n", 1755 | " [ 2.5172e+00, 1.0693e-01, 8.1028e-01, ..., 3.5955e-02,\n", 1756 | " -6.9835e-01, 8.9405e-01],\n", 1757 | " [ 2.5367e+00, 1.4478e-01, -3.4576e-02, ..., 6.4797e-01,\n", 1758 | " 1.1206e-01, 9.8735e-01],\n", 1759 | " [ 3.8430e-01, 6.1497e-02, -1.0919e+00, ..., -6.6796e-01,\n", 1760 | " -7.9716e-01, -1.5557e-01]],\n", 1761 | "\n", 1762 | " [[ 5.1218e-01, 1.0114e+00, -9.9042e-02, ..., -1.3558e-02,\n", 1763 | " 6.0653e-02, -4.8712e-01],\n", 1764 | " [ 1.4386e-01, 1.3943e-01, 1.0370e+00, ..., 4.9184e-01,\n", 1765 | " -2.2624e+00, -1.1498e+00],\n", 1766 | " [-1.9003e+00, 7.8038e-01, 4.4414e-01, ..., -3.2482e-01,\n", 1767 | " -3.2954e-01, -1.4257e+00],\n", 1768 | " ...,\n", 1769 | " [ 1.8552e-02, -1.1063e+00, 3.0218e-01, ..., 2.5299e-01,\n", 1770 | " -4.1471e-01, -1.3068e+00],\n", 1771 | " [-5.7917e-01, -1.4744e+00, 1.3458e+00, ..., 1.3185e-02,\n", 1772 | " -1.2953e+00, -1.9154e-01],\n", 1773 | " [ 1.1593e-03, 1.1048e+00, -1.3300e+00, ..., -8.0349e-01,\n", 1774 | " 9.8072e-01, -9.7399e-01]],\n", 1775 | "\n", 1776 | " [[ 5.8680e-02, 1.6701e+00, -1.7631e+00, ..., -1.4910e+00,\n", 1777 | " 6.3294e-01, 5.9995e-01],\n", 1778 | " [ 8.1318e-01, 1.5598e+00, -5.1998e-01, ..., -1.2019e+00,\n", 1779 | " -3.2634e-01, 6.4214e-01],\n", 1780 | " [ 9.6681e-01, 2.0173e+00, -1.4964e+00, ..., -1.3121e+00,\n", 1781 | " 6.4348e-01, 7.1184e-01],\n", 1782 | " ...,\n", 1783 | " [ 7.0589e-01, 6.7730e-01, -1.3638e+00, ..., -2.2973e+00,\n", 1784 | " 3.8681e-01, 2.0984e+00],\n", 1785 | " [-8.4985e-01, 9.6681e-01, -6.8335e-01, ..., -1.6955e+00,\n", 1786 | " 1.8612e-01, 1.8483e+00],\n", 1787 | " [-1.3403e-01, 8.9765e-01, -5.5302e-01, ..., -6.3939e-01,\n", 1788 | " -3.0584e-01, 9.9148e-01]],\n", 1789 | "\n", 1790 | " ...,\n", 1791 | "\n", 1792 | " [[ 4.8541e-02, -1.3876e+00, 1.1273e+00, ..., 2.7735e-01,\n", 1793 | " 7.7170e-01, 1.1573e+00],\n", 1794 | " [-1.5975e+00, -9.7451e-02, 1.9423e+00, ..., -6.4860e-01,\n", 1795 | " 1.0416e+00, 1.0531e+00],\n", 1796 | " [-7.2138e-01, -8.8940e-02, 3.9397e-01, ..., -1.0589e+00,\n", 1797 | " 1.3877e+00, 4.5738e-01],\n", 1798 | " ...,\n", 1799 | " [-4.0014e-01, 4.1524e-02, -4.7060e-01, ..., 7.7953e-01,\n", 1800 | " 7.4016e-01, 2.9395e-01],\n", 1801 | " [-1.2465e-01, 7.3254e-01, 8.9152e-01, ..., 3.9121e-01,\n", 1802 | " -2.4828e-01, 1.0548e+00],\n", 1803 | " [ 2.5064e-02, -3.5634e-01, 2.1089e+00, ..., 3.0338e-01,\n", 1804 | " -1.7687e+00, 8.3406e-01]],\n", 1805 | "\n", 1806 | " [[ 7.0031e-01, 5.4852e-01, -2.1815e+00, ..., -1.9994e+00,\n", 1807 | " -6.8685e-01, 5.9131e-01],\n", 1808 | " [ 4.4750e-01, -1.2232e-01, -1.9971e+00, ..., -1.1358e+00,\n", 1809 | " -6.5962e-01, 1.5463e-01],\n", 1810 | " [ 3.6172e-01, 2.3036e-01, -2.3862e-01, ..., -5.9104e-01,\n", 1811 | " -6.8759e-02, 2.1094e+00],\n", 1812 | " ...,\n", 1813 | " [ 1.0439e+00, 8.3086e-01, -1.4174e+00, ..., -2.7300e-01,\n", 1814 | " 1.2749e+00, 6.4334e-01],\n", 1815 | " [-1.5498e-01, 6.4605e-01, -3.0370e+00, ..., 5.5453e-01,\n", 1816 | " 3.5098e-01, 1.0418e+00],\n", 1817 | " [ 8.4133e-01, 6.3494e-01, -1.5256e+00, ..., 5.6155e-01,\n", 1818 | " 2.2164e-01, 1.3303e+00]],\n", 1819 | "\n", 1820 | " [[ 3.5086e-01, -2.6934e+00, 1.8900e+00, ..., 2.7706e-01,\n", 1821 | " -1.7713e+00, -5.4896e-01],\n", 1822 | " [-4.9080e-01, -1.3102e+00, 1.1192e+00, ..., 8.1749e-01,\n", 1823 | " -2.6997e-01, -7.8137e-02],\n", 1824 | " [ 3.9024e-01, -1.6240e+00, 5.4428e-02, ..., 1.0305e+00,\n", 1825 | " 1.1196e+00, 8.5720e-01],\n", 1826 | " ...,\n", 1827 | " [ 1.5235e+00, -4.2211e-01, 1.7898e+00, ..., -1.7621e+00,\n", 1828 | " -7.1187e-01, -7.1535e-01],\n", 1829 | " [ 2.3102e+00, 2.5573e-01, 7.9558e-01, ..., -4.3334e-01,\n", 1830 | " -7.5122e-01, -9.9735e-01],\n", 1831 | " [ 9.2497e-01, 1.1978e-02, 8.4687e-01, ..., 6.4733e-01,\n", 1832 | " 8.3767e-01, -4.7028e-01]]]], grad_fn=), tensor([[[[-0.4291, 0.2797, 0.2939, ..., -0.0246, -1.8422, 0.6730],\n", 1833 | " [-0.7417, -0.1985, -0.7843, ..., -0.5578, -2.1673, 0.3723],\n", 1834 | " [-0.2529, 1.0896, -0.0478, ..., 0.1942, -1.8731, -0.6244],\n", 1835 | " ...,\n", 1836 | " [-1.5328, 0.5242, 0.6020, ..., 0.7233, -1.0762, 0.2737],\n", 1837 | " [-1.3199, -0.1735, -0.0152, ..., -0.6772, -0.2726, 0.2997],\n", 1838 | " [-0.4179, 0.8824, -0.8403, ..., -0.9326, -1.8911, 0.5420]],\n", 1839 | "\n", 1840 | " [[ 0.1178, 0.3802, 1.0362, ..., 2.0339, 1.7753, 0.3305],\n", 1841 | " [-0.9190, 0.8598, -0.3032, ..., 2.6564, -0.1650, 0.2351],\n", 1842 | " [ 0.2287, 0.2513, 0.7019, ..., 1.9874, 1.1115, -0.9159],\n", 1843 | " ...,\n", 1844 | " [-2.0236, -0.9571, 0.7523, ..., 1.7443, 0.3429, -1.2633],\n", 1845 | " [-0.4839, -0.7939, 0.6668, ..., 1.2476, 0.9814, -0.8654],\n", 1846 | " [ 0.7459, -0.9426, 1.9026, ..., 1.4149, 0.6527, -0.7114]],\n", 1847 | "\n", 1848 | " [[ 0.2265, -0.0559, -1.6443, ..., 0.3455, -1.7621, 0.4488],\n", 1849 | " [ 1.5075, -0.4946, -0.6771, ..., 1.4164, -1.0055, -0.6555],\n", 1850 | " [ 0.7072, -0.0352, -0.1821, ..., 0.2864, -1.7405, 0.1624],\n", 1851 | " ...,\n", 1852 | " [ 0.7808, 0.4483, -0.1570, ..., -1.1798, -2.0781, -0.1590],\n", 1853 | " [ 0.0759, 2.0654, -0.1245, ..., -0.1025, -2.4559, -0.1774],\n", 1854 | " [ 0.1027, -0.9628, -0.8547, ..., 0.0381, -1.3827, 0.3422]],\n", 1855 | "\n", 1856 | " ...,\n", 1857 | "\n", 1858 | " [[ 1.4165, -0.7402, -0.6242, ..., -0.0344, 0.4237, -0.9502],\n", 1859 | " [ 0.5335, 0.5424, -0.4183, ..., -0.8887, 0.1928, -0.1583],\n", 1860 | " [-0.9727, 0.7735, -0.7910, ..., 0.8576, 0.5522, -0.9837],\n", 1861 | " ...,\n", 1862 | " [ 0.9750, -0.2545, -0.5713, ..., 1.8841, 1.7273, -0.7189],\n", 1863 | " [ 2.1350, -0.6275, -0.1482, ..., 0.4217, -0.9026, -0.4498],\n", 1864 | " [ 0.6508, -2.6836, 0.8756, ..., 1.9014, -1.3769, 0.4723]],\n", 1865 | "\n", 1866 | " [[ 0.2372, 0.2647, -1.6349, ..., 1.4684, -0.5107, -0.4961],\n", 1867 | " [-0.3152, 0.3920, -0.1082, ..., 1.2658, 0.1975, 0.7052],\n", 1868 | " [-0.2919, 0.3785, 0.1558, ..., 2.4242, -0.6418, 0.1311],\n", 1869 | " ...,\n", 1870 | " [ 1.8682, -0.1374, -0.0183, ..., 2.1553, 2.0670, 0.6348],\n", 1871 | " [ 0.9390, -0.4737, -0.2526, ..., 1.8633, 1.1035, 0.2287],\n", 1872 | " [ 0.6928, -0.8184, -1.9767, ..., 1.2096, 0.7671, 0.9799]],\n", 1873 | "\n", 1874 | " [[ 1.6203, -0.4663, -0.8119, ..., -0.1802, 0.5213, 0.1037],\n", 1875 | " [ 0.1444, -0.0541, 0.4377, ..., -0.8731, -0.1312, 0.3210],\n", 1876 | " [ 0.1108, -0.5713, 0.0554, ..., -1.3224, 0.2828, 0.1313],\n", 1877 | " ...,\n", 1878 | " [-0.1666, -1.0314, -0.0089, ..., -0.2270, -0.8355, 0.6647],\n", 1879 | " [-0.9226, -1.5954, -0.9909, ..., -0.0871, -1.4400, 0.5527],\n", 1880 | " [-1.8308, -2.2926, -1.8992, ..., -1.0797, -1.3639, 0.2125]]]],\n", 1881 | " grad_fn=)), (tensor([[[[-1.4054, 0.0716, 0.2665, ..., -0.1182, -0.9591, -0.4294],\n", 1882 | " [-1.4375, 2.3722, 0.7419, ..., 0.2924, 0.9828, 0.5136],\n", 1883 | " [-1.9794, 1.9216, -0.6382, ..., 0.0448, 0.8565, 0.1160],\n", 1884 | " ...,\n", 1885 | " [-1.6712, 1.6570, -0.7641, ..., -1.4405, 0.8521, 0.8725],\n", 1886 | " [-1.8452, 0.5742, -1.8533, ..., -0.4928, 1.1915, 0.7424],\n", 1887 | " [-1.5364, 1.3765, -0.0041, ..., -1.7990, 0.9832, 0.2839]],\n", 1888 | "\n", 1889 | " [[ 1.8146, 1.2591, 1.3408, ..., 2.0454, 0.2764, -0.0131],\n", 1890 | " [ 0.2781, 0.7022, 0.1330, ..., 1.1280, -0.2750, -0.3423],\n", 1891 | " [ 0.6432, 1.3591, 0.9497, ..., 1.2362, -0.8214, -0.5320],\n", 1892 | " ...,\n", 1893 | " [ 1.2585, 0.8809, 0.1666, ..., 1.8216, -0.0751, -0.1766],\n", 1894 | " [ 0.4529, -0.0451, 0.0377, ..., 0.3237, -0.1416, -0.9294],\n", 1895 | " [ 0.6558, -0.5827, 0.9018, ..., -0.1717, 0.1311, -1.0932]],\n", 1896 | "\n", 1897 | " [[-0.4996, -0.6197, 0.4406, ..., -0.7244, 0.6313, -0.3556],\n", 1898 | " [-0.9703, 1.3462, -0.7029, ..., 0.3204, -0.8503, -0.4265],\n", 1899 | " [-0.5130, 1.2367, -0.1760, ..., -0.3051, -1.1487, -1.4351],\n", 1900 | " ...,\n", 1901 | " [ 0.0585, 0.6956, 0.2807, ..., 0.8546, -2.4390, -1.0533],\n", 1902 | " [-0.3201, 0.7599, 0.9090, ..., -0.1490, -1.4038, -0.4477],\n", 1903 | " [-1.3182, -0.0718, 1.7005, ..., -0.7101, -1.8313, -0.1907]],\n", 1904 | "\n", 1905 | " ...,\n", 1906 | "\n", 1907 | " [[-2.0253, -1.4171, 1.6486, ..., -1.5917, -1.0295, -0.0449],\n", 1908 | " [-2.1077, -1.8998, 0.8948, ..., -0.5774, -2.1879, 1.2028],\n", 1909 | " [-1.9595, -1.1914, 1.4414, ..., -1.0022, -0.5618, -0.4454],\n", 1910 | " ...,\n", 1911 | " [-0.5643, -0.9385, -0.5579, ..., 0.4425, -1.1162, -0.1750],\n", 1912 | " [-1.3826, 0.4407, -0.9555, ..., -1.4241, -1.4659, 0.5416],\n", 1913 | " [-0.4747, -0.5000, -0.1089, ..., -1.7240, -2.2166, -0.0581]],\n", 1914 | "\n", 1915 | " [[-0.6875, 0.4304, -0.1366, ..., -1.8034, -0.8562, -2.4171],\n", 1916 | " [ 1.5846, 0.0102, -0.2936, ..., -1.4009, -0.8968, -1.6206],\n", 1917 | " [ 0.1738, -0.4552, -1.2048, ..., -0.2263, -0.7861, -2.4550],\n", 1918 | " ...,\n", 1919 | " [ 0.4288, 0.4469, 0.0859, ..., -1.1072, -0.7467, -1.7147],\n", 1920 | " [ 1.3774, -0.0117, -0.0119, ..., 0.4211, 0.5387, -1.5320],\n", 1921 | " [ 0.7774, 0.3030, -2.2248, ..., 0.1310, -0.3114, -1.7080]],\n", 1922 | "\n", 1923 | " [[-0.6575, -0.0282, -0.8773, ..., 0.4285, -1.9353, 1.2520],\n", 1924 | " [-0.7403, -0.8275, 0.3297, ..., -0.0899, -1.8216, 0.3093],\n", 1925 | " [-0.9039, -0.5990, -0.1958, ..., -0.7971, -2.1345, 0.5924],\n", 1926 | " ...,\n", 1927 | " [ 0.5450, -0.9100, 0.7460, ..., 0.7405, -1.7860, 1.6998],\n", 1928 | " [ 0.1575, -0.8660, -0.0961, ..., -0.7950, 0.0662, 1.8594],\n", 1929 | " [ 0.7890, 0.2270, 0.4538, ..., 1.7889, 0.1651, 1.6824]]]],\n", 1930 | " grad_fn=), tensor([[[[ 1.3048, -0.1944, -0.6868, ..., 0.1061, 0.7535, -0.7652],\n", 1931 | " [ 0.3602, 1.2934, -0.6324, ..., 1.1210, 1.4997, -0.0821],\n", 1932 | " [ 0.8313, 1.5722, -1.0181, ..., 0.0622, 0.8999, 0.2330],\n", 1933 | " ...,\n", 1934 | " [ 1.6891, 2.4188, -0.4573, ..., -0.3524, 0.8685, -0.0108],\n", 1935 | " [ 1.2398, 2.0613, -1.7916, ..., -0.2155, 0.8270, 0.7364],\n", 1936 | " [ 0.4756, 1.9869, -1.7419, ..., -0.3314, -0.0552, -0.8455]],\n", 1937 | "\n", 1938 | " [[-2.3166, -0.4643, 0.1283, ..., 0.1000, -1.2126, -0.6845],\n", 1939 | " [-0.7724, -2.4609, -0.4131, ..., -0.0260, 0.1481, -1.4285],\n", 1940 | " [ 0.3064, -1.6924, -0.9233, ..., -1.4981, -0.9328, -0.6733],\n", 1941 | " ...,\n", 1942 | " [-1.2718, -0.8587, -0.1554, ..., -2.0287, -0.9810, 0.8588],\n", 1943 | " [-0.5451, -1.5843, 0.2066, ..., -1.5006, -1.4305, 0.5464],\n", 1944 | " [-0.5885, -1.6579, 0.8426, ..., -0.9352, -0.7217, -0.2090]],\n", 1945 | "\n", 1946 | " [[-0.7246, -1.2406, 0.7494, ..., 1.0922, -0.8172, -0.8970],\n", 1947 | " [ 0.4402, -1.1580, -0.5777, ..., 0.7488, -1.5574, -0.4369],\n", 1948 | " [ 0.1570, -1.6807, 0.6517, ..., -0.2079, 0.2623, -0.4745],\n", 1949 | " ...,\n", 1950 | " [ 0.1862, -2.4092, 0.7655, ..., 1.5929, -0.0506, 0.2315],\n", 1951 | " [ 1.4994, -1.3696, 0.2104, ..., 0.7831, 0.0617, 1.3932],\n", 1952 | " [ 0.5736, -2.5027, -0.9550, ..., 1.4433, -0.4241, 0.4531]],\n", 1953 | "\n", 1954 | " ...,\n", 1955 | "\n", 1956 | " [[-1.1456, 1.4274, -0.0971, ..., -0.0332, 0.7760, -0.5553],\n", 1957 | " [-1.4641, 1.2281, -2.4168, ..., 0.0912, 0.4071, 0.5747],\n", 1958 | " [-1.7980, 0.8296, -1.9186, ..., 0.6895, 1.2305, -0.4543],\n", 1959 | " ...,\n", 1960 | " [-1.8191, -0.1602, -1.6428, ..., 0.6077, 0.9583, 0.5012],\n", 1961 | " [-1.4964, -1.1509, -1.8888, ..., -0.6018, -1.0186, 0.3179],\n", 1962 | " [-1.2357, -0.7655, -0.7183, ..., 0.1021, -0.4915, 1.2447]],\n", 1963 | "\n", 1964 | " [[-0.5546, -0.5167, 0.3910, ..., 0.8552, 0.8968, -1.1778],\n", 1965 | " [-0.3948, -0.1500, 0.2442, ..., -0.5085, 1.0728, -1.1938],\n", 1966 | " [-1.1487, -0.2539, 0.9514, ..., 0.6955, 0.8614, -0.1414],\n", 1967 | " ...,\n", 1968 | " [-0.1534, 0.9755, 0.6401, ..., 0.5958, 1.5979, -0.6915],\n", 1969 | " [ 0.0292, 0.6798, 0.0423, ..., 0.5948, 0.6160, -1.0440],\n", 1970 | " [ 0.0530, 0.2272, 1.0485, ..., 0.2091, 1.1705, 0.1666]],\n", 1971 | "\n", 1972 | " [[ 0.1936, 1.9948, -0.1061, ..., -1.2181, -1.7451, 2.0267],\n", 1973 | " [ 0.0182, 0.6760, -1.7794, ..., 0.2023, 0.7849, 1.4166],\n", 1974 | " [-0.5239, 0.6800, -1.1097, ..., -0.1999, -1.3890, 3.0478],\n", 1975 | " ...,\n", 1976 | " [-1.1398, 1.1164, -0.8705, ..., -0.4026, -0.7764, 1.7004],\n", 1977 | " [-0.6619, 0.9795, 0.2409, ..., 0.7580, -0.9340, 2.4769],\n", 1978 | " [-0.3709, 1.9305, -2.3658, ..., 0.0259, -0.6072, 1.0998]]]],\n", 1979 | " grad_fn=), tensor([[[[-1.6982, 2.3730, 0.3344, ..., -0.2368, -1.9284, 0.3737],\n", 1980 | " [-1.6568, -0.6589, 0.7648, ..., -0.2621, -1.1332, 0.9346],\n", 1981 | " [-2.3218, 0.4707, -0.6835, ..., -0.3221, -0.8742, 0.8575],\n", 1982 | " ...,\n", 1983 | " [ 0.3609, 0.8015, -1.0954, ..., 1.1052, -0.9942, 1.1413],\n", 1984 | " [ 0.4414, 0.6029, 0.9387, ..., 1.2829, -1.0927, 0.7384],\n", 1985 | " [-1.4575, 0.0817, 0.5124, ..., 1.1680, -0.6402, -0.6056]],\n", 1986 | "\n", 1987 | " [[-1.1110, -0.4406, 1.3181, ..., -1.2183, -0.4206, -1.3432],\n", 1988 | " [ 0.0820, -0.0948, 1.3882, ..., -1.1034, -0.3267, -1.9535],\n", 1989 | " [ 0.4780, 0.0892, -0.0702, ..., 0.0464, 0.6126, -0.5235],\n", 1990 | " ...,\n", 1991 | " [-0.5134, -0.1886, 0.7232, ..., -1.5733, 0.8252, -1.0080],\n", 1992 | " [-0.6431, 0.4673, -0.6759, ..., -1.8117, 0.2816, -0.6746],\n", 1993 | " [-1.7421, 1.1500, -0.9018, ..., -0.1793, -0.8702, -1.6551]],\n", 1994 | "\n", 1995 | " [[-0.4298, -1.6662, -0.6432, ..., 0.9355, -0.0912, -0.0464],\n", 1996 | " [-0.0353, -2.2176, -0.3661, ..., 0.4481, 0.2588, 0.8578],\n", 1997 | " [ 0.1809, -1.6518, -1.4989, ..., 0.1565, 0.4638, 1.1074],\n", 1998 | " ...,\n", 1999 | " [-0.3225, -0.2007, 0.0663, ..., 0.0906, 0.5708, 0.7598],\n", 2000 | " [ 1.2155, -0.6403, -0.6246, ..., 0.2112, 1.0863, 0.8600],\n", 2001 | " [-1.5486, -0.6468, -0.1580, ..., -1.1927, -1.1401, 0.9541]],\n", 2002 | "\n", 2003 | " ...,\n", 2004 | "\n", 2005 | " [[ 0.3736, -1.3817, 0.0828, ..., -0.5443, 1.3230, 1.5176],\n", 2006 | " [ 2.1579, -1.3366, 1.2812, ..., 1.2565, -0.9004, 1.4949],\n", 2007 | " [ 2.0221, -1.6790, 1.1706, ..., -0.7868, -0.0580, 1.4813],\n", 2008 | " ...,\n", 2009 | " [ 0.6954, -1.4955, 0.5084, ..., 0.9527, 1.8826, -0.4720],\n", 2010 | " [ 1.1418, -2.1720, 0.7583, ..., -0.1777, 1.1263, -0.2029],\n", 2011 | " [-0.4081, -0.7484, -0.2834, ..., 0.6386, 0.2259, 0.2769]],\n", 2012 | "\n", 2013 | " [[-2.3668, -2.2414, 0.2678, ..., -0.9640, 2.1260, -1.2007],\n", 2014 | " [-1.4113, -2.4533, -0.0571, ..., -0.6361, 1.8075, -0.5678],\n", 2015 | " [-2.3656, -1.6844, 0.1143, ..., -0.2386, 2.2545, -1.8532],\n", 2016 | " ...,\n", 2017 | " [-0.5174, -1.3068, 0.7565, ..., -0.4265, 1.1074, -0.1870],\n", 2018 | " [-1.3409, -0.1959, 0.6487, ..., -0.6916, 0.2856, 0.1339],\n", 2019 | " [-1.6683, -1.6219, 1.8971, ..., -0.1236, 0.4171, 0.2172]],\n", 2020 | "\n", 2021 | " [[-0.4976, -1.0679, -1.9941, ..., 0.1339, -0.7753, -0.4244],\n", 2022 | " [-1.4589, -0.9864, -2.2164, ..., -0.1309, -2.4254, -0.1335],\n", 2023 | " [-0.1325, -1.3582, -1.1422, ..., 1.3284, -0.7869, -0.3660],\n", 2024 | " ...,\n", 2025 | " [-2.3823, 1.2326, 0.0479, ..., -1.0104, -0.5868, 0.5917],\n", 2026 | " [-0.9012, 0.4950, 0.0202, ..., -0.2959, -0.5830, -1.3058],\n", 2027 | " [ 0.6635, 0.2663, 0.0735, ..., 0.6158, -0.5385, 0.1823]]]],\n", 2028 | " grad_fn=), tensor([[[[-2.2034e+00, 1.7162e-01, -2.5471e-01, ..., 2.4659e-01,\n", 2029 | " -1.1306e+00, 1.5210e+00],\n", 2030 | " [-1.8214e+00, 9.1974e-01, 1.0524e+00, ..., 4.6712e-01,\n", 2031 | " 7.1585e-03, 1.7431e+00],\n", 2032 | " [-1.7944e+00, 1.2111e+00, 1.7746e+00, ..., -8.1263e-02,\n", 2033 | " -7.2423e-01, 2.3455e+00],\n", 2034 | " ...,\n", 2035 | " [-9.1744e-01, 1.6548e+00, 1.6892e+00, ..., -4.1003e-01,\n", 2036 | " -2.6266e-01, 1.1222e+00],\n", 2037 | " [ 8.4884e-01, 1.7337e+00, 1.3689e+00, ..., -6.5879e-02,\n", 2038 | " -1.8267e+00, 7.5135e-02],\n", 2039 | " [ 3.5726e-01, 4.9505e-01, 1.2454e+00, ..., -9.6814e-01,\n", 2040 | " -2.9233e-01, 8.4373e-01]],\n", 2041 | "\n", 2042 | " [[-8.3438e-01, 1.5580e+00, 6.8650e-01, ..., -5.4564e-01,\n", 2043 | " 8.9563e-01, 1.2806e+00],\n", 2044 | " [ 4.9233e-03, 1.3172e+00, 5.3151e-02, ..., -6.8951e-01,\n", 2045 | " 2.8390e-01, 7.1242e-01],\n", 2046 | " [-3.2568e-01, 4.4483e-01, -1.7829e-01, ..., -1.4768e+00,\n", 2047 | " 1.7114e+00, 4.1612e-01],\n", 2048 | " ...,\n", 2049 | " [ 1.6133e+00, 1.0262e+00, -2.5962e-01, ..., -8.7402e-01,\n", 2050 | " 1.0925e+00, 1.2186e-01],\n", 2051 | " [ 3.0942e-01, 1.9314e-03, -6.0236e-02, ..., 1.0720e-01,\n", 2052 | " 3.7913e-01, 9.5514e-01],\n", 2053 | " [ 1.1943e+00, 4.4783e-02, 1.6539e-01, ..., -3.4771e-01,\n", 2054 | " -7.5826e-01, -1.4490e-01]],\n", 2055 | "\n", 2056 | " [[ 7.4118e-01, 7.2590e-01, -5.1478e-01, ..., 1.0937e+00,\n", 2057 | " -4.1723e-01, -7.7690e-01],\n", 2058 | " [ 8.4377e-01, 3.0361e-01, -1.5830e+00, ..., 1.4365e+00,\n", 2059 | " 7.4840e-01, -6.4476e-01],\n", 2060 | " [ 8.0190e-01, -5.0738e-01, -1.2730e+00, ..., 1.0945e+00,\n", 2061 | " 7.6041e-02, -1.5358e+00],\n", 2062 | " ...,\n", 2063 | " [-2.5955e-01, -5.6978e-01, -1.0390e+00, ..., 1.4319e+00,\n", 2064 | " 3.3910e-02, 8.1193e-01],\n", 2065 | " [ 9.6365e-01, -9.1730e-01, -5.9856e-01, ..., 3.2434e-01,\n", 2066 | " 2.6688e-01, 1.6109e+00],\n", 2067 | " [ 3.6250e-01, 4.4351e-01, 1.9714e+00, ..., 1.2220e+00,\n", 2068 | " 1.4319e-01, 9.7502e-01]],\n", 2069 | "\n", 2070 | " ...,\n", 2071 | "\n", 2072 | " [[ 2.9864e+00, -1.2450e-01, 1.1634e+00, ..., -2.2415e+00,\n", 2073 | " 2.9488e-01, -1.1858e+00],\n", 2074 | " [ 1.5365e+00, 1.8788e-01, 3.2609e-01, ..., -1.3012e+00,\n", 2075 | " -6.5461e-01, -9.6830e-01],\n", 2076 | " [ 2.0923e+00, -2.8300e-01, 1.7896e+00, ..., -1.2972e+00,\n", 2077 | " -1.1943e+00, -1.6408e+00],\n", 2078 | " ...,\n", 2079 | " [ 5.6653e-01, -8.6010e-01, 1.4121e+00, ..., 1.4983e+00,\n", 2080 | " -9.8427e-01, -2.6941e+00],\n", 2081 | " [ 9.7308e-01, -3.2721e-02, 1.0578e+00, ..., 2.0820e+00,\n", 2082 | " -4.9865e-01, -2.7064e+00],\n", 2083 | " [ 1.8758e+00, -3.2500e-03, 8.9719e-01, ..., -4.9562e-01,\n", 2084 | " -1.5096e-01, -1.1720e+00]],\n", 2085 | "\n", 2086 | " [[ 1.3627e+00, 1.0259e+00, -1.5550e+00, ..., -1.7845e-01,\n", 2087 | " -8.1842e-01, 7.4781e-01],\n", 2088 | " [ 4.5596e-01, -8.1178e-01, -7.6237e-01, ..., -5.1961e-01,\n", 2089 | " 6.6398e-01, 6.9452e-01],\n", 2090 | " [ 5.9729e-01, 8.7549e-01, -1.8603e+00, ..., 6.8863e-01,\n", 2091 | " 4.5351e-01, 1.2662e+00],\n", 2092 | " ...,\n", 2093 | " [ 3.0204e-01, 6.9973e-01, -1.7269e-01, ..., 2.5652e-01,\n", 2094 | " -7.0275e-01, 1.3234e-01],\n", 2095 | " [ 1.6700e+00, -9.5115e-01, 1.1192e-02, ..., 1.4692e-01,\n", 2096 | " -1.1089e+00, -1.3222e+00],\n", 2097 | " [-1.1018e+00, 6.4857e-01, 1.0247e-02, ..., 6.0237e-01,\n", 2098 | " -6.5679e-01, 8.8029e-01]],\n", 2099 | "\n", 2100 | " [[-1.0748e+00, -4.2517e-01, 2.7494e+00, ..., 1.4266e+00,\n", 2101 | " -2.0459e+00, -1.3956e+00],\n", 2102 | " [-3.7759e-01, -1.3627e+00, 1.0223e+00, ..., 1.4736e+00,\n", 2103 | " -7.9030e-01, -4.5627e-01],\n", 2104 | " [-7.0518e-01, -1.5484e+00, 1.1089e+00, ..., 1.9142e+00,\n", 2105 | " -2.6990e-01, 6.8551e-02],\n", 2106 | " ...,\n", 2107 | " [ 7.7728e-01, -6.3474e-01, 2.0541e+00, ..., 3.0278e-01,\n", 2108 | " -3.2087e-02, 6.1043e-01],\n", 2109 | " [ 1.4528e-01, 5.8034e-01, 8.2100e-01, ..., 1.3352e+00,\n", 2110 | " -2.0064e+00, 3.2114e-01],\n", 2111 | " [-4.0202e-02, -1.1691e+00, 1.8921e+00, ..., 5.1390e-01,\n", 2112 | " -1.5008e-01, -7.9119e-01]]]], grad_fn=))), decoder_hidden_states=None, decoder_attentions=None, cross_attentions=None, encoder_last_hidden_state=tensor([[[ 0.4917, -0.2846, 0.2701, ..., -0.8250, 0.8985, 2.9372],\n", 2113 | " [ 0.0000, -0.3958, 0.7851, ..., -0.1896, 0.9054, 1.5924],\n", 2114 | " [ 0.6328, -0.5367, -0.0000, ..., 0.2772, -0.0000, 1.9880],\n", 2115 | " ...,\n", 2116 | " [-0.0158, -0.5645, 0.0859, ..., 1.2770, -1.3071, 0.9423],\n", 2117 | " [-0.7459, -0.8548, -0.7462, ..., 0.7182, -0.9767, 1.1183],\n", 2118 | " [-0.7619, -0.4506, -0.7055, ..., -1.1454, -0.0000, 0.0000]]],\n", 2119 | " grad_fn=), encoder_hidden_states=None, encoder_attentions=None)" 2120 | ] 2121 | }, 2122 | "execution_count": 23, 2123 | "metadata": {}, 2124 | "output_type": "execute_result" 2125 | } 2126 | ], 2127 | "source": [ 2128 | "output" 2129 | ] 2130 | }, 2131 | { 2132 | "cell_type": "code", 2133 | "execution_count": null, 2134 | "metadata": {}, 2135 | "outputs": [], 2136 | "source": [] 2137 | } 2138 | ], 2139 | "metadata": { 2140 | "kernelspec": { 2141 | "display_name": "docformer-v2", 2142 | "language": "python", 2143 | "name": "docformer-v2" 2144 | }, 2145 | "language_info": { 2146 | "codemirror_mode": { 2147 | "name": "ipython", 2148 | "version": 3 2149 | }, 2150 | "file_extension": ".py", 2151 | "mimetype": "text/x-python", 2152 | "name": "python", 2153 | "nbconvert_exporter": "python", 2154 | "pygments_lexer": "ipython3", 2155 | "version": "3.10.13" 2156 | } 2157 | }, 2158 | "nbformat": 4, 2159 | "nbformat_minor": 2 2160 | } 2161 | --------------------------------------------------------------------------------