├── .dockerignore ├── .github └── workflows │ └── build.yaml ├── .gitignore ├── LICENSE ├── README.md ├── autotm ├── __init__.py ├── abstract_params.py ├── algorithms_for_tuning │ ├── __init__.py │ ├── bayesian_optimization │ │ ├── __init__.py │ │ └── bayes_opt.py │ ├── genetic_algorithm │ │ ├── __init__.py │ │ ├── config.yaml │ │ ├── crossover.py │ │ ├── ga.py │ │ ├── genetic_algorithm.py │ │ ├── mutation.py │ │ ├── selection.py │ │ ├── statistics_collector.py │ │ ├── strategy.py │ │ └── surrogate.py │ ├── individuals.py │ └── nelder_mead_optimization │ │ ├── __init__.py │ │ └── nelder_mead.py ├── base.py ├── batch_vect_utils.py ├── clustering.py ├── content_splitter.py ├── data_generator │ └── synthetic_data_generation.ipynb ├── fitness │ ├── __init__.py │ ├── base_score.py │ ├── cluster_tasks.py │ ├── estimator.py │ ├── external_scores.py │ ├── local_tasks.py │ └── tm.py ├── graph_ga.py ├── infer.py ├── main.py ├── main_fitness_worker.py ├── make-recovery-file.py ├── params.py ├── params_logging_utils.py ├── pipeline.py ├── preprocessing │ ├── __init__.py │ ├── cooc.py │ ├── dictionaries_preparation.py │ └── text_preprocessing.py ├── schemas.py ├── utils.py └── visualization │ ├── __init__.py │ └── dynamic_tracker.py ├── cluster ├── bin │ └── fitnessctl ├── charts │ └── autotm │ │ ├── .helmignore │ │ ├── Chart.yaml │ │ ├── templates │ │ ├── _helpers.tpl │ │ ├── configmaps.yaml │ │ ├── deployments.yaml │ │ ├── pvc.yaml │ │ └── services.yaml │ │ └── values.yaml ├── conf │ └── pv.yaml └── docker │ ├── flower.dockerfile │ ├── jupyter.dockerfile │ ├── mlflow-webserver.dockerfile │ └── worker.dockerfile ├── conf └── config.yaml ├── data └── sample_corpora │ ├── clean_docs_v17_gost_only.csv │ ├── dataset_books_stroitelstvo.csv │ ├── imdb_100.csv │ ├── imdb_1000.csv │ ├── partition_df.csv │ └── sample_dataset_lenta.csv ├── distributed ├── README.md ├── autotm_distributed │ ├── __init__.py │ ├── deploy_config_generator.py │ ├── main.py │ ├── metrics.py │ ├── params_logging_utils.py │ ├── preprocessing.py │ ├── schemas.py │ ├── tasks.py │ ├── test_app.py │ ├── tm.py │ └── utils.py ├── bin │ ├── fitnessctl │ ├── kube-fitnessctl │ └── trun.sh ├── deploy │ ├── ess-small-datasets-config.yaml │ ├── file.yaml │ ├── fitness-worker-health-checker.yaml │ ├── kube-fitness-client-job.yaml │ ├── kube-fitness-client-job.yaml.j2 │ ├── kube-fitness-workers.yaml │ ├── kube-fitness-workers.yaml.j2 │ ├── mlflow-minikube-persistent-volumes.yaml │ ├── mlflow-persistent-volumes.yaml │ ├── mlflow.yaml │ ├── mongo_dev │ │ ├── docker-compose.yaml │ │ └── wait-for.sh │ ├── ns.yaml │ └── test-datasets-config.yaml ├── docker │ ├── base.dockerfile │ ├── cli.dockerfile │ ├── client.dockerfile │ ├── flower.dockerfile │ ├── health-checker.dockerfile │ ├── mlflow-webserver.dockerfile │ ├── test.dockerfile │ └── worker.dockerfile ├── poetry.lock ├── pyproject.toml ├── setup.py └── test │ └── topic_model_test.py ├── docs ├── Makefile ├── _static │ └── style.css ├── _templates │ ├── autosummary │ │ ├── class.rst │ │ └── module.rst │ ├── classtemplate.rst │ └── functiontemplate.rst ├── conf.py ├── img │ ├── MyLogo.png │ ├── autotm_arch_v3 (1).png │ ├── img_library_eng.png │ ├── photo_2023-06-29_20-20-57.jpg │ ├── photo_2023-06-30_14-17-08.jpg │ ├── pipeling.png │ └── strategy.png ├── index.rst ├── make.bat └── pages │ ├── algorithms_for_tuning.rst │ ├── api │ ├── algorithms_for_tuning.rst │ ├── fitness.rst │ ├── index.rst │ ├── preprocessing.rst │ └── visualization.rst │ ├── fitness.rst │ ├── installation.rst │ ├── preprocessing.rst │ ├── userguide │ ├── index.rst │ ├── metrics.rst │ └── regularizers.rst │ └── visualization.rst ├── examples ├── demo │ ├── autotm_demo_ru_OLD.ipynb │ └── autotm_demo_updated_clustering.ipynb ├── demo_autotm.ipynb ├── examples_autotm_fit_predict.py ├── graph_building_for_stroitelstvo.py ├── graph_profile_example.py └── topic_modeling_of_corporative_data.py ├── logging.config.yml ├── poetry.lock ├── pyproject.toml ├── scripts ├── __init__.py ├── algs │ ├── full_pipeline_example.py │ └── nelder_mead_experiments.py ├── experiments │ ├── __init__.py │ ├── analysis.ipynb │ ├── experiment.py │ ├── plot_boxplot_Final_results.png │ ├── plot_boxplot_Start_results.png │ ├── plot_progress_hotel-reviews_sample.png │ ├── plots.ipynb │ └── statistics │ │ ├── 240602-124458_hotel-reviews_sample_fixed_parameters.txt │ │ ├── 240602-124458_hotel-reviews_sample_fixed_progress.txt │ │ ├── 240602-124539_hotel-reviews_sample_fixed_parameters.txt │ │ ├── 240602-124539_hotel-reviews_sample_fixed_progress.txt │ │ ├── 240602-124633_hotel-reviews_sample_fixed_parameters.txt │ │ ├── 240602-124633_hotel-reviews_sample_fixed_progress.txt │ │ ├── 240602-124719_hotel-reviews_sample_pipeline_parameters.txt │ │ ├── 240602-124719_hotel-reviews_sample_pipeline_progress.txt │ │ ├── 240602-124816_hotel-reviews_sample_pipeline_parameters.txt │ │ ├── 240602-124816_hotel-reviews_sample_pipeline_progress.txt │ │ ├── 240602-124950_hotel-reviews_sample_pipeline_parameters.txt │ │ ├── 240602-124950_hotel-reviews_sample_pipeline_progress.txt │ │ ├── 240603-231051_hotel-reviews_sample_fixed_parameters.txt │ │ ├── 240603-231154_hotel-reviews_sample_fixed_parameters.txt │ │ ├── 240603-231154_hotel-reviews_sample_fixed_progress.txt │ │ ├── 240603-231313_hotel-reviews_sample_pipeline_parameters.txt │ │ ├── 240603-231313_hotel-reviews_sample_pipeline_progress.txt │ │ ├── 240603-232451_hotel-reviews_sample_fixed_parameters.txt │ │ ├── 240603-232451_hotel-reviews_sample_fixed_progress.txt │ │ ├── 240603-232546_hotel-reviews_sample_fixed_parameters.txt │ │ ├── 240603-232546_hotel-reviews_sample_fixed_progress.txt │ │ ├── 240603-232733_hotel-reviews_sample_fixed_parameters.txt │ │ ├── 240603-232733_hotel-reviews_sample_fixed_progress.txt │ │ ├── 240603-232847_hotel-reviews_sample_pipeline_parameters.txt │ │ ├── 240603-232916_hotel-reviews_sample_fixed_parameters.txt │ │ ├── 240603-232916_hotel-reviews_sample_fixed_progress.txt │ │ ├── 240603-233132_hotel-reviews_sample_pipeline_parameters.txt │ │ └── 240603-233132_hotel-reviews_sample_pipeline_progress.txt ├── other │ ├── big_sample_4class_avg.csv │ ├── preparation_pipeline.py │ ├── sample_3class.csv │ ├── sample_4class_avg.csv │ └── toloka_markup_analysis.ipynb └── topic_modeling_of_corporative_data.py ├── tests ├── integration │ ├── __init__.py │ └── test_fit_predict.py └── unit │ ├── __init__.py │ ├── conftest.py │ ├── test_cooc.py │ ├── test_dictionaries_preparation.py │ ├── test_llm_fitness.py │ └── test_preprocessing.py └── toloka ├── Estimate_topics_interpretability ├── input-data.json ├── instructions.md ├── output-data.json ├── task.css ├── task.html └── task.js └── MarkupPreparation.ipynb /.dockerignore: -------------------------------------------------------------------------------- 1 | ./Notebooks 2 | ./src/algorithms_for_tuning/genetic_algorithm/resources/ 3 | ./venv 4 | ./src/irace_config/ga_test 5 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: build 2 | run-name: ${{ github.repository }} installation test 3 | 4 | on: [push] 5 | 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | python-version: ["3.9", "3.10", "3.11"] 13 | 14 | steps: 15 | - run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event." 16 | - uses: actions/checkout@v3 17 | - name: List files in the repository 18 | run: | 19 | ls ${{ github.workspace }} 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Using python version 25 | run: python --version 26 | # - name: Install pip 27 | # run: apt install -y python3-pip 28 | - name: Install dependencies for tests 29 | run: | 30 | python -m pip install --upgrade pip 31 | python -m pip install flake8 poetry 32 | poetry install 33 | - name: Download english corpus 34 | run: poetry run python -m spacy download en_core_web_sm 35 | - name: Setting language and locale 36 | run: | 37 | sudo apt-get update 38 | sudo apt-get install -y locales 39 | sudo locale-gen ru_RU.UTF-8 40 | sudo update-locale 41 | export LC_ALL="ru_RU.UTF-8" 42 | export LANG="ru_RU.UTF-8" 43 | export LANGUAGE="ru_RU.UTF-8" 44 | - name: set pythonpath 45 | run: | 46 | echo "PYTHONPATH=." >> $GITHUB_ENV 47 | - name: Lint with flake8 48 | run: | 49 | flake8 autotm --count --select=E9,F63,F7,F82 --show-source --statistics 50 | - name: Run test code 51 | run: | 52 | poetry run pytest tests 53 | - run: echo "🍏 This job's status is ${{ job.status }}." 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Sphinx 132 | docs/generated 133 | 134 | autotm/algorithms_for_tuning/genetic_algorithm/resources/* 135 | .idea 136 | algorithms_for_tuning/genetic_algorithm/logs/* 137 | algorithms_for_tuning/genetic_algorithm/bigartm.* 138 | logs 139 | new_logs 140 | resources 141 | *.whl 142 | surrogate_logs/* 143 | *.orig 144 | *.pickle 145 | Notebooks/ 146 | logs/ 147 | examples/mlruns/ 148 | data/experiment_datasets/ 149 | data/processed_sample_corpora 150 | data/sample_corpora/*.txt 151 | data/sample_corpora/*.pkl 152 | data/sample_corpora/batches/* 153 | examples/tmp 154 | examples/metrics 155 | examples/*.txt 156 | examples/out 157 | bigartm.* 158 | coverage_re/ 159 | tests/integration/mlruns/* 160 | tests/integration/metrics 161 | **/mlruns/* 162 | **/metrics/* 163 | autotm_workdir_* 164 | requirements.txt 165 | *.patch 166 | 167 | mixtures.csv 168 | model.artm 169 | 170 | # sphinx docs 171 | docs/_build 172 | 173 | examples/experiments/*.txt 174 | examples/experiments/*/*.txt 175 | examples/experiments/*.png 176 | examples/experiments/*/*.png 177 | scripts/experiments/metrics 178 | scripts/experiments/mlruns 179 | 180 | rsync-repo.sh 181 | 182 | tmp 183 | 184 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Industrial AI Lab Team 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | Library scheme 4 |

5 | 6 |

7 | AutoTM 8 |

9 | 10 |

11 | 12 | Project Status: Active – The project has reached a stable, usable state and is being actively developed. 13 | 14 | build 15 | 16 | License 17 | 18 | 19 | PyPI version 20 | 21 | 22 | Documentation Status 23 | 24 | 25 | Downloads 26 | 27 |

28 | 29 | :sparkles:**News**:sparkles: We have fully updated our framework to AutoTM 2.0 version enriched with new functionality! Stay tuned! 30 | 31 | Automatic parameters selection for topic models (ARTM approach) using evolutionary and bayesian algorithms. 32 | AutoTM provides necessary tools to preprocess english and russian text datasets and tune additively regularized topic models. 33 | 34 | ## What is AutoTM? 35 | Topic modeling is one of the basic methods for a whole range of tasks: 36 | 37 | * Exploratory data analysis of unlabelled text data 38 | * Extracting interpretable features (topics and their combinations) from text data 39 | * Searching for hidden insights in the data 40 | 41 | While ARTM (additive regularization for topic models) approach provides the significant flexibility and quality comparative or better that neural 42 | approaches it is hard to tune such models due to amount of hyperparameters and their combinations. That is why we provide optimization pipelines to efortlessly process custom datasets. 43 | 44 | To overcome the tuning problems AutoTM presents an easy way to represent a learning strategy to train specific models for input corporas. We implement two strategy variants: 45 | 46 | * fixed-size variant, that provides a learning strategy that follow the best practices collected from the manual tuning history 47 | 48 | Learning strategy representation (fixed-size) 49 | 50 | * graph-based variant with more flexibility and unfixed ordering and amount of stages (**New in AutoTM 2.0**). Example of pipeline is provided below: 51 | 52 |

53 | Learning strategy representation (graph-based) 54 |

55 | 56 | Optimization procedure is done by genetic algorithm (GA) which operators are specifically tuned for the each of the strategy creation variants (GA for graph-based is **New in AutoTM 2.0**). Bayesian Optimization is available only for fixed-size strategy. 57 | 58 | To speed up the procedure AutoTM also contain surrogate modeling implementation for fixed-size and graph-based (**New in AutoTM 2.0**) learning strategies that, for some iterations, 59 | approximate fitness function to reduce computation costs on training topic models. 60 | 61 |

62 | Library scheme 63 |

