├── .gitignore ├── API ├── api_service_main.py ├── server_start.bat └── server_start.sh ├── LICENSE ├── README.md ├── examples ├── Annotation_Visualization.ipynb ├── Embd_Interp.ipynb ├── FET_Projection.ipynb ├── GUI_TEST.ipynb ├── KNN_RAG.ipynb ├── KNN_RAG_interp.ipynb ├── Metric_calculation_test.ipynb ├── RL_Simalation_With_Interpretation_Ebeddings.ipynb ├── SBertDistilTest.ipynb ├── SBertSVDDistil.ipynb └── example_data │ ├── clusters.json │ └── metrix.csv ├── explainitall ├── QA │ ├── extractive_qa_sbert │ │ ├── QABotsBase.py │ │ └── SVDBert.py │ └── interp_qa │ │ └── KNNWithGenerative.py ├── __init__.py ├── clusters.py ├── embedder_interp │ └── embd_interpret.py ├── fast_tuning │ ├── Embedder.py │ ├── ExpertBase.py │ ├── SimpleModelCreator.py │ ├── generators │ │ ├── GeneratorWithExpert.py │ │ ├── MCExpert.py │ │ ├── MCModel.py │ │ └── SimpleGenerator.py │ └── trainers │ │ ├── DenceKerasTrainer.py │ │ ├── HMMTrainer.py │ │ └── ProjectionTrainer.py ├── gpt_like_interp │ ├── __init__.py │ ├── downloader.py │ ├── inseq_helpers.py │ ├── interp.py │ └── viz.py ├── gui │ ├── df_to_heatmap_plot.py │ └── interface.py ├── metrics │ ├── CheckingForHallucinations.py │ └── RougeAndPPL │ │ ├── Metrics.py │ │ ├── Metrics_calculator.py │ │ ├── create_database.py │ │ ├── helpers.py │ │ ├── metric_calculation_interface.py │ │ ├── rouge_L.py │ │ └── rouge_N.py ├── nlp.py └── stat_helpers.py ├── main.py ├── pytest.ini ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── test_inseq_helpers.py └── test_stat_helpers.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | /exclude_* 163 | /examples/dataset 164 | /examples/database.sqlite 165 | *cache* 166 | /examples/trained_model 167 | /examples/s 168 | /examples/new_gpt 169 | -------------------------------------------------------------------------------- /API/api_service_main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from typing import List 4 | 5 | import gensim 6 | import numpy as np 7 | from fastapi import FastAPI 8 | from inseq import load_model 9 | from pydantic import BaseModel, constr 10 | from sentence_transformers import SentenceTransformer 11 | from sklearn.neighbors import KNeighborsClassifier 12 | from starlette.responses import RedirectResponse 13 | 14 | from explainitall.QA.extractive_qa_sbert.QABotsBase import cos_dist 15 | from explainitall.QA.extractive_qa_sbert.SVDBert import SVDBertModel 16 | from explainitall.QA.interp_qa.KNNWithGenerative import FredStruct, PromptBot 17 | from explainitall.gpt_like_interp import interp 18 | from explainitall.gpt_like_interp.downloader import DownloadManager 19 | 20 | 21 | class GetAnswerItem(BaseModel): 22 | question: constr(min_length=1) 23 | top_k: int 24 | 25 | 26 | class LoadDatasetItem(BaseModel): 27 | questions: List[str] 28 | answers: List[str] 29 | 30 | 31 | class ClusterItem(BaseModel): 32 | name: constr(min_length=1) 33 | centroid: List[str] 34 | top_k: int 35 | 36 | 37 | class EvaluationItem(BaseModel): 38 | nlp_model_path: constr(min_length=1) 39 | nn_model_path: constr(min_length=1) 40 | clusters: List[ClusterItem] 41 | prompt: constr(min_length=1) 42 | generated_text: constr(min_length=1) 43 | 44 | 45 | class ApiServerInit: 46 | def __init__(self, sbert_path='FractalGPT/SbertSVDDistil', device='cpu'): 47 | self.clusters_r = None 48 | self.bot = None 49 | self.prompt = None 50 | self.generated_text = None 51 | self.explainer = None 52 | self.clusters = None 53 | self.device = device 54 | self.app = FastAPI() 55 | self.sbert = SentenceTransformer(sbert_path) 56 | self.sbert[0].auto_model = SVDBertModel.from_pretrained(sbert_path) 57 | # self.fred = FredStruct('SiberiaSoft/SiberianFredT5-instructor') 58 | 59 | if os.getenv('TEST_MODE_ON_LOW_SPEC_PC') == 'True': 60 | self.fred = FredStruct('ai-forever/FRED-T5-large') 61 | else: 62 | self.fred = FredStruct('FractalGPT/FRED-T5-Interp') 63 | 64 | def load_dataset(self, questions, answers): 65 | self.__init_knn__(questions, answers) 66 | self.bot = PromptBot(self.knn, self.sbert, self.fred, answers, device=self.device) 67 | return True 68 | 69 | @staticmethod 70 | def df_to_dict(data_frame): 71 | df_copy = data_frame.copy(deep=True) 72 | 73 | def make_columns_unique(df): 74 | new_columns = {} 75 | for column in df.columns: 76 | if column in new_columns: 77 | new_columns[column] += 1 78 | new_name = f"{column}_{new_columns[column]}" 79 | else: 80 | new_columns[column] = 0 81 | new_name = column 82 | yield new_name 83 | 84 | df_copy.columns = list(make_columns_unique(df_copy)) 85 | return df_copy.replace([np.nan, np.inf, -np.inf], ["nan", "inf", "-inf"]).to_dict(orient="split") 86 | 87 | def evaluate(self, nlp_model_path, nn_model_path, clusters, prompt, generated_text): 88 | self.clusters_r = [{ 89 | "name": cluster.name, 90 | "centroid": cluster.centroid, 91 | "top_k": cluster.top_k 92 | } for cluster in clusters] 93 | self.prompt = prompt 94 | self.generated_text = generated_text 95 | self.__load_nlp_model__(nlp_model_path) 96 | self.__load_nn_model__(nn_model_path) 97 | self.explainer = interp.ExplainerGPT2(gpt_model=self.nn_model, nlp_model=self.nlp_model) 98 | expl_data = self.explainer.interpret( 99 | input_texts=self.prompt, 100 | generated_texts=self.generated_text, 101 | clusters_description=self.clusters_r, 102 | batch_size=50, 103 | steps=34, 104 | # max_new_tokens=19 105 | ) 106 | return {"word_importance_map": self.df_to_dict(expl_data.word_imp_df), 107 | "word_importance_map_normalized": self.df_to_dict(expl_data.word_imp_norm_df), 108 | "cluster_importance_map": self.df_to_dict(expl_data.cluster_imp_df), 109 | "cluster_importance_map_normalized": self.df_to_dict(expl_data.cluster_imp_aggr_df)} 110 | 111 | def get_answer(self, q, top_k): 112 | return self.bot.get_answers(q, top_k=top_k) 113 | 114 | def __init_knn__(self, questions, answers): 115 | vects_questions = self.sbert.encode(questions) 116 | m = vects_questions.mean(axis=0) 117 | s = vects_questions.std(axis=0) 118 | knn_vects_questions = (vects_questions - m) / s 119 | 120 | self.knn = KNeighborsClassifier(metric=cos_dist) 121 | self.knn.fit(knn_vects_questions, answers) 122 | 123 | def __load_nlp_model__(self, url): 124 | self.nlp_model_url = url 125 | nlp_model_path = DownloadManager.load_zip(url) 126 | self.nlp_model = gensim.models.KeyedVectors.load_word2vec_format(nlp_model_path, binary=True) 127 | return True 128 | 129 | def __load_nn_model__(self, model_name_or_path): 130 | path = os.path.normpath(model_name_or_path) 131 | path_list = path.split(os.sep) 132 | self.nn_model_name = path_list[-1] 133 | self.nn_model = load_model(model=model_name_or_path, attribution_method="integrated_gradients") 134 | return True 135 | 136 | 137 | api_server_init = ApiServerInit() 138 | app = api_server_init.app 139 | 140 | 141 | @app.get("/") 142 | async def redirect_to_docs(): 143 | return RedirectResponse(url="/docs") 144 | 145 | 146 | @app.post( 147 | "/load_dataset", 148 | summary="Load a dataset for the Q&A model", 149 | response_description="Indicates success of dataset loading", 150 | openapi_extra={ 151 | "requestBody": { 152 | "content": { 153 | "application/json": { 154 | "examples": { 155 | "example1": { 156 | "summary": "Загрузка базового набора данных вопросов и ответов", 157 | "value": { 158 | "questions": ["Что такое коала?", 159 | "Опишите африканского слона"], 160 | 161 | "answers": ["Это вид медведей, обитающих в Австралии.", 162 | "Это крупное млекопитающее с длинным хоботом."] 163 | } 164 | } 165 | } 166 | } 167 | } 168 | } 169 | } 170 | ) 171 | async def load_dataset(item: LoadDatasetItem): 172 | """Загружает набор данных, состоящий из вопросов и ответов о животных, в модель вопросов и ответов""" 173 | result = await asyncio.to_thread(api_server_init.load_dataset, item.questions, item.answers) 174 | return {"result": result} 175 | 176 | 177 | @app.post( 178 | "/get_answer", 179 | summary="Получить ответы на вопросы", 180 | response_description="Полученный(е) ответ(ы) на указанные вопросы", 181 | openapi_extra={ 182 | "requestBody": { 183 | "content": { 184 | "application/json": { 185 | "examples": { 186 | "example1": { 187 | "summary": "Получить ответ на вопрос о животном", 188 | "value": { 189 | "question": "Что за животное коала?", 190 | "top_k": 1 191 | } 192 | } 193 | } 194 | } 195 | } 196 | } 197 | } 198 | ) 199 | async def get_answer(item: GetAnswerItem): 200 | """ 201 | Получает ответы на указанный вопрос с использованием загруженной модели вопросов и ответов 202 | """ 203 | result = await asyncio.to_thread(api_server_init.get_answer, item.question, item.top_k) 204 | return {"result": result} 205 | 206 | 207 | @app.post( 208 | "/evaluate", 209 | summary="Оценить сгенерированный текст с использованием модели", 210 | response_description="Результаты оценки", 211 | openapi_extra={ 212 | "requestBody": { 213 | "content": { 214 | "application/json": { 215 | "examples": { 216 | "example1": { 217 | "summary": "Оценить интерпретируемость текста", 218 | "value": { 219 | "nlp_model_path": 'http://vectors.nlpl.eu/repository/20/180.zip', 220 | "nn_model_path": "ai-forever/rugpt3small_based_on_gpt2", 221 | "clusters": [ 222 | {'name': 'Животные', 'centroid': ['собака', 'кошка', 'заяц'], 'top_k': 140}, 223 | {'name': 'Лекарства', 'centroid': ['уколы', 'таблетки', 'противовирусное'], 224 | 'top_k': 160}, 225 | {'name': 'Болезни', 'centroid': ['простуда', 'орви', 'орз', 'грипп'], 'top_k': 20}, 226 | {'name': 'Симптомы', 'centroid': ['температура', 'насморк'], 'top_k': 20} 227 | ], 228 | "prompt": "я думаю что у моей кошки простуда, у нее температура, постоянный кашель: чем мне лечить мою кошку? ответ:", 229 | "generated_text": "На сегодняшний день существует специальное противовирусное лечение для кошек, так же можно применять антибиотики" 230 | } 231 | } 232 | } 233 | } 234 | } 235 | } 236 | } 237 | ) 238 | async def evaluate(item: EvaluationItem): 239 | """Оценивает интерпретируемость сгенерированного текста относительно входного подсказки и предоставленных кластеров""" 240 | data = await asyncio.to_thread(api_server_init.evaluate, item.nlp_model_path, item.nn_model_path, item.clusters, 241 | item.prompt, item.generated_text) 242 | return data 243 | 244 | 245 | if __name__ == "__main__": 246 | import uvicorn 247 | 248 | uvicorn.run(app, host="localhost", port=8000) 249 | -------------------------------------------------------------------------------- /API/server_start.bat: -------------------------------------------------------------------------------- 1 | title API_EXPLAINITALL_PYTHON 2 | uvicorn api_service_main:app --port 8000 --reload -------------------------------------------------------------------------------- /API/server_start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "API_EXPLAINITALL_PYTHON" 3 | uvicorn api_service_main:app --port 8000 --reload -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Downloads](https://img.shields.io/pypi/dm/explainitall.svg)](https://pypi.org/project/explainitall/) 2 | 3 | # **ExplainitAll** 4 | 5 | **ExplainitAll** — это библиотека для интерпретируемого ИИ, предназначенная для интерпретации генеративных моделей ( 6 | GPT-like), и векторизаторов, например, Sbert. Библиотека предоставляет пользователям инструменты для анализа и понимания 7 | работы этих сложных моделей. Кроме того, содержит модули RAG QA, fast_tuning и пользовательский интерфейс. 8 | 9 | --- 10 | 11 | * [Примеры использования](https://github.com/Bots-Avatar/ExplainitAll/tree/main/examples) 12 | * [Исходный код библиотеки](https://github.com/Bots-Avatar/ExplainitAll/tree/main/explainitall) 13 | * [Документация](https://github.com/Bots-Avatar/ExplainitAll/wiki) 14 | 15 | ## Модели: 16 | 17 | * Дистиллированный [Sbert](https://huggingface.co/FractalGPT/SbertDistil) 18 | * Дистиллированный [Sbert](https://huggingface.co/FractalGPT/SbertSVDDistil) с применением SVD разложения, для ускорения 19 | инференса и обучения 20 | * [FRED T5](https://huggingface.co/FractalGPT/FRED-T5-Interp), обученный под задачу RAG, для ответов на вопросы по 21 | интепретации генеративной gpt-подобной сети. 22 | * [FRED T5](https://huggingface.co/FractalGPT/FredT5-Large-Instruct-Context), небольшой T5 обученный для instruct задач с учетом контекста 23 | --- 24 | 25 | ## Перечень направлений прикладного использования: 26 | 27 | Результаты могут применяться в следующих областях: любые вопрос-ответные системы или классификаторы критических 28 | отраслей (медицина, строительство, космос, право и т.п.). Типовой сценарий применения, например для медицины следующий: 29 | разработчик конечного продукта, такого как например система поиска противопоказаний у лекарств в тесном взаимодействии в 30 | заказчиком(врачом, поликлиникой и т.п.) создает набор кластеров тематической области, дообучает трансформерную модель ( 31 | GPT-like: например, семейств ruGPT3 и GPT2) на текстах вопрос-ответ, и на затем в режиме эксплуатации данной, уже 32 | готовой Вопросно-ответной системы подключает библиотеку ExplainitAll для того, чтобы она давал аналитическую оценку – 33 | насколько «надежными» и доверенными являются ответы вопросно-ответной системы на основе результата интерпретации – 34 | действительно ли при ответе на вопросы пользователя система обращала внимание на важные для отрасли кластеры. 35 | 36 | Разработанная библиотека может быть адаптирована как модуль конечного продукта - ассистента врача, 37 | инженера-конструктора, юриста, бухгалтера. Для государственного сектора библиотека может быть полезна т.к. помогает 38 | доверять RAG системам при ответах по налогам, регламентам проведения закупочных процедур, руководствам пользователей 39 | информационных систем, нормативно-правовым актам регулирования. Для промышленных предприятий библиотека применима в 40 | работе с регламентами, руководствами по эксплуатации и обслуживанию сложного технического оборудования, т.к. позволяет 41 | оценивать учет в ответах QA систем понимание специальных, важных для отрасли сокращений, наименования, аббревиатур, 42 | номенклатурных обозначений. 43 | 44 | ## Характеристики: 45 | 46 | * Операционная система Ubuntu 22.04.3 LTS 47 | * Драйвер: NVIDIA версия 535.104.05 48 | * CUDA версия 12.2 49 | * Python 3.10.12 50 | 51 | 52 | * Процессор AMD Ryzen 3 3200G OEM (частота: 3600 МГц, количество ядер: 4) 53 | * Оперативная память 16 GB 54 | 55 | * Графический процессор 56 | * Модель: nVidia TU104GL [Tesla T4] 57 | * Видеопамять 16 GB 58 | 59 | ## Лицензия 60 | 61 | * Лицензия [Apache-2.0 license](https://github.com/Bots-Avatar/ExplainitAll/tree/main#Apache-2.0-1-ov-file) 62 | -------------------------------------------------------------------------------- /examples/GUI_TEST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "26b6dccfbfdf3023", 7 | "metadata": { 8 | "ExecuteTime": { 9 | "end_time": "2024-03-31T23:35:27.783149Z", 10 | "start_time": "2024-03-31T23:35:27.780481Z" 11 | }, 12 | "collapsed": false, 13 | "jupyter": { 14 | "outputs_hidden": false 15 | } 16 | }, 17 | "outputs": [ 18 | { 19 | "name": "stdout", 20 | "output_type": "stream", 21 | "text": [ 22 | "\u001b[33mWARNING: typer 0.12.0 does not provide the extra 'all'\u001b[0m\u001b[33m\n", 23 | "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", 24 | "\u001b[0m" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "!pip install git+https://github.com/Bots-Avatar/ExplainitAll -q" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "eEJyA-XRmLUx", 36 | "metadata": { 37 | "ExecuteTime": { 38 | "end_time": "2024-04-05T18:22:43.620616Z", 39 | "start_time": "2024-04-05T18:22:43.613855Z" 40 | }, 41 | "id": "eEJyA-XRmLUx" 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "import os\n", 46 | "# os.environ['TEST_MODE_ON_LOW_SPEC_PC'] = 'True'" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "id": "eb3267fe", 53 | "metadata": { 54 | "ExecuteTime": { 55 | "end_time": "2024-04-05T18:22:48.980250Z", 56 | "start_time": "2024-04-05T18:22:45.467270Z" 57 | }, 58 | "id": "eb3267fe" 59 | }, 60 | "outputs": [ 61 | { 62 | "name": "stderr", 63 | "output_type": "stream", 64 | "text": [ 65 | "[nltk_data] Downloading package punkt to /root/nltk_data...\n", 66 | "[nltk_data] Package punkt is already up-to-date!\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "from explainitall.gui.interface import DemoInterface\n", 72 | "from explainitall.gui.interface import set_verbosity_error\n", 73 | "set_verbosity_error()" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "id": "a63b1664", 80 | "metadata": { 81 | "ExecuteTime": { 82 | "end_time": "2024-04-05T18:22:59.430666Z", 83 | "start_time": "2024-04-05T18:22:50.245045Z" 84 | }, 85 | "id": "a63b1664" 86 | }, 87 | "outputs": [ 88 | { 89 | "data": { 90 | "application/vnd.jupyter.widget-view+json": { 91 | "model_id": "bb93946321c54dcb881a78a165fb0457", 92 | "version_major": 2, 93 | "version_minor": 0 94 | }, 95 | "text/plain": [ 96 | "modules.json: 0%| | 0.00/341 [00:00" 460 | ], 461 | "text/plain": [ 462 | "" 463 | ] 464 | }, 465 | "metadata": {}, 466 | "output_type": "display_data" 467 | } 468 | ], 469 | "source": [ 470 | "interface.launch()" 471 | ] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "id": "b6sQtIG7rK7t", 476 | "metadata": { 477 | "id": "b6sQtIG7rK7t" 478 | }, 479 | "source": [ 480 | "interface.launch()\n", 481 | "### То же самое можно сделать это кодом (ниже пример)" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": 9, 487 | "id": "f1794464", 488 | "metadata": { 489 | "ExecuteTime": { 490 | "end_time": "2024-03-31T23:37:08.072609Z", 491 | "start_time": "2024-03-31T23:37:08.068446Z" 492 | }, 493 | "id": "f1794464" 494 | }, 495 | "outputs": [], 496 | "source": [ 497 | "interface.context_ = 'у кошки грипп и аллергия на антибиотбиотики вопрос: чем лечить кошку? ответ:'\n", 498 | "interface.generated_text_ = 'лечите ее уколами'\n", 499 | "\n", 500 | "interface.clusters_ = [\n", 501 | " {'name': 'Животные', 'centroid': ['собака', 'кошка', 'заяц'], 'top_k': 140},\n", 502 | " {'name': 'Лекарства', 'centroid': ['уколы', 'таблетки', 'микстуры'], 'top_k': 160},\n", 503 | " {'name': 'Болезни', 'centroid': ['простуда', 'орви', 'орз', 'грипп'], 'top_k': 20},\n", 504 | " {'name': 'Аллергия', 'centroid': ['аллергия'], 'top_k': 20}\n", 505 | "]" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": 10, 511 | "id": "1d12205d", 512 | "metadata": { 513 | "ExecuteTime": { 514 | "end_time": "2024-03-31T23:37:09.287562Z", 515 | "start_time": "2024-03-31T23:37:08.409795Z" 516 | }, 517 | "id": "1d12205d" 518 | }, 519 | "outputs": [ 520 | { 521 | "data": { 522 | "text/plain": [ 523 | "(True, True)" 524 | ] 525 | }, 526 | "execution_count": 10, 527 | "metadata": {}, 528 | "output_type": "execute_result" 529 | } 530 | ], 531 | "source": [ 532 | "interface.load_nlp_model_('http://vectors.nlpl.eu/repository/20/180.zip')" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 11, 538 | "id": "3fe69def", 539 | "metadata": { 540 | "ExecuteTime": { 541 | "end_time": "2024-03-31T23:37:12.280943Z", 542 | "start_time": "2024-03-31T23:37:09.288664Z" 543 | }, 544 | "id": "3fe69def" 545 | }, 546 | "outputs": [ 547 | { 548 | "data": { 549 | "text/plain": [ 550 | "(True, True)" 551 | ] 552 | }, 553 | "execution_count": 11, 554 | "metadata": {}, 555 | "output_type": "execute_result" 556 | } 557 | ], 558 | "source": [ 559 | "interface.load_nn_model_(\"sberbank-ai/rugpt3small_based_on_gpt2\")" 560 | ] 561 | } 562 | ], 563 | "metadata": { 564 | "colab": { 565 | "provenance": [] 566 | }, 567 | "kernelspec": { 568 | "display_name": "Python 3 (ipykernel)", 569 | "language": "python", 570 | "name": "python3" 571 | }, 572 | "language_info": { 573 | "codemirror_mode": { 574 | "name": "ipython", 575 | "version": 3 576 | }, 577 | "file_extension": ".py", 578 | "mimetype": "text/x-python", 579 | "name": "python", 580 | "nbconvert_exporter": "python", 581 | "pygments_lexer": "ipython3", 582 | "version": "3.10.13" 583 | } 584 | }, 585 | "nbformat": 4, 586 | "nbformat_minor": 5 587 | } 588 | -------------------------------------------------------------------------------- /examples/KNN_RAG_interp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2024-03-31T23:37:41.442738Z", 9 | "start_time": "2024-03-31T23:37:41.440452Z" 10 | }, 11 | "collapsed": false, 12 | "jupyter": { 13 | "outputs_hidden": false 14 | } 15 | }, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "\u001b[33mWARNING: typer 0.12.0 does not provide the extra 'all'\u001b[0m\u001b[33m\n", 22 | "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", 23 | "\u001b[0m" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "!pip install git+https://github.com/Bots-Avatar/ExplainitAll -q" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 10, 34 | "metadata": { 35 | "ExecuteTime": { 36 | "end_time": "2024-03-31T23:37:38.390677Z", 37 | "start_time": "2024-03-31T23:37:38.384997Z" 38 | }, 39 | "id": "aXqbhzKyh3I7" 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "import os\n", 44 | "# os.environ['TEST_MODE_ON_LOW_SPEC_PC'] = 'True'" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 1, 50 | "metadata": { 51 | "ExecuteTime": { 52 | "end_time": "2024-04-05T18:27:05.551805Z", 53 | "start_time": "2024-04-05T18:27:00.941860Z" 54 | }, 55 | "id": "aba3N6uRpmy0" 56 | }, 57 | "outputs": [ 58 | { 59 | "name": "stderr", 60 | "output_type": "stream", 61 | "text": [ 62 | "[nltk_data] Downloading package punkt to /root/nltk_data...\n", 63 | "[nltk_data] Package punkt is already up-to-date!\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "from explainitall.QA.interp_qa.KNNWithGenerative import FredStruct, PromptBot\n", 69 | "from explainitall.QA.extractive_qa_sbert.SVDBert import SVDBertModel\n", 70 | "from explainitall.QA.extractive_qa_sbert.QABotsBase import cos_dist\n", 71 | "from explainitall.gpt_like_interp.downloader import DownloadManager\n", 72 | "\n", 73 | "from sklearn.neighbors import KNeighborsClassifier\n", 74 | "from sentence_transformers import SentenceTransformer\n", 75 | "import gensim\n", 76 | "from inseq import load_model\n", 77 | "from explainitall.gpt_like_interp import interp\n", 78 | "from explainitall.gui.interface import set_verbosity_error\n", 79 | "set_verbosity_error()" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 2, 85 | "metadata": { 86 | "ExecuteTime": { 87 | "end_time": "2024-04-05T18:28:06.598120Z", 88 | "start_time": "2024-04-05T18:28:05.680831Z" 89 | }, 90 | "id": "IPrg310VtVNy" 91 | }, 92 | "outputs": [ 93 | { 94 | "name": "stderr", 95 | "output_type": "stream", 96 | "text": [ 97 | "Downloading: /root/.cache/180_zip: 100%|██████████| 462M/462M [00:26<00:00, 18.0MiB/s] \n", 98 | "Extracting: /root/.cache/180_zip_data: 100%|██████████| 4/4 [00:05<00:00, 1.26s/it]\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "def load_nlp_model(nlp_model_url):\n", 104 | " nlp_model_path = DownloadManager.load_zip(nlp_model_url)\n", 105 | " return gensim.models.KeyedVectors.load_word2vec_format(nlp_model_path, binary=True)\n", 106 | "\n", 107 | "# 'ID': 180\n", 108 | "# 'Размер вектора': 300\n", 109 | "# 'Корпус': 'Russian National Corpus'\n", 110 | "# 'Размер словаря': 189193\n", 111 | "# 'Алгоритм': 'Gensim Continuous Bag-of-Words'\n", 112 | "# 'Лемматизация': True\n", 113 | "\n", 114 | "nlp_model = load_nlp_model ('http://vectors.nlpl.eu/repository/20/180.zip')" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 4, 120 | "metadata": { 121 | "ExecuteTime": { 122 | "end_time": "2024-04-05T18:28:07.239780Z", 123 | "start_time": "2024-04-05T18:28:07.237071Z" 124 | }, 125 | "id": "Z_WDkdpKtcd8" 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "model_path = \"sberbank-ai/rugpt3small_based_on_gpt2\"" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 5, 135 | "metadata": { 136 | "ExecuteTime": { 137 | "end_time": "2024-04-05T18:28:12.419667Z", 138 | "start_time": "2024-04-05T18:28:08.710629Z" 139 | }, 140 | "id": "E7jniFg2thv4" 141 | }, 142 | "outputs": [ 143 | { 144 | "data": { 145 | "application/vnd.jupyter.widget-view+json": { 146 | "model_id": "652f3a84abc04530b5e87140d7366f31", 147 | "version_major": 2, 148 | "version_minor": 0 149 | }, 150 | "text/plain": [ 151 | "config.json: 0%| | 0.00/720 [00:00" 104 | ], 105 | "text/plain": [ 106 | "" 107 | ] 108 | }, 109 | "metadata": {}, 110 | "output_type": "display_data" 111 | } 112 | ], 113 | "source": [ 114 | "metric_calculation_interface.launch()" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 8, 120 | "id": "d6c66fd6", 121 | "metadata": { 122 | "ExecuteTime": { 123 | "end_time": "2024-03-31T23:39:37.429481Z", 124 | "start_time": "2024-03-31T23:39:35.850815Z" 125 | } 126 | }, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "application/vnd.jupyter.widget-view+json": { 131 | "model_id": "2ce462cc87e240c5b51d09f4cb40108d", 132 | "version_major": 2, 133 | "version_minor": 0 134 | }, 135 | "text/plain": [ 136 | "tokenizer_config.json: 0%| | 0.00/26.0 [00:00= confidence: 171 | return answer['answer'] 172 | else: 173 | return '' 174 | 175 | def get_prompt(self, q, confidence=0.3, top_k_search=7): 176 | text = self.retr.get_answers(q, top_k_search) 177 | return self.search(text, q, confidence) 178 | 179 | 180 | class SimpleTransformer(): 181 | """ Класс-заглушка """ 182 | 183 | def __init__(self): 184 | pass 185 | 186 | def transform(self, vects): 187 | return vects 188 | 189 | 190 | def cos(x, y): 191 | return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y)) 192 | 193 | 194 | def cos_dist(x, y): 195 | return -cos(x, y) 196 | 197 | 198 | class KnnBot: 199 | """ 200 | Поисковый бот на базе модели векторизации и метода ближайших соседей 201 | с установкой максимального радиуса для детекции аномалий. 202 | """ 203 | 204 | def __init__(self, knn=None, sbert=None, mean=None, std=None, vect_transformer=None, dim=None, n_neighbors=3, eps=1e-200): 205 | self.knn = knn 206 | self.model = sbert 207 | self.mean = np.zeros((dim,)) if mean is None else mean 208 | self.std = np.ones((dim,)) if std is None else std + eps 209 | self.vect_transformer = SimpleTransformer() if vect_transformer is None else vect_transformer 210 | 211 | self.r_char = re.compile('[^A-zА-яЁё0-9": ]') 212 | self.r_spaces = re.compile(r"\s+") 213 | 214 | def clean_string(self, text): 215 | """ 216 | Очистка и нормализация строки. 217 | """ 218 | seq = self.r_char.sub(' ', text.replace('\n', ' ')) 219 | seq = self.r_spaces.sub(' ', seq).strip() 220 | return seq.lower() 221 | 222 | def get_vect(self, q): 223 | """ 224 | Преобразование текста в вектор и его нормализация. 225 | """ 226 | vect_q = self.model.encode(q, convert_to_tensor=False) 227 | vect_q = self.vect_transformer.transform([vect_q])[0] 228 | return (vect_q - self.mean) / self.std 229 | 230 | def __get_answer_text(self, text_q, n_neighbors=1): 231 | """ 232 | Получение ответов на основе векторизованного текста. 233 | 234 | :param text_q: Векторизованный запрос. 235 | :param n_neighbors: Количество ближайших соседей для поиска. 236 | :return: Список ответов. 237 | """ 238 | vect = self.get_vect(text_q) 239 | 240 | # Получаем индексы ближайших соседей 241 | distances, indices = self.knn.kneighbors([vect], n_neighbors=n_neighbors) 242 | 243 | # Возвращаем список ответов на основе индексов ближайших соседей 244 | return list(set([self.answers[idx] for idx in indices[0]])) 245 | 246 | def get_answer(self, q, n_neighbors=1): 247 | """ 248 | Получение ответа на входящий запрос. 249 | 250 | :param q: Вопрос. 251 | :param n_neighbors: Количество ближайших соседей, которые нужно вернуть. 252 | :return: Список ответов. 253 | """ 254 | text_q = self.clean_string(q) 255 | return self.__get_answer_text(text_q, n_neighbors) 256 | 257 | def train(self, csv_path, embedder, knn_neighbors=5): 258 | """ 259 | Метод для обучения бота на основе CSV с вопросами и ответами с расчетом среднего и стандартного отклонения. 260 | 261 | :param csv_path: Путь к CSV файлу с вопросами и ответами. 262 | :param embedder: Модель эмбеддинга для преобразования текста в векторы. 263 | :param knn_neighbors: Количество соседей для метода KNN. 264 | """ 265 | # Шаг 1: Загрузка данных из CSV 266 | data = pd.read_csv(csv_path) 267 | if 'question' not in data.columns or 'answer' not in data.columns: 268 | raise ValueError("CSV файл должен содержать колонки 'question' и 'answer'") 269 | 270 | questions = data['question'].tolist() 271 | answers = data['answer'].tolist() 272 | 273 | # Шаг 2: Преобразование вопросов в векторы 274 | question_vectors = np.array([embedder.encode(q, convert_to_tensor=False) for q in questions]) 275 | 276 | # Шаг 3: Вычисление среднего и стандартного отклонения по векторам вопросов 277 | self.mean = np.mean(question_vectors, axis=0) 278 | self.std = np.std(question_vectors, axis=0) 279 | 280 | # Чтобы избежать деления на ноль, добавляем небольшое значение eps 281 | eps = 1e-10 282 | self.std += eps 283 | 284 | # Шаг 4: Нормализация векторов вопросов 285 | normalized_vectors = (question_vectors - self.mean) / self.std 286 | 287 | # Шаг 5: Инициализация и обучение KNN 288 | self.knn = NearestNeighbors(n_neighbors=knn_neighbors, metric='cosine') 289 | self.knn.fit(normalized_vectors) 290 | 291 | # Сохранение ответов 292 | self.answers = answers 293 | 294 | def get_normalized_vector(self, question): 295 | """ 296 | Получение нормализованного вектора вопроса на основе сохраненных среднего и std. 297 | 298 | :param question: Вопрос для преобразования. 299 | :return: Нормализованный вектор. 300 | """ 301 | vect_q = self.model.encode(question, convert_to_tensor=False) 302 | return (vect_q - self.mean) / self.std 303 | -------------------------------------------------------------------------------- /explainitall/QA/extractive_qa_sbert/SVDBert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import BertModel 3 | 4 | 5 | class SVDLinearLayer(nn.Module): 6 | """ 7 | Линейный слой, использующий пространство уменьшенной размерности, использующий SVD. 8 | """ 9 | 10 | def __init__(self, in_features, out_features, h_dim, bias=True): 11 | super(SVDLinearLayer, self).__init__() 12 | self.encoder = nn.Linear(in_features, h_dim, bias=False) 13 | self.decoder = nn.Linear(h_dim, out_features, bias=bias) 14 | 15 | def forward(self, x): 16 | x = self.encoder(x) 17 | x = self.decoder(x) 18 | return x 19 | 20 | 21 | class SVDBertModel(BertModel): 22 | """ 23 | Модель BERT с уменьшенной размерностью в определенных слоях с использованием подхода SVD. 24 | """ 25 | 26 | def __init__(self, config, svd_dim=5): 27 | super(SVDBertModel, self).__init__(config) 28 | 29 | for i, layer in enumerate(self.encoder.layer): 30 | intermediate_size = layer.intermediate.dense.out_features 31 | output_size = layer.output.dense.out_features 32 | 33 | if i > 0: 34 | layer.intermediate.dense = SVDLinearLayer( 35 | layer.intermediate.dense.in_features, 36 | intermediate_size, 37 | svd_dim 38 | ) 39 | layer.output.dense = SVDLinearLayer( 40 | layer.output.dense.in_features, 41 | output_size, 42 | svd_dim 43 | ) 44 | else: 45 | layer.intermediate.dense = nn.Linear( 46 | layer.intermediate.dense.in_features, 47 | intermediate_size, 48 | bias=True 49 | ) 50 | layer.output.dense = nn.Linear( 51 | layer.output.dense.in_features, 52 | output_size, 53 | bias=True 54 | ) 55 | -------------------------------------------------------------------------------- /explainitall/QA/interp_qa/KNNWithGenerative.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import nltk 4 | import numpy as np 5 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 6 | 7 | from explainitall.QA.extractive_qa_sbert.QABotsBase import SimpleTransformer 8 | 9 | nltk.download('punkt') 10 | 11 | 12 | class FredStruct: 13 | 14 | def __init__(self, path='Ponimash/FredInterpreter', device='cpu'): 15 | self.tokenizer = AutoTokenizer.from_pretrained(path) 16 | self.fred_model = AutoModelForSeq2SeqLM.from_pretrained(path) 17 | self.fred_model = self.fred_model.to(device) 18 | self.fred_model.eval() 19 | 20 | def get_model(self): 21 | return self.fred_model 22 | 23 | def get_tokenizer(self): 24 | return self.tokenizer 25 | 26 | 27 | class PromptBot: 28 | """ 29 | Поисковый бот на базе модели векторизации и метода ближ. соседа 30 | с установкой максимального радиуса для детекции аномалий 31 | """ 32 | 33 | def __init__(self, knn, sbert, fred, texts, max_words=50, mean=None, std=None, vect_transformer=None, dim=None, 34 | n_neighbors=3, eps=1e-200, device='cuda'): 35 | self.max_words = max_words 36 | self.knn = knn 37 | self.knn.n_neighbors = n_neighbors 38 | self.texts = texts 39 | self.sbert = sbert 40 | self.mean = np.zeros((dim)) if str(type(mean)) != "" else mean 41 | self.std = np.ones((dim)) if str(type(std)) != "" else std + eps 42 | self.vect_transformer = SimpleTransformer() if vect_transformer == None else vect_transformer 43 | self.fred = fred 44 | self.device = device 45 | 46 | def __clean_string(self, text): 47 | """ 48 | Очистка строки 49 | """ 50 | seq = text.replace('\n', ' ') 51 | r_char = re.compile('[^A-zА-яЁё0-9": ]') 52 | r_spaces = re.compile(r"\s+") 53 | seq = r_char.sub(' ', seq) 54 | seq = r_spaces.sub(' ', seq).strip() 55 | return seq.lower() 56 | 57 | return data_inp, data_outp 58 | 59 | def __qa__(self, doc, q): 60 | doc = doc.replace('\n', ' ') 61 | q = q.replace('\n', ' ') 62 | q_pr = f"Опираясь только на информацию: {doc}.\n Подробно и полно ответь на вопрос указав все детали и числовые значения: \"{q}\" " 63 | data_inp = self.fred.get_tokenizer()(q_pr, return_tensors="pt").to(self.device) 64 | return data_inp 65 | 66 | def __generate__(self, doc, q): 67 | t = self.__qa__(doc, q) 68 | output_ids = self.fred.get_model().generate( 69 | **t, do_sample=True, temperature=0.2, max_new_tokens=256, top_p=0.95, top_k=15, repetition_penalty=1.03, 70 | no_repeat_ngram_size=3 71 | )[0] 72 | out = self.fred.get_tokenizer().decode(output_ids.tolist(), skip_special_tokens=True) 73 | return out.replace("", "") 74 | 75 | def get_vect(self, q): 76 | """вектор из текста""" 77 | vect_q = self.sbert.encode(q, convert_to_tensor=False) 78 | vect_q = self.vect_transformer.transform([vect_q])[0] 79 | return (vect_q - self.mean) / self.std 80 | 81 | @staticmethod 82 | def cut(text, max_len=15): 83 | words = text.split(' ')[:max_len] 84 | ret_text = '' 85 | for word in words: 86 | ret_text += word + ' ' 87 | 88 | return ret_text 89 | 90 | def get_answers(self, q, top_k=7): 91 | """ответ по преобразованному тексту""" 92 | vect_q = self.get_vect(q) 93 | answer_ids = self.knn.kneighbors([vect_q], top_k) 94 | doc = '' 95 | for i in range(answer_ids[0].shape[1]): 96 | doc += self.texts[answer_ids[1][0][i]] + '\n' 97 | return self.__generate__(doc, q) 98 | -------------------------------------------------------------------------------- /explainitall/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bots-Avatar/ExplainitAll/0339ea5c09c3cd309d53c23b403465c821a778d0/explainitall/__init__.py -------------------------------------------------------------------------------- /explainitall/clusters.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List, Callable, Dict, Union 3 | 4 | import gensim 5 | from gensim.models import KeyedVectors 6 | from pandas import DataFrame 7 | 8 | from explainitall.gpt_like_interp.inseq_helpers import AttrObj 9 | from . import nlp 10 | from .gpt_like_interp import viz, inseq_helpers 11 | 12 | 13 | class Cluster: 14 | def __init__(self, name: str, 15 | sim_thresh: float, 16 | words: list, matching_func: Callable): 17 | self.name = name # Имя кластера 18 | self.sim_thresh = sim_thresh # Порог схожести 19 | self.words = words # Слова 20 | self.matching_func = matching_func # Функция сравнения 21 | 22 | 23 | class ClusterBuilder: 24 | def __init__(self, name: str, seed_words: list, 25 | embeddings: KeyedVectors, 26 | word_processor, 27 | num_similar_words: int = 200): 28 | """ 29 | name: Имя кластера 30 | seed_words: Начальный список слов 31 | embeddings: векторные представления слов 32 | num_similar_words: Количество наиболее похожих слов для поиска 33 | """ 34 | self.name = name 35 | self.embeddings = embeddings 36 | self.num_similar_words = num_similar_words 37 | self.word_processor = word_processor 38 | 39 | self.seed_words = seed_words 40 | 41 | self.embeddable_seed_words = self.word_processor.get_embeddable_words_batch(self.seed_words) 42 | 43 | def build(self, matching_func: Callable) -> Cluster: 44 | """ 45 | Создает и возвращает объект Cluster 46 | matching_func: Функция для сравнения именованной сущности со словом 47 | """ 48 | return Cluster( 49 | name=self.name, 50 | sim_thresh=0, 51 | words=self.find_similar_words(), 52 | matching_func=matching_func 53 | ) 54 | 55 | def get_embeddable_word_from_most_similar(self, value) -> str: 56 | word_and_postfix, likelihood = value 57 | word, postfix = word_and_postfix.split("_") 58 | return self.word_processor.get_embeddable_word_or_none(word) 59 | 60 | def filter_and_clean_words_postfix(self, word_list: list, postfix: str = "_NOUN") -> list: 61 | filtered_w = [w for w in word_list if w and w.endswith(postfix)] 62 | return [w[:-len(postfix)] for w in filtered_w] 63 | 64 | def find_similar_words(self) -> list: 65 | if not self.embeddable_seed_words: 66 | return [] 67 | 68 | try: 69 | similar_words = self.embeddings.most_similar( 70 | positive=self.embeddable_seed_words, 71 | topn=self.num_similar_words 72 | ) 73 | except KeyError as e: 74 | raise Exception("embeddable_seed_words должны быть вида ['cat_NOUN','run_VERB']") from e 75 | 76 | extracted_words = [ 77 | self.get_embeddable_word_from_most_similar(result) for result in similar_words 78 | ] 79 | extracted_words = self.embeddable_seed_words + extracted_words 80 | extracted_words = self.filter_and_clean_words_postfix(extracted_words) 81 | 82 | return extracted_words 83 | 84 | 85 | class ClusterManager: 86 | def __init__(self, embeddings: gensim.models.keyedvectors.KeyedVectors): 87 | self.embeddings = embeddings 88 | self.word_processor = nlp.WordProcessor(embeddings) 89 | 90 | def _is_same_normalized_word(self, word1: str, word2: str) -> bool: 91 | """ 92 | Проверяет, одинаковы ли нормализованные формы двух слов. 93 | """ 94 | return self.word_processor.get_normal_form_or_none(word1) == self.word_processor.get_normal_form_or_none(word2) 95 | 96 | def find_cluster_name(self, word: str, clusters: List[Cluster]) -> str: 97 | """ 98 | Преобразует слово в имя кластера. 99 | """ 100 | normalized_word = self.word_processor.get_normal_form_or_none(word) 101 | if normalized_word: 102 | for cluster in clusters: 103 | if normalized_word in cluster.words: 104 | return cluster.name 105 | return "unnamed" 106 | 107 | def create_clusters(self, clusters_descr: List[Dict]): 108 | """ 109 | Создает кластеры на основе их описаний. 110 | """ 111 | clusters = [] 112 | 113 | for descr in clusters_descr: 114 | clusters.append(ClusterBuilder( 115 | name=descr['name'], 116 | seed_words=descr['centroid'], 117 | embeddings=self.embeddings, 118 | num_similar_words=descr['top_k'], 119 | word_processor=self.word_processor 120 | ).build(matching_func=self._is_same_normalized_word)) 121 | return clusters 122 | 123 | 124 | class ClusterInterpreter: 125 | """Интерпретация Кластеров""" 126 | 127 | def __init__(self, 128 | clusters_discr: List[Dict[str, object]], 129 | cluster_manager: ClusterManager 130 | ): 131 | self.cluster_manager = cluster_manager 132 | self.clusters = cluster_manager.create_clusters(clusters_discr) 133 | 134 | def set_link_with_clusters(self, attribute): 135 | """Устанавливает связь между семантическими кластерами.""" 136 | grouped_attribute = inseq_helpers.group_by(attribute, gmm_norm=True) 137 | return self._create_parsed_attribution(grouped_attribute) 138 | 139 | def get_cluster_importance_df(self, attribute): 140 | """Преобразует атрибуты в dataframe.""" 141 | attribute_with_clusters = self.set_link_with_clusters(attribute) 142 | return inseq_helpers.attr_to_df(attribute_with_clusters) 143 | 144 | def display_attr(self, attribute): 145 | """Отображает атрибуты в виде тепловой карты.""" 146 | attribute_with_clusters = self.set_link_with_clusters(attribute) 147 | return viz.attr_to_heatmap(attribute_with_clusters) 148 | 149 | def _create_parsed_attribution(self, grouped_attribute: AttrObj): 150 | """Создает разобранные атрибуции с сгенерированными метками.""" 151 | tokens_generated_cl = [self.cluster_manager.find_cluster_name(word=x, clusters=self.clusters) 152 | for x in grouped_attribute.tokens_generated] 153 | 154 | tokens_input_cl = [self.cluster_manager.find_cluster_name(word=x, clusters=self.clusters) 155 | for x in grouped_attribute.tokens_input] 156 | 157 | grouped_attribute = copy.deepcopy(grouped_attribute) 158 | grouped_attribute.tokens_input = tokens_input_cl 159 | grouped_attribute.tokens_generated = tokens_generated_cl 160 | return grouped_attribute 161 | 162 | 163 | def aggregate_cluster_df(cluster_df: DataFrame, 164 | aggr_f: Union[('max', 'min', 'sum', 'mean', 'median', 'std', 'var', 'sem', 'skew')], 165 | drop_condition: str = 'unnamed', 166 | cl_names_col: str = 'Tokens', 167 | ) -> DataFrame: 168 | cluster_df = cluster_df[~cluster_df[cl_names_col].str.contains(drop_condition)] 169 | cluster_df = cluster_df.loc[:, ~cluster_df.columns.str.contains(drop_condition)] 170 | aggregation_functions = {col: aggr_f for col in cluster_df.columns if col != cl_names_col} 171 | return cluster_df.groupby(cl_names_col).agg(aggregation_functions).reset_index() 172 | -------------------------------------------------------------------------------- /explainitall/embedder_interp/embd_interpret.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import norm 3 | from sentence_transformers import util 4 | 5 | 6 | def sen_a(model, texts): 7 | embeddings = model.encode(texts) 8 | s = util.pytorch_cos_sim(embeddings, embeddings) 9 | s[s < 0] = 0 10 | matrix = (1 - s).to('cpu').detach().numpy() 11 | return matrix[matrix > 1e-7].mean() 12 | 13 | 14 | class CosRelu: 15 | @staticmethod 16 | def dot_(v1, v2): 17 | """Скалярное произведение (у numpy не точно считает)""" 18 | sum_ = 0 19 | for i, v in enumerate(v2): 20 | sum_ += v1[i] * v 21 | return sum_ 22 | 23 | @staticmethod 24 | def cos(v1, v2): 25 | """Косинус""" 26 | return CosRelu.dot_(v1, v2) / (norm(v1) * norm(v2)) 27 | 28 | @staticmethod 29 | def cos_relu(y_orig, y_without_ner) -> float: 30 | """Рассчет ReLU от косинуса между векторам эмбеддингов""" 31 | r = CosRelu.cos(y_orig, y_without_ner) 32 | r = r if r >= 0 else 0 33 | return r 34 | 35 | 36 | class ModelInterp: 37 | 38 | def __init__(self, model): 39 | self.model = model 40 | self.mask_token = model.tokenizer.mask_token 41 | 42 | def seq_interp(self, sent): 43 | refer = self.model.encode(sent) 44 | 45 | words = sent.split(' ') 46 | imp = [] 47 | y_mod = [] 48 | 49 | for k in range(len(words)): 50 | repl_word = ' '.join([word if i != k else self.mask_token for i, word in enumerate(words)]) 51 | y_mod.append(repl_word) 52 | 53 | targ = self.model.encode(y_mod) 54 | 55 | for vect in targ: 56 | imp.append(1 - CosRelu.cos_relu(refer, vect)) 57 | 58 | imp = np.array(imp) 59 | imp_sum = imp.sum() 60 | imp /= imp_sum 61 | 62 | return {'imp': imp, 'words': words} 63 | 64 | def dataset_interp(self, texts): 65 | ret_data = [] 66 | for text in texts: 67 | interp_seq = self.seq_interp(text) 68 | ret_data.append(interp_seq) 69 | 70 | return ret_data 71 | 72 | def __claster_energy(self, cluster_data): 73 | elements = cluster_data['elements'] 74 | vectors = self.model.encode(elements) 75 | cl_energe = np.array([sum(vector ** 2) for vector in vectors]) 76 | return {'name': cluster_data['name'], 'sensitivity': sen_a(self.model, elements), 'mean': cl_energe.mean()} 77 | 78 | def clusters_interp(self, clusters_data): 79 | return [self.__claster_energy(cluster_data) for cluster_data in clusters_data] 80 | -------------------------------------------------------------------------------- /explainitall/fast_tuning/Embedder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from transformers import GPT2Model, GPT2Config 4 | 5 | 6 | class GPTEmbedder: 7 | """Эмбеддер для gpt""" 8 | 9 | def __init__(self, tokenizer_gpt, gpt, device=None): 10 | 11 | if device is None: 12 | device = 'cpu' 13 | if torch.cuda.is_available(): 14 | device = 'cuda:0' 15 | 16 | self.model = gpt 17 | self.model.to(device) 18 | self.tokenizer = tokenizer_gpt 19 | self.device = device 20 | 21 | # 22 | def get_emb_from_gpt(self, inp_str, n_layer_index=-1, token=-1, is_attention=False): 23 | """ 24 | Данные(эмбеддинги) со скрытых слоев 25 | Вернуть скрытое состояние или данные внимания 26 | """ 27 | att_hidden = 0 28 | if is_attention: 29 | att_hidden = 1 30 | 31 | inp_tokens = self.tokenizer.encode(inp_str) 32 | context = torch.tensor(inp_tokens, dtype=torch.long, device=self.device) 33 | generated = context.unsqueeze(0) 34 | inputs = {'input_ids': generated} 35 | 36 | with torch.no_grad(): 37 | outputs = self.model(**inputs) 38 | 39 | if n_layer_index == 'all': 40 | return outputs.last_hidden_state[0, -1, :].reshape(self.model.config.n_embd).to('cpu').detach().numpy() 41 | else: 42 | return_obj = outputs[1][n_layer_index][att_hidden][0, :, token, :].to('cpu') # 0 - т.к. батч ожидается 1 43 | 44 | shape = return_obj.shape 45 | return return_obj.reshape((shape[0] * shape[1])).detach().numpy() 46 | 47 | def get_embs_from_gpt(self, inp_str, n_layer_index=-1, head_index=0, is_attention=False): 48 | """ 49 | Данные(эмбеддинги) со скрытых слоев(По всем токенам) 50 | Вернуть скрытое состояние или данные внимания 51 | """ 52 | 53 | att_hidden = 0 54 | if is_attention: 55 | att_hidden = 1 56 | 57 | inp_tokens = self.tokenizer.encode(inp_str) 58 | context = torch.tensor(inp_tokens, dtype=torch.long, device=self.device) 59 | generated = context.unsqueeze(0) 60 | inputs = {'input_ids': generated} 61 | 62 | with torch.no_grad(): 63 | outputs = self.model(**inputs) 64 | 65 | # Пройти всю сеть включая слой нормализации 66 | if n_layer_index == 'all': 67 | return outputs.last_hidden_state[0, :, :].reshape((len(inp_tokens), self.model.config.n_embd)).to( 68 | 'cpu').detach().numpy() 69 | 70 | # Пройти заданное число gpt блоков 71 | else: 72 | # Вернуть все головы внимания или только 1 73 | if head_index == 'all': 74 | return_obj = outputs[1][n_layer_index][att_hidden][0, :, :, :].to('cpu').detach().numpy() 75 | out_len = return_obj.shape[0] * return_obj.shape[2] 76 | return_obj = np.transpose(return_obj, (1, 0, 2)).reshape(return_obj.shape[1], out_len) 77 | else: 78 | return_obj = outputs[1][n_layer_index][att_hidden][0, head_index, :, :].to( 79 | 'cpu').detach().numpy() # 0 - т.к. батч ожидается 1 80 | 81 | return return_obj 82 | 83 | def _get_k_layer(self, num_layers=3, name='gpt-embeder'): 84 | """Создание модели на базе первых слоев gpt2 донора""" 85 | base_model = self.model.base_model # модель-донор 86 | config_base = base_model.config # Конфиг модели-донора 87 | config = GPT2Config.from_dict(config_base.to_dict()) # Копирование конфига 88 | 89 | if num_layers < 0: 90 | num_layers = config.n_layer + num_layers + 1 91 | 92 | config.name_or_path = name # Имя сети 93 | config.n_layer = num_layers # Установка нужного числа слоев 94 | 95 | gpt_emb = GPT2Model(config) # Создание модели 96 | gpt_emb.wte.weight = self.model.transformer.wte.weight # Эмбединги слов 97 | gpt_emb.wpe.weight = self.model.transformer.wpe.weight # Эмбединги позиций 98 | gpt_emb.ln_f.weight = self.model.transformer.ln_f.weight # Слой нормализации 99 | 100 | for n_layer in range(num_layers): 101 | gpt_emb.base_model.h[n_layer] = base_model.h[n_layer] # Копирование слоев 102 | 103 | return gpt_emb 104 | 105 | def get_new_model(self, num_layers=3, name='gpt-embeder', save_path='gpt/model_emb'): 106 | """Создание модели на базе первых слоев gpt2 донора с перезаписью""" 107 | m = self._get_k_layer(num_layers, name) 108 | m.save_pretrained(save_path) 109 | return GPT2Model.from_pretrained(save_path) 110 | -------------------------------------------------------------------------------- /explainitall/fast_tuning/ExpertBase.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | class ExpertModel(abc.ABC): 4 | @abc.abstractmethod 5 | def get_bias(self, tokens): 6 | """Вычисление bias из вероятностной модели""" 7 | -------------------------------------------------------------------------------- /explainitall/fast_tuning/SimpleModelCreator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model 5 | 6 | from .Embedder import GPTEmbedder 7 | from .trainers.DenceKerasTrainer import GPTFastTrainer 8 | 9 | 10 | def get_dataset_dense(txts, embedder: GPTEmbedder, tokenizer: GPT2Tokenizer, n_layer_index='all'): 11 | """ 12 | Создает датасет из текстов с помощью заданного embedder и tokenizer. 13 | """ 14 | list_x, list_y = [], [] 15 | for txt in txts: 16 | words = txt.split(' ') 17 | for i in range(0, len(words), 25): 18 | text = ' '.join(words[i:]) 19 | emb = embedder.get_embs_from_gpt(text, n_layer_index=n_layer_index)[:-1][:1024] 20 | ids = np.array(tokenizer(text)['input_ids'])[1:] 21 | list_x.append(emb) 22 | list_y.append(ids) 23 | 24 | return np.concatenate(list_x), np.concatenate(list_y) 25 | 26 | 27 | def gpt_build(trainer: GPTFastTrainer, gpt_emb: GPT2Model, tokenizer: GPT2Tokenizer, y_set, 28 | path_to_save='gpt_model_new'): 29 | """ 30 | Создает и сохраняет новую модель GPT. 31 | """ 32 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 33 | gpt_emb.to(device) 34 | 35 | w_out_matr = trainer.adapter_layer.get_weights()[0] @ trainer.keras_out_weight.T 36 | config_gpt = gpt_emb.config 37 | new_model = GPT2LMHeadModel(config_gpt) 38 | new_model.transformer = gpt_emb 39 | new_model.lm_head = nn.Linear(config_gpt.n_embd, config_gpt.vocab_size, bias=False) 40 | new_model.lm_head.weight = nn.Parameter(torch.tensor(w_out_matr, device=device)) 41 | 42 | new_model.save_pretrained(path_to_save) 43 | tokenizer.save_pretrained(path_to_save) 44 | np.save(f'{path_to_save}/set.data', np.array(y_set)) 45 | 46 | 47 | class SimpleCreator: 48 | def __init__(self, model: GPT2Model, tokenizer: GPT2Tokenizer): 49 | self.tokenizer = tokenizer 50 | main_embedder = GPTEmbedder(tokenizer, model) 51 | self.gpt_emb = main_embedder.get_new_model(num_layers=-1) 52 | self.cut_embedder = GPTEmbedder(self.tokenizer, self.gpt_emb) 53 | self.trainer = GPTFastTrainer(model) 54 | 55 | def train(self, data, lr=0.0003, bs=64, epochs=6, val_split=0.0, save_path='new_model'): 56 | x, y = get_dataset_dense(data, self.cut_embedder, self.tokenizer) 57 | net = self.trainer.create_net() 58 | self.trainer.train(net, x, y, lr=lr, bs=bs, epochs=epochs, val_split=val_split) 59 | y_set = list(set(y)) 60 | gpt_build(self.trainer, self.gpt_emb, self.tokenizer, y_set, save_path) 61 | return net 62 | -------------------------------------------------------------------------------- /explainitall/fast_tuning/generators/GeneratorWithExpert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | class GPTGenerator: 7 | 8 | def __init__(self, model, tokenizer, expert, device = None): 9 | self.infer_device = device 10 | self.model_gpt = model 11 | self.model_gpt.to(device) 12 | self.tokenizer_gpt = tokenizer 13 | self.expert = expert 14 | self.model_gpt.eval() 15 | 16 | # Штраф за повтор 17 | def __repeat_penalty(self, generated, logits, rp): 18 | hist_tokens = generated[0].cpu().numpy() 19 | unique_tokens, counts = np.unique(hist_tokens, return_counts=True) 20 | for token, count in zip(unique_tokens, counts): 21 | logits[token] -= rp * count 22 | 23 | #top p, top k фильтрация 24 | def _get_token(self, logist, top_k = 30, top_p = 0.9, del_simbols = None): 25 | 26 | filter_value = float('-inf') 27 | 28 | # Фильтруем нежелательные токены 29 | if del_simbols != None: 30 | for del_s in del_simbols: 31 | logist[del_s] = filter_value 32 | 33 | #top_k 34 | topk = torch.topk(logist, top_k) 35 | topk_v = topk[0] # значения top_k 36 | topk_i = topk[1] # индексы top_k 37 | 38 | #top_p 39 | probs = F.softmax(topk_v, dim = -1) 40 | cumulative_probs = torch.cumsum(probs, dim = -1) 41 | probs[top_p < cumulative_probs] = 0 42 | 43 | if sum(probs) == 0: 44 | probs[0] = 1 45 | 46 | token_ind = torch.multinomial(probs, 1) 47 | token = topk_i[token_ind] 48 | return token 49 | 50 | # Генерация последовательности 51 | def _sample_sequence(self, length, context_tokens, temperature=1, top_k=30, expert_w = 0.2, rp = 0.1, del_simbols = None): 52 | inp_len = len(context_tokens) 53 | context = torch.tensor(context_tokens, dtype=torch.long, device=self.infer_device) 54 | generated = context.unsqueeze(0) 55 | 56 | with torch.no_grad(): 57 | decoded = '' 58 | for _ in range(length): 59 | inputs = {'input_ids': generated[:, -1023:]} # Входы 60 | outputs = self.model_gpt(**inputs) # Прямой проход gpt 61 | 62 | g_with_start = list(generated[0].cpu().numpy())# Затравка для эксперта 63 | bias_expert = self.expert.get_bias(g_with_start) # bias на базе эксперта или их смеси 64 | bias_expert = torch.tensor(bias_expert).to(self.infer_device) # bias в виде тензора 65 | next_token_logits = ((1-expert_w)*outputs[0][0, -1, :]+expert_w*bias_expert) / temperature 66 | self.__repeat_penalty(generated, next_token_logits, rp) 67 | next_tokens = self._get_token(next_token_logits, top_k = top_k, del_simbols = del_simbols).to(self.infer_device) # Генерация из распределения 68 | 69 | if next_tokens == 0: 70 | break 71 | 72 | generated = torch.cat((generated, next_tokens.unsqueeze(-1)), dim=1) 73 | 74 | out = generated[0, len(context_tokens):].tolist() 75 | new_decoded = self.tokenizer_gpt.decode(out) 76 | if len(new_decoded) > len(decoded): 77 | decoded = new_decoded 78 | return decoded 79 | 80 | 81 | # Генерация текста из текста 82 | def _generate(self, raw_text, length=250, temperature=1., top_k=30, rp = 0.03, expert_w =0.2, del_simbols = None): 83 | context_tokens = self.tokenizer_gpt.encode(raw_text) 84 | out = self._sample_sequence( 85 | length, context_tokens, 86 | rp = rp, 87 | expert_w=expert_w, 88 | temperature=temperature, 89 | top_k=top_k, 90 | del_simbols = del_simbols, 91 | ) 92 | return out 93 | 94 | # Генерация нескольких последовательностей из начального текста 95 | def Generate(self, text, max_len = 200, num_seq = 2, temperature = 1.0, topk = 30, rp = 0.03, expert_w =0.2, del_simbols = None): 96 | ret = [] 97 | for i in range(num_seq): 98 | ret.append( 99 | self._generate(text, max_len, temperature=temperature, top_k=topk, rp=rp, expert_w = expert_w, del_simbols=del_simbols) 100 | ) 101 | return ret 102 | 103 | # Генерация с использованием Bias 104 | class GenerationWithProbs: 105 | 106 | def __init__(self, model, tokenizer, bias_mask, device = 'cpu'): 107 | self.model = model 108 | self.b_mask = bias_mask 109 | self.tokenizer = tokenizer 110 | self.model.to(device) 111 | 112 | def generate(self, text, top_p = 0.9, max_len = 30, rp = 1.15, temperature=0.7, variety = 0.3): 113 | self.__set_variety(variety) 114 | do_sample = temperature > 0 115 | 116 | input_ids = self.tokenizer.encode(text, return_tensors='pt').to(self.model.device) 117 | output_sequences = self.model.generate(input_ids=input_ids, 118 | max_length=max_len, 119 | temperature= None if temperature == 0 else temperature, 120 | top_p= None if temperature == 0 else top_p, 121 | repetition_penalty=rp, 122 | num_return_sequences=1, 123 | do_sample=do_sample) 124 | 125 | generated_text = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True) 126 | return generated_text 127 | 128 | def __set_variety(self, variety=0.0): 129 | bias = np.zeros((self.model.lm_head.out_features,)) 130 | coef_mask = np.log2(variety + 3e-3) / np.log2(np.e) 131 | bias += coef_mask 132 | 133 | for token in self.b_mask: 134 | bias[token] = 0 135 | 136 | b_tensor = torch.tensor(bias, dtype=torch.float32) 137 | out_gpt_layer = torch.nn.Linear(in_features=self.model.lm_head.in_features, out_features=self.model.lm_head.out_features, bias=True) 138 | out_gpt_layer.weight = self.model.lm_head.weight 139 | out_gpt_layer.bias.data.copy_(b_tensor) 140 | self.model.lm_head = out_gpt_layer 141 | -------------------------------------------------------------------------------- /explainitall/fast_tuning/generators/MCExpert.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | 3 | from explainitall.fast_tuning.ExpertBase import ExpertModel 4 | 5 | 6 | class MCExpert(ExpertModel, metaclass=ABCMeta): 7 | 8 | def __init__(self, mc): 9 | self.mc = mc 10 | 11 | def get_bias(self, tokens): 12 | start = [1, 1] + tokens 13 | return self.mc.get_bias(*start[-2:]) -------------------------------------------------------------------------------- /explainitall/fast_tuning/generators/MCModel.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | 4 | 5 | # преобразование 2х токенов в строку 6 | def token_pair_2_str(tokens): 7 | list_x = list(tokens) 8 | return str(list_x) 9 | 10 | # Кодирование конца 11 | def encode_end(tokens_gpt, x_encode): 12 | str_x = token_pair_2_str(tokens_gpt[-2:]) 13 | return x_encode[str_x] if str_x in x_encode else -1 14 | 15 | # Markov 16 | class MarkovModel(): 17 | def __init__(self, len_vect, x_e=None, y_d = None, model = None, path = None, dep = 4): 18 | if path == None: 19 | self.x_encoder = x_e 20 | self.y_decoder = y_d 21 | else: 22 | with open(path + 'x_enc.dat', 'rb') as f: 23 | self.x_encoder = pickle.load(f) 24 | with open(path + 'y_dec.dat', 'rb') as f: 25 | self.y_decoder = pickle.load(f) 26 | with open(path + 'model.dat', 'rb') as f: 27 | model = pickle.load(f) 28 | 29 | self.dep = dep 30 | self.model_log = self.__get_log_model(model) 31 | self.len_vect = len_vect 32 | 33 | #----------------------- Генрерация bias --------------# 34 | def get_bias(self, token_1 = 1, token_2 = 1): 35 | '''Генерация bias''' 36 | b = [token_1, token_2] 37 | key = encode_end(b, self.x_encoder) 38 | bias = np.zeros((self.len_vect)) 39 | 40 | if key != -1: 41 | tokens, logs = self.model_log[key]['tokens'], self.model_log[key]['logists'] 42 | tokens = self.__get_tokens(tokens) 43 | bias -= self.dep 44 | 45 | for i, token in enumerate(tokens): 46 | bias[token] = logs[i] 47 | 48 | return bias 49 | 50 | # Получение нормированных логарифмов 51 | def __get_logist(self, key, model): 52 | logist = np.log(model[key]['probs']) 53 | logist-=max(logist) 54 | return logist 55 | 56 | # Получение модели с логарифмами вероятностей 57 | def __get_log_model(self, model): 58 | m_l ={} 59 | for key in range(len(model)): 60 | m_l_semple = {} 61 | m_l_semple.update({'tokens': model[key]['tokens']}) 62 | m_l_semple.update({'logists':self.__get_logist(key, model)}) 63 | m_l.update({key:m_l_semple}) 64 | return m_l 65 | 66 | 67 | def __get_tokens(self, tokens): 68 | true_tokens = [] 69 | for t in tokens: 70 | true_tokens.append(self.y_decoder[t]) 71 | return true_tokens 72 | 73 | 74 | -------------------------------------------------------------------------------- /explainitall/fast_tuning/generators/SimpleGenerator.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline 2 | import numpy as np 3 | import torch 4 | 5 | class TextGenerator: 6 | def __init__(self, path): 7 | self.tokenizer = AutoTokenizer.from_pretrained(path) 8 | self.model = AutoModelForCausalLM.from_pretrained(path) 9 | 10 | self.y_set = np.load(f'{path}/set.data.npy') 11 | self.set_variety_of_answers(0.0) 12 | 13 | self.pipeline = TextGenerationPipeline(model=self.model, tokenizer=self.tokenizer) 14 | 15 | def set_variety_of_answers(self, variety=0, min_prob=3e-3): 16 | bias = np.zeros((self.model.lm_head.out_features,)) 17 | coef_mask = np.log2(variety + min_prob) / np.log2(np.e) 18 | bias += coef_mask 19 | 20 | for token in self.y_set: 21 | bias[token] = 0 22 | 23 | b_tensor = torch.tensor(bias, dtype=torch.float32) 24 | self.model.lm_head.bias.data.copy_(b_tensor) 25 | 26 | def generate(self, start_text, args=None): 27 | if args is None: 28 | args = { 29 | "temperature": 0.7, 30 | "no_repeat_ngram_size": 2, 31 | "num_beams": 12, 32 | "top_k": 30, 33 | } 34 | return self.pipeline(start_text, **args)["generated_text"] 35 | -------------------------------------------------------------------------------- /explainitall/fast_tuning/trainers/DenceKerasTrainer.py: -------------------------------------------------------------------------------- 1 | import keras.layers as L 2 | import numpy as np 3 | import torch 4 | from keras.layers import Dense 5 | from keras.models import Sequential 6 | from tensorflow import keras as K 7 | 8 | 9 | class GPTFastTrainer: 10 | """ Создание обучаемого последнего слоя""" 11 | 12 | def __init__(self, gpt_model): 13 | self.inp_dim = gpt_model.config.n_embd 14 | self.outp_dim = gpt_model.config.vocab_size 15 | 16 | if torch.cuda.is_available(): 17 | gpt_model.to('cpu') 18 | 19 | self.keras_out_weight = gpt_model.lm_head.weight.detach().numpy().transpose() # Получение весовых коэффициентов 20 | self.keras_adapter_weight = np.eye(self.inp_dim, dtype=float) # Единичная матрица (коэф. адаптирующего слоя) 21 | 22 | self.adapter_layer = Dense(self.inp_dim, use_bias=False, activation='linear') # Адаптирующий слой 23 | self.out_layer = Dense(use_bias=True, units=self.outp_dim, activation='linear', 24 | trainable=False) # Выходной слой (необучается) 25 | 26 | if torch.cuda.is_available(): 27 | gpt_model.to('cuda:0') 28 | 29 | # Создание сети для тюнинга # 30 | def creat_net(self): 31 | net = Sequential() 32 | net.add(L.Input(self.inp_dim)) 33 | net.add(self.adapter_layer) 34 | net.add(self.out_layer) 35 | net.add(L.Activation(activation='softmax')) 36 | net.compile() 37 | # Загрузка весов в выходной слой 38 | self.out_layer.set_weights([self.keras_out_weight, np.zeros(self.outp_dim)]) 39 | # Загрузка весов в слой адаптера 40 | self.adapter_layer.set_weights([self.keras_adapter_weight]) 41 | return net 42 | 43 | def set_variety_of_answers(self, y, variety=0, min_prob=1e-300): 44 | """Пересоздание слоя с установкой вариативности генерации""" 45 | set_tokens = set(y) 46 | bias = np.zeros(self.outp_dim) 47 | coef_mask = np.log2(variety + min_prob) / np.log2(np.e) 48 | bias += coef_mask 49 | 50 | for token in set_tokens: 51 | bias[token] = 0 52 | 53 | self.out_layer.set_weights([self.keras_out_weight, bias]) 54 | 55 | def train(self, net, x, y, lr=0.0003, bs=64, epochs=3, val_split=0.0): 56 | """Обучение сети""" 57 | self.set_variety_of_answers(y) 58 | opt = K.optimizers.Adamax(learning_rate=lr) 59 | net.compile(loss='sparse_categorical_crossentropy', optimizer=opt) 60 | net.fit(x, y, batch_size=bs, epochs=epochs, validation_split=val_split) 61 | -------------------------------------------------------------------------------- /explainitall/fast_tuning/trainers/HMMTrainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class GPT2HMMDataProcessor: 5 | def __init__(self, tokenizer): 6 | self.tokenizer = tokenizer 7 | 8 | def get_data_1(self, texts): 9 | list_y = [] 10 | for text in texts: 11 | list_tok = [1, 1] + self.tokenizer(text)['input_ids'] + [2] 12 | list_y.append(np.array(list_tok)) 13 | return np.concatenate(list_y) 14 | 15 | @staticmethod 16 | def token_pair_2_str(tokens): 17 | list_x = list(tokens) 18 | return str(list_x) 19 | 20 | def createXY(self, tokens_gpt): 21 | x = [] 22 | y = [] 23 | 24 | for i in range(len(tokens_gpt) - 2): 25 | str_x = self.token_pair_2_str(tokens_gpt[i:i + 2]) 26 | x.append(str_x) 27 | y.append(tokens_gpt[i + 2]) 28 | 29 | return x, y 30 | 31 | @staticmethod 32 | def encode_samples_x(x, y_encode): 33 | return [y_encode[i] for i in x] 34 | 35 | @staticmethod 36 | def encode_samples_y(y, y_encode): 37 | return [y_encode[i] for i in y] 38 | 39 | @staticmethod 40 | def encode_end(tokens_gpt, x_encode): 41 | str_x = GPT2HMMDataProcessor.token_pair_2_str(tokens_gpt[-2:]) 42 | return x_encode[str_x] if str_x in x_encode else -1 43 | 44 | def create_data(self, tokens_gpt): 45 | x, y = self.createXY(tokens_gpt) 46 | x_set = list(set(x)) 47 | mask_y = list(set(y)) 48 | 49 | x_encode = {} 50 | x_decode = [] 51 | 52 | y_encode = {} 53 | y_decode = [] 54 | 55 | for i, xi in enumerate(x_set): 56 | x_encode.update({xi: i}) 57 | x_decode.append(xi) 58 | 59 | for i, yi in enumerate(mask_y): 60 | y_encode.update({yi: i}) 61 | y_decode.append(yi) 62 | 63 | x_enc = self.encode_samples_x(x, x_encode) 64 | y_enc = self.encode_samples_y(y, y_encode) 65 | 66 | return {'x': x_enc, 'y': y_enc, 'x_encoder': x_encode, 'y_encoder': y_encode, 'x_decoder': x_decode, 67 | 'y_decoder': y_decode} 68 | 69 | @staticmethod 70 | def train(data): 71 | states = {} 72 | n = len(data['y']) 73 | x = data['x'] 74 | y = data['y'] 75 | 76 | for i in range(n): 77 | if x[i] in states: 78 | if y[i] in states[x[i]]['tokens']: 79 | states[x[i]]['probs'][y[i]] += 1 80 | else: 81 | states[x[i]]['probs'].update({y[i]: 1}) 82 | states[x[i]]['tokens'].update({y[i]: 0}) 83 | else: 84 | states.update({x[i]: {'tokens': {}, 'probs': {}}}) 85 | states[x[i]]['probs'].update({y[i]: 1}) 86 | states[x[i]]['tokens'].update({y[i]: 0}) 87 | 88 | n = max(states.keys()) + 1 89 | n_states = [] 90 | 91 | for i in range(n): 92 | tokens = list(states[i]['tokens'].keys()) 93 | total_count = 0 94 | probs = [] 95 | 96 | for t in tokens: 97 | count = states[i]['probs'][t] 98 | total_count += count 99 | probs.append(count) 100 | 101 | for j, p in enumerate(probs): 102 | probs[j] = p / total_count 103 | 104 | n_states.append({'tokens': tokens, 'probs': probs}) 105 | 106 | return n_states 107 | -------------------------------------------------------------------------------- /explainitall/fast_tuning/trainers/ProjectionTrainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from transformers import Trainer, TrainingArguments, PreTrainedTokenizer 4 | from transformers import DataCollatorForLanguageModeling 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class StringDataset(Dataset): 9 | def __init__(self, tokenizer: PreTrainedTokenizer, texts: list, block_size=256): 10 | block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence) 11 | self.examples = [] 12 | 13 | for text in texts: 14 | if len(text)==0: 15 | continue 16 | tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) 17 | if len(tokenized_text) >= block_size: 18 | for i in range(0, len(tokenized_text) - block_size + 1, block_size): 19 | self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i + block_size])) 20 | else: 21 | self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text)) 22 | 23 | def __len__(self): 24 | return len(self.examples) 25 | 26 | def __getitem__(self, i): 27 | return torch.tensor(self.examples[i], dtype=torch.long) 28 | 29 | class GPTProjectionTrainer: 30 | def __init__(self, model, tokenizer): 31 | self.model = model 32 | self.tokenizer = tokenizer 33 | 34 | def load_dataset(self, texts, block_size=256): 35 | dataset = StringDataset( 36 | tokenizer=self.tokenizer, 37 | texts=texts, 38 | block_size=block_size, 39 | ) 40 | return dataset 41 | 42 | def create_data_collator(self, mlm=False): 43 | data_collator = DataCollatorForLanguageModeling( 44 | tokenizer=self.tokenizer, 45 | mlm=mlm, 46 | ) 47 | return data_collator 48 | 49 | def set_variety(self, bias_mask, variety=0., min_prob=3e-3): 50 | bias = np.zeros((self.model.lm_head.out_features,)) 51 | coef_mask = np.log2(variety + min_prob) / np.log2(np.e) 52 | bias += coef_mask 53 | 54 | for token in bias_mask: 55 | bias[token] = 0 56 | 57 | b_tensor = torch.tensor(bias, dtype=torch.float32) 58 | out_gpt_layer = torch.nn.Linear(in_features=self.model.lm_head.in_features, 59 | out_features=self.model.lm_head.out_features, bias=True) 60 | out_gpt_layer.weight = self.model.lm_head.weight 61 | out_gpt_layer.bias.data.copy_(b_tensor) 62 | self.model.lm_head = out_gpt_layer 63 | 64 | def train(self, train_texts, output_dir="new_gpt", last_k=10, 65 | per_device_train_batch_size=2, num_train_epochs=3, save_steps=1000, device=None, lr = 2e-4): 66 | 67 | train_dataset = self.load_dataset(train_texts) 68 | 69 | self.train_with_dataset(train_dataset, output_dir, last_k, 70 | per_device_train_batch_size, num_train_epochs, save_steps, device, lr) 71 | 72 | 73 | def train_with_dataset(self, train_dataset, output_dir="new_gpt", last_k='all', 74 | per_device_train_batch_size=2, num_train_epochs=1, save_steps=10000, device=None, lr = 2e-4): 75 | if device is None: 76 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 77 | 78 | self.model.to(device) 79 | 80 | params = [] 81 | 82 | # Обучение всех слоев 83 | if last_k == 'all': 84 | for name, param in self.model.named_parameters(): 85 | param.requires_grad = True 86 | 87 | # Только проекции 88 | else: 89 | for name, param in self.model.named_parameters(): 90 | param.requires_grad = False 91 | if "c_proj.weight" in name: # and "mlp" in name 92 | params.append(param) 93 | 94 | for param in params[-last_k:]: 95 | param.requires_grad = True 96 | 97 | 98 | if len(train_dataset) == 0: 99 | raise ValueError("Dataset is empty. Ensure that the input texts are not empty and of sufficient length.") 100 | 101 | data_collator = self.create_data_collator() 102 | 103 | training_args = TrainingArguments( 104 | output_dir=output_dir, 105 | overwrite_output_dir=True, 106 | per_device_train_batch_size=per_device_train_batch_size, 107 | num_train_epochs=num_train_epochs, 108 | save_steps=save_steps, 109 | learning_rate=lr, 110 | ) 111 | 112 | trainer = Trainer( 113 | model=self.model, 114 | args=training_args, 115 | data_collator=data_collator, 116 | train_dataset=train_dataset, 117 | ) 118 | 119 | trainer.train() 120 | -------------------------------------------------------------------------------- /explainitall/gpt_like_interp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bots-Avatar/ExplainitAll/0339ea5c09c3cd309d53c23b403465c821a778d0/explainitall/gpt_like_interp/__init__.py -------------------------------------------------------------------------------- /explainitall/gpt_like_interp/downloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import shutil 4 | import zipfile 5 | 6 | import requests 7 | from tqdm import tqdm 8 | 9 | 10 | class DownloadManager: 11 | base_directory = os.path.join(os.path.expanduser("~"), ".cache") 12 | 13 | def _create_directory(self, directory_name): 14 | directory_path = os.path.join(self.base_directory, directory_name) 15 | os.makedirs(directory_path, exist_ok=True) 16 | return directory_path 17 | 18 | @staticmethod 19 | def _clean_string(text): 20 | return re.sub(r'[^a-zA-Z0-9]', '_', text) 21 | 22 | @staticmethod 23 | def _delete_existing_file(file_path): 24 | if os.path.exists(file_path): 25 | os.unlink(file_path) 26 | 27 | @staticmethod 28 | def _download_file(url: str, filepath: str, verbose: bool = True): 29 | if os.path.exists(filepath): 30 | return 31 | 32 | response = requests.get(url, stream=True) 33 | total_size = int(response.headers.get('content-length', 0)) 34 | with open(filepath, 'wb') as file: 35 | if verbose: 36 | with tqdm( 37 | desc=f"Downloading: {filepath}", 38 | total=total_size, 39 | unit='iB', 40 | unit_scale=True, 41 | unit_divisor=1024, 42 | ) as progress_bar: 43 | for data in response.iter_content(chunk_size=1024): 44 | written_size = file.write(data) 45 | progress_bar.update(written_size) 46 | else: 47 | for data in response.iter_content(chunk_size=1024): 48 | file.write(data) 49 | 50 | @staticmethod 51 | def _delete_existing_folder(path: str): 52 | if os.path.exists(path): 53 | shutil.rmtree(path) 54 | 55 | @staticmethod 56 | def _extract_zip_file(file_path, destination_path): 57 | if not os.path.exists(destination_path): 58 | 59 | with zipfile.ZipFile(file_path, "r") as zip_file: 60 | for file in tqdm(iterable=zip_file.namelist(), 61 | total=len(zip_file.namelist()), 62 | desc=f"Extracting: {destination_path}"): 63 | zip_file.extract(member=file, path=destination_path) 64 | 65 | @classmethod 66 | def load_zip(cls, url, remove_existing=False, model_file_name='model.bin', verbose=True): 67 | zip_filename = cls._clean_string(url.split("/")[-1]) 68 | zip_file_path = os.path.join(cls.base_directory, zip_filename) 69 | 70 | if remove_existing: 71 | cls._delete_existing_file(zip_file_path) 72 | 73 | cls._download_file(url, zip_file_path, verbose) 74 | 75 | extracted_data_directory_path = os.path.join(cls.base_directory, cls._clean_string(zip_filename) + "_data") 76 | if remove_existing: 77 | cls._delete_existing_folder(extracted_data_directory_path) 78 | 79 | cls._extract_zip_file(zip_file_path, extracted_data_directory_path) 80 | 81 | return os.path.join(extracted_data_directory_path, model_file_name) 82 | 83 | 84 | if __name__ == "__main__": 85 | download_path = DownloadManager.load_zip( 86 | 'http://vectors.nlpl.eu/repository/20/180.zip', 87 | remove_existing=True) 88 | print("Download 1 path:", download_path) 89 | 90 | download_path = DownloadManager.load_zip( 91 | 'http://vectors.nlpl.eu/repository/20/180.zip', 92 | remove_existing=False) 93 | print("Download 2 path:", download_path) 94 | 95 | download_path = DownloadManager.load_zip( 96 | 'http://vectors.nlpl.eu/repository/20/180.zip', 97 | remove_existing=True, verbose=False) 98 | print("Download 2 path:", download_path) 99 | -------------------------------------------------------------------------------- /explainitall/gpt_like_interp/inseq_helpers.py: -------------------------------------------------------------------------------- 1 | import re 2 | from copy import deepcopy 3 | from typing import Tuple, Union, Optional 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from inseq import FeatureAttributionOutput 9 | from inseq.utils.typing import GranularSequenceAttributionTensor as Gast 10 | from inseq.utils.typing import TokenSequenceAttributionTensor as Tsat 11 | from torch.linalg import vector_norm 12 | 13 | from explainitall import stat_helpers 14 | 15 | 16 | def sum_normalize_attributions( 17 | attributions: Union[Gast, Tuple[Gast, Gast]], 18 | cat_dim: int = 0, 19 | norm_dim: Optional[int] = 0, 20 | ) -> Tsat: 21 | """ 22 | Суммаризация и нормализация тензоров по dim_sum 23 | РезультатЖ матрица векторов строк 24 | """ 25 | concat = False 26 | if isinstance(attributions, tuple): 27 | concat = True 28 | orig_sizes = [a.shape[cat_dim] for a in attributions] 29 | attributions = torch.cat(attributions, dim=cat_dim) 30 | else: 31 | orig_sizes = [attributions.shape[cat_dim]] 32 | attributions = vector_norm(attributions, ord=2, dim=-1) 33 | if norm_dim is not None: 34 | attributions = attributions / attributions.nansum(dim=norm_dim, keepdim=True) 35 | if len(attributions.shape) == 1: 36 | attributions = attributions.unsqueeze(0) 37 | if concat: 38 | attributions = attributions.split(orig_sizes, dim=cat_dim) 39 | return attributions[0], attributions[1] 40 | return attributions 41 | 42 | 43 | def fix_ig_tokens(feature_attr: FeatureAttributionOutput): 44 | from transformers import GPT2Tokenizer 45 | 46 | feature_attr_conv = deepcopy(feature_attr) 47 | dt = GPT2Tokenizer.from_pretrained(feature_attr_conv.info['model_name']) 48 | 49 | for attr in feature_attr_conv.sequence_attributions: 50 | for token_holder in (attr.source, attr.target): 51 | for s in token_holder: 52 | try: 53 | s.token = dt.convert_tokens_to_string(s.token) 54 | except KeyError: 55 | pass 56 | return feature_attr_conv 57 | 58 | 59 | def get_ig_tokens(feature_attr: FeatureAttributionOutput): 60 | rez = [] 61 | for attr in feature_attr.sequence_attributions: 62 | rez.append((tuple(s.token for s in attr.source), 63 | tuple(t.token for t in attr.target))) 64 | return rez 65 | 66 | 67 | def get_ig_phrases(feature_attr: FeatureAttributionOutput): 68 | return tuple(zip(feature_attr.info['input_texts'], 69 | feature_attr.info['generated_texts'])) 70 | 71 | 72 | def get_g_arrays(feature_attr: FeatureAttributionOutput): 73 | target_arrays = [] 74 | for attr in feature_attr.sequence_attributions: 75 | ta = attr.target_attributions 76 | ta2 = sum_normalize_attributions(ta) 77 | ta2 = np.array(ta2, dtype=float) 78 | 79 | target_arrays.append(np.array(ta2, dtype=float)) 80 | return target_arrays 81 | 82 | 83 | class AttrObj: 84 | def __init__(self, 85 | phrase_input: str, 86 | phrase_generated: str, 87 | tokens_input: Tuple[str, ...], 88 | tokens_generated: Tuple[str, ...], 89 | array: np.ndarray): 90 | self.phrase_input = phrase_input 91 | self.phrase_generated = phrase_generated 92 | self.tokens_input = tokens_input 93 | self.tokens_generated = tokens_generated 94 | self.array = array 95 | 96 | def __repr__(self): 97 | return (f"AttrObj({self.phrase_input=}, " 98 | f"{self.phrase_generated=}, " 99 | f"{self.tokens_input=}, " 100 | f"{self.tokens_generated=}, " 101 | f"{self.array.shape=})") 102 | 103 | 104 | def get_first_attribute(feature_attr: FeatureAttributionOutput): 105 | fixed_attr = fix_ig_tokens(feature_attr) 106 | phrase_input, phrase_generated_full = get_ig_phrases(fixed_attr)[0] 107 | phrase_generated = phrase_generated_full[len(phrase_input):] 108 | tokens_input, tokens_generated_full = get_ig_tokens(fixed_attr)[0] 109 | tokens_generated = tokens_generated_full[len(tokens_input):] 110 | array = get_g_arrays(fixed_attr)[0] 111 | 112 | return AttrObj(phrase_input=phrase_input, 113 | phrase_generated=phrase_generated, 114 | tokens_input=tokens_input, 115 | tokens_generated=tokens_generated, 116 | array=array) 117 | 118 | 119 | def attr_to_df(attr: AttrObj): 120 | """Преобразует атрибуты в DataFrame""" 121 | df = pd.DataFrame(attr.array) 122 | df.columns = attr.tokens_generated 123 | df = df.sort_index() 124 | df.insert(0, 'Tokens', attr.tokens_input + attr.tokens_generated) 125 | return df 126 | 127 | 128 | def squash_arr(arr, squash_row_mask, squash_col_mask, aggr_f=np.max): 129 | # Apply the mask to the rows 130 | row_result = np.array([aggr_f(arr[start:end], axis=0) 131 | for start, end in squash_row_mask]) 132 | # Apply the mask to the columns 133 | col_result = np.array([aggr_f(row_result[:, start:end], axis=1) 134 | for start, end in squash_col_mask]).T 135 | return col_result 136 | 137 | 138 | class Detokenizer: 139 | """ 140 | Класс для детокенизации (приведения токенов к словам). 141 | """ 142 | # список символов-тире 143 | dash_chars = list(map(chr, (45, 8211, 8212, 8722, 9472, 9473, 9476))) 144 | # регулярное выражение для поиска тире между буквами 145 | dash_regex = re.compile(r'(? AttrObj: 209 | tokens_input_grouped = Detokenizer(attr.phrase_input, attr.tokens_input).group_text() 210 | tokens_generated_grouped = Detokenizer(attr.phrase_generated, attr.tokens_generated).group_text() 211 | tokens_input_generated_mask = calculate_mask(tokens_input_grouped + tokens_generated_grouped) 212 | tokens_generated_mask = calculate_mask(tokens_generated_grouped) 213 | 214 | squashed_array = squash_arr(attr.array, 215 | squash_col_mask=tokens_generated_mask, 216 | squash_row_mask=tokens_input_generated_mask) 217 | 218 | tokens_input_grouped_flatten = tuple(["".join(x) for x in tokens_input_grouped]) 219 | tokens_generated_grouped_flatten = tuple("".join(x) for x in tokens_generated_grouped) 220 | 221 | if gmm_norm: 222 | squashed_array = stat_helpers.calc_gmm_stat_params(squashed_array)['new_arr'] 223 | 224 | return AttrObj(phrase_input=attr.phrase_input, 225 | phrase_generated=attr.phrase_generated, 226 | tokens_input=tokens_input_grouped_flatten, 227 | tokens_generated=tokens_generated_grouped_flatten, 228 | array=squashed_array) 229 | -------------------------------------------------------------------------------- /explainitall/gpt_like_interp/interp.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import gensim 4 | import pandas as pd 5 | from inseq import AttributionModel 6 | 7 | from explainitall import clusters 8 | from explainitall.clusters import ClusterManager, aggregate_cluster_df 9 | from explainitall.gpt_like_interp import viz 10 | from . import inseq_helpers 11 | from .inseq_helpers import AttrObj 12 | 13 | 14 | class ExplainerGPT2Output: 15 | def __init__(self, attributions: AttrObj, 16 | attributions_grouped: AttrObj, 17 | attributions_grouped_norm: AttrObj, 18 | 19 | cluster_imp_df: pd.DataFrame, 20 | cluster_imp_aggr_df: pd.DataFrame, 21 | word_imp_df: pd.DataFrame, 22 | word_imp_norm_df: pd.DataFrame): 23 | self.attributions = attributions 24 | self.attributions_grouped = attributions_grouped 25 | self.attributions_grouped_norm = attributions_grouped_norm 26 | 27 | self.cluster_imp_df = cluster_imp_df 28 | self.cluster_imp_aggr_df = cluster_imp_aggr_df 29 | self.word_imp_df = word_imp_df 30 | self.word_imp_norm_df = word_imp_norm_df 31 | 32 | def show_word_imp_heatmap(self): 33 | viz.df_to_heatmap(self.word_imp_df, title="Карта важности слов") 34 | 35 | def show_word_imp_norm_heatmap(self): 36 | viz.df_to_heatmap(self.word_imp_norm_df, title="Карта важности слов, нормированная") 37 | 38 | def show_cluster_imp_heatmap(self): 39 | viz.df_to_heatmap(self.cluster_imp_df, title="Карта важности слов") 40 | 41 | def show_cluster_imp_aggr_heatmap(self): 42 | viz.df_to_heatmap(self.cluster_imp_aggr_df, title="Карта важности слов группированная") 43 | 44 | 45 | class ExplainerGPT2: 46 | def __init__(self, gpt_model: AttributionModel, 47 | nlp_model: gensim.models.keyedvectors.KeyedVectors): 48 | self.gpt_model = gpt_model 49 | self.nlp_model = nlp_model 50 | self._cluster_manager = ClusterManager(embeddings=self.nlp_model) 51 | self.attributions = None 52 | 53 | def interpret(self, 54 | input_texts: str, 55 | generated_texts: str, 56 | clusters_description: List[Dict[str, object]], 57 | batch_size: int = 50, 58 | steps: int = 34, 59 | max_new_tokens: int = None, 60 | aggr_f='mean') -> ExplainerGPT2Output: 61 | self._attribute(input_texts, generated_texts, max_new_tokens, steps, batch_size) 62 | return self._run_pipeline(clusters_description, aggr_f) 63 | 64 | @staticmethod 65 | def calc_max_tokes(input_texts, generated_texts): 66 | from gensim.utils import tokenize 67 | tokens = list(tokenize(input_texts + generated_texts, lowercase=True)) 68 | num_tokens = len(tokens) 69 | return num_tokens + 10 # buffer 70 | 71 | def _attribute(self, input_texts: str, 72 | generated_texts: str, 73 | max_new_tokens: int, 74 | steps: int, batch_size: int): 75 | 76 | generation_args = None 77 | if not generated_texts: 78 | if max_new_tokens is None: 79 | max_new_tokens = self.calc_max_tokes(input_texts, generated_texts) 80 | generation_args = {"max_new_tokens": max_new_tokens} 81 | 82 | out = self.gpt_model.attribute( 83 | input_texts=input_texts, generated_texts=input_texts + generated_texts, 84 | n_steps=steps, generation_args=generation_args, 85 | show_progress=True, pretty_progress=False, internal_batch_size=batch_size 86 | ) 87 | self.attributions = inseq_helpers.get_first_attribute(out) 88 | 89 | def _run_pipeline(self, cluster_desc, aggr_f): 90 | group_attr = inseq_helpers.group_by(self.attributions) 91 | norm_attr = inseq_helpers.group_by(self.attributions, gmm_norm=True) 92 | 93 | word_imp_df = inseq_helpers.attr_to_df(group_attr) 94 | word_imp_norm_df = inseq_helpers.attr_to_df(norm_attr) 95 | 96 | cluster_interpreter = clusters.ClusterInterpreter(clusters_discr=cluster_desc, 97 | cluster_manager=self._cluster_manager) 98 | 99 | cluster_imp_df = cluster_interpreter.get_cluster_importance_df(self.attributions) 100 | try: 101 | cluster_imp_aggr_df = aggregate_cluster_df(cluster_imp_df, aggr_f=aggr_f) 102 | except Exception as e: 103 | print(f"Неверно заданы выходные кластеры: {e}") 104 | cluster_imp_aggr_df = pd.DataFrame() 105 | 106 | return ExplainerGPT2Output( 107 | attributions=self.attributions, attributions_grouped=group_attr, attributions_grouped_norm=norm_attr, 108 | cluster_imp_df=cluster_imp_df, cluster_imp_aggr_df=cluster_imp_aggr_df, 109 | word_imp_df=word_imp_df, word_imp_norm_df=word_imp_norm_df) 110 | -------------------------------------------------------------------------------- /explainitall/gpt_like_interp/viz.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import seaborn as sns 6 | from statsmodels.graphics.gofplots import qqplot 7 | 8 | from explainitall import stat_helpers 9 | 10 | 11 | def show_distribution_histogram(data): 12 | """ 13 | Строит гистограмму распределения 14 | """ 15 | stat_params = stat_helpers.calc_gmm_stat_params(data) 16 | data = data[~np.isnan(data)] 17 | data.sort() 18 | 19 | gaussian_prob = stat_helpers.compute_gaussian_integral(data, stat_params["mean"], stat_params["std"]) 20 | rayleigh_prob = stat_helpers.rayleigh_integral(data, data.var()) 21 | 22 | plt.hist(data, bins=15, cumulative=True, density=True, label='Распределение реальных данных') 23 | plt.plot(data, gaussian_prob, color='red', label='Функция Гаусса') 24 | plt.plot(data, rayleigh_prob, label='Функция Рэлея') 25 | 26 | plt.legend(title='Функции распределения') 27 | 28 | ax = plt.gca() 29 | ax.set_xlabel("Значения", fontsize=10, color='black', labelpad=10) 30 | ax.set_ylabel("Распределение P(X" 126 | "Контекст и сгенерированный текст - это входные данные для анализа.
") 127 | context_text = gr.Text(label='Context', 128 | info="**Контекст** это ваш исходный текст, который служит отправной точкой для генерации ответа " 129 | "пример: 'у кошки грипп и аллергия на антибиотбиотики вопрос: чем лечить кошку? ответ:'", 130 | lines=1, placeholder="Enter context here...") 131 | 132 | output_text = gr.Text(label='Generated text', 133 | info="**Сгенерированный текст** - это результат работы системы, ответ на ваш контекст " 134 | " 'лечите ее уколами'", 135 | lines=1, placeholder="Enter generated text here...") 136 | 137 | with gr.Row(): 138 | texts_load_button = gr.Button("Set texts") 139 | texts_set_checkbox = gr.Checkbox(label='Texts are set', interactive=False, 140 | info="Если тексты установлены, можно переходить к шагу Clusters.") 141 | 142 | with gr.TabItem("Clusters"): 143 | with gr.Tabs(): 144 | with gr.TabItem("Load clusters from file"): 145 | gr.Markdown(""" 146 | Загрузите файл с описанием кластеров. Файл должен быть в формате JSON и содержать список кластеров, где каждый кластер описывается следующими ключами: 147 | - `name`: Название кластера (строка). 148 | - `centroid`: Центроид кластера, представленный списком строк (например, список терминов, характеризующих кластер). 149 | - `top_k`: Количество топовых элементов кластера (целое число). 150 | 151 | 152 | пример файла: example_data/clusters.json 153 | 154 | Пример структуры файла: 155 | ```json 156 | [ 157 | { 158 | "name": "Кластер 1", 159 | "centroid": ["термин1", "термин2", "термин3"], 160 | "top_k": 5 161 | }, 162 | { 163 | "name": "Кластер 2", 164 | "centroid": ["термин4", "термин5", "термин6"], 165 | "top_k": 3 166 | } 167 | ] 168 | ``` 169 | 170 | Этот файл используется для анализа и визуализации влияния различных групп слов (кластеров) на сгенерированный текст. Загрузка подходящего файла позволит провести более глубокий анализ и понять, какие тематические группы наиболее важны в контексте генерации текста. 171 | """) 172 | with gr.Column(): 173 | clusters_file = gr.File(label='Clusters\' file') 174 | with gr.Row(): 175 | clusters_load_from_file_button = gr.Button("Load clusters") 176 | with gr.TabItem("Set clusters manually"): 177 | with gr.Column(): 178 | set_clusters_table = gr.Dataframe(label='Clusters\' table', 179 | headers=['name', 'centroid', 'top_k'], 180 | # max_rows=None, 181 | height=None, 182 | # overflow_row_behaviour='paginate', 183 | wrap=False, 184 | interactive=True) 185 | with gr.Row(): 186 | set_clusters_from_dataframe_button = gr.Button("Set clusters") 187 | 188 | clusters_set_checkbox = gr.Checkbox( 189 | label='Clusters are set Если кластеры и NLP модель установлены, можно переходить к шагу LLM model', 190 | interactive=False) 191 | 192 | cluster_table = gr.Dataframe(label='Clusters', 193 | headers=['name', 'centroid', 'top_k'], 194 | # max_rows=None, 195 | height=None, 196 | # overflow_row_behaviour='paginate', 197 | wrap=False, 198 | interactive=False) 199 | 200 | clusters_file_path_text = gr.Text(label='Save clusters to file') 201 | with gr.Row(): 202 | clusters_save_button = gr.Button("Save clusters") 203 | clusters_save_checkbox = gr.Checkbox(label='Clusters are saved', interactive=False) 204 | 205 | nlp_model_url = gr.Text(label='Model url', 206 | info="Введите URL для загрузки предварительно обученной NLP модели. Модель должна быть в формате, совместимом с библиотекой gensim, например," 207 | " Word2Vec, FastText или любой другой векторной модели слов. например http://vectors.nlpl.eu/repository/20/180.zip", 208 | placeholder="http://vectors.nlpl.eu/repository/20/180.zip", 209 | lines=1) 210 | with gr.Row(): 211 | load_nlp_model_button = gr.Button("Load model") 212 | nlp_model_set_checkbox = gr.Checkbox(label='NLP model loaded', interactive=False) 213 | with gr.TabItem("LLM interpretation model"): 214 | nn_model_name_or_path = gr.Text(label='Model name or path', 215 | info="Введите название модели или путь к ней для использования в качестве модели интерпретации." 216 | " Это должна быть модель на основе GPT или другой современной трансформерной модели, например sberbank-ai/rugpt3small_based_on_gpt2", 217 | placeholder="sberbank-ai/rugpt3small_based_on_gpt2", 218 | lines=1) 219 | with gr.Row(): 220 | load_nn_model_button = gr.Button("Load model") 221 | nn_model_set_checkbox = gr.Checkbox( 222 | label='NN model loaded, если модель загружена можно переходить к шагу Results', 223 | interactive=False) 224 | with gr.TabItem("Results"): 225 | with gr.Row(): 226 | with gr.Column(): 227 | result_texts_set_checkbox = gr.Checkbox(label='Texts are set', interactive=False) 228 | result_clusters_set_checkbox = gr.Checkbox(label='Clusters are set', interactive=False) 229 | result_nlp_model_set_checkbox = gr.Checkbox(label='NLP model loaded', interactive=False) 230 | result_nn_model_set_checkbox = gr.Checkbox(label='NN model loaded', interactive=False) 231 | launch_button = gr.Button("Launch") 232 | with gr.Column(): 233 | with gr.Tabs(): 234 | with gr.TabItem("Word importance"): 235 | # word_importance_image = gr.Image().style(height=600) 236 | word_importance_image = gr.Image(height=600) 237 | with gr.TabItem("Word importance normalized"): 238 | # word_importance_norm_image = gr.Image().style(height=600) 239 | word_importance_norm_image = gr.Image(height=600) 240 | with gr.TabItem("Cluster importance"): 241 | # cluster_importance_image = gr.Image().style(height=600) 242 | cluster_importance_image = gr.Image(height=600) 243 | with gr.TabItem("Cluster importance grouped"): 244 | # cluster_importance_norm_image = gr.Image().style(height=600) 245 | cluster_importance_norm_image = gr.Image(height=600) 246 | with gr.TabItem("Chatbot"): 247 | chat = gr.Chatbot(label="Chatbot", layout="bubble") 248 | msg = gr.Text(label="Message", interactive=True) 249 | send_message = gr.Button("Send", interactive=True) 250 | clear = gr.ClearButton([msg, chat]) 251 | # chat.change(fn=chatbot_function, inputs=msg, outputs=chat) 252 | 253 | texts_load_button.click(self.load_context_and_generated_text_, 254 | inputs=[context_text, output_text], 255 | outputs=[texts_set_checkbox, result_texts_set_checkbox]) 256 | 257 | clusters_load_from_file_button.click(self.load_clusters_from_file_, 258 | inputs=[clusters_file], 259 | outputs=[clusters_set_checkbox, result_clusters_set_checkbox, 260 | cluster_table, set_clusters_table]) 261 | set_clusters_from_dataframe_button.click(self.set_clusters_from_dataframe_, 262 | inputs=[set_clusters_table], 263 | outputs=[clusters_set_checkbox, result_clusters_set_checkbox, 264 | cluster_table]) 265 | clusters_save_button.click(self.save_new_clusters_to_file_, 266 | inputs=[cluster_table, clusters_file_path_text], 267 | outputs=[clusters_save_checkbox]) 268 | 269 | load_nn_model_button.click(self.load_nn_model_, 270 | inputs=[nn_model_name_or_path], 271 | outputs=[nn_model_set_checkbox, result_nn_model_set_checkbox]) 272 | 273 | load_nlp_model_button.click(self.load_nlp_model_, 274 | inputs=[nlp_model_url], 275 | outputs=[nlp_model_set_checkbox, result_nlp_model_set_checkbox]) 276 | 277 | launch_button.click(self.show_results, inputs=[], outputs=[word_importance_image, 278 | word_importance_norm_image, 279 | cluster_importance_image, 280 | cluster_importance_norm_image]) 281 | 282 | send_message.click(self.respond_, inputs=[msg, chat], outputs=[msg, chat]) 283 | 284 | def __del__(self): 285 | pass 286 | 287 | # FUNCTIONALITY: 288 | 289 | def launch(self): 290 | self.demo_.launch(share=True, debug=False, server_name="127.0.0.1", inbrowser=True) 291 | 292 | def show_results(self): 293 | self.explainer_ = interp.ExplainerGPT2(gpt_model=self.nn_model_, nlp_model=self.nlp_model_) 294 | expl_data = self.explainer_.interpret(input_texts=self.context_, 295 | generated_texts=self.generated_text_, 296 | clusters_description=self.clusters_, 297 | batch_size=50, 298 | steps=34, 299 | # max_new_tokens=19 300 | ) 301 | 302 | # Результат интерпретации 303 | imp_df_cl = expl_data.cluster_imp_aggr_df 304 | cl_desc = interp_cl(imp_df_cl) 305 | 306 | clean = [clean_string(cl_data) for cl_data in cl_desc] 307 | vects_x = self.sbert_.encode(clean) 308 | m = vects_x.mean(axis=0) 309 | s = vects_x.std(axis=0) 310 | try: 311 | knn_vects_x = (vects_x - m) / s 312 | knn = KNeighborsClassifier(metric=cos_dist) 313 | knn.fit(knn_vects_x, cl_desc) 314 | 315 | self.interp_bot_ = PromptBot(knn, self.sbert_, self.fred_, cl_desc, device='cpu') 316 | except: 317 | print("Err") 318 | self.interp_bot_ = None 319 | word_importance_plt = df_to_heatmap_plot(expl_data.word_imp_df, title="Карта важности слов") 320 | word_importance_norm_plt = df_to_heatmap_plot(expl_data.word_imp_norm_df, 321 | title="Карта важности слов, нормированная") 322 | 323 | cluster_importance_plt = df_to_heatmap_plot(expl_data.cluster_imp_df, title="Карта важности кластеров") 324 | cluster_importance_norm_plt = df_to_heatmap_plot(expl_data.cluster_imp_aggr_df, 325 | title="Карта важности кластеров, группированная") 326 | 327 | return word_importance_plt, word_importance_norm_plt, cluster_importance_plt, cluster_importance_norm_plt 328 | 329 | # PRIVATE FUNCTIONS: 330 | 331 | def respond_(self, message, chat_history): 332 | ans = self.interp_bot_.get_answers(message, top_k=3) 333 | bot_reply = 'Кластер' + ans.split('.')[0].split('Кластер')[1] 334 | chat_history.append((message, bot_reply)) 335 | 336 | return "", chat_history 337 | 338 | def load_context_and_generated_text_(self, context, generated_text): 339 | self.context_ = context 340 | self.generated_text_ = generated_text 341 | 342 | return True, True 343 | 344 | def load_clusters_from_file_(self, jsonfile: tempfile._TemporaryFileWrapper): 345 | with open(jsonfile.name, 'r') as fp: 346 | self.clusters_ = json.load(fp) 347 | df = make_dataframe_from_clusters(self.clusters_) 348 | 349 | return True, True, df, df 350 | 351 | def set_clusters_from_dataframe_(self, df): 352 | self.clusters_ = make_clusters_from_dataframe(df) 353 | 354 | return True, True, df 355 | 356 | def save_new_clusters_to_file_(self, df, filename): 357 | clusters = make_clusters_from_dataframe(df) 358 | with open(filename, 'w') as fp: 359 | json.dump(clusters, fp) 360 | 361 | return True 362 | 363 | def load_nlp_model_(self, url): 364 | self.npl_model_url_ = url 365 | nlp_model_path = DownloadManager.load_zip(url) 366 | self.nlp_model_ = gensim.models.KeyedVectors.load_word2vec_format(nlp_model_path, binary=True) 367 | 368 | return True, True 369 | 370 | def load_nn_model_(self, model_name_or_path): 371 | path = os.path.normpath(model_name_or_path) 372 | path_list = path.split(os.sep) 373 | self.nn_model_name_ = path_list[-1] 374 | 375 | self.nn_model_ = load_model(model=model_name_or_path, 376 | attribution_method="integrated_gradients") 377 | 378 | return True, True 379 | -------------------------------------------------------------------------------- /explainitall/metrics/CheckingForHallucinations.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import numpy as np 3 | from nltk.tokenize import sent_tokenize 4 | from sentence_transformers import SentenceTransformer, util 5 | 6 | 7 | def sim_cosine(embedding_1, embedding_2): 8 | s = util.pytorch_cos_sim(embedding_1, embedding_2) 9 | return s.detach().numpy() 10 | 11 | def sim_euclidean(embedding_1, embedding_2): 12 | s = np.linalg.norm(embedding_1 - embedding_2, axis=1) 13 | return s 14 | 15 | def sim_manhattan(embedding_1, embedding_2): 16 | s = np.sum(np.abs(embedding_1 - embedding_2), axis=1) 17 | return s 18 | 19 | def sim_cross_encoder(encoder, sentences1, sentences2): 20 | scores = encoder.predict([sentences1, sentences2]) 21 | return scores 22 | 23 | class RAGHallucinationsChecker: 24 | def __init__(self, sbert_model: SentenceTransformer, cross_encoder=None, language = 'russian'): 25 | nltk.download('punkt') 26 | self.sbert_model = sbert_model 27 | self.cross_encoder = cross_encoder 28 | self.language = language 29 | 30 | def load_doc(self, text, block_size=3): 31 | seqs = sent_tokenize(text, language=self.language) 32 | count_block = max([1, len(seqs) // block_size]) 33 | seqs = [list(x) for x in np.array_split(seqs, count_block)] 34 | snippets = [' '.join(x) for x in seqs] 35 | 36 | return snippets 37 | 38 | def get_support_seq(self, doc_snp, ans, prob=0.6, top_k=1, sim_metric='cosine'): 39 | 40 | docs = [] 41 | sn_ans = self.load_doc(ans, 1) 42 | 43 | for d in doc_snp: 44 | docs += self.load_doc(d, 1) 45 | 46 | top_k = min(top_k, len(docs)) 47 | 48 | ans_v = self.sbert_model.encode(sn_ans) 49 | doc_v = self.sbert_model.encode(docs) 50 | 51 | if sim_metric == 'cosine': 52 | matrix_ = sim_cosine(doc_v, ans_v) 53 | elif sim_metric == 'euclidean': 54 | matrix_ = sim_euclidean(doc_v, ans_v) 55 | elif sim_metric == 'manhattan': 56 | matrix_ = sim_manhattan(doc_v, ans_v) 57 | elif sim_metric == 'cross_encoder' and self.cross_encoder is not None: 58 | matrix_ = sim_cross_encoder(self.cross_encoder, docs, sn_ans) 59 | else: 60 | raise ValueError("Unsupported similarity metric or cross-encoder is not provided.") 61 | 62 | res = [] 63 | for i in range(matrix_.shape[1]): 64 | slice_ = matrix_[:, i] 65 | top_indexes = np.argpartition(slice_, -top_k)[-top_k:] 66 | top_probs = matrix_[top_indexes, i] 67 | top_indexes[top_probs < prob] = -1 68 | 69 | reference_texts = [] 70 | indexes = [] 71 | for j, d in enumerate(top_indexes): 72 | if d >= 0: 73 | reference_texts += [docs[d]] 74 | indexes += [d] 75 | 76 | if len(indexes) > 0: 77 | res.append({'answer': sn_ans[i], 'reference_texts': reference_texts, 'indexes': indexes}) 78 | 79 | return res 80 | 81 | def get_conf(self, doc_snp, ans, prob=0.6, sim_metric='cosine'): 82 | answer_a = self.get_support_seq(ans=ans, doc_snp=doc_snp, prob=prob, sim_metric=sim_metric) 83 | len_all = len(ans) 84 | len_with_out_h = len(answer_a) - 1 # Учитываем пробелы 85 | 86 | for s in answer_a: 87 | len_with_out_h += len(s['answer']) 88 | 89 | return len_with_out_h / len_all 90 | 91 | def get_hallucinations_prob(self, doc_snp, ans, prob=0.6, sim_metric='cosine'): 92 | return 1 - self.get_conf(doc_snp, ans, prob, sim_metric) 93 | 94 | -------------------------------------------------------------------------------- /explainitall/metrics/RougeAndPPL/Metrics.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from transformers import PreTrainedModel, PreTrainedTokenizer 3 | from typing import List, Dict 4 | 5 | import torch 6 | 7 | from explainitall.metrics.RougeAndPPL.helpers import get_all_words_from_text 8 | from explainitall.metrics.RougeAndPPL.rouge_L import rouge_L 9 | from explainitall.metrics.RougeAndPPL.rouge_N import rouge_N 10 | 11 | 12 | class Metric(ABC): 13 | @staticmethod 14 | def preprocess(contexts, references, candidates, tokenizer): 15 | raise NotImplementedError 16 | 17 | @abstractmethod 18 | def calculate(self, references_encodings, candidates_encodings): 19 | raise NotImplementedError 20 | 21 | 22 | class Metric_rouge(Metric): 23 | 24 | def __init__(self, n): 25 | self.n = n 26 | 27 | @staticmethod 28 | def preprocess(contexts, references, candidates, tokenizer): 29 | res = {'references': [], 'candidates': []} 30 | for i, context in enumerate(contexts): 31 | res['references'].append(get_all_words_from_text(references[i][len(contexts[i]):])) 32 | res['candidates'].append(get_all_words_from_text(candidates[i][len(contexts[i]):])) 33 | return res 34 | 35 | 36 | class MetricRougeL(Metric_rouge): 37 | 38 | def calculate(self, references_encodings, candidates_encodings): 39 | res = [rouge_L(reference, candidate) for reference, candidate in 40 | zip(references_encodings, candidates_encodings)] 41 | return res 42 | 43 | 44 | class MetricRougeN(Metric_rouge): 45 | 46 | def calculate(self, references_encodings, candidates_encodings): 47 | res = [rouge_N(reference, candidate, self.n) for reference, candidate in 48 | zip(references_encodings, candidates_encodings)] 49 | return res 50 | 51 | 52 | class MetricStandard(Metric): 53 | 54 | def __init__(self, metric_function): 55 | self.metric_function = metric_function 56 | 57 | def calculate(self, references_encodings, candidates_encodings): 58 | res = [self.metric_function(reference, candidate) for reference, candidate in 59 | zip(references_encodings, candidates_encodings)] 60 | return res 61 | 62 | 63 | class Metric_ppl(Metric): 64 | 65 | def __init__(self, model: PreTrainedModel, stride: int): 66 | 67 | self.model = model 68 | self.stride = stride 69 | 70 | @staticmethod 71 | def preprocess(contexts: List[str], references: List[str], candidates: List[str], tokenizer: PreTrainedTokenizer) -> Dict[str, List[List[int]]]: 72 | """ 73 | Предобработка данных 74 | """ 75 | tokenized_data = {'references': [], 'candidates': []} 76 | 77 | for context, reference in zip(contexts, references): 78 | encoded_reference = tokenizer(reference) 79 | tokenized_data['references'].append(encoded_reference.input_ids) 80 | 81 | return tokenized_data 82 | 83 | def calculate(self, reference_encodings: List[List[int]], candidate_encodings: List[List[int]]) -> List[Dict[str, float]]: 84 | """ 85 | Вычисление ppl для списка токенизированных текстов 86 | """ 87 | perplexities = [self._calculate_perplexity(encodings) for encodings in reference_encodings] 88 | return perplexities 89 | 90 | def _calculate_perplexity(self, encodings: List[int]) -> Dict[str, float]: 91 | """ 92 | Вычисление перплексии для одного токенизированного текста 93 | """ 94 | max_length = self.model.config.n_positions 95 | sequence_length = len(encodings) 96 | 97 | neg_log_likelihoods = [] 98 | prev_end_loc = 0 99 | 100 | for start_loc in range(0, sequence_length, self.stride): 101 | end_loc = min(start_loc + max_length, sequence_length) 102 | target_length = end_loc - prev_end_loc 103 | 104 | input_ids = torch.tensor(encodings[start_loc:end_loc], device=self.model.device).unsqueeze(0) 105 | target_ids = input_ids.clone() 106 | target_ids[:, :-target_length] = -1e-3 107 | 108 | with torch.no_grad(): 109 | outputs = self.model(input_ids, labels=target_ids) 110 | neg_log_likelihood = outputs.loss * target_length 111 | 112 | neg_log_likelihoods.append(neg_log_likelihood) 113 | prev_end_loc = end_loc 114 | 115 | if end_loc == sequence_length: 116 | break 117 | 118 | total_neg_log_likelihood = torch.stack(neg_log_likelihoods).sum() 119 | perplexity = torch.exp(total_neg_log_likelihood / sequence_length) 120 | return {'value': perplexity.item()} 121 | -------------------------------------------------------------------------------- /explainitall/metrics/RougeAndPPL/Metrics_calculator.py: -------------------------------------------------------------------------------- 1 | class Metrics_calculator: 2 | metrics_ = None 3 | preprocessing_functions_ = None 4 | tokenizer_ = None 5 | 6 | def __init__(self, tokenizer): 7 | self.metrics_ = {} 8 | self.preprocessing_functions_ = {} 9 | self.tokenizer_ = tokenizer 10 | 11 | def __del__(self): 12 | del self.metrics_ 13 | del self.tokenizer_ 14 | 15 | def calculate(self, contexts, references, candidates): 16 | res = {} 17 | 18 | assert len(contexts) == len(references) == len(candidates) 19 | 20 | preprocessed_sentences = {} 21 | for preproc in self.preprocessing_functions_: 22 | preprocessed = preproc(contexts, references, candidates, self.tokenizer_) 23 | preprocessed_sentences[frozenset(self.preprocessing_functions_[preproc])] = preprocessed 24 | 25 | for metric_name in self.metrics_: 26 | for metric_group in preprocessed_sentences: 27 | if metric_name in metric_group: 28 | prep = preprocessed_sentences[metric_group] 29 | prep_references = prep['references'] 30 | prep_candidates = prep['candidates'] 31 | res[metric_name] = self.metrics_[metric_name].calculate(prep_references, prep_candidates) 32 | break 33 | return res 34 | 35 | def add_metric(self, name, metric): 36 | if name in self.metrics_: 37 | raise Exception('Metric name ' + name + ' already exists!') 38 | self.metrics_[name] = metric 39 | 40 | preproc = metric.preprocess 41 | if preproc not in self.preprocessing_functions_: 42 | self.preprocessing_functions_[preproc] = set() 43 | self.preprocessing_functions_[preproc].add(name) 44 | -------------------------------------------------------------------------------- /explainitall/metrics/RougeAndPPL/create_database.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine, Table, Column, Float, Integer, String, MetaData 2 | 3 | 4 | ENGINE = create_engine('sqlite:///database.sqlite', echo=True) 5 | 6 | 7 | metadata = MetaData() 8 | 9 | 10 | data_table = Table('data', metadata, 11 | Column('id', Integer, primary_key=True, autoincrement=True), 12 | Column('model_name', String), 13 | Column('timestamp', Integer), 14 | Column('dataset_name', String), 15 | Column('dataset_version', Integer), 16 | Column('PPL', Float), 17 | Column('R3', Float), 18 | Column('R5', Float), 19 | Column('R-L', Float)) 20 | 21 | 22 | metadata.create_all(ENGINE) 23 | -------------------------------------------------------------------------------- /explainitall/metrics/RougeAndPPL/helpers.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | from datetime import datetime 4 | 5 | import pandas as pd 6 | import sqlalchemy 7 | 8 | 9 | def calculate_average_metric_values(calculated_metrics): 10 | res = {} 11 | 12 | for metric in calculated_metrics: 13 | sub_metric_average_values = {} 14 | for text_evaluation_result in calculated_metrics[metric]: 15 | for sub_metric in text_evaluation_result: 16 | if sub_metric not in sub_metric_average_values: 17 | sub_metric_average_values[sub_metric] = 0.0 18 | sub_metric_average_values[sub_metric] += text_evaluation_result[sub_metric] 19 | for sub in sub_metric_average_values: 20 | sub_metric_average_values[sub] = sub_metric_average_values[sub] / len(calculated_metrics[metric]) 21 | if 'f1' in sub_metric_average_values: 22 | res[metric] = sub_metric_average_values['f1'] 23 | elif 'value' in sub_metric_average_values: 24 | res[metric] = sub_metric_average_values['value'] 25 | else: 26 | res[metric] = 0.0 27 | 28 | return res 29 | 30 | 31 | def fbeta_score(precision, recall, beta=1): 32 | if precision == 0.0 and recall == 0.0: 33 | return 0.0 34 | 35 | if beta == 1: 36 | return 2 * precision * recall / (precision + recall) 37 | return (1 + beta * beta) * precision * recall / (beta * beta * precision + recall) 38 | 39 | 40 | def generate_candidates(model, tokenizer, sentences, max_length=128, max_new_tokens=100): 41 | candidates = [] 42 | for sentence in sentences: 43 | encoded_input = tokenizer(sentence, truncation=True, max_length=max_length, return_tensors='pt') 44 | encoded_input = encoded_input.to(model.device) 45 | res = model.generate(**encoded_input, max_new_tokens=max_new_tokens) 46 | candidate = tokenizer.decode(res[0], skip_special_tokens=True) 47 | candidates.append(candidate) 48 | return candidates 49 | 50 | 51 | def get_all_words_from_text(text): 52 | words = re.findall(r"[\w]+", text) 53 | return words 54 | 55 | 56 | def get_max_dataset_version(dataset_name, conn, data_table): 57 | statement = sqlalchemy.select(sqlalchemy.func.max(data_table.c.dataset_version)).where( 58 | data_table.c.dataset_name == dataset_name) 59 | 60 | records = [] 61 | 62 | for row in conn.execute(statement): 63 | records.append(row) 64 | 65 | res = records[0][0] 66 | 67 | if res is None: 68 | return -1 69 | 70 | return records[0][0] 71 | 72 | 73 | def get_records_from_database(data_table, conn, specific_column_value=None): 74 | records = [] 75 | 76 | statement = data_table.select() 77 | 78 | if specific_column_value is not None: 79 | for k in specific_column_value: 80 | col = sqlalchemy.sql.column(k) 81 | statement = statement.where(col == specific_column_value[k]) 82 | 83 | for row in conn.execute(statement): 84 | records.append(row) 85 | 86 | return records 87 | 88 | 89 | def insert_new_record(data_table, conn, model_name, dataset_name, dataset_version, metric_values): 90 | record = {'model_name': model_name, 91 | 'timestamp': int(time.time()), 92 | 'dataset_name': dataset_name, 93 | 'dataset_version': dataset_version} 94 | for m in metric_values: 95 | record[m] = metric_values[m] 96 | 97 | statement = data_table.insert().values(**record) 98 | conn.execute(statement) 99 | conn.commit() 100 | 101 | 102 | def make_dataframe_from_history_records(records): 103 | columns = ['model_name', 'date', 'dataset_name', 'dataset_version', 'PPL', 'R3', 'R5', 'R-L'] 104 | res_records = [] 105 | 106 | for rec in records: 107 | r = list(rec[1:]) 108 | d = datetime.fromtimestamp(r[1]) 109 | r[1] = str(d.day) + '.' + str(d.month) + '.' + str(d.year) 110 | for i in range(len(columns[4:])): 111 | r[i + 4] = round(r[i + 4], 2) 112 | res_records.append(r) 113 | 114 | df = pd.DataFrame(res_records, columns=columns) 115 | 116 | return df 117 | 118 | 119 | def split_text_by_whitespaces(text): 120 | text = text.replace('\n', ' ') 121 | text = re.sub(r'\W+', ' ', text) 122 | text = re.sub(r'\s+', ' ', text) 123 | tokens = re.split(r"\s", text) 124 | return tokens 125 | 126 | 127 | def words_n_gramm(text, n_gramm=3): 128 | tx = text.replace('\n', ' ') 129 | tx = re.sub(r' +', ' ', tx) 130 | w = tx.split(' ') 131 | 132 | ng = [] 133 | 134 | for i, word in enumerate(w): 135 | ng.append(' '.join(w[i:i + n_gramm])) 136 | 137 | return list(set(ng)) 138 | -------------------------------------------------------------------------------- /explainitall/metrics/RougeAndPPL/metric_calculation_interface.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import pandas as pd 3 | from transformers import GPT2LMHeadModel, GPT2TokenizerFast, AutoTokenizer, AutoModel, AutoModelForCausalLM 4 | 5 | from explainitall.metrics.RougeAndPPL.Metrics import Metric_ppl, MetricRougeL, MetricRougeN 6 | from explainitall.metrics.RougeAndPPL.Metrics_calculator import Metrics_calculator 7 | from explainitall.metrics.RougeAndPPL.create_database import data_table, ENGINE 8 | from explainitall.metrics.RougeAndPPL.helpers import ( 9 | get_max_dataset_version, generate_candidates, calculate_average_metric_values, 10 | insert_new_record, get_records_from_database, make_dataframe_from_history_records) 11 | 12 | 13 | class MetricCalculationInterface: 14 | demo_ = None 15 | 16 | model_checkbox_ = None 17 | metrics_checkbox_ = None 18 | 19 | calculator_ = None 20 | 21 | model_name_ = None 22 | model_ = None 23 | tokenizer_ = None 24 | model_successfully_loaded_ = None 25 | 26 | dataset_name_ = None 27 | dataset_version_ = None 28 | contexts_ = None 29 | references_ = None 30 | dataset_successfully_loaded_ = None 31 | 32 | conn_ = None 33 | 34 | def __init__(self): 35 | self.conn_ = ENGINE.connect() 36 | 37 | self.demo_ = gr.Blocks() 38 | 39 | with self.demo_: 40 | with gr.Tabs(): 41 | with gr.TabItem("Load"): 42 | gr.Markdown(""" 43 | Загрузите CSV файл с набором данных, который будет использоваться для анализа и оценки качества модели генерации текста
44 | Набор данных должен содержать как минимум две колонки:
45 | - `context`: текстовый контекст или вводные данные, на основе которых модель будет генерировать текст
46 | - `reference`: эталонный текст или ожидаемый ответ модели на заданный контекст
47 | Данный набор данных позволит оценить, насколько хорошо модель способна генерировать текст, соответствующий заданному контексту и эталонным ответам 48 | Пример файла example_data/metrix.csv 49 | """) 50 | with gr.Column(): 51 | with gr.Row(): 52 | with gr.Column(): 53 | with gr.Row(): 54 | with gr.Column(): 55 | dataset_file = gr.File(label='Dataset (CSV)') 56 | dataset_title = gr.Text(label='Dataset title', 57 | info="Это поле обязательно к заполнению. Без указания названия датасета процесс не будет запущен.", 58 | placeholder="Введите название датасета") 59 | 60 | dataset_visualization = gr.Dataframe(label='Dataset', 61 | headers=['context', 'reference'], 62 | wrap=False, 63 | height=500) 64 | with gr.Row(): 65 | dataset_load_button = gr.Button("Load dataset") 66 | dataset_checkbox = gr.Checkbox(label='dataset loaded', interactive=False) 67 | 68 | with gr.Row(): 69 | with gr.Column(): 70 | model_name_or_path = gr.Text(label='Model name or path', 71 | placeholder="distilgpt2", 72 | info="Введите название предварительно обученной модели (например, distilgpt2, gpt2 или sberbank-ai/rugpt3small_based_on_gpt2) или путь к вашей модели") 73 | 74 | with gr.Row(): 75 | model_load_button = gr.Button("Load model") 76 | model_checkbox = gr.Checkbox(label='model loaded', interactive=False) 77 | with gr.Row(): 78 | launch_button = gr.Button("Launch") 79 | metrics_checkbox = gr.Checkbox(label='metrics calculated, откройте вкладку Result и обновите данные', interactive=False) 80 | 81 | with gr.TabItem("Result"): 82 | with gr.Row(): 83 | filter_field_dropdown = gr.Dropdown(["None", "model_name", "dataset_name"], 84 | label='Filter by field') 85 | filter_value_text = gr.Text(label='Filter value') 86 | refresh_button = gr.Button("Refresh") 87 | history_table = gr.Dataframe( 88 | headers=['model_name', 'date', 'dataset_name', 'dataset_version', 'PPL', 'R3', 'R5', 'R-L'], 89 | label='All Data') 90 | 91 | dataset_load_button.click(self.load_dataset_, 92 | inputs=[dataset_file, dataset_title], 93 | outputs=[dataset_visualization, dataset_checkbox]) 94 | model_load_button.click(self.load_model_, 95 | inputs=[model_name_or_path], 96 | outputs=[model_checkbox]) 97 | launch_button.click(self.calculate_metrics_, 98 | inputs=None, 99 | outputs=[metrics_checkbox]) 100 | refresh_button.click(self.refresh_history_, 101 | inputs=[filter_field_dropdown, filter_value_text], 102 | outputs=[history_table]) 103 | 104 | self.model_checkbox_ = model_checkbox 105 | self.dataset_checkbox_ = dataset_checkbox 106 | self.metrics_checkbox_ = metrics_checkbox 107 | 108 | def launch(self): 109 | self.demo_.launch(share=True, debug=False, server_name="127.0.0.1", inbrowser=True) 110 | 111 | def load_model_(self, model_name_or_path): 112 | if not model_name_or_path: 113 | print("Model name or path is required.", model_name_or_path) 114 | 115 | self.model_successfully_loaded_ = False 116 | 117 | self.tokenizer_ = AutoTokenizer.from_pretrained(model_name_or_path) 118 | if self.tokenizer_.pad_token is None: 119 | self.tokenizer_.pad_token = self.tokenizer_.eos_token 120 | self.model_ = AutoModelForCausalLM.from_pretrained(model_name_or_path) 121 | 122 | self.model_name_ = str(model_name_or_path) 123 | 124 | self.calculator_ = Metrics_calculator(self.tokenizer_) 125 | 126 | self.calculator_.add_metric('PPL', Metric_ppl(self.model_, stride=512)) 127 | 128 | self.calculator_.add_metric('R3', MetricRougeN(3)) 129 | self.calculator_.add_metric('R5', MetricRougeN(5)) 130 | 131 | self.calculator_.add_metric('R-L', MetricRougeL(3)) 132 | 133 | self.model_successfully_loaded_ = True 134 | return True 135 | 136 | def load_dataset_(self, csvfile, title): 137 | self.dataset_successfully_loaded_ = False 138 | 139 | print("csvfile", csvfile) 140 | print("title", title) 141 | 142 | if csvfile is None or title == '': 143 | return pd.DataFrame(), False 144 | dataframe = pd.read_csv(csvfile.name, delimiter=',', encoding='utf-8') 145 | 146 | self.dataset_name_ = title 147 | self.dataset_version_ = get_max_dataset_version(self.dataset_name_, self.conn_, data_table) + 1 148 | 149 | self.contexts_ = list(dataframe['context'].values) 150 | self.references_ = list(dataframe['reference'].values) 151 | 152 | self.dataset_successfully_loaded_ = True 153 | return dataframe.head(10), True 154 | 155 | def calculate_metrics_(self): 156 | if self.dataset_successfully_loaded_ is False or self.model_successfully_loaded_ is False: 157 | return False 158 | 159 | model = self.model_ 160 | candidates = generate_candidates(model, 161 | self.tokenizer_, 162 | self.contexts_, 163 | model.config.n_positions, 164 | max_new_tokens=10) 165 | 166 | res = self.calculator_.calculate(self.contexts_, self.references_, candidates) 167 | metric_values = calculate_average_metric_values(res) 168 | 169 | insert_new_record(data_table=data_table, 170 | conn=self.conn_, 171 | model_name=self.model_name_, 172 | dataset_name=self.dataset_name_, 173 | dataset_version=self.dataset_version_, 174 | metric_values=metric_values) 175 | 176 | 177 | return True 178 | 179 | def refresh_history_(self, filter_field, filter_value): 180 | specific_column_value = None 181 | if filter_field != [] and filter_field is not None and filter_field != 'None' and filter_field != '': 182 | specific_column_value = {filter_field: filter_value} 183 | 184 | res = get_records_from_database(data_table, self.conn_, specific_column_value) 185 | df = make_dataframe_from_history_records(res) 186 | 187 | return df 188 | -------------------------------------------------------------------------------- /explainitall/metrics/RougeAndPPL/rouge_L.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | 3 | def find_lowest_position_higher_then_current(pos_list: List[int], current_position: Optional[int]) -> Optional[int]: 4 | """Находит наименьшую позицию в списке, которая больше текущей""" 5 | for pos in pos_list: 6 | if current_position is None or pos > current_position: 7 | return pos 8 | return None 9 | 10 | def get_element_positions(sequence: List[int], element: int) -> List[int]: 11 | """Возвращает список позиций, на которых элемент встречается в последовательности""" 12 | return [i for i, v in enumerate(sequence) if v == element] 13 | 14 | def is_seq1_better_then_seq2(seq1: List[int], seq2: List[int]) -> bool: 15 | """Определяет, является ли первая последовательность лучше второй""" 16 | if seq1[1] < seq2[1]: 17 | return False 18 | if seq1[0] < seq2[0]: 19 | return True 20 | return False 21 | 22 | def unite_sequencies(existing_sequencies: List[List[int]], new_sequencies: List[List[int]]) -> List[List[int]]: 23 | """Объединяет существующие и новые последовательности, оставляя только лучшие""" 24 | i = len(existing_sequencies) - 1 25 | while i >= 0: 26 | g = len(new_sequencies) - 1 27 | while g >= 0: 28 | if is_seq1_better_then_seq2(existing_sequencies[i], new_sequencies[g]): 29 | del new_sequencies[g] 30 | elif is_seq1_better_then_seq2(new_sequencies[g], existing_sequencies[i]): 31 | del existing_sequencies[i] 32 | break 33 | g -= 1 34 | i -= 1 35 | return existing_sequencies + new_sequencies 36 | 37 | def update_existing_sequencies_with_next_element(existing_sequencies: List[List[int]], element: List[int]) -> List[List[int]]: 38 | """Обновляет существующие последовательности с учетом нового элемента""" 39 | new_sequencies = [[element[0], 1]] 40 | for existing_seq in existing_sequencies: 41 | appropriate_ind = find_lowest_position_higher_then_current(element, existing_seq[0]) 42 | if appropriate_ind is not None: 43 | new_sequencies = unite_sequencies(new_sequencies, [[appropriate_ind, existing_seq[1] + 1]]) 44 | return unite_sequencies(existing_sequencies, new_sequencies) 45 | 46 | def fbeta_score(precision: float, recall: float, beta: float = 1.0) -> float: 47 | """Вычисляет FBeta оценку (по-умолчанию F1) на основе precision и recall""" 48 | if precision == 0.0 and recall == 0.0: 49 | return 0.0 50 | if beta == 1.0: 51 | return 2 * precision * recall / (precision + recall) 52 | return (1 + beta * beta) * precision * recall / (beta * beta * precision + recall) 53 | 54 | def rouge_L(reference: List[int], candidate: List[int]) -> Dict[str, float]: 55 | """Вычисляет Rouge-L оценку между эталонной и кандидатской последовательностями.""" 56 | reference_length = len(reference) 57 | candidate_length = len(candidate) 58 | 59 | if candidate_length == 0 or reference_length == 0: 60 | return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0} 61 | 62 | candidate_positions_in_reference = [get_element_positions(reference, el) for el in candidate] 63 | 64 | candidate_positions_in_reference = [positions for positions in candidate_positions_in_reference if positions] 65 | 66 | # Когда после фильтрации нет элементов 67 | if not candidate_positions_in_reference: 68 | return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0} 69 | 70 | existing_sequences = [] 71 | for element_positions in candidate_positions_in_reference: 72 | existing_sequences = update_existing_sequencies_with_next_element(existing_sequences, element_positions) 73 | 74 | # Если нет существующих последовательностей 75 | if not existing_sequences: 76 | return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0} 77 | 78 | max_sequence_length = max(existing_sequences, key=lambda seq: seq[1])[1] 79 | 80 | precision = max_sequence_length / candidate_length 81 | recall = max_sequence_length / reference_length 82 | f1 = fbeta_score(precision, recall) 83 | 84 | return {'precision': precision, 'recall': recall, 'f1': f1} 85 | -------------------------------------------------------------------------------- /explainitall/metrics/RougeAndPPL/rouge_N.py: -------------------------------------------------------------------------------- 1 | from itertools import islice 2 | from typing import List, Tuple, Iterator, Dict 3 | 4 | def split_into_overlapping_chunks(iterable: List[int], chunk_size: int) -> Iterator[Tuple[int, ...]]: 5 | """ 6 | Разбивает последовательность на перекрывающиеся блоки заданного размера 7 | """ 8 | iterator = iter(iterable) 9 | res = tuple(islice(iterator, chunk_size)) 10 | if len(res) == chunk_size: 11 | yield res 12 | for el in iterator: 13 | res = res[1:] + (el,) 14 | yield res 15 | 16 | def list_intersection(l1: List[int], l2: List[int]) -> int: 17 | """ 18 | Подсчитывает и возвращает количество общих элементов между двумя списками (l1 и l2), 19 | удаляя найденные совпадения из второго списка 20 | """ 21 | l1_copy = l1[:] 22 | l2_copy = l2[:] 23 | 24 | res = 0 25 | for el in l1_copy: 26 | i = len(l2_copy) - 1 27 | while i >= 0: 28 | if el == l2_copy[i]: 29 | del l2_copy[i] 30 | res += 1 31 | break 32 | i -= 1 33 | 34 | return res 35 | 36 | def rouge_N(reference: List[int], candidate: List[int], n: int) -> Dict[str, float]: 37 | """ 38 | Вычисляет Rouge-N оценку между эталонной последовательностью и кандидатом 39 | 40 | :param reference: Эталонная последовательность 41 | :param candidate: Кандидатская последовательность 42 | :param n: Размер n-грамм 43 | """ 44 | reference_chunks = list(split_into_overlapping_chunks(reference, n)) 45 | reference_length = len(reference_chunks) 46 | 47 | candidate_chunks = list(split_into_overlapping_chunks(candidate, n)) 48 | candidate_length = len(candidate_chunks) 49 | 50 | if candidate_length == 0 or reference_length == 0: 51 | return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0} 52 | 53 | intersection = list_intersection(reference_chunks, candidate_chunks) 54 | 55 | precision = intersection / candidate_length 56 | recall = intersection / reference_length 57 | f1 = 2*precision * recall / (precision + recall) 58 | 59 | return {'precision': precision, 'recall': recall, 'f1': f1} 60 | -------------------------------------------------------------------------------- /explainitall/nlp.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import gensim 4 | import pymorphy2 5 | 6 | 7 | class WordProcessor: 8 | 9 | def __init__(self, gensim_nlp_embeddings: gensim.models.keyedvectors.KeyedVectors): 10 | self._morph = pymorphy2.MorphAnalyzer() 11 | self.gensim_nlp_embeddings = gensim_nlp_embeddings 12 | 13 | @lru_cache(maxsize=None) 14 | def get_clean_word(self, word: str): 15 | return word.lower().strip() 16 | 17 | # @lru_cache(maxsize=None) 18 | def get_embeddable_words_batch(self, words: list): 19 | cleaned = [self.get_clean_word(w) for w in words] 20 | embeddable = [self.get_embeddable_word_or_none(w) for w in cleaned if w] 21 | return [w for w in embeddable if w] 22 | 23 | @lru_cache(maxsize=None) 24 | def get_morph_or_none(self, word: str): 25 | morphed = self._morph.parse(word) 26 | known_words = [x for x in morphed if x.is_known] 27 | if not known_words: 28 | return None 29 | return known_words[0] 30 | 31 | @lru_cache(maxsize=None) 32 | def get_normal_form_or_none(self, word: str): 33 | clean_word = self.get_clean_word(word) 34 | if not clean_word: 35 | return None 36 | word_morph = self.get_morph_or_none(clean_word) 37 | if not word_morph: 38 | return None 39 | return word_morph.normal_form 40 | 41 | @lru_cache(maxsize=None) 42 | def get_grammeme_or_none(self, word: str): 43 | clean_word = self.get_clean_word(word) 44 | if not clean_word: 45 | return None 46 | word_morph = self.get_morph_or_none(clean_word) 47 | if not word_morph: 48 | return None 49 | tag = word_morph.tag.POS 50 | tag = 'VERB' if word_morph and tag == 'INFN' else tag 51 | return tag 52 | 53 | def get_embeddable_word_or_none(self, word: str): 54 | normal_form = self.get_normal_form_or_none(word) 55 | grammeme = self.get_grammeme_or_none(word) 56 | if not normal_form or not grammeme: 57 | return None 58 | word_tagged = f'{normal_form}_{grammeme}' 59 | if word_tagged not in self.gensim_nlp_embeddings: 60 | return None 61 | return word_tagged 62 | -------------------------------------------------------------------------------- /explainitall/stat_helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Dict 3 | 4 | import numpy as np 5 | from sklearn import mixture 6 | from sklearn.mixture import GaussianMixture 7 | 8 | 9 | def rayleigh_el(el: float, disp: float) -> float: 10 | return (el / disp) * math.exp(-el ** 2 / disp) 11 | 12 | 13 | def rayleigh(arr: np.ndarray, dispersion: float) -> np.ndarray: 14 | return np.array(list(map(lambda x: rayleigh_el(x, dispersion), arr))) 15 | 16 | 17 | def rayleigh_el_integral(el: float, disp: float) -> float: 18 | return 1 - math.exp(-el ** 2 / disp) 19 | 20 | 21 | def rayleigh_integral(arr: np.ndarray, dispersion: float) -> np.ndarray: 22 | return np.array(list(map(lambda x: rayleigh_el_integral(x, dispersion), arr))) 23 | 24 | 25 | def gaussian_integral_single(element, mean, std, sqrt2): 26 | return 0.5 + 0.5 * math.erf((element - mean) / (sqrt2 * std)) 27 | 28 | 29 | def gaussian_mixture_integral_single(element: float, gaussian_mixture_model: GaussianMixture, sqrt2: float) -> float: 30 | means = gaussian_mixture_model.means_ 31 | weights = gaussian_mixture_model.weights_ 32 | variances = gaussian_mixture_model.covariances_[:, 0] 33 | 34 | integral_element = np.sum([ 35 | weight * gaussian_integral_single(mean=mean[0], std=np.sqrt(variance[0]), element=element, sqrt2=sqrt2) 36 | for weight, mean, variance in zip(weights, means, variances) 37 | ]) 38 | return float(integral_element) 39 | 40 | 41 | def gaussian_mixture_integral(arr: np.ndarray, gmm: GaussianMixture) -> np.ndarray: 42 | sqrt2 = math.sqrt(2) 43 | if len(arr.shape) == 1: 44 | return np.array([gaussian_mixture_integral_single(x, gmm, sqrt2) for x in arr]) 45 | 46 | if len(arr.shape) == 2: 47 | dim_1, dim_2 = arr.shape 48 | reshaped_arr = arr.reshape((dim_1 * dim_2)) 49 | reshaped_arr = np.array([gaussian_mixture_integral_single(x, gmm, sqrt2) for x in reshaped_arr]) 50 | return reshaped_arr.reshape((dim_1, dim_2)) 51 | 52 | 53 | def denormalize_array(array: np.ndarray) -> np.ndarray: 54 | edited = array.copy() 55 | non_zero_counts = np.count_nonzero(edited, axis=0) 56 | edited *= non_zero_counts 57 | return edited 58 | 59 | 60 | def calc_gauss_mixture_stat_params(array: np.ndarray, 61 | num_components: int = 3, 62 | seed: Optional[int] = None) -> np.ndarray: 63 | "Рассчет нового массива на базе гауссовой смеси" 64 | d_array_2d = denormalize_array(array) 65 | 66 | d_array_1d = d_array_2d[~np.isnan(array)] 67 | d_array_1d = d_array_1d.reshape(len(d_array_1d), 1) 68 | 69 | gmm = mixture.GaussianMixture(n_components=num_components, random_state=0) 70 | gmm.fit(d_array_1d) 71 | return gaussian_mixture_integral(d_array_2d, gmm) 72 | 73 | 74 | def calc_gmm_stat_params(array: np.ndarray) -> Dict: 75 | array_1d = array[np.logical_not(np.isnan(array))] 76 | 77 | mean = float(np.mean(array_1d)) 78 | std = float(np.std(array_1d)) 79 | new_arr = calc_gauss_mixture_stat_params(array) 80 | 81 | return {'new_arr': new_arr, "mean": mean, "std": std} 82 | 83 | 84 | def compute_gaussian_integral(array: np.ndarray, mean: float, std_dev: float) -> np.ndarray: 85 | """Интегральная Гауссова функция для 1д-2д массива""" 86 | sqrt2 = math.sqrt(2) 87 | vectorized_gaussian = np.vectorize(gaussian_integral_single) 88 | array_1d = array.flatten() 89 | result_array = vectorized_gaussian(array_1d, mean, std_dev, sqrt2) 90 | return result_array.reshape(array.shape) 91 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import gensim 2 | from inseq import load_model 3 | 4 | from explainitall.gpt_like_interp import viz, interp 5 | from explainitall.gpt_like_interp.downloader import DownloadManager 6 | 7 | 8 | def load_nlp_model(nlp_model_url): 9 | nlp_model_path = DownloadManager.load_zip(nlp_model_url) 10 | return gensim.models.KeyedVectors.load_word2vec_format(nlp_model_path, binary=True) 11 | 12 | 13 | # 'ID': 180 14 | # 'Размер вектора': 300 15 | # 'Корпус': 'Russian National Corpus' 16 | # 'Размер словаря': 189193 17 | # 'Алгоритм': 'Gensim Continuous Bag-of-Words' 18 | # 'Лемматизация': True 19 | 20 | nlp_model = load_nlp_model('http://vectors.nlpl.eu/repository/20/180.zip') 21 | 22 | 23 | def load_gpt_model(gpt_model_name): 24 | return load_model(model=gpt_model_name, 25 | attribution_method="integrated_gradients") 26 | 27 | 28 | # 'Фреймворк': 'transformers' 29 | # 'Тренировочные токены': '80 млрд' 30 | # 'Размер контекста': 2048 31 | gpt_model = load_gpt_model("sberbank-ai/rugpt3small_based_on_gpt2") 32 | 33 | clusters_discr = [ 34 | {'name': 'Животные', 'centroid': ['собака', 'кошка', 'заяц'], 'top_k': 140}, 35 | {'name': 'Лекарства', 'centroid': ['уколы', 'таблетки', 'микстуры'], 'top_k': 160}, 36 | {'name': 'Болезни', 'centroid': ['простуда', 'орви', 'орз', 'грипп'], 'top_k': 20}, 37 | {'name': 'Аллергия', 'centroid': ['аллергия'], 'top_k': 20} 38 | ] 39 | 40 | explainer = interp.ExplainerGPT2(gpt_model=gpt_model, nlp_model=nlp_model) 41 | 42 | expl_data = explainer.interpret( 43 | input_texts='у кошки грипп и аллергия на антибиотбиотики вопрос: чем лечить кошку? ответ:', 44 | generated_texts='лечичичичите ее уколами', 45 | clusters_description=clusters_discr, 46 | batch_size=50, 47 | steps=34, 48 | # max_new_tokens=19 49 | ) 50 | 51 | print("\nКарта важности кластеров") 52 | print(expl_data.cluster_imp_df) 53 | 54 | print("\nТепловая карта важности кластеров") 55 | expl_data.show_cluster_imp_heatmap() 56 | 57 | print("\nКарта важности кластеров, группированная") 58 | print(expl_data.cluster_imp_aggr_df) 59 | 60 | print("\nТепловая карта важности кластеров, группированная") 61 | expl_data.show_cluster_imp_aggr_heatmap() 62 | 63 | print("\nКарта важности слов") 64 | print(expl_data.word_imp_df) 65 | 66 | print("\nТепловая карта важности слов") 67 | expl_data.show_word_imp_heatmap() 68 | 69 | print("\nКарта важности слов, нормированная") 70 | print(expl_data.word_imp_norm_df) 71 | 72 | print("\nТепловая карта важности слов, нормированная") 73 | expl_data.show_word_imp_norm_heatmap() 74 | 75 | print("\nГистограмма распределения") 76 | viz.show_distribution_histogram(expl_data.attributions.array) 77 | print("\nГрафик распределения") 78 | viz.show_distribution_plot(expl_data.attributions.array) 79 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -v 3 | python_files = test_*.py 4 | python_functions = test_* 5 | testpaths = tests 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gensim==4.3.2 2 | gradio==4.4.1 3 | gradio_client==0.7.0 4 | inseq==0.5.0 5 | matplotlib==3.8.4 6 | nltk==3.8.1 7 | pandas==2.2.1 8 | pymorphy2==0.9.1 9 | scikit-learn==1.4.1.post1 10 | seaborn==0.13.2 11 | sentence_transformers==2.6.1 12 | SQLAlchemy==2.0.29 13 | statsmodels==0.14.1 14 | torch>=2.0 15 | transformers>=4.39 16 | uvicorn==0.29 17 | scipy==1.12.0 18 | numpy==1.26.4 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup, find_namespace_packages 3 | except ImportError: 4 | from distutils.core import setup, find_namespace_packages 5 | 6 | REQUIREMENTS = [i.strip() for i in open("requirements.txt").readlines()] 7 | 8 | setup( 9 | name='explainitall', 10 | version='1.0.2', 11 | long_description=open('README.md', encoding='utf-8', errors='ignore').read(), 12 | long_description_content_type='text/markdown', 13 | packages=find_namespace_packages(include=['explainitall', 'explainitall.*']), 14 | install_requires=REQUIREMENTS 15 | ) 16 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bots-Avatar/ExplainitAll/0339ea5c09c3cd309d53c23b403465c821a778d0/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_inseq_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from explainitall.gpt_like_interp import inseq_helpers 5 | 6 | test_data = [ 7 | (['Reference! site - a-bout Lorem-Ipsum,', 8 | [' ref', '!ere_', 'nce((', '__site', 'a-bout__', '@Lorem', '*Ipsum'], 9 | [['Ref', 'ere', 'nce'], ['site'], ['a-bout'], ['Lorem', ('-', 0), 'Ipsum']]]), 10 | (['Reference! q321 - a-bout Lorem-Ipsum,', 11 | [' ref', '!ere_', 'nce((', '__q321', 'a-bout__', '@Lorem', '*Ipsum'], 12 | [['Ref', 'ere', 'nce'], ['q321'], ['a-bout'], ['Lorem', ('-', 0), 'Ipsum']]]), 13 | (['giving information on its origins, as well ', 14 | ['🃯🃯giving🃯', 'ÐinformationÐ ', 'on__', 'its', '__origin', '&&&s', 'as', 'well'], 15 | [['giving'], ['information'], ['on'], ['its'], ['origin', 's'], ['as'], ['well']]]), 16 | ([' as a random Lipsum generator.', 17 | ['as🃯🃯', 'a', 'Ðrandom', '🃯lipsum', '###generator!'], 18 | [['as'], ['a'], ['random'], ['Lipsum'], ['generator']]]), 19 | (['Папа у Васи чуть-чуть силён в Математике!', 20 | ['папа🃯🃯', 'у', 'Ðва', '🃯с', '###и!', '!чуть', '-', 'чуть', 'силён', '#в!', 'математике'], 21 | [['Папа'], ['у'], ['Ва', 'с', 'и'], ['чуть', ('-', 0), '', 'чуть'], ['силён'], ['в'], ['Математике']]])] 22 | 23 | 24 | # Apply parametrization 25 | @pytest.mark.parametrize("inp_text,inp_pairs,expected_rez", test_data) 26 | def test_detokenizer(inp_text, inp_pairs, expected_rez): 27 | fact_rez = inseq_helpers.Detokenizer(inp_text, inp_pairs).group_text() 28 | assert fact_rez == expected_rez 29 | 30 | 31 | def test_squash_arr(): 32 | array = np.array([[4., 7., 3., 8., 6.], 33 | [4., 9., 7., 4., 1.], 34 | [2., 6., 0., 0., 7.], 35 | [2., 5., 5., 4., 8.], 36 | [1., 4., 6., 10., 4.]]) 37 | 38 | squashed_arr = inseq_helpers.squash_arr(arr=array, 39 | squash_row_mask=[[0, 1], [1, 5]], 40 | squash_col_mask=[[0, 3], [3, 5]], 41 | aggr_f=np.max) 42 | 43 | expected = np.array([[7., 8.], 44 | [9., 10.]]) 45 | np.testing.assert_array_equal(squashed_arr, expected) 46 | 47 | 48 | def test_calculate_mask(): 49 | grouped_data = [[1, 2, 3], [11, 22], [333]] 50 | expected_output = [[0, 3], [3, 5], [5, 6]] 51 | assert inseq_helpers.calculate_mask(grouped_data) == expected_output 52 | -------------------------------------------------------------------------------- /tests/test_stat_helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import pytest 5 | from sklearn import mixture 6 | 7 | from explainitall import (stat_helpers) 8 | 9 | 10 | def test_rayleigh_el(): 11 | el = 1.0 12 | disp = 2.0 13 | expected = 0.303 14 | result = stat_helpers.rayleigh_el(el, disp) 15 | assert math.isclose(result, expected, rel_tol=1e-3) 16 | 17 | 18 | @pytest.mark.parametrize("arr, dispersion, expected", [ 19 | (np.array([1, 2, 3]), 2.0, np.array([0.303, 0.135, 0.016])), 20 | ]) 21 | def test_rayleigh(arr, dispersion, expected): 22 | result = stat_helpers.rayleigh(arr, dispersion) 23 | np.testing.assert_almost_equal(result, expected, decimal=3) 24 | 25 | 26 | @pytest.mark.parametrize("el, disp, expected", [ 27 | (1.0, 2.0, 0.3934), 28 | (2.0, 1.5, 0.9305), 29 | (0.5, 3.0, 0.0799) 30 | ]) 31 | def test_rayleigh_el_int(el, disp, expected): 32 | result = stat_helpers.rayleigh_el_integral(el, disp) 33 | assert math.isclose(result, expected, rel_tol=1e-3) 34 | 35 | 36 | @pytest.mark.parametrize("arr, dispersion, expected", [ 37 | (np.array([1, 2, 3]), 1.0, np.array([0.6321, 0.9816, 0.99987])), 38 | (np.array([0, 0.5, 1]), 0.5, np.array([0.0, 0.39346, 0.8646]))]) 39 | def test_rayleigh_int(arr, dispersion, expected): 40 | result = stat_helpers.rayleigh_integral(arr, dispersion) 41 | np.testing.assert_almost_equal(result, expected, decimal=4) 42 | 43 | 44 | @pytest.mark.parametrize("el, mu, std, sqrt2, expected", [ 45 | (1.0, 2.0, 0.3934, math.sqrt(2), 0.0055119), 46 | (2.0, 1.5, 0.9305, math.sqrt(2), 0.7044855), 47 | (0.5, 3.0, 0.0799, math.sqrt(2), 0.0) 48 | ]) 49 | def test_gauss_integral_element(el, mu, std, sqrt2, expected): 50 | result = stat_helpers.gaussian_integral_single(el, mu, std, sqrt2) 51 | assert math.isclose(result, expected, rel_tol=1e-3) 52 | 53 | 54 | def test_gauss_m_integral_element(): 55 | arr = np.array([1, 2, 3, 4, 5]) 56 | gmm = mixture.GaussianMixture(n_components=2, random_state=0) 57 | gmm.fit(arr.reshape(-1, 1), 1) 58 | el = 1 59 | sqrt2 = math.sqrt(2) 60 | result = stat_helpers.gaussian_mixture_integral_single(el, gmm, sqrt2) 61 | expected_result = 0.07186 62 | assert math.isclose(result, expected_result, rel_tol=1e-3) 63 | 64 | 65 | def test_gauss_m_integral_1D(): 66 | arr = np.array([1, 2, 3, 4, 5]) 67 | gmm = mixture.GaussianMixture(n_components=2, random_state=0) 68 | gmm.fit(arr.reshape(-1, 1), 1) 69 | expected_result = np.array([0.0719, 0.2886, 0.5261, 0.6753, 0.9324]) 70 | result = stat_helpers.gaussian_mixture_integral(arr=arr, gmm=gmm) 71 | np.testing.assert_almost_equal(result, expected_result, decimal=4) 72 | 73 | 74 | def test_gauss_m_integral_2D(): 75 | arr = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) 76 | gmm = mixture.GaussianMixture(n_components=2, random_state=0) 77 | gmm.fit(arr.reshape(-1, 1), 1) 78 | expected_result = np.array([[0.1676, 0.4391, 0.8942], 79 | [0.1676, 0.4391, 0.8942], 80 | [0.1676, 0.4391, 0.8942]]) 81 | result = stat_helpers.gaussian_mixture_integral(arr=arr, gmm=gmm) 82 | np.testing.assert_almost_equal(result, expected_result, decimal=4) 83 | 84 | 85 | def test_gauss_integral_1D(): 86 | arr = np.array([1, 2, 3, 4, 5]) 87 | mu = 3 88 | std = 1 89 | expected_result = np.array([0.0227, 0.1586, 0.5000, 0.8413, 0.9772]) 90 | result = stat_helpers.compute_gaussian_integral(arr, mu, std) 91 | np.testing.assert_almost_equal(result, expected_result, decimal=4) 92 | 93 | def test_gauss_integral_2D(): 94 | arr = np.array([[1, 2], [3, 4]]) 95 | mu = 2 96 | std = 1 97 | expected_result = np.array([[0.15865525, 0.5], [0.84134475, 0.97724987]]) 98 | result = stat_helpers.compute_gaussian_integral(arr, mu, std) 99 | np.testing.assert_almost_equal(result, expected_result, decimal=4) 100 | 101 | 102 | def test_de_normalizy(): 103 | # Test case 1: Normalized array with non-zero values 104 | arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 105 | expected = np.array([[2., 4., 6.], [8., 10., 12.]]) 106 | np.testing.assert_almost_equal(stat_helpers.denormalize_array(arr), expected, decimal=1) 107 | 108 | 109 | def test_calc_gauss_mixture_stat_params(): 110 | arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 111 | result = stat_helpers.calc_gauss_mixture_stat_params(arr, num_components=1, seed=0) 112 | expected = np.array([[0.061, 0.123, 0.219], [0.349, 0.5, 0.651], [0.781, 0.877, 0.939]]) 113 | np.testing.assert_almost_equal(result, expected, decimal=3) 114 | --------------------------------------------------------------------------------