├── mu-former ├── src │ ├── __init__.py │ ├── protein │ │ ├── tasks │ │ │ ├── __init__.py │ │ │ └── pmlm.py │ │ ├── models │ │ │ └── __init__.py │ │ ├── criterions │ │ │ ├── __init__.py │ │ │ └── pmlm.py │ │ ├── __init__.py │ │ └── dict.txt │ ├── utils.py │ ├── vocab.py │ └── criterion.py ├── data │ └── example │ │ └── IF1_ECOLI_Kelsic_2016.fasta ├── requirements.txt ├── README.md └── .gitignore ├── mu-search ├── src │ └── flexs │ │ ├── flexs │ │ ├── landscapes │ │ │ ├── src │ │ │ │ ├── __init__.py │ │ │ │ ├── protein │ │ │ │ │ ├── tasks │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── mlm.py │ │ │ │ │ │ └── pmlm.py │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── models │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── encoder │ │ │ │ │ │ │ └── transformer_sentence_encoder_layer.py │ │ │ │ │ └── criterions │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── mlm.py │ │ │ │ │ │ └── pmlm.py │ │ │ │ ├── criterion.py │ │ │ │ ├── vocab.py │ │ │ │ └── utils.py │ │ │ ├── landscape │ │ │ │ ├── muformer │ │ │ │ │ └── muformer_landscape │ │ │ │ │ │ ├── src │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── protein │ │ │ │ │ │ │ ├── tasks │ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ │ ├── mlm.py │ │ │ │ │ │ │ │ └── pmlm.py │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ ├── models │ │ │ │ │ │ │ │ └── __init__.py │ │ │ │ │ │ │ └── criterions │ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ │ ├── mlm.py │ │ │ │ │ │ │ │ └── pmlm.py │ │ │ │ │ │ ├── criterion.py │ │ │ │ │ │ ├── vocab.py │ │ │ │ │ │ └── utils.py │ │ │ │ │ │ ├── CFX_test_from_pan_above_median.tsv │ │ │ │ │ │ └── CFX_test_from_pan.tsv │ │ │ │ └── __init__.py │ │ │ ├── __init__.py │ │ │ ├── tf_binding.py │ │ │ └── bert_gfp.py │ │ ├── baselines │ │ │ ├── explorers │ │ │ │ ├── stable_baselines3 │ │ │ │ │ ├── stable_baselines3 │ │ │ │ │ │ ├── py.typed │ │ │ │ │ │ ├── common │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ ├── sb2_compat │ │ │ │ │ │ │ │ └── __init__.py │ │ │ │ │ │ │ ├── envs │ │ │ │ │ │ │ │ └── __init__.py │ │ │ │ │ │ │ ├── vec_env │ │ │ │ │ │ │ │ ├── vec_extract_dict_obs.py │ │ │ │ │ │ │ │ ├── vec_frame_stack.py │ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ │ ├── util.py │ │ │ │ │ │ │ │ ├── vec_check_nan.py │ │ │ │ │ │ │ │ ├── vec_monitor.py │ │ │ │ │ │ │ │ ├── vec_video_recorder.py │ │ │ │ │ │ │ │ └── vec_transpose.py │ │ │ │ │ │ │ ├── running_mean_std.py │ │ │ │ │ │ │ ├── type_aliases.py │ │ │ │ │ │ │ └── results_plotter.py │ │ │ │ │ │ ├── a2c │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── policies.py │ │ │ │ │ │ ├── dqn │ │ │ │ │ │ │ └── __init__.py │ │ │ │ │ │ ├── ppo │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── policies.py │ │ │ │ │ │ ├── sac │ │ │ │ │ │ │ └── __init__.py │ │ │ │ │ │ ├── td3 │ │ │ │ │ │ │ └── __init__.py │ │ │ │ │ │ ├── ddpg │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── policies.py │ │ │ │ │ │ ├── her │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── goal_selection_strategy.py │ │ │ │ │ │ └── __init__.py │ │ │ │ │ ├── .coveragerc │ │ │ │ │ ├── CITATION.bib │ │ │ │ │ ├── .readthedocs.yml │ │ │ │ │ ├── .gitlab-ci.yml │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── NOTICE │ │ │ │ │ ├── Makefile │ │ │ │ │ └── setup.cfg │ │ │ │ ├── environments │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── env.py │ │ │ │ ├── __init__.py │ │ │ │ ├── pure_random.py │ │ │ │ ├── random.py │ │ │ │ ├── evoplay_utils │ │ │ │ │ └── env_model.py │ │ │ │ └── cmaes.py │ │ │ ├── __init__.py │ │ │ └── models │ │ │ │ ├── __init__.py │ │ │ │ ├── mlp.py │ │ │ │ ├── global_epistasis_model.py │ │ │ │ ├── keras_model.py │ │ │ │ ├── sklearn_models.py │ │ │ │ ├── cnn.py │ │ │ │ ├── noisy_abstract_model.py │ │ │ │ ├── adaptive_ensemble.py │ │ │ │ └── basecnn.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── os_utils.py │ │ │ ├── eval_utils.py │ │ │ ├── seq_utils.py │ │ │ └── sequence_utils.py │ │ ├── types.py │ │ ├── __init__.py │ │ ├── landscape.py │ │ ├── model.py │ │ ├── ensemble.py │ │ └── evaluate.py │ │ ├── readthedocs.yml │ │ ├── Makefile │ │ ├── setup.cfg │ │ ├── setup.py │ │ └── .gitignore ├── readthedocs.yml ├── pyrosetta_distributed-0.0.3-py3-none-any.whl ├── Makefile ├── setup.cfg ├── examples │ └── test.py ├── setup.py ├── .gitignore └── tests │ ├── test_landscapes.py │ ├── test_models.py │ └── test_explorers.py ├── pmlm ├── src │ └── protein │ │ ├── models │ │ ├── __init__.py │ │ └── layer.py │ │ ├── tasks │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── dict.txt │ │ └── criterions │ │ └── __init__.py ├── requirements.txt ├── README.md └── script │ └── pretrain.sh ├── README.md └── .gitignore /mu-former/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .protein import * 2 | -------------------------------------------------------------------------------- /mu-former/src/protein/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmlm import * 2 | -------------------------------------------------------------------------------- /mu-former/src/protein/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmlm import ProtPairwiseModel 2 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .protein import * 2 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mu-former/src/protein/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmlm import ProtPairwiseMaskedLMCriterion 2 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/protein/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmlm import * 2 | from .mlm import * -------------------------------------------------------------------------------- /mu-former/src/protein/__init__.py: -------------------------------------------------------------------------------- 1 | from .criterions import * 2 | from .models import * 3 | from .tasks import * -------------------------------------------------------------------------------- /mu-search/readthedocs.yml: -------------------------------------------------------------------------------- 1 | python: 2 | version: 3.7 3 | install: 4 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/sb2_compat/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .protein import * 2 | -------------------------------------------------------------------------------- /mu-search/src/flexs/readthedocs.yml: -------------------------------------------------------------------------------- 1 | python: 2 | version: 3.7 3 | install: 4 | - requirements: requirements.txt -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/protein/__init__.py: -------------------------------------------------------------------------------- 1 | from .criterions import * 2 | from .models import * 3 | from .tasks import * -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/protein/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmlm import ProtPairwiseModel 2 | from .mlm import ProtBaseModel 3 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/environments/__init__.py: -------------------------------------------------------------------------------- 1 | """Reinforcement learning environments for DynaPPO and PPO explorers.""" 2 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/protein/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmlm import * 2 | from .mlm import * -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility modules. 3 | 4 | `utils.sequence_utils` is the most important and useful. 5 | """ 6 | -------------------------------------------------------------------------------- /pmlm/src/protein/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlm import ProtBaseModel 2 | from .pmlm import ProtPairwiseModel 3 | from .pmlmx import ProtPairwiseXModel 4 | -------------------------------------------------------------------------------- /pmlm/src/protein/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlm import * 2 | from .pmlm import * 3 | from .tape import * 4 | from .legacy import * 5 | from .pmlmx import * -------------------------------------------------------------------------------- /mu-search/pyrosetta_distributed-0.0.3-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MutonAI/Mu-Protein/HEAD/mu-search/pyrosetta_distributed-0.0.3-py3-none-any.whl -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/protein/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmlm import ProtPairwiseMaskedLMCriterion 2 | from .mlm import ProteinMaskedLMCriterion 3 | -------------------------------------------------------------------------------- /mu-former/data/example/IF1_ECOLI_Kelsic_2016.fasta: -------------------------------------------------------------------------------- 1 | >IF1_ECOLI_Kelsic_2016|P69222|IF1_ECOLI 2 | MAKEDNIEMQGTVLETLPNTMFRVELENGHVVTAHISGKMRKNYIRILTGDKVTVELTPY 3 | DLSKGRIVFRSR 4 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/protein/__init__.py: -------------------------------------------------------------------------------- 1 | from .criterions import * 2 | from .models import * 3 | from .tasks import * -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/protein/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmlm import ProtPairwiseModel 2 | from .mlm import ProtBaseModel 3 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/utils/os_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_arg_parser(): 4 | return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 5 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/types.py: -------------------------------------------------------------------------------- 1 | """Types definitions for the flexs package.""" 2 | from typing import List, Union 3 | 4 | import numpy as np 5 | 6 | SEQUENCES_TYPE = Union[List[str], np.ndarray] 7 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/protein/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmlm import ProtPairwiseMaskedLMCriterion 2 | from .mlm import ProteinMaskedLMCriterion 3 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | """Baselines module containing robust implementations 2 | of various models and explorers. 3 | """ 4 | from flexs.baselines import explorers, models # noqa: F401 5 | -------------------------------------------------------------------------------- /pmlm/src/protein/__init__.py: -------------------------------------------------------------------------------- 1 | from .criterions import ProteinMaskedLMCriterion 2 | from .datasets import * 3 | from .models import * 4 | from .tasks import * 5 | 6 | from .tokenizers import TAPETokenizer, Uniprot21Dict 7 | -------------------------------------------------------------------------------- /pmlm/src/protein/dict.txt: -------------------------------------------------------------------------------- 1 | A 0 2 | R 0 3 | N 0 4 | D 0 5 | C 0 6 | Q 0 7 | E 0 8 | G 0 9 | H 0 10 | I 0 11 | L 0 12 | K 0 13 | M 0 14 | F 0 15 | P 0 16 | S 0 17 | T 0 18 | W 0 19 | Y 0 20 | V 0 21 | X 0 22 | - 0 23 | -------------------------------------------------------------------------------- /mu-former/src/protein/dict.txt: -------------------------------------------------------------------------------- 1 | A 0 2 | R 0 3 | N 0 4 | D 0 5 | C 0 6 | Q 0 7 | E 0 8 | G 0 9 | H 0 10 | I 0 11 | L 0 12 | K 0 13 | M 0 14 | F 0 15 | P 0 16 | S 0 17 | T 0 18 | W 0 19 | Y 0 20 | V 0 21 | X 0 22 | - 0 23 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/a2c/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.a2c.a2c import A2C 2 | from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 3 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.dqn.dqn import DQN 2 | from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 3 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from stable_baselines3.ppo.ppo import PPO 3 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/sac/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from stable_baselines3.sac.sac import SAC 3 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/td3/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from stable_baselines3.td3.td3 import TD3 3 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.ddpg.ddpg import DDPG 2 | from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 3 | -------------------------------------------------------------------------------- /mu-search/Makefile: -------------------------------------------------------------------------------- 1 | format: 2 | python -m black . 3 | python -m isort --profile black . 4 | 5 | lint: 6 | python -m flake8 flexs 7 | 8 | test: 9 | python -m pytest tests 10 | 11 | .PHONY: docs 12 | docs: 13 | cd ./docs && make html -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/ddpg/policies.py: -------------------------------------------------------------------------------- 1 | # DDPG can be view as a special case of TD3 2 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy # noqa:F401 3 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/her/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy 2 | from stable_baselines3.her.her_replay_buffer import HerReplayBuffer 3 | -------------------------------------------------------------------------------- /pmlm/requirements.txt: -------------------------------------------------------------------------------- 1 | lmdb 2 | numpy==1.23.2 3 | fairseq==0.10.2 4 | scipy 5 | pyarrow 6 | pandas 7 | scikit-learn 8 | cython 9 | biopython==1.74 10 | tensorboard 11 | tensorboardX 12 | tensorboard_logger 13 | pathlib2 14 | tqdm 15 | matplotlib -------------------------------------------------------------------------------- /mu-search/src/flexs/Makefile: -------------------------------------------------------------------------------- 1 | format: 2 | python -m black . 3 | python -m isort --profile black . 4 | 5 | lint: 6 | python -m flake8 flexs 7 | 8 | test: 9 | python -m pytest tests 10 | 11 | .PHONY: docs 12 | docs: 13 | cd ./docs && make html -------------------------------------------------------------------------------- /pmlm/src/protein/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlm import ProteinMaskedLMCriterion 2 | from .pmlm import ProtPairwiseMaskedLMCriterion 3 | from .tape import ProtTapeEvalCriterion 4 | from .legacy import ProtLegacyCriterion 5 | from .pmlmx import ProtPairwiseXMaskedLMCriterion -------------------------------------------------------------------------------- /mu-former/src/protein/criterions/pmlm.py: -------------------------------------------------------------------------------- 1 | from fairseq.criterions import FairseqCriterion, register_criterion 2 | 3 | @register_criterion('prot_pmlm') 4 | class ProtPairwiseMaskedLMCriterion(FairseqCriterion): 5 | def __init__(self, task, tpu=False): 6 | super().__init__(task) 7 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/protein/criterions/mlm.py: -------------------------------------------------------------------------------- 1 | 2 | from fairseq.criterions import FairseqCriterion, register_criterion 3 | 4 | @register_criterion('prot_mlm') 5 | class ProteinMaskedLMCriterion(FairseqCriterion): 6 | def __init__(self, task, tpu=False): 7 | super().__init__(task) 8 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/protein/criterions/pmlm.py: -------------------------------------------------------------------------------- 1 | from fairseq.criterions import FairseqCriterion, register_criterion 2 | 3 | @register_criterion('prot_pmlm') 4 | class ProtPairwiseMaskedLMCriterion(FairseqCriterion): 5 | def __init__(self, task, tpu=False): 6 | super().__init__(task) 7 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/protein/criterions/mlm.py: -------------------------------------------------------------------------------- 1 | 2 | from fairseq.criterions import FairseqCriterion, register_criterion 3 | 4 | @register_criterion('prot_mlm') 5 | class ProteinMaskedLMCriterion(FairseqCriterion): 6 | def __init__(self, task, tpu=False): 7 | super().__init__(task) 8 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/protein/criterions/pmlm.py: -------------------------------------------------------------------------------- 1 | from fairseq.criterions import FairseqCriterion, register_criterion 2 | 3 | @register_criterion('prot_pmlm') 4 | class ProtPairwiseMaskedLMCriterion(FairseqCriterion): 5 | def __init__(self, task, tpu=False): 6 | super().__init__(task) 7 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/CFX_test_from_pan_above_median.tsv: -------------------------------------------------------------------------------- 1 | mutation score 2 | G236S 1.4 3 | M180T;G236S 32.0 4 | E102K;G236S 360.0 5 | E102K;M180T;G236S 360.0 6 | A40G;G236S 23.0 7 | A40G;M180T 1.4 8 | A40G;M180T;G236S 360.0 9 | A40G;E102K 1.4 10 | A40G;E102K;G236S 2100.0 11 | A40G;E102K;M180T;G236S 2900.0 12 | -------------------------------------------------------------------------------- /mu-former/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | numba>=0.45.1 3 | tape_proteins>=0.4 4 | tqdm>=4.51.0 5 | numpy==1.20.1 6 | biopython>=1.78 7 | msgpack_python>=0.5.6 8 | scikit_learn>=0.24.1 9 | PyYAML>=5.2 10 | joblib 11 | fairseq==0.10.2 12 | lmdb 13 | pyarrow 14 | cython 15 | pathlib2 16 | matplotlib 17 | sentencepiece 18 | boto3 19 | protobuf==3.19.0 20 | fair-esm 21 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/a2c/policies.py: -------------------------------------------------------------------------------- 1 | # This file is here just to define MlpPolicy/CnnPolicy 2 | # that work for A2C 3 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy 4 | 5 | MlpPolicy = ActorCriticPolicy 6 | CnnPolicy = ActorCriticCnnPolicy 7 | MultiInputPolicy = MultiInputActorCriticPolicy 8 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/ppo/policies.py: -------------------------------------------------------------------------------- 1 | # This file is here just to define MlpPolicy/CnnPolicy 2 | # that work for PPO 3 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy 4 | 5 | MlpPolicy = ActorCriticPolicy 6 | CnnPolicy = ActorCriticCnnPolicy 7 | MultiInputPolicy = MultiInputActorCriticPolicy 8 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = False 3 | omit = 4 | tests/* 5 | setup.py 6 | # Require graphical interface 7 | stable_baselines3/common/results_plotter.py 8 | # Require ffmpeg 9 | stable_baselines3/common/vec_env/vec_video_recorder.py 10 | 11 | [report] 12 | exclude_lines = 13 | pragma: no cover 14 | raise NotImplementedError() 15 | if typing.TYPE_CHECKING: 16 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv 2 | from stable_baselines3.common.envs.identity_env import ( 3 | FakeImageEnv, 4 | IdentityEnv, 5 | IdentityEnvBox, 6 | IdentityEnvMultiBinary, 7 | IdentityEnvMultiDiscrete, 8 | ) 9 | from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv 10 | -------------------------------------------------------------------------------- /mu-search/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Recommend matching the black line length (default 88), 3 | # rather than using the flake8 default of 79: 4 | max-line-length = 88 5 | extend-ignore = 6 | # See https://github.com/PyCQA/pycodestyle/issues/373 7 | E203, # Whitespace before ‘:’ 8 | D205, # 1 blank line required between summary line and description 9 | D400, # First line should end with a period 10 | E731 # Do not assign a lambda expression, use a def 11 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/CFX_test_from_pan.tsv: -------------------------------------------------------------------------------- 1 | mutation score 2 | WT 0.088 3 | G236S 1.4 4 | M180T 0.063 5 | M180T;G236S 32.0 6 | E102K 0.13 7 | E102K;G236S 360.0 8 | E102K;M180T 0.18 9 | E102K;M180T;G236S 360.0 10 | A40G 0.08800000000000001 11 | A40G;G236S 23.0 12 | A40G;M180T 1.4 13 | A40G;M180T;G236S 360.0 14 | A40G;E102K 1.4 15 | A40G;E102K;G236S 2100.0 16 | A40G;E102K;M180T 0.8 17 | A40G;E102K;M180T;G236S 2900.0 18 | -------------------------------------------------------------------------------- /mu-search/src/flexs/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Recommend matching the black line length (default 88), 3 | # rather than using the flake8 default of 79: 4 | max-line-length = 88 5 | extend-ignore = 6 | # See https://github.com/PyCQA/pycodestyle/issues/373 7 | E203, # Whitespace before ‘:’ 8 | D205, # 1 blank line required between summary line and description 9 | D400, # First line should end with a period 10 | E731 # Do not assign a lambda expression, use a def 11 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/CITATION.bib: -------------------------------------------------------------------------------- 1 | @article{stable-baselines3, 2 | author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann}, 3 | title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations}, 4 | journal = {Journal of Machine Learning Research}, 5 | year = {2021}, 6 | volume = {22}, 7 | number = {268}, 8 | pages = {1-8}, 9 | url = {http://jmlr.org/papers/v22/20-1364.html} 10 | } 11 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Build documentation in the docs/ directory with Sphinx 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | # Optionally build your docs in additional formats such as PDF and ePub 12 | formats: all 13 | 14 | # Set requirements using conda env 15 | conda: 16 | environment: docs/conda_env.yml 17 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/__init__.py: -------------------------------------------------------------------------------- 1 | """The FLEXS (Fitness Landscape EXploration Sandbox) package.""" 2 | 3 | from flexs import types # isort:skip # noqa: F401 4 | 5 | from flexs.landscape import Landscape # isort:skip # noqa: F401 6 | from flexs.model import Model, LandscapeAsModel # isort:skip # noqa: F401 7 | from flexs.ensemble import Ensemble # isort:skip # noqa: F401 8 | from flexs.explorer import Explorer # isort:skip # noqa: F401 9 | 10 | 11 | from flexs import baselines, evaluate, landscapes # isort:skip # noqa: F401 12 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/__init__.py: -------------------------------------------------------------------------------- 1 | """FLEXS landscapes module.""" 2 | from flexs.landscapes import rna # noqa: F401 3 | from flexs.landscapes.additive_aav_packaging import ( # noqa: F401 4 | AdditiveAAVPackaging, 5 | ) 6 | from flexs.landscapes.bert_gfp import BertGFPBrightness # noqa: F401 7 | from flexs.landscapes.rna import RNABinding # noqa: F401 8 | from flexs.landscapes.rosetta import RosettaFolding # noqa: F401 9 | from flexs.landscapes.tf_binding import TFBinding # noqa: F401 10 | from flexs.landscapes.muformer import MuformerLandscape # noqa: F401 11 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | image: stablebaselines/stable-baselines3-cpu:1.4.1a0 2 | 3 | type-check: 4 | script: 5 | - pip install pytype --upgrade 6 | - make type 7 | 8 | pytest: 9 | script: 10 | - pip install tqdm rich # for progress bar 11 | - python --version 12 | # MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error 13 | - MKL_THREADING_LAYER=GNU make pytest 14 | coverage: '/^TOTAL.+?(\d+\%)$/' 15 | 16 | doc-build: 17 | script: 18 | - make doc 19 | 20 | lint-check: 21 | script: 22 | - make check-codestyle 23 | - make lint 24 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/her/goal_selection_strategy.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class GoalSelectionStrategy(Enum): 5 | """ 6 | The strategies for selecting new goals when 7 | creating artificial transitions. 8 | """ 9 | 10 | # Select a goal that was achieved 11 | # after the current step, in the same episode 12 | FUTURE = 0 13 | # Select the goal that was achieved 14 | # at the end of the episode 15 | FINAL = 1 16 | # Select a goal that was achieved in the episode 17 | EPISODE = 2 18 | 19 | 20 | # For convenience 21 | # that way, we can use string to select a strategy 22 | KEY_TO_GOAL_STRATEGY = { 23 | "future": GoalSelectionStrategy.FUTURE, 24 | "final": GoalSelectionStrategy.FINAL, 25 | "episode": GoalSelectionStrategy.EPISODE, 26 | } 27 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/models/__init__.py: -------------------------------------------------------------------------------- 1 | """`baselines.models` module.""" 2 | from flexs.baselines.models.adaptive_ensemble import ( # noqa: F401 3 | AdaptiveEnsemble, 4 | ) 5 | from flexs.baselines.models.cnn import CNN # noqa: F401 6 | from flexs.baselines.models.global_epistasis_model import ( # noqa: F401 7 | GlobalEpistasisModel, 8 | ) 9 | from flexs.baselines.models.keras_model import KerasModel # noqa: F401 10 | from flexs.baselines.models.mlp import MLP # noqa: F401 11 | from flexs.baselines.models.noisy_abstract_model import ( # noqa: F401 12 | NoisyAbstractModel, 13 | ) 14 | from flexs.baselines.models.sklearn_models import ( # noqa: F401 15 | LinearRegression, 16 | LogisticRegression, 17 | RandomForest, 18 | SklearnClassifier, 19 | SklearnRegressor, 20 | ) 21 | from flexs.baselines.models.basecnn import ( 22 | BaseCNN, 23 | MaskedConv1d, 24 | ToyMLP 25 | ) 26 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/vec_env/vec_extract_dict_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper 4 | 5 | 6 | class VecExtractDictObs(VecEnvWrapper): 7 | """ 8 | A vectorized wrapper for extracting dictionary observations. 9 | 10 | :param venv: The vectorized environment 11 | :param key: The key of the dictionary observation 12 | """ 13 | 14 | def __init__(self, venv: VecEnv, key: str): 15 | self.key = key 16 | super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) 17 | 18 | def reset(self) -> np.ndarray: 19 | obs = self.venv.reset() 20 | return obs[self.key] 21 | 22 | def step_wait(self) -> VecEnvStepReturn: 23 | obs, reward, done, info = self.venv.step_wait() 24 | return obs[self.key], reward, done, info 25 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from stable_baselines3.a2c import A2C 4 | from stable_baselines3.common.utils import get_system_info 5 | from stable_baselines3.ddpg import DDPG 6 | from stable_baselines3.dqn import DQN 7 | from stable_baselines3.her.her_replay_buffer import HerReplayBuffer 8 | from stable_baselines3.ppo import PPO 9 | from stable_baselines3.sac import SAC 10 | from stable_baselines3.td3 import TD3 11 | 12 | # Read version from file 13 | version_file = os.path.join(os.path.dirname(__file__), "version.txt") 14 | with open(version_file) as file_handler: 15 | __version__ = file_handler.read().strip() 16 | 17 | 18 | def HER(*args, **kwargs): 19 | raise ImportError( 20 | "Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n " 21 | "Please check the documentation for more information: https://stable-baselines3.readthedocs.io/" 22 | ) 23 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/__init__.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # import importlib 3 | # from muformer_landscape.landscape import MuformerLandscape, MuformerLandscapeList 4 | 5 | # protein_alphabet = 'ACDEFGHIKLMNPQRSTVWY' 6 | 7 | # task_alphabet_dict = { 8 | # 'TEM': protein_alphabet, 9 | # } 10 | 11 | # task_wild_type_dict = { 12 | # 'TEM': 'MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW', 13 | # } 14 | 15 | # landscape_collection = {'muformer': MuformerLandscapeList} 16 | 17 | # def get_landscape(args): 18 | # if args.landscape == 'muformer': 19 | # landscape = landscape_collection[args.landscape](args, task_wild_type_dict[args.task]) 20 | # else: 21 | # raise NotImplementedError 22 | # return landscape, task_alphabet_dict[args.task], task_wild_type_dict[args.task] 23 | 24 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2019 Antonin Raffin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /mu-search/examples/test.py: -------------------------------------------------------------------------------- 1 | from flexs.baselines.explorers import Encoder 2 | import flexs 3 | import flexs.utils.sequence_utils as s_utils 4 | from flexs.baselines.explorers import GwgPairSampler, GWG 5 | import pandas as pd 6 | 7 | 8 | if __name__=="__main__": 9 | problem = flexs.landscapes.additive_aav_packaging.registry()['blood'] 10 | landscape = flexs.landscapes.additive_aav_packaging.AdditiveAAVPackaging(**problem['params']) 11 | starting_sequence = landscape.wild_type 12 | alphabet = s_utils.AAS 13 | encoder = Encoder(alphabet) 14 | print(encoder.vocab_size) 15 | encoded = encoder.encode(alphabet) 16 | onehot = encoder.onehotize(encoded) 17 | # print(onehot) 18 | sampler = GwgPairSampler(encoder, 10, sequences_batch_size=10, model_queries_per_batch=10, temperature=0.1, starting_sequence=starting_sequence, alphabet=alphabet, log_file='efficiency/ggs/blood/10_10.csv') 19 | explorer = GWG(sampler=sampler, rounds=3, sequences_batch_size=10, model_queries_per_batch=10, temperature=0.1, starting_sequence=starting_sequence, alphabet=alphabet) 20 | output = sampler([starting_sequence]) 21 | # print(output[0]['source_sequence'].drop_duplicates()) 22 | explorer.run(landscape) -------------------------------------------------------------------------------- /mu-search/src/flexs/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md") as f: 4 | long_description = f.read() 5 | 6 | setuptools.setup( 7 | name="flexs", 8 | version="0.2.1", 9 | description=( 10 | "FLEXS: an open simulation environment for developing and comparing " 11 | "model-guided biological sequence design algorithms." 12 | ), 13 | url="https://github.com/samsinai/FLEXS", 14 | author="Stewart Slocum", 15 | author_email="slocumstewy@gmail.com", 16 | license="Apache 2.0", 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | packages=setuptools.find_packages(), 20 | python_requires=">=3.5", 21 | install_requires=[ 22 | "cma", 23 | "editdistance", 24 | "numpy>=1.16", 25 | "pandas>=0.23", 26 | "torch>=0.4", 27 | "scikit-learn>=0.20", 28 | "tape-proteins", 29 | "tensorflow>=2", 30 | "tf-agents>=0.3", 31 | ], 32 | classifiers=[ 33 | "Programming Language :: Python :: 3", 34 | "Programming Language :: Python :: 3.5", 35 | "Programming Language :: Python :: 3.6", 36 | "Programming Language :: Python :: 3.7", 37 | "Programming Language :: Python :: 3.8", 38 | ], 39 | ) 40 | -------------------------------------------------------------------------------- /mu-former/src/protein/tasks/pmlm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from fairseq import utils 5 | from fairseq.data import Dictionary 6 | 7 | from fairseq.tasks import FairseqTask, register_task 8 | logger = logging.getLogger(__name__) 9 | 10 | @register_task('prot_pmlm') 11 | class ProtPairwiseMaskedLMTask(FairseqTask): 12 | 13 | def __init__(self, args, dictionary): 14 | super().__init__(args) 15 | self.args = args 16 | self.dictionary = dictionary 17 | if not hasattr(args, 'max_positions'): 18 | self._max_positions = ( 19 | args.max_source_positions, 20 | args.max_target_positions, 21 | ) 22 | else: 23 | self._max_positions = args.max_positions 24 | args.tokens_per_sample = self._max_positions 25 | # add mask token 26 | self.mask_idx = dictionary.add_symbol("") 27 | 28 | @classmethod 29 | def setup_task(cls, args, **kwargs): 30 | paths = utils.split_paths(args.data) 31 | assert len(paths) > 0 32 | dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) 33 | logger.info("dictionary: {} types".format(len(dictionary))) 34 | return cls(args, dictionary) 35 | 36 | @property 37 | def source_dictionary(self): 38 | return self.dictionary -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/__init__.py: -------------------------------------------------------------------------------- 1 | """FLEXS `explorers` module""" 2 | from flexs.baselines.explorers import environments # noqa: F401 3 | from flexs.baselines.explorers.adalead import Adalead # noqa: F401 4 | from flexs.baselines.explorers.bo import BO, GPR_BO # noqa: F401 5 | from flexs.baselines.explorers.cbas_dbas import VAE, CbAS # noqa: F401 6 | from flexs.baselines.explorers.cmaes import CMAES # noqa: F401 7 | from flexs.baselines.explorers.dqn import DQN # noqa: F401 8 | from flexs.baselines.explorers.dyna_ppo import ( # noqa: F401 9 | DynaPPO, 10 | DynaPPOMutative, 11 | ) 12 | from flexs.baselines.explorers.genetic_algorithm import ( # noqa: F401 13 | GeneticAlgorithm, 14 | ) 15 | from flexs.baselines.explorers.ppo import PPO # noqa: F401 16 | from flexs.baselines.explorers.random import Random # noqa: F401 17 | from flexs.baselines.explorers.ggs import Encoder # noqa: F401 18 | from flexs.baselines.explorers.ggs import GwgPairSampler # noqa: F401 19 | from flexs.baselines.explorers.ggs import GWG # noqa: F401 20 | from flexs.baselines.explorers.dirichlet_ppo import MuSearch # noqa: F401 21 | from flexs.baselines.explorers.pex import ProximalExploration # noqa: F401 22 | from flexs.baselines.explorers.evoplay import Evoplay # noqa: F401 23 | from flexs.baselines.explorers.pure_random import PureRandom # noqa: F401 24 | 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ! This repo is deprecated now, please refer to the official repo: [https://github.com/microsoft/Mu-Protein](https://github.com/microsoft/Mu-Protein) 2 | 3 | --- 4 | 5 | # Introduction 6 | The repository primarily hosts the code for μProtein, or Mu-Protein, uProtein, MuProtein for readability, a potent tool tailored for predicting the effects of protein mutations and navigating the fitness landscape. It is configured to facilitate the replication of the models presented in the paper titled *Accelerating protein engineering with fitness landscape modeling and reinforcement learning* which can be accessed at [this link](https://www.biorxiv.org/content/10.1101/2023.11.16.565910v5). Please note that the official release of this repository is expected soon. It will receive ongoing updates and maintenance. After its release, this current repository will be retired. 7 | 8 | This repository consists of three main components: 9 | 10 | - **`pmlm/`** – Protein language model pretraining 11 | - **`mu-former/`** – Fitness landscape modeling using the pretrained protein language model 12 | - **`mu-search/`** – Navigating the constructed fitness landscape oracle with reinforcement learning 13 | 14 | For more details, refer to the respective README files: 15 | 16 | - [PMLM Pretraining](pmlm/README.md) 17 | - [μFormer](mu-former/README.md) 18 | - [μSearch](mu-search/README.md) 19 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/models/mlp.py: -------------------------------------------------------------------------------- 1 | """Define a baseline multilayer perceptron model.""" 2 | import tensorflow as tf 3 | 4 | from . import keras_model 5 | 6 | 7 | class MLP(keras_model.KerasModel): 8 | """A baseline MLP with three dense layers and relu activations.""" 9 | 10 | def __init__( 11 | self, 12 | seq_len, 13 | hidden_size, 14 | alphabet, 15 | loss="MSE", 16 | name=None, 17 | batch_size=256, 18 | epochs=20, 19 | ): 20 | """Create an MLP.""" 21 | model = tf.keras.models.Sequential( 22 | [ 23 | tf.keras.layers.Flatten(), 24 | tf.keras.layers.Dense( 25 | hidden_size, input_shape=(seq_len, len(alphabet)), activation="relu" 26 | ), 27 | tf.keras.layers.Dense(hidden_size, activation="relu"), 28 | tf.keras.layers.Dense(hidden_size, activation="relu"), 29 | tf.keras.layers.Dense(1), 30 | ] 31 | ) 32 | 33 | model.compile(loss=loss, optimizer="adam", metrics=["mse"]) 34 | 35 | if name is None: 36 | name = f"MLP_hidden_size_{hidden_size}" 37 | 38 | super().__init__( 39 | model, 40 | alphabet=alphabet, 41 | name=name, 42 | batch_size=batch_size, 43 | epochs=epochs, 44 | ) 45 | -------------------------------------------------------------------------------- /mu-search/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md") as f: 4 | long_description = f.read() 5 | 6 | setuptools.setup( 7 | name="flexs", 8 | version="0.2.7", 9 | description=( 10 | "FLEXS: an open simulation environment for developing and comparing " 11 | "model-guided biological sequence design algorithms." 12 | ), 13 | url="https://github.com/samsinai/FLEXS", 14 | author="Stewart Slocum", 15 | author_email="slocumstewy@gmail.com", 16 | license="Apache 2.0", 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | packages=setuptools.find_packages(), 20 | python_requires=">=3.5", 21 | install_requires=[ 22 | "cma", 23 | "editdistance", 24 | "numpy>=1.16", 25 | "pandas>=0.23", 26 | "torch>=0.4", 27 | "scikit-learn>=0.20", 28 | "tape-proteins", 29 | "tensorflow>=2", 30 | "tf-agents>=0.7.1", 31 | ], 32 | include_package_data=True, 33 | package_data={ 34 | "": [ 35 | "landscapes/data/additive_aav_packaging/*", 36 | "landscapes/data/rosetta/*", 37 | "landscapes/data/tf_binding/*", 38 | ] 39 | }, 40 | classifiers=[ 41 | "Programming Language :: Python :: 3", 42 | "Programming Language :: Python :: 3.5", 43 | "Programming Language :: Python :: 3.6", 44 | "Programming Language :: Python :: 3.7", 45 | ], 46 | ) 47 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/NOTICE: -------------------------------------------------------------------------------- 1 | Large portion of the code of Stable-Baselines3 (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines, 2 | both licensed under the MIT License: 3 | 4 | before the fork (June 2018): 5 | Copyright (c) 2017 OpenAI (http://openai.com) 6 | 7 | after the fork (June 2018): 8 | Copyright (c) 2018-2019 Stable-Baselines Team 9 | 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in 19 | all copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | THE SOFTWARE. 28 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscape.py: -------------------------------------------------------------------------------- 1 | """Defines the Landscape class.""" 2 | import abc 3 | 4 | import numpy as np 5 | 6 | from flexs.types import SEQUENCES_TYPE 7 | 8 | 9 | class Landscape(abc.ABC): 10 | """ 11 | Base class for all landscapes and for `flexs.Model`. 12 | 13 | Attributes: 14 | cost (int): Number of sequences whose fitness has been evaluated. 15 | name (str): A human-readable name for the landscape (often contains 16 | parameter values in the name) which is used when logging explorer runs. 17 | 18 | """ 19 | 20 | def __init__(self, name: str): 21 | """Create Landscape, setting `name` and setting `cost` to zero.""" 22 | self.cost = 0 23 | self.name = name 24 | 25 | @abc.abstractmethod 26 | def _fitness_function(self, sequences: SEQUENCES_TYPE) -> np.ndarray: 27 | pass 28 | 29 | def get_fitness(self, sequences: SEQUENCES_TYPE) -> np.ndarray: 30 | """ 31 | Score a list/numpy array of sequences. 32 | 33 | This public method should not be overriden – new landscapes should 34 | override the private `_fitness_function` method instead. This method 35 | increments `self.cost` and then calls and returns `_fitness_function`. 36 | 37 | Args: 38 | sequences: A list/numpy array of sequence strings to be scored. 39 | 40 | Returns: 41 | Scores for each sequence. 42 | 43 | """ 44 | self.cost += len(sequences) 45 | return self._fitness_function(sequences) 46 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/protein/tasks/mlm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from fairseq import utils 5 | from fairseq.data import Dictionary 6 | from fairseq.tasks import FairseqTask, register_task 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | @register_task('prot_mlm') 11 | class ProtMaskedLMTask(FairseqTask): 12 | """ 13 | Args: 14 | dictionary (Dictionary): the dictionary for the input of the task 15 | """ 16 | 17 | def __init__(self, args, dictionary): 18 | super().__init__(args) 19 | self.args = args 20 | self.dictionary = dictionary 21 | self.seed = args.seed 22 | self.preds = [] 23 | self.targets = [] 24 | if not hasattr(args, 'max_positions'): 25 | self._max_positions = ( 26 | args.max_source_positions, 27 | args.max_target_positions, 28 | ) 29 | else: 30 | self._max_positions = args.max_positions 31 | args.tokens_per_sample = self._max_positions 32 | # add mask token 33 | self.mask_idx = dictionary.add_symbol("") 34 | 35 | @classmethod 36 | def setup_task(cls, args, **kwargs): 37 | paths = utils.split_paths(args.data) 38 | assert len(paths) > 0 39 | dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) 40 | logger.info("dictionary: {} types".format(len(dictionary))) 41 | return cls(args, dictionary) 42 | 43 | @property 44 | def source_dictionary(self): 45 | return self.dictionary -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/protein/tasks/pmlm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from fairseq import utils 5 | from fairseq.data import Dictionary 6 | 7 | from fairseq.tasks import FairseqTask, register_task 8 | logger = logging.getLogger(__name__) 9 | 10 | @register_task('prot_pmlm') 11 | class ProtPairwiseMaskedLMTask(FairseqTask): 12 | """ 13 | Args: 14 | dictionary (Dictionary): the dictionary for the input of the task 15 | """ 16 | def __init__(self, args, dictionary): 17 | super().__init__(args) 18 | self.args = args 19 | self.dictionary = dictionary 20 | self.seed = args.seed 21 | self.preds = [] 22 | self.targets = [] 23 | if not hasattr(args, 'max_positions'): 24 | self._max_positions = ( 25 | args.max_source_positions, 26 | args.max_target_positions, 27 | ) 28 | else: 29 | self._max_positions = args.max_positions 30 | args.tokens_per_sample = self._max_positions 31 | # add mask token 32 | self.mask_idx = dictionary.add_symbol("") 33 | 34 | @classmethod 35 | def setup_task(cls, args, **kwargs): 36 | paths = utils.split_paths(args.data) 37 | assert len(paths) > 0 38 | dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) 39 | logger.info("dictionary: {} types".format(len(dictionary))) 40 | return cls(args, dictionary) 41 | 42 | @property 43 | def source_dictionary(self): 44 | return self.dictionary -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def pearson_loss(x, y): 5 | mean_x = torch.mean(x) 6 | mean_y = torch.mean(y) 7 | xm = x.sub(mean_x) 8 | ym = y.sub(mean_y) 9 | r_num = xm.dot(ym) 10 | r_den = torch.norm(xm, 2) * torch.norm(ym, 2) 11 | r_val = r_num / r_den 12 | return 1 - r_val 13 | 14 | def pearson_correlation_loss(y_true, y_pred, normalized=False): 15 | """ 16 | Calculate pearson correlation loss 17 | :param y_true: distance matrix tensor tensor size (batch_size, batch_size) 18 | :param y_pred: distance matrix tensor tensor size (batch_size, batch_size) 19 | :param normalized: if True, Softmax is applied to the distance matrix 20 | :return: loss tensor 21 | """ 22 | if normalized: 23 | y_true = F.softmax(y_true, axis=-1) 24 | y_pred = F.softmax(y_pred, axis=-1) 25 | 26 | sum_true = torch.sum(y_true) 27 | sum2_true = torch.sum(torch.pow(y_true, 2)) # square ~= np.pow(a,2) 28 | 29 | sum_pred = torch.sum(y_pred) 30 | sum2_pred = torch.sum(torch.pow(y_pred, 2)) 31 | 32 | prod = torch.sum(y_true * y_pred) 33 | n = y_true.shape[0] # n == y_true.shape[0] 34 | 35 | corr = n * prod - sum_true * sum_pred 36 | corr /= torch.sqrt(n * sum2_true - sum_true * sum_true + torch.finfo(torch.float32).eps) # K.epsilon() == 1e-7 37 | corr /= torch.sqrt(n * sum2_pred - sum_pred * sum_pred + torch.finfo(torch.float32).eps) 38 | 39 | return 1 - corr 40 | 41 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/models/global_epistasis_model.py: -------------------------------------------------------------------------------- 1 | """Define a global epistasis model.""" 2 | import tensorflow as tf 3 | 4 | from . import keras_model 5 | 6 | 7 | class GlobalEpistasisModel(keras_model.KerasModel): 8 | """ 9 | Global epistasis model. 10 | 11 | Weighted sum of input features follow by several dense layers. 12 | A simple, but relatively uneffective nonlinear model. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | seq_len: int, 18 | hidden_size: int, 19 | alphabet: int, 20 | loss="MSE", 21 | name: str = None, 22 | batch_size: int = 256, 23 | epochs: int = 20, 24 | ): 25 | """Create a global epistasis model.""" 26 | model = tf.keras.models.Sequential( 27 | [ 28 | tf.keras.layers.Flatten(), 29 | tf.keras.layers.Dense( 30 | 1, input_shape=(seq_len, len(alphabet)), activation="relu" 31 | ), 32 | tf.keras.layers.Dense(hidden_size, activation="relu"), 33 | tf.keras.layers.Dense(hidden_size, activation="relu"), 34 | tf.keras.layers.Dense(1), 35 | ] 36 | ) 37 | model.compile(loss=loss, optimizer="adam", metrics=["mse"]) 38 | 39 | if name is None: 40 | name = f"MLP_hidden_size_{hidden_size}" 41 | 42 | super().__init__( 43 | model, 44 | alphabet=alphabet, 45 | name=name, 46 | batch_size=batch_size, 47 | epochs=epochs, 48 | ) 49 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/protein/tasks/mlm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from fairseq import utils 5 | from fairseq.data import Dictionary 6 | from fairseq.tasks import FairseqTask, register_task 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | @register_task('prot_mlm') 11 | class ProtMaskedLMTask(FairseqTask): 12 | """ 13 | Args: 14 | dictionary (Dictionary): the dictionary for the input of the task 15 | """ 16 | 17 | def __init__(self, args, dictionary): 18 | super().__init__(args) 19 | self.args = args 20 | self.dictionary = dictionary 21 | self.seed = args.seed 22 | self.preds = [] 23 | self.targets = [] 24 | if not hasattr(args, 'max_positions'): 25 | self._max_positions = ( 26 | args.max_source_positions, 27 | args.max_target_positions, 28 | ) 29 | else: 30 | self._max_positions = args.max_positions 31 | args.tokens_per_sample = self._max_positions 32 | # add mask token 33 | self.mask_idx = dictionary.add_symbol("") 34 | 35 | @classmethod 36 | def setup_task(cls, args, **kwargs): 37 | paths = utils.split_paths(args.data) 38 | assert len(paths) > 0 39 | dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) 40 | logger.info("dictionary: {} types".format(len(dictionary))) 41 | return cls(args, dictionary) 42 | 43 | @property 44 | def source_dictionary(self): 45 | return self.dictionary -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/protein/tasks/pmlm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from fairseq import utils 5 | from fairseq.data import Dictionary 6 | 7 | from fairseq.tasks import FairseqTask, register_task 8 | logger = logging.getLogger(__name__) 9 | 10 | @register_task('prot_pmlm') 11 | class ProtPairwiseMaskedLMTask(FairseqTask): 12 | """ 13 | Args: 14 | dictionary (Dictionary): the dictionary for the input of the task 15 | """ 16 | def __init__(self, args, dictionary): 17 | super().__init__(args) 18 | self.args = args 19 | self.dictionary = dictionary 20 | self.seed = args.seed 21 | self.preds = [] 22 | self.targets = [] 23 | if not hasattr(args, 'max_positions'): 24 | self._max_positions = ( 25 | args.max_source_positions, 26 | args.max_target_positions, 27 | ) 28 | else: 29 | self._max_positions = args.max_positions 30 | args.tokens_per_sample = self._max_positions 31 | # add mask token 32 | self.mask_idx = dictionary.add_symbol("") 33 | 34 | @classmethod 35 | def setup_task(cls, args, **kwargs): 36 | paths = utils.split_paths(args.data) 37 | assert len(paths) > 0 38 | dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) 39 | logger.info("dictionary: {} types".format(len(dictionary))) 40 | return cls(args, dictionary) 41 | 42 | @property 43 | def source_dictionary(self): 44 | return self.dictionary -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def pearson_loss(x, y): 5 | mean_x = torch.mean(x) 6 | mean_y = torch.mean(y) 7 | xm = x.sub(mean_x) 8 | ym = y.sub(mean_y) 9 | r_num = xm.dot(ym) 10 | r_den = torch.norm(xm, 2) * torch.norm(ym, 2) 11 | r_val = r_num / r_den 12 | return 1 - r_val 13 | 14 | def pearson_correlation_loss(y_true, y_pred, normalized=False): 15 | """ 16 | Calculate pearson correlation loss 17 | :param y_true: distance matrix tensor tensor size (batch_size, batch_size) 18 | :param y_pred: distance matrix tensor tensor size (batch_size, batch_size) 19 | :param normalized: if True, Softmax is applied to the distance matrix 20 | :return: loss tensor 21 | """ 22 | if normalized: 23 | y_true = F.softmax(y_true, axis=-1) 24 | y_pred = F.softmax(y_pred, axis=-1) 25 | 26 | sum_true = torch.sum(y_true) 27 | sum2_true = torch.sum(torch.pow(y_true, 2)) # square ~= np.pow(a,2) 28 | 29 | sum_pred = torch.sum(y_pred) 30 | sum2_pred = torch.sum(torch.pow(y_pred, 2)) 31 | 32 | prod = torch.sum(y_true * y_pred) 33 | n = y_true.shape[0] # n == y_true.shape[0] 34 | 35 | corr = n * prod - sum_true * sum_pred 36 | corr /= torch.sqrt(n * sum2_true - sum_true * sum_true + torch.finfo(torch.float32).eps) # K.epsilon() == 1e-7 37 | corr /= torch.sqrt(n * sum2_pred - sum_pred * sum_pred + torch.finfo(torch.float32).eps) 38 | 39 | return 1 - corr 40 | 41 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/Makefile: -------------------------------------------------------------------------------- 1 | SHELL=/bin/bash 2 | LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py 3 | 4 | pytest: 5 | ./scripts/run_tests.sh 6 | 7 | type: 8 | pytype -j auto 9 | 10 | lint: 11 | # stop the build if there are Python syntax errors or undefined names 12 | # see https://lintlyci.github.io/Flake8Rules/ 13 | flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics 14 | # exit-zero treats all errors as warnings. 15 | flake8 ${LINT_PATHS} --count --exit-zero --statistics 16 | 17 | format: 18 | # Sort imports 19 | isort ${LINT_PATHS} 20 | # Reformat using black 21 | black -l 127 ${LINT_PATHS} 22 | 23 | check-codestyle: 24 | # Sort imports 25 | isort --check ${LINT_PATHS} 26 | # Reformat using black 27 | black --check -l 127 ${LINT_PATHS} 28 | 29 | commit-checks: format type lint 30 | 31 | doc: 32 | cd docs && make html 33 | 34 | spelling: 35 | cd docs && make spelling 36 | 37 | clean: 38 | cd docs && make clean 39 | 40 | # Build docker images 41 | # If you do export RELEASE=True, it will also push them 42 | docker: docker-cpu docker-gpu 43 | 44 | docker-cpu: 45 | ./scripts/build_docker.sh 46 | 47 | docker-gpu: 48 | USE_GPU=True ./scripts/build_docker.sh 49 | 50 | # PyPi package release 51 | release: 52 | python setup.py sdist 53 | python setup.py bdist_wheel 54 | twine upload dist/* 55 | 56 | # Test PyPi package release 57 | test-release: 58 | python setup.py sdist 59 | python setup.py bdist_wheel 60 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 61 | 62 | .PHONY: clean spelling doc lint format check-codestyle commit-checks 63 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import pandas as pd 4 | 5 | 6 | def sequence_to_mutation(sequence, starting_sequence): 7 | ''' 8 | Parameters 9 | ---------- 10 | sequence: str 11 | return: ';'.join(WiM) (wide-type W at position i mutated to M) 12 | ''' 13 | mutant_points = [] 14 | assert len(sequence) == len(starting_sequence) 15 | for i in range(len(sequence)): 16 | if sequence[i] != starting_sequence[i]: 17 | mutant_points += ["%s%s%s"%(starting_sequence[i], i+1, sequence[i])] 18 | return ';'.join(mutant_points), len(mutant_points) 19 | 20 | 21 | def print_and_save_sequences(sequences, starting_sequence, save_name="results", save_json=True): 22 | # sequences: dict, key: sequence, value: [score, embedding, ensemble_uncertainty] 23 | # results: list, [mutation, score, ensemble_uncertainty, embedding] 24 | 25 | from collections import Counter 26 | mutant_sites_counter = Counter() 27 | 28 | results = [] 29 | for sequence, value in sequences.items(): 30 | mutation, mutant_sites_num = sequence_to_mutation(sequence, starting_sequence) 31 | results.append([mutation, value[0], value[1]]) 32 | mutant_sites_counter[mutant_sites_num] += 1 33 | 34 | sorted_results = sorted(results, key=lambda x:x[1], reverse=True) 35 | print(mutant_sites_counter) 36 | 37 | with open(save_name + ".txt", "w") as f: 38 | for item in sorted_results: 39 | f.write("{0} {1}\n".format(item[0], item[1])) 40 | 41 | df = pd.DataFrame(sorted_results, columns=['mutation', 'score', 'embedding']) 42 | 43 | if save_json: 44 | df.to_json(save_name + ".json", orient='records', lines=True) -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/model.py: -------------------------------------------------------------------------------- 1 | """Defines base Model class.""" 2 | import abc 3 | from typing import Any, List 4 | 5 | import numpy as np 6 | 7 | import flexs 8 | from flexs.types import SEQUENCES_TYPE 9 | 10 | 11 | class Model(flexs.Landscape, abc.ABC): 12 | """ 13 | Base model class. Inherits from `flexs.Landscape` and adds an additional 14 | `train` method. 15 | 16 | """ 17 | 18 | @abc.abstractmethod 19 | def train(self, sequences: SEQUENCES_TYPE, labels: List[Any]): 20 | """ 21 | Train model. 22 | 23 | This function is called whenever you would want your model to update itself 24 | based on the set of sequences it has measurements for. 25 | 26 | """ 27 | pass 28 | 29 | 30 | class LandscapeAsModel(Model): 31 | """ 32 | This simple class wraps a `flexs.Landscape` in a `flexs.Model` to allow running 33 | experiments against a perfect model. 34 | 35 | This class's `_fitness_function` simply calls the landscape's `_fitness_function`. 36 | """ 37 | 38 | def __init__(self, landscape: flexs.Landscape): 39 | """ 40 | Create a `flexs.Model` out of a `flexs.Landscape`. 41 | 42 | Args: 43 | landscape: Landscape to wrap in a model. 44 | 45 | """ 46 | super().__init__(f"LandscapeAsModel={landscape.name}") 47 | self.landscape = landscape 48 | 49 | def _fitness_function(self, sequences: SEQUENCES_TYPE) -> np.ndarray: 50 | return self.landscape._fitness_function(sequences) 51 | 52 | def gradient_function(self, sequences: SEQUENCES_TYPE) -> np.ndarray: 53 | return self.landscape.gradient_function(sequences) 54 | 55 | def train(self, sequences: SEQUENCES_TYPE, labels: List[Any]): 56 | """No-op.""" 57 | pass 58 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/ensemble.py: -------------------------------------------------------------------------------- 1 | """Defines the Ensemble class.""" 2 | from typing import Callable, List 3 | 4 | import numpy as np 5 | 6 | import flexs 7 | from flexs.types import SEQUENCES_TYPE 8 | 9 | 10 | class Ensemble(flexs.Model): 11 | """ 12 | Class to ensemble models or landscapes together. 13 | 14 | Attributes: 15 | models (List[flexs.Landscape]): List of landscapes/models being ensembled. 16 | combine_with (Callable[[np.ndarray], np.ndarray]): Function to combine ensemble 17 | predictions. 18 | 19 | """ 20 | 21 | def __init__( 22 | self, 23 | models: List[flexs.Landscape], 24 | combine_with: Callable[[np.ndarray], np.ndarray] = lambda x: np.mean(x, axis=1), 25 | ): 26 | """ 27 | Create ensemble. 28 | 29 | Args: 30 | models: List of landscapes/models to ensemble. 31 | combine_with: A function that takes in a matrix of scores 32 | (num_seqs, num_models) and combines ensembled model scores into an 33 | array (num_seqs,). 34 | 35 | """ 36 | name = f"Ens({'|'.join(model.name for model in models)})" 37 | super().__init__(name) 38 | 39 | self.models = models 40 | self.combine_with = combine_with 41 | 42 | def train(self, sequences: SEQUENCES_TYPE, labels: np.ndarray): 43 | """ 44 | Train each model in `self.models`. 45 | 46 | Args: 47 | sequences: Training sequences 48 | labels: Training labels 49 | 50 | """ 51 | for model in self.models: 52 | model.train(sequences, labels) 53 | 54 | def _fitness_function(self, sequences): 55 | scores = np.stack( 56 | [model.get_fitness(sequences) for model in self.models], axis=1 57 | ) 58 | 59 | return self.combine_with(scores) 60 | -------------------------------------------------------------------------------- /pmlm/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | The folder primarily hosts the code for pre-training pairwise masked language model for protein (PMLM). 4 | 5 | # Environment 6 | 7 | Follow the steps below to set up the Conda environment: 8 | ``` 9 | conda create -n pmlm python==3.8 10 | conda activate pmlm 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | Additionally, you need to install PyTorch. The version to be installed is dependent on your GPU driver version. For instance: 15 | ``` 16 | pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 17 | ``` 18 | Or for a cpu-only version: 19 | ``` 20 | pip install torch==1.12.0+cpu torchvision==0.13.0+cpu torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu 21 | ``` 22 | 23 | # Getting Started 24 | 25 | ## Data Preparation 26 | 27 | Download and preprocess the protein sequence data into two text files: one for training and one for validation. Each sequence should: 28 | 29 | - Be **written in uppercase letters** 30 | - Have **spaces inserted between residues** 31 | 32 | For example, each line should be formatted as: 33 | ``` 34 | C A S E F W S A W F ... C A D 35 | ``` 36 | 37 | After formatting, use `fairseq-preprocess` to convert these files into Fairseq binary format: 38 | 39 | ``` 40 | fairseq-preprocess \ 41 | --only-source \ 42 | --trainpref /uniref50.train.seqs \ 43 | --validpref /uniref50.valid.seqs \ 44 | --destdir /generated/uniref50 \ 45 | --workers 120 \ 46 | --srcdict pmlm/src/protein/dict.txt 47 | ``` 48 | 49 | ## Pretraining the Model 50 | To pretrain a model using PMLM with the preprocessed files, run: 51 | ``` 52 | cd pmlm 53 | bash script/pretrain.sh 54 | ``` 55 | 56 | ## Pretrained Model 57 | 58 | A pretrained model is publicly available on [figshare](https://doi.org/10.6084/m9.figshare.26892355). 59 | -------------------------------------------------------------------------------- /pmlm/script/pretrain.sh: -------------------------------------------------------------------------------- 1 | HERE=$(cd "$(dirname "$0")";pwd) 2 | 3 | USER_DIR=$HERE/../src/protein 4 | 5 | DATA_DIR=/mnt/data/generated_data/uniref50_2018_03/ # Path to the preprocess data bin files by fairseq 6 | 7 | MAX_EPOCH=1000 # change this 8 | WARMUP_UPDATES=20000 9 | TOTAL_UPDATES=2000000 10 | PEAK_LR=0.0001 11 | TOKENS_PER_SAMPLE=768 12 | MAX_POSITIONS=768 13 | MAX_TOEKNS=768 14 | UPDATE_FREQ=16 15 | LOG_INTERVAL=10 16 | SAVE_INTERVAL=1 17 | NUM_WORKERS=0 18 | VALID_SUBSET=valid 19 | DDP_BACKEND=no_c10d 20 | BEST_CHECKPOINT_METRIC=loss 21 | 22 | TASK=prot_pmlm 23 | CRIT=prot_pmlm 24 | ARCH=prot_pmlm_1b 25 | 26 | PRETRAIN_TASK=pcomb 27 | 28 | LOG_DIR=$HERE/../log/ 29 | 30 | CHECKPOINT_DIR=$HERE/../ckpt/ur50-$PRETRAIN_TASK-$ARCH-ckpt/ 31 | 32 | rm -rf $LOG_DIR 33 | mkdir -p $LOG_DIR 34 | 35 | fairseq-train $DATA_DIR --fp16 \ 36 | --fix-batches-to-gpus \ 37 | --distributed-no-spawn \ 38 | --task $TASK \ 39 | --criterion $CRIT \ 40 | --arch $ARCH \ 41 | --optimizer adam \ 42 | --adam-betas '(0.9,0.98)' \ 43 | --adam-eps 1e-6 --clip-norm 1.0 \ 44 | --lr-scheduler polynomial_decay \ 45 | --lr $PEAK_LR \ 46 | --total-num-update $TOTAL_UPDATES \ 47 | --warmup-updates $WARMUP_UPDATES \ 48 | --update-freq $UPDATE_FREQ \ 49 | --dropout 0.1 \ 50 | --weight-decay 0.01 \ 51 | --tokens-per-sample $TOKENS_PER_SAMPLE \ 52 | --max-positions $MAX_POSITIONS \ 53 | --max-tokens $MAX_TOEKNS \ 54 | --max-epoch $MAX_EPOCH \ 55 | --log-format simple \ 56 | --log-interval $LOG_INTERVAL \ 57 | --valid-subset $VALID_SUBSET \ 58 | --save-interval $SAVE_INTERVAL \ 59 | --save-interval-updates 1000 \ 60 | --keep-interval-updates 3 \ 61 | --best-checkpoint-metric $BEST_CHECKPOINT_METRIC \ 62 | --ddp-backend=$DDP_BACKEND \ 63 | --tensorboard-logdir $LOG_DIR \ 64 | --num-workers $NUM_WORKERS \ 65 | --save-dir $CHECKPOINT_DIR \ 66 | --sample-break-mode eos \ 67 | --skip-invalid-size-inputs-valid-test \ 68 | --pretrain-task $PRETRAIN_TASK \ 69 | --user-dir $USER_DIR | tee $LOG_DIR/log.txt 70 | 71 | -------------------------------------------------------------------------------- /mu-search/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | plotting_data_subset/ 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | 110 | fluorescence-model/ 111 | cloud/runs/ 112 | events.* 113 | *.rst 114 | *.csv 115 | logs/ 116 | efficiency_trash 117 | efficiency 118 | docs 119 | *.bin 120 | *.pdf 121 | *.png 122 | *.json 123 | paper_code 124 | .github 125 | examples/mean_plot.py 126 | examples/efficiency 127 | examples/efficiency_100_5000.pdf 128 | examples/robustness 129 | examples/figs 130 | *.sh 131 | examples/all_plot.py 132 | -------------------------------------------------------------------------------- /mu-search/src/flexs/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | plotting_data_subset/ 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | 110 | fluorescence-model/ 111 | cloud/runs/ 112 | *.pt 113 | *.rst 114 | *.txt 115 | *.pdf 116 | *.png 117 | *.csv 118 | *.json 119 | flexs/baselines/explorers/stable_baselines3/.dockerignore 120 | flexs/baselines/explorers/stable_baselines3/Dockerfile 121 | *.bat 122 | efficiency_trash 123 | restore 124 | robustness 125 | tests 126 | run.sh 127 | mu.sh 128 | dirichlet_ppo.sh 129 | copy.sh 130 | output.png 131 | output.txt 132 | *.md -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | # This includes the license file in the wheel. 3 | license_files = LICENSE 4 | project_urls = 5 | Code = https://github.com/DLR-RM/stable-baselines3 6 | Documentation = https://stable-baselines3.readthedocs.io/ 7 | 8 | [tool:pytest] 9 | # Deterministic ordering for tests; useful for pytest-xdist. 10 | env = 11 | PYTHONHASHSEED=0 12 | filterwarnings = 13 | # Tensorboard warnings 14 | ignore::DeprecationWarning:tensorboard 15 | # Gym warnings 16 | ignore:Parameters to load are deprecated.:DeprecationWarning 17 | ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning 18 | ignore::UserWarning:gym 19 | ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning 20 | ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning 21 | markers = 22 | expensive: marks tests as expensive (deselect with '-m "not expensive"') 23 | 24 | [pytype] 25 | inputs = stable_baselines3 26 | disable = pyi-error 27 | 28 | [flake8] 29 | ignore = W503,W504,E203,E231 # line breaks before and after binary operators 30 | # Ignore import not used when aliases are defined 31 | per-file-ignores = 32 | ./stable_baselines3/__init__.py:F401 33 | ./stable_baselines3/common/__init__.py:F401 34 | ./stable_baselines3/common/envs/__init__.py:F401 35 | ./stable_baselines3/a2c/__init__.py:F401 36 | ./stable_baselines3/ddpg/__init__.py:F401 37 | ./stable_baselines3/dqn/__init__.py:F401 38 | ./stable_baselines3/her/__init__.py:F401 39 | ./stable_baselines3/ppo/__init__.py:F401 40 | ./stable_baselines3/sac/__init__.py:F401 41 | ./stable_baselines3/td3/__init__.py:F401 42 | ./stable_baselines3/common/vec_env/__init__.py:F401 43 | exclude = 44 | # No need to traverse our git directory 45 | .git, 46 | # There's no value in checking cache directories 47 | __pycache__, 48 | # Don't check the doc 49 | docs/ 50 | # This contains our built documentation 51 | build, 52 | # This contains builds of flake8 that we don't want to check 53 | dist 54 | *.egg-info 55 | max-complexity = 15 56 | # The GitHub editor is 127 chars wide 57 | max-line-length = 127 58 | 59 | [isort] 60 | profile = black 61 | line_length = 127 62 | src_paths = stable_baselines3 63 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/running_mean_std.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import numpy as np 4 | 5 | 6 | class RunningMeanStd: 7 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): 8 | """ 9 | Calulates the running mean and std of a data stream 10 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 11 | 12 | :param epsilon: helps with arithmetic issues 13 | :param shape: the shape of the data stream's output 14 | """ 15 | self.mean = np.zeros(shape, np.float64) 16 | self.var = np.ones(shape, np.float64) 17 | self.count = epsilon 18 | 19 | def copy(self) -> "RunningMeanStd": 20 | """ 21 | :return: Return a copy of the current object. 22 | """ 23 | new_object = RunningMeanStd(shape=self.mean.shape) 24 | new_object.mean = self.mean.copy() 25 | new_object.var = self.var.copy() 26 | new_object.count = float(self.count) 27 | return new_object 28 | 29 | def combine(self, other: "RunningMeanStd") -> None: 30 | """ 31 | Combine stats from another ``RunningMeanStd`` object. 32 | 33 | :param other: The other object to combine with. 34 | """ 35 | self.update_from_moments(other.mean, other.var, other.count) 36 | 37 | def update(self, arr: np.ndarray) -> None: 38 | batch_mean = np.mean(arr, axis=0) 39 | batch_var = np.var(arr, axis=0) 40 | batch_count = arr.shape[0] 41 | self.update_from_moments(batch_mean, batch_var, batch_count) 42 | 43 | def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: Union[int, float]) -> None: 44 | delta = batch_mean - self.mean 45 | tot_count = self.count + batch_count 46 | 47 | new_mean = self.mean + delta * batch_count / tot_count 48 | m_a = self.var * self.count 49 | m_b = batch_var * batch_count 50 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 51 | new_var = m_2 / (self.count + batch_count) 52 | 53 | new_count = batch_count + self.count 54 | 55 | self.mean = new_mean 56 | self.var = new_var 57 | self.count = new_count 58 | -------------------------------------------------------------------------------- /mu-search/tests/test_landscapes.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import warnings 3 | 4 | import flexs 5 | from flexs.utils import sequence_utils as s_utils 6 | 7 | 8 | def test_additive_aav_packaging(): 9 | problem = flexs.landscapes.additive_aav_packaging.registry()["heart"] 10 | landscape = flexs.landscapes.AdditiveAAVPackaging(**problem["params"]) 11 | 12 | test_seqs = s_utils.generate_random_sequences(90, 100, s_utils.AAS) 13 | landscape.get_fitness(test_seqs) 14 | 15 | 16 | def test_rna(): 17 | # Since ViennaRNA is an optional dependency, only test if installed 18 | try: 19 | problem = flexs.landscapes.rna.registry()["C20_L100_RNA1+2"] 20 | landscape = flexs.landscapes.RNABinding(**problem["params"]) 21 | 22 | test_seqs = s_utils.generate_random_sequences(100, 100, s_utils.RNAA) 23 | landscape.get_fitness(test_seqs) 24 | 25 | except ImportError: 26 | warnings.warn( 27 | "Skipping RNABinding landscape test since" "ViennaRNA not installed." 28 | ) 29 | 30 | 31 | def test_rosetta(): 32 | # Since PyRosetta is an optional dependency, only test if installed 33 | try: 34 | problem = flexs.landscapes.rosetta.registry()["3msi"] 35 | landscape = flexs.landscapes.RosettaFolding(**problem["params"]) 36 | 37 | seq_length = len(landscape.wt_pose.sequence()) 38 | test_seqs = s_utils.generate_random_sequences(seq_length, 100, s_utils.AAS) 39 | landscape.get_fitness(test_seqs) 40 | 41 | except ImportError: 42 | warnings.warn( 43 | "Skipping RosettaFolding landscape test since PyRosetta not installed." 44 | ) 45 | 46 | 47 | def test_tf_binding(): 48 | problem = flexs.landscapes.tf_binding.registry()["SIX6_REF_R1"] 49 | landscape = flexs.landscapes.TFBinding(**problem["params"]) 50 | 51 | test_seqs = s_utils.generate_random_sequences(8, 100, s_utils.DNAA) 52 | landscape.get_fitness(test_seqs) 53 | 54 | 55 | # TODO: This test takes too long for github actions. Needs further investigation. 56 | """ 57 | def test_bert_gfp(): 58 | landscape = flexs.landscapes.BertGFPBrightness() 59 | 60 | seq_length = len(landscape.gfp_wt_sequence) 61 | test_seqs = s_utils.generate_random_sequences(seq_length, 100, s_utils.AAS) 62 | landscape.get_fitness(test_seqs) 63 | 64 | # Clean up downloaded model 65 | shutil.rmtree("fluorescence-model") 66 | """ 67 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/models/keras_model.py: -------------------------------------------------------------------------------- 1 | """Define the base KerasModel class.""" 2 | from typing import Callable 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | import flexs 8 | from flexs.types import SEQUENCES_TYPE 9 | from flexs.utils import sequence_utils as s_utils 10 | 11 | 12 | class KerasModel(flexs.Model): 13 | """A wrapper around tensorflow/keras models.""" 14 | 15 | def __init__( 16 | self, 17 | model, 18 | alphabet, 19 | name, 20 | batch_size=256, 21 | epochs=20, 22 | custom_train_function: Callable[[tf.Tensor, tf.Tensor], None] = None, 23 | custom_predict_function: Callable[[tf.Tensor], np.ndarray] = None, 24 | ): 25 | """ 26 | Wrap a tensorflow/keras model. 27 | 28 | Args: 29 | model: A callable and fittable keras model. 30 | alphabet: Alphabet string. 31 | name: Human readable description of model (used for logging). 32 | batch_size: Batch size for `model.fit` and `model.predict`. 33 | epochs: Number of epochs to train for. 34 | custom_train_function: A function that receives a tensor of one-hot 35 | sequences and labels and trains `model`. 36 | custom_predict_function: A function that receives a tensor of one-hot 37 | sequences and predictions. 38 | 39 | """ 40 | super().__init__(name) 41 | 42 | self.model = model 43 | self.alphabet = alphabet 44 | 45 | self.name = name 46 | self.epochs = epochs 47 | self.batch_size = batch_size 48 | 49 | def train( 50 | self, sequences: SEQUENCES_TYPE, labels: np.ndarray, verbose: bool = False 51 | ): 52 | """Train keras model.""" 53 | one_hots = tf.convert_to_tensor( 54 | np.array( 55 | [s_utils.string_to_one_hot(seq, self.alphabet) for seq in sequences] 56 | ), 57 | dtype=tf.float32, 58 | ) 59 | labels = tf.convert_to_tensor(labels) 60 | 61 | self.model.fit( 62 | one_hots, 63 | labels, 64 | batch_size=self.batch_size, 65 | epochs=self.epochs, 66 | verbose=verbose, 67 | ) 68 | 69 | def _fitness_function(self, sequences): 70 | one_hots = tf.convert_to_tensor( 71 | np.array( 72 | [s_utils.string_to_one_hot(seq, self.alphabet) for seq in sequences] 73 | ), 74 | dtype=tf.float32, 75 | ) 76 | 77 | return self.model.predict(one_hots, batch_size=self.batch_size).squeeze(axis=1) 78 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/pure_random.py: -------------------------------------------------------------------------------- 1 | """Defines the Random explorer class.""" 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import flexs 8 | from flexs.utils import sequence_utils as s_utils 9 | from copy import deepcopy 10 | 11 | 12 | class PureRandom(flexs.Explorer): 13 | """A simple random explorer. 14 | 15 | Chooses a random previously measured sequence and mutates it. 16 | 17 | A good baseline to compare other search strategies against. 18 | 19 | Since random search is not data-driven, the model is only used to score 20 | sequences, but not to guide the search strategy. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model: flexs.Model, 26 | rounds: int, 27 | starting_sequence: str, 28 | sequences_batch_size: int, 29 | model_queries_per_batch: int, 30 | alphabet: str, 31 | n: float = 3, 32 | elitist: bool = False, 33 | seed: Optional[int] = None, 34 | log_file: Optional[str] = None, 35 | ): 36 | """ 37 | Create a random search explorer. 38 | 39 | Args: 40 | mu: Average number of residue mutations from parent for generated sequences. 41 | elitist: If true, will propose the top `sequences_batch_size` sequences 42 | generated according to `model`. If false, randomly proposes 43 | `sequences_batch_size` sequences without taking model score into 44 | account (true random search). 45 | seed: Integer seed for random number generator. 46 | 47 | """ 48 | name = f"Random_points={n}" 49 | 50 | super().__init__( 51 | model, 52 | name, 53 | rounds, 54 | sequences_batch_size, 55 | model_queries_per_batch, 56 | starting_sequence, 57 | log_file, 58 | ) 59 | self.n = n 60 | self.rng = np.random.default_rng(seed) 61 | self.alphabet = alphabet 62 | 63 | def propose_sequences( 64 | self, measured_sequences: pd.DataFrame 65 | ) -> Tuple[np.ndarray, np.ndarray]: 66 | """Propose top `sequences_batch_size` sequences for evaluation.""" 67 | old_sequences = measured_sequences["sequence"] 68 | old_sequence_set = set(old_sequences) 69 | new_seqs = set() 70 | 71 | while len(new_seqs) <= self.model_queries_per_batch: 72 | seq = self.starting_sequence 73 | new_seq = s_utils.generate_random_n_points_mutants( 74 | seq, self.n, alphabet=self.alphabet 75 | ) 76 | if new_seq not in old_sequence_set: 77 | new_seqs.add(new_seq) 78 | 79 | new_seqs = np.array(list(new_seqs)) 80 | preds = self.model.get_fitness(new_seqs) 81 | 82 | idxs = self.rng.integers(0, len(new_seqs), size=self.sequences_batch_size) 83 | 84 | return new_seqs[idxs], preds[idxs] 85 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/utils/seq_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | def hamming_distance(seq_1, seq_2): 6 | return sum([x!=y for x, y in zip(seq_1, seq_2)]) 7 | 8 | def random_mutation(sequence, alphabet, num_mutations): 9 | wt_seq = list(sequence) 10 | for _ in range(num_mutations): 11 | idx = np.random.randint(len(sequence)) 12 | wt_seq[idx] = alphabet[np.random.randint(len(alphabet))] 13 | new_seq = ''.join(wt_seq) 14 | return new_seq 15 | 16 | def sequence_to_one_hot(sequence, alphabet): 17 | # Input: - sequence: [sequence_length] 18 | # - alphabet: [alphabet_size] 19 | # Output: - one_hot: [sequence_length, alphabet_size] 20 | 21 | alphabet_dict = {x: idx for idx, x in enumerate(alphabet)} 22 | one_hot = F.one_hot(torch.tensor([alphabet_dict[x] for x in sequence]).long(), num_classes=len(alphabet)) 23 | return one_hot 24 | 25 | def sequences_to_tensor(sequences, alphabet): 26 | # Input: - sequences: [batch_size, sequence_length] 27 | # - alphabet: [alphabet_size] 28 | # Output: - one_hots: [batch_size, alphabet_size, sequence_length] 29 | 30 | one_hots = torch.stack([sequence_to_one_hot(seq, alphabet) for seq in sequences], dim=0) 31 | one_hots = torch.permute(one_hots, [0, 2, 1]).float() 32 | return one_hots 33 | 34 | def sequences_to_mutation_sets(sequences, alphabet, wt_sequence, context_radius): 35 | # Input: - sequences: [batch_size, sequence_length] 36 | # - alphabet: [alphabet_size] 37 | # - wt_sequence: [sequence_length] 38 | # - context_radius: integer 39 | # Output: - mutation_sets: [batch_size, max_mutation_num, alphabet_size, 2*context_radius+1] 40 | # - mutation_sets_mask: [batch_size, max_mutation_num] 41 | 42 | context_width = 2 * context_radius + 1 43 | max_mutation_num = max(1, np.max([hamming_distance(seq, wt_sequence) for seq in sequences])) 44 | 45 | mutation_set_List, mutation_set_mask_List = [], [] 46 | for seq in sequences: 47 | one_hot = sequence_to_one_hot(seq, alphabet).numpy() 48 | one_hot_padded = np.pad(one_hot, ((context_radius, context_radius), (0, 0)), mode='constant', constant_values=0.0) 49 | 50 | mutation_set = [one_hot_padded[i:i+context_width] for i in range(len(seq)) if seq[i]!=wt_sequence[i]] 51 | padding_len = max_mutation_num - len(mutation_set) 52 | mutation_set_mask = [1.0] * len(mutation_set) + [0.0] * padding_len 53 | mutation_set += [np.zeros(shape=(context_width, len(alphabet)))] * padding_len 54 | 55 | mutation_set_List.append(mutation_set) 56 | mutation_set_mask_List.append(mutation_set_mask) 57 | 58 | mutation_sets = torch.tensor(np.array(mutation_set_List)).permute([0, 1, 3, 2]).float() 59 | mutation_sets_mask = torch.tensor(np.array(mutation_set_mask_List)).float() 60 | return mutation_sets, mutation_sets_mask 61 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/vec_env/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | from gym import spaces 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper 7 | from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations 8 | 9 | 10 | class VecFrameStack(VecEnvWrapper): 11 | """ 12 | Frame stacking wrapper for vectorized environment. Designed for image observations. 13 | 14 | Uses the StackedObservations class, or StackedDictObservations depending on the observations space 15 | 16 | :param venv: the vectorized environment to wrap 17 | :param n_stack: Number of frames to stack 18 | :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. 19 | If None, automatically detect channel to stack over in case of image observation or default to "last" (default). 20 | Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces 21 | """ 22 | 23 | def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None): 24 | self.venv = venv 25 | self.n_stack = n_stack 26 | 27 | wrapped_obs_space = venv.observation_space 28 | 29 | if isinstance(wrapped_obs_space, spaces.Box): 30 | assert not isinstance( 31 | channels_order, dict 32 | ), f"Expected None or string for channels_order but received {channels_order}" 33 | self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) 34 | 35 | elif isinstance(wrapped_obs_space, spaces.Dict): 36 | self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) 37 | 38 | else: 39 | raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces") 40 | 41 | observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space) 42 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 43 | 44 | def step_wait( 45 | self, 46 | ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]: 47 | 48 | observations, rewards, dones, infos = self.venv.step_wait() 49 | 50 | observations, infos = self.stackedobs.update(observations, dones, infos) 51 | 52 | return observations, rewards, dones, infos 53 | 54 | def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: 55 | """ 56 | Reset all environments 57 | """ 58 | observation = self.venv.reset() # pytype:disable=annotation-type-mismatch 59 | 60 | observation = self.stackedobs.reset(observation) 61 | return observation 62 | 63 | def close(self) -> None: 64 | self.venv.close() 65 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F401 2 | import typing 3 | from copy import deepcopy 4 | from typing import Optional, Type, Union 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper 7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 8 | from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan 11 | from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs 12 | from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack 13 | from stable_baselines3.common.vec_env.vec_monitor import VecMonitor 14 | from stable_baselines3.common.vec_env.vec_normalize import VecNormalize 15 | from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage 16 | from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder 17 | 18 | # Avoid circular import 19 | if typing.TYPE_CHECKING: 20 | from stable_baselines3.common.type_aliases import GymEnv 21 | 22 | 23 | def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]: 24 | """ 25 | Retrieve a ``VecEnvWrapper`` object by recursively searching. 26 | 27 | :param env: 28 | :param vec_wrapper_class: 29 | :return: 30 | """ 31 | env_tmp = env 32 | while isinstance(env_tmp, VecEnvWrapper): 33 | if isinstance(env_tmp, vec_wrapper_class): 34 | return env_tmp 35 | env_tmp = env_tmp.venv 36 | return None 37 | 38 | 39 | def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]: 40 | """ 41 | :param env: 42 | :return: 43 | """ 44 | return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type 45 | 46 | 47 | def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool: 48 | """ 49 | Check if an environment is already wrapped by a given ``VecEnvWrapper``. 50 | 51 | :param env: 52 | :param vec_wrapper_class: 53 | :return: 54 | """ 55 | return unwrap_vec_wrapper(env, vec_wrapper_class) is not None 56 | 57 | 58 | # Define here to avoid circular import 59 | def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: 60 | """ 61 | Sync eval env and train env when using VecNormalize 62 | 63 | :param env: 64 | :param eval_env: 65 | """ 66 | env_tmp, eval_env_tmp = env, eval_env 67 | while isinstance(env_tmp, VecEnvWrapper): 68 | if isinstance(env_tmp, VecNormalize): 69 | # Only synchronize if observation normalization exists 70 | if hasattr(env_tmp, "obs_rms"): 71 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) 72 | eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) 73 | env_tmp = env_tmp.venv 74 | eval_env_tmp = eval_env_tmp.venv 75 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/random.py: -------------------------------------------------------------------------------- 1 | """Defines the Random explorer class.""" 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import flexs 8 | from flexs.utils import sequence_utils as s_utils 9 | 10 | 11 | class Random(flexs.Explorer): 12 | """A simple random explorer. 13 | 14 | Chooses a random previously measured sequence and mutates it. 15 | 16 | A good baseline to compare other search strategies against. 17 | 18 | Since random search is not data-driven, the model is only used to score 19 | sequences, but not to guide the search strategy. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | model: flexs.Model, 25 | rounds: int, 26 | starting_sequence: str, 27 | sequences_batch_size: int, 28 | model_queries_per_batch: int, 29 | alphabet: str, 30 | mu: float = 1, 31 | elitist: bool = False, 32 | seed: Optional[int] = None, 33 | log_file: Optional[str] = None, 34 | ): 35 | """ 36 | Create a random search explorer. 37 | 38 | Args: 39 | mu: Average number of residue mutations from parent for generated sequences. 40 | elitist: If true, will propose the top `sequences_batch_size` sequences 41 | generated according to `model`. If false, randomly proposes 42 | `sequences_batch_size` sequences without taking model score into 43 | account (true random search). 44 | seed: Integer seed for random number generator. 45 | 46 | """ 47 | name = f"Random_mu={mu}" 48 | 49 | super().__init__( 50 | model, 51 | name, 52 | rounds, 53 | sequences_batch_size, 54 | model_queries_per_batch, 55 | starting_sequence, 56 | log_file, 57 | ) 58 | self.mu = mu 59 | self.rng = np.random.default_rng(seed) 60 | self.alphabet = alphabet 61 | self.elitist = elitist 62 | 63 | def propose_sequences( 64 | self, measured_sequences: pd.DataFrame 65 | ) -> Tuple[np.ndarray, np.ndarray]: 66 | """Propose top `sequences_batch_size` sequences for evaluation.""" 67 | old_sequences = measured_sequences["sequence"] 68 | old_sequence_set = set(old_sequences) 69 | new_seqs = set() 70 | 71 | while len(new_seqs) <= self.model_queries_per_batch: 72 | seq = self.rng.choice(old_sequences) 73 | new_seq = s_utils.generate_random_mutant( 74 | seq, self.mu / len(seq), alphabet=self.alphabet 75 | ) 76 | 77 | if new_seq not in old_sequence_set: 78 | new_seqs.add(new_seq) 79 | 80 | new_seqs = np.array(list(new_seqs)) 81 | preds = self.model.get_fitness(new_seqs) 82 | 83 | if self.elitist: 84 | idxs = np.argsort(preds)[: -self.sequences_batch_size : -1] 85 | else: 86 | idxs = self.rng.integers(0, len(new_seqs), size=self.sequences_batch_size) 87 | 88 | return new_seqs[idxs], preds[idxs] 89 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/type_aliases.py: -------------------------------------------------------------------------------- 1 | """Common aliases for type hints""" 2 | 3 | from enum import Enum 4 | from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union 5 | 6 | import gym 7 | import numpy as np 8 | import torch as th 9 | 10 | from stable_baselines3.common import callbacks, vec_env 11 | 12 | GymEnv = Union[gym.Env, vec_env.VecEnv] 13 | GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] 14 | GymStepReturn = Tuple[GymObs, float, bool, Dict] 15 | TensorDict = Dict[Union[str, int], th.Tensor] 16 | OptimizerStateDict = Dict[str, Any] 17 | MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] 18 | 19 | # A schedule takes the remaining progress as input 20 | # and ouputs a scalar (e.g. learning rate, clip range, ...) 21 | Schedule = Callable[[float], float] 22 | 23 | 24 | class RolloutBufferSamples(NamedTuple): 25 | observations: th.Tensor 26 | actions: th.Tensor 27 | old_values: th.Tensor 28 | old_log_prob: th.Tensor 29 | advantages: th.Tensor 30 | returns: th.Tensor 31 | 32 | 33 | class RolloutBufferForProteinSamples(NamedTuple): 34 | observations: th.Tensor 35 | actions: th.Tensor 36 | observations_2: th.Tensor # Concatenate 'observations_1' with 'actions_1' (Additional property compared with RolloutBufferSamples) 37 | actions_2: th.Tensor # Second subaction (Additional property compared with RolloutBufferSamples) 38 | old_values: th.Tensor 39 | old_values_2: th.Tensor # Value function based on 'observations_2' (Additional property compared with RolloutBufferSamples) 40 | old_log_prob: th.Tensor 41 | old_log_prob_2: th.Tensor # Log probability based on 'observations_2' (Additional property compared with RolloutBufferSamples) 42 | advantages: th.Tensor 43 | advantages_2: th.Tensor # Advantages based on 'old_values_2' (Additional property compared with RolloutBufferSamples) 44 | returns: th.Tensor 45 | returns_2: th.Tensor # Returns based on 'old_values_2' and 'advantages_2' (Additional property compared with RolloutBufferSamples) 46 | 47 | 48 | class DictRolloutBufferSamples(RolloutBufferSamples): 49 | observations: TensorDict 50 | actions: th.Tensor 51 | old_values: th.Tensor 52 | old_log_prob: th.Tensor 53 | advantages: th.Tensor 54 | returns: th.Tensor 55 | 56 | 57 | class ReplayBufferSamples(NamedTuple): 58 | observations: th.Tensor 59 | actions: th.Tensor 60 | next_observations: th.Tensor 61 | dones: th.Tensor 62 | rewards: th.Tensor 63 | 64 | 65 | class DictReplayBufferSamples(ReplayBufferSamples): 66 | observations: TensorDict 67 | actions: th.Tensor 68 | next_observations: TensorDict 69 | dones: th.Tensor 70 | rewards: th.Tensor 71 | 72 | 73 | class RolloutReturn(NamedTuple): 74 | episode_timesteps: int 75 | n_episodes: int 76 | continue_training: bool 77 | 78 | 79 | class TrainFrequencyUnit(Enum): 80 | STEP = "step" 81 | EPISODE = "episode" 82 | 83 | 84 | class TrainFreq(NamedTuple): 85 | frequency: int 86 | unit: TrainFrequencyUnit # either "step" or "episode" 87 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/models/sklearn_models.py: -------------------------------------------------------------------------------- 1 | """Define scikit-learn model wrappers as well a few convenient pre-wrapped models.""" 2 | import abc 3 | 4 | import numpy as np 5 | import sklearn.ensemble 6 | import sklearn.linear_model 7 | 8 | import flexs 9 | from flexs.utils import sequence_utils as s_utils 10 | 11 | 12 | class SklearnModel(flexs.Model, abc.ABC): 13 | """Base sklearn model wrapper.""" 14 | 15 | def __init__(self, model, alphabet, name): 16 | """ 17 | Args: 18 | model: sklearn model to wrap. 19 | alphabet: Alphabet string. 20 | name: Human-readable short model descriptipon (for logging). 21 | 22 | """ 23 | super().__init__(name) 24 | 25 | self.model = model 26 | self.alphabet = alphabet 27 | 28 | def train(self, sequences, labels): 29 | """Flatten one-hot sequences and train model using `model.fit`.""" 30 | one_hots = np.array( 31 | [s_utils.string_to_one_hot(seq, self.alphabet) for seq in sequences] 32 | ) 33 | flattened = one_hots.reshape( 34 | one_hots.shape[0], one_hots.shape[1] * one_hots.shape[2] 35 | ) 36 | self.model.fit(flattened, labels) 37 | 38 | 39 | class SklearnRegressor(SklearnModel, abc.ABC): 40 | """Class for sklearn regressors (uses `model.predict`).""" 41 | 42 | def _fitness_function(self, sequences): 43 | one_hots = np.array( 44 | [s_utils.string_to_one_hot(seq, self.alphabet) for seq in sequences] 45 | ) 46 | flattened = one_hots.reshape( 47 | one_hots.shape[0], one_hots.shape[1] * one_hots.shape[2] 48 | ) 49 | 50 | return self.model.predict(flattened) 51 | 52 | 53 | class SklearnClassifier(SklearnModel, abc.ABC): 54 | """Class for sklearn classifiers (uses `model.predict_proba(...)[:, 1]`).""" 55 | 56 | def _fitness_function(self, sequences): 57 | one_hots = np.array( 58 | [s_utils.string_to_one_hot(seq, self.alphabet) for seq in sequences] 59 | ) 60 | flattened = one_hots.reshape( 61 | one_hots.shape[0], one_hots.shape[1] * one_hots.shape[2] 62 | ) 63 | 64 | return self.model.predict_proba(flattened)[:, 1] 65 | 66 | 67 | class LinearRegression(SklearnRegressor): 68 | """Sklearn linear regression.""" 69 | 70 | def __init__(self, alphabet, **kwargs): 71 | """Create linear regression model.""" 72 | model = sklearn.linear_model.LinearRegression(**kwargs) 73 | super().__init__(model, alphabet, "linear_regression") 74 | 75 | 76 | class LogisticRegression(SklearnRegressor): 77 | """Sklearn logistic regression.""" 78 | 79 | def __init__(self, alphabet, **kwargs): 80 | """Create logistic regression model.""" 81 | model = sklearn.linear_model.LogisticRegression(**kwargs) 82 | super().__init__(model, alphabet, "logistic_regression") 83 | 84 | 85 | class RandomForest(SklearnRegressor): 86 | """Sklearn random forest regressor.""" 87 | 88 | def __init__(self, alphabet, **kwargs): 89 | """Create random forest regressor.""" 90 | model = sklearn.ensemble.RandomForestRegressor(**kwargs) 91 | super().__init__(model, alphabet, "random_forest") 92 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/tf_binding.py: -------------------------------------------------------------------------------- 1 | """Define TFBinding landscape and problem registry.""" 2 | import os 3 | from typing import Dict 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import flexs 9 | from flexs.types import SEQUENCES_TYPE 10 | 11 | 12 | class TFBinding(flexs.Landscape): 13 | """ 14 | A landscape of binding affinity of proposed 8-mer DNA sequences to a 15 | particular transcription factor. 16 | 17 | We use experimental data from Barrera et al. (2016), a survey of the binding 18 | affinity of more than one hundred and fifty transcription factors (TF) to all 19 | possible DNA sequences of length 8. 20 | """ 21 | 22 | def __init__(self, landscape_file: str): 23 | """ 24 | Create a TFBinding landscape from experimental data .csv file. 25 | 26 | See https://github.com/samsinai/FLSD-Sandbox/tree/stewy-redesign/flexs/landscapes/data/tf_binding # noqa: E501 27 | for examples. 28 | """ 29 | super().__init__(name="TF_Binding") 30 | 31 | # Load TF pairwise TF binding measurements from file 32 | data = pd.read_csv(landscape_file, sep="\t") 33 | score = data["E-score"] # "E-score" is enrichment score 34 | norm_score = (score - score.min()) / (score.max() - score.min()) 35 | 36 | # The csv file keeps one DNA strand's sequence in "8-mer" and the other in 37 | # "8-mer.1". 38 | # Since it doesn't really matter which strand we have, we will map the sequences 39 | # of both strands to the same normalized enrichment score. 40 | self.sequences = dict(zip(data["8-mer"], norm_score)) 41 | self.sequences.update(zip(data["8-mer.1"], norm_score)) 42 | 43 | def _fitness_function(self, sequences: SEQUENCES_TYPE) -> np.ndarray: 44 | return np.array([self.sequences[seq] for seq in sequences]) 45 | 46 | 47 | def registry() -> Dict[str, Dict]: 48 | """ 49 | Return a dictionary of problems of the form: 50 | 51 | ```python 52 | { 53 | "problem name": { 54 | "params": ..., 55 | }, 56 | ... 57 | } 58 | ``` 59 | 60 | where `flexs.landscapes.TFBinding(**problem["params"])` instantiates the 61 | transcription factor binding landscape for the given set of parameters. 62 | 63 | Returns: 64 | Problems in the registry. 65 | 66 | """ 67 | tf_binding_data_dir = os.path.join(os.path.dirname(__file__), "data/tf_binding") 68 | 69 | problems = {} 70 | for fname in os.listdir(tf_binding_data_dir): 71 | problem_name = fname.replace("_8mers.txt", "") 72 | 73 | problems[problem_name] = { 74 | "params": {"landscape_file": os.path.join(tf_binding_data_dir, fname)}, 75 | "starts": [ 76 | "GCTCGAGC", 77 | "GCGCGCGC", 78 | "TGCGCGCC", 79 | "ATATAGCC", 80 | "GTTTGGTA", 81 | "ATTATGTT", 82 | "CAGTTTTT", 83 | "AAAAATTT", 84 | "AAAAACGC", 85 | "GTTGTTTT", 86 | "TGCTTTTT", 87 | "AAAGATAG", 88 | "CCTTCTTT", 89 | "AAAGAGAG", 90 | ], 91 | } 92 | 93 | return problems 94 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/vec_env/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for dealing with vectorized environments. 3 | """ 4 | from collections import OrderedDict 5 | from typing import Any, Dict, List, Tuple 6 | 7 | import gym 8 | import numpy as np 9 | 10 | from stable_baselines3.common.preprocessing import check_for_nested_spaces 11 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs 12 | 13 | 14 | def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 15 | """ 16 | Deep-copy a dict of numpy arrays. 17 | 18 | :param obs: a dict of numpy arrays. 19 | :return: a dict of copied numpy arrays. 20 | """ 21 | assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" 22 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) 23 | 24 | 25 | def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: 26 | """ 27 | Convert an internal representation raw_obs into the appropriate type 28 | specified by space. 29 | 30 | :param obs_space: an observation space. 31 | :param obs_dict: a dict of numpy arrays. 32 | :return: returns an observation of the same type as space. 33 | If space is Dict, function is identity; if space is Tuple, converts dict to Tuple; 34 | otherwise, space is unstructured and returns the value raw_obs[None]. 35 | """ 36 | if isinstance(obs_space, gym.spaces.Dict): 37 | return obs_dict 38 | elif isinstance(obs_space, gym.spaces.Tuple): 39 | assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" 40 | return tuple(obs_dict[i] for i in range(len(obs_space.spaces))) 41 | else: 42 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" 43 | return obs_dict[None] 44 | 45 | 46 | def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]: 47 | """ 48 | Get dict-structured information about a gym.Space. 49 | 50 | Dict spaces are represented directly by their dict of subspaces. 51 | Tuple spaces are converted into a dict with keys indexing into the tuple. 52 | Unstructured spaces are represented by {None: obs_space}. 53 | 54 | :param obs_space: an observation space 55 | :return: A tuple (keys, shapes, dtypes): 56 | keys: a list of dict keys. 57 | shapes: a dict mapping keys to shapes. 58 | dtypes: a dict mapping keys to dtypes. 59 | """ 60 | check_for_nested_spaces(obs_space) 61 | if isinstance(obs_space, gym.spaces.Dict): 62 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" 63 | subspaces = obs_space.spaces 64 | elif isinstance(obs_space, gym.spaces.Tuple): 65 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)} 66 | else: 67 | assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" 68 | subspaces = {None: obs_space} 69 | keys = [] 70 | shapes = {} 71 | dtypes = {} 72 | for key, box in subspaces.items(): 73 | keys.append(key) 74 | shapes[key] = box.shape 75 | dtypes[key] = box.dtype 76 | return keys, shapes, dtypes 77 | -------------------------------------------------------------------------------- /mu-former/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | The folder primarily hosts the code for μFormer, or Mu-Former, uFormer, Muformer for readability, a potent tool tailored for predicting the effects of protein mutations. 3 | 4 | 5 | # Environment 6 | 7 | To ensure optimal functioning of the uFormer application, a specific Conda environment should be set up. We've tested this setup using CUDA Version 12.2. Follow the steps below to set up the Conda environment: 8 | ``` 9 | conda create -n mutation python==3.8 10 | conda activate mutation 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | Additionally, you need to install PyTorch. The version to be installed is dependent on your GPU driver version. For instance: 15 | ``` 16 | pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 17 | ``` 18 | Or for a cpu-only version: 19 | ``` 20 | pip install torch==1.12.0+cpu torchvision==0.13.0+cpu torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu 21 | ``` 22 | 23 | # Getting Started 24 | 25 | The pre-trained encoder and sample datasets for fine-tuning are publicly accessible on [figshare](https://doi.org/10.6084/m9.figshare.26892355). Begin by downloading the checkpoint and sample datasets to your local storage. Subsequently, you can follow the provided command lines to fine-tune the model for your data, using the pre-prepared encoder. 26 | 27 | If you're using a single GPU card, you can run the application using the following command: 28 | 29 | ``` 30 | python main.py --decoder-name siamese --encoder-lr 1e-6 --decoder-lr 1e-4 \ 31 | --epochs 300 --warmup-epochs 10 --batch-size 8 \ 32 | --pretrained-model \ 33 | --fasta \ 34 | --train \ 35 | --valid \ 36 | --test \ 37 | --output-dir 38 | ``` 39 | 40 | In case you're running the program on a node with multiple GPU cards (4, for example), the command can be adjusted as follows: 41 | ``` 42 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=4 --node_rank=0 --master_port=6005 \ 43 | main.py --decoder-name siamese --encoder-lr 1e-6 --decoder-lr 1e-4 --batch-size 2 \ 44 | --epochs 300 --warmup-epochs 10 \ 45 | --pretrained-model \ 46 | --fasta \ 47 | --train \ 48 | --valid \ 49 | --test \ 50 | --output-dir 51 | ``` 52 | 53 | ### Running an Example 54 | 55 | After downloading the files from **[Figshare](https://doi.org/10.6084/m9.figshare.26892355)**, place the checkpoint file in a subfolder named **`ckpt/`** within this directory. 56 | 57 | You can then run the following command to test the setup: 58 | 59 | ``` 60 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=4 --node_rank=0 --master_port=6005 \ 61 | main.py --decoder-name siamese --encoder-lr 1e-6 --decoder-lr 1e-4 --batch-size 2 \ 62 | --epochs 300 --warmup-epochs 10 \ 63 | --pretrained-model ckpt/uformer_l_encoder.pt \ 64 | --fasta data/example/IF1_ECOLI_Kelsic_2016.fasta \ 65 | --train data/example/IF1_ECOLI_Kelsic_2016_train.tsv \ 66 | --valid data/example/IF1_ECOLI_Kelsic_2016_valid.tsv \ 67 | --test data/example/IF1_ECOLI_Kelsic_2016_test.tsv \ 68 | --output-dir output/ 69 | ``` 70 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/models/cnn.py: -------------------------------------------------------------------------------- 1 | """Define a baseline CNN Model.""" 2 | import tensorflow as tf 3 | 4 | from . import keras_model 5 | import numpy as np 6 | import flexs.utils.sequence_utils as s_utils 7 | import torch 8 | 9 | 10 | class CNN(keras_model.KerasModel): 11 | """A baseline CNN model with 3 conv layers and 2 dense layers.""" 12 | 13 | def __init__( 14 | self, 15 | seq_len: int, 16 | num_filters: int, 17 | hidden_size: int, 18 | alphabet: str, 19 | loss="MSE", 20 | name: str = None, 21 | batch_size: int = 256, 22 | epochs: int = 20, 23 | ): 24 | """Create the CNN.""" 25 | model = tf.keras.models.Sequential( 26 | [ 27 | tf.keras.layers.Conv1D( 28 | num_filters, 29 | len(alphabet) - 1, 30 | padding="valid", 31 | strides=1, 32 | input_shape=(seq_len, len(alphabet)), 33 | ), 34 | tf.keras.layers.Conv1D( 35 | num_filters, 20, padding="same", activation="relu", strides=1 36 | ), 37 | tf.keras.layers.MaxPooling1D(1), 38 | tf.keras.layers.Conv1D( 39 | num_filters, 40 | len(alphabet) - 1, 41 | padding="same", 42 | activation="relu", 43 | strides=1, 44 | ), 45 | tf.keras.layers.GlobalMaxPooling1D(), 46 | tf.keras.layers.Dense(hidden_size, activation="relu"), 47 | tf.keras.layers.Dense(hidden_size, activation="relu"), 48 | tf.keras.layers.Dropout(0.25), 49 | tf.keras.layers.Dense(1), 50 | ] 51 | ) 52 | 53 | model.compile(loss=loss, optimizer="adam", metrics=["mse"]) 54 | 55 | if name is None: 56 | name = f"CNN_hidden_size_{hidden_size}_num_filters_{num_filters}" 57 | 58 | super().__init__( 59 | model, 60 | alphabet=alphabet, 61 | name=name, 62 | batch_size=batch_size, 63 | epochs=epochs, 64 | ) 65 | 66 | # @tf.function 67 | def gradient_function(self, sequences): 68 | one_hots = tf.convert_to_tensor( 69 | np.array( 70 | [s_utils.string_to_one_hot(seq, self.alphabet) for seq in sequences] 71 | ), 72 | dtype=tf.float32, 73 | ) 74 | 75 | # Use GradientTape to track operations for gradient computation 76 | with tf.GradientTape() as tape: 77 | # Tell TensorFlow to watch this tensor for gradient computation 78 | tape.watch(one_hots) 79 | # Get predictions from the model 80 | predictions = self.model(one_hots, training=False) 81 | predictions = tf.squeeze(predictions, axis=1) 82 | 83 | # Compute gradients 84 | gradients = tape.gradient(predictions, one_hots) 85 | gradients = torch.tensor(gradients.numpy()) 86 | # tf.print(f"cnn.py 84 gradients: {gradients}, one_hots: {one_hots}") 87 | gradients_cur = tf.reduce_sum(gradients * one_hots, axis=-1, keepdims=True) 88 | delta_ij = gradients - gradients_cur 89 | 90 | return gradients, delta_ij 91 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/environments/env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from functools import reduce 3 | from gym import Env 4 | from gym.spaces import MultiDiscrete 5 | 6 | 7 | class LandscapeEnv(Env): 8 | """ RL environment for protein optimization. """ 9 | def __init__(self, model, alphabet, horizon, starting_seq, for_rna=True) -> None: 10 | super().__init__() 11 | self.model = model 12 | self.horizon = horizon 13 | self.starting_seq = starting_seq 14 | self.alphabet = alphabet 15 | self.char2int = {} 16 | for i, a in enumerate(alphabet): 17 | self.char2int[a] = i 18 | self.array2str = lambda arr: reduce(lambda x,y:x+y, map(lambda x:self.alphabet[x], arr)) 19 | self.str2array = lambda s: np.array(list(map(lambda x: self.char2int[x], s))) 20 | 21 | num_alphabets = len(alphabet) 22 | sequence_length = len(self.starting_seq) 23 | 24 | observation_space_list = [num_alphabets] * 2 * sequence_length 25 | action_space_list = [sequence_length, num_alphabets] 26 | observation_space1_list = observation_space_list 27 | observation_space2_list = observation_space1_list + [action_space_list[0]] 28 | action_space1_list = action_space_list[:1] 29 | action_space2_list = action_space_list[1:] 30 | 31 | self.observation_space = MultiDiscrete(observation_space_list) 32 | self.action_space = MultiDiscrete(action_space_list) 33 | self.observation_space1 = MultiDiscrete(observation_space1_list) 34 | self.observation_space2 = MultiDiscrete(observation_space2_list) 35 | self.action_space1 = MultiDiscrete(action_space1_list) 36 | self.action_space2 = MultiDiscrete(action_space2_list) 37 | 38 | self.reset() 39 | 40 | def step(self, action): 41 | loc = action[0] 42 | mutate_to = action[1] 43 | self.cur_mutations += 1 44 | self.cur_seq = self.cur_seq[:loc] + self.alphabet[mutate_to] + self.cur_seq[loc+1:] 45 | 46 | done = self.cur_mutations == self.horizon 47 | # print(f"wild type: {self.starting_seq}") 48 | # print(f"wild type fitness: {self.model.get_fitness([self.starting_seq])}") 49 | # print(f"self.cur_seq: {self.cur_seq}, fitness: {self.model.get_fitness([self.cur_seq])}, mutations: {self.cur_mutations}, done: {done}") 50 | if done: 51 | # try: 52 | # reward, embedding = self.model.get_fitness([self.cur_seq]) 53 | # return self.obs, reward[0], done, {"current_sequence": self.str2array(self.cur_seq), "embedding": embedding[0], "ensemble_uncertainty": ensemble_uncertainty[0]} 54 | # except: 55 | reward = self.model.get_fitness([self.cur_seq]) 56 | embedding = None 57 | ensemble_uncertainty = None 58 | return self.obs, reward[0], done, {"current_sequence": self.str2array(self.cur_seq), "embedding": None, "ensemble_uncertainty": None} 59 | else: 60 | return self.obs, 0, done, {} 61 | 62 | def reset(self): 63 | self.cur_mutations = 0 64 | self.cur_seq = self.starting_seq 65 | return self.obs 66 | 67 | def change_starting_sequence(self, new_starting_seq): 68 | self.starting_seq = new_starting_seq 69 | self.reset() 70 | 71 | @property 72 | def obs(self): 73 | return self.str2array(self.starting_seq + self.cur_seq) 74 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/vec_env/vec_check_nan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 6 | 7 | 8 | class VecCheckNan(VecEnvWrapper): 9 | """ 10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default, 11 | allowing you to know from what the NaN of inf originated from. 12 | 13 | :param venv: the vectorized environment to wrap 14 | :param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning 15 | :param warn_once: Whether or not to only warn once. 16 | :param check_inf: Whether or not to check for +inf or -inf as well 17 | """ 18 | 19 | def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True): 20 | VecEnvWrapper.__init__(self, venv) 21 | self.raise_exception = raise_exception 22 | self.warn_once = warn_once 23 | self.check_inf = check_inf 24 | self._actions = None 25 | self._observations = None 26 | self._user_warned = False 27 | 28 | def step_async(self, actions: np.ndarray) -> None: 29 | self._check_val(async_step=True, actions=actions) 30 | 31 | self._actions = actions 32 | self.venv.step_async(actions) 33 | 34 | def step_wait(self) -> VecEnvStepReturn: 35 | observations, rewards, news, infos = self.venv.step_wait() 36 | 37 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news) 38 | 39 | self._observations = observations 40 | return observations, rewards, news, infos 41 | 42 | def reset(self) -> VecEnvObs: 43 | observations = self.venv.reset() 44 | self._actions = None 45 | 46 | self._check_val(async_step=False, observations=observations) 47 | 48 | self._observations = observations 49 | return observations 50 | 51 | def _check_val(self, *, async_step: bool, **kwargs) -> None: 52 | # if warn and warn once and have warned once: then stop checking 53 | if not self.raise_exception and self.warn_once and self._user_warned: 54 | return 55 | 56 | found = [] 57 | for name, val in kwargs.items(): 58 | has_nan = np.any(np.isnan(val)) 59 | has_inf = self.check_inf and np.any(np.isinf(val)) 60 | if has_inf: 61 | found.append((name, "inf")) 62 | if has_nan: 63 | found.append((name, "nan")) 64 | 65 | if found: 66 | self._user_warned = True 67 | msg = "" 68 | for i, (name, type_val) in enumerate(found): 69 | msg += f"found {type_val} in {name}" 70 | if i != len(found) - 1: 71 | msg += ", " 72 | 73 | msg += ".\r\nOriginated from the " 74 | 75 | if not async_step: 76 | if self._actions is None: 77 | msg += "environment observation (at reset)" 78 | else: 79 | msg += f"environment, Last given value was: \r\n\taction={self._actions}" 80 | else: 81 | msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}" 82 | 83 | if self.raise_exception: 84 | raise ValueError(msg) 85 | else: 86 | warnings.warn(msg, UserWarning) 87 | -------------------------------------------------------------------------------- /mu-former/src/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import torch 4 | import pathlib 5 | import numpy as np 6 | 7 | import os 8 | 9 | class Logger(object): 10 | def __init__(self, logfile=None, level=logging.INFO): 11 | ''' 12 | logfile: pathlib object 13 | ''' 14 | self.logger = logging.getLogger() 15 | self.logger.setLevel(level) 16 | formatter = logging.Formatter("[%(asctime)s] %(message)s", "%Y-%m-%d %H:%M:%S") 17 | 18 | for hd in self.logger.handlers[:]: 19 | self.logger.removeHandler(hd) 20 | 21 | sh = logging.StreamHandler(sys.stdout) 22 | sh.setFormatter(formatter) 23 | self.logger.addHandler(sh) 24 | 25 | if logfile is not None: 26 | logfile.parent.mkdir(exist_ok=True, parents=True) 27 | fh = logging.FileHandler(logfile, 'w') 28 | fh.setFormatter(formatter) 29 | self.logger.addHandler(fh) 30 | 31 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 32 | local_rank = int(os.environ["RANK"]) 33 | if local_rank == 0: 34 | self.should_log = True 35 | else: 36 | self.should_log = False 37 | else: 38 | self.should_log = True 39 | 40 | def debug(self, msg): 41 | if self.should_log: 42 | self.logger.debug(msg) 43 | 44 | def info(self, msg): 45 | if self.should_log: 46 | self.logger.info(msg) 47 | 48 | def warning(self, msg): 49 | self.logger.warning(msg) 50 | 51 | def error(self, msg): 52 | self.logger.error(msg) 53 | 54 | 55 | class Saver(object): 56 | def __init__(self, output_dir): 57 | self.save_dir = pathlib.Path(output_dir) 58 | 59 | def save_ckp(self, pt, filename='checkpoint.pt'): 60 | self.save_dir.mkdir(exist_ok=True, parents=True) 61 | # _use_new_zipfile_serialization=True 62 | torch.save(pt, str(self.save_dir/filename)) 63 | 64 | def save_df(self, df, filename): 65 | self.save_dir.mkdir(exist_ok=True, parents=True) 66 | df.to_csv(self.save_dir/filename, float_format='%.8f', index=False, sep='\t') 67 | 68 | 69 | class EarlyStopping(object): 70 | def __init__(self, 71 | patience=100, eval_freq=1, best_score=None, 72 | delta=1e-9, higher_better=True): 73 | self.patience = patience 74 | self.eval_freq = eval_freq 75 | self.best_score = best_score 76 | self.delta = delta 77 | self.higher_better = higher_better 78 | self.counter = 0 79 | self.early_stop = False 80 | 81 | def not_improved(self, val_score): 82 | if np.isnan(val_score): 83 | return True 84 | if self.higher_better: 85 | return val_score < self.best_score + self.delta 86 | else: 87 | return val_score > self.best_score - self.delta 88 | 89 | def update(self, val_score): 90 | if self.best_score is None: 91 | self.best_score = val_score 92 | is_best = True 93 | elif self.not_improved(val_score): 94 | self.counter += self.eval_freq 95 | if (self.patience is not None) and (self.counter > self.patience): 96 | self.early_stop = True 97 | is_best = False 98 | else: 99 | self.best_score = val_score 100 | self.counter = 0 101 | is_best = True 102 | return is_best 103 | -------------------------------------------------------------------------------- /mu-search/tests/test_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import sklearn 4 | 5 | import flexs 6 | from flexs import baselines 7 | 8 | rng = np.random.default_rng() 9 | 10 | 11 | class FakeModel(flexs.Model): 12 | def _fitness_function(self, sequences): 13 | return rng.random(size=len(sequences)) 14 | 15 | def train(self, *args, **kwargs): 16 | pass 17 | 18 | 19 | class FakeLandscape(flexs.Landscape): 20 | def _fitness_function(self, sequences): 21 | return rng.random(size=len(sequences)) 22 | 23 | 24 | class FakeConstantModel(flexs.Model): 25 | def __init__(self, constant): 26 | super().__init__(name="ConstantModel") 27 | self.constant = constant 28 | 29 | def _fitness_function(self, sequences): 30 | return np.ones(len(sequences)) * self.constant 31 | 32 | def train(self, *args, **kwargs): 33 | pass 34 | 35 | 36 | def test_adaptive_ensemble(): 37 | models = [FakeConstantModel(1), FakeConstantModel(2)] 38 | ens = baselines.models.AdaptiveEnsemble(models) 39 | 40 | assert np.sum(ens.weights) == 1 41 | 42 | assert ens.get_fitness(["ATC"]) == 1.5 43 | 44 | models = [FakeModel(name="FakeModel") for _ in range(2)] 45 | ens = baselines.models.AdaptiveEnsemble(models) 46 | 47 | ens.train(["ATC"] * 15, list(range(15))) 48 | 49 | print(ens.weights) 50 | assert np.any(ens.weights != np.ones(len(models)) / len(models)) 51 | # Possible floating-point error from summation 52 | assert np.isclose(np.sum(ens.weights), 1) 53 | 54 | 55 | def test_keras_models(): 56 | cnn = baselines.models.CNN( 57 | seq_len=3, 58 | num_filters=1, 59 | hidden_size=1, 60 | kernel_size=2, 61 | alphabet=flexs.utils.sequence_utils.DNAA, 62 | ) 63 | cnn.get_fitness(["ATC"]) 64 | 65 | gem = baselines.models.GlobalEpistasisModel( 66 | seq_len=3, 67 | hidden_size=1, 68 | alphabet=flexs.utils.sequence_utils.DNAA, 69 | ) 70 | gem.get_fitness(["ATC"]) 71 | 72 | mlp = baselines.models.MLP( 73 | seq_len=3, 74 | hidden_size=1, 75 | alphabet=flexs.utils.sequence_utils.DNAA, 76 | ) 77 | mlp.get_fitness(["ATC"]) 78 | 79 | 80 | def test_noisy_abstract_model(): 81 | nam = baselines.models.NoisyAbstractModel( 82 | landscape=FakeLandscape(name="FakeLandscape") 83 | ) 84 | assert len(nam.cache) == 0 85 | fitness = nam.get_fitness(["ATC"]) 86 | assert len(nam.cache) == 1 87 | assert nam.get_fitness(["ATC"]) == fitness 88 | 89 | nam = baselines.models.NoisyAbstractModel( 90 | landscape=FakeConstantModel(2), signal_strength=1 91 | ) 92 | assert nam.get_fitness(["ATC"]) == [2] 93 | 94 | nam = baselines.models.NoisyAbstractModel( 95 | landscape=FakeConstantModel(2), signal_strength=0 96 | ) 97 | nam.get_fitness(["ATC"]) 98 | # Flaky, but extremely unlikely to fail 99 | assert nam.get_fitness(["ATG"]) != [2] 100 | 101 | 102 | def test_sklearn_models(): 103 | sklearn_models = [ 104 | baselines.models.LinearRegression, 105 | baselines.models.LogisticRegression, 106 | baselines.models.RandomForest, 107 | ] 108 | for model in sklearn_models: 109 | m = model( 110 | alphabet=flexs.utils.sequence_utils.DNAA, 111 | ) 112 | with pytest.raises(sklearn.exceptions.NotFittedError): 113 | m.get_fitness(["ATC"]) 114 | m.train(["ATC", "ATG"], [1, 2]) 115 | m.get_fitness(["ATC"]) 116 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/vocab.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from io import StringIO 3 | import pandas as pd 4 | 5 | ''' 6 | Amino acide encoding modified from 7 | https://github.com/openvax/mhcflurry/blob/74b751e6d72605eef4a49641d364066193541b5a/mhcflurry/amino_acid.py 8 | ''' 9 | COMMON_AMINO_ACIDS_INDEX = collections.OrderedDict( 10 | {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 11 | 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 12 | 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 13 | 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19, '-': 20}) 14 | AMINO_ACIDS = list(COMMON_AMINO_ACIDS_INDEX.keys()) 15 | 16 | AMINO_ACID_INDEX = collections.OrderedDict( 17 | {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 18 | 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 19 | 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 20 | 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19, 21 | 'X': 20, 'Z': 20, 'B': 20, 'J': 20, '-': 20}) 22 | 23 | ''' 24 | CCMPred index of amino acid 25 | https://github.com/soedinglab/CCMpred/blob/2b2f9a0747a5e53035c33636d430f2f11dc186dd/src/sequence.c 26 | ''' 27 | CCMPRED_AMINO_ACID_INDEX1 = collections.OrderedDict( 28 | {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 29 | 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 30 | 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 31 | 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, '-': 20}) 32 | CCMPRED_AMINO_ACID_INDEX = collections.OrderedDict( 33 | {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 34 | 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 35 | 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 36 | 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, 'B': 20, 37 | 'J': 20, 'O': 20, 'U': 20, 'X': 20, 'Z': 20, '-': 20, '_': 20}) 38 | CCMPRED_AMINO_ACIDS = list(CCMPRED_AMINO_ACID_INDEX.keys()) 39 | 40 | BLOSUM62_MATRIX = pd.read_csv(StringIO(""" 41 | A R N D C Q E G H I L K M F P S T W Y V - 42 | A 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0 0 43 | R -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3 0 44 | N -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3 0 45 | D -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3 0 46 | C 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 0 47 | Q -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2 0 48 | E -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2 0 49 | G 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3 0 50 | H -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3 0 51 | I -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3 0 52 | L -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1 0 53 | K -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2 0 54 | M -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1 0 55 | F -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1 0 56 | P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2 0 57 | S 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2 0 58 | T 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0 0 59 | W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3 0 60 | Y -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1 0 61 | V 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4 0 62 | - 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 63 | """), sep='\s+').loc[AMINO_ACIDS, AMINO_ACIDS] 64 | 65 | ENCODING_DATA_FRAMES = { 66 | "BLOSUM62": BLOSUM62_MATRIX, 67 | "one-hot": pd.DataFrame([ 68 | [1 if i == j else 0 for i in range(len(AMINO_ACIDS))] 69 | for j in range(len(AMINO_ACIDS)) 70 | ], index=AMINO_ACIDS, columns=AMINO_ACIDS) 71 | } -------------------------------------------------------------------------------- /mu-former/src/vocab.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from io import StringIO 3 | import pandas as pd 4 | 5 | ''' 6 | Amino acide encoding modified from 7 | https://github.com/openvax/mhcflurry/blob/74b751e6d72605eef4a49641d364066193541b5a/mhcflurry/amino_acid.py 8 | ''' 9 | COMMON_AMINO_ACIDS_INDEX = collections.OrderedDict( 10 | { 11 | 'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 12 | 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 13 | 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 14 | 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19, '-': 20 15 | }) 16 | 17 | AMINO_ACIDS = list(COMMON_AMINO_ACIDS_INDEX.keys()) 18 | 19 | AMINO_ACID_INDEX = collections.OrderedDict({ 20 | 'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 21 | 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 22 | 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 23 | 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19, 24 | 'X': 20, 'Z': 20, 'B': 20, 'J': 20, '-': 20 25 | }) 26 | 27 | ''' 28 | CCMPred index of amino acid 29 | https://github.com/soedinglab/CCMpred/blob/2b2f9a0747a5e53035c33636d430f2f11dc186dd/src/sequence.c 30 | ''' 31 | CCMPRED_AMINO_ACID_INDEX1 = collections.OrderedDict({ 32 | 'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 33 | 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 34 | 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 35 | 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, '-': 20 36 | }) 37 | 38 | CCMPRED_AMINO_ACID_INDEX = collections.OrderedDict({ 39 | 'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 40 | 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 41 | 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 42 | 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, 'B': 20, 43 | 'J': 20, 'O': 20, 'U': 20, 'X': 20, 'Z': 20, '-': 20, '_': 20 44 | }) 45 | 46 | CCMPRED_AMINO_ACIDS = list(CCMPRED_AMINO_ACID_INDEX.keys()) 47 | 48 | BLOSUM62_MATRIX = pd.read_csv(StringIO(""" 49 | A R N D C Q E G H I L K M F P S T W Y V - 50 | A 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0 0 51 | R -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3 0 52 | N -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3 0 53 | D -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3 0 54 | C 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 0 55 | Q -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2 0 56 | E -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2 0 57 | G 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3 0 58 | H -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3 0 59 | I -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3 0 60 | L -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1 0 61 | K -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2 0 62 | M -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1 0 63 | F -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1 0 64 | P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2 0 65 | S 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2 0 66 | T 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0 0 67 | W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3 0 68 | Y -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1 0 69 | V 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4 0 70 | - 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 71 | """), sep='\s+').loc[AMINO_ACIDS, AMINO_ACIDS] 72 | 73 | ENCODING_DATA_FRAMES = { 74 | "BLOSUM62": BLOSUM62_MATRIX, 75 | "one-hot": pd.DataFrame([ 76 | [1 if i == j else 0 for i in range(len(AMINO_ACIDS))] 77 | for j in range(len(AMINO_ACIDS)) 78 | ], index=AMINO_ACIDS, columns=AMINO_ACIDS) 79 | } -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/vocab.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from io import StringIO 3 | import pandas as pd 4 | 5 | ''' 6 | Amino acide encoding modified from 7 | https://github.com/openvax/mhcflurry/blob/74b751e6d72605eef4a49641d364066193541b5a/mhcflurry/amino_acid.py 8 | ''' 9 | COMMON_AMINO_ACIDS_INDEX = collections.OrderedDict( 10 | {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 11 | 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 12 | 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 13 | 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19, '-': 20}) 14 | AMINO_ACIDS = list(COMMON_AMINO_ACIDS_INDEX.keys()) 15 | 16 | AMINO_ACID_INDEX = collections.OrderedDict( 17 | {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 18 | 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 19 | 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 20 | 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19, 21 | 'X': 20, 'Z': 20, 'B': 20, 'J': 20, '-': 20}) 22 | 23 | ''' 24 | CCMPred index of amino acid 25 | https://github.com/soedinglab/CCMpred/blob/2b2f9a0747a5e53035c33636d430f2f11dc186dd/src/sequence.c 26 | ''' 27 | CCMPRED_AMINO_ACID_INDEX1 = collections.OrderedDict( 28 | {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 29 | 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 30 | 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 31 | 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, '-': 20}) 32 | CCMPRED_AMINO_ACID_INDEX = collections.OrderedDict( 33 | {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 34 | 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 35 | 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 36 | 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, 'B': 20, 37 | 'J': 20, 'O': 20, 'U': 20, 'X': 20, 'Z': 20, '-': 20, '_': 20}) 38 | CCMPRED_AMINO_ACIDS = list(CCMPRED_AMINO_ACID_INDEX.keys()) 39 | 40 | BLOSUM62_MATRIX = pd.read_csv(StringIO(""" 41 | A R N D C Q E G H I L K M F P S T W Y V - 42 | A 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0 0 43 | R -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3 0 44 | N -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3 0 45 | D -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3 0 46 | C 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 0 47 | Q -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2 0 48 | E -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2 0 49 | G 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3 0 50 | H -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3 0 51 | I -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3 0 52 | L -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1 0 53 | K -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2 0 54 | M -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1 0 55 | F -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1 0 56 | P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2 0 57 | S 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2 0 58 | T 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0 0 59 | W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3 0 60 | Y -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1 0 61 | V 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4 0 62 | - 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 63 | """), sep='\s+').loc[AMINO_ACIDS, AMINO_ACIDS] 64 | 65 | ENCODING_DATA_FRAMES = { 66 | "BLOSUM62": BLOSUM62_MATRIX, 67 | "one-hot": pd.DataFrame([ 68 | [1 if i == j else 0 for i in range(len(AMINO_ACIDS))] 69 | for j in range(len(AMINO_ACIDS)) 70 | ], index=AMINO_ACIDS, columns=AMINO_ACIDS) 71 | } -------------------------------------------------------------------------------- /mu-search/tests/test_explorers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import flexs 4 | from flexs import baselines 5 | 6 | 7 | class FakeModel(flexs.Model): 8 | def _fitness_function(self, sequences): 9 | return np.random.random(size=len(sequences)) 10 | 11 | def train(self, *args, **kwargs): 12 | pass 13 | 14 | 15 | class FakeLandscape(flexs.Landscape): 16 | def _fitness_function(self, sequences): 17 | return np.random.random(size=len(sequences)) 18 | 19 | 20 | starting_sequence = "ATCATCAT" 21 | fakeModel = FakeModel(name="FakeModel") 22 | fakeLandscape = FakeLandscape(name="FakeLandscape") 23 | 24 | 25 | def test_random(): 26 | explorer = baselines.explorers.Random( 27 | model=fakeModel, 28 | rounds=3, 29 | sequences_batch_size=5, 30 | model_queries_per_batch=20, 31 | starting_sequence=starting_sequence, 32 | alphabet="ATCG", 33 | ) 34 | explorer.run(fakeLandscape) 35 | 36 | 37 | def test_adalead(): 38 | explorer = baselines.explorers.Adalead( 39 | model=fakeModel, 40 | rounds=3, 41 | sequences_batch_size=5, 42 | model_queries_per_batch=20, 43 | eval_batch_size=1, 44 | starting_sequence=starting_sequence, 45 | alphabet="ATCG", 46 | ) 47 | explorer.run(fakeLandscape) 48 | 49 | 50 | def test_bo(): 51 | explorer = baselines.explorers.BO( 52 | model=fakeModel, 53 | rounds=3, 54 | sequences_batch_size=5, 55 | model_queries_per_batch=20, 56 | starting_sequence=starting_sequence, 57 | alphabet="ATCG", 58 | ) 59 | explorer.run(fakeLandscape) 60 | 61 | 62 | def test_gpr_bo(): 63 | explorer = baselines.explorers.GPR_BO( 64 | model=fakeModel, 65 | rounds=3, 66 | sequences_batch_size=5, 67 | model_queries_per_batch=20, 68 | starting_sequence=starting_sequence, 69 | alphabet="ATCG", 70 | ) 71 | explorer.run(fakeLandscape) 72 | 73 | 74 | def test_dqn(): 75 | explorer = baselines.explorers.DQN( 76 | model=fakeModel, 77 | rounds=3, 78 | sequences_batch_size=5, 79 | model_queries_per_batch=20, 80 | starting_sequence=starting_sequence, 81 | alphabet="ATCG", 82 | ) 83 | explorer.run(fakeLandscape) 84 | 85 | 86 | def test_dynappo(): 87 | explorer = baselines.explorers.DynaPPO( 88 | landscape=fakeLandscape, 89 | rounds=3, 90 | sequences_batch_size=5, 91 | model_queries_per_batch=20, 92 | starting_sequence=starting_sequence, 93 | alphabet="ATCG", 94 | num_experiment_rounds=1, 95 | num_model_rounds=1, 96 | ) 97 | explorer.run(fakeLandscape) 98 | 99 | 100 | def test_cmaes(): 101 | explorer = baselines.explorers.CMAES( 102 | fakeModel, 103 | population_size=15, 104 | max_iter=200, 105 | initial_variance=0.3, 106 | rounds=3, 107 | starting_sequence=starting_sequence, 108 | sequences_batch_size=5, 109 | model_queries_per_batch=20, 110 | alphabet="ATCG", 111 | ) 112 | explorer.run(fakeLandscape) 113 | 114 | 115 | def test_cbas(): 116 | vae = baselines.explorers.VAE( 117 | len(starting_sequence), "ATCG", epochs=2, verbose=False 118 | ) 119 | explorer = baselines.explorers.CbAS( 120 | fakeModel, 121 | vae, 122 | rounds=3, 123 | starting_sequence=starting_sequence, 124 | sequences_batch_size=5, 125 | model_queries_per_batch=20, 126 | alphabet="ATCG", 127 | ) 128 | explorer.run(fakeLandscape) 129 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import torch 4 | import pathlib 5 | import numpy as np 6 | 7 | class Logger(object): 8 | def __init__(self, logfile=None, level=logging.INFO): 9 | ''' 10 | logfile: pathlib object 11 | ''' 12 | self.logger = logging.getLogger() 13 | self.logger.setLevel(level) 14 | formatter = logging.Formatter("%(asctime)s\t%(message)s", "%Y-%m-%d %H:%M:%S") 15 | 16 | for hd in self.logger.handlers[:]: 17 | self.logger.removeHandler(hd) 18 | 19 | sh = logging.StreamHandler(sys.stdout) 20 | sh.setFormatter(formatter) 21 | self.logger.addHandler(sh) 22 | 23 | if logfile is not None: 24 | logfile.parent.mkdir(exist_ok=True, parents=True) 25 | fh = logging.FileHandler(logfile, 'w') 26 | fh.setFormatter(formatter) 27 | self.logger.addHandler(fh) 28 | 29 | def debug(self, msg): 30 | self.logger.debug(msg) 31 | 32 | def info(self, msg): 33 | self.logger.info(msg) 34 | 35 | def warning(self, msg): 36 | self.logger.warning(msg) 37 | 38 | def error(self, msg): 39 | self.logger.error(msg) 40 | 41 | 42 | class Saver(object): 43 | def __init__(self, output_dir): 44 | self.save_dir = pathlib.Path(output_dir) 45 | 46 | def save_ckp(self, pt, filename='checkpoint.pt'): 47 | self.save_dir.mkdir(exist_ok=True, parents=True) 48 | # _use_new_zipfile_serialization=True 49 | torch.save(pt, str(self.save_dir/filename)) 50 | 51 | def save_df(self, df, filename): 52 | self.save_dir.mkdir(exist_ok=True, parents=True) 53 | df.to_csv(self.save_dir/filename, float_format='%.8f', index=False, sep='\t') 54 | 55 | 56 | class EarlyStopping(object): 57 | def __init__(self, 58 | patience=100, eval_freq=1, best_score=None, 59 | delta=1e-9, higher_better=True): 60 | self.patience = patience 61 | self.eval_freq = eval_freq 62 | self.best_score = best_score 63 | self.delta = delta 64 | self.higher_better = higher_better 65 | self.counter = 0 66 | self.early_stop = False 67 | 68 | def not_improved(self, val_score): 69 | if np.isnan(val_score): 70 | return True 71 | if self.higher_better: 72 | return abs(val_score) < abs(self.best_score) + self.delta 73 | else: 74 | return abs(val_score) > abs(self.best_score) - self.delta 75 | 76 | def update(self, val_score): 77 | if self.best_score is None: 78 | self.best_score = val_score 79 | is_best = True 80 | elif self.not_improved(val_score): 81 | self.counter += self.eval_freq 82 | if (self.patience is not None) and (self.counter > self.patience): 83 | self.early_stop = True 84 | is_best = False 85 | else: 86 | self.best_score = val_score 87 | self.counter = 0 88 | is_best = True 89 | return is_best 90 | 91 | def save_read_listxt(path, list=None): 92 | ''' 93 | :param path: 储存list的位置 94 | :param list: list数据 95 | :return: None/relist 当仅有path参数输入时为读取模式将txt读取为list 96 | 当path参数和list都有输入时为保存模式将list保存为txt 97 | ''' 98 | 99 | if list != None: 100 | file = open(path, 'w') 101 | file.write(str(list)) 102 | file.close() 103 | return None 104 | else: 105 | file = open(path, 'r') 106 | rdlist = eval(file.read()) 107 | file.close() 108 | return rdlist -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/bert_gfp.py: -------------------------------------------------------------------------------- 1 | """Defines the BertGFPBrightness landscape.""" 2 | import os 3 | 4 | import numpy as np 5 | import requests 6 | import tape 7 | import torch 8 | 9 | import flexs 10 | 11 | 12 | class BertGFPBrightness(flexs.Landscape): 13 | r""" 14 | Green fluorescent protein (GFP) brightness landscape. 15 | 16 | The oracle used in this lanscape is the transformer model 17 | from TAPE (https://github.com/songlab-cal/tape). 18 | 19 | To create the transformer model used here, run the command: 20 | 21 | ```tape-train transformer fluorescence --from_pretrained bert-base \ 22 | --batch_size 128 \ 23 | --gradient_accumulation_steps 10 \ 24 | --data_dir .``` 25 | 26 | Note that the output of this landscape is not normalized to be between 0 and 1. 27 | 28 | Attributes: 29 | gfp_wt_sequence (str): Wild-type sequence for jellyfish 30 | green fluorescence protein. 31 | 32 | """ 33 | 34 | gfp_wt_sequence = ( 35 | "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVT" 36 | "TLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIE" 37 | "LKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNT" 38 | "PIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK" 39 | ) 40 | 41 | def __init__(self): 42 | """ 43 | Create GFP landscape. 44 | 45 | Downloads model into `./fluorescence-model` if not already cached there. 46 | If interrupted during download, may have to delete this folder and try again. 47 | """ 48 | super().__init__(name="GFP") 49 | 50 | # Download GFP model weights and config info 51 | if not os.path.exists("fluorescence-model"): 52 | os.mkdir("fluorescence-model") 53 | 54 | # URL for BERT GFP fluorescence model 55 | gfp_model_path = "https://fluorescence-model.s3.amazonaws.com/fluorescence_transformer_20-05-25-03-49-06_184764/" # noqa: E501 56 | for file_name in [ 57 | "args.json", 58 | "checkpoint.bin", 59 | "config.json", 60 | "pytorch_model.bin", 61 | ]: 62 | print("Downloading", file_name) 63 | response = requests.get(gfp_model_path + file_name) 64 | with open(f"fluorescence-model/{file_name}", "wb") as f: 65 | f.write(response.content) 66 | 67 | self.tokenizer = tape.TAPETokenizer(vocab="iupac") 68 | 69 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 70 | self.model = tape.ProteinBertForValuePrediction.from_pretrained( 71 | "fluorescence-model" 72 | ).to(self.device) 73 | 74 | def _fitness_function(self, sequences, batch_size=20): 75 | sequences = np.array(sequences) 76 | num_batches = (len(sequences) + batch_size - 1) // batch_size # 计算需要的批次数 77 | 78 | results = [] 79 | 80 | for i in range(num_batches): 81 | batch_sequences = sequences[i * batch_size : (i + 1) * batch_size] 82 | 83 | # 将序列编码并移动到指定设备 84 | encoded_seqs = torch.tensor( 85 | [self.tokenizer.encode(seq) for seq in batch_sequences] 86 | ).to(self.device) 87 | 88 | try: 89 | batch_result = self.model(encoded_seqs)[0].detach().numpy().astype(float).reshape(-1) 90 | except: 91 | batch_result = self.model(encoded_seqs)[0].detach().cpu().numpy().astype(float).reshape(-1) 92 | 93 | results.extend(batch_result) 94 | 95 | return np.array(results) 96 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/models/noisy_abstract_model.py: -------------------------------------------------------------------------------- 1 | """Define the noisy abstract model class.""" 2 | import editdistance 3 | import numpy as np 4 | 5 | import flexs 6 | from flexs.types import SEQUENCES_TYPE 7 | 8 | 9 | class NoisyAbstractModel(flexs.Model): 10 | r""" 11 | Behaves like a ground truth model. 12 | 13 | It corrupts a ground truth model with noise, which is modulated by distance 14 | to already measured sequences. 15 | 16 | Specifically, $\hat{f}(x) = \alpha^d f(x) + (1 - \alpha^d) \epsilon$ where 17 | $\epsilon$ is drawn from an exponential distribution with mean $f(x)$ 18 | $d$ is the edit distance to the closest measured neighbor, 19 | and $\alpha$ is the signal strength. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | landscape: flexs.Landscape, 25 | signal_strength: float = 0.9, 26 | ): 27 | """ 28 | Create a noisy abstract model. 29 | 30 | Args: 31 | landscape: The ground truth landscape. 32 | signal_strength: A value between 0 and 1 representing the 33 | true signal strength. 34 | 35 | """ 36 | super().__init__(f"NAMb_ss{signal_strength}") 37 | 38 | self.landscape = landscape 39 | self.ss = signal_strength 40 | self.cache = {} 41 | 42 | def _get_min_distance(self, sequence): 43 | # Special case if cache is empty 44 | if len(self.cache) == 0: 45 | return 0, sequence 46 | 47 | new_dist = np.inf 48 | closest = None 49 | 50 | for seq in self.cache: 51 | dist = editdistance.eval(sequence, seq) 52 | 53 | if dist == 1: 54 | return dist, seq 55 | 56 | if dist < new_dist: 57 | new_dist = dist 58 | closest = seq 59 | 60 | return new_dist, closest 61 | 62 | def train(self, sequences: SEQUENCES_TYPE, labels: np.ndarray): 63 | """ 64 | Training step simply stores sequences and labels in a 65 | dictionary for future lookup. 66 | """ 67 | self.cache.update(zip(sequences, labels)) 68 | 69 | def _fitness_function(self, sequences): 70 | sequences = np.array(sequences) 71 | fitnesses = np.empty(len(sequences)) 72 | 73 | # We use cached evaluations so that the model gives deterministic outputs 74 | cached = np.array([seq in self.cache for seq in sequences]) 75 | fitnesses[cached] = np.array([self.cache[seq] for seq in sequences[cached]]) 76 | 77 | new_fitnesses = [] 78 | for seq in sequences[~cached]: 79 | 80 | # Otherwise, fitness = alpha * true_fitness + (1 - alpha) * noise 81 | # where alpha = signal_strength ^ (dist to nearest neighbor) 82 | # and noise is the nearest neighbor's fitness plus exponentially 83 | # distributed noise 84 | distance, neighbor_seq = self._get_min_distance(seq) 85 | 86 | signal = self.landscape.get_fitness([seq]).item() 87 | neighbor_fitness = self.landscape.get_fitness([neighbor_seq]).item() 88 | if neighbor_fitness >= 0: 89 | noise = np.random.exponential(scale=neighbor_fitness) 90 | else: 91 | noise = np.random.choice(list(self.cache.values())) 92 | 93 | alpha = self.ss ** distance 94 | new_fitnesses.append(alpha * signal + (1 - alpha) * noise) 95 | 96 | fitnesses[~cached] = new_fitnesses 97 | 98 | # Update cache with new sequences and their predicted fitnesses 99 | self.cache.update(zip(sequences[~cached], fitnesses[~cached])) 100 | 101 | return np.array(fitnesses) 102 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /mu-former/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/models/adaptive_ensemble.py: -------------------------------------------------------------------------------- 1 | """Defines the AdaptiveEnsemble model.""" 2 | from typing import List 3 | 4 | import numpy as np 5 | import scipy.stats 6 | import sklearn.model_selection 7 | 8 | import flexs 9 | from flexs.types import SEQUENCES_TYPE 10 | 11 | 12 | def r2_weights(model_preds: np.ndarray, labels: np.ndarray) -> np.ndarray: 13 | """ 14 | Args: 15 | model_preds: A numpy array of shape (num_models, num_samples) containing 16 | model predictions. 17 | labels: A numpy array of true labels. 18 | 19 | Returns: 20 | A numpy array of shape (num_models,) containing $r^2$ scores for models. 21 | 22 | """ 23 | r2s = np.array( 24 | [scipy.stats.pearsonr(preds, labels)[0] ** 2 for preds in model_preds] 25 | ) 26 | return r2s / r2s.sum() 27 | 28 | 29 | class AdaptiveEnsemble(flexs.Model): 30 | """ 31 | Ensemble class that weights individual model predictions adaptively, 32 | according to some reweighting function. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | models: List[flexs.Model], 38 | combine_with="sum", 39 | adapt_weights_with="r2_weights", 40 | adaptive_val_size: float = 0.2, 41 | ): 42 | """ 43 | Args: 44 | models: Models to ensemble 45 | combine_with: A function taking in weight vector and model outputs and 46 | returning an aggregate output per sample. `np.sum(weights * outputs)` 47 | by default. 48 | adapt_weights_with: A function taking in a numpy array of shape 49 | (num_models, num_samples) containing model predictions, and a numpy 50 | array of true labels (num_samples,) that returns an array of 51 | shape (num_models,) containing model_weights. `r2_weights` by default. 52 | adaptive_val_size: Portion of model training data to go into validation 53 | split used for computing adaptive weight values. 54 | """ 55 | name = f"AdaptiveEns({'|'.join(model.name for model in models)})" 56 | super().__init__(name) 57 | 58 | self.models = models 59 | self.weights = np.ones(len(models)) / len(models) 60 | 61 | if combine_with == "sum": 62 | combine_with = lambda w, x: np.sum(w * x, axis=1) 63 | self.combine_with = combine_with 64 | 65 | if adapt_weights_with == "r2_weights": 66 | adapt_weights_with = r2_weights 67 | self.adapt_weights_with = adapt_weights_with 68 | 69 | self.adaptive_val_size = adaptive_val_size 70 | 71 | def train(self, sequences: SEQUENCES_TYPE, labels: np.ndarray): 72 | """ 73 | Train each model in the ensemble and then adaptively reweight them 74 | according to `adapt_weights_with`. 75 | 76 | Args: 77 | sequences: Training sequences. 78 | lables: Training sequence labels. 79 | 80 | """ 81 | # If very few sequences, don't bother with reweighting 82 | if len(sequences) < 10: 83 | for model in self.models: 84 | model.train(sequences, labels) 85 | return 86 | 87 | (train_X, test_X, train_y, test_y,) = sklearn.model_selection.train_test_split( 88 | np.array(sequences), np.array(labels), test_size=self.adaptive_val_size 89 | ) 90 | 91 | for model in self.models: 92 | model.train(train_X, train_y) 93 | 94 | preds = np.stack([model.get_fitness(test_X) for model in self.models], axis=0) 95 | self.weights = self.adapt_weights_with(preds, test_y) 96 | 97 | def _fitness_function(self, sequences: SEQUENCES_TYPE) -> np.ndarray: 98 | scores = np.stack( 99 | [model.get_fitness(sequences) for model in self.models], axis=1 100 | ) 101 | 102 | return self.combine_with(self.weights, scores) 103 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/landscape/muformer/muformer_landscape/src/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import torch 4 | import pathlib 5 | import numpy as np 6 | 7 | class Logger(object): 8 | def __init__(self, logfile=None, level=logging.INFO): 9 | ''' 10 | logfile: pathlib object 11 | ''' 12 | self.logger = logging.getLogger() 13 | self.logger.setLevel(level) 14 | formatter = logging.Formatter("%(asctime)s\t%(message)s", "%Y-%m-%d %H:%M:%S") 15 | 16 | for hd in self.logger.handlers[:]: 17 | self.logger.removeHandler(hd) 18 | 19 | sh = logging.StreamHandler(sys.stdout) 20 | sh.setFormatter(formatter) 21 | self.logger.addHandler(sh) 22 | 23 | if logfile is not None: 24 | logfile.parent.mkdir(exist_ok=True, parents=True) 25 | fh = logging.FileHandler(logfile, 'w') 26 | fh.setFormatter(formatter) 27 | self.logger.addHandler(fh) 28 | 29 | def debug(self, msg): 30 | self.logger.debug(msg) 31 | 32 | def info(self, msg): 33 | self.logger.info(msg) 34 | 35 | def warning(self, msg): 36 | self.logger.warning(msg) 37 | 38 | def error(self, msg): 39 | self.logger.error(msg) 40 | 41 | 42 | class Saver(object): 43 | def __init__(self, output_dir): 44 | self.save_dir = pathlib.Path(output_dir) 45 | 46 | def save_ckp(self, pt, filename='checkpoint.pt'): 47 | self.save_dir.mkdir(exist_ok=True, parents=True) 48 | torch.save(pt, str(self.save_dir/filename)) 49 | 50 | def save_df(self, df, filename): 51 | self.save_dir.mkdir(exist_ok=True, parents=True) 52 | df.to_csv(self.save_dir/filename, float_format='%.8f', index=False, sep='\t') 53 | 54 | 55 | class EarlyStopping(object): 56 | def __init__(self, 57 | patience=100, eval_freq=1, best_score=None, 58 | least_loss=None, delta=1e-9, higher_better=True): 59 | self.patience = patience 60 | self.eval_freq = eval_freq 61 | self.best_score = best_score 62 | self.least_loss = least_loss 63 | self.delta = delta 64 | self.higher_better = higher_better 65 | self.counter = 0 66 | self.early_stop = False 67 | 68 | def not_improved(self, val_score, val_loss): 69 | if np.isnan(val_score): 70 | return True 71 | if self.higher_better: 72 | # return (val_score < self.best_score + self.delta) 73 | return (val_score < self.best_score - self.delta) or ((val_score > self.best_score - self.delta) and (val_loss > self.least_loss - self.delta)) 74 | else: 75 | # return val_score > self.best_score - self.delta 76 | return (val_score > self.best_score + self.delta) or ((val_score < self.best_score + self.delta) and (val_loss > self.least_loss - self.delta)) 77 | 78 | def update(self, val_score, val_loss): 79 | # val_score = abs(val_score) 80 | if self.best_score is None: 81 | self.best_score = val_score 82 | self.least_loss = val_loss 83 | is_best = True 84 | elif self.not_improved(val_score,val_loss): 85 | self.counter += self.eval_freq 86 | if (self.patience is not None) and (self.counter > self.patience): 87 | self.early_stop = True 88 | is_best = False 89 | else: 90 | self.best_score = val_score 91 | self.least_loss = val_loss 92 | self.counter = 0 93 | is_best = True 94 | return is_best 95 | 96 | def save_read_listxt(path, list=None): 97 | if list != None: 98 | file = open(path, 'w') 99 | file.write(str(list)) 100 | file.close() 101 | return None 102 | else: 103 | file = open(path, 'r') 104 | rdlist = eval(file.read()) 105 | file.close() 106 | return rdlist -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/evoplay_utils/env_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | import numpy as np 10 | 11 | AAS = "ILVAGMFYWEDQNHCRKSTP" 12 | 13 | def string_to_one_hot(sequence: str, alphabet: str) -> np.ndarray: 14 | 15 | out = np.zeros((len(sequence), len(alphabet))) 16 | for i in range(len(sequence)): 17 | out[i, alphabet.index(sequence[i])] = 1 18 | return out 19 | 20 | def string_to_feature(string): 21 | seq_list = [] 22 | seq_list.append(string) 23 | seq_np = np.array( 24 | [string_to_one_hot(seq, AAS) for seq in seq_list] 25 | ) 26 | one_hots = torch.from_numpy(seq_np) 27 | one_hots = one_hots.to(torch.float32) 28 | return one_hots 29 | 30 | def predict(model,inputs): 31 | one_hots_0= string_to_one_hot(inputs, AAS) 32 | one_hots = torch.from_numpy(one_hots_0) 33 | one_hots = one_hots.unsqueeze(0) 34 | one_hots = one_hots.to(torch.float32) 35 | with torch.no_grad(): 36 | inputs = one_hots 37 | inputs = inputs.permute(0, 2, 1) 38 | outputs = model(inputs) 39 | outputs = outputs.squeeze() 40 | return outputs 41 | class CNN(nn.Module): 42 | """predictor network module""" 43 | 44 | def __init__( 45 | self, 46 | seq_len, 47 | alphabet_len, 48 | ): 49 | super(CNN, self).__init__() 50 | self.board_width = seq_len 51 | self.board_height = alphabet_len 52 | # conv layers 53 | self.conv1 = nn.Conv1d(20, 32, kernel_size=3) # 54 | self.conv2 = nn.Conv1d(32, 32, kernel_size=3, padding=1) #, padding=0 55 | self.conv3 = nn.Conv1d(32, 32, kernel_size=3, padding=1) # , padding=0 56 | self.maxpool1 = nn.MaxPool1d(kernel_size=1, stride=1) 57 | self.maxpool2 = nn.MaxPool1d(kernel_size=1, stride=1) 58 | 59 | self.val_fc1 = nn.Linear(7520, 100) 60 | 61 | self.val_fc2 = nn.Linear(100, 100) # * alphabet_len 62 | self.dropout = nn.Dropout(p=0.25) 63 | self.val_fc3 = nn.Linear(100, 1) 64 | 65 | def forward(self, input): 66 | x = F.relu(self.conv1(input)) 67 | x = F.relu(self.conv2(x)) 68 | x = F.relu(self.maxpool1(x)) 69 | 70 | x = F.relu(self.conv3(x)) 71 | x_act = F.relu(self.maxpool2(x)) 72 | x_score_1 = x_act.view(x_act.shape[0], -1) 73 | 74 | x_score_2 = F.relu(self.val_fc1(x_score_1)) 75 | x_score_2 = F.relu(self.val_fc2(x_score_2)) 76 | x_score_2 = self.dropout(x_score_2) 77 | x_score_3 = self.val_fc3(x_score_2) 78 | return x_score_3 79 | 80 | 81 | class CNN2(nn.Module): 82 | """predictor network module""" 83 | 84 | def __init__( 85 | self, 86 | seq_len, 87 | alphabet_len, 88 | ): 89 | super(CNN2, self).__init__() 90 | self.board_width = seq_len 91 | self.board_height = alphabet_len 92 | # conv layers 93 | self.conv1 = nn.Conv1d(20, 32, kernel_size=3) # 94 | self.conv2 = nn.Conv1d(32, 32, kernel_size=3, padding=1) #, padding=0 95 | self.conv3 = nn.Conv1d(32, 32, kernel_size=3, padding=1) # , padding=0 96 | 97 | self.maxpool1 = nn.MaxPool1d(kernel_size=1, stride=1) 98 | self.maxpool2 = nn.MaxPool1d(kernel_size=1, stride=1) 99 | 100 | self.val_fc1 = nn.Linear(2336, 100)# * alphabet_len 101 | self.val_fc2 = nn.Linear(100, 100) # * alphabet_len 102 | self.dropout = nn.Dropout(p=0.25) 103 | self.val_fc3 = nn.Linear(100, 1) 104 | 105 | def forward(self, input): 106 | x = F.relu(self.conv1(input)) 107 | x = F.relu(self.conv2(x)) 108 | x = F.relu(self.maxpool1(x)) 109 | 110 | x = F.relu(self.conv3(x)) 111 | x_act = F.relu(self.maxpool2(x)) 112 | 113 | x_score_1 = x_act.view(x_act.shape[0], -1) 114 | 115 | x_score_2 = F.relu(self.val_fc1(x_score_1)) 116 | x_score_2 = F.relu(self.val_fc2(x_score_2)) 117 | x_score_2 = self.dropout(x_score_2) 118 | x_score_3 = self.val_fc3(x_score_2) 119 | return x_score_3 120 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/vec_env/vec_monitor.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | from typing import Optional, Tuple 4 | 5 | import numpy as np 6 | 7 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 8 | 9 | 10 | class VecMonitor(VecEnvWrapper): 11 | """ 12 | A vectorized monitor wrapper for *vectorized* Gym environments, 13 | it is used to record the episode reward, length, time and other data. 14 | 15 | Some environments like `openai/procgen `_ 16 | or `gym3 `_ directly initialize the 17 | vectorized environments, without giving us a chance to use the ``Monitor`` 18 | wrapper. So this class simply does the job of the ``Monitor`` wrapper on 19 | a vectorized level. 20 | 21 | :param venv: The vectorized environment 22 | :param filename: the location to save a log file, can be None for no log 23 | :param info_keywords: extra information to log, from the information return of env.step() 24 | """ 25 | 26 | def __init__( 27 | self, 28 | venv: VecEnv, 29 | filename: Optional[str] = None, 30 | info_keywords: Tuple[str, ...] = (), 31 | ): 32 | # Avoid circular import 33 | from stable_baselines3.common.monitor import Monitor, ResultsWriter 34 | 35 | # This check is not valid for special `VecEnv` 36 | # like the ones created by Procgen, that does follow completely 37 | # the `VecEnv` interface 38 | try: 39 | is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0] 40 | except AttributeError: 41 | is_wrapped_with_monitor = False 42 | 43 | if is_wrapped_with_monitor: 44 | warnings.warn( 45 | "The environment is already wrapped with a `Monitor` wrapper" 46 | "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be" 47 | "overwritten by the `VecMonitor` ones.", 48 | UserWarning, 49 | ) 50 | 51 | VecEnvWrapper.__init__(self, venv) 52 | self.episode_returns = None 53 | self.episode_lengths = None 54 | self.episode_count = 0 55 | self.t_start = time.time() 56 | 57 | env_id = None 58 | if hasattr(venv, "spec") and venv.spec is not None: 59 | env_id = venv.spec.id 60 | 61 | if filename: 62 | self.results_writer = ResultsWriter( 63 | filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords 64 | ) 65 | else: 66 | self.results_writer = None 67 | self.info_keywords = info_keywords 68 | 69 | def reset(self) -> VecEnvObs: 70 | obs = self.venv.reset() 71 | self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) 72 | self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 73 | return obs 74 | 75 | def step_wait(self) -> VecEnvStepReturn: 76 | obs, rewards, dones, infos = self.venv.step_wait() 77 | self.episode_returns += rewards 78 | self.episode_lengths += 1 79 | new_infos = list(infos[:]) 80 | for i in range(len(dones)): 81 | if dones[i]: 82 | info = infos[i].copy() 83 | episode_return = self.episode_returns[i] 84 | episode_length = self.episode_lengths[i] 85 | episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)} 86 | for key in self.info_keywords: 87 | episode_info[key] = info[key] 88 | info["episode"] = episode_info 89 | self.episode_count += 1 90 | self.episode_returns[i] = 0 91 | self.episode_lengths[i] = 0 92 | if self.results_writer: 93 | self.results_writer.write_row(episode_info) 94 | new_infos[i] = info 95 | return obs, rewards, dones, new_infos 96 | 97 | def close(self) -> None: 98 | if self.results_writer: 99 | self.results_writer.close() 100 | return self.venv.close() 101 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/vec_env/vec_video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable 3 | 4 | from gym.wrappers.monitoring import video_recorder 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 8 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 9 | 10 | 11 | class VecVideoRecorder(VecEnvWrapper): 12 | """ 13 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. 14 | It requires ffmpeg or avconv to be installed on the machine. 15 | 16 | :param venv: 17 | :param video_folder: Where to save videos 18 | :param record_video_trigger: Function that defines when to start recording. 19 | The function takes the current number of step, 20 | and returns whether we should start recording or not. 21 | :param video_length: Length of recorded videos 22 | :param name_prefix: Prefix to the video name 23 | """ 24 | 25 | def __init__( 26 | self, 27 | venv: VecEnv, 28 | video_folder: str, 29 | record_video_trigger: Callable[[int], bool], 30 | video_length: int = 200, 31 | name_prefix: str = "rl-video", 32 | ): 33 | 34 | VecEnvWrapper.__init__(self, venv) 35 | 36 | self.env = venv 37 | # Temp variable to retrieve metadata 38 | temp_env = venv 39 | 40 | # Unwrap to retrieve metadata dict 41 | # that will be used by gym recorder 42 | while isinstance(temp_env, VecEnvWrapper): 43 | temp_env = temp_env.venv 44 | 45 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): 46 | metadata = temp_env.get_attr("metadata")[0] 47 | else: 48 | metadata = temp_env.metadata 49 | 50 | self.env.metadata = metadata 51 | 52 | self.record_video_trigger = record_video_trigger 53 | self.video_recorder = None 54 | 55 | self.video_folder = os.path.abspath(video_folder) 56 | # Create output folder if needed 57 | os.makedirs(self.video_folder, exist_ok=True) 58 | 59 | self.name_prefix = name_prefix 60 | self.step_id = 0 61 | self.video_length = video_length 62 | 63 | self.recording = False 64 | self.recorded_frames = 0 65 | 66 | def reset(self) -> VecEnvObs: 67 | obs = self.venv.reset() 68 | self.start_video_recorder() 69 | return obs 70 | 71 | def start_video_recorder(self) -> None: 72 | self.close_video_recorder() 73 | 74 | video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" 75 | base_path = os.path.join(self.video_folder, video_name) 76 | self.video_recorder = video_recorder.VideoRecorder( 77 | env=self.env, base_path=base_path, metadata={"step_id": self.step_id} 78 | ) 79 | 80 | self.video_recorder.capture_frame() 81 | self.recorded_frames = 1 82 | self.recording = True 83 | 84 | def _video_enabled(self) -> bool: 85 | return self.record_video_trigger(self.step_id) 86 | 87 | def step_wait(self) -> VecEnvStepReturn: 88 | obs, rews, dones, infos = self.venv.step_wait() 89 | 90 | self.step_id += 1 91 | if self.recording: 92 | self.video_recorder.capture_frame() 93 | self.recorded_frames += 1 94 | if self.recorded_frames > self.video_length: 95 | print(f"Saving video to {self.video_recorder.path}") 96 | self.close_video_recorder() 97 | elif self._video_enabled(): 98 | self.start_video_recorder() 99 | 100 | return obs, rews, dones, infos 101 | 102 | def close_video_recorder(self) -> None: 103 | if self.recording: 104 | self.video_recorder.close() 105 | self.recording = False 106 | self.recorded_frames = 1 107 | 108 | def close(self) -> None: 109 | VecEnvWrapper.close(self) 110 | self.close_video_recorder() 111 | 112 | def __del__(self): 113 | self.close() 114 | -------------------------------------------------------------------------------- /mu-former/src/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | def pearson_loss(x, y): 8 | mean_x = torch.mean(x) 9 | mean_y = torch.mean(y) 10 | xm = x.sub(mean_x) 11 | ym = y.sub(mean_y) 12 | r_num = xm.dot(ym) 13 | r_den = torch.norm(xm, 2) * torch.norm(ym, 2) 14 | r_val = r_num / r_den 15 | return 1 - r_val 16 | 17 | def pearson_correlation_loss(y_pred, y_true, normalized=False): 18 | """ 19 | Calculate pearson correlation loss 20 | :param y_true: distance matrix tensor tensor size (batch_size, batch_size) 21 | :param y_pred: distance matrix tensor tensor size (batch_size, batch_size) 22 | :param normalized: if True, Softmax is applied to the distance matrix 23 | :return: loss tensor 24 | """ 25 | if normalized: 26 | y_true = F.softmax(y_true, axis=-1) 27 | y_pred = F.softmax(y_pred, axis=-1) 28 | 29 | sum_true = torch.sum(y_true) 30 | sum2_true = torch.sum(torch.pow(y_true, 2)) # square ~= np.pow(a,2) 31 | 32 | sum_pred = torch.sum(y_pred) 33 | sum2_pred = torch.sum(torch.pow(y_pred, 2)) 34 | 35 | prod = torch.sum(y_true * y_pred) 36 | n = y_true.shape[0] # n == y_true.shape[0] 37 | 38 | corr = n * prod - sum_true * sum_pred 39 | 40 | corr /= torch.sqrt(n * sum2_true - sum_true * sum_true + torch.finfo(torch.float32).eps) 41 | corr /= torch.sqrt(torch.clamp((n * sum2_pred - sum_pred * sum_pred + torch.finfo(torch.float32).eps), min=0.000001)) 42 | 43 | return 1 - corr 44 | 45 | ### Reference: https://github.com/technicolor-research/sodeep/ 46 | class SpearmanLoss(nn.Module): 47 | """ Loss function inspired by spearmann correlation.self 48 | Required the trained model to have a good initlization. 49 | 50 | Set beta to 1 for a few epoch to help with the initialization. 51 | """ 52 | def __init__(self, beta=0.0): 53 | super(SpearmanLoss, self).__init__() 54 | 55 | self.criterion_rank = nn.MSELoss() 56 | self.criterion_score = nn.L1Loss() 57 | # self.criterion_score = nn.MSELoss() 58 | 59 | self.beta = beta 60 | 61 | def get_rank(self, batch_score, dim=0): 62 | rank = torch.argsort(batch_score, dim=dim) 63 | rank = torch.argsort(rank, dim=dim) 64 | rank = (rank * -1) + batch_score.size(dim) 65 | rank = rank.float() 66 | rank = rank / batch_score.size(dim) 67 | return rank 68 | 69 | def comp(self, inpu): 70 | in_mat1 = torch.triu(inpu.repeat(inpu.size(0), 1), diagonal=1) 71 | in_mat2 = torch.triu(inpu.repeat(inpu.size(0), 1).t(), diagonal=1) 72 | 73 | comp_first = (in_mat1 - in_mat2) 74 | comp_second = (in_mat2 - in_mat1) 75 | 76 | std1 = torch.std(comp_first).item() 77 | std2 = torch.std(comp_second).item() 78 | 79 | std1 = torch.finfo(torch.float32).eps if np.isnan(std1) else std1 80 | std2 = torch.finfo(torch.float32).eps if np.isnan(std2) else std2 81 | 82 | comp_first = torch.sigmoid(comp_first * (6.8 / (std1 + torch.finfo(torch.float32).eps))) 83 | comp_second = torch.sigmoid(comp_second * (6.8 / (std2 + torch.finfo(torch.float32).eps))) 84 | 85 | comp_first = torch.triu(comp_first, diagonal=1) 86 | comp_second = torch.triu(comp_second, diagonal=1) 87 | 88 | return (torch.sum(comp_first, 1) + torch.sum(comp_second, 0) + 1) / inpu.size(0) 89 | 90 | def sort(self, input_): 91 | out = [self.comp(input_[d]) for d in range(input_.size(0))] 92 | out = torch.stack(out) 93 | 94 | return out.view(input_.size(0), -1) 95 | 96 | def forward(self, mem_pred, mem_gt): 97 | rank_gt = self.get_rank(mem_gt) 98 | 99 | rank_pred = self.sort(mem_pred.unsqueeze(0)).view(-1) 100 | 101 | return self.criterion_rank(rank_pred, rank_gt) + self.beta * self.criterion_score(mem_pred, mem_gt) 102 | 103 | CRITERION = { 104 | 'mae': nn.L1Loss(), 105 | 'mse': nn.MSELoss(), 106 | 'pearson': pearson_correlation_loss, 107 | 'spearman': SpearmanLoss(), 108 | } 109 | 110 | def get_criterion(name): 111 | if name in CRITERION: 112 | return CRITERION[name] 113 | else: 114 | return nn.L1Loss() 115 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/evaluate.py: -------------------------------------------------------------------------------- 1 | """A small set of evaluation metrics to benchmark explorers.""" 2 | from typing import Callable, List, Tuple 3 | 4 | import flexs 5 | from flexs import baselines 6 | 7 | 8 | def robustness( 9 | landscape: flexs.Landscape, 10 | make_explorer: Callable[[flexs.Model, float], flexs.Explorer], 11 | signal_strengths: List[float] = [0, 0.5, 0.75, 0.9, 1], 12 | verbose: bool = True, 13 | ): 14 | """ 15 | Evaluate explorer outputs as a function of the noisyness of its model. 16 | 17 | It runs the same explorer with `flexs.NoisyAbstractModel`'s of different 18 | signal strengths. 19 | 20 | Args: 21 | landscape: The landscape to run on. 22 | make_explorer: A function that takes in a model and signal strength 23 | (for potential bookkeeping/logging purposes) and an explorer. 24 | signal_strengths: A list of signal strengths between 0 and 1. 25 | 26 | """ 27 | results = [] 28 | for ss in signal_strengths: 29 | print(f"Evaluating for robustness with model accuracy; signal_strength: {ss}") 30 | 31 | model = baselines.models.NoisyAbstractModel(landscape, signal_strength=ss) 32 | explorer = make_explorer(model, ss) 33 | res = explorer.run(landscape, verbose=verbose) 34 | 35 | results.append((ss, res)) 36 | 37 | return results 38 | 39 | 40 | def efficiency( 41 | landscape: flexs.Landscape, 42 | make_explorer: Callable[[int, int], flexs.Explorer], 43 | budgets: List[Tuple[int, int]] = [ 44 | (100, 500), 45 | (100, 5000), 46 | (1000, 5000), 47 | (1000, 10000), 48 | ], 49 | ): 50 | """ 51 | Evaluate explorer outputs as a function of the number of allowed ground truth 52 | measurements and model queries per round. 53 | 54 | Args: 55 | landscape: Ground truth fitness landscape. 56 | make_explorer: A function that takes in a `sequences_batch_size` and 57 | a `model_queries_per_batch` and returns an explorer. 58 | budgets: A list of tuples (`sequences_batch_size`, `model_queries_per_batch`). 59 | 60 | """ 61 | results = [] 62 | for sequences_batch_size, model_queries_per_batch in budgets: 63 | print( 64 | f"Evaluating for sequences_batch_size: {sequences_batch_size}, " 65 | f"model_queries_per_batch: {model_queries_per_batch}" 66 | ) 67 | explorer = make_explorer(sequences_batch_size, model_queries_per_batch) 68 | print("-----------------") 69 | print(explorer) 70 | print(explorer.model) 71 | print("-----------------") 72 | res = explorer.run( 73 | landscape 74 | ) # TODO: is this being logged? bc the last budget pair would take very long 75 | 76 | results.append(((sequences_batch_size, model_queries_per_batch), res)) 77 | 78 | return results 79 | 80 | 81 | def adaptivity( 82 | landscape: flexs.Landscape, 83 | make_explorer: Callable[[int, int, int], flexs.Explorer], 84 | num_rounds: List[int] = [1, 10, 100], 85 | total_ground_truth_measurements: int = 1000, 86 | total_model_queries: int = 10000, 87 | ): 88 | """ 89 | For a fixed total budget of ground truth measurements and model queries, 90 | run with different numbers of rounds. 91 | 92 | Args: 93 | landscape: Ground truth fitness landscape. 94 | make_explorer: A function that takes in a number of rounds, a 95 | `sequences_batch_size` and a `model_queries_per_batch` and returns an 96 | explorer. 97 | num_rounds: A list of number of rounds to run the explorer with. 98 | total_ground_truth_measurements: Total number of ground truth measurements 99 | across all rounds (`sequences_batch_size * rounds`). 100 | total_model_queries: Total number of model queries across all rounds 101 | (`model_queries_per_round * rounds`). 102 | 103 | """ 104 | results = [] 105 | for rounds in num_rounds: 106 | print(f"Evaluating for num_rounds: {rounds}") 107 | explorer = make_explorer( 108 | rounds, 109 | int(total_ground_truth_measurements / rounds), 110 | int(total_model_queries / rounds), 111 | ) 112 | res = explorer.run(landscape) 113 | 114 | results.append((rounds, res)) 115 | 116 | return results 117 | -------------------------------------------------------------------------------- /pmlm/src/protein/models/layer.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from fairseq import utils 6 | from fairseq.modules import LayerNorm 7 | from fairseq.modules.fairseq_dropout import FairseqDropout 8 | from fairseq.modules.quant_noise import quant_noise 9 | 10 | from .multihead_attention import MultiheadAttention 11 | 12 | class TransformerProteinEncoderLayer(nn.Module): 13 | 14 | def __init__( 15 | self, 16 | embedding_dim: int = 768, 17 | ffn_embedding_dim: int = 3072, 18 | num_attention_heads: int = 8, 19 | dropout: float = 0.1, 20 | attention_dropout: float = 0.1, 21 | activation_dropout: float = 0.1, 22 | activation_fn: str = "relu", 23 | export: bool = False, 24 | q_noise: float = 0.0, 25 | qn_block_size: int = 8, 26 | init_fn: Callable = None, 27 | ) -> None: 28 | super().__init__() 29 | 30 | if init_fn is not None: 31 | init_fn() 32 | 33 | # Initialize parameters 34 | self.embedding_dim = embedding_dim 35 | self.dropout_module = FairseqDropout( 36 | dropout, module_name=self.__class__.__name__ 37 | ) 38 | self.activation_dropout_module = FairseqDropout( 39 | activation_dropout, module_name=self.__class__.__name__ 40 | ) 41 | 42 | # Initialize blocks 43 | self.activation_fn = utils.get_activation_fn(activation_fn) 44 | self.self_attn = self.build_self_attention( 45 | self.embedding_dim, 46 | num_attention_heads, 47 | dropout=attention_dropout, 48 | self_attention=True, 49 | q_noise=q_noise, 50 | qn_block_size=qn_block_size, 51 | ) 52 | 53 | # layer norm associated with the self attention layer 54 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) 55 | 56 | self.fc1 = self.build_fc1( 57 | self.embedding_dim, 58 | ffn_embedding_dim, 59 | q_noise=q_noise, 60 | qn_block_size=qn_block_size, 61 | ) 62 | self.fc2 = self.build_fc2( 63 | ffn_embedding_dim, 64 | self.embedding_dim, 65 | q_noise=q_noise, 66 | qn_block_size=qn_block_size, 67 | ) 68 | 69 | # layer norm associated with the position wise feed-forward NN 70 | self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) 71 | 72 | def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): 73 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 74 | 75 | def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): 76 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 77 | 78 | def build_self_attention( 79 | self, 80 | embed_dim, 81 | num_attention_heads, 82 | dropout, 83 | self_attention, 84 | q_noise, 85 | qn_block_size, 86 | ): 87 | return MultiheadAttention( 88 | embed_dim, 89 | num_attention_heads, 90 | dropout=dropout, 91 | self_attention=self_attention, 92 | q_noise=q_noise, 93 | qn_block_size=qn_block_size, 94 | ) 95 | 96 | def forward( 97 | self, 98 | x: torch.Tensor, 99 | self_attn_mask: Optional[torch.Tensor] = None, 100 | self_attn_padding_mask: Optional[torch.Tensor] = None, 101 | ): 102 | """ 103 | LayerNorm is applied either before or after the self-attention/ffn 104 | modules similar to the original Transformer implementation. 105 | """ 106 | residual = x 107 | x, attn = self.self_attn( 108 | query=x, 109 | key=x, 110 | value=x, 111 | key_padding_mask=self_attn_padding_mask, 112 | need_weights=True, 113 | attn_mask=self_attn_mask, 114 | need_head_weights=True, 115 | ) 116 | x = self.dropout_module(x) 117 | x = residual + x 118 | x = self.self_attn_layer_norm(x) 119 | 120 | residual = x 121 | x = self.activation_fn(self.fc1(x)) 122 | x = self.activation_dropout_module(x) 123 | x = self.fc2(x) 124 | x = self.dropout_module(x) 125 | x = residual + x 126 | x = self.final_layer_norm(x) 127 | 128 | return x, attn # attn: (n_heads, batch_size, seq_length, seq_length) 129 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/models/basecnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MaskedConv1d(nn.Conv1d): 6 | """ A masked 1-dimensional convolution layer. 7 | 8 | Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically. 9 | 10 | Shape: 11 | Input: (N, L, in_channels) 12 | input_mask: (N, L, 1), optional 13 | Output: (N, L, out_channels) 14 | """ 15 | 16 | def __init__(self, in_channels: int, out_channels: int, 17 | kernel_size: int, stride: int=1, dilation: int=1, groups: int=1, 18 | bias: bool=True): 19 | """ 20 | :param in_channels: input channels 21 | :param out_channels: output channels 22 | :param kernel_size: the kernel width 23 | :param stride: filter shift 24 | :param dilation: dilation factor 25 | :param groups: perform depth-wise convolutions 26 | :param bias: adds learnable bias to output 27 | """ 28 | padding = dilation * (kernel_size - 1) // 2 29 | super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, 30 | groups=groups, bias=bias, padding=padding) 31 | 32 | def forward(self, x, input_mask=None): 33 | if input_mask is not None: 34 | x = x * input_mask 35 | return super().forward(x.transpose(1, 2)).transpose(1, 2) 36 | 37 | 38 | class LengthMaxPool1D(nn.Module): 39 | def __init__(self, in_dim, out_dim, linear=False, activation='relu'): 40 | super().__init__() 41 | self.linear = linear 42 | if self.linear: 43 | self.layer = nn.Linear(in_dim, out_dim) 44 | 45 | if activation == 'swish': 46 | self.act_fn = lambda x: x * torch.sigmoid(100.0*x) 47 | elif activation == 'softplus': 48 | self.act_fn = nn.Softplus() 49 | elif activation == 'sigmoid': 50 | self.act_fn = nn.Sigmoid() 51 | elif activation == 'leakyrelu': 52 | self.act_fn = nn.LeakyReLU() 53 | elif activation == 'relu': 54 | self.act_fn = lambda x: F.relu(x) 55 | else: 56 | raise NotImplementedError 57 | 58 | def forward(self, x): 59 | if self.linear: 60 | x = self.act_fn(self.layer(x)) 61 | x = torch.max(x, dim=1)[0] 62 | return x 63 | 64 | 65 | class BaseCNN(nn.Module): 66 | def __init__( 67 | self, 68 | n_tokens: int = 20, 69 | kernel_size: int = 5 , 70 | input_size: int = 256, 71 | dropout: float = 0.0, 72 | make_one_hot=True, 73 | activation: str = 'relu', 74 | linear: bool=True, 75 | **kwargs): 76 | super(BaseCNN, self).__init__() 77 | self.encoder = nn.Conv1d(n_tokens, input_size, kernel_size=kernel_size) 78 | self.embedding = LengthMaxPool1D( 79 | linear=linear, 80 | in_dim=input_size, 81 | out_dim=input_size*2, 82 | activation=activation, 83 | ) 84 | self.decoder = nn.Linear(input_size*2, 1) 85 | self.n_tokens = n_tokens 86 | self.dropout = nn.Dropout(dropout) # TODO: actually add this to model 87 | self.input_size = input_size 88 | self._make_one_hot = make_one_hot 89 | 90 | def forward(self, x): 91 | #onehotize 92 | if self._make_one_hot: 93 | x = F.one_hot(x.long(), num_classes=self.n_tokens) 94 | x = x.permute(0, 2, 1).float() 95 | # encoder 96 | x = self.encoder(x).permute(0, 2, 1) 97 | x = self.dropout(x) 98 | # embed 99 | x = self.embedding(x) 100 | # decoder 101 | output = self.decoder(x).squeeze(1) 102 | return output 103 | 104 | # A Dense MLP with (n_layers) layers that takes in a binary vector and outputs a scalar 105 | class ToyMLP(nn.Module): 106 | def __init__(self, seq_len, n_layers=1, hidden_size=256, **kwargs): 107 | super(ToyMLP, self).__init__() 108 | self.layers = nn.ModuleList() 109 | self.layers.append(nn.Linear(seq_len, hidden_size)) 110 | for i in range(n_layers-1): 111 | self.layers.append(nn.Linear(hidden_size, hidden_size)) 112 | self.layers.append(nn.Linear(hidden_size, 1)) 113 | self.act_fn = nn.ReLU() 114 | 115 | def forward(self, x): 116 | for layer in self.layers[:-1]: 117 | x = self.act_fn(layer(x)) 118 | x = self.layers[-1](x) 119 | return x 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/utils/sequence_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for manipulating sequences.""" 2 | import random 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | from copy import deepcopy 7 | 8 | AAS = "ILVAGMFYWEDQNHCRKSTP" 9 | """str: Amino acid alphabet for proteins (length 20 - no stop codon).""" 10 | 11 | RNAA = "UGCA" 12 | """str: RNA alphabet (4 base pairs).""" 13 | 14 | DNAA = "TGCA" 15 | """str: DNA alphabet (4 base pairs).""" 16 | 17 | BA = "01" 18 | """str: Binary alphabet '01'.""" 19 | 20 | 21 | def construct_mutant_from_sample( 22 | pwm_sample: np.ndarray, one_hot_base: np.ndarray 23 | ) -> np.ndarray: 24 | """Return one hot mutant, a utility function for some explorers.""" 25 | one_hot = np.zeros(one_hot_base.shape) 26 | one_hot += one_hot_base 27 | i, j = np.nonzero(pwm_sample) # this can be problematic for non-positive fitnesses 28 | one_hot[i, :] = 0 29 | one_hot[i, j] = 1 30 | return one_hot 31 | 32 | 33 | def string_to_one_hot(sequence: str, alphabet: str) -> np.ndarray: 34 | """ 35 | Return the one-hot representation of a sequence string according to an alphabet. 36 | 37 | Args: 38 | sequence: Sequence string to convert to one_hot representation. 39 | alphabet: Alphabet string (assigns each character an index). 40 | 41 | Returns: 42 | One-hot numpy array of shape `(len(sequence), len(alphabet))`. 43 | 44 | """ 45 | out = np.zeros((len(sequence), len(alphabet))) 46 | for i in range(len(sequence)): 47 | # print(f"sequence[i]: {sequence[i]}, alphabet: {alphabet}") 48 | out[i, alphabet.index(sequence[i])] = 1 49 | return out 50 | 51 | 52 | def one_hot_to_string( 53 | one_hot: Union[List[List[int]], np.ndarray], alphabet: str 54 | ) -> str: 55 | """ 56 | Return the sequence string representing a one-hot vector according to an alphabet. 57 | 58 | Args: 59 | one_hot: One-hot of shape `(len(sequence), len(alphabet)` representing 60 | a sequence. 61 | alphabet: Alphabet string (assigns each character an index). 62 | 63 | Returns: 64 | Sequence string representation of `one_hot`. 65 | 66 | """ 67 | # print(f"sequence_utils.py 66 one_hot: {one_hot}") 68 | residue_idxs = np.argmax(one_hot, axis=1) 69 | # print(f"sequence_utils.py 67 residue_idxs: {residue_idxs}") 70 | return "".join([alphabet[idx] for idx in residue_idxs]) 71 | 72 | 73 | def generate_single_mutants(wt: str, alphabet: str) -> List[str]: 74 | """Generate all single mutants of `wt`.""" 75 | sequences = [wt] 76 | for i in range(len(wt)): 77 | tmp = list(wt) 78 | for j in range(len(alphabet)): 79 | tmp[i] = alphabet[j] 80 | sequences.append("".join(tmp)) 81 | return sequences 82 | 83 | 84 | def generate_random_n_points_mutants(sequence: str, n: int, alphabet: str) -> List[str]: 85 | """ 86 | First select n random points in the sequence, then randomly mutate each point. Return the mutant sequence which has n points mutated. 87 | 88 | Args: 89 | sequence: Sequence to mutate. 90 | n: Number of points to mutate. 91 | alphabet: Alphabet string. 92 | 93 | Returns: 94 | n points mutated sequence. 95 | 96 | """ 97 | sequences_to_mutate = list(deepcopy(sequence)) 98 | # select n points 99 | points = random.sample(range(len(sequence)), n) 100 | # mutate each point 101 | for i in points: 102 | sequences_to_mutate[i] = random.choice(alphabet) 103 | sequences_to_mutate = "".join(sequences_to_mutate) 104 | return sequences_to_mutate 105 | 106 | 107 | def generate_random_sequences(length: int, number: int, alphabet: str) -> List[str]: 108 | """Generate random sequences of particular length.""" 109 | return [ 110 | "".join([random.choice(alphabet) for _ in range(length)]) for _ in range(number) 111 | ] 112 | 113 | def generate_random_mutant(sequence: str, mu: float, alphabet: str) -> str: 114 | """ 115 | Generate a mutant of `sequence` where each residue mutates with probability `mu`. 116 | 117 | So the expected value of the total number of mutations is `len(sequence) * mu`. 118 | 119 | Args: 120 | sequence: Sequence that will be mutated from. 121 | mu: Probability of mutation per residue. 122 | alphabet: Alphabet string. 123 | 124 | Returns: 125 | Mutant sequence string. 126 | 127 | """ 128 | mutant = [] 129 | for s in sequence: 130 | if random.random() < mu: 131 | mutant.append(random.choice(alphabet)) 132 | else: 133 | mutant.append(s) 134 | return "".join(mutant) 135 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/vec_env/vec_transpose.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Dict, Union 3 | 4 | import numpy as np 5 | from gym import spaces 6 | 7 | from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first 8 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper 9 | 10 | 11 | class VecTransposeImage(VecEnvWrapper): 12 | """ 13 | Re-order channels, from HxWxC to CxHxW. 14 | It is required for PyTorch convolution layers. 15 | 16 | :param venv: 17 | :param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not, 18 | which may result in unwanted behavior, see GH issue #671. 19 | """ 20 | 21 | def __init__(self, venv: VecEnv, skip: bool = False): 22 | assert is_image_space(venv.observation_space) or isinstance( 23 | venv.observation_space, spaces.dict.Dict 24 | ), "The observation space must be an image or dictionary observation space" 25 | 26 | self.skip = skip 27 | # Do nothing 28 | if skip: 29 | super().__init__(venv) 30 | return 31 | 32 | if isinstance(venv.observation_space, spaces.dict.Dict): 33 | self.image_space_keys = [] 34 | observation_space = deepcopy(venv.observation_space) 35 | for key, space in observation_space.spaces.items(): 36 | if is_image_space(space): 37 | # Keep track of which keys should be transposed later 38 | self.image_space_keys.append(key) 39 | observation_space.spaces[key] = self.transpose_space(space, key) 40 | else: 41 | observation_space = self.transpose_space(venv.observation_space) 42 | super().__init__(venv, observation_space=observation_space) 43 | 44 | @staticmethod 45 | def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box: 46 | """ 47 | Transpose an observation space (re-order channels). 48 | 49 | :param observation_space: 50 | :param key: In case of dictionary space, the key of the observation space. 51 | :return: 52 | """ 53 | # Sanity checks 54 | assert is_image_space(observation_space), "The observation space must be an image" 55 | assert not is_image_space_channels_first( 56 | observation_space 57 | ), f"The observation space {key} must follow the channel last convention" 58 | height, width, channels = observation_space.shape 59 | new_shape = (channels, height, width) 60 | return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) 61 | 62 | @staticmethod 63 | def transpose_image(image: np.ndarray) -> np.ndarray: 64 | """ 65 | Transpose an image or batch of images (re-order channels). 66 | 67 | :param image: 68 | :return: 69 | """ 70 | if len(image.shape) == 3: 71 | return np.transpose(image, (2, 0, 1)) 72 | return np.transpose(image, (0, 3, 1, 2)) 73 | 74 | def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]: 75 | """ 76 | Transpose (if needed) and return new observations. 77 | 78 | :param observations: 79 | :return: Transposed observations 80 | """ 81 | # Do nothing 82 | if self.skip: 83 | return observations 84 | 85 | if isinstance(observations, dict): 86 | # Avoid modifying the original object in place 87 | observations = deepcopy(observations) 88 | for k in self.image_space_keys: 89 | observations[k] = self.transpose_image(observations[k]) 90 | else: 91 | observations = self.transpose_image(observations) 92 | return observations 93 | 94 | def step_wait(self) -> VecEnvStepReturn: 95 | observations, rewards, dones, infos = self.venv.step_wait() 96 | 97 | # Transpose the terminal observations 98 | for idx, done in enumerate(dones): 99 | if not done: 100 | continue 101 | if "terminal_observation" in infos[idx]: 102 | infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"]) 103 | 104 | return self.transpose_observations(observations), rewards, dones, infos 105 | 106 | def reset(self) -> Union[np.ndarray, Dict]: 107 | """ 108 | Reset all environments 109 | """ 110 | return self.transpose_observations(self.venv.reset()) 111 | 112 | def close(self) -> None: 113 | self.venv.close() 114 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/cmaes.py: -------------------------------------------------------------------------------- 1 | """CMAES explorer.""" 2 | from typing import Optional, Tuple 3 | 4 | import cma 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import flexs 9 | from flexs.utils import sequence_utils as s_utils 10 | 11 | 12 | class CMAES(flexs.Explorer): 13 | """ 14 | An explorer which implements the covariance matrix adaptation evolution 15 | strategy (CMAES). 16 | 17 | Optimizes a continuous relaxation of the one-hot sequence that we use to 18 | construct a normal distribution around, sample from, and then argmax to get 19 | sequences for the objective function. 20 | 21 | http://blog.otoro.net/2017/10/29/visual-evolution-strategies/ is a helpful guide. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | model: flexs.Model, 27 | rounds: int, 28 | sequences_batch_size: int, 29 | model_queries_per_batch: int, 30 | starting_sequence: str, 31 | alphabet: str, 32 | population_size: int = 15, 33 | max_iter: int = 400, 34 | initial_variance: float = 0.2, 35 | log_file: Optional[str] = None, 36 | ): 37 | """ 38 | Args: 39 | population_size: Number of proposed solutions per iteration. 40 | max_iter: Maximum number of iterations. 41 | initial_variance: Initial variance passed into cma. 42 | """ 43 | name = f"CMAES_popsize{population_size}" 44 | 45 | super().__init__( 46 | model, 47 | name, 48 | rounds, 49 | sequences_batch_size, 50 | model_queries_per_batch, 51 | starting_sequence, 52 | log_file, 53 | ) 54 | 55 | self.alphabet = alphabet 56 | self.population_size = population_size 57 | self.max_iter = max_iter 58 | self.initial_variance = initial_variance 59 | self.round = 0 60 | 61 | def _soln_to_string(self, soln): 62 | x = soln.reshape((len(self.starting_sequence), len(self.alphabet))) 63 | 64 | one_hot = np.zeros(x.shape) 65 | one_hot[np.arange(len(one_hot)), np.argmax(x, axis=1)] = 1 66 | 67 | return s_utils.one_hot_to_string(one_hot, self.alphabet) 68 | 69 | def propose_sequences( 70 | self, measured_sequences: pd.DataFrame 71 | ) -> Tuple[np.ndarray, np.ndarray]: 72 | """Propose top `sequences_batch_size` sequences for evaluation.""" 73 | measured_sequence_dict = dict( 74 | zip(measured_sequences["sequence"], measured_sequences["true_score"]) 75 | ) 76 | 77 | # Keep track of new sequences generated this round 78 | top_idx = measured_sequences["true_score"].argmax() 79 | top_seq = measured_sequences["sequence"].to_numpy()[top_idx] 80 | top_val = measured_sequences["true_score"].to_numpy()[top_idx] 81 | sequences = {top_seq: top_val} 82 | 83 | def objective_function(soln): 84 | seq = self._soln_to_string(soln) 85 | 86 | if seq in sequences: 87 | return sequences[seq] 88 | if seq in measured_sequence_dict: 89 | return measured_sequence_dict[seq] 90 | 91 | return self.model.get_fitness([seq]).item() 92 | 93 | # Starting solution gives equal weight to all residues at all positions 94 | x0 = s_utils.string_to_one_hot(top_seq, self.alphabet).flatten() 95 | opts = {"popsize": self.population_size, "verbose": -9, "verb_log": 0} 96 | es = cma.CMAEvolutionStrategy(x0, np.sqrt(self.initial_variance), opts) 97 | 98 | # Explore until we reach `self.max_iter` or run out of model queries 99 | initial_cost = self.model.cost 100 | for _ in range(self.max_iter): 101 | 102 | # Stop exploring if we will run out of model queries 103 | current_cost = self.model.cost - initial_cost 104 | if current_cost + self.population_size > self.model_queries_per_batch: 105 | break 106 | 107 | # `ask_and_eval` generates a new population of sequences 108 | solutions, fitnesses = es.ask_and_eval(objective_function) 109 | # `tell` updates model parameters 110 | es.tell(solutions, fitnesses) 111 | 112 | # Store scores of generated sequences 113 | for soln, f in zip(solutions, fitnesses): 114 | sequences[self._soln_to_string(soln)] = f 115 | 116 | # We propose the top `self.sequences_batch_size` new sequences we have generated 117 | new_seqs = np.array(list(sequences.keys())) 118 | # Negate `objective_function` scores 119 | preds = np.array(list(sequences.values())) 120 | sorted_order = np.argsort(preds)[: -self.sequences_batch_size : -1] 121 | 122 | return new_seqs[sorted_order], preds[sorted_order] 123 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/landscapes/src/protein/models/encoder/transformer_sentence_encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Callable, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | from fairseq import utils 11 | from fairseq.modules import LayerNorm 12 | from fairseq.modules.fairseq_dropout import FairseqDropout 13 | from fairseq.modules.quant_noise import quant_noise 14 | from .multihead_attention import MultiheadAttention 15 | 16 | class TransformerSentenceEncoderLayer(nn.Module): 17 | """ 18 | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained 19 | models. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | embedding_dim: int = 768, 25 | ffn_embedding_dim: int = 3072, 26 | num_attention_heads: int = 8, 27 | dropout: float = 0.1, 28 | attention_dropout: float = 0.1, 29 | activation_dropout: float = 0.1, 30 | activation_fn: str = "relu", 31 | export: bool = False, 32 | q_noise: float = 0.0, 33 | qn_block_size: int = 8, 34 | init_fn: Callable = None, 35 | ) -> None: 36 | super().__init__() 37 | 38 | if init_fn is not None: 39 | init_fn() 40 | 41 | # Initialize parameters 42 | self.embedding_dim = embedding_dim 43 | self.dropout_module = FairseqDropout( 44 | dropout, module_name=self.__class__.__name__ 45 | ) 46 | self.activation_dropout_module = FairseqDropout( 47 | activation_dropout, module_name=self.__class__.__name__ 48 | ) 49 | 50 | # Initialize blocks 51 | self.activation_fn = utils.get_activation_fn(activation_fn) 52 | self.self_attn = self.build_self_attention( 53 | self.embedding_dim, 54 | num_attention_heads, 55 | dropout=attention_dropout, 56 | self_attention=True, 57 | q_noise=q_noise, 58 | qn_block_size=qn_block_size, 59 | ) 60 | 61 | # layer norm associated with the self attention layer 62 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) 63 | 64 | self.fc1 = self.build_fc1( 65 | self.embedding_dim, 66 | ffn_embedding_dim, 67 | q_noise=q_noise, 68 | qn_block_size=qn_block_size, 69 | ) 70 | self.fc2 = self.build_fc2( 71 | ffn_embedding_dim, 72 | self.embedding_dim, 73 | q_noise=q_noise, 74 | qn_block_size=qn_block_size, 75 | ) 76 | 77 | # layer norm associated with the position wise feed-forward NN 78 | self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) 79 | 80 | def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): 81 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 82 | 83 | def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): 84 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 85 | 86 | def build_self_attention( 87 | self, 88 | embed_dim, 89 | num_attention_heads, 90 | dropout, 91 | self_attention, 92 | q_noise, 93 | qn_block_size, 94 | ): 95 | return MultiheadAttention( 96 | embed_dim, 97 | num_attention_heads, 98 | dropout=dropout, 99 | self_attention=True, 100 | q_noise=q_noise, 101 | qn_block_size=qn_block_size, 102 | ) 103 | 104 | def forward( 105 | self, 106 | x: torch.Tensor, 107 | self_attn_mask: Optional[torch.Tensor] = None, 108 | self_attn_padding_mask: Optional[torch.Tensor] = None, 109 | attn_bias = None, 110 | ): 111 | """ 112 | LayerNorm is applied either before or after the self-attention/ffn 113 | modules similar to the original Transformer implementation. 114 | """ 115 | residual = x 116 | 117 | x, attn = self.self_attn( 118 | query=x, 119 | key=x, 120 | value=x, 121 | key_padding_mask = self_attn_padding_mask, 122 | need_weights = True, 123 | attn_mask = self_attn_mask, 124 | attn_bias = attn_bias, 125 | ) 126 | x = self.dropout_module(x) 127 | x = residual + x 128 | x = self.self_attn_layer_norm(x) 129 | 130 | residual = x 131 | x = self.activation_fn(self.fc1(x)) 132 | x = self.activation_dropout_module(x) 133 | x = self.fc2(x) 134 | x = self.dropout_module(x) 135 | x = residual + x 136 | x = self.final_layer_norm(x) 137 | return x, attn 138 | -------------------------------------------------------------------------------- /mu-search/src/flexs/flexs/baselines/explorers/stable_baselines3/stable_baselines3/common/results_plotter.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | # import matplotlib 7 | # matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode 8 | from matplotlib import pyplot as plt 9 | 10 | from stable_baselines3.common.monitor import load_results 11 | 12 | X_TIMESTEPS = "timesteps" 13 | X_EPISODES = "episodes" 14 | X_WALLTIME = "walltime_hrs" 15 | POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME] 16 | EPISODES_WINDOW = 100 17 | 18 | 19 | def rolling_window(array: np.ndarray, window: int) -> np.ndarray: 20 | """ 21 | Apply a rolling window to a np.ndarray 22 | 23 | :param array: the input Array 24 | :param window: length of the rolling window 25 | :return: rolling window on the input array 26 | """ 27 | shape = array.shape[:-1] + (array.shape[-1] - window + 1, window) 28 | strides = array.strides + (array.strides[-1],) 29 | return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides) 30 | 31 | 32 | def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]: 33 | """ 34 | Apply a function to the rolling window of 2 arrays 35 | 36 | :param var_1: variable 1 37 | :param var_2: variable 2 38 | :param window: length of the rolling window 39 | :param func: function to apply on the rolling window on variable 2 (such as np.mean) 40 | :return: the rolling output with applied function 41 | """ 42 | var_2_window = rolling_window(var_2, window) 43 | function_on_var2 = func(var_2_window, axis=-1) 44 | return var_1[window - 1 :], function_on_var2 45 | 46 | 47 | def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]: 48 | """ 49 | Decompose a data frame variable to x ans ys 50 | 51 | :param data_frame: the input data 52 | :param x_axis: the axis for the x and y output 53 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 54 | :return: the x and y output 55 | """ 56 | if x_axis == X_TIMESTEPS: 57 | x_var = np.cumsum(data_frame.l.values) 58 | y_var = data_frame.r.values 59 | elif x_axis == X_EPISODES: 60 | x_var = np.arange(len(data_frame)) 61 | y_var = data_frame.r.values 62 | elif x_axis == X_WALLTIME: 63 | # Convert to hours 64 | x_var = data_frame.t.values / 3600.0 65 | y_var = data_frame.r.values 66 | else: 67 | raise NotImplementedError 68 | return x_var, y_var 69 | 70 | 71 | def plot_curves( 72 | xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2) 73 | ) -> None: 74 | """ 75 | plot the curves 76 | 77 | :param xy_list: the x and y coordinates to plot 78 | :param x_axis: the axis for the x and y output 79 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 80 | :param title: the title of the plot 81 | :param figsize: Size of the figure (width, height) 82 | """ 83 | 84 | plt.figure(title, figsize=figsize) 85 | max_x = max(xy[0][-1] for xy in xy_list) 86 | min_x = 0 87 | for (_, (x, y)) in enumerate(xy_list): 88 | plt.scatter(x, y, s=2) 89 | # Do not plot the smoothed curve at all if the timeseries is shorter than window size. 90 | if x.shape[0] >= EPISODES_WINDOW: 91 | # Compute and plot rolling mean with window of size EPISODE_WINDOW 92 | x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) 93 | plt.plot(x, y_mean) 94 | plt.xlim(min_x, max_x) 95 | plt.title(title) 96 | plt.xlabel(x_axis) 97 | plt.ylabel("Episode Rewards") 98 | plt.tight_layout() 99 | 100 | 101 | def plot_results( 102 | dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2) 103 | ) -> None: 104 | """ 105 | Plot the results using csv files from ``Monitor`` wrapper. 106 | 107 | :param dirs: the save location of the results to plot 108 | :param num_timesteps: only plot the points below this value 109 | :param x_axis: the axis for the x and y output 110 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') 111 | :param task_name: the title of the task to plot 112 | :param figsize: Size of the figure (width, height) 113 | """ 114 | 115 | data_frames = [] 116 | for folder in dirs: 117 | data_frame = load_results(folder) 118 | if num_timesteps is not None: 119 | data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps] 120 | data_frames.append(data_frame) 121 | xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames] 122 | plot_curves(xy_list, x_axis, task_name, figsize) 123 | --------------------------------------------------------------------------------