├── decorators ├── __init__.py └── graph_output.py ├── pages ├── home │ ├── functions │ │ └── __init__.py │ ├── outputs │ │ └── __init__.py │ ├── __init__.py │ ├── callbacks.py │ └── layout │ │ ├── __init__.py │ │ └── landing_layout.py ├── graphs │ ├── functions │ │ ├── __init__.py │ │ ├── filter_dataframe.py │ │ ├── graph_generators.py │ │ └── data_processing.py │ ├── outputs │ │ ├── __init__.py │ │ ├── network_editor.py │ │ └── network_inspector.py │ ├── __init__.py │ ├── data.py │ ├── layout │ │ ├── graphs_layout │ │ │ ├── __init__.py │ │ │ ├── editor_layout.py │ │ │ └── inspector_layout.py │ │ └── __init__.py │ └── callbacks.py ├── interact │ ├── __init__.py │ ├── outputs │ │ ├── __init__.py │ │ └── fired_rules.py │ ├── functions │ │ ├── __init__.py │ │ ├── filter_dataframe.py │ │ ├── graph_generators.py │ │ ├── data_processing.py │ │ └── twc.py │ ├── layout │ │ ├── filters_layout.py │ │ ├── graphs_layout │ │ │ ├── __init__.py │ │ │ ├── inspect_layout.py │ │ │ └── interact_layout.py │ │ ├── configuration_layout.py │ │ └── __init__.py │ ├── data.py │ ├── callbacks.py │ └── twc_agent.py ├── scalars │ ├── __init__.py │ ├── outputs │ │ ├── __init__.py │ │ └── scalars.py │ ├── functions │ │ ├── __init__.py │ │ ├── filter_dataframe.py │ │ └── data_processing.py │ ├── layout │ │ ├── filters_layout.py │ │ ├── graphs_layout │ │ │ ├── __init__.py │ │ │ └── scalars_layout.py │ │ ├── configuration_layout.py │ │ └── __init__.py │ ├── data.py │ └── callbacks.py └── __init__.py ├── utils_demo ├── __init__.py ├── percentage_format.py └── callbacks │ ├── __init__.py │ ├── reset_page_callback.py │ ├── select_all_callback.py │ └── download_callback.py ├── assets ├── globe-icon.png └── robot-icon.jpeg ├── static └── data │ ├── scalars_example.npz │ ├── record_lnn_wts_N4_pos1_neg2.npy │ ├── scalars_lnn_wts_N4_pos1_neg2.npz │ ├── textworld_logo.txt │ ├── network1.json │ └── scalars_example.json ├── .gitmodules ├── components ├── __init__.py ├── carbon_jumbotron.py ├── ui_shell.py ├── fired_rule_card.py ├── graph_card.py ├── dialog_row.py ├── network_card.py └── next_actions_row.py ├── requirements.txt ├── global_callbacks.py ├── router.py ├── LICENSE ├── app.py ├── .gitignore └── README.md /decorators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pages/home/functions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pages/home/outputs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pages/graphs/functions/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pages/home/__init__.py: -------------------------------------------------------------------------------- 1 | from .layout import layout 2 | -------------------------------------------------------------------------------- /pages/interact/__init__.py: -------------------------------------------------------------------------------- 1 | from .layout import layout 2 | -------------------------------------------------------------------------------- /pages/scalars/__init__.py: -------------------------------------------------------------------------------- 1 | from .layout import layout 2 | -------------------------------------------------------------------------------- /pages/scalars/outputs/__init__.py: -------------------------------------------------------------------------------- 1 | from .scalars import scalars 2 | -------------------------------------------------------------------------------- /pages/interact/outputs/__init__.py: -------------------------------------------------------------------------------- 1 | from .fired_rules import fired_rules 2 | -------------------------------------------------------------------------------- /utils_demo/__init__.py: -------------------------------------------------------------------------------- 1 | from .percentage_format import percentage_format 2 | -------------------------------------------------------------------------------- /pages/scalars/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .filter_dataframe import filter_dataframe 2 | -------------------------------------------------------------------------------- /assets/globe-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/nesa-demo/HEAD/assets/globe-icon.png -------------------------------------------------------------------------------- /assets/robot-icon.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/nesa-demo/HEAD/assets/robot-icon.jpeg -------------------------------------------------------------------------------- /static/data/scalars_example.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/nesa-demo/HEAD/static/data/scalars_example.npz -------------------------------------------------------------------------------- /utils_demo/percentage_format.py: -------------------------------------------------------------------------------- 1 | def percentage_format(x: float) -> str: 2 | return f"{(x * 100):.1f}%" 3 | -------------------------------------------------------------------------------- /pages/graphs/outputs/__init__.py: -------------------------------------------------------------------------------- 1 | from .network_editor import network_editor 2 | from .network_inspector import network_inspector 3 | -------------------------------------------------------------------------------- /static/data/record_lnn_wts_N4_pos1_neg2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/nesa-demo/HEAD/static/data/record_lnn_wts_N4_pos1_neg2.npy -------------------------------------------------------------------------------- /static/data/scalars_lnn_wts_N4_pos1_neg2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/nesa-demo/HEAD/static/data/scalars_lnn_wts_N4_pos1_neg2.npz -------------------------------------------------------------------------------- /pages/interact/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .filter_dataframe import filter_dataframe 2 | from .twc import load_twc_game 3 | from .twc import get_agent_actions 4 | -------------------------------------------------------------------------------- /pages/graphs/__init__.py: -------------------------------------------------------------------------------- 1 | from .layout import layout 2 | from .outputs.network_editor import network_editor 3 | from .outputs.network_inspector import network_inspector 4 | -------------------------------------------------------------------------------- /utils_demo/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .download_callback import download_callback 2 | from .select_all_callback import select_all_callback 3 | from .reset_page_callback import reset_page_callback 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/loa"] 2 | path = third_party/loa 3 | url = git@github.com:IBM/LOA.git 4 | [submodule "third_party/commonsense_rl"] 5 | path = third_party/commonsense_rl 6 | url = git@github.com:daiki-kimura/commonsense-rl.git 7 | -------------------------------------------------------------------------------- /pages/home/callbacks.py: -------------------------------------------------------------------------------- 1 | from dash import Dash 2 | from dash.dependencies import Input, Output, State 3 | 4 | from . import outputs 5 | from utils_demo.callbacks import select_all_callback 6 | 7 | 8 | def register(app: Dash): 9 | # placeholder 10 | x = 1 11 | -------------------------------------------------------------------------------- /components/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph_card import graph_card 2 | from .network_card import network_card 3 | from .fired_rule_card import fired_rule_card 4 | from .ui_shell import ui_shell 5 | from .dialog_row import dialog_row 6 | from .next_actions_row import next_actions_row 7 | 8 | -------------------------------------------------------------------------------- /pages/graphs/functions/filter_dataframe.py: -------------------------------------------------------------------------------- 1 | from pandas import DataFrame 2 | from typing import List 3 | 4 | 5 | def filter_dataframe(data: DataFrame, years: List[str], countries: List[str]): 6 | criteria = (data['year'].isin(years)) & \ 7 | (data['country'].isin(countries)) 8 | return data[criteria] 9 | -------------------------------------------------------------------------------- /pages/interact/functions/filter_dataframe.py: -------------------------------------------------------------------------------- 1 | from pandas import DataFrame 2 | from typing import List 3 | 4 | 5 | def filter_dataframe(data: DataFrame, years: List[str], countries: List[str]): 6 | criteria = (data['year'].isin(years)) & \ 7 | (data['country'].isin(countries)) 8 | return data[criteria] 9 | -------------------------------------------------------------------------------- /pages/scalars/outputs/scalars.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import plotly.express as px 3 | import plotly.graph_objs as go 4 | from decorators.graph_output import graph_output 5 | 6 | 7 | @graph_output('scalars') 8 | def scalars(data: pd.DataFrame) -> go.Figure: 9 | return px.line(data, x="epoch", y="loss", color='loss_type') 10 | -------------------------------------------------------------------------------- /pages/scalars/functions/filter_dataframe.py: -------------------------------------------------------------------------------- 1 | from pandas import DataFrame 2 | from typing import List 3 | import pandas as pd 4 | 5 | 6 | def filter_dataframe(data: DataFrame, loss_type: List[str]): 7 | dfs = [] 8 | for l in loss_type: 9 | group, metric = l.split(": ") 10 | dfs += [data[(data['group'] == group) & (data['loss_type'] == metric)]] 11 | 12 | return pd.concat(dfs) 13 | -------------------------------------------------------------------------------- /pages/interact/layout/filters_layout.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_html_components as html 3 | from pages.interact import data 4 | 5 | filters_layout = [ 6 | html.Div(children=[ 7 | dca.Dropdown( 8 | id='game_level_selection', 9 | options=data.options_game_level, 10 | value='Easy', 11 | label='Select Game Level' 12 | ), 13 | ]), 14 | ] 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dash==1.20.0 2 | dash-carbon-components==0.1.6 3 | dash-cytoscape==0.2.0 4 | dash-extensions==0.0.53 5 | Flask==1.1.2 6 | flask-oidc==1.4.0 7 | numpy==1.21.1 8 | pandas==1.2.4 9 | plotly==4.14.3 10 | urllib3==1.26.7 11 | gunicorn==20.1.0 12 | cdd==0.1.4 13 | numpy==1.21.1 14 | textworld==1.2.0 15 | tqdm==4.50.2 16 | requests==2.26.0 17 | matplotlib==3.5.0 18 | jinja2==3.0.3 19 | itsdangerous==2.0.1 20 | werkzeug==2.0.3 21 | -------------------------------------------------------------------------------- /pages/scalars/layout/filters_layout.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_html_components as html 3 | from pages.scalars import data 4 | 5 | filters_layout = [ 6 | html.Div(children=[ 7 | dca.MultiSelect( 8 | id="loss_type_selection", 9 | options=data.options_loss_checklist, 10 | value=data.value_loss_checklist, 11 | label="Select Metrics" 12 | ) 13 | ]), 14 | ] 15 | -------------------------------------------------------------------------------- /pages/scalars/layout/graphs_layout/__init__.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | from .scalars_layout import scalars_layout 3 | 4 | graphs_layout = [ 5 | dca.Tabs( 6 | style={'width': '100%', 'backgroundColor': 'white'}, 7 | id='scalars_tabs', 8 | headerSizes=['lg-10'], 9 | value='training_tab', 10 | children=[ 11 | dca.Tab(value='training_tab', label='Training', children=scalars_layout), 12 | ]) 13 | ] 14 | -------------------------------------------------------------------------------- /global_callbacks.py: -------------------------------------------------------------------------------- 1 | from dash import Dash 2 | from dash.dependencies import Output, Input, ALL, State 3 | 4 | 5 | def register_global_callbacks(app: Dash): 6 | # Callback to reload the page when the user press the reset button 7 | @app.callback( 8 | Output('url', 'pathname'), 9 | Input({'type': 'reset', 'page': ALL}, 'n_clicks'), 10 | State('url', 'pathname'), 11 | prevent_initial_call=True 12 | ) 13 | def reset_configuration(reset, pathname): 14 | return pathname 15 | -------------------------------------------------------------------------------- /pages/__init__.py: -------------------------------------------------------------------------------- 1 | EASY_LEVEL = 'Easy' 2 | MEDIUM_LEVEL = 'Medium' 3 | HARD_LEVEL = 'Hard' 4 | 5 | # LEVELS = [EASY_LEVEL, MEDIUM_LEVEL, HARD_LEVEL] 6 | LEVELS = [EASY_LEVEL] 7 | 8 | DL_AGENT = True 9 | LOA_AGENT = True 10 | 11 | MESSAGE_FOR_SELECT_NEXT_ACTION = 'Select next action' 12 | MESSAGE_FOR_DONE = \ 13 | 'The game is Done. Please reload the page, if you want to retry.' 14 | 15 | REPO_URL = 'https://github.com/IBM/nesa-demo' 16 | LOA_REPO_URL = 'https://github.com/IBM/LOA' 17 | RIGHT_TOP_MESSAGE = 'Repo: ' + REPO_URL 18 | -------------------------------------------------------------------------------- /utils_demo/callbacks/reset_page_callback.py: -------------------------------------------------------------------------------- 1 | from dash import Dash 2 | from dash.dependencies import Output, Input, ALL, State 3 | 4 | 5 | def reset_page_callback(app: Dash): 6 | # Callback to reload the page when the user press the reset button 7 | @app.callback( 8 | Output('url', 'pathname'), 9 | Input({'type': 'reset', 'page': ALL}, 'n_clicks'), 10 | State('url', 'pathname'), 11 | prevent_initial_call=True 12 | ) 13 | def __reset_page_callback(reset, pathname): 14 | return pathname 15 | -------------------------------------------------------------------------------- /utils_demo/callbacks/select_all_callback.py: -------------------------------------------------------------------------------- 1 | from dash import Dash 2 | from dash.dependencies import Input, Output, State 3 | 4 | 5 | def select_all_callback(app: Dash, button_id: str, dropdown_id: str): 6 | @app.callback( 7 | Output(dropdown_id, 'value'), 8 | Input(button_id, 'n_clicks'), 9 | State(dropdown_id, 'options') 10 | ) 11 | def __select_all_callback(click, options): 12 | values = [ 13 | e if type(e) != dict else e['value'] for e in 14 | options 15 | ] 16 | return values 17 | -------------------------------------------------------------------------------- /pages/graphs/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .functions.data_processing import * 3 | 4 | home_dir = './' 5 | 6 | network = load_network(os.path.join(home_dir, 'static', 'data', 'network1.json')) 7 | scalars = create_loss_data(load_scalars(os.path.join(home_dir, 'static', 'data', 'scalars_lnn_wts_N4_pos1_neg2.npz'))) 8 | roots_names, nodes, edges = create_network(network_dict=network) 9 | slider_dict = create_slider_dict(network['epochs']) 10 | options_loss_checklist = [{"label": x, "value": x} for x in scalars['loss_type'].unique()] 11 | value_loss_checklist = [x for x in scalars['loss_type'].unique()] -------------------------------------------------------------------------------- /pages/graphs/layout/graphs_layout/__init__.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | from .editor_layout import editor_layout 3 | from .inspector_layout import inspector_layout 4 | 5 | graphs_layout = [ 6 | dca.Tabs( 7 | style={'width': '100%', 'backgroundColor': 'white'}, 8 | id='graphs_tabs', 9 | headerSizes=['lg-10'], 10 | value='editor_tab', 11 | children=[ 12 | dca.Tab(value='editor_tab', label='Edit Network', children=editor_layout), 13 | dca.Tab(value='inspector_tab', label='Inspect Network', children=inspector_layout), 14 | ]) 15 | ] 16 | -------------------------------------------------------------------------------- /pages/interact/layout/graphs_layout/__init__.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | from .interact_layout import interact_layout 3 | from .inspect_layout import inspect_layout 4 | 5 | graphs_layout = [ 6 | dca.Tabs( 7 | style={'width': '100%', 'backgroundColor': 'white'}, 8 | id='interact_tabs', 9 | headerSizes=['lg-10'], 10 | value='interact_tab', 11 | children=[ 12 | dca.Tab(value='interact_tab', 13 | label='Agent Interaction', children=interact_layout), 14 | # dca.Tab(value='inspect_tab', 15 | # label='Agent Inspection', children=inspect_layout), 16 | ]) 17 | ] 18 | -------------------------------------------------------------------------------- /router.py: -------------------------------------------------------------------------------- 1 | from dash import Dash 2 | from dash.dependencies import Output, Input 3 | 4 | from pages.home.layout import layout as home_layout 5 | from pages.scalars.layout import layout as scalars_layout 6 | from pages.graphs.layout import layout as graphs_layout 7 | from pages.interact.layout import layout as interact_layout 8 | 9 | pages = { 10 | '/': home_layout, 11 | '/scalars': scalars_layout, 12 | '/graphs': graphs_layout, 13 | '/interact': interact_layout, 14 | } 15 | 16 | 17 | def register_router(app: Dash): 18 | # Router callback 19 | @app.callback( 20 | Output('page-content', 'children'), 21 | Input('url', 'pathname'), 22 | ) 23 | def display_page(pathname): 24 | return pages[pathname] 25 | -------------------------------------------------------------------------------- /pages/scalars/layout/graphs_layout/scalars_layout.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | from components import graph_card 3 | 4 | scalars_layout = dca.Grid( 5 | style={'padding': '16px', 6 | 'height': 'calc(100% - 40px)', 7 | 'overflow': 'auto'}, 8 | className='bx--grid--narrow bx--grid--full-width', 9 | children=[ 10 | dca.Row(children=[ 11 | dca.Column(columnSizes=['sm-4'], children=[ 12 | graph_card( 13 | graph_id='scalars', 14 | graph_name='Logical Neural Network Losses by Epoch', 15 | graph_info='Logical Neural Network Losses by Epoch', 16 | height=500 17 | ), 18 | ]), 19 | ]), 20 | ]) 21 | -------------------------------------------------------------------------------- /components/carbon_jumbotron.py: -------------------------------------------------------------------------------- 1 | import dash_bootstrap_components as dbc 2 | import dash_html_components as html 3 | 4 | fluid_jumbotron = dbc.Jumbotron( 5 | [ 6 | dbc.Container( 7 | [ 8 | html.H1("Fluid jumbotron", className="display-3"), 9 | html.P( 10 | "This jumbotron occupies the entire horizontal " 11 | "space of its parent.", 12 | className="lead", 13 | ), 14 | html.P( 15 | "You will need to embed a fluid container in " 16 | "the jumbotron.", 17 | className="lead", 18 | ), 19 | ], 20 | fluid=True, 21 | ) 22 | ], 23 | fluid=True, 24 | ) -------------------------------------------------------------------------------- /pages/graphs/layout/graphs_layout/editor_layout.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | from components import network_card 3 | 4 | editor_layout = dca.Grid( 5 | style={'padding': '16px', 6 | 'height': 'calc(100% - 40px)', 7 | 'overflow': 'auto'}, 8 | className='bx--grid--narrow bx--grid--full-width', 9 | children=[ 10 | dca.Row(children=[ 11 | dca.Column(columnSizes=['sm-4'], children=[ 12 | network_card( 13 | graph_id='network_editor', 14 | graph_name='Edit Logical Neural Network Weights', 15 | graph_info='Edit Logical Neural Network Weights', 16 | table=True, 17 | height=500, 18 | ), 19 | ]), 20 | ]), 21 | ]) 22 | -------------------------------------------------------------------------------- /pages/graphs/layout/graphs_layout/inspector_layout.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | from components import network_card 3 | 4 | inspector_layout = dca.Grid( 5 | style={'padding': '16px', 6 | 'height': 'calc(100% - 40px)', 7 | 'overflow': 'auto'}, 8 | className='bx--grid--narrow bx--grid--full-width', 9 | children=[ 10 | dca.Row(children=[ 11 | dca.Column(columnSizes=['sm-4'], children=[ 12 | network_card( 13 | graph_id='network_inspector', 14 | graph_name='Inspect Logical Neural Network Weights', 15 | graph_info='Inspect Logical Neural Network Weights', 16 | slider=True, 17 | height=500 18 | ), 19 | ]), 20 | ]), 21 | ]) 22 | -------------------------------------------------------------------------------- /pages/graphs/outputs/network_editor.py: -------------------------------------------------------------------------------- 1 | import dash_cytoscape as cyto 2 | # Load extra layouts 3 | cyto.load_extra_layouts() 4 | 5 | 6 | def network_editor(edges, nodes, roots_names, cytolayout='dagre'): 7 | return cyto.Cytoscape( 8 | id='network_editor_graph', 9 | elements=nodes + edges[-1], 10 | layout={'name': cytolayout, 'roots': roots_names}, 11 | style={'width': '100%', 'height': '100%'}, 12 | stylesheet=[ 13 | { 14 | 'selector': 'node', 15 | 'style': { 16 | 'label': 'data(label)' 17 | } 18 | }, 19 | { 20 | 'selector': 'edge', 21 | 'style': { 22 | 'opacity': 'data(weight)' 23 | } 24 | }, 25 | ] 26 | ) 27 | -------------------------------------------------------------------------------- /pages/graphs/outputs/network_inspector.py: -------------------------------------------------------------------------------- 1 | import dash_cytoscape as cyto 2 | # Load extra layouts 3 | cyto.load_extra_layouts() 4 | 5 | 6 | def network_inspector(edges, nodes, roots_names, cytolayout='dagre'): 7 | return cyto.Cytoscape( 8 | id='network_inspector_graph', 9 | elements=nodes + edges[0], 10 | layout={'name': cytolayout, 'roots': roots_names}, 11 | style={'width': '100%', 'height': '100%'}, 12 | stylesheet=[ 13 | { 14 | 'selector': 'node', 15 | 'style': { 16 | 'label': 'data(label)' 17 | } 18 | }, 19 | { 20 | 'selector': 'edge', 21 | 'style': { 22 | 'opacity': 'data(weight)' 23 | } 24 | }, 25 | ] 26 | ) -------------------------------------------------------------------------------- /pages/interact/outputs/fired_rules.py: -------------------------------------------------------------------------------- 1 | import dash_cytoscape as cyto 2 | # Load extra layouts 3 | cyto.load_extra_layouts() 4 | 5 | 6 | def fired_rules(edges, nodes, roots_names, cytolayout='dagre', id='fired_rules_agent_1'): 7 | return cyto.Cytoscape( 8 | id=id, 9 | elements=nodes + edges[-1], 10 | layout={'name': cytolayout, 'roots': roots_names}, 11 | style={'width': '100%', 'height': '100%'}, 12 | stylesheet=[ 13 | { 14 | 'selector': 'node', 15 | 'style': { 16 | 'label': 'data(label)' 17 | } 18 | }, 19 | { 20 | 'selector': 'edge', 21 | 'style': { 22 | 'opacity': 'data(weight)' 23 | } 24 | }, 25 | ] 26 | ) 27 | -------------------------------------------------------------------------------- /components/ui_shell.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_html_components as html 3 | from dash_core_components import Location 4 | 5 | 6 | def ui_shell(name: str, header, sidebar): 7 | return html.Div([ 8 | Location(id='url', refresh=False), 9 | dca.UIShell( 10 | id='ui-shell', 11 | name=name, 12 | headerItems=header, 13 | sidebarItems=sidebar 14 | ), 15 | html.Div( 16 | id='page-content', 17 | style={ 18 | 'height': 'calc(100vh - 48px)', 19 | 'margin': '0', 20 | 'width': '100%', 21 | 'overflow': 'auto', 22 | 'backgroundColor': '#f4f4f4', 23 | 'marginTop': '48px', 24 | }, 25 | ), 26 | html.Div(id='dummy_div') 27 | ]) 28 | -------------------------------------------------------------------------------- /pages/scalars/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .functions.data_processing import * 3 | 4 | home_dir = './' 5 | 6 | network = load_network(os.path.join(home_dir, 'static', 'data', 'network1.json')) 7 | scalars = create_loss_data(load_scalars(os.path.join(home_dir, 'static', 'data', 'scalars_example.json'))) 8 | roots_names, nodes, edges = create_network(network_dict=network) 9 | slider_dict = create_slider_dict(network['epochs']) 10 | options_loss_checklist = np.concatenate([[{"label": group + ": " + x, "value": group + ": " + x} 11 | for x in scalars[scalars['group'] == group]['loss_type'].unique()] 12 | for group in scalars['group'].unique()]) 13 | value_loss_checklist = np.concatenate([[group + ": " + x 14 | for x in scalars[scalars['group'] == group]['loss_type'].unique()] 15 | for group in scalars['group'].unique()]) 16 | -------------------------------------------------------------------------------- /components/fired_rule_card.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_cytoscape as cyto 3 | import dash_html_components as html 4 | 5 | 6 | def fired_rule_card(graph_id: str, graph_name: str, 7 | graph_info: str = '', height: int = 250) -> dca.Card: 8 | children = [ 9 | html.Div(id=f'{graph_id}', 10 | style={'height': f'{height}px', 'width': '100%'}, 11 | children=cyto.Cytoscape( 12 | id=f'{graph_id}_graph', 13 | elements=[], 14 | style={'height': '100%', 'width': '100%'})), 15 | ] 16 | 17 | return dca.Card( 18 | id=f'{graph_id}_card', 19 | title=graph_name, 20 | info=graph_info, 21 | actions=[{'displayName': 'Download CSV', 22 | 'actionPropName': 'download'}, 23 | {'displayName': 'Download Excel', 24 | 'actionPropName': 'download_excel'}], 25 | children=children) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 International Business Machines 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pages/scalars/callbacks.py: -------------------------------------------------------------------------------- 1 | from dash import Dash 2 | from dash.dependencies import Input, Output, State 3 | 4 | from . import outputs 5 | from .functions import filter_dataframe 6 | from utils_demo.callbacks import select_all_callback 7 | from .data import scalars 8 | 9 | 10 | def register(app: Dash): 11 | # Function to register the select_all callbacks, receive the checkbox id and the dropdown id. 12 | 13 | @app.callback( 14 | Output('loss_type_selection', 'value'), 15 | [ 16 | Input('loss_type_selection', 'value') 17 | ] 18 | ) 19 | def countries_selection(loss_types): 20 | values = list(loss_types) 21 | return values 22 | 23 | @app.callback( 24 | [ 25 | Output('scalars', 'figure'), 26 | ], 27 | [ 28 | Input('configuration_apply_scalars', 'n_clicks'), 29 | ], 30 | [ 31 | State('loss_type_selection', 'value'), 32 | ]) 33 | def apply_configuration(apply_click, loss_types): 34 | filtered_data = filter_dataframe(scalars, loss_types) 35 | return [ 36 | outputs.scalars(filtered_data), 37 | ] 38 | -------------------------------------------------------------------------------- /components/graph_card.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_html_components as html 3 | import dash_table 4 | from dash_core_components import Graph 5 | from dash_extensions import Download 6 | 7 | 8 | def graph_card(graph_id: str, graph_name: str, graph_info: str = '', 9 | radios=None, height: int = 250) -> dca.Card: 10 | if radios is None: 11 | radios = [] 12 | children = [ 13 | Graph(id=f'{graph_id}', style={'height': f'{height}px'}), 14 | Download(id=f'{graph_id}_download'), 15 | ] 16 | i = 0 17 | for radio in radios: 18 | i += 1 19 | children.append( 20 | dca.RadioButtonGroup( 21 | id=f'{graph_id}_radio{i}', 22 | radiosButtons=radio.buttons, 23 | value=radio.value 24 | ) 25 | ) 26 | return dca.Card( 27 | id=f'{graph_id}_card', 28 | title=graph_name, 29 | info=graph_info, 30 | actions=[{'displayName': 'Download CSV', 31 | 'actionPropName': 'download'}, 32 | {'displayName': 'Download Excel', 33 | 'actionPropName': 'download_excel'}], 34 | children=children) 35 | -------------------------------------------------------------------------------- /static/data/textworld_logo.txt: -------------------------------------------------------------------------------- 1 | ________ ________ __ __ ________ 2 | | \| \| \ | \| \ 3 | \$$$$$$$$| $$$$$$$$| $$ | $$ \$$$$$$$$ 4 | | $$ | $$__ \$$\/ $$ | $$ 5 | | $$ | $$ \ >$$ $$ | $$ 6 | | $$ | $$$$$ / $$$$\ | $$ 7 | | $$ | $$_____ | $$ \$$\ | $$ 8 | | $$ | $$ \| $$ | $$ | $$ 9 | \$$ \$$$$$$$$ \$$ \$$ \$$ 10 | __ __ ______ _______ __ _______ 11 | | \ _ | \ / \ | \ | \ | \ 12 | | $$ / \ | $$| $$$$$$\| $$$$$$$\| $$ | $$$$$$$\ 13 | | $$/ $\| $$| $$ | $$| $$__| $$| $$ | $$ | $$ 14 | | $$ $$$\ $$| $$ | $$| $$ $$| $$ | $$ | $$ 15 | | $$ $$\$$\$$| $$ | $$| $$$$$$$\| $$ | $$ | $$ 16 | | $$$$ \$$$$| $$__/ $$| $$ | $$| $$_____ | $$__/ $$ 17 | | $$$ \$$$ \$$ $$| $$ | $$| $$ \| $$ $$ 18 | \$$ \$$ \$$$$$$ \$$ \$$ \$$$$$$$$ \$$$$$$$ 19 | -------------------------------------------------------------------------------- /pages/interact/layout/graphs_layout/inspect_layout.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_html_components as html 3 | from components import dialog_row, fired_rule_card, next_actions_row 4 | 5 | inspect_layout = dca.Grid( 6 | style={'padding': '16px', 7 | 'height': 'calc(100% - 40px)', 8 | 'overflow': 'auto'}, 9 | className='bx--grid--narrow bx--grid--full-width', 10 | children=[ 11 | dca.Row(children=[ 12 | dca.Column(columnSizes=['sm-4'], children=[ 13 | dca.Row(children=[ 14 | dca.Column(columnSizes=['sm-2'], children=[ 15 | fired_rule_card( 16 | graph_id='working_memory_agent_1', 17 | graph_name='Working Memory Agent 1', 18 | graph_info='Working Memory for Agent 1', 19 | height=250, 20 | ), 21 | ]), 22 | dca.Column(columnSizes=['sm-2'], children=[ 23 | fired_rule_card( 24 | graph_id='fired_rules_agent_1', 25 | graph_name='Fired Rules for Agent 1', 26 | graph_info='Fired Rules for Agent 1', 27 | height=250, 28 | ), 29 | ]), 30 | ], style={'marginLeft': '0px'}) 31 | ]), 32 | ]), 33 | ]) 34 | -------------------------------------------------------------------------------- /pages/graphs/functions/graph_generators.py: -------------------------------------------------------------------------------- 1 | import dash_cytoscape as cyto 2 | 3 | 4 | def create_edit_network_graph(edges, nodes, roots_names, cytolayout='dagre'): 5 | return cyto.Cytoscape( 6 | id='nn', 7 | elements=nodes + edges[-1], 8 | layout={'name': cytolayout, 'roots': roots_names}, 9 | style={'width': '400px', 'height': '500px'}, 10 | stylesheet=[ 11 | { 12 | 'selector': 'node', 13 | 'style': { 14 | 'label': 'data(label)' 15 | } 16 | }, 17 | { 18 | 'selector': 'edge', 19 | 'style': { 20 | 'opacity': 'data(weight)' 21 | } 22 | }, 23 | ] 24 | ) 25 | 26 | 27 | def create_network_training_inspector_graph(edges, nodes, roots_names, cytolayout='dagre'): 28 | return cyto.Cytoscape( 29 | id='nn-slider', 30 | elements=nodes + edges[0], 31 | layout={'name': cytolayout, 'roots': roots_names}, 32 | style={'width': '400px', 'height': '500px'}, 33 | stylesheet=[ 34 | { 35 | 'selector': 'node', 36 | 'style': { 37 | 'label': 'data(label)' 38 | } 39 | }, 40 | { 41 | 'selector': 'edge', 42 | 'style': { 43 | 'opacity': 'data(weight)' 44 | } 45 | }, 46 | ] 47 | ) -------------------------------------------------------------------------------- /pages/scalars/layout/configuration_layout.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_html_components as html 3 | 4 | from .filters_layout import filters_layout 5 | 6 | configuration_layout = [ 7 | html.Div(style={'display': 'flex', 'height': '40px', 'alignItems': 'center'}, children=[ 8 | html.H5(style={'paddingLeft': '16px'}, children='Configuration'), 9 | ]), 10 | dca.Tabs(children=[ 11 | dca.Tab(value='filters', label='Filters', children=[ 12 | html.Div( 13 | style={ 14 | 'width': '100%', 15 | 'height': 'calc(100% - 130px)', 16 | 'overflow': 'auto', 17 | 'padding': '8px', 18 | 'paddingBottom': '96px' 19 | }, 20 | children=filters_layout 21 | ) 22 | ]) 23 | ]), 24 | html.Div( 25 | style={ 26 | 'borderTop': '1px solid #ddd', 27 | 'display': 'flex', 28 | 'flexDirection': 'row', 29 | 'justifyContent': 'flex-end', 30 | 'width': '100%', 31 | 'padding': '8px', 32 | }, 33 | children=[ 34 | dca.Button(id='configuration_scalars_filters_reset', 35 | size='sm', children='Reset', kind='secondary', 36 | style={'marginRight': '1px'}), 37 | dca.Button(id='configuration_scalars_filters_apply', 38 | size='sm', children='Apply', kind='primary') 39 | ]) 40 | ] 41 | -------------------------------------------------------------------------------- /pages/interact/layout/configuration_layout.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_html_components as html 3 | 4 | from .filters_layout import filters_layout 5 | 6 | configuration_layout = [ 7 | html.Div(style={'display': 'flex', 8 | 'height': '40px', 9 | 'alignItems': 'center'}, children=[ 10 | html.H5(style={'paddingLeft': '16px'}, children='Configuration'), 11 | ]), 12 | dca.Tabs(children=[ 13 | dca.Tab(value='filters', label='Games', children=[ 14 | html.Div( 15 | style={ 16 | 'width': '100%', 17 | 'height': 'calc(100% - 130px)', 18 | 'overflow': 'auto', 19 | 'padding': '8px', 20 | 'paddingBottom': '96px' 21 | }, 22 | children=filters_layout 23 | ) 24 | ]) 25 | ]), 26 | html.Div( 27 | style={ 28 | 'borderTop': '1px solid #ddd', 29 | 'display': 'flex', 30 | 'flexDirection': 'row', 31 | 'justifyContent': 'flex-end', 32 | 'width': '100%', 33 | 'padding': '8px', 34 | }, 35 | children=[ 36 | dca.Button(id='configuration_interact_games_reset', 37 | size='sm', children='Reset', kind='secondary', 38 | style={'marginRight': '1px'}), 39 | dca.Button(id='configuration_interact_games_apply', 40 | size='sm', children='Apply', kind='primary') 41 | ]) 42 | ] 43 | -------------------------------------------------------------------------------- /pages/graphs/layout/__init__.py: -------------------------------------------------------------------------------- 1 | import dash_html_components as html 2 | 3 | from ... import RIGHT_TOP_MESSAGE 4 | from .graphs_layout import graphs_layout 5 | 6 | layout = [ 7 | html.Div( 8 | style={ 9 | 'width': '100%', 10 | 'display': 'flex', 11 | 'flexDirection': 'row', 12 | 'alignItems': 'flex-end', 13 | 'justifyContent': 'space-between', 14 | 'backgroundColor': 'white', 15 | 'padding': '8px 24px', 16 | 'borderBottom': '1px solid #ddd' 17 | }, 18 | children=[ 19 | html.H4(children=['Explore Logical Nueral Network Graphs']), 20 | html.Div(style={'display': 'flex', 'alignItems': 'flex-end'}, 21 | children=[ 22 | html.Span(id='TCV_value'), 23 | html.Span(style={'marginLeft': '8px'}, 24 | id='Entitled_value'), 25 | html.Span( 26 | style={'marginLeft': '8px'}, 27 | children=[RIGHT_TOP_MESSAGE] 28 | ) 29 | ]) 30 | ] 31 | ), 32 | html.Div( 33 | style={ 34 | 'display': 'flex', 35 | 'flexDirection': 'row', 36 | # UIShell header = 48px. Page Title = 64px 37 | 'height': 'calc(100% - 45px)', 38 | 'width': '100%' 39 | }, 40 | children=[ 41 | html.Div(style={ 42 | 'width': '100%', 43 | 'height': '100%' 44 | }, children=graphs_layout), 45 | ] 46 | ) 47 | ] 48 | -------------------------------------------------------------------------------- /pages/interact/functions/graph_generators.py: -------------------------------------------------------------------------------- 1 | import dash_cytoscape as cyto 2 | 3 | 4 | def create_edit_network_graph(edges, nodes, roots_names, cytolayout='dagre'): 5 | return cyto.Cytoscape( 6 | id='nn', 7 | elements=nodes + edges[-1], 8 | layout={'name': cytolayout, 'roots': roots_names}, 9 | style={'width': '400px', 'height': '500px'}, 10 | stylesheet=[ 11 | { 12 | 'selector': 'node', 13 | 'style': { 14 | 'label': 'data(label)' 15 | } 16 | }, 17 | { 18 | 'selector': 'edge', 19 | 'style': { 20 | 'opacity': 'data(weight)' 21 | } 22 | }, 23 | ] 24 | ) 25 | 26 | 27 | def create_network_training_inspector_graph(edges, 28 | nodes, 29 | roots_names, 30 | cytolayout='dagre'): 31 | return cyto.Cytoscape( 32 | id='nn-slider', 33 | elements=nodes + edges[0], 34 | layout={'name': cytolayout, 'roots': roots_names}, 35 | style={'width': '400px', 'height': '500px'}, 36 | stylesheet=[ 37 | { 38 | 'selector': 'node', 39 | 'style': { 40 | 'label': 'data(label)' 41 | } 42 | }, 43 | { 44 | 'selector': 'edge', 45 | 'style': { 46 | 'opacity': 'data(weight)' 47 | } 48 | }, 49 | ] 50 | ) 51 | -------------------------------------------------------------------------------- /components/dialog_row.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_html_components as html 3 | 4 | 5 | def dialog_row(is_agent: bool, text: str): 6 | """ 7 | 8 | :param is_agent: bool, is the agent otherwise the environment 9 | :param text: text to display 10 | """ 11 | if is_agent: 12 | img = 'assets/robot-icon.jpeg' 13 | float_dir = 'right' 14 | margins = {'marginLeft': 'calc(100% - 40px)'} 15 | else: 16 | img = 'assets/globe-icon.png' 17 | float_dir = 'left' 18 | margins = {} 19 | 20 | if isinstance(text, str): 21 | texts = [t for t in text.split('
')] 22 | text = list() 23 | for t in texts: 24 | text.append(t) 25 | text.append(html.Br()) 26 | text = text[:-1] 27 | 28 | return dca.Row(children=[ 29 | dca.Column( 30 | columnSizes=['sm-4'], 31 | children=[ 32 | html.Div( 33 | children=[ 34 | html.Img(src=img, 35 | style={'height': '50px', 36 | 'width': '50px', 37 | 'float': float_dir, 38 | 'marginRight': '20px', 39 | 'marginLeft': '20px'}), 40 | html.P(text, 41 | style={ 42 | 'float': float_dir, 43 | 'maxWidth': '70%' 44 | }), 45 | ], 46 | style={'display': 'inline-block'}.update(margins)) 47 | ], 48 | ), 49 | ]) 50 | -------------------------------------------------------------------------------- /static/data/network1.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs": 3, 3 | "input": [ 4 | "x1", 5 | "x2", 6 | "x3", 7 | "x4" 8 | ], 9 | "network": { 10 | "not_x2": { 11 | "gate_type": "NOT", 12 | "parents": [ 13 | "x2" 14 | ], 15 | "weights": [ 16 | [ 17 | 1.0 18 | ], 19 | [ 20 | 1.0 21 | ], 22 | [ 23 | 0.9 24 | ] 25 | ] 26 | }, 27 | "and1": { 28 | "gate_type": "AND", 29 | "parents": [ 30 | "x1", 31 | "not_x2", 32 | "x3" 33 | ], 34 | "weights": [ 35 | [ 36 | 0.2, 37 | 0.5, 38 | 0.8 39 | ], 40 | [ 41 | 0.3, 42 | 0.5, 43 | 0.8 44 | ], 45 | [ 46 | 0.2, 47 | 0.5, 48 | 0.8 49 | ] 50 | ] 51 | }, 52 | "or1": { 53 | "gate_type": "OR", 54 | "parents": [ 55 | "and1", 56 | "x4" 57 | ], 58 | "weights": [ 59 | [ 60 | 0.6, 61 | 0.8 62 | ], 63 | [ 64 | 0.6, 65 | 0.8 66 | ], 67 | [ 68 | 0.6, 69 | 0.8 70 | ] 71 | ] 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /pages/home/layout/__init__.py: -------------------------------------------------------------------------------- 1 | import dash_html_components as html 2 | 3 | from ... import RIGHT_TOP_MESSAGE 4 | from .landing_layout import landing_layout 5 | 6 | layout = [ 7 | html.Div( 8 | style={ 9 | 'width': '100%', 10 | 'display': 'flex', 11 | 'flexDirection': 'row', 12 | 'alignItems': 'flex-end', 13 | 'justifyContent': 'space-between', 14 | 'backgroundColor': 'white', 15 | 'padding': '8px 24px', 16 | 'borderBottom': '1px solid #ddd' 17 | }, 18 | children=[ 19 | html.H4(children=['Welcome to the NeSA Demo']), 20 | html.Div(style={ 21 | 'display': 'flex', 22 | 'alignItems': 'flex-end' 23 | }, 24 | children=[ 25 | html.Span(id='TCV_value'), 26 | html.Span( 27 | style={'marginLeft': '8px'}, 28 | id='Entitled_value'), 29 | html.Span( 30 | style={'marginLeft': '8px'}, 31 | children=[RIGHT_TOP_MESSAGE] 32 | ) 33 | ]) 34 | ] 35 | ), 36 | html.Div( 37 | style={ 38 | 'display': 'flex', 39 | 'flexDirection': 'row', 40 | # UIShell header = 48px. Page Title = 64px 41 | 'height': 'calc(100% - 45px)', 42 | 'width': '100%' 43 | }, 44 | children=[ 45 | html.Div(style={ 46 | 'width': '100%', 47 | 'height': '100%' 48 | }, 49 | children=html.Span( 50 | style={'marginLeft': '8px'}, 51 | children=[landing_layout] 52 | )), 53 | ] 54 | ) 55 | ] 56 | -------------------------------------------------------------------------------- /utils_demo/callbacks/download_callback.py: -------------------------------------------------------------------------------- 1 | import plotly.graph_objs as go 2 | from dash import Dash 3 | from dash.dependencies import Input, Output, State 4 | from dash_extensions.snippets import send_data_frame 5 | from pandas import DataFrame 6 | 7 | 8 | def download_callback(app: Dash, card_id: str, graph_id: str, download_id: str) -> None: 9 | @app.callback( 10 | Output(download_id, "data"), 11 | Input(card_id, "action_click"), 12 | State(graph_id, "figure"), 13 | State(card_id, "title"), 14 | prevent_initial_call=True 15 | ) 16 | def __download_callback(action_click: str, figure: go.Figure, title: str): 17 | if figure: 18 | d = figure['data'] 19 | # Fell free to extend the download callback to support more graphs types 20 | if d[0]['type'] == 'scattergl': 21 | output = __download_scatter(d) 22 | else: 23 | output = __download_bar(d) 24 | if action_click.startswith('download_excel'): 25 | return send_data_frame(output.to_excel, title + '.xlsx', index=False) 26 | return send_data_frame(output.to_csv, title + '.csv', index=False) 27 | return '' 28 | 29 | def __download_scatter(d): 30 | output = DataFrame() 31 | for element in d: 32 | df = DataFrame() 33 | df[element['xaxis']] = element['x'] 34 | df[element['yaxis']] = element['y'] 35 | df['name'] = element['name'] 36 | df['customdata'] = element['customdata'] 37 | output = output.append(df) 38 | return output 39 | 40 | def __download_bar(d): 41 | output = DataFrame() 42 | output['x'] = d[0]['x'] 43 | for element in d: 44 | output[element['name']] = element['y'] 45 | return output 46 | -------------------------------------------------------------------------------- /pages/scalars/layout/__init__.py: -------------------------------------------------------------------------------- 1 | import dash_html_components as html 2 | 3 | from ... import RIGHT_TOP_MESSAGE 4 | from .configuration_layout import configuration_layout 5 | from .filters_layout import filters_layout 6 | from .graphs_layout import graphs_layout 7 | 8 | layout = [ 9 | html.Div( 10 | style={ 11 | 'width': '100%', 12 | 'display': 'flex', 13 | 'flexDirection': 'row', 14 | 'alignItems': 'flex-end', 15 | 'justifyContent': 'space-between', 16 | 'backgroundColor': 'white', 17 | 'padding': '8px 24px', 18 | 'borderBottom': '1px solid #ddd' 19 | }, 20 | children=[ 21 | html.H4(children=['Explore Logical Neural Network Scalars']), 22 | html.Div(style={'display': 'flex', 'alignItems': 'flex-end'}, 23 | children=[ 24 | html.Span(id='TCV_value'), 25 | html.Span(style={'marginLeft': '8px'}, 26 | id='Entitled_value'), 27 | html.Span(style={'marginLeft': '8px'}, 28 | children=[RIGHT_TOP_MESSAGE]) 29 | ]) 30 | ] 31 | ), 32 | html.Div( 33 | style={ 34 | 'display': 'flex', 35 | 'flexDirection': 'row', 36 | # UIShell header = 48px. Page Title = 64px 37 | 'height': 'calc(100% - 45px)', 38 | 'width': '100%' 39 | }, 40 | children=[ 41 | html.Div(style={ 42 | 'width': 'calc(100% - 344px)', 43 | 'height': '100%' 44 | }, children=graphs_layout), 45 | html.Div(style={ 46 | 'width': '344px', 47 | 'height': '100%', 48 | 'borderLeft': '1px solid #ddd', 49 | 'backgroundColor': 'white' 50 | }, children=configuration_layout), 51 | ] 52 | ) 53 | ] 54 | -------------------------------------------------------------------------------- /pages/interact/layout/__init__.py: -------------------------------------------------------------------------------- 1 | import dash_html_components as html 2 | 3 | from ... import RIGHT_TOP_MESSAGE 4 | from .configuration_layout import configuration_layout 5 | from .filters_layout import filters_layout 6 | from .graphs_layout import graphs_layout 7 | 8 | layout = [ 9 | html.Div( 10 | style={ 11 | 'width': '100%', 12 | 'display': 'flex', 13 | 'flexDirection': 'row', 14 | 'alignItems': 'flex-end', 15 | 'justifyContent': 'space-between', 16 | 'backgroundColor': 'white', 17 | 'padding': '8px 24px', 18 | 'borderBottom': '1px solid #ddd' 19 | }, 20 | children=[ 21 | html.H4(children=['Interact with Agent']), 22 | html.Div(style={'display': 'flex', 'alignItems': 'flex-end'}, 23 | children=[ 24 | html.Span(id='TCV_value'), 25 | html.Span(style={'marginLeft': '8px'}, 26 | id='Entitled_value'), 27 | html.Span(style={'marginLeft': '8px'}, 28 | children=[RIGHT_TOP_MESSAGE]) 29 | ]) 30 | ] 31 | ), 32 | html.Div( 33 | style={ 34 | 'display': 'flex', 35 | 'flexDirection': 'row', 36 | # UIShell header = 48px. Page Title = 64px 37 | 'height': 'calc(100% - 45px)', 38 | 'width': '100%' 39 | }, 40 | children=[ 41 | html.Div(style={ 42 | 'width': 'calc(100% - 344px)', 43 | 'height': '100%' 44 | }, children=graphs_layout), 45 | html.Div(style={ 46 | 'width': '344px', 47 | 'height': '100%', 48 | 'borderLeft': '1px solid #ddd', 49 | 'backgroundColor': 'white' 50 | }, children=configuration_layout), 51 | ] 52 | ) 53 | ] 54 | -------------------------------------------------------------------------------- /decorators/graph_output.py: -------------------------------------------------------------------------------- 1 | import plotly.graph_objs as go 2 | 3 | function_map = {} 4 | 5 | 6 | # Decorator to handle empty filtered_df and apply the graph application layout 7 | def graph_output(graph_id: str): 8 | def decorator(func): 9 | def wrapper(*args, **kw): 10 | filtered_df = args[0] 11 | if filtered_df is None: 12 | filtered_df = kw['filtered_df'] 13 | if filtered_df.empty: 14 | return go.Figure(layout=go.Layout( 15 | xaxis={"visible": False}, 16 | yaxis={"visible": False}, 17 | paper_bgcolor='#fff', 18 | plot_bgcolor='#fff', 19 | annotations=[ 20 | { 21 | "text": "No matching data found. Please change the filters.", 22 | "xref": "paper", 23 | "yref": "paper", 24 | "showarrow": False, 25 | "font": { 26 | "size": 16 27 | } 28 | } 29 | ] 30 | )) 31 | fig = func(*args, **kw) 32 | fig.update_xaxes(showgrid=False, showticklabels=True, zeroline=True, color='#B2B2B2', 33 | tickcolor='#B2B2B2') 34 | fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False) 35 | fig.layout.update( 36 | paper_bgcolor='#fff', 37 | plot_bgcolor='#fff', 38 | legend=go.layout.Legend( 39 | x=.175, 40 | y=2.00, 41 | traceorder="normal", 42 | orientation='h', 43 | font=dict( 44 | family="sans-serif", 45 | size=14, 46 | color="#B2B2B2" 47 | ) 48 | ), 49 | margin=dict(t=20, b=75), 50 | ) 51 | return fig 52 | 53 | function_map[graph_id] = wrapper 54 | return wrapper 55 | 56 | return decorator 57 | -------------------------------------------------------------------------------- /pages/scalars/functions/data_processing.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pandas as pd 4 | 5 | 6 | def load_network(network_json_path): 7 | with open(network_json_path, 'r') as f: 8 | return json.load(f) 9 | 10 | 11 | def load_scalars(scalars_npz_path): 12 | with open(scalars_npz_path, 'r') as f: 13 | obj = json.load(f) 14 | return obj 15 | 16 | 17 | def create_network(network_dict): 18 | # generate roots string 19 | roots = network_dict['input'] 20 | roots_names = ', '.join(['[id = "{}"]'.format(x) for x in roots]) 21 | 22 | # generate nodes list 23 | node_ids = network_dict['input'] + list(network_dict['network'].keys()) # + ['output'] 24 | node_labels = network_dict['input'] + [x['gate_type'] for _, x in network_dict['network'].items()] # + ['y'] 25 | nodes = [{'data': {'id': x, 'label': y}} for x, y in zip(node_ids, node_labels)] 26 | 27 | # generate edges by epoch lists 28 | edges = [] 29 | for epoch in range(network_dict['epochs']): 30 | epoch_edges = [] 31 | for k, v in network_dict['network'].items(): 32 | for parent, weight in zip(v['parents'], v['weights'][epoch]): 33 | epoch_edges += [{'data': {'source': parent, 'target': k, 'weight': weight}}] 34 | edges += [epoch_edges] 35 | 36 | return roots_names, nodes, edges 37 | 38 | 39 | def create_slider_dict(num_epochs): 40 | if num_epochs <= 10: 41 | return {i: str(i) for i in range(num_epochs)} 42 | elif num_epochs <= 50: 43 | return {i * 5: str(i * 5) for i in range(num_epochs // 5)} 44 | elif num_epochs <= 100: 45 | return {i * 10: str(2 ** i) for i in range(num_epochs // 10)} 46 | else: 47 | return {i * 50: str(2 ** i) for i in range(num_epochs)} 48 | 49 | 50 | def create_loss_data(scalars_dict): 51 | group_col = [] 52 | epoch_col = [] 53 | loss_col = [] 54 | loss_type_col = [] 55 | 56 | for group, group_dict in scalars_dict.items(): 57 | 58 | for i in range(len(group_dict['values'])): 59 | epoch_col += group_dict['epoch'] 60 | 61 | for k, v in group_dict['values'].items(): 62 | loss_col += v 63 | loss_type_col += [k] * len(v) 64 | group_col += [group] * len(v) 65 | 66 | return pd.DataFrame({'group': group_col, 'epoch': epoch_col, 'loss': loss_col, 'loss_type': loss_type_col}) 67 | 68 | -------------------------------------------------------------------------------- /components/network_card.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_core_components as dcc 3 | import dash_html_components as html 4 | import dash_table 5 | import dash_cytoscape as cyto 6 | 7 | 8 | def network_card(graph_id: str, graph_name: str, graph_info: str = '', slider: bool = False, table: bool = False, height: int = 250) -> dca.Card: 9 | children = [ 10 | html.Div(id=f'{graph_id}', style={'height': f'{height}px', 'width': '100%'}, 11 | children=cyto.Cytoscape( 12 | id=f'{graph_id}_graph', elements=[], style={'height': '100%', 'width': '100%'})), 13 | ] 14 | 15 | if slider: 16 | slider_component = html.Div( 17 | id=f'{graph_id}_slider_parent', style={'width': '25%', 'padding': '25px', 'paddingBottom': '50px'}, 18 | children=[html.H4("Select Epoch:"), 19 | dcc.Slider(id=f'{graph_id}_slider', min=0, max=1, value=0, marks={}, step=None)] 20 | ) 21 | children.append(slider_component) 22 | 23 | if table: 24 | table_component = html.Div( 25 | id=f'{graph_id}_table_parent', style={'width': '50%', 'padding': '25px'}, 26 | children=[ 27 | html.H4("Edit Weight:"), 28 | dash_table.DataTable( 29 | id=f'{graph_id}_table', 30 | columns=( 31 | [{'id': 'Parent', 'name': 'Parent'}, {'id': 'Child', 'name': 'Child'}, 32 | {'id': 'Weight', 'name': 'Weight'}] 33 | ), 34 | data=[dict(Parent='', Child='', Weight='')], 35 | editable=True, 36 | style_table={'overflowX': 'auto', 'width': '50%'}, 37 | style_cell={ 38 | # all three widths are needed 39 | 'minWidth': '120px', 'width': '120px', 'maxWidth': '120px', 40 | 'overflow': 'hidden', 41 | 'textOverflow': 'ellipsis', 42 | } 43 | ) 44 | ] 45 | ) 46 | children.append(table_component) 47 | 48 | return dca.Card( 49 | id=f'{graph_id}_card', 50 | title=graph_name, 51 | info=graph_info, 52 | actions=[{'displayName': 'Download CSV', 'actionPropName': 'download'}, 53 | {'displayName': 'Download Excel', 'actionPropName': 'download_excel'}], 54 | children=children) 55 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import dash 5 | import dash_html_components as html 6 | 7 | from components import ui_shell 8 | from decorators.graph_output import function_map 9 | from router import register_router 10 | from utils_demo.callbacks import reset_page_callback, download_callback 11 | 12 | from pages.home.layout import layout as home_layout 13 | from pages.home.callbacks import register as register_home 14 | 15 | from pages.scalars.layout import layout as scalars_layout 16 | from pages.scalars.callbacks import register as register_scalars 17 | 18 | from pages.graphs.layout import layout as graphs_layout 19 | from pages.graphs.callbacks import register as register_graphs 20 | 21 | from pages.interact.layout import layout as interact_layout 22 | from pages.interact.callbacks import register as register_interact 23 | 24 | # Dash App instantiation 25 | app = dash.Dash(__name__) 26 | server = app.server 27 | app.title = 'NeSA Demo' 28 | 29 | app.layout = ui_shell( 30 | app.title, 31 | header=[ 32 | {'name': 'Home', 'url': '/'}, 33 | {'name': 'Interact', 'url': '/interact'}, 34 | # {'name': 'Scalars', 'url': '/scalars'}, 35 | # {'name': 'Graphs', 'url': '/graphs'}, 36 | ], 37 | sidebar=[] 38 | ) 39 | app.validation_layout = html.Div([ 40 | app.layout, 41 | home_layout, 42 | scalars_layout, 43 | graphs_layout, 44 | interact_layout, 45 | ]) 46 | 47 | 48 | register_router(app) 49 | 50 | # OPENID CONNECT SSO 51 | enable_sso = os.getenv('DASH_ENABLE_OIDC', 'false') 52 | if enable_sso.lower() == 'true': 53 | raise NotImplemented 54 | 55 | # Callback to reset the page when the user press the reset button 56 | reset_page_callback(app) 57 | 58 | # Automatically register the download callbacks of all graphs 59 | # in the application that using the graph_output decorator 60 | for graph_id in function_map.keys(): 61 | download_callback(app, graph_id + '_card', graph_id, 62 | graph_id + '_download') 63 | 64 | # Pages callbacks 65 | register_home(app) 66 | register_scalars(app) 67 | register_graphs(app) 68 | # register_comparison(app) 69 | register_interact(app) 70 | 71 | # Load static files 72 | html.Img(src=app.get_asset_url('globe-icon.png')) 73 | html.Img(src=app.get_asset_url('robot-icon.png')) 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument('--port', type=int, default=8050) 78 | parser.add_argument('--release', action='store_true') 79 | args = parser.parse_args() 80 | 81 | app.run_server(debug=not args.release, host='0.0.0.0', port=args.port) 82 | -------------------------------------------------------------------------------- /pages/graphs/functions/data_processing.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pandas as pd 4 | 5 | 6 | def load_network(network_json_path): 7 | with open(network_json_path, 'r') as f: 8 | return json.load(f) 9 | 10 | 11 | def load_scalars(scalars_npz_path): 12 | obj = np.load(scalars_npz_path) 13 | return { 14 | 'epoch': obj['epoch'].tolist(), 15 | 'losses': {'supervised_loss': obj['supervised_loss'].tolist(), 16 | 'constraint_loss': obj['constraint_loss'].tolist(), 17 | 'total_loss': obj['total_loss'].tolist()} 18 | } 19 | 20 | 21 | def create_network(network_dict): 22 | # generate roots string 23 | roots = network_dict['input'] 24 | roots_names = ', '.join(['[id = "{}"]'.format(x) for x in roots]) 25 | 26 | # generate nodes list 27 | node_ids = network_dict['input'] + list(network_dict['network'].keys()) # + ['output'] 28 | node_labels = network_dict['input'] + [x['gate_type'] for _, x in network_dict['network'].items()] # + ['y'] 29 | nodes = [{'data': {'id': x, 'label': y}} for x, y in zip(node_ids, node_labels)] 30 | 31 | # generate edges by epoch lists 32 | edges = [] 33 | for epoch in range(network_dict['epochs']): 34 | epoch_edges = [] 35 | for k, v in network_dict['network'].items(): 36 | for parent, weight in zip(v['parents'], v['weights'][epoch]): 37 | epoch_edges += [{'data': {'source': parent, 'target': k, 'weight': weight}}] 38 | edges += [epoch_edges] 39 | 40 | return roots_names, nodes, edges 41 | 42 | 43 | def create_slider_dict(num_epochs): 44 | if num_epochs <= 10: 45 | return {i: str(i) for i in range(num_epochs)} 46 | elif num_epochs <= 50: 47 | return {i * 5: str(i * 5) for i in range(num_epochs // 5)} 48 | elif num_epochs <= 100: 49 | return {i * 10: str(2 ** i) for i in range(num_epochs // 10)} 50 | else: 51 | return {i * 50: str(2 ** i) for i in range(num_epochs)} 52 | 53 | 54 | def create_loss_data(scalars_dict): 55 | epoch_col = [] 56 | loss_col = [] 57 | loss_type_col = [] 58 | 59 | for i in range(len(scalars_dict['losses'])): 60 | epoch_col += scalars_dict['epoch'] 61 | 62 | for k, v in scalars_dict['losses'].items(): 63 | loss_col += v 64 | loss_type_col += [k] * len(v) 65 | 66 | return pd.DataFrame({'epoch': epoch_col, 'loss': loss_col, 'loss_type': loss_type_col}) 67 | 68 | 69 | def create_slider_transform_function(num_epochs): 70 | if num_epochs <= 10: 71 | return lambda x: x 72 | elif num_epochs <= 50: 73 | return lambda x: x 74 | elif num_epochs <= 100: 75 | return lambda x: int(2 ** (x // 10)) 76 | else: 77 | return lambda x: int(2 ** (x // 50)) 78 | 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | /data/ 4 | .idea/workspace.xml 5 | cache/ 6 | results/ 7 | static/games/ 8 | 9 | # Created by .ignore support plugin (hsz.mobi) 10 | ### Python template 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | 150 | -------------------------------------------------------------------------------- /pages/graphs/callbacks.py: -------------------------------------------------------------------------------- 1 | from dash import Dash 2 | from dash.dependencies import Input, Output, State 3 | import pandas as pd 4 | 5 | from . import outputs 6 | from .functions.data_processing import create_slider_transform_function 7 | from .data import * 8 | 9 | 10 | def register(app: Dash): 11 | 12 | @app.callback(Output('network_editor', 'children'), 13 | Input('ui-shell', 'name')) 14 | def edit_network_graph(input): 15 | return outputs.network_editor(edges, nodes, roots_names, cytolayout='dagre') 16 | 17 | @app.callback(Output('network_inspector', 'children'), 18 | [Input('ui-shell', 'name')]) 19 | def edit_network_graph(input): 20 | return outputs.network_inspector(edges, nodes, roots_names, cytolayout='dagre') 21 | 22 | @app.callback( 23 | Output('network_editor_graph', 'elements'), 24 | [ 25 | Input('network_editor_table', 'data'), 26 | Input('network_editor_table', 'columns'), 27 | ], 28 | [ 29 | State('network_editor_graph', 'elements') 30 | ]) 31 | def edit_network_graph_weights(rows, columns, elements): 32 | df = pd.DataFrame(rows, columns=[c['name'] for c in columns]) 33 | for i, data in enumerate(elements): 34 | if 'source' in data['data']: 35 | sub_df = df[(df['Parent'] == data['data']['source']) & (df['Child'] == data['data']['target'])] 36 | if not sub_df.empty: 37 | elements[i]['data']['weight'] = float(sub_df['Weight'].iloc[0]) 38 | 39 | return elements 40 | 41 | @app.callback( 42 | Output('network_editor_table', 'data'), 43 | Input('network_editor_graph', 'tapEdgeData')) 44 | def display_input_table(edge): 45 | if edge is None: 46 | return [dict(Parent='', Child='', Weight='')] 47 | return [dict(Parent=edge['source'], Child=edge['target'], Weight=edge['weight'])] 48 | 49 | @app.callback( 50 | [ 51 | Output('network_inspector_graph', 'elements'), 52 | Output('network_inspector_slider', 'marks'), 53 | Output('network_inspector_slider', 'max'), 54 | ], 55 | [ 56 | Input('network_inspector_slider', 'value'), 57 | ], 58 | [ 59 | State('network_inspector_graph', 'elements') 60 | ]) 61 | def update_network_training_inspector_graph(selected_epoch, elements): 62 | slider_transform_function = create_slider_transform_function(network['epochs']) 63 | for i, edge in enumerate(edges[slider_transform_function(selected_epoch)]): 64 | for j, data in enumerate(elements): 65 | if 'source' in data['data']: 66 | if (data['data']['source'] == edge['data']['source']) & ( 67 | data['data']['target'] == edge['data']['target']): 68 | elements[j]['data']['weight'] = float(edge['data']['weight']) 69 | return elements, slider_dict, network['epochs'] 70 | -------------------------------------------------------------------------------- /pages/interact/functions/data_processing.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pandas as pd 4 | 5 | 6 | def load_network(network_json_path): 7 | with open(network_json_path, 'r') as f: 8 | return json.load(f) 9 | 10 | 11 | def load_scalars(scalars_npz_path): 12 | obj = np.load(scalars_npz_path) 13 | return { 14 | 'epoch': obj['epoch'].tolist(), 15 | 'losses': {'supervised_loss': obj['supervised_loss'].tolist(), 16 | 'constraint_loss': obj['constraint_loss'].tolist(), 17 | 'total_loss': obj['total_loss'].tolist()} 18 | } 19 | 20 | 21 | def create_network(network_dict): 22 | # generate roots string 23 | roots = network_dict['input'] 24 | roots_names = ', '.join(['[id = "{}"]'.format(x) for x in roots]) 25 | 26 | # generate nodes list 27 | node_ids = \ 28 | network_dict['input'] + \ 29 | list(network_dict['network'].keys()) # + ['output'] 30 | node_labels = \ 31 | network_dict['input'] + \ 32 | [x['gate_type'] for _, x in network_dict['network'].items()] # + ['y'] 33 | nodes = [{'data': {'id': x, 'label': y}} 34 | for x, y in zip(node_ids, node_labels)] 35 | 36 | # generate edges by epoch lists 37 | edges = [] 38 | for epoch in range(network_dict['epochs']): 39 | epoch_edges = [] 40 | for k, v in network_dict['network'].items(): 41 | for parent, weight \ 42 | in zip(v['parents'], v['weights'][epoch]): 43 | epoch_edges += [{'data': {'source': parent, 44 | 'target': k, 45 | 'weight': weight}}] 46 | edges += [epoch_edges] 47 | 48 | return roots_names, nodes, edges 49 | 50 | 51 | def create_slider_dict(num_epochs): 52 | if num_epochs <= 10: 53 | return {i: str(i) for i in range(num_epochs)} 54 | elif num_epochs <= 50: 55 | return {i * 5: str(i * 5) for i in range(num_epochs // 5)} 56 | elif num_epochs <= 100: 57 | return {i * 10: str(2 ** i) for i in range(num_epochs // 10)} 58 | else: 59 | return {i * 50: str(2 ** i) for i in range(num_epochs)} 60 | 61 | 62 | def create_loss_data(scalars_dict): 63 | epoch_col = [] 64 | loss_col = [] 65 | loss_type_col = [] 66 | 67 | for i in range(len(scalars_dict['losses'])): 68 | epoch_col += scalars_dict['epoch'] 69 | 70 | for k, v in scalars_dict['losses'].items(): 71 | loss_col += v 72 | loss_type_col += [k] * len(v) 73 | 74 | return pd.DataFrame({'epoch': epoch_col, 75 | 'loss': loss_col, 76 | 'loss_type': loss_type_col}) 77 | 78 | 79 | def create_slider_transform_function(num_epochs): 80 | if num_epochs <= 10: 81 | return lambda x: x 82 | elif num_epochs <= 50: 83 | return lambda x: x 84 | elif num_epochs <= 100: 85 | return lambda x: int(2 ** (x // 10)) 86 | else: 87 | return lambda x: int(2 ** (x // 50)) 88 | -------------------------------------------------------------------------------- /pages/home/layout/landing_layout.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_html_components as html 3 | 4 | from ... import LOA_REPO_URL, REPO_URL 5 | 6 | landing_layout = dca.Grid( 7 | style={ 8 | 'padding': '16px', 9 | 'height': 'calc(100% - 75px)', 10 | 'overflow': 'auto', 11 | 'width': '75%' 12 | }, 13 | className='bx--grid--narrow bx--grid--full-width', 14 | children=[ 15 | dca.Row(children=[ 16 | dca.Column(columnSizes=['sm-4'], children=[ 17 | dca.Card( 18 | id='landing_card', 19 | children=[ 20 | html.H1("NeSA Demo", 21 | style={ 22 | 'padding-top': '10px', 23 | 'padding-bottom': '10px', 24 | }), 25 | html.P( 26 | "Welcome to Neuro-Sybmolic Agent (NeSA) Demo, " 27 | "where you can explore, understand and interact " 28 | "with NeSA which is Logical Optimal Action (LOA).", 29 | className="lead", 30 | style={ 31 | 'padding-top': '10px', 32 | 'padding-bottom': '10px', 33 | } 34 | ), 35 | html.Hr(className="my-2"), 36 | html.P( 37 | "Click the buttons below to find for each code " 38 | "of NeSA Demo and LOA.", 39 | style={ 40 | 'padding-top': '10px', 41 | 'padding-bottom': '10px', 42 | } 43 | ), 44 | dca.Button( 45 | id='learn_more_button', 46 | size='sm', 47 | children='NeSA Demo Repo', 48 | kind='primary', 49 | href=REPO_URL, 50 | style={ 51 | 'padding': '10px', 52 | 'right': '10px', 53 | 'left': '0px', 54 | } 55 | ), 56 | dca.Button( 57 | id='learn_more_button', 58 | size='sm', 59 | children='LOA Repo', 60 | kind='primary', 61 | href=LOA_REPO_URL, 62 | style={ 63 | 'padding': '10px', 64 | 'right': '0px', 65 | 'left': '10px', 66 | } 67 | ), 68 | ] 69 | ), 70 | ]), 71 | ]), 72 | ]) 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neuro-Symbolic Agent Demo 2 | 3 | This repository contains a demo for Neuro-Symbolic Agent (NeSA), which is specifically [Logical Optimal Action (LOA)](https://github.com/IBM/LOA). 4 | 5 | ## Setup 6 | 7 | - Anaconda 4.10.3 8 | - Tested on Mac and Linux 9 | 10 | ```bash 11 | git clone --recursive git@github.com:IBM/nesa-demo.git 12 | conda create -n nesa-demo python=3.8 13 | conda activate nesa-demo 14 | conda install pytorch=1.10.0 torchvision torchaudio -c pytorch 15 | conda install gensim==3.8.3 networkx unidecode nltk=3.6.3 16 | pip install -U spacy 17 | python -m spacy download en_core_web_sm 18 | python -m nltk.downloader 'punkt' 19 | pip install -r requirements.txt 20 | cp -r third_party/commonsense_rl/games static/ 21 | 22 | 23 | # Download models 24 | wget -O results.zip https://ibm.box.com/shared/static/chr1vvgb70mmt2gr1yijlsw3g7fq2pgs.zip 25 | unzip results.zip 26 | rm -f results.zip 27 | 28 | 29 | # Download AMR cache file 30 | mkdir -p cache 31 | wget -O cache/amr_cache.pkl https://ibm.box.com/shared/static/klsvx54skc5wlf35qg3klo35ex25dbb0.pkl 32 | ``` 33 | 34 | ## Execute 35 | 36 | ```bash 37 | export AMR_SERVER_IP=localhost 38 | export AMR_SERVER_PORT= 39 | python app.py --release 40 | ``` 41 | 42 | ### If you want to train the model by yourself 43 | 44 | - commonsense-rl (DL-only method) 45 | 46 | ```bash 47 | cd third_party/commonsense_rl/ 48 | python -u train_agent.py --agent_type knowledgeaware --game_dir ./games/twc --game_name *.ulx --difficulty_level easy --graph_type world --graph_mode evolve --graph_emb_type glove --world_evolve_type manual --initial_seed 0 --nruns 1 49 | ``` 50 | 51 | - LOA (NeSA method) 52 | 53 | ```bash 54 | cd third_party/loa/ 55 | # follow the setup steps in README.md 56 | python train.py 57 | cp results/loa-twc-dleasy-np2-nt15-ps1-ks6-spboth.pkl ../../results/ 58 | ``` 59 | 60 | ## Citations 61 | 62 | This repository provides code for the following paper, please cite the paper and give a star if you find the paper and code useful for your work. 63 | 64 | - Daiki Kimura, Subhajit Chaudhury, Masaki Ono, Michiaki Tatsubori, Don Joven Agravante, Asim Munawar, Akifumi Wachi, Ryosuke Kohita, and Alexander Gray, "[LOA: Logical Optimal Actions for Text-based Interaction Games](https://aclanthology.org/2021.acl-demo.27/)", ACL-IJCNLP 2021. 65 | 66 |
Details and bibtex
67 | 68 | The paper presents an initial demonstration of logical optimal action (LOA) on TextWorld (TW) Coin collector, TW Cooking, TW Commonsense, and Jericho. In this version, the human player can select an action by hand and recommendation action list from LOA with visualizing acquired knowledge for improvement of interpretability of trained rules. 69 | 70 | ``` 71 | @inproceedings{kimura-etal-2021-loa, 72 | title = "{LOA}: Logical Optimal Actions for Text-based Interaction Games", 73 | author = "Kimura, Daiki and Chaudhury, Subhajit and Ono, Masaki and Tatsubori, Michiaki and Agravante, Don Joven and Munawar, Asim and Wachi, Akifumi and Kohita, Ryosuke and Gray, Alexander", 74 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing: System Demonstrations", 75 | month = aug, 76 | year = "2021", 77 | address = "Online", 78 | publisher = "Association for Computational Linguistics", 79 | url = "https://aclanthology.org/2021.acl-demo.27", 80 | doi = "10.18653/v1/2021.acl-demo.27", 81 | pages = "227--231" 82 | } 83 | ``` 84 |
85 | 86 | 87 | ## License 88 | 89 | MIT License -------------------------------------------------------------------------------- /pages/interact/layout/graphs_layout/interact_layout.py: -------------------------------------------------------------------------------- 1 | import dash_carbon_components as dca 2 | import dash_html_components as html 3 | from components import dialog_row, fired_rule_card, next_actions_row 4 | 5 | interact_layout = dca.Grid( 6 | style={'padding': '16px', 7 | 'height': 'calc(100% - 40px)', 8 | 'overflow': 'auto'}, 9 | className='bx--grid--narrow bx--grid--full-width', 10 | children=[ 11 | dca.Row( 12 | children=[ 13 | dca.Column(columnSizes=['sm-4'], children=[ 14 | dca.Card( 15 | id='interactive_agent_chat', 16 | children=[ 17 | ], 18 | style={ 19 | 'overflow': 'auto', 20 | 'display': 'flex', 21 | 'flex-direction': 'column-reverse', 22 | 'height': '400px', 23 | 'width': '100%' 24 | } 25 | ) 26 | ]) 27 | ], 28 | ), 29 | dca.Row( 30 | children=[ 31 | dca.Column(columnSizes=['sm-3'], children=[ 32 | dca.Card( 33 | id='select_next_action_card', 34 | children=[ 35 | html.H3('', 36 | id='message_for_top_of_action_selector', 37 | style={ 38 | 'padding-bottom': '0px' 39 | }), 40 | 41 | html.Div( 42 | id='next_action_row', 43 | children=[ 44 | next_actions_row( 45 | [''], 46 | agent_1_actions=[['', 0]], 47 | agent_1_rules=None, 48 | agent_1_facts=None, 49 | agent_2_actions=[['', 0]], 50 | done=False 51 | ), 52 | ] 53 | ), 54 | ], 55 | style={ 56 | 'margin-top': '20px', 57 | 'margin-bottom': '20px', 58 | 'height': 'calc(100% - 50px)' 59 | } 60 | ), 61 | ]), 62 | dca.Column(columnSizes=['sm-1'], children=[ 63 | dca.Card( 64 | id='apply_next_action_card', 65 | children=[ 66 | html.Div( 67 | children=[ 68 | dca.Button( 69 | id='reset_environment', 70 | size='sm', 71 | children='Reset Environment', 72 | kind='primary', 73 | style={'margin-top': '10%'} 74 | ) 75 | ], 76 | style={'text-align': 'center'} 77 | ), 78 | html.P( 79 | id='game_score', children='', 80 | style={ 81 | 'text-align': 'center', 82 | 'margin-top': '5%', 83 | 'margin-bottom': '5%' 84 | } 85 | ), 86 | ], 87 | style={ 88 | 'margin-top': '20px', 89 | 'margin-bottom': '0px', 90 | 'height': 'calc(100% - 50px)' 91 | } 92 | ), 93 | ]), 94 | ], 95 | ), 96 | ]) 97 | -------------------------------------------------------------------------------- /pages/interact/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from .. import (DL_AGENT, EASY_LEVEL, HARD_LEVEL, LEVELS, LOA_AGENT, 5 | MEDIUM_LEVEL) 6 | from .functions.data_processing import (create_loss_data, create_network, 7 | create_slider_dict, load_network, 8 | load_scalars) 9 | 10 | if True: 11 | os.environ['DDLNN_HOME'] = 'third_party/loa/third_party/dd_lnn/' 12 | os.environ['TWC_HOME'] = 'static/games/twc/' 13 | 14 | sys_path_backup = sys.path 15 | sys.path.append('third_party/commonsense_rl/') 16 | from .twc_agent import get_twc_agent, kg_graphs 17 | sys.path = sys.path[:-1] 18 | 19 | sys.path.append('third_party/loa/') 20 | from third_party.loa.amr_parser import AMRSemParser 21 | from third_party.loa.loa_agent import LOAAgent, LogicalTWCQuantifier 22 | 23 | 24 | network = \ 25 | load_network(os.path.join('static', 'data', 'network1.json')) 26 | scalars = \ 27 | create_loss_data(load_scalars(os.path.join( 28 | 'static', 'data', 'scalars_lnn_wts_N4_pos1_neg2.npz'))) 29 | 30 | roots_names, nodes, edges = create_network(network_dict=network) 31 | slider_dict = create_slider_dict(network['epochs']) 32 | options_loss_checklist = [{"label": x, "value": x} 33 | for x in scalars['loss_type'].unique()] 34 | value_loss_checklist = [x for x in scalars['loss_type'].unique()] 35 | 36 | 37 | options_game_level = [{"label": x, "value": x} 38 | for x in LEVELS] 39 | value_game_level = [EASY_LEVEL] 40 | options_game = [{"label": x, "value": x} 41 | for x in ['One', 'Two', 'Three', 'Four', 'Five']] 42 | value_game = ['One'] 43 | 44 | f = open(os.path.join('static', 'data', 'textworld_logo.txt'), 45 | 'r', encoding='UTF-8') 46 | textworld_logo = f.read() 47 | f.close() 48 | 49 | game_no = 0 50 | easy_env = LogicalTWCQuantifier('easy', 51 | split='test', 52 | max_episode_steps=50, 53 | batch_size=None, 54 | game_number=game_no) 55 | medium_env = LogicalTWCQuantifier('medium', 56 | split='test', 57 | max_episode_steps=50, 58 | batch_size=None, 59 | game_number=game_no) 60 | hard_env = LogicalTWCQuantifier('hard', 61 | split='test', 62 | max_episode_steps=50, 63 | batch_size=None, 64 | game_number=game_no) 65 | 66 | env_dict = \ 67 | {EASY_LEVEL: easy_env, MEDIUM_LEVEL: medium_env, HARD_LEVEL: hard_env} 68 | 69 | difficulty_level = 'easy' 70 | loa_pkl_filepath = 'results/loa-twc-dleasy-np2-nt15-ps1-ks6-spboth.pkl' 71 | sem_parser_mode = 'both' 72 | 73 | amr_server_ip = os.environ.get('AMR_SERVER_IP', 'localhost') 74 | amr_server_port_str = os.environ.get('AMR_SERVER_PORT', '') 75 | try: 76 | amr_server_port = int(amr_server_port_str) 77 | except ValueError: 78 | amr_server_port = None 79 | 80 | if LOA_AGENT: 81 | loa_agent = LOAAgent(admissible_verbs=None, 82 | amr_server_ip=amr_server_ip, 83 | amr_server_port=amr_server_port, 84 | prune_by_state_change=True, 85 | sem_parser_mode=sem_parser_mode) 86 | 87 | loa_agent.load_pickel(loa_pkl_filepath) 88 | rest_amr = AMRSemParser(amr_server_ip=amr_server_ip, 89 | amr_server_port=amr_server_port, 90 | cache_folder='cache/') 91 | adm_verbs = loa_agent.admissible_verbs 92 | loa_agent.pi.eval() 93 | loa_rules = loa_agent.extract_rules() 94 | else: 95 | loa_agent = None 96 | rest_amr = None 97 | loa_rules = None 98 | 99 | if DL_AGENT: 100 | twc_agent, twc_agent_goal_graphs, \ 101 | twc_agent_manual_world_graphs = get_twc_agent() 102 | else: 103 | twc_agent = None 104 | twc_agent_goal_graphs = None 105 | twc_agent_manual_world_graphs = None 106 | 107 | scored_action_history = [[] for _ in range(1)] 108 | -------------------------------------------------------------------------------- /pages/interact/functions/twc.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | import sys 5 | 6 | import gym 7 | import textworld 8 | import textworld.gym 9 | import torch 10 | from textworld import EnvInfos 11 | 12 | if True: 13 | sys.path.append(os.path.join(os.path.dirname(__file__) + '/../../../')) 14 | 15 | from third_party.loa.amr_parser import get_formatted_obs_text 16 | from third_party.loa.logical_twc import Action2Literal 17 | from third_party.loa.utils import obtain_predicates_logic_vector 18 | 19 | 20 | home_dir = './' 21 | 22 | request_infos = \ 23 | EnvInfos(verbs=True, moves=True, inventory=True, description=True, 24 | objective=True, intermediate_reward=True, 25 | policy_commands=True, max_score=True, admissible_commands=True, 26 | last_action=True, game=True, facts=True, entities=True, 27 | won=True, lost=True, location=True) 28 | 29 | 30 | def load_twc_game(level_str, type_str, index): 31 | game_file_names = sorted(glob.glob( 32 | os.path.join(home_dir, 'static', 'games', 'twc', '%s', '%s', '*.ulx') % 33 | (level_str, type_str))) 34 | game_file_name = [game_file_names[index]] 35 | env_id = \ 36 | textworld.gym.register_games(game_file_name, request_infos, 37 | max_episode_steps=50, 38 | name='twc-%s-%s-%d' % 39 | (level_str, type_str, index), 40 | batch_size=None) 41 | env = gym.make(env_id) 42 | 43 | return env 44 | 45 | 46 | def info2infos(info): 47 | return {k: [v] for k, v in info.items()} 48 | 49 | 50 | def get_agent_actions(env, obs, score, done, info, 51 | action, scored_action_history, 52 | loa_agent=None, dl_agent=None): 53 | from ..data import rest_amr 54 | 55 | action2literal = Action2Literal() 56 | all_actions = info['admissible_commands'] 57 | if loa_agent is None: 58 | ns_agent_actions_list = \ 59 | [a for a in all_actions if not a.startswith('examine')] 60 | ns_agent_actions = \ 61 | [[a, random.random()] for a in ns_agent_actions_list] 62 | else: 63 | ns_agent_actions = list() 64 | 65 | facts = env.get_logical_state(info) 66 | obs_text = get_formatted_obs_text(info) 67 | 68 | try: 69 | verbnet_facts, _ = rest_amr.obs2facts(obs_text, 70 | mode='both', 71 | verbose=True) 72 | 73 | rest_amr.save_cache() 74 | 75 | verbnet_facts['atlocation'] = facts['atlocation'] 76 | verbnet_facts['is_instance'] = facts['is_instance'] 77 | 78 | logical_facts = \ 79 | { 80 | 'at_location': [list(x) for x in 81 | verbnet_facts['atlocation']], 82 | 'carry': verbnet_facts['carry'] 83 | if 'carry' in verbnet_facts else [] 84 | } 85 | 86 | for adm_comm in all_actions: 87 | rule, x, y = action2literal(adm_comm) 88 | if rule in loa_agent.admissible_verbs: 89 | rule_arity = loa_agent.admissible_verbs[rule] 90 | 91 | logic_vector, all_preds = \ 92 | obtain_predicates_logic_vector( 93 | rule_arity, x, y, 94 | facts=verbnet_facts, 95 | template=loa_agent.arity_predicate_templates) 96 | logic_vector = logic_vector.unsqueeze(0) 97 | yhat = loa_agent.pi.forward_eval(logic_vector, 98 | lnn_model_name=rule) 99 | # print("{} : {:.2f}".format(adm_comm, yhat.item())) 100 | ns_agent_actions.append([adm_comm, float(yhat.item())]) 101 | else: 102 | ns_agent_actions.append([adm_comm, 0]) 103 | 104 | except Exception: 105 | logical_facts = None 106 | for adm_comm in all_actions: 107 | ns_agent_actions.append([adm_comm, 0]) 108 | 109 | if dl_agent is None or dl_agent[0] is None: 110 | dl_agent_actions_list = \ 111 | [a for a in all_actions if a.startswith('examine')] 112 | dl_agent_actions = \ 113 | [[a, random.random()] for a in dl_agent_actions_list] 114 | else: 115 | dl_agent_actions = list() 116 | 117 | twc_agent, twc_agent_goal_graphs, twc_agent_manual_world_graphs = \ 118 | dl_agent 119 | 120 | infos = info2infos(info) 121 | 122 | game_goal_graphs = [None] * 1 123 | game_manual_world_graph = [None] * 1 124 | 125 | for b, game in enumerate(infos["game"]): 126 | if "uuid" in game.metadata: 127 | game_id = game.metadata["uuid"].split("-")[-1] 128 | game_goal_graphs[b] = twc_agent_goal_graphs[game_id] 129 | game_manual_world_graph[b] = \ 130 | twc_agent_manual_world_graphs[game_id] 131 | 132 | infos['goal_graph'] = game_goal_graphs 133 | infos['manual_world_graph'] = game_manual_world_graph 134 | 135 | if twc_agent.graph_emb_type and \ 136 | ('local' in twc_agent.graph_type or 137 | 'world' in twc_agent.graph_type): 138 | twc_agent.update_current_graph([obs], [action], 139 | scored_action_history, 140 | infos, 141 | 'evolve') 142 | 143 | action, commands_values = \ 144 | twc_agent.act([obs], [score], [done], infos, 145 | scored_action_history, random_action=False) 146 | 147 | values = torch.nn.functional.softmax(commands_values)[0].tolist() 148 | 149 | for i, a in enumerate(all_actions): 150 | dl_agent_actions.append([a, values[i]]) 151 | 152 | return all_actions, ns_agent_actions, dl_agent_actions, logical_facts 153 | -------------------------------------------------------------------------------- /pages/interact/callbacks.py: -------------------------------------------------------------------------------- 1 | import dash 2 | from components import dialog_row, next_actions_row 3 | from dash import Dash 4 | from dash.dependencies import Input, Output, State 5 | from dash.exceptions import PreventUpdate 6 | from dash_html_components.P import P 7 | 8 | from .. import MESSAGE_FOR_DONE, MESSAGE_FOR_SELECT_NEXT_ACTION 9 | from . import outputs 10 | from .data import (edges, env_dict, kg_graphs, loa_agent, loa_rules, nodes, 11 | roots_names, scored_action_history, twc_agent, 12 | twc_agent_goal_graphs, twc_agent_manual_world_graphs) 13 | from .functions import get_agent_actions 14 | 15 | 16 | def register(app: Dash): 17 | # Function to register the select_all callbacks, 18 | # receive the checkbox id and the dropdown id. 19 | 20 | @app.callback( 21 | Output('game_level_selection', 'value'), 22 | [ 23 | Input('game_level_selection', 'value') 24 | ] 25 | ) 26 | def game_level_selection(game_level_value): 27 | values = game_level_value 28 | return values 29 | 30 | @app.callback(Output('fired_rules_agent_1', 'children'), 31 | Input('ui-shell', 'name')) 32 | def edit_network_graph(input): 33 | return outputs.fired_rules(edges, nodes, roots_names, 34 | cytolayout='dagre', 35 | id='fired_rules_agent_1_graph') 36 | 37 | @app.callback(Output('working_memory_agent_1', 'children'), 38 | Input('ui-shell', 'name')) 39 | def edit_network_graph(input): 40 | return outputs.fired_rules(edges, nodes, roots_names, 41 | cytolayout='dagre', 42 | id='working_memory_agent_1_graph') 43 | 44 | @app.callback( 45 | [Output('interactive_agent_chat', 'children'), 46 | Output('next_action_row', 'children'), 47 | Output('game_score', 'children'), 48 | Output('message_for_top_of_action_selector', 'children')], 49 | [Input('agent_1_actions_pie_chart', 'clickData'), 50 | Input('agent_2_actions_pie_chart', 'clickData'), 51 | Input('submit_action', 'n_clicks'), 52 | Input('reset_environment', 'n_clicks'), 53 | Input('configuration_interact_games_apply', 'n_clicks'), 54 | Input('configuration_interact_games_reset', 'n_clicks')], 55 | [State('interactive_agent_chat', 'children'), 56 | State('game_level_selection', 'value'), 57 | State('all_actions_dropdown', 'value')] 58 | ) 59 | def add_dialog_row(agent_1_click_data, 60 | agent_2_click_data, 61 | submit_action_n_action, 62 | reset_environment_n_clicks, 63 | configuration_interact_games_apply_n_clicks, 64 | configuration_interact_games_reset_n_clicks, 65 | interactive_agent_chat_children, 66 | game_level, 67 | all_actions_dropdown): 68 | 69 | env = env_dict[game_level] 70 | if twc_agent is not None: 71 | twc_agent.kg_graph = kg_graphs[game_level] 72 | ctx = dash.callback_context 73 | 74 | if not ctx.triggered: 75 | button_id = 'No clicks yet' 76 | else: 77 | button_id = ctx.triggered[0]['prop_id'].split('.')[0] 78 | 79 | print(button_id) 80 | 81 | if (button_id != 'agent_1_actions_pie_chart' and 82 | button_id != 'agent_2_actions_pie_chart' and 83 | button_id != 'submit_action' and 84 | reset_environment_n_clicks is None) \ 85 | or button_id == 'reset_environment' \ 86 | or button_id == 'configuration_interact_games_reset' \ 87 | or button_id == 'configuration_interact_games_apply': 88 | obs, info = env.reset() 89 | if twc_agent is not None: 90 | twc_agent.start_episode(1) 91 | 92 | new_interactive_chat_children = \ 93 | [dialog_row(is_agent=False, 94 | text='Objective: ' + info['objective']), 95 | dialog_row(is_agent=False, 96 | text=info['description'].replace('\n\n', '\n') 97 | .replace('\n', '
')), 98 | dialog_row(is_agent=False, text=info['inventory']) 99 | ] 100 | all_actions, ns_agent_actions, dl_agent_actions, loa_facts = \ 101 | get_agent_actions(env, obs, [0], [False], info, 102 | '', scored_action_history, 103 | loa_agent=loa_agent, 104 | dl_agent=[ 105 | twc_agent, 106 | twc_agent_goal_graphs, 107 | twc_agent_manual_world_graphs 108 | ]) 109 | new_action_row = next_actions_row( 110 | all_actions=info['admissible_commands'], 111 | agent_1_actions=ns_agent_actions, 112 | agent_1_rules=loa_rules, 113 | agent_1_facts=loa_facts, 114 | agent_2_actions=dl_agent_actions, 115 | done=False 116 | ) 117 | 118 | score_str = 'Score: %d/%d' % (0, info['max_score']) 119 | 120 | return \ 121 | list(reversed(new_interactive_chat_children)), \ 122 | new_action_row, \ 123 | score_str, \ 124 | MESSAGE_FOR_SELECT_NEXT_ACTION 125 | 126 | elif (button_id.find('pie') != -1 or 127 | (button_id == 'submit_action' 128 | and all_actions_dropdown != 'NONE')): 129 | 130 | print(all_actions_dropdown) 131 | 132 | # recreate old interactions 133 | new_interactive_chat_children = [] 134 | for row in interactive_agent_chat_children[::-1]: 135 | img = \ 136 | row['props']['children'][0]['props']['children'][0][ 137 | 'props']['children'][0]['props']['src'] 138 | is_agent = True if img.find('robot') > -1 else False 139 | text = \ 140 | row['props']['children'][0]['props']['children'][0][ 141 | 'props']['children'][1]['props']['children'] 142 | new_interactive_chat_children.extend( 143 | [dialog_row(is_agent=is_agent, text=text)]) 144 | 145 | if (button_id == 'agent_1_actions_pie_chart'): 146 | action = agent_1_click_data['points'][0]['label'] 147 | elif (button_id == 'agent_2_actions_pie_chart'): 148 | action = agent_2_click_data['points'][0]['label'] 149 | elif (button_id == 'submit_action'): 150 | action = all_actions_dropdown 151 | 152 | obs, score, done, info = env.step(action) 153 | environment_response = obs 154 | 155 | score_str = 'Score: %d/%d' % (score, info['max_score']) 156 | 157 | # add new interactions 158 | new_interactive_chat_children.extend( 159 | [dialog_row(is_agent=True, text=action), 160 | dialog_row(is_agent=False, 161 | text=environment_response.replace('\n\n', '\n') 162 | .replace('\n\n', '\n').replace('\n', '
'))]) 163 | 164 | all_actions, ns_agent_actions, dl_agent_actions, loa_facts = \ 165 | get_agent_actions(env, obs, [0], [False], info, 166 | action, scored_action_history, 167 | loa_agent=loa_agent, 168 | dl_agent=[ 169 | twc_agent, 170 | twc_agent_goal_graphs, 171 | twc_agent_manual_world_graphs 172 | ]) 173 | new_action_row = next_actions_row( 174 | all_actions=info['admissible_commands'], 175 | agent_1_actions=ns_agent_actions, 176 | agent_1_rules=loa_rules, 177 | agent_1_facts=loa_facts, 178 | agent_2_actions=dl_agent_actions, 179 | done=done 180 | ) 181 | 182 | return \ 183 | list(reversed(new_interactive_chat_children)), \ 184 | new_action_row, \ 185 | score_str, \ 186 | MESSAGE_FOR_DONE if done else MESSAGE_FOR_SELECT_NEXT_ACTION 187 | else: 188 | raise PreventUpdate 189 | -------------------------------------------------------------------------------- /components/next_actions_row.py: -------------------------------------------------------------------------------- 1 | from colorsys import hls_to_rgb 2 | 3 | from numpy import tile, zeros 4 | 5 | import dash_carbon_components as dca 6 | import dash_core_components as dcc 7 | import dash_html_components as html 8 | import plotly.graph_objects as go 9 | 10 | NEXT_ACTION_MIN_CONFIDENCE = 0.05 11 | 12 | 13 | def next_actions_row(all_actions: list, 14 | agent_1_actions: dict, 15 | agent_1_rules: dict, 16 | agent_1_facts: dict, 17 | agent_2_actions: dict, 18 | done: bool): 19 | 20 | hex_array = [] 21 | step = 360 // len(all_actions) 22 | for angle in range(0, 360, step): 23 | rgb = hls_to_rgb(angle / 360, .7, .7) 24 | hex = \ 25 | '#%02x%02x%02x' % \ 26 | (round(rgb[0] * 255), round(rgb[1] * 255), round(rgb[2] * 255)) 27 | hex_array.append(hex) 28 | colors = {'colors': hex_array} 29 | 30 | all_actions_list = [] 31 | for idx, action in enumerate(sorted(all_actions)): 32 | all_actions_list += [html.Li( 33 | children=[ 34 | html.Span( 35 | style={ 36 | 'width': '12px', 37 | 'height': '12px', 38 | 'background': colors['colors'][idx], 39 | 'display':'inline-block', 40 | 'margin-right':'6px' 41 | } 42 | ), 43 | html.Span(action), 44 | html.Br() 45 | ] 46 | )] 47 | 48 | first_selection = 'Choose from all possible action list' 49 | dropdown_options = [ 50 | { 51 | 'label': first_selection, 52 | 'value': 'NONE' 53 | } 54 | ] 55 | 56 | for idx, action in enumerate(sorted(all_actions)): 57 | dropdown_options.append({'label': action, 'value': action}) 58 | 59 | actions_dropdown = dca.Dropdown( 60 | id='all_actions_dropdown', 61 | label=first_selection, 62 | options=dropdown_options, 63 | value=dropdown_options[0]['value'], 64 | style={ 65 | 'height': '40px', 66 | } 67 | ) 68 | 69 | action_submit = dca.Button( 70 | 'Perform action', 71 | id='submit_action', 72 | kind='primary', 73 | style={ 74 | 'height': '40px', 75 | }, 76 | size='small' 77 | ) 78 | 79 | all_actions_card = dca.Card( 80 | id='all_actions_card', 81 | title='All Possible Actions', 82 | children=[ 83 | html.Ul( 84 | children=all_actions_list, 85 | className='lead', 86 | style={ 87 | 'paddingTop': '10px', 88 | 'paddingBottom': '10px', 89 | 'listStyleType': 'none' 90 | } 91 | ), 92 | ], 93 | style={ 94 | 'width': '100%', 95 | 'height': 'calc(100% - 50px)', 96 | } 97 | ) 98 | 99 | # Agent 1 Actions Card 100 | agent_1_labels_list = [] 101 | agent_1_values_list = [] 102 | agent_1_text_pos_list = [] 103 | 104 | for action, confidence in agent_1_actions: 105 | agent_1_labels_list.append(action) 106 | agent_1_values_list.append(confidence) 107 | 108 | if confidence > NEXT_ACTION_MIN_CONFIDENCE: 109 | agent_1_text_pos_list.append('outside') 110 | else: 111 | agent_1_text_pos_list.append('none') 112 | 113 | agent_1_max_idx = agent_1_values_list.index(max(agent_1_values_list)) 114 | agent_1_pull = zeros(len(agent_1_values_list)) 115 | agent_1_pull[agent_1_max_idx] = 0.3 116 | 117 | agent_layout = \ 118 | go.Layout(title='Click to perform action', hovermode='closest', 119 | height=330) 120 | 121 | agent_1_rules_children = list() 122 | if agent_1_rules is not None: 123 | for k, v in agent_1_rules.items(): 124 | if v != '': 125 | agent_1_rules_children.append( 126 | k + ' = ' + v.replace('atlocation', 'at_location')) 127 | agent_1_rules_children.append(html.Br()) 128 | agent_1_rules_children = agent_1_rules_children[:-1] 129 | 130 | agent_1_logical_facts_children = list() 131 | if agent_1_facts is not None: 132 | for k, v in agent_1_facts.items(): 133 | if v != []: 134 | for i in v: 135 | if isinstance(i, list): 136 | agent_1_logical_facts_children.append( 137 | k + '(' + ', '.join(i) + ')') 138 | else: 139 | agent_1_logical_facts_children.append(k + '(%s)' % i) 140 | agent_1_logical_facts_children.append(', ') 141 | agent_1_logical_facts_children = \ 142 | agent_1_logical_facts_children[:-1] 143 | agent_1_logical_facts_children.append(html.Br()) 144 | agent_1_logical_facts_children = agent_1_logical_facts_children[:-1] 145 | else: 146 | agent_1_logical_facts_children.append('AMR does not work') 147 | 148 | agent_1_actions_card = dca.Card( 149 | id='agent_1_actions_card', 150 | title='Recommended actions from NeSA (LOA)', 151 | children=[ 152 | dcc.Graph( 153 | id='agent_1_actions_pie_chart', 154 | figure=go.Figure( 155 | data=[go.Pie( 156 | labels=agent_1_labels_list, 157 | values=agent_1_values_list, 158 | sort=False, 159 | marker=dict(colors, 160 | line=dict(color='#efefef', width=1)), 161 | pull=agent_1_pull, 162 | textposition=agent_1_text_pos_list, 163 | showlegend=False, 164 | )], 165 | layout=agent_layout 166 | )), 167 | html.P('Current Logical Facts: ', 168 | style={'text-align': 'left', 'font-weight': 'bold'}), 169 | html.P( 170 | id='logical_facts_p', 171 | children=agent_1_logical_facts_children, 172 | style={'text-align': 'center'}), 173 | html.P('Trained Rules: ', 174 | style={'text-align': 'left', 'font-weight': 'bold'}), 175 | html.P( 176 | id='loa_rule_p', 177 | children=agent_1_rules_children, 178 | style={'text-align': 'center'}), 179 | 180 | ], 181 | style={'width': '100%'} 182 | ) 183 | 184 | # Agent 2 Actions Card 185 | agent_2_labels_list = [] 186 | agent_2_values_list = [] 187 | agent_2_text_pos_list = [] 188 | 189 | for action, confidence in agent_2_actions: 190 | agent_2_labels_list.append(action) 191 | agent_2_values_list.append(confidence) 192 | 193 | if confidence > NEXT_ACTION_MIN_CONFIDENCE: 194 | agent_2_text_pos_list.append('outside') 195 | else: 196 | agent_2_text_pos_list.append('none') 197 | 198 | agent_2_max_idx = agent_2_values_list.index(max(agent_2_values_list)) 199 | agent_2_pull = zeros(len(agent_2_values_list)) 200 | agent_2_pull[agent_2_max_idx] = 0.3 201 | 202 | agent_2_actions_card = dca.Card( 203 | id='agent_2_actions_card', 204 | title='Recommended actions from DL-Agent', 205 | children=[ 206 | dcc.Graph( 207 | id='agent_2_actions_pie_chart', 208 | figure=go.Figure( 209 | data=[go.Pie( 210 | labels=agent_2_labels_list, 211 | values=agent_2_values_list, 212 | sort=False, 213 | marker=dict(colors, 214 | line=dict(color='#efefef', width=1)), 215 | pull=agent_2_pull, 216 | textposition=agent_2_text_pos_list, 217 | showlegend=False 218 | )], 219 | layout=agent_layout 220 | )), 221 | ], 222 | style={'width': '100%'} 223 | ) 224 | 225 | final_row = dca.Row( 226 | children=[ 227 | dca.Row( 228 | children=[ 229 | dca.Column(actions_dropdown, columnSizes=['md-6'], 230 | style={'height': '1.5em'}), 231 | dca.Column(action_submit, columnSizes=['md-2']) 232 | ], 233 | style={ 234 | 'width': '100%', 235 | 'margin-top': '10pt', 236 | } 237 | ), 238 | dca.Row( 239 | children=[ 240 | dca.Column(all_actions_card, columnSizes=['md-2']), 241 | dca.Column(agent_1_actions_card, columnSizes=['md-3']), 242 | dca.Column(agent_2_actions_card, columnSizes=['md-3']) 243 | ], 244 | style={ 245 | 'margin-top': '10pt', 246 | } 247 | ) 248 | ], 249 | style={'paddingLeft': '2em'} 250 | ) if not done else dca.Row(children=[]) 251 | 252 | return final_row -------------------------------------------------------------------------------- /pages/interact/twc_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping 2 | 3 | import numpy as np 4 | 5 | import torch 6 | 7 | from .. import EASY_LEVEL, HARD_LEVEL, LEVELS, MEDIUM_LEVEL 8 | 9 | if True: 10 | from third_party.commonsense_rl import agent 11 | from third_party.commonsense_rl.games import dataset 12 | from third_party.commonsense_rl.utils_twc import extractor 13 | from third_party.commonsense_rl.utils_twc.kg \ 14 | import (RelationExtractor, construct_kg, load_manual_graphs) 15 | from third_party.commonsense_rl.utils_twc.nlp import Tokenizer 16 | from third_party.commonsense_rl.utils_twc.textworld_utils \ 17 | import get_goal_graph 18 | from third_party.commonsense_rl.utils_twc.generic import max_len, to_tensor 19 | 20 | kg_graphs = dict() 21 | 22 | 23 | class TestKnowledgeAwareAgent(agent.KnowledgeAwareAgent): 24 | def act(self, obs: str, score: int, 25 | done: bool, infos: Mapping[str, Any], 26 | scored_commands: list, random_action=False): 27 | batch_size = len(obs) 28 | if not self._episode_has_started: 29 | self.start_episode(batch_size) 30 | 31 | just_finished = [done[b] != self.last_done[b] 32 | for b in range(batch_size)] 33 | sel_rand_action_idx = \ 34 | [np.random.choice(len(infos["admissible_commands"][b])) 35 | for b in range(batch_size)] 36 | if random_action: 37 | return \ 38 | [infos["admissible_commands"][b][sel_rand_action_idx[b]] 39 | for b in range(batch_size)] 40 | 41 | torch.autograd.set_detect_anomaly(True) 42 | input_t = [] 43 | # Build agent's observation: feedback + look + inventory. 44 | state = ["{}\n{}\n{}\n{}".format(obs[b], 45 | infos["description"][b], 46 | infos["inventory"][b], 47 | ' \n'.join(scored_commands[b])) 48 | for b in range(batch_size)] 49 | # Tokenize and pad the input and the commands to chose from. 50 | state_tensor = self._process(state, self.word2id) 51 | 52 | command_list = [] 53 | for b in range(batch_size): 54 | cmd_b = self._process(infos["admissible_commands"][b], 55 | self.word2id) 56 | command_list.append(cmd_b) 57 | max_num_candidate = \ 58 | max_len(infos["admissible_commands"]) 59 | max_num_word = max([cmd.size(1) for cmd in command_list]) 60 | commands_tensor = \ 61 | to_tensor(np.zeros((batch_size, max_num_candidate, max_num_word)), 62 | self.device) 63 | for b in range(batch_size): 64 | commands_tensor[b, :command_list[b].size(0), 65 | :command_list[b].size(1)] = command_list[b] 66 | 67 | localkg_tensor = torch.FloatTensor() 68 | localkg_adj_tensor = torch.FloatTensor() 69 | worldkg_tensor = torch.FloatTensor() 70 | worldkg_adj_tensor = torch.FloatTensor() 71 | localkg_hint_tensor = torch.FloatTensor() 72 | worldkg_hint_tensor = torch.FloatTensor() 73 | if self.graph_emb_type is not None and \ 74 | ('local' in self.graph_type or 'world' in self.graph_type): 75 | 76 | # prepare Local graph and world graph .... 77 | # Extra empty node (sentinel node) for no attention option 78 | # (Xiong et al ICLR 2017 and https://arxiv.org/pdf/1612.01887.pdf) 79 | if 'world' in self.graph_type: 80 | world_entities = [] 81 | for b in range(batch_size): 82 | world_entities.extend(self.world_graph[b].nodes()) 83 | world_entities = set(world_entities) 84 | wentities2id = dict( 85 | zip(world_entities, range(len(world_entities)))) 86 | max_num_nodes = \ 87 | len(wentities2id) + \ 88 | 1 if self.sentinel_node else len(wentities2id) 89 | worldkg_tensor = \ 90 | self._process(wentities2id, self.node2id, 91 | sentinel=self.sentinel_node) 92 | world_adj_matrix = \ 93 | np.zeros((batch_size, max_num_nodes, max_num_nodes), 94 | dtype="float32") 95 | for b in range(batch_size): 96 | # get adjacentry matrix for each batch based on the 97 | # all_entities 98 | triplets = [list(edges) 99 | for edges 100 | in self.world_graph[b].edges.data('relation')] 101 | for [e1, e2, r] in triplets: 102 | e1 = wentities2id[e1] 103 | e2 = wentities2id[e2] 104 | world_adj_matrix[b][e1][e2] = 1.0 105 | world_adj_matrix[b][e2][e1] = 1.0 # reverse relation 106 | for e1 in list(self.world_graph[b].nodes): 107 | e1 = wentities2id[e1] 108 | world_adj_matrix[b][e1][e1] = 1.0 109 | if self.sentinel_node: # Fully connected sentinel 110 | world_adj_matrix[b][-1, :] = \ 111 | np.ones((max_num_nodes), dtype="float32") 112 | world_adj_matrix[b][:, -1] = \ 113 | np.ones((max_num_nodes), dtype="float32") 114 | worldkg_adj_tensor = \ 115 | to_tensor(world_adj_matrix, self.device, type="float") 116 | 117 | if 'local' in self.graph_type: 118 | local_entities = [] 119 | for b in range(batch_size): 120 | local_entities.extend(self.local_graph[b].nodes()) 121 | local_entities = set(local_entities) 122 | lentities2id = dict( 123 | zip(local_entities, range(len(local_entities)))) 124 | max_num_nodes = \ 125 | len(lentities2id) + \ 126 | 1 if self.sentinel_node else len(lentities2id) 127 | localkg_tensor = \ 128 | self._process(lentities2id, self.word2id, 129 | sentinel=self.sentinel_node) 130 | local_adj_matrix = np.zeros( 131 | (batch_size, max_num_nodes, max_num_nodes), 132 | dtype="float32") 133 | for b in range(batch_size): 134 | # get adjacentry matrix for each batch based on the 135 | # all_entities 136 | triplets = [list(edges) 137 | for edges 138 | in self.local_graph[b].edges.data('relation')] 139 | for [e1, e2, r] in triplets: 140 | e1 = lentities2id[e1] 141 | e2 = lentities2id[e2] 142 | local_adj_matrix[b][e1][e2] = 1.0 143 | local_adj_matrix[b][e2][e1] = 1.0 144 | for e1 in list(self.local_graph[b].nodes): 145 | e1 = lentities2id[e1] 146 | local_adj_matrix[b][e1][e1] = 1.0 147 | if self.sentinel_node: 148 | local_adj_matrix[b][-1, :] = np.ones((max_num_nodes), 149 | dtype="float32") 150 | local_adj_matrix[b][:, -1] = np.ones((max_num_nodes), 151 | dtype="float32") 152 | localkg_adj_tensor = to_tensor(local_adj_matrix, self.device, 153 | type="float") 154 | 155 | if len(scored_commands) > 0: 156 | # Get the scored commands as one string 157 | hint_str = \ 158 | [' \n'.join(scored_commands[b][-self.hist_scmds_size:]) 159 | for b in range(batch_size)] 160 | else: 161 | hint_str = [obs[b] + ' \n' + infos["inventory"][b] 162 | for b in range(batch_size)] 163 | localkg_hint_tensor = self._process(hint_str, self.word2id) 164 | worldkg_hint_tensor = self._process(hint_str, self.node2id) 165 | 166 | input_t.append(state_tensor) 167 | input_t.append(commands_tensor) 168 | input_t.append(localkg_tensor) 169 | input_t.append(localkg_hint_tensor) 170 | input_t.append(localkg_adj_tensor) 171 | input_t.append(worldkg_tensor) 172 | input_t.append(worldkg_hint_tensor) 173 | input_t.append(worldkg_adj_tensor) 174 | 175 | outputs, indexes, values = self.model(*input_t) 176 | outputs, indexes, values = \ 177 | outputs, indexes.view(batch_size), values.view(batch_size) 178 | sel_action_idx = [indexes[b] for b in range(batch_size)] 179 | action = \ 180 | [infos["admissible_commands"][b][sel_action_idx[b]] 181 | for b in range(batch_size)] 182 | 183 | if any(done): 184 | for b in range(batch_size): 185 | if done[b]: 186 | self.model.reset_hidden_per_batch(b) 187 | action[b] = 'look' 188 | 189 | assert self.mode == "test" 190 | return action, outputs 191 | 192 | 193 | class myDict(dict): 194 | def __init__(self, **arg): 195 | super(myDict, self).__init__(**arg) 196 | 197 | def __getattr__(self, key): 198 | return self.get(key) 199 | 200 | def __setattr__(self, key, value): 201 | self.__setitem__(key, value) 202 | 203 | 204 | def get_twc_agent(): 205 | opt = myDict() 206 | 207 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 208 | 209 | opt.difficulty_level = 'easy' 210 | opt.graph_type = 'world' 211 | opt.graph_mode = 'evolve' 212 | opt.mode = 'test' 213 | opt.eval_max_step_per_episode = 50 214 | opt.local_evolve_type = 'direct' 215 | opt.hidden_size = 300 216 | opt.world_evolve_type = 'manual' 217 | opt.egreedy_epsilon = 0.0 218 | opt.emb_loc = 'embeddings/' 219 | opt.word_emb_type = 'glove' 220 | opt.graph_emb_type = 'glove' 221 | opt.hist_scmds_size = 3 222 | opt.batch_size = 1 223 | opt.no_eval_episodes = 5 224 | opt.verbose = False 225 | 226 | opt.pretrained_model = \ 227 | 'results/' \ 228 | 'knowledgeaware_twc_evolve_world_glove_glove-1runs_' \ 229 | '100episodes_3hsize_0.0eps_easy_direct_manual_0runId.pt' 230 | 231 | tk_extractor = extractor.get_extractor("max") 232 | 233 | graph = None 234 | 235 | print("Testing ...") 236 | tokenizer = Tokenizer(noun_only_tokens=False, 237 | use_stopword=False, 238 | ngram=3, 239 | extractor=tk_extractor) 240 | rel_extractor = RelationExtractor(tokenizer, 241 | openie_url='http://localhost:9000/') 242 | agent = \ 243 | TestKnowledgeAwareAgent(graph, opt, tokenizer, rel_extractor, device) 244 | agent.type = "knowledgeaware" 245 | 246 | print('Loading Pretrained Model ...', end='') 247 | agent.model.load_state_dict( 248 | torch.load(opt.pretrained_model, map_location=device)) 249 | print('DONE') 250 | 251 | agent.test(opt.batch_size) 252 | opt.nepisodes = opt.no_eval_episodes # for testing 253 | opt.max_step_per_episode = opt.eval_max_step_per_episode 254 | print("RUN") 255 | 256 | infos_to_request = agent.infos_to_request 257 | infos_to_request.max_score = True 258 | 259 | print("Loading Graph ... ", end='') 260 | manual_world_graphs = dict() 261 | goal_graphs = {} 262 | lower_level_str = {EASY_LEVEL: 'easy', 263 | MEDIUM_LEVEL: 'medium', HARD_LEVEL: 'hard'} 264 | for level in LEVELS: 265 | game_path = 'static/games/twc/%s/test/' % lower_level_str[level] 266 | 267 | agent.kg_graph, _, _ = \ 268 | construct_kg(game_path + '/conceptnet_subgraph.txt') 269 | kg_graphs[level] = agent.kg_graph 270 | 271 | manual_world_graph = \ 272 | load_manual_graphs(game_path + '/manual_subgraph_brief') 273 | manual_world_graphs.update(manual_world_graph) 274 | 275 | game_path = game_path + '/*.ulx' 276 | 277 | env, game_file_names = \ 278 | dataset.get_game_env(game_path, infos_to_request, 279 | opt.max_step_per_episode, opt.batch_size, 280 | mode='test', verbose=False) 281 | 282 | for game_file in env.gamefiles: 283 | goal_graph = get_goal_graph(game_file) 284 | if goal_graph: 285 | game_id = game_file.split('-')[-1].split('.')[0] 286 | goal_graphs[game_id] = goal_graph 287 | print(' DONE') 288 | 289 | return agent, goal_graphs, manual_world_graphs 290 | -------------------------------------------------------------------------------- /static/data/scalars_example.json: -------------------------------------------------------------------------------- 1 | {"group_1_title": {"epoch": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], "values": {"metric_1": [0.7263092642955757, 0.650814175746826, 0.9452719254350358, 0.9605748218074432, 0.9907773133942781, 0.8070759387524138, 0.037596551585502014, 0.6375140210503832, 0.05825530898779607, 0.7649976267057701, 0.21416577633538159, 0.3776367172323728, 0.7056007582763498, 0.3008019870474091, 0.328641402457131, 0.8174738102196303, 0.6610996452930827, 0.6744939584540219, 0.6659853912741522, 0.7074776532977316, 0.8268828968948576, 0.18567535643962596, 0.6709408236838508, 0.4659316891933447, 0.6139158862644765, 0.6594406728950818, 0.583728867832163, 0.6303004663945762, 0.32981096725484327, 0.41490140517508745, 0.5761062868211977, 0.9403014434554472, 0.2773285790301915, 0.9251844903212242, 0.2746553895723519, 0.4853315369133774, 0.7551193617483314, 0.9954838563578079, 0.8214176415465565, 0.19288748479971973, 0.5385706552639667, 0.5333290299348397, 0.3118174266746552, 0.3779987572135167, 0.3577437842596277, 0.9867149516472253, 0.3746001241337039, 0.19199208162460601, 0.9301952402055853, 0.950616061197246, 0.17611770593766563, 0.5546114436782651, 0.3237618631897038, 0.934999342728096, 0.2556903955201888, 0.951494728885388, 0.8454217824439942, 0.8432156751022196, 0.5726059836657801, 0.9059890528488673, 0.36360770691177546, 0.17306336684949264, 0.8726367897740441, 0.1369181379811809, 0.48069823351544516, 0.9679866989419651, 0.6883934287655215, 0.4236372984151183, 0.5525177609252383, 0.8576442268669524, 0.36675385954852324, 0.4248199153620257, 0.8774676626859607, 0.9273967260640797, 0.30329479773220513, 0.40373769627209644, 0.10184299415989606, 0.6221018874959273, 0.39838318400858574, 0.8082756489152386, 0.2797154359833389, 0.8834377056218622, 0.8463101549791504, 0.6987100672857748, 0.025980693694148016, 0.6917407768226216, 0.906465314611594, 0.05602487256447042, 0.27091363313688144, 0.9034248510604405, 0.7913341976400963, 0.7957396483295873, 0.5317846915772171, 0.9549186002225954, 0.7211519789345014, 0.09849402549985142, 0.33484713198283955, 0.35312966373111865, 0.8734570109747126, 0.6775282197029389], "metric_2": [0.2988989225113444, 0.8557776831487817, 0.509054819808299, 0.6831965479663075, 0.16238951900080745, 0.7136989698867923, 0.2556203247119677, 0.5795234516475726, 0.5541312318169648, 0.5536771249924618, 0.3671884206212115, 0.11689290560477505, 0.38519061523941944, 0.9921674857428865, 0.8336488933160848, 0.9213849828182269, 0.2238351422047281, 0.6290009364849366, 0.6314867273900496, 0.30732016546470664, 0.6643300034965155, 0.8260311283317132, 0.6136795733045007, 0.8911276068568877, 0.9725360331556843, 0.24866091531320955, 0.5200892255203482, 0.45859119184737185, 0.582041797863422, 0.27956438325585986, 0.6915438318741579, 0.8401365325970566, 0.24409583940512725, 0.5162463333013708, 0.0633992802507286, 0.5667670986510974, 0.6460277592932612, 0.9112138685144716, 0.9416958367122912, 0.07462267066016137, 0.9737484735140046, 0.9608870169717301, 0.7790596529371325, 0.6159674709776304, 0.20789811896070687, 0.9155898987514287, 0.8681024010220514, 0.1931652521223045, 0.7738277353705666, 0.32324440806382304, 0.7507370759097293, 0.1368370534704243, 0.49797935898265355, 0.5028009001164992, 0.2903075488240352, 0.5014764526201635, 0.9406571997146836, 0.7028739359375553, 0.2296395605323266, 0.6196678954774116, 0.36328384259487023, 0.25406650092356065, 0.15051647842851767, 0.8847600970079634, 0.27359874743071755, 0.9906553273870699, 0.8472274672254829, 0.21398413895917834, 0.9956272008619329, 0.486784084927958, 0.9752058565312457, 0.9441412407227038, 0.7698824305163635, 0.5252650438398864, 0.9504517008022179, 0.12678355178115008, 0.3053589211810116, 0.1992564829478749, 0.19497552814968444, 0.5222807678109807, 0.8208359741138658, 0.33815358403295315, 0.05779694359416998, 0.7529928850610074, 0.3318410666509467, 0.751650832137123, 0.4650170610447715, 0.8616553868631017, 0.8298713063783847, 0.7589922934389705, 0.04339073522046566, 0.6650808040359492, 0.3046213914020355, 0.1980210170903437, 0.10053103983350742, 0.2543799733414147, 0.4767996982337023, 0.6059217055842477, 0.7103150390389937, 0.6113852410822783], "metric_3": [0.36452811163137244, 0.7398796243675478, 0.4808926803813087, 0.6139811688163317, 0.4998358611463942, 0.7660981172271755, 0.6059939719299644, 0.5828395647968088, 0.8379662556760578, 0.10437376219868932, 0.0972376921757131, 0.5556598150858085, 0.7301789298922184, 0.8140259330733657, 0.7643790329608704, 0.9018658769752959, 0.6040221083348205, 0.047251141633777305, 0.8009817074035986, 0.34378135540393784, 0.7252191645633194, 0.04343888809316321, 0.3792320721905854, 0.6265254817798267, 0.2138004092953777, 0.016345378926308163, 0.05256362337536136, 0.6600720985248092, 0.3412899804715075, 0.23493302554165052, 0.6090609623448537, 0.4122068787822578, 0.40411253398190505, 0.4550682317942557, 0.9395648307639259, 0.14487550324769916, 0.047798917238674354, 0.42493679241836846, 0.9271837539452737, 0.2830453356009647, 0.7635972173349447, 0.5043317015335269, 0.24477944574184052, 0.8946222413819432, 0.9638095642738045, 0.293496032381647, 0.7870625662682522, 0.276263813690405, 0.012222696388410448, 0.5996729582090634, 0.3483519913350136, 0.48913503639033673, 0.06144795919939261, 0.42924891999965675, 0.3228804624439221, 0.6122498493025151, 0.6495289854485388, 0.35088394672994994, 0.17921360963962474, 0.408879423799124, 0.31891033824775183, 0.19079521948298517, 0.3978761917940027, 0.7159060239784119, 0.8409287889211757, 0.3065270597711369, 0.0009575551904349444, 0.19910388065180717, 0.332443738606107, 0.2965757260551386, 0.802508882657789, 0.8553278899590605, 0.7110571677819751, 0.29168415006762716, 0.1443210564628431, 0.8416107001956058, 0.605093102870106, 0.27411918024156035, 0.357379222542488, 0.9183935304739553, 0.42168102547936903, 0.6069773022056989, 0.4413776892640755, 0.7034270974145194, 0.8696141385901581, 0.8921680588304426, 0.07844400587359424, 0.559264053023562, 0.9959694433203716, 0.36790820608793984, 0.4132566615326493, 0.6730952234023334, 0.2697008456569888, 0.5954272017660522, 0.9046416257082337, 0.5072114502799511, 0.24076483241877356, 0.18421664968624396, 0.5389059766610129, 0.9428590617049787]}}, "group_2_title": {"epoch": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], "values": {"metric_1a": [0.28609746331611463, 0.8667230767602756, 0.6054304590207998, 0.13677762190696408, 0.24211414103465922, 0.8749164193067819, 0.1771714864079451, 0.606235738177549, 0.6520170143268077, 0.8578865845416932, 0.17130162264866777, 0.6500857208288926, 0.051483080657358515, 0.8818006530772926, 0.7270858617350963, 0.32956955231551444, 0.2388462727736046, 0.4494990464756402, 0.43626582023385385, 0.4745321194342699, 0.9127448379984519, 0.3176112564008119, 0.776679112552363, 0.4269457500763617, 0.08817065454345063, 0.986753544899226, 0.4596455926821241, 0.3033605234081628, 0.022451927525360738, 0.902742519264832, 0.6709531064511908, 0.40148367124861883, 0.3212834233511104, 0.4314926258256544, 0.9676634399788752, 0.5200826440166473, 0.34604225291597634, 0.9130649382718549, 0.027543346870844365, 0.002301837869361334, 0.023279120329323333, 0.5264233528681005, 0.5587635098075245, 0.9027270491040046, 0.7129281963516908, 0.28799763452666716, 0.8610765107104382, 0.9908743252807782, 0.598604973565301, 0.6223702920877702, 0.24594099963306038, 0.26029892506175734, 0.1822292685664586, 0.8114304779651426, 0.2467343695671249, 0.9896124593422727, 0.12067288273051602, 0.6921787859193946, 0.5392638301157847, 0.33981126464059586, 0.4291503141794232, 0.6358949496946437, 0.871805830415785, 0.6367618170361738, 0.7756609630677284, 0.9319470247888401, 0.34137703579917855, 0.04204947234700984, 0.3946827931471254, 0.6015244531058325, 0.7016844892405526, 0.9420978607373333, 0.3276571984458271, 0.859776112406228, 0.7675305947249379, 0.40577229679826177, 0.16024779562695146, 0.024762562621508666, 0.43463915327476577, 0.7845207294044492, 0.13912100719597043, 0.7574932085405249, 0.45128816549199935, 0.30417781937046917, 0.17827134144114898, 0.16252678073019333, 0.5576807191103, 0.1939888643866552, 0.5289844872971621, 0.19409550094506556, 0.9447385230576967, 0.14002356621724943, 0.12240816443442515, 0.5392395673852325, 0.005113920618228129, 0.24181613141542457, 0.18934859344551236, 0.8961090824700195, 0.031193391393572245, 0.2157537109943849], "metric_2a": [0.06465812437571727, 0.5259705912540801, 0.3865814847404889, 0.02633897066406432, 0.719152699061079, 0.123985037949368, 0.4042233243870762, 0.28603586599892883, 0.038665516646805176, 0.46353917459295324, 0.1585162916046572, 0.9951066505806376, 0.8682973469453199, 0.1634756081559524, 0.956676857338513, 0.2759504274121729, 0.018414863447751117, 0.2599074097166184, 0.09231053190579419, 0.5296240594571995, 0.43477404472423387, 0.3199647584150702, 0.06498952890978338, 0.6096337092184928, 0.8777853109057544, 0.9958396134295762, 0.9826497461150956, 0.017610782640402722, 0.9597384026756459, 0.4736007727768883, 0.5485741829092401, 0.8229300519333949, 0.6599195649230792, 0.11849789591678361, 0.2738327187308778, 0.15584324845837028, 0.6691354737949559, 0.544086325204595, 0.5985173553841355, 0.2360864397023108, 0.6930801676922161, 0.43357024810144484, 0.40413903410660146, 0.5618276017040155, 0.802808031545467, 0.9591903506159738, 0.05992510786373484, 0.17504260860140353, 0.8567141837455394, 0.3563837796654393, 0.03659879461194071, 0.6623751490995026, 0.31225354131387595, 0.14602672571775865, 0.31494386478180925, 0.6829631376319457, 0.5885231454839575, 0.7851208424405598, 0.6784911361863933, 0.24889587707405358, 0.439909138090951, 0.9429810038542626, 0.3989380624979445, 0.34398278424555495, 0.74949463925559, 0.767017593498918, 0.12304696731899456, 0.46134766674093375, 0.3077908168732244, 0.7323802196097593, 0.6696428677233439, 0.2430581380938085, 0.7594183603933583, 0.9757927046220066, 0.5674510446901367, 0.18072069531339718, 0.22242938382835276, 0.7904726911803434, 0.6124749228527778, 0.3224327866070853, 0.08044642389791878, 0.7230486593548343, 0.37796882464496295, 0.8519692212945333, 0.6854641238488012, 0.16276831422240712, 0.44882138905288016, 0.5199424868085954, 0.3981381252472024, 0.9125346962752106, 0.6920554470432193, 0.43821430000210504, 0.04710869245229343, 0.933126493621845, 0.6418528726599054, 0.04740178599926459, 0.14839384731729988, 0.3334102409994021, 0.7086953547093742, 0.25438609326223127], "metric_3a": [0.5651807262774659, 0.2993231650414313, 0.9258538975208918, 0.7434496749241533, 0.2638095151877584, 0.7824412461691287, 0.6494211729498465, 0.6789811114557915, 0.6731868125933886, 0.16626806598491417, 0.6270453503990296, 0.22097835315913272, 0.1489648839629506, 0.2719493950667977, 0.15559618986132218, 0.4404051695876331, 0.06266157696843999, 0.943458277908139, 0.8427447377630385, 0.026143649251483825, 0.04497811959529707, 0.8549387150496773, 0.35827174274290396, 0.49800544017165627, 0.9557392658521625, 0.5527450274441481, 0.8411870958328104, 0.9458676034436142, 0.15513732262895574, 0.9705647888467367, 0.8834815962432742, 0.7829775996064641, 0.6028454936149898, 0.6885425508911416, 0.47988605438792, 0.4863409340860102, 0.5995806601546987, 0.25077918442368674, 0.7755751430177866, 0.8949917284779108, 0.06317750368657371, 0.8721663202583677, 0.45648359206444156, 0.5120626138469158, 0.7294690407972457, 0.6208096142216849, 0.609339313103225, 0.07604139098361107, 0.7323234538472966, 0.0390811137557574, 0.6333817085899165, 0.9690651843973804, 0.0509921078093768, 0.7695149351438397, 0.5260542924869207, 0.44714051141666267, 0.8960957808403007, 0.6857266019195839, 0.6523991776575594, 0.7076515577154874, 0.5166968756965976, 0.004356666812825205, 0.4069208609445836, 0.42395783587185587, 0.024266751769454475, 0.12555136199270567, 0.27004841637301535, 0.24124861815284637, 0.1329504746215724, 0.8252934809404322, 0.7973772464140926, 0.4356962611740467, 0.9258713685790705, 0.4514508997413963, 0.7888085597584025, 0.13497735063041827, 0.10633388867774951, 0.08540101863607119, 0.2771026612430575, 0.19361290204310921, 0.7758979211732474, 0.591826821164162, 0.9172899952299286, 0.15349899185099114, 0.0040908669189275715, 0.9444985266855123, 0.3244854607048602, 0.18429357573739147, 0.6787510206609458, 0.2667561984482465, 0.48515536702977713, 0.13916844867209344, 0.304777963906657, 0.444124026770498, 0.16011953488973163, 0.6419861852548925, 0.13322307639925313, 0.5480399719269131, 0.4752113036190607, 0.6365005094288997]}}, "group_3_title": {"epoch": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], "values": {"metric_1b": [0.18305856158206812, 0.11230478209532047, 0.048924985214618744, 0.8633907562757337, 0.5988058438822577, 0.0540019624820659, 0.6249787622915396, 0.5770735271802475, 0.13614091188545063, 0.2909864003101479, 0.5602876435147006, 0.34485340084644456, 0.41206728786666336, 0.9223926736597513, 0.008047669306157745, 0.5617958244250182, 0.8829024190951236, 0.10109287968200986, 0.03288722763800056, 0.8297131457578969, 0.8478987354472525, 0.8400866305345334, 0.8090888819763489, 0.07583088270843685, 0.452886769612057, 0.4188198585002759, 0.7007690658198222, 0.8226117480947263, 0.5191097991449249, 0.4632186822885712, 0.6213838405153324, 0.4100672693633042, 0.9133827684501118, 0.5356393045732033, 0.024262235351979577, 0.1047998439859007, 0.03454300099366636, 0.8578190720935649, 0.36751691695882815, 0.6281705276309232, 0.2315667590538375, 0.9191020571477505, 0.3037256015993952, 0.6897173562383274, 0.38513304971338436, 0.700568572864533, 0.8891891557275073, 0.889989598207128, 0.15978083927117903, 0.009303351418654349, 0.4398732177610969, 0.9384739931043278, 0.9926258124962272, 0.19482001409785388, 0.35547054220734964, 0.7122940718832766, 0.7368718239209021, 0.04163146572929788, 0.7033757097883613, 0.09049893923118313, 0.8551462856987802, 0.47325375444169226, 0.061532637564271964, 0.1551888082194569, 0.4717104136883753, 0.4517258551035066, 0.10405036825017211, 0.9309831441028743, 0.6709902659964543, 0.8122547085846817, 0.5588181086507575, 0.43809785545203506, 0.6618475508600294, 0.12677563519184398, 0.014148719600260429, 0.4472681186672368, 0.46023906164383666, 0.45672459038111246, 0.7379400848818168, 0.343935514235037, 0.5006995021373022, 0.6859292587206353, 0.9878808159939284, 0.027039572170882376, 0.7443128206033511, 0.5723588784662971, 0.6832086299006084, 0.6140257458821505, 0.019298531502040417, 0.8029107890682183, 0.09858808146290221, 0.9306459199015064, 0.1424412458359725, 0.4037772287819633, 0.12861722029320233, 0.8708723094816239, 0.5296766984094904, 0.5348866351094752, 0.5087036066711695, 0.19799053622390084], "metric_2b": [0.19667742626088902, 0.6676018753632206, 0.08774110793458123, 0.72251991244779, 0.6080101257851253, 0.583412323459681, 0.2705046983973056, 0.32852568406840266, 0.01402961865375485, 0.39245729091229287, 0.9386679004572778, 0.09741620889478297, 0.15653697669084354, 0.6367563942484307, 0.4792434886947283, 0.16486944631486355, 0.35022022059531577, 0.6326617917596368, 0.4867928932517773, 0.2223466026644948, 0.3773185815972727, 0.8806421759411144, 0.2863900159535838, 0.8686738966293995, 0.9713204343055553, 0.2536420502349507, 0.14701139772723293, 0.5779908786771067, 0.5254892052502556, 0.4689136420401001, 0.5468343994468052, 0.3165472038928928, 0.20487726903649994, 0.7353351846206654, 0.4369574907027254, 0.0022234584815107317, 0.9174921093869562, 0.05753153266148414, 0.303911319179872, 0.0755629595481202, 0.8689207431796679, 0.9299623416191993, 0.3003708188430345, 0.7256940503546374, 0.35341208048229755, 0.5677885938903329, 0.8538569666841529, 0.5043951215380752, 0.8939606386330914, 0.43111871365354393, 0.08586577217593472, 0.8451000123577365, 0.3634900722428326, 0.5205551789656291, 0.9040479830581949, 0.0522360046617627, 0.8490257954521038, 0.6303773174549667, 0.017244717841017887, 0.8723112815390648, 0.37536514566320967, 0.20456199756907334, 0.11965999632504098, 0.7476332767904883, 0.5345892245642186, 0.6091883686855478, 0.05685906289447484, 0.6536313429834902, 0.5584984087012688, 0.09963851250622224, 0.6899765722636626, 0.060908030560592796, 0.5833417174795782, 0.1630569941171237, 0.5413420523170934, 0.07558100092706133, 0.9712284871126535, 0.2216303384345919, 0.09328711931231326, 0.6625907898240395, 0.41979812257554316, 0.8758439074212724, 0.17730237556611017, 0.031573388163868565, 0.09784863717061176, 0.05376490028641634, 0.33192348201749966, 0.5500850213374298, 0.4270771767800998, 0.03378382106100053, 0.5669531023116394, 0.442559283355225, 0.06008925475366422, 0.23198942499736963, 0.6703493735831324, 0.40049898075692125, 0.48086853535587504, 0.7835504584483176, 0.46586574246503976, 0.6942807654136312], "metric_3b": [0.2799353615526686, 0.35245618458149197, 0.295820846295927, 0.6329169664104469, 0.6728996542396825, 0.8249847374417317, 0.008241303349783458, 0.7245144194952651, 0.6825048806102164, 0.9636965985974566, 0.24840498674722633, 0.08225262978034154, 0.36468246584805364, 0.32662319138356677, 0.4310096713200283, 0.6767590544574441, 0.16277003649659816, 0.7338539871262406, 0.1728370830087388, 0.2938427430517213, 0.2752784174973113, 0.264000211313819, 0.28767756781830656, 0.6578090911877161, 0.9354245661704084, 0.46914322679254983, 0.7863511344604165, 0.5737009433555612, 0.7753024221433785, 0.7733602155860151, 0.18517287113880732, 0.8634598673649996, 0.3718264877872375, 0.20283514496408916, 0.2167908492786229, 0.14075772676028553, 0.22935138558040014, 0.34491696946594874, 0.9911591234364501, 0.3999081385708595, 0.35373298226964833, 0.18009054630615107, 0.26048106135508564, 0.3925425340837563, 0.2394334888219214, 0.41138187717916463, 0.6823298459490825, 0.7869688381006704, 0.9955666898869401, 0.34318906604678323, 0.7644671416738077, 0.014019312713014154, 0.5174752651854067, 0.08267187949745336, 0.14310197605097086, 0.18313109032676778, 0.7051408655382979, 0.3396445459597186, 0.868175373758711, 0.8766491069864557, 0.8129734198767626, 0.7973369176852978, 0.6441495896740429, 0.8146576278066356, 0.5413545084919077, 0.8154274426849845, 0.33231778042931537, 0.5104188105513344, 0.37773679072651334, 0.7827374138045159, 0.048948492521863396, 0.1503298872922737, 0.17639654674837535, 0.9085239103931825, 0.5163529004607889, 0.8614388579364515, 0.712781520078112, 0.21435289427798798, 0.6787883368689688, 0.1966916185421298, 0.18611615009801463, 0.40344048036116975, 0.13479024014867136, 0.3702597171033418, 0.7530637556610457, 0.07814422050796022, 0.9844467623765932, 0.25828088838534846, 0.5057933281890507, 0.09303864888060853, 0.5781598035629075, 0.23256044538093645, 0.8774522282048897, 0.5037685345559803, 0.7986760279292796, 0.4517449770193305, 0.08880061405817596, 0.8160275675487458, 0.05795791511748749, 0.4909361660967574]}}} --------------------------------------------------------------------------------