├── xai ├── core │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── dataloader │ │ │ ├── __init__.py │ │ │ ├── DataLoaderText.py │ │ │ ├── DataLoaderImage.py │ │ │ └── DataLoaderAbstract.py │ │ ├── datawriter │ │ │ ├── __init__.py │ │ │ └── ResultWriter.py │ │ ├── DataLoaderFactory.py │ │ └── DataManager.py │ ├── model │ │ ├── __init__.py │ │ └── ModelLoader.py │ ├── algorithm │ │ ├── __init__.py │ │ ├── AlgAbstract.py │ │ └── Lime.py │ └── XAIProcessor.py ├── info │ ├── __init__.py │ ├── DatasetInfo.py │ ├── FieldInfo.py │ └── XAIJobInfo.py ├── __init__.py ├── common │ ├── __init__.py │ ├── Common.py │ └── Constants.py └── AutoAPEXAI.py ├── .gitattributes ├── .gitignore ├── README.md ├── requirements.txt ├── .editorconfig ├── xai.sh ├── pyproject.toml ├── Dockerfile ├── setup.py ├── conf └── xai-conf.xml └── LICENSE /xai/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /xai/info/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /xai/core/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /xai/core/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /xai/core/algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /xai/core/data/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /xai/core/data/datawriter/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | target/ 3 | /xai/tests/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoAPE(Advanced Perceptron Engine)-xai 2 | AutoAPE(Advanced Perceptron Engine) - XAI(eXplainable AI) 3 | -------------------------------------------------------------------------------- /xai/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jin.kim@seculayer.com 4 | # Powered by Seculayer © 2021 AI Service Model Team, R&D Center. 5 | -------------------------------------------------------------------------------- /xai/common/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jin.kim@seculayer.com 4 | # Powered by Seculayer © 2021 AI Service Model Team, R&D Center. 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy == 1.20 2 | matplotlib==3.1.1 3 | scikit-image==0.19.2 4 | joblib==1.2.0 5 | tensorflow==2.7.0 6 | scikit-learn == 0.23.1 7 | lime==0.2.0.1 8 | torch == 1.9.1 9 | gensim==3.7.3 10 | xgboost == 1.3.3 11 | lightgbm == 3.3.3 12 | pdpbox == 0.2.1 13 | seaborn == 0.12.2 14 | -------------------------------------------------------------------------------- /xai/common/Common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jin.kim@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | import json 7 | 8 | # ---- automl packages 9 | from pycmmn.Singleton import Singleton 10 | from pycmmn.logger.MPLogger import MPLogger 11 | from pycmmn.utils.FileUtils import FileUtils 12 | from xai.common.Constants import Constants 13 | 14 | 15 | # class : class_name 16 | class Common(metaclass=Singleton): 17 | # make directories 18 | FileUtils.mkdir(Constants.DIR_DATA_ROOT) 19 | FileUtils.mkdir(Constants.DIR_LOG) 20 | 21 | # LOGGER 22 | LOGGER: MPLogger = MPLogger(log_dir=Constants.DIR_LOG, log_level=Constants.LOG_LEVEL, 23 | log_name=Constants.LOG_NAME) 24 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset = utf-8 5 | indent_style = space 6 | trim_trailing_whitespace = true 7 | insert_final_newline = true 8 | end_of_line = lf 9 | 10 | [*.py] 11 | charset = utf-8 12 | end_of_line = lf 13 | indent_style = space 14 | indent_size = 4 15 | trim_trailing_whitespace = true 16 | insert_final_newline = true 17 | 18 | [{Dockerfile,*.dockerfile,Dockerfile.*}] 19 | charset = utf-8 20 | end_of_line = lf 21 | indent_style = space 22 | indent_size = 4 23 | trim_trailing_whitespace = true 24 | insert_final_newline = true 25 | 26 | [pom.xml] 27 | charset = utf-8 28 | end_of_line = lf 29 | indent_style = space 30 | indent_size = 2 31 | trim_trailing_whitespace = true 32 | insert_final_newline = true 33 | 34 | [{*.yml,*.yaml}] 35 | charset = utf-8 36 | end_of_line = lf 37 | indent_style = space 38 | indent_size = 2 39 | trim_trailing_whitespace = true 40 | insert_final_newline = true 41 | -------------------------------------------------------------------------------- /xai/AutoAPEXAI.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jin.kim@seculayer.com 4 | # Powered by Seculayer © 2021-2022 AI Service Model Team, R&D Center. 5 | 6 | from xai.common.Common import Common 7 | from xai.common.Constants import Constants 8 | from xai.core.XAIProcessor import XAIProcessor 9 | 10 | 11 | class AutoAPEXAI(object): 12 | LOGGER = Common.LOGGER.getLogger() 13 | 14 | def __init__(self, key, task_idx): 15 | self.key: str = key 16 | self.task_idx: str = task_idx 17 | 18 | self.LOGGER.info(Constants.VERSION_MANAGER.print_version()) 19 | self.processor = XAIProcessor(key, task_idx, Constants.JOB_TYPE) 20 | 21 | def run(self) -> None: 22 | self.processor.run() 23 | 24 | 25 | if __name__ == '__main__': 26 | import sys 27 | _key = sys.argv[1] 28 | _task_idx = sys.argv[2] 29 | 30 | xai = AutoAPEXAI(_key, _task_idx) 31 | xai.run() 32 | -------------------------------------------------------------------------------- /xai.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ###################################################################################### 3 | # eyeCloudAI 3.1 MLPS Run Script 4 | # Author : Jin Kim 5 | # e-mail : jinkim@seculayer.com 6 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 7 | ###################################################################################### 8 | 9 | APP_PATH=/eyeCloudAI/app/ape 10 | 11 | export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 12 | export CUDA_HOME=/usr/local/cuda 13 | 14 | if [ -x "${APP_PATH}/xai/.venv/bin/python3" ] 15 | then 16 | PYTHON_BIN="$APP_PATH/xai/.venv/bin/python3" 17 | else 18 | PYTHON_BIN="$(command -v python3.7)" 19 | export PYTHONPATH=$PYTHONPATH:$APP_PATH/xai/lib:$APP_PATH/xai 20 | export PYTHONPATH=$PYTHONPATH:$APP_PATH/pycmmn/lib:$APP_PATH/pycmmn 21 | export PYTHONPATH=$PYTHONPATH:$APP_PATH/dataconverter/lib:$APP_PATH/dataconverter 22 | fi 23 | 24 | KEY=${1} 25 | WORKER_IDX=${2} 26 | 27 | $PYTHON_BIN -m xai.AutoAPEXAI ${KEY} ${WORKER_IDX} 28 | -------------------------------------------------------------------------------- /xai/core/data/DataLoaderFactory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jinkim@seculayer.co.kr 4 | # Powered by Seculayer © 2021 Service Model Team 5 | 6 | from xai.common.Common import Common 7 | from xai.common.Constants import Constants 8 | from pycmmn.sftp.SFTPClientManager import SFTPClientManager 9 | from xai.info.XAIJobInfo import XAIJobInfo 10 | from xai.core.data.dataloader.DataLoaderImage import DataLoaderImage 11 | from xai.core.data.dataloader.DataLoaderText import DataLoaderText 12 | 13 | 14 | class DataloaderFactory(object): 15 | LOGGER = Common.LOGGER.getLogger() 16 | 17 | @staticmethod 18 | def create(dataset_format: str, job_info: XAIJobInfo, sftp_client: SFTPClientManager): 19 | case = { 20 | Constants.DATASET_FORMAT_TEXT: "DataLoaderText", 21 | Constants.DATASET_FORMAT_IMAGE: "DataLoaderImage", 22 | Constants.DATASET_FORMAT_TABLE: "DataLoaderText" 23 | }.get(dataset_format) 24 | return eval(case)(job_info, sftp_client) 25 | 26 | 27 | if __name__ == '__main__': 28 | pass 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "xai" 3 | version = "1.0.0" 4 | description = "" 5 | authors = ["Your Name "] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.7, <3.11" 9 | lime = "^0.2.0" 10 | scikit-image = "^0.19.2" 11 | matplotlib = "^3.5.2" 12 | pycmmn = { git = "https://ssdlc-bitbucket.seculayer.com:8443/scm/slaism/autoape-pycmmn.git", rev = "main" } 13 | dataconverter = { git = "https://ssdlc-bitbucket.seculayer.com:8443/scm/slaism/autoape-dataconverter.git", branch = "main" } 14 | tensorflow = "^2.7" 15 | joblib = "^1.1.0" 16 | 17 | [tool.poetry.dev-dependencies] 18 | black = "^22" 19 | isort = "^5.10.1" 20 | pytest = "^7.1.1" 21 | mypy = "^0.942" 22 | hypothesis = "^6.43.3" 23 | pytest-xdist = { extras = ["psutil"], version = "^2.5.0" } 24 | pytest-cov = "^3.0.0" 25 | prospector = { extras = [ 26 | "with_mypy", 27 | "with_vulture", 28 | "with_bandit", 29 | ], version = "^1.7.7" } 30 | coverage = "^6.3.3" 31 | 32 | [build-system] 33 | requires = ["poetry-core>=1.0.0"] 34 | build-backend = "poetry.core.masonry.api" 35 | 36 | [tool.isort] 37 | profile = "black" 38 | 39 | [tool.pytest.ini_options] 40 | minversion = "7.0" 41 | addopts = "-ra -q --failed-first -n auto" 42 | testpaths = ["tests"] 43 | 44 | [tool.pylint.messages_control] 45 | disable = "C0330, C0326" 46 | 47 | [tool.pylint.format] 48 | max-line-length = "88" 49 | -------------------------------------------------------------------------------- /xai/core/data/dataloader/DataLoaderText.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jinkim@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | from xai.common.Constants import Constants 7 | from xai.core.data.dataloader.DataLoaderAbstract import DataLoaderAbstract 8 | 9 | 10 | class DataLoaderText(DataLoaderAbstract): 11 | 12 | def __init__(self, job_info, sftp_client): 13 | super().__init__(job_info, sftp_client) 14 | 15 | def read(self, file_list, fields): 16 | features = list() 17 | labels = list() 18 | origin_data = list() 19 | 20 | for file in file_list: 21 | self.LOGGER.info("read text file : {}".format(file)) 22 | generator = self.sftp_client.load_json_oneline( 23 | filename=file, 24 | dataset_format=Constants.DATASET_FORMAT_TEXT 25 | ) 26 | while True: 27 | line: str = next(generator) 28 | if line == "#file_end#": 29 | break 30 | feature, label, data = self._convert(line, fields, self.functions) 31 | 32 | features.append(feature), labels.append(label), origin_data.append(data) 33 | 34 | self.is_exception = False 35 | 36 | self.make_inout_units(features, fields) 37 | return [features, labels, origin_data] 38 | -------------------------------------------------------------------------------- /xai/core/data/dataloader/DataLoaderImage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jinkim@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | from typing import List 7 | 8 | from xai.common.Constants import Constants 9 | from xai.core.data.dataloader.DataLoaderAbstract import DataLoaderAbstract 10 | 11 | 12 | class DataLoaderImage(DataLoaderAbstract): 13 | 14 | def __init__(self, job_info, sftp_client): 15 | super().__init__(job_info, sftp_client) 16 | 17 | def read(self, file_list, fields): 18 | features = list() 19 | labels = list() 20 | origin_data = list() 21 | 22 | for file in file_list: 23 | self.LOGGER.info("read image file : {}".format(file)) 24 | generator = self.sftp_client.load_json_oneline( 25 | filename=file, 26 | dataset_format=Constants.DATASET_FORMAT_IMAGE 27 | ) 28 | while True: 29 | line = next(generator) 30 | if line == "#file_end#": 31 | break 32 | feature, label, data = self._convert(line, fields, self.functions) 33 | features.append(feature), labels.append(label), origin_data.append(data) 34 | 35 | self.is_exception = False 36 | 37 | self.make_inout_units(features, fields) 38 | return [features, labels, origin_data] 39 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1.3 2 | FROM seculayer/python:3.7 AS builder 3 | LABEL maintainer="jinkim jinkim@seculayer.com" 4 | 5 | ARG APP_DIR="/opt/app" 6 | ARG POETRY_VERSION=1.1.13 7 | 8 | ENV POETRY_VIRTUALENVS_IN_PROJECT=1 \ 9 | PATH="/root/.local/bin:$PATH" 10 | 11 | RUN --mount=type=cache,target=/root/.cache/pip \ 12 | pip install install pipx 13 | RUN pipx ensurepath 14 | RUN pipx install "poetry==$POETRY_VERSION" 15 | 16 | WORKDIR ${APP_DIR} 17 | 18 | COPY pyproject.toml poetry.lock ${APP_DIR} 19 | 20 | RUN --mount=type=secret,id=gitconfig,target=/root/.gitconfig,required=true \ 21 | --mount=type=secret,id=cert,required=true \ 22 | # --mount=type=cache,target=/root/.cache/pypoetry/cache \ 23 | # --mount=type=cache,target=/root/.cache/pypoetry/artifacts \ 24 | poetry install --no-dev --no-root --no-interaction --no-ansi 25 | 26 | 27 | FROM seculayer/python:3.7 AS app 28 | ARG APP_DIR="/opt/app" 29 | ARG CLOUD_AI_DIR="/eyeCloudAI/app/ape/xai/" 30 | ENV CLOUD_AI_DIR ${CLOUD_AI_DIR} 31 | ENV LANG=en_US.UTF-8 LANGUAGE=en_US:en LC_ALL=en_US.UTF-8 32 | 33 | RUN mkdir -p ${CLOUD_AI_DIR} 34 | WORKDIR ${CLOUD_AI_DIR} 35 | 36 | RUN groupadd -g 1000 aiuser 37 | RUN useradd -r -u 1000 -g aiuser aiuser 38 | RUN chown -R aiuser:aiuser /eyeCloudAI 39 | USER aiuser 40 | 41 | COPY --chown=aiuser:aiuser --from=builder ${APP_DIR}/.venv ${CLOUD_AI_DIR}/.venv 42 | COPY --chown=aiuser:aiuser xai ${CLOUD_AI_DIR}/xai 43 | COPY --chown=aiuser:aiuser xai.sh ${CLOUD_AI_DIR} 44 | RUN chmod +x ${CLOUD_AI_DIR}/xai.sh 45 | 46 | ENV PATH="${CLOUD_AI_DIR}/.venv/bin:$PATH" 47 | 48 | CMD [] 49 | -------------------------------------------------------------------------------- /xai/info/DatasetInfo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jinkim@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | from typing import List, Dict 7 | 8 | from xai.info.FieldInfo import FieldInfo 9 | 10 | 11 | class DatasetInfo(object): 12 | def __init__(self, dataset_dict: dict, target_field: str): 13 | self.total_dist_file_cnt: int = int(dataset_dict.get("dist_file_cnt", "1")) 14 | self.metadata: List[Dict] = dataset_dict.get("metadata_json", {}).get("meta", []) 15 | self.data_analysis_json: List[Dict] = dataset_dict.get("fields", []) 16 | self.fields: List[FieldInfo] = self.set_fields( 17 | self.data_analysis_json, 18 | self.metadata, 19 | target_field 20 | ) 21 | self.label_yn: str = dataset_dict.get("label_yn", "N") 22 | self.file_list: List[str] = dataset_dict.get("metadata_json", {}).get("file_list", []) 23 | 24 | @staticmethod 25 | def set_fields(data_analysis_json, metadata, target_field): 26 | fields = list() 27 | for field_dict in data_analysis_json: 28 | field_sn: int = int(field_dict.get("field_sn")) 29 | meta_dict = None 30 | try: 31 | meta_dict = metadata[field_sn] 32 | except IndexError: 33 | meta_dict = {} 34 | field = FieldInfo(field_dict, meta_dict, target_field) 35 | fields.append(field) 36 | 37 | return fields 38 | 39 | def get_fields(self) -> List[FieldInfo]: 40 | return self.fields 41 | -------------------------------------------------------------------------------- /xai/core/algorithm/AlgAbstract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Manki Baek 3 | # e-mail : manki.baek@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | from typing import Dict, List 7 | import numpy as np 8 | 9 | from xai.common.Common import Common 10 | from xai.common.Constants import Constants 11 | from xai.info.XAIJobInfo import XAIJobInfo 12 | from dataconverter.core.ConvertAbstract import ConvertAbstract 13 | from xai.core.data.dataloader.DataLoaderAbstract import DataLoaderAbstract 14 | 15 | 16 | class AlgAbstract(object): 17 | LOGGER = Common.LOGGER.getLogger() 18 | 19 | def __init__(self, model, job_info: XAIJobInfo): 20 | self.model = model 21 | self.job_info = job_info 22 | self.functions: List[List[ConvertAbstract]] = DataLoaderAbstract.build_functions( 23 | fields=self.job_info.get_dataset_info().get_fields() 24 | ) 25 | self.fields = self.job_info.get_dataset_info().get_fields() 26 | 27 | def run(self, data: Dict, json_data: Dict): 28 | raise NotImplementedError 29 | 30 | def model_inference(self, x: list): 31 | start = 0 32 | results = None 33 | len_x = len(x) 34 | 35 | while start < len_x: 36 | end = start + Constants.BATCH_SIZE 37 | batch_x = x[start: end] 38 | try: 39 | # case tensorflow 40 | tmp_results = self.model.predict(batch_x) 41 | tmp_results = np.argmax(tmp_results, axis=1) 42 | except: 43 | # case sklearn 44 | tmp_results = self.model.predict(np.array(batch_x)) 45 | if start == 0: 46 | results = tmp_results 47 | else: 48 | results = np.concatenate((results, tmp_results), axis=0) 49 | 50 | start += Constants.BATCH_SIZE 51 | 52 | return results 53 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jin.kim@seculayer.com 4 | # Powered by Seculayer © 2021 AI Service Model Team, R&D Center. 5 | 6 | # ---------------------------------------------------------------------------------------------- 7 | # AutoML - XAI(Explainable Artificial Intelligence) Setup Script 8 | # ---------------------------------------------------------------------------------------------- 9 | 10 | from typing import List 11 | 12 | from setuptools import setup, find_packages 13 | 14 | 15 | class APEPythonSetup(object): 16 | def __init__(self): 17 | self.module_nm = "xai" 18 | self.version = "1.0.0" 19 | 20 | @staticmethod 21 | def get_require_packages() -> List[str]: 22 | f = open("./requirements.txt", "r") 23 | require_packages = f.readlines() 24 | f.close() 25 | return require_packages 26 | 27 | @staticmethod 28 | def get_packages() -> List[str]: 29 | return find_packages( 30 | exclude=[ 31 | "build", "tests", "scripts", "dists" 32 | ], 33 | ) 34 | 35 | def setup(self) -> None: 36 | setup( 37 | name=self.module_nm, 38 | version=self.version, 39 | description="SecuLayer Inc. AutoML Project \n" 40 | "Module : XAI(Explainable Artificial Intelligence)", 41 | author="Jin Kim", 42 | author_email="jin.kim@seculayer.com", 43 | packages=self.get_packages(), 44 | package_dir={ 45 | "conf": "conf", 46 | "resources": "resources" 47 | }, 48 | python_requires='>3.7', 49 | package_data={ 50 | # self.module_nm: FILE_LIST 51 | }, 52 | install_requires=self.get_require_packages(), 53 | zip_safe=False, 54 | ) 55 | 56 | 57 | if __name__ == '__main__': 58 | print(" __ ______ ____ _____") 59 | print(" / |/ / | / __ \/ ___/") 60 | print(" / /|_/ / /| | / /_/ /\__ \ ") 61 | print(" / / / / ___ |/ _, _/___/ / ") 62 | print("/_/ /_/_/ |_/_/ |_|/____/ ") 63 | print(" ") 64 | APEPythonSetup().setup() 65 | 66 | -------------------------------------------------------------------------------- /xai/info/FieldInfo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jinkim@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | import re 7 | from typing import List, Dict 8 | 9 | from pycmmn.utils.StringUtil import StringUtil 10 | from dataconverter.core.ConvertFunctionInfo import ConvertFunctionInfo, ConvertFunctionInfoBuilder 11 | 12 | 13 | class FieldInfo(object): 14 | def __init__(self, field_dict: dict, metadata_dict: dict, target_field: str): 15 | self.field_sn = StringUtil.get_int(field_dict.get("field_sn", 0)) 16 | self.field_name = field_dict.get("name", "") 17 | self.target_field = target_field 18 | self.stat_dict = field_dict.get("statistic", dict()) 19 | self.field_type = field_dict.get("field_type") 20 | 21 | if self.target_field == self.field_name: 22 | self.is_label = True 23 | else: 24 | self.is_label = False 25 | self.is_multiple = True if len(self.field_name.split("@COMMA@")) >= 2 else False 26 | self.function: List[ConvertFunctionInfo] = self._create_functions(field_dict.get("functions", "")) 27 | 28 | def __str__(self) -> str: 29 | return "name : {}".format(self.field_name) 30 | 31 | def label(self) -> bool: 32 | return self.is_label 33 | 34 | def multiple(self) -> bool: 35 | return self.is_multiple 36 | 37 | # --- static variables 38 | _REGEX_FN_STR = "(\\[\\[@[\\w\\d_]+\\([^\\]]*\\)\\]\\])" 39 | _PATTERN_REGEX_FN_STR = re.compile(_REGEX_FN_STR) 40 | 41 | @classmethod 42 | def _get_function_str_list(cls, functions) -> List[str]: 43 | return cls._PATTERN_REGEX_FN_STR.findall(functions) 44 | 45 | def _create_functions(self, full_fn_str: str) -> List[ConvertFunctionInfo]: 46 | functions: List[ConvertFunctionInfo] = list() 47 | for fn_str in self._get_function_str_list(full_fn_str): 48 | fn_info = ConvertFunctionInfoBuilder() \ 49 | .set_fn_str(fn_str) \ 50 | .set_stat_dict(self.stat_dict) \ 51 | .build() 52 | functions.append(fn_info) 53 | return functions 54 | 55 | def get_function(self) -> List[ConvertFunctionInfo]: 56 | return self.function 57 | 58 | def get_field_name(self) -> str: 59 | return self.field_name 60 | 61 | def get_statistic(self) -> Dict: 62 | return self.stat_dict 63 | -------------------------------------------------------------------------------- /xai/core/data/datawriter/ResultWriter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jinkim@seculayer.co.kr 4 | # Powered by Seculayer © 2017 AI-TF Team 5 | ###################################################################################### 6 | import os 7 | import json 8 | import socket 9 | 10 | from pycmmn.utils.Utils import Utils 11 | from pycmmn.utils.FileUtils import FileUtils 12 | ###################################################################################### 13 | 14 | 15 | # class : ResultWriterAbstract 16 | class ResultWriter(object): 17 | @staticmethod 18 | def result_file_write(**kwargs): 19 | result_path = kwargs["result_path"] 20 | results = kwargs["results"] 21 | result_type = kwargs["result_type"] 22 | host_name = socket.gethostname() 23 | 24 | file_name = f"{result_path}/{result_type}_{Utils.get_current_time_with_mili_sec()}_{host_name}" 25 | 26 | len_results = len(results) 27 | # Common.LOGGER.get_logger().info("[{}] : {} rows".format(file_name, len_results)) 28 | if len_results == 0: 29 | return 30 | 31 | # write json 32 | start = 0 33 | batch_size = 20000 34 | idx = 0 35 | while start <= len_results: 36 | if start + batch_size < len_results: 37 | batch_result = results[start: start + batch_size] 38 | else: 39 | batch_result = results[start:] 40 | 41 | # Common.LOGGER.get_logger().info("[{}_{}] : {} rows".format(file_name, idx, len(batch_result))) 42 | f = FileUtils.file_pointer("{}_{}.tmp".format(file_name, idx), "w") 43 | line_idx = 0 44 | for line in batch_result: 45 | # json.dump(line, codecs.getwriter("utf-8")(f) , ensure_ascii=False) 46 | json.dump(line, f, ensure_ascii=False) 47 | f.write("\n") 48 | 49 | if line_idx % 5000 == 0: 50 | f.flush() 51 | line_idx += 1 52 | 53 | f.close() 54 | # rename 55 | # os.rename("{}_{}.tmp".format(file_name, idx), "{}_{}.{}".format(file_name, idx, ext)) 56 | os.rename("{}_{}.tmp".format(file_name, idx), "{}_{}.{}".format(file_name, idx, "done")) 57 | # Common.LOGGER.get_logger().info("{} result file write complete - {}_{}.{}".format(ext, file_name, idx, "done")) 58 | start += batch_size 59 | idx += 1 60 | -------------------------------------------------------------------------------- /conf/xai-conf.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | dir_data_root 5 | /eyeCloudAI/data 6 | Data root directory 7 | 8 | 9 | ape.features.dir 10 | /processing/ape/division 11 | 12 | 13 | ape_web_file_dir 14 | /DATA/eyeCloudAI/app/www/store/upload 15 | 16 | 17 | 18 | dir_log 19 | /eyeCloudAI/logs 20 | None 21 | 22 | 23 | log_name 24 | DataProcessRecommender 25 | None 26 | 27 | 28 | log_level 29 | INFO 30 | [INFO, DEBUG, WARN, ERROR, CRITICAL] 31 | 32 | 33 | lime_image_sample_cnt 34 | 50 35 | lime image explainer parameter 36 | 37 | 38 | lime_tabular_sample_cnt 39 | 1000 40 | lime tabular explainer parameter 41 | 42 | 43 | lime_text_sample_cnt 44 | 1000 45 | lime text explainer parameter 46 | 47 | 48 | 49 | storage_svc 50 | 51 | 10.1.35.230 52 | 53 | 54 | storage_sftp_port 55 | 56 | 30122 57 | 58 | 59 | storage_username 60 | HE12RmzKHQtH3bL7tTRqCg== 61 | 62 | 63 | storage_password 64 | jTf6XrqcYX1SAhv9JUPq+w== 65 | 66 | 67 | mrms_svc 68 | mrms-svc 69 | 70 | 71 | mrms_sftp_port 72 | 10022 73 | 74 | 75 | mrms_rest_port 76 | 9200 77 | 78 | 79 | mrms_username 80 | HE12RmzKHQtH3bL7tTRqCg== 81 | 82 | 83 | mrms_password 84 | jTf6XrqcYX1SAhv9JUPq+w== 85 | 86 | 87 | -------------------------------------------------------------------------------- /xai/core/model/ModelLoader.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import joblib 3 | import tensorflow as tf 4 | from xgboost import XGBClassifier 5 | import pickle 6 | 7 | from xai.common.Common import Common 8 | from xai.common.Constants import Constants 9 | from pycmmn.sftp.SFTPClientManager import SFTPClientManager 10 | from pycmmn.utils.FileUtils import FileUtils 11 | 12 | 13 | class ModelLoader(object): 14 | LOGGER = Common.LOGGER.getLogger() 15 | MRMS_SFTP_MANAGER: SFTPClientManager = SFTPClientManager( 16 | "{}:{}".format(Constants.MRMS_SVC, Constants.MRMS_SFTP_PORT), 17 | Constants.MRMS_USER, Constants.MRMS_PASSWD, LOGGER 18 | ) 19 | 20 | @classmethod 21 | def load(cls, lib_type, model_id): 22 | case_fn: Callable = { 23 | Constants.LIB_TYPE_TF: ModelLoader._get_tf_model, 24 | Constants.LIB_TYPE_SKL: ModelLoader._get_skl_model, 25 | Constants.LIB_TYPE_XGB: ModelLoader._get_xgb_model, 26 | Constants.LIB_TYPE_LGBM: ModelLoader._get_lgbm_model 27 | }.get(lib_type) 28 | 29 | ModelLoader._scp_model_from_storage(model_id) 30 | dir_model = '{}/{}/0'.format( 31 | Constants.DIR_TEMP, model_id 32 | ) 33 | if FileUtils.is_exist(dir_model): 34 | try: 35 | cls.LOGGER.info("model load ....") 36 | cls.LOGGER.info("model dir : {}".format(dir_model)) 37 | 38 | return case_fn(dir_model) 39 | except Exception as e: 40 | cls.LOGGER.error(e, exc_info=True) 41 | else: 42 | cls.LOGGER.warn("MODEL FILE IS NOT EXIST : [{}]".format(dir_model)) 43 | 44 | @classmethod 45 | def _get_xgb_model(cls, dir_model): 46 | model = XGBClassifier() 47 | try: 48 | model.load_model(dir_model + "/model.h5") 49 | return model 50 | except Exception as e: 51 | cls.LOGGER.error(e, exc_info=True) 52 | raise e 53 | 54 | @classmethod 55 | def _get_lgbm_model(cls, dir_model): 56 | try: 57 | f = open(dir_model + "/apeflow.model", "rb") 58 | return pickle.load(f) 59 | except Exception as e: 60 | cls.LOGGER.error(e, exc_info=True) 61 | raise e 62 | 63 | @classmethod 64 | def _get_tf_model(cls, dir_model): 65 | try: 66 | return tf.keras.models.load_model(dir_model) 67 | except Exception as e: 68 | cls.LOGGER.error(e, exc_info=True) 69 | raise e 70 | 71 | @classmethod 72 | def _get_skl_model(cls, dir_model): 73 | try: 74 | return joblib.load("{}/skl_model.joblib".format(dir_model)) 75 | except Exception as e: 76 | cls.LOGGER.error(e, exc_info=True) 77 | raise e 78 | 79 | @classmethod 80 | def _scp_model_from_storage(cls, model_id) -> None: 81 | remote_path = f"{Constants.DIR_STORAGE}/{model_id}" 82 | try: 83 | cls.MRMS_SFTP_MANAGER.scp_from_storage( 84 | remote_path, Constants.DIR_TEMP 85 | ) 86 | except Exception as e: 87 | cls.LOGGER.error(e, exc_info=True) 88 | -------------------------------------------------------------------------------- /xai/core/data/DataManager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jinkim@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | from multiprocessing import Queue 7 | from typing import List 8 | 9 | from pycmmn.Singleton import Singleton 10 | from xai.info.XAIJobInfo import XAIJobInfo 11 | from xai.info.FieldInfo import FieldInfo 12 | 13 | from xai.common.Common import Common 14 | from pycmmn.decorator.CalTimeDecorator import CalTimeDecorator 15 | from xai.info import DatasetInfo 16 | from pycmmn.sftp.SFTPClientManager import SFTPClientManager 17 | from xai.core.data.DataLoaderFactory import DataloaderFactory 18 | 19 | 20 | class DataManager(object, metaclass=Singleton): 21 | LOGGER = Common.LOGGER.getLogger() 22 | 23 | def __init__(self, job_info: XAIJobInfo, sftp_client: SFTPClientManager) -> None: 24 | # threading.Thread.__init__(self) 25 | self.job_info: XAIJobInfo = job_info 26 | self.data_queue: Queue = Queue() 27 | self.sftp_client = sftp_client 28 | 29 | self.dataset_info: DatasetInfo = self.job_info.get_dataset_info() 30 | self.dataset = {} 31 | 32 | @CalTimeDecorator("Data Manager", LOGGER) 33 | def run(self) -> None: 34 | try: 35 | self.LOGGER.info("DataManager Start.") 36 | 37 | # ---- data load 38 | self.dataset = self.read_files(self.dataset_info.get_fields()) 39 | 40 | self.LOGGER.info("DataManager End.") 41 | except Exception as e: 42 | self.LOGGER.error(e, exc_info=True) 43 | 44 | raise e 45 | 46 | def read_files(self, fields: List[FieldInfo]) \ 47 | -> List: 48 | # ---- prepare 49 | # 분산이 되면 워커마다 파일 1개씩, 아니면 워커1개가 모든 파일을 읽는다 50 | file_list = list() 51 | if self.job_info.get_dist_yn() \ 52 | and (len(self.job_info.get_file_list()) == self.job_info.get_num_worker()): 53 | idx = int(self.job_info.get_task_idx()) 54 | file_list.append(self.job_info.get_file_list()[idx]) 55 | else: 56 | file_list = self.job_info.get_file_list() 57 | 58 | # data_list = self.read_subproc(file_list, fields) 59 | data_list = DataloaderFactory.create( 60 | dataset_format=self.job_info.get_dataset_format(), 61 | job_info=self.job_info, 62 | sftp_client=self.sftp_client 63 | ).read(file_list, fields) 64 | 65 | return data_list 66 | 67 | def get_learn_data(self) -> dict: 68 | return {"x": self.dataset[0][0], "y": self.dataset[0][1]} 69 | 70 | def get_eval_data(self) -> dict: 71 | return {"x": self.dataset[1][0], "y": self.dataset[1][1]} 72 | 73 | def get_json_data(self) -> list: 74 | return self.dataset[2] 75 | 76 | def get_inference_data(self) -> dict: 77 | return {"x": self.dataset[0], "y": self.dataset[1]} 78 | 79 | 80 | # ---- builder Pattern 81 | class DataManagerBuilder(object): 82 | def __init__(self): 83 | self.job_info = None 84 | self.sftp_client = None 85 | 86 | def set_job_info(self, job_info): 87 | self.job_info = job_info 88 | return self 89 | 90 | def set_sftp_client(self, sftp_client): 91 | self.sftp_client = sftp_client 92 | return self 93 | 94 | def build(self) -> DataManager: 95 | return DataManager(job_info=self.job_info, sftp_client=self.sftp_client) 96 | -------------------------------------------------------------------------------- /xai/common/Constants.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jin.kim@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | from pycmmn.Singleton import Singleton 7 | from pycmmn.utils.ConfUtils import ConfUtils 8 | from pycmmn.utils.FileUtils import FileUtils 9 | from pycmmn.tools.VersionManagement import VersionManagement 10 | 11 | import os 12 | 13 | 14 | # class : Constants 15 | class Constants(metaclass=Singleton): 16 | _working_dir = os.getcwd() 17 | _data_cvt_dir = _working_dir + "/../xai" 18 | _conf_xml_filename = _data_cvt_dir + "/conf/xai-conf.xml" 19 | 20 | _MODE = "deploy" 21 | 22 | if not FileUtils.is_exist(_conf_xml_filename): 23 | _MODE = "dev" 24 | 25 | if _working_dir != "/eyeCloudAI/app/ape/xai": 26 | os.chdir(FileUtils.get_realpath(__file__) + "/../../") 27 | 28 | _working_dir = os.getcwd() 29 | _data_cvt_dir = _working_dir + "/../xai" 30 | _conf_xml_filename = _working_dir + "/conf/xai-conf.xml" 31 | 32 | # load config xml file 33 | _CONFIG = ConfUtils.load(filename=os.getcwd() + "/conf/xai-conf.xml") 34 | 35 | try: 36 | VERSION_MANAGER = VersionManagement(app_path=_working_dir) 37 | except Exception as e: 38 | # DEFAULT 39 | VersionManagement.generate( 40 | version="1.0.0", 41 | app_path=_working_dir, 42 | module_nm="xai", 43 | ) 44 | VERSION_MANAGER = VersionManagement(app_path=_working_dir) 45 | VERSION = VERSION_MANAGER.VERSION 46 | MODULE_NM = VERSION_MANAGER.MODULE_NM 47 | 48 | # Directories 49 | DIR_DATA_ROOT = _CONFIG.get("dir_data_root", "/eyeCloudAI/data") 50 | DIR_PROCESSING = DIR_DATA_ROOT + _CONFIG.get("dir_processing", "/processing/ape") 51 | DIR_JOB = DIR_PROCESSING + _CONFIG.get("dir_job", "/jobs") 52 | DIR_STORAGE = DIR_DATA_ROOT + _CONFIG.get("dir_storage", "/storage/ape") 53 | DIR_TEMP = DIR_DATA_ROOT + "/processing/ape/temp" 54 | DIR_RESOURCES = ( 55 | FileUtils.get_realpath(file=__file__) + "/resources" 56 | ) 57 | DIR_RESULT = DIR_PROCESSING + _CONFIG.get("dir_result", "/results_xai") 58 | DIR_WEB_FILE = _CONFIG.get("ape_web_file_dir", "/eyeCloudAI/app/www/store/upload") 59 | 60 | # Logs 61 | DIR_LOG = _CONFIG.get("dir_log", "./logs") 62 | LOG_LEVEL = _CONFIG.get("log_level", "INFO") 63 | LOG_NAME = _CONFIG.get("log_name", "XAI") 64 | 65 | # Hosts 66 | MRMS_SVC = _CONFIG.get("mrms_svc", "mrms-svc") 67 | MRMS_SFTP_PORT = int(_CONFIG.get("mrms_sftp_port", "10022")) 68 | MRMS_REST_PORT = int(_CONFIG.get("mrms_rest_port", "9200")) 69 | MRMS_USER = _CONFIG.get("mrms_username", "HE12RmzKHQtH3bL7tTRqCg==") 70 | MRMS_PASSWD = _CONFIG.get("mrms_password", "jTf6XrqcYX1SAhv9JUPq+w==") 71 | 72 | STORAGE_SVC = _CONFIG.get("storage_svc", "ape-storage-svc") 73 | STORAGE_SFTP_PORT = int(_CONFIG.get("storage_sftp_port", "10122")) 74 | STORAGE_USER = _CONFIG.get("storage_username", "HE12RmzKHQtH3bL7tTRqCg==") 75 | STORAGE_PASSWD = _CONFIG.get("storage_password", "jTf6XrqcYX1SAhv9JUPq+w==") 76 | 77 | REST_URL_ROOT = "http://{}:{}".format( 78 | MRMS_SVC, MRMS_REST_PORT 79 | ) 80 | 81 | JOB_TYPE = "xai" 82 | 83 | LIB_TYPE_TF_SINGLE = "TF_SINGLE" 84 | LIB_TYPE_TF = "TF" 85 | LIB_TYPE_GS = "GS" 86 | LIB_TYPE_SKL = "SKL" 87 | LIB_TYPE_LGBM = "LGBM" 88 | LIB_TYPE_XGB = "XGB" 89 | 90 | DATASET_FORMAT_TEXT = "1" 91 | DATASET_FORMAT_IMAGE = "2" 92 | DATASET_FORMAT_TABLE = "3" 93 | 94 | XAI_ALG_GRAD_SCAM = "gram_scam" 95 | XAI_ALG_LIME = "lime" 96 | 97 | LIME_IMAGE_SAMPLE_CNT = int(_CONFIG.get("lime_image_sample_cnt", "50")) 98 | LIME_TABULAR_SAMPLE_CNT = int(_CONFIG.get("lime_tabular_sample_cnt", "1000")) 99 | LIME_TEXT_SAMPLE_CNT = int(_CONFIG.get("lime_text_sample_cnt", "1000")) 100 | 101 | STATUS_XAI_COMPLETE = "6" 102 | STATUS_XAI_ERROR = "7" 103 | 104 | BATCH_SIZE = int(_CONFIG.get("inference_batch_size", "512")) 105 | 106 | # TABLE FIELD TYPE 107 | FIELD_TYPE_NULL = "null" 108 | FIELD_TYPE_INT = "int" 109 | FIELD_TYPE_FLOAT = "float" 110 | FIELD_TYPE_STRING = "string" 111 | FIELD_TYPE_IMAGE = "image" 112 | FIELD_TYPE_DATE = "date" 113 | FIELD_TYPE_LIST = "list" 114 | 115 | 116 | if __name__ == '__main__': 117 | print(Constants.DIR_DATA_ROOT) 118 | -------------------------------------------------------------------------------- /xai/core/data/dataloader/DataLoaderAbstract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jinkim@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | from typing import Tuple, List 7 | import numpy as np 8 | import json 9 | 10 | from pycmmn.rest.RestManager import RestManager 11 | from xai.common.Common import Common 12 | from xai.common.Constants import Constants 13 | from xai.info.FieldInfo import FieldInfo 14 | from dataconverter.core.ConvertAbstract import ConvertAbstract 15 | from dataconverter.core.ConvertFactory import ConvertFactory 16 | from pycmmn.utils.ListParser import ListParser 17 | 18 | 19 | class DataLoaderAbstract(object): 20 | LOGGER = Common.LOGGER.getLogger() 21 | 22 | def __init__(self, job_info, sftp_client): 23 | self.job_info = job_info 24 | self.sftp_client = sftp_client 25 | self.functions: List[List[ConvertAbstract]] = self.build_functions( 26 | self.job_info.get_dataset_info().get_fields() 27 | ) 28 | self.LOGGER.info(self.functions) 29 | 30 | self.is_exception = False 31 | 32 | @classmethod 33 | def build_functions(cls, fields: List[FieldInfo]) -> List[List[ConvertAbstract]]: 34 | functions: List[List[ConvertAbstract]] = list() 35 | for field in fields: 36 | cvt_fn_list: List[ConvertAbstract] = list() 37 | for fn_info in field.get_function(): 38 | cvt_fn_list.append(ConvertFactory.create_cvt_fn( 39 | cvt_fn_info=fn_info, 40 | logger=cls.LOGGER, 41 | cvt_dict=RestManager.get_cnvr_dict( 42 | rest_url_root=Constants.REST_URL_ROOT, logger=cls.LOGGER 43 | ) 44 | )) 45 | functions.append(cvt_fn_list) 46 | return functions 47 | 48 | def _convert(self, line, fields: List[FieldInfo], functions) -> Tuple[list, list, dict]: 49 | features = list() 50 | labels = list() 51 | line_error = False 52 | 53 | for idx, field in enumerate(fields): 54 | name = field.field_name 55 | if field.field_type == Constants.FIELD_TYPE_LIST: 56 | value = ListParser.parse(line.get(name, "[]")) 57 | elif not field.multiple(): 58 | value = line.get(name, "") 59 | else: 60 | value = list() 61 | for _name in name.split("@COMMA@"): 62 | value.append(line.get(_name, "")) 63 | 64 | # TODO : 한 필드에 2개의 함수가 있을 경우 잘 동작하는지 확인 65 | for fn in functions[idx]: 66 | try: 67 | value = fn.apply(value) 68 | except Exception as e: 69 | if not self.is_exception: 70 | self.LOGGER.error(e, exc_info=True) 71 | value = self.get_dummy(fn) 72 | line_error = True 73 | 74 | if field.label(): 75 | labels += value 76 | else: 77 | if name == "image": 78 | features = value[0] 79 | else: 80 | features += value 81 | 82 | if not self.is_exception and line_error: 83 | self.is_exception = line_error 84 | 85 | return features, labels, line 86 | 87 | def make_inout_units(self, features, fields: List[FieldInfo]): 88 | input_units = np.shape(features)[1:] 89 | output_units = self.get_output_units(fields) 90 | self.job_info.set_input_units(input_units) 91 | self.job_info.set_output_units(output_units) 92 | self.LOGGER.info("input_units : {}".format(input_units)) 93 | self.LOGGER.info("output_units : {}".format(output_units)) 94 | 95 | def get_output_units(self, fields: List[FieldInfo]): 96 | for field_info in fields: 97 | self.LOGGER.info(field_info.is_label) 98 | self.LOGGER.info(field_info.stat_dict) 99 | if field_info.is_label: 100 | try: 101 | return len(field_info.stat_dict.get("unique", {})) 102 | except Exception as e: 103 | self.LOGGER.error(e, exc_info=True) 104 | return 1 105 | 106 | def write_dp_result(self, features, labels, file_path): 107 | rst_dict = dict() 108 | save_path = file_path.rsplit('/', 2)[0] 109 | 110 | self.LOGGER.info("features[0]: {}".format(features[0])) 111 | self.LOGGER.info("labels[0]: {}".format(labels[0])) 112 | 113 | rst_dict['features'] = features 114 | rst_dict['targets'] = labels 115 | 116 | f = self.sftp_client.get_client().open( 117 | f"{save_path}/{self.job_info.get_hist_no()}_{self.job_info.get_task_idx()}.dp", 118 | 'w' 119 | ) 120 | 121 | try: 122 | f.write(json.dumps(rst_dict, indent=2)) 123 | except Exception as e: 124 | self.LOGGER.error(e, exc_info=True) 125 | finally: 126 | f.close() 127 | 128 | def read(self, file_list: List[str], fields: List[FieldInfo]) -> List: 129 | raise NotImplementedError 130 | 131 | @staticmethod 132 | def get_dummy(fn): 133 | if fn.get_return_type == "str": 134 | dummy_val = "" 135 | elif fn.get_return_type == "float": 136 | dummy_val = 0. 137 | else: 138 | dummy_val = 0 139 | 140 | return [dummy_val] * fn.get_num_feat() 141 | -------------------------------------------------------------------------------- /xai/info/XAIJobInfo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Jin Kim 3 | # e-mail : jinkim@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | import logging 7 | from typing import Dict 8 | 9 | from pycmmn.Singleton import Singleton 10 | from xai.common.Constants import Constants 11 | from pycmmn.exceptions.FileLoadError import FileLoadError 12 | from pycmmn.exceptions.JsonParsingError import JsonParsingError 13 | from xai.info.DatasetInfo import DatasetInfo 14 | from pycmmn.sftp.SFTPClientManager import SFTPClientManager 15 | from pycmmn.utils.StringUtil import StringUtil 16 | 17 | 18 | class XAIJobInfo(object, metaclass=Singleton): 19 | def __init__(self, hist_no, task_idx, job_type, job_dir, logger, sftp_client): 20 | self.job_type: str = job_type 21 | self.hist_no: str = hist_no 22 | self.task_idx: str = task_idx 23 | self.job_dir: str = job_dir 24 | self.LOGGER = logger 25 | self.sftp_client: SFTPClientManager = sftp_client 26 | 27 | self.info_dict: dict = self._load() 28 | self.LOGGER.debug(self.info_dict) 29 | 30 | self.dataset_info: DatasetInfo = self._create_dataset(self.info_dict.get("datasets")) 31 | 32 | # ---- loading 33 | def _create_job_filename(self) -> str: 34 | return self.job_type + "_" + self.hist_no + ".job" 35 | 36 | def _load(self) -> dict: 37 | filename = self._create_job_filename() 38 | try: 39 | path = f"{self.job_dir}/xai/{filename}" 40 | job_dict: Dict = self.sftp_client.load_json_data(path) 41 | self.LOGGER.info(f"--------JOB INFO(dataset excluded)----------") 42 | for key, value in job_dict.items(): 43 | if key == "datasets": 44 | continue 45 | self.LOGGER.info(f"{key} : {value}") 46 | self.LOGGER.info(f"job load...") 47 | except FileNotFoundError as e: 48 | self.LOGGER.error(str(e), exc_info=True) 49 | raise FileLoadError(file_name=filename) 50 | except Exception as e: 51 | self.LOGGER.error(str(e), exc_info=True) 52 | raise JsonParsingError() 53 | 54 | return job_dict 55 | 56 | def _create_dataset(self, dataset_dict) -> DatasetInfo: 57 | dataset = DatasetInfo(dataset_dict, self.get_target_field()) 58 | self.LOGGER.debug(str(dataset)) 59 | 60 | return dataset 61 | 62 | # ---- get 63 | def get_hist_no(self) -> str: 64 | return self.hist_no 65 | 66 | def get_dataset_info(self) -> DatasetInfo: 67 | return self.dataset_info 68 | 69 | def get_task_idx(self) -> str: 70 | return self.task_idx 71 | 72 | def get_fields(self): 73 | return self.dataset_info.get_fields() 74 | 75 | def get_key(self) -> str: 76 | # key format : jobType_HistNo 77 | return self.info_dict.get("key", "") 78 | 79 | def get_param_dict_list(self) -> list: 80 | return [self.info_dict.get("algorithms", dict())] 81 | 82 | def set_input_units(self, input_units): 83 | for param_dict in self.get_param_dict_list(): 84 | param_dict["params"]["input_units"] = input_units 85 | 86 | def set_output_units(self, output_units): 87 | for param_dict in self.get_param_dict_list(): 88 | param_dict["params"]["output_units"] = output_units 89 | 90 | def get_num_worker(self) -> int: 91 | return int(self.info_dict.get("num_worker", "1")) 92 | 93 | def get_project_id(self) -> str: 94 | return self.info_dict.get("project_id") 95 | 96 | def get_target_field(self) -> str: 97 | return self.info_dict.get("target_field") 98 | 99 | def get_file_list(self) -> list: 100 | return self.info_dict.get("datasets", {}).get("metadata_json", {}).get("file_list") 101 | 102 | def get_dataset_lines(self) -> list: 103 | if self.get_dataset_format() == Constants.DATASET_FORMAT_TEXT: 104 | return self.info_dict.get("datasets", {}).get("metadata_json", {}).get("file_num_line") 105 | elif self.get_dataset_format() == Constants.DATASET_FORMAT_IMAGE: 106 | return self.info_dict.get("datasets", {}).get("metadata_json", {}).get("file_num") 107 | 108 | def get_dist_yn(self) -> bool: 109 | return StringUtil.get_boolean(self.info_dict.get("algorithms", {}).get("dist_yn", "").lower()) 110 | 111 | def get_dataset_format(self) -> str: 112 | return self.info_dict.get("dataset_format") 113 | 114 | def get_xai_alg(self): 115 | # TODO : 알고리즘 추가시 분기점 116 | return Constants.XAI_ALG_LIME 117 | 118 | def get_lib_type(self): 119 | rst = { 120 | "1": Constants.LIB_TYPE_TF_SINGLE, 121 | "2": Constants.LIB_TYPE_TF, 122 | "4": Constants.LIB_TYPE_GS, 123 | "5": Constants.LIB_TYPE_SKL 124 | }.get(self.info_dict["algorithms"]["lib_type"]) 125 | 126 | if rst == Constants.LIB_TYPE_TF_SINGLE: 127 | rst = { 128 | "XGBoost": Constants.LIB_TYPE_XGB, 129 | "LightGBM": Constants.LIB_TYPE_LGBM 130 | }.get(self.info_dict["algorithms"]["algorithm_code"]) 131 | 132 | return rst 133 | 134 | def get_model_id(self): 135 | return self.info_dict.get("learn_hist_no", None) 136 | 137 | def get_infr_hist_no(self): 138 | return self.info_dict.get("infr_hist_no", None) 139 | 140 | 141 | class XAIJobInfoBuilder(object): 142 | def __init__(self): 143 | self.job_type = None 144 | self.hist_no = None 145 | self.job_type = None 146 | self.task_idx = None 147 | self.job_dir = None 148 | self.logger = logging.getLogger() 149 | self.sftp_client = None 150 | 151 | def set_hist_no(self, hist_no): 152 | self.hist_no = hist_no 153 | return self 154 | 155 | def set_job_dir(self, job_dir): 156 | self.job_dir = job_dir 157 | return self 158 | 159 | def set_job_type(self, job_type): 160 | self.job_type = job_type 161 | return self 162 | 163 | def set_task_idx(self, task_idx): 164 | self.task_idx = task_idx 165 | return self 166 | 167 | def set_logger(self, logger): 168 | self.logger = logger 169 | return self 170 | 171 | def set_sftp_client(self, sftp_client): 172 | self.sftp_client = sftp_client 173 | return self 174 | 175 | def build(self) -> XAIJobInfo: 176 | return XAIJobInfo( 177 | hist_no=self.hist_no, task_idx=self.task_idx, 178 | job_type=self.job_type, job_dir=self.job_dir, 179 | logger=self.logger, sftp_client=self.sftp_client 180 | ) 181 | -------------------------------------------------------------------------------- /xai/core/XAIProcessor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Manki Baek 3 | # e-mail : manki.baek@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | from typing import Union 7 | from datetime import datetime 8 | import os 9 | import tensorflow as tf 10 | from typing import List, Dict 11 | 12 | from pycmmn.sftp.SFTPClientManager import SFTPClientManager 13 | from pycmmn.rest.RestManager import RestManager 14 | from xai.common.Common import Common 15 | from xai.common.Constants import Constants 16 | from xai.info.XAIJobInfo import XAIJobInfo, XAIJobInfoBuilder 17 | from xai.core.data.DataManager import DataManager, DataManagerBuilder 18 | from xai.core.model.ModelLoader import ModelLoader 19 | from xai.core.algorithm.Lime import Lime 20 | from xai.core.data.datawriter.ResultWriter import ResultWriter 21 | 22 | 23 | class XAIProcessor(object): 24 | LOGGER = Common.LOGGER.getLogger() 25 | 26 | def __init__(self, hist_no: str, task_idx: str, job_type: str) -> None: 27 | self.mrms_sftp_manager: SFTPClientManager = SFTPClientManager( 28 | "{}:{}".format(Constants.MRMS_SVC, Constants.MRMS_SFTP_PORT), 29 | Constants.MRMS_USER, Constants.MRMS_PASSWD, self.LOGGER 30 | ) 31 | self.storage_sftp_manager: SFTPClientManager = SFTPClientManager( 32 | "{}:{}".format(Constants.STORAGE_SVC, Constants.STORAGE_SFTP_PORT), 33 | Constants.STORAGE_USER, Constants.STORAGE_PASSWD, self.LOGGER 34 | ) 35 | 36 | self.job_info: XAIJobInfo = XAIJobInfoBuilder() \ 37 | .set_hist_no(hist_no=hist_no) \ 38 | .set_task_idx(task_idx) \ 39 | .set_job_dir(Constants.DIR_JOB) \ 40 | .set_job_type(job_type=job_type) \ 41 | .set_logger(self.LOGGER) \ 42 | .set_sftp_client(self.mrms_sftp_manager) \ 43 | .build() 44 | 45 | self.job_key: str = self.job_info.get_key() 46 | self.job_type: str = job_type 47 | self.task_idx: str = task_idx 48 | 49 | self.lib_type = self.job_info.get_lib_type() 50 | self._set_backend(task_idx) 51 | 52 | self.model = None 53 | self.data_loader_manager: DataManager = DataManagerBuilder() \ 54 | .set_job_info(job_info=self.job_info) \ 55 | .set_sftp_client(self.mrms_sftp_manager) \ 56 | .build() 57 | self.xai_cls: Union[Lime, None] = None 58 | 59 | def run(self) -> None: 60 | try: 61 | self.LOGGER.info(f"-- XAI start. [{self.job_key}]") 62 | 63 | self.data_loader_manager.run() 64 | self.model = self.model_load() 65 | self.xai_cls = self.set_xai_cls() 66 | data = self.data_loader_manager.get_inference_data() 67 | json_data = self.data_loader_manager.get_json_data() 68 | self.LOGGER.info(f"Data Length : [{len(data['x'])}]") 69 | result: list = self.xai_cls.run(data, json_data) 70 | self.result_write(result) 71 | 72 | RestManager.update_xai_status_cd( 73 | Constants.REST_URL_ROOT, 74 | self.LOGGER, 75 | Constants.STATUS_XAI_COMPLETE, 76 | self.job_info.get_hist_no(), 77 | "0", "-" 78 | ) 79 | self.LOGGER.info("-- XAI end. [{}]".format(self.job_key)) 80 | 81 | except Exception as e: 82 | self.LOGGER.error(e, exc_info=True) 83 | RestManager.update_xai_status_cd( 84 | Constants.REST_URL_ROOT, 85 | self.LOGGER, 86 | Constants.STATUS_XAI_ERROR, 87 | self.job_info.get_hist_no(), 88 | "0", "-" 89 | ) 90 | 91 | def set_xai_cls(self): 92 | return { 93 | # Constants.XAI_ALG_GRAD_SCAM: GradScam(self.model, self.job_info), 94 | Constants.XAI_ALG_LIME: Lime(self.model, self.job_info, self.storage_sftp_manager) 95 | }.get(self.job_info.get_xai_alg(), Constants.XAI_ALG_GRAD_SCAM) 96 | 97 | def model_load(self): 98 | return ModelLoader.load( 99 | lib_type=self.lib_type, model_id=self.job_info.get_model_id() 100 | ) 101 | 102 | def result_write(self, result_list): 103 | json_data = self.data_loader_manager.get_json_data() 104 | json_data = self._insert_xai_info(json_data, result_list) 105 | 106 | ResultWriter.result_file_write( 107 | result_path=Constants.DIR_RESULT, 108 | results=json_data, 109 | result_type=Constants.JOB_TYPE 110 | ) 111 | 112 | def _insert_xai_info(self, json_data, result_dict_list: List[Dict]): 113 | curr_time = datetime.now().strftime('%Y%m%d%H%M%S') 114 | 115 | unique_keys = None 116 | label_field_info = self.job_info.get_dataset_info().get_fields()[0] 117 | if label_field_info.label(): 118 | unique_keys = list(label_field_info.get_statistic().get("unique", {}).get("unique").keys()) 119 | 120 | for line_idx, jsonline in enumerate(json_data): 121 | result_dict_keys = result_dict_list[line_idx].keys() 122 | for key in result_dict_keys: 123 | # convert original label 124 | if key == "inference_result" and unique_keys is not None: 125 | jsonline[key] = unique_keys[int(result_dict_list[line_idx][key])] 126 | else: 127 | jsonline[key] = result_dict_list[line_idx][key] 128 | jsonline["eqp_dt"] = curr_time 129 | jsonline["xai_hist_no"] = self.job_key 130 | jsonline["infr_hist_no"] = self.job_info.get_infr_hist_no() 131 | if jsonline.__contains__("image"): 132 | jsonline.pop("image") 133 | json_data[line_idx] = jsonline 134 | 135 | if len(result_dict_list) > len(json_data): 136 | result_dict_list[-1]["eqp_dt"] = curr_time 137 | result_dict_list[-1]["xai_hist_no"] = self.job_key 138 | result_dict_list[-1]["infr_hist_no"] = self.job_info.get_infr_hist_no() 139 | json_data.append(result_dict_list[-1]) 140 | 141 | return json_data 142 | 143 | def _set_backend(self, task_idx): 144 | if self.lib_type == Constants.LIB_TYPE_TF: 145 | self.LOGGER.info(f"TF_CONFIG : {os.environ['TF_CONFIG']}") 146 | 147 | if os.environ.get("CUDA_VISIBLE_DEVICES", None) is "-1": 148 | self.LOGGER.info("Running CPU MODE") 149 | else: 150 | physical_devices = tf.config.list_physical_devices('GPU') 151 | 152 | if os.environ.get("CUDA_VISIBLE_DEVICES", None) is None: 153 | os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("NVIDIA_COM_GPU_MEM_IDX", "0") 154 | 155 | if len(physical_devices) != 0: 156 | # allow growth GPU memory 157 | tf.config.set_visible_devices(physical_devices[0], 'GPU') 158 | 159 | self.LOGGER.info( 160 | f"gpu_no : {os.environ['CUDA_VISIBLE_DEVICES']}, task_idx : {task_idx}, \ 161 | physical devices: {physical_devices}, \ 162 | NVIDIA_COM_GPU_MEM_IDX : {os.environ.get('NVIDIA_COM_GPU_MEM_IDX', 'no variable!')}" 163 | ) 164 | 165 | # 메모리 제한 166 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" 167 | mem_limit = int(int(os.environ.get("NVIDIA_COM_GPU_MEM_POD", 1024)) * 0.35) 168 | self.LOGGER.info("GPU Memory Limit Size : {}".format(mem_limit)) 169 | tf.config.experimental.set_memory_growth(physical_devices[0], False) 170 | tf.config.set_logical_device_configuration( 171 | physical_devices[0], 172 | [tf.config.LogicalDeviceConfiguration(memory_limit=mem_limit)]) 173 | 174 | else: 175 | self.LOGGER.debug("Physical Devices(GPU) are None") 176 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /xai/core/algorithm/Lime.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author : Manki Baek 3 | # e-mail : manki.baek@seculayer.com 4 | # Powered by Seculayer © 2021 Service Model Team, R&D Center. 5 | 6 | from lime import lime_image, lime_tabular, lime_text 7 | from lime.wrappers.scikit_image import SegmentationAlgorithm 8 | from typing import Dict, Callable, List, AnyStr, Union 9 | import numpy as np 10 | from sklearn.pipeline import make_pipeline 11 | from datetime import datetime 12 | import pandas as pd 13 | import matplotlib.pyplot as plt 14 | import seaborn as sns 15 | from pdpbox import pdp 16 | 17 | from pycmmn.sftp.SFTPClientManager import SFTPClientManager 18 | from pycmmn.rest.RestManager import RestManager 19 | from pycmmn.utils.CV2Utils import CV2Utils 20 | from pycmmn.utils.FileUtils import FileUtils 21 | from xai.common.Common import Common 22 | from xai.common.Constants import Constants 23 | from xai.info.XAIJobInfo import XAIJobInfo 24 | from xai.core.algorithm.AlgAbstract import AlgAbstract 25 | 26 | 27 | class Lime(AlgAbstract): 28 | LOGGER = Common.LOGGER.getLogger() 29 | 30 | def __init__(self, model, job_info: XAIJobInfo, sftp_manager: SFTPClientManager): 31 | super().__init__(model, job_info) 32 | self.data_type = self.get_dataset_type() 33 | self.sftp_manager = sftp_manager 34 | 35 | self.predict_fn: Union[Callable, str] = self.define_predict_fn() 36 | 37 | self.explainer = None 38 | self.segmenter = None 39 | 40 | def get_dataset_type(self) -> str: 41 | dataset_format = self.job_info.get_dataset_format() 42 | 43 | # 현재, 이미지가 아닐경우 무조건 dataset format을 text로 분류 44 | if dataset_format == Constants.DATASET_FORMAT_TEXT: 45 | str_flag = False 46 | for field in self.fields: 47 | if field.is_label: 48 | continue 49 | if field.field_type == "string": 50 | str_flag = True 51 | break 52 | 53 | if not str_flag: 54 | dataset_format = Constants.DATASET_FORMAT_TABLE 55 | 56 | return dataset_format 57 | 58 | def define_predict_fn(self) -> Union[Callable, str]: 59 | predict_fn: Union[Callable, str, None] = None 60 | if self.job_info.get_lib_type() == Constants.LIB_TYPE_TF: 61 | if self.data_type == Constants.DATASET_FORMAT_TEXT: 62 | predict_fn = "predict" 63 | else: 64 | predict_fn = self.model.predict 65 | else: # Constants.LIB_TYPE_SKL, LIB_TYPE_LGBM, LIB_TYPE_XGB 66 | if self.data_type == Constants.DATASET_FORMAT_TEXT: 67 | predict_fn = "predict_proba" 68 | else: 69 | predict_fn = self.model.predict_proba 70 | 71 | return predict_fn 72 | 73 | def define_explainer(self, cvt_data) -> None: 74 | if self.data_type == Constants.DATASET_FORMAT_IMAGE: 75 | try: 76 | self.explainer = lime_image.LimeImageExplainer() 77 | # 이미지 분할 알고리즘 종류 slic, quickshift, felzenszwalb 78 | self.segmenter = SegmentationAlgorithm( 79 | 'slic', # 분할 알고리즘 이름 80 | n_segments=30, # 이미지 분할 조각 개수 81 | compactness=3, # 유사한 파트를 합치는 함수 82 | sigma=1 # 스무딩 역할: 0과 1사이의 float 83 | ) 84 | except Exception as e: 85 | self.LOGGER.error(e, exc_info=True) 86 | 87 | elif self.data_type == Constants.DATASET_FORMAT_TEXT: 88 | try: 89 | self.explainer = lime_text.LimeTextExplainer(random_state=42, split_expression=r' ') 90 | # self.explainer = lime_text.LimeTextExplainer(random_state=42, split_expression=r'[\W_0-9]+') 91 | except Exception as e: 92 | self.LOGGER.error(e, exc_info=True) 93 | 94 | elif self.data_type == Constants.DATASET_FORMAT_TABLE: 95 | column_list = list() 96 | try: 97 | for field_idx, field in enumerate(self.fields): 98 | if field.is_label: 99 | continue 100 | column_list.append(field.field_name) 101 | 102 | self.explainer = lime_tabular.LimeTabularExplainer(np.array(cvt_data), feature_names=column_list) 103 | except Exception as e: 104 | self.LOGGER.error(f"colum_list : {column_list}, cvt_data_shape : {np.shape(np.array(cvt_data))}") 105 | self.LOGGER.error(e, exc_info=True) 106 | 107 | def run(self, data: Dict, json_data: List): 108 | x = data['x'] 109 | self.define_explainer(x) 110 | 111 | result_list = list() 112 | 113 | total_start_time = datetime.now() 114 | for idx, line in enumerate(x): 115 | loop_start_time = datetime.now() 116 | 117 | try: 118 | case: Callable = { 119 | Constants.DATASET_FORMAT_TEXT: self.text_data_run, 120 | Constants.DATASET_FORMAT_IMAGE: self.image_data_run, 121 | Constants.DATASET_FORMAT_TABLE: self.tabular_data_run 122 | }.get(self.data_type, self.text_data_run) 123 | 124 | line_rst_dict = case(x[idx], json_data[idx], idx) 125 | 126 | result_list.append(line_rst_dict) 127 | 128 | self.LOGGER.info(f"Line [{idx}] is finished..") 129 | 130 | # 진행률 표시 131 | progress_pct = (idx + 1) / len(x) * 100 132 | self.LOGGER.info(f"{progress_pct} % completed...") 133 | 134 | RestManager.send_xai_progress( 135 | Constants.REST_URL_ROOT, self.LOGGER, self.job_info.get_hist_no(), progress_pct 136 | ) 137 | 138 | except Exception as e: 139 | self.LOGGER.error(e, exc_info=True) 140 | self.LOGGER.error(f"idx : {idx}, data : {line}") 141 | self.LOGGER.error("append {} at result_list") 142 | result_list.append({}) 143 | 144 | self.LOGGER.info(f"Loop excution time : [{datetime.now() - loop_start_time}]") 145 | 146 | if self.data_type == Constants.DATASET_FORMAT_TABLE: 147 | result_list.append(self.make_table_statistics(json_data, x)) 148 | 149 | self.scp_image_rst() 150 | 151 | self.LOGGER.info(f"Total excution time : [{datetime.now() - total_start_time}]") 152 | 153 | RestManager.send_xai_progress( 154 | Constants.REST_URL_ROOT, self.LOGGER, self.job_info.get_hist_no(), 100.0, "delete" 155 | ) 156 | 157 | return result_list 158 | 159 | def make_table_statistics(self, json_data: List[Dict], cvt_data: List) -> Dict: 160 | df_data = pd.DataFrame(json_data) 161 | rst_dict = dict() 162 | 163 | try: 164 | df_data.pop('key') 165 | df_data.pop('dataset_id') 166 | df_data.pop('proc_dt') 167 | except KeyError: 168 | pass 169 | 170 | FileUtils.mkdir(f"{self.job_info.get_hist_no()}") 171 | 172 | line_plot_list = self.create_line_plot(df_data) 173 | rst_dict["line_plot_list"] = line_plot_list 174 | 175 | pdp_isolate_list = self.create_pdp_isolate(cvt_data) 176 | rst_dict["pdp_isolate_list"] = pdp_isolate_list 177 | 178 | pdp_interact_list = self.create_pdp_interact(cvt_data) 179 | rst_dict["pdp_interact_list"] = pdp_interact_list 180 | 181 | return rst_dict 182 | 183 | def create_line_plot(self, df_data) -> List[str]: 184 | features_list = [field.field_name for field in self.fields] 185 | target_field = features_list.pop(0) 186 | 187 | color_list = ["blue", "green", "orange", "red"] 188 | file_name_list = list() 189 | 190 | for idx, feature in enumerate(features_list): 191 | plt.figure(figsize=(16, 9)) 192 | sns.set() 193 | sns.lineplot(data=df_data, x=target_field, y=feature, errorbar=('ci', 95), color=color_list[idx % len(color_list)]) 194 | plt.subplots_adjust(top=0.93) 195 | plt.suptitle(f"{feature} and {target_field} plot", fontsize=16) 196 | file_name = f"{feature}_{target_field}_line_plot" 197 | file_name_list.append(f"/xai/table_statistics/{self.job_info.get_hist_no()}/{file_name}.png") 198 | plt.savefig(f"{self.job_info.get_hist_no()}/{file_name}", bbox_inched='tight', dpi=400) 199 | 200 | return file_name_list 201 | 202 | def create_pdp_isolate(self, cvt_data: List, cluster_flag=False, nb_clusters=None, lines_flag=False) -> List[str]: 203 | if not self.job_info.get_lib_type() in [Constants.LIB_TYPE_XGB, Constants.LIB_TYPE_LGBM, Constants.LIB_TYPE_SKL]: 204 | return [] 205 | 206 | features_list = [field.field_name for field in self.fields] 207 | features_list.pop(0) 208 | x = pd.DataFrame(cvt_data, columns=features_list) 209 | file_name_list = list() 210 | 211 | for feature in features_list: 212 | # Create the data that we will plot 213 | pdp_goals = pdp.pdp_isolate(model=self.model, dataset=x, model_features=features_list, feature=feature) 214 | pdp.pdp_plot(pdp_goals, feature, cluster=cluster_flag, n_cluster_centers=nb_clusters, plot_lines=lines_flag) 215 | file_name = f"{feature}_pdp_isolate.png" 216 | plt.savefig(f"{self.job_info.get_hist_no()}/{file_name}", dpi=400, bbox_inches="tight") 217 | file_name_list.append(f"/xai/table_statistics/{self.job_info.get_hist_no()}/{file_name}") 218 | 219 | return file_name_list 220 | 221 | def create_pdp_interact(self, cvt_data: List) -> List[str]: 222 | if not self.job_info.get_lib_type() in [Constants.LIB_TYPE_XGB, Constants.LIB_TYPE_LGBM, Constants.LIB_TYPE_SKL]: 223 | return [] 224 | 225 | features_list = [field.field_name for field in self.fields] 226 | features_list.pop(0) 227 | x = pd.DataFrame(cvt_data, columns=features_list) 228 | file_name_list = list() 229 | 230 | for start in range(len(features_list)): 231 | if start == len(features_list) - 1 or len(features_list) < 2: 232 | break 233 | for end in range(start + 1, len(features_list)): 234 | features_to_plot = [features_list[start], features_list[end]] 235 | inter = pdp.pdp_interact(model=self.model, dataset=x, model_features=features_list, features=features_to_plot) 236 | pdp.pdp_interact_plot(pdp_interact_out=inter, feature_names=features_to_plot, plot_type='contour') 237 | file_name = f"{'_'.join(features_to_plot)}_pdp_interact.png" 238 | plt.savefig(f"{self.job_info.get_hist_no()}/{file_name}", dpi=400, bbox_inches="tight") 239 | file_name_list.append(f"/xai/table_statistics/{self.job_info.get_hist_no()}/{file_name}") 240 | 241 | return file_name_list 242 | 243 | def scp_image_rst(self) -> None: 244 | if self.data_type == Constants.DATASET_FORMAT_IMAGE: 245 | self.sftp_manager.scp_to_storage( 246 | local_path=f"{self.job_info.get_hist_no()}", 247 | remote_path=f"{Constants.DIR_WEB_FILE}/xai/masked_image" 248 | ) 249 | self.sftp_manager.scp_to_storage( 250 | local_path=f"{self.job_info.get_hist_no()}_thumbnail", 251 | remote_path=f"{Constants.DIR_WEB_FILE}/xai/masked_thumbnail_image" 252 | ) 253 | FileUtils.remove_dir(f"{self.job_info.get_hist_no()}") 254 | FileUtils.remove_dir(f"{self.job_info.get_hist_no()}_thumbnail") 255 | elif self.data_type == Constants.DATASET_FORMAT_TABLE: 256 | self.sftp_manager.scp_to_storage( 257 | local_path=f"{self.job_info.get_hist_no()}", 258 | remote_path=f"{Constants.DIR_WEB_FILE}/xai/table_statistics" 259 | ) 260 | FileUtils.remove_dir(f"{self.job_info.get_hist_no()}") 261 | 262 | def text_data_run(self, cvt_line_data, json_line_data, line_idx): 263 | inferenced_y = self.model_inference([cvt_line_data]) 264 | 265 | pipe = make_pipeline(self.functions[-1][-1], self.model) 266 | line_rst_dict = dict() 267 | original_idx_dict = dict() 268 | cvt_dict = dict() 269 | 270 | s_idx = 0 271 | reversed_data = list() 272 | for field_idx, field in enumerate(self.fields): 273 | if field.is_label: 274 | continue 275 | max_len = self.functions[field_idx][-1].get_num_feat() 276 | e_idx = s_idx + max_len 277 | 278 | reversed_data = cvt_line_data[s_idx:e_idx] 279 | for cvt_idx in range(len(self.functions[field_idx]) - 1, -1, -1): # 역순 280 | reversed_data = self.functions[field_idx][cvt_idx].reverse(reversed_data, json_line_data[field.field_name]) 281 | tmp_idx_list, cvt_origin = self.functions[field_idx][-1].get_original_idx( 282 | cvt_data=cvt_line_data[s_idx:e_idx], original_data=json_line_data[field.field_name] 283 | ) 284 | original_idx_dict[field.field_name] = tmp_idx_list 285 | cvt_dict[field.field_name] = cvt_origin 286 | exp = self.explainer.explain_instance( 287 | " ".join(reversed_data), eval(f"pipe.{self.predict_fn}"), num_samples=Constants.LIME_TEXT_SAMPLE_CNT 288 | ) 289 | 290 | # store result 291 | temp_effect_val = dict() 292 | for key in exp.local_exp.keys(): 293 | temp_effect_val[key] = [(int(f_idx), val) for (f_idx, val) in exp.local_exp[key]] 294 | line_rst_dict["lime_effect_val"] = temp_effect_val 295 | line_rst_dict["lime_class_names"] = exp.class_names 296 | line_rst_dict["lime_predict_proba"] = exp.predict_proba.tolist() 297 | line_rst_dict["inference_result"] = int(inferenced_y[0]) 298 | line_rst_dict["origin_idx_dict"] = original_idx_dict 299 | line_rst_dict["cvt_dict"] = cvt_dict 300 | line_rst_dict["lime_cvt_text"] = " ".join(reversed_data) 301 | line_rst_dict["lime_text_highlight_idx"] = exp.domain_mapper.indexed_string.positions 302 | line_rst_dict["lime_inverse_vocab"] = exp.domain_mapper.indexed_string.inverse_vocab 303 | 304 | return line_rst_dict 305 | 306 | def text_data_run_deprecated(self, cvt_line_data, json_line_data, line_idx): 307 | 308 | inferenced_y = self.model_inference([cvt_line_data]) 309 | 310 | column_list = list() 311 | original_idx_dict = dict() 312 | cvt_dict = dict() 313 | line_rst_dict = dict() 314 | s_idx = 0 315 | unique_labels = 0 316 | for field_idx, field in enumerate(self.fields): 317 | if field.is_label: 318 | unique_labels = field.stat_dict["unique_count"] 319 | continue 320 | 321 | max_len = self.functions[field_idx][-1].get_num_feat() 322 | e_idx = s_idx + max_len 323 | 324 | reversed_data = cvt_line_data[s_idx:e_idx] 325 | for cvt_idx in range(len(self.functions[field_idx]) - 1, -1, -1): # 역순 326 | reversed_data = self.functions[field_idx][cvt_idx].reverse(reversed_data, json_line_data[field.field_name]) 327 | tmp_idx_list, cvt_origin = self.functions[field_idx][-1].get_original_idx( 328 | cvt_data=cvt_line_data[s_idx:e_idx], original_data=json_line_data[field.field_name] 329 | ) 330 | original_idx_dict[field.field_name] = tmp_idx_list 331 | cvt_dict[field.field_name] = cvt_origin 332 | column_list.extend(reversed_data) 333 | 334 | explainer = lime_tabular.LimeTabularExplainer(np.array(cvt_line_data), feature_names=column_list) 335 | exp = explainer.explain_instance( 336 | np.array(cvt_line_data), predict_fn=self.predict_fn, 337 | num_samples=1000, labels=range(unique_labels), num_features=6 338 | ) 339 | """ 340 | exp.domain_mapper.feature_names : feature name 341 | exp.local_exp : (feature idx, 점수) 342 | exp.domain_mapper.feature_values : feature value 343 | exp.domain_mapper.discretized_feature_names : 점수위에 설명??? 344 | exp.calss_names : class name 345 | """ 346 | # line_rst_dict["feature_name"] = exp.domain_mapper.feature_names 347 | temp_effect_val = dict() 348 | for key in exp.local_exp.keys(): 349 | temp_effect_val[key] = [(exp.domain_mapper.feature_names[f_idx], val) for (f_idx, val) in exp.local_exp[key]] 350 | line_rst_dict["effect_val"] = temp_effect_val 351 | line_rst_dict["class_names"] = exp.class_names 352 | line_rst_dict["predict_proba"] = exp.predict_proba.tolist() 353 | line_rst_dict["inference_result"] = int(inferenced_y[0]) 354 | line_rst_dict["origin_idx_dict"] = original_idx_dict 355 | line_rst_dict["cvt_dict"] = cvt_dict 356 | 357 | # exp.save_to_file(f"./temp_data/rrmm_{line_idx}.html") 358 | 359 | return line_rst_dict 360 | 361 | def tabular_data_run(self, cvt_line_data, json_line_data, line_idx): 362 | inferenced_y = self.model_inference([cvt_line_data]) 363 | line_rst_dict = dict() 364 | 365 | exp = self.explainer.explain_instance( 366 | np.array(cvt_line_data), predict_fn=self.predict_fn, 367 | num_samples=Constants.LIME_TABULAR_SAMPLE_CNT, num_features=6 368 | ) 369 | """ 370 | exp.domain_mapper.feature_names : feature name 371 | exp.local_exp : (feature idx, 점수) 372 | exp.domain_mapper.feature_values : feature value 373 | exp.domain_mapper.discretized_feature_names : 값의 범위 374 | exp.calss_names : class name 375 | """ 376 | temp_effect_val = dict() 377 | for key in exp.local_exp.keys(): 378 | temp_effect_val[key] = [(int(f_idx), val) for (f_idx, val) in exp.local_exp[key]] 379 | line_rst_dict["lime_effect_val"] = temp_effect_val 380 | line_rst_dict["lime_discretized_feature_names"] = exp.domain_mapper.discretized_feature_names 381 | line_rst_dict["lime_class_names"] = exp.class_names 382 | line_rst_dict["lime_predict_proba"] = exp.predict_proba.tolist() 383 | line_rst_dict["lime_feature_names"] = exp.domain_mapper.feature_names 384 | line_rst_dict["lime_feature_values"] = exp.domain_mapper.feature_values 385 | line_rst_dict["inference_result"] = int(inferenced_y[0]) 386 | line_rst_dict["cvt_data"] = cvt_line_data 387 | 388 | return line_rst_dict 389 | 390 | def image_data_run(self, cvt_line_data, json_line_data, line_idx): 391 | inferenced_y = self.model_inference([cvt_line_data]) 392 | 393 | line_rst_dict = dict() 394 | 395 | exp = self.explainer.explain_instance( 396 | np.array(cvt_line_data), 397 | classifier_fn=self.predict_fn, # 각 class 확률 반환 398 | num_samples=Constants.LIME_IMAGE_SAMPLE_CNT, # sample space 399 | segmentation_fn=self.segmenter # 분할 알고리즘 400 | ) 401 | img, mask = exp.get_image_and_mask(exp.top_labels[0]) 402 | self.LOGGER.info(mask) 403 | self.LOGGER.info(f"inference rst : [{exp.top_labels[0]}]") 404 | 405 | mask = np.expand_dims(mask, axis=2) 406 | masked_img: np.ndarray = img * mask 407 | masked_thumbnail_img = CV2Utils.resize(masked_img.astype("float32"), (256, 256)) 408 | line_json = json_line_data 409 | 410 | self.make_rst_image(line_idx, line_json, masked_img, masked_thumbnail_img) 411 | 412 | file_path: AnyStr = line_json["file_path"] 413 | thumbnail_path: AnyStr = line_json["thumbnail_path"] 414 | line_rst_dict["file_path"] = file_path.replace(Constants.DIR_WEB_FILE, "") 415 | line_rst_dict["thumbnail_path"] = thumbnail_path.replace(Constants.DIR_WEB_FILE, "") 416 | line_rst_dict["masked_file_path"] = f"/xai/masked_image/{self.job_info.get_hist_no()}" 417 | line_rst_dict["masked_thumbnail_path"] = f"/xai/masked_thumbnail_image/{self.job_info.get_hist_no()}_thumbnail" 418 | 419 | if inferenced_y.shape[-1] >= 2: 420 | inferenced_y = inferenced_y.argmax(axis=1) 421 | line_rst_dict["inference_result"] = int(inferenced_y[0]) 422 | # line_rst_dict["mask"] = mask.tolist() 423 | 424 | return line_rst_dict 425 | 426 | def make_rst_image(self, line_idx: int, line_json: Dict, masked_image, masked_thumbnail_image) -> None: 427 | if line_idx == 0: 428 | FileUtils.mkdir(f"{self.job_info.get_hist_no()}") 429 | FileUtils.mkdir(f"{self.job_info.get_hist_no()}_thumbnail") 430 | 431 | file_nm = line_json["file_ori_nm"] 432 | CV2Utils.imwrite(f"{self.job_info.get_hist_no()}/{file_nm}", masked_image) 433 | CV2Utils.imwrite(f"{self.job_info.get_hist_no()}_thumbnail/{file_nm}", masked_thumbnail_image) 434 | --------------------------------------------------------------------------------