├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── examples ├── diffusion_train_generate_example.ipynb ├── new.ipynb ├── rr ├── test.py ├── vONTSS_example.ipynb ├── vONT_example.ipynb └── vONT_example_visualization.ipynb ├── main.py ├── requirements.txt └── src ├── diffusion ├── __init__.py ├── diffuser_training.py ├── diffusion.py └── diffusion_generate.py ├── huggingface ├── __pycache__ │ ├── configuration_detime.cpython-310.pyc │ └── modeling_detime.cpython-310.pyc ├── detime │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── configuration_detime.cpython-310.pyc │ │ └── modeling_detime.cpython-310.pyc │ ├── configuration_detime.py │ └── modeling_detime.py └── test.ipynb ├── src └── topicmodeling │ └── preprocess.py └── topicmodeling ├── CNN_Encoder.py ├── T5_Encoder └── __pycache__ │ ├── CNN_Encoder.cpython-310.pyc │ ├── flanT5_cnn_lighting.cpython-310.pyc │ └── sentence_encoder.cpython-310.pyc ├── __init__.py ├── __pycache__ ├── CNN_Encoder.cpython-310.pyc ├── flanT5_cnn_lighting.cpython-310.pyc ├── model.cpython-310.pyc ├── model.cpython-38.pyc ├── preprocess.cpython-310.pyc ├── sentence_encoder.cpython-310.pyc └── utils.cpython-310.pyc ├── flanT5_cnn_lighting.py ├── hyperspherical_vae ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ └── __init__.cpython-38.pyc ├── distributions │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── hyperspherical_uniform.cpython-310.pyc │ │ ├── hyperspherical_uniform.cpython-38.pyc │ │ ├── von_mises_fisher.cpython-310.pyc │ │ └── von_mises_fisher.cpython-38.pyc │ ├── hyperspherical_uniform.py │ └── von_mises_fisher.py └── ops │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── ive.cpython-310.pyc │ └── ive.cpython-38.pyc │ └── ive.py ├── model.py ├── model_evaluate.ipynb ├── preprocess.ipynb ├── preprocess.py ├── sentence_encoder.py ├── sentence_eval.ipynb ├── utils.py └── vontss.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # DeTiME: Diffusion-Enhanced Topic Modeling using Encoder-decoder based LLM (Accepted by EMNLP 2023 as Findings) 4 | 5 | This repository is the official implementation of [DeTiME: Diffusion-Enhanced Topic Modeling using Encoder-decoder based LLM](https://aclanthology.org/2023.findings-emnlp.606.pdf). 6 | 7 | DeTiME can generate embeddings, do diffusion and 8 | 9 | 10 | ## Installation 11 | 12 | 13 | 14 | To install requirements: 15 | 16 | ```setup 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | 21 | ## Training and Evaluation 22 | 23 | To train and evaluate the model, run this command: 24 | 25 | 26 | 27 | Step 1: If the data is in the huggingface. specify --data_source as the repository of hugging face 28 | If the data is a csv file specify where the data is and specify --data_source csv 29 | Step 2: Define number of topics. if the number is 10 use --numb_embeddings 10 30 | Step 3: Define the metric you want to evaluate, currently it supports diversity, c_v, c_uci, etc 31 | 32 | Then you just have to run 33 | ```train 34 | python3 main.py --data_source xwjzds/ag_news --metric diversity --topk 20 35 | ``` 36 | It will output the diversity metric using data in xwjzds/ag_news 37 | 38 | ## Embedding Explain 39 | 40 | ## Diffusion Explain 41 | 42 | After getting the embedding using the encoders of DeTiME, the diffusion can be leveraged to denoise the embeddings. The denoised embeddings can be passed to the decoders of the DeTiME to generate text. 43 | 44 | The training of diffusor involved two steps. 45 | 46 | Step 1: generate embedding of datasets using the encoders of the DeTiME. The code below shows how to generate embeddings 47 | 48 | ```python 49 | outputs = [] 50 | 51 | text_ls = dataset['summary'] 52 | 53 | batch_size = 2 54 | 55 | batch_ls = [text_ls[ind: ind + batch_size]for ind in range(0, len(text_ls), batch_size)] 56 | 57 | print(dataset) 58 | 59 | for text in tqdm(batch_ls): 60 | 61 | # inputs = tokenizer(text, return_tensors="pt").input_ids 62 | # attention = tokenizer(text, return_tensors="pt").attention_mask 63 | 64 | # add instruction 65 | # text = ['repeat: ' + t for t in text] 66 | 67 | inputs = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length = args.max_length) 68 | 69 | # get the inputs and attention 70 | inputs_id = inputs.input_ids.to(models.device) 71 | attention = inputs.attention_mask.to(models.device) 72 | 73 | output = models.model.encoder(inputs_id, attention).last_hidden_state #batch size * seq length * embedding size, 74 | output = models.encoder(output) 75 | outputs.append(output.detach().cpu()) 76 | 77 | gc.collect() 78 | ``` 79 | 80 | Step 2: train a diffusor using the embeddings. To train a diffusor, the users can leverage 81 | python diffuser_training.py --embedding_input './example/embed_vectors_base_7_1000_prefix.pt' --model_name 'UNet_Conv' --output_dir './example'. Here. embedding_input is the embedding file location, model_name is the diffusor model name to train, output_dir is the location where the trained diffusor saved. 82 | 83 | To generate the text using the deniosed embedding, three steps are involved. 84 | 85 | Step 1: generate embedding of datasets using the encoders of the DeTiME. 86 | 87 | Step 2: denoise the embeddings using the generated embeddings. 88 | ```python 89 | from diffusion.diffusion_generate import generate_diffused_embed, generate_text 90 | # generate from the noise vector 91 | sampling_turn = 2 92 | timesteps = 1000 93 | 94 | x_noise = torch.randn((num_images, 4, latent_dim // 4), device=device) 95 | x_track_ls_ls_noise, x_0_track_ls_ls_noise = generate_diffused_embed(x_noise, model, timesteps, device, batch_size=2, 96 | num_generated_sample=2, return_all_time_embed=True) 97 | ``` 98 | 99 | Step 3: generate text from the denoised embeddings. 100 | 101 | 102 | ## Interactive Code 103 | 104 | Example of using dataset from OCTIS 105 | 106 | 107 | 108 | ```python 109 | from octis.dataset.dataset import Dataset 110 | import sys 111 | sys.path.insert(0, '../src/topicmodeling') 112 | from model import TopicModel 113 | from datasets import load_dataset 114 | from octis.evaluation_metrics.diversity_metrics import TopicDiversity 115 | from octis.evaluation_metrics.coherence_metrics import Coherence 116 | 117 | 118 | dataset = Dataset() 119 | dataset.fetch_dataset("20NewsGroup") #It can support 20NewsGroup, BBC_News, DBLP, DBPedia_IT 120 | tm = TopicModel(numb_embeddings = 10) 121 | texts = [' '.join(i) for i in dataset.get_corpus()] 122 | model_output = tm.train_model(texts) 123 | metric = TopicDiversity(topk=10) 124 | topic_diversity_score = metric.score(model_output) # Compute score of diversity 125 | cmetric = Coherence(texts = tm.tp.lemmas, measure='c_npmi') 126 | coherence = cmetric.score(model_output) # Compute score of coherence 127 | ``` 128 | 129 | Example of using datasets from huggingface 130 | ```python 131 | import sys 132 | sys.path.insert(0, '../src/topicmodeling') 133 | from model import TopicModel 134 | from datasets import load_dataset 135 | from octis.evaluation_metrics.diversity_metrics import TopicDiversity 136 | from octis.evaluation_metrics.coherence_metrics import Coherence 137 | 138 | 139 | df = load_dataset('xwjzds/ag_news') 140 | tm = TopicModel(numb_embeddings = 10) 141 | 142 | model_output = tm.train_model(df['train']['text']) 143 | metric = TopicDiversity(topk=10) 144 | topic_diversity_score = metric.score(model_output) # Compute score of diversity 145 | cmetric = Coherence(texts = tm.tp.lemmas, measure='c_npmi') 146 | coherence = cmetric.score(model_output) # Compute score of coherence 147 | 148 | ``` 149 | 150 | ## Arugument Explain 151 | 152 | Arguments Explained: 153 | 154 | --numb_embeddings: Number of embeddings (default is 10). 155 | 156 | --epochs: Number of epochs for training (default is 20). 157 | 158 | --batch_size: Batch size for training (default is 256). 159 | 160 | --gpu_num: GPU number to use (default is 1). 161 | 162 | --learning_rate: Learning rate (default is 0.002). 163 | 164 | --weight_decay: Weight decay (default is 1.2e-6). 165 | 166 | --penalty: Penalty term (default is 1). 167 | 168 | --beta: Beta value (default is 1). 169 | 170 | --temp: Temperature (default is 10). 171 | 172 | --data_source: Data source type (default is 'huggingface'). Can be 'huggingface', 'csv', or 'txt'. 173 | 174 | --data_path: Path to the data file for 'csv' or 'txt' (default is ''). 175 | 176 | --metrics: List of metrics to report (default is ['diversity', 'c_v', 'c_npmi', 'c_uci', 'u_mass']). 177 | 178 | --topk: Top k words to report for diversity (default is 10). 179 | 180 | 181 | 182 | ## Results 183 | 184 | Our model achieves the following performance on Ag News: 185 | 186 | 187 | 188 | | Model name | Diversity | C_v | C_npmi | 189 | | ------------------ |---------------- | -------------- | -------------- | 190 | | vONT | 0.865 | 0.618 | 0.115 | 191 | | DeTiME | 0.93 | 0.645 | 0.113 | 192 | 193 | 194 | we use existed embeddings in this code relase instead of using spherical embeddings. Training a spherical embeddings takes time. We noticed that this reported performance is better than the performance on our paper. 195 | 196 | 197 | ## Citation 198 | ``` 199 | @inproceedings{xu-etal-2023-vontss, 200 | title = "v{ONTSS}: v{MF} based semi-supervised neural topic modeling with optimal transport", 201 | author = "Xu, Weijie and 202 | Jiang, Xiaoyu and 203 | Sengamedu Hanumantha Rao, Srinivasan and 204 | Iannacci, Francis and 205 | Zhao, Jinjin", 206 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2023", 207 | month = jul, 208 | year = "2023", 209 | address = "Toronto, Canada", 210 | publisher = "Association for Computational Linguistics", 211 | url = "https://aclanthology.org/2023.findings-acl.271", 212 | doi = "10.18653/v1/2023.findings-acl.271", 213 | pages = "4433--4457", 214 | abstract = "Recently, Neural Topic Models (NTM), inspired by variational autoencoders, have attracted a lot of research interest; however, these methods have limited applications in the real world due to the challenge of incorporating human knowledge. This work presents a semi-supervised neural topic modeling method, vONTSS, which uses von Mises-Fisher (vMF) based variational autoencoders and optimal transport. When a few keywords per topic are provided, vONTSS in the semi-supervised setting generates potential topics and optimizes topic-keyword quality and topic classification. Experiments show that vONTSS outperforms existing semi-supervised topic modeling methods in classification accuracy and diversity. vONTSS also supports unsupervised topic modeling. Quantitative and qualitative experiments show that vONTSS in the unsupervised setting outperforms recent NTMs on multiple aspects: vONTSS discovers highly clustered and coherent topics on benchmark datasets. It is also much faster than the state-of-the-art weakly supervised text classification method while achieving similar classification performance. We further prove the equivalence of optimal transport loss and cross-entropy loss at the global minimum.", 215 | } 216 | @inproceedings{xu-etal-2023-detime, 217 | title = "{D}e{T}i{ME}: Diffusion-Enhanced Topic Modeling using Encoder-decoder based {LLM}", 218 | author = "Xu, Weijie and 219 | Hu, Wenxiang and 220 | Wu, Fanyou and 221 | Sengamedu, Srinivasan", 222 | editor = "Bouamor, Houda and 223 | Pino, Juan and 224 | Bali, Kalika", 225 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2023", 226 | month = dec, 227 | year = "2023", 228 | address = "Singapore", 229 | publisher = "Association for Computational Linguistics", 230 | url = "https://aclanthology.org/2023.findings-emnlp.606", 231 | doi = "10.18653/v1/2023.findings-emnlp.606", 232 | pages = "9040--9057", 233 | abstract = "In the burgeoning field of natural language processing, Neural Topic Models (NTMs) and Large Language Models (LLMs) have emerged as areas of significant research interest. Despite this, NTMs primarily utilize contextual embeddings from LLMs, which are not optimal for clustering or capable for topic generation. Our study addresses this gap by introducing a novel framework named Diffusion-Enhanced Topic Modeling using Encoder-Decoder-based LLMs (DeTiME). DeTiME leverages Encoder-Decoder-based LLMs to produce highly clusterable embeddings that could generate topics that exhibit both superior clusterability and enhanced semantic coherence compared to existing methods. Additionally, by exploiting the power of diffusion, our framework also provides the capability to generate content relevant to the identified topics. This dual functionality allows users to efficiently produce highly clustered topics and related content simultaneously. DeTiME{'}s potential extends to generating clustered embeddings as well. Notably, our proposed framework proves to be efficient to train and exhibits high adaptability, demonstrating its potential for a wide array of applications.", 234 | } 235 | ``` -------------------------------------------------------------------------------- /examples/new.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "[nltk_data] Downloading package punkt to /home/weijiexu/nltk_data...\n", 13 | "[nltk_data] Package punkt is already up-to-date!\n", 14 | "[nltk_data] Downloading package wordnet to /home/weijiexu/nltk_data...\n", 15 | "[nltk_data] Package wordnet is already up-to-date!\n", 16 | "[nltk_data] Downloading package averaged_perceptron_tagger to\n", 17 | "[nltk_data] /home/weijiexu/nltk_data...\n", 18 | "[nltk_data] Package averaged_perceptron_tagger is already up-to-\n", 19 | "[nltk_data] date!\n", 20 | "[nltk_data] Downloading package stopwords to\n", 21 | "[nltk_data] /home/weijiexu/nltk_data...\n", 22 | "[nltk_data] Package stopwords is already up-to-date!\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "import sys\n", 28 | "sys.path.insert(0, '../src/topicmodeling')\n", 29 | "from model import TopicModel\n", 30 | "from datasets import load_dataset\n", 31 | "from octis.evaluation_metrics.diversity_metrics import TopicDiversity\n", 32 | "from octis.evaluation_metrics.coherence_metrics import Coherence\n", 33 | "\n", 34 | "\n", 35 | "df = load_dataset('xwjzds/ag_news')\n", 36 | "tm = TopicModel(numb_embeddings = 10)\n", 37 | "\n", 38 | "model_output = tm.train_model(df['train']['text'])\n", 39 | "metric = TopicDiversity(topk=10)\n", 40 | "topic_diversity_score = metric.score(model_output) # Compute score of diversity\n", 41 | "cmetric = Coherence(texts = tm.tp.lemmas, measure='c_npmi')\n", 42 | "coherence1 = cmetric.score(model_output) # Compute score of coherence\n", 43 | "cmetric = Coherence(texts = tm.tp.lemmas, measure='c_v')\n", 44 | "coherence2 = cmetric.score(model_output) # Compute score of coherence\n", 45 | "print(topic_diversity_score, coherence1, coherence2)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 1, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "ename": "NameError", 55 | "evalue": "name 'df' is not defined", 56 | "output_type": "error", 57 | "traceback": [ 58 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 59 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 60 | "\u001b[1;32m/home/weijiexu/workspace/Vontss/examples/new.ipynb Cell 2\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mlen\u001b[39m(df[\u001b[39m'\u001b[39m\u001b[39mtrain\u001b[39m\u001b[39m'\u001b[39m][\u001b[39m'\u001b[39m\u001b[39mtext\u001b[39m\u001b[39m'\u001b[39m])\n", 61 | "\u001b[0;31mNameError\u001b[0m: name 'df' is not defined" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "len(df['train']['text'])" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 7, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 8, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/plain": [ 84 | "0.4406901587312282" 85 | ] 86 | }, 87 | "execution_count": 8, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "coherence" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 21, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "Collecting gensim==4.3.1\n", 106 | " Downloading gensim-4.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.4 MB)\n", 107 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m26.4/26.4 MB\u001b[0m \u001b[31m63.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", 108 | "\u001b[?25hRequirement already satisfied: numpy>=1.18.5 in /local/home/weijiexu/anaconda3/envs/env-01/lib/python3.10/site-packages (from gensim==4.3.1) (1.23.0)\n", 109 | "Requirement already satisfied: scipy>=1.7.0 in /local/home/weijiexu/anaconda3/envs/env-01/lib/python3.10/site-packages (from gensim==4.3.1) (1.10.1)\n", 110 | "Requirement already satisfied: smart-open>=1.8.1 in /local/home/weijiexu/anaconda3/envs/env-01/lib/python3.10/site-packages (from gensim==4.3.1) (6.3.0)\n", 111 | "Installing collected packages: gensim\n", 112 | " Attempting uninstall: gensim\n", 113 | " Found existing installation: gensim 4.2.0\n", 114 | " Uninstalling gensim-4.2.0:\n", 115 | " Successfully uninstalled gensim-4.2.0\n", 116 | "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", 117 | "octis 1.13.1 requires gensim==4.2.0, but you have gensim 4.3.1 which is incompatible.\u001b[0m\u001b[31m\n", 118 | "\u001b[0mSuccessfully installed gensim-4.3.1\n", 119 | "\n", 120 | "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n", 121 | "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "!pip install gensim==4.3.1" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 14, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "data": { 136 | "text/plain": [ 137 | "array([[4.1902823e-07, 1.3587736e-11, 5.3726155e-03, ..., 1.9312135e-06,\n", 138 | " 1.2757047e-02, 5.9493249e-07],\n", 139 | " [7.3549999e-10, 1.0243623e-11, 8.0542906e-09, ..., 4.1704459e-07,\n", 140 | " 3.0640180e-08, 1.8471485e-11],\n", 141 | " [5.7667028e-04, 2.3967863e-07, 3.3781704e-04, ..., 4.3794357e-10,\n", 142 | " 1.9988370e-04, 6.9555554e-08],\n", 143 | " ...,\n", 144 | " [9.4333009e-05, 6.0200409e-06, 1.6074823e-04, ..., 1.8656671e-06,\n", 145 | " 1.5472580e-07, 9.5643941e-03],\n", 146 | " [4.6894229e-06, 1.4929505e-08, 6.0049203e-08, ..., 2.1408705e-01,\n", 147 | " 1.2392248e-09, 1.7268887e-08],\n", 148 | " [2.8513581e-03, 3.7061946e-07, 3.7983521e-03, ..., 1.7031522e-04,\n", 149 | " 5.1026487e-05, 6.4088619e-07]], dtype=float32)" 150 | ] 151 | }, 152 | "execution_count": 14, 153 | "metadata": {}, 154 | "output_type": "execute_result" 155 | } 156 | ], 157 | "source": [ 158 | "model_output['topic-word-matrix']" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 13, 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "data": { 168 | "text/plain": [ 169 | "array([[0.0994494 , 0.0990864 , 0.09963817, ..., 0.09905903, 0.09858026,\n", 170 | " 0.09912238],\n", 171 | " [0.06421395, 0.06424627, 0.06424221, ..., 0.06428529, 0.06446692,\n", 172 | " 0.06432256],\n", 173 | " [0.10132632, 0.10125542, 0.10153911, ..., 0.10124898, 0.10130309,\n", 174 | " 0.10128608],\n", 175 | " ...,\n", 176 | " [0.10297655, 0.10317902, 0.10267786, ..., 0.10305715, 0.10250283,\n", 177 | " 0.10282135],\n", 178 | " [0.07066783, 0.07073216, 0.07061214, ..., 0.07076373, 0.07085302,\n", 179 | " 0.0706471 ],\n", 180 | " [0.15995942, 0.1605845 , 0.16014957, ..., 0.1603556 , 0.16008826,\n", 181 | " 0.16038543]], dtype=float32)" 182 | ] 183 | }, 184 | "execution_count": 13, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "model_output['topic-document-matrix']" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [] 199 | } 200 | ], 201 | "metadata": { 202 | "kernelspec": { 203 | "display_name": "env-01", 204 | "language": "python", 205 | "name": "python3" 206 | }, 207 | "language_info": { 208 | "codemirror_mode": { 209 | "name": "ipython", 210 | "version": 3 211 | }, 212 | "file_extension": ".py", 213 | "mimetype": "text/x-python", 214 | "name": "python", 215 | "nbconvert_exporter": "python", 216 | "pygments_lexer": "ipython3", 217 | "version": "3.10.10" 218 | }, 219 | "orig_nbformat": 4 220 | }, 221 | "nbformat": 4, 222 | "nbformat_minor": 2 223 | } 224 | -------------------------------------------------------------------------------- /examples/rr: -------------------------------------------------------------------------------- 1 | torch.Size([120000, 4096]) 2 | 64 1875 3 | None 4 | torch.Size([120000, 4096]) 5 | (14696, 100) 6 | 270.9416328869395 0.03138272997055417 7 | 200.71236691342742 0.03416816332637628 8 | 185.7224679389742 0.031684735313312076 9 | 179.90556302914487 0.02988790799734562 10 | 176.64112255364847 0.02890335287906722 11 | 174.24899301752606 0.02881038779897222 12 | 172.38616240609176 0.02895136079443162 13 | 170.91855489572228 0.029198516851295032 14 | 169.66865441946587 0.029653682363535296 15 | 168.85879405983476 0.029893301236731156 16 | 168.19305732255296 0.030158104621239308 17 | 167.5825169610062 0.03047366635695195 18 | 166.94266248105177 0.030853741045699698 19 | 166.3755670478349 0.03127345895922896 20 | 165.971091687298 0.03151722982335192 21 | 165.7004096350436 0.03154783026337115 22 | 165.44454559130963 0.03170995326883503 23 | 165.27525988418157 0.031727322859804764 24 | 165.1246534888424 0.03170219998655797 25 | 165.12682871218684 0.031954723265347706 26 | 0.93 0.11319229471256073 0.6448654289335599 27 | -------------------------------------------------------------------------------- /examples/test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../src/topicmodeling') 3 | from model import TopicModel 4 | from datasets import load_dataset 5 | from octis.evaluation_metrics.diversity_metrics import TopicDiversity 6 | from octis.evaluation_metrics.coherence_metrics import Coherence 7 | 8 | 9 | df = load_dataset('xwjzds/ag_news') 10 | tm = TopicModel(numb_embeddings = 10) 11 | 12 | model_output = tm.train_model(df['train']['text'], name = 'agnews') 13 | metric = TopicDiversity(topk=10) 14 | topic_diversity_score = metric.score(model_output) # Compute score of diversity 15 | cmetric = Coherence(texts = tm.tp.lemmas, measure='c_npmi') 16 | coherence1 = cmetric.score(model_output) # Compute score of coherence 17 | cmetric = Coherence(texts = tm.tp.lemmas, measure='c_v') 18 | coherence2 = cmetric.score(model_output) # Compute score of coherence 19 | print(topic_diversity_score, coherence1, coherence2) -------------------------------------------------------------------------------- /examples/vONT_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 13 | " warnings.warn(\n", 14 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 15 | " from .autonotebook import tqdm as notebook_tqdm\n", 16 | "Found cached dataset parquet (/Users/weijiexu/.cache/huggingface/datasets/xwjzds___parquet/xwjzds--ag_news-14abcf7379d2aad0/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n", 17 | "100%|██████████| 2/2 [00:00<00:00, 82.91it/s]\n", 18 | "[nltk_data] Downloading package punkt to /Users/weijiexu/nltk_data...\n", 19 | "[nltk_data] Package punkt is already up-to-date!\n", 20 | "[nltk_data] Downloading package wordnet to\n", 21 | "[nltk_data] /Users/weijiexu/nltk_data...\n", 22 | "[nltk_data] Package wordnet is already up-to-date!\n", 23 | "[nltk_data] Downloading package averaged_perceptron_tagger to\n", 24 | "[nltk_data] /Users/weijiexu/nltk_data...\n", 25 | "[nltk_data] Package averaged_perceptron_tagger is already up-to-\n", 26 | "[nltk_data] date!\n", 27 | "[nltk_data] Downloading package stopwords to\n", 28 | "[nltk_data] /Users/weijiexu/nltk_data...\n", 29 | "[nltk_data] Package stopwords is already up-to-date!\n", 30 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 31 | " warnings.warn(\n", 32 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 33 | " warnings.warn(\n", 34 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 35 | " warnings.warn(\n", 36 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 37 | " warnings.warn(\n", 38 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 39 | " warnings.warn(\n", 40 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 41 | " warnings.warn(\n", 42 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 43 | " warnings.warn(\n", 44 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 45 | " warnings.warn(\n", 46 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 47 | " warnings.warn(\n", 48 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 49 | " warnings.warn(\n" 50 | ] 51 | }, 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "10 12000\n", 57 | "None\n", 58 | "(14696, 100)\n" 59 | ] 60 | }, 61 | { 62 | "name": "stderr", 63 | "output_type": "stream", 64 | "text": [ 65 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/torch/distributions/distribution.py:51: UserWarning: does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.\n", 66 | " warnings.warn(f'{self.__class__} does not define `arg_constraints`. ' +\n", 67 | "/Users/weijiexu/workspace/newd/vONTSS/examples/../src/topicmodeling/model.py:248: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", 68 | " z = self.h_to_z(self.temp * z)\n" 69 | ] 70 | }, 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "301.445813990351 0.018776538698419706\n", 76 | "201.43519367795508 0.03673993603094046\n", 77 | "185.00539409580514 0.036067669945103784\n", 78 | "180.48459939204298 0.03485209299430156\n", 79 | "178.31570714369002 0.0337027379238148\n", 80 | "176.72348936686893 0.03266826551208999\n", 81 | "175.25629810585397 0.031764034884793166\n", 82 | "173.74914131083213 0.030792366328047535\n", 83 | "171.8219049775016 0.03059555485701637\n", 84 | "169.97334719009237 0.030550426138298854\n", 85 | "168.5936137282772 0.030986411234876242\n", 86 | "167.6509848718704 0.031058786317769652\n", 87 | "166.92787546186304 0.031178461113718272\n", 88 | "166.36131999131712 0.031415080544409724\n", 89 | "165.9475578194234 0.031622917119310354\n", 90 | "165.6068212025201 0.03163114861289321\n", 91 | "165.34470294228495 0.03168722147197484\n", 92 | "165.18317086356026 0.031795562771178766\n", 93 | "165.04680537567464 0.031804229738488635\n", 94 | "165.03599343371036 0.031809317936965904\n" 95 | ] 96 | }, 97 | { 98 | "name": "stderr", 99 | "output_type": "stream", 100 | "text": [ 101 | "/Users/weijiexu/workspace/newd/vONTSS/examples/../src/topicmodeling/model.py:444: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", 102 | " self.z = self.model.h_to_z(z_mean).detach().numpy()\n" 103 | ] 104 | }, 105 | { 106 | "ename": "SyntaxError", 107 | "evalue": "keyword argument repeated: numb_embeddings (1751582241.py, line 13)", 108 | "output_type": "error", 109 | "traceback": [ 110 | "\u001b[0;36m Cell \u001b[0;32mIn[1], line 13\u001b[0;36m\u001b[0m\n\u001b[0;31m tm = TopicModel(numb_embeddings = 10, epochs=20, batch_size=256, gpu_num=1, numb_embeddings=20,\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m keyword argument repeated: numb_embeddings\n" 111 | ] 112 | } 113 | ], 114 | "source": [ 115 | "\n", 116 | "import sys\n", 117 | "sys.path.insert(0, '../src/topicmodeling')\n", 118 | "from model import TopicModel\n", 119 | "from datasets import load_dataset\n", 120 | "from octis.evaluation_metrics.diversity_metrics import TopicDiversity\n", 121 | "from octis.evaluation_metrics.coherence_metrics import Coherence\n", 122 | "\n", 123 | "\n", 124 | "df = load_dataset('xwjzds/ag_news')\n", 125 | "tm = TopicModel(numb_embeddings = 10)\n", 126 | "\n", 127 | "model_output = tm.train_model(df['train']['text'])\n", 128 | "\n", 129 | "\n" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 2, 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "data": { 139 | "text/plain": [ 140 | "\"Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\\\band of ultra-cynics, are seeing green again.\"" 141 | ] 142 | }, 143 | "execution_count": 2, 144 | "metadata": {}, 145 | "output_type": "execute_result" 146 | } 147 | ], 148 | "source": [ 149 | "df['train']['text'][0]" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 15, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "ename": "Exception", 159 | "evalue": "Words in topics are less than 19", 160 | "output_type": "error", 161 | "traceback": [ 162 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 163 | "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", 164 | "Cell \u001b[0;32mIn[15], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39moctis\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mevaluation_metrics\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcoherence_metrics\u001b[39;00m \u001b[39mimport\u001b[39;00m Coherence\n\u001b[1;32m 3\u001b[0m metric \u001b[39m=\u001b[39m TopicDiversity(topk\u001b[39m=\u001b[39m\u001b[39m19\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m topic_diversity_score \u001b[39m=\u001b[39m metric\u001b[39m.\u001b[39;49mscore(model_output) \u001b[39m# Compute score of the metric\u001b[39;00m\n", 165 | "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/octis/evaluation_metrics/diversity_metrics.py:47\u001b[0m, in \u001b[0;36mTopicDiversity.score\u001b[0;34m(self, model_output)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39m0\u001b[39m\n\u001b[1;32m 46\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtopk \u001b[39m>\u001b[39m \u001b[39mlen\u001b[39m(topics[\u001b[39m0\u001b[39m]):\n\u001b[0;32m---> 47\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mWords in topics are less than \u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m \u001b[39mstr\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtopk))\n\u001b[1;32m 48\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 49\u001b[0m unique_words \u001b[39m=\u001b[39m \u001b[39mset\u001b[39m()\n", 166 | "\u001b[0;31mException\u001b[0m: Words in topics are less than 19" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "from octis.evaluation_metrics.diversity_metrics import TopicDiversity\n", 172 | "from octis.evaluation_metrics.coherence_metrics import Coherence\n", 173 | "metric = TopicDiversity(topk=10)\n", 174 | "topic_diversity_score = metric.score(model_output) # Compute score of the metric\n", 175 | "cmetric = Coherence(texts = tm.tp.lemmas, measure='c_npmi')\n", 176 | "topic_diversity_score = metric.score(model_output) # Compute score of the metri" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 10, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 11, 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "data": { 193 | "text/plain": [ 194 | "0.10958606182477335" 195 | ] 196 | }, 197 | "execution_count": 11, 198 | "metadata": {}, 199 | "output_type": "execute_result" 200 | } 201 | ], 202 | "source": [ 203 | "topic_diversity_score" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 12, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "metric = Coherence(texts = tm.tp.lemmas, measure='c_v')\n", 213 | "topic_diversity_score = metric.score(model_output) # Compute score of the metri" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 13, 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "data": { 223 | "text/plain": [ 224 | "0.6116887794085797" 225 | ] 226 | }, 227 | "execution_count": 13, 228 | "metadata": {}, 229 | "output_type": "execute_result" 230 | } 231 | ], 232 | "source": [ 233 | "topic_diversity_score" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [] 242 | } 243 | ], 244 | "metadata": { 245 | "kernelspec": { 246 | "display_name": "Python 3", 247 | "language": "python", 248 | "name": "python3" 249 | }, 250 | "language_info": { 251 | "codemirror_mode": { 252 | "name": "ipython", 253 | "version": 3 254 | }, 255 | "file_extension": ".py", 256 | "mimetype": "text/x-python", 257 | "name": "python", 258 | "nbconvert_exporter": "python", 259 | "pygments_lexer": "ipython3", 260 | "version": "3.9.6" 261 | }, 262 | "orig_nbformat": 4 263 | }, 264 | "nbformat": 4, 265 | "nbformat_minor": 2 266 | } 267 | -------------------------------------------------------------------------------- /examples/vONT_example_visualization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 20, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "[nltk_data] Downloading package punkt to /Users/weijiexu/nltk_data...\n", 13 | "[nltk_data] Package punkt is already up-to-date!\n", 14 | "[nltk_data] Downloading package wordnet to\n", 15 | "[nltk_data] /Users/weijiexu/nltk_data...\n", 16 | "[nltk_data] Package wordnet is already up-to-date!\n", 17 | "[nltk_data] Downloading package averaged_perceptron_tagger to\n", 18 | "[nltk_data] /Users/weijiexu/nltk_data...\n", 19 | "[nltk_data] Package averaged_perceptron_tagger is already up-to-\n", 20 | "[nltk_data] date!\n", 21 | "[nltk_data] Downloading package stopwords to\n", 22 | "[nltk_data] /Users/weijiexu/nltk_data...\n", 23 | "[nltk_data] Package stopwords is already up-to-date!\n", 24 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 25 | " warnings.warn(\n", 26 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 27 | " warnings.warn(\n", 28 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 29 | " warnings.warn(\n", 30 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 31 | " warnings.warn(\n", 32 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 33 | " warnings.warn(\n", 34 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 35 | " warnings.warn(\n", 36 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 37 | " warnings.warn(\n", 38 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 39 | " warnings.warn(\n", 40 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 41 | " warnings.warn(\n", 42 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 43 | " warnings.warn(\n" 44 | ] 45 | }, 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "10 1631\n", 51 | "None\n", 52 | "(1369, 100)\n" 53 | ] 54 | }, 55 | { 56 | "name": "stderr", 57 | "output_type": "stream", 58 | "text": [ 59 | "/Users/weijiexu/Library/Python/3.9/lib/python/site-packages/torch/distributions/distribution.py:51: UserWarning: does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.\n", 60 | " warnings.warn(f'{self.__class__} does not define `arg_constraints`. ' +\n", 61 | "/Users/weijiexu/workspace/newd/vONTSS/examples/../src/topicmodeling/model.py:248: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", 62 | " z = self.h_to_z(self.temp * z)\n" 63 | ] 64 | }, 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "354.1831285953522 0.013702268421184272\n", 70 | "314.8430812358856 0.013719041293370537\n", 71 | "260.29898738861084 0.013856412741006352\n", 72 | "226.1326608657837 0.014872243540594354\n", 73 | "209.88615798950195 0.019702297504409216\n", 74 | "202.78133893013 0.025992500450229272\n", 75 | "200.1721682548523 0.02665090240770951\n", 76 | "198.93756580352783 0.025329961848910898\n", 77 | "197.56569409370422 0.024662632873514667\n", 78 | "196.31796407699585 0.024598620628239587\n", 79 | "195.56219696998596 0.024520813080016524\n", 80 | "194.72087502479553 0.024929481995059177\n", 81 | "193.97954988479614 0.024810771603370085\n", 82 | "193.29074358940125 0.024626206315588206\n", 83 | "192.86098980903625 0.024751044693402946\n", 84 | "192.43597888946533 0.024925754143623635\n", 85 | "192.1820297241211 0.02504920243518427\n", 86 | "191.86085605621338 0.025005116767715663\n", 87 | "191.89765119552612 0.024898336792830378\n", 88 | "191.8501913547516 0.024892392160836607\n" 89 | ] 90 | }, 91 | { 92 | "name": "stderr", 93 | "output_type": "stream", 94 | "text": [ 95 | "/Users/weijiexu/workspace/newd/vONTSS/examples/../src/topicmodeling/model.py:444: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", 96 | " self.z = self.model.h_to_z(z_mean).detach().numpy()\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "from octis.dataset.dataset import Dataset\n", 102 | "import sys\n", 103 | "sys.path.insert(0, '../src/topicmodeling')\n", 104 | "from model import TopicModel\n", 105 | "from datasets import load_dataset\n", 106 | "from octis.evaluation_metrics.diversity_metrics import TopicDiversity\n", 107 | "from octis.evaluation_metrics.coherence_metrics import Coherence\n", 108 | "\n", 109 | "\n", 110 | "dataset = Dataset()\n", 111 | "dataset.fetch_dataset(\"20NewsGroup\")\n", 112 | "tm = TopicModel(numb_embeddings = 10)\n", 113 | "texts = [' '.join(i) for i in dataset.get_corpus()]\n", 114 | "model_output = tm.train_model(texts)\n", 115 | "\n" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 27, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "from octis.evaluation_metrics.diversity_metrics import TopicDiversity\n", 125 | "from octis.evaluation_metrics.coherence_metrics import Coherence\n", 126 | "metric = TopicDiversity(topk=10)\n", 127 | "topic_diversity_score = metric.score(model_output) # Compute score of the metric\n", 128 | "cmetric = Coherence(texts = tm.tp.lemmas, measure='c_npmi')\n", 129 | "coherence = cmetric.score(model_output) # Compute score of the metri" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 28, 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "data": { 139 | "text/plain": [ 140 | "0.85" 141 | ] 142 | }, 143 | "execution_count": 28, 144 | "metadata": {}, 145 | "output_type": "execute_result" 146 | } 147 | ], 148 | "source": [ 149 | "topic_diversity_score" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 29, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "data": { 159 | "text/plain": [ 160 | "-0.025114232519858597" 161 | ] 162 | }, 163 | "execution_count": 29, 164 | "metadata": {}, 165 | "output_type": "execute_result" 166 | } 167 | ], 168 | "source": [ 169 | "coherence " 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [] 178 | } 179 | ], 180 | "metadata": { 181 | "kernelspec": { 182 | "display_name": ".venv", 183 | "language": "python", 184 | "name": "python3" 185 | }, 186 | "language_info": { 187 | "codemirror_mode": { 188 | "name": "ipython", 189 | "version": 3 190 | }, 191 | "file_extension": ".py", 192 | "mimetype": "text/x-python", 193 | "name": "python", 194 | "nbconvert_exporter": "python", 195 | "pygments_lexer": "ipython3", 196 | "version": "3.9.6" 197 | }, 198 | "orig_nbformat": 4 199 | }, 200 | "nbformat": 4, 201 | "nbformat_minor": 2 202 | } 203 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | sys.path.insert(0, 'src/topicmodeling') 4 | from model import TopicModel 5 | from datasets import load_dataset 6 | from octis.evaluation_metrics.diversity_metrics import TopicDiversity 7 | from octis.evaluation_metrics.coherence_metrics import Coherence 8 | import pandas as pd 9 | # Function to load data from different sources 10 | def load_data(source, path=''): 11 | if source == 'huggingface': 12 | return load_dataset(path)['train']['text'][:1000] 13 | elif source == 'csv': 14 | df = pd.read_csv(path) 15 | return df['text'].tolist() 16 | elif source == 'txt': 17 | with open(path, 'r') as f: 18 | return f.readlines() 19 | else: 20 | raise ValueError("Invalid data source") 21 | 22 | def evaluate_model(model_output, texts, metric, topk = 10): 23 | if metric != 'diversity': 24 | metric = Coherence(texts=texts, measure=metric) 25 | metric_result = metric.score(model_output) 26 | else: 27 | metric = TopicDiversity(topk=topk) 28 | metric_result = metric.score(model_output) 29 | return metric_result 30 | 31 | # Initialize argparse 32 | parser = argparse.ArgumentParser(description='Topic Modeling Script') 33 | 34 | # Add arguments 35 | parser.add_argument('--numb_embeddings', type=int, default=10, help='Number of embeddings') 36 | parser.add_argument('--epochs', type=int, default=20, help='Number of epochs') 37 | parser.add_argument('--batch_size', type=int, default=256, help='Batch size') 38 | parser.add_argument('--gpu_num', type=int, default=1, help='GPU number') 39 | parser.add_argument('--learning_rate', type=float, default=0.002, help='Learning rate') 40 | parser.add_argument('--weight_decay', type=float, default=1.2e-6, help='Weight decay') 41 | parser.add_argument('--penalty', type=int, default=1, help='Penalty') 42 | parser.add_argument('--beta', type=int, default=1, help='Beta') 43 | parser.add_argument('--temp', type=int, default=10, help='Temperature') 44 | parser.add_argument('--top_n_words', type=int, default=20, help='Top N words') 45 | parser.add_argument('--num_representative_docs', type=int, default=5, help='Number of representative documents') 46 | parser.add_argument('--top_n_topics', type=int, default=100, help='Top N topics') 47 | parser.add_argument('--embedding_dim', type=int, default=100, help='Embedding dimension') 48 | parser.add_argument('--data_source', type=str, default='huggingface', help='Data source type: huggingface, csv, txt') 49 | parser.add_argument('--data_path', type=str, default='', help='Path to the data file for csv or txt') 50 | parser.add_argument('--metrics', nargs='+', default=['diversity', 'c_v', 'c_npmi', 'c_uci', 'u_mass'], help='List of metrics to report') 51 | parser.add_argument('--topk', type=int, default=10, help='top k words to report for diversity') 52 | 53 | if __name__ == '__main__': 54 | # Parse arguments 55 | args = parser.parse_args() 56 | 57 | 58 | 59 | # Load dataset 60 | print(args.data_source) 61 | df_text = load_data(args.data_source, args.data_path) 62 | 63 | # Initialize and train the model 64 | tm = TopicModel( 65 | numb_embeddings=args.numb_embeddings, 66 | epochs=args.epochs, 67 | batch_size=args.batch_size, 68 | gpu_num=args.gpu_num, 69 | learning_rate=args.learning_rate, 70 | weight_decay=args.weight_decay, 71 | penalty=args.penalty, 72 | beta=args.beta, 73 | temp=args.temp, 74 | top_n_words=args.top_n_words, 75 | num_representative_docs=args.num_representative_docs, 76 | top_n_topics=args.top_n_topics, 77 | embedding_dim=args.embedding_dim 78 | ) 79 | 80 | model_output = tm.train_model(df_text, args.data_path.replace('/', '_')) 81 | 82 | scores = [] 83 | #evaluation 84 | for metric in args.metrics: 85 | score = evaluate_model(model_output, tm.tp.lemmas, metric) 86 | scores.append(score) 87 | print(metric + ' is ' + str(score)) 88 | 89 | 90 | print(scores) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.11.0 2 | torch==2.2.0 3 | scikit-learn==1.5.0 4 | nltk==3.11.1 5 | gensim==4.3.1 6 | torchvision==0.2.1 7 | octis==1.13.1 8 | cudatoolkit==10.1 9 | nltk==3.9 10 | huggingface-hub==0.16.4 11 | gensim==4.2.0 12 | -------------------------------------------------------------------------------- /src/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/diffusion/__init__.py -------------------------------------------------------------------------------- /src/diffusion/diffuser_training.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | sys.path.append('../src/diffusion') 4 | 5 | from diffusion.diffusion import Autoencoder, UNet, UNetConv 6 | from dataclasses import dataclass 7 | import torch 8 | import math 9 | import torch 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import Dataset, DataLoader 13 | from diffusers import DDPMScheduler 14 | from diffusers.optimization import get_cosine_schedule_with_warmup 15 | #Step 3 optimize reverse sampling 16 | from accelerate import Accelerator 17 | from huggingface_hub import HfFolder, Repository, whoami 18 | import argparse 19 | from tqdm.auto import tqdm 20 | from pathlib import Path 21 | import os 22 | from accelerate import notebook_launcher 23 | import torch.nn.functional as F 24 | 25 | class NumpyArrayDataset(Dataset): 26 | #This class is used to upload the data 27 | def __init__(self, numpy_array): 28 | self.data = numpy_array 29 | 30 | self.shape = self.data.shape[1] 31 | def __len__(self): 32 | return len(self.data) 33 | 34 | def __getitem__(self, idx): 35 | return self.data[idx] 36 | 37 | def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): 38 | # Initialize accelerator and tensorboard logging 39 | print(os.path.join(config.output_dir, "logs")) 40 | accelerator = Accelerator( 41 | mixed_precision=config.mixed_precision, 42 | gradient_accumulation_steps=config.gradient_accumulation_steps, 43 | log_with="tensorboard", 44 | project_dir=os.path.join(config.output_dir, "logs") 45 | 46 | ) 47 | if accelerator.is_main_process: 48 | accelerator.init_trackers("train_example") 49 | 50 | # Prepare everything 51 | # There is no specific order to remember, you just need to unpack the 52 | # objects in the same order you gave them to the prepare method. 53 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 54 | model, optimizer, train_dataloader, lr_scheduler 55 | ) 56 | 57 | global_step = 0 58 | 59 | # Now you train the model 60 | for epoch in range(config.num_epochs): 61 | progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) 62 | progress_bar.set_description(f"Epoch {epoch}") 63 | 64 | for step, batch in enumerate(train_dataloader): 65 | clean_images = batch 66 | # Sample noise to add to the images 67 | noise = torch.randn(clean_images.shape).to(clean_images.device) 68 | bs = clean_images.shape[0] 69 | 70 | # Sample a random timestep for each image 71 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long() 72 | 73 | # Add noise to the clean images according to the noise magnitude at each timestep 74 | # (this is the forward diffusion process) 75 | noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) 76 | 77 | with accelerator.accumulate(model): 78 | # Predict the noise residual 79 | noise_pred = model(noisy_images, timesteps) 80 | loss = F.mse_loss(noise_pred, noise) 81 | accelerator.backward(loss) 82 | 83 | accelerator.clip_grad_norm_(model.parameters(), 1.0) 84 | optimizer.step() 85 | lr_scheduler.step() 86 | optimizer.zero_grad() 87 | 88 | progress_bar.update(1) 89 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} 90 | progress_bar.set_postfix(**logs) 91 | accelerator.log(logs, step=global_step) 92 | global_step += 1 93 | return model 94 | 95 | 96 | 97 | if __name__ == "__main__": 98 | 99 | parser = argparse.ArgumentParser() 100 | 101 | parser.add_argument("--train_batch_size", type=int, default=10) 102 | parser.add_argument("--eval_batch_size", type=int, default=10) 103 | parser.add_argument("--num_epochs", type=int, default=30) 104 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 105 | parser.add_argument("--learning_rate", type=float, default=1e-4) 106 | parser.add_argument("--lr_warmup_steps", type=int, default=500) 107 | parser.add_argument("--mixed_precision", type=str, default='fp16') 108 | parser.add_argument("--model_name", type=str, default='VAE') 109 | parser.add_argument("--hidden_dim", type=int, default=256) 110 | parser.add_argument("--embedding_input", type=str, default='test') 111 | parser.add_argument("--output_dir", type=str, default='.') 112 | 113 | args, _ = parser.parse_known_args() 114 | 115 | if args.embedding_input == 'test': 116 | generated_embeddings = torch.ones(100000, 768 * 2) #The shape (# of samples, # of dimensions) 117 | else: 118 | generated_embeddings = torch.tensor(torch.load(args.embedding_input)) 119 | 120 | if args.model_name == 'VAE': 121 | model = Autoencoder(generated_embeddings.shape[1], args.hidden_dim) 122 | elif args.model_name == 'UNet': 123 | model = UNet(generated_embeddings.shape[1], args.hidden_dim) 124 | elif args.model_name == 'UNet_Conv': 125 | model = UNetConv(generated_embeddings.shape[2], generated_embeddings.shape[1], generated_embeddings.shape[1]) 126 | 127 | 128 | 129 | dataset = NumpyArrayDataset(generated_embeddings) 130 | # create a dataloader for the dataset 131 | dataloader = DataLoader(dataset, batch_size=256, shuffle=True) 132 | 133 | #generate argparse argument based on TrainingConfig 134 | class TrainingConfig: 135 | image_size = generated_embeddings.shape[1] # based on generated embeddings 136 | train_batch_size = args.train_batch_size 137 | eval_batch_size = args.eval_batch_size # how many images to sample during evaluation 138 | num_epochs = args.num_epochs 139 | gradient_accumulation_steps = args.gradient_accumulation_steps 140 | learning_rate = args.learning_rate 141 | lr_warmup_steps = args.lr_warmup_steps 142 | mixed_precision = args.mixed_precision # `no` for float32, `fp16` for automatic mixed precision 143 | overwrite_output_dir = True # overwrite the old model when re-running the notebook 144 | seed = 0, 145 | output_dir = args.output_dir 146 | 147 | 148 | config = TrainingConfig() 149 | 150 | 151 | noise_scheduler = DDPMScheduler(num_train_timesteps=100) 152 | # We still have to define 153 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) 154 | 155 | #Learning rate scheduler 156 | lr_scheduler = get_cosine_schedule_with_warmup( 157 | optimizer=optimizer, 158 | num_warmup_steps=config.lr_warmup_steps, 159 | num_training_steps=(len(dataloader) * config.num_epochs), 160 | ) 161 | 162 | arg = (config, model, noise_scheduler, optimizer, dataloader, lr_scheduler) 163 | 164 | notebook_launcher(train_loop, arg, num_processes=1) 165 | #Save pytorch model to dict 166 | output_dir = args.output_dir 167 | embedding_input = args.embedding_input 168 | model_name = args.model_name 169 | torch.save(model.state_dict(), f'{output_dir}/diffusion_model.pt') 170 | 171 | -------------------------------------------------------------------------------- /src/diffusion/diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | #create a VAE with time step as input + noise as input and predict the noise 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import math 13 | 14 | class SinusoidalPositionEmbeddings(nn.Module): 15 | def __init__(self, dim): 16 | super().__init__() 17 | self.dim = dim 18 | 19 | def forward(self, time): 20 | device = time.device 21 | half_dim = self.dim // 2 22 | embeddings = math.log(10000) / (half_dim - 1) 23 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 24 | embeddings = time[:, None] * embeddings[None, :] 25 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 26 | return embeddings 27 | 28 | #We use basic embedding concatenate time embeddings for training 29 | class Autoencoder(nn.Module): 30 | def __init__(self, input_dim, hidden_dim): 31 | super(Autoencoder, self).__init__() 32 | self.embedding = SinusoidalPositionEmbeddings(input_dim) # Embedding layer for integer input 33 | self.encoder = nn.Sequential( 34 | nn.Linear(input_dim, hidden_dim), 35 | 36 | nn.Linear(hidden_dim, hidden_dim//2), 37 | 38 | nn.Linear(hidden_dim//2, hidden_dim//4), 39 | 40 | ) 41 | self.decoder = nn.Sequential( 42 | 43 | nn.Linear(hidden_dim//4, hidden_dim//2), 44 | 45 | nn.Linear(hidden_dim//2, hidden_dim), 46 | 47 | nn.Linear(hidden_dim, input_dim), 48 | 49 | ) 50 | 51 | def forward(self, x, t): 52 | t_embed = self.embedding(t) 53 | #print(x.shape, t.shape, t_embed.shape) 54 | x = x + t_embed 55 | x = self.encoder(x) 56 | x = self.decoder(x) 57 | return x 58 | 59 | class UNet(nn.Module): 60 | def __init__(self, embedding_dim, hidden_dim): 61 | super(UNet, self).__init__() 62 | self.embedding = SinusoidalPositionEmbeddings(embedding_dim) # Embedding layer for integer input 63 | self.down1 = nn.Linear(embedding_dim , hidden_dim) 64 | self.down2 = nn.Linear(hidden_dim, hidden_dim // 2) 65 | self.bottom = nn.Linear(hidden_dim // 2, hidden_dim // 2) 66 | self.up1 = nn.Linear(hidden_dim // 2, hidden_dim) 67 | self.up2 = nn.Linear(hidden_dim, embedding_dim) 68 | 69 | def forward(self, x, t): 70 | t_embed = self.embedding(t) 71 | #print(x.shape, t.shape, t_embed.shape) 72 | #x = torch.cat([x, t_embed], dim = 1) 73 | x = x + t_embed 74 | x1 = self.down1(x) 75 | x2 = self.down2(x1) 76 | x3 = self.bottom(x2) 77 | x4 = self.up1(x3 + x2) # Skip connection 78 | x5 = self.up2(x4 + x1) # Skip connection 79 | return x5 80 | 81 | 82 | class FFC(nn.Module): 83 | def __init__(self, input_dim, output_dim, **kwargs) -> None: 84 | super().__init__(**kwargs) 85 | self.ffc = nn.Sequential( 86 | nn.Linear(input_dim, output_dim), 87 | nn.BatchNorm1d(output_dim), 88 | nn.ReLU(inplace=True), 89 | ) 90 | 91 | def forward(self, x): 92 | x = self.ffc(x) 93 | return x 94 | 95 | 96 | class UNet_1(nn.Module): 97 | def __init__(self, embedding_dim, hidden_dim): 98 | super(UNet_1, self).__init__() 99 | self.embedding = SinusoidalPositionEmbeddings(embedding_dim) # Embedding layer for integer input 100 | # self.down1 = nn.Linear(embedding_dim , hidden_dim) 101 | # self.down2 = nn.Linear(hidden_dim, hidden_dim // 2) 102 | # self.bottom = nn.Linear(hidden_dim // 2, hidden_dim // 2) 103 | # self.up1 = nn.Linear(hidden_dim // 2, hidden_dim) 104 | # self.up2 = nn.Linear(hidden_dim, embedding_dim) 105 | 106 | self.down1 = FFC(embedding_dim, hidden_dim) 107 | self.down2 = FFC(hidden_dim, hidden_dim // 2) 108 | self.bottom = FFC(hidden_dim // 2, hidden_dim // 2) 109 | self.up1 = FFC(hidden_dim // 2, hidden_dim) 110 | self.up2 = FFC(hidden_dim, embedding_dim) 111 | 112 | def forward(self, x, t): 113 | t_embed = self.embedding(t) 114 | #print(x.shape, t.shape, t_embed.shape) 115 | #x = torch.cat([x, t_embed], dim = 1) 116 | x = x + t_embed 117 | x1 = self.down1(x) 118 | x2 = self.down2(x1) 119 | x3 = self.bottom(x2) 120 | x4 = self.up1(x3 + x2) # Skip connection 121 | x5 = self.up2(x4 + x1) # Skip connection 122 | return x5 123 | 124 | # based on the https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py 125 | 126 | class DoubleConv(nn.Module): 127 | """(convolution => [BN] => ReLU) * 2""" 128 | 129 | def __init__(self, in_channels, out_channels, mid_channels=None): 130 | super().__init__() 131 | if not mid_channels: 132 | mid_channels = out_channels 133 | self.double_conv = nn.Sequential( 134 | nn.Conv1d(in_channels, mid_channels, kernel_size=3, padding=1, stride=1, bias=False), 135 | nn.BatchNorm1d(mid_channels), 136 | nn.ReLU(inplace=True), 137 | nn.Conv1d(mid_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False), 138 | nn.BatchNorm1d(out_channels), 139 | nn.ReLU(inplace=True) 140 | ) 141 | 142 | def forward(self, x): 143 | return self.double_conv(x) 144 | 145 | 146 | class Down(nn.Module): 147 | """Downscaling with maxpool then double conv""" 148 | 149 | def __init__(self, in_channels, out_channels): 150 | super().__init__() 151 | self.maxpool_conv = nn.Sequential( 152 | nn.MaxPool1d(2), 153 | DoubleConv(in_channels, out_channels) 154 | ) 155 | 156 | def forward(self, x): 157 | return self.maxpool_conv(x) 158 | 159 | 160 | class Up(nn.Module): 161 | """Upscaling then double conv""" 162 | 163 | def __init__(self, in_channels, out_channels, bilinear=True): 164 | super().__init__() 165 | 166 | # if bilinear, use the normal convolutions to reduce the number of channels 167 | if bilinear: 168 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 169 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 170 | else: 171 | self.up = nn.ConvTranspose1d(in_channels, in_channels // 2, kernel_size=2, stride=2) 172 | self.conv = DoubleConv(in_channels, out_channels) 173 | 174 | def forward(self, x1, x2): 175 | x1 = self.up(x1) 176 | # input is CHW 177 | diffY = x2.size()[2] - x1.size()[2] 178 | # diffX = x2.size()[3] - x1.size()[3] 179 | 180 | x1 = F.pad(x1, [diffY // 2, diffY - diffY // 2]) 181 | # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 182 | # diffY // 2, diffY - diffY // 2]) 183 | # if you have padding issues, see 184 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 185 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 186 | x = torch.cat([x2, x1], dim=1) 187 | return self.conv(x) 188 | 189 | 190 | class OutConv(nn.Module): 191 | def __init__(self, in_channels, out_channels): 192 | super(OutConv, self).__init__() 193 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1) 194 | 195 | def forward(self, x): 196 | return self.conv(x) 197 | 198 | 199 | class UNet_1_0(nn.Module): 200 | def __init__(self, embedding_dim, hidden_dim): 201 | super(UNet_1_0, self).__init__() 202 | self.UNet_1_0 = SinusoidalPositionEmbeddings(embedding_dim) # Embedding layer for integer input 203 | # self.down1 = nn.Linear(embedding_dim , hidden_dim) 204 | # self.down2 = nn.Linear(hidden_dim, hidden_dim // 2) 205 | # self.bottom = nn.Linear(hidden_dim // 2, hidden_dim // 2) 206 | # self.up1 = nn.Linear(hidden_dim // 2, hidden_dim) 207 | # self.up2 = nn.Linear(hidden_dim, embedding_dim) 208 | 209 | self.down1 = DoubleConv(embedding_dim, hidden_dim) 210 | self.down2 = DoubleConv(hidden_dim, hidden_dim // 2) 211 | self.bottom = DoubleConv(hidden_dim // 2, hidden_dim // 2) 212 | self.up1 = DoubleConv(hidden_dim // 2, hidden_dim) 213 | self.up2 = DoubleConv(hidden_dim, embedding_dim) 214 | 215 | def forward(self, x, t): 216 | t_embed = self.embedding(t) 217 | #print(x.shape, t.shape, t_embed.shape) 218 | #x = torch.cat([x, t_embed], dim = 1) 219 | x = x + t_embed 220 | x1 = self.down1(x) 221 | x2 = self.down2(x1) 222 | x3 = self.bottom(x2) 223 | x4 = self.up1(x3 + x2) # Skip connection 224 | x5 = self.up2(x4 + x1) # Skip connection 225 | return x5 226 | 227 | class UNetConv(nn.Module): 228 | def __init__(self, embedding_dim, n_channels, n_classes, bilinear=False): 229 | super(UNetConv, self).__init__() 230 | 231 | self.embedding = SinusoidalPositionEmbeddings(embedding_dim) # Embedding layer for integer input 232 | 233 | self.n_channels = n_channels 234 | self.n_classes = n_classes 235 | self.bilinear = bilinear 236 | 237 | self.inc = (DoubleConv(n_channels, 64)) 238 | self.down1 = (Down(64, 128)) 239 | self.down2 = (Down(128, 256)) 240 | self.down3 = (Down(256, 512)) 241 | factor = 2 if bilinear else 1 242 | self.down4 = (Down(512, 1024 // factor)) 243 | self.up1 = (Up(1024, 512 // factor, bilinear)) 244 | self.up2 = (Up(512, 256 // factor, bilinear)) 245 | self.up3 = (Up(256, 128 // factor, bilinear)) 246 | self.up4 = (Up(128, 64, bilinear)) 247 | self.outc = (OutConv(64, n_classes)) 248 | 249 | def forward(self, x, t): 250 | 251 | t_embed = self.embedding(t) 252 | # add time embed for each channel 253 | t_embed = t_embed.unsqueeze(1) 254 | 255 | x = x + t_embed 256 | 257 | x1 = self.inc(x) 258 | x2 = self.down1(x1) 259 | x3 = self.down2(x2) 260 | x4 = self.down3(x3) 261 | x5 = self.down4(x4) 262 | x = self.up1(x5, x4) 263 | x = self.up2(x, x3) 264 | x = self.up3(x, x2) 265 | x = self.up4(x, x1) 266 | logits = self.outc(x) 267 | 268 | return logits 269 | -------------------------------------------------------------------------------- /src/diffusion/diffusion_generate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | this module covers fnctions used to denoise 3 | ''' 4 | 5 | import sys 6 | sys.path.append('..') 7 | sys.path.append('../src/') 8 | 9 | import torch 10 | from tqdm import tqdm 11 | from diffusion.diffuser_training import DDPMScheduler 12 | 13 | 14 | def denoise(x, model, timesteps, noise_scheduler, device): 15 | 16 | x = x.to(device) 17 | model = model.to(device) 18 | 19 | num_sample = x.shape[0] 20 | 21 | x_track_ls = [] 22 | x_0_track_ls = [] 23 | 24 | for time_step in tqdm(iterable=reversed(range(0, timesteps)), 25 | total=timesteps, dynamic_ncols=False, 26 | desc="Sampling :: ", position=0): 27 | 28 | ts = torch.ones(num_sample, dtype=torch.long, device=device) * time_step 29 | # z = torch.randn_like(x) if time_step > 1 else torch.zeros_like(x) 30 | # print(get_gpu_memory()) # in bytes 31 | predicted_noise = model(x, ts) 32 | # print(get_gpu_memory()) # in bytes 33 | 34 | # detach x and clean the cache 35 | x = x.detach().cpu() 36 | predicted_noise = predicted_noise.detach().cpu() 37 | 38 | # gc.collect() 39 | # torch.cuda.empty_cache() 40 | 41 | res = noise_scheduler.step(predicted_noise, time_step, x) 42 | 43 | x, x_orginal = res['prev_sample'], res['pred_original_sample'] 44 | 45 | # only return the smaple and another 46 | # res = noise_scheduler.step(predicted_noise, time_step, x, return_dict=False) 47 | # x = res[0] 48 | 49 | # print(get_gpu_memory()) # in bytes 50 | #print(process.memory_info().rss) 51 | # gc.collect() 52 | # torch.cuda.empty_cache() 53 | #print(process.memory_info().rss) # in bytes 54 | x_0_track_ls.append(x_orginal.detach().cpu()) 55 | x_track_ls.append(x.detach().cpu()) 56 | 57 | x = x.to(device) 58 | 59 | return x_track_ls, x_0_track_ls 60 | 61 | def generate_diffused_embed(embed_ls, diffuser, timesteps, device, batch_size=4, 62 | num_generated_sample=10, return_all_time_embed=False): 63 | 64 | # for each embed, we are going to generate several diffused embeddings 65 | 66 | noise_scheduler = DDPMScheduler(num_train_timesteps=timesteps) 67 | 68 | # embed in batch 69 | embed_ls_batch = [embed_ls[i:i+batch_size] for i in range(0, len(embed_ls), batch_size)] 70 | 71 | x_track_final_ls, x_0_track_final_ls = [], [] 72 | 73 | for _ in range(num_generated_sample): 74 | 75 | x_track_ls_ls, x_0_track_ls_ls = [], [] 76 | 77 | for i in range(len(embed_ls_batch)): 78 | 79 | x_diff_input = embed_ls_batch[i] 80 | 81 | # x_diff_input = torch.concat(x_diff_input, dim=0) 82 | x_track_ls, x_0_track_ls = denoise(x_diff_input, diffuser, timesteps, noise_scheduler, device=device) 83 | 84 | x_track_ls = [x.unsqueeze(0) for x in x_track_ls] 85 | x_0_track_ls = [x.unsqueeze(0) for x in x_0_track_ls] 86 | 87 | x_track_ls = torch.concat(x_track_ls, dim=0) # t * sub_bz * dim 88 | x_0_track_ls = torch.concat(x_0_track_ls, dim=0) 89 | 90 | x_track_ls_ls.append(x_track_ls) # t * sub_bz * dim 91 | x_0_track_ls_ls.append(x_0_track_ls) 92 | 93 | # get the full batch embed 94 | x_track_ls_ls = torch.concat(x_track_ls_ls, dim=1) # t * bz * dim 95 | x_0_track_ls_ls = torch.concat(x_0_track_ls_ls, dim=1) # t * bz * dim 96 | 97 | x_track_final_ls.append(x_track_ls_ls.unsqueeze(dim=0)) # 1 * t * bz * dim 98 | x_0_track_final_ls.append(x_0_track_ls_ls.unsqueeze(dim=0)) # 1 * t * bz * dim 99 | 100 | x_track_final_ls = torch.concat(x_track_final_ls, dim=0) # num_generated_sample * t * bz * dim 101 | x_0_track_final_ls = torch.concat(x_0_track_final_ls, dim=0) # num_generated_sample * t * bz * dim 102 | 103 | # get the embed after the full diffusion 104 | x_track_time_aft_diff = x_track_final_ls[:, -1, :, :] 105 | 106 | if return_all_time_embed: 107 | return x_track_final_ls, x_0_track_final_ls 108 | 109 | return x_track_time_aft_diff 110 | 111 | 112 | def generate_text(x, models, tokenizer, max_length, device, sub_batch, **generate_kwargs): 113 | 114 | # create a subbatch 115 | x_list = x 116 | if sub_batch: 117 | x_list = [x[i:i+sub_batch] for i in range(0, len(x), sub_batch)] 118 | 119 | output_text_ls = [] 120 | 121 | for x_latent in x_list: 122 | # diffuse the sample in the embedding space 123 | # x_latent = x 124 | if len(x_latent.shape) == 2: 125 | x_latent = x.reshape(-1, 4, 768) 126 | 127 | # encoded vector 128 | x_latent = x_latent.to(device) 129 | output = models.decoder.to(device)(x_latent) 130 | 131 | 132 | # generate dummy inputs and attention with the same length as cent 133 | inputs = torch.ones(output.shape, dtype=torch.int).to(device) 134 | 135 | sent_outputs = models.model.generate( 136 | input_ids=inputs, 137 | encoder_outputs={0: output, }, # in order to use other decoder strategy 138 | max_length=max_length, **generate_kwargs) 139 | 140 | outputs_text = tokenizer.batch_decode( 141 | sent_outputs.detach().cpu(), skip_special_tokens=True) 142 | 143 | output_text_ls.extend(outputs_text) 144 | 145 | return output_text_ls -------------------------------------------------------------------------------- /src/huggingface/__pycache__/configuration_detime.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/huggingface/__pycache__/configuration_detime.cpython-310.pyc -------------------------------------------------------------------------------- /src/huggingface/__pycache__/modeling_detime.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/huggingface/__pycache__/modeling_detime.cpython-310.pyc -------------------------------------------------------------------------------- /src/huggingface/detime/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/huggingface/detime/__init__.py -------------------------------------------------------------------------------- /src/huggingface/detime/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/huggingface/detime/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/huggingface/detime/__pycache__/configuration_detime.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/huggingface/detime/__pycache__/configuration_detime.cpython-310.pyc -------------------------------------------------------------------------------- /src/huggingface/detime/__pycache__/modeling_detime.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/huggingface/detime/__pycache__/modeling_detime.cpython-310.pyc -------------------------------------------------------------------------------- /src/huggingface/detime/configuration_detime.py: -------------------------------------------------------------------------------- 1 | 2 | from transformers import T5Config, PretrainedConfig 3 | from typing import List 4 | 5 | 6 | # define Flan-T5 nest CNN autoencoder here 7 | class DeTiMEAutoConfig(T5Config): 8 | model_type = "detime" 9 | 10 | def __init__( 11 | self, 12 | hidden_size1: int = 512, 13 | hidden_size3: int = 512, 14 | num_layer: int = 1, 15 | dropout: float = 0.1, 16 | max_length: int = 512, 17 | model_name: str = None, 18 | **kwargs, 19 | ): 20 | self.hidden_size1 = hidden_size1 21 | self.hidden_size3 = hidden_size3 22 | self.num_layer = num_layer 23 | self.dropout = dropout 24 | self.max_length = max_length 25 | self.model_name = model_name 26 | super().__init__(**kwargs) -------------------------------------------------------------------------------- /src/huggingface/detime/modeling_detime.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module includes all the classes and functions for the nested autoencoder. 3 | """ 4 | 5 | from transformers import PreTrainedModel 6 | from transformers import T5ForConditionalGeneration, AutoModelForSeq2SeqLM 7 | import datasets 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | import random 12 | import os 13 | from configuration_detime import DeTiMEAutoConfig 14 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 15 | 16 | 17 | 18 | # Define the CNN encoder and decoder model 19 | class CNNEncoder(nn.Module): 20 | def __init__(self, hidden_size1, hidden_size3): 21 | super().__init__() 22 | # Define the encoder 23 | self.encoder = nn.Sequential( 24 | nn.Conv1d(in_channels=hidden_size1, out_channels=128, kernel_size=3, stride=1, padding=1), 25 | nn.ReLU(), 26 | nn.Conv1d(in_channels=128, out_channels=16, kernel_size=3, stride=1, padding=1), 27 | nn.ReLU(), 28 | # nn.Conv1d(in_channels=16, out_channels=4, kernel_size=3, stride=1, padding=1), 29 | # nn.ReLU(), 30 | nn.Conv1d(in_channels=16, out_channels=hidden_size3, kernel_size=3, stride=1, padding=1) 31 | ) 32 | 33 | def forward(self, x): 34 | # x = x.permute(0, 2, 1) 35 | # Encode the input 36 | encoded = self.encoder(x) 37 | return encoded 38 | 39 | class CNNDecoder(nn.Module): 40 | def __init__(self, hidden_size1, hidden_size3) -> None: 41 | super().__init__() 42 | 43 | # Define the decoder 44 | self.decoder = nn.Sequential( 45 | nn.Conv1d(in_channels=hidden_size3, out_channels=16, kernel_size=3, stride=1, padding=1), 46 | nn.ReLU(), 47 | nn.Conv1d(in_channels=16, out_channels=128, kernel_size=3, stride=1, padding=1), 48 | nn.ReLU(), 49 | # nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 50 | # nn.ReLU(), 51 | nn.Conv1d(in_channels=128, out_channels=hidden_size1, kernel_size=3, stride=1, padding=1), 52 | nn.Sigmoid() 53 | ) 54 | 55 | def forward(self, x): 56 | # Decode the encoding 57 | decoded = self.decoder(x) 58 | # decoded = decoded.permute(0, 2, 1) 59 | return decoded 60 | 61 | 62 | 63 | class DeTiME(PreTrainedModel): 64 | config_class = DeTiMEAutoConfig 65 | 66 | def __init__(self, config): 67 | super().__init__(config) 68 | #change t5-small to config 69 | model_name_or_path = config.model 70 | # peft_config = PrefixTuningConfig(peft_type="PREFIX_TUNING", task_type=TaskType.SEQ_2_SEQ_LM, 71 | # inference_mode=False, num_virtual_tokens=10) 72 | # model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) 73 | # model = get_peft_model(model, peft_config) 74 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) 75 | 76 | #model.print_trainable_parameters() 77 | self.model = model 78 | self.config_model = 'CNN' 79 | if self.config_model == 'CNN': 80 | # self.model = T5ForConditionalGeneration.from_pretrained("t5-small") 81 | self.encoder = CNNEncoder( 82 | config.hidden_size1, config.hidden_size3) 83 | self.decoder = CNNDecoder( 84 | config.hidden_size1, config.hidden_size3) 85 | self.encoder.main_input_name = self.model.main_input_name 86 | 87 | 88 | self.encoder.main_input_name = self.model.main_input_name 89 | self.main_input_name = self.model.main_input_name 90 | 91 | def forward(self, input_ids, attention_mask, labels, **kwargs): 92 | output = self.model.encoder( 93 | input_ids=input_ids, attention_mask=attention_mask).last_hidden_state #batch size * seq length * embedding size, 94 | #print(output.shape) 95 | if self.config_model == 'CNN': 96 | encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size 97 | #print(encoder_output.shape) 98 | 99 | output = self.decoder(encoder_output) #1 batch_size, hidden_size 100 | 101 | return self.model.forward(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ), labels=labels.contiguous(), **kwargs) 102 | 103 | def generate(self, input_ids, attention_mask, **kwargs): 104 | output = self.model.encoder( 105 | input_ids=input_ids, attention_mask=attention_mask).last_hidden_state #batch size * seq length * embedding size, 106 | #print(output.shape) 107 | # encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size 108 | # #print(encoder_output.shape) 109 | if self.config_model == 'CNN': 110 | encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size 111 | #print(encoder_output.shape) 112 | 113 | output = self.decoder(encoder_output) #1 batch_size, hidden_size 114 | elif self.config_model == 'RNN': 115 | output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size 116 | 117 | # output = self.decoder(encoder_output) #1 batch_size, hidden_size 118 | 119 | return self.model.generate(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ), **kwargs) 120 | 121 | -------------------------------------------------------------------------------- /src/huggingface/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/weijiexu/anaconda3/envs/env-01/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "#register the model to huggingface\n", 19 | "from transformers import AutoConfig, AutoModel, AutoModelForImageClassification\n", 20 | "from configuration_detime import DeTiMEAutoConfig\n", 21 | "from modeling_detime import DeTiME\n", 22 | "AutoConfig.register(\"detime\", DeTiMEAutoConfig)\n", 23 | "AutoModel.register(DeTiMEAutoConfig, DeTiME)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "from detime.configuration_detime import DeTiMEAutoConfig\n", 33 | "from detime.modeling_detime import DeTiME\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "DeTiMEAutoConfig.register_for_auto_class()\n", 43 | "DeTiME.register_for_auto_class(\"AutoModel\")" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 5, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/plain": [ 54 | "" 55 | ] 56 | }, 57 | "execution_count": 5, 58 | "metadata": {}, 59 | "output_type": "execute_result" 60 | } 61 | ], 62 | "source": [ 63 | "import torch\n", 64 | "DetimeConfig = DeTiMEAutoConfig(\n", 65 | " hidden_size1=512, hidden_size3=4, \n", 66 | " hidden_size2=768 , output_size=4 * 768, \n", 67 | " model = 'google/flan-t5-large')\n", 68 | "model = DeTiME(DetimeConfig)\n", 69 | "model.load_state_dict(torch.load( '/home/weijiexu/WXWH/WXWH/p4/p4/model/flant5_nest_peftflan-t5-large8'))\n" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 10, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stderr", 79 | "output_type": "stream", 80 | "text": [ 81 | "/home/weijiexu/anaconda3/envs/env-01/lib/python3.10/site-packages/huggingface_hub/_commit_api.py:288: UserWarning: About to update multiple times the same file in the same commit: 'modeling_detime.py'. This can cause undesired inconsistencies in your repo.\n", 82 | " warnings.warn(\n", 83 | "/home/weijiexu/anaconda3/envs/env-01/lib/python3.10/site-packages/huggingface_hub/_commit_api.py:288: UserWarning: About to update multiple times the same file in the same commit: 'configuration_detime.py'. This can cause undesired inconsistencies in your repo.\n", 84 | " warnings.warn(\n", 85 | "/home/weijiexu/anaconda3/envs/env-01/lib/python3.10/site-packages/huggingface_hub/_commit_api.py:288: UserWarning: About to update multiple times the same file in the same commit: 'config.json'. This can cause undesired inconsistencies in your repo.\n", 86 | " warnings.warn(\n", 87 | "/home/weijiexu/anaconda3/envs/env-01/lib/python3.10/site-packages/huggingface_hub/_commit_api.py:288: UserWarning: About to update multiple times the same file in the same commit: 'pytorch_model.bin'. This can cause undesired inconsistencies in your repo.\n", 88 | " warnings.warn(\n", 89 | "pytorch_model.bin: 0%| | 0.00/3.13G [00:00 None: 40 | super().__init__() 41 | 42 | # Define the decoder 43 | self.decoder = nn.Sequential( 44 | nn.Conv1d(in_channels=hidden_size3, out_channels=16, kernel_size=3, stride=1, padding=1), 45 | nn.ReLU(), 46 | nn.Conv1d(in_channels=16, out_channels=128, kernel_size=3, stride=1, padding=1), 47 | nn.ReLU(), 48 | # nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 49 | # nn.ReLU(), 50 | nn.Conv1d(in_channels=128, out_channels=hidden_size1, kernel_size=3, stride=1, padding=1), 51 | nn.Sigmoid() 52 | ) 53 | 54 | def forward(self, x): 55 | # Decode the encoding 56 | decoded = self.decoder(x) 57 | # decoded = decoded.permute(0, 2, 1) 58 | return decoded 59 | 60 | 61 | # class Decoder(torch.nn.Module): 62 | # def __init__(self, hidden_size1, hidden_size2, output_size, num_layers=1, dropout=0.1): 63 | # super().__init__() 64 | # self.fc = torch.nn.Linear(output_size, hidden_size1 * hidden_size2) 65 | # self.dropout = torch.nn.Dropout(dropout) 66 | # # self.batch_size = batch_size 67 | # self.hidden_size1 = hidden_size1 68 | # self.hidden_size2 = hidden_size2 69 | 70 | # def forward(self, x): 71 | # x = self.dropout(x) 72 | # x = self.fc(x) 73 | # x = torch.reshape(x, (-1, self.hidden_size1, self.hidden_size2)) 74 | 75 | # return x 76 | 77 | 78 | # class Encoder(torch.nn.Module): 79 | # def __init__(self, hidden_size1, hidden_size2, output_size, num_layers=1, dropout=0.1): 80 | # super().__init__() 81 | # self.fc = torch.nn.Linear(hidden_size1 * hidden_size2, output_size) 82 | # self.dropout = torch.nn.Dropout(dropout) 83 | # self.hidden_size1 = hidden_size1 84 | # self.hidden_size2 = hidden_size2 85 | # # self.batch_size = batch_size 86 | 87 | # def forward(self, x): 88 | # # to do: Add softmax 89 | 90 | # # pad last dim by 1 on each side 91 | # p1d = (0, 0, 0, self.hidden_size1 - x.shape[1]) 92 | # x = F.pad(x, p1d, "constant", 0) 93 | # x = torch.reshape(x, (-1, self.hidden_size1 * self.hidden_size2)) 94 | # x = self.dropout(x) 95 | # x = self.fc(x) 96 | # # Add softmax to represent topic distributioon 97 | # x = F.softmax(x, dim=1) 98 | # return x 99 | 100 | 101 | # class lstm_decoder(nn.Module): 102 | # ''' Decodes hidden state output by encoder ''' 103 | 104 | # def __init__(self, hidden_size2, output_size, num_layers = 1): 105 | 106 | # ''' 107 | # : param input_size: the number of features in the input X 108 | # : param hidden_size: the number of features in the hidden state h 109 | # : param num_layers: number of recurrent layers (i.e., 2 means there are 110 | # : 2 stacked LSTMs) 111 | # ''' 112 | 113 | # super(lstm_decoder, self).__init__() 114 | # self.hidden_size2 = hidden_size2 115 | # self.output_size = output_size 116 | # self.num_layers = num_layers 117 | 118 | # self.lstm = nn.GRU(input_size = hidden_size2, hidden_size = output_size, 119 | # num_layers = num_layers, batch_first = True) 120 | # self.linear = nn.Linear(output_size, hidden_size2) 121 | # self.init_input = nn.Parameter(torch.randn(1, hidden_size2)) # Learned constant vector 122 | 123 | 124 | # def forward(self, x_input, encoder_hidden_states): 125 | 126 | # ''' 127 | # : param x_input: should be 2D (batch_size, input_size) 128 | # : param encoder_hidden_states: hidden states 129 | # : return output, hidden: output gives all the hidden states in the sequence; 130 | # : hidden gives the hidden state and cell state for the last 131 | # : element in the sequence 132 | 133 | # ''' 134 | # #print(x_input.shape, encoder_hidden_states.shape) 135 | # lstm_out, self.hidden = self.lstm(x_input, encoder_hidden_states) 136 | # output = self.linear(lstm_out) 137 | 138 | # return output, self.hidden 139 | 140 | # def init_hidden(self, encoder_hidden): 141 | # # Set the initial hidden state of the decoder to be the last hidden state of the encoder 142 | # return encoder_hidden 143 | 144 | # def init_input_vector(self, batch_size): 145 | # # Create a batch of learned constant vectors for the initial input 146 | # return self.init_input.expand(batch_size, 1, -1) #Batch_size, sequence length, embeddings size 147 | 148 | 149 | # class lstm_encoder(nn.Module): 150 | # ''' Encodes time-series sequence ''' 151 | 152 | # def __init__(self, hidden_size2, output_size, num_layers = 1): 153 | 154 | # ''' 155 | # : param input_size: the number of features in the input X 156 | # : param input_size: the number of features in the input X 157 | # : param hidden_size: the number of features in the hidden state h 158 | # : param num_layers: number of recurrent layers (i.e., 2 means there are 159 | # : 2 stacked LSTMs) 160 | # ''' 161 | 162 | # super(lstm_encoder, self).__init__() 163 | 164 | # self.hidden_size2 = hidden_size2 165 | # self.output_size = output_size 166 | # self.num_layers = num_layers 167 | 168 | # # define GRU layer 169 | # self.lstm = nn.GRU(input_size = hidden_size2, hidden_size = output_size, 170 | # num_layers = num_layers, batch_first = True) 171 | # self.linear = nn.Linear(output_size, hidden_size2) 172 | 173 | # def forward(self, x): 174 | 175 | # ''' 176 | # : param x_input: input of shape (# in batch, seq_len, input_size) 177 | # : return lstm_out, hidden: lstm_out gives all the hidden states in the sequence; 178 | # : hidden gives the hidden state and cell state for the last 179 | # : element in the sequence 180 | # ''' 181 | 182 | # lstm_out, self.hidden = self.lstm(x) 183 | # lstm_out = self.linear(lstm_out) 184 | # return lstm_out, self.hidden 185 | 186 | 187 | class RNNEncoder(torch.nn.Module): 188 | def __init__(self, config): 189 | super().__init__() 190 | 191 | # self.encoder = Encoder(config['hidden_size1'], 192 | # config['hidden_size2'], config['output_size']) 193 | # self.decoder = Decoder(config['hidden_size1'], 194 | # config['hidden_size2'], config['output_size']) 195 | self.encoder = lstm_encoder( 196 | config.hidden_size2, config.output_size) 197 | self.decoder = lstm_decoder( 198 | config.hidden_size2, config.output_size) 199 | # self.config = T5Config( 200 | 201 | # vocab_size=self.model.config.vocab_size, 202 | # d_model=self.model.config.d_model, 203 | # d_ff=self.model.config.d_ff, 204 | # num_heads=self.model.config.num_heads 205 | # ) 206 | 207 | def forward(self, output, **kwargs): 208 | # Get the output of the T5 encoder 209 | # with torch.no_grad(): 210 | encoder_output, encoder_hidden = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size 211 | batch_size = output.size(0) 212 | decoder_input = self.decoder.init_input_vector(batch_size) #batch_size, 1, embedding size 213 | 214 | decoder_hidden = self.decoder.init_hidden(encoder_hidden) #1 batch_size, hidden_size 215 | target_length = encoder_output.size(1) 216 | outputs = [] 217 | use_teacher_forcing = True if random.random() < 0 else False 218 | 219 | if use_teacher_forcing: 220 | # Use teacher forcing: feed the target sequence to the decoder one token at a time 221 | for i in range(target_length): 222 | #print(decoder_input.shape) 223 | decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden) 224 | outputs.append(decoder_output.squeeze(1)) 225 | decoder_input = output[:, i:i+1] 226 | #print(decoder_input.shape) 227 | 228 | else: 229 | # Use the previous output as the input to the decoder 230 | for i in range(target_length): 231 | #print(decoder_input.shape, decoder_hidden.shape) 232 | #batch size * seq length * embedding size, 1 * batch size * hidden_size 233 | decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden) 234 | outputs.append(decoder_output.squeeze(1)) 235 | decoder_input = decoder_output 236 | 237 | output = torch.stack(outputs, dim=1) 238 | return output 239 | 240 | 241 | 242 | 243 | 244 | 245 | # class T5NestAutoencoder(PreTrainedModel): 246 | # def __init__(self, config): 247 | # super().__init__(config) 248 | # self.model = T5ForConditionalGeneration.from_pretrained("t5-small") 249 | # self.encoder = Encoder(config.hidden_size1, 250 | # config.hidden_size2, config.output_size) 251 | # self.decoder = Decoder(config.hidden_size1, 252 | # config.hidden_size2, config.output_size) 253 | # self.encoder.main_input_name = self.model.main_input_name 254 | # self.main_input_name = self.model.main_input_name 255 | 256 | # # self.config = T5Config( 257 | 258 | # # vocab_size=self.model.config.vocab_size, 259 | # # d_model=self.model.config.d_model, 260 | # # d_ff=self.model.config.d_ff, 261 | # # num_heads=self.model.config.num_heads 262 | # # ) 263 | 264 | # def forward(self, input_ids, attention_mask, **kwargs): 265 | # # Get the output of the T5 encoder 266 | # # with torch.no_grad(): 267 | # output = self.model.encoder( 268 | # input_ids=input_ids, attention_mask=attention_mask).last_hidden_state 269 | 270 | # output = self.encoder(output) 271 | # output = self.decoder(output) 272 | 273 | # return self.model.forward(input_ids=input_ids, encoder_outputs=(output, ), labels=input_ids, **kwargs) 274 | 275 | # # def generate(self, input_ids, attention_mask, max_length=20, num_return_sequences=1.0, 276 | # # length_penalty=1.0, do_sample=True, early_stopping=True, num_beams = 1, num_beam_groups = 1, 277 | # # synced_gpus=False): 278 | # # output = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state 279 | 280 | # # output = self.encoder(output) 281 | # # output = self.decoder(output) 282 | 283 | # # return self.model.generate(input_ids=input_ids, encoder_outputs = (output, ), max_length = max_length, 284 | # # num_return_sequences = num_return_sequences, length_penalty=length_penalty, 285 | # # do_sample=do_sample, early_stopping=early_stopping, num_beams = num_beams, 286 | # # num_beam_groups = num_beam_groups, synced_gpus=synced_gpus) 287 | # def generate(self, input_ids, attention_mask, **kwargs): 288 | # output = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state 289 | 290 | # output = self.encoder(output) 291 | # output = self.decoder(output) 292 | 293 | # return self.model.generate(input_ids=input_ids, encoder_outputs=(output,), **kwargs) 294 | 295 | 296 | # class T5NestAutoencoderTrainT5Decoder(PreTrainedModel): 297 | # def __init__(self, model_train, config): 298 | 299 | # super().__init__(config) 300 | # # load the our trained model 301 | # self.model_train = model_train 302 | # self.main_input_name = self.model_train.main_input_name 303 | # # load another t5 model for summarization decoder 304 | # self.model_sum = T5ForConditionalGeneration.from_pretrained("t5-small") 305 | 306 | # # get the trained t5 encoder and the following auto-encoder 307 | # self.T5_encoder = self.model_train.model.encoder 308 | # self.encoder = self.model_train.encoder 309 | # self.decoder = self.model_train.decoder 310 | 311 | # def forward(self, input_ids, attention_mask, labels, **kwargs): 312 | # # Get the output of the T5 encoder 313 | # # with torch.no_grad(): 314 | # output = self.T5_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state 315 | 316 | # # the following auto-encoder 317 | # output = self.encoder(output) 318 | # output = self.decoder(output) 319 | 320 | # # no need to calculate the gradient above this point 321 | # output = output.detach() 322 | 323 | # # add the decoder for summarization 324 | # # note that in current preprocess_function, we just assign the input_ids of target to the 325 | # # labels of the dataset 326 | # output = self.model_sum.forward( 327 | # input_ids=input_ids, encoder_outputs=(output,), labels=labels, **kwargs) 328 | 329 | # return output 330 | 331 | # def generate(self, input_ids, attention_mask, **kwargs): 332 | # # t5 encoder 333 | # output = self.T5_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state 334 | 335 | # # the following auto-encoder 336 | # output = self.encoder(output) 337 | # output = self.decoder(output) 338 | 339 | # # use another decoder to generate 340 | # output = self.model_sum.generate(input_ids=input_ids, encoder_outputs=(output,), **kwargs) 341 | 342 | # return output -------------------------------------------------------------------------------- /src/topicmodeling/T5_Encoder/__pycache__/CNN_Encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/T5_Encoder/__pycache__/CNN_Encoder.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/T5_Encoder/__pycache__/flanT5_cnn_lighting.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/T5_Encoder/__pycache__/flanT5_cnn_lighting.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/T5_Encoder/__pycache__/sentence_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/T5_Encoder/__pycache__/sentence_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/__init__.py -------------------------------------------------------------------------------- /src/topicmodeling/__pycache__/CNN_Encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/__pycache__/CNN_Encoder.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/__pycache__/flanT5_cnn_lighting.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/__pycache__/flanT5_cnn_lighting.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /src/topicmodeling/__pycache__/preprocess.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/__pycache__/preprocess.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/__pycache__/sentence_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/__pycache__/sentence_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/__init__.py: -------------------------------------------------------------------------------- 1 | import hyperspherical_vae.ops 2 | import hyperspherical_vae.distributions 3 | -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from hyperspherical_vae.distributions.von_mises_fisher import VonMisesFisher 2 | from hyperspherical_vae.distributions.hyperspherical_uniform import ( 3 | HypersphericalUniform, 4 | ) 5 | -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/distributions/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/distributions/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/distributions/__pycache__/hyperspherical_uniform.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/distributions/__pycache__/hyperspherical_uniform.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/distributions/__pycache__/hyperspherical_uniform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/distributions/__pycache__/hyperspherical_uniform.cpython-38.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/distributions/__pycache__/von_mises_fisher.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/distributions/__pycache__/von_mises_fisher.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/distributions/__pycache__/von_mises_fisher.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/distributions/__pycache__/von_mises_fisher.cpython-38.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/distributions/hyperspherical_uniform.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | class HypersphericalUniform(torch.distributions.Distribution): 6 | 7 | support = torch.distributions.constraints.real 8 | has_rsample = False 9 | _mean_carrier_measure = 0 10 | 11 | @property 12 | def dim(self): 13 | return self._dim 14 | 15 | @property 16 | def device(self): 17 | return self._device 18 | 19 | @device.setter 20 | def device(self, val): 21 | self._device = val if isinstance(val, torch.device) else torch.device(val) 22 | 23 | def __init__(self, dim, validate_args=None, device="cpu"): 24 | super(HypersphericalUniform, self).__init__( 25 | torch.Size([dim]), validate_args=validate_args 26 | ) 27 | self._dim = dim 28 | self.device = device 29 | 30 | def sample(self, shape=torch.Size()): 31 | output = ( 32 | torch.distributions.Normal(0, 1) 33 | .sample( 34 | (shape if isinstance(shape, torch.Size) else torch.Size([shape])) 35 | + torch.Size([self._dim + 1]) 36 | ) 37 | .to(self.device) 38 | ) 39 | 40 | return output / output.norm(dim=-1, keepdim=True) 41 | 42 | def entropy(self): 43 | return self.__log_surface_area() 44 | 45 | def log_prob(self, x): 46 | return -torch.ones(x.shape[:-1], device=self.device) * self.__log_surface_area() 47 | 48 | def __log_surface_area(self): 49 | if torch.__version__ >= "1.0.0": 50 | lgamma = torch.lgamma(torch.tensor([(self._dim + 1) / 2]).to(self.device)) 51 | else: 52 | lgamma = torch.lgamma( 53 | torch.Tensor([(self._dim + 1) / 2], device=self.device) 54 | ) 55 | return math.log(2) + ((self._dim + 1) / 2) * math.log(math.pi) - lgamma 56 | -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/distributions/von_mises_fisher.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.distributions.kl import register_kl 4 | 5 | from hyperspherical_vae.ops.ive import ive, ive_fraction_approx, ive_fraction_approx2 6 | from hyperspherical_vae.distributions.hyperspherical_uniform import ( 7 | HypersphericalUniform, 8 | ) 9 | 10 | 11 | class VonMisesFisher(torch.distributions.Distribution): 12 | 13 | arg_constraints = { 14 | "loc": torch.distributions.constraints.real, 15 | "scale": torch.distributions.constraints.positive, 16 | } 17 | support = torch.distributions.constraints.real 18 | has_rsample = True 19 | _mean_carrier_measure = 0 20 | 21 | @property 22 | def mean(self): 23 | # option 1: 24 | # return self.loc * ( 25 | # ive(self.__m / 2, self.scale) / ive(self.__m / 2 - 1, self.scale) 26 | # ) 27 | # option 2: 28 | return self.loc * ive_fraction_approx(torch.tensor(self.__m / 2), self.scale) 29 | # options 3: 30 | # return self.loc * ive_fraction_approx2(torch.tensor(self.__m / 2), self.scale) 31 | 32 | @property 33 | def stddev(self): 34 | return self.scale 35 | 36 | def __init__(self, loc, scale, validate_args=None, k=1): 37 | self.dtype = loc.dtype 38 | self.loc = loc 39 | self.scale = scale 40 | self.device = loc.device 41 | self.__m = loc.shape[-1] 42 | self.__e1 = (torch.Tensor([1.0] + [0] * (loc.shape[-1] - 1))).to(self.device) 43 | self.k = k 44 | 45 | super().__init__(self.loc.size(), validate_args=validate_args) 46 | 47 | def sample(self, shape=torch.Size()): 48 | with torch.no_grad(): 49 | return self.rsample(shape) 50 | 51 | def rsample(self, shape=torch.Size()): 52 | shape = shape if isinstance(shape, torch.Size) else torch.Size([shape]) 53 | 54 | w = ( 55 | self.__sample_w3(shape=shape) 56 | if self.__m == 3 57 | else self.__sample_w_rej(shape=shape) 58 | ) 59 | 60 | v = ( 61 | torch.distributions.Normal(0, 1) 62 | .sample(shape + torch.Size(self.loc.shape)) 63 | .to(self.device) 64 | .transpose(0, -1)[1:] 65 | ).transpose(0, -1) 66 | v = v / v.norm(dim=-1, keepdim=True) 67 | 68 | w_ = torch.sqrt(torch.clamp(1 - (w ** 2), 1e-10)) 69 | x = torch.cat((w, w_ * v), -1) 70 | z = self.__householder_rotation(x) 71 | 72 | return z.type(self.dtype) 73 | 74 | def __sample_w3(self, shape): 75 | shape = shape + torch.Size(self.scale.shape) 76 | u = torch.distributions.Uniform(0, 1).sample(shape).to(self.device) 77 | self.__w = ( 78 | 1 79 | + torch.stack( 80 | [torch.log(u), torch.log(1 - u) - 2 * self.scale], dim=0 81 | ).logsumexp(0) 82 | / self.scale 83 | ) 84 | return self.__w 85 | 86 | def __sample_w_rej(self, shape): 87 | c = torch.sqrt((4 * (self.scale ** 2)) + (self.__m - 1) ** 2) 88 | b_true = (-2 * self.scale + c) / (self.__m - 1) 89 | 90 | # using Taylor approximation with a smooth swift from 10 < scale < 11 91 | # to avoid numerical errors for large scale 92 | b_app = (self.__m - 1) / (4 * self.scale) 93 | s = torch.min( 94 | torch.max( 95 | torch.tensor([0.0], dtype=self.dtype, device=self.device), 96 | self.scale - 10, 97 | ), 98 | torch.tensor([1.0], dtype=self.dtype, device=self.device), 99 | ) 100 | b = b_app * s + b_true * (1 - s) 101 | 102 | a = (self.__m - 1 + 2 * self.scale + c) / 4 103 | d = (4 * a * b) / (1 + b) - (self.__m - 1) * math.log(self.__m - 1) 104 | 105 | self.__b, (self.__e, self.__w) = b, self.__while_loop(b, a, d, shape, k=self.k) 106 | return self.__w 107 | 108 | @staticmethod 109 | def first_nonzero(x, dim, invalid_val=-1): 110 | mask = x > 0 111 | idx = torch.where( 112 | mask.any(dim=dim), 113 | mask.float().argmax(dim=1).squeeze(), 114 | torch.tensor(invalid_val, device=x.device), 115 | ) 116 | return idx 117 | 118 | def __while_loop(self, b, a, d, shape, k=20, eps=1e-20): 119 | # matrix while loop: samples a matrix of [A, k] samples, to avoid looping all together 120 | b, a, d = [ 121 | e.repeat(*shape, *([1] * len(self.scale.shape))).reshape(-1, 1) 122 | for e in (b, a, d) 123 | ] 124 | w, e, bool_mask = ( 125 | torch.zeros_like(b).to(self.device), 126 | torch.zeros_like(b).to(self.device), 127 | (torch.ones_like(b) == 1).to(self.device), 128 | ) 129 | 130 | sample_shape = torch.Size([b.shape[0], k]) 131 | shape = shape + torch.Size(self.scale.shape) 132 | 133 | while bool_mask.sum() != 0: 134 | con1 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64) 135 | con2 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64) 136 | e_ = ( 137 | torch.distributions.Beta(con1, con2) 138 | .sample(sample_shape) 139 | .to(self.device) 140 | .type(self.dtype) 141 | ) 142 | 143 | u = ( 144 | torch.distributions.Uniform(0 + eps, 1 - eps) 145 | .sample(sample_shape) 146 | .to(self.device) 147 | .type(self.dtype) 148 | ) 149 | 150 | w_ = (1 - (1 + b) * e_) / (1 - (1 - b) * e_) 151 | t = (2 * a * b) / (1 - (1 - b) * e_) 152 | 153 | accept = ((self.__m - 1.0) * t.log() - t + d) > torch.log(u) 154 | accept_idx = self.first_nonzero(accept, dim=-1, invalid_val=-1).unsqueeze(1) 155 | accept_idx_clamped = accept_idx.clamp(0) 156 | # we use .abs(), in order to not get -1 index issues, the -1 is still used afterwards 157 | w_ = w_.gather(1, accept_idx_clamped.view(-1, 1)) 158 | e_ = e_.gather(1, accept_idx_clamped.view(-1, 1)) 159 | 160 | reject = accept_idx < 0 161 | accept = ~reject if torch.__version__ >= "1.2.0" else 1 - reject 162 | 163 | w[bool_mask * accept] = w_[bool_mask * accept] 164 | e[bool_mask * accept] = e_[bool_mask * accept] 165 | 166 | bool_mask[bool_mask * accept] = reject[bool_mask * accept] 167 | 168 | return e.reshape(shape), w.reshape(shape) 169 | 170 | def __householder_rotation(self, x): 171 | u = self.__e1 - self.loc 172 | u = u / (u.norm(dim=-1, keepdim=True) + 1e-5) 173 | z = x - 2 * (x * u).sum(-1, keepdim=True) * u 174 | return z 175 | 176 | def entropy(self): 177 | # option 1: 178 | # output = ( 179 | # -self.scale 180 | # * ive(self.__m / 2, self.scale) 181 | # / ive((self.__m / 2) - 1, self.scale) 182 | # ) 183 | # option 2: 184 | output = - self.scale * ive_fraction_approx(torch.tensor(self.__m / 2), self.scale) 185 | # option 3: 186 | # output = - self.scale * ive_fraction_approx2(torch.tensor(self.__m / 2), self.scale) 187 | 188 | return output.view(*(output.shape[:-1])) #+ self._log_normalization() 189 | 190 | def log_prob(self, x): 191 | return self._log_unnormalized_prob(x) - self._log_normalization() 192 | 193 | def _log_unnormalized_prob(self, x): 194 | output = self.scale * (self.loc * x).sum(-1, keepdim=True) 195 | 196 | return output.view(*(output.shape[:-1])) 197 | 198 | def _log_normalization(self): 199 | output = -( 200 | (self.__m / 2 - 1) * torch.log(self.scale) 201 | - (self.__m / 2) * math.log(2 * math.pi) 202 | - (self.scale + torch.log(ive(self.__m / 2 - 1, self.scale))) 203 | ) 204 | 205 | return output.view(*(output.shape[:-1])) 206 | 207 | 208 | @register_kl(VonMisesFisher, HypersphericalUniform) 209 | def _kl_vmf_uniform(vmf, hyu): 210 | #print(vmf.entropy() , hyu.entropy()) 211 | return -vmf.entropy() + hyu.entropy() -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from hyperspherical_vae.ops.ive import ive 2 | -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/ops/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/ops/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/ops/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/ops/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/ops/__pycache__/ive.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/ops/__pycache__/ive.cpython-310.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/ops/__pycache__/ive.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/hyperspherical_vae/ops/__pycache__/ive.cpython-38.pyc -------------------------------------------------------------------------------- /src/topicmodeling/hyperspherical_vae/ops/ive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.special 4 | from numbers import Number 5 | 6 | 7 | class IveFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(self, v, z): 10 | 11 | assert isinstance(v, Number), "v must be a scalar" 12 | 13 | self.save_for_backward(z) 14 | self.v = v 15 | z_cpu = z.data.cpu().numpy() 16 | 17 | if np.isclose(v, 0): 18 | output = scipy.special.i0e(z_cpu, dtype=z_cpu.dtype) 19 | elif np.isclose(v, 1): 20 | output = scipy.special.i1e(z_cpu, dtype=z_cpu.dtype) 21 | else: # v > 0 22 | output = scipy.special.ive(v, z_cpu, dtype=z_cpu.dtype) 23 | # else: 24 | # print(v, type(v), np.isclose(v, 0)) 25 | # raise RuntimeError('v must be >= 0, it is {}'.format(v)) 26 | 27 | return torch.Tensor(output).to(z.device) 28 | 29 | @staticmethod 30 | def backward(self, grad_output): 31 | z = self.saved_tensors[-1] 32 | return ( 33 | None, 34 | grad_output * (ive(self.v - 1, z) - ive(self.v, z) * (self.v + z) / z), 35 | ) 36 | 37 | 38 | class Ive(torch.nn.Module): 39 | def __init__(self, v): 40 | super(Ive, self).__init__() 41 | self.v = v 42 | 43 | def forward(self, z): 44 | return ive(self.v, z) 45 | 46 | 47 | ive = IveFunction.apply 48 | 49 | 50 | ########## 51 | # The below provided approximations were provided in the 52 | # respective source papers, to improve the stability of 53 | # the Bessel fractions. 54 | # I_(v/2)(k) / I_(v/2 - 1)(k) 55 | 56 | # source: https://arxiv.org/pdf/1606.02008.pdf 57 | def ive_fraction_approx(v, z): 58 | # I_(v/2)(k) / I_(v/2 - 1)(k) >= z / (v-1 + ((v+1)^2 + z^2)^0.5 59 | return z / (v - 1 + torch.pow(torch.pow(v + 1, 2) + torch.pow(z, 2), 0.5)) 60 | 61 | 62 | # source: https://arxiv.org/pdf/1902.02603.pdf 63 | def ive_fraction_approx2(v, z, eps=1e-20): 64 | def delta_a(a): 65 | lamb = v + (a - 1.0) / 2.0 66 | return (v - 0.5) + lamb / ( 67 | 2 * torch.sqrt((torch.pow(lamb, 2) + torch.pow(z, 2)).clamp(eps)) 68 | ) 69 | 70 | delta_0 = delta_a(0.0) 71 | delta_2 = delta_a(2.0) 72 | B_0 = z / ( 73 | delta_0 + torch.sqrt((torch.pow(delta_0, 2) + torch.pow(z, 2))).clamp(eps) 74 | ) 75 | B_2 = z / ( 76 | delta_2 + torch.sqrt((torch.pow(delta_2, 2) + torch.pow(z, 2))).clamp(eps) 77 | ) 78 | 79 | return (B_0 + B_2) / 2.0 80 | -------------------------------------------------------------------------------- /src/topicmodeling/model.py: -------------------------------------------------------------------------------- 1 | # Organizing the imports 2 | # Standard libraries 3 | import string 4 | import pickle 5 | from collections import defaultdict 6 | 7 | # Libraries for Deep Learning and ML 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | from sklearn.cluster import KMeans 14 | from sklearn import metrics 15 | from scipy import sparse 16 | from hyperspherical_vae.distributions import VonMisesFisher, HypersphericalUniform 17 | 18 | # Libraries for NLP 19 | import nltk 20 | from nltk.corpus import stopwords, wordnet 21 | from nltk.stem import WordNetLemmatizer 22 | from nltk.tokenize import word_tokenize 23 | from nltk import pos_tag 24 | import gensim.downloader 25 | import gensim 26 | 27 | # Other utilities 28 | import pandas as pd 29 | import numpy as np 30 | import ot 31 | import matplotlib.pyplot as plt 32 | import seaborn as sns 33 | from datasets import Dataset 34 | from utils import kld_normal 35 | from preprocess import TextProcessor 36 | 37 | 38 | class EmbTopic(nn.Module): 39 | """ 40 | A class used to represent decoder for Embedded Topic Modeling 41 | 42 | Attributes 43 | ---------- 44 | topic_emb: nn.Parameters 45 | represent topic embedding 46 | 47 | 48 | Methods: 49 | -------- 50 | forward(logit) 51 | Output the result from decoder 52 | get_topics 53 | result before log 54 | 55 | 56 | """ 57 | def __init__(self, embedding, k, normalize = False): 58 | super(EmbTopic, self).__init__() 59 | self.embedding = embedding 60 | n_vocab, topic_dim = embedding.weight.size() 61 | self.k = k 62 | self.topic_emb = nn.Parameter(torch.Tensor(k, topic_dim)) 63 | self.reset_parameters() 64 | self.normalize = normalize 65 | 66 | def forward(self, logit): 67 | # return the log_prob of vocab distribution 68 | # if normalize: 69 | # self.topic_emb = torch.nn.Parameter(normalize(self.topic_emb)) 70 | if self.normalize: 71 | val = normalize(self.topic_emb) @ self.embedding.weight.transpose(0, 1) 72 | else: 73 | val = self.topic_emb @ self.embedding.weight.transpose(0, 1) 74 | # print(val.shape) 75 | beta = F.softmax(val, dim=1) 76 | # print(beta.shape) 77 | # return beta 78 | return torch.log(torch.matmul(logit, beta) + 1e-10) 79 | 80 | def get_topics(self): 81 | return F.softmax(self.topic_emb @ self.embedding.weight.transpose(0, 1), dim=1) 82 | 83 | 84 | def get_rank(self): 85 | #self.topic_emb = torch.nn.Parameter(normalize(self.topic_emb)) 86 | return normalize(self.topic_emb) @ self.embedding.weight.transpose(0, 1) 87 | 88 | def reset_parameters(self): 89 | init.normal_(self.topic_emb) 90 | # init.kaiming_uniform_(self.topic_emb, a=math.sqrt(5)) 91 | # init.normal_(self.embedding.weight, std=0.01) 92 | 93 | def extra_repr(self): 94 | k, d = self.topic_emb.size() 95 | return 'topic_emb: Parameter({}, {})'.format(k, d) 96 | 97 | 98 | 99 | def topic_covariance_penalty(topic_emb, EPS=1e-12): 100 | """topic_emb: T x topic_dim.""" 101 | #normalized the topic 102 | normalized_topic = topic_emb / (torch.norm(topic_emb, dim=-1, keepdim=True) + EPS) 103 | #get topic similarity absolute value 104 | cosine = (normalized_topic @ normalized_topic.transpose(0, 1)).abs() 105 | #average similarity 106 | mean = cosine.mean() 107 | #variance 108 | var = ((cosine - mean) ** 2).mean() 109 | return mean - var, var, mean 110 | 111 | class NormalParameter(nn.Module): 112 | def __init__(self, in_features, out_features): 113 | super(NormalParameter, self).__init__() 114 | self.in_features = in_features 115 | self.out_features = out_features 116 | self.mu = nn.Linear(in_features, out_features) 117 | self.log_sigma = nn.Linear(in_features, out_features) 118 | self.reset_parameters() 119 | 120 | def forward(self, h): 121 | return self.mu(h), self.log_sigma(h) 122 | 123 | def reset_parameters(self): 124 | init.zeros_(self.log_sigma.weight) 125 | init.zeros_(self.log_sigma.bias) 126 | 127 | class NTM(nn.Module): 128 | """NTM that keeps track of output 129 | """ 130 | def __init__(self, hidden, normal, h_to_z, topics): 131 | super(NTM, self).__init__() 132 | self.hidden = hidden 133 | self.normal = normal 134 | self.h_to_z = h_to_z 135 | self.topics = topics 136 | self.output = None 137 | self.drop = nn.Dropout(p=0.5) 138 | def forward(self, x, n_sample=1): 139 | h = self.hidden(x) 140 | h = self.drop(h) 141 | mu, log_sigma = self.normal(h) 142 | #identify how far it is away from normal distribution 143 | kld = kld_normal(mu, log_sigma) 144 | #print(kld.shape) 145 | rec_loss = 0 146 | for i in range(n_sample): 147 | #reparametrician trick 148 | z = torch.zeros_like(mu).normal_() * torch.exp(0.5*log_sigma) + mu 149 | #decode 150 | 151 | z = self.h_to_z(z) 152 | self.output = z 153 | #print(z) 154 | #z = self.drop(z) 155 | #get log probability for reconstruction loss 156 | log_prob = self.topics(z) 157 | rec_loss = rec_loss - (log_prob * x).sum(dim=-1) 158 | #average reconstruction loss 159 | rec_loss = rec_loss / n_sample 160 | #print(rec_loss.shape) 161 | minus_elbo = rec_loss + kld 162 | 163 | return { 164 | 'loss': minus_elbo, 165 | 'minus_elbo': minus_elbo, 166 | 'rec_loss': rec_loss, 167 | 'kld': kld 168 | } 169 | 170 | def get_topics(self): 171 | return self.topics.get_topics() 172 | 173 | def optimal_transport_prior(softmax_top, index, 174 | lambda_sh = 1): 175 | """ add prior as a semi-supervised loss 176 | 177 | parameters 178 | ---------- 179 | softmax_top: softmax results from decoder 180 | index: list: a list of list with number as index 181 | embedding: numpy array, word embedding trained by spherical word embeddings 182 | beta: float, weights for prior loss 183 | gamma: float, weights for negative sampling 184 | iter2: int, how many epochs to train for third phase 185 | sample: int, sample number 186 | lambda_sh: low means high entrophy 187 | 188 | Returns: 189 | -------- 190 | int 191 | loss functions 192 | 193 | """ 194 | 195 | m = - torch.log(softmax_top + 1e-12) 196 | loss = torch.cat([m[:, i].mean(axis = 1).reshape(1, -1) for i in index]).to(m.device) 197 | #print(loss.shape) 198 | b = torch.ones(loss.shape[1]).to(m.device) 199 | a = torch.ones(loss.shape[0]).to(m.device) 200 | 201 | return ot.sinkhorn(a, b, loss, lambda_sh).sum() 202 | 203 | class VNTM(nn.Module): 204 | """NTM that keeps track of output 205 | """ 206 | def __init__(self, hidden, normal, h_to_z, topics, layer, top_number, penalty, beta = 1, index = None, temp=10): 207 | super(VNTM, self).__init__() 208 | self.hidden = hidden 209 | #self.normal = normal 210 | self.h_to_z = h_to_z 211 | self.topics = topics 212 | self.output = None 213 | self.index = index 214 | self.drop = nn.Dropout(p=0.3) 215 | self.fc_mean = nn.Linear(layer, top_number) 216 | self.fc_var = nn.Linear(layer, 1) 217 | self.num = top_number 218 | self.penalty = penalty 219 | self.temp = temp 220 | self.beta = beta 221 | 222 | #self.dirichlet = torch.distributions.dirichlet.Dirichlet((torch.ones(self.topics.k)/self.topics.k).cuda()) 223 | def forward(self, x, input, device, n_sample=1, epoch = 0): 224 | h = self.hidden(input) 225 | h = self.drop(h) 226 | z_mean = self.fc_mean(h) 227 | z_mean = z_mean / z_mean.norm(dim=-1, keepdim=True) 228 | # the `+ 1` prevent collapsing behaviors 229 | z_var = F.softplus(self.fc_var(h)) + 1 230 | 231 | q_z = VonMisesFisher(z_mean, z_var) 232 | p_z = HypersphericalUniform(self.num - 1, device=device) 233 | kld = torch.distributions.kl.kl_divergence(q_z, p_z).mean().to(device) 234 | #print(q_z) 235 | #mu, log_sigma = self.normal(h) 236 | #identify how far it is away from normal distribution 237 | 238 | #print(kld.shape) 239 | rec_loss = 0 240 | for i in range(n_sample): 241 | #reparametrician trick 242 | z = q_z.rsample() 243 | #z = nn.Softmax()(z) 244 | #decode 245 | #print(z) 246 | 247 | z = self.h_to_z(self.temp * z) 248 | self.output = z 249 | #print(z) 250 | 251 | #get log probability for reconstruction loss 252 | log_prob = self.topics(z) 253 | rec_loss = rec_loss - (log_prob * x).sum(dim=-1) 254 | #average reconstruction loss 255 | rec_loss = rec_loss / n_sample 256 | #print(rec_loss.shape) 257 | minus_elbo = rec_loss + kld 258 | penalty, var, mean = topic_covariance_penalty(self.topics.topic_emb) 259 | if self.index is not None: 260 | sinkhorn = optimal_transport_prior(self.topics.get_topics(), self.index) 261 | else: 262 | sinkhorn = 0 263 | 264 | return { 265 | 'loss': minus_elbo + penalty * self.penalty + sinkhorn * self.beta, 266 | 'minus_elbo': minus_elbo, 267 | 'rec_loss': rec_loss, 268 | 'kld': kld 269 | } 270 | 271 | def get_topics(self): 272 | return self.topics.get_topics() 273 | 274 | def get_mlp(features, activate): 275 | """features: mlp size of each layer, append activation in each layer except for the first layer.""" 276 | if isinstance(activate, str): 277 | activate = getattr(nn, activate) 278 | layers = [] 279 | for in_f, out_f in zip(features[:-1], features[1:]): 280 | layers.append(nn.Linear(in_f, out_f)) 281 | layers.append(activate()) 282 | return nn.Sequential(*layers) 283 | 284 | class GSM(NTM): 285 | def __init__(self, hidden, normal, h_to_z, topics, penalty): 286 | # h_to_z will output probabilities over topics 287 | super(GSM, self).__init__(hidden, normal, h_to_z, topics) 288 | self.penalty = penalty 289 | 290 | def forward(self, x, device, n_sample=1): 291 | stat = super(GSM, self).forward(x, n_sample) 292 | loss = stat['loss'].to(device) 293 | penalty, var, mean = topic_covariance_penalty(self.topics.topic_emb) 294 | 295 | stat.update({ 296 | 'loss': loss #+ penalty.to(device) * self.penalty, 297 | # 'penalty_mean': mean, 298 | # 'penalty_var': var, 299 | # 'penalty': penalty.to(device) * self.penalty, 300 | }) 301 | 302 | return stat 303 | 304 | class Topics(nn.Module): 305 | def __init__(self, k, vocab_size, bias=True): 306 | super(Topics, self).__init__() 307 | self.k = k 308 | self.vocab_size = vocab_size 309 | self.topic = nn.Linear(k, vocab_size, bias=bias) 310 | 311 | def forward(self, logit): 312 | # return the log_prob of vocab distribution 313 | return torch.log_softmax(self.topic(logit), dim=-1) 314 | 315 | def get_topics(self): 316 | return torch.softmax(self.topic.weight.data.transpose(0, 1), dim=-1) 317 | 318 | def get_topic_word_logit(self): 319 | """topic x V. 320 | Return the logits instead of probability distribution 321 | """ 322 | return self.topic.weight.transpose(0, 1) 323 | 324 | 325 | class TopicModel: 326 | def __init__(self, epochs=20, batch_size=256, gpu_num=1, numb_embeddings=20, 327 | learning_rate=0.002, weight_decay=1.2e-6, penalty=1, beta = 1, temp = 10, 328 | top_n_words=20, num_representative_docs=5, top_n_topics=100, embedding_dim=100): 329 | 330 | self.dataset = None 331 | self.epochs = epochs 332 | self.batch_size = batch_size 333 | self.gpu_num = gpu_num 334 | self.numb_embeddings = numb_embeddings 335 | self.learning_rate = learning_rate 336 | self.weight_decay = weight_decay 337 | self.penalty = penalty 338 | self.top_n_words = top_n_words 339 | self.num_representative_docs = num_representative_docs 340 | self.top_n_topics = top_n_topics 341 | self.embedding_dim = embedding_dim 342 | self.device = torch.device(f"cuda:{gpu_num}" if torch.cuda.is_available() else "cpu") 343 | self.beta = beta 344 | self.temp = temp 345 | self.z = None 346 | self.model = None 347 | 348 | def train(self, X, Input, batch_size): 349 | self.model.train() 350 | total_nll = 0.0 351 | total_kld = 0.0 352 | 353 | indices = torch.randperm(X.shape[0]) 354 | indices = torch.split(indices, batch_size) 355 | length = len(indices) 356 | for idx, ind in enumerate(indices): 357 | data_batch = X[ind].to(self.device).float() 358 | emb_batch = Input[ind].to(self.device).float() 359 | d = self.model(x = data_batch, input = emb_batch, device = self.device) 360 | 361 | total_nll += d['rec_loss'].sum().item() / batch_size 362 | total_kld += d['kld'].sum().item() / batch_size 363 | loss = d['loss'] 364 | 365 | self.optimizer.zero_grad() 366 | loss.sum().backward() 367 | self.optimizer.step() 368 | self.scheduler.step() 369 | 370 | print(total_nll/length, total_kld/length) 371 | 372 | def fit_transform(self, dataset, name, index = []): 373 | self.dataset = dataset 374 | self.name = name 375 | self.tp = TextProcessor(self.dataset, self.name) 376 | self.tp.process() 377 | bag_of_words = torch.tensor(self.tp.bow) 378 | embedding_text = torch.tensor(self.tp.embeddings ) 379 | if index != []: 380 | index_words = [[self.tp.word_to_index[word] for word in ind if word in self.tp.word_to_index] for ind in index] 381 | else: 382 | index_words = None 383 | print(index_words) 384 | print(embedding_text.shape) 385 | #print(bag_of_words.shape) 386 | # rest of your initialization code here 387 | layer = embedding_text.shape[1]//16 388 | hidden = get_mlp([embedding_text.shape[1], embedding_text.shape[1]//4, layer], nn.GELU) 389 | normal = NormalParameter(layer, self.numb_embeddings) 390 | h_to_z = nn.Softmax() 391 | embedding = nn.Embedding(bag_of_words.shape[1], 100) 392 | # p1d = (0, 0, 0, 10000 - company1.embeddings.shape[0]) # pad last dim by 1 on each side 393 | # out = F.pad(company1.embeddings, p1d, "constant", 0) # effectively zero padding 394 | 395 | glove_vectors = gensim.downloader.load('glove-wiki-gigaword-100') 396 | embed = np.asarray([glove_vectors[self.tp.index_to_word[i]] if self.tp.index_to_word[i] in glove_vectors else np.asarray([1]*100) for i in self.tp.index_to_word ]) 397 | print(embed.shape) 398 | embedding.weight = torch.nn.Parameter(torch.from_numpy(embed).float()) 399 | embedding.weight.requires_grad=True 400 | 401 | 402 | 403 | topics = EmbTopic(embedding = embedding, 404 | k = self.numb_embeddings, normalize = False) 405 | 406 | 407 | 408 | 409 | self.model = VNTM(hidden = hidden, 410 | normal = normal, 411 | h_to_z = h_to_z, 412 | topics = topics, 413 | layer = layer, 414 | top_number = self.numb_embeddings, 415 | index = index_words, 416 | penalty = self.penalty, 417 | beta = self.beta, 418 | temp = self.temp, 419 | ).to(self.device).float() 420 | 421 | #batch_size = 256 422 | self.optimizer = optim.Adam(self.model.parameters(), 423 | lr=self.learning_rate, 424 | weight_decay=self.weight_decay) 425 | 426 | 427 | 428 | 429 | self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=0.002, steps_per_epoch=int(bag_of_words.shape[0]/self.batch_size) + 1, epochs=self.epochs) 430 | 431 | # Initialize and train your model 432 | for epoch in range(self.epochs): 433 | self.train(bag_of_words, embedding_text, self.batch_size) 434 | 435 | # Store the topics 436 | emb = self.model.topics.get_topics().cpu().detach().numpy() 437 | self.topics = [[self.tp.index_to_word[ind] for ind in np.argsort(emb[i])[::-1][:self.top_n_topics]] for i in range(self.numb_embeddings)] #100 can be specified 438 | self.topics_score = [[score for score in np.sort(emb[i])[::-1]] for i in range(self.numb_embeddings)] 439 | # Compute and store the documents-topics distributions 440 | data_batch = bag_of_words.float() 441 | self.model.cpu() 442 | 443 | z = self.model.hidden(embedding_text.float()) 444 | z_mean = self.model.fc_mean(z) 445 | z_mean = z_mean / z_mean.norm(dim=-1, keepdim=True) 446 | self.z = self.model.h_to_z(z_mean).detach().numpy() 447 | self.topic_doc = [[ind for ind in np.argsort(self.z[:, i])[::-1][:100] ] for i in range(self.numb_embeddings)] #100 can be specified 448 | self.topic_doc_score = [[ind for ind in np.sort(self.z[:, i])[::-1][:100] ] for i in range(self.numb_embeddings)] #100 can be specified 449 | 450 | 451 | return self.topics, self.z 452 | 453 | def get_topics(self, index): 454 | return [(i, j) for i, j in zip(self.topics[index], self.topics_score[index])][:self.top_n_words] 455 | 456 | def get_representative_docs(self, index): 457 | return [(self.dataset[i], j) for i, j in zip(self.topic_doc[index], self.topic_doc_score[index])][:self.num_representative_docs] 458 | 459 | def topic_word_matrix(self): 460 | return self.model.topics.get_topics().cpu().detach().numpy() 461 | 462 | def topic_keywords(self): 463 | return self.topics 464 | 465 | def visualize_topic_similarity(self): 466 | # Compute topic similarity matrix 467 | topic_word_matrix = self.model.topics.topic_emb.detach().numpy() 468 | similarity_matrix = np.matmul(topic_word_matrix, topic_word_matrix.T) 469 | 470 | # Plot the similarity matrix as a heatmap 471 | plt.figure(figsize=(10, 10)) 472 | sns.heatmap(similarity_matrix, cmap="YlGnBu", square=True) 473 | plt.title('Topic Similarity Heatmap') 474 | plt.xlabel('Topic IDs') 475 | plt.ylabel('Topic IDs') 476 | plt.show() 477 | 478 | def visualize_topic_keywords(self, topic_id, num_keywords=10): 479 | # Get top keywords for the given topic 480 | topic_keywords = self.get_topics(topic_id)[:num_keywords] 481 | words, scores = zip(*topic_keywords) 482 | 483 | # Generate the bar plot 484 | plt.figure(figsize=(10, 5)) 485 | plt.barh(words, scores, color='skyblue') 486 | plt.xlabel("Keyword Importance") 487 | plt.title(f"Top {num_keywords} Keywords for Topic {topic_id}") 488 | plt.gca().invert_yaxis() 489 | plt.show() 490 | 491 | def get_document_info(self, top_n_words=10): 492 | data = [] 493 | for topic_id in range(self.numb_embeddings): 494 | topic_keywords = self.get_topics(topic_id)[:top_n_words] 495 | topic_keywords_str = "_".join([word for word, _ in topic_keywords[:3]]) 496 | 497 | # Get the document that has the highest probability for this topic 498 | doc_indices = np.argsort(self.z[:, topic_id])[::-1] 499 | representative_doc_index = doc_indices[0] 500 | representative_doc = self.dataset[representative_doc_index] 501 | 502 | # Count the number of documents that have this topic as their dominant topic 503 | dominant_topics = np.argmax(self.z, axis=1) 504 | num_docs = np.sum(dominant_topics == topic_id) 505 | 506 | data.append([topic_id, f"{topic_id}_{topic_keywords_str}", topic_keywords_str, representative_doc, num_docs]) 507 | 508 | df = pd.DataFrame(data, columns=["Topic", "Name", "Top_n_words", "Representative_Doc", "Num_Docs"]) 509 | return df 510 | 511 | def train_model(self, dataset, name, hyperparameters={}, top_words=10): 512 | 513 | self.top_n_words = top_words 514 | # Extract hyperparameters and set them as attributes 515 | if 'epochs' in hyperparameters: 516 | self.epochs = hyperparameters['epochs'] 517 | if 'batch_size' in hyperparameters: 518 | self.batch_size = hyperparameters['batch_size'] 519 | if 'gpu_num' in hyperparameters: 520 | self.gpu_num = hyperparameters['gpu_num'] 521 | if 'numb_embeddings' in hyperparameters: 522 | self.numb_embeddings = hyperparameters['numb_embeddings'] 523 | if 'learning_rate' in hyperparameters: 524 | self.learning_rate = hyperparameters['learning_rate'] 525 | if 'weight_decay' in hyperparameters: 526 | self.weight_decay = hyperparameters['weight_decay'] 527 | if 'penalty' in hyperparameters: 528 | self.penalty = hyperparameters['penalty'] 529 | if 'beta' in hyperparameters: 530 | self.beta = hyperparameters['beta'] 531 | if 'temp' in hyperparameters: 532 | self.temp = hyperparameters['temp'] 533 | 534 | if 'num_representative_docs' in hyperparameters: 535 | self.num_representative_docs = hyperparameters['num_representative_docs'] 536 | if 'top_n_topics' in hyperparameters: 537 | self.top_n_topics = hyperparameters['top_n_topics'] 538 | if 'embedding_dim' in hyperparameters: 539 | self.embedding_dim = hyperparameters['embedding_dim'] 540 | 541 | # Check if the model has been trained 542 | if self.z is None: 543 | self.fit_transform(dataset, name) 544 | 545 | # Create the model output 546 | model_output = {} 547 | model_output['topics'] = [i[:top_words] for i in self.topics] 548 | model_output['topic-word-matrix'] = self.model.topics.get_topics().cpu().detach().numpy() 549 | model_output['topic-document-matrix'] = self.z.T 550 | 551 | return model_output -------------------------------------------------------------------------------- /src/topicmodeling/model_evaluate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5ForConditionalGeneration\n", 10 | "# model = AutoModelForSeq2SeqLM.from_pretrained('model/flant5_nest_peft_PREFIX_TUNING_SEQ_2_SEQ_LMflan-t5-base1', local_files_only=True)\n", 11 | "tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-large')\n", 12 | "\n" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 24, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "data": { 22 | "text/plain": [ 23 | "FlanT5NestCNNAutoencoder(\n", 24 | " (model): T5ForConditionalGeneration(\n", 25 | " (shared): Embedding(32128, 1024)\n", 26 | " (encoder): T5Stack(\n", 27 | " (embed_tokens): Embedding(32128, 1024)\n", 28 | " (block): ModuleList(\n", 29 | " (0): T5Block(\n", 30 | " (layer): ModuleList(\n", 31 | " (0): T5LayerSelfAttention(\n", 32 | " (SelfAttention): T5Attention(\n", 33 | " (q): Linear(in_features=1024, out_features=1024, bias=False)\n", 34 | " (k): Linear(in_features=1024, out_features=1024, bias=False)\n", 35 | " (v): Linear(in_features=1024, out_features=1024, bias=False)\n", 36 | " (o): Linear(in_features=1024, out_features=1024, bias=False)\n", 37 | " (relative_attention_bias): Embedding(32, 16)\n", 38 | " )\n", 39 | " (layer_norm): T5LayerNorm()\n", 40 | " (dropout): Dropout(p=0.1, inplace=False)\n", 41 | " )\n", 42 | " (1): T5LayerFF(\n", 43 | " (DenseReluDense): T5DenseGatedActDense(\n", 44 | " (wi_0): Linear(in_features=1024, out_features=2816, bias=False)\n", 45 | " (wi_1): Linear(in_features=1024, out_features=2816, bias=False)\n", 46 | " (wo): Linear(in_features=2816, out_features=1024, bias=False)\n", 47 | " (dropout): Dropout(p=0.1, inplace=False)\n", 48 | " (act): NewGELUActivation()\n", 49 | " )\n", 50 | " (layer_norm): T5LayerNorm()\n", 51 | " (dropout): Dropout(p=0.1, inplace=False)\n", 52 | " )\n", 53 | " )\n", 54 | " )\n", 55 | " (1-23): 23 x T5Block(\n", 56 | " (layer): ModuleList(\n", 57 | " (0): T5LayerSelfAttention(\n", 58 | " (SelfAttention): T5Attention(\n", 59 | " (q): Linear(in_features=1024, out_features=1024, bias=False)\n", 60 | " (k): Linear(in_features=1024, out_features=1024, bias=False)\n", 61 | " (v): Linear(in_features=1024, out_features=1024, bias=False)\n", 62 | " (o): Linear(in_features=1024, out_features=1024, bias=False)\n", 63 | " )\n", 64 | " (layer_norm): T5LayerNorm()\n", 65 | " (dropout): Dropout(p=0.1, inplace=False)\n", 66 | " )\n", 67 | " (1): T5LayerFF(\n", 68 | " (DenseReluDense): T5DenseGatedActDense(\n", 69 | " (wi_0): Linear(in_features=1024, out_features=2816, bias=False)\n", 70 | " (wi_1): Linear(in_features=1024, out_features=2816, bias=False)\n", 71 | " (wo): Linear(in_features=2816, out_features=1024, bias=False)\n", 72 | " (dropout): Dropout(p=0.1, inplace=False)\n", 73 | " (act): NewGELUActivation()\n", 74 | " )\n", 75 | " (layer_norm): T5LayerNorm()\n", 76 | " (dropout): Dropout(p=0.1, inplace=False)\n", 77 | " )\n", 78 | " )\n", 79 | " )\n", 80 | " )\n", 81 | " (final_layer_norm): T5LayerNorm()\n", 82 | " (dropout): Dropout(p=0.1, inplace=False)\n", 83 | " )\n", 84 | " (decoder): T5Stack(\n", 85 | " (embed_tokens): Embedding(32128, 1024)\n", 86 | " (block): ModuleList(\n", 87 | " (0): T5Block(\n", 88 | " (layer): ModuleList(\n", 89 | " (0): T5LayerSelfAttention(\n", 90 | " (SelfAttention): T5Attention(\n", 91 | " (q): Linear(in_features=1024, out_features=1024, bias=False)\n", 92 | " (k): Linear(in_features=1024, out_features=1024, bias=False)\n", 93 | " (v): Linear(in_features=1024, out_features=1024, bias=False)\n", 94 | " (o): Linear(in_features=1024, out_features=1024, bias=False)\n", 95 | " (relative_attention_bias): Embedding(32, 16)\n", 96 | " )\n", 97 | " (layer_norm): T5LayerNorm()\n", 98 | " (dropout): Dropout(p=0.1, inplace=False)\n", 99 | " )\n", 100 | " (1): T5LayerCrossAttention(\n", 101 | " (EncDecAttention): T5Attention(\n", 102 | " (q): Linear(in_features=1024, out_features=1024, bias=False)\n", 103 | " (k): Linear(in_features=1024, out_features=1024, bias=False)\n", 104 | " (v): Linear(in_features=1024, out_features=1024, bias=False)\n", 105 | " (o): Linear(in_features=1024, out_features=1024, bias=False)\n", 106 | " )\n", 107 | " (layer_norm): T5LayerNorm()\n", 108 | " (dropout): Dropout(p=0.1, inplace=False)\n", 109 | " )\n", 110 | " (2): T5LayerFF(\n", 111 | " (DenseReluDense): T5DenseGatedActDense(\n", 112 | " (wi_0): Linear(in_features=1024, out_features=2816, bias=False)\n", 113 | " (wi_1): Linear(in_features=1024, out_features=2816, bias=False)\n", 114 | " (wo): Linear(in_features=2816, out_features=1024, bias=False)\n", 115 | " (dropout): Dropout(p=0.1, inplace=False)\n", 116 | " (act): NewGELUActivation()\n", 117 | " )\n", 118 | " (layer_norm): T5LayerNorm()\n", 119 | " (dropout): Dropout(p=0.1, inplace=False)\n", 120 | " )\n", 121 | " )\n", 122 | " )\n", 123 | " (1-23): 23 x T5Block(\n", 124 | " (layer): ModuleList(\n", 125 | " (0): T5LayerSelfAttention(\n", 126 | " (SelfAttention): T5Attention(\n", 127 | " (q): Linear(in_features=1024, out_features=1024, bias=False)\n", 128 | " (k): Linear(in_features=1024, out_features=1024, bias=False)\n", 129 | " (v): Linear(in_features=1024, out_features=1024, bias=False)\n", 130 | " (o): Linear(in_features=1024, out_features=1024, bias=False)\n", 131 | " )\n", 132 | " (layer_norm): T5LayerNorm()\n", 133 | " (dropout): Dropout(p=0.1, inplace=False)\n", 134 | " )\n", 135 | " (1): T5LayerCrossAttention(\n", 136 | " (EncDecAttention): T5Attention(\n", 137 | " (q): Linear(in_features=1024, out_features=1024, bias=False)\n", 138 | " (k): Linear(in_features=1024, out_features=1024, bias=False)\n", 139 | " (v): Linear(in_features=1024, out_features=1024, bias=False)\n", 140 | " (o): Linear(in_features=1024, out_features=1024, bias=False)\n", 141 | " )\n", 142 | " (layer_norm): T5LayerNorm()\n", 143 | " (dropout): Dropout(p=0.1, inplace=False)\n", 144 | " )\n", 145 | " (2): T5LayerFF(\n", 146 | " (DenseReluDense): T5DenseGatedActDense(\n", 147 | " (wi_0): Linear(in_features=1024, out_features=2816, bias=False)\n", 148 | " (wi_1): Linear(in_features=1024, out_features=2816, bias=False)\n", 149 | " (wo): Linear(in_features=2816, out_features=1024, bias=False)\n", 150 | " (dropout): Dropout(p=0.1, inplace=False)\n", 151 | " (act): NewGELUActivation()\n", 152 | " )\n", 153 | " (layer_norm): T5LayerNorm()\n", 154 | " (dropout): Dropout(p=0.1, inplace=False)\n", 155 | " )\n", 156 | " )\n", 157 | " )\n", 158 | " )\n", 159 | " (final_layer_norm): T5LayerNorm()\n", 160 | " (dropout): Dropout(p=0.1, inplace=False)\n", 161 | " )\n", 162 | " (lm_head): Linear(in_features=1024, out_features=32128, bias=False)\n", 163 | " )\n", 164 | " (encoder): CNNEncoder(\n", 165 | " (encoder): Sequential(\n", 166 | " (0): Conv1d(512, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", 167 | " (1): ReLU()\n", 168 | " (2): Conv1d(128, 16, kernel_size=(3,), stride=(1,), padding=(1,))\n", 169 | " (3): ReLU()\n", 170 | " (4): Conv1d(16, 4, kernel_size=(3,), stride=(1,), padding=(1,))\n", 171 | " )\n", 172 | " )\n", 173 | " (decoder): CNNDecoder(\n", 174 | " (decoder): Sequential(\n", 175 | " (0): Conv1d(4, 16, kernel_size=(3,), stride=(1,), padding=(1,))\n", 176 | " (1): ReLU()\n", 177 | " (2): Conv1d(16, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", 178 | " (3): ReLU()\n", 179 | " (4): Conv1d(128, 512, kernel_size=(3,), stride=(1,), padding=(1,))\n", 180 | " (5): Sigmoid()\n", 181 | " )\n", 182 | " )\n", 183 | ")" 184 | ] 185 | }, 186 | "execution_count": 24, 187 | "metadata": {}, 188 | "output_type": "execute_result" 189 | } 190 | ], 191 | "source": [ 192 | "from flanT5_cnn_lighting import T5AutoConfig, FlanT5NestCNNAutoencoder\n", 193 | "import torch\n", 194 | "config = T5AutoConfig(hidden_size1=512, hidden_size3=4, \n", 195 | " hidden_size2=768, output_size=4 * 768, \n", 196 | " model = 'google/flan-t5-large') # Replace with your config\n", 197 | "model = FlanT5NestCNNAutoencoder(config)\n", 198 | "model.load_state_dict(torch.load('/home/weijiexu/flant5_nest_peftflan-t5-large8'))\n", 199 | "model.eval() # Set the model to inference mode" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 47, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "model.eval()\n", 209 | "strings = \"I insist on this point because the market cannot regulate all and that, accordingly, it does not have there nor cannot have there of competition without limits.\"\n", 210 | "inputs = tokenizer(strings, return_tensors=\"pt\", padding='max_length', max_length = 512).input_ids.cuda()\n", 211 | "am = tokenizer(strings, return_tensors=\"pt\", padding='max_length', max_length = 512).attention_mask.cuda()\n", 212 | "outputs = model.cuda().generate(inputs, am, max_length = 512)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 48, 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "name": "stdout", 222 | "output_type": "stream", 223 | "text": [ 224 | "['I stress this point because the market cannot regulate everything, which is why competition is not completely unrestricted, nor should it be.']\n" 225 | ] 226 | } 227 | ], 228 | "source": [ 229 | "print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) " 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 5, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stderr", 239 | "output_type": "stream", 240 | "text": [ 241 | "Downloading readme: 100%|██████████| 396/396 [00:00<00:00, 289kB/s]\n" 242 | ] 243 | }, 244 | { 245 | "name": "stdout", 246 | "output_type": "stream", 247 | "text": [ 248 | "Downloading and preparing dataset None/None to /home/weijiexu/.cache/huggingface/datasets/xwjzds___parquet/xwjzds--sentence_paraphase-3dff1f881395cfc3/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...\n" 249 | ] 250 | }, 251 | { 252 | "name": "stderr", 253 | "output_type": "stream", 254 | "text": [ 255 | "Downloading data: 100%|██████████| 21.4M/21.4M [00:00<00:00, 32.9MB/s]\n", 256 | "Downloading data files: 100%|██████████| 1/1 [00:02<00:00, 2.04s/it]\n", 257 | "Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 909.83it/s]\n", 258 | " \r" 259 | ] 260 | }, 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "Dataset parquet downloaded and prepared to /home/weijiexu/.cache/huggingface/datasets/xwjzds___parquet/xwjzds--sentence_paraphase-3dff1f881395cfc3/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.\n" 266 | ] 267 | }, 268 | { 269 | "name": "stderr", 270 | "output_type": "stream", 271 | "text": [ 272 | "100%|██████████| 1/1 [00:00<00:00, 226.28it/s]\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "# from datasets import load_dataset\n", 278 | "\n", 279 | "# dataset = load_dataset(\"xwjzds/sentence_paraphase\")" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 8, 285 | "metadata": {}, 286 | "outputs": [ 287 | { 288 | "data": { 289 | "text/plain": [ 290 | "{'input': 'U.S. prosecutors have arrested more than 130 individuals and have seized more than $17 million in a continuing crackdown on Internet fraud and abuse.',\n", 291 | " 'output': 'More than 130 people have been arrested and $17 million worth of property seized in an Internet fraud sweep announced Friday by three U.S. government agencies.'}" 292 | ] 293 | }, 294 | "execution_count": 8, 295 | "metadata": {}, 296 | "output_type": "execute_result" 297 | } 298 | ], 299 | "source": [ 300 | "# dataset['train'][0]" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [] 309 | } 310 | ], 311 | "metadata": { 312 | "kernelspec": { 313 | "display_name": "env-01", 314 | "language": "python", 315 | "name": "python3" 316 | }, 317 | "language_info": { 318 | "codemirror_mode": { 319 | "name": "ipython", 320 | "version": 3 321 | }, 322 | "file_extension": ".py", 323 | "mimetype": "text/x-python", 324 | "name": "python", 325 | "nbconvert_exporter": "python", 326 | "pygments_lexer": "ipython3", 327 | "version": "3.10.10" 328 | }, 329 | "orig_nbformat": 4 330 | }, 331 | "nbformat": 4, 332 | "nbformat_minor": 2 333 | } 334 | -------------------------------------------------------------------------------- /src/topicmodeling/preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/weijiexu/anaconda3/envs/env-01/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n", 14 | "/local/home/weijiexu/workspace/Vontss/src/topicmodeling/T5_Encoder/CNN_Encoder.py:15: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n", 15 | " metric = datasets.load_metric('sacrebleu')\n", 16 | "Found cached dataset parquet (/home/weijiexu/.cache/huggingface/datasets/xwjzds___parquet/xwjzds--ag_news-14abcf7379d2aad0/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n", 17 | "100%|██████████| 2/2 [00:00<00:00, 329.24it/s]\n", 18 | "[nltk_data] Downloading package punkt to /home/weijiexu/nltk_data...\n", 19 | "[nltk_data] Package punkt is already up-to-date!\n", 20 | "[nltk_data] Downloading package wordnet to /home/weijiexu/nltk_data...\n", 21 | "[nltk_data] Package wordnet is already up-to-date!\n", 22 | "[nltk_data] Downloading package averaged_perceptron_tagger to\n", 23 | "[nltk_data] /home/weijiexu/nltk_data...\n", 24 | "[nltk_data] Package averaged_perceptron_tagger is already up-to-\n", 25 | "[nltk_data] date!\n", 26 | "[nltk_data] Downloading package stopwords to\n", 27 | "[nltk_data] /home/weijiexu/nltk_data...\n", 28 | "[nltk_data] Package stopwords is already up-to-date!\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "from datasets import load_dataset\n", 34 | "from preprocess import TextProcessor\n", 35 | "df = load_dataset('xwjzds/ag_news')\n", 36 | "\n", 37 | "model_output = TextProcessor(df['train']['text'][:10])" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "[\"Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\\\band of ultra-cynics, are seeing green again.\"]\n", 50 | "['Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\\\which has a reputation for making well-timed and occasionally\\\\controversial plays in the defense industry, has quietly placed\\\\its bets on another part of the market.']\n", 51 | "[\"Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\\\\about the economy and the outlook for earnings are expected to\\\\hang over the stock market next week during the depth of the\\\\summer doldrums.\"]\n", 52 | "['Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\\\\flows from the main pipeline in southern Iraq after\\\\intelligence showed a rebel militia could strike\\\\infrastructure, an oil official said on Saturday.']\n", 53 | "['Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.']\n", 54 | "['Stocks End Up, But Near Year Lows (Reuters) Reuters - Stocks ended slightly higher on Friday\\\\but stayed near lows for the year as oil prices surged past #36;46\\\\a barrel, offsetting a positive outlook from computer maker\\\\Dell Inc. (DELL.O)']\n", 55 | "[\"Money Funds Fell in Latest Week (AP) AP - Assets of the nation's retail money market mutual funds fell by #36;1.17 billion in the latest week to #36;849.98 trillion, the Investment Company Institute said Thursday.\"]\n", 56 | "['Fed minutes show dissent over inflation (USATODAY.com) USATODAY.com - Retail sales bounced back a bit in July, and new claims for jobless benefits fell last week, the government said Thursday, indicating the economy is improving from a midsummer slump.']\n", 57 | "['Safety Net (Forbes.com) Forbes.com - After earning a PH.D. in Sociology, Danny Bazil Riley started to work as the general manager at a commercial real estate firm at an annual base salary of #36;70,000. Soon after, a financial planner stopped by his desk to drop off brochures about insurance benefits available through his employer. But, at 32, \"buying insurance was the furthest thing from my mind,\" says Riley.']\n", 58 | "[\"Wall St. Bears Claw Back Into the Black NEW YORK (Reuters) - Short-sellers, Wall Street's dwindling band of ultra-cynics, are seeing green again.\"]\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "model_output.embedding_generation()" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "data": { 73 | "text/plain": [ 74 | "[array([-1.6417611, -2.2995715, -1.9510227, ..., 0.7497782, 2.988791 ,\n", 75 | " 1.4079323], dtype=float32),\n", 76 | " array([-2.3591247, -2.4123752, -2.8603513, ..., 0.3443652, 1.8753424,\n", 77 | " 1.6522243], dtype=float32),\n", 78 | " array([-1.8146471 , -1.4705577 , -2.005217 , ..., -0.06101079,\n", 79 | " 2.519045 , 1.517622 ], dtype=float32),\n", 80 | " array([-2.1401658, -1.7492014, -2.4797866, ..., -1.1438857, 3.1832254,\n", 81 | " 2.0104322], dtype=float32),\n", 82 | " array([-1.604986 , -1.6979145, -2.130483 , ..., -0.823496 , 3.3996437,\n", 83 | " 0.7275871], dtype=float32),\n", 84 | " array([-2.014131 , -1.7336566 , -2.1519225 , ..., -0.40943134,\n", 85 | " 1.4521998 , 1.5252116 ], dtype=float32),\n", 86 | " array([-1.6848204, -2.1754618, -2.0295758, ..., 1.8960128, 2.2492247,\n", 87 | " -0.1974375], dtype=float32),\n", 88 | " array([-1.9613777 , -1.8956141 , -2.031622 , ..., 0.69780046,\n", 89 | " 1.278298 , 2.0490847 ], dtype=float32),\n", 90 | " array([-2.2047985 , -1.6887155 , -2.8186448 , ..., 0.48909074,\n", 91 | " 0.4389767 , 2.0311885 ], dtype=float32),\n", 92 | " array([-1.5503036 , -2.0949159 , -1.6868635 , ..., 0.23605314,\n", 93 | " 2.8458438 , 1.5262506 ], dtype=float32)]" 94 | ] 95 | }, 96 | "execution_count": 3, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "model_output.embeddings" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [] 111 | } 112 | ], 113 | "metadata": { 114 | "kernelspec": { 115 | "display_name": "env-01", 116 | "language": "python", 117 | "name": "python3" 118 | }, 119 | "language_info": { 120 | "codemirror_mode": { 121 | "name": "ipython", 122 | "version": 3 123 | }, 124 | "file_extension": ".py", 125 | "mimetype": "text/x-python", 126 | "name": "python", 127 | "nbconvert_exporter": "python", 128 | "pygments_lexer": "ipython3", 129 | "version": "3.10.10" 130 | }, 131 | "orig_nbformat": 4 132 | }, 133 | "nbformat": 4, 134 | "nbformat_minor": 2 135 | } 136 | -------------------------------------------------------------------------------- /src/topicmodeling/preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | from datasets import load_dataset, Dataset 4 | from itertools import chain 5 | import multiprocessing 6 | from multiprocessing import Pool 7 | import nltk 8 | from nltk.corpus import stopwords, wordnet 9 | from nltk.stem import WordNetLemmatizer 10 | from nltk.tokenize import word_tokenize 11 | from nltk import pos_tag 12 | import numpy as np 13 | import pandas as pd 14 | from sklearn.feature_extraction.text import TfidfVectorizer 15 | import string 16 | import pickle 17 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 18 | from flanT5_cnn_lighting import FlanT5NestCNNAutoencoder, T5AutoConfig 19 | from sentence_encoder import Encode_Sentence 20 | import torch 21 | 22 | class TextProcessor: 23 | 24 | 25 | def __init__(self, data, name): 26 | self.data = data 27 | self.name = name 28 | self.bow = None 29 | self.word_to_index = None 30 | self.index_to_word = None 31 | self.lemmatized_sentences = None 32 | self.embeddings = None 33 | 34 | nltk.download('punkt') 35 | nltk.download('wordnet') 36 | nltk.download('averaged_perceptron_tagger') 37 | nltk.download('stopwords') 38 | 39 | def __str__(self): 40 | """String representation of TextProcessor""" 41 | return f'TextProcessor(len(data)={len(self.data)})' 42 | 43 | 44 | def get_wordnet_pos(self, word): 45 | tag = pos_tag([word])[0][1][0].upper() 46 | tag_dict = {"J": wordnet.ADJ, 47 | "N": wordnet.NOUN, 48 | "V": wordnet.VERB, 49 | "R": wordnet.ADV} 50 | 51 | return tag_dict.get(tag, wordnet.NOUN) 52 | 53 | def convert_to_bag_of_words(self, list_of_lists, min_freq=10, max_freq_ratio=0.05): 54 | # Your existing implementation here 55 | word_freq = defaultdict(int) 56 | for lst in list_of_lists: 57 | for word in lst: 58 | word_freq[word] += 1 59 | max_freq = len(list_of_lists) * max_freq_ratio 60 | vocabulary = {word for word, count in word_freq.items() if min_freq <= count < max_freq} 61 | word_to_index = {word: i for i, word in enumerate(vocabulary)} 62 | index_to_word = {i: word for word, i in word_to_index.items()} 63 | num_lists = len(list_of_lists) 64 | vocab_size = len(vocabulary) 65 | bag_of_words = [[0] * vocab_size for _ in range(num_lists)] 66 | self.lemmas = [] 67 | for i, lst in enumerate(list_of_lists): 68 | lemma = [] 69 | for word in lst: 70 | lemma = [word for word in lst if word in word_to_index] 71 | if word in word_to_index: 72 | index = word_to_index[word] 73 | bag_of_words[i][index] += 1 74 | 75 | self.lemmas.append(lemma) 76 | 77 | self.bow, self.word_to_index, self.index_to_word = bag_of_words, word_to_index, index_to_word 78 | 79 | def extract_important_words(self, tfidf_vector, feature_names): 80 | # Your existing implementation here 81 | coo_matrix = tfidf_vector.tocoo() 82 | sorted_items = sorted(zip(coo_matrix.col, coo_matrix.data), key=lambda x: (x[1], x[0]), reverse=True) 83 | 84 | return self.extract_topn_from_vector(feature_names, sorted_items) 85 | 86 | def extract_topn_from_vector(self, feature_names, sorted_items, topn=20): 87 | sorted_items = sorted_items[:topn] 88 | 89 | score_vals = [] 90 | feature_vals = [] 91 | 92 | for idx, score in sorted_items: 93 | score_vals.append(round(score, 3)) 94 | feature_vals.append(feature_names[idx]) 95 | 96 | results = {} 97 | for idx in range(len(feature_vals)): 98 | results[feature_vals[idx]] = score_vals[idx] 99 | 100 | return results 101 | 102 | # def lemmatize_sentences(self): 103 | # lemmatizer = WordNetLemmatizer() 104 | # lemmatized_sentences = [] 105 | # stop_words = set(stopwords.words('english')) 106 | # table = str.maketrans(string.punctuation, ' ' * len(string.punctuation)) 107 | 108 | # for index, sentence in enumerate(self.data): 109 | # if index % 100 == 0: 110 | # print(index) 111 | # sentence = sentence.translate(table).lower().replace(" ", " ") 112 | 113 | # words = word_tokenize(sentence) 114 | # lemmatized_words = [lemmatizer.lemmatize(word, self.get_wordnet_pos(word)) for word in words if word not in stop_words and lemmatizer.lemmatize(word, self.get_wordnet_pos(word)) != '' ] 115 | # lemmatized_words = [i for i in lemmatized_words if len(i) >= 3 and (not i.isdigit()) and ' ' not in i] 116 | # lemmatized_sentences.append(lemmatized_words) 117 | 118 | # self.lemmatized_sentences = lemmatized_sentences 119 | 120 | 121 | def worker(self, data_chunk): 122 | lemmatizer = WordNetLemmatizer() 123 | stop_words = set(stopwords.words('english')) 124 | table = str.maketrans(string.punctuation, ' ' * len(string.punctuation)) 125 | lemmatized_chunk = [] 126 | 127 | for sentence in data_chunk: 128 | sentence = sentence.translate(table).lower().replace(" ", " ") 129 | words = word_tokenize(sentence) 130 | lemmatized_words = [lemmatizer.lemmatize(word, self.get_wordnet_pos(word)) for word in words if word not in stop_words and lemmatizer.lemmatize(word, self.get_wordnet_pos(word)) != '' ] 131 | lemmatized_words = [i for i in lemmatized_words if len(i) >= 3 and (not i.isdigit()) and ' ' not in i] 132 | lemmatized_chunk.append(lemmatized_words) 133 | 134 | return lemmatized_chunk 135 | 136 | def lemmatize_sentences(self): 137 | num_processes = multiprocessing.cpu_count() 138 | pool = Pool(num_processes) 139 | data_chunks = np.array_split(self.data, num_processes) 140 | results = pool.map(self.worker, data_chunks) 141 | pool.close() 142 | pool.join() 143 | #print(results) 144 | self.lemmatized_sentences = [j for i in results for j in i] 145 | 146 | print(len(results), len(results[0])) 147 | 148 | def embedding_generation(self): 149 | 150 | 151 | 152 | config = T5AutoConfig(hidden_size1=512, hidden_size3=4, 153 | hidden_size2=768, output_size=4 * 768, 154 | 155 | model='google/flan-t5-large', 156 | 157 | ) # Replace with your config 158 | models = Encode_Sentence(config, max_length=512, 159 | pretrain_model = '/home/weijiexu/flant5_nest_peftflan-t5-large8', 160 | pretrain_model_token = '/home/weijiexu/flant5_nest_peft_PREFIX_TUNING_SEQ_2_SEQ_LMflan-t5-large8', 161 | ) 162 | self.embeddings = models.encode(self.data) 163 | print(self.embeddings.shape) 164 | 165 | def process(self): 166 | try: 167 | self.embeddings = torch.load(self.name + '.pt') 168 | 169 | except: 170 | self.embedding_generation() 171 | torch.save(self.embeddings, self.name) 172 | self.lemmatize_sentences() 173 | self.convert_to_bag_of_words(self.lemmatized_sentences) -------------------------------------------------------------------------------- /src/topicmodeling/sentence_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 5 | from flanT5_cnn_lighting import FlanT5NestCNNAutoencoder, T5AutoConfig 6 | 7 | from sentence_transformers import SentenceTransformer 8 | 9 | from torch.utils.data import DataLoader, TensorDataset 10 | 11 | from transformers import AutoTokenizer 12 | 13 | 14 | 15 | class Encode_Sentence(FlanT5NestCNNAutoencoder): 16 | def __init__(self, config, pretrain_model_token, pretrain_model, *args, **kwargs): 17 | super().__init__(config) 18 | # model_name_or_path = kwargs.get('model_name_or_path', 'google/flan-t5-base') 19 | self.tokenizer = AutoTokenizer.from_pretrained(pretrain_model_token) 20 | self.load_state_dict(torch.load(pretrain_model)) 21 | self.max_length = kwargs.get('max_length', 256) 22 | self.device_name = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | self.to(self.device_name) 24 | 25 | def encode(self, sentences, batch_size=32, **kwargs): 26 | # Tokenize all sentences and create a TensorDataset 27 | sent_token = self.tokenizer(sentences, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt") 28 | dataset = TensorDataset(sent_token.input_ids, sent_token.attention_mask) 29 | 30 | # Create a DataLoader for efficient batch loading 31 | loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) 32 | 33 | embeddings = [] 34 | for input_ids, attention_mask in loader: 35 | input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device) 36 | 37 | with torch.no_grad(): 38 | # Assuming your model returns embeddings as the last hidden state 39 | output = self.model.encoder(input_ids, attention_mask=attention_mask).last_hidden_state 40 | output = self.encoder(output) 41 | #output = output.mean(dim=1) # Example aggregation, e.g., mean pooling over token embeddings 42 | 43 | embeddings.append(output.cpu()) 44 | 45 | # Concatenate all batches to get the final embeddings tensor 46 | embeddings = torch.cat(embeddings, dim=0) 47 | embeddings = embeddings.reshape(embeddings.shape[0], -1) 48 | return embeddings 49 | 50 | # def encode(self, sentences, batch_size=1, **kwargs): 51 | # """ 52 | # Returns a list of embeddings for the given sentences. 53 | # Args: 54 | # sentences (`List[str]`): List of sentences to encode 55 | # batch_size (`int`): Batch size for the encoding 56 | 57 | # Returns: 58 | # `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences 59 | # """ 60 | 61 | # #device_encode = self.model.device 62 | 63 | # sents_batches = [sentences[i:i+batch_size] for i in range(0, len(sentences), batch_size)] 64 | 65 | # output_ls = [] 66 | # attention_ls = [] 67 | 68 | # for sents_batch in sents_batches: 69 | # # model tokenizor 70 | # #print(sents_batch) 71 | # sent_token = self.tokenizer(sents_batch, truncation=True, padding="max_length", 72 | # max_length=self.max_length, return_tensors="pt") 73 | 74 | # # get the inputs and attention 75 | # inputs = sent_token.input_ids.to(self.device_name) 76 | # attention = sent_token.attention_mask.to(self.device_name) 77 | 78 | # with torch.no_grad(): 79 | # # get the latent space embedding 80 | # output = self.model.encoder(inputs, attention).last_hidden_state 81 | # output = self.encoder(output) 82 | 83 | # # Convert PyTorch tensor to NumPy array 84 | # numpy_array = output.cpu().detach().numpy() 85 | 86 | # output_ls.extend(numpy_array) 87 | # attention_ls.extend(attention.cpu().detach().numpy()) 88 | 89 | # # Reshape NumPy array (num_sent, num_token, dim_token) to 2D matrix 90 | # # numpy_array = np.concatenate(output_ls, axis=0) 91 | # # embed = numpy_array.reshape((-1, numpy_array.shape[-1] * numpy_array.shape[-2])) 92 | # # print(output_ls[0].shape) 93 | # # embed_ls = [numpy_array.reshape((-1, numpy_array.shape[-1] * numpy_array.shape[-2])) for numpy_array in output_ls] 94 | # embed_ls = [numpy_array.flatten() for numpy_array in output_ls] 95 | # # print(embed_ls[0].shape) 96 | # # print(len(embed_ls)) 97 | # # print(embed_ls[0]) 98 | 99 | # return embed_ls -------------------------------------------------------------------------------- /src/topicmodeling/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def kld_normal(mu, log_sigma): 4 | """KL divergence to standard normal distribution. 5 | mu: batch_size x dim 6 | log_sigma: batch_size x dim 7 | """ 8 | #normal distribution KL divergence of two gaussian 9 | #https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians 10 | return -0.5 * (1 - mu ** 2 + 2 * log_sigma - torch.exp(2 * log_sigma)).sum(dim=-1) 11 | 12 | 13 | -------------------------------------------------------------------------------- /src/topicmodeling/vontss.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/text_generation_diffusion_llm_topic/501b981b8987f13ce73a1e24627caeeb790813b6/src/topicmodeling/vontss.py --------------------------------------------------------------------------------