├── scripts ├── arabidopsis_halleri │ ├── config.yaml │ ├── .gitignore │ ├── envs │ │ ├── repeatmasker.yaml │ │ ├── umap.yaml │ │ ├── expand_annotation.yaml │ │ ├── gpn.yaml │ │ ├── run_classification.yaml │ │ ├── global.yaml │ │ └── notebook.yaml │ ├── profiles │ │ ├── default │ │ │ └── config.v8+.yaml │ │ └── slurm │ │ │ └── config.v8+.yaml │ ├── scripts │ │ ├── umap.py │ │ ├── expand_annotation.py │ │ └── run_classification.py │ ├── README.md │ └── Snakefile ├── tribolium_castaneum │ ├── .gitignore │ ├── profiles │ │ ├── default │ │ │ └── config.v8+.yaml │ │ └── slurm │ │ │ └── config.v8+.yaml │ ├── envs │ │ ├── repeatmasker.yaml │ │ ├── umap.yaml │ │ ├── gpn.yaml │ │ ├── run_classification.yaml │ │ ├── expand_annotation.yaml │ │ ├── global.yaml │ │ └── notebook.yaml │ ├── input │ │ └── id_mapping.tsv │ ├── config.yaml │ ├── scripts │ │ ├── umap.py │ │ ├── run_classification.py │ │ └── expand_annotation.py │ └── Snakefile ├── data_preparation │ ├── .gitignore │ ├── profiles │ │ └── default │ │ │ └── config.v8+.yaml │ ├── README.md │ ├── upload_to_hf.py │ ├── workflow │ │ ├── envs │ │ │ └── global.yaml │ │ ├── Snakefile │ │ └── rules │ │ │ ├── download.smk │ │ │ ├── intervals.smk │ │ │ └── dataset.smk │ └── config │ │ ├── assemblies.tsv │ │ ├── annotated_tenebrionidae.tsv │ │ ├── config.yaml │ │ └── annotated_chrom+_ncbi_refseq_cucujiformia_assemblies.tsv ├── high_throughput_gpn_computation │ ├── resources │ │ └── .gitignore │ ├── results │ │ └── .gitignore │ ├── .gitignore │ ├── profiles │ │ ├── default │ │ │ └── config.v8+.yaml │ │ └── slurm │ │ │ └── config.v8+.yaml │ ├── workflow │ │ ├── envs │ │ │ └── gpn.yaml │ │ ├── rules │ │ │ └── compute_gpn.smk │ │ ├── Snakefile │ │ └── scripts │ │ │ └── compute_gpn.py │ └── config │ │ └── config.yaml ├── nucleotide_dependency_maps │ ├── .gitignore │ ├── profiles │ │ └── default │ │ │ └── config.v8+.yaml │ ├── config.yaml │ ├── scripts │ │ ├── display_dependency_map.py │ │ ├── compute_dependency_map.py │ │ └── nucleotide_dependency_map_helpers.py │ └── Snakefile ├── path_distribution.png ├── path_distribution_models.png ├── parquet_to_csv.py └── model_evaluation.ipynb ├── experiments └── context_length_prediction_impact │ ├── .gitignore │ ├── profiles │ ├── default │ │ └── config.v8+.yaml │ └── slurm │ │ └── config.v8+.yaml │ ├── envs │ ├── global.yaml │ ├── compute_statistics.yaml │ └── scripts.yaml │ ├── scripts │ ├── create_annotation_db.py │ ├── generate_random_sample_positions.py │ ├── display_chromosome_boxplot.py │ ├── compute_gpn_per_context_length.py │ ├── compute_probabilities_per_context_length.py │ ├── display_distribution_shift_chart.py │ ├── display_stacked_variance_chart.py │ ├── stats_helpers.py │ ├── display_stacked_probability_chart.py │ ├── compute_distribution_shift_over_chromosome.py │ ├── compute_prediction_variance_over_chromosome copy.py │ ├── compute_gpn_statistics_over_chromosome.py │ ├── compute_all_sliding_window_for_sequence.py │ └── helpers.py │ ├── config.yaml │ └── Snakefile ├── .gitignore ├── environment.yaml ├── devenv.def ├── global_training.yaml ├── gpn_gpu.yaml ├── README.md ├── requirements.txt └── NOTES.md /scripts/arabidopsis_halleri/config.yaml: -------------------------------------------------------------------------------- 1 | window_size: 512 -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/.gitignore: -------------------------------------------------------------------------------- 1 | output 2 | .snakemake -------------------------------------------------------------------------------- /scripts/data_preparation/.gitignore: -------------------------------------------------------------------------------- 1 | .snakemake 2 | results 3 | tmp -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | resources 3 | .snakemake -------------------------------------------------------------------------------- /scripts/high_throughput_gpn_computation/resources/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /scripts/high_throughput_gpn_computation/results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /scripts/nucleotide_dependency_maps/.gitignore: -------------------------------------------------------------------------------- 1 | .snakemake 2 | output 3 | __pycache__ -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/.gitignore: -------------------------------------------------------------------------------- 1 | .snakemake 2 | output 3 | __pycache__ 4 | data -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # any build apptainers images 2 | *.sif 3 | 4 | # jupyter checkpoints 5 | **/.ipynb_checkpoints/ 6 | 7 | data -------------------------------------------------------------------------------- /scripts/high_throughput_gpn_computation/.gitignore: -------------------------------------------------------------------------------- 1 | .cache 2 | .conda 3 | .config 4 | .nv 5 | .snakemake 6 | .condarc 7 | -------------------------------------------------------------------------------- /scripts/path_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SilvanCodes/masterthesis-ramses/main/scripts/path_distribution.png -------------------------------------------------------------------------------- /scripts/path_distribution_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SilvanCodes/masterthesis-ramses/main/scripts/path_distribution_models.png -------------------------------------------------------------------------------- /scripts/nucleotide_dependency_maps/profiles/default/config.v8+.yaml: -------------------------------------------------------------------------------- 1 | software-deployment-method: conda 2 | conda-prefix: /scratch/sbuedenb/snakemake 3 | -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/envs/repeatmasker.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - bioconda 4 | - defaults 5 | dependencies: 6 | - repeatmasker 7 | -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/profiles/default/config.v8+.yaml: -------------------------------------------------------------------------------- 1 | jobs: 10 2 | 3 | software-deployment-method: conda 4 | conda-prefix: /scratch/sbuedenb/snakemake 5 | -------------------------------------------------------------------------------- /scripts/data_preparation/profiles/default/config.v8+.yaml: -------------------------------------------------------------------------------- 1 | jobs: 10 2 | 3 | software-deployment-method: conda 4 | conda-prefix: /scratch/sbuedenb/snakemake 5 | -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/profiles/default/config.v8+.yaml: -------------------------------------------------------------------------------- 1 | jobs: 10 2 | 3 | software-deployment-method: conda 4 | conda-prefix: /scratch/sbuedenb/snakemake 5 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/profiles/default/config.v8+.yaml: -------------------------------------------------------------------------------- 1 | jobs: 10 2 | 3 | conda-prefix: /scratch/sbuedenb/snakemake 4 | software-deployment-method: conda -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/envs/repeatmasker.yaml: -------------------------------------------------------------------------------- 1 | name: repeatmasker_env 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - defaults 6 | dependencies: 7 | - repeatmasker 8 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: snakemake 2 | channels: 3 | - bioconda 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - snakemake=9.5.1 8 | - snakemake-executor-plugin-slurm=1.3.6 9 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/envs/global.yaml: -------------------------------------------------------------------------------- 1 | name: global_env 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - defaults 6 | dependencies: 7 | - ncbi-datasets-cli -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/envs/umap.yaml: -------------------------------------------------------------------------------- 1 | 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - scikit-learn=1.5.2 7 | - pyarrow=19.0.1 8 | - pandas=2.2.3 9 | - umap-learn=0.5.7 -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/envs/expand_annotation.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - bioconda 4 | - defaults 5 | dependencies: 6 | - pandas=2.2.3 7 | - bioframe=0.7.2 8 | - gtfparse=2.5.0 9 | # - pyarrow=19.0.1 -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/envs/gpn.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - bioconda 4 | - defaults 5 | dependencies: 6 | - ninja=1.12.1 7 | - pip 8 | - pip: 9 | - git+https://github.com/songlab-cal/gpn.git@main -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/envs/umap.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: umap_env 3 | channels: 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - scikit-learn=1.5.2 8 | - pyarrow=19.0.1 9 | - pandas=2.2.3 10 | - umap-learn=0.5.7 -------------------------------------------------------------------------------- /scripts/high_throughput_gpn_computation/profiles/default/config.v8+.yaml: -------------------------------------------------------------------------------- 1 | jobs: 10 2 | 3 | software-deployment-method: 4 | - conda 5 | - apptainer 6 | apptainer-args: "--bind /scratch/sbuedenb --nv" 7 | conda-prefix: /scratch/sbuedenb/snakemake 8 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/create_annotation_db.py: -------------------------------------------------------------------------------- 1 | import gffutils 2 | 3 | db = gffutils.create_db( 4 | snakemake.input[0], snakemake.output[0], merge_strategy="create_unique" 5 | ) 6 | db.update(list(db.create_introns())) 7 | -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/envs/run_classification.yaml: -------------------------------------------------------------------------------- 1 | name: run_classification_env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - pandas=2.2.3 7 | - scikit-learn 8 | - pyarrow 9 | - joblib 10 | - numpy 11 | -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/envs/gpn.yaml: -------------------------------------------------------------------------------- 1 | name: gpn_env 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - defaults 6 | dependencies: 7 | - ninja=1.12.1 8 | - pip 9 | - pip: 10 | - git+https://github.com/songlab-cal/gpn.git@main -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/envs/run_classification.yaml: -------------------------------------------------------------------------------- 1 | name: run_classification_env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - pandas=2.2.3 7 | - scikit-learn 8 | - pyarrow 9 | - joblib 10 | - numpy 11 | -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/input/id_mapping.tsv: -------------------------------------------------------------------------------- 1 | NC_087394.1,1 2 | NC_087395.1,2 3 | NC_087396.1,3 4 | NC_087397.1,4 5 | NC_087398.1,5 6 | NC_087399.1,6 7 | NC_087400.1,7 8 | NC_087401.1,8 9 | NC_087402.1,9 10 | NC_087403.1,10 11 | NC_087404.1,11 12 | NC_003081.2,MT -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/envs/global.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - bioconda 4 | - defaults 5 | dependencies: 6 | - pandas=2.2.3 7 | - biopython=1.85 8 | - bioframe=0.7.2 9 | - more-itertools=10.6.0 10 | - pip 11 | - pip: 12 | - git+https://github.com/songlab-cal/gpn.git@main -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/envs/compute_statistics.yaml: -------------------------------------------------------------------------------- 1 | name: compute_statistics_env 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - defaults 6 | dependencies: 7 | - gffutils 8 | - pandas 9 | - tqdm 10 | - pip 11 | - pip: 12 | - git+https://github.com/songlab-cal/gpn.git@main -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/envs/expand_annotation.yaml: -------------------------------------------------------------------------------- 1 | name: expand_annotation_env 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - defaults 6 | dependencies: 7 | - pandas=2.2.3 8 | - bioframe=0.7.2 9 | - gtfparse=2.5.0 10 | # - pyarrow=19.0.1 11 | - pip: 12 | - git+https://github.com/songlab-cal/gpn.git@main -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/envs/global.yaml: -------------------------------------------------------------------------------- 1 | name: global_env 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - defaults 6 | dependencies: 7 | - pandas=2.2.3 8 | - biopython=1.85 9 | - bioframe=0.7.2 10 | - more-itertools=10.6.0 11 | - pip 12 | - pip: 13 | - git+https://github.com/songlab-cal/gpn.git@0.6 14 | -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/config.yaml: -------------------------------------------------------------------------------- 1 | FASTA_URL: "https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/031/307/605/GCF_031307605.1_icTriCast1.1/GCF_031307605.1_icTriCast1.1_genomic.fna.gz" 2 | GTF_URL: "https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/031/307/605/GCF_031307605.1_icTriCast1.1/GCF_031307605.1_icTriCast1.1_genomic.gff.gz" 3 | 4 | window_size: 512 -------------------------------------------------------------------------------- /scripts/data_preparation/README.md: -------------------------------------------------------------------------------- 1 | # Data preparation 2 | 3 | This workflow is a copy from https://github.com/songlab-cal/gpn/tree/main/workflow/make_dataset. 4 | 5 | The `config` folder contains `annotated_chrom+_ncbi_refseq_cucujiformia_assemblies.tsv` which is the selection of species forming the https://huggingface.co/datasets/sbuedenb/big_beetle_dataset. -------------------------------------------------------------------------------- /scripts/high_throughput_gpn_computation/workflow/envs/gpn.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - pandas 7 | - numpy 8 | - pytorch 9 | - pytorch-cuda 10 | - transformers 11 | - biopython 12 | - tqdm 13 | - pip 14 | - pip: 15 | - git+https://github.com/songlab-cal/gpn.git@main 16 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/envs/scripts.yaml: -------------------------------------------------------------------------------- 1 | name: scripts_env 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - pandas 8 | - numpy 9 | - pytorch 10 | - pytorch-cuda 11 | - transformers 12 | - tqdm 13 | - matplotlib 14 | - seaborn 15 | - pip 16 | - pip: 17 | - git+https://github.com/songlab-cal/gpn.git@main -------------------------------------------------------------------------------- /devenv.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: continuumio/miniconda3 3 | 4 | %files 5 | ./requirements.txt /opt 6 | 7 | %post 8 | # conda config --set channel_priority strict 9 | apt-get update && apt-get install -y build-essential vcftools graphviz 10 | conda install python=3.11 11 | pip install -r /opt/requirements.txt 12 | 13 | %startscript 14 | jupyter notebook --no-browser --port 9999 15 | -------------------------------------------------------------------------------- /scripts/data_preparation/upload_to_hf.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import HfApi 2 | 3 | api = HfApi() 4 | 5 | private = False 6 | repo_id = "sbuedenb/big_beetle_dataset-1024" # replace with your username, dataset name 7 | folder_path = "results/dataset" 8 | api.create_repo(repo_id=repo_id, repo_type="dataset", private=private) 9 | api.upload_folder(repo_id=repo_id, folder_path=folder_path, repo_type="dataset") 10 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/profiles/slurm/config.v8+.yaml: -------------------------------------------------------------------------------- 1 | # https://snakemake.github.io/snakemake-plugin-catalog/plugins/executor/slurm.html 2 | executor: slurm 3 | jobs: 10 4 | 5 | software-deployment-method: conda 6 | conda-prefix: /scratch/sbuedenb/snakemake 7 | 8 | 9 | default-resources: 10 | slurm_partition: "gpu" 11 | slurm_account: "ag-wiehe" 12 | mem: "64gb" 13 | runtime: 60 14 | -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/envs/notebook.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: notebook_env 3 | channels: 4 | - conda-forge 5 | - bioconda 6 | - defaults 7 | dependencies: 8 | - python=3.11 9 | - bioframe 10 | - matplotlib 11 | - pyarrow 12 | - pandas 13 | - numpy 14 | - seaborn 15 | - tqdm 16 | - umap-learn[plot]=0.5.7 17 | - jupyterlab 18 | - pip 19 | - pip: 20 | - git+https://github.com/songlab-cal/gpn.git@main -------------------------------------------------------------------------------- /scripts/data_preparation/workflow/envs/global.yaml: -------------------------------------------------------------------------------- 1 | name: global_env 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - bioconda 6 | - defaults 7 | dependencies: 8 | - transformers=4.26.1 9 | - huggingface_hub 10 | - pandas 11 | - numpy 12 | - pytorch 13 | - pytorch-cuda 14 | - tqdm 15 | - ncbi-datasets-cli 16 | - pip 17 | - pip: 18 | - git+https://github.com/SilvanCodes/gpn.git@main 19 | -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/envs/notebook.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: notebook_env 3 | channels: 4 | - conda-forge 5 | - bioconda 6 | - defaults 7 | dependencies: 8 | - python=3.10 9 | - bioframe 10 | - matplotlib 11 | - pyarrow 12 | - pandas 13 | - numpy 14 | - seaborn 15 | - tqdm 16 | - umap-learn[plot]=0.5.7 17 | - jupyterlab 18 | - pip 19 | - pip: 20 | - git+https://github.com/songlab-cal/gpn.git@main -------------------------------------------------------------------------------- /global_training.yaml: -------------------------------------------------------------------------------- 1 | name: global_env_training 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - bioconda 6 | - defaults 7 | dependencies: 8 | - python=3.10 9 | - transformers=4.26.1 10 | - huggingface_hub 11 | - safetensors 12 | - pandas 13 | - numpy 14 | - pytorch 15 | - pytorch-cuda 16 | - tqdm 17 | - ncbi-datasets-cli 18 | - pip 19 | - pip: 20 | - git+https://github.com/SilvanCodes/gpn.git@main 21 | -------------------------------------------------------------------------------- /scripts/high_throughput_gpn_computation/workflow/rules/compute_gpn.smk: -------------------------------------------------------------------------------- 1 | rule compute_gpn: 2 | input: 3 | "resources/{accession}.HMA4_region.fa", 4 | output: 5 | "results/{accession}/{chromosome}/{start_position}_{stop_position}_{reverse_complement}/gpn_scores.parquet", 6 | conda: 7 | "../envs/gpn.yaml" 8 | resources: 9 | slurm_extra="-G h100_1g.12gb:1", 10 | script: 11 | "../scripts/compute_gpn.py" 12 | -------------------------------------------------------------------------------- /scripts/high_throughput_gpn_computation/profiles/slurm/config.v8+.yaml: -------------------------------------------------------------------------------- 1 | # https://snakemake.github.io/snakemake-plugin-catalog/plugins/executor/slurm.html 2 | executor: slurm 3 | jobs: 10 4 | 5 | software-deployment-method: 6 | - conda 7 | - apptainer 8 | apptainer-args: "--bind /scratch/sbuedenb --nv" 9 | conda-prefix: /scratch/sbuedenb/snakemake 10 | 11 | 12 | default-resources: 13 | slurm_partition: "gpu" 14 | slurm_account: "ag-wiehe" 15 | mem: "64gb" 16 | runtime: 60 17 | -------------------------------------------------------------------------------- /scripts/data_preparation/config/assemblies.tsv: -------------------------------------------------------------------------------- 1 | Assembly Accession Assembly Name Organism Name Organism Infraspecific Names Breed Organism Infraspecific Names Strain Organism Infraspecific Names Cultivar Organism Infraspecific Names Ecotype Organism Infraspecific Names Isolate Organism Infraspecific Names Sex Annotation Name Assembly Level Assembly Release Date WGS project accession Assembly Stats Number of Scaffolds 2 | GCF_031307605.1 icTriCast1.1 Tribolium castaneum GA2 male GCF_031307605.1-RS_2024_04 Chromosome 2023-09-14 JANKOB01 148 3 | -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/scripts/umap.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from umap import UMAP 3 | from sklearn.pipeline import Pipeline 4 | from sklearn.preprocessing import StandardScaler 5 | 6 | embeddings = pd.read_parquet(snakemake.input[0]) 7 | proj = Pipeline( 8 | [ 9 | ("scaler", StandardScaler()), 10 | ("umap", UMAP(metric='cosine', random_state=42, verbose=True)), 11 | ] 12 | ).fit_transform(embeddings) 13 | proj = pd.DataFrame(proj, columns=["UMAP1", "UMAP2"]) 14 | proj.to_parquet(snakemake.output[0], index=False) -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/scripts/umap.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from umap import UMAP 3 | from sklearn.pipeline import Pipeline 4 | from sklearn.preprocessing import StandardScaler 5 | 6 | embeddings = pd.read_parquet(snakemake.input[0]) 7 | proj = Pipeline( 8 | [ 9 | ("scaler", StandardScaler()), 10 | ("umap", UMAP(metric='cosine', random_state=42, verbose=True)), 11 | ] 12 | ).fit_transform(embeddings) 13 | proj = pd.DataFrame(proj, columns=["UMAP1", "UMAP2"]) 14 | proj.to_parquet(snakemake.output[0], index=False) -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/config.yaml: -------------------------------------------------------------------------------- 1 | MODEL: "gonzalobenegas/gpn-brassicales" 2 | 3 | ACCESSION: "GCF_000001735.4" 4 | CHROMOSOME: "NC_003076.8" 5 | # MASKED_POSITION: 18867726 6 | 7 | MAXIMUM_CONTEXT_LENGTH: 2000 8 | 9 | # rather large window covering 5% of different context length to ensure 10 | # predictions have stabilized long term 11 | PREDICTION_VARIANCE_WINDOW_SIZE: 100 12 | 13 | PREDICTION_VARIANCE_THRESHOLD: 0.0001 14 | 15 | # number of uniform random samples to compute genome statistics 16 | RANDOM_POSITION_COUNT: 10000 17 | 18 | DISTRIBUTION_SHIFT_WINDOW_SIZE: 10 19 | DISTRIBUTION_SHIFT_THRESHOLD: 0.01 -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/generate_random_sample_positions.py: -------------------------------------------------------------------------------- 1 | from gpn.data import load_fasta 2 | import random 3 | import json 4 | 5 | chromosome = snakemake.config["CHROMOSOME"] 6 | random_position_count = snakemake.config["RANDOM_POSITION_COUNT"] 7 | max_context_length = snakemake.config["MAXIMUM_CONTEXT_LENGTH"] 8 | 9 | sequence_path = snakemake.input[0] 10 | genome = load_fasta(sequence_path) 11 | sequence = genome[chromosome] 12 | 13 | # treat as 1-indexed 14 | gap = max_context_length // 2 + 1 15 | 16 | random_positions = random.sample(range(gap, len(sequence) - gap), random_position_count) 17 | 18 | with open(snakemake.output[0], 'w') as f: 19 | json.dump(random_positions, f) -------------------------------------------------------------------------------- /scripts/data_preparation/workflow/Snakefile: -------------------------------------------------------------------------------- 1 | configfile: "config/config.yaml" 2 | 3 | 4 | conda: "envs/global.yaml" 5 | 6 | import pandas as pd 7 | 8 | 9 | print(config) 10 | 11 | assemblies = pd.read_csv(config["assemblies_path"], sep="\t", index_col=0) 12 | splits = ["train", "validation", "test"] 13 | 14 | 15 | # comment out if you have your own fasta files 16 | # and make sure you have genomes (and annotations, if applicable) in the right place 17 | # results/genome/{assembly}.fa.gz (and results/annotation/{assembly}.gff.gz) 18 | include: "rules/download.smk" 19 | include: "rules/intervals.smk" 20 | include: "rules/dataset.smk" 21 | 22 | 23 | rule all: 24 | input: 25 | expand("results/dataset/data/{split}", split=splits), 26 | -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/README.md: -------------------------------------------------------------------------------- 1 | # _Arabidopsis halleri_ 2 | 3 | This workflow is here in order to compute a UMAP of embeddings of the Lan3.1 genome from _Arabidopsis halleri_. 4 | 5 | It is adopted from https://github.com/songlab-cal/gpn/blob/main/analysis/gpn_arabidopsis/Snakefile 6 | 7 | ## Preparation 8 | 9 | This workflow needs the two files `Lan3.1.fna.gz` and `Lan3.1.genomic.gff3.gz` in a folder named `resources`. 10 | The first file is the genome, the ssecond fiel is the annotation. 11 | The can be downloaded here for example: https://phytozome-next.jgi.doe.gov/info/Ahalleri_v2_1_0 12 | 13 | ## Required programs 14 | 15 | `Conda` and `Snakemake` need to be installed. 16 | 17 | ## Addons 18 | 19 | If you have SLURM available, install `snakemake-executor-plugin-slurm` and adapt `profiles/slurm`. -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/profiles/slurm/config.v8+.yaml: -------------------------------------------------------------------------------- 1 | # https://snakemake.github.io/snakemake-plugin-catalog/plugins/executor/slurm.html 2 | executor: slurm 3 | jobs: 10 4 | 5 | software-deployment-method: conda 6 | conda-prefix: /scratch/sbuedenb/snakemake 7 | 8 | 9 | default-resources: 10 | slurm_partition: "smp" 11 | slurm_account: "ag-wiehe" 12 | mem: "32gb" 13 | tasks: 1 14 | cpus_per_task: 8 15 | runtime: 15 16 | 17 | set-resources: 18 | get_embeddings: 19 | slurm_partition: "gpu" # deviating partition for this rule 20 | mem: "64gb" 21 | runtime: 90 # 1.5 hour 22 | 23 | 24 | run_umap: 25 | mem: "128gb" 26 | tasks: 1 27 | cpus_per_task: 16 28 | runtime: 90 # 1.5 hour 29 | 30 | 31 | run_classification: 32 | mem: "256gb" 33 | tasks: 4 34 | cpus_per_task: 16 35 | runtime: 1440 # 24 hours 36 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/display_chromosome_boxplot.py: -------------------------------------------------------------------------------- 1 | import helpers 2 | import pandas as pd 3 | from snakemake.script import snakemake 4 | 5 | # load statistics_over_chromosome 6 | statistics_over_chromosome_path = snakemake.input[0] 7 | statistics_over_chromosome = pd.read_parquet(statistics_over_chromosome_path) 8 | 9 | category_counts = statistics_over_chromosome["feature"].value_counts() 10 | 11 | # we only plot categories with at least 100 datapoints in it 12 | valid_categories = category_counts[category_counts >= 100].index 13 | 14 | filtered_df = statistics_over_chromosome[ 15 | statistics_over_chromosome["feature"].isin(valid_categories) 16 | ] 17 | 18 | helpers.boxplot( 19 | filtered_df, 20 | snakemake.output[0], 21 | snakemake.wildcards.format, 22 | title=snakemake.params.title, 23 | ylabel="Influential Context Size", 24 | marker=512 25 | # xlabel="Genomic Annotation" 26 | ) 27 | -------------------------------------------------------------------------------- /scripts/nucleotide_dependency_maps/config.yaml: -------------------------------------------------------------------------------- 1 | # look at ncbi ftp links to find new urls 2 | FASTA_URL: { 3 | "GCF_000001735.4": "https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/735/GCF_000001735.4_TAIR10.1/GCF_000001735.4_TAIR10.1_genomic.fna.gz" 4 | } 5 | 6 | MODELS: [ 7 | "gonzalobenegas/gpn-brassicales", 8 | ] 9 | 10 | SEQUENCES: [ 11 | "NC_003076.8" 12 | # "GCF_000001735.4", 13 | # "Lan3.1" 14 | ] 15 | 16 | DISPLAY_VMAX: 1.5 17 | 18 | CHROMOSOME: "NC_003076.8" 19 | START_POSITION: 10000 20 | END_POSITION: 11000 21 | 22 | 23 | # fuer yannick 24 | # HMA4-1 chr3 22023391-22031473 25 | # HMA4-2 chr3 21941657-21950203 26 | # HMA4-3 chr3 21886677-21896021 27 | # CHROMOSOME: "chr3" 28 | # START_POSITION: 22023391 29 | # END_POSITION: 22031473 30 | # START_POSITION: 21941657 31 | # END_POSITION: 21950203 32 | # START_POSITION: 21886677 33 | # END_POSITION: 21896021 34 | 35 | 36 | DISPLAY_MAP_START_RELATIVE: 0 37 | DISPLAY_MAP_END_RELATIVE: -1 -------------------------------------------------------------------------------- /scripts/data_preparation/config/annotated_tenebrionidae.tsv: -------------------------------------------------------------------------------- 1 | Assembly Accession Assembly Name Organism Name Organism Infraspecific Names Breed Organism Infraspecific Names Strain Organism Infraspecific Names Cultivar Organism Infraspecific Names Ecotype Organism Infraspecific Names Isolate Organism Infraspecific Names Sex Annotation Name Assembly Level Assembly Release Date WGS project accession Assembly Stats Number of Scaffolds 2 | GCF_031307605.1 icTriCast1.1 Tribolium castaneum GA2 male GCF_031307605.1-RS_2024_04 Chromosome 2023-09-14 JANKOB01 148 3 | GCF_963966145.1 icTenMoli1.1 Tenebrio molitor GCF_963966145.1-RS_2024_10 Chromosome 2024-02-12 CAWYQD01 236 4 | GCF_036711695.1 CSIRO_AGI_Zmor_V1 Zophobas morio UQ-CSIRO AGI-343-1 GCF_036711695.1-RS_2024_03 Contig 2024-02-20 JAVQLY01 1279 5 | GCF_015345945.1 Tmad_KSU_1.1 Tribolium madens KSU strain multiple individuals pooled male and female NCBI Tribolium madens Annotation Release 100 Scaffold 2020-11-10 JADDYM01 112 6 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/compute_gpn_per_context_length.py: -------------------------------------------------------------------------------- 1 | import helpers 2 | import pandas as pd 3 | from gpn.data import load_fasta 4 | from snakemake.script import snakemake 5 | 6 | chromosome = snakemake.wildcards.chromosome 7 | 8 | # load sequence 9 | sequence_path = snakemake.input[0] 10 | genome = load_fasta(sequence_path) 11 | sequence = genome[chromosome] 12 | 13 | probabilities_per_context_length_df_path = snakemake.input[1] 14 | 15 | position = int(probabilities_per_context_length_df_path.split("/")[-2]) 16 | 17 | results = pd.read_parquet(probabilities_per_context_length_df_path) 18 | 19 | # arrays are zero indexed, genomes not 20 | reference_nucleotide = sequence[position - 1] 21 | 22 | # skip position when reference is unknown 23 | # if reference_nucleotide in ["n", "N"]: 24 | # continue 25 | 26 | gpn_scores = helpers.compute_gpn_score(reference_nucleotide, results) 27 | 28 | 29 | gpn_scores.to_parquet(snakemake.output[0]) 30 | -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/profiles/slurm/config.v8+.yaml: -------------------------------------------------------------------------------- 1 | # https://snakemake.github.io/snakemake-plugin-catalog/plugins/executor/slurm.html 2 | executor: slurm 3 | jobs: 10 4 | 5 | software-deployment-method: conda 6 | conda-prefix: /scratch/sbuedenb/snakemake 7 | 8 | # mail-user: sbuedenb@smail.uni-koeln.de 9 | # mail-type: END,FAIL 10 | 11 | default-resources: 12 | slurm_partition: "smp" 13 | slurm_account: "ag-wiehe" 14 | mem: "32gb" 15 | tasks: 1 16 | cpus_per_task: 8 17 | runtime: 15 18 | 19 | set-resources: 20 | get_embeddings: 21 | slurm_partition: "gpu" # deviating partition for this rule 22 | mem: "64gb" 23 | runtime: 180 # 3 hour 24 | 25 | run_umap: 26 | mem: "128gb" 27 | tasks: 1 28 | cpus_per_task: 16 29 | runtime: 90 # 1.5 hour 30 | 31 | run_classification: 32 | mem: "256gb" 33 | tasks: 1 34 | cpus_per_task: 8 35 | runtime: 1440 # 24 hours 36 | -------------------------------------------------------------------------------- /scripts/nucleotide_dependency_maps/scripts/display_dependency_map.py: -------------------------------------------------------------------------------- 1 | import nucleotide_dependency_map_helpers as ndm 2 | import pandas as pd 3 | 4 | from gpn.data import load_fasta 5 | 6 | chromosome = snakemake.config["CHROMOSOME"] 7 | seq_start = snakemake.config["START_POSITION"] 8 | seq_end = snakemake.config["END_POSITION"] 9 | 10 | map_start = snakemake.config["DISPLAY_MAP_START_RELATIVE"] 11 | map_end = snakemake.config["DISPLAY_MAP_END_RELATIVE"] 12 | 13 | 14 | # load sequence 15 | sequence_path = snakemake.input[0] 16 | genome = load_fasta(sequence_path) 17 | sequence = genome[chromosome][seq_start - 1 : seq_end] 18 | 19 | # load dep_map 20 | dependency_map_path = snakemake.input[1] 21 | dependency_map = pd.read_parquet(dependency_map_path).values 22 | 23 | print(dependency_map.shape) 24 | 25 | title = f"{chromosome} {seq_start + map_start} - {map_end + seq_end + 1}" 26 | 27 | ndm.map_seq_to_file( 28 | dependency_map[map_start:map_end, map_start:map_end], 29 | sequence[map_start:map_end], 30 | snakemake.output[0], 31 | snakemake.wildcards.format, 32 | vmax=snakemake.config["DISPLAY_VMAX"], 33 | title=title, 34 | ) 35 | -------------------------------------------------------------------------------- /gpn_gpu.yaml: -------------------------------------------------------------------------------- 1 | name: gpn_gpu 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - defaults 6 | dependencies: 7 | - python=3.10 8 | # - pytorch=2.6.0 9 | # - torchvision 10 | # - torchaudio 11 | # - pytorch-cuda=12.6 12 | # - triton=3.1.0 13 | - transformers=4.52.1 14 | - huggingface_hub 15 | - safetensors 16 | - pandas 17 | - numpy 18 | - tqdm 19 | - ncbi-datasets-cli 20 | - pip 21 | - pip: 22 | - git+https://github.com/SilvanCodes/gpn.git@336feddce479b7356693575638d8db17041eab5f 23 | - git+https://github.com/SilvanCodes/datasets.git@9f42abd72c1a27ee3f15f37ab94d56b5a985c9c0 24 | - torch 25 | - torchvision 26 | - torchaudio 27 | 28 | # channels: 29 | # - pytorch 30 | # - nvidia 31 | # - conda-forge 32 | # - bioconda 33 | # - defaults 34 | # dependencies: 35 | # - python=3.10 36 | # - pytorch=2.5.1 37 | # - pytorch-cuda=12.4 38 | # - torchvision=0.20.1 39 | # - torchaudio=2.5.1 40 | # - transformers>=4.48 41 | # - datasets>=3.6 42 | # - huggingface_hub 43 | # - safetensors 44 | # - pandas 45 | # - numpy 46 | # - tqdm 47 | # - ncbi-datasets-cli 48 | # - pip 49 | # - pip: 50 | # - git+https://github.com/SilvanCodes/gpn.git@main 51 | -------------------------------------------------------------------------------- /scripts/parquet_to_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | 5 | def convert_parquet_to_csv_and_delete(root_dir): 6 | for dirpath, _, filenames in os.walk(root_dir): 7 | for file in filenames: 8 | if file.endswith(".parquet"): 9 | parquet_path = os.path.join(dirpath, file) 10 | csv_path = os.path.splitext(parquet_path)[0] + ".csv" 11 | 12 | try: 13 | print(f"Converting: {parquet_path}") 14 | df = pd.read_parquet(parquet_path) 15 | df.to_csv(csv_path, index=False) 16 | print(f"Saved: {csv_path}") 17 | 18 | # Delete original parquet file 19 | os.remove(parquet_path) 20 | print(f"Deleted: {parquet_path}") 21 | except Exception as e: 22 | print(f"❌ Failed to convert {parquet_path}: {e}") 23 | 24 | 25 | if __name__ == "__main__": 26 | import sys 27 | 28 | if len(sys.argv) != 2: 29 | print("Usage: python parquet_to_csv_and_delete.py ") 30 | sys.exit(1) 31 | 32 | directory = sys.argv[1] 33 | convert_parquet_to_csv_and_delete(directory) 34 | -------------------------------------------------------------------------------- /scripts/high_throughput_gpn_computation/config/config.yaml: -------------------------------------------------------------------------------- 1 | MODEL_PATH: "gonzalobenegas/gpn-brassicales" 2 | 3 | BATCH_SIZE: 128 4 | INCLUDED_CONTEXT: 512 #total count of bp symmetetrically around each predicted position 5 | 6 | # workflow assumes an .fna.gz file inside resources 7 | # ACCESSION: "Lan3.1" 8 | # CHROMOSOME: "chr3" 9 | 10 | REVERSE_COMPLEMENT: 11 | [ 12 | false, 13 | true, 14 | false, 15 | true, 16 | false, 17 | false, 18 | true, 19 | false, 20 | false, 21 | false] 22 | 23 | # start computation from 24 | START_POSITION: 25 | 26 | # stop computation at 27 | STOP_POSITION: 999999 28 | 29 | ACCESSION: 30 | [ 31 | "BAC", 32 | "Bors_12", 33 | "Goli_08", 34 | "Lan3.1", 35 | "Lan5", 36 | "Lan5_hap2", 37 | "Pais_09", 38 | "Rund_05_S1", 39 | "Ukan_25", 40 | "Wall_10", 41 | ] 42 | 43 | CHROMOSOME: 44 | [ 45 | "EU382073.1", 46 | "h1tg000001l:1539155-1704685", 47 | "h1tg000022l:224290-393930", 48 | "chr3:21866677-22051473", 49 | "Chr3:26654299-26831685", 50 | "Chr3:23465017-23625889", 51 | "h1tg000006l:1479490-1622324", 52 | "tig00000673_chr3:1356930-1527379", 53 | "h2tg000015l:6652465-6810532", 54 | "tig00000401_chr3:6197730-6356479", 55 | ] 56 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/compute_probabilities_per_context_length.py: -------------------------------------------------------------------------------- 1 | import helpers 2 | import pandas as pd 3 | import torch 4 | from transformers import AutoModelForMaskedLM, AutoTokenizer 5 | from snakemake.script import snakemake 6 | 7 | # gpn specific model configuration 8 | import gpn.model 9 | from gpn.data import load_fasta 10 | 11 | print(f"GPU Model: {torch.cuda.get_device_name(0)}") 12 | 13 | model_path = snakemake.wildcards.model 14 | 15 | # load tokenizer 16 | tokenizer = AutoTokenizer.from_pretrained(model_path) 17 | print(f"tokenizer vocabulary: {tokenizer.get_vocab()}") 18 | 19 | # load model 20 | model = AutoModelForMaskedLM.from_pretrained(model_path) 21 | device = "cuda" 22 | model.to(device) 23 | model.eval() 24 | 25 | # load config 26 | max_context_length = snakemake.config["MAXIMUM_CONTEXT_LENGTH"] 27 | 28 | chromosome = snakemake.wildcards.chromosome 29 | masked_position = int(snakemake.wildcards.position) 30 | 31 | # load sequence 32 | sequence_path = snakemake.input[0] 33 | genome = load_fasta(sequence_path) 34 | sequence = genome[chromosome] 35 | 36 | print(len(sequence)) 37 | 38 | df = helpers.compute_context_length_dependency( 39 | model, 40 | tokenizer, 41 | sequence, 42 | masked_position, 43 | max_context_length=max_context_length, 44 | ) 45 | 46 | df.to_parquet(snakemake.output[0]) 47 | -------------------------------------------------------------------------------- /scripts/nucleotide_dependency_maps/scripts/compute_dependency_map.py: -------------------------------------------------------------------------------- 1 | import nucleotide_dependency_map_helpers as ndm 2 | import pandas as pd 3 | import torch 4 | from transformers import AutoModelForMaskedLM, AutoTokenizer 5 | 6 | # gpn specific model configuration 7 | import gpn.model 8 | from gpn.data import load_fasta 9 | 10 | print(f"GPU Model: {torch.cuda.get_device_name(0)}") 11 | 12 | model_path = snakemake.params.model 13 | 14 | # load tokenizer 15 | tokenizer = AutoTokenizer.from_pretrained(model_path) 16 | print(f"tokenizer vocabulary: {tokenizer.get_vocab()}") 17 | 18 | # load model 19 | model = AutoModelForMaskedLM.from_pretrained(model_path) 20 | device = "cuda" 21 | model.to(device) 22 | model.eval() 23 | 24 | chromosome = snakemake.config["CHROMOSOME"] 25 | # subtract due to zero based indexing in arrays 26 | seq_start = snakemake.config["START_POSITION"] - 1 27 | seq_end = snakemake.config["END_POSITION"] 28 | 29 | print(f"start: {seq_start}, end: {seq_end}, end-start: {seq_end - seq_start}") 30 | 31 | 32 | # load sequence 33 | sequence_path = snakemake.input[0] 34 | genome = load_fasta(sequence_path) 35 | sequence = genome[chromosome][seq_start:seq_end] 36 | 37 | print(len(sequence)) 38 | 39 | dependency_map = ndm.compute_dependency_map(sequence, model, tokenizer) 40 | 41 | df = pd.DataFrame(dependency_map) 42 | df.to_parquet(snakemake.output[0]) 43 | -------------------------------------------------------------------------------- /scripts/data_preparation/config/config.yaml: -------------------------------------------------------------------------------- 1 | # assumes the first column contains the assembly name 2 | assemblies_path: "config/annotated_chrom+_ncbi_refseq_cucujiformia_assemblies.tsv" 3 | 4 | # Intervals from fasta file used for training: 5 | # - "all": all positions 6 | # - "defined": positions with defined nucleotides (not N) 7 | # - "annotation_{feature}": only positions from annotation, e.g. CDS, exon 8 | # - "balanced_v1": recipe used in original paper 9 | target_intervals: "balanced_v1" 10 | 11 | # window_size: 512 12 | # step_size: 256 13 | window_size: 1025 14 | step_size: 512 15 | add_rc: False # random rc is now done on-the-fly during training 16 | 17 | # chroms will be randomly assigned to splits 18 | split_proportion: 19 | train: 0.99 20 | validation: 0.005 21 | test: 0.005 22 | 23 | # this chroms are forced to be in validation set 24 | whitelist_validation_chroms: 25 | # - "NC_087403.1" # Tribolium Castaneum chr10 26 | - "NC_087401.1" # Tribolium Castaneum chr8 27 | # this chroms are forced to be in test set 28 | whitelist_test_chroms: 29 | # - "NC_087404.1" # Tribolium Castaneum chr11 30 | - "NC_087402.1" # Tribolium Castaneum chr9 31 | 32 | # We want to split data into shards of e.g. ~100MB each 33 | # It's good to have at least num_cpus shards to increase parallel loading speed 34 | # of iterable datasets from HF hub 35 | # samples_per_file: 500_000 36 | samples_per_file: 250_000 37 | -------------------------------------------------------------------------------- /scripts/data_preparation/workflow/rules/download.smk: -------------------------------------------------------------------------------- 1 | conda: "../envs/global.yaml" 2 | 3 | assemblies["Assembly Name"] = assemblies["Assembly Name"].str.replace(" ", "_") 4 | assemblies["genome_path"] = ( 5 | "tmp/" 6 | + assemblies.index 7 | + "/ncbi_dataset/data/" 8 | + assemblies.index 9 | + "/" 10 | + assemblies.index 11 | + "_" 12 | + assemblies["Assembly Name"] 13 | + "_genomic.fna" 14 | ) 15 | assemblies["annotation_path"] = ( 16 | "tmp/" 17 | + assemblies.index 18 | + "/ncbi_dataset/data/" 19 | + assemblies.index 20 | + "/genomic.gff" 21 | ) 22 | 23 | 24 | rule download_genome: 25 | output: 26 | "results/genome/{assembly}.fa.gz", 27 | "results/annotation/{assembly}.gff.gz", 28 | params: 29 | tmp_dir=directory("tmp/{assembly}"), 30 | genome_path=lambda wildcards: assemblies.loc[wildcards.assembly, "genome_path"], 31 | annotation_path=lambda wildcards: assemblies.loc[ 32 | wildcards.assembly, "annotation_path" 33 | ], 34 | shell: 35 | """ 36 | mkdir -p {params.tmp_dir} && cd {params.tmp_dir} && 37 | datasets download genome accession {wildcards.assembly} --include genome,gff3 \ 38 | && unzip ncbi_dataset.zip && cd - && gzip -c {params.genome_path} > {output[0]}\ 39 | && gzip -c {params.annotation_path} > {output[1]} && rm -r {params.tmp_dir} 40 | """ 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Welcome to the codebase behind the masterthesis "Analysis on GPN and its application to distant species" 2 | 3 | Most happens in `scripts` and `experiments`. 4 | 5 | 6 | 7 | `scripts/data_preparation` is pretty much a copy of https://github.com/songlab-cal/gpn/tree/main/workflow/make_dataset with different species as input. 8 | 9 | 10 | `scripts/arabidopsis_halleri` and `scripts/tribolium_castaneum` are adaptations of https://github.com/songlab-cal/gpn/blob/main/analysis/gpn_arabidopsis/Snakefile to the respective species. 11 | 12 | 13 | `scripts/nucleotide_dependency_maps` is an adaptation to a Snakemake workflow from https://github.com/gagneurlab/dependencies_DNALM/blob/main/compute_and_visualize_dep_maps.ipynb. 14 | 15 | 16 | `scripts/high_throughput_gpn_computation` is an original pipeline to quickly compute gpn scores for genomes. 17 | 18 | 19 | `experiments/context_length_prediction_impact` is the code for the analysis on utilized context size of the GPN. 20 | 21 | 22 | `NOTES.md` is kind of of a dev-log and also contains the command used to start the training of new GPN models. 23 | 24 | 25 | https://huggingface.co/sbuedenb has all published models and datasets. 26 | 27 | https://api.wandb.ai/links/sbuedenb-university-of-cologne/wp5omxcc has all training and evaluation graphs of all the total 26 runs. 28 | 29 | https://www.semanticscholar.org/shared/library/folder/10788499 contains the researched literature. -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/display_distribution_shift_chart.py: -------------------------------------------------------------------------------- 1 | import helpers 2 | import pandas as pd 3 | from snakemake.script import snakemake 4 | from gpn.data import load_fasta 5 | 6 | # load probabilities_per_context_length 7 | probabilities_per_context_length_path = snakemake.input[0] 8 | probabilities_per_context_length = pd.read_parquet( 9 | probabilities_per_context_length_path 10 | ) 11 | 12 | print(probabilities_per_context_length.shape) 13 | 14 | # load config 15 | chromosome = snakemake.wildcards.chromosome 16 | masked_position = int(snakemake.wildcards.position) 17 | window_size = snakemake.config["DISTRIBUTION_SHIFT_WINDOW_SIZE"] 18 | distribution_shift_threshold = float(snakemake.config["DISTRIBUTION_SHIFT_THRESHOLD"]) 19 | 20 | title = f"Moved Probability Mass over Context Length \n (chr: {chromosome}, pos: {masked_position})" 21 | 22 | 23 | distribution_shift = helpers.get_distribution_shift(probabilities_per_context_length) 24 | 25 | threshold_step = ( 26 | helpers.find_context_size_step_with_distribution_shift_below_threshold( 27 | probabilities_per_context_length, 28 | window_size=window_size, 29 | threshold=distribution_shift_threshold, 30 | ) 31 | ) 32 | 33 | helpers.plot_line( 34 | distribution_shift, 35 | snakemake.output[0], 36 | snakemake.wildcards.format, 37 | title=title, 38 | xlabel="Context Length (bp)", 39 | ylabel="Moved Probability Mass", 40 | marker=threshold_step, 41 | ) 42 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/display_stacked_variance_chart.py: -------------------------------------------------------------------------------- 1 | import helpers 2 | import pandas as pd 3 | from snakemake.script import snakemake 4 | from gpn.data import load_fasta 5 | 6 | # load probabilities_per_context_length 7 | probabilities_per_context_length_path = snakemake.input[1] 8 | probabilities_per_context_length = pd.read_parquet( 9 | probabilities_per_context_length_path 10 | ) 11 | 12 | print(probabilities_per_context_length.shape) 13 | 14 | # load config 15 | chromosome = snakemake.wildcards.chromosome 16 | masked_position = int(snakemake.wildcards.position) 17 | window_size = snakemake.config["PREDICTION_VARIANCE_WINDOW_SIZE"] 18 | prediction_variance_threshold = float(snakemake.config["PREDICTION_VARIANCE_THRESHOLD"]) 19 | 20 | # load sequence 21 | sequence_path = snakemake.input[0] 22 | genome = load_fasta(sequence_path) 23 | sequence = genome[chromosome] 24 | 25 | title = ( 26 | f"(chr: {chromosome}, pos: {masked_position}, ref: {sequence[masked_position - 1]})" 27 | ) 28 | 29 | rolling_var = helpers.rolling_variance( 30 | probabilities_per_context_length, window_size=window_size 31 | ) 32 | 33 | threshold_step = rolling_var.index[ 34 | rolling_var.sum(axis=1) < prediction_variance_threshold 35 | ].min() 36 | 37 | 38 | helpers.plot_stacked_area( 39 | rolling_var, 40 | snakemake.output[0], 41 | snakemake.wildcards.format, 42 | title=title, 43 | xlabel="Total Context (bp)", 44 | ylabel="Variance of Prediction", 45 | marker=threshold_step, 46 | ) 47 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/stats_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rolling_variance(results, window_size=100): 5 | # we compute the variance over the forward looking window, i.e. if in the future nothing changes, we dont need more context size 6 | return results[::-1].rolling(window=window_size, min_periods=1).var(ddof=0)[::-1] 7 | 8 | 9 | def find_context_size_step_with_total_prediction_variance_below_threshold( 10 | results, window_size=100, threshold=1e-5 11 | ): 12 | rolling_var = rolling_variance(results, window_size) 13 | return rolling_var.index[rolling_var.sum(axis=1) < threshold].min() 14 | 15 | 16 | def compute_gpn_score(reference_nucleotide, probabilities_per_context_length): 17 | 18 | reference_nucleotide = reference_nucleotide.lower() 19 | 20 | nucleotides = ["a", "c", "g", "t"] 21 | alternatives = [n for n in nucleotides if n != reference_nucleotide] 22 | 23 | gpn_scores = probabilities_per_context_length 24 | 25 | for alt in alternatives: 26 | gpn_scores[f"gpn_{alt}"] = gpn_scores[alt] / gpn_scores[reference_nucleotide] 27 | 28 | gpn_scores = np.log2(gpn_scores).drop(nucleotides, axis=1) 29 | 30 | def find_context_size_step_with_distribution_shift_below_threshold(results, window_size=10, threshold=0.01): 31 | distribution_shift = results.diff().abs().sum(axis=1).div(2)[1:] 32 | roll_avg_diff = distribution_shift[::-1].rolling(window=window_size, min_periods=1).mean()[::-1] 33 | return roll_avg_diff.index[roll_avg_diff < threshold].min() 34 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/display_stacked_probability_chart.py: -------------------------------------------------------------------------------- 1 | import helpers 2 | import pandas as pd 3 | from snakemake.script import snakemake 4 | from gpn.data import load_fasta 5 | 6 | # load probabilities_per_context_length 7 | probabilities_per_context_length_path = snakemake.input[1] 8 | probabilities_per_context_length = pd.read_parquet( 9 | probabilities_per_context_length_path 10 | ) 11 | 12 | print(probabilities_per_context_length.shape) 13 | 14 | # load config 15 | chromosome = snakemake.wildcards.chromosome 16 | masked_position = int(snakemake.wildcards.position) 17 | window_size = snakemake.config["PREDICTION_VARIANCE_WINDOW_SIZE"] 18 | prediction_variance_threshold = float(snakemake.config["PREDICTION_VARIANCE_THRESHOLD"]) 19 | 20 | # load sequence 21 | sequence_path = snakemake.input[0] 22 | genome = load_fasta(sequence_path) 23 | sequence = genome[chromosome] 24 | 25 | title = f"Influence of context length \n (chr: {chromosome}, pos: {masked_position}, ref: {sequence[masked_position - 1]})" 26 | 27 | threshold_step = ( 28 | helpers.find_context_size_step_with_total_prediction_variance_below_threshold( 29 | probabilities_per_context_length, 30 | window_size=window_size, 31 | threshold=prediction_variance_threshold, 32 | ) 33 | ) 34 | 35 | helpers.plot_stacked_area( 36 | probabilities_per_context_length, 37 | snakemake.output[0], 38 | snakemake.wildcards.format, 39 | title=title, 40 | xlabel="Total Context (bp)", 41 | ylabel="Predicted Distribution", 42 | marker=threshold_step, 43 | ) 44 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/compute_distribution_shift_over_chromosome.py: -------------------------------------------------------------------------------- 1 | import stats_helpers 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from snakemake.script import snakemake 5 | import gffutils 6 | 7 | # load config 8 | distribution_shift_window_size = snakemake.config["DISTRIBUTION_SHIFT_WINDOW_SIZE"] 9 | distribution_shift_threshold = float(snakemake.config["DISTRIBUTION_SHIFT_THRESHOLD"]) 10 | 11 | chromosome = snakemake.wildcards.chromosome 12 | 13 | # load annotation db 14 | db = gffutils.FeatureDB(snakemake.input.annotation_db) 15 | 16 | 17 | def get_position_feature_type(position, chrom=chromosome): 18 | overlapping_features = list(db.region(seqid=chrom, start=position, end=position)) 19 | if not overlapping_features: 20 | return "unknown" 21 | overlapping_features.sort(key=lambda f: f.end - f.start + 1) 22 | return overlapping_features[0].featuretype 23 | 24 | 25 | threshold_steps = [] 26 | positions = [] 27 | 28 | for df_path in tqdm(snakemake.input.position_data, desc="Random Position", position=0): 29 | results = pd.read_parquet(df_path) 30 | 31 | position = int(df_path.split("/")[-2]) 32 | 33 | positions.append(position) 34 | 35 | threshold_steps.append( 36 | stats_helpers.find_context_size_step_with_distribution_shift_below_threshold( 37 | results, 38 | window_size=distribution_shift_window_size, 39 | threshold=distribution_shift_threshold, 40 | ) 41 | ) 42 | 43 | df = pd.DataFrame({"position": positions, "threshold_steps": threshold_steps}) 44 | 45 | df["feature"] = df["position"].map(get_position_feature_type) 46 | 47 | df.to_parquet(snakemake.output[0]) 48 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/compute_prediction_variance_over_chromosome copy.py: -------------------------------------------------------------------------------- 1 | import stats_helpers 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from snakemake.script import snakemake 5 | import gffutils 6 | 7 | # load config 8 | prediction_variance_window_size = snakemake.config["PREDICTION_VARIANCE_WINDOW_SIZE"] 9 | prediction_variance_threshold = float(snakemake.config["PREDICTION_VARIANCE_THRESHOLD"]) 10 | 11 | chromosome = snakemake.wildcards.chromosome 12 | 13 | # load annotation db 14 | db = gffutils.FeatureDB(snakemake.input.annotation_db) 15 | 16 | 17 | def get_position_feature_type(position, chrom=chromosome): 18 | overlapping_features = list(db.region(seqid=chrom, start=position, end=position)) 19 | if not overlapping_features: 20 | return "unknown" 21 | overlapping_features.sort(key=lambda f: f.end - f.start + 1) 22 | return overlapping_features[0].featuretype 23 | 24 | 25 | threshold_steps = [] 26 | positions = [] 27 | 28 | for df_path in tqdm(snakemake.input.position_data, desc="Random Position", position=0): 29 | results = pd.read_parquet(df_path) 30 | 31 | position = int(df_path.split("/")[-2]) 32 | 33 | positions.append(position) 34 | 35 | threshold_steps.append( 36 | stats_helpers.find_context_size_step_with_total_prediction_variance_below_threshold( 37 | results, 38 | window_size=prediction_variance_window_size, 39 | threshold=prediction_variance_threshold, 40 | ) 41 | ) 42 | 43 | df = pd.DataFrame({"position": positions, "threshold_steps": threshold_steps}) 44 | 45 | df["feature"] = df["position"].map(get_position_feature_type) 46 | 47 | df.to_parquet(snakemake.output[0]) 48 | -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/scripts/run_classification.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from sklearn.linear_model import LogisticRegression, LogisticRegressionCV 4 | from sklearn.model_selection import cross_val_predict, LeaveOneGroupOut 5 | from sklearn.pipeline import Pipeline 6 | from sklearn.preprocessing import StandardScaler 7 | from joblib import Memory 8 | 9 | memory = Memory(location="/scratch/sbuedenb/cache/.sk_cache", verbose=0) 10 | 11 | windows = pd.read_parquet(snakemake.input[0]) 12 | features = pd.read_parquet(snakemake.input[1]) 13 | 14 | want = [ 15 | # "NC_087395.1", # chr2 16 | # "NC_087396.1", # chr3 17 | # "NC_087397.1", # chr4 18 | # "NC_087398.1", # chr5 19 | # "NC_087399.1", # chr6 20 | "NC_087403.1", # chr10 21 | "NC_087404.1" # chr11 22 | ] 23 | 24 | # subset to selected chromosomes 25 | windows = windows[windows.chrom.isin(want)] 26 | features = features.loc[windows.index] 27 | 28 | X = features.to_numpy(dtype=np.float32) 29 | y = windows.Region.values 30 | groups = windows.chrom.values 31 | 32 | clf = Pipeline( 33 | [ 34 | ("scaler", StandardScaler(copy=False)), 35 | ( 36 | "linear", 37 | LogisticRegressionCV( 38 | solver="saga", 39 | random_state=42, 40 | verbose=2, 41 | max_iter=500, 42 | class_weight="balanced", 43 | n_jobs=1, 44 | ), 45 | ), 46 | ], 47 | memory=memory, 48 | ) 49 | 50 | preds = cross_val_predict( 51 | clf, 52 | X, 53 | y, 54 | groups=groups, 55 | cv=LeaveOneGroupOut(), 56 | verbose=2, 57 | n_jobs=-1, 58 | ) 59 | 60 | pd.DataFrame({"pred_Region": preds}).to_parquet(snakemake.output[0], index=False) 61 | -------------------------------------------------------------------------------- /scripts/nucleotide_dependency_maps/Snakefile: -------------------------------------------------------------------------------- 1 | configfile: "config.yaml" 2 | 3 | 4 | assert config["START_POSITION"] < config["END_POSITION"] 5 | assert (config["END_POSITION"] - config["START_POSITION"]) - config[ 6 | "DISPLAY_MAP_START_RELATIVE" 7 | ] + config["DISPLAY_MAP_END_RELATIVE"] + 1 > 0 8 | 9 | 10 | rule all: 11 | input: 12 | expand( 13 | "output/display_dependency_map/{model}/{sequence}/{chromosome}/seq_{seq_start}-{seq_end}/dp_{map_start}-{map_end}.png", 14 | model=config["MODELS"], 15 | sequence=config["SEQUENCES"], 16 | chromosome=config["CHROMOSOME"], 17 | seq_start=config["START_POSITION"], 18 | seq_end=config["END_POSITION"], 19 | map_start=config["DISPLAY_MAP_START_RELATIVE"], 20 | map_end=config["DISPLAY_MAP_END_RELATIVE"], 21 | ), 22 | 23 | 24 | rule download_fasta: 25 | output: 26 | "output/download_fasta/{sequence}.fna", 27 | params: 28 | id=lambda wildcards: wildcards.sequence, 29 | db="nuccore", 30 | format="fasta", 31 | wrapper: 32 | "v5.8.2/bio/entrez/efetch" 33 | 34 | 35 | rule compute_dependency_map: 36 | input: 37 | "output/download_fasta/{sequence}.fna", 38 | params: 39 | model=lambda wildcards: wildcards.model, 40 | output: 41 | "output/compute_dependency_map/{model}/{sequence}/{chromosome}/seq_{seq_start}-{seq_end}.parquet", 42 | script: 43 | "scripts/compute_dependency_map.py" 44 | 45 | 46 | rule display_dependency_map: 47 | input: 48 | "output/download_fasta/{sequence}.fna", 49 | "output/compute_dependency_map/{model}/{sequence}/{chromosome}/seq_{seq_start}-{seq_end}.parquet", 50 | output: 51 | "output/display_dependency_map/{model}/{sequence}/{chromosome}/seq_{seq_start}-{seq_end}/dp_{map_start}-{map_end}.{format}", 52 | script: 53 | "scripts/display_dependency_map.py" 54 | -------------------------------------------------------------------------------- /scripts/data_preparation/config/annotated_chrom+_ncbi_refseq_cucujiformia_assemblies.tsv: -------------------------------------------------------------------------------- 1 | Assembly Accession Assembly Name Organism Name Organism Infraspecific Names Breed Organism Infraspecific Names Strain Organism Infraspecific Names Cultivar Organism Infraspecific Names Ecotype Organism Infraspecific Names Isolate Organism Infraspecific Names Sex Annotation Name Assembly Level Assembly Release Date WGS project accession Assembly Stats Number of Scaffolds 2 | GCF_917563875.1 PGI_DIABVI_V3a Diabrotica virgifera virgifera NCBI Diabrotica virgifera virgifera Annotation Release 101 Chromosome 2022-07-29 10 3 | GCF_024364675.1 icAetTumi1.1 Aethina tumida Nest 87 male NCBI Aethina tumida Annotation Release 101 Chromosome 2022-07-22 JALKMD01 8 4 | GCF_031307605.1 icTriCast1.1 Tribolium castaneum GA2 male GCF_031307605.1-RS_2024_04 Chromosome 2023-09-14 JANKOB01 148 5 | GCF_914767665.1 icHarAxyr1.1 Harmonia axyridis NCBI Harmonia axyridis Annotation Release 100 Chromosome 2021-09-16 CAJZBN01 13 6 | GCF_963966145.1 icTenMoli1.1 Tenebrio molitor GCF_963966145.1-RS_2024_10 Chromosome 2024-02-12 CAWYQD01 236 7 | GCF_907165205.1 icCocSept1.1 Coccinella septempunctata NCBI Coccinella septempunctata Annotation Release 100 Chromosome 2021-05-17 CAJRAZ01 24 8 | GCF_040115645.1 ASM4011564v1 Euwallacea fornicatus EFF26 female GCF_040115645.1-RS_2024_07 Chromosome 2024-06-11 JBBMRQ01 185 9 | GCF_022605725.1 icAntGran1.3 Anthonomus grandis grandis male NCBI Anthonomus grandis grandis Annotation Release 100 Chromosome 2022-08-03 JAKYJU02 304 10 | GCF_026250575.1 icDioCari1.1 Diorhabda carinulata Delta male GCF_026250575.1-RS_2023_06 Chromosome 2022-11-21 JAOVUX01 68 11 | GCF_026230105.1 icDioSubl1.1 Diorhabda sublineata icDioSubl1.1 male GCF_026230105.1-RS_2023_05 Chromosome 2022-11-18 JAPHNJ01 310 12 | GCF_040954645.1 icDiaUnde3 Diabrotica undecimpunctata CICGRU male GCF_040954645.1-RS_2025_03 Chromosome 2024-08-01 JBDMBP01 2867 13 | 14 | GCF_039881205.1 ESF131.1 Euwallacea similis ESF13 female GCF_039881205.1-RS_2024_07 Chromosome 2024-05-29 JBBMRO01 128 15 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/compute_gpn_statistics_over_chromosome.py: -------------------------------------------------------------------------------- 1 | import stats_helpers 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from gpn.data import load_fasta 5 | from snakemake.script import snakemake 6 | import gffutils 7 | 8 | # load config 9 | prediction_variance_window_size = snakemake.config["PREDICTION_VARIANCE_WINDOW_SIZE"] 10 | prediction_variance_threshold = float(snakemake.config["PREDICTION_VARIANCE_THRESHOLD"]) 11 | 12 | chromosome = snakemake.wildcards.chromosome 13 | 14 | # load annotation db 15 | db = gffutils.FeatureDB(snakemake.input.annotation_db) 16 | 17 | 18 | def get_position_feature_type(position, chrom=chromosome): 19 | overlapping_features = list(db.region(seqid=chrom, start=position, end=position)) 20 | if not overlapping_features: 21 | return "unknown" 22 | overlapping_features.sort(key=lambda f: f.end - f.start + 1) 23 | return overlapping_features[0].featuretype 24 | 25 | 26 | # load sequence 27 | sequence_path = snakemake.input.sequence 28 | genome = load_fasta(sequence_path) 29 | sequence = genome[chromosome] 30 | 31 | threshold_steps = [] 32 | positions = [] 33 | 34 | for df_path in tqdm(snakemake.input.position_data, desc="Random Position", position=0): 35 | results = pd.read_parquet(df_path) 36 | 37 | position = int(df_path.split("/")[-2]) 38 | 39 | reference_nucleotide = sequence[position - 1] 40 | 41 | # skip position when reference is unknown 42 | if reference_nucleotide in ["n", "N"]: 43 | continue 44 | 45 | gpn_scores = stats_helpers.compute_gpn_score(reference_nucleotide, results) 46 | 47 | positions.append(position) 48 | 49 | threshold_steps.append( 50 | stats_helpers.find_context_size_step_with_total_prediction_variance_below_threshold( 51 | results, 52 | window_size=prediction_variance_window_size, 53 | threshold=prediction_variance_threshold, 54 | ) 55 | ) 56 | 57 | df = pd.DataFrame({"position": positions, "threshold_steps": threshold_steps}) 58 | 59 | df["feature"] = df["position"].map(get_position_feature_type) 60 | 61 | df.to_parquet(snakemake.output[0]) 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # see: https://github.com/songlab-cal/gpn 2 | gpn @ https://github.com/songlab-cal/gpn/archive/refs/tags/0.6.zip 3 | 4 | # see: https://matplotlib.org 5 | matplotlib == 3.9.2 6 | 7 | # see here: https://seaborn.pydata.org 8 | seaborn == 0.13.2 9 | 10 | # see here: https://pandas.pydata.org 11 | pandas == 2.2.3 12 | 13 | # see here: https://scikit-learn.org 14 | scikit-learn == 1.5.2 15 | 16 | # see here: https://pytorch.org 17 | torch == 2.5.1 18 | 19 | # see here: https://huggingface.co/docs/transformers/index 20 | transformers == 4.51.3 21 | 22 | evaluate == 0.4.3 23 | 24 | # see here: https://docs.jupyter.org/en/latest/ 25 | jupyter == 1.1.1 26 | 27 | # see here: https://pytorch.org/vision/stable/index.html 28 | torchvision == 0.20.1 29 | 30 | # see here: https://snakemake.github.io 31 | snakemake == 8.25.5 32 | 33 | # see here: https://github.com/althonos/pymemesuite 34 | pymemesuite == 0.1.0a3 35 | 36 | # see here: https://github.com/jbkinney/logomaker 37 | logomaker == 0.8 38 | 39 | # see here: https://sgkit-dev.github.io/sgkit/latest/ 40 | sgkit == 0.9.0 41 | 42 | # see here: https://github.com/brentp/cyvcf2 43 | cyvcf2 == 0.31.1 44 | 45 | # see here: https://github.com/h5py/h5py 46 | h5py == 3.12.1 47 | 48 | # see here: https://github.com/more-itertools/more-itertools 49 | more-itertools == 10.5.0 50 | 51 | # see here: https://scipy.org 52 | scipy == 1.14.1 53 | 54 | # see here: https://www.statsmodels.org/stable/index.html 55 | statsmodels == 0.14.0 56 | 57 | # see here: https://umap-learn.readthedocs.io/en/latest/basic_usage.html 58 | umap-learn[plot] == 0.5.7 59 | 60 | # see here: https://scanpy.readthedocs.io/en/stable/ 61 | scanpy == 1.10.4 62 | 63 | # see here: https://python.igraph.org/en/latest/index.html 64 | igraph == 0.11.8 65 | 66 | # see here: https://github.com/vtraag/leidenalg 67 | leidenalg == 0.10.2 68 | 69 | # - pytorch-cuda ? 70 | # - snakemake 71 | # - pymemesuite 72 | # - logomaker 73 | # - sgkit 74 | # - cyvcf2 75 | # - h5py 76 | # - more-itertools 77 | # - scipy 78 | # - statsmodels 79 | # - umap-learn 80 | # - scanpy 81 | # - python-igraph 82 | # - leidenalg 83 | # - ncbi-datasets-cli => standalone cli tool 84 | 85 | onnx 86 | 87 | onnxscript -------------------------------------------------------------------------------- /scripts/data_preparation/workflow/rules/intervals.smk: -------------------------------------------------------------------------------- 1 | conda: "../envs/global.yaml" 2 | 3 | from gpn.data import ( 4 | Genome, 5 | load_table, 6 | get_balanced_intervals, 7 | filter_length, 8 | filter_annotation_features, 9 | ) 10 | 11 | 12 | rule make_all_intervals: 13 | input: 14 | "results/genome/{assembly}.fa.gz", 15 | output: 16 | "results/intervals/{assembly}/all.parquet", 17 | threads: 2 18 | run: 19 | I = Genome(input[0]).get_all_intervals() 20 | I = filter_length(I, config["window_size"]) 21 | I.to_parquet(output[0], index=False) 22 | 23 | 24 | rule make_defined_intervals: 25 | input: 26 | "results/genome/{assembly}.fa.gz", 27 | output: 28 | "results/intervals/{assembly}/defined.parquet", 29 | threads: 2 30 | run: 31 | I = Genome(input[0]).get_defined_intervals() 32 | I = filter_length(I, config["window_size"]) 33 | I.to_parquet(output[0], index=False) 34 | 35 | 36 | rule make_annotation_intervals: 37 | input: 38 | "results/intervals/{assembly}/defined.parquet", 39 | "results/annotation/{assembly}.gff.gz", 40 | output: 41 | "results/intervals/{assembly}/annotation_{feature}.parquet", 42 | run: 43 | I = pd.read_parquet(input[0]) 44 | annotation = load_table(input[1]) 45 | include_flank = config.get( 46 | "annotation_features_include_flank", config["window_size"] // 2 47 | ) 48 | add_jiter = config.get("annotation_features_add_jitter", 100) 49 | I = filter_annotation_features( 50 | I, 51 | annotation, 52 | wildcards.feature, 53 | include_flank=include_flank, 54 | jitter=add_jitter, 55 | ) 56 | I = filter_length(I, config["window_size"]) 57 | I.to_parquet(output[0], index=False) 58 | 59 | 60 | rule make_balanced_v1_intervals: 61 | input: 62 | "results/intervals/{assembly}/defined.parquet", 63 | "results/annotation/{assembly}.gff.gz", 64 | output: 65 | "results/intervals/{assembly}/balanced_v1.parquet", 66 | run: 67 | defined_intervals = load_table(input[0]) 68 | annotation = load_table(input[1]) 69 | intervals = get_balanced_intervals( 70 | defined_intervals, 71 | annotation, 72 | config["window_size"], 73 | config.get("promoter_upstream", 1000), 74 | ) 75 | intervals.to_parquet(output[0], index=False) 76 | -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/scripts/expand_annotation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import bioframe as bf 3 | from snakemake.script import snakemake 4 | 5 | # from gtfparse import read_gtf 6 | 7 | 8 | # gtf = read_gtf(snakemake.input[0], expand_attribute_column=False) 9 | 10 | gtf = pd.read_csv( 11 | snakemake.input[0], 12 | sep="\t", 13 | header=None, 14 | comment="#", 15 | dtype={"chrom": str}, 16 | names=[ 17 | "chrom", 18 | "source", 19 | "feature", 20 | "start", 21 | "end", 22 | "score", 23 | "strand", 24 | "frame", 25 | "attribute", 26 | ], 27 | ) 28 | 29 | # why??? in gpn.data.load_table 30 | gtf.start -= 1 31 | 32 | # add missing region entries 33 | for chrom in ["chr1", "chr2", "chr3", "chr4", "chr5", "chr6", "chr7", "chr8"]: 34 | start = 0 35 | end = gtf[gtf["chrom"] == chrom]["end"].max() 36 | region_entry = { 37 | "chrom": chrom, 38 | "source": ".", 39 | "feature": "region", 40 | "start": start, 41 | "end": end, 42 | "score": ".", 43 | "strand": "+", 44 | "frame": ".", 45 | "attribute": ".", 46 | } 47 | gtf.loc[len(gtf)] = region_entry 48 | 49 | 50 | genic_features = [ 51 | "gene", 52 | "transcript", 53 | ] 54 | 55 | chrom_regions = gtf[gtf.feature == "region"][["chrom", "start", "end"]] 56 | 57 | genic_intervals = gtf[gtf.feature.isin(genic_features)][["chrom", "start", "end"]] 58 | 59 | genic_intervals = bf.merge(genic_intervals) 60 | 61 | 62 | intergenic = bf.subtract(chrom_regions, genic_intervals) 63 | # subtract uses end of subtracted interval as start, it seems 64 | intergenic["start"] = intergenic["start"] + 1 65 | intergenic["feature"] = "intergenic" 66 | 67 | gtf = pd.concat([gtf, intergenic], ignore_index=True) 68 | 69 | 70 | gtf_exon = gtf[gtf.feature == "exon"] 71 | exonic_intervals = bf.merge(gtf_exon)[["chrom", "start", "end"]] 72 | intronic_intervals = bf.subtract(genic_intervals, exonic_intervals) 73 | # subtract uses end of subtracted interval as start, it seems 74 | intronic_intervals["start"] = intronic_intervals["start"] + 1 75 | 76 | intronic_intervals["feature"] = "intron" 77 | 78 | gtf = pd.concat([gtf, intronic_intervals], ignore_index=True) 79 | 80 | 81 | gtf_cds = gtf[gtf.feature == "CDS"] 82 | gene_cds_overlap = bf.overlap(genic_intervals, gtf_cds) 83 | non_coding_genes = gene_cds_overlap[gene_cds_overlap["chrom_"].isnull()][ 84 | ["chrom", "start", "end"] 85 | ] 86 | non_coding_genes["feature"] = "ncRNA_gene" 87 | 88 | 89 | gtf = pd.concat([gtf, non_coding_genes], ignore_index=True) 90 | 91 | gtf = gtf.drop_duplicates(subset=["chrom", "start", "end", "feature"]) 92 | 93 | gtf.to_parquet(snakemake.output[0], index=False) 94 | -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/scripts/run_classification.py: -------------------------------------------------------------------------------- 1 | # import pandas as pd 2 | # import numpy as np 3 | # from sklearn.linear_model import LogisticRegression, LogisticRegressionCV 4 | # from sklearn.model_selection import cross_val_predict, LeaveOneGroupOut 5 | # from sklearn.pipeline import Pipeline 6 | # from sklearn.preprocessing import StandardScaler 7 | # from joblib import Memory 8 | 9 | import pandas as pd 10 | from sklearn.linear_model import LogisticRegressionCV 11 | from sklearn.model_selection import cross_val_predict, LeaveOneGroupOut 12 | from sklearn.pipeline import Pipeline 13 | from sklearn.preprocessing import StandardScaler 14 | 15 | windows = pd.read_parquet(snakemake.input[0]) 16 | features = pd.read_parquet(snakemake.input[1]) 17 | 18 | 19 | clf = Pipeline( 20 | [ 21 | ("scaler", StandardScaler()), 22 | ( 23 | "linear", 24 | LogisticRegressionCV( 25 | random_state=42, 26 | verbose=True, 27 | max_iter=1000, 28 | class_weight="balanced", 29 | n_jobs=-1, 30 | ), 31 | ), 32 | ] 33 | ) 34 | preds = cross_val_predict( 35 | clf, 36 | features, 37 | windows.Region, 38 | groups=windows.chrom, 39 | cv=LeaveOneGroupOut(), 40 | verbose=True, 41 | ) 42 | pd.DataFrame({"pred_Region": preds}).to_parquet(output[0], index=False) 43 | 44 | # memory = Memory(location="/scratch/sbuedenb/cache/.sk_cache", verbose=0) 45 | 46 | # windows = pd.read_parquet(snakemake.input[0]) 47 | # features = pd.read_parquet(snakemake.input[1]) 48 | 49 | # want = [ 50 | # # "NC_087395.1", # chr2 51 | # # "NC_087396.1", # chr3 52 | # # "NC_087397.1", # chr4 53 | # # "NC_087398.1", # chr5 54 | # # "NC_087399.1", # chr6 55 | # "NC_087403.1", # chr10 56 | # "NC_087404.1" # chr11 57 | # ] 58 | 59 | # # subset to selected chromosomes 60 | # windows = windows[windows.chrom.isin(want)] 61 | # features = features.loc[windows.index] 62 | 63 | # X = features.to_numpy(dtype=np.float32) 64 | # y = windows.Region.values 65 | # groups = windows.chrom.values 66 | 67 | # clf = Pipeline( 68 | # [ 69 | # ("scaler", StandardScaler(copy=False)), 70 | # ( 71 | # "linear", 72 | # LogisticRegressionCV( 73 | # solver="saga", 74 | # random_state=42, 75 | # verbose=2, 76 | # max_iter=500, 77 | # class_weight="balanced", 78 | # n_jobs=1, 79 | # ), 80 | # ), 81 | # ], 82 | # memory=memory, 83 | # ) 84 | 85 | # preds = cross_val_predict( 86 | # clf, 87 | # X, 88 | # y, 89 | # groups=groups, 90 | # cv=LeaveOneGroupOut(), 91 | # verbose=2, 92 | # n_jobs=-1, 93 | # ) 94 | 95 | # pd.DataFrame({"pred_Region": preds}).to_parquet(snakemake.output[0], index=False) 96 | -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/scripts/expand_annotation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import bioframe as bf 3 | import re 4 | from gpn.data import load_fasta, load_table 5 | # from gtfparse import read_gtf 6 | 7 | 8 | # gtf = read_gtf(snakemake.input[0], expand_attribute_column=False) 9 | 10 | gtf = load_table(snakemake.input[0]) 11 | 12 | genome = load_fasta(snakemake.input[1]) 13 | 14 | # extract repeat regions from ref. genome 15 | for chrom in gtf.chrom.unique(): 16 | matches = list(re.finditer(r'[a-z]+', genome[chrom])) 17 | stretches = [(m.start(), m.end()) for m in matches] 18 | repeats = pd.DataFrame( 19 | stretches, columns=["start", "end"], dtype="int64" 20 | ) 21 | repeats["chrom"] = chrom 22 | 23 | repeats["feature"] = "Repeat" 24 | repeats["source"] = "Derived" 25 | repeats["strand"] = "+" 26 | gtf = pd.concat([gtf, repeats], ignore_index=True) 27 | 28 | 29 | genic_features = [ 30 | "gene", 31 | ### entirely covered from gene 32 | # "mRNA", 33 | # "CDS", 34 | # "ncRNA", 35 | # "transcript", 36 | # "lnc_RNA", 37 | # "primary_transcript", 38 | # "tRNA", 39 | # "snRNA", 40 | # "miRNA", 41 | # "rRNA", 42 | # "snoRNA", 43 | # "piRNA", 44 | # "cDNA_match", 45 | 46 | ### partially covered from gene 47 | "five_prime_UTR", 48 | "three_prime_UTR", 49 | "exon", 50 | "pseudogene", 51 | 52 | ### not conclusively genic 53 | # "Repeat" 54 | ] 55 | 56 | chrom_regions = gtf[gtf.feature == "region"][["chrom", "start", "end"]] 57 | 58 | genic_intervals = gtf[gtf.feature.isin(genic_features)][ 59 | ["chrom", "start", "end"] 60 | ] 61 | 62 | genic_intervals = bf.merge(genic_intervals) 63 | 64 | 65 | intergenic = bf.subtract(chrom_regions, genic_intervals) 66 | # subtract uses end of subtracted interval as start, it seems 67 | intergenic['start'] = intergenic['start'] + 1 68 | intergenic["feature"] = "intergenic" 69 | 70 | gtf = pd.concat([gtf, intergenic], ignore_index=True) 71 | 72 | 73 | gtf_exon = gtf[gtf.feature == "exon"] 74 | exonic_intervals = bf.merge(gtf_exon)[["chrom", "start", "end"]] 75 | intronic_intervals = bf.subtract(genic_intervals, exonic_intervals) 76 | # subtract uses end of subtracted interval as start, it seems 77 | intronic_intervals['start'] = intronic_intervals['start'] + 1 78 | 79 | intronic_intervals["feature"] = "intron" 80 | 81 | gtf = pd.concat([gtf, intronic_intervals], ignore_index=True) 82 | 83 | 84 | gtf_cds = gtf[gtf.feature=="CDS"] 85 | gene_cds_overlap = bf.overlap(genic_intervals, gtf_cds) 86 | non_coding_genes = gene_cds_overlap[gene_cds_overlap['chrom_'].isnull()][["chrom", "start", "end"]] 87 | non_coding_genes["feature"] = "ncRNA_gene" 88 | 89 | 90 | gtf = pd.concat([gtf, non_coding_genes], ignore_index=True) 91 | 92 | gtf = gtf.drop_duplicates(subset=["chrom", "start", "end", "feature"]) 93 | 94 | gtf.to_parquet(snakemake.output[0], index=False) -------------------------------------------------------------------------------- /scripts/data_preparation/workflow/rules/dataset.smk: -------------------------------------------------------------------------------- 1 | conda: "../envs/global.yaml" 2 | 3 | from gpn.data import Genome, make_windows, get_seq 4 | import math 5 | import numpy as np 6 | import os 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | 11 | split_proportions = [config["split_proportion"][split] for split in splits] 12 | assert np.isclose(sum(split_proportions), 1) 13 | 14 | 15 | rule make_dataset_assembly: 16 | input: 17 | lambda wildcards: f"results/intervals/{wildcards['assembly']}/{config['target_intervals']}.parquet", 18 | "results/genome/{assembly}.fa.gz", 19 | output: 20 | temp( 21 | expand( 22 | "results/dataset_assembly/{{assembly}}/{split}.parquet", split=splits 23 | ) 24 | ), 25 | threads: 2 26 | run: 27 | intervals = pd.read_parquet(input[0]) 28 | genome = Genome(input[1]) 29 | intervals = make_windows( 30 | intervals, 31 | config["window_size"], 32 | config["step_size"], 33 | config["add_rc"], 34 | ) 35 | print(intervals) 36 | intervals = intervals.sample(frac=1.0, random_state=42) 37 | intervals["assembly"] = wildcards["assembly"] 38 | intervals = intervals[["assembly", "chrom", "start", "end", "strand"]] 39 | intervals = get_seq(intervals, genome) 40 | print(intervals) 41 | 42 | chroms = intervals.chrom.unique() 43 | chrom_split = np.random.choice( 44 | splits, 45 | p=split_proportions, 46 | size=len(chroms), 47 | ) 48 | chrom_split[np.isin(chroms, config["whitelist_validation_chroms"])] = ( 49 | "validation" 50 | ) 51 | chrom_split[np.isin(chroms, config["whitelist_test_chroms"])] = "test" 52 | chrom_split = pd.Series(chrom_split, index=chroms) 53 | 54 | intervals_split = chrom_split[intervals.chrom] 55 | 56 | for path, split in zip(output, splits): 57 | print(path, split) 58 | # to parquet to be able to load faster later 59 | intervals[(intervals_split == split).values].to_parquet( 60 | path, 61 | index=False, 62 | ) 63 | 64 | 65 | # before uploading to HF Hub, remove data/split/.snakemake_timestamp files 66 | rule merge_datasets: 67 | input: 68 | expand( 69 | "results/dataset_assembly/{assembly}/{{split}}.parquet", 70 | assembly=assemblies.index, 71 | ), 72 | output: 73 | directory("results/dataset/data/{split}"), 74 | threads: workflow.cores 75 | run: 76 | intervals = pd.concat( 77 | tqdm((pd.read_parquet(path) for path in input), total=len(input)), 78 | ignore_index=True, 79 | ).sample(frac=1, random_state=42) 80 | print(intervals) 81 | 82 | if config.get("subsample_to_target", False) and wildcards.split == "train": 83 | n_target = (intervals.assembly == config["target_assembly"]).sum() 84 | intervals = ( 85 | intervals.groupby("assembly") 86 | .sample(n=n_target, random_state=42) 87 | .sample(frac=1, random_state=42) 88 | ) 89 | print(wildcards.split, intervals.assembly.value_counts()) 90 | print(intervals) 91 | 92 | n_shards = math.ceil(len(intervals) / config["samples_per_file"]) 93 | assert n_shards < 10000 94 | os.makedirs(output[0]) 95 | for i in tqdm(range(n_shards)): 96 | path = Path(output[0]) / f"shard_{i:05}.jsonl.zst" 97 | intervals.iloc[i::n_shards].to_json( 98 | path, 99 | orient="records", 100 | lines=True, 101 | compression={"method": "zstd", "threads": -1}, 102 | ) 103 | -------------------------------------------------------------------------------- /scripts/high_throughput_gpn_computation/workflow/Snakefile: -------------------------------------------------------------------------------- 1 | configfile: "config/config.yaml" 2 | 3 | 4 | container: "docker://continuumio/miniconda3" 5 | 6 | 7 | include: "rules/compute_gpn.smk" 8 | 9 | 10 | configurations = [ 11 | { 12 | "accession": "BAC", 13 | "chromosome": "EU382073.1", 14 | "reverse_complement": False, 15 | "positions": [ 16 | {"start": 64809, "stop": 72891}, 17 | {"start": 145542, "stop": 154625}, 18 | {"start": 200642, "stop": 210233}, 19 | ], 20 | }, 21 | { 22 | "accession": "Bors_12", 23 | "chromosome": "h1tg000001l:1539155-1704685", 24 | "reverse_complement": True, 25 | "positions": [ 26 | {"start": 19296, "stop": 28595}, 27 | {"start": 81970, "stop": 92289}, 28 | {"start": 143820, "stop": 151878}, 29 | ], 30 | }, 31 | { 32 | "accession": "Goli_08", 33 | "chromosome": "h1tg000022l:224290-393930", 34 | "reverse_complement": False, 35 | "positions": [ 36 | {"start": 11995, "stop": 21725}, 37 | {"start": 81338, "stop": 90143}, 38 | {"start": 140172, "stop": 150343}, 39 | ], 40 | }, 41 | { 42 | "accession": "Lan3.1", 43 | "chromosome": "chr3:21866677-22051473", 44 | "reverse_complement": True, 45 | "positions": [ 46 | {"start": 20001, "stop": 29622}, 47 | {"start": 74981, "stop": 84046}, 48 | {"start": 156715, "stop": 164797}, 49 | ], 50 | },{ 51 | "accession": "Lan3.1", 52 | "chromosome": "AHB2_Lan3.1_chr3:4241036-4244757", 53 | "reverse_complement": False, 54 | "positions": [ 55 | {"start": 0, "stop": 9999} 56 | ], 57 | },{ 58 | "accession": "Lan3.1", 59 | "chromosome": "AHB1_Lan3.1_chr3:19529095-19532314", 60 | "reverse_complement": False, 61 | "positions": [ 62 | {"start": 0, "stop": 9999} 63 | ], 64 | }, 65 | { 66 | "accession": "Lan5", 67 | "chromosome": "Chr3:26654299-26831685", 68 | "reverse_complement": False, 69 | "positions": [ 70 | {"start": 20001, "stop": 28083}, 71 | {"start": 100746, "stop": 109829}, 72 | {"start": 147765, "stop": 157387}, 73 | ], 74 | }, 75 | { 76 | "accession": "Lan5_hap2", 77 | "chromosome": "Chr3:23465017-23625889", 78 | "reverse_complement": False, 79 | "positions": [ 80 | {"start": 20001, "stop": 28542}, 81 | {"start": 76736, "stop": 85696}, 82 | {"start": 131258, "stop": 140873}, 83 | ], 84 | }, 85 | # { 86 | # "accession": "Noss_08", 87 | # "chromosome": "h1tg000006l:1479490-1622324", 88 | # "reverse_complement": False, 89 | # "positions": [ 90 | # ], 91 | # }, 92 | { 93 | "accession": "Pais_09", 94 | "chromosome": "tig00000673_chr3:1356930-1527379", 95 | "reverse_complement": True, 96 | "positions": [ 97 | {"start": 20001, "stop": 28957}, 98 | {"start": 66436, "stop": 76055}, 99 | {"start": 142436, "stop": 150450}, 100 | ], 101 | }, 102 | { 103 | "accession": "Rund_05_S1", 104 | "chromosome": "h2tg000015l:6652465-6810532", 105 | "reverse_complement": False, 106 | "positions": [ 107 | {"start": 13385, "stop": 21713}, 108 | {"start": 74557, "stop": 83507}, 109 | {"start": 129767, "stop": 138783}, 110 | ], 111 | }, 112 | { 113 | "accession": "Ukan_25", 114 | "chromosome": "h1tg000003l:24225942-24390488", 115 | "reverse_complement": False, 116 | "positions": [ 117 | {"start": 13390, "stop": 21722}, 118 | {"start": 83142, "stop": 92694}, 119 | {"start": 136322, "stop": 145258}, 120 | ], 121 | }, 122 | { 123 | "accession": "Wall_10", 124 | "chromosome": "tig00000401_chr3:6197730-6356479", 125 | "reverse_complement": False, 126 | "positions": [ 127 | {"start": 20001, "stop": 28083}, 128 | {"start": 74745, "stop": 83690}, 129 | {"start": 129160, "stop": 138750}, 130 | ], 131 | }, 132 | # Add more configurations as needed 133 | ] 134 | 135 | 136 | rule all: 137 | input: 138 | [ 139 | "results/{accession}/{chromosome}/{start_position}_{stop_position}_{reverse_complement}/gpn_scores.parquet".format( 140 | accession=con["accession"], 141 | chromosome=con["chromosome"], 142 | reverse_complement="rev" if con["reverse_complement"] else "fwd", 143 | start_position=position["start"], 144 | stop_position=position["stop"], 145 | ) 146 | for con in configurations 147 | for position in con["positions"] 148 | ], 149 | 150 | 151 | # expand( 152 | # "results/{accession}/{chromosome}/{start_position}_{stop_position}/gpn_scores.parquet", 153 | # accession=config["ACCESSION"], 154 | # chromosome=config["CHROMOSOME"], 155 | # start_position=config["START_POSITION"], 156 | # stop_position=config["STOP_POSITION"], 157 | # ), 158 | -------------------------------------------------------------------------------- /scripts/arabidopsis_halleri/Snakefile: -------------------------------------------------------------------------------- 1 | configfile: "config.yaml" 2 | 3 | 4 | conda: "envs/global.yaml" 5 | 6 | 7 | localrules: 8 | # download_reference, 9 | # download_annotation, 10 | download_utr_script, 11 | add_utr_to_annotation, 12 | expand_annotation, 13 | define_embedding_windows, 14 | all, 15 | 16 | 17 | # set to WANDB_MODE=disabled when not using it 18 | envvars: 19 | "WANDB_MODE", 20 | 21 | 22 | import pandas as pd 23 | from Bio import SeqIO 24 | import gzip 25 | import bioframe as bf 26 | from gpn.data import load_table, Genome, filter_length, make_windows 27 | import more_itertools 28 | 29 | 30 | WINDOW_SIZE = config["window_size"] 31 | EMBEDDING_WINDOW_SIZE = 100 32 | 33 | 34 | models = [ 35 | "gonzalobenegas/gpn-brassicales", 36 | ] 37 | 38 | 39 | rule all: 40 | input: 41 | expand("results/embedding/umap/{model}.parquet", model=models), 42 | # expand("results/embedding/classification/{model}.parquet", model=models), 43 | 44 | 45 | # rule download_reference: 46 | # output: 47 | # "output/genome.fa.gz", 48 | # shell: 49 | # "wget --no-check-certificate {config[FASTA_URL]} -O {output}" 50 | 51 | 52 | # rule download_annotation: 53 | # output: 54 | # "output/annotation.gtf.gz", 55 | # shell: 56 | # "wget --no-check-certificate {config[GTF_URL]} -O {output}" 57 | 58 | 59 | rule download_utr_script: 60 | params: 61 | url="https://ftp.ncbi.nlm.nih.gov/genomes/TOOLS/add_utrs_to_gff/add_utrs_to_gff.py", 62 | output: 63 | "resources/add_utrs_to_gff.py", 64 | shell: 65 | "wget --no-check-certificate {params.url} -O {output}" 66 | 67 | 68 | rule add_utr_to_annotation: 69 | input: 70 | "resources/add_utrs_to_gff.py", 71 | "resources/Lan3.1.genomic.gff3.gz", 72 | output: 73 | "results/annotation_utr.gtf", 74 | shell: 75 | "python {input} > {output}" 76 | 77 | 78 | rule expand_annotation: 79 | input: 80 | "results/annotation_utr.gtf", 81 | output: 82 | "results/annotation.expanded.parquet", 83 | conda: 84 | "envs/expand_annotation.yaml" 85 | script: 86 | "scripts/expand_annotation.py" 87 | 88 | 89 | rule define_embedding_windows: 90 | input: 91 | "results/annotation.expanded.parquet", 92 | "resources/Lan3.1.fna.gz", 93 | output: 94 | "results/embedding/windows.parquet", 95 | run: 96 | gtf = pd.read_parquet(input[0]) 97 | genome = Genome(input[1]) 98 | genome.filter_chroms( 99 | ["chr1", "chr2", "chr3", "chr4", "chr5", "chr6", "chr7", "chr8"] 100 | ) 101 | defined_intervals = genome.get_defined_intervals() 102 | defined_intervals = filter_length(defined_intervals, WINDOW_SIZE) 103 | windows = make_windows(defined_intervals, WINDOW_SIZE, EMBEDDING_WINDOW_SIZE) 104 | windows.rename(columns={"start": "full_start", "end": "full_end"}, inplace=True) 105 | 106 | windows["start"] = ( 107 | windows.full_start + windows.full_end 108 | ) // 2 - EMBEDDING_WINDOW_SIZE // 2 109 | windows["end"] = windows.start + EMBEDDING_WINDOW_SIZE 110 | 111 | features_of_interest = [ 112 | "intergenic", 113 | "CDS", 114 | "intron", 115 | # "three_prime_UTR", 116 | # "five_prime_UTR", 117 | "ncRNA_gene", 118 | # "Repeat", 119 | ] 120 | 121 | for f in features_of_interest: 122 | print(f) 123 | windows = bf.coverage(windows, gtf[gtf.feature == f]) 124 | windows.rename(columns=dict(coverage=f), inplace=True) 125 | 126 | # we keep if the center 100 bp are exactly covered by just on of the region of interest 127 | windows = windows[ 128 | (windows[features_of_interest] == EMBEDDING_WINDOW_SIZE).sum(axis=1) == 1 129 | ] 130 | windows["Region"] = windows[features_of_interest].idxmax(axis=1) 131 | windows.drop(columns=features_of_interest, inplace=True) 132 | 133 | windows.rename( 134 | columns={"start": "center_start", "end": "center_end"}, inplace=True 135 | ) 136 | windows.rename(columns={"full_start": "start", "full_end": "end"}, inplace=True) 137 | print(windows) 138 | windows.to_parquet(output[0], index=False) 139 | 140 | 141 | rule get_embeddings: 142 | input: 143 | "results/embedding/windows.parquet", 144 | "resources/Lan3.1.fna.gz", 145 | output: 146 | "results/embedding/embeddings/{model}.parquet", 147 | conda: 148 | "envs/gpn.yaml" 149 | threads: workflow.cores 150 | resources: 151 | slurm_extra="-G h100:1", 152 | shell: 153 | """ 154 | python -m gpn.ss.get_embeddings {input} {EMBEDDING_WINDOW_SIZE} \ 155 | {wildcards.model} {output} --per_device_batch_size 4000 --is_file \ 156 | --dataloader_num_workers {threads} 157 | """ 158 | 159 | 160 | rule run_umap: 161 | input: 162 | "{anything}/embeddings/{model}.parquet", 163 | output: 164 | "{anything}/umap/{model}.parquet", 165 | conda: 166 | "envs/umap.yaml" 167 | script: 168 | "scripts/umap.py" 169 | 170 | 171 | rule run_classification: 172 | input: 173 | "{anything}/windows.parquet", 174 | "{anything}/embeddings/{model}.parquet", 175 | output: 176 | "{anything}/classification/{model}.parquet", 177 | threads: workflow.cores 178 | resources: 179 | slurm_extra="--mail-user=sbuedenb@smail.uni-koeln.de --mail-type=END,FAIL" 180 | conda: 181 | "envs/run_classification.yaml" 182 | script: 183 | "scripts/run_classification.py" -------------------------------------------------------------------------------- /scripts/tribolium_castaneum/Snakefile: -------------------------------------------------------------------------------- 1 | configfile: "config.yaml" 2 | conda: "envs/global.yaml" 3 | 4 | 5 | localrules: 6 | download_reference, 7 | download_annotation, 8 | download_utr_script, 9 | add_utr_to_annotation, 10 | expand_annotation, 11 | all, 12 | 13 | # set to WANDB_MODE=disabled when not using it 14 | envvars: "WANDB_MODE" 15 | 16 | 17 | import pandas as pd 18 | from Bio import SeqIO 19 | import gzip 20 | import bioframe as bf 21 | from gpn.data import load_table, Genome, filter_length, make_windows 22 | import more_itertools 23 | 24 | 25 | WINDOW_SIZE = config["window_size"] 26 | EMBEDDING_WINDOW_SIZE = 100 27 | 28 | 29 | models = [ 30 | "gonzalobenegas/gpn-brassicales", 31 | "sbuedenb/beetle-gpn", 32 | # "sbuedenb/beetle-gpn-wide", 33 | "sbuedenb/beetle-gpn-wide-reduced" 34 | ] 35 | 36 | 37 | rule all: 38 | input: 39 | expand("output/embedding/umap/{model}.parquet", model=models), 40 | # expand("output/embedding/classification/{model}.parquet", model=models), 41 | 42 | 43 | rule download_reference: 44 | output: 45 | "output/genome.fa.gz", 46 | shell: 47 | "wget --no-check-certificate {config[FASTA_URL]} -O {output}" 48 | 49 | 50 | rule download_annotation: 51 | output: 52 | "output/annotation.gtf.gz", 53 | shell: 54 | "wget --no-check-certificate {config[GTF_URL]} -O {output}" 55 | 56 | rule download_utr_script: 57 | params: 58 | url = "https://ftp.ncbi.nlm.nih.gov/genomes/TOOLS/add_utrs_to_gff/add_utrs_to_gff.py" 59 | output: 60 | "output/add_utrs_to_gff.py", 61 | shell: 62 | "wget --no-check-certificate {params.url} -O {output}" 63 | 64 | rule add_utr_to_annotation: 65 | input: 66 | "output/add_utrs_to_gff.py", 67 | "output/annotation.gtf.gz", 68 | output: 69 | "output/annotation_utr.gtf", 70 | shell: 71 | "python {input} > {output}" 72 | 73 | 74 | rule expand_annotation: 75 | input: 76 | "output/annotation_utr.gtf", 77 | "output/genome.fa.gz", 78 | output: 79 | "output/annotation.expanded.parquet", 80 | conda: 81 | "envs/expand_annotation.yaml" 82 | script: 83 | "scripts/expand_annotation.py" 84 | 85 | 86 | rule define_embedding_windows: 87 | input: 88 | "output/annotation.expanded.parquet", 89 | "output/genome.fa.gz", 90 | output: 91 | "output/embedding/windows.parquet", 92 | run: 93 | gtf = pd.read_parquet(input[0]) 94 | genome = Genome(input[1]) 95 | genome.filter_chroms([ 96 | "NC_087394.1", 97 | "NC_087395.1", 98 | "NC_087396.1", 99 | "NC_087397.1", 100 | "NC_087398.1", 101 | "NC_087399.1", 102 | "NC_087400.1", 103 | "NC_087401.1", 104 | "NC_087402.1", 105 | "NC_087403.1", 106 | "NC_087404.1", 107 | # "NC_003081.2" 108 | ]) 109 | defined_intervals = genome.get_defined_intervals() 110 | defined_intervals = filter_length(defined_intervals, WINDOW_SIZE) 111 | windows = make_windows(defined_intervals, WINDOW_SIZE, EMBEDDING_WINDOW_SIZE) 112 | windows.rename(columns={"start": "full_start", "end": "full_end"}, inplace=True) 113 | 114 | windows["start"] = ( 115 | windows.full_start + windows.full_end 116 | ) // 2 - EMBEDDING_WINDOW_SIZE // 2 117 | windows["end"] = windows.start + EMBEDDING_WINDOW_SIZE 118 | 119 | # acts also as priority list if multiple annotations overlap 120 | features_of_interest = [ 121 | "three_prime_UTR", 122 | "five_prime_UTR", 123 | "CDS", 124 | "ncRNA_gene", 125 | "intron", 126 | "Repeat", 127 | "intergenic", 128 | ] 129 | 130 | for f in features_of_interest: 131 | print(f) 132 | windows = bf.coverage(windows, gtf[gtf.feature == f]) 133 | windows.rename(columns=dict(coverage=f), inplace=True) 134 | 135 | # we keep if the center 100 bp are exactly covered by at least one of the region of interest 136 | windows = windows[ 137 | (windows[features_of_interest] == EMBEDDING_WINDOW_SIZE).sum(axis=1) >= 1 138 | ] 139 | windows["Region"] = windows[features_of_interest].idxmax(axis=1) 140 | windows.drop(columns=features_of_interest, inplace=True) 141 | 142 | windows.rename( 143 | columns={"start": "center_start", "end": "center_end"}, inplace=True 144 | ) 145 | windows.rename(columns={"full_start": "start", "full_end": "end"}, inplace=True) 146 | print(windows) 147 | windows.to_parquet(output[0], index=False) 148 | 149 | 150 | rule get_embeddings: 151 | input: 152 | "output/embedding/windows.parquet", 153 | "output/genome.fa.gz", 154 | output: 155 | "output/embedding/embeddings/{model}.parquet", 156 | conda: 157 | "envs/gpn.yaml" 158 | threads: workflow.cores 159 | resources: 160 | slurm_extra="-G h100:1" 161 | shell: 162 | """ 163 | python -m gpn.ss.get_embeddings {input} {EMBEDDING_WINDOW_SIZE} \ 164 | {wildcards.model} {output} --per_device_batch_size 4000 --is_file \ 165 | --dataloader_num_workers {threads} 166 | """ 167 | 168 | rule run_umap: 169 | input: 170 | "{anything}/embeddings/{model}.parquet", 171 | output: 172 | "{anything}/umap/{model}.parquet", 173 | conda: 174 | "envs/umap.yaml" 175 | script: 176 | "scripts/umap.py" 177 | 178 | rule run_classification: 179 | input: 180 | "{anything}/windows.parquet", 181 | "{anything}/embeddings/{model}.parquet", 182 | output: 183 | "{anything}/classification/{model}.parquet", 184 | threads: workflow.cores 185 | resources: 186 | slurm_extra="--mail-user=sbuedenb@smail.uni-koeln.de --mail-type=END,FAIL" 187 | conda: 188 | "envs/run_classification.yaml" 189 | script: 190 | "scripts/run_classification.py" -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/compute_all_sliding_window_for_sequence.py: -------------------------------------------------------------------------------- 1 | # goal: take a DNA sequcene and compute for each position 2 | # - gpn predicted distribution 3 | # - gpn scores with given sequence as assumed reference nucleotide 4 | 5 | # result file layout: 6 | # index: position on chromosome 7 | # column: ref, p_a, p_c, p_g, p_t, gpn_a, gpa_c, gpn_g, gpn_t 8 | # p_x denotes the resulting gpn predicted probabilities 9 | # gpn_x denotes the gpn score w.r.t ref, the reference_nucleotide 10 | 11 | import numpy as np 12 | import pandas as pd 13 | from snakemake.script import snakemake 14 | from transformers import AutoModelForMaskedLM, AutoTokenizer 15 | from tqdm import tqdm 16 | from datasets import Dataset 17 | from torch.utils.data import DataLoader 18 | import torch 19 | 20 | # gpn specific model configuration 21 | import gpn.model 22 | from gpn.data import load_fasta 23 | 24 | chromosome = snakemake.wildcards.chromosome 25 | 26 | # load sequence 27 | sequence_path = snakemake.input[0] 28 | genome = load_fasta(sequence_path) 29 | sequence = genome[chromosome] 30 | 31 | model_path = snakemake.wildcards.model 32 | 33 | # load tokenizer 34 | tokenizer = AutoTokenizer.from_pretrained(model_path) 35 | print(f"tokenizer vocabulary: {tokenizer.get_vocab()}") 36 | 37 | # load model 38 | model = AutoModelForMaskedLM.from_pretrained(model_path) 39 | device = "cuda" 40 | model.to(device) 41 | model.eval() 42 | 43 | # start of HMA4-3 44 | start_position = 21886677 45 | # end of HMA4-1 46 | stop_position = 22031473 47 | 48 | 49 | def sliding_window_generator( 50 | sequence, start_position, stop_position, tokenizer, window_size=513, step_size=1 51 | ): 52 | """ 53 | Generate sliding windows over a DNA sequence. 54 | 55 | Args: 56 | fasta_path (str): Path to the FASTA file 57 | window_size (int): Size of the sliding window 58 | step_size (int): Step size for sliding the window 59 | 60 | Yields: 61 | dict: A dictionary with the sequence window 62 | """ 63 | seq_len = len(sequence) 64 | 65 | window_size_half = window_size // 2 66 | 67 | assert start_position - window_size_half - 1 >= 0 68 | assert stop_position + window_size_half - 1 <= seq_len 69 | 70 | for position in range(start_position, stop_position + 1, step_size): 71 | # arrays are 0-indexed, genomes 1-indexed 72 | position = position - 1 73 | 74 | start = int(position - window_size_half) 75 | end = int(position + window_size_half + 1) 76 | 77 | sequence_window = sequence[start:end] 78 | 79 | center = len(sequence_window) // 2 80 | 81 | tokenized_input = tokenizer( 82 | sequence_window, 83 | return_tensors="pt", 84 | return_attention_mask=False, 85 | return_token_type_ids=False, 86 | ) 87 | 88 | # Remove the batch dimension for dataset compatibility 89 | tokenized_data = {k: v.squeeze(0) for k, v in tokenized_input.items()} 90 | 91 | # mask the center nucleotide 92 | tokenized_data["input_ids"][center] = tokenizer.mask_token_id 93 | tokenized_data["reference"] = sequence_window[center].lower() 94 | 95 | # Add position information 96 | tokenized_data["position"] = position 97 | tokenized_data["sequence"] = ( 98 | sequence_window # Keep the original sequence for reference 99 | ) 100 | 101 | yield tokenized_data 102 | 103 | 104 | dataset = Dataset.from_generator( 105 | lambda: sliding_window_generator( 106 | sequence, start_position, stop_position, tokenizer 107 | ), 108 | # features={ 109 | # "input_ids": "sequence", 110 | # "reference": "string", 111 | # "sequence": "string", 112 | # "position": "int32", 113 | # }, 114 | ) 115 | 116 | 117 | def collate_fn(batch): 118 | return { 119 | "input_ids": torch.tensor([item["input_ids"] for item in batch]), 120 | "reference": [item["reference"] for item in batch], 121 | "sequence": [item["sequence"] for item in batch], 122 | "position": [item["position"] for item in batch], 123 | } 124 | 125 | 126 | batch_size = 64 127 | dataloader = DataLoader( 128 | dataset, 129 | batch_size=batch_size, 130 | collate_fn=collate_fn, 131 | shuffle=False, # For sliding windows, keep in order 132 | ) 133 | 134 | # Process batches 135 | all_predictions = [] 136 | acgt_idxs = [tokenizer.get_vocab()[nuc] for nuc in ["a", "c", "g", "t"]] 137 | 138 | window_size = 513 139 | center = window_size // 2 140 | results = [] 141 | for batch in tqdm(dataloader, desc="Batch"): 142 | # print(batch) 143 | # batch_results = [] 144 | # for i in range(len(batch["input_ids"])): 145 | current_input = batch["input_ids"] 146 | 147 | # print(current_input) 148 | # print(current_input.shape) 149 | with torch.no_grad(): 150 | all_logits = ( 151 | model(input_ids=current_input.to(device)).logits.cpu().to(torch.float32) 152 | ) 153 | 154 | nucleotide_logits = all_logits[:, :, acgt_idxs] 155 | output_probs = torch.nn.functional.softmax(nucleotide_logits, dim=-1) 156 | 157 | all_predictions.append(output_probs) 158 | 159 | for i in range(len(batch["input_ids"])): 160 | results.append( 161 | { 162 | "position": batch["position"][i], 163 | "reference": batch["reference"][i], 164 | "p_a": output_probs[i][center][0], 165 | "p_c": output_probs[i][center][1], 166 | "p_g": output_probs[i][center][2], 167 | "p_t": output_probs[i][center][3], 168 | } 169 | ) 170 | 171 | # print(batch_results) 172 | # print(model(**batch)) 173 | # print(model(input_ids=batch.to(device)).logits.cpu().to(torch.float32)) 174 | 175 | # print(all_predictions) 176 | results = pd.DataFrame(results) 177 | 178 | # convert all tensors to floats 179 | results = results.map( 180 | lambda x: x.item() if torch.is_tensor(x) and x.numel() == 1 else x 181 | ) 182 | 183 | p_reference = [ 184 | row[col] 185 | for row, col in zip(results.to_dict("records"), "p_" + results["reference"]) 186 | ] 187 | 188 | for alt in ["a", "c", "g", "t"]: 189 | results["gpn_" + alt] = results["p_" + alt] / p_reference 190 | 191 | results[["gpn_a", "gpn_c", "gpn_g", "gpn_t"]] = np.log2( 192 | results[["gpn_a", "gpn_c", "gpn_g", "gpn_t"]] 193 | ) 194 | 195 | results.to_parquet(snakemake.output[0]) 196 | -------------------------------------------------------------------------------- /scripts/nucleotide_dependency_maps/scripts/nucleotide_dependency_map_helpers.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | import seaborn as sns 5 | import torch 6 | from datasets import Dataset 7 | from matplotlib.colors import LogNorm 8 | from transformers import DefaultDataCollator 9 | 10 | 11 | def mutate_sequence(seq): 12 | nuc_table = {"A": 0, "C": 1, "G": 2, "T": 3} 13 | 14 | seq = seq.upper() 15 | mutated_sequences = {"seq": [], "mutation_pos": [], "nuc": [], "var_nt_idx": []} 16 | mutated_sequences["seq"].append(seq) 17 | mutated_sequences["mutation_pos"].append(-1) 18 | mutated_sequences["nuc"].append("real sequence") 19 | mutated_sequences["var_nt_idx"].append(-1) 20 | 21 | mutate_until_position = len(seq) 22 | 23 | for i in range(mutate_until_position): 24 | for nuc in ["A", "C", "G", "T"]: 25 | if nuc != seq[i]: 26 | mutated_sequences["seq"].append(seq[:i] + nuc + seq[i + 1 :]) 27 | mutated_sequences["mutation_pos"].append(i) 28 | mutated_sequences["nuc"].append(nuc) 29 | mutated_sequences["var_nt_idx"].append(nuc_table[nuc]) 30 | 31 | mutations_df = pd.DataFrame(mutated_sequences) 32 | 33 | return mutations_df 34 | 35 | 36 | def create_dataloader(dataset, tokenizer, batch_size=64, rolling_masking=False): 37 | 38 | ds = Dataset.from_pandas(dataset[["seq"]]) 39 | 40 | # print(ds.shape) 41 | 42 | tok_ds = ds.map( 43 | lambda x: tokenizer( 44 | list(x["seq"]), 45 | return_tensors="pt", 46 | return_attention_mask=False, 47 | return_token_type_ids=False, 48 | ), 49 | batched=False, 50 | num_proc=20, 51 | ) 52 | 53 | # print(tok_ds.shape) 54 | 55 | rem_tok_ds = tok_ds.remove_columns("seq") 56 | 57 | # print(rem_tok_ds.shape) 58 | 59 | data_collator = DefaultDataCollator() 60 | 61 | data_loader = torch.utils.data.DataLoader( 62 | rem_tok_ds, 63 | batch_size=batch_size, 64 | num_workers=4, 65 | shuffle=False, 66 | collate_fn=data_collator, 67 | ) 68 | 69 | return data_loader 70 | 71 | 72 | def model_inference(model, tokenizer, data_loader, device="cuda"): 73 | acgt_idxs = [tokenizer.get_vocab()[nuc] for nuc in ["a", "c", "g", "t"]] 74 | 75 | output_arrays = [] 76 | for i, batch in enumerate(data_loader): 77 | # get some tokenized sequences (B, L_in) 78 | 79 | tokens = batch["input_ids"] 80 | 81 | print(i) 82 | 83 | tokens = torch.squeeze(tokens, dim=2) 84 | 85 | # print(tokens.shape) 86 | # predict 87 | with torch.autocast(device): 88 | with torch.no_grad(): 89 | # ORIGINAL 90 | # outputs = model(tokens.to(device)).prediction_logits.cpu().to(torch.float32) 91 | # APATED 92 | outputs = ( 93 | model(input_ids=tokens.to(device)).logits.cpu().to(torch.float32) 94 | ) 95 | # calculate probability distribution only over the logits of nucleotides, not all tokens 96 | nucleotide_logits = outputs[:, :, acgt_idxs] 97 | output_probs = torch.nn.functional.softmax(nucleotide_logits, dim=-1) 98 | 99 | # this calculates softmax over all tokens 100 | # output_probs = torch.nn.functional.softmax(outputs, dim=-1)[ 101 | # :, :, acgt_idxs 102 | # ] # B, L_seq, 4 103 | output_arrays.append(output_probs) 104 | 105 | # rebuild to B, L_seq, 4 106 | snp_reconstruct = torch.concat(output_arrays, axis=0) 107 | 108 | return snp_reconstruct.to(torch.float32).numpy() 109 | 110 | 111 | def compute_dependency_map(seq, model, tokenizer, epsilon=1e-10): 112 | 113 | dataset = mutate_sequence(seq) 114 | data_loader = create_dataloader(dataset, tokenizer) 115 | snp_reconstruct = model_inference(model, tokenizer, data_loader) 116 | 117 | # those tokens do not esist for GPN 118 | # snp_reconstruct = snp_reconstruct[:,2:-1,:] # discard the beginning of sentence token, species token and end of sentence token 119 | 120 | # for the logit add a small value epsilon and renormalize such that every prob in one position sums to 1 121 | snp_reconstruct = snp_reconstruct + epsilon 122 | snp_reconstruct = snp_reconstruct / snp_reconstruct.sum(axis=-1)[:, :, np.newaxis] 123 | 124 | seq_len = snp_reconstruct.shape[1] 125 | snp_effect = np.zeros((seq_len, seq_len, 4, 4)) 126 | reference_probs = snp_reconstruct[ 127 | dataset[dataset["nuc"] == "real sequence"].index[0] 128 | ] 129 | 130 | snp_effect[ 131 | dataset.iloc[1:]["mutation_pos"].values, 132 | :, 133 | dataset.iloc[1:]["var_nt_idx"].values, 134 | :, 135 | ] = ( 136 | np.log2(snp_reconstruct[1:]) 137 | - np.log2(1 - snp_reconstruct[1:]) 138 | - np.log2(reference_probs) 139 | + np.log2(1 - reference_probs) 140 | ) 141 | 142 | dep_map = np.max(np.abs(snp_effect), axis=(2, 3)) 143 | # zero main diagonal values 144 | dep_map[np.arange(dep_map.shape[0]), np.arange(dep_map.shape[0])] = 0 145 | 146 | return dep_map 147 | 148 | 149 | ## Visualization functions 150 | def map_seq_to_file( 151 | matrix, 152 | dna_sequence, 153 | path, 154 | format, 155 | plot_size=10, 156 | vmax=5, 157 | tick_label_fontsize=8, 158 | title=None, 159 | ): 160 | 161 | fig, ax = plt.subplots(figsize=(plot_size, plot_size)) 162 | 163 | if title is not None: 164 | ax.set_title(title) 165 | 166 | sns.heatmap( 167 | matrix, 168 | cmap="coolwarm", 169 | vmax=vmax, 170 | ax=ax, 171 | xticklabels=False, 172 | yticklabels=False, 173 | # norm=LogNorm(), 174 | ) 175 | ax.set_aspect("equal") 176 | 177 | tick_positions = np.arange(len(dna_sequence)) + 0.5 # Center the ticks 178 | 179 | ax.set_xticks(tick_positions) 180 | ax.set_yticks(tick_positions) 181 | ax.set_xticklabels(list(dna_sequence), fontsize=tick_label_fontsize, rotation=0) 182 | ax.set_yticklabels(list(dna_sequence), fontsize=tick_label_fontsize) 183 | 184 | # plt.show() 185 | plt.savefig(path, format=format) 186 | 187 | 188 | def map_to_file( 189 | matrix, path, format, vmax=None, display_values=False, annot_size=8, fig_size=10 190 | ): 191 | 192 | plt.figure(figsize=(fig_size, fig_size)) 193 | 194 | ax = sns.heatmap( 195 | matrix, 196 | cmap="coolwarm", 197 | vmax=vmax, 198 | annot=display_values, 199 | fmt=".2f", 200 | annot_kws={"size": annot_size}, 201 | ) 202 | 203 | ax.set_aspect("equal") 204 | 205 | plt.savefig(path, format=format) 206 | -------------------------------------------------------------------------------- /experiments/context_length_prediction_impact/scripts/helpers.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import numpy as np 4 | import seaborn as sns 5 | import torch 6 | from tqdm import tqdm 7 | 8 | 9 | def get_gpn_probabilities(model, tokenizer, context_tokens, device="cuda"): 10 | acgt_idxs = [tokenizer.get_vocab()[nuc] for nuc in ["a", "c", "g", "t"]] 11 | with torch.no_grad(): 12 | all_logits = ( 13 | model(input_ids=context_tokens.to(device)).logits.cpu().to(torch.float32) 14 | ) 15 | nucleotide_logits = all_logits[:, :, acgt_idxs] 16 | output_probs = torch.nn.functional.softmax(nucleotide_logits, dim=-1) 17 | return output_probs 18 | 19 | 20 | def compute_context_length_dependency( 21 | model, 22 | tokenizer, 23 | chromosome_sequence, 24 | position, 25 | max_context_length=1000, 26 | step_size=10, 27 | ): 28 | # shift to zero based indexing in arrays vs. chromosomes 29 | position = position - 1 30 | 31 | half = max_context_length // 2 32 | start = int(position - half) 33 | end = int(position + half + 1) 34 | 35 | sequence = chromosome_sequence[start:end] 36 | 37 | # print(f"masked: {sequence[half]}") 38 | 39 | input_ids = tokenizer( 40 | sequence, 41 | return_tensors="pt", 42 | return_attention_mask=False, 43 | return_token_type_ids=False, 44 | )["input_ids"] 45 | 46 | results_dict = {} 47 | 48 | steps = max_context_length // step_size 49 | center = len(input_ids[0]) // 2 50 | 51 | # mask the center nucleotide 52 | input_ids[0, center] = tokenizer.mask_token_id 53 | 54 | for i in tqdm( 55 | range(0, steps * step_size + 1, step_size), 56 | desc="Predicting", 57 | position=1, 58 | leave=False, 59 | ): 60 | half = i // 2 61 | start = int(center - half) 62 | end = int(center + half + 1) 63 | 64 | current_context = input_ids[:, start:end] 65 | 66 | # print(current_context) 67 | 68 | results_dict[i] = get_gpn_probabilities(model, tokenizer, current_context)[ 69 | 0, half, : 70 | ] 71 | 72 | results = pd.DataFrame(results_dict, index=["a", "c", "g", "t"]) 73 | 74 | return results.T 75 | 76 | 77 | def rolling_variance(results, window_size=100): 78 | # we compute the variance over the forward looking window, i.e. if in the future nothing changes, we dont need more context size 79 | return results[::-1].rolling(window=window_size, min_periods=1).var(ddof=0)[::-1] 80 | 81 | 82 | def find_context_size_step_with_total_prediction_variance_below_threshold( 83 | results, window_size=100, threshold=1e-5 84 | ): 85 | rolling_var = rolling_variance(results, window_size=window_size) 86 | return rolling_var.index[rolling_var.sum(axis=1) < threshold].min() 87 | 88 | 89 | def get_distribution_shift(df): 90 | return df.diff().abs().sum(axis=1).div(2)[1:] 91 | 92 | 93 | def find_context_size_step_with_distribution_shift_below_threshold( 94 | results, window_size=10, threshold=0.01 95 | ): 96 | distribution_shift = get_distribution_shift(results) 97 | roll_avg_diff = ( 98 | distribution_shift[::-1].rolling(window=window_size, min_periods=1).mean()[::-1] 99 | ) 100 | return roll_avg_diff.index[roll_avg_diff < threshold].min() 101 | 102 | 103 | def compute_gpn_score(reference_nucleotide, probabilities_per_context_length): 104 | 105 | reference_nucleotide = reference_nucleotide.lower() 106 | 107 | nucleotides = ["a", "c", "g", "t"] 108 | alternatives = [n for n in nucleotides if n != reference_nucleotide] 109 | 110 | gpn_scores = probabilities_per_context_length 111 | 112 | for alt in alternatives: 113 | gpn_scores[f"gpn_{alt}"] = gpn_scores[alt] / gpn_scores[reference_nucleotide] 114 | 115 | gpn_scores = np.log2(gpn_scores).drop(nucleotides, axis=1) 116 | 117 | 118 | def plot_stacked_area( 119 | df, 120 | path, 121 | format, 122 | title="Stacked Area Chart", 123 | xlabel="X-axis", 124 | ylabel="Y-axis", 125 | marker=None, 126 | ): 127 | """ 128 | Plots a stacked area chart from a Pandas DataFrame. 129 | 130 | Parameters: 131 | - df: Pandas DataFrame (index is x-axis, columns are categories to be stacked). 132 | - title: Title of the chart. 133 | - xlabel: Label for the x-axis. 134 | - ylabel: Label for the y-axis. 135 | - colors: Optional list of colors for the areas. 136 | """ 137 | plt.figure(figsize=(10, 6)) 138 | df.plot(kind="area", stacked=True, alpha=0.7, colormap="viridis") 139 | 140 | print(marker) 141 | if marker: 142 | plt.axvline( 143 | x=marker, color="red", linestyle="-", linewidth=1, label="threshold" 144 | ) 145 | 146 | plt.title(title, fontsize=14) 147 | plt.xlabel(xlabel, fontsize=12) 148 | plt.ylabel(ylabel, fontsize=12) 149 | plt.grid(alpha=0.3) 150 | plt.legend(title="Nucleotides", bbox_to_anchor=(1.05, 1), loc="upper left") 151 | plt.tight_layout() 152 | 153 | plt.savefig(path, format=format) 154 | 155 | 156 | def plot_line( 157 | df, 158 | path, 159 | format, 160 | title="Line Chart", 161 | xlabel="X-axis", 162 | ylabel="Y-axis", 163 | marker=None, 164 | ): 165 | plt.figure(figsize=(10, 6)) 166 | df.plot(kind="line", alpha=0.7, colormap="viridis", ylim=(0, 1)) 167 | 168 | # print(marker) 169 | if marker: 170 | plt.axvline( 171 | x=marker, color="red", linestyle="-", linewidth=1, label="threshold" 172 | ) 173 | 174 | plt.title(title, fontsize=14) 175 | plt.xlabel(xlabel, fontsize=12) 176 | plt.ylabel(ylabel, fontsize=12) 177 | plt.grid(alpha=0.3) 178 | plt.tight_layout() 179 | 180 | plt.savefig(path, format=format) 181 | 182 | 183 | def boxplot( 184 | data, 185 | path, 186 | format, 187 | title="Boxplot", 188 | xlabel="X-axis", 189 | ylabel="Y-axis", 190 | y_col="threshold_steps", 191 | x_col="feature", 192 | marker=None, 193 | ): 194 | # Create the box plot 195 | plt.figure(figsize=(12, 5)) 196 | sns.boxplot(data=data, y=y_col, x=x_col, color="skyblue") 197 | 198 | sample_sizes = data[x_col].value_counts().sort_index() 199 | 200 | for i, label in enumerate(sample_sizes): 201 | plt.annotate( 202 | f"n={label}", 203 | (i, plt.gca().get_ylim()[0]), 204 | xytext=(0, -20), 205 | textcoords="offset points", 206 | ha="center", 207 | va="top", 208 | ) 209 | 210 | if marker: 211 | plt.axhline( 212 | y=marker, 213 | color="red", 214 | linestyle="-", 215 | linewidth=1, 216 | label="training sample size", 217 | ) 218 | 219 | # Add labels and title 220 | plt.title(title, fontsize=14) 221 | # plt.xlabel(xlabel, fontsize=12) 222 | plt.ylabel(ylabel, fontsize=12) 223 | 224 | plt.savefig(path, format=format) 225 | -------------------------------------------------------------------------------- /scripts/high_throughput_gpn_computation/workflow/scripts/compute_gpn.py: -------------------------------------------------------------------------------- 1 | # goal: take a DNA sequcene and compute for each position 2 | # - gpn predicted distribution 3 | # - gpn scores with given sequence as assumed reference nucleotide 4 | 5 | # result file layout: 6 | # index: position on chromosome 7 | # column: ref, p_a, p_c, p_g, p_t, gpn_a, gpa_c, gpn_g, gpn_t 8 | # p_x denotes the resulting gpn predicted probabilities 9 | # gpn_x denotes the gpn score w.r.t ref, the reference_nucleotide 10 | 11 | import numpy as np 12 | import pandas as pd 13 | from snakemake.script import snakemake 14 | from transformers import AutoModelForMaskedLM, AutoTokenizer 15 | from tqdm import tqdm 16 | from datasets import Dataset 17 | from torch.utils.data import DataLoader 18 | import torch 19 | from Bio.Seq import Seq 20 | import warnings 21 | 22 | 23 | # gpn specific model configuration 24 | import gpn.model 25 | from gpn.data import load_fasta 26 | 27 | chromosome = snakemake.wildcards.chromosome 28 | 29 | # load sequence 30 | sequence_path = snakemake.input[0] 31 | genome = load_fasta(sequence_path) 32 | sequence = genome[chromosome] 33 | 34 | model_path = snakemake.config["MODEL_PATH"] 35 | 36 | # load tokenizer 37 | tokenizer = AutoTokenizer.from_pretrained(model_path) 38 | print(f"tokenizer vocabulary: {tokenizer.get_vocab()}") 39 | 40 | # load model 41 | model = AutoModelForMaskedLM.from_pretrained(model_path) 42 | device = "cuda" 43 | model.to(device) 44 | model.eval() 45 | 46 | # start of HMA4-3 47 | start_position = int(snakemake.wildcards.start_position) 48 | # end of HMA4-1 49 | stop_position = int(snakemake.wildcards.stop_position) 50 | 51 | print(f"start_position: {start_position}") 52 | print(f"stop_position: {stop_position}") 53 | 54 | # Adjust start and stop positions to be within the sequence length 55 | start_position = max(1, start_position) 56 | stop_position = min(len(sequence), stop_position) 57 | 58 | print(f"start_position: {start_position}") 59 | print(f"stop_position: {stop_position}") 60 | 61 | if start_position != int(snakemake.wildcards.start_position): 62 | warnings.warn( 63 | f"start_position was adjusted to be within the sequence length, now: {start_position}" 64 | ) 65 | if stop_position != int(snakemake.wildcards.stop_position): 66 | warnings.warn( 67 | f"stop_position was adjusted to be within the sequence length, now: {stop_position}" 68 | ) 69 | 70 | if snakemake.config["REVERSE_COMPLEMENT"]: 71 | warnings.warn( 72 | "REVERSE_COMPLEMENT is set to True, the sequence will be reverse complemented" 73 | ) 74 | 75 | context_length = int(snakemake.config["INCLUDED_CONTEXT"]) 76 | 77 | window_size = context_length + 1 78 | 79 | 80 | # print(f"start_position: {start_position}") 81 | # print(f"stop_position: {stop_position}") 82 | 83 | # print(f"sequence length: {len(sequence)}") 84 | # print(f"context_length: {context_length}") 85 | # print(f"window_size: {window_size}") 86 | 87 | assert start_position < stop_position 88 | 89 | # would ensure all positions have equal context available 90 | # assert stop_position - start_position < len(sequence) - context_length 91 | 92 | 93 | def sliding_window_generator( 94 | sequence, start_position, stop_position, tokenizer, window_size=513, step_size=1 95 | ): 96 | """ 97 | Generate sliding windows over a DNA sequence. 98 | 99 | Args: 100 | fasta_path (str): Path to the FASTA file 101 | window_size (int): Size of the sliding window 102 | step_size (int): Step size for sliding the window 103 | 104 | Yields: 105 | dict: A dictionary with the sequence window 106 | """ 107 | seq_len = len(sequence) 108 | 109 | window_size_half = window_size // 2 110 | 111 | # condition has been relaxed by clipping the start and end positions inside loop 112 | # assert start_position - window_size_half - 1 >= 0 113 | # assert stop_position + window_size_half - 1 <= seq_len 114 | 115 | for position in range(start_position, stop_position + 1, step_size): 116 | # arrays are 0-indexed, genomes 1-indexed 117 | position = position - 1 118 | 119 | start = position - window_size_half 120 | end = position + window_size_half + 1 121 | 122 | # Slice the actual sequence 123 | sequence_window = sequence[max(start, 0) : min(end, seq_len)] 124 | 125 | # Calculate how much padding is needed 126 | left_pad = max(0, -start) 127 | right_pad = max(0, end - seq_len) 128 | 129 | # Pad with 'n' (ambiguous base) as needed 130 | sequence_window = ("n" * left_pad) + sequence_window + ("n" * right_pad) 131 | 132 | assert ( 133 | len(sequence_window) == window_size 134 | ), f"Expected {window_size}, got {len(sequence_window)}" 135 | 136 | # if snakemake.config["REVERSE_COMPLEMENT"]: 137 | 138 | if snakemake.wildcards.reverse_complement == "rev": 139 | sequence_window = str(Seq(sequence_window).reverse_complement()) 140 | 141 | center = len(sequence_window) // 2 142 | 143 | tokenized_input = tokenizer( 144 | sequence_window, 145 | return_tensors="pt", 146 | return_attention_mask=False, 147 | return_token_type_ids=False, 148 | ) 149 | 150 | # Remove the batch dimension for dataset compatibility 151 | tokenized_data = {k: v.squeeze(0) for k, v in tokenized_input.items()} 152 | 153 | # mask the center nucleotide 154 | tokenized_data["input_ids"][center] = tokenizer.mask_token_id 155 | tokenized_data["reference"] = sequence_window[center].lower() 156 | 157 | # Add position information 158 | tokenized_data["position"] = position 159 | tokenized_data["sequence"] = ( 160 | sequence_window # Keep the original sequence for reference 161 | ) 162 | 163 | yield tokenized_data 164 | 165 | 166 | dataset = Dataset.from_generator( 167 | lambda: sliding_window_generator( 168 | sequence, start_position, stop_position, tokenizer, window_size=window_size 169 | ), 170 | ) 171 | 172 | 173 | def collate_fn(batch): 174 | return { 175 | "input_ids": torch.tensor([item["input_ids"] for item in batch]), 176 | "reference": [item["reference"] for item in batch], 177 | "sequence": [item["sequence"] for item in batch], 178 | "position": [item["position"] for item in batch], 179 | } 180 | 181 | 182 | batch_size = int(snakemake.config["BATCH_SIZE"]) 183 | 184 | dataloader = DataLoader( 185 | dataset, 186 | batch_size=batch_size, 187 | collate_fn=collate_fn, 188 | shuffle=False, # For sliding windows, keep in order 189 | ) 190 | 191 | # Process batches 192 | # all_predictions = [] 193 | acgt_idxs = [tokenizer.get_vocab()[nuc] for nuc in ["a", "c", "g", "t"]] 194 | 195 | center = window_size // 2 196 | results = [] 197 | for batch in tqdm(dataloader, desc="Batch"): 198 | current_input = batch["input_ids"] 199 | 200 | with torch.no_grad(): 201 | all_logits = ( 202 | model(input_ids=current_input.to(device)).logits.cpu().to(torch.float32) 203 | ) 204 | 205 | nucleotide_logits = all_logits[:, :, acgt_idxs] 206 | output_probs = torch.nn.functional.softmax(nucleotide_logits, dim=-1) 207 | 208 | # all_predictions.append(output_probs) 209 | 210 | for i in range(len(batch["input_ids"])): 211 | results.append( 212 | { 213 | "position": batch["position"][i], 214 | "reference": batch["reference"][i], 215 | "p_a": output_probs[i][center][0], 216 | "p_c": output_probs[i][center][1], 217 | "p_g": output_probs[i][center][2], 218 | "p_t": output_probs[i][center][3], 219 | } 220 | ) 221 | 222 | results = pd.DataFrame(results) 223 | 224 | # we kinda compute back to front when flipping to reverse complement so data is nicer to be understood when reverted here 225 | # if snakemake.config["REVERSE_COMPLEMENT"]: 226 | if snakemake.wildcards.reverse_complement == "rev": 227 | results = results[::-1] 228 | 229 | # convert all tensors to floats 230 | results = results.map( 231 | lambda x: x.item() if torch.is_tensor(x) and x.numel() == 1 else x 232 | ) 233 | 234 | p_reference = [ 235 | row[col] if col != "p_n" else 0.0 236 | for row, col in zip(results.to_dict("records"), "p_" + results["reference"]) 237 | ] 238 | 239 | for alt in ["a", "c", "g", "t"]: 240 | results["gpn_" + alt] = results["p_" + alt] / p_reference 241 | 242 | results[["gpn_a", "gpn_c", "gpn_g", "gpn_t"]] = np.log2( 243 | results[["gpn_a", "gpn_c", "gpn_g", "gpn_t"]] 244 | ) 245 | 246 | results.to_parquet(snakemake.output[0]) 247 | -------------------------------------------------------------------------------- /scripts/model_evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "60f85605-2488-49f1-8db7-86b6b3ca1939", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from datasets import load_dataset\n", 11 | "from torch.utils.data import DataLoader\n", 12 | "import torch\n", 13 | "from tqdm.auto import tqdm\n", 14 | "from evaluate import load # Hugging Face’s metrics hub\n", 15 | "\n", 16 | "import gpn.model\n", 17 | "from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling\n", 18 | "from pathlib import Path\n", 19 | "import os" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 23, 25 | "id": "3e267077-2285-4b1e-b22a-958c2964ff13", 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "dilation_schedule=[1, 3, 9, 27, 81, 243, 1, 3, 9, 27, 81, 243, 1, 3, 9, 27, 81, 243]\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "# dataset_name = \"sbuedenb/small_beetle_dataset\"\n", 38 | "# model_name = \"sbuedenb/beetle-gpn\" # v1: Top-1 accuracy: 51.8759% (validation) , v2: 53.0181%\n", 39 | "# model_name = \"sbuedenb/beetle-gpn-wide\" # Top-1 accuracy: 53.3793% (validation)\n", 40 | "\n", 41 | "\n", 42 | "# model_name = \"sbuedenb/beetle-gpn-wide-reduced\" # Top-1 accuracy: 51.8314%\n", 43 | "model_name = \"/home/sbuedenb/models/long-wide-cosine/\"\n", 44 | "dataset_name = \"sbuedenb/big_beetle_dataset\"\n", 45 | "# model_name = \"songlab/gpn-brassicales\"\n", 46 | "# (on brassicales) Top-1 accuracy: 53.8563% (validation), Top-1 accuracy: 53.2370% (test)\n", 47 | "# (on cucujiformia) Top-1 accuracy: 42.8384%\n", 48 | "\n", 49 | "# dataset_name = \"songlab/genomes-brassicales-balanced-v1\"\n", 50 | "\n", 51 | "\n", 52 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 53 | "model = AutoModelForMaskedLM.from_pretrained(model_name, local_files_only=True).eval()\n", 54 | "dataset = load_dataset(dataset_name, split=\"validation\") # or \"validation\"" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 24, 60 | "id": "6433c893-2a80-4cbb-8de2-ca6d9e5146c0", 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "'cuda'" 67 | ] 68 | }, 69 | "execution_count": 24, 70 | "metadata": {}, 71 | "output_type": "execute_result" 72 | } 73 | ], 74 | "source": [ 75 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 76 | "model.to(device);\n", 77 | "device" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "id": "f106fb0b-a6d3-4293-a01a-4f1d0bf8f495", 83 | "metadata": {}, 84 | "source": [ 85 | "# Top-1 accuracy on sbuedenb/big_beetle_dataset\n", 86 | "Model | Accuracy (eval) | Accuracy (test)\n", 87 | "-|-|-\n", 88 | "songlab/gpn-brassicales| 42.7848% | 42.9517%\n", 89 | "sbuedenb/beetle-gpn | 51.4824% | 56.0279%\n", 90 | "sbuedenb/beetle-gpn-wide-reduced | **51.8868%** | **56.2513%**\n", 91 | "sbuedenb/long-wide-cosine | 52.01% +- 0.63% | 55.70 +- 1.26%" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 25, 97 | "id": "c325a509-25c5-4745-b5ef-397972653aec", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "def tokenize_function(batch):\n", 102 | " res = tokenizer(\n", 103 | " batch[\"seq\"],\n", 104 | " return_special_tokens_mask=True,\n", 105 | " padding=False,\n", 106 | " truncation=False,\n", 107 | " return_token_type_ids=False,\n", 108 | " )\n", 109 | " return res\n", 110 | "\n", 111 | "tokenized = dataset.map(tokenize_function, batched=True, remove_columns=[\"seq\", \"assembly\", \"chrom\", \"strand\"])" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 26, 117 | "id": "49e582e6-570f-427f-8ed1-40e0f6348211", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "data_collator = DataCollatorForLanguageModeling(\n", 122 | " tokenizer=tokenizer,\n", 123 | " mlm=True,\n", 124 | " mlm_probability=0.15, # standard BERT mask-ratio\n", 125 | " seed=42,\n", 126 | ")" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 27, 132 | "id": "c84a8ebc-d945-4f9f-b1ad-b7f532c83cb4", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "loader = DataLoader(\n", 137 | " tokenized,\n", 138 | " batch_size=256,\n", 139 | " shuffle=False,\n", 140 | " collate_fn=data_collator,\n", 141 | ")" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 28, 147 | "id": "c215eb9e-190c-4423-a43b-0ace78fe2fac", 148 | "metadata": { 149 | "scrolled": true 150 | }, 151 | "outputs": [ 152 | { 153 | "data": { 154 | "application/vnd.jupyter.widget-view+json": { 155 | "model_id": "7efebdf72b6341eb856fcc1bce647463", 156 | "version_major": 2, 157 | "version_minor": 0 158 | }, 159 | "text/plain": [ 160 | "evaluating: 0%| | 0/166 [00:00 {output}" 120 | 121 | 122 | rule create_annotation_db: 123 | input: 124 | "output/unpack_ncbi_dataset/{accession}/annotation.gff", 125 | output: 126 | "output/create_annotation_db/{accession}/annotation.db", 127 | conda: 128 | "envs/compute_statistics.yaml" 129 | script: 130 | "scripts/create_annotation_db.py" 131 | 132 | 133 | rule compute_probabilities_per_context_length: 134 | input: 135 | "output/unpack_ncbi_dataset/{accession}/genome.fna", 136 | output: 137 | "output/compute_probabilities_per_context_length/{model}/{accession}/{chromosome}/{position}/probabilities_per_context_length.parquet", 138 | conda: 139 | "envs/scripts.yaml" 140 | resources: 141 | slurm_extra="-G 1", 142 | script: 143 | "scripts/compute_probabilities_per_context_length.py" 144 | 145 | 146 | # rule compute_gpn_per_context_length: 147 | # input: 148 | # "output/unpack_ncbi_dataset/{accession}/genome.fna", 149 | # "output/compute_probabilities_per_context_length/{model}/{accession}/{chromosome}/{position}/probabilities_per_context_length.parquet", 150 | # output: 151 | # "output/compute_gpn_per_context_length/{model}/{accession}/{chromosome}/{position}/gpn_per_context_length.parquet", 152 | # conda: 153 | # "envs/scripts.yaml" 154 | # script: 155 | # "scripts/compute_gpn_per_context_length.py" 156 | 157 | 158 | rule generate_random_sample_positions: 159 | input: 160 | "output/unpack_ncbi_dataset/{accession}/genome.fna", 161 | output: 162 | "output/generate_random_sample_positions/{accession}/{chromosome}/positions.json", 163 | conda: 164 | "envs/scripts.yaml" 165 | script: 166 | "scripts/generate_random_sample_positions.py" 167 | 168 | 169 | rule compute_prediction_variance_over_chromosome: 170 | input: 171 | annotation_db="output/create_annotation_db/{accession}/annotation.db", 172 | position_data=expand( 173 | "output/compute_probabilities_per_context_length/{{model}}/{{accession}}/{{chromosome}}/{position}/probabilities_per_context_length.parquet", 174 | position=all_random_positions, 175 | ), 176 | output: 177 | "output/compute_prediction_variance_over_chromosome/{model}/{accession}/{chromosome}/prediction_variance_over_chromosome.parquet", 178 | conda: 179 | "envs/compute_statistics.yaml" 180 | script: 181 | "scripts/compute_prediction_variance_over_chromosome.py" 182 | 183 | 184 | rule compute_distribution_shift_over_chromosome: 185 | input: 186 | annotation_db="output/create_annotation_db/{accession}/annotation.db", 187 | position_data=expand( 188 | "output/compute_probabilities_per_context_length/{{model}}/{{accession}}/{{chromosome}}/{position}/probabilities_per_context_length.parquet", 189 | position=all_random_positions, 190 | ), 191 | output: 192 | "output/compute_distribution_shift_over_chromosome/{model}/{accession}/{chromosome}/distribution_shift_over_chromosome.parquet", 193 | conda: 194 | "envs/compute_statistics.yaml" 195 | script: 196 | "scripts/compute_distribution_shift_over_chromosome.py" 197 | 198 | 199 | rule compute_gpn_statistics_over_chromosome: 200 | input: 201 | sequence="output/unpack_ncbi_dataset/{accession}/genome.fna", 202 | annotation_db="output/create_annotation_db/{accession}/annotation.db", 203 | position_data=expand( 204 | "output/compute_probabilities_per_context_length/{{model}}/{{accession}}/{{chromosome}}/{position}/probabilities_per_context_length.parquet", 205 | position=all_random_positions, 206 | ), 207 | output: 208 | "output/compute_gpn_statistics_over_chromosome/{model}/{accession}/{chromosome}/gpn_statistics_over_chromosome.parquet", 209 | conda: 210 | "envs/compute_statistics.yaml" 211 | script: 212 | "scripts/compute_gpn_statistics_over_chromosome.py" 213 | 214 | 215 | rule compute_all_for_sequence: 216 | input: 217 | "data/Lan3.1.fna.gz", 218 | output: 219 | "output/compute_all_for_sequence/{model}/Lan3.1/{chromosome}/gpn_scores.parquet", 220 | conda: 221 | "envs/scripts.yaml" 222 | resources: 223 | slurm_extra="-G h100:1", 224 | script: 225 | "scripts/compute_all_sliding_window_for_sequence.py" 226 | 227 | 228 | rule display_chromosome_boxplot: 229 | params: 230 | title="Context size with prediction influence", 231 | input: 232 | # "output/compute_distribution_shift_over_chromosome/{model}/{accession}/{chromosome}/distribution_shift_over_chromosome.parquet", 233 | "output/compute_prediction_variance_over_chromosome/{model}/{accession}/{chromosome}/prediction_variance_over_chromosome.parquet", 234 | output: 235 | "output/display_chromosome_boxplot/{model}/{accession}/{chromosome}/chromosome_boxplot.{format}", 236 | conda: 237 | "envs/scripts.yaml" 238 | script: 239 | "scripts/display_chromosome_boxplot.py" 240 | 241 | 242 | rule display_gpn_chromosome_boxplot: 243 | params: 244 | title="Context size with GPN-score influence", 245 | input: 246 | "output/compute_gpn_statistics_over_chromosome/{model}/{accession}/{chromosome}/gpn_statistics_over_chromosome.parquet", 247 | output: 248 | "output/display_gpn_chromosome_boxplot/{model}/{accession}/{chromosome}/gpn_chromosome_boxplot.{format}", 249 | conda: 250 | "envs/scripts.yaml" 251 | script: 252 | "scripts/display_chromosome_boxplot.py" 253 | 254 | 255 | rule display_stacked_probability_chart: 256 | input: 257 | "output/unpack_ncbi_dataset/{accession}/genome.fna", 258 | "output/compute_probabilities_per_context_length/{model}/{accession}/{chromosome}/{position}/probabilities_per_context_length.parquet", 259 | output: 260 | "output/display_stacked_probability_chart/{model}/{accession}/{chromosome}/{position}/stacked_probability_chart.{format}", 261 | conda: 262 | "envs/scripts.yaml" 263 | script: 264 | "scripts/display_stacked_probability_chart.py" 265 | 266 | 267 | rule display_distribution_shift_chart: 268 | input: 269 | "output/compute_probabilities_per_context_length/{model}/{accession}/{chromosome}/{position}/probabilities_per_context_length.parquet", 270 | output: 271 | "output/display_distribution_shift_chart/{model}/{accession}/{chromosome}/{position}/distribution_shift_chart.{format}", 272 | conda: 273 | "envs/scripts.yaml" 274 | script: 275 | "scripts/display_distribution_shift_chart.py" 276 | 277 | 278 | rule display_stacked_variance_chart: 279 | input: 280 | "output/unpack_ncbi_dataset/{accession}/genome.fna", 281 | "output/compute_probabilities_per_context_length/{model}/{accession}/{chromosome}/{position}/probabilities_per_context_length.parquet", 282 | output: 283 | "output/display_stacked_variance_chart/{model}/{accession}/{chromosome}/{position}/stacked_variance_chart.{format}", 284 | conda: 285 | "envs/scripts.yaml" 286 | script: 287 | "scripts/display_stacked_variance_chart.py" 288 | 289 | 290 | rule display_stacked_gpn_variance_chart: 291 | input: 292 | "output/unpack_ncbi_dataset/{accession}/genome.fna", 293 | "output/compute_gpn_per_context_length/{model}/{accession}/{chromosome}/{position}/gpn_per_context_length.parquet", 294 | output: 295 | "output/display_stacked_gpn_variance_chart/{model}/{accession}/{chromosome}/{position}/stacked_gpn_variance_chart.{format}", 296 | conda: 297 | "envs/scripts.yaml" 298 | script: 299 | "scripts/display_stacked_variance_chart.py" 300 | -------------------------------------------------------------------------------- /NOTES.md: -------------------------------------------------------------------------------- 1 | # Documentation of what I do on this machine 2 | 3 | ## Preamble 4 | Added ssh-keys of desktop an laptop machines. 5 | 6 | ## Setup development environment 7 | 1. Connect to machine via Remote-SSH plugin of vscode 8 | 2. Prepare devenv via [apptainer](https://apptainer.org/docs/user/1.3/quick_start.html) 9 | 3. Load cuda module: `module load cuda/12.5` 10 | 11 | ### apptainer usage 12 | Build .sif 13 | ```sh 14 | apptainer build --nv devenv.sif devenv.def 15 | ``` 16 | 17 | Run .sif interactively 18 | ```sh 19 | apptainer shell --nv devenv.sif 20 | ``` 21 | 22 | Run .sif as jupyter server in background 23 | ```sh 24 | apptainer instance start devenv.sif devenv-instance-1 25 | ``` 26 | > available on http://localhost:8888 27 | 28 | ## Get interactive node with some GPUS 29 | 30 | ```sh 31 | srun -p interactive --time=2:00:00 --mem=100gb -G 2 --ntasks=2 --cpus-per-task=8 --pty /bin/bash 32 | ``` 33 | 34 | Run apptainer with [`--nv` flag](https://apptainer.org/docs/user/main/gpu.html) to make cuda and graphics cards accessible. 35 | 36 | ```sh 37 | apptainer shell --nv devenv.sif 38 | ``` 39 | 40 | Install flash-attention. 41 | > see here: https://github.com/Dao-AILab/flash-attention 42 | 43 | ```sh 44 | pip install flash-attn --no-build-isolation 45 | ``` 46 | 47 | Start jupyter notebook server. 48 | ```sh 49 | jupyter notebook --no-browser --port 9999 50 | ``` 51 | 52 | Setup portforwarding of compute n 53 | ode via login node to local machine in new shell. (https://people.cs.umass.edu/~kexiao/posts/jupyter_notebook_remote_slurm.html) 54 | ```sh 55 | ssh -t -t sbuedenb@ramses4.itcc.uni-koeln.de -L 9999:localhost:8008 ssh ramses15233 -L 8008:localhost:9999 56 | ``` 57 | 58 | ## Spike: Make a NDM from GPN 59 | > see: scripts/nucleotide_dependency_maps_for_gpn.ipynb 60 | 61 | ## Spike: Re-run UMAP plot from GPN paper 62 | 63 | Clone gpn repository 64 | 65 | Had to install many new python libs, see requirements.txt 66 | 67 | Had to install vcftools 68 | 69 | Had to fix `rename_reference` rule in Snakefile 70 | 71 | Had to get a node with proper GPU (more than ~4GB VRAM) to run `run_umap` rule 72 | 73 | ## Workaround for ncbi-datasets-cli (autority error) 74 | 75 | - check NCBI page of genome ([Tribolium](https://www.ncbi.nlm.nih.gov/datasets/genome/GCF_031307605.1/)) 76 | - handpick files from FTP server 77 | 78 | ```bash 79 | wget https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/031/307/605/GCF_031307605.1_icTriCast1.1/GCF_031307605.1_icTriCast1.1_genomic.fna.gz 80 | 81 | wget https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/031/307/605/GCF_031307605.1_icTriCast1.1/GCF_031307605.1_icTriCast1.1_genomic.gff.gz 82 | ``` 83 | 84 | ## Setup without apptainer 85 | 86 | ```bash 87 | module load lang/Miniconda3/23.9.0-0 && \ 88 | conda env create --file environment.yml && \ 89 | conda activate snakemake 90 | ``` 91 | 92 | ## snakemake good to know 93 | 94 | Build conda environments: 95 | ``` 96 | snakemake --sdm conda --conda-create-envs-only 97 | ``` 98 | 99 | Run with conda and specific profile: 100 | ``` 101 | snakemake --sdm conda --profile=profiles/slurm 102 | ``` 103 | 104 | ## General stuff 105 | 106 | List default AG: 107 | `sacctmgr show assoc -n user=$USER format=Account` 108 | 109 | Put ninja in path in case flash-attn has to be build. 110 | 111 | `export PATH='/home/sbuedenb/.local/bin':$PATH` 112 | 113 | 114 | ```bash 115 | salloc -p gpu --ntasks=1 --cpus-per-task=16 --time=100:00 --mem=40gb -G 2 116 | srun --pty bash 117 | ``` 118 | > --ntasks => "cores" 119 | > --cpus-per-task => "threads" 120 | > -G 1 => "GPUs" 121 | 122 | 123 | check time left on slurm job 124 | ```bash 125 | squeue -h -j $SLURM_JOBID -o %L 126 | ``` 127 | 128 | 129 | ## Problems 130 | 131 | ```bash 132 | salloc -p gpu --ntasks=1 --cpus-per-task=16 --time=300:00 --mem=40gb -G 1 133 | ... 134 | 47%|██████████████████████████████████████ | 1258/2675 [1:31:09<1:42:45, 4.35s/it] 135 | salloc: Job 255505 has exceeded its time limit and its allocation has been revoked. 136 | srun: forcing job termination 137 | Hangup 138 | [sbuedenb@ramses4 masterthesis]$ srun: Job step aborted: Waiting up to 32 seconds for job step to finish. 139 | slurmstepd: error: *** STEP 255506.0 ON ramses16304 CANCELLED AT 2024-12-20T14:27:46 *** 140 | srun: error: ramses16304: task 0: Killed 141 | srun: Terminating StepId=255506.0 142 | tcsetattr: Input/output error 143 | 144 | --- 145 | 146 | salloc -p gpu --ntasks=1 --cpus-per-task=16 --time=300:00 --mem=40gb -G 2 147 | ... 148 | 66%|█████████████████████████████████████████████████▍ | 882/1338 [2:43:55<1:24:51, 11.17s/it] 149 | salloc: Job 255506 has exceeded its time limit and its allocation has been revoked. 150 | srun: forcing job termination 151 | Hangup 152 | [sbuedenb@ramses4 masterthesis]$ srun: Job step aborted: Waiting up to 32 seconds for job step to finish. 153 | slurmstepd: error: *** STEP 255532.0 ON ramses16301 CANCELLED AT 2024-12-20T17:55:18 *** 154 | srun: error: ramses16301: task 0: Killed 155 | srun: Terminating StepId=255532.0 156 | tcsetattr: Input/output error 157 | 158 | --- 159 | 160 | salloc -p gpu --ntasks=1 --cpus-per-task=16 --time=6:00:00 --mem=100gb -G 4 161 | ... 162 | 99%|█████████████████████████████████████████████████████████████████████████████▎| 663/669 [2:03:40<01:07, 11.19s/it] 163 | salloc: Job 255532 has exceeded its time limit and its allocation has been revoked. 164 | srun: forcing job termination 165 | Hangup 166 | [sbuedenb@ramses4 masterthesis]$ srun: Job step aborted: Waiting up to 32 seconds for job step to finish. 167 | slurmstepd: error: *** STEP 255604.0 ON ramses16304 CANCELLED AT 2024-12-20T20:08:49 *** 168 | srun: error: ramses16304: task 0: Killed 169 | srun: Terminating StepId=255604.0 170 | tcsetattr: Input/output error 171 | 172 | --- 173 | 174 | salloc -p gpu --ntasks=1 --cpus-per-task=16 --time=6:00:00 --mem=100gb -G 4 175 | ... 176 | 177 | 178 | 179 | ``` 180 | 181 | ## Train a GPN 182 | 183 | ``` 184 | export WANDB_API_KEY= 185 | export WANDB_ENTITY=sbuedenb-university-of-cologne 186 | export WANDB_PROJECT=beetle-gpn 187 | 188 | export RUN_NAME=wide-cosine-1024 189 | export OUTPUT_DIR="/scratch/sbuedenb/gpn-training/$RUN_NAME/" 190 | ``` 191 | then 192 | ``` 193 | conda activate gpn_gpu 194 | ``` 195 | then 196 | ``` 197 | torchrun --nproc_per_node=$(echo $CUDA_VISIBLE_DEVICES | awk -F',' '{print NF}') -m gpn.ss.run_mlm \ 198 | --do_train \ 199 | --do_eval \ 200 | --report_to wandb \ 201 | --prediction_loss_only True \ 202 | --remove_unused_columns False \ 203 | --dataset_name sbuedenb/big_beetle_dataset-1024 \ 204 | --tokenizer_name gonzalobenegas/tokenizer-dna-mlm \ 205 | --soft_masked_loss_weight_train 0.1 \ 206 | --soft_masked_loss_weight_evaluation 0.0 \ 207 | --total_batch_size 2048 \ 208 | --weight_decay 0.01 \ 209 | --optim adamw_torch \ 210 | --dataloader_num_workers 4 \ 211 | --seed 42 \ 212 | --save_strategy steps \ 213 | --save_steps 5000 \ 214 | --eval_strategy steps \ 215 | --eval_steps 1000 \ 216 | --logging_steps 100 \ 217 | --max_steps 180000 \ 218 | --warmup_steps 1000 \ 219 | --learning_rate 4e-3 \ 220 | --lr_scheduler_type cosine \ 221 | --run_name $RUN_NAME \ 222 | --output_dir $OUTPUT_DIR \ 223 | --model_type GPN \ 224 | --per_device_train_batch_size 256 \ 225 | --per_device_eval_batch_size 256 \ 226 | --gradient_accumulation_steps 1 \ 227 | --ddp_find_unused_parameters False \ 228 | --bf16 \ 229 | --bf16_full_eval \ 230 | --ignore_data_skip \ 231 | --config_overrides "first_kernel_size=9,rest_kernel_size=9,dilation_max=243,dilation_cycle=6,dilation_base=3,num_hidden_layers=18" 232 | ``` 233 | --ignore_data_skip \ 234 | --torch_compile \ 235 | --lr_scheduler_type constant_with_warmup \ 236 | 237 | 238 | 239 | ``` 240 | --nnodes=1:2 \ 241 | --nproc-per-node=4 \ 242 | --max-restarts=6 \ 243 | --standalone \ 244 | ``` 245 | ### fuck up counter: 7 246 | 247 | - space in command from repo 248 | - python 3.12 too new (3.10 should be stable) 249 | - imports `is_torch_tpu_available` which is deprecated without even using it 250 | - imports `from scipy.stats import geom` without documenting or using 251 | - dataset is not local ? 252 | - SegFault 253 | - random crash 254 | 255 | ### 256 | salloc \ 257 | --mail-user=sbuedenb@smail.uni-koeln.de \ 258 | --mail-type=BEGIN \ 259 | --nodes=1 \ 260 | --ntasks=4 \ 261 | --cpus-per-task=16 \ 262 | --time=48:00:00 \ 263 | --mem=64gb \ 264 | -p gpu \ 265 | -G h100:4 266 | > srun --pty bash 267 | 268 | scontrol update JobId=661168 StartTime=10:00:00 269 | 270 | ### training notes 271 | #### "GPN" current gpn architecture (first-try:small_data, third-try-big-data) #! continue RUN_ID: 9lptl80n 272 | Number of trainable parameters = 93.091.328 273 | --per_device_train_batch_size 256 \ 274 | --per_device_eval_batch_size 256 \ 275 | 276 | third-try-big-data: v1 is 240.000 277 | 278 | #### "GPN" current gpn architecture with dropout (second-try) 279 | Number of trainable parameters = 93.091.328 280 | --config_overrides "hidden_dropout_prob=0.1" 281 | --per_device_train_batch_size 256 \ 282 | --per_device_eval_batch_size 256 \ 283 | 284 | #### "GPN" with wide flat pyramid kernel config (wide_flat_pyramid_kernel_k9b3:small_data, wide-net-big-data) 285 | Number of trainable parameters = 118.257.152 286 | --config_overrides "first_kernel_size=9,rest_kernel_size=9,dilation_max=81,dilation_cycle=5,dilation_base=3" 287 | --per_device_train_batch_size 256 \ 288 | --per_device_eval_batch_size 256 \ 289 | 290 | #### "GPN" with wide flat pyramid kernel config (long-wide:long-data) #! fresh start 291 | --dataset_name sbuedenb/big_beetle_dataset-2048 \ 292 | --config_overrides "first_kernel_size=9,rest_kernel_size=9,dilation_max=243,dilation_cycle=6,dilation_base=3,num_hidden_layers=18" 293 | --per_device_train_batch_size 128 \ 294 | --per_device_eval_batch_size 128 \ 295 | --learning_rate 2e-3 \ 296 | --lr_scheduler_type cosine \ 297 | --max_steps 180000 298 | 299 | Total batch size: 4 GPU * 128/GPU = 512 with 2048 tokens per sample => 1.048.576 tokens per batch 300 | Number of trainable parameters = 85,219,840 301 | 302 | Try: --learning_rate 2e-3 \ 303 | --lr_scheduler_type cosine \ 304 | -- 305 | 306 | #### "GPN" less deep with wide flat pyramid kernel config (wide-big-reduced) #! continue RUN_ID: g5dyoby9 307 | --config_overrides "first_kernel_size=9,rest_kernel_size=9,dilation_max=81,dilation_cycle=5,dilation_base=3,num_hidden_layers=20" 308 | --per_device_train_batch_size 256 \ 309 | --per_device_eval_batch_size 256 \ 310 | 311 | Total batch size: 4 GPU * 256/GPU = 1024 with 512 tokens per sample => 524.288 tokens ber batch 312 | Number of trainable parameters = 94.659.072 313 | 314 | wide-big-reduced: v1 is 120.000, v2 is 240.000 315 | 316 | 317 | torchrun --nproc_per_node=$(echo $CUDA_VISIBLE_DEVICES | awk -F',' '{print NF}') -m gpn.ss.run_mlm \ 318 | --do_train \ 319 | --do_eval \ 320 | --report_to wandb \ 321 | --prediction_loss_only True \ 322 | --remove_unused_columns False \ 323 | --dataset_name sbuedenb/big_beetle_dataset \ 324 | --tokenizer_name gonzalobenegas/tokenizer-dna-mlm \ 325 | --soft_masked_loss_weight_train 0.1 \ 326 | --soft_masked_loss_weight_evaluation 0.0 \ 327 | --total_batch_size 2048 \ 328 | --weight_decay 0.01 \ 329 | --optim adamw_torch \ 330 | --dataloader_num_workers 9 \ 331 | --seed 42 \ 332 | --save_strategy steps \ 333 | --save_steps 5000 \ 334 | --eval_strategy steps \ 335 | --eval_steps 1000 \ 336 | --logging_steps 100 \ 337 | --num_train_epochs 26 \ 338 | --learning_rate 1e-3 \ 339 | --lr_scheduler_type cosine_with_restarts \ 340 | --lr_scheduler_kwargs '{"num_cycles":3}' \ 341 | --max_steps 114169 \ 342 | --warmup_steps 1000 \ 343 | --load_best_model_at_end True \ 344 | --metric_for_best_model loss \ 345 | --run_name $RUN_NAME \ 346 | --output_dir $OUTPUT_DIR \ 347 | --model_type GPN \ 348 | --per_device_train_batch_size 256 \ 349 | --per_device_eval_batch_size 256 \ 350 | --gradient_accumulation_steps 1 \ 351 | --ddp_find_unused_parameters False \ 352 | --bf16 \ 353 | --bf16_full_eval \ 354 | --ignore_data_skip \ 355 | --config_overrides "first_kernel_size=9,rest_kernel_size=9,dilation_max=81,dilation_cycle=5,dilation_base=3,num_hidden_layers=20" 356 | 357 | --lr_scheduler_type cosine 358 | --max_steps 30000 359 | --resume_from_checkpoint 360 | 361 | 362 | #### "GPN" smol 363 | --config_overrides "first_kernel_size=9,rest_kernel_size=5,dilation_max=64,dilation_cycle=7,dilation_base=2,num_hidden_layers=21,hidden_size=256,intermediate_size=1024,hidden_dropout_prob=0.1" 364 | Number of trainable parameters = 19.608.320 365 | 366 | #### "ConvNet" gpn legacy 367 | Number of trainable parameters = 65880071 368 | 369 | ### resume WANDB run 370 | ``` 371 | export WANDB_RESUME=allow 372 | export WANDB_RUN_ID="9inaq395" 373 | ``` 374 | 375 | > RUN_ID must be explicit id not name 376 | 377 | ### upload to huggingface 378 | 379 | > put stuff into new folder `v2` 380 | 381 | huggingface-cli upload sbuedenb/beetle-gpn ./v2 . --commit-message "v2 – retrained with dropout 0.1" 382 | 383 | 384 | https://www.ncbi.nlm.nih.gov/datasets/docs/v2/policies-annotation/genomeftp/#are-repetitive-sequences-in-eukaryotic-genomes-masked 385 | 386 | 387 | 388 | --------------------------------------------------------------------------------