├── autoascend ├── __init__.py ├── visualization │ ├── __init__.py │ ├── glyph2tile.py │ ├── utils.py │ ├── scopes.py │ └── visualizer.py ├── objects │ ├── __init__.py │ └── utils.py ├── combat │ ├── __init__.py │ ├── utils.py │ ├── monster_utils.py │ ├── rl_scoring.py │ ├── movement_priority.py │ └── fight_heur.py ├── monster_tracker │ ├── __init__.py │ ├── kernels.py │ └── monster_tracker.py ├── exceptions.py ├── item │ ├── __init__.py │ ├── utils.py │ ├── item_priority_base.py │ ├── inventory_items.py │ └── item.py ├── glyph │ ├── screen_symbols.py │ ├── monster.py │ ├── __init__.py │ ├── screen_symbol_consts.py │ └── monflag.py ├── soko_solver │ └── __init__.py ├── stats_logger.py ├── level.py ├── rl_utils.py ├── strategy.py ├── utils.py └── env_wrapper.py ├── bin ├── docker-build.sh ├── docker-run.sh ├── filter_for_vis.py ├── solve_sokoban.py ├── get_observations_stats.py ├── summary.py └── main.py ├── requirements.txt ├── LICENSE ├── Dockerfile ├── muzero ├── muzero.patch ├── rl_features_stats.json └── nethack.py └── README.md /autoascend/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /autoascend/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /autoascend/visualization/glyph2tile.py: -------------------------------------------------------------------------------- 1 | /glyph2tile.py -------------------------------------------------------------------------------- /bin/docker-build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | docker build . -t nethack "$@" 3 | -------------------------------------------------------------------------------- /autoascend/objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * 2 | from .utils import * -------------------------------------------------------------------------------- /autoascend/combat/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fight_heur, monster_utils, movement_priority, utils, rl_scoring 2 | -------------------------------------------------------------------------------- /autoascend/monster_tracker/__init__.py: -------------------------------------------------------------------------------- 1 | from .kernels import disappearance_mask 2 | from .monster_tracker import MonsterTracker 3 | -------------------------------------------------------------------------------- /autoascend/exceptions.py: -------------------------------------------------------------------------------- 1 | class AgentFinished(Exception): 2 | pass 3 | 4 | 5 | class AgentPanic(Exception): 6 | pass 7 | 8 | 9 | class AgentChangeStrategy(Exception): 10 | pass 11 | -------------------------------------------------------------------------------- /autoascend/item/__init__.py: -------------------------------------------------------------------------------- 1 | from .item import Item 2 | from .item_manager import ContainerContent, ItemManager 3 | from .utils import flatten_items, find_equivalent_item, check_if_triggered_container_trap 4 | -------------------------------------------------------------------------------- /autoascend/glyph/screen_symbols.py: -------------------------------------------------------------------------------- 1 | from . import screen_symbol_consts 2 | from .screen_symbol_consts import * 3 | 4 | 5 | def find(glyph): 6 | for k, v in vars(screen_symbol_const).items(): 7 | if k.startswith('S_') and v == glyph: 8 | return f'SS.{k}' 9 | -------------------------------------------------------------------------------- /bin/docker-run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | devices=0 3 | docker run --gpus='"device='$devices'"' --ipc=host --net=host -it --rm \ 4 | -e DISPLAY=$DISPLAY \ 5 | -e PYTHONPATH=/workspace \ 6 | -v /tmp/.X11-unix:/tmp/.X11-unix \ 7 | -v $HOME/.Xauthority:/root/.Xauthority \ 8 | -v `pwd`:/workspace nethack:latest "$@" 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.24 2 | gym==0.19.0 3 | matplotlib==3.4.3 4 | nevergrad==0.4.3.post8 5 | nltk==3.6.2 6 | numba==0.52.0 7 | numpy==1.21.2 8 | opencv-python==4.5.4.58 9 | pandas==1.3.4 10 | Pillow==8.4.0 11 | pyinstrument==4.0.4 12 | ray==1.8.0 13 | redis==3.5.3 14 | scikit-learn==0.24.2 15 | scipy==1.6.3 16 | seaborn==0.11.2 17 | toolz==0.11.2 18 | torch==1.10.0a0+3fd9dcf 19 | torchvision==0.11.0a0 -------------------------------------------------------------------------------- /autoascend/item/utils.py: -------------------------------------------------------------------------------- 1 | def flatten_items(iterable): 2 | ret = [] 3 | for item in iterable: 4 | ret.append(item) 5 | if item.is_container(): 6 | ret.extend(flatten_items(item.content)) 7 | return ret 8 | 9 | 10 | def find_equivalent_item(item, iterable): 11 | assert item.text 12 | for i in iterable: 13 | assert i.text 14 | if i.text == item.text: 15 | return i 16 | assert 0, (item, iterable) 17 | 18 | 19 | def check_if_triggered_container_trap(message): 20 | return ('A cloud of ' in message and ' gas billows from ' in message) or \ 21 | 'Suddenly you are frozen in place!' in message or \ 22 | 'A tower of flame bursts from ' in message or \ 23 | 'You are jolted by a surge of electricity!' in message or \ 24 | 'But luckily ' in message 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright © 2022 Maciej Sypetkowski, Michał Sypetkowski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /autoascend/combat/utils.py: -------------------------------------------------------------------------------- 1 | def wielding_ranged_weapon(agent): 2 | for item in agent.inventory.items: 3 | if item.is_launcher() and item.equipped: 4 | return True 5 | return False 6 | 7 | 8 | def wielding_melee_weapon(agent): 9 | for item in agent.inventory.items: 10 | if item.is_weapon() and item.equipped: 11 | return True 12 | return False 13 | 14 | 15 | def line_dis_from(agent, y, x): 16 | return max(abs(agent.blstats.x - x), abs(agent.blstats.y - y)) 17 | 18 | 19 | def inside(agent, y, x): 20 | return 0 <= y < agent.glyphs.shape[0] and 0 <= x < agent.glyphs.shape[1] 21 | 22 | 23 | def action_str(agent, action): 24 | priority, a = action 25 | if a[0] == 'move': 26 | return f'{priority}m:{a[1]},{a[2]}' 27 | elif a[0] == 'melee': 28 | return f'{priority}me:{a[1]},{a[2]}' 29 | elif a[0] == 'pickup': 30 | return f'{priority}{a[0][0]}:{len(a[1])}' 31 | elif a[0] == 'zap': 32 | wand = a[3] 33 | letter = agent.inventory.items.get_letter(wand) 34 | return f'{priority}z{letter}:{a[1]},{a[2]}' 35 | elif a[0] == 'elbereth': 36 | return f'{priority:.1f}e' 37 | elif a[0] == 'wait': 38 | return f'{priority:.1f}w' 39 | elif a[0] == 'go_to': 40 | return f'{priority}goto:{a[1]},{a[2]}' 41 | else: 42 | return f'{priority}{a[0][0]}:{a[1]},{a[2]}' 43 | -------------------------------------------------------------------------------- /autoascend/glyph/monster.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import nle.nethack as nh 4 | 5 | from .monflag import * 6 | 7 | 8 | def is_monster(glyph): 9 | return nh.glyph_is_monster(glyph) 10 | 11 | 12 | def is_pet(glyph): 13 | return nh.glyph_is_pet(glyph) 14 | 15 | 16 | @functools.lru_cache(nh.NUMMONS * 10) 17 | def permonst(glyph): 18 | if nh.glyph_is_monster(glyph): 19 | return nh.permonst(nh.glyph_to_mon(glyph)) 20 | elif nh.glyph_is_pet(glyph): 21 | return nh.permonst(nh.glyph_to_pet(glyph)) 22 | elif nh.glyph_is_body(glyph): 23 | return nh.permonst(glyph - nh.GLYPH_BODY_OFF) 24 | else: 25 | assert 0, glyph 26 | 27 | 28 | def find(glyph): 29 | if glyph in ALL_MONS or glyph in MON.ALL_PETS: 30 | return f'MON.fn({repr(permonst(glyph).mname)})' 31 | 32 | 33 | @functools.lru_cache(nh.NUMMONS) 34 | def from_name(name): 35 | return nh.GLYPH_MON_OFF + id_from_name(name) 36 | 37 | 38 | @functools.lru_cache(nh.NUMMONS) 39 | def id_from_name(name): 40 | for i in range(nh.NUMMONS): 41 | if nh.permonst(i).mname == name: 42 | return i 43 | assert 0, name 44 | 45 | 46 | def body_from_name(name): 47 | return id_from_name(name) + nh.GLYPH_BODY_OFF 48 | 49 | 50 | fn = from_name 51 | 52 | ALL_MONS = [nh.GLYPH_MON_OFF + i for i in range(nh.NUMMONS)] 53 | ALL_PETS = [nh.GLYPH_PET_OFF + i for i in range(nh.NUMMONS)] 54 | -------------------------------------------------------------------------------- /bin/filter_for_vis.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pandas as pd 4 | 5 | 6 | def interesting_reason(txt): 7 | return True 8 | txt = txt.lower() 9 | # return 'Error' not in txt 10 | # return 'starved' in txt.lower() 11 | return ('food' in txt or 'fainted' in txt or 'starved' in txt) 12 | # return 'captain' in txt or 'shop' in txt 13 | # return 'rotten' in txt 14 | return ('food' not in txt 15 | and 'shop' not in txt 16 | and 'falling rock' not in txt 17 | and 'Error' not in txt 18 | and 'starved' not in txt 19 | and 'timeout' not in txt 20 | and 'sleeping' not in txt 21 | and 'wand' not in txt 22 | and 'bolt' not in txt 23 | and 'missile' not in txt 24 | and 'rotted' not in txt 25 | and 'guard' not in txt 26 | and 'quit' not in txt) 27 | 28 | 29 | def process(path='/tmp/nh_sim.json'): 30 | ret = dict() 31 | 32 | with open(path, 'r') as f: 33 | df = pd.DataFrame(json.load(f)) 34 | df['role'] = [ch[:3] for ch in df.character] 35 | 36 | for row in df.itertuples(): 37 | # if row.score > 2200: 38 | # continue 39 | # if interesting_reason(row.end_reason): 40 | if row.search_diff > 300: 41 | print(row.seed[0], row.steps, row.end_reason) 42 | ret[row.seed[0]] = row.steps - 128 43 | 44 | with open('filtered.json', 'w') as f: 45 | json.dump(ret, f) 46 | 47 | 48 | # process('/tmp/nh_sim.json') 49 | # process('nh_sim_fix_engrave.json') 50 | # process('/workspace/nh_sim_fight2.json') 51 | 52 | 53 | process() 54 | -------------------------------------------------------------------------------- /autoascend/item/item_priority_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ItemPriorityBase: 5 | """ 6 | The base class for inventory item priority logic. 7 | """ 8 | 9 | def _split(self, items, forced_items, weight_capacity): 10 | ''' 11 | returns a dict (container_item or None for inventory) -> 12 | (list of counts to take corresponding to `items`) 13 | 14 | Lack of the container in the dict means "don't change the content except for 15 | items wanted by other containers" 16 | 17 | Order of `items` matters. First items are more important. 18 | Otherwise the agent will drop and pickup items repeatedly. 19 | 20 | The function should be monotonic (i.e. removing an item from the argument, 21 | shouldn't decrease counts of other items). Otherwise the agent may 22 | go to the item, don't take it, and repeat infinitely 23 | 24 | weight capacity can be exceeded. It's only a hint what the agent wants 25 | ''' 26 | raise NotImplementedError() 27 | 28 | def split(self, items, forced_items, weight_capacity): 29 | ret = self._split(items, forced_items, weight_capacity) 30 | assert None in ret 31 | counts = np.array(list(ret.values())).sum(0) 32 | assert all((0 <= count <= item.count for count, item in zip(counts, items))) 33 | assert all((0 <= c <= item.count for cs in ret.values() for c, item in zip(cs, items))) 34 | assert all((item not in ret or item.is_container() for item in items)) 35 | assert all((item not in ret or ret[item][i] == 0 for i, item in enumerate(items))) 36 | return ret 37 | -------------------------------------------------------------------------------- /autoascend/combat/monster_utils.py: -------------------------------------------------------------------------------- 1 | # heuristic monster types lists 2 | ONLY_RANGED_SLOW_MONSTERS = ['floating eye', 'blue jelly', 'brown mold', 'gas spore', 'acid blob'] 3 | EXPLODING_MONSTERS = ['yellow light', 'gas spore', 'flaming sphere', 'freezing sphere', 'shocking sphere'] 4 | INSECTS = ['giant ant', 'killer bee', 'soldier ant', 'fire ant', 'giant beetle', 'queen bee'] 5 | WEAK_MONSTERS = ['lichen', 'newt', 'shrieker', 'grid bug'] 6 | WEIRD_MONSTERS = ['leprechaun', 'nymph'] 7 | 8 | 9 | def is_monster_faster(agent, monster): 10 | _, y, x, mon, _ = monster 11 | # TOOD: implement properly 12 | return 'bat' in mon.mname or 'dog' in mon.mname or 'cat' in mon.mname \ 13 | or 'kitten' in mon.mname or 'pony' in mon.mname or 'horse' in mon.mname \ 14 | or 'bee' in mon.mname or 'fox' in mon.mname 15 | 16 | 17 | def imminent_death_on_melee(agent, monster): 18 | if is_dangerous_monster(monster): 19 | return agent.blstats.hitpoints <= 16 20 | return agent.blstats.hitpoints <= 8 21 | 22 | 23 | def is_dangerous_monster(monster): 24 | _, y, x, mon, _ = monster 25 | is_pet = 'dog' in mon.mname or 'cat' in mon.mname or 'kitten' in mon.mname or 'pony' in mon.mname \ 26 | or 'horse' in mon.mname 27 | # 'mumak' in mon.mname or 'orc' in mon.mname or 'rothe' in mon.mname \ 28 | # or 'were' in mon.mname or 'unicorn' in mon.mname or 'elf' in mon.mname or 'leocrotta' in mon.mname \ 29 | # or 'mimic' in mon.mname 30 | return is_pet or mon.mname in INSECTS 31 | 32 | 33 | def consider_melee_only_ranged_if_hp_full(agent, monster): 34 | return monster[3].mname in ('brown mold', 'blue jelly') and agent.blstats.hitpoints == agent.blstats.max_hitpoints 35 | -------------------------------------------------------------------------------- /autoascend/monster_tracker/kernels.py: -------------------------------------------------------------------------------- 1 | import numba as nb 2 | import numpy as np 3 | 4 | 5 | @nb.njit('b1[:,:](i2[:,:],i2[:,:],i4)', cache=True) 6 | def disappearance_mask(old_mons, new_mons, max_radius): 7 | ret = np.zeros_like(new_mons, dtype=nb.b1) 8 | for y in range(new_mons.shape[0]): 9 | for x in range(new_mons.shape[1]): 10 | glyph = old_mons[y, x] 11 | if glyph == -1: 12 | continue 13 | ret[y, x] = (new_mons[max(0, y - max_radius): min(y + max_radius + 1, new_mons.shape[0]), 14 | max(0, x - max_radius): min(x + max_radius + 1, new_mons.shape[1])] != glyph).all() 15 | return ret 16 | 17 | 18 | @nb.njit('optional(b1[:,:])(i2[:,:],i2[:,:],i2[:,:],i4)', cache=True) 19 | def figure_out_monster_movement(peaceful_mons, aggressive_mons, new_mons, max_radius): 20 | ret_peaceful_mons = np.zeros_like(peaceful_mons, dtype=nb.b1) 21 | for y in range(new_mons.shape[0]): 22 | for x in range(new_mons.shape[1]): 23 | glyph = new_mons[y, x] 24 | if glyph == -1: 25 | continue 26 | 27 | can_be_peaceful = False 28 | can_be_aggressive = False 29 | for py in range(max(0, y - max_radius), 30 | min(y + max_radius + 1, new_mons.shape[0])): 31 | for px in range(max(0, x - max_radius), 32 | min(x + max_radius + 1, new_mons.shape[1])): 33 | if peaceful_mons[py, px] == glyph: 34 | can_be_peaceful = True 35 | if aggressive_mons[py, px] == glyph: 36 | can_be_aggressive = True 37 | if can_be_peaceful == can_be_aggressive: 38 | return None 39 | if can_be_peaceful: 40 | ret_peaceful_mons[y, x] = True 41 | 42 | return ret_peaceful_mons 43 | -------------------------------------------------------------------------------- /autoascend/visualization/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | FONT_SIZE = 32 5 | 6 | 7 | def put_text(img, text, pos, scale=FONT_SIZE / 35, thickness=1, color=(255, 255, 0), console=False): 8 | # TODO: figure out how exactly opencv anchors the text 9 | pos = (pos[0] + FONT_SIZE // 2, pos[1] + FONT_SIZE // 2 + 8) 10 | 11 | if console: 12 | # TODO: implement equal characters size font 13 | # scale *= 2 14 | # font = cv2.FONT_HERSHEY_PLAIN 15 | font = cv2.FONT_HERSHEY_SIMPLEX 16 | else: 17 | font = cv2.FONT_HERSHEY_SIMPLEX 18 | return cv2.putText(img, text, pos, font, 19 | scale, color, thickness, cv2.LINE_AA) 20 | 21 | 22 | def draw_frame(img, color=(90, 90, 90), thickness=3): 23 | return cv2.rectangle(img, (0, 0), (img.shape[1] - 1, img.shape[0] - 1), color, thickness) 24 | 25 | 26 | def draw_grid(imgs, ncol): 27 | grid = imgs.reshape((-1, ncol, *imgs[0].shape)) 28 | rows = [] 29 | for row in grid: 30 | rows.append(np.concatenate(row, axis=1)) 31 | return np.concatenate(rows, axis=0) 32 | return img 33 | 34 | 35 | class VideoWriter: 36 | def __init__(self, path, fps, resolution=1080): 37 | self.path = path 38 | self.path.parent.mkdir(exist_ok=True, parents=True) 39 | self.out = None # lazy init 40 | self.fps = fps 41 | self.resolution = (round(resolution * 16 / 9), resolution) 42 | 43 | def _make_writer(self, frame): 44 | h, w = frame.shape[:2] 45 | print(f'Initializing video writer with resolution {w}x{h}: {self.path}') 46 | return cv2.VideoWriter(str(self.path), 47 | cv2.VideoWriter_fourcc(*'mp4v'), 48 | self.fps, (w, h)) 49 | 50 | def write(self, frame): 51 | frame = cv2.resize(frame, self.resolution) 52 | frame = frame.astype(np.uint8)[..., ::-1] 53 | if self.out is None: 54 | self.out = self._make_writer(frame) 55 | self.out.write(frame) 56 | 57 | def close(self): 58 | self.out.release() -------------------------------------------------------------------------------- /autoascend/soko_solver/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .maps import maps 4 | from .. import utils 5 | 6 | IGNORE = 0 7 | EMPTY = 1 8 | WALL = 2 9 | BOULDER = 3 10 | TARGET = 4 11 | 12 | 13 | class SokoMap: 14 | def __init__(self, pos, sokomap): 15 | self.sokomap = sokomap 16 | self.pos = pos 17 | 18 | def bfs(self): 19 | return utils.bfs(*self.pos, walkable=self.sokomap == EMPTY, 20 | walkable_diagonally=np.zeros_like(self.sokomap, dtype=bool), 21 | can_squeeze=False) 22 | 23 | def move(self, boulder_y, boulder_x, dy, dx): 24 | assert self.sokomap[boulder_y, boulder_x] == BOULDER 25 | assert self.sokomap[boulder_y - dy, boulder_x - dx] == EMPTY 26 | assert self.bfs()[boulder_y - dy, boulder_x - dx] != -1 27 | assert self.sokomap[boulder_y + dy, boulder_x + dx] in [EMPTY, TARGET] 28 | 29 | self.pos = boulder_y, boulder_x 30 | self.sokomap[boulder_y, boulder_x] = EMPTY 31 | if self.sokomap[boulder_y + dy, boulder_x + dx] == EMPTY: 32 | self.sokomap[boulder_y + dy, boulder_x + dx] = BOULDER 33 | elif self.sokomap[boulder_y + dy, boulder_x + dx] == TARGET: 34 | pass 35 | else: 36 | assert 0 37 | 38 | def print(self, pos=None): 39 | mapping = {IGNORE: ' ', EMPTY: '.', WALL: '#', BOULDER: '0', TARGET: '^'} 40 | for y in range(self.sokomap.shape[0]): 41 | for x in range(self.sokomap.shape[1]): 42 | print((mapping[self.sokomap[y, x]] if (y, x) != (pos or self.pos) else '@'), end='') 43 | print() 44 | print() 45 | 46 | 47 | def convert_map(text): 48 | START = -1 49 | mapping = {'<': EMPTY, '>': START, '.': EMPTY, '?': EMPTY, '+': EMPTY, 50 | '0': BOULDER, '-': WALL, '|': WALL, ' ': IGNORE, '^': TARGET} 51 | ret = [] 52 | for line in text.splitlines(): 53 | if not line: 54 | continue 55 | ret.append([mapping[l] for l in line]) 56 | ret = np.array(ret) 57 | assert len(list(zip(*(ret == TARGET).nonzero()))) == 1 58 | start = list(zip(*(ret == START).nonzero())) 59 | assert len(start) == 1 60 | start = start[0] 61 | ret[ret == START] = EMPTY 62 | return SokoMap(start, np.array(ret)) 63 | -------------------------------------------------------------------------------- /autoascend/stats_logger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | from . import character 6 | 7 | 8 | class StatsLogger: 9 | def __init__(self): 10 | self._values = { 11 | "agent_panic": 0, 12 | "elbereth_write": 0, 13 | "container_untrap_success": 0, 14 | "container_untrap_fail": 0, 15 | "untrap_success": 0, 16 | "triggered_undetected_trap": 0, 17 | "allow_walk_traps": 0, 18 | "allow_attack_all": 0, 19 | "sokoban_dropped": 0, 20 | "wait_in_fight": 0, 21 | "melee_gas_spore": 0, 22 | "ad_aerarium_below_me": 0, 23 | "drop_gold": 0, 24 | **{f"cast_{n}": 0 for n in character.ALL_SPELL_NAMES}, 25 | **{f"cast_fail_{n}": 0 for n in character.ALL_SPELL_NAMES}, 26 | } 27 | self._max_values = { 28 | "search_diff": -float('inf'), 29 | } 30 | 31 | self._cumulative_values = { 32 | "max_turns_on_position": defaultdict(int), 33 | } 34 | 35 | self.gold_stats = ['mean', 'median', 'std', 'min', 'max', 'first', 'last'] 36 | self._keys = list(self._values) + list(self._max_values) + list(self._cumulative_values) + self.gold_stats 37 | 38 | self.gold = [] 39 | 40 | def log_cumulative_value(self, name, key, value): 41 | self._cumulative_values[name][key] += value 42 | 43 | def log_event(self, name): 44 | self._values[name] += 1 45 | 46 | def log_gold(self, amount): 47 | self.gold.append(amount) 48 | 49 | def log_max_value(self, name, value): 50 | self._max_values[name] = max(self._max_values[name], value) 51 | 52 | def get_stats_dict(self): 53 | ret = dict() 54 | ret.update(self._values) 55 | ret.update(self._max_values) 56 | ret.update({k: max(v.values()) for k, v in self._cumulative_values.items()}) 57 | 58 | for stat in self.gold_stats: 59 | try: 60 | ret['gold_' + stat] = getattr(np, stat)(self.gold) 61 | except AttributeError: 62 | if stat == 'first': 63 | ret['gold_' + stat] = max(self.gold[:20]) 64 | elif stat == 'last': 65 | ret['gold_' + stat] = self.gold[-1] 66 | else: 67 | assert 0, stat 68 | return ret 69 | -------------------------------------------------------------------------------- /bin/solve_sokoban.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import termios 4 | import tty 5 | 6 | from autoascend.soko_solver import convert_map, maps 7 | 8 | 9 | def main(): 10 | termios.tcgetattr(sys.stdin) 11 | tty.setcbreak(sys.stdin.fileno()) 12 | try: 13 | # check answers / manually solve 14 | for text_map, answer in maps.items(): 15 | sokomap = convert_map(text_map) 16 | if answer is None: 17 | answer = [] 18 | 19 | answer = answer.copy() 20 | path = [] 21 | while (sokomap.sokomap == BOULDER).sum() > 0 or len(answer) != 0: 22 | print(path) 23 | sokomap.print() 24 | 25 | if len(answer) == 0: 26 | y, x = sokomap.pos 27 | while 1: 28 | sokomap.print((y, x)) 29 | dir = os.read(sys.stdin.fileno(), 3) 30 | mapping = {b'j': (1, 0), b'k': (-1, 0), b'h': (0, -1), b'l': (0, 1)} 31 | if dir not in mapping: 32 | continue 33 | dy, dx = mapping[dir] 34 | if sokomap.sokomap[y + dy, x + dx] == EMPTY: 35 | y, x = y + dy, x + dx 36 | elif sokomap.sokomap[y + dy, x + dx] == BOULDER: 37 | y, x = y + dy, x + dx 38 | break 39 | else: 40 | (y, x), (dy, dx) = answer[0] 41 | answer = answer[1:] 42 | 43 | if sokomap.sokomap[y, x] != BOULDER: 44 | print('that is not a boulder!') 45 | continue 46 | if sokomap.sokomap[y - dy, x - dx] != EMPTY: 47 | print('you cannot stand to push in this direction!') 48 | continue 49 | if sokomap.bfs()[y - dy, x - dx] == -1: 50 | print('you cannot get there!') 51 | continue 52 | if sokomap.sokomap[y + dy, x + dx] not in [EMPTY, TARGET]: 53 | print('you cannot push in this direction!') 54 | continue 55 | 56 | path.append(((y, x), (dy, dx))) 57 | 58 | sokomap.move(y, x, dy, dx) 59 | 60 | print(path) 61 | sokomap.print() 62 | print('OK') 63 | finally: 64 | os.system('stty sane') 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /autoascend/level.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | from . import utils 6 | from .glyph import C, G, SHOP, SS 7 | 8 | 9 | class Level: 10 | DUNGEONS_OF_DOOM = 0 11 | GNOMISH_MINES = 2 12 | QUEST = 3 13 | SOKOBAN = 4 14 | 15 | PLANE = 1000 # TODO: fill with actual value 16 | 17 | dungeon_names = {v: k for k, v in locals().items() if not k.startswith('_')} 18 | 19 | def __init__(self, dungeon_number, level_number): 20 | self.dungeon_number = dungeon_number 21 | self.level_number = level_number 22 | 23 | self.walkable = np.zeros((C.SIZE_Y, C.SIZE_X), bool) 24 | self.seen = np.zeros((C.SIZE_Y, C.SIZE_X), bool) 25 | self.objects = np.zeros((C.SIZE_Y, C.SIZE_X), np.int16) 26 | self.objects[:] = -1 27 | self.was_on = np.zeros((C.SIZE_Y, C.SIZE_X), bool) 28 | 29 | self.shop = np.zeros((C.SIZE_Y, C.SIZE_X), bool) 30 | self.shop_interior = np.zeros((C.SIZE_Y, C.SIZE_X), bool) 31 | self.shop_type = np.zeros((C.SIZE_Y, C.SIZE_X), np.int32) + SHOP.UNKNOWN 32 | 33 | self.search_count = np.zeros((C.SIZE_Y, C.SIZE_X), np.int32) 34 | self.door_open_count = np.zeros((C.SIZE_Y, C.SIZE_X), np.int32) 35 | 36 | self.item_disagreement_counter = np.zeros((C.SIZE_Y, C.SIZE_X), np.int32) 37 | self.items = np.empty((C.SIZE_Y, C.SIZE_X), dtype=object) 38 | self.items.fill([]) 39 | self.item_count = np.zeros((C.SIZE_Y, C.SIZE_X), dtype=np.int32) 40 | 41 | self.stair_destination = {} # {(y, x) -> ((dungeon, level), (y, x))} 42 | self.altars = {} # {(y, x) -> alignment} 43 | 44 | self.corpses_to_eat = defaultdict(lambda: defaultdict(lambda: -10000)) # {(y, x) -> {monster_id -> age_turn}} 45 | 46 | # e.g. ad aerarium -- avoid valut entrance 47 | self.forbidden = np.zeros((C.SIZE_Y, C.SIZE_X), bool) 48 | 49 | def key(self): 50 | return (self.dungeon_number, self.level_number) 51 | 52 | def get_stairs(self, down=False, up=False, portal=False, all=False): 53 | # TODO: add portal 54 | if all: 55 | down = up = portal = True 56 | assert down or up or portal 57 | elems = [] 58 | if down: 59 | elems.append(G.STAIR_DOWN) 60 | if up: 61 | elems.append(G.STAIR_UP) 62 | mask = utils.isin(self.objects, *elems) 63 | return {(y, x): self.stair_destination.get((y, x), None) for y, x in zip(*mask.nonzero())} 64 | 65 | def is_light_level(self): 66 | return np.sum(utils.isin(self.objects, [SS.S_room, SS.S_litcorr])) > 15 67 | -------------------------------------------------------------------------------- /bin/get_observations_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for extracting features stats for featurization for RL experiment. 3 | Uncomment code fragment in autoascend/combat/rl_scoring.py to generate the observations.txt file. 4 | """ 5 | 6 | import base64 7 | import json 8 | import pickle 9 | from collections import defaultdict 10 | 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import seaborn as sns 14 | 15 | 16 | with open('/tmp/vis/observations.txt') as f: 17 | observations = defaultdict(list) 18 | for line in f.readlines(): 19 | observation = pickle.loads(base64.b64decode(line)) 20 | for k, v in observation.items(): 21 | observations[k].append(v) 22 | 23 | 24 | def plot_column(df, column): 25 | if df[column].dtype == object: 26 | plt.xticks(rotation=45) 27 | try: 28 | sns.histplot(df[column]) 29 | except ValueError: 30 | print(f'ValueError when plotting {column}') 31 | except IndexError: 32 | print(f'IndexError when plotting {column}') 33 | 34 | 35 | def plot_df(df_name, df, max_plots_in_row=10): 36 | nrows = (len(df.columns) + max_plots_in_row - 1) // max_plots_in_row 37 | fig = plt.figure(figsize=(8, 2 * nrows), dpi=80) 38 | fig.suptitle(df_name, fontsize=100) 39 | gridspec = fig.add_gridspec(nrows=nrows, ncols=max_plots_in_row) 40 | for i, c in enumerate(df.columns): 41 | row = i // max_plots_in_row 42 | col = i % max_plots_in_row 43 | ax = fig.add_subplot(gridspec[i:i + 1]) 44 | ax.title.set_text(c) 45 | plt.sca(ax) 46 | plot_column(df, c) 47 | 48 | plt.tight_layout() 49 | plt.show() 50 | 51 | 52 | stats = defaultdict(dict) 53 | 54 | 55 | for k, v in observations.items(): 56 | v = np.array(v) 57 | print('----------------------', k, v.shape) 58 | if len(v.shape) > 2: 59 | v = v.transpose((0, 2, 3, 1)) 60 | v = v.reshape(-1, v.shape[-1]) 61 | print(k, v.shape) 62 | print([np.mean(np.isnan(v[:, i])) for i in range(v.shape[1])]) 63 | # plot_df(k, pd.DataFrame(v).sample(10000), 5) 64 | mean, std = np.nanmean(v, axis=0), np.nanstd(v, axis=0) 65 | v_normalized = (v - mean) / std 66 | minv = np.nanmin(v_normalized, axis=0) 67 | print(mean) 68 | print(std) 69 | print(minv) 70 | 71 | stats[k]['mean'] = mean.tolist() 72 | stats[k]['std'] = mean.tolist() 73 | stats[k]['min'] = mean.tolist() 74 | 75 | if k == 'heur_action_priorities': 76 | for i in range(v_normalized.shape[1]): 77 | v_normalized[:, i][np.isnan(v_normalized[:, i])] = minv[i] 78 | else: 79 | v_normalized[np.isnan(v_normalized)] = 0 80 | 81 | # plot_df(k + ' normalized', pd.DataFrame(v_normalized).sample(10000), 5) 82 | print() 83 | 84 | 85 | with open('/workspace/muzero/rl_features_stats.json', 'w') as f: 86 | json.dump(stats, f, indent=4) 87 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:21.08-py3 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | RUN apt-get update 5 | RUN apt-get install -y build-essential autoconf libtool pkg-config python3-dev python3-pip python3-numpy git flex \ 6 | bison libbz2-dev xterm gfortran xdot 7 | 8 | COPY requirements.txt . 9 | RUN pip install -r requirements.txt 10 | 11 | COPY muzero/muzero.patch . 12 | RUN git clone https://github.com/werner-duvaud/muzero-general /muzero && cd /muzero && git checkout 23a1f691 13 | RUN patch -d /muzero muzero.patch 14 | 15 | RUN git clone https://github.com/facebookresearch/nle.git /nle --recursive \ 16 | && cd /nle && git checkout v0.7.3 \ 17 | && sed '/#define NLE_ALLOW_SEEDING 1/i#define NLE_ALLOW_SEEDING 1' /nle/include/nleobs.h -i \ 18 | && sed '/self\.env\.set_initial_seeds = f/d' /nle/nle/env/tasks.py -i \ 19 | && sed '/self\.env\.set_current_seeds = f/d' /nle/nle/env/tasks.py -i \ 20 | && sed '/self\.env\.get_current_seeds = f/d' /nle/nle/env/tasks.py -i \ 21 | && sed '/def seed(self, core=None, disp=None, reseed=True):/d' /nle/nle/env/tasks.py -i \ 22 | && sed '/raise RuntimeError("NetHackChallenge doesn.t allow seed changes")/d' /nle/nle/env/tasks.py -i 23 | 24 | RUN cd /nle && python setup.py install 25 | 26 | # uncomment for PyPy support 27 | # RUN conda create -n pypy pypy 28 | # RUN printf '#!/bin/bash\nexec conda run --no-capture-output -n pypy pypy "$@"' >/usr/bin/pypy3 \ 29 | # && chmod +x /usr/bin/pypy3 \ 30 | # && ln -s /usr/bin/pypy{3,} 31 | # RUN pypy -m ensurepip && pypy -m pip install --upgrade pip 32 | # RUN pypy -m pip install numpy scipy scikit-build toolz pyinstrument 33 | # RUN git clone https://github.com/opencv/opencv-python /opencv-python 34 | # RUN pypy -m pip install scikit-build 35 | # RUN CMAKE_ARGS="-D PYTHON3_LIBRARY=/opt/conda/envs/pypy/lib/libpypy3-c.so" pypy -m pip install opencv-python 36 | # RUN cd /nle && pypy setup.py install 37 | 38 | RUN cd /nle/sys/unix && ./setup.sh 39 | RUN cd /nle/util && sed -i '1s/^/CC := $(CC) -I\/nle\/include -DNOMAIL /' Makefile && make 40 | RUN cd /nle/src && sed -i '1s/^/CC := $(CC) -I\/nle\/include -DNOMAIL /' Makefile && make tile.c 41 | RUN python -c 'from pathlib import Path ; text = Path("/nle/src/tile.c").read_text() ; \ 42 | print("glyph2tile = [", text[text.find("{") + 1 : text.find("};")], "]")' >/glyph2tile.py 43 | RUN cd / && python -c 'import nle ; from glyph2tile import glyph2tile ; \ 44 | assert isinstance(glyph2tile, list) and len(glyph2tile) == nle.nethack.MAX_GLYPH' 45 | 46 | # download tileset 47 | RUN mkdir /tilesets && wget 'https://nethackwiki.com/mediawiki/images/7/73/3.6.1tiles32.png' -P /tilesets 48 | 49 | # uncomment to install jupyter vim plugin 50 | # RUN apt update && apt install -y npm 51 | # RUN pip install -U jupyterlab==1.2.14 52 | # RUN jupyter labextension uninstall jupyterlab-jupytext jupyterlab_tensorboard 53 | # RUN jupyter labextension install jupyterlab_vim 54 | -------------------------------------------------------------------------------- /autoascend/monster_tracker/monster_tracker.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | from nle.nethack import actions as A 5 | 6 | from .kernels import figure_out_monster_movement 7 | from .. import utils 8 | from ..exceptions import AgentPanic 9 | from ..glyph import C, G 10 | 11 | 12 | class MonsterTracker: 13 | def __init__(self, agent): 14 | self.agent = agent 15 | self.on_panic() 16 | 17 | def on_panic(self): 18 | self._last_glyphs = None 19 | self.peaceful_monster_mask = np.zeros((C.SIZE_Y, C.SIZE_X), bool) 20 | self.monster_mask = np.zeros((C.SIZE_Y, C.SIZE_X), bool) 21 | 22 | def take_all_monsters(self): 23 | if utils.any_in(self.agent.glyphs, G.SWALLOW): 24 | return {} 25 | with self.agent.atom_operation(): 26 | self.agent.step(A.Command.WHATIS, iter(['M'])) 27 | if 'No monsters are currently shown on the map.' in self.agent.message: 28 | return {} 29 | try: 30 | index = self.agent.popup.index('All monsters currently shown on the map:') 31 | except IndexError: 32 | assert 0, (self.agent.message, self.agent.popup) 33 | regex = re.compile(r"^<(\d+),(\d+)> ([\x00-\x7F]) ([a-zA-z-,' ]+)$") 34 | 35 | monsters = {} 36 | for line in self.agent.popup[index + 1:]: 37 | r = regex.search(line) 38 | assert r is not None, line 39 | x, y, char, name = r.groups() 40 | y, x = int(y), int(x) - 1 41 | 42 | # char_on_map = self.agent.last_observation['chars'][y, x] 43 | # assert ord(char) == char_on_map, (char, chr(char_on_map)) 44 | 45 | monsters[y, x] = name 46 | return monsters 47 | 48 | def _get_current_masks(self): 49 | new_monster_mask = utils.isin(self.agent.glyphs, G.MONS, G.INVISIBLE_MON) 50 | new_monster_mask[self.agent.blstats.y, self.agent.blstats.x] = 0 51 | pet_mask = utils.isin(self.agent.glyphs, G.PETS) 52 | 53 | return new_monster_mask, pet_mask 54 | 55 | def update(self): 56 | new_monster_mask, _ = self._get_current_masks() 57 | 58 | if self._last_glyphs is None: 59 | new_peaceful_mons = None 60 | else: 61 | pea_mon = self._last_glyphs.copy() 62 | pea_mon[~self.peaceful_monster_mask] = -1 63 | agr_mon = self._last_glyphs.copy() 64 | agr_mon[~self.monster_mask | self.peaceful_monster_mask] = -1 65 | new_mon = self.agent.glyphs.copy() 66 | new_mon[~new_monster_mask] = -1 67 | new_peaceful_mons = figure_out_monster_movement(pea_mon, agr_mon, new_mon, max_radius=2) 68 | 69 | self.monster_mask = new_monster_mask 70 | self.peaceful_monster_mask.fill(0) 71 | if not self.agent.character.prop.hallu: 72 | if new_peaceful_mons is None: 73 | all_monsters = self.take_all_monsters() 74 | self.monster_mask, pet_mask = self._get_current_masks() # glyphs can change sometimes after calling `take_all_monsters` 75 | for (y, x), name in all_monsters.items(): 76 | if not (self.monster_mask[y, x] or pet_mask[y, x] or (y, x) == ( 77 | self.agent.blstats.y, self.agent.blstats.x)): 78 | raise AgentPanic('monsters differs between list and glyphs') 79 | if 'peaceful' in name and not pet_mask[y, x]: 80 | self.peaceful_monster_mask[y, x] = 1 81 | else: 82 | self.peaceful_monster_mask = new_peaceful_mons 83 | # TODO: on hallu no monsters are peaceful 84 | 85 | assert (~self.peaceful_monster_mask | self.monster_mask).all() 86 | self._last_glyphs = self.agent.glyphs.copy() 87 | -------------------------------------------------------------------------------- /muzero/muzero.patch: -------------------------------------------------------------------------------- 1 | diff -u a/self_play.py b/self_play.py 2 | --- a/self_play.py 2022-02-28 20:13:16.761313975 +0100 3 | +++ b/self_play.py 2022-02-26 12:06:11.769504383 +0100 4 | @@ -8,8 +8,7 @@ 5 | import models 6 | 7 | 8 | -@ray.remote 9 | -class SelfPlay: 10 | +class SelfPlayNoRay: 11 | """ 12 | Class which run in a dedicated thread to play games and save them to the replay-buffer. 13 | """ 14 | @@ -107,18 +106,28 @@ 15 | 16 | self.close_game() 17 | 18 | - def play_game( 19 | + def play_game(self, *args, **kwargs): 20 | + it = self.play_game_generator(*args, **kwargs) 21 | + next(it) 22 | + try: 23 | + action = it.send((self.game.reset(), None, None, self.game.to_play(), self.game.legal_actions())) 24 | + while 1: 25 | + action = it.send((*self.game.step(action), self.game.to_play(), self.game.legal_actions())) 26 | + except StopIteration as e: 27 | + return e.value 28 | + 29 | + def play_game_generator( 30 | self, temperature, temperature_threshold, render, opponent, muzero_player 31 | ): 32 | """ 33 | Play one game with actions based on the Monte Carlo tree search at each moves. 34 | """ 35 | game_history = GameHistory() 36 | - observation = self.game.reset() 37 | + observation, _, _, to_play, legal_actions = yield None 38 | game_history.action_history.append(0) 39 | game_history.observation_history.append(observation) 40 | game_history.reward_history.append(0) 41 | - game_history.to_play_history.append(self.game.to_play()) 42 | + game_history.to_play_history.append(to_play) 43 | 44 | done = False 45 | 46 | @@ -141,12 +150,12 @@ 47 | ) 48 | 49 | # Choose the action 50 | - if opponent == "self" or muzero_player == self.game.to_play(): 51 | + if opponent == "self" or muzero_player == to_play: 52 | root, mcts_info = MCTS(self.config).run( 53 | self.model, 54 | stacked_observations, 55 | - self.game.legal_actions(), 56 | - self.game.to_play(), 57 | + legal_actions, 58 | + to_play, 59 | True, 60 | ) 61 | action = self.select_action( 62 | @@ -160,14 +169,14 @@ 63 | if render: 64 | print(f'Tree depth: {mcts_info["max_tree_depth"]}') 65 | print( 66 | - f"Root value for player {self.game.to_play()}: {root.value():.2f}" 67 | + f"Root value for player {to_play}: {root.value():.2f}" 68 | ) 69 | else: 70 | action, root = self.select_opponent_action( 71 | opponent, stacked_observations 72 | ) 73 | 74 | - observation, reward, done = self.game.step(action) 75 | + observation, reward, done, to_play, legal_actions = yield action 76 | 77 | if render: 78 | print(f"Played action: {self.game.action_to_string(action)}") 79 | @@ -179,7 +188,7 @@ 80 | game_history.action_history.append(action) 81 | game_history.observation_history.append(observation) 82 | game_history.reward_history.append(reward) 83 | - game_history.to_play_history.append(self.game.to_play()) 84 | + game_history.to_play_history.append(to_play) 85 | 86 | return game_history 87 | 88 | @@ -245,6 +254,8 @@ 89 | 90 | return action 91 | 92 | +SelfPlay = ray.remote(SelfPlayNoRay) 93 | + 94 | 95 | # Game independent 96 | class MCTS: 97 | -------------------------------------------------------------------------------- /muzero/rl_features_stats.json: -------------------------------------------------------------------------------- 1 | { 2 | "player_scalar_stats": { 3 | "mean": [ 4 | 48.71262741088867, 5 | 50.99024200439453, 6 | 0.9566950798034668, 7 | 0.0403948575258255, 8 | 0.8801918029785156 9 | ], 10 | "std": [ 11 | 48.71262741088867, 12 | 50.99024200439453, 13 | 0.9566950798034668, 14 | 0.0403948575258255, 15 | 0.8801918029785156 16 | ], 17 | "min": [ 18 | 48.71262741088867, 19 | 50.99024200439453, 20 | 0.9566950798034668, 21 | 0.0403948575258255, 22 | 0.8801918029785156 23 | ] 24 | }, 25 | "semantic_maps": { 26 | "mean": [ 27 | 0.445127010345459, 28 | 0.017212659120559692, 29 | 4.9437055587768555 30 | ], 31 | "std": [ 32 | 0.445127010345459, 33 | 0.017212659120559692, 34 | 4.9437055587768555 35 | ], 36 | "min": [ 37 | 0.445127010345459, 38 | 0.017212659120559692, 39 | 4.9437055587768555 40 | ] 41 | }, 42 | "heur_action_priorities": { 43 | "mean": [ 44 | 0.28280264139175415, 45 | 0.25963684916496277, 46 | 0.2609662115573883, 47 | 0.40300971269607544, 48 | 0.2299221009016037, 49 | 0.08346183598041534, 50 | -0.1237427294254303, 51 | 0.34145572781562805, 52 | -15.16542911529541, 53 | -4.844458103179932, 54 | -4.023793697357178, 55 | -11.228976249694824, 56 | -17.46282196044922, 57 | -1.6620674133300781, 58 | -18.367876052856445, 59 | -11.888944625854492, 60 | -7.3320393562316895, 61 | -5.830069065093994, 62 | -4.87669563293457, 63 | -4.094380855560303, 64 | -18.273988723754883, 65 | -4.975750923156738, 66 | -17.4599666595459, 67 | -3.321195125579834 68 | ], 69 | "std": [ 70 | 0.28280264139175415, 71 | 0.25963684916496277, 72 | 0.2609662115573883, 73 | 0.40300971269607544, 74 | 0.2299221009016037, 75 | 0.08346183598041534, 76 | -0.1237427294254303, 77 | 0.34145572781562805, 78 | -15.16542911529541, 79 | -4.844458103179932, 80 | -4.023793697357178, 81 | -11.228976249694824, 82 | -17.46282196044922, 83 | -1.6620674133300781, 84 | -18.367876052856445, 85 | -11.888944625854492, 86 | -7.3320393562316895, 87 | -5.830069065093994, 88 | -4.87669563293457, 89 | -4.094380855560303, 90 | -18.273988723754883, 91 | -4.975750923156738, 92 | -17.4599666595459, 93 | -3.321195125579834 94 | ], 95 | "min": [ 96 | 0.28280264139175415, 97 | 0.25963684916496277, 98 | 0.2609662115573883, 99 | 0.40300971269607544, 100 | 0.2299221009016037, 101 | 0.08346183598041534, 102 | -0.1237427294254303, 103 | 0.34145572781562805, 104 | -15.16542911529541, 105 | -4.844458103179932, 106 | -4.023793697357178, 107 | -11.228976249694824, 108 | -17.46282196044922, 109 | -1.6620674133300781, 110 | -18.367876052856445, 111 | -11.888944625854492, 112 | -7.3320393562316895, 113 | -5.830069065093994, 114 | -4.87669563293457, 115 | -4.094380855560303, 116 | -18.273988723754883, 117 | -4.975750923156738, 118 | -17.4599666595459, 119 | -3.321195125579834 120 | ] 121 | } 122 | } -------------------------------------------------------------------------------- /autoascend/visualization/scopes.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from .utils import VideoWriter 5 | 6 | 7 | class DrawTilesScope(): 8 | 9 | def __init__(self, visualizer, tiles, color, is_path=False, is_heatmap=False, mode='fill'): 10 | from ..glyph import C # imported here to allow agent reloading 11 | self.visualizer = visualizer 12 | self.is_heatmap = is_heatmap 13 | self.color = color 14 | self.mode = mode 15 | if self.is_heatmap: 16 | assert not is_path 17 | assert self.mode == 'fill' 18 | assert isinstance(self.color, str) 19 | self.tiles = tiles 20 | else: 21 | if isinstance(tiles, np.ndarray) and tiles.shape == (C.SIZE_Y, C.SIZE_X): 22 | self.tiles = list(zip(*tiles.nonzero())) 23 | else: 24 | self.tiles = tiles 25 | self.is_path = is_path 26 | 27 | def draw_fun(self, rendered): 28 | if self.is_heatmap: 29 | if self.is_heatmap: 30 | grayscale = np.zeros(rendered.shape, dtype=float) 31 | mask = np.ones_like(grayscale).astype(bool) 32 | for y in range(self.tiles.shape[0]): 33 | for x in range(self.tiles.shape[1]): 34 | y1 = y * self.visualizer.tile_size 35 | x1 = x * self.visualizer.tile_size 36 | slic = (slice(y1, y1 + self.visualizer.tile_size), 37 | slice(x1, x1 + self.visualizer.tile_size)) 38 | if np.isnan(self.tiles[y, x]): 39 | mask[slic] = False 40 | else: 41 | grayscale[slic] = self.tiles[y, x] 42 | grayscale[mask] -= np.min(grayscale[mask]) 43 | grayscale[mask] /= np.max(grayscale[mask]) 44 | grayscale = (grayscale * 255).astype(np.uint8) 45 | grayscale = cv2.blur(grayscale, (15, 15)) 46 | # https://docs.opencv.org/4.5.2/d3/d50/group__imgproc__colormap.html 47 | heatmap = cv2.applyColorMap(grayscale, cv2.__dict__[f'COLORMAP_{self.color.upper()}'])[..., ::-1] 48 | return (rendered * 0.5 + heatmap * 0.5).astype(np.uint8) * mask + (rendered * ~mask) 49 | 50 | color = self.color 51 | alpha = 1 52 | if len(color) == 4: 53 | alpha = color[-1] / 255 54 | color = color[:-1] 55 | 56 | if alpha != 1: 57 | orig_rendered, rendered = rendered, np.zeros_like(rendered) 58 | 59 | if self.is_path: 60 | for p1, p2 in zip(self.tiles, self.tiles[1:]): 61 | p1 = [round((i + 0.5) * self.visualizer.tile_size) for i in p1][::-1] 62 | p2 = [round((i + 0.5) * self.visualizer.tile_size) for i in p2][::-1] 63 | cv2.line(rendered, p1, p2, color, 2) 64 | else: 65 | for p in self.tiles: 66 | p1 = [round(i * self.visualizer.tile_size) for i in p][::-1] 67 | p2 = [round((i + 1) * self.visualizer.tile_size) for i in p][::-1] 68 | if self.mode == 'fill': 69 | cv2.rectangle(rendered, p1, p2, color, -1) 70 | if self.mode == 'frame': 71 | cv2.rectangle(rendered, p1, p2, color, 3) 72 | 73 | if alpha != 1: 74 | rendered = np.clip(orig_rendered.astype(np.int16) + (rendered * alpha).astype(np.int16), 0, 255).astype( 75 | np.uint8) 76 | 77 | return rendered.copy() 78 | 79 | def __enter__(self): 80 | self.fun_instance = lambda x: self.draw_fun(x) 81 | self.visualizer.drawers.append(self.fun_instance) 82 | 83 | def __exit__(self, exc_type, exc_val, exc_tb): 84 | self.visualizer.drawers.remove(self.fun_instance) 85 | 86 | 87 | class DebugLogScope(): 88 | 89 | def __init__(self, visualizer, txt, color): 90 | self.visualizer = visualizer 91 | self.txt = txt 92 | self.color = color 93 | 94 | def __enter__(self): 95 | self.visualizer.log_messages.append(self.txt) 96 | 97 | def __exit__(self, exc_type, exc_val, exc_tb): 98 | self.visualizer.log_messages.remove(self.txt) 99 | -------------------------------------------------------------------------------- /autoascend/rl_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class RLModel: 8 | def __init__(self, observation_def, action_space, train=False, training_comm=(None, None)): 9 | # observation_def -- list (name, tuple (shape, dtype)) 10 | self.observation_def = observation_def 11 | self.action_space = action_space 12 | self.train = train 13 | self.input_queue, self.output_queue = None, None 14 | if self.train: 15 | training_comm[0].put(pickle.loads(pickle.dumps(self))) # HACK 16 | self.input_queue, self.output_queue = training_comm 17 | else: 18 | import self_play 19 | import games.nethack 20 | checkpoint = torch.load('/checkpoints/nethack/2021-10-08--16-13-24/model.checkpoint') 21 | config = games.nethack.MuZeroConfig(rl_model=self) 22 | self.inference_iterator = self_play.SelfPlayNoRay(checkpoint, lambda *a: None, config, 0) \ 23 | .play_game_generator(0, 0, False, config.opponent, 0) 24 | assert next(self.inference_iterator) is None 25 | self.is_first_iteration = True 26 | 27 | # def encode_observation(self, observation): 28 | # assert sorted(observation.keys()) == sorted(self.observation_def.keys()) 29 | # ret = [] 30 | # for key, (shape, dtype) in self.observation_def: 31 | # val = observation[key] 32 | # assert val.shape == shape, (val.shape, shape) 33 | # ret.append(np.array(list(val.reshape(-1).astype(dtype).tobytes()), dtype=np.uint8)) 34 | # ret = np.concatenate(ret) 35 | # return ret 36 | 37 | def encode_observation(self, observation): 38 | vals = [] 39 | hw_shape = None 40 | for key, (shape, dtype) in self.observation_def: 41 | vals.append(observation[key]) 42 | if hw_shape is not None and len(shape) > 1: 43 | if len(shape) == 2: 44 | assert hw_shape == shape, (hw_shape, shape) 45 | elif len(shape) == 3: 46 | assert hw_shape == shape[1:], (hw_shape, shape) 47 | else: 48 | assert 0, hw_shape 49 | if len(shape) > 1: 50 | if len(shape) == 2: 51 | hw_shape = shape 52 | elif len(shape) == 3: 53 | hw_shape = shape[1:] 54 | else: 55 | assert 0 56 | 57 | vals = [( 58 | val.reshape(val.shape[0], *hw_shape) if len(val.shape) == 3 else 59 | val.reshape(1, *val.shape) if len(val.shape) == 2 else 60 | val.reshape(val.shape[0], 1, 1).repeat(hw_shape[0], 1).repeat(hw_shape[1], 2) 61 | ).astype(np.float32) for val in vals] 62 | return np.concatenate(vals, 0) 63 | 64 | def zero_observation(self): 65 | ret = {} 66 | for key, (shape, dtype) in self.observation_def: 67 | ret[key] = np.zeros(shape=shape, dtype=dtype) 68 | return ret 69 | 70 | def observation_shape(self): 71 | return self.encode_observation(self.zero_observation()).shape 72 | 73 | # def decode_observation(self, data): 74 | # ret = {} 75 | # for key, (shape, dtype) in self.observation_def: 76 | # arr = np.zeros(shape=shape, dtype=dtype) 77 | # s = len(arr.tobytes()) 78 | # ret[key] = np.frombuffer(bytes(data[:s]), dtype=np.dtype).reshape(shape) 79 | # data = data[s:] 80 | # assert len(data) == 0 81 | # return ret 82 | 83 | def choose_action(self, agent, observation, legal_actions): 84 | assert len(legal_actions) > 0 85 | assert all(map(lambda action: action in self.action_space, legal_actions)) 86 | assert len(legal_actions) > 0 87 | legal_actions = [self.action_space.index(action) for action in legal_actions] 88 | if self.train: 89 | self.input_queue.put((observation, legal_actions, agent.score)) 90 | action_id = self.output_queue.get() 91 | if action_id is None: 92 | raise KeyboardInterrupt() 93 | else: 94 | action_id = self.inference_iterator.send((self.encode_observation(observation), 0, False, 0, legal_actions)) 95 | assert action_id in legal_actions 96 | return self.action_space[action_id] 97 | -------------------------------------------------------------------------------- /autoascend/glyph/__init__.py: -------------------------------------------------------------------------------- 1 | import nle.nethack as nh 2 | 3 | from . import monster as MON 4 | from . import screen_symbols as SS 5 | 6 | 7 | class WEA: 8 | @staticmethod 9 | def expected_damage(damage_str): 10 | if '-' in damage_str: 11 | raise NotImplementedError() 12 | ret = 0 13 | for word in damage_str.split('+'): 14 | if 'd' in word: 15 | sides = int(word[word.find('d') + 1:]) 16 | mult = word[:word.find('d')] 17 | if not mult: 18 | mult = 1 19 | else: 20 | mult = int(mult) 21 | else: 22 | sides = 1 23 | mult = int(word) 24 | ret += mult * (1 + sides) / 2 25 | return ret 26 | 27 | 28 | class SHOP: 29 | UNKNOWN = 0 30 | # names from nle/src/shknam.c 31 | name2id = { 32 | 'UNKNOWN': UNKNOWN, 33 | "general store": 1, 34 | "used armor dealership": 2, 35 | "second-hand bookstore": 3, 36 | "liquor emporium": 4, 37 | "antique weapons outlet": 5, 38 | "delicatessen": 6, 39 | "jewelers": 7, 40 | "quality apparel and accessories": 8, 41 | "hardware store": 9, 42 | "rare books": 10, 43 | "health food store": 11, 44 | "lighting store": 12, 45 | } 46 | 47 | 48 | class Hunger: 49 | SATIATED = 0 50 | NOT_HUNGRY = 1 51 | HUNGRY = 2 52 | WEAK = 3 53 | FAINTING = 4 54 | 55 | 56 | class C: 57 | SIZE_X = 79 58 | SIZE_Y = 21 59 | 60 | 61 | class G: # Glyphs 62 | FLOOR: ['.'] = frozenset({SS.S_room, SS.S_ndoor, SS.S_darkroom, SS.S_corr, SS.S_litcorr}) 63 | VISIBLE_FLOOR: ['.'] = frozenset({SS.S_room, SS.S_litcorr}) 64 | STONE: [' '] = frozenset({SS.S_stone}) 65 | WALL: ['|', '-'] = frozenset({SS.S_vwall, SS.S_hwall, SS.S_tlcorn, SS.S_trcorn, SS.S_blcorn, SS.S_brcorn, 66 | SS.S_crwall, SS.S_tuwall, SS.S_tdwall, SS.S_tlwall, SS.S_trwall}) 67 | STAIR_UP: ['<'] = frozenset({SS.S_upstair, SS.S_upladder}) 68 | STAIR_DOWN: ['>'] = frozenset({SS.S_dnstair, SS.S_dnladder}) 69 | ALTAR: ['_'] = frozenset({SS.S_altar}) 70 | FOUNTAIN = frozenset({SS.S_fountain}) 71 | 72 | DOOR_CLOSED: ['+'] = frozenset({SS.S_vcdoor, SS.S_hcdoor}) 73 | DOOR_OPENED: ['-', '|'] = frozenset({SS.S_vodoor, SS.S_hodoor}) 74 | DOORS = frozenset.union(DOOR_CLOSED, DOOR_OPENED) 75 | 76 | BARS = frozenset({SS.S_bars}) 77 | 78 | MONS = frozenset(MON.ALL_MONS) 79 | PETS = frozenset(MON.ALL_PETS) 80 | WARNING = frozenset({nh.GLYPH_WARNING_OFF + i for i in range(nh.WARNCOUNT)}) 81 | INVISIBLE_MON = frozenset({nh.GLYPH_INVISIBLE, *WARNING}) 82 | 83 | SHOPKEEPER = frozenset({MON.fn('shopkeeper')}) 84 | ORACLE = frozenset({MON.fn('Oracle')}) 85 | GUARD = frozenset({MON.fn('guard')}) 86 | 87 | STATUES = frozenset({i + nh.GLYPH_STATUE_OFF for i in range(nh.NUMMONS)}) 88 | 89 | BODIES = frozenset({nh.GLYPH_BODY_OFF + i for i in range(nh.NUMMONS)}) 90 | OBJECTS = frozenset({nh.GLYPH_OBJ_OFF + i for i in range(nh.NUM_OBJECTS) 91 | if ord(nh.objclass(i).oc_class) != nh.ROCK_CLASS}) 92 | BOULDER = frozenset({nh.GLYPH_OBJ_OFF + i for i in range(nh.NUM_OBJECTS) 93 | if ord(nh.objclass(i).oc_class) == nh.ROCK_CLASS}) 94 | 95 | NORMAL_OBJECTS = frozenset({i for i in range(nh.MAX_GLYPH) if nh.glyph_is_normal_object(i)}) 96 | FOOD_OBJECTS = frozenset({i for i in NORMAL_OBJECTS 97 | if ord(nh.objclass(nh.glyph_to_obj(i)).oc_class) == nh.FOOD_CLASS}) 98 | 99 | TRAPS = frozenset({SS.S_arrow_trap, SS.S_dart_trap, SS.S_falling_rock_trap, SS.S_squeaky_board, SS.S_bear_trap, 100 | SS.S_land_mine, SS.S_rolling_boulder_trap, SS.S_sleeping_gas_trap, SS.S_rust_trap, 101 | SS.S_fire_trap, SS.S_pit, SS.S_spiked_pit, SS.S_hole, SS.S_trap_door, SS.S_teleportation_trap, 102 | SS.S_level_teleporter, SS.S_magic_portal, SS.S_web, SS.S_statue_trap, SS.S_magic_trap, 103 | SS.S_anti_magic_trap, SS.S_polymorph_trap}) 104 | 105 | SWALLOW = frozenset(range(nh.GLYPH_SWALLOW_OFF, nh.GLYPH_WARNING_OFF)) 106 | 107 | DICT = {k: v for k, v in locals().items() if not k.startswith('_')} 108 | 109 | @classmethod 110 | def assert_map(cls, glyphs, chars): 111 | for glyph, char in zip(glyphs.reshape(-1), chars.reshape(-1)): 112 | char = bytes([char]).decode() 113 | for k, v in cls.__annotations__.items(): 114 | assert glyph not in cls.DICT[k] or char in v, f'{k} {v} {glyph} {char}' 115 | 116 | 117 | G.INV_DICT = {i: [k for k, v in G.DICT.items() if i in v] 118 | for i in set.union(*map(set, G.DICT.values()))} 119 | -------------------------------------------------------------------------------- /autoascend/combat/rl_scoring.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from . import utils 3 | 4 | RL_CONTEXT_SIZE = 7 5 | 6 | 7 | def fight2_action_space(agent): 8 | directions = [(-1, -1), (-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1)] 9 | return [ 10 | *[('move', dy, dx) for dy, dx in directions], 11 | *[('melee', dy, dx) for dy, dx in directions], 12 | *[('ranged', dy, dx) for dy, dx in directions], 13 | # *[('zap', dy, dx) for dy, dx in directions], 14 | # ('pickup',), 15 | ] 16 | 17 | 18 | def init_fight2_model(agent): 19 | from .. import rl_utils 20 | agent._fight2_model = rl_utils.RLModel(( 21 | ('player_scalar_stats', ((5,), np.float32)), 22 | ('semantic_maps', ((3, RL_CONTEXT_SIZE, RL_CONTEXT_SIZE), np.float32)), 23 | ('heur_action_priorities', ((8 * 3,), np.float32)), 24 | ), 25 | action_space=fight2_action_space(agent), 26 | train=agent.rl_model_to_train == 'fight2', 27 | training_comm=agent.rl_model_training_comm, 28 | ) 29 | with open('/workspace/muzero/rl_features_stats.json', 'r') as f: 30 | agent._fight2_features_stats = json.load(f) 31 | 32 | 33 | def fight2_player_scalar_stats(agent): 34 | ret = [agent.blstats.hitpoints, 35 | agent.blstats.max_hitpoints, 36 | agent.blstats.hitpoints / agent.blstats.max_hitpoints, 37 | utils.wielding_ranged_weapon(agent), 38 | utils.wielding_melee_weapon(agent)] 39 | ret = np.array(ret, dtype=np.float32) 40 | assert not np.isnan(ret).any() 41 | return ret 42 | 43 | 44 | def fight2_semantic_maps(agent): 45 | radius_y = radius_x = RL_CONTEXT_SIZE // 2 46 | y1, y2, x1, x2 = agent.blstats.y - radius_y, agent.blstats.y + radius_y + 1, \ 47 | agent.blstats.x - radius_x, agent.blstats.x + radius_x + 1 48 | level = agent.current_level() 49 | walkable = level.walkable & ~utils.isin(agent.glyphs, G.BOULDER) & \ 50 | ~agent.monster_tracker.peaceful_monster_mask & \ 51 | ~utils.isin(level.objects, G.TRAPS) 52 | 53 | mspeed = np.ones((C.SIZE_Y, C.SIZE_X), dtype=int) * np.nan 54 | for _, y, x, mon, _ in agent.get_visible_monsters(): 55 | mspeed[y][x] = mon.mmove 56 | 57 | ret = list(map(lambda q: utils.slice_with_padding(q, y1, y2, x1, x2), ( 58 | walkable, agent.monster_tracker.monster_mask, mspeed, 59 | ))) 60 | return np.stack(ret, axis=0).astype(np.float32) 61 | 62 | 63 | def fight2_encoded_heur_action_priorities(agent, heur_priorities): 64 | ret = [] 65 | for action in agent._fight2_model.action_space: 66 | if action in heur_priorities: 67 | ret.append(heur_priorities[action]) 68 | else: 69 | ret.append(np.nan) 70 | return np.array(ret).astype(np.float32) 71 | 72 | 73 | def fight2_get_observation(agent, heur_priorities): 74 | def normalize(name, features): 75 | mean, std, minv = [agent._fight2_features_stats[name][k] for k in ['mean', 'std', 'min']] 76 | v_normalized = features.copy() 77 | assert len(mean) == features.shape[0], (len(mean), features.shape[0]) 78 | for i in range(features.shape[0]): 79 | v_normalized[i, ...] = (features[i, ...] - mean[i]) / std[i] 80 | if name == 'heur_action_priorities': 81 | for i in range(v_normalized.shape[0]): 82 | if np.isnan(v_normalized[i]): 83 | v_normalized[i] = minv[i] 84 | else: 85 | v_normalized[np.isnan(v_normalized)] = 0 86 | return v_normalized 87 | 88 | return {k: normalize(k, v) for k, v in 89 | [('player_scalar_stats', fight2_player_scalar_stats(agent)), 90 | ('semantic_maps', fight2_semantic_maps(agent)), 91 | ('heur_action_priorities', fight2_encoded_heur_action_priorities(agent, heur_priorities))]} 92 | 93 | 94 | def rl_communicate(agent, actions): 95 | action_priorities_for_rl = dict() 96 | for pr, action in actions: 97 | if action[0] == 'go_to': 98 | continue 99 | if action[0] == 'pickup': 100 | action = (action[0],) 101 | if action[0] == 'zap': 102 | action = action[:3] 103 | if action[0] not in ('zap', 'pickup'): 104 | assert action in agent._fight2_model.action_space, action 105 | action_priorities_for_rl[action] = pr 106 | observation = agent._fight2_get_observation(action_priorities_for_rl) 107 | 108 | # uncomment to gather features for get_observations_stats.py 109 | # import pickle 110 | # import base64 111 | # encoded = base64.b64encode(pickle.dumps(observation)).decode() 112 | # with open('/tmp/vis/observations.txt', 'a', buffering=1) as f: 113 | # f.writelines([encoded + '\n']) 114 | 115 | priority, best_action = max(actions, key=lambda x: x[0]) if actions else None 116 | rl_action = agent._fight2_model.choose_action(agent, observation, list(action_priorities_for_rl.keys())) 117 | # TODO: use RL 118 | best_action = rl_action 119 | return best_action 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoAscend -- 1st place NetHack agent for [the NetHack Challenge at NeurIPS 2021](https://www.aicrowd.com/challenges/neurips-2021-the-nethack-challenge) 2 | 3 | 4 | ## Description 5 | The general overview of the approach can be find [here](https://youtu.be/fVkXE330Bh0?t=4439) (1:14:00 -- 1:21:21). 6 | For more context about the challenge and NetHack see [the entire video](https://www.youtube.com/watch?v=fVkXE330Bh0). 7 | Some example episode visualizations are rendered in [this playlist](https://www.youtube.com/playlist?list=PLJ92BrynhLbdQVcz6-bUAeTeUo5i901RQ). 8 | 9 | 10 | ## Environment 11 | We supply the repo with `Dockerfile` that contains all necessary dependencies to run the code. 12 | 13 | `./bin/docker-build.sh` and `./bin/docker-run.sh` are convinience scripts for building and running the docker container. 14 | Note that they should be run only from the root of the repository. 15 | `./bin/docker-run.sh` mounts X11 socket and Xauthority within the container to enable visualization. 16 | You may need to tune it depending on your X11 configuration. 17 | 18 | In `Dockerfile`, besides only installing dependencies, 19 | the [NLE](https://github.com/facebookresearch/nle) library is pulled and slightly modified 20 | to enable game seeding, glyph to tile mapping is generated, and tileset is downloaded, 21 | muzero is pulled and custom patch applied (needed only for experimental reinforcement learning workflows). 22 | 23 | We encourage using docker, but if you decide that you don't want to use it, be sure to make sure that the environment is compatible, 24 | e.g. NLE version supports seeding, tileset is downloaded and hardcoded path in the code changed, 25 | `autoascend/visualization/glyph2tile.py` is a proper file instead of a symlink. 26 | 27 | 28 | ## How to run 29 | `./bin/main.py [PARAMS]` is the main entrypoint. It has three modes: 30 | * `simulate` -- a mode that simulates `--episodes` episodes, and saves results to `--simulation-results` json file. 31 | If the file exists at the beginning, it checks which episodes were already simulated not to simulate episode 32 | with the same seed twice. The script uses [Ray](https://www.ray.io/) to allow running episodes in parallel, 33 | and requires Ray instance to be running (for simpliest setup just run `ray start --head` beforehand). 34 | If you're planning to develop the code and add new features make sure that you understand how Ray works, 35 | because in some cases it may not update code properly if you don't restart the server. 36 | * `run` -- a mode that runs a single episode with visualization. 37 | The visualization supports custom input to override agent action. Just type any letter to pass this input to the environment. 38 | If you type `backspace` key, the agent action will be executed. `delete` key works similary, but fast forward 16 frames. 39 | Be aware that using custom input may confuse the agent, which may result in unexpected behavior and exceptions, 40 | so you may consider using `--panic-on-error` flag to handle unexpected errors gracefully. 41 | * `profile` -- a mode that profiles the code. We implemented two profilers (cProfile and pyinstrument) 42 | that can be set with `--profiler` flag. In pyinstrument we customly process/fake tracebacks to adjust 43 | the summary report to our code to be easier to read and understand (refer to the implementation for details). 44 | 45 | 46 | ## Code structure 47 | The base strategy class with description used for defining strategies is defined in `autoascend/strategy.py`. 48 | Strategy consists of entering condition and agent's behavior. The class contains a few methods for controling the flow 49 | and combine strategies together using functional interface (e.g. `repeat`, `until`, `preempt`), however strategies 50 | can be also passed into and run inside other strategies in imperative manner if needed. 51 | 52 | The main strategy is defined in `autoascend/global_logic.py:GlobalLogic.global_strategy()` 53 | 54 | * `autoascend/global_logic.py` -- contains definitions of the main strategy and other high-level strategies 55 | (e.g. altar farming, altar item identification, dipping for the Excalibur) 56 | * `autoascend/agent.py` -- definition of the agent class. The agent class contains logic for updating the state of the game, 57 | wraps NLE actions into atomic actions (e.g. untrap, pray, open_door), more complex actions (e.g. go_to), 58 | and defines some low-level stategies. 59 | * `autoascend/item/item.py` -- an item instance. Contains a list of possible glyphs, a list of possible objects, 60 | amount of items, beatitude, information about being equipped, bonuses, etc. 61 | * `autoascend/item/inventory.py` -- item and inventory handling logic. That included atoms and strategies for 62 | handling items taking into account bag's items, arranging items, selecting gear to wear/wield, etc. 63 | * `autoascend/item/inventory_items.py` -- a class representing items that are in player's inventory. 64 | * `autoascend/item/item_manager.py` -- a class for managing general information about items in the game. 65 | That includes item identification helpers, known glyph to object and object to glyph mapping, content of bags, parsing items. 66 | * `autoascend/combat` -- combat behavior and helpers. 67 | * `autoascend/exploration_logic.py` -- exploration specific strategies, including exploration within the level and across levels. 68 | * `autoascend/env_wrapper.py` -- an NLE environment wrapper. That includes utilities for forking the process and reloading the agent. 69 | * `autoascend/glyph` -- hardcoded glyphs with their meaning and related helpers. 70 | * `autoascend/object` -- hardcoded objects with their meaning and related helpers. 71 | * `autoascend/soko_solver` -- utilities and method for solving sokoban. 72 | * `autoascend/visualization` -- episode visualization tool. 73 | -------------------------------------------------------------------------------- /autoascend/glyph/screen_symbol_consts.py: -------------------------------------------------------------------------------- 1 | import nle.nethack as nh 2 | 3 | 4 | S_stone =nh.GLYPH_CMAP_OFF+ 0# 5 | S_vwall =nh.GLYPH_CMAP_OFF+ 1# 6 | S_hwall =nh.GLYPH_CMAP_OFF+ 2# 7 | S_tlcorn =nh.GLYPH_CMAP_OFF+ 3# 8 | S_trcorn =nh.GLYPH_CMAP_OFF+ 4# 9 | S_blcorn =nh.GLYPH_CMAP_OFF+ 5# 10 | S_brcorn =nh.GLYPH_CMAP_OFF+ 6# 11 | S_crwall =nh.GLYPH_CMAP_OFF+ 7# 12 | S_tuwall =nh.GLYPH_CMAP_OFF+ 8# 13 | S_tdwall =nh.GLYPH_CMAP_OFF+ 9# 14 | S_tlwall =nh.GLYPH_CMAP_OFF+ 10# 15 | S_trwall =nh.GLYPH_CMAP_OFF+ 11# 16 | S_ndoor =nh.GLYPH_CMAP_OFF+ 12# 17 | S_vodoor =nh.GLYPH_CMAP_OFF+ 13# 18 | S_hodoor =nh.GLYPH_CMAP_OFF+ 14# 19 | S_vcdoor =nh.GLYPH_CMAP_OFF+ 15# /* closed door, vertical wall */ 20 | S_hcdoor =nh.GLYPH_CMAP_OFF+ 16# /* closed door, horizontal wall */ 21 | S_bars =nh.GLYPH_CMAP_OFF+ 17# /* KMH -- iron bars */ 22 | S_tree =nh.GLYPH_CMAP_OFF+ 18# /* KMH */ 23 | S_room =nh.GLYPH_CMAP_OFF+ 19# 24 | S_darkroom =nh.GLYPH_CMAP_OFF+ 20# 25 | S_corr =nh.GLYPH_CMAP_OFF+ 21# 26 | S_litcorr =nh.GLYPH_CMAP_OFF+ 22# 27 | S_upstair =nh.GLYPH_CMAP_OFF+ 23# 28 | S_dnstair =nh.GLYPH_CMAP_OFF+ 24# 29 | S_upladder =nh.GLYPH_CMAP_OFF+ 25# 30 | S_dnladder =nh.GLYPH_CMAP_OFF+ 26# 31 | S_altar =nh.GLYPH_CMAP_OFF+ 27# 32 | S_grave =nh.GLYPH_CMAP_OFF+ 28# 33 | S_throne =nh.GLYPH_CMAP_OFF+ 29# 34 | S_sink =nh.GLYPH_CMAP_OFF+ 30# 35 | S_fountain =nh.GLYPH_CMAP_OFF+ 31# 36 | S_pool =nh.GLYPH_CMAP_OFF+ 32# 37 | S_ice =nh.GLYPH_CMAP_OFF+ 33# 38 | S_lava =nh.GLYPH_CMAP_OFF+ 34# 39 | S_vodbridge =nh.GLYPH_CMAP_OFF+ 35# 40 | S_hodbridge =nh.GLYPH_CMAP_OFF+ 36# 41 | S_vcdbridge =nh.GLYPH_CMAP_OFF+ 37# /* closed drawbridge, vertical wall */ 42 | S_hcdbridge =nh.GLYPH_CMAP_OFF+ 38# /* closed drawbridge, horizontal wall */ 43 | S_air =nh.GLYPH_CMAP_OFF+ 39# 44 | S_cloud =nh.GLYPH_CMAP_OFF+ 40# 45 | S_water =nh.GLYPH_CMAP_OFF+ 41# 46 | 47 | #/* end dungeon characters, begin traps */ 48 | 49 | S_arrow_trap =nh.GLYPH_CMAP_OFF+ 42# 50 | S_dart_trap =nh.GLYPH_CMAP_OFF+ 43# 51 | S_falling_rock_trap =nh.GLYPH_CMAP_OFF+ 44# 52 | S_squeaky_board =nh.GLYPH_CMAP_OFF+ 45# 53 | S_bear_trap =nh.GLYPH_CMAP_OFF+ 46# 54 | S_land_mine =nh.GLYPH_CMAP_OFF+ 47# 55 | S_rolling_boulder_trap =nh.GLYPH_CMAP_OFF+ 48# 56 | S_sleeping_gas_trap =nh.GLYPH_CMAP_OFF+ 49# 57 | S_rust_trap =nh.GLYPH_CMAP_OFF+ 50# 58 | S_fire_trap =nh.GLYPH_CMAP_OFF+ 51# 59 | S_pit =nh.GLYPH_CMAP_OFF+ 52# 60 | S_spiked_pit =nh.GLYPH_CMAP_OFF+ 53# 61 | S_hole =nh.GLYPH_CMAP_OFF+ 54# 62 | S_trap_door =nh.GLYPH_CMAP_OFF+ 55# 63 | S_teleportation_trap =nh.GLYPH_CMAP_OFF+ 56# 64 | S_level_teleporter =nh.GLYPH_CMAP_OFF+ 57# 65 | S_magic_portal =nh.GLYPH_CMAP_OFF+ 58# 66 | S_web =nh.GLYPH_CMAP_OFF+ 59# 67 | S_statue_trap =nh.GLYPH_CMAP_OFF+ 60# 68 | S_magic_trap =nh.GLYPH_CMAP_OFF+ 61# 69 | S_anti_magic_trap =nh.GLYPH_CMAP_OFF+ 62# 70 | S_polymorph_trap =nh.GLYPH_CMAP_OFF+ 63# 71 | S_vibrating_square =nh.GLYPH_CMAP_OFF+ 64# /* for display rather than any trap effect */ 72 | 73 | #/* end traps, begin special effects */ 74 | 75 | S_vbeam =nh.GLYPH_CMAP_OFF+ 65# /* The 4 zap beam symbols. Do NOT separate. */ 76 | S_hbeam =nh.GLYPH_CMAP_OFF+ 66# /* To change order or add, see function */ 77 | S_lslant =nh.GLYPH_CMAP_OFF+ 67# /* zapdir_to_glyph() in display.c. */ 78 | S_rslant =nh.GLYPH_CMAP_OFF+ 68# 79 | S_digbeam =nh.GLYPH_CMAP_OFF+ 69# /* dig beam symbol */ 80 | S_flashbeam =nh.GLYPH_CMAP_OFF+ 70# /* camera flash symbol */ 81 | S_boomleft =nh.GLYPH_CMAP_OFF+ 71# /* thrown boomerang, open left, e.g ')' */ 82 | S_boomright =nh.GLYPH_CMAP_OFF+ 72# /* thrown boomerang, open right, e.g. '(' */ 83 | S_ss1 =nh.GLYPH_CMAP_OFF+ 73# /* 4 magic shield ("resistance sparkle") glyphs */ 84 | S_ss2 =nh.GLYPH_CMAP_OFF+ 74# 85 | S_ss3 =nh.GLYPH_CMAP_OFF+ 75# 86 | S_ss4 =nh.GLYPH_CMAP_OFF+ 76# 87 | S_poisoncloud =nh.GLYPH_CMAP_OFF+ 77 88 | S_goodpos =nh.GLYPH_CMAP_OFF+ 78# /* valid position for targeting via getpos() */ 89 | 90 | #/* The 8 swallow symbols. Do NOT separate. To change order or add, */ 91 | #/* see the function swallow_to_glyph() in display.c. */ 92 | S_sw_tl =nh.GLYPH_CMAP_OFF+ 79# /* swallow top left [1] */ 93 | S_sw_tc =nh.GLYPH_CMAP_OFF+ 80# /* swallow top center [2] Order: */ 94 | S_sw_tr =nh.GLYPH_CMAP_OFF+ 81# /* swallow top right [3] */ 95 | S_sw_ml =nh.GLYPH_CMAP_OFF+ 82# /* swallow middle left [4] 1 2 3 */ 96 | S_sw_mr =nh.GLYPH_CMAP_OFF+ 83# /* swallow middle right [6] 4 5 6 */ 97 | S_sw_bl =nh.GLYPH_CMAP_OFF+ 84# /* swallow bottom left [7] 7 8 9 */ 98 | S_sw_bc =nh.GLYPH_CMAP_OFF+ 85# /* swallow bottom center [8] */ 99 | S_sw_br =nh.GLYPH_CMAP_OFF+ 86# /* swallow bottom right [9] */ 100 | 101 | S_explode1 =nh.GLYPH_CMAP_OFF+ 87# /* explosion top left */ 102 | S_explode2 =nh.GLYPH_CMAP_OFF+ 88# /* explosion top center */ 103 | S_explode3 =nh.GLYPH_CMAP_OFF+ 89# /* explosion top right Ex. */ 104 | S_explode4 =nh.GLYPH_CMAP_OFF+ 90# /* explosion middle left */ 105 | S_explode5 =nh.GLYPH_CMAP_OFF+ 91# /* explosion middle center /-\ */ 106 | S_explode6 =nh.GLYPH_CMAP_OFF+ 92# /* explosion middle right |@| */ 107 | S_explode7 =nh.GLYPH_CMAP_OFF+ 93# /* explosion bottom left \-/ */ 108 | S_explode8 =nh.GLYPH_CMAP_OFF+ 94# /* explosion bottom center */ 109 | S_explode9 =nh.GLYPH_CMAP_OFF+ 95# /* explosion bottom right */ 110 | 111 | #/* end effects */ 112 | 113 | MAXPCHARS = 96# /* maximum number of mapped characters */ -------------------------------------------------------------------------------- /autoascend/item/inventory_items.py: -------------------------------------------------------------------------------- 1 | import nle.nethack as nh 2 | 3 | from autoascend import objects as O 4 | 5 | 6 | class InventoryItems: 7 | def __init__(self, agent): 8 | self.agent = agent 9 | self._previous_inv_strs = None 10 | 11 | self._clear() 12 | 13 | def _clear(self): 14 | self.main_hand = None 15 | self.off_hand = None 16 | self.suit = None 17 | self.helm = None 18 | self.gloves = None 19 | self.boots = None 20 | self.cloak = None 21 | self.shirt = None 22 | 23 | self.total_weight = 0 24 | 25 | self.all_items = [] 26 | self.all_letters = [] 27 | 28 | self._recheck_containers = True 29 | 30 | def __iter__(self): 31 | return iter(self.all_items) 32 | 33 | def __str__(self): 34 | return ( 35 | f'main_hand: {self.main_hand}\n' 36 | f'off_hand : {self.off_hand}\n' 37 | f'suit : {self.suit}\n' 38 | f'helm : {self.helm}\n' 39 | f'gloves : {self.gloves}\n' 40 | f'boots : {self.boots}\n' 41 | f'cloak : {self.cloak}\n' 42 | f'shirt : {self.shirt}\n' 43 | f'Items:\n' + 44 | '\n'.join([f' {l} - {i}' for l, i in zip(self.all_letters, self.all_items)]) 45 | ) 46 | 47 | def total_nutrition(self): 48 | ret = 0 49 | for item in self: 50 | if item.is_food(): 51 | ret += item.object.nutrition * item.count 52 | return ret 53 | 54 | def free_slots(self): 55 | is_coin = any((isinstance(item, O.Coin) for item in self)) 56 | return 52 + is_coin - len(self.all_items) 57 | 58 | def on_panic(self): 59 | self._previous_inv_strs = None 60 | self._clear() 61 | 62 | def update(self, force=False): 63 | if force: 64 | self._recheck_containers = True 65 | 66 | if force or self._previous_inv_strs is None or \ 67 | (self.agent.last_observation['inv_strs'] != self._previous_inv_strs).any(): 68 | self._clear() 69 | self._previous_inv_strs = self.agent.last_observation['inv_strs'] 70 | previous_inv_strs = self._previous_inv_strs 71 | 72 | # For some reasons sometime the inventory entries in last_observation may be duplicated 73 | iterable = set() 74 | for item_name, category, glyph, letter in zip( 75 | self.agent.last_observation['inv_strs'], 76 | self.agent.last_observation['inv_oclasses'], 77 | self.agent.last_observation['inv_glyphs'], 78 | self.agent.last_observation['inv_letters']): 79 | item_name = bytes(item_name).decode().strip('\0') 80 | letter = chr(letter) 81 | if not item_name: 82 | continue 83 | iterable.add((item_name, category, glyph, letter)) 84 | iterable = sorted(iterable, key=lambda x: x[-1]) 85 | 86 | assert len(iterable) == len(set(map(lambda x: x[-1], iterable))), \ 87 | 'letters in inventory are not unique' 88 | 89 | for item_name, category, glyph, letter in iterable: 90 | item = self.agent.inventory.item_manager.get_item_from_text(item_name, category=category, 91 | glyph=glyph if not nh.glyph_is_body( 92 | glyph) and not nh.glyph_is_statue( 93 | glyph) else None, 94 | position=None) 95 | 96 | self.all_items.append(item) 97 | self.all_letters.append(letter) 98 | 99 | if item.equipped: 100 | for types, sub, name in [ 101 | ((O.Weapon, O.WepTool), None, 'main_hand'), 102 | (O.Armor, O.ARM_SHIELD, 'off_hand'), # TODO: twoweapon support 103 | (O.Armor, O.ARM_SUIT, 'suit'), 104 | (O.Armor, O.ARM_HELM, 'helm'), 105 | (O.Armor, O.ARM_GLOVES, 'gloves'), 106 | (O.Armor, O.ARM_BOOTS, 'boots'), 107 | (O.Armor, O.ARM_CLOAK, 'cloak'), 108 | (O.Armor, O.ARM_SHIRT, 'shirt'), 109 | ]: 110 | if isinstance(item.objs[0], types) and (sub is None or sub == item.objs[0].sub): 111 | assert getattr(self, name) is None, ((name, getattr(self, name), item), str(self), iterable) 112 | setattr(self, name, item) 113 | break 114 | 115 | if item.is_possible_container() or (item.is_container() and self._recheck_containers): 116 | self.agent.inventory.check_container_content(item) 117 | 118 | if (self.agent.last_observation['inv_strs'] != previous_inv_strs).any(): 119 | self.update() 120 | return 121 | 122 | self.total_weight += item.weight() 123 | # weight is sometimes unambiguous for unidentified items. All exceptions: 124 | # {'helmet': 30, 'helm of brilliance': 50, 'helm of opposite alignment': 50, 'helm of telepathy': 50} 125 | # {'leather gloves': 10, 'gauntlets of fumbling': 10, 'gauntlets of power': 30, 'gauntlets of dexterity': 10} 126 | # {'speed boots': 20, 'water walking boots': 15, 'jumping boots': 20, 'elven boots': 15, 'fumble boots': 20, 'levitation boots': 15} 127 | # {'luckstone': 10, 'loadstone': 500, 'touchstone': 10, 'flint': 10} 128 | 129 | self._recheck_containers = False 130 | 131 | def get_letter(self, item): 132 | assert item in self.all_items, (item, self.all_items) 133 | return self.all_letters[self.all_items.index(item)] 134 | -------------------------------------------------------------------------------- /bin/summary.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from collections import Counter 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | HEADER = '-' * 50 10 | 11 | 12 | def extract_last_part_from_exception(text): 13 | return text[text.rfind(' File "'):] 14 | 15 | 16 | def extract_last_place_from_exception(text): 17 | return text[text.rfind(' File "'):].splitlines()[0] 18 | 19 | 20 | def load_df(filepath): 21 | with Path(filepath).open() as f: 22 | df = json.load(f) 23 | df = pd.DataFrame.from_dict(df) 24 | for k in df.keys(): 25 | df[k] = [tuple(v) if isinstance(v, list) else v for v in df[k]] 26 | return df 27 | 28 | 29 | def give_examples(df, ref_df): 30 | return f'{len(df)}x ({len(df) / len(ref_df) * 100:.1f}%) ({sorted([(t.seed, t.steps, t.score) for t in df.itertuples()][:5], key=lambda x: x[1])})' 31 | 32 | 33 | def print_exceptions(df, ref_df): 34 | print(HEADER, 'EXCEPTIONS:') 35 | counter = Counter([extract_last_place_from_exception(r) for r in df.end_reason if r.startswith('exception:')]) 36 | for k, v in counter.most_common(): 37 | d = df[[r.startswith('exception:') and k == extract_last_place_from_exception(r) for r in df.end_reason]] 38 | print(k, '\n', extract_last_part_from_exception(d.end_reason.iloc[0]), ':', give_examples(d, df)) 39 | print() 40 | print() 41 | print() 42 | 43 | 44 | def get_group_from_end_reason(text): 45 | if text.startswith('exception:'): 46 | return 'exception' 47 | if 'starved' in text or 'while fainted from lack of food' in text: 48 | return 'food' 49 | if 'the shopkeeper' in text: 50 | return 'peaceful_mon' 51 | if 'was poisoned' in text: 52 | if 'corpse' in text or 'glob' in text: 53 | return 'poisoned_food' 54 | else: 55 | return 'poisoned_other' 56 | if 'turned to stone' in text: 57 | return 'stone' 58 | if 'frozen by a monster' in text: 59 | return 'frozen' 60 | if 'while sleeping' in text: 61 | return 'sleeping' 62 | 63 | return 'other' 64 | 65 | 66 | def print_end_reasons(df, ref_df): 67 | print(HEADER, 'END REASONS:') 68 | for end_reason_group, d in sorted(df.groupby('end_reason_group'), key=lambda x: -len(x[1])): 69 | print(' ', 'GROUP:', end_reason_group, ':', give_examples(d, ref_df)) 70 | counter = Counter([r for r in d.end_reason if not r.startswith('exception:')]) 71 | for k, v in counter.most_common(): 72 | d2 = d[[not r.startswith('exception:') and k == r for r in d.end_reason]] 73 | print(' ', k, ':', give_examples(d2, d)) 74 | print() 75 | print() 76 | print() 77 | 78 | 79 | def print_summary(comment, df, ref_df, indent=0): 80 | indent_chars = ' ' * indent 81 | 82 | print(indent_chars + HEADER, f'SUMMARY ({comment}):') 83 | 84 | print(indent_chars + ' ', '*' * 8, 'stats') 85 | for stat_name, stat_values in [ 86 | ('score ', df.score), 87 | *[(f'score-{role} ', df[df.role == role].score) for role in sorted(df.role.unique())], 88 | *[(f'score-mile-{milestone} ', df[df.milestone == milestone].score) for milestone in 89 | sorted(df.milestone.unique())], 90 | ('exp_level ', df.experience_level), 91 | ('dung_level ', df.level_num), 92 | ('runtime_dur ', df.duration), 93 | ]: 94 | mean = np.mean(stat_values) 95 | std = np.std(stat_values) 96 | quantiles = np.quantile(stat_values, [0, 0.05, 0.25, 0.5, 0.75, 0.95, 1]) 97 | quantiles = ' '.join((f'{q:6.0f}' for q in quantiles)) 98 | with np.printoptions(precision=3, suppress=True): 99 | print(indent_chars + ' ', stat_name, ':', 100 | f'{mean:6.0f} +/- {std:6.0f} [{quantiles}] ({len(stat_values)}x)') 101 | 102 | print(indent_chars + ' ', '*' * 8, 'end_reasons', give_examples(df, ref_df)) 103 | for end_reason_group, d in sorted(df.groupby('end_reason_group'), key=lambda x: -len(x[1])): 104 | print(indent_chars + ' ', end_reason_group, ':', give_examples(d, df)) 105 | 106 | print() 107 | 108 | 109 | def main(filepath): 110 | df = load_df(filepath) 111 | df.seed = [s[0] for s in df.seed] 112 | df['end_reason_group'] = [get_group_from_end_reason(e) for e in df.end_reason] 113 | df['role'] = [c[:3] for c in df.character] 114 | median = np.median(df.score) 115 | 116 | print_exceptions(df, df) 117 | print_end_reasons(df, df) 118 | 119 | print(HEADER, 'SORTED BY SCORE:') 120 | print(df.sort_values('score')) 121 | print() 122 | 123 | print_summary('all', df, df) 124 | print_summary('score >= median', df[df.score >= median], df) 125 | print_summary('score < median', df[df.score < median], df) 126 | 127 | print(HEADER, 'BY ROLE:') 128 | for k, d in df.groupby('role'): 129 | print_summary(k, d, df, indent=1) 130 | 131 | print(HEADER, 'BY MILESTONE:') 132 | for k, d in df.groupby('milestone'): 133 | print_summary(f'milestone-{k}', d, df, indent=1) 134 | 135 | print(HEADER, 'TO PASTE:') 136 | std_median = np.std([np.median(np.random.choice(df.score, size=max(1, len(df) // 2))) for _ in range(1024)]) 137 | print(f'median : {np.median(df.score):.1f} +/- {std_median:.1f}') 138 | print(f'mean : {np.mean(df.score):.1f} +/- {np.std(df.score) / (len(df) ** 0.5):.1f}') 139 | for role in sorted(df.role.unique()): 140 | s = df[df.role == role].score 141 | print(f'{role}:{s.median():.0f},{s.mean():.0f}', end='') 142 | if role != 'wiz': 143 | print('|', end='') 144 | if role == 'mon': 145 | print() 146 | print() 147 | print(f'exceptions: {sum((r.startswith("exception:") for r in df.end_reason)) / len(df) * 100:.1f}%') 148 | print(f'avg_turns : {np.mean(df.turns):.1f} +/- {np.std(df.turns) / (len(df) ** 0.5):.1f}') 149 | print(f'avg_steps : {np.mean(df.steps):.1f} +/- {np.std(df.steps) / (len(df) ** 0.5):.1f}') 150 | 151 | 152 | if __name__ == '__main__': 153 | pd.set_option('display.min_rows', 30) 154 | pd.set_option('display.max_rows', 50) 155 | pd.set_option('display.max_columns', None) 156 | pd.set_option('display.width', None) 157 | pd.set_option('display.max_colwidth', 30) 158 | 159 | filepath = '/workspace/nh_sim.json' if len(sys.argv) <= 1 else sys.argv[1] 160 | main(filepath) 161 | -------------------------------------------------------------------------------- /autoascend/objects/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from .data import * 4 | from .. import utils 5 | 6 | 7 | @utils.copy_result 8 | @functools.lru_cache(len(objects)) 9 | def possibilities_from_glyph(i): 10 | assert nh.glyph_is_object(i) 11 | obj_id = nh.glyph_to_obj(i) 12 | desc = nh.objdescr.from_idx(obj_id).oc_descr or nh.objdescr.from_idx(obj_id).oc_name 13 | cat = ord(nh.objclass(obj_id).oc_class) 14 | 15 | if cat == nh.WEAPON_CLASS: 16 | if desc == 'runed broadsword': 17 | return [objects[obj_id]] 18 | 19 | ret = [o for o in objects if o is not None and (o.desc or o.name) == desc] 20 | assert len(ret) == 1 21 | return ret 22 | 23 | if cat == nh.ARMOR_CLASS: 24 | # https://nethackwiki.com/wiki/Armor 25 | ambiguous_groups = [ 26 | ('piece of cloth', 'opera cloak', 'ornamental cope', 'tattered cape'), 27 | ('plumed helmet', 'etched helmet', 'crested helmet', 'visored helmet'), 28 | ('old gloves', 'padded gloves', 'riding gloves', 'fencing gloves'), 29 | ( 30 | 'mud boots', 'buckled boots', 'riding boots', 'snow boots', 'hiking boots', 'combat boots', 31 | 'jungle boots'), 32 | ] 33 | for group in ambiguous_groups: 34 | if desc in group: 35 | return [o for o in objects if o is not None and o.desc in group] 36 | 37 | # the item is unambiguous or is 'conical hat' 38 | ret = [o for o in objects if o is not None and (o.desc or o.name) == desc] 39 | if desc != 'conical hat': 40 | assert len(ret) == 1, ret 41 | return ret 42 | 43 | if cat in [nh.TOOL_CLASS, nh.FOOD_CLASS]: 44 | return [o for i, o in enumerate(objects) if o is not None and ord(nh.objclass(i).oc_class) == cat and \ 45 | (o.desc or o.name) == desc] 46 | 47 | if cat == nh.GEM_CLASS: 48 | # https://nethackwiki.com/wiki/Gem 49 | desc2names = { 50 | 'black': ['black opal', 'jet', 'obsidian', 'worthless piece of black glass'], 51 | 'blue': ['sapphire', 'turquoise', 'aquamarine', 'fluorite', 'worthless piece of blue glass'], 52 | 'gray': ['luckstone', 'loadstone', 'touchstone', 'flint'], 53 | 'green': ['emerald', 'turquoise', 'aquamarine', 'fluorite', 'jade', 'worthless piece of green glass'], 54 | 'orange': ['jacinth', 'agate', 'worthless piece of orange glass'], 55 | 'red': ['ruby', 'garnet', 'jasper', 'worthless piece of red glass'], 56 | 'rock': ['rock'], 57 | 'violet': ['amethyst', 'fluorite', 'worthless piece of violet glass'], 58 | 'white': ['dilithium crystal', 'diamond', 'opal', 'fluorite', 'worthless piece of white glass'], 59 | 'yellow': ['citrine', 'chrysoberyl', 'worthless piece of yellow glass'], 60 | 'yellowish brown': ['amber', 'topaz', 'worthless piece of yellowish brown glass'], 61 | } 62 | return [from_name(name, cat) for name in desc2names[desc]] 63 | 64 | if cat == nh.AMULET_CLASS: 65 | if desc == 'Amulet of Yendor': 66 | return [from_name('cheap plastic imitation of the Amulet of Yendor'), from_name('Amulet of Yendor')] 67 | return [o for o in objects if isinstance(o, Amulet) and o.name is not None] 68 | 69 | if cat == nh.RING_CLASS: 70 | return [o for o in objects if isinstance(o, Ring) and o.name is not None] 71 | 72 | if cat == nh.COIN_CLASS: 73 | ret = [o for o in objects if isinstance(o, Coin) and o.name is not None] 74 | assert len(ret) == 1 75 | return ret 76 | 77 | if cat == nh.POTION_CLASS: 78 | if desc == 'clear': 79 | return [from_name('water')] 80 | return [o for o in objects if isinstance(o, Potion) and o.name != 'water' and o.name is not None] 81 | 82 | if cat == nh.SCROLL_CLASS: 83 | ambiguous_desc_name = [('stamped', 'mail'), ('unlabeled', 'blank paper')] 84 | for odesc, oname in ambiguous_desc_name: 85 | if desc == odesc: 86 | return [from_name(oname, nh.SCROLL_CLASS)] 87 | 88 | ambiguous_descs = [odesc for odesc, _ in ambiguous_desc_name] 89 | return [o for o in objects if isinstance(o, Scroll) and o.name not in ambiguous_descs and o.name is not None] 90 | 91 | if cat == nh.SPBOOK_CLASS: 92 | ambiguous_desc_name = [('plain', 'blank paper'), ('paperback', 'novel'), ('papyrus', 'Book of the Dead')] 93 | for odesc, oname in ambiguous_desc_name: 94 | if desc == odesc: 95 | return [from_name(oname, nh.SPBOOK_CLASS)] 96 | 97 | ambiguous_descs = [odesc for odesc, _ in ambiguous_desc_name] 98 | return [o for o in objects if isinstance(o, Spell) and o.name not in ambiguous_descs and o.name is not None] 99 | 100 | if cat == nh.WAND_CLASS: 101 | return [o for o in objects if isinstance(o, Wand) and o.name is not None] 102 | 103 | if cat in [nh.ROCK_CLASS, nh.BALL_CLASS, 16]: 104 | return [o for o in objects if o is not None and (o.desc or o.name) == desc] 105 | 106 | if objects[obj_id] == objects[-1]: 107 | return [objects[-1]] 108 | 109 | assert 0, (obj_id, objects[obj_id], cat) 110 | 111 | 112 | @utils.copy_result 113 | @functools.lru_cache(len(objects)) 114 | def possible_glyphs_from_object(obj): 115 | return [i for i in range(nh.GLYPH_OBJ_OFF, nh.GLYPH_OBJ_OFF + nh.NUM_OBJECTS) 116 | if objects[i - nh.GLYPH_OBJ_OFF] is not None and obj in possibilities_from_glyph(i)] 117 | 118 | 119 | @utils.copy_result 120 | @functools.lru_cache(len(objects) * 20) 121 | def desc_to_glyphs(desc, category=None): 122 | assert desc is not None 123 | ret = [i + nh.GLYPH_OBJ_OFF for i, o in enumerate(objects) 124 | if o is not None and ord(nh.objclass(i).oc_class) == category and o.desc == desc] 125 | assert ret 126 | return ret 127 | 128 | 129 | @functools.lru_cache(len(objects) * 100) 130 | def from_name(name, category=None): 131 | ret = [] 132 | for i, o in enumerate(objects): 133 | if o is not None and o.name == name and \ 134 | (category is None or ord(nh.objclass(i).oc_class) == category): 135 | ret.append(o) 136 | 137 | assert len(ret) == 1, (name, category, ret) 138 | return ret[0] 139 | 140 | 141 | @functools.lru_cache(len(objects)) 142 | def get_category(obj): 143 | return ord(nh.objclass(objects.index(obj)).oc_class) 144 | -------------------------------------------------------------------------------- /autoascend/strategy.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | 4 | class Strategy: 5 | """ 6 | A class representing strategy together with the condition for entering. 7 | 8 | A strategy is defined as a function returning a generator which yields exactly once. 9 | An yielded value indicate the condition for entering the strategy. Before the first yield no agent actions 10 | should be called! 11 | 12 | For example (pseudocode): 13 | ``` 14 | Strategy.wrap 15 | def brutal_fight_strategy(agent, max_distance): 16 | # check condition 17 | y, x = find_closest_monster() 18 | if y == -1: # no monster on the map 19 | yield False 20 | if max(abs(y - agent.blstats.y), abs(x - agent.blstats.x)) > max_distance: 21 | yield False 22 | 23 | yield True 24 | # execute action 25 | agent.go_to(y, x, stop_one_before=True) 26 | agent.fight(y, x) 27 | ``` 28 | """ 29 | 30 | @classmethod 31 | def wrap(cls, func): 32 | return lambda *a, **k: Strategy(wraps(func)(lambda: func(*a, **k))) 33 | 34 | def __init__(self, strategy, config=None): 35 | self.strategy = strategy 36 | if config is None: 37 | self.config = str(self.strategy) 38 | else: 39 | self.config = config 40 | 41 | def run(self, return_condition=False): 42 | gen = self.strategy() 43 | if not next(gen): 44 | if return_condition: 45 | return False 46 | return None 47 | try: 48 | next(gen) 49 | assert 0 50 | except StopIteration as e: 51 | if return_condition: 52 | return True 53 | return e.value 54 | 55 | def check_condition(self): 56 | gen = self.strategy() 57 | return next(gen) 58 | 59 | def condition(self, condition): 60 | """ Add additional condition for entering the strategy """ 61 | def f(self=self, condition=condition): 62 | if not condition(): 63 | yield False 64 | assert 0 65 | it = self.strategy() 66 | yield next(it) 67 | try: 68 | next(it) 69 | assert 0 70 | except StopIteration as e: 71 | return e.value 72 | 73 | return Strategy(f, {'strategy': self.config, 'condition': str(condition)}) 74 | 75 | def until(self, agent, condition): 76 | """ Run the strategy until condition """ 77 | def f(): 78 | if not condition(): 79 | yield False 80 | assert 0 81 | yield True 82 | 83 | strategy = self.condition(lambda: not condition()).preempt(agent, [Strategy(f)], 84 | continue_after_preemption=False) 85 | strategy.config = {'strategy': self.config, 'until': str(condition)} 86 | return strategy 87 | 88 | def before(self, strategy): 89 | """ Stack sequentially two strategies """ 90 | def f(self=self, strategy=strategy): 91 | yielded = False 92 | r1, r2 = None, None 93 | 94 | v1 = self.strategy() 95 | if next(v1): 96 | if not yielded: 97 | yielded = True 98 | yield True 99 | try: 100 | next(v1) 101 | assert 0, v1 102 | except StopIteration as e: 103 | r1 = e.value 104 | 105 | v2 = strategy.strategy() 106 | if next(v2): 107 | if not yielded: 108 | yielded = True 109 | yield True 110 | try: 111 | next(v2) 112 | assert 0, v2 113 | except StopIteration as e: 114 | r2 = e.value 115 | 116 | if not yielded: 117 | yield False 118 | 119 | return (r1, r2) 120 | 121 | return Strategy(f, {'1': self.config, '2': strategy.config}) 122 | 123 | def preempt(self, agent, strategies, continue_after_preemption=True): 124 | """ Specify other strategies that may preempt the strategy """ 125 | def f(self=self, agent=agent, strategies=strategies): 126 | gen = self.strategy() 127 | condition_passed = False 128 | with agent.disallow_step_calling(): 129 | condition_passed = next(gen) 130 | 131 | if not condition_passed: 132 | yield False 133 | yield True 134 | 135 | assert not agent._no_step_calls 136 | 137 | def f2(): 138 | try: 139 | next(gen) 140 | assert 0, gen 141 | except StopIteration as e: 142 | return e.value 143 | 144 | return agent.preempt(strategies, self, first_func=f2, continue_after_preemption=continue_after_preemption) 145 | 146 | return Strategy(f, {'strategy': self.config, 'preempt': [s.config for s in strategies]}) 147 | 148 | def repeat(self): 149 | """ Repeat strategy until the condition is true """ 150 | def f(self=self): 151 | yielded = False 152 | val = None 153 | while 1: 154 | gen = self.strategy() 155 | if not next(gen): 156 | if not yielded: 157 | yield False 158 | return 159 | 160 | if not yielded: 161 | yielded = True 162 | yield True 163 | 164 | try: 165 | next(gen) 166 | assert 0, gen 167 | except StopIteration as e: 168 | val = e.value 169 | return val 170 | 171 | return Strategy(f, {'repeat': self.config}) 172 | 173 | def every(self, num_of_iterations): 174 | """ 175 | Check the condition only every `num_of_iterations` iterations. Otherwise assume false. 176 | Used for execution time optimization. 177 | """ 178 | current_num = -1 179 | 180 | def f(): 181 | nonlocal current_num 182 | current_num += 1 183 | if current_num % num_of_iterations != 0: 184 | yield False 185 | assert 0 186 | it = self.strategy() 187 | yield next(it) 188 | current_num = -1 189 | try: 190 | next(it) 191 | assert 0 192 | except StopIteration as e: 193 | return e.value 194 | 195 | return Strategy(f, {'strategy': self.config, 'every': num_of_iterations}) 196 | 197 | def __repr__(self): 198 | return str(self.config) 199 | -------------------------------------------------------------------------------- /autoascend/combat/movement_priority.py: -------------------------------------------------------------------------------- 1 | from ..utils import adjacent 2 | from . import utils 3 | from .monster_utils import WEAK_MONSTERS, ONLY_RANGED_SLOW_MONSTERS, consider_melee_only_ranged_if_hp_full, \ 4 | imminent_death_on_melee, EXPLODING_MONSTERS, WEIRD_MONSTERS 5 | 6 | 7 | def _draw_around(priority, y, x, value, radius=1, operation='add'): 8 | # TODO: optimize 9 | for y1 in range(y - radius, y + radius + 1): 10 | for x1 in range(x - radius, x + radius + 1): 11 | if max(abs(y1 - y), abs(x1 - x)) != radius: 12 | continue 13 | if 0 <= y1 < priority.shape[0] and 0 <= x1 < priority.shape[1]: 14 | if operation == 'add': 15 | priority[y1, x1] += value 16 | elif operation == 'max': 17 | priority[y1, x1] = max(priority[y1, x1], value) 18 | else: 19 | assert 0, operation 20 | 21 | 22 | def _draw_ranged(priority, y, x, value, walkable, radius=1, operation='add'): 23 | # TODO: optimize 24 | for direction_y in (-1, 0, 1): 25 | for direction_x in (-1, 0, 1): 26 | if direction_y != 0 or direction_x != 0: 27 | for i in range(1, radius + 1): 28 | y1 = y + direction_y * i 29 | x1 = x + direction_x * i 30 | if 0 <= y1 < priority.shape[0] and 0 <= x1 < priority.shape[1]: 31 | if not walkable[y1, x1]: 32 | break 33 | if operation == 'add': 34 | priority[y1, x1] += value 35 | elif operation == 'max': 36 | priority[y1, x1] = max(priority[y1, x1], value) 37 | else: 38 | assert 0, operation 39 | 40 | 41 | def draw_monster_priority_positive(agent, monster, priority, walkable): 42 | _, y, x, mon, _ = monster 43 | 44 | # don't move into the monster 45 | priority[y, x] = float('nan') 46 | 47 | if mon.mname in WEAK_MONSTERS: 48 | # weak monster - freely engage in melee 49 | _draw_around(priority, y, x, 2, radius=1, operation='max') 50 | _draw_around(priority, y, x, 1, radius=2, operation='max') 51 | elif 'mold' in mon.mname and mon.mname not in ONLY_RANGED_SLOW_MONSTERS: 52 | if agent.blstats.hitpoints >= 15 or agent.blstats.hitpoints == agent.blstats.max_hitpoints: 53 | # freely engage in melee 54 | _draw_around(priority, y, x, 2, radius=1, operation='max') 55 | _draw_around(priority, y, x, 1, radius=2, operation='max') 56 | if len(agent.inventory.get_ranged_combinations()): 57 | _draw_ranged(priority, y, x, 1, walkable, radius=7, operation='max') 58 | elif mon.mname in ONLY_RANGED_SLOW_MONSTERS: # and agent.inventory.get_ranged_combinations(): 59 | if consider_melee_only_ranged_if_hp_full(agent, monster): 60 | _draw_around(priority, y, x, 2, radius=1, operation='max') 61 | _draw_around(priority, y, x, 1, radius=2, operation='max') 62 | if len(agent.inventory.get_ranged_combinations()): 63 | _draw_ranged(priority, y, x, 1, walkable, radius=7, operation='max') 64 | elif 'unicorn' in mon.mname: 65 | if agent.blstats.hitpoints >= 15 or agent.blstats.hitpoints == agent.blstats.max_hitpoints: 66 | # freely engage in melee 67 | _draw_around(priority, y, x, 2, radius=1, operation='max') 68 | _draw_around(priority, y, x, 1, radius=2, operation='max') 69 | else: 70 | if not imminent_death_on_melee(agent, monster) and not utils.wielding_ranged_weapon(agent): 71 | # engage, but ensure striking first if possible 72 | if mon.mmove <= 12: 73 | _draw_around(priority, y, x, 3, radius=2, operation='max') 74 | else: 75 | _draw_around(priority, y, x, 3, radius=3, operation='max') 76 | if utils.wielding_ranged_weapon(agent): 77 | _draw_ranged(priority, y, x, 4, walkable, radius=7, operation='max') 78 | elif len(agent.inventory.get_ranged_combinations()): 79 | _draw_ranged(priority, y, x, 1, walkable, radius=7, operation='max') 80 | 81 | 82 | def draw_monster_priority_negative(agent, monster, priority, walkable): 83 | _, y, x, mon, _ = monster 84 | 85 | if imminent_death_on_melee(agent, monster) and not mon.mname in WEAK_MONSTERS \ 86 | and not mon.mname in ONLY_RANGED_SLOW_MONSTERS: 87 | if mon.mmove <= 12: 88 | _draw_around(priority, y, x, -10, radius=1) 89 | else: 90 | if adjacent((agent.blstats.y, agent.blstats.x), (y, x)): 91 | # no point in running -- monster is fast 92 | pass 93 | else: 94 | _draw_around(priority, y, x, -10, radius=2) 95 | _draw_around(priority, y, x, -5, radius=1) 96 | 97 | if not len(agent.inventory.get_ranged_combinations()): 98 | # prefer avoiding being in line of fire 99 | _draw_ranged(priority, y, x, -1, walkable, radius=7) 100 | 101 | # if agent.blstats.hitpoints <= 8 and not is_monster_faster(agent, monster) and not mon.mname in WEAK_MONSTERS \ 102 | # and not mon.mname in ONLY_RANGED_SLOW_MONSTERS: 103 | # # stay out of melee range 104 | # _draw_around(priority, y, x, -10, radius=1) 105 | # if not len(agent.inventory.get_ranged_combinations()): 106 | # # prefer avoiding being in line of fire 107 | # _draw_ranged(priority, y, x, -1, walkable, radius=7) 108 | 109 | if mon.mname in EXPLODING_MONSTERS: 110 | _draw_around(priority, y, x, -10, radius=1) 111 | if mon.mname not in ONLY_RANGED_SLOW_MONSTERS: 112 | _draw_around(priority, y, x, -5, radius=2) 113 | _draw_ranged(priority, y, x, 4, walkable, radius=7) 114 | elif 'mold' in mon.mname and mon.mname not in ONLY_RANGED_SLOW_MONSTERS: 115 | # prioritize staying in ranged weapons line of fire 116 | if len(agent.inventory.get_ranged_combinations()): 117 | _draw_ranged(priority, y, x, 2, walkable, radius=7) 118 | elif mon.mname in WEIRD_MONSTERS: 119 | # stay away 120 | _draw_around(priority, y, x, -10, radius=1) 121 | # prioritize staying in ranged weapons line of fire 122 | if len(agent.inventory.get_ranged_combinations()): 123 | _draw_ranged(priority, y, x, 6, walkable, radius=7) 124 | elif mon.mname in ONLY_RANGED_SLOW_MONSTERS: # and agent.inventory.get_ranged_combinations(): 125 | # ignore 126 | pass 127 | elif 'unicorn' in mon.mname: 128 | pass 129 | else: 130 | if mon.mname not in WEAK_MONSTERS: 131 | # engage, but ensure striking first if possible 132 | _draw_around(priority, y, x, -9, radius=1) 133 | if not len(agent.inventory.get_ranged_combinations()): 134 | _draw_ranged(priority, y, x, -1, walkable, radius=7) 135 | 136 | if mon.mname == 'purple worm' and len(agent.inventory.get_ranged_combinations()): 137 | _draw_around(priority, y, x, -10, radius=1) 138 | -------------------------------------------------------------------------------- /autoascend/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from collections import Counter 3 | from functools import partial, wraps 4 | from itertools import chain 5 | 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import numba as nb 9 | import numpy as np 10 | import seaborn as sns 11 | import toolz 12 | 13 | from .strategy import Strategy 14 | 15 | 16 | @nb.njit(cache=True) 17 | def bfs(y, x, *, walkable, walkable_diagonally, can_squeeze): 18 | dis = np.zeros(walkable.shape, dtype=np.int32) 19 | dis[:] = -1 20 | dis[y, x] = 0 21 | 22 | buf = np.zeros((walkable.shape[0] * walkable.shape[1], 2), dtype=np.uint32) 23 | index = 0 24 | buf[index] = (y, x) 25 | size = 1 26 | while index < size: 27 | y, x = buf[index] 28 | index += 1 29 | 30 | for dy in [-1, 0, 1]: 31 | for dx in [-1, 0, 1]: 32 | py, px = y + dy, x + dx 33 | if 0 <= py < walkable.shape[0] and 0 <= px < walkable.shape[1] and (dy != 0 or dx != 0): 34 | if (walkable[py, px] and 35 | (abs(dy) + abs(dx) <= 1 or 36 | (walkable_diagonally[py, px] and walkable_diagonally[y, x] and 37 | (can_squeeze or walkable[py, x] or walkable[y, px])))): 38 | if dis[py, px] == -1: 39 | dis[py, px] = dis[y, x] + 1 40 | buf[size] = (py, px) 41 | size += 1 42 | 43 | return dis 44 | 45 | 46 | def translate(array, y_offset, x_offset, out=None): 47 | if out is None: 48 | out = np.zeros_like(array) 49 | else: 50 | out.fill(0) 51 | 52 | if y_offset > 0: 53 | array = array[:-y_offset] 54 | elif y_offset < 0: 55 | array = array[-y_offset:] 56 | if x_offset > 0: 57 | array = array[:, :-x_offset] 58 | elif x_offset < 0: 59 | array = array[:, -x_offset:] 60 | 61 | sy, sx = max(y_offset, 0), max(x_offset, 0) 62 | out[sy: sy + array.shape[0], sx: sx + array.shape[1]] = array 63 | return out 64 | 65 | 66 | @nb.njit('b1[:,:](i2[:,:],i2,i2,b1[:])', cache=True) 67 | def _isin_kernel(array, mi, ma, mask): 68 | ret = np.zeros(array.shape, dtype=nb.b1) 69 | for y in range(array.shape[0]): 70 | for x in range(array.shape[1]): 71 | if array[y, x] < mi or array[y, x] > ma: 72 | continue 73 | ret[y, x] = mask[array[y, x] - mi] 74 | return ret 75 | 76 | 77 | @functools.lru_cache(1024) 78 | def _isin_mask(elems): 79 | elems = np.array(list(chain(*elems)), np.int16) 80 | return _isin_mask_kernel(elems) 81 | 82 | 83 | @nb.njit('Tuple((i2,i2,b1[:]))(i2[:])', cache=True) 84 | def _isin_mask_kernel(elems): 85 | mi: i2 = 32767 86 | ma: i2 = -32768 87 | for i in range(elems.shape[0]): 88 | if mi > elems[i]: 89 | mi = elems[i] 90 | if ma < elems[i]: 91 | ma = elems[i] 92 | ret = np.zeros(ma - mi + 1, dtype=nb.b1) 93 | for i in range(elems.shape[0]): 94 | ret[elems[i] - mi] = True 95 | return mi, ma, ret 96 | 97 | 98 | def isin(array, *elems): 99 | assert array.dtype == np.int16 100 | 101 | # for memoization 102 | elems = tuple(( 103 | e if isinstance(e, tuple) else 104 | e if isinstance(e, frozenset) else 105 | tuple(e) if isinstance(e, list) else 106 | frozenset(e) if isinstance(e, set) else 107 | e 108 | for e in elems)) 109 | 110 | mi, ma, mask = _isin_mask(elems) 111 | return _isin_kernel(array, mi, ma, mask) 112 | 113 | 114 | def any_in(array, *elems): 115 | # TODO: optimize 116 | return isin(array, *elems).any() 117 | 118 | 119 | @toolz.curry 120 | def debug_log(txt, fun, color=(255, 255, 255)): 121 | @wraps(fun) 122 | def wrapper(self, *args, **kwargs): 123 | # TODO: make it cleaner 124 | if type(self).__name__ != 'Agent': 125 | env = self.agent.env 126 | else: 127 | env = self.env 128 | 129 | with env.debug_log(txt=txt, color=color): 130 | ret = fun(self, *args, **kwargs) 131 | if isinstance(ret, Strategy): 132 | def f(strategy=ret.strategy, *a, **k): 133 | it = strategy(*a, **k) 134 | yield next(it) 135 | with env.debug_log(txt=txt, color=color): 136 | try: 137 | next(it) 138 | assert 0 139 | except StopIteration as e: 140 | return e.value 141 | 142 | ret.strategy = partial(f, ret.strategy) 143 | return ret 144 | 145 | return wrapper 146 | 147 | 148 | def adjacent(p1, p2): 149 | return max(abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) == 1 150 | 151 | 152 | def calc_dps(to_hit, damage): 153 | return damage * min(20, max(0, (to_hit - 1))) / 20 154 | 155 | 156 | @Strategy.wrap 157 | def assert_strategy(error=None): 158 | yield True 159 | assert 0, error 160 | 161 | 162 | def copy_result(func): 163 | @wraps(func) 164 | def f(*args, **kwargs): 165 | ret = func(*args, **kwargs) 166 | if isinstance(ret, list): 167 | return ret.copy() 168 | if isinstance(ret, tuple): 169 | return tuple((x.copy() if isinstance(x, list) else x for x in ret)) 170 | return ret.copy() 171 | 172 | return f 173 | 174 | 175 | def dilate(mask, radius=1, with_diagonal=True): 176 | d = radius * 2 + 1 177 | if with_diagonal: 178 | kernel = np.ones((d, d), dtype=np.uint8) 179 | else: 180 | kernel = np.zeros((d, d), dtype=np.uint8) 181 | kernel[radius: radius + 1, :] = 1 182 | kernel[:, radius: radius + 1] = 1 183 | return cv2.dilate(mask.astype(np.uint8), kernel).astype(bool) 184 | 185 | 186 | def slice_with_padding(array, a1, a2, b1, b2, pad_value=0): 187 | ret = np.zeros_like(array, shape=(a2 - a1, b2 - b1)) + pad_value 188 | off_a1 = -a1 if a1 < 0 else 0 189 | off_b1 = -b1 if b1 < 0 else 0 190 | off_a2 = array.shape[0] - a2 if a2 > array.shape[0] else ret.shape[0] 191 | off_b2 = array.shape[1] - b2 if b2 > array.shape[1] else ret.shape[1] 192 | ret[off_a1: off_a2, off_b1: off_b2] = array[max(0, a1): a2, max(0, b1): b2] 193 | return ret 194 | 195 | 196 | def slice_square_with_padding(array, center_y, center_x, radius, pad_value=0): 197 | return slice_with_padding(array, center_y - radius, center_y + radius + 1, 198 | center_x - radius, center_x + radius + 1, pad_value=pad_value) 199 | 200 | 201 | def plot_dashboard(fig, res): 202 | histogram_keys = ['score', 'steps', 'turns', 'level_num', 'experience_level', 'milestone'] 203 | spec = fig.add_gridspec(len(histogram_keys) + 2, 2) 204 | for i, k in enumerate(histogram_keys): 205 | ax = fig.add_subplot(spec[i, 0]) 206 | ax.set_title(k) 207 | if isinstance(res[k][0], str): 208 | counter = Counter(res[k]) 209 | sns.barplot(x=[k for k, v in counter.most_common()], y=[v for k, v in counter.most_common()]) 210 | else: 211 | if k in ['level_num', 'experience_level', 'milestone']: 212 | bins = [b + 0.5 for b in range(max(res[k]) + 1)] 213 | else: 214 | bins = np.quantile(res[k], 215 | np.linspace(0, 1, min(len(res[k]) // (20 + len(res[k]) // 50) + 2, 50))) 216 | sns.histplot(res[k], bins=bins, stat='density', ax=ax) 217 | if k == 'milestone': 218 | ticks = sorted(set([(int(m), str(m)) for m in res[k]])) 219 | plt.xticks(ticks=[t[0] for t in ticks], labels=[t[1] for t in ticks]) 220 | ax = fig.add_subplot(spec[:len(histogram_keys) // 2, 1]) 221 | sns.scatterplot(x='turns', y='steps', data=res, ax=ax) 222 | ax = fig.add_subplot(spec[len(histogram_keys) // 2: -2, 1]) 223 | sns.scatterplot(x='turns', y='score', data=res, ax=ax) 224 | ax = fig.add_subplot(spec[-2:, :]) 225 | res['role'] = [h.split('-')[0] for h in res['character']] 226 | res['race'] = [h.split('-')[1] for h in res['character']] 227 | res['gender'] = [h.split('-')[2] for h in res['character']] 228 | res['alignment'] = [h.split('-')[3] for h in res['character']] 229 | res['race-alignment'] = [f'{r}-{a}' for r, a in zip(res['race'], res['alignment'])] 230 | sns.violinplot(x='role', y='score', color='white', hue='gender', 231 | hue_order=sorted(set(res['gender'])), split=len(set(res['gender'])) == 2, 232 | order=sorted(set(res['role'])), inner='quartile', 233 | data=res, ax=ax) 234 | palette = ['#ff7043', '#cc3311', '#ee3377', '#0077bb', '#33bbee', '#009988', '#bbbbbb'] 235 | sns.swarmplot(x='role', y='score', hue='race-alignment', hue_order=sorted(set(res['race-alignment'])), 236 | order=sorted(set(res['role'])), 237 | data=res, ax=ax, palette=palette) 238 | -------------------------------------------------------------------------------- /autoascend/item/item.py: -------------------------------------------------------------------------------- 1 | import nle.nethack as nh 2 | 3 | from autoascend import objects as O 4 | from autoascend.glyph import MON, WEA 5 | 6 | 7 | class Item: 8 | # beatitude 9 | UNKNOWN = 0 10 | CURSED = 1 11 | UNCURSED = 2 12 | BLESSED = 3 13 | 14 | # shop status 15 | NOT_SHOP = 0 16 | FOR_SALE = 1 17 | UNPAID = 2 18 | 19 | def __init__(self, objs, glyphs, count=1, status=UNKNOWN, modifier=None, equipped=False, at_ready=False, 20 | monster_id=None, shop_status=NOT_SHOP, price=0, dmg_bonus=None, to_hit_bonus=None, 21 | naming='', comment='', uses=None, text=None): 22 | assert isinstance(objs, list) and len(objs) >= 1 23 | assert isinstance(glyphs, list) and len(glyphs) >= 1 and all((nh.glyph_is_object(g) for g in glyphs)) 24 | assert isinstance(count, int) 25 | 26 | self.objs = objs 27 | self.glyphs = glyphs 28 | self.count = count 29 | self.status = status 30 | self.modifier = modifier 31 | self.equipped = equipped 32 | self.uses = uses 33 | self.at_ready = at_ready 34 | self.monster_id = monster_id 35 | self.shop_status = shop_status 36 | self.price = price 37 | self.dmg_bonus = dmg_bonus 38 | self.to_hit_bonus = to_hit_bonus 39 | self.naming = naming 40 | self.comment = comment 41 | self.text = text 42 | 43 | self.content = None # for checked containers it will be set after the constructor 44 | self.container_id = None # for containers and possible containers it will be set after the constructor 45 | 46 | self.category = O.get_category(self.objs[0]) 47 | assert all((ord(nh.objclass(nh.glyph_to_obj(g)).oc_class) == self.category for g in self.glyphs)) 48 | 49 | def display_glyphs(self): 50 | if self.is_corpse(): 51 | assert self.monster_id is not None 52 | return [nh.GLYPH_BODY_OFF + self.monster_id] 53 | if self.is_statue(): 54 | assert self.monster_id is not None 55 | return [nh.GLYPH_STATUE_OFF + self.monster_id] 56 | return self.glyphs 57 | 58 | def is_unambiguous(self): 59 | return len(self.objs) == 1 60 | 61 | def can_be_dropped_from_inventory(self): 62 | return not ( 63 | (isinstance(self.objs[0], (O.Weapon, O.WepTool)) and self.status == Item.CURSED and self.equipped) or 64 | (isinstance(self.objs[0], O.Armor) and self.equipped) or 65 | (self.is_unambiguous() and self.object == O.from_name('loadstone') and self.status == Item.CURSED) or 66 | (self.category == nh.BALL_CLASS and self.equipped) 67 | ) 68 | 69 | def weight(self, with_content=True): 70 | return self.count * self.unit_weight(with_content=with_content) 71 | 72 | def unit_weight(self, with_content=True): 73 | if self.is_corpse(): 74 | return MON.permonst(self.monster_id).cwt 75 | 76 | if self.is_possible_container(): 77 | return 100000 78 | 79 | if self.objs[0] in [ 80 | O.from_name("glob of gray ooze"), 81 | O.from_name("glob of brown pudding"), 82 | O.from_name("glob of green slime"), 83 | O.from_name("glob of black pudding"), 84 | ]: 85 | assert self.is_unambiguous() 86 | return 10000 # weight is unknown 87 | 88 | weight = max((obj.wt for obj in self.objs)) 89 | 90 | if self.is_container() and with_content: 91 | weight += self.content.weight() # TODO: bag of holding 92 | 93 | return weight 94 | 95 | @property 96 | def object(self): 97 | assert self.is_unambiguous() 98 | return self.objs[0] 99 | 100 | ######## WEAPON 101 | 102 | def is_weapon(self): 103 | return self.category == nh.WEAPON_CLASS 104 | 105 | def get_weapon_bonus(self, large_monster): 106 | assert self.is_weapon() 107 | 108 | hits = [] 109 | dmgs = [] 110 | for weapon in self.objs: 111 | dmg = WEA.expected_damage(weapon.damage_large if large_monster else weapon.damage_small) 112 | to_hit = 1 + weapon.hitbon 113 | if self.modifier is not None: 114 | dmg += max(0, self.modifier) 115 | to_hit += self.modifier 116 | 117 | dmg += 0 if self.dmg_bonus is None else self.dmg_bonus 118 | to_hit += 0 if self.to_hit_bonus is None else self.to_hit_bonus 119 | 120 | dmgs.append(dmg) 121 | hits.append(to_hit) 122 | 123 | # assume the worse 124 | return min(hits), min(dmgs) 125 | 126 | def is_launcher(self): 127 | if not self.is_weapon() or not self.is_unambiguous(): 128 | return False 129 | 130 | return self.object.name in ['bow', 'elven bow', 'orcish bow', 'yumi', 'crossbow', 'sling'] 131 | 132 | def is_fired_projectile(self, launcher=None): 133 | if not self.is_weapon() or not self.is_unambiguous(): 134 | return False 135 | 136 | arrows = ['arrow', 'elven arrow', 'orcish arrow', 'silver arrow', 'ya'] 137 | 138 | if launcher is None: 139 | return self.object.name in (arrows + ['crossbow bolt']) # TODO: sling ammo 140 | else: 141 | launcher_name = launcher.object.name 142 | if launcher_name == 'crossbow': 143 | return self.object.name == 'crossbow bolt' 144 | elif launcher_name == 'sling': 145 | # TODO: sling ammo 146 | return False 147 | else: # any bow 148 | assert launcher_name in ['bow', 'elven bow', 'orcish bow', 'yumi'], launcher_name 149 | return self.object.name in arrows 150 | 151 | def is_thrown_projectile(self): 152 | if not self.is_weapon() or not self.is_unambiguous(): 153 | return False 154 | 155 | # TODO: boomerang 156 | # TODO: aklys, Mjollnir 157 | return self.object.name in \ 158 | ['dagger', 'orcish dagger', 'dagger silver', 'athame dagger', 'elven dagger', 159 | 'worm tooth', 'knife', 'stiletto', 'scalpel', 'crysknife', 160 | 'dart', 'shuriken'] 161 | 162 | def __str__(self): 163 | if self.text is not None: 164 | return self.text 165 | return (f'{self.count}_' 166 | f'{self.status if self.status is not None else ""}_' 167 | f'{self.modifier if self.modifier is not None else ""}_' 168 | f'{",".join(list(map(lambda x: x.name, self.objs)))}' 169 | ) 170 | 171 | def __repr__(self): 172 | return str(self) 173 | 174 | ######## ARMOR 175 | 176 | def is_armor(self): 177 | return self.category == nh.ARMOR_CLASS 178 | 179 | def get_ac(self): 180 | assert self.is_armor() 181 | return self.object.ac - (self.modifier if self.modifier is not None else 0) 182 | 183 | ######## WAND 184 | 185 | def is_wand(self): 186 | return isinstance(self.objs[0], O.Wand) 187 | 188 | def is_beam_wand(self): 189 | if not self.is_wand(): 190 | return False 191 | beam_wand_types = ['cancellation', 'locking', 'make invisible', 192 | 'nothing', 'opening', 'polymorph', 'probing', 'slow monster', 193 | 'speed monster', 'striking', 'teleportation', 'undead turning'] 194 | beam_wand_types = [O.from_name(w, nh.WAND_CLASS) for w in beam_wand_types] 195 | for obj in self.objs: 196 | if obj not in beam_wand_types: 197 | return False 198 | return True 199 | 200 | def is_ray_wand(self): 201 | if not self.is_wand(): 202 | return False 203 | ray_wand_types = ['cold', 'death', 'digging', 'fire', 'lightning', 'magic missile', 'sleep'] 204 | ray_wand_types = [O.from_name(w, nh.WAND_CLASS) for w in ray_wand_types] 205 | for obj in self.objs: 206 | if obj not in ray_wand_types: 207 | return False 208 | return True 209 | 210 | def wand_charges_left(self, item): 211 | assert item.is_wand() 212 | 213 | def is_offensive_usable_wand(self): 214 | if len(self.objs) != 1: 215 | return False 216 | if not self.is_ray_wand(): 217 | return False 218 | if self.uses == 'no charges': 219 | # TODO: is it right ? 220 | return False 221 | if self.objs[0] == O.from_name('sleep', nh.WAND_CLASS): 222 | return False 223 | if self.objs[0] == O.from_name('digging', nh.WAND_CLASS): 224 | return False 225 | return True 226 | 227 | ######## FOOD 228 | 229 | def is_food(self): 230 | if isinstance(self.objs[0], O.Food): 231 | assert self.is_unambiguous() 232 | return True 233 | 234 | def nutrition_per_weight(self): 235 | # TODO: corpses/tins 236 | assert self.is_food() 237 | return self.object.nutrition / max(self.unit_weight(), 1) 238 | 239 | def is_corpse(self): 240 | if self.objs[0] == O.from_name('corpse'): 241 | assert self.is_unambiguous() 242 | return True 243 | return False 244 | 245 | ######## STATUE 246 | 247 | def is_statue(self): 248 | if self.objs[0] == O.from_name('statue'): 249 | assert self.is_unambiguous() 250 | return True 251 | return False 252 | 253 | ######## CONTAINER 254 | 255 | def is_chest(self): 256 | if self.is_unambiguous() and self.object.name == 'bag of tricks': 257 | return False 258 | assert self.is_possible_container() or self.is_container(), self.objs 259 | assert isinstance(self.objs[0], O.Container), self.objs 260 | return self.objs[0].desc != 'bag' 261 | 262 | def is_container(self): 263 | # don't consider bag of tricks as a container. 264 | # If the identifier doesn't exist yet, it's not consider a container 265 | return self.content is not None 266 | 267 | def is_possible_container(self): 268 | if self.is_container(): 269 | return False 270 | 271 | if self.is_unambiguous() and self.object.name == 'bag of tricks': 272 | return False 273 | return any((isinstance(obj, O.Container) for obj in self.objs)) 274 | 275 | def content(self): 276 | assert self.is_container() 277 | return self.content 278 | -------------------------------------------------------------------------------- /autoascend/glyph/monflag.py: -------------------------------------------------------------------------------- 1 | # /include/monflag.h 2 | MS_SILENT = 0 #/* makes no sound */ 3 | MS_BARK = 1 #/* if full moon, may howl */ 4 | MS_MEW = 2 #/* mews or hisses */ 5 | MS_ROAR = 3 #/* roars */ 6 | MS_GROWL = 4 #/* growls */ 7 | MS_SQEEK = 5 #/* squeaks, as a rodent */ 8 | MS_SQAWK = 6 #/* squawks, as a bird */ 9 | MS_HISS = 7 #/* hisses */ 10 | MS_BUZZ = 8 #/* buzzes (killer bee) */ 11 | MS_GRUNT = 9 #/* grunts (or speaks own language) */ 12 | MS_NEIGH = 10 #/* neighs, as an equine */ 13 | MS_WAIL = 11 #/* wails, as a tortured soul */ 14 | MS_GURGLE = 12 #/* gurgles, as liquid or through saliva */ 15 | MS_BURBLE = 13 #/* burbles (jabberwock) */ 16 | MS_ANIMAL = 13 #/* up to here are animal noises */ 17 | MS_SHRIEK = 15 #/* wakes up others */ 18 | MS_BONES = 16 #/* rattles bones (skeleton) */ 19 | MS_LAUGH = 17 #/* grins, smiles, giggles, and laughs */ 20 | MS_MUMBLE = 18 #/* says something or other */ 21 | MS_IMITATE = 19 #/* imitates others (leocrotta) */ 22 | MS_ORC = MS_GRUNT #/* intelligent brutes */ 23 | MS_HUMANOID = 20 #/* generic traveling companion */ 24 | MS_ARREST = 21 #/* "Stop in the name of the law!" (Kops) */ 25 | MS_SOLDIER = 22 #/* army and watchmen expressions */ 26 | MS_GUARD = 23 #/* "Please drop that gold and follow me." */ 27 | MS_DJINNI = 24 #/* "Thank you for freeing me!" */ 28 | MS_NURSE = 25 #/* "Take off your shirt, please." */ 29 | MS_SEDUCE = 26 #/* "Hello, sailor." (Nymphs) */ 30 | MS_VAMPIRE = 27 #/* vampiric seduction, Vlad's exclamations */ 31 | MS_BRIBE = 28 #/* asks for money, or berates you */ 32 | MS_CUSS = 29 #/* berates (demons) or intimidates (Wiz) */ 33 | MS_RIDER = 30 #/* astral level special monsters */ 34 | MS_LEADER = 31 #/* your class leader */ 35 | MS_NEMESIS = 32 #/* your nemesis */ 36 | MS_GUARDIAN = 33 #/* your leader's guards */ 37 | MS_SELL = 34 #/* demand payment, complain about shoplifters */ 38 | MS_ORACLE = 35 #/* do a consultation */ 39 | MS_PRIEST = 36 #/* ask for contribution; do cleansing */ 40 | MS_SPELL = 37 #/* spellcaster not matching any of the above */ 41 | MS_WERE = 38 #/* lycanthrope in human form */ 42 | MS_BOAST = 39 #/* giants */ 43 | 44 | MR_FIRE = 0x01 #/* resists fire */ 45 | MR_COLD = 0x02 #/* resists cold */ 46 | MR_SLEEP = 0x04 #/* resists sleep */ 47 | MR_DISINT = 0x08 #/* resists disintegration */ 48 | MR_ELEC = 0x10 #/* resists electricity */ 49 | MR_POISON = 0x20 #/* resists poison */ 50 | MR_ACID = 0x40 #/* resists acid */ 51 | MR_STONE = 0x80 #/* resists petrification */ 52 | #/* other resistances: magic, sickness */ 53 | #/* other conveyances: teleport, teleport control, telepathy */ 54 | 55 | #/* individual resistances */ 56 | MR2_SEE_INVIS = 0x0100 #/* see invisible */ 57 | MR2_LEVITATE = 0x0200 #/* levitation */ 58 | MR2_WATERWALK = 0x0400 #/* water walking */ 59 | MR2_MAGBREATH = 0x0800 #/* magical breathing */ 60 | MR2_DISPLACED = 0x1000 #/* displaced */ 61 | MR2_STRENGTH = 0x2000 #/* gauntlets of power */ 62 | MR2_FUMBLING = 0x4000 #/* clumsy */ 63 | 64 | M1_FLY = 0x00000001 #/* can fly or float */ 65 | M1_SWIM = 0x00000002 #/* can traverse water */ 66 | M1_AMORPHOUS = 0x00000004 #/* can flow under doors */ 67 | M1_WALLWALK = 0x00000008 #/* can phase thru rock */ 68 | M1_CLING = 0x00000010 #/* can cling to ceiling */ 69 | M1_TUNNEL = 0x00000020 #/* can tunnel thru rock */ 70 | M1_NEEDPICK = 0x00000040 #/* needs pick to tunnel */ 71 | M1_CONCEAL = 0x00000080 #/* hides under objects */ 72 | M1_HIDE = 0x00000100 #/* mimics, blends in with ceiling */ 73 | M1_AMPHIBIOUS = 0x00000200 #/* can survive underwater */ 74 | M1_BREATHLESS = 0x00000400 #/* doesn't need to breathe */ 75 | M1_NOTAKE = 0x00000800 #/* cannot pick up objects */ 76 | M1_NOEYES = 0x00001000 #/* no eyes to gaze into or blind */ 77 | M1_NOHANDS = 0x00002000 #/* no hands to handle things */ 78 | M1_NOLIMBS = 0x00006000 #/* no arms/legs to kick/wear on */ 79 | M1_NOHEAD = 0x00008000 #/* no head to behead */ 80 | M1_MINDLESS = 0x00010000 #/* has no mind--golem, zombie, mold */ 81 | M1_HUMANOID = 0x00020000 #/* has humanoid head/arms/torso */ 82 | M1_ANIMAL = 0x00040000 #/* has animal body */ 83 | M1_SLITHY = 0x00080000 #/* has serpent body */ 84 | M1_UNSOLID = 0x00100000 #/* has no solid or liquid body */ 85 | M1_THICK_HIDE = 0x00200000 #/* has thick hide or scales */ 86 | M1_OVIPAROUS = 0x00400000 #/* can lay eggs */ 87 | M1_REGEN = 0x00800000 #/* regenerates hit points */ 88 | M1_SEE_INVIS = 0x01000000 #/* can see invisible creatures */ 89 | M1_TPORT = 0x02000000 #/* can teleport */ 90 | M1_TPORT_CNTRL = 0x04000000 #/* controls where it teleports to */ 91 | M1_ACID = 0x08000000 #/* acidic to eat */ 92 | M1_POIS = 0x10000000 #/* poisonous to eat */ 93 | M1_CARNIVORE = 0x20000000 #/* eats corpses */ 94 | M1_HERBIVORE = 0x40000000 #/* eats fruits */ 95 | M1_OMNIVORE = 0x60000000 #/* eats both */ 96 | #ifdef NHSTDC 97 | #define M1_METALLIVORE 0x80000000UL /* eats metal */ 98 | #else 99 | M1_METALLIVORE = 0x80000000 #/* eats metal */ 100 | #endif 101 | 102 | M2_NOPOLY = 0x00000001 #/* players mayn't poly into one */ 103 | M2_UNDEAD = 0x00000002 #/* is walking dead */ 104 | M2_WERE = 0x00000004 #/* is a lycanthrope */ 105 | M2_HUMAN = 0x00000008 #/* is a human */ 106 | M2_ELF = 0x00000010 #/* is an elf */ 107 | M2_DWARF = 0x00000020 #/* is a dwarf */ 108 | M2_GNOME = 0x00000040 #/* is a gnome */ 109 | M2_ORC = 0x00000080 #/* is an orc */ 110 | M2_DEMON = 0x00000100 #/* is a demon */ 111 | M2_MERC = 0x00000200 #/* is a guard or soldier */ 112 | M2_LORD = 0x00000400 #/* is a lord to its kind */ 113 | M2_PRINCE = 0x00000800 #/* is an overlord to its kind */ 114 | M2_MINION = 0x00001000 #/* is a minion of a deity */ 115 | M2_GIANT = 0x00002000 #/* is a giant */ 116 | M2_SHAPESHIFTER = 0x00004000 #/* is a shapeshifting species */ 117 | M2_MALE = 0x00010000 #/* always male */ 118 | M2_FEMALE = 0x00020000 #/* always female */ 119 | M2_NEUTER = 0x00040000 #/* neither male nor female */ 120 | M2_PNAME = 0x00080000 #/* monster name is a proper name */ 121 | M2_HOSTILE = 0x00100000 #/* always starts hostile */ 122 | M2_PEACEFUL = 0x00200000 #/* always starts peaceful */ 123 | M2_DOMESTIC = 0x00400000 #/* can be tamed by feeding */ 124 | M2_WANDER = 0x00800000 #/* wanders randomly */ 125 | M2_STALK = 0x01000000 #/* follows you to other levels */ 126 | M2_NASTY = 0x02000000 #/* extra-nasty monster (more xp) */ 127 | M2_STRONG = 0x04000000 #/* strong (or big) monster */ 128 | M2_ROCKTHROW = 0x08000000 #/* throws boulders */ 129 | M2_GREEDY = 0x10000000 #/* likes gold */ 130 | M2_JEWELS = 0x20000000 #/* likes gems */ 131 | M2_COLLECT = 0x40000000 #/* picks up weapons and food */ 132 | #ifdef NHSTDC 133 | #define M2_MAGIC 0x80000000UL /* picks up magic items */ 134 | #else 135 | M2_MAGIC = 0x80000000 #/* picks up magic items */ 136 | #endif 137 | 138 | M3_WANTSAMUL = 0x0001 #/* would like to steal the amulet */ 139 | M3_WANTSBELL = 0x0002 #/* wants the bell */ 140 | M3_WANTSBOOK = 0x0004 #/* wants the book */ 141 | M3_WANTSCAND = 0x0008 #/* wants the candelabrum */ 142 | M3_WANTSARTI = 0x0010 #/* wants the quest artifact */ 143 | M3_WANTSALL = 0x001f #/* wants any major artifact */ 144 | M3_WAITFORU = 0x0040 #/* waits to see you or get attacked */ 145 | M3_CLOSE = 0x0080 #/* lets you close unless attacked */ 146 | 147 | M3_COVETOUS = 0x001f #/* wants something */ 148 | M3_WAITMASK = 0x00c0 #/* waiting... */ 149 | 150 | #/* Infravision is currently implemented for players only */ 151 | M3_INFRAVISION = 0x0100 #/* has infravision */ 152 | M3_INFRAVISIBLE = 0x0200 #/* visible by infravision */ 153 | 154 | M3_DISPLACES = 0x0400 #/* moves monsters out of its way */ 155 | 156 | MZ_TINY = 0 #/* < 2' */ 157 | MZ_SMALL = 1 #/* 2-4' */ 158 | MZ_MEDIUM = 2 #/* 4-7' */ 159 | MZ_HUMAN = MZ_MEDIUM #/* human-sized */ 160 | MZ_LARGE = 3 #/* 7-12' */ 161 | MZ_HUGE = 4 #/* 12-25' */ 162 | MZ_GIGANTIC = 7 #/* off the scale */ 163 | 164 | #/* Monster races -- must stay within ROLE_RACEMASK */ 165 | #/* Eventually this may become its own field */ 166 | MH_HUMAN = M2_HUMAN 167 | MH_ELF = M2_ELF 168 | MH_DWARF = M2_DWARF 169 | MH_GNOME = M2_GNOME 170 | MH_ORC = M2_ORC 171 | 172 | #/* for mons[].geno (constant during game) */ 173 | G_UNIQ = 0x1000 #/* generated only once */ 174 | G_NOHELL = 0x0800 #/* not generated in "hell" */ 175 | G_HELL = 0x0400 #/* generated only in "hell" */ 176 | G_NOGEN = 0x0200 #/* generated only specially */ 177 | G_SGROUP = 0x0080 #/* appear in small groups normally */ 178 | G_LGROUP = 0x0040 #/* appear in large groups normally */ 179 | G_GENO = 0x0020 #/* can be genocided */ 180 | G_NOCORPSE = 0x0010 #/* no corpse left ever */ 181 | G_FREQ = 0x0007 #/* creation frequency mask */ 182 | 183 | #/* for mvitals[].mvflags (variant during game), along with G_NOCORPSE */ 184 | G_KNOWN = 0x0004 #/* have been encountered */ 185 | G_GENOD = 0x0002 #/* have been genocided */ 186 | G_EXTINCT = 0x0001 #/* have been extinguished as population control */ 187 | G_GONE = (G_GENOD | G_EXTINCT) 188 | MV_KNOWS_EGG = 0x0008 #/* player recognizes egg of this monster type */ 189 | 190 | S_ANT = 1#, /* a */ 191 | S_BLOB = 2#, /* b */ 192 | S_COCKATRICE = 3#, /* c */ 193 | S_DOG = 4#, /* d */ 194 | S_EYE = 5#, /* e */ 195 | S_FELINE = 6#, /* f: cats */ 196 | S_GREMLIN = 7#, /* g */ 197 | S_HUMANOID = 8#, /* h: small humanoids: hobbit, dwarf */ 198 | S_IMP = 9#, /* i: minor demons */ 199 | S_JELLY = 10#, /* j */ 200 | S_KOBOLD = 11#, /* k */ 201 | S_LEPRECHAUN = 12#, /* l */ 202 | S_MIMIC = 13#, /* m */ 203 | S_NYMPH = 14#, /* n */ 204 | S_ORC = 15#, /* o */ 205 | S_PIERCER = 16#, /* p */ 206 | S_QUADRUPED = 17#, /* q: excludes horses */ 207 | S_RODENT = 18#, /* r */ 208 | S_SPIDER = 19#, /* s */ 209 | S_TRAPPER = 20#, /* t */ 210 | S_UNICORN = 21#, /* u: includes horses */ 211 | S_VORTEX = 22#, /* v */ 212 | S_WORM = 23#, /* w */ 213 | S_XAN = 24#, /* x */ 214 | S_LIGHT = 25#, /* y: yellow light, black light */ 215 | S_ZRUTY = 26#, /* z */ 216 | S_ANGEL = 27#, /* A */ 217 | S_BAT = 28#, /* B */ 218 | S_CENTAUR = 29#, /* C */ 219 | S_DRAGON = 30#, /* D */ 220 | S_ELEMENTAL = 31#, /* E: includes invisible stalker */ 221 | S_FUNGUS = 32#, /* F */ 222 | S_GNOME = 33#, /* G */ 223 | S_GIANT = 34#, /* H: large humanoid: giant, ettin, minotaur */ 224 | S_invisible = 35#, /* I: non-class present in def_monsyms[] */ 225 | S_JABBERWOCK = 36#, /* J */ 226 | S_KOP = 37#, /* K */ 227 | S_LICH = 38#, /* L */ 228 | S_MUMMY = 39#, /* M */ 229 | S_NAGA = 40#, /* N */ 230 | S_OGRE = 41#, /* O */ 231 | S_PUDDING = 42#, /* P */ 232 | S_QUANTMECH = 43#, /* Q */ 233 | S_RUSTMONST = 44#, /* R */ 234 | S_SNAKE = 45#, /* S */ 235 | S_TROLL = 46#, /* T */ 236 | S_UMBER = 47#, /* U: umber hulk */ 237 | S_VAMPIRE = 48#, /* V */ 238 | S_WRAITH = 49#, /* W */ 239 | S_XORN = 50#, /* X */ 240 | S_YETI = 51#, /* Y: includes owlbear, monkey */ 241 | S_ZOMBIE = 52#, /* Z */ 242 | S_HUMAN = 53#, /* @ */ 243 | S_GHOST = 54#, /* */ 244 | S_GOLEM = 55#, /* ' */ 245 | S_DEMON = 56#, /* & */ 246 | S_EEL = 57#, /* ; (fish) */ 247 | S_LIZARD = 58#, /* : (reptiles) */ 248 | 249 | S_WORM_TAIL = 59#, /* ~ */ 250 | S_MIMIC_DEF = 60#, /* ] */ 251 | 252 | MAXMCLASSES = 61# /* number of monster classes */ 253 | -------------------------------------------------------------------------------- /muzero/nethack.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import datetime 3 | import os 4 | import multiprocessing 5 | import traceback 6 | 7 | import gym 8 | 9 | from .abstract_game import AbstractGame 10 | from autoascend.env_wrapper import EnvWrapper 11 | 12 | 13 | class MuZeroConfig: 14 | def __init__(self, rl_model=None): 15 | 16 | # More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization 17 | 18 | self.seed = 0 # Seed for numpy, torch and the game 19 | self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available 20 | 21 | 22 | 23 | ### Game 24 | if rl_model is None: 25 | game = Game() 26 | game._kill_thread() 27 | rl_model = game.rl_model 28 | del game 29 | if 'HACKDIR' in os.environ: 30 | del os.environ['HACKDIR'] # nle leave some trashes that need to be cleaned up 31 | 32 | self.observation_shape = rl_model.observation_shape() # Dimensions of the game observation, must be 3D (channel, height, width). For a 1D array, please reshape it to (1, 1, length of array) 33 | self.action_space = list(range(len(rl_model.action_space))) # Fixed list of all possible actions. You should only edit the length 34 | 35 | #self.observation_shape = (98, 1, 1) 36 | #self.action_space = list(range(9)) 37 | self.players = list(range(1)) # List of players. You should only edit the length 38 | self.stacked_observations = 0 # Number of previous observations and previous actions to add to the current observation 39 | 40 | # Evaluate 41 | self.muzero_player = 0 # Turn Muzero begins to play (0: MuZero plays first, 1: MuZero plays second) 42 | self.opponent = None # Hard coded agent that MuZero faces to assess his progress in multiplayer games. It doesn't influence training. None, "random" or "expert" if implemented in the Game class 43 | 44 | 45 | 46 | ### Self-Play 47 | self.num_workers = 40 # Number of simultaneous threads/workers self-playing to feed the replay buffer 48 | self.selfplay_on_gpu = False 49 | self.max_moves = 1e6 # Maximum number of moves if game is not finished before 50 | self.num_simulations = 5 # Number of future moves self-simulated 51 | self.discount = 0.997 # Chronological discount of the reward 52 | self.temperature_threshold = None # Number of moves before dropping the temperature given by visit_softmax_temperature_fn to 0 (ie selecting the best action). If None, visit_softmax_temperature_fn is used every time 53 | 54 | # Root prior exploration noise 55 | self.root_dirichlet_alpha = 0.25 56 | self.root_exploration_fraction = 0.25 57 | 58 | # UCB formula 59 | self.pb_c_base = 19652 60 | self.pb_c_init = 1.25 61 | 62 | 63 | 64 | ### Network 65 | self.network = "resnet" # "resnet" / "fullyconnected" 66 | self.support_size = 300 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward))) 67 | 68 | # Residual Network 69 | self.downsample = None # Downsample observations before representation network, False / "CNN" (lighter) / "resnet" (See paper appendix Network Architecture) 70 | self.blocks = 8 # Number of blocks in the ResNet 71 | self.channels = 128 # Number of channels in the ResNet 72 | self.reduced_channels_reward = 256 # Number of channels in reward head 73 | self.reduced_channels_value = 256 # Number of channels in value head 74 | self.reduced_channels_policy = 256 # Number of channels in policy head 75 | self.resnet_fc_reward_layers = [256, 256] # Define the hidden layers in the reward head of the dynamic network 76 | self.resnet_fc_value_layers = [256, 256] # Define the hidden layers in the value head of the prediction network 77 | self.resnet_fc_policy_layers = [256, 256] # Define the hidden layers in the policy head of the prediction network 78 | 79 | # Fully Connected Network 80 | self.encoding_size = 10 81 | self.fc_representation_layers = [] # Define the hidden layers in the representation network 82 | self.fc_dynamics_layers = [16] # Define the hidden layers in the dynamics network 83 | self.fc_reward_layers = [16] # Define the hidden layers in the reward network 84 | self.fc_value_layers = [] # Define the hidden layers in the value network 85 | self.fc_policy_layers = [] # Define the hidden layers in the policy network 86 | 87 | 88 | 89 | ### Training 90 | self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "/checkpoints", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs 91 | self.save_model = True # Save the checkpoint in results_path as model.checkpoint 92 | self.training_steps = int(1000e3) # Total number of training steps (ie weights update according to a batch) 93 | self.batch_size = 128 # Number of parts of games to train on at each training step 94 | self.checkpoint_interval = int(1e3) # Number of training steps before using the model for self-playing 95 | self.value_loss_weight = 0.25 # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze) 96 | self.train_on_gpu = True #torch.cuda.is_available() # Train on GPU if available 97 | 98 | self.optimizer = "Adam" # "Adam" or "SGD". Paper uses SGD 99 | self.weight_decay = 1e-4 # L2 weights regularization 100 | self.momentum = 0.9 # Used only if optimizer is SGD 101 | 102 | # Exponential learning rate schedule 103 | self.lr_init = 0.001 # Initial learning rate 104 | self.lr_decay_rate = 0.1 # Set it to 1 to use a constant learning rate 105 | self.lr_decay_steps = 350e3 106 | 107 | 108 | 109 | ### Replay Buffer 110 | self.replay_buffer_size = 1e6 # Number of self-play games to keep in the replay buffer 111 | self.num_unroll_steps = 5 # Number of game moves to keep for every batch element 112 | self.td_steps = 10 # Number of steps in the future to take into account for calculating the target value 113 | self.PER = True # Prioritized Replay (See paper appendix Training), select in priority the elements in the replay buffer which are unexpected for the network 114 | self.PER_alpha = 1 # How much prioritization is used, 0 corresponding to the uniform case, paper suggests 1 115 | 116 | # Reanalyze (See paper appendix Reanalyse) 117 | self.use_last_model_value = True # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) 118 | self.reanalyse_on_gpu = False 119 | 120 | 121 | 122 | ### Adjust the self play / training ratio to avoid over/underfitting 123 | self.self_play_delay = 0 # Number of seconds to wait after each played game 124 | self.training_delay = 0 # Number of seconds to wait after each training step 125 | self.ratio = None # Desired training steps per self played step ratio. Equivalent to a synchronous version, training can take much longer. Set it to None to disable it 126 | 127 | 128 | def visit_softmax_temperature_fn(self, trained_steps): 129 | """ 130 | Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses. 131 | The smaller it is, the more likely the best action (ie with the highest visit count) is chosen. 132 | 133 | Returns: 134 | Positive float. 135 | """ 136 | if trained_steps < 500e3: 137 | return 1.0 138 | elif trained_steps < 750e3: 139 | return 0.5 140 | else: 141 | return 0.25 142 | 143 | 144 | class Game(AbstractGame): 145 | """ 146 | Game wrapper. 147 | """ 148 | 149 | def __init__(self, seed=None): 150 | self.rl_model = None 151 | self.last_score = 0 152 | self._start_thread() 153 | self.current_legal_actions = list(range(len(self.rl_model.action_space))) 154 | 155 | @staticmethod 156 | def _create_env(seed, output_queue, input_queue, env=None): 157 | if env is None: 158 | env = gym.make('NetHackChallenge-v0') 159 | env = EnvWrapper(env, 160 | agent_args=dict(rl_model_to_train='fight2', 161 | rl_model_training_comm=(output_queue, input_queue))) 162 | if seed is not None: 163 | env.env.seed(seed, seed) 164 | return env 165 | 166 | def _start_thread(self): 167 | output_queue, input_queue = multiprocessing.Queue(), multiprocessing.Queue() 168 | self.output_queue = output_queue 169 | self.input_queue = input_queue 170 | def f(): 171 | try: 172 | env = Game._create_env(None, output_queue, input_queue) 173 | env.main() 174 | except BaseException as e: 175 | print(f'exception: {"".join(traceback.format_exception(None, e, e.__traceback__))}') 176 | finally: 177 | output_queue.put((None, None, None)) 178 | 179 | self.thread = multiprocessing.Process(target=f) 180 | self.thread.start() 181 | self.rl_model = self.output_queue.get() 182 | 183 | def _kill_thread(self): 184 | res = ctypes.pythonapi.PyThreadState_SetAsyncExc(self.thread.ident, 185 | ctypes.py_object(KeyboardInterrupt)) 186 | if res > 1: 187 | ctypes.pythonapi.PyThreadState_SetAsyncExc(self.thread.ident, 0) 188 | raise RuntimeError('Exception raise failure') 189 | self.thread.terminate() 190 | self.thread.join() 191 | 192 | def step(self, action): 193 | """ 194 | Apply action to the game. 195 | 196 | Args: 197 | action : action of the action_space to take. 198 | 199 | Returns: 200 | The new observation, the reward and a boolean if the game has ended. 201 | """ 202 | self.input_queue.put(action) 203 | val = self.output_queue.get() 204 | observation, self.current_legal_actions, score = val 205 | if observation is None: 206 | is_done = True 207 | observation = self.rl_model.zero_observation() 208 | score = self.last_score 209 | else: 210 | is_done = False 211 | reward = score - self.last_score 212 | self.last_score = score 213 | return self.rl_model.encode_observation(observation), reward, is_done 214 | 215 | def legal_actions(self): 216 | """ 217 | Should return the legal actions at each turn, if it is not available, it can return 218 | the whole action space. At each turn, the game have to be able to handle one of returned actions. 219 | 220 | For complex game where calculating legal moves is too long, the idea is to define the legal actions 221 | equal to the action space but to return a negative reward if the action is illegal. 222 | 223 | Returns: 224 | An array of integers, subset of the action space. 225 | """ 226 | return self.current_legal_actions 227 | 228 | def reset(self): 229 | """ 230 | Reset the game for a new game. 231 | 232 | Returns: 233 | Initial observation of the game. 234 | """ 235 | while 1: 236 | self._kill_thread() 237 | self.last_score = 0 238 | self._start_thread() 239 | 240 | val = self.output_queue.get() 241 | observation, self.current_legal_actions, score = val 242 | self.last_score = score 243 | if observation is not None: 244 | break 245 | return self.rl_model.encode_observation(observation) 246 | 247 | def close(self): 248 | """ 249 | Properly close the game. 250 | """ 251 | self._kill_thread() 252 | 253 | def render(self): 254 | """ 255 | Display the game observation. 256 | """ 257 | self.env.render() 258 | input("Press enter to take a step ") 259 | -------------------------------------------------------------------------------- /autoascend/combat/fight_heur.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from itertools import product 3 | 4 | import numpy as np 5 | from scipy import signal 6 | 7 | from ..glyph import G 8 | from ..utils import adjacent 9 | from .monster_utils import is_monster_faster, is_dangerous_monster, \ 10 | ONLY_RANGED_SLOW_MONSTERS, EXPLODING_MONSTERS, WEAK_MONSTERS, consider_melee_only_ranged_if_hp_full 11 | from .movement_priority import draw_monster_priority_positive, draw_monster_priority_negative 12 | from .utils import wielding_ranged_weapon, line_dis_from, inside 13 | 14 | 15 | def melee_monster_priority(agent, monsters, monster): 16 | _, y, x, mon, _ = monster 17 | ret = 1 18 | if agent.blstats.hitpoints > 8 or is_monster_faster(agent, monster): 19 | ret += 15 20 | if wielding_ranged_weapon(agent) and not is_monster_faster(agent, monster): 21 | ret -= 6 22 | if mon.mname in EXPLODING_MONSTERS: 23 | ret -= 17 24 | if 'were' in mon.mname: 25 | ret += 1 26 | # if not wielding_melee_weapon(agent): 27 | # ret -= 5 28 | if mon.mname in ONLY_RANGED_SLOW_MONSTERS: 29 | if not consider_melee_only_ranged_if_hp_full(agent, monster): 30 | ret -= 100 31 | if mon.mname == 'floating eye': 32 | ret -= 10 33 | if mon.mname == 'gas spore': 34 | ret -= 5 35 | 36 | if mon.mname == 'gas spore': 37 | # handle a specific case when you are trapped by a gas spore 38 | if len(agent.get_visible_monsters()) == 1 \ 39 | and agent.blstats.hitpoints / agent.blstats.max_hitpoints: 40 | dis = agent.bfs() 41 | for y2, x2 in zip(*np.nonzero(dis != -1)): 42 | if not adjacent((y, x), (y2, x2)): 43 | return ret 44 | agent.stats_logger.log_event('melee_gas_spore') 45 | return 1 # a priority higher than random moving around 46 | 47 | return ret 48 | 49 | 50 | def ranged_priority(agent, dy, dx, monsters): 51 | ret = 11 52 | 53 | closest_mon_dis = float('inf') 54 | for monster in monsters: 55 | _, my, mx, mon, _ = monster 56 | assert my != agent.blstats.y or mx != agent.blstats.x 57 | if mon.mname not in WEAK_MONSTERS + ONLY_RANGED_SLOW_MONSTERS: 58 | closest_mon_dis = min(closest_mon_dis, line_dis_from(agent, my, mx)) 59 | 60 | if closest_mon_dis == 1: 61 | ret -= 11 62 | 63 | launcher, ammo = agent.inventory.get_best_ranged_set() 64 | if ammo is None: 65 | return None 66 | 67 | if launcher is not None and not launcher.equipped: 68 | ret -= 5 69 | 70 | y, x = agent.blstats.y, agent.blstats.x 71 | while True: 72 | y += dy 73 | x += dx 74 | if not 0 <= y < agent.glyphs.shape[0] or not 0 <= x < agent.glyphs.shape[1]: 75 | return None 76 | 77 | if agent.glyphs[y, x] in G.PETS or not agent.current_level().walkable[y, x]: 78 | return None 79 | 80 | if agent.glyphs[y, x] in G.MONS: 81 | monster = [m for m in monsters if m[1] == y and m[2] == x] 82 | if not monster: 83 | # there is a monster that shouldn't be attacked 84 | return None 85 | assert len(monster) == 1 86 | _, _, _, mon, _ = monster[0] 87 | dis = line_dis_from(agent, y, x) 88 | if dis > agent.character.get_range(launcher, ammo): 89 | return None 90 | if dis in (1, 2): 91 | ret -= 5 92 | if dis == 1: 93 | ret -= 6 94 | if mon.mname == 'gas spore': # only gas spore ? 95 | ret -= 100 96 | return ret, y, x, monster[0] 97 | 98 | 99 | def get_next_states(agent, wand, y, x, dy, dx): 100 | if not inside(agent, y, x) or not agent.current_level().walkable[y, x]: 101 | can_bounce = wand.is_ray_wand() 102 | if not can_bounce: 103 | return [] 104 | if dy == 0 or dx == 0: 105 | return [(y - dy, x - dx, -dy, -dx, 1.0, 1)] 106 | # TODO: diagonal 107 | side1 = (y, x - dx) 108 | side2 = (y - dy, x) 109 | side1_wall = not inside(agent, *side1) or not agent.current_level().walkable[side1] 110 | side2_wall = not inside(agent, *side2) or not agent.current_level().walkable[side2] 111 | dy1, dx1 = side2[0] - side1[0], side2[1] - side1[1] 112 | dy2, dx2 = side1[0] - side2[0], side1[1] - side2[1] 113 | if side1_wall and side2_wall: 114 | return [(y - dy, x - dx, -dy, -dx, 1.0, 1)] 115 | elif not side1_wall and not side2_wall: 116 | return [(y - dy, x - dx, -dy, -dx, 1 / 20, 1), 117 | (y + dy1, x + dx1, dy1, dx1, 19 / 40, 1), 118 | (y + dy2, x + dx2, dy2, dx2, 19 / 40, 1)] 119 | elif side1_wall: 120 | return [(y + dy1, x + dx1, dy1, dx1, 1.0, 1)] 121 | elif side2_wall: 122 | return [(y + dy2, x + dx2, dy2, dx2, 1.0, 1)] 123 | else: 124 | assert 0 125 | return [(y + dy, x + dx, dy, dx, 1.0, 0)] 126 | 127 | 128 | def _simulate_wand_path(agent, wand, monsters, y, x, dy, dx, range_left, hit_targets, probability): 129 | if range_left < 0: 130 | return 131 | 132 | for y, x, dy, dx, next_prob, range_penalty in get_next_states(agent, wand, y, x, dy, dx): 133 | range_left -= range_penalty 134 | monster = [m for m in monsters if m[1] == y and m[2] == x] 135 | if monster: 136 | assert len(monster) == 1 137 | monster = monster[0] 138 | # For each monster hit, range decreases by 2. 139 | range_left -= 2 140 | elif inside(agent, y, x) and agent.glyphs[y, x] in G.PETS: 141 | monster = 'pet' 142 | # For each monster hit, range decreases by 2. 143 | range_left -= 2 144 | elif agent.blstats.y == y and agent.blstats.x == x: 145 | monster = 'self' 146 | range_left -= 2 147 | else: 148 | monster = None 149 | 150 | hit_targets[(y, x, monster)] += probability * next_prob 151 | 152 | _simulate_wand_path(agent, wand, monsters, y, x, dy, dx, range_left - 1, hit_targets, 1.0) 153 | 154 | 155 | def simulate_wand_path(agent, wand, monsters, dy, dx): 156 | """ Returns list of tuples (y, x, hit_object, expected_hit_count). 157 | """ 158 | y, x = agent.blstats.y, agent.blstats.x 159 | 160 | # TODO: random range left from 6 or 7 to 13 161 | hit_targets = defaultdict(int) 162 | _simulate_wand_path(agent, wand, monsters, y, x, dy, dx, 13, hit_targets, 1.0) 163 | for (y, x, hit_object), expected_hit_count in hit_targets.items(): 164 | yield y, x, hit_object, expected_hit_count 165 | 166 | 167 | def get_potential_wand_usages(agent, monsters, dy, dx): 168 | ret = [] 169 | player_hp_ratio = agent.blstats.hitpoints / agent.blstats.max_hitpoints 170 | # TODO: also get items recursively from bags 171 | for item in agent.inventory.items: 172 | targeted_monsters = set() 173 | if not item.is_offensive_usable_wand(): 174 | continue 175 | priority = 0 176 | # print('--------------', dy, dx) 177 | for y, x, monster, p in simulate_wand_path(agent, item, monsters, dy, dx): 178 | # print(y, x, monster, p) 179 | if monster == 'pet': 180 | priority -= p * 20 181 | elif monster == 'self': 182 | priority -= p * 30 183 | elif monster is not None: 184 | _, y, x, mon, _ = monster 185 | if mon.mname in WEAK_MONSTERS: 186 | priority += min(p, 1) * 1 187 | elif is_dangerous_monster(monster): 188 | priority += p * 25 189 | else: 190 | priority += min(p, 1) * 10 191 | targeted_monsters.add((y, x, monster)) 192 | if targeted_monsters: 193 | # priority = priority * (1 - player_hp_ratio) - 10 194 | priority = priority - 15 195 | if agent.inventory.engraving_below_me.lower() == 'elbereth': 196 | priority -= 100 197 | ret.append((priority, ('zap', dy, dx, item, targeted_monsters))) 198 | return ret 199 | 200 | 201 | def elbereth_action(agent, monsters): 202 | if agent.inventory.engraving_below_me.lower() == 'elbereth': 203 | return [] 204 | if not agent.can_engrave(): 205 | return [] 206 | adj_monsters_count = 0 207 | for monster in monsters: 208 | _, my, mx, mon, _ = monster 209 | if mon.mname in ONLY_RANGED_SLOW_MONSTERS: 210 | continue 211 | if not adjacent((my, mx), (agent.blstats.y, agent.blstats.x)): 212 | continue 213 | multiplier = np.clip(20 / agent.blstats.hitpoints, 1.0, 1.5) 214 | if is_monster_faster(agent, monster): 215 | multiplier *= 2 216 | if mon in WEAK_MONSTERS: 217 | adj_monsters_count += 0.1 * multiplier 218 | continue 219 | adj_monsters_count += 1 * multiplier 220 | if is_dangerous_monster(monster): 221 | adj_monsters_count += 2 * multiplier 222 | 223 | player_hp_ratio = (agent.blstats.hitpoints / agent.blstats.max_hitpoints) ** 0.5 224 | if agent.blstats.hitpoints < 30 and adj_monsters_count > 0: 225 | return [(-15 + 20 * adj_monsters_count * (1 - player_hp_ratio), ('elbereth',))] 226 | return [] 227 | 228 | 229 | def wait_action(agent, monsters): 230 | if agent.inventory.engraving_below_me.lower() == 'elbereth': 231 | player_hp_ratio = agent.blstats.hitpoints / agent.blstats.max_hitpoints 232 | priority = 30 - player_hp_ratio * 40 233 | return [(priority, ('wait',))] 234 | return [] 235 | 236 | 237 | def get_available_actions(agent, monsters): 238 | actions = [] 239 | 240 | # melee attack actions 241 | for monster in monsters: 242 | _, y, x, mon, _ = monster 243 | if adjacent((y, x), (agent.blstats.y, agent.blstats.x)): 244 | priority = melee_monster_priority(agent, monsters, monster) 245 | if agent.inventory.engraving_below_me.lower() == 'elbereth': 246 | priority -= 100 247 | dy = y - agent.blstats.y 248 | dx = x - agent.blstats.x 249 | actions.append((priority, ('melee', dy, dx))) 250 | 251 | # ranged attack actions 252 | for dy, dx in product([-1, 0, 1], [-1, 0, 1]): 253 | if dy != 0 or dx != 0: 254 | ranged_pr = ranged_priority(agent, dy, dx, monsters) 255 | if ranged_pr is not None: 256 | pri, y, x, monster = ranged_pr 257 | if agent.inventory.engraving_below_me.lower() == 'elbereth': 258 | pri -= 100 259 | if all(monster[3].mname in ONLY_RANGED_SLOW_MONSTERS for monster in monsters): 260 | pri += 10 261 | actions.append((pri, ('ranged', dy, dx))) 262 | 263 | actions.extend(get_potential_wand_usages(agent, monsters, dy, dx)) 264 | 265 | to_pickup = decide_what_to_pickup(agent) 266 | if to_pickup: 267 | actions.append((15, ('pickup', to_pickup))) 268 | 269 | actions.extend(elbereth_action(agent, monsters)) 270 | actions.extend(wait_action(agent, monsters)) 271 | 272 | return actions 273 | 274 | 275 | def decide_what_to_pickup(agent): 276 | projectiles_below_me = [i for i in agent.inventory.items_below_me 277 | if i.is_thrown_projectile() or i.is_fired_projectile()] 278 | my_launcher, ammo = agent.inventory.get_best_ranged_set(additional_ammo=[i for i in projectiles_below_me]) 279 | to_pickup = [] 280 | for item in agent.inventory.items_below_me: 281 | if item.is_thrown_projectile() or (my_launcher is not None and item.is_fired_projectile(launcher=my_launcher)): 282 | to_pickup.append(item) 283 | return to_pickup 284 | 285 | 286 | def goto_action(agent, priority, monsters): 287 | values = [] 288 | walkable = agent.current_level().walkable 289 | for dy, dx in [(-1, -1), (-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1)]: 290 | y, x = agent.blstats.y - dy, agent.blstats.x - dx 291 | if not 0 <= y < walkable.shape[0] or not 0 <= x < walkable.shape[1]: 292 | continue 293 | if not np.isnan(priority[y, x]): 294 | values.append(priority[y, x]) 295 | if len(set(values)) > 1: 296 | return [] 297 | 298 | assert monsters 299 | for monster in monsters: 300 | _, my, mx, mon, _ = monster 301 | if not adjacent((agent.blstats.y, agent.blstats.x), (my, mx)): 302 | # and not mon.mname in ONLY_RANGED_SLOW_MONSTERS: 303 | return [(1, ('go_to', my, mx))] 304 | assert 0, monsters 305 | 306 | 307 | def get_corridors_priority_map(walkable): 308 | k = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) 309 | wall_count = signal.convolve2d((~walkable).astype(int), k, boundary='symm', mode='same') 310 | corridor_mask = (wall_count == 6).astype(int) 311 | corridor_mask[~walkable] = 0 312 | corridor_dilated = signal.convolve2d(corridor_mask.astype(int), k, boundary='symm', mode='same') 313 | return corridor_mask + corridor_dilated >= 1 314 | 315 | 316 | def get_priorities(agent): 317 | """ Returns a pair (move priority heatmap, other actions (with priorities) list) """ 318 | walkable = agent.current_level().walkable 319 | priority = np.zeros(walkable.shape, dtype=float) 320 | monsters = agent.get_visible_monsters() 321 | for m in monsters: 322 | draw_monster_priority_positive(agent, m, priority, walkable) 323 | for m in monsters: 324 | draw_monster_priority_negative(agent, m, priority, walkable) 325 | priority[~walkable] = float('nan') 326 | 327 | # TODO: figure out how to use corridors priority so that it improves the score 328 | # if len([m for m in monsters if m[3].mname not in chain(ONLY_RANGED_SLOW_MONSTERS, WEAK_MONSTERS)]) >= 4: 329 | # priority += get_corridors_priority_map(walkable) 330 | # for _, _, _, mon, _ in monsters: 331 | # if ord(mon.mlet) == MON.S_ANT: 332 | # priority += get_corridors_priority_map(walkable) 333 | # break 334 | 335 | # use relative priority to te current position 336 | priority -= priority[agent.blstats.y, agent.blstats.x] 337 | 338 | actions = get_available_actions(agent, monsters) 339 | if not any(a[1][0] in ('melee', 'ranged') for a in actions): 340 | actions.extend(goto_action(agent, priority, monsters)) 341 | return priority, actions 342 | 343 | 344 | def get_move_actions(agent, dis, move_priority_heatmap): 345 | """ Returns list of tuples (priority, ('move', dy, dx)) """ 346 | ret = [] 347 | for dy, dx in [(-1, -1), (-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1), (-1, -1)]: 348 | y, x = agent.blstats.y + dy, agent.blstats.x + dx 349 | if not 0 <= y < dis.shape[0] or not 0 <= x < dis.shape[1]: 350 | continue 351 | if not dis[y, x] == 1: 352 | continue 353 | 354 | if not np.isnan(move_priority_heatmap[y, x]): 355 | ret.append((move_priority_heatmap[y, x], ('move', dy, dx))) 356 | return ret 357 | -------------------------------------------------------------------------------- /autoascend/env_wrapper.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import contextlib 3 | import gc 4 | import multiprocessing 5 | import os 6 | import shutil 7 | import sys 8 | import tempfile 9 | import termios 10 | import tty 11 | from pathlib import Path 12 | from pprint import pprint 13 | 14 | import nle.nethack as nh 15 | 16 | from autoascend.visualization import visualizer 17 | from autoascend import agent as agent_lib # the library can be reloaded in `reload_agent` function 18 | 19 | 20 | def fork_with_nethack_env(env): 21 | tmpdir = tempfile.mkdtemp(prefix='nlecopy_') 22 | shutil.copytree(env.env._vardir, tmpdir, dirs_exist_ok=True) 23 | env.env._tempdir = None # it has to be done before the fork to avoid removing the same directory two times 24 | gc.collect() 25 | 26 | pid = os.fork() 27 | 28 | env.env._tempdir = tempfile.TemporaryDirectory(prefix='nlefork_') 29 | shutil.copytree(tmpdir, env.env._tempdir.name, dirs_exist_ok=True) 30 | env.env._vardir = env.env._tempdir.name 31 | os.chdir(env.env._vardir) 32 | return pid 33 | 34 | 35 | def reload_agent(base_path=str(Path(__file__).parent.absolute())): 36 | global visualize, agent_lib 37 | visualize = agent_lib = None 38 | modules_to_remove = [] 39 | for k, m in sys.modules.items(): 40 | if hasattr(m, '__file__') and m.__file__ and m.__file__.startswith(base_path): 41 | modules_to_remove.append(k) 42 | del m 43 | 44 | gc.collect() 45 | while modules_to_remove: 46 | for k in modules_to_remove: 47 | assert sys.getrefcount(sys.modules[k]) >= 2 48 | if sys.getrefcount(sys.modules[k]) == 2: 49 | sys.modules.pop(k) 50 | modules_to_remove.remove(k) 51 | gc.collect() 52 | break 53 | else: 54 | assert 0, ('cannot unload agent library', 55 | {k: sys.getrefcount(sys.modules[k]) for k in modules_to_remove}) 56 | 57 | 58 | class ReloadAgent(KeyboardInterrupt): 59 | # it inherits from KeyboardInterrupt as the agent doesn't catch that exception 60 | pass 61 | 62 | 63 | class EnvWrapper: 64 | def __init__(self, env, to_skip=0, visualizer_args=dict(enable=False), 65 | step_limit=None, agent_args={}, interactive=False): 66 | self.env = env 67 | self.agent_args = agent_args 68 | self.interactive = interactive 69 | self.to_skip = to_skip 70 | self.step_limit = step_limit 71 | self.visualizer = None 72 | if visualizer_args['enable']: 73 | visualizer_args.pop('enable') 74 | self.visualizer = visualizer.Visualizer(self, **visualizer_args) 75 | self.last_observation = None 76 | self.agent = None 77 | 78 | self.draw_walkable = False 79 | self.draw_seen = False 80 | self.draw_shop = False 81 | 82 | self.is_done = False 83 | 84 | def _init_agent(self): 85 | self.agent = agent_lib.Agent(self, **self.agent_args) 86 | 87 | def main(self): 88 | self.reset() 89 | while 1: 90 | try: 91 | self._init_agent() 92 | self.agent.main() 93 | break 94 | except ReloadAgent: 95 | pass 96 | finally: 97 | self.render() 98 | 99 | self.agent = None 100 | reload_agent() 101 | 102 | def reset(self): 103 | obs = self.env.reset() 104 | self.score = 0 105 | self.step_count = 0 106 | self.end_reason = '' 107 | self.last_observation = obs 108 | self.is_done = False 109 | 110 | if self.agent is not None: 111 | self.render() 112 | 113 | agent_lib.G.assert_map(obs['glyphs'], obs['chars']) 114 | 115 | blstats = agent_lib.BLStats(*obs['blstats']) 116 | assert obs['chars'][blstats.y, blstats.x] == ord('@') 117 | 118 | return obs 119 | 120 | def fork(self): 121 | fork_again = True 122 | while fork_again: 123 | pid = fork_with_nethack_env(self.env) 124 | if pid != 0: 125 | # parent 126 | print('freezing parent') 127 | while 1: 128 | try: 129 | os.waitpid(pid, 0) 130 | break 131 | except KeyboardInterrupt: 132 | pass 133 | self.visualizer.force_next_frame() 134 | self.visualizer.render() 135 | while 1: 136 | try: 137 | fork_again = input('fork again [yn]: ') 138 | if fork_again == 'y': 139 | fork_again = True 140 | break 141 | elif fork_again == 'n': 142 | fork_again = False 143 | break 144 | except KeyboardInterrupt: 145 | pass 146 | 147 | termios.tcgetattr(sys.stdin) 148 | tty.setcbreak(sys.stdin.fileno()) 149 | else: 150 | # child 151 | atexit.unregister(multiprocessing.util._exit_function) 152 | self.visualizer.force_next_frame() 153 | self.visualizer.render() 154 | break 155 | 156 | def render(self, force=False): 157 | if self.visualizer is not None: 158 | with self.debug_tiles(self.agent.current_level().walkable, color=(0, 255, 0, 128)) \ 159 | if self.draw_walkable else contextlib.suppress(): 160 | with self.debug_tiles(~self.agent.current_level().seen, color=(255, 0, 0, 128)) \ 161 | if self.draw_seen else contextlib.suppress(): 162 | with self.debug_tiles(self.agent.current_level().shop, color=(0, 0, 255, 64)) \ 163 | if self.draw_shop else contextlib.suppress(): 164 | with self.debug_tiles(self.agent.current_level().shop_interior, color=(0, 0, 255, 64)) \ 165 | if self.draw_shop else contextlib.suppress(): 166 | with self.debug_tiles((self.last_observation['specials'] & nh.MG_OBJPILE) > 0, 167 | color=(0, 255, 255, 128)): 168 | with self.debug_tiles([self.agent.cursor_pos], 169 | color=(255, 255, 255, 128)): 170 | if force: 171 | self.visualizer.force_next_frame() 172 | rendered = self.visualizer.render() 173 | 174 | if not force and (not self.interactive or not rendered): 175 | return 176 | 177 | if self.agent is not None: 178 | print('Message:', self.agent.message) 179 | print('Pop-up :', self.agent.popup) 180 | print() 181 | if self.agent is not None and hasattr(self.agent, 'blstats'): 182 | print(agent_lib.BLStats(*self.last_observation['blstats'])) 183 | print(f'Carrying: {self.agent.inventory.items.total_weight} / {self.agent.character.carrying_capacity}') 184 | print('Character:', self.agent.character) 185 | print('Misc :', self.last_observation['misc']) 186 | print('Score:', self.score) 187 | print('Steps:', self.env._steps) 188 | print('Turns:', self.env._turns) 189 | print('Seed :', self.env.get_seeds()) 190 | print('Items below me :', self.agent.inventory.items_below_me) 191 | print('Engraving below me:', self.agent.inventory.engraving_below_me) 192 | print() 193 | print(self.agent.inventory.items) 194 | print('-' * 20) 195 | 196 | self.env.render() 197 | print('-' * 20) 198 | print() 199 | 200 | def print_help(self): 201 | scene_glyphs = set(self.env.last_observation[0].reshape(-1)) 202 | obj_classes = {getattr(nh, x): x for x in dir(nh) if x.endswith('_CLASS')} 203 | glyph_classes = sorted((getattr(nh, x), x) for x in dir(nh) if x.endswith('_OFF')) 204 | 205 | texts = [] 206 | for i in range(nh.MAX_GLYPH): 207 | desc = '' 208 | if glyph_classes and i == glyph_classes[0][0]: 209 | cls = glyph_classes.pop(0)[1] 210 | 211 | if nh.glyph_is_monster(i): 212 | desc = f': "{nh.permonst(nh.glyph_to_mon(i)).mname}"' 213 | 214 | if nh.glyph_is_normal_object(i): 215 | obj = nh.objclass(nh.glyph_to_obj(i)) 216 | appearance = nh.OBJ_DESCR(obj) or nh.OBJ_NAME(obj) 217 | oclass = ord(obj.oc_class) 218 | desc = f': {obj_classes[oclass]}: "{appearance}"' 219 | 220 | desc2 = 'Labels: ' 221 | if i in agent_lib.G.INV_DICT: 222 | desc2 += ','.join(agent_lib.G.INV_DICT[i]) 223 | 224 | if i in scene_glyphs: 225 | pos = (self.env.last_observation[0].reshape(-1) == i).nonzero()[0] 226 | count = len(pos) 227 | pos = pos[0] 228 | char = bytes([self.env.last_observation[1].reshape(-1)[pos]]) 229 | texts.append((-count, f'{" " if i in agent_lib.G.INV_DICT else "U"} Glyph {i:4d} -> ' 230 | f'Char: {char} Count: {count:4d} ' 231 | f'Type: {cls.replace("_OFF", ""):11s} {desc:30s} ' 232 | f'{agent_lib.ALL.find(i) if agent_lib.ALL.find(i) is not None else "":20} ' 233 | f'{desc2}')) 234 | for _, t in sorted(texts): 235 | print(t) 236 | 237 | def get_action(self): 238 | while 1: 239 | key = os.read(sys.stdin.fileno(), 5) 240 | 241 | if key == b'\x1bOP': # F1 242 | self.draw_walkable = not self.draw_walkable 243 | self.visualizer.force_next_frame() 244 | self.render() 245 | continue 246 | elif key == b'\x1bOQ': # F2 247 | self.draw_seen = not self.draw_seen 248 | self.visualizer.force_next_frame() 249 | self.render() 250 | continue 251 | 252 | elif key == b'\x1bOR': # F3 253 | self.draw_shop = not self.draw_shop 254 | self.visualizer.force_next_frame() 255 | self.render() 256 | continue 257 | 258 | if key == b'\x1bOS': # F4 259 | raise ReloadAgent() 260 | 261 | if key == b'\x1b[15~': # F5 262 | self.fork() 263 | continue 264 | 265 | elif key == b'\x1b[3~': # Delete 266 | self.to_skip = 16 267 | return None 268 | 269 | if len(key) != 1: 270 | print('wrong key', key) 271 | continue 272 | key = key[0] 273 | if key == 10: 274 | key = 13 275 | 276 | if key == 63: # '?" 277 | self.print_help() 278 | continue 279 | elif key == 127: # Backspace 280 | self.visualizer.force_next_frame() 281 | return None 282 | else: 283 | actions = [a for a in self.env._actions if int(a) == key] 284 | assert len(actions) < 2 285 | if len(actions) == 0: 286 | print('wrong key', key) 287 | continue 288 | 289 | action = actions[0] 290 | return action 291 | 292 | def step(self, agent_action): 293 | if self.visualizer is not None and self.visualizer.video_writer is None: 294 | self.visualizer.step(self.last_observation, repr(chr(int(agent_action)))) 295 | 296 | if self.interactive and self.to_skip <= 1: 297 | self.visualizer.force_next_frame() 298 | self.render() 299 | 300 | if self.interactive: 301 | print() 302 | print('agent_action:', agent_action, repr(chr(int(agent_action)))) 303 | print() 304 | 305 | if self.to_skip > 0: 306 | self.to_skip -= 1 307 | action = None 308 | else: 309 | action = self.get_action() 310 | 311 | if action is None: 312 | action = agent_action 313 | 314 | if self.interactive: 315 | print('action:', action) 316 | print() 317 | else: 318 | if self.visualizer is not None: 319 | self.visualizer.step(self.last_observation, repr(chr(int(agent_action)))) 320 | action = agent_action 321 | 322 | obs, reward, done, info = self.env.step(self.env._actions.index(action)) 323 | self.score += reward 324 | self.step_count += 1 325 | # if not done: 326 | # agent_lib.G.assert_map(obs['glyphs'], obs['chars']) 327 | 328 | # uncomment to debug measure up to assumed median 329 | # if self.score >= 7000: 330 | # done = True 331 | # self.end_reason = 'quit after median' 332 | if done: 333 | if self.visualizer is not None: 334 | self.visualizer.step(self.last_observation, repr(chr(int(agent_action)))) 335 | 336 | end_reason = bytes(obs['tty_chars'].reshape(-1)).decode().replace('You made the top ten list!', '').split() 337 | if end_reason[7].startswith('Agent'): 338 | self.score = int(end_reason[6]) 339 | end_reason = ' '.join(end_reason[8:-2]) 340 | else: 341 | assert self.score == 0, end_reason 342 | end_reason = ' '.join(end_reason[7:-2]) 343 | first_sentence = end_reason.split('.')[0].split() 344 | self.end_reason = info['end_status'].name + ': ' + \ 345 | (' '.join(first_sentence[:first_sentence.index('in')]) + '. ' + 346 | '.'.join(end_reason.split('.')[1:]).strip()).strip() 347 | if self.step_limit is not None and self.step_count == self.step_limit + 1: 348 | self.end_reason = self.end_reason or 'steplimit' 349 | done = True 350 | elif self.step_limit is not None and self.step_count > self.step_limit + 1: 351 | assert 0 352 | 353 | self.last_observation = obs 354 | 355 | if done: 356 | self.is_done = True 357 | if self.visualizer is not None: 358 | self.render() 359 | if self.interactive: 360 | print('Summary:') 361 | pprint(self.get_summary()) 362 | 363 | return obs, reward, done, info 364 | 365 | def debug_tiles(self, *args, **kwargs): 366 | if self.visualizer is not None: 367 | return self.visualizer.debug_tiles(*args, **kwargs) 368 | return contextlib.suppress() 369 | 370 | def debug_log(self, txt, color=(255, 255, 255)): 371 | if self.visualizer is not None: 372 | return self.visualizer.debug_log(txt, color) 373 | return contextlib.suppress() 374 | 375 | def get_summary(self): 376 | return { 377 | 'score': self.score, 378 | 'steps': self.env._steps, 379 | 'turns': self.agent.blstats.time, 380 | 'level_num': len(self.agent.levels), 381 | 'experience_level': self.agent.blstats.experience_level, 382 | 'milestone': self.agent.global_logic.milestone, 383 | 'panic_num': len(self.agent.all_panics), 384 | 'character': str(self.agent.character).split()[0], 385 | 'end_reason': self.end_reason, 386 | 'seed': self.env.get_seeds(), 387 | **self.agent.stats_logger.get_stats_dict(), 388 | } 389 | -------------------------------------------------------------------------------- /bin/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import os 4 | import subprocess 5 | import sys 6 | import termios 7 | import time 8 | import traceback 9 | import tty 10 | import warnings 11 | from argparse import ArgumentParser 12 | from multiprocessing import Process, Queue 13 | from multiprocessing.pool import ThreadPool 14 | from pathlib import Path 15 | from pprint import pprint 16 | 17 | import gym 18 | import nle.nethack as nh 19 | import numpy as np 20 | 21 | from autoascend import agent as agent_lib 22 | from autoascend.env_wrapper import EnvWrapper 23 | from autoascend.utils import plot_dashboard 24 | 25 | 26 | def prepare_env(args, seed): 27 | seed += args.seed 28 | 29 | if args.role: 30 | while 1: 31 | env = gym.make('NetHackChallenge-v0') 32 | env.seed(seed, seed) 33 | obs = env.reset() 34 | blstats = agent_lib.BLStats(*obs['blstats']) 35 | character_glyph = obs['glyphs'][blstats.y, blstats.x] 36 | if any([nh.permonst(nh.glyph_to_mon(character_glyph)).mname.startswith(role) for role in args.role]): 37 | break 38 | seed += 10 ** 9 39 | env.close() 40 | 41 | if args.visualize_ends is not None: 42 | assert args.mode == 'simulate' 43 | args.skip_to = 2 ** 32 44 | 45 | visualize_with_simulate = args.visualize_ends is not None or args.output_video_dir is not None 46 | visualizer_args = dict(enable=args.mode == 'run' or visualize_with_simulate, 47 | start_visualize=args.visualize_ends[seed] if args.visualize_ends is not None else None, 48 | show=args.mode == 'run', 49 | output_dir=Path('/tmp/vis/') / str(seed), 50 | frame_skipping=None if not visualize_with_simulate else 1, 51 | output_video_path=(args.output_video_dir / f'{seed}.mp4' 52 | if args.output_video_dir is not None else None)) 53 | env = EnvWrapper(gym.make('NetHackChallenge-v0', no_progress_timeout=1000), 54 | to_skip=args.skip_to, visualizer_args=visualizer_args, 55 | agent_args=dict(panic_on_errors=args.panic_on_errors, 56 | verbose=args.mode == 'run'), 57 | interactive=args.mode == 'run') 58 | env.env.seed(seed, seed) 59 | return env 60 | 61 | 62 | def single_simulation(args, seed_offset, timeout=720): 63 | start_time = time.time() 64 | env = prepare_env(args, seed_offset) 65 | 66 | try: 67 | if timeout is not None: 68 | with ThreadPool(1) as pool: 69 | pool.apply_async(env.main).get(timeout) 70 | else: 71 | env.main() 72 | except multiprocessing.context.TimeoutError: 73 | env.end_reason = f'timeout' 74 | except BaseException as e: 75 | env.end_reason = f'exception: {"".join(traceback.format_exception(None, e, e.__traceback__))}' 76 | print(f'Seed {env.env.get_seeds()}, step {env.step_count}:', env.end_reason) 77 | 78 | end_time = time.time() 79 | summary = env.get_summary() 80 | summary['duration'] = end_time - start_time 81 | 82 | if args.visualize_ends is not None: 83 | env.visualizer.save_end_history() 84 | 85 | if env.visualizer is not None and env.visualizer.video_writer is not None: 86 | env.visualizer.video_writer.close() 87 | env.env.close() 88 | 89 | return summary 90 | 91 | 92 | def run_single_interactive_game(args): 93 | termios.tcgetattr(sys.stdin) 94 | tty.setcbreak(sys.stdin.fileno()) 95 | try: 96 | summary = single_simulation(args, 0, timeout=None) 97 | pprint(summary) 98 | finally: 99 | os.system('stty sane') 100 | 101 | 102 | def run_profiling(args): 103 | if args.profiler == 'cProfile': 104 | import cProfile, pstats 105 | elif args.profiler == 'pyinstrument': 106 | from pyinstrument import Profiler 107 | elif args.profiler == 'none': 108 | pass 109 | else: 110 | assert 0 111 | 112 | if args.profiler == 'cProfile': 113 | pr = cProfile.Profile() 114 | elif args.profiler == 'pyinstrument': 115 | profiler = Profiler() 116 | elif args.profiler == 'none': 117 | pass 118 | else: 119 | assert 0 120 | 121 | if args.profiler == 'cProfile': 122 | pr.enable() 123 | elif args.profiler == 'pyinstrument': 124 | profiler.start() 125 | elif args.profiler == 'none': 126 | pass 127 | else: 128 | assert 0 129 | 130 | start_time = time.time() 131 | res = [] 132 | for i in range(args.episodes): 133 | print(f'starting {i + 1} game...') 134 | res.append(single_simulation(args, i, timeout=None)) 135 | duration = time.time() - start_time 136 | 137 | if args.profiler == 'cProfile': 138 | pr.disable() 139 | elif args.profiler == 'pyinstrument': 140 | session = profiler.stop() 141 | elif args.profiler == 'none': 142 | pass 143 | else: 144 | assert 0 145 | 146 | print() 147 | print('turns_per_second :', sum([r['turns'] for r in res]) / duration) 148 | print('steps_per_second :', sum([r['steps'] for r in res]) / duration) 149 | print('episodes_per_hour:', len(res) / duration * 3600) 150 | print() 151 | 152 | if args.profiler == 'cProfile': 153 | stats = pstats.Stats(pr).sort_stats(pstats.SortKey.CUMULATIVE) 154 | stats.print_stats(30) 155 | stats = pstats.Stats(pr).sort_stats(pstats.SortKey.TIME) 156 | stats.print_stats(30) 157 | stats.dump_stats('/tmp/nethack_stats.profile') 158 | 159 | subprocess.run('gprof2dot -f pstats /tmp/nethack_stats.profile -o /tmp/calling_graph.dot'.split()) 160 | subprocess.run('xdot /tmp/calling_graph.dot'.split()) 161 | elif args.profiler == 'pyinstrument': 162 | frame_records = session.frame_records 163 | 164 | new_records = [] 165 | for record in frame_records: 166 | ret_frames = [] 167 | for frame in record[0][1:][::-1]: 168 | func, module, line = frame.split('\0') 169 | if func in ['f', 'f2', 'run', 'wrapper']: 170 | continue 171 | ret_frames.append(frame) 172 | if module.endswith('agent.py') and func in ['step', 'preempt', 'call_update_functions']: 173 | break 174 | ret_frames.append(record[0][0]) 175 | new_records.append((ret_frames[::-1], record[1] / session.duration * 100)) 176 | session.frame_records = new_records 177 | session.start_call_stack = [session.start_call_stack[0]] 178 | 179 | print('Cumulative time:') 180 | profiler._last_session = session 181 | print(profiler.output_text(unicode=True, color=True)) 182 | 183 | new_records = [] 184 | for record in frame_records: 185 | ret_frames = [] 186 | for frame in record[0][1:][::-1]: 187 | func, module, line = frame.split('\0') 188 | ret_frames.append(frame) 189 | if str(Path(module).absolute()).startswith(str(Path(__file__).parent.absolute())): 190 | break 191 | ret_frames.append(record[0][0]) 192 | new_records.append((ret_frames[::-1], record[1] / session.duration * 100)) 193 | session.frame_records = new_records 194 | session.start_call_stack = [session.start_call_stack[0]] 195 | 196 | print('Total time:') 197 | profiler._last_session = session 198 | print(profiler.output_text(unicode=True, color=True, show_all=True)) 199 | elif args.profiler == 'none': 200 | pass 201 | else: 202 | assert 0 203 | 204 | 205 | def run_simulations(args): 206 | import ray 207 | ray.init(address='auto') 208 | 209 | start_time = time.time() 210 | plot_queue = Queue() 211 | 212 | def plot_thread_func(): 213 | from matplotlib import pyplot as plt 214 | import seaborn as sns 215 | 216 | warnings.filterwarnings('ignore') 217 | sns.set() 218 | 219 | fig = plt.figure() 220 | plt.show(block=False) 221 | while 1: 222 | res = None 223 | try: 224 | while 1: 225 | res = plot_queue.get(block=False) 226 | except: 227 | plt.pause(0.5) 228 | if res is None: 229 | continue 230 | 231 | fig.clear() 232 | plot_dashboard(fig, res) 233 | fig.tight_layout() 234 | plt.show(block=False) 235 | 236 | if not args.no_plot: 237 | plt_process = Process(target=plot_thread_func) 238 | plt_process.start() 239 | 240 | refs = [] 241 | 242 | @ray.remote(num_gpus=1 / 4 if args.with_gpu else 0) 243 | def remote_simulation(args, seed_offset, timeout=500): 244 | # I think there is some nondeterminism in nle environment when playing 245 | # multiple episodes (maybe bones?). That should do the trick 246 | q = Queue() 247 | 248 | if args.output_video_dir is not None: 249 | timeout = 4 * 24 * 60 * 60 250 | 251 | def sim(): 252 | q.put(single_simulation(args, seed_offset, timeout=timeout)) 253 | 254 | try: 255 | p = Process(target=sim, daemon=False) 256 | p.start() 257 | return q.get() 258 | finally: 259 | p.terminate() 260 | p.join() 261 | 262 | # uncomment to debug why join doesn't work properly 263 | # from multiprocessing.pool import ThreadPool 264 | # with ThreadPool(1) as thrpool: 265 | # def fun(): 266 | # import time 267 | # while True: 268 | # time.sleep(1) 269 | # print(p.pid, p.is_alive(), p.exitcode, p) 270 | # thrpool.apply_async(fun) 271 | # p.join(timeout=timeout + 1) 272 | # assert not q.empty() 273 | 274 | try: 275 | with args.simulation_results.open('r') as f: 276 | all_res = json.load(f) 277 | print('Continue running: ', (len(all_res['seed']))) 278 | except FileNotFoundError: 279 | all_res = {} 280 | 281 | done_seeds = set() 282 | if 'seed' in all_res: 283 | done_seeds = set(s[0] for s in all_res['seed']) 284 | 285 | # remove runs finished with exceptions if rerunning with --panic-on-errors 286 | if args.panic_on_errors and all_res: 287 | idx_to_repeat = set() 288 | for i, (seed, reason) in enumerate(zip(all_res['seed'], all_res['end_reason'])): 289 | if reason.startswith('exception'): 290 | idx_to_repeat.add(i) 291 | done_seeds.remove(seed[0]) 292 | print('Repeating idx:', idx_to_repeat) 293 | for k, v in all_res.items(): 294 | all_res[k] = [v for i, v in enumerate(v) if i not in idx_to_repeat] 295 | 296 | print('skipping seeds', done_seeds) 297 | for seed_offset in range(args.episodes): 298 | seed = args.seed + seed_offset 299 | if seed in done_seeds: 300 | continue 301 | if args.seeds and seed not in args.seeds: 302 | continue 303 | if args.visualize_ends is None or seed_offset in [k % 10 ** 9 for k in args.visualize_ends]: 304 | refs.append(remote_simulation.remote(args, seed_offset)) 305 | 306 | count = len(done_seeds) 307 | initial_count = count 308 | for handle in refs: 309 | ref, refs = ray.wait(refs, num_returns=1, timeout=None) 310 | single_res = ray.get(ref[0]) 311 | 312 | if not all_res: 313 | all_res = {key: [] for key in single_res} 314 | assert all_res.keys() == single_res.keys() 315 | 316 | count += 1 317 | for k, v in single_res.items(): 318 | all_res[k].append(v if not hasattr(v, 'item') else v.item()) 319 | 320 | plot_queue.put(all_res) 321 | 322 | total_duration = time.time() - start_time 323 | 324 | median_score_std = np.std([np.median(np.random.choice(all_res["score"], 325 | size=max(1, len(all_res["score"]) // 2))) 326 | for _ in range(1024)]) 327 | 328 | text = [] 329 | text.append(f'count : {count}') 330 | text.append(f'time_per_simulation : {np.mean(all_res["duration"])}') 331 | text.append(f'simulations_per_hour : {3600 / np.mean(all_res["duration"])}') 332 | text.append(f'simulations_per_hour(multi) : {3600 * (count - initial_count) / total_duration}') 333 | text.append(f'time_per_turn : {np.sum(all_res["duration"]) / np.sum(all_res["turns"])}') 334 | text.append(f'turns_per_second : {np.sum(all_res["turns"]) / np.sum(all_res["duration"])}') 335 | text.append(f'turns_per_second(multi) : {np.sum(all_res["turns"]) / total_duration}') 336 | text.append(f'panic_num_per_game(median) : {np.median(all_res["panic_num"])}') 337 | text.append(f'panic_num_per_game(mean) : {np.sum(all_res["panic_num"]) / count}') 338 | text.append(f'score_median : {np.median(all_res["score"]):.1f} +/- ' 339 | f'{median_score_std:.1f}') 340 | text.append(f'score_mean : {np.mean(all_res["score"]):.1f} +/- ' 341 | f'{np.std(all_res["score"]) / (len(all_res["score"]) ** 0.5):.1f}') 342 | text.append(f'score_05-95 : {np.quantile(all_res["score"], 0.05)} ' 343 | f'{np.quantile(all_res["score"], 0.95)}') 344 | text.append(f'score_25-75 : {np.quantile(all_res["score"], 0.25)} ' 345 | f'{np.quantile(all_res["score"], 0.75)}') 346 | text.append(f'exceptions : ' 347 | f'{sum([r.startswith("exception:") for r in all_res["end_reason"]])}') 348 | text.append(f'steplimit : ' 349 | f'{sum([r.startswith("steplimit") or r.startswith("ABORT") for r in all_res["end_reason"]])}') 350 | text.append(f'timeout : ' 351 | f'{sum([r.startswith("timeout") for r in all_res["end_reason"]])}') 352 | print('\n'.join(text) + '\n') 353 | 354 | if args.visualize_ends is None: 355 | with args.simulation_results.open('w') as f: 356 | json.dump(all_res, f) 357 | 358 | print('DONE!') 359 | ray.shutdown() 360 | 361 | 362 | def parse_args(): 363 | parser = ArgumentParser() 364 | parser.add_argument('mode', choices=('simulate', 'run', 'profile')) 365 | parser.add_argument('--seed', type=int, help='Starting random seed') 366 | parser.add_argument('--seeds', nargs="*", type=int, 367 | help='Run only these specific seeds (only relevant in simulate mode)') 368 | parser.add_argument('--skip-to', type=int, default=0) 369 | parser.add_argument('-n', '--episodes', type=int, default=1) 370 | parser.add_argument('--role', choices=('arc', 'bar', 'cav', 'hea', 'kni', 371 | 'mon', 'pri', 'ran', 'rog', 'sam', 372 | 'tou', 'val', 'wiz'), 373 | action='append') 374 | parser.add_argument('--panic-on-errors', action='store_true') 375 | parser.add_argument('--no-plot', action='store_true') 376 | parser.add_argument('--visualize-ends', type=Path, default=None, 377 | help='Path to json file with dict: seed -> visualization_start_step.' 378 | 'THIS IS AN UNMAINTAINED FEATURE.' 379 | 'It was used to save some visualizer frames before agent deathto conveniently browse them.') 380 | parser.add_argument('--output-video-dir', type=Path, default=None, 381 | help="Episode visualization video directory -- valid only with 'simulate' mode") 382 | parser.add_argument('--profiler', choices=('cProfile', 'pyinstrument', 'none'), default='pyinstrument') 383 | parser.add_argument('--with-gpu', action='store_true') 384 | parser.add_argument('--simulation-results', default='nh_sim.json', type=Path, 385 | help='path to simulation results json. Only for simulation mode') 386 | 387 | args = parser.parse_args() 388 | if args.seed is None: 389 | args.seed = np.random.randint(0, 2 ** 30) 390 | 391 | if args.visualize_ends is not None: 392 | with args.visualize_ends.open('r') as f: 393 | args.visualize_ends = {int(k): int(v) for k, v in json.load(f).items()} 394 | 395 | if args.output_video_dir is not None: 396 | assert args.mode == 'simulate', "Video output only valid in 'simulate' mode" 397 | 398 | print('ARGS:', args) 399 | return args 400 | 401 | 402 | def main(): 403 | args = parse_args() 404 | if args.mode == 'simulate': 405 | run_simulations(args) 406 | elif args.mode == 'profile': 407 | run_profiling(args) 408 | elif args.mode == 'run': 409 | run_single_interactive_game(args) 410 | else: 411 | assert 0 412 | 413 | 414 | if __name__ == '__main__': 415 | main() 416 | -------------------------------------------------------------------------------- /autoascend/visualization/visualizer.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import queue 3 | import time 4 | 5 | import cv2 6 | import nle.nethack as nh 7 | import numpy as np 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | # avoid importing agent modules here, because it makes agent reloading less reliable 11 | from .scopes import DrawTilesScope, DebugLogScope 12 | from .utils import put_text, draw_frame, draw_grid, FONT_SIZE, VideoWriter 13 | 14 | HISTORY_SIZE = 13 15 | RENDERS_HISTORY_SIZE = 128 16 | 17 | 18 | class Visualizer: 19 | 20 | def __init__(self, env, tileset_path='/tilesets/3.6.1tiles32.png', tile_size=32, 21 | start_visualize=None, show=False, output_dir=None, frame_skipping=None, output_video_path=None): 22 | self.env = env 23 | self.tile_size = tile_size 24 | self._window_name = 'NetHackVis' 25 | self.show = show 26 | self.start_visualize = start_visualize 27 | self.output_dir = output_dir 28 | 29 | self.last_obs = None 30 | 31 | self.tileset = cv2.imread(tileset_path)[..., ::-1] 32 | if self.tileset is None: 33 | raise FileNotFoundError(f'Tileset {tileset_path} not found') 34 | if self.tileset.shape[0] % tile_size != 0 or self.tileset.shape[1] % tile_size != 0: 35 | raise ValueError("Tileset and tile_size doesn't match modulo") 36 | 37 | h = self.tileset.shape[0] // tile_size 38 | w = self.tileset.shape[1] // tile_size 39 | tiles = [] 40 | for y in range(h): 41 | y *= tile_size 42 | for x in range(w): 43 | x *= tile_size 44 | tiles.append(self.tileset[y:y + tile_size, x:x + tile_size]) 45 | self.tileset = np.array(tiles) 46 | 47 | # note that this file is a symlink (acutall file is in the docker container) 48 | from .glyph2tile import glyph2tile 49 | 50 | self.glyph2tile = np.array(glyph2tile) 51 | 52 | if self.show: 53 | print('Read tileset of size:', self.tileset.shape) 54 | 55 | self.action_history = list() 56 | 57 | self.message_history = list() 58 | self.popup_history = list() 59 | 60 | self.drawers = [] 61 | self.log_messages = list() 62 | self.log_messages_history = list() 63 | 64 | self.frame_skipping = frame_skipping 65 | self.frame_counter = -1 66 | self._force_next_frame = False 67 | self._dynamic_frame_skipping_exp = lambda: min(0.95, 1 - 1 / (self.env.step_count + 1)) 68 | self._dynamic_frame_skipping_render_time = 0 69 | self._dynamic_frame_skipping_agent_time = 1e-6 70 | self._dynamic_frame_skipping_threshold = 0.3 # for render_time / agent_time 71 | self._dynamic_frame_skipping_last_end_time = None 72 | self.total_time = 0 73 | 74 | self.renders_history = None 75 | if not self.show and output_video_path is None: 76 | assert output_dir is not None 77 | self.renders_history = queue.deque(maxlen=RENDERS_HISTORY_SIZE) 78 | self.output_dir = output_dir 79 | self.output_dir.mkdir(exist_ok=True, parents=True) 80 | 81 | self._start_display_thread() 82 | 83 | self.last_obs = None 84 | 85 | self.video_writer = None 86 | if output_video_path is not None: 87 | self.video_writer = VideoWriter(output_video_path, fps=10) 88 | 89 | self.tty_downscale = 1.0 # consider changing for better performance 90 | self.font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", 91 | int(26 * self.tty_downscale)) 92 | 93 | def debug_tiles(self, *args, **kwargs): 94 | return DrawTilesScope(self, *args, **kwargs) 95 | 96 | def debug_log(self, txt, color): 97 | return DebugLogScope(self, txt, color) 98 | 99 | def step(self, obs, action): 100 | self.last_obs = obs 101 | self.action_history.append(action) 102 | self._update_log_message_history() 103 | self._update_message_history() 104 | self._update_popup_history() 105 | 106 | if self.video_writer is not None: 107 | frame = self._render() 108 | if frame is not None: 109 | self.video_writer.write(frame) 110 | 111 | def render(self): 112 | if self.video_writer is not None: 113 | return False 114 | 115 | self.frame_counter += 1 116 | render_start_time = None 117 | 118 | try: 119 | t = time.time() 120 | frame = self._render() 121 | if frame is None: 122 | return False 123 | render_start_time = t 124 | 125 | if self.show: 126 | self._display_queue.put(frame[..., ::-1].copy()) 127 | 128 | if self.renders_history is not None: 129 | self.renders_history.append(frame) 130 | 131 | finally: 132 | self._update_dynamic_frame_skipping(render_start_time) 133 | 134 | return True 135 | 136 | def _render(self): 137 | if not self._force_next_frame and self.frame_skipping is not None: 138 | # static frame skipping 139 | if self.frame_counter % self.frame_skipping != 0: 140 | return None 141 | 142 | if self.frame_skipping is None: 143 | # dynamic frame skipping 144 | frame_skipping = self._dynamic_frame_skipping_render_time / self._dynamic_frame_skipping_agent_time / \ 145 | self._dynamic_frame_skipping_threshold 146 | if not self._force_next_frame and self.frame_counter <= frame_skipping: 147 | return None 148 | else: 149 | self.frame_counter = 0 150 | 151 | if self.last_obs is None: 152 | return None 153 | 154 | if self.start_visualize is not None: 155 | if self.env.step_count < self.start_visualize: 156 | return None 157 | 158 | if self._force_next_frame: 159 | self.frame_counter = 0 160 | self._force_next_frame = False 161 | 162 | glyphs = self.last_obs['glyphs'] 163 | tiles_idx = self.glyph2tile[glyphs] 164 | tiles = self.tileset[tiles_idx.reshape(-1)] 165 | scene_vis = draw_grid(tiles, glyphs.shape[1]) 166 | for drawer in self.drawers: 167 | scene_vis = drawer(scene_vis) 168 | draw_frame(scene_vis) 169 | topbar = self._draw_topbar(scene_vis.shape[1]) 170 | bottombar = self._draw_bottombar(scene_vis.shape[1]) 171 | 172 | rendered = np.concatenate([topbar, scene_vis, bottombar], axis=0) 173 | inventory = self._draw_inventory(rendered.shape[0]) 174 | return np.concatenate([rendered, inventory], axis=1) 175 | 176 | def save_end_history(self): 177 | print('SAVING', self.output_dir) 178 | for i, render in enumerate(list(self.renders_history)): 179 | render = render[..., ::-1] 180 | out_path = self.output_dir / (str(i).rjust(5, '0') + '.jpg') 181 | cv2.imwrite(str(out_path), render) 182 | 183 | def force_next_frame(self): 184 | self._force_next_frame = True 185 | 186 | def stop_display_thread(self): 187 | if self.show: 188 | self._display_process.terminate() 189 | self._display_process.join() 190 | 191 | def _display_thread(self): 192 | cv2.namedWindow(self._window_name, cv2.WINDOW_NORMAL | cv2.WINDOW_GUI_NORMAL) 193 | 194 | last_size = (None, None) 195 | image = None 196 | while 1: 197 | is_new_image = False 198 | try: 199 | while 1: 200 | try: 201 | image = self._display_queue.get(timeout=0.03) 202 | is_new_image = True 203 | except queue.Empty: 204 | break 205 | 206 | if image is None: 207 | image = self._display_queue.get() 208 | is_new_image = True 209 | 210 | width = cv2.getWindowImageRect(self._window_name)[2] 211 | height = cv2.getWindowImageRect(self._window_name)[3] 212 | ratio = min(width / image.shape[1], height / image.shape[0]) 213 | width, height = round(image.shape[1] * ratio), round(image.shape[0] * ratio) 214 | 215 | if last_size != (width, height) or is_new_image: 216 | last_size = (width, height) 217 | 218 | resized_image = cv2.resize(image, (width, height), cv2.INTER_AREA) 219 | cv2.imshow(self._window_name, resized_image) 220 | 221 | cv2.waitKey(1) 222 | except KeyboardInterrupt: 223 | pass 224 | except (ConnectionResetError, EOFError): 225 | return 226 | 227 | cv2.destroyWindow(self._window_name) 228 | 229 | def _start_display_thread(self): 230 | if self.show: 231 | self._display_queue = multiprocessing.Manager().Queue() 232 | self._display_process = multiprocessing.Process(target=self._display_thread, daemon=False) 233 | self._display_process.start() 234 | 235 | def _update_dynamic_frame_skipping(self, render_start_time): 236 | if self._dynamic_frame_skipping_last_end_time is not None: 237 | self.total_time += time.time() - self._dynamic_frame_skipping_last_end_time 238 | if render_start_time is not None: 239 | render_time = time.time() - render_start_time 240 | else: 241 | render_time = None 242 | agent_time = time.time() - self._dynamic_frame_skipping_last_end_time - \ 243 | (render_time if render_time is not None else 0) 244 | 245 | if render_start_time is not None: 246 | self._dynamic_frame_skipping_render_time = \ 247 | self._dynamic_frame_skipping_render_time * self._dynamic_frame_skipping_exp() + \ 248 | render_time * (1 - self._dynamic_frame_skipping_exp()) 249 | self._dynamic_frame_skipping_agent_time = \ 250 | self._dynamic_frame_skipping_agent_time * self._dynamic_frame_skipping_exp() + \ 251 | agent_time * (1 - self._dynamic_frame_skipping_exp()) 252 | 253 | self._dynamic_frame_skipping_last_end_time = time.time() 254 | 255 | def _draw_bottombar(self, width): 256 | height = FONT_SIZE * len(self.last_obs['tty_chars']) 257 | tty = self._draw_tty(self.last_obs, width - width // 2, height) 258 | stats = self._draw_stats(width // 2, height) 259 | return np.concatenate([tty, stats], axis=1) 260 | 261 | def _draw_stats(self, width, height): 262 | ret = np.zeros((height, width, 3), dtype=np.uint8) 263 | ch = self.env.agent.character 264 | if ch.role is None: 265 | return ret 266 | 267 | # game info 268 | i = 0 269 | txt = [f'Level num: {self.env.agent.current_level().level_number}', 270 | f'Dung num: {self.env.agent.current_level().dungeon_number}', 271 | f'Step: {self.env.step_count}', 272 | f'Turn: {self.env.agent._last_turn}', 273 | f'Score: {self.env.score}', 274 | ] 275 | put_text(ret, ' | '.join(txt), (0, i * FONT_SIZE), color=(255, 255, 255)) 276 | i += 3 277 | 278 | # general character info 279 | txt = [ 280 | {v: k for k, v in ch.name_to_role.items()}[ch.role], 281 | {v: k for k, v in ch.name_to_race.items()}[ch.race], 282 | {v: k for k, v in ch.name_to_alignment.items()}[ch.alignment], 283 | {v: k for k, v in ch.name_to_gender.items()}[ch.gender], 284 | ] 285 | put_text(ret, ' | '.join(txt), (0, i * FONT_SIZE)) 286 | i += 1 287 | txt = [f'HP: {self.env.agent.blstats.hitpoints} / {self.env.agent.blstats.max_hitpoints}', 288 | f'LVL: {self.env.agent.blstats.experience_level}', 289 | f'ENERGY: {self.env.agent.blstats.energy} / {self.env.agent.blstats.max_energy}', 290 | ] 291 | hp_ratio = self.env.agent.blstats.hitpoints / self.env.agent.blstats.max_hitpoints 292 | hp_color = cv2.applyColorMap(np.array([[130 - int((1 - hp_ratio) * 110)]], dtype=np.uint8), 293 | cv2.COLORMAP_TURBO)[0, 0] 294 | put_text(ret, ' | '.join(txt), (0, i * FONT_SIZE), color=tuple(map(int, hp_color))) 295 | i += 2 296 | 297 | # proficiency info 298 | colors = { 299 | 'Basic': (100, 100, 255), 300 | 'Skilled': (100, 255, 100), 301 | 'Expert': (100, 255, 255), 302 | 'Master': (255, 255, 100), 303 | 'Grand Master': (255, 100, 100), 304 | } 305 | for line in ch.get_skill_str_list(): 306 | if 'Unskilled' not in line: 307 | put_text(ret, line, (0, i * FONT_SIZE), color=colors[line.split('-')[-1]]) 308 | i += 1 309 | unskilled = [] 310 | for line in ch.get_skill_str_list(): 311 | if 'Unskilled' in line: 312 | unskilled.append(line.split('-')[0]) 313 | put_text(ret, '|'.join(unskilled), (0, i * FONT_SIZE), color=(100, 100, 100)) 314 | i += 2 315 | put_text(ret, 'Unarmed bonus: ' + str(ch.get_melee_bonus(None)), (0, i * FONT_SIZE)) 316 | i += 2 317 | 318 | stats = list(self.env.agent.stats_logger.get_stats_dict().items()) 319 | stats = [(k, v) for k, v in stats if v != 0] 320 | for j in range((len(stats) + 2) // 3): 321 | def format_value(v): 322 | if isinstance(v, float): 323 | return f'{v:.2f}' 324 | return str(v) 325 | 326 | put_text(ret, ' | '.join(f'{k}={format_value(v)}' for k, v in stats[j * 3: (j + 1) * 3]), 327 | (0, i * FONT_SIZE), color=(100, 100, 100)) 328 | i += 1 329 | i += 1 330 | 331 | if hasattr(self.env.agent.character, 'known_spells'): 332 | put_text(ret, 'Known spells: ' + str(list(self.env.agent.character.known_spells)), (0, i * FONT_SIZE)) 333 | i += 1 334 | 335 | monsters = [(dis, y, x, mon.mname) for dis, y, x, mon, _ in self.env.agent.get_visible_monsters()] 336 | put_text(ret, 'Monsters: ' + str(monsters), (0, i * FONT_SIZE)) 337 | 338 | draw_frame(ret) 339 | return ret 340 | 341 | def _draw_topbar(self, width): 342 | actions_vis = self._draw_action_history(width // 25) 343 | messages_vis = self._draw_message_history(width // 4) 344 | popup_vis = self._draw_popup_history(width // 4) 345 | log_messages_vis = self._draw_debug_message_log(width - width // 25 - width // 4 - width // 4) 346 | ret = np.concatenate([actions_vis, messages_vis, popup_vis, log_messages_vis], axis=1) 347 | assert ret.shape[1] == width 348 | return ret 349 | 350 | def _draw_debug_message_log(self, width): 351 | vis = np.zeros((FONT_SIZE * HISTORY_SIZE, width, 3)).astype(np.uint8) 352 | for i in range(HISTORY_SIZE): 353 | if i >= len(self.log_messages_history): 354 | break 355 | txt = self.log_messages_history[-i - 1] 356 | if i == 0: 357 | put_text(vis, txt, (0, i * FONT_SIZE), color=(255, 255, 255)) 358 | else: 359 | put_text(vis, txt, (0, i * FONT_SIZE), color=(120, 120, 120)) 360 | draw_frame(vis) 361 | return vis 362 | 363 | def _update_log_message_history(self): 364 | txt = '' 365 | if self.env.agent is not None: 366 | txt = ' | '.join(self.log_messages) 367 | # if txt: 368 | self.log_messages_history.append(txt) 369 | 370 | def _draw_action_history(self, width): 371 | vis = np.zeros((FONT_SIZE * HISTORY_SIZE, width, 3)).astype(np.uint8) 372 | for i in range(HISTORY_SIZE): 373 | if i >= len(self.action_history): 374 | break 375 | txt = self.action_history[-i - 1] 376 | if i == 0: 377 | put_text(vis, txt, (0, i * FONT_SIZE), color=(255, 255, 255)) 378 | else: 379 | put_text(vis, txt, (0, i * FONT_SIZE), color=(120, 120, 120)) 380 | draw_frame(vis) 381 | return vis 382 | 383 | def _draw_message_history(self, width): 384 | messages_vis = np.zeros((FONT_SIZE * HISTORY_SIZE, width, 3)).astype(np.uint8) 385 | for i in range(HISTORY_SIZE): 386 | if i >= len(self.message_history): 387 | break 388 | txt = self.message_history[-i - 1] 389 | if i == 0: 390 | put_text(messages_vis, txt, (0, i * FONT_SIZE), color=(255, 255, 255)) 391 | else: 392 | put_text(messages_vis, txt, (0, i * FONT_SIZE), color=(120, 120, 120)) 393 | draw_frame(messages_vis) 394 | return messages_vis 395 | 396 | def _draw_popup_history(self, width): 397 | messages_vis = np.zeros((FONT_SIZE * HISTORY_SIZE, width, 3)).astype(np.uint8) 398 | for i in range(HISTORY_SIZE): 399 | if i >= len(self.popup_history): 400 | break 401 | txt = '|'.join(self.popup_history[-i - 1]) 402 | if i == 0: 403 | put_text(messages_vis, txt, (0, i * FONT_SIZE), color=(255, 255, 255)) 404 | else: 405 | put_text(messages_vis, txt, (0, i * FONT_SIZE), color=(120, 120, 120)) 406 | draw_frame(messages_vis) 407 | return messages_vis 408 | 409 | def _update_message_history(self): 410 | txt = '' 411 | if self.env.agent is not None: 412 | txt = self.env.agent.message 413 | # if txt: 414 | self.message_history.append(txt) 415 | 416 | def _update_popup_history(self): 417 | txt = '' 418 | if self.env.agent is not None: 419 | txt = self.env.agent.popup 420 | # if txt: 421 | self.popup_history.append(txt) 422 | 423 | def _draw_tty(self, obs, width, height): 424 | vis = np.zeros((int(height * self.tty_downscale), 425 | int(width * self.tty_downscale), 3)).astype(np.uint8) 426 | 427 | vis = Image.fromarray(vis) 428 | draw = ImageDraw.Draw(vis) 429 | 430 | for i, line in enumerate(obs['tty_chars']): 431 | txt = ''.join([chr(i) for i in line]) 432 | draw.text((int(5 * self.tty_downscale), int((5 + i * 31) * self.tty_downscale)), 433 | txt, (255, 255, 255), font=self.font) 434 | 435 | vis = np.array(vis.resize((width, height), Image.ANTIALIAS)) 436 | draw_frame(vis) 437 | return vis 438 | 439 | def _draw_item(self, letter, item, width, height, indent=0): 440 | from ..item import Item 441 | 442 | bg_color = { 443 | nh.WAND_CLASS: np.array([0, 50, 50], dtype=np.uint8), 444 | nh.FOOD_CLASS: np.array([0, 50, 0], dtype=np.uint8), 445 | nh.ARMOR_CLASS: np.array([50, 50, 0], dtype=np.uint8), 446 | nh.RING_CLASS: np.array([50, 50, 0], dtype=np.uint8), 447 | nh.SCROLL_CLASS: np.array([30, 30, 30], dtype=np.uint8), 448 | nh.POTION_CLASS: np.array([0, 0, 50], dtype=np.uint8), 449 | } 450 | 451 | indent = int((width - 1) * (1 - 0.9 ** indent)) 452 | 453 | vis = np.zeros((round(height * 0.9), width - indent, 3)).astype(np.uint8) 454 | if item.category in bg_color: 455 | vis += bg_color[item.category] 456 | if item.is_weapon(): 457 | if item.is_thrown_projectile() or item.is_fired_projectile(): 458 | vis += np.array([50, 0, 50], dtype=np.uint8) 459 | else: 460 | vis += np.array([50, 0, 0], dtype=np.uint8) 461 | if letter is not None: 462 | put_text(vis, str(letter), (0, 0)) 463 | 464 | status_str, status_col = { 465 | Item.UNKNOWN: (' ', (255, 255, 255)), 466 | Item.CURSED: ('C', (255, 0, 0)), 467 | Item.UNCURSED: ('U', (0, 255, 255)), 468 | Item.BLESSED: ('B', (0, 255, 0)), 469 | }[item.status] 470 | put_text(vis, status_str, (FONT_SIZE, 0), color=status_col) 471 | 472 | if item.modifier is not None: 473 | put_text(vis, str(item.modifier), (FONT_SIZE * 2, 0)) 474 | 475 | best_launcher, best_ammo = self.env.agent.inventory.get_best_ranged_set() 476 | best_melee = self.env.agent.inventory.get_best_melee_weapon() 477 | if item == best_launcher: 478 | put_text(vis, 'L', (FONT_SIZE * 3, 0), color=(255, 255, 255)) 479 | if item == best_ammo: 480 | put_text(vis, 'A', (FONT_SIZE * 3, 0), color=(255, 255, 255)) 481 | if item == best_melee: 482 | put_text(vis, 'M', (FONT_SIZE * 3, 0), color=(255, 255, 255)) 483 | 484 | if item.is_weapon(): 485 | put_text(vis, str(self.env.agent.character.get_melee_bonus(item)), (FONT_SIZE * 4, 0)) 486 | 487 | put_text(vis, str(item), (FONT_SIZE * 8, round(FONT_SIZE * -0.1)), scale=FONT_SIZE / 40) 488 | # if len(item.objs) > 1: 489 | vis = np.concatenate([vis, np.zeros((vis.shape[0] // 2, vis.shape[1], 3), dtype=np.uint8)]) 490 | put_text(vis, str(len(item.objs)) + ' | ' + ' | '.join((o.name for o in item.objs)), 491 | (0, round(FONT_SIZE * 0.8)), scale=FONT_SIZE / 40) 492 | 493 | draw_frame(vis, color=(80, 80, 80), thickness=2) 494 | 495 | if item.equipped: 496 | cv2.rectangle(vis, (0, 0), (int(FONT_SIZE * 1.4), vis.shape[0] - 1), (0, 255, 255), 6) 497 | 498 | if indent != 0: 499 | vis = np.concatenate([np.zeros((vis.shape[0], width - vis.shape[1], 3), dtype=np.uint8), vis], 1) 500 | 501 | return vis 502 | 503 | def _draw_inventory(self, height): 504 | width = 800 505 | vis = np.zeros((height, width, 3), dtype=np.uint8) 506 | if self.env.agent: 507 | item_h = round(FONT_SIZE * 1.4) 508 | tiles = [] 509 | for i, (letter, item) in enumerate(zip(self.env.agent.inventory.items.all_letters, 510 | self.env.agent.inventory.items.all_items)): 511 | 512 | def rec_draw(item, letter, indent=0): 513 | tiles.append(self._draw_item(letter, item, width, item_h, indent=indent)) 514 | if item.is_container(): 515 | for it in item.content: 516 | rec_draw(it, None, indent + 1) 517 | 518 | rec_draw(item, letter, 0) 519 | if tiles: 520 | vis = np.concatenate(tiles, axis=0) 521 | if vis.shape[0] < height: 522 | vis = np.concatenate([vis, np.zeros((height - vis.shape[0], width, 3), dtype=np.uint8)], axis=0) 523 | else: 524 | vis = cv2.resize(vis, (width, height)) 525 | draw_frame(vis) 526 | return vis 527 | --------------------------------------------------------------------------------