├── image └── figure_pae_0.png ├── README.md ├── .gitignore └── vis_pae.py /image/figure_pae_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zuricho/AlphaFold3_Result_Visualize/HEAD/image/figure_pae_0.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AlphaFold3 Result Visualize 2 | 3 | ## Introduction 4 | 5 | A convinent tool to visualize AlphaFold 3 prediction results with PyMOL and matplotlib 6 | 7 | What can this do: 8 | - Get a `.pml` script to help visualize including: (Work in progress) 9 | - pLDDT 10 | - Chain 11 | - Molecule type 12 | - low PAE contacts 13 | - Get figures about: (Finished) 14 | - PAE 15 | 16 | ## Usage 17 | 18 | ```bash 19 | python vis_pae.py 20 | ``` 21 | 22 | ## Example 23 | 24 | ![image](image/figure_pae_0.png) 25 | 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | .ipynb_checkpoints 4 | .coverage 5 | *.egg-info 6 | 7 | *.csv 8 | *.tsv 9 | *.pk 10 | *.pt 11 | *.fasta 12 | *.pickle 13 | *.pyc 14 | *.mp4 15 | *.ipynb 16 | 17 | *.constraints 18 | *.movemap 19 | *.resfile 20 | test_cmd.sh 21 | 22 | __mmtf__ 23 | __pycache__/ 24 | public 25 | htmlcov 26 | make.bat 27 | examples 28 | chroma/layers/structure/params/centering_2g3n.params 29 | wandb 30 | config.json 31 | 32 | # ides 33 | .vscode 34 | 35 | 36 | # slurm 37 | *.slurm 38 | task_file/ 39 | 40 | 41 | 42 | # alphafold predictions 43 | *.json 44 | *.pdb 45 | *.cif 46 | */terms_of_use.md 47 | fold* 48 | 49 | -------------------------------------------------------------------------------- /vis_pae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | import os 5 | import sys 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import matplotlib as mpl 10 | from mpl_toolkits.axes_grid1 import make_axes_locatable 11 | 12 | mpl.rcParams["font.size"] = 12 13 | mpl.rcParams["font.family"] = "Arial" 14 | 15 | 16 | 17 | # helper functions 18 | 19 | def hide_axes_frame(ax): 20 | """ 21 | Hide the frame of the given matplotlib axis. 22 | 23 | Parameters: 24 | ax (matplotlib.axes.Axes): The axis object whose frame will be made invisible. 25 | """ 26 | # Hide the axes frame 27 | for spine in ax.spines.values(): 28 | spine.set_visible(False) 29 | 30 | # Optionally, you might also want to hide the ticks and labels if needed 31 | ax.set_xticks([]) 32 | ax.set_yticks([]) 33 | 34 | 35 | def process_token_chains(token_chain_ids): 36 | """ 37 | Process token chains to assign numerical indices and find start and end indices for each unique token. 38 | 39 | Parameters: 40 | token_chain_ids (list): List of characters representing token chains. 41 | 42 | Returns: 43 | dict, dict, dict, np.ndarray: Three dictionaries mapping characters to their numerical index, start index, 44 | and end index, plus a numpy array of numerical indices. 45 | """ 46 | # Convert list to numpy array if not already 47 | token_chain_ids = np.array(token_chain_ids) 48 | 49 | # Find unique chains and create a mapping to numerical indices 50 | unique_chains = np.unique(token_chain_ids) 51 | chain_to_num = {char: i for i, char in enumerate(unique_chains)} 52 | 53 | # Convert token chains to numerical indices 54 | token_chain_nums = np.array([chain_to_num[char] for char in token_chain_ids]).reshape(1, -1) 55 | 56 | # Find the start and end index for each character 57 | chain_to_start_index = {} 58 | chain_to_end_index = {} 59 | 60 | for char in unique_chains: 61 | indices = np.where(token_chain_ids == char)[0] 62 | chain_to_start_index[char] = indices[0] 63 | chain_to_end_index[char] = indices[-1] 64 | 65 | return chain_to_num, chain_to_start_index, chain_to_end_index, token_chain_nums 66 | 67 | 68 | def main(alphafold_prediction_name): 69 | # get the file names 70 | json_name_full = [os.path.join(alphafold_prediction_name, f"{alphafold_prediction_name}_full_data_{i}.json") for i in range(5)] 71 | json_name_confidence = [os.path.join(alphafold_prediction_name, f"{alphafold_prediction_name}_summary_confidences_{i}.json") for i in range(5)] 72 | 73 | 74 | # load the data 75 | json_data_full = [json.load(open(i)) for i in json_name_full] 76 | json_data_confidence = [json.load(open(i)) for i in json_name_confidence] 77 | 78 | for model_num in range(5): 79 | 80 | 81 | json_data_full[model_num].keys() 82 | # atom_chain_ids, atom_plddts, contact_probs, pae, token_chain_ids, token_res_ids 83 | 84 | json_data_confidence[model_num].keys() 85 | # chain_iptm, chain_pair_iptm, chain_pair_pae_min, chain_ptm, fraction_disordered, has_clash, iptm, num_recycles, ptm, ranking_score 86 | 87 | 88 | # processing the data 89 | # chain num 90 | chain_to_num, chain_to_start_index, chain_to_end_index, token_chain_nums = process_token_chains(json_data_full[model_num]["token_chain_ids"]) 91 | 92 | # get xticks data 93 | token_res_ids = json_data_full[model_num]["token_res_ids"] 94 | xticks_loc = [] 95 | xticks_present = [] 96 | for i in range(len(token_res_ids)): 97 | if token_res_ids == 1 or token_res_ids[i]%200 == 0: 98 | xticks_loc.append(i) 99 | xticks_present.append(token_res_ids[i]) 100 | 101 | 102 | # Assuming 'json_data_full' and 'model_num' are already defined and available 103 | fig, ax = plt.subplots(figsize=(4, 4)) 104 | 105 | # Display the data 106 | pae_array = np.array(json_data_full[model_num]["pae"]) 107 | image = ax.imshow(pae_array, cmap="Greens_r", vmin=0, vmax=30) 108 | 109 | # ax.set_xticks(xticks_loc, xticks_present) 110 | # ax.set_xticks([]) 111 | # ax.set_xticks(np.arange(0, len(token_res_ids), 200)) 112 | ax.set_yticks([]) 113 | # set the frame to dashed line 114 | for spine in ax.spines.values(): 115 | spine.set_linestyle("--") 116 | spine.set_linewidth(1) 117 | spine.set_color("k") 118 | 119 | 120 | # Create an axes on the right side of ax, which will match the height of ax 121 | divider = make_axes_locatable(ax) 122 | ax_colorbar = divider.append_axes("right", size="5%", pad=0.2) 123 | ax_topbar = divider.append_axes("top", size="8%", pad=0.03) 124 | ax_leftbar = divider.append_axes("left", size="8%", pad=0.03) 125 | 126 | 127 | # topbar 128 | ax_topbar.imshow(token_chain_nums, cmap="tab10", aspect="auto", alpha=0.7) 129 | hide_axes_frame(ax_topbar) 130 | 131 | 132 | # leftbar 133 | ax_leftbar.imshow(token_chain_nums.T, cmap="tab10", aspect="auto", alpha=0.7) 134 | hide_axes_frame(ax_leftbar) 135 | 136 | # colorbar 137 | colorbar = fig.colorbar(image, cax=ax_colorbar, label="PAE (Å)") 138 | 139 | 140 | # plot a axhline at the start and end of each token 141 | for char, start_index in chain_to_start_index.items(): 142 | if start_index != 0: 143 | ax.axhline(start_index - 0.5, color="k", linewidth=1, linestyle="--") 144 | ax.axvline(start_index - 0.5, color="k", linewidth=1, linestyle="--") 145 | ax_topbar.axvline(start_index - 0.5, color="w", linewidth=1, linestyle="-") 146 | ax_leftbar.axhline(start_index - 0.5, color="w", linewidth=1, linestyle="-") 147 | 148 | # Adding text annotations at the center of each token chain 149 | for char in chain_to_start_index: 150 | start_index = chain_to_start_index[char] 151 | end_index = chain_to_end_index[char] 152 | center_index = (start_index + end_index) / 2 153 | # Add text to top bar 154 | ax_topbar.text(center_index, 0, char, color='#222222', ha='center', va='center') 155 | # Add text to left bar 156 | ax_leftbar.text(0, center_index, char, color='#222222', ha='center', va='center') 157 | 158 | # Show the plot 159 | plt.savefig(f"{alphafold_prediction_name}/figure_pae_{model_num}.png", dpi=300, bbox_inches="tight", transparent=False) 160 | 161 | 162 | 163 | 164 | 165 | if __name__ == "__main__": 166 | if len(sys.argv) < 2: 167 | print("Usage: python vis_pae.py ") 168 | sys.exit(1) 169 | alphafold_prediction_name = sys.argv[1] 170 | main(alphafold_prediction_name) 171 | --------------------------------------------------------------------------------