64 | 65 | AutoTM also propose a range of metrics that can be used as fitness function, like classical ones as coherence to LLM-based (**New in AutoTM 2.0**). 66 | 67 | ## Installation 68 | 69 | ! Note: The functionality of topic models training is available only for linux distributions. 70 | 71 | **Via pip:** 72 | 73 | ```pip install autotm``` 74 | 75 | ```python -m spacy download en_core_web_sm``` 76 | 77 | **From source:** 78 | 79 | ```poetry install``` 80 | 81 | ```python -m spacy download en_core_web_sm``` 82 | 83 | [//]: # (## Dataset and ) 84 | 85 | ## Quickstart 86 | 87 | Start with the notebook [Easy topic modeling with AutoTM](https://github.com/aimclub/AutoTM/blob/main/examples/demo_autotm.ipynb) or with the following script [AutoTM configurations](https://github.com/aimclub/AutoTM/blob/main/examples/examples_autotm_fit_predict.py) 88 | 89 | ## Running from the command line 90 | 91 | To fit a model: 92 | ```autotmctl --verbose fit --config conf/config.yaml --in data/sample_corpora/sample_dataset_lenta.csv``` 93 | 94 | To predict with a fitted model: 95 | ```autotmctl predict --in data/sample_corpora/sample_dataset_lenta.csv --model model.artm``` 96 | 97 | 98 | ## Citation 99 | 100 | ```bibtex 101 | @article{10.1093/jigpal/jzac019, 102 | author = {Khodorchenko, Maria and Butakov, Nikolay and Sokhin, Timur and Teryoshkin, Sergey}, 103 | title = "{ Surrogate-based optimization of learning strategies for additively regularized topic models}", 104 | journal = {Logic Journal of the IGPL}, 105 | year = {2022}, 106 | month = {02}, 107 | issn = {1367-0751}, 108 | doi = {10.1093/jigpal/jzac019}, 109 | url = {https://doi.org/10.1093/jigpal/jzac019},} 110 | 111 | ``` 112 | -------------------------------------------------------------------------------- /autotm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/__init__.py -------------------------------------------------------------------------------- /autotm/abstract_params.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from typing import List 4 | 5 | 6 | class AbstractParams(ABC): 7 | @property 8 | @abstractmethod 9 | def basic_topics(self) -> int: 10 | ... 11 | 12 | @property 13 | @abstractmethod 14 | def mutation_probability(self): 15 | ... 16 | 17 | @abstractmethod 18 | def make_params_dict(self): 19 | ... 20 | 21 | @abstractmethod 22 | def run_train(self, model): 23 | """ 24 | Trains the topic model 25 | :param model: an instance of TopicModel 26 | :return: 27 | """ 28 | ... 29 | 30 | @abstractmethod 31 | def validate_params(self) -> bool: 32 | ... 33 | 34 | @abstractmethod 35 | def crossover(self, parent2: "AbstractParams", **kwargs) -> List["AbstractParams"]: 36 | ... 37 | 38 | @abstractmethod 39 | def mutate(self, **kwargs) -> "AbstractParams": 40 | ... 41 | 42 | @abstractmethod 43 | def to_vector(self) -> List[float]: 44 | ... 45 | 46 | -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/algorithms_for_tuning/__init__.py -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/bayesian_optimization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/algorithms_for_tuning/bayesian_optimization/__init__.py -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/genetic_algorithm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/algorithms_for_tuning/genetic_algorithm/__init__.py -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/genetic_algorithm/config.yaml: -------------------------------------------------------------------------------- 1 | appName: geneticAlgo 2 | logLevel: WARN 3 | 4 | testMode: False 5 | 6 | gaAlgoParams: 7 | numEvals: 20 8 | 9 | -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/genetic_algorithm/crossover.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from typing import List, Callable 4 | 5 | 6 | def crossover_pmx(parent_1: List[float], parent_2: List[float], **kwargs) -> List[List[float]]: 7 | """ 8 | Pmx crossover 9 | 10 | Exchange chromosome parts 11 | 12 | Parameters 13 | ---------- 14 | parent_1: List[float] 15 | The first individual to be processed 16 | parent_2: List[float] 17 | The second individual to be processed 18 | 19 | Returns 20 | ---------- 21 | Updated individuals with exchanged chromosome parts 22 | """ 23 | points_num = len(parent_1) 24 | while True: 25 | cut_ix = np.random.choice(points_num + 1, 2, replace=False) 26 | min_ix = np.min(cut_ix) 27 | max_ix = np.max(cut_ix) 28 | part = parent_1[min_ix:max_ix] 29 | if len(part) != len(parent_1): 30 | break 31 | parent_1[min_ix:max_ix] = parent_2[min_ix:max_ix] 32 | parent_2[min_ix:max_ix] = part 33 | return [parent_1, parent_2] 34 | 35 | 36 | # discrete crossover 37 | def crossover_one_point(parent_1: List[float], parent_2: List[float], **kwargs) -> List[List[float]]: 38 | """ 39 | One-point crossover 40 | 41 | Exchange points between chromosomes 42 | 43 | Parameters 44 | ---------- 45 | parent_1: List[float] 46 | The first individual to be processed 47 | parent_2: List[float] 48 | The second individual to be processed 49 | 50 | Returns 51 | ---------- 52 | Updated individuals with exchanged chromosome parts 53 | """ 54 | elem_cross_prob = kwargs["elem_cross_prob"] 55 | for i in range(len(parent_1)): 56 | # removed mutation preservation 57 | if random.random() < elem_cross_prob: 58 | parent_1[i], parent_2[i] = parent_2[i], parent_1[i] 59 | return [parent_1, parent_2] 60 | 61 | 62 | def crossover_blend_new(parent_1: List[float], parent_2: List[float], **kwargs) -> List[List[float]]: 63 | """ 64 | Blend crossover 65 | 66 | Making combination of parents solution with coefficient 67 | 68 | Parameters 69 | ---------- 70 | parent_1: List[float] 71 | The first individual to be processed 72 | parent_2: List[float] 73 | The second individual to be processed 74 | alpha: float 75 | Blending coefficient 76 | 77 | Returns 78 | ---------- 79 | Updated individuals with exchanged chromosome parts 80 | """ 81 | alpha = kwargs["alpha"] 82 | child_1 = [] 83 | child_2 = [] 84 | u = random.random() 85 | gamma = (1.0 + 2.0 * alpha) * u - alpha # fixed (1. + 2. * alpha) * u - alpha 86 | for i in range(len(parent_1)): 87 | child_1.append((1.0 - gamma) * parent_1[i] + gamma * parent_2[i]) 88 | child_2.append(gamma * parent_1[i] + (1.0 - gamma) * parent_2[i]) 89 | 90 | # TODO: reconsider this 91 | child_1[12:15] = parent_1[12:15] 92 | child_2[12:15] = parent_2[12:15] 93 | return [child_1, child_2] 94 | 95 | 96 | def crossover_blend(parent_1: List[float], parent_2: List[float], **kwargs) -> List[List[float]]: 97 | """ 98 | Blend crossover 99 | 100 | Making combination of parents solution with coefficient 101 | 102 | Parameters 103 | ---------- 104 | parent_1: List[float] 105 | The first individual to be processed 106 | parent_2: List[float] 107 | The second individual to be processed 108 | alpha: float 109 | Blending coefficient 110 | 111 | Returns 112 | ---------- 113 | Updated individuals with exchanged chromosome parts 114 | """ 115 | alpha = kwargs["alpha"] 116 | child = [] 117 | u = random.random() 118 | gamma = (1 - 2 * alpha) * u - alpha 119 | for i in range(len(parent_1)): 120 | child.append((1 - gamma) * parent_1[i] + gamma * parent_2[i]) 121 | if random.random() > 0.5: 122 | child[12:15] = parent_1[12:15] 123 | else: 124 | child[12:15] = parent_2[12:15] 125 | return [child] 126 | 127 | 128 | def crossover(crossover_type: str = "crossover_one_point") -> Callable: 129 | """ 130 | Crossover function 131 | 132 | Parameters 133 | ---------- 134 | crossover_type : str, default="crossover_one_point" 135 | Crossover to be used in the genetic algorithm 136 | """ 137 | if crossover_type == "crossover_pmx": 138 | return crossover_pmx 139 | if crossover_type == "crossover_one_point": 140 | return crossover_one_point 141 | if crossover_type == "blend_crossover": 142 | return crossover_blend 143 | -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/genetic_algorithm/mutation.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List 3 | 4 | import numpy as np 5 | 6 | 7 | def mutation_one_param( 8 | individ: List[float], 9 | low_spb: float, 10 | high_spb: float, 11 | low_spm: float, 12 | high_spm: float, 13 | low_n: int, 14 | high_n: int, 15 | low_back: int, 16 | high_back: int, 17 | low_decor: float, 18 | high_decor: float, 19 | elem_mutation_prob: float = 0.1, 20 | ): 21 | """ 22 | One-point mutation 23 | 24 | Checking the probability of mutation for each of the elements 25 | 26 | Parameters 27 | ---------- 28 | individ: List[float] 29 | Individual to be processed 30 | low_spb: float 31 | The lower possible bound for sparsity regularizer of back topics 32 | high_spb: float 33 | The higher possible bound for sparsity regularizer of back topics 34 | low_spm: float 35 | The lower possible bound for sparsity regularizer of specific topics 36 | high_spm: float 37 | The higher possible bound for sparsity regularizer of specific topics 38 | low_n: int 39 | The lower possible bound for amount of iterations between stages 40 | high_n: int 41 | The higher possible bound for amount of iterations between stages 42 | low_back: 43 | The lower possible bound for amount of back topics 44 | high_back: 45 | The higher possible bound for amount of back topics 46 | 47 | 48 | Returns 49 | ---------- 50 | Updated individuals with exchanged chromosome parts 51 | """ 52 | for i in range(len(individ)): 53 | if random.random() <= elem_mutation_prob: 54 | if i in [2, 3]: 55 | individ[i] = np.random.uniform(low=low_spb, high=high_spb, size=1)[0] 56 | for i in [5, 6, 8, 9]: 57 | individ[i] = np.random.uniform(low=low_spm, high=high_spm, size=1)[0] 58 | for i in [1, 4, 7, 10]: 59 | individ[i] = float(np.random.randint(low=low_n, high=high_n, size=1)[0]) 60 | for i in [11]: 61 | individ[i] = float(np.random.randint(low=low_back, high=high_back, size=1)[0]) 62 | for i in [0, 15]: 63 | individ[i] = np.random.uniform(low=low_decor, high=high_decor, size=1)[0] 64 | return individ 65 | 66 | 67 | def positioning_mutation(individ, elem_mutation_prob=0.1, **kwargs): 68 | for i in range(len(individ)): 69 | set_1 = set([i]) 70 | if random.random() <= elem_mutation_prob: 71 | if i in [0, 15]: 72 | set_2 = set([0, 15]) 73 | ix = np.random.choice(list(set_2.difference(set_1))) 74 | tmp = individ[ix] 75 | individ[ix] = individ[i] 76 | individ[i] = tmp 77 | elif i in [1, 4, 7, 10, 11]: 78 | set_2 = set([1, 4, 7, 10, 11]) 79 | ix = np.random.choice(list(set_2.difference(set_1))) 80 | tmp = individ[ix] 81 | individ[ix] = individ[i] 82 | individ[i] = tmp 83 | elif i in [2, 3]: 84 | set_2 = set([2, 3]) 85 | ix = np.random.choice(list(set_2.difference(set_1))) 86 | tmp = individ[ix] 87 | individ[ix] = individ[i] 88 | individ[i] = tmp 89 | elif i in [5, 6, 8, 9]: 90 | set_2 = set([5, 6, 8, 9]) 91 | ix = np.random.choice(list(set_2.difference(set_1))) 92 | tmp = individ[ix] 93 | individ[ix] = individ[i] 94 | individ[i] = tmp 95 | return individ 96 | 97 | 98 | def mutation_combined(individ, elem_mutation_prob=0.1, **kwargs): 99 | if random.random() <= individ[14]: # TODO: check 14th position 100 | return mutation_one_param(individ, elem_mutation_prob) 101 | else: 102 | return positioning_mutation(individ, elem_mutation_prob) 103 | pass 104 | 105 | 106 | def do_swap_in_ranges(individ, i, ranges): 107 | swap_range = next((r for r in ranges if i in r), None) 108 | if swap_range is not None: 109 | swap_range = [j for j in swap_range if i != j] 110 | j = np.random.choice(swap_range) 111 | individ[i], individ[j] = individ[j], individ[i] 112 | 113 | 114 | PSM_NEW_SWAP_RANGES = [[1, 4, 7, 10], [2, 3], [5, 6, 8, 9]] 115 | 116 | 117 | def mutation_psm_new(individ, elem_mutation_prob, **kwargs): 118 | for i in range(len(individ)): 119 | if random.random() < elem_mutation_prob: 120 | if i == 0: 121 | individ[i] = individ[15] 122 | elif i == 15: 123 | individ[i] = individ[0] 124 | else: 125 | do_swap_in_ranges(individ, i, PSM_NEW_SWAP_RANGES) 126 | return individ 127 | 128 | 129 | PSM_SWAP_RANGES = [[1, 4, 7], [2, 5], [3, 6]] 130 | 131 | 132 | def mutation_psm(individ, elem_mutation_prob=0.1, **kwargs): 133 | for i in range(len(individ)): 134 | if random.random() < elem_mutation_prob: 135 | if i == 0: 136 | individ[i] = np.random.uniform(low=1, high=100, size=1)[0] 137 | else: 138 | do_swap_in_ranges(individ, i, PSM_SWAP_RANGES) 139 | return individ 140 | 141 | 142 | def mutation(mutation_type="mutation_one_param"): 143 | if mutation_type == "mutation_one_param": 144 | return mutation_one_param 145 | if mutation_type == "combined": 146 | return mutation_combined 147 | if mutation_type == "psm": 148 | return mutation_psm 149 | if mutation_type == "positioning_mutation": 150 | return positioning_mutation 151 | -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/genetic_algorithm/selection.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import numpy as np 3 | import random 4 | 5 | 6 | # TODO: roulette wheel selection, stochastic universal sampling and tournament selection 7 | 8 | 9 | def yield_matching_pairs(pairs, population): 10 | population.sort(key=operator.attrgetter("fitness_value")) 11 | population_pairs_pool = [] 12 | 13 | while len(population_pairs_pool) < pairs: 14 | chosen = [] 15 | idx = 0 16 | selection_probability = random.random() 17 | for ix, individ in enumerate(population): 18 | if selection_probability <= individ._prob: 19 | idx = ix 20 | chosen.append(individ) 21 | break 22 | 23 | selection_probability = random.random() 24 | for k, individ in enumerate(population): 25 | if k != idx: 26 | if selection_probability <= individ._prob: 27 | elems = frozenset((idx, k)) 28 | if (len(population_pairs_pool) == 0) or ( 29 | (len(population_pairs_pool) > 0) 30 | and (elems not in population_pairs_pool) 31 | ): 32 | chosen.append(individ) 33 | population_pairs_pool.append(elems) 34 | break 35 | else: 36 | continue 37 | if len(chosen) == 1: 38 | selection_idx = np.random.choice( 39 | [m for m in [i for i in range(len(population))] if m != idx] 40 | ) 41 | chosen.append(population[selection_idx]) 42 | if len(chosen) == 0: 43 | yield None, None 44 | else: 45 | yield chosen[0], chosen[1] 46 | 47 | 48 | def selection_fitness_prop(population, best_proc, children_num): 49 | all_fitness = [] 50 | 51 | for individ in population: 52 | all_fitness.append(individ.fitness_value) 53 | fitness_std = np.std(all_fitness) 54 | fitness_mean = np.mean(all_fitness) 55 | cumsum_fitness = 0 56 | # adjust probabilities with sigma scaling 57 | c = 2 58 | for individ in population: 59 | updated_individ_fitness = max(individ.fitness_value - (fitness_mean - c * fitness_std), 0) 60 | cumsum_fitness += updated_individ_fitness 61 | individ._prob = updated_individ_fitness / cumsum_fitness 62 | pairs_count = len(population) * (1 - best_proc) 63 | if children_num == 2: 64 | pairs_count //= 2 65 | return yield_matching_pairs(round(pairs_count), population) 66 | 67 | 68 | def selection_rank_based(population, best_proc, children_num): 69 | population.sort(key=operator.attrgetter("fitness_value")) 70 | for ix, individ in enumerate(population): 71 | individ._prob = 2 * (ix + 1) / (len(population) * (len(population) - 1)) 72 | if children_num == 2: 73 | # new population size 74 | return yield_matching_pairs( 75 | round((len(population) * (1 - best_proc))), population 76 | ) 77 | else: 78 | return yield_matching_pairs( 79 | round((len(population) * (1 - best_proc))), population 80 | ) 81 | 82 | 83 | def stochastic_universal_sampling(): 84 | raise NotImplementedError 85 | 86 | 87 | def selection(selection_type="fitness_prop"): 88 | if selection_type == "fitness_prop": 89 | return selection_fitness_prop 90 | if selection_type == "rank_based": 91 | return selection_rank_based 92 | -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/genetic_algorithm/statistics_collector.py: -------------------------------------------------------------------------------- 1 | from autotm.algorithms_for_tuning.individuals import Individual 2 | 3 | 4 | class StatisticsCollector: 5 | """ 6 | This logger handles collection of statistics 7 | """ 8 | 9 | def log_iteration(self, evaluations: int, best_fitness: float): 10 | """ 11 | :param evaluations: the number of used evaluations 12 | :param best_fitness: the best fitness in the current iteration 13 | """ 14 | pass 15 | 16 | def log_individual(self, individual: Individual): 17 | """ 18 | :param individual: a new evaluated individual 19 | """ 20 | pass 21 | -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/genetic_algorithm/strategy.py: -------------------------------------------------------------------------------- 1 | # strategies of automatic EA configuration 2 | 3 | # CMA-ES (Covariance Matrix Adaptation Evolution Strategy) 4 | 5 | from math import sqrt, log 6 | 7 | # code source: https://github.com/DEAP/deap/blob/master/deap/cma.py 8 | import numpy 9 | 10 | 11 | class Strategy(object): 12 | def __init__(self, centroid, sigma, **kwargs): 13 | self.params = kwargs 14 | 15 | # Create a centroid as a numpy array 16 | self.centroid = numpy.array(centroid) 17 | 18 | self.dim = len(self.centroid) 19 | self.sigma = sigma 20 | self.pc = numpy.zeros(self.dim) 21 | self.ps = numpy.zeros(self.dim) 22 | self.chiN = sqrt(self.dim) * ( 23 | 1 - 1.0 / (4.0 * self.dim) + 1.0 / (21.0 * self.dim**2) 24 | ) 25 | 26 | self.C = self.params.get("cmatrix", numpy.identity(self.dim)) 27 | self.diagD, self.B = numpy.linalg.eigh(self.C) 28 | 29 | indx = numpy.argsort(self.diagD) 30 | self.diagD = self.diagD[indx] ** 0.5 31 | self.B = self.B[:, indx] 32 | self.BD = self.B * self.diagD 33 | 34 | self.cond = self.diagD[indx[-1]] / self.diagD[indx[0]] 35 | 36 | self.lambda_ = self.params.get("lambda_", int(4 + 3 * log(self.dim))) 37 | self.update_count = 0 38 | self.computeParams(self.params) 39 | 40 | def generate(self, ind_init): 41 | r"""Generate a population of :math:`\lambda` individuals of type 42 | *ind_init* from the current strategy. 43 | :param ind_init: A function object that is able to initialize an 44 | individual from a list. 45 | :returns: A list of individuals. 46 | """ 47 | arz = numpy.random.standard_normal((self.lambda_, self.dim)) 48 | arz = self.centroid + self.sigma * numpy.dot(arz, self.BD.T) 49 | return [ind_init(a) for a in arz] 50 | 51 | def update(self, population): 52 | """Update the current covariance matrix strategy from the 53 | *population*. 54 | :param population: A list of individuals from which to update the 55 | parameters. 56 | """ 57 | population.sort(key=lambda ind: ind.fitness, reverse=True) 58 | 59 | old_centroid = self.centroid 60 | self.centroid = numpy.dot(self.weights, population[0: self.mu]) 61 | 62 | c_diff = self.centroid - old_centroid 63 | 64 | # Cumulation : update evolution path 65 | self.ps = (1 - self.cs) * self.ps + sqrt( 66 | self.cs * (2 - self.cs) * self.mueff 67 | ) / self.sigma * numpy.dot( 68 | self.B, (1.0 / self.diagD) * numpy.dot(self.B.T, c_diff) 69 | ) 70 | 71 | hsig = float( 72 | ( 73 | numpy.linalg.norm(self.ps) 74 | / sqrt(1.0 - (1.0 - self.cs) ** (2.0 * (self.update_count + 1.0))) 75 | / self.chiN 76 | < (1.4 + 2.0 / (self.dim + 1.0)) 77 | ) 78 | ) 79 | 80 | self.update_count += 1 81 | 82 | self.pc = (1 - self.cc) * self.pc + hsig * sqrt( 83 | self.cc * (2 - self.cc) * self.mueff 84 | ) / self.sigma * c_diff 85 | 86 | # Update covariance matrix 87 | artmp = population[0: self.mu] - old_centroid 88 | self.C = ( 89 | ( 90 | 1 91 | - self.ccov1 92 | - self.ccovmu 93 | + (1 - hsig) * self.ccov1 * self.cc * (2 - self.cc) 94 | ) 95 | * self.C 96 | + self.ccov1 * numpy.outer(self.pc, self.pc) 97 | + self.ccovmu * numpy.dot((self.weights * artmp.T), artmp) / self.sigma**2 98 | ) 99 | 100 | self.sigma *= numpy.exp( 101 | (numpy.linalg.norm(self.ps) / self.chiN - 1.0) * self.cs / self.damps 102 | ) 103 | 104 | self.diagD, self.B = numpy.linalg.eigh(self.C) 105 | indx = numpy.argsort(self.diagD) 106 | 107 | self.cond = self.diagD[indx[-1]] / self.diagD[indx[0]] 108 | 109 | self.diagD = self.diagD[indx] ** 0.5 110 | self.B = self.B[:, indx] 111 | self.BD = self.B * self.diagD 112 | 113 | def computeParams(self, params): 114 | r"""Computes the parameters depending on :math:`\lambda`. It needs to 115 | be called again if :math:`\lambda` changes during evolution. 116 | :param params: A dictionary of the manually set parameters. 117 | """ 118 | self.mu = params.get("mu", int(self.lambda_ / 2)) 119 | rweights = params.get("weights", "superlinear") 120 | if rweights == "superlinear": 121 | self.weights = log(self.mu + 0.5) - numpy.log(numpy.arange(1, self.mu + 1)) 122 | elif rweights == "linear": 123 | self.weights = self.mu + 0.5 - numpy.arange(1, self.mu + 1) 124 | elif rweights == "equal": 125 | self.weights = numpy.ones(self.mu) 126 | else: 127 | raise RuntimeError("Unknown weights : %s" % rweights) 128 | 129 | self.weights /= sum(self.weights) 130 | self.mueff = 1.0 / sum(self.weights**2) 131 | 132 | self.cc = params.get("ccum", 4.0 / (self.dim + 4.0)) 133 | self.cs = params.get("cs", (self.mueff + 2.0) / (self.dim + self.mueff + 3.0)) 134 | self.ccov1 = params.get("ccov1", 2.0 / ((self.dim + 1.3) ** 2 + self.mueff)) 135 | self.ccovmu = params.get( 136 | "ccovmu", 137 | 2.0 * (self.mueff - 2.0 + 1.0 / self.mueff) / ((self.dim + 2.0) ** 2 + self.mueff), 138 | ) 139 | self.ccovmu = min(1 - self.ccov1, self.ccovmu) 140 | self.damps = ( 141 | 1.0 + 2.0 * max(0.0, sqrt((self.mueff - 1.0) / (self.dim + 1.0)) - 1.0) + self.cs 142 | ) 143 | self.damps = params.get("damps", self.damps) 144 | -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/genetic_algorithm/surrogate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from sklearn.ensemble import RandomForestRegressor, BaggingRegressor 5 | from sklearn.gaussian_process import GaussianProcessRegressor 6 | from sklearn.gaussian_process.kernels import RBF, ConstantKernel, Matern, WhiteKernel, ExpSineSquared, RationalQuadratic 7 | from sklearn.metrics import mean_squared_error 8 | from sklearn.neural_network import MLPRegressor 9 | from sklearn.svm import SVR 10 | from sklearn.tree import DecisionTreeRegressor 11 | 12 | logger = logging.getLogger("GA_algo") 13 | 14 | 15 | # TODO: Add fitness type 16 | def set_surrogate_fitness(value, fitness_type="avg_coherence_score"): 17 | npmis = { 18 | "npmi_50": None, 19 | "npmi_15": None, 20 | "npmi_25": None, 21 | "npmi_50_list": None, 22 | } 23 | scores_dict = { 24 | fitness_type: value, 25 | "perplexityScore": None, 26 | "backgroundTokensRatioScore": None, 27 | "contrast": None, 28 | "purity": None, 29 | "kernelSize": None, 30 | "npmi_50_list": [None], # npmi_values_50_list, 31 | "npmi_50": None, 32 | "sparsity_phi": None, 33 | "sparsity_theta": None, 34 | "topic_significance_uni": None, 35 | "topic_significance_vacuous": None, 36 | "topic_significance_back": None, 37 | "switchP_list": [None], 38 | "switchP": None, 39 | "all_topics": None, 40 | # **coherence_scores, 41 | **npmis, 42 | } 43 | return scores_dict 44 | 45 | 46 | class Surrogate: 47 | def __init__(self, surrogate_name, **kwargs): 48 | self.name = surrogate_name 49 | self.kwargs = kwargs 50 | self.surrogate = None 51 | self.br_n_estimators = None 52 | self.br_n_jobs = None 53 | self.gpr_kernel = None 54 | 55 | def create(self): 56 | kernel = self.kwargs["gpr_kernel"] 57 | del self.kwargs["gpr_kernel"] 58 | gpr_alpha = self.kwargs["gpr_alpha"] 59 | del self.kwargs["gpr_alpha"] 60 | normalize_y = self.kwargs["normalize_y"] 61 | del self.kwargs["normalize_y"] 62 | 63 | if self.name == "random-forest-regressor": 64 | self.surrogate = RandomForestRegressor(**self.kwargs) 65 | elif self.name == "mlp-regressor": 66 | if not self.br_n_estimators: 67 | self.br_n_estimators = self.kwargs["br_n_estimators"] 68 | del self.kwargs["br_n_estimators"] 69 | self.br_n_jobs = self.kwargs["n_jobs"] 70 | del self.kwargs["n_jobs"] 71 | self.kwargs["alpha"] = self.kwargs["mlp_alpha"] 72 | del self.kwargs["mlp_alpha"] 73 | self.surrogate = BaggingRegressor( 74 | base_estimator=MLPRegressor(**self.kwargs), 75 | n_estimators=self.br_n_estimators, 76 | n_jobs=self.br_n_jobs, 77 | ) 78 | elif self.name == "GPR": # tune ?? 79 | if not self.gpr_kernel: 80 | if kernel == "RBF": 81 | self.gpr_kernel = 1.0 * RBF(1.0) 82 | elif kernel == "RBFwithConstant": 83 | self.gpr_kernel = 1.0 * RBF(1.0) + ConstantKernel() 84 | elif kernel == "Matern": 85 | self.gpr_kernel = 1.0 * Matern(1.0) 86 | elif kernel == "WhiteKernel": 87 | self.gpr_kernel = 1.0 * WhiteKernel(1.0) 88 | elif kernel == "ExpSineSquared": 89 | self.gpr_kernel = ExpSineSquared() 90 | elif kernel == "RationalQuadratic": 91 | self.gpr_kernel = RationalQuadratic(1.0) 92 | self.kwargs["kernel"] = self.gpr_kernel 93 | self.kwargs["alpha"] = gpr_alpha 94 | self.kwargs["normalize_y"] = normalize_y 95 | self.surrogate = GaussianProcessRegressor(**self.kwargs) 96 | elif self.name == "decision-tree-regressor": 97 | try: 98 | if self.kwargs["max_depth"] == 0: 99 | self.kwargs["max_depth"] = None 100 | except KeyError: 101 | logger.error("No max_depth") 102 | self.surrogate = DecisionTreeRegressor(**self.kwargs) 103 | elif self.name == "SVR": 104 | self.surrogate = SVR(**self.kwargs) 105 | # else: 106 | # raise Exception('Undefined surr') 107 | 108 | def fit(self, X, y): 109 | logger.debug(f"X: {X}, y: {y}") 110 | self.create() 111 | self.surrogate.fit(X, y) 112 | 113 | def score(self, X, y): 114 | r_2 = self.surrogate.score(X, y) 115 | y_pred = self.surrogate.predict(X) 116 | mse = mean_squared_error(y, y_pred) 117 | rmse = np.sqrt(mse) 118 | return r_2, mse, rmse 119 | 120 | def predict(self, X): 121 | m = self.surrogate.predict(X) 122 | return m 123 | 124 | 125 | def get_prediction_uncertanty(model, X, surrogate_name, percentile=90): 126 | interval_len = [] 127 | if surrogate_name == "random-forest-regressor": 128 | for x in range(len(X)): 129 | preds = [] 130 | for pred in model.estimators_: 131 | prediction = pred.predict(np.array(X[x]).reshape(1, -1)) 132 | preds.append(prediction[0]) 133 | err_down = np.percentile(preds, (100 - percentile) / 2.0) 134 | err_up = np.percentile(preds, 100 - (100 - percentile) / 2.0) 135 | interval_len.append(err_up - err_down) 136 | elif surrogate_name == "GPR": 137 | y_hat, y_sigma = model.predict(X, return_std=True) 138 | interval_len = list(y_sigma) 139 | elif surrogate_name == "decision-tree-regressor": 140 | raise NotImplementedError 141 | return interval_len 142 | -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/individuals.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from abc import ABC, abstractmethod 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from autotm.abstract_params import AbstractParams 9 | from autotm.schemas import IndividualDTO 10 | from autotm.utils import AVG_COHERENCE_SCORE, LLM_SCORE 11 | 12 | SPARSITY_PHI = "sparsity_phi" 13 | SPARSITY_THETA = "sparsity_theta" 14 | SWITCHP_SCORE = "switchP" 15 | DF_NAMES = {"20ng": 0, "lentaru": 1, "amazon_food": 2} 16 | 17 | METRICS_COLS = [ 18 | "avg_coherence_score", 19 | "perplexityScore", 20 | "backgroundTokensRatioScore", 21 | "avg_switchp", 22 | "coherence_10", 23 | "coherence_15", 24 | "coherence_20", 25 | "coherence_25", 26 | "coherence_30", 27 | "coherence_35", 28 | "coherence_40", 29 | "coherence_45", 30 | "coherence_50", 31 | "coherence_55", 32 | "contrast", 33 | "purity", 34 | "kernelSize", 35 | "sparsity_phi", 36 | "sparsity_theta", 37 | "topic_significance_uni", 38 | "topic_significance_vacuous", 39 | "topic_significance_back", 40 | "npmi_15", 41 | "npmi_25", 42 | "npmi_50", 43 | ] 44 | 45 | PATH_TO_LEARNED_SCORING = "./scoring_func" 46 | 47 | 48 | class Individual(ABC): 49 | id: str 50 | 51 | @property 52 | @abstractmethod 53 | def dto(self) -> IndividualDTO: 54 | ... 55 | 56 | @property 57 | @abstractmethod 58 | def fitness_value(self) -> float: 59 | ... 60 | 61 | @property 62 | @abstractmethod 63 | def params(self) -> AbstractParams: 64 | ... 65 | 66 | 67 | class BaseIndividual(Individual, ABC): 68 | def __init__(self, dto: IndividualDTO): 69 | self._dto = dto 70 | 71 | @property 72 | def dto(self) -> IndividualDTO: 73 | return self._dto 74 | 75 | @property 76 | def params(self) -> AbstractParams: 77 | return self.dto.params 78 | 79 | 80 | class RegularFitnessIndividual(BaseIndividual): 81 | @property 82 | def fitness_value(self) -> float: 83 | return self.dto.fitness_value[AVG_COHERENCE_SCORE] 84 | 85 | 86 | class LearnedModel: 87 | def __init__(self, save_path, dataset_name): 88 | dataset_id = DF_NAMES[dataset_name] 89 | general_save_path = os.path.join(save_path, "general") 90 | native_save_path = os.path.join(save_path, "native") 91 | with open( 92 | os.path.join(general_save_path, f"general_automl_{dataset_id}.pickle"), "rb" 93 | ) as f: 94 | self.general_model = pickle.load(f) 95 | self.native_model = [] 96 | for i in range(5): 97 | with open( 98 | os.path.join( 99 | native_save_path, f"native_automl_{dataset_id}_fold_{i}.pickle" 100 | ), 101 | "rb", 102 | ) as f: 103 | self.native_model.append(pickle.load(f)) 104 | 105 | def general_predict(self, df: pd.DataFrame): 106 | y = self.general_model.predict(df[METRICS_COLS]) 107 | return y 108 | 109 | def native_predict(self, df: pd.DataFrame): 110 | y = [] 111 | for k, nm in enumerate(self.native_model): 112 | y.append(nm.predict(df[METRICS_COLS])) 113 | y = np.array(y) 114 | return np.mean(y, axis=0) 115 | 116 | 117 | class LearnedScorerFitnessIndividual(BaseIndividual): 118 | @property 119 | def fitness_value(self) -> float: 120 | # dataset_name = self.dto.dataset # TODO: check namings 121 | # m = LearnedModel(save_path=PATH_TO_LEARNED_SCORING, dataset_name=dataset_name) 122 | # TODO: predict from metrics df 123 | raise NotImplementedError() 124 | 125 | 126 | class SparsityScalerBasedFitnessIndividual(BaseIndividual): 127 | @property 128 | def fitness_value(self) -> float: 129 | # it is a handling of the situation when a fitness-worker wasn't able to correctly calculate this indvidual 130 | # due to some error in the proceess 131 | # and thus the fitness value doesn't have any metrics except dummy AVG_COHERENCE_SCORE equal to zero 132 | if self.dto.fitness_value[AVG_COHERENCE_SCORE] < 0.00000001: 133 | return 0.0 134 | 135 | alpha = 0.7 136 | if 0.2 <= self.dto.fitness_value[SPARSITY_THETA] <= 0.8: 137 | alpha = 1 138 | # if SWITCHP_SCORE in self.dto.fitness_value: 139 | # return alpha * (self.dto.fitness_value[AVG_COHERENCE_SCORE] + self.dto.fitness_value[SWITCHP_SCORE]) 140 | # else: 141 | # return alpha * self.dto.fitness_value[AVG_COHERENCE_SCORE] 142 | return alpha * self.dto.fitness_value[AVG_COHERENCE_SCORE] 143 | 144 | 145 | class LLMBasedFitnessIndividual(BaseIndividual): 146 | @property 147 | def fitness_value(self) -> float: 148 | return self.dto.fitness_value.get(LLM_SCORE, 0.0) 149 | 150 | 151 | class IndividualBuilder: 152 | SUPPORTED_IND_TYPES = ["regular", "sparse", "llm"] 153 | 154 | def __init__(self, ind_type: str = "regular"): 155 | self._ind_type = ind_type 156 | 157 | if self._ind_type not in self.SUPPORTED_IND_TYPES: 158 | raise ValueError(f"Unsupported ind type: {self._ind_type}") 159 | 160 | @property 161 | def individual_type(self) -> str: 162 | return self._ind_type 163 | 164 | def make_individual(self, dto: IndividualDTO) -> Individual: 165 | if self._ind_type == "regular": 166 | return RegularFitnessIndividual(dto=dto) 167 | elif self._ind_type == "sparse": 168 | return SparsityScalerBasedFitnessIndividual(dto=dto) 169 | else: 170 | return LLMBasedFitnessIndividual(dto=dto) 171 | -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/nelder_mead_optimization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/algorithms_for_tuning/nelder_mead_optimization/__init__.py -------------------------------------------------------------------------------- /autotm/algorithms_for_tuning/nelder_mead_optimization/nelder_mead.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from scipy.optimize import minimize 5 | 6 | from autotm.fitness.tm import FitnessCalculatorWrapper 7 | 8 | 9 | class NelderMeadOptimization: 10 | def __init__( 11 | self, 12 | dataset, 13 | data_path, 14 | exp_id, 15 | topic_count, 16 | train_option, 17 | low_decor=0, 18 | high_decor=1e5, 19 | low_n=0, 20 | high_n=30, 21 | low_back=0, 22 | high_back=5, 23 | low_spb=0, 24 | high_spb=1e2, 25 | low_spm=-1e-3, 26 | high_spm=1e2, 27 | low_sp_phi=-1e3, 28 | high_sp_phi=1e3, 29 | low_prob=0, 30 | high_prob=1, 31 | ): 32 | self.dataset = (dataset,) 33 | self.data_path = data_path 34 | self.exp_id = exp_id 35 | self.topic_count = topic_count 36 | self.train_option = train_option 37 | self.high_decor = high_decor 38 | self.low_decor = low_decor 39 | self.low_n = low_n 40 | self.high_n = high_n 41 | self.low_back = low_back 42 | self.high_back = high_back 43 | self.high_spb = high_spb 44 | self.low_spb = low_spb 45 | self.low_spm = low_spm 46 | self.high_spm = high_spm 47 | self.low_sp_phi = low_sp_phi 48 | self.high_sp_phi = high_sp_phi 49 | self.low_prob = low_prob 50 | self.high_prob = high_prob 51 | 52 | def initialize_params(self): 53 | val_decor, val_decor_2 = np.random.uniform(low=self.low_decor, high=self.high_decor, size=2) 54 | var_n = np.random.randint(low=self.low_n, high=self.high_n, size=4) 55 | var_back = np.random.randint(low=self.low_back, high=self.high_back, size=1)[0] 56 | var_sm = np.random.uniform(low=self.low_spb, high=self.high_spb, size=2) 57 | var_sp = np.random.uniform(low=self.low_sp_phi, high=self.high_sp_phi, size=4) 58 | params = [ 59 | val_decor, 60 | var_n[0], 61 | var_sm[0], 62 | var_sm[1], 63 | var_n[1], 64 | var_sp[0], 65 | var_sp[1], 66 | var_n[2], 67 | var_sp[2], 68 | var_sp[3], 69 | var_n[3], 70 | var_back, 71 | val_decor_2, 72 | ] 73 | params = [float(i) for i in params] 74 | return params 75 | 76 | def run_algorithm(self, num_iterations: int = 400, ini_point: list = None): 77 | fitness_calculator = FitnessCalculatorWrapper( 78 | self.dataset, self.data_path, self.topic_count, self.train_option 79 | ) 80 | 81 | if ini_point is None: 82 | initial_point = self.initialize_params() 83 | else: 84 | assert len(ini_point) == 13 85 | logging.info(ini_point) # TODO: remove this 86 | ini_point = [float(i) for i in ini_point] 87 | initial_point = ini_point 88 | 89 | res = minimize( 90 | fitness_calculator.run, 91 | initial_point, 92 | bounds=[ 93 | (self.low_decor, self.high_decor), 94 | (self.low_n, self.high_n), 95 | (self.low_spb, self.high_spb), 96 | (self.low_spb, self.high_spb), 97 | (self.low_n, self.high_n), 98 | (self.low_sp_phi, self.high_sp_phi), 99 | (self.low_sp_phi, self.high_sp_phi), 100 | (self.low_n, self.high_n), 101 | (self.low_sp_phi, self.high_sp_phi), 102 | (self.low_sp_phi, self.high_sp_phi), 103 | (self.low_n, self.high_n), 104 | (self.low_back, self.high_back), 105 | (self.low_decor, self.high_decor), 106 | ], 107 | method="Nelder-Mead", 108 | options={"return_all": True, "maxiter": num_iterations}, 109 | ) 110 | 111 | return res 112 | -------------------------------------------------------------------------------- /autotm/batch_vect_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from artm.batches_utils import BatchVectorizer, Batch 3 | import glob 4 | import random 5 | 6 | 7 | class SampleBatchVectorizer(BatchVectorizer): 8 | def __init__(self, sample_size=10, **kwargs): 9 | self.sample_size = sample_size 10 | super().__init__(**kwargs) 11 | 12 | def _parse_batches(self, data_weight, batches): 13 | if self._process_in_memory: 14 | self._model.master.import_batches(batches) 15 | self._batches_list = [batch.id for batch in batches] 16 | return 17 | 18 | data_paths, data_weights, target_folders = self._populate_data( 19 | data_weight, True 20 | ) 21 | for data_p, data_w, target_f in zip(data_paths, data_weights, target_folders): 22 | if batches is None: 23 | batch_filenames = glob.glob(os.path.join(data_p, "*.batch")) 24 | if len(batch_filenames) < self.sample_size: 25 | self.sample_size = len(batch_filenames) 26 | batch_filenames = random.sample(batch_filenames, self.sample_size) 27 | self._batches_list += [Batch(filename) for filename in batch_filenames] 28 | 29 | if len(self._batches_list) < 1: 30 | raise RuntimeError("No batches were found") 31 | 32 | self._weights += [data_w for i in range(len(batch_filenames))] 33 | else: 34 | self._batches_list += [ 35 | Batch(os.path.join(data_p, batch)) for batch in batches 36 | ] 37 | self._weights += [data_w for i in range(len(batches))] 38 | -------------------------------------------------------------------------------- /autotm/clustering.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import copy 3 | import warnings 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn.cluster import KMeans 9 | from sklearn.manifold import TSNE 10 | from sklearn.preprocessing import StandardScaler 11 | 12 | warnings.filterwarnings('ignore') 13 | 14 | 15 | def cluster_phi(phi_df: pd.DataFrame, n_clusters=10, plot_img=True): 16 | _phi_df = copy.deepcopy(phi_df) 17 | y = _phi_df.index.values 18 | x = _phi_df.values 19 | standardized_x = StandardScaler().fit_transform(x) 20 | y_kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(standardized_x) 21 | 22 | if plot_img: 23 | tsne = TSNE(n_components=2).fit_transform(standardized_x) 24 | plt.scatter(tsne[:, 0], tsne[:, 1], c=y_kmeans.labels_, s=6, cmap='Spectral') 25 | plt.gca().set_aspect('equal', 'datalim') 26 | plt.colorbar(boundaries=np.arange(11) - 0.5).set_ticks(np.arange(10)) 27 | plt.title('Scatterplot of lenta.ru data', fontsize=24) 28 | 29 | _phi_df['labels'] = y_kmeans.labels_ 30 | return _phi_df 31 | 32 | 33 | -------------------------------------------------------------------------------- /autotm/content_splitter.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class BaseTextSplitter(ABC): 5 | 6 | def text_process(self): 7 | raise NotImplementedError 8 | 9 | def transform(self): 10 | raise NotImplementedError 11 | 12 | def text_split(self): 13 | raise NotImplementedError 14 | 15 | class TextSplitter(BaseTextSplitter): 16 | def content_splitting(self): 17 | raise NotImplementedError 18 | 19 | -------------------------------------------------------------------------------- /autotm/data_generator/synthetic_data_generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "outputs": [], 7 | "source": [ 8 | "import numpy as np" 9 | ], 10 | "metadata": { 11 | "collapsed": false 12 | } 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "outputs": [], 18 | "source": [ 19 | "def data_generator(\n", 20 | " vocab_size=10000,\n", 21 | " num_documents=5000,\n", 22 | " mean_text_len=15,\n", 23 | " std_text_len=5,\n", 24 | " topics_num=10,\n", 25 | " distribution=\"dirichlet\",\n", 26 | "):\n", 27 | " vocabulary = [i for i in range(vocab_size)]\n", 28 | "\n", 29 | " for doc in num_documents:\n", 30 | " doc_len = int(np.random.normal(loc=mean_text_len, scale=std_text_len))" 31 | ], 32 | "metadata": { 33 | "collapsed": false 34 | } 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "outputs": [], 40 | "source": [], 41 | "metadata": { 42 | "collapsed": false 43 | } 44 | } 45 | ], 46 | "metadata": { 47 | "kernelspec": { 48 | "display_name": "Python 3", 49 | "language": "python", 50 | "name": "python3" 51 | }, 52 | "language_info": { 53 | "codemirror_mode": { 54 | "name": "ipython", 55 | "version": 2 56 | }, 57 | "file_extension": ".py", 58 | "mimetype": "text/x-python", 59 | "name": "python", 60 | "nbconvert_exporter": "python", 61 | "pygments_lexer": "ipython2", 62 | "version": "2.7.6" 63 | } 64 | }, 65 | "nbformat": 4, 66 | "nbformat_minor": 0 67 | } 68 | -------------------------------------------------------------------------------- /autotm/fitness/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # local or cluster 4 | SUPPORTED_EXEC_MODES = ['local', 'cluster'] 5 | AUTOTM_EXEC_MODE = os.environ.get("AUTOTM_EXEC_MODE", "local") 6 | 7 | 8 | # head or worker 9 | SUPPORTED_COMPONENTS = ['head', 'worker'] 10 | AUTOTM_COMPONENT = os.environ.get("AUTOTM_COMPONENT", "head") 11 | -------------------------------------------------------------------------------- /autotm/fitness/external_scores.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # additional scores to calculate 5 | 6 | 7 | # Topic Significance 8 | def kl_divergence(p, q): 9 | return np.sum(np.where(p != 0, p * np.log(p / q), 0)) 10 | 11 | 12 | # Uniform Distribution Over Words (W-Uniform) 13 | def ts_uniform(topic_word_dist): 14 | n_words = topic_word_dist.shape[0] 15 | w_uniform = np.ones(n_words) / n_words 16 | uniform_distances_kl = [kl_divergence(p, w_uniform) for p in topic_word_dist.T] 17 | return uniform_distances_kl 18 | 19 | 20 | # Vacuous Semantic Distribution (W-Vacuous) 21 | def ts_vacuous(doc_topic_dist, topic_word_dist, total_tokens): 22 | n_words = topic_word_dist.shape[0] 23 | # n_tokens = np.sum([len(text) for text in texts]) 24 | p_k = np.sum(doc_topic_dist, axis=1) / total_tokens 25 | w_vacauous = np.sum(topic_word_dist * np.tile(p_k, (n_words, 1)), axis=1) 26 | vacauous_distances_kl = [kl_divergence(p, w_vacauous) for p in topic_word_dist.T] 27 | return vacauous_distances_kl 28 | 29 | 30 | # Background Distribution (D-BGround) 31 | def ts_bground(doc_topic_dist): 32 | n_documents = doc_topic_dist.shape[1] 33 | d_bground = np.ones(n_documents) / n_documents 34 | d_bground_distances_kl = [kl_divergence(p.T, d_bground) for p in doc_topic_dist] 35 | return d_bground_distances_kl 36 | 37 | 38 | # SwitchP 39 | def switchp(phi, texts): 40 | words = phi.index.to_list() 41 | max_topic_word_dist = np.argmax(phi.to_numpy(), axis=1) 42 | max_topic_word_dist = dict(zip(words, max_topic_word_dist)) 43 | switchp_scores = [] 44 | for text in texts: 45 | mapped_text = [ 46 | max_topic_word_dist[word] 47 | for word in text.split() 48 | if word in max_topic_word_dist 49 | ] 50 | switches = (np.diff(mapped_text) != 0).sum() 51 | switchp_scores.append(switches / (len(mapped_text) - 1)) 52 | 53 | return switchp_scores 54 | -------------------------------------------------------------------------------- /autotm/fitness/local_tasks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional 3 | 4 | from autotm.algorithms_for_tuning.individuals import Individual, IndividualBuilder 5 | from autotm.fitness.tm import fit_tm_of_individual 6 | from autotm.params_logging_utils import log_params_and_artifacts, log_stats, model_files 7 | from autotm.schemas import IndividualDTO 8 | 9 | logger = logging.getLogger("root") 10 | 11 | 12 | def do_fitness_calculating( 13 | individual: IndividualDTO, 14 | log_artifact_and_parameters: bool = False, 15 | log_run_stats: bool = False, 16 | alg_args: Optional[str] = None, 17 | is_tmp: bool = False, 18 | ) -> IndividualDTO: 19 | # make copy 20 | individual_json = individual.model_dump_json() 21 | individual = IndividualDTO.model_validate_json(individual_json) 22 | logger.info("Doing fitness calculating with individual %s" % individual_json) 23 | 24 | with fit_tm_of_individual( 25 | dataset=individual.dataset, 26 | data_path=individual.data_path, 27 | params=individual.params, 28 | fitness_name=individual.fitness_name, 29 | topic_count=individual.topic_count, 30 | force_dataset_settings_checkout=individual.force_dataset_settings_checkout, 31 | train_option=individual.train_option, 32 | ) as (time_metrics, metrics, tm): 33 | individual.fitness_value = metrics 34 | 35 | with model_files(tm) as tm_files: 36 | if log_artifact_and_parameters: 37 | log_params_and_artifacts( 38 | tm, tm_files, individual, time_metrics, alg_args, is_tmp=is_tmp 39 | ) 40 | 41 | if log_run_stats: 42 | log_stats(tm, tm_files, individual, time_metrics, alg_args) 43 | 44 | return individual 45 | 46 | 47 | def calculate_fitness( 48 | individual: IndividualDTO, 49 | log_artifact_and_parameters: bool = False, 50 | log_run_stats: bool = False, 51 | alg_args: Optional[str] = None, 52 | is_tmp: bool = False, 53 | ) -> IndividualDTO: 54 | return do_fitness_calculating( 55 | individual, log_artifact_and_parameters, log_run_stats, alg_args, is_tmp 56 | ) 57 | 58 | 59 | def estimate_fitness(ibuilder: IndividualBuilder, population: List[Individual]) -> List[Individual]: 60 | logger.info("Calculating fitness...") 61 | population_with_fitness = [] 62 | for individual in population: 63 | if individual.dto.fitness_value is not None: 64 | logger.info("Fitness value already calculated") 65 | population_with_fitness.append(individual) 66 | continue 67 | individ_with_fitness = calculate_fitness(individual.dto) 68 | population_with_fitness.append(ibuilder.make_individual(individ_with_fitness)) 69 | logger.info("The fitness results have been obtained") 70 | return population_with_fitness 71 | 72 | 73 | def log_best_solution( 74 | ibuilder: IndividualBuilder, 75 | individual: Individual, 76 | wait_for_result_timeout: Optional[float] = None, 77 | alg_args: Optional[str] = None, 78 | is_tmp: bool = False, 79 | ): 80 | logger.info("Sending a best individual to be logged") 81 | res = ibuilder.make_individual(calculate_fitness(individual.dto, 82 | log_artifact_and_parameters=True, 83 | is_tmp=is_tmp)) 84 | 85 | # TODO: write logging 86 | return res 87 | -------------------------------------------------------------------------------- /autotm/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pprint 4 | import tempfile 5 | from contextlib import contextmanager 6 | from typing import Optional 7 | 8 | import click 9 | import pandas as pd 10 | import yaml 11 | 12 | from autotm.base import AutoTM 13 | 14 | # TODO: add proper logging format initialization 15 | logging.basicConfig(level=logging.INFO) 16 | logger = logging.getLogger() 17 | 18 | 19 | def obtain_autotm_params( 20 | config_path: str, 21 | topic_count: Optional[int], 22 | lang: Optional[str], 23 | alg: Optional[str], 24 | surrogate_alg: Optional[str], 25 | log_file: Optional[str] 26 | ): 27 | # TODO: define clearly format of config 28 | if config_path is not None: 29 | logger.info(f"Reading config from path: {os.path.abspath(config_path)}") 30 | with open(config_path, "r") as f: 31 | config = yaml.safe_load(f) 32 | else: 33 | config = dict() 34 | 35 | if topic_count is not None: 36 | config['topic_count'] = topic_count 37 | if lang is not None: 38 | pp = config.get('preprocessing_params', dict()) 39 | pp['lang'] = lang 40 | config['preprocessing_params'] = pp 41 | if alg is not None: 42 | config['alg_name'] = alg 43 | if surrogate_alg is not None: 44 | config['surrogate_alg_name'] = surrogate_alg 45 | if log_file is not None: 46 | config['log_file_path'] = log_file 47 | 48 | return config 49 | 50 | 51 | @contextmanager 52 | def prepare_working_dir(working_dir: Optional[str] = None): 53 | if working_dir is None: 54 | with tempfile.TemporaryDirectory(prefix="autotm_wd_") as tmp_working_dir: 55 | yield tmp_working_dir 56 | else: 57 | yield working_dir 58 | 59 | 60 | @click.group() 61 | @click.option('-v', '--verbose', is_flag=True, show_default=True, default=False, help="Verbose output") 62 | def cli(verbose: bool): 63 | if verbose: 64 | logging.basicConfig(level=logging.DEBUG, force=True) 65 | 66 | 67 | @cli.command() 68 | @click.option('--config', 'config_path', type=str, help="A path to config for fitting the model") 69 | @click.option( 70 | '--working-dir', 71 | type=str, 72 | help="A path to working directory used by AutoTM for storing intermediate files. " 73 | "If not specified temporary directory will be created in the current directory " 74 | "and will be deleted upon successful finishing." 75 | ) 76 | @click.option('--in', 'in_', type=str, required=True, help="A file in csv format with text corpus to build model on") 77 | @click.option( 78 | '--out', 79 | type=str, 80 | default="mixtures.csv", 81 | help="A path to a file in csv format that will contain topic mixtures for texts from the incoming corpus" 82 | ) 83 | @click.option('--model', type=str, default='model.artm', help="A path that will contain fitted ARTM model") 84 | @click.option('-t', '--topic-count', type=int, help="Number of topics to fit model with") 85 | @click.option('--lang', type=str, help='Language of the dataset') 86 | @click.option('--alg', type=str, help="Hyperparameters tuning algorithm. Available: ga, bayes") 87 | @click.option('--surrogate-alg', type=str, help="Surrogate algorithm to use.") 88 | @click.option('--log-file', type=str, help="Log file path") 89 | @click.option( 90 | '--overwrite', 91 | is_flag=True, 92 | show_default=True, 93 | default=False, 94 | help="Overwrite if model or/and mixture files already exist" 95 | ) 96 | def fit( 97 | config_path: Optional[str], 98 | working_dir: Optional[str], 99 | in_: str, 100 | out: str, 101 | model: str, 102 | topic_count: Optional[int], 103 | lang: Optional[str], 104 | alg: Optional[str], 105 | surrogate_alg: Optional[str], 106 | log_file: Optional[str], 107 | overwrite: bool 108 | ): 109 | config = obtain_autotm_params(config_path, topic_count, lang, alg, surrogate_alg, log_file) 110 | 111 | logger.debug(f"Running AutoTM with params: {pprint.pformat(config, indent=4)}") 112 | 113 | logger.info(f"Loading data from {os.path.abspath(in_)}") 114 | df = pd.read_csv(in_) 115 | 116 | with prepare_working_dir(working_dir) as work_dir: 117 | logger.info(f"Using working directory {os.path.abspath(work_dir)} for AutoTM") 118 | autotm = AutoTM( 119 | **config, 120 | working_dir_path=work_dir 121 | ) 122 | mixtures = autotm.fit_predict(df) 123 | 124 | logger.info(f"Saving model to {os.path.abspath(model)}") 125 | autotm.save(model, overwrite=overwrite) 126 | logger.info(f"Saving mixtures to {os.path.abspath(out)}") 127 | mixtures.to_csv(out, mode='w' if overwrite else 'x') 128 | 129 | logger.info("Finished AutoTM") 130 | 131 | 132 | @cli.command() 133 | @click.option('--model', type=str, required=True, help="A path to fitted saved ARTM model") 134 | @click.option('--in', 'in_', type=str, required=True, help="A file in csv format with text corpus to build model on") 135 | @click.option( 136 | '--out', 137 | type=str, 138 | default="mixtures.csv", 139 | help="A path to a file in csv format that will contain topic mixtures for texts from the incoming corpus" 140 | ) 141 | @click.option( 142 | '--overwrite', 143 | is_flag=True, 144 | show_default=True, 145 | default=False, 146 | help="Overwrite if the mixture file already exists" 147 | ) 148 | def predict(model: str, in_: str, out: str, overwrite: bool): 149 | logger.info(f"Loading model from {os.path.abspath(model)}") 150 | autotm_loaded = AutoTM.load(model) 151 | 152 | logger.info(f"Predicting mixtures for data from {os.path.abspath(in_)}") 153 | df = pd.read_csv(in_) 154 | mixtures = autotm_loaded.predict(df) 155 | 156 | logger.info(f"Saving mixtures to {os.path.abspath(out)}") 157 | mixtures.to_csv(out, mode='w' if overwrite else 'x') 158 | 159 | 160 | if __name__ == "__main__": 161 | cli() 162 | -------------------------------------------------------------------------------- /autotm/main_fitness_worker.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import yaml 4 | 5 | 6 | def main(): 7 | os.environ['AUTOTM_COMPONENT'] = 'worker' 8 | os.environ['AUTOTM_EXEC_MODE'] = 'cluster' 9 | 10 | from autotm.fitness.cluster_tasks import make_celery_app 11 | from autotm.fitness.tm import TopicModelFactory 12 | 13 | if "DATASETS_CONFIG" in os.environ: 14 | with open(os.environ["DATASETS_CONFIG"], "r") as f: 15 | config = yaml.safe_load(f) 16 | dataset_settings = config["datasets"] 17 | else: 18 | dataset_settings = None 19 | 20 | TopicModelFactory.init_factory_settings( 21 | num_processors=os.getenv("NUM_PROCESSORS", None), 22 | dataset_settings=dataset_settings 23 | ) 24 | 25 | celery_app = make_celery_app() 26 | celery_app.worker_main() 27 | 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /autotm/make-recovery-file.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import datetime 3 | import logging 4 | import os 5 | import shutil 6 | import sys 7 | from glob import glob 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | # example: irace-log-2021-07-19T21:48:00-2b5faba3-99a8-47ba-a5cd-2c99e6877160.Rdata 13 | def fpath_to_datetime(fpath: str) -> datetime: 14 | fname = os.path.basename(fpath) 15 | dt_str = fname[len("irace-log-"): len("irace-log-") + len("YYYY-mm-ddTHH:MM:ss")] 16 | dt = datetime.datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%S") 17 | return dt 18 | 19 | 20 | if __name__ == "__main__": 21 | save_dir = sys.argv[1] 22 | logger.info(f"Looking for data files in {save_dir}") 23 | data_files = [ 24 | (fpath, fpath_to_datetime(fpath)) for fpath in glob(f"{save_dir}/*.Rdata") 25 | ] 26 | 27 | if len(data_files) == 0: 28 | logger.info( 29 | f"No data files have been found in {save_dir}. It is a normal situation. Interrupting execution." 30 | ) 31 | sys.exit(0) 32 | 33 | last_data_filepath, _ = max(data_files, key=lambda x: x[1]) 34 | recovery_filepath = os.path.join(save_dir, "recovery_rdata.checkpoint") 35 | 36 | logger.info( 37 | f"Copying last generated data file ({last_data_filepath}) as recovery file ({recovery_filepath})" 38 | ) 39 | shutil.copy(last_data_filepath, recovery_filepath) 40 | logger.info("Copying done") 41 | -------------------------------------------------------------------------------- /autotm/pipeline.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Union 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class IntRangeDistribution(BaseModel): 8 | low: int 9 | high: int 10 | 11 | def create_value(self) -> int: 12 | return random.randint(self.low, self.high) 13 | 14 | def clip(self, value) -> int: 15 | value = int(value) 16 | return int(max(self.low, min(self.high, value))) 17 | 18 | 19 | class FloatRangeDistribution(BaseModel): 20 | low: float 21 | high: float 22 | 23 | def create_value(self) -> float: 24 | return random.uniform(self.low, self.high) 25 | 26 | def clip(self, value) -> float: 27 | return max(self.low, min(self.high, value)) 28 | 29 | 30 | class Param(BaseModel): 31 | """ 32 | Single parameter of a stage. 33 | Distribution can be unavailable after serialisation. 34 | """ 35 | name: str 36 | distribution: Union[IntRangeDistribution, FloatRangeDistribution] 37 | 38 | def create_value(self): 39 | if self.distribution is None: 40 | raise ValueError("Distribution is unavailable. One must restore the distribution after serialisation.") 41 | return self.distribution.create_value() 42 | 43 | 44 | class StageType(BaseModel): 45 | """ 46 | Stage template that defines params for a stage. 47 | See create_stage generator. 48 | """ 49 | name: str 50 | params: List[Param] 51 | 52 | 53 | class Stage(BaseModel): 54 | """ 55 | Stage instance with parameter values. 56 | """ 57 | stage_type: StageType 58 | values: List 59 | 60 | def __init__(self, **data): 61 | """ 62 | :param stage_type: back reference to the template for parameter mutation 63 | :param values: instance's parameter values 64 | """ 65 | super().__init__(**data) 66 | if len(self.stage_type.params) != len(self.values): 67 | raise ValueError("Number of values does not match number of parameters.") 68 | 69 | def __str__(self): 70 | return f"{self.stage_type.name}{self.values}" 71 | 72 | def clip_values(self): 73 | self.values = [param.distribution.clip(value) for param, value in zip(self.stage_type.params, self.values)] 74 | 75 | 76 | def create_stage(stage_type: StageType) -> Stage: 77 | return Stage(stage_type=stage_type, values=[param.create_value() for param in stage_type.params]) 78 | 79 | 80 | class Pipeline(BaseModel): 81 | """ 82 | List of stages that can be mutated. 83 | """ 84 | stages: List[Stage] 85 | required_params: Stage 86 | 87 | def __str__(self): 88 | return f'{str(self.required_params)} {" ".join(map(str, self.stages))}' 89 | 90 | def random_stage_index(self, with_last: bool = False): 91 | last = len(self.stages) 92 | if with_last: 93 | last += 1 94 | return random.randint(0, last - 1) 95 | 96 | def __lt__(self, other): 97 | # important for sort method usage 98 | return len(self.stages) < len(other.stages) 99 | -------------------------------------------------------------------------------- /autotm/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | # former 'ppp.csv' 2 | PREPOCESSED_DATASET_FILENAME = "dataset_processed.csv" 3 | RESERVED_TUPLE = ("_SERVICE_", "total_pairs_count") 4 | -------------------------------------------------------------------------------- /autotm/schemas.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from typing import Union 3 | 4 | from pydantic import BaseModel 5 | 6 | from autotm.params import PipelineParams, FixedListParams 7 | from autotm.utils import MetricsScores 8 | 9 | AnyParams = Union[FixedListParams, PipelineParams] 10 | 11 | 12 | class IndividualDTO(BaseModel): 13 | id: str 14 | data_path: str 15 | params: AnyParams 16 | fitness_name: str = "regular" 17 | dataset: str = "default" 18 | force_dataset_settings_checkout: bool = False 19 | fitness_value: Optional[MetricsScores] = None 20 | exp_id: Optional[int] = None 21 | alg_id: Optional[str] = None 22 | tag: Optional[str] = None 23 | iteration_id: int = 0 24 | topic_count: Optional[int] = None 25 | train_option: str = "offline" 26 | 27 | class Config: 28 | arbitrary_types_allowed = True 29 | 30 | def make_params_dict(self): 31 | return self.params.make_params_dict() 32 | 33 | -------------------------------------------------------------------------------- /autotm/visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/visualization/__init__.py -------------------------------------------------------------------------------- /cluster/bin/fitnessctl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -ex 4 | 5 | if [[ -z "${KUBE_NAMESPACE}" ]] 6 | then 7 | kubectl_args="" 8 | else 9 | kubectl_args="-n ${KUBE_NAMESPACE}" 10 | fi 11 | 12 | release_name="autotm" 13 | 14 | registry="node2.bdcl:5000" 15 | mlflow_image="${registry}/mlflow-webserver:latest" 16 | flower_image="${registry}/flower:latest" 17 | fitness_worker_image="${registry}/fitness-worker:latest" 18 | jupyter_image="${registry}/autotm-jupyter:latest" 19 | 20 | base_dir="./cluster" 21 | docker_files_dir="${base_dir}/docker/" 22 | chart_path="${base_dir}/charts/autotm" 23 | 24 | function build_app() { 25 | echo "Building app..." 26 | 27 | poetry export --without-hashes > requirements.txt 28 | poetry build 29 | 30 | echo "Finished app building" 31 | } 32 | 33 | function build_images(){ 34 | echo "Building images..." 35 | 36 | build_app 37 | 38 | docker build -f "${docker_files_dir}/mlflow-webserver.dockerfile" -t ${mlflow_image} . 39 | docker build -f "${docker_files_dir}/flower.dockerfile" -t ${flower_image} . 40 | docker build -f "${docker_files_dir}/worker.dockerfile" -t ${fitness_worker_image} . 41 | docker build -f "${docker_files_dir}/jupyter.dockerfile" -t ${jupyter_image} . 42 | 43 | echo "Finished images building" 44 | } 45 | 46 | function push_images() { 47 | echo "Pushing images..." 48 | 49 | docker push ${mlflow_image} 50 | docker push ${flower_image} 51 | docker push ${fitness_worker_image} 52 | docker push ${jupyter_image} 53 | 54 | echo "Finished pushing images" 55 | } 56 | 57 | function install() { 58 | echo "Installing images..." 59 | 60 | build_images 61 | push_images 62 | 63 | echo "Finished installing images" 64 | } 65 | 66 | function create_pv() { 67 | kubectl ${kubectl_args} apply -f "${base_dir}/conf/pv.yaml" 68 | } 69 | 70 | function delete_pv() { 71 | kubectl ${kubectl_args} delete -f "${base_dir}/conf/pv.yaml" --ignore-not-found 72 | } 73 | 74 | function recreate_pv() { 75 | delete_pv 76 | create_pv 77 | } 78 | 79 | function install_chart() { 80 | helm install \ 81 | --create-namespace --namespace="${release_name}" \ 82 | --set=autotm_prefix='autotm' --set=mongo_enabled="false" \ 83 | ${release_name} ${chart_path} 84 | } 85 | 86 | function uninstall_chart() { 87 | helm uninstall autotm --ignore-not-found --wait 88 | } 89 | 90 | function upgrade_chart() { 91 | # helm upgrade --namespace="${release_name}" --reuse-values autotm ${chart_path} 92 | helm upgrade --namespace="${release_name}" autotm ${chart_path} 93 | } 94 | 95 | function dry_run_chart() { 96 | helm install \ 97 | --create-namespace --namespace=${release_name} \ 98 | --dry-run=server \ 99 | --set=autotm_prefix='autotm' --set=mongo_enabled="false" \ 100 | ${release_name} ${chart_path} --wait 101 | } 102 | 103 | function help() { 104 | echo " 105 | Supported env variables: 106 | KUBE_NAMESPACE - a kubernetes namespace to make actions in 107 | 108 | List of commands. 109 | build-app - build the app as a .whl distribution 110 | build-images - build all required docker images 111 | push-images - push all required docker images to the private registry on node2.bdcl 112 | install-images - build-images and push-images 113 | create-pv - creates persistent volumes required for functioning of the chart 114 | install-chart - install autotm deployment 115 | uninstall-chart - uninstall autotm deployment 116 | upgrade-chart - upgrade autotm deployment 117 | help - prints this message 118 | " 119 | } 120 | 121 | function main () { 122 | cmd="$1" 123 | 124 | if [ -z "${cmd}" ] 125 | then 126 | echo "No command is provided." 127 | help 128 | exit 1 129 | fi 130 | 131 | shift 1 132 | 133 | echo "Executing command: ${cmd}" 134 | 135 | case "${cmd}" in 136 | "build-app") 137 | build_app 138 | ;; 139 | 140 | "build-images") 141 | build_images 142 | ;; 143 | 144 | "push-images") 145 | push_images 146 | ;; 147 | 148 | "install-images") 149 | install 150 | ;; 151 | 152 | "create-pv") 153 | create_pv 154 | ;; 155 | 156 | "delete-pv") 157 | delete_pv 158 | ;; 159 | 160 | "recreate-pv") 161 | recreate_pv 162 | ;; 163 | 164 | "install-chart") 165 | install_chart 166 | ;; 167 | 168 | "uninstall-chart") 169 | uninstall_chart 170 | ;; 171 | 172 | "upgrade-chart") 173 | upgrade_chart 174 | ;; 175 | 176 | "dry-run-chart") 177 | dry_run_chart 178 | ;; 179 | 180 | "help") 181 | help 182 | ;; 183 | 184 | *) 185 | echo "Unknown command: ${cmd}" 186 | ;; 187 | 188 | esac 189 | } 190 | 191 | main "${@}" 192 | -------------------------------------------------------------------------------- /cluster/charts/autotm/.helmignore: -------------------------------------------------------------------------------- 1 | # Patterns to ignore when building packages. 2 | # This supports shell glob matching, relative path matching, and 3 | # negation (prefixed with !). Only one pattern per line. 4 | .DS_Store 5 | # Common VCS dirs 6 | .git/ 7 | .gitignore 8 | .bzr/ 9 | .bzrignore 10 | .hg/ 11 | .hgignore 12 | .svn/ 13 | # Common backup files 14 | *.swp 15 | *.bak 16 | *.tmp 17 | *.orig 18 | *~ 19 | # Various IDEs 20 | .project 21 | .idea/ 22 | *.tmproj 23 | .vscode/ 24 | -------------------------------------------------------------------------------- /cluster/charts/autotm/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: autotm 3 | description: A Helm chart for AutoTM 4 | 5 | # A chart can be either an 'application' or a 'library' chart. 6 | # 7 | # Application charts are a collection of templates that can be packaged into versioned archives 8 | # to be deployed. 9 | # 10 | # Library charts provide useful utilities or functions for the chart developer. They're included as 11 | # a dependency of application charts to inject those utilities and functions into the rendering 12 | # pipeline. Library charts do not define any templates and therefore cannot be deployed. 13 | type: application 14 | 15 | # This is the chart version. This version number should be incremented each time you make changes 16 | # to the chart and its templates, including the app version. 17 | # Versions are expected to follow Semantic Versioning (https://semver.org/) 18 | version: 0.1.0 19 | 20 | # This is the version number of the application being deployed. This version number should be 21 | # incremented each time you make changes to the application. Versions are not expected to 22 | # follow Semantic Versioning. They should reflect the version the application is using. 23 | # It is recommended to use it with quotes. 24 | appVersion: "1.0.0" 25 | -------------------------------------------------------------------------------- /cluster/charts/autotm/templates/_helpers.tpl: -------------------------------------------------------------------------------- 1 | {{/* 2 | Expand the name of the chart. 3 | */}} 4 | {{- define "st-workspace.name" -}} 5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} 6 | {{- end }} 7 | 8 | {{/* 9 | Create a default fully qualified app name. 10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). 11 | If release name contains chart name it will be used as a full name. 12 | */}} 13 | {{- define "st-workspace.fullname" -}} 14 | {{- if .Values.fullnameOverride }} 15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} 16 | {{- else }} 17 | {{- $name := default .Chart.Name .Values.nameOverride }} 18 | {{- if contains $name .Release.Name }} 19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }} 20 | {{- else }} 21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} 22 | {{- end }} 23 | {{- end }} 24 | {{- end }} 25 | 26 | {{/* 27 | Create chart name and version as used by the chart label. 28 | */}} 29 | {{- define "st-workspace.chart" -}} 30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} 31 | {{- end }} 32 | 33 | {{/* 34 | Common labels 35 | */}} 36 | {{- define "st-workspace.labels" -}} 37 | helm.sh/chart: {{ include "st-workspace.chart" . }} 38 | {{ include "st-workspace.selectorLabels" . }} 39 | {{- if .Chart.AppVersion }} 40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} 41 | {{- end }} 42 | app.kubernetes.io/managed-by: {{ .Release.Service }} 43 | {{- end }} 44 | 45 | {{/* 46 | Selector labels 47 | */}} 48 | {{- define "st-workspace.selectorLabels" -}} 49 | app.kubernetes.io/name: {{ include "st-workspace.name" . }} 50 | app.kubernetes.io/instance: {{ .Release.Name }} 51 | {{- end }} 52 | 53 | {{/* 54 | Create the name of the service account to use 55 | */}} 56 | {{- define "st-workspace.serviceAccountName" -}} 57 | {{- if .Values.serviceAccount.create }} 58 | {{- default (include "st-workspace.fullname" .) .Values.serviceAccount.name }} 59 | {{- else }} 60 | {{- default "default" .Values.serviceAccount.name }} 61 | {{- end }} 62 | {{- end }} 63 | 64 | {{/* 65 | Helper function to define prefix for all entities 66 | */}} 67 | {{- define "autotm.prefix" -}} 68 | {{- $prefix := default .Values.autotm_prefix .Release.Name | trunc 16 -}} 69 | {{- ternary $prefix (printf "%s-" $prefix) (empty $prefix) -}} 70 | {{- end -}} 71 | 72 | {{/* 73 | Mlflow db url 74 | */}} 75 | {{- define "autotm.mlflow_db_url" -}} 76 | {{- if .Values.mlflow_enabled -}} 77 | {{- printf "mysql+pymysql://%s:%s@%smlflow-db:3306/%s" (.Values.mlflow_mysql_user) (.Values.mlflow_mysql_password) (include "autotm.prefix" . ) (.Values.mlflow_mysql_database) -}} 78 | {{- else -}} 79 | "" 80 | {{- end -}} 81 | {{- end -}} 82 | 83 | {{/* 84 | Mongo url 85 | */}} 86 | {{- define "autotm.mongo_url" -}} 87 | {{- if .Values.mongo_enabled -}} 88 | {{- printf "mongodb://%s:%s@%smongo-tm-experiments-db:27017" (.Values.mongo_user) (.Values.mongo_password) (include "autotm.prefix" .) -}} 89 | {{- else -}} 90 | "" 91 | {{- end -}} 92 | {{- end -}} 93 | 94 | {{/* 95 | Celery broker url 96 | */}} 97 | {{- define "autotm.celery_broker_url" -}} 98 | {{- printf "amqp://guest:guest@%srabbitmq-service:5672" (include "autotm.prefix" .) -}} 99 | {{- end -}} 100 | 101 | {{/* 102 | Celery result backend url 103 | */}} 104 | {{- define "autotm.celery_result_backend" -}} 105 | {{- printf "redis://%sredis:6379/1" (include "autotm.prefix" .) -}} 106 | {{- end -}} 107 | 108 | {{/* 109 | Check if required persistent volume exists. 110 | One should invoke it like '{{ include "autotm.find_pv" (list . "pv_mlflow_db") }}', where 'pv_mlflow_db' is a variable from values.yaml 111 | See: 112 | - https://stackoverflow.com/questions/70803242/helm-pass-single-string-to-a-template 113 | - https://austindewey.com/2021/06/02/writing-function-based-templates-in-helm/ 114 | */}} 115 | {{- define "autotm.find_pv" -}} 116 | {{- $root := index . 0 -}} 117 | {{- $pv_var := index . 1 -}} 118 | {{- $pv_name := (printf "%s%s" (include "autotm.prefix" $root) (get $root.Values $pv_var)) -}} 119 | {{- $result := (lookup "v1" "PersistentVolume" "" $pv_name) -}} 120 | {{- if empty $result -}} 121 | {{- (printf "Persistent volume with name '%s' not found" $pv_name) | fail -}} 122 | {{- else -}} 123 | {{- $result.metadata.name -}} 124 | {{- end -}} 125 | {{- end -}} 126 | -------------------------------------------------------------------------------- /cluster/charts/autotm/templates/configmaps.yaml: -------------------------------------------------------------------------------- 1 | {{ if .Values.mlflow_enabled }} 2 | ################## 3 | #### MLFlow 4 | ################## 5 | --- 6 | apiVersion: v1 7 | kind: ConfigMap 8 | metadata: 9 | name: {{ include "autotm.prefix" . }}mlflow-mysql-cnf 10 | labels: 11 | {{- range $key, $val := .Values.required_labels }} 12 | {{ $key }}: {{ $val | quote }} 13 | {{- end }} 14 | data: 15 | my.cnf: | 16 | [mysqld] 17 | max_connections=500 18 | wait_timeout=15 19 | interactive_timeout=15 20 | {{ end }} 21 | ################## 22 | #### Celery & Fitness Workers 23 | ################## 24 | --- 25 | apiVersion: v1 26 | kind: ConfigMap 27 | metadata: 28 | name: {{ include "autotm.prefix" . }}rabbitmq-config 29 | labels: 30 | {{- range $key, $val := .Values.required_labels }} 31 | {{ $key }}: {{ $val | quote }} 32 | {{- end }} 33 | data: 34 | consumer-settings.conf: | 35 | ## Consumer timeout 36 | ## If a message delivered to a consumer has not been acknowledge before this timer 37 | ## triggers the channel will be force closed by the broker. This ensure that 38 | ## faultly consumers that never ack will not hold on to messages indefinitely. 39 | ## 40 | consumer_timeout = 1800000 41 | --- 42 | apiVersion: v1 43 | kind: ConfigMap 44 | metadata: 45 | name: {{ include "autotm.prefix" . }}worker-config 46 | labels: 47 | {{- range $key, $val := .Values.required_labels }} 48 | {{ $key }}: {{ $val | quote }} 49 | {{- end }} 50 | data: 51 | datasets-config.yaml: | 52 | {{ .Values.worker_datasets_config_content | indent 4}} 53 | -------------------------------------------------------------------------------- /cluster/charts/autotm/templates/pvc.yaml: -------------------------------------------------------------------------------- 1 | {{ if .Values.pvc_create_enabled }} 2 | {{ if .Values.mlflow_enabled }} 3 | ################## 4 | #### MLFlow 5 | ################## 6 | --- 7 | apiVersion: v1 8 | kind: PersistentVolumeClaim 9 | metadata: 10 | name: {{ include "autotm.prefix" . }}mlflow-db-pvc 11 | labels: 12 | {{- range $key, $val := .Values.required_labels }} 13 | {{ $key }}: {{ $val | quote }} 14 | {{- end }} 15 | spec: 16 | storageClassName: {{ .Values.storage_class }} 17 | volumeName: {{ include "autotm.find_pv" (list . "pv_mlflow_db") }} 18 | accessModes: 19 | - ReadWriteOnce 20 | resources: 21 | requests: 22 | storage: 25Gi 23 | --- 24 | apiVersion: v1 25 | kind: PersistentVolumeClaim 26 | metadata: 27 | name: {{ include "autotm.prefix" . }}mlflow-artifact-store-pvc 28 | labels: 29 | {{- range $key, $val := .Values.required_labels }} 30 | {{ $key }}: {{ $val | quote }} 31 | {{- end }} 32 | spec: 33 | storageClassName: {{ .Values.storage_class }} 34 | volumeName: {{ include "autotm.find_pv" (list . "pv_mlflow_artifact_store") }} 35 | accessModes: 36 | - ReadWriteMany 37 | resources: 38 | requests: 39 | storage: 25Gi 40 | {{ end }} 41 | {{ if .Values.mongo_enabled }} 42 | ################## 43 | #### Mongo 44 | ################## 45 | --- 46 | apiVersion: v1 47 | kind: PersistentVolumeClaim 48 | metadata: 49 | name: {{ include "autotm.prefix" . }}mongo-tm-experiments-pvc 50 | labels: 51 | {{- range $key, $val := .Values.required_labels }} 52 | {{ $key }}: {{ $val | quote }} 53 | {{- end }} 54 | spec: 55 | storageClassName: {{ .Values.storage_class }} 56 | volumeName: {{ include "autotm.find_pv" (list . "pv_mongo_db") }} 57 | accessModes: 58 | - ReadWriteOnce 59 | resources: 60 | requests: 61 | storage: 50Gi 62 | {{ end }} 63 | ################### 64 | ##### Celery & Fitness Workers 65 | ################### 66 | --- 67 | apiVersion: v1 68 | kind: PersistentVolumeClaim 69 | metadata: 70 | name: {{ include "autotm.prefix" . }}datasets 71 | labels: 72 | {{- range $key, $val := .Values.required_labels }} 73 | {{ $key }}: {{ $val | quote }} 74 | {{- end }} 75 | spec: 76 | storageClassName: {{ .Values.storage_class }} 77 | volumeName: {{ include "autotm.find_pv" (list . "pv_dataset_store") }} 78 | accessModes: 79 | - ReadOnlyMany 80 | resources: 81 | requests: 82 | storage: 50Gi 83 | {{ end }} 84 | -------------------------------------------------------------------------------- /cluster/charts/autotm/templates/services.yaml: -------------------------------------------------------------------------------- 1 | {{ if .Values.jupyter_enabled }} 2 | ################## 3 | #### Jupyter 4 | ################## 5 | --- 6 | apiVersion: v1 7 | kind: Service 8 | metadata: 9 | name: {{ include "autotm.prefix" . }}jupyter 10 | # labels: 11 | # {{- range $key, $val := .Values.required_labels }} 12 | # {{ $key }}:{{ $val | quote }} 13 | # {{- end }} 14 | spec: 15 | type: NodePort 16 | ports: 17 | - port: 8888 18 | protocol: TCP 19 | name: jupyter 20 | - port: 4040 21 | protocol: TCP 22 | name: spark 23 | selector: 24 | app: {{ include "autotm.prefix" . }}jupyter 25 | {{ end }} 26 | {{ if .Values.mlflow_enabled }} 27 | ################## 28 | #### MLFlow 29 | ################## 30 | --- 31 | apiVersion: v1 32 | kind: Service 33 | metadata: 34 | name: {{ include "autotm.prefix" . }}mlflow 35 | labels: 36 | {{- range $key, $val := .Values.required_labels }} 37 | {{ $key }}: {{ $val | quote }} 38 | {{- end }} 39 | spec: 40 | type: NodePort 41 | ports: 42 | - port: 5000 43 | selector: 44 | app: {{ include "autotm.prefix" . }}mlflow 45 | --- 46 | apiVersion: v1 47 | kind: Service 48 | metadata: 49 | name: {{ include "autotm.prefix" . }}mlflow-db 50 | labels: 51 | {{- range $key, $val := .Values.required_labels }} 52 | {{ $key }}: {{ $val | quote }} 53 | {{- end }} 54 | spec: 55 | ports: 56 | - port: 3306 57 | selector: 58 | app: {{ include "autotm.prefix" . }}mlflow-db 59 | {{ end }} 60 | {{ if .Values.mongo_enabled }} 61 | ################## 62 | #### Mongo 63 | ################## 64 | --- 65 | apiVersion: v1 66 | kind: Service 67 | metadata: 68 | name: {{ include "autotm.prefix" . }}mongo-tm-experiments-db 69 | labels: 70 | {{- range $key, $val := .Values.required_labels }} 71 | {{ $key }}: {{ $val | quote }} 72 | {{- end }} 73 | spec: 74 | type: NodePort 75 | ports: 76 | - port: 27017 77 | selector: 78 | app: {{ include "autotm.prefix" . }}mongo-tm-experiments-db 79 | --- 80 | apiVersion: v1 81 | kind: Service 82 | metadata: 83 | name: {{ include "autotm.prefix" . }}mongo-express 84 | labels: 85 | {{- range $key, $val := .Values.required_labels }} 86 | {{ $key }}: {{ $val | quote }} 87 | {{- end }} 88 | spec: 89 | type: NodePort 90 | ports: 91 | - port: 8081 92 | selector: 93 | app: {{ include "autotm.prefix" . }}mongo-express-tm-experiments 94 | {{ end }} 95 | ################## 96 | #### Celery & Fitness Workers 97 | ################## 98 | --- 99 | apiVersion: v1 100 | kind: Service 101 | metadata: 102 | name: {{ include "autotm.prefix" . }}rabbitmq-service 103 | labels: 104 | {{- range $key, $val := .Values.required_labels }} 105 | {{ $key }}: {{ $val | quote }} 106 | {{- end }} 107 | spec: 108 | type: NodePort 109 | ports: 110 | - port: 5672 111 | selector: 112 | app: {{ include "autotm.prefix" . }}rabbitmq 113 | --- 114 | apiVersion: v1 115 | kind: Service 116 | metadata: 117 | name: {{ include "autotm.prefix" . }}redis 118 | labels: 119 | {{- range $key, $val := .Values.required_labels }} 120 | {{ $key }}: {{ $val | quote }} 121 | {{- end }} 122 | spec: 123 | type: NodePort 124 | ports: 125 | - port: 6379 126 | selector: 127 | app: {{ include "autotm.prefix" . }}redis 128 | --- 129 | apiVersion: v1 130 | kind: Service 131 | metadata: 132 | name: {{ include "autotm.prefix" . }}celery-flower-service 133 | labels: 134 | {{- range $key, $val := .Values.required_labels }} 135 | {{ $key }}: {{ $val | quote }} 136 | {{- end }} 137 | spec: 138 | type: NodePort 139 | ports: 140 | - port: 5555 141 | selector: 142 | app: {{ include "autotm.prefix" . }}celery-flower 143 | -------------------------------------------------------------------------------- /cluster/charts/autotm/values.yaml: -------------------------------------------------------------------------------- 1 | ### General section 2 | autotm_prefix: "" 3 | required_labels: 4 | owner: "autotm" 5 | 6 | ### Images 7 | # pull policy 8 | pull_policy: "IfNotPresent" 9 | worker_image_pull_policy: "Always" 10 | # images 11 | mysql_image: "mysql/mysql-server:5.7.28" 12 | phpmyadmin_image: "phpmyadmin:5.1.1" 13 | mongo_image: "mongo:4.4.6-bionic" 14 | mongoexpress_image: "mongo-express:latest" 15 | rabbitmq_image: "node2.bdcl:5000/rabbitmq:3.8-management-alpine" 16 | redis_image: "node2.bdcl:5000/redis:6.2" 17 | mlflow_image: "node2.bdcl:5000/mlflow-webserver:latest" 18 | flower_image: "node2.bdcl:5000/flower:latest" 19 | worker_image: "node2.bdcl:5000/fitness-worker:latest" 20 | jupyter_image: "node2.bdcl:5000/autotm-jupyter:latest" 21 | 22 | ### Volumes 23 | pvc_create_enabled: "true" 24 | storage_class: "manual" 25 | pv_mlflow_db: "mlflow-db-pv" 26 | pv_mlflow_artifact_store: "mlflow-artifact-store-pv" 27 | pv_mongo_db: "mongo-tm-experiments-pv" 28 | pv_dataset_store: "datasets" 29 | 30 | ### Jupyter 31 | jupyter_enabled: "true" 32 | jupyter_cpu_limits: "4" 33 | jupyter_mem_limits: "16Gi" 34 | 35 | ### Mlflow 36 | mlflow_enabled: "true" 37 | mlflow_mysql_database: "mlflow" 38 | mlflow_mysql_user: "mlflow" 39 | mlflow_mysql_password: "mlflow" 40 | mlflow_mysql_root_password: "mlflow" 41 | 42 | ### Mongo 43 | mongo_enabled: "false" 44 | mongo_user: "mongoadmin" 45 | mongo_password: "secret" 46 | 47 | ### Fitness Worker settings 48 | worker_datasets_dir_path: "/storage" 49 | worker_count: "1" 50 | worker_cpu: "4" 51 | worker_mem: "12Gi" 52 | worker_mongo_collection: "tm_stats" 53 | worker_datasets_config_content: | 54 | datasets: 55 | # the rest settings will be the same as for the first dataset but will be added automatically 56 | books_stroyitelstvo_2030: 57 | base_path: "/storage/books_stroyitelstvo_2030_sample" 58 | topic_count: 10 59 | 60 | 20newsgroups: 61 | base_path: "/storage/20newsgroups_sample" 62 | topic_count: 20 63 | 64 | clickhouse_issues: 65 | base_path: "/storage/clickhouse_issues_sample" 66 | labels: yes 67 | topic_count: 50 68 | 69 | # the rest settings will be the same as for the first dataset but will be added automatically 70 | banners: 71 | base_path: "/storage/banners_sample" 72 | topic_count: 20 73 | 74 | amazon_food: 75 | base_path: "/storage/amazon_food_sample" 76 | topic_count: 20 77 | 78 | hotel-reviews: 79 | base_path: "/storage/hotel-reviews_sample" 80 | topic_count: 20 81 | 82 | lenta_ru: 83 | base_path: "/storage/lenta_ru_sample" 84 | topic_count: 20 85 | -------------------------------------------------------------------------------- /cluster/conf/pv.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: v1 3 | kind: PersistentVolume 4 | metadata: 5 | name: autotm-mlflow-db-pv 6 | spec: 7 | storageClassName: manual 8 | capacity: 9 | storage: 25Gi 10 | volumeMode: Filesystem 11 | accessModes: 12 | - ReadWriteOnce 13 | persistentVolumeReclaimPolicy: Retain 14 | local: 15 | path: /data/hdd/autotm-mlflow-db 16 | nodeAffinity: 17 | required: 18 | nodeSelectorTerms: 19 | - matchExpressions: 20 | - key: kubernetes.io/hostname 21 | operator: In 22 | values: 23 | - node5.bdcl 24 | --- 25 | apiVersion: v1 26 | kind: PersistentVolume 27 | metadata: 28 | name: autotm-mlflow-artifact-store-pv 29 | spec: 30 | storageClassName: manual 31 | capacity: 32 | storage: 50Gi 33 | volumeMode: Filesystem 34 | accessModes: 35 | - ReadWriteMany 36 | persistentVolumeReclaimPolicy: Retain 37 | hostPath: 38 | path: /mnt/ess_storage/DN_1/storage/home/khodorchenko/mlflow-tm-experiments/autotm-mlflow-artifact-store 39 | type: DirectoryOrCreate 40 | --- 41 | apiVersion: v1 42 | kind: PersistentVolume 43 | metadata: 44 | name: autotm-mongo-tm-experiments-pv 45 | spec: 46 | storageClassName: manual 47 | capacity: 48 | storage: 50Gi 49 | volumeMode: Filesystem 50 | accessModes: 51 | - ReadWriteOnce 52 | persistentVolumeReclaimPolicy: Retain 53 | local: 54 | path: /data/hdd/autotm-mongo-db-tm-experiments 55 | nodeAffinity: 56 | required: 57 | nodeSelectorTerms: 58 | - matchExpressions: 59 | - key: kubernetes.io/hostname 60 | operator: In 61 | values: 62 | - node5.bdcl 63 | --- 64 | apiVersion: v1 65 | kind: PersistentVolume 66 | metadata: 67 | name: autotm-datasets 68 | spec: 69 | storageClassName: manual 70 | capacity: 71 | storage: 50Gi 72 | volumeMode: Filesystem 73 | accessModes: 74 | - ReadOnlyMany 75 | persistentVolumeReclaimPolicy: Retain 76 | hostPath: 77 | path: /mnt/ess_storage/DN_1/storage/home/khodorchenko/GOTM/datasets_TM_scoring 78 | type: Directory 79 | -------------------------------------------------------------------------------- /cluster/docker/flower.dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:alpine 2 | 3 | # Get latest root certificates 4 | RUN apk add --no-cache ca-certificates && update-ca-certificates 5 | 6 | # Install the required packages 7 | RUN pip install --no-cache-dir redis flower==1.0.0 8 | 9 | # PYTHONUNBUFFERED: Force stdin, stdout and stderr to be totally unbuffered. (equivalent to `python -u`) 10 | # PYTHONHASHSEED: Enable hash randomization (equivalent to `python -R`) 11 | # PYTHONDONTWRITEBYTECODE: Do not write byte files to disk, since we maintain it as readonly. (equivalent to `python -B`) 12 | ENV PYTHONUNBUFFERED=1 PYTHONHASHSEED=random PYTHONDONTWRITEBYTECODE=1 13 | 14 | # Default port 15 | EXPOSE 5555 16 | 17 | ENV FLOWER_DATA_DIR /data 18 | ENV PYTHONPATH ${FLOWER_DATA_DIR} 19 | 20 | WORKDIR $FLOWER_DATA_DIR 21 | 22 | # Add a user with an explicit UID/GID and create necessary directories 23 | RUN set -eux; \ 24 | addgroup -g 1000 flower; \ 25 | adduser -u 1000 -G flower flower -D; \ 26 | mkdir -p "$FLOWER_DATA_DIR"; \ 27 | chown flower:flower "$FLOWER_DATA_DIR" 28 | USER flower 29 | 30 | VOLUME $FLOWER_DATA_DIR 31 | 32 | # for '-A distributed_fitness' see kube_fitness.tasks.make_celery_app 33 | ENTRYPOINT celery flower --address=0.0.0.0 --port=5555 -------------------------------------------------------------------------------- /cluster/docker/jupyter.dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9 2 | 3 | RUN apt-get update && apt-get install -y bash python3 python3-pip 4 | 5 | COPY requirements.txt /tmp/ 6 | 7 | RUN pip3 install -r /tmp/requirements.txt 8 | 9 | COPY dist/autotm-0.1.0-py3-none-any.whl /tmp 10 | 11 | RUN pip3 install --no-deps /tmp/autotm-0.1.0-py3-none-any.whl 12 | 13 | RUN pip3 install jupyter 14 | 15 | RUN pip3 install datasets 16 | 17 | RUN python -m spacy download en_core_web_sm 18 | 19 | COPY data /root 20 | 21 | COPY examples/autotm_demo_updated.ipynb /root 22 | 23 | ENTRYPOINT ["jupyter", "notebook", "--notebook-dir=/root", "--ip=0.0.0.0", "--port=8888", "--allow-root", "--no-browser", "--NotebookApp.token=''"] 24 | -------------------------------------------------------------------------------- /cluster/docker/mlflow-webserver.dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | 3 | RUN pip install PyMySQL==1.1.0 sqlalchemy==2.0.23 psycopg2-binary==2.9.9 protobuf==3.20.0 mlflow[extras]==2.9.2 4 | 5 | ENTRYPOINT ["mlflow", "server"] 6 | -------------------------------------------------------------------------------- /cluster/docker/worker.dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | 3 | RUN apt-get update && apt-get install -y bash python3 python3-pip 4 | 5 | #RUN pip3 install 'celery == 4.4.7' 'bigartm == 0.9.2' 'tqdm == 4.50.2' 'numpy == 1.19.2' 'dataclasses-json == 0.5.2' 6 | 7 | COPY requirements.txt /tmp/ 8 | 9 | RUN pip3 install -r /tmp/requirements.txt 10 | 11 | COPY dist/autotm-0.1.0-py3-none-any.whl /tmp 12 | 13 | RUN pip3 install --no-deps /tmp/autotm-0.1.0-py3-none-any.whl 14 | 15 | ENTRYPOINT fitness-worker \ 16 | --concurrency 1 \ 17 | --queues fitness_tasks \ 18 | --loglevel INFO 19 | -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | preprocessing_params: 3 | lang: "ru" 4 | min_tokens_count: 3 5 | 6 | alg_name: "ga" 7 | alg_params: 8 | num_iterations: 2 9 | num_individuals: 4 10 | use_nelder_mead_in_mutation: no 11 | use_nelder_mead_in_crossover: no 12 | use_nelder_mead_in_selector: no 13 | train_option: "offline" 14 | -------------------------------------------------------------------------------- /distributed/README.md: -------------------------------------------------------------------------------- 1 | Before you start: 2 | 1. One should remember that there is 'help' command thats describes various possible operations: 3 | ``` 4 | ./bin/fitnessctl help 5 | ``` 6 | 7 | 2. All operations with 'fitnessctl' can be executed in a different namespaces. 8 | One only needs to set 'KUBE_NAMESPACE' variable. 9 | ``` 10 | KUBE_NAMESPACE=custom_namespace ./bin/fitnessctl deploy 11 | ``` 12 | 13 | To run kube-fitness app in production with mlflow one should: 14 | 15 | 1. Set up .kube/config the way it can access and manage remote cluster. 16 | 17 | 2. Build the app's wheel: 18 | ``` 19 | ./bin/fitnessctl build-app 20 | ``` 21 | 22 | 3. Build docker images and install them into the shared private repository 23 | `./bin/fitnessctl install-images`. To accomplish this operation successfully you should have the repository added to your docker settings. 24 | See the example of the settings below. 25 | 26 | 4. Create mlflow volumes and PVCs (we require to supply desired_namespace because of PVCs) by issung the command: 27 | ``` 28 | kubectl -n apply -f deploy/mlflow-persistent-volumes.yaml 29 | ``` 30 | Do not forget that the directories required by volumes should exist beforehand. 31 | 32 | 5. Deploy mlflow: 33 | ``` 34 | ./bin/fitnessctl create-mlflow 35 | ``` 36 | 37 | 6. Generate production config either using 'gencfg' command with appropriate arguments or using 'prodcfg' alone: 38 | ``` 39 | ./bin/fitnessctl prodcfg 40 | ``` 41 | 42 | 7. Finally deploy kube-fitness app with the command: 43 | ``` 44 | ./bin/fitnessctl create 45 | ``` 46 | 47 | Alternatively, one may safely use 'recreate' command if there may be already another deployment: 48 | ``` 49 | ./bin/fitnessctl recreate 50 | ``` 51 | 52 | 8. To check if everything works fine, one may create a test client and verify that it is completed successfully: 53 | ``` 54 | ./bin/fitnessctl recreate-client 55 | ``` 56 | 57 | NOTE: there is a shortcut for the above sequence (EXCEPT p.4 and p.5): 58 | ``` 59 | ./bin/fitnessctl deploy 60 | ``` 61 | 62 | To deploy the kube-fitness app on a kubernetes one should do the following: 63 | 64 | 1. Build docker images and install them into the shared private repository 65 | `./bin/fitnessctl install-images`. To accomplish this operation successfully you should have the repository added to your docker settings. 66 | See the example of the settings below. 67 | 68 | 3. Build and install the wheel into the private PyPI registry (private is the name of the registry). 69 | ``` 70 | python setup.py sdist bdist_wheel register -r private upload -r private 71 | ``` 72 | To perform this operation you need to have a proper config at **~/.pypirc** 73 | See the example below. 74 | 75 | 76 | 4. Go to the remote server that has access to your kubernetes cluster and has *kubectl* installed. 77 | Install **kube_fitness** wheel 78 | ``` 79 | pip3 install --trusted-host node2.bdcl --index http://node2.bdcl:30496 kube_fitness --user --upgrade --force-reinstall 80 | ``` 81 | 82 | The wheel should be installed with **--user** setting as it puts necessary files to **$HOME/.local/share/kube-fitness-data** directory. 83 | Other options are recommended for second and further reinstalls. 84 | 85 | 5. Add the line `export PATH="$HOME/.local/bin:$PATH"` into your **~/.bash_profile** to be able to call **kube-fitnessctl** utility. 86 | 87 | 6. Run `kube-fitnessctl gencfg ` to create depployment file for kubernetes. 88 | 89 | 7. Run `kube-fitnessctl create` or `kube-fitnessctl recreate` to deploy fitness workers on your cluster. 90 | 91 | 8. To access the app from your jupyter server, you should install the wheel as well using pip. 92 | 93 | 9. See example how to run the client in kube_fitness/test_app.py 94 | 95 | .pypirc example config: 96 | ``` 97 | [distutils] 98 | index-servers = 99 | private 100 | 101 | [private] 102 | repository: http://:/ 103 | username: 104 | password: 105 | ``` 106 | 107 | 108 | Docker daemon settings (**/etc/docker/daemon.json**) to access a private registry 109 | `"insecure-registries":["node2.bdcl:5000"]` 110 | -------------------------------------------------------------------------------- /distributed/autotm_distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/distributed/autotm_distributed/__init__.py -------------------------------------------------------------------------------- /distributed/autotm_distributed/deploy_config_generator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | from jinja2 import Environment, FileSystemLoader 6 | 7 | logger = logging.getLogger() 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser(description='Generate deploy config with' 12 | ' entities for distributed fitness estimation') 13 | 14 | parser.add_argument("--host-data-dir", 15 | type=str, 16 | required=True, 17 | help="Path to a directory on hosts to be mounted in containers of workers") 18 | 19 | parser.add_argument("--datasets-config", 20 | type=str, 21 | required=True, 22 | help="Path to a directory on hosts to be mounted in containers of workers") 23 | 24 | parser.add_argument("--host-data-dir-mount-path", type=str, default="/storage", 25 | help="Path inside containers of workers where to mount host_data_dir") 26 | 27 | parser.add_argument("--registry", type=str, default=None, 28 | help="Registry to push images to") 29 | 30 | parser.add_argument("--client_image", type=str, default="fitness-client:latest", 31 | help="Image to use for a client pod") 32 | 33 | parser.add_argument("--flower_image", type=str, default="flower:latest", 34 | help="Image to use for Flower - a celery monitoring tool") 35 | 36 | parser.add_argument("--worker_image", type=str, default="fitness-worker:latest", 37 | help="Image to use for worker pods") 38 | 39 | parser.add_argument("--worker_count", type=int, default=1, 40 | help="Count of workers to be launched on the kubernetes cluster") 41 | 42 | parser.add_argument("--worker_cpu", type=int, default=1, 43 | help="Count of cpu (int) to be allocated per worker") 44 | 45 | parser.add_argument("--worker_mem", type=str, default="1G", 46 | help="Amount of memory to be allocated per worker") 47 | 48 | parser.add_argument("--config_template_dir", default="deploy", 49 | help="Path to a template's dir the config will be generated from") 50 | 51 | parser.add_argument("--out_dir", type=str, default=None, 52 | help="Path to a generated config") 53 | 54 | parser.add_argument("--mongo_collection", type=str, default='main_tm_stats', 55 | help="Mongo collection") 56 | 57 | args = parser.parse_args() 58 | 59 | worker_template_path = "kube-fitness-workers.yaml.j2" 60 | client_template_path = "kube-fitness-client-job.yaml.j2" 61 | 62 | if args.out_dir: 63 | worker_cfg_out_path = os.path.join(args.out_dir, "kube-fitness-workers.yaml") 64 | client_cfg_out_path = os.path.join(args.out_dir, "kube-fitness-client-job.yaml") 65 | else: 66 | worker_cfg_out_path = os.path.join(args.config_template_dir, "kube-fitness-workers.yaml") 67 | client_cfg_out_path = os.path.join(args.config_template_dir, "kube-fitness-client-job.yaml") 68 | 69 | datasets_config = args.datasets_config \ 70 | if os.path.isabs(args.datasets_config) else os.path.join(args.config_template_dir, args.datasets_config) 71 | 72 | logging.info(f"Reading datasets config from {args.datasets_config}") 73 | with open(datasets_config, "r") as f: 74 | datasets_config_content = f.read() 75 | 76 | logging.info(f"Using template dir: {args.config_template_dir}") 77 | logging.info(f"Using template {worker_template_path}") 78 | logging.info(f"Generating config file {worker_cfg_out_path}") 79 | 80 | env = Environment(loader=FileSystemLoader(args.config_template_dir)) 81 | template = env.get_template(worker_template_path) 82 | template.stream( 83 | flower_image=f"{args.registry}/{args.flower_image}" if args.registry else args.flower_image, 84 | image=f"{args.registry}/{args.worker_image}" if args.registry else args.worker_image, 85 | pull_policy="Always" if args.registry else "IfNotPresent", 86 | worker_count=args.worker_count, 87 | worker_cpu=args.worker_cpu, 88 | worker_mem=args.worker_mem, 89 | host_data_dir=args.host_data_dir, 90 | host_data_dir_mount_path=args.host_data_dir_mount_path, 91 | datasets_config_content=datasets_config_content, 92 | mongo_collection=args.mongo_collection 93 | ).dump(worker_cfg_out_path) 94 | 95 | template = env.get_template(client_template_path) 96 | template.stream( 97 | image=f"{args.registry}/{args.client_image}" if args.registry else args.client_image, 98 | pull_policy="Always" if args.registry else "IfNotPresent", 99 | ).dump(client_cfg_out_path) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /distributed/autotm_distributed/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import yaml 4 | 5 | from autotm_distributed.tasks import make_celery_app 6 | from autotm_distributed.tm import TopicModelFactory 7 | 8 | if __name__ == '__main__': 9 | if "DATASETS_CONFIG" in os.environ: 10 | with open(os.environ["DATASETS_CONFIG"], "r") as f: 11 | config = yaml.load(f) 12 | dataset_settings = config["datasets"] 13 | else: 14 | dataset_settings = None 15 | 16 | TopicModelFactory.init_factory_settings( 17 | num_processors=os.getenv("NUM_PROCESSORS", None), 18 | dataset_settings=dataset_settings 19 | ) 20 | 21 | celery_app = make_celery_app() 22 | celery_app.worker_main() 23 | -------------------------------------------------------------------------------- /distributed/autotm_distributed/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | MetricsScores = Dict[str, Any] 4 | 5 | TimeMeasurements = Dict[str, float] 6 | 7 | AVG_COHERENCE_SCORE = "avg_coherence_score" 8 | -------------------------------------------------------------------------------- /distributed/autotm_distributed/schemas.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List, Optional 3 | 4 | from pydantic import BaseModel 5 | 6 | from autotm_distributed import MetricsScores 7 | 8 | PARAM_NAMES = [ 9 | 'val_decor', 'var_n_0', 'var_sm_0', 'var_sm_1', 'var_n_1', 10 | 'var_sp_0', 'var_sp_1', 'var_n_2', 11 | 'var_sp_2', 'var_sp_3', "var_n_3", 12 | 'var_n_4', 13 | 'ext_mutation_prob', 'ext_elem_mutation_prob', 'ext_mutation_selector', 14 | 'val_decor_2' 15 | ] 16 | 17 | 18 | # @dataclass_json 19 | # @dataclass 20 | class IndividualDTO(BaseModel): 21 | id: str 22 | params: List[object] 23 | fitness_name: str = "default" 24 | dataset: str = "default" 25 | force_dataset_settings_checkout: bool = False 26 | fitness_value: MetricsScores = None 27 | exp_id: Optional[int] = None 28 | alg_id: Optional[str] = None 29 | tag: Optional[str] = None 30 | iteration_id: int = 0 31 | topic_count: Optional[int] = None 32 | train_option: str = 'offline' 33 | 34 | def make_params_dict(self): 35 | if len(self.params) > len(PARAM_NAMES): 36 | len_diff = len(self.params) - len(PARAM_NAMES) 37 | param_names = copy.deepcopy(PARAM_NAMES) + [f"unknown_param_#{i}" for i in range(len_diff)] 38 | else: 39 | param_names = PARAM_NAMES 40 | 41 | return {name: p_val for name, p_val in zip(param_names, self.params)} 42 | -------------------------------------------------------------------------------- /distributed/autotm_distributed/test_app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pprint import pprint 3 | from typing import Optional 4 | 5 | import click 6 | 7 | from autotm_distributed import AVG_COHERENCE_SCORE 8 | 9 | logging.basicConfig(level=logging.DEBUG) 10 | 11 | from autotm_distributed.tasks import make_celery_app, parallel_fitness, log_best_solution 12 | from autotm_distributed.schemas import IndividualDTO 13 | 14 | logger = logging.getLogger("TEST_APP") 15 | 16 | celery_app = make_celery_app(client_to_broker_max_retries=25) 17 | 18 | population = [ 19 | IndividualDTO(id="1", dataset="first_dataset", topic_count=5, exp_id=0, alg_id="test_alg", iteration_id=100, 20 | tag="just_a_simple_tag", params=[ 21 | 66717.86348968784, 2.0, 98.42286785825902, 22 | 80.18543570807961, 2.0, -19.948347560420373, 23 | -52.28141634493725, 2.0, -92.85597392137976, 24 | -60.49287378084627, 4.0, 3.0, 0.06840138630839943, 25 | 0.556001061599461, 0.9894122432621849, 11679.364068753106]), 26 | 27 | IndividualDTO(id="2", dataset="second_dataset", topic_count=5, params=[ 28 | 42260.028918433134, 1.0, 7.535922806674758, 29 | 92.32509683092258, 3.0, -71.92218883101997, 30 | -56.016307418098386, 1.0, -3.237446735558109, 31 | -20.448661743825, 4.0, 7.0, 0.08031597367372223, 32 | 0.42253357962287563, 0.02249898631530911, 61969.93405537756]), 33 | 34 | IndividualDTO(id="3", dataset="first_dataset", topic_count=5, params=[ 35 | 53787.42158965788, 1.0, 96.00600751876978, 36 | 82.83724552058459, 6.0, -51.625141715571715, 37 | -35.45911388077616, 0.0, -79.6738397452075, 38 | -42.96576312232228, 3.0, 6.0, 0.1842678599143196, 39 | 0.08563438015417912, 0.104280507307428, 91187.26051038165]), 40 | 41 | IndividualDTO(id="4", dataset="second_dataset", topic_count=5, params=[ 42 | 24924.136392296103, 0.0, 98.63988602807903, 43 | 49.03544407815009, 5.0, -8.734591095928806, 44 | -6.99720964952175, 4.0, -32.880078901677265, 45 | -24.61400511189416, 3.0, 0.0, 0.9084621726817743, 46 | 0.6392049950522389, 0.3133878344721094, 39413.00378611856]) 47 | ] 48 | 49 | 50 | @click.group() 51 | def cli(): 52 | pass 53 | 54 | 55 | @cli.command() 56 | def test(): 57 | logger.info("Starting parallel computations...") 58 | 59 | population_with_fitness = parallel_fitness(population) 60 | 61 | assert len(population_with_fitness) == 4, \ 62 | f"Wrong population size (should be == 4): {pprint(population_with_fitness)}" 63 | 64 | for ind in population_with_fitness: 65 | logger.info(f"Individual {ind.id} has fitness {ind.fitness_value}") 66 | 67 | assert max((ind.fitness_value[AVG_COHERENCE_SCORE] for ind in population_with_fitness)) > 0, \ 68 | "At least one fitness should be more than zero" 69 | 70 | logger.info("Test run succeded") 71 | click.echo('Initialized the database') 72 | 73 | 74 | @cli.command() 75 | @click.option('--timeout', type=float, default=25.0) 76 | @click.option('--expid', type=int) 77 | def run_and_log(timeout: float, expid: Optional[str] = None): 78 | best_ind = population[0] 79 | if expid: 80 | best_ind.exp_id = expid 81 | logger.info(f"Individual: {best_ind}") 82 | ind = log_best_solution(individual=best_ind, wait_for_result_timeout=timeout, alg_args="--arg 1 --arg 2") 83 | logger.info(f"Logged the best solution. Obtained fitness is {ind.fitness_value[AVG_COHERENCE_SCORE]}") 84 | 85 | 86 | if __name__ == '__main__': 87 | cli() 88 | -------------------------------------------------------------------------------- /distributed/autotm_distributed/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import logging 3 | from datetime import datetime 4 | from typing import Optional 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class TqdmToLogger(io.StringIO): 11 | """ 12 | Output stream for TQDM which will output to logger module instead of 13 | the StdOut. 14 | """ 15 | def __init__(self, base_logger, level=None): 16 | super(TqdmToLogger, self).__init__() 17 | self.logger = base_logger 18 | self.level = level or logging.INFO 19 | self.buf = '' 20 | 21 | def write(self, buf): 22 | self.buf = buf.strip('\r\n\t ') 23 | 24 | def flush(self): 25 | self.logger.log(self.level, self.buf) 26 | 27 | 28 | class log_exec_timer: 29 | def __init__(self, name: Optional[str] = None): 30 | self.name = name 31 | self._start = None 32 | self._duration = None 33 | 34 | def __enter__(self): 35 | self._start = datetime.now() 36 | return self 37 | 38 | def __exit__(self, typ, value, traceback): 39 | self._duration = (datetime.now() - self._start).total_seconds() 40 | msg = f"Exec time of {self.name}: {self._duration}" if self.name else f"Exec time: {self._duration}" 41 | logger.info(msg) 42 | 43 | @property 44 | def duration(self): 45 | return self._duration 46 | -------------------------------------------------------------------------------- /distributed/bin/kube-fitnessctl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export DIST_MODE=installed 4 | exec fitnessctl "${@}" -------------------------------------------------------------------------------- /distributed/bin/trun.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | docker run -it --rm \ 4 | -v /home/nikolay/wspace/20newsgroups/:/dataset \ 5 | -e "CELERY_BROKER_URL=amqp://guest:guest@rabbitmq-service:5672" \ 6 | -e "CELERY_RESULT_BACKEND=rpc://" \ 7 | -e "DICTIONARY_PATH=/dataset/dictionary.txt" \ 8 | -e "BATCHES_DIR=/dataset/batches" \ 9 | -e "MUTUAL_INFO_DICT_PATH=/dataset/mutual_info_dict.pkl" \ 10 | -e "EXPERIMENTS_PATH=/tmp/tm_experiments" \ 11 | -e "TOPIC_COUNT=10" \ 12 | -e "NUM_PROCESSORS=1" \ 13 | fitness-worker:latest h -------------------------------------------------------------------------------- /distributed/deploy/ess-small-datasets-config.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | # the rest settings will be the same as for the first dataset but will be added automatically 3 | books_stroyitelstvo_2030: 4 | base_path: "/storage/books_stroyitelstvo_2030_sample" 5 | topic_count: 10 6 | 7 | 20newsgroups: 8 | base_path: "/storage/20newsgroups_sample" 9 | topic_count: 20 10 | 11 | clickhouse_issues: 12 | base_path: "/storage/clickhouse_issues_sample" 13 | labels: yes 14 | topic_count: 50 15 | 16 | # the rest settings will be the same as for the first dataset but will be added automatically 17 | banners: 18 | base_path: "/storage/banners_sample" 19 | topic_count: 20 20 | 21 | amazon_food: 22 | base_path: "/storage/amazon_food_sample" 23 | topic_count: 20 24 | 25 | hotel-reviews: 26 | base_path: "/storage/hotel-reviews_sample" 27 | topic_count: 20 28 | 29 | lenta_ru: 30 | base_path: "/storage/lenta_ru_sample" 31 | topic_count: 20 32 | -------------------------------------------------------------------------------- /distributed/deploy/file.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: v1 3 | kind: Service 4 | metadata: 5 | name: mlflow-db-phpmyadmin 6 | spec: 7 | type: NodePort 8 | ports: 9 | - port: 80 10 | selector: 11 | app: mlflow-db-phpmyadmin 12 | --- 13 | apiVersion: apps/v1 14 | kind: Deployment 15 | metadata: 16 | name: mlflow-db-phpmyadmin 17 | spec: 18 | replicas: 1 19 | selector: 20 | matchLabels: 21 | app: mlflow-db-phpmyadmin 22 | template: 23 | metadata: 24 | annotations: 25 | "sidecar.istio.io/inject": "false" 26 | labels: 27 | app: mlflow-db-phpmyadmin 28 | spec: 29 | containers: 30 | - name: phpmyadmin 31 | image: phpmyadmin:5.1.1 32 | imagePullPolicy: IfNotPresent 33 | env: 34 | - name: PMA_HOST 35 | value: mlflow-db 36 | - name: PMA_PORT 37 | value: "3306" 38 | - name: PMA_USER 39 | value: mlflow 40 | - name: PMA_PASSWORD 41 | value: mlflow 42 | ports: 43 | - containerPort: 80 44 | -------------------------------------------------------------------------------- /distributed/deploy/fitness-worker-health-checker.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: batch/v1beta1 3 | kind: CronJob 4 | metadata: 5 | name: fitness-worker-health-checker 6 | spec: 7 | schedule: "*/5 * * * *" 8 | concurrencyPolicy: Forbid 9 | jobTemplate: 10 | spec: 11 | template: 12 | metadata: 13 | annotations: 14 | "sidecar.istio.io/inject": "false" 15 | spec: 16 | serviceAccountName: default-editor 17 | containers: 18 | - name: health-checker 19 | image: node2.bdcl:5000/fitness-worker-health-checker:latest 20 | imagePullPolicy: IfNotPresent 21 | command: 22 | - /bin/sh 23 | - -c 24 | - > 25 | kube_ns=$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace) && 26 | kubectl -n ${kube_ns} logs --prefix --tail=100 -l=app=fitness-worker 2>1 | 27 | grep -i 'Could not create logging file: File exists' | 28 | perl -n -e'/pod\/(.+)\/worker/ && print "$1\n"' | 29 | sort | 30 | uniq | 31 | xargs -n 1 -t -I {} -P8 kubectl -n ${kube_ns} delete pod {} --ignore-not-found 32 | restartPolicy: Never -------------------------------------------------------------------------------- /distributed/deploy/kube-fitness-client-job.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: batch/v1 2 | kind: Job 3 | metadata: 4 | name: fitness-client 5 | spec: 6 | template: 7 | # works even with injection 8 | metadata: 9 | annotations: 10 | "sidecar.istio.io/inject": "false" 11 | spec: 12 | containers: 13 | - name: client 14 | image: node2.bdcl:5000/fitness-client:latest 15 | args: ["test"] 16 | imagePullPolicy: Always 17 | env: 18 | - name: CELERY_BROKER_URL 19 | value: "amqp://guest:guest@rabbitmq-service:5672" 20 | - name: CELERY_RESULT_BACKEND 21 | value: "redis://redis:6379/1" # "rpc://" 22 | restartPolicy: Never 23 | backoffLimit: 0 -------------------------------------------------------------------------------- /distributed/deploy/kube-fitness-client-job.yaml.j2: -------------------------------------------------------------------------------- 1 | apiVersion: batch/v1 2 | kind: Job 3 | metadata: 4 | name: fitness-client 5 | spec: 6 | template: 7 | # works even with injection 8 | metadata: 9 | annotations: 10 | "sidecar.istio.io/inject": "false" 11 | spec: 12 | containers: 13 | - name: client 14 | image: {{ image }} 15 | args: ["test"] 16 | imagePullPolicy: {{ pull_policy }} 17 | env: 18 | - name: CELERY_BROKER_URL 19 | value: "amqp://guest:guest@rabbitmq-service:5672" 20 | - name: CELERY_RESULT_BACKEND 21 | value: "redis://redis:6379/1" # "rpc://" 22 | restartPolicy: Never 23 | backoffLimit: 0 -------------------------------------------------------------------------------- /distributed/deploy/kube-fitness-workers.yaml.j2: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: apps/v1 3 | kind: Deployment 4 | metadata: 5 | name: rabbitmq 6 | spec: 7 | replicas: 1 8 | selector: 9 | matchLabels: 10 | app: rabbitmq 11 | template: 12 | metadata: 13 | annotations: 14 | "sidecar.istio.io/inject": "false" 15 | labels: 16 | app: rabbitmq 17 | spec: 18 | volumes: 19 | - name: config-volume 20 | configMap: 21 | name: rabbitmq-config 22 | containers: 23 | - name: rabbitmq 24 | image: node2.bdcl:5000/rabbitmq:3.8-management-alpine 25 | imagePullPolicy: IfNotPresent 26 | ports: 27 | - containerPort: 5672 28 | volumeMounts: 29 | - name: config-volume 30 | mountPath: /etc/rabbitmq/conf.d/consumer-settings.conf 31 | subPath: consumer-settings.conf 32 | --- 33 | apiVersion: v1 34 | kind: Service 35 | metadata: 36 | name: rabbitmq-service 37 | spec: 38 | ports: 39 | - port: 5672 40 | selector: 41 | app: rabbitmq 42 | --- 43 | apiVersion: v1 44 | kind: ConfigMap 45 | metadata: 46 | name: rabbitmq-config 47 | data: 48 | consumer-settings.conf: | 49 | ## Consumer timeout 50 | ## If a message delivered to a consumer has not been acknowledge before this timer 51 | ## triggers the channel will be force closed by the broker. This ensure that 52 | ## faultly consumers that never ack will not hold on to messages indefinitely. 53 | ## 54 | consumer_timeout = 1800000 55 | --- 56 | apiVersion: apps/v1 # for k8s versions before 1.9.0 use apps/v1beta2 and before 1.8.0 use extensions/v1beta1 57 | kind: Deployment 58 | metadata: 59 | name: redis 60 | spec: 61 | replicas: 1 62 | selector: 63 | matchLabels: 64 | app: redis 65 | template: 66 | metadata: 67 | annotations: 68 | "sidecar.istio.io/inject": "false" 69 | labels: 70 | app: redis 71 | spec: 72 | containers: 73 | - name: redis 74 | image: node2.bdcl:5000/redis:6.2 75 | imagePullPolicy: IfNotPresent 76 | resources: 77 | requests: 78 | cpu: 100m 79 | memory: 100Mi 80 | ports: 81 | - containerPort: 6379 82 | --- 83 | apiVersion: v1 84 | kind: Service 85 | metadata: 86 | name: redis 87 | spec: 88 | ports: 89 | - port: 6379 90 | selector: 91 | app: redis 92 | --- 93 | apiVersion: apps/v1 94 | kind: Deployment 95 | metadata: 96 | name: celery-flower 97 | spec: 98 | replicas: 1 99 | selector: 100 | matchLabels: 101 | app: celery-flower 102 | template: 103 | metadata: 104 | annotations: 105 | "sidecar.istio.io/inject": "false" 106 | labels: 107 | app: celery-flower 108 | spec: 109 | containers: 110 | - name: flower 111 | image: {{ flower_image }} 112 | imagePullPolicy: {{ pull_policy }} 113 | ports: 114 | - containerPort: 5555 115 | env: 116 | - name: CELERY_BROKER_URL 117 | value: "amqp://guest:guest@rabbitmq-service:5672" 118 | - name: CELERY_RESULT_BACKEND 119 | value: "redis://redis:6379/1" # "rpc://" 120 | --- 121 | apiVersion: v1 122 | kind: Service 123 | metadata: 124 | name: celery-flower-service 125 | spec: 126 | type: NodePort 127 | ports: 128 | - port: 5555 129 | selector: 130 | app: celery-flower 131 | --- 132 | apiVersion: v1 133 | kind: ConfigMap 134 | metadata: 135 | name: fitness-worker-config 136 | data: 137 | datasets-config.yaml: | 138 | {{ datasets_config_content | indent( width=4, first=False) }} 139 | --- 140 | apiVersion: apps/v1 141 | kind: Deployment 142 | metadata: 143 | name: fitness-worker 144 | labels: 145 | app: fitness-worker 146 | spec: 147 | replicas: {{ worker_count }} 148 | selector: 149 | matchLabels: 150 | app: fitness-worker 151 | template: 152 | metadata: 153 | annotations: 154 | "sidecar.istio.io/inject": "false" 155 | labels: 156 | app: fitness-worker 157 | spec: 158 | volumes: 159 | - name: dataset 160 | hostPath: 161 | path: {{ host_data_dir }} 162 | type: Directory 163 | - name: config-volume 164 | configMap: 165 | name: fitness-worker-config 166 | - name: mlflow-vol 167 | persistentVolumeClaim: 168 | claimName: mlflow-artifact-store-pvc 169 | containers: 170 | - name: worker 171 | image: {{ image }} 172 | imagePullPolicy: {{ pull_policy }} 173 | volumeMounts: 174 | - name: dataset 175 | mountPath: {{ host_data_dir_mount_path }} 176 | - name: config-volume 177 | mountPath: /etc/fitness/datasets-config.yaml 178 | subPath: datasets-config.yaml 179 | - mountPath: "/var/lib/mlruns" 180 | name: mlflow-vol 181 | env: 182 | - name: CELERY_BROKER_URL 183 | value: "amqp://guest:guest@rabbitmq-service:5672" 184 | - name: CELERY_RESULT_BACKEND 185 | value: "redis://redis:6379/1" # "rpc://" 186 | - name: NUM_PROCESSORS 187 | value: "{{ worker_cpu }}" 188 | - name: DATASETS_CONFIG 189 | value: /etc/fitness/datasets-config.yaml 190 | - name: MLFLOW_TRACKING_URI 191 | value: mysql+pymysql://mlflow:mlflow@mlflow-db:3306/mlflow 192 | # see: https://github.com/mongodb/mongo-python-driver/blob/c8d920a46bfb7b054326b3e983943bfc794cb676/pymongo/mongo_client.py 193 | - name: MONGO_URI 194 | value: mongodb://mongoadmin:secret@mongo-tm-experiments-db:27017 195 | - name: MONGO_COLLECTION 196 | value: "{{ mongo_collection or 'tm_stats' }}" 197 | resources: 198 | requests: 199 | memory: "{{ worker_mem }}" 200 | cpu: "{{ worker_cpu }}" 201 | limits: 202 | memory: "{{ worker_mem }}" 203 | cpu: "{{ worker_cpu }}" 204 | -------------------------------------------------------------------------------- /distributed/deploy/mlflow-minikube-persistent-volumes.yaml: -------------------------------------------------------------------------------- 1 | # TODO: need to be fixed in mongo part 2 | #--- 3 | #apiVersion: v1 4 | #kind: PersistentVolume 5 | #metadata: 6 | # name: mlflow-db-pv 7 | #spec: 8 | # storageClassName: manual 9 | # capacity: 10 | # storage: 25Gi 11 | # volumeMode: Filesystem 12 | # accessModes: 13 | # - ReadWriteOnce 14 | # persistentVolumeReclaimPolicy: Retain 15 | # local: 16 | # path: /data/hdd/mlflow-db 17 | # nodeAffinity: 18 | # required: 19 | # nodeSelectorTerms: 20 | # - matchExpressions: 21 | # - key: kubernetes.io/hostname 22 | # operator: In 23 | # values: 24 | # - localhost.localdomain 25 | #--- 26 | #apiVersion: v1 27 | #kind: PersistentVolume 28 | #metadata: 29 | # name: mlflow-artifact-store-pv 30 | #spec: 31 | # storageClassName: manual 32 | # capacity: 33 | # storage: 50Gi 34 | # volumeMode: Filesystem 35 | # accessModes: 36 | # - ReadWriteMany 37 | # persistentVolumeReclaimPolicy: Retain 38 | # hostPath: 39 | # path: /mnt/ess_storage/DN_1/storage/home/khodorchenko/mlflow-tm-experiments/mlflow-artifact-store 40 | # type: Directory 41 | #--- 42 | #apiVersion: v1 43 | #kind: PersistentVolume 44 | #metadata: 45 | # name: mongo-tm-experiments-pv 46 | #spec: 47 | # storageClassName: manual 48 | # capacity: 49 | # storage: 20Gi 50 | # volumeMode: Filesystem 51 | # accessModes: 52 | # - ReadWriteOnce 53 | # persistentVolumeReclaimPolicy: Retain 54 | # hostPath: 55 | # path: /mnt/ess_storage/DN_1/storage/home/khodorchenko/mlflow-tm-experiments/mongodb 56 | # type: Directory 57 | #--- 58 | #apiVersion: v1 59 | #kind: PersistentVolumeClaim 60 | #metadata: 61 | # name: mlflow-db-pvc 62 | #spec: 63 | # storageClassName: manual 64 | # volumeName: mlflow-db-pv 65 | # accessModes: 66 | # - ReadWriteOnce 67 | # resources: 68 | # requests: 69 | # storage: 25Gi 70 | #--- 71 | #apiVersion: v1 72 | #kind: PersistentVolumeClaim 73 | #metadata: 74 | # name: mlflow-artifact-store-pvc 75 | #spec: 76 | # storageClassName: manual 77 | # volumeName: mlflow-artifact-store-pv 78 | # accessModes: 79 | # - ReadWriteMany 80 | # resources: 81 | # requests: 82 | # storage: 25Gi 83 | #--- 84 | #apiVersion: v1 85 | #kind: PersistentVolumeClaim 86 | #metadata: 87 | # name: mongo-tm-experiments-pvc 88 | #spec: 89 | # storageClassName: manual 90 | # volumeName: mongo-tm-experiments-pv 91 | # accessModes: 92 | # - ReadWriteOnce 93 | # resources: 94 | # requests: 95 | # storage: 20Gi 96 | -------------------------------------------------------------------------------- /distributed/deploy/mlflow-persistent-volumes.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: v1 3 | kind: PersistentVolume 4 | metadata: 5 | name: mlflow-db-pv 6 | spec: 7 | storageClassName: manual 8 | capacity: 9 | storage: 25Gi 10 | volumeMode: Filesystem 11 | accessModes: 12 | - ReadWriteOnce 13 | persistentVolumeReclaimPolicy: Retain 14 | local: 15 | path: /data/hdd/mlflow-db 16 | nodeAffinity: 17 | required: 18 | nodeSelectorTerms: 19 | - matchExpressions: 20 | - key: kubernetes.io/hostname 21 | operator: In 22 | values: 23 | - node3.bdcl 24 | --- 25 | apiVersion: v1 26 | kind: PersistentVolume 27 | metadata: 28 | name: mlflow-artifact-store-pv 29 | spec: 30 | storageClassName: manual 31 | capacity: 32 | storage: 50Gi 33 | volumeMode: Filesystem 34 | accessModes: 35 | - ReadWriteMany 36 | persistentVolumeReclaimPolicy: Retain 37 | hostPath: 38 | path: /mnt/ess_storage/DN_1/storage/home/khodorchenko/mlflow-tm-experiments/mlflow-artifact-store 39 | type: Directory 40 | #--- 41 | #apiVersion: v1 42 | #kind: PersistentVolume 43 | #metadata: 44 | # name: mongo-tm-experiments-pv 45 | #spec: 46 | # storageClassName: manual 47 | # capacity: 48 | # storage: 50Gi 49 | # volumeMode: Filesystem 50 | # accessModes: 51 | # - ReadWriteOnce 52 | # persistentVolumeReclaimPolicy: Retain 53 | # local: 54 | # path: /data/hdd/mongo-db-tm-experiments 55 | # nodeAffinity: 56 | # required: 57 | # nodeSelectorTerms: 58 | # - matchExpressions: 59 | # - key: kubernetes.io/hostname 60 | # operator: In 61 | # values: 62 | # - node3.bdcl 63 | #z 64 | --- 65 | apiVersion: v1 66 | kind: PersistentVolume 67 | metadata: 68 | name: mongo-tm-experiments-pv-part3 69 | spec: 70 | storageClassName: manual 71 | capacity: 72 | storage: 50Gi 73 | volumeMode: Filesystem 74 | accessModes: 75 | - ReadWriteOnce 76 | persistentVolumeReclaimPolicy: Retain 77 | local: 78 | path: /data/hdd/mongo-db-tm-experiments-part3 79 | nodeAffinity: 80 | required: 81 | nodeSelectorTerms: 82 | - matchExpressions: 83 | - key: kubernetes.io/hostname 84 | operator: In 85 | values: 86 | - node11.bdcl 87 | --- 88 | apiVersion: v1 89 | kind: PersistentVolumeClaim 90 | metadata: 91 | name: mlflow-db-pvc 92 | spec: 93 | storageClassName: manual 94 | volumeName: mlflow-db-pv 95 | accessModes: 96 | - ReadWriteOnce 97 | resources: 98 | requests: 99 | storage: 25Gi 100 | --- 101 | apiVersion: v1 102 | kind: PersistentVolumeClaim 103 | metadata: 104 | name: mlflow-artifact-store-pvc 105 | spec: 106 | storageClassName: manual 107 | volumeName: mlflow-artifact-store-pv 108 | accessModes: 109 | - ReadWriteMany 110 | resources: 111 | requests: 112 | storage: 25Gi 113 | --- 114 | apiVersion: v1 115 | kind: PersistentVolumeClaim 116 | metadata: 117 | name: mongo-tm-experiments-pvc 118 | spec: 119 | storageClassName: manual 120 | volumeName: mongo-tm-experiments-pv-part3 121 | accessModes: 122 | - ReadWriteOnce 123 | resources: 124 | requests: 125 | storage: 50Gi 126 | -------------------------------------------------------------------------------- /distributed/deploy/mongo_dev/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: "3.5" 2 | 3 | services: 4 | mongo: 5 | image: mongo:latest 6 | container_name: mongo 7 | environment: 8 | MONGO_INITDB_ROOT_USERNAME: admin 9 | MONGO_INITDB_ROOT_PASSWORD: admin 10 | ports: 11 | - "0.0.0.0:27017:27017" 12 | networks: 13 | - MONGO 14 | volumes: 15 | - type: volume 16 | source: MONGO_DATA 17 | target: /data/db 18 | - type: volume 19 | source: MONGO_CONFIG 20 | target: /data/configdb 21 | mongo-express: 22 | image: mongo-express:latest 23 | container_name: mongo-express 24 | environment: 25 | ME_CONFIG_MONGODB_ADMINUSERNAME: admin 26 | ME_CONFIG_MONGODB_ADMINPASSWORD: admin 27 | ME_CONFIG_MONGODB_SERVER: mongo 28 | ME_CONFIG_MONGODB_PORT: "27017" 29 | # this script was taken from https://github.com/eficode/wait-for 30 | volumes: 31 | - type: bind 32 | source: ./wait-for.sh 33 | target: /wait-for.sh 34 | entrypoint: 35 | - /bin/sh 36 | - /wait-for.sh 37 | - mongo:27017 38 | - -- 39 | - tini 40 | - -- 41 | - /docker-entrypoint.sh 42 | ports: 43 | - "0.0.0.0:8081:8081" 44 | networks: 45 | - MONGO 46 | depends_on: 47 | - mongo 48 | 49 | networks: 50 | MONGO: 51 | name: MONGO 52 | 53 | volumes: 54 | MONGO_DATA: 55 | name: MONGO_DATA 56 | MONGO_CONFIG: 57 | name: MONGO_CONFIG 58 | -------------------------------------------------------------------------------- /distributed/deploy/mongo_dev/wait-for.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2017 Eficode Oy 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | VERSION="2.2.3" 26 | 27 | set -- "$@" -- "$TIMEOUT" "$QUIET" "$PROTOCOL" "$HOST" "$PORT" "$result" 28 | TIMEOUT=15 29 | QUIET=0 30 | # The protocol to make the request with, either "tcp" or "http" 31 | PROTOCOL="tcp" 32 | 33 | echoerr() { 34 | if [ "$QUIET" -ne 1 ]; then printf "%s\n" "$*" 1>&2; fi 35 | } 36 | 37 | usage() { 38 | exitcode="$1" 39 | cat << USAGE >&2 40 | Usage: 41 | $0 host:port|url [-t timeout] [-- command args] 42 | -q | --quiet Do not output any status messages 43 | -t TIMEOUT | --timeout=timeout Timeout in seconds, zero for no timeout 44 | -v | --version Show the version of this tool 45 | -- COMMAND ARGS Execute command with args after the test finishes 46 | USAGE 47 | exit "$exitcode" 48 | } 49 | 50 | wait_for() { 51 | case "$PROTOCOL" in 52 | tcp) 53 | if ! command -v nc >/dev/null; then 54 | echoerr 'nc command is missing!' 55 | exit 1 56 | fi 57 | ;; 58 | http) 59 | if ! command -v wget >/dev/null; then 60 | echoerr 'wget command is missing!' 61 | exit 1 62 | fi 63 | ;; 64 | esac 65 | 66 | TIMEOUT_END=$(($(date +%s) + TIMEOUT)) 67 | 68 | while :; do 69 | case "$PROTOCOL" in 70 | tcp) 71 | nc -w 1 -z "$HOST" "$PORT" > /dev/null 2>&1 72 | ;; 73 | http) 74 | wget --timeout=1 -q "$HOST" -O /dev/null > /dev/null 2>&1 75 | ;; 76 | *) 77 | echoerr "Unknown protocol '$PROTOCOL'" 78 | exit 1 79 | ;; 80 | esac 81 | 82 | result=$? 83 | 84 | if [ $result -eq 0 ] ; then 85 | if [ $# -gt 7 ] ; then 86 | for result in $(seq $(($# - 7))); do 87 | result=$1 88 | shift 89 | set -- "$@" "$result" 90 | done 91 | 92 | TIMEOUT=$2 QUIET=$3 PROTOCOL=$4 HOST=$5 PORT=$6 result=$7 93 | shift 7 94 | exec "$@" 95 | fi 96 | exit 0 97 | fi 98 | 99 | if [ $TIMEOUT -ne 0 -a $(date +%s) -ge $TIMEOUT_END ]; then 100 | echo "Operation timed out" >&2 101 | exit 1 102 | fi 103 | 104 | sleep 1 105 | done 106 | } 107 | 108 | while :; do 109 | case "$1" in 110 | http://*|https://*) 111 | HOST="$1" 112 | PROTOCOL="http" 113 | shift 1 114 | ;; 115 | *:* ) 116 | HOST=$(printf "%s\n" "$1"| cut -d : -f 1) 117 | PORT=$(printf "%s\n" "$1"| cut -d : -f 2) 118 | shift 1 119 | ;; 120 | -v | --version) 121 | echo $VERSION 122 | exit 123 | ;; 124 | -q | --quiet) 125 | QUIET=1 126 | shift 1 127 | ;; 128 | -q-*) 129 | QUIET=0 130 | echoerr "Unknown option: $1" 131 | usage 1 132 | ;; 133 | -q*) 134 | QUIET=1 135 | result=$1 136 | shift 1 137 | set -- -"${result#-q}" "$@" 138 | ;; 139 | -t | --timeout) 140 | TIMEOUT="$2" 141 | shift 2 142 | ;; 143 | -t*) 144 | TIMEOUT="${1#-t}" 145 | shift 1 146 | ;; 147 | --timeout=*) 148 | TIMEOUT="${1#*=}" 149 | shift 1 150 | ;; 151 | --) 152 | shift 153 | break 154 | ;; 155 | --help) 156 | usage 0 157 | ;; 158 | -*) 159 | QUIET=0 160 | echoerr "Unknown option: $1" 161 | usage 1 162 | ;; 163 | *) 164 | QUIET=0 165 | echoerr "Unknown argument: $1" 166 | usage 1 167 | ;; 168 | esac 169 | done 170 | 171 | if ! [ "$TIMEOUT" -ge 0 ] 2>/dev/null; then 172 | echoerr "Error: invalid timeout '$TIMEOUT'" 173 | usage 3 174 | fi 175 | 176 | case "$PROTOCOL" in 177 | tcp) 178 | if [ "$HOST" = "" ] || [ "$PORT" = "" ]; then 179 | echoerr "Error: you need to provide a host and port to test." 180 | usage 2 181 | fi 182 | ;; 183 | http) 184 | if [ "$HOST" = "" ]; then 185 | echoerr "Error: you need to provide a host to test." 186 | usage 2 187 | fi 188 | ;; 189 | esac 190 | 191 | wait_for "$@" 192 | -------------------------------------------------------------------------------- /distributed/deploy/ns.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: v1 3 | kind: Namespace 4 | metadata: 5 | name: mkhodorchenko 6 | labels: 7 | istio-injection: disabled 8 | -------------------------------------------------------------------------------- /distributed/deploy/test-datasets-config.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | first_dataset: 3 | base_path: "/storage/tiny_20newsgroups" 4 | dictionary_path: "dictionary.txt" 5 | batches_dir: "batches" 6 | mutual_info_dict_path: "mutual_info_dict.pkl" 7 | experiments_path: "/tmp/tm_experiments" 8 | topic_count: 10 9 | 10 | # the rest settings will be the same as for the first dataset but will be added automatically 11 | second_dataset: 12 | base_path: "/storage/tiny_20newsgroups_2" 13 | topic_count: 10 14 | -------------------------------------------------------------------------------- /distributed/docker/base.dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | RUN apt-get update && apt-get install -y bash python3 python3-pip 3 | #RUN pip3 install 'celery == 4.4.7' 'bigartm == 0.9.2' 'tqdm == 4.50.2' 'numpy == 1.19.2' 'dataclasses-json == 0.5.2' 4 | COPY requirements.txt /tmp/ 5 | RUN pip3 install -r /tmp/requirements.txt 6 | COPY . /kube-distributed-fitness 7 | RUN pip3 install /kube-distributed-fitness -------------------------------------------------------------------------------- /distributed/docker/cli.dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | RUN DEBIAN_FRONTEND=noninteractive apt-get update -y && \ 4 | DEBIAN_FRONTEND=noninteractive apt-get install -y git curl make cmake build-essential libboost-all-dev 5 | 6 | RUN curl -LJO https://github.com/bigartm/bigartm/archive/refs/tags/v0.9.2.tar.gz 7 | 8 | RUN tar -xvf bigartm-0.9.2.tar.gz 9 | 10 | RUN mkdir bigartm-0.9.2/build 11 | 12 | RUN DEBIAN_FRONTEND=noninteractive apt-get update && DEBIAN_FRONTEND=noninteractive apt-get --yes install bash python3.8 13 | 14 | RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py 15 | 16 | RUN python3.8 get-pip.py 17 | 18 | RUN pip3 install protobuf tqdm wheel 19 | 20 | RUN cd bigartm-0.9.2/build && cmake .. -DBUILD_INTERNAL_PYTHON_API=OFF 21 | 22 | RUN cd bigartm-0.9.2/build && make install 23 | 24 | COPY requirements.txt /tmp/ 25 | 26 | RUN pip3 install -r /tmp/requirements.txt 27 | 28 | COPY . /kube-distributed-fitness 29 | 30 | RUN pip3 install /kube-distributed-fitness 31 | -------------------------------------------------------------------------------- /distributed/docker/client.dockerfile: -------------------------------------------------------------------------------- 1 | FROM fitness-base:latest 2 | ENTRYPOINT ["python3", "-u", "/kube-distributed-fitness/kube_fitness/test_app.py"] -------------------------------------------------------------------------------- /distributed/docker/flower.dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:alpine 2 | 3 | # Get latest root certificates 4 | RUN apk add --no-cache ca-certificates && update-ca-certificates 5 | 6 | # Install the required packages 7 | RUN pip install --no-cache-dir redis flower==1.0.0 8 | 9 | # PYTHONUNBUFFERED: Force stdin, stdout and stderr to be totally unbuffered. (equivalent to `python -u`) 10 | # PYTHONHASHSEED: Enable hash randomization (equivalent to `python -R`) 11 | # PYTHONDONTWRITEBYTECODE: Do not write byte files to disk, since we maintain it as readonly. (equivalent to `python -B`) 12 | ENV PYTHONUNBUFFERED=1 PYTHONHASHSEED=random PYTHONDONTWRITEBYTECODE=1 13 | 14 | # Default port 15 | EXPOSE 5555 16 | 17 | ENV FLOWER_DATA_DIR /data 18 | ENV PYTHONPATH ${FLOWER_DATA_DIR} 19 | 20 | WORKDIR $FLOWER_DATA_DIR 21 | 22 | # Add a user with an explicit UID/GID and create necessary directories 23 | RUN set -eux; \ 24 | addgroup -g 1000 flower; \ 25 | adduser -u 1000 -G flower flower -D; \ 26 | mkdir -p "$FLOWER_DATA_DIR"; \ 27 | chown flower:flower "$FLOWER_DATA_DIR" 28 | USER flower 29 | 30 | VOLUME $FLOWER_DATA_DIR 31 | 32 | # for '-A distributed_fitness' see kube_fitness.tasks.make_celery_app 33 | ENTRYPOINT celery flower --address=0.0.0.0 --port=5555 -------------------------------------------------------------------------------- /distributed/docker/health-checker.dockerfile: -------------------------------------------------------------------------------- 1 | FROM boxboat/kubectl:1.21.3 2 | 3 | WORKDIR / 4 | USER root:root 5 | 6 | RUN apk add perl 7 | 8 | ENTRYPOINT [ "/usr/local/bin/kubectl"] 9 | -------------------------------------------------------------------------------- /distributed/docker/mlflow-webserver.dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | 3 | RUN pip install PyMySQL==0.9.3 psycopg2-binary==2.8.5 protobuf==3.20.1 mlflow[extras]==1.18.0 4 | 5 | ENTRYPOINT ["mlflow", "server"] 6 | -------------------------------------------------------------------------------- /distributed/docker/test.dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | RUN apt-get update && apt-get install -y bash python3 python3-pip 3 | RUN pip3 install 'celery == 4.4.7' 'bigartm == 0.9.2' 'tqdm == 4.50.2' 'numpy == 1.19.2' 'dataclasses-json == 0.5.2' -------------------------------------------------------------------------------- /distributed/docker/worker.dockerfile: -------------------------------------------------------------------------------- 1 | FROM fitness-base:latest 2 | ENTRYPOINT python3 /kube-distributed-fitness/kube_fitness/kube_fitness/main.py \ 3 | --concurrency 1 \ 4 | --queues fitness_tasks \ 5 | --loglevel INFO -------------------------------------------------------------------------------- /distributed/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "autotm-distributed" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["fonhorst "] 6 | readme = "README.md" 7 | packages = [{include = "autotm_distributed"}] 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.9" 11 | bigartm = "0.9.2" 12 | 13 | 14 | [build-system] 15 | requires = ["poetry-core"] 16 | build-backend = "poetry.core.masonry.api" 17 | -------------------------------------------------------------------------------- /distributed/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup 5 | 6 | 7 | def deploy_data_files(): 8 | templates = [os.path.join("deploy", f) for f in os.listdir("deploy") if f.endswith(".j2")] 9 | files = templates + ["deploy/test-datasets-config.yaml"] 10 | return files 11 | 12 | 13 | with open('requirements.txt') as f: 14 | strs = f.readlines() 15 | 16 | install_requires = [str(requirement) for requirement in pkg_resources.parse_requirements(strs)] 17 | 18 | setup( 19 | name='kube_fitness', 20 | version='0.1.0', 21 | description='kube fitness', 22 | package_dir={"": "kube_fitness"}, 23 | packages=["kube_fitness"], 24 | install_requires=install_requires, 25 | include_package_data=True, 26 | data_files=[ 27 | ('share/kube-fitness-data/deploy', deploy_data_files()) 28 | ], 29 | scripts=['bin/kube-fitnessctl', 'bin/fitnessctl'], 30 | entry_points = { 31 | 'console_scripts': ['deploy-config-generator=kube_fitness.deploy_config_generator:main'], 32 | } 33 | ) -------------------------------------------------------------------------------- /distributed/test/topic_model_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from kube_fitness import AVG_COHERENCE_SCORE 4 | from kube_fitness.tm import TopicModelFactory, calculate_fitness_of_individual 5 | 6 | 7 | class TopicModelCase(unittest.TestCase): 8 | def test_topic_model(self): 9 | topic_count = 10 10 | dataset_name_1 = "first" 11 | dataset_name_2 = "second" 12 | 13 | param_s = [ 14 | 66717.86348968784, 2.0, 98.42286785825902, 15 | 80.18543570807961, 2.0, -19.948347560420373, 16 | -52.28141634493725, 2.0, -92.85597392137976, 17 | -60.49287378084627, 4.0, 3.0, 0.06840138630839943, 18 | 0.556001061599461, 0.9894122432621849, 11679.364068753106 19 | ] 20 | 21 | metrics = [ 22 | AVG_COHERENCE_SCORE, 23 | 'perplexityScore', 'backgroundTokensRatioScore', 'contrast', 24 | 'purity', 'kernelSize', 'npmi_50_list', 25 | 'npmi_50', 'sparsity_phi', 'sparsity_theta', 26 | 'topic_significance_uni', 'topic_significance_vacuous', 'topic_significance_back', 27 | 'switchP_list', 28 | 'switchP', 'all_topics', 29 | *(f'coherence_{i}' for i in range(10, 60, 5)), 30 | *(f'coherence_{i}_list' for i in range(10, 60, 5)), 31 | ] 32 | 33 | TopicModelFactory.init_factory_settings(num_processors=2, dataset_settings={ 34 | dataset_name_1: { 35 | "base_path": '/home/nikolay/wspace/test_tiny_dataset', 36 | "topic_count": topic_count, 37 | }, 38 | dataset_name_2: { 39 | "base_path": '/home/nikolay/wspace/test_tiny_dataset_2', 40 | "topic_count": topic_count, 41 | }, 42 | }) 43 | 44 | print(f"Calculating dataset {dataset_name_1}") 45 | fitness = calculate_fitness_of_individual(dataset_name_1, param_s, topic_count=topic_count) 46 | self.assertSetEqual(set(fitness.keys()), set(metrics)) 47 | for m in metrics: 48 | self.assertIsNotNone(fitness[m]) 49 | 50 | print(f"Calculating dataset {dataset_name_2}") 51 | fitness = calculate_fitness_of_individual(dataset_name_2, param_s, topic_count=topic_count) 52 | self.assertSetEqual(set(fitness.keys()), set(metrics)) 53 | for m in metrics: 54 | self.assertIsNotNone(fitness[m]) 55 | 56 | with self.assertRaises(Exception): 57 | calculate_fitness_of_individual("unknown_dataset", param_s, topic_count=topic_count) 58 | 59 | 60 | if __name__ == '__main__': 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | .wy-nav-content { 2 | max-width: none; 3 | } 4 | 5 | .rst-content code.xref { 6 | /* !important prevents the common CSS stylesheets from overriding 7 | this as on RTD they are loaded after this stylesheet */ 8 | color: #E74C3C 9 | } 10 | 11 | html.writer-html4 .rst-content dl:not(.docutils) dl:not(.field-list)>dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list)>dt { 12 | border-left-color: rgb(9, 183, 14) 13 | } 14 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | 11 | 12 | .. 13 | autogenerated from source/_templates/autosummary/class.rst 14 | note it does not have :inherited-members: 15 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/module.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | {{ name | underline }} 5 | 6 | .. automodule:: {{ fullname }} 7 | 8 | {% block classes %} 9 | {% if classes %} 10 | .. rubric:: {{ _('Classes') }} 11 | 12 | .. autosummary:: 13 | :toctree: generated 14 | :nosignatures: 15 | :template: classtemplate.rst 16 | {% for item in classes %} 17 | {{ item }} 18 | {%- endfor %} 19 | {% endif %} 20 | {% endblock %} 21 | 22 | {% block functions %} 23 | {% if functions %} 24 | .. rubric:: {{ _('Functions') }} 25 | 26 | .. autosummary:: 27 | :toctree: generated 28 | :nosignatures: 29 | :template: functiontemplate.rst 30 | {% for item in functions %} 31 | {{ item }} 32 | {%- endfor %} 33 | {% endif %} 34 | {% endblock %} 35 | 36 | 37 | {% block modules %} 38 | {% if modules %} 39 | .. rubric:: {{ _('Modules') }} 40 | 41 | .. autosummary:: 42 | :toctree: 43 | :recursive: 44 | {% for item in modules %} 45 | {{ item }} 46 | {%- endfor %} 47 | {% endif %} 48 | {% endblock %} 49 | -------------------------------------------------------------------------------- /docs/_templates/classtemplate.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline }} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | 11 | 12 | .. 13 | autogenerated from source/_templates/classtemplate.rst 14 | note it does not have :inherited-members: 15 | -------------------------------------------------------------------------------- /docs/_templates/functiontemplate.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | {{ name | underline }} 6 | 7 | .. autofunction:: {{ fullname }} 8 | 9 | .. 10 | autogenerated from source/_templates/functiontemplate.rst 11 | note it does not have :inherited-members: 12 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | import os 10 | import sys 11 | 12 | CURR_PATH = os.path.abspath(os.path.dirname(__file__)) 13 | LIB_PATH = os.path.join(CURR_PATH, os.path.pardir) 14 | sys.path.insert(0, LIB_PATH) 15 | 16 | project = "AutoTM" 17 | copyright = "2023, Strong AI Lab" 18 | author = "Khodorchenko Maria, Butakov Nikolay" 19 | release = "0.1.0" 20 | 21 | # -- General configuration --------------------------------------------------- 22 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 23 | 24 | extensions = ["sphinx.ext.napoleon", "sphinx.ext.autosummary"] 25 | 26 | # Delete external references 27 | autosummary_mock_imports = [ 28 | "numpy", 29 | "pandas", 30 | "artm", 31 | "gensim", 32 | "billiard", 33 | "plotly", 34 | "scipy", 35 | "spacy_langdetect", 36 | "sklearn", 37 | "spacy", 38 | "pymystem3", 39 | "tqdm", 40 | "nltk" 41 | ] 42 | 43 | templates_path = ["_templates"] 44 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 45 | 46 | # -- Options for HTML output ------------------------------------------------- 47 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 48 | 49 | 50 | html_theme = "sphinx_rtd_theme" 51 | pygments_style = 'sphinx' 52 | 53 | html_static_path = ['_static'] 54 | 55 | napoleon_google_docstring = True 56 | napoleon_numpy_docstring = False 57 | 58 | # Autosummary true if you want to generate it from very beginning 59 | # autosummary_generate = True 60 | 61 | # set_type_checking_flag = True 62 | -------------------------------------------------------------------------------- /docs/img/MyLogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/MyLogo.png -------------------------------------------------------------------------------- /docs/img/autotm_arch_v3 (1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/autotm_arch_v3 (1).png -------------------------------------------------------------------------------- /docs/img/img_library_eng.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/img_library_eng.png -------------------------------------------------------------------------------- /docs/img/photo_2023-06-29_20-20-57.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/photo_2023-06-29_20-20-57.jpg -------------------------------------------------------------------------------- /docs/img/photo_2023-06-30_14-17-08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/photo_2023-06-30_14-17-08.jpg -------------------------------------------------------------------------------- /docs/img/pipeling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/pipeling.png -------------------------------------------------------------------------------- /docs/img/strategy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/strategy.png -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. AutoTM documentation master file, created by 2 | sphinx-quickstart on Wed Apr 26 11:07:27 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to AutoTM's documentation! 7 | ================================== 8 | 9 | `AutoTM `_ is am open-source Python library for topic modeling on texts. 10 | 11 | The goal of the library is to provide a simple-to-use interface to perform fast EDA or create interpretable text embeddings. 12 | 13 | .. toctree:: 14 | :maxdepth: 1 15 | :caption: Contents: 16 | 17 | 18 | Installation 19 | Python api 20 | User guide 21 | 22 | 23 | Indices and tables 24 | ================== 25 | 26 | * :ref:`genindex` 27 | * :ref:`modindex` 28 | * :ref:`search` 29 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/pages/algorithms_for_tuning.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | autotm.algorithms_for_tuning 5 | ====================== 6 | 7 | Algorithms for tuning. 8 | 9 | .. .. currentmodule:: autotm.algorithms_for_tuning.bayesian_optimization 10 | .. currentmodule:: autotm.algorithms_for_tuning.genetic_algorithm 11 | 12 | .. autosummary:: 13 | :toctree: ./generated 14 | :nosignatures: 15 | :template: autosummary/class.rst 16 | 17 | crossover.crossover_pmx 18 | crossover.crossover_one_point 19 | crossover.crossover_blend_new 20 | crossover.crossover_blend 21 | crossover.crossover 22 | 23 | -------------------------------------------------------------------------------- /docs/pages/api/algorithms_for_tuning.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | autotm.algorithms_for_tuning 5 | ====================== 6 | 7 | Provides an internal interface for working with optimization algorithms. 8 | 9 | Genetic Algorithm 10 | --------------------- 11 | 12 | .. currentmodule:: autotm.algorithms_for_tuning.genetic_algorithm 13 | 14 | .. autosummary:: 15 | :toctree: ./generated 16 | :nosignatures: 17 | :template: autosummary/class.rst 18 | 19 | crossover.crossover_pmx 20 | crossover.crossover_one_point 21 | crossover.crossover_blend 22 | crossover.crossover 23 | mutation.mutation_one_param 24 | mutation.positioning_mutation 25 | mutation.mutation_combined 26 | mutation.mutation_psm 27 | mutation.mutation 28 | 29 | .. .. currentmodule:: autotm.algorithms_for_tuning.bayesian_optimization -------------------------------------------------------------------------------- /docs/pages/api/fitness.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | autotm.fitness 5 | ====================== 6 | 7 | Fitness. 8 | 9 | .. currentmodule:: autotm.fitness.tm 10 | 11 | .. autosummary:: 12 | :toctree: ./generated 13 | :nosignatures: 14 | :template: autosummary/class.rst 15 | 16 | Dataset 17 | TopicModelFactory 18 | FitnessCalculatorWrapper 19 | calculate_fitness_of_individual 20 | TopicModel 21 | -------------------------------------------------------------------------------- /docs/pages/api/index.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :maxdepth: 2 3 | 4 | algorithms_for_tuning 5 | fitness 6 | preprocessing 7 | visualisation -------------------------------------------------------------------------------- /docs/pages/api/preprocessing.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | autotm.preprocessing 5 | ====================== 6 | 7 | Preprocessing. 8 | 9 | .. currentmodule:: autotm.preprocessing 10 | 11 | .. autosummary:: 12 | :toctree: ./generated 13 | :nosignatures: 14 | 15 | dictionaries_preparation.get_words_dict 16 | -------------------------------------------------------------------------------- /docs/pages/api/visualization.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | autotm.visualization 5 | ====================== 6 | 7 | Visualization. 8 | 9 | .. currentmodule:: autotm.visualization.dynamic_tracker 10 | 11 | .. autosummary:: 12 | :toctree: ./generated 13 | :nosignatures: 14 | :template: autosummary/class.rst 15 | 16 | MetricsCollector 17 | 18 | -------------------------------------------------------------------------------- /docs/pages/fitness.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | autotm.fitness 5 | ====================== 6 | 7 | Fitness. 8 | 9 | .. currentmodule:: autotm.fitness.tm 10 | 11 | .. autosummary:: 12 | :toctree: ./generated 13 | :nosignatures: 14 | :template: autosummary/class.rst 15 | 16 | Dataset 17 | TopicModelFactory 18 | FitnessCalculatorWrapper 19 | calculate_fitness_of_individual 20 | TopicModel 21 | -------------------------------------------------------------------------------- /docs/pages/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ================== 3 | 4 | 5 | Pip installation 6 | ----- 7 | 8 | You can install library `AutoTM` from PyPI. 9 | 10 | .. code-block:: bash 11 | 12 | pip install autotm 13 | 14 | 15 | Development 16 | ----------- 17 | 18 | You can also clone repository and install with poetry. 19 | -------------------------------------------------------------------------------- /docs/pages/preprocessing.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | autotm.preprocessing 5 | ====================== 6 | 7 | Preprocessing. 8 | 9 | .. currentmodule:: autotm.preprocessing 10 | 11 | .. autosummary:: 12 | :toctree: ./generated 13 | :nosignatures: 14 | 15 | dictionaries_preparation.get_words_dict 16 | -------------------------------------------------------------------------------- /docs/pages/userguide/index.rst: -------------------------------------------------------------------------------- 1 | .. _python_userguide: 2 | 3 | Python Guide 4 | ============ 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | regularizers 10 | metrics 11 | -------------------------------------------------------------------------------- /docs/pages/userguide/metrics.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/pages/userguide/metrics.rst -------------------------------------------------------------------------------- /docs/pages/userguide/regularizers.rst: -------------------------------------------------------------------------------- 1 | Regularizers 2 | ================================ 3 | 4 | AutoTM library is using a set of regularizers while optimizing learning strategy. -------------------------------------------------------------------------------- /docs/pages/visualization.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | autotm.visualization 5 | ====================== 6 | 7 | Visualization. 8 | 9 | .. currentmodule:: autotm.visualization.dynamic_tracker 10 | 11 | .. autosummary:: 12 | :toctree: ./generated 13 | :nosignatures: 14 | :template: autosummary/class.rst 15 | 16 | MetricsCollector 17 | 18 | -------------------------------------------------------------------------------- /examples/examples_autotm_fit_predict.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import os 4 | import uuid 5 | from typing import Dict, Any, Optional 6 | 7 | import pandas as pd 8 | from sklearn.model_selection import train_test_split 9 | 10 | from autotm.base import AutoTM 11 | 12 | logging.basicConfig( 13 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 14 | level=logging.DEBUG, datefmt='%Y-%m-%d %H:%M:%S' 15 | ) 16 | logger = logging.getLogger() 17 | 18 | # "base" configuration is an example of default configuration of the algorithm 19 | # "base_en" is showing how to specify dataset parameters and pass it directly to the algorithm 20 | # "static_chromosome" is a previous version of optimization algorithm, may work faster comparing to the "pipeline" version, but results are slightly worse 21 | # "surrogate" is showing how to start training with surrogates option (increase optimization speed) 22 | # "llm" is an example of how to run optimization with gpt-base quality function 23 | # "bayes" is an option with Bayesian Optimization 24 | CONFIGURATIONS = { 25 | "base": { 26 | "alg_name": "ga", 27 | "num_iterations": 50, 28 | "use_pipeline": True 29 | }, 30 | "base_en": { 31 | "alg_name": "ga", 32 | "dataset": { 33 | "lang": "en", 34 | "dataset_path": "data/sample_corpora/imdb_100.csv", 35 | "dataset_name": "imdb_100" 36 | }, 37 | "num_iterations": 20, 38 | "use_pipeline": True 39 | }, 40 | "static_chromosome": { 41 | "alg_name": "ga", 42 | "num_iterations": 20, 43 | "use_pipeline": False 44 | }, 45 | "gost_example": { 46 | "topic_count": 50, 47 | "alg_name": "ga", 48 | "num_iterations": 10, 49 | "use_pipeline": True, 50 | "dataset": { 51 | "lang": "ru", 52 | "col_to_process": 'paragraph', 53 | "dataset_path": "data/sample_corpora/clean_docs_v17_gost_only.csv", 54 | "dataset_name": "gost" 55 | }, 56 | # "individual_type": "llm" 57 | }, 58 | "surrogate": { 59 | "alg_name": "ga", 60 | "num_iterations": 20, 61 | "use_pipeline": True, 62 | "surrogate_name": "random-forest-regressor" 63 | }, 64 | "llm": { 65 | "alg_name": "ga", 66 | "num_iterations": 20, 67 | "use_pipeline": True, 68 | "individual_type": "llm" 69 | }, 70 | "bayes": { 71 | "alg_name": "bayes", 72 | "num_evaluations": 150, 73 | } 74 | } 75 | 76 | 77 | def run(alg_name: str, 78 | alg_params: Dict[str, Any], 79 | dataset: Optional[Dict[str, Any]] = None, 80 | col_to_process: Optional[str] = None, 81 | topic_count: int = 20): 82 | if not dataset: 83 | dataset = { 84 | "lang": "ru", 85 | "dataset_path": "data/sample_corpora/sample_dataset_lenta.csv", 86 | "dataset_name": "lenta_ru" 87 | } 88 | 89 | df = pd.read_csv(dataset['dataset_path']) 90 | train_df, test_df = train_test_split(df, test_size=0.1) 91 | 92 | working_dir_path = f"./autotm_workdir_{uuid.uuid4()}" 93 | model_path = os.path.join(working_dir_path, "autotm_model") 94 | 95 | autotm = AutoTM( 96 | topic_count=topic_count, 97 | preprocessing_params={ 98 | "lang": dataset['lang'] 99 | }, 100 | texts_column_name=col_to_process, 101 | alg_name=alg_name, 102 | alg_params=alg_params, 103 | working_dir_path=working_dir_path, 104 | exp_dataset_name=dataset["dataset_name"] 105 | ) 106 | mixtures = autotm.fit_predict(train_df) 107 | 108 | logger.info(f"Calculated train mixtures: {mixtures.shape}\n\n{mixtures.head(10).to_string()}") 109 | 110 | # saving the model 111 | autotm.save(model_path) 112 | 113 | # loading and checking if everything is fine with predicting 114 | autotm_loaded = AutoTM.load(model_path) 115 | mixtures = autotm_loaded.predict(test_df) 116 | 117 | logger.info(f"Calculated train mixtures: {mixtures.shape}\n\n{mixtures.head(10).to_string()}") 118 | 119 | 120 | def main(conf_name: str = "base"): 121 | if conf_name not in CONFIGURATIONS: 122 | raise ValueError( 123 | f"Unknown configuration {conf_name}. Available configurations: {sorted(CONFIGURATIONS.keys())}" 124 | ) 125 | 126 | conf = CONFIGURATIONS[conf_name] 127 | alg_name = conf['alg_name'] 128 | del conf['alg_name'] 129 | 130 | dataset = None 131 | col_to_process = None 132 | if 'dataset' in conf: 133 | dataset = conf['dataset'] 134 | col_to_process = conf['dataset'].get('col_to_process', None) 135 | del conf['dataset'] 136 | 137 | topic_count = 20 138 | if 'topic_count' in conf: 139 | topic_count = conf['topic_count'] 140 | del conf['topic_count'] 141 | 142 | run( 143 | alg_name=alg_name, 144 | alg_params=conf, 145 | dataset=dataset, 146 | col_to_process=col_to_process, 147 | topic_count=topic_count 148 | ) 149 | 150 | 151 | if __name__ == "__main__": 152 | main(conf_name="gost_example") 153 | -------------------------------------------------------------------------------- /examples/graph_building_for_stroitelstvo.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import os 4 | import uuid 5 | import pandas as pd 6 | import numpy as np 7 | from datasets import load_dataset 8 | from sklearn.model_selection import train_test_split 9 | from autotm.base import AutoTM 10 | from autotm.ontology.ontology_extractor import build_graph 11 | import networkx as nx 12 | 13 | 14 | df = pd.read_dataset('../data/sample_corpora/dataset_books_stroitelstvo.csv') 15 | 16 | working_dir_path = 'autotm_artifacts' 17 | 18 | autotm = AutoTM( 19 | topic_count=10, 20 | texts_column_name='text', 21 | preprocessing_params={ 22 | "lang": "ru", 23 | }, 24 | alg_params={ 25 | "num_iterations": 10, 26 | }, 27 | working_dir_path=working_dir_path 28 | ) 29 | 30 | mixtures = autotm.fit_predict(df) 31 | 32 | # посмотрим на получаемые темы 33 | autotm.print_topics() 34 | 35 | # Составим словарь наименований тем! Обратите внимание, что есть вероятность получения других смесей тем при новом запуске 36 | labels_dict = {'main0':'Кровля', 37 | 'main1': 'Фундамент', 38 | 'main2':'Техника строительства', 39 | 'main3': 'Электросбережение', 40 | 'main4':'Генплан', 41 | 'main5': 'Участок', 42 | 'main6': 'Лестница', 43 | 'main7':'Внешняя отделка', 44 | 'main8': 'Стены', 45 | 'main9': 'Отделка (покраска)', 46 | } 47 | 48 | res_dict, nodes = build_graph(autotm, topic_labels=labels_dict) 49 | 50 | # визуализируем граф 51 | g = nx.DiGraph() 52 | g.add_nodes_from(nodes) 53 | for k, v in res_dict.items(): 54 | g.add_edges_from(([(k)])) 55 | nx.draw(g, with_labels=True) -------------------------------------------------------------------------------- /examples/graph_profile_example.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import os 4 | import uuid 5 | import pandas as pd 6 | import numpy as np 7 | from datasets import load_dataset 8 | from sklearn.model_selection import train_test_split 9 | from autotm.base import AutoTM 10 | from autotm.ontology.ontology_extractor import build_graph 11 | from autotm.clustering import cluster_phi 12 | import networkx as nx 13 | 14 | 15 | df = load_dataset('zloelias/lenta-ru') # https://huggingface.co/datasets 16 | text_sample = df['train']['text'] 17 | df = pd.DataFrame({'text': text_sample}) 18 | 19 | working_dir_path = 'autotm_artifacts' 20 | 21 | # Инициализируем модель и укажем наименование 22 | # суррогатной модели для ускорения вычислений 23 | autotm = AutoTM( 24 | topic_count=20, 25 | texts_column_name='text', 26 | preprocessing_params={ 27 | "lang": "ru", 28 | }, 29 | # alg_name='ga', 30 | alg_params={ 31 | "num_iterations": 15, 32 | "surrogate_alg_name": "GPR" 33 | }, 34 | working_dir_path=working_dir_path 35 | ) 36 | 37 | mixtures = autotm.fit_predict(df) 38 | 39 | # извлекаем матрицу phi - распределение слов по темам 40 | phi_df = autotm._model.get_phi() 41 | 42 | # проведем кластеризацию полученных данных 43 | # из которой можно увидеть, что достаточно хорошо выделяются кластера 44 | cluster_phi(phi_df, n_clusters=10, plot_img=True) 45 | 46 | # осмотрим, какие документы есть в кластерах 47 | mixtures['labels'] = y_kmeans.labels_ 48 | res_df = df.join(mixtures) 49 | print(res_df[res_df['labels'] == 8].text.tolist()[:3]) 50 | 51 | # займемся построением графа, как мы уже знаем 52 | # для лучшей интерпретируемости лучше сделать 53 | # словарь названий тем 54 | autotm.print_topics() 55 | labels_dict = {'main0':'Устройства', 56 | 'main1': 'Спорт', 57 | 'main2':'Экономика', 58 | 'main3': 'Чемпионат', 59 | 'main4':'Исследование', 60 | 'main5': 'Награждение', 61 | 'main6': 'Суд', 62 | 'main7':'Общее', 63 | 'main8': 'Авиаперелеты', 64 | 'main9': 'Музеи', 65 | 'main10': 'Правительство', 66 | 'main11': 'Интернет', 67 | 'main12': 'Искусство', 68 | 'main13': 'Война', 69 | 'main14': 'Нефть', 70 | 'main15': 'Космос', 71 | 'main16': 'Соревнования', 72 | 'main17': 'Биржа', 73 | 'main18': 'Финансы', 74 | 'main19': 'Концерт' 75 | } 76 | 77 | res_dict, nodes = build_graph(autotm, topic_labels=labels_dict) 78 | 79 | # посмотрим на результат обработки - словарь со связями 80 | print(res_dict) 81 | 82 | # Визуализируем получаемый граф 83 | g = nx.DiGraph() 84 | g.add_nodes_from(nodes) 85 | for k, v in res_dict.items(): 86 | g.add_edges_from(([(k)])) 87 | nx.draw(g, with_labels=True) 88 | 89 | # Раскроем одну из нод 90 | mixtures_subset = mixtures.sort_values('main0', ascending=False).head(1000) 91 | 92 | subset_df = df.join(mixtures_subset) 93 | 94 | autotm_subs = AutoTM( 95 | topic_count=6, 96 | texts_column_name='text', 97 | preprocessing_params={ 98 | "lang": "ru", 99 | }, 100 | alg_params={ 101 | "num_iterations": 10, 102 | "surrogate_alg_name": "GPR" 103 | }, 104 | working_dir_path=working_dir_path 105 | ) 106 | 107 | autotm_subs.print_topics() 108 | subs_labels = { 109 | 'main0': 'Технологии', 110 | 'main1': 'Разработка', 111 | 'main2': 'Игры', 112 | 'main3': 'Компьютеры', 113 | 'main4': 'Мобильные', 114 | 'main5': 'Переферия' 115 | } 116 | 117 | autotm_subs.fit_predict(subset_df) 118 | res_dict_subs, nodes_subs = build_graph(autotm_subs, topic_labels=subs_labels) 119 | 120 | g = nx.DiGraph() 121 | g.add_nodes_from(nodes_subs) 122 | for k, v in res_dict_subs.items(): 123 | g.add_edges_from(([(k)])) 124 | nx.draw(g, with_labels=True) -------------------------------------------------------------------------------- /examples/topic_modeling_of_corporative_data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import os 4 | import uuid 5 | from typing import Dict, Any, Optional 6 | 7 | import pandas as pd 8 | from sklearn.model_selection import train_test_split 9 | 10 | from autotm.base import AutoTM 11 | 12 | logging.basicConfig( 13 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 14 | level=logging.DEBUG, datefmt='%Y-%m-%d %H:%M:%S' 15 | ) 16 | logger = logging.getLogger() 17 | 18 | 19 | def main(): 20 | df = pd.read_csv('data/sample_corpora/clean_docs_v17_gost_only.csv') 21 | train_df, test_df = train_test_split(df, test_size=0.1) 22 | 23 | working_dir_path = f"./autotm_workdir_{uuid.uuid4()}" 24 | model_path = os.path.join(working_dir_path, "autotm_model") 25 | 26 | autotm = AutoTM( 27 | topic_count=50, 28 | texts_column_name='paragraph', 29 | preprocessing_params={ 30 | "lang": "ru" 31 | }, 32 | alg_name="ga", 33 | alg_params={ 34 | "num_iterations": 10, 35 | "use_pipeline": True 36 | }, 37 | individual_type="llm", 38 | working_dir_path=working_dir_path, 39 | exp_dataset_name="gost" 40 | ) 41 | mixtures = autotm.fit_predict(train_df) 42 | 43 | logger.info(f"Calculated train mixtures: {mixtures.shape}\n\n{mixtures.head(10).to_string()}") 44 | 45 | # saving the model 46 | autotm.save(model_path) 47 | 48 | # loading and checking if everything is fine with predicting 49 | autotm_loaded = AutoTM.load(model_path) 50 | mixtures = autotm_loaded.predict(test_df) 51 | 52 | logger.info(f"Calculated train mixtures: {mixtures.shape}\n\n{mixtures.head(10).to_string()}") 53 | 54 | 55 | 56 | if __name__ == "__main__": 57 | main(conf_name="gost_example") 58 | -------------------------------------------------------------------------------- /logging.config.yml: -------------------------------------------------------------------------------- 1 | version: 1 2 | console_log: 3 | level: DEBUG 4 | formatters: 5 | simple: 6 | class: logging.Formatter 7 | format: "%(asctime)s %(name)s %(levelname)s %(message)s" 8 | datefmt: "%Y-%m-%d %H:%M:%S" 9 | handlers: 10 | file_handler: 11 | class: logging.FileHandler 12 | filename: example.log 13 | level: DEBUG 14 | formatter: simple 15 | stream_handler: 16 | class: logging.StreamHandler 17 | stream: ext://sys.stderr 18 | level: cfg://console_log.level 19 | formatter: simple 20 | loggers: 21 | logging_example: 22 | level: DEBUG 23 | handlers: [file_handler] 24 | propagate: yes 25 | root: 26 | level: DEBUG 27 | handlers: [stream_handler] -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "autotm" 3 | version = "0.2.3.4" 4 | description = "Automatic hyperparameters tuning for topic models (ARTM approach) using evolutionary algorithms" 5 | authors = [ 6 | "Khodorchenko Maria ", 7 | "Nikolay Butakov alipoov.nb@gmail.com" 8 | ] 9 | readme = "README.md" 10 | license = "Apache-2.0" 11 | homepage = "https://autotm.readthedocs.io/en/latest/" 12 | repository = "https://github.com/ngc436/AutoTM" 13 | packages = [ 14 | {include = "autotm"} 15 | ] 16 | 17 | 18 | [tool.poetry.dependencies] 19 | python = ">=3.9, <3.12" 20 | bigartm = "0.9.2" 21 | protobuf = "<=3.20.0" 22 | tqdm = "4.50.2" 23 | numpy = "*" 24 | PyYAML = "5.3.1" 25 | dataclasses-json = "*" 26 | mlflow = "*" 27 | click = "8.0.1" 28 | scipy = "<=1.10.1" 29 | hyperopt = "*" 30 | pymystem3 = "*" 31 | nltk = "*" 32 | plotly = "*" 33 | spacy = ">=3.5" 34 | spacy-langdetect = "*" 35 | gensim = "^4.1.2" 36 | pandas = "*" 37 | billiard = "*" 38 | dill = "*" 39 | pytest = "*" 40 | celery = ">=4.4.7" 41 | redis = "3.5.3" 42 | jinja2 = "3.0" 43 | PyMySQL = "*" 44 | psycopg2-binary = "*" 45 | pymongo = "3.11.3" 46 | scikit-learn = "^1.1.1" 47 | #pydantic = "1.10.8" 48 | pydantic = "2.6.0" 49 | openai = "^1.31.0" 50 | seaborn = "^0.13.2" 51 | 52 | [tool.poetry.dev-dependencies] 53 | black = "*" 54 | sphinx = "*" 55 | flake8 = "*" 56 | 57 | [tool.pytest.ini_options] 58 | log_cli = true 59 | log_cli_level = "INFO" 60 | log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" 61 | log_cli_date_format = "%Y-%m-%d %H:%M:%S" 62 | 63 | [tool.poetry.scripts] 64 | autotmctl = 'autotm.main:cli' 65 | fitness-worker = 'autotm.main_fitness_worker:main' 66 | 67 | [build-system] 68 | requires = ["poetry-core"] 69 | build-backend = "poetry.core.masonry.api" 70 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/algs/full_pipeline_example.py: -------------------------------------------------------------------------------- 1 | # input example_data padas df with 'text' column 2 | import logging 3 | import time 4 | 5 | import pandas as pd 6 | 7 | from autotm.algorithms_for_tuning.genetic_algorithm.genetic_algorithm import ( 8 | run_algorithm, 9 | ) 10 | from autotm.preprocessing.dictionaries_preparation import prepare_all_artifacts 11 | from autotm.preprocessing.text_preprocessing import process_dataset 12 | 13 | 14 | logging.basicConfig(level=logging.DEBUG) 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | PATH_TO_DATASET = "../../data/sample_corpora/sample_dataset_lenta.csv" # dataset with corpora to be processed 19 | SAVE_PATH = ( 20 | "../data/processed_sample_corpora" # place where all the artifacts will be stored 21 | ) 22 | 23 | dataset = pd.read_csv(PATH_TO_DATASET) 24 | col_to_process = "text" 25 | dataset_name = "lenta_sample" 26 | lang = "ru" # available languages: ru, en 27 | min_tokens_num = 3 # the minimal amount of tokens after processing to save the result 28 | num_iterations = 2 29 | topic_count = 10 30 | exp_id = int(time.time()) 31 | print(exp_id) 32 | 33 | use_nelder_mead_in_mutation = False 34 | use_nelder_mead_in_crossover = False 35 | use_nelder_mead_in_selector = False 36 | train_option = "offline" 37 | 38 | if __name__ == "__main__": 39 | logger.info("Stage 1: Dataset preparation") 40 | process_dataset( 41 | PATH_TO_DATASET, 42 | col_to_process, 43 | SAVE_PATH, 44 | lang, 45 | min_tokens_count=min_tokens_num, 46 | ) 47 | 48 | logger.info("Stage 2: Prepare all artefacts") 49 | prepare_all_artifacts(SAVE_PATH) 50 | 51 | logger.info("Stage 3: Tuning the topic model") 52 | # exp_id and dataset_name will be needed further to store results in mlflow 53 | best_result = run_algorithm( 54 | data_path=SAVE_PATH, 55 | dataset=dataset_name, 56 | exp_id=exp_id, 57 | topic_count=topic_count, 58 | log_file="./log_file_test.txt", 59 | num_iterations=num_iterations, 60 | use_nelder_mead_in_mutation=use_nelder_mead_in_mutation, 61 | use_nelder_mead_in_crossover=use_nelder_mead_in_crossover, 62 | use_nelder_mead_in_selector=use_nelder_mead_in_selector, 63 | train_option=train_option, 64 | ) 65 | logger.info("All finished") 66 | -------------------------------------------------------------------------------- /scripts/algs/nelder_mead_experiments.py: -------------------------------------------------------------------------------- 1 | from autotm.algorithms_for_tuning.nelder_mead_optimization.nelder_mead import ( 2 | NelderMeadOptimization, 3 | ) 4 | import time 5 | import pandas as pd 6 | import os 7 | 8 | # 2 9 | 10 | fnames = [ 11 | "20newsgroups_sample", 12 | "amazon_food_sample", 13 | "banners_sample", 14 | "hotel-reviews_sample", 15 | "lenta_ru_sample", 16 | ] 17 | data_lang = ["en", "en", "ru", "en", "ru"] 18 | dataset_id = 0 19 | 20 | DATA_PATH = os.path.join("/ess_data/GOTM/datasets_TM_scoring", fnames[dataset_id]) 21 | 22 | PATH_TO_DATASET = os.path.join( 23 | DATA_PATH, "dataset_processed.csv" 24 | ) # dataset with corpora to be processed 25 | SAVE_PATH = DATA_PATH # place where all the artifacts will be stored 26 | 27 | dataset = pd.read_csv(PATH_TO_DATASET) 28 | dataset_name = fnames[dataset_id] + "sample_with_nm" 29 | lang = data_lang[dataset_id] # available languages: ru, en 30 | min_tokens_num = 3 # the minimal amount of tokens after processing to save the result 31 | num_iterations = 200 32 | topic_count = 10 33 | exp_id = int(time.time()) 34 | print(exp_id) 35 | train_option = "offline" 36 | 37 | if __name__ == "__main__": 38 | nelder_opt = NelderMeadOptimization( 39 | data_path=SAVE_PATH, 40 | dataset=dataset_name, 41 | exp_id=exp_id, 42 | topic_count=topic_count, 43 | train_option=train_option, 44 | ) 45 | 46 | res = nelder_opt.run_algorithm(num_iterations=num_iterations) 47 | print(res) 48 | print(-res.fun) 49 | -------------------------------------------------------------------------------- /scripts/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/__init__.py -------------------------------------------------------------------------------- /scripts/experiments/plot_boxplot_Final_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/plot_boxplot_Final_results.png -------------------------------------------------------------------------------- /scripts/experiments/plot_boxplot_Start_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/plot_boxplot_Start_results.png -------------------------------------------------------------------------------- /scripts/experiments/plot_progress_hotel-reviews_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/plot_progress_hotel-reviews_sample.png -------------------------------------------------------------------------------- /scripts/experiments/statistics/240602-124458_hotel-reviews_sample_fixed_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,False,4,0.6016707242187753 2 | hotel-reviews_sample,False,5,0.6016707242187753 3 | hotel-reviews_sample,False,6,0.6016707242187753 4 | hotel-reviews_sample,False,7,0.6016707242187753 5 | hotel-reviews_sample,False,8,0.6016707242187753 6 | hotel-reviews_sample,False,9,0.6016707242187753 7 | hotel-reviews_sample,False,10,0.6222956395585306 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240602-124539_hotel-reviews_sample_fixed_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,False,4,0.6364207556768978 2 | hotel-reviews_sample,False,5,0.6364207556768978 3 | hotel-reviews_sample,False,6,0.6364207556768978 4 | hotel-reviews_sample,False,7,0.6364207556768978 5 | hotel-reviews_sample,False,8,0.6364207556768978 6 | hotel-reviews_sample,False,9,0.6364207556768978 7 | hotel-reviews_sample,False,10,0.6364207556768978 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240602-124633_hotel-reviews_sample_fixed_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,False,4,0.6364207556768978 2 | hotel-reviews_sample,False,5,0.6364207556768978 3 | hotel-reviews_sample,False,6,0.6364207556768978 4 | hotel-reviews_sample,False,7,0.6364207556768978 5 | hotel-reviews_sample,False,8,0.6364207556768978 6 | hotel-reviews_sample,False,9,0.6364207556768978 7 | hotel-reviews_sample,False,10,0.6364207556768978 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240602-124719_hotel-reviews_sample_pipeline_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,True,4,0.5368079177080817 2 | hotel-reviews_sample,True,5,0.5368079177080817 3 | hotel-reviews_sample,True,6,0.5368079177080817 4 | hotel-reviews_sample,True,7,0.5368079177080817 5 | hotel-reviews_sample,True,8,0.5368079177080817 6 | hotel-reviews_sample,True,9,0.5368079177080817 7 | hotel-reviews_sample,True,10,0.8378543293795918 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240602-124816_hotel-reviews_sample_pipeline_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,True,4,0.7293230550311838 2 | hotel-reviews_sample,True,5,0.7293230550311838 3 | hotel-reviews_sample,True,6,0.7293230550311838 4 | hotel-reviews_sample,True,7,0.7293230550311838 5 | hotel-reviews_sample,True,8,0.7293230550311838 6 | hotel-reviews_sample,True,9,0.7293230550311838 7 | hotel-reviews_sample,True,10,0.7388121081634285 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240602-124950_hotel-reviews_sample_pipeline_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,True,4,0.6729334700766527 2 | hotel-reviews_sample,True,5,0.6729334700766527 3 | hotel-reviews_sample,True,6,0.6729334700766527 4 | hotel-reviews_sample,True,7,0.6729334700766527 5 | hotel-reviews_sample,True,8,0.6729334700766527 6 | hotel-reviews_sample,True,9,0.6729334700766527 7 | hotel-reviews_sample,True,10,0.6729334700766527 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240603-231051_hotel-reviews_sample_fixed_parameters.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/statistics/240603-231051_hotel-reviews_sample_fixed_parameters.txt -------------------------------------------------------------------------------- /scripts/experiments/statistics/240603-231154_hotel-reviews_sample_fixed_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,False,4,0.6783270373502042 2 | hotel-reviews_sample,False,5,0.6783270373502042 3 | hotel-reviews_sample,False,6,0.6783270373502042 4 | hotel-reviews_sample,False,7,0.6783270373502042 5 | hotel-reviews_sample,False,8,0.682302624412245 6 | hotel-reviews_sample,False,9,0.682302624412245 7 | hotel-reviews_sample,False,10,0.682302624412245 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240603-231313_hotel-reviews_sample_pipeline_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,True,4,0.7047850565368978 2 | hotel-reviews_sample,True,5,0.7047850565368978 3 | hotel-reviews_sample,True,6,0.7047850565368978 4 | hotel-reviews_sample,True,7,0.7047850565368978 5 | hotel-reviews_sample,True,8,0.7258411930803266 6 | hotel-reviews_sample,True,9,0.7258411930803266 7 | hotel-reviews_sample,True,10,0.7258411930803266 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240603-232451_hotel-reviews_sample_fixed_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,False,4,0.652324503842612 2 | hotel-reviews_sample,False,5,0.652324503842612 3 | hotel-reviews_sample,False,6,0.652324503842612 4 | hotel-reviews_sample,False,7,0.652324503842612 5 | hotel-reviews_sample,False,8,0.652324503842612 6 | hotel-reviews_sample,False,9,0.652324503842612 7 | hotel-reviews_sample,False,10,0.6837248912219591 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240603-232546_hotel-reviews_sample_fixed_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,False,4,0.7727277367346939 2 | hotel-reviews_sample,False,5,0.7727277367346939 3 | hotel-reviews_sample,False,6,0.7727277367346939 4 | hotel-reviews_sample,False,7,0.7727277367346939 5 | hotel-reviews_sample,False,8,0.7727277367346939 6 | hotel-reviews_sample,False,9,0.7727277367346939 7 | hotel-reviews_sample,False,10,0.7727277367346939 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240603-232733_hotel-reviews_sample_fixed_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,False,4,0.8807936593885717 2 | hotel-reviews_sample,False,5,0.8807936593885717 3 | hotel-reviews_sample,False,6,0.8807936593885717 4 | hotel-reviews_sample,False,7,0.8807936593885717 5 | hotel-reviews_sample,False,8,0.8807936593885717 6 | hotel-reviews_sample,False,9,0.8807936593885717 7 | hotel-reviews_sample,False,10,0.8807936593885717 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240603-232847_hotel-reviews_sample_pipeline_parameters.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/statistics/240603-232847_hotel-reviews_sample_pipeline_parameters.txt -------------------------------------------------------------------------------- /scripts/experiments/statistics/240603-232916_hotel-reviews_sample_fixed_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,False,4,0.8121951826857141 2 | hotel-reviews_sample,False,5,0.8121951826857141 3 | hotel-reviews_sample,False,6,0.8121951826857141 4 | hotel-reviews_sample,False,7,0.8121951826857141 5 | hotel-reviews_sample,False,8,0.8121951826857141 6 | hotel-reviews_sample,False,9,0.8121951826857141 7 | hotel-reviews_sample,False,10,0.8121951826857141 8 | -------------------------------------------------------------------------------- /scripts/experiments/statistics/240603-233132_hotel-reviews_sample_pipeline_progress.txt: -------------------------------------------------------------------------------- 1 | hotel-reviews_sample,True,4,0.6630215941235102 2 | hotel-reviews_sample,True,5,0.6630215941235102 3 | hotel-reviews_sample,True,6,0.6630215941235102 4 | hotel-reviews_sample,True,7,0.6630215941235102 5 | hotel-reviews_sample,True,8,0.6630215941235102 6 | hotel-reviews_sample,True,9,0.6995545635590203 7 | hotel-reviews_sample,True,10,0.6995545635590203 8 | -------------------------------------------------------------------------------- /scripts/other/preparation_pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path 3 | 4 | from autotm.preprocessing import PREPOCESSED_DATASET_FILENAME 5 | from autotm.preprocessing.dictionaries_preparation import prepare_all_artifacts, prepearing_cooc_dict 6 | from autotm.preprocessing.text_preprocessing import process_dataset 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | PATH_TO_DATASET = "../../data/sample_corpora/sample_dataset_lenta.csv" # dataset with corpora to be processed 12 | SAVE_PATH = "../../data/processed_sample_corpora" # place where all the artifacts will be stored 13 | 14 | col_to_process = "text" 15 | dataset_name = "lenta_sample" 16 | lang = "ru" # available languages: ru, en 17 | min_tokens_num = 3 # the minimal amount of tokens after processing to save the result 18 | 19 | 20 | if __name__ == "__main__": 21 | logger.info("Stage 1: Dataset preparation") 22 | if not os.path.exists(SAVE_PATH): 23 | process_dataset( 24 | PATH_TO_DATASET, 25 | col_to_process, 26 | SAVE_PATH, 27 | lang, 28 | min_tokens_count=min_tokens_num, 29 | ) 30 | else: 31 | logger.info("The preprocessed dataset already exists. Found files on path: %s" % SAVE_PATH) 32 | 33 | logger.info("Stage 2: Prepare all artefacts") 34 | prepare_all_artifacts(SAVE_PATH) 35 | 36 | logger.info("All finished") 37 | -------------------------------------------------------------------------------- /scripts/topic_modeling_of_corporative_data.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import os 4 | import uuid 5 | import pandas as pd 6 | import numpy as np 7 | from datasets import load_dataset 8 | from sklearn.model_selection import train_test_split 9 | from autotm.base import AutoTM 10 | from autotm.ontology.ontology_extractor import build_graph 11 | import networkx as nx 12 | 13 | 14 | df = pd.read_dataset('../data/sample_corpora/clean_docs_v17_gost_only.csv') 15 | 16 | working_dir_path = 'autotm_artifacts' 17 | 18 | autotm = AutoTM( 19 | topic_count=50, 20 | texts_column_name='paragraph', 21 | preprocessing_params={ 22 | "lang": "ru", 23 | }, 24 | alg_params={ 25 | "num_iterations": 10, 26 | }, 27 | working_dir_path=working_dir_path 28 | ) 29 | 30 | mixtures = autotm.fit_predict(df) 31 | 32 | # посмотрим на получаемые темы 33 | autotm.print_topics() -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/test_fit_predict.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import tempfile 4 | 5 | import pandas as pd 6 | import pytest 7 | from numpy.typing import ArrayLike 8 | from sklearn.model_selection import train_test_split 9 | 10 | from autotm.base import AutoTM 11 | 12 | logging.basicConfig(level=logging.INFO) 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def check_predictions(autotm: AutoTM, df: pd.DataFrame, mixtures: ArrayLike): 17 | n_samples, n_samples_mixture = df.shape[0], mixtures.shape[0] 18 | n_topics, n_topics_mixture = len(autotm.topics), mixtures.shape[1] 19 | 20 | assert n_samples_mixture == n_samples 21 | assert n_topics_mixture >= n_topics 22 | assert (~mixtures.isna()).all().all() 23 | assert (~mixtures.isnull()).all().all() 24 | 25 | 26 | # two calls to AutoTM in the same process are not supported due to problems with ARTM deadlocks 27 | @pytest.mark.parametrize('lang,dataset_path', [ 28 | ('en', 'data/sample_corpora/imdb_100.csv'), 29 | # ('ru', 'data/sample_corpora/sample_dataset_lenta.csv'), 30 | ], ids=['imdb_100']) 31 | # ], ids=['imdb_100', 'lenta_ru']) 32 | def test_fit_predict(pytestconfig, lang, dataset_path): 33 | # dataset with corpora to be processed 34 | path_to_dataset = os.path.join(pytestconfig.rootpath, dataset_path) 35 | alg_name = "ga" 36 | 37 | df = pd.read_csv(path_to_dataset) 38 | train_df, test_df = train_test_split(df, test_size=0.1) 39 | 40 | with tempfile.TemporaryDirectory(prefix="fp_tmp_working_dir_") as tmp_working_dir: 41 | model_path = os.path.join(tmp_working_dir, "autotm_model") 42 | 43 | autotm = AutoTM( 44 | preprocessing_params={ 45 | "lang": lang 46 | }, 47 | alg_name=alg_name, 48 | alg_params={ 49 | "num_iterations": 2, 50 | "num_individuals": 4, 51 | }, 52 | working_dir_path=tmp_working_dir 53 | ) 54 | mixtures = autotm.fit_predict(train_df) 55 | check_predictions(autotm, train_df, mixtures) 56 | 57 | # saving the model 58 | autotm.save(model_path) 59 | 60 | # loading and checking if everything is fine with predicting 61 | autotm_loaded = AutoTM.load(model_path) 62 | mixtures = autotm_loaded.predict(test_df) 63 | check_predictions(autotm, test_df, mixtures) 64 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/conftest.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os.path 3 | from typing import Dict 4 | 5 | import pytest 6 | 7 | from autotm.fitness.tm import ENV_AUTOTM_LLM_API_KEY 8 | 9 | 10 | def parse_vw(path: str) -> Dict[str, Dict[str, float]]: 11 | result = dict() 12 | 13 | with open(path, "r") as f: 14 | for line in f.readlines(): 15 | elements = line.split(" ") 16 | word, tokens = elements[0], elements[1:] 17 | if word in result: 18 | raise ValueError("The word is repeated") 19 | result[word] = {token.split(':')[0]: float(token.split(':')[1]) for token in tokens} 20 | 21 | pairs = ((min(word_1, word_2), max(word_1, word_2), value) for word_1, pairs in result.items() for word_2, value in pairs.items()) 22 | pairs = sorted(pairs, key=lambda x: x[0]) 23 | gpairs = itertools.groupby(pairs, key=lambda x: x[0]) 24 | ordered_result = {word_1: {word_2: value for _, word_2, value in pps} for word_1, pps in gpairs} 25 | return ordered_result 26 | 27 | 28 | @pytest.fixture(scope="session") 29 | def test_corpora_path(pytestconfig: pytest.Config) -> str: 30 | return os.path.join(pytestconfig.rootpath, "data", "processed_lenta_ru_sample_corpora") 31 | 32 | @pytest.fixture(scope="session") 33 | def openai_api_key() -> str: 34 | if ENV_AUTOTM_LLM_API_KEY not in os.environ: 35 | raise ValueError(f"Env var {ENV_AUTOTM_LLM_API_KEY} with openai API key is not set") 36 | return os.environ[ENV_AUTOTM_LLM_API_KEY] 37 | -------------------------------------------------------------------------------- /tests/unit/test_dictionaries_preparation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | from autotm.preprocessing.dictionaries_preparation import ( 5 | _add_word_to_dict) 6 | 7 | DATASET_PROCESSED_TINY = pd.DataFrame( 8 | { 9 | "processed_text": [ 10 | "this is text for testing purposes", 11 | "test the testing test to test the test", 12 | "the text is a good example of", 13 | ] 14 | } 15 | ) 16 | 17 | 18 | @pytest.fixture() 19 | def tiny_dataset_(tmpdir): 20 | dataset_processed = tmpdir.join("dataset.txt") 21 | dataset_processed.write(DATASET_PROCESSED_TINY) 22 | return dataset_processed 23 | 24 | 25 | def test__add_word_to_dict(): 26 | test_dict = {} 27 | assert _add_word_to_dict('test', test_dict) == {'test': 1} 28 | -------------------------------------------------------------------------------- /tests/unit/test_llm_fitness.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from autotm.fitness.tm import estimate_topics_with_llm, ENV_AUTOTM_LLM_API_KEY 6 | 7 | 8 | @pytest.mark.skipif(ENV_AUTOTM_LLM_API_KEY not in os.environ, reason=F"ChatGPT API key is not available. {ENV_AUTOTM_LLM_API_KEY} is not set.") 9 | def test_llm_fitness_estimation(openai_api_key: str): 10 | topics = { 11 | "main0": 'entry launch remark rule satellite european build contest player test francis given author canadian cipher', 12 | "main1": 'engineer newspaper holland chapter staff douglas princeton tempest colostate senior executive jose nixon phoenix chemical', 13 | "main3": 'religion atheist belief religious atheism faith christianity existence strong universe follow theist become accept statement', 14 | "main4": 'population bullet arab israeli border village muslim thousand slaughter policy daughter wife authorities switzerland religious', 15 | "main6": 'unix directory comp package library email linux workstation graphics vendor editor user hardware export product', 16 | "main7": 'woman building city left child home face helmet apartment kill wife azerbaijani live father later', 17 | "main10": 'attack lebanese muslim hernlem israeli left troops peace fire away quite stop religion israel especially', 18 | "main11": 'science holland cult compass study tempest investigation methodology nixon psychology department star left colostate scientific', 19 | "main12": 'создавать мужчина премьер добавлять причина оставаться клуб александр сергей закон идти комитет безопасность национальный предлагать', 20 | "main20": 'победа россиянин чемпион американец ассоциация встречаться завершаться килограмм побеждать карьера поражение состояться всемирный категория боец одерживать поедино суметь соперник проигрывать', 21 | "back0": 'garbage garbage garbage', 22 | "back1": 'trash trash trash' 23 | } 24 | topics = {k: v.split(' ') for k, v in topics.items()} 25 | 26 | fitness = estimate_topics_with_llm( 27 | model_id="test_model", 28 | topics=topics, 29 | api_key=openai_api_key, 30 | max_estimated_topics=4, 31 | estimations_per_topic=3 32 | ) 33 | 34 | print(f"Fitness: {fitness}") 35 | 36 | assert fitness > 0 37 | -------------------------------------------------------------------------------- /tests/unit/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from autotm.preprocessing.text_preprocessing import (remove_html, 4 | process_punkt, 5 | get_lemma, lemmatize_text_ru) 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "input,expected_output", 10 | [ 11 | ( 12 | '

