├── .gitignore
├── Dockerfile
├── HiCFoundation.ipynb
├── LICENSE
├── README.md
├── Reproducibility.ipynb
├── data_processing
├── __init__.py
├── collate_fn.py
├── finetune_dataset.py
├── inference_dataset.py
└── pretrain_dataset.py
├── environment.yml
├── environment_notorch.yml
├── example
├── GSM7006609_ValbB8w1081.pairs
├── __init__.py
├── finetune_example
│ ├── train.txt
│ ├── train
│ │ ├── data_0.pkl
│ │ ├── data_1.pkl
│ │ ├── data_2.pkl
│ │ ├── data_3.pkl
│ │ └── data_4.pkl
│ ├── val.txt
│ └── val
│ │ ├── data_0.pkl
│ │ ├── data_1.pkl
│ │ ├── data_2.pkl
│ │ ├── data_3.pkl
│ │ └── data_4.pkl
└── pretrain_example
│ ├── train.txt
│ ├── train
│ ├── data_1.pkl
│ ├── data_2.pkl
│ ├── data_3.pkl
│ ├── data_4.pkl
│ ├── data_5.pkl
│ ├── data_6.pkl
│ ├── data_7.pkl
│ └── data_8.pkl
│ └── val.txt
├── finetune.py
├── finetune
├── __init__.py
├── loss.py
├── main_worker.py
├── train_epoch.py
└── val_epoch.py
├── hicfoundation_model
└── __init__.py
├── imgs
└── framework_github.png
├── inference.py
├── inference
├── __init__.py
├── inference_worker.py
├── load_model.py
└── main_worker.py
├── model
├── Finetune_Model_Head.py
├── NativeScaler.py
├── SSIM.py
├── Vision_Transformer_count.py
├── __init__.py
├── lr_decay.py
├── lr_sched.py
├── model_utils.py
├── models_hicfoundation.py
└── pos_embed.py
├── ops
├── Logger.py
├── __init__.py
├── argparser.py
├── calculate_similarity.py
├── distribute_utils.py
├── file_format_convert.py
├── io_utils.py
├── mean_shift_merge.py
├── smooth_matrix.py
├── sparse_ops.py
└── train_utils.py
├── pretrain.py
├── pretrain
├── __init__.py
├── main_worker.py
├── train_epoch.py
└── val_epoch.py
├── refactor_pretrain
├── Logger.py
├── SSIM.py
├── argparser.py
├── distribute_utils.py
├── main_worker.py
├── model_funcs.py
├── model_utils.py
├── models_hicfoundation.py
├── pretrain.py
├── pretrain_dataset.py
├── train_epoch.py
├── utils.py
└── val_epoch.py
├── requirements.txt
├── test
├── test_convert_rgb.py
├── test_forward.py
├── test_loss.py
├── test_masking.py
└── test_patchify.py
└── utils
├── __init__.py
├── array2bigwig.py
├── array2cool.py
├── array2hic.py
├── cool2array.py
├── hic2array.py
├── hic_coverage.py
└── juicer_tools.jar
/.gitignore:
--------------------------------------------------------------------------------
1 | *pycache*
2 | .coverage
3 | htmlcov
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Use NVIDIA CUDA base image with Ubuntu
2 | FROM nvidia/cuda:11.1.1-devel-ubuntu20.04
3 |
4 | # Set environment variables
5 | ENV DEBIAN_FRONTEND=noninteractive
6 | ENV CONDA_DIR=/opt/conda
7 | ENV PATH=$CONDA_DIR/bin:$PATH
8 | ENV PYTHONPATH=/app:$PYTHONPATH
9 |
10 | # Install system dependencies
11 | # configure hic-straw depdendencies
12 | RUN apt-get update && apt-get install -y \
13 | wget \
14 | curl \
15 | git \
16 | build-essential \
17 | cmake \
18 | libcurl4-openssl-dev \
19 | zlib1g-dev \
20 | && rm -rf /var/lib/apt/lists/*
21 |
22 | # Install Miniconda
23 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
24 | /bin/bash ~/miniconda.sh -b -p $CONDA_DIR && \
25 | rm ~/miniconda.sh && \
26 | conda clean -t -i -p -y
27 |
28 | # Set working directory
29 | WORKDIR /app
30 |
31 | # Create conda environment and install dependencies directly
32 | RUN conda create -n HiCFoundation python=3.8.10 -y && \
33 | conda clean -afy
34 |
35 | # Make RUN commands use the new environment
36 | SHELL ["conda", "run", "-n", "HiCFoundation", "/bin/bash", "-c"]
37 |
38 | # Install conda packages
39 | RUN conda install -c pytorch -c nvidia -c conda-forge -c anaconda -c bioconda -c defaults \
40 | cudatoolkit=11.1.74 \
41 | pip=21.1.3 \
42 | pytorch=1.8.1 \
43 | torchvision=0.9.1 \
44 | timm=0.3.2 \
45 | openjdk \
46 | pandas \
47 | matplotlib \
48 | scipy \
49 | numba \
50 | cooler \
51 | -y && \
52 | conda clean -afy
53 |
54 | # Install pip packages
55 | RUN pip install \
56 | easydict \
57 | opencv-python \
58 | simplejson \
59 | lvis \
60 | "Pillow==9.5.0" \
61 | pytorch_msssim \
62 | scikit-image \
63 | einops \
64 | tensorboard \
65 | pyBigWig
66 |
67 | RUN pip install hic-straw
68 |
69 | # Set the default command to activate conda environment
70 | ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "HiCFoundation"]
71 | CMD ["python", "--version"]
72 |
--------------------------------------------------------------------------------
/Reproducibility.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "private_outputs": true,
7 | "provenance": [],
8 | "authorship_tag": "ABX9TyOXdz0SpmwfOe1u8Ku2hJho",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | }
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {
23 | "id": "view-in-github",
24 | "colab_type": "text"
25 | },
26 | "source": [
27 | "
"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "source": [
33 | "# HiCFoundation: a generalizable Hi-C foundation model for chromatin architecture, single-cell and multi-omics analysis across species\n",
34 | "**This repo is only for calculating reproducbility score by HiCFoundation**"
35 | ],
36 | "metadata": {
37 | "id": "S-eNVerqcDbI"
38 | }
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "source": [
43 | "HiCFoundation is a generalizable Hi-C foundation model for chromatin architecture, single-cell and multi-omics analysis across species.\n",
44 | "\n",
45 | "Copyright (C) 2024 Xiao Wang, Yuanyuan Zhang, Suhita Ray, Anupama Jha, Tangqi Fang, Shengqi Hang, Sergei Doulatov, William Stafford Noble, and Sheng Wang\n",
46 | "\n",
47 | "License: Apache License 2.0\n",
48 | "\n",
49 | "Contact: Sergei Doulatov (doulatov@uw.edu) & William Stafford Noble (wnoble@uw.edu) & Sheng Wang (swang@cs.washington.edu)\n",
50 | "\n",
51 | "For technical problems or questions, please reach to Xiao Wang (wang3702@uw.edu) and Yuanyuan Zhang (zhang038@purdue.edu).\n",
52 | "\n",
53 | "\n",
54 | "If you are using other browsers, disabling tracking protection may help resolve the errors when uploading or downloading files.\n",
55 | "\n",
56 | "For more details, see **Instructions** of the notebook and checkout the **[HiFoundation GitHub](https://github.com/Noble-Lab/HiCFoundation)**. If you use HiCFoundation, please cite it: **Citation**."
57 | ],
58 | "metadata": {
59 | "id": "aqhB21yTcFG_"
60 | }
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "source": [
65 | "# Instructions \n",
66 | "## Steps\n",
67 | "1. Run HiCFoundation Colab on your interested two Hi-C maps and download the embedding pickle files for further processing.\n",
68 | "2. Connect to a **cpu machine** by clicking the right top button **\"connect\"** in the notebook.
\n",
69 | "3. Upload the embedding of 1st Hi-C map (.pkl file) in Input file1.\n",
70 | "4. Upload the embedding of 1st Hi-C map (.pkl file) in Input file2.\n",
71 | "5. Running the score calculation by by clicking the left running button in Run.\n",
72 | "6. You can check the output to get the similarity score in the same tab."
73 | ],
74 | "metadata": {
75 | "id": "j62BWrrecdSS"
76 | }
77 | },
78 | {
79 | "cell_type": "code",
80 | "source": [
81 | "#@title Input embedding file1\n",
82 | "from google.colab import files\n",
83 | "import os\n",
84 | "import os.path\n",
85 | "import re\n",
86 | "import hashlib\n",
87 | "import random\n",
88 | "import string\n",
89 | "from google.colab import drive\n",
90 | "\n",
91 | "from datetime import datetime\n",
92 | "# Get the current date and time\n",
93 | "current_datetime = datetime.now()\n",
94 | "# Convert to string in desired format\n",
95 | "current_datetime_str = current_datetime.strftime(\"%Y-%m-%d-%H-%M-%S\")\n",
96 | "rand_letters = string.ascii_lowercase\n",
97 | "rand_letters = ''.join(random.choice(rand_letters) for i in range(20))\n",
98 | "output_dir=\"/content/\"\n",
99 | "\n",
100 | "#@markdown ## Upload the calculated embedding file(.pkl) of 1st Hi-C from your local file system\n",
101 | "print(\"Please uploading your input files\")\n",
102 | "os.chdir(\"/content/\")\n",
103 | "root_dir = os.getcwd()\n",
104 | "upload_dir = os.path.join(root_dir,rand_letters)\n",
105 | "if not os.path.exists(upload_dir):\n",
106 | " os.mkdir(upload_dir)\n",
107 | "os.chdir(upload_dir)\n",
108 | "map_input = files.upload()\n",
109 | "for fn in map_input.keys():\n",
110 | " print('User uploaded file \"{name}\" with length {length} bytes'.format(\n",
111 | " name=fn, length=len(map_input[fn])))\n",
112 | " hic_input_path1 = os.path.abspath(fn)\n",
113 | " print(\"The input save to %s\"%hic_input_path1)\n",
114 | "os.chdir(root_dir)\n",
115 | "\n"
116 | ],
117 | "metadata": {
118 | "cellView": "form",
119 | "id": "b-pBMNY3c-9o"
120 | },
121 | "execution_count": null,
122 | "outputs": []
123 | },
124 | {
125 | "cell_type": "code",
126 | "source": [
127 | "#@title Input embedding file2\n",
128 | "#@markdown ## Upload the calculated embedding file(.pkl) of 2nd Hi-C from your local file system\n",
129 | "os.chdir(upload_dir)\n",
130 | "map_input = files.upload()\n",
131 | "for fn in map_input.keys():\n",
132 | " print('User uploaded file \"{name}\" with length {length} bytes'.format(\n",
133 | " name=fn, length=len(map_input[fn])))\n",
134 | " hic_input_path2 = os.path.abspath(fn)\n",
135 | " print(\"The input save to %s\"%hic_input_path2)\n",
136 | "os.chdir(root_dir)"
137 | ],
138 | "metadata": {
139 | "cellView": "form",
140 | "id": "WGe1QYgcf1Fw"
141 | },
142 | "execution_count": null,
143 | "outputs": []
144 | },
145 | {
146 | "cell_type": "code",
147 | "source": [
148 | "# @title Reproducibility score calculation\n",
149 | "# This script is to calculate the similarity between two Hi-C using a pre-trained reproducibility model.\n",
150 | "\n",
151 | "import os\n",
152 | "import sys\n",
153 | "import numpy as np\n",
154 | "import pickle\n",
155 | "from collections import defaultdict\n",
156 | "\n",
157 | "input_pickle1 = hic_input_path1\n",
158 | "input_pickle2 = hic_input_path2\n",
159 | "\n",
160 | "def load_pickle(file_path):\n",
161 | " with open(file_path, 'rb') as f:\n",
162 | " data = pickle.load(f)\n",
163 | " return data\n",
164 | "\n",
165 | "input1 = load_pickle(input_pickle1)\n",
166 | "input2 = load_pickle(input_pickle2)\n",
167 | "\n",
168 | "def find_key(chr,loc,key_list):\n",
169 | " \"\"\"\n",
170 | " Find the key in the list of keys that contains the given chromosome and location.\n",
171 | " \"\"\"\n",
172 | " key1 = chr+\":\"+loc\n",
173 | " if key1 in key_list:\n",
174 | " return key1\n",
175 | " key1 = \"chr\"+chr+\":\"+loc\n",
176 | " if key1 in key_list:\n",
177 | " return key1\n",
178 | " key1 = chr+\"_\"+chr+\":\"+loc\n",
179 | " if key1 in key_list:\n",
180 | " return key1\n",
181 | " key1 = \"chr\"+chr+\"_chr\"+chr+\":\"+loc\n",
182 | " if key1 in key_list:\n",
183 | " return key1\n",
184 | " return None\n",
185 | "\n",
186 | "def calculate_similarity(input1, input2):\n",
187 | " \"\"\"\n",
188 | " Calculate the similarity between two Hi-C matrices using a pre-trained reproducibility model.\n",
189 | " \"\"\"\n",
190 | " similarity_dict = defaultdict(list)\n",
191 | " for key in input1.keys():\n",
192 | " #1_1:1960,1960 format of key\n",
193 | " split_chromosome = key.split(\":\")[0]\n",
194 | " split_loc = key.split(\":\")[1]\n",
195 | " combine_key = split_chromosome + \":\" + split_loc\n",
196 | " chr = split_chromosome.split(\"_\")[0]\n",
197 | " chr = chr.replace(\"chr\",\"\")\n",
198 | " if combine_key not in input2.keys():\n",
199 | " combine_key = find_key(chr,split_loc,input2.keys())\n",
200 | " if combine_key is None:\n",
201 | " continue\n",
202 | "\n",
203 | " embedding1 = input1[key]\n",
204 | " embedding2 = input2[combine_key]\n",
205 | " # Calculate the similarity between the two embeddings\n",
206 | " similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))\n",
207 | " if np.isnan(similarity):\n",
208 | " continue\n",
209 | " similarity_dict[chr].append(similarity)\n",
210 | " #ignore chrY, chrM, Un, Alt cases\n",
211 | " similarity_list=[]\n",
212 | " for chrom in similarity_dict:\n",
213 | " if \"Y\" in chrom or \"M\" in chrom or \"Un\" in chrom or \"Alt\" in chrom:\n",
214 | " continue\n",
215 | " mean_val = np.mean(similarity_dict[chrom])\n",
216 | " similarity_list.append(mean_val)\n",
217 | " similarity = np.mean(similarity_list)\n",
218 | " return similarity\n",
219 | "\n",
220 | "similarity = calculate_similarity(input1, input2)\n",
221 | "print(\"The reproducibility score between the two Hi-C is: \", similarity)\n"
222 | ],
223 | "metadata": {
224 | "cellView": "form",
225 | "id": "v41vpgxbgQoV"
226 | },
227 | "execution_count": null,
228 | "outputs": []
229 | }
230 | ]
231 | }
--------------------------------------------------------------------------------
/data_processing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/data_processing/__init__.py
--------------------------------------------------------------------------------
/data_processing/collate_fn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | def collate_fn(batch):
3 | # Transpose the batch (list of lists) to group elements by position
4 | batch_transposed = list(zip(*batch))
5 |
6 | # Process each position across the batch
7 | processed_batch = []
8 | for tensors in batch_transposed:
9 | if all(t is None for t in tensors): # If all are None, keep None
10 | processed_batch.append(None)
11 | else: # Otherwise, stack non-None tensors and replace None with zero tensors
12 | #make sure no None element in the tensors
13 | any_none = any(t is None for t in tensors)
14 | assert not any_none, "None element in a list of tensors"
15 | stacked = [
16 | t for t in tensors
17 | ]
18 | processed_batch.append(torch.stack(stacked))
19 |
20 | return processed_batch
21 |
22 |
--------------------------------------------------------------------------------
/data_processing/finetune_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.utils.data
5 | import random
6 | from collections import defaultdict
7 | from scipy.sparse import coo_matrix
8 | import pickle
9 | from ops.sparse_ops import array_to_coo
10 | from ops.io_utils import load_pickle
11 | def validate_input_size(input_matrix, window_height, window_width):
12 | """
13 | Validate the input size is larger than the window size
14 | Args:
15 | input_matrix: the input matrix
16 | window_height: the height of the window
17 | window_width: the width of the window
18 | """
19 | if isinstance(input_matrix, coo_matrix):
20 | input_matrix = input_matrix.toarray()
21 | input_height, input_width = input_matrix.shape
22 | if input_height==window_height and input_width==window_width:
23 | return True
24 | return False
25 |
26 | def to_tensor(x):
27 | """
28 | Convert the input to tensor
29 | Args:
30 | x: the input data
31 | """
32 | if isinstance(x, np.ndarray):
33 | x = torch.from_numpy(x)
34 | elif x is None:
35 | x = None
36 | #if already tensor, do nothing
37 | elif isinstance(x, torch.Tensor):
38 | pass
39 | #if float, convert to tensor
40 | elif isinstance(x, float):
41 | x = torch.tensor(x)
42 | elif isinstance(x, int):
43 | x = torch.tensor(x)
44 | return x
45 |
46 | def list_to_tensor(x):
47 | """
48 | Convert the list to tensor
49 | Args:
50 | x: the input list
51 | """
52 | y=[]
53 | for i in x:
54 | y.append(to_tensor(i))
55 | return y
56 | class Finetune_Dataset(torch.utils.data.Dataset):
57 | def __init__(self,data_list,
58 | transform=None,
59 | window_height= 224,
60 | window_width = 224):
61 | """
62 | Args:
63 | data_list: list of data directories
64 | transform: the transformation to apply to the data
65 | window_height: the height of the window
66 | window_width: the width of the window
67 | """
68 | self.data_list = data_list
69 | self.transform = transform
70 | self.window_height = window_height
71 | self.window_width = window_width
72 | self.train_dict=defaultdict(list)
73 | self.train_list=[]
74 | for data_index, data_dir in enumerate(data_list):
75 | cur_dir = data_dir
76 | dataset_name = os.path.basename(cur_dir)
77 | listfiles = os.listdir(cur_dir)
78 | for file_index,file in enumerate(listfiles):
79 | cur_path = os.path.join(cur_dir, file)
80 | if file.endswith('.pkl'):
81 | if file_index==0:
82 | #verify the input pkl file includes the input key
83 | data= load_pickle(cur_path)
84 | data_keys = list(data.keys())
85 | if 'input' not in data:
86 | print("The input key is not included in the pkl file. The directory is skipped.")
87 | print("The dir is {}".format(cur_dir))
88 | continue
89 | #check other keys include in the dict
90 | target_exist=False
91 | for key in data_keys:
92 | if "target" in key:
93 | target_exist=True
94 | break
95 | if not target_exist:
96 | print("The target key is not included in the pkl file. The directory is skipped.")
97 | print("The dir is {}".format(cur_dir))
98 | continue
99 | #validate the input size
100 | input_matrix = data['input']
101 | if not validate_input_size(input_matrix, window_height, window_width):
102 | print("The input size is not matched with the window size. The directory is skipped.")
103 | print("The dir is {}".format(cur_dir))
104 | print("The input size is {}".format(input_matrix.shape))
105 | print("The specified window size is {} x {}".format(window_height, window_width))
106 | print("Please adjust --input_row_size and --input_col_size to match your input.")
107 | continue
108 | self.train_dict[dataset_name].append(cur_path)
109 | self.train_list.append(cur_path)
110 | else:
111 | print("The file {} is not a .pkl file.".format(file),"It is skipped.")
112 | continue
113 | print("The number of samples used in the dataset is {}".format(len(self.train_list)))
114 | #you can either select the train_list or train_dict to do training based on your exprience
115 | def __len__(self):
116 | return len(self.train_list)
117 |
118 | def convert_rgb(self,data_log,max_value):
119 | if len(data_log.shape)==2:
120 | data_log = data_log[np.newaxis,:]
121 | data_red = np.ones(data_log.shape)
122 | data_log1 = (max_value-data_log)/max_value
123 | data_rgb = np.concatenate([data_red,data_log1,data_log1],axis=0,dtype=np.float32)#transform only accept channel last case
124 | data_rgb = data_rgb.transpose(1,2,0)
125 | return data_rgb
126 |
127 | def __getitem__(self, idx):
128 | train_file = self.train_list[idx]
129 | data = load_pickle(train_file)
130 | input_matrix = data['input']
131 | if isinstance(input_matrix, coo_matrix):
132 | input_matrix = input_matrix.toarray()
133 | #make sure you save the down-diagonal regions if you use the coo_matrix
134 | #to support off-diagonal submatrix, we did not any automatic symmetrical conversion for your input array.
135 | input_matrix = np.nan_to_num(input_matrix)
136 | input_matrix = input_matrix.astype(np.float32)
137 | input_matrix = np.log10(input_matrix+1)
138 | max_value = np.max(input_matrix)
139 | input_matrix = self.convert_rgb(input_matrix,max_value)
140 | if self.transform:
141 | input_matrix = self.transform(input_matrix)
142 | if "input_count" in data:
143 | total_count = data['input_count']
144 | else:
145 | total_count = None #indiates not passing the total count
146 |
147 | if "2d_target" in data:
148 | target_matrix = data['2d_target']
149 | if isinstance(target_matrix, coo_matrix):
150 | target_matrix = target_matrix.toarray()
151 | target_matrix = np.nan_to_num(target_matrix)
152 | target_matrix = target_matrix.astype(np.float32)
153 | else:
154 | target_matrix = None
155 |
156 | if "embed_target" in data:
157 | embed_target = data['embed_target']
158 | if isinstance(embed_target, coo_matrix):
159 | embed_target = embed_target.toarray()
160 | embed_target = np.nan_to_num(embed_target)
161 | embed_target = embed_target.astype(np.float32)
162 | else:
163 | embed_target = None
164 |
165 | if "1d_target" in data:
166 | target_vector = data['1d_target']
167 | target_vector = np.nan_to_num(target_vector)
168 | target_vector = target_vector.astype(np.float32)
169 | else:
170 | target_vector = None
171 |
172 | return list_to_tensor([input_matrix, total_count, target_matrix, embed_target, target_vector])
173 |
174 |
175 |
176 |
177 |
178 |
179 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: HiCFoundation
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - anaconda
7 | - bioconda
8 | - defaults
9 | dependencies:
10 | - cudatoolkit=11.1.74
11 | - pip=21.1.3
12 | - python=3.8.10
13 | - pytorch=1.8.1
14 | - torchvision=0.9.1
15 | - timm=0.3.2
16 | - openjdk
17 | - pip:
18 | - easydict
19 | - opencv-python
20 | - simplejson
21 | - lvis
22 | - Pillow==9.5.0
23 | - pytorch_msssim
24 | - pandas
25 | - hic-straw
26 | - matplotlib
27 | - scikit-image
28 | - scipy
29 | - einops
30 | - tensorboard
31 | - cooler
32 | - numba
33 | - pyBigWig
34 |
35 |
36 |
--------------------------------------------------------------------------------
/environment_notorch.yml:
--------------------------------------------------------------------------------
1 | name: HiCFoundation
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - anaconda
7 | - bioconda
8 | - defaults
9 | dependencies:
10 | - pip=21.1.3
11 | - python=3.8.10
12 | - openjdk
13 | - pip:
14 | - easydict
15 | - opencv-python
16 | - simplejson
17 | - lvis
18 | - Pillow==9.5.0
19 | - pytorch_msssim
20 | - pandas
21 | - hic-straw
22 | - matplotlib
23 | - scikit-image
24 | - scipy
25 | - einops
26 | - tensorboard
27 | - cooler
28 | - numba
29 | - pyBigWig
30 |
31 |
32 |
--------------------------------------------------------------------------------
/example/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/__init__.py
--------------------------------------------------------------------------------
/example/finetune_example/train.txt:
--------------------------------------------------------------------------------
1 | train
2 |
--------------------------------------------------------------------------------
/example/finetune_example/train/data_0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/train/data_0.pkl
--------------------------------------------------------------------------------
/example/finetune_example/train/data_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/train/data_1.pkl
--------------------------------------------------------------------------------
/example/finetune_example/train/data_2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/train/data_2.pkl
--------------------------------------------------------------------------------
/example/finetune_example/train/data_3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/train/data_3.pkl
--------------------------------------------------------------------------------
/example/finetune_example/train/data_4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/train/data_4.pkl
--------------------------------------------------------------------------------
/example/finetune_example/val.txt:
--------------------------------------------------------------------------------
1 | val
--------------------------------------------------------------------------------
/example/finetune_example/val/data_0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/val/data_0.pkl
--------------------------------------------------------------------------------
/example/finetune_example/val/data_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/val/data_1.pkl
--------------------------------------------------------------------------------
/example/finetune_example/val/data_2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/val/data_2.pkl
--------------------------------------------------------------------------------
/example/finetune_example/val/data_3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/val/data_3.pkl
--------------------------------------------------------------------------------
/example/finetune_example/val/data_4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/finetune_example/val/data_4.pkl
--------------------------------------------------------------------------------
/example/pretrain_example/train.txt:
--------------------------------------------------------------------------------
1 | train
2 |
--------------------------------------------------------------------------------
/example/pretrain_example/train/data_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_1.pkl
--------------------------------------------------------------------------------
/example/pretrain_example/train/data_2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_2.pkl
--------------------------------------------------------------------------------
/example/pretrain_example/train/data_3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_3.pkl
--------------------------------------------------------------------------------
/example/pretrain_example/train/data_4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_4.pkl
--------------------------------------------------------------------------------
/example/pretrain_example/train/data_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_5.pkl
--------------------------------------------------------------------------------
/example/pretrain_example/train/data_6.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_6.pkl
--------------------------------------------------------------------------------
/example/pretrain_example/train/data_7.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_7.pkl
--------------------------------------------------------------------------------
/example/pretrain_example/train/data_8.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/example/pretrain_example/train/data_8.pkl
--------------------------------------------------------------------------------
/example/pretrain_example/val.txt:
--------------------------------------------------------------------------------
1 | train
2 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 |
2 | # My code has references to the following repositories:
3 | # DeiT: https://github.com/facebookresearch/deit
4 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
5 | # MAE: https://github.com/facebookresearch/mae
6 | # AdPE: https://github.com/maple-research-lab/AdPE
7 | # --------------------------------------------------------
8 | import os
9 | from ops.argparser import argparser_finetune
10 | import torch
11 | import torch.multiprocessing as mp
12 | import timm
13 | assert timm.__version__ == "0.3.2" # version check
14 | def main(args):
15 | import socket
16 | hostname = socket.gethostname()
17 | local_ip = socket.gethostbyname(hostname)
18 | print("local ip: ",local_ip)
19 |
20 | ngpus_per_node = torch.cuda.device_count()
21 | args.world_size = args.world_size*ngpus_per_node
22 | from finetune.main_worker import main_worker
23 | if ngpus_per_node==1:
24 | main_worker(args.gpu,ngpus_per_node,args)#if you only have one gpu
25 | else:
26 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
27 |
28 | if __name__ == '__main__':
29 | import resource
30 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
31 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096*2, rlimit[1]))
32 | limit_in_b = 900 * 1024 ** 3
33 | resource.setrlimit(resource.RLIMIT_DATA, (limit_in_b, limit_in_b))
34 | use_cuda = torch.cuda.is_available()
35 | print("starting check cuda status",use_cuda)
36 | #assert cuda is available
37 | assert use_cuda == True, "CUDA is not available, fine-tuning requires CUDA support to run!"
38 | parser = argparser_finetune()
39 | args = parser.parse_args()
40 | #If you have many GPU on your server, but you only want to use few of them
41 | # run command line to configure the environment:
42 | # export CUDA_VISIBLE_DEVICES="0,1,2,3"
43 | # Here you can specify the GPU you want to use
44 | #check the specied input size, must be a multiple of args.patch_size
45 | if args.input_row_size%args.patch_size!=0 or args.input_col_size%args.patch_size!=0:
46 | print("args configuration error: input_row_size and input_col_size must be a multiple of patch_size")
47 | exit(1)
48 | main(args)
49 |
--------------------------------------------------------------------------------
/finetune/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/finetune/__init__.py
--------------------------------------------------------------------------------
/finetune/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | #Please check https://pytorch.org/docs/stable/nn.html for more loss functions
3 |
4 | class cosine_distance(nn.Module):
5 | def __init__(self):
6 | super(cosine_distance, self).__init__()
7 | self.cos = nn.CosineSimilarity(dim=-1, eps=1e-08)
8 | def forward(self, x, y):
9 | return 1-self.cos(x,y)
10 |
11 | def configure_loss(args):
12 | if args.loss_type == 1:
13 | return nn.MSELoss()
14 | elif args.loss_type == 2:
15 | return cosine_distance()
16 | else:
17 | raise Exception("Unknown loss type: {}".format(args.loss_type))
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/finetune/train_epoch.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import sys
4 | import numpy as np
5 | from typing import Iterable
6 | import torch
7 | import torch.nn.functional as F
8 | import time
9 |
10 | from ops.Logger import MetricLogger,SmoothedValue
11 | import model.lr_sched as lr_sched
12 | from finetune.loss import configure_loss
13 | from ops.train_utils import list_to_device, to_value, create_image, torch_to_nparray, convert_gray_rgbimage
14 |
15 |
16 | def train_epoch(model, data_loader_train, optimizer,
17 | loss_scaler, epoch, device,
18 | log_writer=None, args=None):
19 | model.train()
20 | metric_logger = MetricLogger(delimiter=" ")
21 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
22 |
23 | header = 'Epoch: [{}]'.format(epoch)
24 | print_freq = args.print_freq
25 |
26 | accum_iter = args.accum_iter
27 |
28 | optimizer.zero_grad()
29 | if log_writer is not None:
30 | print('Tensorboard log dir: {}'.format(log_writer.log_dir))
31 | print("number of iterations: ",len(data_loader_train))
32 | criterion = configure_loss(args)
33 |
34 | num_iter = len(data_loader_train)
35 | for data_iter_step, train_data in enumerate(metric_logger.log_every(data_loader_train, print_freq, header)):
36 | if data_iter_step % accum_iter == 0:
37 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)
38 | input_matrix, total_count, target_matrix, embed_target, target_vector = list_to_device(train_data,device=device)
39 | output_embedding, output_2d, output_1d = model(input_matrix, total_count)
40 |
41 | if embed_target is not None:
42 | embedding_loss = criterion(output_embedding, embed_target)
43 | else:
44 | embedding_loss = 0
45 | if target_matrix is not None:
46 | #flatten 2d matrix
47 | output_2d_flatten = torch.flatten(output_2d, start_dim=1,end_dim=-1)
48 | target_matrix_flatten = torch.flatten(target_matrix, start_dim=1,end_dim=-1)
49 | output_2d_loss = criterion(output_2d_flatten, target_matrix_flatten)
50 | else:
51 | output_2d_loss = 0
52 | if target_vector is not None:
53 | output_1d_loss = criterion(output_1d, target_vector)
54 | else:
55 | output_1d_loss = 0
56 | loss = embedding_loss + output_2d_loss + output_1d_loss #you can adjust the loss function based on your fine-tuning purpose
57 | #typically, I think you should only finetune for one of the purposes
58 | metric_logger.update(loss=to_value(loss))
59 | metric_logger.update(embedding_loss=to_value(embedding_loss))
60 | metric_logger.update(output_2d_loss=to_value(output_2d_loss))
61 | metric_logger.update(output_1d_loss=to_value(output_1d_loss))
62 | if not math.isfinite(to_value(loss)):
63 | print("Loss is {}, stopping training".format(to_value(loss)))
64 | #sys.exit(1)
65 | optimizer.zero_grad()
66 | continue
67 | loss = loss / accum_iter
68 | loss_scaler(loss, optimizer, parameters=model.parameters(),
69 | update_grad=(data_iter_step + 1) % accum_iter == 0)
70 |
71 | if (data_iter_step + 1) % accum_iter == 0:
72 | optimizer.zero_grad()
73 |
74 | torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on
75 | lr = optimizer.param_groups[0]["lr"]
76 | metric_logger.update(lr=lr)
77 |
78 |
79 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0):
80 | """
81 | We use epoch_1000x as the x-axis in tensorboard.
82 | This calibrates different curves when batch size changes.
83 | """
84 | epoch_1000x = int((data_iter_step / len(data_loader_train) + epoch) * 1000)
85 | log_writer.add_scalars('Loss/loss', {'train_loss': to_value(loss)}, epoch_1000x)
86 | log_writer.add_scalars('Loss/embedding_loss', {'train_loss': to_value(embedding_loss)}, epoch_1000x)
87 | log_writer.add_scalars('Loss/output_2d_loss', {'train_loss': to_value(output_2d_loss)}, epoch_1000x)
88 | log_writer.add_scalars('Loss/output_1d_loss', {'train_loss': to_value(output_1d_loss)}, epoch_1000x)
89 | log_writer.add_scalars('LR/lr', {'lr': lr}, epoch_1000x)
90 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0:
91 | #add visualization for your output and input
92 | new_samples = create_image(input_matrix)
93 | select_num = min(8,len(new_samples))
94 | sample_image = torch_to_nparray(new_samples.clone().detach()[:select_num])
95 | log_writer.add_images('Input_%s'%"train", sample_image, epoch_1000x)
96 | output_2d_image = convert_gray_rgbimage(output_2d.clone().detach()[:select_num])
97 | output_2d_image = torch_to_nparray(output_2d_image)
98 | log_writer.add_images('Output_2d_%s'%"train", output_2d_image, epoch_1000x)
99 | # for name, param in model.named_parameters():
100 | # log_writer.add_histogram(name, param, epoch_1000x)
101 | #raise errors, see https://github.com/pytorch/pytorch/issues/91516
102 | #If you want to use this, install tensorboardX
103 | #then change the code in main_worker.py to "from tensorboardX import SummaryWriter"
104 | # gather the stats from all processes
105 | metric_logger.synchronize_between_processes()
106 | print("Averaged stats:", metric_logger)
107 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
--------------------------------------------------------------------------------
/finetune/val_epoch.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import sys
4 | import numpy as np
5 | from typing import Iterable
6 | import torch
7 | import torch.nn.functional as F
8 | import time
9 |
10 | from ops.Logger import MetricLogger,SmoothedValue
11 | import model.lr_sched as lr_sched
12 | from finetune.loss import configure_loss
13 | from finetune.train_epoch import list_to_device, to_value, \
14 | create_image, torch_to_nparray, convert_gray_rgbimage
15 |
16 | def val_epoch(model, data_loader_val, device, epoch,
17 | log_writer=None, args=None):
18 | model.eval()
19 | metric_logger = MetricLogger(delimiter=" ")
20 | header="Val Epoch: [{}]".format(epoch)
21 | print_freq = args.print_freq
22 | accum_iter = args.accum_iter
23 | criterion = configure_loss(args)
24 | num_iter = len(data_loader_val)
25 | print("number of iterations: ",num_iter)
26 | for data_iter_step, val_data in enumerate(metric_logger.log_every(data_loader_val, print_freq, header)):
27 | input_matrix, total_count, target_matrix, embed_target, target_vector = list_to_device(val_data,device=device)
28 | with torch.no_grad():
29 | output_embedding, output_2d, output_1d = model(input_matrix, total_count)
30 | if embed_target is not None:
31 | embedding_loss = criterion(output_embedding, embed_target)
32 | else:
33 | embedding_loss = 0
34 | if target_matrix is not None:
35 | #flatten 2d matrix
36 | output_2d_flatten = torch.flatten(output_2d, start_dim=1,end_dim=-1)
37 | target_matrix_flatten = torch.flatten(target_matrix, start_dim=1,end_dim=-1)
38 | output_2d_loss = criterion(output_2d_flatten, target_matrix_flatten)
39 | else:
40 | output_2d_loss = 0
41 | if target_vector is not None:
42 | output_1d_loss = criterion(output_1d, target_vector)
43 | else:
44 | output_1d_loss = 0
45 | loss = embedding_loss + output_2d_loss + output_1d_loss
46 | metric_logger.update(loss=to_value(loss))
47 | metric_logger.update(embedding_loss=to_value(embedding_loss))
48 | metric_logger.update(output_2d_loss=to_value(output_2d_loss))
49 | metric_logger.update(output_1d_loss=to_value(output_1d_loss))
50 | torch.cuda.synchronize()
51 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0):
52 | """
53 | We use epoch_1000x as the x-axis in tensorboard.
54 | This calibrates different curves when batch size changes.
55 | """
56 | epoch_1000x = int((data_iter_step / len(data_loader_val) + epoch) * 1000)
57 | log_writer.add_scalars('Loss/loss', {'val_loss': to_value(loss)}, epoch_1000x)
58 | log_writer.add_scalars('Loss/embedding_loss', {'val_loss': to_value(embedding_loss)}, epoch_1000x)
59 | log_writer.add_scalars('Loss/output_2d_loss', {'val_loss': to_value(output_2d_loss)}, epoch_1000x)
60 | log_writer.add_scalars('Loss/output_1d_loss', {'val_loss': to_value(output_1d_loss)}, epoch_1000x)
61 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0:
62 | #add visualization for your output and input
63 | new_samples = create_image(input_matrix)
64 | select_num = min(8,len(new_samples))
65 | sample_image = torch_to_nparray(new_samples.clone().detach()[:select_num])
66 | log_writer.add_images('Input_%s'%"val", sample_image, epoch_1000x)
67 | output_2d_image = convert_gray_rgbimage(output_2d.clone().detach()[:select_num])
68 | output_2d_image = torch_to_nparray(output_2d_image)
69 | log_writer.add_images('Output_2d_%s'%"val", output_2d_image, epoch_1000x)
70 | metric_logger.synchronize_between_processes()
71 | print("Averaged stats:", metric_logger)
72 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
73 |
74 |
--------------------------------------------------------------------------------
/hicfoundation_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/hicfoundation_model/__init__.py
--------------------------------------------------------------------------------
/imgs/framework_github.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/imgs/framework_github.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import timm
3 | assert timm.__version__ == "0.3.2" # version check for timm
4 | from ops.argparser import argparser_infer
5 | from ops.file_format_convert import convert_to_pkl
6 |
7 | def main(args):
8 | import socket
9 | hostname = socket.gethostname()
10 | local_ip = socket.gethostbyname(hostname)
11 | print("local ip: ",local_ip)
12 | #format processing, convert different formats to .pkl format for further processing
13 | output_dir = os.path.abspath(args.output)
14 | os.makedirs(output_dir,exist_ok=True)
15 | input_file = os.path.abspath(args.input)
16 | config_resolution = args.resolution
17 | input_pkl=convert_to_pkl(input_file, output_dir,config_resolution)
18 |
19 | #for reproducibility analysis, we need to smooth the matrix to generate embeddings.
20 | if args.task==1:
21 | from ops.smooth_matrix import smooth_pkl
22 | smooth_pkl_file = os.path.join(output_dir,"input_smoothed.pkl")
23 | input_pkl = smooth_pkl(input_pkl,smooth_pkl_file)
24 | print("Reproducibility analysis smoothed input matrix saved to ",input_pkl)
25 | from inference.main_worker import main_worker
26 | main_worker(args, input_pkl)
27 |
28 |
29 | if __name__ == '__main__':
30 | print("HiCFoundation inference started!")
31 | parser = argparser_infer()
32 | args = parser.parse_args()
33 | if args.gpu is not None:
34 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
35 | #print mode based on --task
36 | if args.task==1:
37 | print("Reproducibility analysis")
38 | elif args.task==2:
39 | print("Loop calling")
40 | elif args.task==3:
41 | print("Resolution enhancement")
42 | elif args.task==4:
43 | print("Epigenomic assay prediction")
44 | elif args.task==5:
45 | print("scHi-C enhancement")
46 | elif args.task==6:
47 | print("Hi-C embedding generation")
48 | embed_depth = args.embed_depth
49 | if embed_depth>8:
50 | print("Error: embed_depth is larger than 8, that is beyond decoder depth. Please set embed_depth<=8")
51 | print("0 indicates the encoder output, k indicates the k-th decoder layer's output")
52 | exit(1)
53 | else:
54 | print("Unknown task specified ",args.task)
55 | print("Please specify the task using --task with 1,2,3,4,5,6")
56 | exit(1)
57 | #check the specied input size, must be a multiple of args.patch_size
58 | if args.input_row_size%args.patch_size!=0 or args.input_col_size%args.patch_size!=0:
59 | print("args configuration error: input_row_size and input_col_size must be a multiple of patch_size")
60 | exit(1)
61 | #output the args in a beautiful format
62 | main(args)
63 |
64 |
--------------------------------------------------------------------------------
/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/inference/__init__.py
--------------------------------------------------------------------------------
/inference/load_model.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def load_model(model_path,input_row_size,input_col_size, task=6):
4 | """
5 | Load a model from a file.
6 |
7 | Args:
8 | model_path (str): The path to the model file.
9 | input_row_size (int): The number of rows in the input matrix.
10 | input_col_size (int): The number of columns in the input matrix.
11 |
12 | Returns:
13 | model: The loaded model.
14 |
15 | Notes:
16 | task 0: fine-tuning setting
17 | task 1: reproducibility analysis
18 | task 2: loop calling
19 | task 3: resolution enhancement
20 | task 4: epigenomic assay prediction
21 | task 5: scHi-C enhancement
22 | task 6: embedding analysis
23 | """
24 | import torch
25 | import model.Vision_Transformer_count as Vision_Transformer
26 | from model.pos_embed import interpolate_pos_embed_inputsize
27 | import torch.nn as nn
28 |
29 | model_name="vit_large_patch16"
30 | patch_size=16
31 |
32 | patch_wise_size = (input_row_size//patch_size, input_col_size//patch_size)
33 | vit_backbone = Vision_Transformer.__dict__[model_name](img_size=(input_row_size,input_col_size))
34 | checkpoint = torch.load(model_path, map_location='cpu')
35 | checkpoint_model = checkpoint['model']
36 | state_dict = vit_backbone.state_dict()
37 | for k in ['head.weight', 'head.bias']:
38 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
39 | print(f"Removing key {k} from pretrained checkpoint")
40 | del checkpoint_model[k]
41 | interpolate_pos_embed_inputsize(vit_backbone, checkpoint_model,input_size=patch_wise_size,
42 | use_decoder=False)
43 | # load pre-trained model
44 | msg = vit_backbone.load_state_dict(checkpoint_model, strict=False)
45 | print("Loading pre-train encoder message!")
46 |
47 | from model.Finetune_Model_Head import Finetune_Model_Head
48 | model = Finetune_Model_Head(vit_backbone, task=task,
49 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
50 | mlp_ratio=4., norm_layer=nn.LayerNorm,pos_embed_size=patch_wise_size)
51 | checkpoint = torch.load(model_path, map_location='cpu')
52 | checkpoint_model = checkpoint['model']
53 | #loading pre-trained decoder
54 | interpolate_pos_embed_inputsize(model, checkpoint['model'],
55 | input_size=patch_wise_size,use_decoder=True)
56 | msg = model.load_state_dict(checkpoint_model, strict=False)
57 | print("Loading pre-train model decoder!")
58 | return model # return the loaded model
59 |
60 | def to_cuda(x):
61 | """
62 | Move a tensor to the GPU.
63 |
64 | Args:
65 | x (torch.Tensor): The tensor to move to the GPU.
66 |
67 | Returns:
68 | torch.Tensor: The tensor on the GPU.
69 | """
70 | import torch
71 | if x is not None:
72 | #if it is float or int, change to tensor
73 | if type(x) is int or type(x) is float:
74 | x = torch.tensor(x)
75 | return x.cuda()
76 | else:
77 | return None
78 |
79 | def to_float(x):
80 | """
81 | Convert a tensor to float.
82 |
83 | Args:
84 | x (torch.Tensor): The tensor to convert to float.
85 |
86 | Returns:
87 | torch.Tensor: The tensor as float.
88 | """
89 | import torch
90 | if x is not None:
91 | return x.float()
92 | else:
93 | return None
94 |
95 | def convert_rgb(data_log,max_value):
96 | import torch
97 | if len(data_log.shape)==2:
98 | data_log = data_log[None,:,:]
99 | data_red = torch.ones(data_log.shape)
100 | data_log1 = (max_value-data_log)/max_value
101 | data_rgb = torch.cat([data_red,data_log1,data_log1],dim=0)
102 | return data_rgb
103 |
104 | def format_input(input):
105 | """
106 | Format the input for the model.
107 |
108 | Args:
109 | input (torch.Tensor): The input tensor.
110 |
111 | Returns:
112 | torch.Tensor: The formatted input tensor.
113 | """
114 | import torch
115 | import torchvision.transforms as transforms
116 | transform_input = transforms.Compose([
117 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
118 |
119 | input = torch.nan_to_num(input)
120 | max_value = torch.max(input)
121 | input = torch.log10(input+1)
122 | max_value = torch.log10(max_value+1)
123 | input = convert_rgb(input,max_value)
124 |
125 | input = transform_input(input)
126 | return input
--------------------------------------------------------------------------------
/model/NativeScaler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch._six import inf
3 |
4 | class NativeScalerWithGradNormCount:
5 | state_dict_key = "amp_scaler"
6 |
7 | def __init__(self):
8 | self._scaler = torch.cuda.amp.GradScaler()
9 |
10 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
11 | self._scaler.scale(loss).backward(create_graph=create_graph)
12 | if update_grad:
13 | if clip_grad is not None:
14 | assert parameters is not None
15 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
16 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
17 | else:
18 | self._scaler.unscale_(optimizer)
19 | norm = get_grad_norm_(parameters)
20 | self._scaler.step(optimizer)
21 | self._scaler.update()
22 | else:
23 | norm = None
24 | return norm
25 |
26 | def state_dict(self):
27 | return self._scaler.state_dict()
28 |
29 | def load_state_dict(self, state_dict):
30 | self._scaler.load_state_dict(state_dict)
31 |
32 |
33 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
34 | if isinstance(parameters, torch.Tensor):
35 | parameters = [parameters]
36 | parameters = [p for p in parameters if p.grad is not None]
37 | norm_type = float(norm_type)
38 | if len(parameters) == 0:
39 | return torch.tensor(0.)
40 | device = parameters[0].grad.device
41 | if norm_type == inf:
42 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
43 | else:
44 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
45 | return total_norm
46 |
--------------------------------------------------------------------------------
/model/Vision_Transformer_count.py:
--------------------------------------------------------------------------------
1 | #adopted from https://github.com/facebookresearch/mae repo
2 |
3 |
4 | from functools import partial
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | import timm.models.vision_transformer
10 |
11 | from model.pos_embed import convert_count_to_pos_embed_cuda
12 |
13 |
14 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
15 | """ Vision Transformer with support for global average pooling
16 | """
17 | def __init__(self, **kwargs):
18 | super(VisionTransformer, self).__init__(**kwargs)
19 | self.patch_size = kwargs['patch_size']
20 | self.in_chans = kwargs['in_chans']
21 | self.embed_dim = kwargs['embed_dim']
22 |
23 | def forward_features(self, x,total_count):
24 | B = x.shape[0]
25 | x = self.patch_embed(x)
26 |
27 | total_count = torch.log10(total_count)
28 | count_embed = convert_count_to_pos_embed_cuda(total_count, self.embed_dim)
29 | count_embed = count_embed.unsqueeze(1)# (N, 1, D)
30 |
31 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
32 | #cls tokens is simply a learnable parameter and 0 positional embedding is added
33 | x = x + self.pos_embed[:, 1:, :]
34 |
35 | x = torch.cat((cls_tokens,count_embed, x), dim=1)
36 |
37 | x = self.pos_drop(x)
38 |
39 | for blk in self.blocks:
40 | x = blk(x)
41 |
42 |
43 | x = self.norm(x)
44 |
45 | return x
46 |
47 |
48 | def vit_base_patch16(**kwargs):
49 | model = VisionTransformer(in_chans=3,
50 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
51 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
52 | return model
53 |
54 |
55 | def vit_large_patch16(**kwargs):
56 | model = VisionTransformer(in_chans=3,
57 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
58 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
59 | return model
60 |
61 |
62 | def vit_huge_patch14(**kwargs):
63 | model = VisionTransformer(in_chans=3,
64 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
65 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
66 | return model
67 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/model/__init__.py
--------------------------------------------------------------------------------
/model/lr_decay.py:
--------------------------------------------------------------------------------
1 |
2 | # --------------------------------------------------------
3 | # References:
4 | # ELECTRA https://github.com/google-research/electra
5 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
6 | # MAE: https://github.com/facebookresearch/mae
7 | # --------------------------------------------------------
8 |
9 | import json
10 |
11 |
12 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
13 | """
14 | Parameter groups for layer-wise lr decay
15 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
16 | """
17 | param_group_names = {}
18 | param_groups = {}
19 |
20 | num_layers = len(model.blocks) + 1
21 |
22 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
23 |
24 | for n, p in model.named_parameters():
25 | if not p.requires_grad:
26 | continue
27 |
28 | # no decay: all 1D parameters and model specific ones
29 | if p.ndim == 1 or n in no_weight_decay_list:
30 | g_decay = "no_decay"
31 | this_decay = 0.
32 | else:
33 | g_decay = "decay"
34 | this_decay = weight_decay
35 |
36 | layer_id = get_layer_id_for_vit(n, num_layers)
37 | group_name = "layer_%d_%s" % (layer_id, g_decay)
38 |
39 | if group_name not in param_group_names:
40 | this_scale = layer_scales[layer_id]
41 |
42 | param_group_names[group_name] = {
43 | "lr_scale": this_scale,
44 | "weight_decay": this_decay,
45 | "params": [],
46 | }
47 | param_groups[group_name] = {
48 | "lr_scale": this_scale,
49 | "weight_decay": this_decay,
50 | "params": [],
51 | }
52 |
53 | param_group_names[group_name]["params"].append(n)
54 | param_groups[group_name]["params"].append(p)
55 |
56 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
57 |
58 | return list(param_groups.values())
59 |
60 |
61 | def get_layer_id_for_vit(name, num_layers):
62 | """
63 | Assign a parameter with its layer id
64 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
65 | """
66 | if name in ['cls_token', 'pos_embed']:
67 | return 0
68 | elif name.startswith('patch_embed'):
69 | return 0
70 | elif name.startswith('blocks'):
71 | return int(name.split('.')[1]) + 1
72 | elif name.startswith('decoder_blocks'):
73 | return int(name.split('.')[1]) + 1
74 | else:
75 | return num_layers
76 |
77 |
78 | def param_groups_lrd_decoder(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
79 | """
80 | Parameter groups for layer-wise lr decay
81 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
82 | """
83 | param_group_names = {}
84 | param_groups = {}
85 |
86 | num_layers = len(model.decoder_blocks) + 1
87 |
88 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
89 |
90 | for n, p in model.named_parameters():
91 | if not p.requires_grad:
92 | print("no gradient apply for ",n)
93 | continue
94 |
95 | # no decay: all 1D parameters and model specific ones
96 | if p.ndim == 1 or n in no_weight_decay_list:
97 | g_decay = "no_decay"
98 | this_decay = 0.
99 | else:
100 | g_decay = "decay"
101 | this_decay = weight_decay
102 |
103 | layer_id = get_layer_id_for_vit(n, num_layers)
104 | group_name = "layer_%d_%s" % (layer_id, g_decay)
105 |
106 | if group_name not in param_group_names:
107 | this_scale = layer_scales[layer_id]
108 |
109 | param_group_names[group_name] = {
110 | "lr_scale": this_scale,
111 | "weight_decay": this_decay,
112 | "params": [],
113 | }
114 | param_groups[group_name] = {
115 | "lr_scale": this_scale,
116 | "weight_decay": this_decay,
117 | "params": [],
118 | }
119 |
120 | param_group_names[group_name]["params"].append(n)
121 | param_groups[group_name]["params"].append(p)
122 |
123 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
124 |
125 | return list(param_groups.values())
126 |
--------------------------------------------------------------------------------
/model/lr_sched.py:
--------------------------------------------------------------------------------
1 | # From MAE: https://github.com/facebookresearch/mae
2 |
3 | import math
4 |
5 | def adjust_learning_rate(optimizer, epoch, args):
6 | """Decay the learning rate with half-cycle cosine after warmup"""
7 | if epoch < args.warmup_epochs:
8 | lr = args.lr * epoch / args.warmup_epochs
9 | else:
10 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
11 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
12 | for param_group in optimizer.param_groups:
13 | if "lr_scale" in param_group:
14 | param_group["lr"] = lr * param_group["lr_scale"]
15 | else:
16 | param_group["lr"] = lr
17 | return lr
18 |
19 |
--------------------------------------------------------------------------------
/model/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ops.distribute_utils import is_main_process
3 | from pathlib import Path
4 | import os
5 |
6 | def load_model(resume_path,args,model_without_ddp, optimizer, loss_scaler):
7 | """
8 | Load the model from the checkpoint
9 | Args:
10 | resume_path: the path to the checkpoint
11 | model_without_ddp: the model
12 | optimizer: the optimizer
13 | loss_scaler: the loss scaler
14 | """
15 | if os.path.isfile(resume_path):
16 | print("=> loading checkpoint '{}'".format(resume_path))
17 | if resume_path.startswith('https'):
18 | checkpoint = torch.hub.load_state_dict_from_url(
19 | resume_path, map_location='cpu', check_hash=True)
20 | else:
21 | checkpoint = torch.load(resume_path, map_location='cpu')
22 | msg=model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
23 | print("model resume message:{}".format(msg))
24 | optimizer.load_state_dict(checkpoint['optimizer'])
25 | loss_scaler.load_state_dict(checkpoint['scaler'])
26 | args.start_epoch = checkpoint['epoch'] + 1
27 | print("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch']))
28 | else:
29 | print("=> no checkpoint found at '{}'".format(resume_path))
30 |
31 | def save_on_master(*args, **kwargs):
32 | if is_main_process():
33 | torch.save(*args, **kwargs)
34 |
35 |
36 |
37 | def save_checkpoint(output_dir, args,epoch, model_without_ddp, optimizer, loss_scaler):
38 | output_dir = Path(output_dir)
39 | epoch_name = str(epoch)
40 |
41 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
42 | for checkpoint_path in checkpoint_paths:
43 | to_save = {
44 | 'model': model_without_ddp.state_dict(),
45 | 'optimizer': optimizer.state_dict(),
46 | 'epoch': epoch,
47 | 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None,
48 | 'args': args,
49 | }
50 |
51 | save_on_master(to_save, checkpoint_path)
52 |
53 | def save_model2path(model_path,args,epoch, model_without_ddp, optimizer, loss_scaler):
54 | to_save={
55 | 'model': model_without_ddp.state_dict(),
56 | 'optimizer': optimizer.state_dict(),
57 | 'epoch': epoch,
58 | 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None,
59 | 'args': args,
60 | }
61 | save_on_master(to_save, model_path)
62 |
63 |
--------------------------------------------------------------------------------
/ops/Logger.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, deque
2 | import torch
3 | from ops.distribute_utils import is_dist_avail_and_initialized
4 | import torch.distributed as dist
5 | import datetime
6 | import os
7 | import time
8 |
9 | class SmoothedValue(object):
10 | """Track a series of values and provide access to smoothed values over a
11 | window or the global series average.
12 | """
13 |
14 | def __init__(self, window_size=20, fmt=None):
15 | if fmt is None:
16 | fmt = "{median:.4f} ({global_avg:.4f})"
17 | self.deque = deque(maxlen=window_size)
18 | self.total = 0.0
19 | self.count = 0
20 | self.fmt = fmt
21 |
22 | def update(self, value, n=1):
23 | self.deque.append(value)
24 | self.count += n
25 | self.total += value * n
26 |
27 | def synchronize_between_processes(self):
28 | """
29 | Warning: does not synchronize the deque!
30 | """
31 | if not is_dist_avail_and_initialized():
32 | return
33 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
34 | dist.barrier()
35 | dist.all_reduce(t)
36 | t = t.tolist()
37 | self.count = int(t[0])
38 | self.total = t[1]
39 |
40 | @property
41 | def median(self):
42 | d = torch.tensor(list(self.deque))
43 | return d.median().item()
44 |
45 | @property
46 | def avg(self):
47 | d = torch.tensor(list(self.deque), dtype=torch.float32)
48 | return d.mean().item()
49 |
50 | @property
51 | def global_avg(self):
52 | return self.total / self.count
53 |
54 | @property
55 | def max(self):
56 | return max(self.deque)
57 |
58 | @property
59 | def value(self):
60 | return self.deque[-1]
61 |
62 | def __str__(self):
63 | return self.fmt.format(
64 | median=self.median,
65 | avg=self.avg,
66 | global_avg=self.global_avg,
67 | max=self.max,
68 | value=self.value)
69 |
70 |
71 | class MetricLogger(object):
72 | def __init__(self, delimiter="\t"):
73 | self.meters = defaultdict(SmoothedValue)
74 | self.delimiter = delimiter
75 |
76 | def update(self, **kwargs):
77 | for k, v in kwargs.items():
78 | if v is None:
79 | continue
80 | if isinstance(v, torch.Tensor):
81 | v = v.item()
82 | assert isinstance(v, (float, int))
83 | self.meters[k].update(v)
84 |
85 | def __getattr__(self, attr):
86 | if attr in self.meters:
87 | return self.meters[attr]
88 | if attr in self.__dict__:
89 | return self.__dict__[attr]
90 | raise AttributeError("'{}' object has no attribute '{}'".format(
91 | type(self).__name__, attr))
92 |
93 | def __str__(self):
94 | loss_str = []
95 | for name, meter in self.meters.items():
96 | loss_str.append(
97 | "{}: {}".format(name, str(meter))
98 | )
99 | return self.delimiter.join(loss_str)
100 |
101 | def synchronize_between_processes(self):
102 | for meter in self.meters.values():
103 | meter.synchronize_between_processes()
104 |
105 | def add_meter(self, name, meter):
106 | self.meters[name] = meter
107 | def log(self,iteration,print_freq,header=None):
108 | if iteration % print_freq == 0:
109 | return
110 |
111 | if not header:
112 | header = ''
113 | #space_fmt = ':' + str(iteration) + 'd'
114 | log_msg = [
115 | header,
116 | '['+str(iteration)+']',
117 | '{meters}',
118 | ]
119 | if torch.cuda.is_available():
120 | log_msg.append('max mem: {memory:.0f}')
121 | MB = 1024.0 * 1024.0
122 | log_msg = self.delimiter.join(log_msg)
123 | if torch.cuda.is_available():
124 | print(log_msg.format(
125 | iteration,
126 | meters=str(self),
127 | memory=torch.cuda.max_memory_allocated() / MB))
128 | else:
129 | print(log_msg.format(
130 | iteration,
131 | meters=str(self)))
132 | def log_every(self, iterable, print_freq, header=None):
133 | i = 0
134 | if not header:
135 | header = ''
136 | start_time = time.time()
137 | end = time.time()
138 | iter_time = SmoothedValue(fmt='{avg:.4f}')
139 | data_time = SmoothedValue(fmt='{avg:.4f}')
140 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
141 | log_msg = [
142 | header,
143 | '[{0' + space_fmt + '}/{1}]',
144 | 'eta: {eta}',
145 | '{meters}',
146 | 'time: {time}',
147 | 'data: {data}'
148 | ]
149 | if torch.cuda.is_available():
150 | log_msg.append('max mem: {memory:.0f}')
151 | log_msg = self.delimiter.join(log_msg)
152 | MB = 1024.0 * 1024.0
153 | for obj in iterable:
154 | data_time.update(time.time() - end)
155 | yield obj
156 | iter_time.update(time.time() - end)
157 | if i % print_freq == 0 or i == len(iterable) - 1:
158 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
159 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
160 | if torch.cuda.is_available():
161 | print(log_msg.format(
162 | i, len(iterable), eta=eta_string,
163 | meters=str(self),
164 | time=str(iter_time), data=str(data_time),
165 | memory=torch.cuda.max_memory_allocated() / MB),flush=True)
166 | else:
167 | print(log_msg.format(
168 | i, len(iterable), eta=eta_string,
169 | meters=str(self),
170 | time=str(iter_time), data=str(data_time)),flush=True)
171 | i += 1
172 | end = time.time()
173 | total_time = time.time() - start_time
174 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
175 | print('{} Total time: {} ({:.4f} s / it)'.format(
176 | header, total_time_str, total_time / len(iterable)),flush=True)
177 |
178 |
179 | def print_important_info(str):
180 | print("="*50)
181 | print("Important:"+str)
182 | print("="*50)
183 |
184 | def print_warning_info(str):
185 | print("*"*50)
186 | print("Warning:"+str)
187 | print("*"*50)
--------------------------------------------------------------------------------
/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/ops/__init__.py
--------------------------------------------------------------------------------
/ops/calculate_similarity.py:
--------------------------------------------------------------------------------
1 | # This script is to calculate the similarity between two Hi-C using a pre-trained reproducibility model.
2 |
3 | import os
4 | import sys
5 | import numpy as np
6 | import pickle
7 | from collections import defaultdict
8 |
9 | input_pickle1 = sys.argv[1]
10 | input_pickle2 = sys.argv[2]
11 |
12 | def load_pickle(file_path):
13 | with open(file_path, 'rb') as f:
14 | data = pickle.load(f)
15 | return data
16 |
17 | input1 = load_pickle(input_pickle1)
18 | input2 = load_pickle(input_pickle2)
19 |
20 | def find_key(chr,loc,key_list):
21 | """
22 | Find the key in the list of keys that contains the given chromosome and location.
23 | """
24 | key1 = chr+":"+loc
25 | if key1 in key_list:
26 | return key1
27 | key1 = "chr"+chr+":"+loc
28 | if key1 in key_list:
29 | return key1
30 | key1 = chr+"_"+chr+":"+loc
31 | if key1 in key_list:
32 | return key1
33 | key1 = "chr"+chr+"_chr"+chr+":"+loc
34 | if key1 in key_list:
35 | return key1
36 | return None
37 |
38 | def calculate_similarity(input1, input2):
39 | """
40 | Calculate the similarity between two Hi-C matrices using a pre-trained reproducibility model.
41 | """
42 | similarity_dict = defaultdict(list)
43 | for key in input1.keys():
44 | #1_1:1960,1960 format of key
45 | split_chromosome = key.split(":")[0]
46 | split_loc = key.split(":")[1]
47 | combine_key = split_chromosome + ":" + split_loc
48 | chr = split_chromosome.split("_")[0]
49 | chr = chr.replace("chr","")
50 | if combine_key not in input2.keys():
51 | combine_key = find_key(chr,split_loc,input2.keys())
52 | if combine_key is None:
53 | continue
54 |
55 | embedding1 = input1[key]
56 | embedding2 = input2[combine_key]
57 | # Calculate the similarity between the two embeddings
58 | similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
59 | if np.isnan(similarity):
60 | continue
61 | similarity_dict[chr].append(similarity)
62 | #ignore chrY, chrM, Un, Alt cases
63 | similarity_list=[]
64 | for chrom in similarity_dict:
65 | if "Y" in chrom or "M" in chrom or "Un" in chrom or "Alt" in chrom:
66 | continue
67 | mean_val = np.mean(similarity_dict[chrom])
68 | similarity_list.append(mean_val)
69 | similarity = np.mean(similarity_list)
70 | return similarity
71 |
72 | similarity = calculate_similarity(input1, input2)
73 | print("The reproducibility score between the two Hi-C is: ", similarity)
74 |
--------------------------------------------------------------------------------
/ops/distribute_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import datetime
4 | import builtins
5 | import numpy as np
6 | import torch.distributed as dist
7 |
8 | def is_dist_avail_and_initialized():
9 | if not dist.is_available():
10 | return False
11 | if not dist.is_initialized():
12 | return False
13 | return True
14 | def get_rank():
15 | if not is_dist_avail_and_initialized():
16 | return 0
17 | return dist.get_rank()
18 | def get_world_size():
19 | if not is_dist_avail_and_initialized():
20 | return 1
21 | return dist.get_world_size()
22 | def all_reduce_mean(x):
23 | world_size = get_world_size()
24 | if world_size > 1:
25 | x_reduce = torch.tensor(x).cuda()
26 | dist.all_reduce(x_reduce)
27 | x_reduce /= world_size
28 | return x_reduce.item()
29 | else:
30 | return x
31 | def is_main_process():
32 | return get_rank() == 0
33 | def setup_for_distributed(is_master):
34 | """
35 | This function disables printing when not in master process
36 | """
37 | builtin_print = builtins.print
38 |
39 | def print(*args, **kwargs):
40 |
41 | if is_master:
42 | now = datetime.datetime.now().time()
43 | builtin_print('[{}] '.format(now), end='') # print with time stamp
44 | builtin_print(*args, **kwargs)
45 |
46 | builtins.print = print
47 |
48 | def init_distributed_mode(gpu,ngpus_per_node,args):
49 |
50 | import resource
51 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
52 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
53 | args.gpu = gpu
54 | args.rank = args.rank * ngpus_per_node + gpu
55 | os.environ['LOCAL_RANK'] = str(args.gpu)
56 | os.environ['RANK'] = str(args.rank)
57 | os.environ['WORLD_SIZE'] = str(args.world_size)
58 | print("make sure the distributed mode is ",args.dist_url)
59 |
60 |
61 |
62 | args.distributed = True
63 |
64 | torch.cuda.set_device(args.gpu)
65 | args.dist_backend = 'nccl'
66 | print('| distributed init (rank {}): {}, gpu {}'.format(
67 | args.rank, args.dist_url, args.gpu), flush=True)
68 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
69 | timeout=datetime.timedelta(seconds=36000),
70 | world_size=args.world_size, rank=args.rank)
71 |
72 | setup_for_distributed(args.rank == 0)
73 |
74 |
75 |
76 | # fix the seed for reproducibility
77 | seed = args.seed + get_rank()
78 | torch.manual_seed(seed)
79 | np.random.seed(seed)
80 |
--------------------------------------------------------------------------------
/ops/file_format_convert.py:
--------------------------------------------------------------------------------
1 | from utils.hic2array import hic2array
2 | from utils.cool2array import cool2array_intra
3 | from utils.array2hic import array2hic
4 | from utils.array2cool import array2cool
5 | import os
6 | import numpy as np
7 | from ops.sparse_ops import array_to_coo
8 | from scipy.sparse import coo_matrix
9 | from collections import defaultdict
10 | def write_pkl(return_dict,output_pkl_path):
11 | import pickle
12 | with open(output_pkl_path,'wb') as f:
13 | pickle.dump(return_dict,f)
14 | print("finish writing to:",output_pkl_path)
15 | def load_pkl(input_pkl):
16 | import pickle
17 | with open(input_pkl,'rb') as f:
18 | return_dict = pickle.load(f)
19 | return return_dict
20 |
21 | def read_text(input_file,config_resolution):
22 | #records should be readID chr1 pos1 chr2 pos2
23 | #read line by line to get the sparse matrix
24 | final_dict=defaultdict(list)
25 | with open(input_file,'r') as f:
26 | for line in f:
27 | line = line.strip().split()
28 | try:
29 | chr1 = line[1]
30 | chr2 = line[3]
31 | pos1 = int(line[2])//config_resolution
32 | pos2 = int(line[4])//config_resolution
33 | final_dict[(chr1,chr2)].append((pos1,pos2))
34 | except:
35 | print("*"*40)
36 | print("Skip line in records:",line)
37 | print("The line should be in format of [readID chr1 pos1 chr2 pos2]")
38 | print("*"*40)
39 | return final_dict
40 |
41 | def countlist2coo(input_dict):
42 | final_dict={}
43 | for key in input_dict:
44 | row=[]
45 | col=[]
46 | data=[]
47 | for item in input_dict[key]:
48 | row.append(item[0])
49 | col.append(item[1])
50 | data.append(1)
51 | max_size = max(max(row),max(col))+1
52 | cur_array = coo_matrix((data,(row,col)),shape=(max_size,max_size))
53 | #sum duplicates
54 | cur_array.sum_duplicates()
55 | final_dict[key]=cur_array
56 | return final_dict
57 | def convert_to_pkl(input_file, output_dir,config_resolution):
58 | output_pkl = os.path.join(output_dir, "input.pkl")
59 | #if it is a .hic file
60 | if input_file.endswith('.hic'):
61 | #convert to .pkl format, only keep intra-chromosome regions
62 | hic2array(input_file,output_pkl=output_pkl,
63 | resolution=config_resolution,normalization="NONE",
64 | tondarray=2)
65 | elif input_file.endswith('.cool'):
66 | return_dict=cool2array_intra(input_file,normalize=False,
67 | tondarray=False,binsize=config_resolution)
68 | write_pkl(return_dict,output_pkl)
69 | elif input_file.endswith('.pkl'):
70 | #load pickle to sanity check
71 | return_dict = load_pkl(input_file)
72 | final_dict = {}
73 | #check if it is dict
74 | if not isinstance(return_dict,dict):
75 | raise ValueError("Input pkl file should be a dictionary")
76 | else:
77 | for key in return_dict:
78 | if isinstance(return_dict[key],np.ndarray):
79 | final_dict[key] = array_to_coo(return_dict[key])
80 | elif isinstance(return_dict[key],coo_matrix):
81 | final_dict[key] = return_dict[key]
82 | else:
83 | raise ValueError("The value of the dictionary in .pkl should be either numpy array or coo_matrix")
84 | write_pkl(final_dict,output_pkl)
85 | elif input_file.endswith('.txt') or input_file.endswith('.pairs'):
86 | #convert to .pkl format
87 | initial_dict = read_text(input_file,config_resolution)
88 | #filter intra-chromosome regions
89 | final_dict = {}
90 | for key in initial_dict:
91 | if key[0] == key[1]:
92 | final_dict[key[0]] = initial_dict[key]
93 | #then change it to coo_matrix array
94 | return_dict = countlist2coo(final_dict)
95 | write_pkl(return_dict,output_pkl)
96 | elif input_file.endswith('.npy'):
97 | #load numpy array
98 | input_array = np.load(input_file)
99 | #convert to coo_matrix
100 | final_array = array_to_coo(input_array)
101 | #save to pkl
102 | final_dict={"chr_tmp":final_array}
103 | write_pkl(final_dict,output_pkl)
104 | else:
105 | print("Unsupported file format ",input_file)
106 | print("Supported file format: .hic/.cool/.pkl/.txt/.npy")
107 | raise ValueError("Unsupported file format")
108 |
109 | return output_pkl
110 |
111 | def pkl2others(input_pkl, output_file,config_resolution,genome_id):
112 | output_dir = os.path.dirname(output_file)
113 | os.makedirs(output_dir,exist_ok=True)
114 | data=load_pkl(input_pkl)
115 | if output_file.endswith('.txt') or output_file.endswith('.pairs'):
116 | #write to simple txt
117 | # [chr1, pos1, chr2, pos2, count]
118 | with open(output_file,'w') as file:
119 | file.write("#readID\tchr1\tpos1\tchr2\tpos2\tcount\n")
120 | for chrom in data:
121 | if type(data[chrom]) == np.ndarray:
122 | for i in range(data[chrom].shape[0]):
123 | for j in range(data[chrom].shape[1]):
124 | if data[chrom][i,j]>0:
125 | file.write(f".\t{chrom}\t{i*config_resolution}\t{chrom}\t{j*config_resolution}\t{data[chrom][i,j]}\n")
126 | else:
127 | for i in range(data[chrom].nnz):
128 | row = data[chrom].row[i]
129 | col = data[chrom].col[i]
130 | count = data[chrom].data[i]
131 | file.write(f".\t{chrom}\t{row*config_resolution}\t{chrom}\t{col*config_resolution}\t{count}\n")
132 | elif output_file.endswith('.npy'):
133 | if len(data)>1:
134 | print("Warning: multiple chromosomes detected, please check the output in .pkl format:",input_pkl)
135 | print("The format is dict in format of [chr]:[scipy.sparse.coo_matrix]")
136 | return
137 | current_array = data[list(data.keys())[0]]
138 | if isinstance(current_array,coo_matrix):
139 | current_array = current_array.toarray()
140 | np.save(output_file,current_array)
141 | elif output_file.endswith('.hic'):
142 | #https://github.com/aidenlab/juicer/wiki/Pre
143 | cur_py_path = os.path.abspath(__file__)
144 | cur_py_dir = os.path.dirname(cur_py_path)
145 | code_repo_dir=os.path.dirname(cur_py_dir)
146 | juicer_tools= os.path.join(code_repo_dir,"utils","juicer_tools.jar")
147 | array2hic(juicer_tools,input_pkl,output_file,config_resolution,genome_id,1)
148 | elif output_file.endswith('.cool'):
149 |
150 | array2cool(input_pkl,output_file,config_resolution,genome_id,1)
151 | elif output_file.endswith('.pkl'):
152 | output_file = input_pkl
153 | else:
154 | print("Unsupported file format ",output_file)
155 | output_file=input_pkl
156 | print("Final output is saved in ",output_file)
157 | return output_file
--------------------------------------------------------------------------------
/ops/io_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import pickle
3 | import os
4 | import json
5 | def load_pickle(path):
6 | with open(path,'rb') as file:
7 | data=pickle.load(file)
8 | return data
9 |
10 | def write_pickle(data,path):
11 | with open(path,'wb') as file:
12 | pickle.dump(data, file)
13 |
14 | def append_record(bedpe_path_min,min_loc_list,chrom,resolution=10000):
15 | if "_" in chrom:
16 | chrom = chrom.split("_")[0]
17 |
18 | with open(bedpe_path_min,'a') as wfile:
19 | for loc in min_loc_list:
20 | x1,x2 = loc
21 | if x1=xdim:
27 | endp[0]=xdim
28 | if endp[1]>=ydim:
29 | endp[1]=ydim
30 | dtotal=0
31 | pos2=np.zeros(2)
32 | for xp in range(int(stp[0]),int(endp[0])):
33 | rx=float(xp-pos[0])**2
34 | for yp in range(int(stp[1]),int(endp[1])):
35 | ry=float(yp-pos[1])**2
36 |
37 | d2=rx+ry
38 | v=np.exp(-1.5*d2*fsiv)*dens[xp,yp]#This is the bottom part of the equation, where pos represents y, (xp,yp,zp) represents xi
39 | dtotal+=v
40 | if v>0:
41 | pos2[0]+=v*(float)(xp)#pos2 is for the top part of the equation
42 | pos2[1]+=v*(float)(yp)
43 |
44 | if dtotal==0:
45 | break
46 | rd=1.00/float(dtotal)
47 | tempcd=np.zeros(2)
48 | for j in range(2):
49 | pos2[j]*=rd#Now we get the equation result
50 | tempcd[j]=pos[j]-pos2[j]
51 | pos[j]=pos2[j]#Prepare for iteration
52 | check_d=tempcd[0]**2+tempcd[1]**2#Iteration until you find the place is stable
53 | if check_d<0.001:
54 | break
55 |
56 | for j in range(2):
57 |
58 | point_cd[i][j]=pos[j]
59 | point_dens[i]=dtotal/cnt
60 | return point_cd,point_dens
61 | @jit(nogil=True,nopython=True)
62 | def acc_merge_point(Ncd,dens,dmin,rv_range,rdcut,stock,cd,d2cut,member):
63 | if True:
64 | for i in range(Ncd-1):
65 | if i%10000==0:
66 | print(i)
67 | tmp=np.zeros(2)
68 | if (dens[i]-dmin)*rv_range < rdcut:
69 | stock[i]=0#Label the small density parts as unused parts
70 | if stock[i]==0:
71 | continue
72 | for j in range(i+1,Ncd):
73 | if stock[j]==0:
74 | continue
75 | d2=0
76 | for k in range(2):
77 | tmp[k]=cd[i][k]-cd[j][k]
78 | d2+=tmp[k]**2
79 | if d2dens[j]:
82 | stock[j]=0
83 | member[j]=i
84 | else:
85 | stock[i]=0
86 | member[i]=j
87 | break#jump out of the second rotation, since i has been merged
88 | #Update member data, to updata some son/grandson points to original father point
89 | for i in range(Ncd):
90 | now=int(member[i])
91 | while now!=member[now]:#If it's not merged points, it will totates to find the father point(merged point)
92 | now=int(member[now])
93 | member[i]=now
94 | return stock,member
95 |
96 | @jit(nogil=True,nopython=True)
97 | def further_merge_point(count_loc, new_density, stock,
98 | new_location, d2cut,member):
99 | #before_count = len(np.argwhere(stock==1))
100 | for i in range(count_loc-1):
101 | if i%10000==0:
102 | print(i)
103 | member_i= member[i]
104 | for j in range(i+1,count_loc):
105 | member_j = member[j]
106 | if member_i==member_j:
107 | continue
108 | d2=0
109 | for k in range(2):
110 | d2+=(new_location[i][k]-new_location[j][k])**2
111 | if d2member_j_cluster_dens:
115 | all_other_index = member==member_j
116 | stock[all_other_index]=0
117 | member[all_other_index]=member_i
118 | else:
119 | all_other_index = member==member_i
120 | stock[all_other_index]=0
121 | member[all_other_index]=member_j
122 | #after_count = len(np.argwhere(stock==1))
123 | #print("after further merge, before count: %d, after count: %d"%(before_count,after_count))
124 | return stock,member
125 |
126 |
127 | def mean_shift_merge(predict_array,cutoff=0.1):
128 | #generate mean shift detections
129 | #cutoff=0.1
130 | bandwidth = 2
131 | gstep = 1
132 | fs = (bandwidth / gstep) * 0.5
133 | fs = fs * fs
134 | fsiv = 1 / (float(fs))
135 | fmaxd = (bandwidth / gstep) * 2.0
136 | if cutoff==1:
137 | location = np.argwhere(predict_array>=cutoff)
138 | else:
139 | location = np.argwhere(predict_array>cutoff)
140 | count_loc = len(location)
141 | tmp_xdim=predict_array.shape[0]
142 | tmp_ydim = predict_array.shape[1]
143 |
144 | new_location, new_density = carry_shift(location, count_loc, fmaxd, fsiv,
145 | tmp_xdim, tmp_ydim, predict_array)
146 | if len(new_location)==0 or len(new_density)==0:
147 | return []
148 | dmin = np.min(new_density)
149 | dmax = np.max(new_density)
150 | drange = dmax - dmin
151 | print('here we get the density range %f' % drange)
152 | rv_range = 1.0 / drange
153 | rdcut = 0.01
154 | stock = np.ones(count_loc)
155 | d2cut = 2 ** 2 #important merge criteria
156 | member = np.arange(count_loc)
157 |
158 | stock, member = acc_merge_point(count_loc, new_density, dmin,
159 | rv_range, rdcut, stock,
160 | new_location, d2cut,member)
161 | #further merge, if we observe any two subpoints are close than 2, then we still merge them.
162 | stock, member = further_merge_point(count_loc, new_density, stock,
163 | new_location, d2cut,member)
164 | final_loc_list=[]
165 | for i in range(count_loc):
166 | if stock[i] == 1:
167 | final_loc_list.append(new_location[i])
168 | final_loc_list =np.stack(final_loc_list,axis=0)
169 | return final_loc_list
--------------------------------------------------------------------------------
/ops/smooth_matrix.py:
--------------------------------------------------------------------------------
1 |
2 | from ops.sparse_ops import sym_mat
3 | from scipy.sparse import triu
4 | import os
5 | import numpy as np
6 | import scipy.sparse as sp
7 | import math
8 |
9 | def trimDiags(a: sp.coo_matrix, iDiagMax: int, bKeepMain: bool):
10 | """Remove diagonal elements whose diagonal index is >= iDiagMax
11 | or is == 0
12 |
13 | Args:
14 | a: Input scipy coo_matrix
15 | iDiagMax: Diagonal offset cutoff
16 | bKeepMain: If true, keep the elements in the main diagonal;
17 | otherwise remove them
18 |
19 | Returns:
20 | coo_matrix with the specified diagonals removed
21 | """
22 | gDist = np.abs(a.row - a.col)
23 | idx = np.where((gDist < iDiagMax) & (bKeepMain | (gDist != 0)))
24 | return sp.coo_matrix((a.data[idx], (a.row[idx], a.col[idx])),
25 | shape=a.shape, dtype=a.dtype)
26 | def meanFilterSparse(a: sp.coo_matrix, h: int):
27 | """Apply a mean filter to an input sparse matrix. This convolves
28 | the input with a kernel of size 2*h + 1 with constant entries and
29 | subsequently reshape the output to be of the same shape as input
30 |
31 | Args:
32 | a: `sp.coo_matrix`, Input matrix to be filtered
33 | h: `int` half-size of the filter
34 |
35 | Returns:
36 | `sp.coo_matrix` filterd matrix
37 | """
38 | assert h > 0, "meanFilterSparse half-size must be greater than 0"
39 | assert sp.issparse(a) and a.getformat() == 'coo',\
40 | "meanFilterSparse input matrix is not scipy.sparse.coo_matrix"
41 | assert a.shape[0] == a.shape[1],\
42 | "meanFilterSparse cannot handle non-square matrix"
43 | fSize = 2 * h + 1
44 | # filter is a square matrix of constant 1 of shape (fSize, fSize)
45 | shapeOut = np.array(a.shape) + fSize - 1
46 | mToeplitz = sp.diags(np.ones(fSize),
47 | np.arange(-fSize+1, 1),
48 | shape=(shapeOut[1], a.shape[1]),
49 | format='csr')
50 | ans = sp.coo_matrix((mToeplitz @ a) @ mToeplitz.T)
51 | # remove the edges since we don't care about them if we are smoothing
52 | # the matrix itself
53 | ansNoEdge = ans.tocsr()[h:(h+a.shape[0]), h:(h+a.shape[1])].tocoo()
54 | # Assign different number of neighbors to the edge to better
55 | # match what the original R implementation of HiCRep does
56 | rowDist2Edge = np.minimum(ansNoEdge.row, ansNoEdge.shape[0] - 1 - ansNoEdge.row)
57 | nDim1 = h + 1 + np.minimum(rowDist2Edge, h)
58 | colDist2Edge = np.minimum(ansNoEdge.col, ansNoEdge.shape[1] - 1 - ansNoEdge.col)
59 | nDim2 = h + 1 + np.minimum(colDist2Edge, h)
60 | nNeighbors = nDim1 * nDim2
61 | ansNoEdge.data /= nNeighbors
62 | return ansNoEdge
63 |
64 |
65 | def smooth_matrix(input_dict,dMax=200,hsize=11,max_value=1000):
66 |
67 | new_dict={}
68 | size_limit=224
69 | for key in input_dict:
70 | current_mat = input_dict[key]
71 | current_mat.data = np.minimum(max_value,current_mat.data)
72 | current_mat = sym_mat(current_mat)
73 | if current_mat.shape[0]<=size_limit:
74 | continue
75 | nDiags = current_mat.shape[0] if dMax < 0 else min(dMax, current_mat.shape[0])
76 | m1 = trimDiags(current_mat, nDiags, False)
77 |
78 | if hsize>0:
79 | # apply smoothing
80 | #m1.data = np.log10(m1.data+1)
81 | m1 = meanFilterSparse(m1, hsize)
82 | #m1.data = np.power(10,m1.data)-1
83 | new_dict[key]=triu(m1,0,format='coo')
84 | return new_dict
85 | from ops.io_utils import load_pickle, write_pickle
86 | def smooth_pkl(input_pkl,output_pkl,
87 | dMax=200,hsize=11,max_value=1000):
88 | input_dict = load_pickle(input_pkl)
89 | new_dict = smooth_matrix(input_dict,dMax,hsize,max_value)
90 | write_pickle(new_dict,output_pkl)
91 | return output_pkl
--------------------------------------------------------------------------------
/ops/sparse_ops.py:
--------------------------------------------------------------------------------
1 | from scipy.sparse import triu,coo_matrix
2 | import numpy as np
3 | def sym_mat(input_array):
4 | down_array = triu(input_array,1,format='coo').T
5 | new_row = np.concatenate([input_array.row,down_array.row])
6 | new_col = np.concatenate([input_array.col,down_array.col])
7 | new_data = np.concatenate([input_array.data,down_array.data])
8 | shape=input_array.shape
9 | final_array = coo_matrix((new_data,(new_row,new_col)),shape=shape)
10 | return final_array
11 |
12 |
13 | def array_to_coo(array):
14 | """
15 | Convert a regular 2D NumPy array to a scipy.sparse.coo_matrix.
16 |
17 | Parameters:
18 | - array (numpy.ndarray): The input 2D array.
19 |
20 | Returns:
21 | - scipy.sparse.coo_matrix: The converted COO matrix.
22 | """
23 | # Find the non-zero elements in the array
24 | row, col = np.nonzero(array)
25 |
26 | # Get the values of the non-zero elements
27 | data = array[row, col]
28 |
29 | # Create the COO matrix
30 | coo_mat = coo_matrix((data, (row, col)), shape=array.shape)
31 |
32 | return coo_mat
33 |
34 | def filter_sparse_region(input_row,input_col,input_data,
35 | start_index,end_index):
36 | """
37 | input_row: the row index of the sparse matrix
38 | input_col: the column index of the sparse matrix
39 | input_data: the data of the sparse matrix
40 | start_index: the start index of the region to filter
41 | end_index: the end index of the region to filter
42 | """
43 | select_index1 = (input_row>=start_index) & (input_row=start_index) & (input_col=start_row) & (input_row=start_col) & (input_colbchw",samples,imagenet_std)
33 | new_samples = torch.clip((new_samples+ imagenet_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)) * 255, 0, 255)
34 | return new_samples
35 |
36 | def torch_to_nparray(data):
37 | #https://github.com/pytorch/pytorch/blob/main/torch/utils/tensorboard/summary.py
38 | #image take n,c,h,w,
39 | """
40 | 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
41 | The image() function will scale the image values to [0, 255] by applying
42 | a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
43 | will be clipped.
44 |
45 | """
46 | data = data.cpu().numpy()
47 | #data = data.transpose(0,2,3,1)
48 | data=np.array(data,dtype=np.uint8)
49 | return data
50 |
51 | def convert_gray_rgbimage(samples):
52 | """
53 | input: B,H,W
54 | """
55 | #add dimension in 1st dim
56 | if len(samples.shape)==3:
57 | samples = samples.unsqueeze(1)
58 | samples = torch.clip(samples, 0, 1)
59 | red_channel = torch.ones(samples.shape,device=samples.device)
60 | gb_channel = 1-samples
61 | new_samples=torch.cat([red_channel,gb_channel,gb_channel],dim=1)*255
62 | return new_samples
--------------------------------------------------------------------------------
/pretrain.py:
--------------------------------------------------------------------------------
1 | # My code has references to the following repositories:
2 | # DeiT: https://github.com/facebookresearch/deit
3 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
4 | # MAE: https://github.com/facebookresearch/mae
5 | # AdPE: https://github.com/maple-research-lab/AdPE
6 | # --------------------------------------------------------
7 | import os
8 | from ops.argparser import argparser_pretrain
9 | import torch
10 | import torch.multiprocessing as mp
11 | import timm
12 | assert timm.__version__ == "0.3.2" # version check
13 |
14 |
15 | def main(args):
16 | import socket
17 | hostname = socket.gethostname()
18 | local_ip = socket.gethostbyname(hostname)
19 | print("local ip: ",local_ip)
20 |
21 | ngpus_per_node = torch.cuda.device_count()
22 | args.world_size = args.world_size*ngpus_per_node
23 | from pretrain.main_worker import main_worker
24 | if ngpus_per_node==1:
25 | main_worker(args.gpu,ngpus_per_node,args)#if you only have one gpu
26 | else:
27 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
28 |
29 |
30 |
31 |
32 | if __name__ == '__main__':
33 | import resource
34 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
35 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096*2, rlimit[1]))
36 | limit_in_b = 900 * 1024 ** 3
37 | resource.setrlimit(resource.RLIMIT_DATA, (limit_in_b, limit_in_b))
38 | use_cuda = torch.cuda.is_available()
39 | print("starting check cuda status",use_cuda)
40 | #assert cuda is available
41 | assert use_cuda == True, "CUDA is not available, pre-training requires CUDA support to run!"
42 | parser = argparser_pretrain()
43 | args = parser.parse_args()
44 | #If you have many GPU on your server, but you only want to use few of them
45 | # run command line to configure the environment:
46 | # export CUDA_VISIBLE_DEVICES="0,1,2,3"
47 | # Here you can specify the GPU you want to use
48 | #check the specied input size, must be a multiple of args.patch_size
49 | if args.input_row_size%args.patch_size!=0 or args.input_col_size%args.patch_size!=0:
50 | print("args configuration error: input_row_size and input_col_size must be a multiple of patch_size")
51 | exit(1)
52 | main(args)
53 |
--------------------------------------------------------------------------------
/pretrain/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/pretrain/__init__.py
--------------------------------------------------------------------------------
/pretrain/main_worker.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.backends.cudnn as cudnn
5 | import time
6 | import datetime
7 | import json
8 | import numpy as np
9 | import torchvision.transforms as transforms
10 | import timm.optim.optim_factory as optim_factory
11 |
12 | from ops.distribute_utils import init_distributed_mode,get_world_size,get_rank,is_main_process
13 | from ops.Logger import print_important_info,print_warning_info
14 | from data_processing.pretrain_dataset import Pretrain_Dataset
15 | from data_processing.collate_fn import collate_fn
16 | import model.models_hicfoundation as models_hicfoundation
17 | from model.NativeScaler import NativeScalerWithGradNormCount as NativeScaler
18 | from model.model_utils import load_model,save_checkpoint,save_model2path
19 | from ops.io_utils import write_log
20 | from pretrain.train_epoch import train_epoch
21 | from pretrain.val_epoch import val_epoch
22 |
23 | def parse_text(config_file, data_dir):
24 | train_list=[]
25 | with open(config_file) as f:
26 | for line in f:
27 | line = line.strip()
28 | line = line.replace('\n', '')
29 | if len(line) == 0:
30 | continue
31 | current_path = os.path.join(data_dir, line)
32 | if not os.path.exists(current_path):
33 | print("The sub-directory {} does not exist in the data directory".format(current_path))
34 | print("Please check the sub-directory name in the {} file".format(config_file))
35 | continue
36 | train_list.append(current_path)
37 | return train_list
38 |
39 | def configure_data_loader(args):
40 | transform_train = transforms.Compose([
41 | transforms.ToTensor(),
42 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
43 | data_dir=os.path.abspath(args.data_path)
44 | train_config= os.path.abspath(args.train_config)
45 | train_list = parse_text(train_config, data_dir)
46 | val_config= os.path.abspath(args.valid_config)
47 | val_list = parse_text(val_config, data_dir)
48 | input_row_size = args.input_row_size
49 | input_col_size = args.input_col_size
50 | sparsity_filter = float(args.sparsity_ratio)
51 | patch_size = args.patch_size
52 |
53 | dataset_train = Pretrain_Dataset(train_list,transform=transform_train,
54 | sparsity_filter=sparsity_filter,patch_size=patch_size,
55 | window_height=input_row_size,window_width=input_col_size)
56 | dataset_val = Pretrain_Dataset(val_list,transform=transform_train,
57 | sparsity_filter=sparsity_filter, patch_size=patch_size,
58 | window_height=input_row_size,window_width=input_col_size)
59 |
60 | if args.distributed:
61 | num_tasks = get_world_size()
62 | global_rank = get_rank()
63 | sampler_train = torch.utils.data.DistributedSampler(
64 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
65 | )
66 | print("Sampler_train = %s" % str(sampler_train))
67 | sampler_val = torch.utils.data.DistributedSampler(
68 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False
69 | )
70 | else:
71 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
72 | sampler_val = torch.utils.data.RandomSampler(dataset_val)
73 | global_rank = -1
74 | sample_batch_size = args.batch_size
75 | data_loader_train = torch.utils.data.DataLoader(
76 | dataset_train, batch_size=sample_batch_size, sampler=sampler_train,
77 | num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True,
78 | collate_fn=collate_fn
79 | )
80 | data_loader_val = torch.utils.data.DataLoader(
81 | dataset_val, batch_size=sample_batch_size, sampler=sampler_val,
82 | num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False,
83 | collate_fn=collate_fn
84 | )
85 | return data_loader_train, data_loader_val
86 |
87 | def config_writer(output_dir,tensorboard_log):
88 | tensorboard_dir = os.path.join(output_dir,'tensorboard')
89 | os.makedirs(tensorboard_dir,exist_ok=True)
90 | if tensorboard_log:
91 | from torch.utils.tensorboard import SummaryWriter
92 | log_writer = SummaryWriter(tensorboard_dir)
93 | else:
94 | log_writer = None
95 | return log_writer
96 | def main_worker(gpu, ngpus_per_node,args):
97 | if ngpus_per_node>1:
98 | init_distributed_mode(gpu,ngpus_per_node,args)
99 | else:
100 | args.distributed=False
101 | print_warning_info("The distributed mode is disabled.\n For pre-training, one GPU may take very long to train!")
102 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
103 | print("{}".format(args).replace(', ', ',\n'))
104 | if args.distributed:
105 | num_tasks = get_world_size()
106 | global_rank = get_rank()
107 | else:
108 | global_rank = -1
109 | num_tasks = 1
110 | output_dir = os.path.abspath(args.output)
111 | if global_rank==0:
112 | os.makedirs(output_dir,exist_ok=True)
113 | log_writer =config_writer(output_dir,args.tensorboard)
114 | elif args.distributed:
115 | log_writer = None
116 | else:
117 | os.makedirs(output_dir,exist_ok=True)
118 | log_writer = config_writer(output_dir,args.tensorboard)
119 |
120 | cudnn.benchmark = True
121 | device = torch.device(args.device)
122 |
123 | # Data loading code
124 | data_loader_train, data_loader_val = configure_data_loader(args)
125 | print("Data loader is configured!")
126 |
127 | # Configure the model
128 | patch_wise_size = (args.input_row_size//args.patch_size,args.input_col_size//args.patch_size)
129 | model = models_hicfoundation.__dict__[args.model](img_size=(args.input_row_size,args.input_col_size))
130 |
131 | model.to(device)
132 | model_without_ddp = model
133 |
134 | if args.distributed:
135 | #not necessary for current setting, since all param with grad
136 | #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
137 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],find_unused_parameters=True)
138 | model_without_ddp = model.module
139 | else:
140 | model_without_ddp =model
141 | #print out model information
142 | print("Model is configured!")
143 |
144 | #configure batch size, learning rate, weight decay
145 | eff_batch_size = args.batch_size * args.accum_iter*get_world_size()
146 | args.lr = args.blr * eff_batch_size / 256
147 | print("Learning rate: %.6f"%args.lr)
148 | print("Accumulative grad iteration: %d"%args.accum_iter)
149 | print("Effective batch size: %d"%eff_batch_size)
150 | print("Base learning rate: %.6f"%args.blr)
151 |
152 | #configure optimizer
153 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
154 | #print(param_groups) #too long printing
155 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
156 | loss_scaler = NativeScaler()
157 |
158 | #resume model if it is loading from checkpoint
159 | print("Optimizer is configured!")
160 |
161 | resume_path=os.path.abspath(args.resume)
162 | load_model(resume_path,args, model_without_ddp, optimizer, loss_scaler)
163 |
164 | model_dir = os.path.join(output_dir, 'model')
165 | os.makedirs(model_dir,exist_ok=True)
166 | log_dir = os.path.join(output_dir, 'log')
167 | os.makedirs(log_dir,exist_ok=True)
168 |
169 | epochs = int(args.epochs)
170 | start_epoch = int(args.start_epoch)
171 | start_time = time.time()
172 | print("Start pre-training from epoch %d"%start_epoch," to epoch %d"%epochs)
173 | save_freq = args.save_freq
174 | best_loss = 1e9
175 | for epoch in range(start_epoch, epochs):
176 | if args.distributed:
177 | data_loader_train.sampler.set_epoch(epoch)
178 | train_stats = train_epoch(
179 | model, data_loader_train,
180 | optimizer, device, epoch, loss_scaler,
181 | log_writer=log_writer,
182 | args=args
183 | )
184 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
185 | 'epoch': epoch,}
186 | if is_main_process():
187 | write_log(log_dir,"train",log_stats)
188 |
189 | #validation run
190 | val_stats = val_epoch(
191 | model, data_loader_val,
192 | device, epoch,
193 | log_writer=log_writer,
194 | args=args
195 | )
196 | val_loss = val_stats['loss']
197 | log_stats_val = {**{f'val_{k}': v for k, v in val_stats.items()},
198 | 'epoch': epoch,}
199 | if is_main_process():
200 | write_log(log_dir,"val",log_stats_val)
201 | if epoch%save_freq==0 or epoch==epochs-1:
202 | #output_dir, args,epoch, model_without_ddp, optimizer, loss_scaler
203 | save_checkpoint(model_dir, args,epoch, model_without_ddp, optimizer, loss_scaler)
204 |
205 |
206 | if val_loss < best_loss:
207 | best_loss = val_loss
208 | #model_path,args,epoch, model_without_ddp, optimizer, loss_scaler
209 | model_path = os.path.join(model_dir, 'model_best.pth.tar')
210 | save_model2path(model_path,args,epoch, model_without_ddp, optimizer, loss_scaler)
211 | total_time = time.time()-start_time
212 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
213 | print('Training time {}'.format(total_time_str))
214 | print("Fine-tuning of HiCFoundation is finished!")
215 | print("The model is saved in {}".format(model_dir))
216 | print("The log is saved in {}".format(log_dir))
217 |
--------------------------------------------------------------------------------
/pretrain/train_epoch.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import sys
4 | import numpy as np
5 | from typing import Iterable
6 | import torch
7 | import torch.nn.functional as F
8 | import time
9 |
10 | from ops.Logger import MetricLogger,SmoothedValue
11 | import model.lr_sched as lr_sched
12 |
13 | from ops.train_utils import list_to_device, to_value, create_image, torch_to_nparray
14 |
15 |
16 | def train_epoch(model,data_loader,optimizer,
17 | device, epoch, loss_scaler,
18 | log_writer=None,args=None):
19 | model.train()
20 |
21 | metric_logger = MetricLogger(delimiter=" ")
22 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
23 | header = 'Epoch: [{}]'.format(epoch)
24 | print_freq = args.print_freq
25 |
26 | accum_iter = args.accum_iter
27 |
28 | optimizer.zero_grad()
29 |
30 | if log_writer is not None:
31 | print('log_dir: {}'.format(log_writer.log_dir))
32 | print("number of iterations: ",len(data_loader))
33 | num_iter = len(data_loader)
34 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
35 | if data_iter_step % accum_iter == 0:
36 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
37 | input_matrix, mask_matrix, hic_count, return_diag,matrix_count = list_to_device(data,device=device)
38 | with torch.cuda.amp.autocast(): #to enable mixed precision training
39 | ssim_loss,contrastive_loss, count_pred, pred_image, mask \
40 | = model(input_matrix, mask_matrix, total_count=hic_count, \
41 | diag=return_diag,mask_ratio=args.mask_ratio)
42 |
43 | matrix_count = torch.log10(matrix_count+1)
44 | count_pred = count_pred.flatten()
45 | count_loss = torch.nn.functional.mse_loss(count_pred, matrix_count)
46 | loss = args.loss_alpha*(ssim_loss+count_loss) + contrastive_loss
47 | metric_logger.update(loss=to_value(loss))
48 | metric_logger.update(ssim_loss=to_value(ssim_loss))
49 | metric_logger.update(count_loss=to_value(count_loss))
50 | metric_logger.update(contrastive_loss=to_value(contrastive_loss))
51 | if not math.isfinite(to_value(loss)):
52 | print("Loss is {}, stopping training".format(to_value(loss)))
53 | #sys.exit(1)
54 | optimizer.zero_grad()
55 | continue
56 | loss = loss / accum_iter
57 | loss_scaler(loss, optimizer, parameters=model.parameters(),
58 | update_grad=(data_iter_step + 1) % accum_iter == 0)
59 |
60 | if (data_iter_step + 1) % accum_iter == 0:
61 | optimizer.zero_grad()
62 |
63 | torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on
64 | lr = optimizer.param_groups[0]["lr"]
65 | metric_logger.update(lr=lr)
66 |
67 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0):
68 | """
69 | We use epoch_1000x as the x-axis in tensorboard.
70 | This calibrates different curves when batch size changes.
71 | """
72 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
73 | log_writer.add_scalars('Loss/loss', {'train_loss': to_value(loss)}, epoch_1000x)
74 | log_writer.add_scalars('Loss/ssim_loss', {'train_loss': to_value(ssim_loss)}, epoch_1000x)
75 | log_writer.add_scalars('Loss/count_loss', {'train_loss': to_value(count_loss)}, epoch_1000x)
76 | log_writer.add_scalars('Loss/contrastive_loss', {'train_loss': to_value(contrastive_loss)}, epoch_1000x)
77 | log_writer.add_scalars('LR/lr', {'lr': lr}, epoch_1000x)
78 | #add visualization
79 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0:
80 | new_samples = create_image(input_matrix)
81 | mask_image = new_samples*(1-mask)
82 | pred_image = create_image(pred_image)
83 |
84 | select_num = min(8,len(new_samples))
85 | new_samples = torch_to_nparray(new_samples.clone().detach()[:select_num])
86 | mask_image = torch_to_nparray(mask_image.clone().detach()[:select_num])
87 | pred_image = torch_to_nparray(pred_image.clone().detach()[:select_num])
88 | log_writer.add_images('Target_%s'%"train", new_samples, epoch_1000x)
89 | log_writer.add_images('Input_%s'%"train", mask_image, epoch_1000x)
90 | log_writer.add_images('Pred_%s'%"train", pred_image, epoch_1000x)
91 | metric_logger.synchronize_between_processes()
92 | print("Averaged stats:", metric_logger)
93 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
94 |
95 |
96 |
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/pretrain/val_epoch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | import numpy as np
4 | from typing import Iterable
5 | import torch
6 | import torch.nn.functional as F
7 | import time
8 |
9 | from ops.Logger import MetricLogger,SmoothedValue
10 | import model.lr_sched as lr_sched
11 |
12 | from ops.train_utils import list_to_device, to_value, create_image, torch_to_nparray
13 |
14 | def val_epoch(model, data_loader,device, epoch,
15 | log_writer=None,
16 | args=None,flag='val'):
17 | model.eval()
18 | metric_logger = MetricLogger(delimiter=" ")
19 | header = 'Val Epoch: [{}]'.format(epoch)
20 | print_freq = args.print_freq
21 | accum_iter = args.accum_iter
22 | if log_writer is not None:
23 | print('log_dir: {}'.format(log_writer.log_dir))
24 | print("number of iterations: ",len(data_loader))
25 | num_iter = len(data_loader)
26 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
27 | input_matrix, mask_matrix, hic_count, return_diag,matrix_count = list_to_device(data,device=device)
28 | with torch.no_grad(): #to enable mixed precision training
29 | ssim_loss,contrastive_loss, count_pred, pred_image, mask \
30 | = model(input_matrix, mask_matrix, total_count=hic_count, \
31 | diag=return_diag,mask_ratio=args.mask_ratio)
32 | matrix_count = torch.log10(matrix_count+1)
33 | count_pred = count_pred.flatten()
34 | count_loss = torch.nn.functional.mse_loss(count_pred, matrix_count)
35 | loss = args.loss_alpha*(ssim_loss+count_loss) + contrastive_loss
36 | metric_logger.update(loss=to_value(loss))
37 | metric_logger.update(ssim_loss=to_value(ssim_loss))
38 | metric_logger.update(count_loss=to_value(count_loss))
39 | metric_logger.update(contrastive_loss=to_value(contrastive_loss))
40 | torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on
41 |
42 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0):
43 | """
44 | We use epoch_1000x as the x-axis in tensorboard.
45 | This calibrates different curves when batch size changes.
46 | """
47 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
48 | log_writer.add_scalars('Loss/loss', {'%s_loss'%flag: to_value(loss)}, epoch_1000x)
49 | log_writer.add_scalars('Loss/ssim_loss', {'%s_loss'%flag: to_value(ssim_loss)}, epoch_1000x)
50 | log_writer.add_scalars('Loss/count_loss', {'%s_loss'%flag: to_value(count_loss)}, epoch_1000x)
51 | log_writer.add_scalars('Loss/contrastive_loss', {'%s_loss'%flag: to_value(contrastive_loss)}, epoch_1000x)
52 | #add visualization
53 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0:
54 | new_samples = create_image(input_matrix)
55 | mask_image = new_samples*(1-mask)
56 | pred_image = create_image(pred_image)
57 |
58 | select_num = min(8,len(new_samples))
59 | new_samples = torch_to_nparray(new_samples.clone().detach()[:select_num])
60 | mask_image = torch_to_nparray(mask_image.clone().detach()[:select_num])
61 | pred_image = torch_to_nparray(pred_image.clone().detach()[:select_num])
62 | log_writer.add_images('Target_%s'%flag, new_samples, epoch_1000x)
63 | log_writer.add_images('Input_%s'%flag, mask_image, epoch_1000x)
64 | log_writer.add_images('Pred_%s'%flag, pred_image, epoch_1000x)
65 | metric_logger.synchronize_between_processes()
66 | print("Averaged stats:", metric_logger)
67 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
--------------------------------------------------------------------------------
/refactor_pretrain/Logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Code here is borrowed from the https://github.com/facebookresearch/deit/blob/main/utils.py
3 | This file includes two classes SmoothValue and MetricLoggers that are directly copied from the link above
4 | TODO:
5 | print_important_info and print_warning_info are new
6 | move these classes and the new functions to a utils.py file
7 | """
8 |
9 |
10 | from collections import defaultdict, deque
11 | import torch
12 | from distribute_utils import is_dist_avail_and_initialized
13 | import torch.distributed as dist
14 | import datetime
15 | import os
16 | import time
17 |
18 | class SmoothedValue(object):
19 | """Track a series of values and provide access to smoothed values over a
20 | window or the global series average.
21 | """
22 |
23 | def __init__(self, window_size=20, fmt=None):
24 | if fmt is None:
25 | fmt = "{median:.4f} ({global_avg:.4f})"
26 | self.deque = deque(maxlen=window_size)
27 | self.total = 0.0
28 | self.count = 0
29 | self.fmt = fmt
30 |
31 | def update(self, value, n=1):
32 | self.deque.append(value)
33 | self.count += n
34 | self.total += value * n
35 |
36 | def synchronize_between_processes(self):
37 | """
38 | Warning: does not synchronize the deque!
39 | """
40 | if not is_dist_avail_and_initialized():
41 | return
42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
43 | dist.barrier()
44 | dist.all_reduce(t)
45 | t = t.tolist()
46 | self.count = int(t[0])
47 | self.total = t[1]
48 |
49 | @property
50 | def median(self):
51 | d = torch.tensor(list(self.deque))
52 | return d.median().item()
53 |
54 | @property
55 | def avg(self):
56 | d = torch.tensor(list(self.deque), dtype=torch.float32)
57 | return d.mean().item()
58 |
59 | @property
60 | def global_avg(self):
61 | return self.total / self.count
62 |
63 | @property
64 | def max(self):
65 | return max(self.deque)
66 |
67 | @property
68 | def value(self):
69 | return self.deque[-1]
70 |
71 | def __str__(self):
72 | return self.fmt.format(
73 | median=self.median,
74 | avg=self.avg,
75 | global_avg=self.global_avg,
76 | max=self.max,
77 | value=self.value)
78 |
79 |
80 | class MetricLogger(object):
81 | def __init__(self, delimiter="\t"):
82 | self.meters = defaultdict(SmoothedValue)
83 | self.delimiter = delimiter
84 |
85 | def update(self, **kwargs):
86 | for k, v in kwargs.items():
87 | if v is None:
88 | continue
89 | if isinstance(v, torch.Tensor):
90 | v = v.item()
91 | assert isinstance(v, (float, int))
92 | self.meters[k].update(v)
93 |
94 | def __getattr__(self, attr):
95 | if attr in self.meters:
96 | return self.meters[attr]
97 | if attr in self.__dict__:
98 | return self.__dict__[attr]
99 | raise AttributeError("'{}' object has no attribute '{}'".format(
100 | type(self).__name__, attr))
101 |
102 | def __str__(self):
103 | loss_str = []
104 | for name, meter in self.meters.items():
105 | loss_str.append(
106 | "{}: {}".format(name, str(meter))
107 | )
108 | return self.delimiter.join(loss_str)
109 |
110 | def synchronize_between_processes(self):
111 | for meter in self.meters.values():
112 | meter.synchronize_between_processes()
113 |
114 | def add_meter(self, name, meter):
115 | self.meters[name] = meter
116 | def log(self,iteration,print_freq,header=None):
117 | if iteration % print_freq == 0:
118 | return
119 |
120 | if not header:
121 | header = ''
122 | #space_fmt = ':' + str(iteration) + 'd'
123 | log_msg = [
124 | header,
125 | '['+str(iteration)+']',
126 | '{meters}',
127 | ]
128 | if torch.cuda.is_available():
129 | log_msg.append('max mem: {memory:.0f}')
130 | MB = 1024.0 * 1024.0
131 | log_msg = self.delimiter.join(log_msg)
132 | if torch.cuda.is_available():
133 | print(log_msg.format(
134 | iteration,
135 | meters=str(self),
136 | memory=torch.cuda.max_memory_allocated() / MB))
137 | else:
138 | print(log_msg.format(
139 | iteration,
140 | meters=str(self)))
141 | def log_every(self, iterable, print_freq, header=None):
142 | i = 0
143 | if not header:
144 | header = ''
145 | start_time = time.time()
146 | end = time.time()
147 | iter_time = SmoothedValue(fmt='{avg:.4f}')
148 | data_time = SmoothedValue(fmt='{avg:.4f}')
149 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
150 | log_msg = [
151 | header,
152 | '[{0' + space_fmt + '}/{1}]',
153 | 'eta: {eta}',
154 | '{meters}',
155 | 'time: {time}',
156 | 'data: {data}'
157 | ]
158 | if torch.cuda.is_available():
159 | log_msg.append('max mem: {memory:.0f}')
160 | log_msg = self.delimiter.join(log_msg)
161 | MB = 1024.0 * 1024.0
162 | for obj in iterable:
163 | data_time.update(time.time() - end)
164 | yield obj
165 | iter_time.update(time.time() - end)
166 | if i % print_freq == 0 or i == len(iterable) - 1:
167 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
168 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
169 | if torch.cuda.is_available():
170 | print(log_msg.format(
171 | i, len(iterable), eta=eta_string,
172 | meters=str(self),
173 | time=str(iter_time), data=str(data_time),
174 | memory=torch.cuda.max_memory_allocated() / MB),flush=True)
175 | else:
176 | print(log_msg.format(
177 | i, len(iterable), eta=eta_string,
178 | meters=str(self),
179 | time=str(iter_time), data=str(data_time)),flush=True)
180 | i += 1
181 | end = time.time()
182 | total_time = time.time() - start_time
183 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
184 | print('{} Total time: {} ({:.4f} s / it)'.format(
185 | header, total_time_str, total_time / len(iterable)),flush=True)
186 |
187 |
188 | def print_important_info(str):
189 | print("="*50)
190 | print("Important:"+str)
191 | print("="*50)
192 |
193 | def print_warning_info(str):
194 | print("*"*50)
195 | print("Warning:"+str)
196 | print("*"*50)
--------------------------------------------------------------------------------
/refactor_pretrain/distribute_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Code here is borrowed from the https://github.com/facebookresearch/deit/blob/main/utils.py
3 | This file includes utility scripts for handling distributed training
4 | """
5 |
6 | import os
7 | import torch
8 | import resource
9 | import datetime
10 | import builtins
11 | import numpy as np
12 | import torch.distributed as dist
13 |
14 | def is_dist_avail_and_initialized():
15 | """
16 | This function checks a distributed process, whether its initialized properly
17 | It returns true, if all the requisites, rank, world_size parameters, seeds etc.. are setup
18 | else it returns false
19 | Args:
20 | None
21 | Returns:
22 | bool: True if distributed backend is usable.
23 |
24 | """
25 | if not dist.is_available():
26 | return False
27 | if not dist.is_initialized():
28 | return False
29 | return True
30 |
31 | def get_rank():
32 | """
33 | This function fetches the rank of a distributed process in a distributed processes group.
34 | Ranks range from [0 - (world_size -1)], rank 0 is assigned to the main (master) worker
35 | In a non-distributed setting the rank of the main process is also 0
36 | Args:
37 | None
38 | Returns:
39 | int: Rank of the process
40 |
41 | """
42 | if not is_dist_avail_and_initialized():
43 | return 0
44 | return dist.get_rank()
45 |
46 | def get_world_size():
47 | """
48 | This function fetches the world_size which is the total number of workers in a distributed group.
49 | In a non-distributed setting the world size is 1 and the main process's rank is 0
50 | Args:
51 | None
52 | Returns:
53 | int: world_size (number of processes in a distributed group)
54 |
55 | """
56 | if not is_dist_avail_and_initialized():
57 | return 1
58 | return dist.get_world_size()
59 |
60 | def all_reduce_mean(x):
61 | """
62 | This function fetches value of a particular variable 'x' from all the distributed processes and
63 | returns a mean of mean of that value. Specifically, it could be used to fetch loss for a batch of input
64 | values across the distributed processes and can return the mean of that loss for the subsequent backprop.
65 | This can be used to gather evaluation metrics from other processes too.
66 | Args:
67 | x (float): A scalar value to be averaged.
68 | Returns:
69 | float: Average of x across all the distributed processes
70 | """
71 | world_size = get_world_size()
72 | if world_size > 1:
73 | x_reduce = torch.tensor(x).cuda()
74 | dist.all_reduce(x_reduce)
75 | x_reduce /= world_size
76 | return x_reduce.item()
77 | else:
78 | return x
79 |
80 | def is_main_process():
81 | """
82 | A binary check for testing whether the process we are operating on is the main process
83 | The check is simple, the rank of the current process should be 0 for it to considered main
84 | Args:
85 | None
86 | Returns:
87 | bool: True if main process
88 | """
89 |
90 | return get_rank() == 0
91 |
92 |
93 | def setup_for_distributed(is_master):
94 | """
95 | This is a function that overrides the python print function to be only
96 | accessible from the main process. This convention is followed to keep the logs clean
97 | Args:
98 | is_master (bool): Boolean input
99 | Returns:
100 | None
101 | """
102 | builtin_print = builtins.print
103 | def print(*args, **kwargs):
104 | if is_master:
105 | now = datetime.datetime.now().time()
106 | builtin_print('[{}] '.format(now), end='') # print with time stamp
107 | builtin_print(*args, **kwargs)
108 | builtins.print = print
109 |
110 | def init_distributed_mode(gpu, ngpus_per_node, args):
111 | """
112 | Initialize distributed training environment by setting a few os paramters and then setup
113 | distributed training across provided number of gpus across the provided number of nodes.
114 | Args:
115 | gpu (int): GPU index for current process.
116 | ngpus_per_node (int): Number of GPUs per node.
117 | args (Namespace): Arguments object containing attributes:
118 | - rank (int): Base rank of the node.
119 | - world_size (int): Total number of processes.
120 | - dist_url (str): Initialization URL for distributed training.
121 | - seed (int): Seed for random number generators.
122 | Returns:
123 | None
124 | """
125 |
126 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
127 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
128 | args.gpu = gpu
129 | args.rank = args.rank * ngpus_per_node + gpu
130 | os.environ['LOCAL_RANK'] = str(args.gpu)
131 | os.environ['RANK'] = str(args.rank)
132 | os.environ['WORLD_SIZE'] = str(args.world_size)
133 | print("make sure the distributed mode is ",args.dist_url)
134 |
135 |
136 |
137 | args.distributed = True
138 |
139 | torch.cuda.set_device(args.gpu)
140 | args.dist_backend = 'nccl'
141 | print('| distributed init (rank {}): {}, gpu {}'.format(
142 | args.rank, args.dist_url, args.gpu), flush=True)
143 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
144 | timeout=datetime.timedelta(seconds=36000),
145 | world_size=args.world_size, rank=args.rank)
146 |
147 | setup_for_distributed(args.rank == 0)
148 |
149 | # fix the seed for reproducibility
150 | seed = args.seed + get_rank()
151 | torch.manual_seed(seed)
152 | np.random.seed(seed)
153 |
--------------------------------------------------------------------------------
/refactor_pretrain/model_funcs.py:
--------------------------------------------------------------------------------
1 | # From MAE: https://github.com/facebookresearch/mae
2 |
3 | import math
4 | import torch
5 | import numpy as np
6 |
7 |
8 | def adjust_learning_rate(optimizer, epoch, args):
9 | """
10 | Decay the learning rate with half-cycle cosine after warmup
11 | :param optimizer: torch.optimizer
12 | :param epoch: int, training epoch
13 | :param args: argparse arguments
14 | :return: learning_rate
15 | """
16 | if epoch < args.warmup_epochs:
17 | lr = args.lr * epoch / args.warmup_epochs
18 | else:
19 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
20 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
21 | for param_group in optimizer.param_groups:
22 | if "lr_scale" in param_group:
23 | param_group["lr"] = lr * param_group["lr_scale"]
24 | else:
25 | param_group["lr"] = lr
26 | return lr
27 |
28 |
29 | def convert_count_to_pos_embed_cuda(count, embed_dim):
30 | """
31 | count should be log-formatted
32 | :param count: torch.tensor, (N,1)
33 | :param embed_dim: int, embedding dimension
34 | :return:
35 | """
36 | assert embed_dim % 2 == 0
37 | omega = torch.arange(embed_dim // 2, dtype=count.dtype, device=count.device)
38 | omega = omega / embed_dim / 2.
39 | omega = 1. / 10000 ** omega
40 |
41 | out = torch.einsum('m,d->md', count, omega) # (M, D/2), outer product
42 | emb_sin = torch.sin(out) # (M, D/2)
43 | emb_cos = torch.cos(out) # (M, D/2)
44 |
45 | emb = torch.cat([emb_sin, emb_cos], dim=1) + 1 # enforce different compared to other embeddings
46 | # (M, D)
47 | return emb
48 |
49 |
50 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
51 | """
52 | grid_size: int of the grid height and width
53 | return:
54 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
55 | """
56 | grid_h = np.arange(grid_size, dtype=np.float32)
57 | grid_w = np.arange(grid_size, dtype=np.float32)
58 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
59 | grid = np.stack(grid, axis=0)
60 |
61 | grid = grid.reshape([2, 1, grid_size, grid_size])
62 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
63 | if cls_token:
64 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
65 | return pos_embed
66 |
67 |
68 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
69 | assert embed_dim % 2 == 0
70 |
71 | # use half of dimensions to encode grid_h
72 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
73 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
74 |
75 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
76 | return emb
77 |
78 |
79 | def get_2d_sincos_pos_embed_rectangle(embed_dim, grid_size, cls_token=False):
80 | """
81 | grid_size, a tuple of height and width
82 | """
83 | grid_size_h, grid_size_w = grid_size
84 | grid_h = np.arange(grid_size_h, dtype=np.float32)
85 | grid_w = np.arange(grid_size_w, dtype=np.float32)
86 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
87 | grid = np.stack(grid, axis=0)
88 | grid = grid.reshape([2, 1, grid_size_w, grid_size_h])
89 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
90 | if cls_token:
91 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
92 | return pos_embed
93 |
94 |
95 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
96 | """
97 | embed_dim: output dimension for each position
98 | pos: a list of positions to be encoded: size (M,)
99 | out: (M, D)
100 | """
101 | assert embed_dim % 2 == 0
102 | omega = np.arange(embed_dim // 2, dtype=np.float32)
103 | omega /= embed_dim / 2.
104 | omega = 1. / 10000 ** omega # (D/2,)
105 |
106 | pos = pos.reshape(-1) # (M,)
107 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
108 |
109 | emb_sin = np.sin(out) # (M, D/2)
110 | emb_cos = np.cos(out) # (M, D/2)
111 |
112 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
113 | return emb
114 |
--------------------------------------------------------------------------------
/refactor_pretrain/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from distribute_utils import is_main_process
3 | from pathlib import Path
4 | import os
5 | from math import inf
6 |
7 |
8 | def load_model(resume_path, args, model_without_ddp, optimizer, loss_scaler):
9 | """
10 | Load the model from the checkpoint
11 | Args:
12 | resume_path: the path to the checkpoint
13 | model_without_ddp: the model
14 | optimizer: the optimizer
15 | loss_scaler: the loss scaler
16 | """
17 | if os.path.isfile(resume_path):
18 | print("=> loading checkpoint '{}'".format(resume_path))
19 | if resume_path.startswith('https'):
20 | checkpoint = torch.hub.load_state_dict_from_url(
21 | resume_path, map_location='cpu', check_hash=True)
22 | else:
23 | checkpoint = torch.load(resume_path, weights_only=False, map_location='cpu')
24 | msg = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
25 | print("model resume message:{}".format(msg))
26 | optimizer.load_state_dict(checkpoint['optimizer'])
27 | loss_scaler.load_state_dict(checkpoint['scaler'])
28 | args.start_epoch = checkpoint['epoch'] + 1
29 | print("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch']))
30 | else:
31 | print("=> no checkpoint found at '{}'".format(resume_path))
32 |
33 |
34 | def save_on_master(*args, **kwargs):
35 | """
36 | Save model
37 | :param args: positional arguments, torch.save arguments
38 | :param kwargs: keyword arguments, torch.save arguments
39 | :return: None
40 | """
41 | if is_main_process():
42 | torch.save(*args, **kwargs)
43 |
44 |
45 | def save_checkpoint(output_dir, args, epoch, model_without_ddp, optimizer, loss_scaler):
46 | """
47 | Save model, optimizer, epoch, scaler and arguments as a checkpoint
48 | :param output_dir: str, output directory for saving
49 | :param args: positional arguments
50 | :param epoch: int, training epoch
51 | :param model_without_ddp: torch.nn.module, model with no distributed components
52 | :param optimizer: torch.optimizer, optimizer for model training
53 | :param loss_scaler: torch.cuda.amp.GradScaler, dynamically scales the loss to prevent underflow when using float16
54 | :return: None.
55 | """
56 | # Save directory
57 | output_dir = Path(output_dir)
58 | # Output epoch
59 | epoch_name = str(epoch)
60 |
61 | # checkpoint path
62 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
63 | for checkpoint_path in checkpoint_paths:
64 | to_save = {
65 | 'model': model_without_ddp.state_dict(),
66 | 'optimizer': optimizer.state_dict(),
67 | 'epoch': epoch,
68 | 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None,
69 | 'args': args,
70 | }
71 |
72 | save_on_master(to_save, checkpoint_path)
73 |
74 |
75 | def save_model2path(model_path, args, epoch, model_without_ddp, optimizer, loss_scaler):
76 | """
77 | Save model to a certain path
78 | :param model_path: str, model path directory and file name
79 | :param args: positional arguments
80 | :param epoch: int, number of epochs
81 | :param model_without_ddp: torch.nn.module, model with no distributional parameters
82 | :param optimizer: torch.optimizer, optimizer for model training
83 | :param loss_scaler: torch.cuda.amp.GradScaler, dynamically scales the loss to prevent underflow when using float16
84 | :return: None
85 | """
86 | to_save = {
87 | 'model': model_without_ddp.state_dict(),
88 | 'optimizer': optimizer.state_dict(),
89 | 'epoch': epoch,
90 | 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None,
91 | 'args': args,
92 | }
93 | save_on_master(to_save, model_path)
94 |
95 |
96 | class NativeScalerWithGradNormCount:
97 | state_dict_key = "amp_scaler"
98 |
99 | def __init__(self):
100 | self._scaler = torch.cuda.amp.GradScaler()
101 |
102 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
103 | self._scaler.scale(loss).backward(create_graph=create_graph)
104 | if update_grad:
105 | if clip_grad is not None:
106 | assert parameters is not None
107 | # unscale the gradients of optimizer's assigned params in-place
108 | self._scaler.unscale_(optimizer)
109 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
110 | else:
111 | self._scaler.unscale_(optimizer)
112 | norm = get_grad_norm_(parameters)
113 | self._scaler.step(optimizer)
114 | self._scaler.update()
115 | else:
116 | norm = None
117 | return norm
118 |
119 | def state_dict(self):
120 | return self._scaler.state_dict()
121 |
122 | def load_state_dict(self, state_dict):
123 | self._scaler.load_state_dict(state_dict)
124 |
125 |
126 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
127 | """
128 | Compute gradient norm
129 | :param parameters: torch.tensor, model parameters
130 | :param norm_type: float, which gradient norm 2.0 -> L2
131 | :return: total norm
132 | """
133 | if isinstance(parameters, torch.Tensor):
134 | parameters = [parameters]
135 | parameters = [p for p in parameters if p.grad is not None]
136 | norm_type = float(norm_type)
137 | if len(parameters) == 0:
138 | return torch.tensor(0.)
139 | device = parameters[0].grad.device
140 | if norm_type == inf:
141 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
142 | else:
143 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
144 | norm_type)
145 | return total_norm
146 |
--------------------------------------------------------------------------------
/refactor_pretrain/pretrain.py:
--------------------------------------------------------------------------------
1 | # My code has references to the following repositories:
2 | # DeiT: https://github.com/facebookresearch/deit
3 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
4 | # MAE: https://github.com/facebookresearch/mae
5 | # AdPE: https://github.com/maple-research-lab/AdPE
6 | # --------------------------------------------------------
7 | import os
8 | from argparser import argparser_pretrain
9 | import torch
10 | import torch.multiprocessing as mp
11 | import timm
12 | assert timm.__version__ == "0.3.2" # version check
13 |
14 |
15 | def main(args):
16 | # collect multi-gpu info
17 | import socket
18 | hostname = socket.gethostname()
19 | local_ip = socket.gethostbyname(hostname)
20 | print("local ip: ",local_ip)
21 |
22 | # number of process per node
23 | ngpus_per_node = torch.cuda.device_count()
24 | # total number of process to run across all gpus
25 | # user gives number of nodes to use through args
26 | # here we multiply process per nodes
27 | args.world_size = args.world_size*ngpus_per_node
28 | from main_worker import main_worker
29 | if ngpus_per_node==1:
30 | main_worker(args.gpu,ngpus_per_node,args)#if you only have one gpu
31 | else:
32 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
33 |
34 |
35 |
36 |
37 | if __name__ == '__main__':
38 | import resource
39 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
40 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096*2, rlimit[1]))
41 | limit_in_b = 900 * 1024 ** 3
42 | resource.setrlimit(resource.RLIMIT_DATA, (limit_in_b, limit_in_b))
43 | use_cuda = torch.cuda.is_available()
44 | print("starting check cuda status",use_cuda)
45 | #assert cuda is available
46 | assert use_cuda == True, "CUDA is not available, pre-training requires CUDA support to run!"
47 | parser = argparser_pretrain()
48 | args = parser.parse_args()
49 | #If you have many GPU on your server, but you only want to use few of them
50 | # run command line to configure the environment:
51 | # export CUDA_VISIBLE_DEVICES="0,1,2,3"
52 | # Here you can specify the GPU you want to use
53 | #check the specied input size, must be a multiple of args.patch_size
54 | if args.input_row_size%args.patch_size!=0 or args.input_col_size%args.patch_size!=0:
55 | print("args configuration error: input_row_size and input_col_size must be a multiple of patch_size")
56 | exit(1)
57 | main(args)
58 |
--------------------------------------------------------------------------------
/refactor_pretrain/train_epoch.py:
--------------------------------------------------------------------------------
1 | import sys # Unused import
2 | import time # Unused import
3 | import math
4 | import torch
5 |
6 | import numpy as np
7 | import torch.nn.functional as F
8 |
9 | from typing import Iterable
10 | from Logger import MetricLogger,SmoothedValue
11 | from model_funcs import adjust_learning_rate
12 | from utils import list_to_device, to_value, create_image, torch_to_nparray
13 |
14 |
15 | def train_epoch(model,data_loader,optimizer,
16 | device, epoch, loss_scaler,
17 | log_writer=None,args=None):
18 |
19 | """
20 | Runs one full training epoch
21 | Args:
22 | model (torch.nn.Module): HiCFoundation model object
23 | data_loader (Iterable): Training Dataloader object
24 | optimizer (torch.optim.Optimizer): Optimizer instance for training.
25 | device (str): Training Device
26 | epoch (int): Current epoch
27 | loss_scaler (callable): NativeScalerWithGradNormCount object
28 | log_writer (optional): Tensorboard object
29 | args (Namespace or dict):
30 | Training configuration with at least:
31 | print_freq (int) - logging frequency in steps
32 | accum_iter (int) - number of steps to accumulate gradients
33 | mask_ratio (float) - masking ratio for model input
34 | loss_alpha (float) - weight for (ssim_loss + count_loss)
35 | Returns:
36 | dict[str, float]: Global metric averages upto this epoch
37 | """
38 |
39 | model.train() # Putting model in the training mode
40 |
41 | metric_logger = MetricLogger(delimiter=" ")
42 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) # Track the learning rate in the smoothedvalue logger
43 | header = 'Epoch: [{}]'.format(epoch)
44 | print_freq = args.print_freq # Logging frequency from config ??
45 |
46 | accum_iter = args.accum_iter
47 |
48 | optimizer.zero_grad() # Reset the old gradients
49 |
50 | if log_writer is not None:
51 | print('log_dir: {}'.format(log_writer.log_dir)) # I am not sure why we need to print it every epoch to know where we logging?
52 |
53 | print("number of iterations: ",len(data_loader))
54 |
55 | num_iter = len(data_loader) # Unused variable
56 |
57 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): # This allows to only log every "print_freq" steps
58 |
59 | if data_iter_step % accum_iter == 0: # Update learning rate after "accum_iter" steps
60 | adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
61 |
62 | # Move batch to device and unpack fields
63 | input_matrix, mask_matrix, hic_count, return_diag, matrix_count = list_to_device(data,device=device)
64 |
65 | with torch.cuda.amp.autocast(): #to enable mixed precision training
66 | ssim_loss, contrastive_loss, count_pred, pred_image, mask = model( # Forward pass
67 | input_matrix, mask_matrix,
68 | total_count=hic_count,
69 | diag=return_diag, mask_ratio=args.mask_ratio,
70 | )
71 |
72 | # Count prediction
73 | matrix_count = torch.log10(matrix_count+1)
74 | count_pred = count_pred.flatten()
75 | count_loss = torch.nn.functional.mse_loss(count_pred, matrix_count)
76 | loss = args.loss_alpha*(ssim_loss+count_loss) + contrastive_loss # Why alpha for ssim+count_loss and no control for contrastive?
77 |
78 | # Log loss values
79 | metric_logger.update(loss=to_value(loss))
80 | metric_logger.update(ssim_loss=to_value(ssim_loss))
81 | metric_logger.update(count_loss=to_value(count_loss))
82 | metric_logger.update(contrastive_loss=to_value(contrastive_loss))
83 |
84 | # Do not backprop for NaN or infinite loss (Maybe introduce loss clipping instead)?
85 | if not math.isfinite(to_value(loss)):
86 | print("Loss is {}, stopping training".format(to_value(loss)))
87 | #sys.exit(1)
88 | optimizer.zero_grad()
89 | continue
90 |
91 | # loss = loss / accum_iter
92 | loss_scaler(
93 | loss, optimizer,
94 | parameters=model.parameters(),
95 | update_grad=(data_iter_step + 1) % accum_iter == 0
96 | ) # NativeScalerWithGradNormCount object
97 |
98 | # After each accumulation cycle, clear grads
99 | if (data_iter_step + 1) % accum_iter == 0:
100 | optimizer.zero_grad()
101 |
102 | # Synchronization step (across multiple workers -- DDP)
103 | torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on
104 | lr = optimizer.param_groups[0]["lr"]
105 | metric_logger.update(lr=lr)
106 |
107 |
108 | # Tensorboard logging
109 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0):
110 | """
111 | We use epoch_1000x as the x-axis in tensorboard.
112 | This calibrates different curves when batch size changes.
113 | """
114 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
115 | log_writer.add_scalars('Loss/loss', {'train_loss': to_value(loss)}, epoch_1000x)
116 | log_writer.add_scalars('Loss/ssim_loss', {'train_loss': to_value(ssim_loss)}, epoch_1000x)
117 | log_writer.add_scalars('Loss/count_loss', {'train_loss': to_value(count_loss)}, epoch_1000x)
118 | log_writer.add_scalars('Loss/contrastive_loss', {'train_loss': to_value(contrastive_loss)}, epoch_1000x)
119 | log_writer.add_scalars('LR/lr', {'lr': lr}, epoch_1000x)
120 | #add visualization
121 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0:
122 | new_samples = create_image(input_matrix)
123 | mask_image = new_samples*(1-mask)
124 | pred_image = create_image(pred_image)
125 |
126 | select_num = min(8,len(new_samples))
127 | new_samples = torch_to_nparray(new_samples.clone().detach()[:select_num])
128 | mask_image = torch_to_nparray(mask_image.clone().detach()[:select_num])
129 | pred_image = torch_to_nparray(pred_image.clone().detach()[:select_num])
130 | log_writer.add_images('Target_%s'%"train", new_samples, epoch_1000x)
131 | log_writer.add_images('Input_%s'%"train", mask_image, epoch_1000x)
132 | log_writer.add_images('Pred_%s'%"train", pred_image, epoch_1000x)
133 |
134 | # Sync metrics across processes (DDP-safe)
135 | metric_logger.synchronize_between_processes()
136 | print("Averaged stats:", metric_logger)
137 |
138 | # Return global averages across all logged metrics
139 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
140 |
141 |
142 |
143 |
144 |
145 |
146 |
--------------------------------------------------------------------------------
/refactor_pretrain/utils.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | from ops.sparse_ops import array_to_coo
4 | """
5 | from scipy.sparse import triu,coo_matrix
6 | import numpy as np
7 | import torch
8 | import numpy as np
9 | import torch.nn as nn
10 | import pickle
11 | import os
12 | import json
13 |
14 |
15 | def array_to_coo(array):
16 | """
17 | Convert a regular 2D NumPy array to a scipy.sparse.coo_matrix.
18 |
19 | Parameters:
20 | - array (numpy.ndarray): The input 2D array.
21 |
22 | Returns:
23 | - scipy.sparse.coo_matrix: The converted COO matrix.
24 | """
25 | # Find the non-zero elements in the array
26 | row, col = np.nonzero(array)
27 |
28 | # Get the values of the non-zero elements
29 | data = array[row, col]
30 |
31 | # Create the COO matrix
32 | coo_mat = coo_matrix((data, (row, col)), shape=array.shape)
33 |
34 | return coo_mat
35 |
36 | """
37 | from ops.io_utils import load_pickle
38 | """
39 |
40 | def load_pickle(path):
41 | with open(path,'rb') as file:
42 | data=pickle.load(file)
43 | return data
44 |
45 | """
46 | from data_processing.finetune_dataset import to_tensor, list_to_tensor
47 | """
48 |
49 |
50 | def to_tensor(x):
51 | """
52 | Convert the input to tensor
53 | Args:
54 | x: the input data
55 | """
56 | if isinstance(x, np.ndarray):
57 | x = torch.from_numpy(x)
58 | elif x is None:
59 | x = None
60 | #if already tensor, do nothing
61 | elif isinstance(x, torch.Tensor):
62 | pass
63 | #if float, convert to tensor
64 | elif isinstance(x, float):
65 | x = torch.tensor(x)
66 | elif isinstance(x, int):
67 | x = torch.tensor(x)
68 | return x
69 |
70 | def list_to_tensor(x):
71 | """
72 | Convert the list to tensor
73 | Args:
74 | x: the input list
75 | """
76 | y=[]
77 | for i in x:
78 | y.append(to_tensor(i))
79 | return y
80 |
81 |
82 |
83 |
84 | """
85 | from ops.train_utils import list_to_device, to_value, create_image, torch_to_nparray
86 | """
87 |
88 | def list_to_device(data_list, device):
89 |
90 | def to_device(data, device):
91 | if data is not None:
92 | new_data = data.to(device,non_blocking=True)
93 | else:
94 | new_data = None
95 | return new_data
96 |
97 | new_data_list = []
98 | for data in data_list:
99 | data = to_device(data, device)
100 | if data is not None:
101 | data = data.float()
102 | new_data_list.append(data)
103 | return new_data_list
104 |
105 | def to_value(data):
106 | if isinstance(data, torch.Tensor):
107 | return data.item()
108 | else:
109 | return data
110 |
111 | def create_image(samples):
112 | imagenet_mean = np.array([0.485, 0.456, 0.406])
113 | imagenet_std = np.array([0.229, 0.224, 0.225])
114 | imagenet_mean = torch.tensor(imagenet_mean,device=samples.device)
115 | imagenet_std = torch.tensor(imagenet_std,device=samples.device)
116 | new_samples = torch.einsum("bchw,c->bchw",samples,imagenet_std)
117 | new_samples = torch.clip((new_samples+ imagenet_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)) * 255, 0, 255)
118 | return new_samples
119 |
120 | def torch_to_nparray(data):
121 | #https://github.com/pytorch/pytorch/blob/main/torch/utils/tensorboard/summary.py
122 | #image take n,c,h,w,
123 | """
124 | 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
125 | The image() function will scale the image values to [0, 255] by applying
126 | a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
127 | will be clipped.
128 |
129 | """
130 | data = data.cpu().numpy()
131 | #data = data.transpose(0,2,3,1)
132 | data=np.array(data,dtype=np.uint8)
133 | return data
134 |
135 |
136 | import torch
137 | def collate_fn(batch):
138 | # Transpose the batch (list of lists) to group elements by position
139 | batch_transposed = list(zip(*batch))
140 |
141 | # Process each position across the batch
142 | processed_batch = []
143 | for tensors in batch_transposed:
144 | if all(t is None for t in tensors): # If all are None, keep None
145 | processed_batch.append(None)
146 | else: # Otherwise, stack non-None tensors and replace None with zero tensors
147 | #make sure no None element in the tensors
148 | any_none = any(t is None for t in tensors)
149 | assert not any_none, "None element in a list of tensors"
150 | stacked = [
151 | t for t in tensors
152 | ]
153 | processed_batch.append(torch.stack(stacked))
154 |
155 | return processed_batch
156 |
157 |
158 | """
159 | from ops.io_utils import write_log
160 | """
161 | def write_log(log_dir,status_flag,log_stats):
162 | cur_log_path = os.path.join(log_dir,status_flag+".log")
163 | with open(cur_log_path, mode="a", encoding="utf-8") as f:
164 | f.write(json.dumps(log_stats) + "\n")
--------------------------------------------------------------------------------
/refactor_pretrain/val_epoch.py:
--------------------------------------------------------------------------------
1 | import sys # Unused
2 | import math # Unused
3 | import time # Unused
4 | import torch
5 | import numpy as np # Unused
6 | import torch.nn.functional as F
7 |
8 | from typing import Iterable # Unused
9 | from Logger import MetricLogger,SmoothedValue # SmoothedValue Unused
10 | from utils import list_to_device, to_value, create_image, torch_to_nparray
11 |
12 | def val_epoch(model, data_loader,
13 | device, epoch,
14 | log_writer=None,
15 | args=None, flag='val'): # Flag unnecessary its always a validation loop
16 | """
17 | Runs one full validation epoch.
18 |
19 | Args:
20 | model (torch.nn.Module): HiCFoundation model object
21 | data_loader (Iterable): Validation Dataloader object
22 | device (str): Training Device
23 | epoch (int): Current epoch
24 | log_writer (optional): Tensorboard object
25 | args (Namespace or dict):
26 | Training configuration with at least:
27 | print_freq (int) - logging frequency in steps
28 | accum_iter (int) - number of steps to accumulate gradients
29 | mask_ratio (float) - masking ratio for model input
30 | loss_alpha (float) - weight for (ssim_loss + count_loss)
31 | Returns:
32 | dict[str, float]: Global metric averages upto this epoch
33 | """
34 |
35 | model.eval() # Putting model in the evaluation mode
36 |
37 | metric_logger = MetricLogger(delimiter=" ")
38 | header = 'Val Epoch: [{}]'.format(epoch)
39 |
40 | print_freq = args.print_freq
41 | accum_iter = args.accum_iter
42 |
43 | if log_writer is not None:
44 | print('log_dir: {}'.format(log_writer.log_dir))
45 |
46 | print("number of iterations: ",len(data_loader))
47 |
48 | num_iter = len(data_loader) # Unused variable
49 |
50 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
51 |
52 | input_matrix, mask_matrix, hic_count, return_diag,matrix_count = list_to_device(data,device=device)
53 |
54 | with torch.no_grad():
55 | # Forward pass
56 | ssim_loss,contrastive_loss, count_pred, pred_image, mask = model(
57 | input_matrix, mask_matrix,
58 | total_count=hic_count, diag=return_diag,
59 | mask_ratio=args.mask_ratio)
60 |
61 | # Count loss
62 | matrix_count = torch.log10(matrix_count+1)
63 | count_pred = count_pred.flatten()
64 | count_loss = torch.nn.functional.mse_loss(count_pred, matrix_count)
65 |
66 | # Total loss
67 | loss = args.loss_alpha*(ssim_loss+count_loss) + contrastive_loss
68 |
69 | # Logging the losses
70 | metric_logger.update(loss=to_value(loss))
71 | metric_logger.update(ssim_loss=to_value(ssim_loss))
72 | metric_logger.update(count_loss=to_value(count_loss))
73 | metric_logger.update(contrastive_loss=to_value(contrastive_loss))
74 |
75 | torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on
76 |
77 | # Tensorboard loggings
78 | if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0):
79 | """
80 | We use epoch_1000x as the x-axis in tensorboard.
81 | This calibrates different curves when batch size changes.
82 | """
83 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
84 | log_writer.add_scalars('Loss/loss', {'%s_loss'%flag: to_value(loss)}, epoch_1000x)
85 | log_writer.add_scalars('Loss/ssim_loss', {'%s_loss'%flag: to_value(ssim_loss)}, epoch_1000x)
86 | log_writer.add_scalars('Loss/count_loss', {'%s_loss'%flag: to_value(count_loss)}, epoch_1000x)
87 | log_writer.add_scalars('Loss/contrastive_loss', {'%s_loss'%flag: to_value(contrastive_loss)}, epoch_1000x)
88 | #add visualization
89 | if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0:
90 | new_samples = create_image(input_matrix)
91 | mask_image = new_samples*(1-mask)
92 | pred_image = create_image(pred_image)
93 |
94 | select_num = min(8,len(new_samples))
95 | new_samples = torch_to_nparray(new_samples.clone().detach()[:select_num])
96 | mask_image = torch_to_nparray(mask_image.clone().detach()[:select_num])
97 | pred_image = torch_to_nparray(pred_image.clone().detach()[:select_num])
98 | log_writer.add_images('Target_%s'%flag, new_samples, epoch_1000x)
99 | log_writer.add_images('Input_%s'%flag, mask_image, epoch_1000x)
100 | log_writer.add_images('Pred_%s'%flag, pred_image, epoch_1000x)
101 |
102 | # Sync metrics across processes (DDP-safe)
103 | metric_logger.synchronize_between_processes()
104 | print("Averaged stats:", metric_logger)
105 |
106 | # Returns a global average of validation metrics
107 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | easydict
2 | opencv-python
3 | simplejson
4 | lvis
5 | Pillow==9.5.0
6 | pytorch_msssim
7 | pandas
8 | hic-straw
9 | matplotlib
10 | scikit-image
11 | scipy
12 | einops
13 | tensorboard
14 | cooler
15 | numba
16 | pyBigWig
17 | timm==0.3.2
--------------------------------------------------------------------------------
/test/test_convert_rgb.py:
--------------------------------------------------------------------------------
1 | """
2 | Unit test for convert RGB
3 |
4 | Author: Xumeng Zhang (xumzhang@uw.edu)
5 |
6 | This function takes only one np window from one npz file
7 | does not handle batch
8 |
9 | """
10 |
11 | import pytest
12 | import os
13 | import sys
14 | import inspect
15 |
16 | # add HiCFoundation dir into path
17 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
18 | parentdir = os.path.dirname(currentdir)
19 | sys.path.insert(0, parentdir)
20 | from data_processing.pretrain_dataset import Pretrain_Dataset
21 |
22 |
23 | import numpy as np
24 | import torchvision.transforms as transforms
25 |
26 | batch_size = 2
27 | transform_mean = [0.485, 0.456, 0.406]
28 | transform_std = [0.229, 0.224, 0.225]
29 |
30 |
31 | # pseudo dataset class for preprocessing
32 | @pytest.fixture
33 | def dataset():
34 | # Create a temporary empty folder
35 | empty_folder_path = os.path.join(os.getcwd(), "temp_empty_folder")
36 | assert not os.path.exists(empty_folder_path) or not os.listdir(empty_folder_path), "temporary empty folder exists and also not empty, try change the folder path for testing"
37 | os.makedirs(empty_folder_path, exist_ok=True)
38 |
39 | transform_train = transforms.Compose([
40 | transforms.ToTensor(),
41 | transforms.Normalize(mean=transform_mean, std=transform_std),
42 | ])
43 |
44 | dataset = Pretrain_Dataset(data_list=[empty_folder_path], transform=transform_train)
45 |
46 | # Remove the folder after dataset creation (or after using it)
47 | # to be safe, only remove it if the folder is empty
48 | if os.path.exists(empty_folder_path) and not os.listdir(empty_folder_path):
49 | os.rmdir(empty_folder_path)
50 |
51 | return dataset
52 |
53 |
54 | # example hic matrix
55 | @pytest.fixture
56 | def example_hic():
57 | # should be np matrix [window_size, window_size]
58 | hic = np.array([[0.5, 1.0], [0.0, 0.75]])
59 | return hic
60 |
61 |
62 | def test_convert_rgb_shape_and_dtype(dataset, example_hic):
63 | out = dataset.convert_rgb(example_hic, np.max(example_hic))
64 |
65 | assert out.shape == (*example_hic.shape, 3) # B,
66 | assert out.dtype == np.float32
67 |
68 |
69 | def test_convert_rgb_values(dataset, example_hic):
70 | """
71 | check if the converted values are as expected
72 | """
73 | out = dataset.convert_rgb(example_hic, np.max(example_hic))
74 | red = out[:, :, 0]
75 | green = out[:, :, 1]
76 | blue = out[:, :, 2]
77 |
78 | # the theoritical value for both blue and green channel
79 | expected_gb = (np.max(example_hic) - example_hic) / np.max(example_hic)
80 |
81 | assert np.allclose(red, np.ones_like(red)), "red channel should be all 1"
82 | assert np.allclose(green, expected_gb), "Green channel is not as expected"
83 | assert np.allclose(blue, expected_gb), "blue channel is not as expected"
84 |
85 |
86 | def test_log10_convert(dataset, example_hic):
87 | """
88 | Test is the function handles data after log10 well
89 | """
90 | log10_input = np.log10(example_hic + 1)
91 | log10_max_val = np.log10(np.max(example_hic) + 1)
92 | out = dataset.convert_rgb(log10_input, log10_max_val)
93 | red = out[:, :, 0]
94 | green = out[:, :, 1]
95 | blue = out[:, :, 2]
96 |
97 | # the theoritical value for both blue and green channel
98 | expected_gb = (log10_max_val - log10_input) / log10_max_val
99 |
100 | assert np.allclose(red, np.ones_like(red))
101 | assert np.allclose(green, expected_gb)
102 | assert np.allclose(blue, expected_gb)
103 |
104 |
105 | def test_post_convert_transform(dataset, example_hic):
106 | """
107 | Test if the post-process transformation reproduce the same matrix
108 | """
109 | log10_input = np.log10(example_hic + 1)
110 | log10_max_val = np.log10(np.max(example_hic) + 1)
111 |
112 | out = dataset.convert_rgb(log10_input, log10_max_val)
113 | out = dataset.transform(out) # [3, window_size, window_size]
114 | assert out.shape[0] == 3, (
115 | "After transformation, the first dimension should be 3, then the window_size by window_size"
116 | )
117 | red_inverse = out[0, :, :] * transform_std[0] + transform_mean[0]
118 | green_inverse = out[1, :, :] * transform_std[1] + transform_mean[1]
119 | blue_inverse = out[2, :, :] * transform_std[2] + transform_mean[2]
120 |
121 | expected_gb = (log10_max_val - log10_input) / log10_max_val
122 |
123 | assert np.allclose(red_inverse, np.ones_like(red_inverse), atol=1e-6)
124 | assert np.allclose(green_inverse, expected_gb, atol=1e-6)
125 | assert np.allclose(blue_inverse, expected_gb, atol=1e-6)
126 |
--------------------------------------------------------------------------------
/test/test_forward.py:
--------------------------------------------------------------------------------
1 | """
2 | Unit test for model forward
3 |
4 | Author: Xumeng Zhang (xumzhang@uw.edu)
5 |
6 | Test if the forward functions in HiCFoundation model work
7 | Models_HiCFoundation.forward_encoder()
8 | Models_HiCFoundation.forward_decoder()
9 | Models_HiCFoundation.forward()
10 |
11 | """
12 |
13 | import pytest
14 | import os
15 | import sys
16 | import inspect
17 | import numpy as np
18 | import torch
19 |
20 | # add HiCFoundation dir into path
21 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
22 | parentdir = os.path.dirname(currentdir)
23 | sys.path.insert(0, parentdir)
24 | from model.models_hicfoundation import Models_HiCFoundation
25 |
26 |
27 | batch_size = 3 # batch size
28 | in_chans = 3 # channels (1 or 3)
29 | image_size = 4 # image size
30 | patch_size = 2
31 |
32 |
33 | # initialize model
34 | @pytest.fixture
35 | def model():
36 | return Models_HiCFoundation(
37 | img_size=(image_size, image_size), patch_size=patch_size
38 | )
39 |
40 |
41 | #
42 | @pytest.fixture
43 | def example_input():
44 | imgs = torch.randn(batch_size, in_chans, image_size, image_size)
45 | # sample 1 and 3: no sparsity
46 | imgs_mask = torch.ones(batch_size, 1, image_size, image_size)
47 | # sample 2: all 0 matrix
48 | imgs_mask[1] = 0
49 | total_count = torch.tensor([1000000, 1000000, float("nan")])
50 | diag = torch.tensor([0, 0, 0]).float()
51 | return imgs, imgs_mask, total_count, diag
52 |
53 |
54 | def test_forward_encoder(model, example_input):
55 | imgs, imgs_mask, total_count, diag = example_input
56 |
57 | latent, mask, ids_restore = model.forward_encoder(
58 | imgs, total_count=total_count, diag=diag, mask_ratio=0.6
59 | )
60 | assert latent.shape[0] == imgs.shape[0]
61 | assert mask.shape[0] == imgs.shape[0]
62 | assert ids_restore.shape[0] == imgs.shape[0]
63 | assert torch.isnan(latent[2]).all(), (
64 | "input nan total count should result in nan encoder output"
65 | )
66 |
67 | # test case if total_count=None
68 | # here total_count=None means not inputing any value for total_counts
69 | # different than having nan in total_count tensor
70 | # nan in total_count tensor will lead to nan loss
71 | latent, mask, ids_restore = model.forward_encoder(
72 | imgs, diag=diag, mask_ratio=0.6
73 | )
74 | assert latent.shape[0] == imgs.shape[0]
75 | assert mask.shape[0] == imgs.shape[0]
76 | assert ids_restore.shape[0] == imgs.shape[0]
77 | assert not torch.isnan(latent).any(), "encoder output contains NaNs"
78 |
79 |
80 | def test_forward_decoder(model, example_input):
81 | imgs, imgs_mask, total_count, diag = example_input
82 |
83 | latent, mask, ids_restore = model.forward_encoder(
84 | imgs, total_count=total_count, diag=diag, mask_ratio=0.6
85 | )
86 | count_pred, patch_pred = model.forward_decoder(latent, ids_restore)
87 | assert count_pred.shape[0] == imgs.shape[0]
88 | assert patch_pred.shape[0] == imgs.shape[0]
89 | assert patch_pred.shape[1] == (image_size // patch_size) ** 2
90 | assert patch_pred.shape[2] == patch_size * patch_size * in_chans
91 | assert torch.isnan(count_pred[2]), (
92 | "input nan total count should result in nan count pred"
93 | )
94 |
95 |
96 | def test_forward(model, example_input):
97 | imgs, imgs_mask, total_count, diag = example_input
98 |
99 | ssim_loss, contrastive_loss, count_pred, pred_img, mask = model.forward(
100 | imgs,
101 | imgs_mask,
102 | total_count=total_count,
103 | diag=diag,
104 | mask_ratio=0.4,
105 | )
106 | assert torch.isnan(ssim_loss), "input nan total count should result in nan ssimloss"
107 | assert torch.isnan(contrastive_loss), (
108 | "input nan total count should result in nan contrastive loss"
109 | )
110 | assert count_pred.shape[0] == imgs.shape[0]
111 | assert pred_img.shape == imgs.shape
112 |
113 | # remove the last sample whose total_count is nan
114 | imgs = imgs[:-1]
115 | imgs_mask = imgs_mask[:-1]
116 | total_count = total_count[:-1]
117 | diag = diag[:-1]
118 | ssim_loss, contrastive_loss, count_pred, pred_img, mask = model.forward(
119 | imgs,
120 | imgs_mask,
121 | total_count=total_count,
122 | diag=diag,
123 | mask_ratio=0.4
124 | )
125 | assert ssim_loss.item() >= 0
126 | assert contrastive_loss.item() >= 0
127 | assert count_pred.shape[0] == imgs.shape[0]
128 | assert pred_img.shape == imgs.shape
129 | assert mask.shape == pred_img.shape
130 | for i in range(batch_size - 1):
131 | assert torch.equal(mask[i, 0], mask[i, 1]) and torch.equal(
132 | mask[i, 1], mask[i, 2]
133 | ), f"mask differ amont channels for sample {i}"
134 |
135 |
--------------------------------------------------------------------------------
/test/test_loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Unit test for loss functions
3 |
4 | Author: Xumeng Zhang (xumzhang@uw.edu)
5 |
6 | Models_HiCFoundation.forward_loss(self, imgs, imgs_mask, pred, mask):
7 | imgs: [N, 3, H, W]
8 | imgs_mask: [N, 1, H, W] indicate those 0 regions and mask them in target
9 | pred: [N, L, D], sequence of embeddings
10 | mask: [N, L], binary mask, 0 is keep, 1 is remove
11 |
12 | """
13 |
14 | import pytest
15 | import os
16 | import sys
17 | import inspect
18 | import numpy as np
19 |
20 | # add HiCFoundation dir into path
21 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
22 | parentdir = os.path.dirname(currentdir)
23 | sys.path.insert(0, parentdir)
24 | from model.models_hicfoundation import Models_HiCFoundation
25 | from data_processing.pretrain_dataset import Pretrain_Dataset
26 |
27 | import torchvision.transforms as transforms
28 |
29 | import torch
30 |
31 | image_size = 4
32 | patch_size = 2
33 | batch_size = 2
34 | mask_ratio = 0.6
35 | transform_mean = [0.485, 0.456, 0.406]
36 | transform_std = [0.229, 0.224, 0.225]
37 |
38 |
39 | # initialize model
40 | @pytest.fixture
41 | def model():
42 | return Models_HiCFoundation(
43 | img_size=(image_size, image_size), patch_size=patch_size
44 | )
45 |
46 |
47 | # pseudo dataset class for preprocessing
48 | @pytest.fixture
49 | def dataset():
50 | # Create a temporary empty folder
51 | empty_folder_path = os.path.join(os.getcwd(), "temp_empty_folder")
52 | assert not os.path.exists(empty_folder_path) or not os.listdir(empty_folder_path), "temporary empty folder exists and also not empty, try change the folder path for testing"
53 | os.makedirs(empty_folder_path, exist_ok=True)
54 |
55 | transform_train = transforms.Compose([
56 | transforms.ToTensor(),
57 | transforms.Normalize(mean=transform_mean, std=transform_std),
58 | ])
59 |
60 | dataset = Pretrain_Dataset(data_list=[empty_folder_path], transform=transform_train)
61 |
62 | # Remove the folder after dataset creation (or after using it)
63 | # to be safe, only remove it if the folder is empty
64 | if os.path.exists(empty_folder_path) and not os.listdir(empty_folder_path):
65 | os.rmdir(empty_folder_path)
66 |
67 | return dataset
68 |
69 |
70 | # Test when input is rgb
71 | @pytest.mark.parametrize("in_chans", [3])
72 | def test_forward_loss_rgb(model, dataset, in_chans):
73 | """
74 | test ssim loss and contrastive loss,
75 | here the ssim loss include the R channel loss
76 | as input of model.forward_loss, the imgs and pred should both be after normalization
77 |
78 | Args:
79 | model: initialized example model
80 | dataset: pseudo dataset
81 | """
82 | torch.manual_seed(42)
83 |
84 | B, C, H, W = batch_size, in_chans, image_size, image_size # small test case
85 | L = (H // patch_size) * (W // patch_size)
86 |
87 | # simulate data preprocessing steps before inputing into loss func
88 | # random example
89 | imgs = np.random.randint(low=0, high=100, size=(H, W))
90 | imgs_converted = dataset.convert_rgb(imgs, max_value=np.max(imgs)) # (H, W, C)
91 | imgs_converted = dataset.transform(imgs_converted) # (C, H, W)
92 | imgs_converted = torch.concat(
93 | [imgs_converted.unsqueeze(0)] * batch_size, dim=0
94 | ) # (B, C, H, W) repeat across batch
95 |
96 | # patched used later as true prediction
97 | imgs_patched = model.patchify(imgs_converted) # (B, L, D)
98 |
99 | # sparsity
100 | imgs_mask = (torch.rand(B, 1, H, W) * 0.2).float()
101 |
102 | # random prediction
103 | pred = torch.rand(B, L, patch_size * patch_size * C)
104 | mask = (torch.rand(B, L) > mask_ratio).float()
105 |
106 | ssim_loss, contrastive_loss = model.forward_loss(
107 | imgs_converted, imgs_mask, pred, mask
108 | )
109 |
110 | # sanity check
111 | assert isinstance(ssim_loss, torch.Tensor), "SSIM loss should be a tensor"
112 | assert isinstance(contrastive_loss, torch.Tensor), (
113 | "Contrastive loss should be a tensor"
114 | )
115 | assert not torch.isnan(ssim_loss), "SSIM loss shouldn't be NaN"
116 | assert not torch.isnan(contrastive_loss), "Contrastive loss shouldn't be NaN"
117 | assert ssim_loss.item() >= 0, "SSIM loss should be non-negative"
118 | assert contrastive_loss.item() >= 0, "Contrastive loss should be non-negative"
119 | assert ssim_loss.item() != 0, "SSIM loss with random prediction should not be 0"
120 | assert contrastive_loss.item() != 0, (
121 | "Contrastive loss with random prediction should not be 0"
122 | )
123 |
124 | # loss with itself
125 | # the prediction here should be after convert rgb
126 | ssim_loss, contrastive_loss = model.forward_loss(
127 | imgs_converted, imgs_mask, imgs_patched, mask
128 | )
129 | assert ssim_loss.item() == 0, (
130 | f"ssim loss with itself should be 0, but it is {ssim_loss}"
131 | )
132 | # actually in Xiao's contrastive loss, this number will never be 0
133 | # it should be as close as to 1
134 | # assert contrastive_loss.item() == 0, (
135 | # f"contrastive loss with itself should be 0, but it is {contrastive_loss}"
136 | # )
137 |
138 |
139 | # Test when input is count matrix
140 | @pytest.mark.parametrize("in_chans", [1])
141 | def test_forward_loss_count(model, in_chans):
142 | """
143 | test ssim loss and contrastive loss for HiC count matrix. i.e. in_chans=1
144 | here the ssim loss include the R channel loss
145 |
146 |
147 | Args:
148 | model: initialized example model
149 | dataset: pseudo dataset
150 | """
151 | torch.manual_seed(42)
152 |
153 | B, C, H, W = batch_size, in_chans, image_size, image_size # small test case
154 | L = (H // patch_size) * (W // patch_size)
155 |
156 | # random example
157 | img_np = np.random.randint(low=0, high=100, size=(H, W))
158 | img_torch = torch.tensor(img_np).unsqueeze(0).float() # (1,H,W)
159 | imgs_torch = torch.concat(
160 | [img_torch.unsqueeze(0)] * batch_size, dim=0
161 | ) # (B, C, H, W)
162 |
163 | # tune the model into 1 in_chans mode
164 | model.in_chans = in_chans
165 |
166 | imgs_patched = model.patchify(imgs_torch) # (B, L, D)
167 |
168 | imgs_mask = (torch.rand(B, 1, H, W) * 0.2).float()
169 |
170 | # random prediction
171 | pred = torch.rand(B, L, patch_size * patch_size * C)
172 | mask = (torch.rand(B, L) > mask_ratio).float()
173 |
174 | ssim_loss, contrastive_loss = model.forward_loss(imgs_torch, imgs_mask, pred, mask)
175 |
176 | # sanity check
177 | assert isinstance(ssim_loss, torch.Tensor), "SSIM loss should be a tensor"
178 | assert isinstance(contrastive_loss, torch.Tensor), (
179 | "Contrastive loss should be a tensor"
180 | )
181 | assert not torch.isnan(ssim_loss), "SSIM loss shouldn't be NaN"
182 | assert not torch.isnan(contrastive_loss), "Contrastive loss shouldn't be NaN"
183 | assert ssim_loss.item() >= 0, "SSIM loss should be non-negative"
184 | assert contrastive_loss.item() >= 0, "Contrastive loss should be non-negative"
185 | assert ssim_loss.item() != 0, "SSIM loss with random prediction should not be 0"
186 | assert contrastive_loss.item() != 0, (
187 | "Contrastive loss with random prediction should not be 0"
188 | )
189 |
--------------------------------------------------------------------------------
/test/test_masking.py:
--------------------------------------------------------------------------------
1 | """
2 | Unit test for masking
3 |
4 | Author: Xumeng Zhang (xumzhang@uw.edu)
5 |
6 | Test random masking for patches.
7 | Focus on diagonal symmetric masking as well.
8 |
9 | random_masking(self, x, mask_ratio,diag=None):
10 | x: [N, L, D], sequence (here L is without additional tokens)
11 | mask_ratio: float, masking ratio
12 | diag: [N,1] diagonal position to symmetrical masking, if None, then random masking
13 | """
14 |
15 | import pytest
16 | import os
17 | import sys
18 | import inspect
19 |
20 | # add HiCFoundation dir into path
21 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
22 | parentdir = os.path.dirname(currentdir)
23 | sys.path.insert(0, parentdir)
24 | from model.models_hicfoundation import Models_HiCFoundation, apply_symmectric_noise
25 |
26 |
27 | import torch
28 |
29 | batch_size = 3
30 | mask_ratio = 0.6
31 | image_size = 128
32 | patch_size = 16
33 |
34 |
35 | # initialize model class
36 | @pytest.fixture
37 | def model():
38 | return Models_HiCFoundation(
39 | img_size=(image_size, image_size), patch_size=patch_size
40 | )
41 |
42 |
43 | # feed an example noise as input
44 | @pytest.fixture
45 | def noise():
46 | noise = torch.zeros([image_size // patch_size, image_size // patch_size])
47 | for i in range(noise.shape[0]):
48 | for j in range(noise.shape[1]):
49 | noise[i, j] = i * 10 + j
50 | noise = torch.stack([noise] * batch_size)
51 | print(noise.shape)
52 | return noise
53 |
54 |
55 | def test_apply_symmetric_noise_shape_and_symmetry(noise):
56 | noise_clone = noise.clone()
57 | diag = torch.tensor([0, -2, 2])
58 |
59 | output = apply_symmectric_noise(noise_clone.clone(), diag)
60 |
61 | assert output.shape == noise.shape
62 |
63 | # Case 1: diag = 0, i.e. full matrix should be symmetric
64 | symm_part = output[0]
65 | assert torch.allclose(symm_part, symm_part.T, atol=1e-6)
66 |
67 | # Case 2: diag = -2
68 | affected = output[1, 2:, :-2]
69 | transposed = affected.T
70 | assert torch.allclose(affected, transposed, atol=1e-6)
71 |
72 | # Case 3: diag = 2
73 | affected = output[2, :-2, 2:]
74 | transposed = affected.T
75 | assert torch.allclose(affected, transposed, atol=1e-6)
76 |
77 |
78 | def test_random_masking_shape(model):
79 | """
80 | test Models_HiCFoundation.random_masking
81 | x is the (batch, length, dim) tensor
82 |
83 | """
84 | x = torch.randn(
85 | batch_size, model.num_patches, model.embed_dim
86 | ) # [batch, tokens, dim]
87 | print(x.shape)
88 |
89 | x_masked, mask, ids_restore = model.random_masking(x, mask_ratio)
90 |
91 | L = x.shape[1]
92 | len_keep = int(L * (1 - mask_ratio))
93 |
94 | assert x_masked.shape == (batch_size, len_keep, model.embed_dim)
95 | assert mask.shape == (batch_size, L)
96 | assert ids_restore.shape == (batch_size, L)
97 | assert (mask.sum(dim=1) == L - len_keep).all() # confirm number of masked positions
98 |
99 |
100 | def test_random_masking_deterministic_output_on_seed(model):
101 | torch.manual_seed(42)
102 | x = torch.randn(1, model.num_patches, model.embed_dim)
103 | x_masked1, mask1, _ = model.random_masking(x, mask_ratio)
104 |
105 | torch.manual_seed(42)
106 | x = torch.randn(1, model.num_patches, model.embed_dim)
107 | x_masked2, mask2, _ = model.random_masking(x, mask_ratio)
108 |
109 | # They should match since the seed and input match
110 | assert torch.allclose(x_masked1, x_masked2), (
111 | "The same random seed leads to different masking"
112 | )
113 | assert torch.allclose(mask1, mask2), (
114 | "The same random seed leads to different masking"
115 | )
116 |
117 | torch.manual_seed(42)
118 | x = torch.randn(1, model.num_patches, model.embed_dim)
119 | x_masked1, mask1, _ = model.random_masking(x, mask_ratio)
120 | x_masked2, mask2, _ = model.random_masking(x, mask_ratio)
121 |
122 | # They should match since the seed and input match
123 | assert not torch.allclose(x_masked1, x_masked2), (
124 | "Different sampling leads to the same masking"
125 | )
126 | assert not torch.allclose(mask1, mask2), (
127 | "Different sampling leads to the same masking"
128 | )
129 |
--------------------------------------------------------------------------------
/test/test_patchify.py:
--------------------------------------------------------------------------------
1 | """
2 | Unit test for patchify and unpatchify
3 | With a focus on the order of each patchify channels
4 |
5 | Author: Xumeng Zhang (xumzhang@uw.edu)
6 |
7 | Models_HiCFoundation.patchify(self, imgs, in_chans=None)
8 | imgs: (N, 3, H, W)
9 | x: (N, L, H*W *self.in_chans)
10 |
11 | Models_HiCFoundation.unpatchify(self, x, in_chans=None):
12 | x: (N, L, patch_size**2 *self.in_chans)
13 |
14 | """
15 |
16 | import pytest
17 | import os
18 | import sys
19 | import inspect
20 |
21 | # add HiCFoundation dir into path
22 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
23 | parentdir = os.path.dirname(currentdir)
24 | sys.path.insert(0, parentdir)
25 | from model.models_hicfoundation import Models_HiCFoundation
26 |
27 |
28 | import torch
29 |
30 | image_size = 4
31 | patch_size = 2
32 | batch_size = 2
33 |
34 | # test for in_chans both 1 and 3 each time
35 | in_chans = 3
36 |
37 |
38 | # feed an easily interpretable example image
39 | def make_example_image(in_chans):
40 | """
41 | make an example image that is easy for as to predict the patchify result
42 | This should be how it look like for count matrix or R channel of rgb image
43 | tensor([[1., 1., 2., 2.],
44 | [1., 1., 2., 2.],
45 | [3., 3., 4., 4.],
46 | [3., 3., 4., 4.]])
47 | G channel should be all 0
48 | B channel should be negative of R channel
49 |
50 | Returns:
51 | tensor: [B, C, H, W], image of size [batch, channel, height, weight]
52 | """
53 | n_patch_row = image_size // patch_size
54 |
55 | patches = []
56 | for i in range(n_patch_row * n_patch_row):
57 | if in_chans == 1:
58 | # make each patch a [patch_size, patch_size] square with all 1's or all 2's or ...
59 | patch = torch.full(
60 | (1, patch_size, patch_size), fill_value=(i + 1), dtype=torch.float32
61 | )
62 | elif in_chans == 3:
63 | # make each patch a [patch_size, patch_size] square
64 | # Channel R: i + 1 (e.g., 1, 2, 3, …)
65 | # Channel G: always 0
66 | # Channel B: -(i + 1) (e.g., -1, -2, -3, …)
67 | chan0 = torch.full(
68 | (patch_size, patch_size), fill_value=(i + 1), dtype=torch.float32
69 | )
70 | chan1 = torch.zeros((patch_size, patch_size), dtype=torch.float32)
71 | chan2 = torch.full(
72 | (patch_size, patch_size), fill_value=-(i + 1), dtype=torch.float32
73 | )
74 | patch = torch.stack([chan0, chan1, chan2])
75 | assert patch.shape == (in_chans, patch_size, patch_size), f"{patch.shape}"
76 | patches.append(patch)
77 |
78 | patches = torch.stack(patches) # [num_patches, inchan, patch_size, patch_size]
79 | patches = patches.reshape(
80 | n_patch_row, n_patch_row, in_chans, patch_size, patch_size
81 | )
82 | img = patches.permute(
83 | 2, 0, 3, 1, 4
84 | ) # [inchan, n_patch_row, patch_size, n_patch_row, patch_size]
85 | img = img.reshape(in_chans, image_size, image_size) # [C, H, W]
86 | # check the first patch is all 1
87 | assert torch.all(img[0, :patch_size, :patch_size] == 1)
88 | if in_chans == 3:
89 | assert torch.all(img[1, :, :] == 0)
90 | assert torch.all(img[2, :patch_size, :patch_size] == -1)
91 | # check the last patch is all n_patch_row**2
92 | assert torch.all(img[0, -patch_size:, -patch_size:] == n_patch_row**2)
93 | if in_chans == 3:
94 | assert torch.all(img[1, :, :] == 0)
95 | assert torch.all(img[2, -patch_size:, -patch_size:] == -(n_patch_row**2))
96 |
97 | return torch.cat([img.unsqueeze(0), img.unsqueeze(0)], dim=0) # [B, C, H, W]
98 |
99 |
100 | # initialize model class
101 | @pytest.fixture
102 | def model():
103 | return Models_HiCFoundation(
104 | img_size=(image_size, image_size), patch_size=patch_size
105 | )
106 |
107 |
108 | @pytest.mark.parametrize("in_chans", [1, 3])
109 | def test_patchify_output(model, in_chans):
110 | example_image = make_example_image(in_chans=in_chans)
111 | patches = model.patchify(example_image, in_chans=in_chans)
112 | # test output shape
113 | assert patches.shape == (
114 | batch_size,
115 | model.num_patches,
116 | model.patch_size**2 * in_chans,
117 | )
118 | assert patches.shape == (
119 | batch_size,
120 | (image_size / patch_size) ** 2,
121 | patch_size**2 * in_chans,
122 | ) # another way calculating
123 |
124 | # test whether output is as expected
125 | # the two batches are the same
126 | assert torch.all(patches[0] == patches[1])
127 |
128 | # check the three channels in the first patch
129 | for i in range(model.num_patches):
130 | if in_chans == 1:
131 | assert torch.all(
132 | patches[0, i, :] == torch.tensor([i + 1] * patch_size**2)
133 | ), f"patch No. {i} patchify not expected"
134 | else:
135 | assert torch.all(
136 | patches[0, i, :] == torch.tensor([i + 1, 0, -(i + 1)] * patch_size**2)
137 | ), f"patch No. {i} patchify not expected"
138 |
139 |
140 | # test unpatchify logic is simple, after patchify and unpatchify if it returns the same thing then its great
141 | @pytest.mark.parametrize("in_chans", [1, 3])
142 | def test_patchify_unpatchify_equivalence(model, in_chans):
143 | example_image = make_example_image(in_chans=in_chans)
144 | patches = model.patchify(example_image, in_chans=in_chans)
145 | reconstructed = model.unpatchify(patches, in_chans=in_chans)
146 |
147 | assert reconstructed.shape == example_image.shape
148 | assert torch.allclose(example_image, reconstructed, atol=1e-6), (
149 | "Reconstructed patchified image doesn't match original"
150 | )
151 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # This dir includes indepeneent functions that are used in the main scripts.
2 | # They can also be run independently.
--------------------------------------------------------------------------------
/utils/array2bigwig.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import shutil
4 | from collections import defaultdict
5 | import pyBigWig
6 | import sys
7 | import pickle
8 | # def parse_genome_size(genome_size_file):
9 | # chrom_size = {}
10 | # with open(genome_size_file) as f:
11 | # for line in f:
12 | # chrom, size = line.strip("\n").split()
13 | # chrom_size[chrom] = int(size)
14 | # return chrom_size
15 |
16 | def array2bigwig(input_file,output_bigwig,resolution=1000):
17 | """
18 | convert the 1D array to bigwig file
19 | """
20 |
21 | #read genome info
22 | #chrom_size = parse_genome_size(genome_path)
23 |
24 | # read the input file
25 | # if input_file.endswith(".npy"):
26 | # data = np.load(input_file)
27 | # elif input_file.endswith(".txt") or input_file.endswith(".Ev1") or input_file.endswith(".Ev2"):
28 | # data = np.loadtxt(input_file)
29 | # else:
30 | # print("The input file should be in npy or txt format")
31 | # sys.exit(1)
32 | with open(input_file, 'rb') as f:
33 | data = pickle.load(f)
34 | #generate chromosomes size dict
35 | chrom_size={}
36 | for chrom in data:
37 | cur_list=data[chrom]
38 | chrom_size[chrom]=len(cur_list)*resolution
39 |
40 | with pyBigWig.open(output_bigwig, "w") as bw:
41 | chromosomes = [(key,chrom_size[key]) for key in chrom_size]
42 | print(chromosomes)
43 | # Add chromosome information to the BigWig file
44 | bw.addHeader(chromosomes)
45 | for info in chromosomes:
46 | chrom = info[0]
47 | print("start chrom",chrom,"size",chrom_size[chrom])
48 | #pos_embedding = np.arange(0, len(data), dtype=int)
49 | start_embedding = np.arange(0, chrom_size[chrom], resolution)
50 | start_embedding = start_embedding.astype(int)
51 | end_embedding= start_embedding+resolution
52 | end_embedding = np.clip(end_embedding, 0, chrom_size[chrom]-1)
53 | # bw.addEntries("chr1", [500, 600, 635], values=[-2.0, 150.0, 25.0], span=20)
54 | chrom_data = data[chrom]
55 | chrom_data= np.nan_to_num(chrom_data)
56 | start_embedding = [int(x) for x in start_embedding]
57 | end_embedding = [int(x) for x in end_embedding]
58 | chrom_data= [float(x) for x in chrom_data]
59 | print(set([type(x) for x in chrom_data]))
60 | print(len(chrom_data),len(start_embedding),len(end_embedding))
61 | assert all(isinstance(c, str) for c in chrom), "Chromosomes must be strings"
62 | assert all(isinstance(s, int) for s in start_embedding), "Start positions must be integers"
63 | assert all(isinstance(e, int) for e in end_embedding), "End positions must be integers"
64 | assert all(isinstance(v, (int, float)) for v in chrom_data), "Values must be int or float"
65 | assert all(s >= 0 for s in start_embedding), "Start positions must be non-negative"
66 | assert all(e >= s for s, e in zip(start_embedding, end_embedding)), "End must be >= Start"
67 | for s, e, v in zip(start_embedding[:10], end_embedding[:10], chrom_data[:10]):
68 | print(f"Chromosome: {chrom}, Start: {s}, End: {e}, Value: {v}")
69 | bw.addEntries([chrom]*len(chrom_data), list(start_embedding), ends=list(end_embedding), values=list(chrom_data))
70 | print("finished",chrom)
71 | return output_bigwig
72 |
73 | """
74 | This script is used to merge bigwig files into one bigwig file.
75 | ```
76 | python3 array2bigwig.py [input_file] [output_bigwig] [resolution]
77 | ```
78 | [input_file]: the pkl file to be converted, should be a dict in format of [chr]:[array].
79 | [output_bigwig]: the output bigwig file.
80 | [resolution]: the resolution stored in the pkl file.
81 |
82 | """
83 |
84 |
85 | if __name__ == '__main__':
86 |
87 | if len(sys.argv) != 4:
88 | print("Usage: python3 array2bigwig.py [input_file] [output_bigwig] [resolution]")
89 | print("input_file: the pkl file to be converted, should be a dict in format of [chr]:[array].")
90 | print("output_bigwig: the output bigwig file.")
91 | print("resolution: the resolution stored in the pkl file.")
92 | sys.exit(1)
93 | input_file = os.path.abspath(sys.argv[1])
94 | output_bigwig = os.path.abspath(sys.argv[2])
95 | output_dir = os.path.basename(output_bigwig)
96 | os.makedirs(output_dir, exist_ok=True)
97 | resolution = int(sys.argv[3])
98 | output_bw = array2bigwig(input_file, output_bigwig, resolution)
99 |
100 | print("Finished converting saved to %s" % output_bw)
101 |
--------------------------------------------------------------------------------
/utils/array2cool.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import pickle
4 | import numpy as np
5 | import pandas as pd
6 | import cooler
7 | def array2sparse(array):
8 | """
9 | The array2sparse function converts a numpy array to a scipy sparce array.
10 |
11 | :param array: Specify the numpy array
12 | :return: A scipy sparce array
13 | :doc-author: Trelent
14 | """
15 | from scipy.sparse import coo_matrix
16 | row, col = np.where(array)
17 | data = array[row, col]
18 | return coo_matrix((data, (row, col)), shape=array.shape)
19 | def array2cool(input_array_pickle,output_cool,resolution,refer_genome_name,mode):
20 | """
21 | The array2cool function converts a dict of numpy array to cool file.
22 |
23 | :param juicer_tools: Specify the location of the juicer_tools
24 | :param input_array_pickle: Specify the path to the pickle file containing the array
25 | :param output_cool: Specify the name of the output cool file
26 | :param resolution: Set the resolution of the hic file
27 | :param refer_genome_name: Specify the reference genome name
28 | :return: A hic file
29 | :doc-author: Trelent
30 | """
31 | #load array
32 | with open(input_array_pickle, 'rb') as f:
33 | data = pickle.load(f)
34 | output_dir = os.path.dirname(output_cool)
35 | os.makedirs(output_dir, exist_ok=True)
36 |
37 | #set each chromosome's length
38 | chromsizes={"name":[],"length":[]}
39 | chromosize_add_dict ={}#[chr_name:add_size]
40 | sort_keys = sorted(data.keys())
41 | accumulate_index = 0
42 | for chrom_name in sort_keys:
43 |
44 | if mode == 0 or mode == 2:
45 | chrom1, chrom2 = chrom_name.split('_')
46 | else:
47 | chrom1 = chrom_name
48 | chrom2 = chrom_name
49 | if chrom1!=chrom2:
50 | continue
51 | if "chr" not in chrom1:
52 | chrom1 = "chr"+chrom1
53 | cur_array = data[chrom_name]
54 | chrom_size_total=resolution*cur_array.shape[0]
55 | #chromsizes={"name":chrom_name,"length":[chrom_size_total]}
56 | chromsizes['name'].append(chrom1)
57 | chromsizes['length'].append(chrom_size_total)
58 | chromosize_add_dict[chrom1]=accumulate_index
59 | accumulate_index += cur_array.shape[0]
60 | print("collecting bin dict size",chromsizes)
61 | chrom_dict=pd.DataFrame.from_dict(chromsizes).set_index("name")['length']
62 | bins = cooler.binnify(chrom_dict, resolution)
63 | #then convert data array to index raw column, count array
64 |
65 |
66 | data_dict = {"bin1_id":[],"bin2_id":[],"count":[]}
67 |
68 | for key in data:
69 | chrom_name = key
70 | if mode == 0 or mode == 2:
71 | chrom1, chrom2 = chrom_name.split('_')
72 | else:
73 | chrom1 = chrom_name
74 | chrom2 = chrom_name
75 | if "chr" not in chrom1:
76 | chrom1 = "chr"+chrom1
77 | if "chr" not in chrom2:
78 | chrom2 = "chr"+chrom2
79 | print("processing",chrom1,chrom2,"...")
80 | matrix = data[key]
81 | if mode>=2:
82 | matrix = array2sparse(matrix)
83 | matrix_row = matrix.row
84 | matrix_col = matrix.col
85 | matrix_data = matrix.data
86 |
87 | matrix_row += chromosize_add_dict[chrom1]
88 | matrix_col += chromosize_add_dict[chrom2]
89 | data_dict['bin1_id']+=list(matrix_row)
90 | data_dict["bin2_id"]+=list(matrix_col)
91 | data_dict['count'] +=list(matrix_data)
92 | accumulate_index += matrix.shape[0]
93 | print("creating cool file...")
94 | #cooler.create_cooler(hic_path, bins,data_dict, dtypes={"count":"int"}, assembly="hg38")
95 | cooler.create_cooler(output_cool, bins=pd.DataFrame.from_dict(bins), pixels=pd.DataFrame.from_dict(data_dict), dtypes={'count': float},assembly=refer_genome_name)
96 | """
97 | Usage
98 | ```
99 | python3 array2cool.py [input.pkl] [output.cool] [resolution] [refer_genome_name] [mode]
100 | ```
101 | The input pickle should be in a pickle file as dict: [chrom1_chrom2]:[array] format for common mode. Here array should be scipy sparce array.
102 | For intra-chromsome only, the dict format can be [chrom]:[array] in pickle files.
103 | [output.cool] is the name of the output cool file.
104 | [resolution] is used to specify the resolution that stored in the output array.
105 | [refer_genome_name] is used to specify the reference genome name. For example, "hg38","hg19","mm10" are valid inputs.
106 | [mode]: 0: all chromosome mode (scipy sparce array); 1: intra-chromosome mode(scipy sparce array); 2: all chromosome mode (numpy array); 3: intra-chromosome mode(numpy array).
107 | """
108 | if __name__ == '__main__':
109 |
110 | if len(sys.argv)!=6:
111 | print('Usage: python3 array2cool.py [input.pkl] [output.cool] [resolution] [refer_genome_name] [mode]')
112 | print("This is the full array2cool script. ")
113 | print("input.pkl: the path to the pickle file containing the array [String].")
114 | print("input.pkl format: [chrom1_chrom2]:[array] format for common mode. Here array should be scipy sparce array. For intra-chromsome only, the dict format can be [chrom]:[array] in pickle files.")
115 | print("output.cool: the name of the output cool file [String].")
116 | print("resolution: resolution of the input array [Integer].")
117 | print("refer_genome_name: the name of the reference genome [String]. Example: hg38, hg19, mm10.")
118 | print("mode: 0: all chromosome mode (scipy sparce array); 1: intra-chromosome mode(scipy sparce array); 2: all chromosome mode (numpy array); 3: intra-chromosome mode(numpy array).")
119 | sys.exit(1)
120 |
121 | script_dir = os.path.dirname(os.path.realpath(__file__))
122 | input_array_pickle = os.path.abspath(sys.argv[1])
123 | output_hic = os.path.abspath(sys.argv[2])
124 | resolution = int(sys.argv[3])
125 | refer_genome_name = str(sys.argv[4])
126 | mode = int(sys.argv[5])
127 | array2cool(input_array_pickle,output_hic,resolution,refer_genome_name,mode)
--------------------------------------------------------------------------------
/utils/array2hic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import pickle
4 | import numpy as np
5 | #assume at least run on a machine with 8 CPUs+64G memory
6 |
7 | def array2sparse(array):
8 | """
9 | The array2sparse function converts a numpy array to a scipy sparce array.
10 |
11 | :param array: Specify the numpy array
12 | :return: A scipy sparce array
13 | :doc-author: Trelent
14 | """
15 | from scipy.sparse import coo_matrix
16 | row, col = np.where(array)
17 | data = array[row, col]
18 | return coo_matrix((data, (row, col)), shape=array.shape)
19 |
20 | def array2hic(juicer_tools,input_array_pickle,
21 | output_hic,resolution,refer_genome_name,mode=0):
22 | """
23 | The array2hic function converts a numpy array to hic file.
24 |
25 | :param juicer_tools: Specify the location of the juicer_tools
26 | :param input_array_pickle: Specify the path to the pickle file containing the array
27 | :param output_hic: Specify the name of the output hic file
28 | :param resolution: Set the resolution of the hic file
29 | :param refer_genome_name: Specify the reference genome name
30 | :return: A hic file
31 | :doc-author: Trelent
32 | """
33 | #load array
34 | with open(input_array_pickle, 'rb') as f:
35 | data = pickle.load(f)
36 | output_dir = os.path.dirname(output_hic)
37 | os.makedirs(output_dir, exist_ok=True)
38 | raw_path = output_hic.replace('.hic','.raw')
39 | with open(raw_path, 'w') as wfile:
40 | for key in data:
41 | if mode == 0 or mode == 2:
42 | chrom1, chrom2 = key.split('_')
43 | else:
44 | chrom1 = key
45 | chrom2 = key
46 | matrix = data[key]
47 | if mode>=2:
48 | matrix = array2sparse(matrix)
49 | #matrix merge records in the same loc
50 | matrix.eliminate_zeros()
51 | matrix.sum_duplicates()
52 | matrix_row = matrix.row
53 | matrix_col = matrix.col
54 | matrix_data = matrix.data
55 | if "chr" not in chrom1:
56 | chrom1 = "chr"+chrom1
57 | if "chr" not in chrom2:
58 | chrom2 = "chr"+chrom2
59 | for i in range(len(matrix_row)):
60 | wfile.write(f'{0} {chrom1} {int(matrix_row[i]*resolution+1)} {0} {0} {chrom2} {matrix_col[i]*resolution+1} {1} {matrix_data[i]:.2f}\n')
61 | code_path = os.path.dirname(juicer_tools)
62 | root_path = os.getcwd()
63 | os.chdir(code_path)
64 | os.system(f'java -Xmx64g -Xmx64g -jar juicer_tools.jar pre -j 8 -d -r {resolution} "{raw_path}" "{output_hic}" "{refer_genome_name}"')
65 | os.remove(raw_path)
66 |
67 | os.chdir(root_path)
68 |
69 | """
70 | Usage
71 | ```
72 | python3 array2hic.py [input.pkl] [output.hic] [resolution] [refer_genome_name] [mode]
73 | ```
74 | The input pickle should be in a pickle file as dict: [chrom1_chrom2]:[array] format for common mode. Here array should be scipy sparce array.
75 | For intra-chromsome only, the dict format can be [chrom]:[array] in pickle files.
76 | [output.hic] is the name of the output hic file.
77 | [resolution] is used to specify the resolution that stored in the output array.
78 | [refer_genome_name] is used to specify the reference genome name. For example, "hg38","hg19","mm10" are valid inputs.
79 | [mode]: 0: all chromosome mode (scipy sparce array); 1: intra-chromosome mode(scipy sparce array); 2: all chromosome mode (numpy array); 3: intra-chromosome mode(numpy array).
80 | """
81 |
82 | if __name__ == '__main__':
83 |
84 | #get current script directory
85 |
86 | if len(sys.argv) != 6:
87 | print('Usage: python3 array2hic.py [input.pkl] [output.hic] [resolution] [refer_genome_name] [mode]')
88 | print("This is the full array2hic script. ")
89 | print("input.pkl: the path to the pickle file containing the array [String].")
90 | print("input.pkl format: [chrom1_chrom2]:[array] format for common mode. Here array should be scipy sparce array. For intra-chromsome only, the dict format can be [chrom]:[array] in pickle files.")
91 | print("output.hic: the name of the output hic file [String].")
92 | print("resolution: resolution of the input array [Integer].")
93 | print("refer_genome_name: the name of the reference genome [String]. Example: hg38, hg19, mm10.")
94 | print("mode: 0: all chromosome mode (scipy sparce array); 1: intra-chromosome mode(scipy sparce array); 2: all chromosome mode (numpy array); 3: intra-chromosome mode(numpy array).")
95 | sys.exit(1)
96 | script_dir = os.path.dirname(os.path.realpath(__file__))
97 | juicer_tools = os.path.join(script_dir, 'juicer_tools.jar')
98 | input_array_pickle = os.path.abspath(sys.argv[1])
99 | output_hic = os.path.abspath(sys.argv[2])
100 | resolution = int(sys.argv[3])
101 | refer_genome_name = str(sys.argv[4])
102 | mode = int(sys.argv[5])
103 | array2hic(juicer_tools,input_array_pickle,output_hic,resolution,refer_genome_name,mode)
104 |
105 |
106 |
--------------------------------------------------------------------------------
/utils/cool2array.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cooler
3 | from scipy.sparse import coo_matrix
4 | import pickle
5 | def write_pkl(data, path):
6 | with open(path, 'wb') as f:
7 | pickle.dump(data, f)
8 | def cool2array(cooler_path,normalize=False,tondarray=False):
9 | """
10 | cooler_path: the path to the cooler file
11 | normalize: if True, the matrix will be normalized by the norm matrix saved in the cooler file
12 | tondarray: if True, the return a numpy array dict
13 | return a numpy/scipy.sparse array dict
14 | [chromosome1_chromsome2]:sparce matrix
15 | """
16 | c = cooler.Cooler(cooler_path)
17 | binsize= c.info['bin-size']
18 | chromosomes= c.chromnames
19 | chromosome_sizes = c.chromsizes
20 | bins = c.bins()[:] #including the chromosome staring ending info for each bin
21 | bins['bin_id'] = bins.index
22 | #column of bins[['chrom', 'start', 'end','weight']
23 | pixels = c.pixels()[:] #including the chromosome staring ending info for each bin
24 | #including bin1_id,bin2_id,count columns
25 | return_dict={}
26 | for k,chromsome in enumerate(chromosomes):
27 | for j,chromsome2 in enumerate(chromosomes):
28 | if j
156 | The output array is saved in a pickle file as dict: [chrom1_chrom2]:[array] format.
157 | Two modes are supported:
158 | ```
159 | 0: scipy coo_array format output;
160 | 1: numpy array format output;
161 | 2: normed scipy coo_array format output;
162 | 3: normed numpy array format output.
163 | ```
164 | """
165 |
166 | if __name__ == '__main__':
167 | import os
168 | import sys
169 | if len(sys.argv) != 4:
170 | print('Usage: python3 cool2array.py [input.cool] [output.pkl] [mode]')
171 | print("This is the full cool2array script. ")
172 | print("mode: 0 for sparse matrix, 1 for dense matrix, 2 for normed sparse matrix, 3 for normed dense matrix")
173 | sys.exit(1)
174 |
175 | cooler_path = os.path.abspath(sys.argv[1])
176 | output_pkl_path = os.path.abspath(sys.argv[2])
177 | output_dir = os.path.dirname(output_pkl_path)
178 | os.makedirs(output_dir,exist_ok=True)
179 | mode = int(sys.argv[3])
180 | if mode not in [0,1,2,3]:
181 | print('mode should be 0,1,2,3')
182 | sys.exit(1)
183 | if mode == 0:
184 | normalize = False
185 | tondarray = False
186 | elif mode == 1:
187 | normalize = False
188 | tondarray = True
189 | elif mode == 2:
190 | normalize = True
191 | tondarray = False
192 | elif mode == 3:
193 | normalize = True
194 | tondarray = True
195 | return_dict = cool2array(cooler_path,normalize=normalize,tondarray=tondarray)
196 | write_pkl(return_dict,output_pkl_path)
197 |
--------------------------------------------------------------------------------
/utils/hic2array.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import numpy as np
4 | from scipy.sparse import coo_matrix
5 | import hicstraw
6 | import os
7 | import pickle
8 | def write_pkl(data, path):
9 | with open(path, 'wb') as f:
10 | pickle.dump(data, f)
11 | def read_chrom_array(chr1, chr2, normalization, hic_file, resolution,call_resolution):
12 | chr1_name = chr1.name
13 | chr2_name = chr2.name
14 | infos = []
15 | infos.append('observed')
16 | infos.append(normalization)
17 | infos.append(hic_file)
18 | infos.append(chr1_name)
19 | infos.append(chr2_name)
20 | infos.append('BP')
21 | infos.append(call_resolution)
22 | print(infos)
23 | row, col, val = [], [], []
24 | rets = hicstraw.straw(*infos)
25 | print('\tlen(rets): {:3e}'.format(len(rets)))
26 | for ret in rets:
27 | row.append((int)(ret.binX // resolution))
28 | col.append((int)(ret.binY // resolution))
29 | val.append(ret.counts)
30 | print('\tsum(val): {:3e}'.format(sum(val)))
31 | if sum(val) == 0:
32 | return None
33 | if chr1_name==chr2_name:
34 | max_shape =max(max(row),max(col))+1
35 | mat_coo = coo_matrix((val, (row, col)), shape = (max_shape,max_shape),dtype=np.float32)
36 | else:
37 | max_row = max(row)+1
38 | max_column = max(col)+1
39 | mat_coo = coo_matrix((val, (row, col)), shape = (max_row,max_column),dtype=np.float32)
40 |
41 | mat_coo = mat_coo #+ triu(mat_coo, 1).T #no below diagonaline records
42 |
43 | return mat_coo
44 |
45 | def validate_resolution(resolution,resolution_list):
46 | for reso in resolution_list:
47 | if resolution%reso==0:
48 | return True
49 | return False
50 |
51 | def find_call_resolution(resolution, resolution_list):
52 | max_best_resolution = min(resolution_list)
53 | resolution_list.sort()
54 | for reso in resolution_list:
55 | if resolution%reso==0:
56 | if reso>max_best_resolution:
57 | max_best_resolution=reso
58 | return max_best_resolution
59 |
60 | def hic2array(input_hic,output_pkl=None,
61 | resolution=25000,normalization="NONE",
62 | tondarray=0):
63 | """
64 | input_hic: str, input hic file path
65 | output_pkl: str, output pickle file path
66 | resolution: int, resolution of the hic file
67 | """
68 |
69 | hic = hicstraw.HiCFile(input_hic)
70 | chrom_list=[]
71 | chrom_dict={}
72 | for chrom in hic.getChromosomes():
73 | print(chrom.name, chrom.length)
74 | if "all" in chrom.name.lower():
75 | continue
76 | chrom_list.append(chrom)
77 | chrom_dict[chrom.name]=chrom.length
78 | resolution_list = hic.getResolutions()
79 | #max_resolution_candidate = max(resolution_list)
80 | if not validate_resolution(resolution,resolution_list):
81 | print("Resolution not found in the hic file, please choose from the following list:")
82 | print(resolution_list)
83 | print("Any other resolution is coarse than this and dividable by one of them can also be supported.")
84 | exit()
85 | output_dict={}
86 | for i in range(len(chrom_list)):
87 | for j in range(i,len(chrom_list)):
88 | if i!=j and tondarray in [2,3]:
89 | #skip inter-chromosome region
90 | continue
91 |
92 | chrom1 = chrom_list[i]
93 | chrom1_name = chrom_list[i].name
94 | chrom2 = chrom_list[j]
95 | chrom2_name = chrom_list[j].name
96 | if 'Un' in chrom1_name or 'Un' in chrom2_name:
97 | continue
98 | if "random" in chrom1_name.lower() or "random" in chrom2_name.lower():
99 | continue
100 | if "alt" in chrom1_name.lower() or "alt" in chrom2_name.lower():
101 | continue
102 | read_array=read_chrom_array(chrom1,chrom2, normalization, input_hic, resolution,call_resolution=find_call_resolution(resolution,resolution_list))
103 | if read_array is None:
104 | print("No data found for",chrom1_name,chrom2_name)
105 | continue
106 | if tondarray in [1,3]:
107 | read_array = read_array.toarray()
108 | if tondarray in [2,3]:
109 | output_dict[chrom1_name]=read_array
110 | else:
111 | output_dict[chrom1_name+"_"+chrom2_name]=read_array
112 | if output_pkl is not None:
113 | output_dir = os.path.dirname(os.path.realpath(output_pkl))
114 | os.makedirs(output_dir, exist_ok=True)
115 | write_pkl(output_dict,output_pkl)
116 |
117 | return output_dict
118 | """
119 |
120 | Usage
121 | ```
122 | python3 hic2array_simple.py [input.hic] [output.pkl] [resolution] [normalization_type] [mode]
123 | ```
124 |
125 | This is the full cool2array script, converting both intra, inter chromosome regions to array format.
126 | The output array is saved in a pickle file as dict: [chrom1_chrom2]:[array] format.
127 | [resolution] is used to specify the resolution that stored in the output array.
128 | [normalization_type] supports the following type:
129 | ```
130 | 0: NONE normalization applied, save the raw data to array.
131 | 1: VC normalization;
132 | 2: VC_SQRT normalization;
133 | 3: KR normalization;
134 | 4: SCALE normalization.
135 | ```
136 | Four modes are supported for different format saving:
137 | ```
138 | 0: scipy coo_array format output;
139 | 1: numpy array format output;
140 | 2: scipy csr_array format output (only include intra-chromsome region).
141 | 3: numpy array format output (only include intra-chromsome region).
142 | ```
143 |
144 | """
145 | if __name__ == '__main__':
146 | import os
147 | import sys
148 | if len(sys.argv) != 6:
149 | print('Usage: python3 hic2array_simple.py [input.hic] [output.pkl] [resolution] [normalization_type] [mode]')
150 | print("This is the full hic2array script. ")
151 | print("normalization type: 0: None normalization; 1: VC normalization; 2: VC_SQRT normalization; 3: KR normalization; 4: SCALE normalization")
152 | print("mode: 0 for sparse matrix, 1 for dense matrix, 2 for sparce matrix (only cis-contact); 3 for dense matrix (only cis-contact).")
153 | sys.exit(1)
154 | resolution = int(sys.argv[3])
155 | normalization_type = int(sys.argv[4])
156 | mode = int(sys.argv[5])
157 | normalization_dict={0:"NONE",1:"VC",2:"VC_SQRT",3:"KR",4:"SCALE"}
158 | if normalization_type not in normalization_dict:
159 | print('normalization type should be 0,1,2,3,4')
160 | print("normalization type: 0: None normalization; 1: VC normalization; 2: VC_SQRT normalization; 3: KR normalization; 4: SCALE normalization")
161 | sys.exit(1)
162 | normalization_type = normalization_dict[normalization_type]
163 | if mode not in [0,1,2,3]:
164 | print('mode should be in choice of 0/1/2/3')
165 | print("mode: 0 for sparse matrix, 1 for dense matrix, 2 for sparce matrix (only cis-contact); 3 for dense matrix (only cis-contact).")
166 | sys.exit(1)
167 | input_hic_path = os.path.abspath(sys.argv[1])
168 | output_pkl_path = os.path.abspath(sys.argv[2])
169 | output_dir = os.path.dirname(output_pkl_path)
170 | os.makedirs(output_dir,exist_ok=True)
171 | hic2array(input_hic_path,output_pkl_path,resolution,normalization_type,mode)
172 |
--------------------------------------------------------------------------------
/utils/hic_coverage.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import pickle
4 | def calculate_coverage(input_pkl):
5 | data = pickle.load(open(input_pkl, 'rb'))
6 | count_reads=0
7 | count_length=0
8 | for chrom in data:
9 | cur_data = data[chrom]
10 | count_length += cur_data.shape[0]
11 | count_reads += cur_data.sum()
12 | count_length = count_length
13 | count_reads = count_reads
14 | coverage = count_reads/count_length
15 | return coverage
16 | """
17 | This script calculates the coverage of the Hi-C data.
18 | ```
19 | python3 hic_coverage.py [input.pkl]
20 | ```
21 | [input.pkl]: the input pkl file containing the Hi-C data
22 |
23 | """
24 |
25 |
26 | if __name__ == '__main__':
27 | if len(sys.argv) != 2:
28 | print("Usage: python3 hic_coverage.py [input.pkl]")
29 | print("[input.pkl]: the input pkl file containing the Hi-C data")
30 | # print("[fragment_size]: the size of the fragment for building Hi-C")
31 | sys.exit(1)
32 | input_pkl = os.path.abspath(sys.argv[1])
33 | #resolution = int(sys.argv[2])
34 | #fragment_size = int(sys.argv[3])
35 | coverage = calculate_coverage(input_pkl)
36 | print("Hi-C Coverage: ", coverage)
--------------------------------------------------------------------------------
/utils/juicer_tools.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Noble-Lab/HiCFoundation/c733ffa6a3de071ba4b3bf6afca6bc9a0c741910/utils/juicer_tools.jar
--------------------------------------------------------------------------------