├── .gitignore ├── README.md ├── aria ├── algorithms │ ├── __init__.py │ ├── offpolicy_train.py │ ├── onpolicy_train.py │ └── trainer │ │ ├── __init__.py │ │ ├── online_reinforce_trainer.py │ │ └── reinforce_trainer.py ├── data │ ├── __init__.py │ └── utils.py ├── environment │ ├── __init__.py │ ├── env_utils.py │ ├── guess_my_city.py │ ├── llm_base.py │ └── twenty_questions.py ├── merge_lora.py ├── models │ ├── __init__.py │ ├── online_reinforce_agent.py │ └── reinforce_agent.py └── utils.py ├── dataset ├── actor_reinforce_llama3-8b_multi.json ├── actor_reinforce_llama3-8b_single.json ├── llama3-8b_guess_embedding_msgs.jsonl └── llama3-8b_negotiation_embedding_msgs.jsonl ├── evaluation ├── eval.sh ├── guess_my_city.py ├── llm_base.py ├── main.py ├── prompt_base.py └── twenty_questions.py ├── figures ├── main.pdf └── main.png ├── main.png ├── pointwise_rm ├── __init__.py ├── pointwiserm_trainer.py ├── reward_config.py └── run_rm.py ├── requirements.txt ├── reward_aggregation ├── bargaining_clustering │ ├── clustering.py │ ├── postprocess.py │ └── preprocess.py ├── clustering_multi.py ├── clustering_single.py ├── game_data_processor.py ├── gen_reinforce_multi.py ├── gen_reinforce_single.py ├── guess_my_city_clustering │ ├── clustering.py │ ├── postprocess.py │ └── preprocess.py ├── negotiation_clustering │ ├── clustering.py │ ├── postprocess.py │ └── preprocess.py ├── prompt_base.py ├── reward_utils.py └── twenty_question_clustering │ ├── clustering.py │ ├── postprocess.py │ └── preprocess.py ├── scripts ├── config │ ├── accelerate_config │ │ └── default_config.yaml │ ├── default.yaml │ ├── ds_config.json │ ├── onlinereinforce_llm.yaml │ └── reinforce_llm.yaml ├── run_offline.py ├── run_online.py ├── train_online.sh └── train_rm.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/README.md -------------------------------------------------------------------------------- /aria/algorithms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/algorithms/__init__.py -------------------------------------------------------------------------------- /aria/algorithms/offpolicy_train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/algorithms/offpolicy_train.py -------------------------------------------------------------------------------- /aria/algorithms/onpolicy_train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/algorithms/onpolicy_train.py -------------------------------------------------------------------------------- /aria/algorithms/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/algorithms/trainer/__init__.py -------------------------------------------------------------------------------- /aria/algorithms/trainer/online_reinforce_trainer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/algorithms/trainer/online_reinforce_trainer.py -------------------------------------------------------------------------------- /aria/algorithms/trainer/reinforce_trainer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/algorithms/trainer/reinforce_trainer.py -------------------------------------------------------------------------------- /aria/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import DummyDataset, ReplayBuffer -------------------------------------------------------------------------------- /aria/data/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/data/utils.py -------------------------------------------------------------------------------- /aria/environment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/environment/__init__.py -------------------------------------------------------------------------------- /aria/environment/env_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/environment/env_utils.py -------------------------------------------------------------------------------- /aria/environment/guess_my_city.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/environment/guess_my_city.py -------------------------------------------------------------------------------- /aria/environment/llm_base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/environment/llm_base.py -------------------------------------------------------------------------------- /aria/environment/twenty_questions.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/environment/twenty_questions.py -------------------------------------------------------------------------------- /aria/merge_lora.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/merge_lora.py -------------------------------------------------------------------------------- /aria/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/models/__init__.py -------------------------------------------------------------------------------- /aria/models/online_reinforce_agent.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/models/online_reinforce_agent.py -------------------------------------------------------------------------------- /aria/models/reinforce_agent.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/models/reinforce_agent.py -------------------------------------------------------------------------------- /aria/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/aria/utils.py -------------------------------------------------------------------------------- /dataset/actor_reinforce_llama3-8b_multi.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/dataset/actor_reinforce_llama3-8b_multi.json -------------------------------------------------------------------------------- /dataset/actor_reinforce_llama3-8b_single.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/dataset/actor_reinforce_llama3-8b_single.json -------------------------------------------------------------------------------- /dataset/llama3-8b_guess_embedding_msgs.jsonl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/dataset/llama3-8b_guess_embedding_msgs.jsonl -------------------------------------------------------------------------------- /dataset/llama3-8b_negotiation_embedding_msgs.jsonl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/dataset/llama3-8b_negotiation_embedding_msgs.jsonl -------------------------------------------------------------------------------- /evaluation/eval.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/evaluation/eval.sh -------------------------------------------------------------------------------- /evaluation/guess_my_city.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/evaluation/guess_my_city.py -------------------------------------------------------------------------------- /evaluation/llm_base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/evaluation/llm_base.py -------------------------------------------------------------------------------- /evaluation/main.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/evaluation/main.py -------------------------------------------------------------------------------- /evaluation/prompt_base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/evaluation/prompt_base.py -------------------------------------------------------------------------------- /evaluation/twenty_questions.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/evaluation/twenty_questions.py -------------------------------------------------------------------------------- /figures/main.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/figures/main.pdf -------------------------------------------------------------------------------- /figures/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/figures/main.png -------------------------------------------------------------------------------- /main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/main.png -------------------------------------------------------------------------------- /pointwise_rm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pointwise_rm/pointwiserm_trainer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/pointwise_rm/pointwiserm_trainer.py -------------------------------------------------------------------------------- /pointwise_rm/reward_config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/pointwise_rm/reward_config.py -------------------------------------------------------------------------------- /pointwise_rm/run_rm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/pointwise_rm/run_rm.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/requirements.txt -------------------------------------------------------------------------------- /reward_aggregation/bargaining_clustering/clustering.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/bargaining_clustering/clustering.py -------------------------------------------------------------------------------- /reward_aggregation/bargaining_clustering/postprocess.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/bargaining_clustering/postprocess.py -------------------------------------------------------------------------------- /reward_aggregation/bargaining_clustering/preprocess.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/bargaining_clustering/preprocess.py -------------------------------------------------------------------------------- /reward_aggregation/clustering_multi.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/clustering_multi.py -------------------------------------------------------------------------------- /reward_aggregation/clustering_single.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/clustering_single.py -------------------------------------------------------------------------------- /reward_aggregation/game_data_processor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/game_data_processor.py -------------------------------------------------------------------------------- /reward_aggregation/gen_reinforce_multi.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/gen_reinforce_multi.py -------------------------------------------------------------------------------- /reward_aggregation/gen_reinforce_single.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/gen_reinforce_single.py -------------------------------------------------------------------------------- /reward_aggregation/guess_my_city_clustering/clustering.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/guess_my_city_clustering/clustering.py -------------------------------------------------------------------------------- /reward_aggregation/guess_my_city_clustering/postprocess.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/guess_my_city_clustering/postprocess.py -------------------------------------------------------------------------------- /reward_aggregation/guess_my_city_clustering/preprocess.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/guess_my_city_clustering/preprocess.py -------------------------------------------------------------------------------- /reward_aggregation/negotiation_clustering/clustering.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/negotiation_clustering/clustering.py -------------------------------------------------------------------------------- /reward_aggregation/negotiation_clustering/postprocess.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/negotiation_clustering/postprocess.py -------------------------------------------------------------------------------- /reward_aggregation/negotiation_clustering/preprocess.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/negotiation_clustering/preprocess.py -------------------------------------------------------------------------------- /reward_aggregation/prompt_base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/prompt_base.py -------------------------------------------------------------------------------- /reward_aggregation/reward_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/reward_utils.py -------------------------------------------------------------------------------- /reward_aggregation/twenty_question_clustering/clustering.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/twenty_question_clustering/clustering.py -------------------------------------------------------------------------------- /reward_aggregation/twenty_question_clustering/postprocess.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/twenty_question_clustering/postprocess.py -------------------------------------------------------------------------------- /reward_aggregation/twenty_question_clustering/preprocess.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/reward_aggregation/twenty_question_clustering/preprocess.py -------------------------------------------------------------------------------- /scripts/config/accelerate_config/default_config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/scripts/config/accelerate_config/default_config.yaml -------------------------------------------------------------------------------- /scripts/config/default.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/scripts/config/default.yaml -------------------------------------------------------------------------------- /scripts/config/ds_config.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/scripts/config/ds_config.json -------------------------------------------------------------------------------- /scripts/config/onlinereinforce_llm.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/scripts/config/onlinereinforce_llm.yaml -------------------------------------------------------------------------------- /scripts/config/reinforce_llm.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/scripts/config/reinforce_llm.yaml -------------------------------------------------------------------------------- /scripts/run_offline.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/scripts/run_offline.py -------------------------------------------------------------------------------- /scripts/run_online.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/scripts/run_online.py -------------------------------------------------------------------------------- /scripts/train_online.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/scripts/train_online.sh -------------------------------------------------------------------------------- /scripts/train_rm.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/scripts/train_rm.sh -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhyang2021/ARIA/HEAD/setup.py --------------------------------------------------------------------------------