Text size using points.

' 13 | 'Text size using pixels.

Text size using relative sizes.

', 14 | 'Text size using points.Text size using pixels.Text size using relative sizes.' 15 | ), 16 | ( 17 | '', '' 18 | ) 19 | ] 20 | ) 21 | def test_html_removal(input, expected_output): 22 | """Test html is removed from text""" 23 | assert remove_html(input) == expected_output 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "input,expected_output", 28 | [ 29 | ( 30 | " lots of space ", "lots of space" 31 | ), 32 | ( 33 | "No#, punctuation? id4321 removed", "No punctuation removed" 34 | ) 35 | ] 36 | ) 37 | def test_process_punct(input, expected_output): 38 | """Test punctuation removal is correct""" 39 | assert process_punkt(input) == expected_output 40 | 41 | 42 | @pytest.mark.parametrize( 43 | "input,expected_output", 44 | [ 45 | ("id19898", "id19898"), 46 | ("нормальные", "нормальные"), 47 | ("dogs", "dog"), 48 | ("toughest", "tough") 49 | ] 50 | ) 51 | def test_get_lemma(input, expected_output): 52 | """Test lemmatization works as intended""" 53 | assert get_lemma(input) == expected_output 54 | 55 | 56 | @pytest.mark.parametrize( 57 | "input,expected_output", 58 | [ 59 | ("Эти тексты для пользователя id198890, потому что он пес", "текст пользователь пес"), 60 | ("Оболочка работает с CPython 2.6+/3.3+ и PyPy 1.9+", "оболочка работать cpython pypy") 61 | ] 62 | ) 63 | def test_full_russian_processing(input, expected_output): 64 | """Test preprocessing for russian""" 65 | assert lemmatize_text_ru(input) == expected_output 66 | -------------------------------------------------------------------------------- /toloka/Estimate_topics_interpretability/input-data.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_id": { 3 | "type": "string", 4 | "hidden": true, 5 | "required": true 6 | }, 7 | "wordset": { 8 | "type": "string", 9 | "hidden": false, 10 | "required": true 11 | }, 12 | "model_id": { 13 | "type": "string", 14 | "hidden": true, 15 | "required": true 16 | }, 17 | "topic_id": { 18 | "type": "string", 19 | "hidden": true, 20 | "required": true 21 | }, 22 | "dataset_name": { 23 | "type": "string", 24 | "hidden": true, 25 | "required": true 26 | }, 27 | "correct_bad_words": { 28 | "type": "json", 29 | "hidden": false, 30 | "required": false 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /toloka/Estimate_topics_interpretability/instructions.md: -------------------------------------------------------------------------------- 1 | Read the set of words. Try to uderstand if they correspond to a single common topic (may be with exception of a few words that don't).
 If it is true, give the topic name with a one word or a short sentance.
