├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md └── src └── instruction_tuning_for_few_shot_absa └── data ├── data_utils.py ├── laptop_data_conversion.py ├── prepare_kshot_data_cat.py └── prepare_kshot_data_sent.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Instruction Tuning for Few-Shot Aspect-Based Sentiment Analysis 2 | 3 | The package contains sources to construct the subsampled datasets for the few-shot experiments 4 | used in the paper. 5 | 6 | ------------------------------------------------------- 7 | 8 | Link to data: 9 | 10 | REST15: 11 | 12 | https://github.com/IsakZhang/ABSA-QUAD/tree/master/data/rest15 13 | 14 | REST16: 15 | 16 | https://github.com/IsakZhang/ABSA-QUAD/tree/master/data/rest16 17 | 18 | LAP14: 19 | 20 | https://github.com/xuuuluuu/Position-Aware-Tagging-for-ASTE/tree/master/data/ASTE-Data-V2/14lap 21 | 22 | ------------------------------------------------------- 23 | 24 | Command to create the k-shot subsets for REST15/REST16 25 | 26 | To create train subsamples for K=5: 27 | 28 | python prepare_kshot_data_cat.py --input_file /train.txt --output_dir --num_shot 5 --num_repeat 1 29 | 30 | For dev subsamples, replace train.txt with dev.txt 31 | 32 | ------------------------------------------------------- 33 | 34 | Command to create the k-shot subsets for LAP14 35 | 36 | Convert above laptop14 data to quad format as follows: 37 | 38 | python laptop_data_conversion.py --train_file /train_triplets.txt --dev_file /dev_triplets.txt --test_file /test_triplets.txt --output_dir 39 | 40 | Now create train subsamples using data in quad format(for K=5): 41 | 42 | python prepare_kshot_data_sent.py --input_file /train.txt --output_dir --num_shot 5 --num_repeat 1 43 | 44 | For dev subsamples, replace train.txt with dev.txt 45 | 46 | # Citation 47 | If you find the sources useful, please consider citing our work: 48 | 49 | ``` 50 | @inproceedings{varia-etal-2023-instruction, 51 | title={Instruction Tuning for Few-Shot Aspect-Based Sentiment Analysis}, 52 | author={Varia, Siddharth and Wang, Shuai and Halder, Kishaloy and Vacareanu, Robert and Ballesteros, Miguel and Benajiba, Yassine and John, Neha Anna and Anubhai, Rishita and Muresan, Smaranda and Roth, Dan}, 53 | year={2023}, 54 | month = "jul", 55 | booktitle = "Proceedings of the 13th Workshop on Computational Approaches to Subjectivity, Sentiment and Social Media Analysis", 56 | publisher = "Association for Computational Linguistics" 57 | } 58 | ``` 59 | 60 | -------------------------------------------------------------------------------- /src/instruction_tuning_for_few_shot_absa/data/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | from collections import defaultdict 7 | 8 | 9 | absa_quad_text2category = { 10 | 'location general':'LOCATION#GENERAL', 11 | 'food prices':'FOOD#PRICES', 12 | 'food quality':'FOOD#QUALITY', 13 | 'food general':'FOOD#GENERAL', 14 | 'ambience general':'AMBIENCE#GENERAL', 15 | 'service general':'SERVICE#GENERAL', 16 | 'restaurant prices':'RESTAURANT#PRICES', 17 | 'drinks prices':'DRINKS#PRICES', 18 | 'restaurant miscellaneous':'RESTAURANT#MISCELLANEOUS', 19 | 'drinks quality':'DRINKS#QUALITY', 20 | 'drinks style_options':'DRINKS#STYLE_OPTIONS', 21 | 'restaurant general':'RESTAURANT#GENERAL', 22 | 'food style_options':'FOOD#STYLE_OPTIONS', 23 | "laptop": "laptop", 24 | "LAPTOP": "laptop", 25 | 26 | } 27 | 28 | 29 | def read_absa_quad_from_file(data_path): 30 | """ 31 | Read data from file, each line is: sent####labels 32 | Return List[List[str]], List[List[Tuple]], Dict 33 | """ 34 | all_sents, all_labels = [], [] 35 | unique_labels = defaultdict(int) 36 | with open(data_path, 'r', encoding='UTF-8') as fp: 37 | words = [] 38 | for line in fp: 39 | line = line.strip() 40 | if line != '': 41 | words, tuples = line.split('####') 42 | all_sents.append(words.split()) 43 | tmp_labels = eval(tuples) 44 | new_labels = [] 45 | for label in tmp_labels: 46 | at,ac,sp,ot = label 47 | if at == 'NULL': 48 | at = 'none' 49 | if ot == 'NULL': 50 | ot = 'none' 51 | if '#' not in ac: 52 | ac = absa_quad_text2category[ac] 53 | unique_labels[ac] += 1 54 | new_labels.append((at.lower(),ac,sp,ot.lower())) 55 | all_labels.append(new_labels) 56 | return all_sents, all_labels, unique_labels 57 | -------------------------------------------------------------------------------- /src/instruction_tuning_for_few_shot_absa/data/laptop_data_conversion.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | """ 6 | Converts the data from the triplet format used by https://github.com/xuuuluuu/Position-Aware-Tagging-for-ASTE/ to 7 | the format used in this code (quad) 8 | Appends "LAPTOP" as the category 9 | """ 10 | 11 | import os 12 | import argparse 13 | 14 | sentiment_map = { 15 | 'POS': 'positive', 16 | 'NEG': 'negative', 17 | 'NEU': 'neutral', 18 | } 19 | 20 | def convert_line_to_quad(line: str) -> str: 21 | sentence, tuples = line.split('####') 22 | words = sentence.split(' ') 23 | labels = eval(tuples) 24 | new_labels = [] 25 | for (aspect_term_indices, opinion_term_indices, sentiment) in labels: 26 | aspect_term = ' '.join([words[x] for x in aspect_term_indices]) 27 | opinion_term = ' '.join([words[x] for x in opinion_term_indices]) 28 | new_labels.append([aspect_term, 'laptop', sentiment_map[sentiment], opinion_term]) 29 | return '####'.join([sentence, str(new_labels)]) 30 | 31 | 32 | def get_args(): 33 | parser = argparse.ArgumentParser("Create laptop 14 data in quad format") 34 | parser.add_argument("--train_file", required=True, type=str, help="train file") 35 | parser.add_argument("--dev_file", required=True, type=str, help="dev file") 36 | parser.add_argument("--test_file", required=True, type=str, help="test file") 37 | parser.add_argument("--output_dir", required=True, type=str, help="Output directory to save new files") 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def main(): 43 | args = get_args() 44 | with open(args.train_file) as fin: 45 | lines = [] 46 | for line in fin: 47 | lines.append(convert_line_to_quad(line)) 48 | with open(os.path.join(args.output_dir, 'train.txt'), 'w+') as fout: 49 | for line in lines: 50 | _=fout.write(f'{line}\n') 51 | 52 | with open(args.dev_file) as fin: 53 | lines = [] 54 | for line in fin: 55 | lines.append(convert_line_to_quad(line)) 56 | with open(os.path.join(args.output_dir, 'dev.txt'), 'w+') as fout: 57 | for line in lines: 58 | _=fout.write(f'{line}\n') 59 | 60 | with open(args.test_file) as fin: 61 | lines = [] 62 | for line in fin: 63 | lines.append(convert_line_to_quad(line)) 64 | with open(os.path.join(args.output_dir, 'test.txt'), 'w+') as fout: 65 | for line in lines: 66 | _=fout.write(f'{line}\n') 67 | 68 | if __name__ == "__main__": 69 | main() -------------------------------------------------------------------------------- /src/instruction_tuning_for_few_shot_absa/data/prepare_kshot_data_cat.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from collections import defaultdict 6 | from email.policy import default 7 | import os 8 | import random 9 | import argparse 10 | from data_utils import read_absa_quad_from_file 11 | from typing import List, Tuple, Set 12 | 13 | import logging 14 | logging.basicConfig(level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | 19 | # GLOBAL_SEED = 12347 20 | SEED_INCREMENT = 1237 21 | 22 | 23 | def _select_k_points_per_class(all_sents: List[List[str]], all_labels: List[List[Tuple]], unique_labels: Set[str], k: int): 24 | 25 | def label_not_covered(label_count, k): 26 | for label in label_count: 27 | if label_count[label] < k: 28 | return label 29 | return None 30 | 31 | def update_label_count(labels, label_count): 32 | for label in labels: 33 | at,ac,sp,ot = label 34 | if '#' in ac: 35 | ac = ac.split('#')[0] 36 | label_count[ac] += 1 37 | 38 | subsampled_sents = [] 39 | subsampled_labels = [] 40 | label_count = {label: 0 for label in unique_labels} 41 | 42 | for i, labels in enumerate(all_labels): 43 | for label in labels: 44 | at,ac,sp,ot = label 45 | if '#' in ac: 46 | ac = ac.split('#')[0] 47 | if label_count[ac] >= k: 48 | continue 49 | else: 50 | subsampled_sents.append(all_sents[i]) 51 | subsampled_labels.append(labels) 52 | update_label_count(labels, label_count) 53 | break 54 | if label_not_covered(label_count, k) is None: 55 | break 56 | 57 | not_covered = label_not_covered(label_count, k) 58 | 59 | if not_covered: 60 | logger.info(f"Not enough labels to fulfil {k} samples for {not_covered}") 61 | # raise ValueError(f"Not enough labels to fulfil {k} samples for {not_covered}") 62 | 63 | return subsampled_sents, subsampled_labels, label_count 64 | 65 | 66 | def write_subsampled_data(output_file: str, all_sents: List[List[str]], all_labels: List[List[Tuple]]): 67 | with open(output_file, 'w+') as fhw: 68 | for sent,labels in zip(all_sents, all_labels): 69 | sent = ' '.join(sent) 70 | line = f'{sent}####{repr(labels)}' 71 | fhw.write(line) 72 | fhw.write('\n') 73 | 74 | 75 | def get_parser(): 76 | parser = argparse.ArgumentParser("prepare kshot data") 77 | 78 | parser.add_argument("--input_file", required=True, type=str, help="Input file") 79 | parser.add_argument("--output_dir", required=True, type=str, help="Output directory") 80 | parser.add_argument("--num_shot", required=True, type=int, help="number of shot for each entity type") 81 | parser.add_argument("--num_repeat", type=int, default=1, help="number of times to repeat the same sampling strategy") 82 | parser.add_argument("--seed", type=int, default=12347, help="random seed") 83 | return parser 84 | 85 | def main(): 86 | parser = get_parser() 87 | config = parser.parse_args() 88 | all_sents_original, all_labels_original, unique_labels_original = read_absa_quad_from_file(config.input_file) 89 | unique_labels = defaultdict(int) 90 | GLOBAL_SEED = config.seed 91 | for k,v in unique_labels_original.items(): 92 | k = k.split('#')[0] 93 | unique_labels[k] += v 94 | print('unique_labels:', unique_labels) 95 | for i in range(config.num_repeat): 96 | current_seed = GLOBAL_SEED+(i*SEED_INCREMENT) 97 | random.seed(current_seed) 98 | combined = list(zip(all_sents_original, all_labels_original)) 99 | random.shuffle(combined) 100 | all_sents, all_labels= zip(*combined) 101 | subsampled_sents, subsampled_labels, label_count = _select_k_points_per_class(all_sents, all_labels, unique_labels, config.num_shot) 102 | output_file = os.path.join(config.output_dir, os.path.basename(config.input_file).rsplit('.',1)[0]+f'_k_{config.num_shot}_seed_{current_seed}.txt') 103 | write_subsampled_data(output_file, subsampled_sents, subsampled_labels) 104 | print(f'Iteration={i+1}, k={config.num_shot}, seed={current_seed}, label_count={label_count}') 105 | 106 | if __name__ == '__main__': 107 | main() -------------------------------------------------------------------------------- /src/instruction_tuning_for_few_shot_absa/data/prepare_kshot_data_sent.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from cProfile import label 6 | from collections import defaultdict 7 | from email.policy import default 8 | import os 9 | import copy 10 | import random 11 | import argparse 12 | from data_utils import read_absa_quad_from_file 13 | from typing import Counter, List, Tuple, Set 14 | 15 | import logging 16 | logging.basicConfig(level=logging.INFO) 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | 21 | # GLOBAL_SEED = 12347 22 | SEED_INCREMENT = 1237 23 | 24 | 25 | def _select_k_points_per_class(all_sents: List[List[str]], all_labels: List[List[Tuple]], unique_labels: Set[str], k: int): 26 | 27 | def label_not_covered(label_count, k): 28 | for label in label_count: 29 | if label_count[label] < k: 30 | return label 31 | return None 32 | 33 | def update_label_count(labels, label_count): 34 | for label in labels: 35 | at,ac,sp,ot = label 36 | label_count[sp] += 1 37 | 38 | subsampled_sents = [] 39 | subsampled_labels = [] 40 | label_count = {label: 0 for label in unique_labels} 41 | 42 | for i, labels in enumerate(all_labels): 43 | for label in labels: 44 | at,ac,sp,ot = label 45 | if label_count[sp] >= k: 46 | continue 47 | else: 48 | subsampled_sents.append(all_sents[i]) 49 | subsampled_labels.append(labels) 50 | update_label_count(labels, label_count) 51 | break 52 | if label_not_covered(label_count, k) is None: 53 | break 54 | 55 | not_covered = label_not_covered(label_count, k) 56 | 57 | if not_covered: 58 | logger.info(f"Not enough labels to fulfil {k} samples for {not_covered}") 59 | # raise ValueError(f"Not enough labels to fulfil {k} samples for {not_covered}") 60 | 61 | return subsampled_sents, subsampled_labels, label_count 62 | 63 | 64 | def write_subsampled_data(output_file: str, all_sents: List[List[str]], all_labels: List[List[Tuple]]): 65 | with open(output_file, 'w+') as fhw: 66 | for sent,labels in zip(all_sents, all_labels): 67 | sent = ' '.join(sent) 68 | line = f'{sent}####{repr(labels)}' 69 | fhw.write(line) 70 | fhw.write('\n') 71 | 72 | 73 | def get_parser(): 74 | parser = argparse.ArgumentParser("prepare kshot data") 75 | 76 | parser.add_argument("--input_file", required=True, type=str, help="Input file") 77 | parser.add_argument("--output_dir", required=True, type=str, help="Output directory") 78 | parser.add_argument("--num_shot", required=True, type=int, help="number of shot for each entity type") 79 | parser.add_argument("--num_repeat", type=int, default=1, help="number of times to repeat the same sampling strategy") 80 | parser.add_argument("--seed", type=int, default=12347, help="random seed") 81 | return parser 82 | 83 | # Similar with the original prepare_kshot_data.py, but adapted for LAP14. For REST15 and REST16 we split based on category (to make sure we cover every category) 84 | # but LAP14 does not have categories (or are too fine-grained, rendering them not that usefull), so we split based on sentiment 85 | # NOTE: This script could have been more generic, allowing for a check of on what to split 86 | def main(): 87 | parser = get_parser() 88 | config = parser.parse_args() 89 | all_sents_original, all_labels_original, unique_labels_original = read_absa_quad_from_file(config.input_file) 90 | # print(unique_labels_original) 91 | unique_labels_original = Counter([y[2] for x in all_labels_original for y in x]) 92 | print(unique_labels_original) 93 | unique_labels = defaultdict(int) 94 | GLOBAL_SEED = config.seed 95 | for k,v in unique_labels_original.items(): 96 | k = k.split('#')[0] 97 | unique_labels[k] += v 98 | print('unique_labels:', unique_labels) 99 | for i in range(config.num_repeat): 100 | current_seed = GLOBAL_SEED+(i*SEED_INCREMENT) 101 | random.seed(current_seed) 102 | combined = list(zip(all_sents_original, all_labels_original)) 103 | random.shuffle(combined) 104 | all_sents, all_labels= zip(*combined) 105 | subsampled_sents, subsampled_labels, label_count = _select_k_points_per_class(all_sents, all_labels, unique_labels, config.num_shot) 106 | output_file = os.path.join(config.output_dir, os.path.basename(config.input_file).rsplit('.',1)[0]+f'_k_{config.num_shot}_seed_{current_seed}.txt') 107 | write_subsampled_data(output_file, subsampled_sents, subsampled_labels) 108 | print(f'Iteration={i+1}, k={config.num_shot}, seed={current_seed}, label_count={label_count}') 109 | 110 | if __name__ == '__main__': 111 | main() --------------------------------------------------------------------------------