├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── THIRD-PARTY-LICENSES ├── anollm ├── __init__.py ├── anollm.py ├── anollm_dataset.py ├── anollm_trainer.py └── anollm_utils.py ├── evaluate_anollm.py ├── evaluate_baselines.py ├── figs └── overview.png ├── requirements.txt ├── scripts ├── exp1-mixed_benchmark │ ├── run_anollm.sh │ └── run_baselines.sh ├── exp2-odds │ ├── run_anollm.sh │ └── run_baselines.sh ├── exp3-binning_effect │ └── run_binning_odds.sh └── exp4-model_size │ ├── run_anollm_1.7B_mixed.sh │ └── run_anollm_1.7B_odds.sh ├── src ├── __init__.py ├── baselines │ ├── dte.py │ └── icl.py ├── data_utils.py ├── get_avg_results.py └── get_results.py └── train_anollm.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AnoLLM: Large Language Models for Tabular Anomaly Detection (ICLR 2025) 2 | 3 |

4 | 5 | GitHub License 6 | 7 | 8 | Openreview 9 | 10 |

11 | 12 | This repository contains the implementation of the paper: 13 | > **AnoLLM: Large Language Models for Tabular Anomaly Detection**
14 | > International Conference on Learning Representations (ICLR 2025)
15 | > Che-Ping Tsai, Ganyu Teng, Phil Wallis, Wei Ding.
16 | 17 | ## Introduction 18 | 19 |
20 | Model Logo 21 |
22 |
23 |
24 | AnoLLM is a novel framework that leverages large language models (LLMs) for unsupervised tabular anomaly detection. It can effectively handle mixed-type tabular data (e.g., continuous/numerical, discrete/categorical, and texts) by adapting a pre-trained LLM with serialized tabular data in the text format. During inference, AnoLLM assigns anomaly scores based on the negative log-likelihood generated by the LLM. Our empirical results indicate that AnoLLM delivers the best performance on six benchmark datasets with mixed feature types. 25 | 26 | 27 | ## Installing Dependencies 28 | 29 | Python version: 3.10 30 | 31 | 32 | Create environment 33 | 34 | ``` 35 | conda create -n anollm python=3.10 36 | conda activate anollm 37 | ``` 38 | 39 | Install packages 40 | 41 | ``` 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | Install Torch, ensuring that the version you choose is compatible with your CUDA version. 46 | ``` 47 | pip install torch==2.3.1 48 | ``` 49 | 50 | Overwrite pyod version to avoid bugs 51 | ``` 52 | pip install pyod==2.0.1 53 | ``` 54 | 55 | ## Rerun our experiments 56 | 57 | 1. Download the following datasets from Kaggle and put them to ``data/[dataset_name]/`` 58 | - [vifd](https://www.kaggle.com/datasets/khusheekapoor/vehicle-insurance-fraud-detection/data) (Vehicle Insurance Fraud Detection) 59 | - [fraudecom](https://www.kaggle.com/datasets/vbinh002/fraud-ecommerce/data) (Fraud E-commerce) 60 | 2. Run the corresponding scripts for each experiment: 61 | ``` 62 | bash scripts/exp1-mixed_benchmark/run_anollm.sh 63 | bash scripts/exp1-mixed_benchmark/run_baselines.sh 64 | bash scripts/exp2-odds/run_anollm.sh 65 | bash scripts/exp2-odds/run_baselines.sh 66 | bash scripts/exp3-binning_effect/run_binning_odds.sh 67 | bash scripts/exp4-model_size/run_anollm_1.7B_mixed.sh 68 | bash scripts/exp4-model_size/run_anollm_1.7B_odds.sh 69 | ``` 70 | 71 | ## Using your own datasets 72 | 73 | To use a custom dataset, create a dataframe with the following structure: ``{feature_name:feature_values}``. Please refer to ``load_dataset()`` function in ``src/data_utils.py`` for further guidance. 74 | 75 | ### Training Models 76 | 77 | For AnoLLM, we use the following command: 78 | 79 | ``` 80 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --binning standard --setting semi_supervised --max_steps 2000 --batch_size $batch_size --model $model 81 | ``` 82 | Check the argument parser in ``train_anollm.py`` for options for datasets and models 83 | 84 | For baselines, we use the following command: 85 | 86 | ``` 87 | CUDA_VISIBLE_DEVICES=0 python evaluate_baselines.py --dataset $dataset --n_splits $n_splits --normalize --setting semi_supervised --split_idx $split_idx 88 | ``` 89 | 90 | Check the argument parser in ``evaluate_baselines.py`` for options for datasets 91 | 92 | ### Evaluation 93 | 94 | To evaluate AnoLLM, we use the following command: 95 | ``` 96 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting semi_supervised --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 97 | ``` 98 | 99 | We evaluate the quality of synthetic data using metrics from various aspects. 100 | ``` 101 | python src/get_results.py --dataset $dataset --n_splits $n_splits --setting semi_supervised 102 | ``` 103 | 104 | ## License 105 | 106 | This project is licensed under the Apache-2.0 License. 107 | 108 | ## Acknowledgement 109 | Baselines were adapted from https://github.com/vicliv/DTE. Part of the code was adapted from https://github.com/kathrinse/be_great. Thanks to all the authors for their great works! 110 | 111 | ## Reference 112 | 113 | ``` 114 | @inproceedings{tsai2025anollm, 115 | title={AnoLLM: Large Language Models for Tabular Anomaly Detection}, 116 | author={Tsai, Che-Ping and Teng, Ganyu and Wallis, Phil and Ding, Wei}, 117 | booktitle={The thirteenth International Conference on Learning Representations}, 118 | year={2025}, 119 | note={Accepted, to appear}, 120 | } 121 | ``` 122 | 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /THIRD-PARTY-LICENSES: -------------------------------------------------------------------------------- 1 | ** DTE repository -- https://github.com/vicliv/DTE/tree/main 2 | 3 | MIT License 4 | 5 | Copyright (c) 2024 Victor Livernoche 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | ** GReaT repository -- https://github.com/kathrinse/be_great/tree/main 26 | 27 | MIT License 28 | 29 | Copyright (c) 2022 Kathrin Seßler and Vadim Borisov 30 | 31 | Permission is hereby granted, free of charge, to any person obtaining a copy 32 | of this software and associated documentation files (the "Software"), to deal 33 | in the Software without restriction, including without limitation the rights 34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 35 | copies of the Software, and to permit persons to whom the Software is 36 | furnished to do so, subject to the following conditions: 37 | 38 | The above copyright notice and this permission notice shall be included in all 39 | copies or substantial portions of the Software. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 47 | SOFTWARE. -------------------------------------------------------------------------------- /anollm/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | from .anollm import AnoLLM 3 | from .anollm_dataset import AnoLLMDataset -------------------------------------------------------------------------------- /anollm/anollm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Original Copyright (c) 2022 Kathrin Seßler and Vadim Borisov. Licensed under the MIT License. 3 | Part of code is adapted from the GReaT repository (https://github.com/kathrinse/be_great/tree/main) 4 | Modifications Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 5 | ''' 6 | import os 7 | import warnings 8 | 9 | import logging 10 | import numpy as np 11 | import pandas as pd 12 | import torch 13 | from tqdm import tqdm 14 | from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, AutoConfig 15 | from torch.nn import CrossEntropyLoss 16 | import typing as tp 17 | from transformers import Trainer 18 | from collections import OrderedDict 19 | from pathlib import Path 20 | 21 | from anollm.anollm_trainer import AnoLLMTrainer 22 | from anollm.anollm_utils import _array_to_dataframe 23 | from anollm.anollm_dataset import AnoLLMDataset, AnoLLMDataCollator 24 | 25 | from safetensors.torch import save_model, load_model 26 | 27 | class AnoLLM: 28 | """AnoLLM Class 29 | 30 | The AnoLLM class handles the whole generation flow. It is used to fine-tune a large language model for tabular data, 31 | and to sample synthetic tabular data. 32 | 33 | Attributes: 34 | llm (str): HuggingFace checkpoint of a pretrained large language model, used a basis of our model 35 | tokenizer (AutoTokenizer): Tokenizer, automatically downloaded from llm-checkpoint 36 | model (AutoModelForCausalLM): Large language model, automatically downloaded from llm-checkpoint 37 | experiment_dir (str): Directory, where the training checkpoints will be saved 38 | batch_size (int): Batch size used for fine-tuning 39 | train_hyperparameters (dict): Additional hyperparameters added to the TrainingArguments used by the 40 | HuggingFaceLibrary, see here the full list of all possible values 41 | https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments 42 | columns (list): List of all features/columns of the tabular dataset 43 | num_cols (list): List of all numerical features/columns of the tabular dataset 44 | """ 45 | 46 | def __init__( 47 | self, 48 | llm: str, 49 | experiment_dir: str = "models", 50 | batch_size: int = 8, 51 | efficient_finetuning: str = "", 52 | max_length_dict: tp.Optional[tp.Dict[str, int]] = None, 53 | textual_columns: tp.List[str] = [], # columns that needs to be normalized, e.g. text columns 54 | random_init: bool = False, # if True, the model will be initialized with random weights. 55 | no_random_permutation: bool = False, # if True, columns will not be permuted randomly 56 | **train_kwargs, 57 | ): 58 | """ 59 | 60 | Args: 61 | llm: HuggingFace checkpoint of a pretrained large language model, used a basis of our model 62 | experiment_dir: Directory, where the training checkpoints will be saved 63 | batch_size: Batch size used for fine-tuning 64 | efficient_finetuning: if efficient_finetuning is 'lora', the model will be fine-tuned with LoRA 65 | max_length_dict: Dictionary that contains the maximum length of each textual features. 66 | train_kwargs: Additional hyperparameters added to the TrainingArguments used by the HuggingFaceLibrary, 67 | see here the full list of all possible values 68 | https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments 69 | """ 70 | # Load Model and Tokenizer from HuggingFace 71 | self.efficient_finetuning = efficient_finetuning 72 | self.llm = llm 73 | self.tokenizer = AutoTokenizer.from_pretrained(self.llm) 74 | self.tokenizer.pad_token = self.tokenizer.eos_token 75 | if not random_init: 76 | self.model = AutoModelForCausalLM.from_pretrained(self.llm, torch_dtype=torch.bfloat16) 77 | else: 78 | config = AutoConfig.from_pretrained(self.llm) 79 | self.model = AutoModelForCausalLM.from_config(config) 80 | 81 | if self.efficient_finetuning == "lora": 82 | # Lazy importing 83 | try: 84 | from peft import ( 85 | LoraConfig, 86 | get_peft_model, 87 | ) 88 | except ImportError: 89 | raise ImportError( 90 | "This function requires the 'peft' package. Please install it with - pip install peft" 91 | ) 92 | 93 | # Define LoRA Config 94 | lora_config = LoraConfig( 95 | r=8, 96 | lora_alpha=32, 97 | target_modules=[ 98 | "q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj" 99 | ], # this is specific for smolLM, to be adapted 100 | lora_dropout=0.1, 101 | bias="none", 102 | ) 103 | # add LoRA adaptor 104 | self.model = get_peft_model(self.model, lora_config) 105 | self.model.print_trainable_parameters() 106 | 107 | # Set the training hyperparameters 108 | self.experiment_dir = experiment_dir 109 | self.batch_size = batch_size 110 | self.max_length_dict = max_length_dict 111 | self.textual_columns = textual_columns 112 | self.no_random_permutation = no_random_permutation 113 | self.train_hyperparameters = train_kwargs 114 | 115 | def fit( 116 | self, 117 | data: tp.Union[pd.DataFrame, np.ndarray], 118 | column_names: tp.Optional[tp.List[str]] = None, 119 | resume_from_checkpoint: tp.Union[bool, str] = False, 120 | use_wandb: bool = False, 121 | data_val: tp.Union[pd.DataFrame, np.ndarray] = None, 122 | label_val: np.ndarray = None, 123 | eval_steps: int = 400, 124 | processed_data_dir: str = None 125 | ) -> Trainer: 126 | """Fine-tune AnoLLM using tabular data. 127 | 128 | Args: 129 | data: Pandas DataFrame that contains the tabular data 130 | column_names: If data is Numpy Array, the feature names have to be defined. If data is Pandas 131 | DataFrame, the value is ignored 132 | 133 | Returns: 134 | AnoLLM Trainer used for the fine-tuning process 135 | """ 136 | df = _array_to_dataframe(data, columns=column_names) 137 | 138 | # Convert DataFrame into HuggingFace dataset object 139 | logging.info("Convert data into HuggingFace dataset object...") 140 | dataset = AnoLLMDataset.from_pandas(df, preserve_index=False) 141 | dataset.set_tokenizer(self.tokenizer) 142 | dataset.set_textual_columns(self.textual_columns) 143 | if self.no_random_permutation: 144 | dataset.fix_column_order() 145 | 146 | processed_data_path = Path(processed_data_dir) / "train_data.pkl" if processed_data_dir is not None else None 147 | dataset.prepare(is_eval = False, max_length_dict=self.max_length_dict, 148 | data_path=processed_data_path) 149 | print("Data 0:", self.tokenizer.decode(dataset[0]['input_ids'] )) 150 | # Set training hyperparameters 151 | logging.info("Create AnoLLM Trainer...") 152 | trainer_args = {} 153 | 154 | if data_val is not None: 155 | df_val = _array_to_dataframe(data_val, columns=column_names) 156 | dataset_val = AnoLLMDataset.from_pandas(df_val, preserve_index=False) 157 | dataset_val.set_tokenizer(self.tokenizer) 158 | dataset_val.set_anomaly_label(label_val) 159 | dataset_val.set_textual_columns(self.textual_columns) 160 | if self.no_random_permutation: 161 | dataset_val.fix_column_order() 162 | 163 | processed_data_path = Path(processed_data_dir) / "val_data.pkl" if processed_data_dir is not None else None 164 | dataset_val.prepare(is_eval = True, max_length_dict=self.max_length_dict, 165 | data_path = processed_data_path) 166 | 167 | self.train_hyperparameters["eval_strategy"] = "steps" 168 | self.train_hyperparameters["eval_steps"] = eval_steps 169 | trainer_args["eval_dataset"] = dataset_val 170 | 171 | if use_wandb: 172 | self.train_hyperparameters["report_to"] = ["wandb"] 173 | self.train_hyperparameters["logging_strategy"] = "steps" 174 | self.train_hyperparameters["logging_dir"] = "./logs" 175 | self.train_hyperparameters["logging_steps"] = 50 176 | self.train_hyperparameters["log_level"] = 'info' 177 | 178 | training_args = TrainingArguments( 179 | self.experiment_dir, 180 | per_device_train_batch_size=self.batch_size, 181 | per_device_eval_batch_size=self.batch_size * 2, 182 | save_strategy = 'no', 183 | max_grad_norm = 0.7, 184 | **self.train_hyperparameters, 185 | ) 186 | 187 | #optimizer = bnb.optim.PagedAdamW32bit(self.model.parameters(), betas=(0.9, 0.95), eps=1e-5) 188 | trainer = AnoLLMTrainer( 189 | self.model, 190 | training_args, 191 | train_dataset=dataset, 192 | tokenizer=self.tokenizer, 193 | data_collator=AnoLLMDataCollator(self.tokenizer), 194 | **trainer_args, 195 | ) 196 | 197 | if data_val is not None: 198 | trainer.set_eval_setting(n_permutations=1) 199 | 200 | # Start training 201 | logging.info("Start training...") 202 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 203 | 204 | return trainer 205 | 206 | def decision_function( 207 | self, 208 | df_test: pd.DataFrame, 209 | n_permutations: int = 16, 210 | batch_size: int = 32, 211 | device: str = "cuda", 212 | feature_wise: bool = False, 213 | ) -> np.ndarray: 214 | ''' Obtain anomaly scores for each sample in the test data 215 | df_test: pandas dataframe of test data 216 | n_permutations: number of permutations to calculate the anomaly score 217 | batch_size: batch size for prediction 218 | device: device to run the model 219 | feature_wise: get anomaly scores for each features. If True, returns anomaly scores for each feature in the test data. Size: (n_test, n_features, n_permutation) 220 | # Returns: 221 | # np.ndarray: Anomaly scores for each sample in the test data. Size: (n_test, n_permutation) or (n_test, n_features, n_permutation) if feature_wise is True 222 | ''' 223 | # Convert DataFrame into HuggingFace dataset object 224 | logging.info("Convert data into HuggingFace dataset object...") 225 | dataset = AnoLLMDataset.from_pandas(df_test, preserve_index=False) 226 | dataset.set_tokenizer(self.tokenizer) 227 | dataset.set_textual_columns(self.textual_columns) 228 | 229 | if self.no_random_permutation: 230 | dataset.fix_column_order() 231 | 232 | dataset.prepare(is_eval = True, max_length_dict=self.max_length_dict) 233 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle = False, 234 | collate_fn = AnoLLMDataCollator(self.tokenizer)) 235 | 236 | self.model.to(device) 237 | comma_id = self.tokenizer.convert_tokens_to_ids(',') 238 | n_col = len(df_test.columns) 239 | column_names = dataset.get_column_names() 240 | if feature_wise: 241 | anomaly_scores = np.zeros((len(df_test), n_col, n_permutations)) 242 | else: 243 | anomaly_scores = np.zeros((len(df_test), n_permutations)) 244 | 245 | loss_fct = CrossEntropyLoss(reduction="none") 246 | 247 | 248 | for perm_idx in tqdm(range(n_permutations)): 249 | start_idx = 0 250 | dataset.shuffle_column_order() 251 | for data in dataloader: 252 | encoded_batch = data["input_ids"].to(device) 253 | attn_mask = data["attention_mask"].to(device) 254 | end_idx = start_idx + len(encoded_batch) 255 | labels = encoded_batch 256 | 257 | start_pos_batch = data["feature_value_start"] 258 | end_pos_batch = data["feature_value_end"] 259 | col_indices_batch = data["col_indices"] 260 | 261 | with torch.no_grad(): 262 | out_logits = self.model(encoded_batch, attention_mask=attn_mask).logits 263 | 264 | shift_logits = out_logits[..., :-1, :].contiguous() 265 | shift_labels = labels[..., 1:].contiguous() 266 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous() 267 | 268 | if feature_wise: 269 | score_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).cpu().to(torch.float32).numpy() # batch * (ori_seq_len -1) 270 | 271 | for i in range(len(encoded_batch)): 272 | for j in range(n_col): 273 | start_pos = start_pos_batch[i][j] 274 | end_pos = end_pos_batch[i][j] 275 | col_idx = col_indices_batch[i][j] 276 | anomaly_scores[start_idx+i, col_idx, perm_idx] = score_batch[i, start_pos:end_pos].sum() 277 | elif len(self.textual_columns) > 0: 278 | score_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).cpu().to(torch.float32).numpy() # batch * (ori_seq_len -1) 279 | for i in range(len(encoded_batch)): 280 | score_single = 0 281 | for j in range(n_col): 282 | start_pos = start_pos_batch[i][j] 283 | end_pos = end_pos_batch[i][j] 284 | col_idx = col_indices_batch[i][j] 285 | if column_names[col_idx] in self.textual_columns: 286 | score_single += score_batch[i, start_pos:end_pos].sum() / (end_pos - start_pos) 287 | else: 288 | score_single += score_batch[i, start_pos:end_pos].sum() 289 | anomaly_scores[start_idx+i, perm_idx] = score_single 290 | else: 291 | score_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).to(torch.float32).sum(1) # remove normalization 292 | anomaly_scores[start_idx:end_idx, perm_idx] = score_batch.cpu().numpy() 293 | start_idx = end_idx 294 | 295 | return anomaly_scores #(len(df_test), n_permutations) 296 | 297 | def save_state_dict(self, path: str): 298 | """Save AnoLLM Model 299 | 300 | Saves the model weights and a configuration file in the given directory. 301 | Warning: Only works in DDP setting! 302 | 303 | Args: 304 | path: Path where to save the model 305 | """ 306 | directory = os.path.dirname(path) 307 | # Make directory 308 | if os.path.isdir(directory): 309 | warnings.warn(f"Directory {path} already exists and is overwritten now.") 310 | else: 311 | os.mkdir(directory) 312 | 313 | model_to_save = self.model.module 314 | save_model(model_to_save, path) 315 | 316 | def load_from_state_dict(self, path: str): 317 | """Load AnoLLM model from state_dict 318 | 319 | Args: 320 | path: path where AnoLLM model is saved 321 | """ 322 | load_model(self.model, path) 323 | 324 | -------------------------------------------------------------------------------- /anollm/anollm_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import typing as tp 3 | import os 4 | 5 | from datasets import Dataset 6 | from dataclasses import dataclass 7 | from transformers import DataCollatorWithPadding 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | import pickle as pkl 11 | MAX_COL_LENGTH = 128 12 | 13 | class AnoLLMDataset(Dataset): 14 | """AnoLLM Dataset 15 | 16 | The AnoLLM overwrites the _getitem function of the HuggingFace Dataset Class to include the permutation step. 17 | 18 | Attributes: 19 | tokenizer (AutoTokenizer): Tokenizer from HuggingFace 20 | """ 21 | 22 | def set_tokenizer(self, tokenizer): 23 | """Set the Tokenizer 24 | 25 | Args: 26 | tokenizer: Tokenizer from HuggingFace 27 | """ 28 | self.tokenizer = tokenizer 29 | 30 | def set_anomaly_label(self, labels): 31 | assert len(labels) == len(self._data) 32 | self.anomaly_labels = labels 33 | 34 | def set_textual_columns(self, columns: tp.List[str]): 35 | col_list = self.get_column_names() 36 | for col in columns: 37 | if col not in col_list: 38 | raise ValueError("Column {} not in the dataset.".format(col)) 39 | self.textual_columns = columns 40 | 41 | def get_n_columns(self): 42 | row = self._data.fast_slice(0, 1) 43 | return row.num_columns 44 | 45 | def get_column_names(self): 46 | row = self._data.fast_slice(0, 1) 47 | return row.column_names 48 | 49 | def shuffle_column_order(self): 50 | # used in evalutaion. the order of the columns is shuffled and then fixed for all data 51 | row = self._data.fast_slice(0, 1) 52 | self.shuffle_idx = list(range(row.num_columns)) 53 | random.shuffle(self.shuffle_idx) 54 | 55 | def fix_column_order(self): 56 | # set the column order to be default column order. Do not shuffle the columns. 57 | row = self._data.fast_slice(0, 1) 58 | self.shuffle_idx = list(range(row.num_columns)) 59 | 60 | def prepare( 61 | self, 62 | is_eval: bool = True, 63 | max_length_dict: tp.Optional[tp.Dict[str, int]] = {}, 64 | data_path = None, 65 | ): 66 | ''' 67 | Preprocess the data by tokenizing each column and truncating the columns to max_length 68 | Inputs: 69 | max_length_dict specifies the maximum length of each column. If None, all columns are truncated to max length 70 | pad_columns specifies whether to pad the columns to the same length according to max_length of a each column 71 | ''' 72 | self.is_eval = is_eval 73 | n_col = self.get_n_columns() 74 | column_names = self.get_column_names() 75 | self.processed_data = [] 76 | self.tokenized_feature_names = [] 77 | bos_token_id = self.tokenizer.bos_token_id 78 | 79 | for col_idx in range(n_col): 80 | feature_names = ' ' + column_names[col_idx] + ' ' 81 | tokenized_feature_names = self.tokenizer(feature_names) 82 | tokenized_is = self.tokenizer('is ') 83 | if bos_token_id and tokenized_feature_names['input_ids'][0] == bos_token_id: 84 | tokenized_feature_names['input_ids'] = tokenized_feature_names['input_ids'][1:] 85 | tokenized_is['input_ids'] = tokenized_is['input_ids'][1:] 86 | 87 | self.tokenized_feature_names.append(tokenized_feature_names["input_ids"] + tokenized_is["input_ids"]) 88 | 89 | if data_path is not None and os.path.exists(data_path): 90 | self.processed_data = pkl.load(open(data_path, 'rb')) 91 | else: 92 | for key in tqdm(range(len(self._data))): 93 | row = self._data.fast_slice(key, 1) 94 | tokenized_texts = [] 95 | for col_idx in range(n_col): 96 | feature_values = str(row.columns[col_idx].to_pylist()[0]).strip() 97 | if len(feature_values) == 0: 98 | feature_values = "None" 99 | data = self.tokenizer(feature_values) 100 | if bos_token_id and data['input_ids'][0] == bos_token_id: 101 | data['input_ids'] = data['input_ids'][1:] 102 | 103 | tokenized_texts.append(data["input_ids"]) 104 | if len(data["input_ids"]) == 0: 105 | print("Warning: tokenized text is empty.", column_names[col_idx],len( feature_values),feature_values) 106 | self.processed_data.append(tokenized_texts) 107 | 108 | # truncate the columns that are too long 109 | for col_idx in range(n_col): 110 | name = column_names[col_idx] 111 | if name not in max_length_dict: 112 | max_length = MAX_COL_LENGTH 113 | else: 114 | max_length = max_length_dict[name] 115 | assert isinstance(max_length, int) 116 | 117 | for data_idx in range(len(self.processed_data)): 118 | length = len(self.processed_data[data_idx][col_idx]) + len(self.tokenized_feature_names[col_idx]) 119 | if length >= max_length: 120 | self.processed_data[data_idx][col_idx] = self.processed_data[data_idx][col_idx][:max_length - len(self.tokenized_feature_names[col_idx])] 121 | if data_path is not None: 122 | pkl.dump(self.processed_data, open(data_path, 'wb')) 123 | print("Preprocessing done.") 124 | 125 | def _getitem( 126 | self, 127 | key: tp.Union[int, slice, str], 128 | decoded: bool = True, 129 | **kwargs 130 | ) -> tp.Union[tp.Dict, tp.List]: 131 | """ 132 | Get one instance of the tabular data, permuted, converted to text and tokenized. 133 | """ 134 | row = self._data.fast_slice(key, 1) 135 | 136 | 137 | # get shuffle_idx 138 | if "shuffle_idx" in self.__dict__: 139 | shuffle_idx = self.shuffle_idx 140 | else: 141 | shuffle_idx = list(range(row.num_columns)) 142 | random.shuffle(shuffle_idx) 143 | 144 | # get tokenized text 145 | comma_id = self.tokenizer.convert_tokens_to_ids(',') 146 | eos_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.eos_token) 147 | bos_token_id = self.tokenizer.bos_token_id 148 | if self.is_eval: 149 | tokenized_text = {"input_ids": [], "attention_mask": [], "feature_value_start":[], 150 | "feature_value_end":[],'col_indices':shuffle_idx} 151 | else: 152 | tokenized_text = {"input_ids": [], "attention_mask": []} 153 | if bos_token_id: 154 | tokenized_text["input_ids"] = [bos_token_id] 155 | 156 | if hasattr(self, "processed_data"): 157 | start_idx = 0 158 | for idx, col_idx in enumerate(shuffle_idx): 159 | tokenized_feature_names = self.tokenized_feature_names[col_idx] 160 | tokenized_feature_values = self.processed_data[key][col_idx] 161 | tokenized_col = tokenized_feature_names + tokenized_feature_values 162 | if idx == len(shuffle_idx) - 1: 163 | tokenized_text["input_ids"] += tokenized_col + [eos_id] 164 | else: 165 | tokenized_text["input_ids"] += tokenized_col + [comma_id] 166 | if self.is_eval: 167 | tokenized_text["feature_value_start"].append(start_idx + len(tokenized_feature_names) -1 ) 168 | tokenized_text["feature_value_end"].append(start_idx + len(tokenized_col) ) 169 | start_idx += len(tokenized_col) + 1 170 | else: 171 | raise ValueError("processed_data is not found. Please run prepare function first.") 172 | tokenized_text["attention_mask"] += [1] * len(tokenized_text["input_ids"]) 173 | return tokenized_text 174 | 175 | def get_item_test(self, key): 176 | row = self._data.fast_slice(key, 1) 177 | shuffle_idx = list(range(row.num_columns)) 178 | random.shuffle(shuffle_idx) 179 | 180 | shuffled_text = ",".join( 181 | [ 182 | " %s is %s " 183 | % (row.column_names[i], str(row.columns[i].to_pylist()[0]).strip() ) 184 | for i in shuffle_idx 185 | ] 186 | ) 187 | tokenized_text = self.tokenizer(shuffled_text, padding=True) 188 | 189 | return shuffled_text, tokenized_text 190 | 191 | def __getitems__(self, keys: tp.Union[int, slice, str, list]): 192 | if isinstance(keys, list): 193 | return [self._getitem(key) for key in keys] 194 | else: 195 | return self._getitem(keys) 196 | 197 | #def add_gaussian_noise(self, value): 198 | # return value + np.random.normal(0, 0.1) 199 | 200 | @dataclass 201 | class AnoLLMDataCollator(DataCollatorWithPadding): 202 | """ 203 | 204 | Overwrites the DataCollatorWithPadding to also pad the labels and not only the input_ids 205 | """ 206 | 207 | def __call__(self, features: tp.List[tp.Dict[str, tp.Any]]): 208 | batch = self.tokenizer.pad( 209 | features, 210 | padding=self.padding, 211 | max_length=self.max_length, 212 | pad_to_multiple_of=self.pad_to_multiple_of, 213 | return_tensors=self.return_tensors, 214 | ) 215 | batch["labels"] = batch["input_ids"].clone() 216 | return batch 217 | 218 | class AnoLLMDataLoader(DataLoader): 219 | ''' 220 | Add set_epoch function so that huggingface trainer can call it 221 | ''' 222 | def set_epoch(self, epoch): 223 | if hasattr(self.sampler, "set_epoch"): 224 | self.sampler.set_epoch(epoch) 225 | print("Set epoch", epoch) 226 | -------------------------------------------------------------------------------- /anollm/anollm_trainer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Original Copyright (c) 2022 Kathrin Seßler and Vadim Borisov. Licensed under the MIT License. 3 | Part of code is adapted from the GReaT repository (https://github.com/kathrinse/be_great/tree/main) 4 | Modifications Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 5 | ''' 6 | 7 | import os 8 | import random 9 | import numpy as np 10 | import torch 11 | import typing as tp 12 | from torch.utils.data import DataLoader 13 | from transformers import Trainer 14 | from sklearn import metrics 15 | from anollm.anollm_dataset import AnoLLMDataCollator, AnoLLMDataLoader 16 | from torch.nn import CrossEntropyLoss 17 | import torch.distributed as dist 18 | from torch.utils.data.distributed import DistributedSampler 19 | 20 | 21 | class AnoLLMTrainer(Trainer): 22 | """ 23 | Overwrites the get_train_dataloader methode of the HuggingFace Trainer to not remove the "unused" columns - 24 | they are needed later! 25 | """ 26 | 27 | def get_train_dataloader(self) -> DataLoader: 28 | if self.train_dataset is None: 29 | raise ValueError("Trainer: training requires a train_dataset.") 30 | 31 | data_collator = self.data_collator 32 | train_dataset = ( 33 | self.train_dataset 34 | ) # self._remove_unused_columns(self.train_dataset, description="training") 35 | local_rank = int(os.environ["LOCAL_RANK"]) 36 | world_size = dist.get_world_size() 37 | train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=local_rank, shuffle=False, drop_last=True) 38 | 39 | return DataLoader( 40 | train_dataset, 41 | batch_size=self._train_batch_size, 42 | sampler=train_sampler, 43 | collate_fn=data_collator, 44 | drop_last=self.args.dataloader_drop_last, 45 | num_workers=self.args.dataloader_num_workers, 46 | pin_memory=self.args.dataloader_pin_memory, 47 | worker_init_fn=_seed_worker, 48 | ) 49 | 50 | # 2025-02-12: Amazon addition. 51 | def set_eval_setting(self, n_permutations): 52 | self.n_permutations = n_permutations 53 | 54 | def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"): 55 | eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset 56 | # do not use distributed sampler 57 | dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.args.eval_batch_size, shuffle = False, 58 | collate_fn = AnoLLMDataCollator(self.tokenizer)) 59 | 60 | 61 | perplexities = np.zeros((len(eval_dataset), self.n_permutations)) 62 | eval_losses = np.zeros((len(eval_dataset), self.n_permutations)) 63 | 64 | loss_fct = CrossEntropyLoss(reduction="none") 65 | 66 | # for conditional columns 67 | comma_id = eval_dataset.tokenizer.convert_tokens_to_ids(',') 68 | n_col = eval_dataset.get_n_columns() 69 | column_names = eval_dataset.get_column_names() 70 | 71 | for perm_idx in range(self.n_permutations): 72 | start_idx = 0 73 | eval_dataset.shuffle_column_order() 74 | for data in dataloader: 75 | encoded_batch = data["input_ids"].to(self.model.device) 76 | attn_mask = data["attention_mask"].to(self.model.device) 77 | end_idx = start_idx + len(encoded_batch) 78 | labels = encoded_batch 79 | 80 | start_pos_batch = data["feature_value_start"] 81 | end_pos_batch = data["feature_value_end"] 82 | col_indices_batch = data["col_indices"] 83 | 84 | with torch.no_grad(): 85 | out_logits = self.model(encoded_batch, attention_mask=attn_mask).logits 86 | 87 | shift_logits = out_logits[..., :-1, :].contiguous() 88 | shift_labels = labels[..., 1:].contiguous() 89 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous() 90 | eval_loss_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) / shift_attention_mask_batch.sum(1) 91 | 92 | if len(eval_dataset.textual_columns) > 0: 93 | perplexity_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).cpu().numpy() # batch * (ori_seq_len -1) 94 | for i in range(len(encoded_batch)): 95 | perplexity_single = 0 96 | for j in range(n_col): 97 | start_pos = start_pos_batch[i][j] 98 | end_pos = end_pos_batch[i][j] 99 | col_idx = col_indices_batch[i][j] 100 | if column_names[col_idx] in eval_dataset.textual_columns: 101 | perplexity_single += perplexity_batch[i, start_pos:end_pos].sum() / (end_pos - start_pos) 102 | else: 103 | perplexity_single += perplexity_batch[i, start_pos:end_pos].sum() 104 | if np.isnan(perplexity_single): 105 | print(start_pos, end_pos, perplexity_batch[i, start_pos:end_pos].sum()) 106 | print(perplexity_batch[i, start_pos:end_pos].sum() / (end_pos - start_pos)) 107 | print(perplexity_single) 108 | perplexities[start_idx+i, perm_idx] = perplexity_single 109 | else: 110 | perplexity_batch = (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) 111 | perplexities[start_idx:end_idx, perm_idx] = perplexity_batch.cpu().numpy() 112 | 113 | eval_losses[start_idx:end_idx, perm_idx] = eval_loss_batch.cpu().numpy() 114 | start_idx = end_idx 115 | 116 | local_rank = int(os.environ["LOCAL_RANK"]) 117 | world_size = dist.get_world_size() 118 | 119 | all_perplexity = [None for _ in range(world_size)] 120 | dist.all_gather_object(all_perplexity, perplexities) 121 | perplexities = np.concatenate(all_perplexity, axis = 1) 122 | 123 | all_eval_loss = [None for _ in range(world_size)] 124 | dist.all_gather_object(all_eval_loss, eval_losses) 125 | eval_losses = np.concatenate(all_eval_loss, axis = 1) 126 | 127 | labels = eval_dataset.anomaly_labels 128 | 129 | mean_perplexity = np.mean(perplexities) 130 | normal_indices = np.where(labels == 0)[0] 131 | anomaly_indices = np.where(labels == 1)[0] 132 | perplexity_normal = np.mean(perplexities[normal_indices]) 133 | eval_loss_normal = np.mean(eval_losses[normal_indices]) 134 | perplexity_anomaly = np.mean(perplexities[anomaly_indices]) 135 | eval_loss_anomaly = np.mean(eval_losses[anomaly_indices]) 136 | 137 | #print("is nan:", np.isnan(eval_dataset.anomaly_labels).sum(), np.isnan(perplexities).sum()) 138 | auc_roc = metrics.roc_auc_score(eval_dataset.anomaly_labels, np.mean(perplexities, axis = 1)) 139 | 140 | metric = {"eval_loss": np.mean(eval_losses), "eval_perplexity": mean_perplexity, "eval_auc_roc": auc_roc, \ 141 | "eval_loss_normal": eval_loss_normal, "eval_perplexity_normal": perplexity_normal, 142 | "eval_loss_anomaly": eval_loss_anomaly, "eval_perplexity_anomaly": perplexity_anomaly} 143 | 144 | if local_rank == 0: 145 | self.log(metric) 146 | self._memory_tracker.stop_and_update_metrics(metric) 147 | 148 | return metric 149 | # End of Amazon addition. 150 | 151 | def _seed_worker(_): 152 | """ 153 | Helper function to set worker seed during Dataloader initialization. 154 | """ 155 | worker_seed = torch.initial_seed() % 2**32 156 | random.seed(worker_seed) 157 | np.random.seed(worker_seed) 158 | torch.manual_seed(worker_seed) 159 | torch.cuda.manual_seed_all(worker_seed) 160 | 161 | -------------------------------------------------------------------------------- /anollm/anollm_utils.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import numpy as np 3 | import pandas as pd 4 | 5 | def _array_to_dataframe( 6 | data: tp.Union[pd.DataFrame, np.ndarray], columns=None 7 | ) -> pd.DataFrame: 8 | """Converts a Numpy Array to a Pandas DataFrame 9 | 10 | Args: 11 | data: Pandas DataFrame or Numpy NDArray 12 | columns: If data is a Numpy Array, columns needs to be a list of all column names 13 | 14 | Returns: 15 | Pandas DataFrame with the given data 16 | """ 17 | if isinstance(data, pd.DataFrame): 18 | return data 19 | 20 | assert isinstance( 21 | data, np.ndarray 22 | ), "Input needs to be a Pandas DataFrame or a Numpy NDArray" 23 | assert ( 24 | columns 25 | ), "To convert the data into a Pandas DataFrame, a list of column names has to be given!" 26 | assert len(columns) == len( 27 | data[0] 28 | ), "%d column names are given, but array has %d columns!" % ( 29 | len(columns), 30 | len(data[0]), 31 | ) 32 | 33 | return pd.DataFrame(data=data, columns=columns) 34 | 35 | -------------------------------------------------------------------------------- /evaluate_anollm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import argparse 4 | 5 | import numpy as np 6 | import torch 7 | import time 8 | 9 | import torch.distributed as dist 10 | 11 | from anollm import AnoLLM 12 | from src.data_utils import load_data, DATA_MAP, get_text_columns, get_max_length_dict 13 | from train_anollm import get_run_name 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--dataset", type = str, default='wine', choices = [d.lower() for d in DATA_MAP.keys()], 19 | help="Name of datasets in the ODDS benchmark") 20 | parser.add_argument("--exp_dir", type = str, default=None) 21 | parser.add_argument("--setting", type = str, default='semi_supervised', choices = ['semi_supervised', 'unsupervised'], help="semi_supervised:an uncontaminated, unsupervised setting; unsupervised:a contaminated, unsupervised setting") 22 | 23 | #dataset hyperparameters 24 | parser.add_argument("--data_dir", type = str, default='data') 25 | parser.add_argument("--n_splits", type = int, default=5) 26 | parser.add_argument("--split_idx", type = int, default=None) # 0 to n_split-1 27 | # binning 28 | parser.add_argument("--binning", type = str, choices=['quantile', 'equal_width', 'language', 'none', 'standard'], default='standard') 29 | parser.add_argument("--n_buckets", type = int, default=10) 30 | parser.add_argument("--remove_feature_name", action = 'store_true') 31 | 32 | # model hyperparameters (for getting the model name) 33 | parser.add_argument("--model", type = str, choices = ['gpt2', 'distilgpt2', 'smol', 'smol-360', 'smol-1.7b'], default='smol') 34 | parser.add_argument("--lora", action='store_true', default=False) 35 | parser.add_argument("--lr", type = float, default=5e-5) 36 | parser.add_argument("--random_init", action='store_true', default=False) 37 | parser.add_argument("--no_random_permutation", action='store_true', default=False) 38 | 39 | #testing 40 | parser.add_argument("--batch_size", type = int, default=128) # per gpu 41 | parser.add_argument("--n_permutations", type = int, default=100) # per gpu 42 | args = parser.parse_args() 43 | 44 | if args.model == 'smol': 45 | args.model = 'HuggingFaceTB/SmolLM-135M' 46 | elif args.model == 'smol-360': 47 | args.model = 'HuggingFaceTB/SmolLM-360M' 48 | elif args.model == 'smol-1.7b': 49 | args.model = 'HuggingFaceTB/SmolLM-1.7B' 50 | 51 | return args 52 | 53 | def main(): 54 | # Set CUDA devices for each process 55 | local_rank = int(os.environ["LOCAL_RANK"]) 56 | world_size = dist.get_world_size() 57 | torch.cuda.set_device(local_rank) 58 | 59 | args = get_args() 60 | 61 | if args.exp_dir is None: 62 | args.exp_dir = Path('exp') / args.dataset / args.setting / "split{}".format(args.n_splits) / "split{}".format(args.split_idx) 63 | 64 | if not os.path.exists(args.exp_dir): 65 | raise ValueError("Experiment directory {} does not exist".format(args.exp_dir)) 66 | 67 | score_dir = args.exp_dir / 'scores' 68 | run_name = get_run_name(args) 69 | 70 | score_path = score_dir / "{}.npy".format(run_name) 71 | print("score_path:", score_path) 72 | if dist.get_rank() == 0: 73 | os.makedirs(score_dir, exist_ok = True) 74 | 75 | remainder = args.n_permutations % world_size 76 | 77 | X_train, X_test, y_train, y_test = load_data(args) 78 | 79 | if not os.path.exists(score_path): 80 | model_dir = args.exp_dir / 'models' 81 | model_path = model_dir / '{}.pt'.format(run_name) 82 | 83 | efficient_finetuning = 'lora' if args.lora else '' 84 | max_length_dict = get_max_length_dict(args.dataset) 85 | text_columns = get_text_columns(args.dataset) 86 | model = AnoLLM(args.model, 87 | efficient_finetuning = efficient_finetuning, 88 | model_path = model_path, 89 | max_length_dict=max_length_dict, 90 | textual_columns = text_columns, 91 | no_random_permutation=args.no_random_permutation, 92 | bp16=True, 93 | ) 94 | print(text_columns, max_length_dict) 95 | 96 | model.load_from_state_dict(model_path) 97 | model.model.to(local_rank) 98 | 99 | # Move the model to the appropriate GPU 100 | # Wrap the model for distributed training 101 | model.model = torch.nn.parallel.DistributedDataParallel( 102 | model.model, device_ids=[local_rank], output_device=local_rank 103 | ) 104 | n_perm = int(args.n_permutations / world_size) 105 | n_perm = n_perm + 1 if local_rank < remainder else n_perm 106 | 107 | start_time = time.time() 108 | scores = model.decision_function(X_test, 109 | n_permutations = n_perm, 110 | batch_size = args.batch_size, 111 | device = "cuda", 112 | ) 113 | end_time = time.time() 114 | 115 | all_scores = [None for _ in range(world_size)] 116 | dist.all_gather_object(all_scores, scores) 117 | 118 | if dist.get_rank() == 0: 119 | 120 | print("Inference time:", end_time - start_time) 121 | 122 | run_time_dir = args.exp_dir / "run_time" / "test" 123 | os.makedirs(run_time_dir, exist_ok = True) 124 | run_time_path = run_time_dir / "{}.txt".format(run_name) 125 | with open(run_time_path, 'w') as f: 126 | f.write(str(end_time - start_time)) 127 | 128 | all_scores = np.concatenate(all_scores, axis = 1) 129 | mean_scores = np.mean(scores, axis = 1) 130 | np.save(score_path, mean_scores) 131 | raw_score_path = score_dir / "raw_{}.npy".format(run_name) 132 | np.save(raw_score_path, all_scores) 133 | 134 | dist.destroy_process_group() 135 | 136 | if __name__ == '__main__': 137 | dist.init_process_group(backend="nccl") 138 | main() -------------------------------------------------------------------------------- /evaluate_baselines.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | import time 7 | 8 | from deepod.models.tabular import DeepSVDD, REPEN, RDP, RCA, GOAD, NeuTraL, SLAD, DeepIsolationForest 9 | from src.baselines.icl import ICL 10 | from src.baselines.dte import DTECategorical 11 | from pyod.models.ecod import ECOD 12 | from pyod.models.pca import PCA 13 | from pyod.models.knn import KNN 14 | from pyod.models.iforest import IForest 15 | from src.data_utils import load_data, DATA_MAP, df_to_numpy, get_text_columns 16 | 17 | DEEPOD_METHODS= { 18 | 'ecod': ECOD(), 19 | 'knn': KNN(), 20 | 'iforest': IForest(), 21 | 'pca': PCA(), 22 | 'deepsvdd': DeepSVDD(), 23 | 'repen': REPEN(), 24 | 'rdp': RDP(), 25 | 'rca': RCA(), 26 | 'goad': GOAD(), 27 | 'neutural': NeuTraL(), 28 | 'icl': ICL(), 29 | 'dif': DeepIsolationForest(), 30 | 'slad': SLAD(), 31 | 'dte': DTECategorical(), 32 | } 33 | 34 | def get_args(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--dataset", type = str, default='wine', choices = [d.lower() for d in DATA_MAP.keys()], 37 | help="Name of datasets in the ODDS benchmark") 38 | parser.add_argument("--exp_dir", type = str, default=None) 39 | parser.add_argument("--setting", type = str, default='semi_supervised', choices = ['semi_supervised', 'unsupervised'], help="semi_supervised:an uncontaminated, unsupervised setting; unsupervised:a contaminated, unsupervised setting") 40 | parser.add_argument("--normalize", action='store_true', default=False) #normalize the numerical features to have zero mean and unit variance 41 | parser.add_argument("--text_encoding", type = str, default='word2vec', choices = ['tfidf', 'word2vec', 'bag_of_words']) 42 | #dataset hyperparameters 43 | parser.add_argument("--data_dir", type = str, default='data') 44 | parser.add_argument("--n_splits", type = int, default=5) 45 | parser.add_argument("--split_idx", type = int, default=None) # 0 to n_split-1 46 | parser.add_argument("--cat_encoding", type = str, default="one_hot", choices = ["one_hot", "ordinal"]) 47 | args = parser.parse_args() 48 | 49 | return args 50 | 51 | def get_run_name(args, method): 52 | if args.normalize: 53 | run_name = '{}_normalized'.format(method) 54 | else: 55 | run_name = method 56 | 57 | if args.cat_encoding == 'ordinal': 58 | run_name += "_ordinal" 59 | 60 | if args.dataset in ['fakejob', 'fakenews']: 61 | run_name += "_{}".format(args.text_encoding) 62 | run_name += "_test_run_time" 63 | 64 | return run_name 65 | 66 | def benchmark(args): 67 | X_train, X_test, y_train, y_test = load_data(args) 68 | 69 | if args.exp_dir is None: 70 | args.exp_dir = Path('exp') / args.dataset / args.setting / "split{}".format(args.n_splits) / "split{}".format(args.split_idx) 71 | 72 | if not os.path.exists(args.exp_dir): 73 | os.makedirs(args.exp_dir, exist_ok = True) 74 | #raise ValueError("Experiment directory {} does not exist".format(args.exp_dir)) 75 | 76 | score_dir = args.exp_dir / 'scores' 77 | os.makedirs(score_dir, exist_ok = True) 78 | 79 | train_run_time_dir = args.exp_dir / 'run_time' / 'train' 80 | os.makedirs(train_run_time_dir, exist_ok = True) 81 | 82 | test_run_time_dir = args.exp_dir / 'run_time' / 'test' 83 | os.makedirs(test_run_time_dir, exist_ok = True) 84 | 85 | # transform dataframe to numpy array 86 | n_train = X_train.shape[0] 87 | X = pd.concat([X_train, X_test], axis = 0) 88 | 89 | textual_columns = get_text_columns(args.dataset) 90 | 91 | X_np = df_to_numpy(X, args.dataset, method = args.cat_encoding, 92 | normalize_numbers=args.normalize, textual_encoding = args.text_encoding, 93 | textual_columns = textual_columns) 94 | X_train = X_np[:n_train] 95 | X_test = X_np[n_train:] 96 | 97 | # ordinal encoding for slad: 98 | X_ord = df_to_numpy(X, args.dataset, method = 'ordinal', 99 | normalize_numbers=args.normalize, textual_encoding = args.text_encoding, 100 | textual_columns = textual_columns) 101 | X_train_ord = X_ord[:n_train] 102 | X_test_ord = X_ord[n_train:] 103 | 104 | 105 | for name, clf in DEEPOD_METHODS.items(): 106 | score_path = score_dir / '{}.npy'.format(get_run_name(args, name)) 107 | train_time_path = train_run_time_dir / '{}.txt'.format(get_run_name(args, name)) 108 | test_time_path = test_run_time_dir / '{}.txt'.format(get_run_name(args, name)) 109 | 110 | if name == 'icl' and args.dataset in ['mulcross', 'covertype', 'http', 'smtp' ]: 111 | clf = ICL(epochs=80) 112 | if name == 'icl' and args.dataset in [ 'covertype', 'http', 'smtp' ]: 113 | clf = ICL(epochs=40) 114 | 115 | #if not os.path.exists(score_path): 116 | if True: 117 | try: 118 | print("Training {} on {} (Split {})...".format(name, args.dataset, args.split_idx)) 119 | if name == 'slad': 120 | start_time = time.time() 121 | clf.fit(X_train_ord, y=None) 122 | end_time = time.time() 123 | train_time = end_time - start_time 124 | 125 | start_time = time.time() 126 | scores = clf.decision_function(X_test_ord) 127 | end_time = time.time() 128 | test_time = end_time - start_time 129 | elif name == "icl" or name == "dte": 130 | start_time = time.time() 131 | clf.fit(X_train_ord, y_train=None) 132 | end_time = time.time() 133 | train_time = end_time - start_time 134 | 135 | start_time = time.time() 136 | scores = clf.decision_function(X_test_ord) 137 | end_time = time.time() 138 | test_time = end_time - start_time 139 | else: 140 | start_time = time.time() 141 | clf.fit(X_train, y=None) 142 | end_time = time.time() 143 | train_time = end_time - start_time 144 | 145 | start_time = time.time() 146 | scores = clf.decision_function(X_test) 147 | end_time = time.time() 148 | test_time = end_time - start_time 149 | if name == 'pca' and np.isinf(scores).any(): 150 | print("Inf in training {}".format(name)) 151 | clf = PCA(n_components = 'mle') 152 | 153 | start_time = time.time() 154 | clf.fit(X_train, y=None) 155 | end_time = time.time() 156 | train_time = end_time - start_time 157 | 158 | start_time = time.time() 159 | scores = clf.decision_function(X_test) 160 | end_time = time.time() 161 | test_time = end_time - start_time 162 | np.save(score_path, scores) 163 | 164 | with open(train_time_path, 'w') as f: 165 | f.write(str(train_time)) 166 | 167 | with open(test_time_path, 'w') as f: 168 | f.write(str(test_time)) 169 | 170 | except: 171 | print("Error in training {}".format(name)) 172 | continue 173 | 174 | def main(): 175 | args = get_args() 176 | 177 | if args.split_idx is None: 178 | for i in range(args.n_splits): 179 | args.split_idx = i 180 | args.exp_dir = None 181 | benchmark(args) 182 | else: 183 | benchmark(args) 184 | 185 | if __name__ == '__main__': 186 | main() -------------------------------------------------------------------------------- /figs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/AnoLLM-large-language-models-for-tabular-anomaly-detection/a051ba450743ea5a57175be305212464fb7bdc16/figs/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | adbench==0.1.11 2 | deepod==0.4.1 3 | feature_engine==1.8.3 4 | gensim==4.3.3 5 | numpy==1.26.4 6 | pandas==2.2.2 7 | scikit_learn==1.6.1 8 | scipy==1.13.1 9 | tqdm==4.66.4 10 | ucimlrepo==0.0.7 11 | peft==0.11.1 12 | datasets==2.20.0 13 | wandb==0.17.4 14 | tf-keras==2.16.0 15 | transformers==4.48.2 -------------------------------------------------------------------------------- /scripts/exp1-mixed_benchmark/run_anollm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | n_splits=5 3 | setting=semi_supervised 4 | 5 | TRAIN_GPUS="0,1,2,3" 6 | INFERENCE_GPUS="1,2,3" 7 | n_train_node=4 8 | n_test_node=3 9 | n_permutations=21 10 | 11 | for model in 'smol' 'smol-360'; do 12 | batch_size=32 13 | eval_batch_size=$((batch_size*2)) 14 | for dataset in 'vifd' 'fraudecom' 'lymphography'; do 15 | expdir=exp/$dataset/$setting/split$n_splits 16 | wandb online 17 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 18 | --batch_size $batch_size --model $model --binning standard --wandb 19 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 20 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 21 | wandb offline 22 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 23 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 24 | --batch_size $batch_size --model $model --binning standard 25 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 26 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 27 | done 28 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 29 | 30 | done 31 | 32 | batch_size=16 33 | eval_batch_size=$((batch_size*2)) 34 | for dataset in 'seismic' ; do 35 | expdir=exp/$dataset/$setting/split$n_splits 36 | wandb online 37 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 38 | --batch_size $batch_size --model $model --binning standard --wandb 39 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 40 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 41 | wandb offline 42 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 43 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 44 | --batch_size $batch_size --model $model --binning standard 45 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 46 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 47 | done 48 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 49 | done 50 | 51 | batch_size=16 52 | eval_batch_size=$((batch_size*2)) 53 | for ((idx = 0 ; idx < 6 ; idx++ )); do 54 | dataset=20news-$idx 55 | expdir=exp/$dataset/$setting/split$n_splits 56 | wandb offline 57 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 58 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 59 | --batch_size $batch_size --model $model --binning standard --lr 0.0005 60 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 61 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lr 0.0005 62 | done 63 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 64 | done 65 | 66 | batch_size=4 67 | eval_batch_size=$((batch_size*2)) 68 | for dataset in 'fakejob'; do 69 | expdir=exp/$dataset/$setting/split$n_splits 70 | wandb online 71 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 20000 \ 72 | --batch_size $batch_size --model $model --binning standard --wandb 73 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 74 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 75 | wandb offline 76 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 77 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 20000\ 78 | --batch_size $batch_size --model $model --binning standard 79 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 80 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 81 | done 82 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 83 | done 84 | 85 | done 86 | -------------------------------------------------------------------------------- /scripts/exp1-mixed_benchmark/run_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | n_splits=5 3 | setting='semi_supervised' 4 | for dataset in 'vifd' 'fraudecom' 'lymphography' 'seismic' 'fakejob'; do 5 | expdir=exp/$dataset/$setting/split$n_splits 6 | for ((split_idx = 0 ; split_idx < $n_splits ; split_idx++ )); do 7 | CUDA_VISIBLE_DEVICES=0 python evaluate_baselines.py --dataset $dataset --n_splits $n_splits --normalize --setting $setting --split_idx $split_idx & 8 | done 9 | wait 10 | expdir=exp/$dataset/$setting/split$n_splits 11 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 12 | done 13 | 14 | for ((idx = 0 ; idx < 5 ; idx++ )); do 15 | dataset=20news-$idx 16 | expdir=exp/$dataset/$setting/split$n_splits 17 | for ((split_idx = 0 ; split_idx < $n_splits ; split_idx++ )); do 18 | CUDA_VISIBLE_DEVICES=$split_idx python evaluate_baselines.py --dataset $dataset --n_splits $n_splits --normalize --setting $setting --split_idx $split_idx & 19 | done 20 | wait 21 | expdir=exp/$dataset/$setting/split$n_splits 22 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 23 | done -------------------------------------------------------------------------------- /scripts/exp2-odds/run_anollm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | n_splits=5 3 | setting=semi_supervised 4 | TRAIN_GPUS="0,1,2,3" 5 | INFERENCE_GPUS="1,2,3" 6 | n_train_node=4 7 | n_test_node=3 8 | n_permutations=21 9 | 10 | for model in 'smol' 'smol-360'; do 11 | batch_size=32 12 | eval_batch_size=$((batch_size*2)) 13 | for dataset in 'wine' 'breastw' 'cardio' 'ecoli' 'lymphography' 'vertebral' 'wbc' 'yeast' 'heart' 'glass' 'ionosphere' \ 14 | 'letter_recognition' 'mammography' 'pendigits' 'pima' 'satellite' 'satimage-2' 'thyroid' 'vowels' 'shuttle'; do 15 | expdir=exp/$dataset/$setting/split$n_splits 16 | wandb online 17 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 18 | --batch_size $batch_size --model $model --binning standard --wandb 19 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 20 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 21 | wandb offline 22 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 23 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 24 | --batch_size $batch_size --model $model --binning standard 25 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 26 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 27 | done 28 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 29 | done 30 | 31 | batch_size=16 32 | eval_batch_size=$((batch_size*2)) 33 | for dataset in 'seismic' 'optdigits' 'annthyroid'; do 34 | expdir=exp/$dataset/$setting/split$n_splits 35 | wandb online 36 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 37 | --batch_size $batch_size --model $model --binning standard --wandb 38 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 39 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 40 | wandb offline 41 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 42 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 43 | --batch_size $batch_size --model $model --binning standard 44 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 45 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 46 | done 47 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 48 | done 49 | 50 | batch_size=128 51 | eval_batch_size=$((batch_size*2)) 52 | for dataset in 'http' 'smtp'; do 53 | expdir=exp/$dataset/$setting/split$n_splits 54 | wandb online 55 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 56 | --batch_size $batch_size --model $model --binning standard --wandb 57 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 58 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 59 | wandb offline 60 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 61 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 62 | --batch_size $batch_size --model $model --binning standard 63 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 64 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 65 | done 66 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 67 | done 68 | 69 | batch_size=96 70 | eval_batch_size=$((batch_size*2)) 71 | for dataset in 'mulcross'; do 72 | expdir=exp/$dataset/$setting/split$n_splits 73 | wandb online 74 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 75 | --batch_size $batch_size --model $model --binning standard --wandb 76 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 77 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 78 | wandb offline 79 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 80 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 81 | --batch_size $batch_size --model $model --binning standard 82 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 83 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 84 | done 85 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 86 | done 87 | 88 | 89 | batch_size=48 90 | eval_batch_size=$((batch_size*2)) 91 | for dataset in 'covertype'; do 92 | expdir=exp/$dataset/$setting/split$n_splits 93 | wandb online 94 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 95 | --batch_size $batch_size --model $model --binning standard --wandb 96 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 97 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 98 | wandb offline 99 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 100 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 101 | --batch_size $batch_size --model $model --binning standard 102 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 103 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 104 | done 105 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 106 | done 107 | 108 | 109 | batch_size=8 110 | eval_batch_size=$((batch_size*2)) 111 | for dataset in 'musk'; do 112 | expdir=exp/$dataset/$setting/split$n_splits 113 | wandb online 114 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 115 | --batch_size $batch_size --model $model --binning standard --wandb 116 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 117 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 118 | wandb offline 119 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 120 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 121 | --batch_size $batch_size --model $model --binning standard 122 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 123 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 124 | done 125 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 126 | done 127 | 128 | batch_size=2 129 | eval_batch_size=$((batch_size*2)) 130 | for dataset in 'arrhythmia' 'speech'; do 131 | expdir=exp/$dataset/$setting/split$n_splits 132 | wandb online 133 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 134 | --batch_size $batch_size --model $model --binning standard --wandb 135 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 136 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 137 | wandb offline 138 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 139 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 140 | --batch_size $batch_size --model $model --binning standard 141 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 142 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard 143 | done 144 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 145 | done 146 | done 147 | 148 | -------------------------------------------------------------------------------- /scripts/exp2-odds/run_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | n_splits=5 3 | setting='semi_supervised' 4 | for dataset in 'wine' 'breastw' 'cardio' 'ecoli' 'lymphography' 'vertebral' 'wbc' 'yeast' \ 5 | 'heart' 'annthyroid' 'glass' 'ionosphere' 'letter_recognition' 'mammography' 'pendigits' 'pima' 'satellite' 'satimage-2' 'thyroid' 'vowels'\ 6 | 'seismic' 'optdigits' 'http' 'smtp' 'mulcross' 'covertype' 'shuttle' 'musk' 'arrhythmia' 'speech'; do 7 | expdir=exp/$dataset/$setting/split$n_splits 8 | for ((split_idx = 0 ; split_idx < $n_splits ; split_idx++ )); do 9 | CUDA_VISIBLE_DEVICES=0 python evaluate_baselines.py --dataset $dataset --n_splits $n_splits --normalize --setting $setting --split_idx $split_idx & 10 | done 11 | wait 12 | expdir=exp/$dataset/$setting/split$n_splits 13 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 14 | done 15 | -------------------------------------------------------------------------------- /scripts/exp3-binning_effect/run_binning_odds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | n_splits=5 3 | setting=semi_supervised 4 | n_buckets=10 5 | model='smol' 6 | TRAIN_GPUS="0,1,2,3" 7 | INFERENCE_GPUS="1,2,3" 8 | n_train_node=4 9 | n_test_node=3 10 | n_permutations=21 11 | 12 | for binning in 'equal_width' 'quantile' 'standard' 'language' 'none'; do 13 | batch_size=32 14 | eval_batch_size=$((batch_size*2)) 15 | for dataset in 'wine' 'breastw' 'cardio' 'ecoli' 'lymphography' 'vertebral' 'wbc' 'yeast' 'heart' 'glass' 'ionosphere' \ 16 | 'letter_recognition' 'mammography' 'pendigits' 'pima' 'satellite' 'satimage-2' 'thyroid' 'vowels' 'shuttle'; do 17 | expdir=exp/$dataset/$setting/split$n_splits 18 | wandb online 19 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 20 | --batch_size $batch_size --model $model --binning $binning --wandb 21 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 22 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 23 | wandb offline 24 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 25 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 26 | --batch_size $batch_size --model $model --binning $binning 27 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 28 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 29 | done 30 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 31 | done 32 | 33 | batch_size=16 34 | eval_batch_size=$((batch_size*2)) 35 | for dataset in 'seismic' 'optdigits' 'annthyroid'; do 36 | expdir=exp/$dataset/$setting/split$n_splits 37 | wandb online 38 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 39 | --batch_size $batch_size --model $model --binning $binning --wandb 40 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 41 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 42 | wandb offline 43 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 44 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 45 | --batch_size $batch_size --model $model --binning $binning 46 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 47 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 48 | done 49 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 50 | done 51 | 52 | batch_size=128 53 | eval_batch_size=$((batch_size*2)) 54 | for dataset in 'http' 'smtp'; do 55 | expdir=exp/$dataset/$setting/split$n_splits 56 | wandb online 57 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 58 | --batch_size $batch_size --model $model --binning $binning --wandb 59 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 60 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 61 | wandb offline 62 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 63 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 64 | --batch_size $batch_size --model $model --binning $binning 65 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 66 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 67 | done 68 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 69 | done 70 | 71 | batch_size=96 72 | eval_batch_size=$((batch_size*2)) 73 | for dataset in 'mulcross'; do 74 | expdir=exp/$dataset/$setting/split$n_splits 75 | wandb online 76 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 77 | --batch_size $batch_size --model $model --binning $binning --wandb 78 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 79 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 80 | wandb offline 81 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 82 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 83 | --batch_size $batch_size --model $model --binning $binning 84 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 85 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 86 | done 87 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 88 | done 89 | 90 | 91 | batch_size=48 92 | eval_batch_size=$((batch_size*2)) 93 | for dataset in 'covertype'; do 94 | expdir=exp/$dataset/$setting/split$n_splits 95 | wandb online 96 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 97 | --batch_size $batch_size --model $model --binning $binning --wandb 98 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 99 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 100 | wandb offline 101 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 102 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 103 | --batch_size $batch_size --model $model --binning $binning 104 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 105 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 106 | done 107 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 108 | done 109 | 110 | 111 | batch_size=8 112 | for dataset in 'musk'; do 113 | expdir=exp/$dataset/$setting/split$n_splits 114 | wandb online 115 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 116 | --batch_size $batch_size --model $model --binning $binning --wandb 117 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 118 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 119 | wandb offline 120 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 121 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 122 | --batch_size $batch_size --model $model --binning $binning 123 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 124 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 125 | done 126 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 127 | done 128 | 129 | batch_size=2 130 | for dataset in 'arrhythmia' 'speech'; do 131 | expdir=exp/$dataset/$setting/split$n_splits 132 | wandb online 133 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 134 | --batch_size $batch_size --model $model --binning $binning --wandb 135 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 136 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 137 | wandb offline 138 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 139 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 140 | --batch_size $batch_size --model $model --binning $binning 141 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 142 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning $binning 143 | done 144 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 145 | done 146 | done 147 | 148 | -------------------------------------------------------------------------------- /scripts/exp4-model_size/run_anollm_1.7B_mixed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | n_splits=5 3 | setting=semi_supervised 4 | TRAIN_GPUS="0,1,2,3" 5 | INFERENCE_GPUS="1,2,3" 6 | n_train_node=4 7 | n_test_node=3 8 | n_permutations=21 9 | 10 | for model in 'smol-1.7b'; do 11 | batch_size=32 12 | eval_batch_size=$((batch_size*2)) 13 | for dataset in 'vifd' 'fraudecom' 'lymphography'; do 14 | expdir=exp/$dataset/$setting/split$n_splits 15 | wandb online 16 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 17 | --batch_size $batch_size --model $model --binning standard --wandb --lora 18 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 19 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 20 | wandb offline 21 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 22 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 23 | --batch_size $batch_size --model $model --binning standard --lora 24 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 25 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 26 | done 27 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 28 | 29 | done 30 | 31 | batch_size=16 32 | eval_batch_size=$((batch_size*2)) 33 | for dataset in 'seismic'; do 34 | expdir=exp/$dataset/$setting/split$n_splits 35 | wandb online 36 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 37 | --batch_size $batch_size --model $model --binning standard --wandb --lora 38 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 39 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 40 | wandb offline 41 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 42 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 43 | --batch_size $batch_size --model $model --binning standard --lora 44 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 45 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 46 | done 47 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 48 | done 49 | 50 | batch_size=4 51 | eval_batch_size=$((batch_size*2)) 52 | for dataset in 'fakejob'; do 53 | expdir=exp/$dataset/$setting/split$n_splits 54 | wandb online 55 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 20000 \ 56 | --batch_size $batch_size --model $model --binning standard --wandb --lora --lr 1e-3 57 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 58 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lr 1e-3 --lora 59 | wandb offline 60 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 61 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 20000\ 62 | --batch_size $batch_size --model $model --binning standard --lora --lr 1e-3 63 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 64 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lr 1e-3 --lora 65 | done 66 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 67 | done 68 | 69 | batch_size=4 70 | eval_batch_size=$((batch_size*2)) 71 | for ((idx = 0 ; idx < 6 ; idx++ )); do 72 | dataset=20news-$idx 73 | expdir=exp/$dataset/$setting/split$n_splits 74 | wandb offline 75 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 76 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 77 | --batch_size $batch_size --model $model --binning standard --lr 0.0005 --lora 78 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 79 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lr 0.0005 --lora 80 | done 81 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 82 | done 83 | 84 | done -------------------------------------------------------------------------------- /scripts/exp4-model_size/run_anollm_1.7B_odds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | n_splits=5 3 | setting=semi_supervised 4 | TRAIN_GPUS="0,1,2,3" 5 | INFERENCE_GPUS="1,2,3" 6 | n_train_node=4 7 | n_test_node=3 8 | n_permutations=21 9 | 10 | for model in 'smol-1.7b'; do 11 | batch_size=32 12 | eval_batch_size=$((batch_size*2)) 13 | for dataset in 'wine' 'breastw' 'cardio' 'ecoli' 'lymphography' 'vertebral' 'wbc' 'yeast'; do 14 | expdir=exp/$dataset/$setting/split$n_splits 15 | wandb online 16 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 17 | --batch_size $batch_size --model $model --binning standard --wandb --lora 18 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 19 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 20 | wandb offline 21 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 22 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 23 | --batch_size $batch_size --model $model --binning standard --lora 24 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 25 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 26 | done 27 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 28 | done 29 | 30 | 31 | batch_size=32 32 | eval_batch_size=$((batch_size*2)) 33 | for dataset in 'heart' 'glass' 'ionosphere' 'letter_recognition' 'mammography' 'pendigits' 'pima' 'satellite' 'satimage-2' 'thyroid' 'vowels' 'shuttle'; do 34 | expdir=exp/$dataset/$setting/split$n_splits 35 | wandb online 36 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 37 | --batch_size $batch_size --model $model --binning standard --wandb --lora 38 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 39 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 40 | wandb offline 41 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 42 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 43 | --batch_size $batch_size --model $model --binning standard --lora 44 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 45 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 46 | done 47 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 48 | done 49 | 50 | 51 | batch_size=16 52 | eval_batch_size=$((batch_size*2)) 53 | for dataset in 'seismic' 'optdigits' 'annthyroid'; do 54 | expdir=exp/$dataset/$setting/split$n_splits 55 | wandb online 56 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 57 | --batch_size $batch_size --model $model --binning standard --wandb --lora 58 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 59 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 60 | wandb offline 61 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 62 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 63 | --batch_size $batch_size --model $model --binning standard --lora 64 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 65 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 66 | done 67 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 68 | done 69 | 70 | batch_size=128 71 | eval_batch_size=$((batch_size*2)) 72 | for dataset in 'http' 'smtp'; do 73 | expdir=exp/$dataset/$setting/split$n_splits 74 | wandb online 75 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 76 | --batch_size $batch_size --model $model --binning standard --wandb --lora 77 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 78 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 79 | wandb offline 80 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 81 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 82 | --batch_size $batch_size --model $model --binning standard --lora 83 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 84 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 85 | done 86 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 87 | done 88 | 89 | batch_size=96 90 | eval_batch_size=$((batch_size*2)) 91 | for dataset in 'mulcross'; do 92 | expdir=exp/$dataset/$setting/split$n_splits 93 | wandb online 94 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 95 | --batch_size $batch_size --model $model --binning standard --wandb --lora 96 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 97 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 98 | wandb offline 99 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 100 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 101 | --batch_size $batch_size --model $model --binning standard --lora 102 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 103 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 104 | done 105 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 106 | done 107 | 108 | 109 | batch_size=48 110 | eval_batch_size=$((batch_size*2)) 111 | for dataset in 'covertype'; do 112 | expdir=exp/$dataset/$setting/split$n_splits 113 | wandb online 114 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 115 | --batch_size $batch_size --model $model --binning standard --wandb --lora 116 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 117 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 118 | wandb offline 119 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 120 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 121 | --batch_size $batch_size --model $model --binning standard --lora 122 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 123 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 124 | done 125 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 126 | done 127 | 128 | 129 | batch_size=8 130 | for dataset in 'musk'; do 131 | expdir=exp/$dataset/$setting/split$n_splits 132 | wandb online 133 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 134 | --batch_size $batch_size --model $model --binning standard --wandb --lora 135 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 136 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 137 | wandb offline 138 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 139 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 140 | --batch_size $batch_size --model $model --binning standard --lora 141 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 142 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 143 | done 144 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 145 | done 146 | 147 | batch_size=2 148 | for dataset in 'arrhythmia' 'speech'; do 149 | expdir=exp/$dataset/$setting/split$n_splits 150 | wandb online 151 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting --max_steps 2000 \ 152 | --batch_size $batch_size --model $model --binning standard --wandb --lora 153 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx 0 --setting $setting\ 154 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 155 | wandb offline 156 | for ((split_idx = 1 ; split_idx < $n_splits ; split_idx++ )); do 157 | CUDA_VISIBLE_DEVICES=$TRAIN_GPUS torchrun --nproc_per_node=$n_train_node train_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting --max_steps 2000\ 158 | --batch_size $batch_size --model $model --binning standard --lora 159 | CUDA_VISIBLE_DEVICES=$INFERENCE_GPUS torchrun --nproc_per_node=$n_test_node evaluate_anollm.py --dataset $dataset --n_splits $n_splits --split_idx $split_idx --setting $setting\ 160 | --batch_size $eval_batch_size --n_permutations $n_permutations --model $model --binning standard --lora 161 | done 162 | python -u src/get_results.py --dataset $dataset --n_splits $n_splits --setting $setting | tee $expdir/evaluate.log 163 | done 164 | done 165 | 166 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py -------------------------------------------------------------------------------- /src/baselines/dte.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 Victor Livernoche. Licensed under the MIT License. 3 | On Diffusion Modeling for Anomaly Detection - Diffusion Time Estimation (https://github.com/vicliv/DTE/tree/main) 4 | @Author: Victor Livernoche 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from torch import nn 9 | import torch 10 | import sklearn.metrics as skm 11 | from torch.optim import Adam 12 | from torch.utils.data import DataLoader 13 | import numpy as np 14 | 15 | from sklearn.metrics import roc_auc_score 16 | 17 | class MLP(nn.Module): 18 | def __init__(self, hidden_sizes, num_bins = 7): 19 | super().__init__() 20 | self.hidden_sizes = hidden_sizes # hidden layers sizes 21 | self.activation = nn.ReLU() # activation to use in the network 22 | 23 | layers = [] 24 | for i in range(1, len(self.hidden_sizes)): 25 | layers.append(nn.Linear(hidden_sizes[i-1], hidden_sizes[i])) 26 | 27 | if num_bins > 1: 28 | # if we have the classification model 29 | layers.append(nn.Linear(hidden_sizes[-1], num_bins)) 30 | self.softmax = nn.Softmax(dim = 1) 31 | else: 32 | # if we have the regression model 33 | layers.append(nn.Linear(hidden_sizes[-1], 1)) 34 | self.softmax = lambda x : x # ignore softmaxt 35 | 36 | self.layers = nn.ModuleList(layers) 37 | 38 | self.drop = torch.nn.Dropout(p=0.5, inplace=False) # dropout 39 | 40 | def forward(self, x): 41 | x = self.activation(self.layers[0](x)) 42 | 43 | for layer in self.layers[1:-1]: 44 | x = self.activation(layer(x)) 45 | x = self.drop(x) 46 | 47 | return self.softmax(self.layers[-1](x)) 48 | 49 | def binning(t, T= 300, num_bins = 30, device = 'cpu'): 50 | """ 51 | Gives the bin number for a given t based on T (maximum) and the number of bins 52 | This is floor(t*num_bins/T) bounded by 0 and T-1 53 | """ 54 | return torch.maximum(torch.minimum(torch.floor(t*num_bins/T).to(device), torch.tensor(num_bins-1).to(device)), torch.tensor(0).to(device)).long() 55 | 56 | class DTE(): 57 | def __init__(self, seed = 0, model_name = "DTE", hidden_size = [256, 512, 256], epochs = 400, batch_size = 64, lr = 1e-4, weight_decay = 5e-4, T=400, num_bins=7, device = None): 58 | self.hidden_size = hidden_size 59 | self.epochs = epochs 60 | self.batch_size = batch_size 61 | self.lr = lr 62 | self.weight_decay = weight_decay 63 | 64 | self.T = T 65 | self.num_bins = num_bins 66 | 67 | if device is None: 68 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 69 | else: 70 | self.device = device 71 | self.seed = seed 72 | 73 | betas = torch.linspace(0.0001, 0.01, T) # linear beta scheduling 74 | 75 | # Pre-calculate different terms for closed form of diffusion process 76 | alphas = 1. - betas 77 | alphas_cumprod = torch.cumprod(alphas, axis=0) 78 | self.alphas_cumprod = alphas_cumprod 79 | 80 | sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) 81 | sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) 82 | 83 | def forward_noise(x_0, t, drift = False): 84 | """ 85 | Takes data point and a timestep as input and 86 | returns the noisy version of it 87 | """ 88 | noise = torch.randn_like(x_0) # epsilon 89 | 90 | noise.requires_grad_() # for the backward propagation of the NN 91 | sqrt_alphas_cumprod_t = torch.take(sqrt_alphas_cumprod, t.cpu()).to(self.device).unsqueeze(1) 92 | sqrt_one_minus_alphas_cumprod_t = torch.take(sqrt_one_minus_alphas_cumprod, t.cpu()).to(self.device).unsqueeze(1) 93 | 94 | # mean + variance 95 | if drift: 96 | return (sqrt_alphas_cumprod_t.to(self.device) * x_0.to(self.device) + sqrt_one_minus_alphas_cumprod_t.to(self.device) * noise.to(self.device)).to(torch.float32) 97 | else: # variance only 98 | return (x_0.to(self.device) + sqrt_one_minus_alphas_cumprod_t.to(self.device) * noise.to(self.device)).to(torch.float32) 99 | 100 | self.forward_noise = forward_noise 101 | self.model = None 102 | 103 | def compute_loss(self, x, t): 104 | pass 105 | 106 | def fit(self, X_train, y_train = None, X_test = None, y_test = None, verbose=False): 107 | if self.model is None: # allows retraining 108 | self.model = MLP([X_train.shape[-1]] + self.hidden_size, num_bins = self.num_bins).to(self.device) 109 | 110 | optimizer = Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) 111 | train_loader = DataLoader(torch.from_numpy(X_train).float(), batch_size=self.batch_size, shuffle=True, drop_last=False) 112 | 113 | train_losses = [] 114 | for epoch in range(self.epochs): 115 | self.model.train() 116 | loss_ = [] 117 | 118 | for x in train_loader: 119 | x = x.to(self.device) 120 | optimizer.zero_grad() 121 | 122 | # sample t uniformly 123 | t = torch.randint(0, self.T, (x.shape[0],), device=self.device).long() 124 | 125 | # compute the loss 126 | loss = self.compute_loss(x, t) 127 | 128 | loss.backward() 129 | optimizer.step() 130 | loss_.append(loss.item()) 131 | 132 | train_losses.append(np.mean(np.array(loss_))) 133 | 134 | if epoch % 1 == 0 and verbose: 135 | if X_test is not None and y_test is not None: 136 | print(roc_auc_score(y_true=y_test, y_score=self.decision_function(X_test))) 137 | print(f"Epoch {epoch} Train Loss: {train_losses[len(train_losses)-1]}") 138 | 139 | return self 140 | 141 | def decision_function(self, X): 142 | test_loader = DataLoader(torch.from_numpy(X).float(), batch_size=100, shuffle=False, drop_last=False) 143 | preds = [] 144 | self.model.eval() 145 | for x in test_loader: 146 | # predict the timestep based on x, or the probability of each class for the classification 147 | pred_t = self.model(x.to(self.device).to(torch.float32)) 148 | preds.append(pred_t.cpu().detach().numpy()) 149 | 150 | preds = np.concatenate(preds, axis=0) 151 | 152 | if self.num_bins > 1: 153 | #preds = np.argmax(preds, axis=1) 154 | 155 | # compute mean prediction over all bins 156 | preds = np.matmul(preds, np.arange(0, preds.shape[-1])) 157 | else: 158 | preds = preds.squeeze() 159 | 160 | return preds 161 | 162 | class DTECategorical(DTE): 163 | def __init__(self, seed = 0, model_name = "DTE_categorical", hidden_size = [256, 512, 256], epochs = 400, batch_size = 64, lr = 1e-4, weight_decay = 5e-4, T=400, num_bins=7, device=None): 164 | if num_bins < 2: 165 | raise ValueError("num_bins must be greater than or equal to 2") 166 | 167 | super().__init__(seed, model_name, hidden_size, epochs, batch_size, lr, weight_decay, T, num_bins, device) 168 | 169 | 170 | def compute_loss(self, x_0, t): 171 | # get the loss based on the input and timestep 172 | 173 | # get noisy sample 174 | x_noisy = self.forward_noise(x_0, t) 175 | 176 | # predict the timestep 177 | t_pred = self.model(x_noisy) 178 | 179 | # For the categorical model, the target is the binned t with cross entropy loss 180 | target = binning(t, T = self.T, device = self.device, num_bins = self.num_bins) 181 | 182 | loss = nn.CrossEntropyLoss()(t_pred, target) 183 | 184 | return loss 185 | 186 | class DTEInverseGamma(DTE): 187 | def __init__(self, seed = 0, model_name = "DTE_inverse_gamma", hidden_size = [256, 512, 256], epochs = 400, batch_size = 64, lr = 1e-4, weight_decay = 5e-4, T=400, device=None): 188 | super().__init__(seed, model_name, hidden_size, epochs, batch_size, lr, weight_decay, T, 0, device) 189 | 190 | def compute_loss(self, x_0, t): 191 | # get the loss based on the input and timestep 192 | _, dim = x_0.shape 193 | eps = 1e-5 194 | # get noisy sample 195 | x_noisy = self.forward_noise(x_0, t) 196 | 197 | # predict the inv gamma parameter 198 | sqrt_beta_pred = self.model(x_noisy) 199 | beta_pred = torch.pow(sqrt_beta_pred, 2).squeeze() 200 | 201 | var_target = (1. - self.alphas_cumprod[t.cpu()]).to(self.device) 202 | log_likelihood = (0.5 * dim - 1) * torch.log(beta_pred + eps) - beta_pred / (var_target) 203 | loss = -log_likelihood.mean() 204 | 205 | return loss 206 | 207 | def decision_function(self, X): 208 | N, dim = X.shape 209 | test_loader = DataLoader(torch.from_numpy(X).float(), batch_size=100, shuffle=False, drop_last=False) 210 | preds = [] 211 | self.model.eval() 212 | for x in test_loader: 213 | # predict the timestep based on x, or the probability of each class for the classification 214 | pred_t = self.model(x.to(self.device).to(torch.float32)) 215 | pred_t = torch.pow(pred_t, 2).squeeze() / ((0.5 * dim - 1)) # mode of the inverse gamma distribution 216 | preds.append(pred_t.cpu().detach().numpy()) 217 | 218 | preds = np.concatenate(preds, axis=0) 219 | 220 | return preds 221 | 222 | 223 | class DTEGaussian(DTE): 224 | def __init__(self, seed = 0, model_name = "DTE_gaussian", hidden_size = [256, 512, 256], epochs = 400, batch_size = 64, lr = 1e-4, weight_decay = 5e-4, T=400, device=None): 225 | super().__init__(seed, model_name, hidden_size, epochs, batch_size, lr, weight_decay, T, 0, device) 226 | 227 | def compute_loss(self, x_0, t): 228 | # get the loss based on the input and timestep 229 | 230 | # get noisy sample 231 | x_noisy = self.forward_noise(x_0, t) 232 | 233 | # predict the timestep 234 | t_pred = self.model(x_noisy) 235 | 236 | t_pred = t_pred.squeeze() 237 | target = t.float() 238 | 239 | loss = nn.MSELoss()(t_pred, target) 240 | 241 | return loss 242 | -------------------------------------------------------------------------------- /src/baselines/icl.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 Victor Livernoche. Licensed under the MIT License. 3 | On Diffusion Modeling for Anomaly Detection - Diffusion Time Estimation (https://github.com/vicliv/DTE/tree/main) 4 | @Author: Victor Livernoche 5 | """ 6 | 7 | import torch 8 | import numpy as np 9 | import random 10 | import pandas as pd 11 | from torch import nn 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 15 | 16 | 17 | def scores_calc_internal(query, positive,no_negatives,tau): 18 | pos_multiplication = (query * positive).sum(dim=2).unsqueeze(2).to(device) 19 | if no_negatives <= query.shape[1]: 20 | negative_index = random.sample(range(0, query.shape[1]), no_negatives) 21 | else: 22 | negative_index = random.sample(range(0, query.shape[1]), query.shape[1]) 23 | neg_multiplication = torch.matmul(query, positive.permute(0, 2, 1)[:, :,negative_index]) 24 | # Removal of the diagonals 25 | identity_matrix = torch.eye(np.shape(query)[1]).unsqueeze(0).repeat(np.shape(query)[0], 1, 26 | 1)[:, :, negative_index].to(device) 27 | neg_multiplication.masked_fill_(identity_matrix == 1, -float('inf')) # exp of -inf=0 28 | logits = torch.cat((pos_multiplication, neg_multiplication), dim=2).to(device) 29 | logits=logits/tau 30 | return (logits) 31 | 32 | 33 | def take_per_row_complement(A, indx, num_elem=3): 34 | all_indx = indx[:,None] + np.arange(num_elem) 35 | 36 | all_indx_complement=[] 37 | for row in all_indx: 38 | complement=a_minus_b(np.arange(A.shape[2]),row) 39 | all_indx_complement.append(complement) 40 | all_indx_complement=np.array(all_indx_complement) 41 | return (A[:,np.arange(all_indx.shape[0])[:,None],all_indx],A[:,np.arange(all_indx.shape[0])[:,None],all_indx_complement]) 42 | 43 | def positive_matrice_builder(dataset, kernel_size): 44 | dataset = torch.squeeze(dataset, 2) 45 | if kernel_size != 1: 46 | indices = np.array((range(dataset.shape[1])))[:-kernel_size + 1] 47 | else: 48 | indices = np.array((range(dataset.shape[1]))) 49 | dataset = torch.unsqueeze(dataset, 1) 50 | dataset = dataset.repeat(1, dataset.shape[2], 1) 51 | 52 | matrice,complement_matrice = take_per_row_complement(dataset, indices, num_elem=kernel_size) 53 | return (matrice,complement_matrice) 54 | 55 | 56 | def take_per_row(A, indx, num_elem=2): 57 | all_indx = indx[:, None] + np.arange(num_elem) 58 | return A[:, np.arange(all_indx.shape[0])[:, None], all_indx] 59 | 60 | 61 | def f1_calculator(classes, losses): 62 | df_version_classes = pd.DataFrame(data=classes) 63 | df_version_losses = pd.DataFrame(losses).astype(np.float64) 64 | Na = df_version_classes[df_version_classes.iloc[:, 0] == 1].shape[0] 65 | anomaly_indices = df_version_losses.nlargest(Na, 0).index.values 66 | picked_anomalies = df_version_classes.iloc[anomaly_indices] 67 | true_pos = picked_anomalies[picked_anomalies.iloc[:, 0] == 1].shape[0] 68 | false_pos = picked_anomalies[picked_anomalies.iloc[:, 0] == 0].shape[0] 69 | f1 = true_pos / (true_pos + false_pos) 70 | return (f1) 71 | 72 | 73 | def a_minus_b (a,b): 74 | sidx = b.argsort() 75 | idx = np.searchsorted(b, a, sorter=sidx) 76 | idx[idx == len(b)] = 0 77 | out = a[b[sidx[idx]] != a] 78 | return out 79 | 80 | class DatasetBuilder(Dataset): 81 | def __init__(self, data): 82 | self.data = data 83 | 84 | def __len__(self): 85 | return (self.data.shape[0]) 86 | 87 | def __getitem__(self, idx): 88 | if torch.is_tensor(idx): 89 | idx = idx.tolist() 90 | sample = {'data': self.data[idx], 'index': idx} 91 | return sample 92 | 93 | class encoder_a(nn.Module): 94 | def __init__(self, kernel_size,hdn_size,d): 95 | super(encoder_a, self).__init__() 96 | self.fc1 = nn.Linear(d-kernel_size, hdn_size) #F network 97 | self.activation1 = nn.Tanh() 98 | self.fc2 = nn.Linear(hdn_size, hdn_size*2) 99 | self.activation2 = nn.LeakyReLU(0.2) 100 | self.fc3 = nn.Linear(hdn_size*2, hdn_size) 101 | self.activation3 = nn.LeakyReLU(0.2) 102 | self.batchnorm_1 = nn.BatchNorm1d(d-kernel_size+1) 103 | self.batchnorm_2 = nn.BatchNorm1d(d-kernel_size+1) 104 | self.fc1_y = nn.Linear(kernel_size, int(hdn_size/4)) #G network 105 | self.activation1_y = nn.LeakyReLU(0.2) 106 | self.fc2_y = nn.Linear(int(hdn_size/4), int(hdn_size/2)) 107 | self.activation2_y = nn.LeakyReLU(0.2) 108 | self.fc3_y = nn.Linear(int(hdn_size/2), hdn_size) 109 | self.activation3_y = nn.LeakyReLU(0.2) 110 | self.kernel_size = kernel_size 111 | self.batchnorm1_y=nn.BatchNorm1d(d-kernel_size+1) 112 | def forward(self, x): 113 | x = x.permute(0, 2, 1) 114 | y,x = positive_matrice_builder(x, self.kernel_size) 115 | x = self.activation1(self.fc1(x)) 116 | x=self.batchnorm_1(x) 117 | x = self.activation2(self.fc2(x)) 118 | x=self.batchnorm_2(x) 119 | x = self.activation3(self.fc3(x)) 120 | y = self.activation1_y(self.fc1_y(y)) 121 | y=self.batchnorm1_y(y) 122 | y = self.activation2_y(self.fc2_y(y)) 123 | y = self.activation3_y(self.fc3_y(y)) 124 | x=nn.functional.normalize(x,dim=1) 125 | y=nn.functional.normalize(y,dim=1) 126 | x=nn.functional.normalize(x,dim=2) 127 | y=nn.functional.normalize(y,dim=2) 128 | return (x, y) 129 | 130 | class ICL(): 131 | def __init__(self, seed=0, model_name="ICL", num_epochs = 2000, no_batchs = 3000, no_negatives=1000, temperature=0.1, lr=0.001, device=None): 132 | self.seed = seed 133 | 134 | if device is None: 135 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 136 | else: 137 | device = device 138 | 139 | self.num_epochs = num_epochs 140 | self.no_btchs = no_batchs 141 | self.no_negatives=no_negatives 142 | self.temperature=temperature 143 | self.lr=lr 144 | self.faster_version="no" 145 | 146 | self.models = [] 147 | 148 | self.perms = [] 149 | 150 | def fit(self,X_train, y_train=None): 151 | train = X_train 152 | train = torch.as_tensor(train, dtype=torch.float) 153 | d = train.shape[1] 154 | n = train.shape[0] 155 | if self.faster_version=='yes': 156 | num_permutations = min(int(np.floor(100 / (np.log(n) + d)) + 1),2) 157 | else: 158 | num_permutations=int(np.floor(100/(np.log(n)+d))+1) 159 | num_permutations = 1 160 | print("going to run for: ", num_permutations, ' permutations') 161 | hiddensize = 200 162 | if d <= 40: 163 | kernel_size = 2 164 | stop_crteria = 0.001 165 | if 40 < d and d <= 160: 166 | kernel_size = 10 167 | stop_crteria = 0.01 168 | if 160 < d: 169 | kernel_size = d - 150 170 | stop_crteria = 0.01 171 | for permutations in range(num_permutations): 172 | if num_permutations > 1: 173 | random_idx = torch.randperm(train.shape[1]) 174 | self.perms.append(random_idx) 175 | train = train[:, random_idx] 176 | 177 | dataset_train = DatasetBuilder(train) 178 | model_a = encoder_a(kernel_size, hiddensize, d).to(device) 179 | self.models.append(model_a) 180 | criterion = nn.CrossEntropyLoss() 181 | optimizer_a = torch.optim.Adam(model_a.parameters(), lr=self.lr) 182 | trainloader = DataLoader(dataset_train, batch_size=self.no_btchs, 183 | shuffle=True, num_workers=0, pin_memory=True) 184 | ### training 185 | for epoch in range(self.num_epochs): 186 | model_a.train() 187 | running_loss = 0 188 | for i, sample in enumerate(trainloader, 0): 189 | model_a.zero_grad() 190 | pre_query = sample['data'].to(device) 191 | pre_query = torch.unsqueeze(pre_query, 1) 192 | pre_query, positives_matrice = model_a(pre_query) 193 | scores_internal = scores_calc_internal(pre_query, positives_matrice,self.no_negatives,self.temperature).to(device) 194 | scores_internal = scores_internal.permute(0, 2, 1) 195 | correct_class = torch.zeros((np.shape(scores_internal)[0], np.shape(scores_internal)[2]), 196 | dtype=torch.long).to(device) 197 | loss = criterion(scores_internal, correct_class).to(device) 198 | loss.backward() 199 | optimizer_a.step() 200 | running_loss += loss.item() 201 | if (running_loss / (i + 1) < stop_crteria): 202 | break 203 | if n<2000: 204 | if (epoch + 1) % 100 == 0: 205 | print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / (i + 1))) 206 | else: 207 | if (epoch + 1) % 10 == 0: 208 | print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / (i + 1))) 209 | return self 210 | 211 | def decision_function(self, test_X): 212 | test = torch.as_tensor(test_X, dtype=torch.float) 213 | 214 | test_losses_contrastloss = torch.zeros(test.shape[0],dtype=torch.float).to(device) 215 | 216 | 217 | for i, model in enumerate(self.models): 218 | if len(self.perms) > 0: 219 | test = test[:, self.perms[i]] 220 | dataset_test = DatasetBuilder(test) 221 | testloader = DataLoader(dataset_test, batch_size=self.no_btchs, 222 | shuffle=True, num_workers=0, pin_memory=True) 223 | 224 | model.eval() 225 | criterion_test = nn.CrossEntropyLoss(reduction='none') 226 | with torch.no_grad(): 227 | for i, sample in enumerate(testloader, 0): 228 | pre_query = sample['data'].to(device) 229 | indexes = sample['index'].to(device) 230 | pre_query_test = torch.unsqueeze(pre_query, 1) # batch X feature X 1 231 | pre_query_test, positives_matrice_test = model(pre_query_test) 232 | scores_internal_test = scores_calc_internal(pre_query_test, positives_matrice_test,self.no_negatives,self.temperature).to(device) 233 | scores_internal_test = scores_internal_test.permute(0, 2, 1) 234 | correct_class = torch.zeros((np.shape(scores_internal_test)[0], np.shape(scores_internal_test)[2]), 235 | dtype=torch.long).to(device) 236 | loss_test = criterion_test(scores_internal_test, correct_class).to(device) 237 | test_losses_contrastloss[indexes] += loss_test.mean(dim=1).to(device) 238 | return test_losses_contrastloss.cpu().detach().numpy() 239 | 240 | 241 | 242 | -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from typing import Optional 5 | from pathlib import Path 6 | import copy 7 | import scipy.io 8 | import pickle as pkl 9 | 10 | import ucimlrepo 11 | import adbench 12 | import pickle 13 | # Preprocessing 14 | import string 15 | from string import ascii_uppercase 16 | 17 | import re 18 | from sklearn.preprocessing import LabelEncoder 19 | from sklearn.preprocessing import StandardScaler 20 | from sklearn.preprocessing import OneHotEncoder 21 | from feature_engine.encoding import RareLabelEncoder 22 | from sklearn.feature_extraction.text import CountVectorizer 23 | from sklearn.feature_extraction.text import TfidfVectorizer 24 | from sklearn.datasets import fetch_20newsgroups 25 | import gensim.downloader as api 26 | 27 | MIXED = [ 'vifd', 'fraudecom', 'fakejob', 'seismic', 'lymphography', '20news-0', '20news-1','20news-2','20news-3','20news-4','20news-5'] 28 | ODDS = ['breastw', 'cardio', 'ecoli', 'lymphography', 'vertebral', 'wbc', 'wine', 'yeast', 'heart', 'arrhythmia', 29 | 'mulcross', 'annthyroid', 'covertype', 'glass', 'http', 'ionosphere', 'letter_recognition', 'mammography', 'musk', 30 | 'optdigits', 'pendigits', 'pima', 'satellite', 'satimage-2', 'seismic', 'shuttle', 'smtp', 'speech', 'thyroid', 'vowels'] 31 | 32 | # Map of dataset names to their corresponding dataset IDs in the UCI ML repository 33 | DATA_MAP ={ 34 | # ucimlrepo 35 | 'breastw':15, 36 | 'cardio':193, 37 | 'ecoli': 39, 38 | 'lymphography': 63, 39 | 'vertebral': 212, 40 | 'wbc':17, 41 | 'wine': 109, 42 | 'yeast':110, 43 | # fraud detection 44 | 'vifd': None, 45 | 'fraudecom': None, 46 | 'fakejob': None, 47 | 'fakenews': None, 48 | # without feature names 49 | 'heart': 96, 50 | 'arrhythmia': None, # download from https://odds.cs.stonybrook.edu/arrhythmia-dataset/ 51 | 'mulcross': None, # download from https://www.openml.org/search?type=data&sort=runs&id=40897&status=active 52 | # adbench datasets: 53 | 'annthyroid': 2, 54 | 'covertype':31, 55 | 'glass': 14, 56 | 'http': 16, 57 | 'ionosphere': 18, 58 | 'letter_recognition':20, 59 | 'mammography': 23, 60 | 'mulcross': None, 61 | 'musk': 25, 62 | 'optdigits':26, 63 | 'pendigits':28, 64 | 'pima':29, 65 | 'satellite':30, 66 | 'satimage-2':31, 67 | 'seismic': None, 68 | 'shuttle':32, 69 | 'smtp':34, 70 | 'speech':36, 71 | 'thyroid':38, 72 | 'vowels':40, 73 | #20news: 74 | '20news-0': None, 75 | '20news-1': None, 76 | '20news-2': None, 77 | '20news-3': None, 78 | '20news-4': None, 79 | '20news-5': None, 80 | } 81 | 82 | def load_dataset(dataset_name, data_dir): 83 | dataset_dir = Path(data_dir) / dataset_name 84 | os.makedirs(dataset_dir, exist_ok = True) 85 | pkl_file = dataset_dir / 'data.pkl' 86 | if os.path.exists(pkl_file): 87 | with open(pkl_file, 'rb') as f: 88 | X, y= pickle.load(f) 89 | return X, y 90 | 91 | if dataset_name == 'wine': 92 | dataset_id = DATA_MAP[dataset_name] 93 | df = ucimlrepo.fetch_ucirepo(id=dataset_id).data['original'] 94 | np_data = load_adbench_data(dataset_name) 95 | columns = [name.replace('_', ' ') for name in df.columns[:-1] ] 96 | 97 | X = pd.DataFrame(data = np_data['X'], columns = columns) 98 | y = np_data['y'] 99 | elif dataset_name == 'breastw': 100 | dataset_id = DATA_MAP[dataset_name] 101 | df = ucimlrepo.fetch_ucirepo(id=dataset_id).data['original'] 102 | columns = [name.replace('_', ' ') for name in df.columns[1:-1] ] 103 | np_data = load_adbench_data(dataset_name) 104 | 105 | X = pd.DataFrame(data = np_data['X'], columns = columns) 106 | y = np_data['y'] 107 | 108 | elif dataset_name == 'cardio': 109 | dataset_id = DATA_MAP[dataset_name] 110 | uci_dataset = ucimlrepo.fetch_ucirepo(id=dataset_id) 111 | # get columns descriptions 112 | var_info = uci_dataset['metadata']['additional_info']['variable_info'] 113 | L = [ k.split(' - ') for k in var_info.split('\n') ] 114 | column_dict = {} 115 | for k, v in L: 116 | column_dict[k] = v.strip('\r') 117 | 118 | df = uci_dataset.data['original'] 119 | df = df[df['NSP'] != 2].reset_index(drop=True) 120 | y = df['NSP'].map({3:1, 1:0}) # map pathologic to 1, normal to 0 121 | y = y.to_numpy() 122 | 123 | df.drop(['CLASS','NSP'], inplace = True, axis = 1) 124 | new_columns = [ column_dict[c] for c in df.columns] 125 | df.columns = new_columns 126 | X = df 127 | elif dataset_name == 'ecoli': 128 | dataset_id = DATA_MAP[dataset_name] 129 | uci_dataset = ucimlrepo.fetch_ucirepo(id=dataset_id) 130 | columns = uci_dataset['variables']['description'][:8] 131 | X = uci_dataset.data['original'].drop(['class'], axis = 1) 132 | X.columns = columns 133 | X = X.drop(X.columns[0], axis=1)# drop id column 134 | y = uci_dataset.data['original']['class'].map({'omL':1,'imL':1,'imS':1, 'cp':0, 'im':0, 'pp':0, 'imU':0, 'om':0}) 135 | y = y.to_numpy() 136 | elif dataset_name == 'lymphography': 137 | dataset_id = DATA_MAP[dataset_name] 138 | uci_dataset = ucimlrepo.fetch_ucirepo(id=dataset_id) 139 | df = uci_dataset.data['original'] 140 | y = df['class'].map({1:1,2:0,3:0,4:1}) # 142 normal, 6 anomalies 141 | y = y.to_numpy() 142 | 143 | df.drop('class', inplace = True, axis = 1) 144 | df.drop('no. of nodes in', inplace = True, axis = 1) 145 | 146 | var_info = uci_dataset['metadata']['additional_info']['variable_info'] 147 | df['lymphatics'] = df['lymphatics'].map({1:'normal', 2:'arched', 3:'deformed', 4:'displaced'}).astype('object') 148 | df['defect in node'] = df['defect in node'].map({1:'no',2:'lacunar', 3:'lac. marginal', 4:'lac. central'}).astype('object') 149 | df['changes in lym'] = df['changes in lym'].map({1:'bean',2:'oval', 3:'round'}).astype('object') 150 | df['changes in node'] = df['changes in node'].map({1:'no',2:'lacunar', 3:'lac. marginal', 4:'lac. central'}).astype('object') 151 | df['changes in stru'] = df['changes in stru'].map({1:'no',2:'grainy', 3:'drop-like', 4:'coarse', 5:'diluted', 6: 'reticular', 7:'stripped', 8:'faint'}).astype('object') 152 | df['special forms'] = df['special forms'].map({1:'no',2:'chalices', 3:'vesicles'}).astype('object') 153 | 154 | for k in ['block of affere', 'bl. of lymph. c', 'bl. of lymph. s', 'by pass', 'extravasates', 'regeneration of', 'early uptake in', 'dislocation of', 'exclusion of no']: 155 | df[k] = df[k].map({1:'no',2:'yes'}).astype('object') 156 | 157 | X = df 158 | 159 | elif dataset_name == 'vertebral': 160 | dataset_id = DATA_MAP[dataset_name] 161 | uci_dataset = ucimlrepo.fetch_ucirepo(id=dataset_id) 162 | df = uci_dataset.data['original'] 163 | 164 | df_anomaly = df[df['class'] == 'Normal'] # 100 normal data is treated as abnormal 165 | df_normal = df[df['class'] != 'Normal'] # 210 166 | df_anomaly = df_anomaly.sample(n=30, random_state = 42) 167 | df = pd.concat([df_anomaly, df_normal], axis = 0, ignore_index=True) 168 | 169 | y = df['class'].map({'Spondylolisthesis':0, 'Normal':1, 'Hernia': 0}) # 210 normal, 30 anomalies 170 | y = y.to_numpy() 171 | df.drop('class', inplace = True, axis = 1) 172 | df.columns = [name.replace('_', ' ') for name in df.columns ] 173 | X = df 174 | elif dataset_name == 'covertype': 175 | dataset_id = DATA_MAP[dataset_name] 176 | uci_dataset = ucimlrepo.fetch_ucirepo(id=dataset_id) 177 | df = uci_dataset.data['original'] 178 | 179 | for column in df.columns: 180 | if 'Soil' in column or 'Wilderness' in column: 181 | df.drop(column, axis =1 , inplace = True) 182 | df_normal = df[df['Cover_Type'] == 2] 183 | df_anomaly = df[df['Cover_Type'] == 4] 184 | df = pd.concat([df_anomaly, df_normal], axis = 0, ignore_index=True) 185 | 186 | y = df['Cover_Type'].map({2:0, 4:1}) 187 | y = y.to_numpy() 188 | df.drop('Cover_Type', inplace = True, axis = 1) 189 | 190 | df.columns = [name.replace('_', ' ') for name in df.columns ] 191 | X = df 192 | elif dataset_name == 'heart': 193 | dataset_id = DATA_MAP[dataset_name] 194 | uci_dataset = ucimlrepo.fetch_ucirepo(id=dataset_id) 195 | df = uci_dataset.data['original'] 196 | 197 | y = df['diagnosis'] 198 | y = y.to_numpy() 199 | 200 | X = uci_dataset.data['original'].drop(['diagnosis'], axis = 1) 201 | 202 | elif dataset_name == 'wbc': 203 | dataset_id = DATA_MAP[dataset_name] 204 | uci_dataset = ucimlrepo.fetch_ucirepo(id=dataset_id) 205 | df = uci_dataset.data['original'] 206 | # downsample anomaly to 21 samples 207 | df_anomaly = df[df['Diagnosis'] == 'M'] 208 | df_normal = df[df['Diagnosis'] == 'B'] 209 | df_anomaly = df_anomaly.sample(n=21, random_state = 42) 210 | df = pd.concat([df_anomaly, df_normal], axis = 0, ignore_index=True) 211 | 212 | y = df['Diagnosis'].map({'M':1, 'B':0}) # 142 normal, 6 anomalies 213 | y = y.to_numpy() 214 | df.drop('Diagnosis', inplace = True, axis = 1) 215 | df.drop('ID', inplace = True, axis = 1) 216 | 217 | X = df 218 | elif dataset_name == 'yeast': 219 | # the split is different than the one in the ADbench 220 | dataset_id = DATA_MAP[dataset_name] 221 | uci_dataset = ucimlrepo.fetch_ucirepo(id=dataset_id) 222 | df = uci_dataset.data['original'] 223 | columns = [ s.rstrip('.') for s in uci_dataset['variables']['description'][1:9] ] 224 | 225 | y = df['localization_site'].map({'CYT':0, 'NUC':0, 'MIT':0,'ME3':0, 'ME2':1, 'ME1':1, 'EXC':0, 'VAC':0, 'POX':0, 'ERL':0}) 226 | y = y.to_numpy() 227 | df.drop('localization_site', inplace = True, axis = 1) 228 | df.drop('Sequence_Name', inplace = True, axis = 1) 229 | df.columns = columns 230 | 231 | X = df 232 | 233 | elif dataset_name == 'vifd': 234 | # dataset can be downloaded from https://www.kaggle.com/datasets/khusheekapoor/vehicle-insurance-fraud-detection/data 235 | 236 | df = pd.read_csv( Path(data_dir) / 'vifd'/ 'carclaims.csv') 237 | y = df['FraudFound'].map({"Yes":1, "No":0}) 238 | y = y.to_numpy() 239 | 240 | df.drop('FraudFound', axis = 1, inplace = True) 241 | def split_on_uppercase(s): 242 | return ''.join(' ' + i if i.isupper() else i for i in s).lower().strip() 243 | columns = [ split_on_uppercase(c) for c in df.columns] 244 | 245 | df.columns = columns 246 | X = df 247 | 248 | elif dataset_name == 'arrhythmia': 249 | data_path = Path(data_dir) / 'arrhythmia' / 'arrhythmia.mat' 250 | if not os.path.exists(data_path): 251 | print("Please download the dataset from https://odds.cs.stonybrook.edu/arrhythmia-dataset/ and put it to data/arrhythmia") 252 | raise ValueError('arrhythmia.mat is not found in {}'.format(data_path)) 253 | data = scipy.io.loadmat(data_path) 254 | X_np, y = data['X'], data['y'] 255 | X = convert_np_to_df(X_np) 256 | 257 | elif dataset_name == 'mulcross': 258 | data_path = Path(data_dir) / 'mulcross' / 'mulcross.arff' 259 | if not os.path.exists(data_path): 260 | print("Please download the dataset from https://www.openml.org/search?type=data&sort=runs&id=40897&status=active and put it to data/mulcross") 261 | raise ValueError('mulcross.arff is not found in {}'.format(data_path)) 262 | data, meta = scipy.io.arff.loadarff(data_path) 263 | X = [ [x[i] for i in range(4)] for x in data] 264 | X_np = np.array(X) 265 | y = [ x[4] for x in data] 266 | y = [ 0 if y == b'Normal' else 1 for y in y] 267 | y = np.array(y) 268 | X = convert_np_to_df(X_np) 269 | elif dataset_name == 'seismic': 270 | # downloaded from https://archive.ics.uci.edu/ml/machine-learning-databases/00266/seismic-bumps.arff 271 | data_path = Path(data_dir) / 'seismic' / 'seismic-bumps.arff' 272 | if not os.path.exists(data_path): 273 | print("Please dwnload the dataset from https://archive.ics.uci.edu/ml/machine-learning-databases/00266/seismic-bumps.arff and put it to data/seismic") 274 | raise ValueError('mulcross.arff is not found in {}'.format(data_path)) 275 | data, meta = scipy.io.arff.loadarff(data_path) 276 | df = pd.DataFrame(data) 277 | 278 | column_replacement = { 279 | 'seismic': 'result of shift seismic hazard assessment in the mine working obtained by the seismic method', 280 | 'seismoacoustic': 'result of shift seismic hazard assessment in the mine working obtained by the seismoacoustic method', 281 | 'shift': 'information about type of a shift', 282 | 'genergy': 'seismic energy recorded within previous shift by the most active geophone (GMax) out of geophones monitoring the longwall', 283 | 'gpuls': 'a number of pulses recorded within previous shift by GMax', 284 | 'gdenergy': 'a deviation of energy recorded within previous shift by GMax from average energy recorded during eight previous shifts', 285 | 'gdpuls': 'a deviation of a number of pulses recorded within previous shift by GMax from average number of pulses recorded during eight previous shifts', 286 | 'ghazard': 'result of shift seismic hazard assessment in the mine working obtained by the seismoacoustic method based on registration coming from GMax only', 287 | 'nbumps': 'the number of seismic bumps recorded within previous shift', 288 | 'nbumps2': 'the number of seismic bumps (in energy range [10^2,10^3)) registered within previous shift', 289 | 'nbumps3': 'the number of seismic bumps (in energy range [10^3,10^4)) registered within previous shift', 290 | 'nbumps4': 'the number of seismic bumps (in energy range [10^4,10^5)) registered within previous shift', 291 | 'nbumps5': 'the number of seismic bumps (in energy range [10^5,10^6)) registered within the last shift', 292 | 'nbumps6': 'the number of seismic bumps (in energy range [10^6,10^7)) registered within previous shift', 293 | 'nbumps7': 'the number of seismic bumps (in energy range [10^7,10^8)) registered within previous shift', 294 | 'nbumps89': 'the number of seismic bumps (in energy range [10^8,10^10)) registered within previous shift', 295 | 'energy': 'total energy of seismic bumps registered within previous shift', 296 | 'maxenergy': 'the maximum energy of the seismic bumps registered within previous shift', 297 | } 298 | # take log on magnitude columns 299 | df['maxenergy'] = np.log(df['maxenergy'].replace(0, 1e-6)) 300 | df['energy'] = np.log(df['energy'].replace(0, 1e-6)) 301 | # Rename the columns 302 | df.rename(columns=column_replacement, inplace=True) 303 | 304 | # Replace categorical values in the columns 305 | df['result of shift seismic hazard assessment in the mine working obtained by the seismic method'] = df['result of shift seismic hazard assessment in the mine working obtained by the seismic method'].replace({b'a': 'lack of hazard', b'b': 'low hazard', b'c': 'high hazard', b'd': 'danger state'}) 306 | df['result of shift seismic hazard assessment in the mine working obtained by the seismoacoustic method'] = df['result of shift seismic hazard assessment in the mine working obtained by the seismoacoustic method'].replace({b'a': 'lack of hazard', b'b': 'low hazard', b'c': 'high hazard', b'd': 'danger state'}) 307 | df['result of shift seismic hazard assessment in the mine working obtained by the seismoacoustic method based on registration coming from GMax only'] = \ 308 | df['result of shift seismic hazard assessment in the mine working obtained by the seismoacoustic method based on registration coming from GMax only'].replace({b'a': 'lack of hazard', b'b': 'low hazard', b'c': 'high hazard', b'd': 'danger state'}) 309 | df['information about type of a shift'] = df['information about type of a shift'].replace({'W': 'coal-getting', 'N': 'preparation shift'}) 310 | 311 | y = df['class'].map({b'0':0,b'1':1}) 312 | y = y.to_numpy() 313 | 314 | df.drop('class', inplace = True, axis = 1) 315 | X = df 316 | 317 | elif dataset_name == 'fraudecom': 318 | # data downloaded from https://www.kaggle.com/datasets/vbinh002/fraud-ecommerce/data 319 | # add one index for device id that only appears once 320 | # preprocessing code adapted from https://www.kaggle.com/code/pa4494/catch-the-bad-guys-with-feature-engineering 321 | # remove device id 322 | import calendar 323 | 324 | data_path = Path(data_dir) / 'fraudecom' 325 | dataset = pd.read_csv(data_path / "Fraud_Data.csv") # Users information 326 | IP_table = pd.read_csv(data_path / "IpAddress_to_Country.csv") # Country from IP in 327 | 328 | IP_table.upper_bound_ip_address.astype("float") 329 | IP_table.lower_bound_ip_address.astype("float") 330 | dataset.ip_address.astype("float") 331 | 332 | # function that takes an IP address as argument and returns country associated based on IP_table 333 | 334 | def IP_to_country(ip) : 335 | try : 336 | return IP_table.country[(IP_table.lower_bound_ip_address < ip) 337 | & 338 | (IP_table.upper_bound_ip_address > ip)].iloc[0] 339 | except IndexError : 340 | return "Unknown" 341 | 342 | # To affect a country to each IP : 343 | dataset["IP_country"] = dataset.ip_address.apply(IP_to_country) 344 | # We convert signup_time and purchase_time en datetime 345 | #dataset = pd.read_csv(data_path / "Fraud_data_with_country.csv") 346 | dataset.signup_time = pd.to_datetime(dataset.signup_time, format = '%Y-%m-%d %H:%M:%S') 347 | dataset.purchase_time = pd.to_datetime(dataset.purchase_time, format = '%Y-%m-%d %H:%M:%S') 348 | 349 | # --- 2 --- 350 | # Column month 351 | dataset["month_purchase"] = dataset.purchase_time.apply(lambda x: calendar.month_name[x.month]) 352 | 353 | # --- 3 --- 354 | # Column week 355 | dataset["weekday_purchase"] = dataset.purchase_time.apply(lambda x: calendar.day_name[x.weekday()]) 356 | # --- 4 --- 357 | # map the device id that appears only once to 0 358 | device_duplicates = pd.DataFrame(dataset.groupby(by = "device_id").device_id.count()) # at this moment, index column name and first column name both are equal to "device_id" 359 | device_duplicates.rename(columns={"device_id": "freq_device"}, inplace=True) # hence we need to replace the "device_id" column name 360 | device_duplicates.reset_index(level=0, inplace= True) # and then we turn device_id from index to column 361 | 362 | dataset = dataset.merge(device_duplicates, on= "device_id") 363 | indices = dataset[dataset.freq_device == 1].index 364 | dataset.loc[indices, "device_id"]= "0" 365 | 366 | le = LabelEncoder() 367 | dataset['device_id'] = le.fit_transform(dataset['device_id']).astype('object') 368 | for column in ['user_id', 'signup_time', 'purchase_time', 'ip_address', 'freq_device']: 369 | dataset.drop(column, axis=1, inplace = True) 370 | 371 | dataset.columns = [name.replace('_', ' ') for name in dataset.columns ] 372 | y = dataset['class'].to_numpy() 373 | X = dataset.drop("class", axis = 1) 374 | X = dataset.drop("device id", axis = 1) 375 | 376 | elif dataset_name == 'fakejob': 377 | # data download link: https://www.kaggle.com/datasets/shivamb/real-or-fake-fake-jobposting-prediction?select=fake_job_postings.csv 378 | df = pd.read_csv( Path(data_dir) / 'fakejob'/ 'fake_job_postings.csv') 379 | 380 | # deal with Nan values 381 | df['location'].fillna('Unknown', inplace=True) 382 | df['department'].fillna('Unknown', inplace=True) 383 | df['salary_range'].fillna('Not Specified', inplace=True) 384 | df['employment_type'].fillna('Not Specified', inplace=True) 385 | df['required_experience'].fillna('Not Specified', inplace=True) 386 | df['required_education'].fillna('Not Specified', inplace=True) 387 | df['industry'].fillna('Not Specified', inplace=True) 388 | df['function'].fillna('Not Specified', inplace=True) 389 | df.drop('job_id', inplace=True, axis=1) 390 | 391 | text_columns = ['title', 'company_profile', 'description', 'requirements', 'benefits'] 392 | df[text_columns] = df[text_columns].fillna('NaN') 393 | 394 | y = df['fraudulent'].to_numpy() 395 | X = df.drop('fraudulent', axis=1) 396 | X.columns = [name.replace('_', ' ') for name in X.columns ] 397 | 398 | 399 | elif dataset_name.startswith('20news-'): 400 | def data_generator(subsample=None, target_label=None): 401 | dataset = fetch_20newsgroups(subset='train') 402 | groups = [['comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x'], 403 | ['rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey'], 404 | ['sci.crypt', 'sci.electronics', 'sci.med', 'sci.space'], 405 | ['misc.forsale'], 406 | ['talk.politics.misc', 'talk.politics.guns', 'talk.politics.mideast'], 407 | ['talk.religion.misc', 'alt.atheism', 'soc.religion.christian']] 408 | 409 | def flatten(l): 410 | return [item for sublist in l for item in sublist] 411 | label_list = dataset['target_names'] 412 | label = [] 413 | for _ in dataset['target']: 414 | _ = label_list[_] 415 | if _ not in flatten(groups): 416 | raise NotImplementedError 417 | 418 | for i, g in enumerate(groups): 419 | if _ in g: 420 | label.append(i) 421 | break 422 | label = np.array(label) 423 | print("Number of labels", len(label)) 424 | idx_n = np.where(label==target_label)[0] 425 | idx_a = np.where(label!=target_label)[0] 426 | label[idx_n] = 0 427 | label[idx_a] = 1 428 | # subsample 429 | if int(subsample * 0.95) > sum(label == 0): 430 | pts_n = sum(label == 0) 431 | pts_a = int(0.05 * pts_n / 0.95) 432 | else: 433 | pts_n = int(subsample * 0.95) 434 | pts_a = int(subsample * 0.05) 435 | 436 | idx_n = np.random.choice(idx_n, pts_n, replace=False) 437 | idx_a = np.random.choice(idx_a, pts_a, replace=False) 438 | idx = np.append(idx_n, idx_a) 439 | np.random.shuffle(idx) 440 | 441 | text = [dataset['data'][i] for i in idx] 442 | label = label[idx] 443 | del dataset 444 | 445 | text = [_.strip().replace('
', '') for _ in text] 446 | 447 | print(f'number of normal samples: {sum(label==0)}, number of anomalies: {sum(label==1)}') 448 | 449 | return text, label 450 | target_label = int(dataset_name.split('-')[1]) 451 | text, label = data_generator(subsample=10000, target_label=target_label) 452 | y = label 453 | X = pd.DataFrame(data = text, columns = ['text']) 454 | 455 | elif dataset_name in DATA_MAP.keys(): 456 | # datasets from ADBench 457 | dataset_root = Path(adbench.__file__).parent.absolute() / "datasets/Classical" 458 | n = DATA_MAP[dataset_name] 459 | for npz_file in os.listdir(dataset_root): 460 | if npz_file.startswith(str(n) + '_'): 461 | print(dataset_name, npz_file) 462 | data = np.load(dataset_root / npz_file, allow_pickle=False) 463 | break 464 | else: 465 | ValueError('{} is not found.'.format(dataset_name)) 466 | X_np, y = data['X'], data['y'] 467 | X = convert_np_to_df(X_np) 468 | else: 469 | raise ValueError('Invalid dataset name {}'.format(dataset_name)) 470 | 471 | assert len(X) == len(y) 472 | 473 | with open(pkl_file, 'wb') as f: 474 | pickle.dump((X,y), f) 475 | 476 | return X, y 477 | 478 | def load_adbench_data(dataset): 479 | dataset_root = Path(adbench.__file__).parent.absolute() / "datasets/Classical" 480 | if not os.path.exists(dataset_root): 481 | from adbench.myutils import Utils 482 | Utils().download_datasets(repo='jihulab') 483 | 484 | if dataset == 'cardio': 485 | return np.load(dataset_root / '6_cardio.npz', allow_pickle=False) 486 | 487 | for npz_file in os.listdir(dataset_root): 488 | if dataset in npz_file.lower(): 489 | return np.load(dataset_root / npz_file, allow_pickle=False) 490 | else: 491 | ValueError('{} is not found.'.format(dataset)) 492 | 493 | def split_data( 494 | X: pd.DataFrame, 495 | dataset_name: str, 496 | n_splits: int, 497 | data_dir: str, 498 | train_ratio: Optional[float] = 0.5, 499 | y: Optional[np.ndarray] = None, # should be provided in semi-supervised settinig 500 | seed: Optional[int] = 42, 501 | setting: Optional[str] = 'semi_supervised' 502 | ) -> tuple: # list of train indices and test indices 503 | np.random.seed(seed) 504 | #save path 505 | split_dir = Path(data_dir) / dataset_name / setting / 'split{}'.format(n_splits) 506 | os.makedirs(split_dir, exist_ok = True) 507 | 508 | train_indices, test_indices = [], [] 509 | for i in range(n_splits): 510 | pkl_file = split_dir / 'index{}.pkl'.format(i) 511 | if os.path.exists(pkl_file): 512 | with open(pkl_file, 'rb') as f: 513 | train_index, test_index = pickle.load(f) 514 | else: 515 | if setting == 'unsupervised': 516 | normal_data_indices = np.where(y==0)[0] 517 | anormal_data_indices = np.where(y==1)[0] 518 | normal_index = np.random.permutation(normal_data_indices) 519 | anormal_index = np.random.permutation(anormal_data_indices) 520 | 521 | train_index = np.concatenate([normal_index[:int(train_ratio * len(normal_index))], anormal_index[:int(train_ratio * len(anormal_index))]]) 522 | test_index = np.concatenate([normal_index[int(train_ratio * len(normal_index)):], anormal_index[int(train_ratio * len(anormal_index)):]]) 523 | elif setting == 'semi_supervised': 524 | normal_data_indices = np.where(y==0)[0] 525 | anormal_data_indices = np.where(y==1)[0] 526 | data_length = len(normal_data_indices) 527 | index = np.random.permutation(normal_data_indices) 528 | 529 | train_index = index[:int(train_ratio * data_length)] 530 | test_index = index[int(train_ratio * data_length):] 531 | test_index = np.concatenate([test_index, anormal_data_indices]) 532 | else: 533 | raise ValueError('Invalid setting. Choose either unsupervised or semi_supervised') 534 | train_index = np.random.permutation(train_index) 535 | test_index = np.random.permutation(test_index) 536 | with open(pkl_file, 'wb') as f: 537 | pickle.dump((train_index, test_index), f) 538 | train_indices.append(train_index) 539 | test_indices.append(test_index) 540 | return train_indices, test_indices 541 | 542 | def convert_np_to_df(X_np): 543 | n_train, n_cols = X_np.shape 544 | # Add missing column names 545 | L = list(string.ascii_uppercase) + [letter1+letter2 for letter1 in string.ascii_uppercase for letter2 in string.ascii_uppercase] 546 | columns = [ L[i] for i in range(n_cols) ] 547 | df = pd.DataFrame(data = X_np, columns = columns) 548 | return df 549 | 550 | def load_data(args): 551 | dataset_dir = Path(args.data_dir) / args.dataset 552 | X, y = load_dataset(args.dataset, args.data_dir) 553 | 554 | if 'binning' in args and args.binning != 'none': 555 | X = normalize(X, args.binning, args.n_buckets) 556 | if 'remove_feature_name' in args and args.remove_feature_name: 557 | print("Removing column names and category names.") 558 | L = list(ascii_uppercase) + [letter1+letter2 for letter1 in ascii_uppercase for letter2 in ascii_uppercase] 559 | X.columns = [ L[i] for i in range(len(X.columns))] 560 | 561 | categorical_data = X.select_dtypes(include = ['object']) 562 | categorical_columns = categorical_data.columns.tolist() 563 | le = LabelEncoder() 564 | for i in categorical_data.columns: 565 | categorical_data[i] = le.fit_transform(categorical_data[i]) 566 | 567 | X_prime = X.drop(categorical_columns, axis = 1) 568 | X = pd.concat([X_prime, categorical_data], axis = 1) 569 | 570 | if 'train_ratio' not in args: 571 | args.train_ratio = 0.5 572 | if 'seed' not in args: 573 | args.seed = 42 574 | train_indices, test_indices = split_data(X, args.dataset, args.n_splits, args.data_dir, 575 | args.train_ratio, y = y, seed = args.seed, setting = args.setting ) 576 | train_index, test_index = train_indices[args.split_idx], test_indices[args.split_idx] 577 | X_train, X_test = X.loc[train_index], X.loc[test_index] 578 | y_train, y_test = y[train_index], y[test_index] 579 | 580 | return X_train, X_test, y_train, y_test 581 | 582 | def normalize(X, method, n_buckets): 583 | # method: ['quantile', 'equal_width', 'language', 'none', 'standard'] 584 | # n_buckets: 0-100 585 | X = copy.deepcopy(X) 586 | def ordinal(n): 587 | if np.isnan(n): 588 | return 'NaN' 589 | n = int(n) 590 | if 10 <= n % 100 <= 20: 591 | suffix = 'th' 592 | else: 593 | suffix = {1: 'st', 2: 'nd', 3: 'rd'}.get(n % 10, 'th') 594 | return 'the ' + str(n) + suffix + ' percentile' 595 | 596 | word_list = ['Minimal', 'Slight', 'Moderate', 'Noticeable', 'Considerable', 'Significant', 'Substantial', 'Major', 'Extensive', 'Maximum'] 597 | def get_word(n): 598 | n = int(n) 599 | if n == 10: 600 | return word_list[-1] 601 | return word_list[n] 602 | 603 | if method == 'quantile': 604 | for column in X.columns: 605 | if X[column].dtype in ['float64', 'int64', 'uint8', 'int16'] and X[column].nunique() > 1: 606 | ranks = X[column].rank(method='min') 607 | X[column] = ranks / len(X[column]) * 100 608 | X[column] = X[column].apply(ordinal) 609 | 610 | elif method == 'equal_width': 611 | for column in X.columns: 612 | if X[column].dtype in ['float64', 'int64', 'uint8', 'int16']: 613 | if X[column].nunique() > 1: 614 | X[column] = X[column].astype('float64') 615 | X[column] = (X[column] - X[column].min()) / (X[column].max() - X[column].min()) * n_buckets 616 | 617 | if 10 % n_buckets == 0: 618 | X[column] = X[column].round(0) / 10 619 | X[column] = X[column].round(1) 620 | else: 621 | X[column] = X[column].round(0) / 100 622 | X[column] = X[column].round(2) 623 | elif method == 'standard': 624 | for column in X.columns: 625 | if X[column].dtype in ['float64', 'int64', 'uint8', 'int16']: 626 | scaler = StandardScaler() 627 | scaler.fit(X[column].values.reshape(-1,1)) 628 | X[column] = scaler.transform(X[column].values.reshape(-1,1)) 629 | X[column] = X[column].round(1) 630 | 631 | elif method == 'language': 632 | for column in X.columns: 633 | if X[column].dtype in ['float64', 'int64', 'uint8', 'int16'] and X[column].nunique() > 1: 634 | X[column] = X[column].astype('float64') 635 | X[column] = (X[column] - X[column].min()) / (X[column].max() - X[column].min()) * 10 636 | X[column] = X[column].apply(get_word) 637 | else: 638 | raise ValueError('Invalid method. Choose either percentile, language or decimal') 639 | return X 640 | 641 | def get_text_columns(dataset_name): 642 | text_columns = [] 643 | if dataset_name == 'fakejob': 644 | text_columns = ['title', 'company profile', 'description', 'requirements', 'benefits'] 645 | elif 'fakenews' == dataset_name: 646 | text_columns = ['title', 'text'] 647 | elif '20news' in dataset_name: 648 | text_columns = ['text'] 649 | return text_columns 650 | 651 | def get_max_length_dict(dataset_name): 652 | max_length_dict = {} 653 | if dataset_name == 'fakejob': 654 | max_length_dict['title'] = 20 655 | text_columns = ['company profile', 'description', 'requirements', 'benefits'] 656 | for col in text_columns: 657 | max_length_dict[col] = 700 658 | elif 'fakenews' == dataset_name: 659 | max_length_dict['title'] = 30 660 | max_length_dict['text'] = 500 661 | elif '20news' in dataset_name: 662 | max_length_dict['text'] = 1000 663 | return max_length_dict 664 | 665 | def df_to_numpy( 666 | X: pd.DataFrame, 667 | dataset_name: Optional[str] = None, 668 | method: Optional[str] = 'ordinal', 669 | normalize_numbers: Optional[bool] = False, 670 | verbose: Optional[bool] = False, 671 | textual_encoding: Optional[str] = 'word2vec', # bag_of_words, tfidf, word2vec, or none 672 | textual_columns: Optional[list] = None 673 | ) -> np.ndarray: 674 | if dataset_name == 'ecoli': 675 | X_np = X.drop(X.columns[0], axis=1).to_numpy() 676 | return X_np 677 | 678 | numeric_data = X.select_dtypes(include = ['float64', 'int64', 'uint8', 'int16', 'float32']) 679 | numeric_columns = numeric_data.columns.tolist() 680 | categorical_data = X.select_dtypes(include = ['object', 'category']) 681 | categorical_columns = categorical_data.columns.tolist() 682 | 683 | if verbose: 684 | print("Number of categorical data", len(categorical_columns)) 685 | print("Categorical columns:", categorical_columns) 686 | 687 | # fill na 688 | if len(numeric_columns) > 0: 689 | for numeric_col in numeric_columns: 690 | X[numeric_col] = X[numeric_col].fillna(X[numeric_col].mean()) 691 | 692 | if normalize_numbers: 693 | # normalize it to have zero mean and unit variance 694 | scaler = StandardScaler() 695 | X[numeric_columns] = scaler.fit_transform(X[numeric_columns]) 696 | 697 | # Handle textual data 698 | if textual_encoding == 'none' and len(textual_columns) > 0: 699 | for col in textual_columns: 700 | categorical_columns.remove(col) 701 | X = X.drop(columns = textual_columns) 702 | textual_columns = [] 703 | 704 | if len(textual_columns) > 0: 705 | if textual_encoding == 'word2vec': 706 | model = api.load('word2vec-google-news-300') 707 | tmp = X[textual_columns].agg(' '.join, axis=1) 708 | X_vecs = [] 709 | for i in range(len(X)): 710 | words = [] 711 | for word in tmp[i].split(): 712 | if word in model.key_to_index: 713 | words.append(word) 714 | # Compute the average word embedding 715 | if words: # Ensure there are valid words left 716 | word_vectors = [model[word] for word in words] 717 | X_vec = np.mean(word_vectors, axis=0) 718 | else: 719 | X_vec = np.zeros(model.vector_size) # Handle the case where no words are in the vocabulary 720 | X_vecs.append(X_vec) 721 | X_vecs = np.array(X_vecs) 722 | for col in textual_columns: 723 | categorical_columns.remove(col) 724 | 725 | elif textual_encoding == 'bag_of_words': 726 | corpus = [] 727 | for col in textual_columns: 728 | for i in range(len(X)): 729 | corpus.append(X[col][i]) 730 | vectorization = CountVectorizer(max_features = 300) 731 | vectorization.fit(corpus) 732 | tmp = X[textual_columns].agg(' '.join, axis=1) 733 | X_vecs = vectorization.transform(tmp).todense() 734 | 735 | for col in textual_columns: 736 | categorical_columns.remove(col) 737 | 738 | elif textual_encoding == 'tfidf': 739 | corpus = [] 740 | for col in textual_columns: 741 | for i in range(len(X)): 742 | corpus.append(X[col][i]) 743 | vectorization = TfidfVectorizer(max_features = 300) 744 | vectorization.fit(corpus) 745 | tmp = X[textual_columns].agg(' '.join, axis=1) 746 | X_vecs = vectorization.transform(tmp).todense() 747 | 748 | for col in textual_columns: 749 | categorical_columns.remove(col) 750 | 751 | else: 752 | raise ValueError('Invalid textual encoding. Choose either bag_of_words, tf-idf or word2vec') 753 | X = X.drop(columns = textual_columns) 754 | X = pd.concat([X, pd.DataFrame(X_vecs)], axis = 1) 755 | 756 | 757 | if len(categorical_columns) > 0: 758 | # categorical features: 759 | # group categories with low frequency into a single category 760 | encoder = RareLabelEncoder( 761 | tol=0.01, # Minimum frequency to be considered as a separate class 762 | max_n_categories=None, # Maximum number of categories to keep 763 | replace_with='Rare', # Value to replace rare categories with 764 | variables=categorical_columns , # Columns to encode 765 | missing_values='ignore', 766 | ) 767 | X = encoder.fit_transform(X) 768 | 769 | # Remove columns that contain identical values 770 | X = X.loc[:, (X != X.iloc[0]).any()] 771 | 772 | # remove categories that have only one value 773 | for column in categorical_columns: 774 | if X[column].nunique() == 1: 775 | X.drop(column, inplace = True, axis = 1) 776 | 777 | if method == 'ordinal': 778 | le = LabelEncoder() 779 | for i in categorical_data.columns: 780 | categorical_data[i] = le.fit_transform(categorical_data[i]) 781 | elif method == 'one_hot': 782 | enc = OneHotEncoder(handle_unknown='ignore', sparse_output=False, drop='first') 783 | one_hot_encoded = enc.fit_transform(X[categorical_columns]) 784 | categorical_data = pd.DataFrame(one_hot_encoded, columns=enc.get_feature_names_out(categorical_columns)) 785 | else: 786 | raise ValueError('Invalid method. Choose either ordinal or one_hot') 787 | X_prime = X.drop(categorical_columns, axis = 1) 788 | X = pd.concat([X_prime, categorical_data], axis = 1) 789 | # remove columns that contain identical values 790 | print(X.shape) 791 | X = X.loc[:, (X != X.iloc[0]).any()] 792 | X_np = X.to_numpy() 793 | return X_np 794 | 795 | def print_dataset_information(dataset, data_dir): 796 | print("-"*100) 797 | print("Dataset: {}".format(dataset)) 798 | X, y = load_dataset(dataset, data_dir) 799 | #print(X['company profile'][:3]) 800 | print(X.columns) 801 | train_indices, test_indices = split_data(X, dataset, 5, data_dir, 802 | 0.5, y = y, seed = 42, setting = 'semi_supervised' ) 803 | print("Dtypes of columns:", X.dtypes) 804 | X_np = df_to_numpy(X, dataset_name = dataset, method = 'one_hot', verbose = True, textual_encoding='word2vec') 805 | print("Number of training samples:", len(train_indices[0])) 806 | print("Number of testing samples:", len(test_indices[0])) 807 | print("Number of anomalies: {:f} ({:.2f}%)".format(np.sum(y), np.sum(y)/len(y) * 100)) 808 | print("Number of features:", len(X.columns)) 809 | print("Number of feature dimensions", X_np.shape[1]) 810 | 811 | def filter_anomalies(X_test, y_test): 812 | X_test = X_test[y_test == 0] 813 | y_test = y_test[y_test == 0] 814 | return X_test, y_test 815 | 816 | 817 | if __name__ == '__main__': 818 | #print_dataset_information('fakejob', 'data') 819 | print_dataset_information('20news-6', 'data') 820 | exit() -------------------------------------------------------------------------------- /src/get_avg_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from data_utils import DATA_MAP, MIXED, ODDS 7 | from get_results import get_metrics, aggregate_results, filter_results 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--dataset", type = str, default='all', choices = ['all', 'mixed', 'odds']) # subset: 10 datasets that contain feature names 12 | parser.add_argument("--setting", type = str, default='semi_supervised', choices = ['semi_supervised', 'unsupervised']) 13 | parser.add_argument("--exp_dir", type = str, default=None) 14 | parser.add_argument("--metric", type = str, choices =["AUC-ROC", "F1", "AUC-PR"], default='AUC-ROC') 15 | #dataset hyperparameters 16 | parser.add_argument("--data_dir", type = str, default='data') 17 | parser.add_argument("--n_splits", type = int, default=5) 18 | parser.add_argument("--only_normalized", action='store_true', default=False) 19 | parser.add_argument("--only_ordinal", action='store_true', default=False) 20 | args = parser.parse_args() 21 | 22 | return args 23 | 24 | def main(): 25 | args = get_args() 26 | roc_scores = {} 27 | ranking_dict = {} 28 | std_scores = {} 29 | if args.dataset == 'all': 30 | DATASETS = [k for k in DATA_MAP.keys()] 31 | elif args.dataset == 'odds': 32 | DATASETS = ODDS 33 | elif args.dataset == 'mixed': 34 | DATASETS = MIXED 35 | # sorted datasets by alphabetic order 36 | DATASETS = sorted(DATASETS) 37 | all_rocs = {} 38 | for dataset_idx, dataset in enumerate(DATASETS): 39 | try: 40 | print("*"*100) 41 | print(dataset) 42 | args.split_idx = None 43 | args.dataset = dataset 44 | L = [] 45 | for i in range(args.n_splits): 46 | args.split_idx = i 47 | args.exp_dir = None 48 | results = get_metrics(args, only_normalized=args.only_normalized, only_ordinal=args.only_ordinal) 49 | L.append(results) 50 | metrics, rankings = aggregate_results(L) 51 | except: 52 | print("Error in dataset: ", dataset) 53 | continue 54 | 55 | metrics = filter_results(metrics) 56 | for k in metrics.keys(): 57 | 58 | if k not in all_rocs: 59 | all_rocs[k] = np.zeros((len(DATASETS), args.n_splits)) 60 | 61 | for i in range(args.n_splits): 62 | all_rocs[k][dataset_idx] = metrics[k][args.metric] 63 | 64 | roc_scores[dataset] = { k: np.mean(metrics[k][args.metric]) for k in metrics.keys()} 65 | ranking_dict[dataset] = {k: int(rankings[args.metric][idx]) for idx, k in enumerate(metrics.keys())} 66 | std_scores[dataset] = { k: np.std(metrics[k][args.metric]) for k in metrics.keys()} 67 | 68 | df = pd.DataFrame(roc_scores).T 69 | avg_row = df.mean(axis=0) 70 | df.loc['avg'] = avg_row 71 | df = df.round(3) 72 | print(df) 73 | df.to_csv('exp/{}_avg_{}.csv'.format(args.setting, args.metric)) 74 | 75 | ''' 76 | ranking_df = pd.DataFrame(ranking_dict).T 77 | avg_row = ranking_df.mean(axis=0) 78 | ranking_df.loc['avg'] = avg_row 79 | ranking_df = ranking_df.round(3) 80 | print(ranking_df) 81 | ranking_df.to_csv('exp/{}_avg_ranking.csv'.format(args.setting)) 82 | ''' 83 | # std 84 | std_df = pd.DataFrame(std_scores).T 85 | avg_std = [] 86 | for c in std_df.columns: 87 | if c not in all_rocs: 88 | avg_std.append(0) 89 | continue 90 | std = np.std( np.mean(all_rocs[c], axis=0)) 91 | avg_std.append(std) 92 | std_df.loc['avg'] = avg_std 93 | std_df = std_df.round(3) 94 | print(std_df) 95 | std_df.to_csv('exp/{}_std_{}.csv'.format(args.setting,args.metric)) 96 | 97 | if __name__ == '__main__': 98 | main() 99 | -------------------------------------------------------------------------------- /src/get_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import argparse 4 | 5 | from sklearn import metrics 6 | import numpy as np 7 | import pandas as pd 8 | from data_utils import load_data, DATA_MAP 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--dataset", type = str, default='wine', choices = [d.lower() for d in DATA_MAP.keys()], 14 | help="Name of datasets in the ODDS benchmark") 15 | parser.add_argument("--exp_dir", type = str, default=None) 16 | parser.add_argument("--setting", type = str, default='semi_supervised', choices = ['semi_supervised', 'unsupervised']) 17 | 18 | #dataset hyperparameters 19 | parser.add_argument("--data_dir", type = str, default='data') 20 | parser.add_argument("--n_splits", type = int, default=5) 21 | parser.add_argument("--split_idx", type = int, default=None) # 0 to n_split-1 22 | 23 | args = parser.parse_args() 24 | 25 | return args 26 | 27 | def tabular_metrics(y_true, y_score): 28 | """ 29 | Calculates evaluation metrics for tabular anomaly detection. 30 | Adapted from https://github.com/xuhongzuo/DeepOD/blob/main/deepod/metrics/_anomaly_detection.py 31 | Args: 32 | 33 | y_true (np.array, required): 34 | Data label, 0 indicates normal timestamp, and 1 is anomaly. 35 | 36 | y_score (np.array, required): 37 | Predicted anomaly scores, higher score indicates higher likelihoods to be anomaly. 38 | 39 | Returns: 40 | tuple: A tuple containing: 41 | 42 | - auc_roc (float): 43 | The score of area under the ROC curve. 44 | 45 | - auc_pr (float): 46 | The score of area under the precision-recall curve. 47 | 48 | - f1 (float): 49 | The score of F1-score. 50 | 51 | - precision (float): 52 | The score of precision. 53 | 54 | - recall (float): 55 | The score of recall. 56 | 57 | """ 58 | # F1@k, using real percentage to calculate F1-score 59 | n_test = len(y_true) 60 | new_index = np.random.permutation(n_test) # shuffle y to prevent bias of ordering (argpartition may discard entries with same value) 61 | y_true = y_true[new_index] 62 | y_score = y_score[new_index] 63 | 64 | #ratio = 100.0 * len(np.where(y_true == 0)[0]) / len(y_true) 65 | #thresh = np.percentile(y_score, ratio) 66 | #y_pred = (y_score >= thresh).astype(int) 67 | 68 | top_k = len(np.where(y_true == 1)[0]) 69 | indices = np.argpartition(y_score, -top_k)[-top_k:] 70 | y_pred = np.zeros_like(y_true) 71 | y_pred[indices] = 1 72 | 73 | y_true = y_true.astype(int) 74 | p, r, f1, support = metrics.precision_recall_fscore_support(y_true, y_pred, average='binary') 75 | 76 | return metrics.roc_auc_score(y_true, y_score), metrics.average_precision_score(y_true, y_score), f1, p, r 77 | 78 | def get_metrics(args, only_raw = False, only_normalized = False, only_ordinal = False): 79 | X_train, X_test, y_train, y_test = load_data(args) 80 | if isinstance(y_test, pd.Series): 81 | y_test = np.array(y_test) 82 | #y_test = y_test.to_numpy() 83 | if args.exp_dir is None: 84 | args.exp_dir = Path('exp') / args.dataset / args.setting / "split{}".format(args.n_splits) / "split{}".format(args.split_idx) 85 | score_dir = args.exp_dir / 'scores' 86 | if not os.path.exists(score_dir): 87 | raise ValueError("Score directory {} does not exist".format(score_dir)) 88 | 89 | 90 | method_dict = {} 91 | for score_npy in os.listdir(score_dir): 92 | if '.npy' in score_npy: 93 | if score_npy.startswith('raw'): 94 | continue 95 | if is_baseline(score_npy) and only_normalized: 96 | if 'normalized' not in score_npy: 97 | continue 98 | 99 | elif is_baseline(score_npy) and only_ordinal: 100 | if 'ordinal' not in score_npy: 101 | continue 102 | 103 | 104 | method = '.'.join(score_npy.split('.')[:-1]) 105 | if method == 'rdp': 106 | continue 107 | scores = np.load(score_dir / score_npy) 108 | if np.isnan(scores).any(): 109 | print("NaNs in scores for {}".format(method)) 110 | method_dict[method] = [0, 0, 0, 0, 0] 111 | elif np.isinf(scores).any(): 112 | print("Infs in scores for {}".format(method)) 113 | method_dict[method] = [0, 0, 0, 0, 0] 114 | else: 115 | auc_roc, auc_pr, f1, p, r = tabular_metrics(y_test, scores) 116 | method_dict[method] = [auc_roc, auc_pr, f1, p, r] 117 | # get ranking info for all methods 118 | rankings = [] 119 | method = list(method_dict.keys())[0] 120 | for i in range(len(method_dict[method])): 121 | scores = [-method_dict[k][i] for k in method_dict.keys()] 122 | ranking = np.argsort(scores).argsort() + 1 123 | rankings.append(ranking) 124 | 125 | print("-"*100) 126 | for idx, (k, v) in enumerate(method_dict.items()): 127 | #print("{:30s}: AUC-ROC: {:.4f}, AUC-PR: {:.4f}, F1: {:.4f}".format(k, v[0], v[1], v[2])) 128 | print("{:30s}: AUC-ROC: {:.4f} ({:2d}), AUC-PR: {:.4f} ({:2d}), F1: {:.4f} ({:2d}), P: {:.4f} ({:2d}), R: {:.4f} ({:2d})".format(k, 129 | v[0], rankings[0][idx], 130 | v[1], rankings[1][idx], 131 | v[2], rankings[2][idx], 132 | v[3], rankings[3][idx], 133 | v[4], rankings[4][idx], 134 | )) 135 | 136 | return method_dict 137 | def is_baseline(s): 138 | if 'anollm' in s: 139 | return False 140 | return True 141 | 142 | def filter_results(d:dict): 143 | d2 = {} 144 | for k in d.keys(): 145 | new_key = k 146 | if is_baseline(k): 147 | #baselines 148 | d2[k] = d[k] 149 | else: 150 | if '_lora' in k: 151 | temp = k.replace('_lora', '') 152 | if temp in d: 153 | continue 154 | else: 155 | new_key = new_key.replace('_lora', '') 156 | d2[new_key] = d[k] 157 | else: 158 | d2[new_key] = d[k] 159 | return d2 160 | 161 | def aggregate_results(m_dicts): 162 | aggregate_results = {k: {'AUC-ROC':[], 'AUC-PR': [], 'F1': [], 'P': [], 'R':[]} for k in m_dicts[0].keys()} 163 | for i in range(len(m_dicts)): 164 | all_keys = list(m_dicts[0].keys()) 165 | for k in all_keys: 166 | try: 167 | aggregate_results[k]['AUC-ROC'] += [m_dicts[i][k][0]] 168 | aggregate_results[k]['AUC-PR'] += [m_dicts[i][k][1]] 169 | aggregate_results[k]['F1'] += [m_dicts[i][k][2]] 170 | aggregate_results[k]['P'] += [m_dicts[i][k][3]] 171 | aggregate_results[k]['R'] += [m_dicts[i][k][4]] 172 | except: 173 | print("Incomplete results for ", k) 174 | if k in aggregate_results: 175 | del aggregate_results[k] 176 | for i in range(len(m_dicts)): 177 | if k in m_dicts[i]: 178 | del m_dicts[i][k] 179 | continue 180 | 181 | print("-"*100) 182 | 183 | # get ranking info for all methods 184 | rankings = {} 185 | key = list(m_dicts[0].keys())[0] 186 | for metric_name in aggregate_results[key].keys(): 187 | scores = [-np.mean(aggregate_results[k][metric_name]) for k in aggregate_results.keys()] 188 | ranking = np.argsort(scores).argsort() + 1 189 | rankings[metric_name] = ranking 190 | 191 | for idx, k in enumerate(aggregate_results.keys()): 192 | print("{:30s}: AUC-ROC: {:.4f} +- {:.4f} ({:2d}), AUC-PR: {:.4f} +- {:.4f} ({:2d}), F1: {:.4f} +- {:.4f} ({:2d}) P: {:.4f} +- {:.4f} ({:2d}) R: {:.4f} +- {:.4f} ({:2d})".format(k, 193 | np.mean(aggregate_results[k]['AUC-ROC']), np.std(aggregate_results[k]['AUC-ROC']), rankings['AUC-ROC'][idx], 194 | np.mean(aggregate_results[k]['AUC-PR']), np.std(aggregate_results[k]['AUC-PR']), rankings['AUC-PR'][idx], 195 | np.mean(aggregate_results[k]['F1']), np.std(aggregate_results[k]['F1']), rankings['F1'][idx], 196 | np.mean(aggregate_results[k]['P']), np.std(aggregate_results[k]['P']), rankings['P'][idx], 197 | np.mean(aggregate_results[k]['R']), np.std(aggregate_results[k]['R']), rankings['R'][idx], 198 | )) 199 | return aggregate_results, rankings 200 | 201 | def main(): 202 | args = get_args() 203 | if args.split_idx is None: 204 | L = [] 205 | for i in range(args.n_splits): 206 | args.split_idx = i 207 | args.exp_dir = None 208 | results = get_metrics(args) 209 | L.append(results) 210 | aggregate_results(L) 211 | else: 212 | print(args) 213 | scores = get_metrics(args) 214 | 215 | 216 | if __name__ == '__main__': 217 | main() -------------------------------------------------------------------------------- /train_anollm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import argparse 6 | import torch.distributed as dist 7 | import torch 8 | import wandb 9 | import time 10 | 11 | from anollm import AnoLLM 12 | from src.data_utils import load_data, DATA_MAP, get_text_columns, get_max_length_dict 13 | 14 | #run by torchrun --nproc_per_node=8 train_llm.py 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--dataset", type = str, default='wine', choices = [d.lower() for d in DATA_MAP.keys()], 19 | help="Name of datasets in the ODDS benchmark") 20 | parser.add_argument("--exp_dir", type = str, default=None) 21 | parser.add_argument("--setting", type = str, default='semi_supervised', choices = ['semi_supervised', 'unsupervised'], help="semi_supervised:an uncontaminated, unsupervised setting; unsupervised:a contaminated, unsupervised setting") 22 | 23 | # wandb 24 | parser.add_argument("--wandb", action='store_true') 25 | parser.add_argument("--entity", type = str, default = None) 26 | parser.add_argument("--project", type = str, default = 'AnoLLM') 27 | 28 | #dataset hyperparameters 29 | parser.add_argument("--data_dir", type = str, default='data') 30 | parser.add_argument("--n_splits", type = int, default=5) 31 | parser.add_argument("--split_idx", type = int, default=0) # 0 to n_split-1 32 | parser.add_argument("--train_ratio", type = float, default=0.5) 33 | parser.add_argument("--seed", type = int, default=42) 34 | 35 | # preprocessing 36 | parser.add_argument("--binning", type = str, choices=['quantile', 'equal_width', 'language', 'none', 'standard'], default='standard') 37 | parser.add_argument("--n_buckets", type = int, default=10) 38 | parser.add_argument("--remove_feature_name", action = 'store_true') 39 | 40 | #training 41 | parser.add_argument("--model", type = str, choices = ['gpt2', 'distilgpt2', 'smol', 'smol-360', 'smol-1.7b'], default='smol') 42 | parser.add_argument("--batch_size", type = int, default=32) # per gpu, eval_batch_size = 2*batch_size 43 | parser.add_argument("--lr", type = float, default=5e-5) 44 | parser.add_argument("--lora", action='store_true', default=False) 45 | parser.add_argument("--max_steps", type = int, default=2000) 46 | parser.add_argument("--eval_steps", type = int, default = 1000) 47 | parser.add_argument("--random_init", action='store_true', default=False) 48 | parser.add_argument("--no_random_permutation", action='store_true', default=False) 49 | 50 | args = parser.parse_args() 51 | if args.exp_dir is None: 52 | args.exp_dir = Path('exp') / args.dataset / args.setting / "split{}".format(args.n_splits) / "split{}".format(args.split_idx) 53 | 54 | if args.model == 'smol': 55 | args.model = 'HuggingFaceTB/SmolLM-135M' 56 | elif args.model == 'smol-360': 57 | args.model = 'HuggingFaceTB/SmolLM-360M' 58 | elif args.model == 'smol-1.7b': 59 | args.model = 'HuggingFaceTB/SmolLM-1.7B' 60 | 61 | args.save_dir = Path(args.exp_dir) / 'models' # save to save models 62 | os.makedirs(args.save_dir, exist_ok = True) 63 | 64 | return args 65 | 66 | def get_run_name(args): 67 | name = 'anollm' 68 | name += '_lr{}'.format(args.lr) 69 | name += '_{}'.format(args.binning) 70 | 71 | if args.model == 'HuggingFaceTB/SmolLM-135M': 72 | name += '_smolLM' 73 | elif args.model == 'HuggingFaceTB/SmolLM-360M': 74 | name += '_smolLM360' 75 | elif args.model == 'HuggingFaceTB/SmolLM-1.7B': 76 | name += '_smolLM1.7B' 77 | else: 78 | name += '_' + args.model 79 | 80 | if args.random_init: 81 | name += '_random_init' 82 | 83 | if args.no_random_permutation: 84 | name += '_no_random_permutation' 85 | 86 | if args.lora: 87 | name += '_lora' 88 | name += "_test" 89 | return name 90 | 91 | 92 | def main(): 93 | # Set CUDA devices for each process 94 | local_rank = int(os.environ["LOCAL_RANK"]) 95 | torch.cuda.set_device(local_rank) 96 | 97 | args = get_args() 98 | if dist.get_rank() == 0: 99 | X_train, X_test, y_train, y_test = load_data(args) 100 | dist.barrier() 101 | if dist.get_rank() != 0: 102 | X_train, X_test, y_train, y_test = load_data(args) 103 | dist.barrier() 104 | 105 | run_name = get_run_name(args) 106 | efficient_finetuning = 'lora' if args.lora else '' 107 | model_path = args.save_dir / '{}.pt'.format(run_name) 108 | dataset_tmp_path = args.save_dir / (run_name + '_data') 109 | 110 | os.makedirs(dataset_tmp_path, exist_ok= True) 111 | print("Model path:", model_path) 112 | #if False: 113 | if os.path.exists(model_path): 114 | print("Model exists, skip training") 115 | return 116 | 117 | max_length_dict = get_max_length_dict(args.dataset) 118 | text_columns = get_text_columns(args.dataset) 119 | def get_model(): 120 | model = AnoLLM(args.model, 121 | batch_size=args.batch_size, 122 | max_steps = args.max_steps, 123 | efficient_finetuning = efficient_finetuning, 124 | max_length_dict=max_length_dict, 125 | textual_columns = text_columns, 126 | random_init=args.random_init, 127 | no_random_permutation=args.no_random_permutation, 128 | bf16=True, 129 | adam_beta2=0.99, 130 | adam_epsilon=1e-7, 131 | learning_rate=args.lr, 132 | ) 133 | return model 134 | # Initialize the LLM 135 | if dist.get_rank() == 0: 136 | anollm = get_model() 137 | dist.barrier() 138 | if dist.get_rank() != 0: 139 | anollm = get_model() 140 | dist.barrier() 141 | # Move the model to the appropriate GPU 142 | anollm.model.to(local_rank) 143 | 144 | # Wrap the model for distributed training 145 | anollm.model = torch.nn.parallel.DistributedDataParallel( 146 | anollm.model, device_ids=[local_rank], output_device=local_rank 147 | ) 148 | if args.wandb and dist.get_rank() == 0: 149 | run = wandb.init( 150 | entity=args.entity, 151 | project=args.project, 152 | name = "{}_splits{}_{}_{}".format(args.dataset, args.split_idx, args.n_splits, run_name), 153 | ) 154 | if len(X_test) > 3000: 155 | np.random.seed(args.seed) 156 | X_test.reset_index(drop = True, inplace = True) 157 | indices = np.random.choice(len(X_test), 3000, replace = False) 158 | X_test = X_test.loc[indices].reset_index(drop = True) 159 | y_test = y_test[indices] 160 | if not args.wandb: 161 | X_test, y_test = None, None 162 | 163 | # Train the model 164 | start_time = time.time() 165 | trainer = anollm.fit(X_train, X_train.columns.to_list(), 166 | use_wandb = args.wandb, 167 | data_val=X_test, 168 | label_val = y_test, 169 | eval_steps = args.eval_steps, 170 | processed_data_dir = dataset_tmp_path, 171 | ) 172 | end_time = time.time() 173 | 174 | # Save the model only from rank 0 process 175 | if dist.get_rank() == 0: 176 | 177 | print("Training time:", end_time - start_time) 178 | run_time_dir = args.exp_dir / "run_time" / "train" 179 | os.makedirs(run_time_dir, exist_ok = True) 180 | run_time_path = run_time_dir / "{}.txt".format(run_name) 181 | with open(run_time_path, 'w') as f: 182 | f.write(str(end_time - start_time)) 183 | 184 | print("Save model to ", model_path) 185 | anollm.save_state_dict(model_path) 186 | 187 | 188 | dist.destroy_process_group() 189 | 190 | if __name__ == "__main__": 191 | # Initialize the distributed process group 192 | dist.init_process_group(backend="nccl") 193 | main() --------------------------------------------------------------------------------