├── .gitignore ├── .gitmodules ├── README.md ├── configs ├── cT0 │ ├── 1_wiki_auto.json │ ├── 1_wiki_auto_eval.json │ ├── 1_wiki_auto_seen_eval.json │ ├── 2_gigaword_1.json │ ├── 2_gigaword_2.json │ ├── 2_gigaword_3.json │ ├── 2_gigaword_eval.json │ ├── 2_gigaword_seen_eval.json │ ├── 3_haiku.json │ ├── 3_haiku_eval.json │ ├── 3_haiku_seen_eval.json │ ├── 4_covid_qa.json │ ├── 4_covid_qa_eval.json │ ├── 4_covid_qa_seen_eval.json │ ├── 5_eli5.json │ ├── 5_eli5_eval.json │ ├── 5_eli5_seen_eval.json │ ├── 6_emdg.json │ ├── 6_emdg_eval.json │ ├── 6_emdg_seen_eval.json │ ├── 7_esnli.json │ ├── 7_esnli_eval.json │ ├── 7_esnli_seen_eval.json │ ├── 8_tw20.json │ ├── 8_tw20_eval.json │ ├── 8_tw20_seen_eval.json │ ├── cT0_eval.json │ └── t0.json ├── debug.json ├── experiment │ ├── Flan_T5_11B.json │ ├── Flan_T5_3B.json │ ├── adversarial_qa │ │ ├── adversarial_qa.json │ │ ├── adversarial_qa2.json │ │ ├── adversarial_qa3.json │ │ ├── adversarial_qa4.json │ │ ├── adversarial_qa5.json │ │ └── adversarial_qa_all.json │ ├── ag_news │ │ ├── ag_news.json │ │ ├── ag_news2.json │ │ ├── ag_news3.json │ │ ├── ag_news4.json │ │ ├── ag_news5.json │ │ ├── ag_news6.json │ │ ├── ag_news7.json │ │ └── ag_news_all.json │ ├── amazon_polarity │ │ ├── amazon_polarity.json │ │ ├── amazon_polarity2.json │ │ ├── amazon_polarity3.json │ │ ├── amazon_polarity4.json │ │ ├── amazon_polarity5.json │ │ ├── amazon_polarity6.json │ │ ├── amazon_polarity7.json │ │ ├── amazon_polarity8.json │ │ ├── amazon_polarity9.json │ │ └── amazon_polarity_all.json │ ├── app_reviews │ │ ├── app_reviews.json │ │ ├── app_reviews2.json │ │ ├── app_reviews3.json │ │ ├── app_reviews4.json │ │ └── app_reviews_all.json │ ├── cnn_dailymail │ │ ├── cnn_dailymail.json │ │ ├── cnn_dailymail2.json │ │ ├── cnn_dailymail3.json │ │ ├── cnn_dailymail4.json │ │ ├── cnn_dailymail5.json │ │ ├── cnn_dailymail6.json │ │ ├── cnn_dailymail7.json │ │ ├── cnn_dailymail8.json │ │ ├── cnn_dailymail9.json │ │ └── cnn_dailymail_all.json │ ├── common_gen │ │ ├── common_gen.json │ │ ├── common_gen2.json │ │ ├── common_gen3.json │ │ ├── common_gen4.json │ │ ├── common_gen5.json │ │ ├── common_gen6.json │ │ ├── common_gen7.json │ │ ├── common_gen8.json │ │ ├── common_gen9.json │ │ └── common_gen_all.json │ ├── commonsense_qa │ │ ├── commonsense_qa.json │ │ ├── commonsense_qa2.json │ │ ├── commonsense_qa3.json │ │ ├── commonsense_qa4.json │ │ ├── commonsense_qa5.json │ │ └── commonsense_qa_all.json │ ├── cos_e │ │ ├── cos_e.json │ │ ├── cos_e10.json │ │ ├── cos_e11.json │ │ ├── cos_e2.json │ │ ├── cos_e3.json │ │ ├── cos_e4.json │ │ ├── cos_e5.json │ │ ├── cos_e6.json │ │ ├── cos_e7.json │ │ ├── cos_e8.json │ │ ├── cos_e9.json │ │ └── cos_e_all.json │ ├── cosmos_qa │ │ ├── cosmosqa.json │ │ ├── cosmosqa10.json │ │ ├── cosmosqa11.json │ │ ├── cosmosqa12.json │ │ ├── cosmosqa13.json │ │ ├── cosmosqa2.json │ │ ├── cosmosqa3.json │ │ ├── cosmosqa4.json │ │ ├── cosmosqa5.json │ │ ├── cosmosqa6.json │ │ ├── cosmosqa7.json │ │ ├── cosmosqa8.json │ │ ├── cosmosqa9.json │ │ └── cosmosqa_all.json │ ├── dbpedia_14 │ │ ├── dbpedia_14.json │ │ ├── dbpedia_142.json │ │ ├── dbpedia_143.json │ │ ├── dbpedia_144.json │ │ └── dbpedia_14_all.json │ ├── dream │ │ ├── dream.json │ │ ├── dream2.json │ │ ├── dream3.json │ │ ├── dream4.json │ │ ├── dream5.json │ │ └── dream_all.json │ ├── duorc │ │ ├── duorc.json │ │ ├── duorc2.json │ │ ├── duorc3.json │ │ ├── duorc4.json │ │ ├── duorc5.json │ │ ├── duorc6.json │ │ ├── duorc7.json │ │ ├── duorc8.json │ │ ├── duorc9.json │ │ └── duorc_all.json │ ├── gigaword │ │ ├── gigaword.json │ │ ├── gigaword2.json │ │ ├── gigaword3.json │ │ ├── gigaword4.json │ │ ├── gigaword5.json │ │ ├── gigaword6.json │ │ ├── gigaword7.json │ │ ├── gigaword8.json │ │ ├── gigaword9.json │ │ └── gigaword_all.json │ ├── hotpot_qa │ │ ├── hotpot_qa.json │ │ ├── hotpot_qa2.json │ │ ├── hotpot_qa3.json │ │ ├── hotpot_qa4.json │ │ ├── hotpot_qa5.json │ │ ├── hotpot_qa6.json │ │ └── hotpot_qa_all.json │ ├── imdb │ │ ├── imdb.json │ │ ├── imdb10.json │ │ ├── imdb11.json │ │ ├── imdb2.json │ │ ├── imdb3.json │ │ ├── imdb4.json │ │ ├── imdb5.json │ │ ├── imdb6.json │ │ ├── imdb7.json │ │ ├── imdb8.json │ │ ├── imdb9.json │ │ └── imdb_all.json │ ├── mrpc │ │ ├── mrpc.json │ │ ├── mrpc2.json │ │ ├── mrpc3.json │ │ ├── mrpc4.json │ │ ├── mrpc5.json │ │ ├── mrpc6.json │ │ ├── mrpc7.json │ │ └── mrpc_all.json │ ├── mt0.json │ ├── multi_news │ │ ├── multi_news.json │ │ ├── multi_news2.json │ │ ├── multi_news3.json │ │ ├── multi_news4.json │ │ ├── multi_news5.json │ │ ├── multi_news6.json │ │ └── multi_news_all.json │ ├── paws │ │ ├── paws.json │ │ ├── paws10.json │ │ ├── paws11.json │ │ ├── paws12.json │ │ ├── paws2.json │ │ ├── paws3.json │ │ ├── paws4.json │ │ ├── paws5.json │ │ ├── paws6.json │ │ ├── paws7.json │ │ ├── paws8.json │ │ ├── paws9.json │ │ └── paws_all.json │ ├── qasc │ │ ├── qasc.json │ │ ├── qasc2.json │ │ ├── qasc3.json │ │ ├── qasc4.json │ │ ├── qasc5.json │ │ ├── qasc6.json │ │ ├── qasc7.json │ │ ├── qasc8.json │ │ └── qasc_all.json │ ├── qqp │ │ ├── qqp.json │ │ ├── qqp2.json │ │ ├── qqp3.json │ │ ├── qqp4.json │ │ ├── qqp5.json │ │ ├── qqp6.json │ │ └── qqp_all.json │ ├── quail │ │ ├── quail.json │ │ ├── quail10.json │ │ ├── quail11.json │ │ ├── quail12.json │ │ ├── quail13.json │ │ ├── quail2.json │ │ ├── quail3.json │ │ ├── quail4.json │ │ ├── quail5.json │ │ ├── quail6.json │ │ ├── quail7.json │ │ ├── quail8.json │ │ ├── quail9.json │ │ └── quail_all.json │ ├── quarel │ │ ├── quarel.json │ │ ├── quarel2.json │ │ ├── quarel3.json │ │ ├── quarel4.json │ │ ├── quarel5.json │ │ └── quarel_all.json │ ├── quartz │ │ ├── quartz.json │ │ ├── quartz2.json │ │ ├── quartz3.json │ │ ├── quartz4.json │ │ ├── quartz5.json │ │ ├── quartz6.json │ │ ├── quartz7.json │ │ ├── quartz8.json │ │ └── quartz_all.json │ ├── quoref │ │ ├── quoref.json │ │ ├── quoref10.json │ │ ├── quoref11.json │ │ ├── quoref2.json │ │ ├── quoref3.json │ │ ├── quoref4.json │ │ ├── quoref5.json │ │ ├── quoref6.json │ │ ├── quoref7.json │ │ ├── quoref8.json │ │ ├── quoref9.json │ │ └── quoref_all.json │ ├── ropes │ │ ├── ropes.json │ │ ├── ropes10.json │ │ ├── ropes11.json │ │ ├── ropes12.json │ │ ├── ropes2.json │ │ ├── ropes3.json │ │ ├── ropes4.json │ │ ├── ropes5.json │ │ ├── ropes6.json │ │ ├── ropes7.json │ │ ├── ropes8.json │ │ ├── ropes9.json │ │ └── ropes_all.json │ ├── rotten_tomatoes │ │ ├── rotten_tomatoes.json │ │ ├── rotten_tomatoes10.json │ │ ├── rotten_tomatoes2.json │ │ ├── rotten_tomatoes3.json │ │ ├── rotten_tomatoes4.json │ │ ├── rotten_tomatoes5.json │ │ ├── rotten_tomatoes6.json │ │ ├── rotten_tomatoes7.json │ │ ├── rotten_tomatoes8.json │ │ ├── rotten_tomatoes9.json │ │ └── rotten_tomatoes_all.json │ ├── samsum │ │ ├── samsum.json │ │ ├── samsum2.json │ │ ├── samsum3.json │ │ ├── samsum4.json │ │ ├── samsum5.json │ │ ├── samsum6.json │ │ ├── samsum7.json │ │ └── samsum_all.json │ ├── sciq │ │ ├── sciq.json │ │ ├── sciq2.json │ │ ├── sciq3.json │ │ ├── sciq4.json │ │ ├── sciq5.json │ │ └── sciq_all.json │ ├── social_i_qa │ │ ├── social_i_qa.json │ │ ├── social_i_qa2.json │ │ ├── social_i_qa3.json │ │ ├── social_i_qa4.json │ │ ├── social_i_qa5.json │ │ ├── social_i_qa6.json │ │ └── social_i_qa_all.json │ ├── t0.json │ ├── t0_11B.json │ ├── t0_3B.json │ ├── t0_target.json │ ├── trec │ │ ├── trec.json │ │ ├── trec10.json │ │ ├── trec11.json │ │ ├── trec12.json │ │ ├── trec13.json │ │ ├── trec14.json │ │ ├── trec15.json │ │ ├── trec16.json │ │ ├── trec17.json │ │ ├── trec18.json │ │ ├── trec2.json │ │ ├── trec3.json │ │ ├── trec4.json │ │ ├── trec5.json │ │ ├── trec6.json │ │ ├── trec7.json │ │ ├── trec8.json │ │ ├── trec9.json │ │ └── trec_all.json │ ├── wiki_bio │ │ └── wiki_bio.json │ ├── wiki_hop │ │ ├── wiki_hop.json │ │ ├── wiki_hop2.json │ │ ├── wiki_hop3.json │ │ ├── wiki_hop4.json │ │ ├── wiki_hop5.json │ │ ├── wiki_hop6.json │ │ ├── wiki_hop7.json │ │ ├── wiki_hop8.json │ │ ├── wiki_hop9.json │ │ └── wiki_hop_all.json │ ├── wiki_qa │ │ ├── wiki_qa.json │ │ ├── wiki_qa10.json │ │ ├── wiki_qa11.json │ │ ├── wiki_qa2.json │ │ ├── wiki_qa3.json │ │ ├── wiki_qa4.json │ │ ├── wiki_qa5.json │ │ ├── wiki_qa6.json │ │ ├── wiki_qa7.json │ │ ├── wiki_qa8.json │ │ ├── wiki_qa9.json │ │ └── wiki_qa_all.json │ ├── wiqa │ │ ├── wiqa.json │ │ ├── wiqa2.json │ │ ├── wiqa3.json │ │ ├── wiqa4.json │ │ ├── wiqa5.json │ │ ├── wiqa6.json │ │ ├── wiqa7.json │ │ ├── wiqa8.json │ │ └── wiqa_all.json │ ├── xsum │ │ ├── xsum.json │ │ ├── xsum10.json │ │ ├── xsum2.json │ │ ├── xsum3.json │ │ ├── xsum4.json │ │ ├── xsum5.json │ │ ├── xsum6.json │ │ ├── xsum7.json │ │ ├── xsum8.json │ │ ├── xsum9.json │ │ └── xsum_all.json │ └── yelp_review_full │ │ ├── yelp_review_full.json │ │ ├── yelp_review_full2.json │ │ ├── yelp_review_full3.json │ │ ├── yelp_review_full4.json │ │ ├── yelp_review_full5.json │ │ ├── yelp_review_full6.json │ │ ├── yelp_review_full7.json │ │ └── yelp_review_full_all.json └── merge │ ├── ST_label_instance_top1.json │ ├── ST_label_instance_top3.json │ ├── oracle_top1.json │ └── oracle_top3.json ├── figure1.PNG ├── figure3.PNG ├── requirements.txt ├── retrieval ├── additional_code │ ├── get_oracle_list.py │ ├── get_oracle_list_task.py │ ├── run.py │ ├── run_expert_32.py │ ├── run_instance.py │ ├── run_instance_32.py │ ├── run_instance_32_merge.py │ ├── run_selected.py │ └── run_task_32.py ├── configs │ ├── BM25_instance_top1.json │ ├── BM25_label_instance_top1.json │ ├── ST_MP_label_instance_top1.json │ ├── ST_choice_instance_top1.json │ ├── ST_choice_top1.json │ ├── ST_instance_top1.json │ ├── ST_label_instance_top1.json │ ├── ST_label_top1.json │ ├── T0_instance_top1.json │ └── ablations │ │ ├── DIFFCSE_choice_instance_new_top1.json │ │ ├── DIFFCSE_choice_instance_top1.json │ │ ├── DIFFCSE_choice_new_top1.json │ │ ├── DIFFCSE_choice_top1.json │ │ ├── DIFFCSE_instance_new_top1.json │ │ ├── DIFFCSE_instance_top1.json │ │ ├── DIFFCSE_label_instance_new_top1.json │ │ ├── DIFFCSE_label_instance_top1.json │ │ ├── DIFFCSE_label_new_top1.json │ │ ├── DIFFCSE_label_top1.json │ │ ├── GTR_Large_choice_instance_new_top1.json │ │ ├── GTR_Large_choice_instance_top1.json │ │ ├── GTR_Large_choice_new_top1.json │ │ ├── GTR_Large_choice_top1.json │ │ ├── GTR_Large_instance_new_top1.json │ │ ├── GTR_Large_instance_top1.json │ │ ├── GTR_Large_label_instance_new_top1.json │ │ ├── GTR_Large_label_instance_top1.json │ │ ├── GTR_Large_label_new_top1.json │ │ ├── GTR_Large_label_top1.json │ │ ├── GTR_XL_choice_instance_new_top1.json │ │ ├── GTR_XL_choice_instance_top1.json │ │ ├── GTR_XL_choice_new_top1.json │ │ ├── GTR_XL_choice_top1.json │ │ ├── GTR_XL_instance_new_top1.json │ │ ├── GTR_XL_instance_top1.json │ │ ├── GTR_XL_label_instance_new_top1.json │ │ ├── GTR_XL_label_instance_top1.json │ │ ├── GTR_XL_label_new_top1.json │ │ ├── GTR_XL_label_top1.json │ │ ├── INSTRUCTOR_Large_choice_instance_new_top1.json │ │ ├── INSTRUCTOR_Large_choice_instance_top1.json │ │ ├── INSTRUCTOR_Large_choice_new_top1.json │ │ ├── INSTRUCTOR_Large_choice_top1.json │ │ ├── INSTRUCTOR_Large_instance_new_top1.json │ │ ├── INSTRUCTOR_Large_instance_top1.json │ │ ├── INSTRUCTOR_Large_label_instance_new_top1.json │ │ ├── INSTRUCTOR_Large_label_instance_top1.json │ │ ├── INSTRUCTOR_Large_label_new_top1.json │ │ ├── INSTRUCTOR_Large_label_top1.json │ │ ├── INSTRUCTOR_XL_choice_instance_new_top1.json │ │ ├── INSTRUCTOR_XL_choice_instance_top1.json │ │ ├── INSTRUCTOR_XL_choice_new_top1.json │ │ ├── INSTRUCTOR_XL_choice_top1.json │ │ ├── INSTRUCTOR_XL_instance_new_top1.json │ │ ├── INSTRUCTOR_XL_instance_top1.json │ │ ├── INSTRUCTOR_XL_label_instance_new_top1.json │ │ ├── INSTRUCTOR_XL_label_instance_top1.json │ │ ├── INSTRUCTOR_XL_label_new_top1.json │ │ ├── INSTRUCTOR_XL_label_top1.json │ │ ├── SIMCSE_SUP_choice_instance_new_top1.json │ │ ├── SIMCSE_SUP_choice_instance_top1.json │ │ ├── SIMCSE_SUP_choice_new_top1.json │ │ ├── SIMCSE_SUP_choice_top1.json │ │ ├── SIMCSE_SUP_instance_new_top1.json │ │ ├── SIMCSE_SUP_instance_top1.json │ │ ├── SIMCSE_SUP_label_instance_new_top1.json │ │ ├── SIMCSE_SUP_label_instance_top1.json │ │ ├── SIMCSE_SUP_label_new_top1.json │ │ ├── SIMCSE_SUP_label_top1.json │ │ ├── SIMCSE_UNSUP_choice_instance_new_top1.json │ │ ├── SIMCSE_UNSUP_choice_instance_top1.json │ │ ├── SIMCSE_UNSUP_choice_new_top1.json │ │ ├── SIMCSE_UNSUP_choice_top1.json │ │ ├── SIMCSE_UNSUP_instance_new_top1.json │ │ ├── SIMCSE_UNSUP_instance_top1.json │ │ ├── SIMCSE_UNSUP_label_instance_new_top1.json │ │ ├── SIMCSE_UNSUP_label_instance_top1.json │ │ ├── SIMCSE_UNSUP_label_new_top1.json │ │ ├── SIMCSE_UNSUP_label_top1.json │ │ ├── ST5_Large_choice_instance_new_top1.json │ │ ├── ST5_Large_choice_instance_top1.json │ │ ├── ST5_Large_choice_new_top1.json │ │ ├── ST5_Large_choice_top1.json │ │ ├── ST5_Large_instance_new_top1.json │ │ ├── ST5_Large_instance_top1.json │ │ ├── ST5_Large_label_instance_new_top1.json │ │ ├── ST5_Large_label_instance_top1.json │ │ ├── ST5_Large_label_new_top1.json │ │ ├── ST5_Large_label_top1.json │ │ ├── ST5_XL_choice_instance_new_top1.json │ │ ├── ST5_XL_choice_instance_top1.json │ │ ├── ST5_XL_choice_new_top1.json │ │ ├── ST5_XL_choice_top1.json │ │ ├── ST5_XL_instance_new_top1.json │ │ ├── ST5_XL_instance_top1.json │ │ ├── ST5_XL_label_instance_new_top1.json │ │ ├── ST5_XL_label_instance_top1.json │ │ ├── ST5_XL_label_new_top1.json │ │ ├── ST5_XL_label_top1.json │ │ ├── ST_12_choice_instance_new_top1.json │ │ ├── ST_12_choice_instance_top1.json │ │ ├── ST_12_choice_new_top1.json │ │ ├── ST_12_choice_top1.json │ │ ├── ST_12_instance_new_top1.json │ │ ├── ST_12_instance_top1.json │ │ ├── ST_12_label_instance_new_top1.json │ │ ├── ST_12_label_instance_top1.json │ │ ├── ST_12_label_new_top1.json │ │ ├── ST_12_label_top1.json │ │ ├── ST_MP_NLI_choice_instance_new_top1.json │ │ ├── ST_MP_NLI_choice_instance_top1.json │ │ ├── ST_MP_NLI_choice_new_top1.json │ │ ├── ST_MP_NLI_choice_top1.json │ │ ├── ST_MP_NLI_instance_new_top1.json │ │ ├── ST_MP_NLI_instance_top1.json │ │ ├── ST_MP_NLI_label_instance_new_top1.json │ │ ├── ST_MP_NLI_label_instance_top1.json │ │ ├── ST_MP_NLI_label_new_top1.json │ │ ├── ST_MP_NLI_label_top1.json │ │ ├── ST_MP_choice_instance_new_top1.json │ │ ├── ST_MP_choice_instance_top1.json │ │ ├── ST_MP_choice_new_top1.json │ │ ├── ST_MP_choice_top1.json │ │ ├── ST_MP_instance_new_top1.json │ │ ├── ST_MP_instance_top1.json │ │ ├── ST_MP_label_instance_new_top1.json │ │ ├── ST_MP_label_instance_top1.json │ │ ├── ST_MP_label_new_top1.json │ │ ├── ST_MP_label_top1.json │ │ ├── ST_choice_instance_new_top1.json │ │ ├── ST_choice_instance_top1.json │ │ ├── ST_choice_new_top1.json │ │ ├── ST_choice_top1.json │ │ ├── ST_instance_new_top1.json │ │ ├── ST_instance_top1.json │ │ ├── ST_label_instance_new_top1.json │ │ ├── ST_label_instance_top1.json │ │ ├── ST_label_new_top1.json │ │ ├── ST_label_top1.json │ │ ├── T0_choice_instance_new_top1.json │ │ ├── T0_choice_instance_top1.json │ │ ├── T0_choice_new_top1.json │ │ ├── T0_choice_top1.json │ │ ├── T0_instance_new_top1.json │ │ ├── T0_instance_top1.json │ │ ├── T0_label_instance_new_top1.json │ │ ├── T0_label_instance_top1.json │ │ ├── T0_label_new_top1.json │ │ └── T0_label_top1.json ├── make_benchmarks.py ├── make_index.py ├── run_retrieval.py └── story_cloze │ └── 2016.csv ├── seq2seq ├── adapters │ ├── __init__.py │ ├── adapter_configuration.py │ ├── adapter_controller.py │ ├── adapter_modeling.py │ ├── adapter_utils.py │ └── low_rank_layer.py ├── additional_code │ ├── check_data.ipynb │ ├── check_data2.ipynb │ ├── check_target_eval_results.py │ ├── check_target_eval_results_twitter.py │ ├── count.py │ ├── create_eval_bash_files.py │ ├── extract_instance_score.py │ ├── process_csv_seen_tasks.py │ ├── process_csv_seen_tasks_2.py │ ├── process_csv_target_eval.py │ ├── run_cT0.py │ ├── run_eval.py │ ├── run_eval_own.py │ ├── run_eval_own_ct0.py │ ├── run_eval_own_t0.py │ ├── run_seen_eval_with_json_flanT5.py │ ├── run_seq2seq_multilingual.py │ └── run_unseen_eval_with_json_multilingual.py ├── args_adapter.py ├── args_data.py ├── args_model.py ├── args_training.py ├── data │ ├── __init__.py │ ├── custom_multi_news │ │ └── custom_multi_news.py │ ├── data_collator.py │ ├── old_tasks.py │ ├── postprocessors.py │ ├── tasks.py │ ├── test.ipynb │ └── utils.py ├── extract_data_subset_seen_eval.py ├── extract_data_subset_seen_train.py ├── extract_data_subset_target_eval.py ├── extract_data_subset_target_train.py ├── extract_data_subset_unseen_eval.py ├── hypercomplex │ ├── __init__.py │ ├── inits.py │ ├── kronecker.py │ └── layers.py ├── merge │ ├── ST_choice_instance_top1_32.json │ ├── oracle_sorted.json │ └── oracle_sorted_dataset_level.json ├── metrics │ ├── __init__.py │ ├── metrics.py │ └── qa_utils.py ├── process_csv.py ├── process_csv_dataset.py ├── projections │ ├── __init__.py │ ├── fwh_cuda │ │ ├── fwh_cpp.cpp │ │ └── fwh_cu.cu │ └── intrinsic.py ├── run_seen_eval_with_json.py ├── run_seen_eval_with_json_T0.py ├── run_seen_eval_with_json_flanT5.py ├── run_seen_eval_with_json_task_vector.py ├── run_seq2seq.py ├── run_seq2seq_multilingual.py ├── run_target_eval_with_json.py ├── run_target_eval_with_json_T0.py ├── run_target_eval_with_json_task_vector.py ├── run_unseen_eval_with_json.py ├── run_unseen_eval_with_json_T0.py ├── run_unseen_eval_with_json_merge.py ├── run_unseen_eval_with_json_multilingual.py ├── run_unseen_eval_with_json_task_vector.py ├── third_party │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ └── t5 │ │ │ ├── __init__.py │ │ │ ├── configuration_t5.py │ │ │ └── modeling_t5.py │ └── trainers │ │ ├── __init__.py │ │ ├── seq2seq_trainer.py │ │ ├── trainer.py │ │ └── trainer_utils.py └── utils │ ├── __init__.py │ └── utils.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | source_prompts 133 | wandb/ 134 | outputs/ 135 | expert_weights/ 136 | cT0_checkpoints/ 137 | output_logs_t0/ 138 | output_logs_ct0/ 139 | output_logs_t0_verbalizer/ 140 | output_logs_verbalizer/ 141 | output_logs_t0_unseen/ 142 | output_logs_instance/ 143 | output_logs_t0_target/ 144 | output_logs_300/ 145 | output_logs_dataset_300/ 146 | output_logs_prompt_300/ 147 | output_logs_oracle/ 148 | output_logs_seen_eval/ 149 | output_logs_seen_eval_flanT5/ 150 | output_logs_seen_eval_T0/ 151 | output_logs_target_eval/ 152 | output_logs_test/ 153 | output_logs_unseen_eval_dataset_300/ 154 | output_logs_unseen_eval_prompt_300/ 155 | 156 | twitter_top20.tweet_as+about.author.count_vectorizer.pkl 157 | twitter_top20.tweet_as+about.author.model.pkl 158 | 159 | seq2seq/output_logs 160 | dsi/cache/ 161 | 162 | seq2seq/cache/ 163 | seq2seq/cache2/ 164 | seq2seq/csv_results 165 | seq2seq/csv_results_300 166 | seq2seq/ct0_csv_results 167 | seq2seq/csv_results_dataset_300 168 | seq2seq/csv_results_oracle 169 | seq2seq/csv_results_prev 170 | seq2seq/csv_results_target_eval 171 | seq2seq/data/amazon_review_polarity_csv/ 172 | seq2seq/data/cT0/ 173 | seq2seq/data/dbpedia_csv/ 174 | seq2seq/data/yelp_review_full_csv/ 175 | 176 | seq2seq/data/manual/ 177 | data/amazon_review_polarity_csv/* 178 | data/cT0 179 | data/dbpedia_csv/* 180 | data/yelp_review_full_csv/* 181 | dbpedia_csv.tgz 182 | story_cloze_2016_val.csv 183 | dsi/story_cloze/ 184 | yelp_review_full_csv.tgz 185 | 186 | retrieval/retrieval_data 187 | retrieval/results 188 | seq2seq/results 189 | retrieval/indexes -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "promptsource"] 2 | path = promptsource 3 | url = https://github.com/doeyoungkim/promptsource.git -------------------------------------------------------------------------------- /configs/cT0/t0.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "do_eval": false, 4 | "do_test": false, 5 | "warmup_steps": 0, 6 | "save_steps": 1000, 7 | "model_name_or_path": "bigscience/T0_3B", 8 | "tokenizer_name": "bigscience/T0_3B", 9 | "save_total_limit": 1, 10 | "per_device_train_batch_size": 1, 11 | "per_device_eval_batch_size": 8, 12 | "load_best_model_at_end": false, 13 | "metric_for_best_model": "rougeL", 14 | "greater_is_better": false, 15 | "evaluation_strategy": "epoch", 16 | "non_linearity": "gelu_new", 17 | "max_source_length": 512, 18 | "learning_rate": 1e-4, 19 | "output_dir": "expert_weights/t0", 20 | "split_validation_test": true, 21 | "task_name": ["t0"], 22 | "adapters_cur_training_task":"t0", 23 | "dataset_config_name": ["none"], 24 | "eval_dataset_name": [ 25 | "wiki_auto","wiki_auto","asset","asset", 26 | "ct0_gigaword","ct0_gigaword","ct0_gigaword","ct0_gigaword","ct0_gigaword","ct0_gigaword", 27 | "haiku", 28 | "covid_qa", 29 | "eli5", 30 | "emdg", 31 | "esnli", 32 | "twitter" 33 | ], 34 | "eval_dataset_config_name": [ 35 | "skip","skip","skip","skip", 36 | "skip","skip","skip","skip","skip","skip", 37 | "skip", 38 | "skip", 39 | "skip", 40 | "skip", 41 | "skip", 42 | "skip" 43 | ], 44 | "eval_prompts": [ 45 | "simplification_1","simplification_2","simplification_1","simplification_2", 46 | "constrain_contain+make_a_title","constrain_contain+write_its_sentence","constrain_end+make_a_title","constrain_end+write_its_sentence","constrain_start+make_a_title","constrain_start+write_its_sentence", 47 | "do_nothing", 48 | "covid_qa_deepset", 49 | "generate_a_question_1", 50 | "dialogue_with_emotion", 51 | "explain_why", 52 | "tweet_as+about" 53 | ], 54 | "test_dataset_name": ["superglue_glue","superglue_glue"], 55 | "test_dataset_config_name": ["rte","copa"], 56 | "num_train_epochs": 5, 57 | "predict_with_generate": false, 58 | "add_layer_norm_before_adapter": false, 59 | "add_layer_norm_after_adapter": false, 60 | "adapter_config_name": "adapter", 61 | "train_task_adapters": true, 62 | "task_reduction_factor": 32, 63 | "unfreeze_lm_head": false, 64 | "unfreeze_layer_norms": true, 65 | "overwrite_output_dir": true, 66 | "max_train_samples": 10000, 67 | "max_val_samples": 300, 68 | "max_test_samples": 100, 69 | "train_prompts": ["t0"], 70 | "test_prompts": ["can we infer","best_option"] 71 | } -------------------------------------------------------------------------------- /configs/debug.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_eval": true, 3 | "do_test": true, 4 | "do_train": true, 5 | "warmup_steps": 500, 6 | "save_steps": 1000, 7 | "model_name_or_path": "t5-base", 8 | "tokenizer_name": "t5-base", 9 | "save_total_limit": 1, 10 | "load_best_model_at_end": true, 11 | "metric_for_best_model": "average_metrics", 12 | "greater_is_better": true, 13 | "evaluation_strategy": "epoch", 14 | "non_linearity": "gelu_new", 15 | "max_source_length": 8, 16 | "learning_rate": 5e-4, 17 | "output_dir": "outputs/finetuning", 18 | "per_device_train_batch_size": 16, 19 | "per_device_eval_batch_size": 16, 20 | "split_validation_test": true, 21 | "task_name": ["superglue-boolq"], 22 | "eval_dataset_name": ["superglue-boolq"], 23 | "test_dataset_name": ["superglue-boolq"], 24 | "num_train_epochs": 1, 25 | "dataset_config_name": ["en"], 26 | "eval_dataset_config_name": ["en"], 27 | "test_dataset_config_name": ["en"], 28 | "predict_with_generate": true, 29 | "wandb_log": true, 30 | "wandb_project": "debug", 31 | "wandb_run_name" : "yolo" 32 | } -------------------------------------------------------------------------------- /configs/experiment/Flan_T5_11B.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "do_eval": false, 4 | "do_test": false, 5 | "warmup_steps": 0, 6 | "save_steps": 1000, 7 | "model_name_or_path": "google/flan-t5-xxl", 8 | "tokenizer_name": "google/flan-t5-xxl", 9 | "save_total_limit": 1, 10 | "per_device_train_batch_size": 1, 11 | "per_device_eval_batch_size": 64, 12 | "load_best_model_at_end": false, 13 | "metric_for_best_model": "rougeL", 14 | "greater_is_better": false, 15 | "evaluation_strategy": "epoch", 16 | "non_linearity": "gelu_new", 17 | "max_source_length": 512, 18 | "learning_rate": 1e-4, 19 | "output_dir": "expert_weights/flan-t5-xxl", 20 | "split_validation_test": true, 21 | "task_name": ["flan-t5-xxl"], 22 | "adapters_cur_training_task":"flan-t5-xxl", 23 | "dataset_config_name": ["none"], 24 | "eval_dataset_name":[ 25 | "hellaswag","hellaswag","hellaswag","hellaswag","hellaswag", 26 | "story_cloze","story_cloze","story_cloze","story_cloze","story_cloze", 27 | "anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1", 28 | "anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2", 29 | "anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3", 30 | "super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa", 31 | "super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb", 32 | "super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte", 33 | "super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed", 34 | "super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic", 35 | "winogrande","winogrande","winogrande","winogrande","winogrande" 36 | ], 37 | "eval_dataset_config_name":[ 38 | "none","none","none","none","none", 39 | "none","none","none","none","none", 40 | "dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1", 41 | "dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2", 42 | "dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3", 43 | "copa","copa","copa","copa","copa","copa","copa","copa", 44 | "cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb", 45 | "rte","rte","rte","rte","rte","rte","rte","rte","rte","rte", 46 | "wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed", 47 | "wic","wic","wic","wic","wic","wic","wic","wic","wic","wic", 48 | "winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl" 49 | ], 50 | "data_args.eval_prompts":[ 51 | "complete_first_then","Randomized prompts template","Predict ending with hint","how_ends","if_begins_how_continues", 52 | "Answer Given options","Choose Story Ending","Movie What Happens Next","Story Continuation and Options","Novel Correct Ending", 53 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 54 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 55 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 56 | "exercise","i_am_hesitating","plausible_alternatives","C1 or C2? premise, so/because\u2026","best_option","more likely","cause_effect","choose", 57 | "can we infer","based on the previous passage","claim true/false/inconclusive","does it follow that","justified in saying","always/sometimes/never","GPT-3 style","consider always/sometimes/never","guaranteed true","must be true","guaranteed/possible/impossible","does this imply","MNLI crowdsource","should assume","take the following as truth", 58 | "MNLI crowdsource","guaranteed true","can we infer","GPT-3 style","does this imply","should assume","does it follow that","based on the previous passage","justified in saying","must be true", 59 | "does the pronoun refer to","by p they mean","in other words","I think they mean","does p stand for","GPT-3 Style","replaced with","p is/are r","the pronoun refers to","Who or what is/are", 60 | "question-context-meaning-with-label","question-context-meaning","grammar_homework","affirmation_true_or_false","GPT-3-prompt","same_sense","question-context","GPT-3-prompt-with-label","polysemous","similar-sense", 61 | "does underscore refer to","stand for","underscore refer to","fill in the blank","Replace" 62 | ], 63 | "test_dataset_name": ["superglue_glue","superglue_glue"], 64 | "test_dataset_config_name": ["rte","copa"], 65 | "num_train_epochs": 5, 66 | "predict_with_generate": false, 67 | "add_layer_norm_before_adapter": false, 68 | "add_layer_norm_after_adapter": false, 69 | "adapter_config_name": "adapter", 70 | "train_task_adapters": true, 71 | "task_reduction_factor": 32, 72 | "unfreeze_lm_head": false, 73 | "unfreeze_layer_norms": true, 74 | "overwrite_output_dir": true, 75 | "max_train_samples": 10000, 76 | "max_test_samples": 100, 77 | "train_prompts": ["flan-t5-xxl"], 78 | "test_prompts": ["can we infer","best_option"] 79 | } -------------------------------------------------------------------------------- /configs/experiment/Flan_T5_3B.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "do_eval": false, 4 | "do_test": false, 5 | "warmup_steps": 0, 6 | "save_steps": 1000, 7 | "model_name_or_path": "google/flan-t5-xl", 8 | "tokenizer_name": "google/flan-t5-xl", 9 | "save_total_limit": 1, 10 | "per_device_train_batch_size": 1, 11 | "per_device_eval_batch_size": 64, 12 | "load_best_model_at_end": false, 13 | "metric_for_best_model": "rougeL", 14 | "greater_is_better": false, 15 | "evaluation_strategy": "epoch", 16 | "non_linearity": "gelu_new", 17 | "max_source_length": 512, 18 | "learning_rate": 1e-4, 19 | "output_dir": "expert_weights/flan-t5-xl", 20 | "split_validation_test": true, 21 | "task_name": ["flan-t5-xl"], 22 | "adapters_cur_training_task":"flan-t5-xl", 23 | "dataset_config_name": ["none"], 24 | "eval_dataset_name":[ 25 | "hellaswag","hellaswag","hellaswag","hellaswag","hellaswag", 26 | "story_cloze","story_cloze","story_cloze","story_cloze","story_cloze", 27 | "anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1", 28 | "anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2", 29 | "anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3", 30 | "super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa", 31 | "super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb", 32 | "super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte", 33 | "super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed", 34 | "super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic", 35 | "winogrande","winogrande","winogrande","winogrande","winogrande" 36 | ], 37 | "eval_dataset_config_name":[ 38 | "none","none","none","none","none", 39 | "none","none","none","none","none", 40 | "dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1", 41 | "dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2", 42 | "dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3", 43 | "copa","copa","copa","copa","copa","copa","copa","copa", 44 | "cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb", 45 | "rte","rte","rte","rte","rte","rte","rte","rte","rte","rte", 46 | "wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed", 47 | "wic","wic","wic","wic","wic","wic","wic","wic","wic","wic", 48 | "winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl" 49 | ], 50 | "data_args.eval_prompts":[ 51 | "complete_first_then","Randomized prompts template","Predict ending with hint","how_ends","if_begins_how_continues", 52 | "Answer Given options","Choose Story Ending","Movie What Happens Next","Story Continuation and Options","Novel Correct Ending", 53 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 54 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 55 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 56 | "exercise","i_am_hesitating","plausible_alternatives","C1 or C2? premise, so/because\u2026","best_option","more likely","cause_effect","choose", 57 | "can we infer","based on the previous passage","claim true/false/inconclusive","does it follow that","justified in saying","always/sometimes/never","GPT-3 style","consider always/sometimes/never","guaranteed true","must be true","guaranteed/possible/impossible","does this imply","MNLI crowdsource","should assume","take the following as truth", 58 | "MNLI crowdsource","guaranteed true","can we infer","GPT-3 style","does this imply","should assume","does it follow that","based on the previous passage","justified in saying","must be true", 59 | "does the pronoun refer to","by p they mean","in other words","I think they mean","does p stand for","GPT-3 Style","replaced with","p is/are r","the pronoun refers to","Who or what is/are", 60 | "question-context-meaning-with-label","question-context-meaning","grammar_homework","affirmation_true_or_false","GPT-3-prompt","same_sense","question-context","GPT-3-prompt-with-label","polysemous","similar-sense", 61 | "does underscore refer to","stand for","underscore refer to","fill in the blank","Replace" 62 | ], 63 | "test_dataset_name": ["superglue_glue","superglue_glue"], 64 | "test_dataset_config_name": ["rte","copa"], 65 | "num_train_epochs": 5, 66 | "predict_with_generate": false, 67 | "add_layer_norm_before_adapter": false, 68 | "add_layer_norm_after_adapter": false, 69 | "adapter_config_name": "adapter", 70 | "train_task_adapters": true, 71 | "task_reduction_factor": 32, 72 | "unfreeze_lm_head": false, 73 | "unfreeze_layer_norms": true, 74 | "overwrite_output_dir": true, 75 | "max_train_samples": 10000, 76 | "max_test_samples": 100, 77 | "train_prompts": ["flan-t5-xl"], 78 | "test_prompts": ["can we infer","best_option"] 79 | } -------------------------------------------------------------------------------- /configs/experiment/mt0.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "do_eval": false, 4 | "do_test": false, 5 | "warmup_steps": 0, 6 | "save_steps": 1000, 7 | "model_name_or_path": "./mt0-xl", 8 | "tokenizer_name": "bigscience/mt0-xl", 9 | "save_total_limit": 1, 10 | "per_device_train_batch_size": 1, 11 | "per_device_eval_batch_size": 16, 12 | "load_best_model_at_end": false, 13 | "metric_for_best_model": "rougeL", 14 | "greater_is_better": false, 15 | "evaluation_strategy": "epoch", 16 | "non_linearity": "gelu_new", 17 | "max_source_length": 512, 18 | "learning_rate": 1e-4, 19 | "output_dir": "expert_weights/mt0", 20 | "split_validation_test": true, 21 | "task_name": ["mt0"], 22 | "adapters_cur_training_task":"mt0", 23 | "dataset_config_name": ["none"], 24 | "eval_dataset_name": [ 25 | "xlsum","xlsum" 26 | ], 27 | "eval_dataset_config_name": [ 28 | "none","none" 29 | ], 30 | "eval_prompts": [ 31 | "translate to kor and summarize","translate to eng and summarize" 32 | ], 33 | "test_dataset_name": ["xlsum"], 34 | "test_dataset_config_name": ["none"], 35 | "num_train_epochs": 5, 36 | "predict_with_generate": false, 37 | "add_layer_norm_before_adapter": false, 38 | "add_layer_norm_after_adapter": false, 39 | "adapter_config_name": "adapter", 40 | "train_task_adapters": false, 41 | "task_reduction_factor": 32, 42 | "unfreeze_lm_head": false, 43 | "unfreeze_layer_norms": true, 44 | "overwrite_output_dir": true, 45 | "max_train_samples": 10000, 46 | "max_val_samples": 300, 47 | "max_test_samples": 100, 48 | "train_prompts": ["mt0"], 49 | "test_prompts": ["translate to kor and summarize"] 50 | } -------------------------------------------------------------------------------- /configs/experiment/t0.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "do_eval": false, 4 | "do_test": false, 5 | "warmup_steps": 0, 6 | "save_steps": 1000, 7 | "model_name_or_path": "bigscience/T0_3B", 8 | "tokenizer_name": "bigscience/T0_3B", 9 | "save_total_limit": 1, 10 | "per_device_train_batch_size": 1, 11 | "per_device_eval_batch_size": 16, 12 | "load_best_model_at_end": false, 13 | "metric_for_best_model": "rougeL", 14 | "greater_is_better": false, 15 | "evaluation_strategy": "epoch", 16 | "non_linearity": "gelu_new", 17 | "max_source_length": 512, 18 | "learning_rate": 1e-4, 19 | "output_dir": "expert_weights/t0", 20 | "split_validation_test": true, 21 | "task_name": ["t0"], 22 | "adapters_cur_training_task":"t0", 23 | "dataset_config_name": ["none"], 24 | "eval_dataset_name": [ 25 | "hellaswag","hellaswag","hellaswag","hellaswag","hellaswag","hellaswag","hellaswag", 26 | "story_cloze","story_cloze","story_cloze","story_cloze","story_cloze", 27 | "anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1", 28 | "anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2", 29 | "anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3", 30 | "super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa", 31 | "super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb", 32 | "super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte", 33 | "super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed", 34 | "super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic", 35 | "winogrande","winogrande","winogrande","winogrande","winogrande","winogrande" 36 | ], 37 | "eval_dataset_config_name": [ 38 | "none","none","none","none","none","none","none", 39 | "none","none","none","none","none", 40 | "dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1", 41 | "dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2", 42 | "dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3", 43 | "copa","copa","copa","copa","copa","copa","copa","copa","copa","copa","copa","copa", 44 | "cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb", 45 | "rte","rte","rte","rte","rte","rte","rte","rte","rte","rte", 46 | "wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed", 47 | "wic","wic","wic","wic","wic","wic","wic","wic","wic","wic", 48 | "winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl" 49 | ], 50 | "eval_prompts": [ 51 | "complete_first_then","Randomized prompts template","Appropriate continuation - Yes or No","Predict ending with hint","Reversed appropriate continuation - Yes or No","how_ends","if_begins_how_continues", 52 | "Answer Given options","Choose Story Ending","Movie What Happens Next","Story Continuation and Options","Novel Correct Ending", 53 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 54 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 55 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 56 | "exercise","\u2026What could happen next, C1 or C2?","i_am_hesitating","plausible_alternatives","C1 or C2? premise, so/because\u2026","\u2026As a result, C1 or C2?","best_option","\u2026which may be caused by","more likely","cause_effect","\u2026why? C1 or C2","choose", 57 | "can we infer","based on the previous passage","claim true/false/inconclusive","does it follow that","justified in saying","always/sometimes/never","GPT-3 style","consider always/sometimes/never","guaranteed true","must be true","guaranteed/possible/impossible","does this imply","MNLI crowdsource","should assume","take the following as truth", 58 | "MNLI crowdsource","guaranteed true","can we infer","GPT-3 style","does this imply","should assume","does it follow that","based on the previous passage","justified in saying","must be true", 59 | "does the pronoun refer to","by p they mean","in other words","I think they mean","does p stand for","GPT-3 Style","replaced with","p is/are r","the pronoun refers to","Who or what is/are", 60 | "question-context-meaning-with-label","question-context-meaning","grammar_homework","affirmation_true_or_false","GPT-3-prompt","same_sense","question-context","GPT-3-prompt-with-label","polysemous","similar-sense", 61 | "does underscore refer to","stand for","underscore refer to","fill in the blank","True or False","Replace" 62 | ], 63 | "test_dataset_name": ["superglue_glue","superglue_glue"], 64 | "test_dataset_config_name": ["rte","copa"], 65 | "num_train_epochs": 5, 66 | "predict_with_generate": false, 67 | "add_layer_norm_before_adapter": false, 68 | "add_layer_norm_after_adapter": false, 69 | "adapter_config_name": "adapter", 70 | "train_task_adapters": true, 71 | "task_reduction_factor": 32, 72 | "unfreeze_lm_head": false, 73 | "unfreeze_layer_norms": true, 74 | "overwrite_output_dir": true, 75 | "max_train_samples": 10000, 76 | "max_val_samples": 300, 77 | "max_test_samples": 100, 78 | "train_prompts": ["t0"], 79 | "test_prompts": ["can we infer","best_option"] 80 | } -------------------------------------------------------------------------------- /configs/experiment/t0_11B.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "do_eval": false, 4 | "do_test": false, 5 | "warmup_steps": 0, 6 | "save_steps": 1000, 7 | "model_name_or_path": "bigscience/T0", 8 | "tokenizer_name": "bigscience/T0", 9 | "save_total_limit": 1, 10 | "per_device_train_batch_size": 1, 11 | "per_device_eval_batch_size": 64, 12 | "load_best_model_at_end": false, 13 | "metric_for_best_model": "rougeL", 14 | "greater_is_better": false, 15 | "evaluation_strategy": "epoch", 16 | "non_linearity": "gelu_new", 17 | "max_source_length": 512, 18 | "learning_rate": 1e-4, 19 | "output_dir": "expert_weights/t0_11B", 20 | "split_validation_test": true, 21 | "task_name": ["t0_11B"], 22 | "adapters_cur_training_task":"t0_11B", 23 | "dataset_config_name": ["none"], 24 | "eval_dataset_name":[ 25 | "hellaswag","hellaswag","hellaswag","hellaswag","hellaswag", 26 | "story_cloze","story_cloze","story_cloze","story_cloze","story_cloze", 27 | "anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1", 28 | "anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2", 29 | "anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3", 30 | "super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa", 31 | "super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb", 32 | "super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte", 33 | "super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed", 34 | "super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic", 35 | "winogrande","winogrande","winogrande","winogrande","winogrande" 36 | ], 37 | "eval_dataset_config_name":[ 38 | "none","none","none","none","none", 39 | "none","none","none","none","none", 40 | "dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1", 41 | "dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2", 42 | "dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3", 43 | "copa","copa","copa","copa","copa","copa","copa","copa", 44 | "cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb", 45 | "rte","rte","rte","rte","rte","rte","rte","rte","rte","rte", 46 | "wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed", 47 | "wic","wic","wic","wic","wic","wic","wic","wic","wic","wic", 48 | "winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl" 49 | ], 50 | "data_args.eval_prompts":[ 51 | "complete_first_then","Randomized prompts template","Predict ending with hint","how_ends","if_begins_how_continues", 52 | "Answer Given options","Choose Story Ending","Movie What Happens Next","Story Continuation and Options","Novel Correct Ending", 53 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 54 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 55 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 56 | "exercise","i_am_hesitating","plausible_alternatives","C1 or C2? premise, so/because\u2026","best_option","more likely","cause_effect","choose", 57 | "can we infer","based on the previous passage","claim true/false/inconclusive","does it follow that","justified in saying","always/sometimes/never","GPT-3 style","consider always/sometimes/never","guaranteed true","must be true","guaranteed/possible/impossible","does this imply","MNLI crowdsource","should assume","take the following as truth", 58 | "MNLI crowdsource","guaranteed true","can we infer","GPT-3 style","does this imply","should assume","does it follow that","based on the previous passage","justified in saying","must be true", 59 | "does the pronoun refer to","by p they mean","in other words","I think they mean","does p stand for","GPT-3 Style","replaced with","p is/are r","the pronoun refers to","Who or what is/are", 60 | "question-context-meaning-with-label","question-context-meaning","grammar_homework","affirmation_true_or_false","GPT-3-prompt","same_sense","question-context","GPT-3-prompt-with-label","polysemous","similar-sense", 61 | "does underscore refer to","stand for","underscore refer to","fill in the blank","Replace" 62 | ], 63 | "test_dataset_name": ["superglue_glue","superglue_glue"], 64 | "test_dataset_config_name": ["rte","copa"], 65 | "num_train_epochs": 5, 66 | "predict_with_generate": false, 67 | "add_layer_norm_before_adapter": false, 68 | "add_layer_norm_after_adapter": false, 69 | "adapter_config_name": "adapter", 70 | "train_task_adapters": true, 71 | "task_reduction_factor": 32, 72 | "unfreeze_lm_head": false, 73 | "unfreeze_layer_norms": true, 74 | "overwrite_output_dir": true, 75 | "max_train_samples": 10000, 76 | "max_test_samples": 100, 77 | "train_prompts": ["t0_11B"], 78 | "test_prompts": ["can we infer","best_option"] 79 | } -------------------------------------------------------------------------------- /configs/experiment/t0_3B.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "do_eval": false, 4 | "do_test": false, 5 | "warmup_steps": 0, 6 | "save_steps": 1000, 7 | "model_name_or_path": "bigscience/T0_3B", 8 | "tokenizer_name": "bigscience/T0_3B", 9 | "save_total_limit": 1, 10 | "per_device_train_batch_size": 1, 11 | "per_device_eval_batch_size": 64, 12 | "load_best_model_at_end": false, 13 | "metric_for_best_model": "rougeL", 14 | "greater_is_better": false, 15 | "evaluation_strategy": "epoch", 16 | "non_linearity": "gelu_new", 17 | "max_source_length": 512, 18 | "learning_rate": 1e-4, 19 | "output_dir": "expert_weights/t0_3B", 20 | "split_validation_test": true, 21 | "task_name": ["t0_3B"], 22 | "adapters_cur_training_task":"t0_3B", 23 | "dataset_config_name": ["none"], 24 | "eval_dataset_name":[ 25 | "hellaswag","hellaswag","hellaswag","hellaswag","hellaswag", 26 | "story_cloze","story_cloze","story_cloze","story_cloze","story_cloze", 27 | "anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1", 28 | "anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2", 29 | "anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3", 30 | "super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa", 31 | "super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb", 32 | "super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte", 33 | "super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed", 34 | "super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic", 35 | "winogrande","winogrande","winogrande","winogrande","winogrande" 36 | ], 37 | "eval_dataset_config_name":[ 38 | "none","none","none","none","none", 39 | "none","none","none","none","none", 40 | "dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1", 41 | "dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2", 42 | "dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3", 43 | "copa","copa","copa","copa","copa","copa","copa","copa", 44 | "cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb", 45 | "rte","rte","rte","rte","rte","rte","rte","rte","rte","rte", 46 | "wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed", 47 | "wic","wic","wic","wic","wic","wic","wic","wic","wic","wic", 48 | "winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl" 49 | ], 50 | "data_args.eval_prompts":[ 51 | "complete_first_then","Randomized prompts template","Predict ending with hint","how_ends","if_begins_how_continues", 52 | "Answer Given options","Choose Story Ending","Movie What Happens Next","Story Continuation and Options","Novel Correct Ending", 53 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 54 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 55 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 56 | "exercise","i_am_hesitating","plausible_alternatives","C1 or C2? premise, so/because\u2026","best_option","more likely","cause_effect","choose", 57 | "can we infer","based on the previous passage","claim true/false/inconclusive","does it follow that","justified in saying","always/sometimes/never","GPT-3 style","consider always/sometimes/never","guaranteed true","must be true","guaranteed/possible/impossible","does this imply","MNLI crowdsource","should assume","take the following as truth", 58 | "MNLI crowdsource","guaranteed true","can we infer","GPT-3 style","does this imply","should assume","does it follow that","based on the previous passage","justified in saying","must be true", 59 | "does the pronoun refer to","by p they mean","in other words","I think they mean","does p stand for","GPT-3 Style","replaced with","p is/are r","the pronoun refers to","Who or what is/are", 60 | "question-context-meaning-with-label","question-context-meaning","grammar_homework","affirmation_true_or_false","GPT-3-prompt","same_sense","question-context","GPT-3-prompt-with-label","polysemous","similar-sense", 61 | "does underscore refer to","stand for","underscore refer to","fill in the blank","Replace" 62 | ], 63 | "test_dataset_name": ["superglue_glue","superglue_glue"], 64 | "test_dataset_config_name": ["rte","copa"], 65 | "num_train_epochs": 5, 66 | "predict_with_generate": false, 67 | "add_layer_norm_before_adapter": false, 68 | "add_layer_norm_after_adapter": false, 69 | "adapter_config_name": "adapter", 70 | "train_task_adapters": true, 71 | "task_reduction_factor": 32, 72 | "unfreeze_lm_head": false, 73 | "unfreeze_layer_norms": true, 74 | "overwrite_output_dir": true, 75 | "max_train_samples": 10000, 76 | "max_test_samples": 100, 77 | "train_prompts": ["t0_3B"], 78 | "test_prompts": ["can we infer","best_option"] 79 | } -------------------------------------------------------------------------------- /configs/experiment/t0_target.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeljang/ELM/27fc2a326cd8846717a022db1cbd94d393420bce/configs/experiment/t0_target.json -------------------------------------------------------------------------------- /configs/experiment/trec/trec11.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "do_eval": false, 4 | "do_test": false, 5 | "warmup_steps": 0, 6 | "save_steps": 1000, 7 | "model_name_or_path": "google/t5-xl-lm-adapt", 8 | "tokenizer_name": "google/t5-xl-lm-adapt", 9 | "save_total_limit": 1, 10 | "per_device_train_batch_size": 16, 11 | "per_device_eval_batch_size": 16, 12 | "load_best_model_at_end": false, 13 | "metric_for_best_model": "accuracy", 14 | "greater_is_better": false, 15 | "evaluation_strategy": "epoch", 16 | "non_linearity": "gelu_new", 17 | "max_source_length": 256, 18 | "learning_rate": 1e-4, 19 | "output_dir": "expert_weights/trec/trec1", 20 | "split_validation_test": true, 21 | "task_name": ["trec"], 22 | "adapters_cur_training_task":"trec", 23 | "dataset_config_name": ["none"], 24 | "eval_dataset_name": ["trec", 25 | "hellaswag","hellaswag","hellaswag","hellaswag","hellaswag","hellaswag","hellaswag", 26 | "story_cloze","story_cloze","story_cloze","story_cloze","story_cloze", 27 | "anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1", 28 | "anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2", 29 | "anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3", 30 | "super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa", 31 | "super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb", 32 | "super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte", 33 | "super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed", 34 | "super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic", 35 | "winogrande","winogrande","winogrande","winogrande","winogrande","winogrande" 36 | ], 37 | "eval_dataset_config_name": ["none", 38 | "none","none","none","none","none","none","none", 39 | "none","none","none","none","none", 40 | "dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1", 41 | "dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2", 42 | "dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3", 43 | "copa","copa","copa","copa","copa","copa","copa","copa","copa","copa","copa","copa", 44 | "cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb", 45 | "rte","rte","rte","rte","rte","rte","rte","rte","rte","rte", 46 | "wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed", 47 | "wic","wic","wic","wic","wic","wic","wic","wic","wic","wic", 48 | "winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl" 49 | ], 50 | "eval_prompts": ["trec1", 51 | "complete_first_then","Randomized prompts template","Appropriate continuation - Yes or No","Predict ending with hint","Reversed appropriate continuation - Yes or No","how_ends","if_begins_how_continues", 52 | "Answer Given options","Choose Story Ending","Movie What Happens Next","Story Continuation and Options","Novel Correct Ending", 53 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 54 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 55 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 56 | "exercise","\u2026What could happen next, C1 or C2?","i_am_hesitating","plausible_alternatives","C1 or C2? premise, so/because\u2026","\u2026As a result, C1 or C2?","best_option","\u2026which may be caused by","more likely","cause_effect","\u2026why? C1 or C2","choose", 57 | "can we infer","based on the previous passage","claim true/false/inconclusive","does it follow that","justified in saying","always/sometimes/never","GPT-3 style","consider always/sometimes/never","guaranteed true","must be true","guaranteed/possible/impossible","does this imply","MNLI crowdsource","should assume","take the following as truth", 58 | "MNLI crowdsource","guaranteed true","can we infer","GPT-3 style","does this imply","should assume","does it follow that","based on the previous passage","justified in saying","must be true", 59 | "does the pronoun refer to","by p they mean","in other words","I think they mean","does p stand for","GPT-3 Style","replaced with","p is/are r","the pronoun refers to","Who or what is/are", 60 | "question-context-meaning-with-label","question-context-meaning","grammar_homework","affirmation_true_or_false","GPT-3-prompt","same_sense","question-context","GPT-3-prompt-with-label","polysemous","similar-sense", 61 | "does underscore refer to","stand for","underscore refer to","fill in the blank","True or False","Replace" 62 | ], 63 | "test_dataset_name": ["superglue_glue","superglue_glue"], 64 | "test_dataset_config_name": ["rte","copa"], 65 | "num_train_epochs": 10, 66 | "predict_with_generate": false, 67 | "add_layer_norm_before_adapter": false, 68 | "add_layer_norm_after_adapter": false, 69 | "adapter_config_name": "adapter", 70 | "train_task_adapters": true, 71 | "task_reduction_factor": 32, 72 | "unfreeze_lm_head": false, 73 | "unfreeze_layer_norms": true, 74 | "overwrite_output_dir": true, 75 | "max_val_samples": 300, 76 | "max_test_samples": 100, 77 | "train_prompts": ["trec1"], 78 | "test_prompts": ["can we infer","best_option"] 79 | } -------------------------------------------------------------------------------- /configs/experiment/trec/trec14.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "do_eval": false, 4 | "do_test": false, 5 | "warmup_steps": 0, 6 | "save_steps": 1000, 7 | "model_name_or_path": "google/t5-xl-lm-adapt", 8 | "tokenizer_name": "google/t5-xl-lm-adapt", 9 | "save_total_limit": 1, 10 | "per_device_train_batch_size": 16, 11 | "per_device_eval_batch_size": 16, 12 | "load_best_model_at_end": false, 13 | "metric_for_best_model": "accuracy", 14 | "greater_is_better": false, 15 | "evaluation_strategy": "epoch", 16 | "non_linearity": "gelu_new", 17 | "max_source_length": 256, 18 | "learning_rate": 1e-4, 19 | "output_dir": "expert_weights/trec/trec2", 20 | "split_validation_test": true, 21 | "task_name": ["trec"], 22 | "adapters_cur_training_task":"trec", 23 | "dataset_config_name": ["none"], 24 | "eval_dataset_name": ["trec", 25 | "hellaswag","hellaswag","hellaswag","hellaswag","hellaswag","hellaswag","hellaswag", 26 | "story_cloze","story_cloze","story_cloze","story_cloze","story_cloze", 27 | "anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1", 28 | "anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2", 29 | "anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3", 30 | "super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa", 31 | "super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb", 32 | "super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte", 33 | "super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed", 34 | "super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic", 35 | "winogrande","winogrande","winogrande","winogrande","winogrande","winogrande" 36 | ], 37 | "eval_dataset_config_name": ["none", 38 | "none","none","none","none","none","none","none", 39 | "none","none","none","none","none", 40 | "dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1", 41 | "dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2", 42 | "dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3", 43 | "copa","copa","copa","copa","copa","copa","copa","copa","copa","copa","copa","copa", 44 | "cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb", 45 | "rte","rte","rte","rte","rte","rte","rte","rte","rte","rte", 46 | "wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed", 47 | "wic","wic","wic","wic","wic","wic","wic","wic","wic","wic", 48 | "winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl" 49 | ], 50 | "eval_prompts": ["trec2", 51 | "complete_first_then","Randomized prompts template","Appropriate continuation - Yes or No","Predict ending with hint","Reversed appropriate continuation - Yes or No","how_ends","if_begins_how_continues", 52 | "Answer Given options","Choose Story Ending","Movie What Happens Next","Story Continuation and Options","Novel Correct Ending", 53 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 54 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 55 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 56 | "exercise","\u2026What could happen next, C1 or C2?","i_am_hesitating","plausible_alternatives","C1 or C2? premise, so/because\u2026","\u2026As a result, C1 or C2?","best_option","\u2026which may be caused by","more likely","cause_effect","\u2026why? C1 or C2","choose", 57 | "can we infer","based on the previous passage","claim true/false/inconclusive","does it follow that","justified in saying","always/sometimes/never","GPT-3 style","consider always/sometimes/never","guaranteed true","must be true","guaranteed/possible/impossible","does this imply","MNLI crowdsource","should assume","take the following as truth", 58 | "MNLI crowdsource","guaranteed true","can we infer","GPT-3 style","does this imply","should assume","does it follow that","based on the previous passage","justified in saying","must be true", 59 | "does the pronoun refer to","by p they mean","in other words","I think they mean","does p stand for","GPT-3 Style","replaced with","p is/are r","the pronoun refers to","Who or what is/are", 60 | "question-context-meaning-with-label","question-context-meaning","grammar_homework","affirmation_true_or_false","GPT-3-prompt","same_sense","question-context","GPT-3-prompt-with-label","polysemous","similar-sense", 61 | "does underscore refer to","stand for","underscore refer to","fill in the blank","True or False","Replace" 62 | ], 63 | "test_dataset_name": ["superglue_glue","superglue_glue"], 64 | "test_dataset_config_name": ["rte","copa"], 65 | "num_train_epochs": 10, 66 | "predict_with_generate": false, 67 | "add_layer_norm_before_adapter": false, 68 | "add_layer_norm_after_adapter": false, 69 | "adapter_config_name": "adapter", 70 | "train_task_adapters": true, 71 | "task_reduction_factor": 32, 72 | "unfreeze_lm_head": false, 73 | "unfreeze_layer_norms": true, 74 | "overwrite_output_dir": true, 75 | "max_val_samples": 300, 76 | "max_test_samples": 100, 77 | "train_prompts": ["trec2"], 78 | "test_prompts": ["can we infer","best_option"] 79 | } -------------------------------------------------------------------------------- /configs/merge/ST_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "top_k_merge": 1, 3 | "log_dir": "output_logs_ST_label_instance_top1k", 4 | "expert_mapping_dir": "merge/ST_choice_instance_top1_32.json", 5 | "do_train": true, 6 | "do_eval": false, 7 | "do_test": false, 8 | "warmup_steps": 0, 9 | "save_steps": 1000, 10 | "model_name_or_path": "google/t5-xl-lm-adapt", 11 | "tokenizer_name": "google/t5-xl-lm-adapt", 12 | "save_total_limit": 1, 13 | "per_device_train_batch_size": 4, 14 | "per_device_eval_batch_size": 4, 15 | "load_best_model_at_end": false, 16 | "metric_for_best_model": "rougeL", 17 | "greater_is_better": false, 18 | "evaluation_strategy": "epoch", 19 | "non_linearity": "gelu_new", 20 | "max_source_length": 512, 21 | "learning_rate": 1e-4, 22 | "output_dir": "expert_weights/common_gen/Given concepts - type 2", 23 | "split_validation_test": true, 24 | "task_name": ["common_gen"], 25 | "adapters_cur_training_task":"common_gen", 26 | "dataset_config_name": ["none"], 27 | "eval_dataset_name": [ 28 | "hellaswag","hellaswag","hellaswag","hellaswag","hellaswag", 29 | "story_cloze","story_cloze","story_cloze","story_cloze","story_cloze", 30 | "anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1", 31 | "anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2", 32 | "anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3", 33 | "super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa", 34 | "super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb", 35 | "super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte", 36 | "super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed", 37 | "super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic", 38 | "winogrande","winogrande","winogrande","winogrande","winogrande" 39 | ], 40 | "eval_dataset_config_name": [ 41 | "none","none","none","none","none", 42 | "none","none","none","none","none", 43 | "dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1", 44 | "dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2", 45 | "dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3", 46 | "copa","copa","copa","copa","copa","copa","copa","copa", 47 | "cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb", 48 | "rte","rte","rte","rte","rte","rte","rte","rte","rte","rte", 49 | "wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed", 50 | "wic","wic","wic","wic","wic","wic","wic","wic","wic","wic", 51 | "winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl" 52 | ], 53 | "eval_prompts": [ 54 | "complete_first_then","Randomized prompts template", "Predict ending with hint", "how_ends","if_begins_how_continues", 55 | "Answer Given options","Choose Story Ending","Movie What Happens Next","Story Continuation and Options","Novel Correct Ending", 56 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 57 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 58 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 59 | "exercise","i_am_hesitating","plausible_alternatives","C1 or C2? premise, so/because\u2026","best_option","more likely","cause_effect", "choose", 60 | "can we infer","based on the previous passage","claim true/false/inconclusive","does it follow that","justified in saying","always/sometimes/never","GPT-3 style","consider always/sometimes/never","guaranteed true","must be true","guaranteed/possible/impossible","does this imply","MNLI crowdsource","should assume","take the following as truth", 61 | "MNLI crowdsource","guaranteed true","can we infer","GPT-3 style","does this imply","should assume","does it follow that","based on the previous passage","justified in saying","must be true", 62 | "does the pronoun refer to","by p they mean","in other words","I think they mean","does p stand for","GPT-3 Style","replaced with","p is/are r","the pronoun refers to","Who or what is/are", 63 | "question-context-meaning-with-label","question-context-meaning","grammar_homework","affirmation_true_or_false","GPT-3-prompt","same_sense","question-context","GPT-3-prompt-with-label","polysemous","similar-sense", 64 | "does underscore refer to","stand for","underscore refer to","fill in the blank","Replace" 65 | ], 66 | "test_dataset_name": ["superglue_glue","superglue_glue"], 67 | "test_dataset_config_name": ["rte","copa"], 68 | "num_train_epochs": 5, 69 | "predict_with_generate": false, 70 | "add_layer_norm_before_adapter": false, 71 | "add_layer_norm_after_adapter": false, 72 | "adapter_config_name": "adapter", 73 | "train_task_adapters": true, 74 | "task_reduction_factor": 32, 75 | "unfreeze_lm_head": false, 76 | "unfreeze_layer_norms": true, 77 | "overwrite_output_dir": true, 78 | "max_train_samples": 50000, 79 | "max_val_samples": 8000, 80 | "max_test_samples": 100, 81 | "train_prompts": ["Given concepts - type 2"], 82 | "test_prompts": ["can we infer","best_option"] 83 | } -------------------------------------------------------------------------------- /configs/merge/ST_label_instance_top3.json: -------------------------------------------------------------------------------- 1 | { 2 | "top_k_merge": 3, 3 | "log_dir": "output_logs_ST_label_instance_top3", 4 | "expert_mapping_dir": "merge/ST_choice_instance_top1_32.json", 5 | "do_train": true, 6 | "do_eval": false, 7 | "do_test": false, 8 | "warmup_steps": 0, 9 | "save_steps": 1000, 10 | "model_name_or_path": "google/t5-xl-lm-adapt", 11 | "tokenizer_name": "google/t5-xl-lm-adapt", 12 | "save_total_limit": 1, 13 | "per_device_train_batch_size": 4, 14 | "per_device_eval_batch_size": 4, 15 | "load_best_model_at_end": false, 16 | "metric_for_best_model": "rougeL", 17 | "greater_is_better": false, 18 | "evaluation_strategy": "epoch", 19 | "non_linearity": "gelu_new", 20 | "max_source_length": 512, 21 | "learning_rate": 1e-4, 22 | "output_dir": "expert_weights/common_gen/Given concepts - type 2", 23 | "split_validation_test": true, 24 | "task_name": ["common_gen"], 25 | "adapters_cur_training_task":"common_gen", 26 | "dataset_config_name": ["none"], 27 | "eval_dataset_name": [ 28 | "hellaswag","hellaswag","hellaswag","hellaswag","hellaswag", 29 | "story_cloze","story_cloze","story_cloze","story_cloze","story_cloze", 30 | "anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1", 31 | "anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2", 32 | "anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3", 33 | "super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa", 34 | "super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb", 35 | "super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte", 36 | "super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed", 37 | "super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic", 38 | "winogrande","winogrande","winogrande","winogrande","winogrande" 39 | ], 40 | "eval_dataset_config_name": [ 41 | "none","none","none","none","none", 42 | "none","none","none","none","none", 43 | "dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1", 44 | "dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2", 45 | "dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3", 46 | "copa","copa","copa","copa","copa","copa","copa","copa", 47 | "cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb", 48 | "rte","rte","rte","rte","rte","rte","rte","rte","rte","rte", 49 | "wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed", 50 | "wic","wic","wic","wic","wic","wic","wic","wic","wic","wic", 51 | "winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl" 52 | ], 53 | "eval_prompts": [ 54 | "complete_first_then","Randomized prompts template", "Predict ending with hint", "how_ends","if_begins_how_continues", 55 | "Answer Given options","Choose Story Ending","Movie What Happens Next","Story Continuation and Options","Novel Correct Ending", 56 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 57 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 58 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 59 | "exercise","i_am_hesitating","plausible_alternatives","C1 or C2? premise, so/because\u2026","best_option","more likely","cause_effect", "choose", 60 | "can we infer","based on the previous passage","claim true/false/inconclusive","does it follow that","justified in saying","always/sometimes/never","GPT-3 style","consider always/sometimes/never","guaranteed true","must be true","guaranteed/possible/impossible","does this imply","MNLI crowdsource","should assume","take the following as truth", 61 | "MNLI crowdsource","guaranteed true","can we infer","GPT-3 style","does this imply","should assume","does it follow that","based on the previous passage","justified in saying","must be true", 62 | "does the pronoun refer to","by p they mean","in other words","I think they mean","does p stand for","GPT-3 Style","replaced with","p is/are r","the pronoun refers to","Who or what is/are", 63 | "question-context-meaning-with-label","question-context-meaning","grammar_homework","affirmation_true_or_false","GPT-3-prompt","same_sense","question-context","GPT-3-prompt-with-label","polysemous","similar-sense", 64 | "does underscore refer to","stand for","underscore refer to","fill in the blank","Replace" 65 | ], 66 | "test_dataset_name": ["superglue_glue","superglue_glue"], 67 | "test_dataset_config_name": ["rte","copa"], 68 | "num_train_epochs": 5, 69 | "predict_with_generate": false, 70 | "add_layer_norm_before_adapter": false, 71 | "add_layer_norm_after_adapter": false, 72 | "adapter_config_name": "adapter", 73 | "train_task_adapters": true, 74 | "task_reduction_factor": 32, 75 | "unfreeze_lm_head": false, 76 | "unfreeze_layer_norms": true, 77 | "overwrite_output_dir": true, 78 | "max_train_samples": 50000, 79 | "max_val_samples": 8000, 80 | "max_test_samples": 100, 81 | "train_prompts": ["Given concepts - type 2"], 82 | "test_prompts": ["can we infer","best_option"] 83 | } -------------------------------------------------------------------------------- /configs/merge/oracle_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "top_k_merge": 1, 3 | "txt_save_dir": "output_logs_oracle", 4 | "expert_mapping_dir": "merge/oracle_sorted.json", 5 | "do_train": true, 6 | "do_eval": false, 7 | "do_test": false, 8 | "warmup_steps": 0, 9 | "save_steps": 1000, 10 | "model_name_or_path": "google/t5-xl-lm-adapt", 11 | "tokenizer_name": "google/t5-xl-lm-adapt", 12 | "save_total_limit": 1, 13 | "per_device_train_batch_size": 4, 14 | "per_device_eval_batch_size": 64, 15 | "load_best_model_at_end": false, 16 | "metric_for_best_model": "rougeL", 17 | "greater_is_better": false, 18 | "evaluation_strategy": "epoch", 19 | "non_linearity": "gelu_new", 20 | "max_source_length": 256, 21 | "learning_rate": 1e-4, 22 | "output_dir": "expert_weights/common_gen/Given concepts - type 2", 23 | "split_validation_test": true, 24 | "task_name": ["common_gen"], 25 | "adapters_cur_training_task":"common_gen", 26 | "dataset_config_name": ["none"], 27 | "eval_dataset_name": [ 28 | "hellaswag","hellaswag","hellaswag","hellaswag","hellaswag", 29 | "story_cloze","story_cloze","story_cloze","story_cloze","story_cloze", 30 | "anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1", 31 | "anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2", 32 | "anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3", 33 | "super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa", 34 | "super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb", 35 | "super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte", 36 | "super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed", 37 | "super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic", 38 | "winogrande","winogrande","winogrande","winogrande","winogrande" 39 | ], 40 | "eval_dataset_config_name": [ 41 | "none","none","none","none","none", 42 | "none","none","none","none","none", 43 | "dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1", 44 | "dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2", 45 | "dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3", 46 | "copa","copa","copa","copa","copa","copa","copa","copa", 47 | "cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb", 48 | "rte","rte","rte","rte","rte","rte","rte","rte","rte","rte", 49 | "wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed", 50 | "wic","wic","wic","wic","wic","wic","wic","wic","wic","wic", 51 | "winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl" 52 | ], 53 | "eval_prompts": [ 54 | "complete_first_then","Randomized prompts template", "Predict ending with hint", "how_ends","if_begins_how_continues", 55 | "Answer Given options","Choose Story Ending","Movie What Happens Next","Story Continuation and Options","Novel Correct Ending", 56 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 57 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 58 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 59 | "exercise","i_am_hesitating","plausible_alternatives","C1 or C2? premise, so/because\u2026","best_option","more likely","cause_effect", "choose", 60 | "can we infer","based on the previous passage","claim true/false/inconclusive","does it follow that","justified in saying","always/sometimes/never","GPT-3 style","consider always/sometimes/never","guaranteed true","must be true","guaranteed/possible/impossible","does this imply","MNLI crowdsource","should assume","take the following as truth", 61 | "MNLI crowdsource","guaranteed true","can we infer","GPT-3 style","does this imply","should assume","does it follow that","based on the previous passage","justified in saying","must be true", 62 | "does the pronoun refer to","by p they mean","in other words","I think they mean","does p stand for","GPT-3 Style","replaced with","p is/are r","the pronoun refers to","Who or what is/are", 63 | "question-context-meaning-with-label","question-context-meaning","grammar_homework","affirmation_true_or_false","GPT-3-prompt","same_sense","question-context","GPT-3-prompt-with-label","polysemous","similar-sense", 64 | "does underscore refer to","stand for","underscore refer to","fill in the blank","Replace" 65 | ], 66 | "test_dataset_name": ["superglue_glue","superglue_glue"], 67 | "test_dataset_config_name": ["rte","copa"], 68 | "num_train_epochs": 5, 69 | "predict_with_generate": false, 70 | "add_layer_norm_before_adapter": false, 71 | "add_layer_norm_after_adapter": false, 72 | "adapter_config_name": "adapter", 73 | "train_task_adapters": true, 74 | "task_reduction_factor": 32, 75 | "unfreeze_lm_head": false, 76 | "unfreeze_layer_norms": true, 77 | "overwrite_output_dir": true, 78 | "max_train_samples": 50000, 79 | "max_val_samples": 300, 80 | "max_test_samples": 100, 81 | "train_prompts": ["Given concepts - type 2"], 82 | "test_prompts": ["can we infer","best_option"] 83 | } -------------------------------------------------------------------------------- /configs/merge/oracle_top3.json: -------------------------------------------------------------------------------- 1 | { 2 | "top_k_merge": 3, 3 | "log_dir": "output_logs_oracle_top3", 4 | "expert_mapping_dir": "merge/oracle_sorted.json", 5 | "do_train": true, 6 | "do_eval": false, 7 | "do_test": false, 8 | "warmup_steps": 0, 9 | "save_steps": 1000, 10 | "model_name_or_path": "google/t5-xl-lm-adapt", 11 | "tokenizer_name": "google/t5-xl-lm-adapt", 12 | "save_total_limit": 1, 13 | "per_device_train_batch_size": 4, 14 | "per_device_eval_batch_size": 4, 15 | "load_best_model_at_end": false, 16 | "metric_for_best_model": "rougeL", 17 | "greater_is_better": false, 18 | "evaluation_strategy": "epoch", 19 | "non_linearity": "gelu_new", 20 | "max_source_length": 512, 21 | "learning_rate": 1e-4, 22 | "output_dir": "expert_weights/common_gen/Given concepts - type 2", 23 | "split_validation_test": true, 24 | "task_name": ["common_gen"], 25 | "adapters_cur_training_task":"common_gen", 26 | "dataset_config_name": ["none"], 27 | "eval_dataset_name": [ 28 | "hellaswag","hellaswag","hellaswag","hellaswag","hellaswag", 29 | "story_cloze","story_cloze","story_cloze","story_cloze","story_cloze", 30 | "anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1","anli_r1", 31 | "anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2","anli_r2", 32 | "anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3","anli_r3", 33 | "super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa","super_glue_copa", 34 | "super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb","super_glue_cb", 35 | "super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte","super_glue_rte", 36 | "super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed","super_glue_wsc.fixed", 37 | "super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic","super_glue_wic", 38 | "winogrande","winogrande","winogrande","winogrande","winogrande" 39 | ], 40 | "eval_dataset_config_name": [ 41 | "none","none","none","none","none", 42 | "none","none","none","none","none", 43 | "dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1","dev_r1", 44 | "dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2","dev_r2", 45 | "dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3","dev_r3", 46 | "copa","copa","copa","copa","copa","copa","copa","copa", 47 | "cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb","cb", 48 | "rte","rte","rte","rte","rte","rte","rte","rte","rte","rte", 49 | "wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed","wsc.fixed", 50 | "wic","wic","wic","wic","wic","wic","wic","wic","wic","wic", 51 | "winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl","winogrande_xl" 52 | ], 53 | "eval_prompts": [ 54 | "complete_first_then","Randomized prompts template", "Predict ending with hint", "how_ends","if_begins_how_continues", 55 | "Answer Given options","Choose Story Ending","Movie What Happens Next","Story Continuation and Options","Novel Correct Ending", 56 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 57 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 58 | "MNLI crowdsource","should assume","does it follow that","GPT-3 style","based on the previous passage","justified in saying","take the following as truth","must be true","can we infer","guaranteed/possible/impossible","always/sometimes/never","does this imply","consider always/sometimes/never","claim true/false/inconclusive","guaranteed true", 59 | "exercise","i_am_hesitating","plausible_alternatives","C1 or C2? premise, so/because\u2026","best_option","more likely","cause_effect", "choose", 60 | "can we infer","based on the previous passage","claim true/false/inconclusive","does it follow that","justified in saying","always/sometimes/never","GPT-3 style","consider always/sometimes/never","guaranteed true","must be true","guaranteed/possible/impossible","does this imply","MNLI crowdsource","should assume","take the following as truth", 61 | "MNLI crowdsource","guaranteed true","can we infer","GPT-3 style","does this imply","should assume","does it follow that","based on the previous passage","justified in saying","must be true", 62 | "does the pronoun refer to","by p they mean","in other words","I think they mean","does p stand for","GPT-3 Style","replaced with","p is/are r","the pronoun refers to","Who or what is/are", 63 | "question-context-meaning-with-label","question-context-meaning","grammar_homework","affirmation_true_or_false","GPT-3-prompt","same_sense","question-context","GPT-3-prompt-with-label","polysemous","similar-sense", 64 | "does underscore refer to","stand for","underscore refer to","fill in the blank","Replace" 65 | ], 66 | "test_dataset_name": ["superglue_glue","superglue_glue"], 67 | "test_dataset_config_name": ["rte","copa"], 68 | "num_train_epochs": 5, 69 | "predict_with_generate": false, 70 | "add_layer_norm_before_adapter": false, 71 | "add_layer_norm_after_adapter": false, 72 | "adapter_config_name": "adapter", 73 | "train_task_adapters": true, 74 | "task_reduction_factor": 32, 75 | "unfreeze_lm_head": false, 76 | "unfreeze_layer_norms": true, 77 | "overwrite_output_dir": true, 78 | "max_train_samples": 50000, 79 | "max_val_samples": 8000, 80 | "max_test_samples": 100, 81 | "train_prompts": ["Given concepts - type 2"], 82 | "test_prompts": ["can we infer","best_option"] 83 | } -------------------------------------------------------------------------------- /figure1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeljang/ELM/27fc2a326cd8846717a022db1cbd94d393420bce/figure1.PNG -------------------------------------------------------------------------------- /figure3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeljang/ELM/27fc2a326cd8846717a022db1cbd94d393420bce/figure3.PNG -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rouge-score 2 | gitpython 3 | sacrebleu 4 | pyarrow 5 | pandas 6 | dill 7 | urllib3 8 | charset_normalizer 9 | idna 10 | wandb 11 | scikit-learn==0.24.2 12 | datasets==1.6.2 13 | huggingface-hub==0.0.8 14 | attrs 15 | transformers==4.6.0 16 | protobuf==3.20.0 17 | jinja2 18 | rank_bm25 -------------------------------------------------------------------------------- /retrieval/additional_code/get_oracle_list.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | 4 | df = pd.read_csv('experts.csv') 5 | eval_datasets = df.columns[2:] 6 | 7 | train_datasets = [] 8 | train_prompts = [] 9 | train_dataset_rows = [] 10 | for index,row in df.iterrows(): 11 | if index==0: 12 | eval_dataset_prompts = row[2:] 13 | else: 14 | train_dataset = row['Dataset'] 15 | train_dataset_prompt = row['Prompt'] 16 | train_dataset_rows.append(row[2:]) 17 | train_datasets.append(train_dataset) 18 | train_prompts.append(train_dataset_prompt) 19 | 20 | oracle_sorted = {} 21 | 22 | evals = ['hellaswag', 'story_cloze/2016', 'anli_dev_r1', 'anli_dev_r2', 'anli_dev_r3', 'super_glue/copa', 'super_glue/cb', 'super_glue/rte', 'super_glue/wsc.fixed', 'super_glue/wic', 'winogrande/winogrande_xl'] 23 | # Evaluate each evaluation prompt 24 | for i in range(len(eval_datasets)): 25 | eval_dataset = eval_datasets[i] 26 | eval_prompt = eval_dataset_prompts[i] 27 | # Getting the exact dataset name: 28 | for name in evals: 29 | if name in eval_dataset: 30 | eval_dataset = name 31 | max_val = 0 32 | expert_results = {} 33 | for j in range(len(train_datasets)): 34 | train_dataset = train_datasets[j] 35 | train_prompt = train_prompts[j] 36 | key = f'{train_dataset}@{train_prompt}' 37 | val = train_dataset_rows[j][i] 38 | expert_results[key] = val 39 | #print(expert_results) 40 | sorted_expert_results = sorted(expert_results, key=expert_results.get, reverse=True) 41 | print(f'{eval_dataset}@{eval_prompt}') 42 | oracle_sorted[f'{eval_dataset}@{eval_prompt}'] = sorted_expert_results 43 | 44 | with open(f'oracle_sorted.json', "w") as outfile: 45 | json.dump(oracle_sorted, outfile) -------------------------------------------------------------------------------- /retrieval/additional_code/get_oracle_list_task.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | 4 | df = pd.read_csv('retrieval_data/dataset_experts.csv') 5 | eval_datasets = df.columns[1:] 6 | 7 | train_datasets = [] 8 | train_prompts = [] 9 | train_dataset_rows = [] 10 | for index,row in df.iterrows(): 11 | if index==0: 12 | eval_dataset_prompts = row[1:] 13 | else: 14 | train_dataset = row['Dataset'] 15 | train_dataset_rows.append(row[1:]) 16 | train_datasets.append(train_dataset) 17 | 18 | oracle_sorted = {} 19 | 20 | evals = ['hellaswag', 'story_cloze/2016', 'anli_dev_r1', 'anli_dev_r2', 'anli_dev_r3', 'super_glue/copa', 'super_glue/cb', 'super_glue/rte', 'super_glue/wsc.fixed', 'super_glue/wic', 'winogrande/winogrande_xl'] 21 | # Evaluate each evaluation prompt 22 | for i in range(len(eval_datasets)): 23 | eval_dataset = eval_datasets[i] 24 | eval_prompt = eval_dataset_prompts[i] 25 | # Getting the exact dataset name: 26 | for name in evals: 27 | if name in eval_dataset: 28 | eval_dataset = name 29 | max_val = 0 30 | expert_results = {} 31 | for j in range(len(train_datasets)): 32 | train_dataset = train_datasets[j] 33 | key = f'{train_dataset}' 34 | val = train_dataset_rows[j][i] 35 | expert_results[key] = val 36 | #print(expert_results) 37 | sorted_expert_results = sorted(expert_results, key=expert_results.get, reverse=True) 38 | print(f'{eval_dataset}@{eval_prompt}') 39 | oracle_sorted[f'{eval_dataset}@{eval_prompt}'] = sorted_expert_results 40 | 41 | with open(f'results/oracle_sorted_dataset_level.json', "w") as outfile: 42 | json.dump(oracle_sorted, outfile) -------------------------------------------------------------------------------- /retrieval/additional_code/run.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, TrainerCallback 3 | from promptsource.promptsource.templates import DatasetTemplates 4 | from sentence_transformers import SentenceTransformer 5 | import numpy as np 6 | import random 7 | from rank_bm25 import BM25Okapi 8 | from statistics import mean 9 | import re 10 | import datasets 11 | 12 | # Setting the method to use 13 | METHOD = 'ST' 14 | #METHOD = 'BM25' 15 | #TEXT_FORMAT = 'prompt' 16 | TEXT_FORMAT = 'label_prompt' 17 | #TEXT_FORMAT = 'label' 18 | #TEXT_FORMAT = 'label@prompt' 19 | 20 | def clease_jinja(jinja): 21 | jinja = re.sub(r'{{.+?}}', '', jinja) 22 | jinja = re.sub(r'{.+?}', '', jinja) 23 | jinja = jinja.replace("|||", "") 24 | jinja = jinja.replace("\n", "") 25 | return jinja 26 | 27 | # Setting the text format to use for retrieval (task level) 28 | def get_text_format(prompt, hard_prompt, answer_choices): 29 | if TEXT_FORMAT=='prompt': 30 | text = f"Prompt: {hard_prompt}" 31 | elif TEXT_FORMAT=='label_prompt': 32 | text = f"Answer Choices: {answer_choices}, Prompt: {hard_prompt}" 33 | elif TEXT_FORMAT=='label': 34 | text = f"Answer Choices: {answer_choices}" 35 | else: 36 | raise Exception('Select the correct TEXT_FORMAT') 37 | return text 38 | 39 | from math import sqrt, pow, exp 40 | def squared_sum(x): 41 | """ return 3 rounded square rooted value """ 42 | return round(sqrt(sum([a*a for a in x])),3) 43 | 44 | def euclidean_distance(x,y): 45 | """ return euclidean distance between two lists """ 46 | return sqrt(sum(pow(a-b,2) for a, b in zip(x, y))) 47 | 48 | def cos_similarity(x,y): 49 | """ return cosine similarity between two lists """ 50 | numerator = sum(a*b for a,b in zip(x,y)) 51 | denominator = squared_sum(x)*squared_sum(y) 52 | return round(numerator/float(denominator),3) 53 | 54 | df = pd.read_csv('experts.csv') 55 | eval_datasets = df.columns[2:] 56 | 57 | train_datasets = [] 58 | train_prompts = [] 59 | train_dataset_rows = [] 60 | for index,row in df.iterrows(): 61 | if index==0: 62 | eval_dataset_prompts = row[2:] 63 | else: 64 | train_dataset = row['Dataset'] 65 | train_dataset_prompt = row['Prompt'] 66 | train_dataset_rows.append(row[2:]) 67 | train_datasets.append(train_dataset) 68 | train_prompts.append(train_dataset_prompt) 69 | 70 | sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2') 71 | 72 | print(f'Length of train prompts: {len(train_prompts)}') 73 | print(f'Length of eval prompts: {len(eval_datasets)}') 74 | 75 | # Make the Training Index 76 | index = {} 77 | index_values = {} 78 | text_to_prompt = {} 79 | corpus = [] 80 | for i in range(len(train_prompts)): 81 | dataset = train_datasets[i] 82 | prompt = train_prompts[i] 83 | promptList = DatasetTemplates(dataset)[prompt] 84 | hard_prompt = clease_jinja(promptList.jinja) 85 | answer_choices = promptList.answer_choices 86 | text = get_text_format(prompt, hard_prompt, answer_choices) # setting format for retrieval 87 | corpus.append(text) 88 | embedding = sentence_transformer.encode(text) 89 | index[f'{dataset},{prompt}'] = embedding 90 | index_values[f'{dataset},{prompt}'] = train_dataset_rows[i] 91 | text_to_prompt[text] = f'{dataset},{prompt}' 92 | 93 | #index_values['testing,prompt'] = 0 94 | #index['testing,prompt'] = sentence_transformer.encode('Replace the _ in the above sentence with the correct option: - -') 95 | 96 | evals = ['hellaswag', 'story_cloze/2016', 'anli_dev_r1', 'anli_dev_r2', 'anli_dev_r3', 'super_glue/copa', 'super_glue/cb', 'super_glue/rte', 'super_glue/wsc.fixed', 'super_glue/wic', 'winogrande/winogrande_xl'] 97 | results = [] 98 | index_keys = list(index.keys()) 99 | tokenized_corpus = [doc.split(" ") for doc in corpus] 100 | bm25 = BM25Okapi(tokenized_corpus) 101 | 102 | # Evaluate each evaluation prompt 103 | for i in range(len(eval_datasets)): 104 | dataset = eval_datasets[i] 105 | prompt = eval_dataset_prompts[i] 106 | # Getting the exact dataset name: 107 | for name in evals: 108 | if name in dataset: 109 | dataset = name 110 | if 'anli_dev' in dataset: 111 | dataset = 'anli' 112 | promptList = DatasetTemplates(dataset)[prompt] 113 | hard_prompt = clease_jinja(promptList.jinja) 114 | answer_choices = promptList.answer_choices 115 | text = get_text_format(prompt, hard_prompt, answer_choices) # setting format for retrieval 116 | 117 | if METHOD=='ST': 118 | embedding = sentence_transformer.encode(text) 119 | max = 0 120 | products = [] 121 | for key in index: 122 | key_embedding = index[key] 123 | #dot_product = np.dot(key_embedding, embedding) 124 | #dot_product = euclidean_distance(key_embedding, embedding) 125 | dot_product = cos_similarity(key_embedding, embedding) 126 | products.append(dot_product) 127 | if dot_product > max: 128 | max = dot_product 129 | max_prompt = key 130 | 131 | elif METHOD=='BM25': 132 | tokenized_query = text.split(" ") 133 | retrieved_hard_prompt = bm25.get_top_n(tokenized_query, corpus, n=1)[0] 134 | max_prompt = text_to_prompt[retrieved_hard_prompt] 135 | elif METHOD=='random': 136 | max_prompt = random.choice(index_keys) 137 | else: 138 | raise Exception('Choose the correct METHOD') 139 | r_dataset = max_prompt.split(',')[0] 140 | r_prompt = max_prompt.split(',')[1:] 141 | r_prompt = "".join(r_prompt) 142 | r_promptList = DatasetTemplates(r_dataset)[r_prompt] 143 | r_hard_prompt = clease_jinja(r_promptList.jinja) 144 | r_answer_choices = r_promptList.answer_choices 145 | print(f'eval dataset: {dataset}, eval prompt: {prompt}, answer choices: {answer_choices}') 146 | print(f'retrieved: {max_prompt}, hard_prompt: {r_answer_choices}\n') 147 | val = index_values[max_prompt][i] 148 | results.append(val) 149 | 150 | pd.DataFrame([results]).to_csv(f'{METHOD}_{TEXT_FORMAT}.csv') 151 | exit() 152 | 153 | print(results) 154 | expert_nums = [5, 5, 15, 15, 15, 8, 15, 10, 10, 10, 5] 155 | eval_results = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 156 | cnt = 0 157 | indx = 0 158 | val_sum = 0 159 | for i in range(len(results)): 160 | val = results[i] 161 | val_sum += float(val) 162 | cnt+=1 163 | if cnt == expert_nums[indx]: 164 | avg = val_sum / cnt 165 | eval_results[indx] = avg 166 | cnt=0 167 | val_sum=0 168 | indx+=1 169 | print(f'{METHOD}_{TEXT_FORMAT}') 170 | print(eval_results) 171 | print(f'AVG: {mean(eval_results)}') 172 | -------------------------------------------------------------------------------- /retrieval/additional_code/run_selected.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, TrainerCallback 3 | from promptsource.promptsource.templates import DatasetTemplates 4 | from sentence_transformers import SentenceTransformer 5 | import numpy as np 6 | import random 7 | from rank_bm25 import BM25Okapi 8 | from statistics import mean 9 | import re 10 | import datasets 11 | import json 12 | 13 | # Setting the method to use 14 | METHOD = 'ST' 15 | #METHOD = 'BM25' 16 | #TEXT_FORMAT = 'prompt' 17 | TEXT_FORMAT = 'label_prompt' 18 | #TEXT_FORMAT = 'label' 19 | #TEXT_FORMAT = 'label@prompt' 20 | 21 | merge_dir = "topk_merge_mappings/ST_choice_instance_top1_32.json" 22 | merge_mapping = json.load(open(merge_dir,"r")) 23 | 24 | def clease_jinja(jinja): 25 | jinja = re.sub(r'{{.+?}}', '', jinja) 26 | jinja = re.sub(r'{.+?}', '', jinja) 27 | jinja = jinja.replace("|||", "") 28 | jinja = jinja.replace("\n", "") 29 | return jinja 30 | 31 | # Setting the text format to use for retrieval (task level) 32 | def get_text_format(prompt, hard_prompt, answer_choices): 33 | if TEXT_FORMAT=='prompt': 34 | text = f"Prompt: {hard_prompt}" 35 | elif TEXT_FORMAT=='label_prompt': 36 | text = f"Answer Choices: {answer_choices}, Prompt: {hard_prompt}" 37 | elif TEXT_FORMAT=='label': 38 | text = f"Answer Choices: {answer_choices}" 39 | else: 40 | raise Exception('Select the correct TEXT_FORMAT') 41 | return text 42 | 43 | from math import sqrt, pow, exp 44 | def squared_sum(x): 45 | """ return 3 rounded square rooted value """ 46 | return round(sqrt(sum([a*a for a in x])),3) 47 | 48 | def euclidean_distance(x,y): 49 | """ return euclidean distance between two lists """ 50 | return sqrt(sum(pow(a-b,2) for a, b in zip(x, y))) 51 | 52 | def cos_similarity(x,y): 53 | """ return cosine similarity between two lists """ 54 | numerator = sum(a*b for a,b in zip(x,y)) 55 | denominator = squared_sum(x)*squared_sum(y) 56 | return round(numerator/float(denominator),3) 57 | 58 | df = pd.read_csv('experts.csv') 59 | eval_datasets = df.columns[2:] 60 | 61 | train_datasets = [] 62 | train_prompts = [] 63 | train_dataset_rows = [] 64 | for index,row in df.iterrows(): 65 | if index==0: 66 | eval_dataset_prompts = row[2:] 67 | else: 68 | train_dataset = row['Dataset'] 69 | train_dataset_prompt = row['Prompt'] 70 | train_dataset_rows.append(row[2:]) 71 | train_datasets.append(train_dataset) 72 | train_prompts.append(train_dataset_prompt) 73 | 74 | sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2') 75 | 76 | print(f'Length of train prompts: {len(train_prompts)}') 77 | print(f'Length of eval prompts: {len(eval_datasets)}') 78 | 79 | # Make the Training Index 80 | index = {} 81 | index_values = {} 82 | text_to_prompt = {} 83 | corpus = [] 84 | for i in range(len(train_prompts)): 85 | dataset = train_datasets[i] 86 | prompt = train_prompts[i] 87 | promptList = DatasetTemplates(dataset)[prompt] 88 | hard_prompt = clease_jinja(promptList.jinja) 89 | answer_choices = promptList.answer_choices 90 | text = get_text_format(prompt, hard_prompt, answer_choices) # setting format for retrieval 91 | corpus.append(text) 92 | embedding = sentence_transformer.encode(text) 93 | index[f'{dataset},{prompt}'] = embedding 94 | index_values[f'{dataset},{prompt}'] = train_dataset_rows[i] 95 | text_to_prompt[text] = f'{dataset},{prompt}' 96 | 97 | #index_values['testing,prompt'] = 0 98 | #index['testing,prompt'] = sentence_transformer.encode('Replace the _ in the above sentence with the correct option: - -') 99 | 100 | evals = ["hellaswag", "story_cloze", "anli_dev_r1", "anli_dev_r2", "anli_dev_r3", "super_glue/copa", "super_glue/cb", "super_glue/rte", "super_glue/wsc.fixed", "super_glue/wic", "winogrande/winogrande_xl"] 101 | eval_name_mapping = { 102 | "hellaswag": "hellaswag", 103 | "story_cloze": "story_cloze", 104 | "anli_dev_r1": "anli_r1", 105 | "anli_dev_r2": "anli_r2", 106 | "anli_dev_r3": "anli_r3", 107 | "super_glue/copa": "super_glue_copa", 108 | "super_glue/cb": "super_glue_cb", 109 | "super_glue/rte": "super_glue_rte", 110 | "super_glue/wsc.fixed": "super_glue_wsc.fixed", 111 | "super_glue/wic": "super_glue_wic", 112 | "winogrande/winogrande_xl": "winogrande" 113 | } 114 | expert_name_mapping2 = { 115 | "cos_e": "cos_e/v1.11", 116 | "wiki_hop": "wiki_hop/original", 117 | "paws": "paws/labeled_final", 118 | "glue_qqp": "glue/qqp", 119 | "glue_mrpc": "glue/mrpc", 120 | "adversarial_qa": "adversarial_qa/adversarialQA", 121 | "duorc": "duorc/ParaphraseRC", 122 | "hotpot_qa": "hotpot_qa/fullwiki", 123 | "cnn_dailymail": "cnn_dailymail/3.0.0" 124 | } 125 | # Evaluate each evaluation prompt 126 | results = [] 127 | index_keys = list(index.keys()) 128 | tokenized_corpus = [doc.split(" ") for doc in corpus] 129 | bm25 = BM25Okapi(tokenized_corpus) 130 | 131 | # Evaluate each evaluation prompt 132 | for i in range(len(eval_datasets)): 133 | dataset = eval_datasets[i] 134 | # Getting the exact dataset name: 135 | print(dataset) 136 | for key in evals: 137 | if key in dataset: 138 | dataset = key 139 | dataset = eval_name_mapping[dataset] 140 | prompt = eval_dataset_prompts[i] 141 | print(dataset, prompt) 142 | 143 | max_prompt = merge_mapping[f'{dataset}@{prompt}'][0] 144 | max_prompt = max_prompt.replace('@', ',') 145 | 146 | r_dataset = max_prompt.split(',')[0] 147 | r_prompt = max_prompt.split(',')[1:] 148 | if r_dataset in expert_name_mapping2: 149 | r_dataset = expert_name_mapping2[r_dataset] 150 | r_prompt = "".join(r_prompt) 151 | max_prompt = f"{r_dataset},{r_prompt}" 152 | print(f'eval dataset: {dataset}, eval prompt: {prompt}, answer choices: {answer_choices}') 153 | #print(f'retrieved: {max_prompt}, hard_prompt: {r_answer_choices}\n') 154 | print(max_prompt) 155 | val = index_values[max_prompt][i] 156 | results.append(val) 157 | 158 | pd.DataFrame([results]).to_csv(f'ST_choice_instance_top1_32_original.csv') -------------------------------------------------------------------------------- /retrieval/configs/BM25_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "BM25", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/BM25_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "BM25", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ST_MP_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ST_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ST_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ST_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ST_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ST_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/T0_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "T0", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/DIFFCSE_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "DIFFCSE", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/DIFFCSE_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "DIFFCSE", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/DIFFCSE_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "DIFFCSE", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/DIFFCSE_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "DIFFCSE", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/DIFFCSE_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "DIFFCSE", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/DIFFCSE_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "DIFFCSE", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/DIFFCSE_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "DIFFCSE", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/DIFFCSE_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "DIFFCSE", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/DIFFCSE_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "DIFFCSE", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/DIFFCSE_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "DIFFCSE", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_Large_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_Large", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_Large_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_Large", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_Large_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_Large", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_Large_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_Large", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_Large_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_Large", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_Large_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_Large", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_Large_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_Large", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_Large_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_Large", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_Large_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_Large", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_Large_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_Large", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_XL_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_XL", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_XL_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_XL", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_XL_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_XL", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_XL_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_XL", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_XL_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_XL", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_XL_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_XL", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_XL_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_XL", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_XL_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_XL", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_XL_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_XL", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/GTR_XL_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "GTR_XL", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_Large_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_Large", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_Large_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_Large", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_Large_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_Large", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_Large_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_Large", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_Large_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_Large", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_Large_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_Large", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_Large_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_Large", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_Large_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_Large", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_Large_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_Large", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_Large_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_Large", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_XL_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_XL", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_XL_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_XL", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_XL_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_XL", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_XL_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_XL", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_XL_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_XL", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_XL_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_XL", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_XL_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_XL", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_XL_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_XL", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_XL_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_XL", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/INSTRUCTOR_XL_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "INSTRUCTOR_XL", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_SUP_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_SUP", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_SUP_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_SUP", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_SUP_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_SUP", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_SUP_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_SUP", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_SUP_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_SUP", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_SUP_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_SUP", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_SUP_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_SUP", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_SUP_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_SUP", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_SUP_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_SUP", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_SUP_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_SUP", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_UNSUP_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_UNSUP", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_UNSUP_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_UNSUP", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_UNSUP_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_UNSUP", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_UNSUP_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_UNSUP", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_UNSUP_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_UNSUP", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_UNSUP_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_UNSUP", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_UNSUP_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_UNSUP", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_UNSUP_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_UNSUP", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_UNSUP_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_UNSUP", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/SIMCSE_UNSUP_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "SIMCSE_UNSUP", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_Large_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_Large", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_Large_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_Large", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_Large_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_Large", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_Large_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_Large", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_Large_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_Large", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_Large_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_Large", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_Large_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_Large", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_Large_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_Large", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_Large_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_Large", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_Large_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_Large", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_XL_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_XL", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_XL_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_XL", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_XL_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_XL", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_XL_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_XL", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_XL_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_XL", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_XL_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_XL", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_XL_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_XL", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_XL_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_XL", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_XL_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_XL", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST5_XL_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST5_XL", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_12_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_12", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_12_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_12", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_12_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_12", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_12_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_12", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_12_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_12", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_12_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_12", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_12_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_12", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_12_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_12", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_12_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_12", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_12_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_12", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_NLI_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP_NLI", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_NLI_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP_NLI", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_NLI_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP_NLI", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_NLI_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP_NLI", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_NLI_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP_NLI", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_NLI_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP_NLI", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_NLI_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP_NLI", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_NLI_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP_NLI", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_NLI_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP_NLI", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_NLI_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP_NLI", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_MP_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST_MP", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/ST_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "ST", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/T0_choice_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "T0", 3 | "text_format": "choice_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/T0_choice_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "T0", 3 | "text_format": "choice_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/T0_choice_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "T0", 3 | "text_format": "choice_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/T0_choice_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "T0", 3 | "text_format": "choice", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/T0_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "T0", 3 | "text_format": "instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/T0_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "T0", 3 | "text_format": "instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/T0_label_instance_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "T0", 3 | "text_format": "label_instance_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/T0_label_instance_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "T0", 3 | "text_format": "label_instance", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/T0_label_new_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "T0", 3 | "text_format": "label_new", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/configs/ablations/T0_label_top1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "T0", 3 | "text_format": "label", 4 | "top_k": 1 5 | } -------------------------------------------------------------------------------- /retrieval/make_benchmarks.py: -------------------------------------------------------------------------------- 1 | # Code fo making Big Bench and MMLU Evaluation data in .json format 2 | import json 3 | import os 4 | 5 | def dataset_prompt_setting(type_path, dataset_name, dataset_config_name, bigbench_path): 6 | if type_path == 'bigbench': 7 | if dataset_config_name==None: 8 | path = os.path.join(bigbench_path, dataset_name, "task.json") 9 | else: 10 | path = os.path.join(bigbench_path, dataset_name, dataset_config_name, "task.json") 11 | with open(path, 'r') as file: 12 | data = json.load(file) 13 | instruction = data["description"] 14 | example = data["examples"] 15 | task_prefix = data.get("task_prefix", "") 16 | input_prefix = data.get("example_input_prefix", "\nQ: ") 17 | output_prefix = data.get("example_output_prefix", "\nA: ") 18 | choice_prefix = data.get("choice_prefix", "\n choice: ") 19 | append_choices_to_input = data.get("append_choices_to_input", True) 20 | 21 | return example, instruction, [task_prefix, input_prefix, output_prefix, choice_prefix, append_choices_to_input] 22 | 23 | def bigbench_input(input, options, task_prefix, input_prefix, choice_prefix, append_choices_to_input, output_prefix): 24 | def choices_string(choices, options, append_choices_to_input): 25 | if append_choices_to_input: 26 | choices_string = choices+choices.join(options) 27 | #choices_string = f'{choices}'+f'{choices}'.join(options) 28 | else: 29 | choices_string="" 30 | return choices_string 31 | input_ = f'{task_prefix}{input_prefix}{input}{choices_string(choice_prefix, options, append_choices_to_input)}{output_prefix}' 32 | return input_ 33 | 34 | bigbench_tasks_all = ['code_line_description', 35 | 'conceptual_combinations/contradictions', 'conceptual_combinations/emergent_properties', 'conceptual_combinations/fanciful_fictional_combinations', 'conceptual_combinations/homonyms', 'conceptual_combinations/invented_words', 'conceptual_combinations/homonyms', 'conceptual_combinations/surprising_uncommon_combinations', 36 | 'formal_fallacies_syllogisms_negation', 37 | 'hindu_knowledge', 38 | 'known_unknowns', 39 | 'language_identification', 40 | 'logic_grid_puzzle', 41 | 'logical_deduction/five_objects', 'logical_deduction/seven_objects', 'logical_deduction/three_objects', 42 | 'misconceptions', 43 | 'movie_dialog_same_or_different', 44 | 'novel_concepts', 45 | 'strategyqa', 46 | 'vitaminc_fact_verification', 47 | 'winowhy'] 48 | 49 | bigbench_tasks = ['code_line_description', 50 | 'conceptual_combinations/contradictions', 'conceptual_combinations/emergent_properties', 'conceptual_combinations/fanciful_fictional_combinations', 'conceptual_combinations/homonyms', 'conceptual_combinations/invented_words', 'conceptual_combinations/homonyms', 'conceptual_combinations/surprising_uncommon_combinations', 51 | 'formal_fallacies_syllogisms_negation', 52 | 'hindu_knowledge', 53 | 'known_unknowns', 54 | 'language_identification', 55 | 'logic_grid_puzzle', 56 | 'logical_deduction/five_objects', 'logical_deduction/seven_objects', 'logical_deduction/three_objects', 57 | 'misconceptions', 58 | 'movie_dialog_same_or_different', 59 | 'strategyqa', 60 | 'vitaminc_fact_verification', 61 | 'winowhy'] 62 | 63 | 64 | bigbench_path = 'retrieval_data/benchmark_tasks' 65 | evals = {} 66 | for origin_dataset_name in bigbench_tasks: 67 | if '/' in origin_dataset_name: 68 | dataset_name_s = origin_dataset_name.split('/') 69 | dataset_name = dataset_name_s[0] 70 | dataset_config_name = dataset_name_s[1] 71 | else: 72 | dataset_name = origin_dataset_name 73 | dataset_config_name = None 74 | eval_dataset = dataset_prompt_setting('bigbench', dataset_name, dataset_config_name, bigbench_path) 75 | print(f'{origin_dataset_name}, number of examples: {len(eval_dataset[0])}') 76 | task_prefix, input_prefix, output_prefix, choice_prefix, append_choices_to_input = eval_dataset[2] 77 | for example in eval_dataset[0]: 78 | eval_instance = {} 79 | if dataset_config_name==None: 80 | eval_instance['config'] = "none" 81 | else: 82 | eval_instance['config'] = dataset_config_name 83 | eval_instance['task'] = dataset_name 84 | input = example['input'] 85 | eval_instance['label_list'] = [] 86 | eval_instance['target'] = None 87 | for label in example['target_scores']: 88 | eval_instance['label_list'].append(label) 89 | if example['target_scores'][label] == 1: 90 | if eval_instance['target']!=None: 91 | print(example['target_scores']) 92 | raise Exception('there are two labels that are correct..!!') 93 | eval_instance['target'] = label 94 | eval_instance['source'] = bigbench_input(input, eval_instance['label_list'], task_prefix, input_prefix, choice_prefix, append_choices_to_input, output_prefix) 95 | if eval_instance['target']==None: 96 | raise Exception('there are zero labels with correct answer!') 97 | if dataset_name not in evals: 98 | evals[dataset_name] = [eval_instance] 99 | else: 100 | evals[dataset_name].append(eval_instance) 101 | 102 | print(len(evals)) 103 | with open("retrieval_data/bigbench_eval.json", "w") as fp: 104 | json.dump(evals,fp) -------------------------------------------------------------------------------- /seq2seq/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | from .adapter_configuration import ADAPTER_CONFIG_MAPPING, AutoAdapterConfig, AdapterConfig 2 | from .adapter_modeling import Adapter, HyperComplexAdapter 3 | from .adapter_controller import AdapterController -------------------------------------------------------------------------------- /seq2seq/adapters/adapter_configuration.py: -------------------------------------------------------------------------------- 1 | """Implements the adapters and other parameter-efficient finetuning methods' configurations.""" 2 | 3 | from collections import OrderedDict 4 | from dataclasses import dataclass 5 | 6 | import torch.nn as nn 7 | 8 | 9 | @dataclass 10 | class AdapterConfig(object): 11 | """Implements the adapter configuration proposed by Houlsby et. al, 2019 12 | in https://arxiv.org/abs/1902.00751. 13 | We additionally pass all the configuration of parameter-efficient finetuning 14 | methods with this config.""" 15 | add_layer_norm_before_adapter: bool = False 16 | add_layer_norm_after_adapter: bool = True 17 | non_linearity: str = "swish" 18 | task_reduction_factor: int = 16 19 | add_adapter_in_feed_forward = True 20 | add_adapter_in_self_attention = True 21 | hidden_dim = 128 22 | task_adapter_layers_encoder = None 23 | task_adapter_layers_decoder = None 24 | task_adapter_in_decoder = True 25 | intrinsic_dim = 100 26 | normalize_intrinsic_projections = False 27 | # This can be either random, or fastfood. 28 | intrinsic_projection = "random" 29 | 30 | # Hypercomplex adapters parameters 31 | hypercomplex_adapters = False 32 | hypercomplex_division = 8 33 | learn_phm = True 34 | hypercomplex_nonlinearity = "glorot-uniform" 35 | shared_phm_rule = False 36 | factorized_phm = False 37 | shared_W_phm = False 38 | factorized_phm_rule = False 39 | phm_c_init = "normal" 40 | phm_rank = 1 41 | phm_init_range = 0.01 42 | 43 | # prefix-tuning parameters. 44 | prefix_dim = 100 45 | init_prefix_from_vocab = False 46 | kronecker_prod = False 47 | 48 | # BitFit configuration. 49 | bitfit = False 50 | 51 | # Low-rank adapters. 52 | low_rank_adapters = False 53 | low_rank_w_init = "glorot-uniform" 54 | low_rank_rank = 1 55 | 56 | 57 | ADAPTER_CONFIG_MAPPING = OrderedDict( 58 | [("adapter", AdapterConfig)]) 59 | 60 | 61 | class AutoAdapterConfig(nn.Module): 62 | """Generic Adapter config class to instantiate different adapter configs.""" 63 | 64 | @classmethod 65 | def get(cls, config_name: str): 66 | if config_name in ADAPTER_CONFIG_MAPPING: 67 | return ADAPTER_CONFIG_MAPPING[config_name]() 68 | raise ValueError( 69 | "Unrecognized adapter config type identifier: {}. Should contain one of {}" 70 | .format(config_name, ", ".join(ADAPTER_CONFIG_MAPPING.keys()))) 71 | -------------------------------------------------------------------------------- /seq2seq/adapters/adapter_controller.py: -------------------------------------------------------------------------------- 1 | """Implements Adapter Controller, a module that keeps multiple 2 | layers of Adapters, and controls which adapter layer to use.""" 3 | import os 4 | import torch.nn as nn 5 | from .adapter_modeling import Adapter, HyperComplexAdapter, LowRankAdapter 6 | 7 | 8 | class AdapterController(nn.Module): 9 | """Implements Adapter controller module which controls the logics of 10 | putting adapter layers within transformer's layers.""" 11 | 12 | def __init__(self, config): 13 | super().__init__() 14 | # low-rank adapters. 15 | self.low_rank_adapters = config.low_rank_adapters 16 | self.intrinsic_projections_path = os.path.join( 17 | config.output_dir, "intrinsic_projections") 18 | self.config = config 19 | self.adapters = nn.ModuleDict(dict()) 20 | if type(config.tasks[0]) is list: 21 | self.tasks = config.tasks[0] 22 | else: 23 | self.tasks = config.tasks 24 | self.device = config.device 25 | self.shared_phm_rule = config.shared_phm_rule 26 | self.hypercomplex_adapters = config.hypercomplex_adapters 27 | self.adapters = self.construct_adapters(self.tasks) 28 | self.add_layer_norm_before_adapter = config.add_layer_norm_before_adapter 29 | self.add_layer_norm_after_adapter = config.add_layer_norm_after_adapter 30 | if self.add_layer_norm_before_adapter: 31 | self.pre_layer_norm = nn.LayerNorm(config.input_dim) 32 | if self.add_layer_norm_after_adapter: 33 | self.post_layer_norm = nn.LayerNorm(config.input_dim) 34 | 35 | def get_task(self, task): 36 | return task 37 | 38 | def construct_adapters(self, tasks): 39 | """ 40 | Constructs adapter layers and adds them to a dictionary for the given 41 | tasks. 42 | Args: 43 | tasks: A list of string containing the task names. 44 | """ 45 | for task in tasks: 46 | if self.hypercomplex_adapters: 47 | self.adapters[task] = HyperComplexAdapter(self.config) 48 | elif self.low_rank_adapters: 49 | self.adapters[task] = LowRankAdapter(self.config) 50 | else: 51 | self.adapters[task] = Adapter(self.config) 52 | return self.adapters 53 | 54 | def disable_adapters(self, tasks): 55 | """ 56 | Given a list of tasks, it freezes their corresponding adapter layers' 57 | parameters. 58 | Args: 59 | tasks: List of tasks. 60 | """ 61 | tasks = self.convert_to_list(tasks) 62 | for task in tasks: 63 | adapter = self.get_adapter(task) 64 | for param in adapter.parameters(): 65 | param.requires_grad = False 66 | 67 | def convert_to_list(self, tasks): 68 | if isinstance(tasks, list): 69 | return tasks 70 | return [tasks] 71 | 72 | def enable_adapters(self, tasks): 73 | """ 74 | Given a list of tasks, it unfreezes their corresponding adapter layers. 75 | Args: 76 | tasks: Given list of tasks. 77 | """ 78 | tasks = self.convert_to_list(tasks) 79 | for task in tasks: 80 | adapter = self.get_adapter(task) 81 | for name, param in adapter.named_parameters(): 82 | if self.config.hypercomplex_adapters and not self.config.learn_phm: 83 | if not "phm_rule" in name: 84 | param.requires_grad = True 85 | else: 86 | param.requires_grad = True 87 | 88 | def get_adapter(self, task): 89 | """Given a task returns its corresponding adapter layer. 90 | Args: 91 | task: Input task name. 92 | Returns: 93 | Adapter layer corresponding to the given task. 94 | """ 95 | # TODO : Fix Manual Setting 96 | # if task=="superglue-rte": 97 | # return self.adapters['paws'] 98 | # elif task=="triviaqa": 99 | # return self.adapters['hotpotqa'] 100 | # elif task=="lama": 101 | # return self.adapters['hotpotqa'] 102 | 103 | # TODO : Fix during only evaluation 104 | task = self.config.adapters_cur_training_task 105 | return self.adapters[task] 106 | #return self.adapters['paws'] 107 | 108 | def forward(self, inputs, task): 109 | """ 110 | Retrieves the adapter layer corresponding to the given 111 | task. It freezes the adapter layers for all the other tasks 112 | and call the selected adapter layer. 113 | Args: 114 | task: the name of the current task. 115 | inputs: the inputs to feed in in the adapter layer. 116 | Returns: 117 | outputs of the adapter layer. 118 | """ 119 | task = self.get_task(task) 120 | # Enables the adapter layer for the given task. 121 | self.enable_adapters(task) 122 | # Disable other adapters. 123 | other_tasks = [x for x in self.tasks if x != task] 124 | self.disable_adapters(other_tasks) 125 | adapter = self.get_adapter(task) 126 | z = self.pre_layer_norm( 127 | inputs) if self.add_layer_norm_before_adapter else inputs 128 | outputs = adapter(z) 129 | if self.add_layer_norm_after_adapter: 130 | outputs = self.post_layer_norm(outputs) 131 | outputs = outputs + inputs 132 | return outputs 133 | -------------------------------------------------------------------------------- /seq2seq/adapters/adapter_modeling.py: -------------------------------------------------------------------------------- 1 | """Implements an Adapter, Low-rank adapters and Hyper-adapter Layers.""" 2 | import torch.nn as nn 3 | from .adapter_utils import Activations 4 | from hypercomplex.layers import PHMLinear 5 | from .low_rank_layer import LowRankLinear 6 | 7 | 8 | class LowRankAdapter(nn.Module): 9 | """This is the low-rank adapter, in which each adapter is composed of two rank-one matrices. 10 | """ 11 | 12 | def __init__(self, config): 13 | super().__init__() 14 | self.config = config 15 | self.input_dim = config.input_dim 16 | self.down_sample_size = self.input_dim // config.reduction_factor 17 | self.activation = Activations(config.non_linearity.lower()) 18 | self.down_sampler = LowRankLinear(self.input_dim, self.down_sample_size, 19 | w_init=config.low_rank_w_init, 20 | rank=config.low_rank_rank) 21 | self.up_sampler = LowRankLinear(self.down_sample_size, self.input_dim, 22 | w_init=config.low_rank_w_init, 23 | rank=config.low_rank_rank) 24 | 25 | def forward(self, x): 26 | z = self.down_sampler(x) 27 | z = self.activation(z) 28 | output = self.up_sampler(z) 29 | return output 30 | 31 | 32 | class Adapter(nn.Module): 33 | """Conventional Adapter layer, in which the weights of up and down sampler modules 34 | are parameters and are optimized.""" 35 | 36 | def __init__(self, config): 37 | super().__init__() 38 | self.config = config 39 | self.input_dim = config.input_dim 40 | self.down_sample_size = self.input_dim // config.reduction_factor 41 | self.activation = Activations(config.non_linearity.lower()) 42 | self.down_sampler = nn.Linear(self.input_dim, self.down_sample_size) 43 | self.up_sampler = nn.Linear(self.down_sample_size, self.input_dim) 44 | 45 | def forward(self, x): 46 | z = self.down_sampler(x) 47 | z = self.activation(z) 48 | output = self.up_sampler(z) 49 | return output 50 | 51 | 52 | class HyperComplexAdapter(nn.Module): 53 | """Hypercomplex Adapter layer, in which the weights of up and down sampler modules 54 | are parameters are 1/n times of the conventional adapter layers, where n is 55 | hypercomplex division number.""" 56 | 57 | def __init__(self, config): 58 | super().__init__() 59 | self.config = config 60 | self.input_dim = config.input_dim 61 | self.down_sample_size = self.input_dim // config.reduction_factor 62 | self.activation = Activations(config.non_linearity.lower()) 63 | self.down_sampler = PHMLinear(in_features=self.input_dim, 64 | out_features=self.down_sample_size, 65 | bias=True, 66 | c_init=config.phm_c_init, 67 | phm_dim=config.hypercomplex_division, 68 | learn_phm=config.learn_phm, 69 | w_init=config.hypercomplex_nonlinearity, 70 | shared_phm_rule=config.shared_phm_rule, 71 | factorized_phm=config.factorized_phm, 72 | shared_W_phm=config.shared_W_phm, 73 | factorized_phm_rule=config.factorized_phm_rule, 74 | phm_rank=config.phm_rank, 75 | phm_init_range=config.phm_init_range, 76 | kronecker_prod=config.kronecker_prod) 77 | self.up_sampler = PHMLinear(in_features=self.down_sample_size, 78 | out_features=self.input_dim, 79 | bias=True, 80 | c_init=config.phm_c_init, 81 | phm_dim=config.hypercomplex_division, 82 | learn_phm=config.learn_phm, 83 | w_init=config.hypercomplex_nonlinearity, 84 | shared_phm_rule=config.shared_phm_rule, 85 | factorized_phm=config.factorized_phm, 86 | shared_W_phm=config.shared_W_phm, 87 | factorized_phm_rule=config.factorized_phm_rule, 88 | phm_rank=config.phm_rank, 89 | phm_init_range=config.phm_init_range, 90 | kronecker_prod=config.kronecker_prod) 91 | 92 | def forward(self, x): 93 | z = self.down_sampler(x) 94 | z = self.activation(z) 95 | return self.up_sampler(z) 96 | -------------------------------------------------------------------------------- /seq2seq/adapters/adapter_utils.py: -------------------------------------------------------------------------------- 1 | """Implementation of different utility functions for adapter layers.""" 2 | import torch.nn as nn 3 | from transformers.activations import get_activation 4 | 5 | 6 | class Activations(nn.Module): 7 | def __init__(self, activation_type): 8 | super().__init__() 9 | self.f = get_activation(activation_type) 10 | 11 | def forward(self, x): 12 | return self.f(x) 13 | -------------------------------------------------------------------------------- /seq2seq/adapters/low_rank_layer.py: -------------------------------------------------------------------------------- 1 | """This script implements a low-rank linear layer.""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | from hypercomplex.inits import glorot_uniform, glorot_normal 6 | 7 | 8 | class LowRankLinear(torch.nn.Module): 9 | def __init__(self, input_dim: int, output_dim: int, rank: int = 1, 10 | bias: bool = True, w_init: str = "glorot-uniform"): 11 | super(LowRankLinear, self).__init__() 12 | self.input_dim = input_dim 13 | self.output_dim = output_dim 14 | self.rank = rank 15 | self.bias = bias 16 | self.w_init = w_init 17 | self.W_left = nn.Parameter(torch.Tensor( 18 | size=(input_dim, rank)), requires_grad=True) 19 | self.W_right = nn.Parameter(torch.Tensor( 20 | size=(rank, output_dim)), requires_grad=True) 21 | if bias: 22 | self.b = nn.Parameter(torch.Tensor(output_dim)) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | if self.bias: 27 | self.b.data = torch.zeros_like(self.b.data) 28 | if self.w_init == "glorot-uniform": 29 | self.W_left.data = glorot_uniform(self.W_left.data) 30 | self.W_right.data = glorot_uniform(self.W_right.data) 31 | elif self.w_init == "glorot-normal": 32 | self.W_left.data = glorot_normal(self.W_left.data) 33 | self.W_right.data = glorot_normal(self.W_right.data) 34 | else: 35 | raise ValueError 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | W = self.W_left*self.W_right 39 | output = torch.matmul(input=x, other=W) 40 | if self.bias: 41 | output += self.b 42 | return output 43 | -------------------------------------------------------------------------------- /seq2seq/additional_code/check_data2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "Using custom data configuration default-4bc186b431f1f6e5\n", 20 | "Reusing dataset json (/home/joel_jang/.cache/huggingface/datasets/json/default-4bc186b431f1f6e5/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02)\n" 21 | ] 22 | }, 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "DatasetDict({\n", 28 | " train: Dataset({\n", 29 | " features: ['translation'],\n", 30 | " num_rows: 138000\n", 31 | " })\n", 32 | "})\n" 33 | ] 34 | }, 35 | { 36 | "ename": "KeyError", 37 | "evalue": "'translation'", 38 | "output_type": "error", 39 | "traceback": [ 40 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 41 | "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", 42 | "\u001b[1;32m/home/joel_jang/seungone/RoE/seq2seq/check_data2.ipynb 셀 2\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m wiki_auto \u001b[39m=\u001b[39m load_dataset(\u001b[39m\"\u001b[39m\u001b[39mjson\u001b[39m\u001b[39m\"\u001b[39m,data_files\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m./data/cT0/training_files/sequential/train.wiki_auto.continual1000.json\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 3\u001b[0m \u001b[39mprint\u001b[39m(wiki_auto)\n\u001b[0;32m----> 4\u001b[0m wiki_auto[\u001b[39m'\u001b[39m\u001b[39mtraining\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m wiki_auto[\u001b[39m'\u001b[39;49m\u001b[39mtranslation\u001b[39;49m\u001b[39m'\u001b[39;49m]\n\u001b[1;32m 5\u001b[0m \u001b[39mprint\u001b[39m(wiki_auto)\n", 43 | "\u001b[0;31mKeyError\u001b[0m: 'translation'" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "from datasets import load_dataset\n", 49 | "wiki_auto_training = load_dataset(\"json\",data_files=\"./data/cT0/training_files/sequential/train.wiki_auto.continual1000.json\")\n", 50 | "print(wiki_auto_training)\n", 51 | "wiki_auto = \n", 52 | "wiki_auto['training'] = wiki_auto['translation']\n", 53 | "print(wiki_auto)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [] 62 | } 63 | ], 64 | "metadata": { 65 | "kernelspec": { 66 | "display_name": "Python 3.8.13 ('RoE')", 67 | "language": "python", 68 | "name": "python3" 69 | }, 70 | "language_info": { 71 | "codemirror_mode": { 72 | "name": "ipython", 73 | "version": 3 74 | }, 75 | "file_extension": ".py", 76 | "mimetype": "text/x-python", 77 | "name": "python", 78 | "nbconvert_exporter": "python", 79 | "pygments_lexer": "ipython3", 80 | "version": "3.8.13" 81 | }, 82 | "orig_nbformat": 4, 83 | "vscode": { 84 | "interpreter": { 85 | "hash": "75ef641d7212823f7d06bad9f0560d9d6992ac1bf3ac25b9d30b7c3dc223aa68" 86 | } 87 | } 88 | }, 89 | "nbformat": 4, 90 | "nbformat_minor": 2 91 | } 92 | -------------------------------------------------------------------------------- /seq2seq/additional_code/check_target_eval_results_twitter.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | import numpy as np 3 | from typing import List 4 | 5 | from sklearn.feature_extraction.text import CountVectorizer 6 | from sklearn.naive_bayes import GaussianNB 7 | from sklearn.linear_model import RidgeClassifier 8 | import _pickle as cPickle 9 | 10 | class Clf(): 11 | 12 | """ 13 | Usage: 14 | 1) load the clf for a task: 15 | path_folder_data = f'{GLOBAL_PATH}/data' 16 | evalset = 'twitter_top20' 17 | prompt_name = 'tweet_as+about' 18 | label_name = 'author' 19 | clf = Clf(path_folder_data, evalset, prompt_name, label_name) 20 | 21 | 2) infer: 22 | print(clf.compute_score(evaluated_predictions)) 23 | """ 24 | 25 | def __init__(self, path_folder_data, evalset, prompt_name, label_name): 26 | self.path_folder_data = path_folder_data 27 | self.evalset = evalset 28 | self.prompt_name = prompt_name 29 | self.label_name = label_name 30 | 31 | self.key_name = f'{evalset}.{prompt_name}.{label_name}' 32 | 33 | path_model = f'{self.key_name}.model.pkl' 34 | path_count_vectorizer = f'{self.key_name}.count_vectorizer.pkl' 35 | 36 | if os.path.exists(path_model): 37 | # load it 38 | with open(path_model, 'rb') as fid: 39 | self.model = cPickle.load(fid) 40 | with open(path_count_vectorizer, 'rb') as fid: 41 | self.count_vectorizer = cPickle.load(fid) 42 | else: 43 | self.model = RidgeClassifier() #GaussianNB() 44 | self.count_vectorizer = CountVectorizer(binary=True) 45 | self.train_model() 46 | # save the classifier 47 | with open(path_model, 'wb') as fid: 48 | cPickle.dump(self.model, fid) 49 | with open(path_count_vectorizer, 'wb') as fid: 50 | cPickle.dump(self.count_vectorizer, fid) 51 | 52 | #transform test data 53 | X_test, y_test = self.get_data('test') 54 | self.y_test = y_test 55 | predictions = self.get_preds(X_test) 56 | print("Accuracy clf:", self.accuracy_score(y_test, predictions)) 57 | 58 | def get_data(self, eval_mode): 59 | 60 | path_ex = os.path.join(self.path_folder_data, self.evalset, f'{self.prompt_name}.{eval_mode}.json') 61 | 62 | with open(path_ex, 'r') as f: 63 | data = json.load(f) 64 | 65 | nb_ex = len(data['src_info']) 66 | outputs = [data['tgt'][idx] for idx in range(nb_ex)] 67 | labels = [data['src_info'][idx][self.label_name] for idx in range(nb_ex)] 68 | 69 | assert len(outputs) == len(labels) 70 | 71 | return outputs, labels 72 | 73 | def train_model(self): 74 | 75 | #fit training data 76 | X_train, y_train = self.get_data('train') 77 | training_data = self.count_vectorizer.fit_transform(X_train).toarray() 78 | self.model.fit(training_data, y_train) 79 | 80 | @staticmethod 81 | def accuracy_score(y_true, y_pred): 82 | return np.average([y1 == y2 for y1, y2 in zip(y_true, y_pred)]) 83 | 84 | def get_preds(self, X_test): 85 | testing_data = self.count_vectorizer.transform(X_test).toarray() 86 | predictions = self.model.predict(testing_data) 87 | 88 | return predictions 89 | 90 | def compute_score(self, outputs): 91 | 92 | clf_predictions = self.get_preds(outputs) 93 | print('*****************************') 94 | print(len(clf_predictions)) 95 | print(len(self.y_test)) 96 | print('*****************************') 97 | return {'CLF_acc': self.accuracy_score(self.y_test, clf_predictions)} 98 | 99 | path_folder_data = "/home/joel_jang/seungone/RoE/seq2seq/data/manual/ct0_data/twitter_top20" 100 | evalset = "twitter_top20" 101 | prompt_name = "tweet_as+about" 102 | label_name = "author" 103 | 104 | clf = Clf(path_folder_data,evalset,prompt_name,label_name) 105 | 106 | twitter_top20 = "/home/joel_jang/seungone/RoE/seq2seq/output_logs/twitter/twitter*tweet_as+about-twitter*tweet_as+about.txt" 107 | results = [twitter_top20] 108 | 109 | predictions = [] 110 | references = [] 111 | sources = [] 112 | 113 | for idx,r in enumerate(results): 114 | pred = [] 115 | ref = [] 116 | src = [] 117 | with open(r,'r') as f: 118 | lines = f.readlines() 119 | for line in lines: 120 | if ('##' not in line) and ('>>' not in line) and ('*' not in line): 121 | 122 | p = line.split(' | ')[0].strip() 123 | r = line.split(' | ')[1].strip() 124 | s = line.split(' | ')[2].strip() 125 | pred.append(p) 126 | ref.append(r) 127 | src.append(s) 128 | predictions.append(pred) 129 | references.append(ref) 130 | sources.append(src) 131 | 132 | results = clf.compute_score(predictions[0]) 133 | print(results) -------------------------------------------------------------------------------- /seq2seq/additional_code/count.py: -------------------------------------------------------------------------------- 1 | import os 2 | root = 'expert_weights' 3 | lst = os.listdir(root) 4 | cnt = 0 5 | for l in lst: 6 | lst2 = os.listdir(root+'/'+l) 7 | cnt+=len(lst2) 8 | print(f'Total number of adapters: {cnt}') -------------------------------------------------------------------------------- /seq2seq/args_model.py: -------------------------------------------------------------------------------- 1 | from adapters import ADAPTER_CONFIG_MAPPING 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | save_adapter_weights: bool = field( 11 | default=True, 12 | metadata={ 13 | "help": "Save the weights for the task-specific adapter."} 14 | ) 15 | load_adapter_weights: bool = field( 16 | default=False, 17 | metadata={ 18 | "help": "Load the weights used to task-sepcific adapters."} 19 | ) 20 | adapter_dir: str = field( 21 | default=None, 22 | metadata={ 23 | "help": "Path to load task-specific adapters"} 24 | ) 25 | model_name_or_path: str = field( 26 | default=None, 27 | metadata={ 28 | "help": "Path to pretrained model or model identifier from huggingface.co/models"} 29 | ) 30 | config_name: Optional[str] = field( 31 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 32 | ) 33 | tokenizer_name: Optional[str] = field( 34 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 35 | ) 36 | cache_dir: Optional[str] = field( 37 | default=None, 38 | metadata={ 39 | "help": "Where to store the pretrained models downloaded from huggingface.co"}, 40 | ) 41 | use_fast_tokenizer: bool = field( 42 | default=True, 43 | metadata={ 44 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 45 | ) 46 | model_revision: str = field( 47 | default="main", 48 | metadata={ 49 | "help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 50 | ) 51 | use_auth_token: bool = field( 52 | default=False, 53 | metadata={ 54 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 55 | "with private models)." 56 | }, 57 | ) 58 | load_prefix_embeddings: bool = field( 59 | default=False, 60 | metadata={ 61 | "help": "load prefix embeddings or not" 62 | }, 63 | ) 64 | save_prefix_only: bool = field( 65 | default=False, 66 | metadata={ 67 | "help": "save prefix embeddings only" 68 | }, 69 | ) 70 | 71 | prompt_embedding_path: Optional[List[str]] = field( 72 | default=None, 73 | metadata={"help": "A list of the paths to prefix embeddings"} 74 | ) 75 | 76 | target_prompt_embedding_path: Optional[str] = field( 77 | default=None, 78 | metadata={"help": "a path to the target prompt embedding"} 79 | ) 80 | 81 | attn_prefix_tuning: bool = field( 82 | default=False, 83 | metadata={ 84 | "help": "Set true if you try ATTEMPT." 85 | }, 86 | ) 87 | 88 | attn_method: Optional[str] = field( 89 | default="sub", 90 | metadata={ 91 | "help": "Attention model for attn_prefix. We currently support the following methods: linear, sub (our main method), and constant (gives the constant and equal weights to all of the prompts.)" 92 | }, 93 | ) 94 | 95 | shared_attn: bool = field( 96 | default=False, 97 | metadata={ 98 | "help": "shared attention" 99 | }, 100 | ) 101 | 102 | load_attention: bool = field( 103 | default=False, 104 | metadata={ 105 | "help": "Set true if you want to load pre-trained attention weights" 106 | }, 107 | ) 108 | 109 | attn_path: Optional[str] = field( 110 | default=None, 111 | metadata={ 112 | "help": "path to attention weights (linear attentions). " 113 | }, 114 | ) 115 | 116 | attn_path_sub: Optional[List[str]] = field( 117 | default=None, 118 | metadata={ 119 | "help": "list of the path to attention weights (sub attentions). [path_to_down_projection_weights, path_to_up_projection_weights]" 120 | }, 121 | ) 122 | 123 | ignore_target: bool = field( 124 | default=False, 125 | metadata={ 126 | "help": "Whether to ignore the new target tokens. Mainly for ablation." 127 | }, 128 | ) 129 | 130 | fix_attention: bool = field( 131 | default=False, 132 | metadata={ 133 | "help": "this will make the attention weights frozen during training. Mainly for ablation." 134 | }, 135 | ) 136 | 137 | temperature: float = field( 138 | default=2000, 139 | metadata={ 140 | "help": "set the soft max temperature of ATTEMPT." 141 | }, 142 | ) 143 | 144 | attn_learning_rate: float = field( 145 | default=None, 146 | metadata={ 147 | "help": "set the learning rate for the attention modules." 148 | }, 149 | ) 150 | 151 | load_layer_norm: bool = field( 152 | default=False, 153 | metadata={ 154 | "help": "Set true if you want to load pre-trained layer-norm weight and biases." 155 | }, 156 | ) 157 | 158 | layer_norm_dir: Optional[List[str]] = field( 159 | default=None, 160 | metadata={ 161 | "help": "Layer norm dir. [path_to_layer_norm_weight.pt, path_to_layer_norm_bias.pt]" 162 | }, 163 | ) 164 | 165 | prefix_num: Optional[int] = field( 166 | default=1, metadata={"help": "the number of prefix"}) -------------------------------------------------------------------------------- /seq2seq/args_training.py: -------------------------------------------------------------------------------- 1 | from adapters import ADAPTER_CONFIG_MAPPING 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List 4 | from transformers import Seq2SeqTrainingArguments 5 | 6 | @dataclass 7 | class TrainingArguments(Seq2SeqTrainingArguments): 8 | wandb_log: Optional[bool] = field(default=False, 9 | metadata={"help": "If set, logs experimental results to wandb"}) 10 | wandb_entity: Optional[str] = field(default="lklab_kaist", 11 | metadata={"help": "Set to the the wandb name"}) 12 | wandb_project: Optional[str] = field(default="retrieval_of_experts", 13 | metadata={"help": "Set to the project name of wandb"}) 14 | wandb_run_name: Optional[str] = field(default="default_run_yolo", 15 | metadata={"help": "Desingate the wandb run name"}) 16 | print_num_parameters: Optional[bool] = field(default=False, metadata={"help": "If set, print the parameters of the model."}) 17 | do_train: Optional[bool] = field(default=False, metadata={ 18 | "help": "If set, evaluates the train performance."}) 19 | do_eval: Optional[bool] = field(default=False, metadata={ 20 | "help": "If set, evaluates the eval performance."}) 21 | do_test: Optional[bool] = field(default=False, metadata={ 22 | "help": "If set, evaluates the test performance."}) 23 | split_validation_test: Optional[bool] = field(default=False, 24 | metadata={"help": "If set, for the datasets which do not" 25 | "have the test set, we use validation set as their" 26 | "test set and make a validation set from either" 27 | "splitting the validation set into half (for smaller" 28 | "than 10K samples datasets), or by using 1K examples" 29 | "from training set as validation set (for larger" 30 | " datasets)."}) 31 | compute_time: Optional[bool] = field( 32 | default=False, metadata={"help": "If set measures the time."}) 33 | compute_memory: Optional[bool] = field( 34 | default=False, metadata={"help": "if set, measures the memory"}) 35 | prefix_length: Optional[int] = field( 36 | default=100, metadata={"help": "Defines the length for prefix tuning."}) 37 | report_to: Optional[str] = field(default="wandb") 38 | save_strategy: Optional[str] = field(default="no") 39 | 40 | -------------------------------------------------------------------------------- /seq2seq/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .tasks import TASK_MAPPING, AutoTask 2 | from .data_collator import TaskDataCollatorForSeq2Seq 3 | from .postprocessors import AutoPostProcessor 4 | -------------------------------------------------------------------------------- /seq2seq/data/custom_multi_news/custom_multi_news.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Multi-News dataset.""" 18 | 19 | import datasets 20 | 21 | 22 | _HOMEPAGE = "https://github.com/Alex-Fabbri/Multi-News" 23 | 24 | _LICENSE = "For non-commercial research and educational purposes only" 25 | 26 | _CITATION = """ 27 | @misc{alex2019multinews, 28 | title={Multi-News: a Large-Scale Multi-Document Summarization Dataset and Abstractive Hierarchical Model}, 29 | author={Alexander R. Fabbri and Irene Li and Tianwei She and Suyi Li and Dragomir R. Radev}, 30 | year={2019}, 31 | eprint={1906.01749}, 32 | archivePrefix={arXiv}, 33 | primaryClass={cs.CL} 34 | } 35 | """ 36 | 37 | _DESCRIPTION = """ 38 | Multi-News, consists of news articles and human-written summaries 39 | of these articles from the site newser.com. 40 | Each summary is professionally written by editors and 41 | includes links to the original articles cited. 42 | There are two features: 43 | - document: text of news articles seperated by special token "|||||". 44 | - summary: news summary. 45 | """ 46 | 47 | #_REPO = "https://huggingface.co/datasets/multi_news/resolve/main/data" 48 | _REPO = "./manual/multi_news" 49 | _URLs = { 50 | "train": [ 51 | f"{_REPO}/train.src.cleaned", 52 | f"{_REPO}/train.tgt", 53 | ], 54 | "val": [ 55 | f"{_REPO}/val.src.cleaned", 56 | f"{_REPO}/val.tgt", 57 | ], 58 | } 59 | 60 | _DOCUMENT = "document" 61 | _SUMMARY = "summary" 62 | 63 | 64 | class MultiNews(datasets.GeneratorBasedBuilder): 65 | """Multi-News dataset.""" 66 | 67 | VERSION = datasets.Version("1.0.0") 68 | 69 | def _info(self): 70 | return datasets.DatasetInfo( 71 | description=_DESCRIPTION, 72 | features=datasets.Features({_DOCUMENT: datasets.Value("string"), _SUMMARY: datasets.Value("string")}), 73 | supervised_keys=(_DOCUMENT, _SUMMARY), 74 | homepage=_HOMEPAGE, 75 | license=_LICENSE, 76 | citation=_CITATION, 77 | ) 78 | 79 | def _split_generators(self, dl_manager): 80 | """Returns SplitGenerators.""" 81 | #files = dl_manager.download(_URLs) 82 | return [ 83 | datasets.SplitGenerator( 84 | name=datasets.Split.TRAIN, 85 | gen_kwargs={"src_file": _URLs["train"][0], "tgt_file": _URLs["train"][1]}, 86 | ), 87 | datasets.SplitGenerator( 88 | name=datasets.Split.VALIDATION, 89 | gen_kwargs={"src_file": _URLs["val"][0], "tgt_file": _URLs["val"][1]}, 90 | ), 91 | ] 92 | 93 | def _generate_examples(self, src_file, tgt_file): 94 | """Yields examples.""" 95 | with open(src_file, encoding="utf-8") as src_f, open(tgt_file, encoding="utf-8") as tgt_f: 96 | for i, (src_line, tgt_line) in enumerate(zip(src_f, tgt_f)): 97 | yield i, { 98 | # In original file, each line has one example and natural newline 99 | # tokens "\n" are being replaced with "NEWLINE_CHAR". Here restore 100 | # the natural newline token to avoid special vocab "NEWLINE_CHAR". 101 | _DOCUMENT: src_line.strip().replace("NEWLINE_CHAR", "\n"), 102 | _SUMMARY: tgt_line.strip(), 103 | } 104 | -------------------------------------------------------------------------------- /seq2seq/data/data_collator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dataclasses import dataclass 3 | from transformers import DataCollatorForSeq2Seq 4 | 5 | 6 | # @dataclass 7 | # class TaskDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): 8 | # def check_uniqueness(self, samples): 9 | # assert len(np.unique(samples)) == 1 10 | 11 | # def __call__(self, features): 12 | # tasks = [] 13 | # for d in features: 14 | # if type(d) is dict: 15 | # tasks.append(d.pop('task')) 16 | # else: 17 | # tasks.append(d['task']) 18 | # d.remove_columns('task') 19 | 20 | # self.check_uniqueness(tasks) 21 | # output = super().__call__(features) 22 | # output["task"] = tasks[0] 23 | # return output 24 | @dataclass 25 | class TaskDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): 26 | def check_uniqueness(self, samples): 27 | assert len(np.unique(samples)) == 1 28 | 29 | def __call__(self, features): 30 | print('#$$$$$$$$$$$### COLLATOR DEBUG #$############$$$') 31 | print(features[0].keys()) 32 | print('#############$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$') 33 | tasks = [d.pop('task') for d in features] 34 | labels_list_exist=False 35 | if 'labels_list' in features[0]: 36 | labels_list_exist=True 37 | labels_list = [d.pop('labels_list') for d in features] 38 | self.check_uniqueness(tasks) 39 | output = super().__call__(features) 40 | output["task"] = tasks[0] 41 | if labels_list_exist: 42 | output["labels_list"] = labels_list 43 | # print('#$$$$$$$$$$$### COLLATOR DEBUG2 #$############$$$') 44 | # print(output) 45 | # print('#############$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$') 46 | return output -------------------------------------------------------------------------------- /seq2seq/data/postprocessors.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import OrderedDict 3 | import numpy as np 4 | 5 | """Defines functions to process the outputs to make them ready for the evaluation.""" 6 | 7 | 8 | def string_to_float(string, default=-1., **unused_kwargs): 9 | """Converts string to float, using default when conversion not possible.""" 10 | try: 11 | return float(string) 12 | except ValueError: 13 | return default 14 | 15 | 16 | class PostProcessor(abc.ABC): 17 | """Postprocess the predictions and labels to make them suitable for 18 | evaluation.""" 19 | 20 | def __init__(self, tokenizer, ignore_pad_token_for_loss): 21 | self.tokenizer = tokenizer 22 | self.ignore_pad_token_for_loss = ignore_pad_token_for_loss 23 | 24 | def process(self, preds, labels, data_info=None): 25 | if isinstance(preds, tuple): 26 | preds = preds[0] 27 | if self.ignore_pad_token_for_loss: 28 | # Replace -100 in the labels as we can't decode them. 29 | labels = np.where(labels != -100, labels, 30 | self.tokenizer.pad_token_id) 31 | preds = np.where(preds != -100, preds, 32 | self.tokenizer.pad_token_id) 33 | decoded_preds = self.tokenizer.batch_decode( 34 | preds, skip_special_tokens=True) 35 | decoded_labels = self.tokenizer.batch_decode( 36 | labels, skip_special_tokens=True) 37 | 38 | # Some simple post-processing 39 | decoded_preds = [pred.strip() for pred in decoded_preds] 40 | decoded_labels = [label.strip() for label in decoded_labels] 41 | return decoded_preds, decoded_labels 42 | 43 | 44 | class MultiRC(PostProcessor): 45 | def process(self, preds, labels, data_info): 46 | preds, labels = super().process(preds, labels, data_info) 47 | preds = [{"group": info["group"], "value":pred} 48 | for info, pred in zip(data_info, preds)] 49 | labels = [{"group": info["group"], "value": label} 50 | for info, label in zip(data_info, labels)] 51 | return preds, labels 52 | 53 | 54 | class Record(PostProcessor): 55 | def process(self, preds, labels, data_info): 56 | preds, labels = super().process(preds, labels, data_info) 57 | labels = [info["answers"] for info in data_info] 58 | return preds, labels 59 | 60 | 61 | POSTPROCESSOR_MAPPING = OrderedDict( 62 | [ 63 | ('superglue-record', Record), 64 | ('superglue-multirc', MultiRC) 65 | ] 66 | ) 67 | 68 | 69 | class AutoPostProcessor: 70 | @classmethod 71 | def get(self, task, tokenizer, ignore_pad_token_for_loss): 72 | if task in POSTPROCESSOR_MAPPING: 73 | return POSTPROCESSOR_MAPPING[task](tokenizer, ignore_pad_token_for_loss) 74 | return PostProcessor(tokenizer, ignore_pad_token_for_loss) 75 | -------------------------------------------------------------------------------- /seq2seq/data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def round_stsb_target(label): 5 | """STSB maps two sentences to a floating point number between 1 and 5 6 | representing their semantic similarity. Since we are treating all tasks as 7 | text-to-text tasks we need to convert this floating point number to a string. 8 | The vast majority of the similarity score labels in STSB are in the set 9 | [0, 0.2, 0.4, ..., 4.8, 5.0]. So, we first round the number to the closest 10 | entry in this set, and then we convert the result to a string (literally e.g. 11 | "3.4"). This converts STSB roughly into a 26-class classification dataset. 12 | Args: 13 | label: original label. 14 | Returns: 15 | A preprocessed label. 16 | """ 17 | return np.round((label * 5) / 5, decimals=1) 18 | -------------------------------------------------------------------------------- /seq2seq/extract_data_subset_target_eval.py: -------------------------------------------------------------------------------- 1 | from data.tasks import AutoTask, TASK_MAPPING, STORYCLOZE, ANLIR1, ANLIR2, ANLIR3, WIKIHOP, AMAZONPOLARITY, YELPREVIEWFULL, DBPEDIA14, TREC, IMDB, APPREVIEWS, GIGAWORD, ROTTENTOMATOES, GeneralTask 2 | import torch 3 | import json 4 | task_list = { 5 | "wiki_auto":{ 6 | "config":"skip", 7 | "prompts":[ 8 | "simplification_1", 9 | "simplification_2", 10 | ] 11 | }, 12 | "asset":{ 13 | "config":"skip", 14 | "prompts":[ 15 | "simplification_1", 16 | "simplification_2", 17 | ] 18 | }, 19 | "ct0_gigaword":{ 20 | "config":"skip", 21 | "prompts":[ 22 | "constrain_contain+make_a_title", 23 | "constrain_contain+write_its_sentence", 24 | "constrain_end+make_a_title", 25 | "constrain_end+write_its_sentence", 26 | "constrain_start+make_a_title", 27 | "constrain_start+write_its_sentence", 28 | ] 29 | }, 30 | "haiku":{ 31 | "config":"skip", 32 | "prompts":[ 33 | "do_nothing" 34 | ] 35 | }, 36 | "covid_qa":{ 37 | "config":"skip", 38 | "prompts":[ 39 | "covid_qa_deepset" 40 | ] 41 | }, 42 | "eli5":{ 43 | "config":"skip", 44 | "prompts":[ 45 | "generate_a_question_1" 46 | ] 47 | }, 48 | "emdg":{ 49 | "config":"skip", 50 | "prompts":[ 51 | "dialogue_with_emotion" 52 | ] 53 | }, 54 | "esnli":{ 55 | "config":"skip", 56 | "prompts":[ 57 | "explain_why" 58 | ] 59 | }, 60 | "twitter":{ 61 | "config":"skip", 62 | "prompts":[ 63 | "tweet_as+about" 64 | ] 65 | } 66 | } 67 | 68 | def shuffled_indices(dataset): 69 | num_samples = len(dataset) 70 | generator = torch.Generator() 71 | generator.manual_seed(42) 72 | return torch.randperm(num_samples, generator=generator).tolist() 73 | 74 | column_names = ['source', 'target', 'extra_fields'] 75 | max_num_instances = 300 76 | n_obs = max_num_instances 77 | eval_instances={} 78 | for _,task_name in enumerate(task_list): 79 | n_obs=max_num_instances 80 | task = task_list[task_name] 81 | config = task['config'] 82 | 83 | eval_instances[task_name] = {} 84 | for prompt in task['prompts']: 85 | data_class = AutoTask.get(task_name, config, prompt) 86 | 87 | if prompt not in eval_instances[task_name].keys(): 88 | eval_instances[task_name][prompt] = {} 89 | eval_data = data_class.load_dataset('validation') 90 | 91 | if len(eval_data)< max_num_instances: 92 | n_obs=len(eval_data) 93 | eval_indices = shuffled_indices(eval_data) 94 | eval_indices = eval_indices[:n_obs] 95 | eval_data = eval_data.select(eval_indices) 96 | 97 | 98 | for idx,data in enumerate(eval_data): 99 | d = data_class.preprocessor(data) 100 | if 'labels_list' in d: 101 | e_formatted_d = { 102 | "config" : config, 103 | "task" : d['task'], 104 | "prompt" : prompt, 105 | "source" : d['source'], 106 | "target" : d['target'], 107 | "labels_list" : d['labels_list'] 108 | } 109 | else: 110 | e_formatted_d = { 111 | "config" : config, 112 | "prompt" : prompt, 113 | "source" : d['source'], 114 | "target" : d['target'] 115 | } 116 | eval_instances[task_name][prompt][str(idx)]=e_formatted_d 117 | 118 | print(len(eval_instances)) 119 | print(len(eval_instances['emdg'])) 120 | 121 | with open(f'target_eval_{max_num_instances}.json','w') as f: 122 | json.dump(eval_instances,f,indent=4) -------------------------------------------------------------------------------- /seq2seq/extract_data_subset_target_train.py: -------------------------------------------------------------------------------- 1 | from data.tasks import AutoTask, TASK_MAPPING, STORYCLOZE, ANLIR1, ANLIR2, ANLIR3, WIKIHOP, AMAZONPOLARITY, YELPREVIEWFULL, DBPEDIA14, TREC, IMDB, APPREVIEWS, GIGAWORD, ROTTENTOMATOES, GeneralTask 2 | import torch 3 | import json 4 | task_list = { 5 | "wiki_auto":{ 6 | "config":"skip", 7 | "prompts":[ 8 | "simplification_1", 9 | "simplification_2", 10 | ] 11 | }, 12 | "asset":{ 13 | "config":"skip", 14 | "prompts":[ 15 | "simplification_1", 16 | "simplification_2", 17 | ] 18 | }, 19 | "ct0_gigaword":{ 20 | "config":"skip", 21 | "prompts":[ 22 | "constrain_contain+make_a_title", 23 | "constrain_contain+write_its_sentence", 24 | "constrain_end+make_a_title", 25 | "constrain_end+write_its_sentence", 26 | "constrain_start+make_a_title", 27 | "constrain_start+write_its_sentence", 28 | ] 29 | }, 30 | "haiku":{ 31 | "config":"skip", 32 | "prompts":[ 33 | "do_nothing" 34 | ] 35 | }, 36 | "covid_qa":{ 37 | "config":"skip", 38 | "prompts":[ 39 | "covid_qa_deepset" 40 | ] 41 | }, 42 | "eli5":{ 43 | "config":"skip", 44 | "prompts":[ 45 | "generate_a_question_1" 46 | ] 47 | }, 48 | "emdg":{ 49 | "config":"skip", 50 | "prompts":[ 51 | "dialogue_with_emotion" 52 | ] 53 | }, 54 | "esnli":{ 55 | "config":"skip", 56 | "prompts":[ 57 | "explain_why" 58 | ] 59 | }, 60 | "twitter":{ 61 | "config":"skip", 62 | "prompts":[ 63 | "tweet_as+about" 64 | ] 65 | } 66 | } 67 | 68 | def shuffled_indices(dataset): 69 | num_samples = len(dataset) 70 | generator = torch.Generator() 71 | generator.manual_seed(42) 72 | return torch.randperm(num_samples, generator=generator).tolist() 73 | 74 | column_names = ['source', 'target', 'extra_fields'] 75 | max_num_instances = 100 76 | n_obs = max_num_instances 77 | train_instances={} 78 | for _,task_name in enumerate(task_list): 79 | n_obs=max_num_instances 80 | task = task_list[task_name] 81 | config = task['config'] 82 | 83 | train_instances[task_name] = {} 84 | for prompt in task['prompts']: 85 | data_class = AutoTask.get(task_name, config, prompt) 86 | 87 | if prompt not in train_instances[task_name].keys(): 88 | train_instances[task_name][prompt] = {} 89 | train_data = data_class.load_dataset('validation') 90 | 91 | if len(train_data)< max_num_instances: 92 | n_obs=len(train_data) 93 | train_indices = shuffled_indices(train_data) 94 | train_indices = train_indices[:n_obs] 95 | train_data = train_data.select(train_indices) 96 | 97 | 98 | for idx,data in enumerate(train_data): 99 | d = data_class.preprocessor(data) 100 | if 'labels_list' in d: 101 | e_formatted_d = { 102 | "config" : config, 103 | "task" : d['task'], 104 | "prompt" : prompt, 105 | "source" : d['source'], 106 | "target" : d['target'], 107 | "labels_list" : d['labels_list'] 108 | } 109 | else: 110 | e_formatted_d = { 111 | "config" : config, 112 | "prompt" : prompt, 113 | "source" : d['source'], 114 | "target" : d['target'] 115 | } 116 | train_instances[task_name][prompt][str(idx)]=e_formatted_d 117 | 118 | print(len(train_instances)) 119 | print(len(train_instances['emdg'])) 120 | 121 | with open(f'target_train_{max_num_instances}.json','w') as f: 122 | json.dump(train_instances,f,indent=4) -------------------------------------------------------------------------------- /seq2seq/hypercomplex/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeljang/ELM/27fc2a326cd8846717a022db1cbd94d393420bce/seq2seq/hypercomplex/__init__.py -------------------------------------------------------------------------------- /seq2seq/hypercomplex/inits.py: -------------------------------------------------------------------------------- 1 | # The codes are from https://github.com/bayer-science-for-a-better-life/phc-gnn 2 | import torch 3 | import math 4 | 5 | 6 | def glorot_normal(tensor: torch.Tensor): 7 | return torch.nn.init.xavier_normal_(tensor, gain=math.sqrt(2)) 8 | 9 | 10 | def glorot_uniform(tensor: torch.Tensor): 11 | return torch.nn.init.xavier_uniform_(tensor, gain=math.sqrt(2)) 12 | -------------------------------------------------------------------------------- /seq2seq/hypercomplex/kronecker.py: -------------------------------------------------------------------------------- 1 | # The codes are from https://github.com/bayer-science-for-a-better-life/phc-gnn 2 | import torch 3 | 4 | # TODO: change this with torch.kron 5 | """A part of the pylabyk library: numpytorch.py at https://github.com/yulkang/pylabyk""" 6 | 7 | 8 | def kronecker_product(a, b): 9 | """ 10 | Kronecker product of matrices a and b with leading batch dimensions. 11 | Batch dimensions are broadcast. The number of them mush 12 | :type a: torch.Tensor 13 | :type b: torch.Tensor 14 | :rtype: torch.Tensor 15 | """ 16 | # return torch.stack([torch.kron(ai, bi) for ai, bi in zip(a,b)], dim=0) 17 | siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:])) 18 | res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4) 19 | siz0 = res.shape[:-4] 20 | out = res.reshape(siz0 + siz1) 21 | return out 22 | 23 | 24 | def kronecker_product_einsum_batched(A: torch.Tensor, B: torch.Tensor): 25 | """ 26 | Batched Version of Kronecker Products 27 | :param A: has shape (b, a, c) 28 | :param B: has shape (b, k, p) 29 | :return: (b, ak, cp) 30 | """ 31 | assert A.dim() == 3 and B.dim() == 3 32 | res = torch.einsum('bac,bkp->bakcp', A, B).view(A.size(0), 33 | A.size(1)*B.size(1), 34 | A.size(2)*B.size(2)) 35 | return res 36 | -------------------------------------------------------------------------------- /seq2seq/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeljang/ELM/27fc2a326cd8846717a022db1cbd94d393420bce/seq2seq/metrics/__init__.py -------------------------------------------------------------------------------- /seq2seq/metrics/qa_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The T5 Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # source: the codes are from https://github.com/google-research/text-to-text-transfer-transformer 15 | """Utilities for Question Answering (QA) evaluation. 16 | Matches results on the SQuAD (v1.1) and TriviaQA (v1.0) evaluation scripts. 17 | """ 18 | 19 | import collections 20 | import string 21 | import regex as re 22 | import numpy as np 23 | 24 | 25 | def _normalize_answer(text, punc_chars, punc_repl): 26 | """Lower text and remove punctuation, articles and extra whitespace.""" 27 | 28 | def remove_articles(s): 29 | return re.sub(r"\b(a|an|the)\b", " ", s) 30 | 31 | def replace_punctuation(s): 32 | to_replace = set(punc_chars) 33 | return "".join(punc_repl if ch in to_replace else ch for ch in s) 34 | 35 | def white_space_fix(s): 36 | return " ".join(s.split()) 37 | 38 | text = text.lower() 39 | text = replace_punctuation(text) 40 | text = remove_articles(text) 41 | text = white_space_fix(text) 42 | return text 43 | 44 | 45 | def normalize_trivia_qa(answer): 46 | """Normalization used in official TriviaQA evaluation script.""" 47 | return _normalize_answer( 48 | answer, punc_chars=string.punctuation + "‘’´`_", punc_repl=" ").strip() 49 | 50 | 51 | def normalize_squad(answer): 52 | """Normalization used in official SQuAD evaluation script.""" 53 | return _normalize_answer(answer, punc_chars=string.punctuation, punc_repl="") 54 | 55 | 56 | def _metric_max_over_ground_truths(metric_fn, ground_truths, prediction): 57 | """Computes the maximum of the metric over all ground truths.""" 58 | return max( 59 | metric_fn(ground_truth, prediction) for ground_truth in ground_truths 60 | ) 61 | 62 | 63 | def _exact_match_score(target, prediction): 64 | return target == prediction 65 | 66 | 67 | def _f1_score(target, prediction): 68 | """Computes token f1 score for a single target and prediction.""" 69 | prediction_tokens = prediction.split() 70 | target_tokens = target.split() 71 | common = (collections.Counter(prediction_tokens) & 72 | collections.Counter(target_tokens)) 73 | num_same = sum(common.values()) 74 | if num_same == 0: 75 | return 0 76 | precision = 1.0 * num_same / len(prediction_tokens) 77 | recall = 1.0 * num_same / len(target_tokens) 78 | f1 = (2 * precision * recall) / (precision + recall) 79 | return f1 80 | 81 | 82 | def qa_metrics(targets, predictions): 83 | """Computes exact match and f1 QA scores, expecting pre-normalized text.""" 84 | if len(targets) != len(predictions): 85 | raise ValueError("Number of targets and predictions must match.") 86 | em = np.mean([ 87 | _metric_max_over_ground_truths(_exact_match_score, t, p) 88 | for p, t in zip(predictions, targets) 89 | ]) 90 | f1 = np.mean([ 91 | _metric_max_over_ground_truths(_f1_score, t, p) 92 | for p, t in zip(predictions, targets) 93 | ]) 94 | em *= 100 95 | f1 *= 100 96 | return {"em": em, "f1": f1} -------------------------------------------------------------------------------- /seq2seq/projections/__init__.py: -------------------------------------------------------------------------------- 1 | from .intrinsic import IntrinsicDimensionLight 2 | -------------------------------------------------------------------------------- /seq2seq/projections/fwh_cuda/fwh_cpp.cpp: -------------------------------------------------------------------------------- 1 | // The codes are from Armen Aghajanyan from facebook, from paper 2 | // Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning 3 | // https://arxiv.org/abs/2012.13255 4 | 5 | 6 | #include 7 | 8 | #include 9 | 10 | // CUDA forward declarations 11 | 12 | void fast_walsh_hadamard_transform_cuda_kernel(const int NN, const int halfLL, torch::Tensor in, torch::Tensor out, bool normalize); 13 | 14 | // C++ interface 15 | 16 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 18 | #define CHECK_INPUT(x) \ 19 | CHECK_CUDA(x); \ 20 | CHECK_CONTIGUOUS(x) 21 | 22 | torch::Tensor fast_walsh_hadamard_transform(torch::Tensor input, bool normalize) 23 | { 24 | CHECK_INPUT(input); 25 | const int NN = input.numel(); 26 | torch::Tensor output_flat = input.clone(); 27 | int ll = 0; 28 | int LL = 1; 29 | while (LL < NN) 30 | { 31 | ll += 1; 32 | LL *= 2; 33 | } 34 | const int halfLL = LL / 2; 35 | fast_walsh_hadamard_transform_cuda_kernel(NN, halfLL, input, output_flat, normalize); 36 | return output_flat; 37 | } 38 | 39 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 40 | { 41 | m.def("fast_walsh_hadamard_transform", &fast_walsh_hadamard_transform, "Fast Walsh Hadamard Transform (CUDA)"); 42 | } 43 | -------------------------------------------------------------------------------- /seq2seq/projections/fwh_cuda/fwh_cu.cu: -------------------------------------------------------------------------------- 1 | // The codes are from Armen Aghajanyan from facebook, from paper 2 | // Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning 3 | // https://arxiv.org/abs/2012.13255 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | template 13 | __global__ void FastWalshHadamardKernel(const int stride, const scalar_t* in, scalar_t* out) { 14 | const auto idx = (threadIdx.x + blockIdx.x * blockDim.x); 15 | const auto elemIdx = (idx / stride ) * (2 * stride) + (idx % stride); 16 | const auto tmp = in[elemIdx], tmp2 = in[elemIdx + stride]; 17 | out[elemIdx] = tmp + tmp2; 18 | out[elemIdx + stride] = tmp - tmp2; 19 | } 20 | 21 | template 22 | __global__ void FastWalshHadamardSubKernel(const scalar_t scalar, scalar_t* out) { 23 | const auto idx = (threadIdx.x + blockIdx.x * blockDim.x); 24 | out[idx] *= scalar; 25 | } 26 | 27 | 28 | void fast_walsh_hadamard_transform_cuda_kernel(const int NN, const int halfLL, torch::Tensor in, torch::Tensor out, bool normalize) { 29 | // Apply Unnormalized Fast Walsh Hadamard transform 30 | int stride = halfLL; 31 | float normalizer = 1.0; 32 | float sqrt2inv = 0.70710678118654746; 33 | 34 | while (stride >= 1) { 35 | if(stride == halfLL) 36 | { 37 | AT_DISPATCH_FLOATING_TYPES(in.scalar_type(),"fast_walsh_hadamard_transform_in", ([&] { 38 | FastWalshHadamardKernel<<>>(stride, in.data_ptr(), out.data_ptr()); 39 | })); 40 | } 41 | else 42 | { 43 | AT_DISPATCH_FLOATING_TYPES(in.scalar_type(),"fast_walsh_hadamard_transform_out", ([&] { 44 | FastWalshHadamardKernel<<>>(stride, out.data_ptr(), out.data_ptr()); 45 | })); 46 | } 47 | 48 | stride /= 2; 49 | normalizer *= sqrt2inv; 50 | } 51 | if(normalize){ 52 | AT_DISPATCH_FLOATING_TYPES(in.scalar_type(),"fast_walsh_hadamard_transform_final", ([&] { 53 | FastWalshHadamardSubKernel<<>>(normalizer, out.data_ptr()); 54 | })); 55 | } 56 | } -------------------------------------------------------------------------------- /seq2seq/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.insert(1, os.path.dirname(os.path.realpath(__file__))) 5 | 6 | -------------------------------------------------------------------------------- /seq2seq/third_party/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .t5.modeling_t5 import T5ForConditionalGeneration 2 | from .t5.configuration_t5 import T5Config 3 | -------------------------------------------------------------------------------- /seq2seq/third_party/models/t5/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_t5 import T5ForConditionalGeneration, T5LayerNorm 2 | from .configuration_t5 import T5Config 3 | -------------------------------------------------------------------------------- /seq2seq/third_party/models/t5/configuration_t5.py: -------------------------------------------------------------------------------- 1 | """ T5 model configuration """ 2 | from transformers.models.t5 import T5Config 3 | 4 | class T5Config(T5Config): 5 | def __init__(self, 6 | train_task_adapters=False, 7 | prefix_tuning=False, 8 | **kwargs): 9 | super().__init__(**kwargs) 10 | self.train_task_adapters = train_task_adapters 11 | self.prefix_tuning = prefix_tuning 12 | -------------------------------------------------------------------------------- /seq2seq/third_party/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import BaseTrainer 2 | from .seq2seq_trainer import Seq2SeqTrainer 3 | -------------------------------------------------------------------------------- /seq2seq/third_party/trainers/seq2seq_trainer.py: -------------------------------------------------------------------------------- 1 | from packaging import version 2 | import torch 3 | from torch import nn 4 | from typing import Any, Dict, List, Optional, Tuple, Union 5 | import wandb 6 | 7 | from torch.utils.data.dataset import Dataset 8 | from transformers import Seq2SeqTrainer 9 | from .trainer import BaseTrainer 10 | 11 | if version.parse(torch.__version__) >= version.parse("1.6"): 12 | from torch.cuda.amp import autocast 13 | 14 | class Seq2SeqTrainer(BaseTrainer, Seq2SeqTrainer): 15 | def __init__(self, train_dataset_sizes=None,shared=False, adapter_config=None, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.adapter_config = adapter_config 18 | self.train_dataset_sizes = train_dataset_sizes 19 | self.shared = shared 20 | 21 | def evaluate( 22 | self, 23 | eval_dataset: Optional[Dict[str, Dataset]] = None, 24 | ignore_keys: Optional[List[str]] = None, 25 | metric_key_prefix: str = "eval", 26 | max_length: Optional[int] = None, 27 | num_beams: Optional[int] = None, 28 | task : str = None 29 | ) -> Dict[str, float]: 30 | # TODO: this also needs to be set per dataset 31 | self._max_length = max_length 32 | self._num_beams = num_beams 33 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, task=task) 34 | 35 | 36 | def prediction_step( 37 | self, 38 | model: nn.Module, 39 | inputs: Dict[str, Union[torch.Tensor, Any]], 40 | prediction_loss_only: bool, 41 | ignore_keys: Optional[List[str]] = None, 42 | task: str = None 43 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 44 | """ 45 | Perform an evaluation step on :obj:`model` using obj:`inputs`. 46 | 47 | Subclass and override to inject custom behavior. 48 | 49 | Args: 50 | model (:obj:`nn.Module`): 51 | The model to evaluate. 52 | inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): 53 | The inputs and targets of the model. 54 | 55 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 56 | argument :obj:`labels`. Check your model's documentation for all accepted arguments. 57 | prediction_loss_only (:obj:`bool`): 58 | Whether or not to return the loss only. 59 | 60 | Return: 61 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 62 | labels (each being optional). 63 | """ 64 | if not self.args.predict_with_generate or prediction_loss_only: 65 | return super().prediction_step( 66 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, task=task 67 | ) 68 | 69 | has_labels = "labels" in inputs 70 | inputs = self._prepare_inputs(inputs) 71 | gen_kwargs = { 72 | "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, 73 | "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, 74 | #"task": inputs["task"] if "task" in inputs else "all" 75 | } 76 | 77 | generated_tokens = self.model.generate( 78 | inputs["input_ids"], 79 | attention_mask=inputs["attention_mask"], 80 | **gen_kwargs, 81 | ) 82 | 83 | # in case the batch is shorter than max length, the output should be padded 84 | if generated_tokens.shape[-1] < gen_kwargs["max_length"]: 85 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 86 | 87 | with torch.no_grad(): 88 | if self.use_amp: 89 | with autocast(): 90 | outputs = model(**inputs) 91 | else: 92 | inputs.pop('task') 93 | outputs = model(**inputs) 94 | if has_labels: 95 | if self.label_smoother is not None: 96 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() 97 | else: 98 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() 99 | else: 100 | loss = None 101 | 102 | if self.args.prediction_loss_only: 103 | return (loss, None, None) 104 | 105 | labels = inputs["labels"] 106 | if labels.shape[-1] < gen_kwargs["max_length"]: 107 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 108 | 109 | return (loss, generated_tokens, labels) 110 | -------------------------------------------------------------------------------- /seq2seq/third_party/trainers/trainer_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Union, NamedTuple, Tuple, Dict, Any, Optional 3 | 4 | class EvalPrediction(NamedTuple): 5 | """ 6 | Evaluation output (always contains labels), to be used to compute metrics. 7 | Parameters: 8 | predictions (:obj:`np.ndarray`): Predictions of the model. 9 | label_ids (:obj:`np.ndarray`): Targets to be matched. 10 | data_info: (:obj:`Dict[str, Any]`): Extra dataset information, one requires 11 | to performs the evaluation. The data_info is a dictionary with keys from 12 | train, eval, test to specify the data_info for each split of the dataset. 13 | """ 14 | predictions: Union[np.ndarray, Tuple[np.ndarray]] 15 | label_ids: np.ndarray 16 | data_info: Dict[str, Any] 17 | input_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]] 18 | 19 | -------------------------------------------------------------------------------- /seq2seq/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import get_adapter_config, freeze_model_params,\ 2 | get_adapter_params_names, create_dir, get_last_checkpoint,\ 3 | pad_punctuation, modify_model_after_init, save_training_config -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Install Compacter.""" 2 | import os 3 | import setuptools 4 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 5 | 6 | #os.environ['TORCH_CUDA_ARCH_LIST']="3.5;3.7;6.1;7.0;7.5;8.6+PTX" 7 | 8 | def setup_package(): 9 | long_description = "seq2seq" 10 | setuptools.setup( 11 | name='seq2seq', 12 | version='0.0.1', 13 | description='Compacter', 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | author='Rabeeh Karimi Mahabadi', 17 | license='MIT License', 18 | packages=setuptools.find_packages( 19 | exclude=['docs', 'tests', 'scripts', 'examples']), 20 | dependency_links=[ 21 | 'https://download.pytorch.org/whl/torch_stable.html', 22 | ], 23 | classifiers=[ 24 | 'Intended Audience :: Developers', 25 | 'Intended Audience :: Science/Research', 26 | 'License :: OSI Approved :: MIT License', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | 'Programming Language :: Python :: 3', 29 | 'Programming Language :: Python :: 3.7.10', 30 | ], 31 | keywords='text nlp machinelearning', 32 | ext_modules=[ 33 | CUDAExtension('seq2seq.projections.fwh_cuda', 34 | sources=[ 35 | 'seq2seq/projections/fwh_cuda/fwh_cpp.cpp', 36 | 'seq2seq/projections/fwh_cuda/fwh_cu.cu', 37 | ] 38 | ) 39 | ], 40 | cmdclass={"build_ext": BuildExtension}, 41 | install_requires=[ 42 | 'datasets==1.6.2', 43 | 'scikit-learn==0.24.2', 44 | 'tensorboard==2.5.0', 45 | 'matplotlib==3.4.2', 46 | 'torch==1.10.0+cu113', 47 | 'transformers==4.6.0' 48 | ], 49 | ) 50 | 51 | 52 | if __name__ == '__main__': 53 | setup_package() 54 | --------------------------------------------------------------------------------