├── nemo_inspector ├── assets │ ├── images │ │ └── icons │ │ │ ├── edit_icon.png │ │ │ ├── save_icon.png │ │ │ └── compare_icon.png │ ├── styles │ │ ├── styles.css │ │ └── ansi_styles.css │ └── scripts │ │ ├── change_element_height.js │ │ └── register_textarea.js ├── tests │ ├── README.md │ └── test_ping.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── decoration │ │ ├── __init__.py │ │ ├── plain_text.py │ │ ├── code.py │ │ ├── common.py │ │ └── latex.py │ └── common.py ├── settings │ ├── __init__.py │ ├── constants │ │ ├── paths.py │ │ ├── templates.py │ │ ├── configurations.py │ │ ├── common.py │ │ └── __init__.py │ └── inspector_config.py ├── layouts │ ├── analyze_page_layouts │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── base_layout.py │ │ ├── modals_layouts.py │ │ └── table_layouts.py │ ├── base_layouts.py │ ├── __init__.py │ └── common_layouts.py ├── __main__.py ├── callbacks │ ├── common │ │ ├── __init__.py │ │ ├── decoration.py │ │ └── navigation.py │ ├── __init__.py │ └── analyze_page │ │ ├── __init__.py │ │ ├── update_dataset.py │ │ ├── short_info_table.py │ │ ├── save_dataset.py │ │ ├── sort_dataset.py │ │ ├── count_stats_dataset.py │ │ ├── filter_dataset.py │ │ ├── label_dataset.py │ │ └── detailed_sample_info.py ├── inspector_app.py └── parse_agruments_helpers.py ├── requirements ├── inspector-tests.txt └── inspector.txt ├── setup.py ├── README.md └── LICENSE /nemo_inspector/assets/images/icons/edit_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/NeMo-Inspector/HEAD/nemo_inspector/assets/images/icons/edit_icon.png -------------------------------------------------------------------------------- /nemo_inspector/assets/images/icons/save_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/NeMo-Inspector/HEAD/nemo_inspector/assets/images/icons/save_icon.png -------------------------------------------------------------------------------- /nemo_inspector/assets/images/icons/compare_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/NeMo-Inspector/HEAD/nemo_inspector/assets/images/icons/compare_icon.png -------------------------------------------------------------------------------- /nemo_inspector/assets/styles/styles.css: -------------------------------------------------------------------------------- 1 | .button-class { 2 | line-height: 20px; 3 | font-size: 14px; 4 | height: 40px; 5 | text-overflow: ellipsis; 6 | white-space: nowrap; 7 | overflow: hidden; 8 | margin-left: 2px; 9 | } -------------------------------------------------------------------------------- /nemo_inspector/tests/README.md: -------------------------------------------------------------------------------- 1 | To launch tests firstly install all the requirements 2 | ``` 3 | pip install -r requirements/main.txt 4 | pip install -r requirements/inspector.txt 5 | pip install -r requirements/common-tests.txt 6 | pip install -r requirements/inspector-tests.txt 7 | ``` 8 | Now it is possible to launch tests 9 | ``` 10 | pytest inspector/tests 11 | ``` -------------------------------------------------------------------------------- /nemo_inspector/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nemo_inspector/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nemo_inspector/settings/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nemo_inspector/layouts/analyze_page_layouts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /requirements/inspector-tests.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | dash[testing] 16 | pytest-rerunfailures 17 | webdriver-manager==4.0.2 18 | -------------------------------------------------------------------------------- /nemo_inspector/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from nemo_inspector.inspector_app import main 16 | 17 | if __name__ == "__main__": 18 | main() 19 | -------------------------------------------------------------------------------- /requirements/inspector.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ansi2html 16 | dash 17 | dash-ace 18 | dash_bootstrap_components 19 | joblib 20 | pandas 21 | pygments 22 | sshtunnel_requests 23 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import nemo_inspector.callbacks.common.decoration as decoration 16 | import nemo_inspector.callbacks.common.navigation as navigation 17 | -------------------------------------------------------------------------------- /nemo_inspector/settings/constants/paths.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pathlib 15 | 16 | COMPARE_ICON_PATH = "assets/images/icons/compare_icon.png" 17 | EDIT_ICON_PATH = "assets/images/icons/edit_icon.png" 18 | SAVE_ICON_PATH = "assets/images/icons/save_icon.png" 19 | PATH_TO_THE_REPOSITORY = pathlib.Path(__file__).parents[3] 20 | -------------------------------------------------------------------------------- /nemo_inspector/settings/constants/templates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ERROR_MESSAGE_TEMPLATE = "When applying {} function\ngot errors\n{}" 16 | QUERY_INPUT_ID = '{{"type": "{}", "id": "{}"}}' 17 | MODEL_SELECTOR_ID = '{{"type": "model_selector", "id": {}}}' 18 | LABEL_SELECTOR_ID = '{{"type": "label_selector", "id": {}}}' 19 | -------------------------------------------------------------------------------- /nemo_inspector/utils/decoration/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from nemo_inspector.utils.decoration.common import ( 16 | design_text_output, 17 | get_height_adjustment, 18 | ) 19 | from nemo_inspector.utils.decoration.code import highlight_code 20 | from nemo_inspector.utils.decoration.plain_text import color_text_diff 21 | -------------------------------------------------------------------------------- /nemo_inspector/assets/scripts/change_element_height.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | window.addEventListener('message', function(event) { 16 | if (event.data && event.data.frameHeight && event.data.frameId) { 17 | var iframe = document.getElementById(event.data.frameId); 18 | if (iframe) { 19 | iframe.style.height = event.data.frameHeight + 'px'; 20 | } 21 | } 22 | }, false); 23 | -------------------------------------------------------------------------------- /nemo_inspector/settings/constants/configurations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | CODE_SEPARATORS = { 16 | "code_begin": "{code_begin}", 17 | "code_end": "{code_end}", 18 | "code_output_begin": "{code_output_begin}", 19 | "code_output_end": "{code_output_end}", 20 | "code_output_format": "llama", 21 | } 22 | DATA_PAGE_SIZE = 10 23 | EXTRA_FIELDS = ["page_index", "file_name"] 24 | STATS_KEYS = [ 25 | "question_index", 26 | "problem", 27 | ] 28 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib import Path 17 | import dash_bootstrap_components as dbc 18 | from dash import Dash 19 | 20 | assets_path = os.path.join(Path(__file__).parents[1], "assets") 21 | 22 | app = Dash( 23 | __name__, 24 | suppress_callback_exceptions=True, 25 | external_stylesheets=[dbc.themes.BOOTSTRAP], 26 | assets_folder=assets_path, 27 | ) 28 | 29 | import nemo_inspector.callbacks.common as common 30 | import nemo_inspector.callbacks.analyze_page as analyze_page 31 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/common/decoration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dash import html 16 | from dash.dependencies import Input, Output 17 | 18 | from nemo_inspector.callbacks import app 19 | from nemo_inspector.utils.decoration import get_height_adjustment 20 | 21 | 22 | @app.callback( 23 | Output("js_container", "children", allow_duplicate=True), 24 | [ 25 | Input("page_content", "children"), 26 | Input("js_trigger", "children"), 27 | ], 28 | prevent_initial_call=True, 29 | ) 30 | def adjust_text_area_height(content: html.Div, trigger: str) -> html.Iframe: 31 | return get_height_adjustment() 32 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import nemo_inspector.callbacks.analyze_page.count_stats_dataset as count_stats_dataset 17 | import nemo_inspector.callbacks.analyze_page.detailed_sample_info as detailed_sample_info 18 | import nemo_inspector.callbacks.analyze_page.filter_dataset as filter_dataset 19 | import nemo_inspector.callbacks.analyze_page.label_dataset as label_dataset 20 | import nemo_inspector.callbacks.analyze_page.save_dataset as save_dataset 21 | import nemo_inspector.callbacks.analyze_page.sort_dataset as sort_dataset 22 | import nemo_inspector.callbacks.analyze_page.short_info_table as short_info_table 23 | import nemo_inspector.callbacks.analyze_page.update_dataset as update_dataset 24 | -------------------------------------------------------------------------------- /nemo_inspector/assets/scripts/register_textarea.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | function registerTextarea() { 16 | var textareas = document.querySelectorAll("textarea"); 17 | textareas.forEach(function(textarea) { 18 | function updateHeight() { 19 | textarea.style.height = 0 + 'px'; 20 | 21 | var height = Math.max(textarea.scrollHeight, textarea.offsetHeight, 22 | textarea.clientHeight); 23 | 24 | textarea.style.height = height + 'px'; 25 | }; 26 | textarea.onload = updateHeight; 27 | textarea.onresize = updateHeight; 28 | textarea.addEventListener('input', updateHeight); 29 | updateHeight() 30 | }); 31 | }; 32 | -------------------------------------------------------------------------------- /nemo_inspector/layouts/base_layouts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dash_bootstrap_components as dbc 16 | from dash import dcc, html 17 | 18 | 19 | def get_main_page_layout() -> html.Div: 20 | return html.Div( 21 | [ 22 | dcc.Location(id="url", refresh=False), 23 | dbc.NavbarSimple( 24 | brand="NeMo Inspector", 25 | sticky="top", 26 | color="blue", 27 | dark=True, 28 | class_name="mb-2", 29 | ), 30 | dbc.Container(id="page_content"), 31 | dbc.Container(id="js_trigger", style={"display": "none"}, children=""), 32 | dbc.Container(id="js_container"), 33 | dbc.Container(id="dummy_output", style={"display": "none"}, children=""), 34 | ] 35 | ) 36 | -------------------------------------------------------------------------------- /nemo_inspector/settings/constants/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ANSI = "ansi" 16 | BASE_GENERATION = "base_generation" 17 | CHOOSE_GENERATION = "choose generation" 18 | CHOOSE_LABEL = "choose label" 19 | COMPARE = "compare" 20 | CODE = "code" 21 | CODE_BEGIN = "code_begin" 22 | CODE_END = "code_end" 23 | CODE_OUTPUT_BEGIN = "code_output_begin" 24 | CODE_OUTPUT_END = "code_output_end" 25 | CUSTOM = "custom" 26 | DELETE = "delete" 27 | EXPECTED_ANSWER_FIELD = "expected_answer" 28 | FILE_NAME = "file_name" 29 | FILES_ONLY = "files_only" 30 | FILES_FILTERING = "add_files_filtering" 31 | GENERAL_STATS = "general_stats" 32 | INLINE_STATS = "inline_stats" 33 | LABEL = "labels" 34 | LATEX = "latex" 35 | MARKDOWN = "markdown" 36 | QUESTIONS_FILTERING = "questions_filtering" 37 | QUESTION_FIELD = "problem" 38 | UNDEFINED = "undefined" 39 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/common/navigation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dash import html 16 | from dash.dependencies import Input, Output 17 | from flask import current_app 18 | 19 | from nemo_inspector.callbacks import app 20 | from nemo_inspector.layouts import get_compare_test_layout 21 | from nemo_inspector.settings.constants import ( 22 | CODE_BEGIN, 23 | CODE_END, 24 | CODE_OUTPUT_BEGIN, 25 | CODE_OUTPUT_END, 26 | ) 27 | from nemo_inspector.settings.constants.configurations import CODE_SEPARATORS 28 | 29 | 30 | @app.callback( 31 | Output("page_content", "children"), 32 | Input("url", "pathname"), 33 | ) 34 | def nav_click(url: str) -> html.Div: 35 | config = current_app.config["nemo_inspector"] 36 | config["code_separators"] = ( 37 | config["code_tags"][CODE_BEGIN], 38 | config["code_tags"][CODE_END], 39 | ) 40 | config["code_output_separators"] = ( 41 | config["code_tags"][CODE_OUTPUT_BEGIN], 42 | config["code_tags"][CODE_OUTPUT_END], 43 | ) 44 | 45 | return get_compare_test_layout() 46 | -------------------------------------------------------------------------------- /nemo_inspector/layouts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from nemo_inspector.layouts.analyze_page_layouts.base_layout import ( 16 | get_compare_test_layout, 17 | get_filtered_tables_layout, 18 | get_tables_layout, 19 | get_updated_tables_layout, 20 | get_sorted_tables_layout, 21 | ) 22 | from nemo_inspector.layouts.analyze_page_layouts.utils import ( 23 | get_stats_input, 24 | get_stats_text, 25 | get_filter_text, 26 | ) 27 | from nemo_inspector.layouts.base_layouts import ( 28 | get_main_page_layout, 29 | ) 30 | from nemo_inspector.layouts.common_layouts import ( 31 | get_selector_layout, 32 | get_single_prompt_output_layout, 33 | get_switch_layout, 34 | get_text_modes_layout, 35 | ) 36 | 37 | from nemo_inspector.layouts.analyze_page_layouts.table_layouts import ( 38 | get_detailed_info_table_column, 39 | get_filter_modal_layout, 40 | get_table_column_header, 41 | get_detailed_info_table_row_content, 42 | get_single_prompt_output_layout, 43 | get_short_info_table_layout, 44 | get_detailed_info_table_content, 45 | ) 46 | -------------------------------------------------------------------------------- /nemo_inspector/settings/constants/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from nemo_inspector.settings.constants.configurations import ( 17 | CODE_SEPARATORS, 18 | DATA_PAGE_SIZE, 19 | EXTRA_FIELDS, 20 | STATS_KEYS, 21 | ) 22 | from nemo_inspector.settings.constants.common import ( 23 | EXPECTED_ANSWER_FIELD, 24 | ANSI, 25 | BASE_GENERATION, 26 | INLINE_STATS, 27 | GENERAL_STATS, 28 | CHOOSE_GENERATION, 29 | CHOOSE_LABEL, 30 | COMPARE, 31 | CODE, 32 | CUSTOM, 33 | DELETE, 34 | FILE_NAME, 35 | FILES_ONLY, 36 | FILES_FILTERING, 37 | GENERAL_STATS, 38 | CODE_BEGIN, 39 | CODE_END, 40 | CODE_OUTPUT_BEGIN, 41 | CODE_OUTPUT_END, 42 | QUESTIONS_FILTERING, 43 | QUESTION_FIELD, 44 | UNDEFINED, 45 | MARKDOWN, 46 | LABEL, 47 | LATEX, 48 | ) 49 | from nemo_inspector.settings.constants.paths import ( 50 | COMPARE_ICON_PATH, 51 | EDIT_ICON_PATH, 52 | SAVE_ICON_PATH, 53 | ) 54 | from nemo_inspector.settings.constants.templates import ( 55 | ERROR_MESSAGE_TEMPLATE, 56 | QUERY_INPUT_ID, 57 | MODEL_SELECTOR_ID, 58 | LABEL_SELECTOR_ID, 59 | ) 60 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import setup, find_packages 16 | 17 | 18 | def parse_requirements(filename): 19 | with open(filename) as f: 20 | return f.read().splitlines() 21 | 22 | 23 | # Read the requirements from the requirements.txt file 24 | requirements = parse_requirements("requirements/inspector.txt") 25 | 26 | setup( 27 | name="nemo_inspector", 28 | version="0.1.0", 29 | description="NeMo Inspector - a tool for datasets analysis", 30 | url="", 31 | long_description=open("README.md").read(), 32 | long_description_content_type="text/markdown", 33 | packages=find_packages(), 34 | include_package_data=True, 35 | install_requires=requirements, 36 | license="Apache License, Version 2.0", 37 | classifiers=[ 38 | "Programming Language :: Python :: 3", 39 | "Programming Language :: Python :: 3.10", 40 | "License :: OSI Approved :: Apache Software License", 41 | "Operating System :: OS Independent", 42 | ], 43 | entry_points={ 44 | "console_scripts": [ 45 | "nemo_inspector=nemo_inspector.inspector_app:main", 46 | ], 47 | }, 48 | python_requires=">=3.10", 49 | ) 50 | -------------------------------------------------------------------------------- /nemo_inspector/settings/inspector_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from itertools import chain 17 | from pathlib import Path 18 | from typing import Dict, Iterable, List 19 | 20 | from nemo_inspector.settings.constants.configurations import CODE_SEPARATORS 21 | 22 | 23 | def _expand_paths(paths: Iterable[str]) -> List[str]: 24 | expanded_files: List[str] = [] 25 | for path in paths: 26 | expanded = Path(path).expanduser() 27 | if any(char in str(expanded) for char in ["*", "?"]): 28 | expanded_files.extend(sorted(map(str, expanded.parent.glob(expanded.name)))) 29 | elif expanded.is_dir(): 30 | expanded_files.extend(sorted(map(str, expanded.rglob("*.jsonl")))) 31 | elif expanded.exists(): 32 | expanded_files.append(str(expanded)) 33 | return expanded_files 34 | 35 | 36 | def unroll_files(paths: Iterable[str]) -> List[str]: 37 | return list(dict.fromkeys(chain.from_iterable(_expand_paths([path]) for path in paths))) 38 | 39 | 40 | @dataclass(kw_only=True) 41 | class InspectorConfig: 42 | model_prediction: Dict[str, str] = field(default_factory=dict) 43 | save_generations_path: str = "nemo_inspector/results/saved_generations" 44 | code_tags: Dict[str, str] = field(default_factory=lambda: CODE_SEPARATORS) 45 | 46 | def __post_init__(self): 47 | self.model_prediction = { 48 | model_name: unroll_files(file_path.split(" ")) 49 | for model_name, file_path in self.model_prediction.items() 50 | } 51 | -------------------------------------------------------------------------------- /nemo_inspector/inspector_app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | from pathlib import Path 17 | import argparse 18 | import dataclasses 19 | import signal 20 | 21 | from nemo_inspector.parse_agruments_helpers import ( 22 | add_arguments_from_dataclass, 23 | args_postproccessing, 24 | create_dataclass_from_args, 25 | ) 26 | 27 | sys.path.append(str(Path(__file__).parents[1])) 28 | 29 | from nemo_inspector.layouts import get_main_page_layout 30 | 31 | from nemo_inspector.settings.inspector_config import InspectorConfig 32 | from nemo_inspector.settings.constants.configurations import CODE_SEPARATORS 33 | 34 | 35 | def main(): 36 | signal.signal(signal.SIGALRM, signal.SIG_IGN) 37 | 38 | parser = argparse.ArgumentParser(description="NeMo Inspector") 39 | 40 | add_arguments_from_dataclass( 41 | parser, 42 | InspectorConfig, 43 | enforce_required=False, 44 | use_type_defaults=True, 45 | ) 46 | 47 | args = parser.parse_args() 48 | args_dict = vars(args) 49 | 50 | cfg = dataclasses.asdict(create_dataclass_from_args(InspectorConfig, args_dict)) 51 | cfg.setdefault("code_tags", {}) 52 | cfg["code_tags"] = {**CODE_SEPARATORS, **cfg["code_tags"]} 53 | cfg = args_postproccessing(cfg) 54 | from nemo_inspector.callbacks import app 55 | 56 | app.server.config.update({"nemo_inspector": cfg}) 57 | app.title = "NeMo Inspector" 58 | app.layout = get_main_page_layout() 59 | app.run( 60 | host="localhost", 61 | port="8080", 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/update_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Tuple 16 | 17 | from dash import ALL, html, no_update 18 | from dash.dependencies import Input, Output, State 19 | 20 | from nemo_inspector.callbacks import app 21 | from nemo_inspector.layouts import get_updated_tables_layout 22 | from nemo_inspector.settings.constants import CHOOSE_GENERATION 23 | 24 | 25 | @app.callback( 26 | [ 27 | Output("update_dataset_modal", "is_open", allow_duplicate=True), 28 | Output("js_container", "children", allow_duplicate=True), 29 | Output("js_trigger", "children", allow_duplicate=True), 30 | ], 31 | [ 32 | Input("update_dataset_button", "n_clicks"), 33 | Input("apply_update_dataset_button", "n_clicks"), 34 | ], 35 | [State("update_dataset_modal", "is_open"), State("js_trigger", "children")], 36 | prevent_initial_call=True, 37 | ) 38 | def open_update_dataset_modal(n1: int, n2: int, is_open: bool, js_trigger: str) -> bool: 39 | if n1 or n2: 40 | is_open = not is_open 41 | return is_open, "", js_trigger + " " 42 | return is_open, "", js_trigger + " " 43 | 44 | 45 | @app.callback( 46 | [ 47 | Output("compare_models_rows", "children", allow_duplicate=True), 48 | Output("loading_container", "children", allow_duplicate=True), 49 | ], 50 | Input("apply_update_dataset_button", "n_clicks"), 51 | [ 52 | State("update_dataset_input", "value"), 53 | State({"type": "model_selector", "id": ALL}, "value"), 54 | State("base_model_answers_selector", "value"), 55 | State("loading_container", "children"), 56 | ], 57 | prevent_initial_call=True, 58 | ) 59 | def update_dataset( 60 | n_ckicks: int, 61 | update_function: str, 62 | models: List[str], 63 | base_model: str, 64 | loading_container: str, 65 | ) -> Tuple[List[html.Tr], bool]: 66 | if base_model == CHOOSE_GENERATION or not update_function: 67 | return no_update, no_update 68 | return ( 69 | get_updated_tables_layout( 70 | base_model=base_model, update_function=update_function, models=models 71 | ), 72 | loading_container + " ", 73 | ) 74 | -------------------------------------------------------------------------------- /nemo_inspector/tests/test_ping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import subprocess 17 | import sys 18 | from pathlib import Path 19 | 20 | import pytest 21 | from selenium import webdriver 22 | from selenium.webdriver.chrome.options import Options 23 | from selenium.webdriver.chrome.service import Service 24 | from selenium.webdriver.common.by import By 25 | from selenium.webdriver.support import expected_conditions as EC 26 | from selenium.webdriver.support.ui import WebDriverWait 27 | from webdriver_manager.chrome import ChromeDriverManager 28 | from webdriver_manager.core.os_manager import ChromeType 29 | 30 | project_root = str(Path(__file__).parents[2]) 31 | sys.path.remove(str(Path(__file__).parents[0])) 32 | 33 | 34 | @pytest.fixture(scope="module") 35 | def nemo_inspector_process(): 36 | # Start the NeMo Inspector as a subprocess 37 | 38 | process = subprocess.Popen( 39 | ["python", "nemo_inspector"], 40 | cwd=project_root, 41 | stdout=subprocess.PIPE, 42 | stderr=subprocess.PIPE, 43 | text=True, 44 | ) 45 | 46 | yield process 47 | 48 | # Terminate the process after the tests 49 | process.terminate() 50 | process.wait() 51 | 52 | 53 | @pytest.fixture 54 | def chrome_driver(): 55 | chrome_driver_path = ChromeDriverManager(chrome_type=ChromeType.GOOGLE).install() 56 | options = Options() 57 | options.page_load_strategy = "normal" 58 | options.add_argument("--headless") 59 | options.add_argument("--disable-gpu") 60 | options.add_argument("--no-sandbox") 61 | options.add_argument("--disable-dev-shm-usage") 62 | 63 | service = Service(chrome_driver_path) 64 | driver = webdriver.Chrome(service=service, options=options) 65 | os.environ["PATH"] += os.pathsep + "/".join(chrome_driver_path.split("/")[:-1]) 66 | yield driver 67 | driver.quit() 68 | 69 | 70 | @pytest.mark.parametrize( 71 | ("element_id", "url"), 72 | [("add_model", "/analyze")], 73 | ) 74 | def test_dash_app_launch(chrome_driver, nemo_inspector_process, element_id, url): 75 | full_url = f"http://localhost:8080{url}" 76 | 77 | chrome_driver.get(full_url) 78 | 79 | element = WebDriverWait(chrome_driver, 10).until( 80 | EC.presence_of_element_located((By.ID, element_id)) 81 | ) 82 | assert element.is_displayed() 83 | -------------------------------------------------------------------------------- /nemo_inspector/assets/styles/ansi_styles.css: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | 16 | /* Foreground Colors */ 17 | .ansi30 { color: black; } /* Black */ 18 | .ansi31 { color: red ; } /* Red */ 19 | .ansi32 { color: green; } /* Green */ 20 | .ansi33 { color: yellow; } /* Yellow */ 21 | .ansi34 { color: blue; } /* Blue */ 22 | .ansi35 { color: magenta; } /* Magenta */ 23 | .ansi36 { color: cyan; } /* Cyan */ 24 | .ansi37 { color: white; } /* White */ 25 | .ansi90 { color: grey; } /* Bright Black (grey) */ 26 | .ansi91 { color: #FFCCCB; } /* Bright Red */ 27 | .ansi92 { color: lightgreen; } /* Bright Green */ 28 | .ansi93 { color: lightyellow; } /* Bright Yellow */ 29 | .ansi94 { color: lightblue; } /* Bright Blue */ 30 | .ansi95 { color: #ff80ff;} /* Bright Magenta */ 31 | .ansi96 { color: lightcyan; } /* Bright Cyan */ 32 | .ansi97 { color: #FFFFF7; } /* Bright White */ 33 | 34 | /* Background Colors */ 35 | .ansi40 { background-color: black; } /* Black */ 36 | .ansi41 { background-color: red; } /* Red */ 37 | .ansi42 { background-color: green; } /* Green */ 38 | .ansi43 { background-color: yellow; } /* Yellow */ 39 | .ansi44 { background-color: blue; } /* Blue */ 40 | .ansi45 { background-color: magenta; } /* Magenta */ 41 | .ansi46 { background-color: cyan; } /* Cyan */ 42 | .ansi47 { background-color: white; } /* White */ 43 | .ansi100 { background-color: grey; } /* Bright Black (grey) */ 44 | .ansi101 { background-color: #FFCCCB; } /* Bright Red */ 45 | .ansi102 { background-color: lightgreen; } /* Bright Green */ 46 | .ansi103 { background-color: lightyellow;} /* Bright Yellow */ 47 | .ansi104 { background-color: lightblue; } /* Bright Blue */ 48 | .ansi105 { background-color: #ff80ff;}/* Bright Magenta */ 49 | .ansi106 { background-color: lightcyan; } /* Bright Cyan */ 50 | .ansi107 { background-color: #FFFFF7; } /* Bright White */ 51 | 52 | /* Styles */ 53 | .ansi1 { font-weight: bold; } /* Bold */ 54 | .ansi3 { font-style: italic; } /* Italic */ 55 | .ansi4 { text-decoration: underline; } /* Underline */ 56 | .ansi9 { text-decoration: line-through; } /* Strikethrough */ 57 | .ansi24 { text-decoration: none; } /* No underline */ 58 | .ansi39 { color: initial; } /* Default foreground color */ 59 | .ansi49 { background-color: initial; } /* Default background color */ 60 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/short_info_table.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict, List, Tuple 16 | 17 | from dash import no_update 18 | from dash.dependencies import Input, Output, State 19 | 20 | from nemo_inspector.callbacks import app 21 | from nemo_inspector.layouts import ( 22 | get_tables_layout, 23 | get_stats_input, 24 | ) 25 | from nemo_inspector.settings.constants import CHOOSE_GENERATION 26 | from nemo_inspector.utils.common import get_excluded_row, get_table_data 27 | 28 | 29 | @app.callback( 30 | [ 31 | Output("compare_models_rows", "children", allow_duplicate=True), 32 | Output("loading_container", "children", allow_duplicate=True), 33 | ], 34 | Input("base_model_answers_selector", "value"), 35 | State("loading_container", "children"), 36 | prevent_initial_call=True, 37 | ) 38 | def choose_base_model( 39 | base_model: str, 40 | loading_container: str, 41 | ) -> Tuple[List, bool]: 42 | if base_model == CHOOSE_GENERATION: 43 | return no_update, no_update 44 | get_excluded_row().clear() 45 | return ( 46 | get_tables_layout( 47 | base_model=base_model, 48 | ), 49 | loading_container + " ", 50 | ) 51 | 52 | 53 | @app.callback( 54 | Output("datatable", "data"), 55 | [ 56 | Input("datatable", "page_current"), 57 | Input("datatable", "page_size"), 58 | ], 59 | State("base_model_answers_selector", "value"), 60 | ) 61 | def change_page(page_current: int, page_size: int, base_model: str) -> List[Dict]: 62 | if not get_table_data(): 63 | return no_update 64 | return [ 65 | data[base_model][0] 66 | for data in get_table_data()[ 67 | page_current * page_size : (page_current + 1) * page_size 68 | ] 69 | if base_model in data.keys() 70 | ] 71 | 72 | 73 | @app.callback( 74 | [ 75 | Output("stats_input_container", "children", allow_duplicate=True), 76 | Output("js_container", "children", allow_duplicate=True), 77 | Output("js_trigger", "children", allow_duplicate=True), 78 | ], 79 | Input("stats_modes", "value"), 80 | State("js_trigger", "children"), 81 | prevent_initial_call=True, 82 | ) 83 | def change_stats_mode(modes: List[str], js_trigger: str) -> str: 84 | if modes is None: 85 | return no_update, no_update, no_update 86 | return get_stats_input(modes), "", js_trigger + " " 87 | -------------------------------------------------------------------------------- /nemo_inspector/utils/decoration/plain_text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from difflib import SequenceMatcher 16 | import re 17 | 18 | 19 | def tokenize(text: str): 20 | """ 21 | Tokenize the text into separate tokens for: 22 | - Whitespace sequences (\s+) 23 | - Word sequences (\w+) 24 | - Punctuation sequences ([^\w\s]+) 25 | 26 | The regex ensures we capture all text in order. 27 | """ 28 | # This pattern will capture tokens in the order they appear: 29 | # (\s+) => one or more whitespace chars 30 | # (\w+) => one or more word chars 31 | # ([^\w\s]+) => one or more chars that are not word chars or whitespace (punctuation) 32 | pattern = re.compile(r"(\s+|\w+|[^\w\s]+)") 33 | tokens = pattern.findall(text) 34 | return tokens 35 | 36 | 37 | def color_text_diff(text1: str, text2: str): 38 | if text1 == text2: 39 | return [(text1, {})] 40 | 41 | tokens1 = tokenize(text1) 42 | tokens2 = tokenize(text2) 43 | 44 | matcher = SequenceMatcher(None, tokens1, tokens2) 45 | result = [] 46 | 47 | for tag, i1, i2, j1, j2 in matcher.get_opcodes(): 48 | if tag == "equal": 49 | for k in range(j1, j2): 50 | result.append((tokens2[k], {})) 51 | 52 | elif tag == "replace": 53 | # Tokens from text1 (deleted) 54 | for k in range(i1, i2): 55 | result.append((tokens1[k], {"background-color": "#c8e6c9"})) 56 | # Tokens from text2 (inserted) 57 | for k in range(j1, j2): 58 | result.append( 59 | ( 60 | tokens2[k], 61 | { 62 | "background-color": "#ffcdd2", 63 | "text-decoration": "line-through", 64 | }, 65 | ) 66 | ) 67 | 68 | elif tag == "insert": 69 | for k in range(j1, j2): 70 | result.append( 71 | ( 72 | tokens2[k], 73 | { 74 | "background-color": "#ffcdd2", 75 | "text-decoration": "line-through", 76 | }, 77 | ) 78 | ) 79 | 80 | elif tag == "delete": 81 | for k in range(i1, i2): 82 | result.append((tokens1[k], {"background-color": "#c8e6c9"})) 83 | 84 | return result 85 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/save_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import os 17 | from typing import List, Tuple 18 | 19 | from dash import callback_context, html, no_update 20 | from dash.dependencies import Input, Output, State 21 | 22 | from nemo_inspector.callbacks import app 23 | from nemo_inspector.settings.constants import ( 24 | EXTRA_FIELDS, 25 | FILE_NAME, 26 | ) 27 | from nemo_inspector.settings.constants.paths import PATH_TO_THE_REPOSITORY 28 | from nemo_inspector.utils.common import get_table_data 29 | 30 | 31 | @app.callback( 32 | Output("save_dataset_modal", "is_open", allow_duplicate=True), 33 | Input("save_dataset", "n_clicks"), 34 | prevent_initial_call=True, 35 | ) 36 | def open_save_dataset_modal(n1: int) -> bool: 37 | ctx = callback_context 38 | if not ctx.triggered: 39 | return no_update 40 | 41 | return True 42 | 43 | 44 | @app.callback( 45 | [ 46 | Output("save_dataset_modal", "is_open", allow_duplicate=True), 47 | Output("error_message", "children"), 48 | ], 49 | Input("save_dataset_button", "n_clicks"), 50 | [ 51 | State("base_model_answers_selector", "value"), 52 | State("save_path", "value"), 53 | ], 54 | prevent_initial_call=True, 55 | ) 56 | def save_dataset(n_click: int, base_model: str, save_path: str) -> Tuple[List, bool]: 57 | if not n_click or not save_path or not base_model: 58 | return no_update, no_update 59 | if save_path.startswith("nemo_inspector"): 60 | save_path = os.path.join(PATH_TO_THE_REPOSITORY, save_path) 61 | if not os.path.exists(save_path): 62 | try: 63 | os.mkdir(save_path) 64 | except: 65 | return True, html.Pre(f"could not save generations by path {save_path}") 66 | 67 | new_data = {} 68 | 69 | for data in get_table_data(): 70 | for file_data in data[base_model]: 71 | file_name = file_data[FILE_NAME] 72 | if file_name not in new_data: 73 | new_data[file_name] = [] 74 | new_data[file_name].append( 75 | { 76 | key: value 77 | for key, value in file_data.items() 78 | if key not in EXTRA_FIELDS 79 | } 80 | ) 81 | 82 | for file_name, data in new_data.items(): 83 | with open(os.path.join(save_path, file_name + ".jsonl"), "w") as file: 84 | file.write("\n".join([json.dumps(line) for line in data])) 85 | 86 | return False, "" 87 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/sort_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | from typing import List, Tuple 17 | 18 | from dash import ALL, callback_context, html, no_update 19 | from dash.dependencies import Input, Output, State 20 | 21 | from nemo_inspector.callbacks import app 22 | from nemo_inspector.layouts import get_sorted_tables_layout 23 | from nemo_inspector.settings.constants import CHOOSE_GENERATION 24 | 25 | 26 | @app.callback( 27 | [ 28 | Output({"type": "sorting", "id": ALL}, "is_open"), 29 | Output("js_container", "children", allow_duplicate=True), 30 | Output("js_trigger", "children", allow_duplicate=True), 31 | ], 32 | [ 33 | Input({"type": "set_sorting_button", "id": ALL}, "n_clicks"), 34 | Input({"type": "apply_sorting_button", "id": ALL}, "n_clicks"), 35 | ], 36 | [ 37 | State({"type": "sorting", "id": ALL}, "is_open"), 38 | State("js_trigger", "children"), 39 | ], 40 | prevent_initial_call=True, 41 | ) 42 | def toggle_modal_sorting(n1: int, n2: int, is_open: bool, js_trigger: str) -> bool: 43 | ctx = callback_context 44 | if not ctx.triggered: 45 | return [no_update] * len(is_open), no_update, no_update 46 | 47 | button_id = json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"] + 1 48 | 49 | if not ctx.triggered[0]["value"]: 50 | return [no_update] * len(is_open), no_update, no_update 51 | 52 | if n1[button_id] or n2[button_id]: 53 | is_open[button_id] = not is_open[button_id] 54 | return is_open, "", js_trigger + " " 55 | return is_open, "", js_trigger + " " 56 | 57 | 58 | @app.callback( 59 | [ 60 | Output("compare_models_rows", "children", allow_duplicate=True), 61 | Output("sorting_container", "children"), 62 | Output("loading_container", "children", allow_duplicate=True), 63 | ], 64 | Input({"type": "apply_sorting_button", "id": -1}, "n_clicks"), 65 | [ 66 | State({"type": "sorting_function_input", "id": -1}, "value"), 67 | State({"type": "model_selector", "id": ALL}, "value"), 68 | State("base_model_answers_selector", "value"), 69 | State("loading_container", "children"), 70 | ], 71 | prevent_initial_call=True, 72 | ) 73 | def sorting_data( 74 | n_ckicks: int, 75 | sorting_function: str, 76 | models: List[str], 77 | base_model: str, 78 | loading_container: str, 79 | ) -> Tuple[List[html.Tr], bool]: 80 | if base_model == CHOOSE_GENERATION or not sorting_function: 81 | return no_update, no_update, no_update 82 | return ( 83 | get_sorted_tables_layout( 84 | base_model=base_model, 85 | sorting_function=sorting_function, 86 | models=models, 87 | ), 88 | html.Pre(f"Sorting function:\n{sorting_function}"), 89 | loading_container + " ", 90 | ) 91 | -------------------------------------------------------------------------------- /nemo_inspector/utils/decoration/code.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from html import escape 16 | from io import StringIO 17 | from typing import Dict, List, Tuple 18 | 19 | from dash import html 20 | from pygments.formatters import HtmlFormatter 21 | from pygments.lexers import PythonLexer 22 | 23 | from nemo_inspector.utils.decoration.common import ( 24 | get_random_id, 25 | iframe_template, 26 | update_height_js, 27 | ) 28 | 29 | 30 | def highlight_code(codes: List[Tuple[str, Dict[str, str]]], **kwargs) -> html.Iframe: 31 | 32 | full_code = "".join([code for code, style in codes]) 33 | 34 | # Track positions and styles 35 | positions = [] 36 | current_pos = 0 37 | for code, style in codes: 38 | start_pos = current_pos 39 | end_pos = current_pos + len(code) 40 | if style: 41 | positions.append((start_pos, end_pos, style)) 42 | current_pos = end_pos 43 | 44 | # Custom formatter to apply styles at correct positions 45 | class CustomHtmlFormatter(HtmlFormatter): 46 | def __init__(self, positions, **options): 47 | super().__init__(**options) 48 | self.positions = positions 49 | self.current_pos = 0 50 | 51 | def format(self, tokensource, outfile): 52 | style_starts = {start: style for start, _, style in self.positions} 53 | style_ends = {end: style for _, end, style in self.positions} 54 | active_styles = [] 55 | 56 | for ttype, value in tokensource: 57 | token_length = len(value) 58 | token_start = self.current_pos 59 | 60 | # Apply styles character by character 61 | result = "" 62 | for i, char in enumerate(value): 63 | char_pos = token_start + i 64 | 65 | # Check if a style starts or ends here 66 | if char_pos in style_starts: 67 | style = style_starts[char_pos] 68 | active_styles.append(style) 69 | if char_pos in style_ends: 70 | style = style_ends[char_pos] 71 | if style in active_styles: 72 | active_styles.remove(style) 73 | 74 | # Get CSS class for syntax highlighting 75 | css_class = self._get_css_class(ttype) 76 | char_html = escape(char) 77 | if css_class: 78 | char_html = f'{char_html}' 79 | 80 | # Apply active styles 81 | if active_styles: 82 | combined_style = {} 83 | for style_dict in active_styles: 84 | combined_style.update(style_dict) 85 | style_str = "; ".join( 86 | f"{k}: {v}" for k, v in combined_style.items() 87 | ) 88 | char_html = f'{char_html}' 89 | 90 | result += char_html 91 | 92 | outfile.write(result) 93 | self.current_pos += token_length 94 | 95 | # Use the custom formatter to highlight the code 96 | lexer = PythonLexer() 97 | formatter = CustomHtmlFormatter(positions, nowrap=True) 98 | style_defs = formatter.get_style_defs(".highlight") 99 | style_defs += """ 100 | .highlight { 101 | font-family: monospace; 102 | } 103 | """ 104 | 105 | output = StringIO() 106 | formatter.format(lexer.get_tokens(full_code), output) 107 | highlighted_code = output.getvalue() 108 | 109 | # Build the iframe content 110 | iframe_id = get_random_id() 111 | content = f""" 112 |
{highlighted_code}
113 | 114 | """ 115 | 116 | return html.Div( 117 | iframe_template( 118 | header=f"", 119 | content=content, 120 | iframe_id=iframe_id, 121 | style={"border": "black 1px solid", "background-color": "#ebecf0d8"}, 122 | ) 123 | ) 124 | -------------------------------------------------------------------------------- /nemo_inspector/utils/decoration/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | import re 17 | import string 18 | from typing import Dict, List, Union 19 | 20 | from ansi2html import Ansi2HTMLConverter 21 | from dash import dcc, html 22 | from nemo_inspector.settings.constants import ANSI, COMPARE, LATEX, MARKDOWN 23 | from nemo_inspector.utils.decoration.latex import preprocess_latex 24 | 25 | 26 | def design_text_output( 27 | texts: List[Union[str, str]], style={}, text_modes: List[str] = [LATEX, ANSI] 28 | ) -> html.Div: 29 | conv = Ansi2HTMLConverter() 30 | ansi_escape = re.compile(r"\x1b\[[0-9;]*m") 31 | full_text = "".join(map(lambda x: x[0], texts)) 32 | if ANSI in text_modes: 33 | if ( 34 | bool(ansi_escape.search(full_text)) 35 | or "ipython-input" in full_text 36 | or "Traceback" in full_text 37 | ): 38 | if bool(ansi_escape.search(full_text)): 39 | full_text = conv.convert(full_text, full=False) 40 | else: 41 | full_text = conv.convert(full_text.replace("[", "\u001b["), full=False) 42 | return html.Div( 43 | iframe_template( 44 | '', 45 | f"
{full_text}
", 46 | ), 47 | style=style, 48 | ) 49 | return html.Div( 50 | ( 51 | dcc.Markdown( 52 | preprocess_latex(full_text, escape=MARKDOWN not in text_modes), 53 | mathjax=True, 54 | dangerously_allow_html=True, 55 | ) 56 | if LATEX in text_modes and COMPARE not in text_modes 57 | else ( 58 | dcc.Markdown(full_text) 59 | if MARKDOWN in text_modes and COMPARE not in text_modes 60 | else [ 61 | html.Span(text, style={**inner_style, "whiteSpace": "pre-wrap"}) 62 | for text, inner_style in texts 63 | ] 64 | ) 65 | ), 66 | style=style, 67 | ) # TODO make classes 68 | 69 | 70 | def get_height_adjustment() -> html.Iframe: 71 | return html.Iframe( 72 | id="query_params_iframe", 73 | srcDoc=""" 74 | 75 | 76 | 77 | 78 | 79 | 84 | 85 | 86 | """, 87 | style={"visibility": "hidden"}, 88 | ) 89 | 90 | 91 | def update_height_js(iframe_id: str) -> str: 92 | return f""" 93 | function updateHeight() {{ 94 | var body = document.body, 95 | html = document.documentElement; 96 | 97 | var height = Math.max(body.scrollHeight, body.offsetHeight, 98 | html.clientHeight, html.scrollHeight, html.offsetHeight); 99 | 100 | parent.postMessage({{ frameHeight: height, frameId: '{iframe_id}' }}, '*'); 101 | }} 102 | window.onload = updateHeight; 103 | window.onresize = updateHeight; 104 | """ 105 | 106 | 107 | def get_random_id() -> str: 108 | return "".join(random.choices(string.ascii_letters + string.digits, k=20)) 109 | 110 | 111 | def iframe_template( 112 | header: str, content: str, style: Dict = {}, iframe_id: str = None 113 | ) -> html.Iframe: 114 | if not iframe_id: 115 | iframe_id = get_random_id() 116 | 117 | iframe_style = { 118 | "width": "100%", 119 | "border": "none", 120 | "overflow": "hidden", 121 | } 122 | 123 | iframe_style.update(style) 124 | 125 | return html.Iframe( 126 | id=iframe_id, 127 | srcDoc=f""" 128 | 129 | 130 | 131 | {header} 132 | 133 | 134 | {content} 135 | 136 | 137 | """, 138 | style=iframe_style, 139 | ) 140 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/count_stats_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from typing import List 17 | 18 | from dash import no_update 19 | from dash.dependencies import Input, Output, State 20 | 21 | from nemo_inspector.callbacks import app 22 | from nemo_inspector.layouts import ( 23 | get_tables_layout, 24 | get_stats_input, 25 | ) 26 | from nemo_inspector.settings.constants import ( 27 | CHOOSE_GENERATION, 28 | DELETE, 29 | ERROR_MESSAGE_TEMPLATE, 30 | GENERAL_STATS, 31 | INLINE_STATS, 32 | ) 33 | from nemo_inspector.utils.common import ( 34 | calculate_metrics_for_whole_data, 35 | get_custom_stats, 36 | get_deleted_stats, 37 | get_general_custom_stats, 38 | get_stats_raw, 39 | get_table_data, 40 | ) 41 | 42 | 43 | @app.callback( 44 | [ 45 | Output("new_stats", "is_open"), 46 | Output("stats_input_container", "children", allow_duplicate=True), 47 | Output("js_container", "children", allow_duplicate=True), 48 | Output("js_trigger", "children", allow_duplicate=True), 49 | ], 50 | [ 51 | Input("set_new_stats_button", "n_clicks"), 52 | Input("apply_new_stats", "n_clicks"), 53 | ], 54 | [ 55 | State("new_stats", "is_open"), 56 | State("stats_modes", "value"), 57 | State("js_trigger", "children"), 58 | ], 59 | prevent_initial_call=True, 60 | ) 61 | def toggle_modal_stats( 62 | n1: int, n2: int, is_open: bool, modes: List[str], js_trigger: str 63 | ) -> bool: 64 | if not n1 and not n2: 65 | return no_update, no_update, no_update, no_update 66 | 67 | if n1 or n2: 68 | is_open = not is_open 69 | return is_open, get_stats_input(modes), "", js_trigger + " " 70 | return is_open, get_stats_input(modes), "", js_trigger + " " 71 | 72 | 73 | @app.callback( 74 | Output("compare_models_rows", "children", allow_duplicate=True), 75 | Input("apply_new_stats", "n_clicks"), 76 | [ 77 | State("stats_input", "value"), 78 | State("base_model_answers_selector", "value"), 79 | State("stats_modes", "value"), 80 | ], 81 | prevent_initial_call=True, 82 | ) 83 | def apply_new_stat( 84 | n_click: int, 85 | code_raw: str, 86 | base_model: str, 87 | stats_modes: List[str], 88 | ) -> List: 89 | if not n_click or code_raw == "": 90 | return no_update 91 | code_raw_lines = code_raw.strip().split("\n") 92 | if not stats_modes or DELETE not in stats_modes: 93 | code = "\n".join(code_raw_lines[:-1]) + "\nnew_stats = " + code_raw_lines[-1] 94 | else: 95 | code = "delete_stats = " + f"'{code_raw_lines[-1]}'" 96 | namespace = {} 97 | try: 98 | exec(code, namespace) 99 | except Exception as e: 100 | logging.error(ERROR_MESSAGE_TEMPLATE.format(code, str(e))) 101 | return no_update 102 | if stats_modes and GENERAL_STATS in stats_modes: 103 | if DELETE in stats_modes: 104 | get_general_custom_stats().pop(namespace["delete_stats"], None) 105 | else: 106 | get_general_custom_stats().update(namespace["new_stats"]) 107 | get_stats_raw()[GENERAL_STATS][ 108 | " ".join(namespace["new_stats"].keys()) 109 | ] = code_raw 110 | else: 111 | if stats_modes and DELETE in stats_modes: 112 | get_custom_stats().pop(namespace["delete_stats"], None) 113 | get_deleted_stats().update(namespace["delete_stats"]) 114 | else: 115 | get_custom_stats().update(namespace["new_stats"]) 116 | get_stats_raw()[INLINE_STATS][ 117 | " ".join(namespace["new_stats"].keys()) 118 | ] = code_raw 119 | if base_model == CHOOSE_GENERATION: 120 | return [] 121 | calculate_metrics_for_whole_data(get_table_data(), base_model) 122 | return get_tables_layout(base_model=base_model) 123 | 124 | 125 | @app.callback( 126 | [ 127 | Output("stats_input", "value", allow_duplicate=True), 128 | Output("js_container", "children", allow_duplicate=True), 129 | Output("js_trigger", "children", allow_duplicate=True), 130 | ], 131 | Input("stats_extractor", "value"), 132 | [ 133 | State("stats_modes", "value"), 134 | State("js_trigger", "children"), 135 | ], 136 | prevent_initial_call=True, 137 | ) 138 | def apply_new_stat(stat: str, stats_modes: List[str], js_trigger: str) -> List: 139 | mode = GENERAL_STATS if GENERAL_STATS in stats_modes else INLINE_STATS 140 | return get_stats_raw()[mode][stat], " ", js_trigger + " " 141 | -------------------------------------------------------------------------------- /nemo_inspector/parse_agruments_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import dataclasses 17 | from typing import Any, Optional, Type 18 | 19 | from nemo_inspector.utils.common import get_type_default, resolve_union_or_any 20 | 21 | 22 | class ParseDict(argparse.Action): 23 | def __call__(self, parser, namespace, values, option_string=None): 24 | setattr(namespace, self.dest, dict()) 25 | for value in values: 26 | key, value = value.split("=") 27 | getattr(namespace, self.dest)[key] = value 28 | 29 | 30 | def add_arguments_from_dataclass( 31 | parser: argparse.ArgumentParser, 32 | dataclass_type: Type[Any], 33 | prefix: str = "", 34 | use_default: Optional[str] = None, 35 | enforce_required: bool = True, 36 | use_type_defaults: bool = True, 37 | ): 38 | for field in dataclasses.fields(dataclass_type): 39 | field_name = f"{prefix}{field.name}" 40 | field_type = field.type 41 | # Handle default values and default_factory 42 | has_default = field.default != dataclasses.MISSING 43 | has_default_factory = field.default_factory != dataclasses.MISSING 44 | 45 | # Handle nested dataclasses 46 | if dataclasses.is_dataclass(field_type): 47 | add_arguments_from_dataclass( 48 | parser, 49 | field_type, 50 | prefix=f"{field_name}.", 51 | use_default=use_default, 52 | enforce_required=enforce_required, 53 | use_type_defaults=use_type_defaults, 54 | ) 55 | continue 56 | 57 | arg_name = f"--{field_name}" 58 | 59 | # Handle Optional and Union types 60 | field_type = resolve_union_or_any(field_type) 61 | 62 | # Determine if the argument is required 63 | is_required = not has_default and not has_default_factory and use_default 64 | 65 | # Helper function to add the argument to the parser 66 | def add_argument(**kwargs): 67 | # Ensure default is passed only once 68 | if "default" not in kwargs: 69 | if use_type_defaults and field.default == dataclasses.MISSING: 70 | kwargs["default"] = get_type_default(field_type) 71 | elif has_default: 72 | kwargs["default"] = field.default 73 | kwargs["required"] = kwargs.get("required", is_required and enforce_required) 74 | parser.add_argument(arg_name, **kwargs) 75 | 76 | # Add argument based on the type of the field 77 | default_message = f"(default: {str(field.default)})" if has_default else "" 78 | kwargs = {"default": use_default} if use_default else {} 79 | if field_type == bool: 80 | add_argument( 81 | action="store_true", 82 | help=f"{field_name} flag {default_message}", 83 | **kwargs, 84 | ) 85 | elif field_type == dict: 86 | add_argument( 87 | nargs="*", 88 | action=ParseDict, 89 | help=f"{field_name} flag {default_message}", 90 | **kwargs, 91 | ) 92 | else: 93 | add_argument( 94 | type=field_type, help=f"{field_name} flag {default_message}", **kwargs 95 | ) 96 | 97 | 98 | def create_dataclass_from_args( 99 | dataclass_type: Type[Any], args_dict: dict, prefix: str = "" 100 | ): 101 | init_kwargs = {} 102 | for field in dataclasses.fields(dataclass_type): 103 | field_name = f"{prefix}{field.name}" 104 | field_type = field.type 105 | # Handle nested dataclasses 106 | if dataclasses.is_dataclass(field_type): 107 | value = create_dataclass_from_args( 108 | field_type, 109 | args_dict, 110 | prefix=f"{field_name}.", 111 | ) 112 | init_kwargs[field.name] = value 113 | continue 114 | 115 | arg_name = field_name 116 | if arg_name in args_dict: 117 | value = args_dict[arg_name] 118 | else: 119 | if field.default != dataclasses.MISSING: 120 | value = field.default 121 | elif field.default_factory != dataclasses.MISSING: 122 | value = field.default_factory() 123 | else: 124 | value = get_type_default(field_type) 125 | init_kwargs[field.name] = value 126 | return dataclass_type(**init_kwargs) 127 | 128 | 129 | def args_postproccessing(args): 130 | args["input_file"] = str(args.get("input_file", "")) 131 | 132 | return args 133 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/filter_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | from typing import List, Tuple 17 | 18 | from dash import ALL, callback_context, html, no_update 19 | from dash.dependencies import Input, Output, State 20 | 21 | from nemo_inspector.callbacks import app 22 | from nemo_inspector.layouts import ( 23 | get_filtered_tables_layout, 24 | get_filter_text, 25 | get_sorted_tables_layout, 26 | ) 27 | from nemo_inspector.settings.constants import ( 28 | CHOOSE_GENERATION, 29 | FILES_FILTERING, 30 | QUESTIONS_FILTERING, 31 | ) 32 | from nemo_inspector.utils.common import get_table_data 33 | 34 | 35 | @app.callback( 36 | [ 37 | Output({"type": "filter", "id": ALL}, "is_open"), 38 | Output("js_container", "children", allow_duplicate=True), 39 | Output("js_trigger", "children", allow_duplicate=True), 40 | ], 41 | [ 42 | Input({"type": "set_filter_button", "id": ALL}, "n_clicks"), 43 | Input({"type": "apply_filter_button", "id": ALL}, "n_clicks"), 44 | ], 45 | [ 46 | State({"type": "filter", "id": ALL}, "is_open"), 47 | State("js_trigger", "children"), 48 | ], 49 | prevent_initial_call=True, 50 | ) 51 | def toggle_modal_filter(n1: int, n2: int, is_open: bool, js_trigger: str) -> bool: 52 | ctx = callback_context 53 | if not ctx.triggered: 54 | return [no_update] * len(is_open), no_update, no_update 55 | button_id = json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"] + 1 56 | if not ctx.triggered[0]["value"]: 57 | return [no_update] * len(is_open), "", js_trigger + " " 58 | 59 | if n1[button_id] or n2[button_id]: 60 | is_open[button_id] = not is_open[button_id] 61 | return is_open, "", js_trigger + " " 62 | return is_open, "", js_trigger + " " 63 | 64 | 65 | @app.callback( 66 | [ 67 | Output("compare_models_rows", "children", allow_duplicate=True), 68 | Output("filtering_container", "children"), 69 | Output("loading_container", "children", allow_duplicate=True), 70 | ], 71 | [ 72 | Input({"type": "apply_filter_button", "id": -1}, "n_clicks"), 73 | ], 74 | [ 75 | State({"type": "filter_function_input", "id": -1}, "value"), 76 | State({"type": "apply_on_filtered_data", "id": -1}, "value"), 77 | State({"type": "filter_mode", "id": -1}, "value"), 78 | State({"type": "sorting_function_input", "id": -1}, "value"), 79 | State({"type": "model_selector", "id": ALL}, "value"), 80 | State("base_model_answers_selector", "value"), 81 | State("filtering_container", "children"), 82 | State("loading_container", "children"), 83 | ], 84 | prevent_initial_call=True, 85 | ) 86 | def filter_data( 87 | n_ckicks: int, 88 | filter_function: str, 89 | apply_on_filtered_data: int, 90 | filter_mode: List[str], 91 | sorting_function: str, 92 | models: List[str], 93 | base_model: str, 94 | filtering_functions: str, 95 | loading_container: str, 96 | ) -> Tuple[List[html.Tr], bool]: 97 | if not n_ckicks: 98 | return no_update, no_update, no_update 99 | if apply_on_filtered_data and filtering_functions: 100 | filtering_functions["props"]["children"] += f"\n{filter_function}" 101 | if base_model == CHOOSE_GENERATION: 102 | return [], no_update, no_update 103 | if len(get_table_data()) == 0: # TODO fix 104 | models = [models[0]] 105 | get_filtered_tables_layout( 106 | base_model=base_model, 107 | filtering_function=filter_function, 108 | filter_mode=( 109 | FILES_FILTERING if filter_mode and len(filter_mode) else QUESTIONS_FILTERING 110 | ), 111 | apply_on_filtered_data=(apply_on_filtered_data if apply_on_filtered_data else 0), 112 | models=models, 113 | ) 114 | rows = get_sorted_tables_layout( 115 | base_model=base_model, 116 | sorting_function=sorting_function, 117 | models=models, 118 | ) 119 | return ( 120 | rows, 121 | ( 122 | html.Pre(f"Filtering function:\n{filter_function}") 123 | if not apply_on_filtered_data or not filtering_functions 124 | else filtering_functions 125 | ), 126 | loading_container + " ", 127 | ) 128 | 129 | 130 | @app.callback( 131 | [ 132 | Output( 133 | {"type": "filter_text", "id": -1}, 134 | "children", 135 | allow_duplicate=True, 136 | ), 137 | Output("js_container", "children", allow_duplicate=True), 138 | Output("js_trigger", "children", allow_duplicate=True), 139 | ], 140 | Input({"type": "filter_mode", "id": -1}, "value"), 141 | State("js_trigger", "children"), 142 | prevent_initial_call=True, 143 | ) 144 | def change_filter_mode(modes: List[str], js_trigger: str) -> str: 145 | if modes is None: 146 | return no_update, no_update, no_update 147 | mode = FILES_FILTERING if modes and len(modes) else QUESTIONS_FILTERING 148 | text = get_filter_text(mode=mode) 149 | return ( 150 | text, 151 | "", 152 | js_trigger + " ", 153 | ) 154 | -------------------------------------------------------------------------------- /nemo_inspector/layouts/common_layouts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import itertools 16 | from typing import Dict, Iterable, List, Optional, Union 17 | 18 | import dash_bootstrap_components as dbc 19 | from dash import html 20 | 21 | from nemo_inspector.settings.constants import ( 22 | ANSI, 23 | CODE, 24 | COMPARE, 25 | LATEX, 26 | MARKDOWN, 27 | ) 28 | from nemo_inspector.utils.common import parse_model_answer 29 | from nemo_inspector.utils.decoration import ( 30 | color_text_diff, 31 | design_text_output, 32 | highlight_code, 33 | ) 34 | 35 | 36 | def get_switch_layout( 37 | id: Union[Dict, str], 38 | labels: List[str], 39 | values: Optional[List[str]] = None, 40 | disabled: List[bool] = [False], 41 | is_active: bool = False, 42 | chosen_values: Optional[List[str]] = None, 43 | additional_params: Dict = {}, 44 | ) -> dbc.Checklist: 45 | if values is None: 46 | values = labels 47 | return dbc.Checklist( 48 | id=id, 49 | options=[ 50 | { 51 | "label": label, 52 | "value": value, 53 | "disabled": is_disabled, 54 | } 55 | for label, value, is_disabled in itertools.zip_longest( 56 | labels, values, disabled, fillvalue=False 57 | ) 58 | ], 59 | value=(chosen_values if chosen_values else [values[0]] if is_active else []), 60 | **additional_params, 61 | ) 62 | 63 | 64 | def get_selector_layout(options: Iterable, id: str, value: str = "") -> dbc.Select: 65 | if value not in options: 66 | options = [value] + list(options) 67 | return dbc.Select( 68 | id=id, 69 | options=[ 70 | { 71 | "label": str(value), 72 | "value": value, 73 | } 74 | for value in options 75 | ], 76 | value=str(value), 77 | ) 78 | 79 | 80 | def get_single_prompt_output_layout( 81 | answer: str, text_modes: List[str] = [CODE, LATEX, ANSI], compare_to: str = "" 82 | ) -> List[html.Div]: 83 | parsed_answers = ( 84 | parse_model_answer(answer) 85 | if CODE in text_modes 86 | else [{"explanation": answer, "code": None, "output": None}] 87 | ) 88 | parsed_compared_answers = ( 89 | ( 90 | parse_model_answer(compare_to) 91 | if CODE in text_modes 92 | else [{"explanation": compare_to, "code": None, "output": None}] 93 | ) 94 | if COMPARE in text_modes 95 | else parsed_answers 96 | ) 97 | 98 | items = [] 99 | styles = { 100 | "explanation": {"default": {}, "wrong": {}}, 101 | "code": {"default": {}, "wrong": {}}, 102 | "output": { 103 | "default": { 104 | "border": "1px solid black", 105 | "background-color": "#cdd4f1c8", 106 | "marginBottom": "10px", 107 | "marginTop": "-6px", 108 | }, 109 | "wrong": { 110 | "border": "1px solid red", 111 | "marginBottom": "10px", 112 | "marginTop": "-6px", 113 | }, 114 | }, 115 | } 116 | 117 | functions = { 118 | "explanation": design_text_output, 119 | "code": highlight_code, 120 | "output": design_text_output, 121 | } 122 | 123 | def check_existence(array: List[Dict[str, str]], i: int, key: str): 124 | return i < len(array) and key in array[i] and array[i][key] 125 | 126 | for i in range(max(len(parsed_answers), len(parsed_compared_answers))): 127 | for key in ["explanation", "code", "output"]: 128 | if check_existence(parsed_answers, i, key) or check_existence( 129 | parsed_compared_answers, i, key 130 | ): 131 | diff = color_text_diff( 132 | ( 133 | parsed_answers[i][key] 134 | if check_existence(parsed_answers, i, key) 135 | else "" 136 | ), 137 | ( 138 | parsed_compared_answers[i][key] 139 | if check_existence(parsed_compared_answers, i, key) 140 | else "" 141 | ), 142 | ) 143 | style_type = ( 144 | "default" 145 | if not check_existence(parsed_answers, i, key) 146 | or "wrong_code_block" not in parsed_answers[i][key] 147 | else "wrong" 148 | ) 149 | style = styles[key][style_type] 150 | item = functions[key](diff, style=style, text_modes=text_modes) 151 | items.append(item) 152 | return items 153 | 154 | 155 | def get_text_modes_layout(id: str, is_formatted: bool = True): 156 | return get_switch_layout( 157 | id={ 158 | "type": "text_modes", 159 | "id": id, 160 | }, 161 | labels=[CODE, LATEX, MARKDOWN, ANSI], 162 | chosen_values=[CODE, LATEX, ANSI] if is_formatted else [], 163 | additional_params={ 164 | "style": { 165 | "display": "inline-flex", 166 | "flex-wrap": "wrap", 167 | }, 168 | "inputStyle": {"margin-left": "-10px"}, 169 | "labelStyle": {"margin-left": "3px"}, 170 | }, 171 | ) 172 | -------------------------------------------------------------------------------- /nemo_inspector/utils/decoration/latex.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Callable, Optional, Tuple 16 | 17 | 18 | def get_starts_with_latex_tag_function( 19 | tag: str, default_index_move: int 20 | ) -> Callable[[str, int], Tuple[bool, int]]: 21 | def starts_with_tag_func_templ(text: str, index: int): 22 | is_starts_with_tag = text.startswith(tag, index) 23 | if not is_starts_with_tag: 24 | returning_index = index + default_index_move 25 | elif "{" not in tag: 26 | returning_index = index + len(tag) 27 | else: 28 | returning_index = text.find("}", index) % (len(text) + 1) 29 | 30 | return is_starts_with_tag, returning_index 31 | 32 | return starts_with_tag_func_templ 33 | 34 | 35 | def proccess_latex_tag( 36 | text: str, 37 | start_index: int, 38 | detect_start_token: Callable[[str, int], Tuple[bool, int]], 39 | detect_end_token: Callable[[str, int], Tuple[bool, int]], 40 | end_sign: Optional[str], 41 | last_block_only: bool = False, 42 | ) -> int: 43 | count = 0 44 | index = start_index 45 | while index < len(text): 46 | if end_sign and text[index] == end_sign: 47 | return start_index, start_index + 1 48 | is_start_token, new_index = detect_start_token(text, index) 49 | count += is_start_token 50 | if last_block_only and is_start_token: 51 | start_index = index 52 | count = min(1, count) 53 | index = new_index 54 | is_end_token, index = detect_end_token(text, index) 55 | count -= is_end_token 56 | if count == 0: 57 | break 58 | return start_index, index + 1 59 | 60 | 61 | def get_single_dollar_detection_functions( 62 | direction: int, default_index_move: int 63 | ) -> Callable[[str, int], Tuple[bool, int]]: 64 | return lambda text, index: ( 65 | text[index] == "$" and not text[index + direction].isspace(), 66 | index + default_index_move, 67 | ) 68 | 69 | 70 | def get_latex_detection_functions(text, index) -> tuple[ 71 | Callable[[str, int], Tuple[bool, int]], 72 | Callable[[str, int], Tuple[bool, int]], 73 | Optional[str], 74 | bool, 75 | bool, 76 | ]: 77 | multiline_tags = [("\\begin{", "\\end{", True), ("$$", "$$", False)] 78 | for start_tag, end_tag, add_dollars in multiline_tags: 79 | if text.startswith(start_tag, index): 80 | return ( 81 | get_starts_with_latex_tag_function(start_tag, 1), 82 | get_starts_with_latex_tag_function(end_tag, 0), 83 | None, 84 | add_dollars, 85 | False, 86 | ) 87 | 88 | starts_with_dollar_func = get_single_dollar_detection_functions(1, 1) 89 | ends_with_dollar_func = get_single_dollar_detection_functions(-1, 0) 90 | if starts_with_dollar_func(text, index)[0]: 91 | return starts_with_dollar_func, ends_with_dollar_func, "\n", False, True 92 | 93 | return None, None, None, None, None 94 | 95 | 96 | def proccess_plain_text(text: str) -> str: 97 | special_chars = r"*_{}[]()#+-.!`" 98 | for character in special_chars: 99 | text = text.replace(character, "\\" + character) 100 | return text 101 | 102 | 103 | def preprocess_latex(text: str, escape: bool = True) -> str: 104 | text = "\n" + text.replace("\\[", "\n$$\n").replace("\\]", "\n$$\n").replace( 105 | "\\(", " $" 106 | ).replace("\\)", "$ ") 107 | 108 | right_side_operations = ["-", "=", "+", "*", "/"] 109 | left_side_operations = ["=", "+", "*", "/"] 110 | for op in right_side_operations: 111 | text = text.replace(op + "$", op + " $") 112 | 113 | for op in left_side_operations: 114 | text = text.replace("$" + op, "$ " + op) 115 | 116 | text += "\n" 117 | index = 1 118 | texts = [] 119 | start_plain_text_index = -1 120 | while index < len(text) - 1: 121 | ( 122 | detect_start_token, 123 | detect_end_token, 124 | end_sign, 125 | add_dollars, 126 | use_last_block_only, 127 | ) = get_latex_detection_functions(text, index) 128 | if detect_start_token is not None: 129 | if start_plain_text_index != -1: 130 | texts.append( 131 | proccess_plain_text(text[start_plain_text_index:index]) 132 | if escape 133 | else text[start_plain_text_index:index] 134 | ) 135 | start_plain_text_index = -1 136 | 137 | start_index, new_index = proccess_latex_tag( 138 | text, 139 | index, 140 | detect_start_token, 141 | detect_end_token, 142 | end_sign, 143 | use_last_block_only, 144 | ) 145 | texts.append( 146 | proccess_plain_text(text[index:start_index]) 147 | if escape 148 | else text[index:start_index] 149 | ) 150 | if add_dollars: 151 | texts.append("\n$$\n") 152 | texts.append(text[start_index:new_index].strip()) 153 | texts.append("\n$$\n") 154 | else: 155 | texts.append(text[start_index:new_index]) 156 | index = new_index 157 | elif start_plain_text_index == -1: 158 | start_plain_text_index = index 159 | index += 1 160 | else: 161 | index += 1 162 | if start_plain_text_index != -1: 163 | texts.append( 164 | proccess_plain_text(text[start_plain_text_index:]) 165 | if escape 166 | else text[start_plain_text_index:] 167 | ) 168 | return "".join(texts).replace("\n", "\n\n").strip() 169 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeMo Inspector 2 | 3 | NeMo Inspector is a tool designed to help you analyze Large Language Model (LLM) generations. It lets you explore and manipulate existing generations, apply filters, sorting criteria, and compute statistics. 4 | 5 | ## Prerequisites 6 | 7 | 1. **Clone and Install the Tool:** 8 | ```shell 9 | git clone git@github.com:NVIDIA/NeMo-Inspector.git 10 | cd NeMo-Inspector 11 | pip install . 12 | ``` 13 | 14 | 2. **Launch the Tool:** 15 | ```shell 16 | nemo_inspector 17 | ``` 18 | 19 | This will start a local server that you can access through your browser. 20 | 21 | ## Analyze Page 22 | 23 | The Analyze page helps you work with pre-generated outputs. To use it, provide paths to the generation files using command-line arguments. For example: 24 | 25 | ```shell 26 | nemo_inspector --model_prediction \ 27 | generation1='/path/to/generation1/output-greedy.jsonl' \ 28 | generation2='/path/to/generation2/output-rs*.jsonl' 29 | ``` 30 | Once loaded, the Analyze page lets you: 31 | 32 | - **Sort and Filter Results:** Apply custom filtering and sorting functions to refine the displayed data. 33 | - **Compare Generations:** View outputs from multiple generation runs side-by-side. 34 | - **Modify and Label Data:** Update or annotate samples and save the changes for future reference. 35 | - **Compute Statistics:** Generate both custom and general statistics to summarize your data. 36 | 37 | ### Filtering 38 | 39 | The tool supports two filtering modes: **Filter Files** mode and **Filter Questions** mode. You can define custom filtering functions in Python and run them directly in the UI. 40 | 41 | #### Filter Files Mode 42 | 43 | - In this mode, the filtering function will be run on each sample across different files simultaneously. 44 | - The input to the filtering function is a dictionary where keys represent generation names and values are JSON objects for that sample. 45 | - The custom function should return a Boolean value (`True` to keep the sample, `False` to filter it out). 46 | 47 | Example of a custom filtering function: 48 | 49 | ```python 50 | def custom_filtering_function(error_message: str) -> bool: 51 | # Implement your logic here 52 | return 'timeout' not in error_message 53 | 54 | # This line will be used for the filtering: 55 | custom_filtering_function(data['generation1']['error_message']) 56 | ``` 57 | 58 | **Note:** The last line of the custom filtering function is used for filtering. All preceding lines are just for computation. 59 | 60 | To apply multiple conditions to multiple generations, use the `&&` separator. For instance: 61 | 62 | ```python 63 | data['generation1']['is_correct'] && not data['generation2']['is_correct'] 64 | ``` 65 | 66 | **Important:** In Filter Files mode, do not write multi-generation conditions without using `&&`. Each condition should be separated by `&&`. 67 | 68 | #### Filter Questions Mode 69 | 70 | - In this mode, the function filters each question across multiple files without filtering out entire files. 71 | - The input is a dictionary of generation names mapping to **lists** of JSON data for that question. 72 | 73 | In this mode, you write conditions without the `&&` operator. For example: 74 | 75 | ```python 76 | data['generation1'][0]['is_correct'] and not data['generation2'][0]['is_correct'] 77 | ``` 78 | 79 | This example filters out questions where the first generation is correct and the second generation is incorrect. It can also compare fields directly: 80 | 81 | ```python 82 | data['generation1'][0]['is_correct'] != data['generation2'][0]['is_correct'] 83 | ``` 84 | 85 | **Note:** These examples cannot be used in Filter Files mode. 86 | 87 | ### Sorting 88 | 89 | Sorting functions are similar to filtering functions, but there are key differences: 90 | 91 | 1. **Scope:** Sorting functions operate on individual data entries (not dictionaries with multiple generations). 92 | 2. **Cross-Generations:** Sorting cannot be applied across multiple generations at once. You must sort one generation at a time. 93 | 94 | A correct sorting function might look like this: 95 | 96 | ```python 97 | def custom_sorting_function(generation: str): 98 | # Sort by the length of the generation text 99 | return len(generation) 100 | 101 | # This line will be used for the sorting: 102 | custom_sorting_function(data['generation']) 103 | ``` 104 | 105 | ### Statistics 106 | 107 | NeMo Inspector supports two types of statistics: 108 | 109 | 1. **Custom Statistics:** Applied to the samples of a single question (for each generation). 110 | 111 | Default custom statistics include: 112 | - `correct_responses` 113 | - `wrong_responses` 114 | - `no_responses` 115 | 116 | 2. **General Custom Statistics:** Applied across all questions and all generations. 117 | 118 | Default general custom statistics include: 119 | - `dataset size` 120 | - `overall number of samples` 121 | - `generations per sample` 122 | 123 | You can change the existing or define your own Custom and General Custom Statistics functions. 124 | 125 | **Custom Statistics Example:** 126 | 127 | ```python 128 | def unique_error_counter(datas): 129 | # `datas` is a list of JSONs (one per file) for a single question 130 | unique_errors = set() 131 | for data in datas: 132 | unique_errors.add(data.get('error_message')) 133 | return len(unique_errors) 134 | 135 | def number_of_runs(datas): 136 | return len(datas) 137 | 138 | # Map function names to functions 139 | {'unique_errors': unique_error_counter, 'number_of_runs': number_of_runs} 140 | ``` 141 | 142 | **General Custom Statistics Example:** 143 | 144 | ```python 145 | def overall_unique_error_counter(datas): 146 | # `datas` is a list of lists of dictionaries, 147 | # where datas[question_index][file_index] is a JSON record 148 | unique_errors = set() 149 | for question_data in datas: 150 | for file_data in question_data: 151 | unique_errors.add(file_data.get('error_message')) 152 | return len(unique_errors) 153 | 154 | # Map function names to functions 155 | {'unique_errors': overall_unique_error_counter} 156 | ``` 157 | 158 | **Note:** The final line in both the Custom and General Custom Statistics code blocks should be a dictionary mapping function names to their corresponding functions. 159 | 160 | ### Modifications 161 | 162 | You can update each sample in the dataset programmatically. At the end of the code block, return the updated sample dictionary: 163 | 164 | ```python 165 | # For example, strip leading and trailing whitespace from the "generation" field 166 | {**data, 'generation': data['generation'].strip()} 167 | ``` 168 | -------------------------------------------------------------------------------- /nemo_inspector/layouts/analyze_page_layouts/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List 16 | 17 | from dash import html 18 | import dash_ace 19 | 20 | from nemo_inspector.layouts.common_layouts import get_selector_layout 21 | from nemo_inspector.settings.constants import STATS_KEYS 22 | from nemo_inspector.utils.common import get_metrics, get_table_data 23 | from nemo_inspector.settings.constants.common import ( 24 | CUSTOM, 25 | DELETE, 26 | FILES_FILTERING, 27 | FILES_ONLY, 28 | GENERAL_STATS, 29 | INLINE_STATS, 30 | QUESTIONS_FILTERING, 31 | ) 32 | from nemo_inspector.utils.common import ( 33 | get_custom_stats, 34 | get_general_custom_stats, 35 | get_stats_raw, 36 | ) 37 | 38 | 39 | def get_filter_text( 40 | available_filters: List[str] = [], mode: str = FILES_FILTERING 41 | ) -> str: 42 | available_filters = list( 43 | get_table_data()[0][list(get_table_data()[0].keys())[0]][0].keys() 44 | if len(get_table_data()) and not available_filters 45 | else STATS_KEYS + list(get_metrics([]).keys()) + ["+ all fields in json"] 46 | ) 47 | if mode == FILES_ONLY: 48 | return ( 49 | "Write an expression to filter the data\n\n" 50 | + "For example:\ndata['is_correct'] and not data['error_message']\n\n" 51 | + "The expression has to return bool.\n\n" 52 | + "Available parameters to filter data:\n" 53 | + "\n".join( 54 | [ 55 | ", ".join(available_filters[start : start + 5]) 56 | for start in range(0, len(available_filters), 5) 57 | ] 58 | ), 59 | ) 60 | elif mode == FILES_FILTERING: 61 | return ( 62 | "Write an expression to filter the data\n" 63 | + "Separate expressions for different generations with &&\n" 64 | + "You can use base_generation variable to access data from the current generation\n\n" 65 | + "For example:\ndata['generation1']['correct_responses'] > 0.5 && data[base_generation]['no_response'] < 0.2\n\n" 66 | + "The expression has to return bool.\n\n" 67 | + "Available parameters to filter data:\n" 68 | + "\n".join( 69 | [ 70 | ", ".join(available_filters[start : start + 5]) 71 | for start in range(0, len(available_filters), 5) 72 | ] 73 | ), 74 | ) 75 | elif mode == QUESTIONS_FILTERING: 76 | return ( 77 | "Write an expression to filter the data\n" 78 | + "You can operate with a dictionary containing keys representing generation names\n" 79 | + "and a list of values as JSON data from your generation from each file.\n" 80 | + "You can use base_generation variable to access data from the current generation\n\n" 81 | + "For example:\ndata['generation1'][0]['is_correct'] != data[base_generation][0]['is_correct']\n\n" 82 | + "The expression has to return bool.\n\n" 83 | + "Available parameters to filter data:\n" 84 | + "\n".join( 85 | [ 86 | ", ".join(available_filters[start : start + 5]) 87 | for start in range(0, len(available_filters), 5) 88 | ] 89 | ), 90 | ) 91 | 92 | 93 | def get_stats_text(general_stats: bool = False, delete: bool = False): 94 | if delete: 95 | return "Choose the name of the statistic you want to delete" 96 | else: 97 | if general_stats: 98 | return ( 99 | "Creating General Custom Statistics:\n\n" 100 | "To introduce new general custom statistics:\n" 101 | "1. Create a dictionary where keys are the names of your custom stats.\n" 102 | "2. Assign functions as values. These functions should accept arrays where first dimension\n" 103 | "is a question index and second is a file number (both sorted and filtered).\n\n" 104 | "Example:\n\n" 105 | "Define a custom function to integrate into your stats:\n\n" 106 | "def my_func(datas):\n" 107 | " correct_responses = 0\n" 108 | " for question_data in datas:\n" 109 | " for file_data in question_data:\n" 110 | " correct_responses += file_data['is_correct']\n" 111 | " return correct_responses\n" 112 | "{'correct_responses': my_func}" 113 | ) 114 | else: 115 | return ( 116 | "Creating Custom Statistics:\n\n" 117 | "To introduce new custom statistics:\n" 118 | "1. Create a dictionary where keys are the names of your custom stats.\n" 119 | "2. Assign functions as values. These functions should accept arrays containing data\n" 120 | "from all relevant files.\n\n" 121 | "Note: Do not use names that already exist in the current stats or JSON fields\n" 122 | "to avoid conflicts.\n\n" 123 | "Example:\n\n" 124 | "Define a custom function to integrate into your stats:\n\n" 125 | "def unique_error_counter(datas):\n" 126 | " unique_errors = set()\n" 127 | " for data in datas:\n" 128 | " unique_errors.add(data.get('error_message'))\n" 129 | " return len(unique_errors)\n\n" 130 | "{'unique_error_count': unique_error_counter}" 131 | ) 132 | 133 | 134 | def get_code_text_area_layout(id) -> dash_ace.DashAceEditor: 135 | return dash_ace.DashAceEditor( 136 | id=id, 137 | theme="tomorrow_night", 138 | mode="python", 139 | tabSize=4, 140 | enableBasicAutocompletion=True, 141 | enableLiveAutocompletion=True, 142 | placeholder="Write your code here...", 143 | value="", 144 | style={"width": "100%", "height": "300px"}, 145 | ) 146 | 147 | 148 | def get_stats_input(modes: List[str] = []) -> List: 149 | body = [] 150 | if DELETE in modes: 151 | delete_options = list( 152 | get_general_custom_stats().keys() 153 | if GENERAL_STATS in modes 154 | else get_custom_stats().keys() 155 | ) 156 | body += [ 157 | get_selector_layout( 158 | delete_options, 159 | "stats_input", 160 | delete_options[0] if delete_options else "", 161 | ) 162 | ] 163 | else: 164 | mode = GENERAL_STATS if GENERAL_STATS in modes else INLINE_STATS 165 | extractor_options = list(get_stats_raw()[mode].keys()) 166 | body += [ 167 | get_selector_layout(extractor_options, "stats_extractor", CUSTOM), 168 | get_code_text_area_layout(id="stats_input"), 169 | ] 170 | return [ 171 | html.Pre( 172 | get_stats_text(GENERAL_STATS in modes, DELETE in modes), id="stats_text" 173 | ), 174 | ] + body 175 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/label_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | from typing import List, Tuple 17 | 18 | from dash import ALL, callback_context, no_update 19 | from dash.dependencies import Input, Output, State 20 | 21 | from nemo_inspector.callbacks import app 22 | from nemo_inspector.settings.constants import ( 23 | CHOOSE_LABEL, 24 | FILE_NAME, 25 | LABEL, 26 | LABEL_SELECTOR_ID, 27 | ) 28 | from nemo_inspector.utils.common import ( 29 | get_labels, 30 | get_table_data, 31 | ) 32 | 33 | 34 | @app.callback( 35 | Output({"type": "label", "id": ALL}, "is_open"), 36 | [ 37 | Input({"type": "set_file_label_button", "id": ALL}, "n_clicks"), 38 | Input({"type": "apply_label_button", "id": ALL}, "n_clicks"), 39 | Input({"type": "delete_label_button", "id": ALL}, "n_clicks"), 40 | ], 41 | [State({"type": "label", "id": ALL}, "is_open")], 42 | ) 43 | def toggle_modal_label(n1: int, n2: int, n3: int, is_open: bool) -> bool: 44 | ctx = callback_context 45 | if not ctx.triggered: 46 | return [no_update] * len(is_open) 47 | 48 | button_id = json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"] + 1 49 | if not ctx.triggered[0]["value"]: 50 | return [no_update] * len(is_open) 51 | 52 | if n1[button_id] or n2[button_id] or n3[button_id]: 53 | is_open[button_id] = not is_open[button_id] 54 | return is_open 55 | return is_open 56 | 57 | 58 | @app.callback( 59 | Output( 60 | "dummy_output", 61 | "children", 62 | allow_duplicate=True, 63 | ), 64 | [ 65 | Input({"type": "apply_label_button", "id": ALL}, "n_clicks"), 66 | Input({"type": "delete_label_button", "id": ALL}, "n_clicks"), 67 | ], 68 | [ 69 | State( 70 | {"type": "aplly_for_all_files", "id": ALL}, 71 | "value", 72 | ), 73 | State({"type": "label_selector", "id": ALL}, "value"), 74 | State({"type": "label_selector", "id": ALL}, "id"), 75 | State("datatable", "page_current"), 76 | State("datatable", "page_size"), 77 | State("datatable", "selected_rows"), 78 | State({"type": "model_selector", "id": ALL}, "value"), 79 | State("base_model_answers_selector", "value"), 80 | State({"type": "file_selector", "id": ALL}, "value"), 81 | State({"type": "file_selector", "id": ALL}, "options"), 82 | State( 83 | "dummy_output", 84 | "children", 85 | ), 86 | ], 87 | prevent_initial_call=True, 88 | ) 89 | def change_label( 90 | n_click_apply: List[int], 91 | n_click_del: List[int], 92 | apply_for_all: List[bool], 93 | labels: List[str], 94 | label_ids: List[int], 95 | current_page: int, 96 | page_size: int, 97 | idx: List[int], 98 | models: List[str], 99 | base_model: str, 100 | file_names: List[str], 101 | file_options: List[str], 102 | dummy_data: str, 103 | ) -> List[List[str]]: 104 | ctx = callback_context 105 | if not ctx.triggered: 106 | return no_update 107 | 108 | button_id = label_ids.index( 109 | json.loads( 110 | LABEL_SELECTOR_ID.format( 111 | json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"] 112 | ) 113 | ) 114 | ) 115 | is_apply = ( 116 | json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["type"] 117 | == "apply_label_button" 118 | ) 119 | if not ctx.triggered[0]["value"] or labels[button_id] == CHOOSE_LABEL: 120 | return no_update 121 | 122 | ALL_FILES = "ALL_FILES" 123 | if button_id == 0: 124 | files = [ALL_FILES] 125 | file = [ALL_FILES] 126 | models_to_process = [(base_model, files, file)] 127 | apply_for_all = [[True] * len(models)] 128 | question_ids = list(range(len(get_table_data()))) 129 | else: 130 | if not idx: 131 | return no_update 132 | models_to_process = [ 133 | ( 134 | models[button_id - 1], 135 | file_options[button_id - 1], 136 | file_names[button_id - 1], 137 | ) 138 | ] 139 | question_ids = [current_page * page_size + idx[0]] 140 | 141 | apply_for_all_files = bool(len(apply_for_all[button_id - 1])) 142 | for question_id in question_ids: 143 | for model, current_file_options, current_file in models_to_process: 144 | options = ( 145 | current_file_options 146 | if button_id != 0 147 | else [ 148 | {"value": file[FILE_NAME]} 149 | for file in get_table_data()[question_id][model] 150 | ] 151 | ) 152 | for file in options: 153 | if not apply_for_all_files and not file["value"] == current_file: 154 | continue 155 | 156 | file_id = 0 157 | for i, model_file in enumerate(get_table_data()[question_id][model]): 158 | if model_file[FILE_NAME] == file["value"]: 159 | file_id = i 160 | break 161 | 162 | if ( 163 | labels[button_id] 164 | not in get_table_data()[question_id][model][file_id][LABEL] 165 | ): 166 | if is_apply: 167 | get_table_data()[question_id][model][file_id][LABEL].append( 168 | labels[button_id] 169 | ) 170 | 171 | elif not is_apply: 172 | get_table_data()[question_id][model][file_id][LABEL].remove( 173 | labels[button_id] 174 | ) 175 | 176 | return dummy_data + "1" 177 | 178 | 179 | @app.callback( 180 | [ 181 | Output({"type": "new_label_input", "id": ALL}, "value"), 182 | Output({"type": "label_selector", "id": ALL}, "options"), 183 | Output({"type": "label_selector", "id": ALL}, "value"), 184 | ], 185 | Input({"type": "add_new_label_button", "id": ALL}, "n_clicks"), 186 | [ 187 | State({"type": "new_label_input", "id": ALL}, "value"), 188 | State({"type": "label_selector", "id": ALL}, "options"), 189 | State({"type": "label_selector", "id": ALL}, "value"), 190 | State({"type": "label_selector", "id": ALL}, "id"), 191 | ], 192 | ) 193 | def add_new_label( 194 | n_click: int, 195 | new_labels: List[str], 196 | options: List[List[str]], 197 | values: List[str], 198 | label_ids: List[int], 199 | ) -> Tuple[List[List[str]], List[str]]: 200 | ctx = callback_context 201 | no_updates = [no_update] * len(new_labels) 202 | if not ctx.triggered: 203 | return no_updates, no_updates, no_updates 204 | 205 | button_id = label_ids.index( 206 | json.loads( 207 | LABEL_SELECTOR_ID.format( 208 | json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"] 209 | ) 210 | ) 211 | ) 212 | 213 | if not ctx.triggered[0]["value"]: 214 | return no_updates, no_updates, no_updates 215 | 216 | if new_labels[button_id] and new_labels[button_id] not in options[button_id]: 217 | for i in range(len(options)): 218 | new_label = {"label": new_labels[button_id], "value": new_labels[button_id]} 219 | if new_label not in options[i]: 220 | options[i].append( 221 | {"label": new_labels[button_id], "value": new_labels[button_id]} 222 | ) 223 | values[button_id] = new_labels[button_id] 224 | else: 225 | return no_updates, no_updates, no_updates 226 | 227 | get_labels().append(new_labels[button_id]) 228 | new_labels[button_id] = "" 229 | 230 | return new_labels, options, values 231 | 232 | 233 | @app.callback( 234 | Output({"type": "chosen_label", "id": ALL}, "children"), 235 | Input({"type": "label_selector", "id": ALL}, "value"), 236 | [ 237 | State({"type": "label_selector", "id": ALL}, "id"), 238 | State({"type": "chosen_label", "id": ALL}, "children"), 239 | ], 240 | ) 241 | def choose_label( 242 | label: List[str], label_ids: List[int], chosen_labels: List[str] 243 | ) -> Tuple[List[List[str]], List[str]]: 244 | ctx = callback_context 245 | if not ctx.triggered: 246 | return [no_update] * len(chosen_labels) 247 | 248 | for trigger in ctx.triggered: 249 | button_id = label_ids.index( 250 | json.loads( 251 | LABEL_SELECTOR_ID.format( 252 | json.loads(trigger["prop_id"].split(".")[0])["id"] 253 | ) 254 | ) 255 | ) 256 | 257 | if not ctx.triggered[0]["value"] or label[button_id] == CHOOSE_LABEL: 258 | chosen_labels[button_id] = "" 259 | else: 260 | chosen_labels[button_id] = f"chosen label: {label[button_id]}" 261 | 262 | return chosen_labels 263 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /nemo_inspector/layouts/analyze_page_layouts/base_layout.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from typing import List 17 | 18 | import dash_bootstrap_components as dbc 19 | from dash import dcc, html 20 | 21 | from nemo_inspector.layouts.analyze_page_layouts.table_layouts import ( 22 | get_detailed_info_table_layout, 23 | get_general_stats_layout, 24 | get_short_info_table_layout, 25 | ) 26 | from nemo_inspector.layouts.common_layouts import ( 27 | get_selector_layout, 28 | ) 29 | from nemo_inspector.settings.constants import ( 30 | BASE_GENERATION, 31 | CHOOSE_GENERATION, 32 | ERROR_MESSAGE_TEMPLATE, 33 | FILES_FILTERING, 34 | ) 35 | from nemo_inspector.utils.common import ( 36 | catch_eval_exception, 37 | clear_table_data, 38 | custom_deepcopy, 39 | get_available_models, 40 | get_data_from_files, 41 | get_eval_function, 42 | get_metrics, 43 | get_table_data, 44 | is_detailed_answers_rows_key, 45 | ) 46 | from nemo_inspector.layouts.analyze_page_layouts.modals_layouts import ( 47 | get_add_stats_modal_layout, 48 | get_change_label_modal_layout, 49 | get_filter_modal_layout, 50 | get_sorting_modal_layout, 51 | get_update_dataset_modal_layout, 52 | get_save_dataset_modal_layout, 53 | ) 54 | 55 | 56 | def get_compare_test_layout() -> html.Div: 57 | return html.Div( 58 | [ 59 | dbc.InputGroup( 60 | [ 61 | get_sorting_modal_layout(), 62 | get_filter_modal_layout(), 63 | get_add_stats_modal_layout(), 64 | get_change_label_modal_layout(apply_for_all_files=False), 65 | get_update_dataset_modal_layout(), 66 | get_save_dataset_modal_layout(), 67 | dbc.Button( 68 | "+", 69 | id="add_model", 70 | outline=True, 71 | color="primary", 72 | className="me-1", 73 | class_name="button-class", 74 | style={"margin-left": "1px"}, 75 | ), 76 | get_selector_layout( 77 | get_available_models().keys(), 78 | "base_model_answers_selector", 79 | value=CHOOSE_GENERATION, 80 | ), 81 | ] 82 | ), 83 | html.Pre(id="filtering_container"), 84 | html.Pre(id="sorting_container"), 85 | dcc.Loading( 86 | dbc.Container( 87 | id="loading_container", style={"display": "none"}, children="" 88 | ), 89 | type="circle", 90 | style={"margin-top": "50px"}, 91 | ), 92 | html.Div( 93 | children=[], 94 | id="compare_models_rows", 95 | ), 96 | ], 97 | ) 98 | 99 | 100 | def get_updated_tables_layout( 101 | base_model: str, update_function: str, models: List[str] 102 | ) -> List[html.Tr]: 103 | errors_dict = {} 104 | if update_function: 105 | update_eval_function = get_eval_function(update_function.strip()) 106 | available_models = { 107 | model_name: model_info["file_paths"] 108 | for model_name, model_info in get_available_models().items() 109 | } 110 | 111 | for question_id in range(len(get_table_data())): 112 | new_dicts = list( 113 | map( 114 | lambda data: catch_eval_exception( 115 | available_models, 116 | update_eval_function, 117 | data, 118 | data, 119 | errors_dict, 120 | ), 121 | get_table_data()[question_id][base_model], 122 | ) 123 | ) 124 | for i, new_dict in enumerate(new_dicts): 125 | for key, value in new_dict.items(): 126 | get_table_data()[question_id][base_model][i][key] = value 127 | 128 | keys = list(get_table_data()[question_id][base_model][i].keys()) 129 | for key in keys: 130 | if key not in new_dict: 131 | get_table_data()[question_id][base_model][i].pop(key) 132 | 133 | if len(errors_dict): 134 | logging.error(ERROR_MESSAGE_TEMPLATE.format("update_dataset", errors_dict)) 135 | 136 | return ( 137 | get_short_info_table_layout() 138 | + get_general_stats_layout(base_model) 139 | + get_detailed_info_table_layout( 140 | models, 141 | list( 142 | filter( 143 | is_detailed_answers_rows_key, 144 | ( 145 | get_table_data()[0][base_model][0].keys() 146 | if len(get_table_data()) and len(get_table_data()[0][base_model]) 147 | else [] 148 | ), 149 | ) 150 | ), 151 | ) 152 | ) 153 | 154 | 155 | def get_sorted_tables_layout( 156 | base_model: str, sorting_function: str, models: List[str] 157 | ) -> List[html.Tr]: 158 | errors_dict = {} 159 | if sorting_function: 160 | sortting_eval_function = get_eval_function(sorting_function.strip()) 161 | available_models = { 162 | model_name: model_info["file_paths"] 163 | for model_name, model_info in get_available_models().items() 164 | } 165 | 166 | for question_id in range(len(get_table_data())): 167 | for model in get_table_data()[question_id].keys(): 168 | get_table_data()[question_id][model].sort( 169 | key=lambda data: catch_eval_exception( 170 | available_models, 171 | sortting_eval_function, 172 | data, 173 | 0, 174 | errors_dict, 175 | ) 176 | ) 177 | 178 | get_table_data().sort( 179 | key=lambda single_question_data: tuple( 180 | map( 181 | lambda data: catch_eval_exception( 182 | available_models, 183 | sortting_eval_function, 184 | data, 185 | 0, 186 | errors_dict, 187 | ), 188 | single_question_data[base_model], 189 | ) 190 | ) 191 | ) 192 | if len(errors_dict): 193 | logging.error(ERROR_MESSAGE_TEMPLATE.format("sorting", errors_dict)) 194 | 195 | return ( 196 | get_short_info_table_layout() 197 | + get_general_stats_layout(base_model) 198 | + get_detailed_info_table_layout( 199 | models, 200 | list( 201 | filter( 202 | is_detailed_answers_rows_key, 203 | ( 204 | get_table_data()[0][base_model][0].keys() 205 | if len(get_table_data()) and len(get_table_data()[0][base_model]) 206 | else [] 207 | ), 208 | ) 209 | ), 210 | ) 211 | ) 212 | 213 | 214 | def get_filtered_tables_layout( 215 | base_model: str, 216 | filtering_function: str, 217 | apply_on_filtered_data: bool, 218 | models: List[str], 219 | filter_mode: str, 220 | ) -> List[html.Tr]: 221 | clean_table_data = [] 222 | if not apply_on_filtered_data: 223 | clear_table_data() 224 | get_table_data().extend(custom_deepcopy(get_data_from_files())) 225 | for question_id in range(len(get_table_data())): 226 | for model_id, files_data in get_table_data()[question_id].items(): 227 | stats = get_metrics(files_data) 228 | get_table_data()[question_id][model_id] = list( 229 | map( 230 | lambda data: {**data, **stats}, 231 | get_table_data()[question_id][model_id], 232 | ) 233 | ) 234 | 235 | errors_dict = {} 236 | if filtering_function: 237 | available_models = { 238 | model_name: model_info["file_paths"] 239 | for model_name, model_info in get_available_models().items() 240 | } 241 | filter_lines = filtering_function.strip().split("\n") 242 | common_expressions, splitted_filters = ( 243 | "\n".join(filter_lines[:-1]), 244 | filter_lines[-1], 245 | ) 246 | full_splitted_filters = [ 247 | common_expressions + "\n" + single_filter 248 | for single_filter in splitted_filters.split("&&") 249 | ] 250 | filtering_functions = ( 251 | list( 252 | [ 253 | get_eval_function(f"{BASE_GENERATION} = '{base_model}'\n" + func) 254 | for func in full_splitted_filters 255 | ] 256 | ) 257 | if filtering_function 258 | else [] 259 | ) 260 | 261 | if filter_mode == FILES_FILTERING: 262 | for question_id in range(len(get_table_data())): 263 | good_data = True 264 | for model_id in get_table_data()[question_id].keys(): 265 | 266 | def filtering_key_function(file_dict): 267 | data = {model_id: file_dict} 268 | return all( 269 | [ 270 | catch_eval_exception( 271 | available_models, 272 | filter_function, 273 | data, 274 | True, 275 | errors_dict, 276 | ) 277 | for filter_function in filtering_functions 278 | ], 279 | ) 280 | 281 | get_table_data()[question_id][model_id] = list( 282 | filter( 283 | filtering_key_function, 284 | get_table_data()[question_id][model_id], 285 | ) 286 | ) 287 | stats = get_metrics(get_table_data()[question_id][model_id]) 288 | get_table_data()[question_id][model_id] = list( 289 | map( 290 | lambda data: {**data, **stats}, 291 | get_table_data()[question_id][model_id], 292 | ) 293 | ) 294 | 295 | if get_table_data()[question_id][model_id] == []: 296 | good_data = False 297 | if good_data: 298 | clean_table_data.append(get_table_data()[question_id]) 299 | else: 300 | func = get_eval_function( 301 | f"{BASE_GENERATION} = '{base_model}'\n" + filtering_function.strip() 302 | ) 303 | clean_table_data = list( 304 | filter( 305 | lambda data: catch_eval_exception( 306 | available_models=[], 307 | eval_func=func, 308 | data=data, 309 | default_answer=True, 310 | errors_dict=errors_dict, 311 | ), 312 | get_table_data(), 313 | ) 314 | ) 315 | clear_table_data() 316 | get_table_data().extend(clean_table_data) 317 | if len(errors_dict): 318 | logging.error(ERROR_MESSAGE_TEMPLATE.format("filtering", errors_dict)) 319 | 320 | return ( 321 | get_short_info_table_layout() 322 | + get_general_stats_layout(base_model) 323 | + get_detailed_info_table_layout( 324 | models, 325 | list( 326 | filter( 327 | is_detailed_answers_rows_key, 328 | ( 329 | get_table_data()[0][base_model][0].keys() 330 | if len(get_table_data()) and len(get_table_data()[0][base_model]) 331 | else [] 332 | ), 333 | ) 334 | ), 335 | ) 336 | ) 337 | 338 | 339 | def get_tables_layout(base_model: str) -> List: 340 | if get_table_data() == []: 341 | get_table_data().extend(custom_deepcopy(get_data_from_files())) 342 | return ( 343 | get_short_info_table_layout() 344 | + get_general_stats_layout(base_model) 345 | + get_detailed_info_table_layout( 346 | [base_model], 347 | list( 348 | filter( 349 | is_detailed_answers_rows_key, 350 | ( 351 | get_table_data()[0][base_model][0].keys() 352 | if len(get_table_data()) and len(get_table_data()[0][base_model]) 353 | else [] 354 | ), 355 | ) 356 | ), 357 | ) 358 | ) 359 | -------------------------------------------------------------------------------- /nemo_inspector/layouts/analyze_page_layouts/modals_layouts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from typing import List 17 | 18 | import dash_bootstrap_components as dbc 19 | from dash import html 20 | from flask import current_app 21 | 22 | from nemo_inspector.layouts.common_layouts import ( 23 | get_selector_layout, 24 | get_switch_layout, 25 | ) 26 | from nemo_inspector.layouts.analyze_page_layouts.utils import ( 27 | get_code_text_area_layout, 28 | get_filter_text, 29 | get_stats_input, 30 | ) 31 | from nemo_inspector.settings.constants import ( 32 | DELETE, 33 | FILES_ONLY, 34 | FILES_FILTERING, 35 | GENERAL_STATS, 36 | ) 37 | from nemo_inspector.settings.constants.configurations import STATS_KEYS 38 | from nemo_inspector.utils.common import get_labels, get_metrics, get_table_data 39 | 40 | 41 | def get_filter_modal_layout( 42 | id: int = -1, available_filters: List[str] = [], mode: str = FILES_FILTERING 43 | ) -> html.Div: 44 | text = get_filter_text(available_filters, mode) 45 | 46 | filter_mode = ( 47 | [ 48 | get_switch_layout( 49 | id={"type": "filter_mode", "id": id}, 50 | labels=["filter files"], 51 | is_active=True, 52 | additional_params={ 53 | "inline": True, 54 | "style": {"margin-left": "10px"}, 55 | }, 56 | ) 57 | ] 58 | if mode != FILES_ONLY 59 | else [] 60 | ) 61 | 62 | header = dbc.ModalHeader( 63 | ( 64 | [ 65 | dbc.ModalTitle( 66 | "Set Up Your Filter", 67 | ), 68 | ] 69 | + filter_mode 70 | ), 71 | close_button=True, 72 | ) 73 | body = dbc.ModalBody( 74 | html.Div( 75 | [ 76 | html.Pre(text, id={"type": "filter_text", "id": id}), 77 | get_code_text_area_layout( 78 | id={ 79 | "type": "filter_function_input", 80 | "id": id, 81 | }, 82 | ), 83 | ] 84 | ) 85 | ) 86 | switch = get_switch_layout( 87 | { 88 | "type": "apply_on_filtered_data", 89 | "id": id, 90 | }, 91 | ["Apply for filtered data"], 92 | additional_params={"style": {"margin-left": "10px"}}, 93 | ) 94 | footer = dbc.ModalFooter( 95 | dbc.Button( 96 | "Apply", 97 | id={"type": "apply_filter_button", "id": id}, 98 | className="ms-auto", 99 | n_clicks=0, 100 | ) 101 | ) 102 | return html.Div( 103 | [ 104 | dbc.Button( 105 | "Filters", 106 | id={"type": "set_filter_button", "id": id}, 107 | class_name="button-class", 108 | ), 109 | dbc.Modal( 110 | [ 111 | header, 112 | body, 113 | switch, 114 | footer, 115 | ], 116 | size="lg", 117 | id={"type": "filter", "id": id}, 118 | centered=True, 119 | is_open=False, 120 | ), 121 | ], 122 | style={"display": "inline-block"}, 123 | ) 124 | 125 | 126 | def get_sorting_modal_layout(id: int = -1, available_params: List[str] = []) -> html.Div: 127 | available_params = list( 128 | get_table_data()[0][list(get_table_data()[0].keys())[0]][0].keys() 129 | if len(get_table_data()) and not available_params 130 | else STATS_KEYS + list(get_metrics([]).keys()) + ["+ all fields in json"] 131 | ) 132 | text = ( 133 | "Write an expression to sort the data\n\n" 134 | "For example: len(data['question'])\n\n" 135 | "The function has to return sortable type\n\n" 136 | "Available parameters to sort data:\n" 137 | + "\n".join( 138 | [ 139 | ", ".join(available_params[start : start + 5]) 140 | for start in range(0, len(available_params), 5) 141 | ] 142 | ) 143 | ) 144 | header = dbc.ModalHeader( 145 | dbc.ModalTitle("Set Up Your Sorting Parameters"), 146 | close_button=True, 147 | ) 148 | body = dbc.ModalBody( 149 | html.Div( 150 | [ 151 | html.Pre(text), 152 | get_code_text_area_layout( 153 | id={ 154 | "type": "sorting_function_input", 155 | "id": id, 156 | }, 157 | ), 158 | ], 159 | ) 160 | ) 161 | footer = dbc.ModalFooter( 162 | dbc.Button( 163 | "Apply", 164 | id={"type": "apply_sorting_button", "id": id}, 165 | className="ms-auto", 166 | n_clicks=0, 167 | ) 168 | ) 169 | return html.Div( 170 | [ 171 | dbc.Button( 172 | "Sort", 173 | id={"type": "set_sorting_button", "id": id}, 174 | class_name="button-class", 175 | ), 176 | dbc.Modal( 177 | [ 178 | header, 179 | body, 180 | footer, 181 | ], 182 | size="lg", 183 | id={"type": "sorting", "id": id}, 184 | centered=True, 185 | is_open=False, 186 | ), 187 | ], 188 | style={"display": "inline-block"}, 189 | ) 190 | 191 | 192 | def get_update_dataset_modal_layout() -> html.Div: 193 | text = ( 194 | "Write an expression to modify the data\n\n" 195 | "For example: {**data, 'generation': data['generation'].strip()}\n\n" 196 | "The function has to return a new dict" 197 | ) 198 | header = dbc.ModalHeader( 199 | dbc.ModalTitle("Update Dataset"), 200 | close_button=True, 201 | ) 202 | body = dbc.ModalBody( 203 | html.Div( 204 | [ 205 | html.Pre(text), 206 | get_code_text_area_layout( 207 | id="update_dataset_input", 208 | ), 209 | ], 210 | ) 211 | ) 212 | footer = dbc.ModalFooter( 213 | dbc.Button( 214 | "Apply", 215 | id="apply_update_dataset_button", 216 | className="ms-auto", 217 | n_clicks=0, 218 | ) 219 | ) 220 | return html.Div( 221 | [ 222 | dbc.Button( 223 | "Update dataset", 224 | id="update_dataset_button", 225 | class_name="button-class", 226 | ), 227 | dbc.Modal( 228 | [ 229 | header, 230 | body, 231 | footer, 232 | ], 233 | size="lg", 234 | id="update_dataset_modal", 235 | centered=True, 236 | is_open=False, 237 | ), 238 | ], 239 | style={"display": "inline-block"}, 240 | ) 241 | 242 | 243 | def get_change_label_modal_layout( 244 | id: int = -1, apply_for_all_files: bool = True 245 | ) -> html.Div: 246 | header = dbc.ModalHeader( 247 | dbc.ModalTitle("Manage labels"), 248 | close_button=True, 249 | ) 250 | switch_layout = ( 251 | [ 252 | get_switch_layout( 253 | { 254 | "type": "aplly_for_all_files", 255 | "id": id, 256 | }, 257 | ["Apply for all files"], 258 | additional_params={"style": {"margin-left": "10px"}}, 259 | ) 260 | ] 261 | if apply_for_all_files 262 | else [] 263 | ) 264 | body = dbc.ModalBody( 265 | html.Div( 266 | [ 267 | get_selector_layout( 268 | options=get_labels(), 269 | id={"type": "label_selector", "id": id}, 270 | value="choose label", 271 | ), 272 | dbc.InputGroup( 273 | [ 274 | dbc.Input( 275 | id={ 276 | "type": "new_label_input", 277 | "id": id, 278 | }, 279 | placeholder="Enter new label", 280 | type="text", 281 | ), 282 | dbc.Button( 283 | "Add", 284 | id={ 285 | "type": "add_new_label_button", 286 | "id": id, 287 | }, 288 | ), 289 | ] 290 | ), 291 | *switch_layout, 292 | html.Pre("", id={"type": "chosen_label", "id": id}), 293 | ], 294 | ) 295 | ) 296 | footer = dbc.ModalFooter( 297 | html.Div( 298 | [ 299 | dbc.Button( 300 | children="Delete", 301 | id={ 302 | "type": "delete_label_button", 303 | "id": id, 304 | }, 305 | className="ms-auto", 306 | n_clicks=0, 307 | ), 308 | html.Pre( 309 | " ", 310 | style={"display": "inline-block", "font-size": "5px"}, 311 | ), 312 | dbc.Button( 313 | children="Apply", 314 | id={"type": "apply_label_button", "id": id}, 315 | className="ms-auto", 316 | n_clicks=0, 317 | ), 318 | ], 319 | ), 320 | style={"display": "inline-block"}, 321 | ) 322 | return html.Div( 323 | [ 324 | dbc.Button( 325 | "Labels", 326 | id={"type": "set_file_label_button", "id": id}, 327 | class_name="button-class", 328 | ), 329 | dbc.Modal( 330 | [header, body, footer], 331 | size="lg", 332 | id={"type": "label", "id": id}, 333 | centered=True, 334 | is_open=False, 335 | ), 336 | ], 337 | style={"display": "inline-block"}, 338 | ) 339 | 340 | 341 | def get_save_dataset_modal_layout() -> html.Div: 342 | return html.Div( 343 | [ 344 | dbc.Button("Save dataset", id="save_dataset", class_name="button-class"), 345 | dbc.Modal( 346 | [ 347 | dbc.ModalBody( 348 | [ 349 | dbc.InputGroup( 350 | [ 351 | dbc.InputGroupText("save_path"), 352 | dbc.Input( 353 | value=os.path.join( 354 | current_app.config["nemo_inspector"]["save_generations_path"], 355 | "default_name", 356 | ), 357 | id="save_path", 358 | type="text", 359 | ), 360 | ], 361 | className="mb-3", 362 | ), 363 | dbc.Container(id="error_message"), 364 | ] 365 | ), 366 | dbc.ModalFooter( 367 | dbc.Button( 368 | "Save", 369 | id="save_dataset_button", 370 | className="ms-auto", 371 | n_clicks=0, 372 | ) 373 | ), 374 | ], 375 | id="save_dataset_modal", 376 | is_open=False, 377 | style={ 378 | "text-align": "center", 379 | "margin-top": "10px", 380 | "margin-bottom": "10px", 381 | }, 382 | ), 383 | ], 384 | ) 385 | 386 | 387 | def get_add_stats_modal_layout() -> html.Div: 388 | modal_header = dbc.ModalHeader( 389 | [ 390 | dbc.ModalTitle("Set Up Your Stats"), 391 | get_switch_layout( 392 | id="stats_modes", 393 | labels=["general stats", "delete mode"], 394 | values=[GENERAL_STATS, DELETE], 395 | additional_params={"inline": True, "style": {"margin-left": "10px"}}, 396 | ), 397 | ], 398 | close_button=True, 399 | ) 400 | modal_body = dbc.ModalBody( 401 | html.Div( 402 | get_stats_input(), 403 | id="stats_input_container", 404 | ) 405 | ) 406 | modal_footer = dbc.ModalFooter( 407 | dbc.Button( 408 | "Apply", 409 | id="apply_new_stats", 410 | className="ms-auto", 411 | n_clicks=0, 412 | ) 413 | ) 414 | return html.Div( 415 | [ 416 | dbc.Button( 417 | "Stats", 418 | id="set_new_stats_button", 419 | class_name="button-class", 420 | ), 421 | dbc.Modal( 422 | [ 423 | modal_header, 424 | modal_body, 425 | modal_footer, 426 | ], 427 | size="lg", 428 | id="new_stats", 429 | centered=True, 430 | is_open=False, 431 | ), 432 | ], 433 | style={"display": "inline-block"}, 434 | ) 435 | -------------------------------------------------------------------------------- /nemo_inspector/utils/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | import json 17 | import logging 18 | import os 19 | import re 20 | from collections import defaultdict 21 | from typing import Callable, Dict, List, Optional, Set, Tuple, Union, get_origin 22 | 23 | from flask import current_app 24 | from joblib import Parallel, delayed 25 | 26 | from nemo_inspector.settings.constants import ( 27 | EXPECTED_ANSWER_FIELD, 28 | CUSTOM, 29 | ERROR_MESSAGE_TEMPLATE, 30 | FILE_NAME, 31 | GENERAL_STATS, 32 | INLINE_STATS, 33 | QUESTION_FIELD, 34 | STATS_KEYS, 35 | UNDEFINED, 36 | ) 37 | 38 | custom_stats = {} 39 | general_custom_stats = {} 40 | deleted_stats = set() 41 | excluded_rows = set() 42 | editable_rows = set() 43 | compared_rows = set() 44 | stats_raw = {INLINE_STATS: {CUSTOM: ""}, GENERAL_STATS: {CUSTOM: ""}} 45 | 46 | dataset_data = [] 47 | labels = [] 48 | 49 | 50 | def get_editable_rows() -> Set: 51 | return editable_rows 52 | 53 | 54 | def get_excluded_row() -> Set: 55 | return excluded_rows 56 | 57 | 58 | def get_deleted_stats() -> Set: 59 | return deleted_stats 60 | 61 | 62 | def get_custom_stats() -> Dict: 63 | return custom_stats 64 | 65 | 66 | def get_compared_rows() -> Dict: 67 | return compared_rows 68 | 69 | 70 | def get_general_custom_stats() -> Dict: 71 | return general_custom_stats 72 | 73 | 74 | def get_stats_raw() -> Dict: 75 | return stats_raw 76 | 77 | 78 | def clear_table_data() -> None: 79 | global dataset_data 80 | dataset_data = [] 81 | 82 | 83 | def get_table_data() -> List: 84 | return dataset_data 85 | 86 | 87 | def get_labels() -> List: 88 | return labels 89 | 90 | 91 | def parse_model_answer(answer: str) -> List[Dict]: 92 | """ 93 | Parses a model answer and extracts code blocks, explanations, and outputs preserving their sequence. 94 | 95 | Args: 96 | answer (str): The model answer to parse. 97 | 98 | Returns: 99 | List[Dict]: A list of dictionaries containing the parsed results. Each dictionary 100 | contains the following keys: 101 | - 'explanation': The explanation text before the code block. 102 | - 'code': The code block. 103 | - 'output': The output of the code block. 104 | 105 | """ 106 | config = current_app.config["nemo_inspector"] 107 | code_start, code_end = map( 108 | re.escape, 109 | config["code_separators"], 110 | ) 111 | output_start, output_end = map( 112 | re.escape, 113 | config["code_output_separators"], 114 | ) 115 | code_pattern = re.compile(rf"{code_start}(.*?){code_end}", re.DOTALL) 116 | code_output_pattern = re.compile( 117 | rf"{code_start}(.*?){code_end}\s*{output_start}(.*?){output_end}", 118 | re.DOTALL, 119 | ) 120 | code_matches = list(code_pattern.finditer(answer)) 121 | code_output_matches = list(code_output_pattern.finditer(answer)) 122 | parsed_results = [] 123 | last_index = 0 124 | for code_match in code_matches: 125 | explanation = answer[last_index : code_match.start()].strip() 126 | code_text = code_match.group(1).strip() 127 | output_text = None 128 | if code_output_matches and code_output_matches[0].start() == code_match.start(): 129 | output_match = code_output_matches.pop(0) 130 | output_text = output_match.group(2).strip() 131 | parsed_results.append( 132 | { 133 | "explanation": explanation, 134 | "code": code_text, 135 | "output": output_text, 136 | } 137 | ) 138 | last_index = code_match.end() 139 | if output_text is not None: 140 | last_index = output_match.end() 141 | if last_index < len(answer): 142 | trailing_text = answer[last_index:].strip() 143 | if code_start.replace("\\", "") in trailing_text: 144 | code_start_index = trailing_text.find(code_start.replace("\\", "")) 145 | parsed_results.append( 146 | { 147 | "explanation": trailing_text[0:code_start_index].strip(), 148 | "code": trailing_text[ 149 | code_start_index + len(code_start.replace("\\", "")) : 150 | ], 151 | "output": "code_block was not finished", 152 | "wrong_code_block": True, 153 | } 154 | ) 155 | trailing_text = None 156 | if trailing_text: 157 | parsed_results.append( 158 | {"explanation": trailing_text, "code": None, "output": None} 159 | ) 160 | return parsed_results 161 | 162 | 163 | @functools.lru_cache() 164 | def get_dataset_sample(index: int, dataset: str) -> Tuple[Dict, int]: 165 | if not dataset or dataset == UNDEFINED or os.path.isfile(dataset) is False: 166 | return {QUESTION_FIELD: "", EXPECTED_ANSWER_FIELD: ""}, 0 167 | with open(dataset) as file: 168 | tests = file.readlines() 169 | index = max(min(len(tests), index), 1) 170 | test = ( 171 | json.loads(tests[index - 1]) 172 | if index != 0 173 | else {QUESTION_FIELD: "", EXPECTED_ANSWER_FIELD: ""} 174 | ) 175 | return test, index 176 | 177 | 178 | def get_stats(all_files_data: List[Dict]) -> Tuple[float, float, float]: 179 | """Returns the percentage of correct, wrong, and no response answers in the given data. 180 | 181 | If not data is provided, returns -1 for all values. 182 | """ 183 | correct = 0 184 | wrong = 0 185 | no_response = 0 186 | for data in all_files_data: 187 | if data.get("predicted_answer") is None: 188 | no_response += 1 189 | elif data.get("is_correct", False): 190 | correct += 1 191 | else: 192 | wrong += 1 193 | 194 | if len(all_files_data): 195 | return ( 196 | correct / len(all_files_data), 197 | wrong / len(all_files_data), 198 | no_response / len(all_files_data), 199 | ) 200 | return -1, -1, -1 201 | 202 | 203 | def get_metrics(all_files_data: List[Dict], errors_dict: Dict = {}) -> Dict: 204 | correct_responses, wrong_responses, no_response = get_stats(all_files_data) 205 | custom_stats = {} 206 | for name, func in get_custom_stats().items(): 207 | if name not in errors_dict: 208 | errors_dict[name] = {} 209 | custom_stats[name] = catch_eval_exception( 210 | [], 211 | func, 212 | all_files_data, 213 | "Got error when applying function", 214 | errors_dict[name], 215 | ) 216 | 217 | stats = { 218 | "correct_responses": round(correct_responses, 2), 219 | "wrong_responses": round(wrong_responses, 2), 220 | "no_response": round(no_response, 2), 221 | **custom_stats, 222 | } 223 | return stats 224 | 225 | 226 | def get_eval_function(text): 227 | template = """ 228 | def eval_function(data): 229 | {} 230 | return {} 231 | """ 232 | code_lines = [""] + text.strip().split("\n") 233 | code = template.format( 234 | "\n ".join(code_lines[:-1]), 235 | code_lines[-1:][0], 236 | ) 237 | namespace = {} 238 | exec(code, namespace) 239 | return namespace["eval_function"] 240 | 241 | 242 | def calculate_metrics_for_whole_data(table_data: List, model_id: str) -> Dict: 243 | errors_dict = {} 244 | for question_id in range(len(table_data)): 245 | stats = get_metrics(table_data[question_id][model_id], errors_dict) 246 | table_data[question_id][model_id] = list( 247 | map( 248 | lambda data: {**data, **stats}, 249 | table_data[question_id][model_id], 250 | ) 251 | ) 252 | if len(errors_dict): 253 | for name, error_dict in errors_dict.items(): 254 | if len(error_dict): 255 | logging.error(ERROR_MESSAGE_TEMPLATE.format(name, error_dict)) 256 | 257 | 258 | def catch_eval_exception( 259 | available_models: List[str], 260 | eval_func: Callable[[Dict], bool], 261 | data: Dict, 262 | default_answer: Union[bool, str], 263 | errors_dict: Optional[Dict] = {}, 264 | ) -> bool: 265 | try: 266 | if eval_func is None: 267 | return default_answer 268 | return eval_func(data) 269 | except Exception as e: 270 | if str(e).split(" ")[-1].replace("'", "") not in available_models: 271 | if str(e) not in errors_dict: 272 | errors_dict[str(e)] = 0 273 | errors_dict[str(e)] += 1 274 | return default_answer 275 | 276 | 277 | def custom_deepcopy(data) -> List: 278 | new_data = [] 279 | for item in data: 280 | new_item = {} 281 | for key, value_list in item.items(): 282 | new_item[key] = value_list 283 | new_data.append(new_item) 284 | return new_data 285 | 286 | 287 | @functools.lru_cache(maxsize=1) 288 | def get_data_from_files() -> List: 289 | base_config = current_app.config["nemo_inspector"] 290 | dataset = None 291 | if os.path.isfile(base_config["input_file"]): 292 | with open(base_config["input_file"]) as f: 293 | dataset = [json.loads(line) for line in f] 294 | 295 | available_models = { 296 | model_name: model_info["file_paths"] 297 | for model_name, model_info in get_available_models().items() 298 | } 299 | 300 | all_models_data_array = [] 301 | 302 | def process_model_files(model_id, results_files, dataset): 303 | model_data = defaultdict(list) 304 | file_names = {} 305 | for file_id, path in enumerate(results_files): 306 | file_name = path.split("/")[-1].split(".")[0] 307 | if file_name in file_names: 308 | file_names[file_name] += 1 309 | file_name += f"_{file_names[file_name]}" 310 | else: 311 | file_names[file_name] = 1 312 | with open(path) as f: 313 | answers = map(json.loads, f) 314 | for question_index, answer in enumerate(answers): 315 | result = { 316 | FILE_NAME: file_name, 317 | **( 318 | dataset[question_index] 319 | if dataset and len(dataset) > question_index 320 | else {} 321 | ), 322 | "question_index": question_index + 1, 323 | "page_index": file_id, 324 | "labels": [], 325 | **answer, 326 | } 327 | model_data[question_index].append(result) 328 | return model_id, model_data 329 | 330 | num_cores = -1 331 | model_data_list = Parallel(n_jobs=num_cores)( 332 | delayed(process_model_files)(model_id, results_files, dataset) 333 | for model_id, results_files in available_models.items() 334 | ) 335 | 336 | for model_id, model_data in model_data_list: 337 | for question_index, results in model_data.items(): 338 | if len(all_models_data_array) <= question_index: 339 | all_models_data_array.append({}) 340 | all_models_data_array[question_index][model_id] = results 341 | stats = get_metrics(all_models_data_array[question_index][model_id]) 342 | all_models_data_array[question_index][model_id] = list( 343 | map( 344 | lambda data: {**data, **stats}, 345 | all_models_data_array[question_index][model_id], 346 | ) 347 | ) 348 | return all_models_data_array 349 | 350 | 351 | def get_filtered_files( 352 | filter_function: str, 353 | sorting_function: str, 354 | array_to_filter: List, 355 | ) -> List: 356 | filter_lambda_functions = [ 357 | get_eval_function(func.strip()) 358 | for func in (filter_function if filter_function else "True").split("&&") 359 | ] 360 | available_models = get_available_models() 361 | filtered_data = [ 362 | list( 363 | filter( 364 | lambda data: catch_eval_exception( 365 | available_models, function, data, False 366 | ), 367 | array_to_filter, 368 | ) 369 | ) 370 | for function in filter_lambda_functions 371 | ] 372 | 373 | filtered_data = list(filter(lambda data: data != [], filtered_data)) 374 | filtered_data = filtered_data[0] if len(filtered_data) > 0 else [{FILE_NAME: ""}] 375 | if sorting_function and filtered_data != [{FILE_NAME: ""}]: 376 | sorting_lambda_function = get_eval_function(sorting_function.strip()) 377 | filtered_data.sort( 378 | key=lambda data: catch_eval_exception( 379 | available_models, sorting_lambda_function, data, 0 380 | ) 381 | ) 382 | 383 | return filtered_data 384 | 385 | 386 | def is_detailed_answers_rows_key(key: str) -> bool: 387 | return ( 388 | key not in get_deleted_stats() 389 | and "index" not in key 390 | and key not in STATS_KEYS + list(get_metrics([]).keys()) 391 | or key == QUESTION_FIELD 392 | ) 393 | 394 | 395 | @functools.lru_cache(maxsize=1) 396 | def get_available_models() -> Dict: 397 | config = current_app.config["nemo_inspector"] 398 | runs_storage = {} 399 | for model_name, files in config["model_prediction"].items(): 400 | runs_storage[model_name] = { 401 | "file_paths": files, 402 | } 403 | 404 | return runs_storage 405 | 406 | 407 | def get_file_id(file_names: List[str], files: List[Dict], column_id: str): 408 | file_id = 0 409 | file_name = ( 410 | file_names[column_id]["value"] 411 | if isinstance(file_names[column_id], Dict) 412 | else file_names[column_id] 413 | ) 414 | for i, file_data in enumerate(files): 415 | if file_data[FILE_NAME] == file_name: 416 | file_id = i 417 | break 418 | return file_id 419 | 420 | 421 | def resolve_type(field_type): 422 | origin = get_origin(field_type) 423 | if origin is not None: 424 | # If the type is a generic, use its origin (e.g., list, dict) 425 | return origin 426 | elif isinstance(field_type, type): 427 | return field_type 428 | else: 429 | # For special cases like typing.Any, return str by default 430 | return str 431 | 432 | 433 | def get_type_default(field_type): 434 | try: 435 | return resolve_type(field_type)() 436 | except Exception: 437 | return None 438 | 439 | 440 | def resolve_union_or_any(field_type): 441 | """Resolve Union and Any types to concrete types""" 442 | return resolve_type(field_type) 443 | -------------------------------------------------------------------------------- /nemo_inspector/layouts/analyze_page_layouts/table_layouts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import logging 17 | import math 18 | from typing import Dict, List 19 | 20 | 21 | import dash_bootstrap_components as dbc 22 | from dash import dash_table, html 23 | 24 | from nemo_inspector.layouts.analyze_page_layouts.modals_layouts import ( 25 | get_change_label_modal_layout, 26 | get_filter_modal_layout, 27 | get_sorting_modal_layout, 28 | ) 29 | from nemo_inspector.layouts.common_layouts import ( 30 | get_selector_layout, 31 | get_single_prompt_output_layout, 32 | get_text_modes_layout, 33 | ) 34 | from nemo_inspector.settings.constants import ( 35 | ANSI, 36 | CODE, 37 | COMPARE, 38 | COMPARE_ICON_PATH, 39 | DATA_PAGE_SIZE, 40 | EDIT_ICON_PATH, 41 | ERROR_MESSAGE_TEMPLATE, 42 | FILE_NAME, 43 | FILES_ONLY, 44 | LABEL, 45 | LATEX, 46 | MODEL_SELECTOR_ID, 47 | STATS_KEYS, 48 | ) 49 | from nemo_inspector.utils.common import ( 50 | catch_eval_exception, 51 | get_available_models, 52 | get_compared_rows, 53 | get_editable_rows, 54 | get_excluded_row, 55 | get_filtered_files, 56 | get_general_custom_stats, 57 | get_metrics, 58 | get_table_data, 59 | is_detailed_answers_rows_key, 60 | ) 61 | 62 | 63 | def get_short_info_table_layout() -> List[dbc.Row]: 64 | return [ 65 | dbc.Row( 66 | dbc.Col( 67 | dash_table.DataTable( 68 | id="datatable", 69 | columns=[ 70 | { 71 | "name": name, 72 | "id": name, 73 | "hideable": True, 74 | } 75 | for name in STATS_KEYS + list(get_metrics([]).keys()) 76 | ], 77 | row_selectable="single", 78 | cell_selectable=False, 79 | page_action="custom", 80 | page_current=0, 81 | page_size=DATA_PAGE_SIZE, 82 | page_count=math.ceil(len(get_table_data()) / DATA_PAGE_SIZE), 83 | style_cell={ 84 | "overflow": "hidden", 85 | "textOverflow": "ellipsis", 86 | "maxWidth": 0, 87 | "textAlign": "center", 88 | }, 89 | style_header={ 90 | "color": "text-primary", 91 | "text_align": "center", 92 | "height": "auto", 93 | "whiteSpace": "normal", 94 | }, 95 | css=[ 96 | { 97 | "selector": ".dash-spreadsheet-menu", 98 | "rule": "position:absolute; bottom: 8px", 99 | }, 100 | { 101 | "selector": ".dash-filter--case", 102 | "rule": "display: none", 103 | }, 104 | { 105 | "selector": ".column-header--hide", 106 | "rule": "display: none", 107 | }, 108 | ], 109 | ), 110 | ) 111 | ), 112 | ] 113 | 114 | 115 | def get_table_column_header( 116 | models: List[str], name: str, id: int, add_del_button: bool = False 117 | ) -> dbc.Col: 118 | del_model_layout = ( 119 | [ 120 | dbc.Button( 121 | "-", 122 | id={"type": "del_model", "id": id}, 123 | outline=True, 124 | color="primary", 125 | className="me-1", 126 | style={"height": "40px"}, 127 | ), 128 | ] 129 | if add_del_button 130 | else [] 131 | ) 132 | return dbc.Col( 133 | html.Div( 134 | [ 135 | html.Div( 136 | get_selector_layout( 137 | models, 138 | json.loads(MODEL_SELECTOR_ID.format(id)), 139 | name, 140 | ), 141 | ), 142 | get_sorting_modal_layout(id), 143 | get_filter_modal_layout(id, mode=FILES_ONLY), 144 | get_change_label_modal_layout(id), 145 | ] 146 | + del_model_layout 147 | + [get_text_modes_layout(id)], 148 | style={"display": "inline-flex"}, 149 | ), 150 | class_name="mt-1 bg-light font-monospace text-break small rounded border", 151 | id={"type": "column_header", "id": id}, 152 | ) 153 | 154 | 155 | def get_table_header(models: List[str]) -> List[dbc.Row]: 156 | return [ 157 | dbc.Row( 158 | [ 159 | dbc.Col( 160 | html.Div( 161 | "", 162 | ), 163 | width=2, 164 | class_name="mt-1 bg-light font-monospace text-break small rounded border", 165 | id="first_column", 166 | ) 167 | ] 168 | + [ 169 | get_table_column_header(get_available_models(), name, i, i != 0) 170 | for i, name in enumerate(models) 171 | ], 172 | id="detailed_answers_header", 173 | ) 174 | ] 175 | 176 | 177 | def get_detailed_info_table_column(id: int, file_id=None) -> dbc.Col: 178 | return dbc.Col( 179 | html.Div( 180 | children=( 181 | get_selector_layout([], {"type": "file_selector", "id": file_id}, "") 182 | if file_id is not None 183 | else "" 184 | ), 185 | id={ 186 | "type": "detailed_models_answers", 187 | "id": id, 188 | }, 189 | ), 190 | class_name="mt-1 bg-light font-monospace text-break small rounded border", 191 | ) 192 | 193 | 194 | def get_detailed_info_table_rows(keys: List[str], colums_number: int) -> List[dbc.Row]: 195 | return [ 196 | dbc.Row( 197 | [ 198 | dbc.Col( 199 | html.Div( 200 | html.Div( 201 | [ 202 | html.Div( 203 | key, 204 | id={"type": "row_name", "id": i}, 205 | style={"display": "inline-block"}, 206 | ), 207 | dbc.Button( 208 | html.Img( 209 | src=EDIT_ICON_PATH, 210 | id={"type": "edit_row_image", "id": i}, 211 | style={ 212 | "height": "15px", 213 | "display": "inline-block", 214 | }, 215 | ), 216 | id={"type": "edit_row_button", "id": i}, 217 | outline=True, 218 | color="primary", 219 | className="me-1", 220 | style={ 221 | "border": "none", 222 | "line-height": "1.2", 223 | "display": "inline-block", 224 | "margin-left": "1px", 225 | "display": ( 226 | "none" 227 | if key in (FILE_NAME, LABEL) 228 | else "inline-block" 229 | ), 230 | }, 231 | ), 232 | dbc.Button( 233 | html.Img( 234 | src=COMPARE_ICON_PATH, 235 | id={"type": "compare_texts", "id": i}, 236 | style={ 237 | "height": "15px", 238 | "display": "inline-block", 239 | }, 240 | ), 241 | id={"type": "compare_texts_button", "id": i}, 242 | outline=True, 243 | color="primary", 244 | className="me-1", 245 | style={ 246 | "border": "none", 247 | "line-height": "1.2", 248 | "display": "inline-block", 249 | "margin-left": "-10px" if key != LABEL else "1px", 250 | "display": ( 251 | "none" if key == FILE_NAME else "inline-block" 252 | ), 253 | }, 254 | ), 255 | dbc.Button( 256 | "-", 257 | id={"type": "del_row", "id": i}, 258 | outline=True, 259 | color="primary", 260 | className="me-1", 261 | style={ 262 | "border": "none", 263 | "display": "inline-block", 264 | "margin-left": ( 265 | "-9px" if key != FILE_NAME else "1px" 266 | ), 267 | }, 268 | ), 269 | ], 270 | style={"display": "inline-block"}, 271 | ), 272 | ), 273 | width=2, 274 | class_name="mt-1 bg-light font-monospace text-break small rounded border", 275 | ) 276 | ] 277 | + [ 278 | get_detailed_info_table_column(j * len(keys) + i) 279 | for j in range(colums_number) 280 | ], 281 | id={"type": "detailed_answers_row", "id": i}, 282 | ) 283 | for i, key in enumerate(keys) 284 | ] 285 | 286 | 287 | def get_detailed_info_table_layout( 288 | models: List[str], 289 | keys: List[str], 290 | ) -> List[dbc.Row]: 291 | return get_table_header(models) + get_detailed_info_table_rows(keys, len(models)) 292 | 293 | 294 | def get_detailed_info_table_row_content( 295 | question_id: int, 296 | model: str, 297 | rows_names: List[str], 298 | files_names: List[str], 299 | file_id: int, 300 | col_id: int, 301 | compare_to: Dict = {}, 302 | text_modes: List[str] = [CODE, LATEX, ANSI], 303 | ) -> List: 304 | table_data = get_table_data()[question_id].get(model, []) 305 | row_data = [] 306 | empty_list = False 307 | if table_data[file_id].get(FILE_NAME, None) not in files_names: 308 | empty_list = True 309 | for key in filter( 310 | lambda key: is_detailed_answers_rows_key(key), 311 | rows_names, 312 | ): 313 | if file_id < 0 or len(table_data) <= file_id or key in get_excluded_row(): 314 | value = "" 315 | elif key == FILE_NAME: 316 | value = get_selector_layout( 317 | files_names, 318 | {"type": "file_selector", "id": col_id}, 319 | (table_data[file_id].get(key, None) if not empty_list else ""), 320 | ) 321 | elif empty_list: 322 | value = "" 323 | elif key in get_editable_rows(): 324 | value = str(table_data[file_id].get(key, None)) 325 | else: 326 | value = get_single_prompt_output_layout( 327 | str(table_data[file_id].get(key, None)), 328 | text_modes + ([COMPARE] if key in get_compared_rows() else []), 329 | str(compare_to.get(key, "")), 330 | ) 331 | row_data.append( 332 | value 333 | if key not in get_editable_rows() 334 | else dbc.Textarea( 335 | id={"type": "editable_row", "id": key, "model_name": model}, value=value 336 | ) 337 | ) 338 | return row_data 339 | 340 | 341 | def get_detailed_info_table_content( 342 | question_id: int, 343 | rows_names: List[str], 344 | models: List[str], 345 | files_id: List[int], 346 | filter_functions: List[str], 347 | sorting_functions: List[str], 348 | text_modes: List[List[str]], 349 | ) -> List: 350 | table_data = [] 351 | for col_id, (model, file_id, filter_function, sorting_function, modes) in enumerate( 352 | zip(models, files_id, filter_functions, sorting_functions, text_modes) 353 | ): 354 | row_data = get_detailed_info_table_row_content( 355 | question_id=question_id, 356 | model=model, 357 | rows_names=rows_names, 358 | files_names=[ 359 | file[FILE_NAME] 360 | for file in get_filtered_files( 361 | filter_function, 362 | sorting_function, 363 | get_table_data()[question_id][model] if len(get_table_data()) else [], 364 | ) 365 | ], 366 | file_id=file_id, 367 | col_id=col_id, 368 | text_modes=modes, 369 | compare_to=get_table_data()[question_id][models[0]][files_id[0]], 370 | ) 371 | table_data.extend(row_data) 372 | return table_data 373 | 374 | 375 | def get_general_stats_layout( 376 | base_model: str, 377 | ) -> html.Div: 378 | data_for_base_model = [data.get(base_model, []) for data in get_table_data()] 379 | custom_stats = {} 380 | for name, func in get_general_custom_stats().items(): 381 | errors_dict = {} 382 | custom_stats[name] = catch_eval_exception( 383 | [], 384 | func, 385 | data_for_base_model, 386 | "Got error when applying function", 387 | errors_dict, 388 | ) 389 | if len(errors_dict): 390 | logging.error(ERROR_MESSAGE_TEMPLATE.format(name, errors_dict)) 391 | 392 | overall_samples = sum(len(question_data) for question_data in data_for_base_model) 393 | dataset_size = len(list(filter(lambda x: bool(x), data_for_base_model))) 394 | stats = { 395 | "dataset size": dataset_size, 396 | "overall number of samples": overall_samples, 397 | "generations per sample": (overall_samples / dataset_size if dataset_size else 0), 398 | **custom_stats, 399 | } 400 | return [html.Div([html.Pre(f"{name}: {value}") for name, value in stats.items()])] 401 | -------------------------------------------------------------------------------- /nemo_inspector/callbacks/analyze_page/detailed_sample_info.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | from typing import Dict, List, Tuple 17 | 18 | from dash import ALL, callback_context, no_update 19 | from dash.dependencies import Input, Output, State 20 | from dash.exceptions import PreventUpdate 21 | 22 | from nemo_inspector.callbacks import app 23 | from nemo_inspector.layouts import ( 24 | get_detailed_info_table_column, 25 | get_table_column_header, 26 | get_detailed_info_table_row_content, 27 | get_detailed_info_table_content, 28 | ) 29 | from nemo_inspector.settings.constants import ( 30 | EDIT_ICON_PATH, 31 | FILE_NAME, 32 | MODEL_SELECTOR_ID, 33 | SAVE_ICON_PATH, 34 | ) 35 | from nemo_inspector.utils.common import ( 36 | get_available_models, 37 | get_compared_rows, 38 | get_editable_rows, 39 | get_excluded_row, 40 | get_file_id, 41 | get_filtered_files, 42 | get_table_data, 43 | ) 44 | 45 | 46 | @app.callback( 47 | [ 48 | Output("js_container", "children", allow_duplicate=True), 49 | Output("js_trigger", "children", allow_duplicate=True), 50 | ], 51 | Input({"type": "editable_row", "id": ALL, "model_name": ALL}, "value"), 52 | [ 53 | State({"type": "model_selector", "id": ALL}, "value"), 54 | State("datatable", "selected_rows"), 55 | State("datatable", "page_current"), 56 | State("datatable", "page_size"), 57 | State({"type": "editable_row", "id": ALL, "model_name": ALL}, "id"), 58 | State({"type": "file_selector", "id": ALL}, "value"), 59 | State("js_trigger", "children"), 60 | ], 61 | prevent_initial_call=True, 62 | ) 63 | def update_data_table( 64 | new_rows_values: List[str], 65 | models: List[str], 66 | idx: List[int], 67 | current_page: int, 68 | page_size: int, 69 | new_rows_ids: List[str], 70 | file_names: List[str], 71 | js_trigger: str, 72 | ) -> Tuple[str, str]: 73 | ctx = callback_context 74 | if not ctx.triggered or not idx: 75 | return no_update, no_update 76 | 77 | file_ids = {} 78 | question_id = current_page * page_size + idx[0] 79 | for model_id, name in enumerate(file_names): 80 | for file_id, file in enumerate( 81 | get_table_data()[question_id][models[model_id]] 82 | if len(get_table_data()) 83 | else [] 84 | ): 85 | if file[FILE_NAME] == name: 86 | file_ids[models[model_id]] = file_id 87 | 88 | for new_rows_id, new_rows_value in zip(new_rows_ids, new_rows_values): 89 | updated_field = new_rows_id["id"] 90 | updated_model = new_rows_id["model_name"] 91 | get_table_data()[question_id][updated_model][file_ids[updated_model]][ 92 | updated_field 93 | ] = new_rows_value 94 | 95 | return "", js_trigger + " " 96 | 97 | 98 | @app.callback( 99 | [ 100 | Output( 101 | {"type": "detailed_models_answers", "id": ALL}, 102 | "children", 103 | allow_duplicate=True, 104 | ), 105 | Output( 106 | {"type": "filter_function_input", "id": ALL}, 107 | "value", 108 | allow_duplicate=True, 109 | ), 110 | Output( 111 | {"type": "sorting_function_input", "id": ALL}, 112 | "value", 113 | allow_duplicate=True, 114 | ), 115 | ], 116 | [ 117 | Input("datatable", "selected_rows"), 118 | Input( 119 | "dummy_output", 120 | "children", 121 | ), 122 | ], 123 | [ 124 | State({"type": "model_selector", "id": ALL}, "value"), 125 | State({"type": "sorting_function_input", "id": ALL}, "value"), 126 | State({"type": "filter_function_input", "id": ALL}, "value"), 127 | State({"type": "row_name", "id": ALL}, "children"), 128 | State("datatable", "page_current"), 129 | State("datatable", "page_size"), 130 | State({"type": "file_selector", "id": ALL}, "value"), 131 | State({"type": "text_modes", "id": ALL}, "value"), 132 | ], 133 | prevent_initial_call=True, 134 | ) 135 | def show_item( 136 | idx: List[int], 137 | dummmy_trigger: str, 138 | models: List[str], 139 | sorting_functions: List[str], 140 | filter_functions: List[str], 141 | rows_names: List[str], 142 | current_page: int, 143 | page_size: int, 144 | file_names: List[str], 145 | text_modes: List[List[str]], 146 | ) -> List[str]: 147 | if not idx: 148 | raise PreventUpdate 149 | ctx = callback_context 150 | if not ctx.triggered: 151 | return [no_update, no_update, no_update] 152 | elif ctx.triggered[0]["prop_id"] == "datatable.selected_rows": 153 | filter_functions = [filter_functions[0]] + [""] * (len(filter_functions) - 1) 154 | sorting_functions = [sorting_functions[0]] + [None] * (len(sorting_functions) - 1) 155 | question_id = current_page * page_size + idx[0] 156 | file_ids = [0] * len(models) 157 | for column_id in range(len(file_names)): 158 | files = ( 159 | get_table_data()[question_id][models[column_id]] 160 | if len(get_table_data()) 161 | else [] 162 | ) 163 | file_ids[column_id] = get_file_id(file_names, files, column_id) 164 | return [ 165 | get_detailed_info_table_content( 166 | question_id=question_id, 167 | rows_names=rows_names, 168 | models=models, 169 | files_id=file_ids, 170 | filter_functions=filter_functions[1:], 171 | sorting_functions=sorting_functions[1:], 172 | text_modes=text_modes, 173 | ), 174 | [""] * len(filter_functions), 175 | sorting_functions, 176 | ] 177 | 178 | 179 | @app.callback( 180 | Output( 181 | "dummy_output", 182 | "children", 183 | allow_duplicate=True, 184 | ), 185 | Input({"type": "compare_texts_button", "id": ALL}, "n_clicks"), 186 | [ 187 | State("dummy_output", "children"), 188 | State({"type": "row_name", "id": ALL}, "children"), 189 | State({"type": "compare_texts_button", "id": ALL}, "n_clicks"), 190 | ], 191 | prevent_initial_call=True, 192 | ) 193 | def compare(n_clicks: List[int], dummy_data: str, row_names: str, button_ids: List[str]): 194 | ctx = callback_context 195 | if not ctx.triggered or not n_clicks: 196 | return no_update 197 | button_id = json.loads(ctx.triggered[0]["prop_id"].split(".")[0])["id"] 198 | if row_names[button_id] not in get_compared_rows(): 199 | get_compared_rows().add(row_names[button_id]) 200 | else: 201 | get_compared_rows().remove(row_names[button_id]) 202 | return dummy_data + "1" 203 | 204 | 205 | @app.callback( 206 | [ 207 | Output( 208 | "dummy_output", 209 | "children", 210 | allow_duplicate=True, 211 | ), 212 | Output({"type": "edit_row_image", "id": ALL}, "src"), 213 | ], 214 | Input({"type": "edit_row_button", "id": ALL}, "n_clicks"), 215 | [ 216 | State({"type": "row_name", "id": ALL}, "children"), 217 | State({"type": "edit_row_image", "id": ALL}, "id"), 218 | State({"type": "edit_row_image", "id": ALL}, "src"), 219 | State({"type": "model_selector", "id": ALL}, "value"), 220 | State("datatable", "selected_rows"), 221 | State("datatable", "page_current"), 222 | State("datatable", "page_size"), 223 | State({"type": "file_selector", "id": ALL}, "value"), 224 | State( 225 | "dummy_output", 226 | "children", 227 | ), 228 | ], 229 | prevent_initial_call=True, 230 | ) 231 | def edit_row( 232 | n_clicks: List[int], 233 | rows: List[str], 234 | button_ids: List[Dict], 235 | edit_row_labels: List[str], 236 | models: List[str], 237 | idx: List[int], 238 | current_page: int, 239 | page_size: int, 240 | file_names: List[str], 241 | dummy_data: str, 242 | ) -> Tuple[str, List[str]]: 243 | ctx = callback_context 244 | if not ctx.triggered or not n_clicks or not idx: 245 | return no_update, [no_update] * len(button_ids) 246 | button_id = json.loads(ctx.triggered[0]["prop_id"].split(".")[0])["id"] 247 | row_index = 0 248 | for i, current_button_id in enumerate(button_ids): 249 | if current_button_id["id"] == button_id: 250 | row_index = i 251 | break 252 | file_ids = [0] * len(models) 253 | question_id = current_page * page_size + idx[0] 254 | for model_id, name in enumerate(file_names): 255 | for file_id, file in enumerate( 256 | get_table_data()[question_id][models[model_id]] 257 | if len(get_table_data()) 258 | else [] 259 | ): 260 | if file[FILE_NAME] == name: 261 | file_ids[model_id] = file_id 262 | 263 | if not n_clicks[row_index]: 264 | return no_update, [no_update] * len(button_ids) 265 | 266 | if rows[row_index] in get_editable_rows(): 267 | edit_row_labels[row_index] = EDIT_ICON_PATH 268 | get_editable_rows().remove(rows[row_index]) 269 | else: 270 | get_editable_rows().add(rows[row_index]) 271 | edit_row_labels[row_index] = SAVE_ICON_PATH 272 | 273 | return dummy_data + "1", edit_row_labels 274 | 275 | 276 | @app.callback( 277 | [ 278 | Output( 279 | "dummy_output", 280 | "children", 281 | allow_duplicate=True, 282 | ), 283 | Output({"type": "del_row", "id": ALL}, "children"), 284 | ], 285 | Input({"type": "del_row", "id": ALL}, "n_clicks"), 286 | [ 287 | State({"type": "row_name", "id": ALL}, "children"), 288 | State({"type": "del_row", "id": ALL}, "id"), 289 | State({"type": "del_row", "id": ALL}, "children"), 290 | State( 291 | "dummy_output", 292 | "children", 293 | ), 294 | ], 295 | prevent_initial_call=True, 296 | ) 297 | def del_row( 298 | n_clicks: List[int], 299 | rows: List[str], 300 | button_ids: List[Dict], 301 | del_row_labels: List[str], 302 | dummy_data: str, 303 | ) -> Tuple[str, List[str]]: 304 | ctx = callback_context 305 | if not ctx.triggered or not n_clicks: 306 | return no_update, [no_update] * len(button_ids) 307 | button_id = json.loads(ctx.triggered[0]["prop_id"].split(".")[0])["id"] 308 | row_index = 0 309 | for i, current_button_id in enumerate(button_ids): 310 | if current_button_id["id"] == button_id: 311 | row_index = i 312 | break 313 | if not n_clicks[row_index]: 314 | return no_update, [no_update] * len(button_ids) 315 | if rows[row_index] in get_excluded_row(): 316 | get_excluded_row().remove(rows[row_index]) 317 | del_row_labels[row_index] = "-" 318 | else: 319 | get_excluded_row().add(rows[row_index]) 320 | del_row_labels[row_index] = "+" 321 | 322 | return dummy_data + "1", del_row_labels 323 | 324 | 325 | @app.callback( 326 | [ 327 | Output( 328 | "detailed_answers_header", 329 | "children", 330 | allow_duplicate=True, 331 | ), 332 | Output( 333 | {"type": "detailed_answers_row", "id": ALL}, 334 | "children", 335 | allow_duplicate=True, 336 | ), 337 | ], 338 | Input("add_model", "n_clicks"), 339 | [ 340 | State("detailed_answers_header", "children"), 341 | State({"type": "detailed_answers_row", "id": ALL}, "children"), 342 | State({"type": "model_selector", "id": ALL}, "id"), 343 | State("datatable", "selected_rows"), 344 | ], 345 | prevent_initial_call=True, 346 | ) 347 | def add_model( 348 | n_clicks: int, 349 | header: List, 350 | rows: List, 351 | selectors_ids: List[int], 352 | idx: List[int], 353 | ) -> Tuple[List, List]: 354 | if not n_clicks: 355 | return no_update, [no_update] * len(rows) 356 | available_models = list(get_available_models().keys()) 357 | last_header_id = selectors_ids[-1]["id"] if selectors_ids != [] else -1 358 | header.append( 359 | get_table_column_header( 360 | available_models, available_models[0], last_header_id + 1, True 361 | ) 362 | ) 363 | last_cell_id = rows[-1][-1]["props"]["children"]["props"]["id"]["id"] 364 | for i, row in enumerate(rows): 365 | row.append( 366 | get_detailed_info_table_column( 367 | last_cell_id + i + 1, 368 | file_id=last_header_id + 1 if i == 0 and idx else None, 369 | ) 370 | ) 371 | 372 | return header, rows 373 | 374 | 375 | @app.callback( 376 | [ 377 | Output("detailed_answers_header", "children"), 378 | Output({"type": "detailed_answers_row", "id": ALL}, "children"), 379 | ], 380 | Input({"type": "del_model", "id": ALL}, "n_clicks"), 381 | [ 382 | State("detailed_answers_header", "children"), 383 | State({"type": "detailed_answers_row", "id": ALL}, "children"), 384 | State({"type": "del_model", "id": ALL}, "id"), 385 | ], 386 | prevent_initial_call=True, 387 | ) 388 | def del_model( 389 | n_clicks: List[int], 390 | header: List, 391 | rows: List, 392 | id_del: List[int], 393 | ) -> Tuple[List, List]: 394 | ctx = callback_context 395 | if not ctx.triggered: 396 | return no_update, [no_update] * len(rows) 397 | 398 | button_id = json.loads(ctx.triggered[0]["prop_id"].split(".")[0])["id"] 399 | 400 | if not ctx.triggered[0]["value"]: 401 | return no_update, [no_update] * len(rows) 402 | 403 | for i, id in enumerate(id_del): 404 | if id["id"] == button_id: 405 | index = i + 2 406 | 407 | header.pop(index) 408 | for i, row in enumerate(rows): 409 | row.pop(index) 410 | 411 | return header, rows 412 | 413 | 414 | @app.callback( 415 | [ 416 | Output( 417 | {"type": "detailed_models_answers", "id": ALL}, 418 | "children", 419 | allow_duplicate=True, 420 | ), 421 | Output( 422 | "dummy_output", 423 | "children", 424 | allow_duplicate=True, 425 | ), 426 | ], 427 | [ 428 | Input({"type": "file_selector", "id": ALL}, "value"), 429 | Input({"type": "text_modes", "id": ALL}, "value"), 430 | ], 431 | [ 432 | State("datatable", "selected_rows"), 433 | State({"type": "file_selector", "id": ALL}, "options"), 434 | State({"type": "model_selector", "id": ALL}, "value"), 435 | State({"type": "model_selector", "id": ALL}, "id"), 436 | State({"type": "row_name", "id": ALL}, "children"), 437 | State("datatable", "page_current"), 438 | State("datatable", "page_size"), 439 | State( 440 | {"type": "detailed_models_answers", "id": ALL}, 441 | "children", 442 | ), 443 | State("dummy_output", "children"), 444 | ], 445 | prevent_initial_call=True, 446 | ) 447 | def change_file( 448 | file_names: List[str], 449 | text_modes: List[List[str]], 450 | idx: List[int], 451 | file_options: List[str], 452 | models: List[str], 453 | model_ids: List[int], 454 | rows_names: List[str], 455 | current_page: int, 456 | page_size: int, 457 | table_data: List[str], 458 | dummy_data: str, 459 | ) -> List[str]: 460 | if not idx: 461 | raise PreventUpdate 462 | 463 | ctx = callback_context 464 | if not ctx.triggered: 465 | return [no_update] * len(table_data), no_update 466 | 467 | question_id = page_size * current_page + idx[0] 468 | for trigger in ctx.triggered: 469 | try: 470 | button_id = model_ids.index( 471 | json.loads( 472 | MODEL_SELECTOR_ID.format( 473 | json.loads(trigger["prop_id"].split(".")[0])["id"] 474 | ) 475 | ) 476 | ) 477 | except ValueError: 478 | continue 479 | 480 | model = models[button_id] 481 | 482 | file_id = get_file_id(file_names, get_table_data()[question_id][model], button_id) 483 | base_file_id = get_file_id( 484 | file_names, get_table_data()[question_id][models[0]], 0 485 | ) 486 | 487 | question_id = current_page * page_size + idx[0] 488 | table_data[button_id * len(rows_names) : (button_id + 1) * len(rows_names)] = ( 489 | get_detailed_info_table_row_content( 490 | question_id=question_id, 491 | model=model, 492 | file_id=file_id, 493 | rows_names=rows_names, 494 | files_names=[option["value"] for option in file_options[button_id]], 495 | col_id=button_id, 496 | text_modes=text_modes[button_id], 497 | compare_to=get_table_data()[question_id][models[0]][base_file_id], 498 | ) 499 | ) 500 | return table_data, dummy_data + "1" if button_id == 0 else dummy_data 501 | 502 | 503 | @app.callback( 504 | [ 505 | Output({"type": "file_selector", "id": ALL}, "options"), 506 | Output({"type": "file_selector", "id": ALL}, "value"), 507 | ], 508 | [ 509 | Input({"type": "apply_filter_button", "id": ALL}, "n_clicks"), 510 | Input({"type": "apply_sorting_button", "id": ALL}, "n_clicks"), 511 | Input({"type": "model_selector", "id": ALL}, "value"), 512 | ], 513 | [ 514 | State({"type": "model_selector", "id": ALL}, "id"), 515 | State({"type": "sorting_function_input", "id": ALL}, "value"), 516 | State({"type": "filter_function_input", "id": ALL}, "value"), 517 | State({"type": "apply_on_filtered_data", "id": ALL}, "value"), 518 | State("datatable", "page_current"), 519 | State("datatable", "page_size"), 520 | State("datatable", "selected_rows"), 521 | State({"type": "file_selector", "id": ALL}, "options"), 522 | State({"type": "file_selector", "id": ALL}, "value"), 523 | ], 524 | prevent_initial_call=True, 525 | ) 526 | def change_files_order( 527 | filter_n_click: int, 528 | sorting_n_click: int, 529 | models: List[str], 530 | model_ids: List[int], 531 | sorting_functions: List[str], 532 | filter_functions: List[str], 533 | apply_on_filtered_data: List[int], 534 | current_page: int, 535 | page_size: int, 536 | idx: List[int], 537 | file_selector_options: List[str], 538 | file_selector_values: List[str], 539 | ) -> Tuple[List[List[str]], List[str]]: 540 | no_updates = [no_update] * len(file_selector_options) 541 | if not filter_n_click and not sorting_n_click: 542 | return no_updates, no_updates 543 | if not idx: 544 | raise PreventUpdate 545 | ctx = callback_context 546 | if not ctx.triggered: 547 | return no_updates, no_updates 548 | try: 549 | button_id = model_ids.index( 550 | json.loads( 551 | MODEL_SELECTOR_ID.format( 552 | json.loads(ctx.triggered[-1]["prop_id"].split(".")[0])["id"] 553 | ) 554 | ) 555 | ) 556 | except ValueError: 557 | return no_updates, no_updates 558 | 559 | if not ctx.triggered[0]["value"] or button_id == -1: 560 | return no_updates, no_updates 561 | model = models[button_id] 562 | question_id = current_page * page_size + idx[0] 563 | array_to_filter = ( 564 | get_table_data()[question_id][model] 565 | if not apply_on_filtered_data or not apply_on_filtered_data[button_id] 566 | else list( 567 | filter( 568 | lambda data: data[FILE_NAME] 569 | in [file_name["label"] for file_name in file_selector_options], 570 | get_table_data()[question_id][model], 571 | ) 572 | ) 573 | ) 574 | file_selector_options[button_id] = [ 575 | {"label": data[FILE_NAME], "value": data[FILE_NAME]} 576 | for data in get_filtered_files( 577 | filter_functions[button_id + 1], 578 | sorting_functions[button_id + 1], 579 | array_to_filter, 580 | ) 581 | ] 582 | file_selector_values[button_id] = file_selector_options[button_id][0] 583 | 584 | return file_selector_options, file_selector_values 585 | --------------------------------------------------------------------------------