├── .gitignore ├── Dockerfile ├── HiCFoundation.ipynb ├── LICENSE ├── README.md ├── Reproducibility.ipynb ├── data_processing ├── __init__.py ├── collate_fn.py ├── finetune_dataset.py ├── inference_dataset.py └── pretrain_dataset.py ├── environment.yml ├── environment_notorch.yml ├── example ├── GSM7006609_ValbB8w1081.pairs ├── __init__.py ├── finetune_example │ ├── train.txt │ ├── train │ │ ├── data_0.pkl │ │ ├── data_1.pkl │ │ ├── data_2.pkl │ │ ├── data_3.pkl │ │ └── data_4.pkl │ ├── val.txt │ └── val │ │ ├── data_0.pkl │ │ ├── data_1.pkl │ │ ├── data_2.pkl │ │ ├── data_3.pkl │ │ └── data_4.pkl └── pretrain_example │ ├── train.txt │ ├── train │ ├── data_1.pkl │ ├── data_2.pkl │ ├── data_3.pkl │ ├── data_4.pkl │ ├── data_5.pkl │ ├── data_6.pkl │ ├── data_7.pkl │ └── data_8.pkl │ └── val.txt ├── finetune.py ├── finetune ├── __init__.py ├── loss.py ├── main_worker.py ├── train_epoch.py └── val_epoch.py ├── hicfoundation_model └── __init__.py ├── imgs └── framework_github.png ├── inference.py ├── inference ├── __init__.py ├── inference_worker.py ├── load_model.py └── main_worker.py ├── model ├── Finetune_Model_Head.py ├── NativeScaler.py ├── SSIM.py ├── Vision_Transformer_count.py ├── __init__.py ├── lr_decay.py ├── lr_sched.py ├── model_utils.py ├── models_hicfoundation.py └── pos_embed.py ├── ops ├── Logger.py ├── __init__.py ├── argparser.py ├── calculate_similarity.py ├── distribute_utils.py ├── file_format_convert.py ├── io_utils.py ├── mean_shift_merge.py ├── smooth_matrix.py ├── sparse_ops.py └── train_utils.py ├── pretrain.py ├── pretrain ├── __init__.py ├── main_worker.py ├── train_epoch.py └── val_epoch.py ├── refactor_pretrain ├── Logger.py ├── SSIM.py ├── argparser.py ├── distribute_utils.py ├── main_worker.py ├── model_funcs.py ├── model_utils.py ├── models_hicfoundation.py ├── pretrain.py ├── pretrain_dataset.py ├── train_epoch.py ├── utils.py └── val_epoch.py ├── requirements.txt ├── test ├── test_convert_rgb.py ├── test_forward.py ├── test_loss.py ├── test_masking.py └── test_patchify.py └── utils ├── __init__.py ├── array2bigwig.py ├── array2cool.py ├── array2hic.py ├── cool2array.py ├── hic2array.py ├── hic_coverage.py └── juicer_tools.jar /.gitignore: -------------------------------------------------------------------------------- 1 | *pycache* 2 | .coverage 3 | htmlcov -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use NVIDIA CUDA base image with Ubuntu 2 | FROM nvidia/cuda:11.1.1-devel-ubuntu20.04 3 | 4 | # Set environment variables 5 | ENV DEBIAN_FRONTEND=noninteractive 6 | ENV CONDA_DIR=/opt/conda 7 | ENV PATH=$CONDA_DIR/bin:$PATH 8 | ENV PYTHONPATH=/app:$PYTHONPATH 9 | 10 | # Install system dependencies 11 | # configure hic-straw depdendencies 12 | RUN apt-get update && apt-get install -y \ 13 | wget \ 14 | curl \ 15 | git \ 16 | build-essential \ 17 | cmake \ 18 | libcurl4-openssl-dev \ 19 | zlib1g-dev \ 20 | && rm -rf /var/lib/apt/lists/* 21 | 22 | # Install Miniconda 23 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ 24 | /bin/bash ~/miniconda.sh -b -p $CONDA_DIR && \ 25 | rm ~/miniconda.sh && \ 26 | conda clean -t -i -p -y 27 | 28 | # Set working directory 29 | WORKDIR /app 30 | 31 | # Create conda environment and install dependencies directly 32 | RUN conda create -n HiCFoundation python=3.8.10 -y && \ 33 | conda clean -afy 34 | 35 | # Make RUN commands use the new environment 36 | SHELL ["conda", "run", "-n", "HiCFoundation", "/bin/bash", "-c"] 37 | 38 | # Install conda packages 39 | RUN conda install -c pytorch -c nvidia -c conda-forge -c anaconda -c bioconda -c defaults \ 40 | cudatoolkit=11.1.74 \ 41 | pip=21.1.3 \ 42 | pytorch=1.8.1 \ 43 | torchvision=0.9.1 \ 44 | timm=0.3.2 \ 45 | openjdk \ 46 | pandas \ 47 | matplotlib \ 48 | scipy \ 49 | numba \ 50 | cooler \ 51 | -y && \ 52 | conda clean -afy 53 | 54 | # Install pip packages 55 | RUN pip install \ 56 | easydict \ 57 | opencv-python \ 58 | simplejson \ 59 | lvis \ 60 | "Pillow==9.5.0" \ 61 | pytorch_msssim \ 62 | scikit-image \ 63 | einops \ 64 | tensorboard \ 65 | pyBigWig 66 | 67 | RUN pip install hic-straw 68 | 69 | # Set the default command to activate conda environment 70 | ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "HiCFoundation"] 71 | CMD ["python", "--version"] 72 | -------------------------------------------------------------------------------- /Reproducibility.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "private_outputs": true, 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyOXdz0SpmwfOe1u8Ku2hJho", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "view-in-github", 24 | "colab_type": "text" 25 | }, 26 | "source": [ 27 | "\"Open" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "source": [ 33 | "# HiCFoundation: a generalizable Hi-C foundation model for chromatin architecture, single-cell and multi-omics analysis across species\n", 34 | "**This repo is only for calculating reproducbility score by HiCFoundation**" 35 | ], 36 | "metadata": { 37 | "id": "S-eNVerqcDbI" 38 | } 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "source": [ 43 | "HiCFoundation is a generalizable Hi-C foundation model for chromatin architecture, single-cell and multi-omics analysis across species.\n", 44 | "\n", 45 | "Copyright (C) 2024 Xiao Wang, Yuanyuan Zhang, Suhita Ray, Anupama Jha, Tangqi Fang, Shengqi Hang, Sergei Doulatov, William Stafford Noble, and Sheng Wang\n", 46 | "\n", 47 | "License: Apache License 2.0\n", 48 | "\n", 49 | "Contact: Sergei Doulatov (doulatov@uw.edu) & William Stafford Noble (wnoble@uw.edu) & Sheng Wang (swang@cs.washington.edu)\n", 50 | "\n", 51 | "For technical problems or questions, please reach to Xiao Wang (wang3702@uw.edu) and Yuanyuan Zhang (zhang038@purdue.edu).\n", 52 | "\n", 53 | "\n", 54 | "If you are using other browsers, disabling tracking protection may help resolve the errors when uploading or downloading files.\n", 55 | "\n", 56 | "For more details, see **Instructions** of the notebook and checkout the **[HiFoundation GitHub](https://github.com/Noble-Lab/HiCFoundation)**. If you use HiCFoundation, please cite it: **Citation**." 57 | ], 58 | "metadata": { 59 | "id": "aqhB21yTcFG_" 60 | } 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "source": [ 65 | "# Instructions \n", 66 | "## Steps\n", 67 | "1. Run HiCFoundation Colab on your interested two Hi-C maps and download the embedding pickle files for further processing.\n", 68 | "2. Connect to a **cpu machine** by clicking the right top button **\"connect\"** in the notebook.
\n", 69 | "3. Upload the embedding of 1st Hi-C map (.pkl file) in Input file1.\n", 70 | "4. Upload the embedding of 1st Hi-C map (.pkl file) in Input file2.\n", 71 | "5. Running the score calculation by by clicking the left running button in Run.\n", 72 | "6. You can check the output to get the similarity score in the same tab." 73 | ], 74 | "metadata": { 75 | "id": "j62BWrrecdSS" 76 | } 77 | }, 78 | { 79 | "cell_type": "code", 80 | "source": [ 81 | "#@title Input embedding file1\n", 82 | "from google.colab import files\n", 83 | "import os\n", 84 | "import os.path\n", 85 | "import re\n", 86 | "import hashlib\n", 87 | "import random\n", 88 | "import string\n", 89 | "from google.colab import drive\n", 90 | "\n", 91 | "from datetime import datetime\n", 92 | "# Get the current date and time\n", 93 | "current_datetime = datetime.now()\n", 94 | "# Convert to string in desired format\n", 95 | "current_datetime_str = current_datetime.strftime(\"%Y-%m-%d-%H-%M-%S\")\n", 96 | "rand_letters = string.ascii_lowercase\n", 97 | "rand_letters = ''.join(random.choice(rand_letters) for i in range(20))\n", 98 | "output_dir=\"/content/\"\n", 99 | "\n", 100 | "#@markdown ## Upload the calculated embedding file(.pkl) of 1st Hi-C from your local file system\n", 101 | "print(\"Please uploading your input files\")\n", 102 | "os.chdir(\"/content/\")\n", 103 | "root_dir = os.getcwd()\n", 104 | "upload_dir = os.path.join(root_dir,rand_letters)\n", 105 | "if not os.path.exists(upload_dir):\n", 106 | " os.mkdir(upload_dir)\n", 107 | "os.chdir(upload_dir)\n", 108 | "map_input = files.upload()\n", 109 | "for fn in map_input.keys():\n", 110 | " print('User uploaded file \"{name}\" with length {length} bytes'.format(\n", 111 | " name=fn, length=len(map_input[fn])))\n", 112 | " hic_input_path1 = os.path.abspath(fn)\n", 113 | " print(\"The input save to %s\"%hic_input_path1)\n", 114 | "os.chdir(root_dir)\n", 115 | "\n" 116 | ], 117 | "metadata": { 118 | "cellView": "form", 119 | "id": "b-pBMNY3c-9o" 120 | }, 121 | "execution_count": null, 122 | "outputs": [] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "source": [ 127 | "#@title Input embedding file2\n", 128 | "#@markdown ## Upload the calculated embedding file(.pkl) of 2nd Hi-C from your local file system\n", 129 | "os.chdir(upload_dir)\n", 130 | "map_input = files.upload()\n", 131 | "for fn in map_input.keys():\n", 132 | " print('User uploaded file \"{name}\" with length {length} bytes'.format(\n", 133 | " name=fn, length=len(map_input[fn])))\n", 134 | " hic_input_path2 = os.path.abspath(fn)\n", 135 | " print(\"The input save to %s\"%hic_input_path2)\n", 136 | "os.chdir(root_dir)" 137 | ], 138 | "metadata": { 139 | "cellView": "form", 140 | "id": "WGe1QYgcf1Fw" 141 | }, 142 | "execution_count": null, 143 | "outputs": [] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "source": [ 148 | "# @title Reproducibility score calculation\n", 149 | "# This script is to calculate the similarity between two Hi-C using a pre-trained reproducibility model.\n", 150 | "\n", 151 | "import os\n", 152 | "import sys\n", 153 | "import numpy as np\n", 154 | "import pickle\n", 155 | "from collections import defaultdict\n", 156 | "\n", 157 | "input_pickle1 = hic_input_path1\n", 158 | "input_pickle2 = hic_input_path2\n", 159 | "\n", 160 | "def load_pickle(file_path):\n", 161 | " with open(file_path, 'rb') as f:\n", 162 | " data = pickle.load(f)\n", 163 | " return data\n", 164 | "\n", 165 | "input1 = load_pickle(input_pickle1)\n", 166 | "input2 = load_pickle(input_pickle2)\n", 167 | "\n", 168 | "def find_key(chr,loc,key_list):\n", 169 | " \"\"\"\n", 170 | " Find the key in the list of keys that contains the given chromosome and location.\n", 171 | " \"\"\"\n", 172 | " key1 = chr+\":\"+loc\n", 173 | " if key1 in key_list:\n", 174 | " return key1\n", 175 | " key1 = \"chr\"+chr+\":\"+loc\n", 176 | " if key1 in key_list:\n", 177 | " return key1\n", 178 | " key1 = chr+\"_\"+chr+\":\"+loc\n", 179 | " if key1 in key_list:\n", 180 | " return key1\n", 181 | " key1 = \"chr\"+chr+\"_chr\"+chr+\":\"+loc\n", 182 | " if key1 in key_list:\n", 183 | " return key1\n", 184 | " return None\n", 185 | "\n", 186 | "def calculate_similarity(input1, input2):\n", 187 | " \"\"\"\n", 188 | " Calculate the similarity between two Hi-C matrices using a pre-trained reproducibility model.\n", 189 | " \"\"\"\n", 190 | " similarity_dict = defaultdict(list)\n", 191 | " for key in input1.keys():\n", 192 | " #1_1:1960,1960 format of key\n", 193 | " split_chromosome = key.split(\":\")[0]\n", 194 | " split_loc = key.split(\":\")[1]\n", 195 | " combine_key = split_chromosome + \":\" + split_loc\n", 196 | " chr = split_chromosome.split(\"_\")[0]\n", 197 | " chr = chr.replace(\"chr\",\"\")\n", 198 | " if combine_key not in input2.keys():\n", 199 | " combine_key = find_key(chr,split_loc,input2.keys())\n", 200 | " if combine_key is None:\n", 201 | " continue\n", 202 | "\n", 203 | " embedding1 = input1[key]\n", 204 | " embedding2 = input2[combine_key]\n", 205 | " # Calculate the similarity between the two embeddings\n", 206 | " similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))\n", 207 | " if np.isnan(similarity):\n", 208 | " continue\n", 209 | " similarity_dict[chr].append(similarity)\n", 210 | " #ignore chrY, chrM, Un, Alt cases\n", 211 | " similarity_list=[]\n", 212 | " for chrom in similarity_dict:\n", 213 | " if \"Y\" in chrom or \"M\" in chrom or \"Un\" in chrom or \"Alt\" in chrom:\n", 214 | " continue\n", 215 | " mean_val = np.mean(similarity_dict[chrom])\n", 216 | " similarity_list.append(mean_val)\n", 217 | " similarity = np.mean(similarity_list)\n", 218 | " return similarity\n", 219 | "\n", 220 | "similarity = calculate_similarity(input1, input2)\n", 221 | "print(\"The reproducibility score between the two Hi-C is: \", similarity)\n" 222 | ], 223 | "metadata": { 224 | "cellView": "form", 225 | "id": "v41vpgxbgQoV" 226 | }, 227 | "execution_count": null, 228 | "outputs": [] 229 | } 230 | ] 231 | } -------------------------------------------------------------------------------- /data_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/data_processing/__init__.py -------------------------------------------------------------------------------- /data_processing/collate_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def collate_fn(batch): 3 | # Transpose the batch (list of lists) to group elements by position 4 | batch_transposed = list(zip(*batch)) 5 | 6 | # Process each position across the batch 7 | processed_batch = [] 8 | for tensors in batch_transposed: 9 | if all(t is None for t in tensors): # If all are None, keep None 10 | processed_batch.append(None) 11 | else: # Otherwise, stack non-None tensors and replace None with zero tensors 12 | #make sure no None element in the tensors 13 | any_none = any(t is None for t in tensors) 14 | assert not any_none, "None element in a list of tensors" 15 | stacked = [ 16 | t for t in tensors 17 | ] 18 | processed_batch.append(torch.stack(stacked)) 19 | 20 | return processed_batch 21 | 22 | -------------------------------------------------------------------------------- /data_processing/finetune_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | import random 6 | from collections import defaultdict 7 | from scipy.sparse import coo_matrix 8 | import pickle 9 | from ops.sparse_ops import array_to_coo 10 | from ops.io_utils import load_pickle 11 | def validate_input_size(input_matrix, window_height, window_width): 12 | """ 13 | Validate the input size is larger than the window size 14 | Args: 15 | input_matrix: the input matrix 16 | window_height: the height of the window 17 | window_width: the width of the window 18 | """ 19 | if isinstance(input_matrix, coo_matrix): 20 | input_matrix = input_matrix.toarray() 21 | input_height, input_width = input_matrix.shape 22 | if input_height==window_height and input_width==window_width: 23 | return True 24 | return False 25 | 26 | def to_tensor(x): 27 | """ 28 | Convert the input to tensor 29 | Args: 30 | x: the input data 31 | """ 32 | if isinstance(x, np.ndarray): 33 | x = torch.from_numpy(x) 34 | elif x is None: 35 | x = None 36 | #if already tensor, do nothing 37 | elif isinstance(x, torch.Tensor): 38 | pass 39 | #if float, convert to tensor 40 | elif isinstance(x, float): 41 | x = torch.tensor(x) 42 | elif isinstance(x, int): 43 | x = torch.tensor(x) 44 | return x 45 | 46 | def list_to_tensor(x): 47 | """ 48 | Convert the list to tensor 49 | Args: 50 | x: the input list 51 | """ 52 | y=[] 53 | for i in x: 54 | y.append(to_tensor(i)) 55 | return y 56 | class Finetune_Dataset(torch.utils.data.Dataset): 57 | def __init__(self,data_list, 58 | transform=None, 59 | window_height= 224, 60 | window_width = 224): 61 | """ 62 | Args: 63 | data_list: list of data directories 64 | transform: the transformation to apply to the data 65 | window_height: the height of the window 66 | window_width: the width of the window 67 | """ 68 | self.data_list = data_list 69 | self.transform = transform 70 | self.window_height = window_height 71 | self.window_width = window_width 72 | self.train_dict=defaultdict(list) 73 | self.train_list=[] 74 | for data_index, data_dir in enumerate(data_list): 75 | cur_dir = data_dir 76 | dataset_name = os.path.basename(cur_dir) 77 | listfiles = os.listdir(cur_dir) 78 | for file_index,file in enumerate(listfiles): 79 | cur_path = os.path.join(cur_dir, file) 80 | if file.endswith('.pkl'): 81 | if file_index==0: 82 | #verify the input pkl file includes the input key 83 | data= load_pickle(cur_path) 84 | data_keys = list(data.keys()) 85 | if 'input' not in data: 86 | print("The input key is not included in the pkl file. The directory is skipped.") 87 | print("The dir is {}".format(cur_dir)) 88 | continue 89 | #check other keys include in the dict 90 | target_exist=False 91 | for key in data_keys: 92 | if "target" in key: 93 | target_exist=True 94 | break 95 | if not target_exist: 96 | print("The target key is not included in the pkl file. The directory is skipped.") 97 | print("The dir is {}".format(cur_dir)) 98 | continue 99 | #validate the input size 100 | input_matrix = data['input'] 101 | if not validate_input_size(input_matrix, window_height, window_width): 102 | print("The input size is not matched with the window size. The directory is skipped.") 103 | print("The dir is {}".format(cur_dir)) 104 | print("The input size is {}".format(input_matrix.shape)) 105 | print("The specified window size is {} x {}".format(window_height, window_width)) 106 | print("Please adjust --input_row_size and --input_col_size to match your input.") 107 | continue 108 | self.train_dict[dataset_name].append(cur_path) 109 | self.train_list.append(cur_path) 110 | else: 111 | print("The file {} is not a .pkl file.".format(file),"It is skipped.") 112 | continue 113 | print("The number of samples used in the dataset is {}".format(len(self.train_list))) 114 | #you can either select the train_list or train_dict to do training based on your exprience 115 | def __len__(self): 116 | return len(self.train_list) 117 | 118 | def convert_rgb(self,data_log,max_value): 119 | if len(data_log.shape)==2: 120 | data_log = data_log[np.newaxis,:] 121 | data_red = np.ones(data_log.shape) 122 | data_log1 = (max_value-data_log)/max_value 123 | data_rgb = np.concatenate([data_red,data_log1,data_log1],axis=0,dtype=np.float32)#transform only accept channel last case 124 | data_rgb = data_rgb.transpose(1,2,0) 125 | return data_rgb 126 | 127 | def __getitem__(self, idx): 128 | train_file = self.train_list[idx] 129 | data = load_pickle(train_file) 130 | input_matrix = data['input'] 131 | if isinstance(input_matrix, coo_matrix): 132 | input_matrix = input_matrix.toarray() 133 | #make sure you save the down-diagonal regions if you use the coo_matrix 134 | #to support off-diagonal submatrix, we did not any automatic symmetrical conversion for your input array. 135 | input_matrix = np.nan_to_num(input_matrix) 136 | input_matrix = input_matrix.astype(np.float32) 137 | input_matrix = np.log10(input_matrix+1) 138 | max_value = np.max(input_matrix) 139 | input_matrix = self.convert_rgb(input_matrix,max_value) 140 | if self.transform: 141 | input_matrix = self.transform(input_matrix) 142 | if "input_count" in data: 143 | total_count = data['input_count'] 144 | else: 145 | total_count = None #indiates not passing the total count 146 | 147 | if "2d_target" in data: 148 | target_matrix = data['2d_target'] 149 | if isinstance(target_matrix, coo_matrix): 150 | target_matrix = target_matrix.toarray() 151 | target_matrix = np.nan_to_num(target_matrix) 152 | target_matrix = target_matrix.astype(np.float32) 153 | else: 154 | target_matrix = None 155 | 156 | if "embed_target" in data: 157 | embed_target = data['embed_target'] 158 | if isinstance(embed_target, coo_matrix): 159 | embed_target = embed_target.toarray() 160 | embed_target = np.nan_to_num(embed_target) 161 | embed_target = embed_target.astype(np.float32) 162 | else: 163 | embed_target = None 164 | 165 | if "1d_target" in data: 166 | target_vector = data['1d_target'] 167 | target_vector = np.nan_to_num(target_vector) 168 | target_vector = target_vector.astype(np.float32) 169 | else: 170 | target_vector = None 171 | 172 | return list_to_tensor([input_matrix, total_count, target_matrix, embed_target, target_vector]) 173 | 174 | 175 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: HiCFoundation 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - anaconda 7 | - bioconda 8 | - defaults 9 | dependencies: 10 | - cudatoolkit=11.1.74 11 | - pip=21.1.3 12 | - python=3.8.10 13 | - pytorch=1.8.1 14 | - torchvision=0.9.1 15 | - timm=0.3.2 16 | - openjdk 17 | - pip: 18 | - easydict 19 | - opencv-python 20 | - simplejson 21 | - lvis 22 | - Pillow==9.5.0 23 | - pytorch_msssim 24 | - pandas 25 | - hic-straw 26 | - matplotlib 27 | - scikit-image 28 | - scipy 29 | - einops 30 | - tensorboard 31 | - cooler 32 | - numba 33 | - pyBigWig 34 | 35 | 36 | -------------------------------------------------------------------------------- /environment_notorch.yml: -------------------------------------------------------------------------------- 1 | name: HiCFoundation 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - anaconda 7 | - bioconda 8 | - defaults 9 | dependencies: 10 | - pip=21.1.3 11 | - python=3.8.10 12 | - openjdk 13 | - pip: 14 | - easydict 15 | - opencv-python 16 | - simplejson 17 | - lvis 18 | - Pillow==9.5.0 19 | - pytorch_msssim 20 | - pandas 21 | - hic-straw 22 | - matplotlib 23 | - scikit-image 24 | - scipy 25 | - einops 26 | - tensorboard 27 | - cooler 28 | - numba 29 | - pyBigWig 30 | 31 | 32 | -------------------------------------------------------------------------------- /example/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/__init__.py -------------------------------------------------------------------------------- /example/finetune_example/train.txt: -------------------------------------------------------------------------------- 1 | train 2 | -------------------------------------------------------------------------------- /example/finetune_example/train/data_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/train/data_0.pkl -------------------------------------------------------------------------------- /example/finetune_example/train/data_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/train/data_1.pkl -------------------------------------------------------------------------------- /example/finetune_example/train/data_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/train/data_2.pkl -------------------------------------------------------------------------------- /example/finetune_example/train/data_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/train/data_3.pkl -------------------------------------------------------------------------------- /example/finetune_example/train/data_4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/train/data_4.pkl -------------------------------------------------------------------------------- /example/finetune_example/val.txt: -------------------------------------------------------------------------------- 1 | val -------------------------------------------------------------------------------- /example/finetune_example/val/data_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/val/data_0.pkl -------------------------------------------------------------------------------- /example/finetune_example/val/data_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/val/data_1.pkl -------------------------------------------------------------------------------- /example/finetune_example/val/data_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/val/data_2.pkl -------------------------------------------------------------------------------- /example/finetune_example/val/data_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/val/data_3.pkl -------------------------------------------------------------------------------- /example/finetune_example/val/data_4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/val/data_4.pkl -------------------------------------------------------------------------------- /example/pretrain_example/train.txt: -------------------------------------------------------------------------------- 1 | train 2 | -------------------------------------------------------------------------------- /example/pretrain_example/train/data_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_1.pkl -------------------------------------------------------------------------------- /example/pretrain_example/train/data_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_2.pkl -------------------------------------------------------------------------------- /example/pretrain_example/train/data_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_3.pkl -------------------------------------------------------------------------------- /example/pretrain_example/train/data_4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_4.pkl -------------------------------------------------------------------------------- /example/pretrain_example/train/data_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_5.pkl -------------------------------------------------------------------------------- /example/pretrain_example/train/data_6.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_6.pkl -------------------------------------------------------------------------------- /example/pretrain_example/train/data_7.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_7.pkl -------------------------------------------------------------------------------- /example/pretrain_example/train/data_8.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_8.pkl -------------------------------------------------------------------------------- /example/pretrain_example/val.txt: -------------------------------------------------------------------------------- 1 | train 2 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | 2 | # My code has references to the following repositories: 3 | # DeiT: https://github.com/facebookresearch/deit 4 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 5 | # MAE: https://github.com/facebookresearch/mae 6 | # AdPE: https://github.com/maple-research-lab/AdPE 7 | # -------------------------------------------------------- 8 | import os 9 | from ops.argparser import argparser_finetune 10 | import torch 11 | import torch.multiprocessing as mp 12 | import timm 13 | assert timm.__version__ == "0.3.2" # version check 14 | def main(args): 15 | import socket 16 | hostname = socket.gethostname() 17 | local_ip = socket.gethostbyname(hostname) 18 | print("local ip: ",local_ip) 19 | 20 | ngpus_per_node = torch.cuda.device_count() 21 | args.world_size = args.world_size*ngpus_per_node 22 | from finetune.main_worker import main_worker 23 | if ngpus_per_node==1: 24 | main_worker(args.gpu,ngpus_per_node,args)#if you only have one gpu 25 | else: 26 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 27 | 28 | if __name__ == '__main__': 29 | import resource 30 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 31 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096*2, rlimit[1])) 32 | limit_in_b = 900 * 1024 ** 3 33 | resource.setrlimit(resource.RLIMIT_DATA, (limit_in_b, limit_in_b)) 34 | use_cuda = torch.cuda.is_available() 35 | print("starting check cuda status",use_cuda) 36 | #assert cuda is available 37 | assert use_cuda == True, "CUDA is not available, fine-tuning requires CUDA support to run!" 38 | parser = argparser_finetune() 39 | args = parser.parse_args() 40 | #If you have many GPU on your server, but you only want to use few of them 41 | # run command line to configure the environment: 42 | # export CUDA_VISIBLE_DEVICES="0,1,2,3" 43 | # Here you can specify the GPU you want to use 44 | #check the specied input size, must be a multiple of args.patch_size 45 | if args.input_row_size%args.patch_size!=0 or args.input_col_size%args.patch_size!=0: 46 | print("args configuration error: input_row_size and input_col_size must be a multiple of patch_size") 47 | exit(1) 48 | main(args) 49 | -------------------------------------------------------------------------------- /finetune/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/finetune/__init__.py -------------------------------------------------------------------------------- /finetune/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | #Please check https://pytorch.org/docs/stable/nn.html for more loss functions 3 | 4 | class cosine_distance(nn.Module): 5 | def __init__(self): 6 | super(cosine_distance, self).__init__() 7 | self.cos = nn.CosineSimilarity(dim=-1, eps=1e-08) 8 | def forward(self, x, y): 9 | return 1-self.cos(x,y) 10 | 11 | def configure_loss(args): 12 | if args.loss_type == 1: 13 | return nn.MSELoss() 14 | elif args.loss_type == 2: 15 | return cosine_distance() 16 | else: 17 | raise Exception("Unknown loss type: {}".format(args.loss_type)) 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /finetune/train_epoch.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import sys 4 | import numpy as np 5 | from typing import Iterable 6 | import torch 7 | import torch.nn.functional as F 8 | import time 9 | 10 | from ops.Logger import MetricLogger,SmoothedValue 11 | import model.lr_sched as lr_sched 12 | from finetune.loss import configure_loss 13 | from ops.train_utils import list_to_device, to_value, create_image, torch_to_nparray, convert_gray_rgbimage 14 | 15 | 16 | def train_epoch(model, data_loader_train, optimizer, 17 | loss_scaler, epoch, device, 18 | log_writer=None, args=None): 19 | model.train() 20 | metric_logger = MetricLogger(delimiter=" ") 21 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) 22 | 23 | header = 'Epoch: [{}]'.format(epoch) 24 | print_freq = args.print_freq 25 | 26 | accum_iter = args.accum_iter 27 | 28 | optimizer.zero_grad() 29 | if log_writer is not None: 30 | print('Tensorboard log dir: {}'.format(log_writer.log_dir)) 31 | print("number of iterations: ",len(data_loader_train)) 32 | criterion = configure_loss(args) 33 | 34 | num_iter = len(data_loader_train) 35 | for data_iter_step, train_data in enumerate(metric_logger.log_every(data_loader_train, print_freq, header)): 36 | if data_iter_step % accum_iter == 0: 37 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args) 38 | input_matrix, total_count, target_matrix, embed_target, target_vector = list_to_device(train_data,device=device) 39 | output_embedding, output_2d, output_1d = model(input_matrix, total_count) 40 | 41 | if embed_target is not None: 42 | embedding_loss = criterion(output_embedding, embed_target) 43 | else: 44 | embedding_loss = 0 45 | if target_matrix is not None: 46 | #flatten 2d matrix 47 | output_2d_flatten = torch.flatten(output_2d, start_dim=1,end_dim=-1) 48 | target_matrix_flatten = torch.flatten(target_matrix, start_dim=1,end_dim=-1) 49 | output_2d_loss = criterion(output_2d_flatten, target_matrix_flatten) 50 | else: 51 | output_2d_loss = 0 52 | if target_vector is not None: 53 | output_1d_loss = criterion(output_1d, target_vector) 54 | else: 55 | output_1d_loss = 0 56 | loss = embedding_loss + output_2d_loss + output_1d_loss #you can adjust the loss function based on your fine-tuning purpose 57 | #typically, I think you should only finetune for one of the purposes 58 | metric_logger.update(loss=to_value(loss)) 59 | metric_logger.update(embedding_loss=to_value(embedding_loss)) 60 | metric_logger.update(output_2d_loss=to_value(output_2d_loss)) 61 | metric_logger.update(output_1d_loss=to_value(output_1d_loss)) 62 | if not math.isfinite(to_value(loss)): 63 | print("Loss is {}, stopping training".format(to_value(loss))) 64 | #sys.exit(1) 65 | optimizer.zero_grad() 66 | continue 67 | loss = loss / accum_iter 68 | loss_scaler(loss, optimizer, parameters=model.parameters(), 69 | update_grad=(data_iter_step + 1) % accum_iter == 0) 70 | 71 | if (data_iter_step + 1) % accum_iter == 0: 72 | optimizer.zero_grad() 73 | 74 | torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on 75 | lr = optimizer.param_groups[0]["lr"] 76 | metric_logger.update(lr=lr) 77 | 78 | 79 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0): 80 | """ 81 | We use epoch_1000x as the x-axis in tensorboard. 82 | This calibrates different curves when batch size changes. 83 | """ 84 | epoch_1000x = int((data_iter_step / len(data_loader_train) + epoch) * 1000) 85 | log_writer.add_scalars('Loss/loss', {'train_loss': to_value(loss)}, epoch_1000x) 86 | log_writer.add_scalars('Loss/embedding_loss', {'train_loss': to_value(embedding_loss)}, epoch_1000x) 87 | log_writer.add_scalars('Loss/output_2d_loss', {'train_loss': to_value(output_2d_loss)}, epoch_1000x) 88 | log_writer.add_scalars('Loss/output_1d_loss', {'train_loss': to_value(output_1d_loss)}, epoch_1000x) 89 | log_writer.add_scalars('LR/lr', {'lr': lr}, epoch_1000x) 90 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0: 91 | #add visualization for your output and input 92 | new_samples = create_image(input_matrix) 93 | select_num = min(8,len(new_samples)) 94 | sample_image = torch_to_nparray(new_samples.clone().detach()[:select_num]) 95 | log_writer.add_images('Input_%s'%"train", sample_image, epoch_1000x) 96 | output_2d_image = convert_gray_rgbimage(output_2d.clone().detach()[:select_num]) 97 | output_2d_image = torch_to_nparray(output_2d_image) 98 | log_writer.add_images('Output_2d_%s'%"train", output_2d_image, epoch_1000x) 99 | # for name, param in model.named_parameters(): 100 | # log_writer.add_histogram(name, param, epoch_1000x) 101 | #raise errors, see https://github.com/pytorch/pytorch/issues/91516 102 | #If you want to use this, install tensorboardX 103 | #then change the code in main_worker.py to "from tensorboardX import SummaryWriter" 104 | # gather the stats from all processes 105 | metric_logger.synchronize_between_processes() 106 | print("Averaged stats:", metric_logger) 107 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /finetune/val_epoch.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import sys 4 | import numpy as np 5 | from typing import Iterable 6 | import torch 7 | import torch.nn.functional as F 8 | import time 9 | 10 | from ops.Logger import MetricLogger,SmoothedValue 11 | import model.lr_sched as lr_sched 12 | from finetune.loss import configure_loss 13 | from finetune.train_epoch import list_to_device, to_value, \ 14 | create_image, torch_to_nparray, convert_gray_rgbimage 15 | 16 | def val_epoch(model, data_loader_val, device, epoch, 17 | log_writer=None, args=None): 18 | model.eval() 19 | metric_logger = MetricLogger(delimiter=" ") 20 | header="Val Epoch: [{}]".format(epoch) 21 | print_freq = args.print_freq 22 | accum_iter = args.accum_iter 23 | criterion = configure_loss(args) 24 | num_iter = len(data_loader_val) 25 | print("number of iterations: ",num_iter) 26 | for data_iter_step, val_data in enumerate(metric_logger.log_every(data_loader_val, print_freq, header)): 27 | input_matrix, total_count, target_matrix, embed_target, target_vector = list_to_device(val_data,device=device) 28 | with torch.no_grad(): 29 | output_embedding, output_2d, output_1d = model(input_matrix, total_count) 30 | if embed_target is not None: 31 | embedding_loss = criterion(output_embedding, embed_target) 32 | else: 33 | embedding_loss = 0 34 | if target_matrix is not None: 35 | #flatten 2d matrix 36 | output_2d_flatten = torch.flatten(output_2d, start_dim=1,end_dim=-1) 37 | target_matrix_flatten = torch.flatten(target_matrix, start_dim=1,end_dim=-1) 38 | output_2d_loss = criterion(output_2d_flatten, target_matrix_flatten) 39 | else: 40 | output_2d_loss = 0 41 | if target_vector is not None: 42 | output_1d_loss = criterion(output_1d, target_vector) 43 | else: 44 | output_1d_loss = 0 45 | loss = embedding_loss + output_2d_loss + output_1d_loss 46 | metric_logger.update(loss=to_value(loss)) 47 | metric_logger.update(embedding_loss=to_value(embedding_loss)) 48 | metric_logger.update(output_2d_loss=to_value(output_2d_loss)) 49 | metric_logger.update(output_1d_loss=to_value(output_1d_loss)) 50 | torch.cuda.synchronize() 51 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0): 52 | """ 53 | We use epoch_1000x as the x-axis in tensorboard. 54 | This calibrates different curves when batch size changes. 55 | """ 56 | epoch_1000x = int((data_iter_step / len(data_loader_val) + epoch) * 1000) 57 | log_writer.add_scalars('Loss/loss', {'val_loss': to_value(loss)}, epoch_1000x) 58 | log_writer.add_scalars('Loss/embedding_loss', {'val_loss': to_value(embedding_loss)}, epoch_1000x) 59 | log_writer.add_scalars('Loss/output_2d_loss', {'val_loss': to_value(output_2d_loss)}, epoch_1000x) 60 | log_writer.add_scalars('Loss/output_1d_loss', {'val_loss': to_value(output_1d_loss)}, epoch_1000x) 61 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0: 62 | #add visualization for your output and input 63 | new_samples = create_image(input_matrix) 64 | select_num = min(8,len(new_samples)) 65 | sample_image = torch_to_nparray(new_samples.clone().detach()[:select_num]) 66 | log_writer.add_images('Input_%s'%"val", sample_image, epoch_1000x) 67 | output_2d_image = convert_gray_rgbimage(output_2d.clone().detach()[:select_num]) 68 | output_2d_image = torch_to_nparray(output_2d_image) 69 | log_writer.add_images('Output_2d_%s'%"val", output_2d_image, epoch_1000x) 70 | metric_logger.synchronize_between_processes() 71 | print("Averaged stats:", metric_logger) 72 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 73 | 74 | -------------------------------------------------------------------------------- /hicfoundation_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/hicfoundation_model/__init__.py -------------------------------------------------------------------------------- /imgs/framework_github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/imgs/framework_github.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timm 3 | assert timm.__version__ == "0.3.2" # version check for timm 4 | from ops.argparser import argparser_infer 5 | from ops.file_format_convert import convert_to_pkl 6 | 7 | def main(args): 8 | import socket 9 | hostname = socket.gethostname() 10 | local_ip = socket.gethostbyname(hostname) 11 | print("local ip: ",local_ip) 12 | #format processing, convert different formats to .pkl format for further processing 13 | output_dir = os.path.abspath(args.output) 14 | os.makedirs(output_dir,exist_ok=True) 15 | input_file = os.path.abspath(args.input) 16 | config_resolution = args.resolution 17 | input_pkl=convert_to_pkl(input_file, output_dir,config_resolution) 18 | 19 | #for reproducibility analysis, we need to smooth the matrix to generate embeddings. 20 | if args.task==1: 21 | from ops.smooth_matrix import smooth_pkl 22 | smooth_pkl_file = os.path.join(output_dir,"input_smoothed.pkl") 23 | input_pkl = smooth_pkl(input_pkl,smooth_pkl_file) 24 | print("Reproducibility analysis smoothed input matrix saved to ",input_pkl) 25 | from inference.main_worker import main_worker 26 | main_worker(args, input_pkl) 27 | 28 | 29 | if __name__ == '__main__': 30 | print("HiCFoundation inference started!") 31 | parser = argparser_infer() 32 | args = parser.parse_args() 33 | if args.gpu is not None: 34 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 35 | #print mode based on --task 36 | if args.task==1: 37 | print("Reproducibility analysis") 38 | elif args.task==2: 39 | print("Loop calling") 40 | elif args.task==3: 41 | print("Resolution enhancement") 42 | elif args.task==4: 43 | print("Epigenomic assay prediction") 44 | elif args.task==5: 45 | print("scHi-C enhancement") 46 | elif args.task==6: 47 | print("Hi-C embedding generation") 48 | embed_depth = args.embed_depth 49 | if embed_depth>8: 50 | print("Error: embed_depth is larger than 8, that is beyond decoder depth. Please set embed_depth<=8") 51 | print("0 indicates the encoder output, k indicates the k-th decoder layer's output") 52 | exit(1) 53 | else: 54 | print("Unknown task specified ",args.task) 55 | print("Please specify the task using --task with 1,2,3,4,5,6") 56 | exit(1) 57 | #check the specied input size, must be a multiple of args.patch_size 58 | if args.input_row_size%args.patch_size!=0 or args.input_col_size%args.patch_size!=0: 59 | print("args configuration error: input_row_size and input_col_size must be a multiple of patch_size") 60 | exit(1) 61 | #output the args in a beautiful format 62 | main(args) 63 | 64 | -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/inference/__init__.py -------------------------------------------------------------------------------- /inference/load_model.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def load_model(model_path,input_row_size,input_col_size, task=6): 4 | """ 5 | Load a model from a file. 6 | 7 | Args: 8 | model_path (str): The path to the model file. 9 | input_row_size (int): The number of rows in the input matrix. 10 | input_col_size (int): The number of columns in the input matrix. 11 | 12 | Returns: 13 | model: The loaded model. 14 | 15 | Notes: 16 | task 0: fine-tuning setting 17 | task 1: reproducibility analysis 18 | task 2: loop calling 19 | task 3: resolution enhancement 20 | task 4: epigenomic assay prediction 21 | task 5: scHi-C enhancement 22 | task 6: embedding analysis 23 | """ 24 | import torch 25 | import model.Vision_Transformer_count as Vision_Transformer 26 | from model.pos_embed import interpolate_pos_embed_inputsize 27 | import torch.nn as nn 28 | 29 | model_name="vit_large_patch16" 30 | patch_size=16 31 | 32 | patch_wise_size = (input_row_size//patch_size, input_col_size//patch_size) 33 | vit_backbone = Vision_Transformer.__dict__[model_name](img_size=(input_row_size,input_col_size)) 34 | checkpoint = torch.load(model_path, map_location='cpu') 35 | checkpoint_model = checkpoint['model'] 36 | state_dict = vit_backbone.state_dict() 37 | for k in ['head.weight', 'head.bias']: 38 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 39 | print(f"Removing key {k} from pretrained checkpoint") 40 | del checkpoint_model[k] 41 | interpolate_pos_embed_inputsize(vit_backbone, checkpoint_model,input_size=patch_wise_size, 42 | use_decoder=False) 43 | # load pre-trained model 44 | msg = vit_backbone.load_state_dict(checkpoint_model, strict=False) 45 | print("Loading pre-train encoder message!") 46 | 47 | from model.Finetune_Model_Head import Finetune_Model_Head 48 | model = Finetune_Model_Head(vit_backbone, task=task, 49 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 50 | mlp_ratio=4., norm_layer=nn.LayerNorm,pos_embed_size=patch_wise_size) 51 | checkpoint = torch.load(model_path, map_location='cpu') 52 | checkpoint_model = checkpoint['model'] 53 | #loading pre-trained decoder 54 | interpolate_pos_embed_inputsize(model, checkpoint['model'], 55 | input_size=patch_wise_size,use_decoder=True) 56 | msg = model.load_state_dict(checkpoint_model, strict=False) 57 | print("Loading pre-train model decoder!") 58 | return model # return the loaded model 59 | 60 | def to_cuda(x): 61 | """ 62 | Move a tensor to the GPU. 63 | 64 | Args: 65 | x (torch.Tensor): The tensor to move to the GPU. 66 | 67 | Returns: 68 | torch.Tensor: The tensor on the GPU. 69 | """ 70 | import torch 71 | if x is not None: 72 | #if it is float or int, change to tensor 73 | if type(x) is int or type(x) is float: 74 | x = torch.tensor(x) 75 | return x.cuda() 76 | else: 77 | return None 78 | 79 | def to_float(x): 80 | """ 81 | Convert a tensor to float. 82 | 83 | Args: 84 | x (torch.Tensor): The tensor to convert to float. 85 | 86 | Returns: 87 | torch.Tensor: The tensor as float. 88 | """ 89 | import torch 90 | if x is not None: 91 | return x.float() 92 | else: 93 | return None 94 | 95 | def convert_rgb(data_log,max_value): 96 | import torch 97 | if len(data_log.shape)==2: 98 | data_log = data_log[None,:,:] 99 | data_red = torch.ones(data_log.shape) 100 | data_log1 = (max_value-data_log)/max_value 101 | data_rgb = torch.cat([data_red,data_log1,data_log1],dim=0) 102 | return data_rgb 103 | 104 | def format_input(input): 105 | """ 106 | Format the input for the model. 107 | 108 | Args: 109 | input (torch.Tensor): The input tensor. 110 | 111 | Returns: 112 | torch.Tensor: The formatted input tensor. 113 | """ 114 | import torch 115 | import torchvision.transforms as transforms 116 | transform_input = transforms.Compose([ 117 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 118 | 119 | input = torch.nan_to_num(input) 120 | max_value = torch.max(input) 121 | input = torch.log10(input+1) 122 | max_value = torch.log10(max_value+1) 123 | input = convert_rgb(input,max_value) 124 | 125 | input = transform_input(input) 126 | return input -------------------------------------------------------------------------------- /model/NativeScaler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._six import inf 3 | 4 | class NativeScalerWithGradNormCount: 5 | state_dict_key = "amp_scaler" 6 | 7 | def __init__(self): 8 | self._scaler = torch.cuda.amp.GradScaler() 9 | 10 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 11 | self._scaler.scale(loss).backward(create_graph=create_graph) 12 | if update_grad: 13 | if clip_grad is not None: 14 | assert parameters is not None 15 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 16 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 17 | else: 18 | self._scaler.unscale_(optimizer) 19 | norm = get_grad_norm_(parameters) 20 | self._scaler.step(optimizer) 21 | self._scaler.update() 22 | else: 23 | norm = None 24 | return norm 25 | 26 | def state_dict(self): 27 | return self._scaler.state_dict() 28 | 29 | def load_state_dict(self, state_dict): 30 | self._scaler.load_state_dict(state_dict) 31 | 32 | 33 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 34 | if isinstance(parameters, torch.Tensor): 35 | parameters = [parameters] 36 | parameters = [p for p in parameters if p.grad is not None] 37 | norm_type = float(norm_type) 38 | if len(parameters) == 0: 39 | return torch.tensor(0.) 40 | device = parameters[0].grad.device 41 | if norm_type == inf: 42 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 43 | else: 44 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 45 | return total_norm 46 | -------------------------------------------------------------------------------- /model/Vision_Transformer_count.py: -------------------------------------------------------------------------------- 1 | #adopted from https://github.com/facebookresearch/mae repo 2 | 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | import timm.models.vision_transformer 10 | 11 | from model.pos_embed import convert_count_to_pos_embed_cuda 12 | 13 | 14 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 15 | """ Vision Transformer with support for global average pooling 16 | """ 17 | def __init__(self, **kwargs): 18 | super(VisionTransformer, self).__init__(**kwargs) 19 | self.patch_size = kwargs['patch_size'] 20 | self.in_chans = kwargs['in_chans'] 21 | self.embed_dim = kwargs['embed_dim'] 22 | 23 | def forward_features(self, x,total_count): 24 | B = x.shape[0] 25 | x = self.patch_embed(x) 26 | 27 | total_count = torch.log10(total_count) 28 | count_embed = convert_count_to_pos_embed_cuda(total_count, self.embed_dim) 29 | count_embed = count_embed.unsqueeze(1)# (N, 1, D) 30 | 31 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 32 | #cls tokens is simply a learnable parameter and 0 positional embedding is added 33 | x = x + self.pos_embed[:, 1:, :] 34 | 35 | x = torch.cat((cls_tokens,count_embed, x), dim=1) 36 | 37 | x = self.pos_drop(x) 38 | 39 | for blk in self.blocks: 40 | x = blk(x) 41 | 42 | 43 | x = self.norm(x) 44 | 45 | return x 46 | 47 | 48 | def vit_base_patch16(**kwargs): 49 | model = VisionTransformer(in_chans=3, 50 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 51 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 52 | return model 53 | 54 | 55 | def vit_large_patch16(**kwargs): 56 | model = VisionTransformer(in_chans=3, 57 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 58 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 59 | return model 60 | 61 | 62 | def vit_huge_patch14(**kwargs): 63 | model = VisionTransformer(in_chans=3, 64 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 65 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 66 | return model 67 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/model/__init__.py -------------------------------------------------------------------------------- /model/lr_decay.py: -------------------------------------------------------------------------------- 1 | 2 | # -------------------------------------------------------- 3 | # References: 4 | # ELECTRA https://github.com/google-research/electra 5 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 6 | # MAE: https://github.com/facebookresearch/mae 7 | # -------------------------------------------------------- 8 | 9 | import json 10 | 11 | 12 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 13 | """ 14 | Parameter groups for layer-wise lr decay 15 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 16 | """ 17 | param_group_names = {} 18 | param_groups = {} 19 | 20 | num_layers = len(model.blocks) + 1 21 | 22 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 23 | 24 | for n, p in model.named_parameters(): 25 | if not p.requires_grad: 26 | continue 27 | 28 | # no decay: all 1D parameters and model specific ones 29 | if p.ndim == 1 or n in no_weight_decay_list: 30 | g_decay = "no_decay" 31 | this_decay = 0. 32 | else: 33 | g_decay = "decay" 34 | this_decay = weight_decay 35 | 36 | layer_id = get_layer_id_for_vit(n, num_layers) 37 | group_name = "layer_%d_%s" % (layer_id, g_decay) 38 | 39 | if group_name not in param_group_names: 40 | this_scale = layer_scales[layer_id] 41 | 42 | param_group_names[group_name] = { 43 | "lr_scale": this_scale, 44 | "weight_decay": this_decay, 45 | "params": [], 46 | } 47 | param_groups[group_name] = { 48 | "lr_scale": this_scale, 49 | "weight_decay": this_decay, 50 | "params": [], 51 | } 52 | 53 | param_group_names[group_name]["params"].append(n) 54 | param_groups[group_name]["params"].append(p) 55 | 56 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 57 | 58 | return list(param_groups.values()) 59 | 60 | 61 | def get_layer_id_for_vit(name, num_layers): 62 | """ 63 | Assign a parameter with its layer id 64 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 65 | """ 66 | if name in ['cls_token', 'pos_embed']: 67 | return 0 68 | elif name.startswith('patch_embed'): 69 | return 0 70 | elif name.startswith('blocks'): 71 | return int(name.split('.')[1]) + 1 72 | elif name.startswith('decoder_blocks'): 73 | return int(name.split('.')[1]) + 1 74 | else: 75 | return num_layers 76 | 77 | 78 | def param_groups_lrd_decoder(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 79 | """ 80 | Parameter groups for layer-wise lr decay 81 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 82 | """ 83 | param_group_names = {} 84 | param_groups = {} 85 | 86 | num_layers = len(model.decoder_blocks) + 1 87 | 88 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 89 | 90 | for n, p in model.named_parameters(): 91 | if not p.requires_grad: 92 | print("no gradient apply for ",n) 93 | continue 94 | 95 | # no decay: all 1D parameters and model specific ones 96 | if p.ndim == 1 or n in no_weight_decay_list: 97 | g_decay = "no_decay" 98 | this_decay = 0. 99 | else: 100 | g_decay = "decay" 101 | this_decay = weight_decay 102 | 103 | layer_id = get_layer_id_for_vit(n, num_layers) 104 | group_name = "layer_%d_%s" % (layer_id, g_decay) 105 | 106 | if group_name not in param_group_names: 107 | this_scale = layer_scales[layer_id] 108 | 109 | param_group_names[group_name] = { 110 | "lr_scale": this_scale, 111 | "weight_decay": this_decay, 112 | "params": [], 113 | } 114 | param_groups[group_name] = { 115 | "lr_scale": this_scale, 116 | "weight_decay": this_decay, 117 | "params": [], 118 | } 119 | 120 | param_group_names[group_name]["params"].append(n) 121 | param_groups[group_name]["params"].append(p) 122 | 123 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 124 | 125 | return list(param_groups.values()) 126 | -------------------------------------------------------------------------------- /model/lr_sched.py: -------------------------------------------------------------------------------- 1 | # From MAE: https://github.com/facebookresearch/mae 2 | 3 | import math 4 | 5 | def adjust_learning_rate(optimizer, epoch, args): 6 | """Decay the learning rate with half-cycle cosine after warmup""" 7 | if epoch < args.warmup_epochs: 8 | lr = args.lr * epoch / args.warmup_epochs 9 | else: 10 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 11 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 12 | for param_group in optimizer.param_groups: 13 | if "lr_scale" in param_group: 14 | param_group["lr"] = lr * param_group["lr_scale"] 15 | else: 16 | param_group["lr"] = lr 17 | return lr 18 | 19 | -------------------------------------------------------------------------------- /model/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ops.distribute_utils import is_main_process 3 | from pathlib import Path 4 | import os 5 | 6 | def load_model(resume_path,args,model_without_ddp, optimizer, loss_scaler): 7 | """ 8 | Load the model from the checkpoint 9 | Args: 10 | resume_path: the path to the checkpoint 11 | model_without_ddp: the model 12 | optimizer: the optimizer 13 | loss_scaler: the loss scaler 14 | """ 15 | if os.path.isfile(resume_path): 16 | print("=> loading checkpoint '{}'".format(resume_path)) 17 | if resume_path.startswith('https'): 18 | checkpoint = torch.hub.load_state_dict_from_url( 19 | resume_path, map_location='cpu', check_hash=True) 20 | else: 21 | checkpoint = torch.load(resume_path, map_location='cpu') 22 | msg=model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 23 | print("model resume message:{}".format(msg)) 24 | optimizer.load_state_dict(checkpoint['optimizer']) 25 | loss_scaler.load_state_dict(checkpoint['scaler']) 26 | args.start_epoch = checkpoint['epoch'] + 1 27 | print("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch'])) 28 | else: 29 | print("=> no checkpoint found at '{}'".format(resume_path)) 30 | 31 | def save_on_master(*args, **kwargs): 32 | if is_main_process(): 33 | torch.save(*args, **kwargs) 34 | 35 | 36 | 37 | def save_checkpoint(output_dir, args,epoch, model_without_ddp, optimizer, loss_scaler): 38 | output_dir = Path(output_dir) 39 | epoch_name = str(epoch) 40 | 41 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 42 | for checkpoint_path in checkpoint_paths: 43 | to_save = { 44 | 'model': model_without_ddp.state_dict(), 45 | 'optimizer': optimizer.state_dict(), 46 | 'epoch': epoch, 47 | 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None, 48 | 'args': args, 49 | } 50 | 51 | save_on_master(to_save, checkpoint_path) 52 | 53 | def save_model2path(model_path,args,epoch, model_without_ddp, optimizer, loss_scaler): 54 | to_save={ 55 | 'model': model_without_ddp.state_dict(), 56 | 'optimizer': optimizer.state_dict(), 57 | 'epoch': epoch, 58 | 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None, 59 | 'args': args, 60 | } 61 | save_on_master(to_save, model_path) 62 | 63 | -------------------------------------------------------------------------------- /ops/Logger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, deque 2 | import torch 3 | from ops.distribute_utils import is_dist_avail_and_initialized 4 | import torch.distributed as dist 5 | import datetime 6 | import os 7 | import time 8 | 9 | class SmoothedValue(object): 10 | """Track a series of values and provide access to smoothed values over a 11 | window or the global series average. 12 | """ 13 | 14 | def __init__(self, window_size=20, fmt=None): 15 | if fmt is None: 16 | fmt = "{median:.4f} ({global_avg:.4f})" 17 | self.deque = deque(maxlen=window_size) 18 | self.total = 0.0 19 | self.count = 0 20 | self.fmt = fmt 21 | 22 | def update(self, value, n=1): 23 | self.deque.append(value) 24 | self.count += n 25 | self.total += value * n 26 | 27 | def synchronize_between_processes(self): 28 | """ 29 | Warning: does not synchronize the deque! 30 | """ 31 | if not is_dist_avail_and_initialized(): 32 | return 33 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 34 | dist.barrier() 35 | dist.all_reduce(t) 36 | t = t.tolist() 37 | self.count = int(t[0]) 38 | self.total = t[1] 39 | 40 | @property 41 | def median(self): 42 | d = torch.tensor(list(self.deque)) 43 | return d.median().item() 44 | 45 | @property 46 | def avg(self): 47 | d = torch.tensor(list(self.deque), dtype=torch.float32) 48 | return d.mean().item() 49 | 50 | @property 51 | def global_avg(self): 52 | return self.total / self.count 53 | 54 | @property 55 | def max(self): 56 | return max(self.deque) 57 | 58 | @property 59 | def value(self): 60 | return self.deque[-1] 61 | 62 | def __str__(self): 63 | return self.fmt.format( 64 | median=self.median, 65 | avg=self.avg, 66 | global_avg=self.global_avg, 67 | max=self.max, 68 | value=self.value) 69 | 70 | 71 | class MetricLogger(object): 72 | def __init__(self, delimiter="\t"): 73 | self.meters = defaultdict(SmoothedValue) 74 | self.delimiter = delimiter 75 | 76 | def update(self, **kwargs): 77 | for k, v in kwargs.items(): 78 | if v is None: 79 | continue 80 | if isinstance(v, torch.Tensor): 81 | v = v.item() 82 | assert isinstance(v, (float, int)) 83 | self.meters[k].update(v) 84 | 85 | def __getattr__(self, attr): 86 | if attr in self.meters: 87 | return self.meters[attr] 88 | if attr in self.__dict__: 89 | return self.__dict__[attr] 90 | raise AttributeError("'{}' object has no attribute '{}'".format( 91 | type(self).__name__, attr)) 92 | 93 | def __str__(self): 94 | loss_str = [] 95 | for name, meter in self.meters.items(): 96 | loss_str.append( 97 | "{}: {}".format(name, str(meter)) 98 | ) 99 | return self.delimiter.join(loss_str) 100 | 101 | def synchronize_between_processes(self): 102 | for meter in self.meters.values(): 103 | meter.synchronize_between_processes() 104 | 105 | def add_meter(self, name, meter): 106 | self.meters[name] = meter 107 | def log(self,iteration,print_freq,header=None): 108 | if iteration % print_freq == 0: 109 | return 110 | 111 | if not header: 112 | header = '' 113 | #space_fmt = ':' + str(iteration) + 'd' 114 | log_msg = [ 115 | header, 116 | '['+str(iteration)+']', 117 | '{meters}', 118 | ] 119 | if torch.cuda.is_available(): 120 | log_msg.append('max mem: {memory:.0f}') 121 | MB = 1024.0 * 1024.0 122 | log_msg = self.delimiter.join(log_msg) 123 | if torch.cuda.is_available(): 124 | print(log_msg.format( 125 | iteration, 126 | meters=str(self), 127 | memory=torch.cuda.max_memory_allocated() / MB)) 128 | else: 129 | print(log_msg.format( 130 | iteration, 131 | meters=str(self))) 132 | def log_every(self, iterable, print_freq, header=None): 133 | i = 0 134 | if not header: 135 | header = '' 136 | start_time = time.time() 137 | end = time.time() 138 | iter_time = SmoothedValue(fmt='{avg:.4f}') 139 | data_time = SmoothedValue(fmt='{avg:.4f}') 140 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 141 | log_msg = [ 142 | header, 143 | '[{0' + space_fmt + '}/{1}]', 144 | 'eta: {eta}', 145 | '{meters}', 146 | 'time: {time}', 147 | 'data: {data}' 148 | ] 149 | if torch.cuda.is_available(): 150 | log_msg.append('max mem: {memory:.0f}') 151 | log_msg = self.delimiter.join(log_msg) 152 | MB = 1024.0 * 1024.0 153 | for obj in iterable: 154 | data_time.update(time.time() - end) 155 | yield obj 156 | iter_time.update(time.time() - end) 157 | if i % print_freq == 0 or i == len(iterable) - 1: 158 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 159 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 160 | if torch.cuda.is_available(): 161 | print(log_msg.format( 162 | i, len(iterable), eta=eta_string, 163 | meters=str(self), 164 | time=str(iter_time), data=str(data_time), 165 | memory=torch.cuda.max_memory_allocated() / MB),flush=True) 166 | else: 167 | print(log_msg.format( 168 | i, len(iterable), eta=eta_string, 169 | meters=str(self), 170 | time=str(iter_time), data=str(data_time)),flush=True) 171 | i += 1 172 | end = time.time() 173 | total_time = time.time() - start_time 174 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 175 | print('{} Total time: {} ({:.4f} s / it)'.format( 176 | header, total_time_str, total_time / len(iterable)),flush=True) 177 | 178 | 179 | def print_important_info(str): 180 | print("="*50) 181 | print("Important:"+str) 182 | print("="*50) 183 | 184 | def print_warning_info(str): 185 | print("*"*50) 186 | print("Warning:"+str) 187 | print("*"*50) -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/ops/__init__.py -------------------------------------------------------------------------------- /ops/calculate_similarity.py: -------------------------------------------------------------------------------- 1 | # This script is to calculate the similarity between two Hi-C using a pre-trained reproducibility model. 2 | 3 | import os 4 | import sys 5 | import numpy as np 6 | import pickle 7 | from collections import defaultdict 8 | 9 | input_pickle1 = sys.argv[1] 10 | input_pickle2 = sys.argv[2] 11 | 12 | def load_pickle(file_path): 13 | with open(file_path, 'rb') as f: 14 | data = pickle.load(f) 15 | return data 16 | 17 | input1 = load_pickle(input_pickle1) 18 | input2 = load_pickle(input_pickle2) 19 | 20 | def find_key(chr,loc,key_list): 21 | """ 22 | Find the key in the list of keys that contains the given chromosome and location. 23 | """ 24 | key1 = chr+":"+loc 25 | if key1 in key_list: 26 | return key1 27 | key1 = "chr"+chr+":"+loc 28 | if key1 in key_list: 29 | return key1 30 | key1 = chr+"_"+chr+":"+loc 31 | if key1 in key_list: 32 | return key1 33 | key1 = "chr"+chr+"_chr"+chr+":"+loc 34 | if key1 in key_list: 35 | return key1 36 | return None 37 | 38 | def calculate_similarity(input1, input2): 39 | """ 40 | Calculate the similarity between two Hi-C matrices using a pre-trained reproducibility model. 41 | """ 42 | similarity_dict = defaultdict(list) 43 | for key in input1.keys(): 44 | #1_1:1960,1960 format of key 45 | split_chromosome = key.split(":")[0] 46 | split_loc = key.split(":")[1] 47 | combine_key = split_chromosome + ":" + split_loc 48 | chr = split_chromosome.split("_")[0] 49 | chr = chr.replace("chr","") 50 | if combine_key not in input2.keys(): 51 | combine_key = find_key(chr,split_loc,input2.keys()) 52 | if combine_key is None: 53 | continue 54 | 55 | embedding1 = input1[key] 56 | embedding2 = input2[combine_key] 57 | # Calculate the similarity between the two embeddings 58 | similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) 59 | if np.isnan(similarity): 60 | continue 61 | similarity_dict[chr].append(similarity) 62 | #ignore chrY, chrM, Un, Alt cases 63 | similarity_list=[] 64 | for chrom in similarity_dict: 65 | if "Y" in chrom or "M" in chrom or "Un" in chrom or "Alt" in chrom: 66 | continue 67 | mean_val = np.mean(similarity_dict[chrom]) 68 | similarity_list.append(mean_val) 69 | similarity = np.mean(similarity_list) 70 | return similarity 71 | 72 | similarity = calculate_similarity(input1, input2) 73 | print("The reproducibility score between the two Hi-C is: ", similarity) 74 | -------------------------------------------------------------------------------- /ops/distribute_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import datetime 4 | import builtins 5 | import numpy as np 6 | import torch.distributed as dist 7 | 8 | def is_dist_avail_and_initialized(): 9 | if not dist.is_available(): 10 | return False 11 | if not dist.is_initialized(): 12 | return False 13 | return True 14 | def get_rank(): 15 | if not is_dist_avail_and_initialized(): 16 | return 0 17 | return dist.get_rank() 18 | def get_world_size(): 19 | if not is_dist_avail_and_initialized(): 20 | return 1 21 | return dist.get_world_size() 22 | def all_reduce_mean(x): 23 | world_size = get_world_size() 24 | if world_size > 1: 25 | x_reduce = torch.tensor(x).cuda() 26 | dist.all_reduce(x_reduce) 27 | x_reduce /= world_size 28 | return x_reduce.item() 29 | else: 30 | return x 31 | def is_main_process(): 32 | return get_rank() == 0 33 | def setup_for_distributed(is_master): 34 | """ 35 | This function disables printing when not in master process 36 | """ 37 | builtin_print = builtins.print 38 | 39 | def print(*args, **kwargs): 40 | 41 | if is_master: 42 | now = datetime.datetime.now().time() 43 | builtin_print('[{}] '.format(now), end='') # print with time stamp 44 | builtin_print(*args, **kwargs) 45 | 46 | builtins.print = print 47 | 48 | def init_distributed_mode(gpu,ngpus_per_node,args): 49 | 50 | import resource 51 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 52 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) 53 | args.gpu = gpu 54 | args.rank = args.rank * ngpus_per_node + gpu 55 | os.environ['LOCAL_RANK'] = str(args.gpu) 56 | os.environ['RANK'] = str(args.rank) 57 | os.environ['WORLD_SIZE'] = str(args.world_size) 58 | print("make sure the distributed mode is ",args.dist_url) 59 | 60 | 61 | 62 | args.distributed = True 63 | 64 | torch.cuda.set_device(args.gpu) 65 | args.dist_backend = 'nccl' 66 | print('| distributed init (rank {}): {}, gpu {}'.format( 67 | args.rank, args.dist_url, args.gpu), flush=True) 68 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 69 | timeout=datetime.timedelta(seconds=36000), 70 | world_size=args.world_size, rank=args.rank) 71 | 72 | setup_for_distributed(args.rank == 0) 73 | 74 | 75 | 76 | # fix the seed for reproducibility 77 | seed = args.seed + get_rank() 78 | torch.manual_seed(seed) 79 | np.random.seed(seed) 80 | -------------------------------------------------------------------------------- /ops/file_format_convert.py: -------------------------------------------------------------------------------- 1 | from utils.hic2array import hic2array 2 | from utils.cool2array import cool2array_intra 3 | from utils.array2hic import array2hic 4 | from utils.array2cool import array2cool 5 | import os 6 | import numpy as np 7 | from ops.sparse_ops import array_to_coo 8 | from scipy.sparse import coo_matrix 9 | from collections import defaultdict 10 | def write_pkl(return_dict,output_pkl_path): 11 | import pickle 12 | with open(output_pkl_path,'wb') as f: 13 | pickle.dump(return_dict,f) 14 | print("finish writing to:",output_pkl_path) 15 | def load_pkl(input_pkl): 16 | import pickle 17 | with open(input_pkl,'rb') as f: 18 | return_dict = pickle.load(f) 19 | return return_dict 20 | 21 | def read_text(input_file,config_resolution): 22 | #records should be readID chr1 pos1 chr2 pos2 23 | #read line by line to get the sparse matrix 24 | final_dict=defaultdict(list) 25 | with open(input_file,'r') as f: 26 | for line in f: 27 | line = line.strip().split() 28 | try: 29 | chr1 = line[1] 30 | chr2 = line[3] 31 | pos1 = int(line[2])//config_resolution 32 | pos2 = int(line[4])//config_resolution 33 | final_dict[(chr1,chr2)].append((pos1,pos2)) 34 | except: 35 | print("*"*40) 36 | print("Skip line in records:",line) 37 | print("The line should be in format of [readID chr1 pos1 chr2 pos2]") 38 | print("*"*40) 39 | return final_dict 40 | 41 | def countlist2coo(input_dict): 42 | final_dict={} 43 | for key in input_dict: 44 | row=[] 45 | col=[] 46 | data=[] 47 | for item in input_dict[key]: 48 | row.append(item[0]) 49 | col.append(item[1]) 50 | data.append(1) 51 | max_size = max(max(row),max(col))+1 52 | cur_array = coo_matrix((data,(row,col)),shape=(max_size,max_size)) 53 | #sum duplicates 54 | cur_array.sum_duplicates() 55 | final_dict[key]=cur_array 56 | return final_dict 57 | def convert_to_pkl(input_file, output_dir,config_resolution): 58 | output_pkl = os.path.join(output_dir, "input.pkl") 59 | #if it is a .hic file 60 | if input_file.endswith('.hic'): 61 | #convert to .pkl format, only keep intra-chromosome regions 62 | hic2array(input_file,output_pkl=output_pkl, 63 | resolution=config_resolution,normalization="NONE", 64 | tondarray=2) 65 | elif input_file.endswith('.cool'): 66 | return_dict=cool2array_intra(input_file,normalize=False, 67 | tondarray=False,binsize=config_resolution) 68 | write_pkl(return_dict,output_pkl) 69 | elif input_file.endswith('.pkl'): 70 | #load pickle to sanity check 71 | return_dict = load_pkl(input_file) 72 | final_dict = {} 73 | #check if it is dict 74 | if not isinstance(return_dict,dict): 75 | raise ValueError("Input pkl file should be a dictionary") 76 | else: 77 | for key in return_dict: 78 | if isinstance(return_dict[key],np.ndarray): 79 | final_dict[key] = array_to_coo(return_dict[key]) 80 | elif isinstance(return_dict[key],coo_matrix): 81 | final_dict[key] = return_dict[key] 82 | else: 83 | raise ValueError("The value of the dictionary in .pkl should be either numpy array or coo_matrix") 84 | write_pkl(final_dict,output_pkl) 85 | elif input_file.endswith('.txt') or input_file.endswith('.pairs'): 86 | #convert to .pkl format 87 | initial_dict = read_text(input_file,config_resolution) 88 | #filter intra-chromosome regions 89 | final_dict = {} 90 | for key in initial_dict: 91 | if key[0] == key[1]: 92 | final_dict[key[0]] = initial_dict[key] 93 | #then change it to coo_matrix array 94 | return_dict = countlist2coo(final_dict) 95 | write_pkl(return_dict,output_pkl) 96 | elif input_file.endswith('.npy'): 97 | #load numpy array 98 | input_array = np.load(input_file) 99 | #convert to coo_matrix 100 | final_array = array_to_coo(input_array) 101 | #save to pkl 102 | final_dict={"chr_tmp":final_array} 103 | write_pkl(final_dict,output_pkl) 104 | else: 105 | print("Unsupported file format ",input_file) 106 | print("Supported file format: .hic/.cool/.pkl/.txt/.npy") 107 | raise ValueError("Unsupported file format") 108 | 109 | return output_pkl 110 | 111 | def pkl2others(input_pkl, output_file,config_resolution,genome_id): 112 | output_dir = os.path.dirname(output_file) 113 | os.makedirs(output_dir,exist_ok=True) 114 | data=load_pkl(input_pkl) 115 | if output_file.endswith('.txt') or output_file.endswith('.pairs'): 116 | #write to simple txt 117 | # [chr1, pos1, chr2, pos2, count] 118 | with open(output_file,'w') as file: 119 | file.write("#readID\tchr1\tpos1\tchr2\tpos2\tcount\n") 120 | for chrom in data: 121 | if type(data[chrom]) == np.ndarray: 122 | for i in range(data[chrom].shape[0]): 123 | for j in range(data[chrom].shape[1]): 124 | if data[chrom][i,j]>0: 125 | file.write(f".\t{chrom}\t{i*config_resolution}\t{chrom}\t{j*config_resolution}\t{data[chrom][i,j]}\n") 126 | else: 127 | for i in range(data[chrom].nnz): 128 | row = data[chrom].row[i] 129 | col = data[chrom].col[i] 130 | count = data[chrom].data[i] 131 | file.write(f".\t{chrom}\t{row*config_resolution}\t{chrom}\t{col*config_resolution}\t{count}\n") 132 | elif output_file.endswith('.npy'): 133 | if len(data)>1: 134 | print("Warning: multiple chromosomes detected, please check the output in .pkl format:",input_pkl) 135 | print("The format is dict in format of [chr]:[scipy.sparse.coo_matrix]") 136 | return 137 | current_array = data[list(data.keys())[0]] 138 | if isinstance(current_array,coo_matrix): 139 | current_array = current_array.toarray() 140 | np.save(output_file,current_array) 141 | elif output_file.endswith('.hic'): 142 | #https://github.com/aidenlab/juicer/wiki/Pre 143 | cur_py_path = os.path.abspath(__file__) 144 | cur_py_dir = os.path.dirname(cur_py_path) 145 | code_repo_dir=os.path.dirname(cur_py_dir) 146 | juicer_tools= os.path.join(code_repo_dir,"utils","juicer_tools.jar") 147 | array2hic(juicer_tools,input_pkl,output_file,config_resolution,genome_id,1) 148 | elif output_file.endswith('.cool'): 149 | 150 | array2cool(input_pkl,output_file,config_resolution,genome_id,1) 151 | elif output_file.endswith('.pkl'): 152 | output_file = input_pkl 153 | else: 154 | print("Unsupported file format ",output_file) 155 | output_file=input_pkl 156 | print("Final output is saved in ",output_file) 157 | return output_file -------------------------------------------------------------------------------- /ops/io_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import os 4 | import json 5 | def load_pickle(path): 6 | with open(path,'rb') as file: 7 | data=pickle.load(file) 8 | return data 9 | 10 | def write_pickle(data,path): 11 | with open(path,'wb') as file: 12 | pickle.dump(data, file) 13 | 14 | def append_record(bedpe_path_min,min_loc_list,chrom,resolution=10000): 15 | if "_" in chrom: 16 | chrom = chrom.split("_")[0] 17 | 18 | with open(bedpe_path_min,'a') as wfile: 19 | for loc in min_loc_list: 20 | x1,x2 = loc 21 | if x1=xdim: 27 | endp[0]=xdim 28 | if endp[1]>=ydim: 29 | endp[1]=ydim 30 | dtotal=0 31 | pos2=np.zeros(2) 32 | for xp in range(int(stp[0]),int(endp[0])): 33 | rx=float(xp-pos[0])**2 34 | for yp in range(int(stp[1]),int(endp[1])): 35 | ry=float(yp-pos[1])**2 36 | 37 | d2=rx+ry 38 | v=np.exp(-1.5*d2*fsiv)*dens[xp,yp]#This is the bottom part of the equation, where pos represents y, (xp,yp,zp) represents xi 39 | dtotal+=v 40 | if v>0: 41 | pos2[0]+=v*(float)(xp)#pos2 is for the top part of the equation 42 | pos2[1]+=v*(float)(yp) 43 | 44 | if dtotal==0: 45 | break 46 | rd=1.00/float(dtotal) 47 | tempcd=np.zeros(2) 48 | for j in range(2): 49 | pos2[j]*=rd#Now we get the equation result 50 | tempcd[j]=pos[j]-pos2[j] 51 | pos[j]=pos2[j]#Prepare for iteration 52 | check_d=tempcd[0]**2+tempcd[1]**2#Iteration until you find the place is stable 53 | if check_d<0.001: 54 | break 55 | 56 | for j in range(2): 57 | 58 | point_cd[i][j]=pos[j] 59 | point_dens[i]=dtotal/cnt 60 | return point_cd,point_dens 61 | @jit(nogil=True,nopython=True) 62 | def acc_merge_point(Ncd,dens,dmin,rv_range,rdcut,stock,cd,d2cut,member): 63 | if True: 64 | for i in range(Ncd-1): 65 | if i%10000==0: 66 | print(i) 67 | tmp=np.zeros(2) 68 | if (dens[i]-dmin)*rv_range < rdcut: 69 | stock[i]=0#Label the small density parts as unused parts 70 | if stock[i]==0: 71 | continue 72 | for j in range(i+1,Ncd): 73 | if stock[j]==0: 74 | continue 75 | d2=0 76 | for k in range(2): 77 | tmp[k]=cd[i][k]-cd[j][k] 78 | d2+=tmp[k]**2 79 | if d2dens[j]: 82 | stock[j]=0 83 | member[j]=i 84 | else: 85 | stock[i]=0 86 | member[i]=j 87 | break#jump out of the second rotation, since i has been merged 88 | #Update member data, to updata some son/grandson points to original father point 89 | for i in range(Ncd): 90 | now=int(member[i]) 91 | while now!=member[now]:#If it's not merged points, it will totates to find the father point(merged point) 92 | now=int(member[now]) 93 | member[i]=now 94 | return stock,member 95 | 96 | @jit(nogil=True,nopython=True) 97 | def further_merge_point(count_loc, new_density, stock, 98 | new_location, d2cut,member): 99 | #before_count = len(np.argwhere(stock==1)) 100 | for i in range(count_loc-1): 101 | if i%10000==0: 102 | print(i) 103 | member_i= member[i] 104 | for j in range(i+1,count_loc): 105 | member_j = member[j] 106 | if member_i==member_j: 107 | continue 108 | d2=0 109 | for k in range(2): 110 | d2+=(new_location[i][k]-new_location[j][k])**2 111 | if d2member_j_cluster_dens: 115 | all_other_index = member==member_j 116 | stock[all_other_index]=0 117 | member[all_other_index]=member_i 118 | else: 119 | all_other_index = member==member_i 120 | stock[all_other_index]=0 121 | member[all_other_index]=member_j 122 | #after_count = len(np.argwhere(stock==1)) 123 | #print("after further merge, before count: %d, after count: %d"%(before_count,after_count)) 124 | return stock,member 125 | 126 | 127 | def mean_shift_merge(predict_array,cutoff=0.1): 128 | #generate mean shift detections 129 | #cutoff=0.1 130 | bandwidth = 2 131 | gstep = 1 132 | fs = (bandwidth / gstep) * 0.5 133 | fs = fs * fs 134 | fsiv = 1 / (float(fs)) 135 | fmaxd = (bandwidth / gstep) * 2.0 136 | if cutoff==1: 137 | location = np.argwhere(predict_array>=cutoff) 138 | else: 139 | location = np.argwhere(predict_array>cutoff) 140 | count_loc = len(location) 141 | tmp_xdim=predict_array.shape[0] 142 | tmp_ydim = predict_array.shape[1] 143 | 144 | new_location, new_density = carry_shift(location, count_loc, fmaxd, fsiv, 145 | tmp_xdim, tmp_ydim, predict_array) 146 | if len(new_location)==0 or len(new_density)==0: 147 | return [] 148 | dmin = np.min(new_density) 149 | dmax = np.max(new_density) 150 | drange = dmax - dmin 151 | print('here we get the density range %f' % drange) 152 | rv_range = 1.0 / drange 153 | rdcut = 0.01 154 | stock = np.ones(count_loc) 155 | d2cut = 2 ** 2 #important merge criteria 156 | member = np.arange(count_loc) 157 | 158 | stock, member = acc_merge_point(count_loc, new_density, dmin, 159 | rv_range, rdcut, stock, 160 | new_location, d2cut,member) 161 | #further merge, if we observe any two subpoints are close than 2, then we still merge them. 162 | stock, member = further_merge_point(count_loc, new_density, stock, 163 | new_location, d2cut,member) 164 | final_loc_list=[] 165 | for i in range(count_loc): 166 | if stock[i] == 1: 167 | final_loc_list.append(new_location[i]) 168 | final_loc_list =np.stack(final_loc_list,axis=0) 169 | return final_loc_list -------------------------------------------------------------------------------- /ops/smooth_matrix.py: -------------------------------------------------------------------------------- 1 | 2 | from ops.sparse_ops import sym_mat 3 | from scipy.sparse import triu 4 | import os 5 | import numpy as np 6 | import scipy.sparse as sp 7 | import math 8 | 9 | def trimDiags(a: sp.coo_matrix, iDiagMax: int, bKeepMain: bool): 10 | """Remove diagonal elements whose diagonal index is >= iDiagMax 11 | or is == 0 12 | 13 | Args: 14 | a: Input scipy coo_matrix 15 | iDiagMax: Diagonal offset cutoff 16 | bKeepMain: If true, keep the elements in the main diagonal; 17 | otherwise remove them 18 | 19 | Returns: 20 | coo_matrix with the specified diagonals removed 21 | """ 22 | gDist = np.abs(a.row - a.col) 23 | idx = np.where((gDist < iDiagMax) & (bKeepMain | (gDist != 0))) 24 | return sp.coo_matrix((a.data[idx], (a.row[idx], a.col[idx])), 25 | shape=a.shape, dtype=a.dtype) 26 | def meanFilterSparse(a: sp.coo_matrix, h: int): 27 | """Apply a mean filter to an input sparse matrix. This convolves 28 | the input with a kernel of size 2*h + 1 with constant entries and 29 | subsequently reshape the output to be of the same shape as input 30 | 31 | Args: 32 | a: `sp.coo_matrix`, Input matrix to be filtered 33 | h: `int` half-size of the filter 34 | 35 | Returns: 36 | `sp.coo_matrix` filterd matrix 37 | """ 38 | assert h > 0, "meanFilterSparse half-size must be greater than 0" 39 | assert sp.issparse(a) and a.getformat() == 'coo',\ 40 | "meanFilterSparse input matrix is not scipy.sparse.coo_matrix" 41 | assert a.shape[0] == a.shape[1],\ 42 | "meanFilterSparse cannot handle non-square matrix" 43 | fSize = 2 * h + 1 44 | # filter is a square matrix of constant 1 of shape (fSize, fSize) 45 | shapeOut = np.array(a.shape) + fSize - 1 46 | mToeplitz = sp.diags(np.ones(fSize), 47 | np.arange(-fSize+1, 1), 48 | shape=(shapeOut[1], a.shape[1]), 49 | format='csr') 50 | ans = sp.coo_matrix((mToeplitz @ a) @ mToeplitz.T) 51 | # remove the edges since we don't care about them if we are smoothing 52 | # the matrix itself 53 | ansNoEdge = ans.tocsr()[h:(h+a.shape[0]), h:(h+a.shape[1])].tocoo() 54 | # Assign different number of neighbors to the edge to better 55 | # match what the original R implementation of HiCRep does 56 | rowDist2Edge = np.minimum(ansNoEdge.row, ansNoEdge.shape[0] - 1 - ansNoEdge.row) 57 | nDim1 = h + 1 + np.minimum(rowDist2Edge, h) 58 | colDist2Edge = np.minimum(ansNoEdge.col, ansNoEdge.shape[1] - 1 - ansNoEdge.col) 59 | nDim2 = h + 1 + np.minimum(colDist2Edge, h) 60 | nNeighbors = nDim1 * nDim2 61 | ansNoEdge.data /= nNeighbors 62 | return ansNoEdge 63 | 64 | 65 | def smooth_matrix(input_dict,dMax=200,hsize=11,max_value=1000): 66 | 67 | new_dict={} 68 | size_limit=224 69 | for key in input_dict: 70 | current_mat = input_dict[key] 71 | current_mat.data = np.minimum(max_value,current_mat.data) 72 | current_mat = sym_mat(current_mat) 73 | if current_mat.shape[0]<=size_limit: 74 | continue 75 | nDiags = current_mat.shape[0] if dMax < 0 else min(dMax, current_mat.shape[0]) 76 | m1 = trimDiags(current_mat, nDiags, False) 77 | 78 | if hsize>0: 79 | # apply smoothing 80 | #m1.data = np.log10(m1.data+1) 81 | m1 = meanFilterSparse(m1, hsize) 82 | #m1.data = np.power(10,m1.data)-1 83 | new_dict[key]=triu(m1,0,format='coo') 84 | return new_dict 85 | from ops.io_utils import load_pickle, write_pickle 86 | def smooth_pkl(input_pkl,output_pkl, 87 | dMax=200,hsize=11,max_value=1000): 88 | input_dict = load_pickle(input_pkl) 89 | new_dict = smooth_matrix(input_dict,dMax,hsize,max_value) 90 | write_pickle(new_dict,output_pkl) 91 | return output_pkl -------------------------------------------------------------------------------- /ops/sparse_ops.py: -------------------------------------------------------------------------------- 1 | from scipy.sparse import triu,coo_matrix 2 | import numpy as np 3 | def sym_mat(input_array): 4 | down_array = triu(input_array,1,format='coo').T 5 | new_row = np.concatenate([input_array.row,down_array.row]) 6 | new_col = np.concatenate([input_array.col,down_array.col]) 7 | new_data = np.concatenate([input_array.data,down_array.data]) 8 | shape=input_array.shape 9 | final_array = coo_matrix((new_data,(new_row,new_col)),shape=shape) 10 | return final_array 11 | 12 | 13 | def array_to_coo(array): 14 | """ 15 | Convert a regular 2D NumPy array to a scipy.sparse.coo_matrix. 16 | 17 | Parameters: 18 | - array (numpy.ndarray): The input 2D array. 19 | 20 | Returns: 21 | - scipy.sparse.coo_matrix: The converted COO matrix. 22 | """ 23 | # Find the non-zero elements in the array 24 | row, col = np.nonzero(array) 25 | 26 | # Get the values of the non-zero elements 27 | data = array[row, col] 28 | 29 | # Create the COO matrix 30 | coo_mat = coo_matrix((data, (row, col)), shape=array.shape) 31 | 32 | return coo_mat 33 | 34 | def filter_sparse_region(input_row,input_col,input_data, 35 | start_index,end_index): 36 | """ 37 | input_row: the row index of the sparse matrix 38 | input_col: the column index of the sparse matrix 39 | input_data: the data of the sparse matrix 40 | start_index: the start index of the region to filter 41 | end_index: the end index of the region to filter 42 | """ 43 | select_index1 = (input_row>=start_index) & (input_row=start_index) & (input_col=start_row) & (input_row=start_col) & (input_colbchw",samples,imagenet_std) 33 | new_samples = torch.clip((new_samples+ imagenet_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)) * 255, 0, 255) 34 | return new_samples 35 | 36 | def torch_to_nparray(data): 37 | #https://github.com/pytorch/pytorch/blob/main/torch/utils/tensorboard/summary.py 38 | #image take n,c,h,w, 39 | """ 40 | 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8). 41 | The image() function will scale the image values to [0, 255] by applying 42 | a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values 43 | will be clipped. 44 | 45 | """ 46 | data = data.cpu().numpy() 47 | #data = data.transpose(0,2,3,1) 48 | data=np.array(data,dtype=np.uint8) 49 | return data 50 | 51 | def convert_gray_rgbimage(samples): 52 | """ 53 | input: B,H,W 54 | """ 55 | #add dimension in 1st dim 56 | if len(samples.shape)==3: 57 | samples = samples.unsqueeze(1) 58 | samples = torch.clip(samples, 0, 1) 59 | red_channel = torch.ones(samples.shape,device=samples.device) 60 | gb_channel = 1-samples 61 | new_samples=torch.cat([red_channel,gb_channel,gb_channel],dim=1)*255 62 | return new_samples -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | # My code has references to the following repositories: 2 | # DeiT: https://github.com/facebookresearch/deit 3 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 4 | # MAE: https://github.com/facebookresearch/mae 5 | # AdPE: https://github.com/maple-research-lab/AdPE 6 | # -------------------------------------------------------- 7 | import os 8 | from ops.argparser import argparser_pretrain 9 | import torch 10 | import torch.multiprocessing as mp 11 | import timm 12 | assert timm.__version__ == "0.3.2" # version check 13 | 14 | 15 | def main(args): 16 | import socket 17 | hostname = socket.gethostname() 18 | local_ip = socket.gethostbyname(hostname) 19 | print("local ip: ",local_ip) 20 | 21 | ngpus_per_node = torch.cuda.device_count() 22 | args.world_size = args.world_size*ngpus_per_node 23 | from pretrain.main_worker import main_worker 24 | if ngpus_per_node==1: 25 | main_worker(args.gpu,ngpus_per_node,args)#if you only have one gpu 26 | else: 27 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 28 | 29 | 30 | 31 | 32 | if __name__ == '__main__': 33 | import resource 34 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 35 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096*2, rlimit[1])) 36 | limit_in_b = 900 * 1024 ** 3 37 | resource.setrlimit(resource.RLIMIT_DATA, (limit_in_b, limit_in_b)) 38 | use_cuda = torch.cuda.is_available() 39 | print("starting check cuda status",use_cuda) 40 | #assert cuda is available 41 | assert use_cuda == True, "CUDA is not available, pre-training requires CUDA support to run!" 42 | parser = argparser_pretrain() 43 | args = parser.parse_args() 44 | #If you have many GPU on your server, but you only want to use few of them 45 | # run command line to configure the environment: 46 | # export CUDA_VISIBLE_DEVICES="0,1,2,3" 47 | # Here you can specify the GPU you want to use 48 | #check the specied input size, must be a multiple of args.patch_size 49 | if args.input_row_size%args.patch_size!=0 or args.input_col_size%args.patch_size!=0: 50 | print("args configuration error: input_row_size and input_col_size must be a multiple of patch_size") 51 | exit(1) 52 | main(args) 53 | -------------------------------------------------------------------------------- /pretrain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/pretrain/__init__.py -------------------------------------------------------------------------------- /pretrain/main_worker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.backends.cudnn as cudnn 5 | import time 6 | import datetime 7 | import json 8 | import numpy as np 9 | import torchvision.transforms as transforms 10 | import timm.optim.optim_factory as optim_factory 11 | 12 | from ops.distribute_utils import init_distributed_mode,get_world_size,get_rank,is_main_process 13 | from ops.Logger import print_important_info,print_warning_info 14 | from data_processing.pretrain_dataset import Pretrain_Dataset 15 | from data_processing.collate_fn import collate_fn 16 | import model.models_hicfoundation as models_hicfoundation 17 | from model.NativeScaler import NativeScalerWithGradNormCount as NativeScaler 18 | from model.model_utils import load_model,save_checkpoint,save_model2path 19 | from ops.io_utils import write_log 20 | from pretrain.train_epoch import train_epoch 21 | from pretrain.val_epoch import val_epoch 22 | 23 | def parse_text(config_file, data_dir): 24 | train_list=[] 25 | with open(config_file) as f: 26 | for line in f: 27 | line = line.strip() 28 | line = line.replace('\n', '') 29 | if len(line) == 0: 30 | continue 31 | current_path = os.path.join(data_dir, line) 32 | if not os.path.exists(current_path): 33 | print("The sub-directory {} does not exist in the data directory".format(current_path)) 34 | print("Please check the sub-directory name in the {} file".format(config_file)) 35 | continue 36 | train_list.append(current_path) 37 | return train_list 38 | 39 | def configure_data_loader(args): 40 | transform_train = transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 43 | data_dir=os.path.abspath(args.data_path) 44 | train_config= os.path.abspath(args.train_config) 45 | train_list = parse_text(train_config, data_dir) 46 | val_config= os.path.abspath(args.valid_config) 47 | val_list = parse_text(val_config, data_dir) 48 | input_row_size = args.input_row_size 49 | input_col_size = args.input_col_size 50 | sparsity_filter = float(args.sparsity_ratio) 51 | patch_size = args.patch_size 52 | 53 | dataset_train = Pretrain_Dataset(train_list,transform=transform_train, 54 | sparsity_filter=sparsity_filter,patch_size=patch_size, 55 | window_height=input_row_size,window_width=input_col_size) 56 | dataset_val = Pretrain_Dataset(val_list,transform=transform_train, 57 | sparsity_filter=sparsity_filter, patch_size=patch_size, 58 | window_height=input_row_size,window_width=input_col_size) 59 | 60 | if args.distributed: 61 | num_tasks = get_world_size() 62 | global_rank = get_rank() 63 | sampler_train = torch.utils.data.DistributedSampler( 64 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 65 | ) 66 | print("Sampler_train = %s" % str(sampler_train)) 67 | sampler_val = torch.utils.data.DistributedSampler( 68 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False 69 | ) 70 | else: 71 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 72 | sampler_val = torch.utils.data.RandomSampler(dataset_val) 73 | global_rank = -1 74 | sample_batch_size = args.batch_size 75 | data_loader_train = torch.utils.data.DataLoader( 76 | dataset_train, batch_size=sample_batch_size, sampler=sampler_train, 77 | num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, 78 | collate_fn=collate_fn 79 | ) 80 | data_loader_val = torch.utils.data.DataLoader( 81 | dataset_val, batch_size=sample_batch_size, sampler=sampler_val, 82 | num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False, 83 | collate_fn=collate_fn 84 | ) 85 | return data_loader_train, data_loader_val 86 | 87 | def config_writer(output_dir,tensorboard_log): 88 | tensorboard_dir = os.path.join(output_dir,'tensorboard') 89 | os.makedirs(tensorboard_dir,exist_ok=True) 90 | if tensorboard_log: 91 | from torch.utils.tensorboard import SummaryWriter 92 | log_writer = SummaryWriter(tensorboard_dir) 93 | else: 94 | log_writer = None 95 | return log_writer 96 | def main_worker(gpu, ngpus_per_node,args): 97 | if ngpus_per_node>1: 98 | init_distributed_mode(gpu,ngpus_per_node,args) 99 | else: 100 | args.distributed=False 101 | print_warning_info("The distributed mode is disabled.\n For pre-training, one GPU may take very long to train!") 102 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 103 | print("{}".format(args).replace(', ', ',\n')) 104 | if args.distributed: 105 | num_tasks = get_world_size() 106 | global_rank = get_rank() 107 | else: 108 | global_rank = -1 109 | num_tasks = 1 110 | output_dir = os.path.abspath(args.output) 111 | if global_rank==0: 112 | os.makedirs(output_dir,exist_ok=True) 113 | log_writer =config_writer(output_dir,args.tensorboard) 114 | elif args.distributed: 115 | log_writer = None 116 | else: 117 | os.makedirs(output_dir,exist_ok=True) 118 | log_writer = config_writer(output_dir,args.tensorboard) 119 | 120 | cudnn.benchmark = True 121 | device = torch.device(args.device) 122 | 123 | # Data loading code 124 | data_loader_train, data_loader_val = configure_data_loader(args) 125 | print("Data loader is configured!") 126 | 127 | # Configure the model 128 | patch_wise_size = (args.input_row_size//args.patch_size,args.input_col_size//args.patch_size) 129 | model = models_hicfoundation.__dict__[args.model](img_size=(args.input_row_size,args.input_col_size)) 130 | 131 | model.to(device) 132 | model_without_ddp = model 133 | 134 | if args.distributed: 135 | #not necessary for current setting, since all param with grad 136 | #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 137 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],find_unused_parameters=True) 138 | model_without_ddp = model.module 139 | else: 140 | model_without_ddp =model 141 | #print out model information 142 | print("Model is configured!") 143 | 144 | #configure batch size, learning rate, weight decay 145 | eff_batch_size = args.batch_size * args.accum_iter*get_world_size() 146 | args.lr = args.blr * eff_batch_size / 256 147 | print("Learning rate: %.6f"%args.lr) 148 | print("Accumulative grad iteration: %d"%args.accum_iter) 149 | print("Effective batch size: %d"%eff_batch_size) 150 | print("Base learning rate: %.6f"%args.blr) 151 | 152 | #configure optimizer 153 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 154 | #print(param_groups) #too long printing 155 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 156 | loss_scaler = NativeScaler() 157 | 158 | #resume model if it is loading from checkpoint 159 | print("Optimizer is configured!") 160 | 161 | resume_path=os.path.abspath(args.resume) 162 | load_model(resume_path,args, model_without_ddp, optimizer, loss_scaler) 163 | 164 | model_dir = os.path.join(output_dir, 'model') 165 | os.makedirs(model_dir,exist_ok=True) 166 | log_dir = os.path.join(output_dir, 'log') 167 | os.makedirs(log_dir,exist_ok=True) 168 | 169 | epochs = int(args.epochs) 170 | start_epoch = int(args.start_epoch) 171 | start_time = time.time() 172 | print("Start pre-training from epoch %d"%start_epoch," to epoch %d"%epochs) 173 | save_freq = args.save_freq 174 | best_loss = 1e9 175 | for epoch in range(start_epoch, epochs): 176 | if args.distributed: 177 | data_loader_train.sampler.set_epoch(epoch) 178 | train_stats = train_epoch( 179 | model, data_loader_train, 180 | optimizer, device, epoch, loss_scaler, 181 | log_writer=log_writer, 182 | args=args 183 | ) 184 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 185 | 'epoch': epoch,} 186 | if is_main_process(): 187 | write_log(log_dir,"train",log_stats) 188 | 189 | #validation run 190 | val_stats = val_epoch( 191 | model, data_loader_val, 192 | device, epoch, 193 | log_writer=log_writer, 194 | args=args 195 | ) 196 | val_loss = val_stats['loss'] 197 | log_stats_val = {**{f'val_{k}': v for k, v in val_stats.items()}, 198 | 'epoch': epoch,} 199 | if is_main_process(): 200 | write_log(log_dir,"val",log_stats_val) 201 | if epoch%save_freq==0 or epoch==epochs-1: 202 | #output_dir, args,epoch, model_without_ddp, optimizer, loss_scaler 203 | save_checkpoint(model_dir, args,epoch, model_without_ddp, optimizer, loss_scaler) 204 | 205 | 206 | if val_loss < best_loss: 207 | best_loss = val_loss 208 | #model_path,args,epoch, model_without_ddp, optimizer, loss_scaler 209 | model_path = os.path.join(model_dir, 'model_best.pth.tar') 210 | save_model2path(model_path,args,epoch, model_without_ddp, optimizer, loss_scaler) 211 | total_time = time.time()-start_time 212 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 213 | print('Training time {}'.format(total_time_str)) 214 | print("Fine-tuning of HiCFoundation is finished!") 215 | print("The model is saved in {}".format(model_dir)) 216 | print("The log is saved in {}".format(log_dir)) 217 | -------------------------------------------------------------------------------- /pretrain/train_epoch.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import sys 4 | import numpy as np 5 | from typing import Iterable 6 | import torch 7 | import torch.nn.functional as F 8 | import time 9 | 10 | from ops.Logger import MetricLogger,SmoothedValue 11 | import model.lr_sched as lr_sched 12 | 13 | from ops.train_utils import list_to_device, to_value, create_image, torch_to_nparray 14 | 15 | 16 | def train_epoch(model,data_loader,optimizer, 17 | device, epoch, loss_scaler, 18 | log_writer=None,args=None): 19 | model.train() 20 | 21 | metric_logger = MetricLogger(delimiter=" ") 22 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) 23 | header = 'Epoch: [{}]'.format(epoch) 24 | print_freq = args.print_freq 25 | 26 | accum_iter = args.accum_iter 27 | 28 | optimizer.zero_grad() 29 | 30 | if log_writer is not None: 31 | print('log_dir: {}'.format(log_writer.log_dir)) 32 | print("number of iterations: ",len(data_loader)) 33 | num_iter = len(data_loader) 34 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 35 | if data_iter_step % accum_iter == 0: 36 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 37 | input_matrix, mask_matrix, hic_count, return_diag,matrix_count = list_to_device(data,device=device) 38 | with torch.cuda.amp.autocast(): #to enable mixed precision training 39 | ssim_loss,contrastive_loss, count_pred, pred_image, mask \ 40 | = model(input_matrix, mask_matrix, total_count=hic_count, \ 41 | diag=return_diag,mask_ratio=args.mask_ratio) 42 | 43 | matrix_count = torch.log10(matrix_count+1) 44 | count_pred = count_pred.flatten() 45 | count_loss = torch.nn.functional.mse_loss(count_pred, matrix_count) 46 | loss = args.loss_alpha*(ssim_loss+count_loss) + contrastive_loss 47 | metric_logger.update(loss=to_value(loss)) 48 | metric_logger.update(ssim_loss=to_value(ssim_loss)) 49 | metric_logger.update(count_loss=to_value(count_loss)) 50 | metric_logger.update(contrastive_loss=to_value(contrastive_loss)) 51 | if not math.isfinite(to_value(loss)): 52 | print("Loss is {}, stopping training".format(to_value(loss))) 53 | #sys.exit(1) 54 | optimizer.zero_grad() 55 | continue 56 | loss = loss / accum_iter 57 | loss_scaler(loss, optimizer, parameters=model.parameters(), 58 | update_grad=(data_iter_step + 1) % accum_iter == 0) 59 | 60 | if (data_iter_step + 1) % accum_iter == 0: 61 | optimizer.zero_grad() 62 | 63 | torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on 64 | lr = optimizer.param_groups[0]["lr"] 65 | metric_logger.update(lr=lr) 66 | 67 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0): 68 | """ 69 | We use epoch_1000x as the x-axis in tensorboard. 70 | This calibrates different curves when batch size changes. 71 | """ 72 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 73 | log_writer.add_scalars('Loss/loss', {'train_loss': to_value(loss)}, epoch_1000x) 74 | log_writer.add_scalars('Loss/ssim_loss', {'train_loss': to_value(ssim_loss)}, epoch_1000x) 75 | log_writer.add_scalars('Loss/count_loss', {'train_loss': to_value(count_loss)}, epoch_1000x) 76 | log_writer.add_scalars('Loss/contrastive_loss', {'train_loss': to_value(contrastive_loss)}, epoch_1000x) 77 | log_writer.add_scalars('LR/lr', {'lr': lr}, epoch_1000x) 78 | #add visualization 79 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0: 80 | new_samples = create_image(input_matrix) 81 | mask_image = new_samples*(1-mask) 82 | pred_image = create_image(pred_image) 83 | 84 | select_num = min(8,len(new_samples)) 85 | new_samples = torch_to_nparray(new_samples.clone().detach()[:select_num]) 86 | mask_image = torch_to_nparray(mask_image.clone().detach()[:select_num]) 87 | pred_image = torch_to_nparray(pred_image.clone().detach()[:select_num]) 88 | log_writer.add_images('Target_%s'%"train", new_samples, epoch_1000x) 89 | log_writer.add_images('Input_%s'%"train", mask_image, epoch_1000x) 90 | log_writer.add_images('Pred_%s'%"train", pred_image, epoch_1000x) 91 | metric_logger.synchronize_between_processes() 92 | print("Averaged stats:", metric_logger) 93 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /pretrain/val_epoch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import numpy as np 4 | from typing import Iterable 5 | import torch 6 | import torch.nn.functional as F 7 | import time 8 | 9 | from ops.Logger import MetricLogger,SmoothedValue 10 | import model.lr_sched as lr_sched 11 | 12 | from ops.train_utils import list_to_device, to_value, create_image, torch_to_nparray 13 | 14 | def val_epoch(model, data_loader,device, epoch, 15 | log_writer=None, 16 | args=None,flag='val'): 17 | model.eval() 18 | metric_logger = MetricLogger(delimiter=" ") 19 | header = 'Val Epoch: [{}]'.format(epoch) 20 | print_freq = args.print_freq 21 | accum_iter = args.accum_iter 22 | if log_writer is not None: 23 | print('log_dir: {}'.format(log_writer.log_dir)) 24 | print("number of iterations: ",len(data_loader)) 25 | num_iter = len(data_loader) 26 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 27 | input_matrix, mask_matrix, hic_count, return_diag,matrix_count = list_to_device(data,device=device) 28 | with torch.no_grad(): #to enable mixed precision training 29 | ssim_loss,contrastive_loss, count_pred, pred_image, mask \ 30 | = model(input_matrix, mask_matrix, total_count=hic_count, \ 31 | diag=return_diag,mask_ratio=args.mask_ratio) 32 | matrix_count = torch.log10(matrix_count+1) 33 | count_pred = count_pred.flatten() 34 | count_loss = torch.nn.functional.mse_loss(count_pred, matrix_count) 35 | loss = args.loss_alpha*(ssim_loss+count_loss) + contrastive_loss 36 | metric_logger.update(loss=to_value(loss)) 37 | metric_logger.update(ssim_loss=to_value(ssim_loss)) 38 | metric_logger.update(count_loss=to_value(count_loss)) 39 | metric_logger.update(contrastive_loss=to_value(contrastive_loss)) 40 | torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on 41 | 42 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0): 43 | """ 44 | We use epoch_1000x as the x-axis in tensorboard. 45 | This calibrates different curves when batch size changes. 46 | """ 47 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 48 | log_writer.add_scalars('Loss/loss', {'%s_loss'%flag: to_value(loss)}, epoch_1000x) 49 | log_writer.add_scalars('Loss/ssim_loss', {'%s_loss'%flag: to_value(ssim_loss)}, epoch_1000x) 50 | log_writer.add_scalars('Loss/count_loss', {'%s_loss'%flag: to_value(count_loss)}, epoch_1000x) 51 | log_writer.add_scalars('Loss/contrastive_loss', {'%s_loss'%flag: to_value(contrastive_loss)}, epoch_1000x) 52 | #add visualization 53 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0: 54 | new_samples = create_image(input_matrix) 55 | mask_image = new_samples*(1-mask) 56 | pred_image = create_image(pred_image) 57 | 58 | select_num = min(8,len(new_samples)) 59 | new_samples = torch_to_nparray(new_samples.clone().detach()[:select_num]) 60 | mask_image = torch_to_nparray(mask_image.clone().detach()[:select_num]) 61 | pred_image = torch_to_nparray(pred_image.clone().detach()[:select_num]) 62 | log_writer.add_images('Target_%s'%flag, new_samples, epoch_1000x) 63 | log_writer.add_images('Input_%s'%flag, mask_image, epoch_1000x) 64 | log_writer.add_images('Pred_%s'%flag, pred_image, epoch_1000x) 65 | metric_logger.synchronize_between_processes() 66 | print("Averaged stats:", metric_logger) 67 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /refactor_pretrain/Logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code here is borrowed from the https://github.com/facebookresearch/deit/blob/main/utils.py 3 | This file includes two classes SmoothValue and MetricLoggers that are directly copied from the link above 4 | TODO: 5 | print_important_info and print_warning_info are new 6 | move these classes and the new functions to a utils.py file 7 | """ 8 | 9 | 10 | from collections import defaultdict, deque 11 | import torch 12 | from distribute_utils import is_dist_avail_and_initialized 13 | import torch.distributed as dist 14 | import datetime 15 | import os 16 | import time 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if v is None: 88 | continue 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError("'{}' object has no attribute '{}'".format( 100 | type(self).__name__, attr)) 101 | 102 | def __str__(self): 103 | loss_str = [] 104 | for name, meter in self.meters.items(): 105 | loss_str.append( 106 | "{}: {}".format(name, str(meter)) 107 | ) 108 | return self.delimiter.join(loss_str) 109 | 110 | def synchronize_between_processes(self): 111 | for meter in self.meters.values(): 112 | meter.synchronize_between_processes() 113 | 114 | def add_meter(self, name, meter): 115 | self.meters[name] = meter 116 | def log(self,iteration,print_freq,header=None): 117 | if iteration % print_freq == 0: 118 | return 119 | 120 | if not header: 121 | header = '' 122 | #space_fmt = ':' + str(iteration) + 'd' 123 | log_msg = [ 124 | header, 125 | '['+str(iteration)+']', 126 | '{meters}', 127 | ] 128 | if torch.cuda.is_available(): 129 | log_msg.append('max mem: {memory:.0f}') 130 | MB = 1024.0 * 1024.0 131 | log_msg = self.delimiter.join(log_msg) 132 | if torch.cuda.is_available(): 133 | print(log_msg.format( 134 | iteration, 135 | meters=str(self), 136 | memory=torch.cuda.max_memory_allocated() / MB)) 137 | else: 138 | print(log_msg.format( 139 | iteration, 140 | meters=str(self))) 141 | def log_every(self, iterable, print_freq, header=None): 142 | i = 0 143 | if not header: 144 | header = '' 145 | start_time = time.time() 146 | end = time.time() 147 | iter_time = SmoothedValue(fmt='{avg:.4f}') 148 | data_time = SmoothedValue(fmt='{avg:.4f}') 149 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 150 | log_msg = [ 151 | header, 152 | '[{0' + space_fmt + '}/{1}]', 153 | 'eta: {eta}', 154 | '{meters}', 155 | 'time: {time}', 156 | 'data: {data}' 157 | ] 158 | if torch.cuda.is_available(): 159 | log_msg.append('max mem: {memory:.0f}') 160 | log_msg = self.delimiter.join(log_msg) 161 | MB = 1024.0 * 1024.0 162 | for obj in iterable: 163 | data_time.update(time.time() - end) 164 | yield obj 165 | iter_time.update(time.time() - end) 166 | if i % print_freq == 0 or i == len(iterable) - 1: 167 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 168 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 169 | if torch.cuda.is_available(): 170 | print(log_msg.format( 171 | i, len(iterable), eta=eta_string, 172 | meters=str(self), 173 | time=str(iter_time), data=str(data_time), 174 | memory=torch.cuda.max_memory_allocated() / MB),flush=True) 175 | else: 176 | print(log_msg.format( 177 | i, len(iterable), eta=eta_string, 178 | meters=str(self), 179 | time=str(iter_time), data=str(data_time)),flush=True) 180 | i += 1 181 | end = time.time() 182 | total_time = time.time() - start_time 183 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 184 | print('{} Total time: {} ({:.4f} s / it)'.format( 185 | header, total_time_str, total_time / len(iterable)),flush=True) 186 | 187 | 188 | def print_important_info(str): 189 | print("="*50) 190 | print("Important:"+str) 191 | print("="*50) 192 | 193 | def print_warning_info(str): 194 | print("*"*50) 195 | print("Warning:"+str) 196 | print("*"*50) -------------------------------------------------------------------------------- /refactor_pretrain/distribute_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code here is borrowed from the https://github.com/facebookresearch/deit/blob/main/utils.py 3 | This file includes utility scripts for handling distributed training 4 | """ 5 | 6 | import os 7 | import torch 8 | import resource 9 | import datetime 10 | import builtins 11 | import numpy as np 12 | import torch.distributed as dist 13 | 14 | def is_dist_avail_and_initialized(): 15 | """ 16 | This function checks a distributed process, whether its initialized properly 17 | It returns true, if all the requisites, rank, world_size parameters, seeds etc.. are setup 18 | else it returns false 19 | Args: 20 | None 21 | Returns: 22 | bool: True if distributed backend is usable. 23 | 24 | """ 25 | if not dist.is_available(): 26 | return False 27 | if not dist.is_initialized(): 28 | return False 29 | return True 30 | 31 | def get_rank(): 32 | """ 33 | This function fetches the rank of a distributed process in a distributed processes group. 34 | Ranks range from [0 - (world_size -1)], rank 0 is assigned to the main (master) worker 35 | In a non-distributed setting the rank of the main process is also 0 36 | Args: 37 | None 38 | Returns: 39 | int: Rank of the process 40 | 41 | """ 42 | if not is_dist_avail_and_initialized(): 43 | return 0 44 | return dist.get_rank() 45 | 46 | def get_world_size(): 47 | """ 48 | This function fetches the world_size which is the total number of workers in a distributed group. 49 | In a non-distributed setting the world size is 1 and the main process's rank is 0 50 | Args: 51 | None 52 | Returns: 53 | int: world_size (number of processes in a distributed group) 54 | 55 | """ 56 | if not is_dist_avail_and_initialized(): 57 | return 1 58 | return dist.get_world_size() 59 | 60 | def all_reduce_mean(x): 61 | """ 62 | This function fetches value of a particular variable 'x' from all the distributed processes and 63 | returns a mean of mean of that value. Specifically, it could be used to fetch loss for a batch of input 64 | values across the distributed processes and can return the mean of that loss for the subsequent backprop. 65 | This can be used to gather evaluation metrics from other processes too. 66 | Args: 67 | x (float): A scalar value to be averaged. 68 | Returns: 69 | float: Average of x across all the distributed processes 70 | """ 71 | world_size = get_world_size() 72 | if world_size > 1: 73 | x_reduce = torch.tensor(x).cuda() 74 | dist.all_reduce(x_reduce) 75 | x_reduce /= world_size 76 | return x_reduce.item() 77 | else: 78 | return x 79 | 80 | def is_main_process(): 81 | """ 82 | A binary check for testing whether the process we are operating on is the main process 83 | The check is simple, the rank of the current process should be 0 for it to considered main 84 | Args: 85 | None 86 | Returns: 87 | bool: True if main process 88 | """ 89 | 90 | return get_rank() == 0 91 | 92 | 93 | def setup_for_distributed(is_master): 94 | """ 95 | This is a function that overrides the python print function to be only 96 | accessible from the main process. This convention is followed to keep the logs clean 97 | Args: 98 | is_master (bool): Boolean input 99 | Returns: 100 | None 101 | """ 102 | builtin_print = builtins.print 103 | def print(*args, **kwargs): 104 | if is_master: 105 | now = datetime.datetime.now().time() 106 | builtin_print('[{}] '.format(now), end='') # print with time stamp 107 | builtin_print(*args, **kwargs) 108 | builtins.print = print 109 | 110 | def init_distributed_mode(gpu, ngpus_per_node, args): 111 | """ 112 | Initialize distributed training environment by setting a few os paramters and then setup 113 | distributed training across provided number of gpus across the provided number of nodes. 114 | Args: 115 | gpu (int): GPU index for current process. 116 | ngpus_per_node (int): Number of GPUs per node. 117 | args (Namespace): Arguments object containing attributes: 118 | - rank (int): Base rank of the node. 119 | - world_size (int): Total number of processes. 120 | - dist_url (str): Initialization URL for distributed training. 121 | - seed (int): Seed for random number generators. 122 | Returns: 123 | None 124 | """ 125 | 126 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 127 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) 128 | args.gpu = gpu 129 | args.rank = args.rank * ngpus_per_node + gpu 130 | os.environ['LOCAL_RANK'] = str(args.gpu) 131 | os.environ['RANK'] = str(args.rank) 132 | os.environ['WORLD_SIZE'] = str(args.world_size) 133 | print("make sure the distributed mode is ",args.dist_url) 134 | 135 | 136 | 137 | args.distributed = True 138 | 139 | torch.cuda.set_device(args.gpu) 140 | args.dist_backend = 'nccl' 141 | print('| distributed init (rank {}): {}, gpu {}'.format( 142 | args.rank, args.dist_url, args.gpu), flush=True) 143 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 144 | timeout=datetime.timedelta(seconds=36000), 145 | world_size=args.world_size, rank=args.rank) 146 | 147 | setup_for_distributed(args.rank == 0) 148 | 149 | # fix the seed for reproducibility 150 | seed = args.seed + get_rank() 151 | torch.manual_seed(seed) 152 | np.random.seed(seed) 153 | -------------------------------------------------------------------------------- /refactor_pretrain/model_funcs.py: -------------------------------------------------------------------------------- 1 | # From MAE: https://github.com/facebookresearch/mae 2 | 3 | import math 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def adjust_learning_rate(optimizer, epoch, args): 9 | """ 10 | Decay the learning rate with half-cycle cosine after warmup 11 | :param optimizer: torch.optimizer 12 | :param epoch: int, training epoch 13 | :param args: argparse arguments 14 | :return: learning_rate 15 | """ 16 | if epoch < args.warmup_epochs: 17 | lr = args.lr * epoch / args.warmup_epochs 18 | else: 19 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 20 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 21 | for param_group in optimizer.param_groups: 22 | if "lr_scale" in param_group: 23 | param_group["lr"] = lr * param_group["lr_scale"] 24 | else: 25 | param_group["lr"] = lr 26 | return lr 27 | 28 | 29 | def convert_count_to_pos_embed_cuda(count, embed_dim): 30 | """ 31 | count should be log-formatted 32 | :param count: torch.tensor, (N,1) 33 | :param embed_dim: int, embedding dimension 34 | :return: 35 | """ 36 | assert embed_dim % 2 == 0 37 | omega = torch.arange(embed_dim // 2, dtype=count.dtype, device=count.device) 38 | omega = omega / embed_dim / 2. 39 | omega = 1. / 10000 ** omega 40 | 41 | out = torch.einsum('m,d->md', count, omega) # (M, D/2), outer product 42 | emb_sin = torch.sin(out) # (M, D/2) 43 | emb_cos = torch.cos(out) # (M, D/2) 44 | 45 | emb = torch.cat([emb_sin, emb_cos], dim=1) + 1 # enforce different compared to other embeddings 46 | # (M, D) 47 | return emb 48 | 49 | 50 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 51 | """ 52 | grid_size: int of the grid height and width 53 | return: 54 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 55 | """ 56 | grid_h = np.arange(grid_size, dtype=np.float32) 57 | grid_w = np.arange(grid_size, dtype=np.float32) 58 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 59 | grid = np.stack(grid, axis=0) 60 | 61 | grid = grid.reshape([2, 1, grid_size, grid_size]) 62 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 63 | if cls_token: 64 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 65 | return pos_embed 66 | 67 | 68 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 69 | assert embed_dim % 2 == 0 70 | 71 | # use half of dimensions to encode grid_h 72 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 73 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 74 | 75 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 76 | return emb 77 | 78 | 79 | def get_2d_sincos_pos_embed_rectangle(embed_dim, grid_size, cls_token=False): 80 | """ 81 | grid_size, a tuple of height and width 82 | """ 83 | grid_size_h, grid_size_w = grid_size 84 | grid_h = np.arange(grid_size_h, dtype=np.float32) 85 | grid_w = np.arange(grid_size_w, dtype=np.float32) 86 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 87 | grid = np.stack(grid, axis=0) 88 | grid = grid.reshape([2, 1, grid_size_w, grid_size_h]) 89 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 90 | if cls_token: 91 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 92 | return pos_embed 93 | 94 | 95 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 96 | """ 97 | embed_dim: output dimension for each position 98 | pos: a list of positions to be encoded: size (M,) 99 | out: (M, D) 100 | """ 101 | assert embed_dim % 2 == 0 102 | omega = np.arange(embed_dim // 2, dtype=np.float32) 103 | omega /= embed_dim / 2. 104 | omega = 1. / 10000 ** omega # (D/2,) 105 | 106 | pos = pos.reshape(-1) # (M,) 107 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 108 | 109 | emb_sin = np.sin(out) # (M, D/2) 110 | emb_cos = np.cos(out) # (M, D/2) 111 | 112 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 113 | return emb 114 | -------------------------------------------------------------------------------- /refactor_pretrain/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from distribute_utils import is_main_process 3 | from pathlib import Path 4 | import os 5 | from math import inf 6 | 7 | 8 | def load_model(resume_path, args, model_without_ddp, optimizer, loss_scaler): 9 | """ 10 | Load the model from the checkpoint 11 | Args: 12 | resume_path: the path to the checkpoint 13 | model_without_ddp: the model 14 | optimizer: the optimizer 15 | loss_scaler: the loss scaler 16 | """ 17 | if os.path.isfile(resume_path): 18 | print("=> loading checkpoint '{}'".format(resume_path)) 19 | if resume_path.startswith('https'): 20 | checkpoint = torch.hub.load_state_dict_from_url( 21 | resume_path, map_location='cpu', check_hash=True) 22 | else: 23 | checkpoint = torch.load(resume_path, weights_only=False, map_location='cpu') 24 | msg = model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 25 | print("model resume message:{}".format(msg)) 26 | optimizer.load_state_dict(checkpoint['optimizer']) 27 | loss_scaler.load_state_dict(checkpoint['scaler']) 28 | args.start_epoch = checkpoint['epoch'] + 1 29 | print("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch'])) 30 | else: 31 | print("=> no checkpoint found at '{}'".format(resume_path)) 32 | 33 | 34 | def save_on_master(*args, **kwargs): 35 | """ 36 | Save model 37 | :param args: positional arguments, torch.save arguments 38 | :param kwargs: keyword arguments, torch.save arguments 39 | :return: None 40 | """ 41 | if is_main_process(): 42 | torch.save(*args, **kwargs) 43 | 44 | 45 | def save_checkpoint(output_dir, args, epoch, model_without_ddp, optimizer, loss_scaler): 46 | """ 47 | Save model, optimizer, epoch, scaler and arguments as a checkpoint 48 | :param output_dir: str, output directory for saving 49 | :param args: positional arguments 50 | :param epoch: int, training epoch 51 | :param model_without_ddp: torch.nn.module, model with no distributed components 52 | :param optimizer: torch.optimizer, optimizer for model training 53 | :param loss_scaler: torch.cuda.amp.GradScaler, dynamically scales the loss to prevent underflow when using float16 54 | :return: None. 55 | """ 56 | # Save directory 57 | output_dir = Path(output_dir) 58 | # Output epoch 59 | epoch_name = str(epoch) 60 | 61 | # checkpoint path 62 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 63 | for checkpoint_path in checkpoint_paths: 64 | to_save = { 65 | 'model': model_without_ddp.state_dict(), 66 | 'optimizer': optimizer.state_dict(), 67 | 'epoch': epoch, 68 | 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None, 69 | 'args': args, 70 | } 71 | 72 | save_on_master(to_save, checkpoint_path) 73 | 74 | 75 | def save_model2path(model_path, args, epoch, model_without_ddp, optimizer, loss_scaler): 76 | """ 77 | Save model to a certain path 78 | :param model_path: str, model path directory and file name 79 | :param args: positional arguments 80 | :param epoch: int, number of epochs 81 | :param model_without_ddp: torch.nn.module, model with no distributional parameters 82 | :param optimizer: torch.optimizer, optimizer for model training 83 | :param loss_scaler: torch.cuda.amp.GradScaler, dynamically scales the loss to prevent underflow when using float16 84 | :return: None 85 | """ 86 | to_save = { 87 | 'model': model_without_ddp.state_dict(), 88 | 'optimizer': optimizer.state_dict(), 89 | 'epoch': epoch, 90 | 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None, 91 | 'args': args, 92 | } 93 | save_on_master(to_save, model_path) 94 | 95 | 96 | class NativeScalerWithGradNormCount: 97 | state_dict_key = "amp_scaler" 98 | 99 | def __init__(self): 100 | self._scaler = torch.cuda.amp.GradScaler() 101 | 102 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 103 | self._scaler.scale(loss).backward(create_graph=create_graph) 104 | if update_grad: 105 | if clip_grad is not None: 106 | assert parameters is not None 107 | # unscale the gradients of optimizer's assigned params in-place 108 | self._scaler.unscale_(optimizer) 109 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 110 | else: 111 | self._scaler.unscale_(optimizer) 112 | norm = get_grad_norm_(parameters) 113 | self._scaler.step(optimizer) 114 | self._scaler.update() 115 | else: 116 | norm = None 117 | return norm 118 | 119 | def state_dict(self): 120 | return self._scaler.state_dict() 121 | 122 | def load_state_dict(self, state_dict): 123 | self._scaler.load_state_dict(state_dict) 124 | 125 | 126 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 127 | """ 128 | Compute gradient norm 129 | :param parameters: torch.tensor, model parameters 130 | :param norm_type: float, which gradient norm 2.0 -> L2 131 | :return: total norm 132 | """ 133 | if isinstance(parameters, torch.Tensor): 134 | parameters = [parameters] 135 | parameters = [p for p in parameters if p.grad is not None] 136 | norm_type = float(norm_type) 137 | if len(parameters) == 0: 138 | return torch.tensor(0.) 139 | device = parameters[0].grad.device 140 | if norm_type == inf: 141 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 142 | else: 143 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), 144 | norm_type) 145 | return total_norm 146 | -------------------------------------------------------------------------------- /refactor_pretrain/pretrain.py: -------------------------------------------------------------------------------- 1 | # My code has references to the following repositories: 2 | # DeiT: https://github.com/facebookresearch/deit 3 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 4 | # MAE: https://github.com/facebookresearch/mae 5 | # AdPE: https://github.com/maple-research-lab/AdPE 6 | # -------------------------------------------------------- 7 | import os 8 | from argparser import argparser_pretrain 9 | import torch 10 | import torch.multiprocessing as mp 11 | import timm 12 | assert timm.__version__ == "0.3.2" # version check 13 | 14 | 15 | def main(args): 16 | # collect multi-gpu info 17 | import socket 18 | hostname = socket.gethostname() 19 | local_ip = socket.gethostbyname(hostname) 20 | print("local ip: ",local_ip) 21 | 22 | # number of process per node 23 | ngpus_per_node = torch.cuda.device_count() 24 | # total number of process to run across all gpus 25 | # user gives number of nodes to use through args 26 | # here we multiply process per nodes 27 | args.world_size = args.world_size*ngpus_per_node 28 | from main_worker import main_worker 29 | if ngpus_per_node==1: 30 | main_worker(args.gpu,ngpus_per_node,args)#if you only have one gpu 31 | else: 32 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 33 | 34 | 35 | 36 | 37 | if __name__ == '__main__': 38 | import resource 39 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 40 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096*2, rlimit[1])) 41 | limit_in_b = 900 * 1024 ** 3 42 | resource.setrlimit(resource.RLIMIT_DATA, (limit_in_b, limit_in_b)) 43 | use_cuda = torch.cuda.is_available() 44 | print("starting check cuda status",use_cuda) 45 | #assert cuda is available 46 | assert use_cuda == True, "CUDA is not available, pre-training requires CUDA support to run!" 47 | parser = argparser_pretrain() 48 | args = parser.parse_args() 49 | #If you have many GPU on your server, but you only want to use few of them 50 | # run command line to configure the environment: 51 | # export CUDA_VISIBLE_DEVICES="0,1,2,3" 52 | # Here you can specify the GPU you want to use 53 | #check the specied input size, must be a multiple of args.patch_size 54 | if args.input_row_size%args.patch_size!=0 or args.input_col_size%args.patch_size!=0: 55 | print("args configuration error: input_row_size and input_col_size must be a multiple of patch_size") 56 | exit(1) 57 | main(args) 58 | -------------------------------------------------------------------------------- /refactor_pretrain/train_epoch.py: -------------------------------------------------------------------------------- 1 | import sys # Unused import 2 | import time # Unused import 3 | import math 4 | import torch 5 | 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | from typing import Iterable 10 | from Logger import MetricLogger,SmoothedValue 11 | from model_funcs import adjust_learning_rate 12 | from utils import list_to_device, to_value, create_image, torch_to_nparray 13 | 14 | 15 | def train_epoch(model,data_loader,optimizer, 16 | device, epoch, loss_scaler, 17 | log_writer=None,args=None): 18 | 19 | """ 20 | Runs one full training epoch 21 | Args: 22 | model (torch.nn.Module): HiCFoundation model object 23 | data_loader (Iterable): Training Dataloader object 24 | optimizer (torch.optim.Optimizer): Optimizer instance for training. 25 | device (str): Training Device 26 | epoch (int): Current epoch 27 | loss_scaler (callable): NativeScalerWithGradNormCount object 28 | log_writer (optional): Tensorboard object 29 | args (Namespace or dict): 30 | Training configuration with at least: 31 | print_freq (int) - logging frequency in steps 32 | accum_iter (int) - number of steps to accumulate gradients 33 | mask_ratio (float) - masking ratio for model input 34 | loss_alpha (float) - weight for (ssim_loss + count_loss) 35 | Returns: 36 | dict[str, float]: Global metric averages upto this epoch 37 | """ 38 | 39 | model.train() # Putting model in the training mode 40 | 41 | metric_logger = MetricLogger(delimiter=" ") 42 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) # Track the learning rate in the smoothedvalue logger 43 | header = 'Epoch: [{}]'.format(epoch) 44 | print_freq = args.print_freq # Logging frequency from config ?? 45 | 46 | accum_iter = args.accum_iter 47 | 48 | optimizer.zero_grad() # Reset the old gradients 49 | 50 | if log_writer is not None: 51 | print('log_dir: {}'.format(log_writer.log_dir)) # I am not sure why we need to print it every epoch to know where we logging? 52 | 53 | print("number of iterations: ",len(data_loader)) 54 | 55 | num_iter = len(data_loader) # Unused variable 56 | 57 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): # This allows to only log every "print_freq" steps 58 | 59 | if data_iter_step % accum_iter == 0: # Update learning rate after "accum_iter" steps 60 | adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 61 | 62 | # Move batch to device and unpack fields 63 | input_matrix, mask_matrix, hic_count, return_diag, matrix_count = list_to_device(data,device=device) 64 | 65 | with torch.cuda.amp.autocast(): #to enable mixed precision training 66 | ssim_loss, contrastive_loss, count_pred, pred_image, mask = model( # Forward pass 67 | input_matrix, mask_matrix, 68 | total_count=hic_count, 69 | diag=return_diag, mask_ratio=args.mask_ratio, 70 | ) 71 | 72 | # Count prediction 73 | matrix_count = torch.log10(matrix_count+1) 74 | count_pred = count_pred.flatten() 75 | count_loss = torch.nn.functional.mse_loss(count_pred, matrix_count) 76 | loss = args.loss_alpha*(ssim_loss+count_loss) + contrastive_loss # Why alpha for ssim+count_loss and no control for contrastive? 77 | 78 | # Log loss values 79 | metric_logger.update(loss=to_value(loss)) 80 | metric_logger.update(ssim_loss=to_value(ssim_loss)) 81 | metric_logger.update(count_loss=to_value(count_loss)) 82 | metric_logger.update(contrastive_loss=to_value(contrastive_loss)) 83 | 84 | # Do not backprop for NaN or infinite loss (Maybe introduce loss clipping instead)? 85 | if not math.isfinite(to_value(loss)): 86 | print("Loss is {}, stopping training".format(to_value(loss))) 87 | #sys.exit(1) 88 | optimizer.zero_grad() 89 | continue 90 | 91 | # loss = loss / accum_iter 92 | loss_scaler( 93 | loss, optimizer, 94 | parameters=model.parameters(), 95 | update_grad=(data_iter_step + 1) % accum_iter == 0 96 | ) # NativeScalerWithGradNormCount object 97 | 98 | # After each accumulation cycle, clear grads 99 | if (data_iter_step + 1) % accum_iter == 0: 100 | optimizer.zero_grad() 101 | 102 | # Synchronization step (across multiple workers -- DDP) 103 | torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on 104 | lr = optimizer.param_groups[0]["lr"] 105 | metric_logger.update(lr=lr) 106 | 107 | 108 | # Tensorboard logging 109 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0): 110 | """ 111 | We use epoch_1000x as the x-axis in tensorboard. 112 | This calibrates different curves when batch size changes. 113 | """ 114 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 115 | log_writer.add_scalars('Loss/loss', {'train_loss': to_value(loss)}, epoch_1000x) 116 | log_writer.add_scalars('Loss/ssim_loss', {'train_loss': to_value(ssim_loss)}, epoch_1000x) 117 | log_writer.add_scalars('Loss/count_loss', {'train_loss': to_value(count_loss)}, epoch_1000x) 118 | log_writer.add_scalars('Loss/contrastive_loss', {'train_loss': to_value(contrastive_loss)}, epoch_1000x) 119 | log_writer.add_scalars('LR/lr', {'lr': lr}, epoch_1000x) 120 | #add visualization 121 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0: 122 | new_samples = create_image(input_matrix) 123 | mask_image = new_samples*(1-mask) 124 | pred_image = create_image(pred_image) 125 | 126 | select_num = min(8,len(new_samples)) 127 | new_samples = torch_to_nparray(new_samples.clone().detach()[:select_num]) 128 | mask_image = torch_to_nparray(mask_image.clone().detach()[:select_num]) 129 | pred_image = torch_to_nparray(pred_image.clone().detach()[:select_num]) 130 | log_writer.add_images('Target_%s'%"train", new_samples, epoch_1000x) 131 | log_writer.add_images('Input_%s'%"train", mask_image, epoch_1000x) 132 | log_writer.add_images('Pred_%s'%"train", pred_image, epoch_1000x) 133 | 134 | # Sync metrics across processes (DDP-safe) 135 | metric_logger.synchronize_between_processes() 136 | print("Averaged stats:", metric_logger) 137 | 138 | # Return global averages across all logged metrics 139 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 140 | 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /refactor_pretrain/utils.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | from ops.sparse_ops import array_to_coo 4 | """ 5 | from scipy.sparse import triu,coo_matrix 6 | import numpy as np 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | import pickle 11 | import os 12 | import json 13 | 14 | 15 | def array_to_coo(array): 16 | """ 17 | Convert a regular 2D NumPy array to a scipy.sparse.coo_matrix. 18 | 19 | Parameters: 20 | - array (numpy.ndarray): The input 2D array. 21 | 22 | Returns: 23 | - scipy.sparse.coo_matrix: The converted COO matrix. 24 | """ 25 | # Find the non-zero elements in the array 26 | row, col = np.nonzero(array) 27 | 28 | # Get the values of the non-zero elements 29 | data = array[row, col] 30 | 31 | # Create the COO matrix 32 | coo_mat = coo_matrix((data, (row, col)), shape=array.shape) 33 | 34 | return coo_mat 35 | 36 | """ 37 | from ops.io_utils import load_pickle 38 | """ 39 | 40 | def load_pickle(path): 41 | with open(path,'rb') as file: 42 | data=pickle.load(file) 43 | return data 44 | 45 | """ 46 | from data_processing.finetune_dataset import to_tensor, list_to_tensor 47 | """ 48 | 49 | 50 | def to_tensor(x): 51 | """ 52 | Convert the input to tensor 53 | Args: 54 | x: the input data 55 | """ 56 | if isinstance(x, np.ndarray): 57 | x = torch.from_numpy(x) 58 | elif x is None: 59 | x = None 60 | #if already tensor, do nothing 61 | elif isinstance(x, torch.Tensor): 62 | pass 63 | #if float, convert to tensor 64 | elif isinstance(x, float): 65 | x = torch.tensor(x) 66 | elif isinstance(x, int): 67 | x = torch.tensor(x) 68 | return x 69 | 70 | def list_to_tensor(x): 71 | """ 72 | Convert the list to tensor 73 | Args: 74 | x: the input list 75 | """ 76 | y=[] 77 | for i in x: 78 | y.append(to_tensor(i)) 79 | return y 80 | 81 | 82 | 83 | 84 | """ 85 | from ops.train_utils import list_to_device, to_value, create_image, torch_to_nparray 86 | """ 87 | 88 | def list_to_device(data_list, device): 89 | 90 | def to_device(data, device): 91 | if data is not None: 92 | new_data = data.to(device,non_blocking=True) 93 | else: 94 | new_data = None 95 | return new_data 96 | 97 | new_data_list = [] 98 | for data in data_list: 99 | data = to_device(data, device) 100 | if data is not None: 101 | data = data.float() 102 | new_data_list.append(data) 103 | return new_data_list 104 | 105 | def to_value(data): 106 | if isinstance(data, torch.Tensor): 107 | return data.item() 108 | else: 109 | return data 110 | 111 | def create_image(samples): 112 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 113 | imagenet_std = np.array([0.229, 0.224, 0.225]) 114 | imagenet_mean = torch.tensor(imagenet_mean,device=samples.device) 115 | imagenet_std = torch.tensor(imagenet_std,device=samples.device) 116 | new_samples = torch.einsum("bchw,c->bchw",samples,imagenet_std) 117 | new_samples = torch.clip((new_samples+ imagenet_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)) * 255, 0, 255) 118 | return new_samples 119 | 120 | def torch_to_nparray(data): 121 | #https://github.com/pytorch/pytorch/blob/main/torch/utils/tensorboard/summary.py 122 | #image take n,c,h,w, 123 | """ 124 | 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8). 125 | The image() function will scale the image values to [0, 255] by applying 126 | a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values 127 | will be clipped. 128 | 129 | """ 130 | data = data.cpu().numpy() 131 | #data = data.transpose(0,2,3,1) 132 | data=np.array(data,dtype=np.uint8) 133 | return data 134 | 135 | 136 | import torch 137 | def collate_fn(batch): 138 | # Transpose the batch (list of lists) to group elements by position 139 | batch_transposed = list(zip(*batch)) 140 | 141 | # Process each position across the batch 142 | processed_batch = [] 143 | for tensors in batch_transposed: 144 | if all(t is None for t in tensors): # If all are None, keep None 145 | processed_batch.append(None) 146 | else: # Otherwise, stack non-None tensors and replace None with zero tensors 147 | #make sure no None element in the tensors 148 | any_none = any(t is None for t in tensors) 149 | assert not any_none, "None element in a list of tensors" 150 | stacked = [ 151 | t for t in tensors 152 | ] 153 | processed_batch.append(torch.stack(stacked)) 154 | 155 | return processed_batch 156 | 157 | 158 | """ 159 | from ops.io_utils import write_log 160 | """ 161 | def write_log(log_dir,status_flag,log_stats): 162 | cur_log_path = os.path.join(log_dir,status_flag+".log") 163 | with open(cur_log_path, mode="a", encoding="utf-8") as f: 164 | f.write(json.dumps(log_stats) + "\n") -------------------------------------------------------------------------------- /refactor_pretrain/val_epoch.py: -------------------------------------------------------------------------------- 1 | import sys # Unused 2 | import math # Unused 3 | import time # Unused 4 | import torch 5 | import numpy as np # Unused 6 | import torch.nn.functional as F 7 | 8 | from typing import Iterable # Unused 9 | from Logger import MetricLogger,SmoothedValue # SmoothedValue Unused 10 | from utils import list_to_device, to_value, create_image, torch_to_nparray 11 | 12 | def val_epoch(model, data_loader, 13 | device, epoch, 14 | log_writer=None, 15 | args=None, flag='val'): # Flag unnecessary its always a validation loop 16 | """ 17 | Runs one full validation epoch. 18 | 19 | Args: 20 | model (torch.nn.Module): HiCFoundation model object 21 | data_loader (Iterable): Validation Dataloader object 22 | device (str): Training Device 23 | epoch (int): Current epoch 24 | log_writer (optional): Tensorboard object 25 | args (Namespace or dict): 26 | Training configuration with at least: 27 | print_freq (int) - logging frequency in steps 28 | accum_iter (int) - number of steps to accumulate gradients 29 | mask_ratio (float) - masking ratio for model input 30 | loss_alpha (float) - weight for (ssim_loss + count_loss) 31 | Returns: 32 | dict[str, float]: Global metric averages upto this epoch 33 | """ 34 | 35 | model.eval() # Putting model in the evaluation mode 36 | 37 | metric_logger = MetricLogger(delimiter=" ") 38 | header = 'Val Epoch: [{}]'.format(epoch) 39 | 40 | print_freq = args.print_freq 41 | accum_iter = args.accum_iter 42 | 43 | if log_writer is not None: 44 | print('log_dir: {}'.format(log_writer.log_dir)) 45 | 46 | print("number of iterations: ",len(data_loader)) 47 | 48 | num_iter = len(data_loader) # Unused variable 49 | 50 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 51 | 52 | input_matrix, mask_matrix, hic_count, return_diag,matrix_count = list_to_device(data,device=device) 53 | 54 | with torch.no_grad(): 55 | # Forward pass 56 | ssim_loss,contrastive_loss, count_pred, pred_image, mask = model( 57 | input_matrix, mask_matrix, 58 | total_count=hic_count, diag=return_diag, 59 | mask_ratio=args.mask_ratio) 60 | 61 | # Count loss 62 | matrix_count = torch.log10(matrix_count+1) 63 | count_pred = count_pred.flatten() 64 | count_loss = torch.nn.functional.mse_loss(count_pred, matrix_count) 65 | 66 | # Total loss 67 | loss = args.loss_alpha*(ssim_loss+count_loss) + contrastive_loss 68 | 69 | # Logging the losses 70 | metric_logger.update(loss=to_value(loss)) 71 | metric_logger.update(ssim_loss=to_value(ssim_loss)) 72 | metric_logger.update(count_loss=to_value(count_loss)) 73 | metric_logger.update(contrastive_loss=to_value(contrastive_loss)) 74 | 75 | torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on 76 | 77 | # Tensorboard loggings 78 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0): 79 | """ 80 | We use epoch_1000x as the x-axis in tensorboard. 81 | This calibrates different curves when batch size changes. 82 | """ 83 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 84 | log_writer.add_scalars('Loss/loss', {'%s_loss'%flag: to_value(loss)}, epoch_1000x) 85 | log_writer.add_scalars('Loss/ssim_loss', {'%s_loss'%flag: to_value(ssim_loss)}, epoch_1000x) 86 | log_writer.add_scalars('Loss/count_loss', {'%s_loss'%flag: to_value(count_loss)}, epoch_1000x) 87 | log_writer.add_scalars('Loss/contrastive_loss', {'%s_loss'%flag: to_value(contrastive_loss)}, epoch_1000x) 88 | #add visualization 89 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0: 90 | new_samples = create_image(input_matrix) 91 | mask_image = new_samples*(1-mask) 92 | pred_image = create_image(pred_image) 93 | 94 | select_num = min(8,len(new_samples)) 95 | new_samples = torch_to_nparray(new_samples.clone().detach()[:select_num]) 96 | mask_image = torch_to_nparray(mask_image.clone().detach()[:select_num]) 97 | pred_image = torch_to_nparray(pred_image.clone().detach()[:select_num]) 98 | log_writer.add_images('Target_%s'%flag, new_samples, epoch_1000x) 99 | log_writer.add_images('Input_%s'%flag, mask_image, epoch_1000x) 100 | log_writer.add_images('Pred_%s'%flag, pred_image, epoch_1000x) 101 | 102 | # Sync metrics across processes (DDP-safe) 103 | metric_logger.synchronize_between_processes() 104 | print("Averaged stats:", metric_logger) 105 | 106 | # Returns a global average of validation metrics 107 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict 2 | opencv-python 3 | simplejson 4 | lvis 5 | Pillow==9.5.0 6 | pytorch_msssim 7 | pandas 8 | hic-straw 9 | matplotlib 10 | scikit-image 11 | scipy 12 | einops 13 | tensorboard 14 | cooler 15 | numba 16 | pyBigWig 17 | timm==0.3.2 -------------------------------------------------------------------------------- /test/test_convert_rgb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit test for convert RGB 3 | 4 | Author: Xumeng Zhang (xumzhang@uw.edu) 5 | 6 | This function takes only one np window from one npz file 7 | does not handle batch 8 | 9 | """ 10 | 11 | import pytest 12 | import os 13 | import sys 14 | import inspect 15 | 16 | # add HiCFoundation dir into path 17 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 18 | parentdir = os.path.dirname(currentdir) 19 | sys.path.insert(0, parentdir) 20 | from data_processing.pretrain_dataset import Pretrain_Dataset 21 | 22 | 23 | import numpy as np 24 | import torchvision.transforms as transforms 25 | 26 | batch_size = 2 27 | transform_mean = [0.485, 0.456, 0.406] 28 | transform_std = [0.229, 0.224, 0.225] 29 | 30 | 31 | # pseudo dataset class for preprocessing 32 | @pytest.fixture 33 | def dataset(): 34 | # Create a temporary empty folder 35 | empty_folder_path = os.path.join(os.getcwd(), "temp_empty_folder") 36 | assert not os.path.exists(empty_folder_path) or not os.listdir(empty_folder_path), "temporary empty folder exists and also not empty, try change the folder path for testing" 37 | os.makedirs(empty_folder_path, exist_ok=True) 38 | 39 | transform_train = transforms.Compose([ 40 | transforms.ToTensor(), 41 | transforms.Normalize(mean=transform_mean, std=transform_std), 42 | ]) 43 | 44 | dataset = Pretrain_Dataset(data_list=[empty_folder_path], transform=transform_train) 45 | 46 | # Remove the folder after dataset creation (or after using it) 47 | # to be safe, only remove it if the folder is empty 48 | if os.path.exists(empty_folder_path) and not os.listdir(empty_folder_path): 49 | os.rmdir(empty_folder_path) 50 | 51 | return dataset 52 | 53 | 54 | # example hic matrix 55 | @pytest.fixture 56 | def example_hic(): 57 | # should be np matrix [window_size, window_size] 58 | hic = np.array([[0.5, 1.0], [0.0, 0.75]]) 59 | return hic 60 | 61 | 62 | def test_convert_rgb_shape_and_dtype(dataset, example_hic): 63 | out = dataset.convert_rgb(example_hic, np.max(example_hic)) 64 | 65 | assert out.shape == (*example_hic.shape, 3) # B, 66 | assert out.dtype == np.float32 67 | 68 | 69 | def test_convert_rgb_values(dataset, example_hic): 70 | """ 71 | check if the converted values are as expected 72 | """ 73 | out = dataset.convert_rgb(example_hic, np.max(example_hic)) 74 | red = out[:, :, 0] 75 | green = out[:, :, 1] 76 | blue = out[:, :, 2] 77 | 78 | # the theoritical value for both blue and green channel 79 | expected_gb = (np.max(example_hic) - example_hic) / np.max(example_hic) 80 | 81 | assert np.allclose(red, np.ones_like(red)), "red channel should be all 1" 82 | assert np.allclose(green, expected_gb), "Green channel is not as expected" 83 | assert np.allclose(blue, expected_gb), "blue channel is not as expected" 84 | 85 | 86 | def test_log10_convert(dataset, example_hic): 87 | """ 88 | Test is the function handles data after log10 well 89 | """ 90 | log10_input = np.log10(example_hic + 1) 91 | log10_max_val = np.log10(np.max(example_hic) + 1) 92 | out = dataset.convert_rgb(log10_input, log10_max_val) 93 | red = out[:, :, 0] 94 | green = out[:, :, 1] 95 | blue = out[:, :, 2] 96 | 97 | # the theoritical value for both blue and green channel 98 | expected_gb = (log10_max_val - log10_input) / log10_max_val 99 | 100 | assert np.allclose(red, np.ones_like(red)) 101 | assert np.allclose(green, expected_gb) 102 | assert np.allclose(blue, expected_gb) 103 | 104 | 105 | def test_post_convert_transform(dataset, example_hic): 106 | """ 107 | Test if the post-process transformation reproduce the same matrix 108 | """ 109 | log10_input = np.log10(example_hic + 1) 110 | log10_max_val = np.log10(np.max(example_hic) + 1) 111 | 112 | out = dataset.convert_rgb(log10_input, log10_max_val) 113 | out = dataset.transform(out) # [3, window_size, window_size] 114 | assert out.shape[0] == 3, ( 115 | "After transformation, the first dimension should be 3, then the window_size by window_size" 116 | ) 117 | red_inverse = out[0, :, :] * transform_std[0] + transform_mean[0] 118 | green_inverse = out[1, :, :] * transform_std[1] + transform_mean[1] 119 | blue_inverse = out[2, :, :] * transform_std[2] + transform_mean[2] 120 | 121 | expected_gb = (log10_max_val - log10_input) / log10_max_val 122 | 123 | assert np.allclose(red_inverse, np.ones_like(red_inverse), atol=1e-6) 124 | assert np.allclose(green_inverse, expected_gb, atol=1e-6) 125 | assert np.allclose(blue_inverse, expected_gb, atol=1e-6) 126 | -------------------------------------------------------------------------------- /test/test_forward.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit test for model forward 3 | 4 | Author: Xumeng Zhang (xumzhang@uw.edu) 5 | 6 | Test if the forward functions in HiCFoundation model work 7 | Models_HiCFoundation.forward_encoder() 8 | Models_HiCFoundation.forward_decoder() 9 | Models_HiCFoundation.forward() 10 | 11 | """ 12 | 13 | import pytest 14 | import os 15 | import sys 16 | import inspect 17 | import numpy as np 18 | import torch 19 | 20 | # add HiCFoundation dir into path 21 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 22 | parentdir = os.path.dirname(currentdir) 23 | sys.path.insert(0, parentdir) 24 | from model.models_hicfoundation import Models_HiCFoundation 25 | 26 | 27 | batch_size = 3 # batch size 28 | in_chans = 3 # channels (1 or 3) 29 | image_size = 4 # image size 30 | patch_size = 2 31 | 32 | 33 | # initialize model 34 | @pytest.fixture 35 | def model(): 36 | return Models_HiCFoundation( 37 | img_size=(image_size, image_size), patch_size=patch_size 38 | ) 39 | 40 | 41 | # 42 | @pytest.fixture 43 | def example_input(): 44 | imgs = torch.randn(batch_size, in_chans, image_size, image_size) 45 | # sample 1 and 3: no sparsity 46 | imgs_mask = torch.ones(batch_size, 1, image_size, image_size) 47 | # sample 2: all 0 matrix 48 | imgs_mask[1] = 0 49 | total_count = torch.tensor([1000000, 1000000, float("nan")]) 50 | diag = torch.tensor([0, 0, 0]).float() 51 | return imgs, imgs_mask, total_count, diag 52 | 53 | 54 | def test_forward_encoder(model, example_input): 55 | imgs, imgs_mask, total_count, diag = example_input 56 | 57 | latent, mask, ids_restore = model.forward_encoder( 58 | imgs, total_count=total_count, diag=diag, mask_ratio=0.6 59 | ) 60 | assert latent.shape[0] == imgs.shape[0] 61 | assert mask.shape[0] == imgs.shape[0] 62 | assert ids_restore.shape[0] == imgs.shape[0] 63 | assert torch.isnan(latent[2]).all(), ( 64 | "input nan total count should result in nan encoder output" 65 | ) 66 | 67 | # test case if total_count=None 68 | # here total_count=None means not inputing any value for total_counts 69 | # different than having nan in total_count tensor 70 | # nan in total_count tensor will lead to nan loss 71 | latent, mask, ids_restore = model.forward_encoder( 72 | imgs, diag=diag, mask_ratio=0.6 73 | ) 74 | assert latent.shape[0] == imgs.shape[0] 75 | assert mask.shape[0] == imgs.shape[0] 76 | assert ids_restore.shape[0] == imgs.shape[0] 77 | assert not torch.isnan(latent).any(), "encoder output contains NaNs" 78 | 79 | 80 | def test_forward_decoder(model, example_input): 81 | imgs, imgs_mask, total_count, diag = example_input 82 | 83 | latent, mask, ids_restore = model.forward_encoder( 84 | imgs, total_count=total_count, diag=diag, mask_ratio=0.6 85 | ) 86 | count_pred, patch_pred = model.forward_decoder(latent, ids_restore) 87 | assert count_pred.shape[0] == imgs.shape[0] 88 | assert patch_pred.shape[0] == imgs.shape[0] 89 | assert patch_pred.shape[1] == (image_size // patch_size) ** 2 90 | assert patch_pred.shape[2] == patch_size * patch_size * in_chans 91 | assert torch.isnan(count_pred[2]), ( 92 | "input nan total count should result in nan count pred" 93 | ) 94 | 95 | 96 | def test_forward(model, example_input): 97 | imgs, imgs_mask, total_count, diag = example_input 98 | 99 | ssim_loss, contrastive_loss, count_pred, pred_img, mask = model.forward( 100 | imgs, 101 | imgs_mask, 102 | total_count=total_count, 103 | diag=diag, 104 | mask_ratio=0.4, 105 | ) 106 | assert torch.isnan(ssim_loss), "input nan total count should result in nan ssimloss" 107 | assert torch.isnan(contrastive_loss), ( 108 | "input nan total count should result in nan contrastive loss" 109 | ) 110 | assert count_pred.shape[0] == imgs.shape[0] 111 | assert pred_img.shape == imgs.shape 112 | 113 | # remove the last sample whose total_count is nan 114 | imgs = imgs[:-1] 115 | imgs_mask = imgs_mask[:-1] 116 | total_count = total_count[:-1] 117 | diag = diag[:-1] 118 | ssim_loss, contrastive_loss, count_pred, pred_img, mask = model.forward( 119 | imgs, 120 | imgs_mask, 121 | total_count=total_count, 122 | diag=diag, 123 | mask_ratio=0.4 124 | ) 125 | assert ssim_loss.item() >= 0 126 | assert contrastive_loss.item() >= 0 127 | assert count_pred.shape[0] == imgs.shape[0] 128 | assert pred_img.shape == imgs.shape 129 | assert mask.shape == pred_img.shape 130 | for i in range(batch_size - 1): 131 | assert torch.equal(mask[i, 0], mask[i, 1]) and torch.equal( 132 | mask[i, 1], mask[i, 2] 133 | ), f"mask differ amont channels for sample {i}" 134 | 135 | -------------------------------------------------------------------------------- /test/test_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit test for loss functions 3 | 4 | Author: Xumeng Zhang (xumzhang@uw.edu) 5 | 6 | Models_HiCFoundation.forward_loss(self, imgs, imgs_mask, pred, mask): 7 | imgs: [N, 3, H, W] 8 | imgs_mask: [N, 1, H, W] indicate those 0 regions and mask them in target 9 | pred: [N, L, D], sequence of embeddings 10 | mask: [N, L], binary mask, 0 is keep, 1 is remove 11 | 12 | """ 13 | 14 | import pytest 15 | import os 16 | import sys 17 | import inspect 18 | import numpy as np 19 | 20 | # add HiCFoundation dir into path 21 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 22 | parentdir = os.path.dirname(currentdir) 23 | sys.path.insert(0, parentdir) 24 | from model.models_hicfoundation import Models_HiCFoundation 25 | from data_processing.pretrain_dataset import Pretrain_Dataset 26 | 27 | import torchvision.transforms as transforms 28 | 29 | import torch 30 | 31 | image_size = 4 32 | patch_size = 2 33 | batch_size = 2 34 | mask_ratio = 0.6 35 | transform_mean = [0.485, 0.456, 0.406] 36 | transform_std = [0.229, 0.224, 0.225] 37 | 38 | 39 | # initialize model 40 | @pytest.fixture 41 | def model(): 42 | return Models_HiCFoundation( 43 | img_size=(image_size, image_size), patch_size=patch_size 44 | ) 45 | 46 | 47 | # pseudo dataset class for preprocessing 48 | @pytest.fixture 49 | def dataset(): 50 | # Create a temporary empty folder 51 | empty_folder_path = os.path.join(os.getcwd(), "temp_empty_folder") 52 | assert not os.path.exists(empty_folder_path) or not os.listdir(empty_folder_path), "temporary empty folder exists and also not empty, try change the folder path for testing" 53 | os.makedirs(empty_folder_path, exist_ok=True) 54 | 55 | transform_train = transforms.Compose([ 56 | transforms.ToTensor(), 57 | transforms.Normalize(mean=transform_mean, std=transform_std), 58 | ]) 59 | 60 | dataset = Pretrain_Dataset(data_list=[empty_folder_path], transform=transform_train) 61 | 62 | # Remove the folder after dataset creation (or after using it) 63 | # to be safe, only remove it if the folder is empty 64 | if os.path.exists(empty_folder_path) and not os.listdir(empty_folder_path): 65 | os.rmdir(empty_folder_path) 66 | 67 | return dataset 68 | 69 | 70 | # Test when input is rgb 71 | @pytest.mark.parametrize("in_chans", [3]) 72 | def test_forward_loss_rgb(model, dataset, in_chans): 73 | """ 74 | test ssim loss and contrastive loss, 75 | here the ssim loss include the R channel loss 76 | as input of model.forward_loss, the imgs and pred should both be after normalization 77 | 78 | Args: 79 | model: initialized example model 80 | dataset: pseudo dataset 81 | """ 82 | torch.manual_seed(42) 83 | 84 | B, C, H, W = batch_size, in_chans, image_size, image_size # small test case 85 | L = (H // patch_size) * (W // patch_size) 86 | 87 | # simulate data preprocessing steps before inputing into loss func 88 | # random example 89 | imgs = np.random.randint(low=0, high=100, size=(H, W)) 90 | imgs_converted = dataset.convert_rgb(imgs, max_value=np.max(imgs)) # (H, W, C) 91 | imgs_converted = dataset.transform(imgs_converted) # (C, H, W) 92 | imgs_converted = torch.concat( 93 | [imgs_converted.unsqueeze(0)] * batch_size, dim=0 94 | ) # (B, C, H, W) repeat across batch 95 | 96 | # patched used later as true prediction 97 | imgs_patched = model.patchify(imgs_converted) # (B, L, D) 98 | 99 | # sparsity 100 | imgs_mask = (torch.rand(B, 1, H, W) * 0.2).float() 101 | 102 | # random prediction 103 | pred = torch.rand(B, L, patch_size * patch_size * C) 104 | mask = (torch.rand(B, L) > mask_ratio).float() 105 | 106 | ssim_loss, contrastive_loss = model.forward_loss( 107 | imgs_converted, imgs_mask, pred, mask 108 | ) 109 | 110 | # sanity check 111 | assert isinstance(ssim_loss, torch.Tensor), "SSIM loss should be a tensor" 112 | assert isinstance(contrastive_loss, torch.Tensor), ( 113 | "Contrastive loss should be a tensor" 114 | ) 115 | assert not torch.isnan(ssim_loss), "SSIM loss shouldn't be NaN" 116 | assert not torch.isnan(contrastive_loss), "Contrastive loss shouldn't be NaN" 117 | assert ssim_loss.item() >= 0, "SSIM loss should be non-negative" 118 | assert contrastive_loss.item() >= 0, "Contrastive loss should be non-negative" 119 | assert ssim_loss.item() != 0, "SSIM loss with random prediction should not be 0" 120 | assert contrastive_loss.item() != 0, ( 121 | "Contrastive loss with random prediction should not be 0" 122 | ) 123 | 124 | # loss with itself 125 | # the prediction here should be after convert rgb 126 | ssim_loss, contrastive_loss = model.forward_loss( 127 | imgs_converted, imgs_mask, imgs_patched, mask 128 | ) 129 | assert ssim_loss.item() == 0, ( 130 | f"ssim loss with itself should be 0, but it is {ssim_loss}" 131 | ) 132 | # actually in Xiao's contrastive loss, this number will never be 0 133 | # it should be as close as to 1 134 | # assert contrastive_loss.item() == 0, ( 135 | # f"contrastive loss with itself should be 0, but it is {contrastive_loss}" 136 | # ) 137 | 138 | 139 | # Test when input is count matrix 140 | @pytest.mark.parametrize("in_chans", [1]) 141 | def test_forward_loss_count(model, in_chans): 142 | """ 143 | test ssim loss and contrastive loss for HiC count matrix. i.e. in_chans=1 144 | here the ssim loss include the R channel loss 145 | 146 | 147 | Args: 148 | model: initialized example model 149 | dataset: pseudo dataset 150 | """ 151 | torch.manual_seed(42) 152 | 153 | B, C, H, W = batch_size, in_chans, image_size, image_size # small test case 154 | L = (H // patch_size) * (W // patch_size) 155 | 156 | # random example 157 | img_np = np.random.randint(low=0, high=100, size=(H, W)) 158 | img_torch = torch.tensor(img_np).unsqueeze(0).float() # (1,H,W) 159 | imgs_torch = torch.concat( 160 | [img_torch.unsqueeze(0)] * batch_size, dim=0 161 | ) # (B, C, H, W) 162 | 163 | # tune the model into 1 in_chans mode 164 | model.in_chans = in_chans 165 | 166 | imgs_patched = model.patchify(imgs_torch) # (B, L, D) 167 | 168 | imgs_mask = (torch.rand(B, 1, H, W) * 0.2).float() 169 | 170 | # random prediction 171 | pred = torch.rand(B, L, patch_size * patch_size * C) 172 | mask = (torch.rand(B, L) > mask_ratio).float() 173 | 174 | ssim_loss, contrastive_loss = model.forward_loss(imgs_torch, imgs_mask, pred, mask) 175 | 176 | # sanity check 177 | assert isinstance(ssim_loss, torch.Tensor), "SSIM loss should be a tensor" 178 | assert isinstance(contrastive_loss, torch.Tensor), ( 179 | "Contrastive loss should be a tensor" 180 | ) 181 | assert not torch.isnan(ssim_loss), "SSIM loss shouldn't be NaN" 182 | assert not torch.isnan(contrastive_loss), "Contrastive loss shouldn't be NaN" 183 | assert ssim_loss.item() >= 0, "SSIM loss should be non-negative" 184 | assert contrastive_loss.item() >= 0, "Contrastive loss should be non-negative" 185 | assert ssim_loss.item() != 0, "SSIM loss with random prediction should not be 0" 186 | assert contrastive_loss.item() != 0, ( 187 | "Contrastive loss with random prediction should not be 0" 188 | ) 189 | -------------------------------------------------------------------------------- /test/test_masking.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit test for masking 3 | 4 | Author: Xumeng Zhang (xumzhang@uw.edu) 5 | 6 | Test random masking for patches. 7 | Focus on diagonal symmetric masking as well. 8 | 9 | random_masking(self, x, mask_ratio,diag=None): 10 | x: [N, L, D], sequence (here L is without additional tokens) 11 | mask_ratio: float, masking ratio 12 | diag: [N,1] diagonal position to symmetrical masking, if None, then random masking 13 | """ 14 | 15 | import pytest 16 | import os 17 | import sys 18 | import inspect 19 | 20 | # add HiCFoundation dir into path 21 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 22 | parentdir = os.path.dirname(currentdir) 23 | sys.path.insert(0, parentdir) 24 | from model.models_hicfoundation import Models_HiCFoundation, apply_symmectric_noise 25 | 26 | 27 | import torch 28 | 29 | batch_size = 3 30 | mask_ratio = 0.6 31 | image_size = 128 32 | patch_size = 16 33 | 34 | 35 | # initialize model class 36 | @pytest.fixture 37 | def model(): 38 | return Models_HiCFoundation( 39 | img_size=(image_size, image_size), patch_size=patch_size 40 | ) 41 | 42 | 43 | # feed an example noise as input 44 | @pytest.fixture 45 | def noise(): 46 | noise = torch.zeros([image_size // patch_size, image_size // patch_size]) 47 | for i in range(noise.shape[0]): 48 | for j in range(noise.shape[1]): 49 | noise[i, j] = i * 10 + j 50 | noise = torch.stack([noise] * batch_size) 51 | print(noise.shape) 52 | return noise 53 | 54 | 55 | def test_apply_symmetric_noise_shape_and_symmetry(noise): 56 | noise_clone = noise.clone() 57 | diag = torch.tensor([0, -2, 2]) 58 | 59 | output = apply_symmectric_noise(noise_clone.clone(), diag) 60 | 61 | assert output.shape == noise.shape 62 | 63 | # Case 1: diag = 0, i.e. full matrix should be symmetric 64 | symm_part = output[0] 65 | assert torch.allclose(symm_part, symm_part.T, atol=1e-6) 66 | 67 | # Case 2: diag = -2 68 | affected = output[1, 2:, :-2] 69 | transposed = affected.T 70 | assert torch.allclose(affected, transposed, atol=1e-6) 71 | 72 | # Case 3: diag = 2 73 | affected = output[2, :-2, 2:] 74 | transposed = affected.T 75 | assert torch.allclose(affected, transposed, atol=1e-6) 76 | 77 | 78 | def test_random_masking_shape(model): 79 | """ 80 | test Models_HiCFoundation.random_masking 81 | x is the (batch, length, dim) tensor 82 | 83 | """ 84 | x = torch.randn( 85 | batch_size, model.num_patches, model.embed_dim 86 | ) # [batch, tokens, dim] 87 | print(x.shape) 88 | 89 | x_masked, mask, ids_restore = model.random_masking(x, mask_ratio) 90 | 91 | L = x.shape[1] 92 | len_keep = int(L * (1 - mask_ratio)) 93 | 94 | assert x_masked.shape == (batch_size, len_keep, model.embed_dim) 95 | assert mask.shape == (batch_size, L) 96 | assert ids_restore.shape == (batch_size, L) 97 | assert (mask.sum(dim=1) == L - len_keep).all() # confirm number of masked positions 98 | 99 | 100 | def test_random_masking_deterministic_output_on_seed(model): 101 | torch.manual_seed(42) 102 | x = torch.randn(1, model.num_patches, model.embed_dim) 103 | x_masked1, mask1, _ = model.random_masking(x, mask_ratio) 104 | 105 | torch.manual_seed(42) 106 | x = torch.randn(1, model.num_patches, model.embed_dim) 107 | x_masked2, mask2, _ = model.random_masking(x, mask_ratio) 108 | 109 | # They should match since the seed and input match 110 | assert torch.allclose(x_masked1, x_masked2), ( 111 | "The same random seed leads to different masking" 112 | ) 113 | assert torch.allclose(mask1, mask2), ( 114 | "The same random seed leads to different masking" 115 | ) 116 | 117 | torch.manual_seed(42) 118 | x = torch.randn(1, model.num_patches, model.embed_dim) 119 | x_masked1, mask1, _ = model.random_masking(x, mask_ratio) 120 | x_masked2, mask2, _ = model.random_masking(x, mask_ratio) 121 | 122 | # They should match since the seed and input match 123 | assert not torch.allclose(x_masked1, x_masked2), ( 124 | "Different sampling leads to the same masking" 125 | ) 126 | assert not torch.allclose(mask1, mask2), ( 127 | "Different sampling leads to the same masking" 128 | ) 129 | -------------------------------------------------------------------------------- /test/test_patchify.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit test for patchify and unpatchify 3 | With a focus on the order of each patchify channels 4 | 5 | Author: Xumeng Zhang (xumzhang@uw.edu) 6 | 7 | Models_HiCFoundation.patchify(self, imgs, in_chans=None) 8 | imgs: (N, 3, H, W) 9 | x: (N, L, H*W *self.in_chans) 10 | 11 | Models_HiCFoundation.unpatchify(self, x, in_chans=None): 12 | x: (N, L, patch_size**2 *self.in_chans) 13 | 14 | """ 15 | 16 | import pytest 17 | import os 18 | import sys 19 | import inspect 20 | 21 | # add HiCFoundation dir into path 22 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 23 | parentdir = os.path.dirname(currentdir) 24 | sys.path.insert(0, parentdir) 25 | from model.models_hicfoundation import Models_HiCFoundation 26 | 27 | 28 | import torch 29 | 30 | image_size = 4 31 | patch_size = 2 32 | batch_size = 2 33 | 34 | # test for in_chans both 1 and 3 each time 35 | in_chans = 3 36 | 37 | 38 | # feed an easily interpretable example image 39 | def make_example_image(in_chans): 40 | """ 41 | make an example image that is easy for as to predict the patchify result 42 | This should be how it look like for count matrix or R channel of rgb image 43 | tensor([[1., 1., 2., 2.], 44 | [1., 1., 2., 2.], 45 | [3., 3., 4., 4.], 46 | [3., 3., 4., 4.]]) 47 | G channel should be all 0 48 | B channel should be negative of R channel 49 | 50 | Returns: 51 | tensor: [B, C, H, W], image of size [batch, channel, height, weight] 52 | """ 53 | n_patch_row = image_size // patch_size 54 | 55 | patches = [] 56 | for i in range(n_patch_row * n_patch_row): 57 | if in_chans == 1: 58 | # make each patch a [patch_size, patch_size] square with all 1's or all 2's or ... 59 | patch = torch.full( 60 | (1, patch_size, patch_size), fill_value=(i + 1), dtype=torch.float32 61 | ) 62 | elif in_chans == 3: 63 | # make each patch a [patch_size, patch_size] square 64 | # Channel R: i + 1 (e.g., 1, 2, 3, …) 65 | # Channel G: always 0 66 | # Channel B: -(i + 1) (e.g., -1, -2, -3, …) 67 | chan0 = torch.full( 68 | (patch_size, patch_size), fill_value=(i + 1), dtype=torch.float32 69 | ) 70 | chan1 = torch.zeros((patch_size, patch_size), dtype=torch.float32) 71 | chan2 = torch.full( 72 | (patch_size, patch_size), fill_value=-(i + 1), dtype=torch.float32 73 | ) 74 | patch = torch.stack([chan0, chan1, chan2]) 75 | assert patch.shape == (in_chans, patch_size, patch_size), f"{patch.shape}" 76 | patches.append(patch) 77 | 78 | patches = torch.stack(patches) # [num_patches, inchan, patch_size, patch_size] 79 | patches = patches.reshape( 80 | n_patch_row, n_patch_row, in_chans, patch_size, patch_size 81 | ) 82 | img = patches.permute( 83 | 2, 0, 3, 1, 4 84 | ) # [inchan, n_patch_row, patch_size, n_patch_row, patch_size] 85 | img = img.reshape(in_chans, image_size, image_size) # [C, H, W] 86 | # check the first patch is all 1 87 | assert torch.all(img[0, :patch_size, :patch_size] == 1) 88 | if in_chans == 3: 89 | assert torch.all(img[1, :, :] == 0) 90 | assert torch.all(img[2, :patch_size, :patch_size] == -1) 91 | # check the last patch is all n_patch_row**2 92 | assert torch.all(img[0, -patch_size:, -patch_size:] == n_patch_row**2) 93 | if in_chans == 3: 94 | assert torch.all(img[1, :, :] == 0) 95 | assert torch.all(img[2, -patch_size:, -patch_size:] == -(n_patch_row**2)) 96 | 97 | return torch.cat([img.unsqueeze(0), img.unsqueeze(0)], dim=0) # [B, C, H, W] 98 | 99 | 100 | # initialize model class 101 | @pytest.fixture 102 | def model(): 103 | return Models_HiCFoundation( 104 | img_size=(image_size, image_size), patch_size=patch_size 105 | ) 106 | 107 | 108 | @pytest.mark.parametrize("in_chans", [1, 3]) 109 | def test_patchify_output(model, in_chans): 110 | example_image = make_example_image(in_chans=in_chans) 111 | patches = model.patchify(example_image, in_chans=in_chans) 112 | # test output shape 113 | assert patches.shape == ( 114 | batch_size, 115 | model.num_patches, 116 | model.patch_size**2 * in_chans, 117 | ) 118 | assert patches.shape == ( 119 | batch_size, 120 | (image_size / patch_size) ** 2, 121 | patch_size**2 * in_chans, 122 | ) # another way calculating 123 | 124 | # test whether output is as expected 125 | # the two batches are the same 126 | assert torch.all(patches[0] == patches[1]) 127 | 128 | # check the three channels in the first patch 129 | for i in range(model.num_patches): 130 | if in_chans == 1: 131 | assert torch.all( 132 | patches[0, i, :] == torch.tensor([i + 1] * patch_size**2) 133 | ), f"patch No. {i} patchify not expected" 134 | else: 135 | assert torch.all( 136 | patches[0, i, :] == torch.tensor([i + 1, 0, -(i + 1)] * patch_size**2) 137 | ), f"patch No. {i} patchify not expected" 138 | 139 | 140 | # test unpatchify logic is simple, after patchify and unpatchify if it returns the same thing then its great 141 | @pytest.mark.parametrize("in_chans", [1, 3]) 142 | def test_patchify_unpatchify_equivalence(model, in_chans): 143 | example_image = make_example_image(in_chans=in_chans) 144 | patches = model.patchify(example_image, in_chans=in_chans) 145 | reconstructed = model.unpatchify(patches, in_chans=in_chans) 146 | 147 | assert reconstructed.shape == example_image.shape 148 | assert torch.allclose(example_image, reconstructed, atol=1e-6), ( 149 | "Reconstructed patchified image doesn't match original" 150 | ) 151 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # This dir includes indepeneent functions that are used in the main scripts. 2 | # They can also be run independently. -------------------------------------------------------------------------------- /utils/array2bigwig.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import shutil 4 | from collections import defaultdict 5 | import pyBigWig 6 | import sys 7 | import pickle 8 | # def parse_genome_size(genome_size_file): 9 | # chrom_size = {} 10 | # with open(genome_size_file) as f: 11 | # for line in f: 12 | # chrom, size = line.strip("\n").split() 13 | # chrom_size[chrom] = int(size) 14 | # return chrom_size 15 | 16 | def array2bigwig(input_file,output_bigwig,resolution=1000): 17 | """ 18 | convert the 1D array to bigwig file 19 | """ 20 | 21 | #read genome info 22 | #chrom_size = parse_genome_size(genome_path) 23 | 24 | # read the input file 25 | # if input_file.endswith(".npy"): 26 | # data = np.load(input_file) 27 | # elif input_file.endswith(".txt") or input_file.endswith(".Ev1") or input_file.endswith(".Ev2"): 28 | # data = np.loadtxt(input_file) 29 | # else: 30 | # print("The input file should be in npy or txt format") 31 | # sys.exit(1) 32 | with open(input_file, 'rb') as f: 33 | data = pickle.load(f) 34 | #generate chromosomes size dict 35 | chrom_size={} 36 | for chrom in data: 37 | cur_list=data[chrom] 38 | chrom_size[chrom]=len(cur_list)*resolution 39 | 40 | with pyBigWig.open(output_bigwig, "w") as bw: 41 | chromosomes = [(key,chrom_size[key]) for key in chrom_size] 42 | print(chromosomes) 43 | # Add chromosome information to the BigWig file 44 | bw.addHeader(chromosomes) 45 | for info in chromosomes: 46 | chrom = info[0] 47 | print("start chrom",chrom,"size",chrom_size[chrom]) 48 | #pos_embedding = np.arange(0, len(data), dtype=int) 49 | start_embedding = np.arange(0, chrom_size[chrom], resolution) 50 | start_embedding = start_embedding.astype(int) 51 | end_embedding= start_embedding+resolution 52 | end_embedding = np.clip(end_embedding, 0, chrom_size[chrom]-1) 53 | # bw.addEntries("chr1", [500, 600, 635], values=[-2.0, 150.0, 25.0], span=20) 54 | chrom_data = data[chrom] 55 | chrom_data= np.nan_to_num(chrom_data) 56 | start_embedding = [int(x) for x in start_embedding] 57 | end_embedding = [int(x) for x in end_embedding] 58 | chrom_data= [float(x) for x in chrom_data] 59 | print(set([type(x) for x in chrom_data])) 60 | print(len(chrom_data),len(start_embedding),len(end_embedding)) 61 | assert all(isinstance(c, str) for c in chrom), "Chromosomes must be strings" 62 | assert all(isinstance(s, int) for s in start_embedding), "Start positions must be integers" 63 | assert all(isinstance(e, int) for e in end_embedding), "End positions must be integers" 64 | assert all(isinstance(v, (int, float)) for v in chrom_data), "Values must be int or float" 65 | assert all(s >= 0 for s in start_embedding), "Start positions must be non-negative" 66 | assert all(e >= s for s, e in zip(start_embedding, end_embedding)), "End must be >= Start" 67 | for s, e, v in zip(start_embedding[:10], end_embedding[:10], chrom_data[:10]): 68 | print(f"Chromosome: {chrom}, Start: {s}, End: {e}, Value: {v}") 69 | bw.addEntries([chrom]*len(chrom_data), list(start_embedding), ends=list(end_embedding), values=list(chrom_data)) 70 | print("finished",chrom) 71 | return output_bigwig 72 | 73 | """ 74 | This script is used to merge bigwig files into one bigwig file. 75 | ``` 76 | python3 array2bigwig.py [input_file] [output_bigwig] [resolution] 77 | ``` 78 | [input_file]: the pkl file to be converted, should be a dict in format of [chr]:[array].
79 | [output_bigwig]: the output bigwig file.
80 | [resolution]: the resolution stored in the pkl file.
81 | 82 | """ 83 | 84 | 85 | if __name__ == '__main__': 86 | 87 | if len(sys.argv) != 4: 88 | print("Usage: python3 array2bigwig.py [input_file] [output_bigwig] [resolution]") 89 | print("input_file: the pkl file to be converted, should be a dict in format of [chr]:[array].") 90 | print("output_bigwig: the output bigwig file.") 91 | print("resolution: the resolution stored in the pkl file.") 92 | sys.exit(1) 93 | input_file = os.path.abspath(sys.argv[1]) 94 | output_bigwig = os.path.abspath(sys.argv[2]) 95 | output_dir = os.path.basename(output_bigwig) 96 | os.makedirs(output_dir, exist_ok=True) 97 | resolution = int(sys.argv[3]) 98 | output_bw = array2bigwig(input_file, output_bigwig, resolution) 99 | 100 | print("Finished converting saved to %s" % output_bw) 101 | -------------------------------------------------------------------------------- /utils/array2cool.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pickle 4 | import numpy as np 5 | import pandas as pd 6 | import cooler 7 | def array2sparse(array): 8 | """ 9 | The array2sparse function converts a numpy array to a scipy sparce array. 10 | 11 | :param array: Specify the numpy array 12 | :return: A scipy sparce array 13 | :doc-author: Trelent 14 | """ 15 | from scipy.sparse import coo_matrix 16 | row, col = np.where(array) 17 | data = array[row, col] 18 | return coo_matrix((data, (row, col)), shape=array.shape) 19 | def array2cool(input_array_pickle,output_cool,resolution,refer_genome_name,mode): 20 | """ 21 | The array2cool function converts a dict of numpy array to cool file. 22 | 23 | :param juicer_tools: Specify the location of the juicer_tools 24 | :param input_array_pickle: Specify the path to the pickle file containing the array 25 | :param output_cool: Specify the name of the output cool file 26 | :param resolution: Set the resolution of the hic file 27 | :param refer_genome_name: Specify the reference genome name 28 | :return: A hic file 29 | :doc-author: Trelent 30 | """ 31 | #load array 32 | with open(input_array_pickle, 'rb') as f: 33 | data = pickle.load(f) 34 | output_dir = os.path.dirname(output_cool) 35 | os.makedirs(output_dir, exist_ok=True) 36 | 37 | #set each chromosome's length 38 | chromsizes={"name":[],"length":[]} 39 | chromosize_add_dict ={}#[chr_name:add_size] 40 | sort_keys = sorted(data.keys()) 41 | accumulate_index = 0 42 | for chrom_name in sort_keys: 43 | 44 | if mode == 0 or mode == 2: 45 | chrom1, chrom2 = chrom_name.split('_') 46 | else: 47 | chrom1 = chrom_name 48 | chrom2 = chrom_name 49 | if chrom1!=chrom2: 50 | continue 51 | if "chr" not in chrom1: 52 | chrom1 = "chr"+chrom1 53 | cur_array = data[chrom_name] 54 | chrom_size_total=resolution*cur_array.shape[0] 55 | #chromsizes={"name":chrom_name,"length":[chrom_size_total]} 56 | chromsizes['name'].append(chrom1) 57 | chromsizes['length'].append(chrom_size_total) 58 | chromosize_add_dict[chrom1]=accumulate_index 59 | accumulate_index += cur_array.shape[0] 60 | print("collecting bin dict size",chromsizes) 61 | chrom_dict=pd.DataFrame.from_dict(chromsizes).set_index("name")['length'] 62 | bins = cooler.binnify(chrom_dict, resolution) 63 | #then convert data array to index raw column, count array 64 | 65 | 66 | data_dict = {"bin1_id":[],"bin2_id":[],"count":[]} 67 | 68 | for key in data: 69 | chrom_name = key 70 | if mode == 0 or mode == 2: 71 | chrom1, chrom2 = chrom_name.split('_') 72 | else: 73 | chrom1 = chrom_name 74 | chrom2 = chrom_name 75 | if "chr" not in chrom1: 76 | chrom1 = "chr"+chrom1 77 | if "chr" not in chrom2: 78 | chrom2 = "chr"+chrom2 79 | print("processing",chrom1,chrom2,"...") 80 | matrix = data[key] 81 | if mode>=2: 82 | matrix = array2sparse(matrix) 83 | matrix_row = matrix.row 84 | matrix_col = matrix.col 85 | matrix_data = matrix.data 86 | 87 | matrix_row += chromosize_add_dict[chrom1] 88 | matrix_col += chromosize_add_dict[chrom2] 89 | data_dict['bin1_id']+=list(matrix_row) 90 | data_dict["bin2_id"]+=list(matrix_col) 91 | data_dict['count'] +=list(matrix_data) 92 | accumulate_index += matrix.shape[0] 93 | print("creating cool file...") 94 | #cooler.create_cooler(hic_path, bins,data_dict, dtypes={"count":"int"}, assembly="hg38") 95 | cooler.create_cooler(output_cool, bins=pd.DataFrame.from_dict(bins), pixels=pd.DataFrame.from_dict(data_dict), dtypes={'count': float},assembly=refer_genome_name) 96 | """ 97 | Usage 98 | ``` 99 | python3 array2cool.py [input.pkl] [output.cool] [resolution] [refer_genome_name] [mode] 100 | ``` 101 | The input pickle should be in a pickle file as dict: [chrom1_chrom2]:[array] format for common mode. Here array should be scipy sparce array.
102 | For intra-chromsome only, the dict format can be [chrom]:[array] in pickle files.
103 | [output.cool] is the name of the output cool file.
104 | [resolution] is used to specify the resolution that stored in the output array.
105 | [refer_genome_name] is used to specify the reference genome name. For example, "hg38","hg19","mm10" are valid inputs.
106 | [mode]: 0: all chromosome mode (scipy sparce array); 1: intra-chromosome mode(scipy sparce array); 2: all chromosome mode (numpy array); 3: intra-chromosome mode(numpy array).
107 | """ 108 | if __name__ == '__main__': 109 | 110 | if len(sys.argv)!=6: 111 | print('Usage: python3 array2cool.py [input.pkl] [output.cool] [resolution] [refer_genome_name] [mode]') 112 | print("This is the full array2cool script. ") 113 | print("input.pkl: the path to the pickle file containing the array [String].") 114 | print("input.pkl format: [chrom1_chrom2]:[array] format for common mode. Here array should be scipy sparce array. For intra-chromsome only, the dict format can be [chrom]:[array] in pickle files.") 115 | print("output.cool: the name of the output cool file [String].") 116 | print("resolution: resolution of the input array [Integer].") 117 | print("refer_genome_name: the name of the reference genome [String]. Example: hg38, hg19, mm10.") 118 | print("mode: 0: all chromosome mode (scipy sparce array); 1: intra-chromosome mode(scipy sparce array); 2: all chromosome mode (numpy array); 3: intra-chromosome mode(numpy array).") 119 | sys.exit(1) 120 | 121 | script_dir = os.path.dirname(os.path.realpath(__file__)) 122 | input_array_pickle = os.path.abspath(sys.argv[1]) 123 | output_hic = os.path.abspath(sys.argv[2]) 124 | resolution = int(sys.argv[3]) 125 | refer_genome_name = str(sys.argv[4]) 126 | mode = int(sys.argv[5]) 127 | array2cool(input_array_pickle,output_hic,resolution,refer_genome_name,mode) -------------------------------------------------------------------------------- /utils/array2hic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import numpy as np 5 | #assume at least run on a machine with 8 CPUs+64G memory 6 | 7 | def array2sparse(array): 8 | """ 9 | The array2sparse function converts a numpy array to a scipy sparce array. 10 | 11 | :param array: Specify the numpy array 12 | :return: A scipy sparce array 13 | :doc-author: Trelent 14 | """ 15 | from scipy.sparse import coo_matrix 16 | row, col = np.where(array) 17 | data = array[row, col] 18 | return coo_matrix((data, (row, col)), shape=array.shape) 19 | 20 | def array2hic(juicer_tools,input_array_pickle, 21 | output_hic,resolution,refer_genome_name,mode=0): 22 | """ 23 | The array2hic function converts a numpy array to hic file. 24 | 25 | :param juicer_tools: Specify the location of the juicer_tools 26 | :param input_array_pickle: Specify the path to the pickle file containing the array 27 | :param output_hic: Specify the name of the output hic file 28 | :param resolution: Set the resolution of the hic file 29 | :param refer_genome_name: Specify the reference genome name 30 | :return: A hic file 31 | :doc-author: Trelent 32 | """ 33 | #load array 34 | with open(input_array_pickle, 'rb') as f: 35 | data = pickle.load(f) 36 | output_dir = os.path.dirname(output_hic) 37 | os.makedirs(output_dir, exist_ok=True) 38 | raw_path = output_hic.replace('.hic','.raw') 39 | with open(raw_path, 'w') as wfile: 40 | for key in data: 41 | if mode == 0 or mode == 2: 42 | chrom1, chrom2 = key.split('_') 43 | else: 44 | chrom1 = key 45 | chrom2 = key 46 | matrix = data[key] 47 | if mode>=2: 48 | matrix = array2sparse(matrix) 49 | #matrix merge records in the same loc 50 | matrix.eliminate_zeros() 51 | matrix.sum_duplicates() 52 | matrix_row = matrix.row 53 | matrix_col = matrix.col 54 | matrix_data = matrix.data 55 | if "chr" not in chrom1: 56 | chrom1 = "chr"+chrom1 57 | if "chr" not in chrom2: 58 | chrom2 = "chr"+chrom2 59 | for i in range(len(matrix_row)): 60 | wfile.write(f'{0} {chrom1} {int(matrix_row[i]*resolution+1)} {0} {0} {chrom2} {matrix_col[i]*resolution+1} {1} {matrix_data[i]:.2f}\n') 61 | code_path = os.path.dirname(juicer_tools) 62 | root_path = os.getcwd() 63 | os.chdir(code_path) 64 | os.system(f'java -Xmx64g -Xmx64g -jar juicer_tools.jar pre -j 8 -d -r {resolution} "{raw_path}" "{output_hic}" "{refer_genome_name}"') 65 | os.remove(raw_path) 66 | 67 | os.chdir(root_path) 68 | 69 | """ 70 | Usage 71 | ``` 72 | python3 array2hic.py [input.pkl] [output.hic] [resolution] [refer_genome_name] [mode] 73 | ``` 74 | The input pickle should be in a pickle file as dict: [chrom1_chrom2]:[array] format for common mode. Here array should be scipy sparce array.
75 | For intra-chromsome only, the dict format can be [chrom]:[array] in pickle files.
76 | [output.hic] is the name of the output hic file.
77 | [resolution] is used to specify the resolution that stored in the output array.
78 | [refer_genome_name] is used to specify the reference genome name. For example, "hg38","hg19","mm10" are valid inputs.
79 | [mode]: 0: all chromosome mode (scipy sparce array); 1: intra-chromosome mode(scipy sparce array); 2: all chromosome mode (numpy array); 3: intra-chromosome mode(numpy array).
80 | """ 81 | 82 | if __name__ == '__main__': 83 | 84 | #get current script directory 85 | 86 | if len(sys.argv) != 6: 87 | print('Usage: python3 array2hic.py [input.pkl] [output.hic] [resolution] [refer_genome_name] [mode]') 88 | print("This is the full array2hic script. ") 89 | print("input.pkl: the path to the pickle file containing the array [String].") 90 | print("input.pkl format: [chrom1_chrom2]:[array] format for common mode. Here array should be scipy sparce array. For intra-chromsome only, the dict format can be [chrom]:[array] in pickle files.") 91 | print("output.hic: the name of the output hic file [String].") 92 | print("resolution: resolution of the input array [Integer].") 93 | print("refer_genome_name: the name of the reference genome [String]. Example: hg38, hg19, mm10.") 94 | print("mode: 0: all chromosome mode (scipy sparce array); 1: intra-chromosome mode(scipy sparce array); 2: all chromosome mode (numpy array); 3: intra-chromosome mode(numpy array).") 95 | sys.exit(1) 96 | script_dir = os.path.dirname(os.path.realpath(__file__)) 97 | juicer_tools = os.path.join(script_dir, 'juicer_tools.jar') 98 | input_array_pickle = os.path.abspath(sys.argv[1]) 99 | output_hic = os.path.abspath(sys.argv[2]) 100 | resolution = int(sys.argv[3]) 101 | refer_genome_name = str(sys.argv[4]) 102 | mode = int(sys.argv[5]) 103 | array2hic(juicer_tools,input_array_pickle,output_hic,resolution,refer_genome_name,mode) 104 | 105 | 106 | -------------------------------------------------------------------------------- /utils/cool2array.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cooler 3 | from scipy.sparse import coo_matrix 4 | import pickle 5 | def write_pkl(data, path): 6 | with open(path, 'wb') as f: 7 | pickle.dump(data, f) 8 | def cool2array(cooler_path,normalize=False,tondarray=False): 9 | """ 10 | cooler_path: the path to the cooler file 11 | normalize: if True, the matrix will be normalized by the norm matrix saved in the cooler file 12 | tondarray: if True, the return a numpy array dict 13 | return a numpy/scipy.sparse array dict 14 | [chromosome1_chromsome2]:sparce matrix 15 | """ 16 | c = cooler.Cooler(cooler_path) 17 | binsize= c.info['bin-size'] 18 | chromosomes= c.chromnames 19 | chromosome_sizes = c.chromsizes 20 | bins = c.bins()[:] #including the chromosome staring ending info for each bin 21 | bins['bin_id'] = bins.index 22 | #column of bins[['chrom', 'start', 'end','weight'] 23 | pixels = c.pixels()[:] #including the chromosome staring ending info for each bin 24 | #including bin1_id,bin2_id,count columns 25 | return_dict={} 26 | for k,chromsome in enumerate(chromosomes): 27 | for j,chromsome2 in enumerate(chromosomes): 28 | if j 156 | The output array is saved in a pickle file as dict: [chrom1_chrom2]:[array] format.
157 | Two modes are supported: 158 | ``` 159 | 0: scipy coo_array format output; 160 | 1: numpy array format output; 161 | 2: normed scipy coo_array format output; 162 | 3: normed numpy array format output. 163 | ``` 164 | """ 165 | 166 | if __name__ == '__main__': 167 | import os 168 | import sys 169 | if len(sys.argv) != 4: 170 | print('Usage: python3 cool2array.py [input.cool] [output.pkl] [mode]') 171 | print("This is the full cool2array script. ") 172 | print("mode: 0 for sparse matrix, 1 for dense matrix, 2 for normed sparse matrix, 3 for normed dense matrix") 173 | sys.exit(1) 174 | 175 | cooler_path = os.path.abspath(sys.argv[1]) 176 | output_pkl_path = os.path.abspath(sys.argv[2]) 177 | output_dir = os.path.dirname(output_pkl_path) 178 | os.makedirs(output_dir,exist_ok=True) 179 | mode = int(sys.argv[3]) 180 | if mode not in [0,1,2,3]: 181 | print('mode should be 0,1,2,3') 182 | sys.exit(1) 183 | if mode == 0: 184 | normalize = False 185 | tondarray = False 186 | elif mode == 1: 187 | normalize = False 188 | tondarray = True 189 | elif mode == 2: 190 | normalize = True 191 | tondarray = False 192 | elif mode == 3: 193 | normalize = True 194 | tondarray = True 195 | return_dict = cool2array(cooler_path,normalize=normalize,tondarray=tondarray) 196 | write_pkl(return_dict,output_pkl_path) 197 | -------------------------------------------------------------------------------- /utils/hic2array.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | from scipy.sparse import coo_matrix 5 | import hicstraw 6 | import os 7 | import pickle 8 | def write_pkl(data, path): 9 | with open(path, 'wb') as f: 10 | pickle.dump(data, f) 11 | def read_chrom_array(chr1, chr2, normalization, hic_file, resolution,call_resolution): 12 | chr1_name = chr1.name 13 | chr2_name = chr2.name 14 | infos = [] 15 | infos.append('observed') 16 | infos.append(normalization) 17 | infos.append(hic_file) 18 | infos.append(chr1_name) 19 | infos.append(chr2_name) 20 | infos.append('BP') 21 | infos.append(call_resolution) 22 | print(infos) 23 | row, col, val = [], [], [] 24 | rets = hicstraw.straw(*infos) 25 | print('\tlen(rets): {:3e}'.format(len(rets))) 26 | for ret in rets: 27 | row.append((int)(ret.binX // resolution)) 28 | col.append((int)(ret.binY // resolution)) 29 | val.append(ret.counts) 30 | print('\tsum(val): {:3e}'.format(sum(val))) 31 | if sum(val) == 0: 32 | return None 33 | if chr1_name==chr2_name: 34 | max_shape =max(max(row),max(col))+1 35 | mat_coo = coo_matrix((val, (row, col)), shape = (max_shape,max_shape),dtype=np.float32) 36 | else: 37 | max_row = max(row)+1 38 | max_column = max(col)+1 39 | mat_coo = coo_matrix((val, (row, col)), shape = (max_row,max_column),dtype=np.float32) 40 | 41 | mat_coo = mat_coo #+ triu(mat_coo, 1).T #no below diagonaline records 42 | 43 | return mat_coo 44 | 45 | def validate_resolution(resolution,resolution_list): 46 | for reso in resolution_list: 47 | if resolution%reso==0: 48 | return True 49 | return False 50 | 51 | def find_call_resolution(resolution, resolution_list): 52 | max_best_resolution = min(resolution_list) 53 | resolution_list.sort() 54 | for reso in resolution_list: 55 | if resolution%reso==0: 56 | if reso>max_best_resolution: 57 | max_best_resolution=reso 58 | return max_best_resolution 59 | 60 | def hic2array(input_hic,output_pkl=None, 61 | resolution=25000,normalization="NONE", 62 | tondarray=0): 63 | """ 64 | input_hic: str, input hic file path 65 | output_pkl: str, output pickle file path 66 | resolution: int, resolution of the hic file 67 | """ 68 | 69 | hic = hicstraw.HiCFile(input_hic) 70 | chrom_list=[] 71 | chrom_dict={} 72 | for chrom in hic.getChromosomes(): 73 | print(chrom.name, chrom.length) 74 | if "all" in chrom.name.lower(): 75 | continue 76 | chrom_list.append(chrom) 77 | chrom_dict[chrom.name]=chrom.length 78 | resolution_list = hic.getResolutions() 79 | #max_resolution_candidate = max(resolution_list) 80 | if not validate_resolution(resolution,resolution_list): 81 | print("Resolution not found in the hic file, please choose from the following list:") 82 | print(resolution_list) 83 | print("Any other resolution is coarse than this and dividable by one of them can also be supported.") 84 | exit() 85 | output_dict={} 86 | for i in range(len(chrom_list)): 87 | for j in range(i,len(chrom_list)): 88 | if i!=j and tondarray in [2,3]: 89 | #skip inter-chromosome region 90 | continue 91 | 92 | chrom1 = chrom_list[i] 93 | chrom1_name = chrom_list[i].name 94 | chrom2 = chrom_list[j] 95 | chrom2_name = chrom_list[j].name 96 | if 'Un' in chrom1_name or 'Un' in chrom2_name: 97 | continue 98 | if "random" in chrom1_name.lower() or "random" in chrom2_name.lower(): 99 | continue 100 | if "alt" in chrom1_name.lower() or "alt" in chrom2_name.lower(): 101 | continue 102 | read_array=read_chrom_array(chrom1,chrom2, normalization, input_hic, resolution,call_resolution=find_call_resolution(resolution,resolution_list)) 103 | if read_array is None: 104 | print("No data found for",chrom1_name,chrom2_name) 105 | continue 106 | if tondarray in [1,3]: 107 | read_array = read_array.toarray() 108 | if tondarray in [2,3]: 109 | output_dict[chrom1_name]=read_array 110 | else: 111 | output_dict[chrom1_name+"_"+chrom2_name]=read_array 112 | if output_pkl is not None: 113 | output_dir = os.path.dirname(os.path.realpath(output_pkl)) 114 | os.makedirs(output_dir, exist_ok=True) 115 | write_pkl(output_dict,output_pkl) 116 | 117 | return output_dict 118 | """ 119 | 120 | Usage 121 | ``` 122 | python3 hic2array_simple.py [input.hic] [output.pkl] [resolution] [normalization_type] [mode] 123 | ``` 124 | 125 | This is the full cool2array script, converting both intra, inter chromosome regions to array format.
126 | The output array is saved in a pickle file as dict: [chrom1_chrom2]:[array] format.
127 | [resolution] is used to specify the resolution that stored in the output array.
128 | [normalization_type] supports the following type:
129 | ``` 130 | 0: NONE normalization applied, save the raw data to array. 131 | 1: VC normalization; 132 | 2: VC_SQRT normalization; 133 | 3: KR normalization; 134 | 4: SCALE normalization. 135 | ``` 136 | Four modes are supported for different format saving: 137 | ``` 138 | 0: scipy coo_array format output; 139 | 1: numpy array format output; 140 | 2: scipy csr_array format output (only include intra-chromsome region). 141 | 3: numpy array format output (only include intra-chromsome region). 142 | ``` 143 | 144 | """ 145 | if __name__ == '__main__': 146 | import os 147 | import sys 148 | if len(sys.argv) != 6: 149 | print('Usage: python3 hic2array_simple.py [input.hic] [output.pkl] [resolution] [normalization_type] [mode]') 150 | print("This is the full hic2array script. ") 151 | print("normalization type: 0: None normalization; 1: VC normalization; 2: VC_SQRT normalization; 3: KR normalization; 4: SCALE normalization") 152 | print("mode: 0 for sparse matrix, 1 for dense matrix, 2 for sparce matrix (only cis-contact); 3 for dense matrix (only cis-contact).") 153 | sys.exit(1) 154 | resolution = int(sys.argv[3]) 155 | normalization_type = int(sys.argv[4]) 156 | mode = int(sys.argv[5]) 157 | normalization_dict={0:"NONE",1:"VC",2:"VC_SQRT",3:"KR",4:"SCALE"} 158 | if normalization_type not in normalization_dict: 159 | print('normalization type should be 0,1,2,3,4') 160 | print("normalization type: 0: None normalization; 1: VC normalization; 2: VC_SQRT normalization; 3: KR normalization; 4: SCALE normalization") 161 | sys.exit(1) 162 | normalization_type = normalization_dict[normalization_type] 163 | if mode not in [0,1,2,3]: 164 | print('mode should be in choice of 0/1/2/3') 165 | print("mode: 0 for sparse matrix, 1 for dense matrix, 2 for sparce matrix (only cis-contact); 3 for dense matrix (only cis-contact).") 166 | sys.exit(1) 167 | input_hic_path = os.path.abspath(sys.argv[1]) 168 | output_pkl_path = os.path.abspath(sys.argv[2]) 169 | output_dir = os.path.dirname(output_pkl_path) 170 | os.makedirs(output_dir,exist_ok=True) 171 | hic2array(input_hic_path,output_pkl_path,resolution,normalization_type,mode) 172 | -------------------------------------------------------------------------------- /utils/hic_coverage.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pickle 4 | def calculate_coverage(input_pkl): 5 | data = pickle.load(open(input_pkl, 'rb')) 6 | count_reads=0 7 | count_length=0 8 | for chrom in data: 9 | cur_data = data[chrom] 10 | count_length += cur_data.shape[0] 11 | count_reads += cur_data.sum() 12 | count_length = count_length 13 | count_reads = count_reads 14 | coverage = count_reads/count_length 15 | return coverage 16 | """ 17 | This script calculates the coverage of the Hi-C data. 18 | ``` 19 | python3 hic_coverage.py [input.pkl] 20 | ``` 21 | [input.pkl]: the input pkl file containing the Hi-C data
22 | 23 | """ 24 | 25 | 26 | if __name__ == '__main__': 27 | if len(sys.argv) != 2: 28 | print("Usage: python3 hic_coverage.py [input.pkl]") 29 | print("[input.pkl]: the input pkl file containing the Hi-C data") 30 | # print("[fragment_size]: the size of the fragment for building Hi-C") 31 | sys.exit(1) 32 | input_pkl = os.path.abspath(sys.argv[1]) 33 | #resolution = int(sys.argv[2]) 34 | #fragment_size = int(sys.argv[3]) 35 | coverage = calculate_coverage(input_pkl) 36 | print("Hi-C Coverage: ", coverage) -------------------------------------------------------------------------------- /utils/juicer_tools.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/utils/juicer_tools.jar --------------------------------------------------------------------------------