Details and bibtex
67 |
68 | The paper presents an initial demonstration of logical optimal action (LOA) on TextWorld (TW) Coin collector, TW Cooking, TW Commonsense, and Jericho. In this version, the human player can select an action by hand and recommendation action list from LOA with visualizing acquired knowledge for improvement of interpretability of trained rules.
69 |
70 | ```
71 | @inproceedings{kimura-etal-2021-loa,
72 | title = "{LOA}: Logical Optimal Actions for Text-based Interaction Games",
73 | author = "Kimura, Daiki and Chaudhury, Subhajit and Ono, Masaki and Tatsubori, Michiaki and Agravante, Don Joven and Munawar, Asim and Wachi, Akifumi and Kohita, Ryosuke and Gray, Alexander",
74 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing: System Demonstrations",
75 | month = aug,
76 | year = "2021",
77 | address = "Online",
78 | publisher = "Association for Computational Linguistics",
79 | url = "https://aclanthology.org/2021.acl-demo.27",
80 | doi = "10.18653/v1/2021.acl-demo.27",
81 | pages = "227--231"
82 | }
83 | ```
84 |
85 |
86 |
87 | ## License
88 |
89 | MIT License
--------------------------------------------------------------------------------
/pages/interact/layout/graphs_layout/interact_layout.py:
--------------------------------------------------------------------------------
1 | import dash_carbon_components as dca
2 | import dash_html_components as html
3 | from components import dialog_row, fired_rule_card, next_actions_row
4 |
5 | interact_layout = dca.Grid(
6 | style={'padding': '16px',
7 | 'height': 'calc(100% - 40px)',
8 | 'overflow': 'auto'},
9 | className='bx--grid--narrow bx--grid--full-width',
10 | children=[
11 | dca.Row(
12 | children=[
13 | dca.Column(columnSizes=['sm-4'], children=[
14 | dca.Card(
15 | id='interactive_agent_chat',
16 | children=[
17 | ],
18 | style={
19 | 'overflow': 'auto',
20 | 'display': 'flex',
21 | 'flex-direction': 'column-reverse',
22 | 'height': '400px',
23 | 'width': '100%'
24 | }
25 | )
26 | ])
27 | ],
28 | ),
29 | dca.Row(
30 | children=[
31 | dca.Column(columnSizes=['sm-3'], children=[
32 | dca.Card(
33 | id='select_next_action_card',
34 | children=[
35 | html.H3('',
36 | id='message_for_top_of_action_selector',
37 | style={
38 | 'padding-bottom': '0px'
39 | }),
40 |
41 | html.Div(
42 | id='next_action_row',
43 | children=[
44 | next_actions_row(
45 | [''],
46 | agent_1_actions=[['', 0]],
47 | agent_1_rules=None,
48 | agent_1_facts=None,
49 | agent_2_actions=[['', 0]],
50 | done=False
51 | ),
52 | ]
53 | ),
54 | ],
55 | style={
56 | 'margin-top': '20px',
57 | 'margin-bottom': '20px',
58 | 'height': 'calc(100% - 50px)'
59 | }
60 | ),
61 | ]),
62 | dca.Column(columnSizes=['sm-1'], children=[
63 | dca.Card(
64 | id='apply_next_action_card',
65 | children=[
66 | html.Div(
67 | children=[
68 | dca.Button(
69 | id='reset_environment',
70 | size='sm',
71 | children='Reset Environment',
72 | kind='primary',
73 | style={'margin-top': '10%'}
74 | )
75 | ],
76 | style={'text-align': 'center'}
77 | ),
78 | html.P(
79 | id='game_score', children='',
80 | style={
81 | 'text-align': 'center',
82 | 'margin-top': '5%',
83 | 'margin-bottom': '5%'
84 | }
85 | ),
86 | ],
87 | style={
88 | 'margin-top': '20px',
89 | 'margin-bottom': '0px',
90 | 'height': 'calc(100% - 50px)'
91 | }
92 | ),
93 | ]),
94 | ],
95 | ),
96 | ])
97 |
--------------------------------------------------------------------------------
/pages/interact/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | from .. import (DL_AGENT, EASY_LEVEL, HARD_LEVEL, LEVELS, LOA_AGENT,
5 | MEDIUM_LEVEL)
6 | from .functions.data_processing import (create_loss_data, create_network,
7 | create_slider_dict, load_network,
8 | load_scalars)
9 |
10 | if True:
11 | os.environ['DDLNN_HOME'] = 'third_party/loa/third_party/dd_lnn/'
12 | os.environ['TWC_HOME'] = 'static/games/twc/'
13 |
14 | sys_path_backup = sys.path
15 | sys.path.append('third_party/commonsense_rl/')
16 | from .twc_agent import get_twc_agent, kg_graphs
17 | sys.path = sys.path[:-1]
18 |
19 | sys.path.append('third_party/loa/')
20 | from third_party.loa.amr_parser import AMRSemParser
21 | from third_party.loa.loa_agent import LOAAgent, LogicalTWCQuantifier
22 |
23 |
24 | network = \
25 | load_network(os.path.join('static', 'data', 'network1.json'))
26 | scalars = \
27 | create_loss_data(load_scalars(os.path.join(
28 | 'static', 'data', 'scalars_lnn_wts_N4_pos1_neg2.npz')))
29 |
30 | roots_names, nodes, edges = create_network(network_dict=network)
31 | slider_dict = create_slider_dict(network['epochs'])
32 | options_loss_checklist = [{"label": x, "value": x}
33 | for x in scalars['loss_type'].unique()]
34 | value_loss_checklist = [x for x in scalars['loss_type'].unique()]
35 |
36 |
37 | options_game_level = [{"label": x, "value": x}
38 | for x in LEVELS]
39 | value_game_level = [EASY_LEVEL]
40 | options_game = [{"label": x, "value": x}
41 | for x in ['One', 'Two', 'Three', 'Four', 'Five']]
42 | value_game = ['One']
43 |
44 | f = open(os.path.join('static', 'data', 'textworld_logo.txt'),
45 | 'r', encoding='UTF-8')
46 | textworld_logo = f.read()
47 | f.close()
48 |
49 | game_no = 0
50 | easy_env = LogicalTWCQuantifier('easy',
51 | split='test',
52 | max_episode_steps=50,
53 | batch_size=None,
54 | game_number=game_no)
55 | medium_env = LogicalTWCQuantifier('medium',
56 | split='test',
57 | max_episode_steps=50,
58 | batch_size=None,
59 | game_number=game_no)
60 | hard_env = LogicalTWCQuantifier('hard',
61 | split='test',
62 | max_episode_steps=50,
63 | batch_size=None,
64 | game_number=game_no)
65 |
66 | env_dict = \
67 | {EASY_LEVEL: easy_env, MEDIUM_LEVEL: medium_env, HARD_LEVEL: hard_env}
68 |
69 | difficulty_level = 'easy'
70 | loa_pkl_filepath = 'results/loa-twc-dleasy-np2-nt15-ps1-ks6-spboth.pkl'
71 | sem_parser_mode = 'both'
72 |
73 | amr_server_ip = os.environ.get('AMR_SERVER_IP', 'localhost')
74 | amr_server_port_str = os.environ.get('AMR_SERVER_PORT', '')
75 | try:
76 | amr_server_port = int(amr_server_port_str)
77 | except ValueError:
78 | amr_server_port = None
79 |
80 | if LOA_AGENT:
81 | loa_agent = LOAAgent(admissible_verbs=None,
82 | amr_server_ip=amr_server_ip,
83 | amr_server_port=amr_server_port,
84 | prune_by_state_change=True,
85 | sem_parser_mode=sem_parser_mode)
86 |
87 | loa_agent.load_pickel(loa_pkl_filepath)
88 | rest_amr = AMRSemParser(amr_server_ip=amr_server_ip,
89 | amr_server_port=amr_server_port,
90 | cache_folder='cache/')
91 | adm_verbs = loa_agent.admissible_verbs
92 | loa_agent.pi.eval()
93 | loa_rules = loa_agent.extract_rules()
94 | else:
95 | loa_agent = None
96 | rest_amr = None
97 | loa_rules = None
98 |
99 | if DL_AGENT:
100 | twc_agent, twc_agent_goal_graphs, \
101 | twc_agent_manual_world_graphs = get_twc_agent()
102 | else:
103 | twc_agent = None
104 | twc_agent_goal_graphs = None
105 | twc_agent_manual_world_graphs = None
106 |
107 | scored_action_history = [[] for _ in range(1)]
108 |
--------------------------------------------------------------------------------
/pages/interact/functions/twc.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import random
4 | import sys
5 |
6 | import gym
7 | import textworld
8 | import textworld.gym
9 | import torch
10 | from textworld import EnvInfos
11 |
12 | if True:
13 | sys.path.append(os.path.join(os.path.dirname(__file__) + '/../../../'))
14 |
15 | from third_party.loa.amr_parser import get_formatted_obs_text
16 | from third_party.loa.logical_twc import Action2Literal
17 | from third_party.loa.utils import obtain_predicates_logic_vector
18 |
19 |
20 | home_dir = './'
21 |
22 | request_infos = \
23 | EnvInfos(verbs=True, moves=True, inventory=True, description=True,
24 | objective=True, intermediate_reward=True,
25 | policy_commands=True, max_score=True, admissible_commands=True,
26 | last_action=True, game=True, facts=True, entities=True,
27 | won=True, lost=True, location=True)
28 |
29 |
30 | def load_twc_game(level_str, type_str, index):
31 | game_file_names = sorted(glob.glob(
32 | os.path.join(home_dir, 'static', 'games', 'twc', '%s', '%s', '*.ulx') %
33 | (level_str, type_str)))
34 | game_file_name = [game_file_names[index]]
35 | env_id = \
36 | textworld.gym.register_games(game_file_name, request_infos,
37 | max_episode_steps=50,
38 | name='twc-%s-%s-%d' %
39 | (level_str, type_str, index),
40 | batch_size=None)
41 | env = gym.make(env_id)
42 |
43 | return env
44 |
45 |
46 | def info2infos(info):
47 | return {k: [v] for k, v in info.items()}
48 |
49 |
50 | def get_agent_actions(env, obs, score, done, info,
51 | action, scored_action_history,
52 | loa_agent=None, dl_agent=None):
53 | from ..data import rest_amr
54 |
55 | action2literal = Action2Literal()
56 | all_actions = info['admissible_commands']
57 | if loa_agent is None:
58 | ns_agent_actions_list = \
59 | [a for a in all_actions if not a.startswith('examine')]
60 | ns_agent_actions = \
61 | [[a, random.random()] for a in ns_agent_actions_list]
62 | else:
63 | ns_agent_actions = list()
64 |
65 | facts = env.get_logical_state(info)
66 | obs_text = get_formatted_obs_text(info)
67 |
68 | try:
69 | verbnet_facts, _ = rest_amr.obs2facts(obs_text,
70 | mode='both',
71 | verbose=True)
72 |
73 | rest_amr.save_cache()
74 |
75 | verbnet_facts['atlocation'] = facts['atlocation']
76 | verbnet_facts['is_instance'] = facts['is_instance']
77 |
78 | logical_facts = \
79 | {
80 | 'at_location': [list(x) for x in
81 | verbnet_facts['atlocation']],
82 | 'carry': verbnet_facts['carry']
83 | if 'carry' in verbnet_facts else []
84 | }
85 |
86 | for adm_comm in all_actions:
87 | rule, x, y = action2literal(adm_comm)
88 | if rule in loa_agent.admissible_verbs:
89 | rule_arity = loa_agent.admissible_verbs[rule]
90 |
91 | logic_vector, all_preds = \
92 | obtain_predicates_logic_vector(
93 | rule_arity, x, y,
94 | facts=verbnet_facts,
95 | template=loa_agent.arity_predicate_templates)
96 | logic_vector = logic_vector.unsqueeze(0)
97 | yhat = loa_agent.pi.forward_eval(logic_vector,
98 | lnn_model_name=rule)
99 | # print("{} : {:.2f}".format(adm_comm, yhat.item()))
100 | ns_agent_actions.append([adm_comm, float(yhat.item())])
101 | else:
102 | ns_agent_actions.append([adm_comm, 0])
103 |
104 | except Exception:
105 | logical_facts = None
106 | for adm_comm in all_actions:
107 | ns_agent_actions.append([adm_comm, 0])
108 |
109 | if dl_agent is None or dl_agent[0] is None:
110 | dl_agent_actions_list = \
111 | [a for a in all_actions if a.startswith('examine')]
112 | dl_agent_actions = \
113 | [[a, random.random()] for a in dl_agent_actions_list]
114 | else:
115 | dl_agent_actions = list()
116 |
117 | twc_agent, twc_agent_goal_graphs, twc_agent_manual_world_graphs = \
118 | dl_agent
119 |
120 | infos = info2infos(info)
121 |
122 | game_goal_graphs = [None] * 1
123 | game_manual_world_graph = [None] * 1
124 |
125 | for b, game in enumerate(infos["game"]):
126 | if "uuid" in game.metadata:
127 | game_id = game.metadata["uuid"].split("-")[-1]
128 | game_goal_graphs[b] = twc_agent_goal_graphs[game_id]
129 | game_manual_world_graph[b] = \
130 | twc_agent_manual_world_graphs[game_id]
131 |
132 | infos['goal_graph'] = game_goal_graphs
133 | infos['manual_world_graph'] = game_manual_world_graph
134 |
135 | if twc_agent.graph_emb_type and \
136 | ('local' in twc_agent.graph_type or
137 | 'world' in twc_agent.graph_type):
138 | twc_agent.update_current_graph([obs], [action],
139 | scored_action_history,
140 | infos,
141 | 'evolve')
142 |
143 | action, commands_values = \
144 | twc_agent.act([obs], [score], [done], infos,
145 | scored_action_history, random_action=False)
146 |
147 | values = torch.nn.functional.softmax(commands_values)[0].tolist()
148 |
149 | for i, a in enumerate(all_actions):
150 | dl_agent_actions.append([a, values[i]])
151 |
152 | return all_actions, ns_agent_actions, dl_agent_actions, logical_facts
153 |
--------------------------------------------------------------------------------
/pages/interact/callbacks.py:
--------------------------------------------------------------------------------
1 | import dash
2 | from components import dialog_row, next_actions_row
3 | from dash import Dash
4 | from dash.dependencies import Input, Output, State
5 | from dash.exceptions import PreventUpdate
6 | from dash_html_components.P import P
7 |
8 | from .. import MESSAGE_FOR_DONE, MESSAGE_FOR_SELECT_NEXT_ACTION
9 | from . import outputs
10 | from .data import (edges, env_dict, kg_graphs, loa_agent, loa_rules, nodes,
11 | roots_names, scored_action_history, twc_agent,
12 | twc_agent_goal_graphs, twc_agent_manual_world_graphs)
13 | from .functions import get_agent_actions
14 |
15 |
16 | def register(app: Dash):
17 | # Function to register the select_all callbacks,
18 | # receive the checkbox id and the dropdown id.
19 |
20 | @app.callback(
21 | Output('game_level_selection', 'value'),
22 | [
23 | Input('game_level_selection', 'value')
24 | ]
25 | )
26 | def game_level_selection(game_level_value):
27 | values = game_level_value
28 | return values
29 |
30 | @app.callback(Output('fired_rules_agent_1', 'children'),
31 | Input('ui-shell', 'name'))
32 | def edit_network_graph(input):
33 | return outputs.fired_rules(edges, nodes, roots_names,
34 | cytolayout='dagre',
35 | id='fired_rules_agent_1_graph')
36 |
37 | @app.callback(Output('working_memory_agent_1', 'children'),
38 | Input('ui-shell', 'name'))
39 | def edit_network_graph(input):
40 | return outputs.fired_rules(edges, nodes, roots_names,
41 | cytolayout='dagre',
42 | id='working_memory_agent_1_graph')
43 |
44 | @app.callback(
45 | [Output('interactive_agent_chat', 'children'),
46 | Output('next_action_row', 'children'),
47 | Output('game_score', 'children'),
48 | Output('message_for_top_of_action_selector', 'children')],
49 | [Input('agent_1_actions_pie_chart', 'clickData'),
50 | Input('agent_2_actions_pie_chart', 'clickData'),
51 | Input('submit_action', 'n_clicks'),
52 | Input('reset_environment', 'n_clicks'),
53 | Input('configuration_interact_games_apply', 'n_clicks'),
54 | Input('configuration_interact_games_reset', 'n_clicks')],
55 | [State('interactive_agent_chat', 'children'),
56 | State('game_level_selection', 'value'),
57 | State('all_actions_dropdown', 'value')]
58 | )
59 | def add_dialog_row(agent_1_click_data,
60 | agent_2_click_data,
61 | submit_action_n_action,
62 | reset_environment_n_clicks,
63 | configuration_interact_games_apply_n_clicks,
64 | configuration_interact_games_reset_n_clicks,
65 | interactive_agent_chat_children,
66 | game_level,
67 | all_actions_dropdown):
68 |
69 | env = env_dict[game_level]
70 | if twc_agent is not None:
71 | twc_agent.kg_graph = kg_graphs[game_level]
72 | ctx = dash.callback_context
73 |
74 | if not ctx.triggered:
75 | button_id = 'No clicks yet'
76 | else:
77 | button_id = ctx.triggered[0]['prop_id'].split('.')[0]
78 |
79 | print(button_id)
80 |
81 | if (button_id != 'agent_1_actions_pie_chart' and
82 | button_id != 'agent_2_actions_pie_chart' and
83 | button_id != 'submit_action' and
84 | reset_environment_n_clicks is None) \
85 | or button_id == 'reset_environment' \
86 | or button_id == 'configuration_interact_games_reset' \
87 | or button_id == 'configuration_interact_games_apply':
88 | obs, info = env.reset()
89 | if twc_agent is not None:
90 | twc_agent.start_episode(1)
91 |
92 | new_interactive_chat_children = \
93 | [dialog_row(is_agent=False,
94 | text='Objective: ' + info['objective']),
95 | dialog_row(is_agent=False,
96 | text=info['description'].replace('\n\n', '\n')
97 | .replace('\n', '