├── nemo_inspector ├── assets │ ├── images │ │ └── icons │ │ │ ├── edit_icon.png │ │ │ ├── save_icon.png │ │ │ └── compare_icon.png │ ├── styles │ │ ├── styles.css │ │ └── ansi_styles.css │ └── scripts │ │ ├── change_element_height.js │ │ └── register_textarea.js ├── tests │ ├── README.md │ └── test_ping.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── decoration │ │ ├── __init__.py │ │ ├── plain_text.py │ │ ├── code.py │ │ ├── common.py │ │ └── latex.py │ └── common.py ├── settings │ ├── __init__.py │ ├── constants │ │ ├── paths.py │ │ ├── templates.py │ │ ├── configurations.py │ │ ├── common.py │ │ └── __init__.py │ └── inspector_config.py ├── layouts │ ├── analyze_page_layouts │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── base_layout.py │ │ ├── modals_layouts.py │ │ └── table_layouts.py │ ├── base_layouts.py │ ├── __init__.py │ └── common_layouts.py ├── __main__.py ├── callbacks │ ├── common │ │ ├── __init__.py │ │ ├── decoration.py │ │ └── navigation.py │ ├── __init__.py │ └── analyze_page │ │ ├── __init__.py │ │ ├── update_dataset.py │ │ ├── short_info_table.py │ │ ├── save_dataset.py │ │ ├── sort_dataset.py │ │ ├── count_stats_dataset.py │ │ ├── filter_dataset.py │ │ ├── label_dataset.py │ │ └── detailed_sample_info.py ├── inspector_app.py └── parse_agruments_helpers.py ├── requirements ├── inspector-tests.txt └── inspector.txt ├── setup.py ├── README.md └── LICENSE /nemo_inspector/assets/images/icons/edit_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/NeMo-Inspector/HEAD/nemo_inspector/assets/images/icons/edit_icon.png -------------------------------------------------------------------------------- /nemo_inspector/assets/images/icons/save_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/NeMo-Inspector/HEAD/nemo_inspector/assets/images/icons/save_icon.png -------------------------------------------------------------------------------- /nemo_inspector/assets/images/icons/compare_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/NeMo-Inspector/HEAD/nemo_inspector/assets/images/icons/compare_icon.png -------------------------------------------------------------------------------- /nemo_inspector/assets/styles/styles.css: -------------------------------------------------------------------------------- 1 | .button-class { 2 | line-height: 20px; 3 | font-size: 14px; 4 | height: 40px; 5 | text-overflow: ellipsis; 6 | white-space: nowrap; 7 | overflow: hidden; 8 | margin-left: 2px; 9 | } -------------------------------------------------------------------------------- /nemo_inspector/tests/README.md: -------------------------------------------------------------------------------- 1 | To launch tests firstly install all the requirements 2 | ``` 3 | pip install -r requirements/main.txt 4 | pip install -r requirements/inspector.txt 5 | pip install -r requirements/common-tests.txt 6 | pip install -r requirements/inspector-tests.txt 7 | ``` 8 | Now it is possible to launch tests 9 | ``` 10 | pytest inspector/tests 11 | ``` -------------------------------------------------------------------------------- /nemo_inspector/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nemo_inspector/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nemo_inspector/settings/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nemo_inspector/layouts/analyze_page_layouts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /requirements/inspector-tests.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | dash[testing] 16 | pytest-rerunfailures 17 | webdriver-manager==4.0.2 18 | -------------------------------------------------------------------------------- /nemo_inspector/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from nemo_inspector.inspector_app import main 16 | 17 | if __name__ == "__main__": 18 | main() 19 | -------------------------------------------------------------------------------- /requirements/inspector.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ansi2html 16 | dash 17 | dash-ace 18 | dash_bootstrap_components 19 | joblib 20 | pandas 21 | pygments 22 | sshtunnel_requests 23 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import nemo_inspector.callbacks.common.decoration as decoration 16 | import nemo_inspector.callbacks.common.navigation as navigation 17 | -------------------------------------------------------------------------------- /nemo_inspector/settings/constants/paths.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pathlib 15 | 16 | COMPARE_ICON_PATH = "assets/images/icons/compare_icon.png" 17 | EDIT_ICON_PATH = "assets/images/icons/edit_icon.png" 18 | SAVE_ICON_PATH = "assets/images/icons/save_icon.png" 19 | PATH_TO_THE_REPOSITORY = pathlib.Path(__file__).parents[3] 20 | -------------------------------------------------------------------------------- /nemo_inspector/settings/constants/templates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ERROR_MESSAGE_TEMPLATE = "When applying {} function\ngot errors\n{}" 16 | QUERY_INPUT_ID = '{{"type": "{}", "id": "{}"}}' 17 | MODEL_SELECTOR_ID = '{{"type": "model_selector", "id": {}}}' 18 | LABEL_SELECTOR_ID = '{{"type": "label_selector", "id": {}}}' 19 | -------------------------------------------------------------------------------- /nemo_inspector/utils/decoration/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from nemo_inspector.utils.decoration.common import ( 16 | design_text_output, 17 | get_height_adjustment, 18 | ) 19 | from nemo_inspector.utils.decoration.code import highlight_code 20 | from nemo_inspector.utils.decoration.plain_text import color_text_diff 21 | -------------------------------------------------------------------------------- /nemo_inspector/assets/scripts/change_element_height.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | window.addEventListener('message', function(event) { 16 | if (event.data && event.data.frameHeight && event.data.frameId) { 17 | var iframe = document.getElementById(event.data.frameId); 18 | if (iframe) { 19 | iframe.style.height = event.data.frameHeight + 'px'; 20 | } 21 | } 22 | }, false); 23 | -------------------------------------------------------------------------------- /nemo_inspector/settings/constants/configurations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | CODE_SEPARATORS = { 16 | "code_begin": "{code_begin}", 17 | "code_end": "{code_end}", 18 | "code_output_begin": "{code_output_begin}", 19 | "code_output_end": "{code_output_end}", 20 | "code_output_format": "llama", 21 | } 22 | DATA_PAGE_SIZE = 10 23 | EXTRA_FIELDS = ["page_index", "file_name"] 24 | STATS_KEYS = [ 25 | "question_index", 26 | "problem", 27 | ] 28 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib import Path 17 | import dash_bootstrap_components as dbc 18 | from dash import Dash 19 | 20 | assets_path = os.path.join(Path(__file__).parents[1], "assets") 21 | 22 | app = Dash( 23 | __name__, 24 | suppress_callback_exceptions=True, 25 | external_stylesheets=[dbc.themes.BOOTSTRAP], 26 | assets_folder=assets_path, 27 | ) 28 | 29 | import nemo_inspector.callbacks.common as common 30 | import nemo_inspector.callbacks.analyze_page as analyze_page 31 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/common/decoration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dash import html 16 | from dash.dependencies import Input, Output 17 | 18 | from nemo_inspector.callbacks import app 19 | from nemo_inspector.utils.decoration import get_height_adjustment 20 | 21 | 22 | @app.callback( 23 | Output("js_container", "children", allow_duplicate=True), 24 | [ 25 | Input("page_content", "children"), 26 | Input("js_trigger", "children"), 27 | ], 28 | prevent_initial_call=True, 29 | ) 30 | def adjust_text_area_height(content: html.Div, trigger: str) -> html.Iframe: 31 | return get_height_adjustment() 32 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import nemo_inspector.callbacks.analyze_page.count_stats_dataset as count_stats_dataset 17 | import nemo_inspector.callbacks.analyze_page.detailed_sample_info as detailed_sample_info 18 | import nemo_inspector.callbacks.analyze_page.filter_dataset as filter_dataset 19 | import nemo_inspector.callbacks.analyze_page.label_dataset as label_dataset 20 | import nemo_inspector.callbacks.analyze_page.save_dataset as save_dataset 21 | import nemo_inspector.callbacks.analyze_page.sort_dataset as sort_dataset 22 | import nemo_inspector.callbacks.analyze_page.short_info_table as short_info_table 23 | import nemo_inspector.callbacks.analyze_page.update_dataset as update_dataset 24 | -------------------------------------------------------------------------------- /nemo_inspector/assets/scripts/register_textarea.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | function registerTextarea() { 16 | var textareas = document.querySelectorAll("textarea"); 17 | textareas.forEach(function(textarea) { 18 | function updateHeight() { 19 | textarea.style.height = 0 + 'px'; 20 | 21 | var height = Math.max(textarea.scrollHeight, textarea.offsetHeight, 22 | textarea.clientHeight); 23 | 24 | textarea.style.height = height + 'px'; 25 | }; 26 | textarea.onload = updateHeight; 27 | textarea.onresize = updateHeight; 28 | textarea.addEventListener('input', updateHeight); 29 | updateHeight() 30 | }); 31 | }; 32 | -------------------------------------------------------------------------------- /nemo_inspector/layouts/base_layouts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dash_bootstrap_components as dbc 16 | from dash import dcc, html 17 | 18 | 19 | def get_main_page_layout() -> html.Div: 20 | return html.Div( 21 | [ 22 | dcc.Location(id="url", refresh=False), 23 | dbc.NavbarSimple( 24 | brand="NeMo Inspector", 25 | sticky="top", 26 | color="blue", 27 | dark=True, 28 | class_name="mb-2", 29 | ), 30 | dbc.Container(id="page_content"), 31 | dbc.Container(id="js_trigger", style={"display": "none"}, children=""), 32 | dbc.Container(id="js_container"), 33 | dbc.Container(id="dummy_output", style={"display": "none"}, children=""), 34 | ] 35 | ) 36 | -------------------------------------------------------------------------------- /nemo_inspector/settings/constants/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ANSI = "ansi" 16 | BASE_GENERATION = "base_generation" 17 | CHOOSE_GENERATION = "choose generation" 18 | CHOOSE_LABEL = "choose label" 19 | COMPARE = "compare" 20 | CODE = "code" 21 | CODE_BEGIN = "code_begin" 22 | CODE_END = "code_end" 23 | CODE_OUTPUT_BEGIN = "code_output_begin" 24 | CODE_OUTPUT_END = "code_output_end" 25 | CUSTOM = "custom" 26 | DELETE = "delete" 27 | EXPECTED_ANSWER_FIELD = "expected_answer" 28 | FILE_NAME = "file_name" 29 | FILES_ONLY = "files_only" 30 | FILES_FILTERING = "add_files_filtering" 31 | GENERAL_STATS = "general_stats" 32 | INLINE_STATS = "inline_stats" 33 | LABEL = "labels" 34 | LATEX = "latex" 35 | MARKDOWN = "markdown" 36 | QUESTIONS_FILTERING = "questions_filtering" 37 | QUESTION_FIELD = "problem" 38 | UNDEFINED = "undefined" 39 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/common/navigation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dash import html 16 | from dash.dependencies import Input, Output 17 | from flask import current_app 18 | 19 | from nemo_inspector.callbacks import app 20 | from nemo_inspector.layouts import get_compare_test_layout 21 | from nemo_inspector.settings.constants import ( 22 | CODE_BEGIN, 23 | CODE_END, 24 | CODE_OUTPUT_BEGIN, 25 | CODE_OUTPUT_END, 26 | ) 27 | from nemo_inspector.settings.constants.configurations import CODE_SEPARATORS 28 | 29 | 30 | @app.callback( 31 | Output("page_content", "children"), 32 | Input("url", "pathname"), 33 | ) 34 | def nav_click(url: str) -> html.Div: 35 | config = current_app.config["nemo_inspector"] 36 | config["code_separators"] = ( 37 | config["code_tags"][CODE_BEGIN], 38 | config["code_tags"][CODE_END], 39 | ) 40 | config["code_output_separators"] = ( 41 | config["code_tags"][CODE_OUTPUT_BEGIN], 42 | config["code_tags"][CODE_OUTPUT_END], 43 | ) 44 | 45 | return get_compare_test_layout() 46 | -------------------------------------------------------------------------------- /nemo_inspector/layouts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from nemo_inspector.layouts.analyze_page_layouts.base_layout import ( 16 | get_compare_test_layout, 17 | get_filtered_tables_layout, 18 | get_tables_layout, 19 | get_updated_tables_layout, 20 | get_sorted_tables_layout, 21 | ) 22 | from nemo_inspector.layouts.analyze_page_layouts.utils import ( 23 | get_stats_input, 24 | get_stats_text, 25 | get_filter_text, 26 | ) 27 | from nemo_inspector.layouts.base_layouts import ( 28 | get_main_page_layout, 29 | ) 30 | from nemo_inspector.layouts.common_layouts import ( 31 | get_selector_layout, 32 | get_single_prompt_output_layout, 33 | get_switch_layout, 34 | get_text_modes_layout, 35 | ) 36 | 37 | from nemo_inspector.layouts.analyze_page_layouts.table_layouts import ( 38 | get_detailed_info_table_column, 39 | get_filter_modal_layout, 40 | get_table_column_header, 41 | get_detailed_info_table_row_content, 42 | get_single_prompt_output_layout, 43 | get_short_info_table_layout, 44 | get_detailed_info_table_content, 45 | ) 46 | -------------------------------------------------------------------------------- /nemo_inspector/settings/constants/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from nemo_inspector.settings.constants.configurations import ( 17 | CODE_SEPARATORS, 18 | DATA_PAGE_SIZE, 19 | EXTRA_FIELDS, 20 | STATS_KEYS, 21 | ) 22 | from nemo_inspector.settings.constants.common import ( 23 | EXPECTED_ANSWER_FIELD, 24 | ANSI, 25 | BASE_GENERATION, 26 | INLINE_STATS, 27 | GENERAL_STATS, 28 | CHOOSE_GENERATION, 29 | CHOOSE_LABEL, 30 | COMPARE, 31 | CODE, 32 | CUSTOM, 33 | DELETE, 34 | FILE_NAME, 35 | FILES_ONLY, 36 | FILES_FILTERING, 37 | GENERAL_STATS, 38 | CODE_BEGIN, 39 | CODE_END, 40 | CODE_OUTPUT_BEGIN, 41 | CODE_OUTPUT_END, 42 | QUESTIONS_FILTERING, 43 | QUESTION_FIELD, 44 | UNDEFINED, 45 | MARKDOWN, 46 | LABEL, 47 | LATEX, 48 | ) 49 | from nemo_inspector.settings.constants.paths import ( 50 | COMPARE_ICON_PATH, 51 | EDIT_ICON_PATH, 52 | SAVE_ICON_PATH, 53 | ) 54 | from nemo_inspector.settings.constants.templates import ( 55 | ERROR_MESSAGE_TEMPLATE, 56 | QUERY_INPUT_ID, 57 | MODEL_SELECTOR_ID, 58 | LABEL_SELECTOR_ID, 59 | ) 60 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import setup, find_packages 16 | 17 | 18 | def parse_requirements(filename): 19 | with open(filename) as f: 20 | return f.read().splitlines() 21 | 22 | 23 | # Read the requirements from the requirements.txt file 24 | requirements = parse_requirements("requirements/inspector.txt") 25 | 26 | setup( 27 | name="nemo_inspector", 28 | version="0.1.0", 29 | description="NeMo Inspector - a tool for datasets analysis", 30 | url="", 31 | long_description=open("README.md").read(), 32 | long_description_content_type="text/markdown", 33 | packages=find_packages(), 34 | include_package_data=True, 35 | install_requires=requirements, 36 | license="Apache License, Version 2.0", 37 | classifiers=[ 38 | "Programming Language :: Python :: 3", 39 | "Programming Language :: Python :: 3.10", 40 | "License :: OSI Approved :: Apache Software License", 41 | "Operating System :: OS Independent", 42 | ], 43 | entry_points={ 44 | "console_scripts": [ 45 | "nemo_inspector=nemo_inspector.inspector_app:main", 46 | ], 47 | }, 48 | python_requires=">=3.10", 49 | ) 50 | -------------------------------------------------------------------------------- /nemo_inspector/settings/inspector_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from itertools import chain 17 | from pathlib import Path 18 | from typing import Dict, Iterable, List 19 | 20 | from nemo_inspector.settings.constants.configurations import CODE_SEPARATORS 21 | 22 | 23 | def _expand_paths(paths: Iterable[str]) -> List[str]: 24 | expanded_files: List[str] = [] 25 | for path in paths: 26 | expanded = Path(path).expanduser() 27 | if any(char in str(expanded) for char in ["*", "?"]): 28 | expanded_files.extend(sorted(map(str, expanded.parent.glob(expanded.name)))) 29 | elif expanded.is_dir(): 30 | expanded_files.extend(sorted(map(str, expanded.rglob("*.jsonl")))) 31 | elif expanded.exists(): 32 | expanded_files.append(str(expanded)) 33 | return expanded_files 34 | 35 | 36 | def unroll_files(paths: Iterable[str]) -> List[str]: 37 | return list(dict.fromkeys(chain.from_iterable(_expand_paths([path]) for path in paths))) 38 | 39 | 40 | @dataclass(kw_only=True) 41 | class InspectorConfig: 42 | model_prediction: Dict[str, str] = field(default_factory=dict) 43 | save_generations_path: str = "nemo_inspector/results/saved_generations" 44 | code_tags: Dict[str, str] = field(default_factory=lambda: CODE_SEPARATORS) 45 | 46 | def __post_init__(self): 47 | self.model_prediction = { 48 | model_name: unroll_files(file_path.split(" ")) 49 | for model_name, file_path in self.model_prediction.items() 50 | } 51 | -------------------------------------------------------------------------------- /nemo_inspector/inspector_app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | from pathlib import Path 17 | import argparse 18 | import dataclasses 19 | import signal 20 | 21 | from nemo_inspector.parse_agruments_helpers import ( 22 | add_arguments_from_dataclass, 23 | args_postproccessing, 24 | create_dataclass_from_args, 25 | ) 26 | 27 | sys.path.append(str(Path(__file__).parents[1])) 28 | 29 | from nemo_inspector.layouts import get_main_page_layout 30 | 31 | from nemo_inspector.settings.inspector_config import InspectorConfig 32 | from nemo_inspector.settings.constants.configurations import CODE_SEPARATORS 33 | 34 | 35 | def main(): 36 | signal.signal(signal.SIGALRM, signal.SIG_IGN) 37 | 38 | parser = argparse.ArgumentParser(description="NeMo Inspector") 39 | 40 | add_arguments_from_dataclass( 41 | parser, 42 | InspectorConfig, 43 | enforce_required=False, 44 | use_type_defaults=True, 45 | ) 46 | 47 | args = parser.parse_args() 48 | args_dict = vars(args) 49 | 50 | cfg = dataclasses.asdict(create_dataclass_from_args(InspectorConfig, args_dict)) 51 | cfg.setdefault("code_tags", {}) 52 | cfg["code_tags"] = {**CODE_SEPARATORS, **cfg["code_tags"]} 53 | cfg = args_postproccessing(cfg) 54 | from nemo_inspector.callbacks import app 55 | 56 | app.server.config.update({"nemo_inspector": cfg}) 57 | app.title = "NeMo Inspector" 58 | app.layout = get_main_page_layout() 59 | app.run( 60 | host="localhost", 61 | port="8080", 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/update_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Tuple 16 | 17 | from dash import ALL, html, no_update 18 | from dash.dependencies import Input, Output, State 19 | 20 | from nemo_inspector.callbacks import app 21 | from nemo_inspector.layouts import get_updated_tables_layout 22 | from nemo_inspector.settings.constants import CHOOSE_GENERATION 23 | 24 | 25 | @app.callback( 26 | [ 27 | Output("update_dataset_modal", "is_open", allow_duplicate=True), 28 | Output("js_container", "children", allow_duplicate=True), 29 | Output("js_trigger", "children", allow_duplicate=True), 30 | ], 31 | [ 32 | Input("update_dataset_button", "n_clicks"), 33 | Input("apply_update_dataset_button", "n_clicks"), 34 | ], 35 | [State("update_dataset_modal", "is_open"), State("js_trigger", "children")], 36 | prevent_initial_call=True, 37 | ) 38 | def open_update_dataset_modal(n1: int, n2: int, is_open: bool, js_trigger: str) -> bool: 39 | if n1 or n2: 40 | is_open = not is_open 41 | return is_open, "", js_trigger + " " 42 | return is_open, "", js_trigger + " " 43 | 44 | 45 | @app.callback( 46 | [ 47 | Output("compare_models_rows", "children", allow_duplicate=True), 48 | Output("loading_container", "children", allow_duplicate=True), 49 | ], 50 | Input("apply_update_dataset_button", "n_clicks"), 51 | [ 52 | State("update_dataset_input", "value"), 53 | State({"type": "model_selector", "id": ALL}, "value"), 54 | State("base_model_answers_selector", "value"), 55 | State("loading_container", "children"), 56 | ], 57 | prevent_initial_call=True, 58 | ) 59 | def update_dataset( 60 | n_ckicks: int, 61 | update_function: str, 62 | models: List[str], 63 | base_model: str, 64 | loading_container: str, 65 | ) -> Tuple[List[html.Tr], bool]: 66 | if base_model == CHOOSE_GENERATION or not update_function: 67 | return no_update, no_update 68 | return ( 69 | get_updated_tables_layout( 70 | base_model=base_model, update_function=update_function, models=models 71 | ), 72 | loading_container + " ", 73 | ) 74 | -------------------------------------------------------------------------------- /nemo_inspector/tests/test_ping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import subprocess 17 | import sys 18 | from pathlib import Path 19 | 20 | import pytest 21 | from selenium import webdriver 22 | from selenium.webdriver.chrome.options import Options 23 | from selenium.webdriver.chrome.service import Service 24 | from selenium.webdriver.common.by import By 25 | from selenium.webdriver.support import expected_conditions as EC 26 | from selenium.webdriver.support.ui import WebDriverWait 27 | from webdriver_manager.chrome import ChromeDriverManager 28 | from webdriver_manager.core.os_manager import ChromeType 29 | 30 | project_root = str(Path(__file__).parents[2]) 31 | sys.path.remove(str(Path(__file__).parents[0])) 32 | 33 | 34 | @pytest.fixture(scope="module") 35 | def nemo_inspector_process(): 36 | # Start the NeMo Inspector as a subprocess 37 | 38 | process = subprocess.Popen( 39 | ["python", "nemo_inspector"], 40 | cwd=project_root, 41 | stdout=subprocess.PIPE, 42 | stderr=subprocess.PIPE, 43 | text=True, 44 | ) 45 | 46 | yield process 47 | 48 | # Terminate the process after the tests 49 | process.terminate() 50 | process.wait() 51 | 52 | 53 | @pytest.fixture 54 | def chrome_driver(): 55 | chrome_driver_path = ChromeDriverManager(chrome_type=ChromeType.GOOGLE).install() 56 | options = Options() 57 | options.page_load_strategy = "normal" 58 | options.add_argument("--headless") 59 | options.add_argument("--disable-gpu") 60 | options.add_argument("--no-sandbox") 61 | options.add_argument("--disable-dev-shm-usage") 62 | 63 | service = Service(chrome_driver_path) 64 | driver = webdriver.Chrome(service=service, options=options) 65 | os.environ["PATH"] += os.pathsep + "/".join(chrome_driver_path.split("/")[:-1]) 66 | yield driver 67 | driver.quit() 68 | 69 | 70 | @pytest.mark.parametrize( 71 | ("element_id", "url"), 72 | [("add_model", "/analyze")], 73 | ) 74 | def test_dash_app_launch(chrome_driver, nemo_inspector_process, element_id, url): 75 | full_url = f"http://localhost:8080{url}" 76 | 77 | chrome_driver.get(full_url) 78 | 79 | element = WebDriverWait(chrome_driver, 10).until( 80 | EC.presence_of_element_located((By.ID, element_id)) 81 | ) 82 | assert element.is_displayed() 83 | -------------------------------------------------------------------------------- /nemo_inspector/assets/styles/ansi_styles.css: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | 16 | /* Foreground Colors */ 17 | .ansi30 { color: black; } /* Black */ 18 | .ansi31 { color: red ; } /* Red */ 19 | .ansi32 { color: green; } /* Green */ 20 | .ansi33 { color: yellow; } /* Yellow */ 21 | .ansi34 { color: blue; } /* Blue */ 22 | .ansi35 { color: magenta; } /* Magenta */ 23 | .ansi36 { color: cyan; } /* Cyan */ 24 | .ansi37 { color: white; } /* White */ 25 | .ansi90 { color: grey; } /* Bright Black (grey) */ 26 | .ansi91 { color: #FFCCCB; } /* Bright Red */ 27 | .ansi92 { color: lightgreen; } /* Bright Green */ 28 | .ansi93 { color: lightyellow; } /* Bright Yellow */ 29 | .ansi94 { color: lightblue; } /* Bright Blue */ 30 | .ansi95 { color: #ff80ff;} /* Bright Magenta */ 31 | .ansi96 { color: lightcyan; } /* Bright Cyan */ 32 | .ansi97 { color: #FFFFF7; } /* Bright White */ 33 | 34 | /* Background Colors */ 35 | .ansi40 { background-color: black; } /* Black */ 36 | .ansi41 { background-color: red; } /* Red */ 37 | .ansi42 { background-color: green; } /* Green */ 38 | .ansi43 { background-color: yellow; } /* Yellow */ 39 | .ansi44 { background-color: blue; } /* Blue */ 40 | .ansi45 { background-color: magenta; } /* Magenta */ 41 | .ansi46 { background-color: cyan; } /* Cyan */ 42 | .ansi47 { background-color: white; } /* White */ 43 | .ansi100 { background-color: grey; } /* Bright Black (grey) */ 44 | .ansi101 { background-color: #FFCCCB; } /* Bright Red */ 45 | .ansi102 { background-color: lightgreen; } /* Bright Green */ 46 | .ansi103 { background-color: lightyellow;} /* Bright Yellow */ 47 | .ansi104 { background-color: lightblue; } /* Bright Blue */ 48 | .ansi105 { background-color: #ff80ff;}/* Bright Magenta */ 49 | .ansi106 { background-color: lightcyan; } /* Bright Cyan */ 50 | .ansi107 { background-color: #FFFFF7; } /* Bright White */ 51 | 52 | /* Styles */ 53 | .ansi1 { font-weight: bold; } /* Bold */ 54 | .ansi3 { font-style: italic; } /* Italic */ 55 | .ansi4 { text-decoration: underline; } /* Underline */ 56 | .ansi9 { text-decoration: line-through; } /* Strikethrough */ 57 | .ansi24 { text-decoration: none; } /* No underline */ 58 | .ansi39 { color: initial; } /* Default foreground color */ 59 | .ansi49 { background-color: initial; } /* Default background color */ 60 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/short_info_table.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict, List, Tuple 16 | 17 | from dash import no_update 18 | from dash.dependencies import Input, Output, State 19 | 20 | from nemo_inspector.callbacks import app 21 | from nemo_inspector.layouts import ( 22 | get_tables_layout, 23 | get_stats_input, 24 | ) 25 | from nemo_inspector.settings.constants import CHOOSE_GENERATION 26 | from nemo_inspector.utils.common import get_excluded_row, get_table_data 27 | 28 | 29 | @app.callback( 30 | [ 31 | Output("compare_models_rows", "children", allow_duplicate=True), 32 | Output("loading_container", "children", allow_duplicate=True), 33 | ], 34 | Input("base_model_answers_selector", "value"), 35 | State("loading_container", "children"), 36 | prevent_initial_call=True, 37 | ) 38 | def choose_base_model( 39 | base_model: str, 40 | loading_container: str, 41 | ) -> Tuple[List, bool]: 42 | if base_model == CHOOSE_GENERATION: 43 | return no_update, no_update 44 | get_excluded_row().clear() 45 | return ( 46 | get_tables_layout( 47 | base_model=base_model, 48 | ), 49 | loading_container + " ", 50 | ) 51 | 52 | 53 | @app.callback( 54 | Output("datatable", "data"), 55 | [ 56 | Input("datatable", "page_current"), 57 | Input("datatable", "page_size"), 58 | ], 59 | State("base_model_answers_selector", "value"), 60 | ) 61 | def change_page(page_current: int, page_size: int, base_model: str) -> List[Dict]: 62 | if not get_table_data(): 63 | return no_update 64 | return [ 65 | data[base_model][0] 66 | for data in get_table_data()[ 67 | page_current * page_size : (page_current + 1) * page_size 68 | ] 69 | if base_model in data.keys() 70 | ] 71 | 72 | 73 | @app.callback( 74 | [ 75 | Output("stats_input_container", "children", allow_duplicate=True), 76 | Output("js_container", "children", allow_duplicate=True), 77 | Output("js_trigger", "children", allow_duplicate=True), 78 | ], 79 | Input("stats_modes", "value"), 80 | State("js_trigger", "children"), 81 | prevent_initial_call=True, 82 | ) 83 | def change_stats_mode(modes: List[str], js_trigger: str) -> str: 84 | if modes is None: 85 | return no_update, no_update, no_update 86 | return get_stats_input(modes), "", js_trigger + " " 87 | -------------------------------------------------------------------------------- /nemo_inspector/utils/decoration/plain_text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from difflib import SequenceMatcher 16 | import re 17 | 18 | 19 | def tokenize(text: str): 20 | """ 21 | Tokenize the text into separate tokens for: 22 | - Whitespace sequences (\s+) 23 | - Word sequences (\w+) 24 | - Punctuation sequences ([^\w\s]+) 25 | 26 | The regex ensures we capture all text in order. 27 | """ 28 | # This pattern will capture tokens in the order they appear: 29 | # (\s+) => one or more whitespace chars 30 | # (\w+) => one or more word chars 31 | # ([^\w\s]+) => one or more chars that are not word chars or whitespace (punctuation) 32 | pattern = re.compile(r"(\s+|\w+|[^\w\s]+)") 33 | tokens = pattern.findall(text) 34 | return tokens 35 | 36 | 37 | def color_text_diff(text1: str, text2: str): 38 | if text1 == text2: 39 | return [(text1, {})] 40 | 41 | tokens1 = tokenize(text1) 42 | tokens2 = tokenize(text2) 43 | 44 | matcher = SequenceMatcher(None, tokens1, tokens2) 45 | result = [] 46 | 47 | for tag, i1, i2, j1, j2 in matcher.get_opcodes(): 48 | if tag == "equal": 49 | for k in range(j1, j2): 50 | result.append((tokens2[k], {})) 51 | 52 | elif tag == "replace": 53 | # Tokens from text1 (deleted) 54 | for k in range(i1, i2): 55 | result.append((tokens1[k], {"background-color": "#c8e6c9"})) 56 | # Tokens from text2 (inserted) 57 | for k in range(j1, j2): 58 | result.append( 59 | ( 60 | tokens2[k], 61 | { 62 | "background-color": "#ffcdd2", 63 | "text-decoration": "line-through", 64 | }, 65 | ) 66 | ) 67 | 68 | elif tag == "insert": 69 | for k in range(j1, j2): 70 | result.append( 71 | ( 72 | tokens2[k], 73 | { 74 | "background-color": "#ffcdd2", 75 | "text-decoration": "line-through", 76 | }, 77 | ) 78 | ) 79 | 80 | elif tag == "delete": 81 | for k in range(i1, i2): 82 | result.append((tokens1[k], {"background-color": "#c8e6c9"})) 83 | 84 | return result 85 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/save_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import os 17 | from typing import List, Tuple 18 | 19 | from dash import callback_context, html, no_update 20 | from dash.dependencies import Input, Output, State 21 | 22 | from nemo_inspector.callbacks import app 23 | from nemo_inspector.settings.constants import ( 24 | EXTRA_FIELDS, 25 | FILE_NAME, 26 | ) 27 | from nemo_inspector.settings.constants.paths import PATH_TO_THE_REPOSITORY 28 | from nemo_inspector.utils.common import get_table_data 29 | 30 | 31 | @app.callback( 32 | Output("save_dataset_modal", "is_open", allow_duplicate=True), 33 | Input("save_dataset", "n_clicks"), 34 | prevent_initial_call=True, 35 | ) 36 | def open_save_dataset_modal(n1: int) -> bool: 37 | ctx = callback_context 38 | if not ctx.triggered: 39 | return no_update 40 | 41 | return True 42 | 43 | 44 | @app.callback( 45 | [ 46 | Output("save_dataset_modal", "is_open", allow_duplicate=True), 47 | Output("error_message", "children"), 48 | ], 49 | Input("save_dataset_button", "n_clicks"), 50 | [ 51 | State("base_model_answers_selector", "value"), 52 | State("save_path", "value"), 53 | ], 54 | prevent_initial_call=True, 55 | ) 56 | def save_dataset(n_click: int, base_model: str, save_path: str) -> Tuple[List, bool]: 57 | if not n_click or not save_path or not base_model: 58 | return no_update, no_update 59 | if save_path.startswith("nemo_inspector"): 60 | save_path = os.path.join(PATH_TO_THE_REPOSITORY, save_path) 61 | if not os.path.exists(save_path): 62 | try: 63 | os.mkdir(save_path) 64 | except: 65 | return True, html.Pre(f"could not save generations by path {save_path}") 66 | 67 | new_data = {} 68 | 69 | for data in get_table_data(): 70 | for file_data in data[base_model]: 71 | file_name = file_data[FILE_NAME] 72 | if file_name not in new_data: 73 | new_data[file_name] = [] 74 | new_data[file_name].append( 75 | { 76 | key: value 77 | for key, value in file_data.items() 78 | if key not in EXTRA_FIELDS 79 | } 80 | ) 81 | 82 | for file_name, data in new_data.items(): 83 | with open(os.path.join(save_path, file_name + ".jsonl"), "w") as file: 84 | file.write("\n".join([json.dumps(line) for line in data])) 85 | 86 | return False, "" 87 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/sort_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | from typing import List, Tuple 17 | 18 | from dash import ALL, callback_context, html, no_update 19 | from dash.dependencies import Input, Output, State 20 | 21 | from nemo_inspector.callbacks import app 22 | from nemo_inspector.layouts import get_sorted_tables_layout 23 | from nemo_inspector.settings.constants import CHOOSE_GENERATION 24 | 25 | 26 | @app.callback( 27 | [ 28 | Output({"type": "sorting", "id": ALL}, "is_open"), 29 | Output("js_container", "children", allow_duplicate=True), 30 | Output("js_trigger", "children", allow_duplicate=True), 31 | ], 32 | [ 33 | Input({"type": "set_sorting_button", "id": ALL}, "n_clicks"), 34 | Input({"type": "apply_sorting_button", "id": ALL}, "n_clicks"), 35 | ], 36 | [ 37 | State({"type": "sorting", "id": ALL}, "is_open"), 38 | State("js_trigger", "children"), 39 | ], 40 | prevent_initial_call=True, 41 | ) 42 | def toggle_modal_sorting(n1: int, n2: int, is_open: bool, js_trigger: str) -> bool: 43 | ctx = callback_context 44 | if not ctx.triggered: 45 | return [no_update] * len(is_open), no_update, no_update 46 | 47 | button_id = json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"] + 1 48 | 49 | if not ctx.triggered[0]["value"]: 50 | return [no_update] * len(is_open), no_update, no_update 51 | 52 | if n1[button_id] or n2[button_id]: 53 | is_open[button_id] = not is_open[button_id] 54 | return is_open, "", js_trigger + " " 55 | return is_open, "", js_trigger + " " 56 | 57 | 58 | @app.callback( 59 | [ 60 | Output("compare_models_rows", "children", allow_duplicate=True), 61 | Output("sorting_container", "children"), 62 | Output("loading_container", "children", allow_duplicate=True), 63 | ], 64 | Input({"type": "apply_sorting_button", "id": -1}, "n_clicks"), 65 | [ 66 | State({"type": "sorting_function_input", "id": -1}, "value"), 67 | State({"type": "model_selector", "id": ALL}, "value"), 68 | State("base_model_answers_selector", "value"), 69 | State("loading_container", "children"), 70 | ], 71 | prevent_initial_call=True, 72 | ) 73 | def sorting_data( 74 | n_ckicks: int, 75 | sorting_function: str, 76 | models: List[str], 77 | base_model: str, 78 | loading_container: str, 79 | ) -> Tuple[List[html.Tr], bool]: 80 | if base_model == CHOOSE_GENERATION or not sorting_function: 81 | return no_update, no_update, no_update 82 | return ( 83 | get_sorted_tables_layout( 84 | base_model=base_model, 85 | sorting_function=sorting_function, 86 | models=models, 87 | ), 88 | html.Pre(f"Sorting function:\n{sorting_function}"), 89 | loading_container + " ", 90 | ) 91 | -------------------------------------------------------------------------------- /nemo_inspector/utils/decoration/code.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from html import escape 16 | from io import StringIO 17 | from typing import Dict, List, Tuple 18 | 19 | from dash import html 20 | from pygments.formatters import HtmlFormatter 21 | from pygments.lexers import PythonLexer 22 | 23 | from nemo_inspector.utils.decoration.common import ( 24 | get_random_id, 25 | iframe_template, 26 | update_height_js, 27 | ) 28 | 29 | 30 | def highlight_code(codes: List[Tuple[str, Dict[str, str]]], **kwargs) -> html.Iframe: 31 | 32 | full_code = "".join([code for code, style in codes]) 33 | 34 | # Track positions and styles 35 | positions = [] 36 | current_pos = 0 37 | for code, style in codes: 38 | start_pos = current_pos 39 | end_pos = current_pos + len(code) 40 | if style: 41 | positions.append((start_pos, end_pos, style)) 42 | current_pos = end_pos 43 | 44 | # Custom formatter to apply styles at correct positions 45 | class CustomHtmlFormatter(HtmlFormatter): 46 | def __init__(self, positions, **options): 47 | super().__init__(**options) 48 | self.positions = positions 49 | self.current_pos = 0 50 | 51 | def format(self, tokensource, outfile): 52 | style_starts = {start: style for start, _, style in self.positions} 53 | style_ends = {end: style for _, end, style in self.positions} 54 | active_styles = [] 55 | 56 | for ttype, value in tokensource: 57 | token_length = len(value) 58 | token_start = self.current_pos 59 | 60 | # Apply styles character by character 61 | result = "" 62 | for i, char in enumerate(value): 63 | char_pos = token_start + i 64 | 65 | # Check if a style starts or ends here 66 | if char_pos in style_starts: 67 | style = style_starts[char_pos] 68 | active_styles.append(style) 69 | if char_pos in style_ends: 70 | style = style_ends[char_pos] 71 | if style in active_styles: 72 | active_styles.remove(style) 73 | 74 | # Get CSS class for syntax highlighting 75 | css_class = self._get_css_class(ttype) 76 | char_html = escape(char) 77 | if css_class: 78 | char_html = f'{char_html}' 79 | 80 | # Apply active styles 81 | if active_styles: 82 | combined_style = {} 83 | for style_dict in active_styles: 84 | combined_style.update(style_dict) 85 | style_str = "; ".join( 86 | f"{k}: {v}" for k, v in combined_style.items() 87 | ) 88 | char_html = f'{char_html}' 89 | 90 | result += char_html 91 | 92 | outfile.write(result) 93 | self.current_pos += token_length 94 | 95 | # Use the custom formatter to highlight the code 96 | lexer = PythonLexer() 97 | formatter = CustomHtmlFormatter(positions, nowrap=True) 98 | style_defs = formatter.get_style_defs(".highlight") 99 | style_defs += """ 100 | .highlight { 101 | font-family: monospace; 102 | } 103 | """ 104 | 105 | output = StringIO() 106 | formatter.format(lexer.get_tokens(full_code), output) 107 | highlighted_code = output.getvalue() 108 | 109 | # Build the iframe content 110 | iframe_id = get_random_id() 111 | content = f""" 112 |
{full_text}",
46 | ),
47 | style=style,
48 | )
49 | return html.Div(
50 | (
51 | dcc.Markdown(
52 | preprocess_latex(full_text, escape=MARKDOWN not in text_modes),
53 | mathjax=True,
54 | dangerously_allow_html=True,
55 | )
56 | if LATEX in text_modes and COMPARE not in text_modes
57 | else (
58 | dcc.Markdown(full_text)
59 | if MARKDOWN in text_modes and COMPARE not in text_modes
60 | else [
61 | html.Span(text, style={**inner_style, "whiteSpace": "pre-wrap"})
62 | for text, inner_style in texts
63 | ]
64 | )
65 | ),
66 | style=style,
67 | ) # TODO make classes
68 |
69 |
70 | def get_height_adjustment() -> html.Iframe:
71 | return html.Iframe(
72 | id="query_params_iframe",
73 | srcDoc="""
74 |
75 |
76 |
77 |
78 |
79 |
84 |
85 |
86 | """,
87 | style={"visibility": "hidden"},
88 | )
89 |
90 |
91 | def update_height_js(iframe_id: str) -> str:
92 | return f"""
93 | function updateHeight() {{
94 | var body = document.body,
95 | html = document.documentElement;
96 |
97 | var height = Math.max(body.scrollHeight, body.offsetHeight,
98 | html.clientHeight, html.scrollHeight, html.offsetHeight);
99 |
100 | parent.postMessage({{ frameHeight: height, frameId: '{iframe_id}' }}, '*');
101 | }}
102 | window.onload = updateHeight;
103 | window.onresize = updateHeight;
104 | """
105 |
106 |
107 | def get_random_id() -> str:
108 | return "".join(random.choices(string.ascii_letters + string.digits, k=20))
109 |
110 |
111 | def iframe_template(
112 | header: str, content: str, style: Dict = {}, iframe_id: str = None
113 | ) -> html.Iframe:
114 | if not iframe_id:
115 | iframe_id = get_random_id()
116 |
117 | iframe_style = {
118 | "width": "100%",
119 | "border": "none",
120 | "overflow": "hidden",
121 | }
122 |
123 | iframe_style.update(style)
124 |
125 | return html.Iframe(
126 | id=iframe_id,
127 | srcDoc=f"""
128 |
129 |
130 |
131 | {header}
132 |
133 |
134 | {content}
135 |
136 |
137 | """,
138 | style=iframe_style,
139 | )
140 |
--------------------------------------------------------------------------------
/nemo_inspector/callbacks/analyze_page/count_stats_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | from typing import List
17 |
18 | from dash import no_update
19 | from dash.dependencies import Input, Output, State
20 |
21 | from nemo_inspector.callbacks import app
22 | from nemo_inspector.layouts import (
23 | get_tables_layout,
24 | get_stats_input,
25 | )
26 | from nemo_inspector.settings.constants import (
27 | CHOOSE_GENERATION,
28 | DELETE,
29 | ERROR_MESSAGE_TEMPLATE,
30 | GENERAL_STATS,
31 | INLINE_STATS,
32 | )
33 | from nemo_inspector.utils.common import (
34 | calculate_metrics_for_whole_data,
35 | get_custom_stats,
36 | get_deleted_stats,
37 | get_general_custom_stats,
38 | get_stats_raw,
39 | get_table_data,
40 | )
41 |
42 |
43 | @app.callback(
44 | [
45 | Output("new_stats", "is_open"),
46 | Output("stats_input_container", "children", allow_duplicate=True),
47 | Output("js_container", "children", allow_duplicate=True),
48 | Output("js_trigger", "children", allow_duplicate=True),
49 | ],
50 | [
51 | Input("set_new_stats_button", "n_clicks"),
52 | Input("apply_new_stats", "n_clicks"),
53 | ],
54 | [
55 | State("new_stats", "is_open"),
56 | State("stats_modes", "value"),
57 | State("js_trigger", "children"),
58 | ],
59 | prevent_initial_call=True,
60 | )
61 | def toggle_modal_stats(
62 | n1: int, n2: int, is_open: bool, modes: List[str], js_trigger: str
63 | ) -> bool:
64 | if not n1 and not n2:
65 | return no_update, no_update, no_update, no_update
66 |
67 | if n1 or n2:
68 | is_open = not is_open
69 | return is_open, get_stats_input(modes), "", js_trigger + " "
70 | return is_open, get_stats_input(modes), "", js_trigger + " "
71 |
72 |
73 | @app.callback(
74 | Output("compare_models_rows", "children", allow_duplicate=True),
75 | Input("apply_new_stats", "n_clicks"),
76 | [
77 | State("stats_input", "value"),
78 | State("base_model_answers_selector", "value"),
79 | State("stats_modes", "value"),
80 | ],
81 | prevent_initial_call=True,
82 | )
83 | def apply_new_stat(
84 | n_click: int,
85 | code_raw: str,
86 | base_model: str,
87 | stats_modes: List[str],
88 | ) -> List:
89 | if not n_click or code_raw == "":
90 | return no_update
91 | code_raw_lines = code_raw.strip().split("\n")
92 | if not stats_modes or DELETE not in stats_modes:
93 | code = "\n".join(code_raw_lines[:-1]) + "\nnew_stats = " + code_raw_lines[-1]
94 | else:
95 | code = "delete_stats = " + f"'{code_raw_lines[-1]}'"
96 | namespace = {}
97 | try:
98 | exec(code, namespace)
99 | except Exception as e:
100 | logging.error(ERROR_MESSAGE_TEMPLATE.format(code, str(e)))
101 | return no_update
102 | if stats_modes and GENERAL_STATS in stats_modes:
103 | if DELETE in stats_modes:
104 | get_general_custom_stats().pop(namespace["delete_stats"], None)
105 | else:
106 | get_general_custom_stats().update(namespace["new_stats"])
107 | get_stats_raw()[GENERAL_STATS][
108 | " ".join(namespace["new_stats"].keys())
109 | ] = code_raw
110 | else:
111 | if stats_modes and DELETE in stats_modes:
112 | get_custom_stats().pop(namespace["delete_stats"], None)
113 | get_deleted_stats().update(namespace["delete_stats"])
114 | else:
115 | get_custom_stats().update(namespace["new_stats"])
116 | get_stats_raw()[INLINE_STATS][
117 | " ".join(namespace["new_stats"].keys())
118 | ] = code_raw
119 | if base_model == CHOOSE_GENERATION:
120 | return []
121 | calculate_metrics_for_whole_data(get_table_data(), base_model)
122 | return get_tables_layout(base_model=base_model)
123 |
124 |
125 | @app.callback(
126 | [
127 | Output("stats_input", "value", allow_duplicate=True),
128 | Output("js_container", "children", allow_duplicate=True),
129 | Output("js_trigger", "children", allow_duplicate=True),
130 | ],
131 | Input("stats_extractor", "value"),
132 | [
133 | State("stats_modes", "value"),
134 | State("js_trigger", "children"),
135 | ],
136 | prevent_initial_call=True,
137 | )
138 | def apply_new_stat(stat: str, stats_modes: List[str], js_trigger: str) -> List:
139 | mode = GENERAL_STATS if GENERAL_STATS in stats_modes else INLINE_STATS
140 | return get_stats_raw()[mode][stat], " ", js_trigger + " "
141 |
--------------------------------------------------------------------------------
/nemo_inspector/parse_agruments_helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import dataclasses
17 | from typing import Any, Optional, Type
18 |
19 | from nemo_inspector.utils.common import get_type_default, resolve_union_or_any
20 |
21 |
22 | class ParseDict(argparse.Action):
23 | def __call__(self, parser, namespace, values, option_string=None):
24 | setattr(namespace, self.dest, dict())
25 | for value in values:
26 | key, value = value.split("=")
27 | getattr(namespace, self.dest)[key] = value
28 |
29 |
30 | def add_arguments_from_dataclass(
31 | parser: argparse.ArgumentParser,
32 | dataclass_type: Type[Any],
33 | prefix: str = "",
34 | use_default: Optional[str] = None,
35 | enforce_required: bool = True,
36 | use_type_defaults: bool = True,
37 | ):
38 | for field in dataclasses.fields(dataclass_type):
39 | field_name = f"{prefix}{field.name}"
40 | field_type = field.type
41 | # Handle default values and default_factory
42 | has_default = field.default != dataclasses.MISSING
43 | has_default_factory = field.default_factory != dataclasses.MISSING
44 |
45 | # Handle nested dataclasses
46 | if dataclasses.is_dataclass(field_type):
47 | add_arguments_from_dataclass(
48 | parser,
49 | field_type,
50 | prefix=f"{field_name}.",
51 | use_default=use_default,
52 | enforce_required=enforce_required,
53 | use_type_defaults=use_type_defaults,
54 | )
55 | continue
56 |
57 | arg_name = f"--{field_name}"
58 |
59 | # Handle Optional and Union types
60 | field_type = resolve_union_or_any(field_type)
61 |
62 | # Determine if the argument is required
63 | is_required = not has_default and not has_default_factory and use_default
64 |
65 | # Helper function to add the argument to the parser
66 | def add_argument(**kwargs):
67 | # Ensure default is passed only once
68 | if "default" not in kwargs:
69 | if use_type_defaults and field.default == dataclasses.MISSING:
70 | kwargs["default"] = get_type_default(field_type)
71 | elif has_default:
72 | kwargs["default"] = field.default
73 | kwargs["required"] = kwargs.get("required", is_required and enforce_required)
74 | parser.add_argument(arg_name, **kwargs)
75 |
76 | # Add argument based on the type of the field
77 | default_message = f"(default: {str(field.default)})" if has_default else ""
78 | kwargs = {"default": use_default} if use_default else {}
79 | if field_type == bool:
80 | add_argument(
81 | action="store_true",
82 | help=f"{field_name} flag {default_message}",
83 | **kwargs,
84 | )
85 | elif field_type == dict:
86 | add_argument(
87 | nargs="*",
88 | action=ParseDict,
89 | help=f"{field_name} flag {default_message}",
90 | **kwargs,
91 | )
92 | else:
93 | add_argument(
94 | type=field_type, help=f"{field_name} flag {default_message}", **kwargs
95 | )
96 |
97 |
98 | def create_dataclass_from_args(
99 | dataclass_type: Type[Any], args_dict: dict, prefix: str = ""
100 | ):
101 | init_kwargs = {}
102 | for field in dataclasses.fields(dataclass_type):
103 | field_name = f"{prefix}{field.name}"
104 | field_type = field.type
105 | # Handle nested dataclasses
106 | if dataclasses.is_dataclass(field_type):
107 | value = create_dataclass_from_args(
108 | field_type,
109 | args_dict,
110 | prefix=f"{field_name}.",
111 | )
112 | init_kwargs[field.name] = value
113 | continue
114 |
115 | arg_name = field_name
116 | if arg_name in args_dict:
117 | value = args_dict[arg_name]
118 | else:
119 | if field.default != dataclasses.MISSING:
120 | value = field.default
121 | elif field.default_factory != dataclasses.MISSING:
122 | value = field.default_factory()
123 | else:
124 | value = get_type_default(field_type)
125 | init_kwargs[field.name] = value
126 | return dataclass_type(**init_kwargs)
127 |
128 |
129 | def args_postproccessing(args):
130 | args["input_file"] = str(args.get("input_file", ""))
131 |
132 | return args
133 |
--------------------------------------------------------------------------------
/nemo_inspector/callbacks/analyze_page/filter_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 | from typing import List, Tuple
17 |
18 | from dash import ALL, callback_context, html, no_update
19 | from dash.dependencies import Input, Output, State
20 |
21 | from nemo_inspector.callbacks import app
22 | from nemo_inspector.layouts import (
23 | get_filtered_tables_layout,
24 | get_filter_text,
25 | get_sorted_tables_layout,
26 | )
27 | from nemo_inspector.settings.constants import (
28 | CHOOSE_GENERATION,
29 | FILES_FILTERING,
30 | QUESTIONS_FILTERING,
31 | )
32 | from nemo_inspector.utils.common import get_table_data
33 |
34 |
35 | @app.callback(
36 | [
37 | Output({"type": "filter", "id": ALL}, "is_open"),
38 | Output("js_container", "children", allow_duplicate=True),
39 | Output("js_trigger", "children", allow_duplicate=True),
40 | ],
41 | [
42 | Input({"type": "set_filter_button", "id": ALL}, "n_clicks"),
43 | Input({"type": "apply_filter_button", "id": ALL}, "n_clicks"),
44 | ],
45 | [
46 | State({"type": "filter", "id": ALL}, "is_open"),
47 | State("js_trigger", "children"),
48 | ],
49 | prevent_initial_call=True,
50 | )
51 | def toggle_modal_filter(n1: int, n2: int, is_open: bool, js_trigger: str) -> bool:
52 | ctx = callback_context
53 | if not ctx.triggered:
54 | return [no_update] * len(is_open), no_update, no_update
55 | button_id = json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"] + 1
56 | if not ctx.triggered[0]["value"]:
57 | return [no_update] * len(is_open), "", js_trigger + " "
58 |
59 | if n1[button_id] or n2[button_id]:
60 | is_open[button_id] = not is_open[button_id]
61 | return is_open, "", js_trigger + " "
62 | return is_open, "", js_trigger + " "
63 |
64 |
65 | @app.callback(
66 | [
67 | Output("compare_models_rows", "children", allow_duplicate=True),
68 | Output("filtering_container", "children"),
69 | Output("loading_container", "children", allow_duplicate=True),
70 | ],
71 | [
72 | Input({"type": "apply_filter_button", "id": -1}, "n_clicks"),
73 | ],
74 | [
75 | State({"type": "filter_function_input", "id": -1}, "value"),
76 | State({"type": "apply_on_filtered_data", "id": -1}, "value"),
77 | State({"type": "filter_mode", "id": -1}, "value"),
78 | State({"type": "sorting_function_input", "id": -1}, "value"),
79 | State({"type": "model_selector", "id": ALL}, "value"),
80 | State("base_model_answers_selector", "value"),
81 | State("filtering_container", "children"),
82 | State("loading_container", "children"),
83 | ],
84 | prevent_initial_call=True,
85 | )
86 | def filter_data(
87 | n_ckicks: int,
88 | filter_function: str,
89 | apply_on_filtered_data: int,
90 | filter_mode: List[str],
91 | sorting_function: str,
92 | models: List[str],
93 | base_model: str,
94 | filtering_functions: str,
95 | loading_container: str,
96 | ) -> Tuple[List[html.Tr], bool]:
97 | if not n_ckicks:
98 | return no_update, no_update, no_update
99 | if apply_on_filtered_data and filtering_functions:
100 | filtering_functions["props"]["children"] += f"\n{filter_function}"
101 | if base_model == CHOOSE_GENERATION:
102 | return [], no_update, no_update
103 | if len(get_table_data()) == 0: # TODO fix
104 | models = [models[0]]
105 | get_filtered_tables_layout(
106 | base_model=base_model,
107 | filtering_function=filter_function,
108 | filter_mode=(
109 | FILES_FILTERING if filter_mode and len(filter_mode) else QUESTIONS_FILTERING
110 | ),
111 | apply_on_filtered_data=(apply_on_filtered_data if apply_on_filtered_data else 0),
112 | models=models,
113 | )
114 | rows = get_sorted_tables_layout(
115 | base_model=base_model,
116 | sorting_function=sorting_function,
117 | models=models,
118 | )
119 | return (
120 | rows,
121 | (
122 | html.Pre(f"Filtering function:\n{filter_function}")
123 | if not apply_on_filtered_data or not filtering_functions
124 | else filtering_functions
125 | ),
126 | loading_container + " ",
127 | )
128 |
129 |
130 | @app.callback(
131 | [
132 | Output(
133 | {"type": "filter_text", "id": -1},
134 | "children",
135 | allow_duplicate=True,
136 | ),
137 | Output("js_container", "children", allow_duplicate=True),
138 | Output("js_trigger", "children", allow_duplicate=True),
139 | ],
140 | Input({"type": "filter_mode", "id": -1}, "value"),
141 | State("js_trigger", "children"),
142 | prevent_initial_call=True,
143 | )
144 | def change_filter_mode(modes: List[str], js_trigger: str) -> str:
145 | if modes is None:
146 | return no_update, no_update, no_update
147 | mode = FILES_FILTERING if modes and len(modes) else QUESTIONS_FILTERING
148 | text = get_filter_text(mode=mode)
149 | return (
150 | text,
151 | "",
152 | js_trigger + " ",
153 | )
154 |
--------------------------------------------------------------------------------
/nemo_inspector/layouts/common_layouts.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import itertools
16 | from typing import Dict, Iterable, List, Optional, Union
17 |
18 | import dash_bootstrap_components as dbc
19 | from dash import html
20 |
21 | from nemo_inspector.settings.constants import (
22 | ANSI,
23 | CODE,
24 | COMPARE,
25 | LATEX,
26 | MARKDOWN,
27 | )
28 | from nemo_inspector.utils.common import parse_model_answer
29 | from nemo_inspector.utils.decoration import (
30 | color_text_diff,
31 | design_text_output,
32 | highlight_code,
33 | )
34 |
35 |
36 | def get_switch_layout(
37 | id: Union[Dict, str],
38 | labels: List[str],
39 | values: Optional[List[str]] = None,
40 | disabled: List[bool] = [False],
41 | is_active: bool = False,
42 | chosen_values: Optional[List[str]] = None,
43 | additional_params: Dict = {},
44 | ) -> dbc.Checklist:
45 | if values is None:
46 | values = labels
47 | return dbc.Checklist(
48 | id=id,
49 | options=[
50 | {
51 | "label": label,
52 | "value": value,
53 | "disabled": is_disabled,
54 | }
55 | for label, value, is_disabled in itertools.zip_longest(
56 | labels, values, disabled, fillvalue=False
57 | )
58 | ],
59 | value=(chosen_values if chosen_values else [values[0]] if is_active else []),
60 | **additional_params,
61 | )
62 |
63 |
64 | def get_selector_layout(options: Iterable, id: str, value: str = "") -> dbc.Select:
65 | if value not in options:
66 | options = [value] + list(options)
67 | return dbc.Select(
68 | id=id,
69 | options=[
70 | {
71 | "label": str(value),
72 | "value": value,
73 | }
74 | for value in options
75 | ],
76 | value=str(value),
77 | )
78 |
79 |
80 | def get_single_prompt_output_layout(
81 | answer: str, text_modes: List[str] = [CODE, LATEX, ANSI], compare_to: str = ""
82 | ) -> List[html.Div]:
83 | parsed_answers = (
84 | parse_model_answer(answer)
85 | if CODE in text_modes
86 | else [{"explanation": answer, "code": None, "output": None}]
87 | )
88 | parsed_compared_answers = (
89 | (
90 | parse_model_answer(compare_to)
91 | if CODE in text_modes
92 | else [{"explanation": compare_to, "code": None, "output": None}]
93 | )
94 | if COMPARE in text_modes
95 | else parsed_answers
96 | )
97 |
98 | items = []
99 | styles = {
100 | "explanation": {"default": {}, "wrong": {}},
101 | "code": {"default": {}, "wrong": {}},
102 | "output": {
103 | "default": {
104 | "border": "1px solid black",
105 | "background-color": "#cdd4f1c8",
106 | "marginBottom": "10px",
107 | "marginTop": "-6px",
108 | },
109 | "wrong": {
110 | "border": "1px solid red",
111 | "marginBottom": "10px",
112 | "marginTop": "-6px",
113 | },
114 | },
115 | }
116 |
117 | functions = {
118 | "explanation": design_text_output,
119 | "code": highlight_code,
120 | "output": design_text_output,
121 | }
122 |
123 | def check_existence(array: List[Dict[str, str]], i: int, key: str):
124 | return i < len(array) and key in array[i] and array[i][key]
125 |
126 | for i in range(max(len(parsed_answers), len(parsed_compared_answers))):
127 | for key in ["explanation", "code", "output"]:
128 | if check_existence(parsed_answers, i, key) or check_existence(
129 | parsed_compared_answers, i, key
130 | ):
131 | diff = color_text_diff(
132 | (
133 | parsed_answers[i][key]
134 | if check_existence(parsed_answers, i, key)
135 | else ""
136 | ),
137 | (
138 | parsed_compared_answers[i][key]
139 | if check_existence(parsed_compared_answers, i, key)
140 | else ""
141 | ),
142 | )
143 | style_type = (
144 | "default"
145 | if not check_existence(parsed_answers, i, key)
146 | or "wrong_code_block" not in parsed_answers[i][key]
147 | else "wrong"
148 | )
149 | style = styles[key][style_type]
150 | item = functions[key](diff, style=style, text_modes=text_modes)
151 | items.append(item)
152 | return items
153 |
154 |
155 | def get_text_modes_layout(id: str, is_formatted: bool = True):
156 | return get_switch_layout(
157 | id={
158 | "type": "text_modes",
159 | "id": id,
160 | },
161 | labels=[CODE, LATEX, MARKDOWN, ANSI],
162 | chosen_values=[CODE, LATEX, ANSI] if is_formatted else [],
163 | additional_params={
164 | "style": {
165 | "display": "inline-flex",
166 | "flex-wrap": "wrap",
167 | },
168 | "inputStyle": {"margin-left": "-10px"},
169 | "labelStyle": {"margin-left": "3px"},
170 | },
171 | )
172 |
--------------------------------------------------------------------------------
/nemo_inspector/utils/decoration/latex.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Callable, Optional, Tuple
16 |
17 |
18 | def get_starts_with_latex_tag_function(
19 | tag: str, default_index_move: int
20 | ) -> Callable[[str, int], Tuple[bool, int]]:
21 | def starts_with_tag_func_templ(text: str, index: int):
22 | is_starts_with_tag = text.startswith(tag, index)
23 | if not is_starts_with_tag:
24 | returning_index = index + default_index_move
25 | elif "{" not in tag:
26 | returning_index = index + len(tag)
27 | else:
28 | returning_index = text.find("}", index) % (len(text) + 1)
29 |
30 | return is_starts_with_tag, returning_index
31 |
32 | return starts_with_tag_func_templ
33 |
34 |
35 | def proccess_latex_tag(
36 | text: str,
37 | start_index: int,
38 | detect_start_token: Callable[[str, int], Tuple[bool, int]],
39 | detect_end_token: Callable[[str, int], Tuple[bool, int]],
40 | end_sign: Optional[str],
41 | last_block_only: bool = False,
42 | ) -> int:
43 | count = 0
44 | index = start_index
45 | while index < len(text):
46 | if end_sign and text[index] == end_sign:
47 | return start_index, start_index + 1
48 | is_start_token, new_index = detect_start_token(text, index)
49 | count += is_start_token
50 | if last_block_only and is_start_token:
51 | start_index = index
52 | count = min(1, count)
53 | index = new_index
54 | is_end_token, index = detect_end_token(text, index)
55 | count -= is_end_token
56 | if count == 0:
57 | break
58 | return start_index, index + 1
59 |
60 |
61 | def get_single_dollar_detection_functions(
62 | direction: int, default_index_move: int
63 | ) -> Callable[[str, int], Tuple[bool, int]]:
64 | return lambda text, index: (
65 | text[index] == "$" and not text[index + direction].isspace(),
66 | index + default_index_move,
67 | )
68 |
69 |
70 | def get_latex_detection_functions(text, index) -> tuple[
71 | Callable[[str, int], Tuple[bool, int]],
72 | Callable[[str, int], Tuple[bool, int]],
73 | Optional[str],
74 | bool,
75 | bool,
76 | ]:
77 | multiline_tags = [("\\begin{", "\\end{", True), ("$$", "$$", False)]
78 | for start_tag, end_tag, add_dollars in multiline_tags:
79 | if text.startswith(start_tag, index):
80 | return (
81 | get_starts_with_latex_tag_function(start_tag, 1),
82 | get_starts_with_latex_tag_function(end_tag, 0),
83 | None,
84 | add_dollars,
85 | False,
86 | )
87 |
88 | starts_with_dollar_func = get_single_dollar_detection_functions(1, 1)
89 | ends_with_dollar_func = get_single_dollar_detection_functions(-1, 0)
90 | if starts_with_dollar_func(text, index)[0]:
91 | return starts_with_dollar_func, ends_with_dollar_func, "\n", False, True
92 |
93 | return None, None, None, None, None
94 |
95 |
96 | def proccess_plain_text(text: str) -> str:
97 | special_chars = r"*_{}[]()#+-.!`"
98 | for character in special_chars:
99 | text = text.replace(character, "\\" + character)
100 | return text
101 |
102 |
103 | def preprocess_latex(text: str, escape: bool = True) -> str:
104 | text = "\n" + text.replace("\\[", "\n$$\n").replace("\\]", "\n$$\n").replace(
105 | "\\(", " $"
106 | ).replace("\\)", "$ ")
107 |
108 | right_side_operations = ["-", "=", "+", "*", "/"]
109 | left_side_operations = ["=", "+", "*", "/"]
110 | for op in right_side_operations:
111 | text = text.replace(op + "$", op + " $")
112 |
113 | for op in left_side_operations:
114 | text = text.replace("$" + op, "$ " + op)
115 |
116 | text += "\n"
117 | index = 1
118 | texts = []
119 | start_plain_text_index = -1
120 | while index < len(text) - 1:
121 | (
122 | detect_start_token,
123 | detect_end_token,
124 | end_sign,
125 | add_dollars,
126 | use_last_block_only,
127 | ) = get_latex_detection_functions(text, index)
128 | if detect_start_token is not None:
129 | if start_plain_text_index != -1:
130 | texts.append(
131 | proccess_plain_text(text[start_plain_text_index:index])
132 | if escape
133 | else text[start_plain_text_index:index]
134 | )
135 | start_plain_text_index = -1
136 |
137 | start_index, new_index = proccess_latex_tag(
138 | text,
139 | index,
140 | detect_start_token,
141 | detect_end_token,
142 | end_sign,
143 | use_last_block_only,
144 | )
145 | texts.append(
146 | proccess_plain_text(text[index:start_index])
147 | if escape
148 | else text[index:start_index]
149 | )
150 | if add_dollars:
151 | texts.append("\n$$\n")
152 | texts.append(text[start_index:new_index].strip())
153 | texts.append("\n$$\n")
154 | else:
155 | texts.append(text[start_index:new_index])
156 | index = new_index
157 | elif start_plain_text_index == -1:
158 | start_plain_text_index = index
159 | index += 1
160 | else:
161 | index += 1
162 | if start_plain_text_index != -1:
163 | texts.append(
164 | proccess_plain_text(text[start_plain_text_index:])
165 | if escape
166 | else text[start_plain_text_index:]
167 | )
168 | return "".join(texts).replace("\n", "\n\n").strip()
169 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NeMo Inspector
2 |
3 | NeMo Inspector is a tool designed to help you analyze Large Language Model (LLM) generations. It lets you explore and manipulate existing generations, apply filters, sorting criteria, and compute statistics.
4 |
5 | ## Prerequisites
6 |
7 | 1. **Clone and Install the Tool:**
8 | ```shell
9 | git clone git@github.com:NVIDIA/NeMo-Inspector.git
10 | cd NeMo-Inspector
11 | pip install .
12 | ```
13 |
14 | 2. **Launch the Tool:**
15 | ```shell
16 | nemo_inspector
17 | ```
18 |
19 | This will start a local server that you can access through your browser.
20 |
21 | ## Analyze Page
22 |
23 | The Analyze page helps you work with pre-generated outputs. To use it, provide paths to the generation files using command-line arguments. For example:
24 |
25 | ```shell
26 | nemo_inspector --model_prediction \
27 | generation1='/path/to/generation1/output-greedy.jsonl' \
28 | generation2='/path/to/generation2/output-rs*.jsonl'
29 | ```
30 | Once loaded, the Analyze page lets you:
31 |
32 | - **Sort and Filter Results:** Apply custom filtering and sorting functions to refine the displayed data.
33 | - **Compare Generations:** View outputs from multiple generation runs side-by-side.
34 | - **Modify and Label Data:** Update or annotate samples and save the changes for future reference.
35 | - **Compute Statistics:** Generate both custom and general statistics to summarize your data.
36 |
37 | ### Filtering
38 |
39 | The tool supports two filtering modes: **Filter Files** mode and **Filter Questions** mode. You can define custom filtering functions in Python and run them directly in the UI.
40 |
41 | #### Filter Files Mode
42 |
43 | - In this mode, the filtering function will be run on each sample across different files simultaneously.
44 | - The input to the filtering function is a dictionary where keys represent generation names and values are JSON objects for that sample.
45 | - The custom function should return a Boolean value (`True` to keep the sample, `False` to filter it out).
46 |
47 | Example of a custom filtering function:
48 |
49 | ```python
50 | def custom_filtering_function(error_message: str) -> bool:
51 | # Implement your logic here
52 | return 'timeout' not in error_message
53 |
54 | # This line will be used for the filtering:
55 | custom_filtering_function(data['generation1']['error_message'])
56 | ```
57 |
58 | **Note:** The last line of the custom filtering function is used for filtering. All preceding lines are just for computation.
59 |
60 | To apply multiple conditions to multiple generations, use the `&&` separator. For instance:
61 |
62 | ```python
63 | data['generation1']['is_correct'] && not data['generation2']['is_correct']
64 | ```
65 |
66 | **Important:** In Filter Files mode, do not write multi-generation conditions without using `&&`. Each condition should be separated by `&&`.
67 |
68 | #### Filter Questions Mode
69 |
70 | - In this mode, the function filters each question across multiple files without filtering out entire files.
71 | - The input is a dictionary of generation names mapping to **lists** of JSON data for that question.
72 |
73 | In this mode, you write conditions without the `&&` operator. For example:
74 |
75 | ```python
76 | data['generation1'][0]['is_correct'] and not data['generation2'][0]['is_correct']
77 | ```
78 |
79 | This example filters out questions where the first generation is correct and the second generation is incorrect. It can also compare fields directly:
80 |
81 | ```python
82 | data['generation1'][0]['is_correct'] != data['generation2'][0]['is_correct']
83 | ```
84 |
85 | **Note:** These examples cannot be used in Filter Files mode.
86 |
87 | ### Sorting
88 |
89 | Sorting functions are similar to filtering functions, but there are key differences:
90 |
91 | 1. **Scope:** Sorting functions operate on individual data entries (not dictionaries with multiple generations).
92 | 2. **Cross-Generations:** Sorting cannot be applied across multiple generations at once. You must sort one generation at a time.
93 |
94 | A correct sorting function might look like this:
95 |
96 | ```python
97 | def custom_sorting_function(generation: str):
98 | # Sort by the length of the generation text
99 | return len(generation)
100 |
101 | # This line will be used for the sorting:
102 | custom_sorting_function(data['generation'])
103 | ```
104 |
105 | ### Statistics
106 |
107 | NeMo Inspector supports two types of statistics:
108 |
109 | 1. **Custom Statistics:** Applied to the samples of a single question (for each generation).
110 |
111 | Default custom statistics include:
112 | - `correct_responses`
113 | - `wrong_responses`
114 | - `no_responses`
115 |
116 | 2. **General Custom Statistics:** Applied across all questions and all generations.
117 |
118 | Default general custom statistics include:
119 | - `dataset size`
120 | - `overall number of samples`
121 | - `generations per sample`
122 |
123 | You can change the existing or define your own Custom and General Custom Statistics functions.
124 |
125 | **Custom Statistics Example:**
126 |
127 | ```python
128 | def unique_error_counter(datas):
129 | # `datas` is a list of JSONs (one per file) for a single question
130 | unique_errors = set()
131 | for data in datas:
132 | unique_errors.add(data.get('error_message'))
133 | return len(unique_errors)
134 |
135 | def number_of_runs(datas):
136 | return len(datas)
137 |
138 | # Map function names to functions
139 | {'unique_errors': unique_error_counter, 'number_of_runs': number_of_runs}
140 | ```
141 |
142 | **General Custom Statistics Example:**
143 |
144 | ```python
145 | def overall_unique_error_counter(datas):
146 | # `datas` is a list of lists of dictionaries,
147 | # where datas[question_index][file_index] is a JSON record
148 | unique_errors = set()
149 | for question_data in datas:
150 | for file_data in question_data:
151 | unique_errors.add(file_data.get('error_message'))
152 | return len(unique_errors)
153 |
154 | # Map function names to functions
155 | {'unique_errors': overall_unique_error_counter}
156 | ```
157 |
158 | **Note:** The final line in both the Custom and General Custom Statistics code blocks should be a dictionary mapping function names to their corresponding functions.
159 |
160 | ### Modifications
161 |
162 | You can update each sample in the dataset programmatically. At the end of the code block, return the updated sample dictionary:
163 |
164 | ```python
165 | # For example, strip leading and trailing whitespace from the "generation" field
166 | {**data, 'generation': data['generation'].strip()}
167 | ```
168 |
--------------------------------------------------------------------------------
/nemo_inspector/layouts/analyze_page_layouts/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import List
16 |
17 | from dash import html
18 | import dash_ace
19 |
20 | from nemo_inspector.layouts.common_layouts import get_selector_layout
21 | from nemo_inspector.settings.constants import STATS_KEYS
22 | from nemo_inspector.utils.common import get_metrics, get_table_data
23 | from nemo_inspector.settings.constants.common import (
24 | CUSTOM,
25 | DELETE,
26 | FILES_FILTERING,
27 | FILES_ONLY,
28 | GENERAL_STATS,
29 | INLINE_STATS,
30 | QUESTIONS_FILTERING,
31 | )
32 | from nemo_inspector.utils.common import (
33 | get_custom_stats,
34 | get_general_custom_stats,
35 | get_stats_raw,
36 | )
37 |
38 |
39 | def get_filter_text(
40 | available_filters: List[str] = [], mode: str = FILES_FILTERING
41 | ) -> str:
42 | available_filters = list(
43 | get_table_data()[0][list(get_table_data()[0].keys())[0]][0].keys()
44 | if len(get_table_data()) and not available_filters
45 | else STATS_KEYS + list(get_metrics([]).keys()) + ["+ all fields in json"]
46 | )
47 | if mode == FILES_ONLY:
48 | return (
49 | "Write an expression to filter the data\n\n"
50 | + "For example:\ndata['is_correct'] and not data['error_message']\n\n"
51 | + "The expression has to return bool.\n\n"
52 | + "Available parameters to filter data:\n"
53 | + "\n".join(
54 | [
55 | ", ".join(available_filters[start : start + 5])
56 | for start in range(0, len(available_filters), 5)
57 | ]
58 | ),
59 | )
60 | elif mode == FILES_FILTERING:
61 | return (
62 | "Write an expression to filter the data\n"
63 | + "Separate expressions for different generations with &&\n"
64 | + "You can use base_generation variable to access data from the current generation\n\n"
65 | + "For example:\ndata['generation1']['correct_responses'] > 0.5 && data[base_generation]['no_response'] < 0.2\n\n"
66 | + "The expression has to return bool.\n\n"
67 | + "Available parameters to filter data:\n"
68 | + "\n".join(
69 | [
70 | ", ".join(available_filters[start : start + 5])
71 | for start in range(0, len(available_filters), 5)
72 | ]
73 | ),
74 | )
75 | elif mode == QUESTIONS_FILTERING:
76 | return (
77 | "Write an expression to filter the data\n"
78 | + "You can operate with a dictionary containing keys representing generation names\n"
79 | + "and a list of values as JSON data from your generation from each file.\n"
80 | + "You can use base_generation variable to access data from the current generation\n\n"
81 | + "For example:\ndata['generation1'][0]['is_correct'] != data[base_generation][0]['is_correct']\n\n"
82 | + "The expression has to return bool.\n\n"
83 | + "Available parameters to filter data:\n"
84 | + "\n".join(
85 | [
86 | ", ".join(available_filters[start : start + 5])
87 | for start in range(0, len(available_filters), 5)
88 | ]
89 | ),
90 | )
91 |
92 |
93 | def get_stats_text(general_stats: bool = False, delete: bool = False):
94 | if delete:
95 | return "Choose the name of the statistic you want to delete"
96 | else:
97 | if general_stats:
98 | return (
99 | "Creating General Custom Statistics:\n\n"
100 | "To introduce new general custom statistics:\n"
101 | "1. Create a dictionary where keys are the names of your custom stats.\n"
102 | "2. Assign functions as values. These functions should accept arrays where first dimension\n"
103 | "is a question index and second is a file number (both sorted and filtered).\n\n"
104 | "Example:\n\n"
105 | "Define a custom function to integrate into your stats:\n\n"
106 | "def my_func(datas):\n"
107 | " correct_responses = 0\n"
108 | " for question_data in datas:\n"
109 | " for file_data in question_data:\n"
110 | " correct_responses += file_data['is_correct']\n"
111 | " return correct_responses\n"
112 | "{'correct_responses': my_func}"
113 | )
114 | else:
115 | return (
116 | "Creating Custom Statistics:\n\n"
117 | "To introduce new custom statistics:\n"
118 | "1. Create a dictionary where keys are the names of your custom stats.\n"
119 | "2. Assign functions as values. These functions should accept arrays containing data\n"
120 | "from all relevant files.\n\n"
121 | "Note: Do not use names that already exist in the current stats or JSON fields\n"
122 | "to avoid conflicts.\n\n"
123 | "Example:\n\n"
124 | "Define a custom function to integrate into your stats:\n\n"
125 | "def unique_error_counter(datas):\n"
126 | " unique_errors = set()\n"
127 | " for data in datas:\n"
128 | " unique_errors.add(data.get('error_message'))\n"
129 | " return len(unique_errors)\n\n"
130 | "{'unique_error_count': unique_error_counter}"
131 | )
132 |
133 |
134 | def get_code_text_area_layout(id) -> dash_ace.DashAceEditor:
135 | return dash_ace.DashAceEditor(
136 | id=id,
137 | theme="tomorrow_night",
138 | mode="python",
139 | tabSize=4,
140 | enableBasicAutocompletion=True,
141 | enableLiveAutocompletion=True,
142 | placeholder="Write your code here...",
143 | value="",
144 | style={"width": "100%", "height": "300px"},
145 | )
146 |
147 |
148 | def get_stats_input(modes: List[str] = []) -> List:
149 | body = []
150 | if DELETE in modes:
151 | delete_options = list(
152 | get_general_custom_stats().keys()
153 | if GENERAL_STATS in modes
154 | else get_custom_stats().keys()
155 | )
156 | body += [
157 | get_selector_layout(
158 | delete_options,
159 | "stats_input",
160 | delete_options[0] if delete_options else "",
161 | )
162 | ]
163 | else:
164 | mode = GENERAL_STATS if GENERAL_STATS in modes else INLINE_STATS
165 | extractor_options = list(get_stats_raw()[mode].keys())
166 | body += [
167 | get_selector_layout(extractor_options, "stats_extractor", CUSTOM),
168 | get_code_text_area_layout(id="stats_input"),
169 | ]
170 | return [
171 | html.Pre(
172 | get_stats_text(GENERAL_STATS in modes, DELETE in modes), id="stats_text"
173 | ),
174 | ] + body
175 |
--------------------------------------------------------------------------------
/nemo_inspector/callbacks/analyze_page/label_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 | from typing import List, Tuple
17 |
18 | from dash import ALL, callback_context, no_update
19 | from dash.dependencies import Input, Output, State
20 |
21 | from nemo_inspector.callbacks import app
22 | from nemo_inspector.settings.constants import (
23 | CHOOSE_LABEL,
24 | FILE_NAME,
25 | LABEL,
26 | LABEL_SELECTOR_ID,
27 | )
28 | from nemo_inspector.utils.common import (
29 | get_labels,
30 | get_table_data,
31 | )
32 |
33 |
34 | @app.callback(
35 | Output({"type": "label", "id": ALL}, "is_open"),
36 | [
37 | Input({"type": "set_file_label_button", "id": ALL}, "n_clicks"),
38 | Input({"type": "apply_label_button", "id": ALL}, "n_clicks"),
39 | Input({"type": "delete_label_button", "id": ALL}, "n_clicks"),
40 | ],
41 | [State({"type": "label", "id": ALL}, "is_open")],
42 | )
43 | def toggle_modal_label(n1: int, n2: int, n3: int, is_open: bool) -> bool:
44 | ctx = callback_context
45 | if not ctx.triggered:
46 | return [no_update] * len(is_open)
47 |
48 | button_id = json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"] + 1
49 | if not ctx.triggered[0]["value"]:
50 | return [no_update] * len(is_open)
51 |
52 | if n1[button_id] or n2[button_id] or n3[button_id]:
53 | is_open[button_id] = not is_open[button_id]
54 | return is_open
55 | return is_open
56 |
57 |
58 | @app.callback(
59 | Output(
60 | "dummy_output",
61 | "children",
62 | allow_duplicate=True,
63 | ),
64 | [
65 | Input({"type": "apply_label_button", "id": ALL}, "n_clicks"),
66 | Input({"type": "delete_label_button", "id": ALL}, "n_clicks"),
67 | ],
68 | [
69 | State(
70 | {"type": "aplly_for_all_files", "id": ALL},
71 | "value",
72 | ),
73 | State({"type": "label_selector", "id": ALL}, "value"),
74 | State({"type": "label_selector", "id": ALL}, "id"),
75 | State("datatable", "page_current"),
76 | State("datatable", "page_size"),
77 | State("datatable", "selected_rows"),
78 | State({"type": "model_selector", "id": ALL}, "value"),
79 | State("base_model_answers_selector", "value"),
80 | State({"type": "file_selector", "id": ALL}, "value"),
81 | State({"type": "file_selector", "id": ALL}, "options"),
82 | State(
83 | "dummy_output",
84 | "children",
85 | ),
86 | ],
87 | prevent_initial_call=True,
88 | )
89 | def change_label(
90 | n_click_apply: List[int],
91 | n_click_del: List[int],
92 | apply_for_all: List[bool],
93 | labels: List[str],
94 | label_ids: List[int],
95 | current_page: int,
96 | page_size: int,
97 | idx: List[int],
98 | models: List[str],
99 | base_model: str,
100 | file_names: List[str],
101 | file_options: List[str],
102 | dummy_data: str,
103 | ) -> List[List[str]]:
104 | ctx = callback_context
105 | if not ctx.triggered:
106 | return no_update
107 |
108 | button_id = label_ids.index(
109 | json.loads(
110 | LABEL_SELECTOR_ID.format(
111 | json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"]
112 | )
113 | )
114 | )
115 | is_apply = (
116 | json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["type"]
117 | == "apply_label_button"
118 | )
119 | if not ctx.triggered[0]["value"] or labels[button_id] == CHOOSE_LABEL:
120 | return no_update
121 |
122 | ALL_FILES = "ALL_FILES"
123 | if button_id == 0:
124 | files = [ALL_FILES]
125 | file = [ALL_FILES]
126 | models_to_process = [(base_model, files, file)]
127 | apply_for_all = [[True] * len(models)]
128 | question_ids = list(range(len(get_table_data())))
129 | else:
130 | if not idx:
131 | return no_update
132 | models_to_process = [
133 | (
134 | models[button_id - 1],
135 | file_options[button_id - 1],
136 | file_names[button_id - 1],
137 | )
138 | ]
139 | question_ids = [current_page * page_size + idx[0]]
140 |
141 | apply_for_all_files = bool(len(apply_for_all[button_id - 1]))
142 | for question_id in question_ids:
143 | for model, current_file_options, current_file in models_to_process:
144 | options = (
145 | current_file_options
146 | if button_id != 0
147 | else [
148 | {"value": file[FILE_NAME]}
149 | for file in get_table_data()[question_id][model]
150 | ]
151 | )
152 | for file in options:
153 | if not apply_for_all_files and not file["value"] == current_file:
154 | continue
155 |
156 | file_id = 0
157 | for i, model_file in enumerate(get_table_data()[question_id][model]):
158 | if model_file[FILE_NAME] == file["value"]:
159 | file_id = i
160 | break
161 |
162 | if (
163 | labels[button_id]
164 | not in get_table_data()[question_id][model][file_id][LABEL]
165 | ):
166 | if is_apply:
167 | get_table_data()[question_id][model][file_id][LABEL].append(
168 | labels[button_id]
169 | )
170 |
171 | elif not is_apply:
172 | get_table_data()[question_id][model][file_id][LABEL].remove(
173 | labels[button_id]
174 | )
175 |
176 | return dummy_data + "1"
177 |
178 |
179 | @app.callback(
180 | [
181 | Output({"type": "new_label_input", "id": ALL}, "value"),
182 | Output({"type": "label_selector", "id": ALL}, "options"),
183 | Output({"type": "label_selector", "id": ALL}, "value"),
184 | ],
185 | Input({"type": "add_new_label_button", "id": ALL}, "n_clicks"),
186 | [
187 | State({"type": "new_label_input", "id": ALL}, "value"),
188 | State({"type": "label_selector", "id": ALL}, "options"),
189 | State({"type": "label_selector", "id": ALL}, "value"),
190 | State({"type": "label_selector", "id": ALL}, "id"),
191 | ],
192 | )
193 | def add_new_label(
194 | n_click: int,
195 | new_labels: List[str],
196 | options: List[List[str]],
197 | values: List[str],
198 | label_ids: List[int],
199 | ) -> Tuple[List[List[str]], List[str]]:
200 | ctx = callback_context
201 | no_updates = [no_update] * len(new_labels)
202 | if not ctx.triggered:
203 | return no_updates, no_updates, no_updates
204 |
205 | button_id = label_ids.index(
206 | json.loads(
207 | LABEL_SELECTOR_ID.format(
208 | json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"]
209 | )
210 | )
211 | )
212 |
213 | if not ctx.triggered[0]["value"]:
214 | return no_updates, no_updates, no_updates
215 |
216 | if new_labels[button_id] and new_labels[button_id] not in options[button_id]:
217 | for i in range(len(options)):
218 | new_label = {"label": new_labels[button_id], "value": new_labels[button_id]}
219 | if new_label not in options[i]:
220 | options[i].append(
221 | {"label": new_labels[button_id], "value": new_labels[button_id]}
222 | )
223 | values[button_id] = new_labels[button_id]
224 | else:
225 | return no_updates, no_updates, no_updates
226 |
227 | get_labels().append(new_labels[button_id])
228 | new_labels[button_id] = ""
229 |
230 | return new_labels, options, values
231 |
232 |
233 | @app.callback(
234 | Output({"type": "chosen_label", "id": ALL}, "children"),
235 | Input({"type": "label_selector", "id": ALL}, "value"),
236 | [
237 | State({"type": "label_selector", "id": ALL}, "id"),
238 | State({"type": "chosen_label", "id": ALL}, "children"),
239 | ],
240 | )
241 | def choose_label(
242 | label: List[str], label_ids: List[int], chosen_labels: List[str]
243 | ) -> Tuple[List[List[str]], List[str]]:
244 | ctx = callback_context
245 | if not ctx.triggered:
246 | return [no_update] * len(chosen_labels)
247 |
248 | for trigger in ctx.triggered:
249 | button_id = label_ids.index(
250 | json.loads(
251 | LABEL_SELECTOR_ID.format(
252 | json.loads(trigger["prop_id"].split(".")[0])["id"]
253 | )
254 | )
255 | )
256 |
257 | if not ctx.triggered[0]["value"] or label[button_id] == CHOOSE_LABEL:
258 | chosen_labels[button_id] = ""
259 | else:
260 | chosen_labels[button_id] = f"chosen label: {label[button_id]}"
261 |
262 | return chosen_labels
263 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/nemo_inspector/layouts/analyze_page_layouts/base_layout.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | from typing import List
17 |
18 | import dash_bootstrap_components as dbc
19 | from dash import dcc, html
20 |
21 | from nemo_inspector.layouts.analyze_page_layouts.table_layouts import (
22 | get_detailed_info_table_layout,
23 | get_general_stats_layout,
24 | get_short_info_table_layout,
25 | )
26 | from nemo_inspector.layouts.common_layouts import (
27 | get_selector_layout,
28 | )
29 | from nemo_inspector.settings.constants import (
30 | BASE_GENERATION,
31 | CHOOSE_GENERATION,
32 | ERROR_MESSAGE_TEMPLATE,
33 | FILES_FILTERING,
34 | )
35 | from nemo_inspector.utils.common import (
36 | catch_eval_exception,
37 | clear_table_data,
38 | custom_deepcopy,
39 | get_available_models,
40 | get_data_from_files,
41 | get_eval_function,
42 | get_metrics,
43 | get_table_data,
44 | is_detailed_answers_rows_key,
45 | )
46 | from nemo_inspector.layouts.analyze_page_layouts.modals_layouts import (
47 | get_add_stats_modal_layout,
48 | get_change_label_modal_layout,
49 | get_filter_modal_layout,
50 | get_sorting_modal_layout,
51 | get_update_dataset_modal_layout,
52 | get_save_dataset_modal_layout,
53 | )
54 |
55 |
56 | def get_compare_test_layout() -> html.Div:
57 | return html.Div(
58 | [
59 | dbc.InputGroup(
60 | [
61 | get_sorting_modal_layout(),
62 | get_filter_modal_layout(),
63 | get_add_stats_modal_layout(),
64 | get_change_label_modal_layout(apply_for_all_files=False),
65 | get_update_dataset_modal_layout(),
66 | get_save_dataset_modal_layout(),
67 | dbc.Button(
68 | "+",
69 | id="add_model",
70 | outline=True,
71 | color="primary",
72 | className="me-1",
73 | class_name="button-class",
74 | style={"margin-left": "1px"},
75 | ),
76 | get_selector_layout(
77 | get_available_models().keys(),
78 | "base_model_answers_selector",
79 | value=CHOOSE_GENERATION,
80 | ),
81 | ]
82 | ),
83 | html.Pre(id="filtering_container"),
84 | html.Pre(id="sorting_container"),
85 | dcc.Loading(
86 | dbc.Container(
87 | id="loading_container", style={"display": "none"}, children=""
88 | ),
89 | type="circle",
90 | style={"margin-top": "50px"},
91 | ),
92 | html.Div(
93 | children=[],
94 | id="compare_models_rows",
95 | ),
96 | ],
97 | )
98 |
99 |
100 | def get_updated_tables_layout(
101 | base_model: str, update_function: str, models: List[str]
102 | ) -> List[html.Tr]:
103 | errors_dict = {}
104 | if update_function:
105 | update_eval_function = get_eval_function(update_function.strip())
106 | available_models = {
107 | model_name: model_info["file_paths"]
108 | for model_name, model_info in get_available_models().items()
109 | }
110 |
111 | for question_id in range(len(get_table_data())):
112 | new_dicts = list(
113 | map(
114 | lambda data: catch_eval_exception(
115 | available_models,
116 | update_eval_function,
117 | data,
118 | data,
119 | errors_dict,
120 | ),
121 | get_table_data()[question_id][base_model],
122 | )
123 | )
124 | for i, new_dict in enumerate(new_dicts):
125 | for key, value in new_dict.items():
126 | get_table_data()[question_id][base_model][i][key] = value
127 |
128 | keys = list(get_table_data()[question_id][base_model][i].keys())
129 | for key in keys:
130 | if key not in new_dict:
131 | get_table_data()[question_id][base_model][i].pop(key)
132 |
133 | if len(errors_dict):
134 | logging.error(ERROR_MESSAGE_TEMPLATE.format("update_dataset", errors_dict))
135 |
136 | return (
137 | get_short_info_table_layout()
138 | + get_general_stats_layout(base_model)
139 | + get_detailed_info_table_layout(
140 | models,
141 | list(
142 | filter(
143 | is_detailed_answers_rows_key,
144 | (
145 | get_table_data()[0][base_model][0].keys()
146 | if len(get_table_data()) and len(get_table_data()[0][base_model])
147 | else []
148 | ),
149 | )
150 | ),
151 | )
152 | )
153 |
154 |
155 | def get_sorted_tables_layout(
156 | base_model: str, sorting_function: str, models: List[str]
157 | ) -> List[html.Tr]:
158 | errors_dict = {}
159 | if sorting_function:
160 | sortting_eval_function = get_eval_function(sorting_function.strip())
161 | available_models = {
162 | model_name: model_info["file_paths"]
163 | for model_name, model_info in get_available_models().items()
164 | }
165 |
166 | for question_id in range(len(get_table_data())):
167 | for model in get_table_data()[question_id].keys():
168 | get_table_data()[question_id][model].sort(
169 | key=lambda data: catch_eval_exception(
170 | available_models,
171 | sortting_eval_function,
172 | data,
173 | 0,
174 | errors_dict,
175 | )
176 | )
177 |
178 | get_table_data().sort(
179 | key=lambda single_question_data: tuple(
180 | map(
181 | lambda data: catch_eval_exception(
182 | available_models,
183 | sortting_eval_function,
184 | data,
185 | 0,
186 | errors_dict,
187 | ),
188 | single_question_data[base_model],
189 | )
190 | )
191 | )
192 | if len(errors_dict):
193 | logging.error(ERROR_MESSAGE_TEMPLATE.format("sorting", errors_dict))
194 |
195 | return (
196 | get_short_info_table_layout()
197 | + get_general_stats_layout(base_model)
198 | + get_detailed_info_table_layout(
199 | models,
200 | list(
201 | filter(
202 | is_detailed_answers_rows_key,
203 | (
204 | get_table_data()[0][base_model][0].keys()
205 | if len(get_table_data()) and len(get_table_data()[0][base_model])
206 | else []
207 | ),
208 | )
209 | ),
210 | )
211 | )
212 |
213 |
214 | def get_filtered_tables_layout(
215 | base_model: str,
216 | filtering_function: str,
217 | apply_on_filtered_data: bool,
218 | models: List[str],
219 | filter_mode: str,
220 | ) -> List[html.Tr]:
221 | clean_table_data = []
222 | if not apply_on_filtered_data:
223 | clear_table_data()
224 | get_table_data().extend(custom_deepcopy(get_data_from_files()))
225 | for question_id in range(len(get_table_data())):
226 | for model_id, files_data in get_table_data()[question_id].items():
227 | stats = get_metrics(files_data)
228 | get_table_data()[question_id][model_id] = list(
229 | map(
230 | lambda data: {**data, **stats},
231 | get_table_data()[question_id][model_id],
232 | )
233 | )
234 |
235 | errors_dict = {}
236 | if filtering_function:
237 | available_models = {
238 | model_name: model_info["file_paths"]
239 | for model_name, model_info in get_available_models().items()
240 | }
241 | filter_lines = filtering_function.strip().split("\n")
242 | common_expressions, splitted_filters = (
243 | "\n".join(filter_lines[:-1]),
244 | filter_lines[-1],
245 | )
246 | full_splitted_filters = [
247 | common_expressions + "\n" + single_filter
248 | for single_filter in splitted_filters.split("&&")
249 | ]
250 | filtering_functions = (
251 | list(
252 | [
253 | get_eval_function(f"{BASE_GENERATION} = '{base_model}'\n" + func)
254 | for func in full_splitted_filters
255 | ]
256 | )
257 | if filtering_function
258 | else []
259 | )
260 |
261 | if filter_mode == FILES_FILTERING:
262 | for question_id in range(len(get_table_data())):
263 | good_data = True
264 | for model_id in get_table_data()[question_id].keys():
265 |
266 | def filtering_key_function(file_dict):
267 | data = {model_id: file_dict}
268 | return all(
269 | [
270 | catch_eval_exception(
271 | available_models,
272 | filter_function,
273 | data,
274 | True,
275 | errors_dict,
276 | )
277 | for filter_function in filtering_functions
278 | ],
279 | )
280 |
281 | get_table_data()[question_id][model_id] = list(
282 | filter(
283 | filtering_key_function,
284 | get_table_data()[question_id][model_id],
285 | )
286 | )
287 | stats = get_metrics(get_table_data()[question_id][model_id])
288 | get_table_data()[question_id][model_id] = list(
289 | map(
290 | lambda data: {**data, **stats},
291 | get_table_data()[question_id][model_id],
292 | )
293 | )
294 |
295 | if get_table_data()[question_id][model_id] == []:
296 | good_data = False
297 | if good_data:
298 | clean_table_data.append(get_table_data()[question_id])
299 | else:
300 | func = get_eval_function(
301 | f"{BASE_GENERATION} = '{base_model}'\n" + filtering_function.strip()
302 | )
303 | clean_table_data = list(
304 | filter(
305 | lambda data: catch_eval_exception(
306 | available_models=[],
307 | eval_func=func,
308 | data=data,
309 | default_answer=True,
310 | errors_dict=errors_dict,
311 | ),
312 | get_table_data(),
313 | )
314 | )
315 | clear_table_data()
316 | get_table_data().extend(clean_table_data)
317 | if len(errors_dict):
318 | logging.error(ERROR_MESSAGE_TEMPLATE.format("filtering", errors_dict))
319 |
320 | return (
321 | get_short_info_table_layout()
322 | + get_general_stats_layout(base_model)
323 | + get_detailed_info_table_layout(
324 | models,
325 | list(
326 | filter(
327 | is_detailed_answers_rows_key,
328 | (
329 | get_table_data()[0][base_model][0].keys()
330 | if len(get_table_data()) and len(get_table_data()[0][base_model])
331 | else []
332 | ),
333 | )
334 | ),
335 | )
336 | )
337 |
338 |
339 | def get_tables_layout(base_model: str) -> List:
340 | if get_table_data() == []:
341 | get_table_data().extend(custom_deepcopy(get_data_from_files()))
342 | return (
343 | get_short_info_table_layout()
344 | + get_general_stats_layout(base_model)
345 | + get_detailed_info_table_layout(
346 | [base_model],
347 | list(
348 | filter(
349 | is_detailed_answers_rows_key,
350 | (
351 | get_table_data()[0][base_model][0].keys()
352 | if len(get_table_data()) and len(get_table_data()[0][base_model])
353 | else []
354 | ),
355 | )
356 | ),
357 | )
358 | )
359 |
--------------------------------------------------------------------------------
/nemo_inspector/layouts/analyze_page_layouts/modals_layouts.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from typing import List
17 |
18 | import dash_bootstrap_components as dbc
19 | from dash import html
20 | from flask import current_app
21 |
22 | from nemo_inspector.layouts.common_layouts import (
23 | get_selector_layout,
24 | get_switch_layout,
25 | )
26 | from nemo_inspector.layouts.analyze_page_layouts.utils import (
27 | get_code_text_area_layout,
28 | get_filter_text,
29 | get_stats_input,
30 | )
31 | from nemo_inspector.settings.constants import (
32 | DELETE,
33 | FILES_ONLY,
34 | FILES_FILTERING,
35 | GENERAL_STATS,
36 | )
37 | from nemo_inspector.settings.constants.configurations import STATS_KEYS
38 | from nemo_inspector.utils.common import get_labels, get_metrics, get_table_data
39 |
40 |
41 | def get_filter_modal_layout(
42 | id: int = -1, available_filters: List[str] = [], mode: str = FILES_FILTERING
43 | ) -> html.Div:
44 | text = get_filter_text(available_filters, mode)
45 |
46 | filter_mode = (
47 | [
48 | get_switch_layout(
49 | id={"type": "filter_mode", "id": id},
50 | labels=["filter files"],
51 | is_active=True,
52 | additional_params={
53 | "inline": True,
54 | "style": {"margin-left": "10px"},
55 | },
56 | )
57 | ]
58 | if mode != FILES_ONLY
59 | else []
60 | )
61 |
62 | header = dbc.ModalHeader(
63 | (
64 | [
65 | dbc.ModalTitle(
66 | "Set Up Your Filter",
67 | ),
68 | ]
69 | + filter_mode
70 | ),
71 | close_button=True,
72 | )
73 | body = dbc.ModalBody(
74 | html.Div(
75 | [
76 | html.Pre(text, id={"type": "filter_text", "id": id}),
77 | get_code_text_area_layout(
78 | id={
79 | "type": "filter_function_input",
80 | "id": id,
81 | },
82 | ),
83 | ]
84 | )
85 | )
86 | switch = get_switch_layout(
87 | {
88 | "type": "apply_on_filtered_data",
89 | "id": id,
90 | },
91 | ["Apply for filtered data"],
92 | additional_params={"style": {"margin-left": "10px"}},
93 | )
94 | footer = dbc.ModalFooter(
95 | dbc.Button(
96 | "Apply",
97 | id={"type": "apply_filter_button", "id": id},
98 | className="ms-auto",
99 | n_clicks=0,
100 | )
101 | )
102 | return html.Div(
103 | [
104 | dbc.Button(
105 | "Filters",
106 | id={"type": "set_filter_button", "id": id},
107 | class_name="button-class",
108 | ),
109 | dbc.Modal(
110 | [
111 | header,
112 | body,
113 | switch,
114 | footer,
115 | ],
116 | size="lg",
117 | id={"type": "filter", "id": id},
118 | centered=True,
119 | is_open=False,
120 | ),
121 | ],
122 | style={"display": "inline-block"},
123 | )
124 |
125 |
126 | def get_sorting_modal_layout(id: int = -1, available_params: List[str] = []) -> html.Div:
127 | available_params = list(
128 | get_table_data()[0][list(get_table_data()[0].keys())[0]][0].keys()
129 | if len(get_table_data()) and not available_params
130 | else STATS_KEYS + list(get_metrics([]).keys()) + ["+ all fields in json"]
131 | )
132 | text = (
133 | "Write an expression to sort the data\n\n"
134 | "For example: len(data['question'])\n\n"
135 | "The function has to return sortable type\n\n"
136 | "Available parameters to sort data:\n"
137 | + "\n".join(
138 | [
139 | ", ".join(available_params[start : start + 5])
140 | for start in range(0, len(available_params), 5)
141 | ]
142 | )
143 | )
144 | header = dbc.ModalHeader(
145 | dbc.ModalTitle("Set Up Your Sorting Parameters"),
146 | close_button=True,
147 | )
148 | body = dbc.ModalBody(
149 | html.Div(
150 | [
151 | html.Pre(text),
152 | get_code_text_area_layout(
153 | id={
154 | "type": "sorting_function_input",
155 | "id": id,
156 | },
157 | ),
158 | ],
159 | )
160 | )
161 | footer = dbc.ModalFooter(
162 | dbc.Button(
163 | "Apply",
164 | id={"type": "apply_sorting_button", "id": id},
165 | className="ms-auto",
166 | n_clicks=0,
167 | )
168 | )
169 | return html.Div(
170 | [
171 | dbc.Button(
172 | "Sort",
173 | id={"type": "set_sorting_button", "id": id},
174 | class_name="button-class",
175 | ),
176 | dbc.Modal(
177 | [
178 | header,
179 | body,
180 | footer,
181 | ],
182 | size="lg",
183 | id={"type": "sorting", "id": id},
184 | centered=True,
185 | is_open=False,
186 | ),
187 | ],
188 | style={"display": "inline-block"},
189 | )
190 |
191 |
192 | def get_update_dataset_modal_layout() -> html.Div:
193 | text = (
194 | "Write an expression to modify the data\n\n"
195 | "For example: {**data, 'generation': data['generation'].strip()}\n\n"
196 | "The function has to return a new dict"
197 | )
198 | header = dbc.ModalHeader(
199 | dbc.ModalTitle("Update Dataset"),
200 | close_button=True,
201 | )
202 | body = dbc.ModalBody(
203 | html.Div(
204 | [
205 | html.Pre(text),
206 | get_code_text_area_layout(
207 | id="update_dataset_input",
208 | ),
209 | ],
210 | )
211 | )
212 | footer = dbc.ModalFooter(
213 | dbc.Button(
214 | "Apply",
215 | id="apply_update_dataset_button",
216 | className="ms-auto",
217 | n_clicks=0,
218 | )
219 | )
220 | return html.Div(
221 | [
222 | dbc.Button(
223 | "Update dataset",
224 | id="update_dataset_button",
225 | class_name="button-class",
226 | ),
227 | dbc.Modal(
228 | [
229 | header,
230 | body,
231 | footer,
232 | ],
233 | size="lg",
234 | id="update_dataset_modal",
235 | centered=True,
236 | is_open=False,
237 | ),
238 | ],
239 | style={"display": "inline-block"},
240 | )
241 |
242 |
243 | def get_change_label_modal_layout(
244 | id: int = -1, apply_for_all_files: bool = True
245 | ) -> html.Div:
246 | header = dbc.ModalHeader(
247 | dbc.ModalTitle("Manage labels"),
248 | close_button=True,
249 | )
250 | switch_layout = (
251 | [
252 | get_switch_layout(
253 | {
254 | "type": "aplly_for_all_files",
255 | "id": id,
256 | },
257 | ["Apply for all files"],
258 | additional_params={"style": {"margin-left": "10px"}},
259 | )
260 | ]
261 | if apply_for_all_files
262 | else []
263 | )
264 | body = dbc.ModalBody(
265 | html.Div(
266 | [
267 | get_selector_layout(
268 | options=get_labels(),
269 | id={"type": "label_selector", "id": id},
270 | value="choose label",
271 | ),
272 | dbc.InputGroup(
273 | [
274 | dbc.Input(
275 | id={
276 | "type": "new_label_input",
277 | "id": id,
278 | },
279 | placeholder="Enter new label",
280 | type="text",
281 | ),
282 | dbc.Button(
283 | "Add",
284 | id={
285 | "type": "add_new_label_button",
286 | "id": id,
287 | },
288 | ),
289 | ]
290 | ),
291 | *switch_layout,
292 | html.Pre("", id={"type": "chosen_label", "id": id}),
293 | ],
294 | )
295 | )
296 | footer = dbc.ModalFooter(
297 | html.Div(
298 | [
299 | dbc.Button(
300 | children="Delete",
301 | id={
302 | "type": "delete_label_button",
303 | "id": id,
304 | },
305 | className="ms-auto",
306 | n_clicks=0,
307 | ),
308 | html.Pre(
309 | " ",
310 | style={"display": "inline-block", "font-size": "5px"},
311 | ),
312 | dbc.Button(
313 | children="Apply",
314 | id={"type": "apply_label_button", "id": id},
315 | className="ms-auto",
316 | n_clicks=0,
317 | ),
318 | ],
319 | ),
320 | style={"display": "inline-block"},
321 | )
322 | return html.Div(
323 | [
324 | dbc.Button(
325 | "Labels",
326 | id={"type": "set_file_label_button", "id": id},
327 | class_name="button-class",
328 | ),
329 | dbc.Modal(
330 | [header, body, footer],
331 | size="lg",
332 | id={"type": "label", "id": id},
333 | centered=True,
334 | is_open=False,
335 | ),
336 | ],
337 | style={"display": "inline-block"},
338 | )
339 |
340 |
341 | def get_save_dataset_modal_layout() -> html.Div:
342 | return html.Div(
343 | [
344 | dbc.Button("Save dataset", id="save_dataset", class_name="button-class"),
345 | dbc.Modal(
346 | [
347 | dbc.ModalBody(
348 | [
349 | dbc.InputGroup(
350 | [
351 | dbc.InputGroupText("save_path"),
352 | dbc.Input(
353 | value=os.path.join(
354 | current_app.config["nemo_inspector"]["save_generations_path"],
355 | "default_name",
356 | ),
357 | id="save_path",
358 | type="text",
359 | ),
360 | ],
361 | className="mb-3",
362 | ),
363 | dbc.Container(id="error_message"),
364 | ]
365 | ),
366 | dbc.ModalFooter(
367 | dbc.Button(
368 | "Save",
369 | id="save_dataset_button",
370 | className="ms-auto",
371 | n_clicks=0,
372 | )
373 | ),
374 | ],
375 | id="save_dataset_modal",
376 | is_open=False,
377 | style={
378 | "text-align": "center",
379 | "margin-top": "10px",
380 | "margin-bottom": "10px",
381 | },
382 | ),
383 | ],
384 | )
385 |
386 |
387 | def get_add_stats_modal_layout() -> html.Div:
388 | modal_header = dbc.ModalHeader(
389 | [
390 | dbc.ModalTitle("Set Up Your Stats"),
391 | get_switch_layout(
392 | id="stats_modes",
393 | labels=["general stats", "delete mode"],
394 | values=[GENERAL_STATS, DELETE],
395 | additional_params={"inline": True, "style": {"margin-left": "10px"}},
396 | ),
397 | ],
398 | close_button=True,
399 | )
400 | modal_body = dbc.ModalBody(
401 | html.Div(
402 | get_stats_input(),
403 | id="stats_input_container",
404 | )
405 | )
406 | modal_footer = dbc.ModalFooter(
407 | dbc.Button(
408 | "Apply",
409 | id="apply_new_stats",
410 | className="ms-auto",
411 | n_clicks=0,
412 | )
413 | )
414 | return html.Div(
415 | [
416 | dbc.Button(
417 | "Stats",
418 | id="set_new_stats_button",
419 | class_name="button-class",
420 | ),
421 | dbc.Modal(
422 | [
423 | modal_header,
424 | modal_body,
425 | modal_footer,
426 | ],
427 | size="lg",
428 | id="new_stats",
429 | centered=True,
430 | is_open=False,
431 | ),
432 | ],
433 | style={"display": "inline-block"},
434 | )
435 |
--------------------------------------------------------------------------------
/nemo_inspector/utils/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import functools
16 | import json
17 | import logging
18 | import os
19 | import re
20 | from collections import defaultdict
21 | from typing import Callable, Dict, List, Optional, Set, Tuple, Union, get_origin
22 |
23 | from flask import current_app
24 | from joblib import Parallel, delayed
25 |
26 | from nemo_inspector.settings.constants import (
27 | EXPECTED_ANSWER_FIELD,
28 | CUSTOM,
29 | ERROR_MESSAGE_TEMPLATE,
30 | FILE_NAME,
31 | GENERAL_STATS,
32 | INLINE_STATS,
33 | QUESTION_FIELD,
34 | STATS_KEYS,
35 | UNDEFINED,
36 | )
37 |
38 | custom_stats = {}
39 | general_custom_stats = {}
40 | deleted_stats = set()
41 | excluded_rows = set()
42 | editable_rows = set()
43 | compared_rows = set()
44 | stats_raw = {INLINE_STATS: {CUSTOM: ""}, GENERAL_STATS: {CUSTOM: ""}}
45 |
46 | dataset_data = []
47 | labels = []
48 |
49 |
50 | def get_editable_rows() -> Set:
51 | return editable_rows
52 |
53 |
54 | def get_excluded_row() -> Set:
55 | return excluded_rows
56 |
57 |
58 | def get_deleted_stats() -> Set:
59 | return deleted_stats
60 |
61 |
62 | def get_custom_stats() -> Dict:
63 | return custom_stats
64 |
65 |
66 | def get_compared_rows() -> Dict:
67 | return compared_rows
68 |
69 |
70 | def get_general_custom_stats() -> Dict:
71 | return general_custom_stats
72 |
73 |
74 | def get_stats_raw() -> Dict:
75 | return stats_raw
76 |
77 |
78 | def clear_table_data() -> None:
79 | global dataset_data
80 | dataset_data = []
81 |
82 |
83 | def get_table_data() -> List:
84 | return dataset_data
85 |
86 |
87 | def get_labels() -> List:
88 | return labels
89 |
90 |
91 | def parse_model_answer(answer: str) -> List[Dict]:
92 | """
93 | Parses a model answer and extracts code blocks, explanations, and outputs preserving their sequence.
94 |
95 | Args:
96 | answer (str): The model answer to parse.
97 |
98 | Returns:
99 | List[Dict]: A list of dictionaries containing the parsed results. Each dictionary
100 | contains the following keys:
101 | - 'explanation': The explanation text before the code block.
102 | - 'code': The code block.
103 | - 'output': The output of the code block.
104 |
105 | """
106 | config = current_app.config["nemo_inspector"]
107 | code_start, code_end = map(
108 | re.escape,
109 | config["code_separators"],
110 | )
111 | output_start, output_end = map(
112 | re.escape,
113 | config["code_output_separators"],
114 | )
115 | code_pattern = re.compile(rf"{code_start}(.*?){code_end}", re.DOTALL)
116 | code_output_pattern = re.compile(
117 | rf"{code_start}(.*?){code_end}\s*{output_start}(.*?){output_end}",
118 | re.DOTALL,
119 | )
120 | code_matches = list(code_pattern.finditer(answer))
121 | code_output_matches = list(code_output_pattern.finditer(answer))
122 | parsed_results = []
123 | last_index = 0
124 | for code_match in code_matches:
125 | explanation = answer[last_index : code_match.start()].strip()
126 | code_text = code_match.group(1).strip()
127 | output_text = None
128 | if code_output_matches and code_output_matches[0].start() == code_match.start():
129 | output_match = code_output_matches.pop(0)
130 | output_text = output_match.group(2).strip()
131 | parsed_results.append(
132 | {
133 | "explanation": explanation,
134 | "code": code_text,
135 | "output": output_text,
136 | }
137 | )
138 | last_index = code_match.end()
139 | if output_text is not None:
140 | last_index = output_match.end()
141 | if last_index < len(answer):
142 | trailing_text = answer[last_index:].strip()
143 | if code_start.replace("\\", "") in trailing_text:
144 | code_start_index = trailing_text.find(code_start.replace("\\", ""))
145 | parsed_results.append(
146 | {
147 | "explanation": trailing_text[0:code_start_index].strip(),
148 | "code": trailing_text[
149 | code_start_index + len(code_start.replace("\\", "")) :
150 | ],
151 | "output": "code_block was not finished",
152 | "wrong_code_block": True,
153 | }
154 | )
155 | trailing_text = None
156 | if trailing_text:
157 | parsed_results.append(
158 | {"explanation": trailing_text, "code": None, "output": None}
159 | )
160 | return parsed_results
161 |
162 |
163 | @functools.lru_cache()
164 | def get_dataset_sample(index: int, dataset: str) -> Tuple[Dict, int]:
165 | if not dataset or dataset == UNDEFINED or os.path.isfile(dataset) is False:
166 | return {QUESTION_FIELD: "", EXPECTED_ANSWER_FIELD: ""}, 0
167 | with open(dataset) as file:
168 | tests = file.readlines()
169 | index = max(min(len(tests), index), 1)
170 | test = (
171 | json.loads(tests[index - 1])
172 | if index != 0
173 | else {QUESTION_FIELD: "", EXPECTED_ANSWER_FIELD: ""}
174 | )
175 | return test, index
176 |
177 |
178 | def get_stats(all_files_data: List[Dict]) -> Tuple[float, float, float]:
179 | """Returns the percentage of correct, wrong, and no response answers in the given data.
180 |
181 | If not data is provided, returns -1 for all values.
182 | """
183 | correct = 0
184 | wrong = 0
185 | no_response = 0
186 | for data in all_files_data:
187 | if data.get("predicted_answer") is None:
188 | no_response += 1
189 | elif data.get("is_correct", False):
190 | correct += 1
191 | else:
192 | wrong += 1
193 |
194 | if len(all_files_data):
195 | return (
196 | correct / len(all_files_data),
197 | wrong / len(all_files_data),
198 | no_response / len(all_files_data),
199 | )
200 | return -1, -1, -1
201 |
202 |
203 | def get_metrics(all_files_data: List[Dict], errors_dict: Dict = {}) -> Dict:
204 | correct_responses, wrong_responses, no_response = get_stats(all_files_data)
205 | custom_stats = {}
206 | for name, func in get_custom_stats().items():
207 | if name not in errors_dict:
208 | errors_dict[name] = {}
209 | custom_stats[name] = catch_eval_exception(
210 | [],
211 | func,
212 | all_files_data,
213 | "Got error when applying function",
214 | errors_dict[name],
215 | )
216 |
217 | stats = {
218 | "correct_responses": round(correct_responses, 2),
219 | "wrong_responses": round(wrong_responses, 2),
220 | "no_response": round(no_response, 2),
221 | **custom_stats,
222 | }
223 | return stats
224 |
225 |
226 | def get_eval_function(text):
227 | template = """
228 | def eval_function(data):
229 | {}
230 | return {}
231 | """
232 | code_lines = [""] + text.strip().split("\n")
233 | code = template.format(
234 | "\n ".join(code_lines[:-1]),
235 | code_lines[-1:][0],
236 | )
237 | namespace = {}
238 | exec(code, namespace)
239 | return namespace["eval_function"]
240 |
241 |
242 | def calculate_metrics_for_whole_data(table_data: List, model_id: str) -> Dict:
243 | errors_dict = {}
244 | for question_id in range(len(table_data)):
245 | stats = get_metrics(table_data[question_id][model_id], errors_dict)
246 | table_data[question_id][model_id] = list(
247 | map(
248 | lambda data: {**data, **stats},
249 | table_data[question_id][model_id],
250 | )
251 | )
252 | if len(errors_dict):
253 | for name, error_dict in errors_dict.items():
254 | if len(error_dict):
255 | logging.error(ERROR_MESSAGE_TEMPLATE.format(name, error_dict))
256 |
257 |
258 | def catch_eval_exception(
259 | available_models: List[str],
260 | eval_func: Callable[[Dict], bool],
261 | data: Dict,
262 | default_answer: Union[bool, str],
263 | errors_dict: Optional[Dict] = {},
264 | ) -> bool:
265 | try:
266 | if eval_func is None:
267 | return default_answer
268 | return eval_func(data)
269 | except Exception as e:
270 | if str(e).split(" ")[-1].replace("'", "") not in available_models:
271 | if str(e) not in errors_dict:
272 | errors_dict[str(e)] = 0
273 | errors_dict[str(e)] += 1
274 | return default_answer
275 |
276 |
277 | def custom_deepcopy(data) -> List:
278 | new_data = []
279 | for item in data:
280 | new_item = {}
281 | for key, value_list in item.items():
282 | new_item[key] = value_list
283 | new_data.append(new_item)
284 | return new_data
285 |
286 |
287 | @functools.lru_cache(maxsize=1)
288 | def get_data_from_files() -> List:
289 | base_config = current_app.config["nemo_inspector"]
290 | dataset = None
291 | if os.path.isfile(base_config["input_file"]):
292 | with open(base_config["input_file"]) as f:
293 | dataset = [json.loads(line) for line in f]
294 |
295 | available_models = {
296 | model_name: model_info["file_paths"]
297 | for model_name, model_info in get_available_models().items()
298 | }
299 |
300 | all_models_data_array = []
301 |
302 | def process_model_files(model_id, results_files, dataset):
303 | model_data = defaultdict(list)
304 | file_names = {}
305 | for file_id, path in enumerate(results_files):
306 | file_name = path.split("/")[-1].split(".")[0]
307 | if file_name in file_names:
308 | file_names[file_name] += 1
309 | file_name += f"_{file_names[file_name]}"
310 | else:
311 | file_names[file_name] = 1
312 | with open(path) as f:
313 | answers = map(json.loads, f)
314 | for question_index, answer in enumerate(answers):
315 | result = {
316 | FILE_NAME: file_name,
317 | **(
318 | dataset[question_index]
319 | if dataset and len(dataset) > question_index
320 | else {}
321 | ),
322 | "question_index": question_index + 1,
323 | "page_index": file_id,
324 | "labels": [],
325 | **answer,
326 | }
327 | model_data[question_index].append(result)
328 | return model_id, model_data
329 |
330 | num_cores = -1
331 | model_data_list = Parallel(n_jobs=num_cores)(
332 | delayed(process_model_files)(model_id, results_files, dataset)
333 | for model_id, results_files in available_models.items()
334 | )
335 |
336 | for model_id, model_data in model_data_list:
337 | for question_index, results in model_data.items():
338 | if len(all_models_data_array) <= question_index:
339 | all_models_data_array.append({})
340 | all_models_data_array[question_index][model_id] = results
341 | stats = get_metrics(all_models_data_array[question_index][model_id])
342 | all_models_data_array[question_index][model_id] = list(
343 | map(
344 | lambda data: {**data, **stats},
345 | all_models_data_array[question_index][model_id],
346 | )
347 | )
348 | return all_models_data_array
349 |
350 |
351 | def get_filtered_files(
352 | filter_function: str,
353 | sorting_function: str,
354 | array_to_filter: List,
355 | ) -> List:
356 | filter_lambda_functions = [
357 | get_eval_function(func.strip())
358 | for func in (filter_function if filter_function else "True").split("&&")
359 | ]
360 | available_models = get_available_models()
361 | filtered_data = [
362 | list(
363 | filter(
364 | lambda data: catch_eval_exception(
365 | available_models, function, data, False
366 | ),
367 | array_to_filter,
368 | )
369 | )
370 | for function in filter_lambda_functions
371 | ]
372 |
373 | filtered_data = list(filter(lambda data: data != [], filtered_data))
374 | filtered_data = filtered_data[0] if len(filtered_data) > 0 else [{FILE_NAME: ""}]
375 | if sorting_function and filtered_data != [{FILE_NAME: ""}]:
376 | sorting_lambda_function = get_eval_function(sorting_function.strip())
377 | filtered_data.sort(
378 | key=lambda data: catch_eval_exception(
379 | available_models, sorting_lambda_function, data, 0
380 | )
381 | )
382 |
383 | return filtered_data
384 |
385 |
386 | def is_detailed_answers_rows_key(key: str) -> bool:
387 | return (
388 | key not in get_deleted_stats()
389 | and "index" not in key
390 | and key not in STATS_KEYS + list(get_metrics([]).keys())
391 | or key == QUESTION_FIELD
392 | )
393 |
394 |
395 | @functools.lru_cache(maxsize=1)
396 | def get_available_models() -> Dict:
397 | config = current_app.config["nemo_inspector"]
398 | runs_storage = {}
399 | for model_name, files in config["model_prediction"].items():
400 | runs_storage[model_name] = {
401 | "file_paths": files,
402 | }
403 |
404 | return runs_storage
405 |
406 |
407 | def get_file_id(file_names: List[str], files: List[Dict], column_id: str):
408 | file_id = 0
409 | file_name = (
410 | file_names[column_id]["value"]
411 | if isinstance(file_names[column_id], Dict)
412 | else file_names[column_id]
413 | )
414 | for i, file_data in enumerate(files):
415 | if file_data[FILE_NAME] == file_name:
416 | file_id = i
417 | break
418 | return file_id
419 |
420 |
421 | def resolve_type(field_type):
422 | origin = get_origin(field_type)
423 | if origin is not None:
424 | # If the type is a generic, use its origin (e.g., list, dict)
425 | return origin
426 | elif isinstance(field_type, type):
427 | return field_type
428 | else:
429 | # For special cases like typing.Any, return str by default
430 | return str
431 |
432 |
433 | def get_type_default(field_type):
434 | try:
435 | return resolve_type(field_type)()
436 | except Exception:
437 | return None
438 |
439 |
440 | def resolve_union_or_any(field_type):
441 | """Resolve Union and Any types to concrete types"""
442 | return resolve_type(field_type)
443 |
--------------------------------------------------------------------------------
/nemo_inspector/layouts/analyze_page_layouts/table_layouts.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 | import logging
17 | import math
18 | from typing import Dict, List
19 |
20 |
21 | import dash_bootstrap_components as dbc
22 | from dash import dash_table, html
23 |
24 | from nemo_inspector.layouts.analyze_page_layouts.modals_layouts import (
25 | get_change_label_modal_layout,
26 | get_filter_modal_layout,
27 | get_sorting_modal_layout,
28 | )
29 | from nemo_inspector.layouts.common_layouts import (
30 | get_selector_layout,
31 | get_single_prompt_output_layout,
32 | get_text_modes_layout,
33 | )
34 | from nemo_inspector.settings.constants import (
35 | ANSI,
36 | CODE,
37 | COMPARE,
38 | COMPARE_ICON_PATH,
39 | DATA_PAGE_SIZE,
40 | EDIT_ICON_PATH,
41 | ERROR_MESSAGE_TEMPLATE,
42 | FILE_NAME,
43 | FILES_ONLY,
44 | LABEL,
45 | LATEX,
46 | MODEL_SELECTOR_ID,
47 | STATS_KEYS,
48 | )
49 | from nemo_inspector.utils.common import (
50 | catch_eval_exception,
51 | get_available_models,
52 | get_compared_rows,
53 | get_editable_rows,
54 | get_excluded_row,
55 | get_filtered_files,
56 | get_general_custom_stats,
57 | get_metrics,
58 | get_table_data,
59 | is_detailed_answers_rows_key,
60 | )
61 |
62 |
63 | def get_short_info_table_layout() -> List[dbc.Row]:
64 | return [
65 | dbc.Row(
66 | dbc.Col(
67 | dash_table.DataTable(
68 | id="datatable",
69 | columns=[
70 | {
71 | "name": name,
72 | "id": name,
73 | "hideable": True,
74 | }
75 | for name in STATS_KEYS + list(get_metrics([]).keys())
76 | ],
77 | row_selectable="single",
78 | cell_selectable=False,
79 | page_action="custom",
80 | page_current=0,
81 | page_size=DATA_PAGE_SIZE,
82 | page_count=math.ceil(len(get_table_data()) / DATA_PAGE_SIZE),
83 | style_cell={
84 | "overflow": "hidden",
85 | "textOverflow": "ellipsis",
86 | "maxWidth": 0,
87 | "textAlign": "center",
88 | },
89 | style_header={
90 | "color": "text-primary",
91 | "text_align": "center",
92 | "height": "auto",
93 | "whiteSpace": "normal",
94 | },
95 | css=[
96 | {
97 | "selector": ".dash-spreadsheet-menu",
98 | "rule": "position:absolute; bottom: 8px",
99 | },
100 | {
101 | "selector": ".dash-filter--case",
102 | "rule": "display: none",
103 | },
104 | {
105 | "selector": ".column-header--hide",
106 | "rule": "display: none",
107 | },
108 | ],
109 | ),
110 | )
111 | ),
112 | ]
113 |
114 |
115 | def get_table_column_header(
116 | models: List[str], name: str, id: int, add_del_button: bool = False
117 | ) -> dbc.Col:
118 | del_model_layout = (
119 | [
120 | dbc.Button(
121 | "-",
122 | id={"type": "del_model", "id": id},
123 | outline=True,
124 | color="primary",
125 | className="me-1",
126 | style={"height": "40px"},
127 | ),
128 | ]
129 | if add_del_button
130 | else []
131 | )
132 | return dbc.Col(
133 | html.Div(
134 | [
135 | html.Div(
136 | get_selector_layout(
137 | models,
138 | json.loads(MODEL_SELECTOR_ID.format(id)),
139 | name,
140 | ),
141 | ),
142 | get_sorting_modal_layout(id),
143 | get_filter_modal_layout(id, mode=FILES_ONLY),
144 | get_change_label_modal_layout(id),
145 | ]
146 | + del_model_layout
147 | + [get_text_modes_layout(id)],
148 | style={"display": "inline-flex"},
149 | ),
150 | class_name="mt-1 bg-light font-monospace text-break small rounded border",
151 | id={"type": "column_header", "id": id},
152 | )
153 |
154 |
155 | def get_table_header(models: List[str]) -> List[dbc.Row]:
156 | return [
157 | dbc.Row(
158 | [
159 | dbc.Col(
160 | html.Div(
161 | "",
162 | ),
163 | width=2,
164 | class_name="mt-1 bg-light font-monospace text-break small rounded border",
165 | id="first_column",
166 | )
167 | ]
168 | + [
169 | get_table_column_header(get_available_models(), name, i, i != 0)
170 | for i, name in enumerate(models)
171 | ],
172 | id="detailed_answers_header",
173 | )
174 | ]
175 |
176 |
177 | def get_detailed_info_table_column(id: int, file_id=None) -> dbc.Col:
178 | return dbc.Col(
179 | html.Div(
180 | children=(
181 | get_selector_layout([], {"type": "file_selector", "id": file_id}, "")
182 | if file_id is not None
183 | else ""
184 | ),
185 | id={
186 | "type": "detailed_models_answers",
187 | "id": id,
188 | },
189 | ),
190 | class_name="mt-1 bg-light font-monospace text-break small rounded border",
191 | )
192 |
193 |
194 | def get_detailed_info_table_rows(keys: List[str], colums_number: int) -> List[dbc.Row]:
195 | return [
196 | dbc.Row(
197 | [
198 | dbc.Col(
199 | html.Div(
200 | html.Div(
201 | [
202 | html.Div(
203 | key,
204 | id={"type": "row_name", "id": i},
205 | style={"display": "inline-block"},
206 | ),
207 | dbc.Button(
208 | html.Img(
209 | src=EDIT_ICON_PATH,
210 | id={"type": "edit_row_image", "id": i},
211 | style={
212 | "height": "15px",
213 | "display": "inline-block",
214 | },
215 | ),
216 | id={"type": "edit_row_button", "id": i},
217 | outline=True,
218 | color="primary",
219 | className="me-1",
220 | style={
221 | "border": "none",
222 | "line-height": "1.2",
223 | "display": "inline-block",
224 | "margin-left": "1px",
225 | "display": (
226 | "none"
227 | if key in (FILE_NAME, LABEL)
228 | else "inline-block"
229 | ),
230 | },
231 | ),
232 | dbc.Button(
233 | html.Img(
234 | src=COMPARE_ICON_PATH,
235 | id={"type": "compare_texts", "id": i},
236 | style={
237 | "height": "15px",
238 | "display": "inline-block",
239 | },
240 | ),
241 | id={"type": "compare_texts_button", "id": i},
242 | outline=True,
243 | color="primary",
244 | className="me-1",
245 | style={
246 | "border": "none",
247 | "line-height": "1.2",
248 | "display": "inline-block",
249 | "margin-left": "-10px" if key != LABEL else "1px",
250 | "display": (
251 | "none" if key == FILE_NAME else "inline-block"
252 | ),
253 | },
254 | ),
255 | dbc.Button(
256 | "-",
257 | id={"type": "del_row", "id": i},
258 | outline=True,
259 | color="primary",
260 | className="me-1",
261 | style={
262 | "border": "none",
263 | "display": "inline-block",
264 | "margin-left": (
265 | "-9px" if key != FILE_NAME else "1px"
266 | ),
267 | },
268 | ),
269 | ],
270 | style={"display": "inline-block"},
271 | ),
272 | ),
273 | width=2,
274 | class_name="mt-1 bg-light font-monospace text-break small rounded border",
275 | )
276 | ]
277 | + [
278 | get_detailed_info_table_column(j * len(keys) + i)
279 | for j in range(colums_number)
280 | ],
281 | id={"type": "detailed_answers_row", "id": i},
282 | )
283 | for i, key in enumerate(keys)
284 | ]
285 |
286 |
287 | def get_detailed_info_table_layout(
288 | models: List[str],
289 | keys: List[str],
290 | ) -> List[dbc.Row]:
291 | return get_table_header(models) + get_detailed_info_table_rows(keys, len(models))
292 |
293 |
294 | def get_detailed_info_table_row_content(
295 | question_id: int,
296 | model: str,
297 | rows_names: List[str],
298 | files_names: List[str],
299 | file_id: int,
300 | col_id: int,
301 | compare_to: Dict = {},
302 | text_modes: List[str] = [CODE, LATEX, ANSI],
303 | ) -> List:
304 | table_data = get_table_data()[question_id].get(model, [])
305 | row_data = []
306 | empty_list = False
307 | if table_data[file_id].get(FILE_NAME, None) not in files_names:
308 | empty_list = True
309 | for key in filter(
310 | lambda key: is_detailed_answers_rows_key(key),
311 | rows_names,
312 | ):
313 | if file_id < 0 or len(table_data) <= file_id or key in get_excluded_row():
314 | value = ""
315 | elif key == FILE_NAME:
316 | value = get_selector_layout(
317 | files_names,
318 | {"type": "file_selector", "id": col_id},
319 | (table_data[file_id].get(key, None) if not empty_list else ""),
320 | )
321 | elif empty_list:
322 | value = ""
323 | elif key in get_editable_rows():
324 | value = str(table_data[file_id].get(key, None))
325 | else:
326 | value = get_single_prompt_output_layout(
327 | str(table_data[file_id].get(key, None)),
328 | text_modes + ([COMPARE] if key in get_compared_rows() else []),
329 | str(compare_to.get(key, "")),
330 | )
331 | row_data.append(
332 | value
333 | if key not in get_editable_rows()
334 | else dbc.Textarea(
335 | id={"type": "editable_row", "id": key, "model_name": model}, value=value
336 | )
337 | )
338 | return row_data
339 |
340 |
341 | def get_detailed_info_table_content(
342 | question_id: int,
343 | rows_names: List[str],
344 | models: List[str],
345 | files_id: List[int],
346 | filter_functions: List[str],
347 | sorting_functions: List[str],
348 | text_modes: List[List[str]],
349 | ) -> List:
350 | table_data = []
351 | for col_id, (model, file_id, filter_function, sorting_function, modes) in enumerate(
352 | zip(models, files_id, filter_functions, sorting_functions, text_modes)
353 | ):
354 | row_data = get_detailed_info_table_row_content(
355 | question_id=question_id,
356 | model=model,
357 | rows_names=rows_names,
358 | files_names=[
359 | file[FILE_NAME]
360 | for file in get_filtered_files(
361 | filter_function,
362 | sorting_function,
363 | get_table_data()[question_id][model] if len(get_table_data()) else [],
364 | )
365 | ],
366 | file_id=file_id,
367 | col_id=col_id,
368 | text_modes=modes,
369 | compare_to=get_table_data()[question_id][models[0]][files_id[0]],
370 | )
371 | table_data.extend(row_data)
372 | return table_data
373 |
374 |
375 | def get_general_stats_layout(
376 | base_model: str,
377 | ) -> html.Div:
378 | data_for_base_model = [data.get(base_model, []) for data in get_table_data()]
379 | custom_stats = {}
380 | for name, func in get_general_custom_stats().items():
381 | errors_dict = {}
382 | custom_stats[name] = catch_eval_exception(
383 | [],
384 | func,
385 | data_for_base_model,
386 | "Got error when applying function",
387 | errors_dict,
388 | )
389 | if len(errors_dict):
390 | logging.error(ERROR_MESSAGE_TEMPLATE.format(name, errors_dict))
391 |
392 | overall_samples = sum(len(question_data) for question_data in data_for_base_model)
393 | dataset_size = len(list(filter(lambda x: bool(x), data_for_base_model)))
394 | stats = {
395 | "dataset size": dataset_size,
396 | "overall number of samples": overall_samples,
397 | "generations per sample": (overall_samples / dataset_size if dataset_size else 0),
398 | **custom_stats,
399 | }
400 | return [html.Div([html.Pre(f"{name}: {value}") for name, value in stats.items()])]
401 |
--------------------------------------------------------------------------------
/nemo_inspector/callbacks/analyze_page/detailed_sample_info.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 | from typing import Dict, List, Tuple
17 |
18 | from dash import ALL, callback_context, no_update
19 | from dash.dependencies import Input, Output, State
20 | from dash.exceptions import PreventUpdate
21 |
22 | from nemo_inspector.callbacks import app
23 | from nemo_inspector.layouts import (
24 | get_detailed_info_table_column,
25 | get_table_column_header,
26 | get_detailed_info_table_row_content,
27 | get_detailed_info_table_content,
28 | )
29 | from nemo_inspector.settings.constants import (
30 | EDIT_ICON_PATH,
31 | FILE_NAME,
32 | MODEL_SELECTOR_ID,
33 | SAVE_ICON_PATH,
34 | )
35 | from nemo_inspector.utils.common import (
36 | get_available_models,
37 | get_compared_rows,
38 | get_editable_rows,
39 | get_excluded_row,
40 | get_file_id,
41 | get_filtered_files,
42 | get_table_data,
43 | )
44 |
45 |
46 | @app.callback(
47 | [
48 | Output("js_container", "children", allow_duplicate=True),
49 | Output("js_trigger", "children", allow_duplicate=True),
50 | ],
51 | Input({"type": "editable_row", "id": ALL, "model_name": ALL}, "value"),
52 | [
53 | State({"type": "model_selector", "id": ALL}, "value"),
54 | State("datatable", "selected_rows"),
55 | State("datatable", "page_current"),
56 | State("datatable", "page_size"),
57 | State({"type": "editable_row", "id": ALL, "model_name": ALL}, "id"),
58 | State({"type": "file_selector", "id": ALL}, "value"),
59 | State("js_trigger", "children"),
60 | ],
61 | prevent_initial_call=True,
62 | )
63 | def update_data_table(
64 | new_rows_values: List[str],
65 | models: List[str],
66 | idx: List[int],
67 | current_page: int,
68 | page_size: int,
69 | new_rows_ids: List[str],
70 | file_names: List[str],
71 | js_trigger: str,
72 | ) -> Tuple[str, str]:
73 | ctx = callback_context
74 | if not ctx.triggered or not idx:
75 | return no_update, no_update
76 |
77 | file_ids = {}
78 | question_id = current_page * page_size + idx[0]
79 | for model_id, name in enumerate(file_names):
80 | for file_id, file in enumerate(
81 | get_table_data()[question_id][models[model_id]]
82 | if len(get_table_data())
83 | else []
84 | ):
85 | if file[FILE_NAME] == name:
86 | file_ids[models[model_id]] = file_id
87 |
88 | for new_rows_id, new_rows_value in zip(new_rows_ids, new_rows_values):
89 | updated_field = new_rows_id["id"]
90 | updated_model = new_rows_id["model_name"]
91 | get_table_data()[question_id][updated_model][file_ids[updated_model]][
92 | updated_field
93 | ] = new_rows_value
94 |
95 | return "", js_trigger + " "
96 |
97 |
98 | @app.callback(
99 | [
100 | Output(
101 | {"type": "detailed_models_answers", "id": ALL},
102 | "children",
103 | allow_duplicate=True,
104 | ),
105 | Output(
106 | {"type": "filter_function_input", "id": ALL},
107 | "value",
108 | allow_duplicate=True,
109 | ),
110 | Output(
111 | {"type": "sorting_function_input", "id": ALL},
112 | "value",
113 | allow_duplicate=True,
114 | ),
115 | ],
116 | [
117 | Input("datatable", "selected_rows"),
118 | Input(
119 | "dummy_output",
120 | "children",
121 | ),
122 | ],
123 | [
124 | State({"type": "model_selector", "id": ALL}, "value"),
125 | State({"type": "sorting_function_input", "id": ALL}, "value"),
126 | State({"type": "filter_function_input", "id": ALL}, "value"),
127 | State({"type": "row_name", "id": ALL}, "children"),
128 | State("datatable", "page_current"),
129 | State("datatable", "page_size"),
130 | State({"type": "file_selector", "id": ALL}, "value"),
131 | State({"type": "text_modes", "id": ALL}, "value"),
132 | ],
133 | prevent_initial_call=True,
134 | )
135 | def show_item(
136 | idx: List[int],
137 | dummmy_trigger: str,
138 | models: List[str],
139 | sorting_functions: List[str],
140 | filter_functions: List[str],
141 | rows_names: List[str],
142 | current_page: int,
143 | page_size: int,
144 | file_names: List[str],
145 | text_modes: List[List[str]],
146 | ) -> List[str]:
147 | if not idx:
148 | raise PreventUpdate
149 | ctx = callback_context
150 | if not ctx.triggered:
151 | return [no_update, no_update, no_update]
152 | elif ctx.triggered[0]["prop_id"] == "datatable.selected_rows":
153 | filter_functions = [filter_functions[0]] + [""] * (len(filter_functions) - 1)
154 | sorting_functions = [sorting_functions[0]] + [None] * (len(sorting_functions) - 1)
155 | question_id = current_page * page_size + idx[0]
156 | file_ids = [0] * len(models)
157 | for column_id in range(len(file_names)):
158 | files = (
159 | get_table_data()[question_id][models[column_id]]
160 | if len(get_table_data())
161 | else []
162 | )
163 | file_ids[column_id] = get_file_id(file_names, files, column_id)
164 | return [
165 | get_detailed_info_table_content(
166 | question_id=question_id,
167 | rows_names=rows_names,
168 | models=models,
169 | files_id=file_ids,
170 | filter_functions=filter_functions[1:],
171 | sorting_functions=sorting_functions[1:],
172 | text_modes=text_modes,
173 | ),
174 | [""] * len(filter_functions),
175 | sorting_functions,
176 | ]
177 |
178 |
179 | @app.callback(
180 | Output(
181 | "dummy_output",
182 | "children",
183 | allow_duplicate=True,
184 | ),
185 | Input({"type": "compare_texts_button", "id": ALL}, "n_clicks"),
186 | [
187 | State("dummy_output", "children"),
188 | State({"type": "row_name", "id": ALL}, "children"),
189 | State({"type": "compare_texts_button", "id": ALL}, "n_clicks"),
190 | ],
191 | prevent_initial_call=True,
192 | )
193 | def compare(n_clicks: List[int], dummy_data: str, row_names: str, button_ids: List[str]):
194 | ctx = callback_context
195 | if not ctx.triggered or not n_clicks:
196 | return no_update
197 | button_id = json.loads(ctx.triggered[0]["prop_id"].split(".")[0])["id"]
198 | if row_names[button_id] not in get_compared_rows():
199 | get_compared_rows().add(row_names[button_id])
200 | else:
201 | get_compared_rows().remove(row_names[button_id])
202 | return dummy_data + "1"
203 |
204 |
205 | @app.callback(
206 | [
207 | Output(
208 | "dummy_output",
209 | "children",
210 | allow_duplicate=True,
211 | ),
212 | Output({"type": "edit_row_image", "id": ALL}, "src"),
213 | ],
214 | Input({"type": "edit_row_button", "id": ALL}, "n_clicks"),
215 | [
216 | State({"type": "row_name", "id": ALL}, "children"),
217 | State({"type": "edit_row_image", "id": ALL}, "id"),
218 | State({"type": "edit_row_image", "id": ALL}, "src"),
219 | State({"type": "model_selector", "id": ALL}, "value"),
220 | State("datatable", "selected_rows"),
221 | State("datatable", "page_current"),
222 | State("datatable", "page_size"),
223 | State({"type": "file_selector", "id": ALL}, "value"),
224 | State(
225 | "dummy_output",
226 | "children",
227 | ),
228 | ],
229 | prevent_initial_call=True,
230 | )
231 | def edit_row(
232 | n_clicks: List[int],
233 | rows: List[str],
234 | button_ids: List[Dict],
235 | edit_row_labels: List[str],
236 | models: List[str],
237 | idx: List[int],
238 | current_page: int,
239 | page_size: int,
240 | file_names: List[str],
241 | dummy_data: str,
242 | ) -> Tuple[str, List[str]]:
243 | ctx = callback_context
244 | if not ctx.triggered or not n_clicks or not idx:
245 | return no_update, [no_update] * len(button_ids)
246 | button_id = json.loads(ctx.triggered[0]["prop_id"].split(".")[0])["id"]
247 | row_index = 0
248 | for i, current_button_id in enumerate(button_ids):
249 | if current_button_id["id"] == button_id:
250 | row_index = i
251 | break
252 | file_ids = [0] * len(models)
253 | question_id = current_page * page_size + idx[0]
254 | for model_id, name in enumerate(file_names):
255 | for file_id, file in enumerate(
256 | get_table_data()[question_id][models[model_id]]
257 | if len(get_table_data())
258 | else []
259 | ):
260 | if file[FILE_NAME] == name:
261 | file_ids[model_id] = file_id
262 |
263 | if not n_clicks[row_index]:
264 | return no_update, [no_update] * len(button_ids)
265 |
266 | if rows[row_index] in get_editable_rows():
267 | edit_row_labels[row_index] = EDIT_ICON_PATH
268 | get_editable_rows().remove(rows[row_index])
269 | else:
270 | get_editable_rows().add(rows[row_index])
271 | edit_row_labels[row_index] = SAVE_ICON_PATH
272 |
273 | return dummy_data + "1", edit_row_labels
274 |
275 |
276 | @app.callback(
277 | [
278 | Output(
279 | "dummy_output",
280 | "children",
281 | allow_duplicate=True,
282 | ),
283 | Output({"type": "del_row", "id": ALL}, "children"),
284 | ],
285 | Input({"type": "del_row", "id": ALL}, "n_clicks"),
286 | [
287 | State({"type": "row_name", "id": ALL}, "children"),
288 | State({"type": "del_row", "id": ALL}, "id"),
289 | State({"type": "del_row", "id": ALL}, "children"),
290 | State(
291 | "dummy_output",
292 | "children",
293 | ),
294 | ],
295 | prevent_initial_call=True,
296 | )
297 | def del_row(
298 | n_clicks: List[int],
299 | rows: List[str],
300 | button_ids: List[Dict],
301 | del_row_labels: List[str],
302 | dummy_data: str,
303 | ) -> Tuple[str, List[str]]:
304 | ctx = callback_context
305 | if not ctx.triggered or not n_clicks:
306 | return no_update, [no_update] * len(button_ids)
307 | button_id = json.loads(ctx.triggered[0]["prop_id"].split(".")[0])["id"]
308 | row_index = 0
309 | for i, current_button_id in enumerate(button_ids):
310 | if current_button_id["id"] == button_id:
311 | row_index = i
312 | break
313 | if not n_clicks[row_index]:
314 | return no_update, [no_update] * len(button_ids)
315 | if rows[row_index] in get_excluded_row():
316 | get_excluded_row().remove(rows[row_index])
317 | del_row_labels[row_index] = "-"
318 | else:
319 | get_excluded_row().add(rows[row_index])
320 | del_row_labels[row_index] = "+"
321 |
322 | return dummy_data + "1", del_row_labels
323 |
324 |
325 | @app.callback(
326 | [
327 | Output(
328 | "detailed_answers_header",
329 | "children",
330 | allow_duplicate=True,
331 | ),
332 | Output(
333 | {"type": "detailed_answers_row", "id": ALL},
334 | "children",
335 | allow_duplicate=True,
336 | ),
337 | ],
338 | Input("add_model", "n_clicks"),
339 | [
340 | State("detailed_answers_header", "children"),
341 | State({"type": "detailed_answers_row", "id": ALL}, "children"),
342 | State({"type": "model_selector", "id": ALL}, "id"),
343 | State("datatable", "selected_rows"),
344 | ],
345 | prevent_initial_call=True,
346 | )
347 | def add_model(
348 | n_clicks: int,
349 | header: List,
350 | rows: List,
351 | selectors_ids: List[int],
352 | idx: List[int],
353 | ) -> Tuple[List, List]:
354 | if not n_clicks:
355 | return no_update, [no_update] * len(rows)
356 | available_models = list(get_available_models().keys())
357 | last_header_id = selectors_ids[-1]["id"] if selectors_ids != [] else -1
358 | header.append(
359 | get_table_column_header(
360 | available_models, available_models[0], last_header_id + 1, True
361 | )
362 | )
363 | last_cell_id = rows[-1][-1]["props"]["children"]["props"]["id"]["id"]
364 | for i, row in enumerate(rows):
365 | row.append(
366 | get_detailed_info_table_column(
367 | last_cell_id + i + 1,
368 | file_id=last_header_id + 1 if i == 0 and idx else None,
369 | )
370 | )
371 |
372 | return header, rows
373 |
374 |
375 | @app.callback(
376 | [
377 | Output("detailed_answers_header", "children"),
378 | Output({"type": "detailed_answers_row", "id": ALL}, "children"),
379 | ],
380 | Input({"type": "del_model", "id": ALL}, "n_clicks"),
381 | [
382 | State("detailed_answers_header", "children"),
383 | State({"type": "detailed_answers_row", "id": ALL}, "children"),
384 | State({"type": "del_model", "id": ALL}, "id"),
385 | ],
386 | prevent_initial_call=True,
387 | )
388 | def del_model(
389 | n_clicks: List[int],
390 | header: List,
391 | rows: List,
392 | id_del: List[int],
393 | ) -> Tuple[List, List]:
394 | ctx = callback_context
395 | if not ctx.triggered:
396 | return no_update, [no_update] * len(rows)
397 |
398 | button_id = json.loads(ctx.triggered[0]["prop_id"].split(".")[0])["id"]
399 |
400 | if not ctx.triggered[0]["value"]:
401 | return no_update, [no_update] * len(rows)
402 |
403 | for i, id in enumerate(id_del):
404 | if id["id"] == button_id:
405 | index = i + 2
406 |
407 | header.pop(index)
408 | for i, row in enumerate(rows):
409 | row.pop(index)
410 |
411 | return header, rows
412 |
413 |
414 | @app.callback(
415 | [
416 | Output(
417 | {"type": "detailed_models_answers", "id": ALL},
418 | "children",
419 | allow_duplicate=True,
420 | ),
421 | Output(
422 | "dummy_output",
423 | "children",
424 | allow_duplicate=True,
425 | ),
426 | ],
427 | [
428 | Input({"type": "file_selector", "id": ALL}, "value"),
429 | Input({"type": "text_modes", "id": ALL}, "value"),
430 | ],
431 | [
432 | State("datatable", "selected_rows"),
433 | State({"type": "file_selector", "id": ALL}, "options"),
434 | State({"type": "model_selector", "id": ALL}, "value"),
435 | State({"type": "model_selector", "id": ALL}, "id"),
436 | State({"type": "row_name", "id": ALL}, "children"),
437 | State("datatable", "page_current"),
438 | State("datatable", "page_size"),
439 | State(
440 | {"type": "detailed_models_answers", "id": ALL},
441 | "children",
442 | ),
443 | State("dummy_output", "children"),
444 | ],
445 | prevent_initial_call=True,
446 | )
447 | def change_file(
448 | file_names: List[str],
449 | text_modes: List[List[str]],
450 | idx: List[int],
451 | file_options: List[str],
452 | models: List[str],
453 | model_ids: List[int],
454 | rows_names: List[str],
455 | current_page: int,
456 | page_size: int,
457 | table_data: List[str],
458 | dummy_data: str,
459 | ) -> List[str]:
460 | if not idx:
461 | raise PreventUpdate
462 |
463 | ctx = callback_context
464 | if not ctx.triggered:
465 | return [no_update] * len(table_data), no_update
466 |
467 | question_id = page_size * current_page + idx[0]
468 | for trigger in ctx.triggered:
469 | try:
470 | button_id = model_ids.index(
471 | json.loads(
472 | MODEL_SELECTOR_ID.format(
473 | json.loads(trigger["prop_id"].split(".")[0])["id"]
474 | )
475 | )
476 | )
477 | except ValueError:
478 | continue
479 |
480 | model = models[button_id]
481 |
482 | file_id = get_file_id(file_names, get_table_data()[question_id][model], button_id)
483 | base_file_id = get_file_id(
484 | file_names, get_table_data()[question_id][models[0]], 0
485 | )
486 |
487 | question_id = current_page * page_size + idx[0]
488 | table_data[button_id * len(rows_names) : (button_id + 1) * len(rows_names)] = (
489 | get_detailed_info_table_row_content(
490 | question_id=question_id,
491 | model=model,
492 | file_id=file_id,
493 | rows_names=rows_names,
494 | files_names=[option["value"] for option in file_options[button_id]],
495 | col_id=button_id,
496 | text_modes=text_modes[button_id],
497 | compare_to=get_table_data()[question_id][models[0]][base_file_id],
498 | )
499 | )
500 | return table_data, dummy_data + "1" if button_id == 0 else dummy_data
501 |
502 |
503 | @app.callback(
504 | [
505 | Output({"type": "file_selector", "id": ALL}, "options"),
506 | Output({"type": "file_selector", "id": ALL}, "value"),
507 | ],
508 | [
509 | Input({"type": "apply_filter_button", "id": ALL}, "n_clicks"),
510 | Input({"type": "apply_sorting_button", "id": ALL}, "n_clicks"),
511 | Input({"type": "model_selector", "id": ALL}, "value"),
512 | ],
513 | [
514 | State({"type": "model_selector", "id": ALL}, "id"),
515 | State({"type": "sorting_function_input", "id": ALL}, "value"),
516 | State({"type": "filter_function_input", "id": ALL}, "value"),
517 | State({"type": "apply_on_filtered_data", "id": ALL}, "value"),
518 | State("datatable", "page_current"),
519 | State("datatable", "page_size"),
520 | State("datatable", "selected_rows"),
521 | State({"type": "file_selector", "id": ALL}, "options"),
522 | State({"type": "file_selector", "id": ALL}, "value"),
523 | ],
524 | prevent_initial_call=True,
525 | )
526 | def change_files_order(
527 | filter_n_click: int,
528 | sorting_n_click: int,
529 | models: List[str],
530 | model_ids: List[int],
531 | sorting_functions: List[str],
532 | filter_functions: List[str],
533 | apply_on_filtered_data: List[int],
534 | current_page: int,
535 | page_size: int,
536 | idx: List[int],
537 | file_selector_options: List[str],
538 | file_selector_values: List[str],
539 | ) -> Tuple[List[List[str]], List[str]]:
540 | no_updates = [no_update] * len(file_selector_options)
541 | if not filter_n_click and not sorting_n_click:
542 | return no_updates, no_updates
543 | if not idx:
544 | raise PreventUpdate
545 | ctx = callback_context
546 | if not ctx.triggered:
547 | return no_updates, no_updates
548 | try:
549 | button_id = model_ids.index(
550 | json.loads(
551 | MODEL_SELECTOR_ID.format(
552 | json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"]
553 | )
554 | )
555 | )
556 | except ValueError:
557 | return no_updates, no_updates
558 |
559 | if not ctx.triggered[0]["value"] or button_id == -1:
560 | return no_updates, no_updates
561 | model = models[button_id]
562 | question_id = current_page * page_size + idx[0]
563 | array_to_filter = (
564 | get_table_data()[question_id][model]
565 | if not apply_on_filtered_data or not apply_on_filtered_data[button_id]
566 | else list(
567 | filter(
568 | lambda data: data[FILE_NAME]
569 | in [file_name["label"] for file_name in file_selector_options],
570 | get_table_data()[question_id][model],
571 | )
572 | )
573 | )
574 | file_selector_options[button_id] = [
575 | {"label": data[FILE_NAME], "value": data[FILE_NAME]}
576 | for data in get_filtered_files(
577 | filter_functions[button_id + 1],
578 | sorting_functions[button_id + 1],
579 | array_to_filter,
580 | )
581 | ]
582 | file_selector_values[button_id] = file_selector_options[button_id][0]
583 |
584 | return file_selector_options, file_selector_values
585 |
--------------------------------------------------------------------------------