├── .flake8 ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── README.md ├── cryosim ├── README.md ├── add_ctf.py ├── add_noise.py ├── pdb2mrc.py ├── project3d.py └── subsample_ctf.py ├── metrics ├── fsc │ ├── README.md │ ├── cdrgn_per_img_fsc.py │ ├── old │ │ └── per_conf │ │ │ ├── cdrgn.py │ │ │ ├── cryosparc_3dcls.py │ │ │ ├── cryosparc_3dflex.py │ │ │ ├── cryosparc_3dva.py │ │ │ ├── cryosparc_abinitio.py │ │ │ ├── drgnai.py │ │ │ ├── opusdsd.py │ │ │ ├── per_conf_calc.py │ │ │ └── re_covar.py │ ├── plot_fsc.py │ └── utils │ │ ├── conformations.py │ │ ├── interface.py │ │ └── volumes.py ├── information_imbalance │ ├── compute_information_imbalance.py │ ├── figures_for_paper.ipynb │ └── submission.sh ├── methods │ └── recovar_scripts │ │ ├── Per_img_gen.ipynb │ │ ├── Reorder_to_original_idx.ipynb │ │ ├── angle_refinement_test.ipynb │ │ ├── gen_avg_latent_vol.py │ │ ├── gen_med_latent_vol.py │ │ ├── gen_vol_for_per_conf_fsc.py │ │ ├── gen_vol_for_per_conf_fsc_ribosembly.py │ │ ├── make_dataset.py │ │ ├── per_image_generation.py │ │ ├── per_img_fsc.py │ │ └── tests_linear_interp.ipynb ├── neighborhood_similarity │ ├── README.md │ ├── cal_neighb_hit_werror.py │ ├── conf-het-1_wrangled_latents.npz │ └── make_plot.py ├── pose.py ├── pose_error │ ├── rot_error.py │ └── trans_error.py ├── utils │ ├── align.py │ ├── align.slurm │ ├── align_multi.py │ └── calculate_GT_metrics.py └── visualization │ ├── .ipynb_checkpoints │ ├── 20231222_visualization_new_ver-checkpoint.ipynb │ ├── 20240118_umap_visualization-checkpoint.ipynb │ ├── 20240409_conf_v1_umap_visualization-checkpoint.ipynb │ ├── conf-het_latent_visualize-checkpoint.ipynb │ ├── dynamight-checkpoint.ipynb │ └── mixhet_latent_visualize-checkpoint.ipynb │ ├── README.md │ ├── calculate_IgG-RL_gt_latents.py │ ├── conf-het-2_CV_dihedral_distance.npy │ ├── visualize_umap_IgG-1D.py │ ├── visualize_umap_IgG-RL.py │ ├── visualize_umap_Ribosembly.py │ └── visualize_umap_Tomotwin-100.py └── requirements.txt /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | extend-ignore = E203,E402,E501,F821 3 | max-complexity = 99 4 | max-line-length = 88 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "metrics/methods/recovar"] 2 | path = metrics/methods/recovar 3 | url = git@github.com:ma-gilles/recovar.git 4 | [submodule "metrics/methods/opusDSD"] 5 | path = metrics/methods/opusDSD 6 | url = git@github.com:alncat/opusDSD.git 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | 4 | exclude: '.cs$|.star$' 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v3.2.0 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: end-of-file-fixer 12 | 13 | - repo: https://github.com/pycqa/flake8 14 | rev: '4.0.1' 15 | hooks: 16 | - id: flake8 17 | 18 | - repo: https://github.com/psf/black 19 | rev: 22.10.0 20 | hooks: 21 | - id: black 22 | language_version: python3 23 | 24 | - repo: https://github.com/MarcoGorelli/absolufy-imports 25 | rev: v0.3.1 26 | hooks: 27 | - id: absolufy-imports 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CryoBench: Diverse and challenging datasets for the heterogeneity problem in cryo-EM 2 | 3 | ## Documentation 4 | 5 | The latest documentation for CryoBench is available at our [homepage](https://cryobench.cs.princeton.edu/) and also at 6 | our [manual](https://ez-lab.gitbook.io/cryobench). 7 | 8 | For any feedback, questions, or bugs, please file a Github issue, start a Github discussion, or email. 9 | 10 | ## Installation 11 | To run the metrics, you have to install `cryodrgn`. 12 | `cryodrgn` may be installed via `pip`, and we recommend installing `cryodrgn` in a clean conda environment. 13 | 14 | # Create and activate conda environment 15 | (base) $ conda create --name cryodrgn python=3.9 16 | (cryodrgn) $ conda activate cryodrgn 17 | 18 | # install cryodrgn 19 | (cryodrgn) $ pip install cryodrgn 20 | 21 | More installation instructions are found in the [documentation](https://ez-lab.gitbook.io/cryodrgn/installation). 22 | 23 | Datasets are available for download at Zenodo. 24 | 25 | 1. Conf-het (IgG-1D, IgG-RL): [https://zenodo.org/records/11629428](https://zenodo.org/records/11629428). 26 | 2. Comp-het (Ribosembly, Tomotwin-100): [https://zenodo.org/records/12528292](https://zenodo.org/records/12528292). 27 | 3. Spike-MD: [https://zenodo.org/records/14941494](https://zenodo.org/records/14941494). 28 | 29 | ## Image Formation 30 | Look at the repo [cryosim](https://github.com/ml-struct-bio/CryoBench/tree/main/cryosim). 31 | 32 | ## Metrics 33 | 34 | ### 1. Per-image FSCs 35 | Look at the repo [metrics/fsc](https://github.com/ml-struct-bio/CryoBench/tree/main/metrics/fsc) 36 | 37 | ### 2. UMAP visualization 38 | Look at the repo [metrics/visualization](https://github.com/ml-struct-bio/CryoBench/tree/main/metrics/visualization) 39 | 40 | ## Contact 41 | 42 | Please submit any bug reports, feature requests, or general usage feedback as a github issue or discussion. 43 | 44 | ## Reference: 45 | 46 | Jeon, Minkyu, et al. "CryoBench: Diverse and challenging datasets for the heterogeneity problem in cryo-EM." NeurIPS 2024 Spotlight. [paper](https://arxiv.org/abs/2408.05526). 47 | 48 | ``` 49 | @inproceedings{jeon2024cryobench, 50 | author = {Jeon, Minkyu and Raghu, Rishwanth and Astore, Miro and Woollard, Geoffrey and Feathers, Ryan and Kaz, Alkin and Hanson, Sonya M. and Cossio, Pilar and Zhong, Ellen D.}, 51 | booktitle = {Advances in Neural Information Processing Systems}, 52 | title = {CryoBench: Diverse and challenging datasets for the heterogeneity problem in cryo-EM}, 53 | year = {2024} 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /cryosim/README.md: -------------------------------------------------------------------------------- 1 | # cryosim: Tools for generating synthetic cryo-EM images 2 | This repository is built upon https://github.com/ml-struct-bio/cryosim/tree/main 3 | ### Dependencies: 4 | * cryodrgn version 3.4.0 5 | * ChimeraX, if generating volumes for the Spike-MD dataset 6 | 7 | ### Generating volumes 8 | For converting a large number of atomic structures, saved as a trajectory, to cryo-EM volumes, the _pdb2mrc.py_ script can be used. For the Spike-MD dataset, the command is 9 | ``` 10 | python pdb2mrc.py seed_structure.pdb sampled_pdbs.xtc 100000 --Apix 1.5 -D 256 --res 3 -c ~/chimerax-1.6.1/bin/ChimeraX -o volumes 11 | ``` 12 | 13 | ### Generating CTF parameters 14 | ``` 15 | # Subsample 100k CTF parameters from an experimental dataset without replacement 16 | $ python subsample_ctf.py experimental_ctf.pkl -o ctf.pkl -N 100000 --Apix 1.5 -D 256 --seed 0 17 | 18 | # Or create 100 separate files for each GT conformation and combine 19 | $ for i in {0..99}; do python subsample_ctf.py experimental_ctf.pkl -N 1000 -D 256 --Apix 1.5 --seed $i -o ctf.${i}.pkl; done 20 | $ cryodrgn_utils concat_pkls $(for i in {0..99}; do echo ctf.${i}.pkl; done) -o ctf.combined.pkl 21 | ``` 22 | 23 | ### Generate projection images of a volume 24 | ``` 25 | # Generate 1k projection images from a volume 26 | $ python project3d.py input.mrc -N 1000 -o output_projections.mrcs --out-pose poses.pkl --t-extent 20 27 | 28 | # Or generate 1k projection images for 100 volumes (Total 100k images) 29 | $ for i in {0..99}; do python project3d.py input.${i}.mrc -N 1000 -o output_projections.${i}.mrcs --out-pose poses.${i}.pkl --t-extent 20; done 30 | $ for i in {0..99}; echo output_projections.${i}.mrcs >> projection_combined.txt; done 31 | 32 | # Integrate all poses to make one pkl file 33 | $ cryodrgn_utils concat_pkls $(for i in {0..99}; do echo pose.${i}.pkl; done) -o pose.combined.pkl 34 | ``` 35 | 36 | ### Add CTF 37 | ``` 38 | # Add CTF to one mrcs 39 | $ python add_ctf.py input_particles.mrcs --Apix 1.5 --ctf-pkl ctf.pkl --s1 0 --s2 0 -o output_particles_w_ctf.mrcs 40 | 41 | # Or generate 100 CTF added mrcs files 42 | $ for i in {0..99}; do python add_ctf.py input_particles.${i}.mrcs --Apix 1.5 --ctf-pkl ctf.${i}.pkl; done 43 | $ for i in {0..99}; echo output_particles_w_ctf.${i}.mrcs >> CTF_added_combined.txt; done 44 | ``` 45 | 46 | ### Add gaussian noise to SNR of 0.01 47 | ``` 48 | # Add gaussian noise to SNR of 0.01 49 | $ python add_noise.py input_noiseless_particles.mrcs -o output_noisy_particles.mrcs --snr 0.01 50 | 51 | # Or generate 100 noise added mrcs files 52 | $ for i in {0..99}; do python add_noise.py input_noiseless_particles.${i}.mrcs -o output_noisy_particles.mrcs --snr 0.01; done 53 | $ for i in {0..99}; echo output_noisy_particles.${i}.mrcs >> noise_added_combined.txt; done 54 | ``` 55 | 56 | ### Downsample particles 57 | ``` 58 | # Downsample from D = 256 to D = 128 59 | $ for i in {0..99} 60 | do 61 | formatted_i=$(printf "%03d" "$i") 62 | cryodrgn downsample output_noisy_particles.${formatted_i}.mrcs -D 128 -o output_noisy_particles.${formatted_i}.128.mrcs 63 | done 64 | ``` 65 | -------------------------------------------------------------------------------- /cryosim/add_noise.py: -------------------------------------------------------------------------------- 1 | """Add noise to a particle stack at a desired SNR""" 2 | 3 | import argparse 4 | import numpy as np 5 | import sys, os 6 | import matplotlib.pyplot as plt 7 | 8 | from cryodrgn import utils 9 | from cryodrgn import mrcfile 10 | from cryodrgn.lattice import EvenLattice 11 | 12 | log = print 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description=__doc__) 17 | parser.add_argument("mrcs", help="Input particles (.mrcs, .star, or .txt)") 18 | parser.add_argument("--snr", type=float) 19 | parser.add_argument("--sigma", type=float) 20 | parser.add_argument( 21 | "--invert", action="store_true", help="invert data (mult by -1)" 22 | ) 23 | parser.add_argument( 24 | "--mask", 25 | choices=("none", "strict", "circular"), 26 | help="Type of mask for computing signal variance", 27 | ) 28 | parser.add_argument("--mask-r", type=int, help="Radius for circular mask") 29 | parser.add_argument( 30 | "--datadir", 31 | help="Optionally overwrite path to starfile .mrcs if loading from a starfile", 32 | ) 33 | parser.add_argument( 34 | "-o", type=os.path.abspath, required=True, help="Output particle stack" 35 | ) 36 | parser.add_argument("--out-png") 37 | return parser 38 | 39 | 40 | def plot_projections(out_png, imgs): 41 | fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(10, 10)) 42 | axes = axes.ravel() 43 | for i in range(min(len(imgs), 9)): 44 | axes[i].imshow(imgs[i]) 45 | plt.savefig(out_png) 46 | 47 | 48 | def mkbasedir(out): 49 | if not os.path.exists(os.path.dirname(out)): 50 | os.makedirs(os.path.dirname(out)) 51 | 52 | 53 | def warnexists(out): 54 | if os.path.exists(out): 55 | log("Warning: {} already exists. Overwriting.".format(out)) 56 | 57 | 58 | def main(args): 59 | assert (args.snr is None) != (args.sigma is None) # xor 60 | 61 | mkbasedir(args.o) 62 | warnexists(args.o) 63 | 64 | # load particles 65 | particles = mrcfile.parse_mrc(args.mrcs)[0] 66 | log(particles.shape) 67 | Nimg, D, D = particles.shape 68 | 69 | # compute noise variance 70 | if args.sigma: 71 | sigma = args.sigma 72 | else: 73 | Nstd = min(10000, Nimg) 74 | if args.mask == "strict": 75 | mask = np.where(particles[:Nstd] > 0) 76 | std = np.std(particles[mask]) 77 | elif args.mask == "circular": 78 | lattice = EvenLattice(D) 79 | mask = lattice.get_circular_mask(args.mask_r) 80 | mask = np.where(mask)[0] # convert from torch uint mask to array index 81 | std = np.std(particles[:Nstd].reshape(Nstd, -1)[:, mask]) 82 | else: 83 | std = np.std(particles[:Nstd]) 84 | sigma = std / np.sqrt(args.snr) 85 | 86 | # add noise 87 | log("Adding noise with std {}".format(sigma)) 88 | particles += np.random.normal(0, sigma, particles.shape) 89 | if args.invert: 90 | print("invert data!") 91 | particles *= -1 92 | # save particles 93 | mrcfile.write_mrc(args.o, particles.astype(np.float32)) 94 | 95 | if args.out_png: 96 | plot_projections(args.out_png, particles[:9]) 97 | 98 | log("Done") 99 | 100 | 101 | if __name__ == "__main__": 102 | main(parse_args().parse_args()) 103 | -------------------------------------------------------------------------------- /cryosim/pdb2mrc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import subprocess 4 | import numpy as np 5 | from cryodrgn import mrcfile 6 | 7 | CHUNK = 10000 # restart ChimeraX session periodically to avoid OOM errors 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='Generate mrc volumes from atomic model trajectory') 11 | parser.add_argument('pdb', help='Path to seed PDB file') 12 | parser.add_argument('traj', help='Path to trajectory file') 13 | parser.add_argument('num_models', type=int, help='Number of structures in the trajectory to generate volumes for') 14 | parser.add_argument('--Apix', type=float, default=1.5, help='Pixel size of volumes') 15 | parser.add_argument('-D', type=int, default=256, help='Box size of volumes') 16 | parser.add_argument('--res', type=float, default=3.0, help='Resolution to simulate density') 17 | parser.add_argument('-c', required=True, help='Path to ChimeraX binary, e.g. ~/chimerax-1.6.1/bin/ChimeraX') 18 | parser.add_argument('-o', required=True, help='Path to directory where volumes will be stored') 19 | return parser.parse_args() 20 | 21 | 22 | class CXCFile: 23 | def __init__(self): 24 | self.commands = [] 25 | 26 | def add(self, command: str): 27 | self.commands.append(command) 28 | 29 | def save(self, file_path: os.path.abspath): 30 | with open(file_path, "w") as file: 31 | file.writelines('\n'.join(self.commands)) 32 | 33 | def execute(self, chimerax_path: os.path.abspath, cxc_path: os.path.abspath): 34 | self.save(cxc_path) 35 | chimerax_command = [chimerax_path, "--nogui", "--cmd", f"open {cxc_path}"] 36 | try: 37 | subprocess.run(chimerax_command, check=True) 38 | except subprocess.CalledProcessError as e: 39 | print(f"Error: {e}") 40 | os.remove(cxc_path) 41 | 42 | 43 | def pad_vol(path, Apix, D): 44 | data, header = mrcfile.parse_mrc(path) 45 | x,y,z = data.shape 46 | new_data = np.zeros((D,D,D), dtype=np.float32) 47 | i, j, k = (D-x)//2, (D-y)//2, (D-z)//2 48 | new_data[i:(i+x),j:(j+y),k:(k+z)] = data 49 | orig_x, orig_y, orig_z = header.origin 50 | new_header = mrcfile.get_mrc_header( 51 | new_data, True, 52 | Apix=Apix, 53 | xorg=(orig_x-k*Apix), 54 | yorg=(orig_y-j*Apix), 55 | zorg=(orig_z-i*Apix) 56 | ) 57 | mrcfile.write_mrc(path, new_data, new_header) 58 | 59 | 60 | def center_all_vols(num_models, outdir): 61 | for i in range(num_models): 62 | path = os.path.join(outdir, f'vol_{i:05d}.mrc') 63 | data, header = mrcfile.parse_mrc(path) 64 | header.origin = (0., 0., 0.) 65 | mrcfile.write_mrc(path, data, header) 66 | 67 | 68 | def generate_ref_vol(pdb_path, outdir, chimerax_path, res, Apix, D): 69 | cxc = CXCFile() 70 | cxc.add(f"open {os.path.abspath(pdb_path)}") 71 | cxc.add(f"molmap #1 {res} gridSpacing {Apix}") 72 | cxc.add(f"save {os.path.abspath(os.path.join(outdir, 'ref.mrc'))} #2") 73 | cxc.add("exit") 74 | cxc.execute(chimerax_path, os.path.abspath(os.path.join(outdir, 'commands.cxc'))) 75 | pad_vol(os.path.abspath(os.path.join(outdir, 'ref.mrc')), Apix, D) 76 | 77 | 78 | def generate_all_vols(pdb_path, traj_path, num_models, outdir, chimerax_path, res, Apix): 79 | for start in range(0, num_models, CHUNK): 80 | cxc = CXCFile() 81 | cxc.add(f"open {os.path.abspath(pdb_path)}") 82 | cxc.add(f"open {os.path.abspath(traj_path)}") 83 | cxc.add(f"open {os.path.abspath(os.path.join(outdir, 'ref.mrc'))}") 84 | for i in range(start, min(start+CHUNK, num_models)): 85 | cxc.add(f"coordset #1 {i+1}") 86 | cxc.add(f"molmap #1 {res} gridSpacing {Apix}") 87 | cxc.add("vol resample #3 onGrid #2") 88 | cxc.add(f"save {os.path.abspath(os.path.join(outdir, f'vol_{i:05d}.mrc'))} #4") 89 | cxc.add("close #3-4") 90 | cxc.add("exit") 91 | cxc.execute(chimerax_path, os.path.abspath(os.path.join(outdir, 'commands.cxc'))) 92 | center_all_vols(num_models, outdir) 93 | os.remove(os.path.join(outdir, 'ref.mrc')) 94 | 95 | 96 | if __name__=="__main__": 97 | args = parse_args() 98 | os.makedirs(args.o) 99 | generate_ref_vol(args.pdb, args.o, args.c, args.res, args.Apix, args.D) 100 | generate_all_vols(args.pdb, args.traj, args.num_models, args.o, args.c, args.res, args.Apix) 101 | -------------------------------------------------------------------------------- /cryosim/project3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate projections of a 3D volume 3 | """ 4 | 5 | import argparse 6 | import numpy as np 7 | import sys, os 8 | import time 9 | import pickle 10 | from scipy.ndimage.fourier import fourier_shift 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.utils.data as data 16 | 17 | from cryodrgn import utils 18 | from cryodrgn import mrcfile 19 | from cryodrgn import lie_tools 20 | from cryodrgn import so3_grid 21 | 22 | import matplotlib 23 | 24 | matplotlib.use("Agg") 25 | import matplotlib.pyplot as plt 26 | 27 | log = print 28 | 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser(description=__doc__) 32 | parser.add_argument("mrc", help="Input volume") 33 | parser.add_argument( 34 | "-o", 35 | type=os.path.abspath, 36 | required=True, 37 | help="Output projection stack (.mrcs)", 38 | ) 39 | parser.add_argument( 40 | "--out-pose", type=os.path.abspath, required=True, help="Output poses (.pkl)" 41 | ) 42 | parser.add_argument( 43 | "--out-png", type=os.path.abspath, help="Montage of first 9 projections" 44 | ) 45 | parser.add_argument( 46 | "--in-pose", 47 | type=os.path.abspath, 48 | help="Optionally provide input poses instead of random poses (.pkl)", 49 | ) 50 | parser.add_argument("-N", type=int, help="Number of random projections") 51 | parser.add_argument( 52 | "-b", type=int, default=100, help="Minibatch size (default: %(default)s)" 53 | ) 54 | parser.add_argument( 55 | "--t-extent", 56 | type=float, 57 | default=5, 58 | help="Extent of image translation in pixels (default: +/-%(default)s)", 59 | ) 60 | parser.add_argument( 61 | "--grid", 62 | type=int, 63 | help="Generate projections on a uniform deterministic grid on SO3. Specify resolution level", 64 | ) 65 | parser.add_argument( 66 | "--tilt", type=float, help="Right-handed x-axis tilt offset in degrees" 67 | ) 68 | parser.add_argument("--seed", type=int, help="Random seed") 69 | parser.add_argument( 70 | "-v", "--verbose", action="store_true", help="Increaes verbosity" 71 | ) 72 | return parser 73 | 74 | 75 | class Projector: 76 | def __init__(self, vol, tilt=None): 77 | nz, ny, nx = vol.shape 78 | assert nz == ny == nx, "Volume must be cubic" 79 | x2, x1, x0 = np.meshgrid( 80 | np.linspace(-1, 1, nz, endpoint=True), 81 | np.linspace(-1, 1, ny, endpoint=True), 82 | np.linspace(-1, 1, nx, endpoint=True), 83 | indexing="ij", 84 | ) 85 | 86 | lattice = np.stack([x0.ravel(), x1.ravel(), x2.ravel()], 1).astype(np.float32) 87 | self.lattice = torch.from_numpy(lattice) 88 | 89 | self.vol = torch.from_numpy(vol.astype(np.float32)) 90 | self.vol = self.vol.unsqueeze(0) 91 | self.vol = self.vol.unsqueeze(0) 92 | 93 | self.nz = nz 94 | self.ny = ny 95 | self.nx = nx 96 | 97 | # FT is not symmetric around origin 98 | D = nz 99 | c = 2 / (D - 1) * (D / 2) - 1 100 | self.center = torch.tensor([c, c, c]) # pixel coordinate for vol[D/2,D/2,D/2] 101 | 102 | if tilt is not None: 103 | assert tilt.shape == (3, 3) 104 | tilt = torch.tensor(tilt) 105 | self.tilt = tilt 106 | 107 | def rotate(self, rot): 108 | B = rot.size(0) 109 | if self.tilt is not None: 110 | rot = self.tilt @ rot 111 | grid = self.lattice @ rot # B x D^3 x 3 112 | grid = grid.view(-1, self.nz, self.ny, self.nx, 3) 113 | offset = ( 114 | self.center - grid[:, int(self.nz / 2), int(self.ny / 2), int(self.nx / 2)] 115 | ) 116 | grid += offset[:, None, None, None, :] 117 | grid = grid.view(1, -1, self.ny, self.nx, 3) 118 | vol = F.grid_sample(self.vol, grid) 119 | vol = vol.view(B, self.nz, self.ny, self.nx) 120 | return vol 121 | 122 | def project(self, rot): 123 | return self.rotate(rot).sum(dim=1) 124 | 125 | 126 | class Poses(data.Dataset): 127 | def __init__(self, pose_pkl): 128 | poses = utils.load_pkl(pose_pkl) 129 | self.rots = torch.tensor(poses[0]).float() 130 | self.trans = poses[1] 131 | self.N = len(poses[0]) 132 | assert self.rots.shape == (self.N, 3, 3) 133 | assert self.trans.shape == (self.N, 2) 134 | assert self.trans.max() < 1 135 | 136 | def __len__(self): 137 | return self.N 138 | 139 | def __getitem__(self, index): 140 | return self.rots[index] 141 | 142 | 143 | class RandomRot(data.Dataset): 144 | def __init__(self, N): 145 | self.N = N 146 | self.rots = lie_tools.random_SO3(N) 147 | 148 | def __len__(self): 149 | return self.N 150 | 151 | def __getitem__(self, index): 152 | return self.rots[index] 153 | 154 | 155 | class GridRot(data.Dataset): 156 | def __init__(self, resol): 157 | quats = so3_grid.grid_SO3(resol) 158 | self.rots = lie_tools.quaternions_to_SO3(torch.tensor(quats)) 159 | self.N = len(self.rots) 160 | 161 | def __len__(self): 162 | return self.N 163 | 164 | def __getitem__(self, index): 165 | return self.rots[index] 166 | 167 | 168 | def plot_projections(out_png, imgs): 169 | fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(10, 10)) 170 | axes = axes.ravel() 171 | for i in range(min(len(imgs), 9)): 172 | axes[i].imshow(imgs[i]) 173 | plt.savefig(out_png) 174 | 175 | 176 | def mkbasedir(out): 177 | if not os.path.exists(os.path.dirname(out)): 178 | os.makedirs(os.path.dirname(out)) 179 | 180 | 181 | def warnexists(out): 182 | if os.path.exists(out): 183 | log("Warning: {} already exists. Overwriting.".format(out)) 184 | 185 | 186 | def translate_img(img, t): 187 | """ 188 | img: BxYxX real space image 189 | t: Bx2 shift in pixels 190 | """ 191 | ff = np.fft.fft2(np.fft.fftshift(img)) 192 | ff = fourier_shift(ff, t) 193 | return np.fft.fftshift(np.fft.ifft2(ff)).real 194 | 195 | 196 | def main(args): 197 | for out in (args.o, args.out_png, args.out_pose): 198 | if not out: 199 | continue 200 | mkbasedir(out) 201 | warnexists(out) 202 | 203 | if args.in_pose is None and args.t_extent == 0.0: 204 | log("Not shifting images") 205 | elif args.in_pose is None: 206 | assert args.t_extent > 0 207 | 208 | if args.seed is not None: 209 | np.random.seed(args.seed) 210 | torch.manual_seed(args.seed) 211 | 212 | use_cuda = torch.cuda.is_available() 213 | log("Use cuda {}".format(use_cuda)) 214 | if use_cuda: 215 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 216 | 217 | t1 = time.time() 218 | vol, _ = mrcfile.parse_mrc(args.mrc) 219 | log("Loaded {} volume".format(vol.shape)) 220 | 221 | if args.tilt: 222 | theta = args.tilt * np.pi / 180 223 | args.tilt = np.array( 224 | [ 225 | [1.0, 0.0, 0.0], 226 | [0, np.cos(theta), -np.sin(theta)], 227 | [0, np.sin(theta), np.cos(theta)], 228 | ] 229 | ).astype(np.float32) 230 | 231 | projector = Projector(vol, args.tilt) 232 | if use_cuda: 233 | projector.lattice = projector.lattice.cuda() 234 | projector.vol = projector.vol.cuda() 235 | 236 | if args.grid is not None: 237 | rots = GridRot(args.grid) 238 | log( 239 | "Generating {} rotations at resolution level {}".format( 240 | len(rots), args.grid 241 | ) 242 | ) 243 | elif args.in_pose is not None: 244 | rots = Poses(args.in_pose) 245 | log("Generating {} rotations from {}".format(len(rots), args.grid)) 246 | else: 247 | log("Generating {} random rotations".format(args.N)) 248 | rots = RandomRot(args.N) 249 | 250 | log("Projecting...") 251 | imgs = [] 252 | iterator = data.DataLoader(rots, batch_size=args.b) 253 | for i, rot in enumerate(iterator): 254 | log("Projecting {}/{}".format((i + 1) * len(rot), args.N)) 255 | projections = projector.project(rot) 256 | projections = projections.cpu().numpy() 257 | imgs.append(projections) 258 | 259 | td = time.time() - t1 260 | log("Projected {} images in {}s ({}s per image)".format(rots.N, td, td / rots.N)) 261 | imgs = np.vstack(imgs) 262 | 263 | if args.in_pose is None and args.t_extent: 264 | log("Shifting images between +/- {} pixels".format(args.t_extent)) 265 | trans = np.random.rand(args.N, 2) * 2 * args.t_extent - args.t_extent 266 | elif args.in_pose is not None: 267 | log("Shifting images by input poses") 268 | D = imgs.shape[-1] 269 | trans = rots.trans * D # convert to pixels 270 | trans = -trans[:, ::-1] # convention for scipy 271 | else: 272 | trans = None 273 | 274 | if trans is not None: 275 | imgs = np.asarray([translate_img(img, t) for img, t in zip(imgs, trans)]) 276 | # convention: we want the first column to be x shift and second column to be y shift 277 | # reverse columns since current implementation of translate_img uses scipy's 278 | # fourier_shift, which is flipped the other way 279 | # convention: save the translation that centers the image 280 | trans = -trans[:, ::-1] 281 | # convert translation from pixel to fraction 282 | D = imgs.shape[-1] 283 | assert D % 2 == 0 284 | trans /= D 285 | 286 | log("Saving {}".format(args.o)) 287 | mrcfile.write_mrc(args.o, imgs.astype(np.float32)) 288 | log("Saving {}".format(args.out_pose)) 289 | rots = rots.rots.cpu().numpy() 290 | with open(args.out_pose, "wb") as f: 291 | if args.t_extent: 292 | pickle.dump((rots, trans), f) 293 | else: 294 | pickle.dump(rots, f) 295 | if args.out_png: 296 | log("Saving {}".format(args.out_png)) 297 | plot_projections(args.out_png, imgs[:9]) 298 | 299 | 300 | if __name__ == "__main__": 301 | args = parse_args().parse_args() 302 | utils._verbose = args.verbose 303 | main(args) 304 | -------------------------------------------------------------------------------- /cryosim/subsample_ctf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Subsample CTF parameters from existing ctf.pkl 3 | """ 4 | 5 | import pickle 6 | import os 7 | import re 8 | import numpy as np 9 | import argparse 10 | from cryodrgn import utils 11 | from cryodrgn import ctf 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description=__doc__) 16 | parser.add_argument( 17 | "ctf_file", type=os.path.abspath, help="Input ctf.pkl to subsample" 18 | ) 19 | parser.add_argument( 20 | "-o", 21 | "--out-ctf", 22 | type=os.path.abspath, 23 | required=True, 24 | help="Output ctf.pkl file", 25 | ) 26 | parser.add_argument("-N", type=int, required=True, help="Number of CTFs") 27 | parser.add_argument( 28 | "--Apix", 29 | type=float, 30 | default=1.5, 31 | help="Overwrite pixel size (A/pix) (default: %(default)s)", 32 | ) 33 | parser.add_argument( 34 | "-D", 35 | type=int, 36 | default=256, 37 | help="Overwrite image size (pixels) (default: %(default)s)", 38 | ) 39 | parser.add_argument("--seed", type=int, help="Random seed") 40 | return parser 41 | 42 | 43 | def mkbasedir(out): 44 | if not os.path.exists(os.path.dirname(out)): 45 | os.makedirs(os.path.dirname(out)) 46 | 47 | 48 | def main(args): 49 | if args.seed is None: 50 | args.seed = np.random.randint() 51 | else: 52 | np.random.seed(args.seed) 53 | mkbasedir(args.out_ctf) 54 | 55 | data = utils.load_pkl(args.ctf_file) 56 | print(f"Loaded {data.shape} ctf.pkl") 57 | 58 | sampled_indices = np.random.choice(data.shape[0], size=args.N, replace=False) 59 | new_ctf = data[sampled_indices] 60 | if args.D: 61 | new_ctf[:, 0] = args.D 62 | if args.Apix: 63 | new_ctf[:, 1] = args.Apix 64 | 65 | ctf.print_ctf_params(new_ctf[0]) 66 | utils.save_pkl(new_ctf, args.out_ctf) 67 | 68 | 69 | if __name__ == "__main__": 70 | args = parse_args().parse_args() 71 | main(args) 72 | -------------------------------------------------------------------------------- /metrics/fsc/README.md: -------------------------------------------------------------------------------- 1 | # FSCs: Tools for comparing similarity of models' reconstructed volumes to ground truth 2 | 3 | This folder contains scripts to calculate Fourier shell correlations between volumes at reconstruction model latent 4 | space co-ordinates corresponding to individual particle images, as well as to visualize these FSC results. 5 | It also contains a folder `old/per_conf` for the same analyses done using an older method of reconstructing volumes at 6 | class average latent space points. 7 | 8 | 9 | ## Installation instructions 10 | 11 | We recommend using conda environments to install the dependencies for calculating these metrics, and to use a separate 12 | environment for each method **as well as for running the reconstruction model and for CryoBench analyses**. 13 | This is necessary as many of the methods have overlapping dependencies — especially cryoDRGN, which forms 14 | the basis for several of the example methods and is also used by CryoBench itself. 15 | 16 | We show here how to install these environments for benchmarking cryoDRGN; instructions for the other example methods 17 | can be found at 18 | [our manual](https://app.gitbook.com/o/gYlX75MBAfjzRuXIYbKH/s/QwtxcduDAIdbCB0vBNnT/getting-started/installation-instructions). 19 | 20 | Start by cloning the CryoBench git repository; note that we also have to fetch the codebases for the 21 | example methods through their submodules: 22 | ```bash 23 | $ git clone --recurse-submodules git@github.com:ml-struct-bio/CryoBench.git --branch='refactor' --recurse-submodules 24 | ``` 25 | 26 | You will also have to install ChimeraX, which can be done by downloading the correct version for your operating system 27 | from [their website](https://www.cgl.ucsf.edu/chimerax/download.html). Also create an environment variable pointing to 28 | this installation: 29 | ```bash 30 | $ export CHIMERAX_PATH="/myhome/software/chimerax-1.6.1/bin/ChimeraX" 31 | ``` 32 | This variable has to be re-defined every time the environment is loaded unless it is e.g. saved in your `.bashrc` file. 33 | 34 | Create an environment for running cryoDRGN models. 35 | Here we specify a recent version to use for producing reconstruction output: 36 | ```bash 37 | $ conda create --name cryodrgn_model python=3.10 38 | $ conda activate cryodrgn_model 39 | (cryodrgn_model)$ pip install 'cryodrgn==3.4.1' 40 | ``` 41 | 42 | Next, create an environment for running CryoBench analyses on cryoDRGN output. 43 | Here we instead install an older version of cryoDRGN, and also downgrade its dependencies to account for updates 44 | since this older version of cryoDRGN was released: 45 | ```bash 46 | $ conda create --name cryodrgn_bench python=3.10 47 | $ conda activate cryodrgn_bench 48 | (cryodrgn_bench)$ pip install git+https://github.com/ml-struct-bio/cryodrgn.git@2.0.0-beta 49 | (cryodrgn_bench)$ pip install 'numpy<1.27' 50 | ``` 51 | 52 | 53 | ## Generating conformations, calculating FSCs, and plotting results 54 | 55 | For CryoDRGN, we show here how to: 56 | 1) Download the IgG-1D Conf-het CryoBench dataset 57 | 2) Run both the fixed-pose and ab-initio versions of the CryoDRGN reconstruction model on the IgG-1D dataset 58 | 2) Apply the CryoBench tools to calculate FSCs between CryoDRGN model conformation volumes and IgG-1D ground truth 59 | 3) Visualize the results of the FSC analysis 60 | 61 | Corresponding instructions for the other example methods can be found at 62 | [our manual](https://app.gitbook.com/o/gYlX75MBAfjzRuXIYbKH/s/QwtxcduDAIdbCB0vBNnT/~/changes/3/getting-started/running-reconstruction-models) 63 | 64 | 65 | ### Downloading a CryoBench dataset 66 | Although you can also download the IgG-1D dataset through e.g. a web browser by navigating to the 67 | [Zenodo portal](https://zenodo.org/records/11629428), the below demonstrates how to download the data via the 68 | command-line: 69 | ```bash 70 | $ curl "https://zenodo.org/records/11629428/files/IgG-1D.zip?download=1" --output IgG-1D.zip 71 | $ unzip IgG-1D.zip 72 | ``` 73 | 74 | The commands below are assumed to be run from the same directory in which `IgG-1D/` created by `unzip` above is located. 75 | 76 | ### cryoDRGN with fixed poses 77 | 78 | We first run the reconstruction algorithm. This command took 3h 20min using 4 Tesla V100 GPUs: 79 | ```bash 80 | $ conda activate cryodrgn_model 81 | 82 | (cryodrgn_model)$ cryodrgn train_vae IgG-1D/images/snr0.01/sorted_particles.128.txt -n 20 --zdim 8 \ 83 | --ctf IgG-1D/combined_ctfs.pkl --poses IgG-1D/combined_poses.pkl \ 84 | -o cBench_input/IgG-1D/cryodrgn_fixed/ 85 | ``` 86 | 87 | We then run the CryoBench script for generating image volumes and comparing them to ground truth volumes. Because 88 | cryoDRGN output volumes are 0-indexed, the last volume from our model lasting twenty epochs is numbered `19`: 89 | ```bash 90 | $ conda activate cryodrgn_bench 91 | 92 | # Compute per image FSC 93 | (cryodrgn_bench)$ python metrics/fsc/cdrgn.py cBench_input/IgG-1D/cryodrgn_fixed/ --epoch 19 --Apix 3.0 -n 100 \ 94 | --gt-dir IgG-1D/vols/128_org/ \ 95 | --mask IgG-1D/init_mask/mask.mrc \ 96 | -o cBench_output/IgG-1D/cryodrgn_fixed/ 97 | 98 | # Plot FSCs 99 | (cryodrgn_bench)$ python metrics/fsc/per_conf_plot.py cBench-output/IgG-1D/cryodrgn_fixed/ 100 | ``` 101 | 102 | ### cryoDRGN with ab-initio poses 103 | ```bash 104 | $ conda activate cryodrgn_model 105 | 106 | (cryodrgn_model)$ cryodrgn abinit_het IgG-1D/images/snr0.005/sorted_particles.128.txt -n 30 --zdim 8 \ 107 | --ctf IgG-1D/combined_ctfs.pkl -o cBench_input/IgG-1D/cryodrgn_abinit/ 108 | ``` 109 | 110 | ```bash 111 | $ conda activate cryodrgn_bench 112 | 113 | # Compute per image FSC 114 | (cryodrgn_bench)$ python metrics/fsc/cdrgn.py cBench_input/IgG-1D/cryodrgn_abinit/ --epoch 29 --Apix 3.0 -n 100 \ 115 | --gt-dir IgG-1D/vols/128_org/ \ 116 | --mask IgG-1D/init_mask/mask.mrc \ 117 | -o cBench_output/IgG-1D/cryodrgn_abinit/ 118 | 119 | # Plot FSCs 120 | (cryodrgn_bench)$ python metrics/fsc/per_conf_plot.py cBench_output/IgG-1D/cryodrgn_abinit/ 121 | ``` 122 | -------------------------------------------------------------------------------- /metrics/fsc/cdrgn_per_img_fsc.py: -------------------------------------------------------------------------------- 1 | """Calculate FSCs between volumes from images' mappings in a cryoDRGN latent space. 2 | 3 | Example usage 4 | ------------- 5 | # See zenodo.org/records/11629428 for the Conf-het dataset used in these examples 6 | 7 | $ python metrics/fsc/cdrgn_per_img_fsc.py cryodrgn_output/train_vae/001_IgG-1D/ \ 8 | IgG-1D/gt_latents.pkl --gt_dir IgG-1D/vols/128_org/ \ 9 | -o cryobench_output/cdrgn_train-vae_001/ \ 10 | -n 100 --Apix 3.0 11 | 12 | # Sample more volumes and align before computing FSCs in parallel using compute cluster 13 | $ python metrics/fsc/cdrgn_per_img_fsc.py cryodrgn_output/abinit_het/001_IgG-1D/ \ 14 | IgG-1D/gt_latents.pkl --gt_dir IgG-1D/vols/128_org/ \ 15 | -o cryobench_output/cdrgn_abinit-het_001/ \ 16 | -n 1000 --Apix 3.0 --parallel-align 17 | 18 | """ 19 | import os 20 | import argparse 21 | import subprocess 22 | import pickle 23 | from glob import glob 24 | from time import time 25 | import logging 26 | import numpy as np 27 | from sklearn.metrics import auc 28 | from utils import volumes 29 | from cryodrgn import mrc, utils 30 | 31 | logging.basicConfig( 32 | level=logging.INFO, 33 | format="(%(levelname)s) (%(filename)s) (%(asctime)s) %(message)s", 34 | datefmt="%d-%b-%y %H:%M:%S", 35 | ) 36 | logger = logging.getLogger(__name__) 37 | 38 | CHIMERAX_PATH = os.environ["CHIMERAX_PATH"] 39 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 40 | ALIGN_PATH = os.path.join(ROOT_DIR, "utils", "align.py") 41 | 42 | 43 | def parse_args() -> argparse.Namespace: 44 | """Create and parse command line arguments for this script (see `main` below).""" 45 | parser = argparse.ArgumentParser() 46 | 47 | parser.add_argument( 48 | "traindir", 49 | help="Path to folder with output from a cryoDRGN train_vae or abinit_het model", 50 | ) 51 | parser.add_argument( 52 | "labels", 53 | type=os.path.abspath, 54 | help=".pkl file with ground truth class index per particle", 55 | ) 56 | parser.add_argument( 57 | "-o", 58 | type=os.path.abspath, 59 | required=True, 60 | help="Path to folder where output will be saved", 61 | ) 62 | parser.add_argument("-n", required=True, type=int, help="Number of vols to sample") 63 | parser.add_argument("--Apix", required=True, type=float) 64 | 65 | parser.add_argument( 66 | "--gt-paths", help=".pkl file with path to ground truth volume per particle" 67 | ) 68 | parser.add_argument( 69 | "--gt-dir", help="path to folder with ground truth .mrc volumes per particle" 70 | ) 71 | 72 | parser.add_argument( 73 | "--mask", 74 | default=None, 75 | type=os.path.abspath, 76 | help="Path to mask .mrc to compute the masked metric", 77 | ) 78 | parser.add_argument( 79 | "--overwrite", 80 | action="store_true", 81 | help="Overwrite already-generated volumes instead of reusing them", 82 | ) 83 | 84 | parser.add_argument("--epoch", type=int, default=-1, help="epoch (default: last)") 85 | parser.add_argument("--cuda-device", default=0, type=int) 86 | parser.add_argument("--no-fscs", action="store_false", dest="calc_fsc_vals") 87 | 88 | parser.add_argument( 89 | "--serial-align", 90 | action="store_true", 91 | help="Align volumes in one after the other on the local compute.", 92 | ) 93 | parser.add_argument( 94 | "--parallel-align", 95 | action="store_true", 96 | help="Align volumes in parallel using a compute cluster", 97 | ) 98 | 99 | return parser.parse_args() 100 | 101 | 102 | def align_volumes( 103 | vol_path: str, ref_path: str, apix: float = 1.0, flip: bool = True 104 | ) -> None: 105 | """Align a volume in a .mrc file to another .mrc volume using ChimeraX.""" 106 | 107 | data, header = mrc.parse_mrc(vol_path) 108 | header.update_origin(0.0, 0.0, 0.0) 109 | header.update_apix(apix) 110 | mrc.write(vol_path, data, header) 111 | 112 | flip_str = "--flip" if flip else "" 113 | log_file = os.path.splitext(vol_path)[0] + ".txt" 114 | cmd = f'{CHIMERAX_PATH} --nogui --script "{ALIGN_PATH} {ref_path} {vol_path} ' 115 | cmd += f'{flip_str} -o {vol_path} -f {log_file} " > {log_file} ' 116 | 117 | subprocess.check_call(cmd, shell=True) 118 | 119 | 120 | def main(args: argparse.Namespace) -> None: 121 | """Running the script to get FSCs across cryoDRGN image-wise conformations.""" 122 | 123 | cfg_file = os.path.join(args.traindir, "config.yaml") 124 | if not os.path.exists(cfg_file): 125 | raise ValueError( 126 | f"Could not find cryoDRGN config file {cfg_file} " 127 | f"— is {args.traindir=} a folder cryoDRGN output folder?" 128 | ) 129 | cfg = os.path.join(args.traindir, "config.yaml") 130 | 131 | if not (args.gt_paths is None) ^ (args.gt_dir is None): 132 | raise ValueError("Must provide exactly one of --gt_paths or --gt_dir!") 133 | if args.serial_align and args.parallel_align: 134 | raise ValueError( 135 | "Cannot use parallelized volume alignment when using --serial-align!" 136 | ) 137 | 138 | labels = pickle.load(open(args.labels, "rb")) 139 | if args.gt_paths is not None: 140 | gt_paths = pickle.load(open(args.gt_paths, "rb")) 141 | 142 | if len(labels) != len(gt_paths): 143 | raise ValueError( 144 | f"Mismatch between size of labels {len(labels)} " 145 | f"and volume paths {len(gt_paths)} !" 146 | ) 147 | 148 | else: 149 | gt_files = sorted( 150 | glob(os.path.join(args.gt_dir, "*.mrc")), key=volumes.numfile_sortkey 151 | ) 152 | gt_paths = [gt_files[i] for i in labels] 153 | 154 | N = len(labels) 155 | particle_idxs = np.arange(0, N, N // args.n) 156 | os.makedirs(args.o, exist_ok=True) 157 | 158 | epoch_str = "" if args.epoch == -1 else f".{args.epoch}" 159 | checkpoint = os.path.join(args.traindir, f"weights{epoch_str}.pkl") 160 | if not os.path.exists(checkpoint): 161 | raise ValueError( 162 | f"Could not find cryoDRGN model weights for epoch {args.epoch} " 163 | f"in output folder {args.traindir=} — did the model finishing running?" 164 | ) 165 | z_path = os.path.join(args.traindir, f"z{epoch_str}.pkl") 166 | if not os.path.exists(z_path): 167 | raise ValueError( 168 | f"Could not find cryoDRGN latent space coordinates for epoch {args.epoch} " 169 | f"in output folder {args.traindir=} — did the model finishing running?" 170 | ) 171 | 172 | z = utils.load_pkl(z_path) 173 | generator = volumes.get_volume_generator(cfg, checkpoint) 174 | log_interval = max(round((len(particle_idxs) // 1000), -2), 5) 175 | gen_paths = list() 176 | 177 | for vol_i, particle_i in enumerate(particle_idxs): 178 | gen_file = os.path.join(args.o, f"vol_{vol_i:03d}.mrc") 179 | gen_paths.append(gen_file) 180 | 181 | if os.path.exists(gen_file) and not args.overwrite: 182 | continue 183 | 184 | if vol_i % log_interval == 0: 185 | logger.info( 186 | f"Generating volume {vol_i + 1}/{len(particle_idxs)} " 187 | f"(vol_{vol_i:03d}.mrc) ..." 188 | ) 189 | 190 | gt_path = gt_paths[particle_i] 191 | gen_vol = generator(z[particle_i, :]) 192 | mrc.write(gen_paths[-1], gen_vol.astype(np.float32)) 193 | if not os.path.isabs(gt_path) and args.gt_paths is not None: 194 | gt_path = os.path.join(os.path.dirname(args.gt_paths), gt_path) 195 | 196 | if args.serial_align: 197 | if vol_i % log_interval == 0: 198 | logger.info(f"Aligning volume (vol_{vol_i:03d} ...") 199 | align_volumes(gen_paths[-1], gt_path, args.Apix) 200 | 201 | if args.parallel_align: 202 | volumes.align_volumes_multi(gen_paths, gt_paths) 203 | 204 | if args.calc_fsc_vals: 205 | gt_paths_sel = [gt_paths[i] for i in particle_idxs] 206 | fsc_curves = volumes.get_fsc_curves( 207 | gt_paths_sel, gen_paths, mask_file=args.mask, outdir=args.o 208 | ) 209 | 210 | auc_vals = { 211 | particle_idxs[i]: auc(fsc_df.pixres, fsc_df.fsc.abs()) 212 | for i, fsc_df in fsc_curves.items() 213 | } 214 | aucs = {class_idx: [] for class_idx in np.unique(labels)} 215 | for i, auc_val in auc_vals.items(): 216 | aucs[labels[i]].append(auc_val) 217 | 218 | logger.info( 219 | "\n".join( 220 | [""] 221 | + [ 222 | f"No Images in Class {class_idx} " 223 | if len(class_aucs) == 0 224 | else f"Class {class_idx} ({len(class_aucs)} image) — " 225 | f"AU-FSC: {np.mean(class_aucs):.5f}" 226 | if len(class_aucs) == 1 227 | else f"Class {class_idx} ({len(class_aucs)} images) — " 228 | f"AU-FSC: {np.mean(class_aucs):.5f} +/- {np.std(class_aucs):.3f}" 229 | for class_idx, class_aucs in aucs.items() 230 | ] 231 | ) 232 | ) 233 | all_aucs = [a for auc_list in aucs.values() for a in auc_list] 234 | logger.info( 235 | f"AU-FSC Overall: {np.mean(all_aucs):.5f} +/- {np.std(all_aucs):.3f}" 236 | ) 237 | 238 | 239 | if __name__ == "__main__": 240 | args = parse_args() 241 | s = time() 242 | main(args) 243 | logger.info(f"Completed in {(time()-s):.5g} seconds") 244 | -------------------------------------------------------------------------------- /metrics/fsc/old/per_conf/cdrgn.py: -------------------------------------------------------------------------------- 1 | """Calculate FSCs between conformations matched across a cryoDRGN model latent space. 2 | 3 | Example usage 4 | ------------- 5 | $ python metrics/fsc/old/per_conf/cdrgn.py results/cryodrgn --epoch 19 --Apix 3.0 \ 6 | -o output --gt-dir ./gt_vols --mask ./mask.mrc 7 | 8 | # Also align output volumes to grund truth volumes with ChimeraX before computing FSCs 9 | $ python metrics/fsc/old/per_conf/cdrgn.py results/cryodrgn --epoch 19 --Apix 3.0 \ 10 | -o output --gt-dir ./gt_vols --mask ./mask.mrc 11 | 12 | """ 13 | import os 14 | import sys 15 | import argparse 16 | import logging 17 | import numpy as np 18 | import torch 19 | import cryodrgn.utils 20 | from cryodrgn import mrc 21 | 22 | sys.path.append( 23 | os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 24 | ) 25 | from utils import volumes, conformations, interface 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def main(args: argparse.Namespace) -> None: 31 | """Running the script to get FSCs across conformations produced by cryoDRGN.""" 32 | 33 | cfg_file = os.path.join(args.input_dir, "config.yaml") 34 | if not os.path.exists(cfg_file): 35 | raise ValueError( 36 | f"Could not find cryoDRGN config file {cfg_file} " 37 | f"— is {args.input_dir=} a folder cryoDRGN output folder?" 38 | ) 39 | 40 | epoch_str = "" if args.epoch == -1 else f".{args.epoch}" 41 | weights_fl = os.path.join(args.input_dir, f"weights{epoch_str}.pkl") 42 | if not os.path.exists(weights_fl): 43 | raise ValueError( 44 | f"Could not find cryoDRGN model weights for epoch {args.epoch} " 45 | f"in output folder {args.input_dir=} — did the model finishing running?" 46 | ) 47 | z_path = os.path.join(args.input_dir, f"z{epoch_str}.pkl") 48 | if not os.path.exists(z_path): 49 | raise ValueError( 50 | f"Could not find cryoDRGN latent space coordinates for epoch {args.epoch} " 51 | f"in output folder {args.input_dir=} — did the model finishing running?" 52 | ) 53 | 54 | logger.info(f"Putting output under: {args.outdir} ...") 55 | voldir = os.path.join(args.outdir, "vols") 56 | os.makedirs(voldir, exist_ok=True) 57 | z = cryodrgn.utils.load_pkl(z_path) 58 | num_imgs = int(args.num_imgs) if z.shape[0] == 100000 else "ribo" 59 | nearest_z_array = conformations.get_nearest_z_array(z, args.num_vols, num_imgs) 60 | 61 | # Generate volumes at these cryoDRGN latent space coordinates using the model tool 62 | out_zfile = os.path.join(args.outdir, "zfile.txt") 63 | generator = volumes.get_volume_generator(cfg_file, weights_fl) 64 | logger.info(out_zfile) 65 | 66 | if os.path.exists(out_zfile) and not args.overwrite: 67 | logger.info("Z file exists, skipping...") 68 | else: 69 | if not args.dry_run: 70 | np.savetxt(out_zfile, nearest_z_array) 71 | 72 | for i, zval in enumerate(nearest_z_array): 73 | gen_file = os.path.join(voldir, f"vol_{i:03d}.mrc") 74 | if os.path.exists(gen_file) and not args.overwrite: 75 | continue 76 | 77 | ztensor = torch.tensor(zval, device=f"cuda:{args.cuda_device}") 78 | mrc.write( 79 | gen_file, generator(ztensor).astype(np.float32), Apix=args.Apix 80 | ) 81 | 82 | # Align output conformation volumes to ground truth volumes using ChimeraX 83 | if args.align_vols: 84 | volumes.align_volumes_multi(voldir, args.gt_dir, flip=args.flip_align) 85 | 86 | if args.calc_fsc_vals: 87 | volumes.get_fsc_curves( 88 | voldir, 89 | args.gt_dir, 90 | mask_file=args.mask, 91 | fast=args.fast, 92 | overwrite=args.overwrite, 93 | ) 94 | 95 | if args.align_vols: 96 | aligndir = "flipped_aligned" if args.flip_align else "aligned" 97 | volumes.get_fsc_curves( 98 | os.path.join(voldir, aligndir), 99 | args.gt_dir, 100 | mask_file=args.mask, 101 | fast=args.fast, 102 | overwrite=args.overwrite, 103 | ) 104 | 105 | 106 | if __name__ == "__main__": 107 | main(interface.add_calc_args().parse_args()) 108 | -------------------------------------------------------------------------------- /metrics/fsc/old/per_conf/cryosparc_3dcls.py: -------------------------------------------------------------------------------- 1 | """Calculate FSCs between conformations matched across cryoSPARC 3D classifications. 2 | 3 | Example usage 4 | ------------- 5 | $ python metrics/fsc/old/per_conf/cryosparc_3dcls.py results/CS-cryobench/J5 \ 6 | -o cBench/cBench-out_3Dcls/ --gt-dir vols/128_org/ --mask bproj_0.005.mrc \ 7 | --num-classes 10 8 | 9 | """ 10 | import os 11 | import sys 12 | import argparse 13 | import json 14 | import numpy as np 15 | from glob import glob 16 | import logging 17 | 18 | sys.path.append( 19 | os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 20 | ) 21 | from utils import volumes, interface 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def add_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 27 | parser.add_argument("--num-classes", default=10, type=int) 28 | 29 | return parser 30 | 31 | 32 | def main(args: argparse.Namespace) -> None: 33 | """Running the script to get FSCs across conformations produced by cryoSPARC.""" 34 | 35 | cfg_file = os.path.join(args.input_dir, "job.json") 36 | if not os.path.exists(cfg_file): 37 | raise ValueError( 38 | f"Could not find cryoSPARC job info file `job.json` in given folder " 39 | f"{args.input_dir=} — did a cryoSPARC job use this as the output path?" 40 | ) 41 | with open(cfg_file) as f: 42 | configs = json.load(f) 43 | 44 | if configs["type"] != "class_3D": 45 | raise ValueError( 46 | f"Given folder {args.input_dir=} contains cryoSPARC job type " 47 | f"`{configs['type']=}`; this script is for ab-initio jobs (`class_3D`)!" 48 | ) 49 | 50 | file_pattern = "*.mrc" 51 | files = [ 52 | f for f in glob(os.path.join(args.input_dir, file_pattern)) if "mask" not in f 53 | ] 54 | pred_dir = sorted(files, key=volumes.numfile_sortkey) 55 | print("pred_dir[0]:", pred_dir[0]) 56 | csparc_num = pred_dir[0].split("/")[-1].split(".")[0].split("_")[3] 57 | csparc_job = pred_dir[0].split("/")[-1].split(".")[0].split("_")[0] 58 | print("cryosparc_num:", csparc_num) 59 | print("cryosparc_job:", csparc_job) 60 | 61 | lst = [] 62 | for cls in range(args.num_classes): 63 | class_fl = f"{csparc_job}_passthrough_particles_class_{cls}.cs" 64 | cs = np.load(os.path.join(args.input_dir, class_fl)) 65 | cs_new = cs[:: args.num_imgs] 66 | print(f"class {cls}: {len(cs_new)}") 67 | 68 | for cs_i in range(len(cs_new)): 69 | path = cs_new[cs_i]["blob/path"].decode("utf-8") 70 | gt = path.split("/")[-1].split("_")[1] 71 | lst.append((int(cls), int(gt))) 72 | 73 | if args.calc_fsc_vals: 74 | volumes.get_fsc_curves( 75 | args.input_dir, 76 | args.gt_dir, 77 | outdir=args.outdir, 78 | mask_file=args.mask, 79 | fast=args.fast, 80 | overwrite=args.overwrite, 81 | vol_fl_function=( 82 | lambda i: f"{csparc_job}_class_{lst[i][0]:02d}_{csparc_num}_volume" 83 | ), 84 | ) 85 | 86 | 87 | if __name__ == "__main__": 88 | main(add_args(interface.add_calc_args()).parse_args()) 89 | -------------------------------------------------------------------------------- /metrics/fsc/old/per_conf/cryosparc_3dflex.py: -------------------------------------------------------------------------------- 1 | """Calculate FSCs across cryoSPARC 3D Flex Train model conformations. 2 | 3 | Example usage 4 | ------------- 5 | $ python metrics/fsc/old/per_conf/cryosparc_3dflex.py results/CS-cryobench/J9 \ 6 | -o cBench/cBench-out_3Dflex/ --gt-dir vols/128_org/ --mask bproj_0.005.mrc 7 | --project-num P564 --job_num J13 8 | 9 | """ 10 | import numpy as np 11 | import os 12 | import sys 13 | import argparse 14 | import zipfile 15 | from cryosparc.tools import CryoSPARC 16 | from cryodrgn import analysis 17 | 18 | sys.path.append( 19 | os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 20 | ) 21 | from utils import volumes, interface 22 | 23 | # replace these as necessary with your CryoSPARC credentials 24 | license_id = None 25 | email = None 26 | password = None 27 | host = None 28 | run_lane = None 29 | 30 | 31 | def add_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 32 | parser.add_argument("--project-num", required=True) 33 | parser.add_argument("--job-num", required=True) 34 | parser.add_argument("--base-port", default=39000, type=int) 35 | 36 | return parser 37 | 38 | 39 | def main(args): 40 | """Script to get FSCs across conformations output by cryoSPARC 3D Flex Generate.""" 41 | num_vols = 100 42 | 43 | cs = CryoSPARC( 44 | license=license_id, 45 | email=email, 46 | password=password, 47 | host=host, 48 | base_port=args.base_port, 49 | ) 50 | project = cs.find_project(args.project_num) 51 | flex_job = cs.find_job(args.project_num, args.job_num) # Flex train 52 | particles = flex_job.load_output("particles") 53 | latents_job = project.create_external_job("W1", "Custom Latents") 54 | latents_job.connect("particles", args.job_num, "particles", slots=["components"]) 55 | 56 | v = np.empty((len(particles), 2)) 57 | for i in range(2): 58 | v[:, i] = particles[f"components_mode_{i}/value"] 59 | 60 | z_lst = [] 61 | z_mean_lst = [] 62 | for i in range(num_vols): 63 | z_nth = v[i * args.num_imgs : (i + 1) * args.num_imgs] 64 | z_nth_avg = z_nth.mean(axis=0) 65 | z_nth_avg = z_nth_avg.reshape(1, -1) 66 | z_lst.append(z_nth) 67 | z_mean_lst.append(z_nth_avg) 68 | 69 | nearest_z_lst = [] 70 | centers_ind_lst = [] 71 | for i in range(num_vols): 72 | nearest_z, centers_ind = analysis.get_nearest_point(z_lst[i], z_mean_lst[i]) 73 | nearest_z_lst.append(nearest_z.reshape(nearest_z.shape[-1])) 74 | centers_ind_lst.append(centers_ind) 75 | latent_pts = np.array(nearest_z_lst) 76 | 77 | slots = [ 78 | {"prefix": "components_mode_%d" % k, "dtype": "components", "required": True} 79 | for k in range(2) 80 | ] 81 | latents_dset = latents_job.add_output( 82 | type="particle", 83 | name="latents", 84 | slots=slots, 85 | title="Latents", 86 | alloc=len(latent_pts), 87 | ) 88 | 89 | for k in range(2): 90 | latents_dset["components_mode_%d/component" % k] = k 91 | latents_dset["components_mode_%d/value" % k] = latent_pts[:, k] 92 | 93 | with latents_job.run(): 94 | latents_job.save_output("latents", latents_dset) 95 | 96 | gen_job = project.create_job("W1", "flex_generate") 97 | gen_job.connect("flex_model", args.job_num, "flex_model") 98 | gen_job.connect("volume", args.job_num, "volume") 99 | gen_job.connect("latents", latents_job.uid, "latents") 100 | 101 | gen_job.queue(lane=run_lane) 102 | gen_job.wait_for_done(error_on_incomplete=True) 103 | zip_outs = [fl for fl in gen_job.list_files() if os.path.splitext(fl)[1] == ".zip"] 104 | assert len(zip_outs) == 1 105 | cryosparc_dir = os.path.join(os.path.split(os.path.realpath(args.input_dir))[0]) 106 | zip_outs = os.path.join(cryosparc_dir, gen_job.uid, zip_outs[0]) 107 | 108 | voldir = os.path.join(args.outdir, "vols") 109 | os.makedirs(voldir, exist_ok=True) 110 | with zipfile.ZipFile(zip_outs, "r") as zip_ref: 111 | zip_ref.extractall(voldir) 112 | 113 | if args.calc_fsc_vals: 114 | volumes.get_fsc_curves( 115 | voldir, 116 | args.gt_dir, 117 | outdir=args.outdir, 118 | mask_file=args.mask, 119 | fast=args.fast, 120 | overwrite=args.overwrite, 121 | vol_fl_function=lambda i: f"{gen_job.uid}_series_000_frame_{i:03d}", 122 | ) 123 | 124 | 125 | if __name__ == "__main__": 126 | main(add_args(interface.add_calc_args()).parse_args()) 127 | -------------------------------------------------------------------------------- /metrics/fsc/old/per_conf/cryosparc_3dva.py: -------------------------------------------------------------------------------- 1 | """Calculate FSCs across cryoSPARC 3D Variability model conformations. 2 | 3 | Example usage 4 | ------------- 5 | $ python metrics/fsc/old/per_conf/cryosparc_3dva.py results/CS-cryobench/J11 \ 6 | -o cBench/cBench-out_3Dvar/ --gt-dir vols/128_org/ --mask bproj_0.005.mrc 7 | 8 | """ 9 | import os 10 | import sys 11 | import argparse 12 | import json 13 | import numpy as np 14 | from glob import glob 15 | import logging 16 | from cryodrgn import analysis, mrc 17 | 18 | sys.path.append( 19 | os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 20 | ) 21 | from utils import volumes, interface 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def main(args: argparse.Namespace) -> None: 27 | """Script to get FSCs across conformations produced by cryoSPARC 3D Variability.""" 28 | 29 | cfg_file = os.path.join(args.input_dir, "job.json") 30 | if not os.path.exists(cfg_file): 31 | raise ValueError( 32 | f"Could not find cryoSPARC job info file `job.json` in given folder " 33 | f"{args.input_dir=} — did a cryoSPARC job use this as the output path?" 34 | ) 35 | with open(cfg_file) as f: 36 | configs = json.load(f) 37 | 38 | if configs["type"] != "var_3D": 39 | raise ValueError( 40 | f"Given folder {args.input_dir=} contains cryoSPARC job type " 41 | f"`{configs['type']=}`; this script is for 3D Variability jobs (`var_3D`)!" 42 | ) 43 | 44 | voldir = os.path.join(args.outdir, "vols") 45 | os.makedirs(voldir, exist_ok=True) 46 | file_pattern = "*.mrc" 47 | files = [ 48 | f for f in glob(os.path.join(args.input_dir, file_pattern)) if "mask" not in f 49 | ] 50 | pred_dir = sorted(files, key=volumes.numfile_sortkey) 51 | print("pred_dir[0]:", pred_dir[0]) 52 | csparc_job = pred_dir[0].split("/")[-1].split(".")[0].split("_")[0] 53 | print("cryosparc_job:", csparc_job) 54 | 55 | # weights z_ik 56 | cs_path = os.path.join(args.input_dir, f"{csparc_job}_particles.cs") 57 | map_mrc_path = os.path.join(args.input_dir, f"{csparc_job}_map.mrc") 58 | 59 | # reference 60 | v_0 = mrc.parse_mrc(map_mrc_path)[0] 61 | x = np.load(cs_path) 62 | component_mrc_path = os.path.join(args.input_dir, f"{csparc_job}_component_0.mrc") 63 | v_k1 = mrc.parse_mrc(component_mrc_path)[0] # [128 128 128] 64 | component_mrc_path = os.path.join(args.input_dir, f"{csparc_job}_component_1.mrc") 65 | v_k2 = mrc.parse_mrc(component_mrc_path)[0] # [128 128 128] 66 | component_mrc_path = os.path.join(args.input_dir, f"{csparc_job}_component_2.mrc") 67 | v_k3 = mrc.parse_mrc(component_mrc_path)[0] # [128 128 128] 68 | 69 | for i in range(args.num_vols): 70 | components_1 = x["components_mode_0/value"] 71 | components_2 = x["components_mode_1/value"] 72 | components_3 = x["components_mode_2/value"] 73 | 74 | start_i, end_i = i * args.num_imgs, (i + 1) * args.num_imgs 75 | z_1 = components_1[start_i:end_i].reshape(args.num_imgs, 1) 76 | z_2 = components_2[start_i:end_i].reshape(args.num_imgs, 1) 77 | z_3 = components_3[start_i:end_i].reshape(args.num_imgs, 1) 78 | 79 | z1_nth_avg = z_1.mean(axis=0) 80 | z1_nth_avg = z1_nth_avg.reshape(1, -1) 81 | z2_nth_avg = z_2.mean(axis=0) 82 | z2_nth_avg = z2_nth_avg.reshape(1, -1) 83 | z3_nth_avg = z_3.mean(axis=0) 84 | z3_nth_avg = z3_nth_avg.reshape(1, -1) 85 | 86 | nearest_z1, centers_ind1 = analysis.get_nearest_point(z_1, z1_nth_avg) 87 | nearest_z2, centers_ind2 = analysis.get_nearest_point(z_2, z2_nth_avg) 88 | nearest_z3, centers_ind3 = analysis.get_nearest_point(z_3, z3_nth_avg) 89 | vol = v_0 + (nearest_z1 * (v_k1) + nearest_z2 * (v_k2) + nearest_z3 * (v_k3)) 90 | mrc.write(os.path.join(voldir, f"vol_{i:03d}.mrc"), vol.astype(np.float32)) 91 | 92 | # Align output conformation volumes to ground truth volumes using ChimeraX 93 | if args.align_vols: 94 | volumes.align_volumes_multi(voldir, args.gt_dir) 95 | 96 | if args.calc_fsc_vals: 97 | volumes.get_fsc_curves( 98 | voldir, 99 | args.gt_dir, 100 | mask_file=args.mask, 101 | fast=args.fast, 102 | overwrite=args.overwrite, 103 | ) 104 | 105 | if args.align_vols: 106 | volumes.get_fsc_curves( 107 | os.path.join(voldir, "aligned"), 108 | args.gt_dir, 109 | mask_file=args.mask, 110 | fast=args.fast, 111 | overwrite=args.overwrite, 112 | ) 113 | 114 | 115 | if __name__ == "__main__": 116 | main(interface.add_calc_args().parse_args()) 117 | -------------------------------------------------------------------------------- /metrics/fsc/old/per_conf/cryosparc_abinitio.py: -------------------------------------------------------------------------------- 1 | """Calculate FSCs across cryoSPARC 3D Ab-Initio model conformations. 2 | 3 | Example usage 4 | ------------- 5 | $ python metrics/fsc/old/per_conf/cryosparc_abinitio.py results/CS-cryobench/J5 \ 6 | -o cBench/cBench-out_3Dcls/ --gt-dir vols/128_org/ --mask bproj_0.005.mrc \ 7 | --num-classes 10 8 | 9 | """ 10 | import os 11 | import sys 12 | import argparse 13 | import json 14 | import numpy as np 15 | from glob import glob 16 | import logging 17 | 18 | sys.path.append( 19 | os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 20 | ) 21 | from utils import volumes, interface 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def add_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 27 | parser.add_argument("--num-classes", default=10, type=int) 28 | 29 | return parser 30 | 31 | 32 | def main(args: argparse.Namespace) -> None: 33 | """Script to get FSCs across conformations produced by cryoSPARC Ab-Initio.""" 34 | 35 | cfg_file = os.path.join(args.input_dir, "job.json") 36 | if not os.path.exists(cfg_file): 37 | raise ValueError( 38 | f"Could not find cryoSPARC job info file `job.json` in given folder " 39 | f"{args.input_dir=} — did a cryoSPARC job use this as the output path?" 40 | ) 41 | with open(cfg_file) as f: 42 | configs = json.load(f) 43 | 44 | if configs["type"] != "homo_abinit": 45 | raise ValueError( 46 | f"Given folder {args.input_dir=} contains cryoSPARC job type " 47 | f"`{configs['type']=}`; this script is for ab-initio jobs (`homo_abinit`)!" 48 | ) 49 | 50 | file_pattern = "*.mrc" 51 | files = [ 52 | f for f in glob(os.path.join(args.input_dir, file_pattern)) if "mask" not in f 53 | ] 54 | pred_dir = sorted(files, key=volumes.numfile_sortkey) 55 | print("pred_dir[0]:", pred_dir[0]) 56 | csparc_num = pred_dir[0].split("/")[-1].split(".")[0].split("_")[3] 57 | csparc_job = pred_dir[0].split("/")[-1].split(".")[0].split("_")[0] 58 | print("cryosparc_num:", csparc_num) 59 | print("cryosparc_job:", csparc_job) 60 | 61 | lst = [] 62 | for cls in range(args.num_classes): 63 | class_fl = f"{csparc_job}_class_{cls:02d}_final_particles.cs" 64 | cs = np.load(os.path.join(args.input_dir, class_fl)) 65 | cs_new = cs[:: args.num_imgs] 66 | print(f"class {cls}: {len(cs_new)}") 67 | 68 | for cs_i in range(len(cs_new)): 69 | path = cs_new[cs_i]["blob/path"].decode("utf-8") 70 | gt = path.split("/")[-1].split("_")[1] 71 | lst.append((int(cls), int(gt))) 72 | 73 | if args.calc_fsc_vals: 74 | volumes.get_fsc_curves( 75 | args.input_dir, 76 | args.gt_dir, 77 | outdir=args.outdir, 78 | mask_file=args.mask, 79 | fast=args.fast, 80 | overwrite=args.overwrite, 81 | vol_fl_function=( 82 | lambda i: f"{csparc_job}_class_{lst[i][0]:02d}_final_volume" 83 | ), 84 | ) 85 | 86 | 87 | if __name__ == "__main__": 88 | main(add_args(interface.add_calc_args()).parse_args()) 89 | -------------------------------------------------------------------------------- /metrics/fsc/old/per_conf/drgnai.py: -------------------------------------------------------------------------------- 1 | """Calculate FSCs between conformations matched across DRGN-AI model latent spaces. 2 | 3 | Example usage 4 | ------------- 5 | $ python metrics/per_conf/drgnai.py results/drgnai_fixed \ 6 | --epoch 19 --Apix 3.0 -o output/drgnai_fixed --gt-dir ./gt_vols \ 7 | --mask ./mask.mrc --num-imgs 1000 --num-vols 100 8 | 9 | """ 10 | import os 11 | import sys 12 | import argparse 13 | import yaml 14 | import logging 15 | import numpy as np 16 | import torch 17 | import cryodrgn.utils 18 | from cryodrgnai.analyze import VolumeGenerator 19 | from cryodrgnai.lattice import Lattice 20 | from cryodrgnai import models 21 | 22 | sys.path.append( 23 | os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 24 | ) 25 | from utils import volumes, conformations, interface 26 | 27 | logging.basicConfig( 28 | level=logging.INFO, 29 | format="(%(levelname)s) (%(filename)s) (%(asctime)s) %(message)s", 30 | datefmt="%d-%b-%y %H:%M:%S", 31 | ) 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def main(args: argparse.Namespace) -> None: 36 | """Running the script to get FSCs across conformations produced by DRGN-AI.""" 37 | 38 | cfg_file = os.path.join(args.input_dir, "out", "drgnai-configs.yaml") 39 | if not os.path.exists(cfg_file): 40 | raise ValueError( 41 | f"Could not find DRGN-AI configuration parameter file 'out/config.pkl' " 42 | f"in given folder {args.input_dir=} — is this a DRGN-AI output folder?" 43 | ) 44 | 45 | epoch_str = "" if args.epoch == -1 else f".{args.epoch}" 46 | weights_fl = os.path.join(args.input_dir, "out", f"weights{epoch_str}.pkl") 47 | if not os.path.exists(weights_fl): 48 | raise ValueError( 49 | f"Could not find DRGN-AI model weights for epoch {args.epoch} " 50 | f"in output folder {args.input_dir=} — did the model finishing running?" 51 | ) 52 | z_path = os.path.join(args.input_dir, "out", f"conf{epoch_str}.pkl") 53 | if not os.path.exists(z_path): 54 | raise ValueError( 55 | f"Could not find DRGN-AI latent space coordinates for epoch {args.epoch} " 56 | f"in output folder {args.input_dir=} — did the model finishing running?" 57 | ) 58 | 59 | logger.info(f"Putting output under: {args.outdir} ...") 60 | os.makedirs(args.outdir, exist_ok=True) 61 | z = cryodrgn.utils.load_pkl(z_path) 62 | num_imgs = int(args.num_imgs) if z.shape[0] == 100000 else "ribo" 63 | nearest_z_array = conformations.get_nearest_z_array(z, args.num_vols, num_imgs) 64 | 65 | with open(cfg_file, "r") as f: 66 | configs = yaml.safe_load(f) 67 | checkpoint = torch.load(weights_fl) 68 | hypervolume_params = checkpoint["hypervolume_params"] 69 | hypervolume = models.HyperVolume(**hypervolume_params) 70 | hypervolume.load_state_dict(checkpoint["hypervolume_state_dict"]) 71 | hypervolume.eval() 72 | hypervolume.to(args.cuda_device) 73 | 74 | lattice = Lattice( 75 | checkpoint["hypervolume_params"]["resolution"], 76 | extent=0.5, 77 | device=args.cuda_device, 78 | ) 79 | z_dim = checkpoint["hypervolume_params"]["z_dim"] 80 | radius_mask = ( 81 | checkpoint["output_mask_radius"] if "output_mask_radius" in checkpoint else None 82 | ) 83 | vol_generator = VolumeGenerator( 84 | hypervolume, 85 | lattice, 86 | z_dim, 87 | True, 88 | radius_mask, 89 | data_norm=(configs["data_norm_mean"], configs["data_norm_std"]), 90 | ) 91 | 92 | out_zfile = os.path.join(args.outdir, "zfile.txt") 93 | logger.info(out_zfile) 94 | if os.path.exists(out_zfile) and not args.overwrite: 95 | logger.info("Z file exists, skipping...") 96 | else: 97 | if not args.dry_run: 98 | np.savetxt(out_zfile, nearest_z_array) 99 | vol_generator.gen_volumes(args.outdir, nearest_z_array) 100 | 101 | # Align output conformation volumes to ground truth volumes using ChimeraX 102 | if args.align_vols: 103 | volumes.align_volumes_multi( 104 | args.outdir, args.gt_dir, flip=args.flip_align, random_seed=args.align_seed 105 | ) 106 | 107 | if args.calc_fsc_vals: 108 | volumes.get_fsc_curves( 109 | args.outdir, 110 | args.gt_dir, 111 | mask_file=args.mask, 112 | fast=args.fast, 113 | overwrite=args.overwrite, 114 | ) 115 | 116 | if args.align_vols: 117 | aligndir = "flipped_aligned" if args.flip_align else "aligned" 118 | volumes.get_fsc_curves( 119 | os.path.join(args.outdir, aligndir), 120 | args.gt_dir, 121 | mask_file=args.mask, 122 | fast=args.fast, 123 | overwrite=args.overwrite, 124 | ) 125 | 126 | 127 | if __name__ == "__main__": 128 | main(interface.add_calc_args().parse_args()) 129 | -------------------------------------------------------------------------------- /metrics/fsc/old/per_conf/opusdsd.py: -------------------------------------------------------------------------------- 1 | """Calculate FSCs between conformations matched across an OPUS-DSD model latent space. 2 | 3 | See github.com/alncat/opusDSD for the source code for this method, and 4 | www.nature.com/articles/s41592-023-02031-6 for its publication. 5 | 6 | Example usage 7 | ------------- 8 | $ python metrics/fsc/old/per_conf/opusdsd.py opusdsd-outputs/001_base/ \ 9 | --epoch 19 -o opusdsd-outputs/001_base/cryobench.10/ \ 10 | --gt-dir IgG-1D/vols/128_org/ --mask IgG-1D/init_mask/mask.mrc \ 11 | --num-imgs 1000 --num-vols 100 --Apix=3.0 12 | 13 | # We sometimes need to pad the opusDSD volumes to a larger box size 14 | $ python metrics/fsc/old/per_conf/opusdsd.py opusdsd-outputs/001_base/ \ 15 | --epoch 19 -o opusdsd-outputs/001_base/cryobench.10/ \ 16 | --gt-dir IgG-1D/vols/128_org/ --mask IgG-1D/init_mask/mask.mrc \ 17 | --num-imgs 1000 --num-vols 100 --Apix=3.0 -D 256 18 | 19 | """ 20 | import os 21 | import sys 22 | import argparse 23 | import subprocess 24 | from glob import glob 25 | import logging 26 | import numpy as np 27 | import torch 28 | 29 | ROOTDIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 30 | sys.path.append(os.path.join(ROOTDIR, "fsc")) 31 | from utils import volumes, conformations, interface 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | def add_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 37 | parser.add_argument("-D", default=128, type=int) 38 | 39 | return parser 40 | 41 | 42 | def main(args: argparse.Namespace) -> None: 43 | """Running the script to get FSCs across conformations produced by OPUS-DSD.""" 44 | 45 | cfg_file = os.path.join(args.input_dir, "config.pkl") 46 | if not os.path.exists(cfg_file): 47 | raise ValueError( 48 | f"Could not find opusDSD config file {cfg_file} " 49 | f"— is {args.input_dir=} a folder opusDSD output folder?" 50 | ) 51 | 52 | epoch_str = "" if args.epoch == -1 else f".{args.epoch}" 53 | weights_fl = os.path.join(args.input_dir, f"weights{epoch_str}.pkl") 54 | if not os.path.exists(weights_fl): 55 | raise ValueError( 56 | f"Could not find opusDSD model weights for epoch {args.epoch} " 57 | f"in output folder {args.input_dir=} — did the model finishing running?" 58 | ) 59 | z_path = os.path.join(args.input_dir, f"z{epoch_str}.pkl") 60 | if not os.path.exists(z_path): 61 | raise ValueError( 62 | f"Could not find opusDSD latent space coordinates for epoch {args.epoch} " 63 | f"in output folder {args.input_dir=} — did the model finishing running?" 64 | ) 65 | 66 | logger.info(f"Putting output under: {args.outdir} ...") 67 | voldir = os.path.join(args.outdir, "vols") 68 | os.makedirs(voldir, exist_ok=True) 69 | z = torch.load(z_path)["mu"].cpu().numpy() 70 | num_imgs = int(args.num_imgs) if z.shape[0] == 100000 else "ribo" 71 | nearest_z_array = conformations.get_nearest_z_array(z, args.num_vols, num_imgs) 72 | 73 | eval_vol_cmd = os.path.join( 74 | ROOTDIR, 75 | "methods", 76 | "opusDSD", 77 | "cryodrgn", 78 | "commands", 79 | "eval_vol.py", 80 | ) 81 | out_zfile = os.path.join(args.outdir, "zfile.txt") 82 | logger.info(out_zfile) 83 | cmd = f"CUDA_VISIBLE_DEVICES={args.cuda_device}; " 84 | cmd += f"python {eval_vol_cmd} --load {weights_fl} -c {cfg_file} " 85 | cmd += f"--zfile {out_zfile} -o {voldir} --Apix {args.Apix}; " 86 | 87 | logging.basicConfig(level=logging.INFO) 88 | logger.info(cmd) 89 | if os.path.exists(out_zfile) and not args.overwrite: 90 | logger.info("Z file exists, skipping...") 91 | else: 92 | if not args.dry_run: 93 | np.savetxt(out_zfile, nearest_z_array) 94 | subprocess.check_call(cmd, shell=True) 95 | 96 | # Align output conformation volumes to ground truth volumes using ChimeraX 97 | if args.align_vols: 98 | volumes.align_volumes_multi(voldir, args.gt_dir) 99 | 100 | conformations.pad_mrc_vols(sorted(glob(os.path.join(voldir, "*.mrc"))), args.D) 101 | if args.align_vols: 102 | volumes.align_volumes_multi(voldir, args.gt_dir) 103 | 104 | if args.calc_fsc_vals: 105 | volumes.get_fsc_curves( 106 | voldir, 107 | args.gt_dir, 108 | mask_file=args.mask, 109 | fast=args.fast, 110 | overwrite=args.overwrite, 111 | vol_fl_function=lambda i: f"reference{i}", 112 | ) 113 | 114 | if args.align_vols: 115 | volumes.get_fsc_curves( 116 | os.path.join(voldir, "aligned"), 117 | args.gt_dir, 118 | mask_file=args.mask, 119 | fast=args.fast, 120 | overwrite=args.overwrite, 121 | vol_fl_function=lambda i: f"reference{i}", 122 | ) 123 | 124 | 125 | if __name__ == "__main__": 126 | main(add_args(interface.add_calc_args()).parse_args()) 127 | -------------------------------------------------------------------------------- /metrics/fsc/old/per_conf/per_conf_calc.py: -------------------------------------------------------------------------------- 1 | """Calculate FSCs between conformations matched across a volume model's latent space. 2 | 3 | This script is an alternative to the method-specific FSC calculation scripts found in 4 | this folder; it can automatically detect the method used to generate 5 | the output folder given. 6 | 7 | Example usage 8 | ------------- 9 | $ python metrics/fsc/old/per_conf/per_conf_calc results/cryodrgn \ 10 | --epoch 19 --Apix 3.0 -o output --gt-dir ./gt_vols --mask ./mask.mrc 11 | 12 | """ 13 | import os 14 | import sys 15 | import argparse 16 | import json 17 | from cdrgn import main as run_cdrgn 18 | from drgnai import main as run_drgnai 19 | from opusdsd import main as run_opusdsd 20 | from re_covar import main as run_recovar 21 | from cryosparc_3dcls import main as run_cryosparc_3dcls 22 | from cryosparc_abinitio import main as run_cryosparc_abinitio 23 | from cryosparc_3dva import main as run_cryosparc_3dva 24 | from cryosparc_3dflex import main as run_cryosparc_3dflex 25 | 26 | sys.path.append( 27 | os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 28 | ) 29 | from utils import interface 30 | 31 | 32 | def main(args: argparse.Namespace): 33 | input_files = os.listdir(args.input_dir) 34 | 35 | if "config.yaml" in input_files: 36 | run_cdrgn(args) 37 | elif "config.pkl" in input_files: 38 | run_opusdsd(args) 39 | elif "reordered_z.npy" in input_files: 40 | run_recovar(args) 41 | 42 | elif "job.json" in input_files: 43 | with open(os.path.join(args.input_dir, "job.json")) as f: 44 | configs = json.load(f) 45 | 46 | if configs["type"] == "var_3D": 47 | run_cryosparc_3dva(args) 48 | elif configs["type"] == "homo_abinit": 49 | run_cryosparc_abinitio(args) 50 | elif configs["type"] == "flex_test": 51 | run_cryosparc_3dflex(args) 52 | elif configs["type"] == "class_3D": 53 | run_cryosparc_3dcls(args) 54 | else: 55 | raise RuntimeError(f"Unrecognized cryoSPARC job type `{configs['type']}` !") 56 | 57 | elif ( 58 | "out" in input_files 59 | and os.path.isdir(os.path.join(args.input_dir, "out")) 60 | and "config.yaml" in os.listdir(os.path.join(args.input_dir, "out")) 61 | ): 62 | run_drgnai(args) 63 | 64 | else: 65 | raise ValueError( 66 | f"Unrecognized output folder format found in `{args.input_dir}`!" 67 | f"Does not match for any known methods: " 68 | "cryoDRGN, DRGN-AI, opusDSD, 3dflex, 3DVA, RECOVAR" 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | main(interface.add_calc_args().parse_args()) 74 | -------------------------------------------------------------------------------- /metrics/fsc/old/per_conf/re_covar.py: -------------------------------------------------------------------------------- 1 | """Calculate FSCs across RECOVAR model conformations. 2 | 3 | Example usage 4 | ------------- 5 | $ python metrics/fsc/old/per_conf/re_covar.py results/recovar/001/ \ 6 | -o cryobench-outputs/recovar/ --gt-dir IgG-1D/vols/128_org/ \ 7 | --mask IgG-1D/init_mask/mask.mrc --num-imgs 1000 --num-vols 100 8 | 9 | """ 10 | import os 11 | import sys 12 | import argparse 13 | import logging 14 | import numpy as np 15 | from cryodrgn import analysis 16 | 17 | ROOTDIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 18 | sys.path.append(os.path.join(ROOTDIR, "methods", "recovar")) 19 | from recovar import dataset, embedding, output 20 | 21 | sys.path.append(os.path.join(ROOTDIR, "fsc")) 22 | from utils import volumes, conformations, interface 23 | 24 | logging.basicConfig( 25 | level=logging.INFO, 26 | format="(%(levelname)s) (%(filename)s) (%(asctime)s) %(message)s", 27 | datefmt="%d-%b-%y %H:%M:%S", 28 | ) 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | def add_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 33 | parser.add_argument("--zdim", default=10, type=int) 34 | parser.add_argument( 35 | "--n-bins", 36 | type=float, 37 | default=50, 38 | dest="n_bins", 39 | help="number of bins for reweighting", 40 | ) 41 | parser.add_argument("--Bfactor", type=float, default=0, help="0") 42 | 43 | return parser 44 | 45 | 46 | def main(args: argparse.Namespace) -> None: 47 | """Running the script to get FSCs across conformations produced by RECOVAR.""" 48 | 49 | pipeline_output = output.PipelineOutput(args.input_dir) 50 | cryos = pipeline_output.get("lazy_dataset") 51 | zs = pipeline_output.get("zs")[args.zdim] 52 | zs_reordered = dataset.reorder_to_original_indexing(zs, cryos) 53 | 54 | latent_path = os.path.join(args.input_dir, "reordered_z.npy") 55 | umap_path = os.path.join(args.input_dir, "reordered_z_umap.npy") 56 | if os.path.exists(latent_path) and not args.overwrite: 57 | logger.info("latent coordinates file already exists, skipping...") 58 | else: 59 | np.save(latent_path, zs_reordered) 60 | 61 | if os.path.exists(umap_path) and not args.overwrite: 62 | logger.info("latent UMAP clustering already exists, skipping...") 63 | else: 64 | umap_pkl = analysis.run_umap(zs_reordered) 65 | np.save(umap_path, umap_pkl) 66 | 67 | cryos = pipeline_output.get("dataset") 68 | embedding.set_contrasts_in_cryos(cryos, pipeline_output.get("contrasts")[args.zdim]) 69 | zs = pipeline_output.get("zs")[args.zdim] 70 | cov_zs = pipeline_output.get("cov_zs")[args.zdim] 71 | noise_variance = pipeline_output.get("noise_var_used") 72 | zs_reordered = dataset.reorder_to_original_indexing(zs, cryos) 73 | num_imgs = int(args.num_imgs) if zs.shape[0] == 100000 else "ribo" 74 | nearest_z_array = conformations.get_nearest_z_array( 75 | zs_reordered, args.num_vols, num_imgs 76 | ) 77 | 78 | output.mkdir_safe(args.outdir) 79 | log_file = os.path.join(args.outdir, "run.log") 80 | if os.path.exists(log_file) and not args.overwrite: 81 | logger.info("run.log file exists, skipping...") 82 | else: 83 | logger.addHandler(logging.FileHandler(log_file)) 84 | logger.info(args) 85 | 86 | output.compute_and_save_reweighted( 87 | cryos, 88 | nearest_z_array, 89 | zs, 90 | cov_zs, 91 | noise_variance, 92 | args.outdir, 93 | args.Bfactor, 94 | args.n_bins, 95 | ) 96 | 97 | # Align output conformation volumes to ground truth volumes using ChimeraX 98 | if args.align_vols: 99 | volumes.align_volumes_multi(args.outdir, args.gt_dir, flip=args.flip_align) 100 | 101 | if args.calc_fsc_vals: 102 | volumes.get_fsc_curves( 103 | args.outdir, 104 | args.gt_dir, 105 | mask_file=args.mask, 106 | fast=args.fast, 107 | overwrite=args.overwrite, 108 | vol_fl_function=lambda i: os.path.join( 109 | f"vol{i:03d}", "ml_optimized_locres_filtered" 110 | ), 111 | ) 112 | 113 | if args.align_vols: 114 | volumes.get_fsc_curves( 115 | args.outdir, 116 | args.gt_dir, 117 | mask_file=args.mask, 118 | fast=args.fast, 119 | overwrite=args.overwrite, 120 | vol_fl_function=lambda i: os.path.join( 121 | f"vol{i:03d}", "ml_optimized_locres_filtered" 122 | ), 123 | ) 124 | 125 | 126 | if __name__ == "__main__": 127 | main(add_args(interface.add_calc_args()).parse_args()) 128 | -------------------------------------------------------------------------------- /metrics/fsc/plot_fsc.py: -------------------------------------------------------------------------------- 1 | """Visualize FSCs between conformations matched across a volume model's latent space. 2 | 3 | The CryoBench output directory used as the argument to this script should contain at 4 | least one folder with the prefix "fsc_" already produced by a FSC analysis script such 5 | as `cdrgn.py`. 6 | 7 | Example usage 8 | ------------- 9 | $ python metrics/fsc/plot_fsc.py cryobench_output/ 10 | 11 | """ 12 | import os 13 | import argparse 14 | from glob import glob 15 | import matplotlib.pyplot as plt 16 | import seaborn as sns 17 | import numpy as np 18 | import pandas as pd 19 | from sklearn.metrics import auc 20 | from utils import volumes 21 | 22 | 23 | def create_args() -> argparse.ArgumentParser: 24 | parser = argparse.ArgumentParser(description=__doc__) 25 | parser.add_argument( 26 | "outdir", 27 | type=os.path.abspath, 28 | help="Input directory containing outputs of FSC per conformation analysis.", 29 | ) 30 | parser.add_argument("--Apix", type=float, default=3.0, help="pixel size") 31 | 32 | return parser 33 | 34 | 35 | def main(args: argparse.Namespace) -> None: 36 | fsc_dirs = [ 37 | d 38 | for d in os.listdir(args.outdir) 39 | if d.startswith("fsc_") and os.path.isdir(os.path.join(args.outdir, d)) 40 | ] 41 | 42 | for fsc_lbl in fsc_dirs: 43 | subdir = os.path.join(args.outdir, fsc_lbl) 44 | fsc_files = sorted( 45 | glob(os.path.join(subdir, "*.txt")), key=volumes.numfile_sortkey 46 | ) 47 | 48 | fsc_list = list() 49 | auc_lst = list() 50 | for i, fsc_file in enumerate(fsc_files): 51 | fsc = pd.read_csv(fsc_file, sep=" ") 52 | plt.plot(fsc.pixres, fsc.fsc, label=i) 53 | fsc_list.append(fsc.assign(vol=i)) 54 | auc_lst.append(auc(fsc.pixres, fsc.fsc)) 55 | 56 | freq = np.arange(0, 6) * 0.1 57 | res_text = [ 58 | f"{int(k / (2. * args.Apix) * 1000)/1000. if k > 0 else 'DC'}" 59 | for k in np.linspace(0, 1, 6) 60 | ] 61 | plt.xlim((0, 0.5)) 62 | plt.ylim((0, 1)) 63 | plt.grid(True, linewidth=0.53) 64 | plt.xticks(freq, res_text, fontsize=11) 65 | plt.yticks(fontsize=11) 66 | plt.xlabel("Spatial frequency (1/Å)", fontsize=14, weight="semibold") 67 | plt.ylabel("Fourier shell correlation", fontsize=14, weight="semibold") 68 | 69 | auc_avg_np = np.nanmean(auc_lst) 70 | auc_std_np = np.nanstd(auc_lst) 71 | auc_med_np = np.nanmedian(auc_lst, 0) 72 | plt.title( 73 | f"AUC: {auc_avg_np:.3f}\u00B1{auc_std_np:.3f}; median: {auc_med_np:.3f}", 74 | fontsize=15, 75 | ) 76 | 77 | auc_str = str() 78 | for i, auc_val in enumerate(auc_lst): 79 | auc_str += f"{i:>7}: AUC {auc_val:.4f}" 80 | if i < (len(auc_lst) - 1) and i % 4 == 3: 81 | auc_str += "\n" 82 | 83 | auc_str += "\n-------------------------------------------------\n" 84 | auc_str += f"AUC_avg: {auc_avg_np:.5f}, std: {auc_std_np:.5f}, " 85 | auc_str += f"AUC_med: {auc_med_np:.5f}" 86 | print(auc_str) 87 | 88 | plt.tight_layout() 89 | pltfile = os.path.join(args.outdir, f"{fsc_lbl}.png") 90 | plt.savefig(pltfile, dpi=1200, bbox_inches="tight") 91 | plt.clf() 92 | 93 | fsc_df = pd.concat(fsc_list).reset_index(drop=True) 94 | sns.set_style("ticks") 95 | g = sns.lineplot(data=fsc_df, x="pixres", y="fsc", color="red", ci="sd") 96 | 97 | plt.xticks(freq, res_text, fontsize=11) 98 | plt.yticks(fontsize=11) 99 | g.figure.axes[0].set(xlim=(0, 0.5)) 100 | g.figure.axes[0].set(ylim=(0, 1.0)) 101 | plt.grid(True, linewidth=0.53) 102 | plt.xlabel("Spatial frequency (1/Å)", fontsize=14, weight="semibold") 103 | plt.ylabel("Fourier shell correlation", fontsize=14, weight="semibold") 104 | 105 | pltfile = os.path.join(args.outdir, f"{fsc_lbl}_means.png") 106 | plt.savefig(pltfile, dpi=1200, bbox_inches="tight") 107 | plt.clf() 108 | 109 | print(f"`{fsc_lbl}` plots saved!\n") 110 | 111 | 112 | if __name__ == "__main__": 113 | main(create_args().parse_args()) 114 | -------------------------------------------------------------------------------- /metrics/fsc/utils/conformations.py: -------------------------------------------------------------------------------- 1 | """Utility functions used across FSC pipelines for handling conformation outputs.""" 2 | 3 | from collections.abc import Iterable 4 | import logging 5 | from glob import glob 6 | from typing import Union 7 | import numpy as np 8 | from cryodrgn import analysis, mrc 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | # Ribosembly number of images per Ribosembly structure (total 16 structures) 14 | RIBOSEMBLY_NUM_IMGS = [ 15 | 9076, 16 | 14378, 17 | 23547, 18 | 44366, 19 | 30647, 20 | 38500, 21 | 3915, 22 | 3980, 23 | 12740, 24 | 11975, 25 | 17988, 26 | 5001, 27 | 35367, 28 | 37448, 29 | 40540, 30 | 5772, 31 | ] 32 | 33 | 34 | def get_nearest_z_array( 35 | zmat: np.ndarray, num_vols: int, num_imgs: Union[int, str] 36 | ) -> np.ndarray: 37 | z_lst = [] 38 | z_mean_lst = [] 39 | for i in range(num_vols): 40 | if isinstance(num_imgs, int): 41 | z_nth = zmat[(i * num_imgs) : ((i + 1) * num_imgs)] 42 | elif num_imgs == "ribo": 43 | z_nth = zmat[ 44 | sum(RIBOSEMBLY_NUM_IMGS[:i]) : sum(RIBOSEMBLY_NUM_IMGS[: (i + 1)]) 45 | ] 46 | else: 47 | raise ValueError(f"{num_imgs=}") 48 | 49 | z_nth_avg = z_nth.mean(axis=0) 50 | z_nth_avg = z_nth_avg.reshape(1, -1) 51 | z_lst.append(z_nth) 52 | z_mean_lst.append(z_nth_avg) 53 | 54 | nearest_z_lst = [] 55 | centers_ind_lst = [] 56 | num_img_for_centers = 0 57 | for i in range(num_vols): 58 | nearest_z, centers_ind = analysis.get_nearest_point(z_lst[i], z_mean_lst[i]) 59 | nearest_z_lst.append(nearest_z.reshape(nearest_z.shape[-1])) 60 | centers_ind_lst.append(centers_ind + num_img_for_centers) 61 | 62 | if num_imgs == "ribo": 63 | num_img_for_centers += RIBOSEMBLY_NUM_IMGS[i] 64 | 65 | return np.array(nearest_z_lst) 66 | 67 | 68 | def pad_mrc_vols(mrc_volfiles: Iterable[str], new_D: int) -> None: 69 | for mrc_file in mrc_volfiles: 70 | v, header = mrc.parse_mrc(mrc_file) 71 | x, y, z = v.shape 72 | assert new_D >= x 73 | assert new_D >= y 74 | assert new_D >= z 75 | 76 | new = np.zeros((new_D, new_D, new_D), dtype=np.float32) 77 | i = (new_D - x) // 2 78 | j = (new_D - y) // 2 79 | k = (new_D - z) // 2 80 | new[i : (i + x), j : (j + y), k : (k + z)] = v 81 | 82 | # adjust origin 83 | apix = header.get_apix() 84 | xorg, yorg, zorg = header.get_origin() 85 | xorg -= apix * k 86 | yorg -= apix * j 87 | zorg -= apix * i 88 | 89 | mrc.write(mrc_file, new, mrc.MRCHeader.make_default_header(new, Apix=apix)) 90 | 91 | 92 | def parse_csparc_dir(workdir): 93 | x = glob("{}/*particles.cs".format(workdir)) 94 | y = [xx for xx in x if "class" not in xx] 95 | y = sorted(y) 96 | cs_info = y[-1].split("_") 97 | it = cs_info[-2] 98 | cs_job = cs_info[-3] 99 | cs_proj = cs_info[-4] 100 | logger.info("Found alignments files: {}".format(y)) 101 | logger.info("Using {} {} iteration {}".format(cs_proj, cs_job, it)) 102 | 103 | return y[-1], cs_proj, cs_job, it 104 | 105 | 106 | def get_csparc_pi(particles_cs, K): 107 | p = np.load(particles_cs) 108 | post = [p["alignments_class_{}/class_posterior".format(i)] for i in range(K)] 109 | post = np.asarray(post) 110 | post = post.T 111 | 112 | return post 113 | -------------------------------------------------------------------------------- /metrics/fsc/utils/interface.py: -------------------------------------------------------------------------------- 1 | """Command-line interfaces shared across FSC per conformation commands.""" 2 | 3 | import argparse 4 | import os 5 | 6 | 7 | def add_calc_args() -> argparse.ArgumentParser: 8 | """Command-line interface used in commands calculating FSCs per conformation.""" 9 | 10 | parser = argparse.ArgumentParser(description=__doc__) 11 | parser.add_argument("input_dir", help="dir contains weights, config, z") 12 | parser.add_argument( 13 | "-o", 14 | "--outdir", 15 | default="output_fsc", 16 | type=os.path.abspath, 17 | help="Output directory", 18 | ) 19 | parser.add_argument( 20 | "--epoch", default=19, type=int, help="Number of training epochs" 21 | ) 22 | parser.add_argument( 23 | "--num-vols", 24 | type=int, 25 | help="Use first reconstructed volumes instead of all of them", 26 | ) 27 | parser.add_argument("--Apix", default=3.0, type=float) 28 | parser.add_argument( 29 | "--num-imgs", 30 | default=1000, 31 | type=int, 32 | help="Number of images per model (structure)", 33 | ) 34 | parser.add_argument( 35 | "--mask", 36 | default=None, 37 | type=os.path.abspath, 38 | help="Path to mask .mrc to compute the masked metric", 39 | ) 40 | parser.add_argument("--gt-dir", help="Directory of gt volumes") 41 | parser.add_argument("--overwrite", action="store_true") 42 | parser.add_argument("--dry-run", action="store_true") 43 | parser.add_argument("--fast", type=int, default=1) 44 | parser.add_argument("--cuda-device", default=0, type=int) 45 | parser.add_argument("--no-fscs", action="store_false", dest="calc_fsc_vals") 46 | parser.add_argument("--align-vols", action="store_true") 47 | parser.add_argument("--flip-align", action="store_true") 48 | 49 | parser.add_argument( 50 | "--align-seed", 51 | type=int, 52 | help="random seed to use for alignment initialization selection", 53 | ) 54 | 55 | return parser 56 | -------------------------------------------------------------------------------- /metrics/fsc/utils/volumes.py: -------------------------------------------------------------------------------- 1 | """Utility functions used across pipelines for calculating FSCs across conformations. 2 | 3 | Many of these functions for calculating FSCs were originally copied from cryoDRGN v3.4.1 4 | CryoBench methods depend on older versions of cryoDRGN that don't have these methods! 5 | 6 | """ 7 | import os 8 | import subprocess 9 | import time 10 | import yaml 11 | import re 12 | from glob import glob 13 | import logging 14 | from typing import Optional, Callable, Union 15 | import numpy as np 16 | import pandas as pd 17 | import torch 18 | from cryodrgn import fft, models, mrc 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | CHIMERAX_PATH = os.environ["CHIMERAX_PATH"] 24 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 25 | 26 | 27 | def numfile_sortkey(s: str) -> list: 28 | """Split a filepath into a list that can be used to sort files by numeric order.""" 29 | parts = re.split("([0-9]+)", s) 30 | parts[1::2] = map(int, parts[1::2]) 31 | 32 | return parts 33 | 34 | 35 | def get_volume_generator( 36 | config_path: str, checkpoint_path: str 37 | ) -> Callable[[torch.Tensor], np.ndarray]: 38 | """Create a latent space volume generator using a saved cryoDRGN model.""" 39 | with open(config_path, "r") as f: 40 | cfg = yaml.safe_load(f) 41 | 42 | norm = [float(x) for x in cfg["dataset_args"]["norm"]] 43 | model, lattice = models.HetOnlyVAE.load(cfg, checkpoint_path, device="cuda:0") 44 | model.eval() 45 | 46 | return lambda z: model.decoder.eval_volume( 47 | lattice.coords, lattice.D, lattice.extent, norm, z 48 | ) 49 | 50 | 51 | def align_volumes_multi( 52 | vol_paths: Union[str, list[str]], 53 | gt_paths: Union[str, list[str]], 54 | outdir: Optional[str] = None, 55 | flip: bool = False, 56 | random_seed: Optional[int] = None, 57 | ) -> None: 58 | if isinstance(vol_paths, str): 59 | if os.path.isdir(vol_paths): 60 | matching_vols = sorted( 61 | glob(os.path.join(vol_paths, "*.mrc")), key=numfile_sortkey 62 | ) 63 | else: 64 | raise ValueError( 65 | "Single value given for `vol_paths` must be a path to a directory " 66 | "containing .mrc volumes to be aligned!" 67 | ) 68 | elif isinstance(vol_paths, list): 69 | matching_vols = vol_paths 70 | else: 71 | raise ValueError( 72 | f"Unrecognized type given for argument " 73 | f"`vol_paths`: {type(vol_paths).__name__} !" 74 | ) 75 | 76 | if isinstance(gt_paths, str): 77 | if os.path.isdir(gt_paths): 78 | gt_vols = sorted(glob(os.path.join(gt_paths, "*.mrc")), key=numfile_sortkey) 79 | else: 80 | raise ValueError( 81 | "Single value given for `gt_paths` must be a path to a directory " 82 | "containing .mrc volumes to be aligned against!" 83 | ) 84 | elif isinstance(gt_paths, list): 85 | gt_vols = gt_paths 86 | else: 87 | raise ValueError( 88 | f"Unrecognized type given for argument " 89 | f"`gt_paths`: {type(gt_paths).__name__} !" 90 | ) 91 | 92 | if outdir is None: 93 | if isinstance(vol_paths, str): 94 | outdir = vol_paths 95 | else: 96 | outdir = os.path.dirname(vol_paths[0]) 97 | 98 | aligndir = "flipped_aligned" if flip else "aligned" 99 | os.makedirs(os.path.join(outdir, aligndir), exist_ok=True) 100 | flip_str = "--flip" if flip else "" 101 | seed_str = f" --seed {random_seed}" if random_seed is not None else "" 102 | align_jobs = list() 103 | 104 | for i, file_path in enumerate(matching_vols): 105 | base_filename = os.path.splitext(os.path.basename(file_path))[0] 106 | new_filename = base_filename + ".mrc" 107 | destination_path = os.path.join(outdir, aligndir, new_filename) 108 | ref_path = gt_vols[i] 109 | tmp_file = os.path.join(outdir, aligndir, f"temp_{i:03d}.txt") 110 | 111 | align_cmd = ( 112 | f"sbatch -t 61 -J align_{i} -o {tmp_file} --wrap='{CHIMERAX_PATH} --nogui " 113 | f"--script \" {os.path.join(ROOT_DIR, 'utils', 'align.py')} {ref_path} " 114 | f"{os.path.join(outdir, new_filename)} -o {destination_path} " 115 | f"{flip_str}{seed_str} -f {tmp_file} \" ' " 116 | ) 117 | if i % 20 == 0: 118 | print(align_cmd) 119 | 120 | align_out = subprocess.run(align_cmd, shell=True, capture_output=True) 121 | assert align_out.stderr.decode("utf8") == "", align_out.stderr.decode("utf8") 122 | align_out = align_out.stdout.decode("utf8") 123 | align_jobs.append(align_out.strip().split("Submitted batch job ")[1]) 124 | 125 | jobs_left = len(align_jobs) 126 | while jobs_left > 0: 127 | if jobs_left > 1: 128 | print(f"Waiting for {jobs_left} volume alignment jobs to finish...") 129 | else: 130 | print( 131 | f"Waiting for one volume alignment job " 132 | f"(Slurm ID: {align_jobs[0]}) to finish..." 133 | ) 134 | 135 | time.sleep(max(10, jobs_left / 1.7)) 136 | jobs_left = ( 137 | subprocess.run( 138 | f"squeue -h -j {','.join(align_jobs)}", shell=True, capture_output=True 139 | ) 140 | .stdout.decode("utf8") 141 | .count("\n") 142 | ) 143 | 144 | 145 | def get_fsc_cutoff(fsc_curve: pd.DataFrame, t: float) -> float: 146 | """Find the resolution at which the FSC curve first crosses a given threshold.""" 147 | fsc_indx = np.where(fsc_curve.fsc < t)[0] 148 | return fsc_curve.pixres[fsc_indx[0]] ** -1 if len(fsc_indx) > 0 else 2.0 149 | 150 | 151 | def get_fftn_center_dists(box_size: int) -> np.array: 152 | """Get distances from the center (and hence the resolution) for FFT co-ordinates.""" 153 | 154 | x = np.arange(-box_size // 2, box_size // 2) 155 | x2, x1, x0 = np.meshgrid(x, x, x, indexing="ij") 156 | coords = np.stack((x0, x1, x2), -1) 157 | dists = (coords**2).sum(-1) ** 0.5 158 | assert dists[box_size // 2, box_size // 2, box_size // 2] == 0.0 159 | 160 | return dists 161 | 162 | 163 | def calculate_fsc( 164 | v1: Union[np.ndarray, torch.Tensor], v2: Union[np.ndarray, torch.Tensor] 165 | ) -> float: 166 | """Calculate the Fourier Shell Correlation between two complex vectors.""" 167 | var = (np.vdot(v1, v1) * np.vdot(v2, v2)) ** 0.5 168 | 169 | return float((np.vdot(v1, v2) / var).real) if var else 1.0 170 | 171 | 172 | def get_fsc_curve( 173 | vol1: torch.Tensor, 174 | vol2: torch.Tensor, 175 | mask_file: Optional[str] = None, 176 | ) -> pd.DataFrame: 177 | """Calculate the FSCs between two volumes across all available resolutions.""" 178 | 179 | maskvol = None 180 | if mask_file is not None: 181 | maskvol = torch.tensor(mrc.parse_mrc(mask_file)[0]) 182 | 183 | # Apply the given mask before applying the Fourier transform 184 | maskvol1 = vol1 * maskvol if maskvol is not None else vol1.clone() 185 | maskvol2 = vol2 * maskvol if maskvol is not None else vol2.clone() 186 | box_size = vol1.shape[0] 187 | dists = get_fftn_center_dists(box_size) 188 | maskvol1 = fft.fftn_center(maskvol1) 189 | maskvol2 = fft.fftn_center(maskvol2) 190 | 191 | prev_mask = np.zeros((box_size, box_size, box_size), dtype=bool) 192 | fsc = [1.0] 193 | for i in range(1, box_size // 2): 194 | mask = dists < i 195 | shell = np.where(mask & np.logical_not(prev_mask)) 196 | fsc.append(calculate_fsc(maskvol1[shell], maskvol2[shell])) 197 | prev_mask = mask 198 | 199 | return pd.DataFrame( 200 | dict(pixres=np.arange(box_size // 2) / box_size, fsc=fsc), dtype=float 201 | ) 202 | 203 | 204 | def get_fsc_curves( 205 | vol_paths: Union[str, list[str]], 206 | gt_paths: Union[str, list[str]], 207 | outdir: Optional[str] = None, 208 | mask_file: Optional[str] = None, 209 | fast: int = 1, 210 | overwrite: bool = False, 211 | vol_fl_function: Callable[[int], str] = lambda i: f"vol_{i:03d}", 212 | ) -> dict[int, pd.DataFrame]: 213 | """Calculate FSC curves across conformations compared to ground truth volumes.""" 214 | 215 | if isinstance(gt_paths, str): 216 | if os.path.isdir(gt_paths): 217 | gt_vols = sorted(glob(os.path.join(gt_paths, "*.mrc")), key=numfile_sortkey) 218 | else: 219 | raise ValueError( 220 | "Single value given for `gt_paths` must be a path to a directory " 221 | "containing .mrc volumes to be aligned against!" 222 | ) 223 | elif isinstance(gt_paths, list): 224 | gt_vols = gt_paths 225 | else: 226 | raise ValueError( 227 | f"Unrecognized type given for argument " 228 | f"`gt_paths`: {type(gt_paths).__name__} !" 229 | ) 230 | 231 | if isinstance(vol_paths, str): 232 | if os.path.isdir(vol_paths): 233 | vol_files = [ 234 | os.path.join(vol_paths, f"{vol_fl_function(i)}.mrc") 235 | for i in range(len(gt_vols)) 236 | ] 237 | else: 238 | raise ValueError( 239 | "Single value given for `vol_paths` must be a path to a directory " 240 | "containing .mrc volumes to be aligned!" 241 | ) 242 | elif isinstance(vol_paths, list): 243 | vol_files = vol_paths 244 | else: 245 | raise ValueError( 246 | f"Unrecognized type given for argument " 247 | f"`vol_paths`: {type(vol_paths).__name__} !" 248 | ) 249 | 250 | if mask_file is not None: 251 | outlbl = f"fsc_{os.path.splitext(os.path.basename(mask_file))[0]}" 252 | else: 253 | outlbl = "fsc_no_mask" 254 | 255 | if outdir is None: 256 | if isinstance(vol_paths, str): 257 | outdir = vol_paths 258 | else: 259 | outdir = os.path.dirname(vol_paths[0]) 260 | 261 | os.makedirs(os.path.join(outdir, outlbl), exist_ok=True) 262 | fsc_curves = dict() 263 | for ii, gt_volfile in enumerate(gt_vols): 264 | if ii % fast != 0: 265 | continue 266 | 267 | out_fsc = os.path.join(outdir, outlbl, f"{ii:03d}.txt") 268 | vol1 = torch.tensor(mrc.parse_mrc(gt_volfile)[0]) 269 | vol2 = torch.tensor(mrc.parse_mrc(vol_files[ii])[0]) 270 | 271 | if os.path.exists(out_fsc) and not overwrite: 272 | if ii % 20 == 0: 273 | logger.info(f"FSC {ii} exists, loading from file...") 274 | fsc_curves[ii] = pd.read_csv(out_fsc, sep=" ") 275 | else: 276 | fsc_curves[ii] = get_fsc_curve(vol1, vol2, mask_file) 277 | if ii % 20 == 0: 278 | logger.info(f"Saving FSC {ii} values to {out_fsc}") 279 | fsc_curves[ii].round(7).clip(0, 1).to_csv( 280 | out_fsc, sep=" ", header=True, index=False 281 | ) 282 | 283 | # Print summary statistics on max resolutions satisfying particular FSC thresholds 284 | fsc143 = [get_fsc_cutoff(x, 0.143) for x in fsc_curves.values()] 285 | logger.info( 286 | f"cryoDRGN FSC=0.143 — " 287 | f"Mean: {np.mean(fsc143):.4g} \t Median {np.median(fsc143):.4g}" 288 | ) 289 | fsc5 = [get_fsc_cutoff(x, 0.5) for x in fsc_curves.values()] 290 | logger.info( 291 | f"cryoDRGN FSC=0.5 — " 292 | f"Mean: {np.mean(fsc5):.4g} \t Median {np.median(fsc5):.4g}" 293 | ) 294 | 295 | return fsc_curves 296 | -------------------------------------------------------------------------------- /metrics/information_imbalance/submission.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=information_imbalance 3 | #SBATCH --output=%j.out 4 | #SBATCH --error=%j.err 5 | #SBATCH --time=24:00:00 6 | 7 | python compute_information_imbalance.py --method conf_het_1 --subset_size 2000 --input_latents_fname confhet1_wrangled_latents.npz --output_information_imbalance_fname confhet1_information_imbalance.csv --uniq_ks 1,3,10,30,100,300 1> stdout.txt 2> stderr.txt 8 | python compute_information_imbalance.py --method conf_het_2 --subset_size 2000 --input_latents_fname confhet2_wrangled_latents.npz --output_information_imbalance_fname confhet2_information_imbalance.csv --uniq_ks 1,3,10,30,100,300 1> stdout.txt 2> stderr.txt 9 | python compute_information_imbalance.py --method assemble_het --subset_size 2000 --input_latents_fname assemblehet_wrangled_latents.npz --output_information_imbalance_fname assemblehet_information_imbalance.csv --uniq_ks 1,3,10,30,100,300 1> stdout.txt 2> stderr.txt 10 | python compute_information_imbalance.py --method mix_het --subset_size 2000 --input_latents_fname mixhet_wrangled_latents.npz --output_information_imbalance_fname mixhet_information_imbalance.csv --uniq_ks 1,3,10,30,100,300 1> stdout.txt 2> stderr.txt 11 | python compute_information_imbalance.py --method md --subset_size 2000 --input_latents_fname MD_wrangled_latents.npz --output_information_imbalance_fname MD_information_imbalance.csv --uniq_ks 1,3,10,30,100,300 1> stdout.txt 2> stderr.txt 12 | -------------------------------------------------------------------------------- /metrics/methods/recovar_scripts/Per_img_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c0e39d38", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pickle\n", 11 | "import numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 11, 17 | "id": "e185ce34", 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stderr", 22 | "output_type": "stream", 23 | "text": [ 24 | "/home/mj7341/.conda/envs/recovar/lib/python3.11/site-packages/jax/_src/api_util.py:174: SyntaxWarning: Jitted function has static_argnums=(3, 4, 5, 6, 7, 8), but only accepts 8 positional arguments. This warning will be replaced by an error after 2022-08-20 at the earliest.\n", 25 | " warnings.warn(f\"Jitted function has {argnums_name}={argnums}, \"\n" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "from recovar import output as o\n", 31 | "from recovar import dataset, utils, latent_density, embedding" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 12, 37 | "id": "78cde65d", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "recovar_result_dir = '/scratch/gpfs/ZHONGE/mj7341/NeurIPS/results/conf-het/dihedral/snr001/recovar'" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 13, 47 | "id": "5d0f5079", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "pipeline_output = o.PipelineOutput(recovar_result_dir +'/')" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 14, 57 | "id": "2b50b796", 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "(INFO) (dataset.py) (19-May-24 11:36:22) Loading halfset from file\n", 65 | "(INFO) (ctf.py) (19-May-24 11:36:23) Image size (pix) : 128\n", 66 | "(INFO) (ctf.py) (19-May-24 11:36:23) A/pix : 3.0\n", 67 | "(INFO) (ctf.py) (19-May-24 11:36:23) DefocusU (A) : 13628.021484375\n", 68 | "(INFO) (ctf.py) (19-May-24 11:36:23) DefocusV (A) : 12750.6298828125\n", 69 | "(INFO) (ctf.py) (19-May-24 11:36:23) Dfang (deg) : 100.841064453125\n", 70 | "(INFO) (ctf.py) (19-May-24 11:36:23) voltage (kV) : 300.0\n", 71 | "(INFO) (ctf.py) (19-May-24 11:36:23) cs (mm) : 2.700000047683716\n", 72 | "(INFO) (ctf.py) (19-May-24 11:36:23) w : 0.10000000149011612\n", 73 | "(INFO) (ctf.py) (19-May-24 11:36:23) Phase shift (deg) : 0.0\n", 74 | "(INFO) (ctf.py) (19-May-24 11:36:23) Image size (pix) : 128\n", 75 | "(INFO) (ctf.py) (19-May-24 11:36:23) A/pix : 3.0\n", 76 | "(INFO) (ctf.py) (19-May-24 11:36:23) DefocusU (A) : 13628.021484375\n", 77 | "(INFO) (ctf.py) (19-May-24 11:36:23) DefocusV (A) : 12750.6298828125\n", 78 | "(INFO) (ctf.py) (19-May-24 11:36:23) Dfang (deg) : 100.841064453125\n", 79 | "(INFO) (ctf.py) (19-May-24 11:36:23) voltage (kV) : 300.0\n", 80 | "(INFO) (ctf.py) (19-May-24 11:36:23) cs (mm) : 2.700000047683716\n", 81 | "(INFO) (ctf.py) (19-May-24 11:36:23) w : 0.10000000149011612\n", 82 | "(INFO) (ctf.py) (19-May-24 11:36:23) Phase shift (deg) : 0.0\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "cryos = pipeline_output.get('lazy_dataset')" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 15, 93 | "id": "d00de065", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "zs = pipeline_output.get('zs')[zdim]" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 16, 103 | "id": "8b94e9b6", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "zs_reordered = dataset.reorder_to_original_indexing(zs, cryos )" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 17, 113 | "id": "dbefe01a", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "num_images = zs.shape[0]" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 18, 123 | "id": "059e1652", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "indices = (np.linspace(0, num_images, 100)).astype(int)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 31, 133 | "id": "8e9bdd69", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "target_zs = zs_reordered[::1000]" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 32, 143 | "id": "2652e3d9", 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "data": { 148 | "text/plain": [ 149 | "(100, 20)" 150 | ] 151 | }, 152 | "execution_count": 32, 153 | "metadata": {}, 154 | "output_type": "execute_result" 155 | } 156 | ], 157 | "source": [ 158 | "target_zs.shape" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "id": "9e5a2586", 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [] 168 | } 169 | ], 170 | "metadata": { 171 | "kernelspec": { 172 | "display_name": "recovar [~/.conda/envs/recovar/]", 173 | "language": "python", 174 | "name": "conda_recovar" 175 | }, 176 | "language_info": { 177 | "codemirror_mode": { 178 | "name": "ipython", 179 | "version": 3 180 | }, 181 | "file_extension": ".py", 182 | "mimetype": "text/x-python", 183 | "name": "python", 184 | "nbconvert_exporter": "python", 185 | "pygments_lexer": "ipython3", 186 | "version": "3.11.9" 187 | } 188 | }, 189 | "nbformat": 4, 190 | "nbformat_minor": 5 191 | } 192 | -------------------------------------------------------------------------------- /metrics/methods/recovar_scripts/Reorder_to_original_idx.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "077b9e20", 6 | "metadata": {}, 7 | "source": [ 8 | "### Recovar " 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "b21f1ad6", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "2024-05-19 22:50:22.401914: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", 22 | "2024-05-19 22:50:23.001492: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", 23 | "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" 24 | ] 25 | }, 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "(INFO) (xla_bridge.py) (19-May-24 22:50:38) Unable to initialize backend 'cuda': Unable to load cuSOLVER. Is it installed?\n", 31 | "(INFO) (xla_bridge.py) (19-May-24 22:50:38) Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: \"rocm\". Available platform names are: CUDA\n", 32 | "(INFO) (xla_bridge.py) (19-May-24 22:50:38) Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", 33 | "(WARNING) (xla_bridge.py) (19-May-24 22:50:38) CUDA backend failed to initialize: Unable to load cuSOLVER. Is it installed? (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n", 34 | "(INFO) (config.py) (19-May-24 22:50:38) Devices found: cpu\n" 35 | ] 36 | }, 37 | { 38 | "name": "stderr", 39 | "output_type": "stream", 40 | "text": [ 41 | "/home/mj7341/.conda/envs/recovar/lib/python3.11/site-packages/jax/_src/api_util.py:174: SyntaxWarning: Jitted function has static_argnums=(3, 4, 5, 6, 7, 8), but only accepts 8 positional arguments. This warning will be replaced by an error after 2022-08-20 at the earliest.\n", 42 | " warnings.warn(f\"Jitted function has {argnums_name}={argnums}, \"\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "import recovar.config \n", 48 | "import logging\n", 49 | "import numpy as np\n", 50 | "from recovar import output as o\n", 51 | "from recovar import dataset, utils, latent_density, embedding\n", 52 | "from scipy.spatial import distance_matrix\n", 53 | "import pickle\n", 54 | "import os, argparse\n", 55 | "\n", 56 | "from cryodrgn import analysis" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 26, 62 | "id": "e21c45d7", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "recovar_result_dir = '/scratch/gpfs/ZHONGE/mj7341/NeurIPS/results/conf-het/dihedral/snr0001/recovar'\n", 67 | "zdim = 10" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 27, 73 | "id": "13dd8c42", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "(INFO) (dataset.py) (20-May-24 00:04:40) Loading halfset from file\n", 81 | "(INFO) (ctf.py) (20-May-24 00:04:41) Image size (pix) : 128\n", 82 | "(INFO) (ctf.py) (20-May-24 00:04:41) A/pix : 3.0\n", 83 | "(INFO) (ctf.py) (20-May-24 00:04:41) DefocusU (A) : 13628.021484375\n", 84 | "(INFO) (ctf.py) (20-May-24 00:04:41) DefocusV (A) : 12750.6298828125\n", 85 | "(INFO) (ctf.py) (20-May-24 00:04:41) Dfang (deg) : 100.841064453125\n", 86 | "(INFO) (ctf.py) (20-May-24 00:04:41) voltage (kV) : 300.0\n", 87 | "(INFO) (ctf.py) (20-May-24 00:04:41) cs (mm) : 2.700000047683716\n", 88 | "(INFO) (ctf.py) (20-May-24 00:04:41) w : 0.10000000149011612\n", 89 | "(INFO) (ctf.py) (20-May-24 00:04:41) Phase shift (deg) : 0.0\n", 90 | "(INFO) (ctf.py) (20-May-24 00:04:41) Image size (pix) : 128\n", 91 | "(INFO) (ctf.py) (20-May-24 00:04:41) A/pix : 3.0\n", 92 | "(INFO) (ctf.py) (20-May-24 00:04:41) DefocusU (A) : 13628.021484375\n", 93 | "(INFO) (ctf.py) (20-May-24 00:04:41) DefocusV (A) : 12750.6298828125\n", 94 | "(INFO) (ctf.py) (20-May-24 00:04:41) Dfang (deg) : 100.841064453125\n", 95 | "(INFO) (ctf.py) (20-May-24 00:04:41) voltage (kV) : 300.0\n", 96 | "(INFO) (ctf.py) (20-May-24 00:04:41) cs (mm) : 2.700000047683716\n", 97 | "(INFO) (ctf.py) (20-May-24 00:04:41) w : 0.10000000149011612\n", 98 | "(INFO) (ctf.py) (20-May-24 00:04:41) Phase shift (deg) : 0.0\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "pipeline_output = o.PipelineOutput(recovar_result_dir + '/')\n", 104 | "cryos = pipeline_output.get('lazy_dataset')\n", 105 | "zs = pipeline_output.get('zs')[zdim]\n", 106 | "zs_reordered = dataset.reorder_to_original_indexing(zs, cryos)\n", 107 | "\n", 108 | "latent_path = os.path.join(recovar_result_dir, 'reordered_z.npy')\n", 109 | "np.save(latent_path, zs_reordered)\n", 110 | "umap_pkl = analysis.run_umap(zs_reordered)\n", 111 | "umap_path = os.path.join(recovar_result_dir, 'reordered_z_umap.npy')\n", 112 | "np.save(umap_path, umap_pkl)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "ed544498", 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "recovar [~/.conda/envs/recovar/]", 127 | "language": "python", 128 | "name": "conda_recovar" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.11.9" 141 | } 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 5 145 | } 146 | -------------------------------------------------------------------------------- /metrics/methods/recovar_scripts/gen_avg_latent_vol.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import logging 5 | 6 | sys.path.append("recovar") 7 | import recovar 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def add_args(parser: argparse.ArgumentParser): 13 | 14 | parser.add_argument( 15 | "result_dir", 16 | type=os.path.abspath, 17 | help="result dir (output dir of pipeline)", 18 | ) 19 | 20 | parser.add_argument( 21 | "-o", 22 | "--outdir", 23 | type=os.path.abspath, 24 | required=False, 25 | help="Output directory to save model", 26 | ) 27 | 28 | # parser.add_argument( 29 | # "--latent-points", type=os.path.abspath, 30 | # required=True, 31 | # help="path to latent points (.txt file)", 32 | # ) 33 | 34 | parser.add_argument("--Bfactor", type=float, default=0, help="0") 35 | 36 | parser.add_argument( 37 | "--n-bins", 38 | type=float, 39 | default=50, 40 | dest="n_bins", 41 | help="number of bins for reweighting", 42 | ) 43 | 44 | parser.add_argument( 45 | "--zdim", type=int, default=20, help="z dim of the latent space" 46 | ) 47 | parser.add_argument( 48 | "--num-imgs", type=int, default=1000, help="z dim of the latent space" 49 | ) 50 | parser.add_argument( 51 | "--vol-num", default=0, type=int, help="n-th vol to reconstruct" 52 | ) 53 | 54 | return parser 55 | 56 | 57 | def compute_state(args): 58 | 59 | po = recovar.output.PipelineOutput(args.result_dir + "/") 60 | # target_zs = np.loadtxt(args.latent_points) 61 | output_folder = args.outdir 62 | 63 | # if args.zdim1: 64 | # zdim =1 65 | # target_zs = target_zs[:,None] 66 | # else: 67 | # zdim = target_zs.shape[-1] 68 | # if target_zs.ndim ==1: 69 | # logger.warning("Did you mean to use --zdim1?") 70 | # target_zs = target_zs[None] 71 | 72 | # if zdim not in po.get('zs'): 73 | # logger.error("z-dim not found in results. Options are:" + ','.join(str(e) for e in po.get('zs').keys())) 74 | cryos = po.get("dataset") 75 | recovar.embedding.set_contrasts_in_cryos(cryos, po.get("contrasts")[args.zdim]) 76 | zs = po.get("zs")[args.zdim] 77 | cov_zs = po.get("cov_zs")[args.zdim] 78 | noise_variance = po.get("noise_var_used") 79 | n_bins = args.n_bins 80 | zs_reordered = recovar.dataset.reorder_to_original_indexing(zs, cryos) 81 | z = zs_reordered[args.vol_num * args.num_imgs : (args.vol_num + 1) * args.num_imgs] 82 | target_zs = z.mean(axis=0) 83 | target_zs = target_zs.reshape(1, -1) 84 | recovar.output.mkdir_safe(output_folder) 85 | logger.addHandler(logging.FileHandler(f"{output_folder}/run.log")) 86 | logger.info(args) 87 | recovar.output.compute_and_save_reweighted( 88 | cryos, 89 | target_zs, 90 | zs, 91 | cov_zs, 92 | noise_variance, 93 | output_folder, 94 | args.Bfactor, 95 | n_bins, 96 | ) 97 | 98 | 99 | if __name__ == "__main__": 100 | parser = argparse.ArgumentParser(description=__doc__) 101 | args = add_args(parser).parse_args() 102 | compute_state(args) 103 | -------------------------------------------------------------------------------- /metrics/methods/recovar_scripts/gen_med_latent_vol.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import numpy as np 5 | import logging 6 | 7 | sys.path.append("recovar") 8 | import recovar 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def add_args(parser: argparse.ArgumentParser): 16 | 17 | parser.add_argument( 18 | "result_dir", 19 | type=os.path.abspath, 20 | help="result dir (output dir of pipeline)", 21 | ) 22 | 23 | parser.add_argument( 24 | "-o", 25 | "--outdir", 26 | type=os.path.abspath, 27 | required=False, 28 | help="Output directory to save model", 29 | ) 30 | 31 | # parser.add_argument( 32 | # "--latent-points", type=os.path.abspath, 33 | # required=True, 34 | # help="path to latent points (.txt file)", 35 | # ) 36 | 37 | parser.add_argument("--Bfactor", type=float, default=0, help="0") 38 | 39 | parser.add_argument( 40 | "--n-bins", 41 | type=float, 42 | default=50, 43 | dest="n_bins", 44 | help="number of bins for reweighting", 45 | ) 46 | 47 | parser.add_argument( 48 | "--zdim", type=int, default=20, help="z dim of the latent space" 49 | ) 50 | parser.add_argument( 51 | "--num-imgs", type=int, default=1000, help="z dim of the latent space" 52 | ) 53 | parser.add_argument( 54 | "--vol-num", default=0, type=int, help="n-th vol to reconstruct" 55 | ) 56 | 57 | return parser 58 | 59 | 60 | def compute_state(args): 61 | 62 | po = recovar.output.PipelineOutput(args.result_dir + "/") 63 | # target_zs = np.loadtxt(args.latent_points) 64 | output_folder = args.outdir 65 | 66 | # if args.zdim1: 67 | # zdim =1 68 | # target_zs = target_zs[:,None] 69 | # else: 70 | # zdim = target_zs.shape[-1] 71 | # if target_zs.ndim ==1: 72 | # logger.warning("Did you mean to use --zdim1?") 73 | # target_zs = target_zs[None] 74 | 75 | # if zdim not in po.get('zs'): 76 | # logger.error("z-dim not found in results. Options are:" + ','.join(str(e) for e in po.get('zs').keys())) 77 | cryos = po.get("dataset") 78 | recovar.embedding.set_contrasts_in_cryos(cryos, po.get("contrasts")[args.zdim]) 79 | zs = po.get("zs")[args.zdim] 80 | cov_zs = po.get("cov_zs")[args.zdim] 81 | noise_variance = po.get("noise_var_used") 82 | n_bins = args.n_bins 83 | zs_reordered = recovar.dataset.reorder_to_original_indexing(zs, cryos) 84 | z = zs_reordered[args.vol_num * args.num_imgs : (args.vol_num + 1) * args.num_imgs] 85 | # target_zs = z.mean(axis=0) 86 | target_zs = np.median(z, axis=0) 87 | target_zs = target_zs.reshape(1, -1) 88 | recovar.output.mkdir_safe(output_folder) 89 | logger.addHandler(logging.FileHandler(f"{output_folder}/run.log")) 90 | logger.info(args) 91 | recovar.output.compute_and_save_reweighted( 92 | cryos, 93 | target_zs, 94 | zs, 95 | cov_zs, 96 | noise_variance, 97 | output_folder, 98 | args.Bfactor, 99 | n_bins, 100 | ) 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser(description=__doc__) 105 | args = add_args(parser).parse_args() 106 | compute_state(args) 107 | -------------------------------------------------------------------------------- /metrics/methods/recovar_scripts/gen_vol_for_per_conf_fsc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | def add_args(parser: argparse.ArgumentParser): 6 | 7 | parser.add_argument( 8 | "result_dir", 9 | type=os.path.abspath, 10 | help="result dir (output dir of pipeline)", 11 | ) 12 | 13 | parser.add_argument( 14 | "-o", 15 | "--outdir", 16 | type=os.path.abspath, 17 | required=False, 18 | help="Output directory to save model", 19 | ) 20 | parser.add_argument("--Bfactor", type=float, default=0, help="0") 21 | 22 | parser.add_argument( 23 | "--n-bins", 24 | type=float, 25 | default=50, 26 | dest="n_bins", 27 | help="number of bins for reweighting", 28 | ) 29 | 30 | parser.add_argument( 31 | "--zdim", type=int, default=20, help="z dim of the latent space" 32 | ) 33 | parser.add_argument( 34 | "--num-imgs", type=int, default=1000, help="z dim of the latent space" 35 | ) 36 | parser.add_argument( 37 | "--num-vols", default=100, type=int, help="number of G.T Volumes" 38 | ) 39 | 40 | return parser 41 | -------------------------------------------------------------------------------- /metrics/methods/recovar_scripts/gen_vol_for_per_conf_fsc_ribosembly.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from recovar import output as o 4 | from recovar import dataset, embedding 5 | import os 6 | import argparse 7 | 8 | logger = logging.getLogger(__name__) 9 | from cryodrgn import analysis 10 | 11 | 12 | def add_args(parser: argparse.ArgumentParser): 13 | 14 | parser.add_argument( 15 | "result_dir", 16 | type=os.path.abspath, 17 | help="result dir (output dir of pipeline)", 18 | ) 19 | 20 | parser.add_argument( 21 | "-o", 22 | "--outdir", 23 | type=os.path.abspath, 24 | required=False, 25 | help="Output directory to save model", 26 | ) 27 | 28 | parser.add_argument("--Bfactor", type=float, default=0, help="0") 29 | 30 | parser.add_argument( 31 | "--n-bins", 32 | type=float, 33 | default=50, 34 | dest="n_bins", 35 | help="number of bins for reweighting", 36 | ) 37 | 38 | parser.add_argument( 39 | "--zdim", type=int, default=20, help="z dim of the latent space" 40 | ) 41 | parser.add_argument( 42 | "--num-vols", default=100, type=int, help="number of G.T Volumes" 43 | ) 44 | 45 | return parser 46 | 47 | 48 | def compute_state(args): 49 | 50 | po = o.PipelineOutput(args.result_dir + "/") 51 | output_folder = args.outdir 52 | cryos = po.get("dataset") 53 | embedding.set_contrasts_in_cryos(cryos, po.get("contrasts")[args.zdim]) 54 | zs = po.get("zs")[args.zdim] 55 | cov_zs = po.get("cov_zs")[args.zdim] 56 | noise_variance = po.get("noise_var_used") 57 | n_bins = args.n_bins 58 | zs_reordered = dataset.reorder_to_original_indexing(zs, cryos) 59 | 60 | num_imgs = [ 61 | 9076, 62 | 14378, 63 | 23547, 64 | 44366, 65 | 30647, 66 | 38500, 67 | 3915, 68 | 3980, 69 | 12740, 70 | 11975, 71 | 17988, 72 | 5001, 73 | 35367, 74 | 37448, 75 | 40540, 76 | 5772, 77 | ] 78 | 79 | z_lst = [] 80 | z_mean_lst = [] 81 | for i in range(args.num_vols): 82 | z_nth = zs_reordered[sum(num_imgs[:i]) : sum(num_imgs[: i + 1])] 83 | z_nth_avg = z_nth.mean(axis=0) 84 | z_nth_avg = z_nth_avg.reshape(1, -1) 85 | z_lst.append(z_nth) 86 | z_mean_lst.append(z_nth_avg) 87 | nearest_z_lst = [] 88 | centers_ind_lst = [] 89 | 90 | num_img_for_centers = 0 91 | for i in range(args.num_vols): 92 | nearest_z, centers_ind = analysis.get_nearest_point(z_lst[i], z_mean_lst[i]) 93 | nearest_z_lst.append(nearest_z.reshape(nearest_z.shape[-1])) 94 | centers_ind_lst.append(centers_ind + num_img_for_centers) 95 | num_img_for_centers += num_imgs[i] 96 | centers_ind_array = np.array(centers_ind_lst) 97 | target_zs = zs_reordered[centers_ind_array].reshape( 98 | len(centers_ind_array), zs_reordered.shape[-1] 99 | ) 100 | 101 | o.mkdir_safe(output_folder) 102 | logger.addHandler(logging.FileHandler(f"{output_folder}/run.log")) 103 | logger.info(args) 104 | o.compute_and_save_reweighted( 105 | cryos, 106 | target_zs, 107 | zs, 108 | cov_zs, 109 | noise_variance, 110 | output_folder, 111 | args.Bfactor, 112 | n_bins, 113 | ) 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser(description=__doc__) 118 | args = add_args(parser).parse_args() 119 | compute_state(args) 120 | -------------------------------------------------------------------------------- /metrics/methods/recovar_scripts/make_dataset.py: -------------------------------------------------------------------------------- 1 | from importlib import reload 2 | import numpy as np 3 | from scipy.stats import norm 4 | from recovar import output 5 | from recovar import simulator 6 | 7 | reload(simulator) 8 | 9 | 10 | # warnings.filterwarnings("error") 11 | grid_size = 128 12 | for log_n in [6]: 13 | # output_folder ='/home/mg6942/mytigress/spike256/../' 14 | volume_folder_input = f"/tigress/CRYOEM/singerlab/mg6942/simulated_empiar10180/volumes_{grid_size}_small/vol" 15 | output_folder = volume_folder_input + f"/dataset_{log_n}_bump_3/" 16 | outlier_file_input = "/home/mg6942/mytigress/6vxx_256.mrc" 17 | n_images = int(10 ** (log_n)) 18 | voxel_size = ( 19 | 4.25 * 128 / grid_size 20 | ) # f"{output_folder}../spike{grid_size}_small/0000.mrc" 21 | output.mkdir_safe(output_folder) 22 | volume_distribution = np.zeros(300) 23 | first_k = 300 24 | volume_distribution[:first_k] = 1 / first_k 25 | 26 | # Triple bump 27 | grid = np.linspace(0, 1, 300) 28 | means = [0.25, 0.5, 0.75] 29 | std_used = 0.05 30 | weights = [0.5, 0.25, 0.5] 31 | volume_distribution = norm.pdf(grid, means[0], std_used) * weights[0] 32 | volume_distribution += norm.pdf(grid, means[1], std_used) * weights[1] 33 | volume_distribution += norm.pdf(grid, means[2], std_used) * weights[2] 34 | volume_distribution = volume_distribution / volume_distribution.sum() 35 | 36 | image_stack, sim_info = simulator.generate_synthetic_dataset( 37 | output_folder, 38 | voxel_size, 39 | volume_folder_input, 40 | n_images, 41 | outlier_file_input=outlier_file_input, 42 | grid_size=grid_size, 43 | volume_distribution=volume_distribution, 44 | dataset_params_option="uniform", 45 | noise_level=3, 46 | noise_model="radial1", 47 | put_extra_particles=False, 48 | percent_outliers=0.00, 49 | volume_radius=0.7, 50 | trailing_zero_format_in_vol_name=True, 51 | noise_scale_std=0.2 * 0, 52 | contrast_std=0.2 * 0, 53 | disc_type="nufft", 54 | ) 55 | print(f"Finished generating dataset {output_folder}") 56 | -------------------------------------------------------------------------------- /metrics/methods/recovar_scripts/per_image_generation.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | 4 | 5 | def main(): 6 | recovar_result_dir = "/home/mg6942/mytigress/10345/recovar_data/radial/" 7 | 8 | embeddings = pickle.load(open(recovar_result_dir + "model/embeddings.pkl", "rb")) 9 | zdim = 10 10 | zs = embeddings["zs"][zdim] 11 | 12 | # Sort by first coordinate? 13 | sorted_idx = np.argsort(zs[:, 0]) 14 | zs = zs[sorted_idx] 15 | 16 | num_images = zs.shape[0] 17 | # Take one hundred images, evenly spaced in number of zs 18 | indices = (np.linspace(0, num_images, 100)).astype(int) 19 | target_zs = zs[indices] 20 | 21 | # If in the original indexing, things are sorted by the ground truth labels, and you want to sample every 100 in the ground truth label, then you can run this: 22 | 23 | # NOTE that since imports recovar and it is not pip installed, you should recovar this file to recovar folder to run it. 24 | 25 | # from recovar import output, dataset 26 | # pipeline_output = output.PipelineOutput(recovar_result_dir + '/') 27 | # cryos = pipeline_output.get('lazy_dataset') 28 | # zs = pipeline_output.get('zs')[zdim] 29 | # zs_reordered = dataset.reorder_to_original_indexing(zs, cryos ) 30 | # num_images = zs.shape[0] 31 | # # Take one hundred images, evenly spaced in number of zs 32 | # indices = (np.linspace(0, num_images, 100)).astype(int) 33 | # target_zs = zs_reordered[indices] 34 | 35 | output_dir = "zs_to_eval.txt" 36 | np.savetxt(output_dir, target_zs) 37 | 38 | return 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /metrics/methods/recovar_scripts/per_img_fsc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import jax.numpy as jnp 4 | import recovar 5 | import logging 6 | import plotly.offline as py 7 | import numpy as np 8 | from scipy.spatial import distance_matrix 9 | 10 | ftu = recovar.fourier_transform_utils.fourier_transform_utils(jnp) 11 | logger = logging.getLogger(__name__) 12 | py.init_notebook_mode() 13 | 14 | 15 | def add_args(parser: argparse.ArgumentParser): 16 | 17 | parser.add_argument( 18 | "result_dir", 19 | # dest="result_dir", 20 | type=os.path.abspath, 21 | help="result dir (output dir of pipeline)", 22 | ) 23 | 24 | parser.add_argument( 25 | "-o", 26 | "--outdir", 27 | type=os.path.abspath, 28 | required=False, 29 | help="Output directory to save model", 30 | ) 31 | parser.add_argument( 32 | "--zdim", 33 | type=int, 34 | help="Dimension of latent variable (a single int, not a list)", 35 | ) 36 | 37 | parser.add_argument( 38 | "--n-clusters", 39 | dest="n_clusters", 40 | type=int, 41 | default=40, 42 | help="number of k-means clusters (default 40)", 43 | ) 44 | 45 | parser.add_argument( 46 | "--n-trajectories", 47 | type=int, 48 | default=6, 49 | dest="n_trajectories", 50 | help="number of trajectories to compute between k-means clusters (default 6)", 51 | ) 52 | 53 | parser.add_argument( 54 | "--skip-umap", 55 | dest="skip_umap", 56 | action="store_true", 57 | help="whether to skip u-map embedding (can be slow for large dataset)", 58 | ) 59 | 60 | parser.add_argument( 61 | "--skip-centers", 62 | dest="skip_centers", 63 | action="store_true", 64 | help="whether to generate the volume of the k-means centers", 65 | ) 66 | 67 | parser.add_argument( 68 | "--adaptive", 69 | action="store_true", 70 | help="whether to use the adapative discretization scheme in reweighing to compute trajectory volumes", 71 | ) 72 | 73 | parser.add_argument( 74 | "--n-vols-along-path", 75 | type=int, 76 | default=6, 77 | dest="n_vols_along_path", 78 | help="number of volumes to compute along each trajectory (default 6)", 79 | ) 80 | 81 | parser.add_argument( 82 | "--q", 83 | type=float, 84 | default=None, 85 | help="quantile used for reweighting (default = 0.95)", 86 | ) 87 | 88 | parser.add_argument("--Bfactor", type=float, default=0, help="0") 89 | 90 | parser.add_argument( 91 | "--n-bins", 92 | type=float, 93 | default=30, 94 | dest="n_bins", 95 | help="number of bins for reweighting", 96 | ) 97 | 98 | parser.add_argument( 99 | "--n-std", 100 | metavar=float, 101 | type=float, 102 | default=None, 103 | help="number of standard deviations to use for reweighting (don't set q and this parameter, only one of them)", 104 | ) 105 | 106 | return parser 107 | 108 | 109 | def pick_pairs(centers, n_pairs): 110 | # We try to pick some pairs that cover the latent space in some way. 111 | # This probably could be improved 112 | # 113 | # Pick some pairs that are far away from each other. 114 | pairs = [] 115 | X = distance_matrix(centers[:, :], centers[:, :]) 116 | 117 | for _ in range(n_pairs // 2): 118 | 119 | i_idx, j_idx = np.unravel_index(np.argmax(X), X.shape) 120 | X[i_idx, :] = 0 121 | X[:, i_idx] = 0 122 | X[j_idx, :] = 0 123 | X[:, j_idx] = 0 124 | pairs.append([i_idx, j_idx]) 125 | 126 | # Pick some pairs that are far in the first few principal components. 127 | zdim = centers.shape[-1] 128 | max_k = np.min([n_pairs // 2, zdim]) 129 | for k in range(max_k): 130 | i_idx = np.argmax(centers[:, k]) 131 | j_idx = np.argmin(centers[:, k]) 132 | pairs.append([i_idx, j_idx]) 133 | 134 | return pairs 135 | 136 | 137 | def mkdir_safe(folder): 138 | os.makedirs(folder, exist_ok=True) 139 | 140 | 141 | def main( 142 | result_dir, 143 | output_folder, 144 | zdim, 145 | n_clusters, 146 | n_paths, 147 | skip_umap, 148 | q, 149 | n_std, 150 | adaptive, 151 | B_factor, 152 | n_bins, 153 | n_vols_along_path, 154 | skip_centers, 155 | ): 156 | po = recovar.output.PipelineOutput(args.recovar_result_dir + "/") 157 | cryos = po.get("dataset") 158 | zdim = 10 159 | recovar.embedding.set_contrasts_in_cryos(cryos, po.get("contrasts")[zdim]) 160 | zs = po.get("zs")[zdim] 161 | cov_zs = po.get("cov_zs")[zdim] 162 | noise_variance = po.get("noise_var_used") 163 | B_factor = 0 164 | n_bins = 30 165 | num_imgs = 1000 166 | 167 | new_zs = zs[::num_imgs] 168 | new_cov_zs = cov_zs[::num_imgs] 169 | 170 | print(noise_variance, B_factor, n_bins, new_zs, new_cov_zs) 171 | 172 | 173 | if __name__ == "__main__": 174 | parser = argparse.ArgumentParser(description=__doc__) 175 | args = add_args(parser).parse_args() 176 | main( 177 | args.result_dir, 178 | output_folder=args.outdir, 179 | zdim=args.zdim, 180 | n_clusters=args.n_clusters, 181 | n_paths=args.n_trajectories, 182 | skip_umap=args.skip_umap, 183 | q=args.q, 184 | n_std=args.n_std, 185 | adaptive=args.adaptive, 186 | B_factor=args.Bfactor, 187 | n_bins=args.n_bins, 188 | n_vols_along_path=args.n_vols_along_path, 189 | skip_centers=args.skip_centers, 190 | ) 191 | -------------------------------------------------------------------------------- /metrics/neighborhood_similarity/README.md: -------------------------------------------------------------------------------- 1 | # CryoBench: Diverse and challenging datasets for the heterogeneity problem in cryo-EM 2 | 3 | ## Documentation: 4 | 5 | The latest documentation for CryoBench is available [homepage](https://cryobench.cs.princeton.edu/). 6 | 7 | For any feedback, questions, or bugs, please file a Github issue, start a Github discussion, or email. 8 | 9 | ## Installation: 10 | 11 | To run the script that calculates the neighborhood similarity, please first install [JAX](https://jax.readthedocs.io/en/latest/installation.html). 12 | 13 | # install jax 14 | $ pip install -U jax 15 | 16 | 17 | ## Neighborhood Similarity 18 | The neighborhood similarity quantifies the percentage of matching neighbors with respect to the ground truth that are found within a neighborhood radius `k`. 19 | 20 | 21 | ### Example 22 | In the repo [metrics/neighborhood_similarity](https://github.com/ml-struct-bio/CryoBench/tree/main/metrics/neighborhood_similarity) , you can find `cal_neighb_hit_werror.py` that calculates the neighborhood similarity between ground truth embeddings and several sets of embeddings from reconstruction algorithms found in `conf-het-1_wrangled_latents.npz`: 23 | 24 | $ python cal_neighb_hit_werror.py 25 | 26 | The output files are the neighborhood similarity as a function of the neighborhood radius for each reconstruction algorithm. 27 | 28 | ## References: 29 | 30 | Jeon, Minkyu, et al. "CryoBench: Diverse and challenging datasets for the heterogeneity problem in cryo-EM." arXiv preprint arXiv:2408.05526 (2024) [paper](https://arxiv.org/abs/2408.05526). 31 | 32 | Boggust, Angie et a. "Embedding comparator: Visualizing differences in global structure and local neighborhoods via small multiples." In 27th international 33 | conference on intelligent user interfaces, pages 746–766, 2022. 34 | 35 | ## Contact 36 | 37 | Please submit any bug reports, feature requests, or general usage feedback as a github issue or discussion. 38 | -------------------------------------------------------------------------------- /metrics/neighborhood_similarity/cal_neighb_hit_werror.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code to compute the neighborhood similarity 3 | 4 | """ 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import sys 9 | 10 | import pickle 11 | 12 | import jax.numpy as jnp 13 | import jax 14 | 15 | 16 | def compute_loss(n_neighs, dist_points1_int, dist_points2_int): 17 | """ 18 | Computes product of nearest neighbors matrices for two sets of points. 19 | 20 | Arguments 21 | ---------- 22 | dist_points1_int: jax Array 23 | int 1 if neighbor 0 otherwise for the first set of points. 24 | dist_points2_int: jax Array 25 | int 1 if neighbor 0 otherwise for the second set of points. 26 | 27 | Returns 28 | -------- 29 | loss: int 30 | Number of neighbors that match for the two sets of points 31 | 32 | """ 33 | loss = jnp.sum(dist_points1_int * dist_points2_int) 34 | return loss 35 | 36 | 37 | def compare_points_index(index, n_neighs, points1, points2): 38 | """ 39 | Auxliary function for vmapping `compare_points`. This function computes the euclidean distances for each set of points with the reference point being points[index]. The distances are arg_sorted twice to obtain the rank. Neighbor matrix are used to computed_loss. 40 | 41 | If the number of neighbors is an Array, then it computes the loss for each number of neighbors by vmapping over compute_loss. 42 | 43 | Arguments 44 | --------- 45 | index: int 46 | index of the reference point (to compute distances) 47 | n_neighs int | Array: 48 | number of neighbors to consider (can be an iterable) 49 | points1: Array 50 | First set of points 51 | points2: Array 52 | Second set of points 53 | 54 | Returns 55 | -------- 56 | loss: array with the loss for each number of neighbors 57 | 58 | """ 59 | dist_points1 = ( 60 | jnp.argsort(jnp.argsort(jnp.linalg.norm(points1 - points1[index], axis=1))) 61 | < n_neighs[0] + 1 62 | ) 63 | dist_points2 = ( 64 | jnp.argsort(jnp.argsort(jnp.linalg.norm(points2 - points2[index], axis=1))) 65 | < n_neighs[0] + 1 66 | ) 67 | # print(dist_points1) 68 | dist_points1_int = dist_points1.astype(int) 69 | dist_points2_int = dist_points2.astype(int) 70 | 71 | compute_loss_ = jax.vmap(compute_loss, in_axes=(0, None, None)) 72 | 73 | return compute_loss_(n_neighs, dist_points1_int, dist_points2_int) 74 | 75 | 76 | def compare_points(n_neighs, points1, points2): 77 | """ 78 | Computes the losses between points1 and points2. 79 | The loss is defined as the number of neighbors that match between the sets of points. 80 | 81 | Arguments 82 | ---------- 83 | n_neighs: Array 84 | Number of neighbors to be considered 85 | points1: Array 86 | First set of points 87 | points2: Array 88 | Second set of points 89 | 90 | Returns 91 | -------- 92 | loss: Array["number of neighbors", "number of points"] 93 | The loss for each point, for different number of neighbors considered 94 | """ 95 | 96 | assert points1.shape[0] == points2.shape[0], "The number of points must be equal" 97 | 98 | comp_points_map_index = jax.vmap( 99 | compare_points_index, in_axes=(0, None, None, None) 100 | ) 101 | 102 | return comp_points_map_index( 103 | np.arange(points1.shape[0]), n_neighs, points1, points2 104 | ).T 105 | 106 | 107 | def calculate_neigh_hits_k(start, points1, points2, k_neigh_range): 108 | """ 109 | Computes the number of matching neighbors as a function of k 110 | 111 | Arguments 112 | ---------- 113 | start: int 114 | Starting point to take embedding subsets. This is to calculate the error. 115 | points1: Array 116 | First set of points 117 | points2: Array 118 | Second set of points 119 | 120 | Returns 121 | -------- 122 | neigh_hit_k: Array 123 | The mean number of matched neighbors for k 124 | """ 125 | points_gt = points1[start::5] 126 | points_embd = points2[start::5] 127 | 128 | neigh_hit_k = [] 129 | 130 | for k in k_neigh_range: 131 | k_neighs = jnp.array([k]) 132 | losses = compare_points(k_neighs, points_gt, points_embd) 133 | mean = losses.mean(1)[0] 134 | neigh_hit_k.append(mean) 135 | 136 | return neigh_hit_k 137 | 138 | 139 | data = np.load("conf-het-1_wrangled_latents.npz") 140 | 141 | embd_names = data.files 142 | 143 | print(embd_names) 144 | 145 | points_gt = data["gt_s1_embeddings"] 146 | points_gt = jnp.array(points_gt.copy()) 147 | 148 | k_neigh_range = np.arange(200, 4100, 200) 149 | 150 | for name in embd_names: 151 | 152 | points_embd = data[name] 153 | print(name, points_gt.shape, points_embd.shape) 154 | 155 | points_embd = jnp.array(points_embd.copy()) 156 | 157 | neigh_hit_diff_start_k = [] 158 | 159 | for start in range(0, 5): 160 | neigh_hit = calculate_neigh_hits_k(start, points_gt, points_embd, k_neigh_range) 161 | neigh_hit_diff_start_k.append(neigh_hit) 162 | 163 | mean_neigh_hit_k = np.array(neigh_hit_diff_start_k) 164 | i = -1 165 | 166 | with open(f"{name}_output.txt", "w") as file: 167 | for k in k_neigh_range: 168 | i += 1 169 | file.write( 170 | f"{k} {mean_neigh_hit_k.mean(0)[i]} {mean_neigh_hit_k.std(0)[i]}\n" 171 | ) 172 | -------------------------------------------------------------------------------- /metrics/neighborhood_similarity/conf-het-1_wrangled_latents.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-struct-bio/CryoBench/20931c235a7c78626ed169d00d4534f80c63bc86/metrics/neighborhood_similarity/conf-het-1_wrangled_latents.npz -------------------------------------------------------------------------------- /metrics/neighborhood_similarity/make_plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | # Function to read data from text file 5 | def read_data(filename): 6 | data = np.loadtxt(filename) 7 | x = data[:, 0] 8 | y = data[:, 1] 9 | error = data[:, 2] 10 | return x, y, error 11 | 12 | 13 | embd_names = [ 14 | "cryosparc_3dflex_embeddings", 15 | "cryosparc_3dva_embeddings", 16 | "recovar_embeddings", 17 | "cryodrgn_embeddings", 18 | "cryodrgn2_embeddings", 19 | "drgnai_abinit_embeddings", 20 | "drgnai_fixed_embeddings", 21 | "opusdsd_mu_embeddings", 22 | ] 23 | 24 | # embd_names = ['CryoDRGN', 'DrgnAI-fixed','Opus-DSD','3DFlex', '3DVA', 'RECOVAR','CryoDRGN2','DrgnAI-abinit'] 25 | 26 | 27 | mapping = { 28 | "cryosparc_3dflex_embeddings": "3DFlex", 29 | "cryosparc_3dva_embeddings": "3DVA", 30 | "drgnai_fixed_embeddings": "DrgnAI-fixed", 31 | "recovar_embeddings": "RECOVAR", 32 | "cryodrgn_embeddings": "CryoDRGN", 33 | "cryodrgn2_embeddings": "CryoDRGN2", 34 | "drgnai_abinit_embeddings": "DrgnAI-abinit", 35 | "opusdsd_mu_embeddings": "Opus-DSD", 36 | } 37 | 38 | # Transforming the list using the mapping dictionary 39 | 40 | color_map = { 41 | "cryodrgn_embeddings": "#6190e6", 42 | "drgnai_fixed_embeddings": "#88B4E6", 43 | "opusdsd_mu_embeddings": "#b0e0e6", 44 | "cryosparc_3dflex_embeddings": "#98fb98", 45 | "cryosparc_3dva_embeddings": "#f4a460", 46 | "recovar_embeddings": "#f08080", 47 | "cryodrgn2_embeddings": "#7b68ee", 48 | "drgnai_abinit_embeddings": "#a569bd", 49 | # '3D Class': '#d8bfd8', 50 | # '3D Class (abinit)': '#da70d6', 51 | # 'G.T': '#bfbfbf' 52 | } 53 | 54 | # Read data from each file and create plots 55 | for name in embd_names: 56 | # Read data from text file 57 | filename = f"{name}_output.txt" 58 | x, y, error = read_data(filename) 59 | 60 | # Create plot 61 | plt.errorbar( 62 | x / 200.0, 63 | y / x * 100, 64 | yerr=error / x * 100, 65 | fmt="o", 66 | markersize=8, 67 | label=name, 68 | color=color_map.get(name, "black"), 69 | ) 70 | plt.plot( 71 | x / 200.0, 72 | y / x * 100, 73 | linestyle="-", 74 | color=color_map.get(name, "black"), 75 | linewidth=2.5, 76 | ) 77 | 78 | # Set plot title and labels 79 | # plt.title('Embedding Neighborhood Similarity',fontsize=20) 80 | plt.xlabel("Neighborhood Radius [%]", fontsize=20) 81 | plt.ylabel("% of Matching Neighbors", fontsize=20) 82 | # plt.legend(fontsize=8) 83 | plt.legend().set_visible(False) 84 | plt.xlim(0, 10) 85 | plt.ylim(0, 100) 86 | plt.xticks(fontsize=15) 87 | plt.yticks(fontsize=15) 88 | 89 | # Set tab20 colormap for the plot 90 | plt.set_cmap("tab20") 91 | 92 | # Save plot as a high-resolution PDF 93 | plt.tight_layout() 94 | plt.savefig("neighbor_conensus-conf-het-1.pdf", dpi=1200, bbox_inches="tight") 95 | 96 | # Show the plot 97 | # plt.show() 98 | -------------------------------------------------------------------------------- /metrics/pose_error/rot_error.py: -------------------------------------------------------------------------------- 1 | """Align and compute distance between two series of rotation matrices. 2 | 3 | Example usage 4 | ------------- 5 | $ python metrics/pose_error/rot_error.py cryobench_input/003_IgG-1D_cdrgn2/ \ 6 | datasets/IgG-1D/combined_poses.pkl --labels datasets/IgG-1D/gt_latents.pkl \ 7 | --save-err cryobench_output/cdrgn2_003_rot-error/ 8 | 9 | """ 10 | import os 11 | import argparse 12 | import logging 13 | import numpy as np 14 | from datetime import datetime as dt 15 | import torch 16 | import cryodrgn.utils 17 | from cryodrgn import lie_tools 18 | from scipy.linalg import logm 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def parse_args() -> argparse.ArgumentParser: 24 | parser = argparse.ArgumentParser(description=__doc__) 25 | 26 | parser.add_argument("traindir", help="Results directory") 27 | parser.add_argument("true_poses", help=".pkl file with GT poses") 28 | parser.add_argument("--save-err", help="save error with .npy") 29 | parser.add_argument("--epoch", type=int, default=-1, help="epoch (default: last)") 30 | parser.add_argument( 31 | "--labels", help=".pkl file with ground truth class index per particle" 32 | ) 33 | parser.add_argument("--ind", type=str) 34 | parser.add_argument( 35 | "--pred-labels", 36 | help="Provide if best class pose needs to be selected from predicted poses", 37 | ) 38 | parser.add_argument( 39 | "-N", type=int, default=30, help="Number of particles to attempt to align on" 40 | ) 41 | parser.add_argument("--data", type=str, help="name of the data") 42 | parser.add_argument("--seed", type=int, default=0) 43 | 44 | return parser 45 | 46 | 47 | # will give very close result to ang_dist() 48 | def ang_dist_2(A, B): 49 | diff_rot = np.zeros(len(A)) 50 | for i in range(len(diff_rot)): 51 | diff_rot[i] = np.sum(logm(np.dot(A[i].T, B[i])) ** 2) ** 0.5 52 | 53 | return np.rad2deg(diff_rot) / np.sqrt(2) 54 | 55 | 56 | def ang_dist_oop(A, B): 57 | unitvec_gt = np.array([0, 0, 1], dtype=np.float32).reshape(3, 1) 58 | out_of_planes_gt = np.sum(A * unitvec_gt, axis=-2) 59 | out_of_planes_gt /= np.linalg.norm(out_of_planes_gt, axis=-1, keepdims=True) 60 | out_of_planes_pred = np.sum(B * unitvec_gt, axis=-2) 61 | out_of_planes_pred /= np.linalg.norm(out_of_planes_pred, axis=-1, keepdims=True) 62 | diff_angle = ( 63 | np.arccos(np.clip(np.sum(out_of_planes_gt * out_of_planes_pred, -1), -1.0, 1.0)) 64 | * 180.0 65 | / np.pi 66 | ) 67 | 68 | return diff_angle 69 | 70 | 71 | def ang_dist(A, B): 72 | diff_rot = np.zeros_like(A) 73 | for i in range(len(diff_rot)): 74 | diff_rot[i, ...] = A[i, ...].T @ B[i, ...] 75 | diff_angle = np.arccos( 76 | np.clip((np.trace(diff_rot, axis1=1, axis2=2) - 1) / 2, -1, 1) 77 | ) 78 | diff_angle = np.abs(np.rad2deg(diff_angle)) 79 | 80 | return diff_angle 81 | 82 | 83 | def rot_diff(A, B): 84 | return np.matmul(np.swapaxes(B, -1, -2), A) 85 | 86 | 87 | def _flip(rot): 88 | return np.matmul(np.diag([1, 1, -1]).astype(rot.dtype), rot) 89 | 90 | 91 | def align_rot(rotA, rotB, N, flip=False): 92 | if flip: 93 | rotB = _flip(rotB) 94 | 95 | best_rot, best_medse = None, 1e9 96 | 97 | for i in np.random.choice(len(rotA), min(len(rotA), N), replace=False): 98 | mean_rot = np.dot(rotB[i].T, rotA[i]) 99 | rotB_hat = np.matmul(rotB, mean_rot) 100 | medse = np.median(np.sum((rotB_hat - rotA) ** 2, axis=(1, 2))) 101 | if medse < best_medse: 102 | best_medse = medse 103 | best_rot = mean_rot 104 | 105 | # align B into A's reference frame 106 | rotA_hat = np.matmul(rotA, best_rot.T).astype(rotA.dtype) 107 | rotB_hat = np.matmul(rotB, best_rot).astype(rotB.dtype) 108 | dist2 = np.sum((rotB_hat - rotA) ** 2, axis=(1, 2)) 109 | if flip: 110 | rotA_hat = _flip(rotA_hat) 111 | 112 | return rotA_hat, rotB_hat, best_rot, dist2 113 | 114 | 115 | def align_rot_flip(rotA, rotB, N): 116 | ret1 = align_rot(rotA, rotB, N, flip=False) 117 | ret2 = align_rot(rotA, rotB, N, flip=True) 118 | 119 | if np.median(ret1[-1]) < np.median(ret2[-1]): 120 | return ret1 121 | else: 122 | return ret2 123 | 124 | 125 | def err_format(err: float, digits: int = 5) -> str: 126 | if err < 1: 127 | return format(err, f".{digits}f") 128 | else: 129 | return format(format(err, f"{digits}g"), f"<0{digits + 2}") 130 | 131 | 132 | def main(args: argparse.Namespace) -> None: 133 | np.random.seed(args.seed) 134 | 135 | if not os.path.exists(os.path.join(args.save_err)): 136 | os.makedirs(os.path.join(args.save_err)) 137 | 138 | t1 = dt.now() 139 | 140 | basedir, job = os.path.split(os.path.normpath(args.traindir)) 141 | 142 | def train_path(f: str) -> str: 143 | return os.path.join(args.traindir, f) 144 | 145 | if os.path.exists(train_path(f"{job}_final_particles.cs")): 146 | method = "cryosparc" 147 | particle_info = np.load(train_path(f"{job}_final_particles.cs")) 148 | highest_prob = np.zeros(len(particle_info)) 149 | rot1 = np.zeros((len(particle_info), 3)) 150 | cl_idx = 0 151 | 152 | while f"alignments_class_{cl_idx}/class_posterior" in particle_info.dtype.names: 153 | new_best = ( 154 | particle_info[f"alignments_class_{cl_idx}/class_posterior"] 155 | > highest_prob 156 | ) 157 | highest_prob[new_best] = particle_info[ 158 | f"alignments_class_{cl_idx}/class_posterior" 159 | ][new_best] 160 | rot1[new_best] = particle_info[f"alignments_class_{cl_idx}/pose"][new_best] 161 | cl_idx += 1 162 | 163 | rot1 = lie_tools.expmap(torch.tensor(rot1)) 164 | rot1 = rot1.cpu().numpy() 165 | rot1 = np.array([x.T for x in rot1]) 166 | 167 | elif os.path.isfile(args.traindir) and os.path.splitext(args.traindir)[1] == ".pkl": 168 | method = ".pkl" 169 | rot1 = cryodrgn.utils.load_pkl(args.traindir) 170 | 171 | else: 172 | # if args.data == "Tomotwin-100": 173 | # rot1 = load_pkl(args.traindir) 174 | # else: 175 | if os.path.isdir(train_path("out")): 176 | method = "drgnai" 177 | args.traindir = os.path.join(args.traindir, "out") 178 | else: 179 | method = "cryodrgn" 180 | 181 | if args.epoch == -1: 182 | args.epoch = max( 183 | int(f.split(".")[1]) 184 | for f in os.listdir(args.traindir) 185 | if f.startswith("pose.") and len(f.split(".")) == 3 186 | ) 187 | 188 | rot1 = cryodrgn.utils.load_pkl(train_path(f"pose.{args.epoch}.pkl")) 189 | 190 | rot2 = cryodrgn.utils.load_pkl(args.true_poses) 191 | 192 | if isinstance(rot1, tuple): 193 | rot1 = rot1[0] 194 | if isinstance(rot2, tuple): 195 | rot2 = rot2[0] 196 | 197 | if args.pred_labels: 198 | pred_labels = cryodrgn.utils.load_pkl(args.pred_labels) 199 | rot1 = np.take_along_axis(rot1, pred_labels[:, None, None, None], 1).squeeze() 200 | if args.ind: 201 | ind = cryodrgn.utils.load_pkl(args.ind) 202 | rot2 = rot2[ind] 203 | 204 | assert rot1.shape == rot2.shape 205 | logger.info(f"data and method: {args.data}, {method}") 206 | errors_lst = [] 207 | 208 | if args.labels: 209 | labels = np.array(cryodrgn.utils.load_pkl(args.labels), dtype=int) 210 | if args.ind: 211 | labels = labels[ind] 212 | 213 | uniq_lbls = np.unique(labels) 214 | cls_space = int(np.log10(len(uniq_lbls) - 1)) + 1 215 | print("\n[mean; median] rotation errors:") 216 | print(f"{' ' * (7 + cls_space)}| Frobenius | Geodesic") 217 | print("-" * (45 + cls_space)) 218 | 219 | frob_means, frob_meds, counts = list(), list(), list() 220 | geo_means, geo_meds = list(), list() 221 | for i in np.unique(labels): 222 | mask = labels == i 223 | # print('mask:',mask.shape) 224 | counts.append(mask.sum()) 225 | rot1_i = rot1[mask] 226 | rot2_i = rot2[mask] 227 | # print('rot1_i:',rot1_i.shape) 228 | # print('rot2_i:',rot2_i.shape) 229 | r1, r2, rot, dist2 = align_rot_flip(rot1_i, rot2_i, args.N) 230 | # print('r1:',r1.shape) 231 | # print('dist2:',dist2.shape) 232 | fmean, fmed = np.mean(dist2), np.median(dist2) 233 | frob_means.append(fmean) 234 | frob_meds.append(fmed) 235 | 236 | ang_dists = ang_dist(r1, rot2_i) 237 | # print('ang_dists:',ang_dists.shape) 238 | gmean, gmed = np.mean(ang_dists), np.median(ang_dists) 239 | geo_means.append(gmean) 240 | geo_meds.append(gmed) 241 | errors_lst.append(ang_dists) 242 | 243 | fstr = f"{err_format(fmean)}; {err_format(fmed)}" 244 | gstr = f"{err_format(gmean)}; {err_format(gmed)}" 245 | print(f"Class {i:<{cls_space}} | {fstr} | {gstr}") 246 | 247 | logger.info(f"Class average Mean squared error: {np.mean(geo_means)}") 248 | logger.info(f"Class average Median squared error: {np.mean(geo_meds)}") 249 | w_mean = np.sum(np.array(geo_means) * np.array(counts)) / len(labels) 250 | w_med = np.sum(np.array(geo_meds) * np.array(counts)) / len(labels) 251 | logger.info(f"Weighted class average Mean squared error: {w_mean}") 252 | logger.info(f"Weighted class average Median squared error: {w_med}") 253 | 254 | else: 255 | r1, r2, rot, dist2 = align_rot_flip(rot1, rot2, args.N) 256 | 257 | logger.info(f"Mean squared error: {np.mean(dist2)}") 258 | logger.info(f"Median squared error: {np.median(dist2)}") 259 | 260 | ang_dists = ang_dist(r1, rot2) 261 | mean, med = np.mean(ang_dists), np.median(ang_dists) 262 | logger.info(f"Mean Geodesic: {mean}") 263 | logger.info(f"Median Geodesic: {med}") 264 | errors_lst.append(ang_dists) 265 | 266 | npy_name = f"errs_{method}_rot.npy" 267 | err_np_path = os.path.join(args.save_err, npy_name) 268 | errors_npy = np.array(errors_lst) 269 | 270 | with open(err_np_path, "wb") as f: 271 | np.save(f, errors_npy) 272 | 273 | tottime = dt.now() - t1 274 | logger.info(f"Finished in {tottime} ({tottime / args.N} per particle) ") 275 | 276 | 277 | if __name__ == "__main__": 278 | main(parse_args().parse_args()) 279 | -------------------------------------------------------------------------------- /metrics/pose_error/trans_error.py: -------------------------------------------------------------------------------- 1 | """Compute error between two series of particle shifts""" 2 | 3 | import os 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import matplotlib.pyplot as plt 8 | import cryodrgn.utils 9 | 10 | 11 | def parse_args() -> argparse.ArgumentParser: 12 | parser = argparse.ArgumentParser(description=__doc__) 13 | 14 | parser.add_argument("trans1", help="Input translations") 15 | parser.add_argument("trans2", help="Input translations") 16 | parser.add_argument("--save-err", help="save error with .npy") 17 | parser.add_argument("--rot-pred", action="store_true") 18 | parser.add_argument( 19 | "--labels", help=".pkl file with ground truth class index per particle" 20 | ) 21 | parser.add_argument("--ind1", help="Index filter for trans1") 22 | parser.add_argument("--ind2", help="Index filter for trans2") 23 | parser.add_argument("--ind-rot", help="Index filter for rot") 24 | parser.add_argument( 25 | "--rot", help="Input rotations, to adjust for translation shift between models" 26 | ) 27 | parser.add_argument("--s1", type=float, default=1.0, help="Scale for trans1") 28 | parser.add_argument("--s2", type=float, default=1.0, help="Scale for trans2") 29 | parser.add_argument("--show", action="store_true", help="Show histogram") 30 | parser.add_argument("-v", "--verbose", action="store_true", help="Verbosity") 31 | 32 | return parser 33 | 34 | 35 | def adjust_translations_for_offset(args, rot, trans1, trans2): 36 | rot = torch.from_numpy(rot).float() 37 | trans1 = torch.from_numpy(trans1).float() 38 | trans2 = torch.from_numpy(trans2).float() 39 | model_offset = torch.nn.Parameter(torch.zeros((1, 3), dtype=torch.float32)) 40 | optimizer = torch.optim.Adam([model_offset], lr=1e0) 41 | 42 | # normalize the translations to avoid ill-conditioning 43 | # FIXME: didn't deal with mean 44 | std = float((trans1.std() + trans2.std()) / 2) 45 | trans1 = trans1 / std 46 | trans2 = trans2 / std 47 | 48 | for i in range(512): 49 | optimizer.zero_grad() 50 | rotated_offset = (rot @ model_offset.unsqueeze(-1)).squeeze(-1)[:, :2] 51 | loss = ((trans2 - trans1 + rotated_offset) ** 2).sum(-1).mean() 52 | loss.backward() 53 | optimizer.step() 54 | 55 | if i & (i - 1) == 0 and args.verbose: 56 | print( 57 | f"epoch {i} RMSE: {loss**0.5:.4f} Offset: {model_offset.data.numpy()}" 58 | ) 59 | 60 | adj_trans2 = trans2 + rotated_offset.detach() 61 | return adj_trans2.numpy() * std, model_offset.detach().numpy() * std 62 | 63 | 64 | def main(args: argparse.Namespace) -> None: 65 | trans1 = cryodrgn.utils.load_pkl(args.trans1) 66 | if isinstance(trans1, tuple): 67 | trans1 = trans1[1] 68 | trans1 *= args.s1 69 | 70 | if args.verbose: 71 | print(trans1.shape) 72 | print(trans1) 73 | 74 | trans2 = cryodrgn.utils.load_pkl(args.trans2) 75 | if isinstance(trans2, tuple): 76 | trans2 = trans2[1] 77 | 78 | trans2 *= args.s2 79 | if args.verbose: 80 | print(trans2.shape) 81 | print(trans2) 82 | 83 | if args.ind1: 84 | trans1 = trans1[cryodrgn.utils.load_pkl(args.ind1).astype(int)] 85 | if args.ind2: 86 | trans2 = trans2[cryodrgn.utils.load_pkl(args.ind2).astype(int)] 87 | 88 | assert trans1.shape == trans2.shape 89 | if args.verbose: 90 | print(np.mean(trans1, axis=0)) 91 | print(np.mean(trans2, axis=0)) 92 | 93 | errors_lst = [] 94 | if args.labels: 95 | labels = np.array(cryodrgn.utils.load_pkl(args.labels), dtype=int) 96 | means, meds, counts = [], [], [] 97 | if args.rot: 98 | rot = cryodrgn.utils.load_pkl(args.rot) 99 | 100 | if isinstance(rot, tuple): 101 | rot = rot[0] 102 | if args.ind_rot: 103 | rot = rot[cryodrgn.utils.load_pkl(args.ind_rot).astype(int)] 104 | 105 | for i in np.unique(labels): 106 | mask = labels == i 107 | counts.append(mask.sum()) 108 | 109 | trans1_i = trans1[mask] 110 | trans2_i = trans2[mask] 111 | dists_i = np.sum((trans1_i - trans2_i) ** 2, axis=1) ** 0.5 112 | 113 | print(dists_i.shape) 114 | print(f"Class {i} Mean error: {np.mean(dists_i)}") 115 | print(f"Class {i} Median error: {np.median(dists_i)}") 116 | 117 | if args.rot: 118 | rot_i = rot[mask] 119 | trans2_i, offset_3d = adjust_translations_for_offset( 120 | args, rot_i, trans1_i, trans2_i 121 | ) 122 | dists_i = np.sum((trans1_i - trans2_i) ** 2, axis=1) ** 0.5 123 | mean, med = np.mean(dists_i), np.median(dists_i) 124 | if args.verbose: 125 | print("offset3d: {}".format(offset_3d)) 126 | print(f"Class {i} Mean error after adjustment: {mean}") 127 | print(f"Class {i} Median error after adjustment: {med}") 128 | 129 | means.append(mean) 130 | meds.append(med) 131 | errors_lst.append(dists_i) 132 | 133 | dists = np.array(errors_lst) 134 | 135 | else: 136 | dists = np.sum((trans1 - trans2) ** 2, axis=1) ** 0.5 137 | if args.verbose: 138 | print(dists.shape) 139 | 140 | print(f"Mean error: {np.mean(dists):.7g}") 141 | print(f"Median error: {np.median(dists):.7g}") 142 | 143 | if args.rot: 144 | rot = cryodrgn.utils.load_pkl(args.rot) 145 | if isinstance(rot, tuple): 146 | rot = rot[0] 147 | if args.ind_rot: 148 | rot = rot[cryodrgn.utils.load_pkl(args.ind_rot).astype(int)] 149 | 150 | trans2, offset_3d = adjust_translations_for_offset( 151 | args, rot, trans1, trans2 152 | ) 153 | dists = np.sum((trans1 - trans2) ** 2, axis=1) ** 0.5 154 | print(f"offset3d: {np.round(offset_3d, 6)}") 155 | print(f"Mean error after adjustment: {np.mean(dists):.7g}") 156 | print(f"Median error after adjustment: {np.median(dists):.7g}") 157 | 158 | if args.rot_pred: 159 | npy_name = "errs_trans_rot_pred_new.npy" 160 | else: 161 | npy_name = "errs_trans_rot_gt_new.npy" 162 | 163 | err_np_path = os.path.join(args.save_err, npy_name) 164 | os.makedirs(args.save_err, exist_ok=True) 165 | # errors_npy = np.array(dists) 166 | with open(err_np_path, "wb") as f: 167 | np.save(f, dists) 168 | 169 | if args.show: 170 | plt.figure(1) 171 | plt.hist(dists) 172 | plt.figure(2) 173 | plt.scatter(trans1[:, 0], trans1[:, 1], s=1, alpha=0.1) 174 | plt.figure(3) 175 | plt.scatter(trans2[:, 0], trans2[:, 1], s=1, alpha=0.1) 176 | plt.figure(4) 177 | d = trans1 - trans2 178 | plt.scatter(d[:, 0], d[:, 1], s=1, alpha=0.1) 179 | plt.show() 180 | 181 | 182 | if __name__ == "__main__": 183 | args = parse_args().parse_args() 184 | verbose = args.verbose 185 | main(args) 186 | -------------------------------------------------------------------------------- /metrics/utils/align.py: -------------------------------------------------------------------------------- 1 | """A Python wrapper for aligning two .mrc volumes using ChimeraX command-line tools. 2 | 3 | Example usage 4 | ------------- 5 | chimerax --nogui --script "metrics/utils/align.py ref.mrc vol.mrc -o out.mrc" 6 | chimerax --nogui --script \ 7 | "metrics/utils/align.py ref.mrc vol.mrc -o out.mrc -f out.txt --flip" 8 | 9 | """ 10 | import os 11 | import argparse 12 | from chimerax.core.commands import run 13 | 14 | parser = argparse.ArgumentParser(description="Aligns two volumes") 15 | parser.add_argument("ref", help="Input volume to align on") 16 | parser.add_argument("vol", help="Input volume to align") 17 | parser.add_argument("-o", type=os.path.abspath, required=True, help="Aligned mrc") 18 | parser.add_argument( 19 | "-f", 20 | type=os.path.abspath, 21 | required=True, 22 | help="Text file that this program's output is being piped to (required if flip=True)", 23 | ) 24 | parser.add_argument( 25 | "--ninits", type=int, default=50, help="Number of alignments to try" 26 | ) 27 | parser.add_argument( 28 | "--flip", 29 | action="store_true", 30 | help="Run an additional ninits alignments after flipping handedness of vol", 31 | ) 32 | parser.add_argument( 33 | "--seed", type=int, help="random seed to use for alignment initializations" 34 | ) 35 | args = parser.parse_args() 36 | 37 | # Open the two volumes in ChimeraX as `#1` and `#2` 38 | run(session, f"open {args.ref}") 39 | run(session, f"open {args.vol}") 40 | seed_str = f" seed {args.seed}" if args.seed is not None else "" 41 | 42 | if not args.flip: 43 | run(session, f"fitmap #2 inMap #1 search {args.ninits}{seed_str}") 44 | run(session, "volume resample #2 onGrid #1 modelId #3") 45 | run(session, f"save {args.o} #3") 46 | else: 47 | run(session, "volume flip #2") 48 | run(session, f"fitmap #2 inMap #1 search {args.ninits}{seed_str}") 49 | run(session, f"fitmap #3 inMap #1 search {args.ninits}{seed_str}") 50 | 51 | corrs = [] 52 | f = open(args.f, "r") 53 | for line in f: 54 | if line.startswith(" correlation"): 55 | corrs.append(float(line.split(",")[0][16:])) 56 | f.close() 57 | 58 | print(corrs) 59 | if corrs[0] > corrs[1]: 60 | run(session, "volume resample #2 onGrid #1 modelId #4") 61 | else: 62 | run(session, "volume resample #3 onGrid #1 modelId #4") 63 | run(session, f"save {args.o} #4") 64 | 65 | run(session, "exit") 66 | -------------------------------------------------------------------------------- /metrics/utils/align.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=job # create a short name for your job 3 | #SBATCH -p cryoem 4 | #SBATCH --nodes=1 # node count 5 | #SBATCH --ntasks=1 # total number of tasks across all nodes 6 | #SBATCH --cpus-per-task=1 # cpu-cores per task (>1 if multi-threaded tasks) 7 | #SBATCH --mem-per-cpu=16G # memory per cpu-core (4G per cpu-core is default) 8 | #SBATCH --time=00:10:00 # total run time limit (HH:MM:SS) 9 | 10 | ref=$1 11 | vol=$2 12 | outvol=$3 13 | tmpfile=$4 14 | 15 | /scratch/gpfs/mj7341/chimerax-1.6.1/bin/ChimeraX --nogui --script "align.py $ref $vol -o $outvol -f $tmpfile" > $tmpfile 16 | -------------------------------------------------------------------------------- /metrics/utils/align_multi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | import glob, re 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("dir") 8 | parser.add_argument("--apix", type=float) 9 | parser.add_argument("--org-vol", required=True) 10 | args = parser.parse_args() 11 | 12 | 13 | def natural_sort_key(s): 14 | # Convert the string to a list of text and numbers 15 | parts = re.split("([0-9]+)", s) 16 | 17 | # Convert numeric parts to integers for proper numeric comparison 18 | parts[1::2] = map(int, parts[1::2]) 19 | 20 | return parts 21 | 22 | 23 | file_pattern = "*.mrc" 24 | files = glob.glob(os.path.join(args.org_vol, file_pattern)) 25 | gt_dir = sorted(files, key=natural_sort_key) 26 | 27 | file_pattern = "*.mrc" 28 | matching_files = glob.glob(os.path.join(args.dir, file_pattern)) 29 | matching_files = sorted(matching_files, key=natural_sort_key) 30 | 31 | os.makedirs(os.path.join(args.dir, "aligned"), exist_ok=True) 32 | os.makedirs(os.path.join(args.dir, "flipped_aligned"), exist_ok=True) 33 | 34 | for i, file_path in enumerate(matching_files): 35 | base_filename = os.path.splitext(os.path.basename(file_path))[0] 36 | new_filename = base_filename + ".mrc" 37 | destination_path = os.path.join(args.dir, "aligned", new_filename) 38 | ref_path = gt_dir[i] 39 | 40 | align_cmd = f"sbatch metrics/align.slurm \ 41 | {ref_path} \ 42 | {os.path.join(args.dir, new_filename)} \ 43 | {destination_path} \ 44 | {os.path.join(args.dir, 'aligned', f'temp_{i:03d}.txt')}" 45 | print(align_cmd) 46 | subprocess.check_call(align_cmd, shell=True) 47 | -------------------------------------------------------------------------------- /metrics/utils/calculate_GT_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from MDAnalysis.lib.distances import calc_dihedrals 4 | 5 | import MDAnalysis as mda 6 | 7 | 8 | universe = mda.Universe("prot.psf", "conf-het-2.dcd") 9 | sel1 = universe.select_atoms("resid 1-213 and segid H or segid L") 10 | hinge1 = universe.select_atoms("resid 213 and segid H and name CA") 11 | hinge2 = universe.select_atoms("resid 244 and segid H and name CA and resname CYS") 12 | sel2 = universe.select_atoms("not ( resid 1-213 and segid H or segid L)") 13 | dihs = np.zeros(universe.trajectory.n_frames) 14 | dists = np.zeros(universe.trajectory.n_frames) 15 | 16 | for i in range(universe.trajectory.n_frames): 17 | universe.trajectory[i] 18 | coord0 = sel1.center_of_mass() 19 | coord1 = hinge1.center_of_mass() 20 | coord2 = hinge2.center_of_mass() 21 | coord3 = sel2.center_of_mass() 22 | dihs[i] = calc_dihedrals(coord0, coord1, coord2, coord3) 23 | dists[i] = np.linalg.norm((coord0 - coord3)) 24 | 25 | dihs = np.array([angle + 2 * np.pi if angle < 0 else angle for angle in dihs]) 26 | dihs = dihs / np.pi 27 | write_array = np.array([dists, dihs]).T 28 | 29 | 30 | # np.save('conf-het-2_CV_dihedral_distance.npy',write_array) 31 | # $plt.xlabel('Center of mass distance d ($\AA$)',fontsize=20) 32 | # plt.ylabel('Dihedral angle $\phi$ (radians/$\pi$)',fontsize=20) 33 | plt.figure(figsize=(7, 5)) 34 | plt.xlabel("Center of mass distance ($\AA$)", fontsize=20) 35 | plt.ylabel("Dihedral angle (radians/ $\pi$)", fontsize=20) 36 | plt.xticks(fontsize=16) 37 | plt.yticks(fontsize=16) 38 | plt.ylim([0, 2]) 39 | plt.scatter(dists, dihs, s=200, edgecolors="black", c=dists, cmap="viridis") 40 | plt.tight_layout() 41 | plt.savefig("dihedral_distance.pdf", dpi=92) 42 | plt.show() 43 | -------------------------------------------------------------------------------- /metrics/visualization/.ipynb_checkpoints/conf-het_latent_visualize-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "8e18e83d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pandas as pd\n", 11 | "import numpy as np\n", 12 | "from cryodrgn.starfile import Starfile\n", 13 | "from cryodrgn import analysis, utils, config\n", 14 | "\n", 15 | "import pickle\n", 16 | "import os, sys\n", 17 | "import re\n", 18 | "import argparse\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "log = print\n", 21 | "%matplotlib inline" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "id": "3aab88d4", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "def natural_sort_key(s):\n", 32 | " # Convert the string to a list of text and numbers\n", 33 | " parts = re.split('([0-9]+)', s)\n", 34 | " \n", 35 | " # Convert numeric parts to integers for proper numeric comparison\n", 36 | " parts[1::2] = map(int, parts[1::2])\n", 37 | " \n", 38 | " return parts" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "id": "6bab7c94", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "def plt_umap_labels():\n", 49 | " plt.xticks([])\n", 50 | " plt.yticks([])" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "id": "5825361d", 56 | "metadata": {}, 57 | "source": [ 58 | "# Conf-het-1" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "id": "73b90d9f", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "dihed_angles = np.linspace(-180, 176.4, 100)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 5, 74 | "id": "f8489fd0", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "umap_pkl = \"/scratch/gpfs/ZHONGE/mj7341/NeurIPS/results/conf-het/dihedral/snr001/cryodrgn/analyze.19/umap.pkl\"\n", 79 | "umap_pkl = open(umap_pkl, 'rb')\n", 80 | "umap = pickle.load(umap_pkl) " 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "4d80f11b", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "mix_label = np.arange(100)\n", 91 | "fig, ax = plt.subplots(figsize=(4,4))\n", 92 | "plot_dim = (0,1)\n", 93 | "# Dihedral angles\n", 94 | "c_all = np.repeat(mix_label,1000)\n", 95 | "c = c_all\n", 96 | "plot_args = dict(alpha=.1, s=1, cmap='gist_rainbow', vmin=-180,vmax=176.4)\n", 97 | "plt.scatter(new_umap[:,plot_dim[0]], new_umap[:,plot_dim[1]], c=c_all, **plot_args)" 98 | ] 99 | } 100 | ], 101 | "metadata": { 102 | "kernelspec": { 103 | "display_name": "tomodrgn [~/.conda/envs/tomodrgn/]", 104 | "language": "python", 105 | "name": "conda_tomodrgn" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.10.13" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 5 122 | } 123 | -------------------------------------------------------------------------------- /metrics/visualization/.ipynb_checkpoints/dynamight-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "fcb7b08c", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "id": "ae33e00d", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "path ='/scratch/gpfs/ZHONGE/mj7341/NeurIPS/results/conf-het/dihedral/snr001/dynamight/inverse_deformations/inv_chkpt.pth'" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "id": "cb2b77ce", 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "ename": "ModuleNotFoundError", 31 | "evalue": "No module named 'dynamight'", 32 | "output_type": "error", 33 | "traceback": [ 34 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 35 | "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", 36 | "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m d \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m)\u001b[49m\n", 37 | "File \u001b[0;32m~/.conda/envs/recovar/lib/python3.11/site-packages/torch/serialization.py:1014\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[1;32m 1012\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1013\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m pickle\u001b[38;5;241m.\u001b[39mUnpicklingError(UNSAFE_MESSAGE \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(e)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 1014\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_load\u001b[49m\u001b[43m(\u001b[49m\u001b[43mopened_zipfile\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1015\u001b[0m \u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1016\u001b[0m \u001b[43m \u001b[49m\u001b[43mpickle_module\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1017\u001b[0m \u001b[43m \u001b[49m\u001b[43moverall_storage\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moverall_storage\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1018\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpickle_load_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1019\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m mmap:\n\u001b[1;32m 1020\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmmap can only be used with files saved with \u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1021\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`torch.save(_use_new_zipfile_serialization=True), \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1022\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplease torch.save your checkpoint with this option in order to use mmap.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", 38 | "File \u001b[0;32m~/.conda/envs/recovar/lib/python3.11/site-packages/torch/serialization.py:1422\u001b[0m, in \u001b[0;36m_load\u001b[0;34m(zip_file, map_location, pickle_module, pickle_file, overall_storage, **pickle_load_args)\u001b[0m\n\u001b[1;32m 1420\u001b[0m unpickler \u001b[38;5;241m=\u001b[39m UnpicklerWrapper(data_file, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpickle_load_args)\n\u001b[1;32m 1421\u001b[0m unpickler\u001b[38;5;241m.\u001b[39mpersistent_load \u001b[38;5;241m=\u001b[39m persistent_load\n\u001b[0;32m-> 1422\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43munpickler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1424\u001b[0m torch\u001b[38;5;241m.\u001b[39m_utils\u001b[38;5;241m.\u001b[39m_validate_loaded_sparse_tensors()\n\u001b[1;32m 1425\u001b[0m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_log_api_usage_metadata(\n\u001b[1;32m 1426\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch.load.metadata\u001b[39m\u001b[38;5;124m\"\u001b[39m, {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mserialization_id\u001b[39m\u001b[38;5;124m\"\u001b[39m: zip_file\u001b[38;5;241m.\u001b[39mserialization_id()}\n\u001b[1;32m 1427\u001b[0m )\n", 39 | "File \u001b[0;32m~/.conda/envs/recovar/lib/python3.11/site-packages/torch/serialization.py:1415\u001b[0m, in \u001b[0;36m_load..UnpicklerWrapper.find_class\u001b[0;34m(self, mod_name, name)\u001b[0m\n\u001b[1;32m 1413\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 1414\u001b[0m mod_name \u001b[38;5;241m=\u001b[39m load_module_mapping\u001b[38;5;241m.\u001b[39mget(mod_name, mod_name)\n\u001b[0;32m-> 1415\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfind_class\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmod_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n", 40 | "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'dynamight'" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "d = torch.load(path)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "a0bdf0d9", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [] 55 | } 56 | ], 57 | "metadata": { 58 | "kernelspec": { 59 | "display_name": "recovar [~/.conda/envs/recovar/]", 60 | "language": "python", 61 | "name": "conda_recovar" 62 | }, 63 | "language_info": { 64 | "codemirror_mode": { 65 | "name": "ipython", 66 | "version": 3 67 | }, 68 | "file_extension": ".py", 69 | "mimetype": "text/x-python", 70 | "name": "python", 71 | "nbconvert_exporter": "python", 72 | "pygments_lexer": "ipython3", 73 | "version": "3.11.9" 74 | } 75 | }, 76 | "nbformat": 4, 77 | "nbformat_minor": 5 78 | } 79 | -------------------------------------------------------------------------------- /metrics/visualization/README.md: -------------------------------------------------------------------------------- 1 | # Latent visualization: Tools for visualizing latent space colored by ground truth 2 | ### Dependencies: 3 | * cryodrgn version 3.4.0 4 | * recovar 5 | 6 | ### Example usage (IgG-1D): 7 | * `result-path`: A path to the folder that contains UMAP and latent files before the method name (e.g., /scratch/gpfs/ZHONGE/mj7341/CryoBench/results/IgG-1D/snr0.01). 8 | ``` 9 | $ conda activate cryodrgn 10 | 11 | # 3D Class 12 | * Copy 3dcls job to `results/3dcls/cls_{num classes}` (e.g., JXX_class_XX_XXXXX_volume.mrc) 13 | $ python metrics/visualization/visualize_umap_IgG-1D.py --method 3dcls --is_cryosparc -o output/visualize_umap_igg1d --cryosparc_path cryosparc/CS-IgG1D --result-path results --num_imgs 1000 --num_classes 10 --num_vols 100 14 | 15 | # 3D Class abinit 16 | * Copy 3dcls_abinit job to `results/3dcls_abinit/cls_{num classes}` (e.g., JXX_class_XX_final_volume.mrc) 17 | $ python metrics/visualization/visualize_umap_IgG-1D.py --method 3dcls_abinit --is_cryosparc -o output/visualize_umap_igg1d --cryosparc_path cryosparc/CS-IgG1D --result-path results --num_imgs 1000 --num_classes 20 --num_vols 100 18 | 19 | # 3DVA 20 | $ python metrics/visualization/visualize_umap_IgG-1D.py --method 3dva --cryosparc_job_num JXX --is_cryosparc -o output/visualize_umap_igg1d --cryosparc_path cryosparc/CS-IgG1D --result-path results --num_imgs 1000 --num_vols 100 21 | 22 | # 3DFlex 23 | $ python metrics/visualization/visualize_umap_IgG-1D.py --method 3dflex --cryosparc_job_num JXX --is_cryosparc -o output/visualize_umap_igg1d --cryosparc_path cryosparc/CS-IgG1D --result-path results --num_imgs 1000 --num_vols 100 24 | 25 | # Get reordered latent of recovar 26 | $ conda activate recovar 27 | $ python metrics/methods/recovar/gen_reordered_z.py --recovar-result-dir results/recovar 28 | 29 | # Other methods 30 | $ for method in cryodrgn cryodrgn2 drgnai_fixed drgnai_abinit opus-dsd recovar 31 | do 32 | python metrics/visualization/visualize_umap_IgG-1D.py --method ${method} -o output/visualize_umap_igg1d --result-path results --num_imgs 1000 --num_vols 100 33 | done 34 | ``` 35 | -------------------------------------------------------------------------------- /metrics/visualization/calculate_IgG-RL_gt_latents.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from MDAnalysis.lib.distances import calc_dihedrals 4 | 5 | import MDAnalysis as mda 6 | 7 | 8 | 9 | file_root = 'IgG-RL/pdbs/' 10 | n_structures = 100 11 | pdb_list = [''] * 100 12 | for i in range(100): 13 | pdb_list [i] = file_root + f"{i:03}" + '.pdb' 14 | 15 | 16 | pdb_list = [''] 17 | universe = mda.Universe (pdb_list) 18 | sel1 = universe.select_atoms('resid 1-213 and segid H or segid L') 19 | hinge1 = universe.select_atoms('resid 213 and segid H and name CA') 20 | hinge2 = universe.select_atoms('resid 244 and segid H and name CA and resname CYS') 21 | sel2 = universe.select_atoms('not ( resid 1-213 and segid H or segid L)') 22 | dihs = np.zeros(universe.trajectory.n_frames) 23 | dists = np.zeros(universe.trajectory.n_frames) 24 | 25 | for i in range(universe.trajectory.n_frames): 26 | universe.trajectory[i] 27 | coord0 = sel1.center_of_mass() 28 | coord1 = hinge1.center_of_mass() 29 | coord2 = hinge2.center_of_mass() 30 | coord3 = sel2.center_of_mass() 31 | dihs[i] = calc_dihedrals(coord0, coord1, coord2, coord3) 32 | dists[i] = np.linalg.norm((coord0, coord3)) 33 | 34 | dihs = np.array([angle + 2*np.pi if angle < 0 else angle for angle in dihs]) 35 | dihs = dihs/np.pi 36 | write_array = np.array([dists,dihs]).T 37 | 38 | 39 | #np.save('conf-het-2_CV_dihedral_distance.npy',write_array) 40 | #$plt.xlabel('Center of mass distance d ($\AA$)',fontsize=20) 41 | #plt.ylabel('Dihedral angle $\phi$ (radians/$\pi$)',fontsize=20) 42 | plt.figure(figsize=(7,5)) 43 | plt.xlabel('Center of mass distance ($\AA$)',fontsize=20) 44 | plt.ylabel('Dihedral angle (radians/ $\pi$)',fontsize=20) 45 | plt.xticks(fontsize=16) 46 | plt.yticks(fontsize=16) 47 | plt.ylim([0,2]) 48 | plt.scatter(dists,dihs,s=200,edgecolors='black',c = dists, cmap = 'viridis' ) 49 | plt.tight_layout() 50 | plt.savefig('dihedral_distance.pdf',dpi=92) 51 | plt.show() 52 | -------------------------------------------------------------------------------- /metrics/visualization/conf-het-2_CV_dihedral_distance.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-struct-bio/CryoBench/20931c235a7c78626ed169d00d4534f80c63bc86/metrics/visualization/conf-het-2_CV_dihedral_distance.npy -------------------------------------------------------------------------------- /metrics/visualization/visualize_umap_IgG-1D.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import pickle 5 | import json 6 | import glob 7 | import numpy as np 8 | from cryodrgn import analysis 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | 12 | sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "fsc")) 13 | from utils import volumes 14 | 15 | 16 | def parse_args() -> argparse.ArgumentParser: 17 | parser = argparse.ArgumentParser(description=__doc__) 18 | 19 | parser.add_argument( 20 | "result_path", 21 | type=os.path.abspath, 22 | help="umap & latent folder before method name (e.g. /scratch/gpfs/ZHONGE/mj7341/CryoBench/results/IgG-1D/snr0.01)", 23 | ) 24 | parser.add_argument( 25 | "-o", 26 | type=os.path.abspath, 27 | required=True, 28 | help="Output folder to save the UMAP plot", 29 | ) 30 | 31 | parser.add_argument("--num_imgs", type=int, help="number of images") 32 | parser.add_argument("--num_classes", type=int, default=20, help="number of classes") 33 | parser.add_argument("--num_vols", type=int, default=100, help="number of classes") 34 | 35 | return parser 36 | 37 | 38 | def parse_class_assignments(path_for_label, path_for_model, K): 39 | # for labels 40 | cs = np.load(path_for_label) 41 | keys = ["alignments_class3D_{}/class_posterior".format(i) for i in range(K)] 42 | classes = [[x[k] for k in keys] for x in cs] 43 | classes = np.asarray(classes) 44 | class_id = classes.argmax(axis=1) 45 | 46 | # for models 47 | model_lst = [] 48 | path_for_model.split("/")[-2] 49 | cs_model_path = [ 50 | path_for_model 51 | + path_for_model.split("/")[-2] 52 | + "_passthrough_particles_class_{}.cs".format(i) 53 | for i in range(K) 54 | ] 55 | for i in range(K): 56 | cs_model = np.load(cs_model_path[i]) 57 | for j in range(len(cs_model)): 58 | num_vol = cs_model[j][1].split(b"_")[1].decode("utf-8") 59 | model_lst.append(num_vol) 60 | model_id = np.asarray(model_lst) 61 | 62 | return class_id, model_id 63 | 64 | 65 | def parse_class_abinit_assignments(path_for_label, path_for_model, K): 66 | # for labels 67 | cs = np.load(path_for_label, allow_pickle=True) 68 | keys = ["alignments_class_{}/class_posterior".format(i) for i in range(K)] 69 | classes = [[x[k] for k in keys] for x in cs] 70 | classes = np.asarray(classes) 71 | class_id = classes.argmax(axis=1) 72 | 73 | # for models 74 | model_lst = [] 75 | path_for_model.split("/")[-2] 76 | cs_model_path = [ 77 | path_for_model 78 | + path_for_model.split("/")[-2] 79 | + "_class_{:02d}_final_particles.cs".format(i) 80 | for i in range(K) 81 | ] 82 | for i in range(K): 83 | cs_model = np.load(cs_model_path[i]) 84 | for j in range(len(cs_model)): 85 | num_vol = cs_model[j][1].split(b"_")[1].decode("utf-8") 86 | model_lst.append(num_vol) 87 | model_id = np.asarray(model_lst) 88 | 89 | return class_id, model_id 90 | 91 | 92 | def plot_3dcls(args, dihedral_angles, labels_3dcls, jitter=0.04): 93 | 94 | x = np.repeat(dihedral_angles, 1000) 95 | xx = np.cos(x / 180 * np.pi) 96 | yy = np.sin(x / 180 * np.pi) 97 | 98 | xx_jittered = xx + np.random.randn(len(xx)) * jitter 99 | yy_jittered = yy + np.random.randn(len(yy)) * jitter 100 | 101 | colorList = plt.cm.tab20.colors 102 | cmap = matplotlib.colors.ListedColormap(colorList[: args.num_classes]) 103 | 104 | plt.scatter( 105 | xx_jittered, 106 | yy_jittered, 107 | cmap=cmap, 108 | s=1, 109 | alpha=0.1, 110 | c=labels_3dcls, 111 | vmin=0, 112 | vmax=args.num_classes, 113 | rasterized=True, 114 | ) 115 | plt_umap_labels() 116 | plt.savefig( 117 | f"{args.o}/{args.method}/{args.method}_{args.num_classes}.pdf", 118 | bbox_inches="tight", 119 | ) 120 | plt.show() 121 | plt.close() 122 | 123 | 124 | def plt_umap_labels(): 125 | plt.xticks([]) 126 | plt.yticks([]) 127 | 128 | 129 | def plot_methods(args, method, v, dihedral_angles, is_umap=True, use_axis=False): 130 | fig, ax = plt.subplots(figsize=(4, 4)) 131 | # Whole 132 | plot_dim = (0, 1) 133 | # Dihedral angles 134 | c_all = np.repeat(dihedral_angles, args.num_imgs) 135 | c = c_all 136 | plot_args = dict(alpha=0.1, s=1, cmap="gist_rainbow", vmin=0, vmax=356.4) 137 | ax.scatter(v[:, plot_dim[0]], v[:, plot_dim[1]], c=c, rasterized=True, **plot_args) 138 | 139 | if is_umap: 140 | if use_axis: 141 | fig.savefig(os.path.join(args.o, f"{method}_umap.pdf"), bbox_inches="tight") 142 | else: 143 | plt_umap_labels() 144 | fig.savefig( 145 | os.path.join(args.o, f"{method}_umap_no_axis.pdf"), 146 | bbox_inches="tight", 147 | ) 148 | 149 | else: 150 | if use_axis: 151 | fig.savefig( 152 | os.path.join(args.o, f"{method}_latent.pdf"), bbox_inches="tight" 153 | ) 154 | else: 155 | plt_umap_labels() 156 | fig.savefig( 157 | os.path.join(args.o, f"{method}_latent_no_axis.pdf"), 158 | bbox_inches="tight", 159 | ) 160 | 161 | plt.close() 162 | 163 | 164 | def main(args): 165 | dihedral_angles = np.linspace(0, 356.4, 100) 166 | input_files = os.listdir(args.result_path) 167 | os.makedirs(args.o, exist_ok=True) 168 | 169 | if "config.yaml" in input_files: 170 | umap_pkl = os.path.join(args.result_path, "analyze.19", "umap.pkl") 171 | umap_pkl = open(umap_pkl, "rb") 172 | umap_pkl = pickle.load(umap_pkl) 173 | plot_methods(args, "cryoDRGN", umap_pkl, dihedral_angles, is_umap=True) 174 | 175 | elif "config.pkl" in input_files: 176 | umap_pkl = f"{args.result_path}/{args.method}/analyze.19/umap.pkl" 177 | 178 | umap_pkl = open(umap_pkl, "rb") 179 | umap_pkl = pickle.load(umap_pkl) 180 | # UMap 181 | plot_methods(args, "OPUS-DSD", umap_pkl, dihedral_angles, is_umap=True) 182 | 183 | elif "reordered_z.npy" in input_files: 184 | latent_path = os.path.join(args.result_path, args.method, "reordered_z.npy") 185 | latent_z = np.load(latent_path) 186 | 187 | umap_path = f"{args.result_path}/{args.method}/reordered_z_umap.npy" 188 | umap_pkl = analysis.run_umap(latent_z) # v: latent space 189 | np.save(umap_path, umap_pkl) 190 | 191 | plot_methods(args, "RECOVAR", umap_pkl, dihedral_angles, is_umap=True) 192 | 193 | elif "job.json" in input_files: 194 | cryosparc_path, cryosparc_job = os.path.split( 195 | os.path.normpath(args.result_path) 196 | ) 197 | with open(os.path.join(args.input_dir, "job.json")) as f: 198 | configs = json.load(f) 199 | 200 | if configs["type"] == "class_3D": 201 | file_pattern = "*.mrc" 202 | files = glob.glob( 203 | os.path.join( 204 | args.result_path, 205 | args.method, 206 | "cls_" + str(args.num_classes), 207 | file_pattern, 208 | ) 209 | ) 210 | pred_dir = sorted(files, key=volumes.natural_sort_key) 211 | cryosparc_num = os.path.split(pred_dir[0])[-1].split(".")[0].split("_")[3] 212 | print("cryosparc_num:", cryosparc_num) 213 | print("cryosparc_job:", cryosparc_job) 214 | path_for_label = os.path.join( 215 | args.result_path, f"{cryosparc_job}_{cryosparc_num}_particles.cs" 216 | ) 217 | path_for_model = os.path.join(cryosparc_path, cryosparc_job) 218 | labels_3dcls, models_3dcls = parse_class_assignments( 219 | path_for_label, path_for_model, args.num_classes 220 | ) 221 | plot_3dcls(args, dihedral_angles, labels_3dcls, jitter=0.04) 222 | 223 | elif args.method == "3dcls_abinit": 224 | file_pattern = "*.mrc" 225 | files = glob.glob( 226 | os.path.join( 227 | args.result_path, 228 | args.method, 229 | "cls_" + str(args.num_classes), 230 | file_pattern, 231 | ) 232 | ) 233 | pred_dir = sorted(files, key=volumes.natural_sort_key) 234 | cryosparc_job = pred_dir[0].split("/")[-1].split(".")[0].split("_")[0] 235 | print("cryosparc_job:", cryosparc_job) 236 | path_for_label = f"{args.cryosparc_path}/{cryosparc_job}/{cryosparc_job}_final_particles.cs" 237 | path_for_model = f"{args.cryosparc_path}/{cryosparc_job}/" 238 | labels_3dcls, models_3dcls = parse_class_abinit_assignments( 239 | path_for_label, path_for_model, args.num_classes 240 | ) 241 | plot_3dcls(args, dihedral_angles, labels_3dcls, jitter=0.04) 242 | 243 | elif args.method == "3dva": 244 | path = f"{args.cryosparc_path}/{args.cryosparc_job_num}/{args.cryosparc_job_num}_particles.cs" 245 | 246 | x = np.load(path) 247 | v = np.empty((len(x), 3)) # component_0,1,2 248 | for i in range(3): 249 | v[:, i] = x[f"components_mode_{i}/value"] 250 | latent_path = f"{args.o}/{args.method}/{args.method}_latents.npy" 251 | np.save(latent_path, v) 252 | # UMap 253 | umap_path = f"{args.o}/{args.method}/{args.method}_umap.npy" 254 | if not os.path.exists(umap_path): 255 | umap_latent = analysis.run_umap(v) # v: latent space 256 | np.save(umap_path, umap_latent) 257 | else: 258 | umap_latent = np.load(umap_path) 259 | plot_methods(args, "3DVA", umap_latent, dihedral_angles, is_umap=True) 260 | 261 | elif args.method == "3dflex": 262 | path = f"{args.cryosparc_path}/{args.cryosparc_job_num}/{args.cryosparc_job_num}_latents_011200.cs" 263 | x = np.load(path) 264 | 265 | v = np.empty((len(x), 2)) 266 | for i in range(2): 267 | v[:, i] = x[f"components_mode_{i}/value"] 268 | latent_path = f"{args.o}/{args.method}/{args.method}_latents.npy" 269 | np.save(latent_path, v) 270 | # Latent 271 | plot_methods(args, "3DFlex", v, dihedral_angles, is_umap=True) 272 | 273 | elif ( 274 | "out" in input_files 275 | and os.path.isdir(os.path.join(args.input_dir, "out")) 276 | and "config.yaml" in os.listdir(os.path.join(args.input_dir, "out")) 277 | ): 278 | umap_pkl = f"{args.result_path}/{args.method}/out/analysis_100/umap.pkl" 279 | umap_pkl = open(umap_pkl, "rb") 280 | umap_pkl = pickle.load(umap_pkl) 281 | plot_methods(args, "DRGN-AI", umap_pkl, dihedral_angles, is_umap=True) 282 | 283 | else: 284 | raise ValueError( 285 | f"Unrecognized output folder format found in `{args.result_path}`!" 286 | f"Does not match for any known methods: " 287 | "cryoDRGN, DRGN-AI, OPUS-DSD, 3dflex, 3DVA, RECOVAR" 288 | ) 289 | 290 | 291 | if __name__ == "__main__": 292 | args = parse_args().parse_args() 293 | main(args) 294 | print("done!") 295 | -------------------------------------------------------------------------------- /metrics/visualization/visualize_umap_IgG-RL.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from cryodrgn import analysis 4 | 5 | import pickle 6 | import os 7 | import re 8 | import argparse 9 | import matplotlib.pyplot as plt 10 | 11 | log = print 12 | 13 | import glob, re 14 | 15 | print("loaded") 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description=__doc__) 20 | parser.add_argument("--method", type=str, help="type of methods") 21 | parser.add_argument( 22 | "--cv_num", type=int, help="cv number to use for color", required=True 23 | ) 24 | parser.add_argument("--is_cryosparc", action="store_true", help="cryosparc or not") 25 | parser.add_argument("--num_imgs", type=int, help="number of images") 26 | parser.add_argument("--num_classes", type=int, default=20, help="number of classes") 27 | parser.add_argument("--num_vols", type=int, default=100, help="number of classes") 28 | parser.add_argument( 29 | "-o", 30 | type=os.path.abspath, 31 | required=True, 32 | help="Output projection stack (.mrcs)", 33 | ) 34 | parser.add_argument( 35 | "--result-path", 36 | type=os.path.abspath, 37 | required=True, 38 | help="umap & latent folder before method name (e.g. /scratch/gpfs/ZHONGE/mj7341/cryosim/results/conf_het_v1/snr01)", 39 | ) 40 | parser.add_argument( 41 | "--cryosparc_path", 42 | type=os.path.abspath, 43 | default="/scratch/gpfs/ZHONGE/mj7341/cryosparc/CS-new-mask-for-confhet-v1", 44 | help="cryosparc folder path", 45 | ) 46 | parser.add_argument("--cryosparc_job_num", type=str, help="cryosparc job number") 47 | 48 | return parser 49 | 50 | 51 | def natural_sort_key(s): 52 | # Convert the string to a list of text and numbers 53 | parts = re.split("([0-9]+)", s) 54 | 55 | # Convert numeric parts to integers for proper numeric comparison 56 | parts[1::2] = map(int, parts[1::2]) 57 | 58 | return parts 59 | 60 | 61 | def plt_umap_labels(): 62 | plt.xticks([]) 63 | plt.yticks([]) 64 | 65 | 66 | def parse_class_assignments(path_for_label, path_for_model, K): 67 | # for labels 68 | cs = np.load(path_for_label) 69 | keys = ["alignments_class3D_{}/class_posterior".format(i) for i in range(K)] 70 | classes = [[x[k] for k in keys] for x in cs] 71 | classes = np.asarray(classes) 72 | class_id = classes.argmax(axis=1) 73 | 74 | # for models 75 | model_lst = [] 76 | path_for_model.split("/")[-2] 77 | cs_model_path = [ 78 | path_for_model 79 | + path_for_model.split("/")[-2] 80 | + "_passthrough_particles_class_{}.cs".format(i) 81 | for i in range(K) 82 | ] 83 | for i in range(K): 84 | cs_model = np.load(cs_model_path[i]) 85 | for j in range(len(cs_model)): 86 | num_vol = cs_model[j][1].split(b"_")[1].decode("utf-8") 87 | model_lst.append(num_vol) 88 | model_id = np.asarray(model_lst) 89 | 90 | return class_id, model_id 91 | 92 | 93 | def parse_class_abinit_assignments(path_for_label, path_for_model, K): 94 | # for labels 95 | cs = np.load(path_for_label, allow_pickle=True) 96 | keys = ["alignments_class_{}/class_posterior".format(i) for i in range(K)] 97 | classes = [[x[k] for k in keys] for x in cs] 98 | classes = np.asarray(classes) 99 | class_id = classes.argmax(axis=1) 100 | 101 | # for models 102 | model_lst = [] 103 | path_for_model.split("/")[-2] 104 | cs_model_path = [ 105 | path_for_model 106 | + path_for_model.split("/")[-2] 107 | + "_class_{:02d}_final_particles.cs".format(i) 108 | for i in range(K) 109 | ] 110 | for i in range(K): 111 | cs_model = np.load(cs_model_path[i]) 112 | for j in range(len(cs_model)): 113 | num_vol = cs_model[j][1].split(b"_")[1].decode("utf-8") 114 | model_lst.append(num_vol) 115 | model_id = np.asarray(model_lst) 116 | 117 | return class_id, model_id 118 | 119 | 120 | def plot_3dcls(args, dihedral_angles, labels_3dcls, jitter=0.04): 121 | 122 | x = np.repeat(dihedral_angles, 1000) 123 | xx = np.cos(x / 180 * np.pi) 124 | yy = np.sin(x / 180 * np.pi) 125 | 126 | xx_jittered = xx + np.random.randn(len(xx)) * jitter 127 | yy_jittered = yy + np.random.randn(len(yy)) * jitter 128 | 129 | plt.scatter( 130 | xx_jittered, 131 | yy_jittered, 132 | cmap="viridis", 133 | s=1, 134 | alpha=0.1, 135 | c=labels_3dcls, 136 | vmin=0, 137 | vmax=args.num_classes, 138 | ) 139 | plt_umap_labels() 140 | plt.savefig( 141 | f"{args.o}/{args.method}/{args.method}_{args.num_classes}_{args.cv_num}.png", 142 | dpi=1200, 143 | bbox_inches="tight", 144 | ) 145 | plt.show() 146 | plt.close() 147 | 148 | 149 | def plot_methods(args, v, dihedral_angles, is_umap=True, use_axis=False): 150 | fig, ax = plt.subplots(figsize=(4, 4)) 151 | # Whole 152 | plot_dim = (0, 1) 153 | c_all = np.repeat(dihedral_angles, 1000) 154 | c = c_all 155 | plot_args = dict( 156 | alpha=0.1, 157 | s=1, 158 | cmap="gist_rainbow", 159 | vmin=np.amin(dihedral_angles), 160 | vmax=np.amax(dihedral_angles), 161 | ) 162 | plt.scatter(v[:, plot_dim[0]], v[:, plot_dim[1]], c=c, **plot_args) 163 | use_axis = True 164 | plt.xlabel("UMap1") 165 | plt.ylabel("UMap2") 166 | ax.get_xaxis().set_visible(False) 167 | ax.get_yaxis().set_visible(False) 168 | if use_axis: 169 | plt.savefig( 170 | f"{args.o}/{args.method}/{args.method}_umap_{args.cv_num}.png", 171 | dpi=1200, 172 | bbox_inches="tight", 173 | ) 174 | else: 175 | plt_umap_labels() 176 | plt.savefig( 177 | f"{args.o}/{args.method}/{args.method}_umap_no_axis_{args.cv_num}.png", 178 | dpi=1200, 179 | bbox_inches="tight", 180 | ) 181 | plt.close() 182 | 183 | 184 | def main(args): 185 | if args.cv_num == 0: 186 | dihedral_angles = np.load("conf-het-2_CV_dihedral_distance.npy")[:, 0] 187 | if args.cv_num == 1: 188 | dihedral_angles = np.load("conf-het-2_CV_dihedral_distance.npy")[:, 1] 189 | print(dihedral_angles) 190 | 191 | if args.is_cryosparc: 192 | if args.method == "3dcls": 193 | file_pattern = "*.mrc" 194 | files = glob.glob( 195 | os.path.join( 196 | args.result_path, 197 | args.method, 198 | "cls_" + str(args.num_classes), 199 | file_pattern, 200 | ) 201 | ) 202 | pred_dir = sorted(files, key=natural_sort_key) 203 | cryosparc_num = pred_dir[0].split("/")[-1].split(".")[0].split("_")[3] 204 | cryosparc_job = pred_dir[0].split("/")[-1].split(".")[0].split("_")[0] 205 | print("cryosparc_num:", cryosparc_num) 206 | print("cryosparc_job:", cryosparc_job) 207 | path_for_label = f"{args.cryosparc_path}/{cryosparc_job}/{cryosparc_job}_{cryosparc_num}_particles.cs" 208 | path_for_model = f"{args.cryosparc_path}/{cryosparc_job}/" 209 | labels_3dcls, models_3dcls = parse_class_assignments( 210 | path_for_label, path_for_model, args.num_classes 211 | ) 212 | plot_3dcls(args, dihedral_angles, labels_3dcls, jitter=0.04) 213 | 214 | elif args.method == "3dcls_abinit": 215 | path_for_label = f"{args.cryosparc_path}/{args.cryosparc_job_num}/{args.cryosparc_job_num}_final_particles.cs" 216 | path_for_model = f"{args.cryosparc_path}/{args.cryosparc_job_num}/" 217 | labels_3dcls, models_3dcls = parse_class_abinit_assignments( 218 | path_for_label, path_for_model, args.num_classes 219 | ) 220 | plot_3dcls(args, dihedral_angles, labels_3dcls, jitter=0.04) 221 | 222 | elif args.method == "3dva": 223 | path = f"/mnt/ceph/users/mastore/cryobench_2024_ellen_zhong_colab/conf-het-2/pdbs/snr001/3dva/3dva_latents.npy" 224 | 225 | x = np.load(path) 226 | v = x 227 | 228 | latent_path = f"{args.o}/{args.method}/{args.method}_latents.npy" 229 | np.save(latent_path, v) 230 | umap_path = f"{args.o}/{args.method}/{args.method}_umap.npy" 231 | if not os.path.exists(umap_path): 232 | umap_latent = analysis.run_umap(v) # v: latent space 233 | np.save(umap_path, umap_latent) 234 | else: 235 | umap_latent = np.load(umap_path) 236 | plot_methods(args, umap_latent, dihedral_angles, is_umap=True) 237 | 238 | elif args.method == "3dflex": 239 | path = f"/mnt/ceph/users/mastore/cryobench_2024_ellen_zhong_colab/conf-het-2/pdbs/snr001/3dflex/3dflex_latents.npy" 240 | x = np.load(path) 241 | 242 | v = np.empty((len(x), 2)) 243 | latent_path = f"{args.o}/{args.method}/{args.method}_latents.npy" 244 | latent_path = "temp.txt" 245 | np.savetxt(latent_path, v) 246 | 247 | v[np.isinf(v)] = 0 248 | v[v > 10**30] = 0 249 | v = np.array(v, dtype=np.float32) 250 | umap_path = f"{args.o}/{args.method}/{args.method}_umap.npy" 251 | plot_methods(args, v, dihedral_angles, is_umap=False) 252 | 253 | else: 254 | print("chooseing method") 255 | if args.method == "cryodrgn": 256 | umap_pkl = f"{args.result_path}/{args.method}/analyze.19/umap.pkl" 257 | 258 | umap_pkl = open(umap_pkl, "rb") 259 | umap_pkl = pickle.load(umap_pkl) 260 | plot_methods(args, umap_pkl, dihedral_angles, is_umap=True) 261 | 262 | elif args.method == "cryodrgn2": 263 | latent_z = f"{args.result_path}/{args.method}/z.29.pkl" 264 | latent_z = open(latent_z, "rb") 265 | latent_z = pickle.load(latent_z) 266 | umap_pkl = f"{args.result_path}/{args.method}/analyze.29/umap.pkl" 267 | 268 | umap_pkl = open(umap_pkl, "rb") 269 | umap_pkl = pickle.load(umap_pkl) 270 | plot_methods(args, umap_pkl, dihedral_angles, is_umap=True) 271 | 272 | elif args.method == "drgnai_fixed": 273 | umap_pkl = f"{args.result_path}/{args.method}/analysis_100/umap.pkl" 274 | 275 | umap_pkl = open(umap_pkl, "rb") 276 | umap_pkl = pickle.load(umap_pkl) 277 | plot_methods(args, umap_pkl, dihedral_angles, is_umap=True) 278 | 279 | elif args.method == "drgnai_abinit": 280 | umap_pkl = f"{args.result_path}/{args.method}/analysis_100/umap.pkl" 281 | 282 | umap_pkl = open(umap_pkl, "rb") 283 | umap_pkl = pickle.load(umap_pkl) 284 | plot_methods(args, umap_pkl, dihedral_angles, is_umap=True) 285 | 286 | elif args.method == "opus-dsd": 287 | umap_pkl = f"{args.result_path}/{args.method}/analyze.19/umap.pkl" 288 | 289 | umap_pkl = open(umap_pkl, "rb") 290 | umap_pkl = pickle.load(umap_pkl) 291 | plot_methods(args, umap_pkl, dihedral_angles, is_umap=True) 292 | 293 | elif args.method == "recovar": 294 | print("loading recovar") 295 | latent_path = os.path.join(args.result_path, args.method, "reordered_z.npy") 296 | latent_z = np.load(latent_path) 297 | 298 | umap_path = f"{args.result_path}/{args.method}/reordered_z_umap.npy" 299 | umap_pkl = np.load(umap_path) 300 | plot_methods(args, umap_pkl, dihedral_angles, is_umap=True) 301 | 302 | 303 | if __name__ == "__main__": 304 | args = parse_args().parse_args() 305 | if not os.path.exists(args.o + "/" + args.method): 306 | os.makedirs(args.o + "/" + args.method) 307 | 308 | main(args) 309 | -------------------------------------------------------------------------------- /metrics/visualization/visualize_umap_Ribosembly.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from cryodrgn import analysis 3 | 4 | import pickle 5 | import os 6 | import re 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | 10 | log = print 11 | from matplotlib import colors 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description=__doc__) 16 | parser.add_argument("--method", type=str, help="type of methods") 17 | parser.add_argument("--is_cryosparc", action="store_true", help="cryosparc or not") 18 | parser.add_argument("--num_imgs", type=int, help="number of images") 19 | parser.add_argument("--num_classes", type=int, default=20, help="number of classes") 20 | parser.add_argument("--num_vols", type=int, default=100, help="number of classes") 21 | parser.add_argument( 22 | "-o", 23 | type=os.path.abspath, 24 | required=True, 25 | help="Output folder to save the UMAP plot", 26 | ) 27 | parser.add_argument( 28 | "--result-path", 29 | type=os.path.abspath, 30 | required=True, 31 | help="umap & latent folder before method name (e.g. /scratch/gpfs/ZHONGE/mj7341/CryoBench/results/IgG-1D/snr0.01)", 32 | ) 33 | parser.add_argument( 34 | "--cryosparc_path", type=os.path.abspath, help="cryosparc folder path" 35 | ) 36 | parser.add_argument("--cryosparc_job_num", type=str, help="cryosparc job number") 37 | 38 | return parser 39 | 40 | 41 | def natural_sort_key(s): 42 | # Convert the string to a list of text and numbers 43 | parts = re.split("([0-9]+)", s) 44 | 45 | # Convert numeric parts to integers for proper numeric comparison 46 | parts[1::2] = map(int, parts[1::2]) 47 | 48 | return parts 49 | 50 | 51 | def gt_colors(num_imgs): 52 | c_lst = [] 53 | c_num = 0 54 | for num_img in num_imgs: 55 | for i in range(num_img): 56 | c_lst.append(c_num) 57 | c_num += 1 58 | c_all = np.array(c_lst) 59 | return c_all 60 | 61 | 62 | def plt_umap_labels(): 63 | plt.xticks([]) 64 | plt.yticks([]) 65 | 66 | 67 | def plot_methods(args, v, is_umap=True, use_axis=False): 68 | fig, ax = plt.subplots(figsize=(4, 4)) 69 | # Whole 70 | colorList = [ 71 | "#3182bd", 72 | "#6baed6", 73 | "#9ecae1", 74 | "#e6550d", 75 | "#fd8d3c", 76 | "#fdae6b", 77 | "#fdd0a2", 78 | "#e377c2", 79 | "#f7b6d2", 80 | "#31a354", 81 | "#74c476", 82 | "#a1d99b", 83 | "#756bb1", 84 | "#9e9ac8", 85 | "#bcbddc", 86 | "#dadaeb", 87 | ] 88 | cmap = colors.ListedColormap(colorList[: args.num_vols]) 89 | print("cmap.N:", cmap.N) 90 | 91 | num_imgs = [ 92 | 9076, 93 | 14378, 94 | 23547, 95 | 44366, 96 | 30647, 97 | 38500, 98 | 3915, 99 | 3980, 100 | 12740, 101 | 11975, 102 | 17988, 103 | 5001, 104 | 35367, 105 | 37448, 106 | 40540, 107 | 5772, 108 | ] 109 | c_lst = [] 110 | c_num = 0 111 | for num_img in num_imgs: 112 | for i in range(num_img): 113 | c_lst.append(c_num) 114 | c_num += 1 115 | c_all = np.array(c_lst) 116 | plt.scatter( 117 | v[:, 0], 118 | v[:, 1], 119 | alpha=0.1, 120 | s=1, 121 | cmap=cmap, 122 | c=c_all, 123 | label=c_all, 124 | rasterized=True, 125 | ) 126 | 127 | if is_umap: 128 | if use_axis: 129 | plt.savefig( 130 | f"{args.o}/{args.method}/{args.method}_umap.pdf", bbox_inches="tight" 131 | ) 132 | else: 133 | plt_umap_labels() 134 | plt.savefig( 135 | f"{args.o}/{args.method}/{args.method}_umap_no_axis.pdf", 136 | bbox_inches="tight", 137 | ) 138 | else: 139 | if use_axis: 140 | plt.savefig( 141 | f"{args.o}/{args.method}/{args.method}_latent.pdf", bbox_inches="tight" 142 | ) 143 | else: 144 | plt_umap_labels() 145 | plt.savefig( 146 | f"{args.o}/{args.method}/{args.method}_latent_no_axis.pdf", 147 | bbox_inches="tight", 148 | ) 149 | plt.close() 150 | 151 | 152 | def main(args): 153 | if args.is_cryosparc: 154 | if args.method == "3dva": 155 | path = f"{args.cryosparc_path}/{args.cryosparc_job_num}/{args.cryosparc_job_num}_particles.cs" 156 | 157 | x = np.load(path) 158 | v = np.empty((len(x), 3)) # component_0,1,2 159 | for i in range(3): 160 | v[:, i] = x[f"components_mode_{i}/value"] 161 | latent_path = f"{args.o}/{args.method}/{args.method}_latents.npy" 162 | np.save(latent_path, v) 163 | 164 | umap_path = f"{args.o}/{args.method}/{args.method}_umap.npy" 165 | if not os.path.exists(umap_path): 166 | umap_latent = analysis.run_umap(v) # v: latent space 167 | np.save(umap_path, umap_latent) 168 | else: 169 | umap_latent = np.load(umap_path) 170 | plot_methods(args, umap_latent, is_umap=True) 171 | 172 | else: 173 | if args.method == "cryodrgn": 174 | umap_pkl = f"{args.result_path}/{args.method}/analyze.49/umap.pkl" 175 | 176 | umap_pkl = open(umap_pkl, "rb") 177 | umap_pkl = pickle.load(umap_pkl) 178 | plot_methods(args, umap_pkl, is_umap=True) 179 | 180 | elif args.method == "cryodrgn2": 181 | umap_pkl = f"{args.result_path}/{args.method}/analyze.29/umap.pkl" 182 | 183 | umap_pkl = open(umap_pkl, "rb") 184 | umap_pkl = pickle.load(umap_pkl) 185 | plot_methods(args, umap_pkl, is_umap=True) 186 | 187 | elif args.method == "drgnai_fixed": 188 | umap_pkl = f"{args.result_path}/{args.method}/out/analysis_100/umap.pkl" 189 | 190 | umap_pkl = open(umap_pkl, "rb") 191 | umap_pkl = pickle.load(umap_pkl) 192 | plot_methods(args, umap_pkl, is_umap=True) 193 | 194 | elif args.method == "drgnai_abinit": 195 | umap_pkl = f"{args.result_path}/{args.method}/out/analysis_100/umap.pkl" 196 | 197 | umap_pkl = open(umap_pkl, "rb") 198 | umap_pkl = pickle.load(umap_pkl) 199 | plot_methods(args, umap_pkl, is_umap=True) 200 | 201 | elif args.method == "opus-dsd": 202 | umap_pkl = f"{args.result_path}/{args.method}/analyze.19/umap.pkl" 203 | 204 | umap_pkl = open(umap_pkl, "rb") 205 | umap_pkl = pickle.load(umap_pkl) 206 | plot_methods(args, umap_pkl, is_umap=True) 207 | 208 | elif args.method == "recovar": 209 | latent_path = os.path.join(args.result_path, args.method, "reordered_z.npy") 210 | latent_z = np.load(latent_path) 211 | umap_path = f"{args.result_path}/{args.method}/reordered_z_umap.npy" 212 | umap_pkl = analysis.run_umap(latent_z) 213 | np.save(umap_path, umap_pkl) 214 | plot_methods(args, umap_pkl, is_umap=True) 215 | 216 | 217 | if __name__ == "__main__": 218 | args = parse_args().parse_args() 219 | if not os.path.exists(args.o + "/" + args.method): 220 | os.makedirs(args.o + "/" + args.method) 221 | 222 | main(args) 223 | print("done!") 224 | -------------------------------------------------------------------------------- /metrics/visualization/visualize_umap_Tomotwin-100.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from cryodrgn import analysis 3 | 4 | import pickle 5 | import os 6 | import re 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | 10 | log = print 11 | 12 | import glob 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description=__doc__) 17 | parser.add_argument("--method", type=str, help="type of methods") 18 | parser.add_argument("--is_cryosparc", action="store_true", help="cryosparc or not") 19 | parser.add_argument("--num_classes", type=int, default=20, help="number of classes") 20 | parser.add_argument("--num_vols", type=int, default=100, help="number of classes") 21 | parser.add_argument("--num_imgs", type=int, default=1000, help="number of images") 22 | parser.add_argument( 23 | "-o", 24 | type=os.path.abspath, 25 | required=True, 26 | help="Output folder to save the UMAP plot", 27 | ) 28 | parser.add_argument( 29 | "--result-path", 30 | type=os.path.abspath, 31 | required=True, 32 | help="umap & latent folder before method name (e.g. /scratch/gpfs/ZHONGE/mj7341/CryoBench/results/IgG-1D/snr0.01)", 33 | ) 34 | parser.add_argument( 35 | "--cryosparc_path", type=os.path.abspath, help="cryosparc folder path" 36 | ) 37 | parser.add_argument("--cryosparc_job_num", type=str, help="cryosparc job number") 38 | 39 | return parser 40 | 41 | 42 | def natural_sort_key(s): 43 | # Convert the string to a list of text and numbers 44 | parts = re.split("([0-9]+)", s) 45 | 46 | # Convert numeric parts to integers for proper numeric comparison 47 | parts[1::2] = map(int, parts[1::2]) 48 | 49 | return parts 50 | 51 | 52 | def plt_umap_labels(): 53 | plt.xticks([]) 54 | plt.yticks([]) 55 | 56 | 57 | def plot_methods(args, v, is_umap=True, use_axis=False): 58 | mass_for_class = [ 59 | 156, 60 | 162, 61 | 168, 62 | 174, 63 | 179, 64 | 180, 65 | 191, 66 | 193, 67 | 197, 68 | 200, 69 | 204, 70 | 205, 71 | 206, 72 | 206, 73 | 215, 74 | 226, 75 | 231, 76 | 233, 77 | 240, 78 | 247, 79 | 249, 80 | 250, 81 | 251, 82 | 257, 83 | 258, 84 | 260, 85 | 266, 86 | 280, 87 | 280, 88 | 285, 89 | 291, 90 | 296, 91 | 313, 92 | 313, 93 | 316, 94 | 331, 95 | 333, 96 | 359, 97 | 375, 98 | 377, 99 | 382, 100 | 383, 101 | 393, 102 | 399, 103 | 410, 104 | 422, 105 | 424, 106 | 439, 107 | 456, 108 | 464, 109 | 468, 110 | 478, 111 | 488, 112 | 490, 113 | 491, 114 | 493, 115 | 502, 116 | 511, 117 | 515, 118 | 518, 119 | 518, 120 | 529, 121 | 547, 122 | 551, 123 | 556, 124 | 574, 125 | 588, 126 | 590, 127 | 591, 128 | 595, 129 | 597, 130 | 602, 131 | 607, 132 | 622, 133 | 629, 134 | 632, 135 | 633, 136 | 652, 137 | 663, 138 | 671, 139 | 681, 140 | 727, 141 | 732, 142 | 838, 143 | 847, 144 | 864, 145 | 865, 146 | 877, 147 | 881, 148 | 881, 149 | 899, 150 | 921, 151 | 956, 152 | 957, 153 | 1014, 154 | 1023, 155 | 1023, 156 | 1057, 157 | 1066, 158 | 1131, 159 | ] 160 | mass_for_classes = np.repeat(mass_for_class, args.num_imgs) 161 | fig, ax = plt.subplots(figsize=(4, 4)) 162 | plot_dim = (0, 1) 163 | 164 | plt.scatter( 165 | v[:, plot_dim[0]], 166 | v[:, plot_dim[1]], 167 | alpha=0.1, 168 | s=1, 169 | c=mass_for_classes, 170 | cmap="rainbow", 171 | rasterized=True, 172 | ) 173 | 174 | if is_umap: 175 | if use_axis: 176 | plt.savefig( 177 | f"{args.o}/{args.method}/{args.method}_umap.pdf", bbox_inches="tight" 178 | ) 179 | else: 180 | plt_umap_labels() 181 | plt.savefig( 182 | f"{args.o}/{args.method}/{args.method}_umap_no_axis_rainbow.pdf", 183 | bbox_inches="tight", 184 | ) 185 | else: 186 | if use_axis: 187 | plt.savefig( 188 | f"{args.o}/{args.method}/{args.method}_latent.pdf", bbox_inches="tight" 189 | ) 190 | else: 191 | plt_umap_labels() 192 | plt.savefig( 193 | f"{args.o}/{args.method}/{args.method}_latent_no_axis_rainbow.pdf", 194 | bbox_inches="tight", 195 | ) 196 | plt.close() 197 | 198 | 199 | def main(args): 200 | if args.is_cryosparc: 201 | if args.method == "3dva": 202 | path = f"{args.cryosparc_path}/{args.cryosparc_job_num}/{args.cryosparc_job_num}_particles.cs" 203 | 204 | x = np.load(path) 205 | v = np.empty((len(x), 3)) # component_0,1,2 206 | for i in range(3): 207 | v[:, i] = x[f"components_mode_{i}/value"] 208 | latent_path = f"{args.o}/{args.method}/{args.method}_latents.npy" 209 | np.save(latent_path, v) 210 | 211 | # UMap 212 | umap_path = f"{args.o}/{args.method}/{args.method}_umap.npy" 213 | if not os.path.exists(umap_path): 214 | umap_latent = analysis.run_umap(v) # v: latent space 215 | np.save(umap_path, umap_latent) 216 | else: 217 | umap_latent = np.load(umap_path) 218 | plot_methods(args, umap_latent, is_umap=True) 219 | 220 | else: 221 | if args.method == "cryodrgn": 222 | 223 | umap_pkl = f"{args.result_path}/{args.method}/analyze.19/umap.pkl" 224 | 225 | umap_pkl = open(umap_pkl, "rb") 226 | umap_pkl = pickle.load(umap_pkl) 227 | plot_methods(args, umap_pkl, is_umap=True) 228 | 229 | elif args.method == "cryodrgn2": 230 | umap_pkl = f"{args.result_path}/{args.method}/analyze.29/umap.pkl" 231 | 232 | umap_pkl = open(umap_pkl, "rb") 233 | umap_pkl = pickle.load(umap_pkl) 234 | # UMap 235 | plot_methods(args, umap_pkl, is_umap=True) 236 | 237 | elif args.method == "drgnai_fixed": 238 | umap_pkl = f"{args.result_path}/{args.method}/out/analysis_100/umap.pkl" 239 | 240 | umap_pkl = open(umap_pkl, "rb") 241 | umap_pkl = pickle.load(umap_pkl) 242 | plot_methods(args, umap_pkl, is_umap=True) 243 | 244 | elif args.method == "drgnai_abinit": 245 | umap_pkl = f"{args.result_path}/{args.method}/out/analysis_100/umap.pkl" 246 | 247 | umap_pkl = open(umap_pkl, "rb") 248 | umap_pkl = pickle.load(umap_pkl) 249 | # UMap 250 | plot_methods(args, umap_pkl, is_umap=True) 251 | 252 | elif args.method == "opus-dsd": 253 | umap_pkl = f"{args.result_path}/{args.method}/analyze.19/umap.pkl" 254 | 255 | umap_pkl = open(umap_pkl, "rb") 256 | umap_pkl = pickle.load(umap_pkl) 257 | # UMap 258 | plot_methods(args, umap_pkl, is_umap=True) 259 | 260 | elif args.method == "recovar": 261 | latent_path = os.path.join(args.result_path, args.method, "reordered_z.npy") 262 | latent_z = np.load(latent_path) 263 | 264 | umap_path = f"{args.result_path}/{args.method}/reordered_z_umap.npy" 265 | umap_pkl = analysis.run_umap(latent_z) # v: latent space 266 | np.save(umap_path, umap_pkl) 267 | plot_methods(args, umap_pkl, is_umap=True) 268 | 269 | 270 | if __name__ == "__main__": 271 | args = parse_args().parse_args() 272 | if not os.path.exists(args.o + "/" + args.method): 273 | os.makedirs(args.o + "/" + args.method) 274 | 275 | main(args) 276 | print("done!") 277 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<1.27 2 | pandas<2 3 | matplotlib<3.7 4 | seaborn<0.12 5 | scikit-learn 6 | torch>1.0.0 7 | --------------------------------------------------------------------------------