Optionally, identify the words that doesn't belong to the common topica among the wordset.

Examples:  

  • Good topic:
    run, swimming, athelete, ski, training, exercises, paddle, referee

    Topic name: sport
  • Good topic (but with "bad" words that don't belong to the common topic):
    greenhouse, chicken, tomato, eggs, 

    Topic name: 
  • Bad topic (cannot be identified)
  • Bad topic (mixing of two or more topics)




  • Advertising and spam. This category includes texts that ask users to go to an external resource or buy a product, offer earnings, or advertise dating sites. Such messages often contain shortened links to various sites.
  • Nonsense. The text is a meaningless set of characters or words. Emoticons and emojis, nicknames and hashtags don't belong to this category.
  • Insults. This category includes insults and threats that are clearly targeted at a user.
  • Violation of the law. The text incites extremist activities, violence, criminal activity, or hatred based on gender, race, nationality, or belonging to a social group; or the text promotes suicide, drugs, or the sale of weapons.
  • Profanity. This category includes comments that contain obscenities or profanity.
2 | -------------------------------------------------------------------------------- /toloka/Estimate_topics_interpretability/output-data.json: -------------------------------------------------------------------------------- 1 | { 2 | "quality": { 3 | "type": "string", 4 | "hidden": false, 5 | "required": true 6 | }, 7 | "bad_words": { 8 | "type": "json", 9 | "hidden": false, 10 | "required": true 11 | }, 12 | "topic_name": { 13 | "type": "string", 14 | "hidden": false, 15 | "required": false 16 | }, 17 | "golden_bad_words": { 18 | "type": "boolean", 19 | "hidden": false, 20 | "required": true 21 | }, 22 | "golden_binary_quality": { 23 | "type": "boolean", 24 | "hidden": false, 25 | "required": true 26 | } 27 | } 28 | 29 | -------------------------------------------------------------------------------- /toloka/Estimate_topics_interpretability/task.css: -------------------------------------------------------------------------------- 1 | /* Task on the page */ 2 | .task { 3 | border: 1px solid #ccc; 4 | width: 500px; 5 | padding: 15px; 6 | display: inline-block; 7 | } 8 | 9 | .tsk-block { 10 | border-radius: 3px; 11 | margin-bottom: 10px; 12 | } 13 | 14 | .obj-text { 15 | border: 1px solid #ccc; 16 | padding: 15px 15px 15px 71px; 17 | position: relative; 18 | background-color: #e6f7dc; 19 | } 20 | 21 | /* Quotation mark */ 22 | .quote-sign { 23 | background-image: url('data:image/svg+xml;utf8,'); 24 | background-position: center center; 25 | background-repeat: no-repeat; 26 | background-size: contain; 27 | width: 36px; 28 | height: 36px; 29 | position: absolute; 30 | top: 7px; 31 | left: 15px; 32 | } 33 | 34 | .tsk-block fieldset { 35 | padding: 10px 20px; 36 | border-radius: 3px; 37 | border: 1px solid #ccc; 38 | margin: 0; 39 | } 40 | 41 | .tsk-block legend { 42 | font-weight: bold; 43 | padding: 0 6px; 44 | } 45 | 46 | .field_type_checkbox { 47 | display: block; 48 | } 49 | 50 | .task__error { 51 | border-radius: 3px; 52 | } 53 | 54 | .second_scale { 55 | display: none; 56 | } 57 | 58 | /* Displaying task content on mobile devices */ 59 | @media screen and (max-width: 600px) { 60 | .task-suite { 61 | padding: 0; 62 | } 63 | 64 | .task { 65 | width: 100%; 66 | margin: 0; 67 | } 68 | 69 | .task-suite div:not(:last-child) { 70 | margin-bottom: 10px; 71 | } 72 | 73 | .hint_label, 74 | .field__hotkey { 75 | display: none; 76 | } 77 | 78 | .field_type_checkbox { 79 | white-space: normal; 80 | } 81 | 82 | .quote-sign { 83 | width: 20px; 84 | height: 20px; 85 | top: 13px; 86 | } 87 | 88 | .obj-text { 89 | padding: 15px 10px 15px 41px; 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /toloka/Estimate_topics_interpretability/task.html: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | {{{wordset}}} 5 |
6 | 7 | 8 |
9 |
10 | Can you name a single topic that is represented by this set of words? 11 | {{field type="radio" name="quality" value="good" size="L" label="Yes" hotkey="1" class="yes"}} 12 | {{field type="radio" name="quality" value="rather_good" size="L" label="Rather yes" hotkey="1" class="yes"}} 13 | {{field type="radio" name="quality" value="rather_bad" size="L" label="Rather no" hotkey="3" class="no"}} 14 | {{field type="radio" name="quality" value="bad" size="L" label="No" hotkey="4" class="no"}} 15 |
16 |
17 | 18 | 19 |
20 | Give a name in one or few words to the topic described with the wordset 21 | {{field type="input" name="topic_name" value="" placeholder="Topic name" validation-show="right-center"}} 22 | 23 |
24 | Which words are out of the topic? 25 | {{#each words}} 26 | {{field type="checkbox" name=(concat "bad_words." this.name) label=this.title}} 27 | {{/each}} 28 |
29 |
30 | -------------------------------------------------------------------------------- /toloka/Estimate_topics_interpretability/task.js: -------------------------------------------------------------------------------- 1 | exports.Task = extend(TolokaHandlebarsTask, function(options) { 2 | TolokaHandlebarsTask.call(this, options); 3 | }, { 4 | is_good_topic: function(solution) { 5 | const positives = ['good', 'rather_good'] 6 | return positives.includes(solution.output_values['quality']) 7 | }, 8 | 9 | intersection: function(setA, setB) { 10 | return new Set( 11 | [...setA].filter(element => setB.has(element)) 12 | ); 13 | }, 14 | 15 | union: function(setA, setB) { 16 | return new Set( 17 | [...setA, ...setB] 18 | ); 19 | }, 20 | 21 | bad_words_set: function(obj) { 22 | let badWordsSet = new Set(); 23 | for (var prop in obj) { 24 | if (Object.prototype.hasOwnProperty.call(obj, prop)) { 25 | let word = prop; 26 | let is_bad = obj[word]; 27 | if (is_bad) { 28 | badWordsSet.add(word); 29 | } 30 | } 31 | } 32 | 33 | return badWordsSet; 34 | }, 35 | 36 | setSolution: function(solution) { 37 | TolokaHandlebarsTask.prototype.setSolution.apply(this, arguments); 38 | var workspaceOptions = this.getWorkspaceOptions(); 39 | 40 | var tname = solution.output_values['topic_name'] || ""; 41 | this.setSolutionOutputValue("topic_name", tname); 42 | 43 | if (this.rendered) { 44 | if (!workspaceOptions.isReviewMode && !workspaceOptions.isReadOnly) { 45 | // Show a set of checkboxes if the answer "There are violations" (BAD) is selected. Otherwise, hide it 46 | if (solution.output_values['quality']) { 47 | 48 | var row = this.getDOMElement().querySelector('.second_scale'); 49 | row.style.display = this.is_good_topic(solution) ? 'block' : 'none'; 50 | 51 | if (!this.is_good_topic(solution)) { 52 | let data = this.getTemplateData(); 53 | let words_out = {}; 54 | for (let i = 0; i < data.words.length; i++) { 55 | words_out[data.words[i].name] = false; 56 | } 57 | 58 | this.setSolutionOutputValue("bad_words", words_out); 59 | 60 | this.setSolutionOutputValue("topic_name", ""); 61 | 62 | } 63 | } 64 | } 65 | } 66 | }, 67 | 68 | getTemplateData: function() { 69 | let data = TolokaHandlebarsTask.prototype.getTemplateData.call(this); 70 | 71 | const words = data.wordset.split(" "); 72 | let word_outs = []; 73 | for (let i = 0; i < words.length; i++) { 74 | word_outs.push({'name': words[i], 'title': words[i]}); 75 | } 76 | 77 | data.words = word_outs; 78 | 79 | return data; 80 | }, 81 | 82 | // Error message processing 83 | addError: function(message, field, errors) { 84 | errors || (errors = { 85 | task_id: this.getOptions().task.id, 86 | errors: {} 87 | }); 88 | errors.errors[field] = { 89 | message: message 90 | }; 91 | 92 | return errors; 93 | }, 94 | 95 | // Checking the answers: if the answer "There are violations" is selected, at least one checkbox must be checked 96 | validate: function(solution) { 97 | var errors = null; 98 | var topic_name = solution.output_values.topic_name; 99 | topic_name = typeof topic_name !== 'undefined' ? topic_name.trim() : ""; 100 | let bad_topic_name = topic_name.length < 3 || topic_name.length > 50 101 | 102 | if (this.is_good_topic(solution) && bad_topic_name) { 103 | errors = this.addError("Topic name is less than 3 symbols or more than 50", '__TASK__', errors); 104 | } 105 | 106 | var correctBadWords = this.getTask().input_values.correct_bad_words; 107 | var golden; 108 | if (!correctBadWords) { 109 | golden = false; 110 | } else { 111 | var badWords = solution.output_values.bad_words; 112 | 113 | let correctBadWordsSet = this.bad_words_set(correctBadWords); 114 | let badWordsSet = this.bad_words_set(badWords); 115 | 116 | var intersection = this.intersection(correctBadWordsSet, badWordsSet) ; 117 | var union = this.union(correctBadWordsSet, badWordsSet); 118 | var golden = intersection.size / union.size >= 0.8 ? true : false; 119 | } 120 | this.setSolutionOutputValue("golden_bad_words", golden); 121 | 122 | var goldenBinaryQuality = this.is_good_topic(solution); 123 | this.setSolutionOutputValue("golden_binary_quality", goldenBinaryQuality); 124 | 125 | return errors || TolokaHandlebarsTask.prototype.validate.apply(this, arguments); 126 | }, 127 | 128 | // Open the second question block in verification mode to see the checkboxes marked by the performer 129 | onRender: function() { 130 | var workspaceOptions = this.getWorkspaceOptions(); 131 | 132 | if (workspaceOptions.isReviewMode || workspaceOptions.isReadOnly || this.is_good_topic(this.getSolution())){ 133 | var row = this.getDOMElement().querySelector('.second_scale'); 134 | row.style.display = 'block'; 135 | } 136 | 137 | this.rendered = true; 138 | } 139 | }); 140 | 141 | function extend(ParentClass, constructorFunction, prototypeHash) { 142 | constructorFunction = constructorFunction || function() { 143 | }; 144 | prototypeHash = prototypeHash || {}; 145 | if (ParentClass) { 146 | constructorFunction.prototype = Object.create(ParentClass.prototype); 147 | } 148 | for (var i in prototypeHash) { 149 | constructorFunction.prototype[i] = prototypeHash[i]; 150 | } 151 | return constructorFunction; 152 | } 153 | --------------------------------------------------------------------------------