├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── data_utils.py ├── digit_seq_rewards.py ├── generate_digit_data.py ├── imessage_chat_data.py ├── increasing_mult_1_test.jsonl ├── increasing_mult_1_train.jsonl ├── increasing_mult_1_valid.jsonl ├── increasing_mult_2_test.jsonl ├── increasing_mult_2_train.jsonl ├── increasing_mult_2_valid.jsonl ├── reward_function_increasing_mult_2_train.jsonl └── reward_function_increasing_mult_2_valid.jsonl ├── imessage_bot.md ├── mlx_ppo_trainer.py ├── mlx_reward.png ├── models ├── __init__.py ├── base.py ├── config.py ├── convert.py ├── fuse.py ├── llama.py ├── lora.py ├── mixtral.py └── prompt_tuning.py ├── ppo_training.py ├── pytorch_baseline ├── README.md ├── __init__.py ├── pytorch_ppo_trainer.py ├── pytorch_ppo_training.py ├── pytorch_sft.py ├── pytorch_talk_to_model.py └── utils.py ├── pytorch_reward.png ├── requirements.txt ├── sequential_digits.md ├── sft.py ├── talk_to_model.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/.gitignore -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/README.md -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/data_utils.py -------------------------------------------------------------------------------- /data/digit_seq_rewards.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/digit_seq_rewards.py -------------------------------------------------------------------------------- /data/generate_digit_data.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/generate_digit_data.py -------------------------------------------------------------------------------- /data/imessage_chat_data.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/imessage_chat_data.py -------------------------------------------------------------------------------- /data/increasing_mult_1_test.jsonl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/increasing_mult_1_test.jsonl -------------------------------------------------------------------------------- /data/increasing_mult_1_train.jsonl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/increasing_mult_1_train.jsonl -------------------------------------------------------------------------------- /data/increasing_mult_1_valid.jsonl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/increasing_mult_1_valid.jsonl -------------------------------------------------------------------------------- /data/increasing_mult_2_test.jsonl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/increasing_mult_2_test.jsonl -------------------------------------------------------------------------------- /data/increasing_mult_2_train.jsonl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/increasing_mult_2_train.jsonl -------------------------------------------------------------------------------- /data/increasing_mult_2_valid.jsonl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/increasing_mult_2_valid.jsonl -------------------------------------------------------------------------------- /data/reward_function_increasing_mult_2_train.jsonl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/reward_function_increasing_mult_2_train.jsonl -------------------------------------------------------------------------------- /data/reward_function_increasing_mult_2_valid.jsonl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/data/reward_function_increasing_mult_2_valid.jsonl -------------------------------------------------------------------------------- /imessage_bot.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/imessage_bot.md -------------------------------------------------------------------------------- /mlx_ppo_trainer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/mlx_ppo_trainer.py -------------------------------------------------------------------------------- /mlx_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/mlx_reward.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/models/base.py -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/models/config.py -------------------------------------------------------------------------------- /models/convert.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/models/convert.py -------------------------------------------------------------------------------- /models/fuse.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/models/fuse.py -------------------------------------------------------------------------------- /models/llama.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/models/llama.py -------------------------------------------------------------------------------- /models/lora.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/models/lora.py -------------------------------------------------------------------------------- /models/mixtral.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/models/mixtral.py -------------------------------------------------------------------------------- /models/prompt_tuning.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/models/prompt_tuning.py -------------------------------------------------------------------------------- /ppo_training.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/ppo_training.py -------------------------------------------------------------------------------- /pytorch_baseline/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/pytorch_baseline/README.md -------------------------------------------------------------------------------- /pytorch_baseline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_baseline/pytorch_ppo_trainer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/pytorch_baseline/pytorch_ppo_trainer.py -------------------------------------------------------------------------------- /pytorch_baseline/pytorch_ppo_training.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/pytorch_baseline/pytorch_ppo_training.py -------------------------------------------------------------------------------- /pytorch_baseline/pytorch_sft.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/pytorch_baseline/pytorch_sft.py -------------------------------------------------------------------------------- /pytorch_baseline/pytorch_talk_to_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/pytorch_baseline/pytorch_talk_to_model.py -------------------------------------------------------------------------------- /pytorch_baseline/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/pytorch_baseline/utils.py -------------------------------------------------------------------------------- /pytorch_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/pytorch_reward.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/requirements.txt -------------------------------------------------------------------------------- /sequential_digits.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/sequential_digits.md -------------------------------------------------------------------------------- /sft.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/sft.py -------------------------------------------------------------------------------- /talk_to_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/talk_to_model.py -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrew-silva/mlx-rlhf/HEAD/utils.py --------------------------------------------------------------------------------